Compare commits
8 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 82d7676fe3 | |||
| 7e7e62c6ff | |||
| eda9ff8300 | |||
| efb7a299af | |||
| d06750a5fd | |||
| 8c72cd12ee | |||
| 751e250f70 | |||
| b50014067d |
@@ -0,0 +1,91 @@
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel
|
||||
|
||||
from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import StableDiffusion3Pipeline
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TextToImageInput(BaseModel):
|
||||
model: str
|
||||
prompt: str
|
||||
size: str | None = None
|
||||
n: int | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class PresetModels:
|
||||
SD3: List[str] = field(default_factory=lambda: ["stabilityai/stable-diffusion-3-medium"])
|
||||
SD3_5: List[str] = field(
|
||||
default_factory=lambda: [
|
||||
"stabilityai/stable-diffusion-3.5-large",
|
||||
"stabilityai/stable-diffusion-3.5-large-turbo",
|
||||
"stabilityai/stable-diffusion-3.5-medium",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class TextToImagePipelineSD3:
|
||||
def __init__(self, model_path: str | None = None):
|
||||
self.model_path = model_path or os.getenv("MODEL_PATH")
|
||||
self.pipeline: StableDiffusion3Pipeline | None = None
|
||||
self.device: str | None = None
|
||||
|
||||
def start(self):
|
||||
if torch.cuda.is_available():
|
||||
model_path = self.model_path or "stabilityai/stable-diffusion-3.5-large"
|
||||
logger.info("Loading CUDA")
|
||||
self.device = "cuda"
|
||||
self.pipeline = StableDiffusion3Pipeline.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=torch.float16,
|
||||
).to(device=self.device)
|
||||
elif torch.backends.mps.is_available():
|
||||
model_path = self.model_path or "stabilityai/stable-diffusion-3.5-medium"
|
||||
logger.info("Loading MPS for Mac M Series")
|
||||
self.device = "mps"
|
||||
self.pipeline = StableDiffusion3Pipeline.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=torch.bfloat16,
|
||||
).to(device=self.device)
|
||||
else:
|
||||
raise Exception("No CUDA or MPS device available")
|
||||
|
||||
|
||||
class ModelPipelineInitializer:
|
||||
def __init__(self, model: str = "", type_models: str = "t2im"):
|
||||
self.model = model
|
||||
self.type_models = type_models
|
||||
self.pipeline = None
|
||||
self.device = "cuda" if torch.cuda.is_available() else "mps"
|
||||
self.model_type = None
|
||||
|
||||
def initialize_pipeline(self):
|
||||
if not self.model:
|
||||
raise ValueError("Model name not provided")
|
||||
|
||||
# Check if model exists in PresetModels
|
||||
preset_models = PresetModels()
|
||||
|
||||
# Determine which model type we're dealing with
|
||||
if self.model in preset_models.SD3:
|
||||
self.model_type = "SD3"
|
||||
elif self.model in preset_models.SD3_5:
|
||||
self.model_type = "SD3_5"
|
||||
|
||||
# Create appropriate pipeline based on model type and type_models
|
||||
if self.type_models == "t2im":
|
||||
if self.model_type in ["SD3", "SD3_5"]:
|
||||
self.pipeline = TextToImagePipelineSD3(self.model)
|
||||
else:
|
||||
raise ValueError(f"Model type {self.model_type} not supported for text-to-image")
|
||||
elif self.type_models == "t2v":
|
||||
raise ValueError(f"Unsupported type_models: {self.type_models}")
|
||||
|
||||
return self.pipeline
|
||||
@@ -0,0 +1,171 @@
|
||||
# Asynchronous server and parallel execution of models
|
||||
|
||||
> Example/demo server that keeps a single model in memory while safely running parallel inference requests by creating per-request lightweight views and cloning only small, stateful components (schedulers, RNG state, small mutable attrs). Works with StableDiffusion3 pipelines.
|
||||
> We recommend running 10 to 50 inferences in parallel for optimal performance, averaging between 25 and 30 seconds to 1 minute and 1 minute and 30 seconds. (This is only recommended if you have a GPU with 35GB of VRAM or more; otherwise, keep it to one or two inferences in parallel to avoid decoding or saving errors due to memory shortages.)
|
||||
|
||||
## ⚠️ IMPORTANT
|
||||
|
||||
* The example demonstrates how to run pipelines like `StableDiffusion3-3.5` concurrently while keeping a single copy of the heavy model parameters on GPU.
|
||||
|
||||
## Necessary components
|
||||
|
||||
All the components needed to create the inference server are in the current directory:
|
||||
|
||||
```
|
||||
server-async/
|
||||
├── utils/
|
||||
├─────── __init__.py
|
||||
├─────── scheduler.py # BaseAsyncScheduler wrapper and async_retrieve_timesteps for secure inferences
|
||||
├─────── requestscopedpipeline.py # RequestScoped Pipeline for inference with a single in-memory model
|
||||
├─────── utils.py # Image/video saving utilities and service configuration
|
||||
├── Pipelines.py # pipeline loader classes (SD3)
|
||||
├── serverasync.py # FastAPI app with lifespan management and async inference endpoints
|
||||
├── test.py # Client test script for inference requests
|
||||
├── requirements.txt # Dependencies
|
||||
└── README.md # This documentation
|
||||
```
|
||||
|
||||
## What `diffusers-async` adds / Why we needed it
|
||||
|
||||
Core problem: a naive server that calls `pipe.__call__` concurrently can hit **race conditions** (e.g., `scheduler.set_timesteps` mutates shared state) or explode memory by deep-copying the whole pipeline per-request.
|
||||
|
||||
`diffusers-async` / this example addresses that by:
|
||||
|
||||
* **Request-scoped views**: `RequestScopedPipeline` creates a shallow copy of the pipeline per request so heavy weights (UNet, VAE, text encoder) remain shared and *are not duplicated*.
|
||||
* **Per-request mutable state**: stateful small objects (scheduler, RNG state, small lists/dicts, callbacks) are cloned per request. The system uses `BaseAsyncScheduler.clone_for_request(...)` for scheduler cloning, with fallback to safe `deepcopy` or other heuristics.
|
||||
* **Tokenizer concurrency safety**: `RequestScopedPipeline` now manages an internal tokenizer lock with automatic tokenizer detection and wrapping. This ensures that Rust tokenizers are safe to use under concurrency — race condition errors like `Already borrowed` no longer occur.
|
||||
* **`async_retrieve_timesteps(..., return_scheduler=True)`**: fully retro-compatible helper that returns `(timesteps, num_inference_steps, scheduler)` without mutating the shared scheduler. For users not using `return_scheduler=True`, the behavior is identical to the original API.
|
||||
* **Robust attribute handling**: wrapper avoids writing to read-only properties (e.g., `components`) and auto-detects small mutable attributes to clone while avoiding duplication of large tensors. Configurable tensor size threshold prevents cloning of large tensors.
|
||||
* **Enhanced scheduler wrapping**: `BaseAsyncScheduler` automatically wraps schedulers with improved `__getattr__`, `__setattr__`, and debugging methods (`__repr__`, `__str__`).
|
||||
|
||||
## How the server works (high-level flow)
|
||||
|
||||
1. **Single model instance** is loaded into memory (GPU/MPS) when the server starts.
|
||||
2. On each HTTP inference request:
|
||||
|
||||
* The server uses `RequestScopedPipeline.generate(...)` which:
|
||||
|
||||
* automatically wraps the base scheduler in `BaseAsyncScheduler` (if not already wrapped),
|
||||
* obtains a *local scheduler* (via `clone_for_request()` or `deepcopy`),
|
||||
* does `local_pipe = copy.copy(base_pipe)` (shallow copy),
|
||||
* sets `local_pipe.scheduler = local_scheduler` (if possible),
|
||||
* clones only small mutable attributes (callbacks, rng, small latents) with auto-detection,
|
||||
* wraps tokenizers with thread-safe locks to prevent race conditions,
|
||||
* optionally enters a `model_cpu_offload_context()` for memory offload hooks,
|
||||
* calls the pipeline on the local view (`local_pipe(...)`).
|
||||
3. **Result**: inference completes, images are moved to CPU & saved (if requested), internal buffers freed (GC + `torch.cuda.empty_cache()`).
|
||||
4. Multiple requests can run in parallel while sharing heavy weights and isolating mutable state.
|
||||
|
||||
## How to set up and run the server
|
||||
|
||||
### 1) Install dependencies
|
||||
|
||||
Recommended: create a virtualenv / conda environment.
|
||||
|
||||
```bash
|
||||
pip install diffusers
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### 2) Start the server
|
||||
|
||||
Using the `serverasync.py` file that already has everything you need:
|
||||
|
||||
```bash
|
||||
python serverasync.py
|
||||
```
|
||||
|
||||
The server will start on `http://localhost:8500` by default with the following features:
|
||||
- FastAPI application with async lifespan management
|
||||
- Automatic model loading and pipeline initialization
|
||||
- Request counting and active inference tracking
|
||||
- Memory cleanup after each inference
|
||||
- CORS middleware for cross-origin requests
|
||||
|
||||
### 3) Test the server
|
||||
|
||||
Use the included test script:
|
||||
|
||||
```bash
|
||||
python test.py
|
||||
```
|
||||
|
||||
Or send a manual request:
|
||||
|
||||
`POST /api/diffusers/inference` with JSON body:
|
||||
|
||||
```json
|
||||
{
|
||||
"prompt": "A futuristic cityscape, vibrant colors",
|
||||
"num_inference_steps": 30,
|
||||
"num_images_per_prompt": 1
|
||||
}
|
||||
```
|
||||
|
||||
Response example:
|
||||
|
||||
```json
|
||||
{
|
||||
"response": ["http://localhost:8500/images/img123.png"]
|
||||
}
|
||||
```
|
||||
|
||||
### 4) Server endpoints
|
||||
|
||||
- `GET /` - Welcome message
|
||||
- `POST /api/diffusers/inference` - Main inference endpoint
|
||||
- `GET /images/{filename}` - Serve generated images
|
||||
- `GET /api/status` - Server status and memory info
|
||||
|
||||
## Advanced Configuration
|
||||
|
||||
### RequestScopedPipeline Parameters
|
||||
|
||||
```python
|
||||
RequestScopedPipeline(
|
||||
pipeline, # Base pipeline to wrap
|
||||
mutable_attrs=None, # Custom list of attributes to clone
|
||||
auto_detect_mutables=True, # Enable automatic detection of mutable attributes
|
||||
tensor_numel_threshold=1_000_000, # Tensor size threshold for cloning
|
||||
tokenizer_lock=None, # Custom threading lock for tokenizers
|
||||
wrap_scheduler=True # Auto-wrap scheduler in BaseAsyncScheduler
|
||||
)
|
||||
```
|
||||
|
||||
### BaseAsyncScheduler Features
|
||||
|
||||
* Transparent proxy to the original scheduler with `__getattr__` and `__setattr__`
|
||||
* `clone_for_request()` method for safe per-request scheduler cloning
|
||||
* Enhanced debugging with `__repr__` and `__str__` methods
|
||||
* Full compatibility with existing scheduler APIs
|
||||
|
||||
### Server Configuration
|
||||
|
||||
The server configuration can be modified in `serverasync.py` through the `ServerConfigModels` dataclass:
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class ServerConfigModels:
|
||||
model: str = 'stabilityai/stable-diffusion-3.5-medium'
|
||||
type_models: str = 't2im'
|
||||
host: str = '0.0.0.0'
|
||||
port: int = 8500
|
||||
```
|
||||
|
||||
## Troubleshooting (quick)
|
||||
|
||||
* `Already borrowed` — previously a Rust tokenizer concurrency error.
|
||||
✅ This is now fixed: `RequestScopedPipeline` automatically detects and wraps tokenizers with thread locks, so race conditions no longer happen.
|
||||
|
||||
* `can't set attribute 'components'` — pipeline exposes read-only `components`.
|
||||
✅ The RequestScopedPipeline now detects read-only properties and skips setting them automatically.
|
||||
|
||||
* Scheduler issues:
|
||||
* If the scheduler doesn't implement `clone_for_request` and `deepcopy` fails, we log and fallback — but prefer `async_retrieve_timesteps(..., return_scheduler=True)` to avoid mutating the shared scheduler.
|
||||
✅ Note: `async_retrieve_timesteps` is fully retro-compatible — if you don't pass `return_scheduler=True`, the behavior is unchanged.
|
||||
|
||||
* Memory issues with large tensors:
|
||||
✅ The system now has configurable `tensor_numel_threshold` to prevent cloning of large tensors while still cloning small mutable ones.
|
||||
|
||||
* Automatic tokenizer detection:
|
||||
✅ The system automatically identifies tokenizer components by checking for tokenizer methods, class names, and attributes, then applies thread-safe wrappers.
|
||||
@@ -0,0 +1,10 @@
|
||||
torch
|
||||
torchvision
|
||||
transformers
|
||||
sentencepiece
|
||||
fastapi
|
||||
uvicorn
|
||||
ftfy
|
||||
accelerate
|
||||
xformers
|
||||
protobuf
|
||||
@@ -0,0 +1,230 @@
|
||||
import asyncio
|
||||
import gc
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import threading
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional, Type
|
||||
|
||||
import torch
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import FileResponse
|
||||
from Pipelines import ModelPipelineInitializer
|
||||
from pydantic import BaseModel
|
||||
|
||||
from utils import RequestScopedPipeline, Utils
|
||||
|
||||
|
||||
@dataclass
|
||||
class ServerConfigModels:
|
||||
model: str = "stabilityai/stable-diffusion-3.5-medium"
|
||||
type_models: str = "t2im"
|
||||
constructor_pipeline: Optional[Type] = None
|
||||
custom_pipeline: Optional[Type] = None
|
||||
components: Optional[Dict[str, Any]] = None
|
||||
torch_dtype: Optional[torch.dtype] = None
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8500
|
||||
|
||||
|
||||
server_config = ServerConfigModels()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
app.state.logger = logging.getLogger("diffusers-server")
|
||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128,expandable_segments:True"
|
||||
os.environ["CUDA_LAUNCH_BLOCKING"] = "0"
|
||||
|
||||
app.state.total_requests = 0
|
||||
app.state.active_inferences = 0
|
||||
app.state.metrics_lock = asyncio.Lock()
|
||||
app.state.metrics_task = None
|
||||
|
||||
app.state.utils_app = Utils(
|
||||
host=server_config.host,
|
||||
port=server_config.port,
|
||||
)
|
||||
|
||||
async def metrics_loop():
|
||||
try:
|
||||
while True:
|
||||
async with app.state.metrics_lock:
|
||||
total = app.state.total_requests
|
||||
active = app.state.active_inferences
|
||||
app.state.logger.info(f"[METRICS] total_requests={total} active_inferences={active}")
|
||||
await asyncio.sleep(5)
|
||||
except asyncio.CancelledError:
|
||||
app.state.logger.info("Metrics loop cancelled")
|
||||
raise
|
||||
|
||||
app.state.metrics_task = asyncio.create_task(metrics_loop())
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
task = app.state.metrics_task
|
||||
if task:
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
try:
|
||||
stop_fn = getattr(model_pipeline, "stop", None) or getattr(model_pipeline, "close", None)
|
||||
if callable(stop_fn):
|
||||
await run_in_threadpool(stop_fn)
|
||||
except Exception as e:
|
||||
app.state.logger.warning(f"Error during pipeline shutdown: {e}")
|
||||
|
||||
app.state.logger.info("Lifespan shutdown complete")
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
logger = logging.getLogger("DiffusersServer.Pipelines")
|
||||
|
||||
|
||||
initializer = ModelPipelineInitializer(
|
||||
model=server_config.model,
|
||||
type_models=server_config.type_models,
|
||||
)
|
||||
model_pipeline = initializer.initialize_pipeline()
|
||||
model_pipeline.start()
|
||||
|
||||
request_pipe = RequestScopedPipeline(model_pipeline.pipeline)
|
||||
pipeline_lock = threading.Lock()
|
||||
|
||||
logger.info(f"Pipeline initialized and ready to receive requests (model ={server_config.model})")
|
||||
|
||||
app.state.MODEL_INITIALIZER = initializer
|
||||
app.state.MODEL_PIPELINE = model_pipeline
|
||||
app.state.REQUEST_PIPE = request_pipe
|
||||
app.state.PIPELINE_LOCK = pipeline_lock
|
||||
|
||||
|
||||
class JSONBodyQueryAPI(BaseModel):
|
||||
model: str | None = None
|
||||
prompt: str
|
||||
negative_prompt: str | None = None
|
||||
num_inference_steps: int = 28
|
||||
num_images_per_prompt: int = 1
|
||||
|
||||
|
||||
@app.middleware("http")
|
||||
async def count_requests_middleware(request: Request, call_next):
|
||||
async with app.state.metrics_lock:
|
||||
app.state.total_requests += 1
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {"message": "Welcome to the Diffusers Server"}
|
||||
|
||||
|
||||
@app.post("/api/diffusers/inference")
|
||||
async def api(json: JSONBodyQueryAPI):
|
||||
prompt = json.prompt
|
||||
negative_prompt = json.negative_prompt or ""
|
||||
num_steps = json.num_inference_steps
|
||||
num_images_per_prompt = json.num_images_per_prompt
|
||||
|
||||
wrapper = app.state.MODEL_PIPELINE
|
||||
initializer = app.state.MODEL_INITIALIZER
|
||||
|
||||
utils_app = app.state.utils_app
|
||||
|
||||
if not wrapper or not wrapper.pipeline:
|
||||
raise HTTPException(500, "Model not initialized correctly")
|
||||
if not prompt.strip():
|
||||
raise HTTPException(400, "No prompt provided")
|
||||
|
||||
def make_generator():
|
||||
g = torch.Generator(device=initializer.device)
|
||||
return g.manual_seed(random.randint(0, 10_000_000))
|
||||
|
||||
req_pipe = app.state.REQUEST_PIPE
|
||||
|
||||
def infer():
|
||||
gen = make_generator()
|
||||
return req_pipe.generate(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
generator=gen,
|
||||
num_inference_steps=num_steps,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=initializer.device,
|
||||
output_type="pil",
|
||||
)
|
||||
|
||||
try:
|
||||
async with app.state.metrics_lock:
|
||||
app.state.active_inferences += 1
|
||||
|
||||
output = await run_in_threadpool(infer)
|
||||
|
||||
async with app.state.metrics_lock:
|
||||
app.state.active_inferences = max(0, app.state.active_inferences - 1)
|
||||
|
||||
urls = [utils_app.save_image(img) for img in output.images]
|
||||
return {"response": urls}
|
||||
|
||||
except Exception as e:
|
||||
async with app.state.metrics_lock:
|
||||
app.state.active_inferences = max(0, app.state.active_inferences - 1)
|
||||
logger.error(f"Error during inference: {e}")
|
||||
raise HTTPException(500, f"Error in processing: {e}")
|
||||
|
||||
finally:
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
torch.cuda.ipc_collect()
|
||||
gc.collect()
|
||||
|
||||
|
||||
@app.get("/images/{filename}")
|
||||
async def serve_image(filename: str):
|
||||
utils_app = app.state.utils_app
|
||||
file_path = os.path.join(utils_app.image_dir, filename)
|
||||
if not os.path.isfile(file_path):
|
||||
raise HTTPException(status_code=404, detail="Image not found")
|
||||
return FileResponse(file_path, media_type="image/png")
|
||||
|
||||
|
||||
@app.get("/api/status")
|
||||
async def get_status():
|
||||
memory_info = {}
|
||||
if torch.cuda.is_available():
|
||||
memory_allocated = torch.cuda.memory_allocated() / 1024**3 # GB
|
||||
memory_reserved = torch.cuda.memory_reserved() / 1024**3 # GB
|
||||
memory_info = {
|
||||
"memory_allocated_gb": round(memory_allocated, 2),
|
||||
"memory_reserved_gb": round(memory_reserved, 2),
|
||||
"device": torch.cuda.get_device_name(0),
|
||||
}
|
||||
|
||||
return {"current_model": server_config.model, "type_models": server_config.type_models, "memory": memory_info}
|
||||
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host=server_config.host, port=server_config.port)
|
||||
@@ -0,0 +1,65 @@
|
||||
import os
|
||||
import time
|
||||
import urllib.parse
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
SERVER_URL = "http://localhost:8500/api/diffusers/inference"
|
||||
BASE_URL = "http://localhost:8500"
|
||||
DOWNLOAD_FOLDER = "generated_images"
|
||||
WAIT_BEFORE_DOWNLOAD = 2 # seconds
|
||||
|
||||
os.makedirs(DOWNLOAD_FOLDER, exist_ok=True)
|
||||
|
||||
|
||||
def save_from_url(url: str) -> str:
|
||||
"""Download the given URL (relative or absolute) and save it locally."""
|
||||
if url.startswith("/"):
|
||||
direct = BASE_URL.rstrip("/") + url
|
||||
else:
|
||||
direct = url
|
||||
resp = requests.get(direct, timeout=60)
|
||||
resp.raise_for_status()
|
||||
filename = os.path.basename(urllib.parse.urlparse(direct).path) or f"img_{int(time.time())}.png"
|
||||
path = os.path.join(DOWNLOAD_FOLDER, filename)
|
||||
with open(path, "wb") as f:
|
||||
f.write(resp.content)
|
||||
return path
|
||||
|
||||
|
||||
def main():
|
||||
payload = {
|
||||
"prompt": "The T-800 Terminator Robot Returning From The Future, Anime Style",
|
||||
"num_inference_steps": 30,
|
||||
"num_images_per_prompt": 1,
|
||||
}
|
||||
|
||||
print("Sending request...")
|
||||
try:
|
||||
r = requests.post(SERVER_URL, json=payload, timeout=480)
|
||||
r.raise_for_status()
|
||||
except Exception as e:
|
||||
print(f"Request failed: {e}")
|
||||
return
|
||||
|
||||
body = r.json().get("response", [])
|
||||
# Normalize to a list
|
||||
urls = body if isinstance(body, list) else [body] if body else []
|
||||
if not urls:
|
||||
print("No URLs found in the response. Check the server output.")
|
||||
return
|
||||
|
||||
print(f"Received {len(urls)} URL(s). Waiting {WAIT_BEFORE_DOWNLOAD}s before downloading...")
|
||||
time.sleep(WAIT_BEFORE_DOWNLOAD)
|
||||
|
||||
for u in urls:
|
||||
try:
|
||||
path = save_from_url(u)
|
||||
print(f"Image saved to: {path}")
|
||||
except Exception as e:
|
||||
print(f"Error downloading {u}: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,2 @@
|
||||
from .requestscopedpipeline import RequestScopedPipeline
|
||||
from .utils import Utils
|
||||
@@ -0,0 +1,296 @@
|
||||
import copy
|
||||
import threading
|
||||
from typing import Any, Iterable, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers.utils import logging
|
||||
|
||||
from .scheduler import BaseAsyncScheduler, async_retrieve_timesteps
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def safe_tokenize(tokenizer, *args, lock, **kwargs):
|
||||
with lock:
|
||||
return tokenizer(*args, **kwargs)
|
||||
|
||||
|
||||
class RequestScopedPipeline:
|
||||
DEFAULT_MUTABLE_ATTRS = [
|
||||
"_all_hooks",
|
||||
"_offload_device",
|
||||
"_progress_bar_config",
|
||||
"_progress_bar",
|
||||
"_rng_state",
|
||||
"_last_seed",
|
||||
"latents",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pipeline: Any,
|
||||
mutable_attrs: Optional[Iterable[str]] = None,
|
||||
auto_detect_mutables: bool = True,
|
||||
tensor_numel_threshold: int = 1_000_000,
|
||||
tokenizer_lock: Optional[threading.Lock] = None,
|
||||
wrap_scheduler: bool = True,
|
||||
):
|
||||
self._base = pipeline
|
||||
self.unet = getattr(pipeline, "unet", None)
|
||||
self.vae = getattr(pipeline, "vae", None)
|
||||
self.text_encoder = getattr(pipeline, "text_encoder", None)
|
||||
self.components = getattr(pipeline, "components", None)
|
||||
|
||||
if wrap_scheduler and hasattr(pipeline, "scheduler") and pipeline.scheduler is not None:
|
||||
if not isinstance(pipeline.scheduler, BaseAsyncScheduler):
|
||||
pipeline.scheduler = BaseAsyncScheduler(pipeline.scheduler)
|
||||
|
||||
self._mutable_attrs = list(mutable_attrs) if mutable_attrs is not None else list(self.DEFAULT_MUTABLE_ATTRS)
|
||||
self._tokenizer_lock = tokenizer_lock if tokenizer_lock is not None else threading.Lock()
|
||||
|
||||
self._auto_detect_mutables = bool(auto_detect_mutables)
|
||||
self._tensor_numel_threshold = int(tensor_numel_threshold)
|
||||
|
||||
self._auto_detected_attrs: List[str] = []
|
||||
|
||||
def _make_local_scheduler(self, num_inference_steps: int, device: Optional[str] = None, **clone_kwargs):
|
||||
base_sched = getattr(self._base, "scheduler", None)
|
||||
if base_sched is None:
|
||||
return None
|
||||
|
||||
if not isinstance(base_sched, BaseAsyncScheduler):
|
||||
wrapped_scheduler = BaseAsyncScheduler(base_sched)
|
||||
else:
|
||||
wrapped_scheduler = base_sched
|
||||
|
||||
try:
|
||||
return wrapped_scheduler.clone_for_request(
|
||||
num_inference_steps=num_inference_steps, device=device, **clone_kwargs
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"clone_for_request failed: {e}; falling back to deepcopy()")
|
||||
try:
|
||||
return copy.deepcopy(wrapped_scheduler)
|
||||
except Exception as e:
|
||||
logger.warning(f"Deepcopy of scheduler failed: {e}. Returning original scheduler (*risky*).")
|
||||
return wrapped_scheduler
|
||||
|
||||
def _autodetect_mutables(self, max_attrs: int = 40):
|
||||
if not self._auto_detect_mutables:
|
||||
return []
|
||||
|
||||
if self._auto_detected_attrs:
|
||||
return self._auto_detected_attrs
|
||||
|
||||
candidates: List[str] = []
|
||||
seen = set()
|
||||
for name in dir(self._base):
|
||||
if name.startswith("__"):
|
||||
continue
|
||||
if name in self._mutable_attrs:
|
||||
continue
|
||||
if name in ("to", "save_pretrained", "from_pretrained"):
|
||||
continue
|
||||
try:
|
||||
val = getattr(self._base, name)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
import types
|
||||
|
||||
# skip callables and modules
|
||||
if callable(val) or isinstance(val, (types.ModuleType, types.FunctionType, types.MethodType)):
|
||||
continue
|
||||
|
||||
# containers -> candidate
|
||||
if isinstance(val, (dict, list, set, tuple, bytearray)):
|
||||
candidates.append(name)
|
||||
seen.add(name)
|
||||
else:
|
||||
# try Tensor detection
|
||||
try:
|
||||
if isinstance(val, torch.Tensor):
|
||||
if val.numel() <= self._tensor_numel_threshold:
|
||||
candidates.append(name)
|
||||
seen.add(name)
|
||||
else:
|
||||
logger.debug(f"Ignoring large tensor attr '{name}', numel={val.numel()}")
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if len(candidates) >= max_attrs:
|
||||
break
|
||||
|
||||
self._auto_detected_attrs = candidates
|
||||
logger.debug(f"Autodetected mutable attrs to clone: {self._auto_detected_attrs}")
|
||||
return self._auto_detected_attrs
|
||||
|
||||
def _is_readonly_property(self, base_obj, attr_name: str) -> bool:
|
||||
try:
|
||||
cls = type(base_obj)
|
||||
descriptor = getattr(cls, attr_name, None)
|
||||
if isinstance(descriptor, property):
|
||||
return descriptor.fset is None
|
||||
if hasattr(descriptor, "__set__") is False and descriptor is not None:
|
||||
return False
|
||||
except Exception:
|
||||
pass
|
||||
return False
|
||||
|
||||
def _clone_mutable_attrs(self, base, local):
|
||||
attrs_to_clone = list(self._mutable_attrs)
|
||||
attrs_to_clone.extend(self._autodetect_mutables())
|
||||
|
||||
EXCLUDE_ATTRS = {
|
||||
"components",
|
||||
}
|
||||
|
||||
for attr in attrs_to_clone:
|
||||
if attr in EXCLUDE_ATTRS:
|
||||
logger.debug(f"Skipping excluded attr '{attr}'")
|
||||
continue
|
||||
if not hasattr(base, attr):
|
||||
continue
|
||||
if self._is_readonly_property(base, attr):
|
||||
logger.debug(f"Skipping read-only property '{attr}'")
|
||||
continue
|
||||
|
||||
try:
|
||||
val = getattr(base, attr)
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not getattr('{attr}') on base pipeline: {e}")
|
||||
continue
|
||||
|
||||
try:
|
||||
if isinstance(val, dict):
|
||||
setattr(local, attr, dict(val))
|
||||
elif isinstance(val, (list, tuple, set)):
|
||||
setattr(local, attr, list(val))
|
||||
elif isinstance(val, bytearray):
|
||||
setattr(local, attr, bytearray(val))
|
||||
else:
|
||||
# small tensors or atomic values
|
||||
if isinstance(val, torch.Tensor):
|
||||
if val.numel() <= self._tensor_numel_threshold:
|
||||
setattr(local, attr, val.clone())
|
||||
else:
|
||||
# don't clone big tensors, keep reference
|
||||
setattr(local, attr, val)
|
||||
else:
|
||||
try:
|
||||
setattr(local, attr, copy.copy(val))
|
||||
except Exception:
|
||||
setattr(local, attr, val)
|
||||
except (AttributeError, TypeError) as e:
|
||||
logger.debug(f"Skipping cloning attribute '{attr}' because it is not settable: {e}")
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.debug(f"Unexpected error cloning attribute '{attr}': {e}")
|
||||
continue
|
||||
|
||||
def _is_tokenizer_component(self, component) -> bool:
|
||||
if component is None:
|
||||
return False
|
||||
|
||||
tokenizer_methods = ["encode", "decode", "tokenize", "__call__"]
|
||||
has_tokenizer_methods = any(hasattr(component, method) for method in tokenizer_methods)
|
||||
|
||||
class_name = component.__class__.__name__.lower()
|
||||
has_tokenizer_in_name = "tokenizer" in class_name
|
||||
|
||||
tokenizer_attrs = ["vocab_size", "pad_token", "eos_token", "bos_token"]
|
||||
has_tokenizer_attrs = any(hasattr(component, attr) for attr in tokenizer_attrs)
|
||||
|
||||
return has_tokenizer_methods and (has_tokenizer_in_name or has_tokenizer_attrs)
|
||||
|
||||
def generate(self, *args, num_inference_steps: int = 50, device: Optional[str] = None, **kwargs):
|
||||
local_scheduler = self._make_local_scheduler(num_inference_steps=num_inference_steps, device=device)
|
||||
|
||||
try:
|
||||
local_pipe = copy.copy(self._base)
|
||||
except Exception as e:
|
||||
logger.warning(f"copy.copy(self._base) failed: {e}. Falling back to deepcopy (may increase memory).")
|
||||
local_pipe = copy.deepcopy(self._base)
|
||||
|
||||
if local_scheduler is not None:
|
||||
try:
|
||||
timesteps, num_steps, configured_scheduler = async_retrieve_timesteps(
|
||||
local_scheduler.scheduler,
|
||||
num_inference_steps=num_inference_steps,
|
||||
device=device,
|
||||
return_scheduler=True,
|
||||
**{k: v for k, v in kwargs.items() if k in ["timesteps", "sigmas"]},
|
||||
)
|
||||
|
||||
final_scheduler = BaseAsyncScheduler(configured_scheduler)
|
||||
setattr(local_pipe, "scheduler", final_scheduler)
|
||||
except Exception:
|
||||
logger.warning("Could not set scheduler on local pipe; proceeding without replacing scheduler.")
|
||||
|
||||
self._clone_mutable_attrs(self._base, local_pipe)
|
||||
|
||||
# 4) wrap tokenizers on the local pipe with the lock wrapper
|
||||
tokenizer_wrappers = {} # name -> original_tokenizer
|
||||
try:
|
||||
# a) wrap direct tokenizer attributes (tokenizer, tokenizer_2, ...)
|
||||
for name in dir(local_pipe):
|
||||
if "tokenizer" in name and not name.startswith("_"):
|
||||
tok = getattr(local_pipe, name, None)
|
||||
if tok is not None and self._is_tokenizer_component(tok):
|
||||
tokenizer_wrappers[name] = tok
|
||||
setattr(
|
||||
local_pipe,
|
||||
name,
|
||||
lambda *args, tok=tok, **kwargs: safe_tokenize(
|
||||
tok, *args, lock=self._tokenizer_lock, **kwargs
|
||||
),
|
||||
)
|
||||
|
||||
# b) wrap tokenizers in components dict
|
||||
if hasattr(local_pipe, "components") and isinstance(local_pipe.components, dict):
|
||||
for key, val in local_pipe.components.items():
|
||||
if val is None:
|
||||
continue
|
||||
|
||||
if self._is_tokenizer_component(val):
|
||||
tokenizer_wrappers[f"components[{key}]"] = val
|
||||
local_pipe.components[key] = lambda *args, tokenizer=val, **kwargs: safe_tokenize(
|
||||
tokenizer, *args, lock=self._tokenizer_lock, **kwargs
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Tokenizer wrapping step encountered an error: {e}")
|
||||
|
||||
result = None
|
||||
cm = getattr(local_pipe, "model_cpu_offload_context", None)
|
||||
try:
|
||||
if callable(cm):
|
||||
try:
|
||||
with cm():
|
||||
result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
|
||||
except TypeError:
|
||||
# cm might be a context manager instance rather than callable
|
||||
try:
|
||||
with cm:
|
||||
result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
|
||||
except Exception as e:
|
||||
logger.debug(f"model_cpu_offload_context usage failed: {e}. Proceeding without it.")
|
||||
result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
|
||||
else:
|
||||
# no offload context available — call directly
|
||||
result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
|
||||
|
||||
return result
|
||||
|
||||
finally:
|
||||
try:
|
||||
for name, tok in tokenizer_wrappers.items():
|
||||
if name.startswith("components["):
|
||||
key = name[len("components[") : -1]
|
||||
local_pipe.components[key] = tok
|
||||
else:
|
||||
setattr(local_pipe, name, tok)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error restoring wrapped tokenizers: {e}")
|
||||
@@ -0,0 +1,141 @@
|
||||
import copy
|
||||
import inspect
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class BaseAsyncScheduler:
|
||||
def __init__(self, scheduler: Any):
|
||||
self.scheduler = scheduler
|
||||
|
||||
def __getattr__(self, name: str):
|
||||
if hasattr(self.scheduler, name):
|
||||
return getattr(self.scheduler, name)
|
||||
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
|
||||
|
||||
def __setattr__(self, name: str, value):
|
||||
if name == "scheduler":
|
||||
super().__setattr__(name, value)
|
||||
else:
|
||||
if hasattr(self, "scheduler") and hasattr(self.scheduler, name):
|
||||
setattr(self.scheduler, name, value)
|
||||
else:
|
||||
super().__setattr__(name, value)
|
||||
|
||||
def clone_for_request(self, num_inference_steps: int, device: Union[str, torch.device, None] = None, **kwargs):
|
||||
local = copy.deepcopy(self.scheduler)
|
||||
local.set_timesteps(num_inference_steps=num_inference_steps, device=device, **kwargs)
|
||||
cloned = self.__class__(local)
|
||||
return cloned
|
||||
|
||||
def __repr__(self):
|
||||
return f"BaseAsyncScheduler({repr(self.scheduler)})"
|
||||
|
||||
def __str__(self):
|
||||
return f"BaseAsyncScheduler wrapping: {str(self.scheduler)}"
|
||||
|
||||
|
||||
def async_retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call.
|
||||
Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Backwards compatible: by default the function behaves exactly as before and returns
|
||||
(timesteps_tensor, num_inference_steps)
|
||||
|
||||
If the caller passes `return_scheduler=True` in kwargs, the function will **not** mutate the passed
|
||||
scheduler. Instead it will use a cloned scheduler if available (via `scheduler.clone_for_request`)
|
||||
or a deepcopy fallback, call `set_timesteps` on that cloned scheduler, and return:
|
||||
(timesteps_tensor, num_inference_steps, scheduler_in_use)
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Optional kwargs:
|
||||
return_scheduler (bool, default False): if True, return (timesteps, num_inference_steps, scheduler_in_use)
|
||||
where `scheduler_in_use` is a scheduler instance that already has timesteps set.
|
||||
This mode will prefer `scheduler.clone_for_request(...)` if available, to avoid mutating the original scheduler.
|
||||
|
||||
Returns:
|
||||
`(timesteps_tensor, num_inference_steps)` by default (backwards compatible), or
|
||||
`(timesteps_tensor, num_inference_steps, scheduler_in_use)` if `return_scheduler=True`.
|
||||
"""
|
||||
# pop our optional control kwarg (keeps compatibility)
|
||||
return_scheduler = bool(kwargs.pop("return_scheduler", False))
|
||||
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
|
||||
# choose scheduler to call set_timesteps on
|
||||
scheduler_in_use = scheduler
|
||||
if return_scheduler:
|
||||
# Do not mutate the provided scheduler: prefer to clone if possible
|
||||
if hasattr(scheduler, "clone_for_request"):
|
||||
try:
|
||||
# clone_for_request may accept num_inference_steps or other kwargs; be permissive
|
||||
scheduler_in_use = scheduler.clone_for_request(
|
||||
num_inference_steps=num_inference_steps or 0, device=device
|
||||
)
|
||||
except Exception:
|
||||
scheduler_in_use = copy.deepcopy(scheduler)
|
||||
else:
|
||||
# fallback deepcopy (scheduler tends to be smallish - acceptable)
|
||||
scheduler_in_use = copy.deepcopy(scheduler)
|
||||
|
||||
# helper to test if set_timesteps supports a particular kwarg
|
||||
def _accepts(param_name: str) -> bool:
|
||||
try:
|
||||
return param_name in set(inspect.signature(scheduler_in_use.set_timesteps).parameters.keys())
|
||||
except (ValueError, TypeError):
|
||||
# if signature introspection fails, be permissive and attempt the call later
|
||||
return False
|
||||
|
||||
# now call set_timesteps on the chosen scheduler_in_use (may be original or clone)
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = _accepts("timesteps")
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler_in_use.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler_in_use.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps_out = scheduler_in_use.timesteps
|
||||
num_inference_steps = len(timesteps_out)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = _accepts("sigmas")
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler_in_use.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler_in_use.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps_out = scheduler_in_use.timesteps
|
||||
num_inference_steps = len(timesteps_out)
|
||||
else:
|
||||
# default path
|
||||
scheduler_in_use.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps_out = scheduler_in_use.timesteps
|
||||
|
||||
if return_scheduler:
|
||||
return timesteps_out, num_inference_steps, scheduler_in_use
|
||||
return timesteps_out, num_inference_steps
|
||||
@@ -0,0 +1,48 @@
|
||||
import gc
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import uuid
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Utils:
|
||||
def __init__(self, host: str = "0.0.0.0", port: int = 8500):
|
||||
self.service_url = f"http://{host}:{port}"
|
||||
self.image_dir = os.path.join(tempfile.gettempdir(), "images")
|
||||
if not os.path.exists(self.image_dir):
|
||||
os.makedirs(self.image_dir)
|
||||
|
||||
self.video_dir = os.path.join(tempfile.gettempdir(), "videos")
|
||||
if not os.path.exists(self.video_dir):
|
||||
os.makedirs(self.video_dir)
|
||||
|
||||
def save_image(self, image):
|
||||
if hasattr(image, "to"):
|
||||
try:
|
||||
image = image.to("cpu")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if isinstance(image, torch.Tensor):
|
||||
from torchvision import transforms
|
||||
|
||||
to_pil = transforms.ToPILImage()
|
||||
image = to_pil(image.squeeze(0).clamp(0, 1))
|
||||
|
||||
filename = "img" + str(uuid.uuid4()).split("-")[0] + ".png"
|
||||
image_path = os.path.join(self.image_dir, filename)
|
||||
logger.info(f"Saving image to {image_path}")
|
||||
|
||||
image.save(image_path, format="PNG", optimize=True)
|
||||
|
||||
del image
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return os.path.join(self.service_url, "images", filename)
|
||||
@@ -278,6 +278,29 @@ def get_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]:
|
||||
}
|
||||
RENAME_DICT = VACE_TRANSFORMER_KEYS_RENAME_DICT
|
||||
SPECIAL_KEYS_REMAP = VACE_TRANSFORMER_SPECIAL_KEYS_REMAP
|
||||
elif model_type == "Wan2.2-VACE-Fun-14B":
|
||||
config = {
|
||||
"model_id": "alibaba-pai/Wan2.2-VACE-Fun-A14B",
|
||||
"diffusers_config": {
|
||||
"added_kv_proj_dim": None,
|
||||
"attention_head_dim": 128,
|
||||
"cross_attn_norm": True,
|
||||
"eps": 1e-06,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"in_channels": 16,
|
||||
"num_attention_heads": 40,
|
||||
"num_layers": 40,
|
||||
"out_channels": 16,
|
||||
"patch_size": [1, 2, 2],
|
||||
"qk_norm": "rms_norm_across_heads",
|
||||
"text_dim": 4096,
|
||||
"vace_layers": [0, 5, 10, 15, 20, 25, 30, 35],
|
||||
"vace_in_channels": 96,
|
||||
},
|
||||
}
|
||||
RENAME_DICT = VACE_TRANSFORMER_KEYS_RENAME_DICT
|
||||
SPECIAL_KEYS_REMAP = VACE_TRANSFORMER_SPECIAL_KEYS_REMAP
|
||||
elif model_type == "Wan2.2-I2V-14B-720p":
|
||||
config = {
|
||||
"model_id": "Wan-AI/Wan2.2-I2V-A14B",
|
||||
@@ -975,7 +998,17 @@ if __name__ == "__main__":
|
||||
image_encoder=image_encoder,
|
||||
image_processor=image_processor,
|
||||
)
|
||||
elif "VACE" in args.model_type:
|
||||
elif "Wan2.2-VACE" in args.model_type:
|
||||
pipe = WanVACEPipeline(
|
||||
transformer=transformer,
|
||||
transformer_2=transformer_2,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
vae=vae,
|
||||
scheduler=scheduler,
|
||||
boundary_ratio=0.875,
|
||||
)
|
||||
elif "Wan-VACE" in args.model_type:
|
||||
pipe = WanVACEPipeline(
|
||||
transformer=transformer,
|
||||
text_encoder=text_encoder,
|
||||
|
||||
@@ -495,6 +495,7 @@ else:
|
||||
"LTXImageToVideoPipeline",
|
||||
"LTXLatentUpsamplePipeline",
|
||||
"LTXPipeline",
|
||||
"LucyEditPipeline",
|
||||
"Lumina2Pipeline",
|
||||
"Lumina2Text2ImgPipeline",
|
||||
"LuminaPipeline",
|
||||
@@ -1149,6 +1150,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LTXImageToVideoPipeline,
|
||||
LTXLatentUpsamplePipeline,
|
||||
LTXPipeline,
|
||||
LucyEditPipeline,
|
||||
Lumina2Pipeline,
|
||||
Lumina2Text2ImgPipeline,
|
||||
LuminaPipeline,
|
||||
|
||||
@@ -558,70 +558,62 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
|
||||
ait_sd[target_key] = value
|
||||
|
||||
if any("guidance_in" in k for k in sds_sd):
|
||||
assign_remaining_weights(
|
||||
[
|
||||
(
|
||||
"time_text_embed.guidance_embedder.linear_1.{lora_key}.weight",
|
||||
"lora_unet_guidance_in_in_layer.{orig_lora_key}.weight",
|
||||
None,
|
||||
),
|
||||
(
|
||||
"time_text_embed.guidance_embedder.linear_2.{lora_key}.weight",
|
||||
"lora_unet_guidance_in_out_layer.{orig_lora_key}.weight",
|
||||
None,
|
||||
),
|
||||
],
|
||||
_convert_to_ai_toolkit(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
"lora_unet_guidance_in_in_layer",
|
||||
"time_text_embed.guidance_embedder.linear_1",
|
||||
)
|
||||
|
||||
_convert_to_ai_toolkit(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
"lora_unet_guidance_in_out_layer",
|
||||
"time_text_embed.guidance_embedder.linear_2",
|
||||
)
|
||||
|
||||
if any("img_in" in k for k in sds_sd):
|
||||
assign_remaining_weights(
|
||||
[
|
||||
("x_embedder.{lora_key}.weight", "lora_unet_img_in.{orig_lora_key}.weight", None),
|
||||
],
|
||||
_convert_to_ai_toolkit(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
"lora_unet_img_in",
|
||||
"x_embedder",
|
||||
)
|
||||
|
||||
if any("txt_in" in k for k in sds_sd):
|
||||
assign_remaining_weights(
|
||||
[
|
||||
("context_embedder.{lora_key}.weight", "lora_unet_txt_in.{orig_lora_key}.weight", None),
|
||||
],
|
||||
_convert_to_ai_toolkit(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
"lora_unet_txt_in",
|
||||
"context_embedder",
|
||||
)
|
||||
|
||||
if any("time_in" in k for k in sds_sd):
|
||||
assign_remaining_weights(
|
||||
[
|
||||
(
|
||||
"time_text_embed.timestep_embedder.linear_1.{lora_key}.weight",
|
||||
"lora_unet_time_in_in_layer.{orig_lora_key}.weight",
|
||||
None,
|
||||
),
|
||||
(
|
||||
"time_text_embed.timestep_embedder.linear_2.{lora_key}.weight",
|
||||
"lora_unet_time_in_out_layer.{orig_lora_key}.weight",
|
||||
None,
|
||||
),
|
||||
],
|
||||
_convert_to_ai_toolkit(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
"lora_unet_time_in_in_layer",
|
||||
"time_text_embed.timestep_embedder.linear_1",
|
||||
)
|
||||
_convert_to_ai_toolkit(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
"lora_unet_time_in_out_layer",
|
||||
"time_text_embed.timestep_embedder.linear_2",
|
||||
)
|
||||
|
||||
if any("vector_in" in k for k in sds_sd):
|
||||
assign_remaining_weights(
|
||||
[
|
||||
(
|
||||
"time_text_embed.text_embedder.linear_1.{lora_key}.weight",
|
||||
"lora_unet_vector_in_in_layer.{orig_lora_key}.weight",
|
||||
None,
|
||||
),
|
||||
(
|
||||
"time_text_embed.text_embedder.linear_2.{lora_key}.weight",
|
||||
"lora_unet_vector_in_out_layer.{orig_lora_key}.weight",
|
||||
None,
|
||||
),
|
||||
],
|
||||
_convert_to_ai_toolkit(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
"lora_unet_vector_in_in_layer",
|
||||
"time_text_embed.text_embedder.linear_1",
|
||||
)
|
||||
_convert_to_ai_toolkit(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
"lora_unet_vector_in_out_layer",
|
||||
"time_text_embed.text_embedder.linear_2",
|
||||
)
|
||||
|
||||
if any("final_layer" in k for k in sds_sd):
|
||||
|
||||
@@ -674,7 +674,7 @@ class JointTransformerBlock(nn.Module):
|
||||
encoder_hidden_states: torch.FloatTensor,
|
||||
temb: torch.FloatTensor,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
joint_attention_kwargs = joint_attention_kwargs or {}
|
||||
if self.use_dual_attention:
|
||||
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
|
||||
|
||||
@@ -1052,7 +1052,7 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
is_residual=is_residual,
|
||||
)
|
||||
|
||||
self.spatial_compression_ratio = 2 ** len(self.temperal_downsample)
|
||||
self.spatial_compression_ratio = scale_factor_spatial
|
||||
|
||||
# When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
|
||||
# to perform decoding of a single video latent at a time.
|
||||
@@ -1145,12 +1145,13 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
def _encode(self, x: torch.Tensor):
|
||||
_, _, num_frame, height, width = x.shape
|
||||
|
||||
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
|
||||
return self.tiled_encode(x)
|
||||
|
||||
self.clear_cache()
|
||||
if self.config.patch_size is not None:
|
||||
x = patchify(x, patch_size=self.config.patch_size)
|
||||
|
||||
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
|
||||
return self.tiled_encode(x)
|
||||
|
||||
iter_ = 1 + (num_frame - 1) // 4
|
||||
for i in range(iter_):
|
||||
self._enc_conv_idx = [0]
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -92,7 +92,7 @@ class AuraFlowPatchEmbed(nn.Module):
|
||||
|
||||
return selected_indices
|
||||
|
||||
def forward(self, latent):
|
||||
def forward(self, latent) -> torch.Tensor:
|
||||
batch_size, num_channels, height, width = latent.size()
|
||||
latent = latent.view(
|
||||
batch_size,
|
||||
@@ -173,7 +173,7 @@ class AuraFlowSingleTransformerBlock(nn.Module):
|
||||
hidden_states: torch.FloatTensor,
|
||||
temb: torch.FloatTensor,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
attention_kwargs = attention_kwargs or {}
|
||||
|
||||
@@ -242,7 +242,7 @@ class AuraFlowJointTransformerBlock(nn.Module):
|
||||
encoder_hidden_states: torch.FloatTensor,
|
||||
temb: torch.FloatTensor,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
residual = hidden_states
|
||||
residual_context = encoder_hidden_states
|
||||
attention_kwargs = attention_kwargs or {}
|
||||
@@ -472,7 +472,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From
|
||||
timestep: torch.LongTensor = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
|
||||
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
|
||||
@@ -122,7 +122,7 @@ class CogVideoXBlock(nn.Module):
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
text_seq_length = encoder_hidden_states.size(1)
|
||||
attention_kwargs = attention_kwargs or {}
|
||||
|
||||
@@ -441,7 +441,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cac
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
|
||||
@@ -315,7 +315,7 @@ class ConsisIDBlock(nn.Module):
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
text_seq_length = encoder_hidden_states.size(1)
|
||||
|
||||
# norm & modulate
|
||||
@@ -691,7 +691,7 @@ class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
id_cond: Optional[torch.Tensor] = None,
|
||||
id_vit_hidden: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -124,7 +124,7 @@ class LuminaNextDiTBlock(nn.Module):
|
||||
encoder_mask: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Perform a forward pass through the LuminaNextDiTBlock.
|
||||
|
||||
@@ -297,7 +297,7 @@ class LuminaNextDiT2DModel(ModelMixin, ConfigMixin):
|
||||
image_rotary_emb: torch.Tensor,
|
||||
cross_attention_kwargs: Dict[str, Any] = None,
|
||||
return_dict=True,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
|
||||
"""
|
||||
Forward pass of LuminaNextDiT.
|
||||
|
||||
|
||||
@@ -472,7 +472,7 @@ class BriaSingleTransformerBlock(nn.Module):
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
text_seq_len = encoder_hidden_states.shape[1]
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
@@ -588,7 +588,7 @@ class BriaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
||||
return_dict: bool = True,
|
||||
controlnet_block_samples=None,
|
||||
controlnet_single_block_samples=None,
|
||||
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
|
||||
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
|
||||
"""
|
||||
The [`BriaTransformer2DModel`] forward method.
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Dict, Union
|
||||
from typing import Dict, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -79,7 +79,7 @@ class CogView3PlusTransformerBlock(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
emb: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
text_seq_length = encoder_hidden_states.size(1)
|
||||
|
||||
# norm & modulate
|
||||
@@ -293,7 +293,7 @@ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
target_size: torch.Tensor,
|
||||
crop_coords: torch.Tensor,
|
||||
return_dict: bool = True,
|
||||
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
||||
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
|
||||
"""
|
||||
The [`CogView3PlusTransformer2DModel`] forward method.
|
||||
|
||||
|
||||
@@ -494,7 +494,7 @@ class CogView4TransformerBlock(nn.Module):
|
||||
] = None,
|
||||
attention_mask: Optional[Dict[str, torch.Tensor]] = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# 1. Timestep conditioning
|
||||
(
|
||||
norm_hidden_states,
|
||||
@@ -717,7 +717,7 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
|
||||
image_rotary_emb: Optional[
|
||||
Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]
|
||||
] = None,
|
||||
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
||||
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
|
||||
@@ -55,7 +55,7 @@ class HiDreamImageTimestepEmbed(nn.Module):
|
||||
self.time_proj = Timesteps(num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=hidden_size)
|
||||
|
||||
def forward(self, timesteps: torch.Tensor, wdtype: Optional[torch.dtype] = None):
|
||||
def forward(self, timesteps: torch.Tensor, wdtype: Optional[torch.dtype] = None) -> torch.Tensor:
|
||||
t_emb = self.time_proj(timesteps).to(dtype=wdtype)
|
||||
t_emb = self.timestep_embedder(t_emb)
|
||||
return t_emb
|
||||
@@ -87,7 +87,7 @@ class HiDreamImagePatchEmbed(nn.Module):
|
||||
self.out_channels = out_channels
|
||||
self.proj = nn.Linear(in_channels * patch_size * patch_size, out_channels, bias=True)
|
||||
|
||||
def forward(self, latent):
|
||||
def forward(self, latent) -> torch.Tensor:
|
||||
latent = self.proj(latent)
|
||||
return latent
|
||||
|
||||
@@ -534,7 +534,7 @@ class HiDreamImageTransformerBlock(nn.Module):
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
wtype = hidden_states.dtype
|
||||
(
|
||||
shift_msa_i,
|
||||
@@ -592,7 +592,7 @@ class HiDreamBlock(nn.Module):
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
return self.block(
|
||||
hidden_states=hidden_states,
|
||||
hidden_states_masks=hidden_states_masks,
|
||||
@@ -786,7 +786,7 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
|
||||
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
|
||||
@@ -529,7 +529,7 @@ class HunyuanVideoSingleTransformerBlock(nn.Module):
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
text_seq_length = encoder_hidden_states.shape[1]
|
||||
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
|
||||
|
||||
@@ -684,7 +684,7 @@ class HunyuanVideoTokenReplaceSingleTransformerBlock(nn.Module):
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
token_replace_emb: torch.Tensor = None,
|
||||
num_tokens: int = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
text_seq_length = encoder_hidden_states.shape[1]
|
||||
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
|
||||
|
||||
@@ -1038,7 +1038,7 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
||||
guidance: torch.Tensor = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -216,7 +216,7 @@ class HunyuanVideoFramepackTransformer3DModel(
|
||||
indices_latents_history_4x: Optional[torch.Tensor] = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
|
||||
@@ -82,6 +82,7 @@ class UNet1DModel(ModelMixin, ConfigMixin):
|
||||
out_channels: int = 2,
|
||||
extra_in_channels: int = 0,
|
||||
time_embedding_type: str = "fourier",
|
||||
time_embedding_dim: Optional[int] = None,
|
||||
flip_sin_to_cos: bool = True,
|
||||
use_timestep_embedding: bool = False,
|
||||
freq_shift: float = 0.0,
|
||||
@@ -100,15 +101,23 @@ class UNet1DModel(ModelMixin, ConfigMixin):
|
||||
|
||||
# time
|
||||
if time_embedding_type == "fourier":
|
||||
time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
|
||||
if time_embed_dim % 2 != 0:
|
||||
raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
|
||||
self.time_proj = GaussianFourierProjection(
|
||||
embedding_size=8, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
|
||||
embedding_size=time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
|
||||
)
|
||||
timestep_input_dim = 2 * block_out_channels[0]
|
||||
timestep_input_dim = time_embed_dim
|
||||
elif time_embedding_type == "positional":
|
||||
time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
|
||||
self.time_proj = Timesteps(
|
||||
block_out_channels[0], flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=freq_shift
|
||||
)
|
||||
timestep_input_dim = block_out_channels[0]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
|
||||
)
|
||||
|
||||
if use_timestep_embedding:
|
||||
time_embed_dim = block_out_channels[0] * 4
|
||||
|
||||
@@ -285,6 +285,7 @@ else:
|
||||
]
|
||||
_import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"]
|
||||
_import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"]
|
||||
_import_structure["lucy"] = ["LucyEditPipeline"]
|
||||
_import_structure["marigold"].extend(
|
||||
[
|
||||
"MarigoldDepthPipeline",
|
||||
@@ -682,6 +683,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LEditsPPPipelineStableDiffusionXL,
|
||||
)
|
||||
from .ltx import LTXConditionPipeline, LTXImageToVideoPipeline, LTXLatentUpsamplePipeline, LTXPipeline
|
||||
from .lucy import LucyEditPipeline
|
||||
from .lumina import LuminaPipeline, LuminaText2ImgPipeline
|
||||
from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline
|
||||
from .marigold import (
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_lucy_edit"] = ["LucyEditPipeline"]
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .pipeline_lucy_edit import LucyEditPipeline
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
@@ -0,0 +1,735 @@
|
||||
# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
|
||||
# Copyright 2025 The Decart AI Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# Modifications by Decart AI Team:
|
||||
# - Based on pipeline_wan.py, but with supports recieving a condition video appended to the channel dimension.
|
||||
|
||||
import html
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import AutoTokenizer, UMT5EncoderModel
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...loaders import WanLoraLoaderMixin
|
||||
from ...models import AutoencoderKLWan, WanTransformer3DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .pipeline_output import LucyPipelineOutput
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
if is_ftfy_available():
|
||||
import ftfy
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```python
|
||||
>>> from typing import List
|
||||
|
||||
>>> import torch
|
||||
>>> from PIL import Image
|
||||
|
||||
>>> from diffusers import AutoencoderKLWan, LucyEditPipeline
|
||||
>>> from diffusers.utils import export_to_video, load_video
|
||||
|
||||
>>> # Arguments
|
||||
>>> url = "https://d2drjpuinn46lb.cloudfront.net/painter_original_edit.mp4"
|
||||
>>> prompt = "Change the apron and blouse to a classic clown costume: satin polka-dot jumpsuit in bright primary colors, ruffled white collar, oversized pom-pom buttons, white gloves, oversized red shoes, red foam nose; soft window light from left, eye-level medium shot, natural folds and fabric highlights."
|
||||
>>> negative_prompt = ""
|
||||
>>> num_frames = 81
|
||||
>>> height = 480
|
||||
>>> width = 832
|
||||
|
||||
|
||||
>>> # Load video
|
||||
>>> def convert_video(video: List[Image.Image]) -> List[Image.Image]:
|
||||
... video = load_video(url)[:num_frames]
|
||||
... video = [video[i].resize((width, height)) for i in range(num_frames)]
|
||||
... return video
|
||||
|
||||
|
||||
>>> video = load_video(url, convert_method=convert_video)
|
||||
|
||||
>>> # Load model
|
||||
>>> model_id = "decart-ai/Lucy-Edit-Dev"
|
||||
>>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
|
||||
>>> pipe = LucyEditPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> # Generate video
|
||||
>>> output = pipe(
|
||||
... prompt=prompt,
|
||||
... video=video,
|
||||
... negative_prompt=negative_prompt,
|
||||
... height=480,
|
||||
... width=832,
|
||||
... num_frames=81,
|
||||
... guidance_scale=5.0,
|
||||
... ).frames[0]
|
||||
|
||||
>>> # Export video
|
||||
>>> export_to_video(output, "output.mp4", fps=24)
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
def basic_clean(text):
|
||||
text = ftfy.fix_text(text)
|
||||
text = html.unescape(html.unescape(text))
|
||||
return text.strip()
|
||||
|
||||
|
||||
def whitespace_clean(text):
|
||||
text = re.sub(r"\s+", " ", text)
|
||||
text = text.strip()
|
||||
return text
|
||||
|
||||
|
||||
def prompt_clean(text):
|
||||
text = whitespace_clean(basic_clean(text))
|
||||
return text
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
class LucyEditPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
r"""
|
||||
Pipeline for video-to-video generation using Lucy Edit.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
||||
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
||||
|
||||
Args:
|
||||
tokenizer ([`T5Tokenizer`]):
|
||||
Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer),
|
||||
specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
|
||||
text_encoder ([`T5EncoderModel`]):
|
||||
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
|
||||
the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
|
||||
transformer ([`WanTransformer3DModel`]):
|
||||
Conditional Transformer to denoise the input latents.
|
||||
scheduler ([`UniPCMultistepScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKLWan`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
||||
transformer_2 ([`WanTransformer3DModel`], *optional*):
|
||||
Conditional Transformer to denoise the input latents during the low-noise stage. If provided, enables
|
||||
two-stage denoising where `transformer` handles high-noise stages and `transformer_2` handles low-noise
|
||||
stages. If not provided, only `transformer` is used.
|
||||
boundary_ratio (`float`, *optional*, defaults to `None`):
|
||||
Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising.
|
||||
The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. When provided,
|
||||
`transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps <
|
||||
boundary_timestep. If `None`, only `transformer` is used for the entire denoising process.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->transformer_2->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
_optional_components = ["transformer", "transformer_2"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: AutoTokenizer,
|
||||
text_encoder: UMT5EncoderModel,
|
||||
vae: AutoencoderKLWan,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
transformer: Optional[WanTransformer3DModel] = None,
|
||||
transformer_2: Optional[WanTransformer3DModel] = None,
|
||||
boundary_ratio: Optional[float] = None,
|
||||
expand_timesteps: bool = False, # Wan2.2 ti2v
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
transformer_2=transformer_2,
|
||||
)
|
||||
self.register_to_config(boundary_ratio=boundary_ratio)
|
||||
self.register_to_config(expand_timesteps=expand_timesteps)
|
||||
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
|
||||
self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
||||
|
||||
# Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline._get_t5_prompt_embeds
|
||||
def _get_t5_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
num_videos_per_prompt: int = 1,
|
||||
max_sequence_length: int = 226,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
prompt = [prompt_clean(u) for u in prompt]
|
||||
batch_size = len(prompt)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_attention_mask=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
|
||||
seq_lens = mask.gt(0).sum(dim=1).long()
|
||||
|
||||
prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
|
||||
prompt_embeds = torch.stack(
|
||||
[torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
|
||||
)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
num_videos_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
max_sequence_length: int = 226,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use classifier free guidance or not.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device
|
||||
dtype: (`torch.dtype`, *optional*):
|
||||
torch dtype
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds = self._get_t5_prompt_embeds(
|
||||
prompt=prompt,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
|
||||
negative_prompt_embeds = self._get_t5_prompt_embeds(
|
||||
prompt=negative_prompt,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
video,
|
||||
prompt,
|
||||
negative_prompt,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
guidance_scale_2=None,
|
||||
):
|
||||
if height % 16 != 0 or width % 16 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
elif negative_prompt is not None and (
|
||||
not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
|
||||
):
|
||||
raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
|
||||
|
||||
if self.config.boundary_ratio is None and guidance_scale_2 is not None:
|
||||
raise ValueError("`guidance_scale_2` is only supported when the pipeline's `boundary_ratio` is not None.")
|
||||
|
||||
if video is None:
|
||||
raise ValueError("`video` is required, received None.")
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
video: Optional[torch.Tensor] = None,
|
||||
batch_size: int = 1,
|
||||
num_channels_latents: int = 16,
|
||||
height: int = 480,
|
||||
width: int = 832,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
num_latent_frames = (
|
||||
(video.size(2) - 1) // self.vae_scale_factor_temporal + 1 if latents is None else latents.size(1)
|
||||
)
|
||||
shape = (
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
num_latent_frames,
|
||||
height // self.vae_scale_factor_spatial,
|
||||
width // self.vae_scale_factor_spatial,
|
||||
)
|
||||
# Prepare noise latents
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
latents = latents.to(device)
|
||||
|
||||
# Prepare condition latents
|
||||
condition_latents = [
|
||||
retrieve_latents(self.vae.encode(vid.unsqueeze(0)), sample_mode="argmax") for vid in video
|
||||
]
|
||||
|
||||
condition_latents = torch.cat(condition_latents, dim=0).to(dtype)
|
||||
|
||||
latents_mean = (
|
||||
torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1).to(device, dtype)
|
||||
)
|
||||
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
|
||||
device, dtype
|
||||
)
|
||||
|
||||
condition_latents = (condition_latents - latents_mean) * latents_std
|
||||
|
||||
# Check shapes
|
||||
assert latents.shape == condition_latents.shape, (
|
||||
f"Latents shape {latents.shape} does not match expected shape {condition_latents.shape}. Please check the input."
|
||||
)
|
||||
|
||||
return latents, condition_latents
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1.0
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def current_timestep(self):
|
||||
return self._current_timestep
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@property
|
||||
def attention_kwargs(self):
|
||||
return self._attention_kwargs
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
video: List[Image.Image],
|
||||
prompt: Union[str, List[str]] = None,
|
||||
negative_prompt: Union[str, List[str]] = None,
|
||||
height: int = 480,
|
||||
width: int = 832,
|
||||
num_frames: int = 81,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 5.0,
|
||||
guidance_scale_2: Optional[float] = None,
|
||||
num_videos_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
output_type: Optional[str] = "np",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[
|
||||
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
||||
] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 512,
|
||||
):
|
||||
r"""
|
||||
The call function to the pipeline for generation.
|
||||
|
||||
Args:
|
||||
video (`List[Image.Image]`):
|
||||
The video to use as the condition for the video generation.
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, pass `prompt_embeds` instead.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to avoid during image generation. If not defined, pass `negative_prompt_embeds`
|
||||
instead. Ignored when not using guidance (`guidance_scale` < `1`).
|
||||
height (`int`, defaults to `480`):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, defaults to `832`):
|
||||
The width in pixels of the generated image.
|
||||
num_frames (`int`, defaults to `81`):
|
||||
The number of frames in the generated video.
|
||||
num_inference_steps (`int`, defaults to `50`):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, defaults to `5.0`):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion
|
||||
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
||||
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
||||
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
||||
the text `prompt`, usually at the expense of lower image quality.
|
||||
guidance_scale_2 (`float`, *optional*, defaults to `None`):
|
||||
Guidance scale for the low-noise stage transformer (`transformer_2`). If `None` and the pipeline's
|
||||
`boundary_ratio` is not None, uses the same value as `guidance_scale`. Only used when `transformer_2`
|
||||
and the pipeline's `boundary_ratio` are not None.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
||||
generation deterministic.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor is generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
||||
provided, text embeddings are generated from the `prompt` input argument.
|
||||
output_type (`str`, *optional*, defaults to `"np"`):
|
||||
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`LucyPipelineOutput`] instead of a plain tuple.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
|
||||
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
|
||||
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
|
||||
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
|
||||
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
max_sequence_length (`int`, defaults to `512`):
|
||||
The maximum sequence length of the text encoder. If the prompt is longer than this, it will be
|
||||
truncated. If the prompt is shorter, it will be padded to this length.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~LucyPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`LucyPipelineOutput`] is returned, otherwise a `tuple` is returned where
|
||||
the first element is a list with the generated images and the second element is a list of `bool`s
|
||||
indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
|
||||
"""
|
||||
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
video,
|
||||
prompt,
|
||||
negative_prompt,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
guidance_scale_2,
|
||||
)
|
||||
|
||||
if num_frames % self.vae_scale_factor_temporal != 1:
|
||||
logger.warning(
|
||||
f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
|
||||
)
|
||||
num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
|
||||
num_frames = max(num_frames, 1)
|
||||
|
||||
if self.config.boundary_ratio is not None and guidance_scale_2 is None:
|
||||
guidance_scale_2 = guidance_scale
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._guidance_scale_2 = guidance_scale_2
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
# 3. Encode input prompt
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
)
|
||||
|
||||
transformer_dtype = self.transformer.dtype if self.transformer is not None else self.transformer_2.dtype
|
||||
prompt_embeds = prompt_embeds.to(transformer_dtype)
|
||||
if negative_prompt_embeds is not None:
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = (
|
||||
self.transformer.config.out_channels
|
||||
if self.transformer is not None
|
||||
else self.transformer_2.config.out_channels
|
||||
)
|
||||
video = self.video_processor.preprocess_video(video, height=height, width=width).to(
|
||||
device, dtype=torch.float32
|
||||
)
|
||||
latents, condition_latents = self.prepare_latents(
|
||||
video,
|
||||
batch_size * num_videos_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
torch.float32,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
mask = torch.ones(latents.shape, dtype=torch.float32, device=device)
|
||||
|
||||
# 6. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
if self.config.boundary_ratio is not None:
|
||||
boundary_timestep = self.config.boundary_ratio * self.scheduler.config.num_train_timesteps
|
||||
else:
|
||||
boundary_timestep = None
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
|
||||
if boundary_timestep is None or t >= boundary_timestep:
|
||||
# wan2.1 or high-noise stage in wan2.2
|
||||
current_model = self.transformer
|
||||
current_guidance_scale = guidance_scale
|
||||
else:
|
||||
# low-noise stage in wan2.2
|
||||
current_model = self.transformer_2
|
||||
current_guidance_scale = guidance_scale_2
|
||||
|
||||
# latent_model_input = latents.to(transformer_dtype)
|
||||
latent_model_input = torch.cat([latents, condition_latents], dim=1).to(transformer_dtype)
|
||||
# latent_model_input = torch.cat([latents, latents], dim=1).to(transformer_dtype)
|
||||
if self.config.expand_timesteps:
|
||||
# seq_len: num_latent_frames * latent_height//2 * latent_width//2
|
||||
temp_ts = (mask[0][0][:, ::2, ::2] * t).flatten()
|
||||
# batch_size, seq_len
|
||||
timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1)
|
||||
else:
|
||||
timestep = t.expand(latents.shape[0])
|
||||
|
||||
with current_model.cache_context("cond"):
|
||||
noise_pred = current_model(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
with current_model.cache_context("uncond"):
|
||||
noise_uncond = current_model(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
self._current_timestep = None
|
||||
|
||||
if not output_type == "latent":
|
||||
latents = latents.to(self.vae.dtype)
|
||||
latents_mean = (
|
||||
torch.tensor(self.vae.config.latents_mean)
|
||||
.view(1, self.vae.config.z_dim, 1, 1, 1)
|
||||
.to(latents.device, latents.dtype)
|
||||
)
|
||||
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
|
||||
latents.device, latents.dtype
|
||||
)
|
||||
latents = latents / latents_std + latents_mean
|
||||
video = self.vae.decode(latents, return_dict=False)[0]
|
||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||
else:
|
||||
video = latents
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (video,)
|
||||
|
||||
return LucyPipelineOutput(frames=video)
|
||||
@@ -0,0 +1,20 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers.utils import BaseOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
class LucyPipelineOutput(BaseOutput):
|
||||
r"""
|
||||
Output class for Lucy pipelines.
|
||||
|
||||
Args:
|
||||
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
|
||||
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
|
||||
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
|
||||
`(batch_size, num_frames, channels, height, width)`.
|
||||
"""
|
||||
|
||||
frames: torch.Tensor
|
||||
@@ -48,7 +48,6 @@ from .transformers_loading_utils import _load_tokenizer_from_dduf, _load_transfo
|
||||
if is_transformers_available():
|
||||
import transformers
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizerBase
|
||||
from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME
|
||||
from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME
|
||||
from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME
|
||||
|
||||
@@ -112,7 +111,7 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
|
||||
]
|
||||
|
||||
if is_transformers_available():
|
||||
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
|
||||
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME]
|
||||
|
||||
# model_pytorch, diffusion_model_pytorch, ...
|
||||
weight_prefixes = [w.split(".")[0] for w in weight_names]
|
||||
@@ -191,7 +190,7 @@ def filter_model_files(filenames):
|
||||
]
|
||||
|
||||
if is_transformers_available():
|
||||
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
|
||||
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME]
|
||||
|
||||
allowed_extensions = [wn.split(".")[-1] for wn in weight_names]
|
||||
|
||||
@@ -212,7 +211,7 @@ def variant_compatible_siblings(filenames, variant=None, ignore_patterns=None) -
|
||||
]
|
||||
|
||||
if is_transformers_available():
|
||||
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
|
||||
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME]
|
||||
|
||||
# model_pytorch, diffusion_model_pytorch, ...
|
||||
weight_prefixes = [w.split(".")[0] for w in weight_names]
|
||||
|
||||
@@ -152,16 +152,26 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
text_encoder ([`T5EncoderModel`]):
|
||||
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
|
||||
the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
|
||||
transformer ([`WanTransformer3DModel`]):
|
||||
transformer ([`WanVACETransformer3DModel`]):
|
||||
Conditional Transformer to denoise the input latents.
|
||||
transformer_2 ([`WanVACETransformer3DModel`], *optional*):
|
||||
Conditional Transformer to denoise the input latents during the low-noise stage. In two-stage denoising,
|
||||
`transformer` handles high-noise stages and `transformer_2` handles low-noise stages. If not provided, only
|
||||
`transformer` is used.
|
||||
scheduler ([`UniPCMultistepScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKLWan`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
||||
boundary_ratio (`float`, *optional*, defaults to `None`):
|
||||
Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising.
|
||||
The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. When provided,
|
||||
`transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps <
|
||||
boundary_timestep. If `None`, only `transformer` is used for the entire denoising process.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
_optional_components = ["transformer_2"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -170,6 +180,8 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
transformer: WanVACETransformer3DModel,
|
||||
vae: AutoencoderKLWan,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
transformer_2: WanVACETransformer3DModel = None,
|
||||
boundary_ratio: Optional[float] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -178,9 +190,10 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
transformer_2=transformer_2,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
self.register_to_config(boundary_ratio=boundary_ratio)
|
||||
self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
|
||||
self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
||||
@@ -321,6 +334,7 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
video=None,
|
||||
mask=None,
|
||||
reference_images=None,
|
||||
guidance_scale_2=None,
|
||||
):
|
||||
base = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1]
|
||||
if height % base != 0 or width % base != 0:
|
||||
@@ -332,6 +346,8 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
if self.config.boundary_ratio is None and guidance_scale_2 is not None:
|
||||
raise ValueError("`guidance_scale_2` is only supported when the pipeline's `boundary_ratio` is not None.")
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
@@ -667,6 +683,7 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
num_frames: int = 81,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 5.0,
|
||||
guidance_scale_2: Optional[float] = None,
|
||||
num_videos_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
@@ -728,6 +745,10 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
guidance_scale_2 (`float`, *optional*, defaults to `None`):
|
||||
Guidance scale for the low-noise stage transformer (`transformer_2`). If `None` and the pipeline's
|
||||
`boundary_ratio` is not None, uses the same value as `guidance_scale`. Only used when `transformer_2`
|
||||
and the pipeline's `boundary_ratio` are not None.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
@@ -793,6 +814,7 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
video,
|
||||
mask,
|
||||
reference_images,
|
||||
guidance_scale_2,
|
||||
)
|
||||
|
||||
if num_frames % self.vae_scale_factor_temporal != 1:
|
||||
@@ -802,7 +824,11 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
|
||||
num_frames = max(num_frames, 1)
|
||||
|
||||
if self.config.boundary_ratio is not None and guidance_scale_2 is None:
|
||||
guidance_scale_2 = guidance_scale
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._guidance_scale_2 = guidance_scale_2
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
@@ -896,36 +922,53 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
if self.config.boundary_ratio is not None:
|
||||
boundary_timestep = self.config.boundary_ratio * self.scheduler.config.num_train_timesteps
|
||||
else:
|
||||
boundary_timestep = None
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
|
||||
if boundary_timestep is None or t >= boundary_timestep:
|
||||
# wan2.1 or high-noise stage in wan2.2
|
||||
current_model = self.transformer
|
||||
current_guidance_scale = guidance_scale
|
||||
else:
|
||||
# low-noise stage in wan2.2
|
||||
current_model = self.transformer_2
|
||||
current_guidance_scale = guidance_scale_2
|
||||
|
||||
latent_model_input = latents.to(transformer_dtype)
|
||||
timestep = t.expand(latents.shape[0])
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
control_hidden_states=conditioning_latents,
|
||||
control_hidden_states_scale=conditioning_scale,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_uncond = self.transformer(
|
||||
with current_model.cache_context("cond"):
|
||||
noise_pred = current_model(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
control_hidden_states=conditioning_latents,
|
||||
control_hidden_states_scale=conditioning_scale,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
with current_model.cache_context("uncond"):
|
||||
noise_uncond = current_model(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
control_hidden_states=conditioning_latents,
|
||||
control_hidden_states_scale=conditioning_scale,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
@@ -1592,6 +1592,21 @@ class LTXPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class LucyEditPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class Lumina2Pipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -907,6 +907,13 @@ class FluxLoRAIntegrationTests(unittest.TestCase):
|
||||
|
||||
assert max_diff < 1e-3
|
||||
|
||||
def test_flux_kohya_embedders_conversion(self):
|
||||
"""Test that embedders load without throwing errors"""
|
||||
self.pipeline.load_lora_weights("rockerBOO/flux-bpo-po-lora")
|
||||
self.pipeline.unload_lora_weights()
|
||||
|
||||
assert True
|
||||
|
||||
def test_flux_xlabs(self):
|
||||
self.pipeline.load_lora_weights("XLabs-AI/flux-lora-collection", weight_name="disney_lora.safetensors")
|
||||
self.pipeline.fuse_lora()
|
||||
|
||||
@@ -87,6 +87,7 @@ class WanVACEPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"transformer_2": None,
|
||||
}
|
||||
return components
|
||||
|
||||
|
||||
Reference in New Issue
Block a user