mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[Kernel] Porting fuse_minimax_qk_norm to manual fusion (#43410)
Signed-off-by: Jee Jee Li <jeejeelee@inferact.ai>
This commit is contained in:
@@ -21,7 +21,6 @@ or just on the low or high end.
|
||||
| Fusion | `PassConfig` flag | Fused operations | Default at | E2E Speedup | Fullgraph | `num_tokens` |
|
||||
| ------------------------------------------------------------------------------ | ---------------------------- | ---------------------------------------------- | ------------------------------ | ------------------ | --------- | ------------ |
|
||||
| [AllReduce + RMSNorm](#allreduce--rmsnorm-fuse_allreduce_rms) | `fuse_allreduce_rms` | All-reduce → RMSNorm (+residual_add) (→ quant) | O2 (Hopper/Blackwell + TP > 1) | 5-20% | No | Low |
|
||||
| [MiniMax QK Norm](#minimax-qk-norm-fuse_minimax_qk_norm) | `fuse_minimax_qk_norm` | Q/K variance all-reduce → Q/K RMSNorm | Off by default | 2-3% | No | Low |
|
||||
| [Attention + Quant](#attention--quantization-fuse_attn_quant) | `fuse_attn_quant` | Attention output → FP8/NVFP4 quant | Off by default | 3-7% | Yes | Always |
|
||||
| [MLA Attention + Quant](#attention--quantization-fuse_attn_quant) | `fuse_attn_quant` | MLA Attention output → FP8/NVFP4 quant | Off by default | TBD | Yes | Always |
|
||||
| [RoPE + KV-Cache Update](#rope--kv-cache-update-fuse_rope_kvcache) | `fuse_rope_kvcache` | Rotary embedding → KV cache write | O2 (ROCm/AITER only) | 2-4% | No | Low |
|
||||
@@ -42,7 +41,6 @@ The table below lists the quantization schemes supported by each fusion on each
|
||||
| Fusion | SM100 (Blackwell) | SM90 (Hopper) | SM89 (Ada) | SM80 (Ampere) | ROCm |
|
||||
| ---------------------------- | ---------------------------------------- | ---------------------------------------- | ---------------------------------------- | ------------- | ---------------------------------------- |
|
||||
| `fuse_allreduce_rms` | FP16/BF16, FP8 static, NVFP4 | FP16/BF16, FP8 static | — | — | — |
|
||||
| `fuse_minimax_qk_norm`\* | FP16/BF16 | FP16/BF16 | FP16/BF16 | FP16/BF16 | — |
|
||||
| `fuse_attn_quant`\* | FP8 static\*, NVFP4\* | FP8 static\* | FP8 static\* | — | FP8 static\* |
|
||||
| `fuse_attn_quant` (MLA)\* | FP8 static\*, FP8 per-group\*, NVFP4\* | FP8 static\*, FP8 per-group\* | FP8 static\*, FP8 per-group\* | — | FP8 static\* (untested) |
|
||||
| `fuse_rope_kvcache` | — | — | — | — | FP16/BF16 |
|
||||
@@ -58,9 +56,6 @@ The table below lists the quantization schemes supported by each fusion on each
|
||||
fused quantization output. See the [`fuse_attn_quant` section](#attention--quantization-fuse_attn_quant)
|
||||
for per-backend details.
|
||||
|
||||
\* `fuse_minimax_qk_norm` is a model-specific pass for `MiniMaxM2ForCausalLM`. It also requires
|
||||
tensor parallelism (`tp_size > 1`) and the CUDA custom op `minimax_allreduce_rms_qk`.
|
||||
|
||||
† `enable_sp` and `fuse_gemm_comms` are only autoconfigured for SM90 today;
|
||||
other architectures support requires setting `PassConfig.sp_min_token_num` explicitly.
|
||||
SM100 support also requires setting `VLLM_DISABLED_KERNELS=FlashInferFP8ScaledMMLinearKernel`.
|
||||
@@ -191,35 +186,6 @@ If these conditions are set, the fusion is enabled automatically for optimizatio
|
||||
|
||||
- Pass: [`vllm/compilation/passes/fusion/rope_kvcache_fusion.py`](https://github.com/vllm-project/vllm/blob/main/vllm/compilation/passes/fusion/rope_kvcache_fusion.py)
|
||||
|
||||
### MiniMax QK Norm (`fuse_minimax_qk_norm`)
|
||||
|
||||
!!! info
|
||||
This is a MiniMax-specific compile pass. It is currently only enabled when all of the following hold:
|
||||
the model architecture is `MiniMaxM2ForCausalLM`, tensor parallelism is enabled (`tp_size > 1`),
|
||||
and the CUDA custom op `minimax_allreduce_rms_qk` is available. It is not enabled by default at any
|
||||
optimization level.
|
||||
|
||||
**What it fuses.** Fuses the MiniMax M2 Q/K normalization path that performs an all-reduce over the
|
||||
per-token Q/K variances before applying RMS normalization to Q and K.
|
||||
|
||||
This pass is distinct from [`enable_qk_norm_rope_fusion`](#qk-norm--rope-enable_qk_norm_rope_fusion):
|
||||
`fuse_minimax_qk_norm` targets MiniMax M2's tensor-parallel all-reduce + RMSNorm sequence, while
|
||||
`enable_qk_norm_rope_fusion` targets the later Q/K RMSNorm + RoPE sequence used by several other models.
|
||||
|
||||
Example:
|
||||
|
||||
```bash
|
||||
vllm serve MiniMaxAI/MiniMax-M2.5 \
|
||||
--tensor-parallel-size 4 \
|
||||
--compilation-config '{"mode": 3, "pass_config": {"fuse_minimax_qk_norm": true}}'
|
||||
```
|
||||
|
||||
**Code locations.**
|
||||
|
||||
- Pass: [`vllm/compilation/passes/fusion/minimax_qk_norm_fusion.py`](https://github.com/vllm-project/vllm/blob/main/vllm/compilation/passes/fusion/minimax_qk_norm_fusion.py)
|
||||
- CUDA op: [`csrc/minimax_reduce_rms_kernel.cu`](https://github.com/vllm-project/vllm/blob/main/csrc/minimax_reduce_rms_kernel.cu) (`minimax_allreduce_rms_qk`)
|
||||
- Workspace helper: [`vllm/model_executor/layers/mamba/lamport_workspace.py`](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/lamport_workspace.py)
|
||||
|
||||
### Sequence Parallelism (`enable_sp`)
|
||||
|
||||
**What it fuses.** Replaces all-reduce collectives with reduce-scatter + local RMSNorm + all-gather,
|
||||
|
||||
@@ -10,7 +10,7 @@ from torch.multiprocessing import spawn
|
||||
from tests.kernels.utils import opcheck
|
||||
from tests.utils import ensure_current_vllm_config, init_test_distributed_environment
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
from vllm.model_executor.layers.mamba.linear_attn import MiniMaxText01RMSNormTP
|
||||
from vllm.model_executor.layers.minimax_rms_norm import MiniMaxText01RMSNormTP
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.network_utils import get_open_port
|
||||
from vllm.utils.torch_utils import set_random_seed
|
||||
@@ -59,7 +59,7 @@ def _worker_forward_qk(
|
||||
|
||||
# Set up Lamport workspace.
|
||||
from vllm.distributed.parallel_state import get_tp_group
|
||||
from vllm.model_executor.layers.mamba.lamport_workspace import (
|
||||
from vllm.model_executor.layers.minimax_rms_norm.lamport_workspace import (
|
||||
get_allreduce_workspace,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,340 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
"""
|
||||
Fusion pass: replace MiniMax QK allreduce + RMS norm with the Lamport
|
||||
fused kernel (minimax_allreduce_rms_qk) for decode-size batches.
|
||||
|
||||
Pattern (inlined forward_qk in compiled graph):
|
||||
q, k, v = qkv.split([q_size, kv_size, kv_size], -1)
|
||||
q_fp32 = q.to(float32); k_fp32 = k.to(float32)
|
||||
q_var = q_fp32.pow(2).mean(-1, keepdim=True)
|
||||
k_var = k_fp32.pow(2).mean(-1, keepdim=True)
|
||||
qk_var = cat([q_var, k_var], -1)
|
||||
qk_var = allreduce(qk_var) / tp_world
|
||||
q_var, k_var = qk_var.chunk(2, -1)
|
||||
q_out = (q_fp32 * rsqrt(q_var + eps) * q_weight).to(orig_dtype)
|
||||
k_out = (k_fp32 * rsqrt(k_var + eps) * k_weight).to(orig_dtype)
|
||||
return q_out, k_out, v
|
||||
|
||||
Replacement (pure, no in-place on qkv/q/k):
|
||||
q_out, k_out = minimax_qk_norm_fused(qkv, q_weight, k_weight, workspace, ...)
|
||||
v = qkv.split([q_size, kv_size, kv_size], -1)[2]
|
||||
return q_out, k_out, v
|
||||
|
||||
is_applicable_for_range: only fires for compile_range.end <= max_decode_tokens
|
||||
so that large prefill batches fall through to the original forward_qk (= main).
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
import torch.fx as fx
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.utils import Range
|
||||
from vllm.distributed import tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
from ..inductor_pass import enable_fake_mode
|
||||
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
MAX_TOKEN_NUM = 2048
|
||||
|
||||
_MINIMAX_QK_NORM_FUSED_OP = None
|
||||
if hasattr(torch.ops._C, "minimax_allreduce_rms_qk"):
|
||||
|
||||
def _minimax_qk_norm_fused(
|
||||
qkv: torch.Tensor,
|
||||
norm_weight_q: torch.Tensor,
|
||||
norm_weight_k: torch.Tensor,
|
||||
q_size: int,
|
||||
kv_size: int,
|
||||
rank: int,
|
||||
nranks: int,
|
||||
eps: float,
|
||||
max_tokens: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
from vllm.distributed.parallel_state import get_tp_group
|
||||
from vllm.model_executor.layers.mamba.lamport_workspace import (
|
||||
get_allreduce_workspace,
|
||||
)
|
||||
|
||||
workspace = get_allreduce_workspace(
|
||||
rank=rank,
|
||||
world_size=nranks,
|
||||
max_tokens=max_tokens,
|
||||
process_group=get_tp_group().cpu_group,
|
||||
)
|
||||
return torch.ops._C.minimax_allreduce_rms_qk(
|
||||
qkv,
|
||||
norm_weight_q,
|
||||
norm_weight_k,
|
||||
workspace,
|
||||
q_size,
|
||||
kv_size,
|
||||
rank,
|
||||
nranks,
|
||||
eps,
|
||||
)
|
||||
|
||||
def _minimax_qk_norm_fused_fake(
|
||||
qkv: torch.Tensor,
|
||||
norm_weight_q: torch.Tensor,
|
||||
norm_weight_k: torch.Tensor,
|
||||
q_size: int,
|
||||
kv_size: int,
|
||||
rank: int,
|
||||
nranks: int,
|
||||
eps: float,
|
||||
max_tokens: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
T = qkv.shape[0]
|
||||
return (
|
||||
torch.empty([T, q_size], dtype=qkv.dtype, device=qkv.device),
|
||||
torch.empty([T, kv_size], dtype=qkv.dtype, device=qkv.device),
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="minimax_qk_norm_fused",
|
||||
op_func=_minimax_qk_norm_fused,
|
||||
fake_impl=_minimax_qk_norm_fused_fake,
|
||||
mutates_args=[],
|
||||
)
|
||||
_MINIMAX_QK_NORM_FUSED_OP = torch.ops.vllm.minimax_qk_norm_fused.default
|
||||
|
||||
|
||||
class MiniMaxQKNormPattern:
|
||||
"""
|
||||
Match the forward_qk allreduce+rms pattern and replace with Lamport kernel.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
q_size: int,
|
||||
kv_size: int,
|
||||
eps: float,
|
||||
tp_world: int,
|
||||
tp_rank: int,
|
||||
max_tokens: int,
|
||||
dtype: torch.dtype,
|
||||
device: str | None,
|
||||
) -> None:
|
||||
self.q_size = q_size
|
||||
self.kv_size = kv_size
|
||||
self.eps = eps
|
||||
self.tp_world = tp_world
|
||||
self.tp_rank = tp_rank
|
||||
self.max_tokens = max_tokens
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
T = 4
|
||||
qkv = torch.empty(
|
||||
[T, self.q_size + 2 * self.kv_size],
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
q_weight = torch.empty([self.q_size], device=self.device, dtype=self.dtype)
|
||||
k_weight = torch.empty([self.kv_size], device=self.device, dtype=self.dtype)
|
||||
return [qkv, q_weight, k_weight]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
q_size = self.q_size
|
||||
kv_size = self.kv_size
|
||||
eps = self.eps
|
||||
tp_world = self.tp_world
|
||||
max_tokens = self.max_tokens
|
||||
tp_rank = self.tp_rank
|
||||
dtype = self.dtype
|
||||
|
||||
def pattern(
|
||||
qkv: torch.Tensor,
|
||||
q_weight: torch.Tensor,
|
||||
k_weight: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
q, k, v = qkv.split([q_size, kv_size, kv_size], dim=-1)
|
||||
q_fp32 = q.to(torch.float32)
|
||||
k_fp32 = k.to(torch.float32)
|
||||
q_var = q_fp32.pow(2).mean(dim=-1, keepdim=True)
|
||||
k_var = k_fp32.pow(2).mean(dim=-1, keepdim=True)
|
||||
qk_var = torch.cat([q_var, k_var], dim=-1)
|
||||
qk_var = tensor_model_parallel_all_reduce(qk_var) / tp_world
|
||||
q_var, k_var = qk_var.chunk(2, dim=-1)
|
||||
q_out = (q_fp32 * torch.rsqrt(q_var + eps) * q_weight).to(dtype)
|
||||
k_out = (k_fp32 * torch.rsqrt(k_var + eps) * k_weight).to(dtype)
|
||||
return q_out, k_out, v
|
||||
|
||||
def replacement(
|
||||
qkv: torch.Tensor,
|
||||
q_weight: torch.Tensor,
|
||||
k_weight: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
assert _MINIMAX_QK_NORM_FUSED_OP is not None
|
||||
q_out, k_out = torch.ops.vllm.minimax_qk_norm_fused(
|
||||
qkv,
|
||||
q_weight,
|
||||
k_weight,
|
||||
q_size,
|
||||
kv_size,
|
||||
tp_rank,
|
||||
tp_world,
|
||||
eps,
|
||||
max_tokens,
|
||||
)
|
||||
_, _, v = qkv.split([q_size, kv_size, kv_size], dim=-1)
|
||||
return q_out, k_out, v
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
# Second pattern: three separate split_with_sizes nodes (one per output),
|
||||
# each with _users=1. This occurs when the QKV projection uses a
|
||||
# functional GEMM kernel (e.g. cutlass_scaled_mm via auto_functionalized),
|
||||
# which causes inductor to generate one split per consumer.
|
||||
def pattern_split3(
|
||||
qkv: torch.Tensor,
|
||||
q_weight: torch.Tensor,
|
||||
k_weight: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
q = qkv.split([q_size, kv_size, kv_size], dim=-1)[0]
|
||||
k = qkv.split([q_size, kv_size, kv_size], dim=-1)[1]
|
||||
v = qkv.split([q_size, kv_size, kv_size], dim=-1)[2]
|
||||
q_fp32 = q.to(torch.float32)
|
||||
k_fp32 = k.to(torch.float32)
|
||||
q_var = q_fp32.pow(2).mean(dim=-1, keepdim=True)
|
||||
k_var = k_fp32.pow(2).mean(dim=-1, keepdim=True)
|
||||
qk_var = torch.cat([q_var, k_var], dim=-1)
|
||||
qk_var = tensor_model_parallel_all_reduce(qk_var) / tp_world
|
||||
q_var, k_var = qk_var.chunk(2, dim=-1)
|
||||
q_out = (q_fp32 * torch.rsqrt(q_var + eps) * q_weight).to(dtype)
|
||||
k_out = (k_fp32 * torch.rsqrt(k_var + eps) * k_weight).to(dtype)
|
||||
return q_out, k_out, v
|
||||
|
||||
pm.register_replacement(
|
||||
pattern_split3, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class MiniMaxQKNormPass(VllmPatternMatcherPass):
|
||||
"""
|
||||
Replace forward_qk allreduce+norm with the Lamport fused kernel.
|
||||
Only applied for decode-size compile ranges (small token counts).
|
||||
"""
|
||||
|
||||
def __init__(self, config: VllmConfig) -> None:
|
||||
super().__init__(config)
|
||||
self.disabled = True
|
||||
|
||||
if _MINIMAX_QK_NORM_FUSED_OP is None:
|
||||
logger.warning_once(
|
||||
"minimax_allreduce_rms_qk op not found, MiniMaxQKNormPass disabled."
|
||||
)
|
||||
return
|
||||
|
||||
tp_world = get_tensor_model_parallel_world_size()
|
||||
if tp_world <= 1:
|
||||
logger.warning_once("MiniMaxQKNormPass disabled: tp_size <= 1.")
|
||||
return
|
||||
|
||||
if config.model_config is None:
|
||||
logger.warning_once("MiniMaxQKNormPass disabled: no model_config.")
|
||||
return
|
||||
|
||||
hf_cfg = config.model_config.hf_config
|
||||
|
||||
model_name = getattr(hf_cfg, "architectures", "")[0]
|
||||
if model_name != "MiniMaxM2ForCausalLM":
|
||||
return
|
||||
|
||||
num_attention_heads = getattr(hf_cfg, "num_attention_heads", 0)
|
||||
num_key_value_heads = getattr(hf_cfg, "num_key_value_heads", 0)
|
||||
hidden_size = getattr(hf_cfg, "hidden_size", 0)
|
||||
head_dim = getattr(hf_cfg, "head_dim", 0)
|
||||
eps: float = getattr(hf_cfg, "rms_norm_eps", 1e-6)
|
||||
|
||||
if (
|
||||
num_attention_heads != 48
|
||||
or num_key_value_heads != 8
|
||||
or hidden_size != 3072
|
||||
or head_dim != 128
|
||||
):
|
||||
logger.warning_once(
|
||||
"MiniMaxQKNormPass disabled: cannot infer model info from hf_config."
|
||||
)
|
||||
return
|
||||
|
||||
num_heads_per_rank = num_attention_heads // tp_world
|
||||
num_kv_heads_per_rank = max(1, num_key_value_heads // tp_world)
|
||||
q_size = num_heads_per_rank * head_dim
|
||||
kv_size = num_kv_heads_per_rank * head_dim
|
||||
|
||||
self.max_token_num = min(
|
||||
MAX_TOKEN_NUM, config.scheduler_config.max_num_batched_tokens
|
||||
)
|
||||
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
# Allocate Lamport workspace first.
|
||||
from vllm.distributed.parallel_state import get_tp_group
|
||||
from vllm.model_executor.layers.mamba.lamport_workspace import (
|
||||
get_allreduce_workspace,
|
||||
)
|
||||
|
||||
get_allreduce_workspace(
|
||||
rank=tp_rank,
|
||||
world_size=tp_world,
|
||||
max_tokens=self.max_token_num,
|
||||
process_group=get_tp_group().cpu_group,
|
||||
)
|
||||
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="minimax_qk_norm_pass"
|
||||
)
|
||||
self._register_patterns(q_size, kv_size, eps, tp_world, tp_rank)
|
||||
self.dump_patterns(config, self.patterns)
|
||||
self.disabled = False
|
||||
|
||||
@enable_fake_mode
|
||||
def _register_patterns(
|
||||
self,
|
||||
q_size: int,
|
||||
kv_size: int,
|
||||
eps: float,
|
||||
tp_world: int,
|
||||
tp_rank: int,
|
||||
) -> None:
|
||||
MiniMaxQKNormPattern(
|
||||
q_size=q_size,
|
||||
kv_size=kv_size,
|
||||
eps=eps,
|
||||
tp_world=tp_world,
|
||||
tp_rank=tp_rank,
|
||||
max_tokens=self.max_token_num,
|
||||
dtype=self.model_dtype,
|
||||
device=self.device,
|
||||
).register(self.patterns)
|
||||
|
||||
def is_applicable_for_range(self, compile_range: Range) -> bool:
|
||||
if self.disabled:
|
||||
return False
|
||||
|
||||
return bool(compile_range.end <= self.max_token_num)
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph) -> None:
|
||||
if self.disabled:
|
||||
return
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("MiniMaxQKNormPass replaced %s patterns", self.matched_count)
|
||||
|
||||
def uuid(self) -> str:
|
||||
return VllmInductorPass.hash_source(self, MiniMaxQKNormPattern)
|
||||
@@ -44,7 +44,6 @@ if current_platform.is_cuda_alike():
|
||||
if current_platform.is_cuda():
|
||||
from .fusion.allreduce_rms_fusion import AllReduceFusionPass
|
||||
from .fusion.collective_fusion import AsyncTPPass
|
||||
from .fusion.minimax_qk_norm_fusion import MiniMaxQKNormPass
|
||||
|
||||
from .inductor_pass import (
|
||||
CustomGraphPass,
|
||||
@@ -154,9 +153,6 @@ class PostGradPassManager(CustomGraphPass): # type: ignore[misc]
|
||||
else:
|
||||
self.passes += [AllReduceFusionPass(config)]
|
||||
|
||||
if self.pass_config.fuse_minimax_qk_norm:
|
||||
self.passes += [MiniMaxQKNormPass(config)]
|
||||
|
||||
if self.pass_config.fuse_norm_quant:
|
||||
if rocm_aiter_ops.is_enabled():
|
||||
self.passes += [
|
||||
|
||||
@@ -135,7 +135,9 @@ class PassConfig:
|
||||
fuse_allreduce_rms: bool = None # type: ignore[assignment]
|
||||
"""Enable flashinfer allreduce fusion."""
|
||||
fuse_minimax_qk_norm: bool = None # type: ignore[assignment]
|
||||
"""Enable fused allreduce+RMSNorm for MiniMax QK norm."""
|
||||
"""Deprecated. The MiniMax QK norm fusion is now applied automatically at
|
||||
runtime (see `MiniMaxText01RMSNormTP.forward_qkv`). This flag is kept for
|
||||
backward compatibility and has no effect; it will be removed in v0.23."""
|
||||
enable_qk_norm_rope_fusion: bool = None # type: ignore[assignment]
|
||||
"""Enable fused Q/K RMSNorm + RoPE pass."""
|
||||
fuse_rope_kvcache_cat_mla: bool = None # type: ignore[assignment]
|
||||
@@ -294,6 +296,13 @@ class PassConfig:
|
||||
"current platform is not CUDA or ROCm. The fusion will be disabled."
|
||||
)
|
||||
self.fuse_rope_kvcache_cat_mla = False
|
||||
if self.fuse_minimax_qk_norm is not None:
|
||||
logger.warning_once(
|
||||
"`fuse_minimax_qk_norm` is deprecated and has no effect; "
|
||||
"the MiniMax QK norm fusion is now applied automatically at "
|
||||
"runtime when its conditions are met. This flag will be "
|
||||
"removed in v0.23."
|
||||
)
|
||||
|
||||
def log_enabled_passes(self) -> None:
|
||||
"""
|
||||
|
||||
@@ -1828,22 +1828,6 @@ class VllmConfig:
|
||||
compile_range_end,
|
||||
)
|
||||
|
||||
if compilation_config.pass_config.fuse_minimax_qk_norm:
|
||||
from vllm.compilation.passes.fusion.minimax_qk_norm_fusion import (
|
||||
MAX_TOKEN_NUM,
|
||||
)
|
||||
|
||||
max_token_num = min(
|
||||
MAX_TOKEN_NUM, self.scheduler_config.max_num_batched_tokens
|
||||
)
|
||||
if compile_range_end is not None and max_token_num < compile_range_end:
|
||||
computed_compile_ranges_endpoints.append(max_token_num)
|
||||
else:
|
||||
logger.debug(
|
||||
"Max num batched tokens below MiniMax QK norm fusion threshold, "
|
||||
"MiniMax QK norm fusion enabled for all num_tokens."
|
||||
)
|
||||
|
||||
if compilation_config.compile_ranges_endpoints is not None:
|
||||
for x in compilation_config.compile_ranges_endpoints:
|
||||
assert isinstance(x, int)
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
|
||||
import math
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -11,13 +10,11 @@ from einops import rearrange
|
||||
from torch import nn
|
||||
|
||||
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
|
||||
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.lightning_attn import (
|
||||
lightning_attention,
|
||||
linear_decode_forward_triton,
|
||||
@@ -28,6 +25,7 @@ from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateDtypeCalculator,
|
||||
MambaStateShapeCalculator,
|
||||
)
|
||||
from vllm.model_executor.layers.minimax_rms_norm import MiniMaxText01RMSNormTP
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
from vllm.v1.attention.backend import AttentionMetadata
|
||||
@@ -35,92 +33,6 @@ from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata
|
||||
from vllm.v1.attention.backends.registry import MambaAttentionBackendEnum
|
||||
|
||||
|
||||
@CustomOp.register("minimax_text01_rmsnorm_tp")
|
||||
class MiniMaxText01RMSNormTP(CustomOp):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
eps: float = 1e-6,
|
||||
*,
|
||||
weight_shard_world_size: int | None = None,
|
||||
weight_shard_rank: int | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.tp_world = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.weight_shard_world = weight_shard_world_size or self.tp_world
|
||||
self.weight_shard_rank = (
|
||||
self.tp_rank if weight_shard_rank is None else weight_shard_rank
|
||||
)
|
||||
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size // self.weight_shard_world))
|
||||
self.weight.weight_loader = partial(
|
||||
self.weight_loader,
|
||||
shard_world_size=self.weight_shard_world,
|
||||
shard_rank=self.weight_shard_rank,
|
||||
)
|
||||
self.variance_epsilon = eps
|
||||
|
||||
@staticmethod
|
||||
def weight_loader(
|
||||
param: nn.Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
shard_world_size: int | None = None,
|
||||
shard_rank: int | None = None,
|
||||
) -> None:
|
||||
if shard_world_size is None:
|
||||
shard_world_size = get_tensor_model_parallel_world_size()
|
||||
if shard_rank is None:
|
||||
shard_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
shard_size = loaded_weight.shape[0] // shard_world_size
|
||||
shard = slice(shard_rank * shard_size, (shard_rank + 1) * shard_size)
|
||||
param.data.copy_(loaded_weight[shard])
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
orig_dtype = x.dtype
|
||||
x = x.to(torch.float32)
|
||||
variance = x.pow(2).mean(dim=-1, keepdim=True, dtype=torch.float32)
|
||||
if self.tp_world > 1:
|
||||
variance = tensor_model_parallel_all_reduce(variance) / self.tp_world
|
||||
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
||||
x = (x * self.weight).to(orig_dtype)
|
||||
return x
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert residual is None, "RMSNorm does not support residual connection."
|
||||
return self._forward(x)
|
||||
|
||||
@staticmethod
|
||||
def forward_qk(
|
||||
q_norm: "MiniMaxText01RMSNormTP",
|
||||
k_norm: "MiniMaxText01RMSNormTP",
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
orig_dtype = q.dtype
|
||||
q = q.to(torch.float32)
|
||||
k = k.to(torch.float32)
|
||||
q_var = q.pow(2).mean(dim=-1, keepdim=True)
|
||||
k_var = k.pow(2).mean(dim=-1, keepdim=True)
|
||||
if q_norm.tp_world > 1:
|
||||
qk_var = torch.cat([q_var, k_var], dim=-1)
|
||||
qk_var = tensor_model_parallel_all_reduce(qk_var) / q_norm.tp_world
|
||||
q_var, k_var = qk_var.chunk(2, dim=-1)
|
||||
q = q * torch.rsqrt(q_var + q_norm.variance_epsilon) * q_norm.weight
|
||||
k = k * torch.rsqrt(k_var + k_norm.variance_epsilon) * k_norm.weight
|
||||
q = q.to(orig_dtype)
|
||||
k = k.to(orig_dtype)
|
||||
return q, k
|
||||
|
||||
|
||||
def clear_linear_attention_cache_for_new_sequences(
|
||||
kv_cache: torch.Tensor,
|
||||
state_indices_tensor: torch.Tensor,
|
||||
|
||||
@@ -0,0 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.model_executor.layers.minimax_rms_norm.rms_norm_tp import (
|
||||
MiniMaxText01RMSNormTP,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"MiniMaxText01RMSNormTP",
|
||||
]
|
||||
@@ -0,0 +1,234 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_tp_group,
|
||||
)
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
# Max number of tokens supported by the Lamport fused allreduce+RMSNorm kernel.
|
||||
# Larger batches fall back to the eager allreduce + RMSNorm path.
|
||||
MINIMAX_QK_NORM_MAX_TOKEN_NUM = 2048
|
||||
|
||||
_MINIMAX_FUSED_AR_RMS_QK = getattr(torch.ops._C, "minimax_allreduce_rms_qk", None)
|
||||
|
||||
|
||||
@torch.compile(backend=current_platform.simple_compile_backend, dynamic=True)
|
||||
def _minimax_qk_norm_fallback(
|
||||
qkv: torch.Tensor,
|
||||
q_weight: torch.Tensor,
|
||||
k_weight: torch.Tensor,
|
||||
q_size: int,
|
||||
kv_size: int,
|
||||
tp_rank: int,
|
||||
tp_world: int,
|
||||
eps: float,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
q, k, _ = qkv.split([q_size, kv_size, kv_size], dim=-1)
|
||||
orig_dtype = q.dtype
|
||||
q = q.to(torch.float32)
|
||||
k = k.to(torch.float32)
|
||||
q_var = q.pow(2).mean(dim=-1, keepdim=True)
|
||||
k_var = k.pow(2).mean(dim=-1, keepdim=True)
|
||||
if tp_world > 1:
|
||||
qk_var = torch.cat([q_var, k_var], dim=-1)
|
||||
qk_var = tensor_model_parallel_all_reduce(qk_var) / tp_world
|
||||
q_var, k_var = qk_var.chunk(2, dim=-1)
|
||||
q = q * torch.rsqrt(q_var + eps) * q_weight
|
||||
k = k * torch.rsqrt(k_var + eps) * k_weight
|
||||
return q.to(orig_dtype), k.to(orig_dtype)
|
||||
|
||||
|
||||
def _minimax_qk_norm_fusion(
|
||||
qkv: torch.Tensor,
|
||||
q_weight: torch.Tensor,
|
||||
k_weight: torch.Tensor,
|
||||
q_size: int,
|
||||
kv_size: int,
|
||||
tp_rank: int,
|
||||
tp_world: int,
|
||||
eps: float,
|
||||
workspace: torch.Tensor | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert qkv.ndim == 2
|
||||
num_tokens = qkv.shape[0]
|
||||
if (
|
||||
workspace is not None
|
||||
and tp_world > 1
|
||||
and num_tokens <= MINIMAX_QK_NORM_MAX_TOKEN_NUM
|
||||
and _MINIMAX_FUSED_AR_RMS_QK is not None
|
||||
):
|
||||
return _MINIMAX_FUSED_AR_RMS_QK(
|
||||
qkv,
|
||||
q_weight,
|
||||
k_weight,
|
||||
workspace,
|
||||
q_size,
|
||||
kv_size,
|
||||
tp_rank,
|
||||
tp_world,
|
||||
eps,
|
||||
)
|
||||
return _minimax_qk_norm_fallback(
|
||||
qkv, q_weight, k_weight, q_size, kv_size, tp_rank, tp_world, eps
|
||||
)
|
||||
|
||||
|
||||
def _minimax_qk_norm_fusion_fake(
|
||||
qkv: torch.Tensor,
|
||||
q_weight: torch.Tensor,
|
||||
k_weight: torch.Tensor,
|
||||
q_size: int,
|
||||
kv_size: int,
|
||||
tp_rank: int,
|
||||
tp_world: int,
|
||||
eps: float,
|
||||
workspace: torch.Tensor | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert qkv.ndim == 2
|
||||
num_tokens = qkv.shape[0]
|
||||
return (
|
||||
torch.empty([num_tokens, q_size], dtype=qkv.dtype, device=qkv.device),
|
||||
torch.empty([num_tokens, kv_size], dtype=qkv.dtype, device=qkv.device),
|
||||
)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="minimax_qk_norm_fusion",
|
||||
op_func=_minimax_qk_norm_fusion,
|
||||
fake_impl=_minimax_qk_norm_fusion_fake,
|
||||
mutates_args=[],
|
||||
)
|
||||
|
||||
|
||||
@CustomOp.register("minimax_text01_rmsnorm_tp")
|
||||
class MiniMaxText01RMSNormTP(CustomOp):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
eps: float = 1e-6,
|
||||
*,
|
||||
weight_shard_world_size: int | None = None,
|
||||
weight_shard_rank: int | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.tp_world = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.weight_shard_world = weight_shard_world_size or self.tp_world
|
||||
self.weight_shard_rank = (
|
||||
self.tp_rank if weight_shard_rank is None else weight_shard_rank
|
||||
)
|
||||
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size // self.weight_shard_world))
|
||||
self.weight.weight_loader = partial(
|
||||
self.weight_loader,
|
||||
shard_world_size=self.weight_shard_world,
|
||||
shard_rank=self.weight_shard_rank,
|
||||
)
|
||||
self.variance_epsilon = eps
|
||||
|
||||
self.workspace = None
|
||||
if _MINIMAX_FUSED_AR_RMS_QK is not None and self.tp_world > 1:
|
||||
from .lamport_workspace import (
|
||||
get_allreduce_workspace,
|
||||
)
|
||||
|
||||
self.workspace = get_allreduce_workspace(
|
||||
rank=self.tp_rank,
|
||||
world_size=self.tp_world,
|
||||
max_tokens=MINIMAX_QK_NORM_MAX_TOKEN_NUM,
|
||||
process_group=get_tp_group().cpu_group,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def weight_loader(
|
||||
param: nn.Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
shard_world_size: int | None = None,
|
||||
shard_rank: int | None = None,
|
||||
) -> None:
|
||||
if shard_world_size is None:
|
||||
shard_world_size = get_tensor_model_parallel_world_size()
|
||||
if shard_rank is None:
|
||||
shard_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
shard_size = loaded_weight.shape[0] // shard_world_size
|
||||
shard = slice(shard_rank * shard_size, (shard_rank + 1) * shard_size)
|
||||
param.data.copy_(loaded_weight[shard])
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
orig_dtype = x.dtype
|
||||
x = x.to(torch.float32)
|
||||
variance = x.pow(2).mean(dim=-1, keepdim=True, dtype=torch.float32)
|
||||
if self.tp_world > 1:
|
||||
variance = tensor_model_parallel_all_reduce(variance) / self.tp_world
|
||||
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
||||
x = (x * self.weight).to(orig_dtype)
|
||||
return x
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert residual is None, "RMSNorm does not support residual connection."
|
||||
return self._forward(x)
|
||||
|
||||
@staticmethod
|
||||
def forward_qk(
|
||||
q_norm: "MiniMaxText01RMSNormTP",
|
||||
k_norm: "MiniMaxText01RMSNormTP",
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
orig_dtype = q.dtype
|
||||
q = q.to(torch.float32)
|
||||
k = k.to(torch.float32)
|
||||
q_var = q.pow(2).mean(dim=-1, keepdim=True)
|
||||
k_var = k.pow(2).mean(dim=-1, keepdim=True)
|
||||
if q_norm.tp_world > 1:
|
||||
qk_var = torch.cat([q_var, k_var], dim=-1)
|
||||
qk_var = tensor_model_parallel_all_reduce(qk_var) / q_norm.tp_world
|
||||
q_var, k_var = qk_var.chunk(2, dim=-1)
|
||||
q = q * torch.rsqrt(q_var + q_norm.variance_epsilon) * q_norm.weight
|
||||
k = k * torch.rsqrt(k_var + k_norm.variance_epsilon) * k_norm.weight
|
||||
q = q.to(orig_dtype)
|
||||
k = k.to(orig_dtype)
|
||||
return q, k
|
||||
|
||||
@staticmethod
|
||||
def forward_qkv(
|
||||
q_norm: "MiniMaxText01RMSNormTP",
|
||||
k_norm: "MiniMaxText01RMSNormTP",
|
||||
qkv: torch.Tensor,
|
||||
q_size: int,
|
||||
kv_size: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
assert qkv.ndim == 2
|
||||
assert q_norm.variance_epsilon == k_norm.variance_epsilon
|
||||
q, k = torch.ops.vllm.minimax_qk_norm_fusion(
|
||||
qkv,
|
||||
q_norm.weight,
|
||||
k_norm.weight,
|
||||
q_size,
|
||||
kv_size,
|
||||
q_norm.tp_rank,
|
||||
q_norm.tp_world,
|
||||
q_norm.variance_epsilon,
|
||||
q_norm.workspace,
|
||||
)
|
||||
_, _, v = qkv.split([q_size, kv_size, kv_size], dim=-1)
|
||||
return q, k, v
|
||||
@@ -39,7 +39,6 @@ from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||
from vllm.model_executor.layers.mamba.linear_attn import (
|
||||
MiniMaxText01LinearAttention,
|
||||
MiniMaxText01LinearKernel,
|
||||
MiniMaxText01RMSNormTP,
|
||||
clear_linear_attention_cache_for_new_sequences,
|
||||
linear_attention_decode,
|
||||
linear_attention_prefill_and_mix,
|
||||
@@ -49,6 +48,7 @@ from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateDtypeCalculator,
|
||||
MambaStateShapeCalculator,
|
||||
)
|
||||
from vllm.model_executor.layers.minimax_rms_norm import MiniMaxText01RMSNormTP
|
||||
from vllm.model_executor.layers.mla import MLAModules, MultiHeadLatentAttentionWrapper
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
|
||||
@@ -50,7 +50,7 @@ from vllm.model_executor.layers.linear import (
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.linear_attn import MiniMaxText01RMSNormTP
|
||||
from vllm.model_executor.layers.minimax_rms_norm import MiniMaxText01RMSNormTP
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@@ -243,8 +243,9 @@ class MiniMaxM2Attention(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = MiniMaxText01RMSNormTP.forward_qk(self.q_norm, self.k_norm, q, k)
|
||||
q, k, v = MiniMaxText01RMSNormTP.forward_qkv(
|
||||
self.q_norm, self.k_norm, qkv, self.q_size, self.kv_size
|
||||
)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
|
||||
Reference in New Issue
Block a user