mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
* feat: adding multimodal (only image for now) support in trtllm-bench Signed-off-by: Rakib Hasan <rhasan@nvidia.com> * fix: add in load_dataset() calls to maintain the v2.19.2 behavior Signed-off-by: Rakib Hasan <rhasan@nvidia.com> * re-adding prompt_token_ids and using that for prompt_len Signed-off-by: Rakib Hasan <rhasan@nvidia.com> * updating the datasets version in examples as well Signed-off-by: Rakib Hasan <rhasan@nvidia.com> * api changes are not needed Signed-off-by: Rakib Hasan <rhasan@nvidia.com> * moving datasets requirement and removing a missed api change Signed-off-by: Rakib Hasan <rhasan@nvidia.com> * addressing review comments Signed-off-by: Rakib Hasan <rhasan@nvidia.com> * refactoring the quickstart example Signed-off-by: Rakib Hasan <rhasan@nvidia.com> --------- Signed-off-by: Rakib Hasan <rhasan@nvidia.com>
98 lines
3.7 KiB
Python
98 lines
3.7 KiB
Python
from __future__ import annotations
|
|
|
|
from pathlib import Path
|
|
from typing import Any, List, Optional, Union
|
|
|
|
from pydantic import (AliasChoices, BaseModel, Field, computed_field,
|
|
model_validator)
|
|
|
|
from tensorrt_llm.bench.dataclasses.statistics import PercentileStats
|
|
|
|
|
|
class BenchmarkEnvironment(BaseModel):
|
|
model: str
|
|
checkpoint_path: Optional[Path]
|
|
workspace: Path
|
|
|
|
|
|
class InferenceRequest(BaseModel):
|
|
task_id: int
|
|
prompt: Optional[Union[str, Any]] = None
|
|
output_tokens: int
|
|
input_ids: Optional[List[int]] = Field(
|
|
alias=AliasChoices("input_ids", "logits"))
|
|
|
|
@model_validator(mode="after")
|
|
def verify_prompt_and_logits(self) -> InferenceRequest:
|
|
if self.prompt is None and self.input_ids is None:
|
|
raise ValueError(
|
|
f"Both prompt and input_ids for {self.task_id} are both None.")
|
|
return self
|
|
|
|
|
|
class DatasetMetadata(BaseModel):
|
|
isl_stats: PercentileStats
|
|
osl_stats: PercentileStats
|
|
seq_len_stats: PercentileStats
|
|
num_requests: int
|
|
dataset_path: Optional[Path] = None
|
|
|
|
@computed_field
|
|
@property
|
|
def max_isl(self) -> int:
|
|
return int(self.isl_stats.maximum)
|
|
|
|
@computed_field
|
|
@property
|
|
def max_osl(self) -> int:
|
|
return int(self.osl_stats.maximum)
|
|
|
|
@computed_field
|
|
@property
|
|
def max_sequence_length(self) -> int:
|
|
return int(self.seq_len_stats.maximum)
|
|
|
|
@computed_field
|
|
@property
|
|
def avg_isl(self) -> int:
|
|
return int(self.isl_stats.average)
|
|
|
|
@computed_field
|
|
@property
|
|
def avg_osl(self) -> int:
|
|
return int(self.osl_stats.average)
|
|
|
|
@computed_field
|
|
@property
|
|
def avg_sequence_length(self) -> int:
|
|
return int(self.seq_len_stats.average)
|
|
|
|
def _format_number(self, value: float) -> str:
|
|
"""Format number to fit within 9 characters including decimal."""
|
|
if value >= 100000:
|
|
return f"{value:9.2e}".ljust(
|
|
9
|
|
) # Scientific notation for large numbers, padded to 9 characters
|
|
return f"{value:9.4f}".ljust(
|
|
9) # Fixed point for smaller numbers, padded to 9 characters
|
|
|
|
def get_summary_for_print(self) -> str:
|
|
form = self._format_number
|
|
return (
|
|
"\n===========================================================\n"
|
|
"= DATASET DETAILS\n"
|
|
"===========================================================\n"
|
|
f"Dataset Path: {self.dataset_path}\n"
|
|
f"Number of Sequences: {self.num_requests}\n"
|
|
"\n-- Percentiles statistics ---------------------------------\n\n"
|
|
" Input Output Seq. Length\n"
|
|
"-----------------------------------------------------------\n"
|
|
f"MIN: {form(self.isl_stats.minimum)} {form(self.osl_stats.minimum)} {form(self.seq_len_stats.minimum)}\n"
|
|
f"MAX: {form(self.isl_stats.maximum)} {form(self.osl_stats.maximum)} {form(self.seq_len_stats.maximum)}\n"
|
|
f"AVG: {form(self.isl_stats.average)} {form(self.osl_stats.average)} {form(self.seq_len_stats.average)}\n"
|
|
f"P50: {form(self.isl_stats.p50)} {form(self.osl_stats.p50)} {form(self.seq_len_stats.p50)}\n"
|
|
f"P90: {form(self.isl_stats.p90)} {form(self.osl_stats.p90)} {form(self.seq_len_stats.p90)}\n"
|
|
f"P95: {form(self.isl_stats.p95)} {form(self.osl_stats.p95)} {form(self.seq_len_stats.p95)}\n"
|
|
f"P99: {form(self.isl_stats.p99)} {form(self.osl_stats.p99)} {form(self.seq_len_stats.p99)}\n"
|
|
"===========================================================\n")
|