mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-23 20:23:08 +08:00
529 lines
20 KiB
Python
529 lines
20 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
import multiprocessing as mp
|
|
from copy import deepcopy
|
|
from datetime import timedelta
|
|
from pathlib import Path
|
|
from threading import Event, Thread
|
|
from time import monotonic_ns, sleep
|
|
from typing import Generator, List, Tuple
|
|
|
|
import click
|
|
from click_option_group import optgroup
|
|
|
|
import tensorrt_llm.bindings.executor as trtllm
|
|
from tensorrt_llm.bench.benchmark.dataclasses import (BenchmarkStatistics,
|
|
RuntimeConfig)
|
|
from tensorrt_llm.bench.benchmark.utils import (ResponseTuple, StatsKeeper,
|
|
get_executor_requests,
|
|
get_settings_from_engine)
|
|
from tensorrt_llm.bench.dataclasses import BenchmarkEnvironment
|
|
from tensorrt_llm.bench.enums import IFBSchedulingPolicy
|
|
from tensorrt_llm.bench.utils.data import (create_dataset_from_stream,
|
|
initialize_tokenizer)
|
|
from tensorrt_llm.logger import logger
|
|
|
|
|
|
@click.command(name="throughput")
|
|
@optgroup.group("Engine run configuration.",
|
|
help="Runtime settings for executing a TensorRT-LLM engine.")
|
|
@optgroup.option(
|
|
"--engine_dir",
|
|
type=click.Path(exists=True,
|
|
readable=True,
|
|
path_type=Path,
|
|
resolve_path=True),
|
|
required=True,
|
|
help="Path to a serialized TRT-LLM engine.",
|
|
)
|
|
@optgroup.option(
|
|
"--max_batch_size",
|
|
type=int,
|
|
help="Maximum runtime batch size to run the engine with.",
|
|
)
|
|
@optgroup.option(
|
|
"--max_num_tokens",
|
|
type=int,
|
|
help="Maximum runtime tokens that an engine can accept.",
|
|
)
|
|
@optgroup.option(
|
|
"--beam_width",
|
|
type=int,
|
|
default=1,
|
|
help="Number of search beams.",
|
|
)
|
|
@optgroup.option(
|
|
"--kv_cache_free_gpu_mem_fraction",
|
|
type=float,
|
|
default=.90,
|
|
help="The percentage of memory to use for KV Cache after model load.",
|
|
)
|
|
@optgroup.group(
|
|
"Engine Input Configuration",
|
|
help="Input configuration for driving the engine.",
|
|
)
|
|
@optgroup.option(
|
|
"--dataset",
|
|
type=click.Path(exists=True,
|
|
readable=True,
|
|
path_type=Path,
|
|
resolve_path=True),
|
|
default=None,
|
|
help="Pass in a dataset file for parsing instead of stdin.",
|
|
)
|
|
@optgroup.option(
|
|
"--request_rate",
|
|
type=int,
|
|
default=-1,
|
|
help="Desired input request rate (number of messages per second).",
|
|
hidden=True,
|
|
)
|
|
@optgroup.option(
|
|
"--num_requests",
|
|
type=int,
|
|
default=0,
|
|
help="Number of requests to cap benchmark run at. Minimum between value and"
|
|
"length of dataset.",
|
|
)
|
|
@click.option(
|
|
"--streaming",
|
|
is_flag=True,
|
|
default=False,
|
|
help="Enable streaming mode for requests.",
|
|
)
|
|
@click.option(
|
|
"--iteration_log",
|
|
type=click.Path(dir_okay=False,
|
|
writable=True,
|
|
readable=False,
|
|
path_type=Path,
|
|
resolve_path=True),
|
|
required=False,
|
|
help="Path where iteration stats should be written to.",
|
|
)
|
|
@click.pass_obj
|
|
def throughput_command(
|
|
bench_env: BenchmarkEnvironment,
|
|
**params,
|
|
) -> None:
|
|
"""Run a throughput test on a TRT-LLM engine."""
|
|
|
|
logger.set_level("info")
|
|
logger.info("Preparing to run throughput benchmark...")
|
|
# Parameters from CLI
|
|
# Model, experiment, and engine params
|
|
dataset_path: Path = params.pop("dataset")
|
|
request_rate: int = params.pop("request_rate")
|
|
num_requests: int = params.pop("num_requests")
|
|
model: str = bench_env.model
|
|
engine_dir: Path = params.pop("engine_dir")
|
|
iteration_log: Path = params.pop("iteration_log")
|
|
|
|
# Engine configuration parsing
|
|
exec_settings, build_cfg = get_settings_from_engine(engine_dir)
|
|
exec_settings["model"] = model
|
|
engine_bs = exec_settings["settings_config"]["max_batch_size"]
|
|
engine_tokens = exec_settings["settings_config"]["max_num_tokens"]
|
|
engine_max_seq_len = build_cfg["max_seq_len"]
|
|
|
|
# Check that we are not using a low latency engine
|
|
# Right now, this is based on max batch size.
|
|
if engine_bs == 1:
|
|
raise ValueError(
|
|
"An engine with a batch size greater than 1 should be used for "
|
|
"throughput benchmarking. Exiting.")
|
|
|
|
# Runtime Options
|
|
runtime_max_bs = params.pop("max_batch_size")
|
|
runtime_max_bs = runtime_max_bs if runtime_max_bs else engine_bs
|
|
runtime_max_tokens = params.pop("max_num_tokens")
|
|
runtime_max_tokens = runtime_max_bs if runtime_max_tokens else engine_tokens
|
|
kv_cache_percent = params.pop("kv_cache_free_gpu_mem_fraction")
|
|
beam_width = params.pop("beam_width")
|
|
streaming = params.pop("streaming")
|
|
|
|
# Update configuration with runtime options
|
|
exec_settings["settings_config"]["kv_cache_percent"] = kv_cache_percent
|
|
exec_settings["settings_config"]["max_batch_size"] = runtime_max_bs
|
|
exec_settings["settings_config"]["max_num_tokens"] = runtime_max_tokens
|
|
exec_settings["settings_config"]["beam_width"] = beam_width
|
|
exec_settings["settings_config"][
|
|
"scheduler_policy"] = IFBSchedulingPolicy.NO_EVICT
|
|
|
|
# Dynamic runtime features.
|
|
exec_settings["settings_config"]["dynamic_max_batch_size"] = True
|
|
|
|
# Construct the runtime configuration dataclass.
|
|
runtime_config = RuntimeConfig(**exec_settings)
|
|
|
|
# Initialize the HF tokenizer for the specified model.
|
|
tokenizer = initialize_tokenizer(bench_env.model)
|
|
|
|
# Dataset Loading and Preparation
|
|
with open(dataset_path, "r") as dataset:
|
|
metadata, requests = create_dataset_from_stream(
|
|
tokenizer, dataset, num_requests=num_requests)
|
|
# TODO: Verify that the engine can handle the max/min ISL/OSL.
|
|
if metadata.max_sequence_length > engine_max_seq_len:
|
|
raise RuntimeError(
|
|
f"Engine supports a max sequence of {engine_max_seq_len}. Provided "
|
|
"dataset contains a maximum sequence of "
|
|
f"{metadata.max_sequence_length}. Please rebuild a new engine to"
|
|
"support this dataset.")
|
|
|
|
# Dataset Loading and Preparation
|
|
executor_requests = get_executor_requests(
|
|
requests,
|
|
streaming,
|
|
eos_id=-1,
|
|
pad_id=-1,
|
|
)
|
|
del requests
|
|
|
|
logger.info("Setting up benchmarker and infrastructure.")
|
|
new_request_queue = mp.Queue()
|
|
response_queue = mp.Queue()
|
|
benchmark = ThroughputBenchmark(
|
|
dataset=executor_requests,
|
|
request_rate=request_rate,
|
|
runtime_cfg=runtime_config,
|
|
request_queue=new_request_queue,
|
|
response_queue=response_queue,
|
|
streaming=streaming,
|
|
iteration_log=iteration_log,
|
|
)
|
|
|
|
try:
|
|
logger.info("Ready to start benchmark.")
|
|
benchmark.start_benchmark()
|
|
benchmark.wait()
|
|
benchmark.stop_benchmark()
|
|
benchmark.dump_extra_stats()
|
|
benchmark.report_statistics()
|
|
except KeyboardInterrupt:
|
|
benchmark.stop_benchmark()
|
|
finally:
|
|
benchmark.shutdown()
|
|
|
|
|
|
class ExecutorManager:
|
|
"""Utility class for managing a TRT-LLM Executor instance."""
|
|
|
|
def __init__(self,
|
|
runtime_cfg: RuntimeConfig,
|
|
response_queue: mp.Queue,
|
|
iteration_log: Path = None) -> None:
|
|
"""Initialize the ExecutorManager.
|
|
|
|
Args:
|
|
runtime_cfg (RuntimeConfig): Execution runtime configuration.
|
|
response_queue (mp.Queue): Process-safe queue for passing request
|
|
responses to main process.
|
|
iteration_log (Path): Path to iteration log stored at end of run.
|
|
"""
|
|
logger.info("Initializing Executor.")
|
|
# Runtime related properties.
|
|
self.runtime_config: RuntimeConfig = runtime_cfg
|
|
# Runtime tracking and multiprocessing.
|
|
self.responses = response_queue
|
|
self._shutdown = Event()
|
|
self.backend_ready = Event()
|
|
self._resp_daemon_finished = Event()
|
|
self.iteration_log = iteration_log
|
|
config = self.runtime_config.get_config()
|
|
config.iter_stats_max_iterations = 100000000 if self.iteration_log else 0
|
|
self.executor = trtllm.Executor(self.runtime_config.engine_dir,
|
|
trtllm.ModelType.DECODER_ONLY,
|
|
executor_config=config)
|
|
|
|
logger.info("WAITING ON EXECUTOR...")
|
|
while not self.executor.can_enqueue_requests():
|
|
logger.info("Waiting for executor to stand up...")
|
|
sleep(1)
|
|
|
|
self.backend_ready.set()
|
|
|
|
self.response_thread = Thread(target=self.response_daemon)
|
|
self.response_thread.start()
|
|
|
|
def enqueue(self, *requests: trtllm.Request) -> Generator[Tuple[int, int]]:
|
|
"""Generate the next request identifier.
|
|
|
|
Yields:
|
|
Generator[int]: The request identifier of the last queued request.
|
|
"""
|
|
for request in requests:
|
|
req_id = self.executor.enqueue_request(request)
|
|
yield req_id, len(request.input_token_ids)
|
|
|
|
def stop(self) -> None:
|
|
"""Stop a running manager."""
|
|
|
|
logger.info("Stopping response parsing.")
|
|
self._shutdown.set()
|
|
self.response_thread.join()
|
|
logger.info("Parsing stopped.")
|
|
|
|
def shutdown(self) -> None:
|
|
"""Shutdown daemon components."""
|
|
|
|
if self.executor is not None:
|
|
logger.info("Shutting down ExecutorServer.")
|
|
self.executor.shutdown()
|
|
|
|
def dump_extra_stats(self) -> None:
|
|
if self.iteration_log is not None:
|
|
with open(self.iteration_log, "w") as iter_log:
|
|
for iteration in self.executor.get_latest_iteration_stats():
|
|
iter_log.write(f"{iteration.to_json_str()}\n")
|
|
|
|
def response_daemon(self) -> None:
|
|
"""Daemon method for retrieving messages from the Executor."""
|
|
|
|
logger.info("Starting response daemon...")
|
|
|
|
def _process_response() -> None:
|
|
responses = self.executor.await_responses(timeout=timedelta(
|
|
microseconds=0.00000000000001))
|
|
now = monotonic_ns()
|
|
if len(responses) > 0:
|
|
self.responses.put([
|
|
ResponseTuple(now, r.request_id, r.result.is_final,
|
|
r.has_error(), r.result.output_token_ids[0],
|
|
r.result.decoding_iter) for r in responses
|
|
])
|
|
|
|
while not self._shutdown.is_set():
|
|
_process_response()
|
|
|
|
logger.info("Collecting last responses before shutdown.")
|
|
# Reap the last messages before shutting down
|
|
_process_response()
|
|
self._resp_daemon_finished.set()
|
|
logger.info("Completed request parsing.")
|
|
|
|
|
|
class ThroughputBenchmark:
|
|
"""Throughput benchmark utility class."""
|
|
|
|
def __init__(
|
|
self,
|
|
dataset: List[trtllm.Request],
|
|
request_rate: int,
|
|
runtime_cfg: RuntimeConfig,
|
|
request_queue: mp.Queue,
|
|
response_queue: mp.Queue,
|
|
streaming: bool,
|
|
iteration_log: Path = None,
|
|
) -> None:
|
|
"""Initialize the throughput benchmark.
|
|
|
|
Args:
|
|
dataset (List[trtllm.Request]): A dataset of TRT-LLM requests to
|
|
benchmark against.
|
|
request_rate (int): Rate to deliver input requests to the backend.
|
|
runtime_cfg (RuntimeConfig): Runtime configuration.
|
|
request_queue (mp.Queue): Process-safe queue of request identifiers
|
|
response_queue (mp.Queue): Process-safe queue for passing request
|
|
responses to main process.
|
|
streaming (bool): Enable/disable streaming mode.
|
|
iteration_log (Path): Path to iteration log stored at end of run.
|
|
"""
|
|
logger.info(
|
|
f"Initializing Throughput Benchmark. [rate={request_rate} req/s]")
|
|
# Dataset and input properties.
|
|
self.requests = dataset
|
|
self.delay_func = lambda x: sleep(
|
|
x) if request_rate > 0 else lambda x: None
|
|
self.request_delay = 1.0 / request_rate
|
|
|
|
# Runtime configuration for Executor
|
|
self.runtime_config = deepcopy(runtime_cfg)
|
|
self.streaming = streaming
|
|
self.executor = None
|
|
self.iteration_log = iteration_log
|
|
|
|
# Request and response reporting structures
|
|
self.new_request_queue = request_queue
|
|
self.response_queue = response_queue
|
|
|
|
# Benchmark stats and time tracking.
|
|
self.start_time = None
|
|
self.end_time = None
|
|
self.submitted_requests = 0
|
|
self.statistics = StatsKeeper()
|
|
|
|
# Multiprocessing for handling request load generation
|
|
# and response parsing.
|
|
self.stop = mp.Event()
|
|
self.parsing_complete = mp.Event()
|
|
self.request_thread: Thread = Thread(target=self.enqueue_process)
|
|
self.stats_process: Thread = Thread(target=self.collect_statistics)
|
|
|
|
def enqueue_process(self) -> None:
|
|
"""Method for starting enqueueing requests."""
|
|
logger.info("WAITING ON BACKEND TO BE READY...")
|
|
self.executor.backend_ready.wait()
|
|
logger.info("Request serving started.")
|
|
request_generator = self.executor.enqueue(*self.requests)
|
|
# Iterate the generator until we run out of requests.
|
|
# Note the walrus operator.
|
|
while ((request := next(request_generator, False))
|
|
and not self.stop.is_set()):
|
|
self.submitted_requests += 1
|
|
timestamp = monotonic_ns()
|
|
self.new_request_queue.put((timestamp, request[0], request[1]))
|
|
self.delay_func(self.request_delay)
|
|
logger.info("Request serving stopped.")
|
|
|
|
def start_benchmark(self) -> None:
|
|
"""Start the benchmark."""
|
|
# Start the ExecutorManager for running the backend.
|
|
self.executor = ExecutorManager(
|
|
self.runtime_config,
|
|
self.response_queue,
|
|
iteration_log=self.iteration_log,
|
|
)
|
|
logger.info("Executor started.")
|
|
# Note the time we started the thread.
|
|
self.start_time = monotonic_ns()
|
|
self.request_thread.start()
|
|
# Start the statistics thread.
|
|
self.stats_process.start()
|
|
logger.info("Benchmark started.")
|
|
|
|
def stop_benchmark(self) -> None:
|
|
"""Stop the benchmark and clean up backend and threads."""
|
|
logger.info("Stop received.")
|
|
self.stop.set()
|
|
self.executor.stop()
|
|
self.request_thread.join()
|
|
logger.info("Request generator successfully joined.")
|
|
self.stats_process.join()
|
|
logger.info("Statistics process successfully joined.")
|
|
|
|
def shutdown(self) -> None:
|
|
"""Shutdown the backend."""
|
|
logger.info("Benchmark Shutdown called!")
|
|
if self.executor is not None:
|
|
self.executor.shutdown()
|
|
logger.info("Executor shutdown.")
|
|
|
|
def wait(self) -> bool:
|
|
"""Wait (blocking) on the benchmark.
|
|
|
|
Returns:
|
|
bool: Return whether the event is set.
|
|
"""
|
|
return not self.parsing_complete.wait()
|
|
|
|
def dump_extra_stats(self) -> None:
|
|
"""Write extended stats to a file."""
|
|
self.executor.dump_extra_stats()
|
|
|
|
def collect_statistics(self) -> None:
|
|
"""Collect statistics (daemon method)."""
|
|
logger.info("Starting statistics collection.")
|
|
|
|
def _process_requests() -> None:
|
|
while not self.new_request_queue.empty():
|
|
new_request: Tuple[float,
|
|
int] = self.new_request_queue.get_nowait()
|
|
self.statistics.register_request(new_request[1], new_request[0],
|
|
new_request[2])
|
|
|
|
while not self.response_queue.empty():
|
|
responses: Tuple[
|
|
int,
|
|
List[ResponseTuple]] = self.response_queue.get_nowait()
|
|
for response in responses:
|
|
self.statistics.register_response(
|
|
response.request_id,
|
|
response.timestamp,
|
|
response.final,
|
|
response.error,
|
|
response.decoding_iteration,
|
|
response.tokens,
|
|
)
|
|
|
|
logger.info("Collecting live stats...")
|
|
# TODO: Revisit this conditional, if the request rate is slow enough this
|
|
# will probably prematurely trip. We will likely need a conditional that
|
|
# captures a new event for submission being complete, with the stop event
|
|
# overriding it if detected.
|
|
while not self.stop.is_set(
|
|
) and self.statistics.num_complete < self.submitted_requests:
|
|
_process_requests()
|
|
|
|
logger.info("Collecting last stats...")
|
|
_process_requests()
|
|
self.end_time = monotonic_ns()
|
|
self.parsing_complete.set()
|
|
logger.info("Ending statistics collection.")
|
|
|
|
def report_statistics(self) -> BenchmarkStatistics:
|
|
"""Report internal statistics about benchmark."""
|
|
|
|
config_path = self.runtime_config.engine_dir / "config.json"
|
|
with open(config_path, "r") as config:
|
|
engine_config = json.load(config)
|
|
|
|
stats = self.statistics.generate_statistics_summary()
|
|
rt_cfg = self.runtime_config
|
|
build_cfg = engine_config["build_config"]
|
|
pretrain_cfg = engine_config["pretrained_config"]
|
|
total_latency_s = stats.total_latency_ns / 1.0e9
|
|
|
|
logging_info = (
|
|
"\n\n===========================================================\n"
|
|
"= ENGINE DETAILS\n"
|
|
"===========================================================\n"
|
|
f"Model:\t\t\t{rt_cfg.model}\n"
|
|
f"Engine Directory:\t{rt_cfg.engine_dir}\n"
|
|
f"TensorRT-LLM Version:\t{rt_cfg.sw_version}\n"
|
|
f"Dtype:\t\t\t{pretrain_cfg['dtype']}\n"
|
|
f"KV Cache Dtype:\t\t{pretrain_cfg['quantization']['kv_cache_quant_algo']}\n"
|
|
f"Quantization:\t\t{pretrain_cfg['quantization']['quant_algo']}\n"
|
|
f"Max Sequence Length:\t{build_cfg['max_seq_len']}\n"
|
|
f"\n"
|
|
"===========================================================\n"
|
|
"= WORLD + RUNTIME INFORMATION \n"
|
|
"===========================================================\n"
|
|
f"TP Size:\t\t{rt_cfg.world_config.tp_size}\n"
|
|
f"PP Size:\t\t{rt_cfg.world_config.pp_size}\n"
|
|
f"Max Runtime Batch Size:\t{rt_cfg.settings_config.max_batch_size}\n"
|
|
f"Max Runtime Tokens:\t{rt_cfg.settings_config.max_num_tokens}\n"
|
|
f"Scheduling Policy:\t{rt_cfg.settings_config.scheduler_policy.values[1]}\n"
|
|
f"KV Memory Percentage:\t{rt_cfg.settings_config.kv_cache_percent * 100.0:.2f}%\n"
|
|
f"Issue Rate (req/sec):\t{stats.issue_rate_ns * 1e9:.4E}\n"
|
|
f"\n"
|
|
"===========================================================\n"
|
|
"= PERFORMANCE OVERVIEW \n"
|
|
"===========================================================\n"
|
|
f"Number of requests:\t\t{stats.num_requests}\n"
|
|
f"Average Input Length (tokens):\t{stats.average_input_length:.4f}\n"
|
|
f"Average Output Length (tokens):\t{stats.average_output_length:.4f}\n"
|
|
f"Token Throughput (tokens/sec):\t{stats.total_output_tokens / total_latency_s:.4f}\n"
|
|
f"Request Throughput (req/sec):\t{stats.num_requests / total_latency_s:.4f}\n"
|
|
f"Total Latency (ms):\t\t{stats.total_latency_ns * 1.0e-6:.4f}\n")
|
|
|
|
if self.streaming:
|
|
logging_info = (
|
|
f"{logging_info}"
|
|
"\n"
|
|
"===========================================================\n"
|
|
"= STREAMING STATISTICS \n"
|
|
"===========================================================\n"
|
|
f"Average request latency (ms):\t\t{stats.request_latency_percentiles.average * 1.0e-6:.4f}\n"
|
|
f"Average time-to-first-token (ms):\t{stats.ttft_percentiles.average * 1.0e-6:.4f}\n"
|
|
f"Average inter-token latency (ms):\t{stats.itl_percentiles.average * 1.0e-6:.4f}\n"
|
|
)
|
|
|
|
logging_info = (
|
|
f"{logging_info}"
|
|
"\n===========================================================\n")
|
|
|
|
logger.info(logging_info)
|
|
return stats
|