mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[None][fix] Update to pull LLM from a central location. (#6458)
Signed-off-by: Frank Di Natale <3429989+FrankD412@users.noreply.github.com>
This commit is contained in:
parent
6a44e5b9d1
commit
788fc62d23
@ -0,0 +1,159 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, Optional
|
||||
|
||||
from pydantic import AliasChoices, BaseModel, Field
|
||||
|
||||
from tensorrt_llm import LLM as PyTorchLLM
|
||||
from tensorrt_llm._tensorrt_engine import LLM
|
||||
from tensorrt_llm._torch.auto_deploy import LLM as AutoDeployLLM
|
||||
from tensorrt_llm.bench.benchmark.utils.processes import IterationWriter
|
||||
from tensorrt_llm.bench.build.build import get_model_config
|
||||
from tensorrt_llm.bench.dataclasses.configuration import RuntimeConfig
|
||||
from tensorrt_llm.bench.dataclasses.general import BenchmarkEnvironment
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
|
||||
class GeneralExecSettings(BaseModel):
|
||||
model_config = {
|
||||
"extra": "ignore"
|
||||
} # Ignore extra fields not defined in the model
|
||||
|
||||
backend: str = Field(
|
||||
default="pytorch",
|
||||
description="The backend to use when running benchmarking")
|
||||
beam_width: int = Field(default=1, description="Number of search beams")
|
||||
model_path: Optional[Path] = Field(default=None,
|
||||
description="Path to model checkpoint")
|
||||
concurrency: int = Field(
|
||||
default=-1, description="Desired concurrency rate, <=0 for no limit")
|
||||
dataset_path: Optional[Path] = Field(default=None,
|
||||
validation_alias=AliasChoices(
|
||||
"dataset_path", "dataset"),
|
||||
description="Path to dataset file")
|
||||
engine_dir: Optional[Path] = Field(
|
||||
default=None, description="Path to a serialized TRT-LLM engine")
|
||||
eos_id: int = Field(
|
||||
default=-1, description="End-of-sequence token ID, -1 to disable EOS")
|
||||
iteration_log: Optional[Path] = Field(
|
||||
default=None, description="Path where iteration logging is written")
|
||||
kv_cache_percent: float = Field(
|
||||
default=0.90,
|
||||
validation_alias=AliasChoices("kv_cache_percent",
|
||||
"kv_cache_free_gpu_mem_fraction"),
|
||||
description="Percentage of memory for KV Cache after model load")
|
||||
max_input_len: int = Field(default=4096,
|
||||
description="Maximum input sequence length")
|
||||
max_seq_len: Optional[int] = Field(default=None,
|
||||
description="Maximum sequence length")
|
||||
modality: Optional[str] = Field(
|
||||
default=None, description="Modality of multimodal requests")
|
||||
model: Optional[str] = Field(default=None, description="Model name or path")
|
||||
num_requests: int = Field(
|
||||
default=0, description="Number of requests to cap benchmark run at")
|
||||
output_json: Optional[Path] = Field(
|
||||
default=None, description="Path where output should be written")
|
||||
report_json: Optional[Path] = Field(
|
||||
default=None, description="Path where report should be written")
|
||||
request_json: Optional[Path] = Field(
|
||||
default=None,
|
||||
description="Path where per request information is written")
|
||||
streaming: bool = Field(default=False,
|
||||
description="Whether to use streaming mode")
|
||||
warmup: int = Field(default=2,
|
||||
description="Number of requests to warm up benchmark")
|
||||
|
||||
@property
|
||||
def iteration_writer(self) -> IterationWriter:
|
||||
return IterationWriter(self.iteration_log)
|
||||
|
||||
@property
|
||||
def model_type(self) -> str:
|
||||
return get_model_config(self.model, self.checkpoint_path).model_type
|
||||
|
||||
@property
|
||||
def checkpoint_path(self) -> Path:
|
||||
return self.model_path or self.model
|
||||
|
||||
|
||||
def ignore_trt_only_args(kwargs: dict, backend: str):
|
||||
"""Ignore TensorRT-only arguments for non-TensorRT backends.
|
||||
|
||||
Args:
|
||||
kwargs: Dictionary of keyword arguments to be passed to the LLM constructor.
|
||||
backend: The backend type (e.g., "pytorch", "_autodeploy").
|
||||
"""
|
||||
trt_only_args = [
|
||||
"batching_type",
|
||||
"normalize_log_probs",
|
||||
"extended_runtime_perf_knob_config",
|
||||
]
|
||||
for arg in trt_only_args:
|
||||
if kwargs.pop(arg, None):
|
||||
logger.warning(f"Ignore {arg} for {backend} backend.")
|
||||
|
||||
|
||||
def get_llm(runtime_config: RuntimeConfig, kwargs: dict):
|
||||
"""Create and return an appropriate LLM instance based on the backend configuration.
|
||||
|
||||
Args:
|
||||
runtime_config: Runtime configuration containing backend selection and settings.
|
||||
kwargs: Additional keyword arguments to pass to the LLM constructor.
|
||||
|
||||
Returns:
|
||||
An instance of the appropriate LLM class for the specified backend.
|
||||
"""
|
||||
llm_cls = LLM
|
||||
|
||||
if runtime_config.backend != "tensorrt":
|
||||
ignore_trt_only_args(kwargs, runtime_config.backend)
|
||||
|
||||
if runtime_config.backend == 'pytorch':
|
||||
llm_cls = PyTorchLLM
|
||||
|
||||
if runtime_config.iteration_log is not None:
|
||||
kwargs["enable_iter_perf_stats"] = True
|
||||
|
||||
elif runtime_config.backend == "_autodeploy":
|
||||
kwargs["world_size"] = kwargs.pop("tensor_parallel_size", None)
|
||||
llm_cls = AutoDeployLLM
|
||||
|
||||
llm = llm_cls(**kwargs)
|
||||
return llm
|
||||
|
||||
|
||||
def get_general_cli_options(
|
||||
params: Dict, bench_env: BenchmarkEnvironment) -> GeneralExecSettings:
|
||||
"""Get general execution settings from command line parameters.
|
||||
|
||||
Args:
|
||||
params: Dictionary of command line parameters.
|
||||
bench_env: Benchmark environment containing model and checkpoint information.
|
||||
|
||||
Returns:
|
||||
An instance of GeneralExecSettings containing general execution settings.
|
||||
"""
|
||||
# Create a copy of params to avoid modifying the original
|
||||
settings_dict = params.copy()
|
||||
|
||||
# Add derived values that need to be computed from bench_env
|
||||
model_path = bench_env.checkpoint_path
|
||||
model = bench_env.model
|
||||
# Override/add the computed values
|
||||
settings_dict.update({
|
||||
"model_path": model_path,
|
||||
"model": model,
|
||||
})
|
||||
|
||||
# Create and return the settings object, ignoring any extra fields
|
||||
return GeneralExecSettings(**settings_dict)
|
||||
|
||||
|
||||
def generate_json_report(report_path: Optional[Path], func: Callable):
|
||||
if report_path is None:
|
||||
logger.debug("No report path provided, skipping report generation.")
|
||||
else:
|
||||
logger.info(f"Writing report information to {report_path}...")
|
||||
with open(report_path, "w") as f:
|
||||
f.write(json.dumps(func(), indent=4))
|
||||
logger.info(f"Report information written to {report_path}.")
|
||||
@ -1,8 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
import click
|
||||
@ -11,13 +11,10 @@ from click_option_group import (MutuallyExclusiveOptionGroup, OptionGroup,
|
||||
optgroup)
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from tensorrt_llm import LLM as PyTorchLLM
|
||||
from tensorrt_llm._tensorrt_engine import LLM
|
||||
from tensorrt_llm._torch.auto_deploy import LLM as AutoDeployLLM
|
||||
from tensorrt_llm.bench.benchmark import (generate_json_report,
|
||||
get_general_cli_options, get_llm)
|
||||
from tensorrt_llm.bench.benchmark.utils.asynchronous import async_benchmark
|
||||
from tensorrt_llm.bench.benchmark.utils.general import generate_warmup_dataset
|
||||
from tensorrt_llm.bench.benchmark.utils.processes import IterationWriter
|
||||
from tensorrt_llm.bench.build.build import get_model_config
|
||||
from tensorrt_llm.bench.dataclasses.configuration import RuntimeConfig
|
||||
from tensorrt_llm.bench.dataclasses.general import BenchmarkEnvironment
|
||||
from tensorrt_llm.bench.dataclasses.reporting import ReportUtility
|
||||
@ -25,7 +22,9 @@ from tensorrt_llm.llmapi import CapacitySchedulerPolicy
|
||||
from tensorrt_llm.models.modeling_utils import SpeculativeDecodingMode
|
||||
|
||||
# isort: off
|
||||
from tensorrt_llm.bench.benchmark.utils.general import get_settings_from_engine, get_settings, update_sampler_args_with_extra_options, ALL_SUPPORTED_BACKENDS
|
||||
from tensorrt_llm.bench.benchmark.utils.general import (
|
||||
get_settings_from_engine, get_settings,
|
||||
update_sampler_args_with_extra_options, ALL_SUPPORTED_BACKENDS)
|
||||
# isort: on
|
||||
from tensorrt_llm.bench.utils.data import (create_dataset_from_stream,
|
||||
initialize_tokenizer,
|
||||
@ -46,6 +45,13 @@ from tensorrt_llm.sampling_params import SamplingParams
|
||||
default=None,
|
||||
help="Path to a serialized TRT-LLM engine.",
|
||||
)
|
||||
@optgroup.option(
|
||||
"--extra_llm_api_options",
|
||||
type=str,
|
||||
default=None,
|
||||
help=
|
||||
"Path to a YAML file that overwrites the parameters specified by trtllm-bench."
|
||||
)
|
||||
@optgroup.option("--backend",
|
||||
type=click.Choice(ALL_SUPPORTED_BACKENDS),
|
||||
default="pytorch",
|
||||
@ -184,50 +190,30 @@ def latency_command(
|
||||
**params,
|
||||
) -> None:
|
||||
"""Run a latency test on a TRT-LLM engine."""
|
||||
|
||||
logger.info("Preparing to run latency benchmark...")
|
||||
# Parameters from CLI
|
||||
# Model, experiment, and engine params
|
||||
dataset_path: Path = params.get("dataset")
|
||||
num_requests: int = params.get("num_requests")
|
||||
model: str = bench_env.model
|
||||
checkpoint_path: Path = bench_env.checkpoint_path or bench_env.model
|
||||
engine_dir: Path = params.get("engine_dir")
|
||||
concurrency: int = params.get("concurrency")
|
||||
beam_width: int = params.get("beam_width")
|
||||
warmup: int = params.get("warmup")
|
||||
modality: str = params.get("modality")
|
||||
max_input_len: int = params.get("max_input_len")
|
||||
max_seq_len: int = params.get("max_seq_len")
|
||||
backend: str = params.get("backend")
|
||||
model_type = get_model_config(model, checkpoint_path).model_type
|
||||
options = get_general_cli_options(params, bench_env)
|
||||
|
||||
# Runtime Options
|
||||
kv_cache_percent = params.get("kv_cache_free_gpu_mem_fraction")
|
||||
# Speculative Decode Options
|
||||
medusa_choices = params.get("medusa_choices")
|
||||
|
||||
# Reporting Options
|
||||
report_json: Path = params.pop("report_json")
|
||||
iteration_log: Path = params.pop("iteration_log")
|
||||
iteration_writer = IterationWriter(iteration_log)
|
||||
|
||||
# Initialize the HF tokenizer for the specified model.
|
||||
tokenizer = initialize_tokenizer(checkpoint_path)
|
||||
tokenizer = initialize_tokenizer(options.checkpoint_path)
|
||||
|
||||
# Dataset Loading and Preparation
|
||||
with open(dataset_path, "r") as dataset:
|
||||
with open(options.dataset_path, "r") as dataset:
|
||||
metadata, requests = create_dataset_from_stream(
|
||||
tokenizer,
|
||||
dataset,
|
||||
num_requests=num_requests,
|
||||
model_dir=checkpoint_path,
|
||||
model_type=model_type,
|
||||
modality=modality,
|
||||
max_input_seq_len_for_multimodal=max_input_len)
|
||||
num_requests=options.num_requests,
|
||||
model_dir=options.checkpoint_path,
|
||||
model_type=options.model_type,
|
||||
modality=options.modality,
|
||||
max_input_seq_len_for_multimodal=options.max_input_len)
|
||||
|
||||
metadata.dataset_path = dataset_path
|
||||
metadata.dataset_path = options.dataset_path
|
||||
|
||||
if modality is None:
|
||||
if options.modality is None:
|
||||
# Log dataset info
|
||||
# NOTE: This table is only accurate for non-multimodal models.
|
||||
# The accurate table for multimodal models will be logged after the benchmark is done.
|
||||
@ -235,20 +221,20 @@ def latency_command(
|
||||
|
||||
# Engine configuration parsing for PyTorch backend
|
||||
kwargs = {}
|
||||
if backend and backend.lower() in ALL_SUPPORTED_BACKENDS and backend.lower(
|
||||
) != "tensorrt":
|
||||
if options.backend and options.backend.lower(
|
||||
) in ALL_SUPPORTED_BACKENDS and options.backend.lower() != "tensorrt":
|
||||
if bench_env.checkpoint_path is None:
|
||||
snapshot_download(model)
|
||||
snapshot_download(options.model)
|
||||
|
||||
exec_settings = get_settings(params, metadata, bench_env.model,
|
||||
bench_env.checkpoint_path)
|
||||
kwargs_max_sql = max_seq_len or metadata.max_sequence_length
|
||||
kwargs_max_sql = options.max_seq_len or metadata.max_sequence_length
|
||||
logger.info(f"Setting PyTorch max sequence length to {kwargs_max_sql}")
|
||||
kwargs["max_seq_len"] = kwargs_max_sql
|
||||
elif backend.lower() == "tensorrt":
|
||||
assert max_seq_len is None, (
|
||||
elif options.backend.lower() == "tensorrt":
|
||||
assert options.max_seq_len is None, (
|
||||
"max_seq_len is not a runtime parameter for C++ backend")
|
||||
exec_settings, build_cfg = get_settings_from_engine(engine_dir)
|
||||
exec_settings, build_cfg = get_settings_from_engine(options.engine_dir)
|
||||
engine_max_seq_len = build_cfg["max_seq_len"]
|
||||
|
||||
if metadata.max_sequence_length > engine_max_seq_len:
|
||||
@ -259,17 +245,18 @@ def latency_command(
|
||||
"support this dataset.")
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Invalid backend: {backend}, please use one of the following: "
|
||||
f"Invalid backend: {options.backend}, please use one of the following: "
|
||||
f"{ALL_SUPPORTED_BACKENDS}")
|
||||
|
||||
exec_settings["model"] = model
|
||||
exec_settings["model"] = options.model
|
||||
engine_tokens = exec_settings["settings_config"]["max_num_tokens"]
|
||||
|
||||
# Update configuration with runtime options
|
||||
exec_settings["settings_config"]["kv_cache_percent"] = kv_cache_percent
|
||||
exec_settings["settings_config"][
|
||||
"kv_cache_percent"] = options.kv_cache_percent
|
||||
exec_settings["settings_config"]["max_batch_size"] = 1
|
||||
exec_settings["settings_config"]["max_num_tokens"] = engine_tokens
|
||||
exec_settings["settings_config"]["beam_width"] = beam_width
|
||||
exec_settings["settings_config"]["beam_width"] = options.beam_width
|
||||
exec_settings["settings_config"]["chunking"] = False
|
||||
exec_settings["settings_config"][
|
||||
"scheduler_policy"] = CapacitySchedulerPolicy.GUARANTEED_NO_EVICT
|
||||
@ -286,6 +273,8 @@ def latency_command(
|
||||
exec_settings["performance_options"]["cuda_graphs"] = True
|
||||
exec_settings["performance_options"]["multi_block_mode"] = True
|
||||
|
||||
exec_settings["extra_llm_api_options"] = params.get("extra_llm_api_options")
|
||||
|
||||
# Decoding Options
|
||||
if medusa_choices is not None:
|
||||
with open(medusa_choices, "r") as medusa_yml:
|
||||
@ -297,31 +286,12 @@ def latency_command(
|
||||
|
||||
llm = None
|
||||
kwargs = kwargs | runtime_config.get_llm_args()
|
||||
kwargs['backend'] = backend
|
||||
kwargs['backend'] = options.backend
|
||||
|
||||
try:
|
||||
logger.info("Setting up latency benchmark.")
|
||||
|
||||
if "pytorch_backend_config" in kwargs and iteration_log is not None:
|
||||
kwargs["pytorch_backend_config"].enable_iter_perf_stats = True
|
||||
|
||||
if runtime_config.backend == 'pytorch':
|
||||
if kwargs.pop("extended_runtime_perf_knob_config", None):
|
||||
logger.warning(
|
||||
"Ignore extended_runtime_perf_knob_config for pytorch backend."
|
||||
)
|
||||
llm = PyTorchLLM(**kwargs)
|
||||
elif runtime_config.backend == "_autodeploy":
|
||||
if kwargs.pop("extended_runtime_perf_knob_config", None):
|
||||
logger.warning(
|
||||
"Ignore extended_runtime_perf_knob_config for _autodeploy backend."
|
||||
)
|
||||
kwargs["world_size"] = kwargs.pop("tensor_parallel_size", None)
|
||||
kwargs.pop("pipeline_parallel_size", None)
|
||||
|
||||
llm = AutoDeployLLM(**kwargs)
|
||||
else:
|
||||
llm = LLM(**kwargs)
|
||||
llm = get_llm(runtime_config, kwargs)
|
||||
|
||||
ignore_eos = True if runtime_config.decoding_config.decoding_mode == SpeculativeDecodingMode.NONE else False
|
||||
eos_id = tokenizer.eos_token_id if not ignore_eos else -1
|
||||
@ -330,9 +300,10 @@ def latency_command(
|
||||
sampler_args = {
|
||||
"end_id": eos_id,
|
||||
"pad_id": pad_id,
|
||||
"n": beam_width,
|
||||
"use_beam_search": beam_width > 1
|
||||
"n": options.beam_width,
|
||||
"use_beam_search": options.beam_width > 1
|
||||
}
|
||||
|
||||
sampler_args = update_sampler_args_with_extra_options(
|
||||
sampler_args, params.pop("sampler_options"))
|
||||
sampling_params = SamplingParams(**sampler_args)
|
||||
@ -340,9 +311,9 @@ def latency_command(
|
||||
post_proc_params = None # No detokenization
|
||||
|
||||
# Perform warmup if requested.
|
||||
if warmup > 0:
|
||||
if options.warmup > 0:
|
||||
logger.info("Setting up for warmup...")
|
||||
warmup_dataset = generate_warmup_dataset(requests, warmup)
|
||||
warmup_dataset = generate_warmup_dataset(requests, options.warmup)
|
||||
logger.info("Running warmup.")
|
||||
asyncio.run(
|
||||
async_benchmark(llm,
|
||||
@ -350,14 +321,15 @@ def latency_command(
|
||||
post_proc_params,
|
||||
warmup_dataset,
|
||||
False,
|
||||
concurrency,
|
||||
modality=modality))
|
||||
options.concurrency,
|
||||
modality=options.modality))
|
||||
# WAR: IterationResult is a singleton tied to the executor.
|
||||
# Since the benchmark calls asyncio.run() multiple times (e.g., during warmup),
|
||||
# we must reset it to ensure it attaches to the correct event loop.
|
||||
llm._executor._iter_stats_result = None
|
||||
logger.info("Warmup done.")
|
||||
|
||||
iteration_writer = options.iteration_writer
|
||||
with iteration_writer.capture():
|
||||
statistics = asyncio.run(
|
||||
async_benchmark(llm,
|
||||
@ -365,23 +337,27 @@ def latency_command(
|
||||
post_proc_params,
|
||||
requests,
|
||||
True,
|
||||
concurrency,
|
||||
options.concurrency,
|
||||
iteration_writer.full_address,
|
||||
modality=modality))
|
||||
modality=options.modality))
|
||||
|
||||
logger.info(f"Benchmark done. Reporting results...")
|
||||
logger.info("Benchmark done. Reporting results...")
|
||||
|
||||
if modality is not None:
|
||||
if options.modality is not None:
|
||||
# For multimodal models, we need to update the metadata with the correct input lengths
|
||||
metadata = update_metadata_for_multimodal(metadata, statistics)
|
||||
|
||||
report_utility = ReportUtility(statistics, metadata, runtime_config,
|
||||
logger, kwargs, True)
|
||||
if report_json:
|
||||
logger.info(f"Writing report to '{report_json}'.")
|
||||
with open(report_json, "w") as f:
|
||||
f.write(
|
||||
json.dumps(report_utility.get_statistics_dict(), indent=4))
|
||||
# Generate reports for statistics, output tokens, and request info.
|
||||
generate_json_report(options.report_json,
|
||||
report_utility.get_statistics_dict)
|
||||
generate_json_report(
|
||||
options.output_json,
|
||||
partial(report_utility.get_output_tokens, tokenizer))
|
||||
generate_json_report(
|
||||
options.request_json,
|
||||
partial(report_utility.get_request_info, tokenizer))
|
||||
report_utility.report_statistics()
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Keyboard interrupt, exiting benchmark...")
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
import click
|
||||
@ -10,18 +10,16 @@ from click_option_group import (MutuallyExclusiveOptionGroup, OptionGroup,
|
||||
optgroup)
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from tensorrt_llm.bench.benchmark import (GeneralExecSettings,
|
||||
generate_json_report,
|
||||
get_general_cli_options, get_llm)
|
||||
from tensorrt_llm.bench.benchmark.utils.asynchronous import async_benchmark
|
||||
from tensorrt_llm.bench.benchmark.utils.processes import IterationWriter
|
||||
from tensorrt_llm.bench.build.build import get_model_config
|
||||
from tensorrt_llm.tools.importlib_utils import import_custom_module_from_dir
|
||||
|
||||
# isort: off
|
||||
from tensorrt_llm.bench.benchmark.utils.general import (
|
||||
get_settings_from_engine, get_settings, ALL_SUPPORTED_BACKENDS)
|
||||
# isort: on
|
||||
from tensorrt_llm import LLM as PyTorchLLM
|
||||
from tensorrt_llm._tensorrt_engine import LLM
|
||||
from tensorrt_llm._torch.auto_deploy import LLM as AutoDeployLLM
|
||||
from tensorrt_llm.bench.benchmark.utils.general import (
|
||||
generate_warmup_dataset, update_sampler_args_with_extra_options)
|
||||
from tensorrt_llm.bench.dataclasses.configuration import RuntimeConfig
|
||||
@ -293,10 +291,22 @@ def throughput_command(
|
||||
**params,
|
||||
) -> None:
|
||||
"""Run a throughput test on a TRT-LLM engine."""
|
||||
|
||||
logger.info("Preparing to run throughput benchmark...")
|
||||
# Parameters from CLI
|
||||
# Model, experiment, and engine params
|
||||
image_data_format: str = params.get("image_data_format", "pt")
|
||||
data_device: str = params.get("data_device", "cpu")
|
||||
no_skip_tokenizer_init: bool = params.get("no_skip_tokenizer_init", False)
|
||||
|
||||
# Get general CLI options using the centralized function
|
||||
options: GeneralExecSettings = get_general_cli_options(params, bench_env)
|
||||
tokenizer = initialize_tokenizer(options.checkpoint_path)
|
||||
|
||||
# Extract throughput-specific options not handled by GeneralExecSettings
|
||||
max_batch_size = params.get("max_batch_size")
|
||||
max_num_tokens = params.get("max_num_tokens")
|
||||
enable_chunked_context: bool = params.get("enable_chunked_context")
|
||||
scheduler_policy: str = params.get("scheduler_policy")
|
||||
|
||||
custom_module_dirs: list[Path] = params.pop("custom_module_dirs", [])
|
||||
for custom_module_dir in custom_module_dirs:
|
||||
try:
|
||||
@ -306,77 +316,50 @@ def throughput_command(
|
||||
f"Failed to import custom module from {custom_module_dir}: {e}")
|
||||
raise e
|
||||
|
||||
dataset_path: Path = params.get("dataset")
|
||||
no_skip_tokenizer_init: bool = params.get("no_skip_tokenizer_init", False)
|
||||
eos_id: int = params.get("eos_id")
|
||||
warmup: int = params.get("warmup")
|
||||
num_requests: int = params.get("num_requests")
|
||||
max_seq_len: int = params.get("max_seq_len")
|
||||
model: str = bench_env.model
|
||||
checkpoint_path: Path = bench_env.checkpoint_path or bench_env.model
|
||||
engine_dir: Path = params.get("engine_dir")
|
||||
concurrency: int = params.get("concurrency")
|
||||
backend: str = params.get("backend")
|
||||
modality: str = params.get("modality")
|
||||
max_input_len: int = params.get("max_input_len")
|
||||
image_data_format: str = params.get("image_data_format", "pt")
|
||||
data_device: str = params.get("data_device", "cpu")
|
||||
model_type = get_model_config(model, checkpoint_path).model_type
|
||||
|
||||
# Reporting options
|
||||
report_json: Path = params.get("report_json")
|
||||
output_json: Path = params.get("output_json")
|
||||
request_json: Path = params.get("request_json")
|
||||
iteration_log: Path = params.get("iteration_log")
|
||||
iteration_writer = IterationWriter(iteration_log)
|
||||
|
||||
# Runtime kwargs and option tracking.
|
||||
kwargs = {}
|
||||
|
||||
# Initialize the HF tokenizer for the specified model. This is only used for data preparation.
|
||||
tokenizer = initialize_tokenizer(checkpoint_path)
|
||||
|
||||
# Dataset Loading and Preparation
|
||||
with open(dataset_path, "r") as dataset:
|
||||
with open(options.dataset_path, "r") as dataset:
|
||||
metadata, requests = create_dataset_from_stream(
|
||||
tokenizer,
|
||||
dataset,
|
||||
num_requests=num_requests,
|
||||
model_dir=checkpoint_path,
|
||||
model_type=model_type,
|
||||
modality=modality,
|
||||
num_requests=options.num_requests,
|
||||
model_dir=options.checkpoint_path,
|
||||
model_type=options.model_type,
|
||||
modality=options.modality,
|
||||
image_data_format=image_data_format,
|
||||
data_device=data_device,
|
||||
max_input_seq_len_for_multimodal=max_input_len)
|
||||
metadata.dataset_path = dataset_path
|
||||
max_input_seq_len_for_multimodal=options.max_input_len)
|
||||
metadata.dataset_path = options.dataset_path
|
||||
params["target_input_len"] = params.get(
|
||||
"target_input_len") or metadata.avg_isl
|
||||
params["target_output_len"] = params.get(
|
||||
"target_output_len") or metadata.avg_osl
|
||||
|
||||
if modality is None:
|
||||
if options.modality is None:
|
||||
# Log dataset info
|
||||
# NOTE: This table is only accurate for non-multimodal models.
|
||||
# The accurate table for multimodal models will be logged after the benchmark is done.
|
||||
logger.info(metadata.get_summary_for_print())
|
||||
|
||||
# Engine configuration parsing
|
||||
if backend and backend.lower() in ALL_SUPPORTED_BACKENDS and backend.lower(
|
||||
) != "tensorrt":
|
||||
if options.backend and options.backend.lower(
|
||||
) in ALL_SUPPORTED_BACKENDS and options.backend.lower() != "tensorrt":
|
||||
# If we're dealing with a model name, perform a snapshot download to
|
||||
# make sure we have a local copy of the model.
|
||||
if bench_env.checkpoint_path is None:
|
||||
snapshot_download(model)
|
||||
snapshot_download(options.model)
|
||||
|
||||
exec_settings = get_settings(params, metadata, bench_env.model,
|
||||
bench_env.checkpoint_path)
|
||||
kwargs_max_sql = max_seq_len or metadata.max_sequence_length
|
||||
kwargs_max_sql = options.max_seq_len or metadata.max_sequence_length
|
||||
logger.info(f"Setting PyTorch max sequence length to {kwargs_max_sql}")
|
||||
kwargs["max_seq_len"] = kwargs_max_sql
|
||||
elif backend.lower() == "tensorrt":
|
||||
assert max_seq_len is None, (
|
||||
elif options.backend.lower() == "tensorrt":
|
||||
assert options.max_seq_len is None, (
|
||||
"max_seq_len is not a runtime parameter for C++ backend")
|
||||
exec_settings, build_cfg = get_settings_from_engine(engine_dir)
|
||||
exec_settings, build_cfg = get_settings_from_engine(options.engine_dir)
|
||||
engine_max_seq_len = build_cfg["max_seq_len"]
|
||||
|
||||
# TODO: Verify that the engine can handle the max/min ISL/OSL.
|
||||
@ -388,29 +371,23 @@ def throughput_command(
|
||||
"to support this dataset.")
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Invalid backend: {backend}, please use one of the following: "
|
||||
f"Invalid backend: {options.backend}, please use one of the following: "
|
||||
"pytorch, tensorrt, _autodeploy.")
|
||||
|
||||
exec_settings["model"] = model
|
||||
exec_settings["model"] = options.model
|
||||
engine_bs = exec_settings["settings_config"]["max_batch_size"]
|
||||
engine_tokens = exec_settings["settings_config"]["max_num_tokens"]
|
||||
|
||||
# Runtime Options
|
||||
runtime_max_bs = params.get("max_batch_size")
|
||||
runtime_max_tokens = params.get("max_num_tokens")
|
||||
runtime_max_bs = runtime_max_bs or engine_bs
|
||||
runtime_max_tokens = runtime_max_tokens or engine_tokens
|
||||
kv_cache_percent = params.get("kv_cache_free_gpu_mem_fraction")
|
||||
beam_width = params.get("beam_width")
|
||||
streaming: bool = params.get("streaming")
|
||||
enable_chunked_context: bool = params.get("enable_chunked_context")
|
||||
scheduler_policy: str = params.get("scheduler_policy")
|
||||
runtime_max_bs = max_batch_size or engine_bs
|
||||
runtime_max_tokens = max_num_tokens or engine_tokens
|
||||
|
||||
# Update configuration with runtime options
|
||||
exec_settings["settings_config"]["kv_cache_percent"] = kv_cache_percent
|
||||
exec_settings["settings_config"][
|
||||
"kv_cache_percent"] = options.kv_cache_percent
|
||||
exec_settings["settings_config"]["max_batch_size"] = runtime_max_bs
|
||||
exec_settings["settings_config"]["max_num_tokens"] = runtime_max_tokens
|
||||
exec_settings["settings_config"]["beam_width"] = beam_width
|
||||
exec_settings["settings_config"]["beam_width"] = options.beam_width
|
||||
exec_settings["settings_config"][
|
||||
"scheduler_policy"] = CapacitySchedulerPolicy.GUARANTEED_NO_EVICT if scheduler_policy == "guaranteed_no_evict" else CapacitySchedulerPolicy.MAX_UTILIZATION
|
||||
exec_settings["settings_config"]["chunking"] = enable_chunked_context
|
||||
@ -420,48 +397,25 @@ def throughput_command(
|
||||
|
||||
# LlmArgs
|
||||
exec_settings["extra_llm_api_options"] = params.pop("extra_llm_api_options")
|
||||
exec_settings["iteration_log"] = iteration_log
|
||||
exec_settings["iteration_log"] = options.iteration_log
|
||||
|
||||
# Construct the runtime configuration dataclass.
|
||||
runtime_config = RuntimeConfig(**exec_settings)
|
||||
llm = None
|
||||
|
||||
def ignore_trt_only_args(kwargs: dict):
|
||||
trt_only_args = [
|
||||
"batching_type",
|
||||
"normalize_log_probs",
|
||||
"extended_runtime_perf_knob_config",
|
||||
]
|
||||
for arg in trt_only_args:
|
||||
if kwargs.pop(arg, None):
|
||||
logger.warning(
|
||||
f"Ignore {arg} for {runtime_config.backend} backend.")
|
||||
|
||||
try:
|
||||
logger.info("Setting up throughput benchmark.")
|
||||
kwargs = kwargs | runtime_config.get_llm_args()
|
||||
kwargs['backend'] = backend
|
||||
kwargs['skip_tokenizer_init'] = not no_skip_tokenizer_init
|
||||
kwargs['backend'] = options.backend
|
||||
|
||||
if backend == "pytorch" and iteration_log is not None:
|
||||
kwargs["enable_iter_perf_stats"] = True
|
||||
|
||||
if runtime_config.backend == 'pytorch':
|
||||
ignore_trt_only_args(kwargs)
|
||||
llm = PyTorchLLM(**kwargs)
|
||||
elif runtime_config.backend == "_autodeploy":
|
||||
ignore_trt_only_args(kwargs)
|
||||
kwargs["world_size"] = kwargs.pop("tensor_parallel_size", None)
|
||||
|
||||
llm = AutoDeployLLM(**kwargs)
|
||||
else:
|
||||
llm = LLM(**kwargs)
|
||||
llm = get_llm(runtime_config, kwargs)
|
||||
|
||||
sampler_args = {
|
||||
"end_id": eos_id,
|
||||
"pad_id": eos_id,
|
||||
"n": beam_width,
|
||||
"use_beam_search": beam_width > 1
|
||||
"end_id": options.eos_id,
|
||||
"pad_id": options.eos_id,
|
||||
"n": options.beam_width,
|
||||
"use_beam_search": options.beam_width > 1
|
||||
}
|
||||
sampler_args = update_sampler_args_with_extra_options(
|
||||
sampler_args, params.pop("sampler_options"))
|
||||
@ -470,9 +424,9 @@ def throughput_command(
|
||||
post_proc_params = None # No detokenization
|
||||
|
||||
# Perform warmup if requested.
|
||||
if warmup > 0:
|
||||
if options.warmup > 0:
|
||||
logger.info("Setting up for warmup...")
|
||||
warmup_dataset = generate_warmup_dataset(requests, warmup)
|
||||
warmup_dataset = generate_warmup_dataset(requests, options.warmup)
|
||||
logger.info("Running warmup.")
|
||||
asyncio.run(
|
||||
async_benchmark(llm,
|
||||
@ -480,46 +434,42 @@ def throughput_command(
|
||||
post_proc_params,
|
||||
warmup_dataset,
|
||||
False,
|
||||
concurrency,
|
||||
modality=modality))
|
||||
options.concurrency,
|
||||
modality=options.modality))
|
||||
# WAR: IterationResult is a singleton tied to the executor.
|
||||
# Since the benchmark calls asyncio.run() multiple times (e.g., during warmup),
|
||||
# we must reset it to ensure it attaches to the correct event loop.
|
||||
llm._executor._iter_stats_result = None
|
||||
logger.info("Warmup done.")
|
||||
|
||||
iteration_writer = options.iteration_writer
|
||||
with iteration_writer.capture():
|
||||
statistics = asyncio.run(
|
||||
async_benchmark(llm,
|
||||
sampling_params,
|
||||
post_proc_params,
|
||||
requests,
|
||||
streaming,
|
||||
concurrency,
|
||||
options.streaming,
|
||||
options.concurrency,
|
||||
iteration_writer.full_address,
|
||||
modality=modality))
|
||||
modality=options.modality))
|
||||
|
||||
logger.info(f"Benchmark done. Reporting results...")
|
||||
if modality is not None:
|
||||
logger.info("Benchmark done. Reporting results...")
|
||||
if options.modality is not None:
|
||||
# For multimodal models, we need to update the metadata with the correct input lengths
|
||||
metadata = update_metadata_for_multimodal(metadata, statistics)
|
||||
|
||||
report_utility = ReportUtility(statistics, metadata, runtime_config,
|
||||
logger, kwargs, streaming)
|
||||
if report_json:
|
||||
logger.info(f"Writing report to '{report_json}'.")
|
||||
with open(report_json, "w") as f:
|
||||
f.write(
|
||||
json.dumps(report_utility.get_statistics_dict(), indent=4))
|
||||
if output_json:
|
||||
logger.info(f"Writing output to {output_json}.")
|
||||
with open(output_json, "w") as f:
|
||||
output_token_info = report_utility.get_output_tokens(tokenizer)
|
||||
f.write(json.dumps(output_token_info, indent=4))
|
||||
if request_json:
|
||||
logger.info(f"Writing request information to {request_json}.")
|
||||
with open(request_json, "w") as f:
|
||||
f.write(json.dumps(report_utility.get_request_info(tokenizer)))
|
||||
logger, kwargs, options.streaming)
|
||||
# Generate reports for statistics, output tokens, and request info.
|
||||
generate_json_report(options.report_json,
|
||||
report_utility.get_statistics_dict)
|
||||
generate_json_report(
|
||||
options.output_json,
|
||||
partial(report_utility.get_output_tokens, tokenizer))
|
||||
generate_json_report(
|
||||
options.request_json,
|
||||
partial(report_utility.get_request_info, tokenizer))
|
||||
report_utility.report_statistics()
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Keyboard interrupt, exiting benchmark...")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user