diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 3c44ae0e54..fe8964a1db 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -977,7 +977,7 @@ class PyExecutor: # When overlap scheduler is enabled, and we already prepared the draft tokens in the previous batch, # we don't need to initialize py_draft_tokens at this stage because we haven't append the accepted tokens to the request yet. if not self.has_previous_draft_tokens: - # If speculation is off, this function sets py_draft_tokens to None + # If speculation is off, this function sets py_draft_tokens to [] # for all active requests. If it's on, we initialize py_draft_tokens # with dummy draft tokens to make the scheduler aware of the fact # that speculation is about to happen. @@ -1136,7 +1136,7 @@ class PyExecutor: req.py_draft_tokens = [0] * max_draft_len req.py_draft_pages_allocated = max_draft_len else: - req.py_draft_tokens = None + req.py_draft_tokens = [] req.py_draft_pages_allocated = 0 except Exception as e: diff --git a/tests/unittest/_torch/speculative/test_dynamic_spec_decode.py b/tests/unittest/_torch/speculative/test_dynamic_spec_decode.py index 0c6cc35eaf..92937b3483 100644 --- a/tests/unittest/_torch/speculative/test_dynamic_spec_decode.py +++ b/tests/unittest/_torch/speculative/test_dynamic_spec_decode.py @@ -108,6 +108,75 @@ def test_dynamic_spec_decode(enforce_single_worker, assert text_spec == text_ref +# This test is a supplement to test_dynamic_spec_decode, because forcing single process will disable people to use SIGINT(ctrl-c) when testing +# Dynamic spec decode in this test is expected to firstly start with mode OFF +# then it naturally turns ON when the remaining effective requests is less than self.concurrency. +# Example: (logic in tensorrt_llm._torch.speculative.drafter.should_use_spec_decode) +# At start: len(requests): 3, max_batch_size: 3, token_cap: 1638 -> num_effective_requests: 3, self.max_concurrency: 2 -> spec decode OFF +# Later: len(requests): 1, max_batch_size: 3, token_cap: 1638 -> num_effective_requests: 1, self.max_concurrency: 2 -> spec decode ON +@pytest.mark.parametrize("disable_overlap_scheduler", [True, False]) +@pytest.mark.high_cuda_memory +def test_dynamic_spec_decode_without_force_single_process( + disable_overlap_scheduler: bool): + total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9 + if total_mem_gb < 35: + pytest.skip("Not enough memory to load target + draft model") + models_path = llm_models_root() + eagle_model_dir = f"{models_path}/EAGLE3-LLaMA3.1-Instruct-8B" + target_model_dir = f"{models_path}/llama-3.1-model/Llama-3.1-8B-Instruct" + + # Allow with 3 concurrent requests + max_batch_size = 3 + max_draft_len = 4 + kv_cache_config = KvCacheConfig(enable_block_reuse=True, max_tokens=8192) + cuda_graph_config = CudaGraphConfig(batch_sizes=[1]) + + llm_common_config = dict( + model=target_model_dir, + attn_backend="TRTLLM", + disable_overlap_scheduler=disable_overlap_scheduler, + cuda_graph_config=cuda_graph_config, + max_batch_size=max_batch_size, + kv_cache_config=kv_cache_config, + max_seq_len=4096, + ) + + spec_config = EagleDecodingConfig( + max_draft_len=max_draft_len, + speculative_model_dir=eagle_model_dir, + # Llama 3 does not support one model eagle. + eagle3_one_model=False, + # allow speculation only when <= 2 effective request + max_concurrency=2, + ) + + llm_spec = LLM(**llm_common_config, speculative_config=spec_config) + # Output tests + prompts = [ + "The capital of France is", + "The president of the United States is", + "What is the capital of Australia?", + "Explain in one sentence why the sky is blue.", + "Who wrote the book 'Pride and Prejudice'?", + "List three U.S. national holidays in the year 2025.", + "Who painted the Mona Lisa?", + ] + sampling_params = SamplingParams(max_tokens=10, temperature=0) + + results_spec = llm_spec.generate(prompts, sampling_params) + generated_text_spec = [result.outputs[0].text for result in results_spec] + llm_spec.shutdown() + + llm_ref = LLM(**llm_common_config) + results_ref = llm_ref.generate(prompts, sampling_params) + generated_text_ref = [result.outputs[0].text for result in results_ref] + llm_ref.shutdown() + + for text_spec, text_ref in zip(generated_text_spec, generated_text_ref): + # The spec decode algorithm currently guarantees identical results + assert text_spec == text_ref + + def test_should_use_spec_decode(): from tensorrt_llm._torch.speculative.drafter import Drafter