TensorRT-LLMs/tensorrt_llm/models/unet/weights.py
Kaiyu Xie 385626572d
Update TensorRT-LLM (#2502)
* Update TensorRT-LLM

---------

Co-authored-by: 岑灿 <yunyi.hyy@alibaba-inc.com>
2024-11-26 16:51:34 +08:00

198 lines
7.6 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
import numpy as np
from ...logger import logger
def update_timestep_weight(src, dst):
dst.linear_1.update_parameters(src.linear_1)
dst.linear_2.update_parameters(src.linear_2)
def update_crossattn_downblock_2d_weight(src, dst):
for index, value in enumerate(src.resnets):
update_resnet_block_weight(value, dst.resnets[index])
for index, value in enumerate(src.attentions):
update_transformer_2d_model_weight(dst.attentions[index], value)
if src.downsamplers:
for index, value in enumerate(src.downsamplers):
dst.downsamplers[index].conv.update_parameters(value.conv)
def update_transformer_2d_model_weight(gm, m):
gm.norm.update_parameters(m.norm)
gm.proj_in.update_parameters(m.proj_in)
for i in range(len(gm.transformer_blocks)):
gm.transformer_blocks[i].attn1.to_qkv.weight.value = np.concatenate(
(m.transformer_blocks[i].attn1.to_q.weight.detach().cpu().numpy(),
m.transformer_blocks[i].attn1.to_k.weight.detach().cpu().numpy(),
m.transformer_blocks[i].attn1.to_v.weight.detach().cpu().numpy()))
gm.transformer_blocks[i].attn1.to_out.update_parameters(
m.transformer_blocks[i].attn1.to_out[0])
gm.transformer_blocks[i].attn2.to_q.update_parameters(
m.transformer_blocks[i].attn2.to_q)
gm.transformer_blocks[i].attn2.to_kv.weight.value = np.concatenate(
(m.transformer_blocks[i].attn2.to_k.weight.detach().cpu().numpy(),
m.transformer_blocks[i].attn2.to_v.weight.detach().cpu().numpy()))
gm.transformer_blocks[i].attn2.to_out.update_parameters(
m.transformer_blocks[i].attn2.to_out[0])
gm.transformer_blocks[i].ff.proj_in.update_parameters(
m.transformer_blocks[i].ff.net[0].proj)
gm.transformer_blocks[i].ff.proj_out.update_parameters(
m.transformer_blocks[i].ff.net[2])
gm.transformer_blocks[i].norm1.update_parameters(
m.transformer_blocks[i].norm1)
gm.transformer_blocks[i].norm2.update_parameters(
m.transformer_blocks[i].norm2)
gm.transformer_blocks[i].norm3.update_parameters(
m.transformer_blocks[i].norm3)
gm.proj_out.update_parameters(m.proj_out)
def update_upblock_2d_weight(src, dst):
if src.upsamplers:
for index, value in enumerate(src.upsamplers):
dst.upsamplers[index].conv.update_parameters(value.conv)
for index, value in enumerate(src.resnets):
dst.resnets[index].norm1.update_parameters(value.norm1)
dst.resnets[index].conv1.update_parameters(value.conv1)
dst.resnets[index].norm2.update_parameters(value.norm2)
dst.resnets[index].conv2.update_parameters(value.conv2)
if value.conv_shortcut:
dst.resnets[index].conv_shortcut.update_parameters(
value.conv_shortcut)
dst.resnets[index].time_emb_proj.update_parameters(value.time_emb_proj)
def update_downblock_2d_weight(src, dst):
if src.downsamplers:
for index, value in enumerate(src.downsamplers):
dst.downsamplers[index].conv.update_parameters(value.conv)
for index, value in enumerate(src.resnets):
dst.resnets[index].norm1.update_parameters(value.norm1)
dst.resnets[index].conv1.update_parameters(value.conv1)
dst.resnets[index].norm2.update_parameters(value.norm2)
dst.resnets[index].conv2.update_parameters(value.conv2)
if value.conv_shortcut:
dst.resnets[index].conv_shortcut.update_parameters(
value.conv_shortcut)
dst.resnets[index].time_emb_proj.update_parameters(value.time_emb_proj)
def update_unet_mid_block_2d_weight(src, dst):
for index, value in enumerate(src.resnets):
update_resnet_block_weight(value, dst.resnets[index])
for index, value in enumerate(src.attentions):
update_transformer_2d_model_weight(dst.attentions[index], value)
def update_crossattn_upblock_2d_weight(src, dst):
for index, value in enumerate(src.resnets):
update_resnet_block_weight(value, dst.resnets[index])
for index, value in enumerate(src.attentions):
update_transformer_2d_model_weight(dst.attentions[index], value)
if src.upsamplers:
for index, value in enumerate(src.upsamplers):
dst.upsamplers[index].conv.update_parameters(value.conv)
def update_resnet_block_weight(src, dst):
dst.norm1.update_parameters(src.norm1)
dst.conv1.update_parameters(src.conv1)
dst.norm2.update_parameters(src.norm2)
dst.conv2.update_parameters(src.conv2)
dst.time_emb_proj.update_parameters(src.time_emb_proj)
if src.conv_shortcut:
dst.conv_shortcut.update_parameters(src.conv_shortcut)
def update_unetmidblock_2d_weight(src, dst):
for index, value in enumerate(src.attentions):
dst.attentions[index].group_norm.update_parameters(value.group_norm)
dst.attentions[index].proj_attn.update_parameters(value.proj_attn)
dst.attentions[index].qkv.weight.value = np.concatenate(
(value.query.weight.detach().cpu().numpy(),
value.key.weight.detach().cpu().numpy(),
value.value.weight.detach().cpu().numpy()))
dst.attentions[index].qkv.bias.value = np.concatenate(
(value.query.bias.detach().cpu().numpy(),
value.key.bias.detach().cpu().numpy(),
value.value.bias.detach().cpu().numpy()))
for index, value in enumerate(src.resnets):
update_resnet_block_weight(value, dst.resnets[index])
def update_unet_2d_condition_model_weights(src, dst):
dst.conv_in.update_parameters(src.conv_in)
dst.time_embedding.update_parameters(src.time_embedding)
if src.config.addition_embed_type:
dst.add_embedding.update_parameters(src.add_embedding)
for index, type in enumerate(src.config.down_block_types):
if type == 'CrossAttnDownBlock2D':
update_crossattn_downblock_2d_weight(src.down_blocks[index],
dst.down_blocks[index])
elif type == 'DownBlock2D':
update_downblock_2d_weight(src.down_blocks[index],
dst.down_blocks[index])
update_unet_mid_block_2d_weight(src.mid_block, dst.mid_block)
for index, type in enumerate(src.config.up_block_types):
if type == 'CrossAttnUpBlock2D':
update_crossattn_upblock_2d_weight(src.up_blocks[index],
dst.up_blocks[index])
elif type == 'UpBlock2D':
update_upblock_2d_weight(src.up_blocks[index], dst.up_blocks[index])
dst.conv_norm_out.update_parameters(src.conv_norm_out)
dst.conv_out.update_parameters(src.conv_out)
def load_from_hf_unet(src, dst):
logger.info('Loading weights from HF Unet...')
tik = time.time()
update_unet_2d_condition_model_weights(src, dst)
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
logger.info(f'Weights loaded. Total time: {t}')