[https://nvbugs/5517023][fix] Pass allreduce strategy and force NCCL on pre-Blackwell arch (#7768)

Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
This commit is contained in:
Yukun He 2025-09-17 13:28:43 +08:00 committed by Yanchao Lu
parent edbe270198
commit ab26d21620

View File

@ -11,7 +11,8 @@ from transformers.modeling_utils import load_sharded_checkpoint
from transformers.models.llama4.modeling_llama4 import Llama4MultiModalProjector
from tensorrt_llm._torch.distributed import (AllReduce, AllReduceFusionOp,
AllReduceParams, MoEAllReduce)
AllReduceParams, AllReduceStrategy,
MoEAllReduce)
from tensorrt_llm._torch.models.checkpoints.base_weight_mapper import \
BaseWeightMapper
from tensorrt_llm._utils import get_sm_version
@ -652,7 +653,12 @@ class LlamaDecoderLayer(DecoderLayer):
eps=config.rms_norm_eps,
dtype=config.torch_dtype)
self.all_reduce = AllReduce(mapping=model_config.mapping)
# TODO: This is a temporary fix to disable oneshot kernel for pre-Blackwell arch to avoid perf regressions
self.all_reduce = AllReduce(
strategy=model_config.allreduce_strategy
if get_sm_version() >= 100 else AllReduceStrategy.NCCL,
mapping=model_config.mapping,
)
self.next_layer_layernorm: RMSNorm = None
self.next_attn: LlamaAttention = None