mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
fix:https://nvbugs/5324248 (#4973)
Signed-off-by: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com>
This commit is contained in:
parent
20d0649f19
commit
0c7dd660d8
@ -313,7 +313,8 @@ class PyTorchModelEngine(ModelEngine):
|
||||
if mapping.has_pp():
|
||||
init_pp_comm(mapping)
|
||||
self.dist = dist
|
||||
ExpertStatistic.create(self.dist.rank)
|
||||
if dist is not None:
|
||||
ExpertStatistic.create(self.dist.rank)
|
||||
self.pytorch_backend_config = pytorch_backend_config
|
||||
self.spec_config = spec_config
|
||||
self.is_spec_decode = spec_config is not None
|
||||
|
||||
@ -96,6 +96,9 @@ def create_model_engine_and_kvcache(config: PyTorchConfig = None):
|
||||
|
||||
config = config if config else PyTorchConfig(
|
||||
use_cuda_graph=True, cuda_graph_padding_enabled=True)
|
||||
config.cuda_graph_batch_sizes = [
|
||||
1, 2, 4, 8, 16, 32, 64, 128
|
||||
] if config.cuda_graph_batch_sizes is None else config.cuda_graph_batch_sizes
|
||||
test_batches = (5, 13)
|
||||
for batch_size in test_batches:
|
||||
assert batch_size not in config.cuda_graph_batch_sizes
|
||||
@ -153,6 +156,7 @@ class PyTorchModelEngineTestCase(unittest.TestCase):
|
||||
batch.context_requests = []
|
||||
batch.generation_requests = requests
|
||||
pages_before = kv_cache_manager.get_num_free_blocks()
|
||||
new_dummy_block = 1 if model_engine.cuda_graph_dummy_request is None else 0
|
||||
with model_engine._maybe_pad_batch(
|
||||
batch, kv_cache_manager) as padded_batch:
|
||||
if batch_size < 8 and max_seq_len < 25:
|
||||
@ -165,8 +169,9 @@ class PyTorchModelEngineTestCase(unittest.TestCase):
|
||||
# The seqlen check makes sure we don't exceed the KV cache memory
|
||||
# budget.
|
||||
self.assertIs(batch, padded_batch)
|
||||
self.assertEqual(kv_cache_manager.get_num_free_blocks(),
|
||||
pages_before)
|
||||
self.assertEqual(
|
||||
kv_cache_manager.get_num_free_blocks() + new_dummy_block,
|
||||
pages_before)
|
||||
|
||||
kv_cache_manager.shutdown()
|
||||
|
||||
@ -205,7 +210,7 @@ class PyTorchModelEngineTestCase(unittest.TestCase):
|
||||
|
||||
model_engine.forward(batch, resource_manager)
|
||||
expected_gen_pos_id = torch.tensor([prompt_len],
|
||||
dtype=torch.int64,
|
||||
dtype=torch.int32,
|
||||
device='cuda').unsqueeze(0)
|
||||
torch.testing.assert_close(model_engine.model.recorded_position_ids,
|
||||
expected_gen_pos_id,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user