TensorRT-LLMs/examples/gptneox/convert_checkpoint.py
Kaiyu Xie b57221b764
Update TensorRT-LLM (#941)
* Update TensorRT-LLM

---------

Co-authored-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
2024-01-23 23:22:35 +08:00

743 lines
31 KiB
Python

import argparse
import json
import os
import time
from concurrent.futures import ThreadPoolExecutor, wait
from typing import List, Optional
import safetensors
import safetensors.torch
import torch
from safetensors import safe_open
from transformers import AutoConfig, AutoModelForCausalLM
import tensorrt_llm
from tensorrt_llm._utils import str_dtype_to_torch
from tensorrt_llm.mapping import Mapping
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('--model_dir', type=str, default=None)
parser.add_argument('--world_size',
type=int,
default=1,
help='world size, only support tensor parallelism now')
parser.add_argument('--dtype',
type=str,
default='float16',
choices=['float32', 'bfloat16', 'float16'])
parser.add_argument(
'--use_weight_only',
default=False,
action="store_true",
help='Quantize weights for the various GEMMs to INT4/INT8.'
'See --weight_only_precision to set the precision')
parser.add_argument(
'--weight_only_precision',
const='int8',
type=str,
nargs='?',
default='int8',
choices=['int8', 'int4', 'int4_gptq'],
help=
'Define the precision for the weights when using weight-only quantization.'
'You must also use --use_weight_only for that argument to have an impact.'
)
parser.add_argument('--per_group',
action="store_true",
default=False,
help='Use per group quantization')
parser.add_argument('--ammo_quant_ckpt_path',
type=str,
default=None,
help='Path of a quantized model checkpoint')
parser.add_argument(
'--use_parallel_embedding',
action="store_true",
default=False,
help=
'By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled'
)
parser.add_argument(
'--embedding_sharding_dim',
type=int,
default=0,
choices=[0, 1],
help=
'By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). '
'To shard it along hidden dimension, set embedding_sharding_dim=1'
'Note: embedding sharing is only enabled when embedding_sharding_dim = 0'
)
parser.add_argument(
'--use_embedding_sharing',
action="store_true",
default=False,
help=
'Try to reduce the engine size by sharing the embedding lookup table between two layers.'
'Note: the flag might not take effect when the criteria are not met.')
parser.add_argument('--output_dir',
type=str,
default='tllm_checkpoint',
help='The path to save the TensorRT-LLM checkpoint')
parser.add_argument(
'--workers',
type=int,
default=1,
help='The number of workers for converting checkpoint in parallel')
args = parser.parse_args()
if args.use_weight_only:
if args.per_group and args.weight_only_precision == 'int4':
args.weight_only_precision = 'int4_gptq'
elif args.weight_only_precision == 'int4_gptq':
args.per_group = True
return args
# TODO: Seems all convert checkpoints may use following utility functions.
# Maybe in one common version.
def reorder_qkv_weight_or_bias(weight: torch.Tensor,
head_dim: int,
num_heads: int,
num_kv_heads: Optional[int] = None,
tp_size: int = 1,
is_bias: bool = False) -> torch.Tensor:
""" Reorder the qkv weight for TRT-LLM use.
The shape of the fused QKV weights in HF is different from the shape that
TRT-LLM requires. In particular, the weight of HF consists of interleaved
q, k, v head weights, while that of TRT-LLM is contiguous.
HF : [q1, k1, v1, ..., qh, kh, vh]
TRT-LLM: [q1, ..., qh, k1, ..., kh, v1, vh]
where qi, vi, ki are weight vectors corresponding to attention head i.
It's similar to multi/grouped query attention cases.
We reorder and split the weight of an attention layer to fit into TRT-LLM.
The reordered weight and bias will be
weight: (T, Qh * D + 2 * KVh * D, H)
bias : (T, Qh * D + 2 * KVh * D)
where T=tp_size, Qh=local_num_q_heads, KVh=local_num_kv_heads, D=head_dim,
H=hidden_dim. In the multi/grouped query attention, the number of K/V
attention heads are less than that of Q attention, so that K/V attention
heads may be shared across different ranks if necessary.
For tensor parallelism, we use the first dimension to select the
corresponding weights.
"""
# Query types and expected kv heads.
# - Conventional MHA: num_heads = num_kv_heads
# - Multi-Query Attention: num_kv_heads = 1
# - Grouped-Query Attention: num_heads % num_kv_heads = 0
num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
assert num_heads % num_kv_heads == 0, \
f'num_heads({num_heads}) must be divisible by ' \
f'num_kv_heads({num_kv_heads})).'
# The number of attention heads per group: N q head + 1 k head + 1 v head.
num_group_heads = num_heads // num_kv_heads + 2
assert weight.shape[0] == num_kv_heads * num_group_heads * head_dim, \
f'{weight.shape[0]} != {num_kv_heads} * {num_group_heads} * {head_dim}'
qkv_in = num_heads * head_dim if not is_bias else 1
# Split Q/K/V weights
weight = weight.reshape(num_kv_heads, num_heads // num_kv_heads + 2,
head_dim, qkv_in)
q_w = weight[:, :-2, ...] # (nKV, num_heads // nKV, head_dim, qkv_in)
k_w = weight[:, -2:-1, ...] # (nKV, 1, head_dim, qkv_in)
v_w = weight[:, -1:, ...] # (nKV, 1, head_dim, qkv_in)
if num_kv_heads < num_heads and num_kv_heads < tp_size:
# Duplicate K/V heads to make sure that each rank has at least one
# K/V heads. For instance, num_heads=8, num_kv_heads=2, tp_size=4,
# we will make the qkv weight as below.
# Orig: [q0 q1 q2 q3 k0 v0 q4 q5 q6 q7 k1 v0 v1]
# >>>> [[q0 q1 k0 v0], [q2 q3 k0 v0], [q4 q5 k1 v1], [q6 q7 k1 v1]]
assert tp_size % num_kv_heads == 0
num_dups = tp_size // num_kv_heads
# k_w and v_w have the same shape.
new_shape = (num_kv_heads, num_dups) + k_w.shape[2:]
k_w = torch.broadcast_to(k_w, size=new_shape)
v_w = torch.broadcast_to(v_w, size=new_shape)
# Update the number of kv heads.
num_kv_heads = tp_size
reordered = torch.concat(
[
q_w.reshape(tp_size, num_heads // tp_size, head_dim, qkv_in),
k_w.reshape(tp_size, num_kv_heads // tp_size, head_dim, qkv_in),
v_w.reshape(tp_size, num_kv_heads // tp_size, head_dim, qkv_in),
],
dim=1,
)
qkv_out = (num_heads + 2 * num_kv_heads) // tp_size * head_dim
return reordered.reshape((tp_size, qkv_out, -1))
def load_from_gptq_gptneox(quant_ckpt_path,
hf_config=None,
use_parallel_embedding=False,
sharding_dim=0,
share_embedding_table=False,
mapping=Mapping(),
dtype='float16'):
tensorrt_llm.logger.info(
'Loading weights from groupwise GPTQ LLaMA safetensors...')
weights = {}
tik = time.time()
gptq_model = safe_open(quant_ckpt_path, framework="pt", device=0)
gptq_prefix = "gpt_neox."
gptq_suffix_list = [".qweight", ".qzeros", ".scales"]
split_sym = "."
packer = torch.ops.trtllm.pack_int8_tensor_to_packed_int4
preprocessor = torch.ops.trtllm.preprocess_weights_for_mixed_gemm
torch_dtype = str_dtype_to_torch(dtype)
def load(key, no_prefix=0):
if no_prefix:
return gptq_model.get_tensor(key).cpu()
else:
return gptq_model.get_tensor(gptq_prefix + key).cpu()
def torch_split(v, dim):
if v.shape[dim] % mapping.tp_size != 0:
tensorrt_llm.logger.error(
"Current weight shape is invalid for mapping.tp_size=" +
str(mapping.tp_size))
assert False, "Invalid TP size"
return v.split(v.shape[dim] // mapping.tp_size,
dim=dim)[mapping.tp_rank].contiguous()
def unpack_int32_into_int8(w_packed):
# Unpack inputs packed in int32/float32 into uint4 and store them in int8 format
w_packed_int4x2 = w_packed.contiguous().view(torch.uint8)
w_unpacked = torch.zeros(w_packed_int4x2.shape[0],
w_packed_int4x2.shape[1] * 2,
dtype=torch.int8,
device=w_packed.device)
w_unpacked[:, ::2] = w_packed_int4x2 % 16
w_unpacked[:, 1::2] = w_packed_int4x2 // 16
return w_unpacked.contiguous()
def process_and_assign_weight(v: List[torch.Tensor],
tllm_prex: str,
tp_dim: int = -1):
if tp_dim == -1:
qweight_int32, qzeros_int32, scales_fp16 = [
item.cpu() for item in v
]
else:
qweight_int32, qzeros_int32, scales_fp16 = [
torch_split(item, tp_dim).cpu() for item in v
]
USE_UINT4_INPUT = 1 # Set to true if checkpoint store UINT4 weights
USE_GPTQ_FOR_LLAMA = 1 # GPTQ-for-LLaMA added 1 to zeros
qweight_unpacked_int8 = unpack_int32_into_int8(
qweight_int32.T).T.contiguous() - 8
qweight_interleaved = preprocessor(packer(qweight_unpacked_int8),
torch.quint4x2).view(torch.float16)
# zeros = zeros * scales
qzeros_unpacked_int32 = unpack_int32_into_int8(qzeros_int32)
if not USE_UINT4_INPUT:
# Correcting UINT4 values back to INT4 order
mask_negative = qzeros_unpacked_int32[qzeros_unpacked_int32 < 0]
mask_positive = qzeros_unpacked_int32[qzeros_unpacked_int32 >= 0]
qzeros_unpacked_int32 = qzeros_unpacked_int32 + 16 * mask_negative - 16 * mask_positive
zeros_x_scales_fp16 = (-qzeros_unpacked_int32 + 8 * USE_UINT4_INPUT -
USE_GPTQ_FOR_LLAMA) * scales_fp16
zeros_x_scales_fp16 = zeros_x_scales_fp16.half()
results = {
f'{tllm_prex}.weight': qweight_interleaved,
f'{tllm_prex}.weights_scaling_factor': scales_fp16,
f'{tllm_prex}.zero': zeros_x_scales_fp16,
}
return results
def preprocess_groupwise_weight_params(qweight_unpacked_int8, scales_fp16,
qzeros_unpacked_int8):
UINT4_TO_INT4_FLAG = 1
GPTQ_FLAG = 1
qweight_interleaved = preprocessor(packer(qweight_unpacked_int8),
torch.quint4x2).view(torch.float16)
# zeros = zeros * scales
zeros_x_scales_fp16 = (-qzeros_unpacked_int8 + 8 * UINT4_TO_INT4_FLAG -
GPTQ_FLAG) * scales_fp16
zeros_x_scales_fp16 = zeros_x_scales_fp16.half()
# return processed interleaved weight, original scales and zeros * scales
return qweight_interleaved.contiguous(), scales_fp16.contiguous(
), zeros_x_scales_fp16.contiguous()
# Load weights from GPTQ checkpoint into TRT-LLM module
# 1. vocab_embedding
v = load('embed_in.weight')
if mapping.is_first_pp_rank():
if not use_parallel_embedding:
weights['transformer.vocab_embedding.weight'] = v.to(torch_dtype)
else:
assert hf_config.vocab_size % mapping.tp_size == 0
weights['transformer.vocab_embedding.weight'] = torch_split(
v, sharding_dim).to(torch_dtype)
# 2. lm_head
if not share_embedding_table:
v = load('embed_out.weight', no_prefix=1)
if mapping.is_last_pp_rank():
if not share_embedding_table:
weights['lm_head.weight'] = torch_split(v, 0).to(torch_dtype)
elif not mapping.is_first_pp_rank():
weights['transformer.vocab_embedding.weight'] = torch_split(
v, 0).to(torch_dtype)
# 3. ln_f
v = load('final_layer_norm.weight')
b = load('final_layer_norm.bias')
if mapping.is_last_pp_rank():
weights['transformer.ln_f.weight'] = v.to(torch_dtype)
weights['transformer.ln_f.bias'] = b.to(torch_dtype)
# 4. Weights inside each layer
num_hidden_layers = hf_config.num_hidden_layers
layers_per_pipeline_stage = num_hidden_layers // mapping.pp_size
layers_range = list(
range(mapping.pp_rank * layers_per_pipeline_stage,
(mapping.pp_rank + 1) * layers_per_pipeline_stage, 1))
for l in layers_range:
layer_idx = l - mapping.pp_rank * layers_per_pipeline_stage
prefix = "layers" + split_sym + str(l) + split_sym
tensorrt_llm.logger.info(f'Process weights in layer: {layer_idx}')
# layer = tensorrt_llm_llama.layers[layer_idx]
tllm_prex = f'transformer.layers.{l - layers_range[0]}'
# 4.1 attention.qkv
num_heads = hf_config.num_attention_heads
hidden_size = hf_config.hidden_size
head_size = hidden_size // num_heads
qweight_int32 = load(prefix + 'attention.query_key_value.qweight')
scales_fp16 = load(prefix + 'attention.query_key_value.scales')
qzeros_int32 = load(prefix + 'attention.query_key_value.qzeros')
biases_fp16 = load(prefix + 'attention.query_key_value.bias')
GROUP_SIZE = hidden_size // scales_fp16.shape[0]
# [hidden_size // 8, hidden_size * 3] -> [hidden_size * 3, hidden_size]
qweight_unpacked_int8 = unpack_int32_into_int8(
qweight_int32.T).contiguous() - 8
# [hidden_size // GROUP_SIZE, hidden_size * 3 // 8] ->
# [hidden_size // GROUP_SIZE, hidden_size * 3]
qzeros_unpacked_int8 = unpack_int32_into_int8(qzeros_int32)
# qkv_weights [num_heads x (q|k|v), hidden_size] ->
# [(num_heads x q)|(num_heads x k)|(num_heads x v), hidden_size]
new_qkv_weight_shape = torch.Size(
[num_heads, 3, head_size * qweight_unpacked_int8.size()[-1]])
# [hidden_size * 3, hidden_size]
qweight_unpacked_int8 = qweight_unpacked_int8.view(
new_qkv_weight_shape).permute(1, 0, 2).reshape(
[hidden_size * 3, hidden_size]).contiguous()
new_qkv_scale_shape = torch.Size(
[num_heads, 3, head_size * (hidden_size // GROUP_SIZE)])
# [hidden_size * 3, hidden_size // GROUP_SIZE]
scales_fp16 = scales_fp16.T.contiguous().view(
new_qkv_scale_shape).permute(1, 0, 2).reshape(
[hidden_size * 3, hidden_size // GROUP_SIZE]).contiguous()
new_qkv_zero_shape = torch.Size(
[num_heads, 3, head_size * (hidden_size // GROUP_SIZE)])
# [hidden_size * 3, hidden_size // GROUP_SIZE]
qzeros_unpacked_int8 = qzeros_unpacked_int8.T.contiguous().view(
new_qkv_zero_shape).permute(1, 0, 2).reshape(
[hidden_size * 3, hidden_size // GROUP_SIZE]).contiguous()
new_qkv_bias_shape = torch.Size([num_heads, 3, head_size])
biases_fp16 = biases_fp16.view(new_qkv_bias_shape).permute(
1, 0, 2).reshape([hidden_size * 3])
tp_size = mapping.tp_size
if tp_size > 1:
qweight_unpacked_int8 = qweight_unpacked_int8.reshape(
[3, hidden_size, hidden_size])
qweight_unpacked_int8 = torch_split(qweight_unpacked_int8, dim=1)
qweight_unpacked_int8 = qweight_unpacked_int8.reshape(
[3 * hidden_size // tp_size, hidden_size])
scales_fp16 = scales_fp16.reshape(
[3, hidden_size, hidden_size // GROUP_SIZE])
scales_fp16 = torch_split(scales_fp16, dim=1)
scales_fp16 = scales_fp16.reshape(
[3 * hidden_size // tp_size, hidden_size // GROUP_SIZE])
qzeros_unpacked_int8 = qzeros_unpacked_int8.reshape(
[3, hidden_size, hidden_size // GROUP_SIZE])
qzeros_unpacked_int8 = torch_split(qzeros_unpacked_int8, dim=1)
qzeros_unpacked_int8 = qzeros_unpacked_int8.reshape(
[3 * hidden_size // tp_size, hidden_size // GROUP_SIZE])
biases_fp16 = biases_fp16.reshape([3, hidden_size])
biases_fp16 = torch_split(biases_fp16, dim=1)
biases_fp16 = biases_fp16.reshape([3 * hidden_size // tp_size])
qweight_fp32, scales_fp16, zeros_fp16 = preprocess_groupwise_weight_params(
qweight_unpacked_int8.T.contiguous(), scales_fp16.T.contiguous(),
qzeros_unpacked_int8.T.contiguous())
weights.update({
f'{tllm_prex}.attention.qkv.weight': qweight_fp32,
f'{tllm_prex}.attention.qkv.weights_scaling_factor': scales_fp16,
f'{tllm_prex}.attention.qkv.zero': zeros_fp16,
f'{tllm_prex}.attention.qkv.bias': biases_fp16,
})
# 4.2 attention.dense
v = [load(prefix + 'attention.dense' + suf) for suf in gptq_suffix_list]
# pre scaling down for duplicated bias add between different tp ranks
b = load(prefix + 'attention.dense.bias') / mapping.tp_size
weights.update(
process_and_assign_weight(v,
f'{tllm_prex}.attention.dense',
tp_dim=0))
weights.update({f'{tllm_prex}.attention.dense.bias': b.to(torch_dtype)})
# 4.3 mlp.fc
v = [
load(prefix + 'mlp.dense_h_to_4h' + suf) for suf in gptq_suffix_list
]
b = load(prefix + 'mlp.dense_h_to_4h.bias')
weights.update(
process_and_assign_weight(v, f'{tllm_prex}.mlp.fc', tp_dim=1))
weights.update(
{f'{tllm_prex}.mlp.fc.bias': torch_split(b, dim=0).to(torch_dtype)})
# 4.4 mlp.proj
v = [
load(prefix + 'mlp.dense_4h_to_h' + suf) for suf in gptq_suffix_list
]
# pre scaling down for duplicated bias add between different tp ranks
b = load(prefix + 'mlp.dense_4h_to_h.bias') / mapping.tp_size
weights.update(
process_and_assign_weight(v, f'{tllm_prex}.mlp.proj', tp_dim=0))
weights.update({f'{tllm_prex}.mlp.proj.bias': b.to(torch_dtype)})
# 4.5 input_layernorm
v = load(prefix + 'input_layernorm.weight')
b = load(prefix + 'input_layernorm.bias')
weights[f'{tllm_prex}.input_layernorm.weight'] = v.to(torch_dtype)
weights[f'{tllm_prex}.input_layernorm.bias'] = b.to(torch_dtype)
# 4.6 post_layernorm
v = load(prefix + 'post_attention_layernorm.weight')
b = load(prefix + 'post_attention_layernorm.bias')
weights[f'{tllm_prex}.post_attention_layernorm.weight'] = v.to(
torch_dtype)
weights[f'{tllm_prex}.post_attention_layernorm.bias'] = b.to(
torch_dtype)
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
tensorrt_llm.logger.info(f'Weights loaded. Total time: {t}')
return weights
def split_qkv_weight(weight: torch.Tensor,
hidden_size: int,
num_heads: int,
tp_size: int,
rank: int,
is_bias: bool,
num_kv_heads: Optional[int] = None) -> torch.Tensor:
""" Splits the QKV matrix according to tensor parallelism """
head_dim = hidden_size // num_heads
weight = reorder_qkv_weight_or_bias(weight,
head_dim=head_dim,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
tp_size=tp_size,
is_bias=is_bias)
# Copy a sliced tensor to prevent memory leak. A sliced tensor shares the
# memory buffer of the original tensor. So, returning without copying makes
# the buffer of a loaded "qkv" be referenced, resulting GC can't release
# those weights until the whole process ends.
if not is_bias:
return weight[rank, ...].clone().contiguous()
else:
return weight[rank, ...].ravel().clone().contiguous()
def split(v, tp_size, idx, dim=0):
if tp_size == 1:
return v
if len(v.shape) == 1:
return torch.chunk(v, tp_size)[idx].contiguous()
else:
return torch.chunk(v, tp_size, dim=dim)[idx].contiguous()
def split_matrix_tp(v, tensor_parallel, rank, dim):
return split(v, tensor_parallel, rank, dim=dim)
def get_weight(config, prefix, dtype):
return config[prefix + '.weight'].to(dtype).detach()
def get_bias(config, prefix, dtype):
return config[prefix + '.bias'].to(dtype).detach()
def get_weight_and_bias(config, prefix, dtype):
return get_weight(config, prefix, dtype), get_bias(config, prefix, dtype)
def get_tllm_linear_weight(weight,
prefix,
bias=None,
use_weight_only=False,
plugin_weight_only_quant_type=torch.int8):
results = {}
if use_weight_only:
v = weight.t().contiguous()
processed_torch_weights, torch_weight_scales = \
torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix(
v, plugin_weight_only_quant_type)
results[prefix + 'weight'] = processed_torch_weights
results[prefix + 'per_channel_scale'] = torch_weight_scales
else:
results[prefix + 'weight'] = weight.contiguous()
if bias is not None:
results[prefix + 'bias'] = bias
return results
def convert_hf_gptneox(hf_model,
mapping: Mapping,
dtype='float32',
use_parallel_embedding=False,
sharding_dim=0,
share_embedding_table=False,
use_weight_only=False,
plugin_weight_only_quant_type=torch.int8):
weights = {}
tik = time.time()
model_params = dict(hf_model.named_parameters())
dtype = getattr(torch, dtype)
num_attention_heads = hf_model.config.num_attention_heads
hidden_size = hf_model.config.hidden_size
tensor_parallel = mapping.tp_size
rank = mapping.rank
for l in range(hf_model.config.num_hidden_layers):
prefix = f'gpt_neox.layers.{l}.'
tllm_prex = f'transformer.layers.{l}.'
qkv_weight, qkv_bias = get_weight_and_bias(
model_params, prefix + 'attention.query_key_value', dtype)
qkv_w = split_qkv_weight(qkv_weight,
hidden_size,
num_attention_heads,
mapping.tp_size,
mapping.tp_rank,
is_bias=False,
num_kv_heads=num_attention_heads)
if qkv_bias is None:
qkv_b = None
else:
qkv_b = split_qkv_weight(qkv_bias,
hidden_size,
num_attention_heads,
mapping.tp_size,
mapping.tp_rank,
is_bias=True,
num_kv_heads=num_attention_heads)
weights.update(
get_tllm_linear_weight(qkv_w, tllm_prex + 'attention.qkv.', qkv_b,
use_weight_only,
plugin_weight_only_quant_type))
attn_dense_weight, attn_dense_bias = get_weight_and_bias(
model_params, prefix + 'attention.dense', dtype)
split_v = split_matrix_tp(attn_dense_weight,
tensor_parallel,
rank,
dim=1)
weights.update(
get_tllm_linear_weight(split_v, tllm_prex + 'attention.dense.',
attn_dense_bias, use_weight_only,
plugin_weight_only_quant_type))
mlp_fc_weight, mlp_fc_bias = get_weight_and_bias(
model_params, prefix + 'mlp.dense_h_to_4h', dtype)
split_v = split_matrix_tp(mlp_fc_weight, tensor_parallel, rank, dim=0)
bias = split_matrix_tp(mlp_fc_bias, tensor_parallel, rank, dim=0)
weights.update(
get_tllm_linear_weight(split_v, tllm_prex + 'mlp.fc.', bias,
use_weight_only,
plugin_weight_only_quant_type))
mlp_proj_weight, mlp_proj_bias = get_weight_and_bias(
model_params, prefix + 'mlp.dense_4h_to_h', dtype)
split_v = split_matrix_tp(mlp_proj_weight, tensor_parallel, rank, dim=1)
weights.update(
get_tllm_linear_weight(split_v, tllm_prex + 'mlp.proj.',
mlp_proj_bias, use_weight_only,
plugin_weight_only_quant_type))
# Layer norms do not use tensor parallelism
input_ln_weight, input_ln_bias = get_weight_and_bias(
model_params, prefix + 'input_layernorm', dtype)
weights[tllm_prex + 'input_layernorm.weight'] = input_ln_weight
weights[tllm_prex + 'input_layernorm.bias'] = input_ln_bias
post_ln_weight, post_ln_bias = get_weight_and_bias(
model_params, prefix + 'post_attention_layernorm', dtype)
weights[tllm_prex + 'post_attention_layernorm.weight'] = post_ln_weight
weights[tllm_prex + 'post_attention_layernorm.bias'] = post_ln_bias
embed_w = get_weight(model_params, 'gpt_neox.embed_in', dtype)
lm_head_w = get_weight(model_params, 'embed_out', dtype)
if not share_embedding_table:
weights['lm_head.weight'] = split_matrix_tp(lm_head_w,
tensor_parallel,
rank,
dim=0)
if not use_parallel_embedding:
weights['transformer.vocab_embedding.weight'] = embed_w
else:
assert hf_model.config.vocab_size % tensor_parallel == 0
weights['transformer.vocab_embedding.weight'] = split_matrix_tp(
embed_w, tensor_parallel, rank, dim=sharding_dim)
ln_f_w, ln_f_b = get_weight_and_bias(model_params,
'gpt_neox.final_layer_norm', dtype)
weights['transformer.ln_f.weight'] = ln_f_w
weights['transformer.ln_f.bias'] = ln_f_b
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
print(f'Weights loaded. Total time: {t}')
return weights
if __name__ == '__main__':
# TODO(qijun): Currently, the convert script depends on a torch op:
# torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix,
# which is included in tensorrt_llm Python package. Otherwise, the convert
# script does not need to import tensorrt_llm. Will remove it after reimplementing
# the op with PyTorch.
print(tensorrt_llm.__version__)
args = parse_arguments()
tik = time.time()
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
hf_config = AutoConfig.from_pretrained(args.model_dir)
hf_model = AutoModelForCausalLM.from_pretrained(args.model_dir,
torch_dtype="auto")
config = {
'architecture': hf_config.architectures[0],
'dtype': args.dtype,
'num_hidden_layers': hf_config.num_hidden_layers,
'num_attention_heads': hf_config.num_attention_heads,
'hidden_size': hf_config.hidden_size,
'vocab_size': hf_config.vocab_size,
'position_embedding_type': 'learned_absolute',
'max_position_embeddings': hf_config.max_position_embeddings,
'rotary_emb_base': hf_config.rotary_emb_base,
'rotary_pct': hf_config.rotary_pct,
'hidden_act': hf_config.hidden_act,
'quantization': {
'use_weight_only': args.use_weight_only,
'weight_only_precision': args.weight_only_precision,
},
'mapping': {
'world_size': args.world_size,
'tp_size': args.world_size,
},
'use_parallel_embedding': args.use_parallel_embedding,
'embedding_sharding_dim': args.embedding_sharding_dim,
'share_embedding_table': args.use_embedding_sharing,
}
if args.use_weight_only and args.weight_only_precision == 'int4_gptq':
assert args.per_group
config['quantization'].update({
'weight_only_precision': 'int4',
'per_group': args.per_group,
'zero': True,
})
with open(os.path.join(args.output_dir, 'config.json'), 'w') as f:
json.dump(config, f, indent=4)
def covert_and_save(rank):
mapping = Mapping(world_size=args.world_size,
rank=rank,
tp_size=args.world_size,
pp_size=1)
try:
if args.use_weight_only and args.weight_only_precision == 'int4_gptq':
weights = load_from_gptq_gptneox(
args.ammo_quant_ckpt_path,
hf_config,
use_parallel_embedding=args.use_parallel_embedding,
sharding_dim=args.embedding_sharding_dim,
share_embedding_table=args.use_embedding_sharing,
mapping=mapping,
dtype=args.dtype)
else:
if args.weight_only_precision == 'int4':
plugin_weight_only_quant_type = torch.quint4x2
else:
plugin_weight_only_quant_type = torch.int8
weights = convert_hf_gptneox(
hf_model,
mapping,
dtype=args.dtype,
use_weight_only=args.use_weight_only,
plugin_weight_only_quant_type=plugin_weight_only_quant_type,
use_parallel_embedding=args.use_parallel_embedding,
sharding_dim=args.embedding_sharding_dim,
share_embedding_table=args.use_embedding_sharing)
safe_save_path = os.path.join(args.output_dir,
f'rank{rank}.safetensors')
tensorrt_llm.logger.info(f'Saving safetensors to: {safe_save_path}')
safetensors.torch.save_file(weights, safe_save_path)
tensorrt_llm.logger.info(f'Saved safetensors to: {safe_save_path}')
except Exception as e:
tensorrt_llm.logger.info(f'Excepting when converting, {e}')
if args.workers == 1:
for rank in range(args.world_size):
covert_and_save(rank)
else:
with ThreadPoolExecutor(max_workers=args.workers) as p:
futures = [
p.submit(covert_and_save, rank)
for rank in range(args.world_size)
]
wait(futures)
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
print(f'Total time of converting checkpoints: {t}')