TensorRT-LLMs/tensorrt_llm/serve/visual_gen_utils.py
Chang Liu 26901e4aa0
[TRTLLM-10612][feat] Initial support of AIGV models in TRTLLM (#11462)
Signed-off-by: Chang Liu (Enterprise Products) <liuc@nvidia.com>
Signed-off-by: Chang Liu <9713593+chang-l@users.noreply.github.com>
Signed-off-by: Zhenhua Wang <zhenhuaw@nvidia.com>
Co-authored-by: Freddy Qi <junq@nvidia.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Zhenhua Wang <zhenhuaw@nvidia.com>
2026-02-14 06:11:11 +08:00

113 lines
4.3 KiB
Python

import asyncio
import base64
import os
import shutil
from typing import Any, Dict, List, Optional
from tensorrt_llm.llmapi.visual_gen import VisualGenParams
from tensorrt_llm.serve.openai_protocol import (
ImageEditRequest,
ImageGenerationRequest,
VideoGenerationRequest,
)
def parse_visual_gen_params(
request: ImageGenerationRequest | VideoGenerationRequest | ImageEditRequest,
id: str,
media_storage_path: Optional[str] = None,
) -> VisualGenParams:
params = VisualGenParams()
params.prompt = request.prompt
if request.negative_prompt is not None:
params.negative_prompt = request.negative_prompt
if request.size is not None and request.size != "auto":
params.width, params.height = map(int, request.size.split("x"))
if request.guidance_scale is not None:
params.guidance_scale = request.guidance_scale
if request.guidance_rescale is not None:
params.guidance_rescale = request.guidance_rescale
if isinstance(request, ImageGenerationRequest) or isinstance(request, ImageEditRequest):
if request.num_inference_steps is not None:
params.num_inference_steps = request.num_inference_steps
elif isinstance(request, ImageGenerationRequest) and request.quality == "hd":
params.num_inference_steps = 30
if request.n is not None:
params.num_images_per_prompt = request.n
if isinstance(request, ImageEditRequest):
if request.image is not None:
if isinstance(request.image, list):
params.image = [base64.b64decode(image) for image in request.image]
else:
params.image = [base64.b64decode(request.image)]
if request.mask is not None:
if isinstance(request.mask, list):
params.mask = [base64.b64decode(mask) for mask in request.mask]
else:
params.mask = base64.b64decode(request.mask)
elif isinstance(request, VideoGenerationRequest):
if request.num_inference_steps is not None:
params.num_inference_steps = request.num_inference_steps
if request.input_reference is not None:
if media_storage_path is None:
raise ValueError("media_storage_path is required when input_reference is provided")
params.input_reference = os.path.join(media_storage_path, f"{id}_reference.png")
if isinstance(request.input_reference, str):
with open(params.input_reference, "wb") as f:
f.write(base64.b64decode(request.input_reference))
else:
with open(params.input_reference, "wb") as f:
shutil.copyfileobj(request.input_reference.file, f)
params.frame_rate = request.fps
params.num_frames = int(request.seconds * request.fps)
if request.seed is not None:
params.seed = int(request.seed)
return params
class AsyncDictStore:
"""A small async-safe in-memory key-value store for dict items.
This encapsulates the usual pattern of a module-level dict guarded by
an asyncio.Lock and provides simple CRUD methods that are safe to call
concurrently from FastAPI request handlers and background tasks.
"""
def __init__(self) -> None:
self._items: Dict[str, Dict[str, Any]] = {}
self._lock = asyncio.Lock()
async def upsert(self, key: str, value: Dict[str, Any]) -> None:
async with self._lock:
self._items[key] = value
async def update_fields(self, key: str, updates: Dict[str, Any]) -> Optional[Dict[str, Any]]:
async with self._lock:
item = self._items.get(key)
if item is None:
return None
item.update(updates)
return item
async def get(self, key: str) -> Optional[Dict[str, Any]]:
async with self._lock:
return self._items.get(key)
async def pop(self, key: str) -> Optional[Dict[str, Any]]:
async with self._lock:
return self._items.pop(key, None)
async def list_values(self) -> List[Dict[str, Any]]:
async with self._lock:
return list(self._items.values())
# Global stores shared by OpenAI entrypoints
# [request_id, dict]
VIDEO_STORE = AsyncDictStore()