diff --git a/tensorrt_llm/serve/scripts/benchmark_dataset.py b/tensorrt_llm/serve/scripts/benchmark_dataset.py index 35d2744aea..485e6354ef 100644 --- a/tensorrt_llm/serve/scripts/benchmark_dataset.py +++ b/tensorrt_llm/serve/scripts/benchmark_dataset.py @@ -20,6 +20,7 @@ SampleRequest instances, similar to the approach used in ShareGPT. import json import logging import random +import time from abc import ABC, abstractmethod from dataclasses import dataclass from typing import Any, Callable, Optional, Union @@ -38,6 +39,162 @@ logger = logging.getLogger(__name__) # ----------------------------------------------------------------------------- +def timing_decorator(method_name: str): + """ + Decorator to time method execution and print the results. + + Args: + method_name: Name to display in timing output (e.g., 'load_data', 'sample') + """ + + def decorator(func): + + def wrapper(self, *args, **kwargs): + dataset_name = self.__class__.__name__ + start_time = time.perf_counter() + print(f"{dataset_name}.{method_name}() started...") + + try: + result = func(self, *args, **kwargs) + end_time = time.perf_counter() + duration = end_time - start_time + print( + f"{dataset_name}.{method_name}() completed in {duration:.4f} seconds" + ) + return result + except Exception as e: + end_time = time.perf_counter() + duration = end_time - start_time + print( + f"{dataset_name}.{method_name}() failed after {duration:.4f} seconds: {str(e)}" + ) + raise + + return wrapper + + return decorator + + +def auto_time_methods(*method_names): + """ + Class decorator that automatically applies timing to specified methods + in the class and all its subclasses. + + Usage: + @auto_time_methods("load_data", "sample") + class MyDataset(BenchmarkDataset): + def load_data(self): # Will be automatically timed + pass + def sample(self): # Will be automatically timed + pass + """ + + def class_decorator(cls): + # Store the method names that should be timed + cls._timed_methods = method_names + + # Override __init_subclass__ to automatically apply timing to subclasses + original_init_subclass = getattr(cls, '__init_subclass__', + lambda **kwargs: None) + + @classmethod + def __init_subclass__(subcls, **kwargs): + original_init_subclass(**kwargs) + + # Apply timing to the specified methods if they exist in the subclass + for method_name in method_names: + if hasattr(subcls, method_name): + original_method = getattr(subcls, method_name) + + # Only wrap if not already wrapped (check for our wrapper's signature) + if not hasattr(original_method, '_is_timed'): + timed_method = timing_decorator(method_name)( + original_method) + timed_method._is_timed = True + setattr(subcls, method_name, timed_method) + + cls.__init_subclass__ = __init_subclass__ + + # Also apply timing to methods in the current class + for method_name in method_names: + if hasattr(cls, method_name): + original_method = getattr(cls, method_name) + if not hasattr(original_method, '_is_timed'): + timed_method = timing_decorator(method_name)( + original_method) + timed_method._is_timed = True + setattr(cls, method_name, timed_method) + + return cls + + return class_decorator + + +def batch_tokenize_prompts( + prompts: list[str], + tokenizer: PreTrainedTokenizerBase, + batch_size: int = 1000, + progress_name: str = "prompts") -> tuple[list[int], list[list[int]]]: + """ + Efficiently tokenize a list of prompts using batch processing. + + Args: + prompts: List of text prompts to tokenize + tokenizer: The tokenizer to use + batch_size: Number of prompts to process in each batch + progress_name: Name to show in progress messages + + Returns: + Tuple of (prompt_lengths, prompt_token_ids) where: + - prompt_lengths: List of prompt lengths (number of tokens per prompt) + - prompt_token_ids: List of token ID lists for each prompt + """ + import time + + if not prompts: + return [], [] + + print( + f"Batch tokenizing {len(prompts)} {progress_name} (batch_size={batch_size})..." + ) + + prompt_lengths = [] + prompt_token_ids = [] + total_time = 0 + + for i in range(0, len(prompts), batch_size): + batch_prompts = prompts[i:i + batch_size] + + # Batch tokenization + start_time = time.perf_counter() + batch_encoded = tokenizer(batch_prompts, + padding=False, + truncation=False) + batch_time = time.perf_counter() - start_time + total_time += batch_time + + # Extract lengths and token IDs + for j in range(len(batch_prompts)): + token_ids = batch_encoded.input_ids[j] + prompt_lengths.append(len(token_ids)) + prompt_token_ids.append(token_ids) + + # Progress reporting + if (i + batch_size) % 5000 == 0 or (i + batch_size) >= len(prompts): + processed = min(i + batch_size, len(prompts)) + avg_time = total_time / processed * 1000 + print( + f" Processed {processed}/{len(prompts)} {progress_name} - Avg: {avg_time:.2f}ms per item" + ) + + avg_time_per_prompt = total_time / len(prompts) * 1000 + print( + f"Batch tokenization completed: {total_time:.4f}s total ({avg_time_per_prompt:.2f}ms per {progress_name[:-1]})" + ) + + return prompt_lengths, prompt_token_ids + + @dataclass class SampleRequest: """ @@ -54,6 +211,7 @@ class SampleRequest: # ----------------------------------------------------------------------------- +@auto_time_methods("load_data", "sample") class BenchmarkDataset(ABC): DEFAULT_SEED = 0 IS_MULTIMODAL = False @@ -256,23 +414,25 @@ class RandomDataset(BenchmarkDataset): # Shuffle the dataset. random.shuffle(dataset) + # Batch tokenize all prompts first for efficiency + prompt_lengths, prompt_token_ids = batch_tokenize_prompts( + dataset, tokenizer, progress_name="random dataset prompts") + # Filter out sequences that are too long or too short requests = [] - for prompt in dataset: + for prompt, initial_prompt_len, cached_token_ids in zip( + dataset, prompt_lengths, prompt_token_ids): i = len(requests) if i == num_requests: break - # Tokenize the prompts and completions. - prompt_token_ids = tokenizer.encode(prompt) - prompt_len = len(prompt_token_ids) - # Skip empty prompt - if prompt_len == 0: + if initial_prompt_len == 0: continue - if prompt_len > input_lens[i]: - input_ids = prompt_token_ids[:input_lens[i]] + if initial_prompt_len > input_lens[i]: + # Use cached token IDs to avoid re-encoding + input_ids = cached_token_ids[:input_lens[i]] else: # Re-calculate the prompt length to exclude special tokens. prompt_len = len( @@ -281,11 +441,12 @@ class RandomDataset(BenchmarkDataset): continue ratio = (input_lens[i] + prompt_len) // prompt_len prompt = " ".join([prompt] * ratio) - prompt_token_ids = tokenizer.encode(prompt) - while len(prompt_token_ids) < input_lens[i]: + prompt_token_ids_for_truncation = tokenizer.encode(prompt) + while len(prompt_token_ids_for_truncation) < input_lens[i]: prompt += " " + prompt - prompt_token_ids = tokenizer.encode(prompt) - input_ids = prompt_token_ids[:input_lens[i]] + prompt_token_ids_for_truncation = tokenizer.encode( + prompt) + input_ids = prompt_token_ids_for_truncation[:input_lens[i]] prompt = prefix_token_ids + input_ids @@ -363,20 +524,36 @@ class CustomDataset(BenchmarkDataset): def sample(self, tokenizer: PreTrainedTokenizerBase, num_requests: int) -> list[SampleRequest]: - samples: list = [] - for entry in self.data: - if len(samples) >= num_requests: + """ + Optimized version using batch tokenization for better performance. + """ + # Collect all prompts and metadata + prompts = [] + max_tokens_list = [] + + for i, entry in enumerate(self.data): + if len(prompts) >= num_requests: break prompt = entry["input"]["messages"][1]["content"] - prompt_ids = tokenizer(prompt).input_ids - prompt_len = len(prompt_ids) max_tokens = entry["input"]["max_tokens"] + prompts.append(prompt) + max_tokens_list.append(max_tokens) + + # Use batch tokenization utility + prompt_lengths, _ = batch_tokenize_prompts( + prompts, tokenizer, progress_name="custom dataset prompts") + + # Create SampleRequest objects + samples = [] + for prompt, prompt_len, max_tokens in zip(prompts, prompt_lengths, + max_tokens_list): samples.append( SampleRequest( prompt=prompt, prompt_len=prompt_len, expected_output_len=max_tokens, )) + return samples @@ -428,33 +605,47 @@ class ShareGPTDataset(BenchmarkDataset): enable_multimodal_chat: bool = False, **kwargs, ) -> list: - samples: list = [] + if enable_multimodal_chat: + raise NotImplementedError + + # Collect prompts and completions for batch processing + prompts = [] + completions = [] + for entry in self.data: - if len(samples) >= num_requests: + if len(prompts) >= num_requests: break prompt, completion = ( entry["conversations"][0]["value"], entry["conversations"][1]["value"], ) + prompts.append(prompt) + completions.append(completion) - prompt_ids = tokenizer(prompt).input_ids - completion_ids = tokenizer(completion).input_ids - prompt_len = len(prompt_ids) - new_output_len = (len(completion_ids) - if output_len is None else output_len) + # Batch tokenize prompts and completions + prompt_lengths, _ = batch_tokenize_prompts( + prompts, tokenizer, progress_name="ShareGPT prompts") + completion_lengths, _ = batch_tokenize_prompts( + completions, tokenizer, progress_name="ShareGPT completions") + + # Filter and create samples + samples: list = [] + for prompt, completion, prompt_len, completion_len in zip( + prompts, completions, prompt_lengths, completion_lengths): + new_output_len = completion_len if output_len is None else output_len if not is_valid_sequence(prompt_len, new_output_len, skip_min_output_len_check=output_len is not None): continue - if enable_multimodal_chat: - raise NotImplementedError + samples.append( SampleRequest( prompt=prompt, prompt_len=prompt_len, expected_output_len=new_output_len, )) + self.maybe_oversample_requests(samples, num_requests) return samples @@ -498,10 +689,11 @@ class SonnetDataset(BenchmarkDataset): return_prompt_formatted: bool = False, **kwargs, ) -> list: - # Calculate average token length for a poem line. - tokenized_lines = [tokenizer(line).input_ids for line in self.data] - avg_len = sum(len(tokens) - for tokens in tokenized_lines) / len(tokenized_lines) + # Calculate average token length for poem lines using batch tokenization + line_lengths, _ = batch_tokenize_prompts(self.data, + tokenizer, + progress_name="sonnet lines") + avg_len = sum(line_lengths) / len(line_lengths) # Build the base prompt. base_prompt = "Pick as many lines as you can from these poem lines:\n" @@ -658,34 +850,47 @@ class ConversationDataset(HuggingFaceDataset): output_len: Optional[int] = None, enable_multimodal_chat: bool = False, **kwargs) -> list: - # Filter examples with at least 2 conversations + if enable_multimodal_chat: + raise NotImplementedError + + # Filter examples with at least 2 conversations and collect data filtered_data = self.data.filter(lambda x: len(x["conversations"]) >= 2) - sampled_requests = [] + prompts = [] + completions = [] dynamic_output = output_len is None for item in filtered_data: - if len(sampled_requests) >= num_requests: + if len(prompts) >= num_requests: break conv = item["conversations"] prompt, completion = conv[0]["value"], conv[1]["value"] + prompts.append(prompt) + completions.append(completion) - prompt_ids = tokenizer(prompt).input_ids - completion_ids = tokenizer(completion).input_ids - prompt_len = len(prompt_ids) - completion_len = len(completion_ids) - output_len = completion_len if dynamic_output else output_len - assert isinstance(output_len, int) and output_len > 0 + # Batch tokenize prompts and completions + prompt_lengths, _ = batch_tokenize_prompts( + prompts, tokenizer, progress_name="conversation prompts") + completion_lengths, _ = batch_tokenize_prompts( + completions, tokenizer, progress_name="conversation completions") + + # Filter and create samples + sampled_requests = [] + for prompt, completion, prompt_len, completion_len in zip( + prompts, completions, prompt_lengths, completion_lengths): + current_output_len = completion_len if dynamic_output else output_len + assert isinstance(current_output_len, + int) and current_output_len > 0 if dynamic_output and not is_valid_sequence(prompt_len, completion_len): continue - if enable_multimodal_chat: - raise NotImplementedError + sampled_requests.append( SampleRequest( prompt=prompt, prompt_len=prompt_len, - expected_output_len=output_len, + expected_output_len=current_output_len, )) + self.maybe_oversample_requests(sampled_requests, num_requests) return sampled_requests @@ -717,20 +922,31 @@ class VisionArenaDataset(HuggingFaceDataset): enable_multimodal_chat: bool = False, **kwargs, ) -> list: + if enable_multimodal_chat: + raise NotImplementedError + output_len = (output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN) - sampled_requests = [] + + # Collect prompts for batch processing + prompts = [] + parser_fn = self.SUPPORTED_DATASET_PATHS.get(self.dataset_path) + if parser_fn is None: + raise ValueError(f"Unsupported dataset path: {self.dataset_path}") + for item in self.data: - if len(sampled_requests) >= num_requests: + if len(prompts) >= num_requests: break - parser_fn = self.SUPPORTED_DATASET_PATHS.get(self.dataset_path) - if parser_fn is None: - raise ValueError( - f"Unsupported dataset path: {self.dataset_path}") prompt = parser_fn(item) - prompt_len = len(tokenizer(prompt).input_ids) - if enable_multimodal_chat: - raise NotImplementedError + prompts.append(prompt) + + # Batch tokenize prompts + prompt_lengths, _ = batch_tokenize_prompts( + prompts, tokenizer, progress_name="vision arena prompts") + + # Create samples + sampled_requests = [] + for prompt, prompt_len in zip(prompts, prompt_lengths): sampled_requests.append( SampleRequest( prompt=prompt, @@ -769,12 +985,22 @@ class InstructCoderDataset(HuggingFaceDataset): **kwargs) -> list: output_len = (output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN) - sampled_requests = [] + + # Collect prompts for batch processing + prompts = [] for item in self.data: - if len(sampled_requests) >= num_requests: + if len(prompts) >= num_requests: break prompt = f"{item['instruction']}:\n{item['input']}" - prompt_len = len(tokenizer(prompt).input_ids) + prompts.append(prompt) + + # Batch tokenize prompts + prompt_lengths, _ = batch_tokenize_prompts( + prompts, tokenizer, progress_name="instruct coder prompts") + + # Create samples + sampled_requests = [] + for prompt, prompt_len in zip(prompts, prompt_lengths): sampled_requests.append( SampleRequest( prompt=prompt, @@ -813,22 +1039,31 @@ class MTBenchDataset(HuggingFaceDataset): **kwargs) -> list: output_len = (output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN) - sampled_requests = [] + # Collect prompts for batch processing + prompts = [] for item in self.data: - if len(sampled_requests) >= num_requests: + if len(prompts) >= num_requests: break - prompt = item['turns'][0] + raw_prompt = item['turns'][0] # apply template - prompt = tokenizer.apply_chat_template([{ - "role": "user", - "content": prompt - }], - add_generation_prompt=True, - tokenize=False) + formatted_prompt = tokenizer.apply_chat_template( + [{ + "role": "user", + "content": raw_prompt + }], + add_generation_prompt=True, + tokenize=False) + prompts.append(formatted_prompt) - prompt_len = len(tokenizer(prompt).input_ids) + # Batch tokenize prompts + prompt_lengths, _ = batch_tokenize_prompts( + prompts, tokenizer, progress_name="MT-Bench prompts") + + # Create samples + sampled_requests = [] + for prompt, prompt_len in zip(prompts, prompt_lengths): sampled_requests.append( SampleRequest( prompt=prompt, @@ -858,20 +1093,32 @@ class AIMODataset(HuggingFaceDataset): num_requests: int, output_len: Optional[int] = None, **kwargs) -> list: - sampled_requests = [] dynamic_output = output_len is None + # Collect prompts and completions for batch processing + prompts = [] + completions = [] for item in self.data: - if len(sampled_requests) >= num_requests: + if len(prompts) >= num_requests: break prompt, completion = item['problem'], item["solution"] + prompts.append(prompt) + completions.append(completion) - prompt_ids = tokenizer(prompt).input_ids - completion_ids = tokenizer(completion).input_ids - prompt_len = len(prompt_ids) - completion_len = len(completion_ids) - output_len = completion_len if dynamic_output else output_len - assert isinstance(output_len, int) and output_len > 0 + # Batch tokenize prompts and completions + prompt_lengths, _ = batch_tokenize_prompts(prompts, + tokenizer, + progress_name="AIMO prompts") + completion_lengths, _ = batch_tokenize_prompts( + completions, tokenizer, progress_name="AIMO completions") + + # Filter and create samples + sampled_requests = [] + for prompt, completion, prompt_len, completion_len in zip( + prompts, completions, prompt_lengths, completion_lengths): + current_output_len = completion_len if dynamic_output else output_len + assert isinstance(current_output_len, + int) and current_output_len > 0 if dynamic_output and not is_valid_sequence(prompt_len, completion_len, max_prompt_len=2048, @@ -881,7 +1128,7 @@ class AIMODataset(HuggingFaceDataset): SampleRequest( prompt=prompt, prompt_len=prompt_len, - expected_output_len=output_len, + expected_output_len=current_output_len, )) self.maybe_oversample_requests(sampled_requests, num_requests) return sampled_requests