mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Fix TP8 for NVFP4 kv dupilcation. (#4143)
Signed-off-by: Tracin <10434017+Tracin@users.noreply.github.com>
This commit is contained in:
parent
d1fa80dee3
commit
b0dd581e6b
@ -280,13 +280,16 @@ class Qwen3MoeForCausalLM(DecoderModelForCausalLM[Qwen3MoEModel,
|
||||
for new_name in params_map[names[-1]]:
|
||||
fw = filter_weights(".".join(names[:-1] + [new_name]),
|
||||
weights)
|
||||
tensors_need_duplication = ["weight", "bias"]
|
||||
if module.quant_config.quant_mode.has_nvfp4():
|
||||
tensors_need_duplication.append("weight_scale")
|
||||
if new_name in ["k_proj", "v_proj"]:
|
||||
fw = {
|
||||
k: (duplicate_kv_weight(
|
||||
weight=v[:],
|
||||
head_dim=head_dim,
|
||||
tensor_parallel_size=tp_size)
|
||||
if k in ["weight", "bias"] else v)
|
||||
if k in tensors_need_duplication else v)
|
||||
for k, v in fw.items()
|
||||
}
|
||||
module_weights.append(fw)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user