[None][refactor] Move draft token padding out of Drafter (#7134)

Signed-off-by: Mike Iovine <6158008+mikeiovine@users.noreply.github.com>
This commit is contained in:
Mike Iovine 2025-08-27 05:07:50 -04:00 committed by GitHub
parent dbd4f21687
commit 8b216135f0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 14 additions and 22 deletions

View File

@ -42,7 +42,7 @@ from .guided_decoder import GuidedDecoder
from .handle_logits import HandleLogits
from .kv_cache_transceiver import KvCacheTransceiver
from .llm_request import (ExecutorRequest, LlmRequest, LlmRequestState,
LlmResponse)
LlmResponse, get_draft_token_length)
from .model_engine import ModelEngine
from .sampler import Sampler, SampleState, SampleStateTensors
from .scheduler import RequestScheduler, ScheduledRequests
@ -966,6 +966,15 @@ class PyExecutor:
self.drafter.prepare_draft_tokens(
scheduled_batch, self.resource_manager)
# Pad draft tokens to the max draft length. This is for CUDA
# graph compatibility.
for req in scheduled_batch.generation_requests:
max_draft_tokens = self.max_draft_len
num_draft_tokens = get_draft_token_length(req)
req.py_draft_tokens.extend(
0 for _ in range(max_draft_tokens -
num_draft_tokens))
batch_outputs = self._forward_step(scheduled_batch)
self._execute_guided_decoder(scheduled_batch,
batch_outputs['logits'])

View File

@ -10,8 +10,7 @@ from tensorrt_llm.logger import logger
from ..pyexecutor.guided_decoder import GuidedDecoder
from ..pyexecutor.handle_logits import HandleLogits
from ..pyexecutor.llm_request import (LlmRequest, LlmRequestState,
get_draft_token_length)
from ..pyexecutor.llm_request import LlmRequest, LlmRequestState
from ..pyexecutor.resource_manager import BaseResourceManager, ResourceManager
from ..pyexecutor.sampler import Sampler, SampleState, TorchSampler
from ..pyexecutor.scheduler import ScheduledRequests
@ -326,15 +325,6 @@ class ModelDrafter(Drafter):
return new_requests
def _pad_to_max_draft_tokens(self,
scheduled_requests: ScheduledRequests) -> None:
"""Pad draft tokens to maximum length for all generation requests."""
for req in scheduled_requests.generation_requests:
max_draft_tokens = self.max_draft_tokens
num_draft_tokens = get_draft_token_length(req)
req.py_draft_tokens.extend(
0 for _ in range(max_draft_tokens - num_draft_tokens))
def _execute_guided_decoder(self,
scheduled_batch: ScheduledRequests,
logits: torch.Tensor,
@ -418,7 +408,6 @@ class ModelDrafter(Drafter):
self._update_requests(previous_batch)
self._process_decoded_tokens(previous_batch.scheduled_requests,
req_id_to_old_request)
self._pad_to_max_draft_tokens(scheduled_requests)
if self.guided_decoder is not None:
self.guided_decoder.rollback_draft_tokens(scheduled_requests)

View File

@ -87,13 +87,13 @@ class NGramPoolManager(BaseResourceManager):
self,
prefix: list[int],
request_id: int,
padding_id: int,
max_sequence_length: int,
):
prefix_len = len(prefix)
max_draft_token_length_this_step = max_sequence_length - 1 - prefix_len
if max_draft_token_length_this_step <= 0: # No draft token is need if the prefix is long enough
return [padding_id]
return []
if request_id not in self.start_index: # Extend start_index and pool for a new request
self.start_index[request_id] = 0
if not self.is_public_pool:
@ -125,8 +125,7 @@ class NGramPoolManager(BaseResourceManager):
pool[pattern].remove(match)
pool[pattern].add(new_match)
# Find match
draft_tokens = [padding_id] # fallback value
draft_tokens = []
for size in range(min(self.max_matching_ngram_size, prefix_len - 1), 0,
-1):
pattern = tuple(prefix[-size:])
@ -194,12 +193,7 @@ class NGramDrafter(Drafter):
draft_tokens = self.spec_resource_manager.get_draft_tokens(
prefix,
request.request_id,
padding_id=0,
max_sequence_length=request.py_orig_prompt_len +
request.py_max_new_tokens,
)
# Pad length to `self.max_draft_len`
if len(draft_tokens) > 0:
pad_length = self.max_draft_len - len(draft_tokens)
draft_tokens.extend([0] * pad_length)
request.py_draft_tokens = draft_tokens