mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[#8781][fix] Cache the AllReduce wrapper to avoid re-allocating workspace which caused a hang (#8803)
Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
This commit is contained in:
parent
da73410d3b
commit
f8778230e3
@ -8,6 +8,11 @@ try:
|
||||
from ...distributed import AllReduce, allgather
|
||||
from ...modules.linear import AllReduceFusionOp, AllReduceParams, AllReduceStrategy
|
||||
|
||||
# Cache AllReduce modules to avoid recreating on every call
|
||||
# This is critical for CUDA graph compatibility - recreating modules during
|
||||
# warmup causes hangs due to workspace allocation with CPU synchronization
|
||||
_allreduce_cache = {}
|
||||
|
||||
def trtllm_allgather(tensor, dim, sizes=None):
|
||||
rank, world_size = get_rank_world_size()
|
||||
p_config = Mapping(world_size=world_size, tp_size=world_size, rank=rank)
|
||||
@ -16,9 +21,17 @@ try:
|
||||
def trtllm_allreduce(tensor, op, all_reduce_params=None):
|
||||
rank, world_size = get_rank_world_size()
|
||||
assert op == ReduceOp.SUM, "TRT-LLM all reduce only supports SUM op."
|
||||
p_config = Mapping(world_size=world_size, tp_size=world_size, rank=rank)
|
||||
# Use Strategy.NCCL until https://nvbugspro.nvidia.com/bug/5331013 is fixed, then change to Strategy.AUTO
|
||||
torch_op = AllReduce(mapping=p_config, strategy=AllReduceStrategy.NCCL)
|
||||
|
||||
# Cache key includes rank, world_size, and dtype to handle different configurations
|
||||
cache_key = (rank, world_size, tensor.dtype)
|
||||
if cache_key not in _allreduce_cache:
|
||||
p_config = Mapping(world_size=world_size, tp_size=world_size, rank=rank)
|
||||
# Use Strategy.AUTO for optimal performance
|
||||
_allreduce_cache[cache_key] = AllReduce(
|
||||
mapping=p_config, strategy=AllReduceStrategy.AUTO, dtype=tensor.dtype
|
||||
)
|
||||
|
||||
torch_op = _allreduce_cache[cache_key]
|
||||
return torch_op(tensor, all_reduce_params=all_reduce_params)
|
||||
|
||||
@torch.library.custom_op(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user