mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Merge 01cf98132a into 6df2c8a074
This commit is contained in:
commit
2d45b482e0
@ -6,6 +6,7 @@ markers =
|
||||
fmhca
|
||||
debug
|
||||
bench
|
||||
needs_l40s
|
||||
# bin: unit tests
|
||||
# test: python script for invoking fmha.exe
|
||||
testpaths = bin test
|
||||
|
||||
@ -170,6 +170,24 @@ TRT-LLM multimodal supports the following modalities and data types (depending o
|
||||
`load_base64_image utility <https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/utils/load_base64_image.py>`__
|
||||
for implementation details.
|
||||
|
||||
**Image embeddings**
|
||||
|
||||
It is also possible to directly provide the image embeddings to use by the multimodal
|
||||
model.
|
||||
|
||||
* Using "image_embeds" with base64-encoded data:
|
||||
|
||||
.. code-block:: json
|
||||
|
||||
{"role": "user", "content": [
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
{"type": "image_embeds", "image_embeds": {"data": "{image_embeddings_base64}"}}}
|
||||
]}
|
||||
|
||||
.. note::
|
||||
The contents of `image_embeddings_base64` can be generated by base64-encoding
|
||||
the result of serializing a tensor via `torch.save`.
|
||||
|
||||
**Video**
|
||||
|
||||
* Using "video_url":
|
||||
|
||||
@ -16,7 +16,8 @@ from .utils import (ALL_SUPPORTED_AUDIO_MODELS, ALL_SUPPORTED_IMAGE_MODELS,
|
||||
async_load_audio, async_load_image, async_load_video,
|
||||
convert_image_mode, default_multimodal_input_loader,
|
||||
encode_base64_content_from_url, encode_base64_image,
|
||||
get_cache_salt_id, load_image, load_video)
|
||||
get_cache_salt_id, load_base64_image_embeds, load_image,
|
||||
load_video)
|
||||
|
||||
__all__ = [
|
||||
"ALL_SUPPORTED_MULTIMODAL_MODELS",
|
||||
@ -57,4 +58,5 @@ __all__ = [
|
||||
"get_cache_salt_id",
|
||||
"compute_retained_tokens_count",
|
||||
"compute_retention_mask",
|
||||
"load_base64_image_embeds",
|
||||
]
|
||||
|
||||
@ -114,6 +114,15 @@ def load_base64_image(parsed_url: str) -> Image.Image:
|
||||
return image
|
||||
|
||||
|
||||
def load_base64_image_embeds(str_content: str) -> torch.Tensor:
|
||||
content_bytes = base64.b64decode(str_content)
|
||||
with BytesIO(content_bytes) as buf:
|
||||
image_data: torch.Tensor = torch.load(buf,
|
||||
weights_only=True,
|
||||
map_location="cpu")
|
||||
return image_data
|
||||
|
||||
|
||||
def load_image(image: Union[str, Image.Image],
|
||||
format: str = "pt",
|
||||
device: str = "cpu") -> Union[Image.Image, torch.Tensor]:
|
||||
@ -425,13 +434,14 @@ class MultimodalData(TypedDict):
|
||||
"""Type definition for multimodal data structure."""
|
||||
modality: str
|
||||
data: Any
|
||||
is_embedding: bool
|
||||
|
||||
|
||||
class ConversationMessage(TypedDict):
|
||||
"""Type definition for conversation message structure."""
|
||||
role: str
|
||||
content: List[dict[str, Any]]
|
||||
media: List[MultimodalData] | List[torch.Tensor] | List[Dict[str, Any]]
|
||||
media: List[MultimodalData]
|
||||
|
||||
# @classmethod
|
||||
# def fromSample(cls, sample: dict[str, str]) -> "ConversationMessage":
|
||||
@ -446,33 +456,57 @@ class MultimodalDataTracker:
|
||||
model_type: str,
|
||||
multimodal_server_config: Optional[MultimodalServerConfig] = None):
|
||||
self._model_type = model_type
|
||||
self._data = defaultdict[str](list)
|
||||
self._placeholder_counts = defaultdict[str](int)
|
||||
self._data = defaultdict[str, list](list)
|
||||
self._embeddings = defaultdict[str, list](list)
|
||||
self._placeholder_counts = defaultdict[str, int](int)
|
||||
self._multimodal_server_config = multimodal_server_config if multimodal_server_config is not None else MultimodalServerConfig(
|
||||
)
|
||||
|
||||
async def retrieve_all_async(self) -> Optional[Dict[str, List[Any]]]:
|
||||
"""Retrieve all collected multimodal data."""
|
||||
if not self._data:
|
||||
return None
|
||||
async def retrieve_all_async(
|
||||
self
|
||||
) -> tuple[Optional[Dict[str, List[Any]]], Optional[Dict[str, List[Any]]]]:
|
||||
"""Retrieve all collected multimodal data and embeddings."""
|
||||
|
||||
return {
|
||||
modality: await asyncio.gather(*items)
|
||||
for modality, items in self._data.items()
|
||||
}
|
||||
async def _retrieve(
|
||||
data: Optional[dict[str,
|
||||
list]]) -> Optional[Dict[str, List[Any]]]:
|
||||
if not data:
|
||||
return None
|
||||
return {
|
||||
modality: await asyncio.gather(*items)
|
||||
for modality, items in data.items() if items
|
||||
}
|
||||
|
||||
def retrieve_all_sync(self) -> Optional[Dict[str, List[Any]]]:
|
||||
"""Retrieve all collected multimodal data."""
|
||||
if not self._data:
|
||||
return None
|
||||
return await _retrieve(self._data), await _retrieve(self._embeddings)
|
||||
|
||||
return {modality: items for modality, items in self._data.items()}
|
||||
def retrieve_all_sync(
|
||||
self
|
||||
) -> tuple[Optional[Dict[str, List[Any]]], Optional[Dict[str, List[Any]]]]:
|
||||
"""Retrieve all collected multimodal data and embeddings."""
|
||||
|
||||
def add_data(self, media_type: str, data: Union[Coroutine, Any]):
|
||||
current_count = len(self._data[media_type]) + 1
|
||||
def _retrieve(
|
||||
data: Optional[dict[str,
|
||||
list]]) -> Optional[Dict[str, List[Any]]]:
|
||||
if not data:
|
||||
return None
|
||||
return {
|
||||
modality: items
|
||||
for modality, items in data.items() if items
|
||||
}
|
||||
|
||||
return _retrieve(self._data), _retrieve(self._embeddings)
|
||||
|
||||
def add_data(self,
|
||||
media_type: str,
|
||||
data: Union[Coroutine, Any],
|
||||
*,
|
||||
is_embedding: bool = False):
|
||||
current_count = len(self._data[media_type]) + len(
|
||||
self._embeddings[media_type]) + 1
|
||||
placeholder = retrieve_multimodal_placeholder(self._model_type,
|
||||
media_type, current_count)
|
||||
self._data[media_type].append(data)
|
||||
(self._embeddings
|
||||
if is_embedding else self._data)[media_type].append(data)
|
||||
if placeholder:
|
||||
self._placeholder_counts[placeholder] += 1
|
||||
|
||||
@ -643,33 +677,34 @@ def default_multimodal_input_loader(
|
||||
media = [media]
|
||||
if modality in ["image", "multiple_image"]:
|
||||
if is_embedding:
|
||||
_load = lambda mm: mm
|
||||
|
||||
# each mm_embedding corresponds to each image placeholder
|
||||
if not isinstance(media, list):
|
||||
media = [media]
|
||||
|
||||
mm_data = [{
|
||||
'modality': modality,
|
||||
'mm_embedding_info': mm
|
||||
} for mm in media]
|
||||
else:
|
||||
mm_data = [
|
||||
MultimodalData(modality=modality,
|
||||
data=load_image(i,
|
||||
format=image_data_format,
|
||||
device=device))
|
||||
for i in media
|
||||
]
|
||||
_load = lambda mm: load_image(
|
||||
mm, format=image_data_format, device=device)
|
||||
|
||||
mm_data = [
|
||||
MultimodalData(modality=modality,
|
||||
data=_load(mm),
|
||||
is_embedding=is_embedding) for mm in media
|
||||
]
|
||||
elif modality == "video":
|
||||
if is_embedding:
|
||||
raise ValueError(
|
||||
"External embedding is not supported for video modality yet."
|
||||
)
|
||||
mm_data = [
|
||||
MultimodalData(modality=modality,
|
||||
data=load_video(i,
|
||||
num_frames,
|
||||
format=image_data_format,
|
||||
device=device)) for i in media
|
||||
MultimodalData(
|
||||
modality=modality,
|
||||
data=load_video(i,
|
||||
num_frames,
|
||||
format=image_data_format,
|
||||
device=device),
|
||||
is_embedding=False,
|
||||
) for i in media
|
||||
]
|
||||
elif modality == "audio":
|
||||
if is_embedding:
|
||||
@ -677,8 +712,11 @@ def default_multimodal_input_loader(
|
||||
"External embedding is not supported for audio modality yet."
|
||||
)
|
||||
mm_data = [
|
||||
MultimodalData(modality=modality,
|
||||
data=load_audio(i, device=device)) for i in media
|
||||
MultimodalData(
|
||||
modality=modality,
|
||||
data=load_audio(i, device=device),
|
||||
is_embedding=False,
|
||||
) for i in media
|
||||
]
|
||||
elif modality == "image_audio":
|
||||
if is_embedding:
|
||||
@ -706,16 +744,22 @@ def default_multimodal_input_loader(
|
||||
pass
|
||||
if _modal is None:
|
||||
raise ValueError(f"Unknown matching modality: {modality}")
|
||||
mm_data.append(MultimodalData(modality=_modal, data=data))
|
||||
mm_data.append(
|
||||
MultimodalData(modality=_modal,
|
||||
data=data,
|
||||
is_embedding=False))
|
||||
elif modality == "mixture_text_image":
|
||||
mm_data = []
|
||||
for m in media:
|
||||
if m:
|
||||
mm_data.append(
|
||||
MultimodalData(modality="image",
|
||||
data=load_image(m,
|
||||
format=image_data_format,
|
||||
device=device)))
|
||||
MultimodalData(
|
||||
modality="image",
|
||||
data=load_image(m,
|
||||
format=image_data_format,
|
||||
device=device),
|
||||
is_embedding=False,
|
||||
))
|
||||
else:
|
||||
raise ValueError(f"Unknown modality: {modality}")
|
||||
return ConversationMessage(role="user", content=prompt, media=mm_data)
|
||||
@ -749,17 +793,12 @@ def default_multimodal_input_loader(
|
||||
is_embedding)
|
||||
mm_data_tracker = MultimodalDataTracker(model_type)
|
||||
for mdata in conv["media"]:
|
||||
# Check if mdata is a MultimodalData
|
||||
if isinstance(mdata,
|
||||
dict) and "modality" in mdata and "data" in mdata:
|
||||
mdata_modality = mdata["modality"]
|
||||
if modality == "multiple_image":
|
||||
mdata_modality = "image"
|
||||
mm_data_tracker.add_data(mdata_modality, mdata["data"])
|
||||
else:
|
||||
# Add embeddings to the tracker for placeholder handling
|
||||
mm_data_tracker.add_data(mdata["modality"],
|
||||
mdata["mm_embedding_info"])
|
||||
mdata_modality = mdata["modality"]
|
||||
if modality == "multiple_image":
|
||||
mdata_modality = "image"
|
||||
mm_data_tracker.add_data(mdata_modality,
|
||||
mdata["data"],
|
||||
is_embedding=is_embedding)
|
||||
mm_placeholder_counts = mm_data_tracker.placeholder_counts()
|
||||
prompt = conv["content"]
|
||||
if mm_placeholder_counts:
|
||||
@ -776,11 +815,13 @@ def default_multimodal_input_loader(
|
||||
|
||||
if mm_placeholder_counts:
|
||||
if mm_embeddings is not None:
|
||||
input[
|
||||
_, input[
|
||||
"multi_modal_embeddings"] = mm_data_tracker.retrieve_all_sync(
|
||||
)
|
||||
else:
|
||||
input["multi_modal_data"] = mm_data_tracker.retrieve_all_sync()
|
||||
input[
|
||||
"multi_modal_data"], _ = mm_data_tracker.retrieve_all_sync(
|
||||
)
|
||||
inputs.append(input)
|
||||
|
||||
return inputs
|
||||
|
||||
@ -17,7 +17,8 @@ from typing_extensions import Required
|
||||
from tensorrt_llm.inputs import (ConversationMessage, MultimodalData,
|
||||
MultimodalDataTracker,
|
||||
add_multimodal_placeholders, async_load_audio,
|
||||
async_load_image, async_load_video)
|
||||
async_load_image, async_load_video,
|
||||
load_base64_image_embeds)
|
||||
from tensorrt_llm.inputs.multimodal import MultimodalServerConfig
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
@ -33,24 +34,45 @@ class ChatCompletionContentPartVideoParam(TypedDict, total=False):
|
||||
type: Required[Literal["video_url"]]
|
||||
|
||||
|
||||
class ImageEmbedsData(TypedDict):
|
||||
"""Type definition for serialized image embeddings structure."""
|
||||
data: Required[str]
|
||||
|
||||
|
||||
class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False):
|
||||
"""Type definition for image embeddings passed in base64-encoded PyTorch tensor format."""
|
||||
image_embeds: Required[
|
||||
# TODO: Besides "data", could support "url" and "ipc_handle" in the future.
|
||||
ImageEmbedsData]
|
||||
type: Required[Literal["image_embeds"]]
|
||||
|
||||
|
||||
# Type Aliases and Constants
|
||||
ChatCompletionContentPartParam: TypeAlias = Union[
|
||||
OpenAIChatCompletionContentPartParam, ChatCompletionContentPartVideoParam,
|
||||
str]
|
||||
OpenAIChatCompletionContentPartParam,
|
||||
ChatCompletionContentPartVideoParam,
|
||||
ChatCompletionContentPartImageEmbedsParam,
|
||||
str,
|
||||
]
|
||||
|
||||
# TODO: Add "input_audio" to support byte_encoded audio input.
|
||||
VALID_MESSAGE_CONTENT_MM_PART_TYPES = [
|
||||
"text", "image_url", "video_url", "audio_url"
|
||||
"text",
|
||||
"image_url",
|
||||
"video_url",
|
||||
"audio_url",
|
||||
"image_embeds",
|
||||
]
|
||||
|
||||
# Parser Functions
|
||||
_TextParser = partial(cast, ChatCompletionContentPartTextParam)
|
||||
_ImageParser = partial(cast, ChatCompletionContentPartImageParam)
|
||||
_ImageEmbedsParser = partial(cast, ChatCompletionContentPartImageEmbedsParam)
|
||||
_VideoParser = partial(cast, ChatCompletionContentPartVideoParam)
|
||||
_AudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
|
||||
|
||||
MM_PARSER_MAP: dict[str, Callable[[ChatCompletionContentPartParam], Union[
|
||||
str, dict[str, str]]]] = {
|
||||
str, dict[str, str], None]]] = {
|
||||
"text":
|
||||
lambda part: _TextParser(part).get("text", None),
|
||||
"image_url":
|
||||
@ -59,12 +81,15 @@ MM_PARSER_MAP: dict[str, Callable[[ChatCompletionContentPartParam], Union[
|
||||
lambda part: _VideoParser(part).get("video_url", {}).get("url", None),
|
||||
"audio_url":
|
||||
lambda part: _AudioParser(part).get("audio_url", {}).get("url", None),
|
||||
"image_embeds":
|
||||
lambda part: _ImageEmbedsParser(part).get("image_embeds", {}).get(
|
||||
"data", None),
|
||||
}
|
||||
|
||||
|
||||
def _parse_chat_message_content_mm_part(
|
||||
part: ChatCompletionContentPartParam
|
||||
) -> tuple[str, Union[str, dict[str, str]]]:
|
||||
) -> tuple[str, Union[str, dict[str, str], None]]:
|
||||
"""Parse a single multimodal part of a chat message."""
|
||||
assert isinstance(part, dict)
|
||||
part_type = part.get("type", None)
|
||||
@ -78,9 +103,9 @@ def _parse_chat_message_content_mm_part(
|
||||
|
||||
|
||||
def parse_chat_message_content_part(
|
||||
part: ChatCompletionMessageParam,
|
||||
part: ChatCompletionContentPartParam,
|
||||
mm_data_tracker: MultimodalDataTracker,
|
||||
) -> Optional[Any]:
|
||||
) -> str | MultimodalData | None:
|
||||
"""Parse a single part of a chat message."""
|
||||
if isinstance(part, str):
|
||||
return part
|
||||
@ -110,7 +135,23 @@ def parse_chat_message_content_part(
|
||||
logger.error(f"Failed to load image: {str(e)}")
|
||||
return None
|
||||
|
||||
return MultimodalData(modality="image", data=load_image_async())
|
||||
return MultimodalData(modality="image",
|
||||
data=load_image_async(),
|
||||
is_embedding=False)
|
||||
|
||||
if part_type == "image_embeds":
|
||||
str_content = cast(str, content)
|
||||
|
||||
async def decode_image_embeds_async():
|
||||
try:
|
||||
return load_base64_image_embeds(str_content)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to decode image data: {str(e)}")
|
||||
return None
|
||||
|
||||
return MultimodalData(modality="image",
|
||||
data=decode_image_embeds_async(),
|
||||
is_embedding=True)
|
||||
|
||||
if part_type == "video_url":
|
||||
str_content = cast(str, content)
|
||||
@ -125,7 +166,9 @@ def parse_chat_message_content_part(
|
||||
logger.error(f"Failed to load video: {str(e)}")
|
||||
return None
|
||||
|
||||
return MultimodalData(modality="video", data=load_video_async())
|
||||
return MultimodalData(modality="video",
|
||||
data=load_video_async(),
|
||||
is_embedding=False)
|
||||
|
||||
if part_type == "audio_url":
|
||||
str_content = cast(str, content)
|
||||
@ -140,14 +183,16 @@ def parse_chat_message_content_part(
|
||||
logger.error(f"Failed to load audio: {str(e)}")
|
||||
return None
|
||||
|
||||
return MultimodalData(modality="audio", data=load_audio_async())
|
||||
return MultimodalData(modality="audio",
|
||||
data=load_audio_async(),
|
||||
is_embedding=False)
|
||||
|
||||
raise NotImplementedError(f"Unknown part type: {part_type}")
|
||||
|
||||
|
||||
def parse_chat_message_content_parts(
|
||||
role: str,
|
||||
parts: Iterable[ChatCompletionMessageParam],
|
||||
parts: Iterable[ChatCompletionContentPartParam],
|
||||
mm_data_tracker: MultimodalDataTracker,
|
||||
) -> ConversationMessage:
|
||||
"""Parse multiple parts of a chat message."""
|
||||
@ -224,8 +269,9 @@ def parse_chat_messages_coroutines(
|
||||
messages: List[ChatCompletionMessageParam],
|
||||
model_config: AutoConfig,
|
||||
multimodal_server_config: Optional[MultimodalServerConfig] = None
|
||||
) -> Tuple[List[ConversationMessage], Optional[Coroutine[
|
||||
Any, Any, Optional[Dict[str, List[Any]]]]]]:
|
||||
) -> Tuple[List[ConversationMessage], Coroutine[Any, Any, tuple[Optional[Dict[
|
||||
str, List[Any]]], Optional[Dict[str, List[Any]]]]], list[dict[str,
|
||||
int]]]:
|
||||
"""Parse multiple chat messages and return conversation and coroutine."""
|
||||
conversation = []
|
||||
mm_placeholder_counts = []
|
||||
@ -237,7 +283,9 @@ def parse_chat_messages_coroutines(
|
||||
conversation.append(parsed_msg)
|
||||
if parsed_msg["media"]:
|
||||
for mdata in parsed_msg["media"]:
|
||||
mm_data_tracker.add_data(mdata["modality"], mdata["data"])
|
||||
mm_data_tracker.add_data(mdata["modality"],
|
||||
mdata["data"],
|
||||
is_embedding=mdata["is_embedding"])
|
||||
mm_placeholder_count = mm_data_tracker.placeholder_counts()
|
||||
if mm_placeholder_count:
|
||||
parsed_msg["content"] = add_multimodal_placeholders(
|
||||
|
||||
@ -563,9 +563,13 @@ class OpenAIServer:
|
||||
)
|
||||
prompt = prompt_inputs(prompt)
|
||||
|
||||
mm_data = await mm_coroutines
|
||||
if mm_data is not None:
|
||||
mm_data, mm_embeddings = await mm_coroutines
|
||||
if mm_data:
|
||||
prompt["multi_modal_data"] = mm_data
|
||||
if mm_embeddings:
|
||||
prompt["multi_modal_embeddings"] = mm_embeddings
|
||||
if mm_data and mm_embeddings:
|
||||
raise ValueError("Passing 'multi_modal_data' and 'multi_modal_embeddings' at the same time is not supported.")
|
||||
|
||||
postproc_args.reasoning_parser = self.llm.args.reasoning_parser
|
||||
postproc_args.tool_parser = self.tool_parser
|
||||
@ -666,7 +670,9 @@ class OpenAIServer:
|
||||
)
|
||||
prompt = prompt_inputs(prompt)
|
||||
|
||||
mm_data = await mm_coroutines
|
||||
mm_data, mm_embeddings = await mm_coroutines
|
||||
if mm_embeddings:
|
||||
raise ValueError("Cannot use multimodal embeddings as input")
|
||||
if mm_data is not None:
|
||||
prompt["multi_modal_data"] = mm_data
|
||||
|
||||
|
||||
@ -1683,9 +1683,13 @@ def test_openai_lora(llm_root, llm_venv):
|
||||
|
||||
def test_openai_chat_multimodal_example(llm_root, llm_venv):
|
||||
test_root = unittest_path() / "llmapi" / "apps"
|
||||
llm_venv.run_cmd(
|
||||
["-m", "pytest",
|
||||
str(test_root / "_test_openai_chat_multimodal.py")])
|
||||
llm_venv.run_cmd([
|
||||
"-m",
|
||||
"pytest",
|
||||
str(test_root / "_test_openai_chat_multimodal.py"),
|
||||
"-m",
|
||||
"not needs_l40s",
|
||||
])
|
||||
|
||||
|
||||
def test_openai_mmencoder_example(llm_root, llm_venv):
|
||||
|
||||
@ -28,6 +28,7 @@ l0_l40s:
|
||||
- test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[phi4-multimodal-instruct-multimodals/Phi-4-multimodal-instruct-audio]
|
||||
- test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[phi4-multimodal-instruct-multimodals/Phi-4-multimodal-instruct-image]
|
||||
- test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[phi4-multimodal-instruct-multimodals/Phi-4-multimodal-instruct-image_audio]
|
||||
- unittest/llmapi/apps/_test_openai_chat_multimodal.py::test_single_chat_session_image_embeds -m needs_l40s
|
||||
# MMMU sanity check
|
||||
- accuracy/test_llm_api_pytorch_multimodal.py::TestQwen2_5_VL_7B::test_auto_dtype
|
||||
- accuracy/test_llm_api_pytorch_multimodal.py::TestVILA1_5_3B::test_auto_dtype
|
||||
|
||||
@ -0,0 +1,53 @@
|
||||
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# used by tests/unittest/llmapi/apps/_test_openai_chat_multimodal.py
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from tensorrt_llm._torch.models.modeling_qwen2vl import Qwen2VLInputProcessorBase
|
||||
from tensorrt_llm.inputs import ExtraProcessedInputs, TextPrompt
|
||||
from tensorrt_llm.sampling_params import SamplingParams
|
||||
|
||||
_attach_multimodal_embeddings_orig = Qwen2VLInputProcessorBase.attach_multimodal_embeddings
|
||||
|
||||
|
||||
# signature taken from tensorrt_llm/inputs/registry.py
|
||||
def _attach_multimodal_embeddings(
|
||||
self,
|
||||
inputs: TextPrompt,
|
||||
multimodal_embedding: dict[str, list[torch.Tensor]],
|
||||
sampling_params: SamplingParams,
|
||||
) -> tuple[list[int], Optional[ExtraProcessedInputs]]:
|
||||
try:
|
||||
_attach_multimodal_embeddings_orig(self, inputs, multimodal_embedding, sampling_params)
|
||||
except NotImplementedError:
|
||||
pass
|
||||
else:
|
||||
raise ValueError(
|
||||
"Remove this custom module, Qwen2VLInputProcessorBase implements attach_multimodal_embeddings"
|
||||
)
|
||||
|
||||
tempdir = tempfile.gettempdir()
|
||||
file_path = Path(tempdir) / "multimodal_embedding.pickle"
|
||||
with open(file_path, "wb") as f:
|
||||
torch.save(multimodal_embedding, f)
|
||||
raise ValueError(file_path)
|
||||
|
||||
|
||||
Qwen2VLInputProcessorBase.attach_multimodal_embeddings = _attach_multimodal_embeddings
|
||||
@ -1,13 +1,18 @@
|
||||
import io
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
from base64 import b64encode
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
import torch
|
||||
import yaml
|
||||
from PIL import Image
|
||||
|
||||
from tensorrt_llm._torch.shared_tensor import SharedTensorContainer
|
||||
from tensorrt_llm.inputs import encode_base64_image
|
||||
|
||||
from ..test_llm import get_model_path
|
||||
@ -17,6 +22,13 @@ pytestmark = pytest.mark.threadleak(enabled=False)
|
||||
|
||||
from utils.llm_data import llm_models_root
|
||||
|
||||
from ._test_openai_mmencoder import RemoteMMEncoderServer
|
||||
from ._test_openai_mmencoder import server as mm_encoder_server
|
||||
from ._test_openai_mmencoder import \
|
||||
test_multimodal_content_mm_encoder as _test_multimodal_content_mm_encoder
|
||||
|
||||
assert mm_encoder_server is not None # keep 'mm_encoder_server' fixture visible in this module
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", ids=["Qwen2.5-VL-3B-Instruct"])
|
||||
def model_name():
|
||||
@ -25,7 +37,7 @@ def model_name():
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def temp_extra_llm_api_options_file(request):
|
||||
temp_dir = tempfile.gettempdir()
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
temp_file_path = os.path.join(temp_dir, "extra_llm_api_options.yaml")
|
||||
try:
|
||||
extra_llm_api_options_dict = {
|
||||
@ -123,6 +135,98 @@ def test_single_chat_session_image(client: openai.OpenAI, model_name: str):
|
||||
== chat_completion.choices[0].message.content
|
||||
|
||||
|
||||
# used by mm_encoder_server
|
||||
@pytest.fixture(scope="module")
|
||||
def extra_encoder_options() -> bool:
|
||||
return False
|
||||
|
||||
|
||||
# used by mm_encoder_server
|
||||
@pytest.fixture(scope="module")
|
||||
def temp_extra_encoder_options_file() -> str:
|
||||
return "/dummy/path"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server_patched(model_name: str, temp_extra_llm_api_options_file: str):
|
||||
# Custom module implements missing 'attach_multimodal_embeddings' to intercept
|
||||
# embeddings.
|
||||
model_path = get_model_path(model_name)
|
||||
args = [
|
||||
"--extra_llm_api_options",
|
||||
temp_extra_llm_api_options_file,
|
||||
"--max_batch_size",
|
||||
"64",
|
||||
"--max_num_tokens",
|
||||
"16384",
|
||||
"--custom_module_dirs",
|
||||
str(
|
||||
Path(sys.modules[test_single_chat_session_image_embeds.__module__].
|
||||
__file__).parent / "_attach_multimodal_embeddings_patch"),
|
||||
]
|
||||
with RemoteOpenAIServer(model_path, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest.mark.needs_l40s
|
||||
@pytest.mark.asyncio(loop_scope="module")
|
||||
def test_single_chat_session_image_embeds(
|
||||
server_patched: RemoteOpenAIServer,
|
||||
model_name: str,
|
||||
mm_encoder_server: RemoteMMEncoderServer,
|
||||
):
|
||||
client = server_patched.get_client()
|
||||
messages, mm_embed_handle = _test_multimodal_content_mm_encoder(
|
||||
mm_encoder_server.get_client(), model_name)
|
||||
|
||||
max_completion_tokens = 10
|
||||
|
||||
chat_completion_image = client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
temperature=0.0,
|
||||
logprobs=False)
|
||||
|
||||
mm_embed = SharedTensorContainer.from_dict(mm_embed_handle).get_local_view()
|
||||
with io.BytesIO() as buf:
|
||||
torch.save(mm_embed, buf)
|
||||
mm_embed_bytes = buf.getvalue()
|
||||
|
||||
image_content = messages[0]["content"][1]
|
||||
assert image_content["type"] == "image_url"
|
||||
image_content.clear()
|
||||
image_content["type"] = "image_embeds"
|
||||
image_content["image_embeds"] = {
|
||||
"data": b64encode(mm_embed_bytes).decode("ascii")
|
||||
}
|
||||
|
||||
# test single completion
|
||||
#
|
||||
# FIXME: Remove try-except and use 'server' instead of 'server_patched',
|
||||
# once Qwen2VLInputProcessorBase implements attach_multimodal_embeddings.
|
||||
try:
|
||||
chat_completion_embeds = client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
temperature=0.0,
|
||||
logprobs=False)
|
||||
|
||||
assert chat_completion_embeds.choices[
|
||||
0].message == chat_completion_image.choices[0].message
|
||||
except openai.BadRequestError as e:
|
||||
assert isinstance(e.body, dict)
|
||||
with open(Path(e.body["message"]), "rb") as f:
|
||||
intercepted_embeddings = torch.load(f, weights_only=True)
|
||||
assert list(intercepted_embeddings.keys()) == ["image"]
|
||||
assert len(intercepted_embeddings["image"]) == 1
|
||||
torch.testing.assert_close(intercepted_embeddings["image"][0],
|
||||
mm_embed.cpu())
|
||||
pytest.xfail(
|
||||
reason="Model does not implement 'attach_multimodal_embeddings'")
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="module")
|
||||
def test_single_chat_session_multi_image(client: openai.OpenAI,
|
||||
model_name: str):
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Any
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
@ -67,7 +68,9 @@ def async_client(server: RemoteMMEncoderServer):
|
||||
return server.get_async_client()
|
||||
|
||||
|
||||
def test_multimodal_content_mm_encoder(client: openai.OpenAI, model_name: str):
|
||||
def test_multimodal_content_mm_encoder(
|
||||
client: openai.OpenAI,
|
||||
model_name: str) -> tuple[list[dict[str, Any]], dict[str, Any]]:
|
||||
|
||||
content_text = "Describe the natural environment in the image."
|
||||
image_url = str(llm_models_root() / "multimodals" / "test_data" /
|
||||
@ -105,6 +108,8 @@ def test_multimodal_content_mm_encoder(client: openai.OpenAI, model_name: str):
|
||||
assert mm_handle["tensor_size"][
|
||||
1] == 2048 # qwen2.5-vl: hidden_size of the vision encoder
|
||||
|
||||
return messages, mm_handle # used by tests/unittest/llmapi/apps/_test_openai_chat_multimodal.py
|
||||
|
||||
|
||||
def test_health(server: RemoteMMEncoderServer):
|
||||
health_url = server.url_for("health")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user