[TRTLLM-9522][chore] implement default attach_multimodal_embeddings (#9664)

Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com>
This commit is contained in:
mpikulski 2025-12-06 07:12:16 +01:00 committed by GitHub
parent 7cd5a67e25
commit 8d2178d321
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 22 additions and 4 deletions

View File

@ -141,6 +141,21 @@ class BaseMultimodalInputProcessor(ABC):
self._use_fast: bool = kwargs.get('use_fast', True) self._use_fast: bool = kwargs.get('use_fast', True)
self._multimodal_hashing_supported: Optional[bool] = None self._multimodal_hashing_supported: Optional[bool] = None
def attach_multimodal_embeddings(
self,
inputs: TextPrompt,
multimodal_embedding: Dict[str, List[torch.Tensor]],
sampling_params: SamplingParams,
) -> Tuple[List[int], Optional[ExtraProcessedInputs]]:
"""
Handle externally provided multimodal input embeddings.
While inputs["multi_modal_data"] is handled by __call__, this method is intended to process
inputs["multi_modal_embeddings"].
"""
raise NotImplementedError(
"Input processor does not support multimodal embedding input")
@property @property
@abstractmethod @abstractmethod
def processor(self) -> AutoProcessor: def processor(self) -> AutoProcessor:

View File

@ -8,7 +8,7 @@ import time
import weakref import weakref
from collections.abc import Mapping from collections.abc import Mapping
from pathlib import Path from pathlib import Path
from typing import Any, List, Literal, Optional, Sequence, Union from typing import Any, List, Literal, Optional, Sequence, Union, cast
import transformers import transformers
from tqdm import tqdm from tqdm import tqdm
@ -17,7 +17,8 @@ from transformers import PreTrainedTokenizerBase
from tensorrt_llm._utils import mpi_disabled from tensorrt_llm._utils import mpi_disabled
from tensorrt_llm.inputs.data import TextPrompt from tensorrt_llm.inputs.data import TextPrompt
from tensorrt_llm.inputs.multimodal import MultimodalInput, MultimodalParams from tensorrt_llm.inputs.multimodal import MultimodalInput, MultimodalParams
from tensorrt_llm.inputs.registry import DefaultInputProcessor from tensorrt_llm.inputs.registry import (BaseMultimodalInputProcessor,
DefaultInputProcessor)
from tensorrt_llm.llmapi import tracing from tensorrt_llm.llmapi import tracing
from tensorrt_llm.metrics.enums import MetricNames from tensorrt_llm.metrics.enums import MetricNames
@ -458,8 +459,10 @@ class BaseLLM:
inputs, sampling_params) inputs, sampling_params)
elif 'multi_modal_embeddings' in inputs: elif 'multi_modal_embeddings' in inputs:
mm_embedding_info = inputs['multi_modal_embeddings'] mm_embedding_info = inputs['multi_modal_embeddings']
prompt_token_ids, extra_processed_inputs = self.input_processor.attach_multimodal_embeddings( prompt_token_ids, extra_processed_inputs = cast(
inputs, mm_embedding_info, sampling_params) self.input_processor,
BaseMultimodalInputProcessor).attach_multimodal_embeddings(
inputs, mm_embedding_info, sampling_params)
else: else:
with nvtx_range_debug("input_processor"): with nvtx_range_debug("input_processor"):
prompt_token_ids, extra_processed_inputs = self.input_processor( prompt_token_ids, extra_processed_inputs = self.input_processor(