mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[None][refactor] Move draft token padding out of Drafter (#7134)
Signed-off-by: Mike Iovine <6158008+mikeiovine@users.noreply.github.com>
This commit is contained in:
parent
dbd4f21687
commit
8b216135f0
@ -42,7 +42,7 @@ from .guided_decoder import GuidedDecoder
|
||||
from .handle_logits import HandleLogits
|
||||
from .kv_cache_transceiver import KvCacheTransceiver
|
||||
from .llm_request import (ExecutorRequest, LlmRequest, LlmRequestState,
|
||||
LlmResponse)
|
||||
LlmResponse, get_draft_token_length)
|
||||
from .model_engine import ModelEngine
|
||||
from .sampler import Sampler, SampleState, SampleStateTensors
|
||||
from .scheduler import RequestScheduler, ScheduledRequests
|
||||
@ -966,6 +966,15 @@ class PyExecutor:
|
||||
self.drafter.prepare_draft_tokens(
|
||||
scheduled_batch, self.resource_manager)
|
||||
|
||||
# Pad draft tokens to the max draft length. This is for CUDA
|
||||
# graph compatibility.
|
||||
for req in scheduled_batch.generation_requests:
|
||||
max_draft_tokens = self.max_draft_len
|
||||
num_draft_tokens = get_draft_token_length(req)
|
||||
req.py_draft_tokens.extend(
|
||||
0 for _ in range(max_draft_tokens -
|
||||
num_draft_tokens))
|
||||
|
||||
batch_outputs = self._forward_step(scheduled_batch)
|
||||
self._execute_guided_decoder(scheduled_batch,
|
||||
batch_outputs['logits'])
|
||||
|
||||
@ -10,8 +10,7 @@ from tensorrt_llm.logger import logger
|
||||
|
||||
from ..pyexecutor.guided_decoder import GuidedDecoder
|
||||
from ..pyexecutor.handle_logits import HandleLogits
|
||||
from ..pyexecutor.llm_request import (LlmRequest, LlmRequestState,
|
||||
get_draft_token_length)
|
||||
from ..pyexecutor.llm_request import LlmRequest, LlmRequestState
|
||||
from ..pyexecutor.resource_manager import BaseResourceManager, ResourceManager
|
||||
from ..pyexecutor.sampler import Sampler, SampleState, TorchSampler
|
||||
from ..pyexecutor.scheduler import ScheduledRequests
|
||||
@ -326,15 +325,6 @@ class ModelDrafter(Drafter):
|
||||
|
||||
return new_requests
|
||||
|
||||
def _pad_to_max_draft_tokens(self,
|
||||
scheduled_requests: ScheduledRequests) -> None:
|
||||
"""Pad draft tokens to maximum length for all generation requests."""
|
||||
for req in scheduled_requests.generation_requests:
|
||||
max_draft_tokens = self.max_draft_tokens
|
||||
num_draft_tokens = get_draft_token_length(req)
|
||||
req.py_draft_tokens.extend(
|
||||
0 for _ in range(max_draft_tokens - num_draft_tokens))
|
||||
|
||||
def _execute_guided_decoder(self,
|
||||
scheduled_batch: ScheduledRequests,
|
||||
logits: torch.Tensor,
|
||||
@ -418,7 +408,6 @@ class ModelDrafter(Drafter):
|
||||
self._update_requests(previous_batch)
|
||||
self._process_decoded_tokens(previous_batch.scheduled_requests,
|
||||
req_id_to_old_request)
|
||||
self._pad_to_max_draft_tokens(scheduled_requests)
|
||||
|
||||
if self.guided_decoder is not None:
|
||||
self.guided_decoder.rollback_draft_tokens(scheduled_requests)
|
||||
|
||||
@ -87,13 +87,13 @@ class NGramPoolManager(BaseResourceManager):
|
||||
self,
|
||||
prefix: list[int],
|
||||
request_id: int,
|
||||
padding_id: int,
|
||||
max_sequence_length: int,
|
||||
):
|
||||
prefix_len = len(prefix)
|
||||
max_draft_token_length_this_step = max_sequence_length - 1 - prefix_len
|
||||
if max_draft_token_length_this_step <= 0: # No draft token is need if the prefix is long enough
|
||||
return [padding_id]
|
||||
return []
|
||||
|
||||
if request_id not in self.start_index: # Extend start_index and pool for a new request
|
||||
self.start_index[request_id] = 0
|
||||
if not self.is_public_pool:
|
||||
@ -125,8 +125,7 @@ class NGramPoolManager(BaseResourceManager):
|
||||
pool[pattern].remove(match)
|
||||
pool[pattern].add(new_match)
|
||||
|
||||
# Find match
|
||||
draft_tokens = [padding_id] # fallback value
|
||||
draft_tokens = []
|
||||
for size in range(min(self.max_matching_ngram_size, prefix_len - 1), 0,
|
||||
-1):
|
||||
pattern = tuple(prefix[-size:])
|
||||
@ -194,12 +193,7 @@ class NGramDrafter(Drafter):
|
||||
draft_tokens = self.spec_resource_manager.get_draft_tokens(
|
||||
prefix,
|
||||
request.request_id,
|
||||
padding_id=0,
|
||||
max_sequence_length=request.py_orig_prompt_len +
|
||||
request.py_max_new_tokens,
|
||||
)
|
||||
# Pad length to `self.max_draft_len`
|
||||
if len(draft_tokens) > 0:
|
||||
pad_length = self.max_draft_len - len(draft_tokens)
|
||||
draft_tokens.extend([0] * pad_length)
|
||||
request.py_draft_tokens = draft_tokens
|
||||
|
||||
Loading…
Reference in New Issue
Block a user