TensorRT-LLMs/examples/models/contrib/stdit/scheduler.py
bhsueh_NV 322ac565fc
chore: clean some ci of qa test (#3083)
* move some models to examples/models/contrib

Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>

* update the document

Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>

* remove arctic, blip2, cogvlm, dbrx from qa test list

Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>

* remove tests of dit, mmdit and stdit from qa test

Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>

* remove grok, jais, sdxl, skywork, smaug from qa test list

Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>

* re-organize the glm examples

Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>

* fix issues after running pre-commit

Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>

* fix some typo in glm_4_9b readme

Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>

* fix bug

Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>

---------

Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>
2025-03-31 14:30:41 +08:00

102 lines
3.9 KiB
Python

# Copyright 2024 HPC-AI Technology Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# reference: https://github.com/hpcaitech/Open-Sora/blob/main/opensora/schedulers/rf/rectified_flow.py
import torch
from torch.distributions import LogisticNormal
def timestep_transform(
t,
model_kwargs,
base_resolution=512 * 512,
base_num_frames=1,
scale=1.0,
num_timesteps=1,
):
# Force fp16 input to fp32 to avoid nan output
for key in ["height", "width", "num_frames"]:
if model_kwargs[key].dtype == torch.float16:
model_kwargs[key] = model_kwargs[key].float()
t = t / num_timesteps
resolution = model_kwargs["height"] * model_kwargs["width"]
ratio_space = (resolution / base_resolution).sqrt()
# NOTE: currently, we do not take fps into account
# NOTE: temporal_reduction is hardcoded, this should be equal to the temporal reduction factor of the vae
if model_kwargs["num_frames"][0] == 1:
num_frames = torch.ones_like(model_kwargs["num_frames"])
else:
num_frames = model_kwargs["num_frames"] // 17 * 5
ratio_time = (num_frames / base_num_frames).sqrt()
ratio = ratio_space * ratio_time * scale
new_t = ratio * t / (1 + (ratio - 1) * t)
new_t = new_t * num_timesteps
return new_t
class RFlowScheduler:
def __init__(
self,
num_timesteps=1000,
num_sampling_steps=10,
sample_method="uniform",
loc=0.0,
scale=1.0,
use_timestep_transform=False,
):
self.num_timesteps = num_timesteps
self.num_sampling_steps = num_sampling_steps
assert sample_method in ["uniform", "logit-normal"]
self.sample_method = sample_method
if sample_method == "logit-normal":
self.distribution = LogisticNormal(torch.tensor([loc]),
torch.tensor([scale]))
self.sample_t = lambda x: self.distribution.sample(
(x.shape[0], ))[:, 0].to(x.device)
self.use_timestep_transform = use_timestep_transform
def add_noise(
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.IntTensor,
) -> torch.FloatTensor:
timepoints = timesteps.float() / self.num_timesteps
timepoints = 1 - timepoints # [1,1/1000]
# timepoint (bsz) noise: (bsz, 4, frame, w ,h)
# expand timepoint to noise shape
timepoints = timepoints.unsqueeze(1).unsqueeze(1).unsqueeze(
1).unsqueeze(1)
timepoints = timepoints.repeat(1, noise.shape[1], noise.shape[2],
noise.shape[3], noise.shape[4])
return timepoints * original_samples + (1 - timepoints) * noise