TensorRT-LLMs/tensorrt_llm/_torch/distributed/ops.py
Yukun He 0ae7017342
Unify two versions of AllReduce custom op (#3032)
* Rewrite unit test for unified allreduce op. Removing the legacy unit test.
* Revise formats, fusion_op bindings. Put all tensors as optional inputs.
* Move the MoeAllreduceOp to a separate custom op.
* Move all the fusion patterns to the new version of the AllReduce fusion kernel. Remove the AllReduce strategy config. Revise the AllReduce strategies and fusion pattern definitions.
* Add more TODOs, fixing minor bugs, and remove legacy code.

Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
2025-04-22 21:58:42 +08:00

240 lines
8.8 KiB
Python

import threading
from typing import List, Optional, Tuple, Union
import torch
from torch import nn
from tensorrt_llm.functional import (AllReduceFusionOp, AllReduceParams,
AllReduceStrategy)
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.plugin.plugin import CustomAllReduceHelper
_thread_local = threading.local()
def get_allreduce_workspace(mapping: Mapping) -> torch.LongTensor:
if not hasattr(_thread_local, 'allreduce_workspaces'):
_thread_local.allreduce_workspaces = {}
allreduce_workspaces = _thread_local.allreduce_workspaces
if mapping not in allreduce_workspaces:
ipc_buffers, workspace = CustomAllReduceHelper.allocate_allreduce_fusion_workspace(
mapping,
CustomAllReduceHelper.max_workspace_size_auto(mapping.tp_size),
)
allreduce_workspaces[mapping] = (ipc_buffers, workspace)
return allreduce_workspaces[mapping][1]
def userbuffers_allreduce_finalize(
input: torch.Tensor,
force_applying_finalize: bool = False) -> torch.Tensor:
output = torch.ops.trtllm.userbuffers_allreduce_finalize(
input, force_applying_finalize)
return output
def allgather(input: torch.Tensor,
mapping: Mapping,
gather_dim: int = -1) -> torch.Tensor:
'''
Add an operation that performs a collective all-gather.
The input tensors in the different ranks must have the same shape.
The output tensor will be replicated among the TP group.
Given the 'section_size = input.shape[gather_dim]', each rank
contributes a section of its input tensor that correspond to
'rank*section_size:(rank+1)*section_size',
and 'output.shape[gather_dim] = input.shape[gather_dim] * tp_group_size'.
That operation is implemented using a torch op that wraps the NCCL all-gather
collective operation. See
https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#allgather
for details.
Args:
input (Tensor): The input tensor.
mapping (Mapping): The parallel mapping.
gather_dim (int): Gather along given dimension. By default -1.
Returns:
The gathered tensor.
'''
if mapping.tp_size == 1:
return input
output = torch.ops.trtllm.allgather(
input,
mapping.tp_group,
)
if gather_dim < 0:
gather_dim += input.ndim
output = torch.movedim(output, 0, gather_dim)
input_shape = input.size()
output = output.reshape(input_shape[:gather_dim] +
(mapping.tp_size * input_shape[gather_dim], ) +
input_shape[gather_dim + 1:])
return output
def reducescatter(input: torch.Tensor,
mapping: Mapping,
scatter_dim: int = -1) -> torch.Tensor:
if mapping.tp_size == 1:
return input
output = torch.ops.trtllm.reducescatter(
input,
mapping.tp_group,
)
if scatter_dim < 0:
scatter_dim += input.ndim
output = torch.movedim(output, 0, scatter_dim)
input_shape = input.size()
output = output.reshape(input_shape[:scatter_dim] +
(input_shape[scatter_dim] // mapping.tp_size, ) +
input_shape[scatter_dim + 1:])
return output
class AllReduce(nn.Module):
def __init__(self,
mapping: Mapping,
strategy: AllReduceStrategy = AllReduceStrategy.AUTO):
super().__init__()
"""
AllReduce is a module that performs an all-reduce operation on a tensor.
Args:
mapping (Mapping): The parallel mapping config.
strategy (AllReduceStrategy):
Three types of all-reduce strategies are supported:
- UB: AllReduce uses user-buffer based all-reduce kernel. Supported ops:
- RESIDUAL_RMS_NORM
- RESIDUAL_RMS_NORM_QUANT_FP8
- RESIDUAL_RMS_NORM_QUANT_NVFP4
- NCCL: AllReduce delegates all-reduce to NCCL MIN_LATENCY mode kernel. Supported ops:
- NONE (AllReduce only)
- RESIDUAL_RMS_NORM
- MIN_LATENCY: AllReduce uses MIN_LATENCY mode kernel. Supported ops:
- NONE (AllReduce only)
- RESIDUAL_RMS_NORM
- RESIDUAL_RMS_NORM_QUANT_FP8
- RESIDUAL_RMS_NORM_QUANT_NVFP4
- RESIDUAL_RMS_NORM_OUT_QUANT_FP8
- RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4
- AUTO: AUTO chooses between NCCL and MIN_LATENCY mode based on a heuristic policy.
"""
self.mapping = mapping
self.workspace = None
self.strategy = strategy
if self.mapping.tp_size > 1:
# When Strategy is UB, it is guaranteed that the workspace is not used.
if self.strategy != AllReduceStrategy.UB:
self.workspace = get_allreduce_workspace(self.mapping)
def forward(
self,
input: torch.Tensor,
*,
all_reduce_params: Optional[AllReduceParams] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
'''
The input tensors in the different ranks must have the same shape.
The output tensor will have that same shape with the input tensor.
The output tensor will be replicated among the TP group.
Note that it is not an in-place operation like torch.distributed.all_reduce.
That operation is implemented using a torch op that wraps the NCCL all-reduce
collective operation and custom one-shot/two-shot allreduce kernels. See
https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#allreduce
for details.
Args:
input (Tensor): The input tensor.
all_reduce_params (AllReduceParams): The parameters for the fused ops into the allreduce op.
Returns:
A tensor lists with different tensor outptus according to the fusion_op.
NONE: [hidden_states]
RESIDUAL_RMS_NORM: [hidden_states, residual]
RESIDUAL_RMS_NORM_QUANT_FP8: [norm_quant, residual]
RESIDUAL_RMS_NORM_OUT_QUANT_FP8: [norm, norm_quant, residual]
RESIDUAL_RMS_NORM_QUANT_NVFP4: [norm_quant_fp4, scale_factor, residual]
RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4: [norm, norm_quant_fp4, scale_factor, residual]
'''
if self.mapping.tp_size == 1 or (all_reduce_params is not None
and all_reduce_params.enable_allreduce
== False):
return input
# Assume using no fusion allreduce here
if all_reduce_params is None:
all_reduce_params = AllReduceParams()
output = torch.ops.trtllm.allreduce(
input=input,
residual=all_reduce_params.residual,
norm_weight=all_reduce_params.norm_weight,
scale=all_reduce_params.scale,
bias=all_reduce_params.bias,
workspace=self.workspace,
group=self.mapping.tp_group,
strategy=self.strategy,
op=all_reduce_params.fusion_op,
eps=all_reduce_params.eps,
)
return output if len(output) > 1 else output[0]
class DeepseekAllReduce(nn.Module):
def __init__(self, mapping: Mapping):
super().__init__()
self.mapping = mapping
self.workspace = None
if self.mapping.tp_size > 1:
self.workspace = get_allreduce_workspace(mapping)
def forward(
self,
hidden_states: torch.Tensor,
reduce_fusion_inputs: List[torch.Tensor],
eps: float,
fusion_op: AllReduceFusionOp,
) -> Tuple[torch.Tensor, ...]:
"""
hidden_states: hidden_states of the model
reduce_fusion_inputs: [residual, norm_weight, scale (if using FP4 quantization)]
eps: epsilon for RMSNorm
fusion_op: AllReduceFusionOp Type, currently supports RMSNorm:
* RESIDUAL_RMS_NORM: allreduce + residual + Norm
* RESIDUAL_RMS_NORM_QUANT_NVFP4: allreduce + residual + Norm + fp4 quantization
output:
* [hidden_states, residual] if using RESIDUAL_RMS_NORM fusion_op
* [act_fp4, act_sf, residual] if using RESIDUAL_RMS_NORM_QUANT_NVFP4 fusion_op
"""
output = torch.ops.trtllm.deepseek_allreduce_fusion(
input=hidden_states,
workspace=self.workspace,
reduce_fusion_inputs=reduce_fusion_inputs,
rank=self.mapping.tp_rank,
nranks=self.mapping.tp_size,
eps=eps,
fusion_op=fusion_op,
)
if len(output) == 0:
raise ValueError(f"Unsupported fusion op: {fusion_op}")
return output