diff --git a/tensorrt_llm/inputs/registry.py b/tensorrt_llm/inputs/registry.py index 32832682e4..7737600e6f 100644 --- a/tensorrt_llm/inputs/registry.py +++ b/tensorrt_llm/inputs/registry.py @@ -141,6 +141,21 @@ class BaseMultimodalInputProcessor(ABC): self._use_fast: bool = kwargs.get('use_fast', True) self._multimodal_hashing_supported: Optional[bool] = None + def attach_multimodal_embeddings( + self, + inputs: TextPrompt, + multimodal_embedding: Dict[str, List[torch.Tensor]], + sampling_params: SamplingParams, + ) -> Tuple[List[int], Optional[ExtraProcessedInputs]]: + """ + Handle externally provided multimodal input embeddings. + + While inputs["multi_modal_data"] is handled by __call__, this method is intended to process + inputs["multi_modal_embeddings"]. + """ + raise NotImplementedError( + "Input processor does not support multimodal embedding input") + @property @abstractmethod def processor(self) -> AutoProcessor: diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index 4e5272e20f..41c9bdeeae 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -8,7 +8,7 @@ import time import weakref from collections.abc import Mapping from pathlib import Path -from typing import Any, List, Literal, Optional, Sequence, Union +from typing import Any, List, Literal, Optional, Sequence, Union, cast import transformers from tqdm import tqdm @@ -17,7 +17,8 @@ from transformers import PreTrainedTokenizerBase from tensorrt_llm._utils import mpi_disabled from tensorrt_llm.inputs.data import TextPrompt from tensorrt_llm.inputs.multimodal import MultimodalInput, MultimodalParams -from tensorrt_llm.inputs.registry import DefaultInputProcessor +from tensorrt_llm.inputs.registry import (BaseMultimodalInputProcessor, + DefaultInputProcessor) from tensorrt_llm.llmapi import tracing from tensorrt_llm.metrics.enums import MetricNames @@ -458,8 +459,10 @@ class BaseLLM: inputs, sampling_params) elif 'multi_modal_embeddings' in inputs: mm_embedding_info = inputs['multi_modal_embeddings'] - prompt_token_ids, extra_processed_inputs = self.input_processor.attach_multimodal_embeddings( - inputs, mm_embedding_info, sampling_params) + prompt_token_ids, extra_processed_inputs = cast( + self.input_processor, + BaseMultimodalInputProcessor).attach_multimodal_embeddings( + inputs, mm_embedding_info, sampling_params) else: with nvtx_range_debug("input_processor"): prompt_token_ids, extra_processed_inputs = self.input_processor(