From b0dd581e6b9363e693a7d8daf3052d2acaec5f25 Mon Sep 17 00:00:00 2001 From: Tracin <10434017+Tracin@users.noreply.github.com> Date: Thu, 8 May 2025 17:30:02 +0800 Subject: [PATCH] Fix TP8 for NVFP4 kv dupilcation. (#4143) Signed-off-by: Tracin <10434017+Tracin@users.noreply.github.com> --- tensorrt_llm/_torch/models/modeling_qwen3_moe.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tensorrt_llm/_torch/models/modeling_qwen3_moe.py b/tensorrt_llm/_torch/models/modeling_qwen3_moe.py index 3d43e3e016..943c2e2ed2 100644 --- a/tensorrt_llm/_torch/models/modeling_qwen3_moe.py +++ b/tensorrt_llm/_torch/models/modeling_qwen3_moe.py @@ -280,13 +280,16 @@ class Qwen3MoeForCausalLM(DecoderModelForCausalLM[Qwen3MoEModel, for new_name in params_map[names[-1]]: fw = filter_weights(".".join(names[:-1] + [new_name]), weights) + tensors_need_duplication = ["weight", "bias"] + if module.quant_config.quant_mode.has_nvfp4(): + tensors_need_duplication.append("weight_scale") if new_name in ["k_proj", "v_proj"]: fw = { k: (duplicate_kv_weight( weight=v[:], head_dim=head_dim, tensor_parallel_size=tp_size) - if k in ["weight", "bias"] else v) + if k in tensors_need_duplication else v) for k, v in fw.items() } module_weights.append(fw)