[TRTLLM-10858][feat] Multi-image support for EPD disagg (#11264)

* Why?

Prior to this commit, we only supported a single multimodal input for
E/P/D disaggregated serving.

* What?

This commit does a minor refactor of the multimodal embedding handles
that cross process boundaries to enable this.
Existing unit tests are updated accordingly to test this.

The `RequestOutput` has its `mm_embedding_handle` replaced in favor of
`disaggregated_params`, addressing a previous TODO.

Signed-off-by: William Zhang <133824995+2ez4bz@users.noreply.github.com>
This commit is contained in:
William Zhang 2026-02-11 20:50:00 -08:00 committed by GitHub
parent 42648734b8
commit ca9537e17c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 128 additions and 80 deletions

View File

@ -176,7 +176,7 @@ class LlavaNextInputProcessor(BaseMultimodalInputProcessor,
Args:
inputs: Text prompt input container. Must contain a non-empty prompt string.
mm_handles: List of multimodal embedding handles. Currently only a single handle is supported.
mm_handles: List of multimodal embedding handles.
Returns:
Tuple[List[int], List[int], List[int]]:
@ -192,12 +192,13 @@ class LlavaNextInputProcessor(BaseMultimodalInputProcessor,
if not isinstance(mm_handles, list):
raise ValueError("mm_handles must be a list")
if len(mm_handles) != 1:
# TODO: only support single multimodal item within a request for now
raise NotImplementedError(
"Only one mm_handle is supported for LlavaNext for now")
hidden_size = mm_handles[0]['tensor_size'][1]
assert hidden_size == self.config.text_config.hidden_size, "Multimodal embedding hidden size must match model hidden size"
expected_hidden_size = self.config.text_config.hidden_size
for i, mm_handle in enumerate(mm_handles):
hidden_size = mm_handle['tensor_size'][1]
if hidden_size != expected_hidden_size:
raise RuntimeError(
f"Multimodal embedding {i} hidden size {hidden_size} must match model hidden size {expected_hidden_size}"
)
input_ids = self.tokenizer(text_prompt,
return_tensors="pt").input_ids[0]

View File

@ -1061,7 +1061,7 @@ class Qwen2_5VLInputProcessorBase(Qwen2VLInputProcessorBase):
Args:
inputs: Text prompt input container. Must contain a non-empty prompt string.
mm_handles: List of multimodal embedding handles. Currently only a single handle is supported.
mm_handles: List of multimodal embedding handles.
Returns:
Tuple[List[int], List[int], List[int]]:
@ -1077,12 +1077,13 @@ class Qwen2_5VLInputProcessorBase(Qwen2VLInputProcessorBase):
if not isinstance(mm_handles, list):
raise TypeError("mm_handles must be a list")
if len(mm_handles) != 1:
# TODO: only support single multimodal item within a request for now
raise NotImplementedError(
"Only one mm_handle is supported for Qwen2.5 VL for now")
hidden_size = mm_handles[0]['tensor_size'][1]
assert hidden_size == self.config.text_config.hidden_size, "Multimodal embedding hidden size must match model hidden size"
expected_hidden_size = self.config.text_config.hidden_size
for i, mm_handle in enumerate(mm_handles):
hidden_size = mm_handle['tensor_size'][1]
if hidden_size != expected_hidden_size:
raise RuntimeError(
f"Multimodal embedding {i} hidden size {hidden_size} must match model hidden size {expected_hidden_size}"
)
input_ids = self.tokenizer(text_prompt,
return_tensors="pt").input_ids[0]

View File

@ -360,7 +360,7 @@ class Qwen3VLInputProcessorBase(BaseMultimodalInputProcessor, BaseMultimodalDumm
Args:
inputs: Text prompt input container. Must contain a non-empty prompt string.
mm_handles: List of multimodal embedding handles. Currently only a single handle is supported.
mm_handles: List of multimodal embedding handles.
Returns:
Tuple[List[int], List[int], List[int]]:
@ -376,19 +376,16 @@ class Qwen3VLInputProcessorBase(BaseMultimodalInputProcessor, BaseMultimodalDumm
if not isinstance(mm_handles, list):
raise TypeError("mm_handles must be a list")
if len(mm_handles) > 1:
# TODO: only support single multimodal item within a request for now
raise NotImplementedError("Only one mm_handle is supported for Qwen3 VL for now")
hidden_size = mm_handles[0]["tensor_size"][1]
num_deepstack_levels = len(self.config.vision_config.deepstack_visual_indexes)
# This is because, unlike previous Qwen VL models, the embeddings are concatenated with
# feature maps from deepstack layers.
expected_size = self.config.text_config.hidden_size * (1 + num_deepstack_levels)
if hidden_size != expected_size:
raise RuntimeError(
f"Expected multimodal embedding to have hidden size {expected_size}, got {hidden_size}."
)
for i, mm_handle in enumerate(mm_handles):
hidden_size = mm_handle["tensor_size"][1]
if hidden_size != expected_size:
raise RuntimeError(
f"Expected multimodal embedding {i} to have hidden size {expected_size}, got {hidden_size}."
)
input_ids = self.tokenizer(text_prompt, return_tensors="pt").input_ids[0]

View File

@ -246,7 +246,7 @@ class PyResult:
list[float] | None] | None = None
log_probs_list: list[tuple[list[TokenLogprobs], list[float]
| None]] = field(default_factory=list)
mm_embeddings: dict[str, Any] | None = None
mm_embeddings: list[dict[str, Any] | None] = None
mrope_position_ids: dict[str, Any] | None = None
mrope_position_deltas: dict[str, Any] | None = None
additional_context_outputs_list: list[tuple[str, torch.Tensor]] = field(
@ -289,7 +289,7 @@ class PyResult:
use_chunked_generation_logits=use_chunked_generation_logits,
chunk_size=self._chunk_size) if return_generation_logits else None
self._log_probs = LogProbStorage() if return_log_probs else None
self._mm_embeddings = None
self._mm_embeddings: Optional[List[Dict[str, Any]]] = None
self._mrope_position_ids = None
self._mrope_position_deltas = None
self._additional_context_outputs = {
@ -362,9 +362,22 @@ class PyResult:
self._log_probs.append(log_probs, cum_log_probs)
self.diff.log_probs_list.append((log_probs, cum_log_probs))
def append_mm_embeddings(self, mm_embeddings: torch.Tensor):
self._mm_embeddings = SharedTensorContainer.from_tensor(
mm_embeddings).dump_to_dict()
def append_mm_embeddings(self, mm_embeddings: torch.Tensor,
multimodal_lengths: List[int]):
"""Split concatenated embeddings by multimodal_lengths and create handles for each.
Args:
mm_embeddings: Concatenated multimodal embeddings tensor of shape [total_tokens, hidden_dim]
multimodal_lengths: List of token lengths for each multimodal item
"""
# Split the concatenated tensor by lengths to get per-item embeddings
split_embeddings = torch.split(mm_embeddings, multimodal_lengths, dim=0)
# Create a SharedTensorContainer handle for each split
self._mm_embeddings = [
SharedTensorContainer.from_tensor(emb).dump_to_dict()
for emb in split_embeddings
]
self.diff.mm_embeddings = self._mm_embeddings
def set_mrope_position(
@ -441,7 +454,8 @@ class PyResult:
return self._log_probs and self._log_probs.cum_log_probs
@property
def mm_embedding_handle(self) -> Dict[str, Any] | None:
def mm_embedding_handles(self) -> List[Dict[str, Any]] | None:
"""Returns a list of SharedTensorContainer handles, one per multimodal item."""
return self._mm_embeddings
@property
@ -485,7 +499,7 @@ class LlmResult:
"""LlmResult wraps `bindings.executor.Result` but detour some features to Python implementation"""
py_result_properties = frozenset(
('context_logits', 'generation_logits', 'log_probs', 'cum_log_probs',
'mm_embedding_handle', 'additional_context_outputs',
'mm_embedding_handles', 'additional_context_outputs',
'additional_generation_outputs', 'mrope_position_ids_handle',
'mrope_position_deltas_handle'))

View File

@ -326,7 +326,7 @@ class EarlyStopWithMMResult(Sampler[SampleStateWithMMResult]):
f"mm_embedding shape mismatch: {len(mm_embedding)} != {sum(request.multimodal_lengths)}"
)
request.py_result.append_mm_embeddings(mm_embedding)
request.py_result.append_mm_embeddings(mm_embedding, request.multimodal_lengths)
# Store mrope data if available
if mrope_position_ids is not None and mrope_position_deltas is not None:

View File

@ -170,7 +170,7 @@ class GenerationResultBase:
self.id = id
self.sampling_params = sampling_params
self.postproc_params = postproc_params
self.disaggregated_params = None
self._disaggregated_params = None
self.decoding_iter = 0
self.cached_tokens = 0
# Average decoded tokens per runtime iteration; set when the first LLM response arrives.
@ -196,7 +196,6 @@ class GenerationResultBase:
CompletionOutput(i) for i in range(self.sampling_params.best_of)
]
self._context_logits: Optional[torch.Tensor] = None
self._mm_embedding_handle: Optional[Dict[str, Any]] = None
self._background_error_handler = None
if background_error_handler is not None:
@ -234,9 +233,9 @@ class GenerationResultBase:
return self._context_logits
@property
# TODO: Keep this property only for backward compatibility. In the future, access multimodal embedding handles from disaggregated_params instead.
def mm_embedding_handle(self) -> Optional[Dict[str, Any]]:
return self._mm_embedding_handle
def disaggregated_params(self) -> Optional[DisaggregatedParams]:
"""Returns the disaggregated params."""
return self._disaggregated_params
def _handle_sequence(self,
finish_reasons,
@ -420,7 +419,7 @@ class GenerationResultBase:
# Use `replace` to preserve things like `mrope_position_ids_handle` and
# `mrope_position_deltas_handle`. However, explicitly set
# `multimodal_embedding_handles=None` since they should no longer be needed.
self.disaggregated_params = dataclasses.replace(
self._disaggregated_params = dataclasses.replace(
existing_disagg_params or DisaggregatedParams(),
request_type="context_only",
first_gen_tokens=context_phase_params.first_gen_tokens,
@ -447,19 +446,16 @@ class GenerationResultBase:
if response_result.context_logits is not None:
self._context_logits = response_result.context_logits
if hasattr(response_result, 'mm_embedding_handle'
) and response_result.mm_embedding_handle is not None:
self._mm_embedding_handle = response_result.mm_embedding_handle
if self.disaggregated_params is not None:
self.disaggregated_params.multimodal_embedding_handles = [
response_result.mm_embedding_handle
],
self.disaggregated_params.multimodal_hashes = self._multimodal_hashes
if hasattr(response_result, "mm_embedding_handles"
) and response_result.mm_embedding_handles is not None:
# mm_embedding_handles is a list of handles (one per multimodal item).
mm_embedding_handles = response_result.mm_embedding_handles
if self._disaggregated_params is not None:
self._disaggregated_params.multimodal_embedding_handles = mm_embedding_handles
self._disaggregated_params.multimodal_hashes = self._multimodal_hashes
else:
self.disaggregated_params = DisaggregatedParams(
multimodal_embedding_handles=[
response_result.mm_embedding_handle
],
self._disaggregated_params = DisaggregatedParams(
multimodal_embedding_handles=mm_embedding_handles,
multimodal_hashes=self._multimodal_hashes)
# Handle mrope handles for both:
@ -468,9 +464,9 @@ class GenerationResultBase:
# were already computed in encode phase, but mrope still needs forwarding)
if (getattr(response_result, "mrope_position_ids_handle", None)
is not None and self.disaggregated_params is not None):
self.disaggregated_params.mrope_position_ids_handle = (
self._disaggregated_params.mrope_position_ids_handle = (
response_result.mrope_position_ids_handle)
self.disaggregated_params.mrope_position_deltas_handle = (
self._disaggregated_params.mrope_position_deltas_handle = (
response_result.mrope_position_deltas_handle)
# Processing background errors here ASAF during generation.
@ -716,7 +712,7 @@ class GenerationResult(GenerationResultBase):
)
self._generation_request = generation_request
self._streaming = generation_request.streaming
self.disaggregated_params = disaggregated_params
self._disaggregated_params = disaggregated_params
# minimal sampling params needed for logprob calculation
self._logprob_params = logprob_params
self.trace_headers = generation_request.trace_headers
@ -837,7 +833,7 @@ class GenerationResult(GenerationResultBase):
'outputs',
'finished',
"context_logits",
"mm_embedding_handle",
"disaggregated_params",
]
def __repr__(self) -> str:

View File

@ -60,7 +60,7 @@ class RequestOutput(DetokenizedGenerationResultBase, GenerationResult):
prompt_token_ids (List[int]): The token ids of the prompt.
outputs (List[CompletionOutput]): The output sequences of the request.
context_logits (torch.Tensor, optional): The logits on the prompt token ids.
mm_embedding_handle (Dict[str, Any], optional): The multimodal embedding handle of the request.
disaggregated_params (DisaggregatedParams, optional): Parameters for disaggregated serving, including multimodal embedding handles.
finished (bool): Whether the whole request is finished.
"""
@ -94,7 +94,7 @@ class RequestOutput(DetokenizedGenerationResultBase, GenerationResult):
"prompt_token_ids",
"outputs",
"finished",
"mm_embedding_handle",
"disaggregated_params",
]

View File

@ -630,14 +630,28 @@ class OpenAIServer:
async def create_mm_embedding_response(promise: RequestOutput):
await promise.aresult()
# TODO: Replace mm_embedding_handle with a dedicated OpenAIBaseModel(JSON-safe), when enable multimodal disagg E2E
mm_embedding_handle = getattr(promise, "mm_embedding_handle", None)
if not mm_embedding_handle or "tensor_size" not in mm_embedding_handle:
# TODO: Replace mm_embedding_handles with a dedicated OpenAIBaseModel(JSON-safe), when enable multimodal disagg E2E
mm_embedding_handles = (
promise.disaggregated_params.multimodal_embedding_handles
if promise.disaggregated_params
else None
)
if not mm_embedding_handles:
return self.create_error_response(
message="Multimodal embedding handle missing in response",
err_type="InternalServerError",
status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
num_tokens = int(mm_embedding_handle["tensor_size"][0])
if any("tensor_size" not in h for h in mm_embedding_handles):
return self.create_error_response(
message="Multimodal embedding handle missing tensor_size",
err_type="InternalServerError",
status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
mm_embedding_handle = (
mm_embedding_handles[0]
if len(mm_embedding_handles) == 1
else mm_embedding_handles
)
num_tokens = sum(int(h["tensor_size"][0]) for h in mm_embedding_handles)
return ChatCompletionResponse(
id=str(promise.request_id),
model=self.model,

View File

@ -122,9 +122,13 @@ def test_get_num_tokens_per_image(model_key, multimodal_model_configs):
image_width, image_height = test_image.size
# Get actual embedding tensor for this image
disagg_params = encoder_outputs[image_idx].disaggregated_params
assert disagg_params is not None
mm_embedding_handles = disagg_params.multimodal_embedding_handles
assert mm_embedding_handles is not None
assert len(mm_embedding_handles) == 1
actual_embedding = SharedTensorContainer.from_dict(
encoder_outputs[image_idx].mm_embedding_handle).get_local_view(
)
mm_embedding_handles[0]).get_local_view()
# The first dimension should be the number of image tokens
actual_num_tokens = actual_embedding.shape[0]
@ -230,9 +234,13 @@ def test_get_num_tokens_per_video(model_key, multimodal_model_configs):
video_width, video_height = video_data.frames[0].size
# Get actual embedding tensor for this image
disagg_params = encoder_outputs[video_idx].disaggregated_params
assert disagg_params is not None
mm_embedding_handles = disagg_params.multimodal_embedding_handles
assert mm_embedding_handles is not None
assert len(mm_embedding_handles) == 1
actual_embedding = SharedTensorContainer.from_dict(
encoder_outputs[video_idx].mm_embedding_handle).get_local_view(
)
mm_embedding_handles[0]).get_local_view()
# The first dimension should be the number of image tokens
actual_num_tokens = actual_embedding.shape[0]

View File

@ -2,7 +2,6 @@ import copy
import json
import os
import time
from itertools import product
from pathlib import Path
from typing import Generator
@ -193,8 +192,12 @@ def llms(model_dir: Path,
"""Get LLM for prefill and, if disagg, separate LLM for decode."""
free_gpu_memory_fraction = 0.2
disable_overlap_scheduler = pd_disagg
# NOTE: if the number of tokens that need to pass from P -> D exceeds `max_tokens_in_buffer`,
# one may see the following error:
# >>> tensorrt_llm.executor.utils.RequestError: Error in kv cache transfer for generation
# requests.
cache_transceiver_cfg = CacheTransceiverConfig(
backend="DEFAULT") if pd_disagg else None
backend="DEFAULT", max_tokens_in_buffer=10240) if pd_disagg else None
kv_cache_config = KvCacheConfig(
enable_block_reuse=False, # Disable for output 1:1 matching check
free_gpu_memory_fraction=free_gpu_memory_fraction,
@ -290,14 +293,13 @@ def _assert_handles_are_different(x: dict | None, y: dict | None) -> None:
assert x[key] != y[key]
# TODO: Add multi-image in single chat test
@pytest.mark.threadleak(enabled=False)
def test_single_image_chat(
def test_single_request_chat_multiple_images(
pd_disagg: bool,
model_dir: Path,
llms: tuple[LLM, LLM | None],
):
"""Test processing single image using encoder (pass mm_embeddings) + LLM API.
"""Test processing a single request with multiple images.
This test verifies that encoder (pass mm_embeddings) + LLM API produces identical
results to standard llm generation (pass raw image) by comparing outputs.
@ -309,8 +311,8 @@ def test_single_image_chat(
max_batch_size = 1
# Test data - OpenAI chat completion format
prompts = ["Describe the natural environment in the image."]
media = [example_images[0]]
prompts = ["Compare these 2 images."]
media = [example_images[0], example_images[1]]
# Sampling configuration
sampling_params = SamplingParams(max_tokens=max_tokens)
@ -461,8 +463,14 @@ def test_pd_disagg_with_image_input(
f"Generated text doesn't match for output {i}, generation {j}:\nReference: {ref_gen.text!r}\nTest: {test_gen.text!r}"
@pytest.mark.parametrize("use_mm_embeddings,pass_embeddings_through_loader",
product([False, True], [False, True]))
# Explicit combinations instead of product([False, True], [False, True]) to avoid
# having to call `pytest.skip` within the test code itself. This saves on CI time, since `llms`
# take a long time to instantiate.
@pytest.mark.parametrize("use_mm_embeddings,pass_embeddings_through_loader", [
(False, False),
(True, False),
(True, True),
])
@pytest.mark.threadleak(enabled=False)
def test_multi_request_batch_chat(
model_dir: Path,
@ -490,6 +498,7 @@ def test_multi_request_batch_chat(
if llm_decode is not None:
pytest.skip("Disagg support not implemented in test case")
# Guard against accidental reintroduction of invalid parameter combinations.
if pass_embeddings_through_loader and not use_mm_embeddings:
pytest.skip("Redundant test configuration")
@ -522,10 +531,18 @@ def test_multi_request_batch_chat(
encoder_outputs = encoder.generate(inputs)
if use_mm_embeddings:
for input, encoder_output in zip(inputs, encoder_outputs):
mm_embed_handle = encoder_output.mm_embedding_handle
assert mm_embed_handle is not None
mm_embed = SharedTensorContainer.from_dict(
mm_embed_handle).get_local_view()
disagg_params = encoder_output.disaggregated_params
assert disagg_params is not None
mm_embed_handles = disagg_params.multimodal_embedding_handles
assert mm_embed_handles is not None
# `mm_embed_handles` is list of handles (one per multimodal item).
# Reconstruct and concatenate all embeddings for this request.
mm_embeds = [
SharedTensorContainer.from_dict(handle).get_local_view()
for handle in mm_embed_handles
]
mm_embed = torch.cat(
mm_embeds, dim=0) if len(mm_embeds) > 1 else mm_embeds[0]
input["multi_modal_embeddings"] = {"image": mm_embed}
if pass_embeddings_through_loader:

View File

@ -19,7 +19,7 @@ import yaml
from pydantic import BaseModel
import tensorrt_llm
from tensorrt_llm import LLM
from tensorrt_llm import LLM, DisaggregatedParams
# Import BaseCheckpointLoader for YAML processing
from tensorrt_llm._torch.models.checkpoints.base_checkpoint_loader import \
BaseCheckpointLoader

View File

@ -27,6 +27,6 @@ properties:
finished:
annotation: bool
default: inspect._empty
mm_embedding_handle:
annotation: Optional[Dict[str, Any]]
disaggregated_params:
annotation: Optional[DisaggregatedParams]
default: inspect._empty