mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
This issue is found for tp=ep=8 on the multi-node machine due to the inconsistent PP sizes. * Reform the workspace allocation implementation to avoid the list-out-of-range issues. * Disable min_latency_mode under the multi-node scenario to avoid the illegal memory access issue. Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
264 lines
9.7 KiB
Python
264 lines
9.7 KiB
Python
import threading
|
|
from typing import Optional, Tuple, Union
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
from tensorrt_llm.functional import AllReduceParams, AllReduceStrategy
|
|
from tensorrt_llm.mapping import Mapping
|
|
from tensorrt_llm.plugin.plugin import CustomAllReduceHelper
|
|
|
|
_thread_local = threading.local()
|
|
|
|
|
|
def get_allreduce_workspace(mapping: Mapping) -> torch.LongTensor:
|
|
if not hasattr(_thread_local, f'allreduce_workspaces_{mapping.pp_rank}'):
|
|
setattr(_thread_local, f'allreduce_workspaces_{mapping.pp_rank}', {})
|
|
|
|
allreduce_workspaces = getattr(_thread_local,
|
|
f'allreduce_workspaces_{mapping.pp_rank}')
|
|
if mapping not in allreduce_workspaces:
|
|
ipc_buffers, workspace = CustomAllReduceHelper.allocate_allreduce_fusion_workspace(
|
|
mapping,
|
|
CustomAllReduceHelper.max_workspace_size_auto(
|
|
mapping.tp_size, support_deterministic=False),
|
|
)
|
|
allreduce_workspaces[mapping] = (ipc_buffers, workspace)
|
|
return allreduce_workspaces[mapping][1]
|
|
|
|
|
|
def userbuffers_allreduce_finalize(
|
|
input: torch.Tensor,
|
|
force_applying_finalize: bool = False) -> torch.Tensor:
|
|
output = torch.ops.trtllm.userbuffers_allreduce_finalize(
|
|
input, force_applying_finalize)
|
|
return output
|
|
|
|
|
|
def allgather(input: torch.Tensor,
|
|
mapping: Mapping,
|
|
gather_dim: int = -1) -> torch.Tensor:
|
|
'''
|
|
Add an operation that performs a collective all-gather.
|
|
|
|
The input tensors in the different ranks must have the same shape.
|
|
The output tensor will be replicated among the TP group.
|
|
|
|
Given the 'section_size = input.shape[gather_dim]', each rank
|
|
contributes a section of its input tensor that correspond to
|
|
'rank*section_size:(rank+1)*section_size',
|
|
and 'output.shape[gather_dim] = input.shape[gather_dim] * tp_group_size'.
|
|
|
|
That operation is implemented using a torch op that wraps the NCCL all-gather
|
|
collective operation. See
|
|
https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#allgather
|
|
for details.
|
|
|
|
Args:
|
|
input (Tensor): The input tensor.
|
|
mapping (Mapping): The parallel mapping.
|
|
gather_dim (int): Gather along given dimension. By default -1.
|
|
Returns:
|
|
The gathered tensor.
|
|
'''
|
|
if mapping.tp_size == 1:
|
|
return input
|
|
|
|
output = torch.ops.trtllm.allgather(
|
|
input,
|
|
mapping.tp_group,
|
|
)
|
|
|
|
if gather_dim < 0:
|
|
gather_dim += input.ndim
|
|
|
|
output = torch.movedim(output, 0, gather_dim)
|
|
input_shape = input.size()
|
|
output = output.reshape(input_shape[:gather_dim] +
|
|
(mapping.tp_size * input_shape[gather_dim], ) +
|
|
input_shape[gather_dim + 1:])
|
|
return output
|
|
|
|
|
|
def reducescatter(input: torch.Tensor,
|
|
mapping: Mapping,
|
|
scatter_dim: int = -1) -> torch.Tensor:
|
|
if mapping.tp_size == 1:
|
|
return input
|
|
|
|
output = torch.ops.trtllm.reducescatter(
|
|
input,
|
|
mapping.tp_group,
|
|
)
|
|
|
|
if scatter_dim < 0:
|
|
scatter_dim += input.ndim
|
|
|
|
output = torch.movedim(output, 0, scatter_dim)
|
|
input_shape = input.size()
|
|
output = output.reshape(input_shape[:scatter_dim] +
|
|
(input_shape[scatter_dim] // mapping.tp_size, ) +
|
|
input_shape[scatter_dim + 1:])
|
|
return output
|
|
|
|
|
|
class AllReduce(nn.Module):
|
|
|
|
def __init__(self,
|
|
mapping: Mapping,
|
|
strategy: AllReduceStrategy = AllReduceStrategy.AUTO):
|
|
super().__init__()
|
|
"""
|
|
AllReduce is a module that performs an all-reduce operation on a tensor.
|
|
|
|
Args:
|
|
mapping (Mapping): The parallel mapping config.
|
|
strategy (AllReduceStrategy):
|
|
Three types of all-reduce strategies are supported:
|
|
- UB: AllReduce uses user-buffer based all-reduce kernel. Supported ops:
|
|
- RESIDUAL_RMS_NORM
|
|
- RESIDUAL_RMS_NORM_QUANT_FP8
|
|
- RESIDUAL_RMS_NORM_QUANT_NVFP4
|
|
|
|
- NCCL: AllReduce delegates all-reduce to NCCL MIN_LATENCY mode kernel. Supported ops:
|
|
- NONE (AllReduce only)
|
|
- RESIDUAL_RMS_NORM
|
|
|
|
- MIN_LATENCY: AllReduce uses MIN_LATENCY mode kernel. Supported ops:
|
|
- NONE (AllReduce only)
|
|
- RESIDUAL_RMS_NORM
|
|
- RESIDUAL_RMS_NORM_QUANT_FP8
|
|
- RESIDUAL_RMS_NORM_QUANT_NVFP4
|
|
- RESIDUAL_RMS_NORM_OUT_QUANT_FP8
|
|
- RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4
|
|
|
|
- AUTO: AUTO chooses between NCCL and MIN_LATENCY mode based on a heuristic policy.
|
|
|
|
Note:
|
|
For the reference implementation for each pattern, please refer to the following unit test:
|
|
https://github.com/NVIDIA/TensorRT-LLM/blob/main/tests/unittest/_torch/multi_gpu/test_allreduce.py
|
|
"""
|
|
|
|
self.mapping = mapping
|
|
self.workspace = None
|
|
self.strategy = strategy
|
|
|
|
if self.mapping.tp_size > 1:
|
|
# When Strategy is UB, it is guaranteed that the workspace is not used.
|
|
if self.strategy != AllReduceStrategy.UB:
|
|
self.workspace = get_allreduce_workspace(self.mapping)
|
|
|
|
def forward(
|
|
self,
|
|
input: torch.Tensor,
|
|
*,
|
|
all_reduce_params: Optional[AllReduceParams] = None,
|
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
|
|
'''
|
|
The input tensors in the different ranks must have the same shape.
|
|
The output tensor will have that same shape with the input tensor.
|
|
The output tensor will be replicated among the TP group.
|
|
Note that it is not an in-place operation like torch.distributed.all_reduce.
|
|
|
|
That operation is implemented using a torch op that wraps the NCCL all-reduce
|
|
collective operation and custom one-shot/two-shot allreduce kernels. See
|
|
https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#allreduce
|
|
for details.
|
|
|
|
Args:
|
|
input (Tensor): The input tensor.
|
|
all_reduce_params (AllReduceParams): The parameters for the fused ops into the allreduce op.
|
|
Returns:
|
|
A tensor lists with different tensor outptus according to the fusion_op.
|
|
NONE: [hidden_states]
|
|
RESIDUAL_RMS_NORM: [hidden_states, residual]
|
|
RESIDUAL_RMS_NORM_QUANT_FP8: [norm_quant, residual]
|
|
RESIDUAL_RMS_NORM_OUT_QUANT_FP8: [norm, norm_quant, residual]
|
|
RESIDUAL_RMS_NORM_QUANT_NVFP4: [norm_quant_fp4, scale_factor, residual]
|
|
RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4: [norm, norm_quant_fp4, scale_factor, residual]
|
|
'''
|
|
if self.mapping.tp_size == 1 or (all_reduce_params is not None
|
|
and all_reduce_params.enable_allreduce
|
|
== False):
|
|
return input
|
|
|
|
# Assume using no fusion allreduce here
|
|
if all_reduce_params is None:
|
|
all_reduce_params = AllReduceParams()
|
|
|
|
output = torch.ops.trtllm.allreduce(
|
|
input=input,
|
|
residual=all_reduce_params.residual,
|
|
norm_weight=all_reduce_params.norm_weight,
|
|
scale=all_reduce_params.scale,
|
|
bias=all_reduce_params.bias,
|
|
workspace=self.workspace,
|
|
group=self.mapping.tp_group,
|
|
strategy=self.strategy,
|
|
op=all_reduce_params.fusion_op,
|
|
eps=all_reduce_params.eps,
|
|
)
|
|
|
|
return output if len(output) > 1 else output[0]
|
|
|
|
|
|
class MoEAllReduce(nn.Module):
|
|
|
|
def __init__(self, mapping: Mapping):
|
|
"""
|
|
MoEAllReduce is a module that performs a specific fused MoE reduction
|
|
followed by a regular AR + RMS norm.
|
|
|
|
Args:
|
|
mapping (Mapping): The parallel mapping config.
|
|
|
|
Notes:
|
|
Support pattern: MoE Reduction + Add + AR + ADD_RMS, see this torch reference implementation:
|
|
expert_reduction = torch.sum(active_experts_token_input *
|
|
scale.unsqueeze(-1),
|
|
dim=0)
|
|
output_add = expert_reduction + shared_expert_output
|
|
output_residual = output_add + residual
|
|
output_hidden_states = rms_norm(output_residual, norm_weight, eps)
|
|
"""
|
|
super().__init__()
|
|
self.mapping = mapping
|
|
self.workspace = get_allreduce_workspace(self.mapping)
|
|
|
|
def forward(
|
|
self,
|
|
residual: torch.Tensor,
|
|
norm_weight: torch.Tensor,
|
|
device_num_experts: torch.Tensor,
|
|
scale_input: torch.Tensor,
|
|
active_experts_token_input: torch.Tensor,
|
|
token_input: torch.Tensor,
|
|
eps: float,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Args:
|
|
residual: residual tensor
|
|
norm_weight: RMS norm weight
|
|
device_num_experts: number of experts per device
|
|
scale_input: experts to token score
|
|
active_experts_token_input: per token per expert input
|
|
token_input: per token input, shared expert output
|
|
eps: epsilon for RMSNorm
|
|
|
|
Output:
|
|
hidden_states: hidden_states of the model
|
|
residual: residual tensor
|
|
"""
|
|
return torch.ops.trtllm.moe_allreduce(
|
|
residual=residual,
|
|
norm_weight=norm_weight,
|
|
device_num_experts=device_num_experts,
|
|
scale_input=scale_input,
|
|
active_experts_token_input=active_experts_token_input,
|
|
token_input=token_input,
|
|
workspace=self.workspace,
|
|
rank=self.mapping.tp_rank,
|
|
nranks=self.mapping.tp_size,
|
|
eps=eps,
|
|
)
|