This commit is contained in:
Venky 2026-01-13 21:25:08 +08:00 committed by GitHub
commit 7c60356daf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 208 additions and 97 deletions

View File

@ -84,7 +84,7 @@ kv_cache_config:
enable_block_reuse: false
free_gpu_memory_fraction: 0.8
speculative_config:
decoding_type: Eagle
decoding_type: Eagle3
max_draft_len: 3
speculative_model_dir: /config/models/eagle/
cuda_graph_config:

View File

@ -68,7 +68,7 @@ docker run -d --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 \
-p 8000:8000 --gpus=all -e "TRTLLM_ENABLE_PDL=1" \
-v /path/to/maverick:/config/models/maverick -v /path/to/eagle:/config/models/eagle \
docker.io/<username>/tensorrt_llm:main sh \
-c "echo -e 'enable_autotuner: false\nenable_attention_dp: false\nenable_min_latency: true\ncuda_graph_config:\n max_batch_size: 8\nspeculative_config:\n decoding_type: Eagle\n max_draft_len: 3\n speculative_model_dir: /config/models/eagle\n eagle3_one_model: true\nkv_cache_config:\n enable_block_reuse: false' > c.yaml && \
-c "echo -e 'enable_autotuner: false\nenable_attention_dp: false\nenable_min_latency: true\ncuda_graph_config:\n max_batch_size: 8\nspeculative_config:\n decoding_type: Eagle3\n max_draft_len: 3\n speculative_model_dir: /config/models/eagle\n eagle3_one_model: true\nkv_cache_config:\n enable_block_reuse: false' > c.yaml && \
TRT_LLM_DISABLE_LOAD_WEIGHTS_IN_PARALLEL=True \
trtllm-serve /config/models/maverick \
--host 0.0.0.0 --port 8000 \

View File

