mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[TRTLLM-5000][feat] Pytorch implementation of ngram drafter (#3936)
* v1.5 Signed-off-by: wili-65535 <wili-65535@users.noreply.github.com> v1.5.4 Add back draft_overhead to spec dec stats Signed-off-by: Thor Johnsen <41591019+thorjohnsen@users.noreply.github.com> * v1.5.5: fix CI error Signed-off-by: wili-65535 <wili-65535@users.noreply.github.com> * v1.6: fix CI error 8196 > 8192 Signed-off-by: wili-65535 <wili-65535@users.noreply.github.com> * Address reviewer concerns Signed-off-by: Thor Johnsen <41591019+thorjohnsen@users.noreply.github.com> * Address reviewer concerns Signed-off-by: Thor Johnsen <41591019+thorjohnsen@users.noreply.github.com> * precommit run Signed-off-by: Thor Johnsen <41591019+thorjohnsen@users.noreply.github.com> * v2.0: Address reviewer concerns Signed-off-by: wili-65535 <wili-65535@users.noreply.github.com> * v2.1: add fix from wili Signed-off-by: wili-65535 <wili-65535@users.noreply.github.com> * Revert changes that require use of TypeAlias because that requires python version >= 3.10 Signed-off-by: Thor Johnsen <41591019+thorjohnsen@users.noreply.github.com> --------- Signed-off-by: Thor Johnsen <41591019+thorjohnsen@users.noreply.github.com> Signed-off-by: wili-65535 <wili-65535@users.noreply.github.com> Co-authored-by: wili-65535 <wili-65535@users.noreply.github.com>
This commit is contained in:
parent
9199793848
commit
5d438be59a
@ -252,6 +252,11 @@ public:
|
||||
static void serialize(InflightBatchingStats const& inflightBatchingStats, std::ostream& os);
|
||||
static size_t serializedSize(InflightBatchingStats const& inflightBatchingStats);
|
||||
|
||||
// SpecDecodingStats
|
||||
static SpecDecodingStats deserializeSpecDecodingStats(std::istream& is);
|
||||
static void serialize(SpecDecodingStats const& specDecStats, std::ostream& os);
|
||||
static size_t serializedSize(SpecDecodingStats const& specDecStats);
|
||||
|
||||
// IterationStats
|
||||
static IterationStats deserializeIterationStats(std::vector<char>& buffer);
|
||||
static IterationStats deserializeIterationStats(std::istream& is);
|
||||
|
||||
@ -46,6 +46,7 @@ class Tensor;
|
||||
|
||||
using TensorPtr = std::shared_ptr<Tensor>;
|
||||
using SizeType32 = std::int32_t;
|
||||
using SizeType64 = std::int64_t;
|
||||
using FloatType = float;
|
||||
using TokenIdType = std::int32_t;
|
||||
using VecTokens = std::vector<TokenIdType>;
|
||||
@ -294,6 +295,24 @@ struct InflightBatchingStats
|
||||
float avgNumDecodedTokensPerIter;
|
||||
};
|
||||
|
||||
/// @brief Struct that holds speculative decoding stats
|
||||
struct SpecDecodingStats
|
||||
{
|
||||
/// @brief Total number of proposed draft tokens for all requests
|
||||
SizeType64 numDraftTokens;
|
||||
/// @brief Total number of accepted draft tokens for all requests
|
||||
SizeType64 numAcceptedTokens;
|
||||
/// @brief Number of requests with at least one draft token in batch
|
||||
SizeType64 numRequestsWithDraftTokens;
|
||||
/// @brief Acceptance length, defined as average number of tokens produced per step for all requests with at least
|
||||
/// one draft token
|
||||
double acceptanceLength;
|
||||
/// @brief Iteration latency for draft token generation only (ms)
|
||||
double iterLatencyMS;
|
||||
/// @brief Draft overhead, defined as iterLatencyMS (specdec) / iterLatencyMS (total)
|
||||
double draftOverhead;
|
||||
};
|
||||
|
||||
/// @brief Struct that holds the stats of a single iteration
|
||||
struct IterationStats
|
||||
{
|
||||
@ -341,6 +360,8 @@ struct IterationStats
|
||||
std::optional<StaticBatchingStats> staticBatchingStats;
|
||||
/// @brief Stats specific to inflight batching
|
||||
std::optional<InflightBatchingStats> inflightBatchingStats;
|
||||
/// @brief Stats specific to speculative decoding
|
||||
std::optional<SpecDecodingStats> specDecStats;
|
||||
};
|
||||
|
||||
/// @brief Enum class that represents the state of a request
|
||||
|
||||
@ -31,11 +31,13 @@ NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(
|
||||
StaticBatchingStats, numScheduledRequests, numContextRequests, numCtxTokens, numGenTokens, emptyGenSlots);
|
||||
NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(InflightBatchingStats, numScheduledRequests, numContextRequests, numGenRequests,
|
||||
numPausedRequests, numCtxTokens, microBatchId, avgNumDecodedTokensPerIter);
|
||||
NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(SpecDecodingStats, numDraftTokens, numAcceptedTokens, numRequestsWithDraftTokens,
|
||||
acceptanceLength, iterLatencyMS, draftOverhead);
|
||||
NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(IterationStats, timestamp, iter, iterLatencyMS, newActiveRequestsQueueLatencyMS,
|
||||
numNewActiveRequests, numActiveRequests, numQueuedRequests, numCompletedRequests, maxNumActiveRequests,
|
||||
maxBatchSizeStatic, maxBatchSizeTunerRecommended, maxBatchSizeRuntime, maxNumTokensStatic,
|
||||
maxNumTokensTunerRecommended, maxNumTokensRuntime, gpuMemUsage, cpuMemUsage, pinnedMemUsage, kvCacheStats,
|
||||
staticBatchingStats, inflightBatchingStats);
|
||||
staticBatchingStats, inflightBatchingStats, specDecStats);
|
||||
NLOHMANN_JSON_SERIALIZE_ENUM(RequestStage,
|
||||
{{RequestStage::kQUEUED, "QUEUED"}, {RequestStage::kCONTEXT_IN_PROGRESS, "CONTEXT_IN_PROGRESS"},
|
||||
{RequestStage::kGENERATION_IN_PROGRESS, "GENERATION_IN_PROGRESS"},
|
||||
|
||||
@ -1727,6 +1727,42 @@ size_t Serialization::serializedSize(InflightBatchingStats const& inflightBatchi
|
||||
return totalSize;
|
||||
}
|
||||
|
||||
// SpecDecodingStats
|
||||
SpecDecodingStats Serialization::deserializeSpecDecodingStats(std::istream& is)
|
||||
{
|
||||
auto numDraftTokens = su::deserialize<SizeType64>(is);
|
||||
auto numAcceptedTokens = su::deserialize<SizeType64>(is);
|
||||
auto numRequestsWithDraftTokens = su::deserialize<SizeType64>(is);
|
||||
auto acceptanceLength = su::deserialize<double>(is);
|
||||
auto iterLatencyMS = su::deserialize<double>(is);
|
||||
auto draftOverhead = su::deserialize<double>(is);
|
||||
|
||||
return SpecDecodingStats{
|
||||
numDraftTokens, numAcceptedTokens, numRequestsWithDraftTokens, acceptanceLength, iterLatencyMS, draftOverhead};
|
||||
}
|
||||
|
||||
void Serialization::serialize(SpecDecodingStats const& state, std::ostream& os)
|
||||
{
|
||||
su::serialize(state.numDraftTokens, os);
|
||||
su::serialize(state.numAcceptedTokens, os);
|
||||
su::serialize(state.numRequestsWithDraftTokens, os);
|
||||
su::serialize(state.acceptanceLength, os);
|
||||
su::serialize(state.iterLatencyMS, os);
|
||||
su::serialize(state.draftOverhead, os);
|
||||
}
|
||||
|
||||
size_t Serialization::serializedSize(SpecDecodingStats const& state)
|
||||
{
|
||||
size_t totalSize = 0;
|
||||
totalSize += su::serializedSize(state.numDraftTokens);
|
||||
totalSize += su::serializedSize(state.numAcceptedTokens);
|
||||
totalSize += su::serializedSize(state.numRequestsWithDraftTokens);
|
||||
totalSize += su::serializedSize(state.acceptanceLength);
|
||||
totalSize += su::serializedSize(state.iterLatencyMS);
|
||||
totalSize += su::serializedSize(state.draftOverhead);
|
||||
return totalSize;
|
||||
}
|
||||
|
||||
// IterationStats
|
||||
|
||||
IterationStats Serialization::deserializeIterationStats(std::istream& is)
|
||||
@ -1754,12 +1790,13 @@ IterationStats Serialization::deserializeIterationStats(std::istream& is)
|
||||
auto crossKvCacheStats = su::deserialize<std::optional<KvCacheStats>>(is);
|
||||
auto staticBatchingStats = su::deserialize<std::optional<StaticBatchingStats>>(is);
|
||||
auto inflightBatchingStats = su::deserialize<std::optional<InflightBatchingStats>>(is);
|
||||
auto specdecStats = su::deserialize<std::optional<SpecDecodingStats>>(is);
|
||||
|
||||
return IterationStats{timestamp, iter, iterLatencyMS, newActiveRequestsQueueLatencyMS, numNewActiveRequests,
|
||||
numActiveRequests, numQueuedRequests, numCompletedRequests, maxNumActiveRequests, maxBatchSizeStatic,
|
||||
maxBatchSizeTunerRecommended, maxBatchSizeRuntime, maxNumTokensStatic, maxNumTokensTunerRecommended,
|
||||
maxNumTokensRuntime, gpuMemUsage, cpuMemUsage, pinnedMemUsage, kvCacheStats, crossKvCacheStats,
|
||||
staticBatchingStats, inflightBatchingStats};
|
||||
staticBatchingStats, inflightBatchingStats, specdecStats};
|
||||
}
|
||||
|
||||
IterationStats Serialization::deserializeIterationStats(std::vector<char>& buffer)
|
||||
@ -1797,6 +1834,7 @@ size_t Serialization::serializedSize(IterationStats const& iterStats)
|
||||
totalSize += su::serializedSize(iterStats.crossKvCacheStats);
|
||||
totalSize += su::serializedSize(iterStats.staticBatchingStats);
|
||||
totalSize += su::serializedSize(iterStats.inflightBatchingStats);
|
||||
totalSize += su::serializedSize(iterStats.specDecStats);
|
||||
|
||||
return totalSize;
|
||||
}
|
||||
@ -1825,6 +1863,7 @@ void Serialization::serialize(IterationStats const& iterStats, std::ostream& os)
|
||||
su::serialize(iterStats.crossKvCacheStats, os);
|
||||
su::serialize(iterStats.staticBatchingStats, os);
|
||||
su::serialize(iterStats.inflightBatchingStats, os);
|
||||
su::serialize(iterStats.specDecStats, os);
|
||||
}
|
||||
|
||||
std::vector<char> Serialization::serialize(IterationStats const& iterStats)
|
||||
|
||||
@ -465,6 +465,10 @@ T deserialize(std::istream& is)
|
||||
{
|
||||
return Serialization::deserializeInflightBatchingStats(is);
|
||||
}
|
||||
else if constexpr (std::is_same_v<T, tensorrt_llm::executor::SpecDecodingStats>)
|
||||
{
|
||||
return Serialization::deserializeSpecDecodingStats(is);
|
||||
}
|
||||
else if constexpr (std::is_same_v<T, tensorrt_llm::executor::IterationStats>)
|
||||
{
|
||||
return Serialization::deserializeIterationStats(is);
|
||||
|
||||
@ -132,6 +132,15 @@ void initBindings(pybind11::module_& m)
|
||||
.def_readwrite("micro_batch_id", &tle::InflightBatchingStats::microBatchId)
|
||||
.def_readwrite("avg_num_decoded_tokens_per_iter", &tle::InflightBatchingStats::avgNumDecodedTokensPerIter);
|
||||
|
||||
py::class_<tle::SpecDecodingStats>(m, "SpecDecodingStats")
|
||||
.def(py::init<>())
|
||||
.def_readwrite("num_draft_tokens", &tle::SpecDecodingStats::numDraftTokens)
|
||||
.def_readwrite("num_accepted_tokens", &tle::SpecDecodingStats::numAcceptedTokens)
|
||||
.def_readwrite("num_requests_with_draft_tokens", &tle::SpecDecodingStats::numRequestsWithDraftTokens)
|
||||
.def_readwrite("acceptance_length", &tle::SpecDecodingStats::acceptanceLength)
|
||||
.def_readwrite("iter_latency_ms", &tle::SpecDecodingStats::iterLatencyMS)
|
||||
.def_readwrite("draft_overhead", &tle::SpecDecodingStats::draftOverhead);
|
||||
|
||||
py::class_<tle::IterationStats>(m, "IterationStats")
|
||||
.def(py::init<>())
|
||||
.def_readwrite("timestamp", &tle::IterationStats::timestamp)
|
||||
@ -150,6 +159,7 @@ void initBindings(pybind11::module_& m)
|
||||
.def_readwrite("cross_kv_cache_stats", &tle::IterationStats::crossKvCacheStats)
|
||||
.def_readwrite("static_batching_stats", &tle::IterationStats::staticBatchingStats)
|
||||
.def_readwrite("inflight_batching_stats", &tle::IterationStats::inflightBatchingStats)
|
||||
.def_readwrite("specdec_stats", &tle::IterationStats::specDecStats)
|
||||
.def("to_json_str",
|
||||
[](tle::IterationStats const& iterationStats)
|
||||
{ return tle::JsonSerialization::toJsonStr(iterationStats); });
|
||||
|
||||
@ -4,10 +4,28 @@ This document shows how to build and run a model using Prompt-Lookup speculative
|
||||
|
||||
## Overview
|
||||
|
||||
We provide two styles of workflow to run Prompt-Lookup (named V1 and V2 respectively) now. V1 is in TRT workflow and similar to the Draft-Target-Model workflow, running in orchestrator mode and calling `runner.generate()` multiple times to get outputs, which is more flexible for customizing but slightly more overhead. V2 is in pytorch workflow and similar to the Look-Ahead workflow, running in leader mode and calling `runner.generate()` only one time to get outputs, which provides higher performance but fixed process.
|
||||
|
||||
The Prompt-Lookup has 3 additional hyperparameters that you need to specify to control the process of generation:
|
||||
- `prompt_lookup_num_tokens`: the number of tokens we extract from input prompt or previous generated output as draft tokens in one iteration, which the range is from 4 to 10 in common usage. Empirically, the larger the value is, the higher acceptance ratio but higher overhead is expected at the same time, so the right balance based on the models and application scenarios needs to be found.
|
||||
- `max_matching_ngram_size`: the number of tokens we get from the tail of the generated output as a pattern, which is used to match in input prompt or previous generated output. Empirically, the larger the value is, the more precise context can be matched from the existed sequence, indicating higher acceptance ratio, but the higher probability of miss-match and higher overhead appear, which fall back to normal generation (one token per iteration).
|
||||
- `device_list`: the index list of device(s) to run the model. The length of it must be the same as the TP size of the draft model engine. For instances, `device_list=[0]` means using tp_size=1 and GPU 0 for the model, `device_list=[4,5,6,7]` means using tp=4 and GPU from 4 to 7 for the model.
|
||||
- `prompt_lookup_num_tokens`: the maximum number of tokens provided as draft tokens in one iteration, which is usually from 4 to 10 in common usage (default value: 4). Empirically, the larger the value is, the higher acceptance rate but higher overhead is expected at the same time, so the right balance based on the models and application scenarios needs to be found.
|
||||
- `max_matching_ngram_size`: the maximum number of tokens extracted from the tail of the input prompt or generated output as a pattern, which is used to search corresponding draft tokens (default value: 2). Empirically, the larger the value is, the more precise context can be matched from the existed sequence, indicating higher acceptance rate, but the higher probability of miss-match and higher overhead appear, which fall back to normal generation (one token per iteration).
|
||||
- `device_list`: the index list of device(s) to run the model in V1 workflow. The length of it must be the same as the TP size of the draft model engine. For instances, `device_list=[0]` means using tp_size=1 and GPU 0 for the model, `device_list=[4,5,6,7]` means using tp=4 and GPU from 4 to 7 for the model. This parameter is neddless in V2 workflow.
|
||||
|
||||
+ For example, the process of getting draft tokens using `prompt_lookup_num_tokens=2` and `max_matching_ngram_size=4` with a sentence `prefix=[..., t1, t2, t3, t4]` is like below:
|
||||
|
||||
```Python
|
||||
pattern = prefix[:-2] # pattern=[t3, t4] (length=2)
|
||||
if pattern in pool and len(pool[pattern]) == 4: # assuming it is {(t3, t4): (t5, t6, t7, t8)}
|
||||
return pool[pattern] # draft token = [t5, t6, t7, t8]
|
||||
elif pattern in pool and len(pool[pattern]) == <4: # assuming it is {(t3, t4): (t9, t10, t11)}
|
||||
return pool[pattern] # draft token = [t9, t10, t11]
|
||||
pattern = prefix[:-1] # Try shorter pattern if no candidate of length=2 exists, pattern=[t4] (length=1)
|
||||
if pattern in pool and len(pool[pattern]) == 4: # The same process as above
|
||||
return pool[pattern]
|
||||
elif pattern in pool and len(pool[pattern]) == <4:
|
||||
return pool[pattern]
|
||||
return None # No any candidate exists
|
||||
```
|
||||
|
||||
## Support Matrix
|
||||
* GPU Compute Capability >= 8.0 (Ampere or newer)
|
||||
@ -17,17 +35,21 @@ The Prompt-Lookup has 3 additional hyperparameters that you need to specify to c
|
||||
|
||||
## Usage
|
||||
|
||||
### Build engines
|
||||
### V1 workflow
|
||||
|
||||
+ We use an open-source `llama-v2-13B` models in this example.
|
||||
+ `--use_paged_context_fmha=enable` must be specified since we need KVcache reuse in this approach.
|
||||
+ `--speculative_decoding_mode=draft_tokens_external` must be specified.
|
||||
+ `--max_draft_len` must be specified larger or equal to `prompt_lookup_num_tokens`.
|
||||
+ `---prompt_lookup_config` is corresponding configuration of Prompt-Lookup, we can see its usage in [util.py](../util.py).
|
||||
+ As an example, `[10,2,[0]]` means `prompt_lookup_num_tokens=10`, `max_matching_ngram_size=2`, and device of target model is `GPU0`.
|
||||
+ `--kv_cache_enable_block_reuse` must be specified for this approach.
|
||||
+ Only CPP session is supported, so `--use_py_session` must not be specified.
|
||||
+ `--num_beams` can not be specified as larger than 1 since beam search is not supported in this approach yet.
|
||||
|
||||
```bash
|
||||
cd examples/models/core/llama
|
||||
|
||||
python3 convert_checkpoint.py \
|
||||
# Build engine
|
||||
python3 examples/models/core/llama/convert_checkpoint.py \
|
||||
--model_dir=<Path To Llama-v2-13B repo> \
|
||||
--output_dir=./ckpt-target \
|
||||
--dtype=float16
|
||||
@ -42,34 +64,18 @@ trtllm-build \
|
||||
--max_batch_size=4 \
|
||||
--max_input_len=3200 \
|
||||
--max_seq_len=4800
|
||||
```
|
||||
|
||||
### Run decoding
|
||||
|
||||
+ `---prompt_lookup_config` is corresponding configuration of Prompt-Lookup, we can see its usage in [util.py](../util.py).
|
||||
+ As an example, `[10,2,[0]]` means `prompt_lookup_num_tokens=10`, `max_matching_ngram_size=2`, and device of target model is `GPU0`.
|
||||
+ `--kv_cache_enable_block_reuse` must be specified for this approach.
|
||||
+ Only CPP session is supported, so `--use_py_session` must not be specified.
|
||||
+ `--num_beams` can not be specified as larger than 1 since beam search is not supported in this approach yet.
|
||||
|
||||
```bash
|
||||
cd examples/models/core/llama
|
||||
|
||||
python3 ../../../run.py \
|
||||
# Run decoding
|
||||
python3 examples/run.py \
|
||||
--tokenizer_dir <Path To Llama-v2-7B repo> \
|
||||
--engine_dir ./target-engine \
|
||||
--prompt_lookup_config="[10,2,[0]]" \
|
||||
--max_output_len=256 \
|
||||
--kv_cache_enable_block_reuse \
|
||||
--input_text="How does Draft-Sampling work?"
|
||||
```
|
||||
|
||||
## Run summarization tasks
|
||||
|
||||
```bash
|
||||
cd examples/models/core/llama
|
||||
|
||||
python ../../../summarize.py \
|
||||
# Run summarization tasks
|
||||
python examples/summarize.py \
|
||||
--test_hf \
|
||||
--test_trt_llm \
|
||||
--check_accuracy \
|
||||
@ -79,3 +85,11 @@ python ../../../summarize.py \
|
||||
--prompt_lookup_config="[10,2,[0]]" \
|
||||
--kv_cache_enable_block_reuse
|
||||
```
|
||||
|
||||
### V2 workflow
|
||||
|
||||
```bash
|
||||
python3 examples/pytorch/quickstart_advanced.py \
|
||||
--max_matching_ngram_size=2 \
|
||||
--spec_decode_nextn=4
|
||||
```
|
||||
|
||||
@ -36,8 +36,8 @@ class PLDPool: # Ngrams pool for Prompt-Lookup-Decoding
|
||||
is_use_oldest: bool = True,
|
||||
):
|
||||
self.input_batch_size = input_batch_size
|
||||
self.plnt = prompt_lookup_num_tokens # Shorter name
|
||||
self.mmns = max_matching_ngram_size # Shorter name
|
||||
self.prompt_lookup_num_tokens = prompt_lookup_num_tokens
|
||||
self.max_matching_ngram_size = max_matching_ngram_size
|
||||
self.end_id = end_id
|
||||
self.max_seq_len = max_seq_len
|
||||
self.is_keep_all = is_keep_all
|
||||
@ -45,9 +45,25 @@ class PLDPool: # Ngrams pool for Prompt-Lookup-Decoding
|
||||
self.pool = [{} for _ in range(input_batch_size)]
|
||||
self.start_index = [0 for _ in range(input_batch_size)]
|
||||
|
||||
# modified from `transformers/generation/candidate_generator.py`
|
||||
assert self.prompt_lookup_num_tokens > 0, f"prompt_lookup_num_tokens must be greater than 0, but got {self.prompt_lookup_num_tokens}"
|
||||
assert self.max_matching_ngram_size > 0, f"max_matching_ngram_size must be greater than 0, but got {self.max_matching_ngram_size}"
|
||||
|
||||
def print_pool(self):
|
||||
"""
|
||||
For debug
|
||||
"""
|
||||
logger.info(f"Batch size = {self.input_batch_size}")
|
||||
for i, map in enumerate(self.pool):
|
||||
logger.info(f"Slot {i}, size = {len(map)}")
|
||||
for key, values in map.items():
|
||||
logger.info(f" {key}->{values}")
|
||||
|
||||
def get_draft_tokens(self, prefix: list[torch.Tensor],
|
||||
batch_slot: list[int]):
|
||||
"""
|
||||
Get draft tokens from a batch of requests
|
||||
modified from `transformers/generation/candidate_generator.py`
|
||||
"""
|
||||
batch_size = len(prefix)
|
||||
prefix_len = [len(prefix[bi]) for bi in range(batch_size)]
|
||||
draft_tokens = [] # `logits` is useless yet
|
||||
@ -61,25 +77,30 @@ class PLDPool: # Ngrams pool for Prompt-Lookup-Decoding
|
||||
|
||||
# Update pool
|
||||
sequence = prefix[bi][self.start_index[gbi]:].tolist()
|
||||
for size in range(min(self.mmns, prefix_len[bi] - 1), 0, -1):
|
||||
for size in range(
|
||||
min(self.max_matching_ngram_size, prefix_len[bi] - 1), 0,
|
||||
-1):
|
||||
# Find each possible key-value combination, and use tuple for hash
|
||||
for l in range(len(sequence) - size):
|
||||
r = min(l + size + self.plnt, len(sequence))
|
||||
r = min(l + size + self.prompt_lookup_num_tokens,
|
||||
len(sequence))
|
||||
key = tuple(sequence[l:l + size])
|
||||
value = tuple(sequence[l + size:r])
|
||||
if key not in self.pool[gbi] or not self.is_keep_all or \
|
||||
len(self.pool[gbi][key][0]) < self.plnt:
|
||||
len(self.pool[gbi][key][0]) < self.prompt_lookup_num_tokens:
|
||||
# Update the value if
|
||||
# 1. the key does not exist
|
||||
# 2. we only keep one value for each key
|
||||
# 2. we only keep the newest one value for each key (MRU)
|
||||
# 3. the length of the value saved before is less than `prompt_lookup_num_tokens`
|
||||
self.pool[gbi][key] = OrderedSet((value, ))
|
||||
elif value not in self.pool[gbi][key]:
|
||||
# Extend the value if the key is already existed and we want to keep all of them
|
||||
# Extend the value if the key is already existed but count of values is not enough
|
||||
self.pool[gbi][key].add(value)
|
||||
|
||||
# Find match
|
||||
for size in range(min(self.mmns, prefix_len[bi] - 1), 0, -1):
|
||||
for size in range(
|
||||
min(self.max_matching_ngram_size, prefix_len[bi] - 1), 0,
|
||||
-1):
|
||||
pattern = tuple(prefix[bi][-size:].tolist())
|
||||
if pattern not in self.pool[gbi]:
|
||||
continue
|
||||
@ -92,7 +113,8 @@ class PLDPool: # Ngrams pool for Prompt-Lookup-Decoding
|
||||
break
|
||||
draft_tokens.append(chosen_ids)
|
||||
self.start_index[gbi] = max(
|
||||
0, prefix_len[bi] - (self.plnt + self.mmns - 1))
|
||||
0, prefix_len[bi] - (self.prompt_lookup_num_tokens +
|
||||
self.max_matching_ngram_size - 1))
|
||||
|
||||
return draft_tokens, None
|
||||
|
||||
@ -108,20 +130,21 @@ def run_dtm_pld(batch_input_ids,
|
||||
*,
|
||||
target_runner=None):
|
||||
# `dtm` for Draft-Target-Model, `pld` for Prompt-Lookup-Decoding
|
||||
assert (args.draft_target_model_config is not None) ^ (args.prompt_lookup_config is not None), \
|
||||
"`--draft_target_model_config` and `--prompt_lookup_config` can not be specified at the same time."
|
||||
if args.draft_target_model_config is not None:
|
||||
assert args.draft_engine_dir is not None, "`--draft_engine_dir` must be specified in Draft-Target-Model."
|
||||
is_dtm = (args.draft_target_model_config is not None)
|
||||
is_pld = (args.prompt_lookup_config is not None)
|
||||
assert is_dtm ^ is_pld, "`--draft_target_model_config` and `--prompt_lookup_config` can not be specified at the same time."
|
||||
if is_dtm:
|
||||
assert args.draft_engine_dir is not None, "`--draft_engine_dir` must be specified in Draft-Target-Model."
|
||||
draft_len, draft_device_list, target_device_list, use_logits = ast.literal_eval(
|
||||
args.draft_target_model_config)
|
||||
logger.info(f"Using Draft-Target-Model speculative decoding")
|
||||
logger.info(f"draft_len: {draft_len}")
|
||||
logger.info(f"Device(s) for draft model: {draft_device_list}")
|
||||
logger.info(f"Device(s) for target model: {target_device_list}")
|
||||
logger.info(f"Use logits to accept tokens: {use_logits}")
|
||||
if is_pld:
|
||||
logger.info(
|
||||
f"Using Prompt-Lookup-Decoding speculative decoding V1 workflow")
|
||||
prompt_lookup_num_tokens, max_matching_ngram_size, target_device_list = ast.literal_eval(
|
||||
args.prompt_lookup_config)
|
||||
logger.info(f"prompt_lookup_num_tokens: {prompt_lookup_num_tokens}")
|
||||
@ -206,9 +229,8 @@ def run_dtm_pld(batch_input_ids,
|
||||
device_ids=target_device_list)
|
||||
target_runner = ModelRunnerCpp.from_dir(**target_runner_kwargs)
|
||||
|
||||
if is_dtm and use_logits and \
|
||||
not (draft_runner.gather_generation_logits and target_runner.gather_generation_logits):
|
||||
assert False, "`--gather_generation_logits` must be specified while building draft/target models for using logits to accept"
|
||||
if is_dtm and use_logits:
|
||||
assert draft_runner.gather_generation_logits and target_runner.gather_generation_logits, "`--gather_generation_logits` must be specified while building draft/target models for using logits to accept"
|
||||
|
||||
common_generaion_kwargs = dict(
|
||||
max_attention_window_size=args.max_attention_window_size,
|
||||
|
||||
@ -1,12 +1,12 @@
|
||||
# TRT-LLM with PyTorch
|
||||
|
||||
Run the quick start script:
|
||||
## Run the quick start script:
|
||||
|
||||
```bash
|
||||
python3 quickstart.py
|
||||
```
|
||||
|
||||
Run the advanced usage example script:
|
||||
## Run the advanced usage example script:
|
||||
|
||||
```bash
|
||||
# BF16
|
||||
@ -29,8 +29,9 @@ python3 quickstart_advanced.py --model_dir nvidia/Llama-3_1-Nemotron-Ultra-253B-
|
||||
|
||||
# Nemotron-H requires disabling cache reuse in kv cache
|
||||
python3 quickstart_advanced.py --model_dir nvidia/Nemotron-H-8B-Base-8K --disable_kv_cache_reuse --max_batch_size 8
|
||||
```
|
||||
|
||||
Run the multimodal example script:
|
||||
## Run the multimodal example script:
|
||||
|
||||
```bash
|
||||
# default inputs
|
||||
@ -43,29 +44,39 @@ python3 quickstart_multimodal.py --model_dir Efficient-Large-Model/NVILA-8B --mo
|
||||
# Note: media should be either image or video. Mixing image and video is not supported.
|
||||
python3 quickstart_multimodal.py --model_dir Efficient-Large-Model/NVILA-8B --modality video --prompt "Tell me what you see in the video briefly." "Describe the scene in the video briefly." --media "https://huggingface.co/datasets/Efficient-Large-Model/VILA-inference-demos/resolve/main/OAI-sora-tokyo-walk.mp4" "https://huggingface.co/datasets/Efficient-Large-Model/VILA-inference-demos/resolve/main/world.mp4" --max_tokens 128 [--use_cuda_graph]
|
||||
```
|
||||
## Supported Models
|
||||
| Architecture | Model | HuggingFace Example | Modality |
|
||||
|--------------|-------|---------------------|----------|
|
||||
| `BertForSequenceClassification` | BERT-based | `textattack/bert-base-uncased-yelp-polarity` | L |
|
||||
| `DeepseekV3ForCausalLM` | DeepSeek-V3 | `deepseek-ai/DeepSeek-V3` | L |
|
||||
| `Gemma3ForCausalLM` | Gemma3 | `google/gemma-3-1b-it` | L |
|
||||
| `LlavaLlamaModel` | VILA | `Efficient-Large-Model/NVILA-8B` | L + V |
|
||||
| `LlavaNextForConditionalGeneration` | LLaVA-NeXT | `llava-hf/llava-v1.6-mistral-7b-hf` | L + V |
|
||||
| `LlamaForCausalLM` | Llama 3 <br> Llama 3.1 <br> Llama 2 <br> LLaMA | `meta-llama/Meta-Llama-3.1-70B` | L |
|
||||
| `Llama4ForConditionalGeneration` | Llama 4 Scout <br> Llama 4 Maverick | `meta-llama/Llama-4-Scout-17B-16E-Instruct` <br> `meta-llama/Llama-4-Maverick-17B-128E-Instruct` | L + V |
|
||||
| `MistralForCausalLM` | Mistral | `mistralai/Mistral-7B-v0.1` | L |
|
||||
| `MixtralForCausalLM` | Mixtral | `mistralai/Mixtral-8x7B-v0.1` | L |
|
||||
| `MllamaForConditionalGeneration` | Llama 3.2 | `meta-llama/Llama-3.2-11B-Vision` | L |
|
||||
| `NemotronForCausalLM` | Nemotron-3 <br> Nemotron-4 <br> Minitron | `nvidia/Minitron-8B-Base` | L |
|
||||
| `NemotronHForCausalLM` | Nemotron-H | `nvidia/Nemotron-H-8B-Base-8K` <br> `nvidia/Nemotron-H-47B-Base-8K` <br> `nvidia/Nemotron-H-56B-Base-8K` | L |
|
||||
| `NemotronNASForCausalLM` | LLamaNemotron <br> LlamaNemotron Super <br> LlamaNemotron Ultra | `nvidia/Llama-3_1-Nemotron-51B-Instruct` <br> `nvidia/Llama-3_3-Nemotron-Super-49B-v1` <br> `nvidia/Llama-3_1-Nemotron-Ultra-253B-v1` | L |
|
||||
| `Qwen2ForCausalLM` | QwQ, Qwen2 | `Qwen/Qwen2-7B-Instruct` | L |
|
||||
| `Qwen2ForProcessRewardModel` | Qwen2-based | `Qwen/Qwen2.5-Math-PRM-7B` | L |
|
||||
| `Qwen2ForRewardModel` | Qwen2-based | `Qwen/Qwen2.5-Math-RM-72B` | L |
|
||||
| `Qwen2VLForConditionalGeneration` | Qwen2-VL | `Qwen/Qwen2-VL-7B-Instruct` | L + V |
|
||||
| `Qwen2_5_VLForConditionalGeneration` | Qwen2.5-VL | `Qwen/Qwen2.5-VL-7B-Instruct` | L + V |
|
||||
|
||||
### Supported Models
|
||||
| Architecture | Model | HuggingFace Example | Modality |
|
||||
| :----------------------------------: | :----------------------------------------------------------- | :----------------------------------------------------------- | :------: |
|
||||
| `BertForSequenceClassification` | BERT-based | `textattack/bert-base-uncased-yelp-polarity` | L |
|
||||
| `DeepseekV3ForCausalLM` | DeepSeek-V3 | `deepseek-ai/DeepSeek-V3 ` | L |
|
||||
| `Gemma3ForCausalLM` | Gemma3 | `google/gemma-3-1b-it` | L |
|
||||
| `LlavaLlamaModel` | VILA | `Efficient-Large-Model/NVILA-8B` | L + V |
|
||||
| `LlavaNextForConditionalGeneration` | LLaVA-NeXT | `llava-hf/llava-v1.6-mistral-7b-hf` | L + V |
|
||||
| `LlamaForCausalLM` | Llama 3 <br> Llama 3.1 <br> Llama 2 <br> LLaMA | `meta-llama/Meta-Llama-3.1-70B` | L |
|
||||
| `Llama4ForConditionalGeneration` | Llama 4 Scout <br> Llama 4 Maverick | `meta-llama/Llama-4-Scout-17B-16E-Instruct` <br> `meta-llama/Llama-4-Maverick-17B-128E-Instruct` | L + V |
|
||||
| `MistralForCausalLM` | Mistral | `mistralai/Mistral-7B-v0.1` | L |
|
||||
| `MixtralForCausalLM` | Mixtral | `mistralai/Mixtral-8x7B-v0.1` | L |
|
||||
| `MllamaForConditionalGeneration` | Llama 3.2 | `meta-llama/Llama-3.2-11B-Vision` | L |
|
||||
| `NemotronForCausalLM` | Nemotron-3 <br> Nemotron-4 <br> Minitron | `nvidia/Minitron-8B-Base` | L |
|
||||
| `NemotronHForCausalLM` | Nemotron-H | `nvidia/Nemotron-H-8B-Base-8K` <br> `nvidia/Nemotron-H-47B-Base-8K` <br> `nvidia/Nemotron-H-56B-Base-8K` | L |
|
||||
| `NemotronNASForCausalLM` | LLamaNemotron <br> LlamaNemotron Super <br> LlamaNemotron Ultra | `nvidia/Llama-3_1-Nemotron-51B-Instruct` <br> `nvidia/Llama-3_3-Nemotron-Super-49B-v1` <br> `nvidia/Llama-3_1-Nemotron-Ultra-253B-v1` | L |
|
||||
| `Qwen2ForCausalLM` | QwQ, Qwen2 | `Qwen/Qwen2-7B-Instruct` | L |
|
||||
| `Qwen2ForProcessRewardModel` | Qwen2-based | `Qwen/Qwen2.5-Math-PRM-7B` | L |
|
||||
| `Qwen2ForRewardModel` | Qwen2-based | `Qwen/Qwen2.5-Math-RM-72B` | L |
|
||||
| `Qwen2VLForConditionalGeneration` | Qwen2-VL | `Qwen/Qwen2-VL-7B-Instruct` | L + V |
|
||||
| `Qwen2_5_VLForConditionalGeneration` | Qwen2.5-VL | `Qwen/Qwen2.5-VL-7B-Instruct` | L + V |
|
||||
|
||||
Note:
|
||||
- L: Language only
|
||||
- L + V: Language and Vision multimodal support
|
||||
- Llama 3.2 accepts vision input, but our support currently limited to text only.
|
||||
|
||||
## Run the speculative decoding script:
|
||||
|
||||
```bash
|
||||
# NGram drafter
|
||||
python3 examples/pytorch/quickstart_advanced.py \
|
||||
--max_matching_ngram_size=2 \
|
||||
--spec_decode_nextn=4
|
||||
```
|
||||
|
||||
@ -4,7 +4,7 @@ from tensorrt_llm import SamplingParams
|
||||
from tensorrt_llm._torch import LLM
|
||||
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
|
||||
from tensorrt_llm.llmapi import (EagleDecodingConfig, KvCacheConfig,
|
||||
MTPDecodingConfig)
|
||||
MTPDecodingConfig, NGramDecodingConfig)
|
||||
|
||||
example_prompts = [
|
||||
"Hello, my name is",
|
||||
@ -103,6 +103,7 @@ def add_llm_args(parser):
|
||||
parser.add_argument('--spec_decode_algo', type=str, default=None)
|
||||
parser.add_argument('--spec_decode_nextn', type=int, default=1)
|
||||
parser.add_argument('--eagle_model_dir', type=str, default=None)
|
||||
parser.add_argument('--max_matching_ngram_size', type=int, default=5)
|
||||
|
||||
# Relaxed acceptance
|
||||
parser.add_argument('--use_relaxed_acceptance_for_thinking',
|
||||
@ -130,6 +131,7 @@ def setup_llm(args):
|
||||
use_cuda_graph=args.use_cuda_graph,
|
||||
load_format=args.load_format,
|
||||
print_iter_log=args.print_iter_log,
|
||||
enable_iter_perf_stats=args.print_iter_log,
|
||||
torch_compile_enabled=args.use_torch_compile,
|
||||
torch_compile_piecewise_cuda_graph=args.use_piecewise_cuda_graph,
|
||||
moe_backend=args.moe_backend,
|
||||
@ -154,6 +156,14 @@ def setup_llm(args):
|
||||
spec_config = EagleDecodingConfig(
|
||||
max_draft_len=args.spec_decode_nextn,
|
||||
pytorch_eagle_weights_path=args.eagle_model_dir)
|
||||
elif spec_decode_algo == "NGRAM":
|
||||
spec_config = NGramDecodingConfig(
|
||||
prompt_lookup_num_tokens=args.spec_decode_nextn,
|
||||
max_matching_ngram_size=args.max_matching_ngram_size,
|
||||
is_keep_all=True,
|
||||
is_use_oldest=True,
|
||||
is_public_pool=True,
|
||||
)
|
||||
else:
|
||||
spec_config = None
|
||||
|
||||
|
||||
@ -454,7 +454,8 @@ def main(args):
|
||||
lora_ckpt_source=args.lora_ckpt_source,
|
||||
gpu_weights_percent=args.gpu_weights_percent,
|
||||
max_output_len=args.max_output_len,
|
||||
enable_context_fmha_fp32_acc=args.enable_context_fmha_fp32_acc)
|
||||
enable_context_fmha_fp32_acc=args.enable_context_fmha_fp32_acc,
|
||||
)
|
||||
if args.medusa_choices is not None:
|
||||
args.medusa_choices = ast.literal_eval(args.medusa_choices)
|
||||
assert args.temperature == 1.0, "Medusa should use temperature == 1.0"
|
||||
|
||||
@ -491,10 +491,34 @@ def main(args):
|
||||
f"Using {'Python' if args.use_py_session else 'C++'} session")
|
||||
|
||||
runner_cls = ModelRunner if args.use_py_session else ModelRunnerCpp
|
||||
runner_kwargs = dict(engine_dir=args.engine_dir,
|
||||
rank=runtime_rank,
|
||||
debug_mode=args.debug_mode,
|
||||
gpu_weights_percent=args.gpu_weights_percent)
|
||||
runner_kwargs = dict(
|
||||
engine_dir=args.engine_dir,
|
||||
rank=runtime_rank,
|
||||
debug_mode=args.debug_mode,
|
||||
gpu_weights_percent=args.gpu_weights_percent,
|
||||
enable_context_fmha_fp32_acc=args.enable_context_fmha_fp32_acc,
|
||||
)
|
||||
if not args.use_py_session:
|
||||
runner_kwargs.update(
|
||||
lora_dir=args.lora_dir,
|
||||
lora_ckpt_source=args.lora_ckpt_source,
|
||||
max_batch_size=max_batch_size,
|
||||
max_input_len=test_token_num,
|
||||
max_output_len=output_len,
|
||||
max_beam_width=num_beams,
|
||||
max_attention_window_size=max_attention_window_size,
|
||||
sink_token_length=sink_token_length,
|
||||
max_tokens_in_paged_kv_cache=args.max_tokens_in_paged_kv_cache,
|
||||
kv_cache_enable_block_reuse=args.kv_cache_enable_block_reuse,
|
||||
kv_cache_free_gpu_memory_fraction=args.
|
||||
kv_cache_free_gpu_memory_fraction,
|
||||
enable_chunked_context=args.enable_chunked_context,
|
||||
multi_block_mode=args.multi_block_mode,
|
||||
cuda_graph_mode=args.cuda_graph_mode,
|
||||
gather_generation_logits=args.eval_ppl,
|
||||
use_gpu_direct_storage=args.use_gpu_direct_storage,
|
||||
)
|
||||
|
||||
if args.medusa_choices is not None:
|
||||
args.medusa_choices = ast.literal_eval(args.medusa_choices)
|
||||
assert args.temperature == 1.0, "Medusa should use temperature == 1.0"
|
||||
@ -523,38 +547,16 @@ def main(args):
|
||||
if args.prompt_lookup_config is not None:
|
||||
assert args.kv_cache_enable_block_reuse, "`--kv_cache_enable_block_reuse` must be specified in speculative decoding."
|
||||
assert not args.use_py_session, "`--use_py_session` is not supported in Speculative decoding."
|
||||
assert not is_enc_dec, "Encoder-Decoder model is not supported in Speculative decoding."
|
||||
assert args.num_beams == 1, "`--num_beams>1` is not supported in Speculative decoding."
|
||||
prompt_lookup_num_tokens, _, target_device_list = ast.literal_eval(
|
||||
args.prompt_lookup_config)
|
||||
args.max_output_len = output_len # Specialization for PLD
|
||||
runner_kwargs.update(is_orchestrator_mode=True,
|
||||
device_ids=target_device_list)
|
||||
|
||||
if not args.use_py_session:
|
||||
runner_kwargs.update(
|
||||
lora_dir=args.lora_dir,
|
||||
lora_ckpt_source=args.lora_ckpt_source,
|
||||
max_batch_size=max_batch_size,
|
||||
max_input_len=test_token_num,
|
||||
max_output_len=output_len,
|
||||
max_beam_width=num_beams,
|
||||
max_attention_window_size=max_attention_window_size,
|
||||
sink_token_length=sink_token_length,
|
||||
max_tokens_in_paged_kv_cache=args.max_tokens_in_paged_kv_cache,
|
||||
kv_cache_enable_block_reuse=args.kv_cache_enable_block_reuse,
|
||||
kv_cache_free_gpu_memory_fraction=args.
|
||||
kv_cache_free_gpu_memory_fraction,
|
||||
enable_chunked_context=args.enable_chunked_context,
|
||||
multi_block_mode=args.multi_block_mode,
|
||||
cuda_graph_mode=args.cuda_graph_mode,
|
||||
gather_generation_logits=args.eval_ppl,
|
||||
use_gpu_direct_storage=args.use_gpu_direct_storage)
|
||||
runner_kwargs.update(
|
||||
enable_context_fmha_fp32_acc=args.enable_context_fmha_fp32_acc)
|
||||
if args.prompt_lookup_config is not None:
|
||||
# Specialization for PLD since many call of `generate()` is needed
|
||||
runner_kwargs.update(max_input_len=test_token_num +
|
||||
device_ids=target_device_list,
|
||||
max_input_len=test_token_num +
|
||||
prompt_lookup_num_tokens + output_len)
|
||||
|
||||
runner = runner_cls.from_dir(**runner_kwargs)
|
||||
assert not (args.eval_ppl and not runner.gather_context_logits), \
|
||||
"PPL evaluation requires engine built with gather_context_logits enabled"
|
||||
|
||||
@ -463,7 +463,8 @@ def instantiate_sampler(model_engine: PyTorchModelEngine,
|
||||
assert pytorch_backend_config.attn_backend == "FLASHINFER_STAR_ATTENTION", "attention backend of star attention should be 'FLASHINFER_STAR_ATTENTION'"
|
||||
sampler = TorchStarAttentionSampler(
|
||||
max_seq_len=model_engine.max_seq_len)
|
||||
elif model_engine.spec_config is not None:
|
||||
elif model_engine.spec_config is not None and model_engine.spec_config.spec_dec_mode.has_spec_decoder(
|
||||
):
|
||||
sampler = get_spec_decoder(max_seq_len=model_engine.max_seq_len,
|
||||
spec_config=model_engine.spec_config)
|
||||
elif pytorch_backend_config.enable_trtllm_sampler:
|
||||
|
||||
@ -1042,7 +1042,6 @@ class PyTorchModelEngine(ModelEngine):
|
||||
request_ids.append(request.py_request_id)
|
||||
all_prompt_tokens = request.get_tokens(0)
|
||||
draft_lens.append(0)
|
||||
|
||||
begin_compute = request.context_current_position
|
||||
end_compute = begin_compute + request.context_chunk_size
|
||||
prompt_tokens = all_prompt_tokens[begin_compute:end_compute]
|
||||
|
||||
@ -24,7 +24,8 @@ from tensorrt_llm.bindings.executor import (DisServingRequestStats,
|
||||
FinishReason, InflightBatchingStats,
|
||||
IterationStats, KvCacheStats,
|
||||
RequestStage, RequestStats,
|
||||
RequestType, StaticBatchingStats)
|
||||
RequestType, SpecDecodingStats,
|
||||
StaticBatchingStats)
|
||||
from tensorrt_llm.bindings.internal.batch_manager import (LlmRequestType,
|
||||
ReqIdsSet)
|
||||
from tensorrt_llm.logger import logger
|
||||
@ -512,6 +513,10 @@ class PyExecutor:
|
||||
stats.inflight_batching_stats = InflightBatchingStats()
|
||||
# staticBatchingStats is not used in pytorch path
|
||||
stats.static_batching_stats = StaticBatchingStats()
|
||||
spec_resource_manager = self.resource_manager.resource_managers.get(
|
||||
"spec_resource_manager")
|
||||
if spec_resource_manager is not None:
|
||||
stats.specdec_stats = SpecDecodingStats()
|
||||
return stats
|
||||
|
||||
def _populate_req_stats(
|
||||
@ -611,6 +616,9 @@ class PyExecutor:
|
||||
scheduled_batch.paused_requests)
|
||||
stats.inflight_batching_stats.avg_num_decoded_tokens_per_iter = 0
|
||||
stats.inflight_batching_stats.micro_batch_id = 0
|
||||
if stats.specdec_stats is not None:
|
||||
stats.specdec_stats.draft_overhead = 0.0 if iter_latency_ms <= 0.0 else float(
|
||||
stats.specdec_stats.iter_latency_ms) / float(iter_latency_ms)
|
||||
return stats
|
||||
|
||||
def _append_iter_stats(self,
|
||||
@ -627,7 +635,7 @@ class PyExecutor:
|
||||
active_requests: List[LlmRequest],
|
||||
batch_state: BatchState):
|
||||
iter_end_time = time.time()
|
||||
iter_latency_ms = iter_end_time - batch_state.iter_start_time
|
||||
iter_latency_ms = (iter_end_time - batch_state.iter_start_time) * 1e3
|
||||
if batch_state.iter_stats is None:
|
||||
return
|
||||
|
||||
@ -791,6 +799,10 @@ class PyExecutor:
|
||||
torch.cuda.set_device(self.device_id)
|
||||
got_finish_signal = False
|
||||
num_dummy_request = 0
|
||||
is_ngram = hasattr(
|
||||
self.model_engine, "spec_config"
|
||||
) and self.model_engine.spec_config is not None and self.model_engine.spec_config.spec_dec_mode.is_ngram(
|
||||
)
|
||||
with self._profiler() as profile_step:
|
||||
iter_start_time = time.time()
|
||||
iter_stats = None
|
||||
@ -803,26 +815,28 @@ class PyExecutor:
|
||||
new_requests) or got_finish_signal
|
||||
if got_finish_signal and len(self.active_requests) == 0:
|
||||
break
|
||||
|
||||
if self.kv_cache_transceiver:
|
||||
self._check_disagg_gen_transfer_status()
|
||||
|
||||
if self.enable_iter_perf_stats:
|
||||
iter_stats = self._get_init_iter_stats(
|
||||
len(new_requests),
|
||||
self.new_active_requests_queue_latency_ms)
|
||||
|
||||
if self.kv_cache_transceiver:
|
||||
self._check_disagg_gen_transfer_status()
|
||||
|
||||
if not got_finish_signal:
|
||||
num_dummy_request = self._get_num_dummy_request()
|
||||
if num_dummy_request > 0:
|
||||
self._merge_dummy_request(num_dummy_request)
|
||||
|
||||
if self.draft_model_engine is not None:
|
||||
if self.draft_model_engine is not None or is_ngram:
|
||||
self._prepare_draft_requests()
|
||||
|
||||
scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule(
|
||||
)
|
||||
|
||||
if self.kv_cache_transceiver:
|
||||
# For requests that are fitting disagg gen init, also prepare resources for KV cache manager
|
||||
self._prepare_disagg_gen_init(
|
||||
fitting_disagg_gen_init_requests)
|
||||
if num_fitting_reqs == 0 and not fitting_disagg_gen_init_requests:
|
||||
@ -848,10 +862,21 @@ class PyExecutor:
|
||||
finished_requests = []
|
||||
|
||||
if scheduled_batch.batch_size > 0:
|
||||
has_ngram_iter_stats = is_ngram and self.model_engine.spec_config.spec_dec_mode.is_ngram(
|
||||
) and iter_stats is not None
|
||||
if has_ngram_iter_stats:
|
||||
before = time.time()
|
||||
|
||||
self.resource_manager.prepare_resources(scheduled_batch)
|
||||
if self.draft_model_engine is not None:
|
||||
self._prepare_draft_tokens(scheduled_batch)
|
||||
|
||||
if has_ngram_iter_stats:
|
||||
self._insert_ngram_iter_stats(scheduled_batch,
|
||||
iter_stats)
|
||||
iter_stats.specdec_stats.iter_latency_ms = (
|
||||
time.time() - before) * 1e3
|
||||
|
||||
if self.kv_cache_transceiver:
|
||||
# For generation requests which have completed KV cache transfer
|
||||
self._prepare_disagg_gen_transmission_complete(
|
||||
@ -903,6 +928,7 @@ class PyExecutor:
|
||||
# Set draft tokens here to make the KV cache manager
|
||||
# and scheduler aware of them.
|
||||
for req in self.active_requests:
|
||||
# TODO: enable draft tokens in context phase
|
||||
if req.state != LlmRequestState.GENERATION_IN_PROGRESS:
|
||||
continue
|
||||
req.py_last_draft_tokens = req.py_draft_tokens
|
||||
@ -961,7 +987,7 @@ class PyExecutor:
|
||||
|
||||
if num_fitting_reqs == 0 and not fitting_disagg_gen_init_requests:
|
||||
logger.warning(
|
||||
"num_fitting_reqs =0 and fitting_disagg_gen_init_requests is empty , may not have enough kvCache"
|
||||
"num_fitting_reqs=0 and fitting_disagg_gen_init_requests is empty, may not have enough kvCache"
|
||||
)
|
||||
self.kv_cache_transceiver.check_context_transfer_status(
|
||||
1)
|
||||
@ -1677,6 +1703,43 @@ class PyExecutor:
|
||||
logger.error(f"Encountered an error in sampling: {error_msg}")
|
||||
self._handle_errors(error_msg)
|
||||
|
||||
def _insert_ngram_iter_stats(
|
||||
self, scheduled_requests: ScheduledRequests, iter_stats: IterationStats
|
||||
) -> Tuple[ScheduledRequests, Dict[int, LlmRequest]]:
|
||||
"""
|
||||
Get statistic information from the draft tokens in NGram drafter
|
||||
"""
|
||||
assert iter_stats is not None
|
||||
|
||||
total_num_draft_tokens = 0
|
||||
total_num_accepted_tokens = 0
|
||||
num_requests_with_draft_tokens = 0
|
||||
for request in chain(scheduled_requests.context_requests,
|
||||
scheduled_requests.generation_requests):
|
||||
num_draft_tokens = 0 if request.py_last_draft_tokens is None else len(
|
||||
request.py_last_draft_tokens)
|
||||
num_accepted_tokens = getattr(request,
|
||||
"py_num_accepted_draft_tokens", 0)
|
||||
if num_draft_tokens > 0:
|
||||
total_num_draft_tokens = total_num_draft_tokens + num_draft_tokens
|
||||
total_num_accepted_tokens = total_num_accepted_tokens + num_accepted_tokens
|
||||
num_requests_with_draft_tokens = num_requests_with_draft_tokens + 1
|
||||
|
||||
if num_requests_with_draft_tokens > 0:
|
||||
iter_stats.specdec_stats.iter_latency_ms = 0.0 # We do not coutn time in this method
|
||||
iter_stats.specdec_stats.num_draft_tokens = total_num_draft_tokens
|
||||
iter_stats.specdec_stats.num_accepted_tokens = total_num_accepted_tokens
|
||||
iter_stats.specdec_stats.num_requests_with_draft_tokens = num_requests_with_draft_tokens
|
||||
iter_stats.specdec_stats.acceptance_length = float(
|
||||
(total_num_accepted_tokens + num_requests_with_draft_tokens
|
||||
)) / float(num_requests_with_draft_tokens)
|
||||
else:
|
||||
iter_stats.specdec_stats.iter_latency_ms = 0.0
|
||||
iter_stats.specdec_stats.num_draft_tokens = 0
|
||||
iter_stats.specdec_stats.num_accepted_tokens = 0
|
||||
iter_stats.specdec_stats.num_requests_with_draft_tokens = 0
|
||||
iter_stats.specdec_stats.acceptance_length = 1.0
|
||||
|
||||
@nvtx_range("_prepare_draft_batch")
|
||||
def _prepare_draft_batch(
|
||||
self, scheduled_requests: ScheduledRequests
|
||||
|
||||
@ -12,7 +12,7 @@ from tensorrt_llm.quantization import KV_CACHE_QUANT_ALGO_LIST
|
||||
|
||||
from ..attention_backend.interface import AttentionRuntimeFeatures
|
||||
from ..distributed import MPIDist
|
||||
from ..speculative import Eagle3Config, get_spec_resource_manager
|
||||
from ..speculative import Eagle3Config, NGramConfig, get_spec_resource_manager
|
||||
from ._util import (create_kv_cache_manager, create_py_executor_instance,
|
||||
estimate_max_kv_cache_tokens, get_token_num_for_estimation,
|
||||
instantiate_sampler, is_mla)
|
||||
@ -63,11 +63,13 @@ def create_py_executor(executor_config: ExecutorConfig,
|
||||
|
||||
spec_config = executor_config.speculative_config
|
||||
has_draft_model_engine = isinstance(spec_config, Eagle3Config)
|
||||
has_ngram_drafter = isinstance(spec_config, NGramConfig)
|
||||
|
||||
attn_runtime_features = AttentionRuntimeFeatures(
|
||||
chunked_prefill=executor_config.enable_chunked_context,
|
||||
cache_reuse=executor_config.kv_cache_config.enable_block_reuse,
|
||||
has_speculative_draft_tokens=has_draft_model_engine,
|
||||
has_speculative_draft_tokens=has_draft_model_engine
|
||||
or has_ngram_drafter,
|
||||
)
|
||||
|
||||
model_engine = PyTorchModelEngine(
|
||||
|
||||
@ -314,7 +314,7 @@ class TorchSampler(Sampler):
|
||||
num_accepted += 1
|
||||
new_token = new_tokens_list[token_idx + num_accepted]
|
||||
num_tokens = request.add_new_token(new_token, beam_idx)
|
||||
new_tokens.append(num_tokens)
|
||||
new_tokens.append(num_tokens) # `num_tokens`->`new_token`
|
||||
|
||||
if self._handle_stop_criteria(request, new_token,
|
||||
num_tokens, beam_idx):
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from .eagle3 import Eagle3Config, Eagle3SpecMetadata
|
||||
from .interface import SpecConfig, SpecMetadata
|
||||
from .mtp import MTPConfig, MTPEagleWorker, MTPSpecMetadata, MTPWorker
|
||||
from .ngram import NGramConfig
|
||||
from .utils import (get_num_spec_layers, get_spec_decoder, get_spec_metadata,
|
||||
get_spec_resource_manager)
|
||||
|
||||
@ -8,5 +9,5 @@ __all__ = [
|
||||
"SpecConfig", "SpecMetadata", "MTPConfig", "MTPWorker", "MTPEagleWorker",
|
||||
"Eagle3Config", "Eagle3SpecMetadata", "MTPSpecMetadata",
|
||||
"get_spec_metadata", "get_spec_resource_manager", "get_spec_decoder",
|
||||
"get_num_spec_layers"
|
||||
"get_num_spec_layers", "NGramConfig"
|
||||
]
|
||||
|
||||
@ -15,6 +15,7 @@ class SpeculativeDecodingMode(IntEnum):
|
||||
MTP = auto()
|
||||
MTP_EAGLE = auto()
|
||||
EAGLE3 = auto()
|
||||
NGRAM = auto()
|
||||
NONE = auto()
|
||||
|
||||
def is_mtp(self):
|
||||
@ -26,6 +27,9 @@ class SpeculativeDecodingMode(IntEnum):
|
||||
def is_eagle3(self):
|
||||
return self == SpeculativeDecodingMode.EAGLE3
|
||||
|
||||
def is_ngram(self):
|
||||
return self == SpeculativeDecodingMode.NGRAM
|
||||
|
||||
def is_none(self):
|
||||
return self == SpeculativeDecodingMode.NONE
|
||||
|
||||
@ -35,6 +39,9 @@ class SpeculativeDecodingMode(IntEnum):
|
||||
def support_overlap_scheduler(self):
|
||||
return self.is_mtp()
|
||||
|
||||
def has_spec_decoder(self):
|
||||
return self.is_mtp() or self.is_eagle3()
|
||||
|
||||
def extend_ctx(self, attention_backend: AttentionBackend):
|
||||
"""
|
||||
If true, treat generation requests with draft tokens as
|
||||
@ -43,8 +50,9 @@ class SpeculativeDecodingMode(IntEnum):
|
||||
"""
|
||||
|
||||
# Fixme: only trtllm attention backend supports eagle3 generation-phase kernels on blackwell.
|
||||
return self.is_eagle3() and not (isinstance(
|
||||
attention_backend, TrtllmAttention) and get_sm_version() == 100)
|
||||
return (self.is_eagle3()
|
||||
and not (isinstance(attention_backend, TrtllmAttention)
|
||||
and get_sm_version() == 100)) or self.is_ngram()
|
||||
|
||||
@staticmethod
|
||||
def from_string(name: Optional[str]) -> "SpeculativeDecodingMode":
|
||||
|
||||
210
tensorrt_llm/_torch/speculative/ngram.py
Normal file
210
tensorrt_llm/_torch/speculative/ngram.py
Normal file
@ -0,0 +1,210 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
from ordered_set import OrderedSet
|
||||
|
||||
from ..pyexecutor.llm_request import LlmRequest
|
||||
from ..pyexecutor.resource_manager import BaseResourceManager
|
||||
from ..pyexecutor.scheduler import ScheduledRequests
|
||||
from .interface import SpecConfig, SpeculativeDecodingMode
|
||||
|
||||
|
||||
@dataclass
|
||||
class NGramConfig(SpecConfig):
|
||||
"""
|
||||
Configuration for N-gram drafter.
|
||||
"""
|
||||
# The name of speculative decoding.
|
||||
spec_dec_name = "NGRAM"
|
||||
|
||||
num_extra_kv_tokens: int = 0
|
||||
max_draft_tokens: int = 0
|
||||
|
||||
prompt_lookup_num_tokens: int = 5
|
||||
max_matching_ngram_size: int = 5
|
||||
end_id: int = -1
|
||||
is_keep_all: bool = True
|
||||
is_use_oldest: bool = True
|
||||
is_public_pool: bool = True
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.spec_dec_mode = SpeculativeDecodingMode.from_string(
|
||||
self.spec_dec_name)
|
||||
self.max_draft_tokens = self.prompt_lookup_num_tokens
|
||||
|
||||
def update_from_model_config(self, model_config):
|
||||
pass
|
||||
|
||||
|
||||
class NGramPoolManager(BaseResourceManager):
|
||||
"""
|
||||
This class maintains the pattern-matches pairs for NGram drafter.
|
||||
|
||||
For example, one of the existed pairs could be: ["I","love"] -> [["apple", "because", "it", "is"], ["banana", "and"]].
|
||||
|
||||
Here we call ["I","love"] as `pattern`, and [["apple", "because", "it", "is"], ["banana", "and"]] as `matches`.
|
||||
|
||||
`pattern` is a list of token_ids. The pool provides corresponding draft tokens from the matches if the pattern appears at the tail of the sentence during generation.
|
||||
|
||||
`matches` is a list of candidate draft token_ids attaching to a pattern.
|
||||
|
||||
Arguments:
|
||||
prompt_lookup_num_tokens: int
|
||||
The length maximum of draft tokens (can be understood as length maximum of output draft tokens).
|
||||
|
||||
max_matching_ngram_size: int
|
||||
The length maximum of searching tokens (can be understood as length maximum of input tokens to search).
|
||||
|
||||
is_keep_all: bool = True
|
||||
Whether to keep all candidate pattern-matches pairs, only one match is kept for each pattern if False.
|
||||
|
||||
is_use_oldest: bool = True
|
||||
Whether to provide the oldest match when pattern is hit, the newest one is provided if False.
|
||||
|
||||
is_public_pool: bool = True
|
||||
Whether to use a common pool for all requests, or the pool is private for each request if False.
|
||||
|
||||
Members:
|
||||
pool: dict[tuple[int], OrderedSet[int]] or dict[int, dict[tuple[int], OrderedSet[int]]]
|
||||
If is_public_pool == True, it maps from patterns to matches
|
||||
If is_public_pool == False, it maps from request ID to the request-specific pool
|
||||
|
||||
start_index: dict[int, int]
|
||||
It maps from request ID to the index of the prompt to update the pool in the next step
|
||||
"""
|
||||
|
||||
def __init__(self, config: NGramConfig, max_num_requests: int):
|
||||
|
||||
self.max_num_requests = max_num_requests
|
||||
self.max_num_draft_tokens = config.max_draft_tokens
|
||||
|
||||
self.prompt_lookup_num_tokens = config.prompt_lookup_num_tokens
|
||||
self.max_matching_ngram_size = config.max_matching_ngram_size
|
||||
self.is_keep_all = config.is_keep_all
|
||||
self.is_use_oldest = config.is_use_oldest # TODO: remove this if updating strategy is supported
|
||||
self.is_public_pool = config.is_public_pool
|
||||
self.pool = {}
|
||||
self.start_index = {}
|
||||
|
||||
def prepare_resources(self, scheduled_batch: ScheduledRequests):
|
||||
# Update pool and provide draft tokens for the requests
|
||||
for request in scheduled_batch.generation_requests:
|
||||
num_draft_tokens = 0 if request.py_last_draft_tokens is None else \
|
||||
len(request.py_last_draft_tokens)
|
||||
num_accepted_tokens = getattr(request,
|
||||
"py_num_accepted_draft_tokens", 0)
|
||||
num_rejected_tokens = num_draft_tokens - num_accepted_tokens
|
||||
assert num_rejected_tokens >= 0
|
||||
|
||||
# Generate draft tokens
|
||||
draft_tokens = self._get_draft_tokens(
|
||||
request.get_tokens()[0],
|
||||
request.request_id,
|
||||
request.py_end_id,
|
||||
request.py_orig_prompt_len + request.py_max_new_tokens,
|
||||
)
|
||||
|
||||
# Pad to max_draft_tokens
|
||||
if draft_tokens is not None:
|
||||
pad_length = self.max_num_draft_tokens - len(draft_tokens)
|
||||
draft_tokens.extend([request.py_end_id] * pad_length)
|
||||
request.py_draft_tokens = draft_tokens
|
||||
|
||||
def update_resources(self, scheduled_batch: ScheduledRequests):
|
||||
pass
|
||||
|
||||
def free_resources(self, request: LlmRequest):
|
||||
if self.is_public_pool:
|
||||
return # TODO: need to have a strategy to swap out the pairs
|
||||
request_id = request.request_id
|
||||
if request_id in self.pool:
|
||||
self.pool.pop(request_id)
|
||||
self.start_index.pop(request_id)
|
||||
|
||||
def add_dummy_requests(self, request_ids: List[int]):
|
||||
pass
|
||||
|
||||
def shutdown(self):
|
||||
pass
|
||||
|
||||
def get_max_resource_count(self) -> int:
|
||||
return self.max_num_requests
|
||||
|
||||
def get_needed_resource_to_completion(self, request: LlmRequest):
|
||||
return 0
|
||||
|
||||
def print_pool(self): # For debug
|
||||
if self.is_public_pool:
|
||||
logger.debug(f"Using public pool, size = {len(self.pool)}")
|
||||
self._print_line(self.pool)
|
||||
else:
|
||||
logger.debug(f"Using private pool")
|
||||
for request_id, request_map in self.pool.items():
|
||||
logger.debug(f"Request {request_id}, size={len(request_map)}")
|
||||
self._print_line(request_map, 4)
|
||||
|
||||
def _print_line(self, local_map, indentation=0): # For debug
|
||||
for pattern, matches in local_map.items():
|
||||
output = " " * indentation + str(pattern) + "->"
|
||||
for match in matches:
|
||||
output += str(match) + ", "
|
||||
logger.debug(output)
|
||||
|
||||
def _get_draft_tokens(
|
||||
self,
|
||||
prefix: list[int],
|
||||
request_id: int,
|
||||
end_id: int,
|
||||
max_sequence_length: int,
|
||||
):
|
||||
prefix_len = len(prefix)
|
||||
max_draft_token_length = max_sequence_length - 1 - prefix_len
|
||||
if max_draft_token_length <= 0: # Skip search if prefix is long enough
|
||||
return None
|
||||
|
||||
if request_id not in self.start_index: # A new request
|
||||
self.start_index[request_id] = 0
|
||||
if not self.is_public_pool:
|
||||
assert len(self.pool) + 1 <= self.max_num_requests
|
||||
self.pool[request_id] = {}
|
||||
pool = (self.pool if self.is_public_pool else self.pool[request_id])
|
||||
|
||||
# Update pool
|
||||
sequence = prefix[self.start_index[request_id]:]
|
||||
for size in range(min(self.max_matching_ngram_size, prefix_len - 1), 0,
|
||||
-1):
|
||||
# Find each possible pattern-match combination, and use tuple for hash
|
||||
for l in range(len(sequence) - size):
|
||||
r = min(l + size + self.prompt_lookup_num_tokens, len(sequence))
|
||||
pattern = tuple(sequence[l:l + size])
|
||||
new_match = tuple(sequence[l + size:r])
|
||||
if pattern not in pool or \
|
||||
(not self.is_keep_all and len(match) > pool[pattern][0]):
|
||||
# Replace the match if
|
||||
# 1. the pattern does not exist in the pool
|
||||
# 2. only one match is kept, and the new match is longer (MRU)
|
||||
pool[pattern] = OrderedSet((new_match, ))
|
||||
elif new_match not in pool[pattern]:
|
||||
# Update the matches if the pattern is already existed:
|
||||
# TODO: need a strategy to maintain the short candidates, now we just remove them
|
||||
# Drop all existed matches with small length
|
||||
for match in pool[pattern]:
|
||||
if len(match) < len(new_match):
|
||||
pool[pattern].remove(match)
|
||||
pool[pattern].add(new_match)
|
||||
|
||||
# Find match
|
||||
draft_tokens = [end_id]
|
||||
for size in range(min(self.max_matching_ngram_size, prefix_len - 1), 0,
|
||||
-1):
|
||||
pattern = tuple(prefix[-size:])
|
||||
if pattern not in pool:
|
||||
continue
|
||||
draft_tokens = pool[pattern][0 if self.is_use_oldest else -1]
|
||||
draft_tokens = list(draft_tokens)[:max_draft_token_length]
|
||||
break
|
||||
self.start_index[request_id] = max(
|
||||
0, prefix_len -
|
||||
(self.prompt_lookup_num_tokens + self.max_matching_ngram_size - 1))
|
||||
|
||||
return draft_tokens
|
||||
@ -1,5 +1,6 @@
|
||||
from .eagle3 import Eagle3Sampler, Eagle3SpecMetadata
|
||||
from .mtp import MTPHiddenStatesManager, MTPSampler, MTPSpecMetadata
|
||||
from .ngram import NGramPoolManager
|
||||
|
||||
|
||||
def get_spec_metadata(spec_config,
|
||||
@ -34,6 +35,8 @@ def get_spec_resource_manager(spec_config, model_config, max_num_requests):
|
||||
return MTPHiddenStatesManager(spec_config, model_config.torch_dtype,
|
||||
model_config.hidden_size,
|
||||
max_num_requests)
|
||||
elif spec_config.spec_dec_mode.is_ngram():
|
||||
return NGramPoolManager(spec_config, max_num_requests)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
@ -8,7 +8,7 @@ from .llm_args import (BatchingType, CacheTransceiverConfig, CalibConfig,
|
||||
DynamicBatchConfig, EagleDecodingConfig,
|
||||
ExtendedRuntimePerfKnobConfig, KvCacheConfig,
|
||||
LookaheadDecodingConfig, MedusaDecodingConfig,
|
||||
MTPDecodingConfig, SchedulerConfig)
|
||||
MTPDecodingConfig, NGramDecodingConfig, SchedulerConfig)
|
||||
from .llm_utils import (BuildConfig, KvCacheRetentionConfig, QuantAlgo,
|
||||
QuantConfig)
|
||||
from .mpi_session import MpiCommSession
|
||||
@ -40,4 +40,5 @@ __all__ = [
|
||||
'ContextChunkingPolicy',
|
||||
'DynamicBatchConfig',
|
||||
'CacheTransceiverConfig',
|
||||
'NGramDecodingConfig',
|
||||
]
|
||||
|
||||
@ -195,7 +195,8 @@ class DecodingBaseConfig(BaseModel):
|
||||
"MTP": MTPDecodingConfig,
|
||||
"Medusa": MedusaDecodingConfig,
|
||||
"Eagle": EagleDecodingConfig,
|
||||
"Lookahead": LookaheadDecodingConfig
|
||||
"Lookahead": LookaheadDecodingConfig,
|
||||
"NGram": NGramDecodingConfig,
|
||||
}
|
||||
|
||||
config_class = config_classes.get(decoding_type)
|
||||
@ -236,6 +237,40 @@ class EagleDecodingConfig(DecodingBaseConfig):
|
||||
decoding_type: ClassVar[str] = "Eagle"
|
||||
|
||||
|
||||
class NGramDecodingConfig(DecodingBaseConfig):
|
||||
"""
|
||||
Configuration for NGram drafter speculative decoding.
|
||||
|
||||
Arguments:
|
||||
prompt_lookup_num_tokens: int
|
||||
The length maximum of draft tokens (can be understood as length maximum of output draft tokens).
|
||||
|
||||
max_matching_ngram_size: int
|
||||
The length maximum of searching tokens (can be understood as length maximum of input tokens to search).
|
||||
|
||||
is_keep_all: bool = True
|
||||
Whether to keep all candidate pattern-matches pairs, only one match is kept for each pattern if False.
|
||||
|
||||
is_use_oldest: bool = True
|
||||
Whether to provide the oldest match when pattern is hit, the newest one is provided if False.
|
||||
|
||||
is_public_pool: bool = True
|
||||
Whether to use a common pool for all requests, or the pool is private for each request if False.
|
||||
"""
|
||||
|
||||
prompt_lookup_num_tokens: int = 2
|
||||
max_matching_ngram_size: int = 4
|
||||
is_keep_all: bool = True
|
||||
is_use_oldest: bool = True
|
||||
is_public_pool: bool = True
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict):
|
||||
return cls(**data)
|
||||
|
||||
decoding_type: ClassVar[str] = "NGram"
|
||||
|
||||
|
||||
class MTPDecodingConfig(DecodingBaseConfig):
|
||||
num_nextn_predict_layers: Optional[int] = 1
|
||||
use_relaxed_acceptance_for_thinking: Optional[bool] = False
|
||||
@ -880,8 +915,8 @@ class LlmArgs(BaseModel):
|
||||
# Speculative decoding parameters
|
||||
speculative_config: Optional[Union[
|
||||
LookaheadDecodingConfig, MedusaDecodingConfig, EagleDecodingConfig,
|
||||
MTPDecodingConfig]] = Field(default=None,
|
||||
description="Speculative decoding config.")
|
||||
MTPDecodingConfig, NGramDecodingConfig]] = Field(
|
||||
default=None, description="Speculative decoding config.")
|
||||
|
||||
batching_type: Optional[BatchingType] = Field(default=None,
|
||||
description="Batching type.")
|
||||
@ -1209,7 +1244,21 @@ class LlmArgs(BaseModel):
|
||||
max_draft_tokens=self.speculative_config.max_draft_len,
|
||||
eagle_weights_path=self.speculative_config.
|
||||
pytorch_eagle_weights_path)
|
||||
|
||||
elif isinstance(self.speculative_config, NGramDecodingConfig):
|
||||
self.build_config.speculative_decoding_mode = SpeculativeDecodingMode.NGRAM
|
||||
assert self.backend == 'pytorch'
|
||||
assert self.speculative_config.prompt_lookup_num_tokens > 0 and self.speculative_config.max_matching_ngram_size > 0
|
||||
self.build_config.max_draft_len = self.speculative_config.max_draft_len
|
||||
from tensorrt_llm._torch.speculative import NGramConfig
|
||||
self.speculative_config = NGramConfig(
|
||||
prompt_lookup_num_tokens=self.speculative_config.
|
||||
prompt_lookup_num_tokens,
|
||||
max_matching_ngram_size=self.speculative_config.
|
||||
max_matching_ngram_size,
|
||||
is_keep_all=self.speculative_config.is_keep_all,
|
||||
is_use_oldest=self.speculative_config.is_use_oldest,
|
||||
is_public_pool=self.speculative_config.is_public_pool,
|
||||
)
|
||||
elif isinstance(self.speculative_config, MTPDecodingConfig):
|
||||
from tensorrt_llm._torch.speculative import MTPConfig
|
||||
self.speculative_config = MTPConfig(
|
||||
|
||||
@ -31,8 +31,8 @@ from .build_cache import (BuildCache, BuildCacheConfig, CachedStage,
|
||||
get_build_cache_config_from_env)
|
||||
from .llm_args import (CalibConfig, EagleDecodingConfig, KvCacheConfig, LlmArgs,
|
||||
LookaheadDecodingConfig, MedusaDecodingConfig,
|
||||
MTPDecodingConfig, _ModelFormatKind, _ModelWrapper,
|
||||
_ParallelConfig, get_model_format,
|
||||
MTPDecodingConfig, NGramDecodingConfig, _ModelFormatKind,
|
||||
_ModelWrapper, _ParallelConfig, get_model_format,
|
||||
update_llm_args_with_extra_dict,
|
||||
update_llm_args_with_extra_options)
|
||||
from .mpi_session import MPINodeState, MpiSession
|
||||
@ -860,6 +860,7 @@ __all__ = [
|
||||
'LookaheadDecodingConfig',
|
||||
'MedusaDecodingConfig',
|
||||
'MTPDecodingConfig',
|
||||
'NGramDecodingConfig',
|
||||
'ContextChunkingPolicy',
|
||||
'CapacitySchedulerPolicy',
|
||||
'BuildConfig',
|
||||
|
||||
@ -96,6 +96,7 @@ class SpeculativeDecodingMode(IntFlag):
|
||||
LOOKAHEAD_DECODING = auto()
|
||||
EXPLICIT_DRAFT_TOKENS = auto()
|
||||
EAGLE = auto()
|
||||
NGRAM = auto()
|
||||
|
||||
@staticmethod
|
||||
def from_arguments(args: argparse.Namespace):
|
||||
@ -111,6 +112,8 @@ class SpeculativeDecodingMode(IntFlag):
|
||||
return SpeculativeDecodingMode.EXPLICIT_DRAFT_TOKENS
|
||||
elif args.speculative_decoding_mode == "eagle":
|
||||
return SpeculativeDecodingMode.EAGLE
|
||||
elif args.speculative_decoding_mode == "ngram":
|
||||
return SpeculativeDecodingMode.NGRAM
|
||||
else:
|
||||
assert False, "Unknown speculative_decoding_mode " + args.speculative_decoding_mode
|
||||
|
||||
|
||||
83
tests/unittest/_torch/speculative/test_ngram.py
Normal file
83
tests/unittest/_torch/speculative/test_ngram.py
Normal file
@ -0,0 +1,83 @@
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tensorrt_llm import SamplingParams
|
||||
from tensorrt_llm._torch import LLM
|
||||
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
|
||||
from tensorrt_llm.llmapi import KvCacheConfig, NGramDecodingConfig
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
||||
from utils.llm_data import llm_models_root
|
||||
|
||||
|
||||
# TODO: Add cuda graph enabled tests.
|
||||
# Cuda graph cannot currently be enabled for ngram because cuda graph requires
|
||||
# spec metadata and ngram does not have it.
|
||||
@pytest.mark.parametrize("use_cuda_graph,attn_backend",
|
||||
[[False, "TRTLLM"], [False, "FLASHINFER"]])
|
||||
def test_llama_ngram(use_cuda_graph: bool, attn_backend: str):
|
||||
total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
|
||||
if total_mem_gb < 31:
|
||||
pytest.skip("Not enough memory to load target model")
|
||||
|
||||
models_path = llm_models_root()
|
||||
|
||||
pytorch_config = PyTorchConfig(
|
||||
enable_overlap_scheduler=False,
|
||||
use_cuda_graph=use_cuda_graph,
|
||||
# Only create a single CUDA graph to prevent OOM in CI
|
||||
attn_backend=attn_backend,
|
||||
cuda_graph_batch_sizes=[1],
|
||||
)
|
||||
|
||||
kv_cache_config = KvCacheConfig(enable_block_reuse=False, max_tokens=2080)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=32,
|
||||
temperature=0,
|
||||
)
|
||||
max_batch_size = 1
|
||||
|
||||
target_model_dir = f"{models_path}/llama-models-v2/llama-v2-13b-hf"
|
||||
|
||||
draft_len = 4
|
||||
spec_config = NGramDecodingConfig(
|
||||
prompt_lookup_num_tokens=draft_len,
|
||||
max_matching_ngram_size=draft_len,
|
||||
is_keep_all=True,
|
||||
is_use_oldest=True,
|
||||
is_public_pool=True,
|
||||
)
|
||||
llm_spec = LLM(model=target_model_dir,
|
||||
max_batch_size=max_batch_size,
|
||||
pytorch_backend_config=pytorch_config,
|
||||
kv_cache_config=kv_cache_config,
|
||||
speculative_config=spec_config)
|
||||
|
||||
prompts = [
|
||||
"The capital of France is", "The president of the United States is"
|
||||
]
|
||||
results_spec = llm_spec.generate(prompts, sampling_params)
|
||||
generated_text_spec = [result.outputs[0].text for result in results_spec]
|
||||
llm_spec.shutdown()
|
||||
|
||||
llm_ref = LLM(model=target_model_dir,
|
||||
max_batch_size=max_batch_size,
|
||||
pytorch_backend_config=pytorch_config,
|
||||
kv_cache_config=kv_cache_config)
|
||||
|
||||
results_ref = llm_ref.generate(prompts, sampling_params)
|
||||
generated_text_ref = [result.outputs[0].text for result in results_ref]
|
||||
llm_ref.shutdown()
|
||||
|
||||
for text_spec, text_ref in zip(generated_text_spec, generated_text_ref):
|
||||
# The spec decode algorithm currently guarantees identical results
|
||||
assert text_spec == text_ref
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@ -74,7 +74,7 @@ methods:
|
||||
# Speculative decoding
|
||||
speculative_config:
|
||||
annotation: Union[tensorrt_llm.llmapi.llm_args.LookaheadDecodingConfig, tensorrt_llm.llmapi.llm_utils.MedusaDecodingConfig,
|
||||
tensorrt_llm.llmapi.llm_utils.EagleDecodingConfig, tensorrt_llm.llmapi.MTPDecodingConfig, NoneType]
|
||||
tensorrt_llm.llmapi.llm_utils.EagleDecodingConfig, tensorrt_llm.llmapi.MTPDecodingConfig, tensorrt_llm.llmapi.llm_args.NGramDecodingConfig, NoneType]
|
||||
default: null
|
||||
# generation constraints
|
||||
max_batch_size:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user