diff --git a/cpp/kernels/fmha_v2/pytest.ini b/cpp/kernels/fmha_v2/pytest.ini index 1b7c950701..4ffcf349e9 100644 --- a/cpp/kernels/fmha_v2/pytest.ini +++ b/cpp/kernels/fmha_v2/pytest.ini @@ -6,6 +6,7 @@ markers = fmhca debug bench + needs_l40s # bin: unit tests # test: python script for invoking fmha.exe testpaths = bin test diff --git a/docs/source/commands/trtllm-serve/trtllm-serve.rst b/docs/source/commands/trtllm-serve/trtllm-serve.rst index b26e45de92..c73e903e6c 100644 --- a/docs/source/commands/trtllm-serve/trtllm-serve.rst +++ b/docs/source/commands/trtllm-serve/trtllm-serve.rst @@ -170,6 +170,24 @@ TRT-LLM multimodal supports the following modalities and data types (depending o `load_base64_image utility `__ for implementation details. +**Image embeddings** + +It is also possible to directly provide the image embeddings to use by the multimodal +model. + +* Using "image_embeds" with base64-encoded data: + + .. code-block:: json + + {"role": "user", "content": [ + {"type": "text", "text": "What's in this image?"}, + {"type": "image_embeds", "image_embeds": {"data": "{image_embeddings_base64}"}}} + ]} + +.. note:: + The contents of `image_embeddings_base64` can be generated by base64-encoding + the result of serializing a tensor via `torch.save`. + **Video** * Using "video_url": diff --git a/tensorrt_llm/inputs/__init__.py b/tensorrt_llm/inputs/__init__.py index 406a71d4f5..4587d39f03 100644 --- a/tensorrt_llm/inputs/__init__.py +++ b/tensorrt_llm/inputs/__init__.py @@ -16,7 +16,8 @@ from .utils import (ALL_SUPPORTED_AUDIO_MODELS, ALL_SUPPORTED_IMAGE_MODELS, async_load_audio, async_load_image, async_load_video, convert_image_mode, default_multimodal_input_loader, encode_base64_content_from_url, encode_base64_image, - get_cache_salt_id, load_image, load_video) + get_cache_salt_id, load_base64_image_embeds, load_image, + load_video) __all__ = [ "ALL_SUPPORTED_MULTIMODAL_MODELS", @@ -57,4 +58,5 @@ __all__ = [ "get_cache_salt_id", "compute_retained_tokens_count", "compute_retention_mask", + "load_base64_image_embeds", ] diff --git a/tensorrt_llm/inputs/utils.py b/tensorrt_llm/inputs/utils.py index bbbd5f4f8f..0dc09547b3 100644 --- a/tensorrt_llm/inputs/utils.py +++ b/tensorrt_llm/inputs/utils.py @@ -114,6 +114,15 @@ def load_base64_image(parsed_url: str) -> Image.Image: return image +def load_base64_image_embeds(str_content: str) -> torch.Tensor: + content_bytes = base64.b64decode(str_content) + with BytesIO(content_bytes) as buf: + image_data: torch.Tensor = torch.load(buf, + weights_only=True, + map_location="cpu") + return image_data + + def load_image(image: Union[str, Image.Image], format: str = "pt", device: str = "cpu") -> Union[Image.Image, torch.Tensor]: @@ -425,13 +434,14 @@ class MultimodalData(TypedDict): """Type definition for multimodal data structure.""" modality: str data: Any + is_embedding: bool class ConversationMessage(TypedDict): """Type definition for conversation message structure.""" role: str content: List[dict[str, Any]] - media: List[MultimodalData] | List[torch.Tensor] | List[Dict[str, Any]] + media: List[MultimodalData] # @classmethod # def fromSample(cls, sample: dict[str, str]) -> "ConversationMessage": @@ -446,33 +456,57 @@ class MultimodalDataTracker: model_type: str, multimodal_server_config: Optional[MultimodalServerConfig] = None): self._model_type = model_type - self._data = defaultdict[str](list) - self._placeholder_counts = defaultdict[str](int) + self._data = defaultdict[str, list](list) + self._embeddings = defaultdict[str, list](list) + self._placeholder_counts = defaultdict[str, int](int) self._multimodal_server_config = multimodal_server_config if multimodal_server_config is not None else MultimodalServerConfig( ) - async def retrieve_all_async(self) -> Optional[Dict[str, List[Any]]]: - """Retrieve all collected multimodal data.""" - if not self._data: - return None + async def retrieve_all_async( + self + ) -> tuple[Optional[Dict[str, List[Any]]], Optional[Dict[str, List[Any]]]]: + """Retrieve all collected multimodal data and embeddings.""" - return { - modality: await asyncio.gather(*items) - for modality, items in self._data.items() - } + async def _retrieve( + data: Optional[dict[str, + list]]) -> Optional[Dict[str, List[Any]]]: + if not data: + return None + return { + modality: await asyncio.gather(*items) + for modality, items in data.items() if items + } - def retrieve_all_sync(self) -> Optional[Dict[str, List[Any]]]: - """Retrieve all collected multimodal data.""" - if not self._data: - return None + return await _retrieve(self._data), await _retrieve(self._embeddings) - return {modality: items for modality, items in self._data.items()} + def retrieve_all_sync( + self + ) -> tuple[Optional[Dict[str, List[Any]]], Optional[Dict[str, List[Any]]]]: + """Retrieve all collected multimodal data and embeddings.""" - def add_data(self, media_type: str, data: Union[Coroutine, Any]): - current_count = len(self._data[media_type]) + 1 + def _retrieve( + data: Optional[dict[str, + list]]) -> Optional[Dict[str, List[Any]]]: + if not data: + return None + return { + modality: items + for modality, items in data.items() if items + } + + return _retrieve(self._data), _retrieve(self._embeddings) + + def add_data(self, + media_type: str, + data: Union[Coroutine, Any], + *, + is_embedding: bool = False): + current_count = len(self._data[media_type]) + len( + self._embeddings[media_type]) + 1 placeholder = retrieve_multimodal_placeholder(self._model_type, media_type, current_count) - self._data[media_type].append(data) + (self._embeddings + if is_embedding else self._data)[media_type].append(data) if placeholder: self._placeholder_counts[placeholder] += 1 @@ -643,33 +677,34 @@ def default_multimodal_input_loader( media = [media] if modality in ["image", "multiple_image"]: if is_embedding: + _load = lambda mm: mm + # each mm_embedding corresponds to each image placeholder if not isinstance(media, list): media = [media] - - mm_data = [{ - 'modality': modality, - 'mm_embedding_info': mm - } for mm in media] else: - mm_data = [ - MultimodalData(modality=modality, - data=load_image(i, - format=image_data_format, - device=device)) - for i in media - ] + _load = lambda mm: load_image( + mm, format=image_data_format, device=device) + + mm_data = [ + MultimodalData(modality=modality, + data=_load(mm), + is_embedding=is_embedding) for mm in media + ] elif modality == "video": if is_embedding: raise ValueError( "External embedding is not supported for video modality yet." ) mm_data = [ - MultimodalData(modality=modality, - data=load_video(i, - num_frames, - format=image_data_format, - device=device)) for i in media + MultimodalData( + modality=modality, + data=load_video(i, + num_frames, + format=image_data_format, + device=device), + is_embedding=False, + ) for i in media ] elif modality == "audio": if is_embedding: @@ -677,8 +712,11 @@ def default_multimodal_input_loader( "External embedding is not supported for audio modality yet." ) mm_data = [ - MultimodalData(modality=modality, - data=load_audio(i, device=device)) for i in media + MultimodalData( + modality=modality, + data=load_audio(i, device=device), + is_embedding=False, + ) for i in media ] elif modality == "image_audio": if is_embedding: @@ -706,16 +744,22 @@ def default_multimodal_input_loader( pass if _modal is None: raise ValueError(f"Unknown matching modality: {modality}") - mm_data.append(MultimodalData(modality=_modal, data=data)) + mm_data.append( + MultimodalData(modality=_modal, + data=data, + is_embedding=False)) elif modality == "mixture_text_image": mm_data = [] for m in media: if m: mm_data.append( - MultimodalData(modality="image", - data=load_image(m, - format=image_data_format, - device=device))) + MultimodalData( + modality="image", + data=load_image(m, + format=image_data_format, + device=device), + is_embedding=False, + )) else: raise ValueError(f"Unknown modality: {modality}") return ConversationMessage(role="user", content=prompt, media=mm_data) @@ -749,17 +793,12 @@ def default_multimodal_input_loader( is_embedding) mm_data_tracker = MultimodalDataTracker(model_type) for mdata in conv["media"]: - # Check if mdata is a MultimodalData - if isinstance(mdata, - dict) and "modality" in mdata and "data" in mdata: - mdata_modality = mdata["modality"] - if modality == "multiple_image": - mdata_modality = "image" - mm_data_tracker.add_data(mdata_modality, mdata["data"]) - else: - # Add embeddings to the tracker for placeholder handling - mm_data_tracker.add_data(mdata["modality"], - mdata["mm_embedding_info"]) + mdata_modality = mdata["modality"] + if modality == "multiple_image": + mdata_modality = "image" + mm_data_tracker.add_data(mdata_modality, + mdata["data"], + is_embedding=is_embedding) mm_placeholder_counts = mm_data_tracker.placeholder_counts() prompt = conv["content"] if mm_placeholder_counts: @@ -776,11 +815,13 @@ def default_multimodal_input_loader( if mm_placeholder_counts: if mm_embeddings is not None: - input[ + _, input[ "multi_modal_embeddings"] = mm_data_tracker.retrieve_all_sync( ) else: - input["multi_modal_data"] = mm_data_tracker.retrieve_all_sync() + input[ + "multi_modal_data"], _ = mm_data_tracker.retrieve_all_sync( + ) inputs.append(input) return inputs diff --git a/tensorrt_llm/serve/chat_utils.py b/tensorrt_llm/serve/chat_utils.py index 26ee17c4f4..e08caadaaf 100644 --- a/tensorrt_llm/serve/chat_utils.py +++ b/tensorrt_llm/serve/chat_utils.py @@ -17,7 +17,8 @@ from typing_extensions import Required from tensorrt_llm.inputs import (ConversationMessage, MultimodalData, MultimodalDataTracker, add_multimodal_placeholders, async_load_audio, - async_load_image, async_load_video) + async_load_image, async_load_video, + load_base64_image_embeds) from tensorrt_llm.inputs.multimodal import MultimodalServerConfig from tensorrt_llm.logger import logger @@ -33,24 +34,45 @@ class ChatCompletionContentPartVideoParam(TypedDict, total=False): type: Required[Literal["video_url"]] +class ImageEmbedsData(TypedDict): + """Type definition for serialized image embeddings structure.""" + data: Required[str] + + +class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False): + """Type definition for image embeddings passed in base64-encoded PyTorch tensor format.""" + image_embeds: Required[ + # TODO: Besides "data", could support "url" and "ipc_handle" in the future. + ImageEmbedsData] + type: Required[Literal["image_embeds"]] + + # Type Aliases and Constants ChatCompletionContentPartParam: TypeAlias = Union[ - OpenAIChatCompletionContentPartParam, ChatCompletionContentPartVideoParam, - str] + OpenAIChatCompletionContentPartParam, + ChatCompletionContentPartVideoParam, + ChatCompletionContentPartImageEmbedsParam, + str, +] # TODO: Add "input_audio" to support byte_encoded audio input. VALID_MESSAGE_CONTENT_MM_PART_TYPES = [ - "text", "image_url", "video_url", "audio_url" + "text", + "image_url", + "video_url", + "audio_url", + "image_embeds", ] # Parser Functions _TextParser = partial(cast, ChatCompletionContentPartTextParam) _ImageParser = partial(cast, ChatCompletionContentPartImageParam) +_ImageEmbedsParser = partial(cast, ChatCompletionContentPartImageEmbedsParam) _VideoParser = partial(cast, ChatCompletionContentPartVideoParam) _AudioParser = partial(cast, ChatCompletionContentPartInputAudioParam) MM_PARSER_MAP: dict[str, Callable[[ChatCompletionContentPartParam], Union[ - str, dict[str, str]]]] = { + str, dict[str, str], None]]] = { "text": lambda part: _TextParser(part).get("text", None), "image_url": @@ -59,12 +81,15 @@ MM_PARSER_MAP: dict[str, Callable[[ChatCompletionContentPartParam], Union[ lambda part: _VideoParser(part).get("video_url", {}).get("url", None), "audio_url": lambda part: _AudioParser(part).get("audio_url", {}).get("url", None), + "image_embeds": + lambda part: _ImageEmbedsParser(part).get("image_embeds", {}).get( + "data", None), } def _parse_chat_message_content_mm_part( part: ChatCompletionContentPartParam -) -> tuple[str, Union[str, dict[str, str]]]: +) -> tuple[str, Union[str, dict[str, str], None]]: """Parse a single multimodal part of a chat message.""" assert isinstance(part, dict) part_type = part.get("type", None) @@ -78,9 +103,9 @@ def _parse_chat_message_content_mm_part( def parse_chat_message_content_part( - part: ChatCompletionMessageParam, + part: ChatCompletionContentPartParam, mm_data_tracker: MultimodalDataTracker, -) -> Optional[Any]: +) -> str | MultimodalData | None: """Parse a single part of a chat message.""" if isinstance(part, str): return part @@ -110,7 +135,23 @@ def parse_chat_message_content_part( logger.error(f"Failed to load image: {str(e)}") return None - return MultimodalData(modality="image", data=load_image_async()) + return MultimodalData(modality="image", + data=load_image_async(), + is_embedding=False) + + if part_type == "image_embeds": + str_content = cast(str, content) + + async def decode_image_embeds_async(): + try: + return load_base64_image_embeds(str_content) + except Exception as e: + logger.error(f"Failed to decode image data: {str(e)}") + return None + + return MultimodalData(modality="image", + data=decode_image_embeds_async(), + is_embedding=True) if part_type == "video_url": str_content = cast(str, content) @@ -125,7 +166,9 @@ def parse_chat_message_content_part( logger.error(f"Failed to load video: {str(e)}") return None - return MultimodalData(modality="video", data=load_video_async()) + return MultimodalData(modality="video", + data=load_video_async(), + is_embedding=False) if part_type == "audio_url": str_content = cast(str, content) @@ -140,14 +183,16 @@ def parse_chat_message_content_part( logger.error(f"Failed to load audio: {str(e)}") return None - return MultimodalData(modality="audio", data=load_audio_async()) + return MultimodalData(modality="audio", + data=load_audio_async(), + is_embedding=False) raise NotImplementedError(f"Unknown part type: {part_type}") def parse_chat_message_content_parts( role: str, - parts: Iterable[ChatCompletionMessageParam], + parts: Iterable[ChatCompletionContentPartParam], mm_data_tracker: MultimodalDataTracker, ) -> ConversationMessage: """Parse multiple parts of a chat message.""" @@ -224,8 +269,9 @@ def parse_chat_messages_coroutines( messages: List[ChatCompletionMessageParam], model_config: AutoConfig, multimodal_server_config: Optional[MultimodalServerConfig] = None -) -> Tuple[List[ConversationMessage], Optional[Coroutine[ - Any, Any, Optional[Dict[str, List[Any]]]]]]: +) -> Tuple[List[ConversationMessage], Coroutine[Any, Any, tuple[Optional[Dict[ + str, List[Any]]], Optional[Dict[str, List[Any]]]]], list[dict[str, + int]]]: """Parse multiple chat messages and return conversation and coroutine.""" conversation = [] mm_placeholder_counts = [] @@ -237,7 +283,9 @@ def parse_chat_messages_coroutines( conversation.append(parsed_msg) if parsed_msg["media"]: for mdata in parsed_msg["media"]: - mm_data_tracker.add_data(mdata["modality"], mdata["data"]) + mm_data_tracker.add_data(mdata["modality"], + mdata["data"], + is_embedding=mdata["is_embedding"]) mm_placeholder_count = mm_data_tracker.placeholder_counts() if mm_placeholder_count: parsed_msg["content"] = add_multimodal_placeholders( diff --git a/tensorrt_llm/serve/openai_server.py b/tensorrt_llm/serve/openai_server.py index 44983306dc..afb97aa6f0 100644 --- a/tensorrt_llm/serve/openai_server.py +++ b/tensorrt_llm/serve/openai_server.py @@ -563,9 +563,13 @@ class OpenAIServer: ) prompt = prompt_inputs(prompt) - mm_data = await mm_coroutines - if mm_data is not None: + mm_data, mm_embeddings = await mm_coroutines + if mm_data: prompt["multi_modal_data"] = mm_data + if mm_embeddings: + prompt["multi_modal_embeddings"] = mm_embeddings + if mm_data and mm_embeddings: + raise ValueError("Passing 'multi_modal_data' and 'multi_modal_embeddings' at the same time is not supported.") postproc_args.reasoning_parser = self.llm.args.reasoning_parser postproc_args.tool_parser = self.tool_parser @@ -666,7 +670,9 @@ class OpenAIServer: ) prompt = prompt_inputs(prompt) - mm_data = await mm_coroutines + mm_data, mm_embeddings = await mm_coroutines + if mm_embeddings: + raise ValueError("Cannot use multimodal embeddings as input") if mm_data is not None: prompt["multi_modal_data"] = mm_data diff --git a/tests/integration/defs/test_e2e.py b/tests/integration/defs/test_e2e.py index ab1690c87a..f1d12a3ffb 100644 --- a/tests/integration/defs/test_e2e.py +++ b/tests/integration/defs/test_e2e.py @@ -1683,9 +1683,13 @@ def test_openai_lora(llm_root, llm_venv): def test_openai_chat_multimodal_example(llm_root, llm_venv): test_root = unittest_path() / "llmapi" / "apps" - llm_venv.run_cmd( - ["-m", "pytest", - str(test_root / "_test_openai_chat_multimodal.py")]) + llm_venv.run_cmd([ + "-m", + "pytest", + str(test_root / "_test_openai_chat_multimodal.py"), + "-m", + "not needs_l40s", + ]) def test_openai_mmencoder_example(llm_root, llm_venv): diff --git a/tests/integration/test_lists/test-db/l0_l40s.yml b/tests/integration/test_lists/test-db/l0_l40s.yml index c303789489..d10ba9fc2c 100644 --- a/tests/integration/test_lists/test-db/l0_l40s.yml +++ b/tests/integration/test_lists/test-db/l0_l40s.yml @@ -28,6 +28,7 @@ l0_l40s: - test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[phi4-multimodal-instruct-multimodals/Phi-4-multimodal-instruct-audio] - test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[phi4-multimodal-instruct-multimodals/Phi-4-multimodal-instruct-image] - test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[phi4-multimodal-instruct-multimodals/Phi-4-multimodal-instruct-image_audio] + - unittest/llmapi/apps/_test_openai_chat_multimodal.py::test_single_chat_session_image_embeds -m needs_l40s # MMMU sanity check - accuracy/test_llm_api_pytorch_multimodal.py::TestQwen2_5_VL_7B::test_auto_dtype - accuracy/test_llm_api_pytorch_multimodal.py::TestVILA1_5_3B::test_auto_dtype diff --git a/tests/unittest/llmapi/apps/_attach_multimodal_embeddings_patch/__init__.py b/tests/unittest/llmapi/apps/_attach_multimodal_embeddings_patch/__init__.py new file mode 100644 index 0000000000..7d8281ecd2 --- /dev/null +++ b/tests/unittest/llmapi/apps/_attach_multimodal_embeddings_patch/__init__.py @@ -0,0 +1,53 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# used by tests/unittest/llmapi/apps/_test_openai_chat_multimodal.py + +import tempfile +from pathlib import Path +from typing import Optional + +import torch + +from tensorrt_llm._torch.models.modeling_qwen2vl import Qwen2VLInputProcessorBase +from tensorrt_llm.inputs import ExtraProcessedInputs, TextPrompt +from tensorrt_llm.sampling_params import SamplingParams + +_attach_multimodal_embeddings_orig = Qwen2VLInputProcessorBase.attach_multimodal_embeddings + + +# signature taken from tensorrt_llm/inputs/registry.py +def _attach_multimodal_embeddings( + self, + inputs: TextPrompt, + multimodal_embedding: dict[str, list[torch.Tensor]], + sampling_params: SamplingParams, +) -> tuple[list[int], Optional[ExtraProcessedInputs]]: + try: + _attach_multimodal_embeddings_orig(self, inputs, multimodal_embedding, sampling_params) + except NotImplementedError: + pass + else: + raise ValueError( + "Remove this custom module, Qwen2VLInputProcessorBase implements attach_multimodal_embeddings" + ) + + tempdir = tempfile.gettempdir() + file_path = Path(tempdir) / "multimodal_embedding.pickle" + with open(file_path, "wb") as f: + torch.save(multimodal_embedding, f) + raise ValueError(file_path) + + +Qwen2VLInputProcessorBase.attach_multimodal_embeddings = _attach_multimodal_embeddings diff --git a/tests/unittest/llmapi/apps/_test_openai_chat_multimodal.py b/tests/unittest/llmapi/apps/_test_openai_chat_multimodal.py index fda0f8a493..4183e1874e 100644 --- a/tests/unittest/llmapi/apps/_test_openai_chat_multimodal.py +++ b/tests/unittest/llmapi/apps/_test_openai_chat_multimodal.py @@ -1,13 +1,18 @@ +import io import os +import sys import tempfile +from base64 import b64encode from pathlib import Path from typing import List import openai import pytest +import torch import yaml from PIL import Image +from tensorrt_llm._torch.shared_tensor import SharedTensorContainer from tensorrt_llm.inputs import encode_base64_image from ..test_llm import get_model_path @@ -17,6 +22,13 @@ pytestmark = pytest.mark.threadleak(enabled=False) from utils.llm_data import llm_models_root +from ._test_openai_mmencoder import RemoteMMEncoderServer +from ._test_openai_mmencoder import server as mm_encoder_server +from ._test_openai_mmencoder import \ + test_multimodal_content_mm_encoder as _test_multimodal_content_mm_encoder + +assert mm_encoder_server is not None # keep 'mm_encoder_server' fixture visible in this module + @pytest.fixture(scope="module", ids=["Qwen2.5-VL-3B-Instruct"]) def model_name(): @@ -25,7 +37,7 @@ def model_name(): @pytest.fixture(scope="module") def temp_extra_llm_api_options_file(request): - temp_dir = tempfile.gettempdir() + temp_dir = tempfile.mkdtemp() temp_file_path = os.path.join(temp_dir, "extra_llm_api_options.yaml") try: extra_llm_api_options_dict = { @@ -123,6 +135,98 @@ def test_single_chat_session_image(client: openai.OpenAI, model_name: str): == chat_completion.choices[0].message.content +# used by mm_encoder_server +@pytest.fixture(scope="module") +def extra_encoder_options() -> bool: + return False + + +# used by mm_encoder_server +@pytest.fixture(scope="module") +def temp_extra_encoder_options_file() -> str: + return "/dummy/path" + + +@pytest.fixture(scope="module") +def server_patched(model_name: str, temp_extra_llm_api_options_file: str): + # Custom module implements missing 'attach_multimodal_embeddings' to intercept + # embeddings. + model_path = get_model_path(model_name) + args = [ + "--extra_llm_api_options", + temp_extra_llm_api_options_file, + "--max_batch_size", + "64", + "--max_num_tokens", + "16384", + "--custom_module_dirs", + str( + Path(sys.modules[test_single_chat_session_image_embeds.__module__]. + __file__).parent / "_attach_multimodal_embeddings_patch"), + ] + with RemoteOpenAIServer(model_path, args) as remote_server: + yield remote_server + + +@pytest.mark.needs_l40s +@pytest.mark.asyncio(loop_scope="module") +def test_single_chat_session_image_embeds( + server_patched: RemoteOpenAIServer, + model_name: str, + mm_encoder_server: RemoteMMEncoderServer, +): + client = server_patched.get_client() + messages, mm_embed_handle = _test_multimodal_content_mm_encoder( + mm_encoder_server.get_client(), model_name) + + max_completion_tokens = 10 + + chat_completion_image = client.chat.completions.create( + model=model_name, + messages=messages, + max_completion_tokens=max_completion_tokens, + temperature=0.0, + logprobs=False) + + mm_embed = SharedTensorContainer.from_dict(mm_embed_handle).get_local_view() + with io.BytesIO() as buf: + torch.save(mm_embed, buf) + mm_embed_bytes = buf.getvalue() + + image_content = messages[0]["content"][1] + assert image_content["type"] == "image_url" + image_content.clear() + image_content["type"] = "image_embeds" + image_content["image_embeds"] = { + "data": b64encode(mm_embed_bytes).decode("ascii") + } + + # test single completion + # + # FIXME: Remove try-except and use 'server' instead of 'server_patched', + # once Qwen2VLInputProcessorBase implements attach_multimodal_embeddings. + try: + chat_completion_embeds = client.chat.completions.create( + model=model_name, + messages=messages, + max_completion_tokens=max_completion_tokens, + temperature=0.0, + logprobs=False) + + assert chat_completion_embeds.choices[ + 0].message == chat_completion_image.choices[0].message + except openai.BadRequestError as e: + assert isinstance(e.body, dict) + with open(Path(e.body["message"]), "rb") as f: + intercepted_embeddings = torch.load(f, weights_only=True) + assert list(intercepted_embeddings.keys()) == ["image"] + assert len(intercepted_embeddings["image"]) == 1 + torch.testing.assert_close(intercepted_embeddings["image"][0], + mm_embed.cpu()) + pytest.xfail( + reason="Model does not implement 'attach_multimodal_embeddings'") + + @pytest.mark.asyncio(loop_scope="module") def test_single_chat_session_multi_image(client: openai.OpenAI, model_name: str): diff --git a/tests/unittest/llmapi/apps/_test_openai_mmencoder.py b/tests/unittest/llmapi/apps/_test_openai_mmencoder.py index 312f9232d4..483f9ad994 100644 --- a/tests/unittest/llmapi/apps/_test_openai_mmencoder.py +++ b/tests/unittest/llmapi/apps/_test_openai_mmencoder.py @@ -1,5 +1,6 @@ import os import tempfile +from typing import Any import openai import pytest @@ -67,7 +68,9 @@ def async_client(server: RemoteMMEncoderServer): return server.get_async_client() -def test_multimodal_content_mm_encoder(client: openai.OpenAI, model_name: str): +def test_multimodal_content_mm_encoder( + client: openai.OpenAI, + model_name: str) -> tuple[list[dict[str, Any]], dict[str, Any]]: content_text = "Describe the natural environment in the image." image_url = str(llm_models_root() / "multimodals" / "test_data" / @@ -105,6 +108,8 @@ def test_multimodal_content_mm_encoder(client: openai.OpenAI, model_name: str): assert mm_handle["tensor_size"][ 1] == 2048 # qwen2.5-vl: hidden_size of the vision encoder + return messages, mm_handle # used by tests/unittest/llmapi/apps/_test_openai_chat_multimodal.py + def test_health(server: RemoteMMEncoderServer): health_url = server.url_for("health")