TensorRT-LLMs/tensorrt_llm/serve/chat_utils.py
William Zhang 121140cfec
[None][fixes] Add tool call parsing fixes and Qwen3 coder parser (#8817)
Signed-off-by: William Zhang <133824995+2ez4bz@users.noreply.github.com>
2025-11-13 04:34:38 -08:00

257 lines
9.2 KiB
Python

import json
import uuid
from functools import partial
from typing import (Any, Callable, Coroutine, Dict, Iterable, List, Literal,
Optional, Tuple, TypeAlias, TypedDict, Union, cast)
from openai.types.chat import (ChatCompletionContentPartImageParam,
ChatCompletionContentPartInputAudioParam)
from openai.types.chat import \
ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam
from openai.types.chat import (ChatCompletionContentPartTextParam,
ChatCompletionMessageParam)
from transformers import AutoConfig
from typing_extensions import Required
from tensorrt_llm.inputs import (ConversationMessage, MultimodalData,
MultimodalDataTracker,
add_multimodal_placeholders, async_load_audio,
async_load_image, async_load_video)
from tensorrt_llm.inputs.multimodal import MultimodalServerConfig
from tensorrt_llm.logger import logger
class VideoURL(TypedDict):
"""Type definition for video URL structure."""
url: Required[str]
class ChatCompletionContentPartVideoParam(TypedDict, total=False):
"""Type definition for video content part parameters."""
video_url: Required[VideoURL]
type: Required[Literal["video_url"]]
# Type Aliases and Constants
ChatCompletionContentPartParam: TypeAlias = Union[
OpenAIChatCompletionContentPartParam, ChatCompletionContentPartVideoParam,
str]
# TODO: Add "input_audio" to support byte_encoded audio input.
VALID_MESSAGE_CONTENT_MM_PART_TYPES = [
"text", "image_url", "video_url", "audio_url"
]
# Parser Functions
_TextParser = partial(cast, ChatCompletionContentPartTextParam)
_ImageParser = partial(cast, ChatCompletionContentPartImageParam)
_VideoParser = partial(cast, ChatCompletionContentPartVideoParam)
_AudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
MM_PARSER_MAP: dict[str, Callable[[ChatCompletionContentPartParam], Union[
str, dict[str, str]]]] = {
"text":
lambda part: _TextParser(part).get("text", None),
"image_url":
lambda part: _ImageParser(part).get("image_url", {}).get("url", None),
"video_url":
lambda part: _VideoParser(part).get("video_url", {}).get("url", None),
"audio_url":
lambda part: _AudioParser(part).get("audio_url", {}).get("url", None),
}
def _parse_chat_message_content_mm_part(
part: ChatCompletionContentPartParam
) -> tuple[str, Union[str, dict[str, str]]]:
"""Parse a single multimodal part of a chat message."""
assert isinstance(part, dict)
part_type = part.get("type", None)
if isinstance(part_type, str) and part_type in MM_PARSER_MAP:
return part_type, MM_PARSER_MAP[part_type](part)
if not isinstance(part_type, str):
raise ValueError("Invalid 'type' field in multimodal part.")
return part_type, "unknown part_type content"
def parse_chat_message_content_part(
part: ChatCompletionMessageParam,
mm_data_tracker: MultimodalDataTracker,
) -> Optional[Any]:
"""Parse a single part of a chat message."""
if isinstance(part, str):
return part
part_type, content = _parse_chat_message_content_mm_part(part)
# if part_type is text/image_url/video_url/audio_url but content is None, log a warning and skip
if part_type in VALID_MESSAGE_CONTENT_MM_PART_TYPES and content is None:
logger.warning(
"Skipping multimodal part '%s' (type: '%s') with empty / unparsable content.",
part, part_type)
return None
if part_type == "text":
return cast(str, content)
if part_type == "image_url":
str_content = cast(str, content)
async def load_image_async():
try:
image_kwargs = (
mm_data_tracker._multimodal_server_config.media_io_kwargs
or {}).get("image", {})
return await async_load_image(str_content, **image_kwargs)
except Exception as e:
logger.error(f"Failed to load image: {str(e)}")
return None
return MultimodalData(modality="image", data=load_image_async())
if part_type == "video_url":
str_content = cast(str, content)
async def load_video_async():
try:
video_kwargs = (
mm_data_tracker._multimodal_server_config.media_io_kwargs
or {}).get("video", {})
return await async_load_video(str_content, **video_kwargs)
except Exception as e:
logger.error(f"Failed to load video: {str(e)}")
return None
return MultimodalData(modality="video", data=load_video_async())
if part_type == "audio_url":
str_content = cast(str, content)
async def load_audio_async():
try:
audio_kwargs = (
mm_data_tracker._multimodal_server_config.media_io_kwargs
or {}).get("audio", {})
return await async_load_audio(str_content, **audio_kwargs)
except Exception as e:
logger.error(f"Failed to load audio: {str(e)}")
return None
return MultimodalData(modality="audio", data=load_audio_async())
raise NotImplementedError(f"Unknown part type: {part_type}")
def parse_chat_message_content_parts(
role: str,
parts: Iterable[ChatCompletionMessageParam],
mm_data_tracker: MultimodalDataTracker,
) -> ConversationMessage:
"""Parse multiple parts of a chat message."""
text_parts = []
media_parts = []
for part in parts:
parse_res = parse_chat_message_content_part(part, mm_data_tracker)
if parse_res:
if isinstance(parse_res, str):
text_parts.append(parse_res)
else:
media_parts.append(parse_res)
text_prompt = "\n".join(text_parts)
return ConversationMessage(role=role,
content=text_prompt,
media=media_parts)
def parse_chat_message_content(
message: ChatCompletionMessageParam,
mm_data_tracker: MultimodalDataTracker) -> ConversationMessage:
"""Parse the content of a chat message."""
role = message["role"]
content = message.get("content")
if content is None:
content = []
elif isinstance(content, str):
content = [
ChatCompletionContentPartTextParam(type="text", text=content)
]
result = parse_chat_message_content_parts(
role,
content,
mm_data_tracker,
)
if role == "assistant":
result.update(_parse_assistant_message_content(message))
elif role == "tool":
result.update(_parse_tool_message_content(message))
return result
# Adapted from: https://github.com/vllm-project/vllm/blob/4574d48bab9c4e38b7c0a830eeefc8f0980e8c58/vllm/entrypoints/chat_utils.py#L1406
def _parse_assistant_message_content(message: Dict[str, Any]) -> Dict[str, Any]:
result = {}
tool_calls = message.get("tool_calls")
if tool_calls is not None:
result["tool_calls"] = []
for item in tool_calls:
if content := item["function"].get("arguments"):
if isinstance(content, str):
item["function"]["arguments"] = json.loads(content)
else:
item["function"]["arguments"] = content
else:
item["function"]["arguments"] = {}
result["tool_calls"].append(item)
return result
def _parse_tool_message_content(message: Dict[str, Any]) -> Dict[str, Any]:
result = {}
if "tool_call_id" in message:
result["tool_call_id"] = message["tool_call_id"]
return result
def parse_chat_messages_coroutines(
messages: List[ChatCompletionMessageParam],
model_config: AutoConfig,
multimodal_server_config: Optional[MultimodalServerConfig] = None
) -> Tuple[List[ConversationMessage], Optional[Coroutine[
Any, Any, Optional[Dict[str, List[Any]]]]]]:
"""Parse multiple chat messages and return conversation and coroutine."""
conversation = []
mm_placeholder_counts = []
mm_data_tracker = MultimodalDataTracker(model_config.model_type,
multimodal_server_config)
for msg in messages:
parsed_msg = parse_chat_message_content(msg, mm_data_tracker)
conversation.append(parsed_msg)
if parsed_msg["media"]:
for mdata in parsed_msg["media"]:
mm_data_tracker.add_data(mdata["modality"], mdata["data"])
mm_placeholder_count = mm_data_tracker.placeholder_counts()
if mm_placeholder_count:
parsed_msg["content"] = add_multimodal_placeholders(
model_config.model_type, parsed_msg["content"],
mm_placeholder_count)
mm_placeholder_counts.append(mm_placeholder_count)
return conversation, mm_data_tracker.retrieve_all_async(
), mm_placeholder_counts
def make_tool_call_id(id_type: str = "random", func_name=None, idx=None):
if id_type == "kimi_k2":
return f"functions.{func_name}:{idx}"
else:
# by default return random
return f"chatcmpl-tool-{uuid.uuid4().hex}"