TensorRT-LLMs/tensorrt_llm/serve/chat_utils.py
Yechan Kim 5460d18b10
feat: trtllm-serve multimodal support (#3590)
* feat: trtllm-serve multimodal support

Signed-off-by: yechank <161688079+yechank-nvidia@users.noreply.github.com>

* remove disable argument

Signed-off-by: yechank <161688079+yechank-nvidia@users.noreply.github.com>

* remove disable

Signed-off-by: yechank <161688079+yechank-nvidia@users.noreply.github.com>

* add and separate tests and move the doc

Signed-off-by: yechank <161688079+yechank-nvidia@users.noreply.github.com>

* remove block_resue arg from serve.py

Signed-off-by: yechank <161688079+yechank-nvidia@users.noreply.github.com>

---------

Signed-off-by: yechank <161688079+yechank-nvidia@users.noreply.github.com>
Co-authored-by: Haohang Huang <31998628+symphonylyh@users.noreply.github.com>
2025-04-19 05:01:28 +08:00

248 lines
8.0 KiB
Python

from functools import partial
from typing import (Any, Callable, Dict, Iterable, List, Literal, Optional,
Tuple, TypeAlias, TypedDict, Union, cast)
from openai.types.chat import ChatCompletionContentPartImageParam
from openai.types.chat import \
ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam
from openai.types.chat import (ChatCompletionContentPartTextParam,
ChatCompletionMessageParam)
from transformers import AutoConfig, ProcessorMixin
from typing_extensions import Required
from tensorrt_llm.inputs import load_image, load_video
from tensorrt_llm.llmapi.tokenizer import TokenizerBase
from tensorrt_llm.logger import logger
class VideoURL(TypedDict):
url: Required[str]
class ChatCompletionContentPartVideoParam(TypedDict, total=False):
video_url: Required[VideoURL]
type: Required[Literal["video_url"]]
class ConversationMessage(TypedDict):
role: str
content: str
_TextParser = partial(cast, ChatCompletionContentPartTextParam)
_ImageParser = partial(cast, ChatCompletionContentPartImageParam)
_VideoParser = partial(cast, ChatCompletionContentPartVideoParam)
ChatCompletionContentPartParam: TypeAlias = Union[
OpenAIChatCompletionContentPartParam, ChatCompletionContentPartVideoParam,
str]
MM_PARSER_MAP: dict[
str,
Callable[[ChatCompletionContentPartParam], Union[str, dict[str, str]]],
] = {
"text":
lambda part: _TextParser(part).get("text", None),
"image_url":
lambda part: _ImageParser(part).get("image_url", {}).get("url", None),
"video_url":
lambda part: _VideoParser(part).get("video_url", {}).get("url", None),
}
VALID_MESSAGE_CONTENT_MM_PART_TYPES = ["text", "image_url", "video_url"]
def retrieve_multimodal_placeholder(
modality: str,
model_config: AutoConfig,
current_count: int,
) -> Optional[str]:
model_type = model_config.model_type
if modality == "image":
if model_type in ("mllama", "llama4"):
return "<|image|>"
if model_type in ("qwen2_vl", "qwen2_5_vl"):
return "<|vision_start|><|image_pad|><|vision_end|>"
raise TypeError(f"Unknown {modality} model type: {model_type}")
elif modality == "video":
if model_type in ("qwen2_vl", "qwen2_5_vl"):
return "<|vision_start|><|video_pad|><|vision_end|>"
raise TypeError(f"Unknown {modality} model type: {model_type}")
raise TypeError(f"Unknown modality: {modality}")
def add_multimodal_placeholders(
text_prompt: str,
mm_content_dict: dict[str, list[Any]],
model_config: AutoConfig,
) -> str:
placeholders = []
counts = {}
for media_type, _ in mm_content_dict.items():
if media_type not in counts:
counts[media_type] = 0
counts[media_type] += 1
placeholder = retrieve_multimodal_placeholder(media_type, model_config,
counts[media_type])
if placeholder is not None:
placeholders.append(placeholder)
return "\n".join(placeholders + [text_prompt])
def _parse_chat_message_content_mm_part(
part: ChatCompletionContentPartParam
) -> tuple[str, Union[str, dict[str, str]]]:
assert isinstance(
part, dict) # This is needed to avoid mypy errors: part.get() from str
part_type = part.get("type", None)
if isinstance(part_type, str) and part_type in MM_PARSER_MAP:
content = MM_PARSER_MAP[part_type](part)
return part_type, content
if not isinstance(part_type, str):
raise ValueError("Invalid 'type' field in multimodal part.")
return part_type, "unknown part_type content"
def parse_chat_message_content_part(part: ChatCompletionMessageParam, ):
if isinstance(part, str): # Handle plain text parts
return part
# Handle structured dictionary parts
part_type, content = _parse_chat_message_content_mm_part(part)
# if part_type is text/image_url/video_url but content is None, log a warning and skip
if part_type in VALID_MESSAGE_CONTENT_MM_PART_TYPES and content is None:
logger.warning(
"Skipping multimodal part '%s' (type: '%s') "
"with empty / unparsable content.", part, part_type)
return None
mm_content = None
if part_type == "text":
str_content = cast(str, content)
return str_content, mm_content
# TODO: make them async on multimodal data as loading video/image is time consuming
# Handle all non-text multimodal types
if part_type == "image_url":
str_content = cast(str, content)
mm_content = {"image": load_image(str_content)}
elif part_type == "video_url":
str_content = cast(str, content)
mm_content = {"video": load_video(str_content, num_frames=8)}
else:
raise NotImplementedError(f"Unknown part type: {part_type}")
return None, mm_content
def parse_chat_message_content_parts(
role: str,
parts: Iterable[ChatCompletionMessageParam],
model_config: AutoConfig,
) -> List[ConversationMessage]:
content = []
mm_content_dict = {}
for part in parts:
parse_res, mm_content = parse_chat_message_content_part(part, )
if parse_res:
content.append(parse_res)
# Collect multimodal content
if mm_content:
for media_type, media_value in mm_content.items():
if media_type not in mm_content_dict:
mm_content_dict[media_type] = []
mm_content_dict[media_type].append(media_value)
text_prompt = "\n".join(content)
if mm_content_dict:
text_prompt = add_multimodal_placeholders(text_prompt, mm_content_dict,
model_config)
return [ConversationMessage(role=role,
content=text_prompt)], mm_content_dict
def parse_chat_message_content(
message: ChatCompletionMessageParam,
model_config: AutoConfig,
) -> Tuple[List[ConversationMessage], Optional[Dict[str, List[Any]]]]:
role = message["role"]
content = message.get("content")
if content is None:
content = []
elif isinstance(content, str):
content = [
ChatCompletionContentPartTextParam(type="text", text=content)
]
result, mm_data = parse_chat_message_content_parts(
role,
content,
model_config,
)
return result, mm_data
def resolve_hf_chat_template(
tokenizer: TokenizerBase,
processor: ProcessorMixin,
chat_template: Optional[str],
tools: Optional[list[dict[str, Any]]],
) -> str:
# 1. If chat_template is not None, return it
if chat_template is not None:
return chat_template
# 2. If tool is not provided, use the processor's default chat template
if not tools and processor and hasattr(processor, 'chat_template'):
return processor.chat_template
# 3. If tool is provided, use the tool
try:
return tokenizer.get_chat_template(chat_template, tools=tools)
except Exception:
logger.debug("Failed to load AutoTokenizer chat template for %s",
tokenizer.name_or_path)
return None
def apply_chat_template(
*,
tokenizer: TokenizerBase,
processor: ProcessorMixin,
conversation: list[ConversationMessage],
add_generation_prompt: bool,
tools: Optional[list[dict[str, Any]]] = None,
documents: Optional[list[dict[str, str]]] = None,
chat_template: Optional[str] = None,
chat_template_kwargs: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> str:
hf_chat_template = resolve_hf_chat_template(tokenizer, processor,
chat_template, tools)
if hf_chat_template is None:
raise ValueError(
"No chat template found for the given tokenizer and tools.")
return tokenizer.apply_chat_template(
conversation=conversation,
tokenize=False,
add_generation_prompt=add_generation_prompt,
tools=tools,
documents=documents,
chat_template=hf_chat_template,
**(chat_template_kwargs or {}),
)