Signed-off-by: Tao Li
Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
This commit is contained in:
Tao Li @ NVIDIA 2025-09-18 03:34:05 +08:00 committed by Yanchao Lu
parent 4a09be40f0
commit 44d7c3b245

View File

@ -11,8 +11,7 @@ from transformers.modeling_utils import load_sharded_checkpoint
from transformers.models.llama4.modeling_llama4 import Llama4MultiModalProjector
from tensorrt_llm._torch.distributed import (AllReduce, AllReduceFusionOp,
AllReduceParams, AllReduceStrategy,
MoEAllReduce)
AllReduceParams, MoEAllReduce)
from tensorrt_llm._torch.models.checkpoints.base_weight_mapper import \
BaseWeightMapper
from tensorrt_llm._utils import get_sm_version
@ -650,12 +649,7 @@ class LlamaDecoderLayer(DecoderLayer):
eps=config.rms_norm_eps,
dtype=config.torch_dtype)
# TODO: This is a temporary fix to disable oneshot kernel for pre-Blackwell arch to avoid perf regressions
self.all_reduce = AllReduce(
strategy=model_config.allreduce_strategy
if get_sm_version() >= 100 else AllReduceStrategy.NCCL,
mapping=model_config.mapping,
)
self.all_reduce = AllReduce(mapping=model_config.mapping)
self.next_layer_layernorm: RMSNorm = None
self.next_attn: LlamaAttention = None