Rename layer to comply with deepseek (#6393)

Signed-off-by: peaceh <103117813+peaceh-nv@users.noreply.github.com>
This commit is contained in:
peaceh-nv 2025-07-30 10:00:48 +08:00 committed by GitHub
parent ab40369053
commit 5b420ad267
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 9 additions and 8 deletions

View File

@ -248,7 +248,7 @@ class DeepseekV3Attention(MLA):
dtype=config.torch_dtype,
config=model_config,
aux_stream=aux_stream)
self.fused_a = DeepseekV3Linear(
self.kv_a_proj_with_mqa = DeepseekV3Linear(
config.hidden_size,
self.kv_lora_rank + self.qk_rope_head_dim +
(self.q_lora_rank if not self.is_lite else 0),
@ -1384,7 +1384,7 @@ class DeepseekV3ForCausalLM(DecoderModelForCausalLM[DeepseekV3Model,
attn_module.v_b_proj_scale = nn.Parameter(
v_b_proj_scale, requires_grad=False)
elif names[-1] == "fused_a":
elif names[-1] == "kv_a_proj_with_mqa":
fused_a = weights[
f"{'.'.join(names[:-1])}.kv_a_proj_with_mqa.weight"][:]
if not is_lite:

View File

@ -502,7 +502,7 @@ class MLA(nn.Module):
self.quant_config = quant_config
if not self.is_lite:
self.fused_a = Linear(
self.kv_a_proj_with_mqa = Linear(
hidden_size,
self.q_lora_rank + self.kv_lora_rank + self.qk_rope_head_dim,
bias=bias,
@ -528,7 +528,7 @@ class MLA(nn.Module):
allreduce_strategy=config.allreduce_strategy,
force_dynamic_quantization=config.force_dynamic_quantization)
else:
self.fused_a = Linear(
self.kv_a_proj_with_mqa = Linear(
hidden_size,
self.kv_lora_rank + self.qk_rope_head_dim,
bias=bias,
@ -743,14 +743,15 @@ class MLA(nn.Module):
torch.Tensor: The output tensor.
"""
if self.is_lite:
compressed_kv, k_pe = self.fused_a(hidden_states).split(
compressed_kv, k_pe = self.kv_a_proj_with_mqa(hidden_states).split(
[self.kv_lora_rank, self.qk_rope_head_dim], -1)
compressed_kv = self.kv_a_layernorm(compressed_kv)
q = hidden_states
else:
q, compressed_kv, k_pe = self.fused_a(hidden_states).split(
[self.q_lora_rank, self.kv_lora_rank, self.qk_rope_head_dim],
-1)
q, compressed_kv, k_pe = self.kv_a_proj_with_mqa(
hidden_states).split([
self.q_lora_rank, self.kv_lora_rank, self.qk_rope_head_dim
], -1)
q, compressed_kv = maybe_execute_in_parallel(
lambda: self.q_a_layernorm(q),