TensorRT-LLMs/examples/models/contrib/gptneox/convert_checkpoint.py
Kaiyu Xie 2631f21089
Update (#2978)
Signed-off-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>
2025-03-23 16:39:35 +08:00

731 lines
31 KiB
Python

import argparse
import json
import os
import time
import traceback
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import List, Optional
import safetensors
import safetensors.torch
import torch
from safetensors import safe_open
from transformers import AutoConfig, AutoModelForCausalLM
import tensorrt_llm
from tensorrt_llm._utils import str_dtype_to_torch
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models.convert_utils import (get_weight, get_weight_and_bias,
split_matrix_tp)
from tensorrt_llm.quantization import QuantAlgo
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('--model_dir', type=str, default=None)
parser.add_argument('--tp_size',
type=int,
default=1,
help='N-way tensor parallelism size')
parser.add_argument('--pp_size',
type=int,
default=1,
help='N-way pipeline parallelism size')
parser.add_argument('--dtype',
type=str,
default='float16',
choices=['float32', 'bfloat16', 'float16'])
parser.add_argument(
'--use_weight_only',
default=False,
action="store_true",
help='Quantize weights for the various GEMMs to INT4/INT8.'
'See --weight_only_precision to set the precision')
parser.add_argument(
'--weight_only_precision',
const='int8',
type=str,
nargs='?',
default='int8',
choices=['int8', 'int4', 'int4_gptq'],
help=
'Define the precision for the weights when using weight-only quantization.'
'You must also use --use_weight_only for that argument to have an impact.'
)
parser.add_argument('--quant_ckpt_path',
type=str,
default=None,
help='Path of a quantized model checkpoint')
parser.add_argument(
'--use_parallel_embedding',
action="store_true",
default=False,
help=
'By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled'
)
parser.add_argument(
'--embedding_sharding_dim',
type=int,
default=0,
choices=[0, 1],
help=
'By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). '
'To shard it along hidden dimension, set embedding_sharding_dim=1'
'Note: embedding sharing is only enabled when embedding_sharding_dim = 0'
)
parser.add_argument('--output_dir',
type=str,
default='tllm_checkpoint',
help='The path to save the TensorRT-LLM checkpoint')
parser.add_argument(
'--workers',
type=int,
default=1,
help='The number of workers for converting checkpoint in parallel')
args = parser.parse_args()
return args
# TODO: Seems all convert checkpoints may use following utility functions.
# Maybe in one common version.
def reorder_qkv_weight_or_bias(weight: torch.Tensor,
head_dim: int,
num_heads: int,
num_kv_heads: Optional[int] = None,
tp_size: int = 1,
is_bias: bool = False) -> torch.Tensor:
""" Reorder the qkv weight for TRT-LLM use.
The shape of the fused QKV weights in HF is different from the shape that
TRT-LLM requires. In particular, the weight of HF consists of interleaved
q, k, v head weights, while that of TRT-LLM is contiguous.
HF : [q1, k1, v1, ..., qh, kh, vh]
TRT-LLM: [q1, ..., qh, k1, ..., kh, v1, vh]
where qi, vi, ki are weight vectors corresponding to attention head i.
It's similar to multi/grouped query attention cases.
We reorder and split the weight of an attention layer to fit into TRT-LLM.
The reordered weight and bias will be
weight: (T, Qh * D + 2 * KVh * D, H)
bias : (T, Qh * D + 2 * KVh * D)
where T=tp_size, Qh=local_num_q_heads, KVh=local_num_kv_heads, D=head_dim,
H=hidden_dim. In the multi/grouped query attention, the number of K/V
attention heads are less than that of Q attention, so that K/V attention
heads may be shared across different ranks if necessary.
For tensor parallelism, we use the first dimension to select the
corresponding weights.
"""
# Query types and expected kv heads.
# - Conventional MHA: num_heads = num_kv_heads
# - Multi-Query Attention: num_kv_heads = 1
# - Grouped-Query Attention: num_heads % num_kv_heads = 0
num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
assert num_heads % num_kv_heads == 0, \
f'num_heads({num_heads}) must be divisible by ' \
f'num_kv_heads({num_kv_heads})).'
# The number of attention heads per group: N q head + 1 k head + 1 v head.
num_group_heads = num_heads // num_kv_heads + 2
assert weight.shape[0] == num_kv_heads * num_group_heads * head_dim, \
f'{weight.shape[0]} != {num_kv_heads} * {num_group_heads} * {head_dim}'
qkv_in = num_heads * head_dim if not is_bias else 1
# Split Q/K/V weights
weight = weight.reshape(num_kv_heads, num_heads // num_kv_heads + 2,
head_dim, qkv_in)
q_w = weight[:, :-2, ...] # (nKV, num_heads // nKV, head_dim, qkv_in)
k_w = weight[:, -2:-1, ...] # (nKV, 1, head_dim, qkv_in)
v_w = weight[:, -1:, ...] # (nKV, 1, head_dim, qkv_in)
if num_kv_heads < num_heads and num_kv_heads < tp_size:
# Duplicate K/V heads to make sure that each rank has at least one
# K/V heads. For instance, num_heads=8, num_kv_heads=2, tp_size=4,
# we will make the qkv weight as below.
# Orig: [q0 q1 q2 q3 k0 v0 q4 q5 q6 q7 k1 v0 v1]
# >>>> [[q0 q1 k0 v0], [q2 q3 k0 v0], [q4 q5 k1 v1], [q6 q7 k1 v1]]
assert tp_size % num_kv_heads == 0
num_dups = tp_size // num_kv_heads
# k_w and v_w have the same shape.
new_shape = (num_kv_heads, num_dups) + k_w.shape[2:]
k_w = torch.broadcast_to(k_w, size=new_shape)
v_w = torch.broadcast_to(v_w, size=new_shape)
# Update the number of kv heads.
num_kv_heads = tp_size
reordered = torch.concat(
[
q_w.reshape(tp_size, num_heads // tp_size, head_dim, qkv_in),
k_w.reshape(tp_size, num_kv_heads // tp_size, head_dim, qkv_in),
v_w.reshape(tp_size, num_kv_heads // tp_size, head_dim, qkv_in),
],
dim=1,
)
qkv_out = (num_heads + 2 * num_kv_heads) // tp_size * head_dim
return reordered.reshape((tp_size, qkv_out, -1))
def get_gptq_gptneox_group_size(quant_ckpt_path, hf_config):
gptq_model = safe_open(quant_ckpt_path, framework="pt", device=0)
gptq_prefix = "gpt_neox."
split_sym = "."
def load(key, no_prefix=0):
if no_prefix:
return gptq_model.get_tensor(key).cpu()
else:
return gptq_model.get_tensor(gptq_prefix + key).cpu()
hidden_size = hf_config.hidden_size
prefix = "layers" + split_sym + "0" + split_sym
scales_fp16 = load(prefix + 'attention.query_key_value.scales')
return hidden_size // scales_fp16.shape[0]
def load_from_gptq_gptneox(quant_ckpt_path,
hf_config=None,
use_parallel_embedding=False,
sharding_dim=0,
mapping=Mapping(),
dtype='float16'):
tensorrt_llm.logger.info(
'Loading weights from groupwise GPTQ LLaMA safetensors...')
weights = {}
tik = time.time()
gptq_model = safe_open(quant_ckpt_path, framework="pt", device=0)
gptq_prefix = "gpt_neox."
gptq_suffix_list = [".qweight", ".qzeros", ".scales"]
split_sym = "."
packer = torch.ops.trtllm.pack_int8_tensor_to_packed_int4
preprocessor = torch.ops.trtllm.preprocess_weights_for_mixed_gemm
torch_dtype = str_dtype_to_torch(dtype)
def load(key, no_prefix=0):
if no_prefix:
return gptq_model.get_tensor(key).cpu()
else:
return gptq_model.get_tensor(gptq_prefix + key).cpu()
def torch_split(v, dim):
if v.shape[dim] % mapping.tp_size != 0:
tensorrt_llm.logger.error(
"Current weight shape is invalid for mapping.tp_size=" +
str(mapping.tp_size))
assert False, "Invalid TP size"
return v.split(v.shape[dim] // mapping.tp_size,
dim=dim)[mapping.tp_rank].contiguous()
def unpack_int32_into_int8(w_packed):
# Unpack inputs packed in int32/float32 into uint4 and store them in int8 format
w_packed_int4x2 = w_packed.contiguous().view(torch.uint8)
w_unpacked = torch.zeros(w_packed_int4x2.shape[0],
w_packed_int4x2.shape[1] * 2,
dtype=torch.int8,
device=w_packed.device)
w_unpacked[:, ::2] = w_packed_int4x2 % 16
w_unpacked[:, 1::2] = w_packed_int4x2 // 16
return w_unpacked.contiguous()
def process_and_assign_weight(v: List[torch.Tensor],
tllm_prex: str,
tp_dim: int = -1):
if tp_dim == -1:
qweight_int32, qzeros_int32, scales_fp16 = [
item.cpu() for item in v
]
else:
qweight_int32, qzeros_int32, scales_fp16 = [
torch_split(item, tp_dim).cpu() for item in v
]
USE_UINT4_INPUT = 1 # Set to true if checkpoint store UINT4 weights
USE_GPTQ_FOR_LLAMA = 1 # GPTQ-for-LLaMA added 1 to zeros
qweight_unpacked_int8 = unpack_int32_into_int8(
qweight_int32.T).T.contiguous() - 8
qweight_interleaved = preprocessor(packer(qweight_unpacked_int8),
torch.quint4x2,
torch.float16).view(torch.float16)
# zeros = zeros * scales
qzeros_unpacked_int32 = unpack_int32_into_int8(qzeros_int32)
if not USE_UINT4_INPUT:
# Correcting UINT4 values back to INT4 order
mask_negative = qzeros_unpacked_int32[qzeros_unpacked_int32 < 0]
mask_positive = qzeros_unpacked_int32[qzeros_unpacked_int32 >= 0]
qzeros_unpacked_int32 = qzeros_unpacked_int32 + 16 * mask_negative - 16 * mask_positive
zeros_x_scales_fp16 = (-qzeros_unpacked_int32 + 8 * USE_UINT4_INPUT -
USE_GPTQ_FOR_LLAMA) * scales_fp16
zeros_x_scales_fp16 = zeros_x_scales_fp16.half()
results = {
f'{tllm_prex}.weight': qweight_interleaved,
f'{tllm_prex}.weights_scaling_factor': scales_fp16,
f'{tllm_prex}.zero': zeros_x_scales_fp16,
}
return results
def preprocess_groupwise_weight_params(qweight_unpacked_int8, scales_fp16,
qzeros_unpacked_int8):
UINT4_TO_INT4_FLAG = 1
GPTQ_FLAG = 1
qweight_interleaved = preprocessor(packer(qweight_unpacked_int8),
torch.quint4x2,
torch.float16).view(torch.float16)
# zeros = zeros * scales
zeros_x_scales_fp16 = (-qzeros_unpacked_int8 + 8 * UINT4_TO_INT4_FLAG -
GPTQ_FLAG) * scales_fp16
zeros_x_scales_fp16 = zeros_x_scales_fp16.half()
# return processed interleaved weight, original scales and zeros * scales
return qweight_interleaved.contiguous(), scales_fp16.contiguous(
), zeros_x_scales_fp16.contiguous()
# Load weights from GPTQ checkpoint into TRT-LLM module
# 1. vocab_embedding
v = load('embed_in.weight')
if mapping.is_first_pp_rank():
if not use_parallel_embedding:
weights['transformer.vocab_embedding.weight'] = v.to(torch_dtype)
else:
assert hf_config.vocab_size % mapping.tp_size == 0
weights['transformer.vocab_embedding.weight'] = torch_split(
v, sharding_dim).to(torch_dtype)
# 2. lm_head
if mapping.is_last_pp_rank():
v = load('embed_out.weight', no_prefix=1)
weights['lm_head.weight'] = torch_split(v, 0).to(torch_dtype)
# 3. ln_f
v = load('final_layer_norm.weight')
b = load('final_layer_norm.bias')
if mapping.is_last_pp_rank():
weights['transformer.ln_f.weight'] = v.to(torch_dtype)
weights['transformer.ln_f.bias'] = b.to(torch_dtype)
# 4. Weights inside each layer
num_hidden_layers = hf_config.num_hidden_layers
layers_range = mapping.pp_layers(num_hidden_layers)
for l in layers_range:
layer_idx = l - layers_range[0]
prefix = "layers" + split_sym + str(l) + split_sym
tensorrt_llm.logger.info(f'Process weights in layer: {layer_idx}')
# layer = tensorrt_llm_llama.layers[layer_idx]
tllm_prex = f'transformer.layers.{l - layers_range[0]}'
# 4.1 attention.qkv
num_heads = hf_config.num_attention_heads
hidden_size = hf_config.hidden_size
head_size = hidden_size // num_heads
qweight_int32 = load(prefix + 'attention.query_key_value.qweight')
scales_fp16 = load(prefix + 'attention.query_key_value.scales')
qzeros_int32 = load(prefix + 'attention.query_key_value.qzeros')
biases_fp16 = load(prefix + 'attention.query_key_value.bias')
GROUP_SIZE = hidden_size // scales_fp16.shape[0]
# [hidden_size // 8, hidden_size * 3] -> [hidden_size * 3, hidden_size]
qweight_unpacked_int8 = unpack_int32_into_int8(
qweight_int32.T).contiguous() - 8
# [hidden_size // GROUP_SIZE, hidden_size * 3 // 8] ->
# [hidden_size // GROUP_SIZE, hidden_size * 3]
qzeros_unpacked_int8 = unpack_int32_into_int8(qzeros_int32)
# qkv_weights [num_heads x (q|k|v), hidden_size] ->
# [(num_heads x q)|(num_heads x k)|(num_heads x v), hidden_size]
new_qkv_weight_shape = torch.Size(
[num_heads, 3, head_size * qweight_unpacked_int8.size()[-1]])
# [hidden_size * 3, hidden_size]
qweight_unpacked_int8 = qweight_unpacked_int8.view(
new_qkv_weight_shape).permute(1, 0, 2).reshape(
[hidden_size * 3, hidden_size]).contiguous()
new_qkv_scale_shape = torch.Size(
[num_heads, 3, head_size * (hidden_size // GROUP_SIZE)])
# [hidden_size * 3, hidden_size // GROUP_SIZE]
scales_fp16 = scales_fp16.T.contiguous().view(
new_qkv_scale_shape).permute(1, 0, 2).reshape(
[hidden_size * 3, hidden_size // GROUP_SIZE]).contiguous()
new_qkv_zero_shape = torch.Size(
[num_heads, 3, head_size * (hidden_size // GROUP_SIZE)])
# [hidden_size * 3, hidden_size // GROUP_SIZE]
qzeros_unpacked_int8 = qzeros_unpacked_int8.T.contiguous().view(
new_qkv_zero_shape).permute(1, 0, 2).reshape(
[hidden_size * 3, hidden_size // GROUP_SIZE]).contiguous()
new_qkv_bias_shape = torch.Size([num_heads, 3, head_size])
biases_fp16 = biases_fp16.view(new_qkv_bias_shape).permute(
1, 0, 2).reshape([hidden_size * 3])
tp_size = mapping.tp_size
if tp_size > 1:
qweight_unpacked_int8 = qweight_unpacked_int8.reshape(
[3, hidden_size, hidden_size])
qweight_unpacked_int8 = torch_split(qweight_unpacked_int8, dim=1)
qweight_unpacked_int8 = qweight_unpacked_int8.reshape(
[3 * hidden_size // tp_size, hidden_size])
scales_fp16 = scales_fp16.reshape(
[3, hidden_size, hidden_size // GROUP_SIZE])
scales_fp16 = torch_split(scales_fp16, dim=1)
scales_fp16 = scales_fp16.reshape(
[3 * hidden_size // tp_size, hidden_size // GROUP_SIZE])
qzeros_unpacked_int8 = qzeros_unpacked_int8.reshape(
[3, hidden_size, hidden_size // GROUP_SIZE])
qzeros_unpacked_int8 = torch_split(qzeros_unpacked_int8, dim=1)
qzeros_unpacked_int8 = qzeros_unpacked_int8.reshape(
[3 * hidden_size // tp_size, hidden_size // GROUP_SIZE])
biases_fp16 = biases_fp16.reshape([3, hidden_size])
biases_fp16 = torch_split(biases_fp16, dim=1)
biases_fp16 = biases_fp16.reshape([3 * hidden_size // tp_size])
qweight_fp32, scales_fp16, zeros_fp16 = preprocess_groupwise_weight_params(
qweight_unpacked_int8.T.contiguous(), scales_fp16.T.contiguous(),
qzeros_unpacked_int8.T.contiguous())
weights.update({
f'{tllm_prex}.attention.qkv.weight': qweight_fp32,
f'{tllm_prex}.attention.qkv.weights_scaling_factor': scales_fp16,
f'{tllm_prex}.attention.qkv.zero': zeros_fp16,
f'{tllm_prex}.attention.qkv.bias': biases_fp16,
})
# 4.2 attention.dense
v = [load(prefix + 'attention.dense' + suf) for suf in gptq_suffix_list]
# pre scaling down for duplicated bias add between different tp ranks
b = load(prefix + 'attention.dense.bias') / mapping.tp_size
weights.update(
process_and_assign_weight(v,
f'{tllm_prex}.attention.dense',
tp_dim=0))
weights.update({f'{tllm_prex}.attention.dense.bias': b.to(torch_dtype)})
# 4.3 mlp.fc
v = [
load(prefix + 'mlp.dense_h_to_4h' + suf) for suf in gptq_suffix_list
]
b = load(prefix + 'mlp.dense_h_to_4h.bias')
weights.update(
process_and_assign_weight(v, f'{tllm_prex}.mlp.fc', tp_dim=1))
weights.update(
{f'{tllm_prex}.mlp.fc.bias': torch_split(b, dim=0).to(torch_dtype)})
# 4.4 mlp.proj
v = [
load(prefix + 'mlp.dense_4h_to_h' + suf) for suf in gptq_suffix_list
]
# pre scaling down for duplicated bias add between different tp ranks
b = load(prefix + 'mlp.dense_4h_to_h.bias') / mapping.tp_size
weights.update(
process_and_assign_weight(v, f'{tllm_prex}.mlp.proj', tp_dim=0))
weights.update({f'{tllm_prex}.mlp.proj.bias': b.to(torch_dtype)})
# 4.5 input_layernorm
v = load(prefix + 'input_layernorm.weight')
b = load(prefix + 'input_layernorm.bias')
weights[f'{tllm_prex}.input_layernorm.weight'] = v.to(torch_dtype)
weights[f'{tllm_prex}.input_layernorm.bias'] = b.to(torch_dtype)
# 4.6 post_layernorm
v = load(prefix + 'post_attention_layernorm.weight')
b = load(prefix + 'post_attention_layernorm.bias')
weights[f'{tllm_prex}.post_attention_layernorm.weight'] = v.to(
torch_dtype)
weights[f'{tllm_prex}.post_attention_layernorm.bias'] = b.to(
torch_dtype)
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
tensorrt_llm.logger.info(f'Weights loaded. Total time: {t}')
return weights
def split_qkv_weight(weight: torch.Tensor,
hidden_size: int,
num_heads: int,
tp_size: int,
rank: int,
is_bias: bool,
num_kv_heads: Optional[int] = None) -> torch.Tensor:
""" Splits the QKV matrix according to tensor parallelism """
head_dim = hidden_size // num_heads
weight = reorder_qkv_weight_or_bias(weight,
head_dim=head_dim,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
tp_size=tp_size,
is_bias=is_bias)
# Copy a sliced tensor to prevent memory leak. A sliced tensor shares the
# memory buffer of the original tensor. So, returning without copying makes
# the buffer of a loaded "qkv" be referenced, resulting GC can't release
# those weights until the whole process ends.
if not is_bias:
return weight[rank, ...].clone().contiguous()
else:
return weight[rank, ...].ravel().clone().contiguous()
def get_tllm_linear_weight(weight,
prefix,
bias=None,
use_weight_only=False,
plugin_weight_only_quant_type=torch.int8):
results = {}
if use_weight_only:
v = weight.t().contiguous()
processed_torch_weights, torch_weight_scales = \
torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix(
v, plugin_weight_only_quant_type)
results[prefix + 'weight'] = processed_torch_weights
results[prefix + 'per_channel_scale'] = torch_weight_scales
else:
results[prefix + 'weight'] = weight.contiguous()
if bias is not None:
results[prefix + 'bias'] = bias
return results
def convert_hf_gptneox(hf_model,
mapping: Mapping,
dtype='float32',
use_parallel_embedding=False,
sharding_dim=0,
use_weight_only=False,
plugin_weight_only_quant_type=torch.int8):
weights = {}
tik = time.time()
model_params = dict(hf_model.named_parameters())
dtype = getattr(torch, dtype)
num_attention_heads = hf_model.config.num_attention_heads
hidden_size = hf_model.config.hidden_size
tensor_parallel = mapping.tp_size
rank = mapping.rank
for l in range(hf_model.config.num_hidden_layers):
prefix = f'gpt_neox.layers.{l}.'
tllm_prex = f'transformer.layers.{l}.'
qkv_weight, qkv_bias = get_weight_and_bias(
model_params, prefix + 'attention.query_key_value', dtype)
qkv_w = split_qkv_weight(qkv_weight,
hidden_size,
num_attention_heads,
mapping.tp_size,
mapping.tp_rank,
is_bias=False,
num_kv_heads=num_attention_heads)
if qkv_bias is None:
qkv_b = None
else:
qkv_b = split_qkv_weight(qkv_bias,
hidden_size,
num_attention_heads,
mapping.tp_size,
mapping.tp_rank,
is_bias=True,
num_kv_heads=num_attention_heads)
weights.update(
get_tllm_linear_weight(qkv_w, tllm_prex + 'attention.qkv.', qkv_b,
use_weight_only,
plugin_weight_only_quant_type))
attn_dense_weight, attn_dense_bias = get_weight_and_bias(
model_params, prefix + 'attention.dense', dtype)
split_v = split_matrix_tp(attn_dense_weight,
tensor_parallel,
rank,
dim=1)
weights.update(
get_tllm_linear_weight(split_v, tllm_prex + 'attention.dense.',
attn_dense_bias, use_weight_only,
plugin_weight_only_quant_type))
mlp_fc_weight, mlp_fc_bias = get_weight_and_bias(
model_params, prefix + 'mlp.dense_h_to_4h', dtype)
split_v = split_matrix_tp(mlp_fc_weight, tensor_parallel, rank, dim=0)
bias = split_matrix_tp(mlp_fc_bias, tensor_parallel, rank, dim=0)
weights.update(
get_tllm_linear_weight(split_v, tllm_prex + 'mlp.fc.', bias,
use_weight_only,
plugin_weight_only_quant_type))
mlp_proj_weight, mlp_proj_bias = get_weight_and_bias(
model_params, prefix + 'mlp.dense_4h_to_h', dtype)
split_v = split_matrix_tp(mlp_proj_weight, tensor_parallel, rank, dim=1)
weights.update(
get_tllm_linear_weight(split_v, tllm_prex + 'mlp.proj.',
mlp_proj_bias, use_weight_only,
plugin_weight_only_quant_type))
# Layer norms do not use tensor parallelism
input_ln_weight, input_ln_bias = get_weight_and_bias(
model_params, prefix + 'input_layernorm', dtype)
weights[tllm_prex + 'input_layernorm.weight'] = input_ln_weight
weights[tllm_prex + 'input_layernorm.bias'] = input_ln_bias
post_ln_weight, post_ln_bias = get_weight_and_bias(
model_params, prefix + 'post_attention_layernorm', dtype)
weights[tllm_prex + 'post_attention_layernorm.weight'] = post_ln_weight
weights[tllm_prex + 'post_attention_layernorm.bias'] = post_ln_bias
embed_w = get_weight(model_params, 'gpt_neox.embed_in', dtype)
lm_head_w = get_weight(model_params, 'embed_out', dtype)
weights['lm_head.weight'] = split_matrix_tp(lm_head_w,
tensor_parallel,
rank,
dim=0)
if not use_parallel_embedding:
weights['transformer.vocab_embedding.weight'] = embed_w
else:
assert hf_model.config.vocab_size % tensor_parallel == 0
weights['transformer.vocab_embedding.weight'] = split_matrix_tp(
embed_w, tensor_parallel, rank, dim=sharding_dim)
ln_f_w, ln_f_b = get_weight_and_bias(model_params,
'gpt_neox.final_layer_norm', dtype)
weights['transformer.ln_f.weight'] = ln_f_w
weights['transformer.ln_f.bias'] = ln_f_b
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
print(f'Weights loaded. Total time: {t}')
return weights
if __name__ == '__main__':
# TODO(qijun): Currently, the convert script depends on a torch op:
# torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix,
# which is included in tensorrt_llm Python package. Otherwise, the convert
# script does not need to import tensorrt_llm. Will remove it after reimplementing
# the op with PyTorch.
print(tensorrt_llm.__version__)
args = parse_arguments()
world_size = args.tp_size * args.pp_size
assert args.pp_size == 1, "Pipeline parallelism is not supported."
tensorrt_llm.logger.set_level('info')
tik = time.time()
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
quant_algo = None
plugin_weight_only_quant_type = None
if args.use_weight_only and args.weight_only_precision == 'int8':
plugin_weight_only_quant_type = torch.int8
quant_algo = QuantAlgo.W8A16
elif args.use_weight_only and args.weight_only_precision == 'int4':
plugin_weight_only_quant_type = torch.quint4x2
quant_algo = QuantAlgo.W4A16
elif args.use_weight_only and args.weight_only_precision == 'int4_gptq':
quant_algo = QuantAlgo.W4A16_GPTQ
hf_config = AutoConfig.from_pretrained(args.model_dir)
hf_model = AutoModelForCausalLM.from_pretrained(args.model_dir,
torch_dtype="auto")
config = {
'architecture': hf_config.architectures[0],
'dtype': args.dtype,
'num_hidden_layers': hf_config.num_hidden_layers,
'num_attention_heads': hf_config.num_attention_heads,
'hidden_size': hf_config.hidden_size,
'vocab_size': hf_config.vocab_size,
'position_embedding_type': 'rope_gpt_neox',
'max_position_embeddings': hf_config.max_position_embeddings,
'rotary_emb_base': hf_config.rotary_emb_base,
'rotary_pct': hf_config.rotary_pct,
'hidden_act': hf_config.hidden_act,
'quantization': {
'quant_algo': quant_algo,
},
'mapping': {
'world_size': world_size,
'tp_size': args.tp_size,
'pp_size': args.pp_size,
},
'use_parallel_embedding': args.use_parallel_embedding,
'embedding_sharding_dim': args.embedding_sharding_dim,
}
if args.use_weight_only and args.weight_only_precision == 'int4_gptq':
config['quantization'].update({
'has_zero_point':
True,
'group_size':
get_gptq_gptneox_group_size(args.quant_ckpt_path, hf_config)
})
with open(os.path.join(args.output_dir, 'config.json'), 'w') as f:
json.dump(config, f, indent=4)
def covert_and_save(rank):
mapping = Mapping(world_size=world_size,
rank=rank,
tp_size=args.tp_size,
pp_size=args.pp_size)
if args.use_weight_only and args.weight_only_precision == 'int4_gptq':
weights = load_from_gptq_gptneox(
args.quant_ckpt_path,
hf_config,
use_parallel_embedding=args.use_parallel_embedding,
sharding_dim=args.embedding_sharding_dim,
mapping=mapping,
dtype=args.dtype)
else:
weights = convert_hf_gptneox(
hf_model,
mapping,
dtype=args.dtype,
use_weight_only=args.use_weight_only,
plugin_weight_only_quant_type=plugin_weight_only_quant_type,
use_parallel_embedding=args.use_parallel_embedding,
sharding_dim=args.embedding_sharding_dim)
safe_save_path = os.path.join(args.output_dir,
f'rank{rank}.safetensors')
tensorrt_llm.logger.info(f'Saving safetensors to: {safe_save_path}')
safetensors.torch.save_file(weights, safe_save_path)
tensorrt_llm.logger.info(f'Saved safetensors to: {safe_save_path}')
return True
if args.workers == 1:
for rank in range(world_size):
passed = covert_and_save(rank)
assert passed, "Convert checkpoint failed"
else:
with ThreadPoolExecutor(max_workers=args.workers) as p:
futures = [
p.submit(covert_and_save, rank) for rank in range(world_size)
]
exceptions = []
for future in as_completed(futures):
try:
future.result()
except Exception as e:
traceback.print_exc()
exceptions.append(e)
assert len(
exceptions
) == 0, "Checkpoint conversion failed, please check error log."
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
print(f'Total time of converting checkpoints: {t}')