From 8e6eead6a5a7aa8791daf00425e74c2e1be3aa83 Mon Sep 17 00:00:00 2001 From: HuiGao-NV Date: Tue, 29 Apr 2025 22:23:02 +0800 Subject: [PATCH] refactor: (part1) Add contraints doc for fusedMoe module. (#3882) * Add doc string for FusedMoe module * Address comments. Signed-off-by: Hui Gao --- tensorrt_llm/_torch/modules/fused_moe.py | 77 +++++++++++++++++++++--- tensorrt_llm/_torch/utils.py | 4 ++ 2 files changed, 74 insertions(+), 7 deletions(-) diff --git a/tensorrt_llm/_torch/modules/fused_moe.py b/tensorrt_llm/_torch/modules/fused_moe.py index 773764f616..d2cfe81cb3 100755 --- a/tensorrt_llm/_torch/modules/fused_moe.py +++ b/tensorrt_llm/_torch/modules/fused_moe.py @@ -232,6 +232,45 @@ class FusedMoE(nn.Module): reduce_results (bool): Whether to reduce the results across devices. model_config (ModelConfig): Configuration object for the model. enable_alltoall (bool): whether to enable alltoall instead of allgather/reducescatter + + MoE torch custom op: + cutlass Backend + In min-latency mode: + Quant: + fp8 block scales (SM90 Hopper only): + FusedMoE Op: dynamic quant + gemm1 + swiglu + gemm2 (return tensor list). + fp8 qdq, nvfp4: + FusedMoE Op: gemm1 + swiglu + gemm2 (return tensor list). + + In max-throughput mode: + Quant: + fp8 block scales (SM90 Hopper only): + FusedMoE Op: dynamic quant + scatter + gemm1 + swiglu + gemm2 + finalizeMoeRoute (return one tensor) + p8 qdq, nvfp4: + FusedMoE Op: scatter + gemm1 + swiglu + gemm2 + finalizeMoeRoute (return one tensor) + + trtllm_gen backend: + Only support min-latency mode now (SM100 Blackwell only). + Quant: fp8 block scales quant and nvfp4 quant + FusedMoE Op: routing(topK, etc.) + scatter + gemm1 + swiglu + gemm2 + finalize MoeRoute + + FusedMoE module: + cutlass Backend (moe_backend="CUTLASS"): + min-latency mode: + routing(topK, etc.) + FusedMoE Op + equals to: routing(topK, etc.) [+ dynamic quant fp8 qdq | optional dynamic quant nvfp4] + gemm1 + swiglu + gemm2 + + max-throughput mode: + routing(topK, etc.) [+ dynamic quant for fp8 qdq and nvfp4 ] [+ fp4_allgather] + FusedMoe Op[no allreduce] + reducescatter, with AttentionDP on + equals to: dynamic quant + routing(topK, etc.) [+ fp4_allgather] + scatter + gemm1 + swiglu + gemm2 + finalizeMoeRoute [no allreduce] + reducescatter + + trtllm_gen backend (moe_backend="TRTLLM"): + min-latency mode (min_latency_mode flag of forward has no effect when trtllm_gen is used): + dynamic quant + FusedMoe Op + equals to: dynamic quant + routing(topK, etc.) + scatter + gemm1 + swiglu + gemm2 + finalize MoeRoute + + In min-latency mode, setting `reduce_results=False` disables the AllReduce in the FusedMoE module, so any necessary AllReduce operations must be added explicitly in the model definition. + AttentionDP should be turned off for min-latency mode. """ def __init__( @@ -332,6 +371,31 @@ class FusedMoE(nn.Module): # If True, the router weight will be multiplied on the input rather than at the end of FC2 self.apply_router_weight_on_input = apply_router_weight_on_input + self._check_configs() + + def _check_configs(self): + if self.enable_alltoall: + assert self.use_dp and self.parallel_size > 1,\ + "alltoall should only enabled with attention dp and parallel_size > 1" + + if self.is_trtllm(): + # trtllm_gen backend only support min-latency mode now + assert not self.reduce_results + assert self.quant_config and ( + self.quant_config.quant_mode.has_nvfp4() + | self.quant_config.quant_mode.has_fp8_block_scales() + ), "The TRTLLM backend of FusedMoE only supports fp8_block_scaling and nvfp4 dtypes." + else: + if self.apply_router_weight_on_input: + assert self.routing_method.top_k == 1, "Current walkaround only supports top-1 routing" + if self.quant_config and self.quant_config.quant_mode.has_any_quant( + exclude_kv_cache=True): + if not (self.quant_config.quant_mode.has_nvfp4() + | self.quant_config.quant_mode.has_fp8_block_scales() + | self.quant_config.quant_mode.has_fp8_qdq()): + raise ValueError( + f"unsupported quantization mode: {self.quant_config.quant_mode}" + ) def setup_quant_scales(self): self.quant_scales = None @@ -645,13 +709,11 @@ class FusedMoE(nn.Module): assert token_selected_experts.dtype == torch.int32 if self.apply_router_weight_on_input: - assert self.routing_method.top_k == 1, "Current walkaround only supports top-1 routing" x = x * token_final_scales.to(x.dtype) # TODO: remove this once we have correct fusedmoe kernel ready token_final_scales = None - token_count = x.fp4_tensor.shape[0] if isinstance( - x, Fp4QuantizedTensor) else x.shape[0] + token_count = x.shape[0] alltoall_info = None @@ -772,6 +834,9 @@ class FusedMoE(nn.Module): output_dtype: Optional[torch.dtype] = None, all_rank_num_tokens: Optional[List[int]] = None, ) -> torch.Tensor: + """ + min_latency_mode has no effect when trtllm_gen backend is enabled. + """ if self.is_cutlass(): return self.forward_cutlass(x, router_logits, min_latency_mode, output_dtype, all_rank_num_tokens) @@ -797,10 +862,8 @@ class FusedMoE(nn.Module): assert all_rank_num_tokens is not None if not disable_fp4_allgather(): max_chunk_size //= len(all_rank_num_tokens) - if isinstance(x, Fp4QuantizedTensor): - num_rows = x.fp4_tensor.shape[0] - else: - num_rows = x.shape[0] + + num_rows = x.shape[0] num_chunks = (num_rows + max_chunk_size - 1) // max_chunk_size if min_latency_mode: diff --git a/tensorrt_llm/_torch/utils.py b/tensorrt_llm/_torch/utils.py index dca8067139..98c34f8637 100644 --- a/tensorrt_llm/_torch/utils.py +++ b/tensorrt_llm/_torch/utils.py @@ -96,6 +96,10 @@ class Fp4QuantizedTensor: fp4_tensor: torch.Tensor scaling_factor: torch.Tensor + @property + def shape(self): + return self.fp4_tensor.shape + _disable_fp4_allgather = os.getenv("TLLM_DISABLE_FP4_ALLGATHER", "0") == "1"