mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-16 07:53:55 +08:00
Signed-off-by: Chang Liu (Enterprise Products) <liuc@nvidia.com> Signed-off-by: Chang Liu <9713593+chang-l@users.noreply.github.com> Signed-off-by: Zhenhua Wang <zhenhuaw@nvidia.com> Co-authored-by: Freddy Qi <junq@nvidia.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: Zhenhua Wang <zhenhuaw@nvidia.com>
545 lines
19 KiB
Python
545 lines
19 KiB
Python
import asyncio
|
|
import queue
|
|
import socket
|
|
import threading
|
|
import time
|
|
import traceback
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Optional, Union
|
|
|
|
import torch.multiprocessing as mp
|
|
import zmq
|
|
|
|
from tensorrt_llm._torch.visual_gen import DiffusionRequest, DiffusionResponse
|
|
from tensorrt_llm._torch.visual_gen.executor import run_diffusion_worker
|
|
from tensorrt_llm._torch.visual_gen.output import MediaOutput
|
|
|
|
__all__ = ["VisualGen", "VisualGenParams", "MediaOutput"]
|
|
from tensorrt_llm.executor.ipc import ZeroMqQueue
|
|
from tensorrt_llm.inputs.data import VisualGenInputs
|
|
from tensorrt_llm.logger import logger
|
|
|
|
# Timeouts (seconds)
|
|
POLL_TIMEOUT = 0.01
|
|
AWAIT_TIMEOUT = 0.05
|
|
THREAD_TIMEOUT = 5.0
|
|
WORKER_TIMEOUT = 2.0
|
|
READY_TIMEOUT = 1200 # 20 minutes for large models (Wan 2.2 with transformer_2)
|
|
|
|
|
|
def find_free_port() -> int:
|
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
s.bind(("", 0))
|
|
return s.getsockname()[1]
|
|
|
|
|
|
def get_ip_address() -> str:
|
|
"""Get local IP address."""
|
|
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
|
try:
|
|
s.connect(("10.255.255.255", 1))
|
|
return s.getsockname()[0]
|
|
except Exception:
|
|
return "127.0.0.1"
|
|
finally:
|
|
s.close()
|
|
|
|
|
|
class DiffusionRemoteClient:
|
|
"""Client proxy for remote DiffusionExecutor in worker processes."""
|
|
|
|
def __init__(
|
|
self,
|
|
model_path: Union[str, Path],
|
|
n_workers: int = 1,
|
|
diffusion_config: Optional[dict] = None,
|
|
):
|
|
self.model_path = str(model_path)
|
|
self.n_workers = n_workers
|
|
self.diffusion_config = diffusion_config
|
|
|
|
# Setup distributed env
|
|
self.master_addr = "127.0.0.1"
|
|
self.master_port = find_free_port()
|
|
|
|
# Setup IPC addresses
|
|
self.host_ip = get_ip_address()
|
|
req_port, resp_port = find_free_port(), find_free_port()
|
|
|
|
self.request_queue_addr = f"tcp://0.0.0.0:{req_port}"
|
|
self.response_queue_addr = f"tcp://0.0.0.0:{resp_port}"
|
|
self.req_addr_connect = f"tcp://{self.host_ip}:{req_port}"
|
|
self.resp_addr_connect = f"tcp://{self.host_ip}:{resp_port}"
|
|
|
|
# IPC setup
|
|
self.requests_ipc = None
|
|
self.responses_ipc = None
|
|
self.pending_requests = queue.Queue()
|
|
self.completed_responses: Dict[int, DiffusionResponse] = {}
|
|
|
|
# We'll create asyncio primitives in the background thread's event loop
|
|
self._event_loop = None
|
|
self.response_event = None
|
|
self.lock = None
|
|
self.shutdown_event = threading.Event()
|
|
self.event_loop_ready = threading.Event()
|
|
|
|
# Start background thread (it will create its own event loop)
|
|
self.background_thread = threading.Thread(target=self._serve_forever_thread, daemon=True)
|
|
self.background_thread.start()
|
|
|
|
# Wait for the background thread to initialize the event loop
|
|
self.event_loop_ready.wait()
|
|
|
|
# Launch workers
|
|
logger.info(f"DiffusionClient: Launching {n_workers} workers")
|
|
ctx = mp.get_context("spawn")
|
|
self.worker_processes = []
|
|
for rank in range(n_workers):
|
|
p = ctx.Process(
|
|
target=run_diffusion_worker,
|
|
kwargs={
|
|
"rank": rank,
|
|
"world_size": n_workers,
|
|
"master_addr": self.master_addr,
|
|
"master_port": self.master_port,
|
|
"model_path": self.model_path,
|
|
"request_queue_addr": self.req_addr_connect,
|
|
"response_queue_addr": self.resp_addr_connect,
|
|
"diffusion_config": self.diffusion_config,
|
|
},
|
|
)
|
|
p.start()
|
|
self.worker_processes.append(p)
|
|
|
|
self._wait_ready()
|
|
|
|
@staticmethod
|
|
def _close_socket(ipc_queue):
|
|
if ipc_queue and ipc_queue.socket:
|
|
ipc_queue.socket.setsockopt(zmq.LINGER, 0)
|
|
ipc_queue.close()
|
|
|
|
def enqueue_requests(self, requests: List[DiffusionRequest]) -> List[int]:
|
|
"""Enqueue requests and return their IDs."""
|
|
req_ids = []
|
|
for req in requests:
|
|
self.pending_requests.put(req)
|
|
req_ids.append(req.request_id)
|
|
return req_ids
|
|
|
|
async def await_responses(
|
|
self, request_ids: Union[int, List[int]], timeout: Optional[float] = None
|
|
) -> Union[DiffusionResponse, List[DiffusionResponse]]:
|
|
"""Wait for responses by request IDs.
|
|
|
|
Args:
|
|
request_ids: Single request ID or list of request IDs to wait for
|
|
timeout: Maximum total wait time in seconds (None = wait indefinitely)
|
|
|
|
Returns:
|
|
Single response or list of responses (None if request timed out)
|
|
"""
|
|
is_single = isinstance(request_ids, int)
|
|
ids = [request_ids] if is_single else request_ids
|
|
|
|
start_time = time.time()
|
|
results = {}
|
|
|
|
while len(results) < len(ids):
|
|
async with self.lock:
|
|
for req_id in ids:
|
|
if req_id in self.completed_responses:
|
|
results[req_id] = self.completed_responses.pop(req_id)
|
|
|
|
# All responses collected
|
|
if len(results) == len(ids):
|
|
break
|
|
|
|
# Check if overall timeout exceeded
|
|
if timeout is not None:
|
|
elapsed = time.time() - start_time
|
|
if elapsed >= timeout:
|
|
break
|
|
# Wait for remaining time or AWAIT_TIMEOUT, whichever is shorter
|
|
wait_time = min(timeout - elapsed, AWAIT_TIMEOUT)
|
|
else:
|
|
wait_time = AWAIT_TIMEOUT
|
|
|
|
try:
|
|
await asyncio.wait_for(self.response_event.wait(), timeout=wait_time)
|
|
except asyncio.TimeoutError:
|
|
pass
|
|
self.response_event.clear()
|
|
|
|
out = [results.get(rid) for rid in ids]
|
|
return out[0] if is_single else out
|
|
|
|
def await_responses_sync(
|
|
self, request_ids: Union[int, List[int]], timeout: Optional[float] = None
|
|
) -> Union[DiffusionResponse, List[DiffusionResponse]]:
|
|
"""Sync wrapper to await responses from the main thread."""
|
|
future = asyncio.run_coroutine_threadsafe(
|
|
self.await_responses(request_ids, timeout), self._event_loop
|
|
)
|
|
return future.result(timeout=timeout if timeout else None)
|
|
|
|
def _init_ipc(self) -> bool:
|
|
"""Initialize IPC queues."""
|
|
try:
|
|
logger.info("DiffusionClient: Initializing IPC")
|
|
self.requests_ipc = ZeroMqQueue(
|
|
(self.request_queue_addr, None),
|
|
is_server=True,
|
|
socket_type=zmq.PUSH,
|
|
use_hmac_encryption=False,
|
|
)
|
|
self.responses_ipc = ZeroMqQueue(
|
|
(self.response_queue_addr, None),
|
|
is_server=True,
|
|
socket_type=zmq.PULL,
|
|
use_hmac_encryption=False,
|
|
)
|
|
logger.info("DiffusionClient: IPC ready")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"DiffusionClient: IPC init failed: {e}")
|
|
return False
|
|
|
|
def _send_shutdown(self):
|
|
"""Send shutdown signal."""
|
|
logger.info("DiffusionClient: Sending shutdown signal")
|
|
if self.requests_ipc:
|
|
self.requests_ipc.put(None)
|
|
self._close_socket(self.requests_ipc)
|
|
|
|
def _process_requests(self):
|
|
"""Process pending requests."""
|
|
try:
|
|
req = self.pending_requests.get(timeout=POLL_TIMEOUT)
|
|
if req is None:
|
|
self._send_shutdown()
|
|
self.shutdown_event.set()
|
|
return
|
|
|
|
logger.info(f"DiffusionClient: Sending request {req.request_id}")
|
|
self.requests_ipc.put(req)
|
|
except queue.Empty:
|
|
pass
|
|
except Exception as e:
|
|
logger.error(f"DiffusionClient: Error sending request: {e}")
|
|
logger.error(traceback.format_exc())
|
|
|
|
def _process_responses(self):
|
|
"""Poll and process responses."""
|
|
try:
|
|
if self.responses_ipc.poll(timeout=POLL_TIMEOUT):
|
|
response = self.responses_ipc.get()
|
|
if isinstance(response, DiffusionResponse):
|
|
if response.request_id == -1:
|
|
logger.info("DiffusionClient: Received READY signal")
|
|
|
|
# Schedule the lock acquisition and event setting in the event loop
|
|
asyncio.run_coroutine_threadsafe(
|
|
self._store_response(response), self._event_loop
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"DiffusionClient: Error processing response: {e}")
|
|
|
|
async def _store_response(self, response: DiffusionResponse):
|
|
"""Store response in the completed_responses dict (async helper)."""
|
|
async with self.lock:
|
|
self.completed_responses[response.request_id] = response
|
|
self.response_event.set()
|
|
|
|
def _cleanup_ipc(self):
|
|
"""Cleanup IPC."""
|
|
logger.info("DiffusionClient: Cleaning up IPC")
|
|
self._close_socket(self.requests_ipc)
|
|
self._close_socket(self.responses_ipc)
|
|
|
|
def _serve_forever_thread(self):
|
|
"""Background thread wrapper that creates and runs an event loop."""
|
|
logger.info("DiffusionClient: Background thread started")
|
|
|
|
# Create a new event loop for this thread
|
|
self._event_loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(self._event_loop)
|
|
|
|
# Create async primitives in this thread's event loop
|
|
self.response_event = asyncio.Event()
|
|
self.lock = asyncio.Lock()
|
|
|
|
# Signal that the event loop is ready
|
|
self.event_loop_ready.set()
|
|
|
|
# Run the async serve_forever
|
|
try:
|
|
self._event_loop.run_until_complete(self._serve_forever())
|
|
finally:
|
|
self._event_loop.close()
|
|
logger.info("DiffusionClient: Background thread stopped")
|
|
|
|
async def _serve_forever(self):
|
|
"""Background thread main loop (async version)."""
|
|
if not self._init_ipc():
|
|
return
|
|
|
|
while not self.shutdown_event.is_set():
|
|
self._process_requests()
|
|
self._process_responses()
|
|
await asyncio.sleep(0.001) # Yield control to allow other coroutines to run
|
|
|
|
self._cleanup_ipc()
|
|
|
|
def shutdown(self):
|
|
"""Shutdown client and workers."""
|
|
logger.info("DiffusionClient: Shutting down")
|
|
self.pending_requests.put(None)
|
|
|
|
self.background_thread.join(timeout=THREAD_TIMEOUT)
|
|
if self.background_thread.is_alive():
|
|
logger.warning("DiffusionClient: Force stopping background thread")
|
|
self.shutdown_event.set()
|
|
self.background_thread.join(timeout=1.0)
|
|
|
|
# Shutdown workers
|
|
logger.info("DiffusionClient: Stopping workers")
|
|
for p in self.worker_processes:
|
|
p.join(timeout=WORKER_TIMEOUT)
|
|
if p.is_alive():
|
|
logger.warning(f"DiffusionClient: Terminating worker {p.pid} with SIGTERM")
|
|
p.terminate()
|
|
p.join(timeout=WORKER_TIMEOUT)
|
|
if p.is_alive():
|
|
logger.warning(f"DiffusionClient: Force killing worker {p.pid} with SIGKILL")
|
|
p.kill()
|
|
p.join(timeout=WORKER_TIMEOUT)
|
|
|
|
def _wait_ready(self, timeout: float = READY_TIMEOUT):
|
|
"""Wait for workers to be ready (sync wrapper for async operation)."""
|
|
logger.info("DiffusionClient: Waiting for workers")
|
|
|
|
# Run the async wait in the background thread's event loop
|
|
future = asyncio.run_coroutine_threadsafe(self._wait_ready_async(timeout), self._event_loop)
|
|
return future.result(timeout=timeout)
|
|
|
|
async def _wait_ready_async(self, timeout: float = READY_TIMEOUT):
|
|
"""Wait for workers to be ready (async version)."""
|
|
start_time = time.time()
|
|
|
|
while True:
|
|
async with self.lock:
|
|
if -1 in self.completed_responses:
|
|
self.completed_responses.pop(-1)
|
|
logger.info("DiffusionClient: Workers ready")
|
|
return
|
|
|
|
if time.time() - start_time > timeout:
|
|
raise RuntimeError("DiffusionClient: Timeout waiting for workers")
|
|
|
|
try:
|
|
await asyncio.wait_for(self.response_event.wait(), timeout=AWAIT_TIMEOUT)
|
|
except asyncio.TimeoutError:
|
|
pass
|
|
self.response_event.clear()
|
|
|
|
|
|
class DiffusionGenerationResult:
|
|
"""Future-like object for async generation."""
|
|
|
|
def __init__(self, request_id: int, executor: DiffusionRemoteClient):
|
|
self.request_id = request_id
|
|
self.executor = executor
|
|
self._result = None
|
|
self._finished = False
|
|
self._error = None
|
|
|
|
async def result(self, timeout: Optional[float] = None) -> Any:
|
|
"""Wait for and return result (async version).
|
|
|
|
Can be awaited from any async context (e.g., FastAPI background tasks).
|
|
"""
|
|
if self._finished:
|
|
if self._error:
|
|
raise RuntimeError(self._error)
|
|
return self._result
|
|
|
|
# Use run_coroutine_threadsafe to execute in the background thread's event loop
|
|
future = asyncio.run_coroutine_threadsafe(
|
|
self.executor.await_responses(self.request_id, timeout=timeout),
|
|
self.executor._event_loop,
|
|
)
|
|
|
|
# Await the future in the current event loop
|
|
response = await asyncio.wrap_future(future)
|
|
|
|
if response.error_msg:
|
|
self._error = response.error_msg
|
|
self._finished = True
|
|
raise RuntimeError(f"Generation failed: {response.error_msg}")
|
|
|
|
self._result = response.output
|
|
self._finished = True
|
|
return self._result
|
|
|
|
def cancel(self):
|
|
raise NotImplementedError("Cancel request (not yet implemented).")
|
|
|
|
|
|
@dataclass
|
|
class VisualGenParams:
|
|
"""Parameters for visual generation.
|
|
|
|
Attributes:
|
|
height: Output height in pixels
|
|
width: Output width in pixels
|
|
num_inference_steps: Number of denoising steps
|
|
guidance_scale: Classifier-free guidance scale
|
|
max_sequence_length: Maximum sequence length for text encoding
|
|
seed: Random seed for reproducibility
|
|
|
|
# Video-specific parameters
|
|
num_frames: Number of video frames to generate
|
|
frame_rate: Frame rate for video output in fps
|
|
|
|
# Image-specific parameters
|
|
num_images_per_prompt: Number of images to generate per prompt (for image models)
|
|
|
|
# Advanced parameters
|
|
guidance_rescale: Guidance rescale factor (for some models)
|
|
output_type: Output type ("pt" for PyTorch tensors, "pil" for PIL images)
|
|
"""
|
|
|
|
height: int = 720
|
|
width: int = 1280
|
|
num_inference_steps: int = 50
|
|
guidance_scale: float = 5.0
|
|
max_sequence_length: int = 512
|
|
seed: int = 42
|
|
|
|
# Video-specific parameters
|
|
num_frames: int = 81
|
|
frame_rate: float = 24.0
|
|
input_reference: Optional[str] = None
|
|
|
|
# Image-specific parameters
|
|
num_images_per_prompt: int = 1
|
|
|
|
## Image edit parameters
|
|
image: Optional[List[str]] = None
|
|
mask: Optional[str] = None
|
|
|
|
# Advanced parameters
|
|
guidance_rescale: float = 0.0
|
|
output_type: str = "pt"
|
|
|
|
# Wan-specific parameters
|
|
guidance_scale_2: Optional[float] = None
|
|
boundary_ratio: Optional[float] = None
|
|
last_image: Optional[str] = None
|
|
|
|
|
|
class VisualGen:
|
|
"""High-level API for visual generation."""
|
|
|
|
def __init__(
|
|
self,
|
|
model_path: Union[str, Path],
|
|
n_workers: int = 1,
|
|
diffusion_config: Optional[dict] = None,
|
|
):
|
|
self.model_path = str(model_path)
|
|
self.n_workers = n_workers
|
|
self.diffusion_config = diffusion_config
|
|
|
|
self.executor = DiffusionRemoteClient(
|
|
model_path=self.model_path,
|
|
n_workers=self.n_workers,
|
|
diffusion_config=self.diffusion_config,
|
|
)
|
|
self.req_counter = 0
|
|
|
|
def generate(
|
|
self,
|
|
inputs: VisualGenInputs,
|
|
params: VisualGenParams,
|
|
) -> MediaOutput:
|
|
"""Synchronous generation. Blocks until complete.
|
|
|
|
Args:
|
|
params: Generation parameters.
|
|
|
|
Returns:
|
|
MediaOutput: Generated media with model-specific fields populated:
|
|
- FLUX2: MediaOutput(image=torch.Tensor)
|
|
- WAN: MediaOutput(video=torch.Tensor)
|
|
- LTX2: MediaOutput(video=torch.Tensor, audio=torch.Tensor)
|
|
"""
|
|
future = self.generate_async(
|
|
inputs=inputs,
|
|
params=params,
|
|
)
|
|
|
|
# Use the sync wrapper to get result
|
|
response = self.executor.await_responses_sync(future.request_id, timeout=None)
|
|
if response.error_msg:
|
|
raise RuntimeError(f"Generation failed: {response.error_msg}")
|
|
return response.output
|
|
|
|
def generate_async(
|
|
self,
|
|
inputs: VisualGenInputs,
|
|
params: VisualGenParams,
|
|
) -> DiffusionGenerationResult:
|
|
"""Async generation. Returns immediately with future-like object.
|
|
|
|
Args:
|
|
params: Generation parameters.
|
|
|
|
Returns:
|
|
DiffusionGenerationResult: Call result() to get output dict.
|
|
"""
|
|
req_id = self.req_counter
|
|
self.req_counter += 1
|
|
|
|
if isinstance(inputs, dict):
|
|
prompt = inputs.get("prompt")
|
|
negative_prompt = inputs.get("negative_prompt", None)
|
|
elif isinstance(inputs, str):
|
|
prompt = inputs
|
|
negative_prompt = None
|
|
else:
|
|
# TODO: Support batch generation
|
|
raise ValueError(f"Invalid inputs type: {type(inputs)}")
|
|
|
|
request = DiffusionRequest(
|
|
request_id=req_id,
|
|
prompt=prompt,
|
|
negative_prompt=negative_prompt,
|
|
height=params.height,
|
|
width=params.width,
|
|
num_inference_steps=params.num_inference_steps,
|
|
guidance_scale=params.guidance_scale,
|
|
max_sequence_length=params.max_sequence_length,
|
|
seed=params.seed,
|
|
num_frames=params.num_frames,
|
|
frame_rate=params.frame_rate,
|
|
num_images_per_prompt=params.num_images_per_prompt,
|
|
guidance_rescale=params.guidance_rescale,
|
|
output_type=params.output_type,
|
|
image=params.input_reference,
|
|
guidance_scale_2=params.guidance_scale_2,
|
|
boundary_ratio=params.boundary_ratio,
|
|
last_image=params.last_image,
|
|
)
|
|
|
|
self.executor.enqueue_requests([request])
|
|
return DiffusionGenerationResult(req_id, self.executor)
|
|
|
|
def shutdown(self):
|
|
"""Shutdown executor and cleanup."""
|
|
logger.info("VisualGen: Shutting down")
|
|
self.executor.shutdown()
|