From 5bc8c894f7cdb70c439af1e62666f93c6073368f Mon Sep 17 00:00:00 2001 From: Mike Iovine Date: Wed, 25 Jun 2025 15:51:20 -0400 Subject: [PATCH] [chore] Disable block reuse when draft model speculation is being used (#5448) Signed-off-by: Mike Iovine <6158008+mikeiovine@users.noreply.github.com> --- .../_torch/pyexecutor/py_executor_creator.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py index 354981680e..ff25d84fdb 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py @@ -158,6 +158,21 @@ def _mangle_executor_config(executor_config: ExecutorConfig): ) executor_config.kv_cache_config.enable_block_reuse = False + spec_config = executor_config.speculative_config + if spec_config is not None and spec_config.spec_dec_mode.has_draft_model(): + # The draft and target models have different KV cache managers to support + # different head sizes, dtypes, etc in the generic case. + # However, this line will set context_current_position > 0 if there are + # cached blocks: https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/_torch/pyexecutor/resource_manager.py#L310. + # It actually mutates the LLM request! As a result, when we try to allocate KV cache + # pages for the draft model, is_first_context_chunk returns False and + # no pages are allocated. + # We need to refactor LLMRequest to fix this. Disable block reuse for now. + logger.warning( + f"Disabling block reuse for speculation algorithm {spec_config.spec_dec_mode}" + ) + executor_config.kv_cache_config.enable_block_reuse = False + if pytorch_backend_config.attn_backend == "FLASHINFER_STAR_ATTENTION" and executor_config.enable_chunked_context: logger.warning( f"Disabling chunked context for {pytorch_backend_config.attn_backend} backend"