mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-04 18:21:52 +08:00
[https://nvbugs/5826962][fix] Fix PD disaggregation for VLMs that use mrope (#10865)
* Why?
Commit a6a8898 enabled EPD disaggregation for VLMs that use mrope (e.g.
qwen). However, this broke PD disaggregation for these sames models.
* What?
This commit fixes this, and adds a unit test that guards against it.
Signed-off-by: William Zhang <133824995+2ez4bz@users.noreply.github.com>
Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
This commit is contained in:
parent
4d282bd7c1
commit
bc2487bc2c
@ -513,6 +513,23 @@ class BaseLLM:
|
||||
else:
|
||||
# Convert to shared tensor handle to reduce IPC overhead
|
||||
multimodal_params.to_handle("multimodal_data")
|
||||
if (disaggregated_params is not None) and (
|
||||
"mrope_config"
|
||||
in multimodal_params.multimodal_data):
|
||||
# Propagate mRoPE handles during context-only P -> D so decode-only
|
||||
# can rebuild mrope_config without raw multimodal inputs.
|
||||
mrope_config = multimodal_params.multimodal_data[
|
||||
"mrope_config"]
|
||||
mrope_position_ids = mrope_config.get(
|
||||
"mrope_position_ids")
|
||||
mrope_position_deltas = mrope_config.get(
|
||||
"mrope_position_deltas")
|
||||
if (mrope_position_ids is not None
|
||||
and mrope_position_deltas is not None):
|
||||
disaggregated_params.mrope_position_ids_handle = (
|
||||
mrope_position_ids)
|
||||
disaggregated_params.mrope_position_deltas_handle = (
|
||||
mrope_position_deltas)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"The inputs must be type str or list of int, but got {type(inputs)}"
|
||||
|
||||
@ -12,7 +12,8 @@ from utils.llm_data import llm_models_root
|
||||
from tensorrt_llm import MultimodalEncoder
|
||||
from tensorrt_llm._torch.shared_tensor import SharedTensorContainer
|
||||
from tensorrt_llm.inputs import default_multimodal_input_loader
|
||||
from tensorrt_llm.llmapi import CacheTransceiverConfig, KvCacheConfig
|
||||
from tensorrt_llm.llmapi import (CacheTransceiverConfig, DisaggregatedParams,
|
||||
KvCacheConfig)
|
||||
from tensorrt_llm.llmapi.llm import LLM, SamplingParams
|
||||
|
||||
test_data_root = Path(
|
||||
@ -283,6 +284,58 @@ def test_single_image_chat(
|
||||
f"Log probabilities don't match for output {i}, generation {j}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_dir", [_QWEN_3_VL_DIR], indirect=True)
|
||||
@pytest.mark.parametrize("pd_disagg", [True], indirect=True)
|
||||
@pytest.mark.threadleak(enabled=False)
|
||||
def test_pd_disagg_with_image_input(
|
||||
model_dir: Path,
|
||||
pd_disagg: bool,
|
||||
llms: tuple[LLM, LLM | None],
|
||||
):
|
||||
"""Test P/D disagg with image input."""
|
||||
llm, llm_decode = llms
|
||||
assert llm_decode is not None, "Disaggregated decode worker required."
|
||||
|
||||
prompts = ["Describe the image."]
|
||||
media = [example_images[-1]]
|
||||
sampling_params = SamplingParams(max_tokens=32, temperature=0)
|
||||
|
||||
# Reference outputs: use desired `max_tokens`.
|
||||
inputs = _load_inputs(llm, prompts, media)
|
||||
outputs_ref = llm.generate(inputs, sampling_params=sampling_params)
|
||||
assert outputs_ref is not None and len(outputs_ref) == len(prompts)
|
||||
|
||||
# Prefill: `max_tokens=0`.
|
||||
prefill_disagg_params = DisaggregatedParams(request_type="context_only")
|
||||
outputs = llm.generate(inputs,
|
||||
sampling_params=SamplingParams(max_tokens=0,
|
||||
temperature=0),
|
||||
disaggregated_params=prefill_disagg_params)
|
||||
assert len(outputs) == 1
|
||||
pd_disaggregated_params = outputs[0].disaggregated_params
|
||||
pd_disaggregated_params.request_type = "generation_only"
|
||||
|
||||
# Decode: use desired `max_tokens`.
|
||||
decode_inputs = [{
|
||||
"prompt": inputs[0]["prompt"],
|
||||
"multi_modal_data": None,
|
||||
"prompt_token_ids": outputs[0].prompt_token_ids,
|
||||
}]
|
||||
outputs_pd = llm_decode.generate(
|
||||
decode_inputs,
|
||||
sampling_params=sampling_params,
|
||||
disaggregated_params=pd_disaggregated_params)
|
||||
|
||||
assert len(outputs_pd) == len(prompts)
|
||||
for i, (ref_output, test_output) in enumerate(zip(outputs_ref, outputs_pd)):
|
||||
assert len(ref_output.outputs) == len(test_output.outputs), \
|
||||
f"Number of generated outputs don't match for output {i}: {len(ref_output.outputs)} vs {len(test_output.outputs)}"
|
||||
for j, (ref_gen, test_gen) in enumerate(
|
||||
zip(ref_output.outputs, test_output.outputs)):
|
||||
assert ref_gen.text == test_gen.text, \
|
||||
f"Generated text doesn't match for output {i}, generation {j}:\nReference: {ref_gen.text!r}\nTest: {test_gen.text!r}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_mm_embeddings,pass_embeddings_through_loader",
|
||||
product([False, True], [False, True]))
|
||||
@pytest.mark.threadleak(enabled=False)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user