diff --git a/tensorrt_llm/_torch/models/modeling_llava_next.py b/tensorrt_llm/_torch/models/modeling_llava_next.py index 844c0d8958..1c27498f7a 100644 --- a/tensorrt_llm/_torch/models/modeling_llava_next.py +++ b/tensorrt_llm/_torch/models/modeling_llava_next.py @@ -176,7 +176,7 @@ class LlavaNextInputProcessor(BaseMultimodalInputProcessor, Args: inputs: Text prompt input container. Must contain a non-empty prompt string. - mm_handles: List of multimodal embedding handles. Currently only a single handle is supported. + mm_handles: List of multimodal embedding handles. Returns: Tuple[List[int], List[int], List[int]]: @@ -192,12 +192,13 @@ class LlavaNextInputProcessor(BaseMultimodalInputProcessor, if not isinstance(mm_handles, list): raise ValueError("mm_handles must be a list") - if len(mm_handles) != 1: - # TODO: only support single multimodal item within a request for now - raise NotImplementedError( - "Only one mm_handle is supported for LlavaNext for now") - hidden_size = mm_handles[0]['tensor_size'][1] - assert hidden_size == self.config.text_config.hidden_size, "Multimodal embedding hidden size must match model hidden size" + expected_hidden_size = self.config.text_config.hidden_size + for i, mm_handle in enumerate(mm_handles): + hidden_size = mm_handle['tensor_size'][1] + if hidden_size != expected_hidden_size: + raise RuntimeError( + f"Multimodal embedding {i} hidden size {hidden_size} must match model hidden size {expected_hidden_size}" + ) input_ids = self.tokenizer(text_prompt, return_tensors="pt").input_ids[0] diff --git a/tensorrt_llm/_torch/models/modeling_qwen2vl.py b/tensorrt_llm/_torch/models/modeling_qwen2vl.py index 4ed21f34be..4b465bdd64 100644 --- a/tensorrt_llm/_torch/models/modeling_qwen2vl.py +++ b/tensorrt_llm/_torch/models/modeling_qwen2vl.py @@ -1061,7 +1061,7 @@ class Qwen2_5VLInputProcessorBase(Qwen2VLInputProcessorBase): Args: inputs: Text prompt input container. Must contain a non-empty prompt string. - mm_handles: List of multimodal embedding handles. Currently only a single handle is supported. + mm_handles: List of multimodal embedding handles. Returns: Tuple[List[int], List[int], List[int]]: @@ -1077,12 +1077,13 @@ class Qwen2_5VLInputProcessorBase(Qwen2VLInputProcessorBase): if not isinstance(mm_handles, list): raise TypeError("mm_handles must be a list") - if len(mm_handles) != 1: - # TODO: only support single multimodal item within a request for now - raise NotImplementedError( - "Only one mm_handle is supported for Qwen2.5 VL for now") - hidden_size = mm_handles[0]['tensor_size'][1] - assert hidden_size == self.config.text_config.hidden_size, "Multimodal embedding hidden size must match model hidden size" + expected_hidden_size = self.config.text_config.hidden_size + for i, mm_handle in enumerate(mm_handles): + hidden_size = mm_handle['tensor_size'][1] + if hidden_size != expected_hidden_size: + raise RuntimeError( + f"Multimodal embedding {i} hidden size {hidden_size} must match model hidden size {expected_hidden_size}" + ) input_ids = self.tokenizer(text_prompt, return_tensors="pt").input_ids[0] diff --git a/tensorrt_llm/_torch/models/modeling_qwen3vl.py b/tensorrt_llm/_torch/models/modeling_qwen3vl.py index e9f77153d9..312e5ea745 100644 --- a/tensorrt_llm/_torch/models/modeling_qwen3vl.py +++ b/tensorrt_llm/_torch/models/modeling_qwen3vl.py @@ -360,7 +360,7 @@ class Qwen3VLInputProcessorBase(BaseMultimodalInputProcessor, BaseMultimodalDumm Args: inputs: Text prompt input container. Must contain a non-empty prompt string. - mm_handles: List of multimodal embedding handles. Currently only a single handle is supported. + mm_handles: List of multimodal embedding handles. Returns: Tuple[List[int], List[int], List[int]]: @@ -376,19 +376,16 @@ class Qwen3VLInputProcessorBase(BaseMultimodalInputProcessor, BaseMultimodalDumm if not isinstance(mm_handles, list): raise TypeError("mm_handles must be a list") - if len(mm_handles) > 1: - # TODO: only support single multimodal item within a request for now - raise NotImplementedError("Only one mm_handle is supported for Qwen3 VL for now") - - hidden_size = mm_handles[0]["tensor_size"][1] num_deepstack_levels = len(self.config.vision_config.deepstack_visual_indexes) # This is because, unlike previous Qwen VL models, the embeddings are concatenated with # feature maps from deepstack layers. expected_size = self.config.text_config.hidden_size * (1 + num_deepstack_levels) - if hidden_size != expected_size: - raise RuntimeError( - f"Expected multimodal embedding to have hidden size {expected_size}, got {hidden_size}." - ) + for i, mm_handle in enumerate(mm_handles): + hidden_size = mm_handle["tensor_size"][1] + if hidden_size != expected_size: + raise RuntimeError( + f"Expected multimodal embedding {i} to have hidden size {expected_size}, got {hidden_size}." + ) input_ids = self.tokenizer(text_prompt, return_tensors="pt").input_ids[0] diff --git a/tensorrt_llm/_torch/pyexecutor/llm_request.py b/tensorrt_llm/_torch/pyexecutor/llm_request.py index 490db48b08..bc10d00364 100644 --- a/tensorrt_llm/_torch/pyexecutor/llm_request.py +++ b/tensorrt_llm/_torch/pyexecutor/llm_request.py @@ -246,7 +246,7 @@ class PyResult: list[float] | None] | None = None log_probs_list: list[tuple[list[TokenLogprobs], list[float] | None]] = field(default_factory=list) - mm_embeddings: dict[str, Any] | None = None + mm_embeddings: list[dict[str, Any] | None] = None mrope_position_ids: dict[str, Any] | None = None mrope_position_deltas: dict[str, Any] | None = None additional_context_outputs_list: list[tuple[str, torch.Tensor]] = field( @@ -289,7 +289,7 @@ class PyResult: use_chunked_generation_logits=use_chunked_generation_logits, chunk_size=self._chunk_size) if return_generation_logits else None self._log_probs = LogProbStorage() if return_log_probs else None - self._mm_embeddings = None + self._mm_embeddings: Optional[List[Dict[str, Any]]] = None self._mrope_position_ids = None self._mrope_position_deltas = None self._additional_context_outputs = { @@ -362,9 +362,22 @@ class PyResult: self._log_probs.append(log_probs, cum_log_probs) self.diff.log_probs_list.append((log_probs, cum_log_probs)) - def append_mm_embeddings(self, mm_embeddings: torch.Tensor): - self._mm_embeddings = SharedTensorContainer.from_tensor( - mm_embeddings).dump_to_dict() + def append_mm_embeddings(self, mm_embeddings: torch.Tensor, + multimodal_lengths: List[int]): + """Split concatenated embeddings by multimodal_lengths and create handles for each. + + Args: + mm_embeddings: Concatenated multimodal embeddings tensor of shape [total_tokens, hidden_dim] + multimodal_lengths: List of token lengths for each multimodal item + """ + # Split the concatenated tensor by lengths to get per-item embeddings + split_embeddings = torch.split(mm_embeddings, multimodal_lengths, dim=0) + + # Create a SharedTensorContainer handle for each split + self._mm_embeddings = [ + SharedTensorContainer.from_tensor(emb).dump_to_dict() + for emb in split_embeddings + ] self.diff.mm_embeddings = self._mm_embeddings def set_mrope_position( @@ -441,7 +454,8 @@ class PyResult: return self._log_probs and self._log_probs.cum_log_probs @property - def mm_embedding_handle(self) -> Dict[str, Any] | None: + def mm_embedding_handles(self) -> List[Dict[str, Any]] | None: + """Returns a list of SharedTensorContainer handles, one per multimodal item.""" return self._mm_embeddings @property @@ -485,7 +499,7 @@ class LlmResult: """LlmResult wraps `bindings.executor.Result` but detour some features to Python implementation""" py_result_properties = frozenset( ('context_logits', 'generation_logits', 'log_probs', 'cum_log_probs', - 'mm_embedding_handle', 'additional_context_outputs', + 'mm_embedding_handles', 'additional_context_outputs', 'additional_generation_outputs', 'mrope_position_ids_handle', 'mrope_position_deltas_handle')) diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index 30943f978e..db6210da8d 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -326,7 +326,7 @@ class EarlyStopWithMMResult(Sampler[SampleStateWithMMResult]): f"mm_embedding shape mismatch: {len(mm_embedding)} != {sum(request.multimodal_lengths)}" ) - request.py_result.append_mm_embeddings(mm_embedding) + request.py_result.append_mm_embeddings(mm_embedding, request.multimodal_lengths) # Store mrope data if available if mrope_position_ids is not None and mrope_position_deltas is not None: diff --git a/tensorrt_llm/executor/result.py b/tensorrt_llm/executor/result.py index 56f6598168..9291608c7f 100644 --- a/tensorrt_llm/executor/result.py +++ b/tensorrt_llm/executor/result.py @@ -170,7 +170,7 @@ class GenerationResultBase: self.id = id self.sampling_params = sampling_params self.postproc_params = postproc_params - self.disaggregated_params = None + self._disaggregated_params = None self.decoding_iter = 0 self.cached_tokens = 0 # Average decoded tokens per runtime iteration; set when the first LLM response arrives. @@ -196,7 +196,6 @@ class GenerationResultBase: CompletionOutput(i) for i in range(self.sampling_params.best_of) ] self._context_logits: Optional[torch.Tensor] = None - self._mm_embedding_handle: Optional[Dict[str, Any]] = None self._background_error_handler = None if background_error_handler is not None: @@ -234,9 +233,9 @@ class GenerationResultBase: return self._context_logits @property - # TODO: Keep this property only for backward compatibility. In the future, access multimodal embedding handles from disaggregated_params instead. - def mm_embedding_handle(self) -> Optional[Dict[str, Any]]: - return self._mm_embedding_handle + def disaggregated_params(self) -> Optional[DisaggregatedParams]: + """Returns the disaggregated params.""" + return self._disaggregated_params def _handle_sequence(self, finish_reasons, @@ -420,7 +419,7 @@ class GenerationResultBase: # Use `replace` to preserve things like `mrope_position_ids_handle` and # `mrope_position_deltas_handle`. However, explicitly set # `multimodal_embedding_handles=None` since they should no longer be needed. - self.disaggregated_params = dataclasses.replace( + self._disaggregated_params = dataclasses.replace( existing_disagg_params or DisaggregatedParams(), request_type="context_only", first_gen_tokens=context_phase_params.first_gen_tokens, @@ -447,19 +446,16 @@ class GenerationResultBase: if response_result.context_logits is not None: self._context_logits = response_result.context_logits - if hasattr(response_result, 'mm_embedding_handle' - ) and response_result.mm_embedding_handle is not None: - self._mm_embedding_handle = response_result.mm_embedding_handle - if self.disaggregated_params is not None: - self.disaggregated_params.multimodal_embedding_handles = [ - response_result.mm_embedding_handle - ], - self.disaggregated_params.multimodal_hashes = self._multimodal_hashes + if hasattr(response_result, "mm_embedding_handles" + ) and response_result.mm_embedding_handles is not None: + # mm_embedding_handles is a list of handles (one per multimodal item). + mm_embedding_handles = response_result.mm_embedding_handles + if self._disaggregated_params is not None: + self._disaggregated_params.multimodal_embedding_handles = mm_embedding_handles + self._disaggregated_params.multimodal_hashes = self._multimodal_hashes else: - self.disaggregated_params = DisaggregatedParams( - multimodal_embedding_handles=[ - response_result.mm_embedding_handle - ], + self._disaggregated_params = DisaggregatedParams( + multimodal_embedding_handles=mm_embedding_handles, multimodal_hashes=self._multimodal_hashes) # Handle mrope handles for both: @@ -468,9 +464,9 @@ class GenerationResultBase: # were already computed in encode phase, but mrope still needs forwarding) if (getattr(response_result, "mrope_position_ids_handle", None) is not None and self.disaggregated_params is not None): - self.disaggregated_params.mrope_position_ids_handle = ( + self._disaggregated_params.mrope_position_ids_handle = ( response_result.mrope_position_ids_handle) - self.disaggregated_params.mrope_position_deltas_handle = ( + self._disaggregated_params.mrope_position_deltas_handle = ( response_result.mrope_position_deltas_handle) # Processing background errors here ASAF during generation. @@ -716,7 +712,7 @@ class GenerationResult(GenerationResultBase): ) self._generation_request = generation_request self._streaming = generation_request.streaming - self.disaggregated_params = disaggregated_params + self._disaggregated_params = disaggregated_params # minimal sampling params needed for logprob calculation self._logprob_params = logprob_params self.trace_headers = generation_request.trace_headers @@ -837,7 +833,7 @@ class GenerationResult(GenerationResultBase): 'outputs', 'finished', "context_logits", - "mm_embedding_handle", + "disaggregated_params", ] def __repr__(self) -> str: diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index 415751e7f2..e95a0d4551 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -60,7 +60,7 @@ class RequestOutput(DetokenizedGenerationResultBase, GenerationResult): prompt_token_ids (List[int]): The token ids of the prompt. outputs (List[CompletionOutput]): The output sequences of the request. context_logits (torch.Tensor, optional): The logits on the prompt token ids. - mm_embedding_handle (Dict[str, Any], optional): The multimodal embedding handle of the request. + disaggregated_params (DisaggregatedParams, optional): Parameters for disaggregated serving, including multimodal embedding handles. finished (bool): Whether the whole request is finished. """ @@ -94,7 +94,7 @@ class RequestOutput(DetokenizedGenerationResultBase, GenerationResult): "prompt_token_ids", "outputs", "finished", - "mm_embedding_handle", + "disaggregated_params", ] diff --git a/tensorrt_llm/serve/openai_server.py b/tensorrt_llm/serve/openai_server.py index eee3b17de6..0301fa5b40 100644 --- a/tensorrt_llm/serve/openai_server.py +++ b/tensorrt_llm/serve/openai_server.py @@ -630,14 +630,28 @@ class OpenAIServer: async def create_mm_embedding_response(promise: RequestOutput): await promise.aresult() - # TODO: Replace mm_embedding_handle with a dedicated OpenAIBaseModel(JSON-safe), when enable multimodal disagg E2E - mm_embedding_handle = getattr(promise, "mm_embedding_handle", None) - if not mm_embedding_handle or "tensor_size" not in mm_embedding_handle: + # TODO: Replace mm_embedding_handles with a dedicated OpenAIBaseModel(JSON-safe), when enable multimodal disagg E2E + mm_embedding_handles = ( + promise.disaggregated_params.multimodal_embedding_handles + if promise.disaggregated_params + else None + ) + if not mm_embedding_handles: return self.create_error_response( message="Multimodal embedding handle missing in response", err_type="InternalServerError", status_code=HTTPStatus.INTERNAL_SERVER_ERROR) - num_tokens = int(mm_embedding_handle["tensor_size"][0]) + if any("tensor_size" not in h for h in mm_embedding_handles): + return self.create_error_response( + message="Multimodal embedding handle missing tensor_size", + err_type="InternalServerError", + status_code=HTTPStatus.INTERNAL_SERVER_ERROR) + mm_embedding_handle = ( + mm_embedding_handles[0] + if len(mm_embedding_handles) == 1 + else mm_embedding_handles + ) + num_tokens = sum(int(h["tensor_size"][0]) for h in mm_embedding_handles) return ChatCompletionResponse( id=str(promise.request_id), model=self.model, diff --git a/tests/unittest/_torch/multimodal/test_find_num_image_tokens.py b/tests/unittest/_torch/multimodal/test_find_num_image_tokens.py index 688e3d0ef4..f60954a3e1 100644 --- a/tests/unittest/_torch/multimodal/test_find_num_image_tokens.py +++ b/tests/unittest/_torch/multimodal/test_find_num_image_tokens.py @@ -122,9 +122,13 @@ def test_get_num_tokens_per_image(model_key, multimodal_model_configs): image_width, image_height = test_image.size # Get actual embedding tensor for this image + disagg_params = encoder_outputs[image_idx].disaggregated_params + assert disagg_params is not None + mm_embedding_handles = disagg_params.multimodal_embedding_handles + assert mm_embedding_handles is not None + assert len(mm_embedding_handles) == 1 actual_embedding = SharedTensorContainer.from_dict( - encoder_outputs[image_idx].mm_embedding_handle).get_local_view( - ) + mm_embedding_handles[0]).get_local_view() # The first dimension should be the number of image tokens actual_num_tokens = actual_embedding.shape[0] @@ -230,9 +234,13 @@ def test_get_num_tokens_per_video(model_key, multimodal_model_configs): video_width, video_height = video_data.frames[0].size # Get actual embedding tensor for this image + disagg_params = encoder_outputs[video_idx].disaggregated_params + assert disagg_params is not None + mm_embedding_handles = disagg_params.multimodal_embedding_handles + assert mm_embedding_handles is not None + assert len(mm_embedding_handles) == 1 actual_embedding = SharedTensorContainer.from_dict( - encoder_outputs[video_idx].mm_embedding_handle).get_local_view( - ) + mm_embedding_handles[0]).get_local_view() # The first dimension should be the number of image tokens actual_num_tokens = actual_embedding.shape[0] diff --git a/tests/unittest/_torch/multimodal/test_mm_encoder_standalone.py b/tests/unittest/_torch/multimodal/test_mm_encoder_standalone.py index 4cc6e2e19d..abdb45d1f5 100644 --- a/tests/unittest/_torch/multimodal/test_mm_encoder_standalone.py +++ b/tests/unittest/_torch/multimodal/test_mm_encoder_standalone.py @@ -2,7 +2,6 @@ import copy import json import os import time -from itertools import product from pathlib import Path from typing import Generator @@ -193,8 +192,12 @@ def llms(model_dir: Path, """Get LLM for prefill and, if disagg, separate LLM for decode.""" free_gpu_memory_fraction = 0.2 disable_overlap_scheduler = pd_disagg + # NOTE: if the number of tokens that need to pass from P -> D exceeds `max_tokens_in_buffer`, + # one may see the following error: + # >>> tensorrt_llm.executor.utils.RequestError: Error in kv cache transfer for generation + # requests. cache_transceiver_cfg = CacheTransceiverConfig( - backend="DEFAULT") if pd_disagg else None + backend="DEFAULT", max_tokens_in_buffer=10240) if pd_disagg else None kv_cache_config = KvCacheConfig( enable_block_reuse=False, # Disable for output 1:1 matching check free_gpu_memory_fraction=free_gpu_memory_fraction, @@ -290,14 +293,13 @@ def _assert_handles_are_different(x: dict | None, y: dict | None) -> None: assert x[key] != y[key] -# TODO: Add multi-image in single chat test @pytest.mark.threadleak(enabled=False) -def test_single_image_chat( +def test_single_request_chat_multiple_images( pd_disagg: bool, model_dir: Path, llms: tuple[LLM, LLM | None], ): - """Test processing single image using encoder (pass mm_embeddings) + LLM API. + """Test processing a single request with multiple images. This test verifies that encoder (pass mm_embeddings) + LLM API produces identical results to standard llm generation (pass raw image) by comparing outputs. @@ -309,8 +311,8 @@ def test_single_image_chat( max_batch_size = 1 # Test data - OpenAI chat completion format - prompts = ["Describe the natural environment in the image."] - media = [example_images[0]] + prompts = ["Compare these 2 images."] + media = [example_images[0], example_images[1]] # Sampling configuration sampling_params = SamplingParams(max_tokens=max_tokens) @@ -461,8 +463,14 @@ def test_pd_disagg_with_image_input( f"Generated text doesn't match for output {i}, generation {j}:\nReference: {ref_gen.text!r}\nTest: {test_gen.text!r}" -@pytest.mark.parametrize("use_mm_embeddings,pass_embeddings_through_loader", - product([False, True], [False, True])) +# Explicit combinations instead of product([False, True], [False, True]) to avoid +# having to call `pytest.skip` within the test code itself. This saves on CI time, since `llms` +# take a long time to instantiate. +@pytest.mark.parametrize("use_mm_embeddings,pass_embeddings_through_loader", [ + (False, False), + (True, False), + (True, True), +]) @pytest.mark.threadleak(enabled=False) def test_multi_request_batch_chat( model_dir: Path, @@ -490,6 +498,7 @@ def test_multi_request_batch_chat( if llm_decode is not None: pytest.skip("Disagg support not implemented in test case") + # Guard against accidental reintroduction of invalid parameter combinations. if pass_embeddings_through_loader and not use_mm_embeddings: pytest.skip("Redundant test configuration") @@ -522,10 +531,18 @@ def test_multi_request_batch_chat( encoder_outputs = encoder.generate(inputs) if use_mm_embeddings: for input, encoder_output in zip(inputs, encoder_outputs): - mm_embed_handle = encoder_output.mm_embedding_handle - assert mm_embed_handle is not None - mm_embed = SharedTensorContainer.from_dict( - mm_embed_handle).get_local_view() + disagg_params = encoder_output.disaggregated_params + assert disagg_params is not None + mm_embed_handles = disagg_params.multimodal_embedding_handles + assert mm_embed_handles is not None + # `mm_embed_handles` is list of handles (one per multimodal item). + # Reconstruct and concatenate all embeddings for this request. + mm_embeds = [ + SharedTensorContainer.from_dict(handle).get_local_view() + for handle in mm_embed_handles + ] + mm_embed = torch.cat( + mm_embeds, dim=0) if len(mm_embeds) > 1 else mm_embeds[0] input["multi_modal_embeddings"] = {"image": mm_embed} if pass_embeddings_through_loader: diff --git a/tests/unittest/api_stability/api_stability_core.py b/tests/unittest/api_stability/api_stability_core.py index 53220f4468..d39d67b6a3 100644 --- a/tests/unittest/api_stability/api_stability_core.py +++ b/tests/unittest/api_stability/api_stability_core.py @@ -19,7 +19,7 @@ import yaml from pydantic import BaseModel import tensorrt_llm -from tensorrt_llm import LLM +from tensorrt_llm import LLM, DisaggregatedParams # Import BaseCheckpointLoader for YAML processing from tensorrt_llm._torch.models.checkpoints.base_checkpoint_loader import \ BaseCheckpointLoader diff --git a/tests/unittest/api_stability/references_committed/request_output.yaml b/tests/unittest/api_stability/references_committed/request_output.yaml index 62e8ec2347..b20b585ac4 100644 --- a/tests/unittest/api_stability/references_committed/request_output.yaml +++ b/tests/unittest/api_stability/references_committed/request_output.yaml @@ -27,6 +27,6 @@ properties: finished: annotation: bool default: inspect._empty - mm_embedding_handle: - annotation: Optional[Dict[str, Any]] + disaggregated_params: + annotation: Optional[DisaggregatedParams] default: inspect._empty