From 137fe35539ea182f1495f5021bfda97c729e50c3 Mon Sep 17 00:00:00 2001 From: Yukun He <23156053+hyukn@users.noreply.github.com> Date: Mon, 9 Jun 2025 19:19:16 +0800 Subject: [PATCH] fix: Fix warmup phase batch size out of range. (#4986) Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com> Co-authored-by: QI JUN <22017000+QiJune@users.noreply.github.com> --- 3rdparty/cutlass | 2 +- tensorrt_llm/_torch/pyexecutor/model_engine.py | 13 +++++++------ tests/integration/test_lists/waives.txt | 1 - tests/unittest/llmapi/test_llm.py | 1 - 4 files changed, 8 insertions(+), 9 deletions(-) diff --git a/3rdparty/cutlass b/3rdparty/cutlass index 8206e7a0f5..afa1772203 160000 --- a/3rdparty/cutlass +++ b/3rdparty/cutlass @@ -1 +1 @@ -Subproject commit 8206e7a0f57a9a057cdd2c3bb4899bd5154a82e1 +Subproject commit afa1772203677c5118fcd82537a9c8fefbcc7008 diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 80bbcea186..f04425d9a6 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -584,16 +584,17 @@ class PyTorchModelEngine(ModelEngine): available_blocks = kv_cache_manager.get_num_free_blocks() + maximum_tunable_num_tokens = min( + self.batch_size * num_tokens_per_request, self.max_num_tokens, + available_blocks * kv_cache_manager.tokens_per_block) + # Calculate number of full-length requests and remaining tokens # Each request has num_tokens_per_request tokens, except possibly the last one - full_len_request_num = self.max_num_tokens // num_tokens_per_request - remaining_tokens = self.max_num_tokens % num_tokens_per_request + full_len_request_num = maximum_tunable_num_tokens // num_tokens_per_request + remaining_tokens = maximum_tunable_num_tokens % num_tokens_per_request request_num = full_len_request_num if remaining_tokens == 0 else full_len_request_num + 1 - if self.max_num_tokens > available_blocks * kv_cache_manager.tokens_per_block: - return None, None - requests = kv_cache_manager.add_dummy_requests( request_ids=list(range(full_len_request_num)), token_nums=[num_tokens_per_request] * full_len_request_num, @@ -617,7 +618,7 @@ class PyTorchModelEngine(ModelEngine): result.context_requests = requests result.generation_requests = [] - return result, _create_extra_inputs(1, self.max_num_tokens) + return result, _create_extra_inputs(1, maximum_tunable_num_tokens) @contextlib.contextmanager def release_batch(result): diff --git a/tests/integration/test_lists/waives.txt b/tests/integration/test_lists/waives.txt index b4f14c182e..9ae913763e 100644 --- a/tests/integration/test_lists/waives.txt +++ b/tests/integration/test_lists/waives.txt @@ -428,7 +428,6 @@ test_e2e.py::test_ptq_quickstart_advanced_ngram[Llama-3.1-8B-Instruct-llama-3.1- unittest/_torch/auto_deploy/integration/test_ad_build.py SKIP (https://nvbugs/5318103) accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[tp2pp2-attn_backend=TRTLLM-torch_compile=False] SKIP (https://nvbugs/5318143) accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[tp2pp2-attn_backend=TRTLLM-torch_compile=True] SKIP (https://nvbugs/5318143) -test_e2e.py::test_openai_reasoning SKIP (https://nvbugs/5310329) examples/test_mistral.py::test_llm_mistral_v1_1gpu[mistral-7b-v0.1-float16-max_attention_window_size_4096-summarization_long] SKIP (https://nvbugs/5324976) accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype[mtp_nextn=0-overlap_scheduler=False] SKIP (https://nvbugs/5322354) accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype[mtp_nextn=0-overlap_scheduler=True] SKIP (https://nvbugs/5322354) diff --git a/tests/unittest/llmapi/test_llm.py b/tests/unittest/llmapi/test_llm.py index 43cdd7ced8..bb44ba57d4 100644 --- a/tests/unittest/llmapi/test_llm.py +++ b/tests/unittest/llmapi/test_llm.py @@ -1930,7 +1930,6 @@ def test_llm_get_stats(return_context_logits, enable_iter_req_stats): def test_llm_get_queued_stats(): - pytest.skip("https://nvbugspro.nvidia.com/bug/5325642") enable_iter_req_stats = True use_overlap = False tp_size = 1