[TRTLLM-5000][feat] Pytorch implementation of ngram drafter (#3936)

* v1.5

Signed-off-by: wili-65535 <wili-65535@users.noreply.github.com>

v1.5.4 Add back draft_overhead to spec dec stats

Signed-off-by: Thor Johnsen <41591019+thorjohnsen@users.noreply.github.com>

* v1.5.5: fix CI error

Signed-off-by: wili-65535 <wili-65535@users.noreply.github.com>

* v1.6: fix CI error 8196 > 8192

Signed-off-by: wili-65535 <wili-65535@users.noreply.github.com>

* Address reviewer concerns

Signed-off-by: Thor Johnsen <41591019+thorjohnsen@users.noreply.github.com>

* Address reviewer concerns

Signed-off-by: Thor Johnsen <41591019+thorjohnsen@users.noreply.github.com>

* precommit run

Signed-off-by: Thor Johnsen <41591019+thorjohnsen@users.noreply.github.com>

* v2.0: Address reviewer concerns

Signed-off-by: wili-65535 <wili-65535@users.noreply.github.com>

* v2.1: add fix from wili

Signed-off-by: wili-65535 <wili-65535@users.noreply.github.com>

* Revert changes that require use of TypeAlias because that requires python version >= 3.10

Signed-off-by: Thor Johnsen <41591019+thorjohnsen@users.noreply.github.com>

---------

Signed-off-by: Thor Johnsen <41591019+thorjohnsen@users.noreply.github.com>
Signed-off-by: wili-65535 <wili-65535@users.noreply.github.com>
Co-authored-by: wili-65535 <wili-65535@users.noreply.github.com>
This commit is contained in:
Thor Johnsen 2025-05-20 21:40:00 -05:00 committed by GitHub
parent 9199793848
commit 5d438be59a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
27 changed files with 690 additions and 125 deletions

View File

@ -252,6 +252,11 @@ public:
static void serialize(InflightBatchingStats const& inflightBatchingStats, std::ostream& os);
static size_t serializedSize(InflightBatchingStats const& inflightBatchingStats);
// SpecDecodingStats
static SpecDecodingStats deserializeSpecDecodingStats(std::istream& is);
static void serialize(SpecDecodingStats const& specDecStats, std::ostream& os);
static size_t serializedSize(SpecDecodingStats const& specDecStats);
// IterationStats
static IterationStats deserializeIterationStats(std::vector<char>& buffer);
static IterationStats deserializeIterationStats(std::istream& is);

View File

@ -46,6 +46,7 @@ class Tensor;
using TensorPtr = std::shared_ptr<Tensor>;
using SizeType32 = std::int32_t;
using SizeType64 = std::int64_t;
using FloatType = float;
using TokenIdType = std::int32_t;
using VecTokens = std::vector<TokenIdType>;
@ -294,6 +295,24 @@ struct InflightBatchingStats
float avgNumDecodedTokensPerIter;
};
/// @brief Struct that holds speculative decoding stats
struct SpecDecodingStats
{
/// @brief Total number of proposed draft tokens for all requests
SizeType64 numDraftTokens;
/// @brief Total number of accepted draft tokens for all requests
SizeType64 numAcceptedTokens;
/// @brief Number of requests with at least one draft token in batch
SizeType64 numRequestsWithDraftTokens;
/// @brief Acceptance length, defined as average number of tokens produced per step for all requests with at least
/// one draft token
double acceptanceLength;
/// @brief Iteration latency for draft token generation only (ms)
double iterLatencyMS;
/// @brief Draft overhead, defined as iterLatencyMS (specdec) / iterLatencyMS (total)
double draftOverhead;
};
/// @brief Struct that holds the stats of a single iteration
struct IterationStats
{
@ -341,6 +360,8 @@ struct IterationStats
std::optional<StaticBatchingStats> staticBatchingStats;
/// @brief Stats specific to inflight batching
std::optional<InflightBatchingStats> inflightBatchingStats;
/// @brief Stats specific to speculative decoding
std::optional<SpecDecodingStats> specDecStats;
};
/// @brief Enum class that represents the state of a request

View File

@ -31,11 +31,13 @@ NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(
StaticBatchingStats, numScheduledRequests, numContextRequests, numCtxTokens, numGenTokens, emptyGenSlots);
NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(InflightBatchingStats, numScheduledRequests, numContextRequests, numGenRequests,
numPausedRequests, numCtxTokens, microBatchId, avgNumDecodedTokensPerIter);
NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(SpecDecodingStats, numDraftTokens, numAcceptedTokens, numRequestsWithDraftTokens,
acceptanceLength, iterLatencyMS, draftOverhead);
NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(IterationStats, timestamp, iter, iterLatencyMS, newActiveRequestsQueueLatencyMS,
numNewActiveRequests, numActiveRequests, numQueuedRequests, numCompletedRequests, maxNumActiveRequests,
maxBatchSizeStatic, maxBatchSizeTunerRecommended, maxBatchSizeRuntime, maxNumTokensStatic,
maxNumTokensTunerRecommended, maxNumTokensRuntime, gpuMemUsage, cpuMemUsage, pinnedMemUsage, kvCacheStats,
staticBatchingStats, inflightBatchingStats);
staticBatchingStats, inflightBatchingStats, specDecStats);
NLOHMANN_JSON_SERIALIZE_ENUM(RequestStage,
{{RequestStage::kQUEUED, "QUEUED"}, {RequestStage::kCONTEXT_IN_PROGRESS, "CONTEXT_IN_PROGRESS"},
{RequestStage::kGENERATION_IN_PROGRESS, "GENERATION_IN_PROGRESS"},

View File

@ -1727,6 +1727,42 @@ size_t Serialization::serializedSize(InflightBatchingStats const& inflightBatchi
return totalSize;
}
// SpecDecodingStats
SpecDecodingStats Serialization::deserializeSpecDecodingStats(std::istream& is)
{
auto numDraftTokens = su::deserialize<SizeType64>(is);
auto numAcceptedTokens = su::deserialize<SizeType64>(is);
auto numRequestsWithDraftTokens = su::deserialize<SizeType64>(is);
auto acceptanceLength = su::deserialize<double>(is);
auto iterLatencyMS = su::deserialize<double>(is);
auto draftOverhead = su::deserialize<double>(is);
return SpecDecodingStats{
numDraftTokens, numAcceptedTokens, numRequestsWithDraftTokens, acceptanceLength, iterLatencyMS, draftOverhead};
}
void Serialization::serialize(SpecDecodingStats const& state, std::ostream& os)
{
su::serialize(state.numDraftTokens, os);
su::serialize(state.numAcceptedTokens, os);
su::serialize(state.numRequestsWithDraftTokens, os);
su::serialize(state.acceptanceLength, os);
su::serialize(state.iterLatencyMS, os);
su::serialize(state.draftOverhead, os);
}
size_t Serialization::serializedSize(SpecDecodingStats const& state)
{
size_t totalSize = 0;
totalSize += su::serializedSize(state.numDraftTokens);
totalSize += su::serializedSize(state.numAcceptedTokens);
totalSize += su::serializedSize(state.numRequestsWithDraftTokens);
totalSize += su::serializedSize(state.acceptanceLength);
totalSize += su::serializedSize(state.iterLatencyMS);
totalSize += su::serializedSize(state.draftOverhead);
return totalSize;
}
// IterationStats
IterationStats Serialization::deserializeIterationStats(std::istream& is)
@ -1754,12 +1790,13 @@ IterationStats Serialization::deserializeIterationStats(std::istream& is)
auto crossKvCacheStats = su::deserialize<std::optional<KvCacheStats>>(is);
auto staticBatchingStats = su::deserialize<std::optional<StaticBatchingStats>>(is);
auto inflightBatchingStats = su::deserialize<std::optional<InflightBatchingStats>>(is);
auto specdecStats = su::deserialize<std::optional<SpecDecodingStats>>(is);
return IterationStats{timestamp, iter, iterLatencyMS, newActiveRequestsQueueLatencyMS, numNewActiveRequests,
numActiveRequests, numQueuedRequests, numCompletedRequests, maxNumActiveRequests, maxBatchSizeStatic,
maxBatchSizeTunerRecommended, maxBatchSizeRuntime, maxNumTokensStatic, maxNumTokensTunerRecommended,
maxNumTokensRuntime, gpuMemUsage, cpuMemUsage, pinnedMemUsage, kvCacheStats, crossKvCacheStats,
staticBatchingStats, inflightBatchingStats};
staticBatchingStats, inflightBatchingStats, specdecStats};
}
IterationStats Serialization::deserializeIterationStats(std::vector<char>& buffer)
@ -1797,6 +1834,7 @@ size_t Serialization::serializedSize(IterationStats const& iterStats)
totalSize += su::serializedSize(iterStats.crossKvCacheStats);
totalSize += su::serializedSize(iterStats.staticBatchingStats);
totalSize += su::serializedSize(iterStats.inflightBatchingStats);
totalSize += su::serializedSize(iterStats.specDecStats);
return totalSize;
}
@ -1825,6 +1863,7 @@ void Serialization::serialize(IterationStats const& iterStats, std::ostream& os)
su::serialize(iterStats.crossKvCacheStats, os);
su::serialize(iterStats.staticBatchingStats, os);
su::serialize(iterStats.inflightBatchingStats, os);
su::serialize(iterStats.specDecStats, os);
}
std::vector<char> Serialization::serialize(IterationStats const& iterStats)

