[https://nvbugs/5410279][test] resubmit timeout refactor (#6337)

Signed-off-by: Ivy Zhang <25222398+crazydemo@users.noreply.github.com>
This commit is contained in:
Ivy Zhang 2025-08-05 16:39:25 +08:00 committed by GitHub
parent 7cbe30e17d
commit d101a6cebc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 608 additions and 228 deletions

View File

@ -701,26 +701,59 @@ class CliFlowAccuracyTestHarness:
extra_build_args: Optional[list] = None,
extra_summarize_args: Optional[list] = None,
extra_eval_long_context_args: Optional[list] = None,
env: Optional[Dict[str, str]] = None):
self.install_requirements()
self.initialize_case(
tasks=tasks,
dtype=dtype,
quant_algo=quant_algo,
kv_cache_quant_algo=kv_cache_quant_algo,
spec_dec_algo=spec_dec_algo,
extra_acc_spec=extra_acc_spec,
tp_size=tp_size,
pp_size=pp_size,
cp_size=cp_size,
extra_convert_args=extra_convert_args,
extra_build_args=extra_build_args,
extra_summarize_args=extra_summarize_args,
extra_eval_long_context_args=extra_eval_long_context_args,
env=env)
self.convert()
self.build()
self.evaluate()
env: Optional[Dict[str, str]] = None,
timeout_manager=None):
"""
Run all accuracy test phases with timeout management.
If timeout_manager is provided, each phase will be wrapped to track and deduct remaining timeout.
"""
# Use timeout_manager to manage timeout for each phase
if timeout_manager is not None:
with timeout_manager.timed_operation("install_requirements"):
self.install_requirements()
with timeout_manager.timed_operation("initialize_case"):
self.initialize_case(
tasks=tasks,
dtype=dtype,
quant_algo=quant_algo,
kv_cache_quant_algo=kv_cache_quant_algo,
spec_dec_algo=spec_dec_algo,
extra_acc_spec=extra_acc_spec,
tp_size=tp_size,
pp_size=pp_size,
cp_size=cp_size,
extra_convert_args=extra_convert_args,
extra_build_args=extra_build_args,
extra_summarize_args=extra_summarize_args,
extra_eval_long_context_args=extra_eval_long_context_args,
env=env)
with timeout_manager.timed_operation("convert"):
self.convert()
with timeout_manager.timed_operation("build"):
self.build()
with timeout_manager.timed_operation("evaluate"):
self.evaluate()
else:
# fallback: no timeout management
self.install_requirements()
self.initialize_case(
tasks=tasks,
dtype=dtype,
quant_algo=quant_algo,
kv_cache_quant_algo=kv_cache_quant_algo,
spec_dec_algo=spec_dec_algo,
extra_acc_spec=extra_acc_spec,
tp_size=tp_size,
pp_size=pp_size,
cp_size=cp_size,
extra_convert_args=extra_convert_args,
extra_build_args=extra_build_args,
extra_summarize_args=extra_summarize_args,
extra_eval_long_context_args=extra_eval_long_context_args,
env=env)
self.convert()
self.build()
self.evaluate()
class LlmapiAccuracyTestHarness:

View File

@ -1167,14 +1167,15 @@ class TestMixtral8x22B(CliFlowAccuracyTestHarness):
@skip_pre_ada
@pytest.mark.skip_less_device(4)
@pytest.mark.skip_less_device_memory(80000)
def test_fp8_tp2pp2(self):
def test_fp8_tp2pp2(self, timeout_manager):
self.run(tasks=[CnnDailymail(self.MODEL_NAME),
MMLU(self.MODEL_NAME)],
quant_algo=QuantAlgo.FP8,
tp_size=2,
pp_size=2,
extra_convert_args=["--calib_size=32"],
extra_build_args=["--gemm_plugin=auto"])
extra_build_args=["--gemm_plugin=auto"],
timeout_manager=timeout_manager)
@skip_post_blackwell
@pytest.mark.skip_less_device(8)
@ -1184,7 +1185,8 @@ class TestMixtral8x22B(CliFlowAccuracyTestHarness):
ids=['expert_parallel', 'mixed_parallel', 'tensor_parallel'])
@pytest.mark.parametrize("moe_renorm_mode", [0, 1],
ids=['no_renormalize', 'renormalize'])
def test_int8_plugin_tp8(self, moe_tp_size, moe_renorm_mode):
def test_int8_plugin_tp8(self, moe_tp_size, moe_renorm_mode,
timeout_manager):
self.run(quant_algo=QuantAlgo.W8A16,
tp_size=8,
extra_convert_args=[
@ -1195,7 +1197,8 @@ class TestMixtral8x22B(CliFlowAccuracyTestHarness):
extra_build_args=[
"--max_beam_width=4", "--gemm_plugin=auto",
"--moe_plugin=auto", f"--max_seq_len={8192}"
])
],
timeout_manager=timeout_manager)
class TestGemma2B(CliFlowAccuracyTestHarness):

View File

