[#8781][fix] Cache the AllReduce wrapper to avoid re-allocating workspace which caused a hang (#8803)

Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
This commit is contained in:
Eran Geva 2025-11-02 15:30:39 +02:00 committed by GitHub
parent da73410d3b
commit f8778230e3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -8,6 +8,11 @@ try:
from ...distributed import AllReduce, allgather
from ...modules.linear import AllReduceFusionOp, AllReduceParams, AllReduceStrategy
# Cache AllReduce modules to avoid recreating on every call
# This is critical for CUDA graph compatibility - recreating modules during
# warmup causes hangs due to workspace allocation with CPU synchronization
_allreduce_cache = {}
def trtllm_allgather(tensor, dim, sizes=None):
rank, world_size = get_rank_world_size()
p_config = Mapping(world_size=world_size, tp_size=world_size, rank=rank)
@ -16,9 +21,17 @@ try:
def trtllm_allreduce(tensor, op, all_reduce_params=None):
rank, world_size = get_rank_world_size()
assert op == ReduceOp.SUM, "TRT-LLM all reduce only supports SUM op."
p_config = Mapping(world_size=world_size, tp_size=world_size, rank=rank)
# Use Strategy.NCCL until https://nvbugspro.nvidia.com/bug/5331013 is fixed, then change to Strategy.AUTO
torch_op = AllReduce(mapping=p_config, strategy=AllReduceStrategy.NCCL)
# Cache key includes rank, world_size, and dtype to handle different configurations
cache_key = (rank, world_size, tensor.dtype)
if cache_key not in _allreduce_cache:
p_config = Mapping(world_size=world_size, tp_size=world_size, rank=rank)
# Use Strategy.AUTO for optimal performance
_allreduce_cache[cache_key] = AllReduce(
mapping=p_config, strategy=AllReduceStrategy.AUTO, dtype=tensor.dtype
)
torch_op = _allreduce_cache[cache_key]
return torch_op(tensor, all_reduce_params=all_reduce_params)
@torch.library.custom_op(