mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[TRTLLM-9522][chore] implement default attach_multimodal_embeddings (#9664)
Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com>
This commit is contained in:
parent
7cd5a67e25
commit
8d2178d321
@ -141,6 +141,21 @@ class BaseMultimodalInputProcessor(ABC):
|
||||
self._use_fast: bool = kwargs.get('use_fast', True)
|
||||
self._multimodal_hashing_supported: Optional[bool] = None
|
||||
|
||||
def attach_multimodal_embeddings(
|
||||
self,
|
||||
inputs: TextPrompt,
|
||||
multimodal_embedding: Dict[str, List[torch.Tensor]],
|
||||
sampling_params: SamplingParams,
|
||||
) -> Tuple[List[int], Optional[ExtraProcessedInputs]]:
|
||||
"""
|
||||
Handle externally provided multimodal input embeddings.
|
||||
|
||||
While inputs["multi_modal_data"] is handled by __call__, this method is intended to process
|
||||
inputs["multi_modal_embeddings"].
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"Input processor does not support multimodal embedding input")
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def processor(self) -> AutoProcessor:
|
||||
|
||||
@ -8,7 +8,7 @@ import time
|
||||
import weakref
|
||||
from collections.abc import Mapping
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Literal, Optional, Sequence, Union
|
||||
from typing import Any, List, Literal, Optional, Sequence, Union, cast
|
||||
|
||||
import transformers
|
||||
from tqdm import tqdm
|
||||
@ -17,7 +17,8 @@ from transformers import PreTrainedTokenizerBase
|
||||
from tensorrt_llm._utils import mpi_disabled
|
||||
from tensorrt_llm.inputs.data import TextPrompt
|
||||
from tensorrt_llm.inputs.multimodal import MultimodalInput, MultimodalParams
|
||||
from tensorrt_llm.inputs.registry import DefaultInputProcessor
|
||||
from tensorrt_llm.inputs.registry import (BaseMultimodalInputProcessor,
|
||||
DefaultInputProcessor)
|
||||
from tensorrt_llm.llmapi import tracing
|
||||
from tensorrt_llm.metrics.enums import MetricNames
|
||||
|
||||
@ -458,7 +459,9 @@ class BaseLLM:
|
||||
inputs, sampling_params)
|
||||
elif 'multi_modal_embeddings' in inputs:
|
||||
mm_embedding_info = inputs['multi_modal_embeddings']
|
||||
prompt_token_ids, extra_processed_inputs = self.input_processor.attach_multimodal_embeddings(
|
||||
prompt_token_ids, extra_processed_inputs = cast(
|
||||
self.input_processor,
|
||||
BaseMultimodalInputProcessor).attach_multimodal_embeddings(
|
||||
inputs, mm_embedding_info, sampling_params)
|
||||
else:
|
||||
with nvtx_range_debug("input_processor"):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user