mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[TRTLLM-6975][test] Add multi-turn test cases for VLM models (#6749)
Signed-off-by: Ivy Zhang <25222398+crazydemo@users.noreply.github.com>
This commit is contained in:
parent
cf00003f3d
commit
15bcf80596
@ -108,6 +108,15 @@ def add_multimodal_args(parser):
|
||||
type=str,
|
||||
default="cpu",
|
||||
help="The device to have the input on.")
|
||||
# Add multiturn conversation related parameters
|
||||
parser.add_argument("--multiturn",
|
||||
action="store_true",
|
||||
help="Enable multi-turn conversation mode.")
|
||||
parser.add_argument(
|
||||
"--conversation_turns",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Number of conversation turns for automated testing.")
|
||||
return parser
|
||||
|
||||
|
||||
@ -162,6 +171,80 @@ def main():
|
||||
open(os.path.join(llm._hf_model_dir, 'config.json')))['model_type']
|
||||
assert model_type in ALL_SUPPORTED_MULTIMODAL_MODELS, f"Unsupported model_type: {model_type}"
|
||||
|
||||
# If multiturn mode is enabled
|
||||
if args.multiturn:
|
||||
# Run predefined multiturn conversation examples
|
||||
assert args.prompt is not None, "Please provide a prompt for multiturn conversation."
|
||||
assert args.media is not None, "Please provide media for multiturn conversation."
|
||||
# Determine how many turns to run
|
||||
max_turns = min(args.conversation_turns, len(args.prompt))
|
||||
generated_outputs = [] # Store generated outputs for return
|
||||
|
||||
# Initialize conversation history with the first prompt
|
||||
conversation_history = args.prompt[0] if args.prompt else ""
|
||||
|
||||
for i in range(max_turns):
|
||||
print(f"\n--- Turn {i+1} ---")
|
||||
|
||||
try:
|
||||
# Use multimodal input loader to process input with conversation context
|
||||
# Use accumulated conversation history instead of just the current prompt
|
||||
cur_prompt = conversation_history
|
||||
inputs = default_multimodal_input_loader(
|
||||
tokenizer=llm.tokenizer,
|
||||
model_dir=llm._hf_model_dir,
|
||||
model_type=model_type,
|
||||
modality=args.modality,
|
||||
prompts=[cur_prompt],
|
||||
media=args.media,
|
||||
image_data_format="pt",
|
||||
num_frames=8,
|
||||
device="cpu")
|
||||
|
||||
lora_request = None
|
||||
if args.load_lora:
|
||||
if model_class is None:
|
||||
raise ValueError(
|
||||
"model_class must be provided when load_lora is True"
|
||||
)
|
||||
lora_request = model_class.lora_request(
|
||||
len(inputs), args.modality, llm._hf_model_dir)
|
||||
|
||||
# Generate response
|
||||
outputs = llm.generate(inputs,
|
||||
sampling_params,
|
||||
lora_request=lora_request)
|
||||
assert outputs and len(
|
||||
outputs) > 0 and outputs[0].outputs and len(
|
||||
outputs[0].outputs) > 0
|
||||
response = outputs[0].outputs[0].text.strip()
|
||||
|
||||
# Store generated output
|
||||
generated_outputs.append({
|
||||
"turn": i + 1,
|
||||
"user_input": cur_prompt,
|
||||
"assistant_response": response,
|
||||
"media": args.media
|
||||
})
|
||||
|
||||
conversation_history = conversation_history + "\n" + response
|
||||
if i + 1 < len(args.prompt):
|
||||
conversation_history = conversation_history + "\n" + args.prompt[
|
||||
i + 1]
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in turn {i+1}: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
continue
|
||||
|
||||
for i, output in enumerate(generated_outputs):
|
||||
print(
|
||||
f"[{i}] Prompt: {output['user_input']!r}, Generated text: {output['assistant_response']!r}"
|
||||
)
|
||||
return
|
||||
|
||||
# Original single-turn processing logic
|
||||
# set prompts and media to example prompts and images if they are not provided
|
||||
if args.prompt is None:
|
||||
args.prompt = example_medias_and_prompts[args.modality]["prompt"]
|
||||
|
||||
@ -19,6 +19,13 @@ meta-llama/Llama-3.3-70B-Instruct:
|
||||
accuracy: 84.08
|
||||
meta-llama/Llama-4-Maverick-17B-128E-Instruct:
|
||||
- accuracy: 92.20
|
||||
- quant_algo: FP8
|
||||
kv_cache_quant_algo: FP8
|
||||
accuracy: 92.20
|
||||
- quant_algo: FP8
|
||||
kv_cache_quant_algo: FP8
|
||||
spec_dec_algo: Eagle
|
||||
accuracy: 92.20
|
||||
meta-llama/Llama-4-Scout-17B-16E-Instruct:
|
||||
- accuracy: 89.70
|
||||
- quant_algo: NVFP4
|
||||
|
||||
@ -73,6 +73,9 @@ meta-llama/Llama-4-Maverick-17B-128E-Instruct:
|
||||
kv_cache_quant_algo: FP8
|
||||
spec_dec_algo: Eagle
|
||||
accuracy: 86.40
|
||||
- quant_algo: FP8
|
||||
kv_cache_quant_algo: FP8
|
||||
accuracy: 86.40
|
||||
meta-llama/Llama-4-Scout-17B-16E-Instruct:
|
||||
- accuracy: 80.00
|
||||
- quant_algo: NVFP4
|
||||
|
||||
@ -246,7 +246,7 @@ def run_parallel_test(model_name: str, model_path: str, ctx_pp: int,
|
||||
total_ctx_gpus = ctx_tp * ctx_pp * ctx_instances
|
||||
total_gen_gpus = gen_tp * gen_pp * gen_instances
|
||||
if total_ctx_gpus + total_gen_gpus > get_device_count():
|
||||
pytest.fail(
|
||||
pytest.skip(
|
||||
f"Not enough devices for {ctx_instances} ctx instances (ctx_pp={ctx_pp}*ctx_tp={ctx_tp}) + {gen_instances} gen instances (gen_pp={gen_pp}*gen_tp={gen_tp}), total: {total_ctx_gpus + total_gen_gpus}"
|
||||
)
|
||||
|
||||
@ -378,6 +378,7 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
|
||||
@skip_pre_hopper
|
||||
@parametrize_with_ids("overlap_scheduler", [True, False])
|
||||
@parametrize_with_ids("eagle3_one_model", [True, False])
|
||||
def test_eagle3(self, overlap_scheduler, eagle3_one_model):
|
||||
@ -461,6 +462,7 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
|
||||
|
||||
@pytest.mark.skip_less_device_memory(140000)
|
||||
@pytest.mark.timeout(3600)
|
||||
@pytest.mark.skip_less_device(4)
|
||||
class TestLlama4ScoutInstruct(LlmapiAccuracyTestHarness):
|
||||
MODEL_NAME = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
|
||||
MODEL_PATH = f"{llm_models_root()}/llama4-models/Llama-4-Scout-17B-16E-Instruct"
|
||||
@ -540,6 +542,7 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
@parametrize_with_ids("overlap_scheduler", [True, False])
|
||||
@parametrize_with_ids("mtp_nextn",
|
||||
[0, pytest.param(2, marks=skip_pre_hopper)])
|
||||
@pytest.mark.skip_less_device(4)
|
||||
def test_auto_dtype(self, overlap_scheduler, mtp_nextn):
|
||||
ctx_server_config = {"disable_overlap_scheduler": True}
|
||||
gen_server_config = {"disable_overlap_scheduler": not overlap_scheduler}
|
||||
@ -671,6 +674,7 @@ class TestQwen3_8B(LlmapiAccuracyTestHarness):
|
||||
task.evaluate(llm)
|
||||
|
||||
@pytest.mark.parametrize("overlap_scheduler", [False, True])
|
||||
@skip_pre_hopper
|
||||
def test_auto_dtype(self, overlap_scheduler):
|
||||
ctx_server_config = {
|
||||
"disable_overlap_scheduler": True,
|
||||
|
||||
@ -695,6 +695,7 @@ class TestMistralSmall24B(LlmapiAccuracyTestHarness):
|
||||
MODEL_NAME = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
|
||||
MODEL_PATH = f"{llm_models_root()}/Mistral-Small-3.1-24B-Instruct-2503"
|
||||
|
||||
@pytest.mark.skip_less_device_memory(80000)
|
||||
def test_auto_dtype(self):
|
||||
with LLM(self.MODEL_PATH) as llm:
|
||||
task = CnnDailymail(self.MODEL_NAME)
|
||||
@ -1033,7 +1034,7 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
max_num_streams=3) if torch_compile else None)
|
||||
pytorch_config = dict(
|
||||
disable_overlap_scheduler=not overlap_scheduler,
|
||||
use_cuda_graph=cuda_graph,
|
||||
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
|
||||
torch_compile_config=torch_compile_config,
|
||||
moe_config=MoeConfig(backend="CUTEDSL"),
|
||||
)
|
||||
@ -1191,7 +1192,7 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
max_num_streams=3) if torch_compile else None)
|
||||
pytorch_config = dict(
|
||||
disable_overlap_scheduler=not overlap_scheduler,
|
||||
use_cuda_graph=cuda_graph,
|
||||
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
|
||||
torch_compile_config=torch_compile_config,
|
||||
moe_config=MoeConfig(backend="CUTEDSL"),
|
||||
)
|
||||
|
||||
@ -2051,7 +2051,7 @@ def test_ptp_quickstart_advanced_8gpus(llm_root, llm_venv, model_name,
|
||||
def test_ptp_quickstart_advanced_8gpus_chunked_prefill_sq_22k(
|
||||
llm_root, llm_venv, model_name, model_path, cuda_graph):
|
||||
print(f"Testing {model_name} on 8 GPUs.")
|
||||
example_root = Path(os.path.join(llm_root, "examples", "pytorch"))
|
||||
example_root = Path(os.path.join(llm_root, "examples", "llm-api"))
|
||||
cmd = [
|
||||
str(example_root / "quickstart_advanced.py"),
|
||||
"--enable_chunked_prefill",
|
||||
@ -2076,10 +2076,12 @@ def test_ptp_quickstart_advanced_8gpus_chunked_prefill_sq_22k(
|
||||
@pytest.mark.skip_less_device_memory(80000)
|
||||
@pytest.mark.skip_less_device(2)
|
||||
@pytest.mark.parametrize("model_name,model_path", [
|
||||
("Llama3.1-70B-BF16", "llama-3.1-model/Meta-Llama-3.1-70B"),
|
||||
('Nemotron-Super-49B-v1-BF16',
|
||||
'nemotron-nas/Llama-3_3-Nemotron-Super-49B-v1'),
|
||||
("Mixtral-8x7B-BF16", "Mixtral-8x7B-Instruct-v0.1"),
|
||||
pytest.param('Llama3.1-70B-BF16',
|
||||
'llama-3.1-model/Meta-Llama-3.1-70B',
|
||||
marks=pytest.mark.skip_less_device_memory(95000)),
|
||||
])
|
||||
def test_ptp_quickstart_advanced_2gpus_sm120(llm_root, llm_venv, model_name,
|
||||
model_path):
|
||||
@ -2521,6 +2523,106 @@ def test_ptp_quickstart_multimodal_2gpu(llm_root, llm_venv, model_name,
|
||||
print("All answers are correct!")
|
||||
|
||||
|
||||
@pytest.mark.skip_less_device_memory(80000)
|
||||
@pytest.mark.parametrize("model_name,model_path", [
|
||||
("gemma-3-27b-it", "gemma/gemma-3-27b-it"),
|
||||
("mistral-small-3.1-24b-instruct", "Mistral-Small-3.1-24B-Instruct-2503"),
|
||||
("Phi-4-multimodal-instruct", "multimodals/Phi-4-multimodal-instruct"),
|
||||
])
|
||||
def test_ptp_quickstart_multimodal_multiturn(llm_root, llm_venv, model_name,
|
||||
model_path):
|
||||
example_root = Path(os.path.join(llm_root, "examples", "llm-api"))
|
||||
test_data_root = Path(
|
||||
os.path.join(llm_models_root(), "multimodals", "test_data"))
|
||||
|
||||
print(f"Accuracy test {model_name} image mode with example inputs.")
|
||||
|
||||
# Define accuracy inputs for image modality
|
||||
accuracy_inputs = {
|
||||
"image": {
|
||||
"prompt": [
|
||||
"Describe what you see in this image.",
|
||||
"How would you describe the atmosphere of this scene?",
|
||||
],
|
||||
"media": [
|
||||
str(test_data_root / "inpaint.png"),
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
# Define expected keywords for each model
|
||||
expected_keywords = {
|
||||
"gemma-3-27b-it": {
|
||||
"image": [
|
||||
["half", "dome", "yosemite", "landmark", "rounded"],
|
||||
["atmosphere", "peaceful", "majestic", "calm", "quiet"],
|
||||
],
|
||||
},
|
||||
"mistral-small-3.1-24b-instruct": {
|
||||
"image": [
|
||||
["depicts", "landscape", "rock", "sky", "high", "altitude"],
|
||||
["atmosphere", "serene", "majestic", "sense", "tranquility"],
|
||||
],
|
||||
},
|
||||
"Phi-4-multimodal-instruct": {
|
||||
"image": [
|
||||
["depicts", "landscape", "mountain", "half", "dome"],
|
||||
["atmosphere", "serene", "sense", "tranquility", "peace."],
|
||||
],
|
||||
},
|
||||
}
|
||||
# Build command for image modality
|
||||
cmd = [
|
||||
str(example_root / "quickstart_multimodal.py"),
|
||||
"--model_dir",
|
||||
f"{llm_models_root()}/{model_path}",
|
||||
"--modality",
|
||||
"image",
|
||||
"--multiturn",
|
||||
"--prompt",
|
||||
*accuracy_inputs["image"]["prompt"],
|
||||
"--media",
|
||||
*accuracy_inputs["image"]["media"],
|
||||
]
|
||||
|
||||
# Add model-specific configurations
|
||||
if model_name == "gemma-3-27b-it":
|
||||
# Gemma3 VLM needs a custom mask which is only supported by flashinfer backend currently.
|
||||
# Custom mask involves bidirectional masking of image tokens in context phase. To get this
|
||||
# correct, chunked prefill and kv cache reuse need to be turned off.
|
||||
cmd.append("--image_format=pil")
|
||||
cmd.append("--attention_backend=FLASHINFER")
|
||||
cmd.append("--disable_kv_cache_reuse")
|
||||
elif model_name == "Phi-4-multimodal-instruct":
|
||||
# Set max_seq_len to 4096 to use short rope factor.
|
||||
cmd.append("--max_seq_len=4096")
|
||||
cmd.append("--load_lora")
|
||||
cmd.append("--auto_model_name")
|
||||
cmd.append("Phi4MMForCausalLM")
|
||||
|
||||
output = llm_venv.run_cmd(cmd, caller=check_output)
|
||||
print("output:", output)
|
||||
# Set match ratio based on model
|
||||
match_ratio = 4.0 / 5
|
||||
if model_name == "Phi-4-multimodal-instruct":
|
||||
match_ratio = 0.6
|
||||
|
||||
# Check output accuracy
|
||||
for prompt_output, prompt_keywords in zip(
|
||||
parse_output(output), expected_keywords[model_name]["image"]):
|
||||
matches = [
|
||||
keyword in prompt_output.lower() for keyword in prompt_keywords
|
||||
]
|
||||
obs_match_ratio = 1. * sum(matches) / len(matches)
|
||||
print("prompt_output:", prompt_output)
|
||||
print("prompt_keywords:", prompt_keywords)
|
||||
print("matches:", matches)
|
||||
print("obs_match_ratio:", obs_match_ratio)
|
||||
assert obs_match_ratio >= match_ratio, f"Incorrect output!\nGenerated \"{prompt_output}\"\nExpected keywords \"{prompt_keywords}\"\n Matched keywords: {matches}\n Observed match ratio {obs_match_ratio} below threshold {match_ratio}"
|
||||
|
||||
print("All answers are correct!")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name,model_path", [
|
||||
("BertForSequenceClassification", "bert/bert-base-uncased-yelp-polarity"),
|
||||
])
|
||||
|
||||
@ -602,6 +602,9 @@ test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[image_audio]
|
||||
test_e2e.py::test_ptp_quickstart_multimodal_2gpu[gemma-3-27b-it-gemma/gemma-3-27b-it]
|
||||
test_e2e.py::test_ptp_quickstart_multimodal_2gpu[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503]
|
||||
test_e2e.py::test_ptp_quickstart_multimodal_2gpu[Phi-4-multimodal-instruct-multimodals/Phi-4-multimodal-instruct]
|
||||
test_e2e.py::test_ptp_quickstart_multimodal_multiturn[gemma-3-27b-it-gemma/gemma-3-27b-it]
|
||||
test_e2e.py::test_ptp_quickstart_multimodal_multiturn[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503]
|
||||
test_e2e.py::test_ptp_quickstart_multimodal_multiturn[Phi-4-multimodal-instruct-multimodals/Phi-4-multimodal-instruct]
|
||||
test_e2e.py::test_ptp_quickstart_bert[VANILLA-BertForSequenceClassification-bert/bert-base-uncased-yelp-polarity]
|
||||
test_e2e.py::test_ptp_quickstart_bert[TRTLLM-BertForSequenceClassification-bert/bert-base-uncased-yelp-polarity]
|
||||
test_e2e.py::test_ptp_star_attention_example[Llama3.1-8B-BF16-llama-3.1-model/Meta-Llama-3.1-8B]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user