chore: add is_embedding to MultimodalData

Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com>
This commit is contained in:
ixlmar 2025-12-11 12:58:03 +00:00 committed by mpikulski
parent 0543bf01fb
commit db14542c35
4 changed files with 115 additions and 95 deletions

View File

@ -244,7 +244,7 @@ class MultimodalLmEvalWrapper(LmEvalWrapper):
# NOTE: Since we already have loaded images, for the placeholder purpose, we add data here.
for _ in range(image_count):
mm_data_tracker.add_data("image", None)
mm_data_tracker.add_data("image", None, is_embedding=False)
mm_placeholder_count = mm_data_tracker.placeholder_counts()
if mm_placeholder_count:
# TODO: This is an assumption of not interleaving text and image. Need to extend to interleaved texts.

View File

@ -434,13 +434,14 @@ class MultimodalData(TypedDict):
"""Type definition for multimodal data structure."""
modality: str
data: Any
is_embedding: bool
class ConversationMessage(TypedDict):
"""Type definition for conversation message structure."""
role: str
content: List[dict[str, Any]]
media: List[MultimodalData] | List[torch.Tensor] | List[Dict[str, Any]]
media: List[MultimodalData]
# @classmethod
# def fromSample(cls, sample: dict[str, str]) -> "ConversationMessage":
@ -455,38 +456,54 @@ class MultimodalDataTracker:
model_type: str,
multimodal_server_config: Optional[MultimodalServerConfig] = None):
self._model_type = model_type
self._data = defaultdict[str](list)
self._placeholder_counts = defaultdict[str](int)
self._data = defaultdict[str, list](list)
self._embeddings = defaultdict[str, list](list)
self._placeholder_counts = defaultdict[str, int](int)
self._multimodal_server_config = multimodal_server_config if multimodal_server_config is not None else MultimodalServerConfig(
)
async def retrieve_all_async(self) -> Optional[Dict[str, List[Any]]]:
"""Retrieve all collected multimodal data."""
if not self._data:
return None
async def retrieve_all_async(
self
) -> tuple[Optional[Dict[str, List[Any]]], Optional[Dict[str, List[Any]]]]:
"""Retrieve all collected multimodal data and embeddings."""
return {
modality: await asyncio.gather(*items)
for modality, items in self._data.items()
}
async def _retrieve(
data: Optional[dict[str,
list]]) -> Optional[Dict[str, List[Any]]]:
if not data:
return None
return {
modality: await asyncio.gather(*items)
for modality, items in data.items() if items
}
def retrieve_all_sync(self) -> Optional[Dict[str, List[Any]]]:
"""Retrieve all collected multimodal data."""
if not self._data:
return None
return await _retrieve(self._data), await _retrieve(self._embeddings)
return {modality: items for modality, items in self._data.items()}
def retrieve_all_sync(
self
) -> tuple[Optional[Dict[str, List[Any]]], Optional[Dict[str, List[Any]]]]:
"""Retrieve all collected multimodal data and embeddings."""
def add_data(self,
media_type: str,
data: Union[Coroutine, Any],
*,
modality: Optional[str] = None):
modality = modality or media_type
current_count = len(self._data[media_type]) + 1
def _retrieve(
data: Optional[dict[str,
list]]) -> Optional[Dict[str, List[Any]]]:
if not data:
return None
return {
modality: items
for modality, items in data.items() if items
}
return _retrieve(self._data), _retrieve(self._embeddings)
def add_data(self, media_type: str, data: Union[Coroutine, Any], *,
is_embedding: bool):
current_count = len(self._data[media_type]) + len(
self._embeddings[media_type]) + 1
placeholder = retrieve_multimodal_placeholder(self._model_type,
modality, current_count)
self._data[media_type].append(data)
media_type, current_count)
(self._embeddings
if is_embedding else self._data)[media_type].append(data)
if placeholder:
self._placeholder_counts[placeholder] += 1
@ -657,33 +674,34 @@ def default_multimodal_input_loader(
media = [media]
if modality in ["image", "multiple_image"]:
if is_embedding:
_load = lambda mm: mm
# each mm_embedding corresponds to each image placeholder
if not isinstance(media, list):
media = [media]
mm_data = [{
'modality': modality,
'mm_embedding_info': mm
} for mm in media]
else:
mm_data = [
MultimodalData(modality=modality,
data=load_image(i,
format=image_data_format,
device=device))
for i in media
]
_load = lambda mm: load_image(
mm, format=image_data_format, device=device)
mm_data = [
MultimodalData(modality=modality,
data=_load(mm),
is_embedding=is_embedding) for mm in media
]
elif modality == "video":
if is_embedding:
raise ValueError(
"External embedding is not supported for video modality yet."
)
mm_data = [
MultimodalData(modality=modality,
data=load_video(i,
num_frames,
format=image_data_format,
device=device)) for i in media
MultimodalData(
modality=modality,
data=load_video(i,
num_frames,
format=image_data_format,
device=device),
is_embedding=False,
) for i in media
]
elif modality == "audio":
if is_embedding:
@ -691,8 +709,11 @@ def default_multimodal_input_loader(
"External embedding is not supported for audio modality yet."
)
mm_data = [
MultimodalData(modality=modality,
data=load_audio(i, device=device)) for i in media
MultimodalData(
modality=modality,
data=load_audio(i, device=device),
is_embedding=False,
) for i in media
]
elif modality == "image_audio":
if is_embedding:
@ -720,16 +741,22 @@ def default_multimodal_input_loader(
pass
if _modal is None:
raise ValueError(f"Unknown matching modality: {modality}")
mm_data.append(MultimodalData(modality=_modal, data=data))
mm_data.append(
MultimodalData(modality=_modal,
data=data,
is_embedding=False))
elif modality == "mixture_text_image":
mm_data = []
for m in media:
if m:
mm_data.append(
MultimodalData(modality="image",
data=load_image(m,
format=image_data_format,
device=device)))
MultimodalData(
modality="image",
data=load_image(m,
format=image_data_format,
device=device),
is_embedding=False,
))
else:
raise ValueError(f"Unknown modality: {modality}")
return ConversationMessage(role="user", content=prompt, media=mm_data)
@ -763,17 +790,12 @@ def default_multimodal_input_loader(
is_embedding)
mm_data_tracker = MultimodalDataTracker(model_type)
for mdata in conv["media"]:
# Check if mdata is a MultimodalData
if isinstance(mdata,
dict) and "modality" in mdata and "data" in mdata:
mdata_modality = mdata["modality"]
if modality == "multiple_image":
mdata_modality = "image"
mm_data_tracker.add_data(mdata_modality, mdata["data"])
else:
# Add embeddings to the tracker for placeholder handling
mm_data_tracker.add_data(mdata["modality"],
mdata["mm_embedding_info"])
mdata_modality = mdata["modality"]
if modality == "multiple_image":
mdata_modality = "image"
mm_data_tracker.add_data(mdata_modality,
mdata["data"],
is_embedding=is_embedding)
mm_placeholder_counts = mm_data_tracker.placeholder_counts()
prompt = conv["content"]
if mm_placeholder_counts:
@ -790,11 +812,13 @@ def default_multimodal_input_loader(
if mm_placeholder_counts:
if mm_embeddings is not None:
input[
_, input[
"multi_modal_embeddings"] = mm_data_tracker.retrieve_all_sync(
)
else:
input["multi_modal_data"] = mm_data_tracker.retrieve_all_sync()
input[
"multi_modal_data"], _ = mm_data_tracker.retrieve_all_sync(
)
inputs.append(input)
return inputs

View File

@ -86,12 +86,6 @@ MM_PARSER_MAP: dict[str, Callable[[ChatCompletionContentPartParam], Union[
"data", None),
}
# Map from content part tags used to directly provide embeddings
# to the corresponding data modality.
MM_EMBEDDING_MAP: dict[str, str] = {
"image_embeds": "image",
}
def _parse_chat_message_content_mm_part(
part: ChatCompletionContentPartParam
@ -111,7 +105,7 @@ def _parse_chat_message_content_mm_part(
def parse_chat_message_content_part(
part: ChatCompletionContentPartParam,
mm_data_tracker: MultimodalDataTracker,
) -> Optional[Any]:
) -> str | MultimodalData | None:
"""Parse a single part of a chat message."""
if isinstance(part, str):
return part
@ -141,7 +135,9 @@ def parse_chat_message_content_part(
logger.error(f"Failed to load image: {str(e)}")
return None
return MultimodalData(modality="image", data=load_image_async())
return MultimodalData(modality="image",
data=load_image_async(),
is_embedding=False)
if part_type == "image_embeds":
str_content = cast(str, content)
@ -153,8 +149,9 @@ def parse_chat_message_content_part(
logger.error(f"Failed to decode image data: {str(e)}")
return None
return MultimodalData(modality="image_embeds",
data=decode_image_embeds_async())
return MultimodalData(modality="image",
data=decode_image_embeds_async(),
is_embedding=True)
if part_type == "video_url":
str_content = cast(str, content)
@ -169,7 +166,9 @@ def parse_chat_message_content_part(
logger.error(f"Failed to load video: {str(e)}")
return None
return MultimodalData(modality="video", data=load_video_async())
return MultimodalData(modality="video",
data=load_video_async(),
is_embedding=False)
if part_type == "audio_url":
str_content = cast(str, content)
@ -184,7 +183,9 @@ def parse_chat_message_content_part(
logger.error(f"Failed to load audio: {str(e)}")
return None
return MultimodalData(modality="audio", data=load_audio_async())
return MultimodalData(modality="audio",
data=load_audio_async(),
is_embedding=False)
raise NotImplementedError(f"Unknown part type: {part_type}")
@ -268,8 +269,9 @@ def parse_chat_messages_coroutines(
messages: List[ChatCompletionMessageParam],
model_config: AutoConfig,
multimodal_server_config: Optional[MultimodalServerConfig] = None
) -> Tuple[List[ConversationMessage], Optional[Coroutine[
Any, Any, Optional[Dict[str, List[Any]]]]]]:
) -> Tuple[List[ConversationMessage], Coroutine[Any, Any, tuple[Optional[Dict[
str, List[Any]]], Optional[Dict[str, List[Any]]]]], list[dict[str,
int]]]:
"""Parse multiple chat messages and return conversation and coroutine."""
conversation = []
mm_placeholder_counts = []
@ -283,8 +285,7 @@ def parse_chat_messages_coroutines(
for mdata in parsed_msg["media"]:
mm_data_tracker.add_data(mdata["modality"],
mdata["data"],
modality=MM_EMBEDDING_MAP.get(
mdata["modality"], None))
is_embedding=mdata["is_embedding"])
mm_placeholder_count = mm_data_tracker.placeholder_counts()
if mm_placeholder_count:
parsed_msg["content"] = add_multimodal_placeholders(

View File

@ -36,7 +36,7 @@ from tensorrt_llm.llmapi.disagg_utils import (DisaggClusterConfig,
from tensorrt_llm.llmapi.llm import RequestOutput
from tensorrt_llm.logger import logger
from tensorrt_llm.metrics.collector import MetricsCollector
from tensorrt_llm.serve.chat_utils import (MM_EMBEDDING_MAP, load_chat_template,
from tensorrt_llm.serve.chat_utils import (load_chat_template,
parse_chat_messages_coroutines)
from tensorrt_llm.serve.cluster_storage import create_cluster_storage_client
from tensorrt_llm.serve.disagg_auto_scaling import DisaggClusterWorker
@ -556,20 +556,13 @@ class OpenAIServer:
)
prompt = prompt_inputs(prompt)
mm_data = await mm_coroutines
if mm_data is not None:
# single out directly provided embeddings
mm_embeds = {}
for tag in list(mm_data.keys()):
if (modality := MM_EMBEDDING_MAP.get(tag, None)) is not None:
mm_embeds[modality] = mm_data.pop(tag)
if mm_data:
prompt["multi_modal_data"] = mm_data
if mm_embeds:
prompt["multi_modal_embeddings"] = mm_embeds
if mm_data and mm_embeds:
raise ValueError("Passing 'multi_modal_data' and 'multi_modal_embeddings' at the same time is not supported.")
mm_data, mm_embeddings = await mm_coroutines
if mm_data:
prompt["multi_modal_data"] = mm_data
if mm_embeddings:
prompt["multi_modal_embeddings"] = mm_embeddings
if mm_data and mm_embeddings:
raise ValueError("Passing 'multi_modal_data' and 'multi_modal_embeddings' at the same time is not supported.")
postproc_args.reasoning_parser = self.llm.args.reasoning_parser
postproc_args.tool_parser = self.tool_parser
@ -670,7 +663,9 @@ class OpenAIServer:
)
prompt = prompt_inputs(prompt)
mm_data = await mm_coroutines
mm_data, mm_embeddings = await mm_coroutines
if mm_embeddings:
raise ValueError("Cannot use multimodal embeddings as input")
if mm_data is not None:
prompt["multi_modal_data"] = mm_data