TensorRT-LLMs/tensorrt_llm/_torch/speculative/ngram.py
Thor Johnsen 5d438be59a
[TRTLLM-5000][feat] Pytorch implementation of ngram drafter (#3936)
* v1.5

Signed-off-by: wili-65535 <wili-65535@users.noreply.github.com>

v1.5.4 Add back draft_overhead to spec dec stats

Signed-off-by: Thor Johnsen <41591019+thorjohnsen@users.noreply.github.com>

* v1.5.5: fix CI error

Signed-off-by: wili-65535 <wili-65535@users.noreply.github.com>

* v1.6: fix CI error 8196 > 8192

Signed-off-by: wili-65535 <wili-65535@users.noreply.github.com>

* Address reviewer concerns

Signed-off-by: Thor Johnsen <41591019+thorjohnsen@users.noreply.github.com>

* Address reviewer concerns

Signed-off-by: Thor Johnsen <41591019+thorjohnsen@users.noreply.github.com>

* precommit run

Signed-off-by: Thor Johnsen <41591019+thorjohnsen@users.noreply.github.com>

* v2.0: Address reviewer concerns

Signed-off-by: wili-65535 <wili-65535@users.noreply.github.com>

* v2.1: add fix from wili

Signed-off-by: wili-65535 <wili-65535@users.noreply.github.com>

* Revert changes that require use of TypeAlias because that requires python version >= 3.10

Signed-off-by: Thor Johnsen <41591019+thorjohnsen@users.noreply.github.com>

---------

Signed-off-by: Thor Johnsen <41591019+thorjohnsen@users.noreply.github.com>
Signed-off-by: wili-65535 <wili-65535@users.noreply.github.com>
Co-authored-by: wili-65535 <wili-65535@users.noreply.github.com>
2025-05-21 10:40:00 +08:00

211 lines
8.4 KiB
Python

from dataclasses import dataclass
from typing import List
from ordered_set import OrderedSet
from ..pyexecutor.llm_request import LlmRequest
from ..pyexecutor.resource_manager import BaseResourceManager
from ..pyexecutor.scheduler import ScheduledRequests
from .interface import SpecConfig, SpeculativeDecodingMode
@dataclass
class NGramConfig(SpecConfig):
"""
Configuration for N-gram drafter.
"""
# The name of speculative decoding.
spec_dec_name = "NGRAM"
num_extra_kv_tokens: int = 0
max_draft_tokens: int = 0
prompt_lookup_num_tokens: int = 5
max_matching_ngram_size: int = 5
end_id: int = -1
is_keep_all: bool = True
is_use_oldest: bool = True
is_public_pool: bool = True
def __post_init__(self) -> None:
self.spec_dec_mode = SpeculativeDecodingMode.from_string(
self.spec_dec_name)
self.max_draft_tokens = self.prompt_lookup_num_tokens
def update_from_model_config(self, model_config):
pass
class NGramPoolManager(BaseResourceManager):
"""
This class maintains the pattern-matches pairs for NGram drafter.
For example, one of the existed pairs could be: ["I","love"] -> [["apple", "because", "it", "is"], ["banana", "and"]].
Here we call ["I","love"] as `pattern`, and [["apple", "because", "it", "is"], ["banana", "and"]] as `matches`.
`pattern` is a list of token_ids. The pool provides corresponding draft tokens from the matches if the pattern appears at the tail of the sentence during generation.
`matches` is a list of candidate draft token_ids attaching to a pattern.
Arguments:
prompt_lookup_num_tokens: int
The length maximum of draft tokens (can be understood as length maximum of output draft tokens).
max_matching_ngram_size: int
The length maximum of searching tokens (can be understood as length maximum of input tokens to search).
is_keep_all: bool = True
Whether to keep all candidate pattern-matches pairs, only one match is kept for each pattern if False.
is_use_oldest: bool = True
Whether to provide the oldest match when pattern is hit, the newest one is provided if False.
is_public_pool: bool = True
Whether to use a common pool for all requests, or the pool is private for each request if False.
Members:
pool: dict[tuple[int], OrderedSet[int]] or dict[int, dict[tuple[int], OrderedSet[int]]]
If is_public_pool == True, it maps from patterns to matches
If is_public_pool == False, it maps from request ID to the request-specific pool
start_index: dict[int, int]
It maps from request ID to the index of the prompt to update the pool in the next step
"""
def __init__(self, config: NGramConfig, max_num_requests: int):
self.max_num_requests = max_num_requests
self.max_num_draft_tokens = config.max_draft_tokens
self.prompt_lookup_num_tokens = config.prompt_lookup_num_tokens
self.max_matching_ngram_size = config.max_matching_ngram_size
self.is_keep_all = config.is_keep_all
self.is_use_oldest = config.is_use_oldest # TODO: remove this if updating strategy is supported
self.is_public_pool = config.is_public_pool
self.pool = {}
self.start_index = {}
def prepare_resources(self, scheduled_batch: ScheduledRequests):
# Update pool and provide draft tokens for the requests
for request in scheduled_batch.generation_requests:
num_draft_tokens = 0 if request.py_last_draft_tokens is None else \
len(request.py_last_draft_tokens)
num_accepted_tokens = getattr(request,
"py_num_accepted_draft_tokens", 0)
num_rejected_tokens = num_draft_tokens - num_accepted_tokens
assert num_rejected_tokens >= 0
# Generate draft tokens
draft_tokens = self._get_draft_tokens(
request.get_tokens()[0],
request.request_id,
request.py_end_id,
request.py_orig_prompt_len + request.py_max_new_tokens,
)
# Pad to max_draft_tokens
if draft_tokens is not None:
pad_length = self.max_num_draft_tokens - len(draft_tokens)
draft_tokens.extend([request.py_end_id] * pad_length)
request.py_draft_tokens = draft_tokens
def update_resources(self, scheduled_batch: ScheduledRequests):
pass
def free_resources(self, request: LlmRequest):
if self.is_public_pool:
return # TODO: need to have a strategy to swap out the pairs
request_id = request.request_id
if request_id in self.pool:
self.pool.pop(request_id)
self.start_index.pop(request_id)
def add_dummy_requests(self, request_ids: List[int]):
pass
def shutdown(self):
pass
def get_max_resource_count(self) -> int:
return self.max_num_requests
def get_needed_resource_to_completion(self, request: LlmRequest):
return 0
def print_pool(self): # For debug
if self.is_public_pool:
logger.debug(f"Using public pool, size = {len(self.pool)}")
self._print_line(self.pool)
else:
logger.debug(f"Using private pool")
for request_id, request_map in self.pool.items():
logger.debug(f"Request {request_id}, size={len(request_map)}")
self._print_line(request_map, 4)
def _print_line(self, local_map, indentation=0): # For debug
for pattern, matches in local_map.items():
output = " " * indentation + str(pattern) + "->"
for match in matches:
output += str(match) + ", "
logger.debug(output)
def _get_draft_tokens(
self,
prefix: list[int],
request_id: int,
end_id: int,
max_sequence_length: int,
):
prefix_len = len(prefix)
max_draft_token_length = max_sequence_length - 1 - prefix_len
if max_draft_token_length <= 0: # Skip search if prefix is long enough
return None
if request_id not in self.start_index: # A new request
self.start_index[request_id] = 0
if not self.is_public_pool:
assert len(self.pool) + 1 <= self.max_num_requests
self.pool[request_id] = {}
pool = (self.pool if self.is_public_pool else self.pool[request_id])
# Update pool
sequence = prefix[self.start_index[request_id]:]
for size in range(min(self.max_matching_ngram_size, prefix_len - 1), 0,
-1):
# Find each possible pattern-match combination, and use tuple for hash
for l in range(len(sequence) - size):
r = min(l + size + self.prompt_lookup_num_tokens, len(sequence))
pattern = tuple(sequence[l:l + size])
new_match = tuple(sequence[l + size:r])
if pattern not in pool or \
(not self.is_keep_all and len(match) > pool[pattern][0]):
# Replace the match if
# 1. the pattern does not exist in the pool
# 2. only one match is kept, and the new match is longer (MRU)
pool[pattern] = OrderedSet((new_match, ))
elif new_match not in pool[pattern]:
# Update the matches if the pattern is already existed:
# TODO: need a strategy to maintain the short candidates, now we just remove them
# Drop all existed matches with small length
for match in pool[pattern]:
if len(match) < len(new_match):
pool[pattern].remove(match)
pool[pattern].add(new_match)
# Find match
draft_tokens = [end_id]
for size in range(min(self.max_matching_ngram_size, prefix_len - 1), 0,
-1):
pattern = tuple(prefix[-size:])
if pattern not in pool:
continue
draft_tokens = pool[pattern][0 if self.is_use_oldest else -1]
draft_tokens = list(draft_tokens)[:max_draft_token_length]
break
self.start_index[request_id] = max(
0, prefix_len -
(self.prompt_lookup_num_tokens + self.max_matching_ngram_size - 1))
return draft_tokens