Compare commits

..

4 Commits

Author SHA1 Message Date
sayakpaul f9e27de31a start 2025-06-18 17:05:01 +05:30
Sayak Paul 05e867784d [tests] device_map tests for all models. (#11708)
* device_map tests for all models.

* updates

* Update tests/models/test_modeling_common.py

Co-authored-by: Aryan <aryan@huggingface.co>

* fix device_map in test

---------

Co-authored-by: Aryan <aryan@huggingface.co>
2025-06-18 10:52:06 +05:30
Leo Jiang d72184eba3 [training] add ds support to lora hidream (#11737)
* [training] add ds support to lora hidream

* Apply style fixes

---------

Co-authored-by: J石页 <jiangshuo9@h-partners.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-06-18 09:26:02 +05:30
Saurabh Misra 5ce4814af1 ️ Speed up method AutoencoderKLWan.clear_cache by 886% (#11665)
* ️ Speed up method `AutoencoderKLWan.clear_cache` by 886%

**Key optimizations:**
- Compute the number of `WanCausalConv3d` modules in each model (`encoder`/`decoder`) **only once during initialization**, store in `self._cached_conv_counts`. This removes unnecessary repeated tree traversals at every `clear_cache` call, which was the main bottleneck (from profiling).
- The internal helper `_count_conv3d_fast` is optimized via a generator expression with `sum` for efficiency.

All comments from the original code are preserved, except for updated or removed local docstrings/comments relevant to changed lines.  
**Function signatures and outputs remain unchanged.**

* Apply style fixes

* Apply suggestions from code review

Co-authored-by: Aryan <contact.aryanvs@gmail.com>

* Apply style fixes

---------

Co-authored-by: codeflash-ai[bot] <148906541+codeflash-ai[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: Aryan <aryan@huggingface.co>
Co-authored-by: Aryan <contact.aryanvs@gmail.com>
Co-authored-by: Aseem Saxena <aseem.bits@gmail.com>
2025-06-18 08:46:03 +05:30
7 changed files with 1498 additions and 69 deletions
-13
View File
@@ -302,13 +302,6 @@ compute-bound, [group-offloading](#group-offloading) tends to be better. Group o
</Tip>
<Tip>
When using offloading, users can additionally compile the diffusion transformer/unet to get a
good speed-memory trade-off. First set `torch._dynamo.config.cache_size_limit=1000`, and then before calling the pipeline, add `pipeline.transformer.compile()`.
</Tip>
## Layerwise casting
Layerwise casting stores weights in a smaller data format (for example, `torch.float8_e4m3fn` and `torch.float8_e5m2`) to use less memory and upcasts those weights to a higher precision like `torch.float16` or `torch.bfloat16` for computation. Certain layers (normalization and modulation related weights) are skipped because storing them in fp8 can degrade generation quality.
@@ -372,12 +365,6 @@ apply_layerwise_casting(
)
```
<Tip>
Layerwise casting can be combined with group offloading.
</Tip>
## torch.channels_last
[torch.channels_last](https://pytorch.org/tutorials/intermediate/memory_format_tutorial.html) flips how tensors are stored from `(batch size, channels, height, width)` to `(batch size, heigh, width, channels)`. This aligns the tensors with how the hardware sequentially accesses the tensors stored in memory and avoids skipping around in memory to access the pixel values.
@@ -29,7 +29,7 @@ from pathlib import Path
import numpy as np
import torch
import transformers
from accelerate import Accelerator
from accelerate import Accelerator, DistributedType
from accelerate.logging import get_logger
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
from huggingface_hub import create_repo, upload_folder
@@ -1181,13 +1181,15 @@ def main(args):
transformer_lora_layers_to_save = None
for model in models:
if isinstance(model, type(unwrap_model(transformer))):
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
model = unwrap_model(model)
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
else:
raise ValueError(f"unexpected save model: {model.__class__}")
# make sure to pop weight so that corresponding model is not saved again
weights.pop()
if weights:
weights.pop()
HiDreamImagePipeline.save_lora_weights(
output_dir,
@@ -1197,13 +1199,20 @@ def main(args):
def load_model_hook(models, input_dir):
transformer_ = None
while len(models) > 0:
model = models.pop()
if not accelerator.distributed_type == DistributedType.DEEPSPEED:
while len(models) > 0:
model = models.pop()
if isinstance(model, type(unwrap_model(transformer))):
transformer_ = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
model = unwrap_model(model)
transformer_ = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")
else:
transformer_ = HiDreamImageTransformer2DModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="transformer"
)
transformer_.add_adapter(transformer_lora_config)
lora_state_dict = HiDreamImagePipeline.lora_state_dict(input_dir)
@@ -1655,7 +1664,7 @@ def main(args):
progress_bar.update(1)
global_step += 1
if accelerator.is_main_process:
if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED:
if global_step % args.checkpointing_steps == 0:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
@@ -749,6 +749,16 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
self.tile_sample_stride_height = 192
self.tile_sample_stride_width = 192
# Precompute and cache conv counts for encoder and decoder for clear_cache speedup
self._cached_conv_counts = {
"decoder": sum(isinstance(m, WanCausalConv3d) for m in self.decoder.modules())
if self.decoder is not None
else 0,
"encoder": sum(isinstance(m, WanCausalConv3d) for m in self.encoder.modules())
if self.encoder is not None
else 0,
}
def enable_tiling(
self,
tile_sample_min_height: Optional[int] = None,
@@ -801,18 +811,12 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
self.use_slicing = False
def clear_cache(self):
def _count_conv3d(model):
count = 0
for m in model.modules():
if isinstance(m, WanCausalConv3d):
count += 1
return count
self._conv_num = _count_conv3d(self.decoder)
# Use cached conv counts for decoder and encoder to avoid re-iterating modules each call
self._conv_num = self._cached_conv_counts["decoder"]
self._conv_idx = [0]
self._feat_map = [None] * self._conv_num
# cache encode
self._enc_conv_num = _count_conv3d(self.encoder)
self._enc_conv_num = self._cached_conv_counts["encoder"]
self._enc_conv_idx = [0]
self._enc_feat_map = [None] * self._enc_conv_num
File diff suppressed because it is too large Load Diff
+300
View File
@@ -0,0 +1,300 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Doc utilities: Utilities related to documentation
Adapted from:
https://github.com/huggingface/transformers/blob/5a95ed5ca0826c867e35e52f698db4d8fc907bcb/src/transformers/utils/doc.py
"""
import functools
import inspect
import re
import textwrap
import types
from collections import OrderedDict
from ..pipelines.auto_pipeline import AUTO_TEXT2IMAGE_PIPELINES_MAPPING
def get_docstring_indentation_level(func):
"""Return the indentation level of the start of the docstring of a class or function (or method)."""
# We assume classes are always defined in the global scope
if inspect.isclass(func):
return 4
source = inspect.getsource(func)
first_line = source.splitlines()[0]
function_def_level = len(first_line) - len(first_line.lstrip())
return 4 + function_def_level
def add_start_docstrings(*docstr):
def docstring_decorator(fn):
fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")
return fn
return docstring_decorator
def add_start_docstrings_to_model_forward(*docstr):
def docstring_decorator(fn):
class_name = f"[`{fn.__qualname__.split('.')[0]}`]"
intro = rf""" The {class_name} forward method, overrides the `__call__` special method.
<Tip>
Although the recipe for forward pass needs to be defined within this function, one should call the [`Module`]
instance afterwards instead of this since the former takes care of running the pre and post processing steps while
the latter silently ignores them.
</Tip>
"""
correct_indentation = get_docstring_indentation_level(fn)
current_doc = fn.__doc__ if fn.__doc__ is not None else ""
try:
first_non_empty = next(line for line in current_doc.splitlines() if line.strip() != "")
doc_indentation = len(first_non_empty) - len(first_non_empty.lstrip())
except StopIteration:
doc_indentation = correct_indentation
docs = docstr
# In this case, the correct indentation level (class method, 2 Python levels) was respected, and we should
# correctly reindent everything. Otherwise, the doc uses a single indentation level
if doc_indentation == 4 + correct_indentation:
docs = [textwrap.indent(textwrap.dedent(doc), " " * correct_indentation) for doc in docstr]
intro = textwrap.indent(textwrap.dedent(intro), " " * correct_indentation)
docstring = "".join(docs) + current_doc
fn.__doc__ = intro + docstring
return fn
return docstring_decorator
def add_end_docstrings(*docstr):
def docstring_decorator(fn):
fn.__doc__ = (fn.__doc__ if fn.__doc__ is not None else "") + "".join(docstr)
return fn
return docstring_decorator
PT_RETURN_INTRODUCTION = r"""
Returns:
[`{full_output_type}`] or `tuple(torch.FloatTensor)`: A [`{full_output_type}`] or a tuple of
`torch.FloatTensor` (if `return_dict=False` is passed) comprising various
elements depending on the model and inputs.
"""
TEXT_TO_IMAGE_PIPELINE_CLASSES = list({p[0] for p in AUTO_TEXT2IMAGE_PIPELINES_MAPPING})
def _get_indent(t):
"""Returns the indentation in the first line of t"""
search = re.search(r"^(\s*)\S", t)
return "" if search is None else search.groups()[0]
def _convert_output_args_doc(output_args_doc):
"""Convert output_args_doc to display properly."""
# Split output_arg_doc in blocks argument/description
indent = _get_indent(output_args_doc)
blocks = []
current_block = ""
for line in output_args_doc.split("\n"):
# If the indent is the same as the beginning, the line is the name of new arg.
if _get_indent(line) == indent:
if len(current_block) > 0:
blocks.append(current_block[:-1])
current_block = f"{line}\n"
else:
# Otherwise it's part of the description of the current arg.
# We need to remove 2 spaces to the indentation.
current_block += f"{line[2:]}\n"
blocks.append(current_block[:-1])
# Format each block for proper rendering
for i in range(len(blocks)):
blocks[i] = re.sub(r"^(\s+)(\S+)(\s+)", r"\1- **\2**\3", blocks[i])
blocks[i] = re.sub(r":\s*\n\s*(\S)", r" -- \1", blocks[i])
return "\n".join(blocks)
def _prepare_output_docstrings(output_type, config_class, min_indent=None, add_intro=True):
"""
Prepares the return part of the docstring using `output_type`.
"""
output_docstring = output_type.__doc__
params_docstring = None
if output_docstring is not None:
# Remove the head of the docstring to keep the list of args only
lines = output_docstring.split("\n")
i = 0
while i < len(lines) and re.search(r"^\s*(Args|Parameters):\s*$", lines[i]) is None:
i += 1
if i < len(lines):
params_docstring = "\n".join(lines[(i + 1) :])
params_docstring = _convert_output_args_doc(params_docstring)
elif add_intro:
raise ValueError(
f"No `Args` or `Parameters` section is found in the docstring of `{output_type.__name__}`. Make sure it has "
"docstring and contain either `Args` or `Parameters`."
)
# Add the return introduction
if add_intro:
full_output_type = f"{output_type.__module__}.{output_type.__name__}"
intro = PT_RETURN_INTRODUCTION
intro = intro.format(full_output_type=full_output_type, config_class=config_class)
else:
full_output_type = str(output_type)
intro = f"\nReturns:\n `{full_output_type}`"
if params_docstring is not None:
intro += ":\n"
result = intro
if params_docstring is not None:
result += params_docstring
# Apply minimum indent if necessary
if min_indent is not None:
lines = result.split("\n")
# Find the indent of the first nonempty line
i = 0
while len(lines[i]) == 0:
i += 1
indent = len(_get_indent(lines[i]))
# If too small, add indentation to all nonempty lines
if indent < min_indent:
to_add = " " * (min_indent - indent)
lines = [(f"{to_add}{line}" if len(line) > 0 else line) for line in lines]
result = "\n".join(lines)
return result
FAKE_MODEL_DISCLAIMER = """
<Tip warning={true}>
This example uses a random model as the real ones are all very big. To get proper results, you should use
{real_checkpoint} instead of {fake_checkpoint}. If you get out-of-memory when loading that checkpoint, you can
refer to our optimization docs.
</Tip>
"""
PT_TEXT_TO_IMAGE_SAMPLE = r"""
Example:
```python
>>> from diffusers import DiffusionPipeline
>>> import torch
>>> # If memory doesn't allow, enable optimizations like `enable_model_cpu_offload()`.
>>> pipe = DiffusionPipeline.from_pretrained("{checkpoint}", torch_dtype=torch.bfloat16).to("cuda")
>>> prompt = "a photo of a cute dog."
>>> image = pipe(prompt).images[0] # Configure other pipe call arguments as needed.
```
"""
PT_SAMPLE_DOCSTRINGS = {
"Text2Image": PT_TEXT_TO_IMAGE_SAMPLE
}
PIPELINE_TASKS_TO_SAMPLE_DOCSTRINGS = OrderedDict(["text-to-image", PT_TEXT_TO_IMAGE_SAMPLE])
def filter_outputs_from_example(docstring, **kwargs):
"""
Removes the lines testing an output with the doctest syntax in a code sample when it's set to `None`.
"""
for key, value in kwargs.items():
if value is not None:
continue
doc_key = "{" + key + "}"
docstring = re.sub(rf"\n([^\n]+)\n\s+{doc_key}\n", "\n", docstring)
return docstring
def add_code_sample_docstrings(
*docstr,
checkpoint=None,
output_type=None,
config_class=None,
model_cls=None,
):
def docstring_decorator(fn):
# model_class defaults to function's class if not specified otherwise
model_class = fn.__qualname__.split(".")[0] if model_cls is None else model_cls
sample_docstrings = PT_SAMPLE_DOCSTRINGS
# putting all kwargs for docstrings in a dict to be used
# with the `.format(**doc_kwargs)`. Note that string might
# be formatted with non-existing keys, which is fine.
doc_kwargs = {
"checkpoint": checkpoint,
"true": "{true}", # For <Tip warning={true}> syntax that conflicts with formatting.
}
if model_class in TEXT_TO_IMAGE_PIPELINE_CLASSES:
code_sample = sample_docstrings["Text2Image"]
else:
raise ValueError(f"Docstring can't be built for model {model_class}")
code_sample = filter_outputs_from_example(code_sample)
func_doc = (fn.__doc__ or "") + "".join(docstr)
output_doc = "" if output_type is None else _prepare_output_docstrings(output_type, config_class)
built_doc = code_sample.format(**doc_kwargs)
fn.__doc__ = func_doc + output_doc + built_doc
return fn
return docstring_decorator
def replace_return_docstrings(output_type=None, config_class=None):
def docstring_decorator(fn):
func_doc = fn.__doc__
lines = func_doc.split("\n")
i = 0
while i < len(lines) and re.search(r"^\s*Returns?:\s*$", lines[i]) is None:
i += 1
if i < len(lines):
indent = len(_get_indent(lines[i]))
lines[i] = _prepare_output_docstrings(output_type, config_class, min_indent=indent)
func_doc = "\n".join(lines)
else:
raise ValueError(
f"The function {fn} should have an empty 'Return:' or 'Returns:' in its docstring as placeholder, "
f"current docstring is:\n{func_doc}"
)
fn.__doc__ = func_doc
return fn
return docstring_decorator
def copy_func(f):
"""Returns a copy of a function f."""
# Based on http://stackoverflow.com/a/6528148/190597 (Glenn Maynard)
g = types.FunctionType(f.__code__, f.__globals__, name=f.__name__, argdefs=f.__defaults__, closure=f.__closure__)
g = functools.update_wrapper(g, f)
g.__kwdefaults__ = f.__kwdefaults__
return g
+39
View File
@@ -1736,6 +1736,45 @@ class ModelTesterMixin:
f"AutoModel forward pass diff: {max_diff} exceeds threshold {expected_max_diff}",
)
@parameterized.expand(
[
(-1, "You can't pass device_map as a negative int"),
("foo", "When passing device_map as a string, the value needs to be a device name"),
]
)
def test_wrong_device_map_raises_error(self, device_map, msg_substring):
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
with tempfile.TemporaryDirectory() as tmpdir:
model.save_pretrained(tmpdir)
with self.assertRaises(ValueError) as err_ctx:
_ = self.model_class.from_pretrained(tmpdir, device_map=device_map)
assert msg_substring in str(err_ctx.exception)
@parameterized.expand([0, "cuda", torch.device("cuda")])
@require_torch_gpu
def test_passing_non_dict_device_map_works(self, device_map):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).eval()
with tempfile.TemporaryDirectory() as tmpdir:
model.save_pretrained(tmpdir)
loaded_model = self.model_class.from_pretrained(tmpdir, device_map=device_map)
_ = loaded_model(**inputs_dict)
@parameterized.expand([("", "cuda"), ("", torch.device("cuda"))])
@require_torch_gpu
def test_passing_dict_device_map_works(self, name, device):
# There are other valid dict-based `device_map` values too. It's best to refer to
# the docs for those: https://huggingface.co/docs/accelerate/en/concept_guides/big_model_inference#the-devicemap.
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).eval()
device_map = {name: device}
with tempfile.TemporaryDirectory() as tmpdir:
model.save_pretrained(tmpdir)
loaded_model = self.model_class.from_pretrained(tmpdir, device_map=device_map)
_ = loaded_model(**inputs_dict)
@is_staging_test
class ModelPushToHubTester(unittest.TestCase):
@@ -46,7 +46,6 @@ from diffusers.utils.testing_utils import (
require_peft_backend,
require_torch_accelerator,
require_torch_accelerator_with_fp16,
require_torch_gpu,
skip_mps,
slow,
torch_all_close,
@@ -1084,42 +1083,6 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@parameterized.expand(
[
(-1, "You can't pass device_map as a negative int"),
("foo", "When passing device_map as a string, the value needs to be a device name"),
]
)
def test_wrong_device_map_raises_error(self, device_map, msg_substring):
with self.assertRaises(ValueError) as err_ctx:
_ = self.model_class.from_pretrained(
"hf-internal-testing/unet2d-sharded-dummy-subfolder", subfolder="unet", device_map=device_map
)
assert msg_substring in str(err_ctx.exception)
@parameterized.expand([0, "cuda", torch.device("cuda"), torch.device("cuda:0")])
@require_torch_gpu
def test_passing_non_dict_device_map_works(self, device_map):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
loaded_model = self.model_class.from_pretrained(
"hf-internal-testing/unet2d-sharded-dummy-subfolder", subfolder="unet", device_map=device_map
)
output = loaded_model(**inputs_dict)
assert output.sample.shape == (4, 4, 16, 16)
@parameterized.expand([("", "cuda"), ("", torch.device("cuda"))])
@require_torch_gpu
def test_passing_dict_device_map_works(self, name, device_map):
# There are other valid dict-based `device_map` values too. It's best to refer to
# the docs for those: https://huggingface.co/docs/accelerate/en/concept_guides/big_model_inference#the-devicemap.
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
loaded_model = self.model_class.from_pretrained(
"hf-internal-testing/unet2d-sharded-dummy-subfolder", subfolder="unet", device_map={name: device_map}
)
output = loaded_model(**inputs_dict)
assert output.sample.shape == (4, 4, 16, 16)
@require_peft_backend
def test_load_attn_procs_raise_warning(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()