View File

@ -465,6 +465,10 @@ T deserialize(std::istream& is)
{
return Serialization::deserializeInflightBatchingStats(is);
}
else if constexpr (std::is_same_v<T, tensorrt_llm::executor::SpecDecodingStats>)
{
return Serialization::deserializeSpecDecodingStats(is);
}
else if constexpr (std::is_same_v<T, tensorrt_llm::executor::IterationStats>)
{
return Serialization::deserializeIterationStats(is);

View File

@ -132,6 +132,15 @@ void initBindings(pybind11::module_& m)
.def_readwrite("micro_batch_id", &tle::InflightBatchingStats::microBatchId)
.def_readwrite("avg_num_decoded_tokens_per_iter", &tle::InflightBatchingStats::avgNumDecodedTokensPerIter);
py::class_<tle::SpecDecodingStats>(m, "SpecDecodingStats")
.def(py::init<>())
.def_readwrite("num_draft_tokens", &tle::SpecDecodingStats::numDraftTokens)
.def_readwrite("num_accepted_tokens", &tle::SpecDecodingStats::numAcceptedTokens)
.def_readwrite("num_requests_with_draft_tokens", &tle::SpecDecodingStats::numRequestsWithDraftTokens)
.def_readwrite("acceptance_length", &tle::SpecDecodingStats::acceptanceLength)
.def_readwrite("iter_latency_ms", &tle::SpecDecodingStats::iterLatencyMS)
.def_readwrite("draft_overhead", &tle::SpecDecodingStats::draftOverhead);
py::class_<tle::IterationStats>(m, "IterationStats")
.def(py::init<>())
.def_readwrite("timestamp", &tle::IterationStats::timestamp)
@ -150,6 +159,7 @@ void initBindings(pybind11::module_& m)
.def_readwrite("cross_kv_cache_stats", &tle::IterationStats::crossKvCacheStats)
.def_readwrite("static_batching_stats", &tle::IterationStats::staticBatchingStats)
.def_readwrite("inflight_batching_stats", &tle::IterationStats::inflightBatchingStats)
.def_readwrite("specdec_stats", &tle::IterationStats::specDecStats)
.def("to_json_str",
[](tle::IterationStats const& iterationStats)
{ return tle::JsonSerialization::toJsonStr(iterationStats); });

View File

@ -4,10 +4,28 @@ This document shows how to build and run a model using Prompt-Lookup speculative
## Overview
We provide two styles of workflow to run Prompt-Lookup (named V1 and V2 respectively) now. V1 is in TRT workflow and similar to the Draft-Target-Model workflow, running in orchestrator mode and calling `runner.generate()` multiple times to get outputs, which is more flexible for customizing but slightly more overhead. V2 is in pytorch workflow and similar to the Look-Ahead workflow, running in leader mode and calling `runner.generate()` only one time to get outputs, which provides higher performance but fixed process.
The Prompt-Lookup has 3 additional hyperparameters that you need to specify to control the process of generation:
- `prompt_lookup_num_tokens`: the number of tokens we extract from input prompt or previous generated output as draft tokens in one iteration, which the range is from 4 to 10 in common usage. Empirically, the larger the value is, the higher acceptance ratio but higher overhead is expected at the same time, so the right balance based on the models and application scenarios needs to be found.
- `max_matching_ngram_size`: the number of tokens we get from the tail of the generated output as a pattern, which is used to match in input prompt or previous generated output. Empirically, the larger the value is, the more precise context can be matched from the existed sequence, indicating higher acceptance ratio, but the higher probability of miss-match and higher overhead appear, which fall back to normal generation (one token per iteration).
- `device_list`: the index list of device(s) to run the model. The length of it must be the same as the TP size of the draft model engine. For instances, `device_list=[0]` means using tp_size=1 and GPU 0 for the model, `device_list=[4,5,6,7]` means using tp=4 and GPU from 4 to 7 for the model.
- `prompt_lookup_num_tokens`: the maximum number of tokens provided as draft tokens in one iteration, which is usually from 4 to 10 in common usage (default value: 4). Empirically, the larger the value is, the higher acceptance rate but higher overhead is expected at the same time, so the right balance based on the models and application scenarios needs to be found.
- `max_matching_ngram_size`: the maximum number of tokens extracted from the tail of the input prompt or generated output as a pattern, which is used to search corresponding draft tokens (default value: 2). Empirically, the larger the value is, the more precise context can be matched from the existed sequence, indicating higher acceptance rate, but the higher probability of miss-match and higher overhead appear, which fall back to normal generation (one token per iteration).
- `device_list`: the index list of device(s) to run the model in V1 workflow. The length of it must be the same as the TP size of the draft model engine. For instances, `device_list=[0]` means using tp_size=1 and GPU 0 for the model, `device_list=[4,5,6,7]` means using tp=4 and GPU from 4 to 7 for the model. This parameter is neddless in V2 workflow.
+ For example, the process of getting draft tokens using `prompt_lookup_num_tokens=2` and `max_matching_ngram_size=4` with a sentence `prefix=[..., t1, t2, t3, t4]` is like below:
```Python
pattern = prefix[:-2] # pattern=[t3, t4] (length=2)
if pattern in pool and len(pool[pattern]) == 4: # assuming it is {(t3, t4): (t5, t6, t7, t8)}
return pool[pattern] # draft token = [t5, t6, t7, t8]
elif pattern in pool and len(pool[pattern]) == <4: # assuming it is {(t3, t4): (t9, t10, t11)}
return pool[pattern] # draft token = [t9, t10, t11]
pattern = prefix[:-1] # Try shorter pattern if no candidate of length=2 exists, pattern=[t4] (length=1)
if pattern in pool and len(pool[pattern]) == 4: # The same process as above
return pool[pattern]
elif pattern in pool and len(pool[pattern]) == <4:
return pool[pattern]
return None # No any candidate exists
```
## Support Matrix
* GPU Compute Capability >= 8.0 (Ampere or newer)
@ -17,17 +35,21 @@ The Prompt-Lookup has 3 additional hyperparameters that you need to specify to c
## Usage
### Build engines
### V1 workflow
+ We use an open-source `llama-v2-13B` models in this example.
+ `--use_paged_context_fmha=enable` must be specified since we need KVcache reuse in this approach.
+ `--speculative_decoding_mode=draft_tokens_external` must be specified.
+ `--max_draft_len` must be specified larger or equal to `prompt_lookup_num_tokens`.
+ `---prompt_lookup_config` is corresponding configuration of Prompt-Lookup, we can see its usage in [util.py](../util.py).
+ As an example, `[10,2,[0]]` means `prompt_lookup_num_tokens=10`, `max_matching_ngram_size=2`, and device of target model is `GPU0`.
+ `--kv_cache_enable_block_reuse` must be specified for this approach.
+ Only CPP session is supported, so `--use_py_session` must not be specified.
+ `--num_beams` can not be specified as larger than 1 since beam search is not supported in this approach yet.
```bash
cd examples/models/core/llama
python3 convert_checkpoint.py \
# Build engine
python3 examples/models/core/llama/convert_checkpoint.py \
--model_dir=<Path To Llama-v2-13B repo> \
--output_dir=./ckpt-target \
--dtype=float16
@ -42,34 +64,18 @@ trtllm-build \
--max_batch_size=4 \
--max_input_len=3200 \
--max_seq_len=4800
```
### Run decoding
+ `---prompt_lookup_config` is corresponding configuration of Prompt-Lookup, we can see its usage in [util.py](../util.py).
+ As an example, `[10,2,[0]]` means `prompt_lookup_num_tokens=10`, `max_matching_ngram_size=2`, and device of target model is `GPU0`.
+ `--kv_cache_enable_block_reuse` must be specified for this approach.
+ Only CPP session is supported, so `--use_py_session` must not be specified.
+ `--num_beams` can not be specified as larger than 1 since beam search is not supported in this approach yet.
```bash
cd examples/models/core/llama
python3 ../../../run.py \
# Run decoding
python3 examples/run.py \
--tokenizer_dir <Path To Llama-v2-7B repo> \
--engine_dir ./target-engine \
--prompt_lookup_config="[10,2,[0]]" \
--max_output_len=256 \
--kv_cache_enable_block_reuse \
--input_text="How does Draft-Sampling work?"
```
## Run summarization tasks
```bash
cd examples/models/core/llama
python ../../../summarize.py \
# Run summarization tasks
python examples/summarize.py \
--test_hf \
--test_trt_llm \
--check_accuracy \
@ -79,3 +85,11 @@ python ../../../summarize.py \
--prompt_lookup_config="[10,2,[0]]" \
--kv_cache_enable_block_reuse
```
### V2 workflow
```bash
python3 examples/pytorch/quickstart_advanced.py \
--max_matching_ngram_size=2 \
--spec_decode_nextn=4
```

View File

@ -36,8 +36,8 @@ class PLDPool: # Ngrams pool for Prompt-Lookup-Decoding
is_use_oldest: bool = True,
):
self.input_batch_size = input_batch_size
self.plnt = prompt_lookup_num_tokens # Shorter name
self.mmns = max_matching_ngram_size # Shorter name
self.prompt_lookup_num_tokens = prompt_lookup_num_tokens
self.max_matching_ngram_size = max_matching_ngram_size
self.end_id = end_id
self.max_seq_len = max_seq_len
self.is_keep_all = is_keep_all
@ -45,9 +45,25 @@ class PLDPool: # Ngrams pool for Prompt-Lookup-Decoding
self.pool = [{} for _ in range(input_batch_size)]
self.start_index = [0 for _ in range(input_batch_size)]
# modified from `transformers/generation/candidate_generator.py`
assert self.prompt_lookup_num_tokens > 0, f"prompt_lookup_num_tokens must be greater than 0, but got {self.prompt_lookup_num_tokens}"
assert self.max_matching_ngram_size > 0, f"max_matching_ngram_size must be greater than 0, but got {self.max_matching_ngram_size}"
def print_pool(self):
"""
For debug
"""
logger.info(f"Batch size = {self.input_batch_size}")
for i, map in enumerate(self.pool):
logger.info(f"Slot {i}, size = {len(map)}")
for key, values in map.items():
logger.info(f" {key}->{values}")
def get_draft_tokens(self, prefix: list[torch.Tensor],
batch_slot: list[int]):
"""
Get draft tokens from a batch of requests
modified from `transformers/generation/candidate_generator.py`
"""
batch_size = len(prefix)
prefix_len = [len(prefix[bi]) for bi in range(batch_size)]
draft_tokens = [] # `logits` is useless yet
@ -61,25 +77,30 @@ class PLDPool: # Ngrams pool for Prompt-Lookup-Decoding
# Update pool
sequence = prefix[bi][self.start_index[gbi]:].tolist()
for size in range(min(self.mmns, prefix_len[bi] - 1), 0, -1):
for size in range(
min(self.max_matching_ngram_size, prefix_len[bi] - 1), 0,
-1):
# Find each possible key-value combination, and use tuple for hash
for l in range(len(sequence) - size):
r = min(l + size + self.plnt, len(sequence))
r = min(l + size + self.prompt_lookup_num_tokens,
len(sequence))
key = tuple(sequence[l:l + size])
value = tuple(sequence[l + size:r])
if key not in self.pool[gbi] or not self.is_keep_all or \
len(self.pool[gbi][key][0]) < self.plnt:
len(self.pool[gbi][key][0]) < self.prompt_lookup_num_tokens:
# Update the value if
# 1. the key does not exist
# 2. we only keep one value for each key
# 2. we only keep the newest one value for each key (MRU)
# 3. the length of the value saved before is less than `prompt_lookup_num_tokens`
self.pool[gbi][key] = OrderedSet((value, ))
elif value not in self.pool[gbi][key]:
# Extend the value if the key is already existed and we want to keep all of them
# Extend the value if the key is already existed but count of values is not enough
self.pool[gbi][key].add(value)
# Find match
for size in range(min(self.mmns, prefix_len[bi] - 1), 0, -1):
for size in range(
min(self.max_matching_ngram_size, prefix_len[bi] - 1), 0,
-1):
pattern = tuple(prefix[bi][-size:].tolist())
if pattern not in self.pool[gbi]:
continue
@ -92,7 +113,8 @@ class PLDPool: # Ngrams pool for Prompt-Lookup-Decoding
break
draft_tokens.append(chosen_ids)
self.start_index[gbi] = max(
0, prefix_len[bi] - (self.plnt + self.mmns - 1))
0, prefix_len[bi] - (self.prompt_lookup_num_tokens +
self.max_matching_ngram_size - 1))
return draft_tokens, None
@ -108,20 +130,21 @@ def run_dtm_pld(batch_input_ids,
*,
target_runner=None):
# `dtm` for Draft-Target-Model, `pld` for Prompt-Lookup-Decoding
assert (args.draft_target_model_config is not None) ^ (args.prompt_lookup_config is not None), \
"`--draft_target_model_config` and `--prompt_lookup_config` can not be specified at the same time."
if args.draft_target_model_config is not None:
assert args.draft_engine_dir is not None, "`--draft_engine_dir` must be specified in Draft-Target-Model."
is_dtm = (args.draft_target_model_config is not None)
is_pld = (args.prompt_lookup_config is not None)
assert is_dtm ^ is_pld, "`--draft_target_model_config` and `--prompt_lookup_config` can not be specified at the same time."
if is_dtm:
assert args.draft_engine_dir is not None, "`--draft_engine_dir` must be specified in Draft-Target-Model."
draft_len, draft_device_list, target_device_list, use_logits = ast.literal_eval(
args.draft_target_model_config)
logger.info(f"Using Draft-Target-Model speculative decoding")
logger.info(f"draft_len: {draft_len}")
logger.info(f"Device(s) for draft model: {draft_device_list}")
logger.info(f"Device(s) for target model: {target_device_list}")
logger.info(f"Use logits to accept tokens: {use_logits}")
if is_pld:
logger.info(
f"Using Prompt-Lookup-Decoding speculative decoding V1 workflow")
prompt_lookup_num_tokens, max_matching_ngram_size, target_device_list = ast.literal_eval(
args.prompt_lookup_config)
logger.info(f"prompt_lookup_num_tokens: {prompt_lookup_num_tokens}")
@ -206,9 +229,8 @@ def run_dtm_pld(batch_input_ids,
device_ids=target_device_list)
target_runner = ModelRunnerCpp.from_dir(**target_runner_kwargs)
if is_dtm and use_logits and \
not (draft_runner.gather_generation_logits and target_runner.gather_generation_logits):
assert False, "`--gather_generation_logits` must be specified while building draft/target models for using logits to accept"
if is_dtm and use_logits:
assert draft_runner.gather_generation_logits and target_runner.gather_generation_logits, "`--gather_generation_logits` must be specified while building draft/target models for using logits to accept"
common_generaion_kwargs = dict(
max_attention_window_size=args.max_attention_window_size,

View File

@ -1,12 +1,12 @@
# TRT-LLM with PyTorch
Run the quick start script:
## Run the quick start script:
```bash
python3 quickstart.py
```
Run the advanced usage example script:
## Run the advanced usage example script:
```bash
# BF16
@ -29,8 +29,9 @@ python3 quickstart_advanced.py --model_dir nvidia/Llama-3_1-Nemotron-Ultra-253B-
# Nemotron-H requires disabling cache reuse in kv cache
python3 quickstart_advanced.py --model_dir nvidia/Nemotron-H-8B-Base-8K --disable_kv_cache_reuse --max_batch_size 8
```
Run the multimodal example script:
## Run the multimodal example script:
```bash
# default inputs
@ -43,29 +44,39 @@ python3 quickstart_multimodal.py --model_dir Efficient-Large-Model/NVILA-8B --mo
# Note: media should be either image or video. Mixing image and video is not supported.
python3 quickstart_multimodal.py --model_dir Efficient-Large-Model/NVILA-8B --modality video --prompt "Tell me what you see in the video briefly." "Describe the scene in the video briefly." --media "https://huggingface.co/datasets/Efficient-Large-Model/VILA-inference-demos/resolve/main/OAI-sora-tokyo-walk.mp4" "https://huggingface.co/datasets/Efficient-Large-Model/VILA-inference-demos/resolve/main/world.mp4" --max_tokens 128 [--use_cuda_graph]
```
## Supported Models
| Architecture | Model | HuggingFace Example | Modality |
|--------------|-------|---------------------|----------|
| `BertForSequenceClassification` | BERT-based | `textattack/bert-base-uncased-yelp-polarity` | L |
| `DeepseekV3ForCausalLM` | DeepSeek-V3 | `deepseek-ai/DeepSeek-V3` | L |
| `Gemma3ForCausalLM` | Gemma3 | `google/gemma-3-1b-it` | L |
| `LlavaLlamaModel` | VILA | `Efficient-Large-Model/NVILA-8B` | L + V |
| `LlavaNextForConditionalGeneration` | LLaVA-NeXT | `llava-hf/llava-v1.6-mistral-7b-hf` | L + V |
| `LlamaForCausalLM` | Llama 3 <br> Llama 3.1 <br> Llama 2 <br> LLaMA | `meta-llama/Meta-Llama-3.1-70B` | L |
| `Llama4ForConditionalGeneration` | Llama 4 Scout <br> Llama 4 Maverick | `meta-llama/Llama-4-Scout-17B-16E-Instruct` <br> `meta-llama/Llama-4-Maverick-17B-128E-Instruct` | L + V |
| `MistralForCausalLM` | Mistral | `mistralai/Mistral-7B-v0.1` | L |
| `MixtralForCausalLM` | Mixtral | `mistralai/Mixtral-8x7B-v0.1` | L |
| `MllamaForConditionalGeneration` | Llama 3.2 | `meta-llama/Llama-3.2-11B-Vision` | L |
| `NemotronForCausalLM` | Nemotron-3 <br> Nemotron-4 <br> Minitron | `nvidia/Minitron-8B-Base` | L |
| `NemotronHForCausalLM` | Nemotron-H | `nvidia/Nemotron-H-8B-Base-8K` <br> `nvidia/Nemotron-H-47B-Base-8K` <br> `nvidia/Nemotron-H-56B-Base-8K` | L |
| `NemotronNASForCausalLM` | LLamaNemotron <br> LlamaNemotron Super <br> LlamaNemotron Ultra | `nvidia/Llama-3_1-Nemotron-51B-Instruct` <br> `nvidia/Llama-3_3-Nemotron-Super-49B-v1` <br> `nvidia/Llama-3_1-Nemotron-Ultra-253B-v1` | L |
| `Qwen2ForCausalLM` | QwQ, Qwen2 | `Qwen/Qwen2-7B-Instruct` | L |
| `Qwen2ForProcessRewardModel` | Qwen2-based | `Qwen/Qwen2.5-Math-PRM-7B` | L |
| `Qwen2ForRewardModel` | Qwen2-based | `Qwen/Qwen2.5-Math-RM-72B` | L |
| `Qwen2VLForConditionalGeneration` | Qwen2-VL | `Qwen/Qwen2-VL-7B-Instruct` | L + V |
| `Qwen2_5_VLForConditionalGeneration` | Qwen2.5-VL | `Qwen/Qwen2.5-VL-7B-Instruct` | L + V |
### Supported Models
| Architecture | Model | HuggingFace Example | Modality |
| :----------------------------------: | :----------------------------------------------------------- | :----------------------------------------------------------- | :------: |
| `BertForSequenceClassification` | BERT-based | `textattack/bert-base-uncased-yelp-polarity` | L |
| `DeepseekV3ForCausalLM` | DeepSeek-V3 | `deepseek-ai/DeepSeek-V3 ` | L |
| `Gemma3ForCausalLM` | Gemma3 | `google/gemma-3-1b-it` | L |
| `LlavaLlamaModel` | VILA | `Efficient-Large-Model/NVILA-8B` | L + V |
| `LlavaNextForConditionalGeneration` | LLaVA-NeXT | `llava-hf/llava-v1.6-mistral-7b-hf` | L + V |
| `LlamaForCausalLM` | Llama 3 <br> Llama 3.1 <br> Llama 2 <br> LLaMA | `meta-llama/Meta-Llama-3.1-70B` | L |
| `Llama4ForConditionalGeneration` | Llama 4 Scout <br> Llama 4 Maverick | `meta-llama/Llama-4-Scout-17B-16E-Instruct` <br> `meta-llama/Llama-4-Maverick-17B-128E-Instruct` | L + V |
| `MistralForCausalLM` | Mistral | `mistralai/Mistral-7B-v0.1` | L |
| `MixtralForCausalLM` | Mixtral | `mistralai/Mixtral-8x7B-v0.1` | L |
| `MllamaForConditionalGeneration` | Llama 3.2 | `meta-llama/Llama-3.2-11B-Vision` | L |
| `NemotronForCausalLM` | Nemotron-3 <br> Nemotron-4 <br> Minitron | `nvidia/Minitron-8B-Base` | L |
| `NemotronHForCausalLM` | Nemotron-H | `nvidia/Nemotron-H-8B-Base-8K` <br> `nvidia/Nemotron-H-47B-Base-8K` <br> `nvidia/Nemotron-H-56B-Base-8K` | L |
| `NemotronNASForCausalLM` | LLamaNemotron <br> LlamaNemotron Super <br> LlamaNemotron Ultra | `nvidia/Llama-3_1-Nemotron-51B-Instruct` <br> `nvidia/Llama-3_3-Nemotron-Super-49B-v1` <br> `nvidia/Llama-3_1-Nemotron-Ultra-253B-v1` | L |
| `Qwen2ForCausalLM` | QwQ, Qwen2 | `Qwen/Qwen2-7B-Instruct` | L |
| `Qwen2ForProcessRewardModel` | Qwen2-based | `Qwen/Qwen2.5-Math-PRM-7B` | L |
| `Qwen2ForRewardModel` | Qwen2-based | `Qwen/Qwen2.5-Math-RM-72B` | L |
| `Qwen2VLForConditionalGeneration` | Qwen2-VL | `Qwen/Qwen2-VL-7B-Instruct` | L + V |
| `Qwen2_5_VLForConditionalGeneration` | Qwen2.5-VL | `Qwen/Qwen2.5-VL-7B-Instruct` | L + V |
Note:
- L: Language only
- L + V: Language and Vision multimodal support
- Llama 3.2 accepts vision input, but our support currently limited to text only.
## Run the speculative decoding script:
```bash
# NGram drafter
python3 examples/pytorch/quickstart_advanced.py \
--max_matching_ngram_size=2 \
--spec_decode_nextn=4
```

View File

@ -4,7 +4,7 @@ from tensorrt_llm import SamplingParams
from tensorrt_llm._torch import LLM
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
from tensorrt_llm.llmapi import (EagleDecodingConfig, KvCacheConfig,
MTPDecodingConfig)
MTPDecodingConfig, NGramDecodingConfig)
example_prompts = [
"Hello, my name is",
@ -103,6 +103,7 @@ def add_llm_args(parser):
parser.add_argument('--spec_decode_algo', type=str, default=None)
parser.add_argument('--spec_decode_nextn', type=int, default=1)
parser.add_argument('--eagle_model_dir', type=str, default=None)
parser.add_argument('--max_matching_ngram_size', type=int, default=5)
# Relaxed acceptance
parser.add_argument('--use_relaxed_acceptance_for_thinking',
@ -130,6 +131,7 @@ def setup_llm(args):
use_cuda_graph=args.use_cuda_graph,
load_format=args.load_format,
print_iter_log=args.print_iter_log,
enable_iter_perf_stats=args.print_iter_log,
torch_compile_enabled=args.use_torch_compile,
torch_compile_piecewise_cuda_graph=args.use_piecewise_cuda_graph,
moe_backend=args.moe_backend,
@ -154,6 +156,14 @@ def setup_llm(args):
spec_config = EagleDecodingConfig(
max_draft_len=args.spec_decode_nextn,
pytorch_eagle_weights_path=args.eagle_model_dir)
elif spec_decode_algo == "NGRAM":
spec_config = NGramDecodingConfig(
prompt_lookup_num_tokens=args.spec_decode_nextn,
max_matching_ngram_size=args.max_matching_ngram_size,
is_keep_all=True,
is_use_oldest=True,
is_public_pool=True,
)
else:
spec_config = None

View File

@ -454,7 +454,8 @@ def main(args):
lora_ckpt_source=args.lora_ckpt_source,
gpu_weights_percent=args.gpu_weights_percent,
max_output_len=args.max_output_len,
enable_context_fmha_fp32_acc=args.enable_context_fmha_fp32_acc)
enable_context_fmha_fp32_acc=args.enable_context_fmha_fp32_acc,
)
if args.medusa_choices is not None:
args.medusa_choices = ast.literal_eval(args.medusa_choices)
assert args.temperature == 1.0, "Medusa should use temperature == 1.0"

