feat : support duplicate_kv_weight for qwen3 blockwise scale (#5459)

Signed-off-by: Jiying Dong <87510204+dongjiyingdjy@users.noreply.github.com>
This commit is contained in:
dongjiyingdjy 2025-06-30 11:49:22 +08:00 committed by GitHub
parent 1db63c2546
commit 852b79053d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 36 additions and 30 deletions

View File

@ -239,7 +239,7 @@ class OPTForCausalLM(DecoderModelForCausalLM[OPTModel, OPTConfig]):
def load_weights(self, weights: dict):
tp_size = self.model_config.mapping.tp_size
head_dim = self.config.hidden_size // self.config.num_attention_heads
num_kv_heads = self.model_config.pretrained_config.num_attention_heads
def filter_weights(prefix: str, weights: dict):
result = {}
@ -280,7 +280,7 @@ class OPTForCausalLM(DecoderModelForCausalLM[OPTModel, OPTConfig]):
k:
duplicate_kv_weight(
weight=v[:],
head_dim=head_dim,
num_kv_heads=num_kv_heads,
tensor_parallel_size=tp_size)
if k in ['weight', 'bias'] else v
for k, v in fw.items()

View File

@ -328,9 +328,7 @@ class Gemma3ForCausalLM(DecoderModelForCausalLM[Gemma3TextModel,
# minor change for Gemma3 RMSNorm.
def load_weights(self, weights: Dict):
tp_size = self.model_config.mapping.tp_size
head_dim = getattr(
self.config, "head_dim",
self.config.hidden_size // self.config.num_attention_heads)
num_kv_heads = self.config.num_key_value_heads
params_map = {
'qkv_proj': ['q_proj', 'k_proj', 'v_proj'],
@ -364,7 +362,7 @@ class Gemma3ForCausalLM(DecoderModelForCausalLM[Gemma3TextModel,
k:
duplicate_kv_weight(
weight=v[:],
head_dim=head_dim,
num_kv_heads=num_kv_heads,
tensor_parallel_size=tp_size)
if k in ["weight", "bias"] else v
for k, v in fw.items()

View File

@ -333,8 +333,8 @@ class MllamaForConditionalGeneration(nn.Module):
tp_size = self.config.mapping.tp_size
vision_config = self.config.pretrained_config.vision_config
text_config = self.config.pretrained_config.text_config
text_head_dim = text_config.hidden_size // text_config.num_attention_heads
vision_head_dim = vision_config.hidden_size // vision_config.attention_heads
text_config.hidden_size // text_config.num_attention_heads
vision_config.hidden_size // vision_config.attention_heads
params_map = {
'qkv_proj': ['q_proj', 'k_proj', 'v_proj'],
@ -347,7 +347,7 @@ class MllamaForConditionalGeneration(nn.Module):
# skip load weights if tie word embeddings is enabled and layer is lm_head
if text_config.tie_word_embeddings and "lm_head" in name:
continue
head_dim = vision_head_dim if "vision_model" in name else text_head_dim
num_kv_heads = vision_config.num_key_value_heads if "vision_model" in name else text_config.num_key_value_heads
names = name.split('.')
if names[-1] in params_map:
@ -360,7 +360,7 @@ class MllamaForConditionalGeneration(nn.Module):
k:
duplicate_kv_weight(
weight=v[:],
head_dim=head_dim,
num_kv_heads=num_kv_heads,
tensor_parallel_size=tp_size)
if k in ["weight", "bias"] else v
for k, v in fw.items()

View File

@ -394,9 +394,7 @@ class Qwen3MoeForCausalLM(SpecDecOneEngineForCausalLM[Qwen3MoEModel,
tp_size = self.model_config.mapping.tp_size
enable_attention_dp = self.model_config.mapping.enable_attention_dp
head_dim = getattr(
self.config, "head_dim",
self.config.hidden_size // self.config.num_attention_heads)
num_kv_heads = self.config.num_key_value_heads
params_map = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
@ -419,11 +417,14 @@ class Qwen3MoeForCausalLM(SpecDecOneEngineForCausalLM[Qwen3MoEModel,
tensors_need_duplication = ["weight", "bias"]
if module.quant_config.quant_mode.has_nvfp4():
tensors_need_duplication.append("weight_scale")
if module.quant_config.quant_mode.has_fp8_block_scales(
):
tensors_need_duplication.append("weight_scale_inv")
if new_name in ["k_proj", "v_proj"]:
fw = {
k: (duplicate_kv_weight(
weight=v[:],
head_dim=head_dim,
num_kv_heads=num_kv_heads,
tensor_parallel_size=tp_size
if not enable_attention_dp else 1)
if k in tensors_need_duplication else v)

View File

@ -256,7 +256,7 @@ class Qwen2MoeForCausalLM(DecoderModelForCausalLM[QwenMoeModel,
def load_weights(self, weights: Dict):
tp_size = self.model_config.mapping.tp_size
head_dim = self.config.hidden_size // self.config.num_attention_heads
num_kv_heads = self.config.num_key_value_heads
params_map = {
'qkv_proj': ['q_proj', 'k_proj', 'v_proj'],
@ -281,7 +281,7 @@ class Qwen2MoeForCausalLM(DecoderModelForCausalLM[QwenMoeModel,
k:
duplicate_kv_weight(
weight=v[:],
head_dim=head_dim,
num_kv_heads=num_kv_heads,
tensor_parallel_size=tp_size)
if k in ["weight", "bias"] else v
for k, v in fw.items()

View File

@ -93,11 +93,9 @@ class MetaInitMode(TorchDispatchMode):
return func(*args, **kwargs)
def duplicate_kv_weight(weight: torch.Tensor, head_dim: int,
def duplicate_kv_weight(weight: torch.Tensor, num_kv_heads: int,
tensor_parallel_size: int):
num_kv_heads = weight.shape[0] // head_dim
if num_kv_heads >= tensor_parallel_size:
assert num_kv_heads % tensor_parallel_size == 0
return weight
@ -109,11 +107,15 @@ def duplicate_kv_weight(weight: torch.Tensor, head_dim: int,
if weight.ndim == 1:
return weight.repeat_interleave(reps)
# weight
weight = weight.reshape(num_kv_heads, head_dim,
# weight and scale
assert weight.shape[0] % num_kv_heads == 0
size_per_kv_head = weight.shape[0] // num_kv_heads
weight = weight.reshape(num_kv_heads, size_per_kv_head,
-1)[:, None, :, :].expand(num_kv_heads, reps,
head_dim, weight.shape[1])
return weight.reshape(num_kv_heads * reps * head_dim, -1).clone().detach()
size_per_kv_head,
weight.shape[1])
return weight.reshape(num_kv_heads * reps * size_per_kv_head,
-1).clone().detach()
def iter_modules(
@ -648,9 +650,9 @@ def _load_weights_impl(model: Union[nn.Module, DecoderModelForCausalLM],
logger.info(f"Renamed weights with params_map: {params_map}")
tp_size = 1 if model.model_config.mapping.enable_attention_dp else model.model_config.mapping.tp_size
head_dim = getattr(
model.config, "head_dim",
model.config.hidden_size // model.config.num_attention_heads)
num_kv_heads = model.config.num_key_value_heads if hasattr(
model.config, 'num_key_value_heads'
) and model.config.num_key_value_heads is not None else model.config.num_attention_heads
params_map = {
'qkv_proj': ['q_proj', 'k_proj', 'v_proj'],
@ -687,13 +689,18 @@ def _load_weights_impl(model: Union[nn.Module, DecoderModelForCausalLM],
fw = filter_weights('.'.join(names[:-1] + [new_name]),
weights)
if new_name in ['k_proj', 'v_proj']:
num_kv_heads_list = [num_kv_heads
] * len(fw) if isinstance(
num_kv_heads,
int) else num_kv_heads
fw = {
k:
duplicate_kv_weight(weight=v[:],
head_dim=head_dim,
tensor_parallel_size=tp_size)
duplicate_kv_weight(
weight=v[:],
num_kv_heads=num_kv_heads_list[i],
tensor_parallel_size=tp_size)
if k in ["weight", "bias"] else v
for k, v in fw.items()
for i, (k, v) in enumerate(fw.items())
}
module_weights.append(fw)