mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
test [TRTLLM-4477,TRTLLM-4481]: Accuracy test improvement (Part 3.5): Support GSM8K and GPQA (#3483)
* add gsm8k Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> * fix gsm8k Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> * add gpqa Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> * conditional import lm_eval Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> * gpqa in lm_eval Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> * system prompt Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> * shuffle Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> * update AA prompt and regex Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> * revert AA prompt and regex Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> * integration to tests Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> * fix Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> * add DS-R1 Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> * fix and clean Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> * fix Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> * update tests Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> * update Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> * clean up Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> * free_gpu_memory_fraction=0.8 Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> * fix Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> --------- Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
This commit is contained in:
parent
0c07d4dc21
commit
3fa19ffa4e
@ -1,523 +0,0 @@
|
||||
# MIT License
|
||||
#
|
||||
# Copyright (c) 2020 Dan Hendrycks
|
||||
# Copyright (c) 2023 Deep Cognition and Language Research (DeCLaRe) Lab
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
# Not a contribution
|
||||
# Changes made by NVIDIA CORPORATION & AFFILIATES or otherwise documented as
|
||||
# NVIDIA-proprietary are not a contribution and subject to the following terms and conditions:
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related
|
||||
# documentation and any modifications thereto. Any use, reproduction,
|
||||
# disclosure or distribution of this material and related documentation
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
"""A duplication of examples/mmlu_llmapi.py and tensorrt_llm/bench/benchmark/utils/asynchronous.py, but targeting GPQA task.
|
||||
The duplication is used to get a quick GPQA score in the CI test.
|
||||
TODO: Should be merged with examples/mmlu_llmapi.py
|
||||
Example usage:
|
||||
python gpqa.py --hf_model_dir <HF model path> --data_dir <GPQA csv data path>
|
||||
or with more optimizations:
|
||||
python gpqa.py --hf_model_dir <HF model path> --data_dir <GPQA csv data path> \
|
||||
--limit 0.1 --tp_size 8 --ep_size 8 --use_cuda_graph --enable_overlap_scheduler \
|
||||
--concurrency 8 --mtp_nextn 3 --print_iter_log --batch_size 32 --max_num_tokens 4096
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import List, Optional, Set, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from tensorrt_llm import SamplingParams
|
||||
from tensorrt_llm._torch.llm import LLM as PyTorchLLM
|
||||
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
|
||||
from tensorrt_llm.builder import BuildConfig
|
||||
from tensorrt_llm.llmapi import KvCacheConfig, MTPDecodingConfig
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
# Template for multiple choice questions
|
||||
QUERY_TEMPLATE_MULTICHOICE = """
|
||||
Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering.
|
||||
{Question}
|
||||
A) {A}
|
||||
B) {B}
|
||||
C) {C}
|
||||
D) {D}
|
||||
""".strip()
|
||||
|
||||
# Pattern to extract the answer from the response
|
||||
ANSWER_PATTERN_MULTICHOICE = r"(?i)Answer[ \t]*:[ \t]*([A-D])"
|
||||
|
||||
|
||||
class RandomSeedGenerator:
|
||||
"""A deterministic seed generator for reproducible random number generation.
|
||||
|
||||
This implementation guarantees consistent results across different machines,
|
||||
Python versions, and platforms by using integer-based seed generation.
|
||||
"""
|
||||
|
||||
def __init__(self, initial_seed: int = 42):
|
||||
self.initial_seed = initial_seed
|
||||
self.random_generator = random.Random(initial_seed)
|
||||
|
||||
def gen_seed(self, idx: int, sub_idx: Optional[int] = None) -> int:
|
||||
# This ensures consistent behavior across platforms
|
||||
if sub_idx is not None:
|
||||
# Combine seeds using prime numbers and bit operations
|
||||
# to minimize collisions and maintain reproducibility
|
||||
complex_seed = self.initial_seed
|
||||
complex_seed = (complex_seed * 2147483647) + idx # Use prime number
|
||||
complex_seed = (complex_seed * 2147483647) + (sub_idx if sub_idx
|
||||
is not None else 0)
|
||||
else:
|
||||
complex_seed = (self.initial_seed * 2147483647) + idx
|
||||
|
||||
self.random_generator.seed(complex_seed)
|
||||
return self.random_generator.randint(0, 2**32 - 1)
|
||||
|
||||
|
||||
class DataShuffle:
|
||||
'''
|
||||
A class to shuffle the data with fixed seed.
|
||||
'''
|
||||
|
||||
def __init__(self, seed: int = 42):
|
||||
self.seed = seed
|
||||
self.random_generator = random.Random(self.seed)
|
||||
|
||||
def shuffle(self, data: List[dict]) -> List[dict]:
|
||||
self.random_generator.shuffle(data)
|
||||
return data
|
||||
|
||||
|
||||
# Class to manage tasks for processing requests
|
||||
class TaskManager:
|
||||
|
||||
def __init__(self,
|
||||
model: PyTorchLLM,
|
||||
outbox: asyncio.Queue[Tuple[int, float]],
|
||||
concurrency: int = -1) -> None:
|
||||
self.model = model
|
||||
self._inbox = asyncio.Queue()
|
||||
self._outbox = outbox
|
||||
|
||||
self._stop = asyncio.Event()
|
||||
self._tasks: Set[asyncio.Task] = set()
|
||||
self._backend_task = None
|
||||
self._concurrency_semaphore = asyncio.Semaphore(
|
||||
concurrency) if concurrency > 0 else None
|
||||
|
||||
# Function to extract the answer from the response and calculate the score
|
||||
def get_answer(self, response: str, answer: str) -> float:
|
||||
match = re.search(ANSWER_PATTERN_MULTICHOICE, response)
|
||||
extracted_answer = match.group(1) if match else None
|
||||
score = 1.0 if extracted_answer == answer else 0.0
|
||||
return score
|
||||
|
||||
# Function to process a single request
|
||||
async def process_request(self, idx: int, request: str, answer: str,
|
||||
sampling_params: SamplingParams) -> float:
|
||||
async with semaphore_guard(self._concurrency_semaphore):
|
||||
output = self.model.generate_async(request,
|
||||
sampling_params=sampling_params)
|
||||
|
||||
gen_output = await output.aresult()
|
||||
# Extract generated tokens
|
||||
response = gen_output.outputs[0].text
|
||||
score = self.get_answer(response, answer)
|
||||
await self._outbox.put((idx, score))
|
||||
|
||||
# Worker function to continuously process requests
|
||||
async def worker(self) -> None:
|
||||
while not self._stop.is_set():
|
||||
idx, request, answer, sampling_params = await self._inbox.get()
|
||||
task = asyncio.create_task(
|
||||
self.process_request(idx, request, answer, sampling_params))
|
||||
self._tasks.add(task)
|
||||
task.add_done_callback(self._tasks.discard)
|
||||
|
||||
# Function to stop the worker
|
||||
def stop(self) -> None:
|
||||
self._stop.set()
|
||||
for task in self._tasks:
|
||||
task.cancel()
|
||||
self._backend_task.cancel()
|
||||
|
||||
# Property to check if the worker is busy
|
||||
@property
|
||||
def busy(self) -> bool:
|
||||
return bool(self._tasks)
|
||||
|
||||
# Function to start the worker
|
||||
def run(self) -> None:
|
||||
self._backend_task = asyncio.create_task(self.worker())
|
||||
|
||||
# Function to enqueue a request
|
||||
async def enqueue(self, idx: int, request: str, answer: str,
|
||||
sampling_params: SamplingParams) -> None:
|
||||
await self._inbox.put((idx, request, answer, sampling_params))
|
||||
|
||||
|
||||
def format_multichoice_question(row: dict) -> str:
|
||||
return QUERY_TEMPLATE_MULTICHOICE.format(**row)
|
||||
|
||||
|
||||
def load_data(data_dir: str,
|
||||
dataset_shuffle: DataShuffle,
|
||||
limit: Optional[float] = None,
|
||||
num_runs: int = 1) -> List[List[dict]]:
|
||||
assert data_dir.endswith('.csv'), "The provided file is not a CSV file."
|
||||
df = pd.read_csv(data_dir)
|
||||
dataset = [row.to_dict() for _, row in df.iterrows()]
|
||||
if limit is not None:
|
||||
dataset = dataset[:int(len(dataset) * limit) + 1]
|
||||
shuffled_datasets = []
|
||||
for _ in range(num_runs):
|
||||
shuffled_datasets.append(dataset_shuffle.shuffle(dataset.copy()))
|
||||
return shuffled_datasets
|
||||
|
||||
|
||||
# Function to generate a prompt and the correct answer
|
||||
def gen_prompt(row: dict, tokenizer: AutoTokenizer,
|
||||
dataset_shuffle: DataShuffle) -> Tuple[str, str]:
|
||||
choices = dataset_shuffle.shuffle([
|
||||
row["Correct Answer"],
|
||||
row["Incorrect Answer 1"],
|
||||
row["Incorrect Answer 2"],
|
||||
row["Incorrect Answer 3"],
|
||||
])
|
||||
correct_index = choices.index(row["Correct Answer"])
|
||||
answer = "ABCD"[correct_index]
|
||||
choices_dict = dict(A=choices[0],
|
||||
B=choices[1],
|
||||
C=choices[2],
|
||||
D=choices[3],
|
||||
Question=row["Question"])
|
||||
msg = [{
|
||||
"role": "user",
|
||||
"content": str(format_multichoice_question(choices_dict))
|
||||
}]
|
||||
prompt = tokenizer.apply_chat_template(msg,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True)
|
||||
return prompt, answer
|
||||
|
||||
|
||||
# Async context manager for semaphore
|
||||
@asynccontextmanager
|
||||
async def semaphore_guard(semaphore: Optional[asyncio.Semaphore] = None):
|
||||
if semaphore is not None:
|
||||
await semaphore.acquire()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if semaphore is not None:
|
||||
semaphore.release()
|
||||
|
||||
|
||||
# Function to enqueue messages for processing
|
||||
async def enqueue_messages(backend: TaskManager, dataset: List[dict],
|
||||
tokenizer: AutoTokenizer,
|
||||
sampling_params: SamplingParams,
|
||||
submit_finished: asyncio.Event,
|
||||
seed_generator: RandomSeedGenerator,
|
||||
dataset_shuffle: DataShuffle) -> None:
|
||||
for idx, row in enumerate(dataset):
|
||||
prompt, answer = gen_prompt(row, tokenizer, dataset_shuffle)
|
||||
idx_seed = seed_generator.gen_seed(idx=idx)
|
||||
sampling_params.seed = idx_seed
|
||||
await backend.enqueue(idx, prompt, answer, sampling_params)
|
||||
submit_finished.set()
|
||||
|
||||
|
||||
# Function to benchmark the model asynchronously
|
||||
async def async_benchmark(
|
||||
model: PyTorchLLM,
|
||||
sampling_params: SamplingParams,
|
||||
dataset: List[dict],
|
||||
tokenizer: AutoTokenizer,
|
||||
seed_generator: RandomSeedGenerator,
|
||||
dataset_shuffle: DataShuffle,
|
||||
concurrency: int = -1,
|
||||
) -> List[float]:
|
||||
outbox = asyncio.Queue()
|
||||
submit_finished = asyncio.Event()
|
||||
results = []
|
||||
|
||||
try:
|
||||
backend = TaskManager(model, outbox, concurrency=concurrency)
|
||||
backend.run()
|
||||
|
||||
num_requests = len(dataset)
|
||||
enqueue_task = asyncio.create_task(
|
||||
enqueue_messages(backend, dataset, tokenizer, sampling_params,
|
||||
submit_finished, seed_generator, dataset_shuffle))
|
||||
|
||||
with tqdm(total=num_requests, desc="Processing requests") as pbar:
|
||||
while not submit_finished.is_set() or not outbox.empty() or len(
|
||||
results) < num_requests:
|
||||
try:
|
||||
idx, item = await asyncio.wait_for(outbox.get(),
|
||||
timeout=3600)
|
||||
results.append((idx, item))
|
||||
pbar.update(1)
|
||||
except asyncio.TimeoutError:
|
||||
print("No items in queue. Continuing.")
|
||||
if not backend.busy:
|
||||
break
|
||||
results.sort(key=lambda x: x[0])
|
||||
return results
|
||||
|
||||
except asyncio.CancelledError:
|
||||
enqueue_task.cancel()
|
||||
|
||||
finally:
|
||||
backend.stop()
|
||||
|
||||
|
||||
# Function to parse command line arguments
|
||||
def parse_args():
|
||||
# Model args
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--hf_model_dir",
|
||||
type=str,
|
||||
required=True,
|
||||
default=None,
|
||||
help="HF model dir")
|
||||
parser.add_argument("--tokenizer_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Tokenizer dir")
|
||||
parser.add_argument('--load_format',
|
||||
type=str,
|
||||
default='auto',
|
||||
help='Load format for the model')
|
||||
parser.add_argument("--top_p",
|
||||
type=float,
|
||||
default=1e-5,
|
||||
help="Top-p for sampling")
|
||||
parser.add_argument("--temperature",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Temperature for sampling")
|
||||
|
||||
# PyTorch backend settings
|
||||
parser.add_argument("--backend",
|
||||
type=str,
|
||||
choices=["pytorch"],
|
||||
default="pytorch",
|
||||
help="Choose the backend to run the model")
|
||||
parser.add_argument('--attn_backend',
|
||||
type=str,
|
||||
default='TRTLLM',
|
||||
choices=['TRTLLM', 'FLASHINFER'],
|
||||
help='Attention kernel for PyTorch flow.')
|
||||
parser.add_argument("--max_generation_tokens",
|
||||
type=int,
|
||||
default=32768,
|
||||
help="Maximum generation tokens")
|
||||
parser.add_argument("--concurrency",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="Concurrency for dataset items")
|
||||
parser.add_argument('--batch_size',
|
||||
type=int,
|
||||
default=32,
|
||||
help="Max batch size")
|
||||
parser.add_argument("--max_num_tokens",
|
||||
type=int,
|
||||
default=4096,
|
||||
help="Maximum number of tokens")
|
||||
parser.add_argument("--tp_size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Tensor Parallel size (only for pytorch backend)")
|
||||
parser.add_argument("--ep_size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Expert Parallel size (only for pytorch backend)")
|
||||
|
||||
# KV cache
|
||||
parser.add_argument('--kv_cache_dtype',
|
||||
type=str,
|
||||
default='auto',
|
||||
help='KV cache dtype')
|
||||
parser.add_argument('--kv_cache_disable_block_reuse',
|
||||
default=False,
|
||||
action='store_true',
|
||||
help='Disable block reuse for KV cache')
|
||||
|
||||
# TODO: change the default value back to 0.95
|
||||
parser.add_argument("--kv_cache_fraction",
|
||||
type=float,
|
||||
default=0.85,
|
||||
help='Fraction of KV cache to use')
|
||||
|
||||
# Optimizations
|
||||
parser.add_argument('--use_cuda_graph',
|
||||
default=False,
|
||||
action='store_true',
|
||||
help='Use CUDA graph for inference')
|
||||
parser.add_argument('--torch_compile',
|
||||
action="store_true",
|
||||
help="Enable torch compile for pytorch backend")
|
||||
parser.add_argument("--enable_attention_dp",
|
||||
default=False,
|
||||
action='store_true')
|
||||
parser.add_argument('--print_iter_log',
|
||||
default=False,
|
||||
action='store_true',
|
||||
help='Print iteration logs during execution')
|
||||
parser.add_argument('--enable_overlap_scheduler',
|
||||
default=False,
|
||||
action='store_true')
|
||||
|
||||
# Speculative decoding
|
||||
parser.add_argument('--mtp_nextn',
|
||||
type=int,
|
||||
default=0,
|
||||
help='Number of next-n layers to predict')
|
||||
|
||||
# GPQA args
|
||||
parser.add_argument(
|
||||
"--data_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the data directory. If not available, "
|
||||
"download from https://huggingface.co/datasets/Idavidrein/gpqa")
|
||||
parser.add_argument("--limit",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Limit the number of samples to run")
|
||||
parser.add_argument('--check_accuracy', action='store_true')
|
||||
parser.add_argument('--accuracy_threshold', type=float, default=0.67)
|
||||
parser.add_argument('--seed', type=int, default=42)
|
||||
parser.add_argument('--out_dir', type=str, default=None)
|
||||
parser.add_argument('--num_runs', type=int, default=1)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
if args.tokenizer_dir is None:
|
||||
args.tokenizer_dir = args.hf_model_dir
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
|
||||
# Load the tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.hf_model_dir)
|
||||
|
||||
# Configure the PyTorch model
|
||||
build_config = BuildConfig(max_batch_size=args.batch_size,
|
||||
max_num_tokens=args.max_num_tokens)
|
||||
pytorch_config = PyTorchConfig(
|
||||
attn_backend=args.attn_backend,
|
||||
enable_overlap_scheduler=args.enable_overlap_scheduler,
|
||||
torch_compile_enabled=args.torch_compile,
|
||||
kv_cache_dtype=args.kv_cache_dtype,
|
||||
use_cuda_graph=args.use_cuda_graph,
|
||||
load_format=args.load_format,
|
||||
print_iter_log=args.print_iter_log,
|
||||
# TODO: there is a known issue in autotuner_enabled warmup,
|
||||
# and it will be fixed in the near future
|
||||
autotuner_enabled=False)
|
||||
kv_cache_config = KvCacheConfig(
|
||||
enable_block_reuse=not args.kv_cache_disable_block_reuse,
|
||||
free_gpu_memory_fraction=args.kv_cache_fraction)
|
||||
mtp_config = MTPDecodingConfig(
|
||||
num_nextn_predict_layers=args.mtp_nextn) if args.mtp_nextn > 0 else None
|
||||
|
||||
model = PyTorchLLM(model=args.hf_model_dir,
|
||||
tokenizer=tokenizer,
|
||||
tensor_parallel_size=args.tp_size,
|
||||
kv_cache_config=kv_cache_config,
|
||||
speculative_config=mtp_config,
|
||||
moe_expert_parallel_size=args.ep_size,
|
||||
pytorch_backend_config=pytorch_config,
|
||||
build_config=build_config,
|
||||
enable_attention_dp=args.enable_attention_dp)
|
||||
|
||||
# Configure the sampling params
|
||||
sampling_params = SamplingParams(max_tokens=args.max_generation_tokens,
|
||||
top_p=args.top_p,
|
||||
temperature=args.temperature,
|
||||
end_id=tokenizer.eos_token_id,
|
||||
pad_id=tokenizer.pad_token_id)
|
||||
|
||||
# Load the dataset
|
||||
seed_generator = RandomSeedGenerator(initial_seed=args.seed)
|
||||
dataset_shuffle = DataShuffle(seed=args.seed)
|
||||
datasets = load_data(args.data_dir,
|
||||
dataset_shuffle,
|
||||
limit=args.limit,
|
||||
num_runs=args.num_runs)
|
||||
|
||||
t = time.time()
|
||||
try:
|
||||
# Run the benchmark
|
||||
results = []
|
||||
for i in range(args.num_runs):
|
||||
dataset = datasets[i]
|
||||
result = asyncio.run(
|
||||
async_benchmark(model,
|
||||
sampling_params,
|
||||
dataset,
|
||||
tokenizer,
|
||||
seed_generator,
|
||||
dataset_shuffle,
|
||||
concurrency=args.concurrency))
|
||||
results.append(result)
|
||||
finally:
|
||||
if model is not None:
|
||||
model.__exit__(None, None, None)
|
||||
t = time.time() - t
|
||||
print(f"Finished in {t:.3f} seconds")
|
||||
|
||||
# Calculate and print the accuracy
|
||||
acc = [np.mean([res[1] for res in result]) for result in results]
|
||||
acc_mean = np.mean(acc)
|
||||
for i in range(args.num_runs):
|
||||
print(f"Run {i+1} accuracy: {acc[i]:.3f}")
|
||||
print("Average accuracy: {:.3f}".format(acc_mean))
|
||||
if args.check_accuracy:
|
||||
assert acc_mean >= args.accuracy_threshold, f"Expected accuracy >= {args.accuracy_threshold} while got {acc_mean}"
|
||||
|
||||
return acc
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -88,22 +88,48 @@ python quickstart_advanced.py --model_dir <YOUR_MODEL_DIR> --spec_decode_algo MT
|
||||
|
||||
`N` is the number of MTP modules. When `N` is equal to `0`, which means that MTP is not used (default). When `N` is greater than `0`, which means that `N` MTP modules are enabled. In the current implementation, the weight of each MTP module is shared.
|
||||
|
||||
### Run evaluation on GPQA dataset
|
||||
Download the dataset first
|
||||
1. Sign up a huggingface account and request the access to the gpqa dataset: https://huggingface.co/datasets/Idavidrein/gpqa
|
||||
2. Download the csv file from https://huggingface.co/datasets/Idavidrein/gpqa/blob/main/gpqa_diamond.csv
|
||||
## Evaluation
|
||||
|
||||
Evaluate on GPQA dataset.
|
||||
Evaluate the model accuracy using `trtllm-eval`.
|
||||
|
||||
1. (Optional) Prepare an advanced configuration file:
|
||||
```bash
|
||||
cat >./extra-llm-api-config.yml <<EOF
|
||||
pytorch_backend_config:
|
||||
use_cuda_graph: true
|
||||
enable_overlap_scheduler: true
|
||||
enable_attention_dp: true
|
||||
EOF
|
||||
```
|
||||
python examples/gpqa_llmapi.py \
|
||||
--hf_model_dir <YOUR_MODEL_DIR> \
|
||||
--data_dir <DATASET_PATH> \
|
||||
|
||||
2. Evaluate accuracy on the [MMLU](https://people.eecs.berkeley.edu/~hendrycks/data.tar) dataset:
|
||||
```bash
|
||||
trtllm-eval --model <YOUR_MODEL_DIR> \
|
||||
--tp_size 8 \
|
||||
--use_cuda_graph \
|
||||
--enable_overlap_scheduler \
|
||||
--concurrency 32 \
|
||||
--batch_size 32 \
|
||||
--max_num_tokens 4096
|
||||
--kv_cache_free_gpu_memory_fraction 0.8 \
|
||||
--extra_llm_api_options ./extra-llm-api-config.yml \
|
||||
mmlu
|
||||
```
|
||||
|
||||
3. Evaluate accuracy on the [GSM8K](https://huggingface.co/datasets/openai/gsm8k) dataset:
|
||||
```bash
|
||||
trtllm-eval --model <YOUR_MODEL_DIR> \
|
||||
--tp_size 8 \
|
||||
--kv_cache_free_gpu_memory_fraction 0.8 \
|
||||
--extra_llm_api_options ./extra-llm-api-config.yml \
|
||||
gsm8k
|
||||
```
|
||||
|
||||
4. Evaluate accuracy on the [GPQA](https://huggingface.co/datasets/Idavidrein/gpqa) dataset:
|
||||
```bash
|
||||
# Ensure signing up a huggingface account with access to the GPQA dataset
|
||||
|
||||
trtllm-eval --model <YOUR_MODEL_DIR> \
|
||||
--tp_size 8 \
|
||||
--kv_cache_free_gpu_memory_fraction 0.8 \
|
||||
--extra_llm_api_options ./extra-llm-api-config.yml \
|
||||
gpqa_diamond \
|
||||
--apply_chat_template
|
||||
```
|
||||
|
||||
## Serving
|
||||
|
||||
@ -20,7 +20,8 @@ import tensorrt_llm.profiler as profiler
|
||||
|
||||
from .._torch.llm import LLM as PyTorchLLM
|
||||
from .._torch.pyexecutor.config import PyTorchConfig
|
||||
from ..evaluate import MMLU, CnnDailymail
|
||||
from ..evaluate import (GSM8K, MMLU, CnnDailymail, GPQADiamond, GPQAExtended,
|
||||
GPQAMain)
|
||||
from ..llmapi import LLM, BuildConfig, KvCacheConfig
|
||||
from ..llmapi.llm_utils import update_llm_args_with_extra_options
|
||||
from ..logger import logger, severity_map
|
||||
@ -142,6 +143,7 @@ def main(ctx, model: str, tokenizer: Optional[str], log_level: str,
|
||||
profiler.stop("trtllm init")
|
||||
elapsed_time = profiler.elapsed_time_in_sec("trtllm init")
|
||||
logger.info(f"TRTLLM initialization time: {elapsed_time:.3f} seconds.")
|
||||
profiler.reset("trtllm init")
|
||||
|
||||
# Pass llm to subcommands
|
||||
ctx.obj = llm
|
||||
@ -149,6 +151,10 @@ def main(ctx, model: str, tokenizer: Optional[str], log_level: str,
|
||||
|
||||
main.add_command(CnnDailymail.command)
|
||||
main.add_command(MMLU.command)
|
||||
main.add_command(GSM8K.command)
|
||||
main.add_command(GPQADiamond.command)
|
||||
main.add_command(GPQAMain.command)
|
||||
main.add_command(GPQAExtended.command)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@ -14,6 +14,9 @@
|
||||
# limitations under the License.
|
||||
|
||||
from .cnn_dailymail import CnnDailymail
|
||||
from .lm_eval import GSM8K, GPQADiamond, GPQAExtended, GPQAMain
|
||||
from .mmlu import MMLU
|
||||
|
||||
__all__ = ["CnnDailymail", "MMLU"]
|
||||
__all__ = [
|
||||
"CnnDailymail", "MMLU", "GSM8K", "GPQADiamond", "GPQAMain", "GPQAExtended"
|
||||
]
|
||||
|
||||
@ -29,12 +29,13 @@ class CnnDailymail(Evaluator):
|
||||
|
||||
def __init__(self,
|
||||
dataset_path: str = "ccdv/cnn_dailymail",
|
||||
num_samples: int = None,
|
||||
num_samples: Optional[int] = None,
|
||||
random_seed: int = 0,
|
||||
rouge_path: str = "rouge",
|
||||
apply_chat_template: bool = False,
|
||||
system_prompt: Optional[str] = None):
|
||||
super().__init__(apply_chat_template=apply_chat_template,
|
||||
super().__init__(random_seed=random_seed,
|
||||
apply_chat_template=apply_chat_template,
|
||||
system_prompt=system_prompt)
|
||||
self.data = datasets.load_dataset(dataset_path,
|
||||
"3.0.0",
|
||||
@ -73,15 +74,16 @@ class CnnDailymail(Evaluator):
|
||||
@click.option("--num_samples", type=int, default=None)
|
||||
@click.option("--random_seed", type=int, default=0)
|
||||
@click.option("--rouge_path", type=str, default="rouge")
|
||||
@click.option("--apply_chat_template", is_flag=True, default=False)
|
||||
@click.option("--system_prompt", type=Optional[str], default=None)
|
||||
@click.option("--max_input_length", type=int, default=924)
|
||||
@click.option("--max_output_length", type=int, default=100)
|
||||
@click.option("--check_accuracy", is_flag=True, default=False)
|
||||
@click.option("--accuracy_threshold", type=float, default=15)
|
||||
@click.pass_context
|
||||
@staticmethod
|
||||
def command(ctx, dataset_path: str, num_samples: int, random_seed: int,
|
||||
rouge_path: str, max_input_length: int, max_output_length: int,
|
||||
check_accuracy: bool, accuracy_threshold: float) -> None:
|
||||
rouge_path: str, apply_chat_template: bool,
|
||||
system_prompt: Optional[str], max_input_length: int,
|
||||
max_output_length: int) -> None:
|
||||
llm: Union[LLM, PyTorchLLM] = ctx.obj
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=max_output_length,
|
||||
@ -89,9 +91,8 @@ class CnnDailymail(Evaluator):
|
||||
evaluator = CnnDailymail(dataset_path,
|
||||
num_samples=num_samples,
|
||||
random_seed=random_seed,
|
||||
rouge_path=rouge_path)
|
||||
accuracy = evaluator.evaluate(llm, sampling_params)
|
||||
rouge_path=rouge_path,
|
||||
apply_chat_template=apply_chat_template,
|
||||
system_prompt=system_prompt)
|
||||
evaluator.evaluate(llm, sampling_params)
|
||||
llm.shutdown()
|
||||
|
||||
if check_accuracy:
|
||||
assert accuracy >= accuracy_threshold, f"Expected accuracy >= {accuracy_threshold}, but got {accuracy}"
|
||||
|
||||
@ -12,9 +12,12 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from abc import ABC, abstractmethod, abstractstaticmethod
|
||||
import random
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Iterable, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
import tensorrt_llm.profiler as profiler
|
||||
@ -28,8 +31,12 @@ from ..sampling_params import SamplingParams
|
||||
class Evaluator(ABC):
|
||||
|
||||
def __init__(self,
|
||||
random_seed: int = 0,
|
||||
apply_chat_template: bool = False,
|
||||
system_prompt: Optional[str] = None):
|
||||
random.seed(random_seed)
|
||||
np.random.seed(random_seed)
|
||||
torch.manual_seed(random_seed)
|
||||
self.apply_chat_template = apply_chat_template
|
||||
self.system_prompt = system_prompt
|
||||
|
||||
@ -72,10 +79,12 @@ class Evaluator(ABC):
|
||||
profiler.stop("trtllm exec")
|
||||
elapsed_time = profiler.elapsed_time_in_sec("trtllm exec")
|
||||
logger.info(f"TRTLLM execution time: {elapsed_time:.3f} seconds.")
|
||||
profiler.reset("trtllm exec")
|
||||
|
||||
score = self.compute_score(outputs, references, *zip(*auxiliaries))
|
||||
return score
|
||||
|
||||
@abstractstaticmethod
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def command(ctx, *args, **kwargs) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
296
tensorrt_llm/evaluate/lm_eval.py
Normal file
296
tensorrt_llm/evaluate/lm_eval.py
Normal file
@ -0,0 +1,296 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import copy
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from typing import Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import click
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
import tensorrt_llm.profiler as profiler
|
||||
|
||||
try:
|
||||
from lm_eval.api.model import TemplateLM
|
||||
except ImportError:
|
||||
TemplateLM = object
|
||||
|
||||
from .._torch import LLM as PyTorchLLM
|
||||
from ..llmapi import LLM, RequestOutput
|
||||
from ..logger import logger
|
||||
from ..sampling_params import SamplingParams
|
||||
from .interface import Evaluator
|
||||
|
||||
|
||||
class LmEvalWrapper(TemplateLM):
|
||||
|
||||
def __init__(self,
|
||||
llm: Union[LLM, PyTorchLLM],
|
||||
sampling_params: Optional[SamplingParams] = None):
|
||||
super().__init__()
|
||||
self.llm = llm
|
||||
self.sampling_params = sampling_params
|
||||
|
||||
@property
|
||||
def eot_token_id(self) -> int:
|
||||
return self.llm.tokenizer.eos_token_id
|
||||
|
||||
def apply_chat_template(self,
|
||||
chat_history: List[Dict[str, str]],
|
||||
add_generation_prompt: bool = True) -> str:
|
||||
"""
|
||||
Method to apply a chat template to a list of chat history between user and model.
|
||||
"""
|
||||
return self.llm.tokenizer.apply_chat_template(
|
||||
chat_history,
|
||||
tokenize=False,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
continue_final_message=not add_generation_prompt,
|
||||
)
|
||||
|
||||
@property
|
||||
def tokenizer_name(self) -> str:
|
||||
return self.llm.tokenizer.name_or_path.replace("/", "__")
|
||||
|
||||
def tok_encode(self, string: str, **kwargs) -> List[int]:
|
||||
return self.llm.tokenizer.encode(string, **kwargs)
|
||||
|
||||
def _loglikelihood_tokens(self, requests,
|
||||
**kwargs) -> List[Tuple[float, bool]]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def loglikelihood_rolling(self,
|
||||
requests,
|
||||
disable_tqdm: bool = False) -> List[float]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def _get_sampling_params(self, gen_kwargs: dict) -> SamplingParams:
|
||||
params_mapping = {
|
||||
"temperature": "temperature",
|
||||
"top_p": "top_p",
|
||||
"max_gen_toks": "max_tokens",
|
||||
"until": "stop",
|
||||
}
|
||||
if self.sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
else:
|
||||
sampling_params = copy.deepcopy(self.sampling_params)
|
||||
for lm_eval_key, trtllm_key in params_mapping.items():
|
||||
value = gen_kwargs.pop(lm_eval_key, None)
|
||||
if value is not None:
|
||||
setattr(sampling_params, trtllm_key, value)
|
||||
return sampling_params
|
||||
|
||||
def generate_until(self, requests, disable_tqdm: bool = False) -> List[str]:
|
||||
profiler.start("trtllm exec")
|
||||
outputs = []
|
||||
for request in tqdm(requests,
|
||||
desc="Submitting requests",
|
||||
disable=disable_tqdm):
|
||||
prompt, gen_kwargs = request.args
|
||||
sampling_params = self._get_sampling_params(gen_kwargs)
|
||||
output = self.llm.generate_async(prompt,
|
||||
sampling_params=sampling_params)
|
||||
outputs.append(output)
|
||||
|
||||
for output in tqdm(outputs,
|
||||
desc="Fetching responses",
|
||||
disable=disable_tqdm):
|
||||
output.result()
|
||||
profiler.stop("trtllm exec")
|
||||
elapsed_time = profiler.elapsed_time_in_sec("trtllm exec")
|
||||
logger.info(f"TRTLLM execution time: {elapsed_time:.3f} seconds.")
|
||||
profiler.reset("trtllm exec")
|
||||
|
||||
return [output.outputs[0].text for output in outputs]
|
||||
|
||||
|
||||
class LmEvalEvaluator(Evaluator):
|
||||
|
||||
def __init__(self,
|
||||
task_name: str,
|
||||
dataset_path: str = None,
|
||||
num_samples: Optional[int] = None,
|
||||
random_seed: int = 0,
|
||||
apply_chat_template: bool = False,
|
||||
system_prompt: Optional[str] = None):
|
||||
try:
|
||||
import lm_eval
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
f"Evaluation task {self.__class__.__name__} requires `lm_eval`. "
|
||||
"Please install the package first, e.g., `pip install lm_eval`."
|
||||
) from e
|
||||
super().__init__(random_seed=random_seed,
|
||||
apply_chat_template=apply_chat_template,
|
||||
system_prompt=system_prompt)
|
||||
self.task_name = task_name
|
||||
self.dataset_path = dataset_path
|
||||
self.num_samples = num_samples
|
||||
|
||||
task_manager = lm_eval.tasks.TaskManager(
|
||||
include_path=f"{os.path.dirname(__file__)}/lm_eval_tasks")
|
||||
with self._patch_lm_eval():
|
||||
self.task_dict = lm_eval.tasks.get_task_dict(
|
||||
task_name, task_manager=task_manager)
|
||||
# Few-shot random seed
|
||||
self.task_dict[self.task_name].set_fewshot_seed(random_seed)
|
||||
# Shuffle dataset
|
||||
data = self.task_dict[self.task_name].dataset
|
||||
for split in data.keys():
|
||||
data[split] = data[split].shuffle(random_seed)
|
||||
|
||||
@contextmanager
|
||||
def _patch_lm_eval(self):
|
||||
if self.dataset_path is None:
|
||||
yield
|
||||
return
|
||||
|
||||
import lm_eval
|
||||
self._task_config_post_init = lm_eval.api.task.TaskConfig.__post_init__
|
||||
|
||||
def _patched(task_config, *args, **kwargs):
|
||||
task_config.dataset_path = self.dataset_path
|
||||
self._task_config_post_init(task_config, *args, **kwargs)
|
||||
|
||||
lm_eval.api.task.TaskConfig.__post_init__ = _patched
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
lm_eval.api.task.TaskConfig.__post_init__ = self._task_config_post_init
|
||||
|
||||
def generate_samples(self) -> Iterable[tuple]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def compute_score(self, outputs: List[RequestOutput], references: List[str],
|
||||
*auxiliaries) -> float:
|
||||
raise NotImplementedError()
|
||||
|
||||
def evaluate(self,
|
||||
llm: Union[LLM, PyTorchLLM],
|
||||
sampling_params: Optional[SamplingParams] = None) -> float:
|
||||
import lm_eval
|
||||
results = lm_eval.evaluate(lm=LmEvalWrapper(llm, sampling_params),
|
||||
task_dict=self.task_dict,
|
||||
limit=self.num_samples,
|
||||
apply_chat_template=self.apply_chat_template,
|
||||
system_instruction=self.system_prompt)
|
||||
# Normalize scores to range 0~100
|
||||
scores = results["results"][self.task_name]
|
||||
for metric in scores.keys():
|
||||
if isinstance(scores[metric], (float, int)):
|
||||
scores[metric] *= 100
|
||||
logger.info(
|
||||
f"lm-eval {self.task_name} results (scores normalized to range 0~100):\n{lm_eval.utils.make_table(results)}"
|
||||
)
|
||||
|
||||
average_acc = np.mean(
|
||||
[acc for m, acc in scores.items() if "_stderr" not in m])
|
||||
logger.info(
|
||||
f"lm-eval {self.task_name} average accuracy: {average_acc:.2f}")
|
||||
return average_acc
|
||||
|
||||
@classmethod
|
||||
def command_harness(cls, ctx, **kwargs):
|
||||
llm: Union[LLM, PyTorchLLM] = ctx.obj
|
||||
evaluator = cls(dataset_path=kwargs.pop("dataset_path", None),
|
||||
num_samples=kwargs.pop("num_samples", None),
|
||||
random_seed=kwargs.pop("random_seed", 0),
|
||||
apply_chat_template=kwargs.pop("apply_chat_template",
|
||||
False),
|
||||
system_prompt=kwargs.pop("system_prompt", None))
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=kwargs.pop("max_output_length"),
|
||||
truncate_prompt_tokens=kwargs.pop("max_input_length"))
|
||||
evaluator.evaluate(llm, sampling_params)
|
||||
llm.shutdown()
|
||||
|
||||
|
||||
class GSM8K(LmEvalEvaluator):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__("gsm8k", **kwargs)
|
||||
|
||||
@click.command("gsm8k")
|
||||
@click.option("--dataset_path", type=str, default=None)
|
||||
@click.option("--num_samples", type=int, default=None)
|
||||
@click.option("--random_seed", type=int, default=0)
|
||||
@click.option("--apply_chat_template", is_flag=True, default=False)
|
||||
@click.option("--system_prompt", type=Optional[str], default=None)
|
||||
@click.option("--max_input_length", type=int, default=4096)
|
||||
@click.option("--max_output_length", type=int, default=256)
|
||||
@click.pass_context
|
||||
@staticmethod
|
||||
def command(ctx, **kwargs) -> None:
|
||||
GSM8K.command_harness(ctx, **kwargs)
|
||||
|
||||
|
||||
class GPQADiamond(LmEvalEvaluator):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__("gpqa_diamond_cot_zeroshot_aa", **kwargs)
|
||||
|
||||
@click.command("gpqa_diamond")
|
||||
@click.option("--dataset_path", type=str, default=None)
|
||||
@click.option("--num_samples", type=int, default=None)
|
||||
@click.option("--random_seed", type=int, default=0)
|
||||
@click.option("--apply_chat_template", is_flag=True, default=False)
|
||||
@click.option("--system_prompt", type=Optional[str], default=None)
|
||||
@click.option("--max_input_length", type=int, default=4096)
|
||||
@click.option("--max_output_length", type=int, default=32768)
|
||||
@click.pass_context
|
||||
@staticmethod
|
||||
def command(ctx, **kwargs) -> None:
|
||||
GPQADiamond.command_harness(ctx, **kwargs)
|
||||
|
||||
|
||||
class GPQAMain(LmEvalEvaluator):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__("gpqa_main_cot_zeroshot_aa", **kwargs)
|
||||
|
||||
@click.command("gpqa_main")
|
||||
@click.option("--dataset_path", type=str, default=None)
|
||||
@click.option("--num_samples", type=int, default=None)
|
||||
@click.option("--random_seed", type=int, default=0)
|
||||
@click.option("--apply_chat_template", is_flag=True, default=False)
|
||||
@click.option("--system_prompt", type=Optional[str], default=None)
|
||||
@click.option("--max_input_length", type=int, default=4096)
|
||||
@click.option("--max_output_length", type=int, default=32768)
|
||||
@click.pass_context
|
||||
@staticmethod
|
||||
def command(ctx, **kwargs) -> None:
|
||||
GPQAMain.command_harness(ctx, **kwargs)
|
||||
|
||||
|
||||
class GPQAExtended(LmEvalEvaluator):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__("gpqa_extended_cot_zeroshot_aa", **kwargs)
|
||||
|
||||
@click.command("gpqa_extended")
|
||||
@click.option("--dataset_path", type=str, default=None)
|
||||
@click.option("--num_samples", type=int, default=None)
|
||||
@click.option("--random_seed", type=int, default=0)
|
||||
@click.option("--apply_chat_template", is_flag=True, default=False)
|
||||
@click.option("--system_prompt", type=Optional[str], default=None)
|
||||
@click.option("--max_input_length", type=int, default=4096)
|
||||
@click.option("--max_output_length", type=int, default=32768)
|
||||
@click.pass_context
|
||||
@staticmethod
|
||||
def command(ctx, **kwargs) -> None:
|
||||
GPQAExtended.command_harness(ctx, **kwargs)
|
||||
@ -0,0 +1,26 @@
|
||||
import yaml
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def main() -> None:
|
||||
subset = ["extended", "diamond", "main"]
|
||||
setting = "cot_zeroshot_aa"
|
||||
for task in tqdm(subset):
|
||||
file_name = f"gpqa_{task}_{setting}.yaml"
|
||||
try:
|
||||
with open(f"{file_name}", "w") as f:
|
||||
f.write("# Generated by _generate_configs.py\n")
|
||||
yaml.dump(
|
||||
{
|
||||
"include": f"_gpqa_{setting}_yaml",
|
||||
"task": f"gpqa_{task}_{setting}",
|
||||
"dataset_name": f"gpqa_{task}",
|
||||
},
|
||||
f,
|
||||
)
|
||||
except FileExistsError:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -0,0 +1,31 @@
|
||||
dataset_path: Idavidrein/gpqa
|
||||
tag: gpqa
|
||||
output_type: generate_until
|
||||
process_docs: !function utils.process_docs
|
||||
training_split: train
|
||||
# Because huggingface dataset only has train split
|
||||
validation_split: train
|
||||
test_split: null
|
||||
doc_to_text: "Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering.\n{{Question}}\nA) {{choice1}}\nB) {{choice2}}\nC) {{choice3}}\nD) {{choice4}}"
|
||||
doc_to_target: answer
|
||||
filter_list:
|
||||
- name: "strict-match"
|
||||
filter:
|
||||
- function: "regex"
|
||||
regex_pattern: '(?i)Answer[ \t]*:[ \t]*([A-D])'
|
||||
group_select: 0
|
||||
- function: "take_first"
|
||||
generation_kwargs:
|
||||
until:
|
||||
- "</s>"
|
||||
do_sample: false
|
||||
temperature: 0.0
|
||||
num_fewshot: 0
|
||||
metric_list:
|
||||
- metric: exact_match
|
||||
aggregation: mean
|
||||
higher_is_better: true
|
||||
ignore_case: true
|
||||
ignore_punctuation: true
|
||||
metadata:
|
||||
version: 1.0
|
||||
@ -0,0 +1,4 @@
|
||||
# Generated by _generate_configs.py
|
||||
dataset_name: gpqa_diamond
|
||||
include: _gpqa_cot_zeroshot_aa_yaml
|
||||
task: gpqa_diamond_cot_zeroshot_aa
|
||||
@ -0,0 +1,4 @@
|
||||
# Generated by _generate_configs.py
|
||||
dataset_name: gpqa_extended
|
||||
include: _gpqa_cot_zeroshot_aa_yaml
|
||||
task: gpqa_extended_cot_zeroshot_aa
|
||||
@ -0,0 +1,4 @@
|
||||
# Generated by _generate_configs.py
|
||||
dataset_name: gpqa_main
|
||||
include: _gpqa_cot_zeroshot_aa_yaml
|
||||
task: gpqa_main_cot_zeroshot_aa
|
||||
@ -0,0 +1,35 @@
|
||||
import random
|
||||
|
||||
import datasets
|
||||
|
||||
|
||||
def preprocess(text):
|
||||
if text is None:
|
||||
return " "
|
||||
return text.strip()
|
||||
|
||||
|
||||
def process_docs(dataset: datasets.Dataset) -> datasets.Dataset:
|
||||
|
||||
def _process_doc(doc):
|
||||
choices = [
|
||||
preprocess(doc["Incorrect Answer 1"]),
|
||||
preprocess(doc["Incorrect Answer 2"]),
|
||||
preprocess(doc["Incorrect Answer 3"]),
|
||||
preprocess(doc["Correct Answer"]),
|
||||
]
|
||||
|
||||
random.shuffle(choices)
|
||||
correct_answer_index = choices.index(preprocess(doc["Correct Answer"]))
|
||||
|
||||
out_doc = {
|
||||
"choice1": choices[0],
|
||||
"choice2": choices[1],
|
||||
"choice3": choices[2],
|
||||
"choice4": choices[3],
|
||||
"choices": [choices[0], choices[1], choices[2], choices[3]],
|
||||
"answer": f"{chr(65 + correct_answer_index)}",
|
||||
}
|
||||
return out_doc
|
||||
|
||||
return dataset.map(_process_doc)
|
||||
@ -13,7 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import math
|
||||
import random
|
||||
from typing import Iterable, List, Optional, Union
|
||||
|
||||
import click
|
||||
@ -110,12 +109,13 @@ class MMLU(Evaluator):
|
||||
|
||||
def __init__(self,
|
||||
dataset_path: str,
|
||||
num_samples: int = None,
|
||||
num_samples: Optional[int] = None,
|
||||
num_train: int = 5,
|
||||
random_seed: int = 0,
|
||||
apply_chat_template: bool = False,
|
||||
system_prompt: Optional[str] = None):
|
||||
super().__init__(apply_chat_template=apply_chat_template,
|
||||
super().__init__(random_seed=random_seed,
|
||||
apply_chat_template=apply_chat_template,
|
||||
system_prompt=system_prompt)
|
||||
self.dataset_path = dataset_path
|
||||
if num_samples is None:
|
||||
@ -124,8 +124,6 @@ class MMLU(Evaluator):
|
||||
self.num_samples_per_subject = math.ceil(
|
||||
num_samples / len(self.SUBJECT_TO_SUBCATEGORIES))
|
||||
self.num_train = num_train
|
||||
random.seed(random_seed)
|
||||
np.random.seed(random_seed)
|
||||
|
||||
def format_subject(self, subject):
|
||||
line = subject.split("_")
|
||||
@ -227,15 +225,16 @@ class MMLU(Evaluator):
|
||||
@click.option("--num_samples", type=int, default=None)
|
||||
@click.option("--num_train", type=int, default=5)
|
||||
@click.option("--random_seed", type=int, default=0)
|
||||
@click.option("--apply_chat_template", is_flag=True, default=False)
|
||||
@click.option("--system_prompt", type=Optional[str], default=None)
|
||||
@click.option("--max_input_length", type=int, default=4094)
|
||||
@click.option("--max_output_length", type=int, default=2)
|
||||
@click.option("--check_accuracy", is_flag=True, default=False)
|
||||
@click.option("--accuracy_threshold", type=float, default=30)
|
||||
@click.pass_context
|
||||
@staticmethod
|
||||
def command(ctx, dataset_path: str, num_samples: int, num_train: int,
|
||||
random_seed: int, max_input_length: int, max_output_length: int,
|
||||
check_accuracy: bool, accuracy_threshold: float) -> None:
|
||||
random_seed: int, apply_chat_template: bool,
|
||||
system_prompt: Optional[str], max_input_length: int,
|
||||
max_output_length: int) -> None:
|
||||
llm: Union[LLM, PyTorchLLM] = ctx.obj
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=max_output_length,
|
||||
@ -243,9 +242,8 @@ class MMLU(Evaluator):
|
||||
evaluator = MMLU(dataset_path,
|
||||
num_samples=num_samples,
|
||||
num_train=num_train,
|
||||
random_seed=random_seed)
|
||||
accuracy = evaluator.evaluate(llm, sampling_params)
|
||||
random_seed=random_seed,
|
||||
apply_chat_template=apply_chat_template,
|
||||
system_prompt=system_prompt)
|
||||
evaluator.evaluate(llm, sampling_params)
|
||||
llm.shutdown()
|
||||
|
||||
if check_accuracy:
|
||||
assert accuracy >= accuracy_threshold, f"Expected accuracy >= {accuracy_threshold}, but got {accuracy}"
|
||||
|
||||
@ -28,6 +28,10 @@ class TransformersTokenizer(TokenizerBase):
|
||||
def pad_token_id(self) -> int:
|
||||
return self.tokenizer.pad_token_id
|
||||
|
||||
@property
|
||||
def name_or_path(self) -> str:
|
||||
return self.tokenizer.name_or_path
|
||||
|
||||
def encode(self, text: str, *args, **kwargs) -> List[int]:
|
||||
return self.tokenizer.encode(text, *args, **kwargs)
|
||||
|
||||
|
||||
@ -68,9 +68,13 @@ class Timer:
|
||||
return None
|
||||
return self._total_elapsed_times[tag]
|
||||
|
||||
def reset(self):
|
||||
self._start_times.clear()
|
||||
self._total_elapsed_times.clear()
|
||||
def reset(self, tag=None) -> None:
|
||||
if tag is None:
|
||||
self._start_times.clear()
|
||||
self._total_elapsed_times.clear()
|
||||
else:
|
||||
self._start_times.pop(tag, None)
|
||||
self._total_elapsed_times.pop(tag, None)
|
||||
|
||||
def summary(self):
|
||||
logger.info('Profile Results')
|
||||
@ -93,8 +97,8 @@ def elapsed_time_in_sec(tag):
|
||||
return _default_timer.elapsed_time_in_sec(tag)
|
||||
|
||||
|
||||
def reset():
|
||||
_default_timer.reset()
|
||||
def reset(tag=None):
|
||||
_default_timer.reset(tag=tag)
|
||||
|
||||
|
||||
def summary():
|
||||
|
||||
@ -261,6 +261,38 @@ class MMLU(AccuracyTask):
|
||||
EVALUATOR_KWARGS = dict(dataset_path=DATASET_DIR, random_seed=0)
|
||||
|
||||
|
||||
class GSM8K(AccuracyTask):
|
||||
DATASET = "gsm8k"
|
||||
DATASET_DIR = f"{llm_models_root()}/datasets/openai/gsm8k"
|
||||
|
||||
ALPHA = 0.02
|
||||
BETA = 0.2
|
||||
SIGMA = 50
|
||||
NUM_SAMPLES = 1319 # Full sample
|
||||
|
||||
MAX_INPUT_LEN = 4096
|
||||
MAX_OUTPUT_LEN = 256
|
||||
|
||||
EVALUATOR_CLS = tensorrt_llm.evaluate.GSM8K
|
||||
EVALUATOR_KWARGS = dict(dataset_path=DATASET_DIR, random_seed=0)
|
||||
|
||||
|
||||
class GPQADiamond(AccuracyTask):
|
||||
DATASET = "gpqa_diamond"
|
||||
DATASET_DIR = f"{llm_models_root()}/datasets/gpqa"
|
||||
|
||||
ALPHA = 0.05
|
||||
BETA = 0.2
|
||||
SIGMA = 50
|
||||
NUM_SAMPLES = 198 # Full sample
|
||||
|
||||
MAX_INPUT_LEN = 4096
|
||||
MAX_OUTPUT_LEN = 32768
|
||||
|
||||
EVALUATOR_CLS = tensorrt_llm.evaluate.GPQADiamond
|
||||
EVALUATOR_KWARGS = dict(dataset_path=DATASET_DIR, random_seed=0)
|
||||
|
||||
|
||||
class PassKeyRetrieval64k(AccuracyTask):
|
||||
DATASET = "passkey_retrieval_64k"
|
||||
LEVEL = 3
|
||||
|
||||
12
tests/integration/defs/accuracy/references/gpqa_diamond.yaml
Normal file
12
tests/integration/defs/accuracy/references/gpqa_diamond.yaml
Normal file
@ -0,0 +1,12 @@
|
||||
meta-llama/Llama-3.3-70B-Instruct:
|
||||
- quant_algo: NVFP4
|
||||
kv_cache_quant_algo: FP8
|
||||
accuracy: 45.55
|
||||
- quant_algo: FP8
|
||||
accuracy: 48.03
|
||||
deepseek-ai/DeepSeek-R1:
|
||||
- quant_algo: NVFP4
|
||||
accuracy: 70.45
|
||||
- quant_algo: NVFP4
|
||||
spec_dec_algo: MTP
|
||||
accuracy: 70.06
|
||||
19
tests/integration/defs/accuracy/references/gsm8k.yaml
Normal file
19
tests/integration/defs/accuracy/references/gsm8k.yaml
Normal file
@ -0,0 +1,19 @@
|
||||
meta-llama/Llama-3.1-8B-Instruct:
|
||||
- accuracy: 74.20
|
||||
- quant_algo: FP8
|
||||
accuracy: 74.30
|
||||
- quant_algo: FP8
|
||||
kv_cache_quant_algo: FP8
|
||||
accuracy: 72.85
|
||||
meta-llama/Llama-3.3-70B-Instruct:
|
||||
- quant_algo: NVFP4
|
||||
kv_cache_quant_algo: FP8
|
||||
accuracy: 75.61
|
||||
- quant_algo: FP8
|
||||
accuracy: 83.30
|
||||
deepseek-ai/DeepSeek-R1:
|
||||
- quant_algo: NVFP4
|
||||
accuracy: 95.42
|
||||
- quant_algo: NVFP4
|
||||
spec_dec_algo: MTP
|
||||
accuracy: 95.42
|
||||
@ -68,3 +68,9 @@ deepseek-ai/DeepSeek-V3-Lite:
|
||||
- quant_algo: FP8_BLOCK_SCALES
|
||||
spec_dec_algo: MTP
|
||||
accuracy: 71.29
|
||||
deepseek-ai/DeepSeek-R1:
|
||||
- quant_algo: NVFP4
|
||||
accuracy: 87.33
|
||||
- quant_algo: NVFP4
|
||||
spec_dec_algo: MTP
|
||||
accuracy: 87.33
|
||||
|
||||
@ -19,8 +19,11 @@ import pandas as pd
|
||||
|
||||
metric_regex = {
|
||||
"rouge1": r"(?<=rouge1: )\d+\.\d+",
|
||||
"perplexity": r"(?<=Per-token perplexity: )\d+\.\d+",
|
||||
"mmlu": r"(?<=MMLU weighted average accuracy: )\d+\.\d+",
|
||||
"gsm8k": r"(?<=gsm8k average accuracy: )\d+\.\d+",
|
||||
"gpqa_diamond":
|
||||
r"(?<=gpqa_diamond_cot_zeroshot_aa average accuracy: )\d+\.\d+",
|
||||
"perplexity": r"(?<=Per-token perplexity: )\d+\.\d+",
|
||||
"passkey": r"(?<=passkey accuracy: )\d+\.\d+"
|
||||
}
|
||||
|
||||
|
||||
@ -22,7 +22,8 @@ from tensorrt_llm.quantization import QuantAlgo
|
||||
|
||||
from ..conftest import (llm_models_root, parametrize_with_ids, skip_pre_ada,
|
||||
skip_pre_blackwell, skip_pre_hopper)
|
||||
from .accuracy_core import MMLU, CnnDailymail, LlmapiAccuracyTestHarness
|
||||
from .accuracy_core import (GSM8K, MMLU, CnnDailymail, GPQADiamond,
|
||||
LlmapiAccuracyTestHarness)
|
||||
|
||||
|
||||
class TestLlama3_1_8B(LlmapiAccuracyTestHarness):
|
||||
@ -67,10 +68,10 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
|
||||
)
|
||||
llm = LLM(self.MODEL_PATH, pytorch_backend_config=pytorch_config)
|
||||
with llm:
|
||||
task = CnnDailymail(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
|
||||
@parametrize_with_ids("torch_compile", [False, True])
|
||||
@parametrize_with_ids("attn_backend", ["TRTLLM", "FLASHINFER"])
|
||||
@ -96,10 +97,10 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
|
||||
pipeline_parallel_size=pp_size,
|
||||
pytorch_backend_config=pytorch_config)
|
||||
with llm:
|
||||
task = CnnDailymail(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
|
||||
@skip_pre_ada
|
||||
@parametrize_with_ids("torch_compile", [False, True])
|
||||
@ -126,10 +127,10 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
|
||||
if fp8kv:
|
||||
assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.FP8
|
||||
with llm:
|
||||
task = CnnDailymail(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
|
||||
@skip_pre_ada
|
||||
@parametrize_with_ids("torch_compile", [False, True])
|
||||
@ -166,10 +167,10 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
|
||||
if fp8kv:
|
||||
assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.FP8
|
||||
with llm:
|
||||
task = CnnDailymail(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
|
||||
|
||||
class TestLlama3_3_70BInstruct(LlmapiAccuracyTestHarness):
|
||||
@ -185,6 +186,11 @@ class TestLlama3_3_70BInstruct(LlmapiAccuracyTestHarness):
|
||||
task.evaluate(llm)
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
task = GPQADiamond(self.MODEL_NAME)
|
||||
task.evaluate(llm,
|
||||
extra_evaluator_kwargs=dict(apply_chat_template=True))
|
||||
|
||||
@pytest.mark.skip_less_device(4)
|
||||
@pytest.mark.skip_device_not_contain(["B200"])
|
||||
@ -197,6 +203,11 @@ class TestLlama3_3_70BInstruct(LlmapiAccuracyTestHarness):
|
||||
task.evaluate(llm)
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
task = GPQADiamond(self.MODEL_NAME)
|
||||
task.evaluate(llm,
|
||||
extra_evaluator_kwargs=dict(apply_chat_template=True))
|
||||
|
||||
|
||||
class TestMistral7B(LlmapiAccuracyTestHarness):
|
||||
@ -404,6 +415,7 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
|
||||
@pytest.mark.skip_less_device(4)
|
||||
@skip_pre_blackwell
|
||||
@parametrize_with_ids("attention_dp,cuda_graph,overlap_scheduler",
|
||||
[(False, False, False), (True, False, False),
|
||||
@ -431,6 +443,48 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
task.evaluate(llm)
|
||||
|
||||
|
||||
class TestDeepSeekR1(LlmapiAccuracyTestHarness):
|
||||
MODEL_NAME = "deepseek-ai/DeepSeek-R1"
|
||||
MODEL_PATH = f"{llm_models_root()}/DeepSeek-R1/DeepSeek-R1"
|
||||
|
||||
@pytest.mark.skip_less_device(8)
|
||||
@skip_pre_blackwell
|
||||
@parametrize_with_ids("overlap_scheduler", [False, True])
|
||||
@parametrize_with_ids("cuda_graph", [False, True])
|
||||
@parametrize_with_ids("attention_dp", [False, True])
|
||||
@parametrize_with_ids("mtp_nextn", [None, 2])
|
||||
@pytest.mark.parametrize("tp_size,pp_size,ep_size", [(8, 1, 1), (8, 1, 4),
|
||||
(8, 1, 8)],
|
||||
ids=["tp8", "tp8ep4", "tp8ep8"])
|
||||
def test_nvfp4_8gpus(self, tp_size, pp_size, ep_size, mtp_nextn,
|
||||
attention_dp, cuda_graph, overlap_scheduler):
|
||||
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.4)
|
||||
pytorch_config = PyTorchConfig(
|
||||
enable_overlap_scheduler=overlap_scheduler,
|
||||
use_cuda_graph=cuda_graph)
|
||||
if mtp_nextn is not None and mtp_nextn > 0:
|
||||
mtp_config = MTPDecodingConfig(num_nextn_predict_layers=mtp_nextn)
|
||||
else:
|
||||
mtp_config = None
|
||||
llm = LLM(f"{llm_models_root()}/DeepSeek-R1/DeepSeek-R1-FP4",
|
||||
tensor_parallel_size=tp_size,
|
||||
pipeline_parallel_size=pp_size,
|
||||
moe_expert_parallel_size=ep_size,
|
||||
kv_cache_config=kv_cache_config,
|
||||
pytorch_backend_config=pytorch_config,
|
||||
enable_attention_dp=attention_dp,
|
||||
speculative_config=mtp_config)
|
||||
assert llm.args.quant_config.quant_algo == QuantAlgo.NVFP4
|
||||
with llm:
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
task = GPQADiamond(self.MODEL_NAME)
|
||||
task.evaluate(llm,
|
||||
extra_evaluator_kwargs=dict(apply_chat_template=True))
|
||||
|
||||
|
||||
class TestMinitron4BBaseInstruct(LlmapiAccuracyTestHarness):
|
||||
MODEL_NAME = "nvidia/Nemotron-Mini-4B-Instruct"
|
||||
MODEL_PATH = f"{llm_models_root()}/nemotron/nemotron-mini-4b-instruct_vfp8-fp8-bf16-export"
|
||||
|
||||
@ -1,80 +0,0 @@
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from defs.common import venv_check_call
|
||||
from defs.conftest import get_sm_version, llm_models_root
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", ["DeepSeek-R1"], ids=["deepseek_r1"])
|
||||
@pytest.mark.parametrize("quant", ["fp4", "fp8"])
|
||||
@pytest.mark.parametrize("tp_size", [8], ids=["tp8"])
|
||||
@pytest.mark.parametrize("pp_size", [1], ids=["pp1"])
|
||||
@pytest.mark.parametrize("ep_size", [1, 4, 8], ids=["ep1", "ep4", "ep8"])
|
||||
@pytest.mark.parametrize("mtp_nextn", [0, 1, 2],
|
||||
ids=["nextn0", "nextn1", "nextn2"])
|
||||
@pytest.mark.parametrize("enable_dp", [True, False],
|
||||
ids=["enable_dp", "disable_dp"])
|
||||
@pytest.mark.parametrize("enable_cuda_graph", [True, False],
|
||||
ids=["enable_cuda_graph", "disable_cuda_graph"])
|
||||
@pytest.mark.parametrize(
|
||||
"enable_overlap_scheduler", [True, False],
|
||||
ids=["enable_overlap_scheduler", "disable_overlap_scheduler"])
|
||||
def test_deepseek_gpqa_llmapi(llmapi_example_root, llm_datasets_root, llm_venv,
|
||||
model_name, quant, tp_size, pp_size, ep_size,
|
||||
mtp_nextn, enable_dp, enable_cuda_graph,
|
||||
enable_overlap_scheduler):
|
||||
model_path = {
|
||||
"fp8": "DeepSeek-R1",
|
||||
"fp4": "DeepSeek-R1-FP4",
|
||||
}
|
||||
assert quant in model_path.keys()
|
||||
|
||||
is_fp8 = quant == "fp8"
|
||||
is_fp4 = quant == "fp4"
|
||||
|
||||
if ep_size > tp_size:
|
||||
pytest.skip(
|
||||
f"Expert parallel size {ep_size} must be less than or equal to tensor parallel size {tp_size}"
|
||||
)
|
||||
|
||||
if torch.cuda.device_count() < tp_size * pp_size:
|
||||
pytest.skip(f"Not enough GPUs available, need {tp_size * pp_size} "
|
||||
f"but only have {torch.cuda.device_count()}")
|
||||
|
||||
if is_fp8:
|
||||
pytest.skip(
|
||||
f"FP8 is not supported for gpqa test, and it will be added in the near future"
|
||||
)
|
||||
|
||||
if is_fp4 and (get_sm_version() < 100 or get_sm_version() >= 120):
|
||||
pytest.skip(
|
||||
f"FP4 is not supported in this SM version {get_sm_version()}")
|
||||
|
||||
if pp_size > 1:
|
||||
pytest.skip(
|
||||
"PP is not supported for gpqa test, and it will be added in the near future"
|
||||
)
|
||||
|
||||
model_dir = str(Path(llm_models_root()) / model_name / model_path[quant])
|
||||
gpqa_data_path = str(Path(llm_datasets_root) / "gpqa/gpqa_diamond.csv")
|
||||
|
||||
assert Path(model_dir).exists()
|
||||
|
||||
print("Run GPQA test")
|
||||
gpqa_cmd = [
|
||||
f"{llmapi_example_root}/../gpqa_llmapi.py",
|
||||
f"--hf_model_dir={model_dir}", f"--data_dir={gpqa_data_path}",
|
||||
f"--tp_size={tp_size}", f"--ep_size={ep_size}", "--concurrency=8",
|
||||
f"--mtp_nextn={mtp_nextn}", "--print_iter_log", "--batch_size=32",
|
||||
"--max_num_tokens=4096", "--check_accuracy",
|
||||
"--accuracy_threshold=0.65", "--num_runs=3"
|
||||
]
|
||||
if enable_cuda_graph:
|
||||
gpqa_cmd.append("--use_cuda_graph")
|
||||
if enable_overlap_scheduler:
|
||||
gpqa_cmd.append("--enable_overlap_scheduler")
|
||||
if enable_dp:
|
||||
gpqa_cmd.append("--enable_attention_dp")
|
||||
|
||||
venv_check_call(llm_venv, gpqa_cmd)
|
||||
@ -296,10 +296,6 @@ examples/test_whisper.py::test_llm_whisper_general[large-v3-disable_gemm_plugin-
|
||||
examples/test_whisper.py::test_llm_whisper_general[large-v3-enable_gemm_plugin-enable_attention_plugin-disable_weight_only-float16-nb:1-use_python_runtime]
|
||||
examples/test_whisper.py::test_llm_whisper_general[large-v3-disable_gemm_plugin-enable_attention_plugin-int8-float16-nb:1-use_cpp_runtime]
|
||||
examples/test_whisper.py::test_llm_whisper_general[large-v3-disable_gemm_plugin-enable_attention_plugin-int4-float16-nb:1-use_cpp_runtime]
|
||||
examples/test_deepseek.py::test_deepseek_gpqa_llmapi[enable_overlap_scheduler-enable_cuda_graph-disable_dp-nextn0-ep4-pp1-tp8-fp4-deepseek_r1]
|
||||
examples/test_deepseek.py::test_deepseek_gpqa_llmapi[enable_overlap_scheduler-enable_cuda_graph-disable_dp-nextn2-ep4-pp1-tp8-fp4-deepseek_r1]
|
||||
examples/test_deepseek.py::test_deepseek_gpqa_llmapi[enable_overlap_scheduler-enable_cuda_graph-disable_dp-nextn0-ep8-pp1-tp8-fp4-deepseek_r1]
|
||||
examples/test_deepseek.py::test_deepseek_gpqa_llmapi[enable_overlap_scheduler-enable_cuda_graph-disable_dp-nextn2-ep8-pp1-tp8-fp4-deepseek_r1]
|
||||
|
||||
# Accuracy test list
|
||||
accuracy/test_cli_flow.py::TestGpt2::test_auto_dtype
|
||||
@ -434,6 +430,12 @@ accuracy/test_llm_api_pytorch.py::TestMistral7B::test_auto_dtype
|
||||
accuracy/test_llm_api_pytorch.py::TestMinitron4BBaseInstruct::test_fp8_prequantized
|
||||
accuracy/test_llm_api_pytorch.py::TestNemotronNas::test_auto_dtype_tp8
|
||||
accuracy/test_llm_api_pytorch.py::TestQwen2_7BInstruct::test_auto_dtype
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_8gpus[tp8-cuda_graph-overlap_scheduler]
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_8gpus[tp8-mtp_nextn=2-cuda_graph-overlap_scheduler]
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_8gpus[tp8ep4-cuda_graph-overlap_scheduler]
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_8gpus[tp8ep4-mtp_nextn=2-cuda_graph-overlap_scheduler]
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_8gpus[tp8ep8-cuda_graph-overlap_scheduler]
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_8gpus[tp8ep8-mtp_nextn=2-cuda_graph-overlap_scheduler]
|
||||
|
||||
test_e2e.py::test_benchmark_sanity[bert_base] # 127.18s
|
||||
test_e2e.py::test_benchmark_sanity[gpt_350m] # 64.06s
|
||||
|
||||
Loading…
Reference in New Issue
Block a user