[TRTLLM-6975][test] Add multi-turn test cases for VLM models (#6749)

Signed-off-by: Ivy Zhang <25222398+crazydemo@users.noreply.github.com>
This commit is contained in:
Ivy Zhang 2025-08-13 13:10:13 +08:00 committed by GitHub
parent cf00003f3d
commit 15bcf80596
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 208 additions and 5 deletions

View File

@ -108,6 +108,15 @@ def add_multimodal_args(parser):
type=str,
default="cpu",
help="The device to have the input on.")
# Add multiturn conversation related parameters
parser.add_argument("--multiturn",
action="store_true",
help="Enable multi-turn conversation mode.")
parser.add_argument(
"--conversation_turns",
type=int,
default=2,
help="Number of conversation turns for automated testing.")
return parser
@ -162,6 +171,80 @@ def main():
open(os.path.join(llm._hf_model_dir, 'config.json')))['model_type']
assert model_type in ALL_SUPPORTED_MULTIMODAL_MODELS, f"Unsupported model_type: {model_type}"
# If multiturn mode is enabled
if args.multiturn:
# Run predefined multiturn conversation examples
assert args.prompt is not None, "Please provide a prompt for multiturn conversation."
assert args.media is not None, "Please provide media for multiturn conversation."
# Determine how many turns to run
max_turns = min(args.conversation_turns, len(args.prompt))
generated_outputs = [] # Store generated outputs for return
# Initialize conversation history with the first prompt
conversation_history = args.prompt[0] if args.prompt else ""
for i in range(max_turns):
print(f"\n--- Turn {i+1} ---")
try:
# Use multimodal input loader to process input with conversation context
# Use accumulated conversation history instead of just the current prompt
cur_prompt = conversation_history
inputs = default_multimodal_input_loader(
tokenizer=llm.tokenizer,
model_dir=llm._hf_model_dir,
model_type=model_type,
modality=args.modality,
prompts=[cur_prompt],
media=args.media,
image_data_format="pt",
num_frames=8,
device="cpu")
lora_request = None
if args.load_lora:
if model_class is None:
raise ValueError(
"model_class must be provided when load_lora is True"
)
lora_request = model_class.lora_request(
len(inputs), args.modality, llm._hf_model_dir)
# Generate response
outputs = llm.generate(inputs,
sampling_params,
lora_request=lora_request)
assert outputs and len(
outputs) > 0 and outputs[0].outputs and len(
outputs[0].outputs) > 0
response = outputs[0].outputs[0].text.strip()
# Store generated output
generated_outputs.append({
"turn": i + 1,
"user_input": cur_prompt,
"assistant_response": response,
"media": args.media
})
conversation_history = conversation_history + "\n" + response
if i + 1 < len(args.prompt):
conversation_history = conversation_history + "\n" + args.prompt[
i + 1]
except Exception as e:
print(f"Error in turn {i+1}: {e}")
import traceback
traceback.print_exc()
continue
for i, output in enumerate(generated_outputs):
print(
f"[{i}] Prompt: {output['user_input']!r}, Generated text: {output['assistant_response']!r}"
)
return
# Original single-turn processing logic
# set prompts and media to example prompts and images if they are not provided
if args.prompt is None:
args.prompt = example_medias_and_prompts[args.modality]["prompt"]

View File

@ -19,6 +19,13 @@ meta-llama/Llama-3.3-70B-Instruct:
accuracy: 84.08
meta-llama/Llama-4-Maverick-17B-128E-Instruct:
- accuracy: 92.20
- quant_algo: FP8
kv_cache_quant_algo: FP8
accuracy: 92.20
- quant_algo: FP8
kv_cache_quant_algo: FP8
spec_dec_algo: Eagle
accuracy: 92.20
meta-llama/Llama-4-Scout-17B-16E-Instruct:
- accuracy: 89.70
- quant_algo: NVFP4

View File

@ -73,6 +73,9 @@ meta-llama/Llama-4-Maverick-17B-128E-Instruct:
kv_cache_quant_algo: FP8
spec_dec_algo: Eagle
accuracy: 86.40
- quant_algo: FP8
kv_cache_quant_algo: FP8
accuracy: 86.40
meta-llama/Llama-4-Scout-17B-16E-Instruct:
- accuracy: 80.00
- quant_algo: NVFP4

View File

