TensorRT-LLMs/tensorrt_llm/models/stdit/model.py
2025-10-28 09:17:26 -07:00

1625 lines
63 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import builtins
import collections
import functools
import json
import math
import os
import re
from collections import OrderedDict
from typing import Optional
import numpy as np
import tensorrt as trt
import torch
from tqdm import tqdm
import tensorrt_llm
from tensorrt_llm._common import default_net
from tensorrt_llm._utils import str_dtype_to_trt, trt_dtype_to_str
from tensorrt_llm.functional import (ACT2FN, AttentionMaskType, LayerNormType,
PositionEmbeddingType, Tensor,
constant_to_tensor_)
from tensorrt_llm.layers import (ColumnLinear, Conv3d, LayerNorm, Linear,
RowLinear)
from tensorrt_llm.layers.attention import (Attention, AttentionParams,
BertAttention, KeyValueCacheParams,
bert_attention, layernorm_map)
from tensorrt_llm.layers.normalization import RmsNorm
from tensorrt_llm.llmapi.kv_cache_type import KVCacheType
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models.generation_mixin import GenerationMixin
from tensorrt_llm.models.model_weights_loader import (ModelWeightsFormat,
ModelWeightsLoader)
from tensorrt_llm.models.modeling_utils import PretrainedConfig, PretrainedModel
from tensorrt_llm.module import Module, ModuleList
from tensorrt_llm.parameter import Parameter
from tensorrt_llm.plugin import current_all_reduce_helper
from tensorrt_llm.quantization import QuantMode
from ...functional import (allgather, arange, cast, chunk, concat, constant,
cos, div, einsum, exp, expand, expand_dims,
expand_mask, masked_select, matmul, meshgrid2d, pad,
permute, pow, rearrange, repeat, repeat_interleave,
rms_norm, shape, sin, slice, softmax, split, squeeze,
stack, sum, unsqueeze, where)
from .config import STDiTModelConfig
# [TODO] For now, we only support static shape, which might contains `-1` when inputs are with dynamic shape.
USE_STATIC_SHAPE = True
# From PyTorch internals
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
return tuple(x)
return tuple([x] * n)
return parse
to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)
to_3tuple = _ntuple(3)
to_4tuple = _ntuple(4)
to_ntuple = _ntuple
# [TODO] make constant `1` compatible with `scale`
def t2i_modulate(x, shift, scale):
return x * (1.0 + scale) + shift
class ModuleSequential(ModuleList):
def __init__(self, modules) -> None:
super(ModuleSequential, self).__init__(modules=modules)
def forward(self, *args, **kwargs):
module = self.__getitem__(0)
outputs = module(*args, **kwargs)
for idx in range(1, len(self._modules)):
module = self.__getitem__(idx)
outputs = module(outputs)
return outputs
class Activation(Module):
def __init__(self, act_fn='silu'):
super().__init__()
self.act_fn = act_fn
def forward(self, input: Tensor):
return ACT2FN[self.act_fn](input)
class RotaryEmbedder(Module):
def __init__(self,
dim,
theta=10000,
interpolate_factor=1.,
theta_rescale_factor=1.,
seq_before_head_dim=False,
cache_if_possible=True,
use_xpos=False,
dtype=None,
mapping=Mapping(),
quant_mode=QuantMode(0)):
super().__init__()
theta *= theta_rescale_factor**(dim / (dim - 2))
freqs = 1. / (theta
**(torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
self.freqs = Parameter(freqs, dtype=dtype)
self.cached_freqs = None
self.seq_before_head_dim = seq_before_head_dim
self.default_seq_dim = -3 if seq_before_head_dim else -2
self.scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
assert interpolate_factor >= 1.
self.interpolate_factor = interpolate_factor
self.cache_if_possible = cache_if_possible
self.use_xpos = use_xpos
self.dtype = dtype
self.mapping = mapping
self.quant_mode = quant_mode
def get_freqs(self,
t: Tensor,
seq_len: Optional[int] = None,
offset: int = 0):
should_cache = self.cache_if_possible and seq_len is not None
if should_cache and isinstance(self.cached_freqs, Tensor):
if (offset + seq_len) <= self.cached_freqs.shape[0]:
return slice(self.cached_freqs,
starts=[offset] +
[0] * len(self.cached_freqs.shape[1:]),
sizes=[seq_len, *self.cached_freqs.shape[1:]])
freqs = self.freqs.value
freqs = unsqueeze(t, axis=-1) * unsqueeze(freqs, axis=0)
freqs = repeat_interleave(freqs, repeats=2, dim=(freqs.ndim() - 1))
if should_cache:
self.cached_freqs = freqs
return freqs
def get_seq_pos(self, seq_len: int, dtype: trt.DataType, offset: int = 0):
return (arange(start=0, end=seq_len, dtype=trt_dtype_to_str(dtype)) +
offset) / self.interpolate_factor
def rotate_half(self, x: Tensor):
x = x.view([*x.shape[:-1], x.shape[-1] // 2, 2])
x1, x2 = x.unbind(x.ndim() - 1)
x = stack([-1 * x2, x1], dim=-1)
x = x.view([*x.shape[:-2], x.shape[-2] * x.shape[-1]])
return x
def apply_rotary_emb(self,
freqs: Tensor,
t: Tensor,
start_index: int = 0,
scale: int = 1.,
seq_dim: int = -2):
if t.ndim() == 3:
seq_len = t.shape[seq_dim]
# freqs = freqs[-seq_len:]
freqs = slice(starts=[freqs.shape[0] - seq_len], sizes=[seq_len])
rot_dim = freqs.shape[-1]
end_index = start_index + rot_dim
assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size ' + \
'to rotate in all the positions {rot_dim}'
t_left = slice(t,
starts=[0] * t.ndim(),
sizes=[*t.shape[:-1], start_index])
t_right = slice(t,
starts=[0] * (t.ndim() - 1) + [end_index],
sizes=[*t.shape[:-1], t.shape[-1] - end_index])
t = (t * cos(freqs) * scale) + (self.rotate_half(t) * sin(freqs) *
scale)
return concat([t_left, t, t_right], dim=-1)
def rotate_queries_or_keys(self,
t: Tensor,
seq_dim: Optional[int] = None,
offset: int = 0,
freq_seq_len: Optional[int] = None):
seq_dim = self.default_seq_dim if seq_dim is None else seq_dim
assert not self.use_xpos, 'you must use `.rotate_queries_and_keys` method ' + \
'instead and pass in both queries and keys, for ' + \
'length extrapolatable rotary embeddings'
seq_len = t.shape[seq_dim]
if freq_seq_len is not None:
assert freq_seq_len >= seq_len
seq_len = freq_seq_len
freqs = self.get_freqs(self.get_seq_pos(seq_len,
dtype=t.dtype,
offset=offset),
seq_len=seq_len,
offset=offset)
if seq_dim == -3:
freqs = rearrange(freqs, 'n d -> n 1 d')
rope_output = self.apply_rotary_emb(freqs, t, seq_dim=seq_dim)
return rope_output
class STDiTRmsNorm(RmsNorm):
def __init__(self,
normalized_shape,
num_groups=1,
eps=1e-06,
elementwise_affine=True,
dtype=None):
super().__init__(normalized_shape, num_groups, eps, elementwise_affine,
dtype)
def forward(self, hidden_states):
weight = None if self.weight is None else self.weight.value
return rms_norm(input=hidden_states,
normalized_shape=self.normalized_shape,
num_groups=self.num_groups,
weight=weight,
eps=self.eps)
class STDiTAttention(BertAttention):
def __init__(self,
hidden_size,
num_attention_heads,
qk_layernorm=True,
layernorm_type=LayerNormType.RmsNorm,
layernorm_eps=1e-06,
bias=True,
rotary_embedding_func=None,
dtype=None,
tp_group=None,
tp_size=1,
tp_rank=0,
cp_group=None,
cp_size=1,
quant_mode: QuantMode = QuantMode(0)):
assert hidden_size % num_attention_heads == 0, "hidden_size should be divisible by num_attention_heads"
super().__init__(hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
bias=bias,
dtype=dtype,
tp_group=tp_group,
tp_size=tp_size,
tp_rank=tp_rank,
cp_group=cp_group,
cp_size=cp_size,
quant_mode=quant_mode)
self.qk_layernorm = qk_layernorm
if self.qk_layernorm:
ln_type = layernorm_map[layernorm_type]
self.q_layernorm = ln_type(self.attention_head_size,
eps=layernorm_eps,
dtype=dtype)
self.k_layernorm = ln_type(self.attention_head_size,
eps=layernorm_eps,
dtype=dtype)
self.rotary_embedding_func = rotary_embedding_func
def forward(self,
hidden_states: Tensor,
attention_mask=None,
input_lengths=None,
max_input_length: int = None):
assert isinstance(hidden_states, Tensor)
B = shape(hidden_states, 0)
N = shape(hidden_states, 1)
C = shape(hidden_states, 2)
input_lengths = expand(unsqueeze(N, 0).cast('int32'), unsqueeze(B, 0))
assert (self.qkv is not None)
qkv = self.qkv(hidden_states)
kv_size = self.attention_head_size * self.num_attention_kv_heads
query, key, value = split(
qkv, [self.attention_hidden_size, kv_size, kv_size], dim=2)
query = query.view(
concat([B, N, self.num_attention_heads,
self.attention_head_size])).permute(dims=[0, 2, 1, 3])
key = key.view(
concat(
[B, N, self.num_attention_kv_heads,
self.attention_head_size])).permute(dims=[0, 2, 1, 3])
if self.qk_layernorm:
query = self.q_layernorm(query)
key = self.k_layernorm(key)
if self.rotary_embedding_func is not None:
query = self.rotary_embedding_func(query)
key = self.rotary_embedding_func(key)
# TODO deal with qkv
query = query.permute(dims=[0, 2, 1, 3]).view(
concat([B, N, self.attention_hidden_size]))
key = key.permute(dims=[0, 2, 1, 3]).view(concat([B, N, kv_size]))
qkv = concat([query, key, value], dim=2)
if default_net().plugin_config.bert_attention_plugin:
# TRT plugin mode
assert input_lengths is not None
assert self.cp_size == 1
if default_net().plugin_config.remove_input_padding:
qkv = qkv.view(
concat([-1, self.attention_hidden_size + 2 * kv_size]))
max_input_length = constant(
np.zeros([
max_input_length,
], dtype=np.int32))
context = bert_attention(qkv,
input_lengths,
self.num_attention_heads,
self.attention_head_size,
q_scaling=self.q_scaling,
max_distance=self.max_distance,
max_input_length=max_input_length)
else:
# plain TRT mode
def transpose_for_scores(x):
new_x_shape = concat([
shape(x, 0),
shape(x, 1), self.num_attention_heads,
self.attention_head_size
])
return x.view(new_x_shape).permute([0, 2, 1, 3])
kv_size = self.attention_head_size * self.num_attention_kv_heads
query, key, value = split(
qkv, [self.attention_hidden_size, kv_size, kv_size], dim=2)
if self.cp_size > 1 and self.cp_group is not None:
key = allgather(key, self.cp_group, gather_dim=1)
value = allgather(value, self.cp_group, gather_dim=1)
query = transpose_for_scores(query)
key = transpose_for_scores(key)
value = transpose_for_scores(value)
key = key.permute([0, 1, 3, 2])
attention_scores = matmul(query, key, use_fp32_acc=False)
attention_scores = attention_scores / (self.q_scaling *
self.norm_factor)
if attention_mask is not None:
attention_mask = expand_mask(attention_mask, shape(query, 2))
attention_mask = cast(attention_mask, attention_scores.dtype)
attention_scores = attention_scores + attention_mask
attention_probs = softmax(attention_scores, dim=-1)
context = matmul(attention_probs, value,
use_fp32_acc=False).permute([0, 2, 1, 3])
context = context.view(
concat([
shape(context, 0),
shape(context, 1), self.attention_hidden_size
]))
context = self.dense(context)
context = context.view(concat([B, N, C]))
return context
class STDiTCrossAttention(Attention):
def __init__(self,
*,
local_layer_idx,
hidden_size,
num_attention_heads,
attention_mask_type=AttentionMaskType.causal,
qkv_bias=True,
dense_bias=True,
position_embedding_type=PositionEmbeddingType.learned_absolute,
dtype=None,
tp_group=None,
tp_size=1,
tp_rank=0,
cp_group=[0],
cp_size=1,
cp_rank=0,
quant_mode: QuantMode = QuantMode(0)):
assert hidden_size % num_attention_heads == 0, "hidden_size should be divisible by num_attention_heads"
super().__init__(local_layer_idx=local_layer_idx,
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
attention_mask_type=attention_mask_type,
bias=qkv_bias,
dense_bias=dense_bias,
cross_attention=True,
position_embedding_type=position_embedding_type,
dtype=dtype,
tp_group=tp_group,
tp_size=tp_size,
tp_rank=tp_rank,
cp_group=cp_group,
cp_size=cp_size,
cp_rank=cp_rank,
quant_mode=quant_mode)
def forward(self,
hidden_states: Tensor,
encoder_output: Tensor,
use_cache=False,
attention_params: Optional[AttentionParams] = None,
kv_cache_params: Optional[KeyValueCacheParams] = None):
bs = shape(encoder_output, 0)
encoder_input_length = shape(encoder_output, 1)
encoder_hidden_size = shape(encoder_output, 2)
encoder_output = encoder_output.view(
concat([bs * 2, encoder_input_length // 2, encoder_hidden_size]))
if default_net().plugin_config.remove_input_padding:
B = shape(hidden_states, 0)
N = shape(hidden_states, 1)
C = shape(hidden_states, 2)
hidden_states = hidden_states.view(concat([B * N, C]))
encoder_output = encoder_output.view(
concat([-1, encoder_hidden_size]))
context = super().forward(hidden_states=hidden_states,
encoder_output=encoder_output,
use_cache=use_cache,
attention_params=attention_params,
kv_cache_params=kv_cache_params)
context = context.view(concat([B, -1, C]))
return context
class T2IFinalLayer(Module):
def __init__(self,
hidden_size,
num_patch,
out_channels,
d_t=None,
d_s=None,
dtype=None,
mapping=Mapping(),
quant_mode=QuantMode(0)):
super().__init__()
self.norm_final = LayerNorm(hidden_size,
elementwise_affine=False,
eps=1e-6,
dtype=dtype)
self.linear = Linear(hidden_size,
num_patch * out_channels,
bias=True,
dtype=dtype)
self.scale_shift_table = Parameter(torch.randn(2, hidden_size) /
hidden_size**0.5,
dtype=dtype)
self.out_channels = out_channels
self.d_t = d_t
self.d_s = d_s
self.dtype = dtype
self.mapping = mapping
self.quant_mode = quant_mode
def t_mask_select(self, x_mask, x, masked_x, T: int, S: int):
# x: [B, (T, S), C]
# mased_x: [B, (T, S), C]
# x_mask: [B, T]
x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S)
masked_x = rearrange(masked_x, "B (T S) C -> B T S C", T=T, S=S)
x = where(expand_dims(x_mask,
[x_mask.ndim(), x_mask.ndim() + 1]), x, masked_x)
x = rearrange(x, "B T S C -> B (T S) C")
return x
def forward(self,
x,
t,
x_mask=None,
t0=None,
T: Optional[int] = None,
S: Optional[int] = None):
if T is None:
T = self.d_t
if S is None:
S = self.d_s
shift, scale = chunk(expand_dims(self.scale_shift_table.value, 0) +
expand_dims(t, 1),
chunks=2,
dim=1)
x = t2i_modulate(self.norm_final(x), shift, scale)
if x_mask is not None:
shift_zero, scale_zero = chunk(
expand_dims(self.scale_shift_table.value, 0) +
expand_dims(t0, 1),
chunks=2,
dim=1)
x_zero = t2i_modulate(self.norm_final(x), shift_zero, scale_zero)
x = self.t_mask_select(x_mask, x, x_zero, T, S)
x = self.linear(x)
self.register_network_output('output', x)
return x
class PositionEmbedding2D(Module):
def __init__(self,
dim: int,
dtype=None,
mapping=Mapping(),
quant_mode=QuantMode(0)):
super().__init__()
self.dim = dim
assert dim % 4 == 0, "dim must be divisible by 4"
half_dim = dim // 2
self.inv_freq = Parameter(
1.0 / (10000**(torch.arange(0, half_dim, 2).float() / half_dim)),
is_buffer=True,
dtype=dtype)
self.dtype = dtype
self.mapping = mapping
self.quant_mode = quant_mode
def _get_sin_cos_emb(self, t):
out = einsum("i,d->id", [t, self.inv_freq.value])
emb_cos = cos(out)
emb_sin = sin(out)
return concat([emb_sin, emb_cos], dim=-1)
@functools.lru_cache(maxsize=512)
def _get_cached_emb(
self,
dtype,
h: int,
w: int,
scale: Tensor,
base_size: Optional[int] = None,
):
grid_h = div(arange(0, h, 'float32'), scale.cast('float32'))
grid_w = div(arange(0, w, 'float32'), scale.cast('float32'))
if base_size is not None:
grid_h *= float(base_size) / h
grid_w *= float(base_size) / w
grid_h, grid_w = meshgrid2d(grid_w, grid_h) # here w goes first
grid_h = permute(grid_h, [1, 0]).flatten()
grid_w = permute(grid_w, [1, 0]).flatten()
emb_h = self._get_sin_cos_emb(grid_h)
emb_w = self._get_sin_cos_emb(grid_w)
return unsqueeze(concat([emb_h, emb_w], dim=-1), 0).cast(dtype)
def forward(self, x, h: int, w: int, scale: Tensor, base_size=None):
pos_embedding = self._get_cached_emb(x.dtype, h, w, scale, base_size)
self.register_network_output('output', pos_embedding)
return pos_embedding
class TimestepEmbedder(Module):
def __init__(self,
hidden_size,
frequency_embedding_size=256,
dtype=None,
mapping=Mapping(),
quant_mode=QuantMode(0)):
super().__init__()
self.mlp = ModuleSequential([
Linear(frequency_embedding_size,
hidden_size,
bias=True,
dtype=dtype),
Activation('silu'),
Linear(hidden_size, hidden_size, bias=True, dtype=dtype)
])
self.frequency_embedding_size = frequency_embedding_size
self.dtype = dtype
self.mapping = mapping
self.quant_mode = quant_mode
@staticmethod
def timestep_embedding(t, dim, max_period=10000, dtype=None):
half = dim // 2
freqs = exp(
-math.log(max_period) *
arange(start=0, end=half, dtype=trt_dtype_to_str(trt.float32)) /
constant(np.array([half], dtype=np.float32)))
args = unsqueeze(t, -1).cast(trt.float32) * unsqueeze(freqs, 0)
embedding = concat([cos(args), sin(args)], dim=-1)
if dtype is not None:
embedding = embedding.cast(dtype)
if dim % 2:
embedding = pad(embedding, (0, 0, 0, 1))
return embedding
def forward(self, t, dtype):
t_freq = self.timestep_embedding(
t, self.frequency_embedding_size).cast(dtype)
t_emb = self.mlp(t_freq)
return t_emb
class SizeEmbedder(TimestepEmbedder):
def __init__(self,
hidden_size,
frequency_embedding_size=256,
dtype=str_dtype_to_trt("float16"),
mapping=Mapping(),
quant_mode=QuantMode(0)):
super().__init__(hidden_size=hidden_size,
frequency_embedding_size=frequency_embedding_size,
dtype=dtype,
mapping=mapping,
quant_mode=quant_mode)
self.mlp = ModuleSequential([
Linear(frequency_embedding_size,
hidden_size,
bias=True,
dtype=dtype),
Activation('silu'),
Linear(hidden_size, hidden_size, bias=True, dtype=dtype)
])
self.outdim = hidden_size
def forward(self, s, bs: int):
if not USE_STATIC_SHAPE:
raise NotImplementedError('Only static shape is supported')
if s.ndim() == 1:
s = unsqueeze(s, 1)
assert s.ndim() == 2
if s.shape[0] != bs:
s = repeat(s, [bs // s.shape[0], 1])
assert s.shape[0] == bs
b, dims = s.shape[0], s.shape[1]
s = rearrange(s, "b d -> (b d)")
s_freq = self.timestep_embedding(s, self.frequency_embedding_size).cast(
self.dtype)
s_emb = self.mlp(s_freq)
s_emb = rearrange(s_emb,
"(b d) d2 -> b (d d2)",
b=b,
d=dims,
d2=self.outdim)
self.register_network_output('output', s_emb)
return s_emb
class CaptionMLP(Module):
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer="gelu",
bias=True,
dtype=None,
mapping=Mapping(),
quant_mode=QuantMode(0),
inner_layernorm=False,
eps=1e-05,
):
super().__init__()
hidden_act = act_layer
out_features = out_features or in_features
hidden_features = hidden_features or in_features
bias = to_2tuple(bias)
if hidden_act not in ACT2FN:
raise ValueError(
'unsupported activation function: {}'.format(hidden_act))
fc_output_size = 2 * hidden_features if hidden_act in [
'swiglu', 'gegelu'
] else hidden_features
self.inner_layernorm = LayerNorm(hidden_features, dtype=dtype,
eps=eps) if inner_layernorm else None
self.fc1 = ColumnLinear(in_features,
fc_output_size,
bias=bias[0],
dtype=dtype,
tp_group=mapping.tp_group,
tp_size=mapping.tp_size,
gather_output=False)
self.fc2 = RowLinear(hidden_features,
out_features,
bias=bias[1],
dtype=dtype,
tp_group=mapping.tp_group,
tp_size=mapping.tp_size)
self.in_features = in_features
self.hidden_features = hidden_features
self.out_features = out_features
self.hidden_act = hidden_act
self.dtype = dtype
self.bias = bias
self.mapping = mapping
self.quant_mode = quant_mode
self.eps = eps
def forward(self, hidden_states, gegelu_limit=None):
inter = self.fc1(hidden_states)
if self.hidden_act == 'gegelu':
inter = ACT2FN[self.hidden_act](inter, gegelu_limit)
else:
inter = ACT2FN[self.hidden_act](inter)
if self.inner_layernorm is not None:
inter = self.inner_layernorm(inter)
output = self.fc2(inter)
return output
class CaptionEmbedder(Module):
def __init__(self,
in_channels,
hidden_size,
uncond_prob,
act_layer='gelu',
token_num=120,
dtype=None,
mapping=Mapping(),
quant_mode=QuantMode(0)):
super().__init__()
self.y_proj = CaptionMLP(
in_features=in_channels,
hidden_features=hidden_size,
out_features=hidden_size,
act_layer=act_layer,
mapping=mapping,
dtype=dtype,
)
self.y_embedding = Parameter(torch.randn(token_num, in_channels) /
in_channels**0.5,
dtype=dtype)
self.uncond_prob = uncond_prob
self.dtype = dtype
self.mapping = mapping
self.quant_mode = quant_mode
def token_drop(self, caption, force_drop_ids=None):
if not USE_STATIC_SHAPE:
raise NotImplementedError('Only static shape is supported')
assert (isinstance(force_drop_ids, torch.Tensor)
or isinstance(force_drop_ids, np.array))
if force_drop_ids is None:
drop_ids = torch.rand(caption.shape[0]).cuda() < self.uncond_prob
else:
drop_ids = torch.Tensor(force_drop_ids) == 1
drop_ids = constant(drop_ids.cpu().numpy())
caption = where(expand_dims(drop_ids, [1, 2, 3]),
self.y_embedding.value, caption)
return caption
def forward(self, caption, force_drop_ids=None):
if force_drop_ids is not None:
caption = self.token_drop(caption, force_drop_ids)
caption = self.y_proj(caption)
self.register_network_output('output', caption)
return caption
class PatchEmbed3D(Module):
def __init__(self,
patch_size=(2, 4, 4),
in_chans=3,
embed_dim=96,
norm_layer=None,
flatten=True,
dtype=None,
mapping=Mapping(),
quant_mode=QuantMode(0)):
super().__init__()
self.patch_size = patch_size
self.flatten = flatten
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = Conv3d(in_chans,
embed_dim,
kernel_size=patch_size,
stride=patch_size,
dtype=dtype)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
self.dtype = dtype
self.mapping = mapping
self.quant_mode = quant_mode
def forward(self, x):
if not USE_STATIC_SHAPE:
raise NotImplementedError('Only static shape is supported')
_, _, D, H, W = x.shape
if W % self.patch_size[2] != 0:
x = pad(x, (0, self.patch_size[2] - W % self.patch_size[2]))
if H % self.patch_size[1] != 0:
x = pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1]))
if D % self.patch_size[0] != 0:
x = pad(
x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0]))
x = self.proj(x) # (B C T H W)
if self.norm is not None:
D = shape(x, 2)
Wh = shape(x, 3)
Ww = shape(x, 4)
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
x = x.transpose(1, 2).view([-1, self.embed_dim, D, Wh, Ww])
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCTHW -> BNC
self.register_network_output('output', x)
return x
class STDiT3Block(Module):
def __init__(self,
hidden_size,
num_heads,
mlp_ratio=4.0,
rope=None,
qk_norm=False,
temporal=False,
dtype=None,
local_layer_idx=0,
mapping=Mapping(),
quant_mode=QuantMode(0)):
super().__init__()
self.temporal = temporal
self.hidden_size = hidden_size
attn_cls = STDiTAttention
mha_cls = STDiTCrossAttention
self.norm1 = LayerNorm(hidden_size,
eps=1e-6,
elementwise_affine=False,
dtype=dtype)
self.attn = attn_cls(hidden_size=hidden_size,
num_attention_heads=num_heads,
qk_layernorm=qk_norm,
bias=True,
rotary_embedding_func=rope,
dtype=dtype,
tp_group=mapping.tp_group,
tp_size=mapping.tp_size,
tp_rank=mapping.tp_rank,
cp_group=mapping.cp_group,
cp_size=mapping.cp_size,
quant_mode=quant_mode)
self.cross_attn = mha_cls(
local_layer_idx=local_layer_idx,
hidden_size=hidden_size,
num_attention_heads=num_heads,
attention_mask_type=tensorrt_llm.layers.AttentionMaskType.causal,
qkv_bias=True,
dense_bias=True,
dtype=dtype,
tp_group=mapping.tp_group,
tp_size=mapping.tp_size,
tp_rank=mapping.tp_rank,
cp_group=mapping.cp_group,
cp_size=mapping.cp_size,
quant_mode=quant_mode)
self.norm2 = LayerNorm(hidden_size,
eps=1e-6,
elementwise_affine=False,
dtype=dtype)
self.mlp = CaptionMLP(
in_features=hidden_size,
hidden_features=int(hidden_size * mlp_ratio),
act_layer='gelu',
mapping=mapping,
dtype=dtype,
)
self.scale_shift_table = Parameter(torch.randn(6, hidden_size) /
hidden_size**0.5,
dtype=dtype)
self.dtype = dtype
self.mapping = mapping
self.quant_mode = quant_mode
def t_mask_select(self, x_mask, x, masked_x, T: int, S: int):
# x: [B, (T, S), C]
# mased_x: [B, (T, S), C]
# x_mask: [B, T]
x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S)
masked_x = rearrange(masked_x, "B (T S) C -> B T S C", T=T, S=S)
x = where(expand_dims(x_mask, [2, 3]), x, masked_x)
x = rearrange(x, "B T S C -> B (T S) C")
return x
def forward(
self,
x,
y,
t,
x_mask=None, # temporal mask
t0=None, # t with timestamp=0
T: Optional[int] = None, # number of frames
S: Optional[int] = None, # number of pixel patches
attention_params: Optional[AttentionParams] = None,
kv_cache_params: Optional[KeyValueCacheParams] = None,
):
if not USE_STATIC_SHAPE:
raise NotImplementedError('Only static shape is supported')
# prepare modulate parameters
B, N, C = x.shape
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = chunk(
expand_dims(self.scale_shift_table.value, 0) + t.view([B, 6, -1]),
chunks=6,
dim=1)
if x_mask is not None:
shift_msa_zero, scale_msa_zero, gate_msa_zero, shift_mlp_zero, scale_mlp_zero, gate_mlp_zero = chunk(
expand_dims(self.scale_shift_table.value, 0) +
t0.view([B, 6, -1]),
chunks=6,
dim=1)
# modulate (attention)
x_m = t2i_modulate(self.norm1(x), shift_msa, scale_msa)
if x_mask is not None:
x_m_zero = t2i_modulate(self.norm1(x), shift_msa_zero,
scale_msa_zero)
x_m = self.t_mask_select(x_mask, x_m, x_m_zero, T, S)
# attention
if self.temporal:
x_m = rearrange(x_m, "B (T S) C -> (B S) T C", T=T, S=S)
x_m = self.attn(
x_m, max_input_length=attention_params.max_context_length)
x_m = rearrange(x_m, "(B S) T C -> B (T S) C", T=T, S=S)
else:
x_m = rearrange(x_m, "B (T S) C -> (B T) S C", T=T, S=S)
x_m = self.attn(
x_m, max_input_length=attention_params.max_context_length)
x_m = rearrange(x_m, "(B T) S C -> B (T S) C", T=T, S=S)
# modulate (attention)
x_m_s = gate_msa * x_m
if x_mask is not None:
x_m_s_zero = gate_msa_zero * x_m
x_m_s = self.t_mask_select(x_mask, x_m_s, x_m_s_zero, T, S)
# residual
x = x + x_m_s
# cross attention
cattn = self.cross_attn(
hidden_states=x,
encoder_output=y,
attention_params=AttentionParams(
sequence_length=attention_params.sequence_length,
context_lengths=attention_params.context_lengths,
host_context_lengths=attention_params.host_context_lengths,
max_context_length=attention_params.max_context_length,
host_request_types=attention_params.host_request_types,
encoder_input_lengths=attention_params.encoder_input_lengths,
encoder_max_input_length=attention_params.
encoder_max_input_length,
host_runtime_perf_knobs=attention_params.
host_runtime_perf_knobs,
host_context_progress=attention_params.host_context_progress,
),
kv_cache_params=KeyValueCacheParams(
past_key_value=kv_cache_params.past_key_value,
host_past_key_value_lengths=kv_cache_params.
host_past_key_value_lengths,
host_max_attention_window_sizes=kv_cache_params.
host_max_attention_window_sizes,
host_sink_token_length=kv_cache_params.host_sink_token_length,
cache_indirection=kv_cache_params.cache_indirection,
kv_cache_block_offsets=None,
host_kv_cache_block_offsets=None,
host_kv_cache_pool_pointers=None,
host_kv_cache_pool_mapping=None,
cross_kv_cache_block_offsets=None,
host_cross_kv_cache_block_offsets=None,
host_cross_kv_cache_pool_pointers=None,
host_cross_kv_cache_pool_mapping=None,
))
x = x + cattn
# modulate (MLP)
x_m = t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)
if x_mask is not None:
x_m_zero = t2i_modulate(self.norm2(x), shift_mlp_zero,
scale_mlp_zero)
x_m = self.t_mask_select(x_mask, x_m, x_m_zero, T, S)
# MLP
x_m = self.mlp(x_m)
# modulate (MLP)
x_m_s = gate_mlp * x_m
if x_mask is not None:
x_m_s_zero = gate_mlp_zero * x_m
x_m_s = self.t_mask_select(x_mask, x_m_s, x_m_s_zero, T, S)
# residual
x = x + x_m_s
return x
class STDiT3Model(PretrainedModel):
def __init__(self, config: STDiTModelConfig):
self.check_config(config)
super().__init__(config)
self.learn_sigma = config.learn_sigma
self.in_channels = config.in_channels
self.out_channels = config.in_channels * 2 if config.learn_sigma else config.in_channels
self.caption_channels = config.caption_channels
self.depth = config.num_hidden_layers
self.mlp_ratio = config.mlp_ratio
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.model_max_length = config.model_max_length
self.latent_size = config.latent_size
self.input_sq_size = config.input_sq_size
self.patch_size = config.stdit_patch_size
self.class_dropout_prob = config.class_dropout_prob
self.qk_norm = config.qk_norm
self.dtype = config.dtype
self.mapping = config.mapping
self.pos_embed = PositionEmbedding2D(self.hidden_size, dtype=self.dtype)
self.rope = RotaryEmbedder(dim=self.hidden_size // self.num_heads,
dtype=self.dtype)
self.x_embedder = PatchEmbed3D(self.patch_size,
self.in_channels,
self.hidden_size,
dtype=self.dtype)
self.t_embedder = TimestepEmbedder(self.hidden_size, dtype=self.dtype)
self.fps_embedder = SizeEmbedder(self.hidden_size, dtype=self.dtype)
self.t_block = ModuleSequential([
Activation('silu'),
Linear(self.hidden_size,
6 * self.hidden_size,
bias=True,
dtype=self.dtype)
])
self.y_embedder = CaptionEmbedder(in_channels=self.caption_channels,
hidden_size=self.hidden_size,
uncond_prob=self.class_dropout_prob,
act_layer='gelu',
token_num=self.model_max_length,
dtype=self.dtype)
self.spatial_blocks = ModuleList([
STDiT3Block(hidden_size=self.hidden_size,
num_heads=self.num_heads,
mlp_ratio=self.mlp_ratio,
qk_norm=self.qk_norm,
dtype=self.dtype,
local_layer_idx=idx,
mapping=self.mapping) for idx in range(self.depth)
])
self.temporal_blocks = ModuleList([
STDiT3Block(hidden_size=self.hidden_size,
num_heads=self.num_heads,
mlp_ratio=self.mlp_ratio,
qk_norm=self.qk_norm,
temporal=True,
rope=self.rope.rotate_queries_or_keys,
dtype=self.dtype,
local_layer_idx=idx,
mapping=self.mapping) for idx in range(self.depth)
])
self.final_layer = T2IFinalLayer(self.hidden_size,
np.prod(self.patch_size),
self.out_channels,
dtype=self.dtype,
mapping=self.mapping)
def check_config(self, config: PretrainedConfig):
config.set_if_not_exist('caption_channels', 4096)
config.set_if_not_exist('num_hidden_layers', 28)
config.set_if_not_exist('latent_size', [30, 45, 80])
config.set_if_not_exist('hidden_size', 1152)
config.set_if_not_exist('stdit_patch_size', [1, 2, 2])
config.set_if_not_exist('in_channels', 4)
config.set_if_not_exist('input_sq_size', 512)
config.set_if_not_exist('num_attention_heads', 16)
config.set_if_not_exist('mlp_ratio', 4.0)
config.set_if_not_exist('class_dropout_prob', 0.1)
config.set_if_not_exist('model_max_length', 300)
config.set_if_not_exist('learn_sigma', True)
config.set_if_not_exist('dtype', None)
config.set_if_not_exist('qk_norm', True)
config.set_if_not_exist('skip_y_embedder', False)
def __post_init__(self):
return
def get_dynamic_size(self, x: Tensor):
if not USE_STATIC_SHAPE:
raise NotImplementedError('Only static shape is supported')
_, _, T, H, W = x.shape
if T % self.patch_size[0] != 0:
T += self.patch_size[0] - T % self.patch_size[0]
if H % self.patch_size[1] != 0:
H += self.patch_size[1] - H % self.patch_size[1]
if W % self.patch_size[2] != 0:
W += self.patch_size[2] - W % self.patch_size[2]
T = T // self.patch_size[0]
H = H // self.patch_size[1]
W = W // self.patch_size[2]
return (T, H, W)
def encode_text(self, y: Tensor, mask: Optional[Tensor] = None):
y = self.y_embedder(y) # [B, 1, N_token, C]
if mask is not None:
if mask.shape[0] != y.shape[0]:
mask = repeat(mask, sizes=(y.shape[0] // mask.shape[0], 1))
mask = squeeze(squeeze(mask, 1), 1)
y = masked_select(
squeeze(y, 1),
where(
unsqueeze(mask, -1).__eq__(
constant_to_tensor_(0, dtype=mask.dtype)),
constant_to_tensor_(False),
constant_to_tensor_(True))).view((1, -1, self.hidden_size))
# [TODO] how to convert y_lens to list?
# y_lens = mask.sum(dim=1).tolist()
y_lens = sum(mask, dim=1)
else:
y_lens = constant(
np.array([y.shape[2]] * y.shape[0], dtype=np.int64))
y = squeeze(y, 1).view((1, -1, self.hidden_size))
self.register_network_output('encode_text.output.y', y)
self.register_network_output('encode_text.output.y_lens', y_lens)
return y, y_lens
def unpatchify(self, x: Tensor, N_t: int, N_h: int, N_w: int, R_t: int,
R_h: int, R_w: int):
T_p, H_p, W_p = self.patch_size
x = rearrange(
x,
"B (N_t N_h N_w) (T_p H_p W_p C_out) -> B C_out (N_t T_p) (N_h H_p) (N_w W_p)",
N_t=N_t,
N_h=N_h,
N_w=N_w,
T_p=T_p,
H_p=H_p,
W_p=W_p,
C_out=self.out_channels,
)
# unpad
x = slice(x,
starts=[0] * x.ndim(),
sizes=concat([
shape(x, 0),
shape(x, 1),
constant(np.array([R_t, R_h, R_w]).astype(np.int64))
]))
self.register_network_output('unpatchify.output', x)
return x
def forward(self,
x: Tensor,
timestep: Tensor,
y: Tensor,
fps: Tensor,
height: Tensor,
width: Tensor,
mask: Optional[Tensor] = None,
x_mask: Optional[Tensor] = None,
attention_params: Optional[AttentionParams] = None,
kv_cache_params: Optional[KeyValueCacheParams] = None,
**kwargs):
if not USE_STATIC_SHAPE:
raise NotImplementedError('Only static shape is supported')
assert tuple(x.shape[2:]) == tuple(
self.latent_size), "For now only static shape is supported."
B = x.shape[0]
x = x.cast(self.dtype)
timestep = timestep.cast(self.dtype)
y = y.cast(self.dtype)
fps = fps.cast(self.dtype)
# === get pos embed ===
_, _, Tx, Hx, Wx = x.shape
T, H, W = self.get_dynamic_size(x)
S = H * W
base_size = round(S**0.5)
resolution_sq = pow(
height.cast('float32') * width.cast('float32'),
constant_to_tensor_(0.5, dtype='float32'))
scale = (resolution_sq / self.input_sq_size).cast(self.dtype)
pos_emb = self.pos_embed(x, h=H, w=W, scale=scale, base_size=base_size)
# === get timestep embed ===
t = self.t_embedder(timestep, dtype=x.dtype) # [B, C]
fps = self.fps_embedder(unsqueeze(fps, 1), bs=B)
t = t + fps
t_mlp = self.t_block(t)
t0 = t0_mlp = None
if x_mask is not None:
t0_timestep = constant(
np.zeros(shape=timestep.shape).astype(np.float32)).cast(x.dtype)
t0 = self.t_embedder(t0_timestep, dtype=x.dtype)
t0 = t0 + fps
t0_mlp = self.t_block(t0)
# === get y embed ===
if self.config.skip_y_embedder:
y_lens = mask
if isinstance(y_lens, Tensor):
y_lens = y_lens.cast('int64')
else:
y, y_lens = self.encode_text(y, mask)
y_lens = None #[11, 11]
# === get x embed ===
x = self.x_embedder(x) # [B, N, C]
x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S)
x = x + pos_emb
x = rearrange(x, "B T S C -> B (T S) C", T=T, S=S)
# === blocks ===
cnt = 0
for spatial_block, temporal_block in zip(self.spatial_blocks,
self.temporal_blocks):
x = spatial_block(
x,
y,
t_mlp,
x_mask=x_mask,
t0=t0_mlp,
T=T,
S=S,
attention_params=attention_params,
kv_cache_params=kv_cache_params,
)
x = temporal_block(
x,
y,
t_mlp,
x_mask=x_mask,
t0=t0_mlp,
T=T,
S=S,
attention_params=attention_params,
kv_cache_params=kv_cache_params,
)
cnt += 1
# === final layer ===
x = self.final_layer(x, t, x_mask, t0, T, S)
x = self.unpatchify(x, T, H, W, Tx, Hx, Wx)
# cast to float32 for better accuracy
output = x.cast('float32')
output.mark_output('output', 'float32')
return output
def prepare_inputs(self, max_batch_size, **kwargs):
'''@brief: Prepare inputs Tensors for the model, the given sizes are used to determine the
ranges of the dimensions of when using TRT dynamic shapes.
@return: a list contains values which can be fed into the self.forward()
'''
mapping = self.config.mapping
if mapping.tp_size > 1:
current_all_reduce_helper().set_workspace_tensor(mapping, 1)
def stdit_default_batch_range(max_batch_size):
return [max_batch_size, max_batch_size, max_batch_size]
default_range = stdit_default_batch_range
# [NOTE] For now only static batch size is supported, so we run the model with max_batch_size.
batch_size = max_batch_size
x = Tensor(name='x',
dtype=self.dtype,
shape=[batch_size, self.in_channels, *self.latent_size],
dim_range=OrderedDict([
('batch_size', [default_range(max_batch_size)]),
('in_channels', [[self.in_channels] * 3]),
('latent_frames', [[self.latent_size[0]] * 3]),
('latent_height', [[self.latent_size[1]] * 3]),
('latent_width', [[self.latent_size[2]] * 3]),
]))
timestep = Tensor(name='timestep',
dtype=self.dtype,
shape=[batch_size],
dim_range=OrderedDict([
('batch_size', [default_range(max_batch_size)]),
]))
y = Tensor(
name='y',
dtype=self.dtype,
shape=[batch_size, 1, self.model_max_length, self.caption_channels],
dim_range=OrderedDict([
('batch_size', [default_range(max_batch_size)]),
('mask_batch_size', [[1, 1, 1]]),
('num_tokens', [[self.model_max_length] * 3]),
('caption_channels', [[self.caption_channels] * 3]),
]))
mask = Tensor(name='mask',
dtype=trt.int32,
shape=[1, self.model_max_length],
dim_range=OrderedDict([
('mask_batch_size', [[1, 1, 1]]),
('num_tokens', [[self.model_max_length] * 3]),
]))
x_mask = Tensor(name='x_mask',
dtype=tensorrt_llm.str_dtype_to_trt('bool'),
shape=[batch_size, self.latent_size[0]],
dim_range=OrderedDict([
('batch_size', [default_range(max_batch_size)]),
('latent_frames', [[self.latent_size[0]] * 3]),
]))
fps = Tensor(name='fps', dtype=self.dtype, shape=[
1,
])
height = Tensor(name='height', dtype=self.dtype, shape=[
1,
])
width = Tensor(name='width', dtype=self.dtype, shape=[
1,
])
use_gpt_attention_plugin = default_net(
).plugin_config.gpt_attention_plugin
remove_input_padding = default_net().plugin_config.remove_input_padding
cross_attn_batch_size = batch_size
max_cattn_seq_len = int(
np.prod([
np.ceil(d / p)
for d, p in zip(self.latent_size, self.patch_size)
]))
max_cattn_enc_len = self.model_max_length
attn_inputs = GenerationMixin().prepare_attention_inputs(
max_batch_size=cross_attn_batch_size,
opt_batch_size=cross_attn_batch_size,
max_beam_width=1,
max_input_len=max_cattn_seq_len,
max_seq_len=max_cattn_seq_len,
num_kv_heads=self.num_heads,
head_size=self.hidden_size // self.num_heads,
num_layers=self.depth,
kv_dtype=self.dtype,
kv_cache_type=KVCacheType.DISABLED,
remove_input_padding=remove_input_padding,
use_gpt_attention_plugin=use_gpt_attention_plugin,
enable_ctx_gen_opt_profiles=False,
mapping=self.mapping,
)
sequence_length = attn_inputs['sequence_length']
host_context_lengths = attn_inputs['host_context_lengths']
host_max_attention_window_sizes = attn_inputs[
'host_max_attention_window_sizes']
host_sink_token_length = attn_inputs['host_sink_token_length']
context_lengths = attn_inputs['context_lengths']
host_request_types = attn_inputs['host_request_types']
host_past_key_value_lengths = attn_inputs['host_past_key_value_lengths']
past_key_value = attn_inputs['past_key_value']
if past_key_value:
past_key_value = past_key_value[0]
cache_indirection = attn_inputs['cache_indirection']
host_runtime_perf_knobs_tensor = attn_inputs['host_runtime_perf_knobs']
host_context_progress = attn_inputs['host_context_progress']
cross_encoder_input_lengths = Tensor(name='encoder_input_lengths',
shape=(cross_attn_batch_size, ),
dtype=str_dtype_to_trt('int32'))
cross_max_encoder_seq_len = Tensor(
name='encoder_max_input_length',
shape=[-1],
dim_range=OrderedDict([
("encoder_max_input_length",
[[1, (max_cattn_enc_len + 1) // 2, max_cattn_enc_len]])
]),
dtype=str_dtype_to_trt('int32'))
attention_params = AttentionParams(
sequence_length=sequence_length,
context_lengths=context_lengths,
host_context_lengths=host_context_lengths,
max_context_length=max_cattn_seq_len,
host_request_types=host_request_types,
encoder_input_lengths=cross_encoder_input_lengths,
encoder_max_input_length=cross_max_encoder_seq_len,
host_runtime_perf_knobs=host_runtime_perf_knobs_tensor,
host_context_progress=host_context_progress)
kv_cache_params = KeyValueCacheParams(
past_key_value=past_key_value,
host_past_key_value_lengths=host_past_key_value_lengths,
host_max_attention_window_sizes=host_max_attention_window_sizes,
host_sink_token_length=host_sink_token_length,
cache_indirection=cache_indirection,
kv_cache_block_offsets=None,
host_kv_cache_block_offsets=None,
host_kv_cache_pool_pointers=None,
host_kv_cache_pool_mapping=None,
cross_kv_cache_block_offsets=None,
host_cross_kv_cache_block_offsets=None,
host_cross_kv_cache_pool_pointers=None,
host_cross_kv_cache_pool_mapping=None,
)
return {
'x': x,
'timestep': timestep,
'y': y,
'mask': mask,
'x_mask': x_mask,
'fps': fps,
'height': height,
'width': width,
'attention_params': attention_params,
'kv_cache_params': kv_cache_params,
}
@classmethod
def from_pretrained(cls,
pretrained_model_dir: str,
dtype='float16',
mapping=Mapping(),
**kwargs):
quant_ckpt_path = kwargs.pop('quant_ckpt_path', None)
assert os.path.exists(f"{pretrained_model_dir}/config.json")
with open(f"{pretrained_model_dir}/config.json", 'r') as f:
hf_config = json.load(f)
hf_tllm_config_remapping = {
'model_type': 'architecture',
'depth': 'num_hidden_layers',
'num_heads': 'num_attention_heads',
'patch_size': 'stdit_patch_size',
'pred_sigma': 'learn_sigma',
}
for hf_key, tllm_key in hf_tllm_config_remapping.items():
hf_config[tllm_key] = hf_config.pop(hf_key)
hf_config.update(kwargs)
model_config = STDiTModelConfig.from_input_config(hf_config,
dtype=dtype,
mapping=mapping)
model_dir = pretrained_model_dir
custom_dict = {}
if quant_ckpt_path is not None:
model_dir = quant_ckpt_path
loader = STDiT3ModelWeightsLoader(model_dir, custom_dict)
model = cls(model_config)
loader.generate_tllm_weights(model)
return model
class STDiT3ModelWeightsLoader(ModelWeightsLoader):
def translate_to_external_key(self, tllm_key: str,
tllm_to_externel_key_dict: dict):
"""Convert and load external checkpoint into a TensorRT LLM model.
"""
trtllm_to_hf_name = {
r"spatial_blocks.(\d+).attn.q_layernorm.weight":
"spatial_blocks.*.attn.q_norm.weight",
r"spatial_blocks.(\d+).attn.k_layernorm.weight":
"spatial_blocks.*.attn.k_norm.weight",
r"spatial_blocks.(\d+).attn.dense.weight":
"spatial_blocks.*.attn.proj.weight",
r"spatial_blocks.(\d+).attn.dense.bias":
"spatial_blocks.*.attn.proj.bias",
r"temporal_blocks.(\d+).attn.q_layernorm.weight":
"temporal_blocks.*.attn.q_norm.weight",
r"temporal_blocks.(\d+).attn.k_layernorm.weight":
"temporal_blocks.*.attn.k_norm.weight",
r"temporal_blocks.(\d+).attn.dense.weight":
"temporal_blocks.*.attn.proj.weight",
r"temporal_blocks.(\d+).attn.dense.bias":
"temporal_blocks.*.attn.proj.bias",
r"spatial_blocks.(\d+).cross_attn.dense.weight":
"spatial_blocks.*.cross_attn.proj.weight",
r"spatial_blocks.(\d+).cross_attn.dense.bias":
"spatial_blocks.*.cross_attn.proj.bias",
r"temporal_blocks.(\d+).cross_attn.dense.weight":
"temporal_blocks.*.cross_attn.proj.weight",
r"temporal_blocks.(\d+).cross_attn.dense.bias":
"temporal_blocks.*.cross_attn.proj.bias",
}
for k, v in trtllm_to_hf_name.items():
m = re.match(k, tllm_key)
if m is not None:
matched_pos = m.groups()
placeholders = v.count('*')
assert len(matched_pos) == placeholders
for i in range(len(matched_pos)):
v = v.replace('*', matched_pos[i], 1)
return v
return tllm_key
def load_tensor(self, key, tp_size=1, tp_dim=-1, tp_rank=0):
hidden_size = self.model.config.hidden_size
if "attn.qkv" in key:
is_cross_attn = "cross_attn.qkv" in key
if is_cross_attn:
# process for cross attention
process_qkv_names = [
'q_linear'.join(key.split('qkv')),
'kv_linear'.join(key.split('qkv'))
]
else:
process_qkv_names = [key]
qkv_tensors = []
for qkv_key in process_qkv_names:
# Retrieve shard index
assert qkv_key in self.shard_map
ptr_idx = self.shard_map[qkv_key]
if self.format == ModelWeightsFormat.SAFETENSORS:
# Force to load Pytorch tensor
tensor = self.shards[ptr_idx].get_tensor(qkv_key)
else:
tensor = self.shards[ptr_idx][qkv_key]
qkv_tensors.append(tensor)
if is_cross_attn:
tensor = torch.concat(qkv_tensors, dim=0)
else:
tensor = qkv_tensors[0]
# Post-process weight and bias if tp_size > 1
if tp_size > 1:
if "weight" in key:
tensor = tensor.reshape(3, hidden_size, hidden_size)
elif "bias" in key:
tensor = tensor.reshape(3, hidden_size)
tp_dim = 1
tensor_shape = tensor.shape
else:
# Retrieve shard index
if key in self.shard_map:
ptr_idx = self.shard_map[key]
else:
return None
if self.format == ModelWeightsFormat.SAFETENSORS:
tensor = self.shards[ptr_idx].get_slice(key)
tensor_shape = tensor.get_shape()
if tensor_shape == []:
tensor = self.shards[ptr_idx].get_tensor(key).unsqueeze(0)
tensor_shape = tensor.shape
else:
tensor = self.shards[ptr_idx][key]
tensor_shape = tensor.shape
if tp_size <= 1 or tp_dim < 0:
return tensor[:]
else:
if len(tensor_shape) == 1 and (tp_dim > 0 or tensor_shape[0] == 1):
return tensor[:]
else:
width = tensor_shape[tp_dim]
if width == 1:
return tensor[:]
slice_width = math.ceil(width / tp_size)
slice_start = tp_rank * slice_width
slice_end = builtins.min((tp_rank + 1) * slice_width, width)
slice_obj = [builtins.slice(None)] * len(tensor_shape)
slice_obj[tp_dim] = builtins.slice(slice_start, slice_end)
res = tensor[tuple(slice_obj)]
if "qkv.weight" in key:
res = res.reshape(3 * (hidden_size // tp_size), hidden_size)
elif "qkv.bias" in key:
res = res.reshape(3 * (hidden_size // tp_size))
return res
def generate_tllm_weights(self,
model,
custom_postprocess_kwargs: dict = {}):
self.update_key_mapping(model)
tp_module_patterns = [
r'.*_blocks.*.attn.qkv.weight$',
r'.*_blocks.*.attn.qkv.bias$',
r'.*_blocks.*.attn.dense.weight$',
r'.*_blocks.*.cross_attn.qkv.weight$',
r'.*_blocks.*.cross_attn.qkv.bias$',
r'.*_blocks.*.cross_attn.dense.weight$',
r'.*_blocks.*.mlp.fc1.weight$',
r'.*_blocks.*.mlp.fc1.bias$',
r'.*_blocks.*.mlp.fc2.weight$',
]
tllm_weights = {}
for tllm_key, _ in tqdm(model.named_parameters()):
skip_tp = not any([
re.match(pattern, tllm_key) is not None
for pattern in tp_module_patterns
])
tllm_weights.update(
self.load(tllm_key,
custom_postprocess_kwargs=custom_postprocess_kwargs,
skip_tp=skip_tp))
self.fill(tllm_weights)