TensorRT-LLMs/tensorrt_llm/bench/benchmark/throughput.py
Kaiyu Xie aaacc9bd68
Update TensorRT-LLM (#2562)
* Update TensorRT-LLM

---------

Co-authored-by: Starrick Liu <73152103+StarrickLiu@users.noreply.github.com>
2024-12-11 00:31:05 -08:00

213 lines
7.2 KiB
Python

from __future__ import annotations
import asyncio
import multiprocessing as mp
import os
from pathlib import Path
import click
from click_option_group import optgroup
from tensorrt_llm.bench.benchmark.utils.asynchronous import async_benchmark
# isort: off
from tensorrt_llm.bench.benchmark.utils.general import (get_executor_requests,
get_settings_from_engine
)
# isort: on
from tensorrt_llm.bench.benchmark.utils.multiproc import ThroughputBenchmark
from tensorrt_llm.bench.dataclasses.configuration import RuntimeConfig
from tensorrt_llm.bench.dataclasses.enums import IFBSchedulingPolicy
from tensorrt_llm.bench.dataclasses.general import BenchmarkEnvironment
from tensorrt_llm.bench.utils.data import (create_dataset_from_stream,
initialize_tokenizer)
from tensorrt_llm.logger import logger
@click.command(name="throughput")
@optgroup.group("Engine run configuration.",
help="Runtime settings for executing a TensorRT-LLM engine.")
@optgroup.option(
"--engine_dir",
type=click.Path(exists=True,
readable=True,
path_type=Path,
resolve_path=True),
required=True,
help="Path to a serialized TRT-LLM engine.",
)
@optgroup.option(
"--max_batch_size",
type=int,
help="Maximum runtime batch size to run the engine with.",
)
@optgroup.option(
"--max_num_tokens",
type=int,
help="Maximum runtime tokens that an engine can accept.",
)
@optgroup.option(
"--beam_width",
type=int,
default=1,
help="Number of search beams.",
)
@optgroup.option(
"--kv_cache_free_gpu_mem_fraction",
type=float,
default=.90,
help="The percentage of memory to use for KV Cache after model load.",
)
@optgroup.group(
"Engine Input Configuration",
help="Input configuration for driving the engine.",
)
@optgroup.option(
"--dataset",
type=click.Path(exists=True,
readable=True,
path_type=Path,
resolve_path=True),
default=None,
help="Pass in a dataset file for parsing instead of stdin.",
)
@optgroup.option(
"--request_rate",
type=int,
default=-1,
help="Desired input request rate (number of messages per second).",
hidden=True,
)
@optgroup.option(
"--num_requests",
type=int,
default=0,
help="Number of requests to cap benchmark run at. Minimum between value and"
"length of dataset.",
)
@click.option(
"--streaming",
is_flag=True,
default=False,
help="Enable streaming mode for requests.",
)
@click.option(
"--iteration_log",
type=click.Path(dir_okay=False,
writable=True,
readable=False,
path_type=Path,
resolve_path=True),
required=False,
help="Path where iteration stats should be written to.",
)
@click.pass_obj
def throughput_command(
bench_env: BenchmarkEnvironment,
**params,
) -> None:
"""Run a throughput test on a TRT-LLM engine."""
TRTLLM_BENCH_EXPERIMENTAL = os.environ.get("TRTLLM_BENCH_EXPERIMENT", False)
logger.set_level("info")
logger.info("Preparing to run throughput benchmark...")
# Parameters from CLI
# Model, experiment, and engine params
dataset_path: Path = params.pop("dataset")
request_rate: int = params.pop("request_rate")
num_requests: int = params.pop("num_requests")
model: str = bench_env.model
checkpoint_path: Path = bench_env.model_path or bench_env.model
engine_dir: Path = params.pop("engine_dir")
iteration_log: Path = params.pop("iteration_log")
# Engine configuration parsing
exec_settings, build_cfg = get_settings_from_engine(engine_dir)
exec_settings["model"] = model
engine_bs = exec_settings["settings_config"]["max_batch_size"]
engine_tokens = exec_settings["settings_config"]["max_num_tokens"]
engine_max_seq_len = build_cfg["max_seq_len"]
# Check that we are not using a low latency engine
# Right now, this is based on max batch size.
if engine_bs == 1:
raise ValueError(
"An engine with a batch size greater than 1 should be used for "
"throughput benchmarking. Exiting.")
# Runtime Options
runtime_max_bs = params.pop("max_batch_size")
runtime_max_bs = runtime_max_bs if runtime_max_bs else engine_bs
runtime_max_tokens = params.pop("max_num_tokens")
runtime_max_tokens = runtime_max_bs if runtime_max_tokens else engine_tokens
kv_cache_percent = params.pop("kv_cache_free_gpu_mem_fraction")
beam_width = params.pop("beam_width")
streaming: bool = params.pop("streaming")
# Update configuration with runtime options
exec_settings["settings_config"]["kv_cache_percent"] = kv_cache_percent
exec_settings["settings_config"]["max_batch_size"] = runtime_max_bs
exec_settings["settings_config"]["max_num_tokens"] = runtime_max_tokens
exec_settings["settings_config"]["beam_width"] = beam_width
exec_settings["settings_config"][
"scheduler_policy"] = IFBSchedulingPolicy.NO_EVICT
# Dynamic runtime features.
exec_settings["settings_config"]["dynamic_max_batch_size"] = True
# Construct the runtime configuration dataclass.
runtime_config = RuntimeConfig(**exec_settings)
# Initialize the HF tokenizer for the specified model.
tokenizer = initialize_tokenizer(checkpoint_path)
# Dataset Loading and Preparation
with open(dataset_path, "r") as dataset:
metadata, requests = create_dataset_from_stream(
tokenizer, dataset, num_requests=num_requests)
# TODO: Verify that the engine can handle the max/min ISL/OSL.
if metadata.max_sequence_length > engine_max_seq_len:
raise RuntimeError(
f"Engine supports a max sequence of {engine_max_seq_len}. Provided "
"dataset contains a maximum sequence of "
f"{metadata.max_sequence_length}. Please rebuild a new engine to"
"support this dataset.")
if TRTLLM_BENCH_EXPERIMENTAL:
asyncio.run(async_benchmark(runtime_config, requests, streaming))
else:
# Dataset Loading and Preparation
executor_requests = get_executor_requests(
requests,
streaming,
eos_id=-1,
pad_id=-1,
)
del requests
logger.info("Setting up benchmarker and infrastructure.")
new_request_queue = mp.Queue()
response_queue = mp.Queue()
benchmark = ThroughputBenchmark(
dataset=executor_requests,
request_rate=request_rate,
runtime_cfg=runtime_config,
request_queue=new_request_queue,
response_queue=response_queue,
streaming=streaming,
iteration_log=iteration_log,
)
try:
logger.info("Ready to start benchmark.")
benchmark.start_benchmark()
benchmark.wait()
benchmark.stop_benchmark()
benchmark.dump_extra_stats()
benchmark.report_statistics()
except KeyboardInterrupt:
benchmark.stop_benchmark()
finally:
benchmark.shutdown()