TensorRT-LLMs/examples/dit/utils_modelopt.py
Kaiyu Xie 31ac30e928
Update TensorRT-LLM (#2215)
* Update TensorRT-LLM

---------

Co-authored-by: Sherlock Xu <65327072+Sherlock113@users.noreply.github.com>
2024-09-10 18:21:22 +08:00

187 lines
8.0 KiB
Python

import os
import re
from collections import OrderedDict
import modelopt.torch.opt as mto
import torch
from diffusers import DiTPipeline
from modelopt.torch.export.layer_utils import (get_activation_scaling_factor,
get_weight_scaling_factor)
from modelopt.torch.export.model_config_utils import to_quantized_weight
from torchvision.datasets.utils import download_url
HUGGINGFACE_TO_FACEBOOK_DIT_NAME_MAPPING = {
"^transformer_blocks.(\d+).norm1.emb.class_embedder.embedding_table.weight$":
"y_embedder.embedding_table.weight",
"^transformer_blocks.(\d+).norm1.emb.timestep_embedder.linear_1.weight$":
"t_embedder.mlp.0.weight",
"^transformer_blocks.(\d+).norm1.emb.timestep_embedder.linear_1.bias$":
"t_embedder.mlp.0.bias",
"^transformer_blocks.(\d+).norm1.emb.timestep_embedder.linear_2.weight$":
"t_embedder.mlp.2.weight",
"^transformer_blocks.(\d+).norm1.emb.timestep_embedder.linear_2.bias$":
"t_embedder.mlp.2.bias",
"^pos_embed.proj.weight$":
"x_embedder.proj.weight",
"^pos_embed.proj.bias$":
"x_embedder.proj.bias",
"^transformer_blocks.(\d+).attn1.to_qkv.weight$":
"blocks.*.attn.qkv.weight",
"^transformer_blocks.(\d+).attn1.to_qkv.bias$":
"blocks.*.attn.qkv.bias",
"^transformer_blocks.(\d+).attn1.to_out.0.weight$":
"blocks.*.attn.proj.weight",
"^transformer_blocks.(\d+).attn1.to_out.0.bias$":
"blocks.*.attn.proj.bias",
"^transformer_blocks.(\d+).ff.net.0.proj.weight$":
"blocks.*.mlp.fc1.weight",
"^transformer_blocks.(\d+).ff.net.0.proj.bias$":
"blocks.*.mlp.fc1.bias",
"^transformer_blocks.(\d+).ff.net.2.weight$":
"blocks.*.mlp.fc2.weight",
"^transformer_blocks.(\d+).ff.net.2.bias$":
"blocks.*.mlp.fc2.bias",
"^transformer_blocks.(\d+).norm1.linear.weight$":
"blocks.*.adaLN_modulation.1.weight",
"^transformer_blocks.(\d+).norm1.linear.bias$":
"blocks.*.adaLN_modulation.1.bias",
"^proj_out_2.weight$":
"final_layer.linear.weight",
"^proj_out_2.bias$":
"final_layer.linear.bias",
"^proj_out_1.weight$":
"final_layer.adaLN_modulation.1.weight",
"^proj_out_1.bias$":
"final_layer.adaLN_modulation.1.bias",
"^transformer_blocks.(\d+).norm1.emb.timestep_embedder.linear_1.weights_scaling_factor$":
"t_embedder.mlp.0.weights_scaling_factor",
"^transformer_blocks.(\d+).norm1.emb.timestep_embedder.linear_1.activation_scaling_factor":
"t_embedder.mlp.0.activation_scaling_factor",
"^transformer_blocks.(\d+).norm1.emb.timestep_embedder.linear_2.weights_scaling_factor":
"t_embedder.mlp.2.weights_scaling_factor",
"^transformer_blocks.(\d+).norm1.emb.timestep_embedder.linear_2.activation_scaling_factor":
"t_embedder.mlp.2.activation_scaling_factor",
"^transformer_blocks.(\d+).attn1.to_qkv.weights_scaling_factor$":
"blocks.*.attn.qkv.weights_scaling_factor",
"^transformer_blocks.(\d+).attn1.to_qkv.activation_scaling_factor$":
"blocks.*.attn.qkv.activation_scaling_factor",
"^transformer_blocks.(\d+).attn1.to_out.0.weights_scaling_factor$":
"blocks.*.attn.proj.weights_scaling_factor",
"^transformer_blocks.(\d+).attn1.to_out.0.activation_scaling_factor$":
"blocks.*.attn.proj.activation_scaling_factor",
"^transformer_blocks.(\d+).ff.net.0.proj.weights_scaling_factor$":
"blocks.*.mlp.fc1.weights_scaling_factor",
"^transformer_blocks.(\d+).ff.net.0.proj.activation_scaling_factor$":
"blocks.*.mlp.fc1.activation_scaling_factor",
"^transformer_blocks.(\d+).ff.net.2.weights_scaling_factor$":
"blocks.*.mlp.fc2.weights_scaling_factor",
"^transformer_blocks.(\d+).ff.net.2.activation_scaling_factor$":
"blocks.*.mlp.fc2.activation_scaling_factor",
"^transformer_blocks.(\d+).norm1.linear.weights_scaling_factor$":
"blocks.*.adaLN_modulation.1.weights_scaling_factor",
"^transformer_blocks.(\d+).norm1.linear.activation_scaling_factor$":
"blocks.*.adaLN_modulation.1.activation_scaling_factor",
"^proj_out_2.weights_scaling_factor$":
"final_layer.linear.weights_scaling_factor",
"^proj_out_2.activation_scaling_factor$":
"final_layer.linear.activation_scaling_factor",
"^proj_out_1.weights_scaling_factor$":
"final_layer.adaLN_modulation.1.weights_scaling_factor",
"^proj_out_1.activation_scaling_factor$":
"final_layer.adaLN_modulation.1.activation_scaling_factor",
}
def convert_amax_to_scaling_factor(model, state_dict):
ret_dict = state_dict.copy()
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
activation_scaling_factor = get_activation_scaling_factor(module)
weight_scaling_factor = get_weight_scaling_factor(module)
if activation_scaling_factor:
ret_dict[
f'{name}.activation_scaling_factor'] = activation_scaling_factor
if weight_scaling_factor:
ret_dict[
f'{name}.weights_scaling_factor'] = weight_scaling_factor
weight = module.weight.detach().cpu()
if weight_scaling_factor:
# only module with valid weight scaling factor
weight = to_quantized_weight(
weight=weight,
weights_scaling_factor=weight_scaling_factor,
quantization="fp8",
)
# replace the quantized weight
ret_dict[f'{name}.weight'] = weight
return ret_dict
def get_weights_map(state_dict):
def _get_fb_dit_name(dit_name):
for k, v in HUGGINGFACE_TO_FACEBOOK_DIT_NAME_MAPPING.items():
m = re.match(k, dit_name)
if m is not None:
if "*" in v:
v = v.replace("*", m.groups()[0])
return v
return dit_name
weights_map = OrderedDict()
for key, value in state_dict.items():
if ("to_q." in key) or ("to_k." in key) or ("to_v." in key):
continue
if _get_fb_dit_name(key) in weights_map:
continue
else:
weights_map[_get_fb_dit_name(key)] = value
return weights_map
def download_model(ckpt="DiT-XL-2-512x512.pt"):
"""
Downloads a pre-trained DiT model from the web.
"""
if not os.path.isfile(ckpt):
os.makedirs('pretrained_models', exist_ok=True)
web_path = f"https://dl.fbaipublicfiles.com/DiT/models/{ckpt}"
download_url(web_path, '.')
model = torch.load(ckpt, map_location=lambda storage, loc: storage)
return model
def remap_model(model_name="facebook/DiT-XL-2-512",
quantized_ckpt="dit.quantized.pt",
output_ckpt="dit.converted.pt",
fuse_qkv=True,
dtype=torch.float32):
pipe = DiTPipeline.from_pretrained(model_name, torch_dtype=dtype)
transformer = pipe.transformer
# fuse qkv gemm
if fuse_qkv:
from diffusers.models.attention_processor import (Attention,
AttnProcessor2_0)
for module in transformer.modules():
if isinstance(module, Attention):
assert (isinstance(module.processor, AttnProcessor2_0))
module.fuse_projections(fuse=True)
pretrained_state_dict = transformer.state_dict()
mto.restore(transformer, quantized_ckpt)
quantized_dict = convert_amax_to_scaling_factor(transformer,
pretrained_state_dict)
assert set(pretrained_state_dict.keys()).issubset(set(
quantized_dict.keys()))
remapped_dict = get_weights_map(quantized_dict)
# TODO: currently we use weights of pos_embed and x_embedder from official DiT because
# in the implementation by HuggingFace, there are some modifications for these modules.
stored_params = download_model(ckpt="DiT-XL-2-512x512.pt")
for name in ["pos_embed", "x_embedder.proj.weight", "x_embedder.proj.bias"]:
remapped_dict[name] = stored_params[name]
torch.save(remapped_dict, output_ckpt)