[mamba] unify KDA conv states into one cache to match 2-state SSM layout (#44539)

This commit is contained in:
Jiangyun Zhu
2026-06-05 02:38:05 +08:00
committed by GitHub
parent 41a4829f22
commit a55fccfc7c
3 changed files with 16 additions and 30 deletions
@@ -85,7 +85,7 @@ direct_register_custom_op(
class KimiGatedDeltaNetAttention(GatedDeltaNetAttention):
def get_state_dtype(
self,
) -> tuple[torch.dtype, torch.dtype, torch.dtype, torch.dtype]:
) -> tuple[torch.dtype, torch.dtype]:
if self.model_config is None or self.cache_config is None:
raise ValueError("model_config and cache_config must be set")
return MambaStateDtypeCalculator.kda_state_dtype(
@@ -94,7 +94,7 @@ class KimiGatedDeltaNetAttention(GatedDeltaNetAttention):
def get_state_shape(
self,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
) -> tuple[tuple[int, ...], tuple[int, ...]]:
return MambaStateShapeCalculator.kda_state_shape(
self.tp_size, self.num_heads, self.head_dim, conv_kernel_size=self.conv_size
)
@@ -300,13 +300,13 @@ class KimiGatedDeltaNetAttention(GatedDeltaNetAttention):
g1 = g1[:, :num_actual_tokens]
beta = beta[:, :num_actual_tokens]
(conv_state_q, conv_state_k, conv_state_v, recurrent_state) = constant_caches
(conv_state, recurrent_state) = constant_caches
# conv_state must be (..., dim, width-1) for the conv kernels.
# DS layout stores it that way directly; SD layout needs a transpose.
if not is_conv_state_dim_first():
conv_state_q = conv_state_q.transpose(-1, -2)
conv_state_k = conv_state_k.transpose(-1, -2)
conv_state_v = conv_state_v.transpose(-1, -2)
conv_state = conv_state.transpose(-1, -2)
conv_state_q, conv_state_k, conv_state_v = conv_state.chunk(3, dim=-2)
q_conv_weights = self.q_conv1d.weight.view(
self.q_conv1d.weight.size(0), self.q_conv1d.weight.size(2)
@@ -120,9 +120,9 @@ class MambaStateDtypeCalculator:
cls,
model_dtype: ModelDType | torch.dtype,
mamba_cache_dtype: MambaDType,
):
) -> tuple[torch.dtype, torch.dtype]:
state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype)
return (state_dtype, state_dtype, state_dtype, torch.float32)
return (state_dtype, torch.float32)
class MambaStateShapeCalculator:
@@ -243,7 +243,7 @@ class MambaStateShapeCalculator:
head_k_dim: int | None = None,
conv_kernel_size: int = 4,
num_spec: int = 0,
) -> tuple[tuple[int, int], tuple[int, int], tuple[int, int], tuple[int, int, int]]:
) -> tuple[tuple[int, int], tuple[int, int, int]]:
if num_k_heads is None:
num_k_heads = num_heads
if head_k_dim is None:
@@ -252,19 +252,12 @@ class MambaStateShapeCalculator:
proj_size = num_heads * head_dim
proj_k_size = num_k_heads * head_k_dim
conv_dim = proj_size + 2 * proj_k_size
conv_state_shape = cls._orient_conv_shape(
divide(proj_size, tp_world_size), conv_kernel_size - 1
)
conv_state_k_shape = cls._orient_conv_shape(
divide(proj_k_size, tp_world_size), conv_kernel_size - 1
divide(conv_dim, tp_world_size), conv_kernel_size - 1
)
recurrent_state_shape = (divide(num_heads, tp_world_size), head_dim, head_dim)
return (
conv_state_shape,
conv_state_k_shape,
conv_state_k_shape,
recurrent_state_shape,
)
return (conv_state_shape, recurrent_state_shape)
@dataclass
@@ -365,9 +358,4 @@ class MambaStateCopyFuncCalculator:
@classmethod
def kda_state_copy_func(cls):
return (
get_conv_copy_spec,
get_conv_copy_spec,
get_conv_copy_spec,
get_temporal_copy_spec,
)
return (get_conv_copy_spec, get_temporal_copy_spec)
+3 -5
View File
@@ -600,7 +600,7 @@ class KimiLinearForCausalLM(
def get_mamba_state_dtype_from_config(
cls,
vllm_config: "VllmConfig",
) -> tuple[torch.dtype, torch.dtype, torch.dtype, torch.dtype]:
) -> tuple[torch.dtype, torch.dtype]:
return MambaStateDtypeCalculator.kda_state_dtype(
vllm_config.model_config.dtype, vllm_config.cache_config.mamba_cache_dtype
)
@@ -608,7 +608,7 @@ class KimiLinearForCausalLM(
@classmethod
def get_mamba_state_shape_from_config(
cls, vllm_config: "VllmConfig"
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
) -> tuple[tuple[int, ...], tuple[int, ...]]:
parallel_config = vllm_config.parallel_config
hf_config = vllm_config.model_config.hf_config
tp_size = parallel_config.tensor_parallel_size
@@ -628,9 +628,7 @@ class KimiLinearForCausalLM(
@classmethod
def get_mamba_state_copy_func(
cls,
) -> tuple[
MambaStateCopyFunc, MambaStateCopyFunc, MambaStateCopyFunc, MambaStateCopyFunc
]:
) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]:
return MambaStateCopyFuncCalculator.kda_state_copy_func()
def compute_logits(