[None][feat] improve dataloading for benchmark_dataset by using batch… (#6548)

Signed-off-by: Zero Zeng <38289304+zerollzeng@users.noreply.github.com>
This commit is contained in:
Zero Zeng 2025-08-11 09:50:41 +08:00 committed by GitHub
parent 60073a7ad9
commit 4b4b91ab51
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -20,6 +20,7 @@ SampleRequest instances, similar to the approach used in ShareGPT.
import json
import logging
import random
import time
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Callable, Optional, Union
@ -38,6 +39,162 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
def timing_decorator(method_name: str):
"""
Decorator to time method execution and print the results.
Args:
method_name: Name to display in timing output (e.g., 'load_data', 'sample')
"""
def decorator(func):
def wrapper(self, *args, **kwargs):
dataset_name = self.__class__.__name__
start_time = time.perf_counter()
print(f"{dataset_name}.{method_name}() started...")
try:
result = func(self, *args, **kwargs)
end_time = time.perf_counter()
duration = end_time - start_time
print(
f"{dataset_name}.{method_name}() completed in {duration:.4f} seconds"
)
return result
except Exception as e:
end_time = time.perf_counter()
duration = end_time - start_time
print(
f"{dataset_name}.{method_name}() failed after {duration:.4f} seconds: {str(e)}"
)
raise
return wrapper
return decorator
def auto_time_methods(*method_names):
"""
Class decorator that automatically applies timing to specified methods
in the class and all its subclasses.
Usage:
@auto_time_methods("load_data", "sample")
class MyDataset(BenchmarkDataset):
def load_data(self): # Will be automatically timed
pass
def sample(self): # Will be automatically timed
pass
"""
def class_decorator(cls):
# Store the method names that should be timed
cls._timed_methods = method_names
# Override __init_subclass__ to automatically apply timing to subclasses
original_init_subclass = getattr(cls, '__init_subclass__',
lambda **kwargs: None)
@classmethod
def __init_subclass__(subcls, **kwargs):
original_init_subclass(**kwargs)
# Apply timing to the specified methods if they exist in the subclass
for method_name in method_names:
if hasattr(subcls, method_name):
original_method = getattr(subcls, method_name)
# Only wrap if not already wrapped (check for our wrapper's signature)
if not hasattr(original_method, '_is_timed'):
timed_method = timing_decorator(method_name)(
original_method)
timed_method._is_timed = True
setattr(subcls, method_name, timed_method)
cls.__init_subclass__ = __init_subclass__
# Also apply timing to methods in the current class
for method_name in method_names:
if hasattr(cls, method_name):
original_method = getattr(cls, method_name)
if not hasattr(original_method, '_is_timed'):
timed_method = timing_decorator(method_name)(
original_method)
timed_method._is_timed = True
setattr(cls, method_name, timed_method)
return cls
return class_decorator
def batch_tokenize_prompts(
prompts: list[str],
tokenizer: PreTrainedTokenizerBase,
batch_size: int = 1000,
progress_name: str = "prompts") -> tuple[list[int], list[list[int]]]:
"""
Efficiently tokenize a list of prompts using batch processing.
Args:
prompts: List of text prompts to tokenize
tokenizer: The tokenizer to use
batch_size: Number of prompts to process in each batch
progress_name: Name to show in progress messages
Returns:
Tuple of (prompt_lengths, prompt_token_ids) where:
- prompt_lengths: List of prompt lengths (number of tokens per prompt)
- prompt_token_ids: List of token ID lists for each prompt
"""
import time
if not prompts:
return [], []
print(
f"Batch tokenizing {len(prompts)} {progress_name} (batch_size={batch_size})..."
)
prompt_lengths = []
prompt_token_ids = []
total_time = 0
for i in range(0, len(prompts), batch_size):
batch_prompts = prompts[i:i + batch_size]
# Batch tokenization
start_time = time.perf_counter()
batch_encoded = tokenizer(batch_prompts,
padding=False,
truncation=False)
batch_time = time.perf_counter() - start_time
total_time += batch_time
# Extract lengths and token IDs
for j in range(len(batch_prompts)):
token_ids = batch_encoded.input_ids[j]
prompt_lengths.append(len(token_ids))
prompt_token_ids.append(token_ids)
# Progress reporting
if (i + batch_size) % 5000 == 0 or (i + batch_size) >= len(prompts):
processed = min(i + batch_size, len(prompts))
avg_time = total_time / processed * 1000
print(
f" Processed {processed}/{len(prompts)} {progress_name} - Avg: {avg_time:.2f}ms per item"
)
avg_time_per_prompt = total_time / len(prompts) * 1000
print(
f"Batch tokenization completed: {total_time:.4f}s total ({avg_time_per_prompt:.2f}ms per {progress_name[:-1]})"
)
return prompt_lengths, prompt_token_ids
@dataclass
class SampleRequest:
"""
@ -54,6 +211,7 @@ class SampleRequest:
# -----------------------------------------------------------------------------
@auto_time_methods("load_data", "sample")
class BenchmarkDataset(ABC):
DEFAULT_SEED = 0
IS_MULTIMODAL = False
@ -256,23 +414,25 @@ class RandomDataset(BenchmarkDataset):
# Shuffle the dataset.
random.shuffle(dataset)
# Batch tokenize all prompts first for efficiency
prompt_lengths, prompt_token_ids = batch_tokenize_prompts(
dataset, tokenizer, progress_name="random dataset prompts")
# Filter out sequences that are too long or too short
requests = []
for prompt in dataset:
for prompt, initial_prompt_len, cached_token_ids in zip(
dataset, prompt_lengths, prompt_token_ids):
i = len(requests)
if i == num_requests:
break
# Tokenize the prompts and completions.
prompt_token_ids = tokenizer.encode(prompt)
prompt_len = len(prompt_token_ids)
# Skip empty prompt
if prompt_len == 0:
if initial_prompt_len == 0:
continue
if prompt_len > input_lens[i]:
input_ids = prompt_token_ids[:input_lens[i]]
if initial_prompt_len > input_lens[i]:
# Use cached token IDs to avoid re-encoding
input_ids = cached_token_ids[:input_lens[i]]
else:
# Re-calculate the prompt length to exclude special tokens.
prompt_len = len(
@ -281,11 +441,12 @@ class RandomDataset(BenchmarkDataset):
continue
ratio = (input_lens[i] + prompt_len) // prompt_len
prompt = " ".join([prompt] * ratio)
prompt_token_ids = tokenizer.encode(prompt)
while len(prompt_token_ids) < input_lens[i]:
prompt_token_ids_for_truncation = tokenizer.encode(prompt)
while len(prompt_token_ids_for_truncation) < input_lens[i]:
prompt += " " + prompt
prompt_token_ids = tokenizer.encode(prompt)
input_ids = prompt_token_ids[:input_lens[i]]
prompt_token_ids_for_truncation = tokenizer.encode(
prompt)
input_ids = prompt_token_ids_for_truncation[:input_lens[i]]
prompt = prefix_token_ids + input_ids
@ -363,20 +524,36 @@ class CustomDataset(BenchmarkDataset):
def sample(self, tokenizer: PreTrainedTokenizerBase,
num_requests: int) -> list[SampleRequest]:
samples: list = []
for entry in self.data:
if len(samples) >= num_requests:
"""
Optimized version using batch tokenization for better performance.
"""
# Collect all prompts and metadata
prompts = []
max_tokens_list = []
for i, entry in enumerate(self.data):
if len(prompts) >= num_requests:
break
prompt = entry["input"]["messages"][1]["content"]
prompt_ids = tokenizer(prompt).input_ids
prompt_len = len(prompt_ids)
max_tokens = entry["input"]["max_tokens"]
prompts.append(prompt)
max_tokens_list.append(max_tokens)
# Use batch tokenization utility
prompt_lengths, _ = batch_tokenize_prompts(
prompts, tokenizer, progress_name="custom dataset prompts")
# Create SampleRequest objects
samples = []
for prompt, prompt_len, max_tokens in zip(prompts, prompt_lengths,
max_tokens_list):
samples.append(
SampleRequest(
prompt=prompt,
prompt_len=prompt_len,
expected_output_len=max_tokens,
))
return samples
@ -428,33 +605,47 @@ class ShareGPTDataset(BenchmarkDataset):
enable_multimodal_chat: bool = False,
**kwargs,
) -> list:
samples: list = []
if enable_multimodal_chat:
raise NotImplementedError
# Collect prompts and completions for batch processing
prompts = []
completions = []
for entry in self.data:
if len(samples) >= num_requests:
if len(prompts) >= num_requests:
break
prompt, completion = (
entry["conversations"][0]["value"],
entry["conversations"][1]["value"],
)
prompts.append(prompt)
completions.append(completion)
prompt_ids = tokenizer(prompt).input_ids
completion_ids = tokenizer(completion).input_ids
prompt_len = len(prompt_ids)
new_output_len = (len(completion_ids)
if output_len is None else output_len)
# Batch tokenize prompts and completions
prompt_lengths, _ = batch_tokenize_prompts(
prompts, tokenizer, progress_name="ShareGPT prompts")
completion_lengths, _ = batch_tokenize_prompts(
completions, tokenizer, progress_name="ShareGPT completions")
# Filter and create samples
samples: list = []
for prompt, completion, prompt_len, completion_len in zip(
prompts, completions, prompt_lengths, completion_lengths):
new_output_len = completion_len if output_len is None else output_len
if not is_valid_sequence(prompt_len,
new_output_len,
skip_min_output_len_check=output_len
is not None):
continue
if enable_multimodal_chat:
raise NotImplementedError
samples.append(
SampleRequest(
prompt=prompt,
prompt_len=prompt_len,
expected_output_len=new_output_len,
))
self.maybe_oversample_requests(samples, num_requests)
return samples
@ -498,10 +689,11 @@ class SonnetDataset(BenchmarkDataset):
return_prompt_formatted: bool = False,
**kwargs,
) -> list:
# Calculate average token length for a poem line.
tokenized_lines = [tokenizer(line).input_ids for line in self.data]
avg_len = sum(len(tokens)
for tokens in tokenized_lines) / len(tokenized_lines)
# Calculate average token length for poem lines using batch tokenization
line_lengths, _ = batch_tokenize_prompts(self.data,
tokenizer,
progress_name="sonnet lines")
avg_len = sum(line_lengths) / len(line_lengths)
# Build the base prompt.
base_prompt = "Pick as many lines as you can from these poem lines:\n"
@ -658,34 +850,47 @@ class ConversationDataset(HuggingFaceDataset):
output_len: Optional[int] = None,
enable_multimodal_chat: bool = False,
**kwargs) -> list:
# Filter examples with at least 2 conversations
if enable_multimodal_chat:
raise NotImplementedError
# Filter examples with at least 2 conversations and collect data
filtered_data = self.data.filter(lambda x: len(x["conversations"]) >= 2)
sampled_requests = []
prompts = []
completions = []
dynamic_output = output_len is None
for item in filtered_data:
if len(sampled_requests) >= num_requests:
if len(prompts) >= num_requests:
break
conv = item["conversations"]
prompt, completion = conv[0]["value"], conv[1]["value"]
prompts.append(prompt)
completions.append(completion)
prompt_ids = tokenizer(prompt).input_ids
completion_ids = tokenizer(completion).input_ids
prompt_len = len(prompt_ids)
completion_len = len(completion_ids)
output_len = completion_len if dynamic_output else output_len
assert isinstance(output_len, int) and output_len > 0
# Batch tokenize prompts and completions
prompt_lengths, _ = batch_tokenize_prompts(
prompts, tokenizer, progress_name="conversation prompts")
completion_lengths, _ = batch_tokenize_prompts(
completions, tokenizer, progress_name="conversation completions")
# Filter and create samples
sampled_requests = []
for prompt, completion, prompt_len, completion_len in zip(
prompts, completions, prompt_lengths, completion_lengths):
current_output_len = completion_len if dynamic_output else output_len
assert isinstance(current_output_len,
int) and current_output_len > 0
if dynamic_output and not is_valid_sequence(prompt_len,
completion_len):
continue
if enable_multimodal_chat:
raise NotImplementedError
sampled_requests.append(
SampleRequest(
prompt=prompt,
prompt_len=prompt_len,
expected_output_len=output_len,
expected_output_len=current_output_len,
))
self.maybe_oversample_requests(sampled_requests, num_requests)
return sampled_requests
@ -717,20 +922,31 @@ class VisionArenaDataset(HuggingFaceDataset):
enable_multimodal_chat: bool = False,
**kwargs,
) -> list:
if enable_multimodal_chat:
raise NotImplementedError
output_len = (output_len
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
sampled_requests = []
# Collect prompts for batch processing
prompts = []
parser_fn = self.SUPPORTED_DATASET_PATHS.get(self.dataset_path)
if parser_fn is None:
raise ValueError(f"Unsupported dataset path: {self.dataset_path}")
for item in self.data:
if len(sampled_requests) >= num_requests:
if len(prompts) >= num_requests:
break
parser_fn = self.SUPPORTED_DATASET_PATHS.get(self.dataset_path)
if parser_fn is None:
raise ValueError(
f"Unsupported dataset path: {self.dataset_path}")
prompt = parser_fn(item)
prompt_len = len(tokenizer(prompt).input_ids)
if enable_multimodal_chat:
raise NotImplementedError
prompts.append(prompt)
# Batch tokenize prompts
prompt_lengths, _ = batch_tokenize_prompts(
prompts, tokenizer, progress_name="vision arena prompts")
# Create samples
sampled_requests = []
for prompt, prompt_len in zip(prompts, prompt_lengths):
sampled_requests.append(
SampleRequest(
prompt=prompt,
@ -769,12 +985,22 @@ class InstructCoderDataset(HuggingFaceDataset):
**kwargs) -> list:
output_len = (output_len
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
sampled_requests = []
# Collect prompts for batch processing
prompts = []
for item in self.data:
if len(sampled_requests) >= num_requests:
if len(prompts) >= num_requests:
break
prompt = f"{item['instruction']}:\n{item['input']}"
prompt_len = len(tokenizer(prompt).input_ids)
prompts.append(prompt)
# Batch tokenize prompts
prompt_lengths, _ = batch_tokenize_prompts(
prompts, tokenizer, progress_name="instruct coder prompts")
# Create samples
sampled_requests = []
for prompt, prompt_len in zip(prompts, prompt_lengths):
sampled_requests.append(
SampleRequest(
prompt=prompt,
@ -813,22 +1039,31 @@ class MTBenchDataset(HuggingFaceDataset):
**kwargs) -> list:
output_len = (output_len
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
sampled_requests = []
# Collect prompts for batch processing
prompts = []
for item in self.data:
if len(sampled_requests) >= num_requests:
if len(prompts) >= num_requests:
break
prompt = item['turns'][0]
raw_prompt = item['turns'][0]
# apply template
prompt = tokenizer.apply_chat_template([{
"role": "user",
"content": prompt
}],
add_generation_prompt=True,
tokenize=False)
formatted_prompt = tokenizer.apply_chat_template(
[{
"role": "user",
"content": raw_prompt
}],
add_generation_prompt=True,
tokenize=False)
prompts.append(formatted_prompt)
prompt_len = len(tokenizer(prompt).input_ids)
# Batch tokenize prompts
prompt_lengths, _ = batch_tokenize_prompts(
prompts, tokenizer, progress_name="MT-Bench prompts")
# Create samples
sampled_requests = []
for prompt, prompt_len in zip(prompts, prompt_lengths):
sampled_requests.append(
SampleRequest(
prompt=prompt,
@ -858,20 +1093,32 @@ class AIMODataset(HuggingFaceDataset):
num_requests: int,
output_len: Optional[int] = None,
**kwargs) -> list:
sampled_requests = []
dynamic_output = output_len is None
# Collect prompts and completions for batch processing
prompts = []
completions = []
for item in self.data:
if len(sampled_requests) >= num_requests:
if len(prompts) >= num_requests:
break
prompt, completion = item['problem'], item["solution"]
prompts.append(prompt)
completions.append(completion)
prompt_ids = tokenizer(prompt).input_ids
completion_ids = tokenizer(completion).input_ids
prompt_len = len(prompt_ids)
completion_len = len(completion_ids)
output_len = completion_len if dynamic_output else output_len
assert isinstance(output_len, int) and output_len > 0
# Batch tokenize prompts and completions
prompt_lengths, _ = batch_tokenize_prompts(prompts,
tokenizer,
progress_name="AIMO prompts")
completion_lengths, _ = batch_tokenize_prompts(
completions, tokenizer, progress_name="AIMO completions")
# Filter and create samples
sampled_requests = []
for prompt, completion, prompt_len, completion_len in zip(
prompts, completions, prompt_lengths, completion_lengths):
current_output_len = completion_len if dynamic_output else output_len
assert isinstance(current_output_len,
int) and current_output_len > 0
if dynamic_output and not is_valid_sequence(prompt_len,
completion_len,
max_prompt_len=2048,
@ -881,7 +1128,7 @@ class AIMODataset(HuggingFaceDataset):
SampleRequest(
prompt=prompt,
prompt_len=prompt_len,
expected_output_len=output_len,
expected_output_len=current_output_len,
))
self.maybe_oversample_requests(sampled_requests, num_requests)
return sampled_requests