mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
chore: add is_embedding to MultimodalData
Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com>
This commit is contained in:
parent
0543bf01fb
commit
db14542c35
@ -244,7 +244,7 @@ class MultimodalLmEvalWrapper(LmEvalWrapper):
|
||||
|
||||
# NOTE: Since we already have loaded images, for the placeholder purpose, we add data here.
|
||||
for _ in range(image_count):
|
||||
mm_data_tracker.add_data("image", None)
|
||||
mm_data_tracker.add_data("image", None, is_embedding=False)
|
||||
mm_placeholder_count = mm_data_tracker.placeholder_counts()
|
||||
if mm_placeholder_count:
|
||||
# TODO: This is an assumption of not interleaving text and image. Need to extend to interleaved texts.
|
||||
|
||||
@ -434,13 +434,14 @@ class MultimodalData(TypedDict):
|
||||
"""Type definition for multimodal data structure."""
|
||||
modality: str
|
||||
data: Any
|
||||
is_embedding: bool
|
||||
|
||||
|
||||
class ConversationMessage(TypedDict):
|
||||
"""Type definition for conversation message structure."""
|
||||
role: str
|
||||
content: List[dict[str, Any]]
|
||||
media: List[MultimodalData] | List[torch.Tensor] | List[Dict[str, Any]]
|
||||
media: List[MultimodalData]
|
||||
|
||||
# @classmethod
|
||||
# def fromSample(cls, sample: dict[str, str]) -> "ConversationMessage":
|
||||
@ -455,38 +456,54 @@ class MultimodalDataTracker:
|
||||
model_type: str,
|
||||
multimodal_server_config: Optional[MultimodalServerConfig] = None):
|
||||
self._model_type = model_type
|
||||
self._data = defaultdict[str](list)
|
||||
self._placeholder_counts = defaultdict[str](int)
|
||||
self._data = defaultdict[str, list](list)
|
||||
self._embeddings = defaultdict[str, list](list)
|
||||
self._placeholder_counts = defaultdict[str, int](int)
|
||||
self._multimodal_server_config = multimodal_server_config if multimodal_server_config is not None else MultimodalServerConfig(
|
||||
)
|
||||
|
||||
async def retrieve_all_async(self) -> Optional[Dict[str, List[Any]]]:
|
||||
"""Retrieve all collected multimodal data."""
|
||||
if not self._data:
|
||||
return None
|
||||
async def retrieve_all_async(
|
||||
self
|
||||
) -> tuple[Optional[Dict[str, List[Any]]], Optional[Dict[str, List[Any]]]]:
|
||||
"""Retrieve all collected multimodal data and embeddings."""
|
||||
|
||||
return {
|
||||
modality: await asyncio.gather(*items)
|
||||
for modality, items in self._data.items()
|
||||
}
|
||||
async def _retrieve(
|
||||
data: Optional[dict[str,
|
||||
list]]) -> Optional[Dict[str, List[Any]]]:
|
||||
if not data:
|
||||
return None
|
||||
return {
|
||||
modality: await asyncio.gather(*items)
|
||||
for modality, items in data.items() if items
|
||||
}
|
||||
|
||||
def retrieve_all_sync(self) -> Optional[Dict[str, List[Any]]]:
|
||||
"""Retrieve all collected multimodal data."""
|
||||
if not self._data:
|
||||
return None
|
||||
return await _retrieve(self._data), await _retrieve(self._embeddings)
|
||||
|
||||
return {modality: items for modality, items in self._data.items()}
|
||||
def retrieve_all_sync(
|
||||
self
|
||||
) -> tuple[Optional[Dict[str, List[Any]]], Optional[Dict[str, List[Any]]]]:
|
||||
"""Retrieve all collected multimodal data and embeddings."""
|
||||
|
||||
def add_data(self,
|
||||
media_type: str,
|
||||
data: Union[Coroutine, Any],
|
||||
*,
|
||||
modality: Optional[str] = None):
|
||||
modality = modality or media_type
|
||||
current_count = len(self._data[media_type]) + 1
|
||||
def _retrieve(
|
||||
data: Optional[dict[str,
|
||||
list]]) -> Optional[Dict[str, List[Any]]]:
|
||||
if not data:
|
||||
return None
|
||||
return {
|
||||
modality: items
|
||||
for modality, items in data.items() if items
|
||||
}
|
||||
|
||||
return _retrieve(self._data), _retrieve(self._embeddings)
|
||||
|
||||
def add_data(self, media_type: str, data: Union[Coroutine, Any], *,
|
||||
is_embedding: bool):
|
||||
current_count = len(self._data[media_type]) + len(
|
||||
self._embeddings[media_type]) + 1
|
||||
placeholder = retrieve_multimodal_placeholder(self._model_type,
|
||||
modality, current_count)
|
||||
self._data[media_type].append(data)
|
||||
media_type, current_count)
|
||||
(self._embeddings
|
||||
if is_embedding else self._data)[media_type].append(data)
|
||||
if placeholder:
|
||||
self._placeholder_counts[placeholder] += 1
|
||||
|
||||
@ -657,33 +674,34 @@ def default_multimodal_input_loader(
|
||||
media = [media]
|
||||
if modality in ["image", "multiple_image"]:
|
||||
if is_embedding:
|
||||
_load = lambda mm: mm
|
||||
|
||||
# each mm_embedding corresponds to each image placeholder
|
||||
if not isinstance(media, list):
|
||||
media = [media]
|
||||
|
||||
mm_data = [{
|
||||
'modality': modality,
|
||||
'mm_embedding_info': mm
|
||||
} for mm in media]
|
||||
else:
|
||||
mm_data = [
|
||||
MultimodalData(modality=modality,
|
||||
data=load_image(i,
|
||||
format=image_data_format,
|
||||
device=device))
|
||||
for i in media
|
||||
]
|
||||
_load = lambda mm: load_image(
|
||||
mm, format=image_data_format, device=device)
|
||||
|
||||
mm_data = [
|
||||
MultimodalData(modality=modality,
|
||||
data=_load(mm),
|
||||
is_embedding=is_embedding) for mm in media
|
||||
]
|
||||
elif modality == "video":
|
||||
if is_embedding:
|
||||
raise ValueError(
|
||||
"External embedding is not supported for video modality yet."
|
||||
)
|
||||
mm_data = [
|
||||
MultimodalData(modality=modality,
|
||||
data=load_video(i,
|
||||
num_frames,
|
||||
format=image_data_format,
|
||||
device=device)) for i in media
|
||||
MultimodalData(
|
||||
modality=modality,
|
||||
data=load_video(i,
|
||||
num_frames,
|
||||
format=image_data_format,
|
||||
device=device),
|
||||
is_embedding=False,
|
||||
) for i in media
|
||||
]
|
||||
elif modality == "audio":
|
||||
if is_embedding:
|
||||
@ -691,8 +709,11 @@ def default_multimodal_input_loader(
|
||||
"External embedding is not supported for audio modality yet."
|
||||
)
|
||||
mm_data = [
|
||||
MultimodalData(modality=modality,
|
||||
data=load_audio(i, device=device)) for i in media
|
||||
MultimodalData(
|
||||
modality=modality,
|
||||
data=load_audio(i, device=device),
|
||||
is_embedding=False,
|
||||
) for i in media
|
||||
]
|
||||
elif modality == "image_audio":
|
||||
if is_embedding:
|
||||
@ -720,16 +741,22 @@ def default_multimodal_input_loader(
|
||||
pass
|
||||
if _modal is None:
|
||||
raise ValueError(f"Unknown matching modality: {modality}")
|
||||
mm_data.append(MultimodalData(modality=_modal, data=data))
|
||||
mm_data.append(
|
||||
MultimodalData(modality=_modal,
|
||||
data=data,
|
||||
is_embedding=False))
|
||||
elif modality == "mixture_text_image":
|
||||
mm_data = []
|
||||
for m in media:
|
||||
if m:
|
||||
mm_data.append(
|
||||
MultimodalData(modality="image",
|
||||
data=load_image(m,
|
||||
format=image_data_format,
|
||||
device=device)))
|
||||
MultimodalData(
|
||||
modality="image",
|
||||
data=load_image(m,
|
||||
format=image_data_format,
|
||||
device=device),
|
||||
is_embedding=False,
|
||||
))
|
||||
else:
|
||||
raise ValueError(f"Unknown modality: {modality}")
|
||||
return ConversationMessage(role="user", content=prompt, media=mm_data)
|
||||
@ -763,17 +790,12 @@ def default_multimodal_input_loader(
|
||||
is_embedding)
|
||||
mm_data_tracker = MultimodalDataTracker(model_type)
|
||||
for mdata in conv["media"]:
|
||||
# Check if mdata is a MultimodalData
|
||||
if isinstance(mdata,
|
||||
dict) and "modality" in mdata and "data" in mdata:
|
||||
mdata_modality = mdata["modality"]
|
||||
if modality == "multiple_image":
|
||||
mdata_modality = "image"
|
||||
mm_data_tracker.add_data(mdata_modality, mdata["data"])
|
||||
else:
|
||||
# Add embeddings to the tracker for placeholder handling
|
||||
mm_data_tracker.add_data(mdata["modality"],
|
||||
mdata["mm_embedding_info"])
|
||||
mdata_modality = mdata["modality"]
|
||||
if modality == "multiple_image":
|
||||
mdata_modality = "image"
|
||||
mm_data_tracker.add_data(mdata_modality,
|
||||
mdata["data"],
|
||||
is_embedding=is_embedding)
|
||||
mm_placeholder_counts = mm_data_tracker.placeholder_counts()
|
||||
prompt = conv["content"]
|
||||
if mm_placeholder_counts:
|
||||
@ -790,11 +812,13 @@ def default_multimodal_input_loader(
|
||||
|
||||
if mm_placeholder_counts:
|
||||
if mm_embeddings is not None:
|
||||
input[
|
||||
_, input[
|
||||
"multi_modal_embeddings"] = mm_data_tracker.retrieve_all_sync(
|
||||
)
|
||||
else:
|
||||
input["multi_modal_data"] = mm_data_tracker.retrieve_all_sync()
|
||||
input[
|
||||
"multi_modal_data"], _ = mm_data_tracker.retrieve_all_sync(
|
||||
)
|
||||
inputs.append(input)
|
||||
|
||||
return inputs
|
||||
|
||||
@ -86,12 +86,6 @@ MM_PARSER_MAP: dict[str, Callable[[ChatCompletionContentPartParam], Union[
|
||||
"data", None),
|
||||
}
|
||||
|
||||
# Map from content part tags used to directly provide embeddings
|
||||
# to the corresponding data modality.
|
||||
MM_EMBEDDING_MAP: dict[str, str] = {
|
||||
"image_embeds": "image",
|
||||
}
|
||||
|
||||
|
||||
def _parse_chat_message_content_mm_part(
|
||||
part: ChatCompletionContentPartParam
|
||||
@ -111,7 +105,7 @@ def _parse_chat_message_content_mm_part(
|
||||
def parse_chat_message_content_part(
|
||||
part: ChatCompletionContentPartParam,
|
||||
mm_data_tracker: MultimodalDataTracker,
|
||||
) -> Optional[Any]:
|
||||
) -> str | MultimodalData | None:
|
||||
"""Parse a single part of a chat message."""
|
||||
if isinstance(part, str):
|
||||
return part
|
||||
@ -141,7 +135,9 @@ def parse_chat_message_content_part(
|
||||
logger.error(f"Failed to load image: {str(e)}")
|
||||
return None
|
||||
|
||||
return MultimodalData(modality="image", data=load_image_async())
|
||||
return MultimodalData(modality="image",
|
||||
data=load_image_async(),
|
||||
is_embedding=False)
|
||||
|
||||
if part_type == "image_embeds":
|
||||
str_content = cast(str, content)
|
||||
@ -153,8 +149,9 @@ def parse_chat_message_content_part(
|
||||
logger.error(f"Failed to decode image data: {str(e)}")
|
||||
return None
|
||||
|
||||
return MultimodalData(modality="image_embeds",
|
||||
data=decode_image_embeds_async())
|
||||
return MultimodalData(modality="image",
|
||||
data=decode_image_embeds_async(),
|
||||
is_embedding=True)
|
||||
|
||||
if part_type == "video_url":
|
||||
str_content = cast(str, content)
|
||||
@ -169,7 +166,9 @@ def parse_chat_message_content_part(
|
||||
logger.error(f"Failed to load video: {str(e)}")
|
||||
return None
|
||||
|
||||
return MultimodalData(modality="video", data=load_video_async())
|
||||
return MultimodalData(modality="video",
|
||||
data=load_video_async(),
|
||||
is_embedding=False)
|
||||
|
||||
if part_type == "audio_url":
|
||||
str_content = cast(str, content)
|
||||
@ -184,7 +183,9 @@ def parse_chat_message_content_part(
|
||||
logger.error(f"Failed to load audio: {str(e)}")
|
||||
return None
|
||||
|
||||
return MultimodalData(modality="audio", data=load_audio_async())
|
||||
return MultimodalData(modality="audio",
|
||||
data=load_audio_async(),
|
||||
is_embedding=False)
|
||||
|
||||
raise NotImplementedError(f"Unknown part type: {part_type}")
|
||||
|
||||
@ -268,8 +269,9 @@ def parse_chat_messages_coroutines(
|
||||
messages: List[ChatCompletionMessageParam],
|
||||
model_config: AutoConfig,
|
||||
multimodal_server_config: Optional[MultimodalServerConfig] = None
|
||||
) -> Tuple[List[ConversationMessage], Optional[Coroutine[
|
||||
Any, Any, Optional[Dict[str, List[Any]]]]]]:
|
||||
) -> Tuple[List[ConversationMessage], Coroutine[Any, Any, tuple[Optional[Dict[
|
||||
str, List[Any]]], Optional[Dict[str, List[Any]]]]], list[dict[str,
|
||||
int]]]:
|
||||
"""Parse multiple chat messages and return conversation and coroutine."""
|
||||
conversation = []
|
||||
mm_placeholder_counts = []
|
||||
@ -283,8 +285,7 @@ def parse_chat_messages_coroutines(
|
||||
for mdata in parsed_msg["media"]:
|
||||
mm_data_tracker.add_data(mdata["modality"],
|
||||
mdata["data"],
|
||||
modality=MM_EMBEDDING_MAP.get(
|
||||
mdata["modality"], None))
|
||||
is_embedding=mdata["is_embedding"])
|
||||
mm_placeholder_count = mm_data_tracker.placeholder_counts()
|
||||
if mm_placeholder_count:
|
||||
parsed_msg["content"] = add_multimodal_placeholders(
|
||||
|
||||
@ -36,7 +36,7 @@ from tensorrt_llm.llmapi.disagg_utils import (DisaggClusterConfig,
|
||||
from tensorrt_llm.llmapi.llm import RequestOutput
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.metrics.collector import MetricsCollector
|
||||
from tensorrt_llm.serve.chat_utils import (MM_EMBEDDING_MAP, load_chat_template,
|
||||
from tensorrt_llm.serve.chat_utils import (load_chat_template,
|
||||
parse_chat_messages_coroutines)
|
||||
from tensorrt_llm.serve.cluster_storage import create_cluster_storage_client
|
||||
from tensorrt_llm.serve.disagg_auto_scaling import DisaggClusterWorker
|
||||
@ -556,20 +556,13 @@ class OpenAIServer:
|
||||
)
|
||||
prompt = prompt_inputs(prompt)
|
||||
|
||||
mm_data = await mm_coroutines
|
||||
if mm_data is not None:
|
||||
# single out directly provided embeddings
|
||||
mm_embeds = {}
|
||||
for tag in list(mm_data.keys()):
|
||||
if (modality := MM_EMBEDDING_MAP.get(tag, None)) is not None:
|
||||
mm_embeds[modality] = mm_data.pop(tag)
|
||||
|
||||
if mm_data:
|
||||
prompt["multi_modal_data"] = mm_data
|
||||
if mm_embeds:
|
||||
prompt["multi_modal_embeddings"] = mm_embeds
|
||||
if mm_data and mm_embeds:
|
||||
raise ValueError("Passing 'multi_modal_data' and 'multi_modal_embeddings' at the same time is not supported.")
|
||||
mm_data, mm_embeddings = await mm_coroutines
|
||||
if mm_data:
|
||||
prompt["multi_modal_data"] = mm_data
|
||||
if mm_embeddings:
|
||||
prompt["multi_modal_embeddings"] = mm_embeddings
|
||||
if mm_data and mm_embeddings:
|
||||
raise ValueError("Passing 'multi_modal_data' and 'multi_modal_embeddings' at the same time is not supported.")
|
||||
|
||||
postproc_args.reasoning_parser = self.llm.args.reasoning_parser
|
||||
postproc_args.tool_parser = self.tool_parser
|
||||
@ -670,7 +663,9 @@ class OpenAIServer:
|
||||
)
|
||||
prompt = prompt_inputs(prompt)
|
||||
|
||||
mm_data = await mm_coroutines
|
||||
mm_data, mm_embeddings = await mm_coroutines
|
||||
if mm_embeddings:
|
||||
raise ValueError("Cannot use multimodal embeddings as input")
|
||||
if mm_data is not None:
|
||||
prompt["multi_modal_data"] = mm_data
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user