[None][fix] Update to pull LLM from a central location. (#6458)

Signed-off-by: Frank Di Natale <3429989+FrankD412@users.noreply.github.com>
This commit is contained in:
Frank 2025-08-25 13:07:29 -07:00 committed by GitHub
parent 6a44e5b9d1
commit 788fc62d23
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 287 additions and 202 deletions

View File

@ -0,0 +1,159 @@
import json
from pathlib import Path
from typing import Callable, Dict, Optional
from pydantic import AliasChoices, BaseModel, Field
from tensorrt_llm import LLM as PyTorchLLM
from tensorrt_llm._tensorrt_engine import LLM
from tensorrt_llm._torch.auto_deploy import LLM as AutoDeployLLM
from tensorrt_llm.bench.benchmark.utils.processes import IterationWriter
from tensorrt_llm.bench.build.build import get_model_config
from tensorrt_llm.bench.dataclasses.configuration import RuntimeConfig
from tensorrt_llm.bench.dataclasses.general import BenchmarkEnvironment
from tensorrt_llm.logger import logger
class GeneralExecSettings(BaseModel):
model_config = {
"extra": "ignore"
} # Ignore extra fields not defined in the model
backend: str = Field(
default="pytorch",
description="The backend to use when running benchmarking")
beam_width: int = Field(default=1, description="Number of search beams")
model_path: Optional[Path] = Field(default=None,
description="Path to model checkpoint")
concurrency: int = Field(
default=-1, description="Desired concurrency rate, <=0 for no limit")
dataset_path: Optional[Path] = Field(default=None,
validation_alias=AliasChoices(
"dataset_path", "dataset"),
description="Path to dataset file")
engine_dir: Optional[Path] = Field(
default=None, description="Path to a serialized TRT-LLM engine")
eos_id: int = Field(
default=-1, description="End-of-sequence token ID, -1 to disable EOS")
iteration_log: Optional[Path] = Field(
default=None, description="Path where iteration logging is written")
kv_cache_percent: float = Field(
default=0.90,
validation_alias=AliasChoices("kv_cache_percent",
"kv_cache_free_gpu_mem_fraction"),
description="Percentage of memory for KV Cache after model load")
max_input_len: int = Field(default=4096,
description="Maximum input sequence length")
max_seq_len: Optional[int] = Field(default=None,
description="Maximum sequence length")
modality: Optional[str] = Field(
default=None, description="Modality of multimodal requests")
model: Optional[str] = Field(default=None, description="Model name or path")
num_requests: int = Field(
default=0, description="Number of requests to cap benchmark run at")
output_json: Optional[Path] = Field(
default=None, description="Path where output should be written")
report_json: Optional[Path] = Field(
default=None, description="Path where report should be written")
request_json: Optional[Path] = Field(
default=None,
description="Path where per request information is written")
streaming: bool = Field(default=False,
description="Whether to use streaming mode")
warmup: int = Field(default=2,
description="Number of requests to warm up benchmark")
@property
def iteration_writer(self) -> IterationWriter:
return IterationWriter(self.iteration_log)
@property
def model_type(self) -> str:
return get_model_config(self.model, self.checkpoint_path).model_type
@property
def checkpoint_path(self) -> Path:
return self.model_path or self.model
def ignore_trt_only_args(kwargs: dict, backend: str):
"""Ignore TensorRT-only arguments for non-TensorRT backends.
Args:
kwargs: Dictionary of keyword arguments to be passed to the LLM constructor.
backend: The backend type (e.g., "pytorch", "_autodeploy").
"""
trt_only_args = [
"batching_type",
"normalize_log_probs",
"extended_runtime_perf_knob_config",
]
for arg in trt_only_args:
if kwargs.pop(arg, None):
logger.warning(f"Ignore {arg} for {backend} backend.")
def get_llm(runtime_config: RuntimeConfig, kwargs: dict):
"""Create and return an appropriate LLM instance based on the backend configuration.
Args:
runtime_config: Runtime configuration containing backend selection and settings.
kwargs: Additional keyword arguments to pass to the LLM constructor.
Returns:
An instance of the appropriate LLM class for the specified backend.
"""
llm_cls = LLM
if runtime_config.backend != "tensorrt":
ignore_trt_only_args(kwargs, runtime_config.backend)
if runtime_config.backend == 'pytorch':
llm_cls = PyTorchLLM
if runtime_config.iteration_log is not None:
kwargs["enable_iter_perf_stats"] = True
elif runtime_config.backend == "_autodeploy":
kwargs["world_size"] = kwargs.pop("tensor_parallel_size", None)
llm_cls = AutoDeployLLM
llm = llm_cls(**kwargs)
return llm
def get_general_cli_options(
params: Dict, bench_env: BenchmarkEnvironment) -> GeneralExecSettings:
"""Get general execution settings from command line parameters.
Args:
params: Dictionary of command line parameters.
bench_env: Benchmark environment containing model and checkpoint information.
Returns:
An instance of GeneralExecSettings containing general execution settings.
"""
# Create a copy of params to avoid modifying the original
settings_dict = params.copy()
# Add derived values that need to be computed from bench_env
model_path = bench_env.checkpoint_path
model = bench_env.model
# Override/add the computed values
settings_dict.update({
"model_path": model_path,
"model": model,
})
# Create and return the settings object, ignoring any extra fields
return GeneralExecSettings(**settings_dict)
def generate_json_report(report_path: Optional[Path], func: Callable):
if report_path is None:
logger.debug("No report path provided, skipping report generation.")
else:
logger.info(f"Writing report information to {report_path}...")
with open(report_path, "w") as f:
f.write(json.dumps(func(), indent=4))
logger.info(f"Report information written to {report_path}.")

View File

@ -1,8 +1,8 @@
from __future__ import annotations
import asyncio
import json
import os
from functools import partial
from pathlib import Path
import click
@ -11,13 +11,10 @@ from click_option_group import (MutuallyExclusiveOptionGroup, OptionGroup,
optgroup)
from huggingface_hub import snapshot_download
from tensorrt_llm import LLM as PyTorchLLM
from tensorrt_llm._tensorrt_engine import LLM
from tensorrt_llm._torch.auto_deploy import LLM as AutoDeployLLM
from tensorrt_llm.bench.benchmark import (generate_json_report,
get_general_cli_options, get_llm)
from tensorrt_llm.bench.benchmark.utils.asynchronous import async_benchmark
from tensorrt_llm.bench.benchmark.utils.general import generate_warmup_dataset
from tensorrt_llm.bench.benchmark.utils.processes import IterationWriter
from tensorrt_llm.bench.build.build import get_model_config
from tensorrt_llm.bench.dataclasses.configuration import RuntimeConfig
from tensorrt_llm.bench.dataclasses.general import BenchmarkEnvironment
from tensorrt_llm.bench.dataclasses.reporting import ReportUtility
@ -25,7 +22,9 @@ from tensorrt_llm.llmapi import CapacitySchedulerPolicy
from tensorrt_llm.models.modeling_utils import SpeculativeDecodingMode
# isort: off
from tensorrt_llm.bench.benchmark.utils.general import get_settings_from_engine, get_settings, update_sampler_args_with_extra_options, ALL_SUPPORTED_BACKENDS
from tensorrt_llm.bench.benchmark.utils.general import (
get_settings_from_engine, get_settings,
update_sampler_args_with_extra_options, ALL_SUPPORTED_BACKENDS)
# isort: on
from tensorrt_llm.bench.utils.data import (create_dataset_from_stream,
initialize_tokenizer,
@ -46,6 +45,13 @@ from tensorrt_llm.sampling_params import SamplingParams
default=None,
help="Path to a serialized TRT-LLM engine.",
)
@optgroup.option(
"--extra_llm_api_options",
type=str,
default=None,
help=
"Path to a YAML file that overwrites the parameters specified by trtllm-bench."
)
@optgroup.option("--backend",
type=click.Choice(ALL_SUPPORTED_BACKENDS),
default="pytorch",
@ -184,50 +190,30 @@ def latency_command(
**params,
) -> None:
"""Run a latency test on a TRT-LLM engine."""
logger.info("Preparing to run latency benchmark...")
# Parameters from CLI
# Model, experiment, and engine params
dataset_path: Path = params.get("dataset")
num_requests: int = params.get("num_requests")
model: str = bench_env.model
checkpoint_path: Path = bench_env.checkpoint_path or bench_env.model
engine_dir: Path = params.get("engine_dir")
concurrency: int = params.get("concurrency")
beam_width: int = params.get("beam_width")
warmup: int = params.get("warmup")
modality: str = params.get("modality")
max_input_len: int = params.get("max_input_len")
max_seq_len: int = params.get("max_seq_len")
backend: str = params.get("backend")
model_type = get_model_config(model, checkpoint_path).model_type
options = get_general_cli_options(params, bench_env)
# Runtime Options
kv_cache_percent = params.get("kv_cache_free_gpu_mem_fraction")
# Speculative Decode Options
medusa_choices = params.get("medusa_choices")
# Reporting Options
report_json: Path = params.pop("report_json")
iteration_log: Path = params.pop("iteration_log")
iteration_writer = IterationWriter(iteration_log)
# Initialize the HF tokenizer for the specified model.
tokenizer = initialize_tokenizer(checkpoint_path)
tokenizer = initialize_tokenizer(options.checkpoint_path)
# Dataset Loading and Preparation
with open(dataset_path, "r") as dataset:
with open(options.dataset_path, "r") as dataset:
metadata, requests = create_dataset_from_stream(
tokenizer,
dataset,
num_requests=num_requests,
model_dir=checkpoint_path,
model_type=model_type,
modality=modality,
max_input_seq_len_for_multimodal=max_input_len)
num_requests=options.num_requests,
model_dir=options.checkpoint_path,
model_type=options.model_type,
modality=options.modality,
max_input_seq_len_for_multimodal=options.max_input_len)
metadata.dataset_path = dataset_path
metadata.dataset_path = options.dataset_path
if modality is None:
if options.modality is None:
# Log dataset info
# NOTE: This table is only accurate for non-multimodal models.
# The accurate table for multimodal models will be logged after the benchmark is done.
@ -235,20 +221,20 @@ def latency_command(
# Engine configuration parsing for PyTorch backend
kwargs = {}
if backend and backend.lower() in ALL_SUPPORTED_BACKENDS and backend.lower(
) != "tensorrt":
if options.backend and options.backend.lower(
) in ALL_SUPPORTED_BACKENDS and options.backend.lower() != "tensorrt":
if bench_env.checkpoint_path is None:
snapshot_download(model)
snapshot_download(options.model)
exec_settings = get_settings(params, metadata, bench_env.model,
bench_env.checkpoint_path)
kwargs_max_sql = max_seq_len or metadata.max_sequence_length
kwargs_max_sql = options.max_seq_len or metadata.max_sequence_length
logger.info(f"Setting PyTorch max sequence length to {kwargs_max_sql}")
kwargs["max_seq_len"] = kwargs_max_sql
elif backend.lower() == "tensorrt":
assert max_seq_len is None, (
elif options.backend.lower() == "tensorrt":
assert options.max_seq_len is None, (
"max_seq_len is not a runtime parameter for C++ backend")
exec_settings, build_cfg = get_settings_from_engine(engine_dir)
exec_settings, build_cfg = get_settings_from_engine(options.engine_dir)
engine_max_seq_len = build_cfg["max_seq_len"]
if metadata.max_sequence_length > engine_max_seq_len:
@ -259,17 +245,18 @@ def latency_command(
"support this dataset.")
else:
raise RuntimeError(
f"Invalid backend: {backend}, please use one of the following: "
f"Invalid backend: {options.backend}, please use one of the following: "
f"{ALL_SUPPORTED_BACKENDS}")
exec_settings["model"] = model
exec_settings["model"] = options.model
engine_tokens = exec_settings["settings_config"]["max_num_tokens"]
# Update configuration with runtime options
exec_settings["settings_config"]["kv_cache_percent"] = kv_cache_percent
exec_settings["settings_config"][
"kv_cache_percent"] = options.kv_cache_percent
exec_settings["settings_config"]["max_batch_size"] = 1
exec_settings["settings_config"]["max_num_tokens"] = engine_tokens
exec_settings["settings_config"]["beam_width"] = beam_width
exec_settings["settings_config"]["beam_width"] = options.beam_width
exec_settings["settings_config"]["chunking"] = False
exec_settings["settings_config"][
"scheduler_policy"] = CapacitySchedulerPolicy.GUARANTEED_NO_EVICT
@ -286,6 +273,8 @@ def latency_command(
exec_settings["performance_options"]["cuda_graphs"] = True
exec_settings["performance_options"]["multi_block_mode"] = True
exec_settings["extra_llm_api_options"] = params.get("extra_llm_api_options")
# Decoding Options
if medusa_choices is not None:
with open(medusa_choices, "r") as medusa_yml:
@ -297,31 +286,12 @@ def latency_command(
llm = None
kwargs = kwargs | runtime_config.get_llm_args()
kwargs['backend'] = backend
kwargs['backend'] = options.backend
try:
logger.info("Setting up latency benchmark.")
if "pytorch_backend_config" in kwargs and iteration_log is not None:
kwargs["pytorch_backend_config"].enable_iter_perf_stats = True
if runtime_config.backend == 'pytorch':
if kwargs.pop("extended_runtime_perf_knob_config", None):
logger.warning(
"Ignore extended_runtime_perf_knob_config for pytorch backend."
)
llm = PyTorchLLM(**kwargs)
elif runtime_config.backend == "_autodeploy":
if kwargs.pop("extended_runtime_perf_knob_config", None):
logger.warning(
"Ignore extended_runtime_perf_knob_config for _autodeploy backend."
)
kwargs["world_size"] = kwargs.pop("tensor_parallel_size", None)
kwargs.pop("pipeline_parallel_size", None)
llm = AutoDeployLLM(**kwargs)
else:
llm = LLM(**kwargs)
llm = get_llm(runtime_config, kwargs)
ignore_eos = True if runtime_config.decoding_config.decoding_mode == SpeculativeDecodingMode.NONE else False
eos_id = tokenizer.eos_token_id if not ignore_eos else -1
@ -330,9 +300,10 @@ def latency_command(
sampler_args = {
"end_id": eos_id,
"pad_id": pad_id,
"n": beam_width,
"use_beam_search": beam_width > 1
"n": options.beam_width,
"use_beam_search": options.beam_width > 1
}
sampler_args = update_sampler_args_with_extra_options(
sampler_args, params.pop("sampler_options"))
sampling_params = SamplingParams(**sampler_args)
@ -340,9 +311,9 @@ def latency_command(
post_proc_params = None # No detokenization
# Perform warmup if requested.
if warmup > 0:
if options.warmup > 0:
logger.info("Setting up for warmup...")
warmup_dataset = generate_warmup_dataset(requests, warmup)
warmup_dataset = generate_warmup_dataset(requests, options.warmup)
logger.info("Running warmup.")
asyncio.run(
async_benchmark(llm,
@ -350,14 +321,15 @@ def latency_command(
post_proc_params,
warmup_dataset,
False,
concurrency,
modality=modality))
options.concurrency,
modality=options.modality))
# WAR: IterationResult is a singleton tied to the executor.
# Since the benchmark calls asyncio.run() multiple times (e.g., during warmup),
# we must reset it to ensure it attaches to the correct event loop.
llm._executor._iter_stats_result = None
logger.info("Warmup done.")
iteration_writer = options.iteration_writer
with iteration_writer.capture():
statistics = asyncio.run(
async_benchmark(llm,
@ -365,23 +337,27 @@ def latency_command(
post_proc_params,
requests,
True,
concurrency,
options.concurrency,
iteration_writer.full_address,
modality=modality))
modality=options.modality))
logger.info(f"Benchmark done. Reporting results...")
logger.info("Benchmark done. Reporting results...")
if modality is not None:
if options.modality is not None:
# For multimodal models, we need to update the metadata with the correct input lengths
metadata = update_metadata_for_multimodal(metadata, statistics)
report_utility = ReportUtility(statistics, metadata, runtime_config,
logger, kwargs, True)
if report_json:
logger.info(f"Writing report to '{report_json}'.")
with open(report_json, "w") as f:
f.write(
json.dumps(report_utility.get_statistics_dict(), indent=4))
# Generate reports for statistics, output tokens, and request info.
generate_json_report(options.report_json,
report_utility.get_statistics_dict)
generate_json_report(
options.output_json,
partial(report_utility.get_output_tokens, tokenizer))
generate_json_report(
options.request_json,
partial(report_utility.get_request_info, tokenizer))
report_utility.report_statistics()
except KeyboardInterrupt:
logger.info("Keyboard interrupt, exiting benchmark...")

View File

@ -1,8 +1,8 @@
from __future__ import annotations
import asyncio
import json
import sys
from functools import partial
from pathlib import Path
import click
@ -10,18 +10,16 @@ from click_option_group import (MutuallyExclusiveOptionGroup, OptionGroup,
optgroup)
from huggingface_hub import snapshot_download
from tensorrt_llm.bench.benchmark import (GeneralExecSettings,
generate_json_report,
get_general_cli_options, get_llm)
from tensorrt_llm.bench.benchmark.utils.asynchronous import async_benchmark
from tensorrt_llm.bench.benchmark.utils.processes import IterationWriter
from tensorrt_llm.bench.build.build import get_model_config
from tensorrt_llm.tools.importlib_utils import import_custom_module_from_dir
# isort: off
from tensorrt_llm.bench.benchmark.utils.general import (
get_settings_from_engine, get_settings, ALL_SUPPORTED_BACKENDS)
# isort: on
from tensorrt_llm import LLM as PyTorchLLM
from tensorrt_llm._tensorrt_engine import LLM
from tensorrt_llm._torch.auto_deploy import LLM as AutoDeployLLM
from tensorrt_llm.bench.benchmark.utils.general import (
generate_warmup_dataset, update_sampler_args_with_extra_options)
from tensorrt_llm.bench.dataclasses.configuration import RuntimeConfig
@ -293,10 +291,22 @@ def throughput_command(
**params,
) -> None:
"""Run a throughput test on a TRT-LLM engine."""
logger.info("Preparing to run throughput benchmark...")
# Parameters from CLI
# Model, experiment, and engine params
image_data_format: str = params.get("image_data_format", "pt")
data_device: str = params.get("data_device", "cpu")
no_skip_tokenizer_init: bool = params.get("no_skip_tokenizer_init", False)
# Get general CLI options using the centralized function
options: GeneralExecSettings = get_general_cli_options(params, bench_env)
tokenizer = initialize_tokenizer(options.checkpoint_path)
# Extract throughput-specific options not handled by GeneralExecSettings
max_batch_size = params.get("max_batch_size")
max_num_tokens = params.get("max_num_tokens")
enable_chunked_context: bool = params.get("enable_chunked_context")
scheduler_policy: str = params.get("scheduler_policy")
custom_module_dirs: list[Path] = params.pop("custom_module_dirs", [])
for custom_module_dir in custom_module_dirs:
try:
@ -306,77 +316,50 @@ def throughput_command(
f"Failed to import custom module from {custom_module_dir}: {e}")
raise e
dataset_path: Path = params.get("dataset")
no_skip_tokenizer_init: bool = params.get("no_skip_tokenizer_init", False)
eos_id: int = params.get("eos_id")
warmup: int = params.get("warmup")
num_requests: int = params.get("num_requests")
max_seq_len: int = params.get("max_seq_len")
model: str = bench_env.model
checkpoint_path: Path = bench_env.checkpoint_path or bench_env.model
engine_dir: Path = params.get("engine_dir")
concurrency: int = params.get("concurrency")
backend: str = params.get("backend")
modality: str = params.get("modality")
max_input_len: int = params.get("max_input_len")
image_data_format: str = params.get("image_data_format", "pt")
data_device: str = params.get("data_device", "cpu")
model_type = get_model_config(model, checkpoint_path).model_type
# Reporting options
report_json: Path = params.get("report_json")
output_json: Path = params.get("output_json")
request_json: Path = params.get("request_json")
iteration_log: Path = params.get("iteration_log")
iteration_writer = IterationWriter(iteration_log)
# Runtime kwargs and option tracking.
kwargs = {}
# Initialize the HF tokenizer for the specified model. This is only used for data preparation.
tokenizer = initialize_tokenizer(checkpoint_path)
# Dataset Loading and Preparation
with open(dataset_path, "r") as dataset:
with open(options.dataset_path, "r") as dataset:
metadata, requests = create_dataset_from_stream(
tokenizer,
dataset,
num_requests=num_requests,
model_dir=checkpoint_path,
model_type=model_type,
modality=modality,
num_requests=options.num_requests,
model_dir=options.checkpoint_path,
model_type=options.model_type,
modality=options.modality,
image_data_format=image_data_format,
data_device=data_device,
max_input_seq_len_for_multimodal=max_input_len)
metadata.dataset_path = dataset_path
max_input_seq_len_for_multimodal=options.max_input_len)
metadata.dataset_path = options.dataset_path
params["target_input_len"] = params.get(
"target_input_len") or metadata.avg_isl
params["target_output_len"] = params.get(
"target_output_len") or metadata.avg_osl
if modality is None:
if options.modality is None:
# Log dataset info
# NOTE: This table is only accurate for non-multimodal models.
# The accurate table for multimodal models will be logged after the benchmark is done.
logger.info(metadata.get_summary_for_print())
# Engine configuration parsing
if backend and backend.lower() in ALL_SUPPORTED_BACKENDS and backend.lower(
) != "tensorrt":
if options.backend and options.backend.lower(
) in ALL_SUPPORTED_BACKENDS and options.backend.lower() != "tensorrt":
# If we're dealing with a model name, perform a snapshot download to
# make sure we have a local copy of the model.
if bench_env.checkpoint_path is None:
snapshot_download(model)
snapshot_download(options.model)
exec_settings = get_settings(params, metadata, bench_env.model,
bench_env.checkpoint_path)
kwargs_max_sql = max_seq_len or metadata.max_sequence_length
kwargs_max_sql = options.max_seq_len or metadata.max_sequence_length
logger.info(f"Setting PyTorch max sequence length to {kwargs_max_sql}")
kwargs["max_seq_len"] = kwargs_max_sql
elif backend.lower() == "tensorrt":
assert max_seq_len is None, (
elif options.backend.lower() == "tensorrt":
assert options.max_seq_len is None, (
"max_seq_len is not a runtime parameter for C++ backend")
exec_settings, build_cfg = get_settings_from_engine(engine_dir)
exec_settings, build_cfg = get_settings_from_engine(options.engine_dir)
engine_max_seq_len = build_cfg["max_seq_len"]
# TODO: Verify that the engine can handle the max/min ISL/OSL.
@ -388,29 +371,23 @@ def throughput_command(
"to support this dataset.")
else:
raise RuntimeError(
f"Invalid backend: {backend}, please use one of the following: "
f"Invalid backend: {options.backend}, please use one of the following: "
"pytorch, tensorrt, _autodeploy.")
exec_settings["model"] = model
exec_settings["model"] = options.model
engine_bs = exec_settings["settings_config"]["max_batch_size"]
engine_tokens = exec_settings["settings_config"]["max_num_tokens"]
# Runtime Options
runtime_max_bs = params.get("max_batch_size")
runtime_max_tokens = params.get("max_num_tokens")
runtime_max_bs = runtime_max_bs or engine_bs
runtime_max_tokens = runtime_max_tokens or engine_tokens
kv_cache_percent = params.get("kv_cache_free_gpu_mem_fraction")
beam_width = params.get("beam_width")
streaming: bool = params.get("streaming")
enable_chunked_context: bool = params.get("enable_chunked_context")
scheduler_policy: str = params.get("scheduler_policy")
runtime_max_bs = max_batch_size or engine_bs
runtime_max_tokens = max_num_tokens or engine_tokens
# Update configuration with runtime options
exec_settings["settings_config"]["kv_cache_percent"] = kv_cache_percent
exec_settings["settings_config"][
"kv_cache_percent"] = options.kv_cache_percent
exec_settings["settings_config"]["max_batch_size"] = runtime_max_bs
exec_settings["settings_config"]["max_num_tokens"] = runtime_max_tokens
exec_settings["settings_config"]["beam_width"] = beam_width
exec_settings["settings_config"]["beam_width"] = options.beam_width
exec_settings["settings_config"][
"scheduler_policy"] = CapacitySchedulerPolicy.GUARANTEED_NO_EVICT if scheduler_policy == "guaranteed_no_evict" else CapacitySchedulerPolicy.MAX_UTILIZATION
exec_settings["settings_config"]["chunking"] = enable_chunked_context
@ -420,48 +397,25 @@ def throughput_command(
# LlmArgs
exec_settings["extra_llm_api_options"] = params.pop("extra_llm_api_options")
exec_settings["iteration_log"] = iteration_log
exec_settings["iteration_log"] = options.iteration_log
# Construct the runtime configuration dataclass.
runtime_config = RuntimeConfig(**exec_settings)
llm = None
def ignore_trt_only_args(kwargs: dict):
trt_only_args = [
"batching_type",
"normalize_log_probs",
"extended_runtime_perf_knob_config",
]
for arg in trt_only_args:
if kwargs.pop(arg, None):
logger.warning(
f"Ignore {arg} for {runtime_config.backend} backend.")
try:
logger.info("Setting up throughput benchmark.")
kwargs = kwargs | runtime_config.get_llm_args()
kwargs['backend'] = backend
kwargs['skip_tokenizer_init'] = not no_skip_tokenizer_init
kwargs['backend'] = options.backend
if backend == "pytorch" and iteration_log is not None:
kwargs["enable_iter_perf_stats"] = True
if runtime_config.backend == 'pytorch':
ignore_trt_only_args(kwargs)
llm = PyTorchLLM(**kwargs)
elif runtime_config.backend == "_autodeploy":
ignore_trt_only_args(kwargs)
kwargs["world_size"] = kwargs.pop("tensor_parallel_size", None)
llm = AutoDeployLLM(**kwargs)
else:
llm = LLM(**kwargs)
llm = get_llm(runtime_config, kwargs)
sampler_args = {
"end_id": eos_id,
"pad_id": eos_id,
"n": beam_width,
"use_beam_search": beam_width > 1
"end_id": options.eos_id,
"pad_id": options.eos_id,
"n": options.beam_width,
"use_beam_search": options.beam_width > 1
}
sampler_args = update_sampler_args_with_extra_options(
sampler_args, params.pop("sampler_options"))
@ -470,9 +424,9 @@ def throughput_command(
post_proc_params = None # No detokenization
# Perform warmup if requested.
if warmup > 0:
if options.warmup > 0:
logger.info("Setting up for warmup...")
warmup_dataset = generate_warmup_dataset(requests, warmup)
warmup_dataset = generate_warmup_dataset(requests, options.warmup)
logger.info("Running warmup.")
asyncio.run(
async_benchmark(llm,
@ -480,46 +434,42 @@ def throughput_command(
post_proc_params,
warmup_dataset,
False,
concurrency,
modality=modality))
options.concurrency,
modality=options.modality))
# WAR: IterationResult is a singleton tied to the executor.
# Since the benchmark calls asyncio.run() multiple times (e.g., during warmup),
# we must reset it to ensure it attaches to the correct event loop.
llm._executor._iter_stats_result = None
logger.info("Warmup done.")
iteration_writer = options.iteration_writer
with iteration_writer.capture():
statistics = asyncio.run(
async_benchmark(llm,
sampling_params,
post_proc_params,
requests,
streaming,
concurrency,
options.streaming,
options.concurrency,
iteration_writer.full_address,
modality=modality))
modality=options.modality))
logger.info(f"Benchmark done. Reporting results...")
if modality is not None:
logger.info("Benchmark done. Reporting results...")
if options.modality is not None:
# For multimodal models, we need to update the metadata with the correct input lengths
metadata = update_metadata_for_multimodal(metadata, statistics)
report_utility = ReportUtility(statistics, metadata, runtime_config,
logger, kwargs, streaming)
if report_json:
logger.info(f"Writing report to '{report_json}'.")
with open(report_json, "w") as f:
f.write(
json.dumps(report_utility.get_statistics_dict(), indent=4))
if output_json:
logger.info(f"Writing output to {output_json}.")
with open(output_json, "w") as f:
output_token_info = report_utility.get_output_tokens(tokenizer)
f.write(json.dumps(output_token_info, indent=4))
if request_json:
logger.info(f"Writing request information to {request_json}.")
with open(request_json, "w") as f:
f.write(json.dumps(report_utility.get_request_info(tokenizer)))
logger, kwargs, options.streaming)
# Generate reports for statistics, output tokens, and request info.
generate_json_report(options.report_json,
report_utility.get_statistics_dict)
generate_json_report(
options.output_json,
partial(report_utility.get_output_tokens, tokenizer))
generate_json_report(
options.request_json,
partial(report_utility.get_request_info, tokenizer))
report_utility.report_statistics()
except KeyboardInterrupt:
logger.info("Keyboard interrupt, exiting benchmark...")