mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
851 lines
32 KiB
Python
851 lines
32 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
LongBench v1 evaluation script with TensorRT-LLM and sparse attention.
|
|
|
|
Usage:
|
|
python longbench_rocket_eval.py --dataset narrativeqa --model_path /path/to/model --longbench_path ./LongBench --output_dir results/
|
|
|
|
# Run all LongBench tasks
|
|
python longbench_rocket_eval.py --model_path /path/to/model --longbench_path ./LongBench --output_dir results/ --token_budget 2048 --rocket_sparse
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import os
|
|
import sys
|
|
import time
|
|
from datetime import datetime
|
|
from typing import Any, Dict, List, Tuple
|
|
|
|
import numpy as np
|
|
from datasets import load_dataset
|
|
from transformers import AutoTokenizer
|
|
|
|
# Add tensorrt_llm imports
|
|
from tensorrt_llm import LLM, SamplingParams
|
|
from tensorrt_llm.llmapi import (CudaGraphConfig, KvCacheConfig, MoeConfig,
|
|
RocketSparseAttentionConfig)
|
|
from tensorrt_llm.logger import logger
|
|
|
|
LONGBENCH_DATASETS = ["narrativeqa", "qasper", "multifieldqa_en", "multifieldqa_zh", "hotpotqa", "2wikimqa", "musique", \
|
|
"dureader", "gov_report", "qmsum", "multi_news", "vcsum", "trec", "triviaqa", "samsum", "lsht", \
|
|
"passage_count", "passage_retrieval_en", "passage_retrieval_zh", "lcc", "repobench-p"]
|
|
|
|
# Task categorization
|
|
TASK_DATASETS = {
|
|
'single_doc_qa':
|
|
["narrativeqa", "qasper", "multifieldqa_en", "multifieldqa_zh"],
|
|
'multi_doc_qa': ["hotpotqa", "2wikimqa", "musique", "dureader"],
|
|
'summarization': ["gov_report", "qmsum", "multi_news", "vcsum"],
|
|
'few_shots': ["trec", "triviaqa", "samsum", "lsht"],
|
|
'synthetic':
|
|
["passage_count", "passage_retrieval_en", "passage_retrieval_zh"],
|
|
'code': ["lcc", "repobench-p"]
|
|
}
|
|
|
|
# Chat templates mapping
|
|
CHAT_TEMPLATES = {
|
|
"llama3.1-8b-instruct": "llama3",
|
|
"llama3-8b-instruct": "llama3",
|
|
"mistral-7b-instruct-v0.2": "mistral",
|
|
"longchat-7b-v1.5-32k": "vicuna"
|
|
}
|
|
|
|
|
|
def parse_arguments() -> argparse.Namespace:
|
|
"""Parse command line arguments."""
|
|
parser = argparse.ArgumentParser(
|
|
description="LongBench evaluation with TensorRT-LLM and RocketKV")
|
|
|
|
# Model and data arguments
|
|
parser.add_argument('--model_path',
|
|
type=str,
|
|
required=True,
|
|
help='Path to model (HF model name or local path)')
|
|
parser.add_argument('--dataset',
|
|
type=str,
|
|
nargs='+',
|
|
choices=LONGBENCH_DATASETS,
|
|
help='LongBench datasets to evaluate on')
|
|
parser.add_argument('--run_all_tasks',
|
|
action='store_true',
|
|
help='Run evaluation on all LongBench tasks')
|
|
parser.add_argument('--longbench_path',
|
|
type=str,
|
|
default='./LongBench',
|
|
help='Path to LongBench directory')
|
|
|
|
# Output arguments
|
|
parser.add_argument('--output_dir',
|
|
type=str,
|
|
required=True,
|
|
help='Directory to save results')
|
|
parser.add_argument('--exp_name',
|
|
type=str,
|
|
default=None,
|
|
help='Experiment name (auto-generated if not provided)')
|
|
|
|
# Model configuration
|
|
parser.add_argument('--attention_backend',
|
|
type=str,
|
|
default='VANILLA',
|
|
choices=['VANILLA', 'TRTLLM', 'FLASHINFER'],
|
|
help='Attention backend to use')
|
|
parser.add_argument('--backend',
|
|
type=str,
|
|
default='pytorch',
|
|
choices=['pytorch', 'tensorrt'],
|
|
help='LLM backend to use')
|
|
parser.add_argument('--chat_template',
|
|
type=str,
|
|
default='auto',
|
|
help='Chat template to use (auto-detect if "auto")')
|
|
|
|
# Sequence and batch configuration
|
|
parser.add_argument('--max_seq_len',
|
|
type=int,
|
|
default=133120,
|
|
help='Maximum sequence length')
|
|
parser.add_argument('--max_batch_size',
|
|
type=int,
|
|
default=1,
|
|
help='Maximum batch size')
|
|
parser.add_argument('--max_new_tokens',
|
|
type=int,
|
|
default=256,
|
|
help='Maximum new tokens to generate')
|
|
parser.add_argument(
|
|
'--max_num_tokens',
|
|
type=int,
|
|
default=133120,
|
|
help='Maximum total tokens across all sequences in a batch')
|
|
|
|
# Parallelism
|
|
parser.add_argument('--moe_backend',
|
|
type=str,
|
|
default='CUTLASS',
|
|
choices=[
|
|
'CUTLASS', 'TRTLLM', 'VANILLA', 'WIDEEP',
|
|
'DEEPGEMM', 'CUTEDSL', 'TRITON'
|
|
])
|
|
parser.add_argument('--tp_size', type=int, default=1)
|
|
parser.add_argument('--moe_ep_size', type=int, default=-1)
|
|
parser.add_argument('--enable_attention_dp',
|
|
default=False,
|
|
action='store_true')
|
|
|
|
# RocketKV configuration
|
|
parser.add_argument('--rocket_sparse',
|
|
action='store_true',
|
|
help='Use rocket sparse attention')
|
|
parser.add_argument('--token_budget',
|
|
type=int,
|
|
default=2048,
|
|
help='Token budget for RocketKV (prompt_budget)')
|
|
parser.add_argument('--window_size',
|
|
type=int,
|
|
default=32,
|
|
help='Window size for RocketKV')
|
|
parser.add_argument('--kernel_size',
|
|
type=int,
|
|
default=63,
|
|
help='Kernel size for RocketKV')
|
|
parser.add_argument('--topk',
|
|
type=int,
|
|
default=64,
|
|
help='Top-k for RocketKV')
|
|
parser.add_argument('--kt_cache_dtype',
|
|
type=str,
|
|
default='float8_e5m2',
|
|
choices=['bfloat16', 'float8_e5m2'],
|
|
help='KT cache data type')
|
|
|
|
# KV cache configuration
|
|
parser.add_argument('--kv_cache_dtype',
|
|
type=str,
|
|
default='auto',
|
|
help='KV cache data type')
|
|
parser.add_argument('--tokens_per_block', type=int, default=32)
|
|
parser.add_argument('--kv_cache_fraction',
|
|
type=float,
|
|
default=0.7,
|
|
help='Fraction of GPU memory for KV cache')
|
|
|
|
# Runtime
|
|
parser.add_argument('--print_iter_log',
|
|
default=False,
|
|
action='store_true',
|
|
help='Print iteration logs during execution')
|
|
parser.add_argument('--use_cuda_graph', default=True, action='store_true')
|
|
parser.add_argument('--cuda_graph_padding_enabled',
|
|
default=True,
|
|
action='store_true')
|
|
parser.add_argument('--cuda_graph_batch_sizes',
|
|
nargs='+',
|
|
type=int,
|
|
default=None)
|
|
|
|
# Evaluation parameters
|
|
parser.add_argument('--num_samples',
|
|
type=int,
|
|
default=None,
|
|
help='Number of samples to evaluate (None for all)')
|
|
parser.add_argument('--start_idx',
|
|
type=int,
|
|
default=0,
|
|
help='Start index for evaluation')
|
|
|
|
# System arguments
|
|
parser.add_argument('--log_level',
|
|
type=str,
|
|
default='info',
|
|
choices=['debug', 'info', 'warning', 'error'],
|
|
help='Logging level')
|
|
parser.add_argument('--seed', type=int, default=42, help='Random seed')
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Validation
|
|
if not args.run_all_tasks and not args.dataset:
|
|
parser.error("Must specify either --dataset or --run_all_tasks")
|
|
|
|
return args
|
|
|
|
|
|
def setup_longbench_imports(longbench_path: str):
|
|
"""Add LongBench to Python path and import required modules."""
|
|
longbench_dir = os.path.join(longbench_path, "LongBench") # for v1
|
|
if not os.path.exists(longbench_dir):
|
|
raise FileNotFoundError(
|
|
f"LongBench directory not found: {longbench_dir}")
|
|
|
|
# Add to path
|
|
if longbench_dir not in sys.path:
|
|
sys.path.insert(0, longbench_dir)
|
|
|
|
|
|
def load_longbench_config(
|
|
longbench_path: str) -> Tuple[Dict[str, str], Dict[str, int]]:
|
|
"""Load LongBench configuration files."""
|
|
config_dir = os.path.join(longbench_path, "LongBench", "config")
|
|
|
|
# Load dataset2prompt.json
|
|
prompt_file = os.path.join(config_dir, "dataset2prompt.json")
|
|
with open(prompt_file, 'r', encoding='utf-8') as f:
|
|
dataset2prompt = json.load(f)
|
|
|
|
# Load dataset2maxlen.json
|
|
maxlen_file = os.path.join(config_dir, "dataset2maxlen.json")
|
|
with open(maxlen_file, 'r', encoding='utf-8') as f:
|
|
dataset2maxlen = json.load(f)
|
|
|
|
return dataset2prompt, dataset2maxlen
|
|
|
|
|
|
# LongBench's build_chat function (simplified version)
|
|
def build_chat(tokenizer, prompt, chat_template):
|
|
"""Build chat prompt following LongBench's approach."""
|
|
if chat_template == "vicuna" or chat_template == "longchat":
|
|
try:
|
|
from fastchat.model import get_conversation_template
|
|
conv = get_conversation_template("vicuna")
|
|
conv.append_message(conv.roles[0], prompt)
|
|
conv.append_message(conv.roles[1], None)
|
|
prompt = conv.get_prompt()
|
|
except ImportError:
|
|
# Fallback if fastchat is not available
|
|
prompt = f"A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n\nUSER: {prompt}\nASSISTANT:"
|
|
elif chat_template == "llama2":
|
|
prompt = f"[INST]{prompt}[/INST]"
|
|
elif chat_template == "llama3":
|
|
prompt = f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
|
elif chat_template == "mistral":
|
|
prompt = f"[INST] {prompt} [/INST]"
|
|
# For other templates or "none", return prompt as-is
|
|
return prompt
|
|
|
|
|
|
def determine_chat_template(model_path: str, chat_template: str) -> str:
|
|
"""Determine chat template based on model path."""
|
|
if chat_template != 'auto':
|
|
return chat_template
|
|
|
|
model_path_lower = model_path.lower()
|
|
|
|
for model_key, template in CHAT_TEMPLATES.items():
|
|
if model_key.replace('-', '').replace('.',
|
|
'') in model_path_lower.replace(
|
|
'-', '').replace('.', ''):
|
|
return template
|
|
|
|
# Default fallback
|
|
if 'llama' in model_path_lower:
|
|
return 'llama3'
|
|
elif 'mistral' in model_path_lower:
|
|
return 'mistral'
|
|
else:
|
|
return 'none' # No special formatting
|
|
|
|
|
|
def post_process(pred: str, chat_template: str, dataset: str) -> str:
|
|
"""Post-process prediction following LongBench's approach."""
|
|
pred = pred.split("</s")[0].strip()
|
|
if chat_template == "qwen":
|
|
pred = pred.split("<|im_end|>")[0]
|
|
elif "llama2" in chat_template.lower():
|
|
pred = (pred.split("(Document")[0].split("\n\nQuestion")[0].split(
|
|
"\n\nAnswer")[0].split("[INST]")[0].split("[/INST]")[0].split(
|
|
"(Passage")[0].strip())
|
|
if dataset == "samsum":
|
|
pred = pred.split("\n")[0].strip()
|
|
|
|
return pred
|
|
|
|
|
|
def format_prompt_style(sample: Dict[str, Any], instruction: str,
|
|
chat_template: str, dataset: str, tokenizer) -> str:
|
|
"""Format prompt following LongBench's approach."""
|
|
# First format the instruction using the sample data
|
|
prompt = instruction.format(**sample)
|
|
|
|
if dataset not in [
|
|
"trec", "triviaqa", "samsum", "lsht", "lcc", "repobench-p"
|
|
]:
|
|
prompt = build_chat(tokenizer, prompt, chat_template)
|
|
|
|
return prompt
|
|
|
|
|
|
def initialize_llm(args: argparse.Namespace) -> Tuple[LLM, AutoTokenizer]:
|
|
"""Initialize LLM and tokenizer."""
|
|
logger.info(f"Initializing LLM with model: {args.model_path}")
|
|
|
|
try:
|
|
# Configure KV cache
|
|
kv_cache_config = KvCacheConfig(
|
|
# sparse attention doesn't support KV cache reuse
|
|
enable_block_reuse=False,
|
|
free_gpu_memory_fraction=args.kv_cache_fraction,
|
|
tokens_per_block=args.tokens_per_block,
|
|
)
|
|
|
|
# Configure CUDA graph
|
|
cuda_graph_config = CudaGraphConfig(
|
|
batch_sizes=args.cuda_graph_batch_sizes,
|
|
enable_padding=args.cuda_graph_padding_enabled,
|
|
) if args.use_cuda_graph else None
|
|
|
|
# Configure sparse attention
|
|
if args.rocket_sparse:
|
|
# Configure RocketKV sparse attention
|
|
sparse_attention_config = RocketSparseAttentionConfig(
|
|
window_size=args.window_size,
|
|
kernel_size=args.kernel_size,
|
|
prompt_budget=args.token_budget,
|
|
topk=args.topk,
|
|
kt_cache_dtype=args.kt_cache_dtype,
|
|
)
|
|
logger.info(f"Using RocketKV sparse attention")
|
|
else:
|
|
sparse_attention_config = None
|
|
logger.info("Using standard attention")
|
|
|
|
# Initialize LLM
|
|
llm = LLM(
|
|
model=args.model_path,
|
|
backend=args.backend,
|
|
kv_cache_config=kv_cache_config,
|
|
max_batch_size=args.max_batch_size,
|
|
attn_backend=args.attention_backend,
|
|
sparse_attention_config=sparse_attention_config,
|
|
tensor_parallel_size=args.tp_size,
|
|
moe_expert_parallel_size=args.moe_ep_size,
|
|
enable_attention_dp=args.enable_attention_dp,
|
|
max_seq_len=args.max_seq_len,
|
|
max_num_tokens=args.max_num_tokens,
|
|
cuda_graph_config=cuda_graph_config,
|
|
print_iter_log=args.print_iter_log,
|
|
moe_config=MoeConfig(backend=args.moe_backend),
|
|
)
|
|
|
|
# Initialize tokenizer
|
|
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
|
|
|
|
logger.info("LLM and tokenizer initialized successfully")
|
|
|
|
return llm, tokenizer
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize LLM: {e}")
|
|
raise e
|
|
|
|
|
|
def evaluate_single_dataset(
|
|
dataset: str, llm: LLM, tokenizer: AutoTokenizer,
|
|
args: argparse.Namespace) -> Tuple[List[Dict], float]:
|
|
"""Evaluate a single dataset."""
|
|
setup_longbench_imports(args.longbench_path)
|
|
|
|
dataset2prompt, dataset2maxlen = load_longbench_config(args.longbench_path)
|
|
|
|
# Load dataset
|
|
logger.info(f"Loading dataset: {dataset}")
|
|
data = [
|
|
data_sample for data_sample in load_dataset(
|
|
'THUDM/LongBench', dataset, split='test', trust_remote_code=True)
|
|
]
|
|
|
|
# Apply data filtering
|
|
if args.num_samples:
|
|
end_idx = min(args.start_idx + args.num_samples, len(data))
|
|
filtered_data = data[args.start_idx:end_idx]
|
|
else:
|
|
filtered_data = data[args.start_idx:]
|
|
|
|
logger.info(f"Dataset {dataset}: {len(filtered_data)} samples to evaluate")
|
|
|
|
# Determine chat template
|
|
chat_template = determine_chat_template(args.model_path, args.chat_template)
|
|
logger.info(f"Using chat template: {chat_template}")
|
|
|
|
# Create sampling parameters
|
|
max_new_tokens = dataset2maxlen[dataset]
|
|
prompt_format = dataset2prompt[dataset]
|
|
|
|
# Set up extra end token ids
|
|
extra_end_token_ids = []
|
|
if chat_template == "llama3":
|
|
eot_id = tokenizer.encode("<|eot_id|>", add_special_tokens=False)[0]
|
|
extra_end_token_ids.append(eot_id)
|
|
logger.info(f"Added llama3 end token: {eot_id}")
|
|
|
|
if chat_template == "qwen":
|
|
im_end_id = tokenizer.encode("<|im_end|>", add_special_tokens=False)[0]
|
|
extra_end_token_ids.append(im_end_id)
|
|
logger.info(f"Added qwen end token: {im_end_id}")
|
|
|
|
if dataset == "samsum":
|
|
newline_id = tokenizer.encode("\n", add_special_tokens=False)[-1]
|
|
extra_end_token_ids.append(newline_id)
|
|
logger.info(f"Added samsum newline token: {newline_id}")
|
|
|
|
# Prepare prompts
|
|
prompts = []
|
|
for sample in filtered_data:
|
|
formatted_prompt = format_prompt_style(sample, prompt_format,
|
|
chat_template, dataset,
|
|
tokenizer)
|
|
# Truncate prompt if it's too long
|
|
token_ids = tokenizer.encode(formatted_prompt, truncation=False)
|
|
if len(token_ids) > args.max_seq_len:
|
|
half = (args.max_seq_len - max_new_tokens) // 2
|
|
formatted_prompt = tokenizer.decode(
|
|
token_ids[:half], skip_special_tokens=True) + tokenizer.decode(
|
|
token_ids[-half:], skip_special_tokens=True)
|
|
|
|
prompts.append(formatted_prompt)
|
|
|
|
if len(prompts) == 0:
|
|
logger.warning(f"No prompts to evaluate for dataset {dataset}")
|
|
return [], 0.0
|
|
|
|
# Run inference
|
|
logger.info(f"Starting inference for {len(prompts)} samples...")
|
|
start_time = time.time()
|
|
|
|
sampling_params = SamplingParams(
|
|
max_tokens=max_new_tokens,
|
|
temperature=0.8,
|
|
top_p=0.95,
|
|
stop_token_ids=extra_end_token_ids if extra_end_token_ids else None,
|
|
)
|
|
|
|
outputs = llm.generate(prompts, sampling_params)
|
|
|
|
inference_time = time.time() - start_time
|
|
logger.info(
|
|
f"Inference completed in {inference_time:.2f} seconds, average time per sample: {inference_time/len(prompts):.3f} seconds"
|
|
)
|
|
|
|
# Prepare results
|
|
results = []
|
|
for i, (sample, output) in enumerate(zip(filtered_data, outputs)):
|
|
prediction = output.outputs[0].text.strip()
|
|
processed_prediction = post_process(prediction, chat_template, dataset)
|
|
|
|
result = {
|
|
'sample_id': args.start_idx + i,
|
|
'input': sample.get('input', ''),
|
|
'context': sample.get('context', ''),
|
|
'answers': sample.get('answers', []),
|
|
'all_classes': sample.get('all_classes', []),
|
|
'prediction': processed_prediction,
|
|
'raw_prediction': prediction,
|
|
'prompt_length': len(output.prompt_token_ids),
|
|
'output_length': len(output.outputs[0].token_ids),
|
|
'inference_time': getattr(output, 'inference_time', None),
|
|
'length': sample.get('length', 0)
|
|
}
|
|
results.append(result)
|
|
|
|
return results, inference_time
|
|
|
|
|
|
def calculate_metrics(
|
|
dataset: str, predictions: List[str], answers_list: List[List[str]],
|
|
all_classes_list: List[List[str]],
|
|
longbench_path: str) -> Tuple[Dict[str, float], List[float]]:
|
|
"""Calculate evaluation metrics for a dataset following LongBench's implementation."""
|
|
|
|
# Setup LongBench imports
|
|
setup_longbench_imports(longbench_path)
|
|
|
|
# Import LongBench metrics
|
|
from metrics import (classification_score, code_sim_score, count_score,
|
|
qa_f1_score, qa_f1_zh_score, retrieval_score,
|
|
retrieval_zh_score, rouge_score, rouge_zh_score)
|
|
|
|
# Mapping of datasets to their metric functions (from LongBench)
|
|
dataset2metric = {
|
|
"narrativeqa": qa_f1_score,
|
|
"qasper": qa_f1_score,
|
|
"multifieldqa_en": qa_f1_score,
|
|
"multifieldqa_zh": qa_f1_zh_score,
|
|
"hotpotqa": qa_f1_score,
|
|
"2wikimqa": qa_f1_score,
|
|
"musique": qa_f1_score,
|
|
"dureader": rouge_zh_score,
|
|
"gov_report": rouge_score,
|
|
"qmsum": rouge_score,
|
|
"multi_news": rouge_score,
|
|
"vcsum": rouge_zh_score,
|
|
"trec": classification_score,
|
|
"triviaqa": qa_f1_score,
|
|
"samsum": rouge_score,
|
|
"lsht": classification_score,
|
|
"passage_retrieval_en": retrieval_score,
|
|
"passage_count": count_score,
|
|
"passage_retrieval_zh": retrieval_zh_score,
|
|
"lcc": code_sim_score,
|
|
"repobench-p": code_sim_score,
|
|
}
|
|
|
|
if dataset not in dataset2metric:
|
|
# Fallback to simple exact match with cleaning
|
|
total_score = 0
|
|
scores = []
|
|
for pred, answers in zip(predictions, answers_list):
|
|
cleaned_pred = pred.lstrip('\n').split('\n')[0].strip()
|
|
score = max([
|
|
1.0 if cleaned_pred.lower() == ans.strip().lower() else 0.0
|
|
for ans in answers
|
|
])
|
|
scores.append(score)
|
|
total_score += score
|
|
return {
|
|
"exact_match": round(100 * total_score / len(predictions), 2)
|
|
}, scores
|
|
|
|
metric_func = dataset2metric[dataset]
|
|
total_score = 0.0
|
|
scores = []
|
|
|
|
# Follow LongBench's scorer function exactly
|
|
for pred, ground_truths, all_classes in zip(predictions, answers_list,
|
|
all_classes_list):
|
|
score = 0.0
|
|
|
|
# Apply the same prediction cleaning as LongBench
|
|
if dataset in ["trec", "triviaqa", "samsum", "lsht"]:
|
|
pred = pred.lstrip('\n').split('\n')[0]
|
|
|
|
# For code datasets, apply additional cleaning
|
|
if dataset in ["lcc", "repobench-p"]:
|
|
# This cleaning is done inside code_sim_score, but let's also apply it here for consistency
|
|
all_lines = pred.lstrip('\n').split('\n')
|
|
for line in all_lines:
|
|
if ('`' not in line) and ('#' not in line) and ('//'
|
|
not in line):
|
|
pred = line
|
|
break
|
|
|
|
# Calculate max score across all reference answers (exactly as in LongBench)
|
|
for ground_truth in ground_truths:
|
|
score = max(
|
|
score, metric_func(pred, ground_truth, all_classes=all_classes))
|
|
|
|
scores.append(score)
|
|
total_score += score
|
|
|
|
final_score = round(100 * total_score / len(predictions), 2)
|
|
return {metric_func.__name__: final_score}, scores
|
|
|
|
|
|
def calculate_task_summary(all_results: Dict[str, Dict]) -> Dict[str, Any]:
|
|
"""Calculate task-level summary statistics following long_bench_tasks_summary.py approach."""
|
|
logger.info("Calculating task-level summary statistics...")
|
|
|
|
summary = {}
|
|
ind_dataset_result = {}
|
|
task_ave_result = {}
|
|
|
|
NA_flag = False
|
|
|
|
# Get individual dataset results
|
|
for dataset in LONGBENCH_DATASETS:
|
|
if dataset in all_results and 'metrics' in all_results[dataset]:
|
|
metrics = all_results[dataset]['metrics']
|
|
# Get the first (and usually only) metric value
|
|
if metrics:
|
|
metric_key = list(metrics.keys())[0]
|
|
val = metrics[metric_key]
|
|
ind_dataset_result[dataset] = val
|
|
else:
|
|
ind_dataset_result[dataset] = 'N/A'
|
|
NA_flag = True
|
|
else:
|
|
ind_dataset_result[dataset] = 'N/A'
|
|
NA_flag = True
|
|
|
|
summary['individual_dataset_result'] = ind_dataset_result
|
|
|
|
# Calculate task-average results
|
|
for task, datasets in TASK_DATASETS.items():
|
|
task_NA_flag = False
|
|
task_ave_result[task] = 0
|
|
valid_count = 0
|
|
|
|
for dataset in datasets:
|
|
if dataset in ind_dataset_result and ind_dataset_result[
|
|
dataset] != 'N/A':
|
|
task_ave_result[task] += ind_dataset_result[dataset]
|
|
valid_count += 1
|
|
else:
|
|
task_NA_flag = True
|
|
|
|
if task_NA_flag or valid_count == 0:
|
|
task_ave_result[task] = 'N/A'
|
|
else:
|
|
task_ave_result[task] = np.round(task_ave_result[task] /
|
|
valid_count,
|
|
decimals=2)
|
|
|
|
summary['task_average_result'] = task_ave_result
|
|
|
|
# Calculate overall LongBench average result (excluding passage_count as in original)
|
|
if NA_flag:
|
|
summary['LB_average_result'] = 'N/A'
|
|
else:
|
|
average_result = 0
|
|
valid_count = 0
|
|
for dataset in LONGBENCH_DATASETS:
|
|
if dataset != 'passage_count' and dataset in ind_dataset_result:
|
|
if ind_dataset_result[dataset] != 'N/A':
|
|
average_result += ind_dataset_result[dataset]
|
|
valid_count += 1
|
|
|
|
if valid_count > 0:
|
|
summary['LB_average_result'] = np.round(average_result /
|
|
valid_count,
|
|
decimals=2)
|
|
else:
|
|
summary['LB_average_result'] = 'N/A'
|
|
|
|
# Log summary statistics
|
|
logger.info("Task Summary Results:")
|
|
logger.info(f"Overall LongBench Average: {summary['LB_average_result']}")
|
|
for task, score in task_ave_result.items():
|
|
logger.info(f"{task}: {score}")
|
|
|
|
return summary
|
|
|
|
|
|
def save_results(results: List[Dict], dataset: str, args: argparse.Namespace,
|
|
inference_time: float, output_dir: str):
|
|
"""Save evaluation results in format compatible with LongBench."""
|
|
task_output_dir = os.path.join(output_dir, dataset)
|
|
os.makedirs(task_output_dir, exist_ok=True)
|
|
|
|
# Extract predictions, answers, and all_classes for evaluation
|
|
predictions = [r['prediction'] for r in results]
|
|
answers_list = [r['answers'] for r in results]
|
|
all_classes_list = [r.get('all_classes', []) for r in results]
|
|
|
|
# Calculate metrics
|
|
processed_results, scores = calculate_metrics(dataset, predictions,
|
|
answers_list,
|
|
all_classes_list,
|
|
args.longbench_path)
|
|
logger.info(f"Evaluation metrics: {processed_results}")
|
|
|
|
# Save detailed results for manual inspection
|
|
results_file = os.path.join(task_output_dir, f"{dataset}_results.jsonl")
|
|
with open(results_file, 'w', encoding='utf-8') as f:
|
|
for result in results:
|
|
json.dump(result, f, ensure_ascii=False)
|
|
f.write('\n')
|
|
|
|
# Save prediction results in LongBench format for evaluation
|
|
pred_dir = os.path.join(task_output_dir, "pred")
|
|
os.makedirs(pred_dir, exist_ok=True)
|
|
pred_file = os.path.join(pred_dir, f"{dataset}.jsonl")
|
|
|
|
with open(pred_file, 'w', encoding='utf-8') as f:
|
|
for idx, result in enumerate(results):
|
|
pred_data = {
|
|
"pred": result['prediction'],
|
|
"answers": result['answers'],
|
|
"all_classes": result.get('all_classes', []),
|
|
"length": result.get('length', 0),
|
|
"score": scores[idx]
|
|
}
|
|
json.dump(pred_data, f, ensure_ascii=False)
|
|
f.write('\n')
|
|
|
|
# Create summary following LongBench format
|
|
config = {
|
|
'pipeline_params': {
|
|
'model_name': args.model_path,
|
|
'method': args.attention_backend,
|
|
'token_budget': args.token_budget,
|
|
'max_seq_len': args.max_seq_len,
|
|
'max_new_tokens': args.max_new_tokens,
|
|
'window_size': args.window_size,
|
|
'kernel_size': args.kernel_size,
|
|
'num_processes': 1, # Single process
|
|
'devices': "0" # Single device
|
|
},
|
|
'eval_params': {
|
|
'dataset': dataset,
|
|
'num_samples': len(results)
|
|
},
|
|
'eval_results': {
|
|
'processed_results': processed_results
|
|
},
|
|
'management': {
|
|
'output_folder_dir': task_output_dir,
|
|
'exp_desc':
|
|
f'{dataset}_{os.path.basename(args.model_path)}_{args.attention_backend}_{args.token_budget}',
|
|
'total_inference_time': inference_time,
|
|
'avg_inference_time': inference_time / len(results),
|
|
'evaluation_timestamp': datetime.now().isoformat()
|
|
}
|
|
}
|
|
|
|
# Save summary
|
|
summary_file = os.path.join(task_output_dir, f"{dataset}_summary.json")
|
|
with open(summary_file, 'w', encoding='utf-8') as f:
|
|
json.dump(config, f, indent=2, ensure_ascii=False)
|
|
|
|
logger.info(f"The results of {dataset} are saved to {task_output_dir}")
|
|
|
|
return processed_results
|
|
|
|
|
|
def main():
|
|
"""Main evaluation function."""
|
|
args = parse_arguments()
|
|
logger.set_level(args.log_level)
|
|
|
|
# Setup experiment name
|
|
if not args.exp_name:
|
|
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
|
model_name = os.path.basename(args.model_path).replace('/', '_')
|
|
args.exp_name = f"longbench_{model_name}_{timestamp}"
|
|
|
|
output_dir = os.path.join(args.output_dir, args.exp_name)
|
|
|
|
logger.info(
|
|
"=========== LongBench Evaluation with TensorRT-LLM ===========")
|
|
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
# Save configuration
|
|
config_file = os.path.join(output_dir, "config.json")
|
|
with open(config_file, 'w') as f:
|
|
json.dump(vars(args), f, indent=2)
|
|
logger.info(f"Configuration saved to {config_file}")
|
|
|
|
# Determine datasets to evaluate
|
|
if args.run_all_tasks:
|
|
datasets = LONGBENCH_DATASETS
|
|
logger.info(f"Running evaluation on full LongBench datasets")
|
|
else:
|
|
datasets = args.dataset
|
|
logger.info(f"Running evaluation on datasets: {args.dataset}")
|
|
|
|
# Initialize LLM and tokenizer
|
|
llm, tokenizer = initialize_llm(args)
|
|
|
|
# Process datasets sequentially
|
|
all_results = {}
|
|
for dataset_idx, dataset in enumerate(datasets):
|
|
logger.info(f"{'='*30}")
|
|
logger.info(
|
|
f"Processing dataset {dataset_idx+1}/{len(datasets)}: {dataset}...")
|
|
|
|
# Evaluate the dataset
|
|
results, inference_time = evaluate_single_dataset(
|
|
dataset, llm, tokenizer, args)
|
|
|
|
# Save results and get metrics
|
|
processed_results = save_results(results, dataset, args, inference_time,
|
|
output_dir)
|
|
|
|
all_results[dataset] = {
|
|
'num_samples': len(results),
|
|
'inference_time': inference_time,
|
|
'output_dir': output_dir,
|
|
'metrics': processed_results
|
|
}
|
|
logger.info(f"Dataset {dataset} completed successfully")
|
|
|
|
total_time = sum(all_results[d]['inference_time'] for d in all_results)
|
|
|
|
# Calculate task-level summary
|
|
task_summary = calculate_task_summary(all_results)
|
|
|
|
# Save overall summary with task statistics
|
|
overall_summary = {
|
|
'experiment_name':
|
|
args.exp_name,
|
|
'total_evaluation_time':
|
|
total_time,
|
|
'evaluated_datasets':
|
|
list(all_results.keys()),
|
|
'successful_datasets':
|
|
[d for d, r in all_results.items() if 'error' not in r],
|
|
'failed_datasets': [d for d, r in all_results.items() if 'error' in r],
|
|
'results_by_dataset':
|
|
all_results,
|
|
'task_summary':
|
|
task_summary, # Add task-level summary
|
|
'configuration':
|
|
vars(args)
|
|
}
|
|
|
|
overall_summary_file = os.path.join(output_dir, "overall_summary.json")
|
|
with open(overall_summary_file, 'w') as f:
|
|
json.dump(overall_summary, f, indent=2)
|
|
|
|
logger.info(f"\n{'-'*80}")
|
|
logger.info(
|
|
f"Evaluation completed. Overall summary saved to: {overall_summary_file}"
|
|
)
|
|
logger.info(f"Total time: {total_time:.2f} seconds")
|
|
|
|
# Print final summary
|
|
if task_summary['LB_average_result'] != 'N/A':
|
|
logger.info(f"FINAL RESULTS:")
|
|
logger.info(
|
|
f"LongBench Overall Average: {task_summary['LB_average_result']}")
|
|
logger.info(f"Task-level results:")
|
|
for task, score in task_summary['task_average_result'].items():
|
|
logger.info(f" {task}: {score}")
|
|
|
|
if overall_summary['failed_datasets']:
|
|
logger.warning(f"Failed datasets: {overall_summary['failed_datasets']}")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|