mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[TRTLLM-9455][feat] support for new checkpoint (#10082)
Signed-off-by: binghanc <176802681+binghanc@users.noreply.github.com> Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
This commit is contained in:
parent
3e0344a53d
commit
692d8f2023
@ -37,6 +37,7 @@ from torch import nn
|
||||
from tqdm import tqdm
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
import tensorrt_llm.quantization.utils.fp4_utils as fp4_utils
|
||||
from tensorrt_llm._ipc_utils import can_access_peer
|
||||
from tensorrt_llm._utils import get_sm_version
|
||||
from tensorrt_llm.functional import PositionEmbeddingType
|
||||
@ -142,6 +143,44 @@ class DeepseekV3WeightLoader:
|
||||
|
||||
def load_weights(self, weights: Dict, skip_modules: List[str] = []):
|
||||
|
||||
def requantize_weight_with_new_scale(weight, weight_scale, old_scale_2,
|
||||
new_scale_2, device):
|
||||
"""
|
||||
Dequantize FP4 weights and requantize with a new scale.
|
||||
|
||||
Args:
|
||||
weight: FP4 quantized weight tensor 2D [,]
|
||||
weight_scale: FP8 per-block scaling factors
|
||||
old_scale_2: original global scale (amax/(448*6))
|
||||
new_scale_2: new global scale (amax/(448*6))
|
||||
device: target device for computation
|
||||
|
||||
Returns:
|
||||
(requantized_weight, new_weight_scale)
|
||||
"""
|
||||
# Remember original dtype of weight_scale
|
||||
original_scale_dtype = weight_scale.dtype
|
||||
original_scale_shape = weight_scale.shape
|
||||
|
||||
# Dequantize
|
||||
dequant_shape = (weight.shape[0], weight.shape[1] * 2)
|
||||
weight_dequant = torch.ops.tensorrt_llm.e2m1_and_ufp8sf_scale_to_float_v2(
|
||||
weight.contiguous(),
|
||||
weight_scale.flatten().view(
|
||||
fp4_utils.float4_sf_dtype).contiguous(), old_scale_2, 16, 1,
|
||||
True).to(dtype=torch.bfloat16).reshape(dequant_shape)
|
||||
|
||||
# Requantize using the new_scale_2
|
||||
weight_requant, weight_scale_requant = torch.ops.trtllm.fp4_quantize(
|
||||
weight_dequant.to(device),
|
||||
1.0 / new_scale_2.to(device),
|
||||
16, # scaling_vector_size
|
||||
False)
|
||||
|
||||
# Ensure the returned scale has the same dtype as the input scale
|
||||
return weight_requant.cpu(), weight_scale_requant.reshape(
|
||||
original_scale_shape).view(original_scale_dtype).cpu()
|
||||
|
||||
def rename_moe_weight(weights: Dict, rename_rules: Dict):
|
||||
result = {}
|
||||
for key, value in weights.items():
|
||||
@ -355,27 +394,128 @@ class DeepseekV3WeightLoader:
|
||||
).view(*attn_module.v_b_proj_dequant.shape).to(
|
||||
attn_module.v_b_proj_dequant.dtype))
|
||||
elif names[-1] == "kv_a_proj_with_mqa":
|
||||
fused_a = weights[
|
||||
f"{'.'.join(names[:-1])}.kv_a_proj_with_mqa.weight"][:]
|
||||
if not is_lite:
|
||||
q_a_proj = weights[
|
||||
f"{'.'.join(names[:-1])}.q_a_proj.weight"][:]
|
||||
fused_a = torch.cat([q_a_proj, fused_a], dim=0)
|
||||
nvfp4_fused_a = self.model_config.get_quant_config(
|
||||
).layer_quant_mode.has_nvfp4() and weights[
|
||||
f"{'.'.join(names[:-1])}.kv_a_proj_with_mqa.weight"].dtype == fp4_utils.float4_e2m1x2 and weights[
|
||||
f"{'.'.join(names[:-1])}.q_a_proj.weight"].dtype == fp4_utils.float4_e2m1x2
|
||||
if nvfp4_fused_a:
|
||||
########### input_scale
|
||||
kv_a_proj_with_mqa_input_scale = weights[
|
||||
f"{'.'.join(names[:-1])}.kv_a_proj_with_mqa.input_scale"]
|
||||
if not is_lite:
|
||||
q_a_proj_input_scale = weights[
|
||||
f"{'.'.join(names[:-1])}.q_a_proj.input_scale"]
|
||||
assert kv_a_proj_with_mqa_input_scale == q_a_proj_input_scale, "kv_a_proj_with_mqa.input_scale and q_a_proj.input_scale should be the same"
|
||||
# modelopt ckpt stores amax/(448*6), convert to (448*6)/amax
|
||||
shared_input_scale = kv_a_proj_with_mqa_input_scale
|
||||
module.input_scale.data.copy_(1.0 / shared_input_scale)
|
||||
E2M1_MAX = 6.0
|
||||
module.inv_input_scale.data.copy_(module.input_scale /
|
||||
E2M1_MAX)
|
||||
########### weight_scale_2
|
||||
need_requant_kv_a_proj_with_mqa = False
|
||||
need_requant_q_a_proj = False
|
||||
kv_a_proj_with_mqa_scale_2 = weights[
|
||||
f"{'.'.join(names[:-1])}.kv_a_proj_with_mqa.weight_scale_2"]
|
||||
shared_weight_scale_2 = kv_a_proj_with_mqa_scale_2
|
||||
if not is_lite:
|
||||
q_a_proj_scale_2 = weights[
|
||||
f"{'.'.join(names[:-1])}.q_a_proj.weight_scale_2"]
|
||||
if kv_a_proj_with_mqa_scale_2 < q_a_proj_scale_2:
|
||||
shared_weight_scale_2 = q_a_proj_scale_2
|
||||
need_requant_kv_a_proj_with_mqa = True
|
||||
elif q_a_proj_scale_2 < kv_a_proj_with_mqa_scale_2:
|
||||
need_requant_q_a_proj = True
|
||||
|
||||
if f"{'.'.join(names[:-1])}.kv_a_proj_with_mqa.weight_scale_inv" in weights:
|
||||
fused_a_scale = weights[
|
||||
f"{'.'.join(names[:-1])}.kv_a_proj_with_mqa.weight_scale_inv"]
|
||||
########### alpha
|
||||
alpha = shared_input_scale.float(
|
||||
) * shared_weight_scale_2.float()
|
||||
module.alpha.data.copy_(alpha)
|
||||
module.scalar_alpha = alpha.item()
|
||||
|
||||
########### weights
|
||||
kv_a_proj_with_mqa = weights[
|
||||
f"{'.'.join(names[:-1])}.kv_a_proj_with_mqa.weight"][:]
|
||||
|
||||
if not is_lite:
|
||||
q_a_proj = weights[
|
||||
f"{'.'.join(names[:-1])}.q_a_proj.weight"][:]
|
||||
|
||||
########### weight_scale
|
||||
kv_a_proj_with_mqa_scale = weights[
|
||||
f"{'.'.join(names[:-1])}.kv_a_proj_with_mqa.weight_scale"][:]
|
||||
kv_a_proj_with_mqa_scale = torch.ops.trtllm.block_scale_interleave(
|
||||
kv_a_proj_with_mqa_scale.view(
|
||||
fp4_utils.float4_sf_dtype))
|
||||
if not is_lite:
|
||||
q_a_proj_scale = weights[
|
||||
f"{'.'.join(names[:-1])}.q_a_proj.weight_scale_inv"][:]
|
||||
fused_a_scale = torch.cat(
|
||||
[q_a_proj_scale, fused_a_scale], dim=0)
|
||||
f"{'.'.join(names[:-1])}.q_a_proj.weight_scale"][:]
|
||||
q_a_proj_scale = torch.ops.trtllm.block_scale_interleave(
|
||||
q_a_proj_scale.view(fp4_utils.float4_sf_dtype))
|
||||
|
||||
########### requantize
|
||||
if need_requant_kv_a_proj_with_mqa:
|
||||
# requant kv_a_proj_with_mqa
|
||||
kv_a_proj_with_mqa, kv_a_proj_with_mqa_scale = requantize_weight_with_new_scale(
|
||||
kv_a_proj_with_mqa,
|
||||
kv_a_proj_with_mqa_scale,
|
||||
kv_a_proj_with_mqa_scale_2,
|
||||
shared_weight_scale_2,
|
||||
device=module.weight.device,
|
||||
)
|
||||
if need_requant_q_a_proj:
|
||||
# requant q_a_proj
|
||||
q_a_proj, q_a_proj_scale = requantize_weight_with_new_scale(
|
||||
q_a_proj,
|
||||
q_a_proj_scale,
|
||||
q_a_proj_scale_2,
|
||||
shared_weight_scale_2,
|
||||
device=module.weight.device)
|
||||
|
||||
########### fuse and load weights
|
||||
if not is_lite:
|
||||
fused_a = torch.cat([q_a_proj, kv_a_proj_with_mqa],
|
||||
dim=0)
|
||||
else:
|
||||
fused_a = kv_a_proj_with_mqa
|
||||
|
||||
# For DeepseekV32: kv_a_proj_with_mqa is oversized
|
||||
# to include indexer k weights, which is filled in post_load_weights.
|
||||
module.weight.data[0:fused_a.shape[0]].copy_(fused_a)
|
||||
|
||||
########### fuse weight_scale
|
||||
if not is_lite:
|
||||
fused_a_scale = torch.cat(
|
||||
[q_a_proj_scale, kv_a_proj_with_mqa_scale],
|
||||
dim=0)
|
||||
else:
|
||||
fused_a_scale = kv_a_proj_with_mqa_scale
|
||||
# For DeepseekV32: kv_a_proj_with_mqa is oversized
|
||||
# to include indexer k weights, which is filled in post_load_weights.
|
||||
module.weight_scale.data[0:fused_a_scale.
|
||||
shape[0]].copy_(fused_a_scale)
|
||||
# For DeepseekV32: kv_a_proj_with_mqa is oversized
|
||||
# to include indexer k weights, which is filled in post_load_weights.
|
||||
module.weight.data[0:fused_a.shape[0]].copy_(fused_a)
|
||||
else:
|
||||
fused_a = weights[
|
||||
f"{'.'.join(names[:-1])}.kv_a_proj_with_mqa.weight"][:]
|
||||
if not is_lite:
|
||||
q_a_proj = weights[
|
||||
f"{'.'.join(names[:-1])}.q_a_proj.weight"][:]
|
||||
fused_a = torch.cat([q_a_proj, fused_a], dim=0)
|
||||
|
||||
if f"{'.'.join(names[:-1])}.kv_a_proj_with_mqa.weight_scale_inv" in weights:
|
||||
fused_a_scale = weights[
|
||||
f"{'.'.join(names[:-1])}.kv_a_proj_with_mqa.weight_scale_inv"]
|
||||
if not is_lite:
|
||||
q_a_proj_scale = weights[
|
||||
f"{'.'.join(names[:-1])}.q_a_proj.weight_scale_inv"][:]
|
||||
fused_a_scale = torch.cat(
|
||||
[q_a_proj_scale, fused_a_scale], dim=0)
|
||||
|
||||
module.weight_scale.data[
|
||||
0:fused_a_scale.shape[0]].copy_(fused_a_scale)
|
||||
# For DeepseekV32: kv_a_proj_with_mqa is oversized
|
||||
# to include indexer k weights, which is filled in post_load_weights.
|
||||
module.weight.data[0:fused_a.shape[0]].copy_(fused_a)
|
||||
elif names[-1] in params_map:
|
||||
module_weights = []
|
||||
for new_name in params_map[names[-1]]:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user