[https://nvbugs/5848756][fix] Re-take ownership of mrope tensors in prefill worker (#11217)

* Why?

Previously, the mrope tensors' IPC handles would just be forwarded from
encode -> prefill -> decode workers. While this is fine for the
prefill worker, it is not for the decode worker, since by the time it
tries to rebuild those tensors, they could have been garbage collected
due to their refcounts reaching zero in the producer (encode) worker.

This could lead to nasty runtime errors when running E/P/D
disaggregated serving.

* What?

This commit fixes this by having the prefill worker take ownership of
those reconstructed tensors, and stand up new copies for the decode
worker.

Closes: NvBug 5848756

Signed-off-by: William Zhang <133824995+2ez4bz@users.noreply.github.com>
This commit is contained in:
William Zhang 2026-02-06 19:37:42 -08:00 committed by GitHub
parent 408d610877
commit ffc0f54959
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 70 additions and 4 deletions

View File

@ -2118,6 +2118,23 @@ class PyTorchModelEngine(ModelEngine):
request.py_multimodal_data = multimodal_params.multimodal_data
multimodal_params_list.append(multimodal_params)
# Re-register mrope tensors for context-only requests (EPD disaggregated serving).
# This creates new IPC handles owned by the prefill worker, so the decode worker
# can access them even after the encode worker's GC deallocates the original memory.
# Without this, the decode worker would receive handles pointing to freed memory.
if (request.is_context_only_request and self.use_mrope and
"mrope_config" in multimodal_params.multimodal_data):
mrope_config = multimodal_params.multimodal_data[
"mrope_config"]
_mrope_position_ids = mrope_config.get("mrope_position_ids")
_mrope_position_deltas = mrope_config.get(
"mrope_position_deltas")
if _mrope_position_ids is not None and _mrope_position_deltas is not None:
# Clone to allocate new memory owned by this (prefill) worker.
request.py_result.set_mrope_position(
_mrope_position_ids.clone(),
_mrope_position_deltas.clone())
request.py_batch_idx = request.py_seq_slot
if len(multimodal_params_list) > 0:

View File

@ -450,8 +450,6 @@ class GenerationResultBase:
if hasattr(response_result, 'mm_embedding_handle'
) and response_result.mm_embedding_handle is not None:
self._mm_embedding_handle = response_result.mm_embedding_handle
mrope_position_ids_handle = response_result.mrope_position_ids_handle
mrope_position_deltas_handle = response_result.mrope_position_deltas_handle
if self.disaggregated_params is not None:
self.disaggregated_params.multimodal_embedding_handles = [
response_result.mm_embedding_handle
@ -463,8 +461,17 @@ class GenerationResultBase:
response_result.mm_embedding_handle
],
multimodal_hashes=self._multimodal_hashes)
self.disaggregated_params.mrope_position_ids_handle = mrope_position_ids_handle
self.disaggregated_params.mrope_position_deltas_handle = mrope_position_deltas_handle
# Handle mrope handles for both:
# 1. Regular mm_embedding case (disaggregated_params was just created/updated above)
# 2. Prefill-only EPD requests (mm_embedding_handle is None because embeddings
# were already computed in encode phase, but mrope still needs forwarding)
if (getattr(response_result, "mrope_position_ids_handle", None)
is not None and self.disaggregated_params is not None):
self.disaggregated_params.mrope_position_ids_handle = (
response_result.mrope_position_ids_handle)
self.disaggregated_params.mrope_position_deltas_handle = (
response_result.mrope_position_deltas_handle)
# Processing background errors here ASAF during generation.
if self._background_error_handler and (

View File

@ -174,6 +174,41 @@ def _load_inputs(llm: LLM, prompts, media, mm_embeddings=None):
return inputs
def _assert_handles_are_different(x: dict | None, y: dict | None) -> None:
# Helper function for checking that two SharedTensorContainer dict representations of the same
# underlying data are different. Certain metadata should stay the same (basically those describing
# the tensor's contents), while others should actually differ (those pertaining to the underlying
# storage).
matching_keys = [
"dtype",
"event_sync_required",
"method_key",
"requires_grad",
# NOTE: this assumes the workers are on the same physical device, which is the case in
# the tests in this file since `LLM` API does not expose a way to select the device ID.
"storage_device",
"storage_size_bytes",
"tensor_offset",
"tensor_size",
"tensor_stride",
]
different_keys = [
"event_handle",
"ref_counter_handle",
"ref_counter_offset",
"storage_handle",
"storage_offset_bytes",
]
assert set(matching_keys + different_keys) == x.keys() == y.keys()
for key in matching_keys:
assert x[key] == y[key]
for key in different_keys:
assert x[key] != y[key]
# TODO: Add multi-image in single chat test
@pytest.mark.threadleak(enabled=False)
def test_single_image_chat(
@ -226,6 +261,7 @@ def test_single_image_chat(
assert ep_disaggregated_params is not None, "Encoder output disaggregated params is None"
ep_disaggregated_params.request_type = "context_and_generation" if not pd_disagg else "context_only"
outputs = llm.generate(inputs,
sampling_params=sampling_params,
disaggregated_params=ep_disaggregated_params)
@ -234,6 +270,12 @@ def test_single_image_chat(
# Generation using llm_decode
assert len(outputs) == 1
pd_disaggregated_params = outputs[0].disaggregated_params
ep_handle = ep_disaggregated_params.mrope_position_ids_handle
pd_handle = pd_disaggregated_params.mrope_position_ids_handle
assert type(ep_handle) is type(pd_handle)
if ep_handle is not None:
_assert_handles_are_different(ep_handle, pd_handle)
pd_disaggregated_params.request_type = "generation_only"
sampling_params = SamplingParams(max_tokens=max_tokens)
# remove multimodal data from input as decoder worker doesn't need it