Fix DFlash prefix cache corruption due to missing lookahead block (#42971)

Signed-off-by: Shreyas Kulkarni <shreyas.gp269@gmail.com>
This commit is contained in:
Shreyas Kulkarni
2026-06-02 08:06:33 -04:00
committed by GitHub
parent f69ede495b
commit 0eeba5eec1
2 changed files with 171 additions and 3 deletions
@@ -0,0 +1,154 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from types import SimpleNamespace
import torch
from tests.v1.core.utils import create_requests
from vllm.config import (
CacheConfig,
ModelConfig,
ParallelConfig,
SchedulerConfig,
SpeculativeConfig,
VllmConfig,
)
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.kv_cache_interface import (
FullAttentionSpec,
KVCacheConfig,
KVCacheGroupSpec,
)
from vllm.v1.structured_output import StructuredOutputManager
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
# Matches defaults from tests/v1/spec_decode/test_eagle.py
DFLASH_TARGET_DIR = "Qwen/Qwen3-8B"
DFLASH_DRAFT_DIR = "z-lab/Qwen3-8B-DFlash-b16"
BLOCK_SIZE = 16
NUM_BLOCKS = 8
NUM_SPECULATIVE_TOKENS = 3
def _dflash_speculative_config(num_speculative_tokens: int) -> SpeculativeConfig:
model_config = ModelConfig(
model=DFLASH_TARGET_DIR,
runner="generate",
max_model_len=100,
trust_remote_code=True,
)
return SpeculativeConfig(
target_model_config=model_config,
target_parallel_config=ParallelConfig(),
model=DFLASH_DRAFT_DIR,
method="dflash",
num_speculative_tokens=num_speculative_tokens,
)
def _create_dflash_scheduler(num_speculative_tokens: int) -> Scheduler:
speculative_config = _dflash_speculative_config(num_speculative_tokens)
model_config = speculative_config.target_model_config
scheduler_config = SchedulerConfig(
max_num_seqs=16,
max_num_batched_tokens=8192,
max_model_len=model_config.max_model_len,
is_encoder_decoder=model_config.is_encoder_decoder,
)
cache_config = CacheConfig(
block_size=BLOCK_SIZE,
gpu_memory_utilization=0.9,
cache_dtype="auto",
enable_prefix_caching=False,
)
vllm_config = VllmConfig(
scheduler_config=scheduler_config,
model_config=model_config,
cache_config=cache_config,
parallel_config=ParallelConfig(),
speculative_config=speculative_config,
)
kv_cache_config = KVCacheConfig(
num_blocks=NUM_BLOCKS,
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(
["layer"],
FullAttentionSpec(
block_size=BLOCK_SIZE,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
),
)
],
)
cache_config.num_gpu_blocks = NUM_BLOCKS
return Scheduler(
vllm_config=vllm_config,
kv_cache_config=kv_cache_config,
block_size=BLOCK_SIZE,
log_stats=True,
structured_output_manager=StructuredOutputManager(vllm_config),
)
def test_dflash_prefill_reserves_lookahead_blocks():
scheduler = _create_dflash_scheduler(NUM_SPECULATIVE_TOKENS)
assert scheduler.num_lookahead_tokens == NUM_SPECULATIVE_TOKENS + 1
(request,) = create_requests(
num_requests=1,
num_tokens=BLOCK_SIZE,
block_size=BLOCK_SIZE,
)
scheduler.add_request(request)
output = scheduler.schedule()
assert output.num_scheduled_tokens[request.request_id] == BLOCK_SIZE
# prefill block + one lookahead block
assert len(output.scheduled_new_reqs[0].block_ids[0]) == 2
def test_dflash_first_prefill_query_window_fits_allocated_blocks():
scheduler = _create_dflash_scheduler(NUM_SPECULATIVE_TOKENS)
(request,) = create_requests(
num_requests=1,
num_tokens=BLOCK_SIZE,
block_size=BLOCK_SIZE,
)
scheduler.add_request(request)
output = scheduler.schedule()
block_ids = output.scheduled_new_reqs[0].block_ids[0]
query_positions = range(BLOCK_SIZE, BLOCK_SIZE + scheduler.num_lookahead_tokens)
assert all(pos // BLOCK_SIZE < len(block_ids) for pos in query_positions)
def test_dflash_drafter_window_reserves_bonus_token():
# DFlash's drafter window is num_spec + 1 (the extra slot is the bonus token),
# so max_seq_len + num_spec + 1 must stay within the draft model's max len.
input_fits_in_drafter = GPUModelRunner._input_fits_in_drafter
dflash_runner = SimpleNamespace(
num_spec_tokens=NUM_SPECULATIVE_TOKENS,
effective_drafter_max_model_len=100,
speculative_config=_dflash_speculative_config(NUM_SPECULATIVE_TOKENS),
)
# window = 4, so 96 fits (96 + 4 == 100) but 97 does not (97 + 4 == 101)
assert input_fits_in_drafter(dflash_runner, SimpleNamespace(max_seq_len=96))
assert not input_fits_in_drafter(dflash_runner, SimpleNamespace(max_seq_len=97))
assert not input_fits_in_drafter(dflash_runner, None) # no metadata
# Other drafters don't reserve the bonus token, so 97 fits (97 + 3 == 100).
plain_runner = SimpleNamespace(
num_spec_tokens=NUM_SPECULATIVE_TOKENS,
effective_drafter_max_model_len=100,
speculative_config=SimpleNamespace(use_dflash=lambda: False),
)
assert input_fits_in_drafter(plain_runner, SimpleNamespace(max_seq_len=97))
+17 -3
View File
@@ -4362,6 +4362,21 @@ class GPUModelRunner(
return None
def _input_fits_in_drafter(
self, common_attn_metadata: CommonAttentionMetadata | None
) -> bool:
if common_attn_metadata is None:
return False
assert self.speculative_config is not None
# DFlash queries one extra token (the bonus token) beyond num_spec_tokens
num_drafter_query_tokens = self.num_spec_tokens + (
1 if self.speculative_config.use_dflash() else 0
)
return (
common_attn_metadata.max_seq_len + num_drafter_query_tokens
<= self.effective_drafter_max_model_len
)
@torch.inference_mode
def sample_tokens(
self, grammar_output: "GrammarOutput | None"
@@ -4441,9 +4456,8 @@ class GPUModelRunner(
propose_drafts_after_bookkeeping = False
if spec_config is not None:
# Decide whether to run the drafter or zero out draft tokens.
input_fits_in_drafter = spec_decode_common_attn_metadata is not None and (
spec_decode_common_attn_metadata.max_seq_len + self.num_spec_tokens
<= self.effective_drafter_max_model_len
input_fits_in_drafter = self._input_fits_in_drafter(
spec_decode_common_attn_metadata
)
use_gpu_toks = (
spec_config.use_eagle()