From ffc0f54959411f714131fbc1ccab86b5e4fc4dcf Mon Sep 17 00:00:00 2001 From: William Zhang <133824995+2ez4bz@users.noreply.github.com> Date: Fri, 6 Feb 2026 19:37:42 -0800 Subject: [PATCH] [https://nvbugs/5848756][fix] Re-take ownership of mrope tensors in prefill worker (#11217) * Why? Previously, the mrope tensors' IPC handles would just be forwarded from encode -> prefill -> decode workers. While this is fine for the prefill worker, it is not for the decode worker, since by the time it tries to rebuild those tensors, they could have been garbage collected due to their refcounts reaching zero in the producer (encode) worker. This could lead to nasty runtime errors when running E/P/D disaggregated serving. * What? This commit fixes this by having the prefill worker take ownership of those reconstructed tensors, and stand up new copies for the decode worker. Closes: NvBug 5848756 Signed-off-by: William Zhang <133824995+2ez4bz@users.noreply.github.com> --- .../_torch/pyexecutor/model_engine.py | 17 ++++++++ tensorrt_llm/executor/result.py | 15 +++++-- .../multimodal/test_mm_encoder_standalone.py | 42 +++++++++++++++++++ 3 files changed, 70 insertions(+), 4 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 9f0fba928f..37b6fa1e99 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -2118,6 +2118,23 @@ class PyTorchModelEngine(ModelEngine): request.py_multimodal_data = multimodal_params.multimodal_data multimodal_params_list.append(multimodal_params) + # Re-register mrope tensors for context-only requests (EPD disaggregated serving). + # This creates new IPC handles owned by the prefill worker, so the decode worker + # can access them even after the encode worker's GC deallocates the original memory. + # Without this, the decode worker would receive handles pointing to freed memory. + if (request.is_context_only_request and self.use_mrope and + "mrope_config" in multimodal_params.multimodal_data): + mrope_config = multimodal_params.multimodal_data[ + "mrope_config"] + _mrope_position_ids = mrope_config.get("mrope_position_ids") + _mrope_position_deltas = mrope_config.get( + "mrope_position_deltas") + if _mrope_position_ids is not None and _mrope_position_deltas is not None: + # Clone to allocate new memory owned by this (prefill) worker. + request.py_result.set_mrope_position( + _mrope_position_ids.clone(), + _mrope_position_deltas.clone()) + request.py_batch_idx = request.py_seq_slot if len(multimodal_params_list) > 0: diff --git a/tensorrt_llm/executor/result.py b/tensorrt_llm/executor/result.py index 0642c1aff1..56f6598168 100644 --- a/tensorrt_llm/executor/result.py +++ b/tensorrt_llm/executor/result.py @@ -450,8 +450,6 @@ class GenerationResultBase: if hasattr(response_result, 'mm_embedding_handle' ) and response_result.mm_embedding_handle is not None: self._mm_embedding_handle = response_result.mm_embedding_handle - mrope_position_ids_handle = response_result.mrope_position_ids_handle - mrope_position_deltas_handle = response_result.mrope_position_deltas_handle if self.disaggregated_params is not None: self.disaggregated_params.multimodal_embedding_handles = [ response_result.mm_embedding_handle @@ -463,8 +461,17 @@ class GenerationResultBase: response_result.mm_embedding_handle ], multimodal_hashes=self._multimodal_hashes) - self.disaggregated_params.mrope_position_ids_handle = mrope_position_ids_handle - self.disaggregated_params.mrope_position_deltas_handle = mrope_position_deltas_handle + + # Handle mrope handles for both: + # 1. Regular mm_embedding case (disaggregated_params was just created/updated above) + # 2. Prefill-only EPD requests (mm_embedding_handle is None because embeddings + # were already computed in encode phase, but mrope still needs forwarding) + if (getattr(response_result, "mrope_position_ids_handle", None) + is not None and self.disaggregated_params is not None): + self.disaggregated_params.mrope_position_ids_handle = ( + response_result.mrope_position_ids_handle) + self.disaggregated_params.mrope_position_deltas_handle = ( + response_result.mrope_position_deltas_handle) # Processing background errors here ASAF during generation. if self._background_error_handler and ( diff --git a/tests/unittest/_torch/multimodal/test_mm_encoder_standalone.py b/tests/unittest/_torch/multimodal/test_mm_encoder_standalone.py index 2ba776e971..2352329ccd 100644 --- a/tests/unittest/_torch/multimodal/test_mm_encoder_standalone.py +++ b/tests/unittest/_torch/multimodal/test_mm_encoder_standalone.py @@ -174,6 +174,41 @@ def _load_inputs(llm: LLM, prompts, media, mm_embeddings=None): return inputs +def _assert_handles_are_different(x: dict | None, y: dict | None) -> None: + # Helper function for checking that two SharedTensorContainer dict representations of the same + # underlying data are different. Certain metadata should stay the same (basically those describing + # the tensor's contents), while others should actually differ (those pertaining to the underlying + # storage). + matching_keys = [ + "dtype", + "event_sync_required", + "method_key", + "requires_grad", + # NOTE: this assumes the workers are on the same physical device, which is the case in + # the tests in this file since `LLM` API does not expose a way to select the device ID. + "storage_device", + "storage_size_bytes", + "tensor_offset", + "tensor_size", + "tensor_stride", + ] + + different_keys = [ + "event_handle", + "ref_counter_handle", + "ref_counter_offset", + "storage_handle", + "storage_offset_bytes", + ] + + assert set(matching_keys + different_keys) == x.keys() == y.keys() + + for key in matching_keys: + assert x[key] == y[key] + for key in different_keys: + assert x[key] != y[key] + + # TODO: Add multi-image in single chat test @pytest.mark.threadleak(enabled=False) def test_single_image_chat( @@ -226,6 +261,7 @@ def test_single_image_chat( assert ep_disaggregated_params is not None, "Encoder output disaggregated params is None" ep_disaggregated_params.request_type = "context_and_generation" if not pd_disagg else "context_only" + outputs = llm.generate(inputs, sampling_params=sampling_params, disaggregated_params=ep_disaggregated_params) @@ -234,6 +270,12 @@ def test_single_image_chat( # Generation using llm_decode assert len(outputs) == 1 pd_disaggregated_params = outputs[0].disaggregated_params + + ep_handle = ep_disaggregated_params.mrope_position_ids_handle + pd_handle = pd_disaggregated_params.mrope_position_ids_handle + assert type(ep_handle) is type(pd_handle) + if ep_handle is not None: + _assert_handles_are_different(ep_handle, pd_handle) pd_disaggregated_params.request_type = "generation_only" sampling_params = SamplingParams(max_tokens=max_tokens) # remove multimodal data from input as decoder worker doesn't need it