mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[https://nvbugs/5625380][chore] Remove multimodal related fields from decoder llm input (#8846)
This commit is contained in:
parent
0f42a24f45
commit
f57dc01e6f
@ -426,10 +426,8 @@ class BaseLLM:
|
||||
prompt = inputs.get("prompt", None)
|
||||
query_token_ids = inputs.get("query_token_ids", None)
|
||||
if is_gen_only:
|
||||
# TODO: support generation-only mode for multimodal disaggregated inference
|
||||
# Need to set multimodal_params = None; but not tested yet
|
||||
raise ValueError(
|
||||
"Multimodal disaggregated inference is not supported for generation-only mode"
|
||||
"Generation-only mode should not need multimodal parameters"
|
||||
)
|
||||
else:
|
||||
mm_hashes = disaggregated_params.multimodal_hashes
|
||||
|
||||
@ -150,6 +150,10 @@ def test_single_image_chat(model_key, pd_disagg, multimodal_model_config):
|
||||
pd_disaggregated_params = outputs[0].disaggregated_params
|
||||
pd_disaggregated_params.request_type = "generation_only"
|
||||
sampling_params = SamplingParams(max_tokens=max_tokens)
|
||||
inputs[0][
|
||||
'multi_modal_data'] = None # remove multimodal data from input as decoder worker doesn't need it
|
||||
inputs[0]['prompt_token_ids'] = outputs[
|
||||
0].prompt_token_ids # use prompt token ids from encoder output
|
||||
|
||||
outputs = llm_decode.generate(
|
||||
inputs,
|
||||
@ -169,9 +173,9 @@ def test_single_image_chat(model_key, pd_disagg, multimodal_model_config):
|
||||
), f"Number of outputs don't match: {len(outputs_ref)} vs {len(outputs)}"
|
||||
|
||||
for i, (ref_output, test_output) in enumerate(zip(outputs_ref, outputs)):
|
||||
# Compare prompts
|
||||
assert ref_output.prompt == test_output.prompt, \
|
||||
f"Prompts don't match for output {i}:\nReference: {ref_output.prompt!r}\nTest: {test_output.prompt!r}"
|
||||
# Cannot compare prompts as decoder worker would void it
|
||||
#assert ref_output.prompt == test_output.prompt, \
|
||||
# f"Prompts don't match for output {i}:\nReference: {ref_output.prompt!r}\nTest: {test_output.prompt!r}"
|
||||
|
||||
# Compare number of generated outputs
|
||||
assert len(ref_output.outputs) == len(test_output.outputs), \
|
||||
|
||||
Loading…
Reference in New Issue
Block a user