[Kernel] Porting fuse_minimax_qk_norm to manual fusion (#43410)

Signed-off-by: Jee Jee Li <jeejeelee@inferact.ai>
This commit is contained in:
Jee Jee Li
2026-05-27 04:16:03 +08:00
committed by GitHub
parent 49b4882779
commit 6e503868ca
12 changed files with 262 additions and 490 deletions
-34
View File
@@ -21,7 +21,6 @@ or just on the low or high end.
| Fusion | `PassConfig` flag | Fused operations | Default at | E2E Speedup | Fullgraph | `num_tokens` |
| ------------------------------------------------------------------------------ | ---------------------------- | ---------------------------------------------- | ------------------------------ | ------------------ | --------- | ------------ |
| [AllReduce + RMSNorm](#allreduce--rmsnorm-fuse_allreduce_rms) | `fuse_allreduce_rms` | All-reduce → RMSNorm (+residual_add) (→ quant) | O2 (Hopper/Blackwell + TP > 1) | 5-20% | No | Low |
| [MiniMax QK Norm](#minimax-qk-norm-fuse_minimax_qk_norm) | `fuse_minimax_qk_norm` | Q/K variance all-reduce → Q/K RMSNorm | Off by default | 2-3% | No | Low |
| [Attention + Quant](#attention--quantization-fuse_attn_quant) | `fuse_attn_quant` | Attention output → FP8/NVFP4 quant | Off by default | 3-7% | Yes | Always |
| [MLA Attention + Quant](#attention--quantization-fuse_attn_quant) | `fuse_attn_quant` | MLA Attention output → FP8/NVFP4 quant | Off by default | TBD | Yes | Always |
| [RoPE + KV-Cache Update](#rope--kv-cache-update-fuse_rope_kvcache) | `fuse_rope_kvcache` | Rotary embedding → KV cache write | O2 (ROCm/AITER only) | 2-4% | No | Low |
@@ -42,7 +41,6 @@ The table below lists the quantization schemes supported by each fusion on each
| Fusion | SM100 (Blackwell) | SM90 (Hopper) | SM89 (Ada) | SM80 (Ampere) | ROCm |
| ---------------------------- | ---------------------------------------- | ---------------------------------------- | ---------------------------------------- | ------------- | ---------------------------------------- |
| `fuse_allreduce_rms` | FP16/BF16, FP8 static, NVFP4 | FP16/BF16, FP8 static | — | — | — |
| `fuse_minimax_qk_norm`\* | FP16/BF16 | FP16/BF16 | FP16/BF16 | FP16/BF16 | — |
| `fuse_attn_quant`\* | FP8 static\*, NVFP4\* | FP8 static\* | FP8 static\* | — | FP8 static\* |
| `fuse_attn_quant` (MLA)\* | FP8 static\*, FP8 per-group\*, NVFP4\* | FP8 static\*, FP8 per-group\* | FP8 static\*, FP8 per-group\* | — | FP8 static\* (untested) |
| `fuse_rope_kvcache` | — | — | — | — | FP16/BF16 |
@@ -58,9 +56,6 @@ The table below lists the quantization schemes supported by each fusion on each
fused quantization output. See the [`fuse_attn_quant` section](#attention--quantization-fuse_attn_quant)
for per-backend details.
\* `fuse_minimax_qk_norm` is a model-specific pass for `MiniMaxM2ForCausalLM`. It also requires
tensor parallelism (`tp_size > 1`) and the CUDA custom op `minimax_allreduce_rms_qk`.
`enable_sp` and `fuse_gemm_comms` are only autoconfigured for SM90 today;
other architectures support requires setting `PassConfig.sp_min_token_num` explicitly.
SM100 support also requires setting `VLLM_DISABLED_KERNELS=FlashInferFP8ScaledMMLinearKernel`.
@@ -191,35 +186,6 @@ If these conditions are set, the fusion is enabled automatically for optimizatio
- Pass: [`vllm/compilation/passes/fusion/rope_kvcache_fusion.py`](https://github.com/vllm-project/vllm/blob/main/vllm/compilation/passes/fusion/rope_kvcache_fusion.py)
### MiniMax QK Norm (`fuse_minimax_qk_norm`)
!!! info
This is a MiniMax-specific compile pass. It is currently only enabled when all of the following hold:
the model architecture is `MiniMaxM2ForCausalLM`, tensor parallelism is enabled (`tp_size > 1`),
and the CUDA custom op `minimax_allreduce_rms_qk` is available. It is not enabled by default at any
optimization level.
**What it fuses.** Fuses the MiniMax M2 Q/K normalization path that performs an all-reduce over the
per-token Q/K variances before applying RMS normalization to Q and K.
This pass is distinct from [`enable_qk_norm_rope_fusion`](#qk-norm--rope-enable_qk_norm_rope_fusion):
`fuse_minimax_qk_norm` targets MiniMax M2's tensor-parallel all-reduce + RMSNorm sequence, while
`enable_qk_norm_rope_fusion` targets the later Q/K RMSNorm + RoPE sequence used by several other models.
Example:
```bash
vllm serve MiniMaxAI/MiniMax-M2.5 \
--tensor-parallel-size 4 \
--compilation-config '{"mode": 3, "pass_config": {"fuse_minimax_qk_norm": true}}'
```
**Code locations.**
- Pass: [`vllm/compilation/passes/fusion/minimax_qk_norm_fusion.py`](https://github.com/vllm-project/vllm/blob/main/vllm/compilation/passes/fusion/minimax_qk_norm_fusion.py)
- CUDA op: [`csrc/minimax_reduce_rms_kernel.cu`](https://github.com/vllm-project/vllm/blob/main/csrc/minimax_reduce_rms_kernel.cu) (`minimax_allreduce_rms_qk`)
- Workspace helper: [`vllm/model_executor/layers/mamba/lamport_workspace.py`](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/lamport_workspace.py)
### Sequence Parallelism (`enable_sp`)
**What it fuses.** Replaces all-reduce collectives with reduce-scatter + local RMSNorm + all-gather,