[https://nvbugs/5625380][chore] Remove multimodal related fields from decoder llm input (#8846)

This commit is contained in:
Chang Liu 2025-11-02 17:44:08 -08:00 committed by GitHub
parent 0f42a24f45
commit f57dc01e6f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 8 additions and 6 deletions

View File

@ -426,10 +426,8 @@ class BaseLLM:
prompt = inputs.get("prompt", None)
query_token_ids = inputs.get("query_token_ids", None)
if is_gen_only:
# TODO: support generation-only mode for multimodal disaggregated inference
# Need to set multimodal_params = None; but not tested yet
raise ValueError(
"Multimodal disaggregated inference is not supported for generation-only mode"
"Generation-only mode should not need multimodal parameters"
)
else:
mm_hashes = disaggregated_params.multimodal_hashes

View File

@ -150,6 +150,10 @@ def test_single_image_chat(model_key, pd_disagg, multimodal_model_config):
pd_disaggregated_params = outputs[0].disaggregated_params
pd_disaggregated_params.request_type = "generation_only"
sampling_params = SamplingParams(max_tokens=max_tokens)
inputs[0][
'multi_modal_data'] = None # remove multimodal data from input as decoder worker doesn't need it
inputs[0]['prompt_token_ids'] = outputs[
0].prompt_token_ids # use prompt token ids from encoder output
outputs = llm_decode.generate(
inputs,
@ -169,9 +173,9 @@ def test_single_image_chat(model_key, pd_disagg, multimodal_model_config):
), f"Number of outputs don't match: {len(outputs_ref)} vs {len(outputs)}"
for i, (ref_output, test_output) in enumerate(zip(outputs_ref, outputs)):
# Compare prompts
assert ref_output.prompt == test_output.prompt, \
f"Prompts don't match for output {i}:\nReference: {ref_output.prompt!r}\nTest: {test_output.prompt!r}"
# Cannot compare prompts as decoder worker would void it
#assert ref_output.prompt == test_output.prompt, \
# f"Prompts don't match for output {i}:\nReference: {ref_output.prompt!r}\nTest: {test_output.prompt!r}"
# Compare number of generated outputs
assert len(ref_output.outputs) == len(test_output.outputs), \