Signed-off-by: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com>
This commit is contained in:
nv-guomingz 2025-06-07 04:14:07 +08:00 committed by GitHub
parent 20d0649f19
commit 0c7dd660d8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 10 additions and 4 deletions

View File

@ -313,7 +313,8 @@ class PyTorchModelEngine(ModelEngine):
if mapping.has_pp():
init_pp_comm(mapping)
self.dist = dist
ExpertStatistic.create(self.dist.rank)
if dist is not None:
ExpertStatistic.create(self.dist.rank)
self.pytorch_backend_config = pytorch_backend_config
self.spec_config = spec_config
self.is_spec_decode = spec_config is not None

View File

@ -96,6 +96,9 @@ def create_model_engine_and_kvcache(config: PyTorchConfig = None):
config = config if config else PyTorchConfig(
use_cuda_graph=True, cuda_graph_padding_enabled=True)
config.cuda_graph_batch_sizes = [
1, 2, 4, 8, 16, 32, 64, 128
] if config.cuda_graph_batch_sizes is None else config.cuda_graph_batch_sizes
test_batches = (5, 13)
for batch_size in test_batches:
assert batch_size not in config.cuda_graph_batch_sizes
@ -153,6 +156,7 @@ class PyTorchModelEngineTestCase(unittest.TestCase):
batch.context_requests = []
batch.generation_requests = requests
pages_before = kv_cache_manager.get_num_free_blocks()
new_dummy_block = 1 if model_engine.cuda_graph_dummy_request is None else 0
with model_engine._maybe_pad_batch(
batch, kv_cache_manager) as padded_batch:
if batch_size < 8 and max_seq_len < 25:
@ -165,8 +169,9 @@ class PyTorchModelEngineTestCase(unittest.TestCase):
# The seqlen check makes sure we don't exceed the KV cache memory
# budget.
self.assertIs(batch, padded_batch)
self.assertEqual(kv_cache_manager.get_num_free_blocks(),
pages_before)
self.assertEqual(
kv_cache_manager.get_num_free_blocks() + new_dummy_block,
pages_before)
kv_cache_manager.shutdown()
@ -205,7 +210,7 @@ class PyTorchModelEngineTestCase(unittest.TestCase):
model_engine.forward(batch, resource_manager)
expected_gen_pos_id = torch.tensor([prompt_len],
dtype=torch.int64,
dtype=torch.int32,
device='cuda').unsqueeze(0)
torch.testing.assert_close(model_engine.model.recorded_position_ids,
expected_gen_pos_id,