View File

@ -491,10 +491,34 @@ def main(args):
f"Using {'Python' if args.use_py_session else 'C++'} session")
runner_cls = ModelRunner if args.use_py_session else ModelRunnerCpp
runner_kwargs = dict(engine_dir=args.engine_dir,
rank=runtime_rank,
debug_mode=args.debug_mode,
gpu_weights_percent=args.gpu_weights_percent)
runner_kwargs = dict(
engine_dir=args.engine_dir,
rank=runtime_rank,
debug_mode=args.debug_mode,
gpu_weights_percent=args.gpu_weights_percent,
enable_context_fmha_fp32_acc=args.enable_context_fmha_fp32_acc,
)
if not args.use_py_session:
runner_kwargs.update(
lora_dir=args.lora_dir,
lora_ckpt_source=args.lora_ckpt_source,
max_batch_size=max_batch_size,
max_input_len=test_token_num,
max_output_len=output_len,
max_beam_width=num_beams,
max_attention_window_size=max_attention_window_size,
sink_token_length=sink_token_length,
max_tokens_in_paged_kv_cache=args.max_tokens_in_paged_kv_cache,
kv_cache_enable_block_reuse=args.kv_cache_enable_block_reuse,
kv_cache_free_gpu_memory_fraction=args.
kv_cache_free_gpu_memory_fraction,
enable_chunked_context=args.enable_chunked_context,
multi_block_mode=args.multi_block_mode,
cuda_graph_mode=args.cuda_graph_mode,
gather_generation_logits=args.eval_ppl,
use_gpu_direct_storage=args.use_gpu_direct_storage,
)
if args.medusa_choices is not None:
args.medusa_choices = ast.literal_eval(args.medusa_choices)
assert args.temperature == 1.0, "Medusa should use temperature == 1.0"
@ -523,38 +547,16 @@ def main(args):
if args.prompt_lookup_config is not None:
assert args.kv_cache_enable_block_reuse, "`--kv_cache_enable_block_reuse` must be specified in speculative decoding."
assert not args.use_py_session, "`--use_py_session` is not supported in Speculative decoding."
assert not is_enc_dec, "Encoder-Decoder model is not supported in Speculative decoding."
assert args.num_beams == 1, "`--num_beams>1` is not supported in Speculative decoding."
prompt_lookup_num_tokens, _, target_device_list = ast.literal_eval(
args.prompt_lookup_config)
args.max_output_len = output_len # Specialization for PLD
runner_kwargs.update(is_orchestrator_mode=True,
device_ids=target_device_list)
if not args.use_py_session:
runner_kwargs.update(
lora_dir=args.lora_dir,
lora_ckpt_source=args.lora_ckpt_source,
max_batch_size=max_batch_size,
max_input_len=test_token_num,
max_output_len=output_len,
max_beam_width=num_beams,
max_attention_window_size=max_attention_window_size,
sink_token_length=sink_token_length,
max_tokens_in_paged_kv_cache=args.max_tokens_in_paged_kv_cache,
kv_cache_enable_block_reuse=args.kv_cache_enable_block_reuse,
kv_cache_free_gpu_memory_fraction=args.
kv_cache_free_gpu_memory_fraction,
enable_chunked_context=args.enable_chunked_context,
multi_block_mode=args.multi_block_mode,
cuda_graph_mode=args.cuda_graph_mode,
gather_generation_logits=args.eval_ppl,
use_gpu_direct_storage=args.use_gpu_direct_storage)
runner_kwargs.update(
enable_context_fmha_fp32_acc=args.enable_context_fmha_fp32_acc)
if args.prompt_lookup_config is not None:
# Specialization for PLD since many call of `generate()` is needed
runner_kwargs.update(max_input_len=test_token_num +
device_ids=target_device_list,
max_input_len=test_token_num +
prompt_lookup_num_tokens + output_len)
runner = runner_cls.from_dir(**runner_kwargs)
assert not (args.eval_ppl and not runner.gather_context_logits), \
"PPL evaluation requires engine built with gather_context_logits enabled"

