TensorRT-LLMs/tensorrt_llm/_torch/distributed/ops.py
Dom Brown 8709fe8b53
chore: bump version to 0.19.0 (#3598) (#3841)
test: add test cases for 0.19 release (#3608)

* fix test name



* add quickstart test for nemotron-ultra



* add rcca multi-node test case for deepseek-v3



* add rcca info



---------




squash (#3642)



fix: nvbugs/5187237: fix deterministic mode crash (#3448)

* nvbugs/5187237 nvbugs/5112075: fix deterministic mode error

* remove waive


* Revert "remove waive"

This reverts commit 0bf5486d19906d692bfb7a6262333c296b0087ac.



* revert ar fusion



---------



update fp8 doc (#3647)




tests: change qa perf test to trtllm-bench (#3619)




 fix: FP8 quantized lm_head (NvBug 5214229) (#3567)



infra: Add PR approval protection for the release branch (#3634)



fix: nvbugs/5231298: pytorch allreduce issue (#3673)



Fix: nvbugs/5222698 variable not defined (#3630)

* Fix: nvbugs/5222698 variable not defined



* Tidy code



---------



test:sync waives.txt from main branch by disabling test_perf/gpt_350m-cppmanager case (#3685)



test:restore fp8 kv cache testing for L0 (#3671)



doc: Update DeepSeek perf docs (#3693)

* Update DeepSeek perf docs



* update



* Apply suggestions from code review




---------




tests: waive test_llm_multi_node (#3664)



fix: update test_user_buffers_mm_add_prologue atol (#3711)



Fix: cherry-pick hmac encryption from main branch (#3635)

* security fix cherry-pick changes from main



* fix hmac in remote mpi session (#3649)



---------





Un-waive DS-V3-Lite tests. (#3621)



fix: FP8 kv accuracy (#3675)

* fix FP8 kv accuracy



* update doc



---------



Fix script options for engines. (#3622)



unwaive multi-node test (#3721)



chore : Split more tests out of gpt tests (#3524) (#3674)



doc:add torch examples link into torch backend documentation (#3749)




test: Get Eagle tests working (#3593) (#3722)




Waive L0 test (#3756)



waive failed case in perf test, change default max_batch_size to 512 and write config.json to output log (#3656)





Update ds v3 parameters in stress test. (#3676)

waive gemma on L20 (#3766)



https://nvbugs/5141291: Fix convert.py script for Qwen model. (#3758)

Include Qwen2VLDecoderLayer in the smooth_qwen2_model function.



fix: PP4 fixes and cleanup (#3688)




remove benchmark test list (#3643)



skip disagg deepseek test if sm!=90 (#3720)



test: skip failed cases on B200 (#3710)

* add skip condition to tests



* fix error



---------



test: [nvbug: 5234494] skip_pre_ada for fp8 cases (#3718)

* skip_pre_ada for fp8 cases



* update



* update after rebase



---------



add know issue to deepseek doc. (#3800)



Fix ModelOpt Mixtral AWQ OOM (#3714) (#3761)




Waive L0 tests (#3826)



fix: Reduce memory usage in fused moe op associated with AutoTuning and fix moe fallback issue. (#3793)

* Reduce memory usage in fused moe op associated with AutoTuning.
* Replace pre-defined bucket size strategy with a generating function based on the tune_max_num_tokens.
* Add free_memory logic of workspace in min_latency_mode fused moe path.



* Fix fused_moe fallback issue. (#3652)

min_latency_mode is only set to False during warmup phase. Thus when it becomes true during inference, all tactics fall back to the default one and thus cause perf regression.



---------



[doc] Better document for Draft-Target-Model (DTM) speculative decoding (#3797)




Fix pre-commit



Fix again



Address some review comments for the MI

Signed-off-by: Dom Brown <3886319+DomBrown@users.noreply.github.com>
Co-authored-by: Zhanrui Sun <184402041+ZhanruiSunCh@users.noreply.github.com>
2025-04-29 16:57:22 +08:00

241 lines
8.8 KiB
Python

import threading
from typing import List, Optional, Tuple, Union
import torch
from torch import nn
from tensorrt_llm.functional import (AllReduceFusionOp, AllReduceParams,
AllReduceStrategy)
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.plugin.plugin import CustomAllReduceHelper
_thread_local = threading.local()
def get_allreduce_workspace(mapping: Mapping) -> torch.LongTensor:
if not hasattr(_thread_local, 'allreduce_workspaces'):
_thread_local.allreduce_workspaces = {}
allreduce_workspaces = _thread_local.allreduce_workspaces
if mapping not in allreduce_workspaces:
ipc_buffers, workspace = CustomAllReduceHelper.allocate_allreduce_fusion_workspace(
mapping,
CustomAllReduceHelper.max_workspace_size_auto(
mapping.tp_size, support_deterministic=False),
)
allreduce_workspaces[mapping] = (ipc_buffers, workspace)
return allreduce_workspaces[mapping][1]
def userbuffers_allreduce_finalize(
input: torch.Tensor,
force_applying_finalize: bool = False) -> torch.Tensor:
output = torch.ops.trtllm.userbuffers_allreduce_finalize(
input, force_applying_finalize)
return output
def allgather(input: torch.Tensor,
mapping: Mapping,
gather_dim: int = -1) -> torch.Tensor:
'''
Add an operation that performs a collective all-gather.
The input tensors in the different ranks must have the same shape.
The output tensor will be replicated among the TP group.
Given the 'section_size = input.shape[gather_dim]', each rank
contributes a section of its input tensor that correspond to
'rank*section_size:(rank+1)*section_size',
and 'output.shape[gather_dim] = input.shape[gather_dim] * tp_group_size'.
That operation is implemented using a torch op that wraps the NCCL all-gather
collective operation. See
https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#allgather
for details.
Args:
input (Tensor): The input tensor.
mapping (Mapping): The parallel mapping.
gather_dim (int): Gather along given dimension. By default -1.
Returns:
The gathered tensor.
'''
if mapping.tp_size == 1:
return input
output = torch.ops.trtllm.allgather(
input,
mapping.tp_group,
)
if gather_dim < 0:
gather_dim += input.ndim
output = torch.movedim(output, 0, gather_dim)
input_shape = input.size()
output = output.reshape(input_shape[:gather_dim] +
(mapping.tp_size * input_shape[gather_dim], ) +
input_shape[gather_dim + 1:])
return output
def reducescatter(input: torch.Tensor,
mapping: Mapping,
scatter_dim: int = -1) -> torch.Tensor:
if mapping.tp_size == 1:
return input
output = torch.ops.trtllm.reducescatter(
input,
mapping.tp_group,
)
if scatter_dim < 0:
scatter_dim += input.ndim
output = torch.movedim(output, 0, scatter_dim)
input_shape = input.size()
output = output.reshape(input_shape[:scatter_dim] +
(input_shape[scatter_dim] // mapping.tp_size, ) +
input_shape[scatter_dim + 1:])
return output
class AllReduce(nn.Module):
def __init__(self,
mapping: Mapping,
strategy: AllReduceStrategy = AllReduceStrategy.AUTO):
super().__init__()
"""
AllReduce is a module that performs an all-reduce operation on a tensor.
Args:
mapping (Mapping): The parallel mapping config.
strategy (AllReduceStrategy):
Three types of all-reduce strategies are supported:
- UB: AllReduce uses user-buffer based all-reduce kernel. Supported ops:
- RESIDUAL_RMS_NORM
- RESIDUAL_RMS_NORM_QUANT_FP8
- RESIDUAL_RMS_NORM_QUANT_NVFP4
- NCCL: AllReduce delegates all-reduce to NCCL MIN_LATENCY mode kernel. Supported ops:
- NONE (AllReduce only)
- RESIDUAL_RMS_NORM
- MIN_LATENCY: AllReduce uses MIN_LATENCY mode kernel. Supported ops:
- NONE (AllReduce only)
- RESIDUAL_RMS_NORM
- RESIDUAL_RMS_NORM_QUANT_FP8
- RESIDUAL_RMS_NORM_QUANT_NVFP4
- RESIDUAL_RMS_NORM_OUT_QUANT_FP8
- RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4
- AUTO: AUTO chooses between NCCL and MIN_LATENCY mode based on a heuristic policy.
"""
self.mapping = mapping
self.workspace = None
self.strategy = strategy
if self.mapping.tp_size > 1:
# When Strategy is UB, it is guaranteed that the workspace is not used.
if self.strategy != AllReduceStrategy.UB:
self.workspace = get_allreduce_workspace(self.mapping)
def forward(
self,
input: torch.Tensor,
*,
all_reduce_params: Optional[AllReduceParams] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
'''
The input tensors in the different ranks must have the same shape.
The output tensor will have that same shape with the input tensor.
The output tensor will be replicated among the TP group.
Note that it is not an in-place operation like torch.distributed.all_reduce.
That operation is implemented using a torch op that wraps the NCCL all-reduce
collective operation and custom one-shot/two-shot allreduce kernels. See
https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#allreduce
for details.
Args:
input (Tensor): The input tensor.
all_reduce_params (AllReduceParams): The parameters for the fused ops into the allreduce op.
Returns:
A tensor lists with different tensor outptus according to the fusion_op.
NONE: [hidden_states]
RESIDUAL_RMS_NORM: [hidden_states, residual]
RESIDUAL_RMS_NORM_QUANT_FP8: [norm_quant, residual]
RESIDUAL_RMS_NORM_OUT_QUANT_FP8: [norm, norm_quant, residual]
RESIDUAL_RMS_NORM_QUANT_NVFP4: [norm_quant_fp4, scale_factor, residual]
RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4: [norm, norm_quant_fp4, scale_factor, residual]
'''
if self.mapping.tp_size == 1 or (all_reduce_params is not None
and all_reduce_params.enable_allreduce
== False):
return input
# Assume using no fusion allreduce here
if all_reduce_params is None:
all_reduce_params = AllReduceParams()
output = torch.ops.trtllm.allreduce(
input=input,
residual=all_reduce_params.residual,
norm_weight=all_reduce_params.norm_weight,
scale=all_reduce_params.scale,
bias=all_reduce_params.bias,
workspace=self.workspace,
group=self.mapping.tp_group,
strategy=self.strategy,
op=all_reduce_params.fusion_op,
eps=all_reduce_params.eps,
)
return output if len(output) > 1 else output[0]
class DeepseekAllReduce(nn.Module):
def __init__(self, mapping: Mapping):
super().__init__()
self.mapping = mapping
self.workspace = None
if self.mapping.tp_size > 1:
self.workspace = get_allreduce_workspace(mapping)
def forward(
self,
hidden_states: torch.Tensor,
reduce_fusion_inputs: List[torch.Tensor],
eps: float,
fusion_op: AllReduceFusionOp,
) -> Tuple[torch.Tensor, ...]:
"""
hidden_states: hidden_states of the model
reduce_fusion_inputs: [residual, norm_weight, scale (if using FP4 quantization)]
eps: epsilon for RMSNorm
fusion_op: AllReduceFusionOp Type, currently supports RMSNorm:
* RESIDUAL_RMS_NORM: allreduce + residual + Norm
* RESIDUAL_RMS_NORM_QUANT_NVFP4: allreduce + residual + Norm + fp4 quantization
output:
* [hidden_states, residual] if using RESIDUAL_RMS_NORM fusion_op
* [act_fp4, act_sf, residual] if using RESIDUAL_RMS_NORM_QUANT_NVFP4 fusion_op
"""
output = torch.ops.trtllm.deepseek_allreduce_fusion(
input=hidden_states,
workspace=self.workspace,
reduce_fusion_inputs=reduce_fusion_inputs,
rank=self.mapping.tp_rank,
nranks=self.mapping.tp_size,
eps=eps,
fusion_op=fusion_op,
)
if len(output) == 0:
raise ValueError(f"Unsupported fusion op: {fusion_op}")
return output