Compare commits

..

3 Commits

Author SHA1 Message Date
Sayak Paul 687982e607 Merge branch 'main' into chroma-docs 2025-06-19 20:19:14 +05:30
DN6 802651e205 update 2025-06-19 19:41:32 +05:30
DN6 907ecf72b1 update 2025-06-19 14:20:40 +05:30
8 changed files with 71 additions and 1992 deletions
+49 -17
View File
@@ -27,9 +27,36 @@ Chroma can use all the same optimizations as Flux.
</Tip>
## Inference (Single File)
## Inference
The `ChromaTransformer2DModel` supports loading checkpoints in the original format. This is also useful when trying to load finetunes or quantized versions of the models that have been published by the community.
The Diffusers version of Chroma is based on the [`unlocked-v37`](https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors) version of the original model, which is available in the [Chroma repository](https://huggingface.co/lodestones/Chroma).
```python
import torch
from diffusers import ChromaPipeline
pipe = ChromaPipeline.from_pretrained("lodestones/Chroma", torch_dtype=torch.bfloat16)
pipe.enabe_model_cpu_offload()
prompt = [
"A high-fashion close-up portrait of a blonde woman in clear sunglasses. The image uses a bold teal and red color split for dramatic lighting. The background is a simple teal-green. The photo is sharp and well-composed, and is designed for viewing with anaglyph 3D glasses for optimal effect. It looks professionally done."
]
negative_prompt = ["low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors"]
image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
generator=torch.Generator("cpu").manual_seed(433),
num_inference_steps=40,
guidance_scale=3.0,
num_images_per_prompt=1,
).images[0]
image.save("chroma.png")
```
## Loading from a single file
To use updated model checkpoints that are not in the Diffusers format, you can use the `ChromaTransformer2DModel` class to load the model from a single file in the original format. This is also useful when trying to load finetunes or quantized versions of the models that have been published by the community.
The following example demonstrates how to run Chroma from a single file.
@@ -38,30 +65,29 @@ Then run the following example
```python
import torch
from diffusers import ChromaTransformer2DModel, ChromaPipeline
from transformers import T5EncoderModel
bfl_repo = "black-forest-labs/FLUX.1-dev"
model_id = "lodestones/Chroma"
dtype = torch.bfloat16
transformer = ChromaTransformer2DModel.from_single_file("https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v35.safetensors", torch_dtype=dtype)
text_encoder = T5EncoderModel.from_pretrained(bfl_repo, subfolder="text_encoder_2", torch_dtype=dtype)
tokenizer = T5Tokenizer.from_pretrained(bfl_repo, subfolder="tokenizer_2", torch_dtype=dtype)
pipe = ChromaPipeline.from_pretrained(bfl_repo, transformer=transformer, text_encoder=text_encoder, tokenizer=tokenizer, torch_dtype=dtype)
transformer = ChromaTransformer2DModel.from_single_file("https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors", torch_dtype=dtype)
pipe = ChromaPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=dtype)
pipe.enable_model_cpu_offload()
prompt = "A cat holding a sign that says hello world"
prompt = [
"A high-fashion close-up portrait of a blonde woman in clear sunglasses. The image uses a bold teal and red color split for dramatic lighting. The background is a simple teal-green. The photo is sharp and well-composed, and is designed for viewing with anaglyph 3D glasses for optimal effect. It looks professionally done."
]
negative_prompt = ["low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors"]
image = pipe(
prompt,
guidance_scale=4.0,
output_type="pil",
num_inference_steps=26,
generator=torch.Generator("cpu").manual_seed(0)
prompt=prompt,
negative_prompt=negative_prompt,
generator=torch.Generator("cpu").manual_seed(433),
num_inference_steps=40,
guidance_scale=3.0,
).images[0]
image.save("image.png")
image.save("chroma-single-file.png")
```
## ChromaPipeline
@@ -69,3 +95,9 @@ image.save("image.png")
[[autodoc]] ChromaPipeline
- all
- __call__
## ChromaImg2ImgPipeline
[[autodoc]] ChromaImg2ImgPipeline
- all
- __call__
+8 -37
View File
@@ -20,7 +20,6 @@ import safetensors.torch
import torch
from ..utils import get_logger, is_accelerate_available
from ..utils.import_utils import is_deepspeed_available, is_deepspeed_version
from .hooks import HookRegistry, ModelHook
@@ -28,8 +27,6 @@ if is_accelerate_available():
from accelerate.hooks import AlignDevicesHook, CpuOffload
from accelerate.utils import send_to_device
if is_deepspeed_available() and is_deepspeed_version(">=", "0.16"):
from ..utils.state_dict_utils import _fast_aio_save
logger = get_logger(__name__) # pylint: disable=invalid-name
@@ -65,7 +62,6 @@ class ModuleGroup:
low_cpu_mem_usage: bool = False,
onload_self: bool = True,
offload_to_disk_path: Optional[str] = None,
_enable_deepnvme_disk_offloading: Optional[bool] = False,
) -> None:
self.modules = modules
self.offload_device = offload_device
@@ -84,9 +80,7 @@ class ModuleGroup:
self._is_offloaded_to_disk = False
if self.offload_to_disk_path:
self._enable_deepnvme_disk_offloading = _enable_deepnvme_disk_offloading
ext = ".pt" if _enable_deepnvme_disk_offloading else ".safetensors"
self.param_file_path = os.path.join(self.offload_to_disk_path, f"group_{id(self)}{ext}")
self.safetensors_file_path = os.path.join(self.offload_to_disk_path, f"group_{id(self)}.safetensors")
all_tensors = []
for module in self.modules:
@@ -159,11 +153,8 @@ class ModuleGroup:
with context:
if self.stream is not None:
# Load to CPU from disk, pin, and async copy to device for overlapping transfer and compute
if self._enable_deepnvme_disk_offloading:
loaded_cpu_tensors = torch.load(self.param_file_path, weights_only=True, map_location="cpu")
else:
loaded_cpu_tensors = safetensors.torch.load_file(self.param_file_path, device="cpu")
# Load to CPU, pin, and async copy to device for overlapping transfer and compute
loaded_cpu_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu")
for key, tensor_obj in self.key_to_tensor.items():
pinned_tensor = loaded_cpu_tensors[key].pin_memory()
tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking)
@@ -174,12 +165,7 @@ class ModuleGroup:
onload_device = (
self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
)
if self._enable_deepnvme_disk_offloading:
loaded_tensors = torch.load(
self.param_file_path, weights_only=True, map_location=onload_device
)
else:
loaded_tensors = safetensors.torch.load_file(self.param_file_path, device=onload_device)
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
for key, tensor_obj in self.key_to_tensor.items():
tensor_obj.data = loaded_tensors[key]
return
@@ -232,18 +218,15 @@ class ModuleGroup:
if self.offload_to_disk_path:
# TODO: we can potentially optimize this code path by checking if the _all_ the desired
# safetensor files exist on the disk and if so, skip this step entirely, reducing IO
# overhead. Currently, we just check if the given `param_file_path` exists and if not
# overhead. Currently, we just check if the given `safetensors_file_path` exists and if not
# we perform a write.
# Check if the file has been saved in this session or if it already exists on disk.
if not self._is_offloaded_to_disk and not os.path.exists(self.param_file_path):
os.makedirs(os.path.dirname(self.param_file_path), exist_ok=True)
if not self._is_offloaded_to_disk and not os.path.exists(self.safetensors_file_path):
os.makedirs(os.path.dirname(self.safetensors_file_path), exist_ok=True)
tensors_to_save = {
key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items()
}
if self._enable_deepnvme_disk_offloading:
_fast_aio_save(tensors_to_save, self.param_file_path)
else:
safetensors.torch.save_file(tensors_to_save, self.param_file_path)
safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path)
# The group is now considered offloaded to disk for the rest of the session.
self._is_offloaded_to_disk = True
@@ -443,7 +426,6 @@ def apply_group_offloading(
record_stream: bool = False,
low_cpu_mem_usage: bool = False,
offload_to_disk_path: Optional[str] = None,
_enable_deepnvme_disk_offloading: Optional[bool] = False,
) -> None:
r"""
Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and
@@ -502,8 +484,6 @@ def apply_group_offloading(
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
the CPU memory is a bottleneck but may counteract the benefits of using streams.
(TODO: include example with `offload_to_disk_path`)
Example:
```python
>>> from diffusers import CogVideoXTransformer3DModel
@@ -549,7 +529,6 @@ def apply_group_offloading(
stream=stream,
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
_enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading,
)
elif offload_type == "leaf_level":
_apply_group_offloading_leaf_level(
@@ -561,7 +540,6 @@ def apply_group_offloading(
stream=stream,
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
_enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading,
)
else:
raise ValueError(f"Unsupported offload_type: {offload_type}")
@@ -577,7 +555,6 @@ def _apply_group_offloading_block_level(
record_stream: Optional[bool] = False,
low_cpu_mem_usage: bool = False,
offload_to_disk_path: Optional[str] = None,
_enable_deepnvme_disk_offloading: Optional[bool] = False,
) -> None:
r"""
This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to
@@ -638,7 +615,6 @@ def _apply_group_offloading_block_level(
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
onload_self=True,
_enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading,
)
matched_module_groups.append(group)
for j in range(i, i + len(current_modules)):
@@ -673,7 +649,6 @@ def _apply_group_offloading_block_level(
stream=None,
record_stream=False,
onload_self=True,
_enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading,
)
if stream is None:
_apply_group_offloading_hook(module, unmatched_group, None)
@@ -690,7 +665,6 @@ def _apply_group_offloading_leaf_level(
record_stream: Optional[bool] = False,
low_cpu_mem_usage: bool = False,
offload_to_disk_path: Optional[str] = None,
_enable_deepnvme_disk_offloading: Optional[bool] = False,
) -> None:
r"""
This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory
@@ -741,7 +715,6 @@ def _apply_group_offloading_leaf_level(
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
onload_self=True,
_enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading,
)
_apply_group_offloading_hook(submodule, group, None)
modules_with_group_offloading.add(name)
@@ -789,7 +762,6 @@ def _apply_group_offloading_leaf_level(
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
onload_self=True,
_enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading,
)
_apply_group_offloading_hook(parent_module, group, None)
@@ -811,7 +783,6 @@ def _apply_group_offloading_leaf_level(
record_stream=False,
low_cpu_mem_usage=low_cpu_mem_usage,
onload_self=True,
_enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading,
)
_apply_lazy_group_offloading_hook(module, unmatched_group, None)
-2
View File
@@ -549,7 +549,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
record_stream: bool = False,
low_cpu_mem_usage=False,
offload_to_disk_path: Optional[str] = None,
_enable_deepnvme_disk_offloading: Optional[bool] = False,
) -> None:
r"""
Activates group offloading for the current model.
@@ -600,7 +599,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
offload_to_disk_path=offload_to_disk_path,
_enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading,
)
def save_pretrained(
@@ -52,20 +52,21 @@ EXAMPLE_DOC_STRING = """
>>> import torch
>>> from diffusers import ChromaPipeline
>>> model_id = "lodestones/Chroma"
>>> ckpt_path = "https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors"
>>> transformer = ChromaTransformer2DModel.from_single_file(ckpt_path, torch_dtype=torch.bfloat16)
>>> text_encoder = AutoModel.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="text_encoder_2")
>>> tokenizer = AutoTokenizer.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="tokenizer_2")
>>> pipe = ChromaImg2ImgPipeline.from_pretrained(
... "black-forest-labs/FLUX.1-schnell",
>>> pipe = ChromaPipeline.from_pretrained(
... model_id,
... transformer=transformer,
... text_encoder=text_encoder,
... tokenizer=tokenizer,
... torch_dtype=torch.bfloat16,
... )
>>> pipe.enable_model_cpu_offload()
>>> prompt = "A cat holding a sign that says hello world"
>>> negative_prompt = "low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors"
>>> prompt = [
... "A high-fashion close-up portrait of a blonde woman in clear sunglasses. The image uses a bold teal and red color split for dramatic lighting. The background is a simple teal-green. The photo is sharp and well-composed, and is designed for viewing with anaglyph 3D glasses for optimal effect. It looks professionally done."
... ]
>>> negative_prompt = [
... "low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors"
... ]
>>> image = pipe(prompt, negative_prompt=negative_prompt).images[0]
>>> image.save("chroma.png")
```
@@ -51,26 +51,21 @@ EXAMPLE_DOC_STRING = """
```py
>>> import torch
>>> from diffusers import ChromaTransformer2DModel, ChromaImg2ImgPipeline
>>> from transformers import AutoModel, Autotokenizer
>>> model_id = "lodestones/Chroma"
>>> ckpt_path = "https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors"
>>> transformer = ChromaTransformer2DModel.from_single_file(ckpt_path, torch_dtype=torch.bfloat16)
>>> text_encoder = AutoModel.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="text_encoder_2")
>>> tokenizer = AutoTokenizer.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="tokenizer_2")
>>> pipe = ChromaImg2ImgPipeline.from_pretrained(
... "black-forest-labs/FLUX.1-schnell",
... model_id,
... transformer=transformer,
... text_encoder=text_encoder,
... tokenizer=tokenizer,
... torch_dtype=torch.bfloat16,
... )
>>> pipe.enable_model_cpu_offload()
>>> image = load_image(
>>> init_image = load_image(
... "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
... )
>>> prompt = "a scenic fastasy landscape with a river and mountains in the background, vibrant colors, detailed, high resolution"
>>> negative_prompt = "low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors"
>>> image = pipe(prompt, image=image, negative_prompt=negative_prompt).images[0]
>>> image = pipe(prompt, image=init_image, negative_prompt=negative_prompt).images[0]
>>> image.save("chroma-img2img.png")
```
"""
File diff suppressed because it is too large Load Diff
-18
View File
@@ -220,11 +220,6 @@ _pytorch_retinaface_available, _pytorch_retinaface_version = _is_package_availab
_better_profanity_available, _better_profanity_version = _is_package_available("better_profanity")
_nltk_available, _nltk_version = _is_package_available("nltk")
_cosmos_guardrail_available, _cosmos_guardrail_version = _is_package_available("cosmos_guardrail")
_deepspeed_available, _deepspeed_version = _is_package_available("deepspeed")
def is_deepspeed_available():
return _deepspeed_available
def is_torch_available():
@@ -660,19 +655,6 @@ def is_torch_version(operation: str, version: str):
return compare_versions(parse(_torch_version), operation, version)
def is_deepspeed_version(operation: str, version: str):
"""
Compares the current DeepSpeed version to a given reference with an operation.
Args:
operation (`str`):
A string representation of an operator, such as `">"` or `"<="`
version (`str`):
A string version of DeepSpeed
"""
return compare_versions(parse(_deepspeed_version), operation, version)
def is_torch_xla_version(operation: str, version: str):
"""
Compares the current torch_xla version to a given reference with an operation.
+1 -58
View File
@@ -18,19 +18,13 @@ State dict utilities: utility methods for converting state dicts easily
import enum
import json
from .import_utils import is_deepspeed_available, is_deepspeed_version, is_torch_available
from .import_utils import is_torch_available
from .logging import get_logger
if is_torch_available():
import torch
if is_deepspeed_available() and is_deepspeed_version(">", "0.16"):
from deepspeed.io import FastFileWriter, FastFileWriterConfig
from deepspeed.ops.op_builder import AsyncIOBuilder, GDSBuilder
from .deep_nvme_utils import save as _nvme_save
logger = get_logger(__name__)
@@ -370,54 +364,3 @@ def _load_sft_state_dict_metadata(model_file: str):
return json.loads(raw) if raw else None
else:
return None
# Utilities below are taken from
# https://github.com/deepspeedai/DeepSpeedExamples/blob/28a984e77b8d096dadc6389b6d1440b823587e28/deepnvme/model_checkpoint/torch_save_utils.py#L16
def _load_io_ops(args):
if AsyncIOBuilder().is_compatible():
AsyncIOBuilder().load(verbose=False)
if args.gpu and GDSBuilder().is_compatible():
GDSBuilder().load(verbose=False)
def _get_aio_handle():
AIO_QUEUE_DEPTH = 8
AIO_BLOCK_SIZE = 8 * (1024**2)
AIO_INTRA_OP_PARALLEL = 1
AIO_SINGLE_SUBMIT = False
h = (
AsyncIOBuilder()
.load(verbose=False)
.aio_handle(
block_size=AIO_BLOCK_SIZE,
queue_depth=AIO_QUEUE_DEPTH,
single_submit=AIO_SINGLE_SUBMIT,
overlap_events=AIO_SINGLE_SUBMIT,
intra_op_parallelism=AIO_INTRA_OP_PARALLEL,
)
)
return h
def _get_aio_components():
PINNED_BUFFER_MB = 64
h = _get_aio_handle()
pinned_memory = torch.zeros(PINNED_BUFFER_MB * (1024**2), dtype=torch.uint8, device="cpu").pin_memory()
return h, pinned_memory
def _fast_aio_save(buffer, file, single_io_buffer=False):
h, pinned_memory = _get_aio_components()
fast_writer_config = FastFileWriterConfig(
dnvme_handle=h,
pinned_tensor=pinned_memory,
double_buffer=not single_io_buffer,
num_parallel_writers=1,
writer_rank=0,
)
ds_fast_writer = FastFileWriter(file_path=file, config=fast_writer_config)
_nvme_save(f=ds_fast_writer, obj=buffer, _use_new_zipfile_serialization=False)
ds_fast_writer.close()