mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-16 07:53:55 +08:00
[https://nvbugs/5848756][fix] Re-take ownership of mrope tensors in prefill worker (#11217)
* Why? Previously, the mrope tensors' IPC handles would just be forwarded from encode -> prefill -> decode workers. While this is fine for the prefill worker, it is not for the decode worker, since by the time it tries to rebuild those tensors, they could have been garbage collected due to their refcounts reaching zero in the producer (encode) worker. This could lead to nasty runtime errors when running E/P/D disaggregated serving. * What? This commit fixes this by having the prefill worker take ownership of those reconstructed tensors, and stand up new copies for the decode worker. Closes: NvBug 5848756 Signed-off-by: William Zhang <133824995+2ez4bz@users.noreply.github.com>
This commit is contained in:
parent
408d610877
commit
ffc0f54959
@ -2118,6 +2118,23 @@ class PyTorchModelEngine(ModelEngine):
|
||||
request.py_multimodal_data = multimodal_params.multimodal_data
|
||||
multimodal_params_list.append(multimodal_params)
|
||||
|
||||
# Re-register mrope tensors for context-only requests (EPD disaggregated serving).
|
||||
# This creates new IPC handles owned by the prefill worker, so the decode worker
|
||||
# can access them even after the encode worker's GC deallocates the original memory.
|
||||
# Without this, the decode worker would receive handles pointing to freed memory.
|
||||
if (request.is_context_only_request and self.use_mrope and
|
||||
"mrope_config" in multimodal_params.multimodal_data):
|
||||
mrope_config = multimodal_params.multimodal_data[
|
||||
"mrope_config"]
|
||||
_mrope_position_ids = mrope_config.get("mrope_position_ids")
|
||||
_mrope_position_deltas = mrope_config.get(
|
||||
"mrope_position_deltas")
|
||||
if _mrope_position_ids is not None and _mrope_position_deltas is not None:
|
||||
# Clone to allocate new memory owned by this (prefill) worker.
|
||||
request.py_result.set_mrope_position(
|
||||
_mrope_position_ids.clone(),
|
||||
_mrope_position_deltas.clone())
|
||||
|
||||
request.py_batch_idx = request.py_seq_slot
|
||||
|
||||
if len(multimodal_params_list) > 0:
|
||||
|
||||
@ -450,8 +450,6 @@ class GenerationResultBase:
|
||||
if hasattr(response_result, 'mm_embedding_handle'
|
||||
) and response_result.mm_embedding_handle is not None:
|
||||
self._mm_embedding_handle = response_result.mm_embedding_handle
|
||||
mrope_position_ids_handle = response_result.mrope_position_ids_handle
|
||||
mrope_position_deltas_handle = response_result.mrope_position_deltas_handle
|
||||
if self.disaggregated_params is not None:
|
||||
self.disaggregated_params.multimodal_embedding_handles = [
|
||||
response_result.mm_embedding_handle
|
||||
@ -463,8 +461,17 @@ class GenerationResultBase:
|
||||
response_result.mm_embedding_handle
|
||||
],
|
||||
multimodal_hashes=self._multimodal_hashes)
|
||||
self.disaggregated_params.mrope_position_ids_handle = mrope_position_ids_handle
|
||||
self.disaggregated_params.mrope_position_deltas_handle = mrope_position_deltas_handle
|
||||
|
||||
# Handle mrope handles for both:
|
||||
# 1. Regular mm_embedding case (disaggregated_params was just created/updated above)
|
||||
# 2. Prefill-only EPD requests (mm_embedding_handle is None because embeddings
|
||||
# were already computed in encode phase, but mrope still needs forwarding)
|
||||
if (getattr(response_result, "mrope_position_ids_handle", None)
|
||||
is not None and self.disaggregated_params is not None):
|
||||
self.disaggregated_params.mrope_position_ids_handle = (
|
||||
response_result.mrope_position_ids_handle)
|
||||
self.disaggregated_params.mrope_position_deltas_handle = (
|
||||
response_result.mrope_position_deltas_handle)
|
||||
|
||||
# Processing background errors here ASAF during generation.
|
||||
if self._background_error_handler and (
|
||||
|
||||
@ -174,6 +174,41 @@ def _load_inputs(llm: LLM, prompts, media, mm_embeddings=None):
|
||||
return inputs
|
||||
|
||||
|
||||
def _assert_handles_are_different(x: dict | None, y: dict | None) -> None:
|
||||
# Helper function for checking that two SharedTensorContainer dict representations of the same
|
||||
# underlying data are different. Certain metadata should stay the same (basically those describing
|
||||
# the tensor's contents), while others should actually differ (those pertaining to the underlying
|
||||
# storage).
|
||||
matching_keys = [
|
||||
"dtype",
|
||||
"event_sync_required",
|
||||
"method_key",
|
||||
"requires_grad",
|
||||
# NOTE: this assumes the workers are on the same physical device, which is the case in
|
||||
# the tests in this file since `LLM` API does not expose a way to select the device ID.
|
||||
"storage_device",
|
||||
"storage_size_bytes",
|
||||
"tensor_offset",
|
||||
"tensor_size",
|
||||
"tensor_stride",
|
||||
]
|
||||
|
||||
different_keys = [
|
||||
"event_handle",
|
||||
"ref_counter_handle",
|
||||
"ref_counter_offset",
|
||||
"storage_handle",
|
||||
"storage_offset_bytes",
|
||||
]
|
||||
|
||||
assert set(matching_keys + different_keys) == x.keys() == y.keys()
|
||||
|
||||
for key in matching_keys:
|
||||
assert x[key] == y[key]
|
||||
for key in different_keys:
|
||||
assert x[key] != y[key]
|
||||
|
||||
|
||||
# TODO: Add multi-image in single chat test
|
||||
@pytest.mark.threadleak(enabled=False)
|
||||
def test_single_image_chat(
|
||||
@ -226,6 +261,7 @@ def test_single_image_chat(
|
||||
|
||||
assert ep_disaggregated_params is not None, "Encoder output disaggregated params is None"
|
||||
ep_disaggregated_params.request_type = "context_and_generation" if not pd_disagg else "context_only"
|
||||
|
||||
outputs = llm.generate(inputs,
|
||||
sampling_params=sampling_params,
|
||||
disaggregated_params=ep_disaggregated_params)
|
||||
@ -234,6 +270,12 @@ def test_single_image_chat(
|
||||
# Generation using llm_decode
|
||||
assert len(outputs) == 1
|
||||
pd_disaggregated_params = outputs[0].disaggregated_params
|
||||
|
||||
ep_handle = ep_disaggregated_params.mrope_position_ids_handle
|
||||
pd_handle = pd_disaggregated_params.mrope_position_ids_handle
|
||||
assert type(ep_handle) is type(pd_handle)
|
||||
if ep_handle is not None:
|
||||
_assert_handles_are_different(ep_handle, pd_handle)
|
||||
pd_disaggregated_params.request_type = "generation_only"
|
||||
sampling_params = SamplingParams(max_tokens=max_tokens)
|
||||
# remove multimodal data from input as decoder worker doesn't need it
|
||||
|
||||
Loading…
Reference in New Issue
Block a user