mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[https://nvbugs/5402719][fix]: Add cuda graph dummy requests to the spec_resource_manager (#6258)
Signed-off-by: ziyixiong-nv <219238287+ziyixiong-nv@users.noreply.github.com>
This commit is contained in:
parent
96d004d800
commit
d853811190
@ -810,8 +810,11 @@ class PyTorchModelEngine(ModelEngine):
|
||||
is_draft_model=self.is_draft_model)
|
||||
return self.spec_metadata
|
||||
|
||||
def _get_padded_batch(self, scheduled_requests: ScheduledRequests,
|
||||
kv_cache_manager) -> int:
|
||||
def _get_padded_batch(
|
||||
self,
|
||||
scheduled_requests: ScheduledRequests,
|
||||
kv_cache_manager,
|
||||
spec_resource_manager: Optional[BaseResourceManager] = None) -> int:
|
||||
can_run_cuda_graph = scheduled_requests.can_run_cuda_graph
|
||||
batch_size = scheduled_requests.batch_size
|
||||
# The number of sequences in the batch is the number of prompts times the beam width.
|
||||
@ -847,13 +850,17 @@ class PyTorchModelEngine(ModelEngine):
|
||||
if available_blocks < 1:
|
||||
return 0
|
||||
|
||||
cuda_graph_dummy_request_ids = [MAX_UINT64 - 1]
|
||||
self.cuda_graph_dummy_request = kv_cache_manager.add_dummy_requests(
|
||||
[MAX_UINT64 - 1],
|
||||
cuda_graph_dummy_request_ids,
|
||||
is_gen=True,
|
||||
max_num_draft_tokens=self.max_draft_len,
|
||||
use_mrope=self.use_mrope,
|
||||
max_beam_width=self.max_beam_width)[0]
|
||||
self.cuda_graph_dummy_request.is_cuda_graph_dummy = True
|
||||
if spec_resource_manager is not None:
|
||||
spec_resource_manager.add_dummy_requests(
|
||||
request_ids=cuda_graph_dummy_request_ids)
|
||||
|
||||
scheduled_requests.generation_requests.extend(
|
||||
[self.cuda_graph_dummy_request] * padding_size)
|
||||
@ -861,8 +868,11 @@ class PyTorchModelEngine(ModelEngine):
|
||||
return padding_size
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _maybe_pad_batch(self, scheduled_requests: ScheduledRequests,
|
||||
kv_cache_manager):
|
||||
def _maybe_pad_batch(
|
||||
self,
|
||||
scheduled_requests: ScheduledRequests,
|
||||
kv_cache_manager,
|
||||
spec_resource_manager: Optional[BaseResourceManager] = None):
|
||||
"""
|
||||
CUDA graphs can only be used for specific batch sizes.
|
||||
|
||||
@ -871,7 +881,8 @@ class PyTorchModelEngine(ModelEngine):
|
||||
because the padded requests will be removed from scheduled requests.
|
||||
"""
|
||||
padding_size = self._get_padded_batch(scheduled_requests,
|
||||
kv_cache_manager)
|
||||
kv_cache_manager,
|
||||
spec_resource_manager)
|
||||
try:
|
||||
yield scheduled_requests
|
||||
finally:
|
||||
@ -2072,6 +2083,7 @@ class PyTorchModelEngine(ModelEngine):
|
||||
spec_metadata.is_spec_dec_dynamic_tree,
|
||||
spec_metadata.max_draft_len)
|
||||
else:
|
||||
spec_resource_manager = None
|
||||
spec_metadata = None
|
||||
|
||||
moe_load_balancer = None
|
||||
@ -2090,8 +2102,8 @@ class PyTorchModelEngine(ModelEngine):
|
||||
with MoeLoadBalancerIterContext(moe_load_balancer):
|
||||
return self._forward_step(inputs, gather_ids,
|
||||
gather_context_logits)
|
||||
with self._maybe_pad_batch(scheduled_requests,
|
||||
kv_cache_manager) as scheduled_requests:
|
||||
with self._maybe_pad_batch(scheduled_requests, kv_cache_manager,
|
||||
spec_resource_manager) as scheduled_requests:
|
||||
maybe_graph = self._maybe_get_cuda_graph(
|
||||
scheduled_requests, spec_config=self.spec_config)
|
||||
if maybe_graph is not None:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user