mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
* v1.5 Signed-off-by: wili-65535 <wili-65535@users.noreply.github.com> v1.5.4 Add back draft_overhead to spec dec stats Signed-off-by: Thor Johnsen <41591019+thorjohnsen@users.noreply.github.com> * v1.5.5: fix CI error Signed-off-by: wili-65535 <wili-65535@users.noreply.github.com> * v1.6: fix CI error 8196 > 8192 Signed-off-by: wili-65535 <wili-65535@users.noreply.github.com> * Address reviewer concerns Signed-off-by: Thor Johnsen <41591019+thorjohnsen@users.noreply.github.com> * Address reviewer concerns Signed-off-by: Thor Johnsen <41591019+thorjohnsen@users.noreply.github.com> * precommit run Signed-off-by: Thor Johnsen <41591019+thorjohnsen@users.noreply.github.com> * v2.0: Address reviewer concerns Signed-off-by: wili-65535 <wili-65535@users.noreply.github.com> * v2.1: add fix from wili Signed-off-by: wili-65535 <wili-65535@users.noreply.github.com> * Revert changes that require use of TypeAlias because that requires python version >= 3.10 Signed-off-by: Thor Johnsen <41591019+thorjohnsen@users.noreply.github.com> --------- Signed-off-by: Thor Johnsen <41591019+thorjohnsen@users.noreply.github.com> Signed-off-by: wili-65535 <wili-65535@users.noreply.github.com> Co-authored-by: wili-65535 <wili-65535@users.noreply.github.com>
211 lines
8.4 KiB
Python
211 lines
8.4 KiB
Python
from dataclasses import dataclass
|
|
from typing import List
|
|
|
|
from ordered_set import OrderedSet
|
|
|
|
from ..pyexecutor.llm_request import LlmRequest
|
|
from ..pyexecutor.resource_manager import BaseResourceManager
|
|
from ..pyexecutor.scheduler import ScheduledRequests
|
|
from .interface import SpecConfig, SpeculativeDecodingMode
|
|
|
|
|
|
@dataclass
|
|
class NGramConfig(SpecConfig):
|
|
"""
|
|
Configuration for N-gram drafter.
|
|
"""
|
|
# The name of speculative decoding.
|
|
spec_dec_name = "NGRAM"
|
|
|
|
num_extra_kv_tokens: int = 0
|
|
max_draft_tokens: int = 0
|
|
|
|
prompt_lookup_num_tokens: int = 5
|
|
max_matching_ngram_size: int = 5
|
|
end_id: int = -1
|
|
is_keep_all: bool = True
|
|
is_use_oldest: bool = True
|
|
is_public_pool: bool = True
|
|
|
|
def __post_init__(self) -> None:
|
|
self.spec_dec_mode = SpeculativeDecodingMode.from_string(
|
|
self.spec_dec_name)
|
|
self.max_draft_tokens = self.prompt_lookup_num_tokens
|
|
|
|
def update_from_model_config(self, model_config):
|
|
pass
|
|
|
|
|
|
class NGramPoolManager(BaseResourceManager):
|
|
"""
|
|
This class maintains the pattern-matches pairs for NGram drafter.
|
|
|
|
For example, one of the existed pairs could be: ["I","love"] -> [["apple", "because", "it", "is"], ["banana", "and"]].
|
|
|
|
Here we call ["I","love"] as `pattern`, and [["apple", "because", "it", "is"], ["banana", "and"]] as `matches`.
|
|
|
|
`pattern` is a list of token_ids. The pool provides corresponding draft tokens from the matches if the pattern appears at the tail of the sentence during generation.
|
|
|
|
`matches` is a list of candidate draft token_ids attaching to a pattern.
|
|
|
|
Arguments:
|
|
prompt_lookup_num_tokens: int
|
|
The length maximum of draft tokens (can be understood as length maximum of output draft tokens).
|
|
|
|
max_matching_ngram_size: int
|
|
The length maximum of searching tokens (can be understood as length maximum of input tokens to search).
|
|
|
|
is_keep_all: bool = True
|
|
Whether to keep all candidate pattern-matches pairs, only one match is kept for each pattern if False.
|
|
|
|
is_use_oldest: bool = True
|
|
Whether to provide the oldest match when pattern is hit, the newest one is provided if False.
|
|
|
|
is_public_pool: bool = True
|
|
Whether to use a common pool for all requests, or the pool is private for each request if False.
|
|
|
|
Members:
|
|
pool: dict[tuple[int], OrderedSet[int]] or dict[int, dict[tuple[int], OrderedSet[int]]]
|
|
If is_public_pool == True, it maps from patterns to matches
|
|
If is_public_pool == False, it maps from request ID to the request-specific pool
|
|
|
|
start_index: dict[int, int]
|
|
It maps from request ID to the index of the prompt to update the pool in the next step
|
|
"""
|
|
|
|
def __init__(self, config: NGramConfig, max_num_requests: int):
|
|
|
|
self.max_num_requests = max_num_requests
|
|
self.max_num_draft_tokens = config.max_draft_tokens
|
|
|
|
self.prompt_lookup_num_tokens = config.prompt_lookup_num_tokens
|
|
self.max_matching_ngram_size = config.max_matching_ngram_size
|
|
self.is_keep_all = config.is_keep_all
|
|
self.is_use_oldest = config.is_use_oldest # TODO: remove this if updating strategy is supported
|
|
self.is_public_pool = config.is_public_pool
|
|
self.pool = {}
|
|
self.start_index = {}
|
|
|
|
def prepare_resources(self, scheduled_batch: ScheduledRequests):
|
|
# Update pool and provide draft tokens for the requests
|
|
for request in scheduled_batch.generation_requests:
|
|
num_draft_tokens = 0 if request.py_last_draft_tokens is None else \
|
|
len(request.py_last_draft_tokens)
|
|
num_accepted_tokens = getattr(request,
|
|
"py_num_accepted_draft_tokens", 0)
|
|
num_rejected_tokens = num_draft_tokens - num_accepted_tokens
|
|
assert num_rejected_tokens >= 0
|
|
|
|
# Generate draft tokens
|
|
draft_tokens = self._get_draft_tokens(
|
|
request.get_tokens()[0],
|
|
request.request_id,
|
|
request.py_end_id,
|
|
request.py_orig_prompt_len + request.py_max_new_tokens,
|
|
)
|
|
|
|
# Pad to max_draft_tokens
|
|
if draft_tokens is not None:
|
|
pad_length = self.max_num_draft_tokens - len(draft_tokens)
|
|
draft_tokens.extend([request.py_end_id] * pad_length)
|
|
request.py_draft_tokens = draft_tokens
|
|
|
|
def update_resources(self, scheduled_batch: ScheduledRequests):
|
|
pass
|
|
|
|
def free_resources(self, request: LlmRequest):
|
|
if self.is_public_pool:
|
|
return # TODO: need to have a strategy to swap out the pairs
|
|
request_id = request.request_id
|
|
if request_id in self.pool:
|
|
self.pool.pop(request_id)
|
|
self.start_index.pop(request_id)
|
|
|
|
def add_dummy_requests(self, request_ids: List[int]):
|
|
pass
|
|
|
|
def shutdown(self):
|
|
pass
|
|
|
|
def get_max_resource_count(self) -> int:
|
|
return self.max_num_requests
|
|
|
|
def get_needed_resource_to_completion(self, request: LlmRequest):
|
|
return 0
|
|
|
|
def print_pool(self): # For debug
|
|
if self.is_public_pool:
|
|
logger.debug(f"Using public pool, size = {len(self.pool)}")
|
|
self._print_line(self.pool)
|
|
else:
|
|
logger.debug(f"Using private pool")
|
|
for request_id, request_map in self.pool.items():
|
|
logger.debug(f"Request {request_id}, size={len(request_map)}")
|
|
self._print_line(request_map, 4)
|
|
|
|
def _print_line(self, local_map, indentation=0): # For debug
|
|
for pattern, matches in local_map.items():
|
|
output = " " * indentation + str(pattern) + "->"
|
|
for match in matches:
|
|
output += str(match) + ", "
|
|
logger.debug(output)
|
|
|
|
def _get_draft_tokens(
|
|
self,
|
|
prefix: list[int],
|
|
request_id: int,
|
|
end_id: int,
|
|
max_sequence_length: int,
|
|
):
|
|
prefix_len = len(prefix)
|
|
max_draft_token_length = max_sequence_length - 1 - prefix_len
|
|
if max_draft_token_length <= 0: # Skip search if prefix is long enough
|
|
return None
|
|
|
|
if request_id not in self.start_index: # A new request
|
|
self.start_index[request_id] = 0
|
|
if not self.is_public_pool:
|
|
assert len(self.pool) + 1 <= self.max_num_requests
|
|
self.pool[request_id] = {}
|
|
pool = (self.pool if self.is_public_pool else self.pool[request_id])
|
|
|
|
# Update pool
|
|
sequence = prefix[self.start_index[request_id]:]
|
|
for size in range(min(self.max_matching_ngram_size, prefix_len - 1), 0,
|
|
-1):
|
|
# Find each possible pattern-match combination, and use tuple for hash
|
|
for l in range(len(sequence) - size):
|
|
r = min(l + size + self.prompt_lookup_num_tokens, len(sequence))
|
|
pattern = tuple(sequence[l:l + size])
|
|
new_match = tuple(sequence[l + size:r])
|
|
if pattern not in pool or \
|
|
(not self.is_keep_all and len(match) > pool[pattern][0]):
|
|
# Replace the match if
|
|
# 1. the pattern does not exist in the pool
|
|
# 2. only one match is kept, and the new match is longer (MRU)
|
|
pool[pattern] = OrderedSet((new_match, ))
|
|
elif new_match not in pool[pattern]:
|
|
# Update the matches if the pattern is already existed:
|
|
# TODO: need a strategy to maintain the short candidates, now we just remove them
|
|
# Drop all existed matches with small length
|
|
for match in pool[pattern]:
|
|
if len(match) < len(new_match):
|
|
pool[pattern].remove(match)
|
|
pool[pattern].add(new_match)
|
|
|
|
# Find match
|
|
draft_tokens = [end_id]
|
|
for size in range(min(self.max_matching_ngram_size, prefix_len - 1), 0,
|
|
-1):
|
|
pattern = tuple(prefix[-size:])
|
|
if pattern not in pool:
|
|
continue
|
|
draft_tokens = pool[pattern][0 if self.is_use_oldest else -1]
|
|
draft_tokens = list(draft_tokens)[:max_draft_token_length]
|
|
break
|
|
self.start_index[request_id] = max(
|
|
0, prefix_len -
|
|
(self.prompt_lookup_num_tokens + self.max_matching_ngram_size - 1))
|
|
|
|
return draft_tokens
|