TensorRT-LLMs/examples/stdit/scheduler.py
2025-03-11 21:13:42 +08:00

102 lines
3.9 KiB
Python

# Copyright 2024 HPC-AI Technology Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# reference: https://github.com/hpcaitech/Open-Sora/blob/main/opensora/schedulers/rf/rectified_flow.py
import torch
from torch.distributions import LogisticNormal
def timestep_transform(
t,
model_kwargs,
base_resolution=512 * 512,
base_num_frames=1,
scale=1.0,
num_timesteps=1,
):
# Force fp16 input to fp32 to avoid nan output
for key in ["height", "width", "num_frames"]:
if model_kwargs[key].dtype == torch.float16:
model_kwargs[key] = model_kwargs[key].float()
t = t / num_timesteps
resolution = model_kwargs["height"] * model_kwargs["width"]
ratio_space = (resolution / base_resolution).sqrt()
# NOTE: currently, we do not take fps into account
# NOTE: temporal_reduction is hardcoded, this should be equal to the temporal reduction factor of the vae
if model_kwargs["num_frames"][0] == 1:
num_frames = torch.ones_like(model_kwargs["num_frames"])
else:
num_frames = model_kwargs["num_frames"] // 17 * 5
ratio_time = (num_frames / base_num_frames).sqrt()
ratio = ratio_space * ratio_time * scale
new_t = ratio * t / (1 + (ratio - 1) * t)
new_t = new_t * num_timesteps
return new_t
class RFlowScheduler:
def __init__(
self,
num_timesteps=1000,
num_sampling_steps=10,
sample_method="uniform",
loc=0.0,
scale=1.0,
use_timestep_transform=False,
):
self.num_timesteps = num_timesteps
self.num_sampling_steps = num_sampling_steps
assert sample_method in ["uniform", "logit-normal"]
self.sample_method = sample_method
if sample_method == "logit-normal":
self.distribution = LogisticNormal(torch.tensor([loc]),
torch.tensor([scale]))
self.sample_t = lambda x: self.distribution.sample(
(x.shape[0], ))[:, 0].to(x.device)
self.use_timestep_transform = use_timestep_transform
def add_noise(
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.IntTensor,
) -> torch.FloatTensor:
timepoints = timesteps.float() / self.num_timesteps
timepoints = 1 - timepoints # [1,1/1000]
# timepoint (bsz) noise: (bsz, 4, frame, w ,h)
# expand timepoint to noise shape
timepoints = timepoints.unsqueeze(1).unsqueeze(1).unsqueeze(
1).unsqueeze(1)
timepoints = timepoints.repeat(1, noise.shape[1], noise.shape[2],
noise.shape[3], noise.shape[4])
return timepoints * original_samples + (1 - timepoints) * noise