TensorRT-LLMs/tensorrt_llm/models/cogvlm/convert.py
Kaiyu Xie aaacc9bd68
Update TensorRT-LLM (#2562)
* Update TensorRT-LLM

---------

Co-authored-by: Starrick Liu <73152103+StarrickLiu@users.noreply.github.com>
2024-12-11 00:31:05 -08:00

241 lines
11 KiB
Python

import time
import numpy as np
import torch
from tensorrt_llm.logger import logger
from ..._utils import pad_vocab_size
from ..llama.convert import (get_tllm_linear_weight, get_weight, split,
split_matrix_tp, split_qkv_tp)
def convert_hf_cogvlm(hf_model,
mapping,
vocab_size=32000,
dtype='float32',
use_parallel_embedding=False,
sharding_dim=0,
use_weight_only=False,
use_gemm_woq_plugin=False,
plugin_weight_only_quant_type=torch.int8,
use_smooth_quant=False,
per_channel=False,
per_token=False,
int8_kv_cache=False,
act_range=[],
qkv_para=[],
smoother=[]):
weights = {}
tik = time.time()
tensor_parallel = mapping.tp_size
model_params = dict(hf_model.named_parameters())
dtype = getattr(torch, dtype)
num_attention_heads = hf_model.config.num_attention_heads
hidden_size = hf_model.config.hidden_size
if hasattr(hf_model.config, "num_key_value_heads"):
num_key_value_heads = hf_model.config.num_key_value_heads
else:
num_key_value_heads = num_attention_heads
mha_mode = (num_key_value_heads == num_attention_heads)
layers_range = mapping.pp_layers(hf_model.config.num_hidden_layers)
assert mha_mode, "CogVLM only supports mha mode"
assert not use_smooth_quant, "CogVLM currently doesn't support smooth quant"
assert not int8_kv_cache, "CogVLM currently doesn't support int8 kv cache"
for l in layers_range:
prefix = f'model.layers.{l}.'
tllm_prex = f'transformer.layers.{l - layers_range[0]}.'
qkv_weight = get_weight(
model_params, prefix + 'self_attn.language_expert_query_key_value',
dtype)
split_v = split_qkv_tp(qkv_weight, num_attention_heads, hidden_size,
tensor_parallel, mapping.tp_rank)
weights.update(
get_tllm_linear_weight(split_v, tllm_prex + 'attention.qkv.', None,
use_weight_only,
plugin_weight_only_quant_type, dtype,
use_gemm_woq_plugin))
vis_qkv_weight = get_weight(
model_params, prefix + 'self_attn.vision_expert_query_key_value',
dtype)
split_v = split_qkv_tp(vis_qkv_weight, num_attention_heads, hidden_size,
tensor_parallel, mapping.tp_rank)
weights.update(
get_tllm_linear_weight(split_v, tllm_prex + 'attention.vis_qkv.',
None, use_weight_only,
plugin_weight_only_quant_type, dtype,
use_gemm_woq_plugin))
attn_dense_weight = get_weight(
model_params, prefix + 'self_attn.language_expert_dense', dtype)
split_v = split_matrix_tp(attn_dense_weight,
tensor_parallel,
mapping.tp_rank,
dim=1)
weights.update(
get_tllm_linear_weight(split_v, tllm_prex + 'attention.dense.',
None, use_weight_only,
plugin_weight_only_quant_type, dtype,
use_gemm_woq_plugin))
attn_vision_dense_weight = get_weight(
model_params, prefix + 'self_attn.vision_expert_dense', dtype)
split_v = split_matrix_tp(attn_vision_dense_weight,
tensor_parallel,
mapping.tp_rank,
dim=1)
weights.update(
get_tllm_linear_weight(split_v, tllm_prex + 'attention.vis_dense.',
None, use_weight_only,
plugin_weight_only_quant_type, dtype,
use_gemm_woq_plugin))
mlp_gate_weight = get_weight(model_params,
prefix + 'mlp.language_mlp.up_proj', dtype)
split_v = split_matrix_tp(mlp_gate_weight,
tensor_parallel,
mapping.tp_rank,
dim=0)
weights.update(
get_tllm_linear_weight(split_v, tllm_prex + 'mlp.gate.', None,
use_weight_only,
plugin_weight_only_quant_type, dtype,
use_gemm_woq_plugin))
vision_mlp_gate_weight = get_weight(model_params,
prefix + 'mlp.vision_mlp.up_proj',
dtype)
split_v = split_matrix_tp(vision_mlp_gate_weight,
tensor_parallel,
mapping.tp_rank,
dim=0)
weights.update(
get_tllm_linear_weight(split_v, tllm_prex + 'vis_mlp.gate.', None,
use_weight_only,
plugin_weight_only_quant_type, dtype,
use_gemm_woq_plugin))
mlp_fc_weight = get_weight(model_params,
prefix + 'mlp.language_mlp.gate_proj', dtype)
split_v = split_matrix_tp(mlp_fc_weight,
tensor_parallel,
mapping.tp_rank,
dim=0)
weights.update(
get_tllm_linear_weight(split_v, tllm_prex + 'mlp.fc.', None,
use_weight_only,
plugin_weight_only_quant_type, dtype,
use_gemm_woq_plugin))
vision_mlp_fc_weight = get_weight(model_params,
prefix + 'mlp.vision_mlp.gate_proj',
dtype)
split_v = split_matrix_tp(vision_mlp_fc_weight,
tensor_parallel,
mapping.tp_rank,
dim=0)
weights.update(
get_tllm_linear_weight(split_v, tllm_prex + 'vis_mlp.fc.', None,
use_weight_only,
plugin_weight_only_quant_type, dtype,
use_gemm_woq_plugin))
mlp_proj_weight = get_weight(model_params,
prefix + 'mlp.language_mlp.down_proj',
dtype)
split_v = split_matrix_tp(mlp_proj_weight,
tensor_parallel,
mapping.tp_rank,
dim=1)
weights.update(
get_tllm_linear_weight(split_v, tllm_prex + 'mlp.proj.', None,
use_weight_only,
plugin_weight_only_quant_type, dtype,
use_gemm_woq_plugin))
vision_mlp_proj_weight = get_weight(model_params,
prefix + 'mlp.vision_mlp.down_proj',
dtype)
split_v = split_matrix_tp(vision_mlp_proj_weight,
tensor_parallel,
mapping.tp_rank,
dim=1)
weights.update(
get_tllm_linear_weight(split_v, tllm_prex + 'vis_mlp.proj.', None,
use_weight_only,
plugin_weight_only_quant_type, dtype,
use_gemm_woq_plugin))
# Layer norms do not use tensor parallelism
input_ln_weight = get_weight(model_params, prefix + 'input_layernorm',
dtype)
weights[tllm_prex + 'input_layernorm.weight'] = input_ln_weight
post_ln_weight = get_weight(model_params,
prefix + 'post_attention_layernorm', dtype)
weights[tllm_prex + 'post_layernorm.weight'] = post_ln_weight
cur_block_weights = [
weight_name for weight_name in model_params
if weight_name.find(prefix) != -1
]
for weight_name in cur_block_weights:
model_params[weight_name] = None
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
v = get_weight(model_params, 'model.embed_tokens', dtype)
if hf_model.config.tie_word_embeddings:
# lm_head.weight has the same weights as embedding
if mapping.is_last_pp_rank():
if vocab_size % mapping.tp_size != 0:
# padding
vocab_size_padded = pad_vocab_size(vocab_size, mapping.tp_size)
pad_width = vocab_size_padded - vocab_size
v = torch.from_numpy(
np.pad(v.detach().cpu().numpy(), ((0, pad_width), (0, 0)),
'constant',
constant_values=0))
weights['lm_head.weight'] = split(v, mapping.tp_size,
mapping.tp_rank)
if use_parallel_embedding:
v = split_matrix_tp(v,
mapping.tp_size,
mapping.tp_rank,
dim=sharding_dim)
if mapping.is_first_pp_rank():
weights['transformer.vocab_embedding.weight'] = v
lm_head_weights = get_weight(model_params, 'lm_head', dtype)
if mapping.is_last_pp_rank():
if vocab_size % mapping.tp_size != 0:
# padding
vocab_size_padded = pad_vocab_size(vocab_size, mapping.tp_size)
pad_width = vocab_size_padded - vocab_size
lm_head_weights = torch.from_numpy(
np.pad(lm_head_weights.detach().cpu().numpy(),
((0, pad_width), (0, 0)),
'constant',
constant_values=0))
weights['lm_head.weight'] = split_matrix_tp(lm_head_weights,
tensor_parallel,
mapping.tp_rank,
dim=0)
ln_f_w = get_weight(model_params, 'model.norm', dtype)
weights['transformer.ln_f.weight'] = ln_f_w
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
logger.info(f'Weights loaded. Total time: {t}')
return weights