TensorRT-LLMs/tensorrt_llm/bench/dataclasses/general.py
rakib-hasan ff3b741045
feat: adding multimodal (only image for now) support in trtllm-bench (#3490)
* feat: adding multimodal (only image for now) support in trtllm-bench

Signed-off-by: Rakib Hasan <rhasan@nvidia.com>

* fix: add  in load_dataset() calls to maintain the v2.19.2 behavior

Signed-off-by: Rakib Hasan <rhasan@nvidia.com>

* re-adding prompt_token_ids and using that for prompt_len

Signed-off-by: Rakib Hasan <rhasan@nvidia.com>

* updating the datasets version in examples as well

Signed-off-by: Rakib Hasan <rhasan@nvidia.com>

* api changes are not needed

Signed-off-by: Rakib Hasan <rhasan@nvidia.com>

* moving datasets requirement and removing a missed api change

Signed-off-by: Rakib Hasan <rhasan@nvidia.com>

* addressing review comments

Signed-off-by: Rakib Hasan <rhasan@nvidia.com>

* refactoring the quickstart example

Signed-off-by: Rakib Hasan <rhasan@nvidia.com>

---------

Signed-off-by: Rakib Hasan <rhasan@nvidia.com>
2025-04-18 07:06:16 +08:00

98 lines
3.7 KiB
Python

from __future__ import annotations
from pathlib import Path
from typing import Any, List, Optional, Union
from pydantic import (AliasChoices, BaseModel, Field, computed_field,
model_validator)
from tensorrt_llm.bench.dataclasses.statistics import PercentileStats
class BenchmarkEnvironment(BaseModel):
model: str
checkpoint_path: Optional[Path]
workspace: Path
class InferenceRequest(BaseModel):
task_id: int
prompt: Optional[Union[str, Any]] = None
output_tokens: int
input_ids: Optional[List[int]] = Field(
alias=AliasChoices("input_ids", "logits"))
@model_validator(mode="after")
def verify_prompt_and_logits(self) -> InferenceRequest:
if self.prompt is None and self.input_ids is None:
raise ValueError(
f"Both prompt and input_ids for {self.task_id} are both None.")
return self
class DatasetMetadata(BaseModel):
isl_stats: PercentileStats
osl_stats: PercentileStats
seq_len_stats: PercentileStats
num_requests: int
dataset_path: Optional[Path] = None
@computed_field
@property
def max_isl(self) -> int:
return int(self.isl_stats.maximum)
@computed_field
@property
def max_osl(self) -> int:
return int(self.osl_stats.maximum)
@computed_field
@property
def max_sequence_length(self) -> int:
return int(self.seq_len_stats.maximum)
@computed_field
@property
def avg_isl(self) -> int:
return int(self.isl_stats.average)
@computed_field
@property
def avg_osl(self) -> int:
return int(self.osl_stats.average)
@computed_field
@property
def avg_sequence_length(self) -> int:
return int(self.seq_len_stats.average)
def _format_number(self, value: float) -> str:
"""Format number to fit within 9 characters including decimal."""
if value >= 100000:
return f"{value:9.2e}".ljust(
9
) # Scientific notation for large numbers, padded to 9 characters
return f"{value:9.4f}".ljust(
9) # Fixed point for smaller numbers, padded to 9 characters
def get_summary_for_print(self) -> str:
form = self._format_number
return (
"\n===========================================================\n"
"= DATASET DETAILS\n"
"===========================================================\n"
f"Dataset Path: {self.dataset_path}\n"
f"Number of Sequences: {self.num_requests}\n"
"\n-- Percentiles statistics ---------------------------------\n\n"
" Input Output Seq. Length\n"
"-----------------------------------------------------------\n"
f"MIN: {form(self.isl_stats.minimum)} {form(self.osl_stats.minimum)} {form(self.seq_len_stats.minimum)}\n"
f"MAX: {form(self.isl_stats.maximum)} {form(self.osl_stats.maximum)} {form(self.seq_len_stats.maximum)}\n"
f"AVG: {form(self.isl_stats.average)} {form(self.osl_stats.average)} {form(self.seq_len_stats.average)}\n"
f"P50: {form(self.isl_stats.p50)} {form(self.osl_stats.p50)} {form(self.seq_len_stats.p50)}\n"
f"P90: {form(self.isl_stats.p90)} {form(self.osl_stats.p90)} {form(self.seq_len_stats.p90)}\n"
f"P95: {form(self.isl_stats.p95)} {form(self.osl_stats.p95)} {form(self.seq_len_stats.p95)}\n"
f"P99: {form(self.isl_stats.p99)} {form(self.osl_stats.p99)} {form(self.seq_len_stats.p99)}\n"
"===========================================================\n")