[#9463][feat] Add revision option to trtllm commands (#9498)

Signed-off-by: Aurelien Chartier <2567591+achartier@users.noreply.github.com>
This commit is contained in:
Aurelien Chartier 2025-11-26 17:30:01 -08:00 committed by GitHub
parent e76e149861
commit f2f197360d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 41 additions and 11 deletions

View File

@ -225,7 +225,7 @@ def latency_command(
if options.backend and options.backend.lower(
) in ALL_SUPPORTED_BACKENDS and options.backend.lower() != "tensorrt":
if bench_env.checkpoint_path is None:
snapshot_download(options.model)
snapshot_download(options.model, revision=bench_env.revision)
exec_settings = get_settings(params, metadata, bench_env.model,
bench_env.checkpoint_path)
@ -250,6 +250,7 @@ def latency_command(
param_hint="backend")
exec_settings["model"] = options.model
exec_settings["revision"] = bench_env.revision
engine_tokens = exec_settings["settings_config"]["max_num_tokens"]
# Update configuration with runtime options

View File

@ -350,7 +350,7 @@ def throughput_command(
# If we're dealing with a model name, perform a snapshot download to
# make sure we have a local copy of the model.
if bench_env.checkpoint_path is None:
snapshot_download(options.model)
snapshot_download(options.model, revision=bench_env.revision)
exec_settings = get_settings(params, metadata, bench_env.model,
bench_env.checkpoint_path)
@ -376,6 +376,7 @@ def throughput_command(
param_hint="backend")
exec_settings["model"] = options.model
exec_settings["revision"] = bench_env.revision
engine_bs = exec_settings["settings_config"]["max_batch_size"]
engine_tokens = exec_settings["settings_config"]["max_num_tokens"]

View File

@ -25,6 +25,7 @@ class RuntimeConfig(BaseModel):
model: str
model_path: Optional[Path] = None
engine_dir: Optional[Path] = None
revision: Optional[str] = None
sw_version: str
settings_config: ExecutorSettingsConfig
# TODO: this is a dict corresponding to the Mapping class, the type should be

View File

@ -14,6 +14,7 @@ class BenchmarkEnvironment(BaseModel):
model: str
checkpoint_path: Optional[Path]
workspace: Path
revision: Optional[str] = None
class InferenceRequest(BaseModel):

View File

@ -306,6 +306,7 @@ class ReportUtility:
"model": self.rt_cfg.model,
"model_path": str(self.rt_cfg.model_path),
"engine_dir": str(self.rt_cfg.engine_dir),
"revision": self.rt_cfg.revision,
"version": self.rt_cfg.sw_version,
},
}
@ -539,6 +540,7 @@ class ReportUtility:
"===========================================================\n"
f"Model:\t\t\t{engine['model']}\n"
f"Model Path:\t\t{engine['model_path']}\n"
f"Revision:\t\t{engine['revision'] or 'N/A'}\n"
f"Engine Directory:\t{engine['engine_dir']}\n"
f"TensorRT LLM Version:\t{engine['version']}\n"
f"Dtype:\t\t\t{pretrain_cfg['dtype']}\n"
@ -554,6 +556,7 @@ class ReportUtility:
"===========================================================\n"
f"Model:\t\t\t{engine['model']}\n"
f"Model Path:\t\t{engine['model_path']}\n"
f"Revision:\t\t{engine['revision'] or 'N/A'}\n"
f"TensorRT LLM Version:\t{engine['version']}\n"
f"Dtype:\t\t\t{engine['dtype']}\n"
f"KV Cache Dtype:\t\t{engine['kv_cache_dtype']}\n"

View File

@ -89,21 +89,21 @@ def create_dataset_from_stream(
while (line := stream.readline()) and len(task_ids) < max_requests:
# We expect the data to come in as a JSON string.
# For example:
# {"prompt": "Generate an infinite response to the following:
# {"task_id": 1, "prompt": "Generate an infinite response to the following:
# There once was a man who.", "output_tokens": 1000}
#
# For multimodal data, the data should be of the form:
# {"prompt": "Generate an infinite response to the following:
# {"task_id": 1, "prompt": "Generate an infinite response to the following:
# There once was a man who.", "output_tokens": 1000,
# "media_paths": ["/path/to/image1.jpg", "/path/to/image2.jpg"]}
#
# For LoRA data, the data should be of the form:
# {"prompt": "Generate an infinite response to the following:
# {"task_id": 1, "prompt": "Generate an infinite response to the following:
# There once was a man who.", "output_tokens": 1000,
# "lora_request": {"lora_name": "my_lora", "lora_int_id": 1, "lora_path": "/path/to/lora"}}
#
# Each line should be a complete JSON dictionary with no indentation
# or newline characters.
# or newline characters. The task_id field is required.
data = json.loads(line)
prompts.append(data.get("prompt"))
media_paths.append(data.get("media_paths", None))

View File

@ -1,4 +1,5 @@
from pathlib import Path
from typing import Optional
import click
@ -37,6 +38,11 @@ from tensorrt_llm.logger import logger, severity_map
type=click.Choice(severity_map.keys()),
default='info',
help="The logging level.")
@click.option("--revision",
type=str,
default=None,
help="The revision to use for the HuggingFace model "
"(branch name, tag name, or commit id).")
@click.pass_context
def main(
ctx,
@ -44,11 +50,13 @@ def main(
model_path: Path,
workspace: Path,
log_level: str,
revision: Optional[str],
) -> None:
logger.set_level(log_level)
ctx.obj = BenchmarkEnvironment(model=model,
checkpoint_path=model_path,
workspace=workspace)
workspace=workspace,
revision=revision)
# Create the workspace where we plan to store intermediate files.
ctx.obj.workspace.mkdir(parents=True, exist_ok=True)

View File

@ -92,6 +92,11 @@ from ..logger import logger, severity_map
is_flag=True,
default=False,
help="Flag for HF transformers.")
@click.option("--revision",
type=str,
default=None,
help="The revision to use for the HuggingFace model "
"(branch name, tag name, or commit id).")
@click.option("--extra_llm_api_options",
type=str,
default=None,
@ -106,7 +111,8 @@ def main(ctx, model: str, tokenizer: Optional[str], log_level: str,
max_num_tokens: int, max_seq_len: int, tp_size: int, pp_size: int,
ep_size: Optional[int], gpus_per_node: Optional[int],
kv_cache_free_gpu_memory_fraction: float, trust_remote_code: bool,
extra_llm_api_options: Optional[str], disable_kv_cache_reuse: bool):
revision: Optional[str], extra_llm_api_options: Optional[str],
disable_kv_cache_reuse: bool):
logger.set_level(log_level)
build_config = BuildConfig(max_batch_size=max_batch_size,
max_num_tokens=max_num_tokens,
@ -125,6 +131,7 @@ def main(ctx, model: str, tokenizer: Optional[str], log_level: str,
"moe_expert_parallel_size": ep_size,
"gpus_per_node": gpus_per_node,
"trust_remote_code": trust_remote_code,
"revision": revision,
"build_config": build_config,
"kv_cache_config": kv_cache_config,
}

View File

@ -95,6 +95,7 @@ def get_llm_args(
free_gpu_memory_fraction: float = 0.9,
num_postprocess_workers: int = 0,
trust_remote_code: bool = False,
revision: Optional[str] = None,
reasoning_parser: Optional[str] = None,
fail_fast_on_attention_window_too_large: bool = False,
otlp_traces_endpoint: Optional[str] = None,
@ -129,6 +130,7 @@ def get_llm_args(
"moe_expert_parallel_size": moe_expert_parallel_size,
"gpus_per_node": gpus_per_node,
"trust_remote_code": trust_remote_code,
"revision": revision,
"build_config": build_config,
"max_batch_size": max_batch_size,
"max_num_tokens": max_num_tokens,
@ -317,6 +319,11 @@ class ChoiceWithAlias(click.Choice):
is_flag=True,
default=False,
help="Flag for HF transformers.")
@click.option("--revision",
type=str,
default=None,
help="The revision to use for the HuggingFace model "
"(branch name, tag name, or commit id).")
@click.option(
"--extra_llm_api_options",
type=str,
@ -381,9 +388,9 @@ def serve(
ep_size: Optional[int], cluster_size: Optional[int],
gpus_per_node: Optional[int], kv_cache_free_gpu_memory_fraction: float,
num_postprocess_workers: int, trust_remote_code: bool,
extra_llm_api_options: Optional[str], reasoning_parser: Optional[str],
tool_parser: Optional[str], metadata_server_config_file: Optional[str],
server_role: Optional[str],
revision: Optional[str], extra_llm_api_options: Optional[str],
reasoning_parser: Optional[str], tool_parser: Optional[str],
metadata_server_config_file: Optional[str], server_role: Optional[str],
fail_fast_on_attention_window_too_large: bool,
otlp_traces_endpoint: Optional[str], enable_chunked_prefill: bool,
disagg_cluster_uri: Optional[str], media_io_kwargs: Optional[str],
@ -418,6 +425,7 @@ def serve(
free_gpu_memory_fraction=kv_cache_free_gpu_memory_fraction,
num_postprocess_workers=num_postprocess_workers,
trust_remote_code=trust_remote_code,
revision=revision,
reasoning_parser=reasoning_parser,
fail_fast_on_attention_window_too_large=
fail_fast_on_attention_window_too_large,