Compare commits
18 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 087a6f00aa | |||
| 2c1ed50fc5 | |||
| 15ad97f782 | |||
| 9f2d5c9ee9 | |||
| dc62e6931e | |||
| 56f740051d | |||
| a34d97cef0 | |||
| fc28791fc8 | |||
| ae14612673 | |||
| 0ab8fe49bf | |||
| 3be6706018 | |||
| cb1b8b21b8 | |||
| 27916822b2 | |||
| 3fe3bc0642 | |||
| 813d42cc96 | |||
| b4d7e9c632 | |||
| 2e83cbbb6d | |||
| 33d10af28f |
@@ -28,7 +28,51 @@ env:
|
||||
PIPELINE_USAGE_CUTOFF: 1000000000 # set high cutoff so that only always-test pipelines run
|
||||
|
||||
jobs:
|
||||
check_code_quality:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.8"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install .[quality]
|
||||
- name: Check quality
|
||||
run: make quality
|
||||
- name: Check if failure
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
echo "Quality check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make style && make quality'" >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
check_repository_consistency:
|
||||
needs: check_code_quality
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.8"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install .[quality]
|
||||
- name: Check repo consistency
|
||||
run: |
|
||||
python utils/check_copies.py
|
||||
python utils/check_dummies.py
|
||||
python utils/check_support_list.py
|
||||
make deps_table_check_updated
|
||||
- name: Check if failure
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
echo "Repo consistency check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make fix-copies'" >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
setup_torch_cuda_pipeline_matrix:
|
||||
needs: [check_code_quality, check_repository_consistency]
|
||||
name: Setup Torch Pipelines CUDA Slow Tests Matrix
|
||||
runs-on:
|
||||
group: aws-general-8-plus
|
||||
@@ -133,6 +177,7 @@ jobs:
|
||||
|
||||
torch_cuda_tests:
|
||||
name: Torch CUDA Tests
|
||||
needs: [check_code_quality, check_repository_consistency]
|
||||
runs-on:
|
||||
group: aws-g4dn-2xlarge
|
||||
container:
|
||||
@@ -201,7 +246,7 @@ jobs:
|
||||
|
||||
run_examples_tests:
|
||||
name: Examples PyTorch CUDA tests on Ubuntu
|
||||
pip uninstall transformers -y && python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
|
||||
needs: [check_code_quality, check_repository_consistency]
|
||||
runs-on:
|
||||
group: aws-g4dn-2xlarge
|
||||
|
||||
@@ -220,6 +265,7 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
pip uninstall transformers -y && python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
|
||||
python -m uv pip install -e [quality,test,training]
|
||||
|
||||
- name: Environment
|
||||
|
||||
@@ -196,6 +196,12 @@ export_to_video(video, "ship.mp4", fps=24)
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## LTXConditionPipeline
|
||||
|
||||
[[autodoc]] LTXConditionPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## LTXPipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.ltx.pipeline_output.LTXPipelineOutput
|
||||
|
||||
@@ -627,6 +627,7 @@ def main(args):
|
||||
ema_vae = EMAModel(vae.parameters(), model_cls=AutoencoderKL, model_config=vae.config)
|
||||
perceptual_loss = lpips.LPIPS(net="vgg").eval()
|
||||
discriminator = NLayerDiscriminator(input_nc=3, n_layers=3, use_actnorm=False).apply(weights_init)
|
||||
discriminator = torch.nn.SyncBatchNorm.convert_sync_batchnorm(discriminator)
|
||||
|
||||
# Taken from [Sayak Paul's Diffusers PR #6511](https://github.com/huggingface/diffusers/pull/6511/files)
|
||||
def unwrap_model(model):
|
||||
@@ -951,13 +952,20 @@ def main(args):
|
||||
logits_fake = discriminator(reconstructions)
|
||||
disc_loss = hinge_d_loss if args.disc_loss == "hinge" else vanilla_d_loss
|
||||
disc_factor = args.disc_factor if global_step >= args.disc_start else 0.0
|
||||
disc_loss = disc_factor * disc_loss(logits_real, logits_fake)
|
||||
d_loss = disc_factor * disc_loss(logits_real, logits_fake)
|
||||
logs = {
|
||||
"disc_loss": disc_loss.detach().mean().item(),
|
||||
"disc_loss": d_loss.detach().mean().item(),
|
||||
"logits_real": logits_real.detach().mean().item(),
|
||||
"logits_fake": logits_fake.detach().mean().item(),
|
||||
"disc_lr": disc_lr_scheduler.get_last_lr()[0],
|
||||
}
|
||||
accelerator.backward(d_loss)
|
||||
if accelerator.sync_gradients:
|
||||
params_to_clip = discriminator.parameters()
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
disc_optimizer.step()
|
||||
disc_lr_scheduler.step()
|
||||
disc_optimizer.zero_grad(set_to_none=args.set_grads_to_none)
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
# Generating images using Flux and PyTorch/XLA
|
||||
|
||||
The `flux_inference` script shows how to do image generation using Flux on TPU devices using PyTorch/XLA. It uses the pallas kernel for flash attention for faster generation.
|
||||
|
||||
It has been tested on [Trillium](https://cloud.google.com/blog/products/compute/introducing-trillium-6th-gen-tpus) TPU versions. No other TPU types have been tested.
|
||||
The `flux_inference` script shows how to do image generation using Flux on TPU devices using PyTorch/XLA. It uses the pallas kernel for flash attention for faster generation using custom flash block sizes for better performance on [Trillium](https://cloud.google.com/blog/products/compute/introducing-trillium-6th-gen-tpus) TPU versions. No other TPU types have been tested.
|
||||
|
||||
## Create TPU
|
||||
|
||||
@@ -23,20 +21,23 @@ Verify that PyTorch and PyTorch/XLA were installed correctly:
|
||||
python3 -c "import torch; import torch_xla;"
|
||||
```
|
||||
|
||||
Install dependencies
|
||||
Clone the diffusers repo and install dependencies
|
||||
|
||||
```bash
|
||||
git clone https://github.com/huggingface/diffusers.git
|
||||
cd diffusers
|
||||
pip install transformers accelerate sentencepiece structlog
|
||||
pushd ../../..
|
||||
pip install .
|
||||
popd
|
||||
cd examples/research_projects/pytorch_xla/inference/flux/
|
||||
```
|
||||
|
||||
## Run the inference job
|
||||
|
||||
### Authenticate
|
||||
|
||||
Run the following command to authenticate your token in order to download Flux weights.
|
||||
**Gated Model**
|
||||
|
||||
As the model is gated, before using it with diffusers you first need to go to the [FLUX.1 [dev] Hugging Face page](https://huggingface.co/black-forest-labs/FLUX.1-dev), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in:
|
||||
|
||||
```bash
|
||||
huggingface-cli login
|
||||
|
||||
@@ -74,6 +74,32 @@ VAE_091_RENAME_DICT = {
|
||||
"last_scale_shift_table": "scale_shift_table",
|
||||
}
|
||||
|
||||
VAE_095_RENAME_DICT = {
|
||||
# decoder
|
||||
"up_blocks.0": "mid_block",
|
||||
"up_blocks.1": "up_blocks.0.upsamplers.0",
|
||||
"up_blocks.2": "up_blocks.0",
|
||||
"up_blocks.3": "up_blocks.1.upsamplers.0",
|
||||
"up_blocks.4": "up_blocks.1",
|
||||
"up_blocks.5": "up_blocks.2.upsamplers.0",
|
||||
"up_blocks.6": "up_blocks.2",
|
||||
"up_blocks.7": "up_blocks.3.upsamplers.0",
|
||||
"up_blocks.8": "up_blocks.3",
|
||||
# encoder
|
||||
"down_blocks.0": "down_blocks.0",
|
||||
"down_blocks.1": "down_blocks.0.downsamplers.0",
|
||||
"down_blocks.2": "down_blocks.1",
|
||||
"down_blocks.3": "down_blocks.1.downsamplers.0",
|
||||
"down_blocks.4": "down_blocks.2",
|
||||
"down_blocks.5": "down_blocks.2.downsamplers.0",
|
||||
"down_blocks.6": "down_blocks.3",
|
||||
"down_blocks.7": "down_blocks.3.downsamplers.0",
|
||||
"down_blocks.8": "mid_block",
|
||||
# common
|
||||
"last_time_embedder": "time_embedder",
|
||||
"last_scale_shift_table": "scale_shift_table",
|
||||
}
|
||||
|
||||
VAE_SPECIAL_KEYS_REMAP = {
|
||||
"per_channel_statistics.channel": remove_keys_,
|
||||
"per_channel_statistics.mean-of-means": remove_keys_,
|
||||
@@ -81,10 +107,6 @@ VAE_SPECIAL_KEYS_REMAP = {
|
||||
"model.diffusion_model": remove_keys_,
|
||||
}
|
||||
|
||||
VAE_091_SPECIAL_KEYS_REMAP = {
|
||||
"timestep_scale_multiplier": remove_keys_,
|
||||
}
|
||||
|
||||
|
||||
def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||
state_dict = saved_dict
|
||||
@@ -104,12 +126,16 @@ def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key:
|
||||
def convert_transformer(
|
||||
ckpt_path: str,
|
||||
dtype: torch.dtype,
|
||||
version: str = "0.9.0",
|
||||
):
|
||||
PREFIX_KEY = "model.diffusion_model."
|
||||
|
||||
original_state_dict = get_state_dict(load_file(ckpt_path))
|
||||
config = {}
|
||||
if version == "0.9.5":
|
||||
config["_use_causal_rope_fix"] = True
|
||||
with init_empty_weights():
|
||||
transformer = LTXVideoTransformer3DModel()
|
||||
transformer = LTXVideoTransformer3DModel(**config)
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
@@ -161,12 +187,19 @@ def get_vae_config(version: str) -> Dict[str, Any]:
|
||||
"out_channels": 3,
|
||||
"latent_channels": 128,
|
||||
"block_out_channels": (128, 256, 512, 512),
|
||||
"down_block_types": (
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
),
|
||||
"decoder_block_out_channels": (128, 256, 512, 512),
|
||||
"layers_per_block": (4, 3, 3, 3, 4),
|
||||
"decoder_layers_per_block": (4, 3, 3, 3, 4),
|
||||
"spatio_temporal_scaling": (True, True, True, False),
|
||||
"decoder_spatio_temporal_scaling": (True, True, True, False),
|
||||
"decoder_inject_noise": (False, False, False, False, False),
|
||||
"downsample_type": ("conv", "conv", "conv", "conv"),
|
||||
"upsample_residual": (False, False, False, False),
|
||||
"upsample_factor": (1, 1, 1, 1),
|
||||
"patch_size": 4,
|
||||
@@ -183,12 +216,19 @@ def get_vae_config(version: str) -> Dict[str, Any]:
|
||||
"out_channels": 3,
|
||||
"latent_channels": 128,
|
||||
"block_out_channels": (128, 256, 512, 512),
|
||||
"down_block_types": (
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
),
|
||||
"decoder_block_out_channels": (256, 512, 1024),
|
||||
"layers_per_block": (4, 3, 3, 3, 4),
|
||||
"decoder_layers_per_block": (5, 6, 7, 8),
|
||||
"spatio_temporal_scaling": (True, True, True, False),
|
||||
"decoder_spatio_temporal_scaling": (True, True, True),
|
||||
"decoder_inject_noise": (True, True, True, False),
|
||||
"downsample_type": ("conv", "conv", "conv", "conv"),
|
||||
"upsample_residual": (True, True, True),
|
||||
"upsample_factor": (2, 2, 2),
|
||||
"timestep_conditioning": True,
|
||||
@@ -200,7 +240,38 @@ def get_vae_config(version: str) -> Dict[str, Any]:
|
||||
"decoder_causal": False,
|
||||
}
|
||||
VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT)
|
||||
VAE_SPECIAL_KEYS_REMAP.update(VAE_091_SPECIAL_KEYS_REMAP)
|
||||
elif version == "0.9.5":
|
||||
config = {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"latent_channels": 128,
|
||||
"block_out_channels": (128, 256, 512, 1024, 2048),
|
||||
"down_block_types": (
|
||||
"LTXVideo095DownBlock3D",
|
||||
"LTXVideo095DownBlock3D",
|
||||
"LTXVideo095DownBlock3D",
|
||||
"LTXVideo095DownBlock3D",
|
||||
),
|
||||
"decoder_block_out_channels": (256, 512, 1024),
|
||||
"layers_per_block": (4, 6, 6, 2, 2),
|
||||
"decoder_layers_per_block": (5, 5, 5, 5),
|
||||
"spatio_temporal_scaling": (True, True, True, True),
|
||||
"decoder_spatio_temporal_scaling": (True, True, True),
|
||||
"decoder_inject_noise": (False, False, False, False),
|
||||
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
|
||||
"upsample_residual": (True, True, True),
|
||||
"upsample_factor": (2, 2, 2),
|
||||
"timestep_conditioning": True,
|
||||
"patch_size": 4,
|
||||
"patch_size_t": 1,
|
||||
"resnet_norm_eps": 1e-6,
|
||||
"scaling_factor": 1.0,
|
||||
"encoder_causal": True,
|
||||
"decoder_causal": False,
|
||||
"spatial_compression_ratio": 32,
|
||||
"temporal_compression_ratio": 8,
|
||||
}
|
||||
VAE_KEYS_RENAME_DICT.update(VAE_095_RENAME_DICT)
|
||||
return config
|
||||
|
||||
|
||||
@@ -223,7 +294,7 @@ def get_args():
|
||||
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
|
||||
parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.")
|
||||
parser.add_argument(
|
||||
"--version", type=str, default="0.9.0", choices=["0.9.0", "0.9.1"], help="Version of the LTX model"
|
||||
"--version", type=str, default="0.9.0", choices=["0.9.0", "0.9.1", "0.9.5"], help="Version of the LTX model"
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
@@ -277,14 +348,17 @@ if __name__ == "__main__":
|
||||
for param in text_encoder.parameters():
|
||||
param.data = param.data.contiguous()
|
||||
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(
|
||||
use_dynamic_shifting=True,
|
||||
base_shift=0.95,
|
||||
max_shift=2.05,
|
||||
base_image_seq_len=1024,
|
||||
max_image_seq_len=4096,
|
||||
shift_terminal=0.1,
|
||||
)
|
||||
if args.version == "0.9.5":
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(use_dynamic_shifting=False)
|
||||
else:
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(
|
||||
use_dynamic_shifting=True,
|
||||
base_shift=0.95,
|
||||
max_shift=2.05,
|
||||
base_image_seq_len=1024,
|
||||
max_image_seq_len=4096,
|
||||
shift_terminal=0.1,
|
||||
)
|
||||
|
||||
pipe = LTXPipeline(
|
||||
scheduler=scheduler,
|
||||
|
||||
@@ -402,6 +402,7 @@ else:
|
||||
"LDMTextToImagePipeline",
|
||||
"LEditsPPPipelineStableDiffusion",
|
||||
"LEditsPPPipelineStableDiffusionXL",
|
||||
"LTXConditionPipeline",
|
||||
"LTXImageToVideoPipeline",
|
||||
"LTXPipeline",
|
||||
"Lumina2Pipeline",
|
||||
@@ -947,6 +948,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LDMTextToImagePipeline,
|
||||
LEditsPPPipelineStableDiffusion,
|
||||
LEditsPPPipelineStableDiffusionXL,
|
||||
LTXConditionPipeline,
|
||||
LTXImageToVideoPipeline,
|
||||
LTXPipeline,
|
||||
Lumina2Pipeline,
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from contextlib import nullcontext
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
@@ -56,7 +56,7 @@ class ModuleGroup:
|
||||
buffers: Optional[List[torch.Tensor]] = None,
|
||||
non_blocking: bool = False,
|
||||
stream: Optional[torch.cuda.Stream] = None,
|
||||
cpu_param_dict: Optional[Dict[torch.nn.Parameter, torch.Tensor]] = None,
|
||||
low_cpu_mem_usage=False,
|
||||
onload_self: bool = True,
|
||||
) -> None:
|
||||
self.modules = modules
|
||||
@@ -64,15 +64,50 @@ class ModuleGroup:
|
||||
self.onload_device = onload_device
|
||||
self.offload_leader = offload_leader
|
||||
self.onload_leader = onload_leader
|
||||
self.parameters = parameters
|
||||
self.buffers = buffers
|
||||
self.parameters = parameters or []
|
||||
self.buffers = buffers or []
|
||||
self.non_blocking = non_blocking or stream is not None
|
||||
self.stream = stream
|
||||
self.cpu_param_dict = cpu_param_dict
|
||||
self.onload_self = onload_self
|
||||
self.low_cpu_mem_usage = low_cpu_mem_usage
|
||||
|
||||
if self.stream is not None and self.cpu_param_dict is None:
|
||||
raise ValueError("cpu_param_dict must be provided when using stream for data transfer.")
|
||||
self.cpu_param_dict = self._init_cpu_param_dict()
|
||||
|
||||
def _init_cpu_param_dict(self):
|
||||
cpu_param_dict = {}
|
||||
if self.stream is None:
|
||||
return cpu_param_dict
|
||||
|
||||
for module in self.modules:
|
||||
for param in module.parameters():
|
||||
cpu_param_dict[param] = param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory()
|
||||
for buffer in module.buffers():
|
||||
cpu_param_dict[buffer] = (
|
||||
buffer.data.cpu() if self.low_cpu_mem_usage else buffer.data.cpu().pin_memory()
|
||||
)
|
||||
|
||||
for param in self.parameters:
|
||||
cpu_param_dict[param] = param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory()
|
||||
|
||||
for buffer in self.buffers:
|
||||
cpu_param_dict[buffer] = buffer.data.cpu() if self.low_cpu_mem_usage else buffer.data.cpu().pin_memory()
|
||||
|
||||
return cpu_param_dict
|
||||
|
||||
@contextmanager
|
||||
def _pinned_memory_tensors(self):
|
||||
pinned_dict = {}
|
||||
try:
|
||||
for param, tensor in self.cpu_param_dict.items():
|
||||
if not tensor.is_pinned():
|
||||
pinned_dict[param] = tensor.pin_memory()
|
||||
else:
|
||||
pinned_dict[param] = tensor
|
||||
|
||||
yield pinned_dict
|
||||
|
||||
finally:
|
||||
pinned_dict = None
|
||||
|
||||
def onload_(self):
|
||||
r"""Onloads the group of modules to the onload_device."""
|
||||
@@ -82,12 +117,30 @@ class ModuleGroup:
|
||||
self.stream.synchronize()
|
||||
|
||||
with context:
|
||||
for group_module in self.modules:
|
||||
group_module.to(self.onload_device, non_blocking=self.non_blocking)
|
||||
if self.parameters is not None:
|
||||
if self.stream is not None:
|
||||
with self._pinned_memory_tensors() as pinned_memory:
|
||||
for group_module in self.modules:
|
||||
for param in group_module.parameters():
|
||||
param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking)
|
||||
for buffer in group_module.buffers():
|
||||
buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking)
|
||||
|
||||
for param in self.parameters:
|
||||
param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking)
|
||||
|
||||
for buffer in self.buffers:
|
||||
buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking)
|
||||
|
||||
else:
|
||||
for group_module in self.modules:
|
||||
for param in group_module.parameters():
|
||||
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
|
||||
for buffer in group_module.buffers():
|
||||
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
|
||||
|
||||
for param in self.parameters:
|
||||
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
|
||||
if self.buffers is not None:
|
||||
|
||||
for buffer in self.buffers:
|
||||
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
|
||||
|
||||
@@ -98,15 +151,18 @@ class ModuleGroup:
|
||||
for group_module in self.modules:
|
||||
for param in group_module.parameters():
|
||||
param.data = self.cpu_param_dict[param]
|
||||
for param in self.parameters:
|
||||
param.data = self.cpu_param_dict[param]
|
||||
for buffer in self.buffers:
|
||||
buffer.data = self.cpu_param_dict[buffer]
|
||||
|
||||
else:
|
||||
for group_module in self.modules:
|
||||
group_module.to(self.offload_device, non_blocking=self.non_blocking)
|
||||
if self.parameters is not None:
|
||||
for param in self.parameters:
|
||||
param.data = param.data.to(self.offload_device, non_blocking=self.non_blocking)
|
||||
if self.buffers is not None:
|
||||
for buffer in self.buffers:
|
||||
buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking)
|
||||
for param in self.parameters:
|
||||
param.data = param.data.to(self.offload_device, non_blocking=self.non_blocking)
|
||||
for buffer in self.buffers:
|
||||
buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking)
|
||||
|
||||
|
||||
class GroupOffloadingHook(ModelHook):
|
||||
@@ -172,6 +228,13 @@ class LazyPrefetchGroupOffloadingHook(ModelHook):
|
||||
self._layer_execution_tracker_module_names = set()
|
||||
|
||||
def initialize_hook(self, module):
|
||||
def make_execution_order_update_callback(current_name, current_submodule):
|
||||
def callback():
|
||||
logger.debug(f"Adding {current_name} to the execution order")
|
||||
self.execution_order.append((current_name, current_submodule))
|
||||
|
||||
return callback
|
||||
|
||||
# To every submodule that contains a group offloading hook (at this point, no prefetching is enabled for any
|
||||
# of the groups), we add a layer execution tracker hook that will be used to determine the order in which the
|
||||
# layers are executed during the forward pass.
|
||||
@@ -183,14 +246,8 @@ class LazyPrefetchGroupOffloadingHook(ModelHook):
|
||||
group_offloading_hook = registry.get_hook(_GROUP_OFFLOADING)
|
||||
|
||||
if group_offloading_hook is not None:
|
||||
|
||||
def make_execution_order_update_callback(current_name, current_submodule):
|
||||
def callback():
|
||||
logger.debug(f"Adding {current_name} to the execution order")
|
||||
self.execution_order.append((current_name, current_submodule))
|
||||
|
||||
return callback
|
||||
|
||||
# For the first forward pass, we have to load in a blocking manner
|
||||
group_offloading_hook.group.non_blocking = False
|
||||
layer_tracker_hook = LayerExecutionTrackerHook(make_execution_order_update_callback(name, submodule))
|
||||
registry.register_hook(layer_tracker_hook, _LAYER_EXECUTION_TRACKER)
|
||||
self._layer_execution_tracker_module_names.add(name)
|
||||
@@ -220,6 +277,7 @@ class LazyPrefetchGroupOffloadingHook(ModelHook):
|
||||
# Remove the layer execution tracker hooks from the submodules
|
||||
base_module_registry = module._diffusers_hook
|
||||
registries = [submodule._diffusers_hook for _, submodule in self.execution_order]
|
||||
group_offloading_hooks = [registry.get_hook(_GROUP_OFFLOADING) for registry in registries]
|
||||
|
||||
for i in range(num_executed):
|
||||
registries[i].remove_hook(_LAYER_EXECUTION_TRACKER, recurse=False)
|
||||
@@ -227,8 +285,13 @@ class LazyPrefetchGroupOffloadingHook(ModelHook):
|
||||
# Remove the current lazy prefetch group offloading hook so that it doesn't interfere with the next forward pass
|
||||
base_module_registry.remove_hook(_LAZY_PREFETCH_GROUP_OFFLOADING, recurse=False)
|
||||
|
||||
# Apply lazy prefetching by setting required attributes
|
||||
group_offloading_hooks = [registry.get_hook(_GROUP_OFFLOADING) for registry in registries]
|
||||
# LazyPrefetchGroupOffloadingHook is only used with streams, so we know that non_blocking should be True.
|
||||
# We disable non_blocking for the first forward pass, but need to enable it for the subsequent passes to
|
||||
# see the benefits of prefetching.
|
||||
for hook in group_offloading_hooks:
|
||||
hook.group.non_blocking = True
|
||||
|
||||
# Set required attributes for prefetching
|
||||
if num_executed > 0:
|
||||
base_module_group_offloading_hook = base_module_registry.get_hook(_GROUP_OFFLOADING)
|
||||
base_module_group_offloading_hook.next_group = group_offloading_hooks[0].group
|
||||
@@ -268,6 +331,7 @@ def apply_group_offloading(
|
||||
num_blocks_per_group: Optional[int] = None,
|
||||
non_blocking: bool = False,
|
||||
use_stream: bool = False,
|
||||
low_cpu_mem_usage=False,
|
||||
) -> None:
|
||||
r"""
|
||||
Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and
|
||||
@@ -349,10 +413,12 @@ def apply_group_offloading(
|
||||
raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.")
|
||||
|
||||
_apply_group_offloading_block_level(
|
||||
module, num_blocks_per_group, offload_device, onload_device, non_blocking, stream
|
||||
module, num_blocks_per_group, offload_device, onload_device, non_blocking, stream, low_cpu_mem_usage
|
||||
)
|
||||
elif offload_type == "leaf_level":
|
||||
_apply_group_offloading_leaf_level(module, offload_device, onload_device, non_blocking, stream)
|
||||
_apply_group_offloading_leaf_level(
|
||||
module, offload_device, onload_device, non_blocking, stream, low_cpu_mem_usage
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported offload_type: {offload_type}")
|
||||
|
||||
@@ -364,6 +430,7 @@ def _apply_group_offloading_block_level(
|
||||
onload_device: torch.device,
|
||||
non_blocking: bool,
|
||||
stream: Optional[torch.cuda.Stream] = None,
|
||||
low_cpu_mem_usage: bool = False,
|
||||
) -> None:
|
||||
r"""
|
||||
This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to
|
||||
@@ -384,13 +451,6 @@ def _apply_group_offloading_block_level(
|
||||
for overlapping computation and data transfer.
|
||||
"""
|
||||
|
||||
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
|
||||
cpu_param_dict = None
|
||||
if stream is not None:
|
||||
for param in module.parameters():
|
||||
param.data = param.data.cpu().pin_memory()
|
||||
cpu_param_dict = {param: param.data for param in module.parameters()}
|
||||
|
||||
# Create module groups for ModuleList and Sequential blocks
|
||||
modules_with_group_offloading = set()
|
||||
unmatched_modules = []
|
||||
@@ -411,7 +471,7 @@ def _apply_group_offloading_block_level(
|
||||
onload_leader=current_modules[0],
|
||||
non_blocking=non_blocking,
|
||||
stream=stream,
|
||||
cpu_param_dict=cpu_param_dict,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
onload_self=stream is None,
|
||||
)
|
||||
matched_module_groups.append(group)
|
||||
@@ -448,7 +508,6 @@ def _apply_group_offloading_block_level(
|
||||
buffers=buffers,
|
||||
non_blocking=False,
|
||||
stream=None,
|
||||
cpu_param_dict=None,
|
||||
onload_self=True,
|
||||
)
|
||||
next_group = matched_module_groups[0] if len(matched_module_groups) > 0 else None
|
||||
@@ -461,6 +520,7 @@ def _apply_group_offloading_leaf_level(
|
||||
onload_device: torch.device,
|
||||
non_blocking: bool,
|
||||
stream: Optional[torch.cuda.Stream] = None,
|
||||
low_cpu_mem_usage: bool = False,
|
||||
) -> None:
|
||||
r"""
|
||||
This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory
|
||||
@@ -483,13 +543,6 @@ def _apply_group_offloading_leaf_level(
|
||||
for overlapping computation and data transfer.
|
||||
"""
|
||||
|
||||
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
|
||||
cpu_param_dict = None
|
||||
if stream is not None:
|
||||
for param in module.parameters():
|
||||
param.data = param.data.cpu().pin_memory()
|
||||
cpu_param_dict = {param: param.data for param in module.parameters()}
|
||||
|
||||
# Create module groups for leaf modules and apply group offloading hooks
|
||||
modules_with_group_offloading = set()
|
||||
for name, submodule in module.named_modules():
|
||||
@@ -503,7 +556,7 @@ def _apply_group_offloading_leaf_level(
|
||||
onload_leader=submodule,
|
||||
non_blocking=non_blocking,
|
||||
stream=stream,
|
||||
cpu_param_dict=cpu_param_dict,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
onload_self=True,
|
||||
)
|
||||
_apply_group_offloading_hook(submodule, group, None)
|
||||
@@ -548,7 +601,7 @@ def _apply_group_offloading_leaf_level(
|
||||
buffers=buffers,
|
||||
non_blocking=non_blocking,
|
||||
stream=stream,
|
||||
cpu_param_dict=cpu_param_dict,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
onload_self=True,
|
||||
)
|
||||
_apply_group_offloading_hook(parent_module, group, None)
|
||||
@@ -567,7 +620,7 @@ def _apply_group_offloading_leaf_level(
|
||||
buffers=None,
|
||||
non_blocking=False,
|
||||
stream=None,
|
||||
cpu_param_dict=None,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
onload_self=True,
|
||||
)
|
||||
_apply_lazy_group_offloading_hook(module, unmatched_group, None)
|
||||
|
||||
@@ -4249,7 +4249,33 @@ class WanLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
return state_dict
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
|
||||
@classmethod
|
||||
def _maybe_expand_t2v_lora_for_i2v(
|
||||
cls,
|
||||
transformer: torch.nn.Module,
|
||||
state_dict,
|
||||
):
|
||||
if transformer.config.image_dim is None:
|
||||
return state_dict
|
||||
|
||||
if any(k.startswith("transformer.blocks.") for k in state_dict):
|
||||
num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in state_dict})
|
||||
is_i2v_lora = any("add_k_proj" in k for k in state_dict) and any("add_v_proj" in k for k in state_dict)
|
||||
|
||||
if is_i2v_lora:
|
||||
return state_dict
|
||||
|
||||
for i in range(num_blocks):
|
||||
for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
|
||||
state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_A.weight"] = torch.zeros_like(
|
||||
state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"]
|
||||
)
|
||||
state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.weight"] = torch.zeros_like(
|
||||
state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"]
|
||||
)
|
||||
|
||||
return state_dict
|
||||
|
||||
def load_lora_weights(
|
||||
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
|
||||
):
|
||||
@@ -4287,7 +4313,11 @@ class WanLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
||||
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
||||
|
||||
# convert T2V LoRA to I2V LoRA (when loaded to Wan I2V) by adding zeros for the additional (missing) _img layers
|
||||
state_dict = self._maybe_expand_t2v_lora_for_i2v(
|
||||
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
||||
state_dict=state_dict,
|
||||
)
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
@@ -449,9 +449,9 @@ class TextualInversionLoaderMixin:
|
||||
|
||||
# 7.5 Offload the model again
|
||||
if is_model_cpu_offload:
|
||||
self.enable_model_cpu_offload()
|
||||
self.enable_model_cpu_offload(device=device)
|
||||
elif is_sequential_cpu_offload:
|
||||
self.enable_sequential_cpu_offload()
|
||||
self.enable_sequential_cpu_offload(device=device)
|
||||
|
||||
# / Unsafe Code >
|
||||
|
||||
|
||||
@@ -196,6 +196,55 @@ class LTXVideoResnetBlock3d(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LTXVideoDownsampler3d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
stride: Union[int, Tuple[int, int, int]] = 1,
|
||||
is_causal: bool = True,
|
||||
padding_mode: str = "zeros",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride)
|
||||
self.group_size = (in_channels * stride[0] * stride[1] * stride[2]) // out_channels
|
||||
|
||||
out_channels = out_channels // (self.stride[0] * self.stride[1] * self.stride[2])
|
||||
|
||||
self.conv = LTXVideoCausalConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
is_causal=is_causal,
|
||||
padding_mode=padding_mode,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = torch.cat([hidden_states[:, :, : self.stride[0] - 1], hidden_states], dim=2)
|
||||
|
||||
residual = (
|
||||
hidden_states.unflatten(4, (-1, self.stride[2]))
|
||||
.unflatten(3, (-1, self.stride[1]))
|
||||
.unflatten(2, (-1, self.stride[0]))
|
||||
)
|
||||
residual = residual.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(1, 4)
|
||||
residual = residual.unflatten(1, (-1, self.group_size))
|
||||
residual = residual.mean(dim=2)
|
||||
|
||||
hidden_states = self.conv(hidden_states)
|
||||
hidden_states = (
|
||||
hidden_states.unflatten(4, (-1, self.stride[2]))
|
||||
.unflatten(3, (-1, self.stride[1]))
|
||||
.unflatten(2, (-1, self.stride[0]))
|
||||
)
|
||||
hidden_states = hidden_states.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(1, 4)
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LTXVideoUpsampler3d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -204,6 +253,7 @@ class LTXVideoUpsampler3d(nn.Module):
|
||||
is_causal: bool = True,
|
||||
residual: bool = False,
|
||||
upscale_factor: int = 1,
|
||||
padding_mode: str = "zeros",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -219,6 +269,7 @@ class LTXVideoUpsampler3d(nn.Module):
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
is_causal=is_causal,
|
||||
padding_mode=padding_mode,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
@@ -352,6 +403,118 @@ class LTXVideoDownBlock3D(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LTXVideo095DownBlock3D(nn.Module):
|
||||
r"""
|
||||
Down block used in the LTXVideo model.
|
||||
|
||||
Args:
|
||||
in_channels (`int`):
|
||||
Number of input channels.
|
||||
out_channels (`int`, *optional*):
|
||||
Number of output channels. If None, defaults to `in_channels`.
|
||||
num_layers (`int`, defaults to `1`):
|
||||
Number of resnet layers.
|
||||
dropout (`float`, defaults to `0.0`):
|
||||
Dropout rate.
|
||||
resnet_eps (`float`, defaults to `1e-6`):
|
||||
Epsilon value for normalization layers.
|
||||
resnet_act_fn (`str`, defaults to `"swish"`):
|
||||
Activation function to use.
|
||||
spatio_temporal_scale (`bool`, defaults to `True`):
|
||||
Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension.
|
||||
Whether or not to downsample across temporal dimension.
|
||||
is_causal (`bool`, defaults to `True`):
|
||||
Whether this layer behaves causally (future frames depend only on past frames) or not.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: Optional[int] = None,
|
||||
num_layers: int = 1,
|
||||
dropout: float = 0.0,
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_act_fn: str = "swish",
|
||||
spatio_temporal_scale: bool = True,
|
||||
is_causal: bool = True,
|
||||
downsample_type: str = "conv",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
out_channels = out_channels or in_channels
|
||||
|
||||
resnets = []
|
||||
for _ in range(num_layers):
|
||||
resnets.append(
|
||||
LTXVideoResnetBlock3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
dropout=dropout,
|
||||
eps=resnet_eps,
|
||||
non_linearity=resnet_act_fn,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
self.downsamplers = None
|
||||
if spatio_temporal_scale:
|
||||
self.downsamplers = nn.ModuleList()
|
||||
|
||||
if downsample_type == "conv":
|
||||
self.downsamplers.append(
|
||||
LTXVideoCausalConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
kernel_size=3,
|
||||
stride=(2, 2, 2),
|
||||
is_causal=is_causal,
|
||||
)
|
||||
)
|
||||
elif downsample_type == "spatial":
|
||||
self.downsamplers.append(
|
||||
LTXVideoDownsampler3d(
|
||||
in_channels=in_channels, out_channels=out_channels, stride=(1, 2, 2), is_causal=is_causal
|
||||
)
|
||||
)
|
||||
elif downsample_type == "temporal":
|
||||
self.downsamplers.append(
|
||||
LTXVideoDownsampler3d(
|
||||
in_channels=in_channels, out_channels=out_channels, stride=(2, 1, 1), is_causal=is_causal
|
||||
)
|
||||
)
|
||||
elif downsample_type == "spatiotemporal":
|
||||
self.downsamplers.append(
|
||||
LTXVideoDownsampler3d(
|
||||
in_channels=in_channels, out_channels=out_channels, stride=(2, 2, 2), is_causal=is_causal
|
||||
)
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> torch.Tensor:
|
||||
r"""Forward method of the `LTXDownBlock3D` class."""
|
||||
|
||||
for i, resnet in enumerate(self.resnets):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb, generator)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Adapted from diffusers.models.autoencoders.autoencoder_kl_cogvideox.CogVideoMidBlock3d
|
||||
class LTXVideoMidBlock3d(nn.Module):
|
||||
r"""
|
||||
@@ -593,8 +756,15 @@ class LTXVideoEncoder3d(nn.Module):
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 128,
|
||||
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
),
|
||||
spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
|
||||
layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
|
||||
downsample_type: Tuple[str, ...] = ("conv", "conv", "conv", "conv"),
|
||||
patch_size: int = 4,
|
||||
patch_size_t: int = 1,
|
||||
resnet_norm_eps: float = 1e-6,
|
||||
@@ -617,20 +787,37 @@ class LTXVideoEncoder3d(nn.Module):
|
||||
)
|
||||
|
||||
# down blocks
|
||||
num_block_out_channels = len(block_out_channels)
|
||||
is_ltx_095 = down_block_types[-1] == "LTXVideo095DownBlock3D"
|
||||
num_block_out_channels = len(block_out_channels) - (1 if is_ltx_095 else 0)
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
for i in range(num_block_out_channels):
|
||||
input_channel = output_channel
|
||||
output_channel = block_out_channels[i + 1] if i + 1 < num_block_out_channels else block_out_channels[i]
|
||||
if not is_ltx_095:
|
||||
output_channel = block_out_channels[i + 1] if i + 1 < num_block_out_channels else block_out_channels[i]
|
||||
else:
|
||||
output_channel = block_out_channels[i + 1]
|
||||
|
||||
down_block = LTXVideoDownBlock3D(
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
num_layers=layers_per_block[i],
|
||||
resnet_eps=resnet_norm_eps,
|
||||
spatio_temporal_scale=spatio_temporal_scaling[i],
|
||||
is_causal=is_causal,
|
||||
)
|
||||
if down_block_types[i] == "LTXVideoDownBlock3D":
|
||||
down_block = LTXVideoDownBlock3D(
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
num_layers=layers_per_block[i],
|
||||
resnet_eps=resnet_norm_eps,
|
||||
spatio_temporal_scale=spatio_temporal_scaling[i],
|
||||
is_causal=is_causal,
|
||||
)
|
||||
elif down_block_types[i] == "LTXVideo095DownBlock3D":
|
||||
down_block = LTXVideo095DownBlock3D(
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
num_layers=layers_per_block[i],
|
||||
resnet_eps=resnet_norm_eps,
|
||||
spatio_temporal_scale=spatio_temporal_scaling[i],
|
||||
is_causal=is_causal,
|
||||
downsample_type=downsample_type[i],
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown down block type: {down_block_types[i]}")
|
||||
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
@@ -794,7 +981,9 @@ class LTXVideoDecoder3d(nn.Module):
|
||||
# timestep embedding
|
||||
self.time_embedder = None
|
||||
self.scale_shift_table = None
|
||||
self.timestep_scale_multiplier = None
|
||||
if timestep_conditioning:
|
||||
self.timestep_scale_multiplier = nn.Parameter(torch.tensor(1000.0, dtype=torch.float32))
|
||||
self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(output_channel * 2, 0)
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(2, output_channel) / output_channel**0.5)
|
||||
|
||||
@@ -803,6 +992,9 @@ class LTXVideoDecoder3d(nn.Module):
|
||||
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
|
||||
if self.timestep_scale_multiplier is not None:
|
||||
temb = temb * self.timestep_scale_multiplier
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states, temb)
|
||||
|
||||
@@ -891,12 +1083,19 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
out_channels: int = 3,
|
||||
latent_channels: int = 128,
|
||||
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
),
|
||||
decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
|
||||
layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
|
||||
decoder_layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
|
||||
spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
|
||||
decoder_spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
|
||||
decoder_inject_noise: Tuple[bool, ...] = (False, False, False, False, False),
|
||||
downsample_type: Tuple[str, ...] = ("conv", "conv", "conv", "conv"),
|
||||
upsample_residual: Tuple[bool, ...] = (False, False, False, False),
|
||||
upsample_factor: Tuple[int, ...] = (1, 1, 1, 1),
|
||||
timestep_conditioning: bool = False,
|
||||
@@ -906,6 +1105,8 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
scaling_factor: float = 1.0,
|
||||
encoder_causal: bool = True,
|
||||
decoder_causal: bool = False,
|
||||
spatial_compression_ratio: int = None,
|
||||
temporal_compression_ratio: int = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -913,8 +1114,10 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
in_channels=in_channels,
|
||||
out_channels=latent_channels,
|
||||
block_out_channels=block_out_channels,
|
||||
down_block_types=down_block_types,
|
||||
spatio_temporal_scaling=spatio_temporal_scaling,
|
||||
layers_per_block=layers_per_block,
|
||||
downsample_type=downsample_type,
|
||||
patch_size=patch_size,
|
||||
patch_size_t=patch_size_t,
|
||||
resnet_norm_eps=resnet_norm_eps,
|
||||
@@ -941,8 +1144,16 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
self.register_buffer("latents_mean", latents_mean, persistent=True)
|
||||
self.register_buffer("latents_std", latents_std, persistent=True)
|
||||
|
||||
self.spatial_compression_ratio = patch_size * 2 ** sum(spatio_temporal_scaling)
|
||||
self.temporal_compression_ratio = patch_size_t * 2 ** sum(spatio_temporal_scaling)
|
||||
self.spatial_compression_ratio = (
|
||||
patch_size * 2 ** sum(spatio_temporal_scaling)
|
||||
if spatial_compression_ratio is None
|
||||
else spatial_compression_ratio
|
||||
)
|
||||
self.temporal_compression_ratio = (
|
||||
patch_size_t * 2 ** sum(spatio_temporal_scaling)
|
||||
if temporal_compression_ratio is None
|
||||
else temporal_compression_ratio
|
||||
)
|
||||
|
||||
# When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
|
||||
# to perform decoding of a single video latent at a time.
|
||||
|
||||
@@ -546,6 +546,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
num_blocks_per_group: Optional[int] = None,
|
||||
non_blocking: bool = False,
|
||||
use_stream: bool = False,
|
||||
low_cpu_mem_usage=False,
|
||||
) -> None:
|
||||
r"""
|
||||
Activates group offloading for the current model.
|
||||
@@ -584,7 +585,14 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
f"open an issue at https://github.com/huggingface/diffusers/issues."
|
||||
)
|
||||
apply_group_offloading(
|
||||
self, onload_device, offload_device, offload_type, num_blocks_per_group, non_blocking, use_stream
|
||||
self,
|
||||
onload_device,
|
||||
offload_device,
|
||||
offload_type,
|
||||
num_blocks_per_group,
|
||||
non_blocking,
|
||||
use_stream,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
def save_pretrained(
|
||||
|
||||
@@ -366,7 +366,7 @@ class ResnetBlock2D(nn.Module):
|
||||
hidden_states = self.conv2(hidden_states)
|
||||
|
||||
if self.conv_shortcut is not None:
|
||||
input_tensor = self.conv_shortcut(input_tensor)
|
||||
input_tensor = self.conv_shortcut(input_tensor.contiguous())
|
||||
|
||||
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -113,20 +113,19 @@ class LTXVideoRotaryPosEmbed(nn.Module):
|
||||
self.patch_size_t = patch_size_t
|
||||
self.theta = theta
|
||||
|
||||
def forward(
|
||||
def _prepare_video_coords(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
batch_size: int,
|
||||
num_frames: int,
|
||||
height: int,
|
||||
width: int,
|
||||
rope_interpolation_scale: Optional[Tuple[torch.Tensor, float, float]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
batch_size = hidden_states.size(0)
|
||||
|
||||
rope_interpolation_scale: Tuple[torch.Tensor, float, float],
|
||||
device: torch.device,
|
||||
) -> torch.Tensor:
|
||||
# Always compute rope in fp32
|
||||
grid_h = torch.arange(height, dtype=torch.float32, device=hidden_states.device)
|
||||
grid_w = torch.arange(width, dtype=torch.float32, device=hidden_states.device)
|
||||
grid_f = torch.arange(num_frames, dtype=torch.float32, device=hidden_states.device)
|
||||
grid_h = torch.arange(height, dtype=torch.float32, device=device)
|
||||
grid_w = torch.arange(width, dtype=torch.float32, device=device)
|
||||
grid_f = torch.arange(num_frames, dtype=torch.float32, device=device)
|
||||
grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing="ij")
|
||||
grid = torch.stack(grid, dim=0)
|
||||
grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
|
||||
@@ -138,6 +137,38 @@ class LTXVideoRotaryPosEmbed(nn.Module):
|
||||
|
||||
grid = grid.flatten(2, 4).transpose(1, 2)
|
||||
|
||||
return grid
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
num_frames: Optional[int] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
rope_interpolation_scale: Optional[Tuple[torch.Tensor, float, float]] = None,
|
||||
video_coords: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
batch_size = hidden_states.size(0)
|
||||
|
||||
if video_coords is None:
|
||||
grid = self._prepare_video_coords(
|
||||
batch_size,
|
||||
num_frames,
|
||||
height,
|
||||
width,
|
||||
rope_interpolation_scale=rope_interpolation_scale,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
else:
|
||||
grid = torch.stack(
|
||||
[
|
||||
video_coords[:, 0] / self.base_num_frames,
|
||||
video_coords[:, 1] / self.base_height,
|
||||
video_coords[:, 2] / self.base_width,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
start = 1.0
|
||||
end = self.theta
|
||||
freqs = self.theta ** torch.linspace(
|
||||
@@ -367,10 +398,11 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
timestep: torch.LongTensor,
|
||||
encoder_attention_mask: torch.Tensor,
|
||||
num_frames: int,
|
||||
height: int,
|
||||
width: int,
|
||||
rope_interpolation_scale: Optional[Tuple[float, float, float]] = None,
|
||||
num_frames: Optional[int] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
rope_interpolation_scale: Optional[Union[Tuple[float, float, float], torch.Tensor]] = None,
|
||||
video_coords: Optional[torch.Tensor] = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> torch.Tensor:
|
||||
@@ -389,7 +421,7 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale)
|
||||
image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale, video_coords)
|
||||
|
||||
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
||||
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
|
||||
|
||||
@@ -264,7 +264,7 @@ else:
|
||||
]
|
||||
)
|
||||
_import_structure["latte"] = ["LattePipeline"]
|
||||
_import_structure["ltx"] = ["LTXPipeline", "LTXImageToVideoPipeline"]
|
||||
_import_structure["ltx"] = ["LTXPipeline", "LTXImageToVideoPipeline", "LTXConditionPipeline"]
|
||||
_import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"]
|
||||
_import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"]
|
||||
_import_structure["marigold"].extend(
|
||||
@@ -618,7 +618,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LEditsPPPipelineStableDiffusion,
|
||||
LEditsPPPipelineStableDiffusionXL,
|
||||
)
|
||||
from .ltx import LTXImageToVideoPipeline, LTXPipeline
|
||||
from .ltx import LTXConditionPipeline, LTXImageToVideoPipeline, LTXPipeline
|
||||
from .lumina import LuminaPipeline, LuminaText2ImgPipeline
|
||||
from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline
|
||||
from .marigold import (
|
||||
|
||||
@@ -533,7 +533,6 @@ class FluxControlImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSin
|
||||
|
||||
return latents
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux_img2img.FluxImg2ImgPipeline.prepare_latents
|
||||
def prepare_latents(
|
||||
self,
|
||||
image,
|
||||
|
||||
@@ -63,6 +63,7 @@ EXAMPLE_DOC_STRING = """
|
||||
>>> from diffusers import FluxControlNetPipeline
|
||||
>>> from diffusers import FluxControlNetModel
|
||||
|
||||
>>> base_model = "black-forest-labs/FLUX.1-dev"
|
||||
>>> controlnet_model = "InstantX/FLUX.1-dev-controlnet-canny"
|
||||
>>> controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16)
|
||||
>>> pipe = FluxControlNetPipeline.from_pretrained(
|
||||
|
||||
@@ -533,7 +533,6 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
|
||||
|
||||
return latents
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux_img2img.FluxImg2ImgPipeline.prepare_latents
|
||||
def prepare_latents(
|
||||
self,
|
||||
image,
|
||||
|
||||
@@ -561,7 +561,6 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
|
||||
|
||||
return latents
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux_inpaint.FluxInpaintPipeline.prepare_latents
|
||||
def prepare_latents(
|
||||
self,
|
||||
image,
|
||||
@@ -614,7 +613,6 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
|
||||
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
|
||||
return latents, noise, image_latents, latent_image_ids
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux_inpaint.FluxInpaintPipeline.prepare_mask_latents
|
||||
def prepare_mask_latents(
|
||||
self,
|
||||
mask,
|
||||
|
||||
@@ -225,7 +225,10 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
|
||||
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
|
||||
self.latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
|
||||
self.image_processor = VaeImageProcessor(
|
||||
vae_scale_factor=self.vae_scale_factor * 2, vae_latent_channels=self.latent_channels
|
||||
)
|
||||
self.tokenizer_max_length = (
|
||||
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
|
||||
)
|
||||
@@ -634,7 +637,10 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
||||
return latents.to(device=device, dtype=dtype), latent_image_ids
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
image_latents = self._encode_vae_image(image=image, generator=generator)
|
||||
if image.shape[1] != self.latent_channels:
|
||||
image_latents = self._encode_vae_image(image=image, generator=generator)
|
||||
else:
|
||||
image_latents = image
|
||||
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
|
||||
# expand init_latents for batch_size
|
||||
additional_image_per_prompt = batch_size // image_latents.shape[0]
|
||||
|
||||
@@ -222,11 +222,13 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FluxIPAdapterM
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
|
||||
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
|
||||
latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
|
||||
self.latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
|
||||
self.image_processor = VaeImageProcessor(
|
||||
vae_scale_factor=self.vae_scale_factor * 2, vae_latent_channels=self.latent_channels
|
||||
)
|
||||
self.mask_processor = VaeImageProcessor(
|
||||
vae_scale_factor=self.vae_scale_factor * 2,
|
||||
vae_latent_channels=latent_channels,
|
||||
vae_latent_channels=self.latent_channels,
|
||||
do_normalize=False,
|
||||
do_binarize=True,
|
||||
do_convert_grayscale=True,
|
||||
@@ -653,7 +655,10 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FluxIPAdapterM
|
||||
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
image_latents = self._encode_vae_image(image=image, generator=generator)
|
||||
if image.shape[1] != self.latent_channels:
|
||||
image_latents = self._encode_vae_image(image=image, generator=generator)
|
||||
else:
|
||||
image_latents = image
|
||||
|
||||
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
|
||||
# expand init_latents for batch_size
|
||||
@@ -710,7 +715,9 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FluxIPAdapterM
|
||||
else:
|
||||
masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator)
|
||||
|
||||
masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
||||
masked_image_latents = (
|
||||
masked_image_latents - self.vae.config.shift_factor
|
||||
) * self.vae.config.scaling_factor
|
||||
|
||||
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
|
||||
if mask.shape[0] < batch_size:
|
||||
|
||||
@@ -23,6 +23,7 @@ except OptionalDependencyNotAvailable:
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_ltx"] = ["LTXPipeline"]
|
||||
_import_structure["pipeline_ltx_condition"] = ["LTXConditionPipeline"]
|
||||
_import_structure["pipeline_ltx_image2video"] = ["LTXImageToVideoPipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
@@ -34,6 +35,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .pipeline_ltx import LTXPipeline
|
||||
from .pipeline_ltx_condition import LTXConditionPipeline
|
||||
from .pipeline_ltx_image2video import LTXImageToVideoPipeline
|
||||
|
||||
else:
|
||||
|
||||
@@ -694,9 +694,8 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 6. Prepare micro-conditions
|
||||
latent_frame_rate = frame_rate / self.vae_temporal_compression_ratio
|
||||
rope_interpolation_scale = (
|
||||
1 / latent_frame_rate,
|
||||
self.vae_temporal_compression_ratio / frame_rate,
|
||||
self.vae_spatial_compression_ratio,
|
||||
self.vae_spatial_compression_ratio,
|
||||
)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -764,9 +764,8 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 6. Prepare micro-conditions
|
||||
latent_frame_rate = frame_rate / self.vae_temporal_compression_ratio
|
||||
rope_interpolation_scale = (
|
||||
1 / latent_frame_rate,
|
||||
self.vae_temporal_compression_ratio / frame_rate,
|
||||
self.vae_spatial_compression_ratio,
|
||||
self.vae_spatial_compression_ratio,
|
||||
)
|
||||
|
||||
@@ -427,7 +427,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
"It seems like you have activated a device mapping strategy on the pipeline which doesn't allow explicit device placement using `to()`. You can call `reset_device_map()` to remove the existing device map from the pipeline."
|
||||
)
|
||||
|
||||
if device_type == "cuda":
|
||||
if device_type in ["cuda", "xpu"]:
|
||||
if pipeline_is_sequentially_offloaded and not pipeline_has_bnb:
|
||||
raise ValueError(
|
||||
"It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading."
|
||||
@@ -440,7 +440,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
|
||||
# Display a warning in this case (the operation succeeds but the benefits are lost)
|
||||
pipeline_is_offloaded = any(module_is_offloaded(module) for _, module in self.components.items())
|
||||
if pipeline_is_offloaded and device_type == "cuda":
|
||||
if pipeline_is_offloaded and device_type in ["cuda", "xpu"]:
|
||||
logger.warning(
|
||||
f"It seems like you have activated model offloading by calling `enable_model_cpu_offload`, but are now manually moving the pipeline to GPU. It is strongly recommended against doing so as memory gains from offloading are likely to be lost. Offloading automatically takes care of moving the individual components {', '.join(self.components.keys())} to GPU when needed. To make sure offloading works as expected, you should consider moving the pipeline back to CPU: `pipeline.to('cpu')` or removing the move altogether if you use offloading."
|
||||
)
|
||||
|
||||
@@ -941,8 +941,7 @@ class PixArtAlphaPipeline(DiffusionPipeline):
|
||||
|
||||
# compute previous image: x_t -> x_t-1
|
||||
if num_inference_steps == 1:
|
||||
# For DMD one step sampling: https://arxiv.org/abs/2311.18828
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).pred_original_sample
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[1]
|
||||
else:
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
|
||||
|
||||
@@ -108,31 +108,16 @@ def prompt_clean(text):
|
||||
return text
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor,
|
||||
latents_mean: torch.Tensor,
|
||||
latents_std: torch.Tensor,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
sample_mode: str = "sample",
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
encoder_output.latent_dist.mean = (encoder_output.latent_dist.mean - latents_mean) * latents_std
|
||||
encoder_output.latent_dist.logvar = torch.clamp(
|
||||
(encoder_output.latent_dist.logvar - latents_mean) * latents_std, -30.0, 20.0
|
||||
)
|
||||
encoder_output.latent_dist.std = torch.exp(0.5 * encoder_output.latent_dist.logvar)
|
||||
encoder_output.latent_dist.var = torch.exp(encoder_output.latent_dist.logvar)
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
encoder_output.latent_dist.mean = (encoder_output.latent_dist.mean - latents_mean) * latents_std
|
||||
encoder_output.latent_dist.logvar = torch.clamp(
|
||||
(encoder_output.latent_dist.logvar - latents_mean) * latents_std, -30.0, 20.0
|
||||
)
|
||||
encoder_output.latent_dist.std = torch.exp(0.5 * encoder_output.latent_dist.logvar)
|
||||
encoder_output.latent_dist.var = torch.exp(encoder_output.latent_dist.logvar)
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return (encoder_output.latents - latents_mean) * latents_std
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
@@ -412,13 +397,15 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
|
||||
if isinstance(generator, list):
|
||||
latent_condition = [
|
||||
retrieve_latents(self.vae.encode(video_condition), latents_mean, latents_std, g) for g in generator
|
||||
retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") for _ in generator
|
||||
]
|
||||
latent_condition = torch.cat(latent_condition)
|
||||
else:
|
||||
latent_condition = retrieve_latents(self.vae.encode(video_condition), latents_mean, latents_std, generator)
|
||||
latent_condition = retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax")
|
||||
latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1)
|
||||
|
||||
latent_condition = (latent_condition - latents_mean) * latents_std
|
||||
|
||||
mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width)
|
||||
mask_lat_size[:, :, list(range(1, num_frames))] = 0
|
||||
first_frame_mask = mask_lat_size[:, :, 0:1]
|
||||
|
||||
@@ -61,7 +61,7 @@ class BnB4BitDiffusersQuantizer(DiffusersQuantizer):
|
||||
self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules
|
||||
|
||||
def validate_environment(self, *args, **kwargs):
|
||||
if not torch.cuda.is_available():
|
||||
if not (torch.cuda.is_available() or torch.xpu.is_available()):
|
||||
raise RuntimeError("No GPU found. A GPU is needed for quantization.")
|
||||
if not is_accelerate_available() or is_accelerate_version("<", "0.26.0"):
|
||||
raise ImportError(
|
||||
@@ -238,11 +238,15 @@ class BnB4BitDiffusersQuantizer(DiffusersQuantizer):
|
||||
|
||||
def update_device_map(self, device_map):
|
||||
if device_map is None:
|
||||
device_map = {"": f"cuda:{torch.cuda.current_device()}"}
|
||||
if torch.xpu.is_available():
|
||||
current_device = f"xpu:{torch.xpu.current_device()}"
|
||||
else:
|
||||
current_device = f"cuda:{torch.cuda.current_device()}"
|
||||
device_map = {"": current_device}
|
||||
logger.info(
|
||||
"The device_map was not initialized. "
|
||||
"Setting device_map to {"
|
||||
": f`cuda:{torch.cuda.current_device()}`}. "
|
||||
": {current_device}}. "
|
||||
"If you want to use the model for inference, please set device_map ='auto' "
|
||||
)
|
||||
return device_map
|
||||
@@ -312,7 +316,10 @@ class BnB4BitDiffusersQuantizer(DiffusersQuantizer):
|
||||
logger.info(
|
||||
"Model was found to be on CPU (could happen as a result of `enable_model_cpu_offload()`). So, moving it to GPU. After dequantization, will move the model back to CPU again to preserve the previous device."
|
||||
)
|
||||
model.to(torch.cuda.current_device())
|
||||
if torch.xpu.is_available():
|
||||
model.to(torch.xpu.current_device())
|
||||
else:
|
||||
model.to(torch.cuda.current_device())
|
||||
|
||||
model = dequantize_and_replace(
|
||||
model, self.modules_to_not_convert, quantization_config=self.quantization_config
|
||||
@@ -343,7 +350,7 @@ class BnB8BitDiffusersQuantizer(DiffusersQuantizer):
|
||||
self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules
|
||||
|
||||
def validate_environment(self, *args, **kwargs):
|
||||
if not torch.cuda.is_available():
|
||||
if not (torch.cuda.is_available() or torch.xpu.is_available()):
|
||||
raise RuntimeError("No GPU found. A GPU is needed for quantization.")
|
||||
if not is_accelerate_available() or is_accelerate_version("<", "0.26.0"):
|
||||
raise ImportError(
|
||||
@@ -402,11 +409,15 @@ class BnB8BitDiffusersQuantizer(DiffusersQuantizer):
|
||||
# Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.update_device_map
|
||||
def update_device_map(self, device_map):
|
||||
if device_map is None:
|
||||
device_map = {"": f"cuda:{torch.cuda.current_device()}"}
|
||||
if torch.xpu.is_available():
|
||||
current_device = f"xpu:{torch.xpu.current_device()}"
|
||||
else:
|
||||
current_device = f"cuda:{torch.cuda.current_device()}"
|
||||
device_map = {"": current_device}
|
||||
logger.info(
|
||||
"The device_map was not initialized. "
|
||||
"Setting device_map to {"
|
||||
": f`cuda:{torch.cuda.current_device()}`}. "
|
||||
": {current_device}}. "
|
||||
"If you want to use the model for inference, please set device_map ='auto' "
|
||||
)
|
||||
return device_map
|
||||
|
||||
@@ -377,6 +377,7 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
s_tmax: float = float("inf"),
|
||||
s_noise: float = 1.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
per_token_timesteps: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
|
||||
"""
|
||||
@@ -397,6 +398,8 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
Scaling factor for noise added to the sample.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A random number generator.
|
||||
per_token_timesteps (`torch.Tensor`, *optional*):
|
||||
The timesteps for each token in the sample.
|
||||
return_dict (`bool`):
|
||||
Whether or not to return a
|
||||
[`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] or tuple.
|
||||
@@ -427,16 +430,26 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
# Upcast to avoid precision issues when computing prev_sample
|
||||
sample = sample.to(torch.float32)
|
||||
|
||||
sigma = self.sigmas[self.step_index]
|
||||
sigma_next = self.sigmas[self.step_index + 1]
|
||||
if per_token_timesteps is not None:
|
||||
per_token_sigmas = per_token_timesteps / self.config.num_train_timesteps
|
||||
|
||||
prev_sample = sample + (sigma_next - sigma) * model_output
|
||||
sigmas = self.sigmas[:, None, None]
|
||||
lower_mask = sigmas < per_token_sigmas[None] - 1e-6
|
||||
lower_sigmas = lower_mask * sigmas
|
||||
lower_sigmas, _ = lower_sigmas.max(dim=0)
|
||||
dt = (per_token_sigmas - lower_sigmas)[..., None]
|
||||
else:
|
||||
sigma = self.sigmas[self.step_index]
|
||||
sigma_next = self.sigmas[self.step_index + 1]
|
||||
dt = sigma_next - sigma
|
||||
|
||||
# Cast sample back to model compatible dtype
|
||||
prev_sample = prev_sample.to(model_output.dtype)
|
||||
prev_sample = sample + dt * model_output
|
||||
|
||||
# upon completion increase step index by one
|
||||
self._step_index += 1
|
||||
if per_token_timesteps is None:
|
||||
# Cast sample back to model compatible dtype
|
||||
prev_sample = prev_sample.to(model_output.dtype)
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
|
||||
@@ -1217,6 +1217,21 @@ class LEditsPPPipelineStableDiffusionXL(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class LTXConditionPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class LTXImageToVideoPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ import random
|
||||
import struct
|
||||
import tempfile
|
||||
from contextlib import contextmanager
|
||||
from typing import List, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
@@ -139,8 +139,31 @@ def _legacy_export_to_video(
|
||||
|
||||
|
||||
def export_to_video(
|
||||
video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], output_video_path: str = None, fps: int = 10
|
||||
video_frames: Union[List[np.ndarray], List[PIL.Image.Image]],
|
||||
output_video_path: str = None,
|
||||
fps: int = 10,
|
||||
quality: float = 5.0,
|
||||
bitrate: Optional[int] = None,
|
||||
macro_block_size: Optional[int] = 16,
|
||||
) -> str:
|
||||
"""
|
||||
quality:
|
||||
Video output quality. Default is 5. Uses variable bit rate. Highest quality is 10, lowest is 0. Set to None to
|
||||
prevent variable bitrate flags to FFMPEG so you can manually specify them using output_params instead.
|
||||
Specifying a fixed bitrate using `bitrate` disables this parameter.
|
||||
|
||||
bitrate:
|
||||
Set a constant bitrate for the video encoding. Default is None causing `quality` parameter to be used instead.
|
||||
Better quality videos with smaller file sizes will result from using the `quality` variable bitrate parameter
|
||||
rather than specifiying a fixed bitrate with this parameter.
|
||||
|
||||
macro_block_size:
|
||||
Size constraint for video. Width and height, must be divisible by this number. If not divisible by this number
|
||||
imageio will tell ffmpeg to scale the image up to the next closest size divisible by this number. Most codecs
|
||||
are compatible with a macroblock size of 16 (default), some can go smaller (4, 8). To disable this automatic
|
||||
feature set it to None or 1, however be warned many players can't decode videos that are odd in size and some
|
||||
codecs will produce poor results or fail. See https://en.wikipedia.org/wiki/Macroblock.
|
||||
"""
|
||||
# TODO: Dhruv. Remove by Diffusers release 0.33.0
|
||||
# Added to prevent breaking existing code
|
||||
if not is_imageio_available():
|
||||
@@ -177,7 +200,9 @@ def export_to_video(
|
||||
elif isinstance(video_frames[0], PIL.Image.Image):
|
||||
video_frames = [np.array(frame) for frame in video_frames]
|
||||
|
||||
with imageio.get_writer(output_video_path, fps=fps) as writer:
|
||||
with imageio.get_writer(
|
||||
output_video_path, fps=fps, quality=quality, bitrate=bitrate, macro_block_size=macro_block_size
|
||||
) as writer:
|
||||
for frame in video_frames:
|
||||
writer.append_data(frame)
|
||||
|
||||
|
||||
@@ -367,7 +367,7 @@ def prepare_encode(
|
||||
if shift_factor is not None:
|
||||
parameters["shift_factor"] = shift_factor
|
||||
if isinstance(image, torch.Tensor):
|
||||
data = safetensors.torch._tobytes(image, "tensor")
|
||||
data = safetensors.torch._tobytes(image.contiguous(), "tensor")
|
||||
parameters["shape"] = list(image.shape)
|
||||
parameters["dtype"] = str(image.dtype).split(".")[-1]
|
||||
else:
|
||||
|
||||
@@ -320,6 +320,21 @@ def require_torch_multi_gpu(test_case):
|
||||
return unittest.skipUnless(torch.cuda.device_count() > 1, "test requires multiple GPUs")(test_case)
|
||||
|
||||
|
||||
def require_torch_multi_accelerator(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires a multi-accelerator setup (in PyTorch). These tests are skipped on a machine
|
||||
without multiple hardware accelerators.
|
||||
"""
|
||||
if not is_torch_available():
|
||||
return unittest.skip("test requires PyTorch")(test_case)
|
||||
|
||||
import torch
|
||||
|
||||
return unittest.skipUnless(
|
||||
torch.cuda.device_count() > 1 or torch.xpu.device_count() > 1, "test requires multiple hardware accelerators"
|
||||
)(test_case)
|
||||
|
||||
|
||||
def require_torch_accelerator_with_fp16(test_case):
|
||||
"""Decorator marking a test that requires an accelerator with support for the FP16 data type."""
|
||||
return unittest.skipUnless(_is_torch_fp16_available(torch_device), "test requires accelerator with fp16 support")(
|
||||
@@ -354,6 +369,31 @@ def require_big_gpu_with_torch_cuda(test_case):
|
||||
)(test_case)
|
||||
|
||||
|
||||
def require_big_accelerator(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires a bigger hardware accelerator (24GB) for execution. Some example pipelines:
|
||||
Flux, SD3, Cog, etc.
|
||||
"""
|
||||
if not is_torch_available():
|
||||
return unittest.skip("test requires PyTorch")(test_case)
|
||||
|
||||
import torch
|
||||
|
||||
if not (torch.cuda.is_available() or torch.xpu.is_available()):
|
||||
return unittest.skip("test requires PyTorch CUDA")(test_case)
|
||||
|
||||
if torch.xpu.is_available():
|
||||
device_properties = torch.xpu.get_device_properties(0)
|
||||
else:
|
||||
device_properties = torch.cuda.get_device_properties(0)
|
||||
|
||||
total_memory = device_properties.total_memory / (1024**3)
|
||||
return unittest.skipUnless(
|
||||
total_memory >= BIG_GPU_MEMORY,
|
||||
f"test requires a hardware accelerator with at least {BIG_GPU_MEMORY} GB memory",
|
||||
)(test_case)
|
||||
|
||||
|
||||
def require_torch_accelerator_with_training(test_case):
|
||||
"""Decorator marking a test that requires an accelerator with support for training."""
|
||||
return unittest.skipUnless(
|
||||
@@ -574,10 +614,10 @@ def load_numpy(arry: Union[str, np.ndarray], local_path: Optional[str] = None) -
|
||||
return arry
|
||||
|
||||
|
||||
def load_pt(url: str):
|
||||
def load_pt(url: str, map_location: str):
|
||||
response = requests.get(url)
|
||||
response.raise_for_status()
|
||||
arry = torch.load(BytesIO(response.content))
|
||||
arry = torch.load(BytesIO(response.content), map_location=map_location)
|
||||
return arry
|
||||
|
||||
|
||||
|
||||
@@ -124,7 +124,7 @@ class AsymmetricAutoencoderKLIntegrationTests(unittest.TestCase):
|
||||
return model
|
||||
|
||||
def get_generator(self, seed=0):
|
||||
generator_device = "cpu" if not torch_device.startswith("cuda") else "cuda"
|
||||
generator_device = "cpu" if not torch_device.startswith(torch_device) else torch_device
|
||||
if torch_device != "mps":
|
||||
return torch.Generator(device=generator_device).manual_seed(seed)
|
||||
return torch.manual_seed(seed)
|
||||
|
||||
@@ -165,7 +165,7 @@ class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model.eval()
|
||||
|
||||
# Keep generator on CPU for non-CUDA devices to compare outputs with CPU result tensors
|
||||
generator_device = "cpu" if not torch_device.startswith("cuda") else "cuda"
|
||||
generator_device = "cpu" if not torch_device.startswith(torch_device) else torch_device
|
||||
if torch_device != "mps":
|
||||
generator = torch.Generator(device=generator_device).manual_seed(0)
|
||||
else:
|
||||
@@ -263,7 +263,7 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
|
||||
return model
|
||||
|
||||
def get_generator(self, seed=0):
|
||||
generator_device = "cpu" if not torch_device.startswith("cuda") else "cuda"
|
||||
generator_device = "cpu" if not torch_device.startswith(torch_device) else torch_device
|
||||
if torch_device != "mps":
|
||||
return torch.Generator(device=generator_device).manual_seed(seed)
|
||||
return torch.manual_seed(seed)
|
||||
|
||||
@@ -183,7 +183,7 @@ class AutoencoderOobleckIntegrationTests(unittest.TestCase):
|
||||
return model
|
||||
|
||||
def get_generator(self, seed=0):
|
||||
generator_device = "cpu" if not torch_device.startswith("cuda") else "cuda"
|
||||
generator_device = "cpu" if not torch_device.startswith(torch_device) else torch_device
|
||||
if torch_device != "mps":
|
||||
return torch.Generator(device=generator_device).manual_seed(seed)
|
||||
return torch.manual_seed(seed)
|
||||
|
||||
@@ -63,7 +63,7 @@ from diffusers.utils.testing_utils import (
|
||||
require_torch_accelerator,
|
||||
require_torch_accelerator_with_training,
|
||||
require_torch_gpu,
|
||||
require_torch_multi_gpu,
|
||||
require_torch_multi_accelerator,
|
||||
run_test_in_subprocess,
|
||||
torch_all_close,
|
||||
torch_device,
|
||||
@@ -1227,7 +1227,7 @@ class ModelTesterMixin:
|
||||
|
||||
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
|
||||
|
||||
@require_torch_multi_gpu
|
||||
@require_torch_multi_accelerator
|
||||
def test_model_parallelism(self):
|
||||
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**config).eval()
|
||||
|
||||
@@ -89,8 +89,7 @@ class CogView4PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
hidden_size=32, intermediate_size=8, num_hidden_layers=2, num_attention_heads=4, head_dim=8
|
||||
)
|
||||
text_encoder = GlmForCausalLM(text_encoder_config)
|
||||
# TODO(aryan): change this to THUDM/CogView4 once released
|
||||
tokenizer = AutoTokenizer.from_pretrained("THUDM/glm-4-9b-chat", trust_remote_code=True)
|
||||
tokenizer = AutoTokenizer.from_pretrained("THUDM/CogView4-6B", subfolder="tokenizer")
|
||||
|
||||
components = {
|
||||
"transformer": transformer,
|
||||
|
||||
@@ -31,9 +31,10 @@ from diffusers import (
|
||||
from diffusers.models import SD3ControlNetModel, SD3MultiControlNetModel
|
||||
from diffusers.utils import load_image
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_big_gpu_with_torch_cuda,
|
||||
require_big_accelerator,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
@@ -219,7 +220,7 @@ class StableDiffusion3ControlNetPipelineFastTests(unittest.TestCase, PipelineTes
|
||||
|
||||
|
||||
@slow
|
||||
@require_big_gpu_with_torch_cuda
|
||||
@require_big_accelerator
|
||||
@pytest.mark.big_gpu_with_torch_cuda
|
||||
class StableDiffusion3ControlNetPipelineSlowTests(unittest.TestCase):
|
||||
pipeline_class = StableDiffusion3ControlNetPipeline
|
||||
@@ -227,12 +228,12 @@ class StableDiffusion3ControlNetPipelineSlowTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_canny(self):
|
||||
controlnet = SD3ControlNetModel.from_pretrained("InstantX/SD3-Controlnet-Canny", torch_dtype=torch.float16)
|
||||
@@ -272,7 +273,7 @@ class StableDiffusion3ControlNetPipelineSlowTests(unittest.TestCase):
|
||||
pipe = StableDiffusion3ControlNetPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-3-medium-diffusers", controlnet=controlnet, torch_dtype=torch.float16
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.enable_model_cpu_offload(device=torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
@@ -304,7 +305,7 @@ class StableDiffusion3ControlNetPipelineSlowTests(unittest.TestCase):
|
||||
pipe = StableDiffusion3ControlNetPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-3-medium-diffusers", controlnet=controlnet, torch_dtype=torch.float16
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.enable_model_cpu_offload(device=torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
@@ -338,7 +339,7 @@ class StableDiffusion3ControlNetPipelineSlowTests(unittest.TestCase):
|
||||
pipe = StableDiffusion3ControlNetPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-3-medium-diffusers", controlnet=controlnet, torch_dtype=torch.float16
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.enable_model_cpu_offload(device=torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
|
||||
@@ -12,7 +12,7 @@ from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
nightly,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_big_gpu_with_torch_cuda,
|
||||
require_big_accelerator,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
@@ -204,7 +204,7 @@ class FluxPipelineFastTests(
|
||||
|
||||
|
||||
@nightly
|
||||
@require_big_gpu_with_torch_cuda
|
||||
@require_big_accelerator
|
||||
@pytest.mark.big_gpu_with_torch_cuda
|
||||
class FluxPipelineSlowTests(unittest.TestCase):
|
||||
pipeline_class = FluxPipeline
|
||||
@@ -292,7 +292,7 @@ class FluxPipelineSlowTests(unittest.TestCase):
|
||||
|
||||
|
||||
@slow
|
||||
@require_big_gpu_with_torch_cuda
|
||||
@require_big_accelerator
|
||||
@pytest.mark.big_gpu_with_torch_cuda
|
||||
class FluxIPAdapterPipelineSlowTests(unittest.TestCase):
|
||||
pipeline_class = FluxPipeline
|
||||
@@ -304,12 +304,12 @@ class FluxIPAdapterPipelineSlowTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def get_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
|
||||
@@ -8,15 +8,16 @@ import torch
|
||||
from diffusers import FluxPipeline, FluxPriorReduxPipeline
|
||||
from diffusers.utils import load_image
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_big_gpu_with_torch_cuda,
|
||||
require_big_accelerator,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
@slow
|
||||
@require_big_gpu_with_torch_cuda
|
||||
@require_big_accelerator
|
||||
@pytest.mark.big_gpu_with_torch_cuda
|
||||
class FluxReduxSlowTests(unittest.TestCase):
|
||||
pipeline_class = FluxPriorReduxPipeline
|
||||
@@ -27,12 +28,12 @@ class FluxReduxSlowTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def get_inputs(self, device, seed=0):
|
||||
init_image = load_image(
|
||||
@@ -59,7 +60,7 @@ class FluxReduxSlowTests(unittest.TestCase):
|
||||
self.base_repo_id, torch_dtype=torch.bfloat16, text_encoder=None, text_encoder_2=None
|
||||
)
|
||||
pipe_redux.to(torch_device)
|
||||
pipe_base.enable_model_cpu_offload()
|
||||
pipe_base.enable_model_cpu_offload(device=torch_device)
|
||||
|
||||
inputs = self.get_inputs(torch_device)
|
||||
base_pipeline_inputs = self.get_base_pipeline_inputs(torch_device)
|
||||
|
||||
@@ -377,9 +377,10 @@ class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin):
|
||||
pipeline.set_ip_adapter_scale(0.7)
|
||||
|
||||
inputs = self.get_dummy_inputs()
|
||||
id_embeds = load_pt("https://huggingface.co/datasets/fabiorigano/testing-images/resolve/main/ai_face2.ipadpt")[
|
||||
0
|
||||
]
|
||||
id_embeds = load_pt(
|
||||
"https://huggingface.co/datasets/fabiorigano/testing-images/resolve/main/ai_face2.ipadpt",
|
||||
map_location=torch_device,
|
||||
)[0]
|
||||
id_embeds = id_embeds.reshape((2, 1, 1, 512))
|
||||
inputs["ip_adapter_image_embeds"] = [id_embeds]
|
||||
inputs["ip_adapter_image"] = None
|
||||
|
||||
@@ -0,0 +1,284 @@
|
||||
# Copyright 2024 The HuggingFace Team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLLTXVideo,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
LTXConditionPipeline,
|
||||
LTXVideoTransformer3DModel,
|
||||
)
|
||||
from diffusers.pipelines.ltx.pipeline_ltx_condition import LTXVideoCondition
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
|
||||
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin, to_np
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class LTXConditionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = LTXConditionPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS.union({"image"})
|
||||
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
required_optional_params = frozenset(
|
||||
[
|
||||
"num_inference_steps",
|
||||
"generator",
|
||||
"latents",
|
||||
"return_dict",
|
||||
"callback_on_step_end",
|
||||
"callback_on_step_end_tensor_inputs",
|
||||
]
|
||||
)
|
||||
test_xformers_attention = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
transformer = LTXVideoTransformer3DModel(
|
||||
in_channels=8,
|
||||
out_channels=8,
|
||||
patch_size=1,
|
||||
patch_size_t=1,
|
||||
num_attention_heads=4,
|
||||
attention_head_dim=8,
|
||||
cross_attention_dim=32,
|
||||
num_layers=1,
|
||||
caption_channels=32,
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKLLTXVideo(
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
latent_channels=8,
|
||||
block_out_channels=(8, 8, 8, 8),
|
||||
decoder_block_out_channels=(8, 8, 8, 8),
|
||||
layers_per_block=(1, 1, 1, 1, 1),
|
||||
decoder_layers_per_block=(1, 1, 1, 1, 1),
|
||||
spatio_temporal_scaling=(True, True, False, False),
|
||||
decoder_spatio_temporal_scaling=(True, True, False, False),
|
||||
decoder_inject_noise=(False, False, False, False, False),
|
||||
upsample_residual=(False, False, False, False),
|
||||
upsample_factor=(1, 1, 1, 1),
|
||||
timestep_conditioning=False,
|
||||
patch_size=1,
|
||||
patch_size_t=1,
|
||||
encoder_causal=True,
|
||||
decoder_causal=False,
|
||||
)
|
||||
vae.use_framewise_encoding = False
|
||||
vae.use_framewise_decoding = False
|
||||
|
||||
torch.manual_seed(0)
|
||||
scheduler = FlowMatchEulerDiscreteScheduler()
|
||||
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
components = {
|
||||
"transformer": transformer,
|
||||
"vae": vae,
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0, use_conditions=False):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
image = torch.randn((1, 3, 32, 32), generator=generator, device=device)
|
||||
if use_conditions:
|
||||
conditions = LTXVideoCondition(
|
||||
image=image,
|
||||
)
|
||||
else:
|
||||
conditions = None
|
||||
|
||||
inputs = {
|
||||
"conditions": conditions,
|
||||
"image": None if use_conditions else image,
|
||||
"prompt": "dance monkey",
|
||||
"negative_prompt": "",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 3.0,
|
||||
"height": 32,
|
||||
"width": 32,
|
||||
# 8 * k + 1 is the recommendation
|
||||
"num_frames": 9,
|
||||
"max_sequence_length": 16,
|
||||
"output_type": "pt",
|
||||
}
|
||||
|
||||
return inputs
|
||||
|
||||
def test_inference(self):
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs2 = self.get_dummy_inputs(device, use_conditions=True)
|
||||
video = pipe(**inputs).frames
|
||||
generated_video = video[0]
|
||||
video2 = pipe(**inputs2).frames
|
||||
generated_video2 = video2[0]
|
||||
|
||||
self.assertEqual(generated_video.shape, (9, 3, 32, 32))
|
||||
|
||||
max_diff = np.abs(generated_video - generated_video2).max()
|
||||
self.assertLessEqual(max_diff, 1e-3)
|
||||
|
||||
def test_callback_inputs(self):
|
||||
sig = inspect.signature(self.pipeline_class.__call__)
|
||||
has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
|
||||
has_callback_step_end = "callback_on_step_end" in sig.parameters
|
||||
|
||||
if not (has_callback_tensor_inputs and has_callback_step_end):
|
||||
return
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
self.assertTrue(
|
||||
hasattr(pipe, "_callback_tensor_inputs"),
|
||||
f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
|
||||
)
|
||||
|
||||
def callback_inputs_subset(pipe, i, t, callback_kwargs):
|
||||
# iterate over callback args
|
||||
for tensor_name, tensor_value in callback_kwargs.items():
|
||||
# check that we're only passing in allowed tensor inputs
|
||||
assert tensor_name in pipe._callback_tensor_inputs
|
||||
|
||||
return callback_kwargs
|
||||
|
||||
def callback_inputs_all(pipe, i, t, callback_kwargs):
|
||||
for tensor_name in pipe._callback_tensor_inputs:
|
||||
assert tensor_name in callback_kwargs
|
||||
|
||||
# iterate over callback args
|
||||
for tensor_name, tensor_value in callback_kwargs.items():
|
||||
# check that we're only passing in allowed tensor inputs
|
||||
assert tensor_name in pipe._callback_tensor_inputs
|
||||
|
||||
return callback_kwargs
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
# Test passing in a subset
|
||||
inputs["callback_on_step_end"] = callback_inputs_subset
|
||||
inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
|
||||
output = pipe(**inputs)[0]
|
||||
|
||||
# Test passing in a everything
|
||||
inputs["callback_on_step_end"] = callback_inputs_all
|
||||
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
|
||||
output = pipe(**inputs)[0]
|
||||
|
||||
def callback_inputs_change_tensor(pipe, i, t, callback_kwargs):
|
||||
is_last = i == (pipe.num_timesteps - 1)
|
||||
if is_last:
|
||||
callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
|
||||
return callback_kwargs
|
||||
|
||||
inputs["callback_on_step_end"] = callback_inputs_change_tensor
|
||||
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
|
||||
output = pipe(**inputs)[0]
|
||||
assert output.abs().sum() < 1e10
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-3)
|
||||
|
||||
def test_attention_slicing_forward_pass(
|
||||
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
|
||||
):
|
||||
if not self.test_attention_slicing:
|
||||
return
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
for component in pipe.components.values():
|
||||
if hasattr(component, "set_default_attn_processor"):
|
||||
component.set_default_attn_processor()
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator_device = "cpu"
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_without_slicing = pipe(**inputs)[0]
|
||||
|
||||
pipe.enable_attention_slicing(slice_size=1)
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_with_slicing1 = pipe(**inputs)[0]
|
||||
|
||||
pipe.enable_attention_slicing(slice_size=2)
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_with_slicing2 = pipe(**inputs)[0]
|
||||
|
||||
if test_max_difference:
|
||||
max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
|
||||
max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
|
||||
self.assertLess(
|
||||
max(max_diff1, max_diff2),
|
||||
expected_max_diff,
|
||||
"Attention slicing should not affect the inference results",
|
||||
)
|
||||
|
||||
def test_vae_tiling(self, expected_diff_max: float = 0.2):
|
||||
generator_device = "cpu"
|
||||
components = self.get_dummy_components()
|
||||
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to("cpu")
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# Without tiling
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
inputs["height"] = inputs["width"] = 128
|
||||
output_without_tiling = pipe(**inputs)[0]
|
||||
|
||||
# With tiling
|
||||
pipe.vae.enable_tiling(
|
||||
tile_sample_min_height=96,
|
||||
tile_sample_min_width=96,
|
||||
tile_sample_stride_height=64,
|
||||
tile_sample_stride_width=64,
|
||||
)
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
inputs["height"] = inputs["width"] = 128
|
||||
output_with_tiling = pipe(**inputs)[0]
|
||||
|
||||
self.assertLess(
|
||||
(to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
|
||||
expected_diff_max,
|
||||
"VAE tiling should not affect the inference results",
|
||||
)
|
||||
@@ -262,7 +262,7 @@ class StableDiffusion3PAGImg2ImgPipelineIntegrationTests(unittest.TestCase):
|
||||
pipeline = AutoPipelineForImage2Image.from_pretrained(
|
||||
self.repo_id, enable_pag=True, torch_dtype=torch.float16, pag_applied_layers=["blocks.(4|17)"]
|
||||
)
|
||||
pipeline.enable_model_cpu_offload()
|
||||
pipeline.enable_model_cpu_offload(device=torch_device)
|
||||
pipeline.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_inputs(torch_device, guidance_scale=0.0, pag_scale=1.8)
|
||||
|
||||
@@ -57,7 +57,7 @@ from diffusers.utils.testing_utils import (
|
||||
require_accelerate_version_greater,
|
||||
require_torch_2,
|
||||
require_torch_accelerator,
|
||||
require_torch_multi_gpu,
|
||||
require_torch_multi_accelerator,
|
||||
run_test_in_subprocess,
|
||||
skip_mps,
|
||||
slow,
|
||||
@@ -1409,7 +1409,7 @@ class StableDiffusionPipelineNightlyTests(unittest.TestCase):
|
||||
|
||||
# (sayakpaul): This test suite was run in the DGX with two GPUs (1, 2).
|
||||
@slow
|
||||
@require_torch_multi_gpu
|
||||
@require_torch_multi_accelerator
|
||||
@require_accelerate_version_greater("0.27.0")
|
||||
class StableDiffusionPipelineDeviceMapTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
@@ -1497,7 +1497,7 @@ class StableDiffusionPipelineDeviceMapTests(unittest.TestCase):
|
||||
assert sd_pipe_with_device_map.hf_device_map is None
|
||||
|
||||
# Make sure `to()` can be used and the pipeline can be called.
|
||||
pipe = sd_pipe_with_device_map.to("cuda")
|
||||
pipe = sd_pipe_with_device_map.to(torch_device)
|
||||
_ = pipe("hello", num_inference_steps=2)
|
||||
|
||||
def test_reset_device_map_enable_model_cpu_offload(self):
|
||||
@@ -1509,7 +1509,7 @@ class StableDiffusionPipelineDeviceMapTests(unittest.TestCase):
|
||||
assert sd_pipe_with_device_map.hf_device_map is None
|
||||
|
||||
# Make sure `enable_model_cpu_offload()` can be used and the pipeline can be called.
|
||||
sd_pipe_with_device_map.enable_model_cpu_offload()
|
||||
sd_pipe_with_device_map.enable_model_cpu_offload(device=torch_device)
|
||||
_ = sd_pipe_with_device_map("hello", num_inference_steps=2)
|
||||
|
||||
def test_reset_device_map_enable_sequential_cpu_offload(self):
|
||||
@@ -1521,5 +1521,5 @@ class StableDiffusionPipelineDeviceMapTests(unittest.TestCase):
|
||||
assert sd_pipe_with_device_map.hf_device_map is None
|
||||
|
||||
# Make sure `enable_sequential_cpu_offload()` can be used and the pipeline can be called.
|
||||
sd_pipe_with_device_map.enable_sequential_cpu_offload()
|
||||
sd_pipe_with_device_map.enable_sequential_cpu_offload(device=torch_device)
|
||||
_ = sd_pipe_with_device_map("hello", num_inference_steps=2)
|
||||
|
||||
@@ -10,7 +10,7 @@ from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, SD3Transfo
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_big_gpu_with_torch_cuda,
|
||||
require_big_accelerator,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
@@ -232,7 +232,7 @@ class StableDiffusion3PipelineFastTests(unittest.TestCase, PipelineTesterMixin):
|
||||
|
||||
|
||||
@slow
|
||||
@require_big_gpu_with_torch_cuda
|
||||
@require_big_accelerator
|
||||
@pytest.mark.big_gpu_with_torch_cuda
|
||||
class StableDiffusion3PipelineSlowTests(unittest.TestCase):
|
||||
pipeline_class = StableDiffusion3Pipeline
|
||||
|
||||
@@ -18,7 +18,7 @@ from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
floats_tensor,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_big_gpu_with_torch_cuda,
|
||||
require_big_accelerator,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
@@ -166,7 +166,7 @@ class StableDiffusion3Img2ImgPipelineFastTests(PipelineLatentTesterMixin, unitte
|
||||
|
||||
|
||||
@slow
|
||||
@require_big_gpu_with_torch_cuda
|
||||
@require_big_accelerator
|
||||
@pytest.mark.big_gpu_with_torch_cuda
|
||||
class StableDiffusion3Img2ImgPipelineSlowTests(unittest.TestCase):
|
||||
pipeline_class = StableDiffusion3Img2ImgPipeline
|
||||
@@ -202,11 +202,10 @@ class StableDiffusion3Img2ImgPipelineSlowTests(unittest.TestCase):
|
||||
}
|
||||
|
||||
def test_sd3_img2img_inference(self):
|
||||
torch.manual_seed(0)
|
||||
pipe = self.pipeline_class.from_pretrained(self.repo_id, torch_dtype=torch.float16)
|
||||
pipe.enable_model_cpu_offload(device=torch_device)
|
||||
|
||||
inputs = self.get_inputs(torch_device)
|
||||
|
||||
image = pipe(**inputs).images[0]
|
||||
image_slice = image[0, :10, :10]
|
||||
expected_slice = np.array(
|
||||
|
||||
@@ -45,6 +45,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.source_code_parsing_utils import ReturnNameVisitor
|
||||
from diffusers.utils.testing_utils import (
|
||||
CaptureLogger,
|
||||
backend_empty_cache,
|
||||
require_accelerate_version_greater,
|
||||
require_accelerator,
|
||||
require_hf_hub_version_greater,
|
||||
@@ -1108,13 +1109,13 @@ class PipelineTesterMixin:
|
||||
# clean up the VRAM before each test
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test in case of CUDA runtime errors
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_save_load_local(self, expected_max_difference=5e-4):
|
||||
components = self.get_dummy_components()
|
||||
@@ -1423,7 +1424,6 @@ class PipelineTesterMixin:
|
||||
def test_save_load_optional_components(self, expected_max_difference=1e-4):
|
||||
if not hasattr(self.pipeline_class, "_optional_components"):
|
||||
return
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
for component in pipe.components.values():
|
||||
@@ -1438,6 +1438,7 @@ class PipelineTesterMixin:
|
||||
|
||||
generator_device = "cpu"
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
torch.manual_seed(0)
|
||||
output = pipe(**inputs)[0]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
@@ -1456,6 +1457,7 @@ class PipelineTesterMixin:
|
||||
)
|
||||
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
torch.manual_seed(0)
|
||||
output_loaded = pipe_loaded(**inputs)[0]
|
||||
|
||||
max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
|
||||
@@ -1550,12 +1552,14 @@ class PipelineTesterMixin:
|
||||
|
||||
generator_device = "cpu"
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
torch.manual_seed(0)
|
||||
output_without_offload = pipe(**inputs)[0]
|
||||
|
||||
pipe.enable_sequential_cpu_offload(device=torch_device)
|
||||
assert pipe._execution_device.type == torch_device
|
||||
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
torch.manual_seed(0)
|
||||
output_with_offload = pipe(**inputs)[0]
|
||||
|
||||
max_diff = np.abs(to_np(output_with_offload) - to_np(output_without_offload)).max()
|
||||
@@ -1613,12 +1617,14 @@ class PipelineTesterMixin:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
torch.manual_seed(0)
|
||||
output_without_offload = pipe(**inputs)[0]
|
||||
|
||||
pipe.enable_model_cpu_offload(device=torch_device)
|
||||
assert pipe._execution_device.type == torch_device
|
||||
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
torch.manual_seed(0)
|
||||
output_with_offload = pipe(**inputs)[0]
|
||||
|
||||
max_diff = np.abs(to_np(output_with_offload) - to_np(output_without_offload)).max()
|
||||
|
||||
@@ -303,6 +303,7 @@ class UnCLIPPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler()
|
||||
)
|
||||
shape = (batch_size, decoder.config.in_channels, decoder.config.sample_size, decoder.config.sample_size)
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
decoder_latents = pipe.prepare_latents(
|
||||
shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler()
|
||||
)
|
||||
|
||||
@@ -407,6 +407,7 @@ class UnCLIPImageVariationPipelineFastTests(PipelineTesterMixin, unittest.TestCa
|
||||
pipe.super_res_first.config.sample_size,
|
||||
pipe.super_res_first.config.sample_size,
|
||||
)
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
super_res_latents = pipe.prepare_latents(
|
||||
shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler()
|
||||
)
|
||||
|
||||
@@ -26,6 +26,7 @@ from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DMo
|
||||
from diffusers.utils import is_accelerate_version, logging
|
||||
from diffusers.utils.testing_utils import (
|
||||
CaptureLogger,
|
||||
backend_empty_cache,
|
||||
is_bitsandbytes_available,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
@@ -35,7 +36,7 @@ from diffusers.utils.testing_utils import (
|
||||
require_bitsandbytes_version_greater,
|
||||
require_peft_backend,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
require_torch_accelerator,
|
||||
require_transformers_version_greater,
|
||||
slow,
|
||||
torch_device,
|
||||
@@ -66,7 +67,7 @@ if is_bitsandbytes_available():
|
||||
@require_bitsandbytes_version_greater("0.43.2")
|
||||
@require_accelerate
|
||||
@require_torch
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
@slow
|
||||
class Base4bitTests(unittest.TestCase):
|
||||
# We need to test on relatively large models (aka >1b parameters otherwise the quantiztion may not work as expected)
|
||||
@@ -84,13 +85,16 @@ class Base4bitTests(unittest.TestCase):
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
prompt_embeds = load_pt(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/prompt_embeds.pt"
|
||||
"https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/prompt_embeds.pt",
|
||||
torch_device,
|
||||
)
|
||||
pooled_prompt_embeds = load_pt(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/pooled_prompt_embeds.pt"
|
||||
"https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/pooled_prompt_embeds.pt",
|
||||
torch_device,
|
||||
)
|
||||
latent_model_input = load_pt(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/latent_model_input.pt"
|
||||
"https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/latent_model_input.pt",
|
||||
torch_device,
|
||||
)
|
||||
|
||||
input_dict_for_transformer = {
|
||||
@@ -106,7 +110,7 @@ class Base4bitTests(unittest.TestCase):
|
||||
class BnB4BitBasicTests(Base4bitTests):
|
||||
def setUp(self):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
# Models
|
||||
self.model_fp16 = SD3Transformer2DModel.from_pretrained(
|
||||
@@ -128,7 +132,7 @@ class BnB4BitBasicTests(Base4bitTests):
|
||||
del self.model_4bit
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_quantization_num_parameters(self):
|
||||
r"""
|
||||
@@ -224,7 +228,7 @@ class BnB4BitBasicTests(Base4bitTests):
|
||||
self.assertTrue(module.weight.dtype == torch.uint8)
|
||||
|
||||
# test if inference works.
|
||||
with torch.no_grad() and torch.amp.autocast("cuda", dtype=torch.float16):
|
||||
with torch.no_grad() and torch.amp.autocast(torch_device, dtype=torch.float16):
|
||||
input_dict_for_transformer = self.get_dummy_inputs()
|
||||
model_inputs = {
|
||||
k: v.to(device=torch_device) for k, v in input_dict_for_transformer.items() if not isinstance(v, bool)
|
||||
@@ -266,9 +270,9 @@ class BnB4BitBasicTests(Base4bitTests):
|
||||
self.assertAlmostEqual(self.model_4bit.get_memory_footprint(), mem_before)
|
||||
|
||||
# Move back to CUDA device
|
||||
for device in [0, "cuda", "cuda:0", "call()"]:
|
||||
for device in [0, f"{torch_device}", f"{torch_device}:0", "call()"]:
|
||||
if device == "call()":
|
||||
self.model_4bit.cuda(0)
|
||||
self.model_4bit.to(f"{torch_device}:0")
|
||||
else:
|
||||
self.model_4bit.to(device)
|
||||
self.assertEqual(self.model_4bit.device, torch.device(0))
|
||||
@@ -286,7 +290,7 @@ class BnB4BitBasicTests(Base4bitTests):
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# Tries with a `device` and `dtype`
|
||||
self.model_4bit.to(device="cuda:0", dtype=torch.float16)
|
||||
self.model_4bit.to(device=f"{torch_device}:0", dtype=torch.float16)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# Tries with a cast
|
||||
@@ -297,7 +301,7 @@ class BnB4BitBasicTests(Base4bitTests):
|
||||
self.model_4bit.half()
|
||||
|
||||
# This should work
|
||||
self.model_4bit.to("cuda")
|
||||
self.model_4bit.to(torch_device)
|
||||
|
||||
# Test if we did not break anything
|
||||
self.model_fp16 = self.model_fp16.to(dtype=torch.float32, device=torch_device)
|
||||
@@ -321,7 +325,7 @@ class BnB4BitBasicTests(Base4bitTests):
|
||||
_ = self.model_fp16.float()
|
||||
|
||||
# Check that this does not throw an error
|
||||
_ = self.model_fp16.cuda()
|
||||
_ = self.model_fp16.to(torch_device)
|
||||
|
||||
def test_bnb_4bit_wrong_config(self):
|
||||
r"""
|
||||
@@ -398,7 +402,7 @@ class BnB4BitTrainingTests(Base4bitTests):
|
||||
model_inputs.update({k: v for k, v in input_dict_for_transformer.items() if k not in model_inputs})
|
||||
|
||||
# Step 4: Check if the gradient is not None
|
||||
with torch.amp.autocast("cuda", dtype=torch.float16):
|
||||
with torch.amp.autocast(torch_device, dtype=torch.float16):
|
||||
out = self.model_4bit(**model_inputs)[0]
|
||||
out.norm().backward()
|
||||
|
||||
@@ -412,7 +416,7 @@ class BnB4BitTrainingTests(Base4bitTests):
|
||||
class SlowBnb4BitTests(Base4bitTests):
|
||||
def setUp(self) -> None:
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
nf4_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
@@ -431,7 +435,7 @@ class SlowBnb4BitTests(Base4bitTests):
|
||||
del self.pipeline_4bit
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_quality(self):
|
||||
output = self.pipeline_4bit(
|
||||
@@ -501,7 +505,7 @@ class SlowBnb4BitTests(Base4bitTests):
|
||||
reason="Test will pass after https://github.com/huggingface/accelerate/pull/3223 is in a release.",
|
||||
strict=True,
|
||||
)
|
||||
def test_pipeline_cuda_placement_works_with_nf4(self):
|
||||
def test_pipeline_device_placement_works_with_nf4(self):
|
||||
transformer_nf4_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
@@ -532,7 +536,7 @@ class SlowBnb4BitTests(Base4bitTests):
|
||||
transformer=transformer_4bit,
|
||||
text_encoder_3=text_encoder_3_4bit,
|
||||
torch_dtype=torch.float16,
|
||||
).to("cuda")
|
||||
).to(torch_device)
|
||||
|
||||
# Check if inference works.
|
||||
_ = pipeline_4bit("table", max_sequence_length=20, num_inference_steps=2)
|
||||
@@ -696,7 +700,7 @@ class SlowBnb4BitFluxTests(Base4bitTests):
|
||||
class BaseBnb4BitSerializationTests(Base4bitTests):
|
||||
def tearDown(self):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_serialization(self, quant_type="nf4", double_quant=True, safe_serialization=True):
|
||||
r"""
|
||||
|
||||
@@ -31,6 +31,7 @@ from diffusers import (
|
||||
from diffusers.utils import is_accelerate_version
|
||||
from diffusers.utils.testing_utils import (
|
||||
CaptureLogger,
|
||||
backend_empty_cache,
|
||||
is_bitsandbytes_available,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
@@ -40,7 +41,7 @@ from diffusers.utils.testing_utils import (
|
||||
require_bitsandbytes_version_greater,
|
||||
require_peft_version_greater,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
require_torch_accelerator,
|
||||
require_transformers_version_greater,
|
||||
slow,
|
||||
torch_device,
|
||||
@@ -71,7 +72,7 @@ if is_bitsandbytes_available():
|
||||
@require_bitsandbytes_version_greater("0.43.2")
|
||||
@require_accelerate
|
||||
@require_torch
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
@slow
|
||||
class Base8bitTests(unittest.TestCase):
|
||||
# We need to test on relatively large models (aka >1b parameters otherwise the quantiztion may not work as expected)
|
||||
@@ -111,7 +112,7 @@ class Base8bitTests(unittest.TestCase):
|
||||
class BnB8bitBasicTests(Base8bitTests):
|
||||
def setUp(self):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
# Models
|
||||
self.model_fp16 = SD3Transformer2DModel.from_pretrained(
|
||||
@@ -129,7 +130,7 @@ class BnB8bitBasicTests(Base8bitTests):
|
||||
del self.model_8bit
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_quantization_num_parameters(self):
|
||||
r"""
|
||||
@@ -279,7 +280,7 @@ class BnB8bitBasicTests(Base8bitTests):
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# Tries with a `device`
|
||||
self.model_8bit.to(torch.device("cuda:0"))
|
||||
self.model_8bit.to(torch.device(f"{torch_device}:0"))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# Tries with a `device`
|
||||
@@ -317,7 +318,7 @@ class BnB8bitBasicTests(Base8bitTests):
|
||||
class Bnb8bitDeviceTests(Base8bitTests):
|
||||
def setUp(self) -> None:
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True)
|
||||
self.model_8bit = SanaTransformer2DModel.from_pretrained(
|
||||
@@ -331,7 +332,7 @@ class Bnb8bitDeviceTests(Base8bitTests):
|
||||
del self.model_8bit
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_buffers_device_assignment(self):
|
||||
for buffer_name, buffer in self.model_8bit.named_buffers():
|
||||
@@ -345,7 +346,7 @@ class Bnb8bitDeviceTests(Base8bitTests):
|
||||
class BnB8bitTrainingTests(Base8bitTests):
|
||||
def setUp(self):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True)
|
||||
self.model_8bit = SD3Transformer2DModel.from_pretrained(
|
||||
@@ -389,7 +390,7 @@ class BnB8bitTrainingTests(Base8bitTests):
|
||||
class SlowBnb8bitTests(Base8bitTests):
|
||||
def setUp(self) -> None:
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True)
|
||||
model_8bit = SD3Transformer2DModel.from_pretrained(
|
||||
@@ -404,7 +405,7 @@ class SlowBnb8bitTests(Base8bitTests):
|
||||
del self.pipeline_8bit
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_quality(self):
|
||||
output = self.pipeline_8bit(
|
||||
@@ -616,7 +617,7 @@ class SlowBnb8bitTests(Base8bitTests):
|
||||
class SlowBnb8bitFluxTests(Base8bitTests):
|
||||
def setUp(self) -> None:
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
model_id = "hf-internal-testing/flux.1-dev-int8-pkg"
|
||||
t5_8bit = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder_2")
|
||||
@@ -633,7 +634,7 @@ class SlowBnb8bitFluxTests(Base8bitTests):
|
||||
del self.pipeline_8bit
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_quality(self):
|
||||
# keep the resolution and max tokens to a lower number for faster execution.
|
||||
@@ -680,7 +681,7 @@ class SlowBnb8bitFluxTests(Base8bitTests):
|
||||
class BaseBnb8bitSerializationTests(Base8bitTests):
|
||||
def setUp(self):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
quantization_config = BitsAndBytesConfig(
|
||||
load_in_8bit=True,
|
||||
@@ -693,7 +694,7 @@ class BaseBnb8bitSerializationTests(Base8bitTests):
|
||||
del self.model_0
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_serialization(self):
|
||||
r"""
|
||||
|
||||
@@ -64,7 +64,7 @@ class DPMSolverSDESchedulerTest(SchedulerCommonTest):
|
||||
if torch_device in ["mps"]:
|
||||
assert abs(result_sum.item() - 167.47821044921875) < 1e-2
|
||||
assert abs(result_mean.item() - 0.2178705964565277) < 1e-3
|
||||
elif torch_device in ["cuda"]:
|
||||
elif torch_device in ["cuda", "xpu"]:
|
||||
assert abs(result_sum.item() - 171.59352111816406) < 1e-2
|
||||
assert abs(result_mean.item() - 0.22342906892299652) < 1e-3
|
||||
else:
|
||||
@@ -96,7 +96,7 @@ class DPMSolverSDESchedulerTest(SchedulerCommonTest):
|
||||
if torch_device in ["mps"]:
|
||||
assert abs(result_sum.item() - 124.77149200439453) < 1e-2
|
||||
assert abs(result_mean.item() - 0.16226289014816284) < 1e-3
|
||||
elif torch_device in ["cuda"]:
|
||||
elif torch_device in ["cuda", "xpu"]:
|
||||
assert abs(result_sum.item() - 128.1663360595703) < 1e-2
|
||||
assert abs(result_mean.item() - 0.16688326001167297) < 1e-3
|
||||
else:
|
||||
@@ -127,7 +127,7 @@ class DPMSolverSDESchedulerTest(SchedulerCommonTest):
|
||||
if torch_device in ["mps"]:
|
||||
assert abs(result_sum.item() - 167.46957397460938) < 1e-2
|
||||
assert abs(result_mean.item() - 0.21805934607982635) < 1e-3
|
||||
elif torch_device in ["cuda"]:
|
||||
elif torch_device in ["cuda", "xpu"]:
|
||||
assert abs(result_sum.item() - 171.59353637695312) < 1e-2
|
||||
assert abs(result_mean.item() - 0.22342908382415771) < 1e-3
|
||||
else:
|
||||
@@ -159,7 +159,7 @@ class DPMSolverSDESchedulerTest(SchedulerCommonTest):
|
||||
if torch_device in ["mps"]:
|
||||
assert abs(result_sum.item() - 176.66974135742188) < 1e-2
|
||||
assert abs(result_mean.item() - 0.23003872730981811) < 1e-2
|
||||
elif torch_device in ["cuda"]:
|
||||
elif torch_device in ["cuda", "xpu"]:
|
||||
assert abs(result_sum.item() - 177.63653564453125) < 1e-2
|
||||
assert abs(result_mean.item() - 0.23003872730981811) < 1e-2
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user