mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
291 lines
12 KiB
Python
Executable File
291 lines
12 KiB
Python
Executable File
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import time
|
|
|
|
import torch
|
|
from transformers import AutoModelForCausalLM
|
|
|
|
from ..._utils import pad_vocab_size, release_gc
|
|
|
|
|
|
## Get HF model
|
|
def load_hf_deepseek(model_dir):
|
|
model = AutoModelForCausalLM.from_pretrained(model_dir,
|
|
device_map='auto',
|
|
dtype='auto',
|
|
trust_remote_code=True)
|
|
return model
|
|
|
|
|
|
## Prepare weights for TP
|
|
def split(v, tp_size, idx, dim=0):
|
|
if tp_size == 1:
|
|
return v
|
|
if len(v.shape) == 1:
|
|
return torch.chunk(v, tp_size)[idx].contiguous()
|
|
else:
|
|
return torch.chunk(v, tp_size, dim=dim)[idx].contiguous()
|
|
|
|
|
|
def split_qkv_tp(v, n_head, n_hidden, tensor_parallel, rank):
|
|
"""
|
|
Splits the QKV matrix according to tensor parallelism
|
|
"""
|
|
v = v.reshape(3, n_hidden, n_hidden)
|
|
split_v = split(v, tensor_parallel, rank, dim=1)
|
|
split_v = split_v.reshape(3 * (n_hidden // tensor_parallel), n_hidden)
|
|
return split_v.contiguous()
|
|
|
|
|
|
def split_matrix_tp(v, tensor_parallel, rank, dim):
|
|
return split(v, tensor_parallel, rank, dim=dim)
|
|
|
|
|
|
def get_weight(config, prefix, dtype, postfix='.weight'):
|
|
if config[prefix + postfix].dtype != dtype:
|
|
config[prefix + postfix].data = config[prefix + postfix].to(dtype)
|
|
return config[prefix + postfix].detach().cpu()
|
|
|
|
|
|
def get_trtllm_linear_weight(weight, prefix, postfix='weight'):
|
|
results = {}
|
|
results[prefix + postfix] = weight
|
|
|
|
return results
|
|
|
|
|
|
def convert_deepseek(hf_model,
|
|
config,
|
|
mapping,
|
|
dtype='float32',
|
|
use_parallel_embedding=False,
|
|
sharding_dim=0):
|
|
|
|
weights = {}
|
|
tik = time.time()
|
|
mapping.tp_size
|
|
model_params = dict(hf_model.named_parameters())
|
|
dtype = getattr(torch, dtype)
|
|
moe_config = config.moe
|
|
layers_range = mapping.pp_layers(config.num_hidden_layers)
|
|
|
|
def convert_layer(l):
|
|
prefix = f'model.layers.{l}.'
|
|
print(prefix)
|
|
trtllm_prex = f'transformer.layers.{l - layers_range[0]}.'
|
|
q_weight = get_weight(model_params, prefix + 'self_attn.q_proj', dtype)
|
|
k_weight = get_weight(model_params, prefix + 'self_attn.k_proj', dtype)
|
|
v_weight = get_weight(model_params, prefix + 'self_attn.v_proj', dtype)
|
|
|
|
qkv_weight = torch.cat([q_weight, k_weight, v_weight], dim=0)
|
|
|
|
split_v = split_qkv_tp(qkv_weight, config.num_attention_heads,
|
|
config.hidden_size, mapping.tp_size,
|
|
mapping.tp_rank)
|
|
|
|
weights.update(
|
|
get_trtllm_linear_weight(split_v, trtllm_prex + 'attention.qkv.'))
|
|
|
|
attn_dense_weight = get_weight(model_params,
|
|
prefix + 'self_attn.o_proj', dtype)
|
|
split_v = split_matrix_tp(attn_dense_weight,
|
|
mapping.tp_size,
|
|
mapping.tp_rank,
|
|
dim=1)
|
|
|
|
weights.update(
|
|
get_trtllm_linear_weight(split_v, trtllm_prex + 'attention.dense.'))
|
|
|
|
if moe_config.has_moe() and l > 0:
|
|
rank_experts = list(range(moe_config.num_experts))
|
|
if mapping.has_moe_ep():
|
|
rank_experts = mapping.ep_experts(moe_config.num_experts)
|
|
for suffix in ["gate_proj", "down_proj", "up_proj"]:
|
|
model_params[f'model.layers.{l}.mlp.experts.{suffix}.weight'] = \
|
|
torch.stack([model_params[f'model.layers.{l}.mlp.experts.{expert}.{suffix}.weight'].detach().cpu()
|
|
for expert in rank_experts])
|
|
|
|
gate_proj = model_params[
|
|
f'model.layers.{l}.mlp.experts.gate_proj.weight']
|
|
down_proj = model_params[
|
|
f'model.layers.{l}.mlp.experts.down_proj.weight']
|
|
up_proj = model_params[
|
|
f'model.layers.{l}.mlp.experts.up_proj.weight']
|
|
if mapping.has_moe_tp():
|
|
gate_proj = split(gate_proj,
|
|
mapping.tp_size,
|
|
mapping.tp_rank,
|
|
dim=1)
|
|
down_proj = split(down_proj,
|
|
mapping.tp_size,
|
|
mapping.tp_rank,
|
|
dim=2)
|
|
up_proj = split(up_proj,
|
|
mapping.tp_size,
|
|
mapping.tp_rank,
|
|
dim=1)
|
|
|
|
model_params[
|
|
f'model.layers.{l}.mlp.experts.up_gate_proj.weight'] = torch.concat(
|
|
[up_proj, gate_proj], dim=-2)
|
|
model_params[
|
|
f'model.layers.{l}.mlp.experts.down_proj.weight'] = down_proj
|
|
|
|
## mlp.experts.down_proj.weight
|
|
moe_experts_down_proj_weights = get_weight(
|
|
model_params, prefix + 'mlp.experts.down_proj', dtype)
|
|
weights.update(
|
|
get_trtllm_linear_weight(moe_experts_down_proj_weights,
|
|
trtllm_prex + 'mlp.proj.'))
|
|
##mlp.experts.up_gate.weight
|
|
moe_experts_up_gate_proj_weights = get_weight(
|
|
model_params, prefix + 'mlp.experts.up_gate_proj', dtype)
|
|
weights.update(
|
|
get_trtllm_linear_weight(moe_experts_up_gate_proj_weights,
|
|
trtllm_prex + 'mlp.fc.'))
|
|
## MOE hardcoded routing_input into trt.float32, please refer to moe.py line 397
|
|
moe_experts_gate_weights = get_weight(model_params,
|
|
prefix + 'mlp.gate',
|
|
torch.float32)
|
|
weights.update(
|
|
get_trtllm_linear_weight(moe_experts_gate_weights,
|
|
trtllm_prex + 'mlp.router.'))
|
|
|
|
if moe_config.shared_expert_intermediate_size > 0:
|
|
shared_moe_up_proj_weights = get_weight(
|
|
model_params, prefix + 'mlp.shared_experts.up_proj', dtype)
|
|
shared_moe_up_proj_weights = split_matrix_tp(
|
|
shared_moe_up_proj_weights,
|
|
mapping.tp_size,
|
|
mapping.tp_rank,
|
|
dim=0)
|
|
shared_moe_down_proj_weights = get_weight(
|
|
model_params, prefix + 'mlp.shared_experts.down_proj',
|
|
dtype)
|
|
shared_moe_down_proj_weights = split_matrix_tp(
|
|
shared_moe_down_proj_weights,
|
|
mapping.tp_size,
|
|
mapping.tp_rank,
|
|
dim=1)
|
|
shared_moe_gate_proj_weights = get_weight(
|
|
model_params, prefix + 'mlp.shared_experts.gate_proj',
|
|
dtype)
|
|
shared_moe_gate_proj_weights = split_matrix_tp(
|
|
shared_moe_gate_proj_weights,
|
|
mapping.tp_size,
|
|
mapping.tp_rank,
|
|
dim=0)
|
|
shared_moe_gate_up_proj_weights = torch.concat(
|
|
[shared_moe_up_proj_weights, shared_moe_gate_proj_weights],
|
|
dim=-2)
|
|
|
|
## mlp.shared_experts.gate_up_proj.weight
|
|
weights.update(
|
|
get_trtllm_linear_weight(
|
|
shared_moe_gate_up_proj_weights,
|
|
trtllm_prex + 'mlp.shared_expert.fc.'))
|
|
|
|
## mlp.shared_experts.down_proj.weight
|
|
weights.update(
|
|
get_trtllm_linear_weight(
|
|
shared_moe_down_proj_weights,
|
|
trtllm_prex + 'mlp.shared_expert.proj.'))
|
|
|
|
else:
|
|
## Current deepseek model has one MLP layer only, if it goes large consider to do fuse
|
|
mlp_gate_weight = get_weight(model_params, prefix + 'mlp.up_proj',
|
|
dtype)
|
|
split_gate = split_matrix_tp(mlp_gate_weight,
|
|
mapping.tp_size,
|
|
mapping.tp_rank,
|
|
dim=0)
|
|
weights.update(
|
|
get_trtllm_linear_weight(split_gate, trtllm_prex + 'mlp.gate.'))
|
|
|
|
mlp_fc_weight = get_weight(model_params, prefix + 'mlp.gate_proj',
|
|
dtype)
|
|
split_fc = split_matrix_tp(mlp_fc_weight,
|
|
mapping.tp_size,
|
|
mapping.tp_rank,
|
|
dim=0)
|
|
weights.update(
|
|
get_trtllm_linear_weight(split_fc, trtllm_prex + 'mlp.fc.'))
|
|
|
|
mlp_proj_weight = get_weight(model_params, prefix + 'mlp.down_proj',
|
|
dtype)
|
|
split_proj = split_matrix_tp(mlp_proj_weight,
|
|
mapping.tp_size,
|
|
mapping.tp_rank,
|
|
dim=1)
|
|
weights.update(
|
|
get_trtllm_linear_weight(split_proj, trtllm_prex + 'mlp.proj.'))
|
|
|
|
# Layer norms do not use tensor parallelism
|
|
input_ln_weight = get_weight(model_params, prefix + 'input_layernorm',
|
|
dtype)
|
|
weights[trtllm_prex + 'input_layernorm.weight'] = input_ln_weight
|
|
post_ln_weight = get_weight(model_params,
|
|
prefix + 'post_attention_layernorm', dtype)
|
|
weights[trtllm_prex + 'post_layernorm.weight'] = post_ln_weight
|
|
|
|
for l in layers_range:
|
|
convert_layer(l)
|
|
release_gc()
|
|
|
|
v = get_weight(model_params, 'model.embed_tokens', dtype)
|
|
if hf_model.config.tie_word_embeddings:
|
|
# lm_head.weight has the same weights as embedding
|
|
if mapping.is_last_pp_rank():
|
|
if config.vocab_size % mapping.tp_size != 0:
|
|
# padding
|
|
vocab_size_padded = pad_vocab_size(config.vocab_size,
|
|
mapping.tp_size)
|
|
pad_width = vocab_size_padded - config.vocab_size
|
|
v = torch.nn.functional.pad(v, (0, 0, 0, pad_width), 'constant',
|
|
0)
|
|
weights['lm_head.weight'] = split(v, mapping.tp_size,
|
|
mapping.tp_rank)
|
|
if use_parallel_embedding:
|
|
v = split_matrix_tp(v,
|
|
mapping.tp_size,
|
|
mapping.tp_rank,
|
|
dim=config.embedding_sharding_dim)
|
|
if mapping.is_first_pp_rank():
|
|
weights['transformer.vocab_embedding.weight'] = v
|
|
lm_head_weights = get_weight(model_params, 'lm_head', dtype)
|
|
|
|
if mapping.is_last_pp_rank():
|
|
if config.vocab_size % mapping.tp_size != 0:
|
|
# padding
|
|
vocab_size_padded = pad_vocab_size(config.vocab_size,
|
|
mapping.tp_size)
|
|
pad_width = vocab_size_padded - config.vocab_size
|
|
lm_head_weights = torch.nn.functional.pad(lm_head_weights,
|
|
(0, 0, 0, pad_width),
|
|
'constant',
|
|
value=0)
|
|
weights['lm_head.weight'] = split_matrix_tp(lm_head_weights,
|
|
mapping.tp_size,
|
|
mapping.tp_rank,
|
|
dim=0)
|
|
ln_f_w = get_weight(model_params, 'model.norm', dtype)
|
|
weights['transformer.ln_f.weight'] = ln_f_w
|
|
tok = time.time()
|
|
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
|
|
print(f'Weights loaded. Total time: {t}')
|
|
#print(set(weights.keys()))
|
|
return weights
|