mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[Kernel] Porting fuse_minimax_qk_norm to manual fusion (#43410)
Signed-off-by: Jee Jee Li <jeejeelee@inferact.ai>
This commit is contained in:
@@ -21,7 +21,6 @@ or just on the low or high end.
|
||||
| Fusion | `PassConfig` flag | Fused operations | Default at | E2E Speedup | Fullgraph | `num_tokens` |
|
||||
| ------------------------------------------------------------------------------ | ---------------------------- | ---------------------------------------------- | ------------------------------ | ------------------ | --------- | ------------ |
|
||||
| [AllReduce + RMSNorm](#allreduce--rmsnorm-fuse_allreduce_rms) | `fuse_allreduce_rms` | All-reduce → RMSNorm (+residual_add) (→ quant) | O2 (Hopper/Blackwell + TP > 1) | 5-20% | No | Low |
|
||||
| [MiniMax QK Norm](#minimax-qk-norm-fuse_minimax_qk_norm) | `fuse_minimax_qk_norm` | Q/K variance all-reduce → Q/K RMSNorm | Off by default | 2-3% | No | Low |
|
||||
| [Attention + Quant](#attention--quantization-fuse_attn_quant) | `fuse_attn_quant` | Attention output → FP8/NVFP4 quant | Off by default | 3-7% | Yes | Always |
|
||||
| [MLA Attention + Quant](#attention--quantization-fuse_attn_quant) | `fuse_attn_quant` | MLA Attention output → FP8/NVFP4 quant | Off by default | TBD | Yes | Always |
|
||||
| [RoPE + KV-Cache Update](#rope--kv-cache-update-fuse_rope_kvcache) | `fuse_rope_kvcache` | Rotary embedding → KV cache write | O2 (ROCm/AITER only) | 2-4% | No | Low |
|
||||
@@ -42,7 +41,6 @@ The table below lists the quantization schemes supported by each fusion on each
|
||||
| Fusion | SM100 (Blackwell) | SM90 (Hopper) | SM89 (Ada) | SM80 (Ampere) | ROCm |
|
||||
| ---------------------------- | ---------------------------------------- | ---------------------------------------- | ---------------------------------------- | ------------- | ---------------------------------------- |
|
||||
| `fuse_allreduce_rms` | FP16/BF16, FP8 static, NVFP4 | FP16/BF16, FP8 static | — | — | — |
|
||||
| `fuse_minimax_qk_norm`\* | FP16/BF16 | FP16/BF16 | FP16/BF16 | FP16/BF16 | — |
|
||||
| `fuse_attn_quant`\* | FP8 static\*, NVFP4\* | FP8 static\* | FP8 static\* | — | FP8 static\* |
|
||||
| `fuse_attn_quant` (MLA)\* | FP8 static\*, FP8 per-group\*, NVFP4\* | FP8 static\*, FP8 per-group\* | FP8 static\*, FP8 per-group\* | — | FP8 static\* (untested) |
|
||||
| `fuse_rope_kvcache` | — | — | — | — | FP16/BF16 |
|
||||
@@ -58,9 +56,6 @@ The table below lists the quantization schemes supported by each fusion on each
|
||||
fused quantization output. See the [`fuse_attn_quant` section](#attention--quantization-fuse_attn_quant)
|
||||
for per-backend details.
|
||||
|
||||
\* `fuse_minimax_qk_norm` is a model-specific pass for `MiniMaxM2ForCausalLM`. It also requires
|
||||
tensor parallelism (`tp_size > 1`) and the CUDA custom op `minimax_allreduce_rms_qk`.
|
||||
|
||||
† `enable_sp` and `fuse_gemm_comms` are only autoconfigured for SM90 today;
|
||||
other architectures support requires setting `PassConfig.sp_min_token_num` explicitly.
|
||||
SM100 support also requires setting `VLLM_DISABLED_KERNELS=FlashInferFP8ScaledMMLinearKernel`.
|
||||
@@ -191,35 +186,6 @@ If these conditions are set, the fusion is enabled automatically for optimizatio
|
||||
|
||||
- Pass: [`vllm/compilation/passes/fusion/rope_kvcache_fusion.py`](https://github.com/vllm-project/vllm/blob/main/vllm/compilation/passes/fusion/rope_kvcache_fusion.py)
|
||||
|
||||
### MiniMax QK Norm (`fuse_minimax_qk_norm`)
|
||||
|
||||
!!! info
|
||||
This is a MiniMax-specific compile pass. It is currently only enabled when all of the following hold:
|
||||
the model architecture is `MiniMaxM2ForCausalLM`, tensor parallelism is enabled (`tp_size > 1`),
|
||||
and the CUDA custom op `minimax_allreduce_rms_qk` is available. It is not enabled by default at any
|
||||
optimization level.
|
||||
|
||||
**What it fuses.** Fuses the MiniMax M2 Q/K normalization path that performs an all-reduce over the
|
||||
per-token Q/K variances before applying RMS normalization to Q and K.
|
||||
|
||||
This pass is distinct from [`enable_qk_norm_rope_fusion`](#qk-norm--rope-enable_qk_norm_rope_fusion):
|
||||
`fuse_minimax_qk_norm` targets MiniMax M2's tensor-parallel all-reduce + RMSNorm sequence, while
|
||||
`enable_qk_norm_rope_fusion` targets the later Q/K RMSNorm + RoPE sequence used by several other models.
|
||||
|
||||
Example:
|
||||
|
||||
```bash
|
||||
vllm serve MiniMaxAI/MiniMax-M2.5 \
|
||||
--tensor-parallel-size 4 \
|
||||
--compilation-config '{"mode": 3, "pass_config": {"fuse_minimax_qk_norm": true}}'
|
||||
```
|
||||
|
||||
**Code locations.**
|
||||
|
||||
- Pass: [`vllm/compilation/passes/fusion/minimax_qk_norm_fusion.py`](https://github.com/vllm-project/vllm/blob/main/vllm/compilation/passes/fusion/minimax_qk_norm_fusion.py)
|
||||
- CUDA op: [`csrc/minimax_reduce_rms_kernel.cu`](https://github.com/vllm-project/vllm/blob/main/csrc/minimax_reduce_rms_kernel.cu) (`minimax_allreduce_rms_qk`)
|
||||
- Workspace helper: [`vllm/model_executor/layers/mamba/lamport_workspace.py`](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/lamport_workspace.py)
|
||||
|
||||
### Sequence Parallelism (`enable_sp`)
|
||||
|
||||
**What it fuses.** Replaces all-reduce collectives with reduce-scatter + local RMSNorm + all-gather,
|
||||
|
||||
Reference in New Issue
Block a user