mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Signed-off-by: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com> Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
333 lines
12 KiB
Python
333 lines
12 KiB
Python
import argparse
|
|
import json
|
|
import os
|
|
import re
|
|
import time
|
|
import traceback
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
|
|
import safetensors.torch
|
|
import torch
|
|
|
|
import tensorrt_llm
|
|
from tensorrt_llm import str_dtype_to_torch
|
|
from tensorrt_llm.mapping import Mapping
|
|
from tensorrt_llm.models.convert_utils import (split, split_matrix_tp,
|
|
split_qkv_bias_tp, split_qkv_tp)
|
|
|
|
FACEBOOK_DIT_NAME_MAPPING = {
|
|
'^t_embedder.mlp.0.weight$':
|
|
't_embedder.mlp1.weight',
|
|
'^t_embedder.mlp.0.bias$':
|
|
't_embedder.mlp1.bias',
|
|
'^t_embedder.mlp.2.weight$':
|
|
't_embedder.mlp2.weight',
|
|
'^t_embedder.mlp.2.bias$':
|
|
't_embedder.mlp2.bias',
|
|
'^t_embedder.mlp.0.weights_scaling_factor$':
|
|
't_embedder.mlp1.weights_scaling_factor',
|
|
'^t_embedder.mlp.0.activation_scaling_factor$':
|
|
't_embedder.mlp1.activation_scaling_factor',
|
|
'^t_embedder.mlp.2.weights_scaling_factor$':
|
|
't_embedder.mlp2.weights_scaling_factor',
|
|
'^t_embedder.mlp.2.activation_scaling_factor$':
|
|
't_embedder.mlp2.activation_scaling_factor',
|
|
|
|
# Add negative lookhead for scaling matching.
|
|
'^blocks.(\d+).mlp.fc1.weight$':
|
|
'blocks.*.mlp.fc.weight',
|
|
'^blocks.(\d+).mlp.fc1.bias$':
|
|
'blocks.*.mlp.fc.bias',
|
|
'^blocks.(\d+).mlp.fc2.weight$':
|
|
'blocks.*.mlp.proj.weight',
|
|
'^blocks.(\d+).mlp.fc2.bias$':
|
|
'blocks.*.mlp.proj.bias',
|
|
'^blocks.(\d+).mlp.fc1.weights_scaling_factor$':
|
|
'blocks.*.mlp.fc.weights_scaling_factor',
|
|
'^blocks.(\d+).mlp.fc1.activation_scaling_factor$':
|
|
'blocks.*.mlp.fc.activation_scaling_factor',
|
|
'^blocks.(\d+).mlp.fc2.weights_scaling_factor$':
|
|
'blocks.*.mlp.proj.weights_scaling_factor',
|
|
'^blocks.(\d+).mlp.fc2.activation_scaling_factor$':
|
|
'blocks.*.mlp.proj.activation_scaling_factor',
|
|
|
|
# Add negative lookhead for scaling matching.
|
|
'^blocks.(\d+).attn.proj.weight$':
|
|
'blocks.*.attn.dense.weight',
|
|
'^blocks.(\d+).attn.proj.bias$':
|
|
'blocks.*.attn.dense.bias',
|
|
'^blocks.(\d+).attn.proj.weights_scaling_factor$':
|
|
'blocks.*.attn.dense.weights_scaling_factor',
|
|
'^blocks.(\d+).attn.proj.activation_scaling_factor$':
|
|
'blocks.*.attn.dense.activation_scaling_factor',
|
|
'^blocks.(\d+).adaLN_modulation.1.weight$':
|
|
'blocks.*.adaLN_modulation.weight',
|
|
'^blocks.(\d+).adaLN_modulation.1.bias$':
|
|
'blocks.*.adaLN_modulation.bias',
|
|
'^blocks.(\d+).adaLN_modulation.1.weights_scaling_factor$':
|
|
'blocks.*.adaLN_modulation.weights_scaling_factor',
|
|
'^blocks.(\d+).adaLN_modulation.1.activation_scaling_factor$':
|
|
'blocks.*.adaLN_modulation.activation_scaling_factor',
|
|
'^final_layer.adaLN_modulation.1.weight$':
|
|
'final_layer.adaLN_modulation.weight',
|
|
'^final_layer.adaLN_modulation.1.bias$':
|
|
'final_layer.adaLN_modulation.bias',
|
|
'^final_layer.adaLN_modulation.1.weights_scaling_factor$':
|
|
'final_layer.adaLN_modulation.weights_scaling_factor',
|
|
'^final_layer.adaLN_modulation.1.activation_scaling_factor$':
|
|
'final_layer.adaLN_modulation.activation_scaling_factor',
|
|
}
|
|
|
|
|
|
def parse_arguments():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--timm_ckpt',
|
|
type=str,
|
|
default="./DiT-XL-2-512x512.pt")
|
|
parser.add_argument('--output_dir',
|
|
type=str,
|
|
default='tllm_checkpoint',
|
|
help='The path to save the TensorRT LLM checkpoint')
|
|
parser.add_argument('--input_size',
|
|
type=int,
|
|
default=64,
|
|
help='The input latent size')
|
|
parser.add_argument('--patch_size',
|
|
type=int,
|
|
default=2,
|
|
help='The patch size for patchify')
|
|
parser.add_argument('--in_channels',
|
|
type=int,
|
|
default=4,
|
|
help='The channels of input latent')
|
|
parser.add_argument('--hidden_size',
|
|
type=int,
|
|
default=1152,
|
|
help='The hidden size of DiT')
|
|
parser.add_argument('--depth',
|
|
type=int,
|
|
default=28,
|
|
help='The number of DiTBlock layers')
|
|
parser.add_argument('--num_heads',
|
|
type=int,
|
|
default=16,
|
|
help='The number of heads of attention module')
|
|
parser.add_argument(
|
|
'--mlp_ratio',
|
|
type=float,
|
|
default=4.0,
|
|
help=
|
|
'The ratio of hidden size compared to input hidden size in MLP layer')
|
|
parser.add_argument(
|
|
'--class_dropout_prob',
|
|
type=float,
|
|
default=0.1,
|
|
help='The probability to drop class token when training')
|
|
parser.add_argument('--num_classes',
|
|
type=int,
|
|
default=1000,
|
|
help='The number of classes for conditional control')
|
|
parser.add_argument('--learn_sigma',
|
|
type=bool,
|
|
default=True,
|
|
help='Whether the model learn sigma')
|
|
parser.add_argument('--cfg_scale', type=float, default=4.0)
|
|
parser.add_argument('--tp_size',
|
|
type=int,
|
|
default=1,
|
|
help='N-way tensor parallelism size')
|
|
parser.add_argument('--cp_size',
|
|
type=int,
|
|
default=1,
|
|
help='Context parallelism size')
|
|
parser.add_argument('--pp_size',
|
|
type=int,
|
|
default=1,
|
|
help='N-way pipeline parallelism size')
|
|
parser.add_argument('--dtype',
|
|
type=str,
|
|
default='float16',
|
|
choices=['float32', 'bfloat16', 'float16'])
|
|
parser.add_argument('--fp8_linear',
|
|
action='store_true',
|
|
help='Whether use FP8 for linear layers')
|
|
parser.add_argument(
|
|
'--diffusers_dit',
|
|
action='store_true',
|
|
help='Convert checkpoint provided by `HuggingFace/diffusers`')
|
|
parser.add_argument(
|
|
'--workers',
|
|
type=int,
|
|
default=1,
|
|
help='The number of workers for converting checkpoint in parallel')
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
def convert_timm_dit(args, mapping, dtype='float32'):
|
|
|
|
weights = {}
|
|
tik = time.time()
|
|
torch_dtype = str_dtype_to_torch(dtype)
|
|
tensor_parallel = mapping.tp_size
|
|
if args.diffusers_dit and args.fp8_linear:
|
|
from utils_modelopt import remap_model
|
|
converted_ckpt = "transformer.fp8.converted.state_dict.pt"
|
|
remap_model(quantized_ckpt=args.timm_ckpt, output_ckpt=converted_ckpt)
|
|
model_params = dict(torch.load(converted_ckpt))
|
|
else:
|
|
model_params = dict(torch.load(args.timm_ckpt))
|
|
timm_to_trtllm_name = FACEBOOK_DIT_NAME_MAPPING
|
|
|
|
def get_trtllm_name(timm_name):
|
|
for k, v in timm_to_trtllm_name.items():
|
|
m = re.match(k, timm_name)
|
|
if m is not None:
|
|
if "*" in v:
|
|
v = v.replace("*", m.groups()[0])
|
|
return v
|
|
return timm_name
|
|
|
|
weights = dict()
|
|
for name, param in model_params.items():
|
|
if param.dtype in [torch.int8, torch.float8_e4m3fn
|
|
] or 'scaling_factor' in name:
|
|
if 'scaling_factor' in name:
|
|
assert param.dtype == torch.float32
|
|
weights[get_trtllm_name(name)] = param.contiguous()
|
|
else:
|
|
weights[get_trtllm_name(name)] = param.contiguous().to(torch_dtype)
|
|
|
|
assert len(weights) == len(model_params)
|
|
|
|
for k, v in weights.items():
|
|
if re.match('^blocks.*.attn.qkv.weight$', k):
|
|
weights[k] = split_qkv_tp(v, args.num_heads, args.hidden_size,
|
|
tensor_parallel, mapping.tp_rank)
|
|
elif re.match('^blocks.*.attn.qkv.bias$', k):
|
|
weights[k] = split_qkv_bias_tp(v, args.num_heads, args.hidden_size,
|
|
tensor_parallel, mapping.tp_rank)
|
|
elif re.match('^blocks.*.attn.dense.weight$', k):
|
|
weights[k] = split_matrix_tp(v,
|
|
tensor_parallel,
|
|
mapping.tp_rank,
|
|
dim=1)
|
|
elif re.match('^blocks.*.mlp.fc.weight$', k):
|
|
weights[k] = split_matrix_tp(v,
|
|
tensor_parallel,
|
|
mapping.tp_rank,
|
|
dim=0)
|
|
elif re.match('^blocks.*.mlp.fc.bias$', k):
|
|
weights[k] = split(v, tensor_parallel, mapping.tp_rank)
|
|
elif re.match('^blocks.*.mlp.proj.weight$', k):
|
|
weights[k] = split_matrix_tp(v,
|
|
tensor_parallel,
|
|
mapping.tp_rank,
|
|
dim=1)
|
|
elif re.match(r'.*adaLN_modulation.weight$', k):
|
|
weights[k] = split_matrix_tp(v,
|
|
tensor_parallel,
|
|
mapping.tp_rank,
|
|
dim=0)
|
|
elif re.match(r'.*adaLN_modulation.bias$', k):
|
|
weights[k] = split(v, tensor_parallel, mapping.tp_rank)
|
|
|
|
tok = time.time()
|
|
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
|
|
print(f'Weights loaded. Total time: {t}')
|
|
return weights
|
|
|
|
|
|
def save_config(args):
|
|
if not os.path.exists(args.output_dir):
|
|
os.makedirs(args.output_dir)
|
|
|
|
config = {
|
|
'architecture': "DiT",
|
|
'dtype': args.dtype,
|
|
'input_size': args.input_size,
|
|
'patch_size': args.patch_size,
|
|
'in_channels': args.in_channels,
|
|
'hidden_size': args.hidden_size,
|
|
'num_hidden_layers': args.depth,
|
|
'num_attention_heads': args.num_heads,
|
|
'mlp_ratio': args.mlp_ratio,
|
|
'class_dropout_prob': args.class_dropout_prob,
|
|
'num_classes': args.num_classes,
|
|
'learn_sigma': args.learn_sigma,
|
|
'cfg_scale': args.cfg_scale,
|
|
'mapping': {
|
|
'world_size': args.cp_size * args.tp_size * args.pp_size,
|
|
'cp_size': args.cp_size,
|
|
'tp_size': args.tp_size,
|
|
'pp_size': args.pp_size,
|
|
}
|
|
}
|
|
if args.fp8_linear:
|
|
config['quantization'] = {
|
|
'quant_algo': "FP8",
|
|
# TODO: add support for exclude modules.
|
|
# 'exclude_modules': ["*final_layer*"],
|
|
}
|
|
|
|
with open(os.path.join(args.output_dir, 'config.json'), 'w') as f:
|
|
json.dump(config, f, indent=4)
|
|
|
|
|
|
def covert_and_save(args, rank):
|
|
if rank == 0:
|
|
save_config(args)
|
|
|
|
mapping = Mapping(world_size=args.cp_size * args.tp_size * args.pp_size,
|
|
rank=rank,
|
|
cp_size=args.cp_size,
|
|
tp_size=args.tp_size,
|
|
pp_size=args.pp_size)
|
|
|
|
weights = convert_timm_dit(args, mapping, dtype=args.dtype)
|
|
|
|
safetensors.torch.save_file(
|
|
weights, os.path.join(args.output_dir, f'rank{rank}.safetensors'))
|
|
|
|
|
|
def execute(workers, func, args):
|
|
if workers == 1:
|
|
for rank, f in enumerate(func):
|
|
f(args, rank)
|
|
else:
|
|
with ThreadPoolExecutor(max_workers=workers) as p:
|
|
futures = [p.submit(f, args, rank) for rank, f in enumerate(func)]
|
|
exceptions = []
|
|
for future in as_completed(futures):
|
|
try:
|
|
future.result()
|
|
except Exception as e:
|
|
traceback.print_exc()
|
|
exceptions.append(e)
|
|
assert len(
|
|
exceptions
|
|
) == 0, "Checkpoint conversion failed, please check error log."
|
|
|
|
|
|
def main():
|
|
print(tensorrt_llm.__version__)
|
|
args = parse_arguments()
|
|
world_size = args.cp_size * args.tp_size * args.pp_size
|
|
|
|
assert args.pp_size == 1, "PP is not supported yet."
|
|
|
|
tik = time.time()
|
|
|
|
if args.timm_ckpt is None:
|
|
return
|
|
|
|
execute(args.workers, [covert_and_save] * world_size, args)
|
|
|
|
tok = time.time()
|
|
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
|
|
print(f'Total time of converting checkpoints: {t}')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|