mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
* Update TensorRT-LLM --------- Co-authored-by: Starrick Liu <73152103+StarrickLiu@users.noreply.github.com>
241 lines
11 KiB
Python
241 lines
11 KiB
Python
import time
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
from tensorrt_llm.logger import logger
|
|
|
|
from ..._utils import pad_vocab_size
|
|
from ..llama.convert import (get_tllm_linear_weight, get_weight, split,
|
|
split_matrix_tp, split_qkv_tp)
|
|
|
|
|
|
def convert_hf_cogvlm(hf_model,
|
|
mapping,
|
|
vocab_size=32000,
|
|
dtype='float32',
|
|
use_parallel_embedding=False,
|
|
sharding_dim=0,
|
|
use_weight_only=False,
|
|
use_gemm_woq_plugin=False,
|
|
plugin_weight_only_quant_type=torch.int8,
|
|
use_smooth_quant=False,
|
|
per_channel=False,
|
|
per_token=False,
|
|
int8_kv_cache=False,
|
|
act_range=[],
|
|
qkv_para=[],
|
|
smoother=[]):
|
|
|
|
weights = {}
|
|
tik = time.time()
|
|
tensor_parallel = mapping.tp_size
|
|
model_params = dict(hf_model.named_parameters())
|
|
dtype = getattr(torch, dtype)
|
|
num_attention_heads = hf_model.config.num_attention_heads
|
|
hidden_size = hf_model.config.hidden_size
|
|
if hasattr(hf_model.config, "num_key_value_heads"):
|
|
num_key_value_heads = hf_model.config.num_key_value_heads
|
|
else:
|
|
num_key_value_heads = num_attention_heads
|
|
mha_mode = (num_key_value_heads == num_attention_heads)
|
|
layers_range = mapping.pp_layers(hf_model.config.num_hidden_layers)
|
|
assert mha_mode, "CogVLM only supports mha mode"
|
|
assert not use_smooth_quant, "CogVLM currently doesn't support smooth quant"
|
|
assert not int8_kv_cache, "CogVLM currently doesn't support int8 kv cache"
|
|
|
|
for l in layers_range:
|
|
prefix = f'model.layers.{l}.'
|
|
tllm_prex = f'transformer.layers.{l - layers_range[0]}.'
|
|
|
|
qkv_weight = get_weight(
|
|
model_params, prefix + 'self_attn.language_expert_query_key_value',
|
|
dtype)
|
|
split_v = split_qkv_tp(qkv_weight, num_attention_heads, hidden_size,
|
|
tensor_parallel, mapping.tp_rank)
|
|
weights.update(
|
|
get_tllm_linear_weight(split_v, tllm_prex + 'attention.qkv.', None,
|
|
use_weight_only,
|
|
plugin_weight_only_quant_type, dtype,
|
|
use_gemm_woq_plugin))
|
|
|
|
vis_qkv_weight = get_weight(
|
|
model_params, prefix + 'self_attn.vision_expert_query_key_value',
|
|
dtype)
|
|
split_v = split_qkv_tp(vis_qkv_weight, num_attention_heads, hidden_size,
|
|
tensor_parallel, mapping.tp_rank)
|
|
weights.update(
|
|
get_tllm_linear_weight(split_v, tllm_prex + 'attention.vis_qkv.',
|
|
None, use_weight_only,
|
|
plugin_weight_only_quant_type, dtype,
|
|
use_gemm_woq_plugin))
|
|
|
|
attn_dense_weight = get_weight(
|
|
model_params, prefix + 'self_attn.language_expert_dense', dtype)
|
|
split_v = split_matrix_tp(attn_dense_weight,
|
|
tensor_parallel,
|
|
mapping.tp_rank,
|
|
dim=1)
|
|
weights.update(
|
|
get_tllm_linear_weight(split_v, tllm_prex + 'attention.dense.',
|
|
None, use_weight_only,
|
|
plugin_weight_only_quant_type, dtype,
|
|
use_gemm_woq_plugin))
|
|
|
|
attn_vision_dense_weight = get_weight(
|
|
model_params, prefix + 'self_attn.vision_expert_dense', dtype)
|
|
split_v = split_matrix_tp(attn_vision_dense_weight,
|
|
tensor_parallel,
|
|
mapping.tp_rank,
|
|
dim=1)
|
|
weights.update(
|
|
get_tllm_linear_weight(split_v, tllm_prex + 'attention.vis_dense.',
|
|
None, use_weight_only,
|
|
plugin_weight_only_quant_type, dtype,
|
|
use_gemm_woq_plugin))
|
|
|
|
mlp_gate_weight = get_weight(model_params,
|
|
prefix + 'mlp.language_mlp.up_proj', dtype)
|
|
split_v = split_matrix_tp(mlp_gate_weight,
|
|
tensor_parallel,
|
|
mapping.tp_rank,
|
|
dim=0)
|
|
weights.update(
|
|
get_tllm_linear_weight(split_v, tllm_prex + 'mlp.gate.', None,
|
|
use_weight_only,
|
|
plugin_weight_only_quant_type, dtype,
|
|
use_gemm_woq_plugin))
|
|
|
|
vision_mlp_gate_weight = get_weight(model_params,
|
|
prefix + 'mlp.vision_mlp.up_proj',
|
|
dtype)
|
|
split_v = split_matrix_tp(vision_mlp_gate_weight,
|
|
tensor_parallel,
|
|
mapping.tp_rank,
|
|
dim=0)
|
|
weights.update(
|
|
get_tllm_linear_weight(split_v, tllm_prex + 'vis_mlp.gate.', None,
|
|
use_weight_only,
|
|
plugin_weight_only_quant_type, dtype,
|
|
use_gemm_woq_plugin))
|
|
|
|
mlp_fc_weight = get_weight(model_params,
|
|
prefix + 'mlp.language_mlp.gate_proj', dtype)
|
|
split_v = split_matrix_tp(mlp_fc_weight,
|
|
tensor_parallel,
|
|
mapping.tp_rank,
|
|
dim=0)
|
|
weights.update(
|
|
get_tllm_linear_weight(split_v, tllm_prex + 'mlp.fc.', None,
|
|
use_weight_only,
|
|
plugin_weight_only_quant_type, dtype,
|
|
use_gemm_woq_plugin))
|
|
|
|
vision_mlp_fc_weight = get_weight(model_params,
|
|
prefix + 'mlp.vision_mlp.gate_proj',
|
|
dtype)
|
|
split_v = split_matrix_tp(vision_mlp_fc_weight,
|
|
tensor_parallel,
|
|
mapping.tp_rank,
|
|
dim=0)
|
|
weights.update(
|
|
get_tllm_linear_weight(split_v, tllm_prex + 'vis_mlp.fc.', None,
|
|
use_weight_only,
|
|
plugin_weight_only_quant_type, dtype,
|
|
use_gemm_woq_plugin))
|
|
|
|
mlp_proj_weight = get_weight(model_params,
|
|
prefix + 'mlp.language_mlp.down_proj',
|
|
dtype)
|
|
split_v = split_matrix_tp(mlp_proj_weight,
|
|
tensor_parallel,
|
|
mapping.tp_rank,
|
|
dim=1)
|
|
weights.update(
|
|
get_tllm_linear_weight(split_v, tllm_prex + 'mlp.proj.', None,
|
|
use_weight_only,
|
|
plugin_weight_only_quant_type, dtype,
|
|
use_gemm_woq_plugin))
|
|
|
|
vision_mlp_proj_weight = get_weight(model_params,
|
|
prefix + 'mlp.vision_mlp.down_proj',
|
|
dtype)
|
|
split_v = split_matrix_tp(vision_mlp_proj_weight,
|
|
tensor_parallel,
|
|
mapping.tp_rank,
|
|
dim=1)
|
|
weights.update(
|
|
get_tllm_linear_weight(split_v, tllm_prex + 'vis_mlp.proj.', None,
|
|
use_weight_only,
|
|
plugin_weight_only_quant_type, dtype,
|
|
use_gemm_woq_plugin))
|
|
|
|
# Layer norms do not use tensor parallelism
|
|
input_ln_weight = get_weight(model_params, prefix + 'input_layernorm',
|
|
dtype)
|
|
weights[tllm_prex + 'input_layernorm.weight'] = input_ln_weight
|
|
|
|
post_ln_weight = get_weight(model_params,
|
|
prefix + 'post_attention_layernorm', dtype)
|
|
weights[tllm_prex + 'post_layernorm.weight'] = post_ln_weight
|
|
cur_block_weights = [
|
|
weight_name for weight_name in model_params
|
|
if weight_name.find(prefix) != -1
|
|
]
|
|
for weight_name in cur_block_weights:
|
|
model_params[weight_name] = None
|
|
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
torch.cuda.ipc_collect()
|
|
|
|
v = get_weight(model_params, 'model.embed_tokens', dtype)
|
|
if hf_model.config.tie_word_embeddings:
|
|
# lm_head.weight has the same weights as embedding
|
|
if mapping.is_last_pp_rank():
|
|
if vocab_size % mapping.tp_size != 0:
|
|
# padding
|
|
vocab_size_padded = pad_vocab_size(vocab_size, mapping.tp_size)
|
|
pad_width = vocab_size_padded - vocab_size
|
|
|
|
v = torch.from_numpy(
|
|
np.pad(v.detach().cpu().numpy(), ((0, pad_width), (0, 0)),
|
|
'constant',
|
|
constant_values=0))
|
|
weights['lm_head.weight'] = split(v, mapping.tp_size,
|
|
mapping.tp_rank)
|
|
|
|
if use_parallel_embedding:
|
|
v = split_matrix_tp(v,
|
|
mapping.tp_size,
|
|
mapping.tp_rank,
|
|
dim=sharding_dim)
|
|
|
|
if mapping.is_first_pp_rank():
|
|
weights['transformer.vocab_embedding.weight'] = v
|
|
|
|
lm_head_weights = get_weight(model_params, 'lm_head', dtype)
|
|
|
|
if mapping.is_last_pp_rank():
|
|
if vocab_size % mapping.tp_size != 0:
|
|
# padding
|
|
vocab_size_padded = pad_vocab_size(vocab_size, mapping.tp_size)
|
|
pad_width = vocab_size_padded - vocab_size
|
|
|
|
lm_head_weights = torch.from_numpy(
|
|
np.pad(lm_head_weights.detach().cpu().numpy(),
|
|
((0, pad_width), (0, 0)),
|
|
'constant',
|
|
constant_values=0))
|
|
weights['lm_head.weight'] = split_matrix_tp(lm_head_weights,
|
|
tensor_parallel,
|
|
mapping.tp_rank,
|
|
dim=0)
|
|
ln_f_w = get_weight(model_params, 'model.norm', dtype)
|
|
weights['transformer.ln_f.weight'] = ln_f_w
|
|
|
|
tok = time.time()
|
|
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
|
|
logger.info(f'Weights loaded. Total time: {t}')
|
|
return weights
|