mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[TRTLLM-9522][chore] implement default attach_multimodal_embeddings (#9664)
Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com>
This commit is contained in:
parent
7cd5a67e25
commit
8d2178d321
@ -141,6 +141,21 @@ class BaseMultimodalInputProcessor(ABC):
|
|||||||
self._use_fast: bool = kwargs.get('use_fast', True)
|
self._use_fast: bool = kwargs.get('use_fast', True)
|
||||||
self._multimodal_hashing_supported: Optional[bool] = None
|
self._multimodal_hashing_supported: Optional[bool] = None
|
||||||
|
|
||||||
|
def attach_multimodal_embeddings(
|
||||||
|
self,
|
||||||
|
inputs: TextPrompt,
|
||||||
|
multimodal_embedding: Dict[str, List[torch.Tensor]],
|
||||||
|
sampling_params: SamplingParams,
|
||||||
|
) -> Tuple[List[int], Optional[ExtraProcessedInputs]]:
|
||||||
|
"""
|
||||||
|
Handle externally provided multimodal input embeddings.
|
||||||
|
|
||||||
|
While inputs["multi_modal_data"] is handled by __call__, this method is intended to process
|
||||||
|
inputs["multi_modal_embeddings"].
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Input processor does not support multimodal embedding input")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def processor(self) -> AutoProcessor:
|
def processor(self) -> AutoProcessor:
|
||||||
|
|||||||
@ -8,7 +8,7 @@ import time
|
|||||||
import weakref
|
import weakref
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, List, Literal, Optional, Sequence, Union
|
from typing import Any, List, Literal, Optional, Sequence, Union, cast
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
@ -17,7 +17,8 @@ from transformers import PreTrainedTokenizerBase
|
|||||||
from tensorrt_llm._utils import mpi_disabled
|
from tensorrt_llm._utils import mpi_disabled
|
||||||
from tensorrt_llm.inputs.data import TextPrompt
|
from tensorrt_llm.inputs.data import TextPrompt
|
||||||
from tensorrt_llm.inputs.multimodal import MultimodalInput, MultimodalParams
|
from tensorrt_llm.inputs.multimodal import MultimodalInput, MultimodalParams
|
||||||
from tensorrt_llm.inputs.registry import DefaultInputProcessor
|
from tensorrt_llm.inputs.registry import (BaseMultimodalInputProcessor,
|
||||||
|
DefaultInputProcessor)
|
||||||
from tensorrt_llm.llmapi import tracing
|
from tensorrt_llm.llmapi import tracing
|
||||||
from tensorrt_llm.metrics.enums import MetricNames
|
from tensorrt_llm.metrics.enums import MetricNames
|
||||||
|
|
||||||
@ -458,8 +459,10 @@ class BaseLLM:
|
|||||||
inputs, sampling_params)
|
inputs, sampling_params)
|
||||||
elif 'multi_modal_embeddings' in inputs:
|
elif 'multi_modal_embeddings' in inputs:
|
||||||
mm_embedding_info = inputs['multi_modal_embeddings']
|
mm_embedding_info = inputs['multi_modal_embeddings']
|
||||||
prompt_token_ids, extra_processed_inputs = self.input_processor.attach_multimodal_embeddings(
|
prompt_token_ids, extra_processed_inputs = cast(
|
||||||
inputs, mm_embedding_info, sampling_params)
|
self.input_processor,
|
||||||
|
BaseMultimodalInputProcessor).attach_multimodal_embeddings(
|
||||||
|
inputs, mm_embedding_info, sampling_params)
|
||||||
else:
|
else:
|
||||||
with nvtx_range_debug("input_processor"):
|
with nvtx_range_debug("input_processor"):
|
||||||
prompt_token_ids, extra_processed_inputs = self.input_processor(
|
prompt_token_ids, extra_processed_inputs = self.input_processor(
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user