diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index 5ff6dc8074..be37c89dc0 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -513,6 +513,23 @@ class BaseLLM: else: # Convert to shared tensor handle to reduce IPC overhead multimodal_params.to_handle("multimodal_data") + if (disaggregated_params is not None) and ( + "mrope_config" + in multimodal_params.multimodal_data): + # Propagate mRoPE handles during context-only P -> D so decode-only + # can rebuild mrope_config without raw multimodal inputs. + mrope_config = multimodal_params.multimodal_data[ + "mrope_config"] + mrope_position_ids = mrope_config.get( + "mrope_position_ids") + mrope_position_deltas = mrope_config.get( + "mrope_position_deltas") + if (mrope_position_ids is not None + and mrope_position_deltas is not None): + disaggregated_params.mrope_position_ids_handle = ( + mrope_position_ids) + disaggregated_params.mrope_position_deltas_handle = ( + mrope_position_deltas) else: raise TypeError( f"The inputs must be type str or list of int, but got {type(inputs)}" diff --git a/tests/unittest/_torch/multimodal/test_mm_encoder_standalone.py b/tests/unittest/_torch/multimodal/test_mm_encoder_standalone.py index aa5c6c5434..2ba776e971 100644 --- a/tests/unittest/_torch/multimodal/test_mm_encoder_standalone.py +++ b/tests/unittest/_torch/multimodal/test_mm_encoder_standalone.py @@ -12,7 +12,8 @@ from utils.llm_data import llm_models_root from tensorrt_llm import MultimodalEncoder from tensorrt_llm._torch.shared_tensor import SharedTensorContainer from tensorrt_llm.inputs import default_multimodal_input_loader -from tensorrt_llm.llmapi import CacheTransceiverConfig, KvCacheConfig +from tensorrt_llm.llmapi import (CacheTransceiverConfig, DisaggregatedParams, + KvCacheConfig) from tensorrt_llm.llmapi.llm import LLM, SamplingParams test_data_root = Path( @@ -283,6 +284,58 @@ def test_single_image_chat( f"Log probabilities don't match for output {i}, generation {j}" +@pytest.mark.parametrize("model_dir", [_QWEN_3_VL_DIR], indirect=True) +@pytest.mark.parametrize("pd_disagg", [True], indirect=True) +@pytest.mark.threadleak(enabled=False) +def test_pd_disagg_with_image_input( + model_dir: Path, + pd_disagg: bool, + llms: tuple[LLM, LLM | None], +): + """Test P/D disagg with image input.""" + llm, llm_decode = llms + assert llm_decode is not None, "Disaggregated decode worker required." + + prompts = ["Describe the image."] + media = [example_images[-1]] + sampling_params = SamplingParams(max_tokens=32, temperature=0) + + # Reference outputs: use desired `max_tokens`. + inputs = _load_inputs(llm, prompts, media) + outputs_ref = llm.generate(inputs, sampling_params=sampling_params) + assert outputs_ref is not None and len(outputs_ref) == len(prompts) + + # Prefill: `max_tokens=0`. + prefill_disagg_params = DisaggregatedParams(request_type="context_only") + outputs = llm.generate(inputs, + sampling_params=SamplingParams(max_tokens=0, + temperature=0), + disaggregated_params=prefill_disagg_params) + assert len(outputs) == 1 + pd_disaggregated_params = outputs[0].disaggregated_params + pd_disaggregated_params.request_type = "generation_only" + + # Decode: use desired `max_tokens`. + decode_inputs = [{ + "prompt": inputs[0]["prompt"], + "multi_modal_data": None, + "prompt_token_ids": outputs[0].prompt_token_ids, + }] + outputs_pd = llm_decode.generate( + decode_inputs, + sampling_params=sampling_params, + disaggregated_params=pd_disaggregated_params) + + assert len(outputs_pd) == len(prompts) + for i, (ref_output, test_output) in enumerate(zip(outputs_ref, outputs_pd)): + assert len(ref_output.outputs) == len(test_output.outputs), \ + f"Number of generated outputs don't match for output {i}: {len(ref_output.outputs)} vs {len(test_output.outputs)}" + for j, (ref_gen, test_gen) in enumerate( + zip(ref_output.outputs, test_output.outputs)): + assert ref_gen.text == test_gen.text, \ + f"Generated text doesn't match for output {i}, generation {j}:\nReference: {ref_gen.text!r}\nTest: {test_gen.text!r}" + + @pytest.mark.parametrize("use_mm_embeddings,pass_embeddings_through_loader", product([False, True], [False, True])) @pytest.mark.threadleak(enabled=False)