TensorRT-LLMs/tensorrt_llm/commands/bench.py
Aurelien Chartier f2f197360d
[#9463][feat] Add revision option to trtllm commands (#9498)
Signed-off-by: Aurelien Chartier <2567591+achartier@users.noreply.github.com>
2025-11-27 09:30:01 +08:00

71 lines
2.1 KiB
Python

from pathlib import Path
from typing import Optional
import click
from tensorrt_llm.bench.benchmark.low_latency import latency_command
from tensorrt_llm.bench.benchmark.throughput import throughput_command
from tensorrt_llm.bench.build.build import build_command
from tensorrt_llm.bench.dataclasses.general import BenchmarkEnvironment
from tensorrt_llm.logger import logger, severity_map
@click.group(name="trtllm-bench", context_settings={'show_default': True})
@click.option(
"--model",
"-m",
required=True,
type=str,
help="The Huggingface name of the model to benchmark.",
)
@click.option(
"--model_path",
required=False,
default=None,
type=click.Path(writable=False, readable=True, path_type=Path),
help=
"Path to a Huggingface checkpoint directory for loading model components.",
)
@click.option(
"--workspace",
"-w",
required=False,
type=click.Path(writable=True, readable=True, path_type=Path),
default="/tmp", # nosec B108
help="The directory to store benchmarking intermediate files.",
)
@click.option('--log_level',
type=click.Choice(severity_map.keys()),
default='info',
help="The logging level.")
@click.option("--revision",
type=str,
default=None,
help="The revision to use for the HuggingFace model "
"(branch name, tag name, or commit id).")
@click.pass_context
def main(
ctx,
model: str,
model_path: Path,
workspace: Path,
log_level: str,
revision: Optional[str],
) -> None:
logger.set_level(log_level)
ctx.obj = BenchmarkEnvironment(model=model,
checkpoint_path=model_path,
workspace=workspace,
revision=revision)
# Create the workspace where we plan to store intermediate files.
ctx.obj.workspace.mkdir(parents=True, exist_ok=True)
main.add_command(build_command)
main.add_command(throughput_command)
main.add_command(latency_command)
if __name__ == "__main__":
main()