TensorRT-LLMs/tensorrt_llm/serve/chat_utils.py
Pengyun Lin 971d16a2ee
[TRTLLM-1658][feat] Enable multiple response in trtllm-serve for TRT backend (#4623)
Signed-off-by: Pengyun Lin <81065165+LinPoly@users.noreply.github.com>
2025-05-28 11:36:44 +08:00

301 lines
10 KiB
Python

import asyncio
from collections import defaultdict
from functools import partial
from typing import (Any, Callable, Coroutine, Dict, Iterable, List, Literal,
Optional, Tuple, TypeAlias, TypedDict, Union, cast)
from openai.types.chat import ChatCompletionContentPartImageParam
from openai.types.chat import \
ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam
from openai.types.chat import (ChatCompletionContentPartTextParam,
ChatCompletionMessageParam)
from transformers import AutoConfig, ProcessorMixin
from typing_extensions import Required
from tensorrt_llm.inputs import async_load_image, async_load_video
from tensorrt_llm.llmapi.tokenizer import TokenizerBase
from tensorrt_llm.logger import logger
class VideoURL(TypedDict):
"""Type definition for video URL structure."""
url: Required[str]
class ChatCompletionContentPartVideoParam(TypedDict, total=False):
"""Type definition for video content part parameters."""
video_url: Required[VideoURL]
type: Required[Literal["video_url"]]
class ConversationMessage(TypedDict):
"""Type definition for conversation message structure."""
role: str
content: str
# Type Aliases and Constants
ChatCompletionContentPartParam: TypeAlias = Union[
OpenAIChatCompletionContentPartParam, ChatCompletionContentPartVideoParam,
str]
VALID_MESSAGE_CONTENT_MM_PART_TYPES = ["text", "image_url", "video_url"]
# Parser Functions
_TextParser = partial(cast, ChatCompletionContentPartTextParam)
_ImageParser = partial(cast, ChatCompletionContentPartImageParam)
_VideoParser = partial(cast, ChatCompletionContentPartVideoParam)
MM_PARSER_MAP: dict[str, Callable[[ChatCompletionContentPartParam], Union[
str, dict[str, str]]]] = {
"text":
lambda part: _TextParser(part).get("text", None),
"image_url":
lambda part: _ImageParser(part).get("image_url", {}).get("url", None),
"video_url":
lambda part: _VideoParser(part).get("video_url", {}).get("url", None),
}
class AsyncMultimodalDataTracker:
"""Tracks and manages multimodal data for async processing."""
def __init__(self, model_config: AutoConfig):
self.model_config = model_config
self.mm_data = defaultdict[str](list)
self.mm_placeholder_counts = defaultdict[str](int)
async def retrieve_all_mm_data(self) -> Optional[Dict[str, List[Any]]]:
"""Retrieve all collected multimodal data."""
if not self.mm_data:
return None
return {
modality: await asyncio.gather(*items)
for modality, items in self.mm_data.items()
}
def retrieve_multimodal_placeholder(self, modality: str,
current_count: int) -> Optional[str]:
"""Get the appropriate placeholder for a given modality and model type."""
model_type = self.model_config.model_type
if modality == "image":
if model_type in ("qwen2_vl", "qwen2_5_vl"):
return "<|vision_start|><|image_pad|><|vision_end|>"
elif model_type in ("mllama", "llama4"):
return "<|image|>"
raise TypeError(f"Unknown {modality} model type: {model_type}")
elif modality == "video":
if model_type in ("qwen2_vl", "qwen2_5_vl"):
return "<|vision_start|><|video_pad|><|vision_end|>"
raise TypeError(f"Unknown {modality} model type: {model_type}")
raise TypeError(f"Unknown modality: {modality}")
def add_mm_data(self, media_type: str, data: Coroutine):
current_count = len(self.mm_data[media_type]) + 1
placeholder = self.retrieve_multimodal_placeholder(
media_type, current_count)
self.mm_data[media_type].append(data)
if placeholder:
self.mm_placeholder_counts[placeholder] += 1
def mm_data_counts(self) -> Dict[str, int]:
"""Get the count of multimodal placeholders."""
return dict(self.mm_placeholder_counts)
def add_multimodal_placeholders(text_prompt: str,
mm_placeholder_counts: dict[str, int]) -> str:
"""Add multimodal placeholders to the text prompt."""
placeholders = []
for placeholder in mm_placeholder_counts:
placeholders.extend([placeholder] * mm_placeholder_counts[placeholder])
return "\n".join(placeholders + [text_prompt])
def _parse_chat_message_content_mm_part(
part: ChatCompletionContentPartParam
) -> tuple[str, Union[str, dict[str, str]]]:
"""Parse a single multimodal part of a chat message."""
assert isinstance(part, dict)
part_type = part.get("type", None)
if isinstance(part_type, str) and part_type in MM_PARSER_MAP:
return part_type, MM_PARSER_MAP[part_type](part)
if not isinstance(part_type, str):
raise ValueError("Invalid 'type' field in multimodal part.")
return part_type, "unknown part_type content"
def parse_chat_message_content_part(
part: ChatCompletionMessageParam,
mm_data_tracker: AsyncMultimodalDataTracker,
) -> Optional[str]:
"""Parse a single part of a chat message."""
if isinstance(part, str):
return part
part_type, content = _parse_chat_message_content_mm_part(part)
# if part_type is text/image_url/video_url but content is None, log a warning and skip
if part_type in VALID_MESSAGE_CONTENT_MM_PART_TYPES and content is None:
logger.warning(
"Skipping multimodal part '%s' (type: '%s') with empty / unparsable content.",
part, part_type)
return None
if part_type == "text":
return cast(str, content)
if part_type == "image_url":
str_content = cast(str, content)
async def load_image_async():
try:
return await async_load_image(str_content)
except Exception as e:
logger.error(f"Failed to load image: {str(e)}")
return None
mm_data_tracker.add_mm_data("image", load_image_async())
return None
if part_type == "video_url":
str_content = cast(str, content)
async def load_video_async():
try:
return await async_load_video(str_content, num_frames=8)
except Exception as e:
logger.error(f"Failed to load video: {str(e)}")
return None
mm_data_tracker.add_mm_data("video", load_video_async())
return None
raise NotImplementedError(f"Unknown part type: {part_type}")
def parse_chat_message_content_parts(
role: str,
parts: Iterable[ChatCompletionMessageParam],
mm_data_tracker: AsyncMultimodalDataTracker,
) -> List[ConversationMessage]:
"""Parse multiple parts of a chat message."""
content_parts = []
for part in parts:
parse_res = parse_chat_message_content_part(part, mm_data_tracker)
if parse_res:
content_parts.append(parse_res)
text_prompt = "\n".join(content_parts)
mm_placeholder_counts = mm_data_tracker.mm_data_counts()
if mm_placeholder_counts:
text_prompt = add_multimodal_placeholders(text_prompt,
mm_placeholder_counts)
return [ConversationMessage(role=role, content=text_prompt)]
def parse_chat_message_content(
message: ChatCompletionMessageParam,
mm_data_tracker: AsyncMultimodalDataTracker,
) -> List[ConversationMessage]:
"""Parse the content of a chat message."""
role = message["role"]
content = message.get("content")
if content is None:
content = []
elif isinstance(content, str):
content = [
ChatCompletionContentPartTextParam(type="text", text=content)
]
result = parse_chat_message_content_parts(
role,
content,
mm_data_tracker,
)
return result
def parse_chat_messages_coroutines(
messages: List[ChatCompletionMessageParam],
model_config: AutoConfig,
) -> Tuple[List[ConversationMessage], Optional[Coroutine[
Any, Any, Optional[Dict[str, List[Any]]]]]]:
"""Parse multiple chat messages and return conversation and coroutine."""
conversation = []
mm_data_tracker = AsyncMultimodalDataTracker(model_config)
for msg in messages:
sub_messages = parse_chat_message_content(msg, mm_data_tracker)
conversation.extend(sub_messages)
return conversation, mm_data_tracker.retrieve_all_mm_data()
def resolve_hf_chat_template(
tokenizer: TokenizerBase,
processor: ProcessorMixin,
chat_template: Optional[str],
tools: Optional[list[dict[str, Any]]],
) -> Optional[str]:
"""Resolve the appropriate chat template to use."""
# 1. If chat_template is not None, return it
if chat_template is not None:
return chat_template
# 2. If tool is not provided, use the processor's default chat template
if not tools and processor and hasattr(processor, 'chat_template'):
return processor.chat_template
# 3. If tool is provided, use the tool
try:
return tokenizer.get_chat_template(chat_template, tools=tools)
except Exception:
logger.debug("Failed to load AutoTokenizer chat template for %s",
tokenizer.name_or_path)
return None
def apply_chat_template(
*,
tokenizer: TokenizerBase,
processor: ProcessorMixin,
conversation: list[ConversationMessage],
add_generation_prompt: bool,
tools: Optional[list[dict[str, Any]]] = None,
documents: Optional[list[dict[str, str]]] = None,
chat_template: Optional[str] = None,
chat_template_kwargs: Optional[dict[str, Any]] = None,
) -> str:
"""Apply chat template to the conversation."""
hf_chat_template = resolve_hf_chat_template(tokenizer, processor,
chat_template, tools)
if hf_chat_template is None:
raise ValueError(
"No chat template found for the given tokenizer and tools.")
return tokenizer.apply_chat_template(
conversation=conversation,
tokenize=False,
add_generation_prompt=add_generation_prompt,
tools=tools,
documents=documents,
chat_template=hf_chat_template,
**(chat_template_kwargs or {}),
)
def check_multiple_response(n: int, backend: Optional[str]):
if n > 1 and backend == "pytorch":
raise ValueError(
"Multiple response is not supported in PyTorch workflow")