[Kernel] Porting fuse_minimax_qk_norm to manual fusion (#43410)

Signed-off-by: Jee Jee Li <jeejeelee@inferact.ai>
This commit is contained in:
Jee Jee Li
2026-05-27 04:16:03 +08:00
committed by GitHub
parent 49b4882779
commit 6e503868ca
12 changed files with 262 additions and 490 deletions
-34
View File
@@ -21,7 +21,6 @@ or just on the low or high end.
| Fusion | `PassConfig` flag | Fused operations | Default at | E2E Speedup | Fullgraph | `num_tokens` |
| ------------------------------------------------------------------------------ | ---------------------------- | ---------------------------------------------- | ------------------------------ | ------------------ | --------- | ------------ |
| [AllReduce + RMSNorm](#allreduce--rmsnorm-fuse_allreduce_rms) | `fuse_allreduce_rms` | All-reduce → RMSNorm (+residual_add) (→ quant) | O2 (Hopper/Blackwell + TP > 1) | 5-20% | No | Low |
| [MiniMax QK Norm](#minimax-qk-norm-fuse_minimax_qk_norm) | `fuse_minimax_qk_norm` | Q/K variance all-reduce → Q/K RMSNorm | Off by default | 2-3% | No | Low |
| [Attention + Quant](#attention--quantization-fuse_attn_quant) | `fuse_attn_quant` | Attention output → FP8/NVFP4 quant | Off by default | 3-7% | Yes | Always |
| [MLA Attention + Quant](#attention--quantization-fuse_attn_quant) | `fuse_attn_quant` | MLA Attention output → FP8/NVFP4 quant | Off by default | TBD | Yes | Always |
| [RoPE + KV-Cache Update](#rope--kv-cache-update-fuse_rope_kvcache) | `fuse_rope_kvcache` | Rotary embedding → KV cache write | O2 (ROCm/AITER only) | 2-4% | No | Low |
@@ -42,7 +41,6 @@ The table below lists the quantization schemes supported by each fusion on each
| Fusion | SM100 (Blackwell) | SM90 (Hopper) | SM89 (Ada) | SM80 (Ampere) | ROCm |
| ---------------------------- | ---------------------------------------- | ---------------------------------------- | ---------------------------------------- | ------------- | ---------------------------------------- |
| `fuse_allreduce_rms` | FP16/BF16, FP8 static, NVFP4 | FP16/BF16, FP8 static | — | — | — |
| `fuse_minimax_qk_norm`\* | FP16/BF16 | FP16/BF16 | FP16/BF16 | FP16/BF16 | — |
| `fuse_attn_quant`\* | FP8 static\*, NVFP4\* | FP8 static\* | FP8 static\* | — | FP8 static\* |
| `fuse_attn_quant` (MLA)\* | FP8 static\*, FP8 per-group\*, NVFP4\* | FP8 static\*, FP8 per-group\* | FP8 static\*, FP8 per-group\* | — | FP8 static\* (untested) |
| `fuse_rope_kvcache` | — | — | — | — | FP16/BF16 |
@@ -58,9 +56,6 @@ The table below lists the quantization schemes supported by each fusion on each
fused quantization output. See the [`fuse_attn_quant` section](#attention--quantization-fuse_attn_quant)
for per-backend details.
\* `fuse_minimax_qk_norm` is a model-specific pass for `MiniMaxM2ForCausalLM`. It also requires
tensor parallelism (`tp_size > 1`) and the CUDA custom op `minimax_allreduce_rms_qk`.
`enable_sp` and `fuse_gemm_comms` are only autoconfigured for SM90 today;
other architectures support requires setting `PassConfig.sp_min_token_num` explicitly.
SM100 support also requires setting `VLLM_DISABLED_KERNELS=FlashInferFP8ScaledMMLinearKernel`.
@@ -191,35 +186,6 @@ If these conditions are set, the fusion is enabled automatically for optimizatio
- Pass: [`vllm/compilation/passes/fusion/rope_kvcache_fusion.py`](https://github.com/vllm-project/vllm/blob/main/vllm/compilation/passes/fusion/rope_kvcache_fusion.py)
### MiniMax QK Norm (`fuse_minimax_qk_norm`)
!!! info
This is a MiniMax-specific compile pass. It is currently only enabled when all of the following hold:
the model architecture is `MiniMaxM2ForCausalLM`, tensor parallelism is enabled (`tp_size > 1`),
and the CUDA custom op `minimax_allreduce_rms_qk` is available. It is not enabled by default at any
optimization level.
**What it fuses.** Fuses the MiniMax M2 Q/K normalization path that performs an all-reduce over the
per-token Q/K variances before applying RMS normalization to Q and K.
This pass is distinct from [`enable_qk_norm_rope_fusion`](#qk-norm--rope-enable_qk_norm_rope_fusion):
`fuse_minimax_qk_norm` targets MiniMax M2's tensor-parallel all-reduce + RMSNorm sequence, while
`enable_qk_norm_rope_fusion` targets the later Q/K RMSNorm + RoPE sequence used by several other models.
Example:
```bash
vllm serve MiniMaxAI/MiniMax-M2.5 \
--tensor-parallel-size 4 \
--compilation-config '{"mode": 3, "pass_config": {"fuse_minimax_qk_norm": true}}'
```
**Code locations.**
- Pass: [`vllm/compilation/passes/fusion/minimax_qk_norm_fusion.py`](https://github.com/vllm-project/vllm/blob/main/vllm/compilation/passes/fusion/minimax_qk_norm_fusion.py)
- CUDA op: [`csrc/minimax_reduce_rms_kernel.cu`](https://github.com/vllm-project/vllm/blob/main/csrc/minimax_reduce_rms_kernel.cu) (`minimax_allreduce_rms_qk`)
- Workspace helper: [`vllm/model_executor/layers/mamba/lamport_workspace.py`](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/lamport_workspace.py)
### Sequence Parallelism (`enable_sp`)
**What it fuses.** Replaces all-reduce collectives with reduce-scatter + local RMSNorm + all-gather,
@@ -10,7 +10,7 @@ from torch.multiprocessing import spawn
from tests.kernels.utils import opcheck
from tests.utils import ensure_current_vllm_config, init_test_distributed_environment
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.model_executor.layers.mamba.linear_attn import MiniMaxText01RMSNormTP
from vllm.model_executor.layers.minimax_rms_norm import MiniMaxText01RMSNormTP
from vllm.platforms import current_platform
from vllm.utils.network_utils import get_open_port
from vllm.utils.torch_utils import set_random_seed
@@ -59,7 +59,7 @@ def _worker_forward_qk(
# Set up Lamport workspace.
from vllm.distributed.parallel_state import get_tp_group
from vllm.model_executor.layers.mamba.lamport_workspace import (
from vllm.model_executor.layers.minimax_rms_norm.lamport_workspace import (
get_allreduce_workspace,
)
@@ -1,340 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Fusion pass: replace MiniMax QK allreduce + RMS norm with the Lamport
fused kernel (minimax_allreduce_rms_qk) for decode-size batches.
Pattern (inlined forward_qk in compiled graph):
q, k, v = qkv.split([q_size, kv_size, kv_size], -1)
q_fp32 = q.to(float32); k_fp32 = k.to(float32)
q_var = q_fp32.pow(2).mean(-1, keepdim=True)
k_var = k_fp32.pow(2).mean(-1, keepdim=True)
qk_var = cat([q_var, k_var], -1)
qk_var = allreduce(qk_var) / tp_world
q_var, k_var = qk_var.chunk(2, -1)
q_out = (q_fp32 * rsqrt(q_var + eps) * q_weight).to(orig_dtype)
k_out = (k_fp32 * rsqrt(k_var + eps) * k_weight).to(orig_dtype)
return q_out, k_out, v
Replacement (pure, no in-place on qkv/q/k):
q_out, k_out = minimax_qk_norm_fused(qkv, q_weight, k_weight, workspace, ...)
v = qkv.split([q_size, kv_size, kv_size], -1)[2]
return q_out, k_out, v
is_applicable_for_range: only fires for compile_range.end <= max_decode_tokens
so that large prefill batches fall through to the original forward_qk (= main).
"""
import torch
import torch._inductor.pattern_matcher as pm
import torch.fx as fx
from torch._inductor.pattern_matcher import PatternMatcherPass
from vllm.config import VllmConfig
from vllm.config.utils import Range
from vllm.distributed import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.logger import init_logger
from vllm.utils.torch_utils import direct_register_custom_op
from ..inductor_pass import enable_fake_mode
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
logger = init_logger(__name__)
MAX_TOKEN_NUM = 2048
_MINIMAX_QK_NORM_FUSED_OP = None
if hasattr(torch.ops._C, "minimax_allreduce_rms_qk"):
def _minimax_qk_norm_fused(
qkv: torch.Tensor,
norm_weight_q: torch.Tensor,
norm_weight_k: torch.Tensor,
q_size: int,
kv_size: int,
rank: int,
nranks: int,
eps: float,
max_tokens: int,
) -> tuple[torch.Tensor, torch.Tensor]:
from vllm.distributed.parallel_state import get_tp_group
from vllm.model_executor.layers.mamba.lamport_workspace import (
get_allreduce_workspace,
)
workspace = get_allreduce_workspace(
rank=rank,
world_size=nranks,
max_tokens=max_tokens,
process_group=get_tp_group().cpu_group,
)
return torch.ops._C.minimax_allreduce_rms_qk(
qkv,
norm_weight_q,
norm_weight_k,
workspace,
q_size,
kv_size,
rank,
nranks,
eps,
)
def _minimax_qk_norm_fused_fake(
qkv: torch.Tensor,
norm_weight_q: torch.Tensor,
norm_weight_k: torch.Tensor,
q_size: int,
kv_size: int,
rank: int,
nranks: int,
eps: float,
max_tokens: int,
) -> tuple[torch.Tensor, torch.Tensor]:
T = qkv.shape[0]
return (
torch.empty([T, q_size], dtype=qkv.dtype, device=qkv.device),
torch.empty([T, kv_size], dtype=qkv.dtype, device=qkv.device),
)
direct_register_custom_op(
op_name="minimax_qk_norm_fused",
op_func=_minimax_qk_norm_fused,
fake_impl=_minimax_qk_norm_fused_fake,
mutates_args=[],
)
_MINIMAX_QK_NORM_FUSED_OP = torch.ops.vllm.minimax_qk_norm_fused.default
class MiniMaxQKNormPattern:
"""
Match the forward_qk allreduce+rms pattern and replace with Lamport kernel.
"""
def __init__(
self,
q_size: int,
kv_size: int,
eps: float,
tp_world: int,
tp_rank: int,
max_tokens: int,
dtype: torch.dtype,
device: str | None,
) -> None:
self.q_size = q_size
self.kv_size = kv_size
self.eps = eps
self.tp_world = tp_world
self.tp_rank = tp_rank
self.max_tokens = max_tokens
self.dtype = dtype
self.device = device
def get_inputs(self) -> list[torch.Tensor]:
T = 4
qkv = torch.empty(
[T, self.q_size + 2 * self.kv_size],
device=self.device,
dtype=self.dtype,
)
q_weight = torch.empty([self.q_size], device=self.device, dtype=self.dtype)
k_weight = torch.empty([self.kv_size], device=self.device, dtype=self.dtype)
return [qkv, q_weight, k_weight]
def register(self, pm_pass: PatternMatcherPass) -> None:
q_size = self.q_size
kv_size = self.kv_size
eps = self.eps
tp_world = self.tp_world
max_tokens = self.max_tokens
tp_rank = self.tp_rank
dtype = self.dtype
def pattern(
qkv: torch.Tensor,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
q, k, v = qkv.split([q_size, kv_size, kv_size], dim=-1)
q_fp32 = q.to(torch.float32)
k_fp32 = k.to(torch.float32)
q_var = q_fp32.pow(2).mean(dim=-1, keepdim=True)
k_var = k_fp32.pow(2).mean(dim=-1, keepdim=True)
qk_var = torch.cat([q_var, k_var], dim=-1)
qk_var = tensor_model_parallel_all_reduce(qk_var) / tp_world
q_var, k_var = qk_var.chunk(2, dim=-1)
q_out = (q_fp32 * torch.rsqrt(q_var + eps) * q_weight).to(dtype)
k_out = (k_fp32 * torch.rsqrt(k_var + eps) * k_weight).to(dtype)
return q_out, k_out, v
def replacement(
qkv: torch.Tensor,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert _MINIMAX_QK_NORM_FUSED_OP is not None
q_out, k_out = torch.ops.vllm.minimax_qk_norm_fused(
qkv,
q_weight,
k_weight,
q_size,
kv_size,
tp_rank,
tp_world,
eps,
max_tokens,
)
_, _, v = qkv.split([q_size, kv_size, kv_size], dim=-1)
return q_out, k_out, v
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
# Second pattern: three separate split_with_sizes nodes (one per output),
# each with _users=1. This occurs when the QKV projection uses a
# functional GEMM kernel (e.g. cutlass_scaled_mm via auto_functionalized),
# which causes inductor to generate one split per consumer.
def pattern_split3(
qkv: torch.Tensor,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
q = qkv.split([q_size, kv_size, kv_size], dim=-1)[0]
k = qkv.split([q_size, kv_size, kv_size], dim=-1)[1]
v = qkv.split([q_size, kv_size, kv_size], dim=-1)[2]
q_fp32 = q.to(torch.float32)
k_fp32 = k.to(torch.float32)
q_var = q_fp32.pow(2).mean(dim=-1, keepdim=True)
k_var = k_fp32.pow(2).mean(dim=-1, keepdim=True)
qk_var = torch.cat([q_var, k_var], dim=-1)
qk_var = tensor_model_parallel_all_reduce(qk_var) / tp_world
q_var, k_var = qk_var.chunk(2, dim=-1)
q_out = (q_fp32 * torch.rsqrt(q_var + eps) * q_weight).to(dtype)
k_out = (k_fp32 * torch.rsqrt(k_var + eps) * k_weight).to(dtype)
return q_out, k_out, v
pm.register_replacement(
pattern_split3, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class MiniMaxQKNormPass(VllmPatternMatcherPass):
"""
Replace forward_qk allreduce+norm with the Lamport fused kernel.
Only applied for decode-size compile ranges (small token counts).
"""
def __init__(self, config: VllmConfig) -> None:
super().__init__(config)
self.disabled = True
if _MINIMAX_QK_NORM_FUSED_OP is None:
logger.warning_once(
"minimax_allreduce_rms_qk op not found, MiniMaxQKNormPass disabled."
)
return
tp_world = get_tensor_model_parallel_world_size()
if tp_world <= 1:
logger.warning_once("MiniMaxQKNormPass disabled: tp_size <= 1.")
return
if config.model_config is None:
logger.warning_once("MiniMaxQKNormPass disabled: no model_config.")
return
hf_cfg = config.model_config.hf_config
model_name = getattr(hf_cfg, "architectures", "")[0]
if model_name != "MiniMaxM2ForCausalLM":
return
num_attention_heads = getattr(hf_cfg, "num_attention_heads", 0)
num_key_value_heads = getattr(hf_cfg, "num_key_value_heads", 0)
hidden_size = getattr(hf_cfg, "hidden_size", 0)
head_dim = getattr(hf_cfg, "head_dim", 0)
eps: float = getattr(hf_cfg, "rms_norm_eps", 1e-6)
if (
num_attention_heads != 48
or num_key_value_heads != 8
or hidden_size != 3072
or head_dim != 128
):
logger.warning_once(
"MiniMaxQKNormPass disabled: cannot infer model info from hf_config."
)
return
num_heads_per_rank = num_attention_heads // tp_world
num_kv_heads_per_rank = max(1, num_key_value_heads // tp_world)
q_size = num_heads_per_rank * head_dim
kv_size = num_kv_heads_per_rank * head_dim
self.max_token_num = min(
MAX_TOKEN_NUM, config.scheduler_config.max_num_batched_tokens
)
tp_rank = get_tensor_model_parallel_rank()
# Allocate Lamport workspace first.
from vllm.distributed.parallel_state import get_tp_group
from vllm.model_executor.layers.mamba.lamport_workspace import (
get_allreduce_workspace,
)
get_allreduce_workspace(
rank=tp_rank,
world_size=tp_world,
max_tokens=self.max_token_num,
process_group=get_tp_group().cpu_group,
)
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="minimax_qk_norm_pass"
)
self._register_patterns(q_size, kv_size, eps, tp_world, tp_rank)
self.dump_patterns(config, self.patterns)
self.disabled = False
@enable_fake_mode
def _register_patterns(
self,
q_size: int,
kv_size: int,
eps: float,
tp_world: int,
tp_rank: int,
) -> None:
MiniMaxQKNormPattern(
q_size=q_size,
kv_size=kv_size,
eps=eps,
tp_world=tp_world,
tp_rank=tp_rank,
max_tokens=self.max_token_num,
dtype=self.model_dtype,
device=self.device,
).register(self.patterns)
def is_applicable_for_range(self, compile_range: Range) -> bool:
if self.disabled:
return False
return bool(compile_range.end <= self.max_token_num)
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph) -> None:
if self.disabled:
return
self.matched_count = self.patterns.apply(graph)
logger.debug("MiniMaxQKNormPass replaced %s patterns", self.matched_count)
def uuid(self) -> str:
return VllmInductorPass.hash_source(self, MiniMaxQKNormPattern)
-4
View File
@@ -44,7 +44,6 @@ if current_platform.is_cuda_alike():
if current_platform.is_cuda():
from .fusion.allreduce_rms_fusion import AllReduceFusionPass
from .fusion.collective_fusion import AsyncTPPass
from .fusion.minimax_qk_norm_fusion import MiniMaxQKNormPass
from .inductor_pass import (
CustomGraphPass,
@@ -154,9 +153,6 @@ class PostGradPassManager(CustomGraphPass): # type: ignore[misc]
else:
self.passes += [AllReduceFusionPass(config)]
if self.pass_config.fuse_minimax_qk_norm:
self.passes += [MiniMaxQKNormPass(config)]
if self.pass_config.fuse_norm_quant:
if rocm_aiter_ops.is_enabled():
self.passes += [
+10 -1
View File
@@ -135,7 +135,9 @@ class PassConfig:
fuse_allreduce_rms: bool = None # type: ignore[assignment]
"""Enable flashinfer allreduce fusion."""
fuse_minimax_qk_norm: bool = None # type: ignore[assignment]
"""Enable fused allreduce+RMSNorm for MiniMax QK norm."""
"""Deprecated. The MiniMax QK norm fusion is now applied automatically at
runtime (see `MiniMaxText01RMSNormTP.forward_qkv`). This flag is kept for
backward compatibility and has no effect; it will be removed in v0.23."""
enable_qk_norm_rope_fusion: bool = None # type: ignore[assignment]
"""Enable fused Q/K RMSNorm + RoPE pass."""
fuse_rope_kvcache_cat_mla: bool = None # type: ignore[assignment]
@@ -294,6 +296,13 @@ class PassConfig:
"current platform is not CUDA or ROCm. The fusion will be disabled."
)
self.fuse_rope_kvcache_cat_mla = False
if self.fuse_minimax_qk_norm is not None:
logger.warning_once(
"`fuse_minimax_qk_norm` is deprecated and has no effect; "
"the MiniMax QK norm fusion is now applied automatically at "
"runtime when its conditions are met. This flag will be "
"removed in v0.23."
)
def log_enabled_passes(self) -> None:
"""
-16
View File
@@ -1828,22 +1828,6 @@ class VllmConfig:
compile_range_end,
)
if compilation_config.pass_config.fuse_minimax_qk_norm:
from vllm.compilation.passes.fusion.minimax_qk_norm_fusion import (
MAX_TOKEN_NUM,
)
max_token_num = min(
MAX_TOKEN_NUM, self.scheduler_config.max_num_batched_tokens
)
if compile_range_end is not None and max_token_num < compile_range_end:
computed_compile_ranges_endpoints.append(max_token_num)
else:
logger.debug(
"Max num batched tokens below MiniMax QK norm fusion threshold, "
"MiniMax QK norm fusion enabled for all num_tokens."
)
if compilation_config.compile_ranges_endpoints is not None:
for x in compilation_config.compile_ranges_endpoints:
assert isinstance(x, int)
@@ -3,7 +3,6 @@
import math
from collections.abc import Callable
from functools import partial
import torch
import torch.nn.functional as F
@@ -11,13 +10,11 @@ from einops import rearrange
from torch import nn
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.lightning_attn import (
lightning_attention,
linear_decode_forward_triton,
@@ -28,6 +25,7 @@ from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator,
MambaStateShapeCalculator,
)
from vllm.model_executor.layers.minimax_rms_norm import MiniMaxText01RMSNormTP
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backend import AttentionMetadata
@@ -35,92 +33,6 @@ from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata
from vllm.v1.attention.backends.registry import MambaAttentionBackendEnum
@CustomOp.register("minimax_text01_rmsnorm_tp")
class MiniMaxText01RMSNormTP(CustomOp):
def __init__(
self,
hidden_size: int,
eps: float = 1e-6,
*,
weight_shard_world_size: int | None = None,
weight_shard_rank: int | None = None,
) -> None:
super().__init__()
self.tp_world = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.weight_shard_world = weight_shard_world_size or self.tp_world
self.weight_shard_rank = (
self.tp_rank if weight_shard_rank is None else weight_shard_rank
)
self.weight = nn.Parameter(torch.ones(hidden_size // self.weight_shard_world))
self.weight.weight_loader = partial(
self.weight_loader,
shard_world_size=self.weight_shard_world,
shard_rank=self.weight_shard_rank,
)
self.variance_epsilon = eps
@staticmethod
def weight_loader(
param: nn.Parameter,
loaded_weight: torch.Tensor,
shard_world_size: int | None = None,
shard_rank: int | None = None,
) -> None:
if shard_world_size is None:
shard_world_size = get_tensor_model_parallel_world_size()
if shard_rank is None:
shard_rank = get_tensor_model_parallel_rank()
shard_size = loaded_weight.shape[0] // shard_world_size
shard = slice(shard_rank * shard_size, (shard_rank + 1) * shard_size)
param.data.copy_(loaded_weight[shard])
def _forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
orig_dtype = x.dtype
x = x.to(torch.float32)
variance = x.pow(2).mean(dim=-1, keepdim=True, dtype=torch.float32)
if self.tp_world > 1:
variance = tensor_model_parallel_all_reduce(variance) / self.tp_world
x = x * torch.rsqrt(variance + self.variance_epsilon)
x = (x * self.weight).to(orig_dtype)
return x
def forward(
self,
x: torch.Tensor,
residual: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert residual is None, "RMSNorm does not support residual connection."
return self._forward(x)
@staticmethod
def forward_qk(
q_norm: "MiniMaxText01RMSNormTP",
k_norm: "MiniMaxText01RMSNormTP",
q: torch.Tensor,
k: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
orig_dtype = q.dtype
q = q.to(torch.float32)
k = k.to(torch.float32)
q_var = q.pow(2).mean(dim=-1, keepdim=True)
k_var = k.pow(2).mean(dim=-1, keepdim=True)
if q_norm.tp_world > 1:
qk_var = torch.cat([q_var, k_var], dim=-1)
qk_var = tensor_model_parallel_all_reduce(qk_var) / q_norm.tp_world
q_var, k_var = qk_var.chunk(2, dim=-1)
q = q * torch.rsqrt(q_var + q_norm.variance_epsilon) * q_norm.weight
k = k * torch.rsqrt(k_var + k_norm.variance_epsilon) * k_norm.weight
q = q.to(orig_dtype)
k = k.to(orig_dtype)
return q, k
def clear_linear_attention_cache_for_new_sequences(
kv_cache: torch.Tensor,
state_indices_tensor: torch.Tensor,
@@ -0,0 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.model_executor.layers.minimax_rms_norm.rms_norm_tp import (
MiniMaxText01RMSNormTP,
)
__all__ = [
"MiniMaxText01RMSNormTP",
]
@@ -0,0 +1,234 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from functools import partial
import torch
from torch import nn
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
get_tp_group,
)
from vllm.model_executor.custom_op import CustomOp
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
# Max number of tokens supported by the Lamport fused allreduce+RMSNorm kernel.
# Larger batches fall back to the eager allreduce + RMSNorm path.
MINIMAX_QK_NORM_MAX_TOKEN_NUM = 2048
_MINIMAX_FUSED_AR_RMS_QK = getattr(torch.ops._C, "minimax_allreduce_rms_qk", None)
@torch.compile(backend=current_platform.simple_compile_backend, dynamic=True)
def _minimax_qk_norm_fallback(
qkv: torch.Tensor,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
q_size: int,
kv_size: int,
tp_rank: int,
tp_world: int,
eps: float,
) -> tuple[torch.Tensor, torch.Tensor]:
q, k, _ = qkv.split([q_size, kv_size, kv_size], dim=-1)
orig_dtype = q.dtype
q = q.to(torch.float32)
k = k.to(torch.float32)
q_var = q.pow(2).mean(dim=-1, keepdim=True)
k_var = k.pow(2).mean(dim=-1, keepdim=True)
if tp_world > 1:
qk_var = torch.cat([q_var, k_var], dim=-1)
qk_var = tensor_model_parallel_all_reduce(qk_var) / tp_world
q_var, k_var = qk_var.chunk(2, dim=-1)
q = q * torch.rsqrt(q_var + eps) * q_weight
k = k * torch.rsqrt(k_var + eps) * k_weight
return q.to(orig_dtype), k.to(orig_dtype)
def _minimax_qk_norm_fusion(
qkv: torch.Tensor,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
q_size: int,
kv_size: int,
tp_rank: int,
tp_world: int,
eps: float,
workspace: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor]:
assert qkv.ndim == 2
num_tokens = qkv.shape[0]
if (
workspace is not None
and tp_world > 1
and num_tokens <= MINIMAX_QK_NORM_MAX_TOKEN_NUM
and _MINIMAX_FUSED_AR_RMS_QK is not None
):
return _MINIMAX_FUSED_AR_RMS_QK(
qkv,
q_weight,
k_weight,
workspace,
q_size,
kv_size,
tp_rank,
tp_world,
eps,
)
return _minimax_qk_norm_fallback(
qkv, q_weight, k_weight, q_size, kv_size, tp_rank, tp_world, eps
)
def _minimax_qk_norm_fusion_fake(
qkv: torch.Tensor,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
q_size: int,
kv_size: int,
tp_rank: int,
tp_world: int,
eps: float,
workspace: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor]:
assert qkv.ndim == 2
num_tokens = qkv.shape[0]
return (
torch.empty([num_tokens, q_size], dtype=qkv.dtype, device=qkv.device),
torch.empty([num_tokens, kv_size], dtype=qkv.dtype, device=qkv.device),
)
direct_register_custom_op(
op_name="minimax_qk_norm_fusion",
op_func=_minimax_qk_norm_fusion,
fake_impl=_minimax_qk_norm_fusion_fake,
mutates_args=[],
)
@CustomOp.register("minimax_text01_rmsnorm_tp")
class MiniMaxText01RMSNormTP(CustomOp):
def __init__(
self,
hidden_size: int,
eps: float = 1e-6,
*,
weight_shard_world_size: int | None = None,
weight_shard_rank: int | None = None,
) -> None:
super().__init__()
self.tp_world = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.weight_shard_world = weight_shard_world_size or self.tp_world
self.weight_shard_rank = (
self.tp_rank if weight_shard_rank is None else weight_shard_rank
)
self.weight = nn.Parameter(torch.ones(hidden_size // self.weight_shard_world))
self.weight.weight_loader = partial(
self.weight_loader,
shard_world_size=self.weight_shard_world,
shard_rank=self.weight_shard_rank,
)
self.variance_epsilon = eps
self.workspace = None
if _MINIMAX_FUSED_AR_RMS_QK is not None and self.tp_world > 1:
from .lamport_workspace import (
get_allreduce_workspace,
)
self.workspace = get_allreduce_workspace(
rank=self.tp_rank,
world_size=self.tp_world,
max_tokens=MINIMAX_QK_NORM_MAX_TOKEN_NUM,
process_group=get_tp_group().cpu_group,
)
@staticmethod
def weight_loader(
param: nn.Parameter,
loaded_weight: torch.Tensor,
shard_world_size: int | None = None,
shard_rank: int | None = None,
) -> None:
if shard_world_size is None:
shard_world_size = get_tensor_model_parallel_world_size()
if shard_rank is None:
shard_rank = get_tensor_model_parallel_rank()
shard_size = loaded_weight.shape[0] // shard_world_size
shard = slice(shard_rank * shard_size, (shard_rank + 1) * shard_size)
param.data.copy_(loaded_weight[shard])
def _forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
orig_dtype = x.dtype
x = x.to(torch.float32)
variance = x.pow(2).mean(dim=-1, keepdim=True, dtype=torch.float32)
if self.tp_world > 1:
variance = tensor_model_parallel_all_reduce(variance) / self.tp_world
x = x * torch.rsqrt(variance + self.variance_epsilon)
x = (x * self.weight).to(orig_dtype)
return x
def forward(
self,
x: torch.Tensor,
residual: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert residual is None, "RMSNorm does not support residual connection."
return self._forward(x)
@staticmethod
def forward_qk(
q_norm: "MiniMaxText01RMSNormTP",
k_norm: "MiniMaxText01RMSNormTP",
q: torch.Tensor,
k: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
orig_dtype = q.dtype
q = q.to(torch.float32)
k = k.to(torch.float32)
q_var = q.pow(2).mean(dim=-1, keepdim=True)
k_var = k.pow(2).mean(dim=-1, keepdim=True)
if q_norm.tp_world > 1:
qk_var = torch.cat([q_var, k_var], dim=-1)
qk_var = tensor_model_parallel_all_reduce(qk_var) / q_norm.tp_world
q_var, k_var = qk_var.chunk(2, dim=-1)
q = q * torch.rsqrt(q_var + q_norm.variance_epsilon) * q_norm.weight
k = k * torch.rsqrt(k_var + k_norm.variance_epsilon) * k_norm.weight
q = q.to(orig_dtype)
k = k.to(orig_dtype)
return q, k
@staticmethod
def forward_qkv(
q_norm: "MiniMaxText01RMSNormTP",
k_norm: "MiniMaxText01RMSNormTP",
qkv: torch.Tensor,
q_size: int,
kv_size: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert qkv.ndim == 2
assert q_norm.variance_epsilon == k_norm.variance_epsilon
q, k = torch.ops.vllm.minimax_qk_norm_fusion(
qkv,
q_norm.weight,
k_norm.weight,
q_size,
kv_size,
q_norm.tp_rank,
q_norm.tp_world,
q_norm.variance_epsilon,
q_norm.workspace,
)
_, _, v = qkv.split([q_size, kv_size, kv_size], dim=-1)
return q, k, v
@@ -39,7 +39,6 @@ from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.linear_attn import (
MiniMaxText01LinearAttention,
MiniMaxText01LinearKernel,
MiniMaxText01RMSNormTP,
clear_linear_attention_cache_for_new_sequences,
linear_attention_decode,
linear_attention_prefill_and_mix,
@@ -49,6 +48,7 @@ from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator,
MambaStateShapeCalculator,
)
from vllm.model_executor.layers.minimax_rms_norm import MiniMaxText01RMSNormTP
from vllm.model_executor.layers.mla import MLAModules, MultiHeadLatentAttentionWrapper
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
+4 -3
View File
@@ -50,7 +50,7 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.linear_attn import MiniMaxText01RMSNormTP
from vllm.model_executor.layers.minimax_rms_norm import MiniMaxText01RMSNormTP
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -243,8 +243,9 @@ class MiniMaxM2Attention(nn.Module):
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = MiniMaxText01RMSNormTP.forward_qk(self.q_norm, self.k_norm, q, k)
q, k, v = MiniMaxText01RMSNormTP.forward_qkv(
self.q_norm, self.k_norm, qkv, self.q_size, self.kv_size
)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)