mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-18 08:45:05 +08:00
[chore] Disable block reuse when draft model speculation is being used (#5448)
Signed-off-by: Mike Iovine <6158008+mikeiovine@users.noreply.github.com>
This commit is contained in:
parent
205c97a4ae
commit
5bc8c894f7
@ -158,6 +158,21 @@ def _mangle_executor_config(executor_config: ExecutorConfig):
|
||||
)
|
||||
executor_config.kv_cache_config.enable_block_reuse = False
|
||||
|
||||
spec_config = executor_config.speculative_config
|
||||
if spec_config is not None and spec_config.spec_dec_mode.has_draft_model():
|
||||
# The draft and target models have different KV cache managers to support
|
||||
# different head sizes, dtypes, etc in the generic case.
|
||||
# However, this line will set context_current_position > 0 if there are
|
||||
# cached blocks: https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/_torch/pyexecutor/resource_manager.py#L310.
|
||||
# It actually mutates the LLM request! As a result, when we try to allocate KV cache
|
||||
# pages for the draft model, is_first_context_chunk returns False and
|
||||
# no pages are allocated.
|
||||
# We need to refactor LLMRequest to fix this. Disable block reuse for now.
|
||||
logger.warning(
|
||||
f"Disabling block reuse for speculation algorithm {spec_config.spec_dec_mode}"
|
||||
)
|
||||
executor_config.kv_cache_config.enable_block_reuse = False
|
||||
|
||||
if pytorch_backend_config.attn_backend == "FLASHINFER_STAR_ATTENTION" and executor_config.enable_chunked_context:
|
||||
logger.warning(
|
||||
f"Disabling chunked context for {pytorch_backend_config.attn_backend} backend"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user