refactor: (part1) Add contraints doc for fusedMoe module. (#3882)

* Add doc string for FusedMoe module
* Address comments.

Signed-off-by: Hui Gao <huig@nvidia.com>
This commit is contained in:
HuiGao-NV 2025-04-29 22:23:02 +08:00 committed by GitHub
parent 06e76020d7
commit 8e6eead6a5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 74 additions and 7 deletions

View File

@ -232,6 +232,45 @@ class FusedMoE(nn.Module):
reduce_results (bool): Whether to reduce the results across devices.
model_config (ModelConfig): Configuration object for the model.
enable_alltoall (bool): whether to enable alltoall instead of allgather/reducescatter
MoE torch custom op:
cutlass Backend
In min-latency mode:
Quant:
fp8 block scales (SM90 Hopper only):
FusedMoE Op: dynamic quant + gemm1 + swiglu + gemm2 (return tensor list).
fp8 qdq, nvfp4:
FusedMoE Op: gemm1 + swiglu + gemm2 (return tensor list).
In max-throughput mode:
Quant:
fp8 block scales (SM90 Hopper only):
FusedMoE Op: dynamic quant + scatter + gemm1 + swiglu + gemm2 + finalizeMoeRoute (return one tensor)
p8 qdq, nvfp4:
FusedMoE Op: scatter + gemm1 + swiglu + gemm2 + finalizeMoeRoute (return one tensor)
trtllm_gen backend:
Only support min-latency mode now (SM100 Blackwell only).
Quant: fp8 block scales quant and nvfp4 quant
FusedMoE Op: routing(topK, etc.) + scatter + gemm1 + swiglu + gemm2 + finalize MoeRoute
FusedMoE module:
cutlass Backend (moe_backend="CUTLASS"):
min-latency mode:
routing(topK, etc.) + FusedMoE Op
equals to: routing(topK, etc.) [+ dynamic quant fp8 qdq | optional dynamic quant nvfp4] + gemm1 + swiglu + gemm2
max-throughput mode:
routing(topK, etc.) [+ dynamic quant for fp8 qdq and nvfp4 ] [+ fp4_allgather] + FusedMoe Op[no allreduce] + reducescatter, with AttentionDP on
equals to: dynamic quant + routing(topK, etc.) [+ fp4_allgather] + scatter + gemm1 + swiglu + gemm2 + finalizeMoeRoute [no allreduce] + reducescatter
trtllm_gen backend (moe_backend="TRTLLM"):
min-latency mode (min_latency_mode flag of forward has no effect when trtllm_gen is used):
dynamic quant + FusedMoe Op
equals to: dynamic quant + routing(topK, etc.) + scatter + gemm1 + swiglu + gemm2 + finalize MoeRoute
In min-latency mode, setting `reduce_results=False` disables the AllReduce in the FusedMoE module, so any necessary AllReduce operations must be added explicitly in the model definition.
AttentionDP should be turned off for min-latency mode.
"""
def __init__(
@ -332,6 +371,31 @@ class FusedMoE(nn.Module):
# If True, the router weight will be multiplied on the input rather than at the end of FC2
self.apply_router_weight_on_input = apply_router_weight_on_input
self._check_configs()
def _check_configs(self):
if self.enable_alltoall:
assert self.use_dp and self.parallel_size > 1,\
"alltoall should only enabled with attention dp and parallel_size > 1"
if self.is_trtllm():
# trtllm_gen backend only support min-latency mode now
assert not self.reduce_results
assert self.quant_config and (
self.quant_config.quant_mode.has_nvfp4()
| self.quant_config.quant_mode.has_fp8_block_scales()
), "The TRTLLM backend of FusedMoE only supports fp8_block_scaling and nvfp4 dtypes."
else:
if self.apply_router_weight_on_input:
assert self.routing_method.top_k == 1, "Current walkaround only supports top-1 routing"
if self.quant_config and self.quant_config.quant_mode.has_any_quant(
exclude_kv_cache=True):
if not (self.quant_config.quant_mode.has_nvfp4()
| self.quant_config.quant_mode.has_fp8_block_scales()
| self.quant_config.quant_mode.has_fp8_qdq()):
raise ValueError(
f"unsupported quantization mode: {self.quant_config.quant_mode}"
)
def setup_quant_scales(self):
self.quant_scales = None
@ -645,13 +709,11 @@ class FusedMoE(nn.Module):
assert token_selected_experts.dtype == torch.int32
if self.apply_router_weight_on_input:
assert self.routing_method.top_k == 1, "Current walkaround only supports top-1 routing"
x = x * token_final_scales.to(x.dtype)
# TODO: remove this once we have correct fusedmoe kernel ready
token_final_scales = None
token_count = x.fp4_tensor.shape[0] if isinstance(
x, Fp4QuantizedTensor) else x.shape[0]
token_count = x.shape[0]
alltoall_info = None
@ -772,6 +834,9 @@ class FusedMoE(nn.Module):
output_dtype: Optional[torch.dtype] = None,
all_rank_num_tokens: Optional[List[int]] = None,
) -> torch.Tensor:
"""
min_latency_mode has no effect when trtllm_gen backend is enabled.
"""
if self.is_cutlass():
return self.forward_cutlass(x, router_logits, min_latency_mode,
output_dtype, all_rank_num_tokens)
@ -797,10 +862,8 @@ class FusedMoE(nn.Module):
assert all_rank_num_tokens is not None
if not disable_fp4_allgather():
max_chunk_size //= len(all_rank_num_tokens)
if isinstance(x, Fp4QuantizedTensor):
num_rows = x.fp4_tensor.shape[0]
else:
num_rows = x.shape[0]
num_rows = x.shape[0]
num_chunks = (num_rows + max_chunk_size - 1) // max_chunk_size
if min_latency_mode:

View File

@ -96,6 +96,10 @@ class Fp4QuantizedTensor:
fp4_tensor: torch.Tensor
scaling_factor: torch.Tensor
@property
def shape(self):
return self.fp4_tensor.shape
_disable_fp4_allgather = os.getenv("TLLM_DISABLE_FP4_ALLGATHER", "0") == "1"