mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[mamba] unify KDA conv states into one cache to match 2-state SSM layout (#44539)
This commit is contained in:
@@ -85,7 +85,7 @@ direct_register_custom_op(
|
||||
class KimiGatedDeltaNetAttention(GatedDeltaNetAttention):
|
||||
def get_state_dtype(
|
||||
self,
|
||||
) -> tuple[torch.dtype, torch.dtype, torch.dtype, torch.dtype]:
|
||||
) -> tuple[torch.dtype, torch.dtype]:
|
||||
if self.model_config is None or self.cache_config is None:
|
||||
raise ValueError("model_config and cache_config must be set")
|
||||
return MambaStateDtypeCalculator.kda_state_dtype(
|
||||
@@ -94,7 +94,7 @@ class KimiGatedDeltaNetAttention(GatedDeltaNetAttention):
|
||||
|
||||
def get_state_shape(
|
||||
self,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...]]:
|
||||
return MambaStateShapeCalculator.kda_state_shape(
|
||||
self.tp_size, self.num_heads, self.head_dim, conv_kernel_size=self.conv_size
|
||||
)
|
||||
@@ -300,13 +300,13 @@ class KimiGatedDeltaNetAttention(GatedDeltaNetAttention):
|
||||
g1 = g1[:, :num_actual_tokens]
|
||||
beta = beta[:, :num_actual_tokens]
|
||||
|
||||
(conv_state_q, conv_state_k, conv_state_v, recurrent_state) = constant_caches
|
||||
(conv_state, recurrent_state) = constant_caches
|
||||
# conv_state must be (..., dim, width-1) for the conv kernels.
|
||||
# DS layout stores it that way directly; SD layout needs a transpose.
|
||||
if not is_conv_state_dim_first():
|
||||
conv_state_q = conv_state_q.transpose(-1, -2)
|
||||
conv_state_k = conv_state_k.transpose(-1, -2)
|
||||
conv_state_v = conv_state_v.transpose(-1, -2)
|
||||
conv_state = conv_state.transpose(-1, -2)
|
||||
|
||||
conv_state_q, conv_state_k, conv_state_v = conv_state.chunk(3, dim=-2)
|
||||
|
||||
q_conv_weights = self.q_conv1d.weight.view(
|
||||
self.q_conv1d.weight.size(0), self.q_conv1d.weight.size(2)
|
||||
|
||||
@@ -120,9 +120,9 @@ class MambaStateDtypeCalculator:
|
||||
cls,
|
||||
model_dtype: ModelDType | torch.dtype,
|
||||
mamba_cache_dtype: MambaDType,
|
||||
):
|
||||
) -> tuple[torch.dtype, torch.dtype]:
|
||||
state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype)
|
||||
return (state_dtype, state_dtype, state_dtype, torch.float32)
|
||||
return (state_dtype, torch.float32)
|
||||
|
||||
|
||||
class MambaStateShapeCalculator:
|
||||
@@ -243,7 +243,7 @@ class MambaStateShapeCalculator:
|
||||
head_k_dim: int | None = None,
|
||||
conv_kernel_size: int = 4,
|
||||
num_spec: int = 0,
|
||||
) -> tuple[tuple[int, int], tuple[int, int], tuple[int, int], tuple[int, int, int]]:
|
||||
) -> tuple[tuple[int, int], tuple[int, int, int]]:
|
||||
if num_k_heads is None:
|
||||
num_k_heads = num_heads
|
||||
if head_k_dim is None:
|
||||
@@ -252,19 +252,12 @@ class MambaStateShapeCalculator:
|
||||
proj_size = num_heads * head_dim
|
||||
proj_k_size = num_k_heads * head_k_dim
|
||||
|
||||
conv_dim = proj_size + 2 * proj_k_size
|
||||
conv_state_shape = cls._orient_conv_shape(
|
||||
divide(proj_size, tp_world_size), conv_kernel_size - 1
|
||||
)
|
||||
conv_state_k_shape = cls._orient_conv_shape(
|
||||
divide(proj_k_size, tp_world_size), conv_kernel_size - 1
|
||||
divide(conv_dim, tp_world_size), conv_kernel_size - 1
|
||||
)
|
||||
recurrent_state_shape = (divide(num_heads, tp_world_size), head_dim, head_dim)
|
||||
return (
|
||||
conv_state_shape,
|
||||
conv_state_k_shape,
|
||||
conv_state_k_shape,
|
||||
recurrent_state_shape,
|
||||
)
|
||||
return (conv_state_shape, recurrent_state_shape)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -365,9 +358,4 @@ class MambaStateCopyFuncCalculator:
|
||||
|
||||
@classmethod
|
||||
def kda_state_copy_func(cls):
|
||||
return (
|
||||
get_conv_copy_spec,
|
||||
get_conv_copy_spec,
|
||||
get_conv_copy_spec,
|
||||
get_temporal_copy_spec,
|
||||
)
|
||||
return (get_conv_copy_spec, get_temporal_copy_spec)
|
||||
|
||||
@@ -600,7 +600,7 @@ class KimiLinearForCausalLM(
|
||||
def get_mamba_state_dtype_from_config(
|
||||
cls,
|
||||
vllm_config: "VllmConfig",
|
||||
) -> tuple[torch.dtype, torch.dtype, torch.dtype, torch.dtype]:
|
||||
) -> tuple[torch.dtype, torch.dtype]:
|
||||
return MambaStateDtypeCalculator.kda_state_dtype(
|
||||
vllm_config.model_config.dtype, vllm_config.cache_config.mamba_cache_dtype
|
||||
)
|
||||
@@ -608,7 +608,7 @@ class KimiLinearForCausalLM(
|
||||
@classmethod
|
||||
def get_mamba_state_shape_from_config(
|
||||
cls, vllm_config: "VllmConfig"
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...]]:
|
||||
parallel_config = vllm_config.parallel_config
|
||||
hf_config = vllm_config.model_config.hf_config
|
||||
tp_size = parallel_config.tensor_parallel_size
|
||||
@@ -628,9 +628,7 @@ class KimiLinearForCausalLM(
|
||||
@classmethod
|
||||
def get_mamba_state_copy_func(
|
||||
cls,
|
||||
) -> tuple[
|
||||
MambaStateCopyFunc, MambaStateCopyFunc, MambaStateCopyFunc, MambaStateCopyFunc
|
||||
]:
|
||||
) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]:
|
||||
return MambaStateCopyFuncCalculator.kda_state_copy_func()
|
||||
|
||||
def compute_logits(
|
||||
|
||||
Reference in New Issue
Block a user