mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Rename layer to comply with deepseek (#6393)
Signed-off-by: peaceh <103117813+peaceh-nv@users.noreply.github.com>
This commit is contained in:
parent
ab40369053
commit
5b420ad267
@ -248,7 +248,7 @@ class DeepseekV3Attention(MLA):
|
||||
dtype=config.torch_dtype,
|
||||
config=model_config,
|
||||
aux_stream=aux_stream)
|
||||
self.fused_a = DeepseekV3Linear(
|
||||
self.kv_a_proj_with_mqa = DeepseekV3Linear(
|
||||
config.hidden_size,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim +
|
||||
(self.q_lora_rank if not self.is_lite else 0),
|
||||
@ -1384,7 +1384,7 @@ class DeepseekV3ForCausalLM(DecoderModelForCausalLM[DeepseekV3Model,
|
||||
attn_module.v_b_proj_scale = nn.Parameter(
|
||||
v_b_proj_scale, requires_grad=False)
|
||||
|
||||
elif names[-1] == "fused_a":
|
||||
elif names[-1] == "kv_a_proj_with_mqa":
|
||||
fused_a = weights[
|
||||
f"{'.'.join(names[:-1])}.kv_a_proj_with_mqa.weight"][:]
|
||||
if not is_lite:
|
||||
|
||||
@ -502,7 +502,7 @@ class MLA(nn.Module):
|
||||
self.quant_config = quant_config
|
||||
|
||||
if not self.is_lite:
|
||||
self.fused_a = Linear(
|
||||
self.kv_a_proj_with_mqa = Linear(
|
||||
hidden_size,
|
||||
self.q_lora_rank + self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
bias=bias,
|
||||
@ -528,7 +528,7 @@ class MLA(nn.Module):
|
||||
allreduce_strategy=config.allreduce_strategy,
|
||||
force_dynamic_quantization=config.force_dynamic_quantization)
|
||||
else:
|
||||
self.fused_a = Linear(
|
||||
self.kv_a_proj_with_mqa = Linear(
|
||||
hidden_size,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
bias=bias,
|
||||
@ -743,14 +743,15 @@ class MLA(nn.Module):
|
||||
torch.Tensor: The output tensor.
|
||||
"""
|
||||
if self.is_lite:
|
||||
compressed_kv, k_pe = self.fused_a(hidden_states).split(
|
||||
compressed_kv, k_pe = self.kv_a_proj_with_mqa(hidden_states).split(
|
||||
[self.kv_lora_rank, self.qk_rope_head_dim], -1)
|
||||
compressed_kv = self.kv_a_layernorm(compressed_kv)
|
||||
q = hidden_states
|
||||
else:
|
||||
q, compressed_kv, k_pe = self.fused_a(hidden_states).split(
|
||||
[self.q_lora_rank, self.kv_lora_rank, self.qk_rope_head_dim],
|
||||
-1)
|
||||
q, compressed_kv, k_pe = self.kv_a_proj_with_mqa(
|
||||
hidden_states).split([
|
||||
self.q_lora_rank, self.kv_lora_rank, self.qk_rope_head_dim
|
||||
], -1)
|
||||
|
||||
q, compressed_kv = maybe_execute_in_parallel(
|
||||
lambda: self.q_a_layernorm(q),
|
||||
|
||||
Loading…
Reference in New Issue
Block a user