mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[None][feat] improve dataloading for benchmark_dataset by using batch… (#6548)
Signed-off-by: Zero Zeng <38289304+zerollzeng@users.noreply.github.com>
This commit is contained in:
parent
60073a7ad9
commit
4b4b91ab51
@ -20,6 +20,7 @@ SampleRequest instances, similar to the approach used in ShareGPT.
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Optional, Union
|
||||
@ -38,6 +39,162 @@ logger = logging.getLogger(__name__)
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def timing_decorator(method_name: str):
|
||||
"""
|
||||
Decorator to time method execution and print the results.
|
||||
|
||||
Args:
|
||||
method_name: Name to display in timing output (e.g., 'load_data', 'sample')
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
|
||||
def wrapper(self, *args, **kwargs):
|
||||
dataset_name = self.__class__.__name__
|
||||
start_time = time.perf_counter()
|
||||
print(f"{dataset_name}.{method_name}() started...")
|
||||
|
||||
try:
|
||||
result = func(self, *args, **kwargs)
|
||||
end_time = time.perf_counter()
|
||||
duration = end_time - start_time
|
||||
print(
|
||||
f"{dataset_name}.{method_name}() completed in {duration:.4f} seconds"
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
end_time = time.perf_counter()
|
||||
duration = end_time - start_time
|
||||
print(
|
||||
f"{dataset_name}.{method_name}() failed after {duration:.4f} seconds: {str(e)}"
|
||||
)
|
||||
raise
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def auto_time_methods(*method_names):
|
||||
"""
|
||||
Class decorator that automatically applies timing to specified methods
|
||||
in the class and all its subclasses.
|
||||
|
||||
Usage:
|
||||
@auto_time_methods("load_data", "sample")
|
||||
class MyDataset(BenchmarkDataset):
|
||||
def load_data(self): # Will be automatically timed
|
||||
pass
|
||||
def sample(self): # Will be automatically timed
|
||||
pass
|
||||
"""
|
||||
|
||||
def class_decorator(cls):
|
||||
# Store the method names that should be timed
|
||||
cls._timed_methods = method_names
|
||||
|
||||
# Override __init_subclass__ to automatically apply timing to subclasses
|
||||
original_init_subclass = getattr(cls, '__init_subclass__',
|
||||
lambda **kwargs: None)
|
||||
|
||||
@classmethod
|
||||
def __init_subclass__(subcls, **kwargs):
|
||||
original_init_subclass(**kwargs)
|
||||
|
||||
# Apply timing to the specified methods if they exist in the subclass
|
||||
for method_name in method_names:
|
||||
if hasattr(subcls, method_name):
|
||||
original_method = getattr(subcls, method_name)
|
||||
|
||||
# Only wrap if not already wrapped (check for our wrapper's signature)
|
||||
if not hasattr(original_method, '_is_timed'):
|
||||
timed_method = timing_decorator(method_name)(
|
||||
original_method)
|
||||
timed_method._is_timed = True
|
||||
setattr(subcls, method_name, timed_method)
|
||||
|
||||
cls.__init_subclass__ = __init_subclass__
|
||||
|
||||
# Also apply timing to methods in the current class
|
||||
for method_name in method_names:
|
||||
if hasattr(cls, method_name):
|
||||
original_method = getattr(cls, method_name)
|
||||
if not hasattr(original_method, '_is_timed'):
|
||||
timed_method = timing_decorator(method_name)(
|
||||
original_method)
|
||||
timed_method._is_timed = True
|
||||
setattr(cls, method_name, timed_method)
|
||||
|
||||
return cls
|
||||
|
||||
return class_decorator
|
||||
|
||||
|
||||
def batch_tokenize_prompts(
|
||||
prompts: list[str],
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
batch_size: int = 1000,
|
||||
progress_name: str = "prompts") -> tuple[list[int], list[list[int]]]:
|
||||
"""
|
||||
Efficiently tokenize a list of prompts using batch processing.
|
||||
|
||||
Args:
|
||||
prompts: List of text prompts to tokenize
|
||||
tokenizer: The tokenizer to use
|
||||
batch_size: Number of prompts to process in each batch
|
||||
progress_name: Name to show in progress messages
|
||||
|
||||
Returns:
|
||||
Tuple of (prompt_lengths, prompt_token_ids) where:
|
||||
- prompt_lengths: List of prompt lengths (number of tokens per prompt)
|
||||
- prompt_token_ids: List of token ID lists for each prompt
|
||||
"""
|
||||
import time
|
||||
|
||||
if not prompts:
|
||||
return [], []
|
||||
|
||||
print(
|
||||
f"Batch tokenizing {len(prompts)} {progress_name} (batch_size={batch_size})..."
|
||||
)
|
||||
|
||||
prompt_lengths = []
|
||||
prompt_token_ids = []
|
||||
total_time = 0
|
||||
|
||||
for i in range(0, len(prompts), batch_size):
|
||||
batch_prompts = prompts[i:i + batch_size]
|
||||
|
||||
# Batch tokenization
|
||||
start_time = time.perf_counter()
|
||||
batch_encoded = tokenizer(batch_prompts,
|
||||
padding=False,
|
||||
truncation=False)
|
||||
batch_time = time.perf_counter() - start_time
|
||||
total_time += batch_time
|
||||
|
||||
# Extract lengths and token IDs
|
||||
for j in range(len(batch_prompts)):
|
||||
token_ids = batch_encoded.input_ids[j]
|
||||
prompt_lengths.append(len(token_ids))
|
||||
prompt_token_ids.append(token_ids)
|
||||
|
||||
# Progress reporting
|
||||
if (i + batch_size) % 5000 == 0 or (i + batch_size) >= len(prompts):
|
||||
processed = min(i + batch_size, len(prompts))
|
||||
avg_time = total_time / processed * 1000
|
||||
print(
|
||||
f" Processed {processed}/{len(prompts)} {progress_name} - Avg: {avg_time:.2f}ms per item"
|
||||
)
|
||||
|
||||
avg_time_per_prompt = total_time / len(prompts) * 1000
|
||||
print(
|
||||
f"Batch tokenization completed: {total_time:.4f}s total ({avg_time_per_prompt:.2f}ms per {progress_name[:-1]})"
|
||||
)
|
||||
|
||||
return prompt_lengths, prompt_token_ids
|
||||
|
||||
|
||||
@dataclass
|
||||
class SampleRequest:
|
||||
"""
|
||||
@ -54,6 +211,7 @@ class SampleRequest:
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
@auto_time_methods("load_data", "sample")
|
||||
class BenchmarkDataset(ABC):
|
||||
DEFAULT_SEED = 0
|
||||
IS_MULTIMODAL = False
|
||||
@ -256,23 +414,25 @@ class RandomDataset(BenchmarkDataset):
|
||||
# Shuffle the dataset.
|
||||
random.shuffle(dataset)
|
||||
|
||||
# Batch tokenize all prompts first for efficiency
|
||||
prompt_lengths, prompt_token_ids = batch_tokenize_prompts(
|
||||
dataset, tokenizer, progress_name="random dataset prompts")
|
||||
|
||||
# Filter out sequences that are too long or too short
|
||||
requests = []
|
||||
for prompt in dataset:
|
||||
for prompt, initial_prompt_len, cached_token_ids in zip(
|
||||
dataset, prompt_lengths, prompt_token_ids):
|
||||
i = len(requests)
|
||||
if i == num_requests:
|
||||
break
|
||||
|
||||
# Tokenize the prompts and completions.
|
||||
prompt_token_ids = tokenizer.encode(prompt)
|
||||
prompt_len = len(prompt_token_ids)
|
||||
|
||||
# Skip empty prompt
|
||||
if prompt_len == 0:
|
||||
if initial_prompt_len == 0:
|
||||
continue
|
||||
|
||||
if prompt_len > input_lens[i]:
|
||||
input_ids = prompt_token_ids[:input_lens[i]]
|
||||
if initial_prompt_len > input_lens[i]:
|
||||
# Use cached token IDs to avoid re-encoding
|
||||
input_ids = cached_token_ids[:input_lens[i]]
|
||||
else:
|
||||
# Re-calculate the prompt length to exclude special tokens.
|
||||
prompt_len = len(
|
||||
@ -281,11 +441,12 @@ class RandomDataset(BenchmarkDataset):
|
||||
continue
|
||||
ratio = (input_lens[i] + prompt_len) // prompt_len
|
||||
prompt = " ".join([prompt] * ratio)
|
||||
prompt_token_ids = tokenizer.encode(prompt)
|
||||
while len(prompt_token_ids) < input_lens[i]:
|
||||
prompt_token_ids_for_truncation = tokenizer.encode(prompt)
|
||||
while len(prompt_token_ids_for_truncation) < input_lens[i]:
|
||||
prompt += " " + prompt
|
||||
prompt_token_ids = tokenizer.encode(prompt)
|
||||
input_ids = prompt_token_ids[:input_lens[i]]
|
||||
prompt_token_ids_for_truncation = tokenizer.encode(
|
||||
prompt)
|
||||
input_ids = prompt_token_ids_for_truncation[:input_lens[i]]
|
||||
|
||||
prompt = prefix_token_ids + input_ids
|
||||
|
||||
@ -363,20 +524,36 @@ class CustomDataset(BenchmarkDataset):
|
||||
|
||||
def sample(self, tokenizer: PreTrainedTokenizerBase,
|
||||
num_requests: int) -> list[SampleRequest]:
|
||||
samples: list = []
|
||||
for entry in self.data:
|
||||
if len(samples) >= num_requests:
|
||||
"""
|
||||
Optimized version using batch tokenization for better performance.
|
||||
"""
|
||||
# Collect all prompts and metadata
|
||||
prompts = []
|
||||
max_tokens_list = []
|
||||
|
||||
for i, entry in enumerate(self.data):
|
||||
if len(prompts) >= num_requests:
|
||||
break
|
||||
prompt = entry["input"]["messages"][1]["content"]
|
||||
prompt_ids = tokenizer(prompt).input_ids
|
||||
prompt_len = len(prompt_ids)
|
||||
max_tokens = entry["input"]["max_tokens"]
|
||||
prompts.append(prompt)
|
||||
max_tokens_list.append(max_tokens)
|
||||
|
||||
# Use batch tokenization utility
|
||||
prompt_lengths, _ = batch_tokenize_prompts(
|
||||
prompts, tokenizer, progress_name="custom dataset prompts")
|
||||
|
||||
# Create SampleRequest objects
|
||||
samples = []
|
||||
for prompt, prompt_len, max_tokens in zip(prompts, prompt_lengths,
|
||||
max_tokens_list):
|
||||
samples.append(
|
||||
SampleRequest(
|
||||
prompt=prompt,
|
||||
prompt_len=prompt_len,
|
||||
expected_output_len=max_tokens,
|
||||
))
|
||||
|
||||
return samples
|
||||
|
||||
|
||||
@ -428,33 +605,47 @@ class ShareGPTDataset(BenchmarkDataset):
|
||||
enable_multimodal_chat: bool = False,
|
||||
**kwargs,
|
||||
) -> list:
|
||||
samples: list = []
|
||||
if enable_multimodal_chat:
|
||||
raise NotImplementedError
|
||||
|
||||
# Collect prompts and completions for batch processing
|
||||
prompts = []
|
||||
completions = []
|
||||
|
||||
for entry in self.data:
|
||||
if len(samples) >= num_requests:
|
||||
if len(prompts) >= num_requests:
|
||||
break
|
||||
prompt, completion = (
|
||||
entry["conversations"][0]["value"],
|
||||
entry["conversations"][1]["value"],
|
||||
)
|
||||
prompts.append(prompt)
|
||||
completions.append(completion)
|
||||
|
||||
prompt_ids = tokenizer(prompt).input_ids
|
||||
completion_ids = tokenizer(completion).input_ids
|
||||
prompt_len = len(prompt_ids)
|
||||
new_output_len = (len(completion_ids)
|
||||
if output_len is None else output_len)
|
||||
# Batch tokenize prompts and completions
|
||||
prompt_lengths, _ = batch_tokenize_prompts(
|
||||
prompts, tokenizer, progress_name="ShareGPT prompts")
|
||||
completion_lengths, _ = batch_tokenize_prompts(
|
||||
completions, tokenizer, progress_name="ShareGPT completions")
|
||||
|
||||
# Filter and create samples
|
||||
samples: list = []
|
||||
for prompt, completion, prompt_len, completion_len in zip(
|
||||
prompts, completions, prompt_lengths, completion_lengths):
|
||||
new_output_len = completion_len if output_len is None else output_len
|
||||
if not is_valid_sequence(prompt_len,
|
||||
new_output_len,
|
||||
skip_min_output_len_check=output_len
|
||||
is not None):
|
||||
continue
|
||||
if enable_multimodal_chat:
|
||||
raise NotImplementedError
|
||||
|
||||
samples.append(
|
||||
SampleRequest(
|
||||
prompt=prompt,
|
||||
prompt_len=prompt_len,
|
||||
expected_output_len=new_output_len,
|
||||
))
|
||||
|
||||
self.maybe_oversample_requests(samples, num_requests)
|
||||
return samples
|
||||
|
||||
@ -498,10 +689,11 @@ class SonnetDataset(BenchmarkDataset):
|
||||
return_prompt_formatted: bool = False,
|
||||
**kwargs,
|
||||
) -> list:
|
||||
# Calculate average token length for a poem line.
|
||||
tokenized_lines = [tokenizer(line).input_ids for line in self.data]
|
||||
avg_len = sum(len(tokens)
|
||||
for tokens in tokenized_lines) / len(tokenized_lines)
|
||||
# Calculate average token length for poem lines using batch tokenization
|
||||
line_lengths, _ = batch_tokenize_prompts(self.data,
|
||||
tokenizer,
|
||||
progress_name="sonnet lines")
|
||||
avg_len = sum(line_lengths) / len(line_lengths)
|
||||
|
||||
# Build the base prompt.
|
||||
base_prompt = "Pick as many lines as you can from these poem lines:\n"
|
||||
@ -658,34 +850,47 @@ class ConversationDataset(HuggingFaceDataset):
|
||||
output_len: Optional[int] = None,
|
||||
enable_multimodal_chat: bool = False,
|
||||
**kwargs) -> list:
|
||||
# Filter examples with at least 2 conversations
|
||||
if enable_multimodal_chat:
|
||||
raise NotImplementedError
|
||||
|
||||
# Filter examples with at least 2 conversations and collect data
|
||||
filtered_data = self.data.filter(lambda x: len(x["conversations"]) >= 2)
|
||||
sampled_requests = []
|
||||
prompts = []
|
||||
completions = []
|
||||
dynamic_output = output_len is None
|
||||
|
||||
for item in filtered_data:
|
||||
if len(sampled_requests) >= num_requests:
|
||||
if len(prompts) >= num_requests:
|
||||
break
|
||||
conv = item["conversations"]
|
||||
prompt, completion = conv[0]["value"], conv[1]["value"]
|
||||
prompts.append(prompt)
|
||||
completions.append(completion)
|
||||
|
||||
prompt_ids = tokenizer(prompt).input_ids
|
||||
completion_ids = tokenizer(completion).input_ids
|
||||
prompt_len = len(prompt_ids)
|
||||
completion_len = len(completion_ids)
|
||||
output_len = completion_len if dynamic_output else output_len
|
||||
assert isinstance(output_len, int) and output_len > 0
|
||||
# Batch tokenize prompts and completions
|
||||
prompt_lengths, _ = batch_tokenize_prompts(
|
||||
prompts, tokenizer, progress_name="conversation prompts")
|
||||
completion_lengths, _ = batch_tokenize_prompts(
|
||||
completions, tokenizer, progress_name="conversation completions")
|
||||
|
||||
# Filter and create samples
|
||||
sampled_requests = []
|
||||
for prompt, completion, prompt_len, completion_len in zip(
|
||||
prompts, completions, prompt_lengths, completion_lengths):
|
||||
current_output_len = completion_len if dynamic_output else output_len
|
||||
assert isinstance(current_output_len,
|
||||
int) and current_output_len > 0
|
||||
if dynamic_output and not is_valid_sequence(prompt_len,
|
||||
completion_len):
|
||||
continue
|
||||
if enable_multimodal_chat:
|
||||
raise NotImplementedError
|
||||
|
||||
sampled_requests.append(
|
||||
SampleRequest(
|
||||
prompt=prompt,
|
||||
prompt_len=prompt_len,
|
||||
expected_output_len=output_len,
|
||||
expected_output_len=current_output_len,
|
||||
))
|
||||
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests)
|
||||
return sampled_requests
|
||||
|
||||
@ -717,20 +922,31 @@ class VisionArenaDataset(HuggingFaceDataset):
|
||||
enable_multimodal_chat: bool = False,
|
||||
**kwargs,
|
||||
) -> list:
|
||||
if enable_multimodal_chat:
|
||||
raise NotImplementedError
|
||||
|
||||
output_len = (output_len
|
||||
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
|
||||
sampled_requests = []
|
||||
|
||||
# Collect prompts for batch processing
|
||||
prompts = []
|
||||
parser_fn = self.SUPPORTED_DATASET_PATHS.get(self.dataset_path)
|
||||
if parser_fn is None:
|
||||
raise ValueError(f"Unsupported dataset path: {self.dataset_path}")
|
||||
|
||||
for item in self.data:
|
||||
if len(sampled_requests) >= num_requests:
|
||||
if len(prompts) >= num_requests:
|
||||
break
|
||||
parser_fn = self.SUPPORTED_DATASET_PATHS.get(self.dataset_path)
|
||||
if parser_fn is None:
|
||||
raise ValueError(
|
||||
f"Unsupported dataset path: {self.dataset_path}")
|
||||
prompt = parser_fn(item)
|
||||
prompt_len = len(tokenizer(prompt).input_ids)
|
||||
if enable_multimodal_chat:
|
||||
raise NotImplementedError
|
||||
prompts.append(prompt)
|
||||
|
||||
# Batch tokenize prompts
|
||||
prompt_lengths, _ = batch_tokenize_prompts(
|
||||
prompts, tokenizer, progress_name="vision arena prompts")
|
||||
|
||||
# Create samples
|
||||
sampled_requests = []
|
||||
for prompt, prompt_len in zip(prompts, prompt_lengths):
|
||||
sampled_requests.append(
|
||||
SampleRequest(
|
||||
prompt=prompt,
|
||||
@ -769,12 +985,22 @@ class InstructCoderDataset(HuggingFaceDataset):
|
||||
**kwargs) -> list:
|
||||
output_len = (output_len
|
||||
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
|
||||
sampled_requests = []
|
||||
|
||||
# Collect prompts for batch processing
|
||||
prompts = []
|
||||
for item in self.data:
|
||||
if len(sampled_requests) >= num_requests:
|
||||
if len(prompts) >= num_requests:
|
||||
break
|
||||
prompt = f"{item['instruction']}:\n{item['input']}"
|
||||
prompt_len = len(tokenizer(prompt).input_ids)
|
||||
prompts.append(prompt)
|
||||
|
||||
# Batch tokenize prompts
|
||||
prompt_lengths, _ = batch_tokenize_prompts(
|
||||
prompts, tokenizer, progress_name="instruct coder prompts")
|
||||
|
||||
# Create samples
|
||||
sampled_requests = []
|
||||
for prompt, prompt_len in zip(prompts, prompt_lengths):
|
||||
sampled_requests.append(
|
||||
SampleRequest(
|
||||
prompt=prompt,
|
||||
@ -813,22 +1039,31 @@ class MTBenchDataset(HuggingFaceDataset):
|
||||
**kwargs) -> list:
|
||||
output_len = (output_len
|
||||
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
|
||||
sampled_requests = []
|
||||
|
||||
# Collect prompts for batch processing
|
||||
prompts = []
|
||||
for item in self.data:
|
||||
if len(sampled_requests) >= num_requests:
|
||||
if len(prompts) >= num_requests:
|
||||
break
|
||||
prompt = item['turns'][0]
|
||||
raw_prompt = item['turns'][0]
|
||||
|
||||
# apply template
|
||||
prompt = tokenizer.apply_chat_template([{
|
||||
"role": "user",
|
||||
"content": prompt
|
||||
}],
|
||||
add_generation_prompt=True,
|
||||
tokenize=False)
|
||||
formatted_prompt = tokenizer.apply_chat_template(
|
||||
[{
|
||||
"role": "user",
|
||||
"content": raw_prompt
|
||||
}],
|
||||
add_generation_prompt=True,
|
||||
tokenize=False)
|
||||
prompts.append(formatted_prompt)
|
||||
|
||||
prompt_len = len(tokenizer(prompt).input_ids)
|
||||
# Batch tokenize prompts
|
||||
prompt_lengths, _ = batch_tokenize_prompts(
|
||||
prompts, tokenizer, progress_name="MT-Bench prompts")
|
||||
|
||||
# Create samples
|
||||
sampled_requests = []
|
||||
for prompt, prompt_len in zip(prompts, prompt_lengths):
|
||||
sampled_requests.append(
|
||||
SampleRequest(
|
||||
prompt=prompt,
|
||||
@ -858,20 +1093,32 @@ class AIMODataset(HuggingFaceDataset):
|
||||
num_requests: int,
|
||||
output_len: Optional[int] = None,
|
||||
**kwargs) -> list:
|
||||
sampled_requests = []
|
||||
dynamic_output = output_len is None
|
||||
|
||||
# Collect prompts and completions for batch processing
|
||||
prompts = []
|
||||
completions = []
|
||||
for item in self.data:
|
||||
if len(sampled_requests) >= num_requests:
|
||||
if len(prompts) >= num_requests:
|
||||
break
|
||||
prompt, completion = item['problem'], item["solution"]
|
||||
prompts.append(prompt)
|
||||
completions.append(completion)
|
||||
|
||||
prompt_ids = tokenizer(prompt).input_ids
|
||||
completion_ids = tokenizer(completion).input_ids
|
||||
prompt_len = len(prompt_ids)
|
||||
completion_len = len(completion_ids)
|
||||
output_len = completion_len if dynamic_output else output_len
|
||||
assert isinstance(output_len, int) and output_len > 0
|
||||
# Batch tokenize prompts and completions
|
||||
prompt_lengths, _ = batch_tokenize_prompts(prompts,
|
||||
tokenizer,
|
||||
progress_name="AIMO prompts")
|
||||
completion_lengths, _ = batch_tokenize_prompts(
|
||||
completions, tokenizer, progress_name="AIMO completions")
|
||||
|
||||
# Filter and create samples
|
||||
sampled_requests = []
|
||||
for prompt, completion, prompt_len, completion_len in zip(
|
||||
prompts, completions, prompt_lengths, completion_lengths):
|
||||
current_output_len = completion_len if dynamic_output else output_len
|
||||
assert isinstance(current_output_len,
|
||||
int) and current_output_len > 0
|
||||
if dynamic_output and not is_valid_sequence(prompt_len,
|
||||
completion_len,
|
||||
max_prompt_len=2048,
|
||||
@ -881,7 +1128,7 @@ class AIMODataset(HuggingFaceDataset):
|
||||
SampleRequest(
|
||||
prompt=prompt,
|
||||
prompt_len=prompt_len,
|
||||
expected_output_len=output_len,
|
||||
expected_output_len=current_output_len,
|
||||
))
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests)
|
||||
return sampled_requests
|
||||
|
||||
Loading…
Reference in New Issue
Block a user