TensorRT-LLMs/tensorrt_llm/_torch/speculative/drafter.py
Ziyi Xiong 536e8776cd
[TRTLLM-6668][feat] Enable overlap scheduler for two-model spec decoding (#7651)
Signed-off-by: ziyixiong-nv <219238287+ziyixiong-nv@users.noreply.github.com>
2025-09-16 07:33:44 +08:00

70 lines
2.4 KiB
Python

from abc import ABC, abstractmethod
from typing import List, Optional, final
from ..pyexecutor.llm_request import LlmRequest, get_draft_token_length
from ..pyexecutor.resource_manager import ResourceManager
from ..pyexecutor.scheduler import ScheduledRequests
class Drafter(ABC):
"""Abstract base class for all drafter implementations."""
def __init__(self, max_concurrency: Optional[int] = None) -> None:
self.max_concurrency = max_concurrency
@abstractmethod
def prepare_draft_tokens(
self,
scheduled_requests: ScheduledRequests,
resource_manager: Optional[ResourceManager] = None,
) -> None:
"""
Prepare the drafter tokens for the forward computation this step.
Args:
scheduled_requests: The scheduled requests for this iteration
"""
raise NotImplementedError
@final
def should_use_spec_decode(self, requests: List[LlmRequest],
max_batch_size: int, max_num_tokens: int,
max_draft_len: int) -> bool:
"""
You probably don't want to override this. ModelEngine
assumes that speculation is always on if max_concurrency
is not specified by the user's spec config.
"""
# Inputs typically validated upstream: max_batch_size>0, max_num_tokens>0, max_draft_len>=0
if self.max_concurrency is None:
return True
# Defensive guards; keep behavior explicit for zero/empty cases
if not requests or max_batch_size <= 0 or max_num_tokens <= 0:
return False
tokens_per_request = 1 + max_draft_len
token_cap = max_num_tokens // tokens_per_request
if token_cap <= 0:
return False
num_effective_requests = min(len(requests), max_batch_size, token_cap)
return num_effective_requests <= self.max_concurrency
@final
def pad_draft_tokens_for_cuda_graph(
self, scheduled_requests: ScheduledRequests) -> None:
"""
Pad draft tokens to the max draft length for CUDA graph compatibility.
Args:
scheduled_requests: The scheduled requests to pad
"""
for req in scheduled_requests.generation_requests:
max_draft_tokens = self.max_draft_tokens
num_draft_tokens = get_draft_token_length(req)
req.py_draft_tokens.extend(
0 for _ in range(max_draft_tokens - num_draft_tokens))