diff --git a/tensorrt_llm/_mnnvl_utils.py b/tensorrt_llm/_mnnvl_utils.py index 7675d10531..dc60578872 100644 --- a/tensorrt_llm/_mnnvl_utils.py +++ b/tensorrt_llm/_mnnvl_utils.py @@ -599,6 +599,7 @@ class MnnvlMoe: top_k: int, token_count: int, use_low_precision_combine: bool = False, + do_reduce: bool = True, ): assert x.dim() == 2, "2D tensor supported, please reshape." output_tensors = torch.ops.trtllm.moe_comm( @@ -614,7 +615,8 @@ class MnnvlMoe: [True], use_low_precision_combine, ) - output_tensor = output_tensors[0] - return torch.sum( - output_tensor.reshape(token_count, top_k, x.shape[1]), dim=1, keepdim=False - ) + output_tensor = output_tensors[0].reshape(token_count, top_k, x.shape[1]) + if do_reduce: + return torch.sum(output_tensor, dim=1, keepdim=False) + else: + return output_tensor diff --git a/tensorrt_llm/_torch/models/modeling_deepseekv3.py b/tensorrt_llm/_torch/models/modeling_deepseekv3.py index f34ff92ba9..0f886d0cd5 100644 --- a/tensorrt_llm/_torch/models/modeling_deepseekv3.py +++ b/tensorrt_llm/_torch/models/modeling_deepseekv3.py @@ -128,6 +128,12 @@ def weight_dequant(x: torch.Tensor, return y +@torch.compile(dynamic=True) +def moe_reduce_add_shared_output(routed_output, shared_output): + routed_output = torch.sum(routed_output, dim=1, keepdim=False) + return shared_output + routed_output + + class DeepseekV3MTPHead(nn.Module): def __init__(self, model_config: ModelConfig[PretrainedConfig]): @@ -585,6 +591,8 @@ class Deepseekv3MoE(nn.Module): do_finalize) return routed_output + # NOTE: define compiled helpers at module scope to avoid defining decorators inside compiled frames + routed_output, shared_output = maybe_execute_in_parallel( _compute_routed_output, _compute_shared_output, self.event_dict[EventType.Main], @@ -593,9 +601,17 @@ class Deepseekv3MoE(nn.Module): if not do_finalize: return [shared_output, *routed_output] else: - assert shared_output.size() == routed_output.size( - ), f'unmatched tensor shape' - final_hidden_states = shared_output + routed_output + if routed_output.dim() == 3: + assert shared_output.numel( + ) * self.top_k == routed_output.numel( + ), 'unmatched tensor shape' + final_hidden_states = moe_reduce_add_shared_output( + routed_output, shared_output) + else: + assert shared_output.size() == routed_output.size( + ), 'unmatched tensor shape' + final_hidden_states = shared_output + routed_output + if not self.use_dp and self.mapping.tp_size > 1: final_hidden_states = self.allreduce( final_hidden_states, diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py index fb2cf56a81..5430141071 100755 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py @@ -906,7 +906,8 @@ class WideEPMoE(MoE): ep_size=self.ep_size, top_k=top_k, token_count=token_count, - use_low_precision_combine=self.use_low_precision_combine) + use_low_precision_combine=self.use_low_precision_combine, + do_reduce=False) return final_hidden_states