# Copyright 2024 HPC-AI Technology Inc. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # reference: https://github.com/hpcaitech/Open-Sora/blob/main/opensora/schedulers/rf/rectified_flow.py import torch from torch.distributions import LogisticNormal def timestep_transform( t, model_kwargs, base_resolution=512 * 512, base_num_frames=1, scale=1.0, num_timesteps=1, ): # Force fp16 input to fp32 to avoid nan output for key in ["height", "width", "num_frames"]: if model_kwargs[key].dtype == torch.float16: model_kwargs[key] = model_kwargs[key].float() t = t / num_timesteps resolution = model_kwargs["height"] * model_kwargs["width"] ratio_space = (resolution / base_resolution).sqrt() # NOTE: currently, we do not take fps into account # NOTE: temporal_reduction is hardcoded, this should be equal to the temporal reduction factor of the vae if model_kwargs["num_frames"][0] == 1: num_frames = torch.ones_like(model_kwargs["num_frames"]) else: num_frames = model_kwargs["num_frames"] // 17 * 5 ratio_time = (num_frames / base_num_frames).sqrt() ratio = ratio_space * ratio_time * scale new_t = ratio * t / (1 + (ratio - 1) * t) new_t = new_t * num_timesteps return new_t class RFlowScheduler: def __init__( self, num_timesteps=1000, num_sampling_steps=10, sample_method="uniform", loc=0.0, scale=1.0, use_timestep_transform=False, ): self.num_timesteps = num_timesteps self.num_sampling_steps = num_sampling_steps assert sample_method in ["uniform", "logit-normal"] self.sample_method = sample_method if sample_method == "logit-normal": self.distribution = LogisticNormal(torch.tensor([loc]), torch.tensor([scale])) self.sample_t = lambda x: self.distribution.sample( (x.shape[0], ))[:, 0].to(x.device) self.use_timestep_transform = use_timestep_transform def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor, ) -> torch.FloatTensor: timepoints = timesteps.float() / self.num_timesteps timepoints = 1 - timepoints # [1,1/1000] # timepoint (bsz) noise: (bsz, 4, frame, w ,h) # expand timepoint to noise shape timepoints = timepoints.unsqueeze(1).unsqueeze(1).unsqueeze( 1).unsqueeze(1) timepoints = timepoints.repeat(1, noise.shape[1], noise.shape[2], noise.shape[3], noise.shape[4]) return timepoints * original_samples + (1 - timepoints) * noise