Fix TP8 for NVFP4 kv dupilcation. (#4143)

Signed-off-by: Tracin <10434017+Tracin@users.noreply.github.com>
This commit is contained in:
Tracin 2025-05-08 17:30:02 +08:00 committed by GitHub
parent d1fa80dee3
commit b0dd581e6b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -280,13 +280,16 @@ class Qwen3MoeForCausalLM(DecoderModelForCausalLM[Qwen3MoEModel,
for new_name in params_map[names[-1]]:
fw = filter_weights(".".join(names[:-1] + [new_name]),
weights)
tensors_need_duplication = ["weight", "bias"]
if module.quant_config.quant_mode.has_nvfp4():
tensors_need_duplication.append("weight_scale")
if new_name in ["k_proj", "v_proj"]:
fw = {
k: (duplicate_kv_weight(
weight=v[:],
head_dim=head_dim,
tensor_parallel_size=tp_size)
if k in ["weight", "bias"] else v)
if k in tensors_need_duplication else v)
for k, v in fw.items()
}
module_weights.append(fw)