This commit is contained in:
mpikulski 2026-01-13 21:25:08 +08:00 committed by GitHub
commit 2d45b482e0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 363 additions and 80 deletions

View File

@ -6,6 +6,7 @@ markers =
fmhca
debug
bench
needs_l40s
# bin: unit tests
# test: python script for invoking fmha.exe
testpaths = bin test

View File

@ -170,6 +170,24 @@ TRT-LLM multimodal supports the following modalities and data types (depending o
`load_base64_image utility <https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/utils/load_base64_image.py>`__
for implementation details.
**Image embeddings**
It is also possible to directly provide the image embeddings to use by the multimodal
model.
* Using "image_embeds" with base64-encoded data:
.. code-block:: json
{"role": "user", "content": [
{"type": "text", "text": "What's in this image?"},
{"type": "image_embeds", "image_embeds": {"data": "{image_embeddings_base64}"}}}
]}
.. note::
The contents of `image_embeddings_base64` can be generated by base64-encoding
the result of serializing a tensor via `torch.save`.
**Video**
* Using "video_url":

View File

@ -16,7 +16,8 @@ from .utils import (ALL_SUPPORTED_AUDIO_MODELS, ALL_SUPPORTED_IMAGE_MODELS,
async_load_audio, async_load_image, async_load_video,
convert_image_mode, default_multimodal_input_loader,
encode_base64_content_from_url, encode_base64_image,
get_cache_salt_id, load_image, load_video)
get_cache_salt_id, load_base64_image_embeds, load_image,
load_video)
__all__ = [
"ALL_SUPPORTED_MULTIMODAL_MODELS",
@ -57,4 +58,5 @@ __all__ = [
"get_cache_salt_id",
"compute_retained_tokens_count",
"compute_retention_mask",
"load_base64_image_embeds",
]

View File

@ -114,6 +114,15 @@ def load_base64_image(parsed_url: str) -> Image.Image:
return image
def load_base64_image_embeds(str_content: str) -> torch.Tensor:
content_bytes = base64.b64decode(str_content)
with BytesIO(content_bytes) as buf:
image_data: torch.Tensor = torch.load(buf,
weights_only=True,
map_location="cpu")
return image_data
def load_image(image: Union[str, Image.Image],
format: str = "pt",
device: str = "cpu") -> Union[Image.Image, torch.Tensor]:
@ -425,13 +434,14 @@ class MultimodalData(TypedDict):
"""Type definition for multimodal data structure."""
modality: str
data: Any
is_embedding: bool
class ConversationMessage(TypedDict):
"""Type definition for conversation message structure."""
role: str
content: List[dict[str, Any]]
media: List[MultimodalData] | List[torch.Tensor] | List[Dict[str, Any]]
media: List[MultimodalData]
# @classmethod
# def fromSample(cls, sample: dict[str, str]) -> "ConversationMessage":
@ -446,33 +456,57 @@ class MultimodalDataTracker:
model_type: str,
multimodal_server_config: Optional[MultimodalServerConfig] = None):
self._model_type = model_type
self._data = defaultdict[str](list)
self._placeholder_counts = defaultdict[str](int)
self._data = defaultdict[str, list](list)
self._embeddings = defaultdict[str, list](list)
self._placeholder_counts = defaultdict[str, int](int)
self._multimodal_server_config = multimodal_server_config if multimodal_server_config is not None else MultimodalServerConfig(
)
async def retrieve_all_async(self) -> Optional[Dict[str, List[Any]]]:
"""Retrieve all collected multimodal data."""
if not self._data:
return None
async def retrieve_all_async(
self
) -> tuple[Optional[Dict[str, List[Any]]], Optional[Dict[str, List[Any]]]]:
"""Retrieve all collected multimodal data and embeddings."""
return {
modality: await asyncio.gather(*items)
for modality, items in self._data.items()
}
async def _retrieve(
data: Optional[dict[str,
list]]) -> Optional[Dict[str, List[Any]]]:
if not data:
return None
return {
modality: await asyncio.gather(*items)
for modality, items in data.items() if items
}
def retrieve_all_sync(self) -> Optional[Dict[str, List[Any]]]:
"""Retrieve all collected multimodal data."""
if not self._data:
return None
return await _retrieve(self._data), await _retrieve(self._embeddings)
return {modality: items for modality, items in self._data.items()}
def retrieve_all_sync(
self
) -> tuple[Optional[Dict[str, List[Any]]], Optional[Dict[str, List[Any]]]]:
"""Retrieve all collected multimodal data and embeddings."""
def add_data(self, media_type: str, data: Union[Coroutine, Any]):
current_count = len(self._data[media_type]) + 1
def _retrieve(
data: Optional[dict[str,
list]]) -> Optional[Dict[str, List[Any]]]:
if not data:
return None
return {
modality: items
for modality, items in data.items() if items
}
return _retrieve(self._data), _retrieve(self._embeddings)
def add_data(self,
media_type: str,
data: Union[Coroutine, Any],
*,
is_embedding: bool = False):
current_count = len(self._data[media_type]) + len(
self._embeddings[media_type]) + 1
placeholder = retrieve_multimodal_placeholder(self._model_type,
media_type, current_count)
self._data[media_type].append(data)
(self._embeddings
if is_embedding else self._data)[media_type].append(data)
if placeholder:
self._placeholder_counts[placeholder] += 1
@ -643,33 +677,34 @@ def default_multimodal_input_loader(
media = [media]
if modality in ["image", "multiple_image"]:
if is_embedding:
_load = lambda mm: mm
# each mm_embedding corresponds to each image placeholder
if not isinstance(media, list):
media = [media]
mm_data = [{
'modality': modality,
'mm_embedding_info': mm
} for mm in media]
else:
mm_data = [
MultimodalData(modality=modality,
data=load_image(i,
format=image_data_format,
device=device))
for i in media
]
_load = lambda mm: load_image(
mm, format=image_data_format, device=device)
mm_data = [
MultimodalData(modality=modality,
data=_load(mm),
is_embedding=is_embedding) for mm in media
]
elif modality == "video":
if is_embedding:
raise ValueError(
"External embedding is not supported for video modality yet."
)
mm_data = [
MultimodalData(modality=modality,
data=load_video(i,
num_frames,
format=image_data_format,
device=device)) for i in media
MultimodalData(
modality=modality,
data=load_video(i,
num_frames,
format=image_data_format,
device=device),
is_embedding=False,
) for i in media
]
elif modality == "audio":
if is_embedding:
@ -677,8 +712,11 @@ def default_multimodal_input_loader(
"External embedding is not supported for audio modality yet."
)
mm_data = [
MultimodalData(modality=modality,
data=load_audio(i, device=device)) for i in media
MultimodalData(
modality=modality,
data=load_audio(i, device=device),
is_embedding=False,
) for i in media
]
elif modality == "image_audio":
if is_embedding:
@ -706,16 +744,22 @@ def default_multimodal_input_loader(
pass
if _modal is None:
raise ValueError(f"Unknown matching modality: {modality}")
mm_data.append(MultimodalData(modality=_modal, data=data))
mm_data.append(
MultimodalData(modality=_modal,
data=data,
is_embedding=False))
elif modality == "mixture_text_image":
mm_data = []
for m in media:
if m:
mm_data.append(
MultimodalData(modality="image",
data=load_image(m,
format=image_data_format,
device=device)))
MultimodalData(
modality="image",
data=load_image(m,
format=image_data_format,
device=device),
is_embedding=False,
))
else:
raise ValueError(f"Unknown modality: {modality}")
return ConversationMessage(role="user", content=prompt, media=mm_data)
@ -749,17 +793,12 @@ def default_multimodal_input_loader(
is_embedding)
mm_data_tracker = MultimodalDataTracker(model_type)
for mdata in conv["media"]:
# Check if mdata is a MultimodalData
if isinstance(mdata,
dict) and "modality" in mdata and "data" in mdata:
mdata_modality = mdata["modality"]
if modality == "multiple_image":
mdata_modality = "image"
mm_data_tracker.add_data(mdata_modality, mdata["data"])
else:
# Add embeddings to the tracker for placeholder handling
mm_data_tracker.add_data(mdata["modality"],
mdata["mm_embedding_info"])
mdata_modality = mdata["modality"]
if modality == "multiple_image":
mdata_modality = "image"
mm_data_tracker.add_data(mdata_modality,
mdata["data"],
is_embedding=is_embedding)
mm_placeholder_counts = mm_data_tracker.placeholder_counts()
prompt = conv["content"]
if mm_placeholder_counts:
@ -776,11 +815,13 @@ def default_multimodal_input_loader(
if mm_placeholder_counts:
if mm_embeddings is not None:
input[
_, input[
"multi_modal_embeddings"] = mm_data_tracker.retrieve_all_sync(
)
else:
input["multi_modal_data"] = mm_data_tracker.retrieve_all_sync()
input[
"multi_modal_data"], _ = mm_data_tracker.retrieve_all_sync(
)
inputs.append(input)
return inputs

View File

@ -17,7 +17,8 @@ from typing_extensions import Required
from tensorrt_llm.inputs import (ConversationMessage, MultimodalData,
MultimodalDataTracker,
add_multimodal_placeholders, async_load_audio,
async_load_image, async_load_video)
async_load_image, async_load_video,
load_base64_image_embeds)
from tensorrt_llm.inputs.multimodal import MultimodalServerConfig
from tensorrt_llm.logger import logger
@ -33,24 +34,45 @@ class ChatCompletionContentPartVideoParam(TypedDict, total=False):
type: Required[Literal["video_url"]]
class ImageEmbedsData(TypedDict):
"""Type definition for serialized image embeddings structure."""
data: Required[str]
class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False):
"""Type definition for image embeddings passed in base64-encoded PyTorch tensor format."""
image_embeds: Required[
# TODO: Besides "data", could support "url" and "ipc_handle" in the future.
ImageEmbedsData]
type: Required[Literal["image_embeds"]]
# Type Aliases and Constants
ChatCompletionContentPartParam: TypeAlias = Union[
OpenAIChatCompletionContentPartParam, ChatCompletionContentPartVideoParam,
str]
OpenAIChatCompletionContentPartParam,
ChatCompletionContentPartVideoParam,
ChatCompletionContentPartImageEmbedsParam,
str,
]
# TODO: Add "input_audio" to support byte_encoded audio input.
VALID_MESSAGE_CONTENT_MM_PART_TYPES = [
"text", "image_url", "video_url", "audio_url"
"text",
"image_url",
"video_url",
"audio_url",
"image_embeds",
]
# Parser Functions
_TextParser = partial(cast, ChatCompletionContentPartTextParam)
_ImageParser = partial(cast, ChatCompletionContentPartImageParam)
_ImageEmbedsParser = partial(cast, ChatCompletionContentPartImageEmbedsParam)
_VideoParser = partial(cast, ChatCompletionContentPartVideoParam)
_AudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
MM_PARSER_MAP: dict[str, Callable[[ChatCompletionContentPartParam], Union[
str, dict[str, str]]]] = {
str, dict[str, str], None]]] = {
"text":
lambda part: _TextParser(part).get("text", None),
"image_url":
@ -59,12 +81,15 @@ MM_PARSER_MAP: dict[str, Callable[[ChatCompletionContentPartParam], Union[
lambda part: _VideoParser(part).get("video_url", {}).get("url", None),
"audio_url":
lambda part: _AudioParser(part).get("audio_url", {}).get("url", None),
"image_embeds":
lambda part: _ImageEmbedsParser(part).get("image_embeds", {}).get(
"data", None),
}
def _parse_chat_message_content_mm_part(
part: ChatCompletionContentPartParam
) -> tuple[str, Union[str, dict[str, str]]]:
) -> tuple[str, Union[str, dict[str, str], None]]:
"""Parse a single multimodal part of a chat message."""
assert isinstance(part, dict)
part_type = part.get("type", None)
@ -78,9 +103,9 @@ def _parse_chat_message_content_mm_part(
def parse_chat_message_content_part(
part: ChatCompletionMessageParam,
part: ChatCompletionContentPartParam,
mm_data_tracker: MultimodalDataTracker,
) -> Optional[Any]:
) -> str | MultimodalData | None:
"""Parse a single part of a chat message."""
if isinstance(part, str):
return part
@ -110,7 +135,23 @@ def parse_chat_message_content_part(
logger.error(f"Failed to load image: {str(e)}")
return None
return MultimodalData(modality="image", data=load_image_async())
return MultimodalData(modality="image",
data=load_image_async(),
is_embedding=False)
if part_type == "image_embeds":
str_content = cast(str, content)
async def decode_image_embeds_async():
try:
return load_base64_image_embeds(str_content)
except Exception as e:
logger.error(f"Failed to decode image data: {str(e)}")
return None
return MultimodalData(modality="image",
data=decode_image_embeds_async(),
is_embedding=True)
if part_type == "video_url":
str_content = cast(str, content)
@ -125,7 +166,9 @@ def parse_chat_message_content_part(
logger.error(f"Failed to load video: {str(e)}")
return None
return MultimodalData(modality="video", data=load_video_async())
return MultimodalData(modality="video",
data=load_video_async(),
is_embedding=False)
if part_type == "audio_url":
str_content = cast(str, content)
@ -140,14 +183,16 @@ def parse_chat_message_content_part(
logger.error(f"Failed to load audio: {str(e)}")
return None
return MultimodalData(modality="audio", data=load_audio_async())
return MultimodalData(modality="audio",
data=load_audio_async(),
is_embedding=False)
raise NotImplementedError(f"Unknown part type: {part_type}")
def parse_chat_message_content_parts(
role: str,
parts: Iterable[ChatCompletionMessageParam],
parts: Iterable[ChatCompletionContentPartParam],
mm_data_tracker: MultimodalDataTracker,
) -> ConversationMessage:
"""Parse multiple parts of a chat message."""
@ -224,8 +269,9 @@ def parse_chat_messages_coroutines(
messages: List[ChatCompletionMessageParam],
model_config: AutoConfig,
multimodal_server_config: Optional[MultimodalServerConfig] = None
) -> Tuple[List[ConversationMessage], Optional[Coroutine[
Any, Any, Optional[Dict[str, List[Any]]]]]]:
) -> Tuple[List[ConversationMessage], Coroutine[Any, Any, tuple[Optional[Dict[
str, List[Any]]], Optional[Dict[str, List[Any]]]]], list[dict[str,
int]]]:
"""Parse multiple chat messages and return conversation and coroutine."""
conversation = []
mm_placeholder_counts = []
@ -237,7 +283,9 @@ def parse_chat_messages_coroutines(
conversation.append(parsed_msg)
if parsed_msg["media"]:
for mdata in parsed_msg["media"]:
mm_data_tracker.add_data(mdata["modality"], mdata["data"])
mm_data_tracker.add_data(mdata["modality"],
mdata["data"],
is_embedding=mdata["is_embedding"])
mm_placeholder_count = mm_data_tracker.placeholder_counts()
if mm_placeholder_count:
parsed_msg["content"] = add_multimodal_placeholders(

View File

@ -563,9 +563,13 @@ class OpenAIServer:
)
prompt = prompt_inputs(prompt)
mm_data = await mm_coroutines
if mm_data is not None:
mm_data, mm_embeddings = await mm_coroutines
if mm_data:
prompt["multi_modal_data"] = mm_data
if mm_embeddings:
prompt["multi_modal_embeddings"] = mm_embeddings
if mm_data and mm_embeddings:
raise ValueError("Passing 'multi_modal_data' and 'multi_modal_embeddings' at the same time is not supported.")
postproc_args.reasoning_parser = self.llm.args.reasoning_parser
postproc_args.tool_parser = self.tool_parser
@ -666,7 +670,9 @@ class OpenAIServer:
)
prompt = prompt_inputs(prompt)
mm_data = await mm_coroutines
mm_data, mm_embeddings = await mm_coroutines
if mm_embeddings:
raise ValueError("Cannot use multimodal embeddings as input")
if mm_data is not None:
prompt["multi_modal_data"] = mm_data

View File

@ -1683,9 +1683,13 @@ def test_openai_lora(llm_root, llm_venv):
def test_openai_chat_multimodal_example(llm_root, llm_venv):
test_root = unittest_path() / "llmapi" / "apps"
llm_venv.run_cmd(
["-m", "pytest",
str(test_root / "_test_openai_chat_multimodal.py")])
llm_venv.run_cmd([
"-m",
"pytest",
str(test_root / "_test_openai_chat_multimodal.py"),
"-m",
"not needs_l40s",
])
def test_openai_mmencoder_example(llm_root, llm_venv):

View File

@ -28,6 +28,7 @@ l0_l40s:
- test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[phi4-multimodal-instruct-multimodals/Phi-4-multimodal-instruct-audio]
- test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[phi4-multimodal-instruct-multimodals/Phi-4-multimodal-instruct-image]
- test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[phi4-multimodal-instruct-multimodals/Phi-4-multimodal-instruct-image_audio]
- unittest/llmapi/apps/_test_openai_chat_multimodal.py::test_single_chat_session_image_embeds -m needs_l40s
# MMMU sanity check
- accuracy/test_llm_api_pytorch_multimodal.py::TestQwen2_5_VL_7B::test_auto_dtype
- accuracy/test_llm_api_pytorch_multimodal.py::TestVILA1_5_3B::test_auto_dtype

View File

@ -0,0 +1,53 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# used by tests/unittest/llmapi/apps/_test_openai_chat_multimodal.py
import tempfile
from pathlib import Path
from typing import Optional
import torch
from tensorrt_llm._torch.models.modeling_qwen2vl import Qwen2VLInputProcessorBase
from tensorrt_llm.inputs import ExtraProcessedInputs, TextPrompt
from tensorrt_llm.sampling_params import SamplingParams
_attach_multimodal_embeddings_orig = Qwen2VLInputProcessorBase.attach_multimodal_embeddings
# signature taken from tensorrt_llm/inputs/registry.py
def _attach_multimodal_embeddings(
self,
inputs: TextPrompt,
multimodal_embedding: dict[str, list[torch.Tensor]],
sampling_params: SamplingParams,
) -> tuple[list[int], Optional[ExtraProcessedInputs]]:
try:
_attach_multimodal_embeddings_orig(self, inputs, multimodal_embedding, sampling_params)
except NotImplementedError:
pass
else:
raise ValueError(
"Remove this custom module, Qwen2VLInputProcessorBase implements attach_multimodal_embeddings"
)
tempdir = tempfile.gettempdir()
file_path = Path(tempdir) / "multimodal_embedding.pickle"
with open(file_path, "wb") as f:
torch.save(multimodal_embedding, f)
raise ValueError(file_path)
Qwen2VLInputProcessorBase.attach_multimodal_embeddings = _attach_multimodal_embeddings

View File

@ -1,13 +1,18 @@
import io
import os
import sys
import tempfile
from base64 import b64encode
from pathlib import Path
from typing import List
import openai
import pytest
import torch
import yaml
from PIL import Image
from tensorrt_llm._torch.shared_tensor import SharedTensorContainer
from tensorrt_llm.inputs import encode_base64_image
from ..test_llm import get_model_path
@ -17,6 +22,13 @@ pytestmark = pytest.mark.threadleak(enabled=False)
from utils.llm_data import llm_models_root
from ._test_openai_mmencoder import RemoteMMEncoderServer
from ._test_openai_mmencoder import server as mm_encoder_server
from ._test_openai_mmencoder import \
test_multimodal_content_mm_encoder as _test_multimodal_content_mm_encoder
assert mm_encoder_server is not None # keep 'mm_encoder_server' fixture visible in this module
@pytest.fixture(scope="module", ids=["Qwen2.5-VL-3B-Instruct"])
def model_name():
@ -25,7 +37,7 @@ def model_name():
@pytest.fixture(scope="module")
def temp_extra_llm_api_options_file(request):
temp_dir = tempfile.gettempdir()
temp_dir = tempfile.mkdtemp()
temp_file_path = os.path.join(temp_dir, "extra_llm_api_options.yaml")
try:
extra_llm_api_options_dict = {
@ -123,6 +135,98 @@ def test_single_chat_session_image(client: openai.OpenAI, model_name: str):
== chat_completion.choices[0].message.content
# used by mm_encoder_server
@pytest.fixture(scope="module")
def extra_encoder_options() -> bool:
return False
# used by mm_encoder_server
@pytest.fixture(scope="module")
def temp_extra_encoder_options_file() -> str:
return "/dummy/path"
@pytest.fixture(scope="module")
def server_patched(model_name: str, temp_extra_llm_api_options_file: str):
# Custom module implements missing 'attach_multimodal_embeddings' to intercept
# embeddings.
model_path = get_model_path(model_name)
args = [
"--extra_llm_api_options",
temp_extra_llm_api_options_file,
"--max_batch_size",
"64",
"--max_num_tokens",
"16384",
"--custom_module_dirs",
str(
Path(sys.modules[test_single_chat_session_image_embeds.__module__].
__file__).parent / "_attach_multimodal_embeddings_patch"),
]
with RemoteOpenAIServer(model_path, args) as remote_server:
yield remote_server
@pytest.mark.needs_l40s
@pytest.mark.asyncio(loop_scope="module")
def test_single_chat_session_image_embeds(
server_patched: RemoteOpenAIServer,
model_name: str,
mm_encoder_server: RemoteMMEncoderServer,
):
client = server_patched.get_client()
messages, mm_embed_handle = _test_multimodal_content_mm_encoder(
mm_encoder_server.get_client(), model_name)
max_completion_tokens = 10
chat_completion_image = client.chat.completions.create(
model=model_name,
messages=messages,
max_completion_tokens=max_completion_tokens,
temperature=0.0,
logprobs=False)
mm_embed = SharedTensorContainer.from_dict(mm_embed_handle).get_local_view()
with io.BytesIO() as buf:
torch.save(mm_embed, buf)
mm_embed_bytes = buf.getvalue()
image_content = messages[0]["content"][1]
assert image_content["type"] == "image_url"
image_content.clear()
image_content["type"] = "image_embeds"
image_content["image_embeds"] = {
"data": b64encode(mm_embed_bytes).decode("ascii")
}
# test single completion
#
# FIXME: Remove try-except and use 'server' instead of 'server_patched',
# once Qwen2VLInputProcessorBase implements attach_multimodal_embeddings.
try:
chat_completion_embeds = client.chat.completions.create(
model=model_name,
messages=messages,
max_completion_tokens=max_completion_tokens,
temperature=0.0,
logprobs=False)
assert chat_completion_embeds.choices[
0].message == chat_completion_image.choices[0].message
except openai.BadRequestError as e:
assert isinstance(e.body, dict)
with open(Path(e.body["message"]), "rb") as f:
intercepted_embeddings = torch.load(f, weights_only=True)
assert list(intercepted_embeddings.keys()) == ["image"]
assert len(intercepted_embeddings["image"]) == 1
torch.testing.assert_close(intercepted_embeddings["image"][0],
mm_embed.cpu())
pytest.xfail(
reason="Model does not implement 'attach_multimodal_embeddings'")
@pytest.mark.asyncio(loop_scope="module")
def test_single_chat_session_multi_image(client: openai.OpenAI,
model_name: str):

View File

@ -1,5 +1,6 @@
import os
import tempfile
from typing import Any
import openai
import pytest
@ -67,7 +68,9 @@ def async_client(server: RemoteMMEncoderServer):
return server.get_async_client()
def test_multimodal_content_mm_encoder(client: openai.OpenAI, model_name: str):
def test_multimodal_content_mm_encoder(
client: openai.OpenAI,
model_name: str) -> tuple[list[dict[str, Any]], dict[str, Any]]:
content_text = "Describe the natural environment in the image."
image_url = str(llm_models_root() / "multimodals" / "test_data" /
@ -105,6 +108,8 @@ def test_multimodal_content_mm_encoder(client: openai.OpenAI, model_name: str):
assert mm_handle["tensor_size"][
1] == 2048 # qwen2.5-vl: hidden_size of the vision encoder
return messages, mm_handle # used by tests/unittest/llmapi/apps/_test_openai_chat_multimodal.py
def test_health(server: RemoteMMEncoderServer):
health_url = server.url_for("health")