mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-16 15:55:08 +08:00
[TRTLLM-10858][feat] Multi-image support for EPD disagg (#11264)
* Why? Prior to this commit, we only supported a single multimodal input for E/P/D disaggregated serving. * What? This commit does a minor refactor of the multimodal embedding handles that cross process boundaries to enable this. Existing unit tests are updated accordingly to test this. The `RequestOutput` has its `mm_embedding_handle` replaced in favor of `disaggregated_params`, addressing a previous TODO. Signed-off-by: William Zhang <133824995+2ez4bz@users.noreply.github.com>
This commit is contained in:
parent
42648734b8
commit
ca9537e17c
@ -176,7 +176,7 @@ class LlavaNextInputProcessor(BaseMultimodalInputProcessor,
|
||||
|
||||
Args:
|
||||
inputs: Text prompt input container. Must contain a non-empty prompt string.
|
||||
mm_handles: List of multimodal embedding handles. Currently only a single handle is supported.
|
||||
mm_handles: List of multimodal embedding handles.
|
||||
|
||||
Returns:
|
||||
Tuple[List[int], List[int], List[int]]:
|
||||
@ -192,12 +192,13 @@ class LlavaNextInputProcessor(BaseMultimodalInputProcessor,
|
||||
if not isinstance(mm_handles, list):
|
||||
raise ValueError("mm_handles must be a list")
|
||||
|
||||
if len(mm_handles) != 1:
|
||||
# TODO: only support single multimodal item within a request for now
|
||||
raise NotImplementedError(
|
||||
"Only one mm_handle is supported for LlavaNext for now")
|
||||
hidden_size = mm_handles[0]['tensor_size'][1]
|
||||
assert hidden_size == self.config.text_config.hidden_size, "Multimodal embedding hidden size must match model hidden size"
|
||||
expected_hidden_size = self.config.text_config.hidden_size
|
||||
for i, mm_handle in enumerate(mm_handles):
|
||||
hidden_size = mm_handle['tensor_size'][1]
|
||||
if hidden_size != expected_hidden_size:
|
||||
raise RuntimeError(
|
||||
f"Multimodal embedding {i} hidden size {hidden_size} must match model hidden size {expected_hidden_size}"
|
||||
)
|
||||
input_ids = self.tokenizer(text_prompt,
|
||||
return_tensors="pt").input_ids[0]
|
||||
|
||||
|
||||
@ -1061,7 +1061,7 @@ class Qwen2_5VLInputProcessorBase(Qwen2VLInputProcessorBase):
|
||||
|
||||
Args:
|
||||
inputs: Text prompt input container. Must contain a non-empty prompt string.
|
||||
mm_handles: List of multimodal embedding handles. Currently only a single handle is supported.
|
||||
mm_handles: List of multimodal embedding handles.
|
||||
|
||||
Returns:
|
||||
Tuple[List[int], List[int], List[int]]:
|
||||
@ -1077,12 +1077,13 @@ class Qwen2_5VLInputProcessorBase(Qwen2VLInputProcessorBase):
|
||||
if not isinstance(mm_handles, list):
|
||||
raise TypeError("mm_handles must be a list")
|
||||
|
||||
if len(mm_handles) != 1:
|
||||
# TODO: only support single multimodal item within a request for now
|
||||
raise NotImplementedError(
|
||||
"Only one mm_handle is supported for Qwen2.5 VL for now")
|
||||
hidden_size = mm_handles[0]['tensor_size'][1]
|
||||
assert hidden_size == self.config.text_config.hidden_size, "Multimodal embedding hidden size must match model hidden size"
|
||||
expected_hidden_size = self.config.text_config.hidden_size
|
||||
for i, mm_handle in enumerate(mm_handles):
|
||||
hidden_size = mm_handle['tensor_size'][1]
|
||||
if hidden_size != expected_hidden_size:
|
||||
raise RuntimeError(
|
||||
f"Multimodal embedding {i} hidden size {hidden_size} must match model hidden size {expected_hidden_size}"
|
||||
)
|
||||
input_ids = self.tokenizer(text_prompt,
|
||||
return_tensors="pt").input_ids[0]
|
||||
|
||||
|
||||
@ -360,7 +360,7 @@ class Qwen3VLInputProcessorBase(BaseMultimodalInputProcessor, BaseMultimodalDumm
|
||||
|
||||
Args:
|
||||
inputs: Text prompt input container. Must contain a non-empty prompt string.
|
||||
mm_handles: List of multimodal embedding handles. Currently only a single handle is supported.
|
||||
mm_handles: List of multimodal embedding handles.
|
||||
|
||||
Returns:
|
||||
Tuple[List[int], List[int], List[int]]:
|
||||
@ -376,19 +376,16 @@ class Qwen3VLInputProcessorBase(BaseMultimodalInputProcessor, BaseMultimodalDumm
|
||||
if not isinstance(mm_handles, list):
|
||||
raise TypeError("mm_handles must be a list")
|
||||
|
||||
if len(mm_handles) > 1:
|
||||
# TODO: only support single multimodal item within a request for now
|
||||
raise NotImplementedError("Only one mm_handle is supported for Qwen3 VL for now")
|
||||
|
||||
hidden_size = mm_handles[0]["tensor_size"][1]
|
||||
num_deepstack_levels = len(self.config.vision_config.deepstack_visual_indexes)
|
||||
# This is because, unlike previous Qwen VL models, the embeddings are concatenated with
|
||||
# feature maps from deepstack layers.
|
||||
expected_size = self.config.text_config.hidden_size * (1 + num_deepstack_levels)
|
||||
if hidden_size != expected_size:
|
||||
raise RuntimeError(
|
||||
f"Expected multimodal embedding to have hidden size {expected_size}, got {hidden_size}."
|
||||
)
|
||||
for i, mm_handle in enumerate(mm_handles):
|
||||
hidden_size = mm_handle["tensor_size"][1]
|
||||
if hidden_size != expected_size:
|
||||
raise RuntimeError(
|
||||
f"Expected multimodal embedding {i} to have hidden size {expected_size}, got {hidden_size}."
|
||||
)
|
||||
|
||||
input_ids = self.tokenizer(text_prompt, return_tensors="pt").input_ids[0]
|
||||
|
||||
|
||||
@ -246,7 +246,7 @@ class PyResult:
|
||||
list[float] | None] | None = None
|
||||
log_probs_list: list[tuple[list[TokenLogprobs], list[float]
|
||||
| None]] = field(default_factory=list)
|
||||
mm_embeddings: dict[str, Any] | None = None
|
||||
mm_embeddings: list[dict[str, Any] | None] = None
|
||||
mrope_position_ids: dict[str, Any] | None = None
|
||||
mrope_position_deltas: dict[str, Any] | None = None
|
||||
additional_context_outputs_list: list[tuple[str, torch.Tensor]] = field(
|
||||
@ -289,7 +289,7 @@ class PyResult:
|
||||
use_chunked_generation_logits=use_chunked_generation_logits,
|
||||
chunk_size=self._chunk_size) if return_generation_logits else None
|
||||
self._log_probs = LogProbStorage() if return_log_probs else None
|
||||
self._mm_embeddings = None
|
||||
self._mm_embeddings: Optional[List[Dict[str, Any]]] = None
|
||||
self._mrope_position_ids = None
|
||||
self._mrope_position_deltas = None
|
||||
self._additional_context_outputs = {
|
||||
@ -362,9 +362,22 @@ class PyResult:
|
||||
self._log_probs.append(log_probs, cum_log_probs)
|
||||
self.diff.log_probs_list.append((log_probs, cum_log_probs))
|
||||
|
||||
def append_mm_embeddings(self, mm_embeddings: torch.Tensor):
|
||||
self._mm_embeddings = SharedTensorContainer.from_tensor(
|
||||
mm_embeddings).dump_to_dict()
|
||||
def append_mm_embeddings(self, mm_embeddings: torch.Tensor,
|
||||
multimodal_lengths: List[int]):
|
||||
"""Split concatenated embeddings by multimodal_lengths and create handles for each.
|
||||
|
||||
Args:
|
||||
mm_embeddings: Concatenated multimodal embeddings tensor of shape [total_tokens, hidden_dim]
|
||||
multimodal_lengths: List of token lengths for each multimodal item
|
||||
"""
|
||||
# Split the concatenated tensor by lengths to get per-item embeddings
|
||||
split_embeddings = torch.split(mm_embeddings, multimodal_lengths, dim=0)
|
||||
|
||||
# Create a SharedTensorContainer handle for each split
|
||||
self._mm_embeddings = [
|
||||
SharedTensorContainer.from_tensor(emb).dump_to_dict()
|
||||
for emb in split_embeddings
|
||||
]
|
||||
self.diff.mm_embeddings = self._mm_embeddings
|
||||
|
||||
def set_mrope_position(
|
||||
@ -441,7 +454,8 @@ class PyResult:
|
||||
return self._log_probs and self._log_probs.cum_log_probs
|
||||
|
||||
@property
|
||||
def mm_embedding_handle(self) -> Dict[str, Any] | None:
|
||||
def mm_embedding_handles(self) -> List[Dict[str, Any]] | None:
|
||||
"""Returns a list of SharedTensorContainer handles, one per multimodal item."""
|
||||
return self._mm_embeddings
|
||||
|
||||
@property
|
||||
@ -485,7 +499,7 @@ class LlmResult:
|
||||
"""LlmResult wraps `bindings.executor.Result` but detour some features to Python implementation"""
|
||||
py_result_properties = frozenset(
|
||||
('context_logits', 'generation_logits', 'log_probs', 'cum_log_probs',
|
||||
'mm_embedding_handle', 'additional_context_outputs',
|
||||
'mm_embedding_handles', 'additional_context_outputs',
|
||||
'additional_generation_outputs', 'mrope_position_ids_handle',
|
||||
'mrope_position_deltas_handle'))
|
||||
|
||||
|
||||
@ -326,7 +326,7 @@ class EarlyStopWithMMResult(Sampler[SampleStateWithMMResult]):
|
||||
f"mm_embedding shape mismatch: {len(mm_embedding)} != {sum(request.multimodal_lengths)}"
|
||||
)
|
||||
|
||||
request.py_result.append_mm_embeddings(mm_embedding)
|
||||
request.py_result.append_mm_embeddings(mm_embedding, request.multimodal_lengths)
|
||||
|
||||
# Store mrope data if available
|
||||
if mrope_position_ids is not None and mrope_position_deltas is not None:
|
||||
|
||||
@ -170,7 +170,7 @@ class GenerationResultBase:
|
||||
self.id = id
|
||||
self.sampling_params = sampling_params
|
||||
self.postproc_params = postproc_params
|
||||
self.disaggregated_params = None
|
||||
self._disaggregated_params = None
|
||||
self.decoding_iter = 0
|
||||
self.cached_tokens = 0
|
||||
# Average decoded tokens per runtime iteration; set when the first LLM response arrives.
|
||||
@ -196,7 +196,6 @@ class GenerationResultBase:
|
||||
CompletionOutput(i) for i in range(self.sampling_params.best_of)
|
||||
]
|
||||
self._context_logits: Optional[torch.Tensor] = None
|
||||
self._mm_embedding_handle: Optional[Dict[str, Any]] = None
|
||||
|
||||
self._background_error_handler = None
|
||||
if background_error_handler is not None:
|
||||
@ -234,9 +233,9 @@ class GenerationResultBase:
|
||||
return self._context_logits
|
||||
|
||||
@property
|
||||
# TODO: Keep this property only for backward compatibility. In the future, access multimodal embedding handles from disaggregated_params instead.
|
||||
def mm_embedding_handle(self) -> Optional[Dict[str, Any]]:
|
||||
return self._mm_embedding_handle
|
||||
def disaggregated_params(self) -> Optional[DisaggregatedParams]:
|
||||
"""Returns the disaggregated params."""
|
||||
return self._disaggregated_params
|
||||
|
||||
def _handle_sequence(self,
|
||||
finish_reasons,
|
||||
@ -420,7 +419,7 @@ class GenerationResultBase:
|
||||
# Use `replace` to preserve things like `mrope_position_ids_handle` and
|
||||
# `mrope_position_deltas_handle`. However, explicitly set
|
||||
# `multimodal_embedding_handles=None` since they should no longer be needed.
|
||||
self.disaggregated_params = dataclasses.replace(
|
||||
self._disaggregated_params = dataclasses.replace(
|
||||
existing_disagg_params or DisaggregatedParams(),
|
||||
request_type="context_only",
|
||||
first_gen_tokens=context_phase_params.first_gen_tokens,
|
||||
@ -447,19 +446,16 @@ class GenerationResultBase:
|
||||
if response_result.context_logits is not None:
|
||||
self._context_logits = response_result.context_logits
|
||||
|
||||
if hasattr(response_result, 'mm_embedding_handle'
|
||||
) and response_result.mm_embedding_handle is not None:
|
||||
self._mm_embedding_handle = response_result.mm_embedding_handle
|
||||
if self.disaggregated_params is not None:
|
||||
self.disaggregated_params.multimodal_embedding_handles = [
|
||||
response_result.mm_embedding_handle
|
||||
],
|
||||
self.disaggregated_params.multimodal_hashes = self._multimodal_hashes
|
||||
if hasattr(response_result, "mm_embedding_handles"
|
||||
) and response_result.mm_embedding_handles is not None:
|
||||
# mm_embedding_handles is a list of handles (one per multimodal item).
|
||||
mm_embedding_handles = response_result.mm_embedding_handles
|
||||
if self._disaggregated_params is not None:
|
||||
self._disaggregated_params.multimodal_embedding_handles = mm_embedding_handles
|
||||
self._disaggregated_params.multimodal_hashes = self._multimodal_hashes
|
||||
else:
|
||||
self.disaggregated_params = DisaggregatedParams(
|
||||
multimodal_embedding_handles=[
|
||||
response_result.mm_embedding_handle
|
||||
],
|
||||
self._disaggregated_params = DisaggregatedParams(
|
||||
multimodal_embedding_handles=mm_embedding_handles,
|
||||
multimodal_hashes=self._multimodal_hashes)
|
||||
|
||||
# Handle mrope handles for both:
|
||||
@ -468,9 +464,9 @@ class GenerationResultBase:
|
||||
# were already computed in encode phase, but mrope still needs forwarding)
|
||||
if (getattr(response_result, "mrope_position_ids_handle", None)
|
||||
is not None and self.disaggregated_params is not None):
|
||||
self.disaggregated_params.mrope_position_ids_handle = (
|
||||
self._disaggregated_params.mrope_position_ids_handle = (
|
||||
response_result.mrope_position_ids_handle)
|
||||
self.disaggregated_params.mrope_position_deltas_handle = (
|
||||
self._disaggregated_params.mrope_position_deltas_handle = (
|
||||
response_result.mrope_position_deltas_handle)
|
||||
|
||||
# Processing background errors here ASAF during generation.
|
||||
@ -716,7 +712,7 @@ class GenerationResult(GenerationResultBase):
|
||||
)
|
||||
self._generation_request = generation_request
|
||||
self._streaming = generation_request.streaming
|
||||
self.disaggregated_params = disaggregated_params
|
||||
self._disaggregated_params = disaggregated_params
|
||||
# minimal sampling params needed for logprob calculation
|
||||
self._logprob_params = logprob_params
|
||||
self.trace_headers = generation_request.trace_headers
|
||||
@ -837,7 +833,7 @@ class GenerationResult(GenerationResultBase):
|
||||
'outputs',
|
||||
'finished',
|
||||
"context_logits",
|
||||
"mm_embedding_handle",
|
||||
"disaggregated_params",
|
||||
]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
|
||||
@ -60,7 +60,7 @@ class RequestOutput(DetokenizedGenerationResultBase, GenerationResult):
|
||||
prompt_token_ids (List[int]): The token ids of the prompt.
|
||||
outputs (List[CompletionOutput]): The output sequences of the request.
|
||||
context_logits (torch.Tensor, optional): The logits on the prompt token ids.
|
||||
mm_embedding_handle (Dict[str, Any], optional): The multimodal embedding handle of the request.
|
||||
disaggregated_params (DisaggregatedParams, optional): Parameters for disaggregated serving, including multimodal embedding handles.
|
||||
finished (bool): Whether the whole request is finished.
|
||||
"""
|
||||
|
||||
@ -94,7 +94,7 @@ class RequestOutput(DetokenizedGenerationResultBase, GenerationResult):
|
||||
"prompt_token_ids",
|
||||
"outputs",
|
||||
"finished",
|
||||
"mm_embedding_handle",
|
||||
"disaggregated_params",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -630,14 +630,28 @@ class OpenAIServer:
|
||||
|
||||
async def create_mm_embedding_response(promise: RequestOutput):
|
||||
await promise.aresult()
|
||||
# TODO: Replace mm_embedding_handle with a dedicated OpenAIBaseModel(JSON-safe), when enable multimodal disagg E2E
|
||||
mm_embedding_handle = getattr(promise, "mm_embedding_handle", None)
|
||||
if not mm_embedding_handle or "tensor_size" not in mm_embedding_handle:
|
||||
# TODO: Replace mm_embedding_handles with a dedicated OpenAIBaseModel(JSON-safe), when enable multimodal disagg E2E
|
||||
mm_embedding_handles = (
|
||||
promise.disaggregated_params.multimodal_embedding_handles
|
||||
if promise.disaggregated_params
|
||||
else None
|
||||
)
|
||||
if not mm_embedding_handles:
|
||||
return self.create_error_response(
|
||||
message="Multimodal embedding handle missing in response",
|
||||
err_type="InternalServerError",
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||
num_tokens = int(mm_embedding_handle["tensor_size"][0])
|
||||
if any("tensor_size" not in h for h in mm_embedding_handles):
|
||||
return self.create_error_response(
|
||||
message="Multimodal embedding handle missing tensor_size",
|
||||
err_type="InternalServerError",
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||
mm_embedding_handle = (
|
||||
mm_embedding_handles[0]
|
||||
if len(mm_embedding_handles) == 1
|
||||
else mm_embedding_handles
|
||||
)
|
||||
num_tokens = sum(int(h["tensor_size"][0]) for h in mm_embedding_handles)
|
||||
return ChatCompletionResponse(
|
||||
id=str(promise.request_id),
|
||||
model=self.model,
|
||||
|
||||
@ -122,9 +122,13 @@ def test_get_num_tokens_per_image(model_key, multimodal_model_configs):
|
||||
image_width, image_height = test_image.size
|
||||
|
||||
# Get actual embedding tensor for this image
|
||||
disagg_params = encoder_outputs[image_idx].disaggregated_params
|
||||
assert disagg_params is not None
|
||||
mm_embedding_handles = disagg_params.multimodal_embedding_handles
|
||||
assert mm_embedding_handles is not None
|
||||
assert len(mm_embedding_handles) == 1
|
||||
actual_embedding = SharedTensorContainer.from_dict(
|
||||
encoder_outputs[image_idx].mm_embedding_handle).get_local_view(
|
||||
)
|
||||
mm_embedding_handles[0]).get_local_view()
|
||||
|
||||
# The first dimension should be the number of image tokens
|
||||
actual_num_tokens = actual_embedding.shape[0]
|
||||
@ -230,9 +234,13 @@ def test_get_num_tokens_per_video(model_key, multimodal_model_configs):
|
||||
video_width, video_height = video_data.frames[0].size
|
||||
|
||||
# Get actual embedding tensor for this image
|
||||
disagg_params = encoder_outputs[video_idx].disaggregated_params
|
||||
assert disagg_params is not None
|
||||
mm_embedding_handles = disagg_params.multimodal_embedding_handles
|
||||
assert mm_embedding_handles is not None
|
||||
assert len(mm_embedding_handles) == 1
|
||||
actual_embedding = SharedTensorContainer.from_dict(
|
||||
encoder_outputs[video_idx].mm_embedding_handle).get_local_view(
|
||||
)
|
||||
mm_embedding_handles[0]).get_local_view()
|
||||
|
||||
# The first dimension should be the number of image tokens
|
||||
actual_num_tokens = actual_embedding.shape[0]
|
||||
|
||||
@ -2,7 +2,6 @@ import copy
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from itertools import product
|
||||
from pathlib import Path
|
||||
from typing import Generator
|
||||
|
||||
@ -193,8 +192,12 @@ def llms(model_dir: Path,
|
||||
"""Get LLM for prefill and, if disagg, separate LLM for decode."""
|
||||
free_gpu_memory_fraction = 0.2
|
||||
disable_overlap_scheduler = pd_disagg
|
||||
# NOTE: if the number of tokens that need to pass from P -> D exceeds `max_tokens_in_buffer`,
|
||||
# one may see the following error:
|
||||
# >>> tensorrt_llm.executor.utils.RequestError: Error in kv cache transfer for generation
|
||||
# requests.
|
||||
cache_transceiver_cfg = CacheTransceiverConfig(
|
||||
backend="DEFAULT") if pd_disagg else None
|
||||
backend="DEFAULT", max_tokens_in_buffer=10240) if pd_disagg else None
|
||||
kv_cache_config = KvCacheConfig(
|
||||
enable_block_reuse=False, # Disable for output 1:1 matching check
|
||||
free_gpu_memory_fraction=free_gpu_memory_fraction,
|
||||
@ -290,14 +293,13 @@ def _assert_handles_are_different(x: dict | None, y: dict | None) -> None:
|
||||
assert x[key] != y[key]
|
||||
|
||||
|
||||
# TODO: Add multi-image in single chat test
|
||||
@pytest.mark.threadleak(enabled=False)
|
||||
def test_single_image_chat(
|
||||
def test_single_request_chat_multiple_images(
|
||||
pd_disagg: bool,
|
||||
model_dir: Path,
|
||||
llms: tuple[LLM, LLM | None],
|
||||
):
|
||||
"""Test processing single image using encoder (pass mm_embeddings) + LLM API.
|
||||
"""Test processing a single request with multiple images.
|
||||
|
||||
This test verifies that encoder (pass mm_embeddings) + LLM API produces identical
|
||||
results to standard llm generation (pass raw image) by comparing outputs.
|
||||
@ -309,8 +311,8 @@ def test_single_image_chat(
|
||||
max_batch_size = 1
|
||||
|
||||
# Test data - OpenAI chat completion format
|
||||
prompts = ["Describe the natural environment in the image."]
|
||||
media = [example_images[0]]
|
||||
prompts = ["Compare these 2 images."]
|
||||
media = [example_images[0], example_images[1]]
|
||||
|
||||
# Sampling configuration
|
||||
sampling_params = SamplingParams(max_tokens=max_tokens)
|
||||
@ -461,8 +463,14 @@ def test_pd_disagg_with_image_input(
|
||||
f"Generated text doesn't match for output {i}, generation {j}:\nReference: {ref_gen.text!r}\nTest: {test_gen.text!r}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_mm_embeddings,pass_embeddings_through_loader",
|
||||
product([False, True], [False, True]))
|
||||
# Explicit combinations instead of product([False, True], [False, True]) to avoid
|
||||
# having to call `pytest.skip` within the test code itself. This saves on CI time, since `llms`
|
||||
# take a long time to instantiate.
|
||||
@pytest.mark.parametrize("use_mm_embeddings,pass_embeddings_through_loader", [
|
||||
(False, False),
|
||||
(True, False),
|
||||
(True, True),
|
||||
])
|
||||
@pytest.mark.threadleak(enabled=False)
|
||||
def test_multi_request_batch_chat(
|
||||
model_dir: Path,
|
||||
@ -490,6 +498,7 @@ def test_multi_request_batch_chat(
|
||||
if llm_decode is not None:
|
||||
pytest.skip("Disagg support not implemented in test case")
|
||||
|
||||
# Guard against accidental reintroduction of invalid parameter combinations.
|
||||
if pass_embeddings_through_loader and not use_mm_embeddings:
|
||||
pytest.skip("Redundant test configuration")
|
||||
|
||||
@ -522,10 +531,18 @@ def test_multi_request_batch_chat(
|
||||
encoder_outputs = encoder.generate(inputs)
|
||||
if use_mm_embeddings:
|
||||
for input, encoder_output in zip(inputs, encoder_outputs):
|
||||
mm_embed_handle = encoder_output.mm_embedding_handle
|
||||
assert mm_embed_handle is not None
|
||||
mm_embed = SharedTensorContainer.from_dict(
|
||||
mm_embed_handle).get_local_view()
|
||||
disagg_params = encoder_output.disaggregated_params
|
||||
assert disagg_params is not None
|
||||
mm_embed_handles = disagg_params.multimodal_embedding_handles
|
||||
assert mm_embed_handles is not None
|
||||
# `mm_embed_handles` is list of handles (one per multimodal item).
|
||||
# Reconstruct and concatenate all embeddings for this request.
|
||||
mm_embeds = [
|
||||
SharedTensorContainer.from_dict(handle).get_local_view()
|
||||
for handle in mm_embed_handles
|
||||
]
|
||||
mm_embed = torch.cat(
|
||||
mm_embeds, dim=0) if len(mm_embeds) > 1 else mm_embeds[0]
|
||||
input["multi_modal_embeddings"] = {"image": mm_embed}
|
||||
|
||||
if pass_embeddings_through_loader:
|
||||
|
||||
@ -19,7 +19,7 @@ import yaml
|
||||
from pydantic import BaseModel
|
||||
|
||||
import tensorrt_llm
|
||||
from tensorrt_llm import LLM
|
||||
from tensorrt_llm import LLM, DisaggregatedParams
|
||||
# Import BaseCheckpointLoader for YAML processing
|
||||
from tensorrt_llm._torch.models.checkpoints.base_checkpoint_loader import \
|
||||
BaseCheckpointLoader
|
||||
|
||||
@ -27,6 +27,6 @@ properties:
|
||||
finished:
|
||||
annotation: bool
|
||||
default: inspect._empty
|
||||
mm_embedding_handle:
|
||||
annotation: Optional[Dict[str, Any]]
|
||||
disaggregated_params:
|
||||
annotation: Optional[DisaggregatedParams]
|
||||
default: inspect._empty
|
||||
|
||||
Loading…
Reference in New Issue
Block a user