From d6e49542bd21a9fe5bb48d75108d6eae330974bf Mon Sep 17 00:00:00 2001 From: Leslie Fang Date: Tue, 10 Feb 2026 20:09:00 +0800 Subject: [PATCH] [https://nvbugs/5848377][fix] fix deepeplowlatency with trtllm moe backend running fp8 DS_R1 (#11266) Signed-off-by: leslie-fang25 Signed-off-by: Leslie Fang Co-authored-by: Tailing Yuan --- .../modules/fused_moe/communication/deep_ep_low_latency.py | 4 ++++ tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py | 5 ++++- .../_torch/modules/fused_moe/fused_moe_trtllm_gen.py | 4 ++++ 3 files changed, 12 insertions(+), 1 deletion(-) diff --git a/tensorrt_llm/_torch/modules/fused_moe/communication/deep_ep_low_latency.py b/tensorrt_llm/_torch/modules/fused_moe/communication/deep_ep_low_latency.py index d7e96a656e..656f4957fe 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/communication/deep_ep_low_latency.py +++ b/tensorrt_llm/_torch/modules/fused_moe/communication/deep_ep_low_latency.py @@ -283,6 +283,10 @@ class DeepEPLowLatency(Communication): self.expert_size_per_partition, num_tokens_per_expert, self.hidden_size ) + if deep_ep_topk_weights.dtype != torch.float32: + # Deep ep low latency combine requires for fp32 weights + deep_ep_topk_weights = deep_ep_topk_weights.to(torch.float32) + if self.use_low_precision_combine: if self._has_nvfp4(): precision = "nvfp4" diff --git a/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py b/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py index 1c47aa5afb..b99de2086d 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py +++ b/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py @@ -756,7 +756,10 @@ class ConfigurableMoE(MoE): if self.enable_dummy_allreduce: self.dummy_allreduce() # Use unified combine interface (reads dispatch state from strategy) - final_hidden_states = self.comm.combine(final_hidden_states) + all_rank_max_num_tokens = max(all_rank_num_tokens) + final_hidden_states = self.comm.combine( + final_hidden_states, all_rank_max_num_tokens=all_rank_max_num_tokens + ) else: # For non-comm case, It should be attention TP or single rank. # only check if allreduce is needed diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py index 637c402f44..ea259cc162 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py @@ -531,6 +531,10 @@ class TRTLLMGenFusedMoE(MoE): routing_bias = routing_bias if router_logits is not None else None + if token_selected_experts is not None: + # for cases like deepep low latency where fake top_k=1 might be used + top_k = token_selected_experts.shape[-1] + # Ensure x_sf is 2D before flattening if x_sf is not None: assert len(