mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
refactor: (part1) Add contraints doc for fusedMoe module. (#3882)
* Add doc string for FusedMoe module * Address comments. Signed-off-by: Hui Gao <huig@nvidia.com>
This commit is contained in:
parent
06e76020d7
commit
8e6eead6a5
@ -232,6 +232,45 @@ class FusedMoE(nn.Module):
|
||||
reduce_results (bool): Whether to reduce the results across devices.
|
||||
model_config (ModelConfig): Configuration object for the model.
|
||||
enable_alltoall (bool): whether to enable alltoall instead of allgather/reducescatter
|
||||
|
||||
MoE torch custom op:
|
||||
cutlass Backend
|
||||
In min-latency mode:
|
||||
Quant:
|
||||
fp8 block scales (SM90 Hopper only):
|
||||
FusedMoE Op: dynamic quant + gemm1 + swiglu + gemm2 (return tensor list).
|
||||
fp8 qdq, nvfp4:
|
||||
FusedMoE Op: gemm1 + swiglu + gemm2 (return tensor list).
|
||||
|
||||
In max-throughput mode:
|
||||
Quant:
|
||||
fp8 block scales (SM90 Hopper only):
|
||||
FusedMoE Op: dynamic quant + scatter + gemm1 + swiglu + gemm2 + finalizeMoeRoute (return one tensor)
|
||||
p8 qdq, nvfp4:
|
||||
FusedMoE Op: scatter + gemm1 + swiglu + gemm2 + finalizeMoeRoute (return one tensor)
|
||||
|
||||
trtllm_gen backend:
|
||||
Only support min-latency mode now (SM100 Blackwell only).
|
||||
Quant: fp8 block scales quant and nvfp4 quant
|
||||
FusedMoE Op: routing(topK, etc.) + scatter + gemm1 + swiglu + gemm2 + finalize MoeRoute
|
||||
|
||||
FusedMoE module:
|
||||
cutlass Backend (moe_backend="CUTLASS"):
|
||||
min-latency mode:
|
||||
routing(topK, etc.) + FusedMoE Op
|
||||
equals to: routing(topK, etc.) [+ dynamic quant fp8 qdq | optional dynamic quant nvfp4] + gemm1 + swiglu + gemm2
|
||||
|
||||
max-throughput mode:
|
||||
routing(topK, etc.) [+ dynamic quant for fp8 qdq and nvfp4 ] [+ fp4_allgather] + FusedMoe Op[no allreduce] + reducescatter, with AttentionDP on
|
||||
equals to: dynamic quant + routing(topK, etc.) [+ fp4_allgather] + scatter + gemm1 + swiglu + gemm2 + finalizeMoeRoute [no allreduce] + reducescatter
|
||||
|
||||
trtllm_gen backend (moe_backend="TRTLLM"):
|
||||
min-latency mode (min_latency_mode flag of forward has no effect when trtllm_gen is used):
|
||||
dynamic quant + FusedMoe Op
|
||||
equals to: dynamic quant + routing(topK, etc.) + scatter + gemm1 + swiglu + gemm2 + finalize MoeRoute
|
||||
|
||||
In min-latency mode, setting `reduce_results=False` disables the AllReduce in the FusedMoE module, so any necessary AllReduce operations must be added explicitly in the model definition.
|
||||
AttentionDP should be turned off for min-latency mode.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -332,6 +371,31 @@ class FusedMoE(nn.Module):
|
||||
|
||||
# If True, the router weight will be multiplied on the input rather than at the end of FC2
|
||||
self.apply_router_weight_on_input = apply_router_weight_on_input
|
||||
self._check_configs()
|
||||
|
||||
def _check_configs(self):
|
||||
if self.enable_alltoall:
|
||||
assert self.use_dp and self.parallel_size > 1,\
|
||||
"alltoall should only enabled with attention dp and parallel_size > 1"
|
||||
|
||||
if self.is_trtllm():
|
||||
# trtllm_gen backend only support min-latency mode now
|
||||
assert not self.reduce_results
|
||||
assert self.quant_config and (
|
||||
self.quant_config.quant_mode.has_nvfp4()
|
||||
| self.quant_config.quant_mode.has_fp8_block_scales()
|
||||
), "The TRTLLM backend of FusedMoE only supports fp8_block_scaling and nvfp4 dtypes."
|
||||
else:
|
||||
if self.apply_router_weight_on_input:
|
||||
assert self.routing_method.top_k == 1, "Current walkaround only supports top-1 routing"
|
||||
if self.quant_config and self.quant_config.quant_mode.has_any_quant(
|
||||
exclude_kv_cache=True):
|
||||
if not (self.quant_config.quant_mode.has_nvfp4()
|
||||
| self.quant_config.quant_mode.has_fp8_block_scales()
|
||||
| self.quant_config.quant_mode.has_fp8_qdq()):
|
||||
raise ValueError(
|
||||
f"unsupported quantization mode: {self.quant_config.quant_mode}"
|
||||
)
|
||||
|
||||
def setup_quant_scales(self):
|
||||
self.quant_scales = None
|
||||
@ -645,13 +709,11 @@ class FusedMoE(nn.Module):
|
||||
assert token_selected_experts.dtype == torch.int32
|
||||
|
||||
if self.apply_router_weight_on_input:
|
||||
assert self.routing_method.top_k == 1, "Current walkaround only supports top-1 routing"
|
||||
x = x * token_final_scales.to(x.dtype)
|
||||
# TODO: remove this once we have correct fusedmoe kernel ready
|
||||
token_final_scales = None
|
||||
|
||||
token_count = x.fp4_tensor.shape[0] if isinstance(
|
||||
x, Fp4QuantizedTensor) else x.shape[0]
|
||||
token_count = x.shape[0]
|
||||
|
||||
alltoall_info = None
|
||||
|
||||
@ -772,6 +834,9 @@ class FusedMoE(nn.Module):
|
||||
output_dtype: Optional[torch.dtype] = None,
|
||||
all_rank_num_tokens: Optional[List[int]] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
min_latency_mode has no effect when trtllm_gen backend is enabled.
|
||||
"""
|
||||
if self.is_cutlass():
|
||||
return self.forward_cutlass(x, router_logits, min_latency_mode,
|
||||
output_dtype, all_rank_num_tokens)
|
||||
@ -797,10 +862,8 @@ class FusedMoE(nn.Module):
|
||||
assert all_rank_num_tokens is not None
|
||||
if not disable_fp4_allgather():
|
||||
max_chunk_size //= len(all_rank_num_tokens)
|
||||
if isinstance(x, Fp4QuantizedTensor):
|
||||
num_rows = x.fp4_tensor.shape[0]
|
||||
else:
|
||||
num_rows = x.shape[0]
|
||||
|
||||
num_rows = x.shape[0]
|
||||
num_chunks = (num_rows + max_chunk_size - 1) // max_chunk_size
|
||||
|
||||
if min_latency_mode:
|
||||
|
||||
@ -96,6 +96,10 @@ class Fp4QuantizedTensor:
|
||||
fp4_tensor: torch.Tensor
|
||||
scaling_factor: torch.Tensor
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self.fp4_tensor.shape
|
||||
|
||||
|
||||
_disable_fp4_allgather = os.getenv("TLLM_DISABLE_FP4_ALLGATHER", "0") == "1"
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user