[TRTLLM-9455][feat] support for new checkpoint (#10082)

Signed-off-by: binghanc <176802681+binghanc@users.noreply.github.com>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
This commit is contained in:
binghanc 2025-12-30 14:46:39 +08:00 committed by GitHub
parent 3e0344a53d
commit 692d8f2023
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -37,6 +37,7 @@ from torch import nn
from tqdm import tqdm
from transformers import PretrainedConfig
import tensorrt_llm.quantization.utils.fp4_utils as fp4_utils
from tensorrt_llm._ipc_utils import can_access_peer
from tensorrt_llm._utils import get_sm_version
from tensorrt_llm.functional import PositionEmbeddingType
@ -142,6 +143,44 @@ class DeepseekV3WeightLoader:
def load_weights(self, weights: Dict, skip_modules: List[str] = []):
def requantize_weight_with_new_scale(weight, weight_scale, old_scale_2,
new_scale_2, device):
"""
Dequantize FP4 weights and requantize with a new scale.
Args:
weight: FP4 quantized weight tensor 2D [,]
weight_scale: FP8 per-block scaling factors
old_scale_2: original global scale (amax/(448*6))
new_scale_2: new global scale (amax/(448*6))
device: target device for computation
Returns:
(requantized_weight, new_weight_scale)
"""
# Remember original dtype of weight_scale
original_scale_dtype = weight_scale.dtype
original_scale_shape = weight_scale.shape
# Dequantize
dequant_shape = (weight.shape[0], weight.shape[1] * 2)
weight_dequant = torch.ops.tensorrt_llm.e2m1_and_ufp8sf_scale_to_float_v2(
weight.contiguous(),
weight_scale.flatten().view(
fp4_utils.float4_sf_dtype).contiguous(), old_scale_2, 16, 1,
True).to(dtype=torch.bfloat16).reshape(dequant_shape)
# Requantize using the new_scale_2
weight_requant, weight_scale_requant = torch.ops.trtllm.fp4_quantize(
weight_dequant.to(device),
1.0 / new_scale_2.to(device),
16, # scaling_vector_size
False)
# Ensure the returned scale has the same dtype as the input scale
return weight_requant.cpu(), weight_scale_requant.reshape(
original_scale_shape).view(original_scale_dtype).cpu()
def rename_moe_weight(weights: Dict, rename_rules: Dict):
result = {}
for key, value in weights.items():
@ -355,27 +394,128 @@ class DeepseekV3WeightLoader:
).view(*attn_module.v_b_proj_dequant.shape).to(
attn_module.v_b_proj_dequant.dtype))
elif names[-1] == "kv_a_proj_with_mqa":
fused_a = weights[
f"{'.'.join(names[:-1])}.kv_a_proj_with_mqa.weight"][:]
if not is_lite:
q_a_proj = weights[
f"{'.'.join(names[:-1])}.q_a_proj.weight"][:]
fused_a = torch.cat([q_a_proj, fused_a], dim=0)
nvfp4_fused_a = self.model_config.get_quant_config(
).layer_quant_mode.has_nvfp4() and weights[
f"{'.'.join(names[:-1])}.kv_a_proj_with_mqa.weight"].dtype == fp4_utils.float4_e2m1x2 and weights[
f"{'.'.join(names[:-1])}.q_a_proj.weight"].dtype == fp4_utils.float4_e2m1x2
if nvfp4_fused_a:
########### input_scale
kv_a_proj_with_mqa_input_scale = weights[
f"{'.'.join(names[:-1])}.kv_a_proj_with_mqa.input_scale"]
if not is_lite:
q_a_proj_input_scale = weights[
f"{'.'.join(names[:-1])}.q_a_proj.input_scale"]
assert kv_a_proj_with_mqa_input_scale == q_a_proj_input_scale, "kv_a_proj_with_mqa.input_scale and q_a_proj.input_scale should be the same"
# modelopt ckpt stores amax/(448*6), convert to (448*6)/amax
shared_input_scale = kv_a_proj_with_mqa_input_scale
module.input_scale.data.copy_(1.0 / shared_input_scale)
E2M1_MAX = 6.0
module.inv_input_scale.data.copy_(module.input_scale /
E2M1_MAX)
########### weight_scale_2
need_requant_kv_a_proj_with_mqa = False
need_requant_q_a_proj = False
kv_a_proj_with_mqa_scale_2 = weights[
f"{'.'.join(names[:-1])}.kv_a_proj_with_mqa.weight_scale_2"]
shared_weight_scale_2 = kv_a_proj_with_mqa_scale_2
if not is_lite:
q_a_proj_scale_2 = weights[
f"{'.'.join(names[:-1])}.q_a_proj.weight_scale_2"]
if kv_a_proj_with_mqa_scale_2 < q_a_proj_scale_2:
shared_weight_scale_2 = q_a_proj_scale_2
need_requant_kv_a_proj_with_mqa = True
elif q_a_proj_scale_2 < kv_a_proj_with_mqa_scale_2:
need_requant_q_a_proj = True
if f"{'.'.join(names[:-1])}.kv_a_proj_with_mqa.weight_scale_inv" in weights:
fused_a_scale = weights[
f"{'.'.join(names[:-1])}.kv_a_proj_with_mqa.weight_scale_inv"]
########### alpha
alpha = shared_input_scale.float(
) * shared_weight_scale_2.float()
module.alpha.data.copy_(alpha)
module.scalar_alpha = alpha.item()
########### weights
kv_a_proj_with_mqa = weights[
f"{'.'.join(names[:-1])}.kv_a_proj_with_mqa.weight"][:]
if not is_lite:
q_a_proj = weights[
f"{'.'.join(names[:-1])}.q_a_proj.weight"][:]
########### weight_scale
kv_a_proj_with_mqa_scale = weights[
f"{'.'.join(names[:-1])}.kv_a_proj_with_mqa.weight_scale"][:]
kv_a_proj_with_mqa_scale = torch.ops.trtllm.block_scale_interleave(
kv_a_proj_with_mqa_scale.view(
fp4_utils.float4_sf_dtype))
if not is_lite:
q_a_proj_scale = weights[
f"{'.'.join(names[:-1])}.q_a_proj.weight_scale_inv"][:]
fused_a_scale = torch.cat(
[q_a_proj_scale, fused_a_scale], dim=0)
f"{'.'.join(names[:-1])}.q_a_proj.weight_scale"][:]
q_a_proj_scale = torch.ops.trtllm.block_scale_interleave(
q_a_proj_scale.view(fp4_utils.float4_sf_dtype))
########### requantize
if need_requant_kv_a_proj_with_mqa:
# requant kv_a_proj_with_mqa
kv_a_proj_with_mqa, kv_a_proj_with_mqa_scale = requantize_weight_with_new_scale(
kv_a_proj_with_mqa,
kv_a_proj_with_mqa_scale,
kv_a_proj_with_mqa_scale_2,
shared_weight_scale_2,
device=module.weight.device,
)
if need_requant_q_a_proj:
# requant q_a_proj
q_a_proj, q_a_proj_scale = requantize_weight_with_new_scale(
q_a_proj,
q_a_proj_scale,
q_a_proj_scale_2,
shared_weight_scale_2,
device=module.weight.device)
########### fuse and load weights
if not is_lite:
fused_a = torch.cat([q_a_proj, kv_a_proj_with_mqa],
dim=0)
else:
fused_a = kv_a_proj_with_mqa
# For DeepseekV32: kv_a_proj_with_mqa is oversized
# to include indexer k weights, which is filled in post_load_weights.
module.weight.data[0:fused_a.shape[0]].copy_(fused_a)
########### fuse weight_scale
if not is_lite:
fused_a_scale = torch.cat(
[q_a_proj_scale, kv_a_proj_with_mqa_scale],
dim=0)
else:
fused_a_scale = kv_a_proj_with_mqa_scale
# For DeepseekV32: kv_a_proj_with_mqa is oversized
# to include indexer k weights, which is filled in post_load_weights.
module.weight_scale.data[0:fused_a_scale.
shape[0]].copy_(fused_a_scale)
# For DeepseekV32: kv_a_proj_with_mqa is oversized
# to include indexer k weights, which is filled in post_load_weights.
module.weight.data[0:fused_a.shape[0]].copy_(fused_a)
else:
fused_a = weights[
f"{'.'.join(names[:-1])}.kv_a_proj_with_mqa.weight"][:]
if not is_lite:
q_a_proj = weights[
f"{'.'.join(names[:-1])}.q_a_proj.weight"][:]
fused_a = torch.cat([q_a_proj, fused_a], dim=0)
if f"{'.'.join(names[:-1])}.kv_a_proj_with_mqa.weight_scale_inv" in weights:
fused_a_scale = weights[
f"{'.'.join(names[:-1])}.kv_a_proj_with_mqa.weight_scale_inv"]
if not is_lite:
q_a_proj_scale = weights[
f"{'.'.join(names[:-1])}.q_a_proj.weight_scale_inv"][:]
fused_a_scale = torch.cat(
[q_a_proj_scale, fused_a_scale], dim=0)
module.weight_scale.data[
0:fused_a_scale.shape[0]].copy_(fused_a_scale)
# For DeepseekV32: kv_a_proj_with_mqa is oversized
# to include indexer k weights, which is filled in post_load_weights.
module.weight.data[0:fused_a.shape[0]].copy_(fused_a)
elif names[-1] in params_map:
module_weights = []
for new_name in params_map[names[-1]]: