[MyPy] Enable mypy for vllm/model_executor/layers/ (#40159)

Signed-off-by: Martin Hickey <martin.hickey@ie.ibm.com>
This commit is contained in:
Martin Hickey
2026-04-22 04:15:02 +01:00
committed by GitHub
parent 6f2c71be8f
commit 3951d3eacd
28 changed files with 243 additions and 146 deletions
-1
View File
@@ -29,7 +29,6 @@ SEPARATE_GROUPS = [
"tests",
# v0 related
"vllm/lora",
"vllm/model_executor/layers",
]
# TODO(woosuk): Include the code from Megatron and HuggingFace.
+15 -12
View File
@@ -666,16 +666,7 @@ _ACTIVATION_REGISTRY = LazyDict(
"gelu": lambda: GELU(),
"gelu_fast": lambda: FastGELU(),
"gelu_new": lambda: NewGELU(),
"gelu_pytorch_tanh": lambda: (
# TODO:[ROCm] PyTorch native GELU with tanh is unstable with torch.compile
logger.warning_once(
"[ROCm] PyTorch's native GELU with tanh approximation is unstable. "
"Falling back to GELU(approximate='none')."
),
nn.GELU(approximate="none"),
)[1]
if current_platform.is_rocm()
else nn.GELU(approximate="tanh"),
"gelu_pytorch_tanh": lambda: _get_gelu_pytorch_tanh(),
"relu": lambda: nn.ReLU(),
"relu2": lambda: ReLUSquaredActivation(),
"silu": lambda: nn.SiLU(),
@@ -687,6 +678,18 @@ _ACTIVATION_REGISTRY = LazyDict(
)
def _get_gelu_pytorch_tanh() -> nn.Module:
"""Get PyTorch GELU with tanh approximation, with ROCm fallback."""
if current_platform.is_rocm():
# TODO:[ROCm] PyTorch native GELU with tanh is unstable with torch.compile
logger.warning_once(
"[ROCm] PyTorch's native GELU with tanh approximation is unstable. "
"Falling back to GELU(approximate='none')."
)
return nn.GELU(approximate="none")
return nn.GELU(approximate="tanh")
def get_act_fn(act_fn_name: str) -> nn.Module:
"""Get an activation function by name."""
act_fn_name = act_fn_name.lower()
@@ -703,12 +706,12 @@ def get_act_fn(act_fn_name: str) -> nn.Module:
return _ACTIVATION_REGISTRY[act_fn_name]
_ACTIVATION_AND_MUL_REGISTRY = LazyDict(
_ACTIVATION_AND_MUL_REGISTRY: LazyDict[nn.Module] = LazyDict(
{
"gelu": lambda: GeluAndMul(),
"silu": lambda: SiluAndMul(),
"geglu": lambda: GeluAndMul(),
"swigluoai": lambda *args, **kwargs: SwigluOAIAndMul(*args, **kwargs),
"swigluoai": lambda: SwigluOAIAndMul(),
}
)
@@ -33,6 +33,7 @@ from vllm.utils.torch_utils import (
)
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionMetadata,
AttentionType,
)
from vllm.v1.attention.backends.registry import AttentionBackendEnum
@@ -209,6 +210,7 @@ class Attention(nn.Module, AttentionLayerBase):
`self.kv_cache`.
"""
super().__init__()
sliding_window: int | None
if per_layer_sliding_window is not None:
# per-layer sliding window
sliding_window = per_layer_sliding_window
@@ -335,7 +337,7 @@ class Attention(nn.Module, AttentionLayerBase):
cache_config.enable_prefix_caching = False
impl_cls = self.attn_backend.get_impl_cls()
self.impl = impl_cls(
self.impl = impl_cls( # type: ignore[assignment] # impl_cls always returns an AttentionImpl subclass
num_heads,
head_size,
scale,
@@ -576,7 +578,7 @@ class Attention(nn.Module, AttentionLayerBase):
def get_attn_backend(self) -> type[AttentionBackend]:
return self.attn_backend
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None:
# Block size may get updated after model loading, refresh it
block_size = vllm_config.cache_config.block_size
# Should not be called for enc-dec or encoder-only attention.
@@ -680,9 +682,16 @@ def get_attention_context(
extracted from the forward context.
"""
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name]
attn_metadata_raw = forward_context.attn_metadata
attn_metadata: AttentionMetadata
if isinstance(attn_metadata_raw, dict):
attn_metadata = attn_metadata_raw[layer_name]
elif isinstance(attn_metadata_raw, list):
# list[dict[str, AttentionMetadata]]: used in speculative decoding
# where [0] is the base-model (non-speculative) metadata dict.
attn_metadata = attn_metadata_raw[0][layer_name]
else:
attn_metadata = attn_metadata_raw
attn_layer: Attention | MLAAttention = forward_context.no_compile_layers[layer_name]
kv_cache = attn_layer.kv_cache
slot_mapping = forward_context.slot_mapping
@@ -708,7 +717,7 @@ def unified_kv_cache_update(
assert hasattr(attn_layer.impl, "do_kv_cache_update"), (
f"{attn_layer.impl.__class__.__name__} does not support kv cache update"
)
attn_layer.impl.do_kv_cache_update(
attn_layer.impl.do_kv_cache_update( # type: ignore[attr-defined]
attn_layer,
key,
value,
@@ -29,7 +29,7 @@ from vllm.v1.kv_cache_interface import (
@functools.lru_cache
def create_chunked_local_attention_backend(
underlying_attn_backend: AttentionBackend,
underlying_attn_backend: type[AttentionBackend],
attention_chunk_size: int,
) -> type[AttentionBackend]:
prefix = f"ChunkedLocalAttention_{attention_chunk_size}_"
@@ -72,7 +72,7 @@ def _get_cross_slot_mapping(
@functools.lru_cache
def create_cross_attention_backend(
underlying_attn_backend: AttentionBackend,
underlying_attn_backend: type[AttentionBackend],
) -> type[AttentionBackend]:
prefix = "CrossAttention_"
underlying_builder = underlying_attn_backend.get_builder_cls()
@@ -87,6 +87,7 @@ def create_cross_attention_backend(
) -> AttentionMetadata:
new_metadata = copy(common_attn_metadata)
new_metadata.causal = False
assert new_metadata.encoder_seq_lens_cpu is not None
max_encoder_len = int(new_metadata.encoder_seq_lens_cpu.max())
new_metadata.max_seq_len = max_encoder_len
# Any computed tokens indicated decode step>1 (no chunked prefill)
@@ -118,7 +119,7 @@ def create_cross_attention_backend(
self.device,
)
attn_metadata = super().build(common_prefix_len, new_metadata, fast_build)
attn_metadata.slot_mapping = slot_mapping
attn_metadata.slot_mapping = slot_mapping # type: ignore[attr-defined]
return attn_metadata
# NOTE(Lucas): we need a custom impl so we can use the slot-mapping computed by
@@ -144,8 +145,12 @@ def create_cross_attention_backend(
and key is not None
and value is not None
):
self.do_kv_cache_update(
layer, key, value, kv_cache, attn_metadata.slot_mapping
self.do_kv_cache_update( # type: ignore[attr-defined]
layer,
key,
value,
kv_cache,
attn_metadata.slot_mapping, # type: ignore[attr-defined]
)
return super().forward(
@@ -21,7 +21,7 @@ from vllm.v1.kv_cache_interface import KVCacheSpec
@functools.lru_cache
def create_encoder_only_attention_backend(
underlying_attn_backend: AttentionBackend,
underlying_attn_backend: type[AttentionBackend],
) -> type[AttentionBackend]:
prefix = "EncoderOnlyAttention_"
underlying_builder = underlying_attn_backend.get_builder_cls()
@@ -93,6 +93,6 @@ class EncoderOnlyAttention(Attention):
**kwargs,
)
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None:
# Does not need KV cache
return None
@@ -389,7 +389,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
cache_config.enable_prefix_caching = False
impl_cls = cast(type[MLAAttentionImpl], self.attn_backend.get_impl_cls())
self.impl = impl_cls(
self.impl = impl_cls( # type: ignore[assignment] # impl_cls always returns an MLAAttentionImpl subclass
num_heads=self.num_heads,
head_size=self.head_size,
scale=self.scale,
@@ -485,16 +485,23 @@ class MLAAttention(nn.Module, AttentionLayerBase):
if self.use_direct_call:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[self.layer_name]
attn_metadata_raw = forward_context.attn_metadata
attn_metadata: MLACommonMetadata
if isinstance(attn_metadata_raw, dict):
attn_metadata = attn_metadata_raw[self.layer_name] # type: ignore[assignment]
elif isinstance(attn_metadata_raw, list):
# list[dict[str, AttentionMetadata]]: used in speculative decoding
# where [0] is the base-model (non-speculative) metadata dict.
attn_metadata = attn_metadata_raw[0][self.layer_name] # type: ignore[assignment]
else:
attn_metadata = attn_metadata_raw
self_kv_cache = self.kv_cache
slot_mapping = forward_context.slot_mapping
assert isinstance(slot_mapping, dict), (
f"Expected slot_mapping to be a dict, got {type(slot_mapping)}. "
)
self.impl.do_kv_cache_update(
self.impl.do_kv_cache_update( # type: ignore[attr-defined]
kv_c_normed,
k_pe,
self_kv_cache,
@@ -612,7 +619,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
num_mha_tokens = q.size(0) - num_mqa_tokens
if num_mha_tokens > 0:
self.impl.forward_mha(
self.impl.forward_mha( # type: ignore[attr-defined]
q[num_mqa_tokens:],
k_c_normed[num_mqa_tokens:],
k_pe[num_mqa_tokens:],
@@ -695,7 +702,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
# call decode attn
if not is_sparse_impl:
assert attn_metadata.decode is not None
attn_out, lse = self.impl.forward_mqa(mqa_q, kv_cache, attn_metadata, self)
attn_out, lse = self.impl.forward_mqa(mqa_q, kv_cache, attn_metadata, self) # type: ignore[attr-defined]
# correct dcp attn_out with lse.
if self.impl.dcp_world_size > 1:
@@ -1053,9 +1060,9 @@ except ImportError:
"AITER_MLA backends use aiter kernels instead."
)
elif current_platform.is_xpu():
from vllm._xpu_ops import xpu_ops as ops
from vllm._xpu_ops import xpu_ops
flash_attn_varlen_func = ops.flash_attn_varlen_func # type: ignore[no-redef]
flash_attn_varlen_func = xpu_ops.flash_attn_varlen_func # type: ignore[no-redef,attr-defined,assignment]
def dynamic_per_batched_tensor_quant(
@@ -1988,7 +1995,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
assert isinstance(attn_metadata.prefill, FlashInferPrefillMetadata)
self._build_fi_prefill_wrappers(attn_metadata.prefill)
return attn_metadata
return attn_metadata # type: ignore[return-value]
def reorg_kvcache(
@@ -117,17 +117,20 @@ def maybe_make_prepare_finalize(
"Detected DP deployment with no --enable-expert-parallel. "
"Falling back to AllGather+ReduceScatter dispatch/combine."
)
device_communicator = get_ep_group().device_communicator
assert device_communicator is not None
assert device_communicator.all2all_manager is not None
return make_moe_prepare_and_finalize_naive_dp_ep(
is_sequence_parallel=moe.moe_parallel_config.is_sequence_parallel,
num_dispatchers=(
get_ep_group().device_communicator.all2all_manager.world_size
),
num_dispatchers=(device_communicator.all2all_manager.world_size),
use_monolithic=use_monolithic,
)
else:
return make_moe_prepare_and_finalize_no_dp_ep(use_monolithic)
all2all_manager = get_ep_group().device_communicator.all2all_manager
device_communicator = get_ep_group().device_communicator
assert device_communicator is not None
all2all_manager = device_communicator.all2all_manager
assert all2all_manager is not None
prepare_finalize: FusedMoEPrepareAndFinalize | None = None
@@ -7,6 +7,7 @@ from typing import Union
import torch
from vllm.config import ParallelConfig, SchedulerConfig
from vllm.config.kernel import MoEBackend
from vllm.distributed import get_dp_group, get_pcp_group, get_tensor_model_parallel_rank
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
@@ -1192,7 +1193,7 @@ class FusedMoEConfig:
# Defaults to intermediate_size_per_partition if not specified.
intermediate_size_per_partition_unpadded: int | None = None
moe_backend: str = "auto"
moe_backend: MoEBackend = "auto"
max_num_tokens: int = SchedulerConfig.DEFAULT_MAX_NUM_BATCHED_TOKENS_FOR_BATCHED_DP
has_bias: bool = False
is_act_and_mul: bool = True
@@ -210,9 +210,9 @@ def persistent_masked_m_silu_mul_quant(
DeepGemmQuantScaleFMT.UE8M0,
]
cuda_arch = current_platform.get_device_capability(
device_id=y.device.index
).to_int()
device_capability = current_platform.get_device_capability(device_id=y.device.index)
assert device_capability is not None
cuda_arch = device_capability.to_int()
if current_platform.is_cuda() and cuda_arch >= 80:
torch.ops._C.persistent_masked_m_silu_mul_quant(
@@ -7,6 +7,7 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import envs
from vllm.config.kernel import MoEBackend
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (
FusedMoEConfig,
@@ -146,7 +147,7 @@ def backend_to_kernel_cls(
raise ValueError(f"Unknown MXFP4 MoE backend: {backend.value}")
def map_mxfp4_backend(runner_backend: str) -> Mxfp4MoeBackend:
def map_mxfp4_backend(runner_backend: MoEBackend) -> Mxfp4MoeBackend:
"""Map user's moe_backend string to Mxfp4MoeBackend."""
mapping: dict[str, Mxfp4MoeBackend] = {
"flashinfer_trtllm": Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16,
@@ -201,10 +202,12 @@ def select_gpt_oss_mxfp4_moe_backend(
Select the primary MXFP4 MoE backend.
Note: Shape-specific fallbacks may still occur at runtime.
"""
triton_kernels_supported = has_triton_kernels() and (
9,
0,
) <= current_platform.get_device_capability() < (11, 0)
device_capability = current_platform.get_device_capability()
triton_kernels_supported = (
has_triton_kernels()
and device_capability is not None
and (9, 0) <= device_capability < (11, 0)
)
# LoRA: separate experts backend path
if config.is_lora_enabled:
@@ -4,6 +4,9 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.distributed import get_ep_group
from vllm.distributed.device_communicators.base_device_communicator import (
All2AllManagerBase,
)
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
@@ -11,12 +14,16 @@ from vllm.utils.flashinfer import nvfp4_block_scale_interleave
def get_local_sizes():
return get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank()
dp_metadata = get_forward_context().dp_metadata
assert dp_metadata is not None
return dp_metadata.get_chunk_sizes_across_dp_rank()
class FlashInferNVLinkOneSidedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
"""FlashInfer implementation using the Moe AlltoAll kernel."""
all2all_manager: All2AllManagerBase
def __init__(
self,
max_num_tokens: int,
@@ -32,8 +39,12 @@ class FlashInferNVLinkOneSidedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeMo
self.hidden_size = hidden_size
self.num_dispatchers_ = num_dispatchers
self.all2all_manager = get_ep_group().device_communicator.all2all_manager
self.all2all_manager.initialize(
device_communicator = get_ep_group().device_communicator
assert device_communicator is not None
all2all_manager = device_communicator.all2all_manager
assert all2all_manager is not None
self.all2all_manager = all2all_manager
self.all2all_manager.initialize( # type: ignore[attr-defined]
max_num_tokens=self.max_num_tokens,
top_k=self.top_k,
num_experts=self.num_experts,
@@ -97,7 +108,8 @@ class FlashInferNVLinkOneSidedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeMo
payloads.append(topk_ids)
payloads.append(topk_weights)
recv_payloads = self.all2all_manager.moe_alltoall.dispatch(
assert self.all2all_manager.moe_alltoall is not None # type: ignore[attr-defined]
recv_payloads = self.all2all_manager.moe_alltoall.dispatch( # type: ignore[attr-defined]
token_selected_experts=topk_ids,
input_payloads=payloads,
runtime_max_tokens_per_rank=self.runtime_max_tokens_per_rank,
@@ -131,7 +143,7 @@ class FlashInferNVLinkOneSidedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeMo
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> None:
assert self.all2all_manager.moe_alltoall is not None
assert self.all2all_manager.moe_alltoall is not None # type: ignore[attr-defined]
ep_size = self.all2all_manager.world_size
hidden_size = fused_expert_output.shape[-1]
@@ -139,7 +151,7 @@ class FlashInferNVLinkOneSidedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeMo
ep_size, self.runtime_max_tokens_per_rank, hidden_size
)
combined_output = self.all2all_manager.moe_alltoall.combine(
combined_output = self.all2all_manager.moe_alltoall.combine( # type: ignore[attr-defined]
payload=fused_expert_output,
runtime_max_tokens_per_rank=self.runtime_max_tokens_per_rank,
)
@@ -15,19 +15,26 @@ from vllm.utils.flashinfer import nvfp4_block_scale_interleave
def get_local_sizes():
return get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank()
dp_metadata = get_forward_context().dp_metadata
assert dp_metadata is not None
return dp_metadata.get_chunk_sizes_across_dp_rank()
class FlashInferNVLinkTwoSidedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
"""Base class for FlashInfer MoE prepare and finalize operations."""
all2all_manager: All2AllManagerBase
def __init__(
self,
num_dispatchers: int = 1,
):
super().__init__()
self.num_dispatchers_ = num_dispatchers
self.all2all_manager = get_ep_group().device_communicator.all2all_manager
device_communicator = get_ep_group().device_communicator
assert device_communicator is not None
assert device_communicator.all2all_manager is not None
self.all2all_manager = device_communicator.all2all_manager
@property
def activation_format(self) -> mk.FusedMoEActivationFormat:
@@ -129,7 +136,7 @@ def flashinfer_alltoall_dispatch(
):
from flashinfer.comm.trtllm_alltoall import MnnvlMoe
assert all2all_manager.ensure_alltoall_workspace_initialized(), (
assert all2all_manager.ensure_alltoall_workspace_initialized(), ( # type: ignore[attr-defined]
"FlashInfer AllToAll workspace not available"
)
@@ -144,7 +151,7 @@ def flashinfer_alltoall_dispatch(
topk_ids,
topk_weights,
None,
all2all_manager.prepare_workspace_tensor,
all2all_manager.prepare_workspace_tensor, # type: ignore[attr-defined]
max_num_token,
ep_rank,
ep_size,
@@ -172,7 +179,7 @@ def flashinfer_alltoall_dispatch(
x = MnnvlMoe.mnnvl_moe_alltoallv(
x,
alltoall_info,
all2all_manager.workspace_tensor,
all2all_manager.workspace_tensor, # type: ignore[attr-defined]
ep_rank,
ep_size,
)
@@ -180,7 +187,7 @@ def flashinfer_alltoall_dispatch(
x_sf = MnnvlMoe.mnnvl_moe_alltoallv(
x_sf,
alltoall_info,
all2all_manager.workspace_tensor,
all2all_manager.workspace_tensor, # type: ignore[attr-defined]
ep_rank,
ep_size,
)
@@ -196,7 +203,7 @@ def flashinfer_alltoall_dispatch(
x = MnnvlMoe.mnnvl_moe_alltoallv(
x,
alltoall_info,
all2all_manager.workspace_tensor,
all2all_manager.workspace_tensor, # type: ignore[attr-defined]
ep_rank,
ep_size,
)
@@ -212,13 +219,13 @@ def flashinfer_alltoall_combine(
):
from flashinfer.comm.trtllm_alltoall import MnnvlMoe
assert all2all_manager.ensure_alltoall_workspace_initialized(), (
assert all2all_manager.ensure_alltoall_workspace_initialized(), ( # type: ignore[attr-defined]
"FlashInfer AllToAll workspace not available"
)
return MnnvlMoe.mnnvl_moe_alltoallv_combine(
output,
alltoall_info,
all2all_manager.workspace_tensor,
all2all_manager.workspace_tensor, # type: ignore[attr-defined]
ep_rank=all2all_manager.rank,
ep_size=all2all_manager.world_size,
top_k=top_k,
@@ -132,9 +132,11 @@ class MoEPrepareAndFinalizeNaiveDPEPModular(mk.FusedMoEPrepareAndFinalizeModular
)
if scales is None:
assert len(res) == 3
a1q, topk_weights, topk_ids = res
a1q_scale = None
else:
assert len(res) == 4
a1q, topk_weights, topk_ids, scales = res
a1q_scale = _unwrap_scale_and_prepare_for_moe(scales, quant_config)
@@ -217,9 +219,11 @@ class MoEPrepareAndFinalizeNaiveDPEPMonolithic(mk.FusedMoEPrepareAndFinalizeMono
)
if scales is None:
assert len(res) == 2
a1q, router_logits = res
a1q_scale = None
else:
assert len(res) == 3
a1q, router_logits, scales = res
a1q_scale = _unwrap_scale_and_prepare_for_moe(scales, quant_config)
@@ -54,11 +54,13 @@ class DefaultMoERunner(MoERunnerBase):
# NOTE: this will be removed once all kernels are migrated into the
# MoEKernel framework.
if self.do_naive_dispatch_combine:
hidden_states, router_logits = get_ep_group().dispatch_router_logits(
res = get_ep_group().dispatch_router_logits(
hidden_states,
router_logits,
self.moe_config.is_sequence_parallel,
)
assert len(res) == 2
hidden_states, router_logits = res
# NOTE: Similar with DP, PCP also needs dispatch and combine. For
# simplicity, AgRsAll2All was added separately for PCP here. Maybe
+25 -18
View File
@@ -16,7 +16,6 @@ from vllm.logger import init_logger
from vllm.model_executor.model_loader.weight_utils import sharded_weight_loader
from vllm.model_executor.utils import set_weight_attrs
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata
from .fla.ops.kda import (
@@ -123,7 +122,7 @@ class KimiDeltaAttention(nn.Module, MambaBase):
self.cache_config = cache_config
if model_config is None:
raise ValueError("model_config must be provided")
kda_config = model_config.linear_attn_config
kda_config = model_config.linear_attn_config # type: ignore[attr-defined]
self.head_dim = kda_config["head_dim"]
self.num_heads = kda_config["num_heads"]
self.layer_idx = layer_idx
@@ -297,19 +296,21 @@ class KimiDeltaAttention(nn.Module, MambaBase):
core_attn_out: torch.Tensor,
) -> None:
forward_context = get_forward_context()
attn_metadata: AttentionMetadata = forward_context.attn_metadata
attn_metadata_raw = forward_context.attn_metadata
if attn_metadata is None:
if attn_metadata_raw is None:
# # V1 profile run
return
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
assert isinstance(attn_metadata, GDNAttentionMetadata)
has_initial_state = attn_metadata.has_initial_state
non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501
num_actual_tokens = attn_metadata.num_actual_tokens
assert isinstance(attn_metadata_raw, dict)
attn_metadata_narrowed = attn_metadata_raw[self.prefix]
assert isinstance(attn_metadata_narrowed, GDNAttentionMetadata)
has_initial_state = attn_metadata_narrowed.has_initial_state
non_spec_query_start_loc = attn_metadata_narrowed.non_spec_query_start_loc
non_spec_state_indices_tensor = (
attn_metadata_narrowed.non_spec_state_indices_tensor
) # noqa: E501
num_actual_tokens = attn_metadata_narrowed.num_actual_tokens
constant_caches = self.kv_cache
q_proj_states = q_proj_states[:num_actual_tokens]
@@ -335,7 +336,7 @@ class KimiDeltaAttention(nn.Module, MambaBase):
v_conv_weights = self.v_conv1d.weight.view(
self.v_conv1d.weight.size(0), self.v_conv1d.weight.size(2)
)
if attn_metadata.num_prefills > 0:
if attn_metadata_narrowed.num_prefills > 0:
q_proj_states = q_proj_states.transpose(0, 1)
k_proj_states = k_proj_states.transpose(0, 1)
v_proj_states = v_proj_states.transpose(0, 1)
@@ -348,7 +349,7 @@ class KimiDeltaAttention(nn.Module, MambaBase):
has_initial_state=has_initial_state,
cache_indices=non_spec_state_indices_tensor,
query_start_loc=non_spec_query_start_loc,
metadata=attn_metadata,
metadata=attn_metadata_narrowed,
).transpose(0, 1)
k = causal_conv1d_fn(
k_proj_states,
@@ -359,7 +360,7 @@ class KimiDeltaAttention(nn.Module, MambaBase):
has_initial_state=has_initial_state,
cache_indices=non_spec_state_indices_tensor,
query_start_loc=non_spec_query_start_loc,
metadata=attn_metadata,
metadata=attn_metadata_narrowed,
).transpose(0, 1)
v = causal_conv1d_fn(
v_proj_states,
@@ -370,11 +371,12 @@ class KimiDeltaAttention(nn.Module, MambaBase):
has_initial_state=has_initial_state,
cache_indices=non_spec_state_indices_tensor,
query_start_loc=non_spec_query_start_loc,
metadata=attn_metadata,
metadata=attn_metadata_narrowed,
).transpose(0, 1)
else:
assert non_spec_state_indices_tensor is not None
decode_conv_indices = non_spec_state_indices_tensor[
: attn_metadata.num_actual_tokens
: attn_metadata_narrowed.num_actual_tokens
]
q = causal_conv1d_update(
q_proj_states,
@@ -408,7 +410,9 @@ class KimiDeltaAttention(nn.Module, MambaBase):
lambda x: rearrange(x, "n (h d) -> 1 n h d", d=self.head_dim), (q, k, v)
)
if attn_metadata.num_prefills > 0:
if attn_metadata_narrowed.num_prefills > 0:
assert non_spec_state_indices_tensor is not None
assert has_initial_state is not None
zero_idx = non_spec_state_indices_tensor[~has_initial_state]
recurrent_state[zero_idx] = 0
initial_state = recurrent_state[non_spec_state_indices_tensor].contiguous()
@@ -429,6 +433,7 @@ class KimiDeltaAttention(nn.Module, MambaBase):
# Init cache
recurrent_state[non_spec_state_indices_tensor] = last_recurrent_state
else:
assert non_spec_query_start_loc is not None
(
core_attn_out_non_spec,
last_recurrent_state,
@@ -440,7 +445,9 @@ class KimiDeltaAttention(nn.Module, MambaBase):
beta=beta,
initial_state=recurrent_state,
use_qk_l2norm_in_kernel=True,
cu_seqlens=non_spec_query_start_loc[: attn_metadata.num_decodes + 1],
cu_seqlens=non_spec_query_start_loc[
: attn_metadata_narrowed.num_decodes + 1
],
ssm_state_indices=non_spec_state_indices_tensor,
)
core_attn_out[0, :num_actual_tokens] = core_attn_out_non_spec[
+1 -1
View File
@@ -76,7 +76,7 @@ def poly_norm(
from vllm import _custom_ops as ops
out = torch.empty_like(x)
ops.poly_norm(
ops.poly_norm( # type: ignore[attr-defined]
out,
x,
weight,
+2 -1
View File
@@ -42,9 +42,10 @@ class MambaBase(AttentionLayerBase):
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None:
mamba_block_size = vllm_config.cache_config.mamba_block_size
assert mamba_block_size is not None
page_size_padded = vllm_config.cache_config.mamba_page_size_padded
return MambaSpec(
shapes=self.get_state_shape(),
shapes=tuple(self.get_state_shape()),
dtypes=self.get_state_dtype(),
block_size=mamba_block_size,
page_size_padded=page_size_padded,
@@ -62,7 +62,6 @@ from vllm.utils.torch_utils import (
_resolve_layer_name,
direct_register_custom_op,
)
from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata
logger = init_logger(__name__)
@@ -121,9 +120,9 @@ def fi_chunk_gated_delta_rule(
class ChunkGatedDeltaRule(CustomOp):
def __init__(self) -> None:
super().__init__()
backend_cfg = get_current_vllm_config().additional_config.get(
"gdn_prefill_backend", "auto"
)
additional_config = get_current_vllm_config().additional_config
assert isinstance(additional_config, dict)
backend_cfg = additional_config.get("gdn_prefill_backend", "auto")
backend = str(backend_cfg).strip().lower()
supports_flashinfer = (
@@ -621,18 +620,19 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
# Part 2: Core Attention
# ============================================================
forward_context = get_forward_context()
attn_metadata: AttentionMetadata = forward_context.attn_metadata
attn_metadata_raw = forward_context.attn_metadata
core_attn_out = torch.zeros(
(num_tokens, self.num_v_heads // self.tp_size, self.head_v_dim),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
z = torch.empty_like(core_attn_out)
if attn_metadata is not None:
attn_metadata = attn_metadata[self.prefix]
if attn_metadata_raw is not None:
assert isinstance(attn_metadata_raw, dict)
attn_metadata = attn_metadata_raw[self.prefix]
# TODO: xpu does not support this param yet
spec_sequence_masks = attn_metadata.spec_sequence_masks
spec_sequence_masks = attn_metadata.spec_sequence_masks # type: ignore[attr-defined]
assert spec_sequence_masks is None
conv_weights = self.conv1d.weight.view(
@@ -658,12 +658,12 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
activation=self.activation,
A_log=self.A_log,
dt_bias=self.dt_bias,
num_prefills=attn_metadata.num_prefills,
num_decodes=attn_metadata.num_decodes,
has_initial_state=attn_metadata.has_initial_state,
non_spec_query_start_loc=attn_metadata.non_spec_query_start_loc,
non_spec_state_indices_tensor=attn_metadata.non_spec_state_indices_tensor,
num_actual_tokens=attn_metadata.num_actual_tokens,
num_prefills=attn_metadata.num_prefills, # type: ignore[attr-defined]
num_decodes=attn_metadata.num_decodes, # type: ignore[attr-defined]
has_initial_state=attn_metadata.has_initial_state, # type: ignore[attr-defined]
non_spec_query_start_loc=attn_metadata.non_spec_query_start_loc, # type: ignore[attr-defined]
non_spec_state_indices_tensor=attn_metadata.non_spec_state_indices_tensor, # type: ignore[attr-defined]
num_actual_tokens=attn_metadata.num_actual_tokens, # type: ignore[attr-defined]
tp_size=self.tp_size,
reorder_input=not self.gqa_interleaved_layout,
)
@@ -792,16 +792,16 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
core_attn_out: torch.Tensor,
):
forward_context = get_forward_context()
attn_metadata: AttentionMetadata = forward_context.attn_metadata
attn_metadata_raw = forward_context.attn_metadata
if attn_metadata is None:
if attn_metadata_raw is None:
# V1 profile run — warm up prefill kernels so that
# autotuning completes before KV cache allocation.
self._warmup_prefill_kernels(mixed_qkv)
return
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
assert isinstance(attn_metadata_raw, dict)
attn_metadata = attn_metadata_raw[self.prefix] # type: ignore[index]
assert isinstance(attn_metadata, GDNAttentionMetadata)
if (
@@ -860,14 +860,16 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
# 1.1: Process the multi-query part
if spec_sequence_masks is not None:
# spec_state_indices_tensor is always set when spec_sequence_masks is set
assert spec_state_indices_tensor is not None
mixed_qkv_spec = causal_conv1d_update(
mixed_qkv_spec,
conv_state,
conv_weights,
self.conv1d.bias,
self.activation,
conv_state_indices=spec_state_indices_tensor[:, 0][
: attn_metadata.num_spec_decodes
conv_state_indices=spec_state_indices_tensor[:, 0][ # type: ignore[index]
: attn_metadata.num_spec_decodes # type: ignore[attr-defined]
],
num_accepted_tokens=num_accepted_tokens,
query_start_loc=spec_query_start_loc,
@@ -900,8 +902,8 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
conv_weights,
self.conv1d.bias,
self.activation,
conv_state_indices=non_spec_state_indices_tensor[
: attn_metadata.num_actual_tokens
conv_state_indices=non_spec_state_indices_tensor[ # type: ignore[index]
: attn_metadata.num_actual_tokens # type: ignore[attr-defined]
],
validate_data=True,
)
@@ -965,8 +967,9 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
v=value_spec,
initial_state=ssm_state,
inplace_final_state=True,
cu_seqlens=spec_query_start_loc[
: attn_metadata.num_spec_decodes + 1
cu_seqlens=spec_query_start_loc[ # type: ignore[index]
: attn_metadata.num_spec_decodes
+ 1 # type: ignore[attr-defined]
],
ssm_state_indices=spec_state_indices_tensor,
num_accepted_tokens=num_accepted_tokens,
@@ -978,8 +981,10 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
# 2.2: Process the remaining part
if attn_metadata.num_prefills > 0:
initial_state = ssm_state[non_spec_state_indices_tensor].contiguous()
initial_state[~has_initial_state, ...] = 0
assert non_spec_state_indices_tensor is not None
initial_state = ssm_state[non_spec_state_indices_tensor].contiguous() # type: ignore[index]
assert has_initial_state is not None
initial_state[~has_initial_state, ...] = 0 # type: ignore[operator]
(
core_attn_out_non_spec,
last_recurrent_state,
@@ -1012,8 +1017,9 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
v=value_non_spec,
initial_state=ssm_state,
inplace_final_state=True,
cu_seqlens=non_spec_query_start_loc[
: attn_metadata.num_decodes + 1
cu_seqlens=non_spec_query_start_loc[ # type: ignore[index]
: attn_metadata.num_decodes
+ 1 # type: ignore[attr-defined]
],
ssm_state_indices=non_spec_state_indices_tensor,
use_qk_l2norm_in_kernel=True,
@@ -1073,7 +1079,7 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
conv_weights,
self.conv1d.bias,
self.activation,
conv_state_indices=non_spec_state_indices_tensor[:num_actual_tokens],
conv_state_indices=non_spec_state_indices_tensor[:num_actual_tokens], # type: ignore[index]
validate_data=False,
)
out_buf = core_attn_out[:num_actual_tokens].unsqueeze(1)
@@ -1086,7 +1092,7 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
scale=self.head_k_dim**-0.5,
initial_state=ssm_state,
out=out_buf,
ssm_state_indices=non_spec_state_indices_tensor[:num_actual_tokens],
ssm_state_indices=non_spec_state_indices_tensor[:num_actual_tokens], # type: ignore[index]
use_qk_l2norm_in_kernel=True,
)
return
@@ -396,10 +396,11 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
self, hidden_states: torch.Tensor, output: torch.Tensor, positions: torch.Tensor
) -> None:
forward_context = get_forward_context()
attn_metadata: AttentionMetadata = forward_context.attn_metadata
if attn_metadata is not None:
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
attn_metadata_raw = forward_context.attn_metadata
attn_metadata: AttentionMetadata | None = None
if attn_metadata_raw is not None:
assert isinstance(attn_metadata_raw, dict)
attn_metadata = attn_metadata_raw[self.prefix]
assert isinstance(attn_metadata, LinearAttentionMetadata)
num_actual_tokens = (
attn_metadata.num_prefill_tokens + attn_metadata.num_decode_tokens
@@ -40,6 +40,7 @@ from vllm.utils.torch_utils import (
_resolve_layer_name,
direct_register_custom_op,
)
from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionMetadata
@@ -258,15 +259,16 @@ class MambaMixer(MambaBase, PluggableLayer):
"""
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
attn_metadata_raw = forward_context.attn_metadata
assert self.cache_config is not None
mamba_block_size = self.cache_config.mamba_block_size
is_mamba_cache_all = self.cache_config.mamba_cache_mode == "all"
if attn_metadata is not None:
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
attn_metadata: AttentionMetadata | None = None
if attn_metadata_raw is not None:
assert isinstance(attn_metadata_raw, dict)
attn_metadata = attn_metadata_raw[self.prefix]
assert isinstance(attn_metadata, Mamba1AttentionMetadata)
query_start_loc_p = attn_metadata.query_start_loc_p
state_indices_tensor_p = attn_metadata.state_indices_tensor_p
@@ -391,6 +393,9 @@ class MambaMixer(MambaBase, PluggableLayer):
ssm_outputs.append(scan_out_p)
if has_decode:
# state_indices_tensor_d is assigned when attn_metadata is not None,
# and has_decode is only True when attn_metadata is not None
assert state_indices_tensor_d is not None
if is_mamba_cache_all:
state_indices_tensor_d_input = state_indices_tensor_d.gather(
1, block_idx_last_computed_token_d.unsqueeze(1)
@@ -572,14 +572,16 @@ class MambaMixer2(MambaBase, PluggableLayer):
# kernels to operate in continuous batching and in chunked prefill
# modes; they are computed at top-level model forward since they
# stay the same and reused for all mamba layers in the same iteration
attn_metadata: AttentionMetadata = forward_context.attn_metadata
attn_metadata_raw = forward_context.attn_metadata
assert self.cache_config is not None
mamba_block_size = self.cache_config.mamba_block_size
is_mamba_cache_all = self.cache_config.mamba_cache_mode == "all"
if attn_metadata is not None:
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
attn_metadata: AttentionMetadata | None = None
if attn_metadata_raw is not None:
assert isinstance(attn_metadata_raw, dict)
attn_metadata = attn_metadata_raw[self.prefix]
assert isinstance(attn_metadata, Mamba2AttentionMetadata)
# conv_state must be (..., dim, width-1) for the conv kernels.
# DS layout stores it that way directly; SD layout needs a
@@ -708,6 +710,7 @@ class MambaMixer2(MambaBase, PluggableLayer):
# 3. State Space Model sequence transformation
initial_states = None
if has_initial_states_p is not None and prep_initial_states:
assert state_indices_tensor_p is not None
kernel_ssm_indices = state_indices_tensor_p
if is_mamba_cache_all:
kernel_ssm_indices = state_indices_tensor_p.gather(
@@ -746,6 +749,13 @@ class MambaMixer2(MambaBase, PluggableLayer):
)
if is_mamba_cache_all:
assert mamba_block_size is not None
assert state_indices_tensor_p is not None
assert block_idx_first_scheduled_token_p is not None
assert block_idx_last_scheduled_token_p is not None
assert last_chunk_indices_p is not None
assert num_computed_tokens_p is not None
# The chunk_stride is the number of chunks per mamba block
# e.g., if mamba_block_size = 512 and chunk_size = 256,
# then chunk_stride = 2
@@ -810,6 +820,7 @@ class MambaMixer2(MambaBase, PluggableLayer):
ssm_state[cache_blocks_to_fill] = from_where
# For all seqs, store the last state (note: might be partial):
assert state_indices_tensor_p is not None
ssm_state[
state_indices_tensor_p.gather(
1, block_idx_last_scheduled_token_p.unsqueeze(1)
@@ -820,10 +831,12 @@ class MambaMixer2(MambaBase, PluggableLayer):
# update ssm states
# - varlen state is a (num_prefills, nheads, headdim, dstate)
# tensor
assert state_indices_tensor_p is not None
ssm_state[state_indices_tensor_p] = varlen_states
# Process decode requests
if has_decode:
assert state_indices_tensor_d is not None
if is_mamba_cache_all:
state_indices_tensor_d_input = state_indices_tensor_d.gather(
1, block_idx_last_computed_token_d.unsqueeze(1)
@@ -113,10 +113,11 @@ class ShortConv(MambaBase, CustomOp):
# chunked prefill modes; they are computed at top-level model forward
# since they stay the same and reused for all mamba layers in the same
# iteration.
attn_metadata: AttentionMetadata = forward_context.attn_metadata
if attn_metadata is not None:
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
attn_metadata_raw = forward_context.attn_metadata
attn_metadata: AttentionMetadata | None = None
if attn_metadata_raw is not None:
assert isinstance(attn_metadata_raw, dict)
attn_metadata = attn_metadata_raw[self.prefix]
assert isinstance(attn_metadata, ShortConvAttentionMetadata)
conv_state = (
self.kv_cache[0]
@@ -115,6 +115,7 @@ def pooler_for_classify(
vllm_config = get_current_vllm_config()
model_config = vllm_config.model_config
assert model_config.pooler_config is not None
head = ClassifierPoolerHead(
head_dtype=model_config.head_dtype,
classifier=classifier,
@@ -124,6 +124,7 @@ def pooler_for_token_classify(
vllm_config = get_current_vllm_config()
model_config = vllm_config.model_config
assert model_config.pooler_config is not None
head = TokenClassifierPoolerHead(
head_dtype=model_config.head_dtype,
classifier=classifier,
@@ -3,7 +3,7 @@
# Supports FP-Quant compression, see https://arxiv.org/abs/2509.23202
from typing import Any
from typing import Any, Literal, cast
import torch
from torch.nn.parameter import Parameter
@@ -251,7 +251,11 @@ class FPQuantLinearMethod(LinearMethodBase):
def fused_quantize_mx(
x_flat: torch.Tensor, hadamard_matrix: torch.Tensor, forward_method: str
) -> tuple[torch.Tensor, torch.Tensor]:
return fusedQuantizeMx(x_flat, hadamard_matrix, method=forward_method)
return fusedQuantizeMx(
x_flat,
hadamard_matrix,
method=cast(Literal["quest", "abs_max"], forward_method),
)
def fused_quantize_mx_fake(x_flat, hadamard_matrix, forward_method):
@@ -114,7 +114,7 @@ class QuarkConfig(QuantizationConfig):
:param hf_to_vllm_mapper: maps from hf model structure (the assumed
structure of the qconfig) to vllm model structure
"""
quant_config_with_hf_to_vllm_mapper = {}
quant_config_with_hf_to_vllm_mapper: dict[str, Any] = {}
for k, v in self.quant_config.items():
if isinstance(v, list):
@@ -26,7 +26,7 @@ from vllm.v1.worker.workspace import current_workspace_manager
if current_platform.is_cuda_alike():
from vllm import _custom_ops as ops
elif current_platform.is_xpu():
from vllm._xpu_ops import xpu_ops as ops
from vllm._xpu_ops import xpu_ops
logger = init_logger(__name__)
@@ -84,12 +84,12 @@ def sparse_attn_indexer(
total_seq_lens,
topk_indices_buffer,
)
attn_metadata = attn_metadata[k_cache_prefix]
assert isinstance(attn_metadata, DeepseekV32IndexerMetadata)
slot_mapping = attn_metadata.slot_mapping
has_decode = attn_metadata.num_decodes > 0
has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens
attn_metadata_narrowed = attn_metadata[k_cache_prefix]
assert isinstance(attn_metadata_narrowed, DeepseekV32IndexerMetadata)
slot_mapping = attn_metadata_narrowed.slot_mapping
has_decode = attn_metadata_narrowed.num_decodes > 0
has_prefill = attn_metadata_narrowed.num_prefills > 0
num_decode_tokens = attn_metadata_narrowed.num_decode_tokens
# During speculative decoding, k may be padded to the CUDA graph batch
# size while slot_mapping only covers actual tokens. Truncate k to avoid
@@ -97,6 +97,8 @@ def sparse_attn_indexer(
num_tokens = slot_mapping.shape[0]
k = k[:num_tokens]
# scale_fmt can be None, but the function expects str
assert scale_fmt is not None
ops.indexer_k_quant_and_cache(
k,
kv_cache,
@@ -107,7 +109,7 @@ def sparse_attn_indexer(
topk_indices_buffer[: hidden_states.shape[0]] = -1
if has_prefill:
prefill_metadata = attn_metadata.prefill
prefill_metadata = attn_metadata_narrowed.prefill
assert prefill_metadata is not None
# Get the full shared workspace buffers once (will allocate on first use)
@@ -144,7 +146,7 @@ def sparse_attn_indexer(
]
if current_platform.is_xpu():
ops.top_k_per_row_prefill(
xpu_ops.top_k_per_row_prefill( # type: ignore[attr-defined]
logits,
chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke,
@@ -167,7 +169,7 @@ def sparse_attn_indexer(
)
if has_decode:
decode_metadata = attn_metadata.decode
decode_metadata = attn_metadata_narrowed.decode
assert decode_metadata is not None
# kv_cache shape [
# kv_cache size requirement [num_block, block_size, n_head, head_dim],
@@ -217,11 +219,11 @@ def sparse_attn_indexer(
topk_indices,
topk_workspace,
topk_tokens,
attn_metadata.max_seq_len,
attn_metadata_narrowed.max_seq_len,
)
else:
if current_platform.is_xpu():
ops.top_k_per_row_decode(
xpu_ops.top_k_per_row_decode( # type: ignore[attr-defined]
logits,
next_n,
seq_lens,