From 0c7dd660d826e48cce3019aea2e459842ee3173a Mon Sep 17 00:00:00 2001 From: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com> Date: Sat, 7 Jun 2025 04:14:07 +0800 Subject: [PATCH] fix:https://nvbugs/5324248 (#4973) Signed-off-by: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/model_engine.py | 3 ++- tests/unittest/_torch/test_pytorch_model_engine.py | 11 ++++++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index aab4b1e823..ddba0c9d6a 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -313,7 +313,8 @@ class PyTorchModelEngine(ModelEngine): if mapping.has_pp(): init_pp_comm(mapping) self.dist = dist - ExpertStatistic.create(self.dist.rank) + if dist is not None: + ExpertStatistic.create(self.dist.rank) self.pytorch_backend_config = pytorch_backend_config self.spec_config = spec_config self.is_spec_decode = spec_config is not None diff --git a/tests/unittest/_torch/test_pytorch_model_engine.py b/tests/unittest/_torch/test_pytorch_model_engine.py index 672da52723..b52884a9b9 100644 --- a/tests/unittest/_torch/test_pytorch_model_engine.py +++ b/tests/unittest/_torch/test_pytorch_model_engine.py @@ -96,6 +96,9 @@ def create_model_engine_and_kvcache(config: PyTorchConfig = None): config = config if config else PyTorchConfig( use_cuda_graph=True, cuda_graph_padding_enabled=True) + config.cuda_graph_batch_sizes = [ + 1, 2, 4, 8, 16, 32, 64, 128 + ] if config.cuda_graph_batch_sizes is None else config.cuda_graph_batch_sizes test_batches = (5, 13) for batch_size in test_batches: assert batch_size not in config.cuda_graph_batch_sizes @@ -153,6 +156,7 @@ class PyTorchModelEngineTestCase(unittest.TestCase): batch.context_requests = [] batch.generation_requests = requests pages_before = kv_cache_manager.get_num_free_blocks() + new_dummy_block = 1 if model_engine.cuda_graph_dummy_request is None else 0 with model_engine._maybe_pad_batch( batch, kv_cache_manager) as padded_batch: if batch_size < 8 and max_seq_len < 25: @@ -165,8 +169,9 @@ class PyTorchModelEngineTestCase(unittest.TestCase): # The seqlen check makes sure we don't exceed the KV cache memory # budget. self.assertIs(batch, padded_batch) - self.assertEqual(kv_cache_manager.get_num_free_blocks(), - pages_before) + self.assertEqual( + kv_cache_manager.get_num_free_blocks() + new_dummy_block, + pages_before) kv_cache_manager.shutdown() @@ -205,7 +210,7 @@ class PyTorchModelEngineTestCase(unittest.TestCase): model_engine.forward(batch, resource_manager) expected_gen_pos_id = torch.tensor([prompt_len], - dtype=torch.int64, + dtype=torch.int32, device='cuda').unsqueeze(0) torch.testing.assert_close(model_engine.model.recorded_position_ids, expected_gen_pos_id,