TensorRT-LLMs/tensorrt_llm/llmapi/visual_gen.py
Chang Liu 26901e4aa0
[TRTLLM-10612][feat] Initial support of AIGV models in TRTLLM (#11462)
Signed-off-by: Chang Liu (Enterprise Products) <liuc@nvidia.com>
Signed-off-by: Chang Liu <9713593+chang-l@users.noreply.github.com>
Signed-off-by: Zhenhua Wang <zhenhuaw@nvidia.com>
Co-authored-by: Freddy Qi <junq@nvidia.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Zhenhua Wang <zhenhuaw@nvidia.com>
2026-02-14 06:11:11 +08:00

545 lines
19 KiB
Python

import asyncio
import queue
import socket
import threading
import time
import traceback
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import torch.multiprocessing as mp
import zmq
from tensorrt_llm._torch.visual_gen import DiffusionRequest, DiffusionResponse
from tensorrt_llm._torch.visual_gen.executor import run_diffusion_worker
from tensorrt_llm._torch.visual_gen.output import MediaOutput
__all__ = ["VisualGen", "VisualGenParams", "MediaOutput"]
from tensorrt_llm.executor.ipc import ZeroMqQueue
from tensorrt_llm.inputs.data import VisualGenInputs
from tensorrt_llm.logger import logger
# Timeouts (seconds)
POLL_TIMEOUT = 0.01
AWAIT_TIMEOUT = 0.05
THREAD_TIMEOUT = 5.0
WORKER_TIMEOUT = 2.0
READY_TIMEOUT = 1200 # 20 minutes for large models (Wan 2.2 with transformer_2)
def find_free_port() -> int:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]
def get_ip_address() -> str:
"""Get local IP address."""
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try:
s.connect(("10.255.255.255", 1))
return s.getsockname()[0]
except Exception:
return "127.0.0.1"
finally:
s.close()
class DiffusionRemoteClient:
"""Client proxy for remote DiffusionExecutor in worker processes."""
def __init__(
self,
model_path: Union[str, Path],
n_workers: int = 1,
diffusion_config: Optional[dict] = None,
):
self.model_path = str(model_path)
self.n_workers = n_workers
self.diffusion_config = diffusion_config
# Setup distributed env
self.master_addr = "127.0.0.1"
self.master_port = find_free_port()
# Setup IPC addresses
self.host_ip = get_ip_address()
req_port, resp_port = find_free_port(), find_free_port()
self.request_queue_addr = f"tcp://0.0.0.0:{req_port}"
self.response_queue_addr = f"tcp://0.0.0.0:{resp_port}"
self.req_addr_connect = f"tcp://{self.host_ip}:{req_port}"
self.resp_addr_connect = f"tcp://{self.host_ip}:{resp_port}"
# IPC setup
self.requests_ipc = None
self.responses_ipc = None
self.pending_requests = queue.Queue()
self.completed_responses: Dict[int, DiffusionResponse] = {}
# We'll create asyncio primitives in the background thread's event loop
self._event_loop = None
self.response_event = None
self.lock = None
self.shutdown_event = threading.Event()
self.event_loop_ready = threading.Event()
# Start background thread (it will create its own event loop)
self.background_thread = threading.Thread(target=self._serve_forever_thread, daemon=True)
self.background_thread.start()
# Wait for the background thread to initialize the event loop
self.event_loop_ready.wait()
# Launch workers
logger.info(f"DiffusionClient: Launching {n_workers} workers")
ctx = mp.get_context("spawn")
self.worker_processes = []
for rank in range(n_workers):
p = ctx.Process(
target=run_diffusion_worker,
kwargs={
"rank": rank,
"world_size": n_workers,
"master_addr": self.master_addr,
"master_port": self.master_port,
"model_path": self.model_path,
"request_queue_addr": self.req_addr_connect,
"response_queue_addr": self.resp_addr_connect,
"diffusion_config": self.diffusion_config,
},
)
p.start()
self.worker_processes.append(p)
self._wait_ready()
@staticmethod
def _close_socket(ipc_queue):
if ipc_queue and ipc_queue.socket:
ipc_queue.socket.setsockopt(zmq.LINGER, 0)
ipc_queue.close()
def enqueue_requests(self, requests: List[DiffusionRequest]) -> List[int]:
"""Enqueue requests and return their IDs."""
req_ids = []
for req in requests:
self.pending_requests.put(req)
req_ids.append(req.request_id)
return req_ids
async def await_responses(
self, request_ids: Union[int, List[int]], timeout: Optional[float] = None
) -> Union[DiffusionResponse, List[DiffusionResponse]]:
"""Wait for responses by request IDs.
Args:
request_ids: Single request ID or list of request IDs to wait for
timeout: Maximum total wait time in seconds (None = wait indefinitely)
Returns:
Single response or list of responses (None if request timed out)
"""
is_single = isinstance(request_ids, int)
ids = [request_ids] if is_single else request_ids
start_time = time.time()
results = {}
while len(results) < len(ids):
async with self.lock:
for req_id in ids:
if req_id in self.completed_responses:
results[req_id] = self.completed_responses.pop(req_id)
# All responses collected
if len(results) == len(ids):
break
# Check if overall timeout exceeded
if timeout is not None:
elapsed = time.time() - start_time
if elapsed >= timeout:
break
# Wait for remaining time or AWAIT_TIMEOUT, whichever is shorter
wait_time = min(timeout - elapsed, AWAIT_TIMEOUT)
else:
wait_time = AWAIT_TIMEOUT
try:
await asyncio.wait_for(self.response_event.wait(), timeout=wait_time)
except asyncio.TimeoutError:
pass
self.response_event.clear()
out = [results.get(rid) for rid in ids]
return out[0] if is_single else out
def await_responses_sync(
self, request_ids: Union[int, List[int]], timeout: Optional[float] = None
) -> Union[DiffusionResponse, List[DiffusionResponse]]:
"""Sync wrapper to await responses from the main thread."""
future = asyncio.run_coroutine_threadsafe(
self.await_responses(request_ids, timeout), self._event_loop
)
return future.result(timeout=timeout if timeout else None)
def _init_ipc(self) -> bool:
"""Initialize IPC queues."""
try:
logger.info("DiffusionClient: Initializing IPC")
self.requests_ipc = ZeroMqQueue(
(self.request_queue_addr, None),
is_server=True,
socket_type=zmq.PUSH,
use_hmac_encryption=False,
)
self.responses_ipc = ZeroMqQueue(
(self.response_queue_addr, None),
is_server=True,
socket_type=zmq.PULL,
use_hmac_encryption=False,
)
logger.info("DiffusionClient: IPC ready")
return True
except Exception as e:
logger.error(f"DiffusionClient: IPC init failed: {e}")
return False
def _send_shutdown(self):
"""Send shutdown signal."""
logger.info("DiffusionClient: Sending shutdown signal")
if self.requests_ipc:
self.requests_ipc.put(None)
self._close_socket(self.requests_ipc)
def _process_requests(self):
"""Process pending requests."""
try:
req = self.pending_requests.get(timeout=POLL_TIMEOUT)
if req is None:
self._send_shutdown()
self.shutdown_event.set()
return
logger.info(f"DiffusionClient: Sending request {req.request_id}")
self.requests_ipc.put(req)
except queue.Empty:
pass
except Exception as e:
logger.error(f"DiffusionClient: Error sending request: {e}")
logger.error(traceback.format_exc())
def _process_responses(self):
"""Poll and process responses."""
try:
if self.responses_ipc.poll(timeout=POLL_TIMEOUT):
response = self.responses_ipc.get()
if isinstance(response, DiffusionResponse):
if response.request_id == -1:
logger.info("DiffusionClient: Received READY signal")
# Schedule the lock acquisition and event setting in the event loop
asyncio.run_coroutine_threadsafe(
self._store_response(response), self._event_loop
)
except Exception as e:
logger.error(f"DiffusionClient: Error processing response: {e}")
async def _store_response(self, response: DiffusionResponse):
"""Store response in the completed_responses dict (async helper)."""
async with self.lock:
self.completed_responses[response.request_id] = response
self.response_event.set()
def _cleanup_ipc(self):
"""Cleanup IPC."""
logger.info("DiffusionClient: Cleaning up IPC")
self._close_socket(self.requests_ipc)
self._close_socket(self.responses_ipc)
def _serve_forever_thread(self):
"""Background thread wrapper that creates and runs an event loop."""
logger.info("DiffusionClient: Background thread started")
# Create a new event loop for this thread
self._event_loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._event_loop)
# Create async primitives in this thread's event loop
self.response_event = asyncio.Event()
self.lock = asyncio.Lock()
# Signal that the event loop is ready
self.event_loop_ready.set()
# Run the async serve_forever
try:
self._event_loop.run_until_complete(self._serve_forever())
finally:
self._event_loop.close()
logger.info("DiffusionClient: Background thread stopped")
async def _serve_forever(self):
"""Background thread main loop (async version)."""
if not self._init_ipc():
return
while not self.shutdown_event.is_set():
self._process_requests()
self._process_responses()
await asyncio.sleep(0.001) # Yield control to allow other coroutines to run
self._cleanup_ipc()
def shutdown(self):
"""Shutdown client and workers."""
logger.info("DiffusionClient: Shutting down")
self.pending_requests.put(None)
self.background_thread.join(timeout=THREAD_TIMEOUT)
if self.background_thread.is_alive():
logger.warning("DiffusionClient: Force stopping background thread")
self.shutdown_event.set()
self.background_thread.join(timeout=1.0)
# Shutdown workers
logger.info("DiffusionClient: Stopping workers")
for p in self.worker_processes:
p.join(timeout=WORKER_TIMEOUT)
if p.is_alive():
logger.warning(f"DiffusionClient: Terminating worker {p.pid} with SIGTERM")
p.terminate()
p.join(timeout=WORKER_TIMEOUT)
if p.is_alive():
logger.warning(f"DiffusionClient: Force killing worker {p.pid} with SIGKILL")
p.kill()
p.join(timeout=WORKER_TIMEOUT)
def _wait_ready(self, timeout: float = READY_TIMEOUT):
"""Wait for workers to be ready (sync wrapper for async operation)."""
logger.info("DiffusionClient: Waiting for workers")
# Run the async wait in the background thread's event loop
future = asyncio.run_coroutine_threadsafe(self._wait_ready_async(timeout), self._event_loop)
return future.result(timeout=timeout)
async def _wait_ready_async(self, timeout: float = READY_TIMEOUT):
"""Wait for workers to be ready (async version)."""
start_time = time.time()
while True:
async with self.lock:
if -1 in self.completed_responses:
self.completed_responses.pop(-1)
logger.info("DiffusionClient: Workers ready")
return
if time.time() - start_time > timeout:
raise RuntimeError("DiffusionClient: Timeout waiting for workers")
try:
await asyncio.wait_for(self.response_event.wait(), timeout=AWAIT_TIMEOUT)
except asyncio.TimeoutError:
pass
self.response_event.clear()
class DiffusionGenerationResult:
"""Future-like object for async generation."""
def __init__(self, request_id: int, executor: DiffusionRemoteClient):
self.request_id = request_id
self.executor = executor
self._result = None
self._finished = False
self._error = None
async def result(self, timeout: Optional[float] = None) -> Any:
"""Wait for and return result (async version).
Can be awaited from any async context (e.g., FastAPI background tasks).
"""
if self._finished:
if self._error:
raise RuntimeError(self._error)
return self._result
# Use run_coroutine_threadsafe to execute in the background thread's event loop
future = asyncio.run_coroutine_threadsafe(
self.executor.await_responses(self.request_id, timeout=timeout),
self.executor._event_loop,
)
# Await the future in the current event loop
response = await asyncio.wrap_future(future)
if response.error_msg:
self._error = response.error_msg
self._finished = True
raise RuntimeError(f"Generation failed: {response.error_msg}")
self._result = response.output
self._finished = True
return self._result
def cancel(self):
raise NotImplementedError("Cancel request (not yet implemented).")
@dataclass
class VisualGenParams:
"""Parameters for visual generation.
Attributes:
height: Output height in pixels
width: Output width in pixels
num_inference_steps: Number of denoising steps
guidance_scale: Classifier-free guidance scale
max_sequence_length: Maximum sequence length for text encoding
seed: Random seed for reproducibility
# Video-specific parameters
num_frames: Number of video frames to generate
frame_rate: Frame rate for video output in fps
# Image-specific parameters
num_images_per_prompt: Number of images to generate per prompt (for image models)
# Advanced parameters
guidance_rescale: Guidance rescale factor (for some models)
output_type: Output type ("pt" for PyTorch tensors, "pil" for PIL images)
"""
height: int = 720
width: int = 1280
num_inference_steps: int = 50
guidance_scale: float = 5.0
max_sequence_length: int = 512
seed: int = 42
# Video-specific parameters
num_frames: int = 81
frame_rate: float = 24.0
input_reference: Optional[str] = None
# Image-specific parameters
num_images_per_prompt: int = 1
## Image edit parameters
image: Optional[List[str]] = None
mask: Optional[str] = None
# Advanced parameters
guidance_rescale: float = 0.0
output_type: str = "pt"
# Wan-specific parameters
guidance_scale_2: Optional[float] = None
boundary_ratio: Optional[float] = None
last_image: Optional[str] = None
class VisualGen:
"""High-level API for visual generation."""
def __init__(
self,
model_path: Union[str, Path],
n_workers: int = 1,
diffusion_config: Optional[dict] = None,
):
self.model_path = str(model_path)
self.n_workers = n_workers
self.diffusion_config = diffusion_config
self.executor = DiffusionRemoteClient(
model_path=self.model_path,
n_workers=self.n_workers,
diffusion_config=self.diffusion_config,
)
self.req_counter = 0
def generate(
self,
inputs: VisualGenInputs,
params: VisualGenParams,
) -> MediaOutput:
"""Synchronous generation. Blocks until complete.
Args:
params: Generation parameters.
Returns:
MediaOutput: Generated media with model-specific fields populated:
- FLUX2: MediaOutput(image=torch.Tensor)
- WAN: MediaOutput(video=torch.Tensor)
- LTX2: MediaOutput(video=torch.Tensor, audio=torch.Tensor)
"""
future = self.generate_async(
inputs=inputs,
params=params,
)
# Use the sync wrapper to get result
response = self.executor.await_responses_sync(future.request_id, timeout=None)
if response.error_msg:
raise RuntimeError(f"Generation failed: {response.error_msg}")
return response.output
def generate_async(
self,
inputs: VisualGenInputs,
params: VisualGenParams,
) -> DiffusionGenerationResult:
"""Async generation. Returns immediately with future-like object.
Args:
params: Generation parameters.
Returns:
DiffusionGenerationResult: Call result() to get output dict.
"""
req_id = self.req_counter
self.req_counter += 1
if isinstance(inputs, dict):
prompt = inputs.get("prompt")
negative_prompt = inputs.get("negative_prompt", None)
elif isinstance(inputs, str):
prompt = inputs
negative_prompt = None
else:
# TODO: Support batch generation
raise ValueError(f"Invalid inputs type: {type(inputs)}")
request = DiffusionRequest(
request_id=req_id,
prompt=prompt,
negative_prompt=negative_prompt,
height=params.height,
width=params.width,
num_inference_steps=params.num_inference_steps,
guidance_scale=params.guidance_scale,
max_sequence_length=params.max_sequence_length,
seed=params.seed,
num_frames=params.num_frames,
frame_rate=params.frame_rate,
num_images_per_prompt=params.num_images_per_prompt,
guidance_rescale=params.guidance_rescale,
output_type=params.output_type,
image=params.input_reference,
guidance_scale_2=params.guidance_scale_2,
boundary_ratio=params.boundary_ratio,
last_image=params.last_image,
)
self.executor.enqueue_requests([request])
return DiffusionGenerationResult(req_id, self.executor)
def shutdown(self):
"""Shutdown executor and cleanup."""
logger.info("VisualGen: Shutting down")
self.executor.shutdown()