mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[MyPy] Enable mypy for vllm/model_executor/layers/ (#40159)
Signed-off-by: Martin Hickey <martin.hickey@ie.ibm.com>
This commit is contained in:
@@ -29,7 +29,6 @@ SEPARATE_GROUPS = [
|
||||
"tests",
|
||||
# v0 related
|
||||
"vllm/lora",
|
||||
"vllm/model_executor/layers",
|
||||
]
|
||||
|
||||
# TODO(woosuk): Include the code from Megatron and HuggingFace.
|
||||
|
||||
@@ -666,16 +666,7 @@ _ACTIVATION_REGISTRY = LazyDict(
|
||||
"gelu": lambda: GELU(),
|
||||
"gelu_fast": lambda: FastGELU(),
|
||||
"gelu_new": lambda: NewGELU(),
|
||||
"gelu_pytorch_tanh": lambda: (
|
||||
# TODO:[ROCm] PyTorch native GELU with tanh is unstable with torch.compile
|
||||
logger.warning_once(
|
||||
"[ROCm] PyTorch's native GELU with tanh approximation is unstable. "
|
||||
"Falling back to GELU(approximate='none')."
|
||||
),
|
||||
nn.GELU(approximate="none"),
|
||||
)[1]
|
||||
if current_platform.is_rocm()
|
||||
else nn.GELU(approximate="tanh"),
|
||||
"gelu_pytorch_tanh": lambda: _get_gelu_pytorch_tanh(),
|
||||
"relu": lambda: nn.ReLU(),
|
||||
"relu2": lambda: ReLUSquaredActivation(),
|
||||
"silu": lambda: nn.SiLU(),
|
||||
@@ -687,6 +678,18 @@ _ACTIVATION_REGISTRY = LazyDict(
|
||||
)
|
||||
|
||||
|
||||
def _get_gelu_pytorch_tanh() -> nn.Module:
|
||||
"""Get PyTorch GELU with tanh approximation, with ROCm fallback."""
|
||||
if current_platform.is_rocm():
|
||||
# TODO:[ROCm] PyTorch native GELU with tanh is unstable with torch.compile
|
||||
logger.warning_once(
|
||||
"[ROCm] PyTorch's native GELU with tanh approximation is unstable. "
|
||||
"Falling back to GELU(approximate='none')."
|
||||
)
|
||||
return nn.GELU(approximate="none")
|
||||
return nn.GELU(approximate="tanh")
|
||||
|
||||
|
||||
def get_act_fn(act_fn_name: str) -> nn.Module:
|
||||
"""Get an activation function by name."""
|
||||
act_fn_name = act_fn_name.lower()
|
||||
@@ -703,12 +706,12 @@ def get_act_fn(act_fn_name: str) -> nn.Module:
|
||||
return _ACTIVATION_REGISTRY[act_fn_name]
|
||||
|
||||
|
||||
_ACTIVATION_AND_MUL_REGISTRY = LazyDict(
|
||||
_ACTIVATION_AND_MUL_REGISTRY: LazyDict[nn.Module] = LazyDict(
|
||||
{
|
||||
"gelu": lambda: GeluAndMul(),
|
||||
"silu": lambda: SiluAndMul(),
|
||||
"geglu": lambda: GeluAndMul(),
|
||||
"swigluoai": lambda *args, **kwargs: SwigluOAIAndMul(*args, **kwargs),
|
||||
"swigluoai": lambda: SwigluOAIAndMul(),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -33,6 +33,7 @@ from vllm.utils.torch_utils import (
|
||||
)
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionMetadata,
|
||||
AttentionType,
|
||||
)
|
||||
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||
@@ -209,6 +210,7 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
`self.kv_cache`.
|
||||
"""
|
||||
super().__init__()
|
||||
sliding_window: int | None
|
||||
if per_layer_sliding_window is not None:
|
||||
# per-layer sliding window
|
||||
sliding_window = per_layer_sliding_window
|
||||
@@ -335,7 +337,7 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
cache_config.enable_prefix_caching = False
|
||||
|
||||
impl_cls = self.attn_backend.get_impl_cls()
|
||||
self.impl = impl_cls(
|
||||
self.impl = impl_cls( # type: ignore[assignment] # impl_cls always returns an AttentionImpl subclass
|
||||
num_heads,
|
||||
head_size,
|
||||
scale,
|
||||
@@ -576,7 +578,7 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
def get_attn_backend(self) -> type[AttentionBackend]:
|
||||
return self.attn_backend
|
||||
|
||||
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
|
||||
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None:
|
||||
# Block size may get updated after model loading, refresh it
|
||||
block_size = vllm_config.cache_config.block_size
|
||||
# Should not be called for enc-dec or encoder-only attention.
|
||||
@@ -680,9 +682,16 @@ def get_attention_context(
|
||||
extracted from the forward context.
|
||||
"""
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata[layer_name]
|
||||
attn_metadata_raw = forward_context.attn_metadata
|
||||
attn_metadata: AttentionMetadata
|
||||
if isinstance(attn_metadata_raw, dict):
|
||||
attn_metadata = attn_metadata_raw[layer_name]
|
||||
elif isinstance(attn_metadata_raw, list):
|
||||
# list[dict[str, AttentionMetadata]]: used in speculative decoding
|
||||
# where [0] is the base-model (non-speculative) metadata dict.
|
||||
attn_metadata = attn_metadata_raw[0][layer_name]
|
||||
else:
|
||||
attn_metadata = attn_metadata_raw
|
||||
attn_layer: Attention | MLAAttention = forward_context.no_compile_layers[layer_name]
|
||||
kv_cache = attn_layer.kv_cache
|
||||
slot_mapping = forward_context.slot_mapping
|
||||
@@ -708,7 +717,7 @@ def unified_kv_cache_update(
|
||||
assert hasattr(attn_layer.impl, "do_kv_cache_update"), (
|
||||
f"{attn_layer.impl.__class__.__name__} does not support kv cache update"
|
||||
)
|
||||
attn_layer.impl.do_kv_cache_update(
|
||||
attn_layer.impl.do_kv_cache_update( # type: ignore[attr-defined]
|
||||
attn_layer,
|
||||
key,
|
||||
value,
|
||||
|
||||
@@ -29,7 +29,7 @@ from vllm.v1.kv_cache_interface import (
|
||||
|
||||
@functools.lru_cache
|
||||
def create_chunked_local_attention_backend(
|
||||
underlying_attn_backend: AttentionBackend,
|
||||
underlying_attn_backend: type[AttentionBackend],
|
||||
attention_chunk_size: int,
|
||||
) -> type[AttentionBackend]:
|
||||
prefix = f"ChunkedLocalAttention_{attention_chunk_size}_"
|
||||
|
||||
@@ -72,7 +72,7 @@ def _get_cross_slot_mapping(
|
||||
|
||||
@functools.lru_cache
|
||||
def create_cross_attention_backend(
|
||||
underlying_attn_backend: AttentionBackend,
|
||||
underlying_attn_backend: type[AttentionBackend],
|
||||
) -> type[AttentionBackend]:
|
||||
prefix = "CrossAttention_"
|
||||
underlying_builder = underlying_attn_backend.get_builder_cls()
|
||||
@@ -87,6 +87,7 @@ def create_cross_attention_backend(
|
||||
) -> AttentionMetadata:
|
||||
new_metadata = copy(common_attn_metadata)
|
||||
new_metadata.causal = False
|
||||
assert new_metadata.encoder_seq_lens_cpu is not None
|
||||
max_encoder_len = int(new_metadata.encoder_seq_lens_cpu.max())
|
||||
new_metadata.max_seq_len = max_encoder_len
|
||||
# Any computed tokens indicated decode step>1 (no chunked prefill)
|
||||
@@ -118,7 +119,7 @@ def create_cross_attention_backend(
|
||||
self.device,
|
||||
)
|
||||
attn_metadata = super().build(common_prefix_len, new_metadata, fast_build)
|
||||
attn_metadata.slot_mapping = slot_mapping
|
||||
attn_metadata.slot_mapping = slot_mapping # type: ignore[attr-defined]
|
||||
return attn_metadata
|
||||
|
||||
# NOTE(Lucas): we need a custom impl so we can use the slot-mapping computed by
|
||||
@@ -144,8 +145,12 @@ def create_cross_attention_backend(
|
||||
and key is not None
|
||||
and value is not None
|
||||
):
|
||||
self.do_kv_cache_update(
|
||||
layer, key, value, kv_cache, attn_metadata.slot_mapping
|
||||
self.do_kv_cache_update( # type: ignore[attr-defined]
|
||||
layer,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
attn_metadata.slot_mapping, # type: ignore[attr-defined]
|
||||
)
|
||||
|
||||
return super().forward(
|
||||
|
||||
@@ -21,7 +21,7 @@ from vllm.v1.kv_cache_interface import KVCacheSpec
|
||||
|
||||
@functools.lru_cache
|
||||
def create_encoder_only_attention_backend(
|
||||
underlying_attn_backend: AttentionBackend,
|
||||
underlying_attn_backend: type[AttentionBackend],
|
||||
) -> type[AttentionBackend]:
|
||||
prefix = "EncoderOnlyAttention_"
|
||||
underlying_builder = underlying_attn_backend.get_builder_cls()
|
||||
@@ -93,6 +93,6 @@ class EncoderOnlyAttention(Attention):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
|
||||
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None:
|
||||
# Does not need KV cache
|
||||
return None
|
||||
|
||||
@@ -389,7 +389,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
||||
cache_config.enable_prefix_caching = False
|
||||
|
||||
impl_cls = cast(type[MLAAttentionImpl], self.attn_backend.get_impl_cls())
|
||||
self.impl = impl_cls(
|
||||
self.impl = impl_cls( # type: ignore[assignment] # impl_cls always returns an MLAAttentionImpl subclass
|
||||
num_heads=self.num_heads,
|
||||
head_size=self.head_size,
|
||||
scale=self.scale,
|
||||
@@ -485,16 +485,23 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
||||
|
||||
if self.use_direct_call:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata[self.layer_name]
|
||||
attn_metadata_raw = forward_context.attn_metadata
|
||||
attn_metadata: MLACommonMetadata
|
||||
if isinstance(attn_metadata_raw, dict):
|
||||
attn_metadata = attn_metadata_raw[self.layer_name] # type: ignore[assignment]
|
||||
elif isinstance(attn_metadata_raw, list):
|
||||
# list[dict[str, AttentionMetadata]]: used in speculative decoding
|
||||
# where [0] is the base-model (non-speculative) metadata dict.
|
||||
attn_metadata = attn_metadata_raw[0][self.layer_name] # type: ignore[assignment]
|
||||
else:
|
||||
attn_metadata = attn_metadata_raw
|
||||
self_kv_cache = self.kv_cache
|
||||
slot_mapping = forward_context.slot_mapping
|
||||
|
||||
assert isinstance(slot_mapping, dict), (
|
||||
f"Expected slot_mapping to be a dict, got {type(slot_mapping)}. "
|
||||
)
|
||||
self.impl.do_kv_cache_update(
|
||||
self.impl.do_kv_cache_update( # type: ignore[attr-defined]
|
||||
kv_c_normed,
|
||||
k_pe,
|
||||
self_kv_cache,
|
||||
@@ -612,7 +619,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
||||
num_mha_tokens = q.size(0) - num_mqa_tokens
|
||||
|
||||
if num_mha_tokens > 0:
|
||||
self.impl.forward_mha(
|
||||
self.impl.forward_mha( # type: ignore[attr-defined]
|
||||
q[num_mqa_tokens:],
|
||||
k_c_normed[num_mqa_tokens:],
|
||||
k_pe[num_mqa_tokens:],
|
||||
@@ -695,7 +702,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
||||
# call decode attn
|
||||
if not is_sparse_impl:
|
||||
assert attn_metadata.decode is not None
|
||||
attn_out, lse = self.impl.forward_mqa(mqa_q, kv_cache, attn_metadata, self)
|
||||
attn_out, lse = self.impl.forward_mqa(mqa_q, kv_cache, attn_metadata, self) # type: ignore[attr-defined]
|
||||
|
||||
# correct dcp attn_out with lse.
|
||||
if self.impl.dcp_world_size > 1:
|
||||
@@ -1053,9 +1060,9 @@ except ImportError:
|
||||
"AITER_MLA backends use aiter kernels instead."
|
||||
)
|
||||
elif current_platform.is_xpu():
|
||||
from vllm._xpu_ops import xpu_ops as ops
|
||||
from vllm._xpu_ops import xpu_ops
|
||||
|
||||
flash_attn_varlen_func = ops.flash_attn_varlen_func # type: ignore[no-redef]
|
||||
flash_attn_varlen_func = xpu_ops.flash_attn_varlen_func # type: ignore[no-redef,attr-defined,assignment]
|
||||
|
||||
|
||||
def dynamic_per_batched_tensor_quant(
|
||||
@@ -1988,7 +1995,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
assert isinstance(attn_metadata.prefill, FlashInferPrefillMetadata)
|
||||
self._build_fi_prefill_wrappers(attn_metadata.prefill)
|
||||
|
||||
return attn_metadata
|
||||
return attn_metadata # type: ignore[return-value]
|
||||
|
||||
|
||||
def reorg_kvcache(
|
||||
|
||||
@@ -117,17 +117,20 @@ def maybe_make_prepare_finalize(
|
||||
"Detected DP deployment with no --enable-expert-parallel. "
|
||||
"Falling back to AllGather+ReduceScatter dispatch/combine."
|
||||
)
|
||||
device_communicator = get_ep_group().device_communicator
|
||||
assert device_communicator is not None
|
||||
assert device_communicator.all2all_manager is not None
|
||||
return make_moe_prepare_and_finalize_naive_dp_ep(
|
||||
is_sequence_parallel=moe.moe_parallel_config.is_sequence_parallel,
|
||||
num_dispatchers=(
|
||||
get_ep_group().device_communicator.all2all_manager.world_size
|
||||
),
|
||||
num_dispatchers=(device_communicator.all2all_manager.world_size),
|
||||
use_monolithic=use_monolithic,
|
||||
)
|
||||
else:
|
||||
return make_moe_prepare_and_finalize_no_dp_ep(use_monolithic)
|
||||
|
||||
all2all_manager = get_ep_group().device_communicator.all2all_manager
|
||||
device_communicator = get_ep_group().device_communicator
|
||||
assert device_communicator is not None
|
||||
all2all_manager = device_communicator.all2all_manager
|
||||
assert all2all_manager is not None
|
||||
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize | None = None
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import Union
|
||||
import torch
|
||||
|
||||
from vllm.config import ParallelConfig, SchedulerConfig
|
||||
from vllm.config.kernel import MoEBackend
|
||||
from vllm.distributed import get_dp_group, get_pcp_group, get_tensor_model_parallel_rank
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
@@ -1192,7 +1193,7 @@ class FusedMoEConfig:
|
||||
# Defaults to intermediate_size_per_partition if not specified.
|
||||
intermediate_size_per_partition_unpadded: int | None = None
|
||||
|
||||
moe_backend: str = "auto"
|
||||
moe_backend: MoEBackend = "auto"
|
||||
max_num_tokens: int = SchedulerConfig.DEFAULT_MAX_NUM_BATCHED_TOKENS_FOR_BATCHED_DP
|
||||
has_bias: bool = False
|
||||
is_act_and_mul: bool = True
|
||||
|
||||
@@ -210,9 +210,9 @@ def persistent_masked_m_silu_mul_quant(
|
||||
DeepGemmQuantScaleFMT.UE8M0,
|
||||
]
|
||||
|
||||
cuda_arch = current_platform.get_device_capability(
|
||||
device_id=y.device.index
|
||||
).to_int()
|
||||
device_capability = current_platform.get_device_capability(device_id=y.device.index)
|
||||
assert device_capability is not None
|
||||
cuda_arch = device_capability.to_int()
|
||||
|
||||
if current_platform.is_cuda() and cuda_arch >= 80:
|
||||
torch.ops._C.persistent_masked_m_silu_mul_quant(
|
||||
|
||||
@@ -7,6 +7,7 @@ import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import envs
|
||||
from vllm.config.kernel import MoEBackend
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import (
|
||||
FusedMoEConfig,
|
||||
@@ -146,7 +147,7 @@ def backend_to_kernel_cls(
|
||||
raise ValueError(f"Unknown MXFP4 MoE backend: {backend.value}")
|
||||
|
||||
|
||||
def map_mxfp4_backend(runner_backend: str) -> Mxfp4MoeBackend:
|
||||
def map_mxfp4_backend(runner_backend: MoEBackend) -> Mxfp4MoeBackend:
|
||||
"""Map user's moe_backend string to Mxfp4MoeBackend."""
|
||||
mapping: dict[str, Mxfp4MoeBackend] = {
|
||||
"flashinfer_trtllm": Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16,
|
||||
@@ -201,10 +202,12 @@ def select_gpt_oss_mxfp4_moe_backend(
|
||||
Select the primary MXFP4 MoE backend.
|
||||
Note: Shape-specific fallbacks may still occur at runtime.
|
||||
"""
|
||||
triton_kernels_supported = has_triton_kernels() and (
|
||||
9,
|
||||
0,
|
||||
) <= current_platform.get_device_capability() < (11, 0)
|
||||
device_capability = current_platform.get_device_capability()
|
||||
triton_kernels_supported = (
|
||||
has_triton_kernels()
|
||||
and device_capability is not None
|
||||
and (9, 0) <= device_capability < (11, 0)
|
||||
)
|
||||
|
||||
# LoRA: separate experts backend path
|
||||
if config.is_lora_enabled:
|
||||
|
||||
+18
-6
@@ -4,6 +4,9 @@ import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.distributed import get_ep_group
|
||||
from vllm.distributed.device_communicators.base_device_communicator import (
|
||||
All2AllManagerBase,
|
||||
)
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
|
||||
@@ -11,12 +14,16 @@ from vllm.utils.flashinfer import nvfp4_block_scale_interleave
|
||||
|
||||
|
||||
def get_local_sizes():
|
||||
return get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank()
|
||||
dp_metadata = get_forward_context().dp_metadata
|
||||
assert dp_metadata is not None
|
||||
return dp_metadata.get_chunk_sizes_across_dp_rank()
|
||||
|
||||
|
||||
class FlashInferNVLinkOneSidedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
|
||||
"""FlashInfer implementation using the Moe AlltoAll kernel."""
|
||||
|
||||
all2all_manager: All2AllManagerBase
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_num_tokens: int,
|
||||
@@ -32,8 +39,12 @@ class FlashInferNVLinkOneSidedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeMo
|
||||
self.hidden_size = hidden_size
|
||||
self.num_dispatchers_ = num_dispatchers
|
||||
|
||||
self.all2all_manager = get_ep_group().device_communicator.all2all_manager
|
||||
self.all2all_manager.initialize(
|
||||
device_communicator = get_ep_group().device_communicator
|
||||
assert device_communicator is not None
|
||||
all2all_manager = device_communicator.all2all_manager
|
||||
assert all2all_manager is not None
|
||||
self.all2all_manager = all2all_manager
|
||||
self.all2all_manager.initialize( # type: ignore[attr-defined]
|
||||
max_num_tokens=self.max_num_tokens,
|
||||
top_k=self.top_k,
|
||||
num_experts=self.num_experts,
|
||||
@@ -97,7 +108,8 @@ class FlashInferNVLinkOneSidedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeMo
|
||||
payloads.append(topk_ids)
|
||||
payloads.append(topk_weights)
|
||||
|
||||
recv_payloads = self.all2all_manager.moe_alltoall.dispatch(
|
||||
assert self.all2all_manager.moe_alltoall is not None # type: ignore[attr-defined]
|
||||
recv_payloads = self.all2all_manager.moe_alltoall.dispatch( # type: ignore[attr-defined]
|
||||
token_selected_experts=topk_ids,
|
||||
input_payloads=payloads,
|
||||
runtime_max_tokens_per_rank=self.runtime_max_tokens_per_rank,
|
||||
@@ -131,7 +143,7 @@ class FlashInferNVLinkOneSidedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeMo
|
||||
apply_router_weight_on_input: bool,
|
||||
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
||||
) -> None:
|
||||
assert self.all2all_manager.moe_alltoall is not None
|
||||
assert self.all2all_manager.moe_alltoall is not None # type: ignore[attr-defined]
|
||||
|
||||
ep_size = self.all2all_manager.world_size
|
||||
hidden_size = fused_expert_output.shape[-1]
|
||||
@@ -139,7 +151,7 @@ class FlashInferNVLinkOneSidedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeMo
|
||||
ep_size, self.runtime_max_tokens_per_rank, hidden_size
|
||||
)
|
||||
|
||||
combined_output = self.all2all_manager.moe_alltoall.combine(
|
||||
combined_output = self.all2all_manager.moe_alltoall.combine( # type: ignore[attr-defined]
|
||||
payload=fused_expert_output,
|
||||
runtime_max_tokens_per_rank=self.runtime_max_tokens_per_rank,
|
||||
)
|
||||
|
||||
+16
-9
@@ -15,19 +15,26 @@ from vllm.utils.flashinfer import nvfp4_block_scale_interleave
|
||||
|
||||
|
||||
def get_local_sizes():
|
||||
return get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank()
|
||||
dp_metadata = get_forward_context().dp_metadata
|
||||
assert dp_metadata is not None
|
||||
return dp_metadata.get_chunk_sizes_across_dp_rank()
|
||||
|
||||
|
||||
class FlashInferNVLinkTwoSidedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
|
||||
"""Base class for FlashInfer MoE prepare and finalize operations."""
|
||||
|
||||
all2all_manager: All2AllManagerBase
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_dispatchers: int = 1,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_dispatchers_ = num_dispatchers
|
||||
self.all2all_manager = get_ep_group().device_communicator.all2all_manager
|
||||
device_communicator = get_ep_group().device_communicator
|
||||
assert device_communicator is not None
|
||||
assert device_communicator.all2all_manager is not None
|
||||
self.all2all_manager = device_communicator.all2all_manager
|
||||
|
||||
@property
|
||||
def activation_format(self) -> mk.FusedMoEActivationFormat:
|
||||
@@ -129,7 +136,7 @@ def flashinfer_alltoall_dispatch(
|
||||
):
|
||||
from flashinfer.comm.trtllm_alltoall import MnnvlMoe
|
||||
|
||||
assert all2all_manager.ensure_alltoall_workspace_initialized(), (
|
||||
assert all2all_manager.ensure_alltoall_workspace_initialized(), ( # type: ignore[attr-defined]
|
||||
"FlashInfer AllToAll workspace not available"
|
||||
)
|
||||
|
||||
@@ -144,7 +151,7 @@ def flashinfer_alltoall_dispatch(
|
||||
topk_ids,
|
||||
topk_weights,
|
||||
None,
|
||||
all2all_manager.prepare_workspace_tensor,
|
||||
all2all_manager.prepare_workspace_tensor, # type: ignore[attr-defined]
|
||||
max_num_token,
|
||||
ep_rank,
|
||||
ep_size,
|
||||
@@ -172,7 +179,7 @@ def flashinfer_alltoall_dispatch(
|
||||
x = MnnvlMoe.mnnvl_moe_alltoallv(
|
||||
x,
|
||||
alltoall_info,
|
||||
all2all_manager.workspace_tensor,
|
||||
all2all_manager.workspace_tensor, # type: ignore[attr-defined]
|
||||
ep_rank,
|
||||
ep_size,
|
||||
)
|
||||
@@ -180,7 +187,7 @@ def flashinfer_alltoall_dispatch(
|
||||
x_sf = MnnvlMoe.mnnvl_moe_alltoallv(
|
||||
x_sf,
|
||||
alltoall_info,
|
||||
all2all_manager.workspace_tensor,
|
||||
all2all_manager.workspace_tensor, # type: ignore[attr-defined]
|
||||
ep_rank,
|
||||
ep_size,
|
||||
)
|
||||
@@ -196,7 +203,7 @@ def flashinfer_alltoall_dispatch(
|
||||
x = MnnvlMoe.mnnvl_moe_alltoallv(
|
||||
x,
|
||||
alltoall_info,
|
||||
all2all_manager.workspace_tensor,
|
||||
all2all_manager.workspace_tensor, # type: ignore[attr-defined]
|
||||
ep_rank,
|
||||
ep_size,
|
||||
)
|
||||
@@ -212,13 +219,13 @@ def flashinfer_alltoall_combine(
|
||||
):
|
||||
from flashinfer.comm.trtllm_alltoall import MnnvlMoe
|
||||
|
||||
assert all2all_manager.ensure_alltoall_workspace_initialized(), (
|
||||
assert all2all_manager.ensure_alltoall_workspace_initialized(), ( # type: ignore[attr-defined]
|
||||
"FlashInfer AllToAll workspace not available"
|
||||
)
|
||||
return MnnvlMoe.mnnvl_moe_alltoallv_combine(
|
||||
output,
|
||||
alltoall_info,
|
||||
all2all_manager.workspace_tensor,
|
||||
all2all_manager.workspace_tensor, # type: ignore[attr-defined]
|
||||
ep_rank=all2all_manager.rank,
|
||||
ep_size=all2all_manager.world_size,
|
||||
top_k=top_k,
|
||||
|
||||
@@ -132,9 +132,11 @@ class MoEPrepareAndFinalizeNaiveDPEPModular(mk.FusedMoEPrepareAndFinalizeModular
|
||||
)
|
||||
|
||||
if scales is None:
|
||||
assert len(res) == 3
|
||||
a1q, topk_weights, topk_ids = res
|
||||
a1q_scale = None
|
||||
else:
|
||||
assert len(res) == 4
|
||||
a1q, topk_weights, topk_ids, scales = res
|
||||
a1q_scale = _unwrap_scale_and_prepare_for_moe(scales, quant_config)
|
||||
|
||||
@@ -217,9 +219,11 @@ class MoEPrepareAndFinalizeNaiveDPEPMonolithic(mk.FusedMoEPrepareAndFinalizeMono
|
||||
)
|
||||
|
||||
if scales is None:
|
||||
assert len(res) == 2
|
||||
a1q, router_logits = res
|
||||
a1q_scale = None
|
||||
else:
|
||||
assert len(res) == 3
|
||||
a1q, router_logits, scales = res
|
||||
a1q_scale = _unwrap_scale_and_prepare_for_moe(scales, quant_config)
|
||||
|
||||
|
||||
@@ -54,11 +54,13 @@ class DefaultMoERunner(MoERunnerBase):
|
||||
# NOTE: this will be removed once all kernels are migrated into the
|
||||
# MoEKernel framework.
|
||||
if self.do_naive_dispatch_combine:
|
||||
hidden_states, router_logits = get_ep_group().dispatch_router_logits(
|
||||
res = get_ep_group().dispatch_router_logits(
|
||||
hidden_states,
|
||||
router_logits,
|
||||
self.moe_config.is_sequence_parallel,
|
||||
)
|
||||
assert len(res) == 2
|
||||
hidden_states, router_logits = res
|
||||
|
||||
# NOTE: Similar with DP, PCP also needs dispatch and combine. For
|
||||
# simplicity, AgRsAll2All was added separately for PCP here. Maybe
|
||||
|
||||
@@ -16,7 +16,6 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader.weight_utils import sharded_weight_loader
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
from vllm.v1.attention.backend import AttentionMetadata
|
||||
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata
|
||||
|
||||
from .fla.ops.kda import (
|
||||
@@ -123,7 +122,7 @@ class KimiDeltaAttention(nn.Module, MambaBase):
|
||||
self.cache_config = cache_config
|
||||
if model_config is None:
|
||||
raise ValueError("model_config must be provided")
|
||||
kda_config = model_config.linear_attn_config
|
||||
kda_config = model_config.linear_attn_config # type: ignore[attr-defined]
|
||||
self.head_dim = kda_config["head_dim"]
|
||||
self.num_heads = kda_config["num_heads"]
|
||||
self.layer_idx = layer_idx
|
||||
@@ -297,19 +296,21 @@ class KimiDeltaAttention(nn.Module, MambaBase):
|
||||
core_attn_out: torch.Tensor,
|
||||
) -> None:
|
||||
forward_context = get_forward_context()
|
||||
attn_metadata: AttentionMetadata = forward_context.attn_metadata
|
||||
attn_metadata_raw = forward_context.attn_metadata
|
||||
|
||||
if attn_metadata is None:
|
||||
if attn_metadata_raw is None:
|
||||
# # V1 profile run
|
||||
return
|
||||
|
||||
assert isinstance(attn_metadata, dict)
|
||||
attn_metadata = attn_metadata[self.prefix]
|
||||
assert isinstance(attn_metadata, GDNAttentionMetadata)
|
||||
has_initial_state = attn_metadata.has_initial_state
|
||||
non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc
|
||||
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
assert isinstance(attn_metadata_raw, dict)
|
||||
attn_metadata_narrowed = attn_metadata_raw[self.prefix]
|
||||
assert isinstance(attn_metadata_narrowed, GDNAttentionMetadata)
|
||||
has_initial_state = attn_metadata_narrowed.has_initial_state
|
||||
non_spec_query_start_loc = attn_metadata_narrowed.non_spec_query_start_loc
|
||||
non_spec_state_indices_tensor = (
|
||||
attn_metadata_narrowed.non_spec_state_indices_tensor
|
||||
) # noqa: E501
|
||||
num_actual_tokens = attn_metadata_narrowed.num_actual_tokens
|
||||
constant_caches = self.kv_cache
|
||||
|
||||
q_proj_states = q_proj_states[:num_actual_tokens]
|
||||
@@ -335,7 +336,7 @@ class KimiDeltaAttention(nn.Module, MambaBase):
|
||||
v_conv_weights = self.v_conv1d.weight.view(
|
||||
self.v_conv1d.weight.size(0), self.v_conv1d.weight.size(2)
|
||||
)
|
||||
if attn_metadata.num_prefills > 0:
|
||||
if attn_metadata_narrowed.num_prefills > 0:
|
||||
q_proj_states = q_proj_states.transpose(0, 1)
|
||||
k_proj_states = k_proj_states.transpose(0, 1)
|
||||
v_proj_states = v_proj_states.transpose(0, 1)
|
||||
@@ -348,7 +349,7 @@ class KimiDeltaAttention(nn.Module, MambaBase):
|
||||
has_initial_state=has_initial_state,
|
||||
cache_indices=non_spec_state_indices_tensor,
|
||||
query_start_loc=non_spec_query_start_loc,
|
||||
metadata=attn_metadata,
|
||||
metadata=attn_metadata_narrowed,
|
||||
).transpose(0, 1)
|
||||
k = causal_conv1d_fn(
|
||||
k_proj_states,
|
||||
@@ -359,7 +360,7 @@ class KimiDeltaAttention(nn.Module, MambaBase):
|
||||
has_initial_state=has_initial_state,
|
||||
cache_indices=non_spec_state_indices_tensor,
|
||||
query_start_loc=non_spec_query_start_loc,
|
||||
metadata=attn_metadata,
|
||||
metadata=attn_metadata_narrowed,
|
||||
).transpose(0, 1)
|
||||
v = causal_conv1d_fn(
|
||||
v_proj_states,
|
||||
@@ -370,11 +371,12 @@ class KimiDeltaAttention(nn.Module, MambaBase):
|
||||
has_initial_state=has_initial_state,
|
||||
cache_indices=non_spec_state_indices_tensor,
|
||||
query_start_loc=non_spec_query_start_loc,
|
||||
metadata=attn_metadata,
|
||||
metadata=attn_metadata_narrowed,
|
||||
).transpose(0, 1)
|
||||
else:
|
||||
assert non_spec_state_indices_tensor is not None
|
||||
decode_conv_indices = non_spec_state_indices_tensor[
|
||||
: attn_metadata.num_actual_tokens
|
||||
: attn_metadata_narrowed.num_actual_tokens
|
||||
]
|
||||
q = causal_conv1d_update(
|
||||
q_proj_states,
|
||||
@@ -408,7 +410,9 @@ class KimiDeltaAttention(nn.Module, MambaBase):
|
||||
lambda x: rearrange(x, "n (h d) -> 1 n h d", d=self.head_dim), (q, k, v)
|
||||
)
|
||||
|
||||
if attn_metadata.num_prefills > 0:
|
||||
if attn_metadata_narrowed.num_prefills > 0:
|
||||
assert non_spec_state_indices_tensor is not None
|
||||
assert has_initial_state is not None
|
||||
zero_idx = non_spec_state_indices_tensor[~has_initial_state]
|
||||
recurrent_state[zero_idx] = 0
|
||||
initial_state = recurrent_state[non_spec_state_indices_tensor].contiguous()
|
||||
@@ -429,6 +433,7 @@ class KimiDeltaAttention(nn.Module, MambaBase):
|
||||
# Init cache
|
||||
recurrent_state[non_spec_state_indices_tensor] = last_recurrent_state
|
||||
else:
|
||||
assert non_spec_query_start_loc is not None
|
||||
(
|
||||
core_attn_out_non_spec,
|
||||
last_recurrent_state,
|
||||
@@ -440,7 +445,9 @@ class KimiDeltaAttention(nn.Module, MambaBase):
|
||||
beta=beta,
|
||||
initial_state=recurrent_state,
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
cu_seqlens=non_spec_query_start_loc[: attn_metadata.num_decodes + 1],
|
||||
cu_seqlens=non_spec_query_start_loc[
|
||||
: attn_metadata_narrowed.num_decodes + 1
|
||||
],
|
||||
ssm_state_indices=non_spec_state_indices_tensor,
|
||||
)
|
||||
core_attn_out[0, :num_actual_tokens] = core_attn_out_non_spec[
|
||||
|
||||
@@ -76,7 +76,7 @@ def poly_norm(
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
out = torch.empty_like(x)
|
||||
ops.poly_norm(
|
||||
ops.poly_norm( # type: ignore[attr-defined]
|
||||
out,
|
||||
x,
|
||||
weight,
|
||||
|
||||
@@ -42,9 +42,10 @@ class MambaBase(AttentionLayerBase):
|
||||
|
||||
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None:
|
||||
mamba_block_size = vllm_config.cache_config.mamba_block_size
|
||||
assert mamba_block_size is not None
|
||||
page_size_padded = vllm_config.cache_config.mamba_page_size_padded
|
||||
return MambaSpec(
|
||||
shapes=self.get_state_shape(),
|
||||
shapes=tuple(self.get_state_shape()),
|
||||
dtypes=self.get_state_dtype(),
|
||||
block_size=mamba_block_size,
|
||||
page_size_padded=page_size_padded,
|
||||
|
||||
@@ -62,7 +62,6 @@ from vllm.utils.torch_utils import (
|
||||
_resolve_layer_name,
|
||||
direct_register_custom_op,
|
||||
)
|
||||
from vllm.v1.attention.backend import AttentionMetadata
|
||||
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -121,9 +120,9 @@ def fi_chunk_gated_delta_rule(
|
||||
class ChunkGatedDeltaRule(CustomOp):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
backend_cfg = get_current_vllm_config().additional_config.get(
|
||||
"gdn_prefill_backend", "auto"
|
||||
)
|
||||
additional_config = get_current_vllm_config().additional_config
|
||||
assert isinstance(additional_config, dict)
|
||||
backend_cfg = additional_config.get("gdn_prefill_backend", "auto")
|
||||
backend = str(backend_cfg).strip().lower()
|
||||
|
||||
supports_flashinfer = (
|
||||
@@ -621,18 +620,19 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
|
||||
# Part 2: Core Attention
|
||||
# ============================================================
|
||||
forward_context = get_forward_context()
|
||||
attn_metadata: AttentionMetadata = forward_context.attn_metadata
|
||||
attn_metadata_raw = forward_context.attn_metadata
|
||||
core_attn_out = torch.zeros(
|
||||
(num_tokens, self.num_v_heads // self.tp_size, self.head_v_dim),
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
z = torch.empty_like(core_attn_out)
|
||||
if attn_metadata is not None:
|
||||
attn_metadata = attn_metadata[self.prefix]
|
||||
if attn_metadata_raw is not None:
|
||||
assert isinstance(attn_metadata_raw, dict)
|
||||
attn_metadata = attn_metadata_raw[self.prefix]
|
||||
|
||||
# TODO: xpu does not support this param yet
|
||||
spec_sequence_masks = attn_metadata.spec_sequence_masks
|
||||
spec_sequence_masks = attn_metadata.spec_sequence_masks # type: ignore[attr-defined]
|
||||
assert spec_sequence_masks is None
|
||||
|
||||
conv_weights = self.conv1d.weight.view(
|
||||
@@ -658,12 +658,12 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
|
||||
activation=self.activation,
|
||||
A_log=self.A_log,
|
||||
dt_bias=self.dt_bias,
|
||||
num_prefills=attn_metadata.num_prefills,
|
||||
num_decodes=attn_metadata.num_decodes,
|
||||
has_initial_state=attn_metadata.has_initial_state,
|
||||
non_spec_query_start_loc=attn_metadata.non_spec_query_start_loc,
|
||||
non_spec_state_indices_tensor=attn_metadata.non_spec_state_indices_tensor,
|
||||
num_actual_tokens=attn_metadata.num_actual_tokens,
|
||||
num_prefills=attn_metadata.num_prefills, # type: ignore[attr-defined]
|
||||
num_decodes=attn_metadata.num_decodes, # type: ignore[attr-defined]
|
||||
has_initial_state=attn_metadata.has_initial_state, # type: ignore[attr-defined]
|
||||
non_spec_query_start_loc=attn_metadata.non_spec_query_start_loc, # type: ignore[attr-defined]
|
||||
non_spec_state_indices_tensor=attn_metadata.non_spec_state_indices_tensor, # type: ignore[attr-defined]
|
||||
num_actual_tokens=attn_metadata.num_actual_tokens, # type: ignore[attr-defined]
|
||||
tp_size=self.tp_size,
|
||||
reorder_input=not self.gqa_interleaved_layout,
|
||||
)
|
||||
@@ -792,16 +792,16 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
|
||||
core_attn_out: torch.Tensor,
|
||||
):
|
||||
forward_context = get_forward_context()
|
||||
attn_metadata: AttentionMetadata = forward_context.attn_metadata
|
||||
attn_metadata_raw = forward_context.attn_metadata
|
||||
|
||||
if attn_metadata is None:
|
||||
if attn_metadata_raw is None:
|
||||
# V1 profile run — warm up prefill kernels so that
|
||||
# autotuning completes before KV cache allocation.
|
||||
self._warmup_prefill_kernels(mixed_qkv)
|
||||
return
|
||||
|
||||
assert isinstance(attn_metadata, dict)
|
||||
attn_metadata = attn_metadata[self.prefix]
|
||||
assert isinstance(attn_metadata_raw, dict)
|
||||
attn_metadata = attn_metadata_raw[self.prefix] # type: ignore[index]
|
||||
assert isinstance(attn_metadata, GDNAttentionMetadata)
|
||||
|
||||
if (
|
||||
@@ -860,14 +860,16 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
|
||||
|
||||
# 1.1: Process the multi-query part
|
||||
if spec_sequence_masks is not None:
|
||||
# spec_state_indices_tensor is always set when spec_sequence_masks is set
|
||||
assert spec_state_indices_tensor is not None
|
||||
mixed_qkv_spec = causal_conv1d_update(
|
||||
mixed_qkv_spec,
|
||||
conv_state,
|
||||
conv_weights,
|
||||
self.conv1d.bias,
|
||||
self.activation,
|
||||
conv_state_indices=spec_state_indices_tensor[:, 0][
|
||||
: attn_metadata.num_spec_decodes
|
||||
conv_state_indices=spec_state_indices_tensor[:, 0][ # type: ignore[index]
|
||||
: attn_metadata.num_spec_decodes # type: ignore[attr-defined]
|
||||
],
|
||||
num_accepted_tokens=num_accepted_tokens,
|
||||
query_start_loc=spec_query_start_loc,
|
||||
@@ -900,8 +902,8 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
|
||||
conv_weights,
|
||||
self.conv1d.bias,
|
||||
self.activation,
|
||||
conv_state_indices=non_spec_state_indices_tensor[
|
||||
: attn_metadata.num_actual_tokens
|
||||
conv_state_indices=non_spec_state_indices_tensor[ # type: ignore[index]
|
||||
: attn_metadata.num_actual_tokens # type: ignore[attr-defined]
|
||||
],
|
||||
validate_data=True,
|
||||
)
|
||||
@@ -965,8 +967,9 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
|
||||
v=value_spec,
|
||||
initial_state=ssm_state,
|
||||
inplace_final_state=True,
|
||||
cu_seqlens=spec_query_start_loc[
|
||||
: attn_metadata.num_spec_decodes + 1
|
||||
cu_seqlens=spec_query_start_loc[ # type: ignore[index]
|
||||
: attn_metadata.num_spec_decodes
|
||||
+ 1 # type: ignore[attr-defined]
|
||||
],
|
||||
ssm_state_indices=spec_state_indices_tensor,
|
||||
num_accepted_tokens=num_accepted_tokens,
|
||||
@@ -978,8 +981,10 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
|
||||
|
||||
# 2.2: Process the remaining part
|
||||
if attn_metadata.num_prefills > 0:
|
||||
initial_state = ssm_state[non_spec_state_indices_tensor].contiguous()
|
||||
initial_state[~has_initial_state, ...] = 0
|
||||
assert non_spec_state_indices_tensor is not None
|
||||
initial_state = ssm_state[non_spec_state_indices_tensor].contiguous() # type: ignore[index]
|
||||
assert has_initial_state is not None
|
||||
initial_state[~has_initial_state, ...] = 0 # type: ignore[operator]
|
||||
(
|
||||
core_attn_out_non_spec,
|
||||
last_recurrent_state,
|
||||
@@ -1012,8 +1017,9 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
|
||||
v=value_non_spec,
|
||||
initial_state=ssm_state,
|
||||
inplace_final_state=True,
|
||||
cu_seqlens=non_spec_query_start_loc[
|
||||
: attn_metadata.num_decodes + 1
|
||||
cu_seqlens=non_spec_query_start_loc[ # type: ignore[index]
|
||||
: attn_metadata.num_decodes
|
||||
+ 1 # type: ignore[attr-defined]
|
||||
],
|
||||
ssm_state_indices=non_spec_state_indices_tensor,
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
@@ -1073,7 +1079,7 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
|
||||
conv_weights,
|
||||
self.conv1d.bias,
|
||||
self.activation,
|
||||
conv_state_indices=non_spec_state_indices_tensor[:num_actual_tokens],
|
||||
conv_state_indices=non_spec_state_indices_tensor[:num_actual_tokens], # type: ignore[index]
|
||||
validate_data=False,
|
||||
)
|
||||
out_buf = core_attn_out[:num_actual_tokens].unsqueeze(1)
|
||||
@@ -1086,7 +1092,7 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
|
||||
scale=self.head_k_dim**-0.5,
|
||||
initial_state=ssm_state,
|
||||
out=out_buf,
|
||||
ssm_state_indices=non_spec_state_indices_tensor[:num_actual_tokens],
|
||||
ssm_state_indices=non_spec_state_indices_tensor[:num_actual_tokens], # type: ignore[index]
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
)
|
||||
return
|
||||
|
||||
@@ -396,10 +396,11 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
|
||||
self, hidden_states: torch.Tensor, output: torch.Tensor, positions: torch.Tensor
|
||||
) -> None:
|
||||
forward_context = get_forward_context()
|
||||
attn_metadata: AttentionMetadata = forward_context.attn_metadata
|
||||
if attn_metadata is not None:
|
||||
assert isinstance(attn_metadata, dict)
|
||||
attn_metadata = attn_metadata[self.prefix]
|
||||
attn_metadata_raw = forward_context.attn_metadata
|
||||
attn_metadata: AttentionMetadata | None = None
|
||||
if attn_metadata_raw is not None:
|
||||
assert isinstance(attn_metadata_raw, dict)
|
||||
attn_metadata = attn_metadata_raw[self.prefix]
|
||||
assert isinstance(attn_metadata, LinearAttentionMetadata)
|
||||
num_actual_tokens = (
|
||||
attn_metadata.num_prefill_tokens + attn_metadata.num_decode_tokens
|
||||
|
||||
@@ -40,6 +40,7 @@ from vllm.utils.torch_utils import (
|
||||
_resolve_layer_name,
|
||||
direct_register_custom_op,
|
||||
)
|
||||
from vllm.v1.attention.backend import AttentionMetadata
|
||||
from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionMetadata
|
||||
|
||||
|
||||
@@ -258,15 +259,16 @@ class MambaMixer(MambaBase, PluggableLayer):
|
||||
"""
|
||||
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
attn_metadata_raw = forward_context.attn_metadata
|
||||
|
||||
assert self.cache_config is not None
|
||||
mamba_block_size = self.cache_config.mamba_block_size
|
||||
is_mamba_cache_all = self.cache_config.mamba_cache_mode == "all"
|
||||
|
||||
if attn_metadata is not None:
|
||||
assert isinstance(attn_metadata, dict)
|
||||
attn_metadata = attn_metadata[self.prefix]
|
||||
attn_metadata: AttentionMetadata | None = None
|
||||
if attn_metadata_raw is not None:
|
||||
assert isinstance(attn_metadata_raw, dict)
|
||||
attn_metadata = attn_metadata_raw[self.prefix]
|
||||
assert isinstance(attn_metadata, Mamba1AttentionMetadata)
|
||||
query_start_loc_p = attn_metadata.query_start_loc_p
|
||||
state_indices_tensor_p = attn_metadata.state_indices_tensor_p
|
||||
@@ -391,6 +393,9 @@ class MambaMixer(MambaBase, PluggableLayer):
|
||||
ssm_outputs.append(scan_out_p)
|
||||
|
||||
if has_decode:
|
||||
# state_indices_tensor_d is assigned when attn_metadata is not None,
|
||||
# and has_decode is only True when attn_metadata is not None
|
||||
assert state_indices_tensor_d is not None
|
||||
if is_mamba_cache_all:
|
||||
state_indices_tensor_d_input = state_indices_tensor_d.gather(
|
||||
1, block_idx_last_computed_token_d.unsqueeze(1)
|
||||
|
||||
@@ -572,14 +572,16 @@ class MambaMixer2(MambaBase, PluggableLayer):
|
||||
# kernels to operate in continuous batching and in chunked prefill
|
||||
# modes; they are computed at top-level model forward since they
|
||||
# stay the same and reused for all mamba layers in the same iteration
|
||||
attn_metadata: AttentionMetadata = forward_context.attn_metadata
|
||||
attn_metadata_raw = forward_context.attn_metadata
|
||||
|
||||
assert self.cache_config is not None
|
||||
mamba_block_size = self.cache_config.mamba_block_size
|
||||
is_mamba_cache_all = self.cache_config.mamba_cache_mode == "all"
|
||||
if attn_metadata is not None:
|
||||
assert isinstance(attn_metadata, dict)
|
||||
attn_metadata = attn_metadata[self.prefix]
|
||||
|
||||
attn_metadata: AttentionMetadata | None = None
|
||||
if attn_metadata_raw is not None:
|
||||
assert isinstance(attn_metadata_raw, dict)
|
||||
attn_metadata = attn_metadata_raw[self.prefix]
|
||||
assert isinstance(attn_metadata, Mamba2AttentionMetadata)
|
||||
# conv_state must be (..., dim, width-1) for the conv kernels.
|
||||
# DS layout stores it that way directly; SD layout needs a
|
||||
@@ -708,6 +710,7 @@ class MambaMixer2(MambaBase, PluggableLayer):
|
||||
# 3. State Space Model sequence transformation
|
||||
initial_states = None
|
||||
if has_initial_states_p is not None and prep_initial_states:
|
||||
assert state_indices_tensor_p is not None
|
||||
kernel_ssm_indices = state_indices_tensor_p
|
||||
if is_mamba_cache_all:
|
||||
kernel_ssm_indices = state_indices_tensor_p.gather(
|
||||
@@ -746,6 +749,13 @@ class MambaMixer2(MambaBase, PluggableLayer):
|
||||
)
|
||||
|
||||
if is_mamba_cache_all:
|
||||
assert mamba_block_size is not None
|
||||
assert state_indices_tensor_p is not None
|
||||
assert block_idx_first_scheduled_token_p is not None
|
||||
assert block_idx_last_scheduled_token_p is not None
|
||||
assert last_chunk_indices_p is not None
|
||||
assert num_computed_tokens_p is not None
|
||||
|
||||
# The chunk_stride is the number of chunks per mamba block
|
||||
# e.g., if mamba_block_size = 512 and chunk_size = 256,
|
||||
# then chunk_stride = 2
|
||||
@@ -810,6 +820,7 @@ class MambaMixer2(MambaBase, PluggableLayer):
|
||||
ssm_state[cache_blocks_to_fill] = from_where
|
||||
|
||||
# For all seqs, store the last state (note: might be partial):
|
||||
assert state_indices_tensor_p is not None
|
||||
ssm_state[
|
||||
state_indices_tensor_p.gather(
|
||||
1, block_idx_last_scheduled_token_p.unsqueeze(1)
|
||||
@@ -820,10 +831,12 @@ class MambaMixer2(MambaBase, PluggableLayer):
|
||||
# update ssm states
|
||||
# - varlen state is a (num_prefills, nheads, headdim, dstate)
|
||||
# tensor
|
||||
assert state_indices_tensor_p is not None
|
||||
ssm_state[state_indices_tensor_p] = varlen_states
|
||||
|
||||
# Process decode requests
|
||||
if has_decode:
|
||||
assert state_indices_tensor_d is not None
|
||||
if is_mamba_cache_all:
|
||||
state_indices_tensor_d_input = state_indices_tensor_d.gather(
|
||||
1, block_idx_last_computed_token_d.unsqueeze(1)
|
||||
|
||||
@@ -113,10 +113,11 @@ class ShortConv(MambaBase, CustomOp):
|
||||
# chunked prefill modes; they are computed at top-level model forward
|
||||
# since they stay the same and reused for all mamba layers in the same
|
||||
# iteration.
|
||||
attn_metadata: AttentionMetadata = forward_context.attn_metadata
|
||||
if attn_metadata is not None:
|
||||
assert isinstance(attn_metadata, dict)
|
||||
attn_metadata = attn_metadata[self.prefix]
|
||||
attn_metadata_raw = forward_context.attn_metadata
|
||||
attn_metadata: AttentionMetadata | None = None
|
||||
if attn_metadata_raw is not None:
|
||||
assert isinstance(attn_metadata_raw, dict)
|
||||
attn_metadata = attn_metadata_raw[self.prefix]
|
||||
assert isinstance(attn_metadata, ShortConvAttentionMetadata)
|
||||
conv_state = (
|
||||
self.kv_cache[0]
|
||||
|
||||
@@ -115,6 +115,7 @@ def pooler_for_classify(
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
model_config = vllm_config.model_config
|
||||
assert model_config.pooler_config is not None
|
||||
head = ClassifierPoolerHead(
|
||||
head_dtype=model_config.head_dtype,
|
||||
classifier=classifier,
|
||||
|
||||
@@ -124,6 +124,7 @@ def pooler_for_token_classify(
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
model_config = vllm_config.model_config
|
||||
assert model_config.pooler_config is not None
|
||||
head = TokenClassifierPoolerHead(
|
||||
head_dtype=model_config.head_dtype,
|
||||
classifier=classifier,
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
# Supports FP-Quant compression, see https://arxiv.org/abs/2509.23202
|
||||
|
||||
from typing import Any
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
@@ -251,7 +251,11 @@ class FPQuantLinearMethod(LinearMethodBase):
|
||||
def fused_quantize_mx(
|
||||
x_flat: torch.Tensor, hadamard_matrix: torch.Tensor, forward_method: str
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return fusedQuantizeMx(x_flat, hadamard_matrix, method=forward_method)
|
||||
return fusedQuantizeMx(
|
||||
x_flat,
|
||||
hadamard_matrix,
|
||||
method=cast(Literal["quest", "abs_max"], forward_method),
|
||||
)
|
||||
|
||||
|
||||
def fused_quantize_mx_fake(x_flat, hadamard_matrix, forward_method):
|
||||
|
||||
@@ -114,7 +114,7 @@ class QuarkConfig(QuantizationConfig):
|
||||
:param hf_to_vllm_mapper: maps from hf model structure (the assumed
|
||||
structure of the qconfig) to vllm model structure
|
||||
"""
|
||||
quant_config_with_hf_to_vllm_mapper = {}
|
||||
quant_config_with_hf_to_vllm_mapper: dict[str, Any] = {}
|
||||
|
||||
for k, v in self.quant_config.items():
|
||||
if isinstance(v, list):
|
||||
|
||||
@@ -26,7 +26,7 @@ from vllm.v1.worker.workspace import current_workspace_manager
|
||||
if current_platform.is_cuda_alike():
|
||||
from vllm import _custom_ops as ops
|
||||
elif current_platform.is_xpu():
|
||||
from vllm._xpu_ops import xpu_ops as ops
|
||||
from vllm._xpu_ops import xpu_ops
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -84,12 +84,12 @@ def sparse_attn_indexer(
|
||||
total_seq_lens,
|
||||
topk_indices_buffer,
|
||||
)
|
||||
attn_metadata = attn_metadata[k_cache_prefix]
|
||||
assert isinstance(attn_metadata, DeepseekV32IndexerMetadata)
|
||||
slot_mapping = attn_metadata.slot_mapping
|
||||
has_decode = attn_metadata.num_decodes > 0
|
||||
has_prefill = attn_metadata.num_prefills > 0
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
attn_metadata_narrowed = attn_metadata[k_cache_prefix]
|
||||
assert isinstance(attn_metadata_narrowed, DeepseekV32IndexerMetadata)
|
||||
slot_mapping = attn_metadata_narrowed.slot_mapping
|
||||
has_decode = attn_metadata_narrowed.num_decodes > 0
|
||||
has_prefill = attn_metadata_narrowed.num_prefills > 0
|
||||
num_decode_tokens = attn_metadata_narrowed.num_decode_tokens
|
||||
|
||||
# During speculative decoding, k may be padded to the CUDA graph batch
|
||||
# size while slot_mapping only covers actual tokens. Truncate k to avoid
|
||||
@@ -97,6 +97,8 @@ def sparse_attn_indexer(
|
||||
num_tokens = slot_mapping.shape[0]
|
||||
k = k[:num_tokens]
|
||||
|
||||
# scale_fmt can be None, but the function expects str
|
||||
assert scale_fmt is not None
|
||||
ops.indexer_k_quant_and_cache(
|
||||
k,
|
||||
kv_cache,
|
||||
@@ -107,7 +109,7 @@ def sparse_attn_indexer(
|
||||
|
||||
topk_indices_buffer[: hidden_states.shape[0]] = -1
|
||||
if has_prefill:
|
||||
prefill_metadata = attn_metadata.prefill
|
||||
prefill_metadata = attn_metadata_narrowed.prefill
|
||||
assert prefill_metadata is not None
|
||||
|
||||
# Get the full shared workspace buffers once (will allocate on first use)
|
||||
@@ -144,7 +146,7 @@ def sparse_attn_indexer(
|
||||
]
|
||||
|
||||
if current_platform.is_xpu():
|
||||
ops.top_k_per_row_prefill(
|
||||
xpu_ops.top_k_per_row_prefill( # type: ignore[attr-defined]
|
||||
logits,
|
||||
chunk.cu_seqlen_ks,
|
||||
chunk.cu_seqlen_ke,
|
||||
@@ -167,7 +169,7 @@ def sparse_attn_indexer(
|
||||
)
|
||||
|
||||
if has_decode:
|
||||
decode_metadata = attn_metadata.decode
|
||||
decode_metadata = attn_metadata_narrowed.decode
|
||||
assert decode_metadata is not None
|
||||
# kv_cache shape [
|
||||
# kv_cache size requirement [num_block, block_size, n_head, head_dim],
|
||||
@@ -217,11 +219,11 @@ def sparse_attn_indexer(
|
||||
topk_indices,
|
||||
topk_workspace,
|
||||
topk_tokens,
|
||||
attn_metadata.max_seq_len,
|
||||
attn_metadata_narrowed.max_seq_len,
|
||||
)
|
||||
else:
|
||||
if current_platform.is_xpu():
|
||||
ops.top_k_per_row_decode(
|
||||
xpu_ops.top_k_per_row_decode( # type: ignore[attr-defined]
|
||||
logits,
|
||||
next_n,
|
||||
seq_lens,
|
||||
|
||||
Reference in New Issue
Block a user