Compare commits

..

8 Commits

Author SHA1 Message Date
Aryan 278b3b8825 fix geglu and swiglu gate computation 2025-01-25 06:43:40 +01:00
Aryan d7f369cbab memory-optimized ff 2025-01-21 23:18:50 +01:00
Lucain a647682224 Remove cache migration script (#10619) 2025-01-21 07:22:59 -10:00
YiYi Xu a1f9a71238 fix offload gpu tests etc (#10366)
* add

* style
2025-01-21 18:52:36 +05:30
Fanli Lin ec37e20972 [tests] make tests device-agnostic (part 3) (#10437)
* initial comit

* fix empty cache

* fix one more

* fix style

* update device functions

* update

* update

* Update src/diffusers/utils/testing_utils.py

Co-authored-by: hlky <hlky@hlky.ac>

* Update src/diffusers/utils/testing_utils.py

Co-authored-by: hlky <hlky@hlky.ac>

* Update src/diffusers/utils/testing_utils.py

Co-authored-by: hlky <hlky@hlky.ac>

* Update tests/pipelines/controlnet/test_controlnet.py

Co-authored-by: hlky <hlky@hlky.ac>

* Update src/diffusers/utils/testing_utils.py

Co-authored-by: hlky <hlky@hlky.ac>

* Update src/diffusers/utils/testing_utils.py

Co-authored-by: hlky <hlky@hlky.ac>

* Update tests/pipelines/controlnet/test_controlnet.py

Co-authored-by: hlky <hlky@hlky.ac>

* with gc.collect

* update

* make style

* check_torch_dependencies

* add mps empty cache

* bug fix

* Apply suggestions from code review

---------

Co-authored-by: hlky <hlky@hlky.ac>
2025-01-21 12:15:45 +00:00
Muyang Li 158a5a87fb Remove the FP32 Wrapper when evaluating (#10617)
Remove the FP32 Wrapper

Co-authored-by: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com>
2025-01-21 16:16:54 +05:30
jiqing-feng 012d08b1bc Enable dreambooth lora finetune example on other devices (#10602)
* enable dreambooth_lora on other devices

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* enable xpu

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* check cuda device before empty cache

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix comment

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* import free_memory

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

---------

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
2025-01-21 14:09:45 +05:30
Sayak Paul 4ace7d0483 [chore] change licensing to 2025 from 2024. (#10615)
change licensing to 2025 from 2024.
2025-01-20 16:57:27 -10:00
151 changed files with 645 additions and 427 deletions
+3 -3
View File
@@ -265,7 +265,7 @@ jobs:
- name: Run PyTorch CUDA tests
env:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
HF_TOKEN: ${{ secrets.HF_TOKEN }}
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
@@ -505,7 +505,7 @@ jobs:
# shell: arch -arch arm64 bash {0}
# env:
# HF_HOME: /System/Volumes/Data/mnt/cache
# HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
# HF_TOKEN: ${{ secrets.HF_TOKEN }}
# run: |
# ${CONDA_RUN} python -m pytest -n 1 -s -v --make-reports=tests_torch_mps \
# --report-log=tests_torch_mps.log \
@@ -561,7 +561,7 @@ jobs:
# shell: arch -arch arm64 bash {0}
# env:
# HF_HOME: /System/Volumes/Data/mnt/cache
# HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
# HF_TOKEN: ${{ secrets.HF_TOKEN }}
# run: |
# ${CONDA_RUN} python -m pytest -n 1 -s -v --make-reports=tests_torch_mps \
# --report-log=tests_torch_mps.log \
+5 -5
View File
@@ -187,7 +187,7 @@ jobs:
- name: Run Flax TPU tests
env:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
HF_TOKEN: ${{ secrets.HF_TOKEN }}
run: |
python -m pytest -n 0 \
-s -v -k "Flax" \
@@ -235,7 +235,7 @@ jobs:
- name: Run ONNXRuntime CUDA tests
env:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
HF_TOKEN: ${{ secrets.HF_TOKEN }}
run: |
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "Onnx" \
@@ -283,7 +283,7 @@ jobs:
python utils/print_env.py
- name: Run example tests on GPU
env:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
HF_TOKEN: ${{ secrets.HF_TOKEN }}
RUN_COMPILE: yes
run: |
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "compile" --make-reports=tests_torch_compile_cuda tests/
@@ -326,7 +326,7 @@ jobs:
python utils/print_env.py
- name: Run example tests on GPU
env:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
HF_TOKEN: ${{ secrets.HF_TOKEN }}
run: |
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "xformers" --make-reports=tests_torch_xformers_cuda tests/
- name: Failure short reports
@@ -372,7 +372,7 @@ jobs:
- name: Run example tests on GPU
env:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
HF_TOKEN: ${{ secrets.HF_TOKEN }}
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install timm
+8 -8
View File
@@ -81,7 +81,7 @@ jobs:
python utils/print_env.py
- name: Slow PyTorch CUDA checkpoint tests on Ubuntu
env:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
HF_TOKEN: ${{ secrets.HF_TOKEN }}
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
@@ -135,7 +135,7 @@ jobs:
- name: Run PyTorch CUDA tests
env:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
HF_TOKEN: ${{ secrets.HF_TOKEN }}
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
@@ -186,7 +186,7 @@ jobs:
- name: Run PyTorch CUDA tests
env:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
HF_TOKEN: ${{ secrets.HF_TOKEN }}
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
@@ -241,7 +241,7 @@ jobs:
- name: Run slow Flax TPU tests
env:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
HF_TOKEN: ${{ secrets.HF_TOKEN }}
run: |
python -m pytest -n 0 \
-s -v -k "Flax" \
@@ -289,7 +289,7 @@ jobs:
- name: Run slow ONNXRuntime CUDA tests
env:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
HF_TOKEN: ${{ secrets.HF_TOKEN }}
run: |
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "Onnx" \
@@ -337,7 +337,7 @@ jobs:
python utils/print_env.py
- name: Run example tests on GPU
env:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
HF_TOKEN: ${{ secrets.HF_TOKEN }}
RUN_COMPILE: yes
run: |
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "compile" --make-reports=tests_torch_compile_cuda tests/
@@ -380,7 +380,7 @@ jobs:
python utils/print_env.py
- name: Run example tests on GPU
env:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
HF_TOKEN: ${{ secrets.HF_TOKEN }}
run: |
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "xformers" --make-reports=tests_torch_xformers_cuda tests/
- name: Failure short reports
@@ -426,7 +426,7 @@ jobs:
- name: Run example tests on GPU
env:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
HF_TOKEN: ${{ secrets.HF_TOKEN }}
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install timm
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
+1 -1
View File
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,5 +1,5 @@
#
# Copyright 2024 The HuggingFace Inc. team.
# Copyright 2025 The HuggingFace Inc. team.
# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
@@ -1,5 +1,5 @@
#
# Copyright 2024 The HuggingFace Inc. team.
# Copyright 2025 The HuggingFace Inc. team.
# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
@@ -1,5 +1,5 @@
#
# Copyright 2024 The HuggingFace Inc. team.
# Copyright 2025 The HuggingFace Inc. team.
# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
+1 -1
View File
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
+1 -1
View File
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
+1 -1
View File
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
+1 -1
View File
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
+1 -1
View File
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
+1 -1
View File
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
+4 -4
View File
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1716,9 +1716,9 @@ def main(args):
pipeline = FluxPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
text_encoder=accelerator.unwrap_model(text_encoder_one),
text_encoder_2=accelerator.unwrap_model(text_encoder_two),
transformer=accelerator.unwrap_model(transformer),
text_encoder=accelerator.unwrap_model(text_encoder_one, keep_fp32_wrapper=False),
text_encoder_2=accelerator.unwrap_model(text_encoder_two, keep_fp32_wrapper=False),
transformer=accelerator.unwrap_model(transformer, keep_fp32_wrapper=False),
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
+12 -9
View File
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -54,7 +54,11 @@ from diffusers import (
)
from diffusers.loaders import StableDiffusionLoraLoaderMixin
from diffusers.optimization import get_scheduler
from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params
from diffusers.training_utils import (
_set_state_dict_into_text_encoder,
cast_training_params,
free_memory,
)
from diffusers.utils import (
check_min_version,
convert_state_dict_to_diffusers,
@@ -151,14 +155,14 @@ def log_validation(
if args.validation_images is None:
images = []
for _ in range(args.num_validation_images):
with torch.cuda.amp.autocast():
with torch.amp.autocast(accelerator.device.type):
image = pipeline(**pipeline_args, generator=generator).images[0]
images.append(image)
else:
images = []
for image in args.validation_images:
image = Image.open(image)
with torch.cuda.amp.autocast():
with torch.amp.autocast(accelerator.device.type):
image = pipeline(**pipeline_args, image=image, generator=generator).images[0]
images.append(image)
@@ -177,7 +181,7 @@ def log_validation(
)
del pipeline
torch.cuda.empty_cache()
free_memory()
return images
@@ -793,7 +797,7 @@ def main(args):
cur_class_images = len(list(class_images_dir.iterdir()))
if cur_class_images < args.num_class_images:
torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
torch_dtype = torch.float16 if accelerator.device.type in ("cuda", "xpu") else torch.float32
if args.prior_generation_precision == "fp32":
torch_dtype = torch.float32
elif args.prior_generation_precision == "fp16":
@@ -829,8 +833,7 @@ def main(args):
image.save(image_filename)
del pipeline
if torch.cuda.is_available():
torch.cuda.empty_cache()
free_memory()
# Handle the repository creation
if accelerator.is_main_process:
@@ -1085,7 +1088,7 @@ def main(args):
tokenizer = None
gc.collect()
torch.cuda.empty_cache()
free_memory()
else:
pre_computed_encoder_hidden_states = None
validation_prompt_encoder_hidden_states = None
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
+1 -1
View File
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
+1 -1
View File
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
+1 -1
View File
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
+1 -1
View File
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
+1 -1
View File
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
+1 -1
View File
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
+1 -1
View File
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
# Copyright 2025 The HuggingFace Inc. team.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
+1 -1
View File
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
+1 -1
View File
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
+10
View File
@@ -1215,10 +1215,20 @@ class FeedForward(nn.Module):
bias: bool = True,
):
super().__init__()
if inner_dim is None:
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
self._dim = dim
self._dim_out = dim_out
self._mult = mult
self._dropout = dropout
self._activation_fn = activation_fn
self._final_dropout = final_dropout
self._inner_dim = inner_dim
self._bias = bias
if activation_fn == "gelu":
act_fn = GELU(dim, inner_dim, bias=bias)
if activation_fn == "gelu-approximate":
+186
View File
@@ -0,0 +1,186 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import torch
from ..utils import logging
from .activations import GEGLU, GELU, ApproximateGELU, LinearActivation, SwiGLU
from .attention import FeedForward
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class _MemoryOptimizedFeedForward(torch.nn.Module):
r"""
See [`~models.attention.FeedForward`] parameter documentation. This class is a copy of the FeedForward class. The
only difference is that this module is optimized for memory.
This method achieves memory savings by applying the ideas of tensor-parallelism sequentially. Input projection
layers are split column-wise and output projection layers are split row-wise. This allows for the computation of
the feedforward pass to occur without ever materializing the full intermediate tensor. Typically, the intermediate
tensor takes 4x-8x more memory than the input tensor. This method reduces that with a small performance tradeoff.
"""
def __init__(
self,
dim: int,
dim_out: Optional[int] = None,
mult: int = 4,
dropout: float = 0.0,
activation_fn: str = "geglu",
final_dropout: bool = False,
inner_dim: Optional[int] = None,
bias: bool = True,
num_splits: int = 4,
) -> None:
super().__init__()
if inner_dim is None:
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
dim_split = inner_dim // num_splits
if inner_dim % dim_split != 0:
raise ValueError(f"inner_dim must be divisible by {mult=}, or {num_splits=} if provided.")
self._dim = dim
self._dim_out = dim_out
self._mult = mult
self._dropout = dropout
self._activation_fn = activation_fn
self._final_dropout = final_dropout
self._inner_dim = inner_dim
self._bias = bias
self._num_splits = num_splits
def get_activation_fn(dim_: int, inner_dim_: int):
if activation_fn == "gelu":
act_fn = GELU(dim_, inner_dim_, bias=bias)
if activation_fn == "gelu-approximate":
act_fn = GELU(dim_, inner_dim_, approximate="tanh", bias=bias)
elif activation_fn == "geglu":
act_fn = GEGLU(dim_, inner_dim_, bias=bias)
elif activation_fn == "geglu-approximate":
act_fn = ApproximateGELU(dim_, inner_dim_, bias=bias)
elif activation_fn == "swiglu":
act_fn = SwiGLU(dim_, inner_dim_, bias=bias)
elif activation_fn == "linear-silu":
act_fn = LinearActivation(dim_, inner_dim_, bias=bias, activation="silu")
return act_fn
# Split column-wise
self.proj_in = torch.nn.ModuleList([get_activation_fn(dim, dim_split) for _ in range(inner_dim // dim_split)])
self.dropout = torch.nn.Dropout(dropout)
# Split row-wise
self.proj_out = torch.nn.ModuleList(
[torch.nn.Linear(dim_split, dim_out, bias=False) for _ in range(inner_dim // dim_split)]
)
self.bias = None
if bias:
self.bias = torch.nn.Parameter(torch.zeros(dim_out))
self.final_dropout = None
if final_dropout:
self.final_dropout = torch.nn.Dropout(dropout)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# Output tensor for "all_reduce" operation
output = hidden_states.new_zeros(hidden_states.shape)
# Apply feedforward pass sequentially since this is intended for memory optimization on a single GPU
for proj_in, proj_out in zip(self.proj_in, self.proj_out):
out = proj_in(hidden_states)
out = self.dropout(out)
out = proj_out(out)
# Perform "all_reduce"
output += out
if self.bias is not None:
output += self.bias
if self.final_dropout is not None:
output = self.final_dropout(output)
return output
def apply_memory_optimized_feedforward(module: torch.nn.Module, num_splits: Optional[int] = None) -> torch.nn.Module:
module_dict = dict(module.named_modules())
for name, submodule in module_dict.items():
if not isinstance(submodule, FeedForward):
continue
logger.debug(f"Applying memory optimized feedforward to layer '{name}'")
state_dict = submodule.state_dict()
num_splits = submodule._mult if num_splits is None else num_splits
# remap net.0.proj.weight
if isinstance(submodule.net[0], (GEGLU, SwiGLU)):
net_0_proj = state_dict.pop("net.0.proj.weight")
proj, gate = net_0_proj.chunk(2, dim=0)
proj = proj.chunk(num_splits, dim=0)
gate = gate.chunk(num_splits, dim=0)
for i in range(num_splits):
state_dict[f"proj_in.{i}.proj.weight"] = torch.cat([proj[i], gate[i]], dim=0)
else:
net_0_proj = state_dict.pop("net.0.proj.weight")
net_0_proj = net_0_proj.chunk(num_splits, dim=0)
for i in range(num_splits):
state_dict[f"proj_in.{i}.proj.weight"] = net_0_proj[i]
# remap net.0.proj.bias
if "net.0.proj.bias" in state_dict:
net_0_proj_bias = state_dict.pop("net.0.proj.bias")
net_0_proj_bias = net_0_proj_bias.chunk(num_splits, dim=0)
for i in range(num_splits):
state_dict[f"proj_in.{i}.proj.bias"] = net_0_proj_bias[i]
# remap net.2.weight
net_2_weight = state_dict.pop("net.2.weight")
net_2_weight = net_2_weight.chunk(num_splits, dim=1)
for i in range(num_splits):
state_dict[f"proj_out.{i}.weight"] = net_2_weight[i]
# remap net.2.bias
if "net.2.bias" in state_dict:
net_2_bias = state_dict.pop("net.2.bias")
state_dict["bias"] = net_2_bias
with torch.device("meta"):
new_ff = _MemoryOptimizedFeedForward(
dim=submodule._dim,
dim_out=submodule._dim_out,
mult=submodule._mult,
dropout=submodule._dropout,
activation_fn=submodule._activation_fn,
final_dropout=submodule._final_dropout,
inner_dim=submodule._inner_dim,
bias=submodule._bias,
num_splits=num_splits,
)
new_ff.load_state_dict(state_dict, strict=True, assign=True)
parent_module_name, _, submodule_name = name.rpartition(".")
parent_module = module_dict[parent_module_name]
setattr(parent_module, submodule_name, new_ff)
return module
+1 -1
View File
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
# Copyright 2025 The HuggingFace Inc. team.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
+1 -1
View File
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
+1 -1
View File
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
# Copyright 2025 The HuggingFace Inc. team.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -82,6 +82,20 @@ class GLUMBConv(nn.Module):
return hidden_states
class SanaModulatedNorm(nn.Module):
def __init__(self, dim: int, elementwise_affine: bool = False, eps: float = 1e-6):
super().__init__()
self.norm = nn.LayerNorm(dim, elementwise_affine=elementwise_affine, eps=eps)
def forward(
self, hidden_states: torch.Tensor, temb: torch.Tensor, scale_shift_table: torch.Tensor
) -> torch.Tensor:
hidden_states = self.norm(hidden_states)
shift, scale = (scale_shift_table[None] + temb[:, None].to(scale_shift_table.device)).chunk(2, dim=1)
hidden_states = hidden_states * (1 + scale) + shift
return hidden_states
class SanaTransformerBlock(nn.Module):
r"""
Transformer block introduced in [Sana](https://huggingface.co/papers/2410.10629).
@@ -221,7 +235,7 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
"""
_supports_gradient_checkpointing = True
_no_split_modules = ["SanaTransformerBlock", "PatchEmbed"]
_no_split_modules = ["SanaTransformerBlock", "PatchEmbed", "SanaModulatedNorm"]
@register_to_config
def __init__(
@@ -288,8 +302,7 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
# 4. Output blocks
self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
self.norm_out = SanaModulatedNorm(inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
self.gradient_checkpointing = False
@@ -462,13 +475,8 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
)
# 3. Normalization
shift, scale = (
self.scale_shift_table[None] + embedded_timestep[:, None].to(self.scale_shift_table.device)
).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states)
hidden_states = self.norm_out(hidden_states, embedded_timestep, self.scale_shift_table)
# 4. Modulation
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.proj_out(hidden_states)
# 5. Unpatchify
+1 -1
View File
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
+1 -1
View File
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
+1 -1
View File
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
+1 -1
View File
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
# Copyright 2025 The HuggingFace Inc. team.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
# Copyright 2025 The HuggingFace Inc. team.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
+1 -1
View File
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
# Copyright 2025 The HuggingFace Inc. team.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

Some files were not shown because too many files have changed in this diff Show More