diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 9f0fba928f..37b6fa1e99 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -2118,6 +2118,23 @@ class PyTorchModelEngine(ModelEngine): request.py_multimodal_data = multimodal_params.multimodal_data multimodal_params_list.append(multimodal_params) + # Re-register mrope tensors for context-only requests (EPD disaggregated serving). + # This creates new IPC handles owned by the prefill worker, so the decode worker + # can access them even after the encode worker's GC deallocates the original memory. + # Without this, the decode worker would receive handles pointing to freed memory. + if (request.is_context_only_request and self.use_mrope and + "mrope_config" in multimodal_params.multimodal_data): + mrope_config = multimodal_params.multimodal_data[ + "mrope_config"] + _mrope_position_ids = mrope_config.get("mrope_position_ids") + _mrope_position_deltas = mrope_config.get( + "mrope_position_deltas") + if _mrope_position_ids is not None and _mrope_position_deltas is not None: + # Clone to allocate new memory owned by this (prefill) worker. + request.py_result.set_mrope_position( + _mrope_position_ids.clone(), + _mrope_position_deltas.clone()) + request.py_batch_idx = request.py_seq_slot if len(multimodal_params_list) > 0: diff --git a/tensorrt_llm/executor/result.py b/tensorrt_llm/executor/result.py index 0642c1aff1..56f6598168 100644 --- a/tensorrt_llm/executor/result.py +++ b/tensorrt_llm/executor/result.py @@ -450,8 +450,6 @@ class GenerationResultBase: if hasattr(response_result, 'mm_embedding_handle' ) and response_result.mm_embedding_handle is not None: self._mm_embedding_handle = response_result.mm_embedding_handle - mrope_position_ids_handle = response_result.mrope_position_ids_handle - mrope_position_deltas_handle = response_result.mrope_position_deltas_handle if self.disaggregated_params is not None: self.disaggregated_params.multimodal_embedding_handles = [ response_result.mm_embedding_handle @@ -463,8 +461,17 @@ class GenerationResultBase: response_result.mm_embedding_handle ], multimodal_hashes=self._multimodal_hashes) - self.disaggregated_params.mrope_position_ids_handle = mrope_position_ids_handle - self.disaggregated_params.mrope_position_deltas_handle = mrope_position_deltas_handle + + # Handle mrope handles for both: + # 1. Regular mm_embedding case (disaggregated_params was just created/updated above) + # 2. Prefill-only EPD requests (mm_embedding_handle is None because embeddings + # were already computed in encode phase, but mrope still needs forwarding) + if (getattr(response_result, "mrope_position_ids_handle", None) + is not None and self.disaggregated_params is not None): + self.disaggregated_params.mrope_position_ids_handle = ( + response_result.mrope_position_ids_handle) + self.disaggregated_params.mrope_position_deltas_handle = ( + response_result.mrope_position_deltas_handle) # Processing background errors here ASAF during generation. if self._background_error_handler and ( diff --git a/tests/unittest/_torch/multimodal/test_mm_encoder_standalone.py b/tests/unittest/_torch/multimodal/test_mm_encoder_standalone.py index 2ba776e971..2352329ccd 100644 --- a/tests/unittest/_torch/multimodal/test_mm_encoder_standalone.py +++ b/tests/unittest/_torch/multimodal/test_mm_encoder_standalone.py @@ -174,6 +174,41 @@ def _load_inputs(llm: LLM, prompts, media, mm_embeddings=None): return inputs +def _assert_handles_are_different(x: dict | None, y: dict | None) -> None: + # Helper function for checking that two SharedTensorContainer dict representations of the same + # underlying data are different. Certain metadata should stay the same (basically those describing + # the tensor's contents), while others should actually differ (those pertaining to the underlying + # storage). + matching_keys = [ + "dtype", + "event_sync_required", + "method_key", + "requires_grad", + # NOTE: this assumes the workers are on the same physical device, which is the case in + # the tests in this file since `LLM` API does not expose a way to select the device ID. + "storage_device", + "storage_size_bytes", + "tensor_offset", + "tensor_size", + "tensor_stride", + ] + + different_keys = [ + "event_handle", + "ref_counter_handle", + "ref_counter_offset", + "storage_handle", + "storage_offset_bytes", + ] + + assert set(matching_keys + different_keys) == x.keys() == y.keys() + + for key in matching_keys: + assert x[key] == y[key] + for key in different_keys: + assert x[key] != y[key] + + # TODO: Add multi-image in single chat test @pytest.mark.threadleak(enabled=False) def test_single_image_chat( @@ -226,6 +261,7 @@ def test_single_image_chat( assert ep_disaggregated_params is not None, "Encoder output disaggregated params is None" ep_disaggregated_params.request_type = "context_and_generation" if not pd_disagg else "context_only" + outputs = llm.generate(inputs, sampling_params=sampling_params, disaggregated_params=ep_disaggregated_params) @@ -234,6 +270,12 @@ def test_single_image_chat( # Generation using llm_decode assert len(outputs) == 1 pd_disaggregated_params = outputs[0].disaggregated_params + + ep_handle = ep_disaggregated_params.mrope_position_ids_handle + pd_handle = pd_disaggregated_params.mrope_position_ids_handle + assert type(ep_handle) is type(pd_handle) + if ep_handle is not None: + _assert_handles_are_different(ep_handle, pd_handle) pd_disaggregated_params.request_type = "generation_only" sampling_params = SamplingParams(max_tokens=max_tokens) # remove multimodal data from input as decoder worker doesn't need it