[fix] Update get_trtllm_bench_build_command to handle batch size and tokens (#6313)

Signed-off-by: Venky Ganesh <23023424+venkywonka@users.noreply.github.com>
This commit is contained in:
Venky 2025-07-31 21:08:09 -07:00 committed by GitHub
parent 4472f11bb7
commit ad5742b105
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -998,7 +998,6 @@ class MultiMetricPerfTest(AbstractPerfScriptTestClass):
def get_trtllm_bench_build_command(self, engine_dir) -> list:
model_dir = self.get_trtllm_bench_model()
dataset_path = os.path.join(engine_dir, "synthetic_data.json")
if model_dir == "":
pytest.skip("Model Name is not supported by trtllm-bench")
model_name = self._config.model_name
@ -1008,13 +1007,19 @@ class MultiMetricPerfTest(AbstractPerfScriptTestClass):
build_cmd = [
self._build_script, f"--log_level=info",
f"--workspace={engine_dir}", f"--model={hf_model_name}",
f"--model_path={model_dir}", "build", f"--dataset={dataset_path}",
f"--model_path={model_dir}", "build",
f"--tp_size={self._config.tp_size}",
f"--pp_size={self._config.pp_size}"
]
max_seq_len = max(self._config.input_lens) + max(
self._config.output_lens)
build_cmd.append(f"--max_seq_len={max_seq_len}")
# Add max_batch_size and max_num_tokens to ensure build matches runtime configuration
# Note: trtllm-bench requires both to be specified together (option group constraint)
assert self._config.max_batch_size > 0, f"max_batch_size must be > 0, got {self._config.max_batch_size}"
assert self._config.max_num_tokens > 0, f"max_num_tokens must be > 0, got {self._config.max_num_tokens}"
build_cmd.append(f"--max_batch_size={self._config.max_batch_size}")
build_cmd.append(f"--max_num_tokens={self._config.max_num_tokens}")
if self._config.quantization:
build_cmd.append(
f"--quantization={self._config.quantization.upper()}")