[https://nvbugs/5826962][fix] Fix PD disaggregation for VLMs that use mrope (#10865)

* Why?

Commit a6a8898 enabled EPD disaggregation for VLMs that use mrope (e.g.
qwen). However, this broke PD disaggregation for these sames models.

* What?

This commit fixes this, and adds a unit test that guards against it.

Signed-off-by: William Zhang <133824995+2ez4bz@users.noreply.github.com>
Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
This commit is contained in:
William Zhang 2026-01-21 22:46:15 -08:00 committed by Yanchao Lu
parent 4d282bd7c1
commit bc2487bc2c
2 changed files with 71 additions and 1 deletions

View File

@ -513,6 +513,23 @@ class BaseLLM:
else:
# Convert to shared tensor handle to reduce IPC overhead
multimodal_params.to_handle("multimodal_data")
if (disaggregated_params is not None) and (
"mrope_config"
in multimodal_params.multimodal_data):
# Propagate mRoPE handles during context-only P -> D so decode-only
# can rebuild mrope_config without raw multimodal inputs.
mrope_config = multimodal_params.multimodal_data[
"mrope_config"]
mrope_position_ids = mrope_config.get(
"mrope_position_ids")
mrope_position_deltas = mrope_config.get(
"mrope_position_deltas")
if (mrope_position_ids is not None
and mrope_position_deltas is not None):
disaggregated_params.mrope_position_ids_handle = (
mrope_position_ids)
disaggregated_params.mrope_position_deltas_handle = (
mrope_position_deltas)
else:
raise TypeError(
f"The inputs must be type str or list of int, but got {type(inputs)}"

View File

@ -12,7 +12,8 @@ from utils.llm_data import llm_models_root
from tensorrt_llm import MultimodalEncoder
from tensorrt_llm._torch.shared_tensor import SharedTensorContainer
from tensorrt_llm.inputs import default_multimodal_input_loader
from tensorrt_llm.llmapi import CacheTransceiverConfig, KvCacheConfig
from tensorrt_llm.llmapi import (CacheTransceiverConfig, DisaggregatedParams,
KvCacheConfig)
from tensorrt_llm.llmapi.llm import LLM, SamplingParams
test_data_root = Path(
@ -283,6 +284,58 @@ def test_single_image_chat(
f"Log probabilities don't match for output {i}, generation {j}"
@pytest.mark.parametrize("model_dir", [_QWEN_3_VL_DIR], indirect=True)
@pytest.mark.parametrize("pd_disagg", [True], indirect=True)
@pytest.mark.threadleak(enabled=False)
def test_pd_disagg_with_image_input(
model_dir: Path,
pd_disagg: bool,
llms: tuple[LLM, LLM | None],
):
"""Test P/D disagg with image input."""
llm, llm_decode = llms
assert llm_decode is not None, "Disaggregated decode worker required."
prompts = ["Describe the image."]
media = [example_images[-1]]
sampling_params = SamplingParams(max_tokens=32, temperature=0)
# Reference outputs: use desired `max_tokens`.
inputs = _load_inputs(llm, prompts, media)
outputs_ref = llm.generate(inputs, sampling_params=sampling_params)
assert outputs_ref is not None and len(outputs_ref) == len(prompts)
# Prefill: `max_tokens=0`.
prefill_disagg_params = DisaggregatedParams(request_type="context_only")
outputs = llm.generate(inputs,
sampling_params=SamplingParams(max_tokens=0,
temperature=0),
disaggregated_params=prefill_disagg_params)
assert len(outputs) == 1
pd_disaggregated_params = outputs[0].disaggregated_params
pd_disaggregated_params.request_type = "generation_only"
# Decode: use desired `max_tokens`.
decode_inputs = [{
"prompt": inputs[0]["prompt"],
"multi_modal_data": None,
"prompt_token_ids": outputs[0].prompt_token_ids,
}]
outputs_pd = llm_decode.generate(
decode_inputs,
sampling_params=sampling_params,
disaggregated_params=pd_disaggregated_params)
assert len(outputs_pd) == len(prompts)
for i, (ref_output, test_output) in enumerate(zip(outputs_ref, outputs_pd)):
assert len(ref_output.outputs) == len(test_output.outputs), \
f"Number of generated outputs don't match for output {i}: {len(ref_output.outputs)} vs {len(test_output.outputs)}"
for j, (ref_gen, test_gen) in enumerate(
zip(ref_output.outputs, test_output.outputs)):
assert ref_gen.text == test_gen.text, \
f"Generated text doesn't match for output {i}, generation {j}:\nReference: {ref_gen.text!r}\nTest: {test_gen.text!r}"
@pytest.mark.parametrize("use_mm_embeddings,pass_embeddings_through_loader",
product([False, True], [False, True]))
@pytest.mark.threadleak(enabled=False)