Compare commits

...

8 Commits

Author SHA1 Message Date
DN6 82d7676fe3 update 2025-09-19 22:31:11 +05:30
Dave Lage 7e7e62c6ff Convert alphas for embedders for sd-scripts to ai toolkit conversion (#12332)
* Convert alphas for embedders for sd-scripts to ai toolkit conversion

* Add kohya embedders conversion test

* Apply style fixes

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-09-18 12:41:22 +05:30
Fredy eda9ff8300 Add RequestScopedPipeline for safe concurrent inference, tokenizer lock and non-mutating retrieve_timesteps (#12328)
* Basic implementation of request scheduling

* Basic editing in SD and Flux Pipelines

* Small Fix

* Fix

* Update for more pipelines

* Add examples/server-async

* Add examples/server-async

* Updated RequestScopedPipeline to handle a single tokenizer lock to avoid race conditions

* Fix

* Fix _TokenizerLockWrapper

* Fix _TokenizerLockWrapper

* Delete _TokenizerLockWrapper

* Fix tokenizer

* Update examples/server-async

* Fix server-async

* Optimizations in examples/server-async

* We keep the implementation simple in examples/server-async

* Update examples/server-async/README.md

* Update examples/server-async/README.md for changes to tokenizer locks and backward-compatible retrieve_timesteps

* The changes to the diffusers core have been undone and all logic is being moved to exmaples/server-async

* Update examples/server-async/utils/*

* Fix BaseAsyncScheduler

* Rollback in the core of the diffusers

* Update examples/server-async/README.md

* Complete rollback of diffusers core files

* Simple implementation of an asynchronous server compatible with SD3-3.5 and Flux Pipelines

* Update examples/server-async/README.md

* Fixed import errors in 'examples/server-async/serverasync.py'

* Flux Pipeline Discard

* Update examples/server-async/README.md

* Apply style fixes

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-09-18 11:33:43 +05:30
DefTruth efb7a299af Fix many type hint errors (#12289)
* fix hidream type hint

* fix hunyuan-video type hint

* fix many type hint

* fix many type hint errors

* fix many type hint errors

* fix many type hint errors

* make stype & make quality
2025-09-16 18:52:15 -10:00
Zijian Zhou d06750a5fd Fix autoencoder_kl_wan.py bugs for Wan2.2 VAE (#12335)
* Update autoencoder_kl_wan.py

When using the Wan2.2 VAE, the spatial compression ratio calculated here is incorrect. It should be 16 instead of 8. Pass it in directly via the config to ensure it’s correct here.

* Update autoencoder_kl_wan.py
2025-09-16 13:43:15 -10:00
Sari Hleihil 8c72cd12ee Added LucyEditPipeline (#12340)
* Added LucyEditPipeline

* add import & stype

missing copied from

* Fix example doc string

---------

Co-authored-by: yiyixuxu <yixu310@gmail.com>
2025-09-16 13:41:05 -10:00
Samarth Agrawal 751e250f70 fixed bug in defining embed dim for UNet1D (#12111)
* fixed bug in defining embed dim

* matched 1d temb process to 2d

* Update src/diffusers/models/unets/unet_1d.py

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
2025-09-16 12:18:48 +05:30
Linoy Tsaban b50014067d Add Wan2.2 VACE - Fun (#12324)
* support Wan2.2-VACE-Fun-A14B

* support Wan2.2-VACE-Fun-A14B

* support Wan2.2-VACE-Fun-A14B

* Apply style fixes

* test

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-09-15 21:31:26 +05:30
34 changed files with 2064 additions and 104 deletions
+91
View File
@@ -0,0 +1,91 @@
import logging
import os
from dataclasses import dataclass, field
from typing import List
import torch
from pydantic import BaseModel
from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import StableDiffusion3Pipeline
logger = logging.getLogger(__name__)
class TextToImageInput(BaseModel):
model: str
prompt: str
size: str | None = None
n: int | None = None
@dataclass
class PresetModels:
SD3: List[str] = field(default_factory=lambda: ["stabilityai/stable-diffusion-3-medium"])
SD3_5: List[str] = field(
default_factory=lambda: [
"stabilityai/stable-diffusion-3.5-large",
"stabilityai/stable-diffusion-3.5-large-turbo",
"stabilityai/stable-diffusion-3.5-medium",
]
)
class TextToImagePipelineSD3:
def __init__(self, model_path: str | None = None):
self.model_path = model_path or os.getenv("MODEL_PATH")
self.pipeline: StableDiffusion3Pipeline | None = None
self.device: str | None = None
def start(self):
if torch.cuda.is_available():
model_path = self.model_path or "stabilityai/stable-diffusion-3.5-large"
logger.info("Loading CUDA")
self.device = "cuda"
self.pipeline = StableDiffusion3Pipeline.from_pretrained(
model_path,
torch_dtype=torch.float16,
).to(device=self.device)
elif torch.backends.mps.is_available():
model_path = self.model_path or "stabilityai/stable-diffusion-3.5-medium"
logger.info("Loading MPS for Mac M Series")
self.device = "mps"
self.pipeline = StableDiffusion3Pipeline.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
).to(device=self.device)
else:
raise Exception("No CUDA or MPS device available")
class ModelPipelineInitializer:
def __init__(self, model: str = "", type_models: str = "t2im"):
self.model = model
self.type_models = type_models
self.pipeline = None
self.device = "cuda" if torch.cuda.is_available() else "mps"
self.model_type = None
def initialize_pipeline(self):
if not self.model:
raise ValueError("Model name not provided")
# Check if model exists in PresetModels
preset_models = PresetModels()
# Determine which model type we're dealing with
if self.model in preset_models.SD3:
self.model_type = "SD3"
elif self.model in preset_models.SD3_5:
self.model_type = "SD3_5"
# Create appropriate pipeline based on model type and type_models
if self.type_models == "t2im":
if self.model_type in ["SD3", "SD3_5"]:
self.pipeline = TextToImagePipelineSD3(self.model)
else:
raise ValueError(f"Model type {self.model_type} not supported for text-to-image")
elif self.type_models == "t2v":
raise ValueError(f"Unsupported type_models: {self.type_models}")
return self.pipeline
+171
View File
@@ -0,0 +1,171 @@
# Asynchronous server and parallel execution of models
> Example/demo server that keeps a single model in memory while safely running parallel inference requests by creating per-request lightweight views and cloning only small, stateful components (schedulers, RNG state, small mutable attrs). Works with StableDiffusion3 pipelines.
> We recommend running 10 to 50 inferences in parallel for optimal performance, averaging between 25 and 30 seconds to 1 minute and 1 minute and 30 seconds. (This is only recommended if you have a GPU with 35GB of VRAM or more; otherwise, keep it to one or two inferences in parallel to avoid decoding or saving errors due to memory shortages.)
## ⚠️ IMPORTANT
* The example demonstrates how to run pipelines like `StableDiffusion3-3.5` concurrently while keeping a single copy of the heavy model parameters on GPU.
## Necessary components
All the components needed to create the inference server are in the current directory:
```
server-async/
├── utils/
├─────── __init__.py
├─────── scheduler.py # BaseAsyncScheduler wrapper and async_retrieve_timesteps for secure inferences
├─────── requestscopedpipeline.py # RequestScoped Pipeline for inference with a single in-memory model
├─────── utils.py # Image/video saving utilities and service configuration
├── Pipelines.py # pipeline loader classes (SD3)
├── serverasync.py # FastAPI app with lifespan management and async inference endpoints
├── test.py # Client test script for inference requests
├── requirements.txt # Dependencies
└── README.md # This documentation
```
## What `diffusers-async` adds / Why we needed it
Core problem: a naive server that calls `pipe.__call__` concurrently can hit **race conditions** (e.g., `scheduler.set_timesteps` mutates shared state) or explode memory by deep-copying the whole pipeline per-request.
`diffusers-async` / this example addresses that by:
* **Request-scoped views**: `RequestScopedPipeline` creates a shallow copy of the pipeline per request so heavy weights (UNet, VAE, text encoder) remain shared and *are not duplicated*.
* **Per-request mutable state**: stateful small objects (scheduler, RNG state, small lists/dicts, callbacks) are cloned per request. The system uses `BaseAsyncScheduler.clone_for_request(...)` for scheduler cloning, with fallback to safe `deepcopy` or other heuristics.
* **Tokenizer concurrency safety**: `RequestScopedPipeline` now manages an internal tokenizer lock with automatic tokenizer detection and wrapping. This ensures that Rust tokenizers are safe to use under concurrency — race condition errors like `Already borrowed` no longer occur.
* **`async_retrieve_timesteps(..., return_scheduler=True)`**: fully retro-compatible helper that returns `(timesteps, num_inference_steps, scheduler)` without mutating the shared scheduler. For users not using `return_scheduler=True`, the behavior is identical to the original API.
* **Robust attribute handling**: wrapper avoids writing to read-only properties (e.g., `components`) and auto-detects small mutable attributes to clone while avoiding duplication of large tensors. Configurable tensor size threshold prevents cloning of large tensors.
* **Enhanced scheduler wrapping**: `BaseAsyncScheduler` automatically wraps schedulers with improved `__getattr__`, `__setattr__`, and debugging methods (`__repr__`, `__str__`).
## How the server works (high-level flow)
1. **Single model instance** is loaded into memory (GPU/MPS) when the server starts.
2. On each HTTP inference request:
* The server uses `RequestScopedPipeline.generate(...)` which:
* automatically wraps the base scheduler in `BaseAsyncScheduler` (if not already wrapped),
* obtains a *local scheduler* (via `clone_for_request()` or `deepcopy`),
* does `local_pipe = copy.copy(base_pipe)` (shallow copy),
* sets `local_pipe.scheduler = local_scheduler` (if possible),
* clones only small mutable attributes (callbacks, rng, small latents) with auto-detection,
* wraps tokenizers with thread-safe locks to prevent race conditions,
* optionally enters a `model_cpu_offload_context()` for memory offload hooks,
* calls the pipeline on the local view (`local_pipe(...)`).
3. **Result**: inference completes, images are moved to CPU & saved (if requested), internal buffers freed (GC + `torch.cuda.empty_cache()`).
4. Multiple requests can run in parallel while sharing heavy weights and isolating mutable state.
## How to set up and run the server
### 1) Install dependencies
Recommended: create a virtualenv / conda environment.
```bash
pip install diffusers
pip install -r requirements.txt
```
### 2) Start the server
Using the `serverasync.py` file that already has everything you need:
```bash
python serverasync.py
```
The server will start on `http://localhost:8500` by default with the following features:
- FastAPI application with async lifespan management
- Automatic model loading and pipeline initialization
- Request counting and active inference tracking
- Memory cleanup after each inference
- CORS middleware for cross-origin requests
### 3) Test the server
Use the included test script:
```bash
python test.py
```
Or send a manual request:
`POST /api/diffusers/inference` with JSON body:
```json
{
"prompt": "A futuristic cityscape, vibrant colors",
"num_inference_steps": 30,
"num_images_per_prompt": 1
}
```
Response example:
```json
{
"response": ["http://localhost:8500/images/img123.png"]
}
```
### 4) Server endpoints
- `GET /` - Welcome message
- `POST /api/diffusers/inference` - Main inference endpoint
- `GET /images/{filename}` - Serve generated images
- `GET /api/status` - Server status and memory info
## Advanced Configuration
### RequestScopedPipeline Parameters
```python
RequestScopedPipeline(
pipeline, # Base pipeline to wrap
mutable_attrs=None, # Custom list of attributes to clone
auto_detect_mutables=True, # Enable automatic detection of mutable attributes
tensor_numel_threshold=1_000_000, # Tensor size threshold for cloning
tokenizer_lock=None, # Custom threading lock for tokenizers
wrap_scheduler=True # Auto-wrap scheduler in BaseAsyncScheduler
)
```
### BaseAsyncScheduler Features
* Transparent proxy to the original scheduler with `__getattr__` and `__setattr__`
* `clone_for_request()` method for safe per-request scheduler cloning
* Enhanced debugging with `__repr__` and `__str__` methods
* Full compatibility with existing scheduler APIs
### Server Configuration
The server configuration can be modified in `serverasync.py` through the `ServerConfigModels` dataclass:
```python
@dataclass
class ServerConfigModels:
model: str = 'stabilityai/stable-diffusion-3.5-medium'
type_models: str = 't2im'
host: str = '0.0.0.0'
port: int = 8500
```
## Troubleshooting (quick)
* `Already borrowed` — previously a Rust tokenizer concurrency error.
✅ This is now fixed: `RequestScopedPipeline` automatically detects and wraps tokenizers with thread locks, so race conditions no longer happen.
* `can't set attribute 'components'` — pipeline exposes read-only `components`.
✅ The RequestScopedPipeline now detects read-only properties and skips setting them automatically.
* Scheduler issues:
* If the scheduler doesn't implement `clone_for_request` and `deepcopy` fails, we log and fallback — but prefer `async_retrieve_timesteps(..., return_scheduler=True)` to avoid mutating the shared scheduler.
✅ Note: `async_retrieve_timesteps` is fully retro-compatible — if you don't pass `return_scheduler=True`, the behavior is unchanged.
* Memory issues with large tensors:
✅ The system now has configurable `tensor_numel_threshold` to prevent cloning of large tensors while still cloning small mutable ones.
* Automatic tokenizer detection:
✅ The system automatically identifies tokenizer components by checking for tokenizer methods, class names, and attributes, then applies thread-safe wrappers.
+10
View File
@@ -0,0 +1,10 @@
torch
torchvision
transformers
sentencepiece
fastapi
uvicorn
ftfy
accelerate
xformers
protobuf
+230
View File
@@ -0,0 +1,230 @@
import asyncio
import gc
import logging
import os
import random
import threading
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import Any, Dict, Optional, Type
import torch
from fastapi import FastAPI, HTTPException, Request
from fastapi.concurrency import run_in_threadpool
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from Pipelines import ModelPipelineInitializer
from pydantic import BaseModel
from utils import RequestScopedPipeline, Utils
@dataclass
class ServerConfigModels:
model: str = "stabilityai/stable-diffusion-3.5-medium"
type_models: str = "t2im"
constructor_pipeline: Optional[Type] = None
custom_pipeline: Optional[Type] = None
components: Optional[Dict[str, Any]] = None
torch_dtype: Optional[torch.dtype] = None
host: str = "0.0.0.0"
port: int = 8500
server_config = ServerConfigModels()
@asynccontextmanager
async def lifespan(app: FastAPI):
logging.basicConfig(level=logging.INFO)
app.state.logger = logging.getLogger("diffusers-server")
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128,expandable_segments:True"
os.environ["CUDA_LAUNCH_BLOCKING"] = "0"
app.state.total_requests = 0
app.state.active_inferences = 0
app.state.metrics_lock = asyncio.Lock()
app.state.metrics_task = None
app.state.utils_app = Utils(
host=server_config.host,
port=server_config.port,
)
async def metrics_loop():
try:
while True:
async with app.state.metrics_lock:
total = app.state.total_requests
active = app.state.active_inferences
app.state.logger.info(f"[METRICS] total_requests={total} active_inferences={active}")
await asyncio.sleep(5)
except asyncio.CancelledError:
app.state.logger.info("Metrics loop cancelled")
raise
app.state.metrics_task = asyncio.create_task(metrics_loop())
try:
yield
finally:
task = app.state.metrics_task
if task:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
try:
stop_fn = getattr(model_pipeline, "stop", None) or getattr(model_pipeline, "close", None)
if callable(stop_fn):
await run_in_threadpool(stop_fn)
except Exception as e:
app.state.logger.warning(f"Error during pipeline shutdown: {e}")
app.state.logger.info("Lifespan shutdown complete")
app = FastAPI(lifespan=lifespan)
logger = logging.getLogger("DiffusersServer.Pipelines")
initializer = ModelPipelineInitializer(
model=server_config.model,
type_models=server_config.type_models,
)
model_pipeline = initializer.initialize_pipeline()
model_pipeline.start()
request_pipe = RequestScopedPipeline(model_pipeline.pipeline)
pipeline_lock = threading.Lock()
logger.info(f"Pipeline initialized and ready to receive requests (model ={server_config.model})")
app.state.MODEL_INITIALIZER = initializer
app.state.MODEL_PIPELINE = model_pipeline
app.state.REQUEST_PIPE = request_pipe
app.state.PIPELINE_LOCK = pipeline_lock
class JSONBodyQueryAPI(BaseModel):
model: str | None = None
prompt: str
negative_prompt: str | None = None
num_inference_steps: int = 28
num_images_per_prompt: int = 1
@app.middleware("http")
async def count_requests_middleware(request: Request, call_next):
async with app.state.metrics_lock:
app.state.total_requests += 1
response = await call_next(request)
return response
@app.get("/")
async def root():
return {"message": "Welcome to the Diffusers Server"}
@app.post("/api/diffusers/inference")
async def api(json: JSONBodyQueryAPI):
prompt = json.prompt
negative_prompt = json.negative_prompt or ""
num_steps = json.num_inference_steps
num_images_per_prompt = json.num_images_per_prompt
wrapper = app.state.MODEL_PIPELINE
initializer = app.state.MODEL_INITIALIZER
utils_app = app.state.utils_app
if not wrapper or not wrapper.pipeline:
raise HTTPException(500, "Model not initialized correctly")
if not prompt.strip():
raise HTTPException(400, "No prompt provided")
def make_generator():
g = torch.Generator(device=initializer.device)
return g.manual_seed(random.randint(0, 10_000_000))
req_pipe = app.state.REQUEST_PIPE
def infer():
gen = make_generator()
return req_pipe.generate(
prompt=prompt,
negative_prompt=negative_prompt,
generator=gen,
num_inference_steps=num_steps,
num_images_per_prompt=num_images_per_prompt,
device=initializer.device,
output_type="pil",
)
try:
async with app.state.metrics_lock:
app.state.active_inferences += 1
output = await run_in_threadpool(infer)
async with app.state.metrics_lock:
app.state.active_inferences = max(0, app.state.active_inferences - 1)
urls = [utils_app.save_image(img) for img in output.images]
return {"response": urls}
except Exception as e:
async with app.state.metrics_lock:
app.state.active_inferences = max(0, app.state.active_inferences - 1)
logger.error(f"Error during inference: {e}")
raise HTTPException(500, f"Error in processing: {e}")
finally:
if torch.cuda.is_available():
torch.cuda.synchronize()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
torch.cuda.ipc_collect()
gc.collect()
@app.get("/images/{filename}")
async def serve_image(filename: str):
utils_app = app.state.utils_app
file_path = os.path.join(utils_app.image_dir, filename)
if not os.path.isfile(file_path):
raise HTTPException(status_code=404, detail="Image not found")
return FileResponse(file_path, media_type="image/png")
@app.get("/api/status")
async def get_status():
memory_info = {}
if torch.cuda.is_available():
memory_allocated = torch.cuda.memory_allocated() / 1024**3 # GB
memory_reserved = torch.cuda.memory_reserved() / 1024**3 # GB
memory_info = {
"memory_allocated_gb": round(memory_allocated, 2),
"memory_reserved_gb": round(memory_reserved, 2),
"device": torch.cuda.get_device_name(0),
}
return {"current_model": server_config.model, "type_models": server_config.type_models, "memory": memory_info}
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host=server_config.host, port=server_config.port)
+65
View File
@@ -0,0 +1,65 @@
import os
import time
import urllib.parse
import requests
SERVER_URL = "http://localhost:8500/api/diffusers/inference"
BASE_URL = "http://localhost:8500"
DOWNLOAD_FOLDER = "generated_images"
WAIT_BEFORE_DOWNLOAD = 2 # seconds
os.makedirs(DOWNLOAD_FOLDER, exist_ok=True)
def save_from_url(url: str) -> str:
"""Download the given URL (relative or absolute) and save it locally."""
if url.startswith("/"):
direct = BASE_URL.rstrip("/") + url
else:
direct = url
resp = requests.get(direct, timeout=60)
resp.raise_for_status()
filename = os.path.basename(urllib.parse.urlparse(direct).path) or f"img_{int(time.time())}.png"
path = os.path.join(DOWNLOAD_FOLDER, filename)
with open(path, "wb") as f:
f.write(resp.content)
return path
def main():
payload = {
"prompt": "The T-800 Terminator Robot Returning From The Future, Anime Style",
"num_inference_steps": 30,
"num_images_per_prompt": 1,
}
print("Sending request...")
try:
r = requests.post(SERVER_URL, json=payload, timeout=480)
r.raise_for_status()
except Exception as e:
print(f"Request failed: {e}")
return
body = r.json().get("response", [])
# Normalize to a list
urls = body if isinstance(body, list) else [body] if body else []
if not urls:
print("No URLs found in the response. Check the server output.")
return
print(f"Received {len(urls)} URL(s). Waiting {WAIT_BEFORE_DOWNLOAD}s before downloading...")
time.sleep(WAIT_BEFORE_DOWNLOAD)
for u in urls:
try:
path = save_from_url(u)
print(f"Image saved to: {path}")
except Exception as e:
print(f"Error downloading {u}: {e}")
if __name__ == "__main__":
main()
+2
View File
@@ -0,0 +1,2 @@
from .requestscopedpipeline import RequestScopedPipeline
from .utils import Utils
@@ -0,0 +1,296 @@
import copy
import threading
from typing import Any, Iterable, List, Optional
import torch
from diffusers.utils import logging
from .scheduler import BaseAsyncScheduler, async_retrieve_timesteps
logger = logging.get_logger(__name__)
def safe_tokenize(tokenizer, *args, lock, **kwargs):
with lock:
return tokenizer(*args, **kwargs)
class RequestScopedPipeline:
DEFAULT_MUTABLE_ATTRS = [
"_all_hooks",
"_offload_device",
"_progress_bar_config",
"_progress_bar",
"_rng_state",
"_last_seed",
"latents",
]
def __init__(
self,
pipeline: Any,
mutable_attrs: Optional[Iterable[str]] = None,
auto_detect_mutables: bool = True,
tensor_numel_threshold: int = 1_000_000,
tokenizer_lock: Optional[threading.Lock] = None,
wrap_scheduler: bool = True,
):
self._base = pipeline
self.unet = getattr(pipeline, "unet", None)
self.vae = getattr(pipeline, "vae", None)
self.text_encoder = getattr(pipeline, "text_encoder", None)
self.components = getattr(pipeline, "components", None)
if wrap_scheduler and hasattr(pipeline, "scheduler") and pipeline.scheduler is not None:
if not isinstance(pipeline.scheduler, BaseAsyncScheduler):
pipeline.scheduler = BaseAsyncScheduler(pipeline.scheduler)
self._mutable_attrs = list(mutable_attrs) if mutable_attrs is not None else list(self.DEFAULT_MUTABLE_ATTRS)
self._tokenizer_lock = tokenizer_lock if tokenizer_lock is not None else threading.Lock()
self._auto_detect_mutables = bool(auto_detect_mutables)
self._tensor_numel_threshold = int(tensor_numel_threshold)
self._auto_detected_attrs: List[str] = []
def _make_local_scheduler(self, num_inference_steps: int, device: Optional[str] = None, **clone_kwargs):
base_sched = getattr(self._base, "scheduler", None)
if base_sched is None:
return None
if not isinstance(base_sched, BaseAsyncScheduler):
wrapped_scheduler = BaseAsyncScheduler(base_sched)
else:
wrapped_scheduler = base_sched
try:
return wrapped_scheduler.clone_for_request(
num_inference_steps=num_inference_steps, device=device, **clone_kwargs
)
except Exception as e:
logger.debug(f"clone_for_request failed: {e}; falling back to deepcopy()")
try:
return copy.deepcopy(wrapped_scheduler)
except Exception as e:
logger.warning(f"Deepcopy of scheduler failed: {e}. Returning original scheduler (*risky*).")
return wrapped_scheduler
def _autodetect_mutables(self, max_attrs: int = 40):
if not self._auto_detect_mutables:
return []
if self._auto_detected_attrs:
return self._auto_detected_attrs
candidates: List[str] = []
seen = set()
for name in dir(self._base):
if name.startswith("__"):
continue
if name in self._mutable_attrs:
continue
if name in ("to", "save_pretrained", "from_pretrained"):
continue
try:
val = getattr(self._base, name)
except Exception:
continue
import types
# skip callables and modules
if callable(val) or isinstance(val, (types.ModuleType, types.FunctionType, types.MethodType)):
continue
# containers -> candidate
if isinstance(val, (dict, list, set, tuple, bytearray)):
candidates.append(name)
seen.add(name)
else:
# try Tensor detection
try:
if isinstance(val, torch.Tensor):
if val.numel() <= self._tensor_numel_threshold:
candidates.append(name)
seen.add(name)
else:
logger.debug(f"Ignoring large tensor attr '{name}', numel={val.numel()}")
except Exception:
continue
if len(candidates) >= max_attrs:
break
self._auto_detected_attrs = candidates
logger.debug(f"Autodetected mutable attrs to clone: {self._auto_detected_attrs}")
return self._auto_detected_attrs
def _is_readonly_property(self, base_obj, attr_name: str) -> bool:
try:
cls = type(base_obj)
descriptor = getattr(cls, attr_name, None)
if isinstance(descriptor, property):
return descriptor.fset is None
if hasattr(descriptor, "__set__") is False and descriptor is not None:
return False
except Exception:
pass
return False
def _clone_mutable_attrs(self, base, local):
attrs_to_clone = list(self._mutable_attrs)
attrs_to_clone.extend(self._autodetect_mutables())
EXCLUDE_ATTRS = {
"components",
}
for attr in attrs_to_clone:
if attr in EXCLUDE_ATTRS:
logger.debug(f"Skipping excluded attr '{attr}'")
continue
if not hasattr(base, attr):
continue
if self._is_readonly_property(base, attr):
logger.debug(f"Skipping read-only property '{attr}'")
continue
try:
val = getattr(base, attr)
except Exception as e:
logger.debug(f"Could not getattr('{attr}') on base pipeline: {e}")
continue
try:
if isinstance(val, dict):
setattr(local, attr, dict(val))
elif isinstance(val, (list, tuple, set)):
setattr(local, attr, list(val))
elif isinstance(val, bytearray):
setattr(local, attr, bytearray(val))
else:
# small tensors or atomic values
if isinstance(val, torch.Tensor):
if val.numel() <= self._tensor_numel_threshold:
setattr(local, attr, val.clone())
else:
# don't clone big tensors, keep reference
setattr(local, attr, val)
else:
try:
setattr(local, attr, copy.copy(val))
except Exception:
setattr(local, attr, val)
except (AttributeError, TypeError) as e:
logger.debug(f"Skipping cloning attribute '{attr}' because it is not settable: {e}")
continue
except Exception as e:
logger.debug(f"Unexpected error cloning attribute '{attr}': {e}")
continue
def _is_tokenizer_component(self, component) -> bool:
if component is None:
return False
tokenizer_methods = ["encode", "decode", "tokenize", "__call__"]
has_tokenizer_methods = any(hasattr(component, method) for method in tokenizer_methods)
class_name = component.__class__.__name__.lower()
has_tokenizer_in_name = "tokenizer" in class_name
tokenizer_attrs = ["vocab_size", "pad_token", "eos_token", "bos_token"]
has_tokenizer_attrs = any(hasattr(component, attr) for attr in tokenizer_attrs)
return has_tokenizer_methods and (has_tokenizer_in_name or has_tokenizer_attrs)
def generate(self, *args, num_inference_steps: int = 50, device: Optional[str] = None, **kwargs):
local_scheduler = self._make_local_scheduler(num_inference_steps=num_inference_steps, device=device)
try:
local_pipe = copy.copy(self._base)
except Exception as e:
logger.warning(f"copy.copy(self._base) failed: {e}. Falling back to deepcopy (may increase memory).")
local_pipe = copy.deepcopy(self._base)
if local_scheduler is not None:
try:
timesteps, num_steps, configured_scheduler = async_retrieve_timesteps(
local_scheduler.scheduler,
num_inference_steps=num_inference_steps,
device=device,
return_scheduler=True,
**{k: v for k, v in kwargs.items() if k in ["timesteps", "sigmas"]},
)
final_scheduler = BaseAsyncScheduler(configured_scheduler)
setattr(local_pipe, "scheduler", final_scheduler)
except Exception:
logger.warning("Could not set scheduler on local pipe; proceeding without replacing scheduler.")
self._clone_mutable_attrs(self._base, local_pipe)
# 4) wrap tokenizers on the local pipe with the lock wrapper
tokenizer_wrappers = {} # name -> original_tokenizer
try:
# a) wrap direct tokenizer attributes (tokenizer, tokenizer_2, ...)
for name in dir(local_pipe):
if "tokenizer" in name and not name.startswith("_"):
tok = getattr(local_pipe, name, None)
if tok is not None and self._is_tokenizer_component(tok):
tokenizer_wrappers[name] = tok
setattr(
local_pipe,
name,
lambda *args, tok=tok, **kwargs: safe_tokenize(
tok, *args, lock=self._tokenizer_lock, **kwargs
),
)
# b) wrap tokenizers in components dict
if hasattr(local_pipe, "components") and isinstance(local_pipe.components, dict):
for key, val in local_pipe.components.items():
if val is None:
continue
if self._is_tokenizer_component(val):
tokenizer_wrappers[f"components[{key}]"] = val
local_pipe.components[key] = lambda *args, tokenizer=val, **kwargs: safe_tokenize(
tokenizer, *args, lock=self._tokenizer_lock, **kwargs
)
except Exception as e:
logger.debug(f"Tokenizer wrapping step encountered an error: {e}")
result = None
cm = getattr(local_pipe, "model_cpu_offload_context", None)
try:
if callable(cm):
try:
with cm():
result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
except TypeError:
# cm might be a context manager instance rather than callable
try:
with cm:
result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
except Exception as e:
logger.debug(f"model_cpu_offload_context usage failed: {e}. Proceeding without it.")
result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
else:
# no offload context available — call directly
result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
return result
finally:
try:
for name, tok in tokenizer_wrappers.items():
if name.startswith("components["):
key = name[len("components[") : -1]
local_pipe.components[key] = tok
else:
setattr(local_pipe, name, tok)
except Exception as e:
logger.debug(f"Error restoring wrapped tokenizers: {e}")
+141
View File
@@ -0,0 +1,141 @@
import copy
import inspect
from typing import Any, List, Optional, Union
import torch
class BaseAsyncScheduler:
def __init__(self, scheduler: Any):
self.scheduler = scheduler
def __getattr__(self, name: str):
if hasattr(self.scheduler, name):
return getattr(self.scheduler, name)
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
def __setattr__(self, name: str, value):
if name == "scheduler":
super().__setattr__(name, value)
else:
if hasattr(self, "scheduler") and hasattr(self.scheduler, name):
setattr(self.scheduler, name, value)
else:
super().__setattr__(name, value)
def clone_for_request(self, num_inference_steps: int, device: Union[str, torch.device, None] = None, **kwargs):
local = copy.deepcopy(self.scheduler)
local.set_timesteps(num_inference_steps=num_inference_steps, device=device, **kwargs)
cloned = self.__class__(local)
return cloned
def __repr__(self):
return f"BaseAsyncScheduler({repr(self.scheduler)})"
def __str__(self):
return f"BaseAsyncScheduler wrapping: {str(self.scheduler)}"
def async_retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call.
Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Backwards compatible: by default the function behaves exactly as before and returns
(timesteps_tensor, num_inference_steps)
If the caller passes `return_scheduler=True` in kwargs, the function will **not** mutate the passed
scheduler. Instead it will use a cloned scheduler if available (via `scheduler.clone_for_request`)
or a deepcopy fallback, call `set_timesteps` on that cloned scheduler, and return:
(timesteps_tensor, num_inference_steps, scheduler_in_use)
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Optional kwargs:
return_scheduler (bool, default False): if True, return (timesteps, num_inference_steps, scheduler_in_use)
where `scheduler_in_use` is a scheduler instance that already has timesteps set.
This mode will prefer `scheduler.clone_for_request(...)` if available, to avoid mutating the original scheduler.
Returns:
`(timesteps_tensor, num_inference_steps)` by default (backwards compatible), or
`(timesteps_tensor, num_inference_steps, scheduler_in_use)` if `return_scheduler=True`.
"""
# pop our optional control kwarg (keeps compatibility)
return_scheduler = bool(kwargs.pop("return_scheduler", False))
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
# choose scheduler to call set_timesteps on
scheduler_in_use = scheduler
if return_scheduler:
# Do not mutate the provided scheduler: prefer to clone if possible
if hasattr(scheduler, "clone_for_request"):
try:
# clone_for_request may accept num_inference_steps or other kwargs; be permissive
scheduler_in_use = scheduler.clone_for_request(
num_inference_steps=num_inference_steps or 0, device=device
)
except Exception:
scheduler_in_use = copy.deepcopy(scheduler)
else:
# fallback deepcopy (scheduler tends to be smallish - acceptable)
scheduler_in_use = copy.deepcopy(scheduler)
# helper to test if set_timesteps supports a particular kwarg
def _accepts(param_name: str) -> bool:
try:
return param_name in set(inspect.signature(scheduler_in_use.set_timesteps).parameters.keys())
except (ValueError, TypeError):
# if signature introspection fails, be permissive and attempt the call later
return False
# now call set_timesteps on the chosen scheduler_in_use (may be original or clone)
if timesteps is not None:
accepts_timesteps = _accepts("timesteps")
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler_in_use.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler_in_use.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps_out = scheduler_in_use.timesteps
num_inference_steps = len(timesteps_out)
elif sigmas is not None:
accept_sigmas = _accepts("sigmas")
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler_in_use.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler_in_use.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps_out = scheduler_in_use.timesteps
num_inference_steps = len(timesteps_out)
else:
# default path
scheduler_in_use.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps_out = scheduler_in_use.timesteps
if return_scheduler:
return timesteps_out, num_inference_steps, scheduler_in_use
return timesteps_out, num_inference_steps
+48
View File
@@ -0,0 +1,48 @@
import gc
import logging
import os
import tempfile
import uuid
import torch
logger = logging.getLogger(__name__)
class Utils:
def __init__(self, host: str = "0.0.0.0", port: int = 8500):
self.service_url = f"http://{host}:{port}"
self.image_dir = os.path.join(tempfile.gettempdir(), "images")
if not os.path.exists(self.image_dir):
os.makedirs(self.image_dir)
self.video_dir = os.path.join(tempfile.gettempdir(), "videos")
if not os.path.exists(self.video_dir):
os.makedirs(self.video_dir)
def save_image(self, image):
if hasattr(image, "to"):
try:
image = image.to("cpu")
except Exception:
pass
if isinstance(image, torch.Tensor):
from torchvision import transforms
to_pil = transforms.ToPILImage()
image = to_pil(image.squeeze(0).clamp(0, 1))
filename = "img" + str(uuid.uuid4()).split("-")[0] + ".png"
image_path = os.path.join(self.image_dir, filename)
logger.info(f"Saving image to {image_path}")
image.save(image_path, format="PNG", optimize=True)
del image
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
return os.path.join(self.service_url, "images", filename)
+34 -1
View File
@@ -278,6 +278,29 @@ def get_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]:
}
RENAME_DICT = VACE_TRANSFORMER_KEYS_RENAME_DICT
SPECIAL_KEYS_REMAP = VACE_TRANSFORMER_SPECIAL_KEYS_REMAP
elif model_type == "Wan2.2-VACE-Fun-14B":
config = {
"model_id": "alibaba-pai/Wan2.2-VACE-Fun-A14B",
"diffusers_config": {
"added_kv_proj_dim": None,
"attention_head_dim": 128,
"cross_attn_norm": True,
"eps": 1e-06,
"ffn_dim": 13824,
"freq_dim": 256,
"in_channels": 16,
"num_attention_heads": 40,
"num_layers": 40,
"out_channels": 16,
"patch_size": [1, 2, 2],
"qk_norm": "rms_norm_across_heads",
"text_dim": 4096,
"vace_layers": [0, 5, 10, 15, 20, 25, 30, 35],
"vace_in_channels": 96,
},
}
RENAME_DICT = VACE_TRANSFORMER_KEYS_RENAME_DICT
SPECIAL_KEYS_REMAP = VACE_TRANSFORMER_SPECIAL_KEYS_REMAP
elif model_type == "Wan2.2-I2V-14B-720p":
config = {
"model_id": "Wan-AI/Wan2.2-I2V-A14B",
@@ -975,7 +998,17 @@ if __name__ == "__main__":
image_encoder=image_encoder,
image_processor=image_processor,
)
elif "VACE" in args.model_type:
elif "Wan2.2-VACE" in args.model_type:
pipe = WanVACEPipeline(
transformer=transformer,
transformer_2=transformer_2,
text_encoder=text_encoder,
tokenizer=tokenizer,
vae=vae,
scheduler=scheduler,
boundary_ratio=0.875,
)
elif "Wan-VACE" in args.model_type:
pipe = WanVACEPipeline(
transformer=transformer,
text_encoder=text_encoder,
+2
View File
@@ -495,6 +495,7 @@ else:
"LTXImageToVideoPipeline",
"LTXLatentUpsamplePipeline",
"LTXPipeline",
"LucyEditPipeline",
"Lumina2Pipeline",
"Lumina2Text2ImgPipeline",
"LuminaPipeline",
@@ -1149,6 +1150,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LTXImageToVideoPipeline,
LTXLatentUpsamplePipeline,
LTXPipeline,
LucyEditPipeline,
Lumina2Pipeline,
Lumina2Text2ImgPipeline,
LuminaPipeline,
+39 -47
View File
@@ -558,70 +558,62 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
ait_sd[target_key] = value
if any("guidance_in" in k for k in sds_sd):
assign_remaining_weights(
[
(
"time_text_embed.guidance_embedder.linear_1.{lora_key}.weight",
"lora_unet_guidance_in_in_layer.{orig_lora_key}.weight",
None,
),
(
"time_text_embed.guidance_embedder.linear_2.{lora_key}.weight",
"lora_unet_guidance_in_out_layer.{orig_lora_key}.weight",
None,
),
],
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
"lora_unet_guidance_in_in_layer",
"time_text_embed.guidance_embedder.linear_1",
)
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
"lora_unet_guidance_in_out_layer",
"time_text_embed.guidance_embedder.linear_2",
)
if any("img_in" in k for k in sds_sd):
assign_remaining_weights(
[
("x_embedder.{lora_key}.weight", "lora_unet_img_in.{orig_lora_key}.weight", None),
],
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
"lora_unet_img_in",
"x_embedder",
)
if any("txt_in" in k for k in sds_sd):
assign_remaining_weights(
[
("context_embedder.{lora_key}.weight", "lora_unet_txt_in.{orig_lora_key}.weight", None),
],
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
"lora_unet_txt_in",
"context_embedder",
)
if any("time_in" in k for k in sds_sd):
assign_remaining_weights(
[
(
"time_text_embed.timestep_embedder.linear_1.{lora_key}.weight",
"lora_unet_time_in_in_layer.{orig_lora_key}.weight",
None,
),
(
"time_text_embed.timestep_embedder.linear_2.{lora_key}.weight",
"lora_unet_time_in_out_layer.{orig_lora_key}.weight",
None,
),
],
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
"lora_unet_time_in_in_layer",
"time_text_embed.timestep_embedder.linear_1",
)
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
"lora_unet_time_in_out_layer",
"time_text_embed.timestep_embedder.linear_2",
)
if any("vector_in" in k for k in sds_sd):
assign_remaining_weights(
[
(
"time_text_embed.text_embedder.linear_1.{lora_key}.weight",
"lora_unet_vector_in_in_layer.{orig_lora_key}.weight",
None,
),
(
"time_text_embed.text_embedder.linear_2.{lora_key}.weight",
"lora_unet_vector_in_out_layer.{orig_lora_key}.weight",
None,
),
],
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
"lora_unet_vector_in_in_layer",
"time_text_embed.text_embedder.linear_1",
)
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
"lora_unet_vector_in_out_layer",
"time_text_embed.text_embedder.linear_2",
)
if any("final_layer" in k for k in sds_sd):
+1 -1
View File
@@ -674,7 +674,7 @@ class JointTransformerBlock(nn.Module):
encoder_hidden_states: torch.FloatTensor,
temb: torch.FloatTensor,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
):
) -> Tuple[torch.Tensor, torch.Tensor]:
joint_attention_kwargs = joint_attention_kwargs or {}
if self.use_dual_attention:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
@@ -1052,7 +1052,7 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
is_residual=is_residual,
)
self.spatial_compression_ratio = 2 ** len(self.temperal_downsample)
self.spatial_compression_ratio = scale_factor_spatial
# When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
# to perform decoding of a single video latent at a time.
@@ -1145,12 +1145,13 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
def _encode(self, x: torch.Tensor):
_, _, num_frame, height, width = x.shape
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
return self.tiled_encode(x)
self.clear_cache()
if self.config.patch_size is not None:
x = patchify(x, patch_size=self.config.patch_size)
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
return self.tiled_encode(x)
iter_ = 1 + (num_frame - 1) // 4
for i in range(iter_):
self._enc_conv_idx = [0]
@@ -13,7 +13,7 @@
# limitations under the License.
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
@@ -92,7 +92,7 @@ class AuraFlowPatchEmbed(nn.Module):
return selected_indices
def forward(self, latent):
def forward(self, latent) -> torch.Tensor:
batch_size, num_channels, height, width = latent.size()
latent = latent.view(
batch_size,
@@ -173,7 +173,7 @@ class AuraFlowSingleTransformerBlock(nn.Module):
hidden_states: torch.FloatTensor,
temb: torch.FloatTensor,
attention_kwargs: Optional[Dict[str, Any]] = None,
):
) -> torch.Tensor:
residual = hidden_states
attention_kwargs = attention_kwargs or {}
@@ -242,7 +242,7 @@ class AuraFlowJointTransformerBlock(nn.Module):
encoder_hidden_states: torch.FloatTensor,
temb: torch.FloatTensor,
attention_kwargs: Optional[Dict[str, Any]] = None,
):
) -> Tuple[torch.Tensor, torch.Tensor]:
residual = hidden_states
residual_context = encoder_hidden_states
attention_kwargs = attention_kwargs or {}
@@ -472,7 +472,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From
timestep: torch.LongTensor = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
@@ -122,7 +122,7 @@ class CogVideoXBlock(nn.Module):
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, torch.Tensor]:
text_seq_length = encoder_hidden_states.size(1)
attention_kwargs = attention_kwargs or {}
@@ -441,7 +441,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cac
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
):
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
@@ -315,7 +315,7 @@ class ConsisIDBlock(nn.Module):
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, torch.Tensor]:
text_seq_length = encoder_hidden_states.size(1)
# norm & modulate
@@ -691,7 +691,7 @@ class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
id_cond: Optional[torch.Tensor] = None,
id_vit_hidden: Optional[torch.Tensor] = None,
return_dict: bool = True,
):
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
@@ -124,7 +124,7 @@ class LuminaNextDiTBlock(nn.Module):
encoder_mask: torch.Tensor,
temb: torch.Tensor,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
):
) -> torch.Tensor:
"""
Perform a forward pass through the LuminaNextDiTBlock.
@@ -297,7 +297,7 @@ class LuminaNextDiT2DModel(ModelMixin, ConfigMixin):
image_rotary_emb: torch.Tensor,
cross_attention_kwargs: Dict[str, Any] = None,
return_dict=True,
) -> torch.Tensor:
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
"""
Forward pass of LuminaNextDiT.
@@ -472,7 +472,7 @@ class BriaSingleTransformerBlock(nn.Module):
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, torch.Tensor]:
text_seq_len = encoder_hidden_states.shape[1]
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
@@ -588,7 +588,7 @@ class BriaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
return_dict: bool = True,
controlnet_block_samples=None,
controlnet_single_block_samples=None,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
"""
The [`BriaTransformer2DModel`] forward method.
@@ -13,7 +13,7 @@
# limitations under the License.
from typing import Dict, Union
from typing import Dict, Tuple, Union
import torch
import torch.nn as nn
@@ -79,7 +79,7 @@ class CogView3PlusTransformerBlock(nn.Module):
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
emb: torch.Tensor,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, torch.Tensor]:
text_seq_length = encoder_hidden_states.size(1)
# norm & modulate
@@ -293,7 +293,7 @@ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin):
target_size: torch.Tensor,
crop_coords: torch.Tensor,
return_dict: bool = True,
) -> Union[torch.Tensor, Transformer2DModelOutput]:
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
"""
The [`CogView3PlusTransformer2DModel`] forward method.
@@ -494,7 +494,7 @@ class CogView4TransformerBlock(nn.Module):
] = None,
attention_mask: Optional[Dict[str, torch.Tensor]] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, torch.Tensor]:
# 1. Timestep conditioning
(
norm_hidden_states,
@@ -717,7 +717,7 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
image_rotary_emb: Optional[
Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]
] = None,
) -> Union[torch.Tensor, Transformer2DModelOutput]:
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
@@ -55,7 +55,7 @@ class HiDreamImageTimestepEmbed(nn.Module):
self.time_proj = Timesteps(num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=hidden_size)
def forward(self, timesteps: torch.Tensor, wdtype: Optional[torch.dtype] = None):
def forward(self, timesteps: torch.Tensor, wdtype: Optional[torch.dtype] = None) -> torch.Tensor:
t_emb = self.time_proj(timesteps).to(dtype=wdtype)
t_emb = self.timestep_embedder(t_emb)
return t_emb
@@ -87,7 +87,7 @@ class HiDreamImagePatchEmbed(nn.Module):
self.out_channels = out_channels
self.proj = nn.Linear(in_channels * patch_size * patch_size, out_channels, bias=True)
def forward(self, latent):
def forward(self, latent) -> torch.Tensor:
latent = self.proj(latent)
return latent
@@ -534,7 +534,7 @@ class HiDreamImageTransformerBlock(nn.Module):
encoder_hidden_states: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
image_rotary_emb: torch.Tensor = None,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, torch.Tensor]:
wtype = hidden_states.dtype
(
shift_msa_i,
@@ -592,7 +592,7 @@ class HiDreamBlock(nn.Module):
encoder_hidden_states: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
image_rotary_emb: torch.Tensor = None,
) -> torch.Tensor:
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
return self.block(
hidden_states=hidden_states,
hidden_states_masks=hidden_states_masks,
@@ -786,7 +786,7 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
**kwargs,
):
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
if encoder_hidden_states is not None:
@@ -529,7 +529,7 @@ class HunyuanVideoSingleTransformerBlock(nn.Module):
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
*args,
**kwargs,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, torch.Tensor]:
text_seq_length = encoder_hidden_states.shape[1]
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
@@ -684,7 +684,7 @@ class HunyuanVideoTokenReplaceSingleTransformerBlock(nn.Module):
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
token_replace_emb: torch.Tensor = None,
num_tokens: int = None,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, torch.Tensor]:
text_seq_length = encoder_hidden_states.shape[1]
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
@@ -1038,7 +1038,7 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
guidance: torch.Tensor = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
@@ -216,7 +216,7 @@ class HunyuanVideoFramepackTransformer3DModel(
indices_latents_history_4x: Optional[torch.Tensor] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
):
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
+11 -2
View File
@@ -82,6 +82,7 @@ class UNet1DModel(ModelMixin, ConfigMixin):
out_channels: int = 2,
extra_in_channels: int = 0,
time_embedding_type: str = "fourier",
time_embedding_dim: Optional[int] = None,
flip_sin_to_cos: bool = True,
use_timestep_embedding: bool = False,
freq_shift: float = 0.0,
@@ -100,15 +101,23 @@ class UNet1DModel(ModelMixin, ConfigMixin):
# time
if time_embedding_type == "fourier":
time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
if time_embed_dim % 2 != 0:
raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
self.time_proj = GaussianFourierProjection(
embedding_size=8, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
embedding_size=time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
)
timestep_input_dim = 2 * block_out_channels[0]
timestep_input_dim = time_embed_dim
elif time_embedding_type == "positional":
time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
self.time_proj = Timesteps(
block_out_channels[0], flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=freq_shift
)
timestep_input_dim = block_out_channels[0]
else:
raise ValueError(
f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
)
if use_timestep_embedding:
time_embed_dim = block_out_channels[0] * 4
+2
View File
@@ -285,6 +285,7 @@ else:
]
_import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"]
_import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"]
_import_structure["lucy"] = ["LucyEditPipeline"]
_import_structure["marigold"].extend(
[
"MarigoldDepthPipeline",
@@ -682,6 +683,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LEditsPPPipelineStableDiffusionXL,
)
from .ltx import LTXConditionPipeline, LTXImageToVideoPipeline, LTXLatentUpsamplePipeline, LTXPipeline
from .lucy import LucyEditPipeline
from .lumina import LuminaPipeline, LuminaText2ImgPipeline
from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline
from .marigold import (
+47
View File
@@ -0,0 +1,47 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_lucy_edit"] = ["LucyEditPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_lucy_edit import LucyEditPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
@@ -0,0 +1,735 @@
# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
# Copyright 2025 The Decart AI Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Modifications by Decart AI Team:
# - Based on pipeline_wan.py, but with supports recieving a condition video appended to the channel dimension.
import html
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import regex as re
import torch
from PIL import Image
from transformers import AutoTokenizer, UMT5EncoderModel
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...loaders import WanLoraLoaderMixin
from ...models import AutoencoderKLWan, WanTransformer3DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import LucyPipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
if is_ftfy_available():
import ftfy
EXAMPLE_DOC_STRING = """
Examples:
```python
>>> from typing import List
>>> import torch
>>> from PIL import Image
>>> from diffusers import AutoencoderKLWan, LucyEditPipeline
>>> from diffusers.utils import export_to_video, load_video
>>> # Arguments
>>> url = "https://d2drjpuinn46lb.cloudfront.net/painter_original_edit.mp4"
>>> prompt = "Change the apron and blouse to a classic clown costume: satin polka-dot jumpsuit in bright primary colors, ruffled white collar, oversized pom-pom buttons, white gloves, oversized red shoes, red foam nose; soft window light from left, eye-level medium shot, natural folds and fabric highlights."
>>> negative_prompt = ""
>>> num_frames = 81
>>> height = 480
>>> width = 832
>>> # Load video
>>> def convert_video(video: List[Image.Image]) -> List[Image.Image]:
... video = load_video(url)[:num_frames]
... video = [video[i].resize((width, height)) for i in range(num_frames)]
... return video
>>> video = load_video(url, convert_method=convert_video)
>>> # Load model
>>> model_id = "decart-ai/Lucy-Edit-Dev"
>>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
>>> pipe = LucyEditPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")
>>> # Generate video
>>> output = pipe(
... prompt=prompt,
... video=video,
... negative_prompt=negative_prompt,
... height=480,
... width=832,
... num_frames=81,
... guidance_scale=5.0,
... ).frames[0]
>>> # Export video
>>> export_to_video(output, "output.mp4", fps=24)
```
"""
def basic_clean(text):
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
def whitespace_clean(text):
text = re.sub(r"\s+", " ", text)
text = text.strip()
return text
def prompt_clean(text):
text = whitespace_clean(basic_clean(text))
return text
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
class LucyEditPipeline(DiffusionPipeline, WanLoraLoaderMixin):
r"""
Pipeline for video-to-video generation using Lucy Edit.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
Args:
tokenizer ([`T5Tokenizer`]):
Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer),
specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
text_encoder ([`T5EncoderModel`]):
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
transformer ([`WanTransformer3DModel`]):
Conditional Transformer to denoise the input latents.
scheduler ([`UniPCMultistepScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKLWan`]):
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
transformer_2 ([`WanTransformer3DModel`], *optional*):
Conditional Transformer to denoise the input latents during the low-noise stage. If provided, enables
two-stage denoising where `transformer` handles high-noise stages and `transformer_2` handles low-noise
stages. If not provided, only `transformer` is used.
boundary_ratio (`float`, *optional*, defaults to `None`):
Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising.
The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. When provided,
`transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps <
boundary_timestep. If `None`, only `transformer` is used for the entire denoising process.
"""
model_cpu_offload_seq = "text_encoder->transformer->transformer_2->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
_optional_components = ["transformer", "transformer_2"]
def __init__(
self,
tokenizer: AutoTokenizer,
text_encoder: UMT5EncoderModel,
vae: AutoencoderKLWan,
scheduler: FlowMatchEulerDiscreteScheduler,
transformer: Optional[WanTransformer3DModel] = None,
transformer_2: Optional[WanTransformer3DModel] = None,
boundary_ratio: Optional[float] = None,
expand_timesteps: bool = False, # Wan2.2 ti2v
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
scheduler=scheduler,
transformer_2=transformer_2,
)
self.register_to_config(boundary_ratio=boundary_ratio)
self.register_to_config(expand_timesteps=expand_timesteps)
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
# Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline._get_t5_prompt_embeds
def _get_t5_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
num_videos_per_prompt: int = 1,
max_sequence_length: int = 226,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
prompt = [prompt_clean(u) for u in prompt]
batch_size = len(prompt)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
add_special_tokens=True,
return_attention_mask=True,
return_tensors="pt",
)
text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
seq_lens = mask.gt(0).sum(dim=1).long()
prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
prompt_embeds = torch.stack(
[torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
)
# duplicate text embeddings for each generation per prompt, using mps friendly method
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
return prompt_embeds
# Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt
def encode_prompt(
self,
prompt: Union[str, List[str]],
negative_prompt: Optional[Union[str, List[str]]] = None,
do_classifier_free_guidance: bool = True,
num_videos_per_prompt: int = 1,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
max_sequence_length: int = 226,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
Whether to use classifier free guidance or not.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
device: (`torch.device`, *optional*):
torch device
dtype: (`torch.dtype`, *optional*):
torch dtype
"""
device = device or self._execution_device
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
prompt_embeds = self._get_t5_prompt_embeds(
prompt=prompt,
num_videos_per_prompt=num_videos_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
if do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
negative_prompt_embeds = self._get_t5_prompt_embeds(
prompt=negative_prompt,
num_videos_per_prompt=num_videos_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
return prompt_embeds, negative_prompt_embeds
def check_inputs(
self,
video,
prompt,
negative_prompt,
height,
width,
prompt_embeds=None,
negative_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
guidance_scale_2=None,
):
if height % 16 != 0 or width % 16 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
elif negative_prompt is not None and (
not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
):
raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
if self.config.boundary_ratio is None and guidance_scale_2 is not None:
raise ValueError("`guidance_scale_2` is only supported when the pipeline's `boundary_ratio` is not None.")
if video is None:
raise ValueError("`video` is required, received None.")
def prepare_latents(
self,
video: Optional[torch.Tensor] = None,
batch_size: int = 1,
num_channels_latents: int = 16,
height: int = 480,
width: int = 832,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
num_latent_frames = (
(video.size(2) - 1) // self.vae_scale_factor_temporal + 1 if latents is None else latents.size(1)
)
shape = (
batch_size,
num_channels_latents,
num_latent_frames,
height // self.vae_scale_factor_spatial,
width // self.vae_scale_factor_spatial,
)
# Prepare noise latents
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device)
# Prepare condition latents
condition_latents = [
retrieve_latents(self.vae.encode(vid.unsqueeze(0)), sample_mode="argmax") for vid in video
]
condition_latents = torch.cat(condition_latents, dim=0).to(dtype)
latents_mean = (
torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1).to(device, dtype)
)
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
device, dtype
)
condition_latents = (condition_latents - latents_mean) * latents_std
# Check shapes
assert latents.shape == condition_latents.shape, (
f"Latents shape {latents.shape} does not match expected shape {condition_latents.shape}. Please check the input."
)
return latents, condition_latents
@property
def guidance_scale(self):
return self._guidance_scale
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1.0
@property
def num_timesteps(self):
return self._num_timesteps
@property
def current_timestep(self):
return self._current_timestep
@property
def interrupt(self):
return self._interrupt
@property
def attention_kwargs(self):
return self._attention_kwargs
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
video: List[Image.Image],
prompt: Union[str, List[str]] = None,
negative_prompt: Union[str, List[str]] = None,
height: int = 480,
width: int = 832,
num_frames: int = 81,
num_inference_steps: int = 50,
guidance_scale: float = 5.0,
guidance_scale_2: Optional[float] = None,
num_videos_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
output_type: Optional[str] = "np",
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512,
):
r"""
The call function to the pipeline for generation.
Args:
video (`List[Image.Image]`):
The video to use as the condition for the video generation.
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, pass `prompt_embeds` instead.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to avoid during image generation. If not defined, pass `negative_prompt_embeds`
instead. Ignored when not using guidance (`guidance_scale` < `1`).
height (`int`, defaults to `480`):
The height in pixels of the generated image.
width (`int`, defaults to `832`):
The width in pixels of the generated image.
num_frames (`int`, defaults to `81`):
The number of frames in the generated video.
num_inference_steps (`int`, defaults to `50`):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, defaults to `5.0`):
Guidance scale as defined in [Classifier-Free Diffusion
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
the text `prompt`, usually at the expense of lower image quality.
guidance_scale_2 (`float`, *optional*, defaults to `None`):
Guidance scale for the low-noise stage transformer (`transformer_2`). If `None` and the pipeline's
`boundary_ratio` is not None, uses the same value as `guidance_scale`. Only used when `transformer_2`
and the pipeline's `boundary_ratio` are not None.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor is generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from the `prompt` input argument.
output_type (`str`, *optional*, defaults to `"np"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`LucyPipelineOutput`] instead of a plain tuple.
attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int`, defaults to `512`):
The maximum sequence length of the text encoder. If the prompt is longer than this, it will be
truncated. If the prompt is shorter, it will be padded to this length.
Examples:
Returns:
[`~LucyPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`LucyPipelineOutput`] is returned, otherwise a `tuple` is returned where
the first element is a list with the generated images and the second element is a list of `bool`s
indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
"""
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
# 1. Check inputs. Raise error if not correct
self.check_inputs(
video,
prompt,
negative_prompt,
height,
width,
prompt_embeds,
negative_prompt_embeds,
callback_on_step_end_tensor_inputs,
guidance_scale_2,
)
if num_frames % self.vae_scale_factor_temporal != 1:
logger.warning(
f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
)
num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
num_frames = max(num_frames, 1)
if self.config.boundary_ratio is not None and guidance_scale_2 is None:
guidance_scale_2 = guidance_scale
self._guidance_scale = guidance_scale
self._guidance_scale_2 = guidance_scale_2
self._attention_kwargs = attention_kwargs
self._current_timestep = None
self._interrupt = False
device = self._execution_device
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
# 3. Encode input prompt
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt=prompt,
negative_prompt=negative_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance,
num_videos_per_prompt=num_videos_per_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
max_sequence_length=max_sequence_length,
device=device,
)
transformer_dtype = self.transformer.dtype if self.transformer is not None else self.transformer_2.dtype
prompt_embeds = prompt_embeds.to(transformer_dtype)
if negative_prompt_embeds is not None:
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = (
self.transformer.config.out_channels
if self.transformer is not None
else self.transformer_2.config.out_channels
)
video = self.video_processor.preprocess_video(video, height=height, width=width).to(
device, dtype=torch.float32
)
latents, condition_latents = self.prepare_latents(
video,
batch_size * num_videos_per_prompt,
num_channels_latents,
height,
width,
torch.float32,
device,
generator,
latents,
)
mask = torch.ones(latents.shape, dtype=torch.float32, device=device)
# 6. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
if self.config.boundary_ratio is not None:
boundary_timestep = self.config.boundary_ratio * self.scheduler.config.num_train_timesteps
else:
boundary_timestep = None
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
self._current_timestep = t
if boundary_timestep is None or t >= boundary_timestep:
# wan2.1 or high-noise stage in wan2.2
current_model = self.transformer
current_guidance_scale = guidance_scale
else:
# low-noise stage in wan2.2
current_model = self.transformer_2
current_guidance_scale = guidance_scale_2
# latent_model_input = latents.to(transformer_dtype)
latent_model_input = torch.cat([latents, condition_latents], dim=1).to(transformer_dtype)
# latent_model_input = torch.cat([latents, latents], dim=1).to(transformer_dtype)
if self.config.expand_timesteps:
# seq_len: num_latent_frames * latent_height//2 * latent_width//2
temp_ts = (mask[0][0][:, ::2, ::2] * t).flatten()
# batch_size, seq_len
timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1)
else:
timestep = t.expand(latents.shape[0])
with current_model.cache_context("cond"):
noise_pred = current_model(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
if self.do_classifier_free_guidance:
with current_model.cache_context("uncond"):
noise_uncond = current_model(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=negative_prompt_embeds,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
self._current_timestep = None
if not output_type == "latent":
latents = latents.to(self.vae.dtype)
latents_mean = (
torch.tensor(self.vae.config.latents_mean)
.view(1, self.vae.config.z_dim, 1, 1, 1)
.to(latents.device, latents.dtype)
)
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
latents.device, latents.dtype
)
latents = latents / latents_std + latents_mean
video = self.vae.decode(latents, return_dict=False)[0]
video = self.video_processor.postprocess_video(video, output_type=output_type)
else:
video = latents
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (video,)
return LucyPipelineOutput(frames=video)
@@ -0,0 +1,20 @@
from dataclasses import dataclass
import torch
from diffusers.utils import BaseOutput
@dataclass
class LucyPipelineOutput(BaseOutput):
r"""
Output class for Lucy pipelines.
Args:
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
`(batch_size, num_frames, channels, height, width)`.
"""
frames: torch.Tensor
@@ -48,7 +48,6 @@ from .transformers_loading_utils import _load_tokenizer_from_dduf, _load_transfo
if is_transformers_available():
import transformers
from transformers import PreTrainedModel, PreTrainedTokenizerBase
from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME
from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME
from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME
@@ -112,7 +111,7 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
]
if is_transformers_available():
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME]
# model_pytorch, diffusion_model_pytorch, ...
weight_prefixes = [w.split(".")[0] for w in weight_names]
@@ -191,7 +190,7 @@ def filter_model_files(filenames):
]
if is_transformers_available():
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME]
allowed_extensions = [wn.split(".")[-1] for wn in weight_names]
@@ -212,7 +211,7 @@ def variant_compatible_siblings(filenames, variant=None, ignore_patterns=None) -
]
if is_transformers_available():
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME]
# model_pytorch, diffusion_model_pytorch, ...
weight_prefixes = [w.split(".")[0] for w in weight_names]
@@ -152,16 +152,26 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
text_encoder ([`T5EncoderModel`]):
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
transformer ([`WanTransformer3DModel`]):
transformer ([`WanVACETransformer3DModel`]):
Conditional Transformer to denoise the input latents.
transformer_2 ([`WanVACETransformer3DModel`], *optional*):
Conditional Transformer to denoise the input latents during the low-noise stage. In two-stage denoising,
`transformer` handles high-noise stages and `transformer_2` handles low-noise stages. If not provided, only
`transformer` is used.
scheduler ([`UniPCMultistepScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKLWan`]):
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
boundary_ratio (`float`, *optional*, defaults to `None`):
Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising.
The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. When provided,
`transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps <
boundary_timestep. If `None`, only `transformer` is used for the entire denoising process.
"""
model_cpu_offload_seq = "text_encoder->transformer->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
_optional_components = ["transformer_2"]
def __init__(
self,
@@ -170,6 +180,8 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
transformer: WanVACETransformer3DModel,
vae: AutoencoderKLWan,
scheduler: FlowMatchEulerDiscreteScheduler,
transformer_2: WanVACETransformer3DModel = None,
boundary_ratio: Optional[float] = None,
):
super().__init__()
@@ -178,9 +190,10 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
transformer_2=transformer_2,
scheduler=scheduler,
)
self.register_to_config(boundary_ratio=boundary_ratio)
self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
@@ -321,6 +334,7 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
video=None,
mask=None,
reference_images=None,
guidance_scale_2=None,
):
base = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1]
if height % base != 0 or width % base != 0:
@@ -332,6 +346,8 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if self.config.boundary_ratio is None and guidance_scale_2 is not None:
raise ValueError("`guidance_scale_2` is only supported when the pipeline's `boundary_ratio` is not None.")
if prompt is not None and prompt_embeds is not None:
raise ValueError(
@@ -667,6 +683,7 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
num_frames: int = 81,
num_inference_steps: int = 50,
guidance_scale: float = 5.0,
guidance_scale_2: Optional[float] = None,
num_videos_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
@@ -728,6 +745,10 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
guidance_scale_2 (`float`, *optional*, defaults to `None`):
Guidance scale for the low-noise stage transformer (`transformer_2`). If `None` and the pipeline's
`boundary_ratio` is not None, uses the same value as `guidance_scale`. Only used when `transformer_2`
and the pipeline's `boundary_ratio` are not None.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -793,6 +814,7 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
video,
mask,
reference_images,
guidance_scale_2,
)
if num_frames % self.vae_scale_factor_temporal != 1:
@@ -802,7 +824,11 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
num_frames = max(num_frames, 1)
if self.config.boundary_ratio is not None and guidance_scale_2 is None:
guidance_scale_2 = guidance_scale
self._guidance_scale = guidance_scale
self._guidance_scale_2 = guidance_scale_2
self._attention_kwargs = attention_kwargs
self._current_timestep = None
self._interrupt = False
@@ -896,36 +922,53 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
if self.config.boundary_ratio is not None:
boundary_timestep = self.config.boundary_ratio * self.scheduler.config.num_train_timesteps
else:
boundary_timestep = None
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
self._current_timestep = t
if boundary_timestep is None or t >= boundary_timestep:
# wan2.1 or high-noise stage in wan2.2
current_model = self.transformer
current_guidance_scale = guidance_scale
else:
# low-noise stage in wan2.2
current_model = self.transformer_2
current_guidance_scale = guidance_scale_2
latent_model_input = latents.to(transformer_dtype)
timestep = t.expand(latents.shape[0])
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
control_hidden_states=conditioning_latents,
control_hidden_states_scale=conditioning_scale,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
if self.do_classifier_free_guidance:
noise_uncond = self.transformer(
with current_model.cache_context("cond"):
noise_pred = current_model(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=negative_prompt_embeds,
encoder_hidden_states=prompt_embeds,
control_hidden_states=conditioning_latents,
control_hidden_states_scale=conditioning_scale,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
if self.do_classifier_free_guidance:
with current_model.cache_context("uncond"):
noise_uncond = current_model(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=negative_prompt_embeds,
control_hidden_states=conditioning_latents,
control_hidden_states_scale=conditioning_scale,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
@@ -1592,6 +1592,21 @@ class LTXPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class LucyEditPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class Lumina2Pipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
+7
View File
@@ -907,6 +907,13 @@ class FluxLoRAIntegrationTests(unittest.TestCase):
assert max_diff < 1e-3
def test_flux_kohya_embedders_conversion(self):
"""Test that embedders load without throwing errors"""
self.pipeline.load_lora_weights("rockerBOO/flux-bpo-po-lora")
self.pipeline.unload_lora_weights()
assert True
def test_flux_xlabs(self):
self.pipeline.load_lora_weights("XLabs-AI/flux-lora-collection", weight_name="disney_lora.safetensors")
self.pipeline.fuse_lora()
+1
View File
@@ -87,6 +87,7 @@ class WanVACEPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"transformer_2": None,
}
return components