mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[https://nvbugs/5652062][fix] Rewind kv_cache and reset draft tokens (#10160)
Signed-off-by: ziyixiong-nv <219238287+ziyixiong-nv@users.noreply.github.com>
This commit is contained in:
parent
46e4af5688
commit
d8b5aeb061
@ -1567,6 +1567,24 @@ class PyExecutor:
|
||||
# For generation requests which have completed KV cache transfer
|
||||
self._prepare_disagg_gen_transmission_complete(
|
||||
scheduled_batch)
|
||||
|
||||
has_draft_batch = self.drafter is not None and self.previous_batch is not None and self.use_spec_decode and self.drafter.should_forward_draft_model(
|
||||
scheduled_batch)
|
||||
# Reset the draft tokens to avoid preparing resources for the draft model.
|
||||
if self.drafter is not None and self.use_spec_decode and not has_draft_batch:
|
||||
self.use_spec_decode = False
|
||||
# We are not running the draft model. Remove the draft tokens and turn off spec
|
||||
# decode so that the requests get handled correctly.
|
||||
# One corner case: when we have at least one context request, we have to keep spec
|
||||
# dec on. This ensures that we capture hidden states for requests that haven't done
|
||||
# prefill yet.
|
||||
self.use_spec_decode = False
|
||||
self.model_engine.enable_spec_decode = len(
|
||||
scheduled_batch.context_requests) > 0
|
||||
if not self.model_engine.enable_spec_decode:
|
||||
for request in scheduled_batch.all_requests():
|
||||
request.py_draft_tokens = []
|
||||
|
||||
self.resource_manager.prepare_resources(scheduled_batch)
|
||||
|
||||
self._kv_connector_start_batch(scheduled_batch)
|
||||
@ -1602,8 +1620,11 @@ class PyExecutor:
|
||||
# so we'll set the target model's input to None and skip updating the target requests after target model forward.
|
||||
use_previous_draft_tokens = self.has_previous_draft_tokens
|
||||
num_accepted_tokens_device = None
|
||||
if self.drafter is not None and (self.use_spec_decode or
|
||||
use_previous_draft_tokens):
|
||||
|
||||
target_inputs = None
|
||||
num_accepted_tokens_device = None
|
||||
|
||||
if has_draft_batch:
|
||||
target_inputs, num_accepted_tokens_device = self._handle_speculative_decoding(
|
||||
scheduled_batch, previous_tensors,
|
||||
previous_tensors_device)
|
||||
@ -2746,44 +2767,20 @@ class PyExecutor:
|
||||
) -> Tuple[Optional[SampleStateTensorsMTP], Optional[torch.Tensor]]:
|
||||
with request_context(is_draft=self.draft_model_engine is not None,
|
||||
scheduled_requests=scheduled_batch):
|
||||
# Do an early checking to see if we need to forward the draft model.
|
||||
# If needed, the overlap should happen between the target requests and the draft requests.
|
||||
# Otherwise, we can still do overlap between the previous target requests and the current target requests.
|
||||
has_draft_batch = (
|
||||
self.previous_batch is not None and self.use_spec_decode
|
||||
and self.drafter.should_forward_draft_model(scheduled_batch))
|
||||
target_outputs = self.previous_batch.sample_state and self.previous_batch.sample_state.device
|
||||
assert target_outputs is not None, "target_outputs should not be None"
|
||||
new_target_inputs, num_accepted_tokens_device = self._accept_draft_tokens(
|
||||
scheduled_batch=scheduled_batch,
|
||||
target_inputs=target_inputs,
|
||||
target_outputs=target_outputs)
|
||||
|
||||
new_target_inputs = None
|
||||
num_accepted_tokens_device = None
|
||||
if has_draft_batch:
|
||||
target_outputs = self.previous_batch.sample_state and self.previous_batch.sample_state.device
|
||||
assert target_outputs is not None, "target_outputs should not be None"
|
||||
new_target_inputs, num_accepted_tokens_device = self._accept_draft_tokens(
|
||||
scheduled_batch=scheduled_batch,
|
||||
target_inputs=target_inputs,
|
||||
target_outputs=target_outputs)
|
||||
self.drafter.generate_draft_tokens_with_overlap(
|
||||
scheduled_batch, self.resource_manager,
|
||||
previous_tensors.device if previous_tensors else None,
|
||||
new_target_inputs, num_accepted_tokens_device)
|
||||
|
||||
if has_draft_batch:
|
||||
self.drafter.generate_draft_tokens_with_overlap(
|
||||
scheduled_batch, self.resource_manager,
|
||||
previous_tensors.device if previous_tensors else None,
|
||||
new_target_inputs, num_accepted_tokens_device)
|
||||
|
||||
# Pad draft tokens to the max draft length for CUDA graph compatibility
|
||||
self.has_previous_draft_tokens = new_target_inputs is not None and new_target_inputs.next_draft_tokens is not None
|
||||
else:
|
||||
self.has_previous_draft_tokens = False
|
||||
# We are not running the draft model. Remove the draft tokens and turn off spec
|
||||
# decode so that the requests get handled correctly.
|
||||
# One corner case: when we have at least one context request, we have to keep spec
|
||||
# dec on. This ensures that we capture hidden states for requests that haven't done
|
||||
# prefill yet.
|
||||
self.use_spec_decode = False
|
||||
self.model_engine.enable_spec_decode = len(
|
||||
scheduled_batch.context_requests) > 0
|
||||
if not self.model_engine.enable_spec_decode:
|
||||
for request in scheduled_batch.all_requests():
|
||||
request.py_draft_tokens = []
|
||||
# Pad draft tokens to the max draft length for CUDA graph compatibility
|
||||
self.has_previous_draft_tokens = new_target_inputs is not None and new_target_inputs.next_draft_tokens is not None
|
||||
|
||||
return new_target_inputs, num_accepted_tokens_device
|
||||
|
||||
|
||||
@ -599,11 +599,11 @@ class KVCacheManager(BaseResourceManager):
|
||||
self.update_kv_cache_draft_token_location(scheduled_batch,
|
||||
attn_metadata,
|
||||
kv_cache_dtype_byte_size)
|
||||
# rewind kv cache
|
||||
for request in scheduled_batch.generation_requests:
|
||||
if request.state != LlmRequestState.GENERATION_COMPLETE:
|
||||
if request.py_rewind_len > 0:
|
||||
self.rewind_kv_cache(request, request.py_rewind_len)
|
||||
# rewind kv cache
|
||||
for request in scheduled_batch.generation_requests:
|
||||
if request.state != LlmRequestState.GENERATION_COMPLETE:
|
||||
if request.py_rewind_len > 0:
|
||||
self.rewind_kv_cache(request, request.py_rewind_len)
|
||||
|
||||
# For context requests, we store the blocks for reuse.
|
||||
for request in scheduled_batch.context_requests:
|
||||
|
||||
@ -588,7 +588,7 @@ def test_eagle3_cuda_graph_padding(disable_overlap_scheduler: bool):
|
||||
max_batch_size = 4
|
||||
max_draft_len = 4
|
||||
kv_cache_config = KvCacheConfig(enable_block_reuse=enable_block_reuse,
|
||||
max_tokens=8192)
|
||||
max_tokens=4096)
|
||||
cuda_graph_config = CudaGraphConfig(batch_sizes=[1, 2, 4],
|
||||
enable_padding=True)
|
||||
|
||||
@ -599,7 +599,7 @@ def test_eagle3_cuda_graph_padding(disable_overlap_scheduler: bool):
|
||||
cuda_graph_config=cuda_graph_config,
|
||||
max_batch_size=max_batch_size,
|
||||
kv_cache_config=kv_cache_config,
|
||||
max_seq_len=8192,
|
||||
max_seq_len=2048,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
)
|
||||
|
||||
@ -617,7 +617,7 @@ def test_eagle3_cuda_graph_padding(disable_overlap_scheduler: bool):
|
||||
"The future of AI is"
|
||||
]
|
||||
|
||||
sampling_params = SamplingParams(max_tokens=20, temperature=0)
|
||||
sampling_params = SamplingParams(max_tokens=2048, temperature=0)
|
||||
llm_spec.generate(prompts, sampling_params)
|
||||
llm_spec.shutdown()
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user