@ -246,7 +246,7 @@ def run_parallel_test(model_name: str, model_path: str, ctx_pp: int,
total_ctx_gpus = ctx_tp * ctx_pp * ctx_instances
total_gen_gpus = gen_tp * gen_pp * gen_instances
if total_ctx_gpus + total_gen_gpus > get_device_count():
pytest.fail(
pytest.skip(
f"Not enough devices for {ctx_instances} ctx instances (ctx_pp={ctx_pp}*ctx_tp={ctx_tp}) + {gen_instances} gen instances (gen_pp={gen_pp}*gen_tp={gen_tp}), total: {total_ctx_gpus + total_gen_gpus}"
)
@ -378,6 +378,7 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)
@skip_pre_hopper
@parametrize_with_ids("overlap_scheduler", [True, False])
@parametrize_with_ids("eagle3_one_model", [True, False])
def test_eagle3(self, overlap_scheduler, eagle3_one_model):
@ -461,6 +462,7 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
@pytest.mark.skip_less_device_memory(140000)
@pytest.mark.timeout(3600)
@pytest.mark.skip_less_device(4)
class TestLlama4ScoutInstruct(LlmapiAccuracyTestHarness):
MODEL_NAME = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
MODEL_PATH = f"{llm_models_root()}/llama4-models/Llama-4-Scout-17B-16E-Instruct"
@ -540,6 +542,7 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
@parametrize_with_ids("overlap_scheduler", [True, False])
@parametrize_with_ids("mtp_nextn",
[0, pytest.param(2, marks=skip_pre_hopper)])
@pytest.mark.skip_less_device(4)
def test_auto_dtype(self, overlap_scheduler, mtp_nextn):
ctx_server_config = {"disable_overlap_scheduler": True}
gen_server_config = {"disable_overlap_scheduler": not overlap_scheduler}
@ -671,6 +674,7 @@ class TestQwen3_8B(LlmapiAccuracyTestHarness):
task.evaluate(llm)
@pytest.mark.parametrize("overlap_scheduler", [False, True])
@skip_pre_hopper
def test_auto_dtype(self, overlap_scheduler):
ctx_server_config = {
"disable_overlap_scheduler": True,

View File

@ -695,6 +695,7 @@ class TestMistralSmall24B(LlmapiAccuracyTestHarness):
MODEL_NAME = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
MODEL_PATH = f"{llm_models_root()}/Mistral-Small-3.1-24B-Instruct-2503"
@pytest.mark.skip_less_device_memory(80000)
def test_auto_dtype(self):
with LLM(self.MODEL_PATH) as llm:
task = CnnDailymail(self.MODEL_NAME)
@ -1033,7 +1034,7 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
max_num_streams=3) if torch_compile else None)
pytorch_config = dict(
disable_overlap_scheduler=not overlap_scheduler,
use_cuda_graph=cuda_graph,
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
torch_compile_config=torch_compile_config,
moe_config=MoeConfig(backend="CUTEDSL"),
)
@ -1191,7 +1192,7 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
max_num_streams=3) if torch_compile else None)
pytorch_config = dict(
disable_overlap_scheduler=not overlap_scheduler,
use_cuda_graph=cuda_graph,
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
torch_compile_config=torch_compile_config,
moe_config=MoeConfig(backend="CUTEDSL"),
)

View File

