add gpqa accuracy test config for wideep

Signed-off-by: Zhenhuan Chen <zhenhuanc@nvidia.com>
This commit is contained in:
Zhenhuan Chen 2026-01-11 19:19:31 -08:00
parent 6a15b493e6
commit 9900551173
5 changed files with 179 additions and 78 deletions

View File

@ -1398,7 +1398,7 @@ repos:
exclude: |
(?x)^(.*cubin.cpp | .*cubin.h)$
- id: check-yaml
args: [--allow-multiple-documents]
args: [--allow-multiple-documents, --unsafe]
exclude: ".*/gitlab/.*.yml"
- id: trailing-whitespace
exclude: '\.(patch|md)$'

View File

@ -138,6 +138,36 @@ def convert_allocations_to_server_config(allocations, server_port=8333):
return server_config
def convert_envs_to_str(env_vars: Dict[str, str]) -> str:
','.join([f"{key}='{value}'" for key, value in env_vars.items()])
def replace_env_in_file(log_dir, file_path, env_var):
with open(file_path, 'r', encoding='utf-8') as f:
config_content = f.read()
for env_name, env_value in env_var.items():
file_content = config_content.replace(env_name, env_value)
tmp_dir = os.path.join(log_dir, "lm_eval_configs")
os.makedirs(tmp_dir, exist_ok=True)
tmp_file = os.path.join(tmp_dir, os.path.basename(file_path))
# Write modified config to temp file
with open(tmp_file, 'w', encoding='utf-8') as f:
f.write(file_content)
# Check if has custom utils.py in the same directory
# Needed for GPQA task
custom_utils_path = os.path.join(os.path.dirname(file_path), 'utils.py')
if os.path.exists(custom_utils_path):
# copy utils.py to temp directory
shutil.copy(custom_utils_path, tmp_dir)
# Return temp directory
return tmp_dir
def submit_job(config, log_dir, dry_run):
# Extract configurations
slurm_config = config['slurm']
@ -208,33 +238,34 @@ def submit_job(config, log_dir, dry_run):
gen_batch_size = worker_config['gen']['max_batch_size']
gen_enable_attention_dp = worker_config['gen']['enable_attention_dp']
# Get eplb num_slots for gen worker
load_balancer_config = worker_config['gen'].get('moe_config', {}).get(
'load_balancer', {})
if isinstance(load_balancer_config, str):
with open(load_balancer_config, 'r') as f:
load_balancer_config = yaml.safe_load(f)
eplb_num_slots = load_balancer_config.get('num_slots', 0)
# Get mtp_size from gen config's speculative_config
mtp_size = worker_config['gen'].get('speculative_config',
{}).get('num_nextn_predict_layers', 0)
# Create base log directory path
if log_dir is None:
# Create base log directory path
date_prefix = datetime.now().strftime("%Y%m%d")
log_base = os.path.join(env_config['work_dir'],
f"logs/{date_prefix}/{isl}-{osl}")
log_base = os.path.join(env_config['work_dir'], "logs")
else:
log_base = log_dir
date_prefix = datetime.now().strftime("%Y%m%d-%H%M%S")
log_base = os.path.join(log_base, f"{date_prefix}/{isl}-{osl}")
# Get eplb num_slots for gen worker
load_balancer_config = worker_config['gen'].get('moe_config', {}).get(
'load_balancer', {})
if isinstance(load_balancer_config, str):
with open(load_balancer_config, 'r') as f:
load_balancer_config = yaml.safe_load(f)
eplb_num_slots = load_balancer_config.get('num_slots', 0)
# Determine directory suffix based on attention_dp
if gen_enable_attention_dp:
dir_suffix = f"disagg_ctx{ctx_num}_gen{gen_num}_dep{gen_tp_size}_batch{gen_batch_size}_eplb{eplb_num_slots}_mtp{mtp_size}"
else:
dir_suffix = f"disagg_ctx{ctx_num}_gen{gen_num}_tep{gen_tp_size}_batch{gen_batch_size}_eplb{eplb_num_slots}_mtp{mtp_size}"
# Get mtp_size from gen config's speculative_config
mtp_size = worker_config['gen'].get('speculative_config',
{}).get('num_nextn_predict_layers',
0)
# Determine directory suffix based on attention_dp
if gen_enable_attention_dp:
dir_suffix = f"disagg_ctx{ctx_num}_gen{gen_num}_dep{gen_tp_size}_batch{gen_batch_size}_eplb{eplb_num_slots}_mtp{mtp_size}"
else:
dir_suffix = f"disagg_ctx{ctx_num}_gen{gen_num}_tep{gen_tp_size}_batch{gen_batch_size}_eplb{eplb_num_slots}_mtp{mtp_size}"
# Create full log directory path
log_dir = os.path.join(log_base, dir_suffix)
# Create full log directory path
log_dir = os.path.join(log_base, dir_suffix)
# Remove existing directory if it exists
if os.path.exists(log_dir):
@ -340,32 +371,44 @@ def submit_job(config, log_dir, dry_run):
f"--container-mounts={env_config['container_mount']}",
f"--mpi=pmix --overlap -N 1 -n 1",
]
env_var = config['benchmark'].get('env_var', '')
benchmark_prefix = client_slurm_prefix + [f"--export \"{env_var}\""]
if benchmark_config['use_nv_sa_benchmark']:
benchmark_cmd = [
f"bash {env_config['work_dir']}/run_benchmark_nv_sa.sh",
f"'{env_config['model_path']}' {isl} {osl} {benchmark_config['benchmark_ratio']} {benchmark_config['multi_round']} {gen_num} '{benchmark_config['concurrency_list']}' {benchmark_config['streaming']} '{log_dir}' {disagg_server_hostname} {disagg_server_port}",
f"&> {log_dir}/6_bench.log"
# Append benchmark commands
if benchmark_config.get('enable_benchmark', True):
env_var = config['benchmark'].get('env_var', {})
benchmark_prefix = client_slurm_prefix + [
f"--export \"{convert_envs_to_str(env_var)}\""
]
client_cmds.append(" ".join(benchmark_prefix + benchmark_cmd))
else:
benchmark_cmd = [
f"bash {env_config['work_dir']}/run_benchmark.sh",
f"'{env_config['model_path']}' '{benchmark_config['dataset_file']}' {benchmark_config['multi_round']} {gen_num} '{benchmark_config['concurrency_list']}' {benchmark_config['streaming']} '{log_dir}' {disagg_server_hostname} {disagg_server_port}",
f"&> {log_dir}/6_bench.log"
]
client_cmds.append(" ".join(benchmark_prefix + benchmark_cmd))
if benchmark_config['use_nv_sa_benchmark']:
benchmark_cmd = [
f"bash {env_config['work_dir']}/run_benchmark_nv_sa.sh",
f"'{env_config['model_path']}' {isl} {osl} {benchmark_config['benchmark_ratio']} {benchmark_config['multi_round']} {gen_num} '{benchmark_config['concurrency_list']}' {benchmark_config['streaming']} '{log_dir}' {disagg_server_hostname} {disagg_server_port}",
f"&> {log_dir}/6_bench.log"
]
client_cmds.append(" ".join(benchmark_prefix + benchmark_cmd))
else:
benchmark_cmd = [
f"bash {env_config['work_dir']}/run_benchmark.sh",
f"'{env_config['model_path']}' '{benchmark_config['dataset_file']}' {benchmark_config['multi_round']} {gen_num} '{benchmark_config['concurrency_list']}' {benchmark_config['streaming']} '{log_dir}' {disagg_server_hostname} {disagg_server_port}",
f"&> {log_dir}/6_bench.log"
]
client_cmds.append(" ".join(benchmark_prefix + benchmark_cmd))
# Append accuracy test commands
if config['accuracy']['enable_accuracy_test']:
env_var = config['accuracy'].get('env_var', '')
accuracy_prefix = client_slurm_prefix + [f"--export \"{env_var}\""]
env_var = config['accuracy'].get('env_var', {})
accuracy_prefix = client_slurm_prefix + [
f"--export \"{convert_envs_to_str(env_var)}\""
]
for task in config['accuracy']['tasks']:
extra_kwargs = config['accuracy']['tasks'][task].get('extra_kwargs', {})
extra_kwargs = config['accuracy']['tasks'][task].get(
'extra_kwargs', {})
extra_kwargs_str = ""
for key, value in extra_kwargs.items():
if isinstance(value, bool):
if value:
extra_kwargs_str += f" --{key}"
elif key == "custom_config":
extra_kwargs_str += f" --include_path={replace_env_in_file(log_dir, value, env_var)}"
else:
extra_kwargs_str += f" --{key}='{value}'"
end_point_map = {
@ -374,13 +417,10 @@ def submit_job(config, log_dir, dry_run):
}
model = config['accuracy']['tasks'][task]['model']
accuracy_cmd = [
'lm_eval',
'--model', model,
'--tasks', task,
'--model_args', f"model={env_config['model_path']},base_url=http://{disagg_server_hostname}:{disagg_server_port}/{end_point_map[model]},{config['accuracy']['tasks'][task]['model_args_extra']}",
'--log_samples',
'--output_path', f'{log_dir}/accuracy_eval_{task}',
extra_kwargs_str,
'lm_eval', '--model', model, '--tasks', task, '--model_args',
f"model={env_config['model_path']},base_url=http://{disagg_server_hostname}:{disagg_server_port}/{end_point_map[model]},{config['accuracy']['tasks'][task]['model_args_extra']}",
'--log_samples', '--output_path',
f'{log_dir}/accuracy_eval_{task}', extra_kwargs_str,
f"&> {log_dir}/7_accuracy_eval_{task}.log"
]
client_cmds.append(" ".join(accuracy_prefix + accuracy_cmd))

View File

@ -10,10 +10,10 @@ metadata:
dataset_file: disagg_datasets/kimi-k2-1024-1024-20000-ratio-1_for_serve.json
accuracy:
datasets:
- dataset_name: gsm8k
expected_value: 0.9454
- dataset_name: gpqa_diamond_cot_zeroshot
expected_value: 0.65
threshold_type: hypothesis_test
filter_type: flexible-extract
filter_type: strict-match
slurm:
script_file: disaggr_torch.slurm
partition: <partition>
@ -23,7 +23,8 @@ slurm:
extra_args: "--gres=gpu:4"
numa_bind: true
benchmark:
mode: gen_only
enable_benchmark: false
mode: e2e
use_nv_sa_benchmark: false
multi_round: 8
benchmark_ratio: 1.0
@ -47,9 +48,16 @@ profiling:
nsys_on: false
accuracy:
enable_accuracy_test: true
model: local-completions
tasks: gsm8k
model_args_extra: num_concurrent=512,max_retries=3,tokenized_requests=false,timeout=1200,max_gen_toks=256,max_length=4096
env_var:
HF_HOME: <hf_home_path>
tasks:
gpqa_diamond_local:
model: "local-chat-completions"
model_args_extra: "num_concurrent=512,max_retries=3,tokenized_requests=false,timeout=7200,max_gen_toks=16384"
extra_kwargs:
apply_chat_template: true
trust_remote_code: true
custom_config: <repo_path>/tests/integration/lm_eval_configs/gpqa_diamond_local.yaml
worker_config:
gen:
enable_layerwise_nvtx_marker: true
@ -58,30 +66,16 @@ worker_config:
enable_attention_dp: true
enable_lm_head_tp_in_adp: false
pipeline_parallel_size: 1
max_batch_size: 1024
max_num_tokens: 1024
max_seq_len: 5120
max_batch_size: 512
max_num_tokens: 512
max_seq_len: 16384
cuda_graph_config:
enable_padding: true
batch_sizes:
- 1
- 2
- 4
- 8
- 16
- 32
- 64
- 128
- 256
- 512
- 768
- 1024
- 2048
- 1024
max_batch_size: 512
print_iter_log: true
kv_cache_config:
enable_block_reuse: false
free_gpu_memory_fraction: 0.8
free_gpu_memory_fraction: 0.85
dtype: fp8
moe_config:
backend: WIDEEP
@ -97,9 +91,9 @@ worker_config:
trust_remote_code: true
ctx:
enable_layerwise_nvtx_marker: true
max_batch_size: 8
max_batch_size: 32
max_num_tokens: 8448
max_seq_len: 5120
max_seq_len: 8448
tensor_parallel_size: 4
moe_expert_parallel_size: 4
enable_attention_dp: true

View File

@ -0,0 +1,33 @@
# Modified from tensorrt_llm/evaluate/lm_eval_tasks/gpqa/cot_zeroshot_aa/gpqa_diamond_cot_zeroshot_aa.yaml
task: gpqa_diamond_local
dataset_path: HF_HOME/datasets/Idavidrein___gpqa
tag: gpqa
output_type: generate_until
process_docs: !function utils.process_gpqa_docs
training_split: train
# Because huggingface dataset only has train split
validation_split: train
test_split: null
doc_to_text: "Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering.\n{{Question}}\nA) {{choice1}}\nB) {{choice2}}\nC) {{choice3}}\nD) {{choice4}}"
doc_to_target: answer
filter_list:
- name: "strict-match"
filter:
- function: "regex"
regex_pattern: '(?i)Answer[ \t]*:[ \t]*([A-D])'
group_select: 0
- function: "take_first"
generation_kwargs:
until:
- "</s>"
do_sample: false
temperature: 0.0
num_fewshot: 0
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
metadata:
version: 1.0

View File

@ -0,0 +1,34 @@
import random
import datasets
def preprocess(text):
if text is None:
return " "
return text.strip()
def process_gpqa_docs(dataset: datasets.Dataset) -> datasets.Dataset:
def _process_doc(doc):
choices = [
preprocess(doc["Incorrect Answer 1"]),
preprocess(doc["Incorrect Answer 2"]),
preprocess(doc["Incorrect Answer 3"]),
preprocess(doc["Correct Answer"]),
]
random.shuffle(choices)
correct_answer_index = choices.index(preprocess(doc["Correct Answer"]))
out_doc = {
"choice1": choices[0],
"choice2": choices[1],
"choice3": choices[2],
"choice4": choices[3],
"choices": [choices[0], choices[1], choices[2], choices[3]],
"answer": f"{chr(65 + correct_answer_index)}",
}
return out_doc
return dataset.map(_process_doc)