TensorRT-LLMs/examples/gpt/convert_checkpoint.py
Kaiyu Xie bf0a5afc92
Update TensorRT-LLM (#1598)
* Update TensorRT-LLM
2024-05-14 16:43:41 +08:00

2041 lines
86 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import argparse
import functools
import json
import logging
import os
import shutil
import tarfile
import time
import traceback
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import Dict, Optional, Tuple, Union
import numpy as np
import safetensors
import torch
import torch.nn as nn
import yaml
from datasets import load_dataset
from tqdm import tqdm
from transformers import (AutoConfig, AutoModelForCausalLM,
AutoModelForVision2Seq, AutoTokenizer, GPT2Config)
from transformers.models.gpt2.modeling_gpt2 import GPT2Block
from transformers.pytorch_utils import Conv1D
import tensorrt_llm
from tensorrt_llm._utils import pad_vocab_size, str_dtype_to_torch
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models.convert_utils import retrieved_layer_index_from_name
from tensorrt_llm.quantization import QuantAlgo
LOGGER = logging.getLogger(__name__)
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('--model_dir', type=str, default=None)
parser.add_argument('--nemo_ckpt_path', type=str, default=None)
parser.add_argument('--load_nemo_on_gpu',
default=False,
action="store_true",
help="Whether to load NeMo checkpoint on GPU")
parser.add_argument(
'--gpt_variant',
default=None,
choices=[
None, 'gpt2', 'santacoder', 'starcoder', 'starcoder2', 'persimmon',
'kosmos-2'
],
help=
"By default the script will try to infer the gpt_variant from model_dir. "
"Or users may overwrite gpt_variant by explicitly passing the variant.")
parser.add_argument('--tp_size',
type=int,
default=1,
help='N-way tensor parallelism size')
parser.add_argument('--pp_size',
type=int,
default=1,
help='N-way pipeline parallelism size')
parser.add_argument('--dtype',
type=str,
default='float16',
choices=['float32', 'bfloat16', 'float16'])
parser.add_argument(
'--use_parallel_embedding',
action="store_true",
default=False,
help=
'By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled'
)
parser.add_argument(
'--embedding_sharding_dim',
type=int,
default=0,
choices=[0, 1],
help=
'By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). '
'To shard it along hidden dimension, set embedding_sharding_dim=1'
'Note: embedding sharing is only enabled when embedding_sharding_dim = 0'
)
parser.add_argument(
'--use_embedding_sharing',
action="store_true",
default=False,
help=
'Try to reduce the engine size by sharing the embedding lookup table between two layers.'
'Note: the flag might not take effect when the criteria are not met.')
parser.add_argument(
'--use_weight_only',
default=False,
action="store_true",
help='Quantize weights for the various GEMMs to INT4/INT8.'
'See --weight_only_precision to set the precision')
parser.add_argument(
'--weight_only_precision',
const='int8',
type=str,
nargs='?',
default='int8',
choices=['int8', 'int4'],
help=
'Define the precision for the weights when using weight-only quantization.'
'You must also use --use_weight_only for that argument to have an impact.'
)
parser.add_argument(
'--int8_kv_cache',
default=False,
action="store_true",
help=
'By default, we use dtype for KV cache. int8_kv_cache chooses int8 quantization for KV'
)
parser.add_argument(
'--per_channel',
default=False,
action="store_true",
help=
'By default, we use a single static scaling factor for the GEMM\'s result. '
'per_channel instead uses a different static scaling factor for each channel. '
'The latter is usually more accurate, but a little slower.')
parser.add_argument(
'--per_token',
default=False,
action="store_true",
help=
'By default, we use a single static scaling factor to scale activations in the int8 range. '
'per_token chooses at run time, and for each token, a custom scaling factor. '
'The latter is usually more accurate, but a little slower.')
parser.add_argument(
"--smoothquant",
"-sq",
type=float,
default=None,
help="Set the α parameter (see https://arxiv.org/pdf/2211.10438.pdf)"
" to Smoothquant the model, and output int8 weights."
" A good first try is 0.5. Must be in [0, 1]")
parser.add_argument("--dataset_cache_dir",
type=str,
default=None,
help="cache dir to load the hugging face dataset")
parser.add_argument('--output_dir',
type=str,
default='tllm_checkpoint',
help='The path to save the TensorRT-LLM checkpoint')
parser.add_argument(
'--workers',
type=int,
default=1,
help='The number of workers for converting checkpoint in parallel')
parser.add_argument('--log_level', type=str, default='info')
parser.add_argument(
'--nemo_rename_key',
type=str,
nargs='+',
default=[],
help=
"Change a layer name when loading a NeMo checkpoint. Should follow <old_name_pattern>:<new_name_pattern>"
)
args = parser.parse_args()
tensorrt_llm.logger.set_level(args.log_level)
return args
def rename_keys(model_state, layer_rename_config: Dict[str, str]):
if not layer_rename_config:
return model_state
new_state_dict = {}
for key, value in model_state.items():
for old, new in layer_rename_config.items():
key = key.replace(old, new)
assert key not in new_state_dict, f"Key already exists: {key}"
new_state_dict[key] = value
return new_state_dict
def load_gpt_config(model_dir: str,
gpt_variant: Optional[str] = None) -> GPT2Config:
config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
if gpt_variant is None:
print("Inferring gpt variant from path...")
for v in [
'starcoder2', 'starcoder', 'santacoder', 'gpt2', 'persimmon',
'kosmos-2'
]:
if v in config._name_or_path or ('fuyu' in config._name_or_path
and v == 'persimmon'):
gpt_variant = v
break
assert gpt_variant in [
'gpt2', 'santacoder', 'starcoder', 'starcoder2', 'persimmon', 'kosmos-2'
]
print(f"Gpt variant: {gpt_variant}")
if gpt_variant in ['starcoder2', 'persimmon']:
config.n_embd = config.hidden_size
config.n_inner = config.intermediate_size
config.n_head = config.num_attention_heads
config.n_kv_head = config.num_key_value_heads if hasattr(
config, 'num_key_value_heads') else config.n_head
config.n_layer = config.num_hidden_layers
config.n_positions = config.max_position_embeddings
config.activation_function = 'gelu' if gpt_variant == 'starcoder2' else 'squared-relu'
config.layer_norm_epsilon = config.norm_epsilon if gpt_variant == 'starcoder2' else config.layer_norm_eps
config.bias = config.use_bias if gpt_variant == 'starcoder2' else True
config.position_embedding_type = 'rope_gpt_neox'
config.rotary_base = config.rope_theta
config.rotary_pct = getattr(config, 'partial_rotary_factor', 1.0)
elif gpt_variant == "kosmos-2":
config.n_embd = config.text_config.embed_dim
config.n_inner = config.text_config.ffn_dim
config.n_head = config.text_config.attention_heads
config.n_kv_head = config.n_head
config.n_layer = config.text_config.layers
config.n_positions = config.text_config.max_position_embeddings
config.activation_function = config.text_config.activation_function
config.layer_norm_epsilon = config.text_config.layer_norm_eps
config.bias = True
config.vocab_size = config.text_config.vocab_size
else:
if config.n_inner is None:
config.n_inner = config.n_embd * 4
if gpt_variant in ['santacoder', 'starcoder']:
config.n_kv_head = 1
else:
config.n_kv_head = config.n_head
return config, gpt_variant
def split(param: torch.Tensor,
tp_rank: int,
tp_size: int,
is_column: bool = True) -> torch.Tensor:
"""Split linear layer's weight, bias or scaling factors for tensor parallelism."""
if param is None:
return None
assert param.ndim in [1, 2]
if tp_size == 1:
return param
if param.numel() == 1:
return param
if param.ndim == 1 and not is_column:
return param
split_dim = 0 if (param.ndim == 1 or is_column) else 1
return torch.chunk(param, tp_size, dim=split_dim)[tp_rank].contiguous()
def split_qkv(
param: torch.Tensor,
tp_rank: int,
tp_size: int,
hidden_size: int,
num_heads: int,
num_kv_heads: Optional[int] = None,
) -> torch.Tensor:
"""Split qkv layer's weight, bias or scaling factors for tensor parallelism.
param: (num_heads*head_dim + 2*num_kv_heads*head_dim, in_dim)
"""
if param is None:
return None
assert hidden_size % num_heads == 0
head_dim = hidden_size // num_heads
num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
assert num_heads % num_kv_heads == 0
assert num_heads % tp_size == 0
q_param, k_param, v_param = torch.split(
param, [hidden_size, num_kv_heads * head_dim, num_kv_heads * head_dim],
dim=0)
if num_kv_heads < tp_size:
assert tp_size % num_kv_heads == 0
num_dups = tp_size // num_kv_heads
remain_shape = k_param.shape[1:]
k_param = k_param.view(
num_kv_heads, head_dim,
*remain_shape).repeat_interleave(num_dups, dim=0).view(
num_kv_heads * head_dim * num_dups, *remain_shape)
v_param = v_param.view(
num_kv_heads, head_dim,
*remain_shape).repeat_interleave(num_dups, dim=0).view(
num_kv_heads * head_dim * num_dups, *remain_shape)
else:
assert num_kv_heads % tp_size == 0
q_param = split(q_param, tp_rank, tp_size, is_column=True)
k_param = split(k_param, tp_rank, tp_size, is_column=True)
v_param = split(v_param, tp_rank, tp_size, is_column=True)
return torch.cat([q_param, k_param, v_param], dim=0)
def split_embedding(
param: torch.Tensor,
tp_rank: int,
tp_size: int,
use_parallel_embedding: bool = False,
sharding_dim: int = 0,
) -> torch.Tensor:
if param is None:
return None
if not use_parallel_embedding:
return param
vocab_size, hidden_size = param.size()
if sharding_dim == 0:
if vocab_size % tp_size != 0:
vocab_size_padded = pad_vocab_size(vocab_size, tp_size)
pad_width = vocab_size_padded - vocab_size
param = torch.nn.functional.pad(param, (0, 0, 0, pad_width),
value=0)
else:
assert hidden_size % tp_size == 0
return split(param, tp_rank, tp_size, is_column=(sharding_dim == 0))
def get_weight(params: Dict[str, torch.Tensor], prefix: str,
dtype: torch.dtype) -> torch.Tensor:
if f'{prefix}.weight' not in params:
return None
return params[f'{prefix}.weight'].to(dtype).detach().cpu()
def get_bias(params: Dict[str, torch.Tensor], prefix: str,
dtype: torch.dtype) -> torch.Tensor:
if f'{prefix}.bias' not in params:
return None
return params[f'{prefix}.bias'].to(dtype).detach().cpu()
def get_weight_and_bias(params: Dict[str, torch.Tensor], prefix: str,
dtype: torch.dtype) -> Tuple[torch.Tensor]:
return get_weight(params, prefix, dtype), get_bias(params, prefix, dtype)
def get_tllm_linear_weight(
weight: torch.Tensor,
prefix: str,
bias: Optional[torch.Tensor] = None,
use_weight_only: bool = False,
plugin_weight_only_quant_type: torch.dtype = torch.int8
) -> Dict[str, torch.Tensor]:
results = {}
if use_weight_only:
v = weight.t().contiguous()
processed_torch_weights, torch_weight_scales = \
torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix(
v, plugin_weight_only_quant_type)
results[f'{prefix}.weight'] = processed_torch_weights
results[f'{prefix}.per_channel_scale'] = torch_weight_scales
else:
results[f'{prefix}.weight'] = weight
if bias is not None:
results[f'{prefix}.bias'] = bias
return results
def convert_hf_gpt(hf_model: AutoModelForCausalLM,
hf_config: AutoConfig,
gpt_variant: str,
mapping: Mapping,
dtype: str = 'float32',
use_parallel_embedding: bool = False,
sharding_dim: int = 0,
share_embedding_table: bool = False,
use_weight_only: bool = False,
plugin_weight_only_quant_type: torch.dtype = torch.int8):
weights = {}
tik = time.time()
model_params = dict(hf_model.named_parameters())
dtype = getattr(torch, dtype)
num_attention_heads = hf_config.n_head
hidden_size = hf_config.n_embd
vocab_size = hf_config.vocab_size
num_kv_heads = hf_config.n_kv_head
num_hidden_layers = hf_config.n_layer
layers_range = mapping.pp_layers(num_hidden_layers)
for l in layers_range:
if gpt_variant == 'starcoder2':
prefix = f'model.layers.{l}'
elif gpt_variant == 'persimmon':
is_fuyu = f'language_model.model.embed_tokens.weight' in model_params
prefix = f'language_model.model.layers.{l}' if is_fuyu else f'model.layers.{l}'
elif gpt_variant == 'kosmos-2':
prefix = f'text_model.model.layers.{l}'
else:
prefix = f'transformer.h.{l}'
tllm_prex = f'transformer.layers.{l-layers_range[0]}'
if gpt_variant == 'santacoder':
q_w, q_b = get_weight_and_bias(model_params,
f'{prefix}.attn.q_attn', dtype)
kv_w, kv_b = get_weight_and_bias(model_params,
f'{prefix}.attn.kv_attn', dtype)
qkv_w = torch.cat([q_w, kv_w], dim=-1)
qkv_b = torch.cat([q_b, kv_b], dim=-1)
elif gpt_variant in ['starcoder2', 'kosmos-2']:
q_w, q_b = get_weight_and_bias(model_params,
f'{prefix}.self_attn.q_proj', dtype)
k_w, k_b = get_weight_and_bias(model_params,
f'{prefix}.self_attn.k_proj', dtype)
v_w, v_b = get_weight_and_bias(model_params,
f'{prefix}.self_attn.v_proj', dtype)
qkv_w = torch.cat([q_w, k_w, v_w], dim=0)
qkv_b = torch.cat([q_b, k_b, v_b], dim=0)
elif gpt_variant == 'persimmon':
qkv_w, qkv_b = get_weight_and_bias(
model_params, f'{prefix}.self_attn.query_key_value', dtype)
else:
qkv_w, qkv_b = get_weight_and_bias(model_params,
f'{prefix}.attn.c_attn', dtype)
if gpt_variant in ['gpt2', 'santacoder']:
qkv_w = qkv_w.t().contiguous() # transpose for Conv1D
if gpt_variant == 'persimmon':
qkv_w = split(qkv_w,
mapping.tp_rank,
mapping.tp_size,
is_column=True)
qkv_b = split(qkv_b,
mapping.tp_rank,
mapping.tp_size,
is_column=True)
else:
qkv_w = split_qkv(qkv_w, mapping.tp_rank, mapping.tp_size,
hidden_size, num_attention_heads, num_kv_heads)
qkv_b = split_qkv(qkv_b, mapping.tp_rank, mapping.tp_size,
hidden_size, num_attention_heads, num_kv_heads)
weights.update(
get_tllm_linear_weight(qkv_w, f'{tllm_prex}.attention.qkv', qkv_b,
use_weight_only,
plugin_weight_only_quant_type))
if gpt_variant == 'starcoder2':
attn_dense_w, attn_dense_b = get_weight_and_bias(
model_params, f'{prefix}.self_attn.o_proj', dtype)
elif gpt_variant == 'persimmon':
attn_dense_w, attn_dense_b = get_weight_and_bias(
model_params, f'{prefix}.self_attn.dense', dtype)
elif gpt_variant == 'kosmos-2':
attn_dense_w, attn_dense_b = get_weight_and_bias(
model_params, f'{prefix}.self_attn.out_proj', dtype)
else:
attn_dense_w, attn_dense_b = get_weight_and_bias(
model_params, f'{prefix}.attn.c_proj', dtype)
if gpt_variant in ['gpt2', 'santacoder']:
attn_dense_w = attn_dense_w.t().contiguous() # transpose for Conv1D
attn_dense_w = split(attn_dense_w,
mapping.tp_rank,
mapping.tp_size,
is_column=False)
weights.update(
get_tllm_linear_weight(attn_dense_w, f'{tllm_prex}.attention.dense',
attn_dense_b, use_weight_only,
plugin_weight_only_quant_type))
if gpt_variant == 'persimmon':
mlp_fc_w, mlp_fc_b = get_weight_and_bias(
model_params, f'{prefix}.mlp.dense_h_to_4h', dtype)
elif gpt_variant == 'kosmos-2':
mlp_fc_w, mlp_fc_b = get_weight_and_bias(model_params,
f'{prefix}.ffn.fc1', dtype)
else:
mlp_fc_w, mlp_fc_b = get_weight_and_bias(model_params,
f'{prefix}.mlp.c_fc',
dtype)
if gpt_variant in ['gpt2', 'santacoder']:
mlp_fc_w = mlp_fc_w.t().contiguous() # transpose for Conv1D
mlp_fc_w = split(mlp_fc_w,
mapping.tp_rank,
mapping.tp_size,
is_column=True)
mlp_fc_b = split(mlp_fc_b,
mapping.tp_rank,
mapping.tp_size,
is_column=True)
weights.update(
get_tllm_linear_weight(mlp_fc_w, f'{tllm_prex}.mlp.fc', mlp_fc_b,
use_weight_only,
plugin_weight_only_quant_type))
if gpt_variant == 'persimmon':
mlp_proj_w, mlp_proj_b = get_weight_and_bias(
model_params, f'{prefix}.mlp.dense_4h_to_h', dtype)
elif gpt_variant == 'kosmos-2':
mlp_proj_w, mlp_proj_b = get_weight_and_bias(
model_params, f'{prefix}.ffn.fc2', dtype)
else:
mlp_proj_w, mlp_proj_b = get_weight_and_bias(
model_params, f'{prefix}.mlp.c_proj', dtype)
if gpt_variant in ['gpt2', 'santacoder']:
mlp_proj_w = mlp_proj_w.t().contiguous() # transpose for Conv1D
mlp_proj_w = split(mlp_proj_w,
mapping.tp_rank,
mapping.tp_size,
is_column=False)
weights.update(
get_tllm_linear_weight(mlp_proj_w, f'{tllm_prex}.mlp.proj',
mlp_proj_b, use_weight_only,
plugin_weight_only_quant_type))
if gpt_variant in ['starcoder2', 'persimmon']:
input_ln_w, input_ln_b = get_weight_and_bias(
model_params, f'{prefix}.input_layernorm', dtype)
elif gpt_variant == 'kosmos-2':
input_ln_w, input_ln_b = get_weight_and_bias(
model_params, f'{prefix}.self_attn_layer_norm', dtype)
else:
input_ln_w, input_ln_b = get_weight_and_bias(
model_params, f'{prefix}.ln_1', dtype)
weights[f'{tllm_prex}.input_layernorm.weight'] = input_ln_w
if input_ln_b is not None:
weights[f'{tllm_prex}.input_layernorm.bias'] = input_ln_b
if gpt_variant in ['starcoder2', 'persimmon']:
post_ln_w, post_ln_b = get_weight_and_bias(
model_params, f'{prefix}.post_attention_layernorm', dtype)
elif gpt_variant == 'kosmos-2':
post_ln_w, post_ln_b = get_weight_and_bias(
model_params, f'{prefix}.final_layer_norm', dtype)
else:
post_ln_w, post_ln_b = get_weight_and_bias(model_params,
f'{prefix}.ln_2', dtype)
weights[f'{tllm_prex}.post_layernorm.weight'] = post_ln_w
if post_ln_b is not None:
weights[f'{tllm_prex}.post_layernorm.bias'] = post_ln_b
if gpt_variant == 'persimmon':
q_layernorm_w, q_layernorm_b = get_weight_and_bias(
model_params, f'{prefix}.self_attn.q_layernorm', dtype)
weights[f'{tllm_prex}.attention.q_layernorm.weight'] = q_layernorm_w
weights[f'{tllm_prex}.attention.q_layernorm.bias'] = q_layernorm_b
k_layernorm_w, k_layernorm_b = get_weight_and_bias(
model_params, f'{prefix}.self_attn.k_layernorm', dtype)
weights[f'{tllm_prex}.attention.k_layernorm.weight'] = k_layernorm_w
weights[f'{tllm_prex}.attention.k_layernorm.bias'] = k_layernorm_b
if gpt_variant == 'kosmos-2':
q_layernorm_w, q_layernorm_b = get_weight_and_bias(
model_params, f'{prefix}.self_attn.inner_attn_ln', dtype)
weights[
f'{tllm_prex}.attention.inner_layernorm.weight'] = q_layernorm_w
weights[
f'{tllm_prex}.attention.inner_layernorm.bias'] = q_layernorm_b
k_layernorm_w, k_layernorm_b = get_weight_and_bias(
model_params, f'{prefix}.ffn.ffn_layernorm', dtype)
weights[f'{tllm_prex}.mlp.inner_layernorm.weight'] = k_layernorm_w
weights[f'{tllm_prex}.mlp.inner_layernorm.bias'] = k_layernorm_b
if mapping.is_first_pp_rank():
if gpt_variant == 'starcoder2':
embed_w = get_weight(model_params, 'model.embed_tokens', dtype)
elif gpt_variant == 'kosmos-2':
embed_w = get_weight(model_params, 'text_model.model.embed_tokens',
dtype)
elif gpt_variant == 'persimmon':
embed_w = get_weight(model_params,
('language_model.' if is_fuyu else '') +
'model.embed_tokens', dtype)
else:
embed_w = get_weight(model_params, 'transformer.wte', dtype)
weights['transformer.vocab_embedding.weight'] = split_embedding(
embed_w,
mapping.tp_rank,
mapping.tp_size,
use_parallel_embedding=use_parallel_embedding,
sharding_dim=sharding_dim)
if gpt_variant == 'kosmos-2':
padding_idx = hf_config.text_config.pad_token_id
sin_pos_embedding = hf_model.text_model.model.embed_positions.get_embedding(
padding_idx + 1 + hf_config.text_config.max_position_embeddings,
hf_config.text_config.embed_dim,
padding_idx=padding_idx) # [2 + num_embeddings, embed_dim]
pos_embed_w = sin_pos_embedding[2:].to(dtype).detach().cpu()
else:
pos_embed_w = get_weight(model_params, 'transformer.wpe', dtype)
if pos_embed_w is not None:
weights['transformer.position_embedding.weight'] = split_embedding(
pos_embed_w,
mapping.tp_rank,
mapping.tp_size,
use_parallel_embedding=use_parallel_embedding,
sharding_dim=sharding_dim)
if mapping.is_last_pp_rank():
if gpt_variant == 'starcoder2':
embed_w = get_weight(model_params, 'lm_head', dtype)
if embed_w is None:
embed_w = get_weight(model_params, 'model.embed_tokens', dtype)
elif gpt_variant == 'persimmon':
embed_w = get_weight(model_params,
('language_model.' if is_fuyu else '') +
'lm_head', dtype)
elif gpt_variant == 'kosmos-2':
embed_w = get_weight(model_params, 'text_model.model.embed_tokens',
dtype)
else:
embed_w = get_weight(model_params, 'transformer.wte', dtype)
if not share_embedding_table:
if vocab_size % mapping.tp_size != 0:
vocab_size_padded = pad_vocab_size(vocab_size, mapping.tp_size)
pad_width = vocab_size_padded - vocab_size
embed_w = torch.nn.functional.pad(embed_w, (0, 0, 0, pad_width),
value=0)
weights['lm_head.weight'] = split(embed_w.clone(),
mapping.tp_rank,
mapping.tp_size,
is_column=True)
if gpt_variant == 'starcoder2':
ln_f_w, ln_f_b = get_weight_and_bias(model_params, 'model.norm',
dtype)
elif gpt_variant == 'persimmon':
ln_f_w, ln_f_b = get_weight_and_bias(
model_params, ('language_model.' if is_fuyu else '') +
'model.final_layernorm', dtype)
elif gpt_variant == 'kosmos-2':
ln_f_w, ln_f_b = get_weight_and_bias(model_params,
'text_model.model.layer_norm',
dtype)
else:
ln_f_w, ln_f_b = get_weight_and_bias(model_params,
'transformer.ln_f', dtype)
weights['transformer.ln_f.weight'] = ln_f_w
if ln_f_b is not None:
weights['transformer.ln_f.bias'] = ln_f_b
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
print(f'Weights loaded. Total time: {t}')
return weights
def generate_int8(weights, act_range, is_qkv=False, multi_query_mode=False):
"""
This function has two purposes:
- compute quantized weights, scaled either per-tensor or per-column
- compute scaling factors
Depending on the GEMM API (CUTLASS/CUBLAS) the required scaling factors differ.
CUTLASS uses two sets of scaling factors. One for the activation X, one for the weight W.
CUBLAS only has one (we can't do per-row scaling). So we must provide pre-multiplied scaling factor.
Here is the list of what we need (T means per-tensor, C per-column):
- scale_x_orig_quant puts fp activation into the quantized range (i.e. [-128, 127], for int8). Used before the GEMM. (T)
- scale_y_quant_orig puts quantized activation into the fp range. Used if the GEMM outputs int8. (T)
- scale_w_quant_orig puts weights from quant range to fp range (used with CUTLASS) (T, C)
- scale_y_accum_quant puts the GEMM result (XW) from accumulation range (int32)
to quant range (int8) (used for CUBLAS) (T, C)
Note that we don't do anything special about row-parallel GEMM. Theoretically, we could have per-GPU scaling factors too,
but then the model would change depending on the number of GPUs used.
For QKV projection, the behavior is special. Even if we have a single matrix to perform QKV projection, we consider it
as three different matrices: Q, K, and V. So per-tensor actually means one scaling factor for each Q, K and V.
"""
# compute weight scaling factors for fp->int8 and int8->fp
if is_qkv and not multi_query_mode:
scale_w_orig_quant_t = 127. / act_range["w"].reshape(3, -1).max(
dim=-1, keepdims=True)[0].cpu().numpy()
scale_w_orig_quant_c = 127. / act_range["w"].reshape(3,
-1).cpu().numpy()
elif is_qkv and multi_query_mode:
raise ValueError(
f"Multi-query w/ int8 quant has not been supported yet")
else:
scale_w_orig_quant_t = 127. / act_range["w"].max().cpu().numpy()
scale_w_orig_quant_c = 127. / act_range["w"].cpu().numpy()
scale_w_quant_orig_t = 1.0 / scale_w_orig_quant_t
scale_w_quant_orig_c = 1.0 / scale_w_orig_quant_c
# compute the rest of needed scaling factors
scale_x_orig_quant_t = np.array(127. / act_range["x"].max().item())
scale_y_orig_quant_t = np.array(127. / act_range["y"].max().item())
scale_y_quant_orig_t = np.array(act_range["y"].max().item() / 127.)
scale_y_accum_quant_t = scale_y_orig_quant_t / (scale_x_orig_quant_t *
scale_w_orig_quant_t)
scale_y_accum_quant_c = scale_y_orig_quant_t / (scale_x_orig_quant_t *
scale_w_orig_quant_c)
if is_qkv:
scale_y_accum_quant_t = np.broadcast_to(scale_y_accum_quant_t,
scale_w_orig_quant_c.shape)
scale_w_quant_orig_t = np.broadcast_to(scale_w_quant_orig_t,
scale_w_orig_quant_c.shape)
to_i8 = lambda x: x.round().clip(-127, 127).astype(np.int8)
return {
"weight.int8": to_i8(weights * scale_w_orig_quant_t),
"weight.int8.col": to_i8(weights * scale_w_orig_quant_c),
"scale_x_orig_quant": scale_x_orig_quant_t.astype(np.float32),
"scale_w_quant_orig": scale_w_quant_orig_t.astype(np.float32),
"scale_w_quant_orig.col": scale_w_quant_orig_c.astype(np.float32),
"scale_y_accum_quant": scale_y_accum_quant_t.astype(np.float32),
"scale_y_accum_quant.col": scale_y_accum_quant_c.astype(np.float32),
"scale_y_quant_orig": scale_y_quant_orig_t.astype(np.float32),
}
@torch.no_grad()
def apply_smoothing(scales,
gemm_weights,
layernorm_weights=None,
layernorm_bias=None,
dtype=torch.float32,
layernorm_1p=False):
if not isinstance(gemm_weights, list):
gemm_weights = [gemm_weights]
if layernorm_weights is not None:
assert layernorm_weights.numel() == scales.numel()
layernorm_weights.div_(scales).to(dtype)
if layernorm_bias is not None:
assert layernorm_bias.numel() == scales.numel()
layernorm_bias.div_(scales).to(dtype)
if layernorm_1p:
layernorm_weights += (1 / scales) - 1
for gemm in gemm_weights:
gemm.mul_(scales.view(1, -1)).to(dtype)
@torch.no_grad()
def smooth_gemm(gemm_weights,
act_scales,
layernorm_weights=None,
layernorm_bias=None,
alpha=0.5,
weight_scales=None):
if not isinstance(gemm_weights, list):
gemm_weights = [gemm_weights]
orig_dtype = gemm_weights[0].dtype
for gemm in gemm_weights:
# gemm_weights are expected to be transposed
assert gemm.shape[1] == act_scales.numel()
if weight_scales is None:
weight_scales = torch.cat(
[gemm.abs().max(dim=0, keepdim=True)[0] for gemm in gemm_weights],
dim=0)
weight_scales = weight_scales.max(dim=0)[0]
weight_scales.to(float).clamp(min=1e-5)
scales = (act_scales.to(gemm_weights[0].device).to(float).pow(alpha) /
weight_scales.pow(1 - alpha)).clamp(min=1e-5)
apply_smoothing(scales, gemm_weights, layernorm_weights, layernorm_bias,
orig_dtype)
return scales
@torch.no_grad()
def capture_activation_range(model,
tokenizer,
dataset,
num_samples=512,
seq_len=512):
model.eval()
device = next(model.parameters()).device
act_scales = defaultdict(lambda: {"x": None, "y": None, "w": None})
def stat_tensor(name, tensor, act_scales, key):
hidden_dim = tensor.shape[-1]
tensor = tensor.view(-1, hidden_dim).abs().detach()
comming_max = torch.max(tensor, dim=0)[0].float()
if act_scales[name][key] is None:
act_scales[name][key] = comming_max
else:
act_scales[name][key] = torch.max(act_scales[name][key],
comming_max)
def stat_input_hook(m, x, y, name):
if isinstance(x, tuple):
x = x[0]
stat_tensor(name, x, act_scales, "x")
stat_tensor(name, y, act_scales, "y")
if act_scales[name]["w"] is None:
act_scales[name]["w"] = m.weight.abs().clip(1e-8,
None).max(dim=0)[0]
hooks = []
for name, m in model.named_modules():
if isinstance(m, nn.Linear) or isinstance(m, Conv1D):
hooks.append(
m.register_forward_hook(
functools.partial(stat_input_hook, name=name)))
for i in tqdm(range(num_samples), desc="calibrating model"):
input_ids = tokenizer(dataset[i]["text"],
return_tensors="pt",
max_length=seq_len,
truncation=True).input_ids.to(device)
model(input_ids)
for h in hooks:
h.remove()
return act_scales
@torch.no_grad()
def smooth_gpt_model(model, scales, alpha):
# Smooth the activation and weights with smoother = $\diag{s}$
for name, module in model.named_modules():
if not isinstance(module, GPT2Block):
continue
# qkv_proj
layer_name = name + ".attn.c_attn"
smoother = smooth_gemm(module.attn.c_attn.weight.T,
scales[layer_name]["x"], module.ln_1.weight,
module.ln_1.bias, alpha)
scales[layer_name]["x"] = scales[layer_name]["x"] / smoother
scales[layer_name]["w"] = module.attn.c_attn.weight.abs().max(dim=0)[0]
# fc1
layer_name = name + ".mlp.c_fc"
smoother = smooth_gemm(module.mlp.c_fc.weight.T,
scales[layer_name]["x"], module.ln_2.weight,
module.ln_2.bias, alpha)
scales[layer_name]["x"] = scales[layer_name]["x"] / smoother
scales[layer_name]["w"] = module.mlp.c_fc.weight.abs().max(dim=0)[0]
def get_tllm_linear_sq_weight(vals,
prefix,
shape,
tensor_parallel,
is_qkv=False,
per_token=False,
per_channel=False,
last_prefix=None,
bias=None,
smoother_value=None,
smoother_shape=None,
rank=0,
cat_dim=0,
multi_query_mode=False):
results = {}
def multi_query_split(data, local_dim, head_size, tp_size, cur_rank):
q, k, v = np.split(data, [local_dim, local_dim + head_size], axis=-1)
q_split = np.split(q, tp_size, axis=-1)
k_split = np.split(k, tp_size, axis=-1)
v_split = np.split(v, tp_size, axis=-1)
return [
np.concatenate((q_split[ii], k_split[ii], v_split[ii]), axis=-1)
for ii in range(tp_size)
][cur_rank]
col_shape = shape if (is_qkv or per_channel) else [1, 1]
if per_token:
if per_channel:
original_weights = np.array(vals["weight.int8.col"])
else:
original_weights = np.array(vals["weight.int8"])
local_dim = original_weights.shape[0]
head_size = (original_weights.shape[1] - local_dim) // 2
if multi_query_mode:
cur_weights = multi_query_split(original_weights, local_dim,
head_size, tensor_parallel, rank)
else:
cur_weights = np.split(original_weights,
tensor_parallel,
axis=cat_dim)[rank]
if is_qkv:
hidden_dim = cur_weights.shape[0]
cur_weights = cur_weights.reshape(hidden_dim, -1)
results[prefix +
'weight'] = torch.from_numpy(cur_weights).t().contiguous()
if smoother_value is None:
results[last_prefix] = torch.from_numpy(
np.array([1.0], dtype=np.float32))
if per_channel:
cur_per_channel_value = vals["scale_w_quant_orig.col"]
if smoother_value is None:
if multi_query_mode:
cur_per_channel_value = multi_query_split(
vals["scale_w_quant_orig.col"], local_dim, head_size,
tensor_parallel, rank)
else:
cur_per_channel_value = np.split(
vals["scale_w_quant_orig.col"],
tensor_parallel,
axis=cat_dim)[rank]
else:
cur_per_channel_value = vals["scale_w_quant_orig"]
if is_qkv:
if multi_query_mode:
cur_per_channel_value = multi_query_split(
vals["scale_w_quant_orig"], local_dim, head_size,
tensor_parallel, rank)
else:
cur_per_channel_value = np.split(vals["scale_w_quant_orig"],
tensor_parallel,
axis=cat_dim)[rank]
results[prefix + 'per_channel_scale'] = torch.from_numpy(
np.array(cur_per_channel_value,
dtype=np.float32).reshape(col_shape)).contiguous()
else:
if per_channel:
original_weights = np.array(vals["weight.int8.col"])
else:
original_weights = np.array(vals["weight.int8"])
local_dim = original_weights.shape[0]
head_size = (original_weights.shape[1] - local_dim) // 2
if multi_query_mode:
cur_weights = multi_query_split(original_weights, local_dim,
head_size, tensor_parallel, rank)
else:
cur_weights = np.split(original_weights,
tensor_parallel,
axis=cat_dim)[rank]
if is_qkv:
hidden_dim = cur_weights.shape[0]
cur_weights = cur_weights.reshape(hidden_dim, -1)
results[prefix +
'weight'] = torch.from_numpy(cur_weights).t().contiguous()
if per_channel:
cur_per_channel_value = vals["scale_y_accum_quant.col"]
if smoother_value is None:
if multi_query_mode:
cur_per_channel_value = multi_query_split(
vals["scale_y_accum_quant.col"], local_dim, head_size,
tensor_parallel, rank)
else:
cur_per_channel_value = np.split(
vals["scale_y_accum_quant.col"],
tensor_parallel,
axis=cat_dim)[rank]
else:
cur_per_channel_value = vals["scale_y_accum_quant"]
# QKV is always per_channel
if is_qkv:
if multi_query_mode:
cur_per_channel_value = multi_query_split(
vals["scale_y_accum_quant"], local_dim, head_size,
tensor_parallel, rank)
else:
cur_per_channel_value = np.split(
vals["scale_y_accum_quant"],
tensor_parallel,
axis=cat_dim)[rank]
results[prefix + 'per_channel_scale'] = torch.from_numpy(
np.array([cur_per_channel_value],
dtype=np.float32).reshape(col_shape)).contiguous()
results[last_prefix] = torch.from_numpy(
np.array([vals['scale_x_orig_quant']],
dtype=np.float32)).contiguous()
results[prefix + 'act_scale'] = torch.from_numpy(
np.array([[vals["scale_y_quant_orig"]]],
dtype=np.float32)).contiguous()
if smoother_value is not None:
cur_smoother_value = np.split(smoother_value,
tensor_parallel,
axis=cat_dim)[rank]
results[prefix + 'smoother'] = cur_smoother_value.reshape(
smoother_shape).contiguous().to(torch.float32)
if bias is not None:
results[prefix + 'bias'] = bias
return results
def convert_hf_gpt_legacy(hf_model: AutoModelForCausalLM,
hf_config: AutoConfig,
gpt_variant: str,
mapping: Mapping,
dtype: str = 'float32',
use_parallel_embedding: bool = False,
sharding_dim: int = 0,
share_embedding_table: bool = False,
use_smooth_quant=False,
per_channel=False,
per_token=False,
int8_kv_cache=False,
act_range=None):
weights = {}
tik = time.time()
model_params = dict(hf_model.named_parameters())
dtype = getattr(torch, dtype)
num_attention_heads = hf_config.n_head
hidden_size = hf_config.n_embd
vocab_size = hf_config.vocab_size
num_kv_heads = hf_config.n_kv_head
num_hidden_layers = hf_config.n_layer
multi_query_mode = (num_kv_heads != num_attention_heads)
tensor_parallel = mapping.tp_size
layers_range = mapping.pp_layers(num_hidden_layers)
for l in layers_range:
prefix = f'transformer.h.{l}'
tllm_prex = f'transformer.layers.{l-layers_range[0]}'
if gpt_variant == 'santacoder':
q_w, q_b = get_weight_and_bias(model_params,
f'{prefix}.attn.q_attn', dtype)
kv_w, kv_b = get_weight_and_bias(model_params,
f'{prefix}.attn.kv_attn', dtype)
qkv_w = torch.cat([q_w, kv_w], dim=-1)
qkv_b = torch.cat([q_b, kv_b], dim=-1)
else:
qkv_w, qkv_b = get_weight_and_bias(model_params,
f'{prefix}.attn.c_attn', dtype)
if gpt_variant in ['gpt2', 'santacoder']:
qkv_w = qkv_w.t().contiguous() # transpose for Conv1D
if use_smooth_quant:
qkv_out_dim = qkv_w.shape[0]
qkv_w_numpy = qkv_w.t().numpy()
if not multi_query_mode:
qkv_w_numpy = qkv_w_numpy.reshape(hidden_size, 3, hidden_size)
int8_weights = generate_int8(qkv_w_numpy,
act_range.get(f'{prefix}.attn.c_attn'),
is_qkv=True,
multi_query_mode=multi_query_mode)
qkv_b = split_qkv(qkv_b, mapping.tp_rank, mapping.tp_size,
hidden_size, num_attention_heads, num_kv_heads)
weights.update(
get_tllm_linear_sq_weight(
int8_weights,
f'{tllm_prex}.attention.qkv.',
[1, qkv_out_dim // tensor_parallel],
tensor_parallel,
is_qkv=True,
per_token=per_token,
per_channel=per_channel,
last_prefix=f'{tllm_prex}.input_layernorm.scale_to_int',
bias=qkv_b,
smoother_value=None,
smoother_shape=None,
rank=rank,
cat_dim=-1,
multi_query_mode=multi_query_mode))
else:
qkv_w = split_qkv(qkv_w, mapping.tp_rank, mapping.tp_size,
hidden_size, num_attention_heads, num_kv_heads)
qkv_b = split_qkv(qkv_b, mapping.tp_rank, mapping.tp_size,
hidden_size, num_attention_heads, num_kv_heads)
weights.update(
get_tllm_linear_weight(qkv_w, f'{tllm_prex}.attention.qkv',
qkv_b))
if int8_kv_cache:
qkv_w_numpy = qkv_w.t().numpy()
if not multi_query_mode:
qkv_w_numpy = qkv_w_numpy.reshape(hidden_size, 3, hidden_size)
int8_weights = generate_int8(qkv_w_numpy,
act_range.get(f'{prefix}.attn.c_attn'),
is_qkv=True,
multi_query_mode=multi_query_mode)
weights[
f'{tllm_prex}.attention.kv_cache_scaling_factor'] = torch.from_numpy(
np.array([int8_weights['scale_y_quant_orig']],
dtype=np.float32)).contiguous()
attn_dense_w, attn_dense_b = get_weight_and_bias(
model_params, f'{prefix}.attn.c_proj', dtype)
if gpt_variant in ['gpt2', 'santacoder']:
attn_dense_w = attn_dense_w.t().contiguous() # transpose for Conv1D
if use_smooth_quant:
attn_dense_w_numpy = attn_dense_w.t().numpy()
int8_weights = generate_int8(attn_dense_w_numpy,
act_range.get(f'{prefix}.attn.c_proj'))
# change it to the real smoother if dense layer is applied smooth quant
fake_smoother_value = torch.ones([1, hidden_size],
dtype=torch.float32)
weights.update(
get_tllm_linear_sq_weight(
int8_weights,
f'{tllm_prex}.attention.dense.', [1, hidden_size],
tensor_parallel,
is_qkv=False,
per_token=per_token,
per_channel=per_channel,
last_prefix=
f'{tllm_prex}.attention.quantization_scaling_factor',
bias=attn_dense_b,
smoother_value=fake_smoother_value,
smoother_shape=[1, hidden_size // tensor_parallel],
rank=rank,
cat_dim=0))
else:
attn_dense_w = split(attn_dense_w,
mapping.tp_rank,
mapping.tp_size,
is_column=False)
weights.update(
get_tllm_linear_weight(attn_dense_w,
f'{tllm_prex}.attention.dense',
attn_dense_b))
mlp_fc_w, mlp_fc_b = get_weight_and_bias(model_params,
f'{prefix}.mlp.c_fc', dtype)
if gpt_variant in ['gpt2', 'santacoder']:
mlp_fc_w = mlp_fc_w.t().contiguous() # transpose for Conv1D
if use_smooth_quant:
mlp_fc_w_numpy = mlp_fc_w.t().numpy()
int8_weights = generate_int8(mlp_fc_w_numpy,
act_range.get(f'{prefix}.mlp.c_fc'))
mlp_fc_b = split(mlp_fc_b,
mapping.tp_rank,
mapping.tp_size,
is_column=True)
weights.update(
get_tllm_linear_sq_weight(
int8_weights,
f'{tllm_prex}.mlp.fc.',
[1, 4 * hidden_size // tensor_parallel],
tensor_parallel,
is_qkv=False,
per_token=per_token,
per_channel=per_channel,
last_prefix=f'{tllm_prex}.post_layernorm.scale_to_int',
bias=mlp_fc_b,
smoother_value=None,
smoother_shape=None,
rank=rank,
cat_dim=-1))
else:
mlp_fc_w = split(mlp_fc_w,
mapping.tp_rank,
mapping.tp_size,
is_column=True)
mlp_fc_b = split(mlp_fc_b,
mapping.tp_rank,
mapping.tp_size,
is_column=True)
weights.update(
get_tllm_linear_weight(mlp_fc_w, f'{tllm_prex}.mlp.fc',
mlp_fc_b))
mlp_proj_w, mlp_proj_b = get_weight_and_bias(model_params,
f'{prefix}.mlp.c_proj',
dtype)
if gpt_variant in ['gpt2', 'santacoder']:
mlp_proj_w = mlp_proj_w.t().contiguous() # transpose for Conv1D
if use_smooth_quant:
mlp_proj_w_numpy = mlp_proj_w.t().numpy()
int8_weights = generate_int8(mlp_proj_w_numpy,
act_range.get(f'{prefix}.mlp.c_proj'))
# change it to the real smoother if proj layer is applied smooth quant
fake_smoother_value = torch.ones([1, 4 * hidden_size],
dtype=torch.float32)
weights.update(
get_tllm_linear_sq_weight(
int8_weights,
f'{tllm_prex}.mlp.proj.', [1, hidden_size],
tensor_parallel,
is_qkv=False,
per_token=per_token,
per_channel=per_channel,
last_prefix=f'{tllm_prex}.mlp.quantization_scaling_factor',
bias=mlp_proj_b,
smoother_value=fake_smoother_value,
smoother_shape=[1, 4 * hidden_size // tensor_parallel],
rank=rank,
cat_dim=0))
else:
mlp_proj_w = split(mlp_proj_w,
mapping.tp_rank,
mapping.tp_size,
is_column=False)
weights.update(
get_tllm_linear_weight(mlp_proj_w, f'{tllm_prex}.mlp.proj',
mlp_proj_b))
input_ln_w, input_ln_b = get_weight_and_bias(model_params,
f'{prefix}.ln_1', dtype)
weights[f'{tllm_prex}.input_layernorm.weight'] = input_ln_w
if input_ln_b is not None:
weights[f'{tllm_prex}.input_layernorm.bias'] = input_ln_b
post_ln_w, post_ln_b = get_weight_and_bias(model_params,
f'{prefix}.ln_2', dtype)
weights[f'{tllm_prex}.post_layernorm.weight'] = post_ln_w
if post_ln_b is not None:
weights[f'{tllm_prex}.post_layernorm.bias'] = post_ln_b
if mapping.is_first_pp_rank():
embed_w = get_weight(model_params, 'transformer.wte', dtype)
weights['transformer.vocab_embedding.weight'] = split_embedding(
embed_w,
mapping.tp_rank,
mapping.tp_size,
use_parallel_embedding=use_parallel_embedding,
sharding_dim=sharding_dim)
pos_embed_w = get_weight(model_params, 'transformer.wpe', dtype)
if pos_embed_w is not None:
weights['transformer.position_embedding.weight'] = split_embedding(
pos_embed_w,
mapping.tp_rank,
mapping.tp_size,
use_parallel_embedding=use_parallel_embedding,
sharding_dim=sharding_dim)
if mapping.is_last_pp_rank():
embed_w = get_weight(model_params, 'transformer.wte', dtype)
if not share_embedding_table:
if vocab_size % mapping.tp_size != 0:
vocab_size_padded = pad_vocab_size(vocab_size, mapping.tp_size)
pad_width = vocab_size_padded - vocab_size
embed_w = torch.nn.functional.pad(embed_w, (0, 0, 0, pad_width),
value=0)
weights['lm_head.weight'] = split(embed_w.clone(),
mapping.tp_rank,
mapping.tp_size,
is_column=True)
ln_f_w, ln_f_b = get_weight_and_bias(model_params, 'transformer.ln_f',
dtype)
weights['transformer.ln_f.weight'] = ln_f_w
if ln_f_b is not None:
weights['transformer.ln_f.bias'] = ln_f_b
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
print(f'Weights loaded. Total time: {t}')
return weights
def cpu_map_location(storage, loc):
return storage.cpu()
def gpu_map_location(storage, loc):
if loc.startswith("cuda"):
training_gpu_idx = int(loc.split(":")[1])
inference_gpu_idx = training_gpu_idx % torch.cuda.device_count()
return storage.cuda(inference_gpu_idx)
elif loc.startswith("cpu"):
return storage.cpu()
else:
raise ValueError(f"Not handled {loc}")
def copy_tokenizer_files(config, out_dir):
basenames = {
"model": "tokenizer",
"vocab_file": "vocab",
"merge_file": "merges",
}
for key in basenames.keys():
if config[key] is None:
continue
path = Path(config[key])
if not path.exists():
LOGGER.debug(f"Tokenizer {key}: {path} file not found")
continue
dst_path = out_dir / f"{basenames[key]}{path.suffix}"
LOGGER.debug(f"Copy tokenizer {key}: {path}->{dst_path}")
shutil.copy(path.as_posix(), dst_path.as_posix())
def update_tokenizer_paths(tokenizer_config: Dict,
tokenizer_file_paths: Dict[str, Optional[str]]):
for key, new_path in tokenizer_file_paths.items():
old_path = tokenizer_config[key]
if old_path is None:
continue
old_path = Path(old_path)
if new_path:
LOGGER.debug(f"Update tokenizer {key} {old_path} -> {new_path}")
tokenizer_config[key] = new_path.as_posix()
elif not old_path.exists():
LOGGER.warning(
f"Tokenizer {key}'s path {old_path} does not exists: set it to None"
)
tokenizer_config[key] = None
return tokenizer_config
def unpack_nemo_ckpt(nemo_archive_path: Union[str, Path],
out_dir_path: Union[str, Path]):
nemo_archive_path = Path(nemo_archive_path)
if not nemo_archive_path.exists():
raise FileNotFoundError(f"{nemo_archive_path} does not exist")
for tar_mode in ["r:", "r:gz"]:
try:
with tarfile.open(nemo_archive_path, mode=tar_mode) as tar_file:
def is_within_directory(directory, target):
abs_directory = os.path.abspath(directory)
abs_target = os.path.abspath(target)
prefix = os.path.commonprefix([abs_directory, abs_target])
return prefix == abs_directory
def safe_members(tar_file):
members = []
for member in tar_file.getmembers():
member_path = os.path.join(out_dir_path, member.name)
if not is_within_directory(out_dir_path, member_path):
raise Exception(
"Attempted Path Traversal in Tar File")
members.append(member)
return members
tar_file.extractall(out_dir_path,
members=safe_members(tar_file),
numeric_owner=False)
return out_dir_path
except tarfile.ReadError:
pass
raise RuntimeError(f"Could not unpack {nemo_archive_path}")
def extract_layers_with_prefix(model_, prefix):
length_to_trim = len(prefix)
model_state = model_.get("state_dict", model_)
return {
key[length_to_trim:]: model_state[key]
for key in model_state.keys() if prefix in key
}
class UnpackedNemoCheckpointDir:
def __init__(self,
checkpoints_dir: Union[str, Path],
load_checkpoints_to_cpu: bool = False):
self._checkpoints_dir = Path(checkpoints_dir)
self._load_checkpoints_to_cpu = load_checkpoints_to_cpu
@property
@functools.lru_cache
def model_config(self):
model_config = None
model_config_filename = "model_config.yaml"
model_configs_paths = list(
self._checkpoints_dir.rglob(model_config_filename))
if model_configs_paths:
if len(model_configs_paths) > 1:
raise RuntimeError(
f"There are more than single {model_config_filename} "
f"in {self._checkpoints_dir}: {', '.join(map(lambda p: p.as_posix(), model_configs_paths))}"
)
model_config_path = model_configs_paths[0]
LOGGER.debug("Loading model config from %s", model_config_path)
with model_config_path.open("r") as model_config_file:
model_config = yaml.load(model_config_file,
Loader=yaml.SafeLoader)
else:
LOGGER.debug("Searching model config in checkpoints")
# try to obtain from checkpoint
checkpoint_name = self.checkpoint_name
checkpoints_paths = sorted(
self._checkpoints_dir.rglob(checkpoint_name))
if checkpoints_paths:
# assume that parallel ranks 0 checkpoint should have model config embedded
checkpoint_path = checkpoints_paths[0]
map_location_fn = cpu_map_location if self._load_checkpoints_to_cpu else gpu_map_location
model_00 = torch.load(checkpoint_path,
map_location=map_location_fn)
if "hyper_parameters" in model_00 and "cfg" in model_00[
"hyper_parameters"]:
model_config = model_00["hyper_parameters"]["cfg"]
LOGGER.debug("Loaded model config from checkpoint %s",
checkpoint_path)
else:
LOGGER.debug("Could not find model config in checkpoint %s",
checkpoint_path)
del model_00
if model_config is None:
LOGGER.warning(
"Could not find checkpoint with NeMo model config in %s",
self._checkpoints_dir)
LOGGER.debug("Loaded model config %s", model_config)
return model_config
@property
def checkpoints_dir(self):
return self._checkpoints_dir
def get_checkpoints_paths(self,
tensor_model_parallel_size=1,
pipeline_model_parallel_size=1):
"""
Injects tensor/pipeline model parallel ranks into the filepath.
Does nothing if not using model parallelism.
"""
checkpoint_path_without_rank = self.checkpoints_dir / self.checkpoint_name
def _inject_parallel_ranks(tp_rank, pp_rank):
if tensor_model_parallel_size > 1 or pipeline_model_parallel_size > 1:
if pipeline_model_parallel_size is None or pipeline_model_parallel_size == 1:
checkpoint_path = (checkpoint_path_without_rank.parent /
f"mp_rank_{tp_rank:02d}" /
checkpoint_path_without_rank.name)
else:
checkpoint_path = (
checkpoint_path_without_rank.parent /
f"tp_rank_{tp_rank:02d}_pp_rank_{pp_rank:03d}" /
checkpoint_path_without_rank.name)
return checkpoint_path
else:
return checkpoint_path_without_rank
return [[
_inject_parallel_ranks(tp_rank=tp_rank, pp_rank=pp_rank)
for pp_rank in range(pipeline_model_parallel_size)
] for tp_rank in range(tensor_model_parallel_size)]
@property
@functools.lru_cache
def checkpoint_name(self):
patterns = [
"model_weights.ckpt", # older megatron checkpoints
"*last.ckpt", # newer format of checkpoints
]
for pattern in patterns:
model_files = sorted(list(self._checkpoints_dir.rglob(pattern)))
if model_files:
return model_files[0].name
raise ValueError(
f"Could not find checkpoint files in {self._checkpoints_dir}")
@functools.lru_cache
def get_tokenizer_file_path(self, tokenizer_key, file_key,
default_filename_pattern):
model_config = self.model_config
file_property = None
if tokenizer_key in model_config and file_key in model_config[
tokenizer_key]:
file_property = model_config[tokenizer_key][file_key]
elif file_key in model_config:
file_property = model_config[file_key]
LOGGER.debug("model_config[%s][%s]=%s", tokenizer_key, file_key,
file_property)
if file_property and file_property.startswith("nemo:"):
filename = file_property.split("nemo:")[1]
filename_pattern = f"*{filename}"
elif file_property and file_property.startswith("/artifacts/"):
filename = Path(file_property).name
filename_pattern = f"*{filename}"
elif file_property is None or file_property == "None":
filename_pattern = None
else:
filename_pattern = default_filename_pattern
LOGGER.warning(
f"Tokenizer file from config: {tokenizer_key}.{file_key}={file_property} "
f"looks like unsupported path. Pattern {filename_pattern} will be used."
)
file_path = None
if filename_pattern is not None:
files_paths = list(self._checkpoints_dir.glob(filename_pattern))
if files_paths:
assert len(files_paths) == 1
file_path = files_paths[0]
return file_path
@functools.lru_cache
def get_all_tokenizer_file_paths(self):
return {
"model":
self.get_tokenizer_file_path("tokenizer", "model", "*.model"),
"vocab_file":
self.get_tokenizer_file_path("tokenizer", "vocab_file", "*vocab*"),
"merge_file":
self.get_tokenizer_file_path("tokenizer", "merge_file",
"*merge*.txt"),
}
def load_nemo_gpt_config(
unpacked_checkpoints_dir: UnpackedNemoCheckpointDir,
layer_rename_config: Dict[str, str] = None) -> GPT2Config:
nemo_model_config = unpacked_checkpoints_dir.model_config
training_tp_size = nemo_model_config.get("tensor_model_parallel_size", 1)
training_pp_size = nemo_model_config.get("pipeline_model_parallel_size", 1)
checkpoints_paths = unpacked_checkpoints_dir.get_checkpoints_paths(
training_tp_size,
training_pp_size,
)
if unpacked_checkpoints_dir._load_checkpoints_to_cpu:
map_location_fn = cpu_map_location
else:
map_location_fn = gpu_map_location
model_00 = torch.load(checkpoints_paths[0][0], map_location=map_location_fn)
model_00 = rename_keys(model_00, layer_rename_config)
vocab_size = model_00[
"model.language_model.embedding.word_embeddings.weight"].shape[
0] * training_tp_size
del model_00
hf_config = GPT2Config(
vocab_size=vocab_size,
n_positions=nemo_model_config['max_position_embeddings'],
n_embd=nemo_model_config['hidden_size'],
n_layer=nemo_model_config['num_layers'],
n_head=nemo_model_config['num_attention_heads'],
n_inner=nemo_model_config['ffn_hidden_size'],
activation_function=nemo_model_config['activation'],
layer_norm_epsilon=nemo_model_config['layernorm_epsilon'],
)
hf_config.n_kv_head = hf_config.n_head
hf_config.bias = nemo_model_config['bias']
# hf_config.apply_query_key_layer_scaling = nemo_model_config['apply_query_key_layer_scaling']
hf_config.apply_query_key_layer_scaling = False
hf_config.position_embedding_type = nemo_model_config.get(
'position_embedding_type', 'learned_absolute')
if hf_config.position_embedding_type == 'rope':
hf_config.position_embedding_type = 'rope_gpt_neox'
hf_config.rotary_base = nemo_model_config.get('rotary_base', 10000.0)
hf_config.rotary_pct = nemo_model_config.get('rotary_percentage', 1.0)
assert hf_config.rotary_pct >= 0 and hf_config.rotary_pct <= 1
rotary_scaling_factor = nemo_model_config.get(
'seq_len_interpolation_factor', None)
if rotary_scaling_factor is None:
hf_config.rotary_scaling = None
else:
assert rotary_scaling_factor > 1
hf_config.rotary_scaling = {
'type': 'linear',
'factor': rotary_scaling_factor
}
tokenizer_config = update_tokenizer_paths(
nemo_model_config["tokenizer"],
unpacked_checkpoints_dir.get_all_tokenizer_file_paths())
return hf_config, tokenizer_config
@torch.no_grad()
def load_torch_checkpoints(checkpoints_paths,
merge_factor,
tp_rank,
pp_rank,
map_location_fn,
handle_model_level_weights,
layer_rename_config: Dict[str, str] = {}):
models = []
for k in range(merge_factor):
rank_weights = checkpoints_paths[tp_rank * merge_factor + k][pp_rank]
model = torch.load(rank_weights, map_location=map_location_fn)
model = rename_keys(model, layer_rename_config)
handle_model_level_weights(model, tp_rank * merge_factor + k, pp_rank)
layers = extract_layers_with_prefix(model,
"model.language_model.encoder.")
models.append(layers)
return models
@torch.no_grad()
def convert_nemo_gpt(unpacked_checkpoints_dir: UnpackedNemoCheckpointDir,
mapping: Mapping,
dtype: str = 'float32',
layer_rename_config: Dict[str, str] = None):
nemo_model_config = unpacked_checkpoints_dir.model_config
checkpoints_paths = unpacked_checkpoints_dir.get_checkpoints_paths(
nemo_model_config.get("tensor_model_parallel_size", 1),
nemo_model_config.get("pipeline_model_parallel_size", 1),
)
if unpacked_checkpoints_dir._load_checkpoints_to_cpu:
map_location_fn = cpu_map_location
else:
map_location_fn = gpu_map_location
dtype = str_dtype_to_torch(dtype)
# load position_embedding from rank 0
model_00 = torch.load(checkpoints_paths[0][0], map_location=map_location_fn)
model_00 = model_00.get("state_dict", model_00)
model_00 = rename_keys(model_00, layer_rename_config)
has_position_embedding = "model.language_model.embedding.position_embeddings.weight" in model_00
has_lm_head = "model.language_model.output_layer.weight" in model_00
del model_00
num_layers = nemo_model_config["num_layers"]
training_tp_size = nemo_model_config.get("tensor_model_parallel_size", 1)
training_pp_size = nemo_model_config.get("pipeline_model_parallel_size", 1)
inference_tp_size = mapping.tp_size
inference_tp_rank = mapping.tp_rank
apply_layernorm_1p = (nemo_model_config.get('normalization',
'') == "layernorm1p")
split_gated_activation = ("swiglu"
in nemo_model_config.get('activation', "gelu"))
num_attention_heads = nemo_model_config["num_attention_heads"]
# use_attention_nemo_shape = True
transpose_weights = True
# multi_query_mode = False
local_dim = None
# merge_factor: how many TP training nodes are merged into an inference TP node
# split_factor: in how many parts a TP training node is split
gcd = np.gcd(training_tp_size, inference_tp_size)
merge_factor = training_tp_size // gcd
split_factor = inference_tp_size // gcd
model_level_weights = defaultdict(list)
def handle_model_level_weights(model, tp_idx: int, pp_idx: int):
if tp_idx == 0 and pp_idx == 0:
if has_position_embedding:
val = model[
"model.language_model.embedding.position_embeddings.weight"]
model_level_weights[
"transformer.position_embedding.weight"].append(val)
if pp_idx == 0:
val = model.get(
"state_dict",
model)["model.language_model.embedding.word_embeddings.weight"]
model_level_weights["transformer.vocab_embedding.weight"].append(
val)
if has_lm_head and pp_idx == training_pp_size - 1:
val = model.get("state_dict",
model)["model.language_model.output_layer.weight"]
model_level_weights["lm_head.weight"].append(val)
weights = {}
tik = time.time()
tp_rank = inference_tp_rank // split_factor
# for tp_rank in range(training_tp_size // merge_factor):
for pp_rank in range(training_pp_size):
models = load_torch_checkpoints(checkpoints_paths, merge_factor,
tp_rank, pp_rank, map_location_fn,
handle_model_level_weights,
layer_rename_config)
for name in list(models[0].keys()):
params = [model[name] for model in models]
if transpose_weights and params[0].ndim == 2:
params = [p.T for p in params]
if "layernorm.weight" in name and apply_layernorm_1p:
params = [p + 1.0 for p in params]
l = retrieved_layer_index_from_name(name)
if l is not None:
new_l = l + pp_rank * num_layers // training_pp_size
prefix = f'transformer.layers.{new_l}'
if 'attention.query_key_value' in name:
if name.endswith('weight'):
hidden_dim = params[0].shape[0]
if local_dim is None:
local_dim = params[0].shape[-1] // 3
# multi_query_mode = False; use_attention_nemo_shape = True
head_num = num_attention_heads // training_tp_size
size_per_head = hidden_dim // num_attention_heads
params = [
param.reshape(hidden_dim, head_num, 3,
size_per_head) for param in params
]
params = [param.permute(0, 2, 1, 3) for param in params]
params = [
param.reshape(hidden_dim, 3, local_dim)
for param in params
]
cat_dim = -1
param = torch.concat(params, dim=cat_dim)
param = torch.chunk(param, split_factor,
dim=cat_dim)[inference_tp_rank %
split_factor]
weights[
f'{prefix}.attention.qkv.weight'] = param.reshape(
hidden_dim, -1).t()
else:
if local_dim is None:
local_dim = params[0].shape[-1] // 3
# multi_query_mode = False; use_attention_nemo_shape = True
head_num = num_attention_heads // training_tp_size
size_per_head = local_dim // head_num
params = [
param.reshape(head_num, 3, size_per_head)
for param in params
]
params = [param.permute(1, 0, 2) for param in params]
params = [
param.reshape(3, local_dim) for param in params
]
cat_dim = -1
param = torch.concat(params, dim=cat_dim)
param = torch.chunk(param, split_factor,
dim=cat_dim)[inference_tp_rank %
split_factor]
weights[f'{prefix}.attention.qkv.bias'] = param.reshape(
-1)
elif 'attention.dense' in name:
if name.endswith('weight'):
cat_dim = 0
param = torch.concat(params, dim=cat_dim)
param = torch.chunk(param, split_factor,
dim=cat_dim)[inference_tp_rank %
split_factor]
weights[f'{prefix}.attention.dense.weight'] = param.t()
else:
weights[f'{prefix}.attention.dense.bias'] = params[0]
elif 'mlp.dense_h_to_4h' in name:
if name.endswith('weight'):
if split_gated_activation:
params = [torch.chunk(p, 2, dim=-1) for p in params]
params, gate_params = list(zip(*params))
cat_dim = -1
param = torch.concat(params, dim=cat_dim)
param = torch.chunk(param, split_factor,
dim=cat_dim)[inference_tp_rank %
split_factor]
weights[f'{prefix}.mlp.fc.weight'] = param.t()
if split_gated_activation:
gate_param = torch.concat(gate_params, dim=cat_dim)
gate_param = torch.chunk(
gate_param, split_factor,
dim=cat_dim)[inference_tp_rank % split_factor]
weights[f'{prefix}.mlp.gate.weight'] = gate_param.t(
)
else:
if split_gated_activation:
params = [torch.chunk(p, 2, dim=-1) for p in params]
params, gate_params = list(zip(*params))
cat_dim = -1
param = torch.concat(params, dim=cat_dim)
param = torch.chunk(param, split_factor,
dim=cat_dim)[inference_tp_rank %
split_factor]
weights[f'{prefix}.mlp.fc.bias'] = param
if split_gated_activation:
gate_param = torch.concat(gate_params, dim=cat_dim)
gate_param = torch.chunk(
gate_param, split_factor,
dim=cat_dim)[inference_tp_rank % split_factor]
weights[f'{prefix}.mlp.gate.bias'] = gate_param
elif 'mlp.dense_4h_to_h' in name:
if name.endswith('weight'):
cat_dim = 0
param = torch.concat(params, dim=cat_dim)
param = torch.chunk(param, split_factor,
dim=cat_dim)[inference_tp_rank %
split_factor]
weights[f'{prefix}.mlp.proj.weight'] = param.t()
else:
weights[f'{prefix}.mlp.proj.bias'] = params[0]
elif 'input_layernorm' in name:
if name.endswith('weight'):
weights[f'{prefix}.input_layernorm.weight'] = params[0]
else:
weights[f'{prefix}.input_layernorm.bias'] = params[0]
elif 'post_attention_layernorm' in name:
if name.endswith('weight'):
weights[f'{prefix}.post_layernorm.weight'] = params[0]
else:
weights[f'{prefix}.post_layernorm.bias'] = params[0]
elif 'final_layernorm' in name:
if name.endswith('weight'):
weights['transformer.ln_f.weight'] = params[0]
else:
weights['transformer.ln_f.bias'] = params[0]
for model in models:
del model[name]
del models
for key in list(model_level_weights.keys()):
weights[key] = torch.concat(model_level_weights[key], dim=0)
del model_level_weights[key]
for key, param in weights.items():
weights[key] = weights[key].to(dtype).contiguous()
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
print(f'Weights loaded. Total time: {t}')
return weights
if __name__ == '__main__':
# TODO(qijun): Currently, the convert script depends on a torch op:
# torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix,
# which is included in tensorrt_llm Python package. Otherwise, the convert
# script does not need to import tensorrt_llm. Will remove it after reimplementing
# the op with PyTorch.
print(tensorrt_llm.__version__)
args = parse_arguments()
world_size = args.tp_size * args.pp_size
tik = time.time()
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
quant_algo = None
kv_cache_quant_algo = None
plugin_weight_only_quant_type = None
if args.use_weight_only:
if args.weight_only_precision == 'int8':
plugin_weight_only_quant_type = torch.int8
quant_algo = QuantAlgo.W8A16
elif args.weight_only_precision == 'int4':
plugin_weight_only_quant_type = torch.quint4x2
quant_algo = QuantAlgo.W4A16
elif args.smoothquant:
if args.per_token and args.per_channel:
quant_algo = QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN
elif not args.per_token and not args.per_channel:
quant_algo = QuantAlgo.W8A8_SQ_PER_TENSOR_PLUGIN
elif not args.per_token and args.per_channel:
quant_algo = QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN
elif args.per_token and not args.per_channel:
quant_algo = QuantAlgo.W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN
if args.int8_kv_cache:
kv_cache_quant_algo = QuantAlgo.INT8
if args.model_dir is not None:
hf_config, gpt_variant = load_gpt_config(args.model_dir,
args.gpt_variant)
elif args.nemo_ckpt_path is not None:
nemo_dir = Path(args.output_dir) / "unpacked"
nemo_dir = unpack_nemo_ckpt(args.nemo_ckpt_path, nemo_dir)
unpacked_checkpoints_dir = UnpackedNemoCheckpointDir(
nemo_dir, load_checkpoints_to_cpu=not args.load_nemo_on_gpu)
layer_rename_config = {
pattern.split(':')[0]: pattern.split(':')[1]
for pattern in args.nemo_rename_key
}
hf_config, tokenizer_config = load_nemo_gpt_config(
unpacked_checkpoints_dir, layer_rename_config)
copy_tokenizer_files(tokenizer_config, Path(args.output_dir))
args.use_parallel_embedding = True
args.embedding_sharding_dim = 0
else:
raise NotImplementedError("No source model path specified!")
config = {
'architecture':
'GPTForCausalLM',
'dtype':
args.dtype,
'num_hidden_layers':
hf_config.n_layer,
'num_attention_heads':
hf_config.n_head,
'num_key_value_heads':
hf_config.n_kv_head,
'hidden_size':
hf_config.n_embd,
'intermediate_size':
hf_config.n_inner,
'norm_epsilon':
hf_config.layer_norm_epsilon,
'vocab_size':
hf_config.vocab_size,
'position_embedding_type':
getattr(hf_config, 'position_embedding_type', 'learned_absolute'),
'max_position_embeddings':
hf_config.n_positions,
'hidden_act':
hf_config.activation_function,
'use_parallel_embedding':
args.use_parallel_embedding,
'embedding_sharding_dim':
args.embedding_sharding_dim,
'share_embedding_table':
args.use_embedding_sharing,
'quantization': {
'quant_algo': quant_algo,
'kv_cache_quant_algo': kv_cache_quant_algo,
},
'mapping': {
'world_size': world_size,
'tp_size': args.tp_size,
'pp_size': args.pp_size,
},
'bias':
getattr(hf_config, 'bias', True),
'apply_query_key_layer_scaling':
getattr(hf_config, 'apply_query_key_layer_scaling', False),
'rotary_pct':
getattr(hf_config, 'rotary_pct', 1.0),
'rotary_base':
getattr(hf_config, 'rotary_base', 10000.0),
'rotary_scaling':
getattr(hf_config, 'rotary_scaling', None),
'qk_layernorm':
args.model_dir is not None and gpt_variant == 'persimmon',
'inner_layernorm':
args.model_dir is not None and gpt_variant == 'kosmos-2',
'norm_before_bmm1':
args.model_dir is not None and gpt_variant == 'kosmos-2',
'scale_embedding':
args.model_dir is not None and gpt_variant == 'kosmos-2'
and hf_config.text_config.scale_embedding,
}
with open(os.path.join(args.output_dir, 'config.json'), 'w') as f:
json.dump(config, f, indent=4)
if args.model_dir is not None:
if gpt_variant == 'kosmos-2':
hf_model = AutoModelForVision2Seq.from_pretrained(
args.model_dir, trust_remote_code=True)
else:
hf_model = AutoModelForCausalLM.from_pretrained(
args.model_dir,
trust_remote_code=True,
device_map="auto",
torch_dtype="auto")
if args.smoothquant is not None or args.int8_kv_cache:
os.environ["TOKENIZERS_PARALLELISM"] = os.environ.get(
"TOKENIZERS_PARALLELISM", "false")
dataset = load_dataset("lambada",
split="validation",
cache_dir=args.dataset_cache_dir)
tokenizer = AutoTokenizer.from_pretrained(args.model_dir)
act_range = capture_activation_range(hf_model, tokenizer, dataset)
if args.smoothquant is not None:
smooth_gpt_model(hf_model, act_range, args.smoothquant)
def convert_and_save(rank):
mapping = Mapping(world_size=world_size,
rank=rank,
tp_size=args.tp_size,
pp_size=args.pp_size)
if args.model_dir is not None:
if args.smoothquant is not None or args.int8_kv_cache:
weights = convert_hf_gpt_legacy(
hf_model,
hf_config,
gpt_variant,
mapping,
dtype=args.dtype,
use_parallel_embedding=args.use_parallel_embedding,
sharding_dim=args.embedding_sharding_dim,
share_embedding_table=args.use_embedding_sharing,
use_smooth_quant=(args.smoothquant is not None),
per_channel=args.per_channel,
per_token=args.per_token,
int8_kv_cache=args.int8_kv_cache,
act_range=act_range,
)
else:
weights = convert_hf_gpt(
hf_model,
hf_config,
gpt_variant,
mapping,
dtype=args.dtype,
use_parallel_embedding=args.use_parallel_embedding,
sharding_dim=args.embedding_sharding_dim,
share_embedding_table=args.use_embedding_sharing,
use_weight_only=args.use_weight_only,
plugin_weight_only_quant_type=plugin_weight_only_quant_type,
)
elif args.nemo_ckpt_path is not None:
weights = convert_nemo_gpt(unpacked_checkpoints_dir, mapping,
args.dtype, layer_rename_config)
safetensors.torch.save_file(
weights, os.path.join(args.output_dir, f'rank{rank}.safetensors'))
if args.workers == 1:
for rank in range(world_size):
convert_and_save(rank)
else:
with ThreadPoolExecutor(max_workers=args.workers) as p:
futures = [
p.submit(convert_and_save, rank) for rank in range(world_size)
]
exceptions = []
for future in as_completed(futures):
try:
future.result()
except Exception as e:
traceback.print_exc()
exceptions.append(e)
assert len(
exceptions
) == 0, "Checkpoint conversion failed, please check error log."
if args.model_dir is not None:
del hf_model
elif args.nemo_ckpt_path is not None:
shutil.rmtree(nemo_dir)
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
print(f'Total time of converting checkpoints: {t}')