mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
feat : support duplicate_kv_weight for qwen3 blockwise scale (#5459)
Signed-off-by: Jiying Dong <87510204+dongjiyingdjy@users.noreply.github.com>
This commit is contained in:
parent
1db63c2546
commit
852b79053d
@ -239,7 +239,7 @@ class OPTForCausalLM(DecoderModelForCausalLM[OPTModel, OPTConfig]):
|
||||
|
||||
def load_weights(self, weights: dict):
|
||||
tp_size = self.model_config.mapping.tp_size
|
||||
head_dim = self.config.hidden_size // self.config.num_attention_heads
|
||||
num_kv_heads = self.model_config.pretrained_config.num_attention_heads
|
||||
|
||||
def filter_weights(prefix: str, weights: dict):
|
||||
result = {}
|
||||
@ -280,7 +280,7 @@ class OPTForCausalLM(DecoderModelForCausalLM[OPTModel, OPTConfig]):
|
||||
k:
|
||||
duplicate_kv_weight(
|
||||
weight=v[:],
|
||||
head_dim=head_dim,
|
||||
num_kv_heads=num_kv_heads,
|
||||
tensor_parallel_size=tp_size)
|
||||
if k in ['weight', 'bias'] else v
|
||||
for k, v in fw.items()
|
||||
|
||||
@ -328,9 +328,7 @@ class Gemma3ForCausalLM(DecoderModelForCausalLM[Gemma3TextModel,
|
||||
# minor change for Gemma3 RMSNorm.
|
||||
def load_weights(self, weights: Dict):
|
||||
tp_size = self.model_config.mapping.tp_size
|
||||
head_dim = getattr(
|
||||
self.config, "head_dim",
|
||||
self.config.hidden_size // self.config.num_attention_heads)
|
||||
num_kv_heads = self.config.num_key_value_heads
|
||||
|
||||
params_map = {
|
||||
'qkv_proj': ['q_proj', 'k_proj', 'v_proj'],
|
||||
@ -364,7 +362,7 @@ class Gemma3ForCausalLM(DecoderModelForCausalLM[Gemma3TextModel,
|
||||
k:
|
||||
duplicate_kv_weight(
|
||||
weight=v[:],
|
||||
head_dim=head_dim,
|
||||
num_kv_heads=num_kv_heads,
|
||||
tensor_parallel_size=tp_size)
|
||||
if k in ["weight", "bias"] else v
|
||||
for k, v in fw.items()
|
||||
|
||||
@ -333,8 +333,8 @@ class MllamaForConditionalGeneration(nn.Module):
|
||||
tp_size = self.config.mapping.tp_size
|
||||
vision_config = self.config.pretrained_config.vision_config
|
||||
text_config = self.config.pretrained_config.text_config
|
||||
text_head_dim = text_config.hidden_size // text_config.num_attention_heads
|
||||
vision_head_dim = vision_config.hidden_size // vision_config.attention_heads
|
||||
text_config.hidden_size // text_config.num_attention_heads
|
||||
vision_config.hidden_size // vision_config.attention_heads
|
||||
|
||||
params_map = {
|
||||
'qkv_proj': ['q_proj', 'k_proj', 'v_proj'],
|
||||
@ -347,7 +347,7 @@ class MllamaForConditionalGeneration(nn.Module):
|
||||
# skip load weights if tie word embeddings is enabled and layer is lm_head
|
||||
if text_config.tie_word_embeddings and "lm_head" in name:
|
||||
continue
|
||||
head_dim = vision_head_dim if "vision_model" in name else text_head_dim
|
||||
num_kv_heads = vision_config.num_key_value_heads if "vision_model" in name else text_config.num_key_value_heads
|
||||
|
||||
names = name.split('.')
|
||||
if names[-1] in params_map:
|
||||
@ -360,7 +360,7 @@ class MllamaForConditionalGeneration(nn.Module):
|
||||
k:
|
||||
duplicate_kv_weight(
|
||||
weight=v[:],
|
||||
head_dim=head_dim,
|
||||
num_kv_heads=num_kv_heads,
|
||||
tensor_parallel_size=tp_size)
|
||||
if k in ["weight", "bias"] else v
|
||||
for k, v in fw.items()
|
||||
|
||||
@ -394,9 +394,7 @@ class Qwen3MoeForCausalLM(SpecDecOneEngineForCausalLM[Qwen3MoEModel,
|
||||
tp_size = self.model_config.mapping.tp_size
|
||||
enable_attention_dp = self.model_config.mapping.enable_attention_dp
|
||||
|
||||
head_dim = getattr(
|
||||
self.config, "head_dim",
|
||||
self.config.hidden_size // self.config.num_attention_heads)
|
||||
num_kv_heads = self.config.num_key_value_heads
|
||||
|
||||
params_map = {
|
||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||
@ -419,11 +417,14 @@ class Qwen3MoeForCausalLM(SpecDecOneEngineForCausalLM[Qwen3MoEModel,
|
||||
tensors_need_duplication = ["weight", "bias"]
|
||||
if module.quant_config.quant_mode.has_nvfp4():
|
||||
tensors_need_duplication.append("weight_scale")
|
||||
if module.quant_config.quant_mode.has_fp8_block_scales(
|
||||
):
|
||||
tensors_need_duplication.append("weight_scale_inv")
|
||||
if new_name in ["k_proj", "v_proj"]:
|
||||
fw = {
|
||||
k: (duplicate_kv_weight(
|
||||
weight=v[:],
|
||||
head_dim=head_dim,
|
||||
num_kv_heads=num_kv_heads,
|
||||
tensor_parallel_size=tp_size
|
||||
if not enable_attention_dp else 1)
|
||||
if k in tensors_need_duplication else v)
|
||||
|
||||
@ -256,7 +256,7 @@ class Qwen2MoeForCausalLM(DecoderModelForCausalLM[QwenMoeModel,
|
||||
|
||||
def load_weights(self, weights: Dict):
|
||||
tp_size = self.model_config.mapping.tp_size
|
||||
head_dim = self.config.hidden_size // self.config.num_attention_heads
|
||||
num_kv_heads = self.config.num_key_value_heads
|
||||
|
||||
params_map = {
|
||||
'qkv_proj': ['q_proj', 'k_proj', 'v_proj'],
|
||||
@ -281,7 +281,7 @@ class Qwen2MoeForCausalLM(DecoderModelForCausalLM[QwenMoeModel,
|
||||
k:
|
||||
duplicate_kv_weight(
|
||||
weight=v[:],
|
||||
head_dim=head_dim,
|
||||
num_kv_heads=num_kv_heads,
|
||||
tensor_parallel_size=tp_size)
|
||||
if k in ["weight", "bias"] else v
|
||||
for k, v in fw.items()
|
||||
|
||||
@ -93,11 +93,9 @@ class MetaInitMode(TorchDispatchMode):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
def duplicate_kv_weight(weight: torch.Tensor, head_dim: int,
|
||||
def duplicate_kv_weight(weight: torch.Tensor, num_kv_heads: int,
|
||||
tensor_parallel_size: int):
|
||||
|
||||
num_kv_heads = weight.shape[0] // head_dim
|
||||
|
||||
if num_kv_heads >= tensor_parallel_size:
|
||||
assert num_kv_heads % tensor_parallel_size == 0
|
||||
return weight
|
||||
@ -109,11 +107,15 @@ def duplicate_kv_weight(weight: torch.Tensor, head_dim: int,
|
||||
if weight.ndim == 1:
|
||||
return weight.repeat_interleave(reps)
|
||||
|
||||
# weight
|
||||
weight = weight.reshape(num_kv_heads, head_dim,
|
||||
# weight and scale
|
||||
assert weight.shape[0] % num_kv_heads == 0
|
||||
size_per_kv_head = weight.shape[0] // num_kv_heads
|
||||
weight = weight.reshape(num_kv_heads, size_per_kv_head,
|
||||
-1)[:, None, :, :].expand(num_kv_heads, reps,
|
||||
head_dim, weight.shape[1])
|
||||
return weight.reshape(num_kv_heads * reps * head_dim, -1).clone().detach()
|
||||
size_per_kv_head,
|
||||
weight.shape[1])
|
||||
return weight.reshape(num_kv_heads * reps * size_per_kv_head,
|
||||
-1).clone().detach()
|
||||
|
||||
|
||||
def iter_modules(
|
||||
@ -648,9 +650,9 @@ def _load_weights_impl(model: Union[nn.Module, DecoderModelForCausalLM],
|
||||
logger.info(f"Renamed weights with params_map: {params_map}")
|
||||
|
||||
tp_size = 1 if model.model_config.mapping.enable_attention_dp else model.model_config.mapping.tp_size
|
||||
head_dim = getattr(
|
||||
model.config, "head_dim",
|
||||
model.config.hidden_size // model.config.num_attention_heads)
|
||||
num_kv_heads = model.config.num_key_value_heads if hasattr(
|
||||
model.config, 'num_key_value_heads'
|
||||
) and model.config.num_key_value_heads is not None else model.config.num_attention_heads
|
||||
|
||||
params_map = {
|
||||
'qkv_proj': ['q_proj', 'k_proj', 'v_proj'],
|
||||
@ -687,13 +689,18 @@ def _load_weights_impl(model: Union[nn.Module, DecoderModelForCausalLM],
|
||||
fw = filter_weights('.'.join(names[:-1] + [new_name]),
|
||||
weights)
|
||||
if new_name in ['k_proj', 'v_proj']:
|
||||
num_kv_heads_list = [num_kv_heads
|
||||
] * len(fw) if isinstance(
|
||||
num_kv_heads,
|
||||
int) else num_kv_heads
|
||||
fw = {
|
||||
k:
|
||||
duplicate_kv_weight(weight=v[:],
|
||||
head_dim=head_dim,
|
||||
tensor_parallel_size=tp_size)
|
||||
duplicate_kv_weight(
|
||||
weight=v[:],
|
||||
num_kv_heads=num_kv_heads_list[i],
|
||||
tensor_parallel_size=tp_size)
|
||||
if k in ["weight", "bias"] else v
|
||||
for k, v in fw.items()
|
||||
for i, (k, v) in enumerate(fw.items())
|
||||
}
|
||||
|
||||
module_weights.append(fw)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user