View File

@ -463,7 +463,8 @@ def instantiate_sampler(model_engine: PyTorchModelEngine,
assert pytorch_backend_config.attn_backend == "FLASHINFER_STAR_ATTENTION", "attention backend of star attention should be 'FLASHINFER_STAR_ATTENTION'"
sampler = TorchStarAttentionSampler(
max_seq_len=model_engine.max_seq_len)
elif model_engine.spec_config is not None:
elif model_engine.spec_config is not None and model_engine.spec_config.spec_dec_mode.has_spec_decoder(
):
sampler = get_spec_decoder(max_seq_len=model_engine.max_seq_len,
spec_config=model_engine.spec_config)
elif pytorch_backend_config.enable_trtllm_sampler:

View File

@ -1042,7 +1042,6 @@ class PyTorchModelEngine(ModelEngine):
request_ids.append(request.py_request_id)
all_prompt_tokens = request.get_tokens(0)
draft_lens.append(0)
begin_compute = request.context_current_position
end_compute = begin_compute + request.context_chunk_size
prompt_tokens = all_prompt_tokens[begin_compute:end_compute]

View File

@ -24,7 +24,8 @@ from tensorrt_llm.bindings.executor import (DisServingRequestStats,
FinishReason, InflightBatchingStats,
IterationStats, KvCacheStats,
RequestStage, RequestStats,
RequestType, StaticBatchingStats)
RequestType, SpecDecodingStats,
StaticBatchingStats)
from tensorrt_llm.bindings.internal.batch_manager import (LlmRequestType,
ReqIdsSet)
from tensorrt_llm.logger import logger
@ -512,6 +513,10 @@ class PyExecutor:
stats.inflight_batching_stats = InflightBatchingStats()
# staticBatchingStats is not used in pytorch path
stats.static_batching_stats = StaticBatchingStats()
spec_resource_manager = self.resource_manager.resource_managers.get(
"spec_resource_manager")
if spec_resource_manager is not None:
stats.specdec_stats = SpecDecodingStats()
return stats
def _populate_req_stats(
@ -611,6 +616,9 @@ class PyExecutor:
scheduled_batch.paused_requests)
stats.inflight_batching_stats.avg_num_decoded_tokens_per_iter = 0
stats.inflight_batching_stats.micro_batch_id = 0
if stats.specdec_stats is not None:
stats.specdec_stats.draft_overhead = 0.0 if iter_latency_ms <= 0.0 else float(
stats.specdec_stats.iter_latency_ms) / float(iter_latency_ms)
return stats
def _append_iter_stats(self,
@ -627,7 +635,7 @@ class PyExecutor:
active_requests: List[LlmRequest],
batch_state: BatchState):
iter_end_time = time.time()
iter_latency_ms = iter_end_time - batch_state.iter_start_time
iter_latency_ms = (iter_end_time - batch_state.iter_start_time) * 1e3
if batch_state.iter_stats is None:
return
@ -791,6 +799,10 @@ class PyExecutor:
torch.cuda.set_device(self.device_id)
got_finish_signal = False
num_dummy_request = 0
is_ngram = hasattr(
self.model_engine, "spec_config"
) and self.model_engine.spec_config is not None and self.model_engine.spec_config.spec_dec_mode.is_ngram(
)
with self._profiler() as profile_step:
iter_start_time = time.time()
iter_stats = None
@ -803,26 +815,28 @@ class PyExecutor:
new_requests) or got_finish_signal
if got_finish_signal and len(self.active_requests) == 0:
break
if self.kv_cache_transceiver:
self._check_disagg_gen_transfer_status()
if self.enable_iter_perf_stats:
iter_stats = self._get_init_iter_stats(
len(new_requests),
self.new_active_requests_queue_latency_ms)
if self.kv_cache_transceiver:
self._check_disagg_gen_transfer_status()
if not got_finish_signal:
num_dummy_request = self._get_num_dummy_request()
if num_dummy_request > 0:
self._merge_dummy_request(num_dummy_request)
if self.draft_model_engine is not None:
if self.draft_model_engine is not None or is_ngram:
self._prepare_draft_requests()
scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule(
)
if self.kv_cache_transceiver:
# For requests that are fitting disagg gen init, also prepare resources for KV cache manager
self._prepare_disagg_gen_init(
fitting_disagg_gen_init_requests)
if num_fitting_reqs == 0 and not fitting_disagg_gen_init_requests:
@ -848,10 +862,21 @@ class PyExecutor:
finished_requests = []
if scheduled_batch.batch_size > 0:
has_ngram_iter_stats = is_ngram and self.model_engine.spec_config.spec_dec_mode.is_ngram(
) and iter_stats is not None
if has_ngram_iter_stats:
before = time.time()
self.resource_manager.prepare_resources(scheduled_batch)
if self.draft_model_engine is not None:
self._prepare_draft_tokens(scheduled_batch)
if has_ngram_iter_stats:
self._insert_ngram_iter_stats(scheduled_batch,
iter_stats)
iter_stats.specdec_stats.iter_latency_ms = (
time.time() - before) * 1e3
if self.kv_cache_transceiver:
# For generation requests which have completed KV cache transfer
self._prepare_disagg_gen_transmission_complete(
@ -903,6 +928,7 @@ class PyExecutor:
# Set draft tokens here to make the KV cache manager
# and scheduler aware of them.
for req in self.active_requests:
# TODO: enable draft tokens in context phase
if req.state != LlmRequestState.GENERATION_IN_PROGRESS:
continue
req.py_last_draft_tokens = req.py_draft_tokens
@ -961,7 +987,7 @@ class PyExecutor:
if num_fitting_reqs == 0 and not fitting_disagg_gen_init_requests:
logger.warning(
"num_fitting_reqs =0 and fitting_disagg_gen_init_requests is empty , may not have enough kvCache"
"num_fitting_reqs=0 and fitting_disagg_gen_init_requests is empty, may not have enough kvCache"
)
self.kv_cache_transceiver.check_context_transfer_status(
1)
@ -1677,6 +1703,43 @@ class PyExecutor:
logger.error(f"Encountered an error in sampling: {error_msg}")
self._handle_errors(error_msg)
def _insert_ngram_iter_stats(
self, scheduled_requests: ScheduledRequests, iter_stats: IterationStats
) -> Tuple[ScheduledRequests, Dict[int, LlmRequest]]:
"""
Get statistic information from the draft tokens in NGram drafter
"""
assert iter_stats is not None
total_num_draft_tokens = 0
total_num_accepted_tokens = 0
num_requests_with_draft_tokens = 0
for request in chain(scheduled_requests.context_requests,
scheduled_requests.generation_requests):
num_draft_tokens = 0 if request.py_last_draft_tokens is None else len(
request.py_last_draft_tokens)
num_accepted_tokens = getattr(request,
"py_num_accepted_draft_tokens", 0)
if num_draft_tokens > 0:
total_num_draft_tokens = total_num_draft_tokens + num_draft_tokens
total_num_accepted_tokens = total_num_accepted_tokens + num_accepted_tokens
num_requests_with_draft_tokens = num_requests_with_draft_tokens + 1
if num_requests_with_draft_tokens > 0:
iter_stats.specdec_stats.iter_latency_ms = 0.0 # We do not coutn time in this method
iter_stats.specdec_stats.num_draft_tokens = total_num_draft_tokens
iter_stats.specdec_stats.num_accepted_tokens = total_num_accepted_tokens
iter_stats.specdec_stats.num_requests_with_draft_tokens = num_requests_with_draft_tokens
iter_stats.specdec_stats.acceptance_length = float(
(total_num_accepted_tokens + num_requests_with_draft_tokens
)) / float(num_requests_with_draft_tokens)
else:
iter_stats.specdec_stats.iter_latency_ms = 0.0
iter_stats.specdec_stats.num_draft_tokens = 0
iter_stats.specdec_stats.num_accepted_tokens = 0
iter_stats.specdec_stats.num_requests_with_draft_tokens = 0
iter_stats.specdec_stats.acceptance_length = 1.0
@nvtx_range("_prepare_draft_batch")
def _prepare_draft_batch(
self, scheduled_requests: ScheduledRequests

View File

@ -12,7 +12,7 @@ from tensorrt_llm.quantization import KV_CACHE_QUANT_ALGO_LIST
from ..attention_backend.interface import AttentionRuntimeFeatures
from ..distributed import MPIDist
from ..speculative import Eagle3Config, get_spec_resource_manager
from ..speculative import Eagle3Config, NGramConfig, get_spec_resource_manager
from ._util import (create_kv_cache_manager, create_py_executor_instance,
estimate_max_kv_cache_tokens, get_token_num_for_estimation,
instantiate_sampler, is_mla)
@ -63,11 +63,13 @@ def create_py_executor(executor_config: ExecutorConfig,
spec_config = executor_config.speculative_config
has_draft_model_engine = isinstance(spec_config, Eagle3Config)
has_ngram_drafter = isinstance(spec_config, NGramConfig)
attn_runtime_features = AttentionRuntimeFeatures(
chunked_prefill=executor_config.enable_chunked_context,
cache_reuse=executor_config.kv_cache_config.enable_block_reuse,
has_speculative_draft_tokens=has_draft_model_engine,
has_speculative_draft_tokens=has_draft_model_engine
or has_ngram_drafter,
)
model_engine = PyTorchModelEngine(

View File

@ -314,7 +314,7 @@ class TorchSampler(Sampler):
num_accepted += 1
new_token = new_tokens_list[token_idx + num_accepted]
num_tokens = request.add_new_token(new_token, beam_idx)
new_tokens.append(num_tokens)
new_tokens.append(num_tokens) # `num_tokens`->`new_token`
if self._handle_stop_criteria(request, new_token,
num_tokens, beam_idx):

View File

@ -1,6 +1,7 @@
from .eagle3 import Eagle3Config, Eagle3SpecMetadata
from .interface import SpecConfig, SpecMetadata
from .mtp import MTPConfig, MTPEagleWorker, MTPSpecMetadata, MTPWorker
from .ngram import NGramConfig
from .utils import (get_num_spec_layers, get_spec_decoder, get_spec_metadata,
get_spec_resource_manager)
@ -8,5 +9,5 @@ __all__ = [
"SpecConfig", "SpecMetadata", "MTPConfig", "MTPWorker", "MTPEagleWorker",
"Eagle3Config", "Eagle3SpecMetadata", "MTPSpecMetadata",
"get_spec_metadata", "get_spec_resource_manager", "get_spec_decoder",
"get_num_spec_layers"
"get_num_spec_layers", "NGramConfig"
]

View File

@ -15,6 +15,7 @@ class SpeculativeDecodingMode(IntEnum):
MTP = auto()
MTP_EAGLE = auto()
EAGLE3 = auto()
NGRAM = auto()
NONE = auto()
def is_mtp(self):
@ -26,6 +27,9 @@ class SpeculativeDecodingMode(IntEnum):
def is_eagle3(self):
return self == SpeculativeDecodingMode.EAGLE3
def is_ngram(self):
return self == SpeculativeDecodingMode.NGRAM
def is_none(self):
return self == SpeculativeDecodingMode.NONE
@ -35,6 +39,9 @@ class SpeculativeDecodingMode(IntEnum):
def support_overlap_scheduler(self):
return self.is_mtp()
def has_spec_decoder(self):
return self.is_mtp() or self.is_eagle3()
def extend_ctx(self, attention_backend: AttentionBackend):
"""
If true, treat generation requests with draft tokens as
@ -43,8 +50,9 @@ class SpeculativeDecodingMode(IntEnum):
"""
# Fixme: only trtllm attention backend supports eagle3 generation-phase kernels on blackwell.
return self.is_eagle3() and not (isinstance(
attention_backend, TrtllmAttention) and get_sm_version() == 100)
return (self.is_eagle3()
and not (isinstance(attention_backend, TrtllmAttention)
and get_sm_version() == 100)) or self.is_ngram()
@staticmethod
def from_string(name: Optional[str]) -> "SpeculativeDecodingMode":

View File

@ -0,0 +1,210 @@
from dataclasses import dataclass
from typing import List
from ordered_set import OrderedSet
from ..pyexecutor.llm_request import LlmRequest
from ..pyexecutor.resource_manager import BaseResourceManager
from ..pyexecutor.scheduler import ScheduledRequests
from .interface import SpecConfig, SpeculativeDecodingMode
@dataclass
class NGramConfig(SpecConfig):
"""
Configuration for N-gram drafter.
"""
# The name of speculative decoding.
spec_dec_name = "NGRAM"
num_extra_kv_tokens: int = 0
max_draft_tokens: int = 0
prompt_lookup_num_tokens: int = 5
max_matching_ngram_size: int = 5
end_id: int = -1
is_keep_all: bool = True
is_use_oldest: bool = True
is_public_pool: bool = True
def __post_init__(self) -> None:
self.spec_dec_mode = SpeculativeDecodingMode.from_string(
self.spec_dec_name)
self.max_draft_tokens = self.prompt_lookup_num_tokens
def update_from_model_config(self, model_config):
pass
class NGramPoolManager(BaseResourceManager):
"""
This class maintains the pattern-matches pairs for NGram drafter.
For example, one of the existed pairs could be: ["I","love"] -> [["apple", "because", "it", "is"], ["banana", "and"]].
Here we call ["I","love"] as `pattern`, and [["apple", "because", "it", "is"], ["banana", "and"]] as `matches`.
`pattern` is a list of token_ids. The pool provides corresponding draft tokens from the matches if the pattern appears at the tail of the sentence during generation.
`matches` is a list of candidate draft token_ids attaching to a pattern.
Arguments:
prompt_lookup_num_tokens: int
The length maximum of draft tokens (can be understood as length maximum of output draft tokens).
max_matching_ngram_size: int
The length maximum of searching tokens (can be understood as length maximum of input tokens to search).
is_keep_all: bool = True
Whether to keep all candidate pattern-matches pairs, only one match is kept for each pattern if False.
is_use_oldest: bool = True
Whether to provide the oldest match when pattern is hit, the newest one is provided if False.
is_public_pool: bool = True
Whether to use a common pool for all requests, or the pool is private for each request if False.
Members:
pool: dict[tuple[int], OrderedSet[int]] or dict[int, dict[tuple[int], OrderedSet[int]]]
If is_public_pool == True, it maps from patterns to matches
If is_public_pool == False, it maps from request ID to the request-specific pool
start_index: dict[int, int]
It maps from request ID to the index of the prompt to update the pool in the next step
"""
def __init__(self, config: NGramConfig, max_num_requests: int):
self.max_num_requests = max_num_requests
self.max_num_draft_tokens = config.max_draft_tokens
self.prompt_lookup_num_tokens = config.prompt_lookup_num_tokens
self.max_matching_ngram_size = config.max_matching_ngram_size
self.is_keep_all = config.is_keep_all
self.is_use_oldest = config.is_use_oldest # TODO: remove this if updating strategy is supported
self.is_public_pool = config.is_public_pool
self.pool = {}
self.start_index = {}
def prepare_resources(self, scheduled_batch: ScheduledRequests):
# Update pool and provide draft tokens for the requests
for request in scheduled_batch.generation_requests:
num_draft_tokens = 0 if request.py_last_draft_tokens is None else \
len(request.py_last_draft_tokens)
num_accepted_tokens = getattr(request,
"py_num_accepted_draft_tokens", 0)
num_rejected_tokens = num_draft_tokens - num_accepted_tokens
assert num_rejected_tokens >= 0
# Generate draft tokens
draft_tokens = self._get_draft_tokens(
request.get_tokens()[0],
request.request_id,
request.py_end_id,
request.py_orig_prompt_len + request.py_max_new_tokens,
)
# Pad to max_draft_tokens
if draft_tokens is not None:
pad_length = self.max_num_draft_tokens - len(draft_tokens)
draft_tokens.extend([request.py_end_id] * pad_length)
request.py_draft_tokens = draft_tokens
def update_resources(self, scheduled_batch: ScheduledRequests):
pass
def free_resources(self, request: LlmRequest):
if self.is_public_pool:
return # TODO: need to have a strategy to swap out the pairs
request_id = request.request_id
if request_id in self.pool:
self.pool.pop(request_id)
self.start_index.pop(request_id)
def add_dummy_requests(self, request_ids: List[int]):
pass
def shutdown(self):
pass
def get_max_resource_count(self) -> int:
return self.max_num_requests
def get_needed_resource_to_completion(self, request: LlmRequest):
return 0
def print_pool(self): # For debug
if self.is_public_pool:
logger.debug(f"Using public pool, size = {len(self.pool)}")
self._print_line(self.pool)
else:
logger.debug(f"Using private pool")
for request_id, request_map in self.pool.items():
logger.debug(f"Request {request_id}, size={len(request_map)}")
self._print_line(request_map, 4)
def _print_line(self, local_map, indentation=0): # For debug
for pattern, matches in local_map.items():
output = " " * indentation + str(pattern) + "->"
for match in matches:
output += str(match) + ", "
logger.debug(output)
def _get_draft_tokens(
self,
prefix: list[int],
request_id: int,
end_id: int,
max_sequence_length: int,
):
prefix_len = len(prefix)
max_draft_token_length = max_sequence_length - 1 - prefix_len
if max_draft_token_length <= 0: # Skip search if prefix is long enough
return None
if request_id not in self.start_index: # A new request
self.start_index[request_id] = 0
if not self.is_public_pool:
assert len(self.pool) + 1 <= self.max_num_requests
self.pool[request_id] = {}
pool = (self.pool if self.is_public_pool else self.pool[request_id])
# Update pool
sequence = prefix[self.start_index[request_id]:]
for size in range(min(self.max_matching_ngram_size, prefix_len - 1), 0,
-1):
# Find each possible pattern-match combination, and use tuple for hash
for l in range(len(sequence) - size):
r = min(l + size + self.prompt_lookup_num_tokens, len(sequence))
pattern = tuple(sequence[l:l + size])
new_match = tuple(sequence[l + size:r])
if pattern not in pool or \
(not self.is_keep_all and len(match) > pool[pattern][0]):
# Replace the match if
# 1. the pattern does not exist in the pool
# 2. only one match is kept, and the new match is longer (MRU)
pool[pattern] = OrderedSet((new_match, ))
elif new_match not in pool[pattern]:
# Update the matches if the pattern is already existed:
# TODO: need a strategy to maintain the short candidates, now we just remove them
# Drop all existed matches with small length
for match in pool[pattern]:
if len(match) < len(new_match):
pool[pattern].remove(match)
pool[pattern].add(new_match)
# Find match
draft_tokens = [end_id]
for size in range(min(self.max_matching_ngram_size, prefix_len - 1), 0,
-1):
pattern = tuple(prefix[-size:])
if pattern not in pool:
continue
draft_tokens = pool[pattern][0 if self.is_use_oldest else -1]
draft_tokens = list(draft_tokens)[:max_draft_token_length]
break
self.start_index[request_id] = max(
0, prefix_len -
(self.prompt_lookup_num_tokens + self.max_matching_ngram_size - 1))
return draft_tokens

View File

@ -1,5 +1,6 @@
from .eagle3 import Eagle3Sampler, Eagle3SpecMetadata
from .mtp import MTPHiddenStatesManager, MTPSampler, MTPSpecMetadata
from .ngram import NGramPoolManager
def get_spec_metadata(spec_config,
@ -34,6 +35,8 @@ def get_spec_resource_manager(spec_config, model_config, max_num_requests):
return MTPHiddenStatesManager(spec_config, model_config.torch_dtype,
model_config.hidden_size,
max_num_requests)
elif spec_config.spec_dec_mode.is_ngram():
return NGramPoolManager(spec_config, max_num_requests)
else:
return None

View File

@ -8,7 +8,7 @@ from .llm_args import (BatchingType, CacheTransceiverConfig, CalibConfig,
DynamicBatchConfig, EagleDecodingConfig,
ExtendedRuntimePerfKnobConfig, KvCacheConfig,
LookaheadDecodingConfig, MedusaDecodingConfig,
MTPDecodingConfig, SchedulerConfig)
MTPDecodingConfig, NGramDecodingConfig, SchedulerConfig)
from .llm_utils import (BuildConfig, KvCacheRetentionConfig, QuantAlgo,
QuantConfig)
from .mpi_session import MpiCommSession
@ -40,4 +40,5 @@ __all__ = [
'ContextChunkingPolicy',
'DynamicBatchConfig',
'CacheTransceiverConfig',
'NGramDecodingConfig',
]

View File

@ -195,7 +195,8 @@ class DecodingBaseConfig(BaseModel):
"MTP": MTPDecodingConfig,
"Medusa": MedusaDecodingConfig,
"Eagle": EagleDecodingConfig,
"Lookahead": LookaheadDecodingConfig
"Lookahead": LookaheadDecodingConfig,
"NGram": NGramDecodingConfig,
}
config_class = config_classes.get(decoding_type)
@ -236,6 +237,40 @@ class EagleDecodingConfig(DecodingBaseConfig):
decoding_type: ClassVar[str] = "Eagle"
class NGramDecodingConfig(DecodingBaseConfig):
"""
Configuration for NGram drafter speculative decoding.
Arguments:
prompt_lookup_num_tokens: int
The length maximum of draft tokens (can be understood as length maximum of output draft tokens).
max_matching_ngram_size: int
The length maximum of searching tokens (can be understood as length maximum of input tokens to search).
is_keep_all: bool = True
Whether to keep all candidate pattern-matches pairs, only one match is kept for each pattern if False.
is_use_oldest: bool = True
Whether to provide the oldest match when pattern is hit, the newest one is provided if False.
is_public_pool: bool = True
Whether to use a common pool for all requests, or the pool is private for each request if False.
"""
prompt_lookup_num_tokens: int = 2
max_matching_ngram_size: int = 4
is_keep_all: bool = True
is_use_oldest: bool = True
is_public_pool: bool = True
@classmethod
def from_dict(cls, data: dict):
return cls(**data)
decoding_type: ClassVar[str] = "NGram"
class MTPDecodingConfig(DecodingBaseConfig):
num_nextn_predict_layers: Optional[int] = 1
use_relaxed_acceptance_for_thinking: Optional[bool] = False
@ -880,8 +915,8 @@ class LlmArgs(BaseModel):
# Speculative decoding parameters
speculative_config: Optional[Union[
LookaheadDecodingConfig, MedusaDecodingConfig, EagleDecodingConfig,
MTPDecodingConfig]] = Field(default=None,
description="Speculative decoding config.")
MTPDecodingConfig, NGramDecodingConfig]] = Field(
default=None, description="Speculative decoding config.")
batching_type: Optional[BatchingType] = Field(default=None,
description="Batching type.")
@ -1209,7 +1244,21 @@ class LlmArgs(BaseModel):
max_draft_tokens=self.speculative_config.max_draft_len,
eagle_weights_path=self.speculative_config.
pytorch_eagle_weights_path)
elif isinstance(self.speculative_config, NGramDecodingConfig):
self.build_config.speculative_decoding_mode = SpeculativeDecodingMode.NGRAM
assert self.backend == 'pytorch'
assert self.speculative_config.prompt_lookup_num_tokens > 0 and self.speculative_config.max_matching_ngram_size > 0
self.build_config.max_draft_len = self.speculative_config.max_draft_len
from tensorrt_llm._torch.speculative import NGramConfig
self.speculative_config = NGramConfig(
prompt_lookup_num_tokens=self.speculative_config.
prompt_lookup_num_tokens,
max_matching_ngram_size=self.speculative_config.
max_matching_ngram_size,
is_keep_all=self.speculative_config.is_keep_all,
is_use_oldest=self.speculative_config.is_use_oldest,
is_public_pool=self.speculative_config.is_public_pool,
)
elif isinstance(self.speculative_config, MTPDecodingConfig):
from tensorrt_llm._torch.speculative import MTPConfig
self.speculative_config = MTPConfig(

View File

@ -31,8 +31,8 @@ from .build_cache import (BuildCache, BuildCacheConfig, CachedStage,
get_build_cache_config_from_env)
from .llm_args import (CalibConfig, EagleDecodingConfig, KvCacheConfig, LlmArgs,
LookaheadDecodingConfig, MedusaDecodingConfig,
MTPDecodingConfig, _ModelFormatKind, _ModelWrapper,
_ParallelConfig, get_model_format,
MTPDecodingConfig, NGramDecodingConfig, _ModelFormatKind,
_ModelWrapper, _ParallelConfig, get_model_format,
update_llm_args_with_extra_dict,
update_llm_args_with_extra_options)
from .mpi_session import MPINodeState, MpiSession
@ -860,6 +860,7 @@ __all__ = [
'LookaheadDecodingConfig',
'MedusaDecodingConfig',
'MTPDecodingConfig',
'NGramDecodingConfig',
'ContextChunkingPolicy',
'CapacitySchedulerPolicy',
'BuildConfig',

View File

@ -96,6 +96,7 @@ class SpeculativeDecodingMode(IntFlag):
LOOKAHEAD_DECODING = auto()
EXPLICIT_DRAFT_TOKENS = auto()
EAGLE = auto()
NGRAM = auto()
@staticmethod
def from_arguments(args: argparse.Namespace):
@ -111,6 +112,8 @@ class SpeculativeDecodingMode(IntFlag):
return SpeculativeDecodingMode.EXPLICIT_DRAFT_TOKENS
elif args.speculative_decoding_mode == "eagle":
return SpeculativeDecodingMode.EAGLE
elif args.speculative_decoding_mode == "ngram":
return SpeculativeDecodingMode.NGRAM
else:
assert False, "Unknown speculative_decoding_mode " + args.speculative_decoding_mode

View File

@ -0,0 +1,83 @@
import os
import sys
import unittest
import pytest
import torch
from tensorrt_llm import SamplingParams
from tensorrt_llm._torch import LLM
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
from tensorrt_llm.llmapi import KvCacheConfig, NGramDecodingConfig
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from utils.llm_data import llm_models_root
# TODO: Add cuda graph enabled tests.
# Cuda graph cannot currently be enabled for ngram because cuda graph requires
# spec metadata and ngram does not have it.
@pytest.mark.parametrize("use_cuda_graph,attn_backend",
[[False, "TRTLLM"], [False, "FLASHINFER"]])
def test_llama_ngram(use_cuda_graph: bool, attn_backend: str):
total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
if total_mem_gb < 31:
pytest.skip("Not enough memory to load target model")
models_path = llm_models_root()
pytorch_config = PyTorchConfig(
enable_overlap_scheduler=False,
use_cuda_graph=use_cuda_graph,
# Only create a single CUDA graph to prevent OOM in CI
attn_backend=attn_backend,
cuda_graph_batch_sizes=[1],
)
kv_cache_config = KvCacheConfig(enable_block_reuse=False, max_tokens=2080)
sampling_params = SamplingParams(
max_tokens=32,
temperature=0,
)
max_batch_size = 1
target_model_dir = f"{models_path}/llama-models-v2/llama-v2-13b-hf"
draft_len = 4
spec_config = NGramDecodingConfig(
prompt_lookup_num_tokens=draft_len,
max_matching_ngram_size=draft_len,
is_keep_all=True,
is_use_oldest=True,
is_public_pool=True,
)
llm_spec = LLM(model=target_model_dir,
max_batch_size=max_batch_size,
pytorch_backend_config=pytorch_config,
kv_cache_config=kv_cache_config,
speculative_config=spec_config)
prompts = [
"The capital of France is", "The president of the United States is"
]
results_spec = llm_spec.generate(prompts, sampling_params)
generated_text_spec = [result.outputs[0].text for result in results_spec]
llm_spec.shutdown()
llm_ref = LLM(model=target_model_dir,
max_batch_size=max_batch_size,
pytorch_backend_config=pytorch_config,
kv_cache_config=kv_cache_config)
results_ref = llm_ref.generate(prompts, sampling_params)
generated_text_ref = [result.outputs[0].text for result in results_ref]
llm_ref.shutdown()
for text_spec, text_ref in zip(generated_text_spec, generated_text_ref):
# The spec decode algorithm currently guarantees identical results
assert text_spec == text_ref
if __name__ == "__main__":
unittest.main()

View File

@ -74,7 +74,7 @@ methods:
# Speculative decoding
speculative_config:
annotation: Union[tensorrt_llm.llmapi.llm_args.LookaheadDecodingConfig, tensorrt_llm.llmapi.llm_utils.MedusaDecodingConfig,
tensorrt_llm.llmapi.llm_utils.EagleDecodingConfig, tensorrt_llm.llmapi.MTPDecodingConfig, NoneType]
tensorrt_llm.llmapi.llm_utils.EagleDecodingConfig, tensorrt_llm.llmapi.MTPDecodingConfig, tensorrt_llm.llmapi.llm_args.NGramDecodingConfig, NoneType]
default: null
# generation constraints
max_batch_size: