mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-22 19:52:38 +08:00
122 lines
4.4 KiB
Python
Executable File
122 lines
4.4 KiB
Python
Executable File
import math
|
|
import typing
|
|
from typing import Union
|
|
|
|
import numpy as np
|
|
import torch # pytype: disable=import-error
|
|
|
|
from tensorrt_llm._utils import str_dtype_to_torch
|
|
|
|
|
|
def split(v: Union[np.ndarray, torch.Tensor],
|
|
tp_size: int,
|
|
tp_rank: int,
|
|
dim=0):
|
|
if tp_size == 1:
|
|
if isinstance(v, np.ndarray):
|
|
return np.ascontiguousarray(v.copy())
|
|
else:
|
|
return v.clone().detach()
|
|
assert len(v.shape) > 1 or dim == 0
|
|
if isinstance(v, np.ndarray):
|
|
return np.ascontiguousarray(
|
|
np.split(v, tp_size, axis=dim)[tp_rank].copy())
|
|
else:
|
|
assert v.shape[dim] % tp_size == 0, \
|
|
'Unable to split: shape={v.shape} (dim={dim}) tp_size={tp_size}.'
|
|
split_size = v.shape[dim] // tp_size
|
|
return v.split(split_size, dim=dim)[tp_rank].clone().detach()
|
|
|
|
|
|
def reshape(v: torch.Tensor, shape=None):
|
|
if shape is None:
|
|
return v.contiguous()
|
|
else:
|
|
return v.reshape(shape).contiguous()
|
|
|
|
|
|
def fuse_qkv_one_layer(params, attn_module_name, trtllm_layer_name, tp_size,
|
|
tp_rank, model_type, weight_shape, bias_shape):
|
|
|
|
qkv_module_names = get_qkv_module_name(model_type)
|
|
|
|
weight = {}
|
|
|
|
# fuse weights of q, k, v
|
|
q_w = params[f'{attn_module_name}.{qkv_module_names["q"]}.weight']
|
|
k_w = params[f'{attn_module_name}.{qkv_module_names["k"]}.weight']
|
|
v_w = params[f'{attn_module_name}.{qkv_module_names["v"]}.weight']
|
|
|
|
# fuse qkv weight
|
|
shape = q_w.shape # (do, din)
|
|
qkv_w = torch.cat([q_w, k_w, v_w],
|
|
dim=0).reshape([3, shape[0], shape[1]]) # (3, do, din)
|
|
qkv_w = split(qkv_w, tp_size, tp_rank, dim=1)
|
|
weight[f'{trtllm_layer_name}.qkv.weight'] = reshape(qkv_w,
|
|
shape=weight_shape)
|
|
|
|
# fuse qkv biases if present
|
|
if f'{attn_module_name}.{qkv_module_names["q"]}.bias' in params.keys(
|
|
) and params[f'{attn_module_name}.{qkv_module_names["q"]}.bias'] is not None:
|
|
q_b = params[f'{attn_module_name}.{qkv_module_names["q"]}.bias']
|
|
k_b = params[f'{attn_module_name}.{qkv_module_names["k"]}.bias']
|
|
v_b = params[f'{attn_module_name}.{qkv_module_names["v"]}.bias']
|
|
shape = q_b.shape[0] # (do,)
|
|
qkv_b = torch.cat([q_b, k_b, v_b], dim=0).reshape([3, shape]) # (3, do)
|
|
qkv_b = split(qkv_b, tp_size, tp_rank, dim=1)
|
|
weight[f'{trtllm_layer_name}.qkv.bias'] = reshape(qkv_b,
|
|
shape=bias_shape)
|
|
return weight
|
|
|
|
|
|
def get_qkv_module_name(model_type):
|
|
if model_type in ["t5", "blip2"]:
|
|
q = "q"
|
|
k = "k"
|
|
v = "v"
|
|
elif model_type == "bart" or model_type == "nmt" or model_type == "language_adapter":
|
|
q = "q_proj"
|
|
k = "k_proj"
|
|
v = "v_proj"
|
|
elif model_type == "pix2struct":
|
|
q = "query"
|
|
k = "key"
|
|
v = "value"
|
|
return {"q": q, "k": k, "v": v}
|
|
|
|
|
|
def convert_weight_to_dtype(params: typing.Dict[str, torch.Tensor],
|
|
dtype: typing.Optional[np.dtype] = None):
|
|
if dtype is not None:
|
|
assert isinstance(dtype,
|
|
str), f"dtype must be str, but get type {type(dtype)}"
|
|
for name in params.keys():
|
|
params[name] = params[name].to(str_dtype_to_torch(dtype))
|
|
|
|
|
|
def fairseq_sin_pos_embedding(num_embeddings: int, embedding_dim: int):
|
|
'''
|
|
generate fairseq specific sinusoidal position embedding [sin, sin, ... cos, cos...]
|
|
https://github.com/facebookresearch/fairseq/blob/main/fairseq/modules/sinusoidal_positional_embedding.py
|
|
'''
|
|
padding_offset = 2
|
|
half_dim = embedding_dim // 2.0
|
|
emb = math.log(10000) / (half_dim - 1)
|
|
emb = torch.exp(torch.arange(half_dim, dtype=torch.float16) * -emb)
|
|
emb = torch.arange(num_embeddings + padding_offset,
|
|
dtype=torch.float16).unsqueeze(1) * emb.unsqueeze(0)
|
|
emb = torch.cat([torch.sin(emb), torch.cos(emb)],
|
|
dim=1).view(num_embeddings + padding_offset, -1)
|
|
if embedding_dim % 2 == 1:
|
|
# zero pad
|
|
emb = torch.cat(
|
|
[emb, torch.zeros(num_embeddings + padding_offset, 1)],
|
|
dim=1,
|
|
dtype=torch.float16)
|
|
'''
|
|
remove first 2 column to match position_id setup difference between fairseq & trt
|
|
fairseq position_id starts with 2, ex: [2, 3, 4 ..]
|
|
trt position_id starts with 0 [0, 1, 2]
|
|
'''
|
|
return emb[padding_offset:, :]
|