mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[https://nvbugs/5517023][fix] Pass allreduce strategy and force NCCL on pre-Blackwell arch (#7768)
Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com> Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
This commit is contained in:
parent
edbe270198
commit
ab26d21620
@ -11,7 +11,8 @@ from transformers.modeling_utils import load_sharded_checkpoint
|
||||
from transformers.models.llama4.modeling_llama4 import Llama4MultiModalProjector
|
||||
|
||||
from tensorrt_llm._torch.distributed import (AllReduce, AllReduceFusionOp,
|
||||
AllReduceParams, MoEAllReduce)
|
||||
AllReduceParams, AllReduceStrategy,
|
||||
MoEAllReduce)
|
||||
from tensorrt_llm._torch.models.checkpoints.base_weight_mapper import \
|
||||
BaseWeightMapper
|
||||
from tensorrt_llm._utils import get_sm_version
|
||||
@ -652,7 +653,12 @@ class LlamaDecoderLayer(DecoderLayer):
|
||||
eps=config.rms_norm_eps,
|
||||
dtype=config.torch_dtype)
|
||||
|
||||
self.all_reduce = AllReduce(mapping=model_config.mapping)
|
||||
# TODO: This is a temporary fix to disable oneshot kernel for pre-Blackwell arch to avoid perf regressions
|
||||
self.all_reduce = AllReduce(
|
||||
strategy=model_config.allreduce_strategy
|
||||
if get_sm_version() >= 100 else AllReduceStrategy.NCCL,
|
||||
mapping=model_config.mapping,
|
||||
)
|
||||
|
||||
self.next_layer_layernorm: RMSNorm = None
|
||||
self.next_attn: LlamaAttention = None
|
||||
|
||||
Loading…
Reference in New Issue
Block a user