Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 687982e607 | |||
| 802651e205 | |||
| 907ecf72b1 |
@@ -27,9 +27,36 @@ Chroma can use all the same optimizations as Flux.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Inference (Single File)
|
||||
## Inference
|
||||
|
||||
The `ChromaTransformer2DModel` supports loading checkpoints in the original format. This is also useful when trying to load finetunes or quantized versions of the models that have been published by the community.
|
||||
The Diffusers version of Chroma is based on the [`unlocked-v37`](https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors) version of the original model, which is available in the [Chroma repository](https://huggingface.co/lodestones/Chroma).
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import ChromaPipeline
|
||||
|
||||
pipe = ChromaPipeline.from_pretrained("lodestones/Chroma", torch_dtype=torch.bfloat16)
|
||||
pipe.enabe_model_cpu_offload()
|
||||
|
||||
prompt = [
|
||||
"A high-fashion close-up portrait of a blonde woman in clear sunglasses. The image uses a bold teal and red color split for dramatic lighting. The background is a simple teal-green. The photo is sharp and well-composed, and is designed for viewing with anaglyph 3D glasses for optimal effect. It looks professionally done."
|
||||
]
|
||||
negative_prompt = ["low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors"]
|
||||
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
generator=torch.Generator("cpu").manual_seed(433),
|
||||
num_inference_steps=40,
|
||||
guidance_scale=3.0,
|
||||
num_images_per_prompt=1,
|
||||
).images[0]
|
||||
image.save("chroma.png")
|
||||
```
|
||||
|
||||
## Loading from a single file
|
||||
|
||||
To use updated model checkpoints that are not in the Diffusers format, you can use the `ChromaTransformer2DModel` class to load the model from a single file in the original format. This is also useful when trying to load finetunes or quantized versions of the models that have been published by the community.
|
||||
|
||||
The following example demonstrates how to run Chroma from a single file.
|
||||
|
||||
@@ -38,30 +65,29 @@ Then run the following example
|
||||
```python
|
||||
import torch
|
||||
from diffusers import ChromaTransformer2DModel, ChromaPipeline
|
||||
from transformers import T5EncoderModel
|
||||
|
||||
bfl_repo = "black-forest-labs/FLUX.1-dev"
|
||||
model_id = "lodestones/Chroma"
|
||||
dtype = torch.bfloat16
|
||||
|
||||
transformer = ChromaTransformer2DModel.from_single_file("https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v35.safetensors", torch_dtype=dtype)
|
||||
|
||||
text_encoder = T5EncoderModel.from_pretrained(bfl_repo, subfolder="text_encoder_2", torch_dtype=dtype)
|
||||
tokenizer = T5Tokenizer.from_pretrained(bfl_repo, subfolder="tokenizer_2", torch_dtype=dtype)
|
||||
|
||||
pipe = ChromaPipeline.from_pretrained(bfl_repo, transformer=transformer, text_encoder=text_encoder, tokenizer=tokenizer, torch_dtype=dtype)
|
||||
transformer = ChromaTransformer2DModel.from_single_file("https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors", torch_dtype=dtype)
|
||||
|
||||
pipe = ChromaPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=dtype)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
prompt = "A cat holding a sign that says hello world"
|
||||
prompt = [
|
||||
"A high-fashion close-up portrait of a blonde woman in clear sunglasses. The image uses a bold teal and red color split for dramatic lighting. The background is a simple teal-green. The photo is sharp and well-composed, and is designed for viewing with anaglyph 3D glasses for optimal effect. It looks professionally done."
|
||||
]
|
||||
negative_prompt = ["low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors"]
|
||||
|
||||
image = pipe(
|
||||
prompt,
|
||||
guidance_scale=4.0,
|
||||
output_type="pil",
|
||||
num_inference_steps=26,
|
||||
generator=torch.Generator("cpu").manual_seed(0)
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
generator=torch.Generator("cpu").manual_seed(433),
|
||||
num_inference_steps=40,
|
||||
guidance_scale=3.0,
|
||||
).images[0]
|
||||
|
||||
image.save("image.png")
|
||||
image.save("chroma-single-file.png")
|
||||
```
|
||||
|
||||
## ChromaPipeline
|
||||
@@ -69,3 +95,9 @@ image.save("image.png")
|
||||
[[autodoc]] ChromaPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## ChromaImg2ImgPipeline
|
||||
|
||||
[[autodoc]] ChromaImg2ImgPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
@@ -20,7 +20,6 @@ import safetensors.torch
|
||||
import torch
|
||||
|
||||
from ..utils import get_logger, is_accelerate_available
|
||||
from ..utils.import_utils import is_deepspeed_available, is_deepspeed_version
|
||||
from .hooks import HookRegistry, ModelHook
|
||||
|
||||
|
||||
@@ -28,8 +27,6 @@ if is_accelerate_available():
|
||||
from accelerate.hooks import AlignDevicesHook, CpuOffload
|
||||
from accelerate.utils import send_to_device
|
||||
|
||||
if is_deepspeed_available() and is_deepspeed_version(">=", "0.16"):
|
||||
from ..utils.state_dict_utils import _fast_aio_save
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -65,7 +62,6 @@ class ModuleGroup:
|
||||
low_cpu_mem_usage: bool = False,
|
||||
onload_self: bool = True,
|
||||
offload_to_disk_path: Optional[str] = None,
|
||||
_enable_deepnvme_disk_offloading: Optional[bool] = False,
|
||||
) -> None:
|
||||
self.modules = modules
|
||||
self.offload_device = offload_device
|
||||
@@ -84,9 +80,7 @@ class ModuleGroup:
|
||||
self._is_offloaded_to_disk = False
|
||||
|
||||
if self.offload_to_disk_path:
|
||||
self._enable_deepnvme_disk_offloading = _enable_deepnvme_disk_offloading
|
||||
ext = ".pt" if _enable_deepnvme_disk_offloading else ".safetensors"
|
||||
self.param_file_path = os.path.join(self.offload_to_disk_path, f"group_{id(self)}{ext}")
|
||||
self.safetensors_file_path = os.path.join(self.offload_to_disk_path, f"group_{id(self)}.safetensors")
|
||||
|
||||
all_tensors = []
|
||||
for module in self.modules:
|
||||
@@ -159,11 +153,8 @@ class ModuleGroup:
|
||||
|
||||
with context:
|
||||
if self.stream is not None:
|
||||
# Load to CPU from disk, pin, and async copy to device for overlapping transfer and compute
|
||||
if self._enable_deepnvme_disk_offloading:
|
||||
loaded_cpu_tensors = torch.load(self.param_file_path, weights_only=True, map_location="cpu")
|
||||
else:
|
||||
loaded_cpu_tensors = safetensors.torch.load_file(self.param_file_path, device="cpu")
|
||||
# Load to CPU, pin, and async copy to device for overlapping transfer and compute
|
||||
loaded_cpu_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu")
|
||||
for key, tensor_obj in self.key_to_tensor.items():
|
||||
pinned_tensor = loaded_cpu_tensors[key].pin_memory()
|
||||
tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking)
|
||||
@@ -174,12 +165,7 @@ class ModuleGroup:
|
||||
onload_device = (
|
||||
self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
|
||||
)
|
||||
if self._enable_deepnvme_disk_offloading:
|
||||
loaded_tensors = torch.load(
|
||||
self.param_file_path, weights_only=True, map_location=onload_device
|
||||
)
|
||||
else:
|
||||
loaded_tensors = safetensors.torch.load_file(self.param_file_path, device=onload_device)
|
||||
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
|
||||
for key, tensor_obj in self.key_to_tensor.items():
|
||||
tensor_obj.data = loaded_tensors[key]
|
||||
return
|
||||
@@ -232,18 +218,15 @@ class ModuleGroup:
|
||||
if self.offload_to_disk_path:
|
||||
# TODO: we can potentially optimize this code path by checking if the _all_ the desired
|
||||
# safetensor files exist on the disk and if so, skip this step entirely, reducing IO
|
||||
# overhead. Currently, we just check if the given `param_file_path` exists and if not
|
||||
# overhead. Currently, we just check if the given `safetensors_file_path` exists and if not
|
||||
# we perform a write.
|
||||
# Check if the file has been saved in this session or if it already exists on disk.
|
||||
if not self._is_offloaded_to_disk and not os.path.exists(self.param_file_path):
|
||||
os.makedirs(os.path.dirname(self.param_file_path), exist_ok=True)
|
||||
if not self._is_offloaded_to_disk and not os.path.exists(self.safetensors_file_path):
|
||||
os.makedirs(os.path.dirname(self.safetensors_file_path), exist_ok=True)
|
||||
tensors_to_save = {
|
||||
key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items()
|
||||
}
|
||||
if self._enable_deepnvme_disk_offloading:
|
||||
_fast_aio_save(tensors_to_save, self.param_file_path)
|
||||
else:
|
||||
safetensors.torch.save_file(tensors_to_save, self.param_file_path)
|
||||
safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path)
|
||||
|
||||
# The group is now considered offloaded to disk for the rest of the session.
|
||||
self._is_offloaded_to_disk = True
|
||||
@@ -443,7 +426,6 @@ def apply_group_offloading(
|
||||
record_stream: bool = False,
|
||||
low_cpu_mem_usage: bool = False,
|
||||
offload_to_disk_path: Optional[str] = None,
|
||||
_enable_deepnvme_disk_offloading: Optional[bool] = False,
|
||||
) -> None:
|
||||
r"""
|
||||
Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and
|
||||
@@ -502,8 +484,6 @@ def apply_group_offloading(
|
||||
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
|
||||
the CPU memory is a bottleneck but may counteract the benefits of using streams.
|
||||
|
||||
(TODO: include example with `offload_to_disk_path`)
|
||||
|
||||
Example:
|
||||
```python
|
||||
>>> from diffusers import CogVideoXTransformer3DModel
|
||||
@@ -549,7 +529,6 @@ def apply_group_offloading(
|
||||
stream=stream,
|
||||
record_stream=record_stream,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
_enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading,
|
||||
)
|
||||
elif offload_type == "leaf_level":
|
||||
_apply_group_offloading_leaf_level(
|
||||
@@ -561,7 +540,6 @@ def apply_group_offloading(
|
||||
stream=stream,
|
||||
record_stream=record_stream,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
_enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported offload_type: {offload_type}")
|
||||
@@ -577,7 +555,6 @@ def _apply_group_offloading_block_level(
|
||||
record_stream: Optional[bool] = False,
|
||||
low_cpu_mem_usage: bool = False,
|
||||
offload_to_disk_path: Optional[str] = None,
|
||||
_enable_deepnvme_disk_offloading: Optional[bool] = False,
|
||||
) -> None:
|
||||
r"""
|
||||
This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to
|
||||
@@ -638,7 +615,6 @@ def _apply_group_offloading_block_level(
|
||||
record_stream=record_stream,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
onload_self=True,
|
||||
_enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading,
|
||||
)
|
||||
matched_module_groups.append(group)
|
||||
for j in range(i, i + len(current_modules)):
|
||||
@@ -673,7 +649,6 @@ def _apply_group_offloading_block_level(
|
||||
stream=None,
|
||||
record_stream=False,
|
||||
onload_self=True,
|
||||
_enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading,
|
||||
)
|
||||
if stream is None:
|
||||
_apply_group_offloading_hook(module, unmatched_group, None)
|
||||
@@ -690,7 +665,6 @@ def _apply_group_offloading_leaf_level(
|
||||
record_stream: Optional[bool] = False,
|
||||
low_cpu_mem_usage: bool = False,
|
||||
offload_to_disk_path: Optional[str] = None,
|
||||
_enable_deepnvme_disk_offloading: Optional[bool] = False,
|
||||
) -> None:
|
||||
r"""
|
||||
This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory
|
||||
@@ -741,7 +715,6 @@ def _apply_group_offloading_leaf_level(
|
||||
record_stream=record_stream,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
onload_self=True,
|
||||
_enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading,
|
||||
)
|
||||
_apply_group_offloading_hook(submodule, group, None)
|
||||
modules_with_group_offloading.add(name)
|
||||
@@ -789,7 +762,6 @@ def _apply_group_offloading_leaf_level(
|
||||
record_stream=record_stream,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
onload_self=True,
|
||||
_enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading,
|
||||
)
|
||||
_apply_group_offloading_hook(parent_module, group, None)
|
||||
|
||||
@@ -811,7 +783,6 @@ def _apply_group_offloading_leaf_level(
|
||||
record_stream=False,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
onload_self=True,
|
||||
_enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading,
|
||||
)
|
||||
_apply_lazy_group_offloading_hook(module, unmatched_group, None)
|
||||
|
||||
|
||||
@@ -549,7 +549,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
record_stream: bool = False,
|
||||
low_cpu_mem_usage=False,
|
||||
offload_to_disk_path: Optional[str] = None,
|
||||
_enable_deepnvme_disk_offloading: Optional[bool] = False,
|
||||
) -> None:
|
||||
r"""
|
||||
Activates group offloading for the current model.
|
||||
@@ -600,7 +599,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
record_stream=record_stream,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
offload_to_disk_path=offload_to_disk_path,
|
||||
_enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading,
|
||||
)
|
||||
|
||||
def save_pretrained(
|
||||
|
||||
@@ -52,20 +52,21 @@ EXAMPLE_DOC_STRING = """
|
||||
>>> import torch
|
||||
>>> from diffusers import ChromaPipeline
|
||||
|
||||
>>> model_id = "lodestones/Chroma"
|
||||
>>> ckpt_path = "https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors"
|
||||
>>> transformer = ChromaTransformer2DModel.from_single_file(ckpt_path, torch_dtype=torch.bfloat16)
|
||||
>>> text_encoder = AutoModel.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="text_encoder_2")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="tokenizer_2")
|
||||
>>> pipe = ChromaImg2ImgPipeline.from_pretrained(
|
||||
... "black-forest-labs/FLUX.1-schnell",
|
||||
>>> pipe = ChromaPipeline.from_pretrained(
|
||||
... model_id,
|
||||
... transformer=transformer,
|
||||
... text_encoder=text_encoder,
|
||||
... tokenizer=tokenizer,
|
||||
... torch_dtype=torch.bfloat16,
|
||||
... )
|
||||
>>> pipe.enable_model_cpu_offload()
|
||||
>>> prompt = "A cat holding a sign that says hello world"
|
||||
>>> negative_prompt = "low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors"
|
||||
>>> prompt = [
|
||||
... "A high-fashion close-up portrait of a blonde woman in clear sunglasses. The image uses a bold teal and red color split for dramatic lighting. The background is a simple teal-green. The photo is sharp and well-composed, and is designed for viewing with anaglyph 3D glasses for optimal effect. It looks professionally done."
|
||||
... ]
|
||||
>>> negative_prompt = [
|
||||
... "low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors"
|
||||
... ]
|
||||
>>> image = pipe(prompt, negative_prompt=negative_prompt).images[0]
|
||||
>>> image.save("chroma.png")
|
||||
```
|
||||
|
||||
@@ -51,26 +51,21 @@ EXAMPLE_DOC_STRING = """
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import ChromaTransformer2DModel, ChromaImg2ImgPipeline
|
||||
>>> from transformers import AutoModel, Autotokenizer
|
||||
|
||||
>>> model_id = "lodestones/Chroma"
|
||||
>>> ckpt_path = "https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors"
|
||||
>>> transformer = ChromaTransformer2DModel.from_single_file(ckpt_path, torch_dtype=torch.bfloat16)
|
||||
>>> text_encoder = AutoModel.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="text_encoder_2")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="tokenizer_2")
|
||||
>>> pipe = ChromaImg2ImgPipeline.from_pretrained(
|
||||
... "black-forest-labs/FLUX.1-schnell",
|
||||
... model_id,
|
||||
... transformer=transformer,
|
||||
... text_encoder=text_encoder,
|
||||
... tokenizer=tokenizer,
|
||||
... torch_dtype=torch.bfloat16,
|
||||
... )
|
||||
>>> pipe.enable_model_cpu_offload()
|
||||
>>> image = load_image(
|
||||
>>> init_image = load_image(
|
||||
... "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
|
||||
... )
|
||||
>>> prompt = "a scenic fastasy landscape with a river and mountains in the background, vibrant colors, detailed, high resolution"
|
||||
>>> negative_prompt = "low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors"
|
||||
>>> image = pipe(prompt, image=image, negative_prompt=negative_prompt).images[0]
|
||||
>>> image = pipe(prompt, image=init_image, negative_prompt=negative_prompt).images[0]
|
||||
>>> image.save("chroma-img2img.png")
|
||||
```
|
||||
"""
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -220,11 +220,6 @@ _pytorch_retinaface_available, _pytorch_retinaface_version = _is_package_availab
|
||||
_better_profanity_available, _better_profanity_version = _is_package_available("better_profanity")
|
||||
_nltk_available, _nltk_version = _is_package_available("nltk")
|
||||
_cosmos_guardrail_available, _cosmos_guardrail_version = _is_package_available("cosmos_guardrail")
|
||||
_deepspeed_available, _deepspeed_version = _is_package_available("deepspeed")
|
||||
|
||||
|
||||
def is_deepspeed_available():
|
||||
return _deepspeed_available
|
||||
|
||||
|
||||
def is_torch_available():
|
||||
@@ -660,19 +655,6 @@ def is_torch_version(operation: str, version: str):
|
||||
return compare_versions(parse(_torch_version), operation, version)
|
||||
|
||||
|
||||
def is_deepspeed_version(operation: str, version: str):
|
||||
"""
|
||||
Compares the current DeepSpeed version to a given reference with an operation.
|
||||
|
||||
Args:
|
||||
operation (`str`):
|
||||
A string representation of an operator, such as `">"` or `"<="`
|
||||
version (`str`):
|
||||
A string version of DeepSpeed
|
||||
"""
|
||||
return compare_versions(parse(_deepspeed_version), operation, version)
|
||||
|
||||
|
||||
def is_torch_xla_version(operation: str, version: str):
|
||||
"""
|
||||
Compares the current torch_xla version to a given reference with an operation.
|
||||
|
||||
@@ -18,19 +18,13 @@ State dict utilities: utility methods for converting state dicts easily
|
||||
import enum
|
||||
import json
|
||||
|
||||
from .import_utils import is_deepspeed_available, is_deepspeed_version, is_torch_available
|
||||
from .import_utils import is_torch_available
|
||||
from .logging import get_logger
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_deepspeed_available() and is_deepspeed_version(">", "0.16"):
|
||||
from deepspeed.io import FastFileWriter, FastFileWriterConfig
|
||||
from deepspeed.ops.op_builder import AsyncIOBuilder, GDSBuilder
|
||||
|
||||
from .deep_nvme_utils import save as _nvme_save
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -370,54 +364,3 @@ def _load_sft_state_dict_metadata(model_file: str):
|
||||
return json.loads(raw) if raw else None
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
# Utilities below are taken from
|
||||
# https://github.com/deepspeedai/DeepSpeedExamples/blob/28a984e77b8d096dadc6389b6d1440b823587e28/deepnvme/model_checkpoint/torch_save_utils.py#L16
|
||||
def _load_io_ops(args):
|
||||
if AsyncIOBuilder().is_compatible():
|
||||
AsyncIOBuilder().load(verbose=False)
|
||||
if args.gpu and GDSBuilder().is_compatible():
|
||||
GDSBuilder().load(verbose=False)
|
||||
|
||||
|
||||
def _get_aio_handle():
|
||||
AIO_QUEUE_DEPTH = 8
|
||||
AIO_BLOCK_SIZE = 8 * (1024**2)
|
||||
AIO_INTRA_OP_PARALLEL = 1
|
||||
AIO_SINGLE_SUBMIT = False
|
||||
|
||||
h = (
|
||||
AsyncIOBuilder()
|
||||
.load(verbose=False)
|
||||
.aio_handle(
|
||||
block_size=AIO_BLOCK_SIZE,
|
||||
queue_depth=AIO_QUEUE_DEPTH,
|
||||
single_submit=AIO_SINGLE_SUBMIT,
|
||||
overlap_events=AIO_SINGLE_SUBMIT,
|
||||
intra_op_parallelism=AIO_INTRA_OP_PARALLEL,
|
||||
)
|
||||
)
|
||||
return h
|
||||
|
||||
|
||||
def _get_aio_components():
|
||||
PINNED_BUFFER_MB = 64
|
||||
h = _get_aio_handle()
|
||||
pinned_memory = torch.zeros(PINNED_BUFFER_MB * (1024**2), dtype=torch.uint8, device="cpu").pin_memory()
|
||||
return h, pinned_memory
|
||||
|
||||
|
||||
def _fast_aio_save(buffer, file, single_io_buffer=False):
|
||||
h, pinned_memory = _get_aio_components()
|
||||
fast_writer_config = FastFileWriterConfig(
|
||||
dnvme_handle=h,
|
||||
pinned_tensor=pinned_memory,
|
||||
double_buffer=not single_io_buffer,
|
||||
num_parallel_writers=1,
|
||||
writer_rank=0,
|
||||
)
|
||||
|
||||
ds_fast_writer = FastFileWriter(file_path=file, config=fast_writer_config)
|
||||
_nvme_save(f=ds_fast_writer, obj=buffer, _use_new_zipfile_serialization=False)
|
||||
ds_fast_writer.close()
|
||||
|
||||
Reference in New Issue
Block a user