mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
198 lines
7.6 KiB
Python
198 lines
7.6 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import time
|
|
|
|
import numpy as np
|
|
|
|
from ...logger import logger
|
|
|
|
|
|
def update_timestep_weight(src, dst):
|
|
dst.linear_1.update_parameters(src.linear_1)
|
|
dst.linear_2.update_parameters(src.linear_2)
|
|
|
|
|
|
def update_crossattn_downblock_2d_weight(src, dst):
|
|
for index, value in enumerate(src.resnets):
|
|
update_resnet_block_weight(value, dst.resnets[index])
|
|
|
|
for index, value in enumerate(src.attentions):
|
|
update_transformer_2d_model_weight(dst.attentions[index], value)
|
|
|
|
if src.downsamplers:
|
|
for index, value in enumerate(src.downsamplers):
|
|
dst.downsamplers[index].conv.update_parameters(value.conv)
|
|
|
|
|
|
def update_transformer_2d_model_weight(gm, m):
|
|
gm.norm.update_parameters(m.norm)
|
|
gm.proj_in.update_parameters(m.proj_in)
|
|
for i in range(len(gm.transformer_blocks)):
|
|
gm.transformer_blocks[i].attn1.to_qkv.weight.value = np.concatenate(
|
|
(m.transformer_blocks[i].attn1.to_q.weight.detach().cpu().numpy(),
|
|
m.transformer_blocks[i].attn1.to_k.weight.detach().cpu().numpy(),
|
|
m.transformer_blocks[i].attn1.to_v.weight.detach().cpu().numpy()))
|
|
gm.transformer_blocks[i].attn1.to_out.update_parameters(
|
|
m.transformer_blocks[i].attn1.to_out[0])
|
|
|
|
gm.transformer_blocks[i].attn2.to_q.update_parameters(
|
|
m.transformer_blocks[i].attn2.to_q)
|
|
gm.transformer_blocks[i].attn2.to_kv.weight.value = np.concatenate(
|
|
(m.transformer_blocks[i].attn2.to_k.weight.detach().cpu().numpy(),
|
|
m.transformer_blocks[i].attn2.to_v.weight.detach().cpu().numpy()))
|
|
gm.transformer_blocks[i].attn2.to_out.update_parameters(
|
|
m.transformer_blocks[i].attn2.to_out[0])
|
|
|
|
gm.transformer_blocks[i].ff.proj_in.update_parameters(
|
|
m.transformer_blocks[i].ff.net[0].proj)
|
|
gm.transformer_blocks[i].ff.proj_out.update_parameters(
|
|
m.transformer_blocks[i].ff.net[2])
|
|
|
|
gm.transformer_blocks[i].norm1.update_parameters(
|
|
m.transformer_blocks[i].norm1)
|
|
gm.transformer_blocks[i].norm2.update_parameters(
|
|
m.transformer_blocks[i].norm2)
|
|
gm.transformer_blocks[i].norm3.update_parameters(
|
|
m.transformer_blocks[i].norm3)
|
|
|
|
gm.proj_out.update_parameters(m.proj_out)
|
|
|
|
|
|
def update_upblock_2d_weight(src, dst):
|
|
if src.upsamplers:
|
|
for index, value in enumerate(src.upsamplers):
|
|
dst.upsamplers[index].conv.update_parameters(value.conv)
|
|
|
|
for index, value in enumerate(src.resnets):
|
|
dst.resnets[index].norm1.update_parameters(value.norm1)
|
|
dst.resnets[index].conv1.update_parameters(value.conv1)
|
|
|
|
dst.resnets[index].norm2.update_parameters(value.norm2)
|
|
dst.resnets[index].conv2.update_parameters(value.conv2)
|
|
|
|
if value.conv_shortcut:
|
|
dst.resnets[index].conv_shortcut.update_parameters(
|
|
value.conv_shortcut)
|
|
|
|
dst.resnets[index].time_emb_proj.update_parameters(value.time_emb_proj)
|
|
|
|
|
|
def update_downblock_2d_weight(src, dst):
|
|
if src.downsamplers:
|
|
for index, value in enumerate(src.downsamplers):
|
|
dst.downsamplers[index].conv.update_parameters(value.conv)
|
|
|
|
for index, value in enumerate(src.resnets):
|
|
dst.resnets[index].norm1.update_parameters(value.norm1)
|
|
dst.resnets[index].conv1.update_parameters(value.conv1)
|
|
|
|
dst.resnets[index].norm2.update_parameters(value.norm2)
|
|
dst.resnets[index].conv2.update_parameters(value.conv2)
|
|
|
|
if value.conv_shortcut:
|
|
dst.resnets[index].conv_shortcut.update_parameters(
|
|
value.conv_shortcut)
|
|
|
|
dst.resnets[index].time_emb_proj.update_parameters(value.time_emb_proj)
|
|
|
|
|
|
def update_unet_mid_block_2d_weight(src, dst):
|
|
for index, value in enumerate(src.resnets):
|
|
update_resnet_block_weight(value, dst.resnets[index])
|
|
|
|
for index, value in enumerate(src.attentions):
|
|
update_transformer_2d_model_weight(dst.attentions[index], value)
|
|
|
|
|
|
def update_crossattn_upblock_2d_weight(src, dst):
|
|
for index, value in enumerate(src.resnets):
|
|
update_resnet_block_weight(value, dst.resnets[index])
|
|
|
|
for index, value in enumerate(src.attentions):
|
|
update_transformer_2d_model_weight(dst.attentions[index], value)
|
|
if src.upsamplers:
|
|
for index, value in enumerate(src.upsamplers):
|
|
dst.upsamplers[index].conv.update_parameters(value.conv)
|
|
|
|
|
|
def update_resnet_block_weight(src, dst):
|
|
dst.norm1.update_parameters(src.norm1)
|
|
dst.conv1.update_parameters(src.conv1)
|
|
|
|
dst.norm2.update_parameters(src.norm2)
|
|
dst.conv2.update_parameters(src.conv2)
|
|
|
|
dst.time_emb_proj.update_parameters(src.time_emb_proj)
|
|
if src.conv_shortcut:
|
|
dst.conv_shortcut.update_parameters(src.conv_shortcut)
|
|
|
|
|
|
def update_unetmidblock_2d_weight(src, dst):
|
|
for index, value in enumerate(src.attentions):
|
|
dst.attentions[index].group_norm.update_parameters(value.group_norm)
|
|
dst.attentions[index].proj_attn.update_parameters(value.proj_attn)
|
|
|
|
dst.attentions[index].qkv.weight.value = np.concatenate(
|
|
(value.query.weight.detach().cpu().numpy(),
|
|
value.key.weight.detach().cpu().numpy(),
|
|
value.value.weight.detach().cpu().numpy()))
|
|
dst.attentions[index].qkv.bias.value = np.concatenate(
|
|
(value.query.bias.detach().cpu().numpy(),
|
|
value.key.bias.detach().cpu().numpy(),
|
|
value.value.bias.detach().cpu().numpy()))
|
|
|
|
for index, value in enumerate(src.resnets):
|
|
update_resnet_block_weight(value, dst.resnets[index])
|
|
|
|
|
|
def update_unet_2d_condition_model_weights(src, dst):
|
|
dst.conv_in.update_parameters(src.conv_in)
|
|
|
|
dst.time_embedding.update_parameters(src.time_embedding)
|
|
if src.config.addition_embed_type:
|
|
dst.add_embedding.update_parameters(src.add_embedding)
|
|
|
|
for index, type in enumerate(src.config.down_block_types):
|
|
if type == 'CrossAttnDownBlock2D':
|
|
update_crossattn_downblock_2d_weight(src.down_blocks[index],
|
|
dst.down_blocks[index])
|
|
elif type == 'DownBlock2D':
|
|
update_downblock_2d_weight(src.down_blocks[index],
|
|
dst.down_blocks[index])
|
|
|
|
update_unet_mid_block_2d_weight(src.mid_block, dst.mid_block)
|
|
|
|
for index, type in enumerate(src.config.up_block_types):
|
|
if type == 'CrossAttnUpBlock2D':
|
|
update_crossattn_upblock_2d_weight(src.up_blocks[index],
|
|
dst.up_blocks[index])
|
|
elif type == 'UpBlock2D':
|
|
update_upblock_2d_weight(src.up_blocks[index], dst.up_blocks[index])
|
|
|
|
dst.conv_norm_out.update_parameters(src.conv_norm_out)
|
|
|
|
dst.conv_out.update_parameters(src.conv_out)
|
|
|
|
|
|
def load_from_hf_unet(src, dst):
|
|
logger.info('Loading weights from HF Unet...')
|
|
tik = time.time()
|
|
|
|
update_unet_2d_condition_model_weights(src, dst)
|
|
|
|
tok = time.time()
|
|
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
|
|
logger.info(f'Weights loaded. Total time: {t}')
|