[refactor] Simplification of Speculative decoding configs (#5639)

Signed-off-by: wili-65535 <wili-65535@users.noreply.github.com>
Co-authored-by: wili-65535 <wili-65535@users.noreply.github.com>
This commit is contained in:
wili 2025-07-10 23:37:30 +08:00 committed by GitHub
parent 67a39dbd63
commit 2e3cf42e03
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
44 changed files with 390 additions and 530 deletions

View File

@ -68,7 +68,7 @@ docker run -d --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 \
-p 8000:8000 --gpus=all -e "TRTLLM_ENABLE_PDL=1" \
-v /path/to/maverick:/config/models/maverick -v /path/to/eagle:/config/models/eagle \
docker.io/<username>/tensorrt_llm:main sh \
-c "echo -e 'enable_attention_dp: false\nenable_min_latency: true\nenable_autotuner: false\ncuda_graph_config:\n max_batch_size: 8\nspeculative_config:\n decoding_type: Eagle\n max_draft_len: 3\n pytorch_weights_path: /config/models/eagle\nkv_cache_config:\n enable_block_reuse: false' > c.yaml && \
-c "echo -e 'enable_attention_dp: false\nenable_min_latency: true\nenable_autotuner: false\ncuda_graph_config:\n max_batch_size: 8\nspeculative_config:\n decoding_type: Eagle\n max_draft_len: 3\n speculative_model_dir: /config/models/eagle\nkv_cache_config:\n enable_block_reuse: false' > c.yaml && \
TRT_LLM_DISABLE_LOAD_WEIGHTS_IN_PARALLEL=True \
trtllm-serve /config/models/maverick \
--host 0.0.0.0 --port 8000 \

View File

@ -108,7 +108,7 @@ When using EAGLE-2, please enable `--eagle_use_dynamic_tree`, which indicates wh
- In EagleNet2, the `N` output nodes of EagleNet1 are expanded, and each node expands `N` new draft tokens. Therefore, this layer also has a total of `N * N` draft tokens. And select the top `N` as the output of this layer.
- Etc.
Finally, after `num_eagle_layer` EagleNets, `N + N * N * (num_eagle_layer - 1)` draft tokens are generated. We will rebuild the final tree based on all draft tokens and their scores. The final generated tree will have `min(N + N * N * (num_eagle_layer - 1), max_draft_tokens)` nodes.
Finally, after `num_eagle_layer` EagleNets, `N + N * N * (num_eagle_layer - 1)` draft tokens are generated. We will rebuild the final tree based on all draft tokens and their scores. The final generated tree will have `min(N + N * N * (num_eagle_layer - 1), max_draft_len)` nodes.

View File

@ -23,12 +23,12 @@ def main():
model = "lmsys/vicuna-7b-v1.3"
# The end user can customize the eagle decoding configuration by specifying the
# speculative_model, max_draft_len, num_eagle_layers, max_non_leaves_per_layer, eagle_choices
# speculative_model_dir, max_draft_len, num_eagle_layers, max_non_leaves_per_layer, eagle_choices
# greedy_sampling,posterior_threshold, use_dynamic_tree and dynamic_tree_max_topK
# with the EagleDecodingConfig class
speculative_config = EagleDecodingConfig(
speculative_model="yuhuili/EAGLE-Vicuna-7B-v1.3",
speculative_model_dir="yuhuili/EAGLE-Vicuna-7B-v1.3",
max_draft_len=63,
num_eagle_layers=4,
max_non_leaves_per_layer=10,

View File

@ -23,12 +23,12 @@ def main():
model = "lmsys/vicuna-7b-v1.3"
# The end user can customize the eagle decoding configuration by specifying the
# speculative_model, max_draft_len, num_eagle_layers, max_non_leaves_per_layer, eagle_choices
# speculative_model_dir, max_draft_len, num_eagle_layers, max_non_leaves_per_layer, eagle_choices
# greedy_sampling,posterior_threshold, use_dynamic_tree and dynamic_tree_max_topK
# with the EagleDecodingConfig class
speculative_config = EagleDecodingConfig(
speculative_model="yuhuili/EAGLE-Vicuna-7B-v1.3",
speculative_model_dir="yuhuili/EAGLE-Vicuna-7B-v1.3",
max_draft_len=63,
num_eagle_layers=4,
max_non_leaves_per_layer=10,

View File

@ -48,10 +48,10 @@ def run_medusa_decoding(use_modelopt_ckpt=False, model_dir=None):
model = "lmsys/vicuna-7b-v1.3"
# The end user can customize the medusa decoding configuration by specifying the
# speculative_model, max_draft_len, medusa heads num and medusa choices
# speculative_model_dir, max_draft_len, medusa heads num and medusa choices
# with the MedusaDecodingConfig class
speculative_config = MedusaDecodingConfig(
speculative_model="FasterDecoding/medusa-vicuna-7b-v1.3",
speculative_model_dir="FasterDecoding/medusa-vicuna-7b-v1.3",
max_draft_len=63,
num_medusa_heads=4,
medusa_choices=[[0], [0, 0], [1], [0, 1], [2], [0, 0, 0], [1, 0], [0, 2], [3], [0, 3], [4], [0, 4], [2, 0], \

View File

@ -35,7 +35,7 @@ def run_MTP(model: Optional[str] = None):
def run_Eagle3():
spec_config = EagleDecodingConfig(
max_draft_len=3,
pytorch_weights_path="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
speculative_model_dir="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
eagle3_one_model=True)
llm = LLM(
@ -50,7 +50,7 @@ def run_Eagle3():
def run_ngram():
spec_config = NGramDecodingConfig(
prompt_lookup_num_tokens=3,
max_draft_len=3,
max_matching_ngram_size=3,
is_keep_all=True,
is_use_oldest=True,

View File

@ -169,15 +169,15 @@ def setup_llm(args):
elif spec_decode_algo == "EAGLE3":
spec_config = EagleDecodingConfig(
max_draft_len=args.spec_decode_nextn,
pytorch_weights_path=args.draft_model_dir,
speculative_model_dir=args.draft_model_dir,
eagle3_one_model=args.use_one_model)
elif spec_decode_algo == "DRAFT_TARGET":
spec_config = DraftTargetDecodingConfig(
max_draft_len=args.spec_decode_nextn,
pytorch_weights_path=args.draft_model_dir)
speculative_model_dir=args.draft_model_dir)
elif spec_decode_algo == "NGRAM":
spec_config = NGramDecodingConfig(
prompt_lookup_num_tokens=args.spec_decode_nextn,
max_draft_len=args.spec_decode_nextn,
max_matching_ngram_size=args.max_matching_ngram_size,
is_keep_all=True,
is_use_oldest=True,

View File

@ -261,8 +261,8 @@ def create_autodeploy_executor(executor_config: ExecutorConfig, checkpoint_dir:
max_num_sequences = ad_config.max_batch_size * dist_mapping.pp_size
# some derivative properties
max_draft_tokens = (
0 if ad_config.speculative_config is None else ad_config.speculative_config.max_draft_tokens
max_draft_len = (
0 if ad_config.speculative_config is None else ad_config.speculative_config.max_draft_len
)
# initialize model engine
@ -299,7 +299,7 @@ def create_autodeploy_executor(executor_config: ExecutorConfig, checkpoint_dir:
# correctly for models as needed.
sampler_args = TorchSampler.Args(
max_seq_len=ad_config.max_seq_len,
max_draft_tokens=max_draft_tokens,
max_draft_len=max_draft_len,
max_num_sequences=max_num_sequences,
max_beam_width=executor_config.max_beam_width,
enable_mixed_sampler=ad_config.enable_mixed_sampler,
@ -317,7 +317,7 @@ def create_autodeploy_executor(executor_config: ExecutorConfig, checkpoint_dir:
disable_overlap_scheduler=ad_config.disable_overlap_scheduler,
max_input_len=ad_config.max_input_len,
max_batch_size=ad_config.max_batch_size,
max_draft_tokens=max_draft_tokens,
max_draft_len=max_draft_len,
max_beam_width=ad_config.max_beam_width,
)
return py_executor

View File

@ -70,7 +70,7 @@ class ModelConfig(Generic[TConfig]):
# to support mixed quantization.
skip_create_weights_in_init: bool = False
spec_config: Optional["SpecConfig"] = None
spec_config: Optional["DecodingBaseConfig"] = None
lora_config: Optional["LoraConfig"] = None
is_generation: bool = True

View File

@ -340,7 +340,7 @@ class SpecDecOneEngineForCausalLM(DecoderModelForCausalLM[TModel, TConfig],
model_config, 'spec_config', None
) and model_config.spec_config.spec_dec_mode.use_one_engine():
draft_config = ModelConfig.from_pretrained(
model_config.spec_config.draft_model_path,
model_config.spec_config.speculative_model_dir,
trust_remote_code=True,
attn_backend=model_config.attn_backend,
moe_backend=model_config.moe_backend,

View File

@ -157,10 +157,10 @@ class KvCacheCreator:
if not pytorch_backend_config.disable_overlap_scheduler:
num_extra_tokens_per_seq = num_extra_tokens_per_seq + 1
if spec_cfg is not None:
num_extra_tokens_per_seq += spec_cfg.max_draft_tokens
num_extra_tokens_per_seq += spec_cfg.max_draft_len
if spec_cfg is not None:
num_extra_tokens_per_seq += spec_cfg.max_draft_tokens
num_extra_tokens_per_seq += spec_cfg.max_draft_len
num_extra_tokens_per_seq += spec_cfg.num_extra_kv_tokens
for req in self._dummy_reqs:
num_req_tokens = len(req.input_token_ids) + num_extra_tokens_per_seq
@ -538,7 +538,7 @@ def create_py_executor_instance(
disable_overlap_scheduler,
max_batch_size=executor_config.max_batch_size,
max_beam_width=executor_config.max_beam_width,
max_draft_tokens=spec_config.max_draft_tokens
max_draft_len=spec_config.max_draft_len
if spec_config is not None else 0,
kv_cache_transceiver=kv_cache_transceiver,
draft_model_engine=draft_model_engine,
@ -549,11 +549,11 @@ def create_py_executor_instance(
def create_torch_sampler_args(executor_config: ExecutorConfig, mapping: Mapping,
*, max_seq_len: int, enable_mixed_sampler: bool):
max_num_sequences = executor_config.max_batch_size * mapping.pp_size
max_draft_tokens = (0 if executor_config.speculative_config is None else
executor_config.speculative_config.max_draft_tokens)
max_draft_len = (0 if executor_config.speculative_config is None else
executor_config.speculative_config.max_draft_len)
return TorchSampler.Args(
max_seq_len=max_seq_len,
max_draft_tokens=max_draft_tokens,
max_draft_len=max_draft_len,
max_num_sequences=max_num_sequences,
max_beam_width=executor_config.max_beam_width,
enable_mixed_sampler=enable_mixed_sampler,

View File

@ -8,7 +8,6 @@ from ...llmapi.llm_args import LoadFormat
from ...logger import logger
from ...mapping import Mapping
from ..model_config import MoeLoadBalancerConfig
from ..speculative import SpecConfig
from .resource_manager import BaseResourceManager
@ -110,7 +109,7 @@ def update_executor_config(
pytorch_backend_config: Optional[PyTorchConfig] = None,
mapping: Optional[Mapping] = None,
build_config: Optional[BuildConfig] = None,
speculative_config: Optional[SpecConfig] = None,
speculative_config: Optional["DecodingBaseConfig"] = None,
hf_model_dir: Optional[str] = None,
max_input_len: Optional[int] = None,
max_seq_len: Optional[int] = None):

View File

@ -53,7 +53,7 @@ class DecodingCUDAGraphRunner:
# [CUDA graph spec decode padding]
# We pad input IDs/position IDs to the maximum draft length (token per request).
# We're forced to do this because we cannot reallocate inputs over many graph runs.
token_per_request = spec_metadata.max_draft_tokens + 1 if spec_metadata is not None else 1
token_per_request = spec_metadata.max_draft_len + 1 if spec_metadata is not None else 1
# Using ones instead of zeros prevents NaNs in e.g. Deepseek
self.input_ids = torch.ones((batch_size * token_per_request, ),

View File

@ -51,7 +51,7 @@ from ..models.modeling_utils import (DecoderModelForCausalLM, MetaInitMode,
run_concurrently, timing)
from ..modules.fused_moe.moe_load_balancer import (
MoeLoadBalancer, MoeLoadBalancerIterContext, maybe_create_moe_load_balancer)
from ..speculative import SpecConfig, SpecMetadata, get_spec_metadata
from ..speculative import SpecMetadata, get_spec_metadata
from ..utils import (get_model_extra_attrs, set_torch_compiling,
with_model_extra_attrs)
from .config import LoadFormat, PyTorchConfig
@ -353,7 +353,7 @@ class PyTorchModelEngine(ModelEngine):
mapping: Optional[Mapping] = None,
attn_runtime_features: Optional[AttentionRuntimeFeatures] = None,
dist: Optional[MPIDist] = None,
spec_config: Optional[SpecConfig] = None,
spec_config: Optional["DecodingBaseConfig"] = None,
guided_decoding_config: Optional[GuidedDecodingConfig] = None,
lora_config: Optional[LoraConfig] = None,
is_draft_model: bool = False,
@ -456,7 +456,7 @@ class PyTorchModelEngine(ModelEngine):
if self.is_spec_decode:
self.spec_metadata = None
self.spec_config.update_from_model_config(self.model.config)
max_num_draft_tokens = self.spec_config.max_draft_tokens * batch_size
max_num_draft_tokens = self.spec_config.max_draft_len * batch_size
self.draft_tokens_cuda = torch.empty((max_num_draft_tokens, ),
dtype=torch.int,
device='cuda')
@ -472,7 +472,7 @@ class PyTorchModelEngine(ModelEngine):
device='cuda')
self.without_logits = self.spec_config.spec_dec_mode.without_logits(
)
self.max_draft_len = spec_config.max_draft_tokens
self.max_draft_len = spec_config.max_draft_len
else:
self.without_logits = False
self.max_draft_len = 0
@ -858,6 +858,7 @@ class PyTorchModelEngine(ModelEngine):
if no_cache:
return get_spec_metadata(
self.spec_config,
self.model.config,
self.batch_size,
max_num_tokens=self.max_num_tokens,
spec_resource_manager=spec_resource_manager,
@ -867,6 +868,7 @@ class PyTorchModelEngine(ModelEngine):
return self.spec_metadata
self.spec_metadata = get_spec_metadata(
self.spec_config,
self.model.config,
self.batch_size,
max_num_tokens=self.max_num_tokens,
spec_resource_manager=spec_resource_manager,
@ -951,7 +953,7 @@ class PyTorchModelEngine(ModelEngine):
def _maybe_get_cuda_graph(
self,
batch: ScheduledRequests,
spec_config: Optional[SpecConfig] = None
spec_config: Optional["DecodingBaseConfig"] = None
) -> Optional[DecodingCUDAGraphRunner]:
"""
Get a CUDA graph runner or return None (e.g. if CUDA graphs are disabled
@ -961,7 +963,7 @@ class PyTorchModelEngine(ModelEngine):
if ExpertStatistic.set_iter(self.iter_counter):
return None
spec_max_draft_tokens = spec_config.max_draft_tokens if self.is_spec_decode else 0
spec_max_draft_tokens = spec_config.max_draft_len if self.is_spec_decode else 0
can_run_cuda_graph = batch.can_run_cuda_graph
batch_size = len(batch.generation_requests)
if self._run_cuda_graphs and self.enable_attention_dp and self.mapping.tp_size > 1:
@ -1078,7 +1080,8 @@ class PyTorchModelEngine(ModelEngine):
if self.spec_config is not None and self.spec_config.spec_dec_mode.need_load_draft_weights(
):
weights = load_weights(self.spec_config.draft_model_path)
weights = load_weights(
self.spec_config.speculative_model_dir)
model.load_draft_weights(weights)
elif load_format == LoadFormat.DUMMY:
@ -1261,9 +1264,8 @@ class PyTorchModelEngine(ModelEngine):
extend_requests += extend_dummy_requests
if not self._disable_overlap_scheduler and self.is_spec_decode:
spec_dec_mode = self.spec_config.spec_dec_mode
assert spec_dec_mode.support_overlap_scheduler(
), f"{self.spec_config.spec_dec_name} does not support overlap scheduler"
assert self.spec_config.spec_dec_mode.support_overlap_scheduler(
), f"{self.spec_config.decoding_type} does not support overlap scheduler"
# will contain previous batch indices of generation requests
previous_batch_indices = []
@ -2078,7 +2080,7 @@ class PyTorchModelEngine(ModelEngine):
spec_metadata.spec_dec_mode.attention_need_spec_dec_mode(),
spec_metadata.is_spec_dec_tree,
spec_metadata.is_spec_dec_dynamic_tree,
spec_metadata.max_draft_tokens)
spec_metadata.max_draft_len)
else:
spec_metadata = None

View File

@ -173,7 +173,7 @@ class PyExecutor:
max_input_len: int = 2048,
max_batch_size: int = 8,
max_beam_width: int = 1,
max_draft_tokens: int = 0,
max_draft_len: int = 0,
kv_cache_transceiver: Optional[KvCacheTransceiver] = None,
draft_model_engine: Optional[ModelEngine] = None,
garbage_collection_gen0_threshold: Optional[int] = None,
@ -207,7 +207,7 @@ class PyExecutor:
self.active = True
self.next_req_id = max_batch_size # The first max_batch_size request IDs are reserved for dummy requests
self.max_beam_width = max_beam_width
self.max_draft_tokens = max_draft_tokens
self.max_draft_len = max_draft_len
self.print_log = model_engine.pytorch_backend_config.print_iter_log
self.enable_iter_perf_stats = model_engine.pytorch_backend_config.enable_iter_perf_stats
self.enable_iter_req_stats = model_engine.pytorch_backend_config.enable_iter_req_stats
@ -979,7 +979,7 @@ class PyExecutor:
LlmRequestState.DISAGG_GENERATION_INIT):
continue
req.py_last_draft_tokens = req.py_draft_tokens
max_draft_len = self.model_engine.spec_config.max_draft_tokens
max_draft_len = self.model_engine.spec_config.max_draft_len
if max_draft_len > 0:
req.py_draft_tokens = [0] * max_draft_len
@ -1523,7 +1523,7 @@ class PyExecutor:
request_ids=[0],
is_gen=not self.has_context_request,
prepare_resource=not self.has_context_request,
max_num_draft_tokens=self.max_draft_tokens,
max_num_draft_tokens=self.max_draft_len,
)[0]
llm_request.is_attention_dp_dummy = True
spec_resource_manager = self.resource_manager.get_resource_manager(
@ -1871,15 +1871,15 @@ class PyExecutor:
# this? Just needs proper kernel support.
def _pad_to_max_draft_tokens():
for req in scheduled_requests.generation_requests:
max_draft_tokens = self.max_draft_tokens
max_draft_len = self.max_draft_len
num_draft_tokens = len(req.py_draft_tokens)
req.py_draft_tokens.extend(
0 for _ in range(max_draft_tokens - num_draft_tokens))
0 for _ in range(max_draft_len - num_draft_tokens))
draft_batch.generation_requests = draft_batch.context_requests + draft_batch.generation_requests
draft_batch.context_requests = []
for i in range(self.max_draft_tokens - 1):
for i in range(self.max_draft_len - 1):
if len(draft_batch.generation_requests) == 0:
break

View File

@ -247,10 +247,10 @@ def create_py_executor(
draft_spec_config = copy.copy(spec_config)
# The draft model won't have any draft tokens attached to
# generation requests when we invoke it autoregressively
draft_spec_config.max_draft_tokens = 0
draft_spec_config.max_draft_len = 0
draft_model_engine = PyTorchModelEngine(
model_path=spec_config.draft_model_path,
model_path=spec_config.speculative_model_dir,
pytorch_backend_config=pytorch_backend_config,
batch_size=executor_config.max_batch_size,
max_beam_width=executor_config.max_beam_width,
@ -276,11 +276,11 @@ def create_py_executor(
if not pytorch_backend_config.disable_overlap_scheduler:
max_seq_len = model_engine.max_seq_len + 1
if spec_config is not None:
max_seq_len += spec_config.max_draft_tokens
max_seq_len += spec_config.max_draft_len
if spec_config is not None:
max_seq_len += spec_config.num_extra_kv_tokens
max_seq_len += spec_config.max_draft_tokens
max_seq_len += spec_config.max_draft_len
executor_config.max_seq_len = max_seq_len
executor_config.max_num_tokens = model_engine.max_num_tokens

View File

@ -2,7 +2,7 @@ import enum
import math
from abc import ABC, abstractmethod
from collections import OrderedDict, defaultdict
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union
import torch
@ -22,9 +22,6 @@ if ENABLE_MULTI_DEVICE:
from tensorrt_llm._utils import mpi_comm
if TYPE_CHECKING:
from ..speculative.interface import SpecConfig
BufferManagerCpp = tensorrt_llm.bindings.internal.runtime.BufferManager
KVCacheManagerCpp = tensorrt_llm.bindings.internal.batch_manager.KVCacheManager
KvCacheConfigCpp = tensorrt_llm.bindings.executor.KvCacheConfig
@ -83,7 +80,7 @@ class BaseResourceManager(ABC):
def get_pp_layers(
num_layers: int,
mapping: Mapping,
spec_config: Optional["SpecConfig"] = None,
spec_config: Optional["DecodingBaseConfig"] = None,
layer_mask: Optional[List[bool]] = None,
) -> Tuple[List[int], int]:
from ..speculative.utils import get_num_spec_layers
@ -127,7 +124,7 @@ class KVCacheManager(BaseResourceManager):
max_batch_size: int,
mapping: Mapping,
dtype: DataType = DataType.HALF,
spec_config: Optional["SpecConfig"] = None,
spec_config: Optional["DecodingBaseConfig"] = None,
layer_mask: Optional[List[bool]] = None,
max_num_tokens: int = 8192,
model_config: Optional[ModelConfig] = None,
@ -902,7 +899,7 @@ class MambaHybridCacheManager(KVCacheManager, MambaCacheManager):
max_batch_size: int,
mapping: Mapping,
dtype: DataType = DataType.HALF,
spec_config: Optional["SpecConfig"] = None,
spec_config: Optional["DecodingBaseConfig"] = None,
) -> None:
# mamba hybrid cache requires block reuse to be disabled in KV cache config

View File

@ -220,7 +220,7 @@ class TorchSampler(Sampler):
@dataclass(frozen=True, kw_only=True)
class Args:
max_seq_len: int
max_draft_tokens: int
max_draft_len: int
max_num_sequences: int
max_beam_width: int
enable_mixed_sampler: bool
@ -228,7 +228,7 @@ class TorchSampler(Sampler):
def __init__(self, args: Args):
self.max_seq_len = args.max_seq_len
self.enable_mixed_sampler = args.enable_mixed_sampler
self.max_tokens = args.max_draft_tokens + 1
self.max_tokens = args.max_draft_len + 1
assert args.max_beam_width == self.MAX_BEAM_WIDTH, "TorchSampler only supports beam_width = 1"
self.num_seq_slots = args.max_num_sequences

View File

@ -1,27 +1,20 @@
from .draft_target import DraftTargetConfig
from .eagle3 import Eagle3Config, Eagle3SpecMetadata
from .eagle3 import Eagle3SpecMetadata
from .interface import SpecConfig, SpecMetadata
from .mtp import MTPConfig, MTPEagleWorker, MTPSpecMetadata, MTPWorker
from .ngram import NGramConfig, NGramDrafter, NGramPoolManager
from .user_provided import UserProvidedConfig
from .mtp import MTPEagleWorker, MTPSpecMetadata, MTPWorker
from .ngram import NGramDrafter, NGramPoolManager
from .utils import (get_num_spec_layers, get_spec_decoder, get_spec_drafter,
get_spec_metadata, get_spec_resource_manager,
get_spec_worker)
__all__ = [
"DraftTargetConfig",
"Eagle3Config",
"Eagle3SpecMetadata",
"MTPConfig",
"MTPEagleWorker",
"MTPSpecMetadata",
"MTPWorker",
"NGramConfig",
"NGramDrafter",
"NGramPoolManager",
"SpecConfig",
"SpecMetadata",
"UserProvidedConfig",
"get_num_spec_layers",
"get_spec_decoder",
"get_spec_drafter",

View File

@ -1,35 +0,0 @@
from dataclasses import dataclass
import torch
from .interface import SpecConfig, SpecMetadata, SpeculativeDecodingMode
@dataclass
class DraftTargetConfig(SpecConfig):
spec_dec_name: str = "DRAFT_TARGET"
def __post_init__(self):
if self.draft_model_path is None:
raise ValueError("Path to Draft weights must be specified.")
self.spec_dec_mode = SpeculativeDecodingMode.from_string(
self.spec_dec_name)
self.num_extra_kv_tokens = 0
def update_from_model_config(self, model_config):
pass
def get_draft_model_prompt(self,
input_tokens: torch.Tensor) -> torch.Tensor:
return input_tokens
@dataclass
class DraftTargetSpecMetadata(SpecMetadata):
def __post_init__(self):
pass
def prepare(self):
pass

View File

@ -4,7 +4,6 @@ from typing import List, Optional, Tuple
import torch
from torch import nn
from tensorrt_llm.logger import logger
from tensorrt_llm.mapping import Mapping
from ..attention_backend import AttentionMetadata
@ -12,42 +11,10 @@ from ..pyexecutor.llm_request import LlmRequest
from ..pyexecutor.resource_manager import BaseResourceManager, SlotManager
from ..pyexecutor.sampler import TorchSampler
from ..pyexecutor.scheduler import ScheduledRequests
from .interface import SpecConfig, SpecMetadata, SpeculativeDecodingMode
from .interface import SpecMetadata
from .mtp import MTPSampler
@dataclass
class Eagle3Config(SpecConfig):
spec_dec_name: str = "EAGLE3"
num_layers: int = 0
hidden_size: int = 0
eagle3_one_model: bool = True
def __post_init__(self):
if self.draft_model_path is None:
raise ValueError("Path to EAGLE3 weights must be specified.")
if self.eagle3_one_model:
self.spec_dec_mode = SpeculativeDecodingMode.EAGLE3_ONE_MODEL
self.num_extra_kv_tokens = self.max_draft_tokens - 1
else:
self.spec_dec_mode = SpeculativeDecodingMode.from_string(
self.spec_dec_name)
logger.info(f"EAGLE3 Config: {self}")
def update_from_model_config(self, model_config):
self.num_layers = model_config.num_hidden_layers
self.hidden_size = model_config.hidden_size
self.dtype = model_config.torch_dtype
def get_draft_model_prompt(self,
input_tokens: torch.Tensor) -> torch.Tensor:
"""
Eagle3 always throws away the first token when processing draft inputs
"""
return input_tokens[1:]
class Eagle3ResourceManager(BaseResourceManager):
"""
Eagle3 needs to save the hidden states for the draft model. When using
@ -55,11 +22,11 @@ class Eagle3ResourceManager(BaseResourceManager):
and one for the draft model. Use this class to manage the hidden states.
"""
def __init__(self, config: Eagle3Config, dtype: torch.dtype,
def __init__(self, config: "EagleDecodingConfig", dtype: torch.dtype,
hidden_size: int, max_num_requests: int, max_seq_len: int,
max_num_tokens: int):
self.dtype = dtype
self.max_draft_tokens = config.max_draft_tokens
self.max_draft_len = config.max_draft_len
self.hidden_size = hidden_size
self.max_num_requests = max_num_requests
self.max_seq_len = max_seq_len
@ -268,7 +235,7 @@ class Eagle3OneModelSpecMetadata(SpecMetadata):
pin_memory=True)
self.batch_indices_cuda[:num_seqs].copy_(batch_indices,
non_blocking=True)
self.num_tokens -= (self.num_generations) * self.max_draft_tokens
self.num_tokens -= (self.num_generations) * self.max_draft_len
def maybe_capture_hidden_states(
self,
@ -288,15 +255,15 @@ class Eagle3OneModelSpecMetadata(SpecMetadata):
class Eagle3OneModelSampler(MTPSampler):
def __init__(self, args: TorchSampler.Args):
super().__init__(args, nextn=args.max_draft_tokens)
super().__init__(args, nextn=args.max_draft_len)
class Eagle3OneModelWorker(nn.Module):
def __init__(self, spec_config: Eagle3Config, mapping: Mapping):
def __init__(self, spec_config: "EagleDecodingConfig", mapping: Mapping):
super().__init__()
self.spec_config = spec_config
self.max_draft_tokens = self.spec_config.max_draft_tokens
self.max_draft_len = self.spec_config.max_draft_len
self.mapping = mapping
@torch.compile(mode="max-autotune-no-cudagraphs")
@ -333,7 +300,7 @@ class Eagle3OneModelWorker(nn.Module):
# Predict draft tokens
next_draft_tokens = []
for i in range(self.max_draft_tokens):
for i in range(self.max_draft_len):
hidden_states, hidden_states_to_save = draft_model.model(**inputs)
# FIXME (jhaotingc): Currently we disable use_spec_decoding mode for Eagle engine nth steps except 1st step.
@ -344,7 +311,7 @@ class Eagle3OneModelWorker(nn.Module):
attn_metadata.use_spec_decoding = False
if i == 0:
start_ids_gen = (spec_metadata.batch_indices_cuda[:num_gens] *
(self.max_draft_tokens + 1)).long()
(self.max_draft_len + 1)).long()
gather_ids_gen = (start_ids_gen +
num_accepted_tokens[num_contexts:] - 1 +
attn_metadata.num_ctx_tokens)
@ -374,8 +341,7 @@ class Eagle3OneModelWorker(nn.Module):
# update kv_lens_cuda
if hasattr(attn_metadata, 'kv_lens_cuda'):
attn_metadata.kv_lens_cuda[num_contexts:batch_size] -= (
self.max_draft_tokens -
num_accepted_tokens[num_contexts:])
self.max_draft_len - num_accepted_tokens[num_contexts:])
attn_metadata.kv_lens_cuda[:num_contexts] += 1
elif hasattr(attn_metadata, 'kv_lens_cuda'):
attn_metadata.kv_lens_cuda[:batch_size] += 1
@ -428,7 +394,7 @@ class Eagle3OneModelWorker(nn.Module):
logits = logits.unsqueeze(0)
# The return buffer
accepted_tokens = torch.empty((batch_size, (self.max_draft_tokens + 1)),
accepted_tokens = torch.empty((batch_size, (self.max_draft_len + 1)),
dtype=torch.int,
device=logits.device)
num_accepted_tokens = torch.ones(batch_size,
@ -442,13 +408,13 @@ class Eagle3OneModelWorker(nn.Module):
# generation
gen_target_tokens = target_tokens[num_contexts:].reshape(
num_gens, self.max_draft_tokens + 1)
num_gens, self.max_draft_len + 1)
accepted_tokens[num_contexts:, :] = gen_target_tokens
draft_tokens = spec_metadata.draft_tokens.reshape(
num_gens, self.max_draft_tokens)
num_accepted_tokens[num_contexts:] += torch.cumprod((
draft_tokens == gen_target_tokens[:, :self.max_draft_tokens]).int(),
dim=-1).sum(1)
num_gens, self.max_draft_len)
num_accepted_tokens[num_contexts:] += torch.cumprod(
(draft_tokens == gen_target_tokens[:, :self.max_draft_len]).int(),
dim=-1).sum(1)
return accepted_tokens, num_accepted_tokens
def draft_decoder(
@ -468,7 +434,7 @@ class Eagle3OneModelWorker(nn.Module):
Returns:
draft_tokens: torch.Tensor
[batch_size * max_draft_tokens]
[batch_size * max_draft_len]
Draft token ids. Flattened.
'''

View File

@ -7,7 +7,6 @@ import torch
from ..._utils import get_sm_version
from ..attention_backend.trtllm import AttentionBackend, TrtllmAttention
from ..model_config import TConfig
class SpeculativeDecodingMode(IntEnum):
@ -110,32 +109,11 @@ class SpeculativeDecodingMode(IntEnum):
class SpecConfig:
"""
Configuration for speculative decoding.
This class is deprecated, but thread-leak of pytest raises flaky error if removing it.
TODO: remove this class safely.
"""
# The name of speculative decoding.
spec_dec_name = None
# The mode of speculative decoding.
spec_dec_mode: SpeculativeDecodingMode = SpeculativeDecodingMode.NONE
# The max number of draft tokens
max_draft_tokens: int = 1024
# The path to the draft model
draft_model_path: Optional[str] = None
# The number of extra kv tokens
num_extra_kv_tokens: int = 0
def __post_init__(self) -> None:
self.spec_dec_mode = SpeculativeDecodingMode.from_string(
self.spec_dec_name)
def update_from_model_config(self, model_config: TConfig):
pass
def get_draft_model_prompt(self,
input_tokens: torch.Tensor) -> torch.Tensor:
"""
Override for spec dec modes that need to preprocess prompt
tokens before passing them to the draft model.
"""
return input_tokens
@dataclass
@ -146,7 +124,7 @@ class SpecMetadata:
# The max number of requests in a single batch.
max_num_requests: int
# The max number of draft tokens.
max_draft_tokens: int
max_draft_len: int
# The number of gen-phase sequences in the batch.
num_generations: int = 0
# Whether CUDA graph is enabled.
@ -180,7 +158,7 @@ class SpecMetadata:
# Some speculative decoding methods need to use different kv lengths for the
# draft/target layers. But KVCacheManager can only support kv caches with the
# same kv lengths for different layers. Add extra kv token in kv cache manager
# to haddle this issue.
# to handle this issue.
num_extra_kv_tokens: Optional[int] = 0 # Number of layers in target model
# The number of layers
num_layers: int = 0

View File

@ -10,7 +10,7 @@ from ..pyexecutor.resource_manager import BaseResourceManager, SlotManager
from ..pyexecutor.sampler import (SampleState, SampleStateTensors, TorchSampler,
add_token, int_tensor)
from ..pyexecutor.scheduler import ScheduledRequests
from .interface import SpecConfig, SpecMetadata, SpeculativeDecodingMode
from .interface import SpecMetadata
@dataclass(kw_only=True)
@ -25,52 +25,10 @@ class SampleStateMTP(SampleState):
host: SampleStateTensorsMTP
@dataclass
class MTPConfig(SpecConfig):
"""
Configuration for MTP.
"""
# The name of speculative decoding.
spec_dec_name = "MTP"
# The number of MTP modules
num_nextn_predict_layers: int = 1
# The number of max batch size
max_batch_size: int = 8
# Whether to use relaxed acceptance during thinking phase for reasoning model
use_relaxed_acceptance_for_thinking: bool = False
# The top-N tokens are sampled from logits to obtain a candidate set.
relaxed_topk: int = 1
# The threshold to further filter the candidate set.
# Filter out tokens with a large probability gap between the top-1 token's log probability.
relaxed_delta: float = 0.
# Whether to use vanilla MTP
use_mtp_vanilla: bool = False
# TODO: Hard code for DeepSeek R1
# When encounter <think>, start thinking phase.
# When encounter </think>, end thinking phase.
# <think> [thinking phase] </think> [real output]
BEGIN_THINKING_PHASE_TOKEN: int = 128798
END_THINKING_PHASE_TOKEN: int = 128799
def __post_init__(self) -> None:
self.spec_dec_mode = SpeculativeDecodingMode.from_string(
self.spec_dec_name)
self.max_draft_tokens = self.num_nextn_predict_layers
def update_from_model_config(self, model_config):
assert self.num_nextn_predict_layers > 0
if model_config.num_nextn_predict_layers == 1 and not self.use_mtp_vanilla:
self.spec_dec_mode = SpeculativeDecodingMode.MTP_EAGLE
self.num_extra_kv_tokens = self.num_nextn_predict_layers - 1
class MTPHiddenStatesManager(BaseResourceManager):
def __init__(self, config: MTPConfig, dtype: torch.dtype, hidden_size: int,
max_num_requests: int):
def __init__(self, config: "MTPDecodingConfig", dtype: torch.dtype,
hidden_size: int, max_num_requests: int):
self.dtype = dtype
self.num_nextn_predict_layers = config.num_nextn_predict_layers
self.hidden_size = hidden_size
@ -199,8 +157,8 @@ class MTPSpecMetadata(SpecMetadata):
pin_memory=True)
self.batch_indices_cuda[:num_seqs].copy_(batch_indices,
non_blocking=True)
# MTP vanilla worker uses total max_draft_tokens input tokens in generation phase,
# while MTP Eagle worker uses (max_draft_tokens + 1) input tokens in the 1st draft
# MTP vanilla worker uses total max_draft_len input tokens in generation phase,
# while MTP Eagle worker uses (max_draft_len + 1) input tokens in the 1st draft
# forward and only one input token in the following draft forward.
# This num_tokens is used to set the all_rank_num_tokens for attention dp.
if not self.spec_dec_mode.is_mtp_eagle():
@ -353,7 +311,7 @@ class MTPSampler(TorchSampler):
class MTPWorker(nn.Module):
def __init__(self, spec_config: MTPConfig):
def __init__(self, spec_config: "MTPDecodingConfig"):
super().__init__()
self.spec_config = spec_config
self.is_thop = False
@ -742,7 +700,7 @@ class MTPWorker(nn.Module):
Returns:
accepted_tokens: torch.Tensor
[batch_size, (max_draft_tokens + 1)]
[batch_size, (max_draft_len + 1)]
Accepted token ids. Flattened.
num_accepted_tokens: torch.Tensor
@ -942,7 +900,7 @@ class MTPWorker(nn.Module):
Target model's hidden states.
accepted_tokens: torch.Tensor
[batch_size, max_draft_tokens + 1]
[batch_size, max_draft_len + 1]
Accepted token ids. Flattened.
num_accepted_tokens: torch.Tensor
@ -1080,7 +1038,7 @@ class MTPWorker(nn.Module):
Returns:
draft_tokens: torch.Tensor
[batch_size * max_draft_tokens]
[batch_size * max_draft_len]
Draft token ids. Flattened.
'''
@ -1090,7 +1048,7 @@ class MTPWorker(nn.Module):
class MTPEagleWorker(MTPWorker):
def __init__(self, spec_config: MTPConfig):
def __init__(self, spec_config: "MTPDecodingConfig"):
super().__init__(spec_config)
self.mtp_num_modules = spec_config.num_nextn_predict_layers

View File

@ -1,4 +1,3 @@
from dataclasses import dataclass
from itertools import chain
from ordered_set import OrderedSet
@ -9,34 +8,6 @@ from ..pyexecutor.llm_request import *
from ..pyexecutor.resource_manager import BaseResourceManager
from ..pyexecutor.scheduler import ScheduledRequests
from .drafter import Drafter
from .interface import SpecConfig, SpeculativeDecodingMode
@dataclass
class NGramConfig(SpecConfig):
"""
Configuration for NGram drafter.
"""
# The name of speculative decoding.
spec_dec_name = "NGRAM"
num_extra_kv_tokens: int = 0
max_draft_tokens: int = 0
prompt_lookup_num_tokens: int = 5
max_matching_ngram_size: int = 5
end_id: int = -1
is_keep_all: bool = True
is_use_oldest: bool = True
is_public_pool: bool = True
def __post_init__(self) -> None:
self.spec_dec_mode = SpeculativeDecodingMode.from_string(
self.spec_dec_name)
self.max_draft_tokens = self.prompt_lookup_num_tokens
def update_from_model_config(self, model_config):
pass
class NGramPoolManager(BaseResourceManager):
@ -52,7 +23,7 @@ class NGramPoolManager(BaseResourceManager):
`matches` is a list of candidate draft token ids attaching to a pattern.
Arguments:
prompt_lookup_num_tokens: int
max_draft_len: int
The length maximum of draft tokens (can be understood as length maximum of output draft tokens).
max_matching_ngram_size: int
@ -76,8 +47,9 @@ class NGramPoolManager(BaseResourceManager):
It maps from request ID to the index of the prompt to update the pool in the next step.
"""
def __init__(self, spec_config: SpecConfig, max_num_requests: int):
self.prompt_lookup_num_tokens = spec_config.prompt_lookup_num_tokens
def __init__(self, spec_config: "NGramDecodingConfig",
max_num_requests: int):
self.max_draft_len = spec_config.max_draft_len
self.max_matching_ngram_size = spec_config.max_matching_ngram_size
self.is_keep_all = spec_config.is_keep_all
self.is_use_oldest = spec_config.is_use_oldest # TODO: remove this if updating strategy is supported
@ -133,7 +105,7 @@ class NGramPoolManager(BaseResourceManager):
-1):
# Find each possible pattern-match combination, and use tuple for hash
for l in range(len(sequence) - size):
r = min(l + size + self.prompt_lookup_num_tokens, len(sequence))
r = min(l + size + self.max_draft_len, len(sequence))
pattern = tuple(sequence[l:l + size])
new_match = tuple(sequence[l + size:r])
if pattern not in pool or \
@ -165,7 +137,7 @@ class NGramPoolManager(BaseResourceManager):
# Update start_index
self.start_index[request_id] = max(
0, prefix_len -
(self.prompt_lookup_num_tokens + self.max_matching_ngram_size - 1))
(self.max_draft_len + self.max_matching_ngram_size - 1))
return draft_tokens
@ -191,12 +163,12 @@ class NGramDrafter(Drafter):
def __init__(
self,
spec_config: SpecConfig,
spec_config: "NGramDecodingConfig",
ngram_pool_manager: NGramPoolManager = None,
):
assert ngram_pool_manager is not None, "NGram needs a resource manager to maintain the pool."
super().__init__(spec_resource_manager=ngram_pool_manager)
self.max_num_draft_tokens = spec_config.max_draft_tokens
self.max_draft_len = spec_config.max_draft_len
def prepare_draft_tokens(
self,
@ -220,8 +192,8 @@ class NGramDrafter(Drafter):
request.py_end_id,
request.py_orig_prompt_len + request.py_max_new_tokens,
)
# Pad length to `self.max_num_draft_tokens`
# Pad length to `self.max_draft_len`
if len(draft_tokens) > 0:
pad_length = self.max_num_draft_tokens - len(draft_tokens)
pad_length = self.max_draft_len - len(draft_tokens)
draft_tokens.extend([request.py_end_id] * pad_length)
request.py_draft_tokens = draft_tokens

View File

@ -1,26 +0,0 @@
from dataclasses import dataclass
from typing import Optional
from tensorrt_llm._torch.speculative.drafter import Drafter
from .interface import SpecConfig, SpeculativeDecodingMode
@dataclass
class UserProvidedConfig(SpecConfig):
"""
Configuration for user provided speculative decoding.
"""
# The name of speculative decoding.
spec_dec_name = "USER_PROVIDED"
num_extra_kv_tokens: int = 0
max_draft_tokens: int = 0
drafter: Optional[Drafter] = None
def __post_init__(self) -> None:
self.spec_dec_mode = SpeculativeDecodingMode.from_string(
self.spec_dec_name)
def update_from_model_config(self, model_config):
pass

View File

@ -1,7 +1,6 @@
from tensorrt_llm._torch.pyexecutor.sampler import TorchSampler
from tensorrt_llm._torch.speculative.interface import SpecConfig, SpecMetadata
from tensorrt_llm._torch.speculative.interface import SpecMetadata
from .draft_target import DraftTargetSpecMetadata
from .eagle3 import (Eagle3OneModelSampler, Eagle3OneModelSpecMetadata,
Eagle3OneModelWorker, Eagle3ResourceManager,
Eagle3SpecMetadata)
@ -11,13 +10,14 @@ from .ngram import NGramDrafter, NGramPoolManager
def get_spec_metadata(spec_config,
model_config,
max_num_requests,
max_num_tokens,
spec_resource_manager=None,
is_draft_model=False):
if spec_config.spec_dec_mode.is_mtp():
return MTPSpecMetadata(
max_draft_tokens=spec_config.max_draft_tokens,
max_draft_len=spec_config.max_draft_len,
spec_dec_mode=spec_config.spec_dec_mode,
mtp_num_modules=spec_config.num_nextn_predict_layers,
max_num_requests=max_num_requests,
@ -25,35 +25,30 @@ def get_spec_metadata(spec_config,
)
if spec_config.spec_dec_mode.is_eagle3():
return Eagle3SpecMetadata(
max_draft_tokens=spec_config.max_draft_tokens,
max_draft_len=spec_config.max_draft_len,
spec_dec_mode=spec_config.spec_dec_mode,
max_num_requests=max_num_requests,
num_layers=spec_config.num_layers,
hidden_size=spec_config.hidden_size,
num_layers=model_config.num_hidden_layers,
hidden_size=model_config.hidden_size,
max_num_tokens=max_num_tokens,
dtype=spec_config.dtype,
dtype=model_config.torch_dtype,
is_draft_model=is_draft_model,
eagle3_resource_manager=spec_resource_manager,
)
if spec_config.spec_dec_mode.is_eagle3_one_model():
return Eagle3OneModelSpecMetadata(
max_draft_tokens=spec_config.max_draft_tokens,
max_draft_len=spec_config.max_draft_len,
spec_dec_mode=spec_config.spec_dec_mode,
max_num_requests=max_num_requests,
num_layers=spec_config.num_layers,
hidden_size=spec_config.hidden_size,
num_layers=model_config.num_hidden_layers,
hidden_size=model_config.hidden_size,
max_num_tokens=max_num_tokens,
)
if spec_config.spec_dec_mode.is_draft_target():
return DraftTargetSpecMetadata(
max_draft_tokens=spec_config.max_draft_tokens,
spec_dec_mode=spec_config.spec_dec_mode,
max_num_requests=max_num_requests,
)
if spec_config.spec_dec_mode.is_ngram(
) or spec_config.spec_dec_mode.is_user_provided():
if spec_config.spec_dec_mode.is_draft_target() or \
spec_config.spec_dec_mode.is_ngram() or \
spec_config.spec_dec_mode.is_user_provided():
return SpecMetadata(
max_draft_tokens=spec_config.max_draft_tokens,
max_draft_len=spec_config.max_draft_len,
spec_dec_mode=spec_config.spec_dec_mode,
max_num_requests=max_num_requests,
)
@ -104,7 +99,8 @@ def get_spec_resource_manager(model_engine,
return None
def get_spec_decoder(sampler_args: TorchSampler.Args, spec_config: SpecConfig):
def get_spec_decoder(sampler_args: TorchSampler.Args,
spec_config: "DecodingBaseConfig"):
if spec_config.spec_dec_mode.is_mtp():
return MTPSampler(sampler_args,
nextn=spec_config.num_nextn_predict_layers)
@ -133,18 +129,16 @@ def get_spec_drafter(model_engine):
def get_num_spec_layers(spec_config):
if spec_config.spec_dec_mode.is_mtp():
return spec_config.num_nextn_predict_layers
elif spec_config.spec_dec_mode.is_eagle3_one_model():
if spec_config.spec_dec_mode.is_eagle3_one_model():
return 1
else:
return 0
return 0
def get_spec_worker(spec_config, mapping):
if spec_config.spec_dec_mode.is_mtp():
return MTPWorker(spec_config)
elif spec_config.spec_dec_mode.is_mtp_eagle():
if spec_config.spec_dec_mode.is_mtp_eagle():
return MTPEagleWorker(spec_config)
elif spec_config.spec_dec_mode.is_eagle3_one_model():
if spec_config.spec_dec_mode.is_eagle3_one_model():
return Eagle3OneModelWorker(spec_config, mapping)
else:
return None
return None

View File

@ -1,4 +1,5 @@
import copy
import functools
import json
import math
import os
@ -222,7 +223,8 @@ class _ModelFormatKind(Enum):
class DecodingBaseConfig(BaseModel):
max_draft_len: Optional[int] = None
speculative_model: Optional[Union[str, Path]] = None
speculative_model_dir: Optional[Union[str, Path]] = None
num_extra_kv_tokens: int = 0
@classmethod
def from_dict(cls, data: dict):
@ -247,6 +249,35 @@ class DecodingBaseConfig(BaseModel):
def _check_fields(self):
pass
def supports_backend(self, backend: str) -> bool:
"""
Override if the speculation algorithm does not support
a subset of the possible backends.
"""
return True
def validate(self) -> None:
"""
Do any additional error checking here.
"""
@functools.cached_property
def spec_dec_mode(self):
# spec_dec_mode has more functionality than the raw decoding_mode string.
# Use an alias for the import here to avoid name collisions with the one for the
# TRT backend.
from tensorrt_llm._torch.speculative.interface import \
SpeculativeDecodingMode as TorchSpeculativeDecodingMode
return TorchSpeculativeDecodingMode.from_string(
self.decoding_type.upper())
def update_from_model_config(self, model_config):
pass
def get_draft_model_prompt(self,
input_tokens: torch.Tensor) -> torch.Tensor:
return input_tokens
class MedusaDecodingConfig(DecodingBaseConfig):
medusa_choices: Optional[List[List[int]]] = None
@ -258,6 +289,9 @@ class MedusaDecodingConfig(DecodingBaseConfig):
decoding_type: ClassVar[str] = "Medusa"
def supports_backend(self, backend: str) -> bool:
return backend not in ("pytorch", "_autodeploy")
class EagleDecodingConfig(DecodingBaseConfig):
eagle_choices: Optional[List[List[int]]] = None
@ -267,7 +301,6 @@ class EagleDecodingConfig(DecodingBaseConfig):
dynamic_tree_max_topK: Optional[int] = None
num_eagle_layers: Optional[int] = None
max_non_leaves_per_layer: Optional[int] = None
pytorch_weights_path: Optional[str] = None
eagle3_one_model: Optional[bool] = True
@classmethod
@ -276,6 +309,25 @@ class EagleDecodingConfig(DecodingBaseConfig):
decoding_type: ClassVar[str] = "Eagle"
def validate(self) -> None:
if self.speculative_model_dir is None:
raise ValueError("Draft model must be provided for EAGLE")
@functools.cached_property
def spec_dec_mode(self):
from tensorrt_llm._torch.speculative.interface import \
SpeculativeDecodingMode as TorchSpeculativeDecodingMode
if self.eagle3_one_model:
return TorchSpeculativeDecodingMode.EAGLE3_ONE_MODEL
return TorchSpeculativeDecodingMode.EAGLE3
def get_draft_model_prompt(self,
input_tokens: torch.Tensor) -> torch.Tensor:
"""
Eagle3 always throws away the first token when processing draft inputs
"""
return input_tokens[1:]
class UserProvidedDecodingConfig(DecodingBaseConfig):
# Type should be Drafter, but it leads to circular import
@ -285,7 +337,7 @@ class UserProvidedDecodingConfig(DecodingBaseConfig):
def from_dict(cls, data: dict):
return cls(**data)
decoding_type: ClassVar[str] = "UserProvided"
decoding_type: ClassVar[str] = "User_Provided"
class NGramDecodingConfig(DecodingBaseConfig):
@ -293,7 +345,7 @@ class NGramDecodingConfig(DecodingBaseConfig):
Configuration for NGram drafter speculative decoding.
Arguments:
prompt_lookup_num_tokens: int
max_draft_len: int
The length maximum of draft tokens (can be understood as length maximum of output draft tokens).
max_matching_ngram_size: int
@ -309,7 +361,6 @@ class NGramDecodingConfig(DecodingBaseConfig):
Whether to use a common pool for all requests, or the pool is private for each request if False.
"""
prompt_lookup_num_tokens: int = 2
max_matching_ngram_size: int = 4
is_keep_all: bool = True
is_use_oldest: bool = True
@ -321,15 +372,20 @@ class NGramDecodingConfig(DecodingBaseConfig):
decoding_type: ClassVar[str] = "NGram"
def supports_backend(self, backend: str) -> bool:
return backend == "pytorch"
class DraftTargetDecodingConfig(DecodingBaseConfig):
pytorch_weights_path: Optional[str] = None
@classmethod
def from_dict(cls, data: dict):
return cls(**data)
decoding_type: ClassVar[str] = "DraftTarget"
decoding_type: ClassVar[str] = "Draft_Target"
def supports_backend(self, backend: str) -> bool:
return backend == "pytorch"
class MTPDecodingConfig(DecodingBaseConfig):
@ -339,12 +395,39 @@ class MTPDecodingConfig(DecodingBaseConfig):
relaxed_delta: Optional[float] = 0.
use_mtp_vanilla: Optional[bool] = False
# TODO: remove this after distinguishing `max_draft_len` and `num_nextn_predict_layers`
# Now we need a flag when MTPDecodingConfig is updated by PyTorchModelEngine.
num_nextn_predict_layers_from_model_config: Optional[int] = 1
# TODO: Hard code for DeepSeek R1
# When encounter <think>, start thinking phase.
# When encounter </think>, end thinking phase.
# <think> [thinking phase] </think> [real output]
BEGIN_THINKING_PHASE_TOKEN: int = 128798
END_THINKING_PHASE_TOKEN: int = 128799
@classmethod
def from_dict(cls, data: dict):
return cls(**data)
decoding_type: ClassVar[str] = "MTP"
def supports_backend(self, backend: str) -> bool:
return backend == "pytorch"
@functools.cached_property
def spec_dec_mode(self):
from tensorrt_llm._torch.speculative.interface import \
SpeculativeDecodingMode as TorchSpeculativeDecodingMode
if self.num_nextn_predict_layers_from_model_config == 1 and not self.use_mtp_vanilla:
return TorchSpeculativeDecodingMode.MTP_EAGLE
return TorchSpeculativeDecodingMode.MTP
def update_from_model_config(self, model_config):
assert self.num_nextn_predict_layers > 0
if model_config.num_nextn_predict_layers == 1 and not self.use_mtp_vanilla:
self.num_extra_kv_tokens = self.num_nextn_predict_layers - 1
class PybindMirror(ABC):
''' A class containing the utilities for mirroring Python classes to
@ -635,6 +718,9 @@ class LookaheadDecodingConfig(DecodingBaseConfig, PybindMirror):
self.max_ngram_size,
self.max_verification_set_size)
def supports_backend(self, backend: str) -> bool:
return backend not in ("pytorch", "_autodeploy")
decoding_type: ClassVar[str] = "Lookahead"
@ -1037,7 +1123,7 @@ class BaseLlmArgs(BaseModel):
return self._model_format
@property
def speculative_model(self) -> Optional[_ModelFormatKind]:
def speculative_model_dir(self) -> Optional[_ModelFormatKind]:
return self._speculative_model
@property
@ -1314,33 +1400,40 @@ class BaseLlmArgs(BaseModel):
@model_validator(mode="after")
def validate_speculative_config(self):
if self.speculative_config:
if isinstance(self.speculative_config, LookaheadDecodingConfig):
lookahead_config = self.speculative_config
# Update the build config
_, _, max_draft_tokens, _ = lookahead_config.calculate_speculative_resource(
)
self.build_config.speculative_decoding_mode = SpeculativeDecodingMode.LOOKAHEAD_DECODING
if max_draft_tokens > self.build_config.max_draft_len:
self.build_config.max_draft_len = max_draft_tokens
if not self.speculative_config.supports_backend(self.backend):
raise ValueError(
f"Speculation type {self.speculative_config.decoding_type} does not "
f"support backend {self.backend}")
# Below, we only need to set speculative_decoding_mode/decoding_config for speculation
# on the TRT backend.
if isinstance(self.speculative_config, LookaheadDecodingConfig):
max_draft_len = self.speculative_config.calculate_speculative_resource(
)[2]
assert max_draft_len > 0
self.build_config.speculative_decoding_mode = SpeculativeDecodingMode.LOOKAHEAD_DECODING
self.build_config.max_draft_len = max(
self.build_config.max_draft_len, max_draft_len)
self.decoding_config = DecodingConfig(
decoding_mode=DecodingMode.Lookahead(),
lookahead_decoding_config=PybindMirror.maybe_to_pybind(
lookahead_config))
elif isinstance(self.speculative_config, MedusaDecodingConfig):
self.build_config.speculative_decoding_mode = SpeculativeDecodingMode.MEDUSA
self.speculative_config))
elif isinstance(self.speculative_config, MedusaDecodingConfig):
assert self.speculative_config.max_draft_len > 0
self.build_config.speculative_decoding_mode = SpeculativeDecodingMode.MEDUSA
self.build_config.max_draft_len = self.speculative_config.max_draft_len
self.decoding_config = DecodingConfig(
decoding_mode=DecodingMode.Medusa(),
medusa_choices=self.speculative_config.medusa_choices)
elif isinstance(self.speculative_config, EagleDecodingConfig):
self.build_config.speculative_decoding_mode = SpeculativeDecodingMode.EAGLE
assert self.speculative_config.max_draft_len > 0
assert self.speculative_config.speculative_model_dir is not None, "Path to EAGLE3 weights must be specified."
self.build_config.max_draft_len = self.speculative_config.max_draft_len
self.build_config.speculative_decoding_mode = SpeculativeDecodingMode.EAGLE
if self.speculative_config.eagle3_one_model:
self.num_extra_kv_tokens = self.max_draft_len - 1
if self.backend not in ['pytorch', '_autodeploy']:
eagle_config = _EagleConfig(
self.speculative_config.eagle_choices,
@ -1351,68 +1444,39 @@ class BaseLlmArgs(BaseModel):
self.decoding_config = DecodingConfig(
decoding_mode=DecodingMode.Eagle(),
eagle_config=eagle_config)
else:
from tensorrt_llm._torch.speculative import Eagle3Config
self.speculative_config = Eagle3Config(
max_draft_tokens=self.speculative_config.max_draft_len,
draft_model_path=self.speculative_config.
pytorch_weights_path,
eagle3_one_model=self.speculative_config.
eagle3_one_model)
elif isinstance(self.speculative_config, NGramDecodingConfig):
assert self.backend in ['pytorch', '_autodeploy']
assert self.speculative_config.prompt_lookup_num_tokens > 0 and self.speculative_config.max_matching_ngram_size > 0
assert self.speculative_config.max_draft_len > 0 and self.speculative_config.max_matching_ngram_size > 0
self.build_config.speculative_decoding_mode = SpeculativeDecodingMode.NGRAM
self.build_config.max_draft_len = self.speculative_config.max_draft_len
from tensorrt_llm._torch.speculative import NGramConfig
self.speculative_config = NGramConfig(
prompt_lookup_num_tokens=self.speculative_config.
prompt_lookup_num_tokens,
max_matching_ngram_size=self.speculative_config.
max_matching_ngram_size,
is_keep_all=self.speculative_config.is_keep_all,
is_use_oldest=self.speculative_config.is_use_oldest,
is_public_pool=self.speculative_config.is_public_pool,
)
elif isinstance(self.speculative_config, DraftTargetDecodingConfig):
self.build_config.speculative_decoding_mode = SpeculativeDecodingMode.DRAFT_TOKENS_EXTERNAL
assert self.backend == 'pytorch'
assert self.backend in ['pytorch']
assert self.speculative_config.max_draft_len > 0
self.build_config.speculative_decoding_mode = SpeculativeDecodingMode.DRAFT_TOKENS_EXTERNAL
self.build_config.max_draft_len = self.speculative_config.max_draft_len
from tensorrt_llm._torch.speculative import DraftTargetConfig
self.speculative_config = DraftTargetConfig(
max_draft_tokens=self.speculative_config.max_draft_len,
draft_model_path=self.speculative_config.
pytorch_weights_path)
elif isinstance(self.speculative_config, MTPDecodingConfig):
from tensorrt_llm._torch.speculative import MTPConfig
self.speculative_config = MTPConfig(
num_nextn_predict_layers=self.speculative_config.
num_nextn_predict_layers,
max_batch_size=self.build_config.max_batch_size,
use_relaxed_acceptance_for_thinking=self.speculative_config.
use_relaxed_acceptance_for_thinking,
relaxed_topk=self.speculative_config.relaxed_topk,
relaxed_delta=self.speculative_config.relaxed_delta,
use_mtp_vanilla=self.speculative_config.use_mtp_vanilla)
assert self.speculative_config.num_nextn_predict_layers > 0
self.speculative_config.max_draft_len = self.speculative_config.num_nextn_predict_layers
elif isinstance(self.speculative_config,
UserProvidedDecodingConfig):
assert self.backend in ['pytorch', '_autodeploy']
from tensorrt_llm._torch.speculative import UserProvidedConfig
self.speculative_config = UserProvidedConfig(
max_draft_tokens=self.speculative_config.max_draft_len,
drafter=self.speculative_config.drafter)
self.build_config.speculative_decoding_mode = SpeculativeDecodingMode.USER_PROVIDED
self.build_config.max_draft_len = self.speculative_config.max_draft_tokens
self.build_config.max_draft_len = self.speculative_config.max_draft_len
else:
raise ValueError(
f"Speculative config type not recognized: {self.speculative_config}"
f"Unrecognized speculative config type {type(self.speculative_config)}"
)
else:
self.decoding_config = None
self._speculative_model = getattr(self.speculative_config,
"speculative_model", None)
"speculative_model_dir", None)
speculative_model_obj = _ModelWrapper(
self._speculative_model
) if self._speculative_model is not None else None

View File

@ -110,8 +110,8 @@ class ModelLoader:
self.model_obj = _ModelWrapper(self.llm_args.model)
self.speculative_model_obj = _ModelWrapper(
self.llm_args.speculative_model
) if self.llm_args.speculative_model is not None else None
self.llm_args.speculative_model_dir
) if self.llm_args.speculative_model_dir is not None else None
if isinstance(self.llm_args, TrtLlmArgs):
self.convert_checkpoint_options = self.llm_args._convert_checkpoint_options
@ -440,8 +440,8 @@ class ModelLoader:
model_cls = AutoModelForCausalLM.get_trtllm_model_class(
self._model_dir, self.llm_args.trust_remote_code,
self.llm_args.decoding_config.decoding_mode
if hasattr(self.llm_args, "speculative_model")
and self.llm_args.speculative_model else None)
if hasattr(self.llm_args, "speculative_model_dir")
and self.llm_args.speculative_model_dir else None)
prequantized = self._update_from_hf_quant_config()
@ -484,7 +484,7 @@ class ModelLoader:
load_model_on_cpu=
True, # TODO:TRTLLM-195 to enhance the weights loading memory usage and chose best location
trust_remote_code=self.llm_args.trust_remote_code,
speculative_model=self._speculative_model_dir,
speculative_model_dir=self._speculative_model_dir,
speculative_config=self.llm_args.speculative_config
if not isinstance(self.llm_args.speculative_config,
LookaheadDecodingConfig) else None,

View File

@ -59,7 +59,7 @@ class EagleConfig(LLaMAConfig):
**kwargs):
import transformers
trust_remote_code = kwargs.pop('trust_remote_code', True)
speculative_config_or_dir = kwargs.pop('speculative_model', None)
speculative_config_or_dir = kwargs.pop('speculative_model_dir', None)
if isinstance(hf_config_or_dir, transformers.PretrainedConfig):
hf_config = hf_config_or_dir

View File

@ -948,11 +948,11 @@ class EagleForCausalLM(LLaMAForCausalLM):
spec_decoding_position_offsets: [bs, max_gen_tokens]
spec_decoding_packed_mask: [bs, max_draft_len, packed_length] **
eagle_temperature: [bs]
rand_data_validation: [bs, max_draft_tokens]
rand_data_validation: [bs, max_draft_len]
** The mask is tricky since the boolean mask will need to be
packed in runtime. So, the last dim will be:
packed_length = ceil((max_draft_tokens+1)/32)
packed_length = ceil((max_draft_len+1)/32)
"""
default_range = GenerationMixin.default_range
remove_input_padding = default_net().plugin_config.remove_input_padding
@ -1228,7 +1228,7 @@ class EagleForCausalLM(LLaMAForCausalLM):
quant_config: Optional[QuantConfig] = None,
**kwargs):
assert hf_model_or_dir is not None
speculative_model_dir = kwargs.get('speculative_model', None)
speculative_model_dir = kwargs.get('speculative_model_dir', None)
tllm_config = EagleConfig.from_hugging_face(hf_model_or_dir,
dtype=dtype,
mapping=mapping,

View File

@ -70,7 +70,7 @@ class MedusaConfig(PretrainedConfig):
import transformers
trust_remote_code = kwargs.pop('trust_remote_code', True)
speculative_config_or_dir = kwargs.pop('speculative_model', None)
speculative_config_or_dir = kwargs.pop('speculative_model_dir', None)
speculative_config = kwargs.pop("speculative_config", None)
if isinstance(hf_config_or_dir, transformers.PretrainedConfig):

View File

@ -191,7 +191,7 @@ class MedusaForCausalLm(PretrainedModel):
import transformers
assert hf_model_or_dir is not None
speculative_model_dir = kwargs.get('speculative_model', None)
speculative_model_dir = kwargs.get('speculative_model_dir', None)
use_preloading = isinstance(hf_model_or_dir,
transformers.PreTrainedModel)

View File

@ -179,15 +179,15 @@ class ReDrafterMixin:
bb_range = default_range(max_batch_size)
bb0_range = default_range(max_batch_size, min_range=0, opt_offset=1)
num_beam_tokens = self.num_beams * self.beam_length
max_draft_tokens = num_beam_tokens - self.num_beams # ignore the true token
max_gen_token_len = 1 + max_draft_tokens # for the true token
max_draft_len = num_beam_tokens - self.num_beams # ignore the true token
max_gen_token_len = 1 + max_draft_len # for the true token
max_gen_token_len_range = default_range(max_gen_token_len)
bb_max_gen_token_len_range = default_range(max_gen_token_len *
max_batch_size,
min_range=0)
kwargs['speculative_decoding_draft_tokens_external'] = False
kwargs['max_draft_len'] = max_draft_tokens
kwargs['max_draft_len'] = max_draft_len
kwargs['spec_decoding_is_generation_length_variable'] = True
inputs = super().prepare_inputs(*args, **kwargs)
assert inputs['spec_decoding_params'] is not None

View File

@ -157,6 +157,8 @@ class AccuracyTask:
elif isinstance(llm.args.speculative_config, DecodingBaseConfig):
spec_dec_algo = llm.args.speculative_config.decoding_type
elif isinstance(llm.args.speculative_config, SpecConfig):
# This branch is deprecated, but thread-leak of pytest raises flaky error if removing it.
# TODO: remove this branch safely.
spec_dec_algo = llm.args.speculative_config.spec_dec_name
else:
raise ValueError(

View File

@ -214,7 +214,7 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
def test_ngram(self):
speculative_decoding_config = {
"decoding_type": "NGram",
"prompt_lookup_num_tokens": 4,
"max_draft_len": 4,
"max_matching_ngram_size": 4,
"is_keep_all": True,
"is_use_oldest": True,

View File

@ -393,7 +393,7 @@ class TestEagleVicuna_7B_v1_3(LlmapiAccuracyTestHarness):
speculative_config = EagleDecodingConfig(
max_draft_len=63,
speculative_model=f"{llm_models_root()}/EAGLE-Vicuna-7B-v1.3",
speculative_model_dir=f"{llm_models_root()}/EAGLE-Vicuna-7B-v1.3",
num_eagle_layers=4,
max_non_leaves_per_layer=10,
eagle_choices=[[0], [0, 0], [1], [0, 1], [2], [0, 0, 0], [1, 0], [0, 2], [3], [0, 3], [4], [0, 4], [2, 0], \
@ -419,7 +419,7 @@ class TestEagle2Vicuna_7B_v1_3(LlmapiAccuracyTestHarness):
speculative_config = EagleDecodingConfig(
max_draft_len=63,
speculative_model=f"{llm_models_root()}/EAGLE-Vicuna-7B-v1.3",
speculative_model_dir=f"{llm_models_root()}/EAGLE-Vicuna-7B-v1.3",
num_eagle_layers=4,
max_non_leaves_per_layer=10,
use_dynamic_tree=True,

View File

@ -243,7 +243,7 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
draft_len = 4
spec_config = EagleDecodingConfig(max_draft_len=draft_len,
pytorch_weights_path=eagle_model_dir)
speculative_model_dir=eagle_model_dir)
llm = LLM(model=target_model_dir,
**pytorch_config,
@ -262,7 +262,7 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
draft_len = 4
spec_config = NGramDecodingConfig(
prompt_lookup_num_tokens=draft_len,
max_draft_len=draft_len,
max_matching_ngram_size=draft_len,
is_keep_all=True,
is_use_oldest=True,

View File

@ -18,7 +18,7 @@ generation_servers:
- "localhost:8002"
speculative_config:
decoding_type: NGram
prompt_lookup_num_tokens: 4
max_draft_len: 4
max_matching_ngram_size: 4
is_keep_all: True
is_use_oldest: True

View File

@ -22,45 +22,43 @@ def test_llama_draft_target(use_cuda_graph: bool, attn_backend: str):
pytest.skip("Not enough memory to load target model")
models_path = llm_models_root()
kv_cache_config = KvCacheConfig(enable_block_reuse=False, max_tokens=2080)
sampling_params = SamplingParams(
max_tokens=32,
temperature=0,
)
max_batch_size = 1
target_model_dir = f"{models_path}/llama-3.1-model/Llama-3.1-8B-Instruct"
draft_model_dir = f"{models_path}/llama-3.1-model/Llama-3.1-8B-Instruct"
target_model_dir = f"{models_path}/llama-3.1-model/Llama-3.1-8B-Instruct"
max_batch_size = 2
max_draft_len = 4
kv_cache_config = KvCacheConfig(enable_block_reuse=False)
cuda_graph_config = CudaGraphConfig(
batch_sizes=[1]) if use_cuda_graph else None
llm_common_config = dict(
model=target_model_dir,
backend='pytorch',
attn_backend=attn_backend,
disable_overlap_scheduler=True,
cuda_graph_config=cuda_graph_config,
max_batch_size=max_batch_size,
kv_cache_config=kv_cache_config,
max_num_tokens=2048,
)
draft_len = 4
spec_config = DraftTargetDecodingConfig(
max_draft_len=draft_len, pytorch_weights_path=draft_model_dir)
llm_spec = LLM(model=target_model_dir,
max_batch_size=max_batch_size,
disable_overlap_scheduler=True,
cuda_graph_config=CudaGraphConfig(
cuda_graph_batch_sizes=[1]) if use_cuda_graph else None,
attn_backend=attn_backend,
kv_cache_config=kv_cache_config,
speculative_config=spec_config)
max_draft_len=max_draft_len,
speculative_model_dir=draft_model_dir,
)
prompts = [
"The capital of France is", "The president of the United States is"
"The capital of France is",
"The president of the United States is",
]
sampling_params = SamplingParams(max_tokens=32)
llm_spec = LLM(**llm_common_config, speculative_config=spec_config)
results_spec = llm_spec.generate(prompts, sampling_params)
generated_text_spec = [result.outputs[0].text for result in results_spec]
llm_spec.shutdown()
llm_ref = LLM(model=target_model_dir,
max_batch_size=max_batch_size,
disable_overlap_scheduler=True,
cuda_graph_config=CudaGraphConfig(
cuda_graph_batch_sizes=[1]) if use_cuda_graph else None,
attn_backend=attn_backend,
kv_cache_config=kv_cache_config)
llm_ref = LLM(**llm_common_config)
results_ref = llm_ref.generate(prompts, sampling_params)
generated_text_ref = [result.outputs[0].text for result in results_ref]
llm_ref.shutdown()

View File

@ -32,81 +32,70 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str,
pytest.skip("Not enough memory to load target + draft model")
models_path = llm_models_root()
pytorch_config = dict(
disable_overlap_scheduler=disable_overlap_scheduler,
# Only create a single CUDA graph to prevent OOM in CI
attn_backend=attn_backend,
)
cuda_graph_config = CudaGraphConfig(
batch_sizes=[1]) if use_cuda_graph else None
kv_cache_config = KvCacheConfig(enable_block_reuse=enable_block_reuse, )
eagle_model_dir = f"{models_path}/EAGLE3-LLaMA3.1-Instruct-8B"
target_model_dir = f"{models_path}/llama-3.1-model/Llama-3.1-8B-Instruct"
draft_len = 4
spec_config = EagleDecodingConfig(
max_draft_len=draft_len,
pytorch_weights_path=eagle_model_dir,
# Llama 3 does not support one model eagle.
eagle3_one_model=use_one_model)
# bs > 1 gives non-deterministic when doing IFB. There are slight chances
# that ref and spec does not match 100%
max_batch_size = 1
max_draft_len = 4
kv_cache_config = KvCacheConfig(enable_block_reuse=enable_block_reuse,
free_gpu_memory_fraction=0.5)
cuda_graph_config = CudaGraphConfig(
batch_sizes=[1]) if use_cuda_graph else None
llm_spec = LLM(
llm_common_config = dict(
model=target_model_dir,
**pytorch_config,
# bs > 1 gives non-deterministic when doing IFB. There are slight chances
# that ref and spec does not match 100%
max_batch_size=1,
attn_backend=attn_backend,
disable_overlap_scheduler=disable_overlap_scheduler,
cuda_graph_config=cuda_graph_config,
max_batch_size=max_batch_size,
kv_cache_config=kv_cache_config,
# This max_seq_len is larger than the one specified
# in the llama 3 8B eagle's config. We want to make sure
# that the draft model won't go above its max in warmup
# in this test.
max_seq_len=8192,
kv_cache_config=kv_cache_config,
cuda_graph_config=cuda_graph_config,
speculative_config=spec_config)
sampling_params = SamplingParams(
max_tokens=10,
temperature=0,
)
# First make sure the acceptance rate is reasonable.
spec_config = EagleDecodingConfig(
max_draft_len=max_draft_len,
speculative_model_dir=eagle_model_dir,
# Llama 3 does not support one model eagle.
eagle3_one_model=use_one_model,
)
llm_spec = LLM(**llm_common_config, speculative_config=spec_config)
# Acceptance rate tests
tok_ids = llm_spec.tokenizer.encode("The future of AI is")
num_tokens = 0
num_drafted = 0
num_accepted = 0
sampling_params = SamplingParams(max_tokens=128, temperature=0)
for output in llm_spec.generate_async(tok_ids,
SamplingParams(max_tokens=128,
temperature=0),
sampling_params,
streaming=True):
beam = output.outputs[0]
new_tokens = beam.token_ids
num_drafted += draft_len
new_tokens = output.outputs[0].token_ids
num_drafted += max_draft_len
num_accepted += len(new_tokens) - num_tokens - 1
num_tokens = len(new_tokens)
accept_rate = num_accepted / num_drafted
assert accept_rate > 0.15
# Output tests
prompts = [
"The capital of France is", "The president of the United States is"
"The capital of France is",
"The president of the United States is",
]
sampling_params = SamplingParams(max_tokens=10, temperature=0)
results_spec = llm_spec.generate(prompts, sampling_params)
generated_text_spec = [result.outputs[0].text for result in results_spec]
llm_spec.shutdown()
llm_ref = LLM(model=target_model_dir,
**pytorch_config,
kv_cache_config=kv_cache_config,
cuda_graph_config=cuda_graph_config)
llm_ref = LLM(**llm_common_config)
results_ref = llm_ref.generate(prompts, sampling_params)
generated_text_ref = [result.outputs[0].text for result in results_ref]
llm_ref.shutdown()

View File

@ -6,9 +6,9 @@ from parameterized import parameterized
import tensorrt_llm
from tensorrt_llm._torch.attention_backend import TrtllmAttentionMetadata
from tensorrt_llm._torch.metadata import KVCacheParams
from tensorrt_llm._torch.speculative.mtp import (MTPConfig,
MTPHiddenStatesManager,
from tensorrt_llm._torch.speculative.mtp import (MTPHiddenStatesManager,
MTPSpecMetadata, MTPWorker)
from tensorrt_llm.llmapi import MTPDecodingConfig
def unittest_name_func(testcase_func, param_num, param):
@ -40,15 +40,14 @@ class TestMTPSampleAndAcceptDraftTokens(unittest.TestCase):
device="cuda") # [num_tokens, vocab_size]
draft_tokens = torch.tensor(
[], dtype=torch.int,
device="cuda") # [batch_size * max_draft_tokens]
[], dtype=torch.int, device="cuda") # [batch_size * max_draft_len]
draft_len = torch.tensor([0], dtype=torch.int,
device="cuda") # [batch_size]
ref_accepted_tokens = torch.tensor(
[[1, 0]], dtype=torch.int,
device="cuda") # [batch_size * max_draft_tokens]
device="cuda") # [batch_size * max_draft_len]
ref_num_accepted_tokens = torch.tensor([1],
dtype=torch.int,
@ -74,15 +73,14 @@ class TestMTPSampleAndAcceptDraftTokens(unittest.TestCase):
device="cuda") # [num_tokens, vocab_size]
draft_tokens = torch.tensor(
[], dtype=torch.int,
device="cuda") # [batch_size * max_draft_tokens]
[], dtype=torch.int, device="cuda") # [batch_size * max_draft_len]
draft_len = torch.tensor([0, 0, 0, 0], dtype=torch.int,
device="cuda") # [batch_size]
ref_accepted_tokens = torch.tensor(
[[1, 0], [3, 0], [3, 0], [6, 0]], dtype=torch.int,
device="cuda") # [batch_size * max_draft_tokens]
device="cuda") # [batch_size * max_draft_len]
ref_num_accepted_tokens = torch.tensor([1, 1, 1, 1],
dtype=torch.int,
@ -111,14 +109,14 @@ class TestMTPSampleAndAcceptDraftTokens(unittest.TestCase):
draft_tokens = torch.tensor(
[1, 3, 4], dtype=torch.int,
device="cuda") # [batch_size * max_draft_tokens]
device="cuda") # [batch_size * max_draft_len]
draft_len = torch.tensor([3], dtype=torch.int,
device="cuda") # [batch_size]
ref_accepted_tokens = torch.tensor(
[[1, 3, 2, 0]], dtype=torch.int,
device="cuda") # [batch_size * max_draft_tokens]
device="cuda") # [batch_size * max_draft_len]
ref_num_accepted_tokens = torch.tensor([3],
dtype=torch.int,
@ -147,14 +145,14 @@ class TestMTPSampleAndAcceptDraftTokens(unittest.TestCase):
draft_tokens = torch.tensor(
[1, 5], dtype=torch.int,
device="cuda") # [batch_size * max_draft_tokens]
device="cuda") # [batch_size * max_draft_len]
draft_len = torch.tensor([1, 1], dtype=torch.int,
device="cuda") # [batch_size]
ref_accepted_tokens = torch.tensor(
[[1, 3], [4, 0]], dtype=torch.int,
device="cuda") # [batch_size * max_draft_tokens]
device="cuda") # [batch_size * max_draft_len]
ref_num_accepted_tokens = torch.tensor([2, 1],
dtype=torch.int,
@ -187,14 +185,14 @@ class TestMTPSampleAndAcceptDraftTokens(unittest.TestCase):
draft_tokens = torch.tensor(
[1, 3, 4, 4, 7, 3], dtype=torch.int,
device="cuda") # [batch_size * max_draft_tokens]
device="cuda") # [batch_size * max_draft_len]
draft_len = torch.tensor([3, 3], dtype=torch.int,
device="cuda") # [batch_size]
ref_accepted_tokens = torch.tensor(
[[1, 3, 2, 0], [4, 6, 0, 0]], dtype=torch.int,
device="cuda") # [batch_size * max_draft_tokens]
device="cuda") # [batch_size * max_draft_len]
ref_num_accepted_tokens = torch.tensor([3, 2],
dtype=torch.int,
@ -231,7 +229,7 @@ class TestMTPSampleAndAcceptDraftTokens(unittest.TestCase):
draft_tokens = torch.tensor(
[1, 3, 5, 4, 6, 5, 5, 7, 4], dtype=torch.int,
device="cuda") # [batch_size * max_draft_tokens]
device="cuda") # [batch_size * max_draft_len]
draft_len = torch.tensor([3, 3, 3], dtype=torch.int,
device="cuda") # [batch_size]
@ -239,7 +237,7 @@ class TestMTPSampleAndAcceptDraftTokens(unittest.TestCase):
ref_accepted_tokens = torch.tensor(
[[1, 3, 2, 0], [4, 6, 5, 2], [4, 0, 0, 0]],
dtype=torch.int,
device="cuda") # [batch_size * max_draft_tokens]
device="cuda") # [batch_size * max_draft_len]
ref_num_accepted_tokens = torch.tensor([3, 4, 1],
dtype=torch.int,
@ -267,15 +265,14 @@ class TestMTPSampleAndAcceptDraftTokens(unittest.TestCase):
device="cuda") # [num_tokens, vocab_size]
draft_tokens = torch.tensor(
[4], dtype=torch.int,
device="cuda") # [batch_size * max_draft_tokens]
[4], dtype=torch.int, device="cuda") # [batch_size * max_draft_len]
draft_len = torch.tensor([0, 1], dtype=torch.int,
device="cuda") # [batch_size]
ref_accepted_tokens = torch.tensor(
[[1, 0], [4, 6]], dtype=torch.int,
device="cuda") # [batch_size * max_draft_tokens]
device="cuda") # [batch_size * max_draft_len]
ref_num_accepted_tokens = torch.tensor([1, 2],
dtype=torch.int,
@ -297,7 +294,8 @@ class TestMTPSampleAndAcceptDraftTokens(unittest.TestCase):
ref_accepted_tokens,
ref_num_accepted_tokens):
batch_size = len(draft_len)
spec_config = MTPConfig(num_nextn_predict_layers=mtp_num_modules)
spec_config = MTPDecodingConfig(
num_nextn_predict_layers=mtp_num_modules)
# attention metedata
attn_metadata = TrtllmAttentionMetadata(max_num_requests=batch_size,
@ -310,7 +308,7 @@ class TestMTPSampleAndAcceptDraftTokens(unittest.TestCase):
# speculative decoding metadata
spec_metadata = MTPSpecMetadata(max_num_requests=32,
spec_dec_mode=spec_config.spec_dec_mode,
max_draft_tokens=mtp_num_modules,
max_draft_len=mtp_num_modules,
mtp_num_modules=mtp_num_modules)
spec_metadata.draft_tokens = draft_tokens
@ -871,7 +869,7 @@ class TestMTPUpdateMTPHiddenStates(unittest.TestCase):
batch_size = len(request_ids)
batch_size - num_context_request
hidden_size = hidden_states.shape[1]
spec_config = MTPConfig(
spec_config = MTPDecodingConfig(
num_nextn_predict_layers=num_nextn_predict_layers)
attn_metadata = TrtllmAttentionMetadata(max_num_requests=batch_size,
@ -892,7 +890,7 @@ class TestMTPUpdateMTPHiddenStates(unittest.TestCase):
spec_metadata = MTPSpecMetadata(
max_num_requests=32,
spec_dec_mode=spec_config.spec_dec_mode,
max_draft_tokens=num_nextn_predict_layers,
max_draft_len=num_nextn_predict_layers,
mtp_num_modules=num_nextn_predict_layers,
mtp_hidden_states_manager=spec_manager)
spec_metadata.request_ids = request_ids
@ -1362,7 +1360,7 @@ class TestMTPPrepareDrafterInputs(unittest.TestCase):
hidden_size = previous_layer_hidden_states.shape[1]
else:
hidden_size = 10
spec_config = MTPConfig(
spec_config = MTPDecodingConfig(
num_nextn_predict_layers=num_nextn_predict_layers)
if attn_metadata is None:
@ -1387,7 +1385,7 @@ class TestMTPPrepareDrafterInputs(unittest.TestCase):
spec_metadata = MTPSpecMetadata(
max_num_requests=32,
spec_dec_mode=spec_config.spec_dec_mode,
max_draft_tokens=num_nextn_predict_layers,
max_draft_len=num_nextn_predict_layers,
mtp_num_modules=num_nextn_predict_layers,
mtp_hidden_states_manager=spec_manager)
spec_metadata.request_ids = request_ids

View File

@ -24,6 +24,10 @@ def test_llama_ngram(disable_overlap_scheduler: bool, use_cuda_graph: bool,
total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
if total_mem_gb < 20:
pytest.skip("Not enough memory to load target model")
max_batch_size = 2
max_draft_len = 4
kv_cache_config = KvCacheConfig(enable_block_reuse=False)
cuda_graph_config = CudaGraphConfig(
batch_sizes=[1]) if use_cuda_graph else None
@ -33,13 +37,13 @@ def test_llama_ngram(disable_overlap_scheduler: bool, use_cuda_graph: bool,
attn_backend=attn_backend,
disable_overlap_scheduler=disable_overlap_scheduler,
cuda_graph_config=cuda_graph_config,
max_batch_size=4,
kv_cache_config=KvCacheConfig(enable_block_reuse=False),
max_batch_size=max_batch_size,
kv_cache_config=kv_cache_config,
max_num_tokens=2048,
)
spec_config = NGramDecodingConfig(
prompt_lookup_num_tokens=4,
max_draft_len=max_draft_len,
max_matching_ngram_size=2,
is_keep_all=True,
is_use_oldest=True,

View File

@ -6,9 +6,9 @@ import pytest
import torch
from tensorrt_llm import LLM, SamplingParams
from tensorrt_llm._torch.speculative.ngram import (NGramConfig, NGramDrafter,
NGramPoolManager)
from tensorrt_llm._torch.speculative.ngram import NGramDrafter, NGramPoolManager
from tensorrt_llm.llmapi import (CudaGraphConfig, KvCacheConfig,
NGramDecodingConfig,
UserProvidedDecodingConfig)
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
@ -25,11 +25,13 @@ def test_llama_user_provided(disable_overlap_scheduler: bool,
total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
if total_mem_gb < 20:
pytest.skip("Not enough memory to load target model")
max_batch_size = 2
max_draft_len = 4
kv_cache_config = KvCacheConfig(enable_block_reuse=False)
cuda_graph_config = CudaGraphConfig(
batch_sizes=[1]) if use_cuda_graph else None
max_batch_size = 4
llm_common_config = dict( \
model=llm_models_root() / "llama-3.1-model" /"Meta-Llama-3.1-8B",
backend='pytorch',
@ -37,27 +39,32 @@ def test_llama_user_provided(disable_overlap_scheduler: bool,
disable_overlap_scheduler=disable_overlap_scheduler,
cuda_graph_config=cuda_graph_config,
max_batch_size=max_batch_size,
kv_cache_config=KvCacheConfig(enable_block_reuse=False),
kv_cache_config=kv_cache_config,
max_num_tokens=2048,
)
draft_len = 4
ngram_config = NGramConfig(
prompt_lookup_num_tokens=draft_len,
ngram_config = NGramDecodingConfig(
max_draft_len=max_draft_len,
max_matching_ngram_size=2,
is_keep_all=True,
is_use_oldest=True,
is_public_pool=True,
)
drafter = NGramDrafter(spec_config=ngram_config,
ngram_pool_manager=NGramPoolManager(
spec_config=ngram_config,
max_num_requests=max_batch_size))
ngram_pool_manager = NGramPoolManager(
spec_config=ngram_config,
max_num_requests=max_batch_size,
)
spec_config = UserProvidedDecodingConfig(max_draft_len=draft_len,
drafter=drafter)
drafter = NGramDrafter(
spec_config=ngram_config,
ngram_pool_manager=ngram_pool_manager,
)
spec_config = UserProvidedDecodingConfig(
max_draft_len=max_draft_len,
drafter=drafter,
)
prompts = [
"The capital of France is",

View File

@ -1128,7 +1128,7 @@ def test_llm_api_medusa():
speculative_config = MedusaDecodingConfig(num_medusa_heads=4,
max_draft_len=63,
speculative_model=get_model_path("medusa-vicuna-7b-v1.3"),
speculative_model_dir=get_model_path("medusa-vicuna-7b-v1.3"),
medusa_choices=[[0], [0, 0], [1], [0, 1], [2], [0, 0, 0], [1, 0], [0, 2], [3], [0, 3], [4], [0, 4], [2, 0], \
[0, 5], [0, 0, 1], [5], [0, 6], [6], [0, 7], [0, 1, 0], [1, 1], [7], [0, 8], [0, 0, 2], [3, 0], \
[0, 9], [8], [9], [1, 0, 0], [0, 2, 0], [1, 2], [0, 0, 3], [4, 0], [2, 1], [0, 0, 4], [0, 0, 5], \
@ -1167,7 +1167,7 @@ def test_llm_api_medusa_tp2():
speculative_config = MedusaDecodingConfig(num_medusa_heads=4,
max_draft_len=63,
speculative_model=get_model_path("medusa-vicuna-7b-v1.3"),
speculative_model_dir=get_model_path("medusa-vicuna-7b-v1.3"),
medusa_choices=[[0], [0, 0], [1], [0, 1], [2], [0, 0, 0], [1, 0], [0, 2], [3], [0, 3], [4], [0, 4], [2, 0], \
[0, 5], [0, 0, 1], [5], [0, 6], [6], [0, 7], [0, 1, 0], [1, 1], [7], [0, 8], [0, 0, 2], [3, 0], \
[0, 9], [8], [9], [1, 0, 0], [0, 2, 0], [1, 2], [0, 0, 3], [4, 0], [2, 1], [0, 0, 4], [0, 0, 5], \
@ -1205,7 +1205,7 @@ def test_llm_api_eagle(**llm_kwargs):
speculative_config = EagleDecodingConfig(
max_draft_len=63,
speculative_model=get_model_path("EAGLE-Vicuna-7B-v1.3"),
speculative_model_dir=get_model_path("EAGLE-Vicuna-7B-v1.3"),
num_eagle_layers=4,
max_non_leaves_per_layer=10,
eagle_choices=[[0], [0, 0], [1], [0, 1], [2], [0, 0, 0], [1, 0], [0, 2], [3], [0, 3], [4], [0, 4], [2, 0], \
@ -1252,7 +1252,7 @@ def test_llm_api_eagle2(**llm_kwargs):
speculative_config = EagleDecodingConfig(
max_draft_len=63,
speculative_model=get_model_path("EAGLE-Vicuna-7B-v1.3"),
speculative_model_dir=get_model_path("EAGLE-Vicuna-7B-v1.3"),
num_eagle_layers=4,
max_non_leaves_per_layer=10,
use_dynamic_tree=True,