[https://nvbugs/5652062][fix] Rewind kv_cache and reset draft tokens (#10160)

Signed-off-by: ziyixiong-nv <219238287+ziyixiong-nv@users.noreply.github.com>
This commit is contained in:
Ziyi Xiong 2025-12-25 22:13:51 +08:00 committed by GitHub
parent 46e4af5688
commit d8b5aeb061
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 43 additions and 46 deletions

View File

@ -1567,6 +1567,24 @@ class PyExecutor:
# For generation requests which have completed KV cache transfer
self._prepare_disagg_gen_transmission_complete(
scheduled_batch)
has_draft_batch = self.drafter is not None and self.previous_batch is not None and self.use_spec_decode and self.drafter.should_forward_draft_model(
scheduled_batch)
# Reset the draft tokens to avoid preparing resources for the draft model.
if self.drafter is not None and self.use_spec_decode and not has_draft_batch:
self.use_spec_decode = False
# We are not running the draft model. Remove the draft tokens and turn off spec
# decode so that the requests get handled correctly.
# One corner case: when we have at least one context request, we have to keep spec
# dec on. This ensures that we capture hidden states for requests that haven't done
# prefill yet.
self.use_spec_decode = False
self.model_engine.enable_spec_decode = len(
scheduled_batch.context_requests) > 0
if not self.model_engine.enable_spec_decode:
for request in scheduled_batch.all_requests():
request.py_draft_tokens = []
self.resource_manager.prepare_resources(scheduled_batch)
self._kv_connector_start_batch(scheduled_batch)
@ -1602,8 +1620,11 @@ class PyExecutor:
# so we'll set the target model's input to None and skip updating the target requests after target model forward.
use_previous_draft_tokens = self.has_previous_draft_tokens
num_accepted_tokens_device = None
if self.drafter is not None and (self.use_spec_decode or
use_previous_draft_tokens):
target_inputs = None
num_accepted_tokens_device = None
if has_draft_batch:
target_inputs, num_accepted_tokens_device = self._handle_speculative_decoding(
scheduled_batch, previous_tensors,
previous_tensors_device)
@ -2746,44 +2767,20 @@ class PyExecutor:
) -> Tuple[Optional[SampleStateTensorsMTP], Optional[torch.Tensor]]:
with request_context(is_draft=self.draft_model_engine is not None,
scheduled_requests=scheduled_batch):
# Do an early checking to see if we need to forward the draft model.
# If needed, the overlap should happen between the target requests and the draft requests.
# Otherwise, we can still do overlap between the previous target requests and the current target requests.
has_draft_batch = (
self.previous_batch is not None and self.use_spec_decode
and self.drafter.should_forward_draft_model(scheduled_batch))
target_outputs = self.previous_batch.sample_state and self.previous_batch.sample_state.device
assert target_outputs is not None, "target_outputs should not be None"
new_target_inputs, num_accepted_tokens_device = self._accept_draft_tokens(
scheduled_batch=scheduled_batch,
target_inputs=target_inputs,
target_outputs=target_outputs)
new_target_inputs = None
num_accepted_tokens_device = None
if has_draft_batch:
target_outputs = self.previous_batch.sample_state and self.previous_batch.sample_state.device
assert target_outputs is not None, "target_outputs should not be None"
new_target_inputs, num_accepted_tokens_device = self._accept_draft_tokens(
scheduled_batch=scheduled_batch,
target_inputs=target_inputs,
target_outputs=target_outputs)
self.drafter.generate_draft_tokens_with_overlap(
scheduled_batch, self.resource_manager,
previous_tensors.device if previous_tensors else None,
new_target_inputs, num_accepted_tokens_device)
if has_draft_batch:
self.drafter.generate_draft_tokens_with_overlap(
scheduled_batch, self.resource_manager,
previous_tensors.device if previous_tensors else None,
new_target_inputs, num_accepted_tokens_device)
# Pad draft tokens to the max draft length for CUDA graph compatibility
self.has_previous_draft_tokens = new_target_inputs is not None and new_target_inputs.next_draft_tokens is not None
else:
self.has_previous_draft_tokens = False
# We are not running the draft model. Remove the draft tokens and turn off spec
# decode so that the requests get handled correctly.
# One corner case: when we have at least one context request, we have to keep spec
# dec on. This ensures that we capture hidden states for requests that haven't done
# prefill yet.
self.use_spec_decode = False
self.model_engine.enable_spec_decode = len(
scheduled_batch.context_requests) > 0
if not self.model_engine.enable_spec_decode:
for request in scheduled_batch.all_requests():
request.py_draft_tokens = []
# Pad draft tokens to the max draft length for CUDA graph compatibility
self.has_previous_draft_tokens = new_target_inputs is not None and new_target_inputs.next_draft_tokens is not None
return new_target_inputs, num_accepted_tokens_device

View File

@ -599,11 +599,11 @@ class KVCacheManager(BaseResourceManager):
self.update_kv_cache_draft_token_location(scheduled_batch,
attn_metadata,
kv_cache_dtype_byte_size)
# rewind kv cache
for request in scheduled_batch.generation_requests:
if request.state != LlmRequestState.GENERATION_COMPLETE:
if request.py_rewind_len > 0:
self.rewind_kv_cache(request, request.py_rewind_len)
# rewind kv cache
for request in scheduled_batch.generation_requests:
if request.state != LlmRequestState.GENERATION_COMPLETE:
if request.py_rewind_len > 0:
self.rewind_kv_cache(request, request.py_rewind_len)
# For context requests, we store the blocks for reuse.
for request in scheduled_batch.context_requests:

View File

@ -588,7 +588,7 @@ def test_eagle3_cuda_graph_padding(disable_overlap_scheduler: bool):
max_batch_size = 4
max_draft_len = 4
kv_cache_config = KvCacheConfig(enable_block_reuse=enable_block_reuse,
max_tokens=8192)
max_tokens=4096)
cuda_graph_config = CudaGraphConfig(batch_sizes=[1, 2, 4],
enable_padding=True)
@ -599,7 +599,7 @@ def test_eagle3_cuda_graph_padding(disable_overlap_scheduler: bool):
cuda_graph_config=cuda_graph_config,
max_batch_size=max_batch_size,
kv_cache_config=kv_cache_config,
max_seq_len=8192,
max_seq_len=2048,
enable_chunked_prefill=enable_chunked_prefill,
)
@ -617,7 +617,7 @@ def test_eagle3_cuda_graph_padding(disable_overlap_scheduler: bool):
"The future of AI is"
]
sampling_params = SamplingParams(max_tokens=20, temperature=0)
sampling_params = SamplingParams(max_tokens=2048, temperature=0)
llm_spec.generate(prompts, sampling_params)
llm_spec.shutdown()