mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
98 lines
3.7 KiB
Python
98 lines
3.7 KiB
Python
from __future__ import annotations
|
|
|
|
from pathlib import Path
|
|
from typing import List, Optional
|
|
|
|
from pydantic import (AliasChoices, BaseModel, Field, computed_field,
|
|
model_validator)
|
|
|
|
from tensorrt_llm.bench.dataclasses.statistics import PercentileStats
|
|
|
|
|
|
class BenchmarkEnvironment(BaseModel):
|
|
model: str
|
|
checkpoint_path: Optional[Path]
|
|
workspace: Path
|
|
|
|
|
|
class InferenceRequest(BaseModel):
|
|
task_id: int
|
|
prompt: Optional[str] = None
|
|
output_tokens: int
|
|
input_ids: Optional[List[int]] = Field(
|
|
alias=AliasChoices("input_ids", "logits"))
|
|
|
|
@model_validator(mode="after")
|
|
def verify_prompt_and_logits(self) -> InferenceRequest:
|
|
if self.prompt is None and self.input_ids is None:
|
|
raise ValueError(
|
|
f"Both prompt and input_ids for {self.task_id} are both None.")
|
|
return self
|
|
|
|
|
|
class DatasetMetadata(BaseModel):
|
|
isl_stats: PercentileStats
|
|
osl_stats: PercentileStats
|
|
seq_len_stats: PercentileStats
|
|
num_requests: int
|
|
dataset_path: Optional[Path] = None
|
|
|
|
@computed_field
|
|
@property
|
|
def max_isl(self) -> int:
|
|
return int(self.isl_stats.maximum)
|
|
|
|
@computed_field
|
|
@property
|
|
def max_osl(self) -> int:
|
|
return int(self.osl_stats.maximum)
|
|
|
|
@computed_field
|
|
@property
|
|
def max_sequence_length(self) -> int:
|
|
return int(self.seq_len_stats.maximum)
|
|
|
|
@computed_field
|
|
@property
|
|
def avg_isl(self) -> int:
|
|
return int(self.isl_stats.average)
|
|
|
|
@computed_field
|
|
@property
|
|
def avg_osl(self) -> int:
|
|
return int(self.osl_stats.average)
|
|
|
|
@computed_field
|
|
@property
|
|
def avg_sequence_length(self) -> int:
|
|
return int(self.seq_len_stats.average)
|
|
|
|
def _format_number(self, value: float) -> str:
|
|
"""Format number to fit within 9 characters including decimal."""
|
|
if value >= 100000:
|
|
return f"{value:9.2e}".ljust(
|
|
9
|
|
) # Scientific notation for large numbers, padded to 9 characters
|
|
return f"{value:9.4f}".ljust(
|
|
9) # Fixed point for smaller numbers, padded to 9 characters
|
|
|
|
def get_summary_for_print(self) -> str:
|
|
form = self._format_number
|
|
return (
|
|
"\n===========================================================\n"
|
|
"= DATASET DETAILS\n"
|
|
"===========================================================\n"
|
|
f"Dataset Path: {self.dataset_path}\n"
|
|
f"Number of Sequences: {self.num_requests}\n"
|
|
"\n-- Percentiles statistics ---------------------------------\n\n"
|
|
" Input Output Seq. Length\n"
|
|
"-----------------------------------------------------------\n"
|
|
f"MIN: {form(self.isl_stats.minimum)} {form(self.osl_stats.minimum)} {form(self.seq_len_stats.minimum)}\n"
|
|
f"MAX: {form(self.isl_stats.maximum)} {form(self.osl_stats.maximum)} {form(self.seq_len_stats.maximum)}\n"
|
|
f"AVG: {form(self.isl_stats.average)} {form(self.osl_stats.average)} {form(self.seq_len_stats.average)}\n"
|
|
f"P50: {form(self.isl_stats.p50)} {form(self.osl_stats.p50)} {form(self.seq_len_stats.p50)}\n"
|
|
f"P90: {form(self.isl_stats.p90)} {form(self.osl_stats.p90)} {form(self.seq_len_stats.p90)}\n"
|
|
f"P95: {form(self.isl_stats.p95)} {form(self.osl_stats.p95)} {form(self.seq_len_stats.p95)}\n"
|
|
f"P99: {form(self.isl_stats.p99)} {form(self.osl_stats.p99)} {form(self.seq_len_stats.p99)}\n"
|
|
"===========================================================\n")
|