From 5e3c26ebfb1ab51733548cd405bb5d45cefa7315 Mon Sep 17 00:00:00 2001 From: ixlmar <206748156+ixlmar@users.noreply.github.com> Date: Thu, 4 Dec 2025 17:39:03 +0100 Subject: [PATCH 01/11] feat: support image_embeds in OpenAI API Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com> --- .../commands/trtllm-serve/trtllm-serve.rst | 18 ++++ tensorrt_llm/inputs/__init__.py | 4 +- tensorrt_llm/inputs/utils.py | 18 +++- tensorrt_llm/serve/chat_utils.py | 57 ++++++++-- tensorrt_llm/serve/openai_server.py | 15 ++- .../__init__.py | 44 ++++++++ .../apps/_test_openai_chat_multimodal.py | 102 +++++++++++++++++- .../llmapi/apps/_test_openai_mmencoder.py | 7 +- 8 files changed, 249 insertions(+), 16 deletions(-) create mode 100644 tests/unittest/llmapi/apps/_attach_multimodal_embeddings_patch/__init__.py diff --git a/docs/source/commands/trtllm-serve/trtllm-serve.rst b/docs/source/commands/trtllm-serve/trtllm-serve.rst index b26e45de92..a7aa9b803d 100644 --- a/docs/source/commands/trtllm-serve/trtllm-serve.rst +++ b/docs/source/commands/trtllm-serve/trtllm-serve.rst @@ -170,6 +170,24 @@ TRT-LLM multimodal supports the following modalities and data types (depending o `load_base64_image utility `__ for implementation details. +**Image embeddings** + +It is also possible to directly provide the image embeddings to use by the multimodal +model. + +* Using "image_embeds" with base64-encoded data: + + .. code-block:: json + + {"role": "user", "content": [ + {"type": "text", "text": "What's in this image?"}, + {"type": "image_embeds", "image_embeds": "{image_embeddings_base64}"}} + ]} + +.. note:: + The contents of `image_embeddings_base64` can be generated by base64-encoding + the result of serializing a tensor via `torch.save`. + **Video** * Using "video_url": diff --git a/tensorrt_llm/inputs/__init__.py b/tensorrt_llm/inputs/__init__.py index 406a71d4f5..4587d39f03 100644 --- a/tensorrt_llm/inputs/__init__.py +++ b/tensorrt_llm/inputs/__init__.py @@ -16,7 +16,8 @@ from .utils import (ALL_SUPPORTED_AUDIO_MODELS, ALL_SUPPORTED_IMAGE_MODELS, async_load_audio, async_load_image, async_load_video, convert_image_mode, default_multimodal_input_loader, encode_base64_content_from_url, encode_base64_image, - get_cache_salt_id, load_image, load_video) + get_cache_salt_id, load_base64_image_embeds, load_image, + load_video) __all__ = [ "ALL_SUPPORTED_MULTIMODAL_MODELS", @@ -57,4 +58,5 @@ __all__ = [ "get_cache_salt_id", "compute_retained_tokens_count", "compute_retention_mask", + "load_base64_image_embeds", ] diff --git a/tensorrt_llm/inputs/utils.py b/tensorrt_llm/inputs/utils.py index bbbd5f4f8f..b7832a3084 100644 --- a/tensorrt_llm/inputs/utils.py +++ b/tensorrt_llm/inputs/utils.py @@ -114,6 +114,15 @@ def load_base64_image(parsed_url: str) -> Image.Image: return image +def load_base64_image_embeds(str_content: str) -> torch.Tensor: + content_bytes = base64.b64decode(str_content) + with BytesIO(content_bytes) as buf: + image_data: torch.Tensor = torch.load(buf, + weights_only=True, + map_location="cpu") + return image_data + + def load_image(image: Union[str, Image.Image], format: str = "pt", device: str = "cpu") -> Union[Image.Image, torch.Tensor]: @@ -468,10 +477,15 @@ class MultimodalDataTracker: return {modality: items for modality, items in self._data.items()} - def add_data(self, media_type: str, data: Union[Coroutine, Any]): + def add_data(self, + media_type: str, + data: Union[Coroutine, Any], + *, + modality: Optional[str] = None): + modality = modality or media_type current_count = len(self._data[media_type]) + 1 placeholder = retrieve_multimodal_placeholder(self._model_type, - media_type, current_count) + modality, current_count) self._data[media_type].append(data) if placeholder: self._placeholder_counts[placeholder] += 1 diff --git a/tensorrt_llm/serve/chat_utils.py b/tensorrt_llm/serve/chat_utils.py index 26ee17c4f4..581615a201 100644 --- a/tensorrt_llm/serve/chat_utils.py +++ b/tensorrt_llm/serve/chat_utils.py @@ -17,7 +17,8 @@ from typing_extensions import Required from tensorrt_llm.inputs import (ConversationMessage, MultimodalData, MultimodalDataTracker, add_multimodal_placeholders, async_load_audio, - async_load_image, async_load_video) + async_load_image, async_load_video, + load_base64_image_embeds) from tensorrt_llm.inputs.multimodal import MultimodalServerConfig from tensorrt_llm.logger import logger @@ -33,24 +34,38 @@ class ChatCompletionContentPartVideoParam(TypedDict, total=False): type: Required[Literal["video_url"]] +class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False): + """Type definition for image embeddings passed in base64-encoded PyTorch tensor format.""" + image_embeds: Required[str] + type: Required[Literal["image_embeds"]] + + # Type Aliases and Constants ChatCompletionContentPartParam: TypeAlias = Union[ - OpenAIChatCompletionContentPartParam, ChatCompletionContentPartVideoParam, - str] + OpenAIChatCompletionContentPartParam, + ChatCompletionContentPartVideoParam, + ChatCompletionContentPartImageEmbedsParam, + str, +] # TODO: Add "input_audio" to support byte_encoded audio input. VALID_MESSAGE_CONTENT_MM_PART_TYPES = [ - "text", "image_url", "video_url", "audio_url" + "text", + "image_url", + "video_url", + "audio_url", + "image_embeds", ] # Parser Functions _TextParser = partial(cast, ChatCompletionContentPartTextParam) _ImageParser = partial(cast, ChatCompletionContentPartImageParam) +_ImageEmbedsParser = partial(cast, ChatCompletionContentPartImageEmbedsParam) _VideoParser = partial(cast, ChatCompletionContentPartVideoParam) _AudioParser = partial(cast, ChatCompletionContentPartInputAudioParam) MM_PARSER_MAP: dict[str, Callable[[ChatCompletionContentPartParam], Union[ - str, dict[str, str]]]] = { + str, dict[str, str], None]]] = { "text": lambda part: _TextParser(part).get("text", None), "image_url": @@ -59,12 +74,20 @@ MM_PARSER_MAP: dict[str, Callable[[ChatCompletionContentPartParam], Union[ lambda part: _VideoParser(part).get("video_url", {}).get("url", None), "audio_url": lambda part: _AudioParser(part).get("audio_url", {}).get("url", None), + "image_embeds": + lambda part: _ImageEmbedsParser(part).get("image_embeds", None), } +# Map from content part tags used to directly provide embeddings +# to the corresponding data modality. +MM_EMBEDDING_MAP: dict[str, str] = { + "image_embeds": "image", +} + def _parse_chat_message_content_mm_part( part: ChatCompletionContentPartParam -) -> tuple[str, Union[str, dict[str, str]]]: +) -> tuple[str, Union[str, dict[str, str], None]]: """Parse a single multimodal part of a chat message.""" assert isinstance(part, dict) part_type = part.get("type", None) @@ -78,7 +101,7 @@ def _parse_chat_message_content_mm_part( def parse_chat_message_content_part( - part: ChatCompletionMessageParam, + part: ChatCompletionContentPartParam, mm_data_tracker: MultimodalDataTracker, ) -> Optional[Any]: """Parse a single part of a chat message.""" @@ -112,6 +135,19 @@ def parse_chat_message_content_part( return MultimodalData(modality="image", data=load_image_async()) + if part_type == "image_embeds": + str_content = cast(str, content) + + async def decode_image_embeds_async(): + try: + return load_base64_image_embeds(str_content) + except Exception as e: + logger.error(f"Failed to decode image data: {str(e)}") + return None + + return MultimodalData(modality="image_embeds", + data=decode_image_embeds_async()) + if part_type == "video_url": str_content = cast(str, content) @@ -147,7 +183,7 @@ def parse_chat_message_content_part( def parse_chat_message_content_parts( role: str, - parts: Iterable[ChatCompletionMessageParam], + parts: Iterable[ChatCompletionContentPartParam], mm_data_tracker: MultimodalDataTracker, ) -> ConversationMessage: """Parse multiple parts of a chat message.""" @@ -237,7 +273,10 @@ def parse_chat_messages_coroutines( conversation.append(parsed_msg) if parsed_msg["media"]: for mdata in parsed_msg["media"]: - mm_data_tracker.add_data(mdata["modality"], mdata["data"]) + mm_data_tracker.add_data(mdata["modality"], + mdata["data"], + modality=MM_EMBEDDING_MAP.get( + mdata["modality"], None)) mm_placeholder_count = mm_data_tracker.placeholder_counts() if mm_placeholder_count: parsed_msg["content"] = add_multimodal_placeholders( diff --git a/tensorrt_llm/serve/openai_server.py b/tensorrt_llm/serve/openai_server.py index 99f19db0f0..be72099ffd 100644 --- a/tensorrt_llm/serve/openai_server.py +++ b/tensorrt_llm/serve/openai_server.py @@ -36,7 +36,7 @@ from tensorrt_llm.llmapi.disagg_utils import (DisaggClusterConfig, from tensorrt_llm.llmapi.llm import RequestOutput from tensorrt_llm.logger import logger from tensorrt_llm.metrics.collector import MetricsCollector -from tensorrt_llm.serve.chat_utils import (load_chat_template, +from tensorrt_llm.serve.chat_utils import (MM_EMBEDDING_MAP, load_chat_template, parse_chat_messages_coroutines) from tensorrt_llm.serve.cluster_storage import create_cluster_storage_client from tensorrt_llm.serve.disagg_auto_scaling import DisaggClusterWorker @@ -558,7 +558,18 @@ class OpenAIServer: mm_data = await mm_coroutines if mm_data is not None: - prompt["multi_modal_data"] = mm_data + # single out directly provided embeddings + mm_embeds = {} + for tag in list(mm_data.keys()): + if (modality := MM_EMBEDDING_MAP.get(tag, None)) is not None: + mm_embeds[modality] = mm_data.pop(tag) + + if mm_data: + prompt["multi_modal_data"] = mm_data + if mm_embeds: + prompt["multi_modal_embeddings"] = mm_embeds + if mm_data and mm_embeds: + raise ValueError("Passing 'multi_modal_data' and 'multi_modal_embeddings' at the same time is not supported.") postproc_args.reasoning_parser = self.llm.args.reasoning_parser postproc_args.tool_parser = self.tool_parser diff --git a/tests/unittest/llmapi/apps/_attach_multimodal_embeddings_patch/__init__.py b/tests/unittest/llmapi/apps/_attach_multimodal_embeddings_patch/__init__.py new file mode 100644 index 0000000000..749481fea1 --- /dev/null +++ b/tests/unittest/llmapi/apps/_attach_multimodal_embeddings_patch/__init__.py @@ -0,0 +1,44 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# used by tests/unittest/llmapi/apps/_test_openai_chat_multimodal.py + +import pickle +import tempfile +from pathlib import Path +from typing import Optional + +import torch + +from tensorrt_llm._torch.models.modeling_qwen2vl import Qwen2VLInputProcessorBase +from tensorrt_llm.inputs import ExtraProcessedInputs, TextPrompt +from tensorrt_llm.sampling_params import SamplingParams + + +# signature taken from tensorrt_llm/inputs/registry.py +def _attach_multimodal_embeddings( + self, + inputs: TextPrompt, + multimodal_embedding: dict[str, list[torch.Tensor]], + sampling_params: SamplingParams, +) -> tuple[list[int], Optional[ExtraProcessedInputs]]: + tempdir = tempfile.gettempdir() + file_path = Path(tempdir) / "multimodal_embedding.pickle" + with open(file_path, "wb") as f: + pickle.dump(multimodal_embedding, f) + raise ValueError(file_path) + + +assert not hasattr(Qwen2VLInputProcessorBase, "attach_multimodal_embeddings") +setattr(Qwen2VLInputProcessorBase, "attach_multimodal_embeddings", _attach_multimodal_embeddings) diff --git a/tests/unittest/llmapi/apps/_test_openai_chat_multimodal.py b/tests/unittest/llmapi/apps/_test_openai_chat_multimodal.py index fda0f8a493..15d25f3656 100644 --- a/tests/unittest/llmapi/apps/_test_openai_chat_multimodal.py +++ b/tests/unittest/llmapi/apps/_test_openai_chat_multimodal.py @@ -1,13 +1,19 @@ +import io import os +import pickle +import sys import tempfile +from base64 import b64encode from pathlib import Path from typing import List import openai import pytest +import torch import yaml from PIL import Image +from tensorrt_llm._torch.shared_tensor import SharedTensorContainer from tensorrt_llm.inputs import encode_base64_image from ..test_llm import get_model_path @@ -17,6 +23,12 @@ pytestmark = pytest.mark.threadleak(enabled=False) from utils.llm_data import llm_models_root +from ._test_openai_mmencoder import RemoteMMEncoderServer +from ._test_openai_mmencoder import server as mm_encoder_server +from ._test_openai_mmencoder import test_multimodal_content_mm_encoder + +assert mm_encoder_server is not None # keep 'mm_encoder_server' fixture visible in this module + @pytest.fixture(scope="module", ids=["Qwen2.5-VL-3B-Instruct"]) def model_name(): @@ -25,7 +37,7 @@ def model_name(): @pytest.fixture(scope="module") def temp_extra_llm_api_options_file(request): - temp_dir = tempfile.gettempdir() + temp_dir = tempfile.mkdtemp() temp_file_path = os.path.join(temp_dir, "extra_llm_api_options.yaml") try: extra_llm_api_options_dict = { @@ -123,6 +135,94 @@ def test_single_chat_session_image(client: openai.OpenAI, model_name: str): == chat_completion.choices[0].message.content +# used by mm_encoder_server +@pytest.fixture(scope="module") +def extra_encoder_options() -> bool: + return False + + +# used by mm_encoder_server +@pytest.fixture(scope="module") +def temp_extra_encoder_options_file() -> str: + return "/dummy/path" + + +@pytest.fixture(scope="module") +def server_patched(model_name: str, temp_extra_llm_api_options_file: str): + # Custom module implements missing 'attach_multimodal_embeddings' to intercept + # embeddings. + model_path = get_model_path(model_name) + args = [ + "--extra_llm_api_options", + temp_extra_llm_api_options_file, + "--max_batch_size", + "64", + "--max_num_tokens", + "16384", + "--custom_module_dirs", + str( + Path(sys.modules[test_single_chat_session_image_embeds.__module__]. + __file__).parent / "_attach_multimodal_embeddings_patch"), + ] + with RemoteOpenAIServer(model_path, args) as remote_server: + yield remote_server + + +@pytest.mark.asyncio(loop_scope="module") +def test_single_chat_session_image_embeds( + server_patched: RemoteOpenAIServer, + model_name: str, + mm_encoder_server: RemoteMMEncoderServer, +): + client = server_patched.get_client() + messages, mm_embed_handle = test_multimodal_content_mm_encoder( + mm_encoder_server.get_client(), model_name) + + max_completion_tokens = 10 + + chat_completion_image = client.chat.completions.create( + model=model_name, + messages=messages, + max_completion_tokens=max_completion_tokens, + temperature=0.0, + logprobs=False) + + mm_embed = SharedTensorContainer.from_dict(mm_embed_handle).get_local_view() + with io.BytesIO() as buf: + torch.save(mm_embed, buf) + mm_embed_bytes = buf.getvalue() + + image_content = messages[0]["content"][1] + assert image_content["type"] == "image_url" + image_content.clear() + image_content["type"] = "image_embeds" + image_content["image_embeds"] = b64encode(mm_embed_bytes).decode("ascii") + + # test single completion + # + # FIXME: Remove try-except and use 'server' instead of 'server_patched', + # once Qwen2VLInputProcessorBase implements attach_multimodal_embeddings. + try: + chat_completion_embeds = client.chat.completions.create( + model=model_name, + messages=messages, + max_completion_tokens=max_completion_tokens, + temperature=0.0, + logprobs=False) + + assert chat_completion_embeds.choices[ + 0].message == chat_completion_image.choices[0].message + except openai.BadRequestError as e: + with open(Path(e.body["message"]), "rb") as f: + intercepted_embeddings = pickle.load(f) + assert list(intercepted_embeddings.keys()) == ["image"] + assert len(intercepted_embeddings["image"]) == 1 + torch.testing.assert_close(intercepted_embeddings["image"][0], + mm_embed.cpu()) + pytest.xfail( + reason="Model does not implement 'attach_multimodal_embeddings'") + + @pytest.mark.asyncio(loop_scope="module") def test_single_chat_session_multi_image(client: openai.OpenAI, model_name: str): diff --git a/tests/unittest/llmapi/apps/_test_openai_mmencoder.py b/tests/unittest/llmapi/apps/_test_openai_mmencoder.py index 312f9232d4..5a5df3fa71 100644 --- a/tests/unittest/llmapi/apps/_test_openai_mmencoder.py +++ b/tests/unittest/llmapi/apps/_test_openai_mmencoder.py @@ -1,5 +1,6 @@ import os import tempfile +from typing import Any, List import openai import pytest @@ -67,7 +68,9 @@ def async_client(server: RemoteMMEncoderServer): return server.get_async_client() -def test_multimodal_content_mm_encoder(client: openai.OpenAI, model_name: str): +def test_multimodal_content_mm_encoder( + client: openai.OpenAI, + model_name: str) -> tuple[list[dict[str, Any]], dict[str, Any]]: content_text = "Describe the natural environment in the image." image_url = str(llm_models_root() / "multimodals" / "test_data" / @@ -105,6 +108,8 @@ def test_multimodal_content_mm_encoder(client: openai.OpenAI, model_name: str): assert mm_handle["tensor_size"][ 1] == 2048 # qwen2.5-vl: hidden_size of the vision encoder + return messages, mm_handle # used by tests/unittest/llmapi/apps/_test_openai_chat_multimodal.py + def test_health(server: RemoteMMEncoderServer): health_url = server.url_for("health") From 045331d494d08a9f3516632ce5df59733de210d7 Mon Sep 17 00:00:00 2001 From: ixlmar <206748156+ixlmar@users.noreply.github.com> Date: Fri, 5 Dec 2025 12:52:54 +0000 Subject: [PATCH 02/11] fix: run test_single_chat_session_image_embeds on L40S Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com> --- cpp/kernels/fmha_v2/pytest.ini | 1 + tests/integration/defs/test_e2e.py | 10 +++++++--- tests/integration/test_lists/test-db/l0_l40s.yml | 1 + .../llmapi/apps/_test_openai_chat_multimodal.py | 1 + 4 files changed, 10 insertions(+), 3 deletions(-) diff --git a/cpp/kernels/fmha_v2/pytest.ini b/cpp/kernels/fmha_v2/pytest.ini index 1b7c950701..4ffcf349e9 100644 --- a/cpp/kernels/fmha_v2/pytest.ini +++ b/cpp/kernels/fmha_v2/pytest.ini @@ -6,6 +6,7 @@ markers = fmhca debug bench + needs_l40s # bin: unit tests # test: python script for invoking fmha.exe testpaths = bin test diff --git a/tests/integration/defs/test_e2e.py b/tests/integration/defs/test_e2e.py index a5acfabb43..75bed9553f 100644 --- a/tests/integration/defs/test_e2e.py +++ b/tests/integration/defs/test_e2e.py @@ -1675,9 +1675,13 @@ def test_openai_lora(llm_root, llm_venv): def test_openai_chat_multimodal_example(llm_root, llm_venv): test_root = unittest_path() / "llmapi" / "apps" - llm_venv.run_cmd( - ["-m", "pytest", - str(test_root / "_test_openai_chat_multimodal.py")]) + llm_venv.run_cmd([ + "-m", + "pytest", + str(test_root / "_test_openai_chat_multimodal.py"), + "-m", + "not needs_l40s", + ]) def test_openai_mmencoder_example(llm_root, llm_venv): diff --git a/tests/integration/test_lists/test-db/l0_l40s.yml b/tests/integration/test_lists/test-db/l0_l40s.yml index c303789489..d10ba9fc2c 100644 --- a/tests/integration/test_lists/test-db/l0_l40s.yml +++ b/tests/integration/test_lists/test-db/l0_l40s.yml @@ -28,6 +28,7 @@ l0_l40s: - test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[phi4-multimodal-instruct-multimodals/Phi-4-multimodal-instruct-audio] - test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[phi4-multimodal-instruct-multimodals/Phi-4-multimodal-instruct-image] - test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[phi4-multimodal-instruct-multimodals/Phi-4-multimodal-instruct-image_audio] + - unittest/llmapi/apps/_test_openai_chat_multimodal.py::test_single_chat_session_image_embeds -m needs_l40s # MMMU sanity check - accuracy/test_llm_api_pytorch_multimodal.py::TestQwen2_5_VL_7B::test_auto_dtype - accuracy/test_llm_api_pytorch_multimodal.py::TestVILA1_5_3B::test_auto_dtype diff --git a/tests/unittest/llmapi/apps/_test_openai_chat_multimodal.py b/tests/unittest/llmapi/apps/_test_openai_chat_multimodal.py index 15d25f3656..5968c7681e 100644 --- a/tests/unittest/llmapi/apps/_test_openai_chat_multimodal.py +++ b/tests/unittest/llmapi/apps/_test_openai_chat_multimodal.py @@ -168,6 +168,7 @@ def server_patched(model_name: str, temp_extra_llm_api_options_file: str): yield remote_server +@pytest.mark.needs_l40s @pytest.mark.asyncio(loop_scope="module") def test_single_chat_session_image_embeds( server_patched: RemoteOpenAIServer, From b2a328c7068058ec462734061b98e6c0a3e9d4a1 Mon Sep 17 00:00:00 2001 From: ixlmar <206748156+ixlmar@users.noreply.github.com> Date: Fri, 5 Dec 2025 14:01:39 +0000 Subject: [PATCH 03/11] add nested "data" Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com> --- tensorrt_llm/serve/chat_utils.py | 12 ++++++++++-- .../llmapi/apps/_test_openai_chat_multimodal.py | 4 +++- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/tensorrt_llm/serve/chat_utils.py b/tensorrt_llm/serve/chat_utils.py index 581615a201..3adee52b9e 100644 --- a/tensorrt_llm/serve/chat_utils.py +++ b/tensorrt_llm/serve/chat_utils.py @@ -34,9 +34,16 @@ class ChatCompletionContentPartVideoParam(TypedDict, total=False): type: Required[Literal["video_url"]] +class ImageEmbedsData(TypedDict): + """Type definition for serialized image embeddings structure.""" + data: Required[str] + + class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False): """Type definition for image embeddings passed in base64-encoded PyTorch tensor format.""" - image_embeds: Required[str] + image_embeds: Required[ + # NB: Besides "data", could support "url" and "ipc_handle" in the future. + ImageEmbedsData] type: Required[Literal["image_embeds"]] @@ -75,7 +82,8 @@ MM_PARSER_MAP: dict[str, Callable[[ChatCompletionContentPartParam], Union[ "audio_url": lambda part: _AudioParser(part).get("audio_url", {}).get("url", None), "image_embeds": - lambda part: _ImageEmbedsParser(part).get("image_embeds", None), + lambda part: _ImageEmbedsParser(part).get("image_embeds", {}).get( + "data", None), } # Map from content part tags used to directly provide embeddings diff --git a/tests/unittest/llmapi/apps/_test_openai_chat_multimodal.py b/tests/unittest/llmapi/apps/_test_openai_chat_multimodal.py index 5968c7681e..dd51407e0b 100644 --- a/tests/unittest/llmapi/apps/_test_openai_chat_multimodal.py +++ b/tests/unittest/llmapi/apps/_test_openai_chat_multimodal.py @@ -197,7 +197,9 @@ def test_single_chat_session_image_embeds( assert image_content["type"] == "image_url" image_content.clear() image_content["type"] = "image_embeds" - image_content["image_embeds"] = b64encode(mm_embed_bytes).decode("ascii") + image_content["image_embeds"] = { + "data": b64encode(mm_embed_bytes).decode("ascii") + } # test single completion # From bebc2e4317edf84f939ccdf252a73344d181eac4 Mon Sep 17 00:00:00 2001 From: ixlmar <206748156+ixlmar@users.noreply.github.com> Date: Fri, 5 Dec 2025 15:33:59 +0000 Subject: [PATCH 04/11] do not run 'test_multimodal_content_mm_encoder' twice Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com> --- tests/unittest/llmapi/apps/_test_openai_chat_multimodal.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/unittest/llmapi/apps/_test_openai_chat_multimodal.py b/tests/unittest/llmapi/apps/_test_openai_chat_multimodal.py index dd51407e0b..d697b24f9b 100644 --- a/tests/unittest/llmapi/apps/_test_openai_chat_multimodal.py +++ b/tests/unittest/llmapi/apps/_test_openai_chat_multimodal.py @@ -25,7 +25,8 @@ from utils.llm_data import llm_models_root from ._test_openai_mmencoder import RemoteMMEncoderServer from ._test_openai_mmencoder import server as mm_encoder_server -from ._test_openai_mmencoder import test_multimodal_content_mm_encoder +from ._test_openai_mmencoder import \ + test_multimodal_content_mm_encoder as _test_multimodal_content_mm_encoder assert mm_encoder_server is not None # keep 'mm_encoder_server' fixture visible in this module @@ -176,7 +177,7 @@ def test_single_chat_session_image_embeds( mm_encoder_server: RemoteMMEncoderServer, ): client = server_patched.get_client() - messages, mm_embed_handle = test_multimodal_content_mm_encoder( + messages, mm_embed_handle = _test_multimodal_content_mm_encoder( mm_encoder_server.get_client(), model_name) max_completion_tokens = 10 From 0543bf01fbd43b6bf2db350a45951b1420d29bde Mon Sep 17 00:00:00 2001 From: ixlmar <206748156+ixlmar@users.noreply.github.com> Date: Mon, 8 Dec 2025 17:02:45 +0000 Subject: [PATCH 05/11] fix: update docs Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com> --- docs/source/commands/trtllm-serve/trtllm-serve.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/commands/trtllm-serve/trtllm-serve.rst b/docs/source/commands/trtllm-serve/trtllm-serve.rst index a7aa9b803d..c73e903e6c 100644 --- a/docs/source/commands/trtllm-serve/trtllm-serve.rst +++ b/docs/source/commands/trtllm-serve/trtllm-serve.rst @@ -181,7 +181,7 @@ model. {"role": "user", "content": [ {"type": "text", "text": "What's in this image?"}, - {"type": "image_embeds", "image_embeds": "{image_embeddings_base64}"}} + {"type": "image_embeds", "image_embeds": {"data": "{image_embeddings_base64}"}}} ]} .. note:: From db14542c35e7e3f93efd7bb3d185e70732e50f6c Mon Sep 17 00:00:00 2001 From: ixlmar <206748156+ixlmar@users.noreply.github.com> Date: Thu, 11 Dec 2025 12:58:03 +0000 Subject: [PATCH 06/11] chore: add is_embedding to MultimodalData Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com> --- tensorrt_llm/evaluate/lm_eval.py | 2 +- tensorrt_llm/inputs/utils.py | 148 ++++++++++++++++------------ tensorrt_llm/serve/chat_utils.py | 33 ++++--- tensorrt_llm/serve/openai_server.py | 27 +++-- 4 files changed, 115 insertions(+), 95 deletions(-) diff --git a/tensorrt_llm/evaluate/lm_eval.py b/tensorrt_llm/evaluate/lm_eval.py index 4a877d75f4..cf9b7f3007 100644 --- a/tensorrt_llm/evaluate/lm_eval.py +++ b/tensorrt_llm/evaluate/lm_eval.py @@ -244,7 +244,7 @@ class MultimodalLmEvalWrapper(LmEvalWrapper): # NOTE: Since we already have loaded images, for the placeholder purpose, we add data here. for _ in range(image_count): - mm_data_tracker.add_data("image", None) + mm_data_tracker.add_data("image", None, is_embedding=False) mm_placeholder_count = mm_data_tracker.placeholder_counts() if mm_placeholder_count: # TODO: This is an assumption of not interleaving text and image. Need to extend to interleaved texts. diff --git a/tensorrt_llm/inputs/utils.py b/tensorrt_llm/inputs/utils.py index b7832a3084..27be5a6ecc 100644 --- a/tensorrt_llm/inputs/utils.py +++ b/tensorrt_llm/inputs/utils.py @@ -434,13 +434,14 @@ class MultimodalData(TypedDict): """Type definition for multimodal data structure.""" modality: str data: Any + is_embedding: bool class ConversationMessage(TypedDict): """Type definition for conversation message structure.""" role: str content: List[dict[str, Any]] - media: List[MultimodalData] | List[torch.Tensor] | List[Dict[str, Any]] + media: List[MultimodalData] # @classmethod # def fromSample(cls, sample: dict[str, str]) -> "ConversationMessage": @@ -455,38 +456,54 @@ class MultimodalDataTracker: model_type: str, multimodal_server_config: Optional[MultimodalServerConfig] = None): self._model_type = model_type - self._data = defaultdict[str](list) - self._placeholder_counts = defaultdict[str](int) + self._data = defaultdict[str, list](list) + self._embeddings = defaultdict[str, list](list) + self._placeholder_counts = defaultdict[str, int](int) self._multimodal_server_config = multimodal_server_config if multimodal_server_config is not None else MultimodalServerConfig( ) - async def retrieve_all_async(self) -> Optional[Dict[str, List[Any]]]: - """Retrieve all collected multimodal data.""" - if not self._data: - return None + async def retrieve_all_async( + self + ) -> tuple[Optional[Dict[str, List[Any]]], Optional[Dict[str, List[Any]]]]: + """Retrieve all collected multimodal data and embeddings.""" - return { - modality: await asyncio.gather(*items) - for modality, items in self._data.items() - } + async def _retrieve( + data: Optional[dict[str, + list]]) -> Optional[Dict[str, List[Any]]]: + if not data: + return None + return { + modality: await asyncio.gather(*items) + for modality, items in data.items() if items + } - def retrieve_all_sync(self) -> Optional[Dict[str, List[Any]]]: - """Retrieve all collected multimodal data.""" - if not self._data: - return None + return await _retrieve(self._data), await _retrieve(self._embeddings) - return {modality: items for modality, items in self._data.items()} + def retrieve_all_sync( + self + ) -> tuple[Optional[Dict[str, List[Any]]], Optional[Dict[str, List[Any]]]]: + """Retrieve all collected multimodal data and embeddings.""" - def add_data(self, - media_type: str, - data: Union[Coroutine, Any], - *, - modality: Optional[str] = None): - modality = modality or media_type - current_count = len(self._data[media_type]) + 1 + def _retrieve( + data: Optional[dict[str, + list]]) -> Optional[Dict[str, List[Any]]]: + if not data: + return None + return { + modality: items + for modality, items in data.items() if items + } + + return _retrieve(self._data), _retrieve(self._embeddings) + + def add_data(self, media_type: str, data: Union[Coroutine, Any], *, + is_embedding: bool): + current_count = len(self._data[media_type]) + len( + self._embeddings[media_type]) + 1 placeholder = retrieve_multimodal_placeholder(self._model_type, - modality, current_count) - self._data[media_type].append(data) + media_type, current_count) + (self._embeddings + if is_embedding else self._data)[media_type].append(data) if placeholder: self._placeholder_counts[placeholder] += 1 @@ -657,33 +674,34 @@ def default_multimodal_input_loader( media = [media] if modality in ["image", "multiple_image"]: if is_embedding: + _load = lambda mm: mm + # each mm_embedding corresponds to each image placeholder if not isinstance(media, list): media = [media] - - mm_data = [{ - 'modality': modality, - 'mm_embedding_info': mm - } for mm in media] else: - mm_data = [ - MultimodalData(modality=modality, - data=load_image(i, - format=image_data_format, - device=device)) - for i in media - ] + _load = lambda mm: load_image( + mm, format=image_data_format, device=device) + + mm_data = [ + MultimodalData(modality=modality, + data=_load(mm), + is_embedding=is_embedding) for mm in media + ] elif modality == "video": if is_embedding: raise ValueError( "External embedding is not supported for video modality yet." ) mm_data = [ - MultimodalData(modality=modality, - data=load_video(i, - num_frames, - format=image_data_format, - device=device)) for i in media + MultimodalData( + modality=modality, + data=load_video(i, + num_frames, + format=image_data_format, + device=device), + is_embedding=False, + ) for i in media ] elif modality == "audio": if is_embedding: @@ -691,8 +709,11 @@ def default_multimodal_input_loader( "External embedding is not supported for audio modality yet." ) mm_data = [ - MultimodalData(modality=modality, - data=load_audio(i, device=device)) for i in media + MultimodalData( + modality=modality, + data=load_audio(i, device=device), + is_embedding=False, + ) for i in media ] elif modality == "image_audio": if is_embedding: @@ -720,16 +741,22 @@ def default_multimodal_input_loader( pass if _modal is None: raise ValueError(f"Unknown matching modality: {modality}") - mm_data.append(MultimodalData(modality=_modal, data=data)) + mm_data.append( + MultimodalData(modality=_modal, + data=data, + is_embedding=False)) elif modality == "mixture_text_image": mm_data = [] for m in media: if m: mm_data.append( - MultimodalData(modality="image", - data=load_image(m, - format=image_data_format, - device=device))) + MultimodalData( + modality="image", + data=load_image(m, + format=image_data_format, + device=device), + is_embedding=False, + )) else: raise ValueError(f"Unknown modality: {modality}") return ConversationMessage(role="user", content=prompt, media=mm_data) @@ -763,17 +790,12 @@ def default_multimodal_input_loader( is_embedding) mm_data_tracker = MultimodalDataTracker(model_type) for mdata in conv["media"]: - # Check if mdata is a MultimodalData - if isinstance(mdata, - dict) and "modality" in mdata and "data" in mdata: - mdata_modality = mdata["modality"] - if modality == "multiple_image": - mdata_modality = "image" - mm_data_tracker.add_data(mdata_modality, mdata["data"]) - else: - # Add embeddings to the tracker for placeholder handling - mm_data_tracker.add_data(mdata["modality"], - mdata["mm_embedding_info"]) + mdata_modality = mdata["modality"] + if modality == "multiple_image": + mdata_modality = "image" + mm_data_tracker.add_data(mdata_modality, + mdata["data"], + is_embedding=is_embedding) mm_placeholder_counts = mm_data_tracker.placeholder_counts() prompt = conv["content"] if mm_placeholder_counts: @@ -790,11 +812,13 @@ def default_multimodal_input_loader( if mm_placeholder_counts: if mm_embeddings is not None: - input[ + _, input[ "multi_modal_embeddings"] = mm_data_tracker.retrieve_all_sync( ) else: - input["multi_modal_data"] = mm_data_tracker.retrieve_all_sync() + input[ + "multi_modal_data"], _ = mm_data_tracker.retrieve_all_sync( + ) inputs.append(input) return inputs diff --git a/tensorrt_llm/serve/chat_utils.py b/tensorrt_llm/serve/chat_utils.py index 3adee52b9e..bdd0324c2a 100644 --- a/tensorrt_llm/serve/chat_utils.py +++ b/tensorrt_llm/serve/chat_utils.py @@ -86,12 +86,6 @@ MM_PARSER_MAP: dict[str, Callable[[ChatCompletionContentPartParam], Union[ "data", None), } -# Map from content part tags used to directly provide embeddings -# to the corresponding data modality. -MM_EMBEDDING_MAP: dict[str, str] = { - "image_embeds": "image", -} - def _parse_chat_message_content_mm_part( part: ChatCompletionContentPartParam @@ -111,7 +105,7 @@ def _parse_chat_message_content_mm_part( def parse_chat_message_content_part( part: ChatCompletionContentPartParam, mm_data_tracker: MultimodalDataTracker, -) -> Optional[Any]: +) -> str | MultimodalData | None: """Parse a single part of a chat message.""" if isinstance(part, str): return part @@ -141,7 +135,9 @@ def parse_chat_message_content_part( logger.error(f"Failed to load image: {str(e)}") return None - return MultimodalData(modality="image", data=load_image_async()) + return MultimodalData(modality="image", + data=load_image_async(), + is_embedding=False) if part_type == "image_embeds": str_content = cast(str, content) @@ -153,8 +149,9 @@ def parse_chat_message_content_part( logger.error(f"Failed to decode image data: {str(e)}") return None - return MultimodalData(modality="image_embeds", - data=decode_image_embeds_async()) + return MultimodalData(modality="image", + data=decode_image_embeds_async(), + is_embedding=True) if part_type == "video_url": str_content = cast(str, content) @@ -169,7 +166,9 @@ def parse_chat_message_content_part( logger.error(f"Failed to load video: {str(e)}") return None - return MultimodalData(modality="video", data=load_video_async()) + return MultimodalData(modality="video", + data=load_video_async(), + is_embedding=False) if part_type == "audio_url": str_content = cast(str, content) @@ -184,7 +183,9 @@ def parse_chat_message_content_part( logger.error(f"Failed to load audio: {str(e)}") return None - return MultimodalData(modality="audio", data=load_audio_async()) + return MultimodalData(modality="audio", + data=load_audio_async(), + is_embedding=False) raise NotImplementedError(f"Unknown part type: {part_type}") @@ -268,8 +269,9 @@ def parse_chat_messages_coroutines( messages: List[ChatCompletionMessageParam], model_config: AutoConfig, multimodal_server_config: Optional[MultimodalServerConfig] = None -) -> Tuple[List[ConversationMessage], Optional[Coroutine[ - Any, Any, Optional[Dict[str, List[Any]]]]]]: +) -> Tuple[List[ConversationMessage], Coroutine[Any, Any, tuple[Optional[Dict[ + str, List[Any]]], Optional[Dict[str, List[Any]]]]], list[dict[str, + int]]]: """Parse multiple chat messages and return conversation and coroutine.""" conversation = [] mm_placeholder_counts = [] @@ -283,8 +285,7 @@ def parse_chat_messages_coroutines( for mdata in parsed_msg["media"]: mm_data_tracker.add_data(mdata["modality"], mdata["data"], - modality=MM_EMBEDDING_MAP.get( - mdata["modality"], None)) + is_embedding=mdata["is_embedding"]) mm_placeholder_count = mm_data_tracker.placeholder_counts() if mm_placeholder_count: parsed_msg["content"] = add_multimodal_placeholders( diff --git a/tensorrt_llm/serve/openai_server.py b/tensorrt_llm/serve/openai_server.py index be72099ffd..108714190f 100644 --- a/tensorrt_llm/serve/openai_server.py +++ b/tensorrt_llm/serve/openai_server.py @@ -36,7 +36,7 @@ from tensorrt_llm.llmapi.disagg_utils import (DisaggClusterConfig, from tensorrt_llm.llmapi.llm import RequestOutput from tensorrt_llm.logger import logger from tensorrt_llm.metrics.collector import MetricsCollector -from tensorrt_llm.serve.chat_utils import (MM_EMBEDDING_MAP, load_chat_template, +from tensorrt_llm.serve.chat_utils import (load_chat_template, parse_chat_messages_coroutines) from tensorrt_llm.serve.cluster_storage import create_cluster_storage_client from tensorrt_llm.serve.disagg_auto_scaling import DisaggClusterWorker @@ -556,20 +556,13 @@ class OpenAIServer: ) prompt = prompt_inputs(prompt) - mm_data = await mm_coroutines - if mm_data is not None: - # single out directly provided embeddings - mm_embeds = {} - for tag in list(mm_data.keys()): - if (modality := MM_EMBEDDING_MAP.get(tag, None)) is not None: - mm_embeds[modality] = mm_data.pop(tag) - - if mm_data: - prompt["multi_modal_data"] = mm_data - if mm_embeds: - prompt["multi_modal_embeddings"] = mm_embeds - if mm_data and mm_embeds: - raise ValueError("Passing 'multi_modal_data' and 'multi_modal_embeddings' at the same time is not supported.") + mm_data, mm_embeddings = await mm_coroutines + if mm_data: + prompt["multi_modal_data"] = mm_data + if mm_embeddings: + prompt["multi_modal_embeddings"] = mm_embeddings + if mm_data and mm_embeddings: + raise ValueError("Passing 'multi_modal_data' and 'multi_modal_embeddings' at the same time is not supported.") postproc_args.reasoning_parser = self.llm.args.reasoning_parser postproc_args.tool_parser = self.tool_parser @@ -670,7 +663,9 @@ class OpenAIServer: ) prompt = prompt_inputs(prompt) - mm_data = await mm_coroutines + mm_data, mm_embeddings = await mm_coroutines + if mm_embeddings: + raise ValueError("Cannot use multimodal embeddings as input") if mm_data is not None: prompt["multi_modal_data"] = mm_data From e42a8f7d641c30d5a2454b824c006a3901fc4fee Mon Sep 17 00:00:00 2001 From: ixlmar <206748156+ixlmar@users.noreply.github.com> Date: Thu, 11 Dec 2025 12:58:20 +0000 Subject: [PATCH 07/11] chore: use torch.save/load Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com> --- .../apps/_attach_multimodal_embeddings_patch/__init__.py | 3 +-- tests/unittest/llmapi/apps/_test_openai_chat_multimodal.py | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/unittest/llmapi/apps/_attach_multimodal_embeddings_patch/__init__.py b/tests/unittest/llmapi/apps/_attach_multimodal_embeddings_patch/__init__.py index 749481fea1..5662063aab 100644 --- a/tests/unittest/llmapi/apps/_attach_multimodal_embeddings_patch/__init__.py +++ b/tests/unittest/llmapi/apps/_attach_multimodal_embeddings_patch/__init__.py @@ -14,7 +14,6 @@ # used by tests/unittest/llmapi/apps/_test_openai_chat_multimodal.py -import pickle import tempfile from pathlib import Path from typing import Optional @@ -36,7 +35,7 @@ def _attach_multimodal_embeddings( tempdir = tempfile.gettempdir() file_path = Path(tempdir) / "multimodal_embedding.pickle" with open(file_path, "wb") as f: - pickle.dump(multimodal_embedding, f) + torch.save(multimodal_embedding, f) raise ValueError(file_path) diff --git a/tests/unittest/llmapi/apps/_test_openai_chat_multimodal.py b/tests/unittest/llmapi/apps/_test_openai_chat_multimodal.py index d697b24f9b..4183e1874e 100644 --- a/tests/unittest/llmapi/apps/_test_openai_chat_multimodal.py +++ b/tests/unittest/llmapi/apps/_test_openai_chat_multimodal.py @@ -1,6 +1,5 @@ import io import os -import pickle import sys import tempfile from base64 import b64encode @@ -217,8 +216,9 @@ def test_single_chat_session_image_embeds( assert chat_completion_embeds.choices[ 0].message == chat_completion_image.choices[0].message except openai.BadRequestError as e: + assert isinstance(e.body, dict) with open(Path(e.body["message"]), "rb") as f: - intercepted_embeddings = pickle.load(f) + intercepted_embeddings = torch.load(f, weights_only=True) assert list(intercepted_embeddings.keys()) == ["image"] assert len(intercepted_embeddings["image"]) == 1 torch.testing.assert_close(intercepted_embeddings["image"][0], From e4233671a934f0d48feb4036ef9f040519a0cea0 Mon Sep 17 00:00:00 2001 From: ixlmar <206748156+ixlmar@users.noreply.github.com> Date: Thu, 11 Dec 2025 13:02:16 +0000 Subject: [PATCH 08/11] address remaining review comments Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com> --- tensorrt_llm/serve/chat_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorrt_llm/serve/chat_utils.py b/tensorrt_llm/serve/chat_utils.py index bdd0324c2a..e08caadaaf 100644 --- a/tensorrt_llm/serve/chat_utils.py +++ b/tensorrt_llm/serve/chat_utils.py @@ -42,7 +42,7 @@ class ImageEmbedsData(TypedDict): class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False): """Type definition for image embeddings passed in base64-encoded PyTorch tensor format.""" image_embeds: Required[ - # NB: Besides "data", could support "url" and "ipc_handle" in the future. + # TODO: Besides "data", could support "url" and "ipc_handle" in the future. ImageEmbedsData] type: Required[Literal["image_embeds"]] From eb9a24e6f057c0dbf7b7e8b204949f9adc4b8130 Mon Sep 17 00:00:00 2001 From: ixlmar <206748156+ixlmar@users.noreply.github.com> Date: Fri, 12 Dec 2025 16:59:08 +0000 Subject: [PATCH 09/11] chore: refine MultimodalDataTracker.add_data Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com> --- tensorrt_llm/evaluate/lm_eval.py | 2 +- tensorrt_llm/inputs/utils.py | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/tensorrt_llm/evaluate/lm_eval.py b/tensorrt_llm/evaluate/lm_eval.py index cf9b7f3007..4a877d75f4 100644 --- a/tensorrt_llm/evaluate/lm_eval.py +++ b/tensorrt_llm/evaluate/lm_eval.py @@ -244,7 +244,7 @@ class MultimodalLmEvalWrapper(LmEvalWrapper): # NOTE: Since we already have loaded images, for the placeholder purpose, we add data here. for _ in range(image_count): - mm_data_tracker.add_data("image", None, is_embedding=False) + mm_data_tracker.add_data("image", None) mm_placeholder_count = mm_data_tracker.placeholder_counts() if mm_placeholder_count: # TODO: This is an assumption of not interleaving text and image. Need to extend to interleaved texts. diff --git a/tensorrt_llm/inputs/utils.py b/tensorrt_llm/inputs/utils.py index 27be5a6ecc..0dc09547b3 100644 --- a/tensorrt_llm/inputs/utils.py +++ b/tensorrt_llm/inputs/utils.py @@ -496,8 +496,11 @@ class MultimodalDataTracker: return _retrieve(self._data), _retrieve(self._embeddings) - def add_data(self, media_type: str, data: Union[Coroutine, Any], *, - is_embedding: bool): + def add_data(self, + media_type: str, + data: Union[Coroutine, Any], + *, + is_embedding: bool = False): current_count = len(self._data[media_type]) + len( self._embeddings[media_type]) + 1 placeholder = retrieve_multimodal_placeholder(self._model_type, From 240eff4bd865b03258e57b935104ed8b7b150600 Mon Sep 17 00:00:00 2001 From: ixlmar <206748156+ixlmar@users.noreply.github.com> Date: Sat, 13 Dec 2025 16:27:24 +0000 Subject: [PATCH 10/11] fix: conform to upstream Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com> --- .../__init__.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/tests/unittest/llmapi/apps/_attach_multimodal_embeddings_patch/__init__.py b/tests/unittest/llmapi/apps/_attach_multimodal_embeddings_patch/__init__.py index 5662063aab..7d8281ecd2 100644 --- a/tests/unittest/llmapi/apps/_attach_multimodal_embeddings_patch/__init__.py +++ b/tests/unittest/llmapi/apps/_attach_multimodal_embeddings_patch/__init__.py @@ -24,6 +24,8 @@ from tensorrt_llm._torch.models.modeling_qwen2vl import Qwen2VLInputProcessorBas from tensorrt_llm.inputs import ExtraProcessedInputs, TextPrompt from tensorrt_llm.sampling_params import SamplingParams +_attach_multimodal_embeddings_orig = Qwen2VLInputProcessorBase.attach_multimodal_embeddings + # signature taken from tensorrt_llm/inputs/registry.py def _attach_multimodal_embeddings( @@ -32,6 +34,15 @@ def _attach_multimodal_embeddings( multimodal_embedding: dict[str, list[torch.Tensor]], sampling_params: SamplingParams, ) -> tuple[list[int], Optional[ExtraProcessedInputs]]: + try: + _attach_multimodal_embeddings_orig(self, inputs, multimodal_embedding, sampling_params) + except NotImplementedError: + pass + else: + raise ValueError( + "Remove this custom module, Qwen2VLInputProcessorBase implements attach_multimodal_embeddings" + ) + tempdir = tempfile.gettempdir() file_path = Path(tempdir) / "multimodal_embedding.pickle" with open(file_path, "wb") as f: @@ -39,5 +50,4 @@ def _attach_multimodal_embeddings( raise ValueError(file_path) -assert not hasattr(Qwen2VLInputProcessorBase, "attach_multimodal_embeddings") -setattr(Qwen2VLInputProcessorBase, "attach_multimodal_embeddings", _attach_multimodal_embeddings) +Qwen2VLInputProcessorBase.attach_multimodal_embeddings = _attach_multimodal_embeddings From 01cf98132a3d03183de0eea01a02c96bc59f5c7e Mon Sep 17 00:00:00 2001 From: ixlmar <206748156+ixlmar@users.noreply.github.com> Date: Tue, 6 Jan 2026 16:59:35 +0100 Subject: [PATCH 11/11] fix: remove unused import Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com> --- tests/unittest/llmapi/apps/_test_openai_mmencoder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unittest/llmapi/apps/_test_openai_mmencoder.py b/tests/unittest/llmapi/apps/_test_openai_mmencoder.py index 5a5df3fa71..483f9ad994 100644 --- a/tests/unittest/llmapi/apps/_test_openai_mmencoder.py +++ b/tests/unittest/llmapi/apps/_test_openai_mmencoder.py @@ -1,6 +1,6 @@ import os import tempfile -from typing import Any, List +from typing import Any import openai import pytest