[TRTLLM-5233][feat]: Add chunking to PyT heuristic for trtllm-bench. (#4133)

* Add chunking to PyT heuristic.

Signed-off-by: Frank Di Natale <3429989+FrankD412@users.noreply.github.com>

* Cast tokens and batch size to ints.

Signed-off-by: Frank Di Natale <3429989+FrankD412@users.noreply.github.com>

---------

Signed-off-by: Frank Di Natale <3429989+FrankD412@users.noreply.github.com>
This commit is contained in:
Frank 2025-05-13 09:47:06 -04:00 committed by GitHub
parent 44d6adfb68
commit c0c3c7f68c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -83,6 +83,8 @@ def get_settings(params: dict, dataset_metadata: DatasetMetadata, model: str,
Dict[str, Union[str, int]]: Properties for runtime config.
"""
extra_llm_api_options = params.get("extra_llm_api_options")
enable_chunked_prefill = params.get("enable_chunked_prefill", False)
kv_cache_dtype = "auto"
if extra_llm_api_options:
with open(extra_llm_api_options, 'r') as f:
@ -92,6 +94,9 @@ def get_settings(params: dict, dataset_metadata: DatasetMetadata, model: str,
kv_cache_dtype = llm_args_dict["pytorch_backend_config"][
"kv_cache_dtype"]
enable_chunked_prefill = llm_args_dict.get("enable_chunked_prefill",
enable_chunked_prefill)
world_config = {
"pp_size": params.get("pp"),
"tp_size": params.get("tp"),
@ -133,6 +138,13 @@ def get_settings(params: dict, dataset_metadata: DatasetMetadata, model: str,
f"Using heuristics or pre-defined settings: max_batch_size={max_batch_size}, max_num_tokens={max_num_tokens}."
)
# If chunked prefill is disabled, we need to ensure that the max_num_tokens is at least the max_isl
if not enable_chunked_prefill and max_num_tokens < dataset_metadata.max_isl:
logger.warning(
f"Chunked prefill is disabled, but max_num_tokens ({max_num_tokens}) is less than the max ISL ({dataset_metadata.max_isl}). "
f"Forcing max_num_tokens to {dataset_metadata.max_isl}.")
max_num_tokens = dataset_metadata.max_isl
pyt_options = {
"use_cuda_graph": True,
"cuda_graph_padding_enabled": True,
@ -146,9 +158,9 @@ def get_settings(params: dict, dataset_metadata: DatasetMetadata, model: str,
"sw_version": version("tensorrt_llm"),
"model_path": model_path,
"settings_config": {
"max_batch_size": max_batch_size,
"max_num_tokens": max_num_tokens,
"chunking": False,
"max_batch_size": int(max_batch_size),
"max_num_tokens": int(max_num_tokens),
"chunking": enable_chunked_prefill,
},
"world_config": world_config,
"backend": backend,