@ -44,7 +44,7 @@ def venv_check_output(venv, cmd, env=None, **kwargs):
return venv.run_cmd(cmd, caller=_war_check_output, env=env, **kwargs)
def venv_mpi_check_call(venv, mpi_cmd, python_cmd):
def venv_mpi_check_call(venv, mpi_cmd, python_cmd, **kwargs):
"""
This function WAR check_call() to run python_cmd with mpi.
If mpi_cmd = ["mpirun", "-n", "2"] and python_cmd = ["run.py"], the command will be:
@ -61,10 +61,10 @@ def venv_mpi_check_call(venv, mpi_cmd, python_cmd):
kwargs["cwd"] = venv.get_working_directory()
return check_call(merged_cmd, **kwargs)
venv.run_cmd(python_cmd, caller=_war_check_call)
venv.run_cmd(python_cmd, caller=_war_check_call, **kwargs)
def venv_mpi_check_output(venv, mpi_cmd, python_cmd, env=None):
def venv_mpi_check_output(venv, mpi_cmd, python_cmd, env=None, **kwargs):
"""
This function WAR check_output() to run python_cmd with mpi.
If mpi_cmd = ["mpirun", "-n", "2"] and python_cmd = ["run.py"], the command will be:
@ -81,7 +81,7 @@ def venv_mpi_check_output(venv, mpi_cmd, python_cmd, env=None):
kwargs["cwd"] = venv.get_working_directory()
return check_output(merged_cmd, **kwargs)
return venv.run_cmd(python_cmd, caller=_war_check_output, env=env)
return venv.run_cmd(python_cmd, caller=_war_check_output, env=env, **kwargs)
def parse_mpi_cmd(cmd):
@ -506,6 +506,7 @@ def convert_weights(llm_venv,
convert_cmd.append(f"--quant_ckpt_path={quant_ckpt_path}")
if per_group:
convert_cmd.append("--per_group")
timeout = kwargs.pop('timeout', None)
for key, value in kwargs.items():
if isinstance(value, bool):
@ -515,7 +516,7 @@ def convert_weights(llm_venv,
convert_cmd.extend([f"--{key}={value}"])
if llm_venv:
venv_check_call(llm_venv, convert_cmd)
venv_check_call(llm_venv, convert_cmd, timeout=timeout)
return model_dir
else:
return convert_cmd, model_dir
@ -607,6 +608,7 @@ def quantize_data(llm_venv,
if kv_cache_dtype:
quantize_cmd.append(f"--kv_cache_dtype={kv_cache_dtype}")
timeout = kwargs.pop('timeout', None)
for key, value in kwargs.items():
if isinstance(value, bool):
@ -617,7 +619,7 @@ def quantize_data(llm_venv,
if llm_venv:
if not exists(output_dir):
venv_check_call(llm_venv, quantize_cmd)
venv_check_call(llm_venv, quantize_cmd, timeout=timeout)
return output_dir
else:
return quantize_cmd, output_dir

View File

@ -2351,3 +2351,38 @@ def tritonserver_test_root(llm_root):
"tests/integration/defs/triton_server")
return tritonserver_root
@pytest.fixture
def timeout_from_marker(request):
"""Get timeout value from pytest timeout marker."""
timeout_marker = request.node.get_closest_marker('timeout')
if timeout_marker:
return timeout_marker.args[0] if timeout_marker.args else None
return None
@pytest.fixture
def timeout_from_command_line(request):
"""Get timeout value from command line --timeout parameter."""
# Get timeout from command line argument
timeout_arg = request.config.getoption("--timeout", default=None)
if timeout_arg is not None:
return float(timeout_arg)
return None
@pytest.fixture
def timeout_manager(timeout_from_command_line, timeout_from_marker):
"""Create a TimeoutManager instance with priority: marker > cmdline > config."""
from defs.utils.timeout_manager import TimeoutManager
# Priority: marker > command line
timeout_value = None
if timeout_from_marker is not None:
timeout_value = timeout_from_marker
elif timeout_from_command_line is not None:
timeout_value = timeout_from_command_line
return TimeoutManager(timeout_value)

View File

@ -94,22 +94,27 @@ def test_llm_commandr_plus_4gpus_summary(commandr_example_root,
llm_commandr_plus_model_root,
llm_datasets_root, llm_rouge_root,
llm_venv, cmodel_dir, engine_dir,
use_weight_only):
use_weight_only, timeout_manager):
"Build & run Command-R+ with smoothquant on 4 gpus."
dtype = 'float16'
tp_size = 4
model_name = os.path.basename(llm_commandr_plus_model_root)
print("Converting checkpoint...")
ckpt_dir = convert_weights(llm_venv=llm_venv,
example_root=commandr_example_root,
cmodel_dir=cmodel_dir,
model=model_name,
model_path=llm_commandr_plus_model_root,
data_type=dtype,
tp_size=tp_size,
gpus=tp_size,
use_weight_only=use_weight_only)
# Convert checkpoint with timeout management
print("Converting checkpoint...")
with timeout_manager.timed_operation("convert"):
ckpt_dir = convert_weights(llm_venv=llm_venv,
example_root=commandr_example_root,
cmodel_dir=cmodel_dir,
model=model_name,
model_path=llm_commandr_plus_model_root,
data_type=dtype,
tp_size=tp_size,
gpus=tp_size,
use_weight_only=use_weight_only,
timeout=timeout_manager.remaining_timeout)
# Build engines with timeout management
print("Building engines...")
build_cmd = [
"trtllm-build",
@ -130,12 +135,23 @@ def test_llm_commandr_plus_4gpus_summary(commandr_example_root,
f"--engine_dir={engine_dir}",
]
check_call(" ".join(build_cmd), shell=True, env=llm_venv._new_env)
with timeout_manager.timed_operation("build"):
check_call(" ".join(build_cmd),
shell=True,
env=llm_venv._new_env,
timeout=timeout_manager.remaining_timeout)
venv_mpi_check_call(
llm_venv,
["mpirun", "-n", str(tp_size), "--allow-run-as-root"], run_cmd)
# Run engines with timeout management
print("Running engines...")
with timeout_manager.timed_operation("run"):
venv_mpi_check_call(
llm_venv, ["mpirun", "-n",
str(tp_size), "--allow-run-as-root"],
run_cmd,
timeout=timeout_manager.remaining_timeout)
# Run summary with timeout management
print("Running summary...")
summary_cmd = generate_summary_cmd(
commandr_example_root,
hf_model_dir=llm_commandr_plus_model_root,
@ -144,6 +160,9 @@ def test_llm_commandr_plus_4gpus_summary(commandr_example_root,
dataset_dir=llm_datasets_root,
rouge_dir=llm_rouge_root)
venv_mpi_check_call(
llm_venv,
["mpirun", "-n", str(tp_size), "--allow-run-as-root"], summary_cmd)
with timeout_manager.timed_operation("summary"):
venv_mpi_check_call(
llm_venv, ["mpirun", "-n",
str(tp_size), "--allow-run-as-root"],
summary_cmd,
timeout=timeout_manager.remaining_timeout)

View File

@ -40,28 +40,37 @@ if get_sm_version() >= 103:
def test_llm_exaone_1gpu(data_type, exaone_example_root, llm_exaone_model_root,
llama_example_root, llm_datasets_root, llm_rouge_root,
llm_venv, cmodel_dir, engine_dir, num_beams,
use_weight_only):
use_weight_only, timeout_manager):
print("Build engines...")
model_name = "exaone"
model_dir = convert_weights(
llm_venv=llm_venv,
# NOTE
# EXAONE is based on llama so reuse llama's checkpoint converter
example_root=llama_example_root,
cmodel_dir=cmodel_dir,
model=model_name,
model_path=llm_exaone_model_root,
data_type=data_type,
use_weight_only=use_weight_only)
build_cmd = [
"trtllm-build",
f"--checkpoint_dir={model_dir}",
f"--output_dir={engine_dir}",
f"--max_beam_width={num_beams}",
]
check_call(" ".join(build_cmd), shell=True, env=llm_venv._new_env)
# Convert weights with timeout management
with timeout_manager.timed_operation("convert"):
model_dir = convert_weights(
llm_venv=llm_venv,
# NOTE
# EXAONE is based on llama so reuse llama's checkpoint converter
example_root=llama_example_root,
cmodel_dir=cmodel_dir,
model=model_name,
model_path=llm_exaone_model_root,
data_type=data_type,
use_weight_only=use_weight_only,
timeout=timeout_manager.remaining_timeout)
# Build engines with timeout management
with timeout_manager.timed_operation("build"):
build_cmd = [
"trtllm-build",
f"--checkpoint_dir={model_dir}",
f"--output_dir={engine_dir}",
f"--max_beam_width={num_beams}",
]
check_call(" ".join(build_cmd),
shell=True,
env=llm_venv._new_env,
timeout=timeout_manager.remaining_timeout)
rouge1_threshold = {
1: 22,
@ -69,6 +78,7 @@ def test_llm_exaone_1gpu(data_type, exaone_example_root, llm_exaone_model_root,
4: 23,
}[num_beams]
# Run summary with timeout management
print("Run summarize...")
summary_cmd = generate_summary_cmd(
exaone_example_root,
@ -82,7 +92,10 @@ def test_llm_exaone_1gpu(data_type, exaone_example_root, llm_exaone_model_root,
num_beams=num_beams,
)
venv_check_call(llm_venv, summary_cmd)
with timeout_manager.timed_operation("summary"):
venv_check_call(llm_venv,
summary_cmd,
timeout=timeout_manager.remaining_timeout)
@pytest.mark.skip_less_device(2)
@ -94,29 +107,40 @@ def test_llm_exaone_1gpu(data_type, exaone_example_root, llm_exaone_model_root,
indirect=True)
def test_llm_exaone_2gpu(data_type, exaone_example_root, llm_exaone_model_root,
llama_example_root, llm_datasets_root, llm_rouge_root,
llm_venv, cmodel_dir, engine_dir, num_beams):
llm_venv, cmodel_dir, engine_dir, num_beams,
timeout_manager):
tp_size = 2
print("Build engines...")
model_name = "exaone"
model_dir = convert_weights(
llm_venv=llm_venv,
# NOTE
# EXAONE is based on llama so reuse llama's checkpoint converter
example_root=llama_example_root,
cmodel_dir=cmodel_dir,
model=model_name,
model_path=llm_exaone_model_root,
data_type=data_type,
tp_size=tp_size,
pp_size=1)
build_cmd = [
"trtllm-build", f"--checkpoint_dir={model_dir}",
f"--output_dir={engine_dir}", f"--max_beam_width={num_beams}"
]
check_call(" ".join(build_cmd), shell=True, env=llm_venv._new_env)
# Convert weights with timeout management
with timeout_manager.timed_operation("convert"):
model_dir = convert_weights(
llm_venv=llm_venv,
# NOTE
# EXAONE is based on llama so reuse llama's checkpoint converter
example_root=llama_example_root,
cmodel_dir=cmodel_dir,
model=model_name,
model_path=llm_exaone_model_root,
data_type=data_type,
tp_size=tp_size,
pp_size=1,
timeout=timeout_manager.remaining_timeout)
# Build engines with timeout management
with timeout_manager.timed_operation("build"):
build_cmd = [
"trtllm-build", f"--checkpoint_dir={model_dir}",
f"--output_dir={engine_dir}", f"--max_beam_width={num_beams}"
]
check_call(" ".join(build_cmd),
shell=True,
env=llm_venv._new_env,
timeout=timeout_manager.remaining_timeout)
# Run summary with timeout management
print("Run summarize...")
summary_cmd = generate_summary_cmd(
exaone_example_root,
@ -130,6 +154,8 @@ def test_llm_exaone_2gpu(data_type, exaone_example_root, llm_exaone_model_root,
num_beams=num_beams,
)
venv_mpi_check_call(llm_venv,
["mpirun", "-n", f"{tp_size}", "--allow-run-as-root"],
summary_cmd)
with timeout_manager.timed_operation("summary"):
venv_mpi_check_call(
llm_venv, ["mpirun", "-n", f"{tp_size}", "--allow-run-as-root"],
summary_cmd,
timeout=timeout_manager.remaining_timeout)

View File

@ -644,55 +644,69 @@ def test_llm_gpt3_175b_96layers_build_only(gpt_example_root, llm_venv,
ids=["parallel_build", "serial_build"])
def test_llm_gpt3_175b_1node_8gpus(gpt_example_root, llm_venv, engine_dir,
use_attention_plugin, use_gemm_plugin,
context_fmha, parallel_build):
context_fmha, parallel_build,
timeout_manager):
"Build & Run GPT-3 175B: 96 layer w/ plugins"
dtype = 'float16'
convert_cmd = [
f"{gpt_example_root}/../../../generate_checkpoint_config.py",
f"--output_path={engine_dir}/ckpt_config.json",
"--architecture=GPTForCausalLM", f"--dtype={dtype}",
"--num_hidden_layers=96", "--num_attention_heads=96",
"--hidden_size=12288", "--vocab_size=51200", "--tp_size=8"
]
venv_check_call(llm_venv, convert_cmd)
# Convert checkpoint with timeout management
with timeout_manager.timed_operation("convert"):
convert_cmd = [
f"{gpt_example_root}/../../../generate_checkpoint_config.py",
f"--output_path={engine_dir}/ckpt_config.json",
"--architecture=GPTForCausalLM", f"--dtype={dtype}",
"--num_hidden_layers=96", "--num_attention_heads=96",
"--hidden_size=12288", "--vocab_size=51200", "--tp_size=8"
]
venv_check_call(llm_venv,
convert_cmd,
timeout=timeout_manager.remaining_timeout)
# Build engines with timeout management
print("Building engines...")
build_cmd = [
"trtllm-build",
f"--model_config={engine_dir}/ckpt_config.json",
f"--output_dir={engine_dir}",
f"--max_batch_size={32}",
f"--max_input_len={924}",
f"--max_seq_len={1024}",
]
with timeout_manager.timed_operation("build"):
build_cmd = [
"trtllm-build",
f"--model_config={engine_dir}/ckpt_config.json",
f"--output_dir={engine_dir}",
f"--max_batch_size={32}",
f"--max_input_len={924}",
f"--max_seq_len={1024}",
]
if use_attention_plugin:
build_cmd.extend([f"--gpt_attention_plugin={dtype}"])
if context_fmha:
build_cmd.extend(["--context_fmha=enable"])
if use_attention_plugin:
build_cmd.extend([f"--gpt_attention_plugin={dtype}"])
if context_fmha:
build_cmd.extend(["--context_fmha=enable"])
else:
build_cmd.extend(["--context_fmha=disable"])
else:
build_cmd.extend(["--context_fmha=disable"])
else:
build_cmd.extend([
"--gpt_attention_plugin=disable",
"--context_fmha=disable",
"--paged_kv_cache=disable",
"--remove_input_padding=disable",
])
if use_gemm_plugin:
build_cmd.extend([f"--gemm_plugin={dtype}"])
if parallel_build:
build_cmd.extend(["--workers=8"])
build_cmd.extend([
"--gpt_attention_plugin=disable",
"--context_fmha=disable",
"--paged_kv_cache=disable",
"--remove_input_padding=disable",
])
if use_gemm_plugin:
build_cmd.extend([f"--gemm_plugin={dtype}"])
if parallel_build:
build_cmd.extend(["--workers=8"])
check_call(" ".join(build_cmd), shell=True, env=llm_venv._new_env)
check_call(" ".join(build_cmd),
shell=True,
env=llm_venv._new_env,
timeout=timeout_manager.remaining_timeout)
# Run inference with timeout management
print('Run gpt3-175b...')
venv_mpi_check_call(
llm_venv,
["mpirun", "--allow-run-as-root", "--oversubscribe", "-np", "8"], [
f"{gpt_example_root}/../../../run.py", "--max_output_len=8",
f"--engine_dir={engine_dir}", "--no_add_special_tokens"
])
with timeout_manager.timed_operation("run"):
venv_mpi_check_call(
llm_venv,
["mpirun", "--allow-run-as-root", "--oversubscribe", "-np", "8"], [
f"{gpt_example_root}/../../../run.py", "--max_output_len=8",
f"--engine_dir={engine_dir}", "--no_add_special_tokens"
],
timeout=timeout_manager.remaining_timeout)
@skip_post_blackwell

View File

@ -3050,7 +3050,8 @@ def test_llm_llama_v3_8b_1048k_long_context_ppl(llama_example_root,
@pytest.mark.timeout(10800 if get_sm_version() < 89 else 3600)
def test_llm_llama_v3_1m_long_context_8gpus(llama_example_root,
llama_model_root, llm_venv,
engine_dir, cmodel_dir):
engine_dir, cmodel_dir,
timeout_manager):
"Build & run llama-3-8B-1048k on long context."
model_name = os.path.basename(llama_model_root)
dtype = 'float16'
@ -3059,51 +3060,66 @@ def test_llm_llama_v3_1m_long_context_8gpus(llama_example_root,
max_seq_len = 1048576
max_batch_size = 256
# Generate evaluation dataset with timeout management
print("Generate evaluation dataset for passkey.")
gen_cmd = [
f"{llama_example_root}/../../../infinitebench/construct_synthetic_dataset.py",
"--test_case=build_passkey",
"--test_level=7",
]
venv_check_call(llm_venv, gen_cmd)
with timeout_manager.timed_operation("gen"):
gen_cmd = [
f"{llama_example_root}/../../../infinitebench/construct_synthetic_dataset.py",
"--test_case=build_passkey",
"--test_level=7",
]
venv_check_call(llm_venv,
gen_cmd,
timeout=timeout_manager.remaining_timeout)
# Convert checkpoint with timeout management
print("Converting checkpoint...")
ckpt_dir = convert_weights(llm_venv=llm_venv,
example_root=llama_example_root,
cmodel_dir=cmodel_dir,
model=model_name,
model_path=llama_model_root,
data_type=dtype,
tp_size=tp_size,
pp_size=pp_size)
with timeout_manager.timed_operation("convert"):
ckpt_dir = convert_weights(llm_venv=llm_venv,
example_root=llama_example_root,
cmodel_dir=cmodel_dir,
model=model_name,
model_path=llama_model_root,
data_type=dtype,
tp_size=tp_size,
pp_size=pp_size,
timeout=timeout_manager.remaining_timeout)
# Build engines with timeout management
print("Building engines...")
build_cmd = [
"trtllm-build", f"--checkpoint_dir={ckpt_dir}",
f"--output_dir={engine_dir}", f"--gemm_plugin={dtype}",
f"--workers={world_size}", f"--max_seq_len={max_seq_len}",
"--max_num_tokens=4096", "--use_paged_context_fmha=enable",
f'--max_batch_size={max_batch_size}'
]
with timeout_manager.timed_operation("build"):
build_cmd = [
"trtllm-build", f"--checkpoint_dir={ckpt_dir}",
f"--output_dir={engine_dir}", f"--gemm_plugin={dtype}",
f"--workers={world_size}", f"--max_seq_len={max_seq_len}",
"--max_num_tokens=4096", "--use_paged_context_fmha=enable",
f'--max_batch_size={max_batch_size}'
]
check_call(" ".join(build_cmd), shell=True, env=llm_venv._new_env)
check_call(" ".join(build_cmd),
shell=True,
env=llm_venv._new_env,
timeout=timeout_manager.remaining_timeout)
# Run passkey evaluation with timeout management
print("Run passkey evaluation...")
eval_cmd = [
f"{llama_example_root}/../../../eval_long_context.py",
f"--engine_dir={engine_dir}",
f"--tokenizer_dir={llama_model_root}",
f"--max_input_length={max_seq_len-10}",
"--max_tokens_in_paged_kv_cache=1100000",
"--task=passkey",
"--stop_idx=10",
"--enable_chunked_context",
"--tensorrt_llm_accuracy_threshold=0.9",
]
with timeout_manager.timed_operation("eval"):
eval_cmd = [
f"{llama_example_root}/../../../eval_long_context.py",
f"--engine_dir={engine_dir}",
f"--tokenizer_dir={llama_model_root}",
f"--max_input_length={max_seq_len-10}",
"--max_tokens_in_paged_kv_cache=1100000",
"--task=passkey",
"--stop_idx=10",
"--enable_chunked_context",
"--tensorrt_llm_accuracy_threshold=0.9",
]
venv_mpi_check_call(
llm_venv, ["mpirun", "-n", f"{world_size}", "--allow-run-as-root"],
eval_cmd)
venv_mpi_check_call(
llm_venv, ["mpirun", "-n", f"{world_size}", "--allow-run-as-root"],
eval_cmd,
timeout=timeout_manager.remaining_timeout)
@pytest.mark.skip_less_device_memory(80000)
@ -3411,7 +3427,8 @@ def test_llm_llama_v3_2_smoothquant_1node_single_gpu(
def test_llm_llama_v3_1_1node_multi_gpus(llama_example_root, llama_model_root,
llm_venv, cmodel_dir,
mmlu_dataset_root, engine_dir,
fp8_quant, gemm_allreduce):
fp8_quant, gemm_allreduce,
timeout_manager):
"Run llama3.1 test on 1 node."
if ("8B" not in llama_model_root) and (get_host_total_memory() < 1000000):
pytest.skip("Host memory is insufficient.")
@ -3429,70 +3446,90 @@ def test_llm_llama_v3_1_1node_multi_gpus(llama_example_root, llama_model_root,
if not fp8_quant and "Meta-Llama-3.1-405B" == model_name:
pytest.skip("Build engine will be OOM on 1 node.")
# Convert weights with timeout management
print("Convert weight...")
model_dir = convert_weights(llm_venv=llm_venv,
example_root=llama_example_root,
cmodel_dir=cmodel_dir,
model=model_name,
model_path=llama_model_root,
data_type=data_type,
tp_size=tp_size,
pp_size=pp_size,
use_fp8_rowwise=fp8_quant,
load_by_shard=True,
workers=world_size)
with timeout_manager.timed_operation("convert"):
model_dir = convert_weights(llm_venv=llm_venv,
example_root=llama_example_root,
cmodel_dir=cmodel_dir,
model=model_name,
model_path=llama_model_root,
data_type=data_type,
tp_size=tp_size,
pp_size=pp_size,
use_fp8_rowwise=fp8_quant,
load_by_shard=True,
workers=world_size,
timeout=timeout_manager.remaining_timeout)
# Build engines with timeout management
print("Build engines...")
build_cmd = [
"trtllm-build",
f"--checkpoint_dir={model_dir}",
f"--output_dir={engine_dir}",
f"--workers={world_size}",
f"--max_batch_size={256}",
"--use_paged_context_fmha=enable",
"--max_num_tokens=4096",
"--max_input_len=64000",
"--max_seq_len=65000",
]
with timeout_manager.timed_operation("build"):
build_cmd = [
"trtllm-build",
f"--checkpoint_dir={model_dir}",
f"--output_dir={engine_dir}",
f"--workers={world_size}",
f"--max_batch_size={256}",
"--use_paged_context_fmha=enable",
"--max_num_tokens=4096",
"--max_input_len=64000",
"--max_seq_len=65000",
]
if gemm_allreduce:
build_cmd += [f"--gemm_allreduce_plugin={data_type}"]
if gemm_allreduce:
build_cmd += [f"--gemm_allreduce_plugin={data_type}"]
check_call(" ".join(build_cmd), shell=True, env=llm_venv._new_env)
check_call(" ".join(build_cmd),
shell=True,
env=llm_venv._new_env,
timeout=timeout_manager.remaining_timeout)
gen_cmd = [
f"{llama_example_root}/../../../infinitebench/construct_synthetic_dataset.py",
"--test_case=build_passkey",
"--test_level=3",
]
# Generate dataset with timeout management
with timeout_manager.timed_operation("gen"):
gen_cmd = [
f"{llama_example_root}/../../../infinitebench/construct_synthetic_dataset.py",
"--test_case=build_passkey",
"--test_level=3",
]
venv_check_call(llm_venv, gen_cmd)
venv_check_call(llm_venv,
gen_cmd,
timeout=timeout_manager.remaining_timeout)
# Run evaluation with timeout management
print("Run eval...")
eval_cmd = [
f"{llama_example_root}/../../../eval_long_context.py",
"--task=passkey",
f"--engine_dir={engine_dir}",
f"--tokenizer_dir={llama_model_root}",
"--stop_idx=6",
"--max_input_length=64000",
"--enable_chunked_context",
"--kv_cache_free_gpu_memory_fraction=0.999",
"--max_tokens_in_paged_kv_cache=65064",
"--output_dir=64k_context_tp8",
]
with timeout_manager.timed_operation("eval"):
eval_cmd = [
f"{llama_example_root}/../../../eval_long_context.py",
"--task=passkey",
f"--engine_dir={engine_dir}",
f"--tokenizer_dir={llama_model_root}",
"--stop_idx=6",
"--max_input_length=64000",
"--enable_chunked_context",
"--kv_cache_free_gpu_memory_fraction=0.999",
"--max_tokens_in_paged_kv_cache=65064",
"--output_dir=64k_context_tp8",
]
venv_mpi_check_call(
llm_venv, ["mpirun", "-n", f"{world_size}", "--allow-run-as-root"],
eval_cmd)
venv_mpi_check_call(
llm_venv, ["mpirun", "-n", f"{world_size}", "--allow-run-as-root"],
eval_cmd,
timeout=timeout_manager.remaining_timeout)
# Run MMLU with timeout management
print("Run mmlu...")
mmlu_cmd = [
"trtllm-eval", f"--model={engine_dir}",
f"--tokenizer={llama_model_root}", "--backend=tensorrt", "mmlu",
f"--dataset_path={mmlu_dataset_root}", "--check_accuracy"
]
check_call(" ".join(mmlu_cmd), shell=True, env=llm_venv._new_env)
with timeout_manager.timed_operation("mmlu"):
mmlu_cmd = [
"trtllm-eval", f"--model={engine_dir}",
f"--tokenizer={llama_model_root}", "--backend=tensorrt", "mmlu",
f"--dataset_path={mmlu_dataset_root}", "--check_accuracy"
]
check_call(" ".join(mmlu_cmd),
shell=True,
env=llm_venv._new_env,
timeout=timeout_manager.remaining_timeout)
@pytest.mark.skip_less_device_memory(80000)

View File

@ -0,0 +1,27 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utility modules for TensorRT-LLM integration tests.
This package provides various utilities to simplify test development and reduce
boilerplate code.
"""
from .timeout_manager import (TimeoutManager, create_timeout_manager,
with_timeout_management)
__all__ = [
'TimeoutManager', 'with_timeout_management', 'create_timeout_manager'
]

View File

@ -0,0 +1,184 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
from contextlib import contextmanager
from typing import Any, Callable, Optional
class TimeoutManager:
"""
A utility class for managing timeout in test cases.
This class helps reduce boilerplate code for timeout handling in test cases
by providing a simple interface to track remaining time and execute operations
with automatic timeout checking.
"""
def __init__(self, initial_timeout: Optional[float] = None):
"""
Initialize the timeout manager.
Args:
initial_timeout: Initial timeout value in seconds. If None, no timeout is enforced.
"""
self._initial_timeout = initial_timeout
self._remaining_timeout = initial_timeout
self._start_time = None
@property
def remaining_timeout(self) -> Optional[float]:
"""Get the remaining timeout value."""
return self._remaining_timeout
def reset(self, timeout: Optional[float] = None) -> None:
"""
Reset the timeout manager with a new timeout value.
Args:
timeout: New timeout value. If None, uses the initial timeout.
"""
self._remaining_timeout = timeout if timeout is not None else self._initial_timeout
self._start_time = None
def check_timeout(self, phase_name: str = "operation") -> None:
"""
Check if timeout has been exceeded and raise TimeoutError if so.
Args:
phase_name: Name of the current phase for error message.
Raises:
TimeoutError: If timeout has been exceeded.
"""
if self._remaining_timeout is not None and self._remaining_timeout <= 0:
raise TimeoutError(f"Timeout exceeded after {phase_name} phase!")
@contextmanager
def timed_operation(self, phase_name: str = "operation"):
"""
Context manager for timing an operation and updating remaining timeout.
Args:
phase_name: Name of the phase for timeout checking.
Yields:
None
Raises:
TimeoutError: If timeout is exceeded after the operation.
"""
if self._remaining_timeout is None:
# No timeout enforcement
yield
return
start_time = time.time()
try:
yield
finally:
operation_time = time.time() - start_time
self._remaining_timeout -= operation_time
self.check_timeout(phase_name)
def execute_with_timeout(self,
operation: Callable[[], Any],
phase_name: str = "operation",
**kwargs) -> Any:
"""
Execute an operation with timeout tracking.
Args:
operation: The operation to execute.
phase_name: Name of the phase for timeout checking.
**kwargs: Additional arguments to pass to the operation.
Returns:
The result of the operation.
Raises:
TimeoutError: If timeout is exceeded after the operation.
"""
with self.timed_operation(phase_name):
return operation(**kwargs)
def call_with_timeout(self,
func: Callable,
*args,
phase_name: str = "operation",
**kwargs) -> Any:
"""
Call a function with timeout tracking.
Args:
func: The function to call.
*args: Positional arguments for the function.
phase_name: Name of the phase for timeout checking.
**kwargs: Keyword arguments for the function.
Returns:
The result of the function call.
Raises:
TimeoutError: If timeout is exceeded after the function call.
"""
with self.timed_operation(phase_name):
return func(*args, **kwargs)
def create_timeout_manager(
timeout_from_marker: Optional[float] = None) -> TimeoutManager:
"""
Create a TimeoutManager instance from a timeout marker value.
Args:
timeout_from_marker: Timeout value from pytest marker.
Returns:
A TimeoutManager instance.
"""
return TimeoutManager(timeout_from_marker)
# Convenience decorator for test functions
def with_timeout_management(func: Callable) -> Callable:
"""
Decorator to automatically inject timeout management into test functions.
This decorator expects the test function to have a 'timeout_from_marker' parameter
and automatically creates a TimeoutManager instance.
Args:
func: The test function to decorate.
Returns:
The decorated function.
"""
import functools
@functools.wraps(func)
def wrapper(*args, **kwargs):
# Extract timeout_from_marker from kwargs
timeout_from_marker = kwargs.get('timeout_from_marker')
# Create timeout manager
timeout_manager = create_timeout_manager(timeout_from_marker)
# Add timeout_manager to kwargs
kwargs['timeout_manager'] = timeout_manager
return func(*args, **kwargs)
return wrapper

View File

@ -15,20 +15,20 @@ examples/test_chatglm.py::test_llm_glm_4_9b_single_gpu_summary[glm-4-9b-enable_w
examples/test_commandr.py::test_llm_commandr_v01_single_gpu_summary[disable_weight_only]
examples/test_commandr.py::test_llm_commandr_v01_single_gpu_summary[enable_weight_only]
examples/test_commandr.py::test_llm_commandr_plus_4gpus_summary[disable_weight_only] TIMEOUT (120)
examples/test_commandr.py::test_llm_commandr_plus_4gpus_summary[enable_weight_only]
examples/test_commandr.py::test_llm_commandr_plus_4gpus_summary[enable_weight_only] TIMEOUT (120)
examples/test_eagle.py::test_llm_eagle_1gpu_modelopt_ckpt[llama3.1-eagle-8b-hf_v0.5-float16-bs8]
examples/test_eagle.py::test_llm_eagle_1gpu[EAGLE-Vicuna-7B-v1.3-float16-bs1-eagle1]
examples/test_eagle.py::test_llm_eagle_1gpu[EAGLE-Vicuna-7B-v1.3-float16-bs1-eagle2]
examples/test_enc_dec.py::test_llm_enc_dec_general[compare_hf-bart-large-cnn-float16-enable_gemm_plugin-enable_attention_plugin-enable_paged_kv_cache-tp:1-pp:1-nb:1-enable_fp8] TIMEOUT (60)
examples/test_enc_dec.py::test_llm_enc_dec_general[compare_hf-byt5-small-float32-enable_gemm_plugin-enable_attention_plugin-enable_paged_kv_cache-tp:1-pp:1-nb:1-enable_fp8]
examples/test_enc_dec.py::test_llm_enc_dec_general[compare_hf-flan-t5-small-float32-enable_gemm_plugin-enable_attention_plugin-enable_paged_kv_cache-tp:1-pp:1-nb:1-disable_fp8]
examples/test_enc_dec.py::test_llm_enc_dec_general[compare_hf-flan-t5-small-float32-enable_gemm_plugin-enable_attention_plugin-enable_paged_kv_cache-tp:2-pp:2-nb:1-enable_fp8]
examples/test_enc_dec.py::test_llm_enc_dec_general[compare_hf-mbart-large-50-many-to-one-mmt-float16-enable_gemm_plugin-enable_attention_plugin-enable_paged_kv_cache-tp:1-pp:1-nb:1-disable_fp8]
examples/test_enc_dec.py::test_llm_enc_dec_general[compare_hf-mbart-large-50-many-to-one-mmt-float16-enable_gemm_plugin-enable_attention_plugin-enable_paged_kv_cache-tp:2-pp:2-nb:1-enable_fp8]
examples/test_enc_dec.py::test_llm_enc_dec_general[compare_hf-t5-small-float32-enable_gemm_plugin-enable_attention_plugin-enable_paged_kv_cache-tp:1-pp:1-nb:1-disable_fp8]
examples/test_enc_dec.py::test_llm_enc_dec_general[compare_hf-t5-small-float32-enable_gemm_plugin-enable_attention_plugin-enable_paged_kv_cache-tp:2-pp:1-nb:1-enable_fp8]
examples/test_enc_dec.py::test_llm_enc_dec_general[no_compare_hf-byt5-small-float32-enable_gemm_plugin-enable_attention_plugin-enable_paged_kv_cache-tp:1-pp:1-nb:1-enable_fp8]
examples/test_enc_dec.py::test_llm_enc_dec_general[no_compare_hf-byt5-small-float32-enable_gemm_plugin-enable_attention_plugin-enable_paged_kv_cache-tp:2-pp:1-nb:1-disable_fp8]
examples/test_enc_dec.py::test_llm_enc_dec_general[compare_hf-bart-large-cnn-float16-enable_gemm_plugin-enable_attention_plugin-enable_paged_kv_cache-tp:1-pp:1-nb:1-enable_fp8] TIMEOUT (90)
examples/test_enc_dec.py::test_llm_enc_dec_general[compare_hf-byt5-small-float32-enable_gemm_plugin-enable_attention_plugin-enable_paged_kv_cache-tp:1-pp:1-nb:1-enable_fp8] TIMEOUT (90)
examples/test_enc_dec.py::test_llm_enc_dec_general[compare_hf-flan-t5-small-float32-enable_gemm_plugin-enable_attention_plugin-enable_paged_kv_cache-tp:1-pp:1-nb:1-disable_fp8] TIMEOUT (90)
examples/test_enc_dec.py::test_llm_enc_dec_general[compare_hf-flan-t5-small-float32-enable_gemm_plugin-enable_attention_plugin-enable_paged_kv_cache-tp:2-pp:2-nb:1-enable_fp8] TIMEOUT (90)
examples/test_enc_dec.py::test_llm_enc_dec_general[compare_hf-mbart-large-50-many-to-one-mmt-float16-enable_gemm_plugin-enable_attention_plugin-enable_paged_kv_cache-tp:1-pp:1-nb:1-disable_fp8] TIMEOUT (90)
examples/test_enc_dec.py::test_llm_enc_dec_general[compare_hf-mbart-large-50-many-to-one-mmt-float16-enable_gemm_plugin-enable_attention_plugin-enable_paged_kv_cache-tp:2-pp:2-nb:1-enable_fp8] TIMEOUT (90)
examples/test_enc_dec.py::test_llm_enc_dec_general[compare_hf-t5-small-float32-enable_gemm_plugin-enable_attention_plugin-enable_paged_kv_cache-tp:1-pp:1-nb:1-disable_fp8] TIMEOUT (90)
examples/test_enc_dec.py::test_llm_enc_dec_general[compare_hf-t5-small-float32-enable_gemm_plugin-enable_attention_plugin-enable_paged_kv_cache-tp:2-pp:1-nb:1-enable_fp8] TIMEOUT (90)
examples/test_enc_dec.py::test_llm_enc_dec_general[no_compare_hf-byt5-small-float32-enable_gemm_plugin-enable_attention_plugin-enable_paged_kv_cache-tp:1-pp:1-nb:1-enable_fp8] TIMEOUT (90)
examples/test_enc_dec.py::test_llm_enc_dec_general[no_compare_hf-byt5-small-float32-enable_gemm_plugin-enable_attention_plugin-enable_paged_kv_cache-tp:2-pp:1-nb:1-disable_fp8] TIMEOUT (90)
examples/test_exaone.py::test_llm_exaone_1gpu[disable_weight_only-exaone_3.0_7.8b_instruct-float16-nb:1] TIMEOUT (90)
examples/test_exaone.py::test_llm_exaone_1gpu[disable_weight_only-exaone_3.0_7.8b_instruct-float16-nb:4] TIMEOUT (90)
examples/test_exaone.py::test_llm_exaone_1gpu[disable_weight_only-exaone_3.0_7.8b_instruct-float16-nb:4] TIMEOUT (90)