[None][fix] Assign [] to req.py_draft_tokens instead of None when spec decode is off (#7511)

Signed-off-by: Zheyu Fu <zheyuf@NVIDIA.com>
This commit is contained in:
Zheyu Fu 2025-09-23 06:54:18 -07:00 committed by GitHub
parent 16bb76c31d
commit 34963ec39c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 71 additions and 2 deletions

View File

@ -977,7 +977,7 @@ class PyExecutor:
# When overlap scheduler is enabled, and we already prepared the draft tokens in the previous batch,
# we don't need to initialize py_draft_tokens at this stage because we haven't append the accepted tokens to the request yet.
if not self.has_previous_draft_tokens:
# If speculation is off, this function sets py_draft_tokens to None
# If speculation is off, this function sets py_draft_tokens to []
# for all active requests. If it's on, we initialize py_draft_tokens
# with dummy draft tokens to make the scheduler aware of the fact
# that speculation is about to happen.
@ -1136,7 +1136,7 @@ class PyExecutor:
req.py_draft_tokens = [0] * max_draft_len
req.py_draft_pages_allocated = max_draft_len
else:
req.py_draft_tokens = None
req.py_draft_tokens = []
req.py_draft_pages_allocated = 0
except Exception as e:

View File

@ -108,6 +108,75 @@ def test_dynamic_spec_decode(enforce_single_worker,
assert text_spec == text_ref
# This test is a supplement to test_dynamic_spec_decode, because forcing single process will disable people to use SIGINT(ctrl-c) when testing
# Dynamic spec decode in this test is expected to firstly start with mode OFF
# then it naturally turns ON when the remaining effective requests is less than self.concurrency.
# Example: (logic in tensorrt_llm._torch.speculative.drafter.should_use_spec_decode)
# At start: len(requests): 3, max_batch_size: 3, token_cap: 1638 -> num_effective_requests: 3, self.max_concurrency: 2 -> spec decode OFF
# Later: len(requests): 1, max_batch_size: 3, token_cap: 1638 -> num_effective_requests: 1, self.max_concurrency: 2 -> spec decode ON
@pytest.mark.parametrize("disable_overlap_scheduler", [True, False])
@pytest.mark.high_cuda_memory
def test_dynamic_spec_decode_without_force_single_process(
disable_overlap_scheduler: bool):
total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
if total_mem_gb < 35:
pytest.skip("Not enough memory to load target + draft model")
models_path = llm_models_root()
eagle_model_dir = f"{models_path}/EAGLE3-LLaMA3.1-Instruct-8B"
target_model_dir = f"{models_path}/llama-3.1-model/Llama-3.1-8B-Instruct"
# Allow with 3 concurrent requests
max_batch_size = 3
max_draft_len = 4
kv_cache_config = KvCacheConfig(enable_block_reuse=True, max_tokens=8192)
cuda_graph_config = CudaGraphConfig(batch_sizes=[1])
llm_common_config = dict(
model=target_model_dir,
attn_backend="TRTLLM",
disable_overlap_scheduler=disable_overlap_scheduler,
cuda_graph_config=cuda_graph_config,
max_batch_size=max_batch_size,
kv_cache_config=kv_cache_config,
max_seq_len=4096,
)
spec_config = EagleDecodingConfig(
max_draft_len=max_draft_len,
speculative_model_dir=eagle_model_dir,
# Llama 3 does not support one model eagle.
eagle3_one_model=False,
# allow speculation only when <= 2 effective request
max_concurrency=2,
)
llm_spec = LLM(**llm_common_config, speculative_config=spec_config)
# Output tests
prompts = [
"The capital of France is",
"The president of the United States is",
"What is the capital of Australia?",
"Explain in one sentence why the sky is blue.",
"Who wrote the book 'Pride and Prejudice'?",
"List three U.S. national holidays in the year 2025.",
"Who painted the Mona Lisa?",
]
sampling_params = SamplingParams(max_tokens=10, temperature=0)
results_spec = llm_spec.generate(prompts, sampling_params)
generated_text_spec = [result.outputs[0].text for result in results_spec]
llm_spec.shutdown()
llm_ref = LLM(**llm_common_config)
results_ref = llm_ref.generate(prompts, sampling_params)
generated_text_ref = [result.outputs[0].text for result in results_ref]
llm_ref.shutdown()
for text_spec, text_ref in zip(generated_text_spec, generated_text_ref):
# The spec decode algorithm currently guarantees identical results
assert text_spec == text_ref
def test_should_use_spec_decode():
from tensorrt_llm._torch.speculative.drafter import Drafter