From a55fccfc7cefc2d085d3557bace0e345cef67961 Mon Sep 17 00:00:00 2001 From: Jiangyun Zhu Date: Fri, 5 Jun 2026 02:38:05 +0800 Subject: [PATCH] [mamba] unify KDA conv states into one cache to match 2-state SSM layout (#44539) --- .../layers/mamba/gdn/kimi_gdn_linear_attn.py | 12 ++++----- .../layers/mamba/mamba_utils.py | 26 +++++-------------- vllm/model_executor/models/kimi_linear.py | 8 +++--- 3 files changed, 16 insertions(+), 30 deletions(-) diff --git a/vllm/model_executor/layers/mamba/gdn/kimi_gdn_linear_attn.py b/vllm/model_executor/layers/mamba/gdn/kimi_gdn_linear_attn.py index 59bab27c48c..23d7070cc80 100644 --- a/vllm/model_executor/layers/mamba/gdn/kimi_gdn_linear_attn.py +++ b/vllm/model_executor/layers/mamba/gdn/kimi_gdn_linear_attn.py @@ -85,7 +85,7 @@ direct_register_custom_op( class KimiGatedDeltaNetAttention(GatedDeltaNetAttention): def get_state_dtype( self, - ) -> tuple[torch.dtype, torch.dtype, torch.dtype, torch.dtype]: + ) -> tuple[torch.dtype, torch.dtype]: if self.model_config is None or self.cache_config is None: raise ValueError("model_config and cache_config must be set") return MambaStateDtypeCalculator.kda_state_dtype( @@ -94,7 +94,7 @@ class KimiGatedDeltaNetAttention(GatedDeltaNetAttention): def get_state_shape( self, - ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], tuple[int, ...]]: + ) -> tuple[tuple[int, ...], tuple[int, ...]]: return MambaStateShapeCalculator.kda_state_shape( self.tp_size, self.num_heads, self.head_dim, conv_kernel_size=self.conv_size ) @@ -300,13 +300,13 @@ class KimiGatedDeltaNetAttention(GatedDeltaNetAttention): g1 = g1[:, :num_actual_tokens] beta = beta[:, :num_actual_tokens] - (conv_state_q, conv_state_k, conv_state_v, recurrent_state) = constant_caches + (conv_state, recurrent_state) = constant_caches # conv_state must be (..., dim, width-1) for the conv kernels. # DS layout stores it that way directly; SD layout needs a transpose. if not is_conv_state_dim_first(): - conv_state_q = conv_state_q.transpose(-1, -2) - conv_state_k = conv_state_k.transpose(-1, -2) - conv_state_v = conv_state_v.transpose(-1, -2) + conv_state = conv_state.transpose(-1, -2) + + conv_state_q, conv_state_k, conv_state_v = conv_state.chunk(3, dim=-2) q_conv_weights = self.q_conv1d.weight.view( self.q_conv1d.weight.size(0), self.q_conv1d.weight.size(2) diff --git a/vllm/model_executor/layers/mamba/mamba_utils.py b/vllm/model_executor/layers/mamba/mamba_utils.py index c1fd81e40e3..0c86c787917 100644 --- a/vllm/model_executor/layers/mamba/mamba_utils.py +++ b/vllm/model_executor/layers/mamba/mamba_utils.py @@ -120,9 +120,9 @@ class MambaStateDtypeCalculator: cls, model_dtype: ModelDType | torch.dtype, mamba_cache_dtype: MambaDType, - ): + ) -> tuple[torch.dtype, torch.dtype]: state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype) - return (state_dtype, state_dtype, state_dtype, torch.float32) + return (state_dtype, torch.float32) class MambaStateShapeCalculator: @@ -243,7 +243,7 @@ class MambaStateShapeCalculator: head_k_dim: int | None = None, conv_kernel_size: int = 4, num_spec: int = 0, - ) -> tuple[tuple[int, int], tuple[int, int], tuple[int, int], tuple[int, int, int]]: + ) -> tuple[tuple[int, int], tuple[int, int, int]]: if num_k_heads is None: num_k_heads = num_heads if head_k_dim is None: @@ -252,19 +252,12 @@ class MambaStateShapeCalculator: proj_size = num_heads * head_dim proj_k_size = num_k_heads * head_k_dim + conv_dim = proj_size + 2 * proj_k_size conv_state_shape = cls._orient_conv_shape( - divide(proj_size, tp_world_size), conv_kernel_size - 1 - ) - conv_state_k_shape = cls._orient_conv_shape( - divide(proj_k_size, tp_world_size), conv_kernel_size - 1 + divide(conv_dim, tp_world_size), conv_kernel_size - 1 ) recurrent_state_shape = (divide(num_heads, tp_world_size), head_dim, head_dim) - return ( - conv_state_shape, - conv_state_k_shape, - conv_state_k_shape, - recurrent_state_shape, - ) + return (conv_state_shape, recurrent_state_shape) @dataclass @@ -365,9 +358,4 @@ class MambaStateCopyFuncCalculator: @classmethod def kda_state_copy_func(cls): - return ( - get_conv_copy_spec, - get_conv_copy_spec, - get_conv_copy_spec, - get_temporal_copy_spec, - ) + return (get_conv_copy_spec, get_temporal_copy_spec) diff --git a/vllm/model_executor/models/kimi_linear.py b/vllm/model_executor/models/kimi_linear.py index a891950fa57..307b24ac112 100644 --- a/vllm/model_executor/models/kimi_linear.py +++ b/vllm/model_executor/models/kimi_linear.py @@ -600,7 +600,7 @@ class KimiLinearForCausalLM( def get_mamba_state_dtype_from_config( cls, vllm_config: "VllmConfig", - ) -> tuple[torch.dtype, torch.dtype, torch.dtype, torch.dtype]: + ) -> tuple[torch.dtype, torch.dtype]: return MambaStateDtypeCalculator.kda_state_dtype( vllm_config.model_config.dtype, vllm_config.cache_config.mamba_cache_dtype ) @@ -608,7 +608,7 @@ class KimiLinearForCausalLM( @classmethod def get_mamba_state_shape_from_config( cls, vllm_config: "VllmConfig" - ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], tuple[int, ...]]: + ) -> tuple[tuple[int, ...], tuple[int, ...]]: parallel_config = vllm_config.parallel_config hf_config = vllm_config.model_config.hf_config tp_size = parallel_config.tensor_parallel_size @@ -628,9 +628,7 @@ class KimiLinearForCausalLM( @classmethod def get_mamba_state_copy_func( cls, - ) -> tuple[ - MambaStateCopyFunc, MambaStateCopyFunc, MambaStateCopyFunc, MambaStateCopyFunc - ]: + ) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]: return MambaStateCopyFuncCalculator.kda_state_copy_func() def compute_logits(