diff --git a/tensorrt_llm/_torch/modules/fused_moe/communication/deep_ep_low_latency.py b/tensorrt_llm/_torch/modules/fused_moe/communication/deep_ep_low_latency.py index d2c6a8164c..d7e96a656e 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/communication/deep_ep_low_latency.py +++ b/tensorrt_llm/_torch/modules/fused_moe/communication/deep_ep_low_latency.py @@ -59,7 +59,9 @@ class DeepEPLowLatency(Communication): self.moe_max_num_tokens = moe_max_num_tokens self.expert_size_per_partition = expert_size_per_partition - self.use_low_precision_combine = use_low_precision_combine + self.use_low_precision_combine = ( + use_low_precision_combine and self.supports_low_precision_combine() + ) # Read from environment variable, same as wideEP self.enable_postquant_alltoall = ( os.environ.get("TRTLLM_MOE_POST_QUANT_ALLTOALLV", "1") == "1" @@ -96,7 +98,12 @@ class DeepEPLowLatency(Communication): """ if not self.enable_postquant_alltoall: return False + return self._has_nvfp4() or self._has_fp8_qdq() or self._has_w4afp8() + def supports_low_precision_combine(self) -> bool: + """ + DeepEP Low Latency supports low-precision combine for: fp8_qdq, nvfp4, w4afp8 + """ return self._has_nvfp4() or self._has_fp8_qdq() or self._has_w4afp8() def is_workload_feasible(self, all_rank_num_tokens: List[int], num_chunks: int) -> bool: