TensorRT-LLMs/tensorrt_llm/models/deepseek_v1/convert.py
mpikulski 93a4b7f1b6
[None][chore] update torch_dtype -> dtype in 'transformers' (#8263)
Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com>
2025-10-15 17:09:30 +09:00

291 lines
12 KiB
Python
Executable File

# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
import torch
from transformers import AutoModelForCausalLM
from ..._utils import pad_vocab_size, release_gc
## Get HF model
def load_hf_deepseek(model_dir):
model = AutoModelForCausalLM.from_pretrained(model_dir,
device_map='auto',
dtype='auto',
trust_remote_code=True)
return model
## Prepare weights for TP
def split(v, tp_size, idx, dim=0):
if tp_size == 1:
return v
if len(v.shape) == 1:
return torch.chunk(v, tp_size)[idx].contiguous()
else:
return torch.chunk(v, tp_size, dim=dim)[idx].contiguous()
def split_qkv_tp(v, n_head, n_hidden, tensor_parallel, rank):
"""
Splits the QKV matrix according to tensor parallelism
"""
v = v.reshape(3, n_hidden, n_hidden)
split_v = split(v, tensor_parallel, rank, dim=1)
split_v = split_v.reshape(3 * (n_hidden // tensor_parallel), n_hidden)
return split_v.contiguous()
def split_matrix_tp(v, tensor_parallel, rank, dim):
return split(v, tensor_parallel, rank, dim=dim)
def get_weight(config, prefix, dtype, postfix='.weight'):
if config[prefix + postfix].dtype != dtype:
config[prefix + postfix].data = config[prefix + postfix].to(dtype)
return config[prefix + postfix].detach().cpu()
def get_trtllm_linear_weight(weight, prefix, postfix='weight'):
results = {}
results[prefix + postfix] = weight
return results
def convert_deepseek(hf_model,
config,
mapping,
dtype='float32',
use_parallel_embedding=False,
sharding_dim=0):
weights = {}
tik = time.time()
mapping.tp_size
model_params = dict(hf_model.named_parameters())
dtype = getattr(torch, dtype)
moe_config = config.moe
layers_range = mapping.pp_layers(config.num_hidden_layers)
def convert_layer(l):
prefix = f'model.layers.{l}.'
print(prefix)
trtllm_prex = f'transformer.layers.{l - layers_range[0]}.'
q_weight = get_weight(model_params, prefix + 'self_attn.q_proj', dtype)
k_weight = get_weight(model_params, prefix + 'self_attn.k_proj', dtype)
v_weight = get_weight(model_params, prefix + 'self_attn.v_proj', dtype)
qkv_weight = torch.cat([q_weight, k_weight, v_weight], dim=0)
split_v = split_qkv_tp(qkv_weight, config.num_attention_heads,
config.hidden_size, mapping.tp_size,
mapping.tp_rank)
weights.update(
get_trtllm_linear_weight(split_v, trtllm_prex + 'attention.qkv.'))
attn_dense_weight = get_weight(model_params,
prefix + 'self_attn.o_proj', dtype)
split_v = split_matrix_tp(attn_dense_weight,
mapping.tp_size,
mapping.tp_rank,
dim=1)
weights.update(
get_trtllm_linear_weight(split_v, trtllm_prex + 'attention.dense.'))
if moe_config.has_moe() and l > 0:
rank_experts = list(range(moe_config.num_experts))
if mapping.has_moe_ep():
rank_experts = mapping.ep_experts(moe_config.num_experts)
for suffix in ["gate_proj", "down_proj", "up_proj"]:
model_params[f'model.layers.{l}.mlp.experts.{suffix}.weight'] = \
torch.stack([model_params[f'model.layers.{l}.mlp.experts.{expert}.{suffix}.weight'].detach().cpu()
for expert in rank_experts])
gate_proj = model_params[
f'model.layers.{l}.mlp.experts.gate_proj.weight']
down_proj = model_params[
f'model.layers.{l}.mlp.experts.down_proj.weight']
up_proj = model_params[
f'model.layers.{l}.mlp.experts.up_proj.weight']
if mapping.has_moe_tp():
gate_proj = split(gate_proj,
mapping.tp_size,
mapping.tp_rank,
dim=1)
down_proj = split(down_proj,
mapping.tp_size,
mapping.tp_rank,
dim=2)
up_proj = split(up_proj,
mapping.tp_size,
mapping.tp_rank,
dim=1)
model_params[
f'model.layers.{l}.mlp.experts.up_gate_proj.weight'] = torch.concat(
[up_proj, gate_proj], dim=-2)
model_params[
f'model.layers.{l}.mlp.experts.down_proj.weight'] = down_proj
## mlp.experts.down_proj.weight
moe_experts_down_proj_weights = get_weight(
model_params, prefix + 'mlp.experts.down_proj', dtype)
weights.update(
get_trtllm_linear_weight(moe_experts_down_proj_weights,
trtllm_prex + 'mlp.proj.'))
##mlp.experts.up_gate.weight
moe_experts_up_gate_proj_weights = get_weight(
model_params, prefix + 'mlp.experts.up_gate_proj', dtype)
weights.update(
get_trtllm_linear_weight(moe_experts_up_gate_proj_weights,
trtllm_prex + 'mlp.fc.'))
## MOE hardcoded routing_input into trt.float32, please refer to moe.py line 397
moe_experts_gate_weights = get_weight(model_params,
prefix + 'mlp.gate',
torch.float32)
weights.update(
get_trtllm_linear_weight(moe_experts_gate_weights,
trtllm_prex + 'mlp.router.'))
if moe_config.shared_expert_intermediate_size > 0:
shared_moe_up_proj_weights = get_weight(
model_params, prefix + 'mlp.shared_experts.up_proj', dtype)
shared_moe_up_proj_weights = split_matrix_tp(
shared_moe_up_proj_weights,
mapping.tp_size,
mapping.tp_rank,
dim=0)
shared_moe_down_proj_weights = get_weight(
model_params, prefix + 'mlp.shared_experts.down_proj',
dtype)
shared_moe_down_proj_weights = split_matrix_tp(
shared_moe_down_proj_weights,
mapping.tp_size,
mapping.tp_rank,
dim=1)
shared_moe_gate_proj_weights = get_weight(
model_params, prefix + 'mlp.shared_experts.gate_proj',
dtype)
shared_moe_gate_proj_weights = split_matrix_tp(
shared_moe_gate_proj_weights,
mapping.tp_size,
mapping.tp_rank,
dim=0)
shared_moe_gate_up_proj_weights = torch.concat(
[shared_moe_up_proj_weights, shared_moe_gate_proj_weights],
dim=-2)
## mlp.shared_experts.gate_up_proj.weight
weights.update(
get_trtllm_linear_weight(
shared_moe_gate_up_proj_weights,
trtllm_prex + 'mlp.shared_expert.fc.'))
## mlp.shared_experts.down_proj.weight
weights.update(
get_trtllm_linear_weight(
shared_moe_down_proj_weights,
trtllm_prex + 'mlp.shared_expert.proj.'))
else:
## Current deepseek model has one MLP layer only, if it goes large consider to do fuse
mlp_gate_weight = get_weight(model_params, prefix + 'mlp.up_proj',
dtype)
split_gate = split_matrix_tp(mlp_gate_weight,
mapping.tp_size,
mapping.tp_rank,
dim=0)
weights.update(
get_trtllm_linear_weight(split_gate, trtllm_prex + 'mlp.gate.'))
mlp_fc_weight = get_weight(model_params, prefix + 'mlp.gate_proj',
dtype)
split_fc = split_matrix_tp(mlp_fc_weight,
mapping.tp_size,
mapping.tp_rank,
dim=0)
weights.update(
get_trtllm_linear_weight(split_fc, trtllm_prex + 'mlp.fc.'))
mlp_proj_weight = get_weight(model_params, prefix + 'mlp.down_proj',
dtype)
split_proj = split_matrix_tp(mlp_proj_weight,
mapping.tp_size,
mapping.tp_rank,
dim=1)
weights.update(
get_trtllm_linear_weight(split_proj, trtllm_prex + 'mlp.proj.'))
# Layer norms do not use tensor parallelism
input_ln_weight = get_weight(model_params, prefix + 'input_layernorm',
dtype)
weights[trtllm_prex + 'input_layernorm.weight'] = input_ln_weight
post_ln_weight = get_weight(model_params,
prefix + 'post_attention_layernorm', dtype)
weights[trtllm_prex + 'post_layernorm.weight'] = post_ln_weight
for l in layers_range:
convert_layer(l)
release_gc()
v = get_weight(model_params, 'model.embed_tokens', dtype)
if hf_model.config.tie_word_embeddings:
# lm_head.weight has the same weights as embedding
if mapping.is_last_pp_rank():
if config.vocab_size % mapping.tp_size != 0:
# padding
vocab_size_padded = pad_vocab_size(config.vocab_size,
mapping.tp_size)
pad_width = vocab_size_padded - config.vocab_size
v = torch.nn.functional.pad(v, (0, 0, 0, pad_width), 'constant',
0)
weights['lm_head.weight'] = split(v, mapping.tp_size,
mapping.tp_rank)
if use_parallel_embedding:
v = split_matrix_tp(v,
mapping.tp_size,
mapping.tp_rank,
dim=config.embedding_sharding_dim)
if mapping.is_first_pp_rank():
weights['transformer.vocab_embedding.weight'] = v
lm_head_weights = get_weight(model_params, 'lm_head', dtype)
if mapping.is_last_pp_rank():
if config.vocab_size % mapping.tp_size != 0:
# padding
vocab_size_padded = pad_vocab_size(config.vocab_size,
mapping.tp_size)
pad_width = vocab_size_padded - config.vocab_size
lm_head_weights = torch.nn.functional.pad(lm_head_weights,
(0, 0, 0, pad_width),
'constant',
value=0)
weights['lm_head.weight'] = split_matrix_tp(lm_head_weights,
mapping.tp_size,
mapping.tp_rank,
dim=0)
ln_f_w = get_weight(model_params, 'model.norm', dtype)
weights['transformer.ln_f.weight'] = ln_f_w
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
print(f'Weights loaded. Total time: {t}')
#print(set(weights.keys()))
return weights