mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
206 lines
7.1 KiB
Python
206 lines
7.1 KiB
Python
from functools import partial
|
|
from typing import (Any, Callable, Coroutine, Dict, Iterable, List, Literal,
|
|
Optional, Tuple, TypeAlias, TypedDict, Union, cast)
|
|
|
|
from openai.types.chat import (ChatCompletionContentPartImageParam,
|
|
ChatCompletionContentPartInputAudioParam)
|
|
from openai.types.chat import \
|
|
ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam
|
|
from openai.types.chat import (ChatCompletionContentPartTextParam,
|
|
ChatCompletionMessageParam)
|
|
from transformers import AutoConfig
|
|
from typing_extensions import Required
|
|
|
|
from tensorrt_llm.inputs import (ConversationMessage, MultimodalData,
|
|
MultimodalDataTracker,
|
|
add_multimodal_placeholders, async_load_audio,
|
|
async_load_image, async_load_video)
|
|
from tensorrt_llm.logger import logger
|
|
|
|
|
|
class VideoURL(TypedDict):
|
|
"""Type definition for video URL structure."""
|
|
url: Required[str]
|
|
|
|
|
|
class ChatCompletionContentPartVideoParam(TypedDict, total=False):
|
|
"""Type definition for video content part parameters."""
|
|
video_url: Required[VideoURL]
|
|
type: Required[Literal["video_url"]]
|
|
|
|
|
|
# Type Aliases and Constants
|
|
ChatCompletionContentPartParam: TypeAlias = Union[
|
|
OpenAIChatCompletionContentPartParam, ChatCompletionContentPartVideoParam,
|
|
str]
|
|
|
|
# TODO: Add "input_audio" to support byte_encoded audio input.
|
|
VALID_MESSAGE_CONTENT_MM_PART_TYPES = [
|
|
"text", "image_url", "video_url", "audio_url"
|
|
]
|
|
|
|
# Parser Functions
|
|
_TextParser = partial(cast, ChatCompletionContentPartTextParam)
|
|
_ImageParser = partial(cast, ChatCompletionContentPartImageParam)
|
|
_VideoParser = partial(cast, ChatCompletionContentPartVideoParam)
|
|
_AudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
|
|
|
|
MM_PARSER_MAP: dict[str, Callable[[ChatCompletionContentPartParam], Union[
|
|
str, dict[str, str]]]] = {
|
|
"text":
|
|
lambda part: _TextParser(part).get("text", None),
|
|
"image_url":
|
|
lambda part: _ImageParser(part).get("image_url", {}).get("url", None),
|
|
"video_url":
|
|
lambda part: _VideoParser(part).get("video_url", {}).get("url", None),
|
|
"audio_url":
|
|
lambda part: _AudioParser(part).get("audio_url", {}).get("url", None),
|
|
}
|
|
|
|
|
|
def _parse_chat_message_content_mm_part(
|
|
part: ChatCompletionContentPartParam
|
|
) -> tuple[str, Union[str, dict[str, str]]]:
|
|
"""Parse a single multimodal part of a chat message."""
|
|
assert isinstance(part, dict)
|
|
part_type = part.get("type", None)
|
|
|
|
if isinstance(part_type, str) and part_type in MM_PARSER_MAP:
|
|
return part_type, MM_PARSER_MAP[part_type](part)
|
|
|
|
if not isinstance(part_type, str):
|
|
raise ValueError("Invalid 'type' field in multimodal part.")
|
|
return part_type, "unknown part_type content"
|
|
|
|
|
|
def parse_chat_message_content_part(
|
|
part: ChatCompletionMessageParam, ) -> Optional[Any]:
|
|
"""Parse a single part of a chat message."""
|
|
if isinstance(part, str):
|
|
return part
|
|
|
|
part_type, content = _parse_chat_message_content_mm_part(part)
|
|
|
|
# if part_type is text/image_url/video_url/audio_url but content is None, log a warning and skip
|
|
if part_type in VALID_MESSAGE_CONTENT_MM_PART_TYPES and content is None:
|
|
logger.warning(
|
|
"Skipping multimodal part '%s' (type: '%s') with empty / unparsable content.",
|
|
part, part_type)
|
|
return None
|
|
|
|
if part_type == "text":
|
|
return cast(str, content)
|
|
|
|
if part_type == "image_url":
|
|
str_content = cast(str, content)
|
|
|
|
async def load_image_async():
|
|
try:
|
|
return await async_load_image(str_content)
|
|
except Exception as e:
|
|
logger.error(f"Failed to load image: {str(e)}")
|
|
return None
|
|
|
|
return MultimodalData(modality="image", data=load_image_async())
|
|
|
|
if part_type == "video_url":
|
|
str_content = cast(str, content)
|
|
|
|
async def load_video_async():
|
|
try:
|
|
return await async_load_video(str_content, num_frames=8)
|
|
except Exception as e:
|
|
logger.error(f"Failed to load video: {str(e)}")
|
|
return None
|
|
|
|
return MultimodalData(modality="video", data=load_video_async())
|
|
|
|
if part_type == "audio_url":
|
|
str_content = cast(str, content)
|
|
|
|
async def load_audio_async():
|
|
try:
|
|
return await async_load_audio(str_content)
|
|
except Exception as e:
|
|
logger.error(f"Failed to load audio: {str(e)}")
|
|
return None
|
|
|
|
return MultimodalData(modality="audio", data=load_audio_async())
|
|
|
|
raise NotImplementedError(f"Unknown part type: {part_type}")
|
|
|
|
|
|
def parse_chat_message_content_parts(
|
|
role: str,
|
|
parts: Iterable[ChatCompletionMessageParam],
|
|
) -> ConversationMessage:
|
|
"""Parse multiple parts of a chat message."""
|
|
text_parts = []
|
|
media_parts = []
|
|
for part in parts:
|
|
parse_res = parse_chat_message_content_part(part)
|
|
if parse_res:
|
|
if isinstance(parse_res, str):
|
|
text_parts.append(parse_res)
|
|
else:
|
|
media_parts.append(parse_res)
|
|
|
|
text_prompt = "\n".join(text_parts)
|
|
|
|
return ConversationMessage(role=role,
|
|
content=text_prompt,
|
|
media=media_parts)
|
|
|
|
|
|
def parse_chat_message_content(
|
|
message: ChatCompletionMessageParam, ) -> ConversationMessage:
|
|
"""Parse the content of a chat message."""
|
|
role = message["role"]
|
|
content = message.get("content")
|
|
|
|
if content is None:
|
|
content = []
|
|
elif isinstance(content, str):
|
|
content = [
|
|
ChatCompletionContentPartTextParam(type="text", text=content)
|
|
]
|
|
|
|
result = parse_chat_message_content_parts(
|
|
role,
|
|
content,
|
|
)
|
|
return result
|
|
|
|
|
|
def parse_chat_messages_coroutines(
|
|
messages: List[ChatCompletionMessageParam],
|
|
model_config: AutoConfig,
|
|
) -> Tuple[List[ConversationMessage], Optional[Coroutine[
|
|
Any, Any, Optional[Dict[str, List[Any]]]]]]:
|
|
"""Parse multiple chat messages and return conversation and coroutine."""
|
|
conversation = []
|
|
mm_placeholder_counts = []
|
|
mm_data_tracker = MultimodalDataTracker(model_config.model_type)
|
|
|
|
for msg in messages:
|
|
parsed_msg = parse_chat_message_content(msg)
|
|
conversation.append(parsed_msg)
|
|
if parsed_msg["media"]:
|
|
for mdata in parsed_msg["media"]:
|
|
mm_data_tracker.add_data(mdata["modality"], mdata["data"])
|
|
mm_placeholder_count = mm_data_tracker.placeholder_counts()
|
|
if mm_placeholder_count:
|
|
parsed_msg["content"] = add_multimodal_placeholders(
|
|
model_config.model_type, parsed_msg["content"],
|
|
mm_placeholder_count)
|
|
mm_placeholder_counts.append(mm_placeholder_count)
|
|
|
|
return conversation, mm_data_tracker.retrieve_all_async(
|
|
), mm_placeholder_counts
|
|
|
|
|
|
def check_multiple_response(n: int, backend: Optional[str]):
|
|
if n > 1 and backend == "pytorch":
|
|
raise ValueError(
|
|
"Multiple response is not supported in PyTorch workflow")
|