[TRTLLM-5516] perf: replicate dummy request for cuda graph padding (cherry-pick #4729) (#5190)

Signed-off-by: QI JUN <22017000+QiJune@users.noreply.github.com>
Signed-off-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>
Co-authored-by: QI JUN <22017000+QiJune@users.noreply.github.com>
This commit is contained in:
Kaiyu Xie 2025-06-14 00:36:15 +08:00 committed by GitHub
parent bfa3b59bb6
commit 746394e990
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -446,6 +446,7 @@ class PyTorchModelEngine(ModelEngine):
# with different KV cache managers.
self.kv_cache_manager_key = KV_CACHE_MANAGER_KEY
self.lora_model_config: Optional[LoraModelConfig] = None
self.cuda_graph_dummy_request = None
def set_lora_model_config(self, lora_target_modules: list[str],
trtllm_modules_to_hf_modules: dict[str, str]):
@ -464,6 +465,10 @@ class PyTorchModelEngine(ModelEngine):
logger.info("Skipping warm up as no KV Cache manager allocated.")
return
# The lifetime of model engine and kv cache manager can be different.
# Reset the global cuda graph dummy request to None in warmup.
self.cuda_graph_dummy_request = None
def get_cuda_graph_warmup_request(batch_size):
available_blocks = kv_cache_manager.get_num_free_blocks()
if available_blocks >= batch_size:
@ -781,7 +786,7 @@ class PyTorchModelEngine(ModelEngine):
return self.spec_metadata
def _get_padded_batch(self, scheduled_requests: ScheduledRequests,
kv_cache_manager):
kv_cache_manager) -> int:
can_run_cuda_graph = scheduled_requests.can_run_cuda_graph
batch_size = scheduled_requests.batch_size
new_batch_size = batch_size
@ -797,35 +802,35 @@ class PyTorchModelEngine(ModelEngine):
if (not self._run_cuda_graphs or not self._cuda_graph_padding_enabled
or not can_run_cuda_graph
or new_batch_size > self._max_cuda_graph_batch_size):
return None
return 0
padded_batch_size = self._round_up_batch_size(new_batch_size)
if batch_size == padded_batch_size:
return None
return 0
padding_size = padded_batch_size - batch_size
if padding_size + scheduled_requests.batch_size > self.batch_size:
return 0
available_blocks = kv_cache_manager.get_num_free_blocks()
# No padding if:
# 1) Not enough KV cache space.
# 2) It would create too many concurrent requests.
# 2 is not strictly required, but we should probably
# No padding if it would create too many concurrent requests.
# This is not strictly required, but we should probably
# respect the requirement just in case that changes in the future.
if available_blocks < padding_size or padding_size + scheduled_requests.batch_size > self.batch_size:
return None
if self.cuda_graph_dummy_request is None:
available_blocks = kv_cache_manager.get_num_free_blocks()
# No padding if not enough KV cache space
if available_blocks < 1:
return 0
# Set the dummy request ids starting at (uint64 max value - padding_size - 1) to avoid conflict with
# active request IDs
max_req_id = MAX_UINT64 - padding_size - 1
generation_requests = kv_cache_manager.add_dummy_requests(
[max_req_id + i + 1 for i in range(padding_size)],
is_gen=True,
max_num_draft_tokens=self.max_draft_len)
for req in generation_requests:
req.is_cuda_graph_dummy = True
scheduled_requests.generation_requests.extend(generation_requests)
return generation_requests
self.cuda_graph_dummy_request = kv_cache_manager.add_dummy_requests(
[MAX_UINT64 - 1],
is_gen=True,
max_num_draft_tokens=self.max_draft_len)[0]
self.cuda_graph_dummy_request.is_cuda_graph_dummy = True
scheduled_requests.generation_requests.extend(
[self.cuda_graph_dummy_request] * padding_size)
return padding_size
@contextlib.contextmanager
def _maybe_pad_batch(self, scheduled_requests: ScheduledRequests,
@ -835,20 +840,16 @@ class PyTorchModelEngine(ModelEngine):
If using CUDA graphs, this method will add dummy requests to the given
batch so we can always use a CUDA graph. It is a context manager
because the padded requests allocate KV pages that should be freed
when you're done with them.
because the padded requests will be removed from scheduled requests.
"""
padding_requests = self._get_padded_batch(scheduled_requests,
kv_cache_manager)
padding_size = self._get_padded_batch(scheduled_requests,
kv_cache_manager)
try:
yield scheduled_requests
finally:
if padding_requests is not None:
padding_len = len(padding_requests)
if padding_size > 0:
scheduled_requests.generation_requests = scheduled_requests.generation_requests[:
-padding_len]
for req in padding_requests:
kv_cache_manager.free_resources(req)
-padding_size]
def _round_up_batch_size(self, batch_size: int) -> int:
"""
@ -1217,11 +1218,11 @@ class PyTorchModelEngine(ModelEngine):
range(len(position_ids),
len(position_ids) + len(generation_requests))))
for request in generation_requests:
if new_tokens_device is None or request.py_batch_idx is None:
if new_tokens_device is None or request.py_batch_idx is None or request.is_cuda_graph_dummy:
# the request has no previous tensor:
# (1) new_tokens_device is None, which means overlap scheduler is disabled; or
# (2) request.py_batch_idx is None, which means the request has no previous batch.
# the second condition includes dummy generation requests created for CUDA graph padding.
# (2) request.py_batch_idx is None, which means the request has no previous batch; or
# (3) request.is_cuda_graph_dummy, which means dummy generation requests created for CUDA graph padding.
# these dummy generation requests should be at the end of generation_requests.
# skip adding their input_ids so that new_tokens_device can be aligned to the correct positions.
if not request.is_cuda_graph_dummy: