mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-24 04:33:04 +08:00
* Update TensorRT-LLM --------- Co-authored-by: Bhuvanesh Sridharan <bhuvan.sridharan@gmail.com> Co-authored-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
248 lines
11 KiB
Python
248 lines
11 KiB
Python
import time
|
|
from ast import literal_eval
|
|
from os import path
|
|
from pathlib import Path
|
|
from typing import Optional, Union
|
|
|
|
import numpy as np
|
|
|
|
from tensorrt_llm import logger
|
|
from tensorrt_llm._utils import numpy_to_dtype, str_dtype_to_np
|
|
from tensorrt_llm.functional import (LayerNormPositionType, LayerNormType,
|
|
MLPType)
|
|
from tensorrt_llm.mapping import Mapping
|
|
from tensorrt_llm.models import ( # TODO: probably need to change model name to distinguish from other models
|
|
DecoderModel, EncoderModel)
|
|
|
|
layernorm_type_map = {i.name: i.value for i in LayerNormType}
|
|
layernorm_position_map = {i.name: i.value for i in LayerNormPositionType}
|
|
mlp_type_map = {i.name: i.value for i in MLPType}
|
|
|
|
|
|
def parse_bart_config(config, component, args):
|
|
assert component in ('encoder', 'decoder'), 'Unsupported component!'
|
|
args.n_layer = config.getint(component, f'{component}_layers')
|
|
args.n_head = config.getint(component, f'{component}_attention_heads')
|
|
args.hidden_size = config.getint(component, 'd_model')
|
|
args.head_size = config.getint(component,
|
|
'd_kv',
|
|
fallback=args.hidden_size // args.n_head)
|
|
args.ffn_hidden_size = config.getint(component, f'{component}_ffn_dim')
|
|
args.vocab_size = config.getint(component, 'vocab_size')
|
|
args.n_positions = config.getint(component, 'max_position_embeddings')
|
|
args.has_position_embedding = config.getboolean(
|
|
component, 'has_position_embedding',
|
|
fallback=True) # TODO: hardcoded here
|
|
args.has_token_type_embedding = config.getboolean(
|
|
component, 'has_token_type_embedding', fallback=False)
|
|
args.has_embedding_layernorm = config.getboolean(component,
|
|
'has_embedding_layernorm',
|
|
fallback=True)
|
|
args.has_embedding_scale = config.getboolean(component, 'scale_embedding')
|
|
args.q_scaling = config.getfloat(component, 'q_scaling', fallback=1.0)
|
|
args.has_attention_qkvo_bias = config.getboolean('structure',
|
|
't5_with_bias',
|
|
fallback=True)
|
|
args.has_mlp_bias = config.getboolean('structure',
|
|
't5_with_bias',
|
|
fallback=True)
|
|
args.has_model_final_layernorm = config.getboolean(
|
|
component, 'has_model_final_layernorm')
|
|
args.layernorm_eps = config.getfloat(component,
|
|
'layer_norm_epsilon',
|
|
fallback=False)
|
|
|
|
normalize_before = config.getboolean(component, 'normalize_before')
|
|
args.layernorm_position = layernorm_position_map[
|
|
'pre_layernorm' if normalize_before else 'post_layernorm']
|
|
|
|
args.layernorm_type = layernorm_type_map[config.get(component,
|
|
'layernorm_type',
|
|
fallback='LayerNorm')]
|
|
args.hidden_act = config.get(component, 'activation_function')
|
|
args.gated_act = config.getboolean(component,
|
|
'is_gated_act',
|
|
fallback=False)
|
|
args.mlp_type = mlp_type_map['GatedMLP' if args.gated_act else 'MLP']
|
|
args.relative_attention = config.get(
|
|
'structure', 'position_embedding_type') == 'relative'
|
|
|
|
args.num_buckets = config.getint(component,
|
|
'relative_attention_num_buckets',
|
|
fallback=0)
|
|
args.max_distance = config.getint(component,
|
|
'relative_attention_max_distance',
|
|
fallback=0)
|
|
args.ckpt_weight_dtype = config.get(component, 'weight_data_type')
|
|
args.max_lora_rank = config.getint(component, 'max_lora_rank', fallback=0)
|
|
args.lora_target_modules = literal_eval(
|
|
config.get(component, 'lora_target_modules'))
|
|
args.hf_modules_to_trtllm_modules = literal_eval(
|
|
config.get(component, 'hf_modules_to_trtllm_modules'))
|
|
args.trtllm_modules_to_hf_modules = literal_eval(
|
|
config.get(component, 'trtllm_modules_to_hf_modules'))
|
|
|
|
if component == 'decoder':
|
|
args.rescale_before_lm_head = config.getboolean(
|
|
component, 'rescale_before_lm_head')
|
|
args.logits_dtype = config.get(component,
|
|
'logits_dtype',
|
|
fallback='float32')
|
|
args.encoder_hidden_size = config.getint('encoder', 'd_model')
|
|
args.encoder_num_heads = config.getint('encoder',
|
|
'encoder_attention_heads')
|
|
args.encoder_head_size = config.getint(
|
|
'encoder',
|
|
'd_kv',
|
|
fallback=args.encoder_hidden_size // args.encoder_num_heads)
|
|
|
|
return args
|
|
|
|
|
|
def load_from_binary_bart(tllm_model: Union[DecoderModel, EncoderModel],
|
|
dir_path,
|
|
args,
|
|
mapping: Optional[Mapping] = None,
|
|
dtype='float32',
|
|
use_parallel_embedding=False,
|
|
sharding_dim=0,
|
|
share_embedding_table=False,
|
|
scaling_factors=None):
|
|
|
|
logger.info('Loading weights from binary...')
|
|
tik = time.time()
|
|
|
|
if mapping is None:
|
|
mapping = Mapping()
|
|
|
|
ckpt_np_dtype = str_dtype_to_np(args.ckpt_weight_dtype)
|
|
|
|
def fromfile(name, split=True, shape=None) -> Optional[np.ndarray]:
|
|
p = path.join(
|
|
dir_path,
|
|
f'{name}.{mapping.tp_rank}.bin' if split else f'{name}.bin')
|
|
if Path(p).exists():
|
|
t = np.fromfile(p, dtype=ckpt_np_dtype)
|
|
t = numpy_to_dtype(t, dtype)
|
|
if shape is not None:
|
|
t = t.reshape(shape)
|
|
t = np.ascontiguousarray(t)
|
|
return t
|
|
return None
|
|
|
|
component = 'encoder' if isinstance(tllm_model, EncoderModel) else 'decoder'
|
|
|
|
# only load word / pos emb and emb layernorm to first PP rank
|
|
if mapping.is_first_pp_rank():
|
|
wte = fromfile(f'model.{component}.embed_tokens.weight',
|
|
shape=[args.vocab_size, -1],
|
|
split=False)
|
|
|
|
# word embedding
|
|
tllm_model.embedding.vocab_embedding.weight.value = wte
|
|
|
|
# positional embedding
|
|
wpe = fromfile(f'model.{component}.embed_positions.weight',
|
|
shape=[args.n_positions, args.hidden_size],
|
|
split=False)
|
|
tllm_model.embedding.position_embedding.weight.value = wpe
|
|
|
|
# Embedding layer norm
|
|
tllm_model.embedding.embedding_layernorm.weight.value = fromfile(
|
|
f'model.{component}.layernorm_embedding.weight', split=False)
|
|
tllm_model.embedding.embedding_layernorm.bias.value = fromfile(
|
|
f'model.{component}.layernorm_embedding.bias', split=False)
|
|
|
|
local_num_layers = tllm_model.num_layers
|
|
|
|
for local_idx, global_idx in enumerate(
|
|
range(mapping.pp_rank * local_num_layers,
|
|
(mapping.pp_rank + 1) * local_num_layers)
|
|
): # TODO: does this load the correct layers for PP?
|
|
layer = getattr(tllm_model, f'{component}_layers')[local_idx]
|
|
layer_prefix = f'model.{component}.layers.{global_idx}'
|
|
|
|
self_attention_layer = getattr(
|
|
layer, 'attention' if component == 'encoder' else 'self_attention')
|
|
|
|
# self attention
|
|
self_attention_layer.qkv.weight.value = fromfile(
|
|
f'{layer_prefix}.self_attn.qkv_proj.weight',
|
|
shape=[args.hidden_size * 3 // mapping.tp_size, args.hidden_size])
|
|
self_attention_layer.qkv.bias.value = fromfile(
|
|
f'{layer_prefix}.self_attn.qkv_proj.bias',
|
|
shape=[args.hidden_size * 3 // mapping.tp_size])
|
|
|
|
self_attention_layer.dense.weight.value = fromfile(
|
|
f'{layer_prefix}.self_attn.out_proj.weight',
|
|
shape=[args.hidden_size, args.hidden_size // mapping.tp_size])
|
|
self_attention_layer.dense.bias.value = fromfile(
|
|
f'{layer_prefix}.self_attn.out_proj.bias',
|
|
shape=[args.hidden_size],
|
|
split=False)
|
|
|
|
self_attention_layernorm = getattr(
|
|
layer, 'self_attention_layernorm'
|
|
if component == 'decoder' else 'attention_layernorm')
|
|
self_attention_layernorm.weight.value = fromfile(
|
|
f'{layer_prefix}.self_attn_layer_norm.weight', split=False)
|
|
self_attention_layernorm.bias.value = fromfile(
|
|
f'{layer_prefix}.self_attn_layer_norm.bias', split=False)
|
|
|
|
# cross attention
|
|
if component == 'decoder':
|
|
layer.cross_attention.qkv.weight.value = fromfile(
|
|
f'{layer_prefix}.encoder_attn.qkv_proj.weight',
|
|
shape=[
|
|
args.hidden_size * 3 // mapping.tp_size, args.hidden_size
|
|
])
|
|
layer.cross_attention.qkv.bias.value = fromfile(
|
|
f'{layer_prefix}.encoder_attn.qkv_proj.bias',
|
|
shape=[args.hidden_size * 3 // mapping.tp_size])
|
|
|
|
layer.cross_attention.dense.weight.value = fromfile(
|
|
f'{layer_prefix}.encoder_attn.out_proj.weight',
|
|
shape=[args.hidden_size, args.hidden_size // mapping.tp_size])
|
|
layer.cross_attention.dense.bias.value = fromfile(
|
|
f'{layer_prefix}.encoder_attn.out_proj.bias',
|
|
shape=[args.hidden_size],
|
|
split=False)
|
|
|
|
layer.cross_attention_layernorm.weight.value = fromfile(
|
|
f'{layer_prefix}.encoder_attn_layer_norm.weight', split=False)
|
|
layer.cross_attention_layernorm.bias.value = fromfile(
|
|
f'{layer_prefix}.encoder_attn_layer_norm.bias', split=False)
|
|
|
|
layer.mlp.fc.weight.value = fromfile(
|
|
f'{layer_prefix}.fc1.weight',
|
|
shape=[args.ffn_hidden_size // mapping.tp_size, args.hidden_size])
|
|
layer.mlp.fc.bias.value = fromfile(
|
|
f'{layer_prefix}.fc1.bias',
|
|
shape=[args.ffn_hidden_size // mapping.tp_size])
|
|
layer.mlp.proj.weight.value = fromfile(
|
|
f'{layer_prefix}.fc2.weight',
|
|
shape=[args.hidden_size, args.ffn_hidden_size // mapping.tp_size])
|
|
layer.mlp.proj.bias.value = fromfile(f'{layer_prefix}.fc2.bias',
|
|
shape=[args.hidden_size],
|
|
split=False)
|
|
|
|
layer.mlp_layernorm.weight.value = fromfile(
|
|
f'{layer_prefix}.final_layer_norm.weight', split=False)
|
|
layer.mlp_layernorm.bias.value = fromfile(
|
|
f'{layer_prefix}.final_layer_norm.bias', split=False)
|
|
|
|
if mapping.is_last_pp_rank():
|
|
if tllm_model.has_model_final_layernorm: # mBART true BART false
|
|
tllm_model.final_layernorm.weight.value = fromfile(
|
|
f'model.{component}.layer_norm.weight', split=False)
|
|
tllm_model.final_layernorm.bias.value = fromfile(
|
|
f'model.{component}.layer_norm.bias', split=False)
|
|
if component == 'decoder':
|
|
tllm_model.lm_head.weight.value = fromfile(
|
|
'lm_head.weight',
|
|
shape=[args.vocab_size // mapping.tp_size, args.hidden_size])
|
|
|
|
tok = time.time()
|
|
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
|
|
logger.info(f'Weights loaded. Total time: {t}')
|