mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
131 lines
5.4 KiB
Python
131 lines
5.4 KiB
Python
import torch
|
|
|
|
from ..._utils import str_dtype_to_torch
|
|
from .split_weights import shuffle_qkv_weights, split_weights_tp
|
|
|
|
|
|
def load_weights_from_hf_model(hf_model, config):
|
|
torch_dtype = str_dtype_to_torch(config.dtype)
|
|
hf_state_dict = hf_model.state_dict()
|
|
weights = {}
|
|
|
|
config.quant_mode.is_weight_only()
|
|
if config.quant_mode.is_int8_weight_only():
|
|
torch.int8
|
|
elif config.quant_mode.is_int4_weight_only():
|
|
torch.quint4x2
|
|
# replace key name
|
|
for key, value in hf_state_dict.items():
|
|
# Decoder Layers
|
|
orig_key = key
|
|
if "model.layers." in key:
|
|
key = key.replace("model.layers.", "transformer.layers.")
|
|
#Attention
|
|
key = key.replace("self_attn.", "attention.")
|
|
key = key.replace("query_key_value.", "qkv.") # small
|
|
key = key.replace("Wqkv.weight", "qkv.weight")
|
|
key = key.replace("qkv_proj.", "qkv.") #128k
|
|
#MLP
|
|
key = key.replace("mlp.fc1.", "mlp.fc.")
|
|
key = key.replace("mlp.fc2.", "mlp.proj.")
|
|
key = key.replace("mlp.gate_up_proj.", "mlp.fc.")
|
|
key = key.replace("mlp.up_proj.", "mlp.fc." if config.architecture
|
|
== 'Phi3SmallForCausalLM' else "mlp.gate.") #128k
|
|
key = key.replace("mlp.down_proj.", "mlp.proj.") #128k
|
|
key = key.replace("mlp.gate_proj.", "mlp.fc.") #128k
|
|
key = key.replace("o_proj.", "dense.") #128k
|
|
|
|
#MoE
|
|
key = key.replace("block_sparse_moe.gate", "mlp.router")
|
|
key = key.replace("block_sparse_moe.experts.0.w3", "mlp.fc")
|
|
key = key.replace("block_sparse_moe.experts.0.w2", "mlp.proj")
|
|
|
|
#Layer norm
|
|
key = key.replace("post_attention_layernorm.",
|
|
"post_layernorm.") #128k
|
|
|
|
# Embedding
|
|
key = key.replace("model.embed_tokens.weight",
|
|
"transformer.vocab_embedding.weight")
|
|
# Final Layer norm
|
|
key = key.replace("model.final_layernorm.", "transformer.ln_f.")
|
|
key = key.replace("model.norm.", "transformer.ln_f.") #128k
|
|
|
|
if "mlp.gate_up_proj." in orig_key: #4k
|
|
original_weights = value.contiguous().clone()
|
|
half_split = original_weights.shape[0] // 2
|
|
first_half, second_half = original_weights[:
|
|
half_split, :], original_weights[
|
|
half_split:, :]
|
|
# Swap the halves
|
|
value = torch.cat((second_half, first_half), dim=0)
|
|
|
|
if config.architecture == "PhiMoEForCausalLM":
|
|
num_experts = config.moe["num_experts"]
|
|
mlp_hidden_size = config.intermediate_size
|
|
num_hidden = config.hidden_size
|
|
rank_experts = list(range(num_experts))
|
|
if config.mapping.has_moe_ep():
|
|
rank_experts = config.mapping.ep_experts(num_experts)
|
|
|
|
def get_moe_weight(key, suffix):
|
|
param = []
|
|
for expert in rank_experts:
|
|
name = key.replace(f"0.{suffix}", f"{expert}.{suffix}")
|
|
fc_value = hf_state_dict[name]
|
|
param.append(fc_value)
|
|
w = torch.stack(param)
|
|
return w.reshape(-1, mlp_hidden_size, num_hidden)
|
|
|
|
if ".0.w3" in orig_key:
|
|
w3 = get_moe_weight(orig_key, 'w3')
|
|
w1 = get_moe_weight(orig_key.replace("w3", "w1"), 'w1')
|
|
value = torch.concat([w3, w1], dim=-2)
|
|
elif ".0.w2" in orig_key:
|
|
w2 = get_moe_weight(orig_key, 'w2')
|
|
value = w2.reshape(-1, num_hidden, mlp_hidden_size)
|
|
elif any([k in orig_key for k in ["w1", "w2", "w3"]]):
|
|
continue
|
|
|
|
if "q_proj" in key: #128k
|
|
q_param = value
|
|
k_param = hf_state_dict[orig_key.replace("q_proj", "k_proj")]
|
|
v_param = hf_state_dict[orig_key.replace("q_proj", "v_proj")]
|
|
value = torch.cat([q_param, k_param, v_param], dim=0)
|
|
key = key.replace("q_proj", "qkv")
|
|
elif "k_proj" in key or "v_proj" in key:
|
|
continue
|
|
|
|
dtype = torch.float if "router" in key else torch_dtype
|
|
weights[key] = value.to(dtype).cpu()
|
|
|
|
#This is for InternVL-4B
|
|
if config.architecture == 'Phi3ForCausalLM':
|
|
keys_to_rename = [
|
|
key for key in weights.keys() if 'language_model.' in key
|
|
]
|
|
keys_to_delete = [
|
|
key for key in weights.keys() if 'vision_model.' in key
|
|
]
|
|
for key in keys_to_rename:
|
|
keys_rename = key.replace('language_model.', '')
|
|
weights[keys_rename] = weights[key]
|
|
del weights[key]
|
|
for key in keys_to_delete:
|
|
del weights[key]
|
|
|
|
if config.architecture == 'Phi3SmallForCausalLM':
|
|
weights['lm_head.weight'] = weights[
|
|
'transformer.vocab_embedding.weight'].clone()
|
|
|
|
# Transform QKV weights from custom Phi3Small format to TRT-LLM format
|
|
for key, value in weights.items():
|
|
if "qkv." in key:
|
|
weights[key] = shuffle_qkv_weights(weights[key], config)
|
|
|
|
if config.architecture in ['Phi3SmallForCausalLM', "PhiMoEForCausalLM"
|
|
] and config.mapping.has_tp():
|
|
weights = split_weights_tp(config, weights, torch_dtype)
|
|
|
|
return weights
|