@ -2051,7 +2051,7 @@ def test_ptp_quickstart_advanced_8gpus(llm_root, llm_venv, model_name,
def test_ptp_quickstart_advanced_8gpus_chunked_prefill_sq_22k(
llm_root, llm_venv, model_name, model_path, cuda_graph):
print(f"Testing {model_name} on 8 GPUs.")
example_root = Path(os.path.join(llm_root, "examples", "pytorch"))
example_root = Path(os.path.join(llm_root, "examples", "llm-api"))
cmd = [
str(example_root / "quickstart_advanced.py"),
"--enable_chunked_prefill",
@ -2076,10 +2076,12 @@ def test_ptp_quickstart_advanced_8gpus_chunked_prefill_sq_22k(
@pytest.mark.skip_less_device_memory(80000)
@pytest.mark.skip_less_device(2)
@pytest.mark.parametrize("model_name,model_path", [
("Llama3.1-70B-BF16", "llama-3.1-model/Meta-Llama-3.1-70B"),
('Nemotron-Super-49B-v1-BF16',
'nemotron-nas/Llama-3_3-Nemotron-Super-49B-v1'),
("Mixtral-8x7B-BF16", "Mixtral-8x7B-Instruct-v0.1"),
pytest.param('Llama3.1-70B-BF16',
'llama-3.1-model/Meta-Llama-3.1-70B',
marks=pytest.mark.skip_less_device_memory(95000)),
])
def test_ptp_quickstart_advanced_2gpus_sm120(llm_root, llm_venv, model_name,
model_path):
@ -2521,6 +2523,106 @@ def test_ptp_quickstart_multimodal_2gpu(llm_root, llm_venv, model_name,
print("All answers are correct!")
@pytest.mark.skip_less_device_memory(80000)
@pytest.mark.parametrize("model_name,model_path", [
("gemma-3-27b-it", "gemma/gemma-3-27b-it"),
("mistral-small-3.1-24b-instruct", "Mistral-Small-3.1-24B-Instruct-2503"),
("Phi-4-multimodal-instruct", "multimodals/Phi-4-multimodal-instruct"),
])
def test_ptp_quickstart_multimodal_multiturn(llm_root, llm_venv, model_name,
model_path):
example_root = Path(os.path.join(llm_root, "examples", "llm-api"))
test_data_root = Path(
os.path.join(llm_models_root(), "multimodals", "test_data"))
print(f"Accuracy test {model_name} image mode with example inputs.")
# Define accuracy inputs for image modality
accuracy_inputs = {
"image": {
"prompt": [
"Describe what you see in this image.",
"How would you describe the atmosphere of this scene?",
],
"media": [
str(test_data_root / "inpaint.png"),
],
}
}
# Define expected keywords for each model
expected_keywords = {
"gemma-3-27b-it": {
"image": [
["half", "dome", "yosemite", "landmark", "rounded"],
["atmosphere", "peaceful", "majestic", "calm", "quiet"],
],
},
"mistral-small-3.1-24b-instruct": {
"image": [
["depicts", "landscape", "rock", "sky", "high", "altitude"],
["atmosphere", "serene", "majestic", "sense", "tranquility"],
],
},
"Phi-4-multimodal-instruct": {
"image": [
["depicts", "landscape", "mountain", "half", "dome"],
["atmosphere", "serene", "sense", "tranquility", "peace."],
],
},
}
# Build command for image modality
cmd = [
str(example_root / "quickstart_multimodal.py"),
"--model_dir",
f"{llm_models_root()}/{model_path}",
"--modality",
"image",
"--multiturn",
"--prompt",
*accuracy_inputs["image"]["prompt"],
"--media",
*accuracy_inputs["image"]["media"],
]
# Add model-specific configurations
if model_name == "gemma-3-27b-it":
# Gemma3 VLM needs a custom mask which is only supported by flashinfer backend currently.
# Custom mask involves bidirectional masking of image tokens in context phase. To get this
# correct, chunked prefill and kv cache reuse need to be turned off.
cmd.append("--image_format=pil")
cmd.append("--attention_backend=FLASHINFER")
cmd.append("--disable_kv_cache_reuse")
elif model_name == "Phi-4-multimodal-instruct":
# Set max_seq_len to 4096 to use short rope factor.
cmd.append("--max_seq_len=4096")
cmd.append("--load_lora")
cmd.append("--auto_model_name")
cmd.append("Phi4MMForCausalLM")
output = llm_venv.run_cmd(cmd, caller=check_output)
print("output:", output)
# Set match ratio based on model
match_ratio = 4.0 / 5
if model_name == "Phi-4-multimodal-instruct":
match_ratio = 0.6
# Check output accuracy
for prompt_output, prompt_keywords in zip(
parse_output(output), expected_keywords[model_name]["image"]):
matches = [
keyword in prompt_output.lower() for keyword in prompt_keywords
]
obs_match_ratio = 1. * sum(matches) / len(matches)
print("prompt_output:", prompt_output)
print("prompt_keywords:", prompt_keywords)
print("matches:", matches)
print("obs_match_ratio:", obs_match_ratio)
assert obs_match_ratio >= match_ratio, f"Incorrect output!\nGenerated \"{prompt_output}\"\nExpected keywords \"{prompt_keywords}\"\n Matched keywords: {matches}\n Observed match ratio {obs_match_ratio} below threshold {match_ratio}"
print("All answers are correct!")
@pytest.mark.parametrize("model_name,model_path", [
("BertForSequenceClassification", "bert/bert-base-uncased-yelp-polarity"),
])

View File

@ -602,6 +602,9 @@ test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[image_audio]
test_e2e.py::test_ptp_quickstart_multimodal_2gpu[gemma-3-27b-it-gemma/gemma-3-27b-it]
test_e2e.py::test_ptp_quickstart_multimodal_2gpu[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503]
test_e2e.py::test_ptp_quickstart_multimodal_2gpu[Phi-4-multimodal-instruct-multimodals/Phi-4-multimodal-instruct]
test_e2e.py::test_ptp_quickstart_multimodal_multiturn[gemma-3-27b-it-gemma/gemma-3-27b-it]
test_e2e.py::test_ptp_quickstart_multimodal_multiturn[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503]
test_e2e.py::test_ptp_quickstart_multimodal_multiturn[Phi-4-multimodal-instruct-multimodals/Phi-4-multimodal-instruct]
test_e2e.py::test_ptp_quickstart_bert[VANILLA-BertForSequenceClassification-bert/bert-base-uncased-yelp-polarity]
test_e2e.py::test_ptp_quickstart_bert[TRTLLM-BertForSequenceClassification-bert/bert-base-uncased-yelp-polarity]
test_e2e.py::test_ptp_star_attention_example[Llama3.1-8B-BF16-llama-3.1-model/Meta-Llama-3.1-8B]