import json import uuid from functools import lru_cache, partial from pathlib import Path from typing import (Any, Callable, Coroutine, Dict, Iterable, List, Literal, Optional, Tuple, TypeAlias, TypedDict, Union, cast) from openai.types.chat import (ChatCompletionContentPartImageParam, ChatCompletionContentPartInputAudioParam) from openai.types.chat import \ ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam from openai.types.chat import (ChatCompletionContentPartTextParam, ChatCompletionMessageParam) from transformers import AutoConfig from typing_extensions import Required from tensorrt_llm.inputs import (ConversationMessage, MultimodalData, MultimodalDataTracker, add_multimodal_placeholders, async_load_audio, async_load_image, async_load_video) from tensorrt_llm.inputs.multimodal import MultimodalServerConfig from tensorrt_llm.logger import logger class VideoURL(TypedDict): """Type definition for video URL structure.""" url: Required[str] class ChatCompletionContentPartVideoParam(TypedDict, total=False): """Type definition for video content part parameters.""" video_url: Required[VideoURL] type: Required[Literal["video_url"]] # Type Aliases and Constants ChatCompletionContentPartParam: TypeAlias = Union[ OpenAIChatCompletionContentPartParam, ChatCompletionContentPartVideoParam, str] # TODO: Add "input_audio" to support byte_encoded audio input. VALID_MESSAGE_CONTENT_MM_PART_TYPES = [ "text", "image_url", "video_url", "audio_url" ] # Parser Functions _TextParser = partial(cast, ChatCompletionContentPartTextParam) _ImageParser = partial(cast, ChatCompletionContentPartImageParam) _VideoParser = partial(cast, ChatCompletionContentPartVideoParam) _AudioParser = partial(cast, ChatCompletionContentPartInputAudioParam) MM_PARSER_MAP: dict[str, Callable[[ChatCompletionContentPartParam], Union[ str, dict[str, str]]]] = { "text": lambda part: _TextParser(part).get("text", None), "image_url": lambda part: _ImageParser(part).get("image_url", {}).get("url", None), "video_url": lambda part: _VideoParser(part).get("video_url", {}).get("url", None), "audio_url": lambda part: _AudioParser(part).get("audio_url", {}).get("url", None), } def _parse_chat_message_content_mm_part( part: ChatCompletionContentPartParam ) -> tuple[str, Union[str, dict[str, str]]]: """Parse a single multimodal part of a chat message.""" assert isinstance(part, dict) part_type = part.get("type", None) if isinstance(part_type, str) and part_type in MM_PARSER_MAP: return part_type, MM_PARSER_MAP[part_type](part) if not isinstance(part_type, str): raise ValueError("Invalid 'type' field in multimodal part.") return part_type, "unknown part_type content" def parse_chat_message_content_part( part: ChatCompletionMessageParam, mm_data_tracker: MultimodalDataTracker, ) -> Optional[Any]: """Parse a single part of a chat message.""" if isinstance(part, str): return part part_type, content = _parse_chat_message_content_mm_part(part) # if part_type is text/image_url/video_url/audio_url but content is None, log a warning and skip if part_type in VALID_MESSAGE_CONTENT_MM_PART_TYPES and content is None: logger.warning( "Skipping multimodal part '%s' (type: '%s') with empty / unparsable content.", part, part_type) return None if part_type == "text": return cast(str, content) if part_type == "image_url": str_content = cast(str, content) async def load_image_async(): try: image_kwargs = ( mm_data_tracker._multimodal_server_config.media_io_kwargs or {}).get("image", {}) return await async_load_image(str_content, **image_kwargs) except Exception as e: logger.error(f"Failed to load image: {str(e)}") return None return MultimodalData(modality="image", data=load_image_async()) if part_type == "video_url": str_content = cast(str, content) async def load_video_async(): try: video_kwargs = ( mm_data_tracker._multimodal_server_config.media_io_kwargs or {}).get("video", {}) return await async_load_video(str_content, **video_kwargs) except Exception as e: logger.error(f"Failed to load video: {str(e)}") return None return MultimodalData(modality="video", data=load_video_async()) if part_type == "audio_url": str_content = cast(str, content) async def load_audio_async(): try: audio_kwargs = ( mm_data_tracker._multimodal_server_config.media_io_kwargs or {}).get("audio", {}) return await async_load_audio(str_content, **audio_kwargs) except Exception as e: logger.error(f"Failed to load audio: {str(e)}") return None return MultimodalData(modality="audio", data=load_audio_async()) raise NotImplementedError(f"Unknown part type: {part_type}") def parse_chat_message_content_parts( role: str, parts: Iterable[ChatCompletionMessageParam], mm_data_tracker: MultimodalDataTracker, ) -> ConversationMessage: """Parse multiple parts of a chat message.""" text_parts = [] media_parts = [] for part in parts: parse_res = parse_chat_message_content_part(part, mm_data_tracker) if parse_res: if isinstance(parse_res, str): text_parts.append(parse_res) else: media_parts.append(parse_res) text_prompt = "\n".join(text_parts) return ConversationMessage(role=role, content=text_prompt, media=media_parts) def parse_chat_message_content( message: ChatCompletionMessageParam, mm_data_tracker: MultimodalDataTracker) -> ConversationMessage: """Parse the content of a chat message.""" role = message["role"] content = message.get("content") if content is None: content = [] elif isinstance(content, str): content = [ ChatCompletionContentPartTextParam(type="text", text=content) ] result = parse_chat_message_content_parts( role, content, mm_data_tracker, ) if role == "assistant": result.update(_parse_assistant_message_content(message)) elif role == "tool": result.update(_parse_tool_message_content(message)) return result # Adapted from: https://github.com/vllm-project/vllm/blob/4574d48bab9c4e38b7c0a830eeefc8f0980e8c58/vllm/entrypoints/chat_utils.py#L1406 def _parse_assistant_message_content(message: Dict[str, Any]) -> Dict[str, Any]: result = {} tool_calls = message.get("tool_calls") if tool_calls is not None: result["tool_calls"] = [] for item in tool_calls: if content := item["function"].get("arguments"): if isinstance(content, str): item["function"]["arguments"] = json.loads(content) else: item["function"]["arguments"] = content else: item["function"]["arguments"] = {} result["tool_calls"].append(item) return result def _parse_tool_message_content(message: Dict[str, Any]) -> Dict[str, Any]: result = {} if "tool_call_id" in message: result["tool_call_id"] = message["tool_call_id"] return result def parse_chat_messages_coroutines( messages: List[ChatCompletionMessageParam], model_config: AutoConfig, multimodal_server_config: Optional[MultimodalServerConfig] = None ) -> Tuple[List[ConversationMessage], Optional[Coroutine[ Any, Any, Optional[Dict[str, List[Any]]]]]]: """Parse multiple chat messages and return conversation and coroutine.""" conversation = [] mm_placeholder_counts = [] mm_data_tracker = MultimodalDataTracker(model_config.model_type, multimodal_server_config) for msg in messages: parsed_msg = parse_chat_message_content(msg, mm_data_tracker) conversation.append(parsed_msg) if parsed_msg["media"]: for mdata in parsed_msg["media"]: mm_data_tracker.add_data(mdata["modality"], mdata["data"]) mm_placeholder_count = mm_data_tracker.placeholder_counts() if mm_placeholder_count: parsed_msg["content"] = add_multimodal_placeholders( model_config.model_type, parsed_msg["content"], mm_placeholder_count) mm_placeholder_counts.append(mm_placeholder_count) return conversation, mm_data_tracker.retrieve_all_async( ), mm_placeholder_counts def make_tool_call_id(id_type: str = "random", func_name=None, idx=None): if id_type == "kimi_k2": return f"functions.{func_name}:{idx}" else: # by default return random return f"chatcmpl-tool-{uuid.uuid4().hex}" # Adapted from # https://github.com/vllm-project/vllm/blob/44b5ce956d3cf28841615a58c1c0873af87bcfe2/vllm/entrypoints/chat_utils.py @lru_cache def load_chat_template( chat_template: Path | str | None, *, is_literal: bool = False, ) -> str | None: if chat_template is None: return None if is_literal: if isinstance(chat_template, Path): raise TypeError( "chat_template is expected to be read directly from its value") return chat_template try: with open(chat_template) as f: return f.read() except OSError as e: if isinstance(chat_template, Path): raise JINJA_CHARS = "{}\n" if not any(c in chat_template for c in JINJA_CHARS): msg = (f"The supplied chat template ({chat_template}) " f"looks like a file path, but it failed to be " f"opened. Reason: {e}") raise ValueError(msg) from e # If opening a file fails, set chat template to be args to # ensure we decode so our escape are interpreted correctly return load_chat_template(chat_template, is_literal=True)