mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-06 03:01:50 +08:00
637 lines
22 KiB
Python
637 lines
22 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
from typing import Tuple, Union
|
|
|
|
from ...functional import concat
|
|
from ...module import Module, ModuleList
|
|
from .attention import AttentionBlock, Transformer2DModel
|
|
from .resnet import Downsample2D, ResnetBlock2D, Upsample2D
|
|
|
|
|
|
def get_down_block(
|
|
down_block_type,
|
|
num_layers,
|
|
in_channels,
|
|
out_channels,
|
|
temb_channels,
|
|
add_downsample,
|
|
resnet_eps,
|
|
resnet_act_fn,
|
|
attn_num_head_channels,
|
|
transformer_layers_per_block=1,
|
|
cross_attention_dim=None,
|
|
downsample_padding=None,
|
|
use_linear_projection=False,
|
|
dtype=None,
|
|
):
|
|
down_block_type = down_block_type[7:] if down_block_type.startswith(
|
|
"UNetRes") else down_block_type
|
|
if down_block_type == "DownBlock2D":
|
|
return DownBlock2D(
|
|
num_layers=num_layers,
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
temb_channels=temb_channels,
|
|
add_downsample=add_downsample,
|
|
resnet_eps=resnet_eps,
|
|
resnet_act_fn=resnet_act_fn,
|
|
downsample_padding=downsample_padding,
|
|
dtype=dtype,
|
|
)
|
|
elif down_block_type == "CrossAttnDownBlock2D":
|
|
if cross_attention_dim is None:
|
|
raise ValueError(
|
|
"cross_attention_dim must be specified for CrossAttnUpBlock2D")
|
|
return CrossAttnDownBlock2D(
|
|
num_layers=num_layers,
|
|
transformer_layers_per_block=transformer_layers_per_block,
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
temb_channels=temb_channels,
|
|
add_downsample=add_downsample,
|
|
resnet_eps=resnet_eps,
|
|
resnet_act_fn=resnet_act_fn,
|
|
downsample_padding=downsample_padding,
|
|
cross_attention_dim=cross_attention_dim,
|
|
attn_num_head_channels=attn_num_head_channels,
|
|
use_linear_projection=use_linear_projection,
|
|
dtype=dtype,
|
|
)
|
|
|
|
raise ValueError(f"{down_block_type} does not exist.")
|
|
|
|
|
|
def get_up_block(
|
|
up_block_type,
|
|
num_layers,
|
|
in_channels,
|
|
out_channels,
|
|
prev_output_channel,
|
|
temb_channels,
|
|
add_upsample,
|
|
resnet_eps,
|
|
resnet_act_fn,
|
|
attn_num_head_channels,
|
|
transformer_layers_per_block=1,
|
|
cross_attention_dim=None,
|
|
use_linear_projection=False,
|
|
dtype=None,
|
|
):
|
|
up_block_type = up_block_type[7:] if up_block_type.startswith(
|
|
"UNetRes") else up_block_type
|
|
if up_block_type == "UpBlock2D":
|
|
return UpBlock2D(
|
|
num_layers=num_layers,
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
prev_output_channel=prev_output_channel,
|
|
temb_channels=temb_channels,
|
|
add_upsample=add_upsample,
|
|
resnet_eps=resnet_eps,
|
|
resnet_act_fn=resnet_act_fn,
|
|
dtype=dtype,
|
|
)
|
|
elif up_block_type == "CrossAttnUpBlock2D":
|
|
if cross_attention_dim is None:
|
|
raise ValueError(
|
|
"cross_attention_dim must be specified for CrossAttnUpBlock2D")
|
|
return CrossAttnUpBlock2D(
|
|
num_layers=num_layers,
|
|
transformer_layers_per_block=transformer_layers_per_block,
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
prev_output_channel=prev_output_channel,
|
|
temb_channels=temb_channels,
|
|
add_upsample=add_upsample,
|
|
resnet_eps=resnet_eps,
|
|
resnet_act_fn=resnet_act_fn,
|
|
cross_attention_dim=cross_attention_dim,
|
|
attn_num_head_channels=attn_num_head_channels,
|
|
use_linear_projection=use_linear_projection,
|
|
dtype=dtype,
|
|
)
|
|
|
|
raise ValueError(f"{up_block_type} does not exist.")
|
|
|
|
|
|
class UpBlock2D(Module):
|
|
|
|
def __init__(
|
|
self,
|
|
in_channels: int,
|
|
prev_output_channel: int,
|
|
out_channels: int,
|
|
temb_channels: int,
|
|
dropout: float = 0.0,
|
|
num_layers: int = 1,
|
|
resnet_eps: float = 1e-6,
|
|
resnet_time_scale_shift: str = "default",
|
|
resnet_act_fn: str = "swish",
|
|
resnet_groups: int = 32,
|
|
resnet_pre_norm: bool = True,
|
|
output_scale_factor=1.0,
|
|
add_upsample=True,
|
|
dtype=None,
|
|
):
|
|
super().__init__()
|
|
resnets = []
|
|
|
|
for i in range(num_layers):
|
|
res_skip_channels = in_channels if (i == num_layers -
|
|
1) else out_channels
|
|
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
|
|
|
resnets.append(
|
|
ResnetBlock2D(
|
|
in_channels=resnet_in_channels + res_skip_channels,
|
|
out_channels=out_channels,
|
|
temb_channels=temb_channels,
|
|
eps=resnet_eps,
|
|
groups=resnet_groups,
|
|
dropout=dropout,
|
|
time_embedding_norm=resnet_time_scale_shift,
|
|
non_linearity=resnet_act_fn,
|
|
output_scale_factor=output_scale_factor,
|
|
pre_norm=resnet_pre_norm,
|
|
dtype=dtype,
|
|
))
|
|
|
|
self.resnets = ModuleList(resnets)
|
|
|
|
if add_upsample:
|
|
self.upsamplers = ModuleList([
|
|
Upsample2D(out_channels,
|
|
use_conv=True,
|
|
out_channels=out_channels,
|
|
dtype=dtype)
|
|
])
|
|
else:
|
|
self.upsamplers = None
|
|
|
|
def forward(self,
|
|
hidden_states,
|
|
res_hidden_states_tuple,
|
|
temb=None,
|
|
upsample_size=None):
|
|
for resnet in self.resnets:
|
|
# pop res hidden states
|
|
res_hidden_states = res_hidden_states_tuple[-1]
|
|
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
|
hidden_states = concat([hidden_states, res_hidden_states], dim=1)
|
|
|
|
hidden_states = resnet(hidden_states, temb)
|
|
|
|
if self.upsamplers is not None:
|
|
for upsampler in self.upsamplers:
|
|
hidden_states = upsampler(hidden_states, upsample_size)
|
|
|
|
return hidden_states
|
|
|
|
|
|
class DownBlock2D(Module):
|
|
|
|
def __init__(
|
|
self,
|
|
in_channels: int,
|
|
out_channels: int,
|
|
temb_channels: int,
|
|
dropout: float = 0.0,
|
|
num_layers: int = 1,
|
|
resnet_eps: float = 1e-6,
|
|
resnet_time_scale_shift: str = "default",
|
|
resnet_act_fn: str = "swish",
|
|
resnet_groups: int = 32,
|
|
resnet_pre_norm: bool = True,
|
|
output_scale_factor=1.0,
|
|
add_downsample=True,
|
|
downsample_padding=1,
|
|
dtype=None,
|
|
):
|
|
super().__init__()
|
|
resnets = []
|
|
|
|
for i in range(num_layers):
|
|
in_channels = in_channels if i == 0 else out_channels
|
|
resnets.append(
|
|
ResnetBlock2D(
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
temb_channels=temb_channels,
|
|
eps=resnet_eps,
|
|
groups=resnet_groups,
|
|
dropout=dropout,
|
|
time_embedding_norm=resnet_time_scale_shift,
|
|
non_linearity=resnet_act_fn,
|
|
output_scale_factor=output_scale_factor,
|
|
pre_norm=resnet_pre_norm,
|
|
dtype=dtype,
|
|
))
|
|
|
|
self.resnets = ModuleList(resnets)
|
|
|
|
if add_downsample:
|
|
self.downsamplers = ModuleList([
|
|
Downsample2D(out_channels,
|
|
use_conv=True,
|
|
out_channels=out_channels,
|
|
padding=downsample_padding,
|
|
dtype=dtype)
|
|
])
|
|
else:
|
|
self.downsamplers = None
|
|
|
|
def forward(self, hidden_states, temb=None):
|
|
output_states = ()
|
|
for resnet in self.resnets:
|
|
hidden_states = resnet(hidden_states, temb)
|
|
|
|
output_states += (hidden_states, )
|
|
if self.downsamplers is not None:
|
|
for downsampler in self.downsamplers:
|
|
hidden_states = downsampler(hidden_states)
|
|
|
|
output_states += (hidden_states, )
|
|
return hidden_states, output_states
|
|
|
|
|
|
class UNetMidBlock2D(Module):
|
|
|
|
def __init__(
|
|
self,
|
|
in_channels: int,
|
|
temb_channels: int,
|
|
dropout: float = 0.0,
|
|
num_layers: int = 1,
|
|
resnet_eps: float = 1e-6,
|
|
resnet_time_scale_shift: str = "default",
|
|
resnet_act_fn: str = "swish",
|
|
resnet_groups: int = 32,
|
|
resnet_pre_norm: bool = True,
|
|
attn_num_head_channels=1,
|
|
attention_type="default",
|
|
output_scale_factor=1.0,
|
|
dtype=None,
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
|
|
self.attention_type = attention_type
|
|
resnet_groups = resnet_groups if resnet_groups is not None else min(
|
|
in_channels // 4, 32)
|
|
|
|
# there is always at least one resnet
|
|
resnets = [
|
|
ResnetBlock2D(
|
|
in_channels=in_channels,
|
|
out_channels=in_channels,
|
|
temb_channels=temb_channels,
|
|
eps=resnet_eps,
|
|
groups=resnet_groups,
|
|
dropout=dropout,
|
|
time_embedding_norm=resnet_time_scale_shift,
|
|
non_linearity=resnet_act_fn,
|
|
output_scale_factor=output_scale_factor,
|
|
pre_norm=resnet_pre_norm,
|
|
dtype=dtype,
|
|
)
|
|
]
|
|
attentions = []
|
|
|
|
for _ in range(num_layers):
|
|
attentions.append(
|
|
AttentionBlock(
|
|
in_channels,
|
|
num_head_channels=attn_num_head_channels,
|
|
rescale_output_factor=output_scale_factor,
|
|
eps=resnet_eps,
|
|
num_groups=resnet_groups,
|
|
dtype=dtype,
|
|
))
|
|
resnets.append(
|
|
ResnetBlock2D(
|
|
in_channels=in_channels,
|
|
out_channels=in_channels,
|
|
temb_channels=temb_channels,
|
|
eps=resnet_eps,
|
|
groups=resnet_groups,
|
|
dropout=dropout,
|
|
time_embedding_norm=resnet_time_scale_shift,
|
|
non_linearity=resnet_act_fn,
|
|
output_scale_factor=output_scale_factor,
|
|
pre_norm=resnet_pre_norm,
|
|
dtype=dtype,
|
|
))
|
|
|
|
self.attentions = ModuleList(attentions)
|
|
self.resnets = ModuleList(resnets)
|
|
|
|
def forward(self, hidden_states, temb=None, encoder_states=None):
|
|
hidden_states = self.resnets[0](hidden_states, temb)
|
|
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
|
if self.attention_type == "default":
|
|
hidden_states = attn(hidden_states)
|
|
else:
|
|
hidden_states = attn(hidden_states, encoder_states)
|
|
hidden_states = resnet(hidden_states, temb)
|
|
|
|
return hidden_states
|
|
|
|
|
|
class CrossAttnUpBlock2D(Module):
|
|
|
|
def __init__(
|
|
self,
|
|
in_channels: int,
|
|
out_channels: int,
|
|
prev_output_channel: int,
|
|
temb_channels: int,
|
|
dropout: float = 0.0,
|
|
num_layers: int = 1,
|
|
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
|
resnet_eps: float = 1e-6,
|
|
resnet_time_scale_shift: str = "default",
|
|
resnet_act_fn: str = "swish",
|
|
resnet_groups: int = 32,
|
|
resnet_pre_norm: bool = True,
|
|
attn_num_head_channels=1,
|
|
cross_attention_dim=1280,
|
|
attention_type="default",
|
|
output_scale_factor=1.0,
|
|
add_upsample=True,
|
|
use_linear_projection: bool = False,
|
|
dtype=None,
|
|
):
|
|
super().__init__()
|
|
resnets = []
|
|
attentions = []
|
|
|
|
self.attention_type = attention_type
|
|
self.attn_num_head_channels = attn_num_head_channels
|
|
|
|
# support for variable transformer layers per block
|
|
if isinstance(transformer_layers_per_block, int):
|
|
transformer_layers_per_block = [transformer_layers_per_block
|
|
] * num_layers
|
|
|
|
for i in range(num_layers):
|
|
res_skip_channels = in_channels if (i == num_layers -
|
|
1) else out_channels
|
|
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
|
|
|
resnets.append(
|
|
ResnetBlock2D(
|
|
in_channels=resnet_in_channels + res_skip_channels,
|
|
out_channels=out_channels,
|
|
temb_channels=temb_channels,
|
|
eps=resnet_eps,
|
|
groups=resnet_groups,
|
|
dropout=dropout,
|
|
time_embedding_norm=resnet_time_scale_shift,
|
|
non_linearity=resnet_act_fn,
|
|
output_scale_factor=output_scale_factor,
|
|
pre_norm=resnet_pre_norm,
|
|
dtype=dtype,
|
|
))
|
|
|
|
attentions.append(
|
|
Transformer2DModel(in_channels=out_channels,
|
|
num_layers=transformer_layers_per_block[i],
|
|
num_attention_heads=attn_num_head_channels,
|
|
attention_head_dim=out_channels //
|
|
attn_num_head_channels,
|
|
norm_num_groups=resnet_groups,
|
|
use_linear_projection=use_linear_projection,
|
|
cross_attention_dim=cross_attention_dim,
|
|
dtype=dtype))
|
|
self.attentions = ModuleList(attentions)
|
|
self.resnets = ModuleList(resnets)
|
|
|
|
if add_upsample:
|
|
self.upsamplers = ModuleList([
|
|
Upsample2D(out_channels,
|
|
use_conv=True,
|
|
out_channels=out_channels,
|
|
dtype=dtype)
|
|
])
|
|
else:
|
|
self.upsamplers = None
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states,
|
|
res_hidden_states_tuple,
|
|
temb=None,
|
|
encoder_hidden_states=None,
|
|
upsample_size=None,
|
|
):
|
|
|
|
for resnet, attn in zip(self.resnets, self.attentions):
|
|
# pop res hidden states
|
|
res_hidden_states = res_hidden_states_tuple[-1]
|
|
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
|
hidden_states = concat([hidden_states, res_hidden_states], dim=1)
|
|
|
|
hidden_states = resnet(hidden_states, temb)
|
|
hidden_states = attn(hidden_states, context=encoder_hidden_states)
|
|
|
|
if self.upsamplers is not None:
|
|
for upsampler in self.upsamplers:
|
|
hidden_states = upsampler(hidden_states, upsample_size)
|
|
|
|
return hidden_states
|
|
|
|
|
|
class CrossAttnDownBlock2D(Module):
|
|
|
|
def __init__(
|
|
self,
|
|
in_channels: int,
|
|
out_channels: int,
|
|
temb_channels: int,
|
|
dropout: float = 0.0,
|
|
num_layers: int = 1,
|
|
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
|
resnet_eps: float = 1e-6,
|
|
resnet_time_scale_shift: str = "default",
|
|
resnet_act_fn: str = "swish",
|
|
resnet_groups: int = 32,
|
|
resnet_pre_norm: bool = True,
|
|
attn_num_head_channels=1,
|
|
cross_attention_dim=1280,
|
|
attention_type="default",
|
|
output_scale_factor=1.0,
|
|
downsample_padding=1,
|
|
add_downsample=True,
|
|
use_linear_projection: bool = False,
|
|
dtype=None,
|
|
):
|
|
super().__init__()
|
|
resnets = []
|
|
attentions = []
|
|
|
|
self.attention_type = attention_type
|
|
self.attn_num_head_channels = attn_num_head_channels
|
|
|
|
# support for variable transformer layers per block
|
|
if isinstance(transformer_layers_per_block, int):
|
|
transformer_layers_per_block = [transformer_layers_per_block
|
|
] * num_layers
|
|
|
|
for i in range(num_layers):
|
|
in_channels = in_channels if i == 0 else out_channels
|
|
resnets.append(
|
|
ResnetBlock2D(
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
temb_channels=temb_channels,
|
|
eps=resnet_eps,
|
|
groups=resnet_groups,
|
|
dropout=dropout,
|
|
time_embedding_norm=resnet_time_scale_shift,
|
|
non_linearity=resnet_act_fn,
|
|
output_scale_factor=output_scale_factor,
|
|
pre_norm=resnet_pre_norm,
|
|
dtype=dtype,
|
|
))
|
|
attentions.append(
|
|
Transformer2DModel(in_channels=out_channels,
|
|
num_layers=transformer_layers_per_block[i],
|
|
num_attention_heads=attn_num_head_channels,
|
|
attention_head_dim=out_channels //
|
|
attn_num_head_channels,
|
|
norm_num_groups=resnet_groups,
|
|
use_linear_projection=use_linear_projection,
|
|
cross_attention_dim=cross_attention_dim,
|
|
dtype=dtype))
|
|
self.attentions = ModuleList(attentions)
|
|
self.resnets = ModuleList(resnets)
|
|
|
|
if add_downsample:
|
|
self.downsamplers = ModuleList([
|
|
Downsample2D(out_channels,
|
|
use_conv=True,
|
|
out_channels=out_channels,
|
|
padding=downsample_padding,
|
|
dtype=dtype)
|
|
])
|
|
else:
|
|
self.downsamplers = None
|
|
|
|
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
|
output_states = ()
|
|
|
|
for resnet, attn in zip(self.resnets, self.attentions):
|
|
|
|
hidden_states = resnet(hidden_states, temb)
|
|
|
|
hidden_states = attn(hidden_states, context=encoder_hidden_states)
|
|
|
|
output_states += (hidden_states, )
|
|
|
|
if self.downsamplers is not None:
|
|
for downsampler in self.downsamplers:
|
|
hidden_states = downsampler(hidden_states)
|
|
|
|
output_states += (hidden_states, )
|
|
|
|
return hidden_states, output_states
|
|
|
|
|
|
class UNetMidBlock2DCrossAttn(Module):
|
|
|
|
def __init__(
|
|
self,
|
|
in_channels: int,
|
|
temb_channels: int,
|
|
dropout: float = 0.0,
|
|
num_layers: int = 1,
|
|
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
|
resnet_eps: float = 1e-6,
|
|
resnet_time_scale_shift: str = "default",
|
|
resnet_act_fn: str = "swish",
|
|
resnet_groups: int = 32,
|
|
resnet_pre_norm: bool = True,
|
|
attn_num_head_channels=1,
|
|
attention_type="default",
|
|
output_scale_factor=1.0,
|
|
cross_attention_dim=1280,
|
|
use_linear_projection: bool = False,
|
|
dtype=None,
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
|
|
self.attention_type = attention_type
|
|
self.attn_num_head_channels = attn_num_head_channels
|
|
resnet_groups = resnet_groups if resnet_groups is not None else min(
|
|
in_channels // 4, 32)
|
|
|
|
# support for variable transformer layers per block
|
|
if isinstance(transformer_layers_per_block, int):
|
|
transformer_layers_per_block = [transformer_layers_per_block
|
|
] * num_layers
|
|
|
|
# there is always at least one resnet
|
|
resnets = [
|
|
ResnetBlock2D(in_channels=in_channels,
|
|
out_channels=in_channels,
|
|
temb_channels=temb_channels,
|
|
eps=resnet_eps,
|
|
groups=resnet_groups,
|
|
dropout=dropout,
|
|
time_embedding_norm=resnet_time_scale_shift,
|
|
non_linearity=resnet_act_fn,
|
|
output_scale_factor=output_scale_factor,
|
|
pre_norm=resnet_pre_norm,
|
|
dtype=dtype)
|
|
]
|
|
attentions = []
|
|
|
|
for i in range(num_layers):
|
|
attentions.append(
|
|
Transformer2DModel(in_channels=in_channels,
|
|
num_layers=transformer_layers_per_block[i],
|
|
num_attention_heads=attn_num_head_channels,
|
|
attention_head_dim=in_channels //
|
|
attn_num_head_channels,
|
|
norm_num_groups=resnet_groups,
|
|
use_linear_projection=use_linear_projection,
|
|
cross_attention_dim=cross_attention_dim,
|
|
dtype=dtype))
|
|
resnets.append(
|
|
ResnetBlock2D(in_channels=in_channels,
|
|
out_channels=in_channels,
|
|
temb_channels=temb_channels,
|
|
eps=resnet_eps,
|
|
groups=resnet_groups,
|
|
dropout=dropout,
|
|
time_embedding_norm=resnet_time_scale_shift,
|
|
non_linearity=resnet_act_fn,
|
|
output_scale_factor=output_scale_factor,
|
|
pre_norm=resnet_pre_norm,
|
|
dtype=dtype))
|
|
|
|
self.attentions = ModuleList(attentions)
|
|
self.resnets = ModuleList(resnets)
|
|
|
|
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
|
hidden_states = self.resnets[0](hidden_states, temb)
|
|
|
|
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
|
hidden_states = attn(hidden_states, encoder_hidden_states)
|
|
hidden_states = resnet(hidden_states, temb)
|
|
|
|
return hidden_states
|