@ -53,12 +53,12 @@ The following draft model checkpoints can be used for EAGLE 3:
* Llama 4 Maverick: [use the checkpoint from the NVIDIA HuggingFace repository](https://huggingface.co/nvidia/Llama-4-Maverick-17B-128E-Eagle3).
```python
from tensorrt_llm.llmapi import EagleDecodingConfig
from tensorrt_llm.llmapi import Eagle3DecodingConfig
# Enable to use the faster one-model implementation for Llama 4.
eagle3_one_model = False
speculative_config = EagleDecodingConfig(
speculative_config = Eagle3DecodingConfig(
max_draft_len=3, speculative_model_dir="/path/to/draft_model", eagle3_one_model=eagle3_one_model)
# Only need to disable overlap scheduler if eagle3_one_model is False.
@ -131,16 +131,18 @@ llm = LLM("/path/to/target_model", speculative_config=speculative_config)
Speculative decoding options must be specified via `--config config.yaml` for both `trtllm-bench` and `trtllm-serve`. All speculative decoding options can be specified in this YAML file. An additional `decoding_type` option is used to specify the type of speculation to use. The available options are:
* `MTP`
* `Eagle` (for EAGLE 3)
* `Eagle3`
* `NGram`
* `DraftTarget`
> Note: The PyTorch backend supports only `Eagle3`. `decoding_type: Eagle` is accepted as a backward-compatible alias for `Eagle3`, but EAGLE (v1/v2) draft checkpoints are incompatible.
The rest of the argument names/valid values are the same as in their corresponding configuration class described in the Quick Start section. For example, a YAML configuration could look like this:
```
disable_overlap_scheduler: true
speculative_config:
decoding_type: Eagle
decoding_type: Eagle3
max_draft_len: 4
speculative_model: /path/to/draft/model
```

View File

@ -96,7 +96,7 @@ speculative_config:
mtp_eagle_one_model: False # Not supported
speculative_config:
decoding_type: "Eagle"
decoding_type: "Eagle3"
eagle3_one_model: False # Not supported
```

View File

@ -171,6 +171,8 @@ The EAGLE approach enhances the single-model Medusa method by predicting and ver
Similarly to ReDrafter, TensorRT-LLM implements the EAGLE model such that logits prediction, draft tokens acceptance and draft token generation are performed inside of the TensorRT engine(EAGLE-1 and EAGLE-2 are both supported). Please, visit the [EAGLE README](https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/eagle/README.md) for information about building and running the model.
> **EAGLE3 note.** If the EAGLE3 draft head config omits `draft_vocab_size`, TensorRT-LLM assumes it matches `vocab_size` and emits a warning. Set `draft_vocab_size` explicitly if the draft head uses a different vocabulary.
### Disaggregated Serving
[Disaggregated Serving](https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/features/disaggregated-service.md) with EAGLE3 using the two model approach is supported in the Pytorch backend. Please refer to the following [Dynamo example](https://github.com/ai-dynamo/dynamo/blob/main/examples/tensorrt_llm/llama4_plus_eagle.md) on how to run EAGLE3 with Disaggregated Serving for Llama 4 Maverick.

View File

@ -6,7 +6,7 @@ from typing import Optional
import click
from tensorrt_llm import LLM, SamplingParams
from tensorrt_llm.llmapi import (EagleDecodingConfig, KvCacheConfig,
from tensorrt_llm.llmapi import (Eagle3DecodingConfig, KvCacheConfig,
MTPDecodingConfig, NGramDecodingConfig)
prompts = [
@ -33,7 +33,7 @@ def run_MTP(model: Optional[str] = None):
def run_Eagle3():
spec_config = EagleDecodingConfig(
spec_config = Eagle3DecodingConfig(
max_draft_len=3,
speculative_model_dir="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
eagle3_one_model=True)

View File

@ -5,7 +5,7 @@ import time
from tensorrt_llm import LLM, SamplingParams
from tensorrt_llm.llmapi import (AttentionDpConfig, AutoDecodingConfig,
CudaGraphConfig, DraftTargetDecodingConfig,
EagleDecodingConfig, KvCacheConfig, MoeConfig,
Eagle3DecodingConfig, KvCacheConfig, MoeConfig,
MTPDecodingConfig, NGramDecodingConfig,
TorchCompileConfig)
@ -222,7 +222,7 @@ def setup_llm(args, **kwargs):
mtp_eagle_one_model=args.use_one_model,
speculative_model_dir=args.model_dir)
elif spec_decode_algo == "EAGLE3":
spec_config = EagleDecodingConfig(
spec_config = Eagle3DecodingConfig(
max_draft_len=args.spec_decode_max_draft_len,
speculative_model_dir=args.draft_model_dir,
eagle3_one_model=args.use_one_model,

View File

@ -837,8 +837,8 @@ settings for your specific use case.
Qwen3 now supports Eagle3 (Speculative Decoding with Eagle3). To enable Eagle3 on Qwen3, you need to set the following arguments when running `trtllm-bench` or `trtllm-serve`:
- `speculative_config.decoding_type: Eagle`
Set the decoding type to "Eagle" to enable Eagle3 speculative decoding.
- `speculative_config.decoding_type: Eagle3`
Set the decoding type to `Eagle3` to enable Eagle3 speculative decoding.
- `speculative_config.max_draft_len: 3`
Set the maximum number of draft tokens generated per step (this value can be adjusted as needed).
- `speculative_config.speculative_model_dir: <EAGLE3_DRAFT_MODEL_PATH>`
@ -855,7 +855,7 @@ Example `config.yml` snippet for Eagle3:
echo "
enable_attention_dp: false
speculative_config:
decoding_type: Eagle
decoding_type: Eagle3
max_draft_len: 3
speculative_model_dir: <EAGLE3_DRAFT_MODEL_PATH>
kv_cache_config:

View File

@ -24,12 +24,18 @@ class AutoModelForCausalLM(Generic[TModel, TConfig]):
vision_encoder_cls, vlm_base_model = vision_encoder_info
return vision_encoder_cls(config, vlm_base_model)
# Hack to detect eagle3 checkpoints. TODO: should we provide
# our own checkpoints with the correct arch? It would let us
# avoid nasty stuff like this.
model_arch = model_arch.replace("Eagle3",
"") # Strip the appended EAGLE3
# Hack to detect eagle3 checkpoints.
# Why it exists:
# - Eagle3 checkpoints have draft_vocab_size in config.json (even if None)
# - Some community checkpoints append "Eagle3" to architecture names ("LlamaForCausalLMEagle3")
# - Some checkpoints don't include "Eagle3" in arch name at all ("LlamaForCausalLM")
# - TensorRT-LLM's MODEL_CLASS_MAPPING expects prefixed names like EAGLE3LlamaForCausalLM
# - Hence: LlamaForCausalLMEagle3 -> EAGLE3LlamaForCausalLM
# LlamaForCausalLM (with draft_vocab_size) -> EAGLE3LlamaForCausalLM
# TODO: should we provide our own checkpoints with the correct arch? It would let us avoid nasty stuff like this.
if hasattr(config.pretrained_config, "draft_vocab_size"):
# It's an Eagle3 checkpoint - strip "Eagle3" suffix if present, then add prefix
model_arch = model_arch.replace("Eagle3", "")
model_arch = "EAGLE3" + model_arch
if model_arch in (
"DeepseekV3ForCausalLM", "Glm4MoeForCausalLM"

View File

@ -4,6 +4,8 @@ import torch
from torch import nn
from transformers import LlamaConfig, PretrainedConfig
from tensorrt_llm.logger import logger
from ...functional import PositionEmbeddingType
from ..attention_backend import AttentionMetadata
from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams
@ -24,6 +26,18 @@ from .modeling_utils import (DecoderModel, DecoderModelForCausalLM, TModel,
register_auto_model)
def _ensure_draft_vocab_size(config: PretrainedConfig) -> None:
if hasattr(config,
"draft_vocab_size") and config.draft_vocab_size is not None:
return
logger.warning(
"Missing 'draft_vocab_size' in pretrained config; defaulting to 'vocab_size'. "
"Set 'draft_vocab_size' explicitly if the draft head uses a different vocabulary."
)
config.draft_vocab_size = config.vocab_size
class Eagle3Attention(Attention):
def __init__(
@ -417,9 +431,8 @@ class Eagle3ForCausalLM(DecoderModelForCausalLM[Eagle3DraftModel,
model_config: ModelConfig[PretrainedConfig],
start_layer_idx: int = 0,
):
draft_vocab_size = model_config.pretrained_config.vocab_size
if model_config.pretrained_config.draft_vocab_size is not None:
draft_vocab_size = model_config.pretrained_config.draft_vocab_size
config = model_config.pretrained_config
_ensure_draft_vocab_size(config)
# Determine if we should use MLA attention based on config
# MLA is used for DeepSeekV3-style models that have kv_lora_rank
@ -435,8 +448,8 @@ class Eagle3ForCausalLM(DecoderModelForCausalLM[Eagle3DraftModel,
super().__init__(
draft_model,
config=model_config,
hidden_size=model_config.pretrained_config.hidden_size,
vocab_size=draft_vocab_size,
hidden_size=config.hidden_size,
vocab_size=config.draft_vocab_size,
)
self.load_lm_head_from_target = True
@ -598,6 +611,7 @@ class MistralLarge3DraftModel(DecoderModel):
# We use MistralLarge3 as the base architecture for EAGLE3 draft layers
# NOTE: Class name says "Eagle" not "Eagle3" to match checkpoint naming (e.g., "Mistral-Large-3-675B-Instruct-2512-Eagle")
@register_auto_model("MistralLarge3EagleForCausalLM")
class MistralLarge3EagleForCausalLM(DecoderModelForCausalLM):

View File

@ -10,10 +10,11 @@ from .llm_args import (AttentionDpConfig, AutoDecodingConfig, BatchingType,
CapacitySchedulerPolicy, ContextChunkingPolicy,
CudaGraphConfig, DeepSeekSparseAttentionConfig,
DraftTargetDecodingConfig, DynamicBatchConfig,
EagleDecodingConfig, ExtendedRuntimePerfKnobConfig,
KvCacheConfig, LlmArgs, LookaheadDecodingConfig,
MedusaDecodingConfig, MoeConfig, MTPDecodingConfig,
NGramDecodingConfig, RocketSparseAttentionConfig,
Eagle3DecodingConfig, EagleDecodingConfig,
ExtendedRuntimePerfKnobConfig, KvCacheConfig, LlmArgs,
LookaheadDecodingConfig, MedusaDecodingConfig, MoeConfig,
MTPDecodingConfig, NGramDecodingConfig,
RocketSparseAttentionConfig,
SaveHiddenStatesDecodingConfig, SchedulerConfig,
SkipSoftmaxAttentionConfig, TorchCompileConfig,
TorchLlmArgs, TrtLlmArgs, UserProvidedDecodingConfig)
@ -38,6 +39,7 @@ __all__ = [
'LookaheadDecodingConfig',
'MedusaDecodingConfig',
'EagleDecodingConfig',
'Eagle3DecodingConfig',
'MTPDecodingConfig',
'SchedulerConfig',
'CapacitySchedulerPolicy',

View File

@ -708,6 +708,8 @@ class DecodingBaseConfig(StrictBaseModel):
_allow_chain_drafter: bool = PrivateAttr(True)
# If set, drafting uses greedy sampling, irrespective of sampling parameters.
_allow_greedy_draft_tokens: bool = PrivateAttr(True)
# Internal: record decoding_type alias used during parsing (for warnings).
_decoding_type_alias: Optional[str] = PrivateAttr(default=None)
@field_validator('draft_len_schedule')
@classmethod
@ -755,13 +757,14 @@ class DecodingBaseConfig(StrictBaseModel):
return v
@classmethod
def from_dict(cls, data: dict):
def from_dict(cls, data: dict, backend: Optional[str] = None):
# dispatch to the correct decoding config
decoding_type = data.get("decoding_type")
config_classes = {
"MTP": MTPDecodingConfig,
"Medusa": MedusaDecodingConfig,
"Eagle": EagleDecodingConfig,
"Eagle3": Eagle3DecodingConfig,
"Lookahead": LookaheadDecodingConfig,
"NGram": NGramDecodingConfig,
"DraftTarget": DraftTargetDecodingConfig,
@ -770,6 +773,14 @@ class DecodingBaseConfig(StrictBaseModel):
"AUTO": AutoDecodingConfig,
}
backend = backend.lower() if isinstance(backend, str) else backend
if decoding_type == "Eagle" and backend in ("pytorch", "_autodeploy"):
data = dict(data)
data.pop("decoding_type")
spec_cfg = Eagle3DecodingConfig(**data)
spec_cfg._decoding_type_alias = "Eagle"
return spec_cfg
config_class = config_classes.get(decoding_type)
if config_class is None:
raise ValueError(f"Invalid decoding type: {decoding_type}")
@ -966,6 +977,10 @@ class EagleDecodingConfig(DecodingBaseConfig):
return False
class Eagle3DecodingConfig(EagleDecodingConfig):
decoding_type: ClassVar[str] = "Eagle3"
class SaveHiddenStatesDecodingConfig(DecodingBaseConfig):
output_directory: str
write_interval: int = 20
@ -2506,9 +2521,14 @@ class TrtLlmArgs(BaseLlmArgs):
decoding_mode=DecodingMode.Medusa(),
medusa_choices=self.speculative_config.medusa_choices)
elif isinstance(self.speculative_config, Eagle3DecodingConfig):
raise ValueError(
"speculative_config.decoding_type 'Eagle3' is only supported on the PyTorch backend. "
"Use decoding_type 'Eagle' for the TensorRT backend.")
elif isinstance(self.speculative_config, EagleDecodingConfig):
assert self.speculative_config.max_draft_len > 0
assert self.speculative_config.speculative_model_dir is not None, "Path to EAGLE3 weights must be specified."
assert self.speculative_config.speculative_model_dir is not None, "Path to EAGLE weights must be specified."
self.build_config.max_draft_len = self.speculative_config.max_draft_len
self.build_config.speculative_decoding_mode = SpeculativeDecodingMode.EAGLE
eagle_config = _EagleConfig(
@ -3024,6 +3044,14 @@ class TorchLlmArgs(BaseLlmArgs):
f"support backend {self.backend}")
if isinstance(self.speculative_config, EagleDecodingConfig):
if (getattr(self.speculative_config, "_decoding_type_alias",
None) == "Eagle" or type(self.speculative_config)
is EagleDecodingConfig):
logger.warning(
"speculative_config.decoding_type 'Eagle' is not supported on the PyTorch backend; only 'Eagle3' is supported. "
"'Eagle' is treated as 'Eagle3' for backward compatibility. "
"EAGLE (v1/v2) draft checkpoints are incompatible with Eagle3—use an Eagle3 draft model."
)
assert self.speculative_config.max_draft_len > 0
assert self.speculative_config.speculative_model_dir is not None, "Path to EAGLE3 weights must be specified."
elif isinstance(self.speculative_config, NGramDecodingConfig):
@ -3323,8 +3351,14 @@ def update_llm_args_with_extra_dict(
if field_name in llm_args_dict:
# Some fields need to be converted manually.
if field_name in ["speculative_config", "sparse_attention_config"]:
llm_args_dict[field_name] = field_type.from_dict(
llm_args_dict[field_name])
if field_name == "speculative_config":
backend = llm_args_dict.get("backend") or llm_args.get(
"backend")
llm_args_dict[field_name] = field_type.from_dict(
llm_args_dict[field_name], backend=backend)
else:
llm_args_dict[field_name] = field_type.from_dict(
llm_args_dict[field_name])
else:
llm_args_dict[field_name] = field_type(
**llm_args_dict[field_name])

View File

@ -30,8 +30,8 @@ from ..module import Module
from .build_cache import (BuildCache, BuildCacheConfig, CachedStage,
get_build_cache_config_from_env)
from .llm_args import (CalibConfig, CudaGraphConfig, DraftTargetDecodingConfig,
EagleDecodingConfig, KvCacheConfig, LlmArgs,
LookaheadDecodingConfig, MedusaDecodingConfig,
Eagle3DecodingConfig, EagleDecodingConfig, KvCacheConfig,
LlmArgs, LookaheadDecodingConfig, MedusaDecodingConfig,
MTPDecodingConfig, NGramDecodingConfig,
UserProvidedDecodingConfig, _ModelFormatKind,
_ModelWrapper, _ParallelConfig,
@ -923,6 +923,7 @@ __all__ = [
'KvCacheConfig',
'CachedModelLoader',
'EagleDecodingConfig',
'Eagle3DecodingConfig',
'update_llm_args_with_extra_dict',
'update_llm_args_with_extra_options',
]

View File

@ -52,7 +52,7 @@ from tensorrt_llm._torch.modules.fused_moe.fused_moe_triton import \
IS_TRITON_KERNELS_AVAILABLE
from tensorrt_llm.llmapi import (AutoDecodingConfig, CudaGraphConfig,
DeepSeekSparseAttentionConfig,
EagleDecodingConfig, KvCacheConfig, MoeConfig,
Eagle3DecodingConfig, KvCacheConfig, MoeConfig,
MTPDecodingConfig, NGramDecodingConfig,
RocketSparseAttentionConfig, SamplingParams,
SkipSoftmaxAttentionConfig, TorchCompileConfig)
@ -275,9 +275,10 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
target_model_dir = f"{llm_models_root()}/llama-3.1-model/Llama-3.1-8B-Instruct"
draft_len = 4
spec_config = EagleDecodingConfig(max_draft_len=draft_len,
speculative_model_dir=eagle_model_dir,
eagle3_one_model=eagle3_one_model)
spec_config = Eagle3DecodingConfig(
max_draft_len=draft_len,
speculative_model_dir=eagle_model_dir,
eagle3_one_model=eagle3_one_model)
with LLM(model=target_model_dir,
**pytorch_config,
@ -367,7 +368,7 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
mocker.patch.dict(os.environ, {"TRTLLM_XGUIDANCE_LENIENT": "1"})
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.8)
cuda_graph_config = CudaGraphConfig(enable_padding=True)
spec_config = EagleDecodingConfig(
spec_config = Eagle3DecodingConfig(
max_draft_len=3,
speculative_model_dir=
f"{llm_models_root()}/EAGLE3-LLaMA3.1-Instruct-8B",
@ -620,9 +621,10 @@ class TestLlama3_3_70BInstruct(LlmapiAccuracyTestHarness):
model_path = f"{llm_models_root()}/modelopt-hf-model-hub/Llama-3.3-70B-Instruct-fp8"
eagle_model_dir = f"{llm_models_root()}/EAGLE3-LLaMA3.3-Instruct-70B"
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.6)
spec_config = EagleDecodingConfig(max_draft_len=3,
speculative_model_dir=eagle_model_dir,
eagle3_one_model=eagle3_one_model)
spec_config = Eagle3DecodingConfig(
max_draft_len=3,
speculative_model_dir=eagle_model_dir,
eagle3_one_model=eagle3_one_model)
torch_compile_config = _get_default_torch_compile_config(torch_compile)
pytorch_config = dict(
disable_overlap_scheduler=not eagle3_one_model,
@ -3440,9 +3442,10 @@ class TestQwen3_8B(LlmapiAccuracyTestHarness):
target_model_dir = f"{llm_models_root()}/Qwen3/Qwen3-8B"
draft_len = 4
spec_config = EagleDecodingConfig(max_draft_len=draft_len,
speculative_model_dir=eagle_model_dir,
eagle3_one_model=eagle3_one_model)
spec_config = Eagle3DecodingConfig(
max_draft_len=draft_len,
speculative_model_dir=eagle_model_dir,
eagle3_one_model=eagle3_one_model)
llm = LLM(model=target_model_dir,
**pytorch_config,
@ -3810,7 +3813,7 @@ class TestQwen3_235B_A22B(LlmapiAccuracyTestHarness):
enable_block_reuse=not eagle3)
spec_config = None
if eagle3:
spec_config = EagleDecodingConfig(
spec_config = Eagle3DecodingConfig(
max_draft_len=2,
speculative_model_dir=
f"{llm_models_root()}/Qwen3/qwen3-235B-eagle3/",
@ -3858,7 +3861,7 @@ class TestQwen3_235B_A22B(LlmapiAccuracyTestHarness):
enable_block_reuse=not eagle3)
spec_config = None
if eagle3:
spec_config = EagleDecodingConfig(
spec_config = Eagle3DecodingConfig(
max_draft_len=2,
speculative_model_dir=
f"{llm_models_root()}/Qwen3/qwen3-235B-eagle3/",
@ -4478,10 +4481,11 @@ class TestGPTOSS(LlmapiAccuracyTestHarness):
eagle_model_dir = f"{llm_models_root()}/gpt_oss/gpt-oss-120b-Eagle3"
draft_len = 3
spec_config = EagleDecodingConfig(max_draft_len=draft_len,
speculative_model_dir=eagle_model_dir,
eagle3_one_model=one_model,
allow_advanced_sampling=True)
spec_config = Eagle3DecodingConfig(
max_draft_len=draft_len,
speculative_model_dir=eagle_model_dir,
eagle3_one_model=one_model,
allow_advanced_sampling=True)
max_seq_len = MAX_INPUT_LEN + MAX_OUTPUT_LEN
llm = LLM(self.MODEL_PATH,
@ -4544,10 +4548,11 @@ class TestGPTOSS(LlmapiAccuracyTestHarness):
eagle_model_dir = f"{llm_models_root()}/gpt_oss/gpt-oss-120b-Eagle3"
draft_len = 3
spec_config = EagleDecodingConfig(max_draft_len=draft_len,
speculative_model_dir=eagle_model_dir,
eagle3_one_model=one_model,
allow_advanced_sampling=True)
spec_config = Eagle3DecodingConfig(
max_draft_len=draft_len,
speculative_model_dir=eagle_model_dir,
eagle3_one_model=one_model,
allow_advanced_sampling=True)
max_seq_len = MAX_INPUT_LEN + MAX_OUTPUT_LEN
llm = LLM(self.MODEL_PATH,
@ -4608,10 +4613,11 @@ class TestGPTOSS(LlmapiAccuracyTestHarness):
eagle_model_dir = f"{llm_models_root()}/gpt_oss/gpt-oss-120b-Eagle3"
draft_len = 3
spec_config = EagleDecodingConfig(max_draft_len=draft_len,
speculative_model_dir=eagle_model_dir,
eagle3_one_model=one_model,
allow_advanced_sampling=True)
spec_config = Eagle3DecodingConfig(
max_draft_len=draft_len,
speculative_model_dir=eagle_model_dir,
eagle3_one_model=one_model,
allow_advanced_sampling=True)
max_seq_len = MAX_INPUT_LEN + MAX_OUTPUT_LEN
llm = LLM(self.MODEL_PATH,
@ -4667,9 +4673,10 @@ class TestGPTOSS(LlmapiAccuracyTestHarness):
eagle_model_dir = f"{llm_models_root()}/gpt_oss/gpt-oss-120b-Eagle3"
draft_len = 3
spec_config = EagleDecodingConfig(max_draft_len=draft_len,
speculative_model_dir=eagle_model_dir,
eagle3_one_model=one_model)
spec_config = Eagle3DecodingConfig(
max_draft_len=draft_len,
speculative_model_dir=eagle_model_dir,
eagle3_one_model=one_model)
max_seq_len = MAX_INPUT_LEN + MAX_OUTPUT_LEN
llm = LLM(self.MODEL_PATH,
@ -5148,7 +5155,7 @@ class TestMistralLarge3_675B(LlmapiAccuracyTestHarness):
enable_block_reuse=not eagle3)
spec_config = None
if eagle3:
spec_config = EagleDecodingConfig(
spec_config = Eagle3DecodingConfig(
max_draft_len=2,
speculative_model_dir=
f"{llm_models_root()}/Mistral-Large-3-675B/Mistral-Large-3-675B-Instruct-2512-Eagle/",
@ -5199,7 +5206,7 @@ class TestMistralLarge3_675B(LlmapiAccuracyTestHarness):
enable_block_reuse=not eagle3)
spec_config = None
if eagle3:
spec_config = EagleDecodingConfig(
spec_config = Eagle3DecodingConfig(
max_draft_len=2,
speculative_model_dir=
f"{llm_models_root()}/Mistral-Large-3-675B/Mistral-Large-3-675B-Instruct-2512-Eagle/",

View File

@ -13,7 +13,7 @@ from tensorrt_llm import LLM, DisaggregatedParams, SamplingParams
from tensorrt_llm._utils import set_mpi_comm
from tensorrt_llm.llmapi import (CacheTransceiverConfig, CudaGraphConfig,
KvCacheConfig, MpiCommSession)
from tensorrt_llm.llmapi.llm_args import EagleDecodingConfig
from tensorrt_llm.llmapi.llm_args import Eagle3DecodingConfig
cloudpickle.register_pickle_by_value(sys.modules[__name__])
MPI.pickle.__init__(
@ -399,7 +399,7 @@ def test_disaggregated_spec_dec_batch_slot_limit(model, spec_dec_model_path,
eagle3_one_model):
# Test whether the batch slots are properly released when using speculative decoding
# with disaggregated serving.
spec_dec_config = EagleDecodingConfig(
spec_dec_config = Eagle3DecodingConfig(
speculative_model_dir=model_path(spec_dec_model_path),
eagle3_one_model=eagle3_one_model,
max_draft_len=3)

View File

@ -21,7 +21,7 @@ from defs.conftest import llm_models_root
from tensorrt_llm import SamplingParams
from tensorrt_llm._torch.auto_deploy.llm import LLM
from tensorrt_llm.llmapi import DraftTargetDecodingConfig, EagleDecodingConfig, KvCacheConfig
from tensorrt_llm.llmapi import DraftTargetDecodingConfig, Eagle3DecodingConfig, KvCacheConfig
prompts = [
"What is the capital of France?",
@ -57,7 +57,7 @@ def make_draft_target_config(spec_model_path: str):
def make_eagle3_config(spec_model_path: str):
return EagleDecodingConfig(
return Eagle3DecodingConfig(
max_draft_len=EAGLE_MAX_DRAFT_LEN,
speculative_model_dir=spec_model_path,
eagle3_one_model=False,
@ -214,7 +214,7 @@ def test_autodeploy_eagle3_acceptance_rate():
max_draft_len = EAGLE_MAX_DRAFT_LEN
# Configure Eagle3 speculative decoding
speculative_config = EagleDecodingConfig(
speculative_config = Eagle3DecodingConfig(
max_draft_len=max_draft_len,
speculative_model_dir=eagle_model,
eagle3_one_model=False,

View File

@ -3375,7 +3375,7 @@ def test_eagle3_output_consistency_4gpus(model_dir: str, draft_model_dir: str):
RCCA: https://nvbugspro.nvidia.com/bug/5575211
"""
from tensorrt_llm import LLM, SamplingParams
from tensorrt_llm.llmapi import (CudaGraphConfig, EagleDecodingConfig,
from tensorrt_llm.llmapi import (CudaGraphConfig, Eagle3DecodingConfig,
KvCacheConfig)
models_path = llm_models_root()
@ -3412,7 +3412,7 @@ def test_eagle3_output_consistency_4gpus(model_dir: str, draft_model_dir: str):
sampling_params = SamplingParams(max_tokens=1024, temperature=0)
# Run with Eagle3
spec_config = EagleDecodingConfig(
spec_config = Eagle3DecodingConfig(
max_draft_len=3,
speculative_model_dir=eagle_model_dir,
eagle3_one_model=True,

View File

@ -10,7 +10,7 @@ from tensorrt_llm._torch.attention_backend.trtllm import TrtllmAttentionMetadata
from tensorrt_llm._torch.speculative.drafting_loops import TreeDraftingLoopWrapper
from tensorrt_llm._torch.speculative.eagle3 import Eagle3ResourceManager, Eagle3SpecMetadata
from tensorrt_llm._torch.speculative.spec_tree_manager import SpecTreeManager
from tensorrt_llm.llmapi import EagleDecodingConfig
from tensorrt_llm.llmapi import Eagle3DecodingConfig
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
@ -84,7 +84,7 @@ def test_draft_token_static_tree_prepare_for_generation():
)
# 2) Create spec metadata
spec_config = EagleDecodingConfig(
spec_config = Eagle3DecodingConfig(
max_draft_len=max_draft_len,
max_total_draft_tokens=max_total_draft_tokens,
speculative_model_dir=eagle_model_dir,

View File

@ -8,7 +8,7 @@ from utils.llm_data import llm_models_root
from tensorrt_llm._torch.speculative.drafting_loops import \
TreeDraftingLoopWrapper
from tensorrt_llm._torch.speculative.spec_tree_manager import SpecTreeManager
from tensorrt_llm.llmapi import EagleDecodingConfig
from tensorrt_llm.llmapi import Eagle3DecodingConfig
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
@ -35,7 +35,7 @@ def test_draft_token_static_tree_sampling():
def run_test(max_batch_size, draft_layer_id, max_total_draft_tokens,
max_draft_len, eagle_choices, logits, use_cuda_graph,
ref_new_tokens):
spec_config = EagleDecodingConfig(
spec_config = Eagle3DecodingConfig(
max_draft_len=max_draft_len,
max_total_draft_tokens=max_total_draft_tokens,
speculative_model_dir=eagle_model_dir,

View File

@ -11,7 +11,7 @@ from tensorrt_llm._torch.pyexecutor.llm_request import (LlmRequest,
from tensorrt_llm._torch.pyexecutor.sampler import TorchSampler
from tensorrt_llm._torch.speculative.spec_tree_manager import SpecTreeManager
from tensorrt_llm.bindings.executor import FinishReason
from tensorrt_llm.llmapi import EagleDecodingConfig
from tensorrt_llm.llmapi import Eagle3DecodingConfig
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
@ -20,7 +20,7 @@ def run_test(eagle_model_dir, max_seq_len, beam_width, use_dynamic_tree,
max_new_tokens, max_batch_size, input_request, input_new_tokens,
draft_layer_id, max_total_draft_tokens, max_draft_len,
eagle_choices, ref_num_accepted_draft_tokens, ref_mtokens):
spec_config = EagleDecodingConfig(
spec_config = Eagle3DecodingConfig(
max_draft_len=max_draft_len,
max_total_draft_tokens=max_total_draft_tokens,
speculative_model_dir=eagle_model_dir,

View File

@ -9,7 +9,7 @@ from utils.llm_data import llm_models_root
from tensorrt_llm import LLM, SamplingParams
from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequestState
from tensorrt_llm.llmapi import (CudaGraphConfig, EagleDecodingConfig,
from tensorrt_llm.llmapi import (CudaGraphConfig, Eagle3DecodingConfig,
KvCacheConfig)
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
@ -54,7 +54,7 @@ def test_dynamic_spec_decode(enforce_single_worker,
max_seq_len=8192,
)
spec_config = EagleDecodingConfig(
spec_config = Eagle3DecodingConfig(
max_draft_len=max_draft_len,
speculative_model_dir=eagle_model_dir,
# Llama 3 does not support one model eagle.

View File

@ -13,7 +13,7 @@ from utils.llm_data import llm_models_root
from tensorrt_llm import LLM, SamplingParams
from tensorrt_llm._torch.attention_backend.trtllm import TrtllmAttentionMetadata
from tensorrt_llm._torch.metadata import KVCacheParams
from tensorrt_llm.llmapi import (CudaGraphConfig, EagleDecodingConfig,
from tensorrt_llm.llmapi import (CudaGraphConfig, Eagle3DecodingConfig,
KvCacheConfig)
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
@ -163,7 +163,7 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str,
# Use a small max_num_tokens so that the chunked prefill path gets exercised.
llm_common_config['max_num_tokens'] = 64
spec_config = EagleDecodingConfig(
spec_config = Eagle3DecodingConfig(
max_draft_len=max_draft_len,
speculative_model_dir=eagle_model_dir,
# Llama 3 does not support one model eagle.
@ -239,7 +239,7 @@ def test_eagle3_spec_decoding_stats(eagle3_one_model):
kv_cache_config = KvCacheConfig(enable_block_reuse=False,
free_gpu_memory_fraction=0.6)
spec_config = EagleDecodingConfig(
spec_config = Eagle3DecodingConfig(
max_draft_len=3,
speculative_model_dir=eagle_model_dir,
eagle3_one_model=eagle3_one_model,
@ -319,7 +319,7 @@ def test_llama_eagle3_long_prompt(use_cuda_graph):
eagle_model_dir = f"{models_path}/EAGLE3-LLaMA3.1-Instruct-8B"
target_model_dir = f"{models_path}/llama-3.1-model/Llama-3.1-8B-Instruct"
spec_config = EagleDecodingConfig(
spec_config = Eagle3DecodingConfig(
max_draft_len=3,
speculative_model_dir=eagle_model_dir,
eagle3_one_model=False,
@ -413,7 +413,7 @@ def test_deepseek_eagle3():
'transformers_version': '4.52.4',
'use_cache': True,
'vocab_size': 129280,
'draft_vocab_size': 129280
'draft_vocab_size': 129280,
}
with tempfile.TemporaryDirectory() as temp_dir:
eagle_model_dir = Path(temp_dir)
@ -443,7 +443,7 @@ def test_deepseek_eagle3():
enable_chunked_prefill=enable_chunked_prefill,
)
spec_config = EagleDecodingConfig(
spec_config = Eagle3DecodingConfig(
max_draft_len=max_draft_len,
speculative_model_dir=eagle_model_dir,
# Llama 3 does not support one model eagle.
@ -554,10 +554,11 @@ def test_deepseek_mla_eagle3():
load_format="dummy",
)
spec_config = EagleDecodingConfig(max_draft_len=max_draft_len,
speculative_model_dir=eagle_model_dir,
eagle3_one_model=use_one_model,
load_format="dummy")
spec_config = Eagle3DecodingConfig(
max_draft_len=max_draft_len,
speculative_model_dir=eagle_model_dir,
eagle3_one_model=use_one_model,
load_format="dummy")
llm_spec = LLM(**llm_common_config, speculative_config=spec_config)
@ -623,7 +624,7 @@ def test_multi_eagle3(use_one_model: bool):
'transformers_version': '4.52.4',
'use_cache': True,
'vocab_size': 128256,
'draft_vocab_size': 128256
'draft_vocab_size': 128256,
}
with tempfile.TemporaryDirectory() as temp_dir:
eagle_model_dir = Path(temp_dir)
@ -652,7 +653,7 @@ def test_multi_eagle3(use_one_model: bool):
load_format="dummy",
)
spec_config = EagleDecodingConfig(
spec_config = Eagle3DecodingConfig(
max_draft_len=max_draft_len,
speculative_model_dir=eagle_model_dir,
# Llama 3 does not support one model eagle.
@ -711,7 +712,7 @@ def test_eagle3_cuda_graph_padding(disable_overlap_scheduler: bool):
enable_chunked_prefill=enable_chunked_prefill,
)
spec_config = EagleDecodingConfig(
spec_config = Eagle3DecodingConfig(
max_draft_len=max_draft_len,
speculative_model_dir=eagle_model_dir,
eagle3_one_model=use_one_model,
@ -764,7 +765,7 @@ def test_eagle3_cdl_sampling(disable_overlap_scheduler: bool):
enable_chunked_prefill=enable_chunked_prefill,
)
spec_config = EagleDecodingConfig(
spec_config = Eagle3DecodingConfig(
max_draft_len=max_draft_len,
speculative_model_dir=eagle_model_dir,
eagle3_one_model=use_one_model,

View File

@ -7,7 +7,7 @@ import torch
from utils.llm_data import llm_models_root
from tensorrt_llm import LLM, SamplingParams
from tensorrt_llm.llmapi import (CudaGraphConfig, EagleDecodingConfig,
from tensorrt_llm.llmapi import (CudaGraphConfig, Eagle3DecodingConfig,
KvCacheConfig)
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
@ -50,7 +50,7 @@ def test_kv_cache_reuse(use_cuda_graph: bool, attn_backend: str):
max_seq_len=8192,
)
spec_config = EagleDecodingConfig(
spec_config = Eagle3DecodingConfig(
max_draft_len=max_draft_len,
speculative_model_dir=eagle_model_dir,
eagle3_one_model=False,

View File

@ -9,7 +9,7 @@ from utils.util import similar, skip_blackwell
from tensorrt_llm import LLM, SamplingParams
from tensorrt_llm._torch.speculative.speculation_gate import SpeculationGate
from tensorrt_llm.llmapi import (CudaGraphConfig, EagleDecodingConfig,
from tensorrt_llm.llmapi import (CudaGraphConfig, Eagle3DecodingConfig,
KvCacheConfig)
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
@ -45,7 +45,7 @@ def test_spec_gate_e2e():
max_seq_len=4096,
)
spec_config = EagleDecodingConfig(
spec_config = Eagle3DecodingConfig(
max_draft_len=max_draft_len,
speculative_model_dir=eagle_model_dir,
# Llama 3 does not support one model eagle.

View File

@ -139,6 +139,48 @@ max_seq_len: 128
assert llm_args.max_seq_len == 128
def test_decoding_type_eagle3_parses_to_eagle3_decoding_config():
spec_cfg = DecodingBaseConfig.from_dict(
dict(decoding_type="Eagle3",
max_draft_len=3,
speculative_model_dir="/path/to/draft/model"))
assert isinstance(spec_cfg, Eagle3DecodingConfig)
def test_decoding_type_eagle_warns_on_pytorch_backend(monkeypatch):
import tensorrt_llm.llmapi.llm_args as llm_args_mod
warnings_seen: list[str] = []
def _capture_warning(msg, *args, **kwargs):
warnings_seen.append(str(msg))
monkeypatch.setattr(llm_args_mod.logger, "warning", _capture_warning)
spec_cfg = DecodingBaseConfig.from_dict(dict(
decoding_type="Eagle",
max_draft_len=3,
speculative_model_dir="/path/to/draft/model"),
backend="pytorch")
assert isinstance(spec_cfg, Eagle3DecodingConfig)
TorchLlmArgs(model=llama_model_path, speculative_config=spec_cfg)
assert any(
"EAGLE (v1/v2) draft checkpoints are incompatible with Eagle3" in m
for m in warnings_seen)
def test_decoding_type_eagle3_errors_on_tensorrt_backend():
spec_cfg = DecodingBaseConfig.from_dict(
dict(decoding_type="Eagle3",
max_draft_len=3,
speculative_model_dir="/path/to/draft/model"))
with pytest.raises(ValueError,
match="only supported on the PyTorch backend"):
TrtLlmArgs(model=llama_model_path, speculative_config=spec_cfg)
def check_defaults(py_config_cls, pybind_config_cls):
py_config = py_config_cls()
pybind_config = pybind_config_cls()