From b3146d095dde96d4b1fbb50d444fe50f3d00d446 Mon Sep 17 00:00:00 2001 From: Venky <23023424+venkywonka@users.noreply.github.com> Date: Thu, 22 Jan 2026 23:24:11 +0800 Subject: [PATCH] [TRTC-122][feat] Eagle3 Specdec UX improvements (#10124) Signed-off-by: Venky Ganesh <23023424+venkywonka@users.noreply.github.com> --- .../blogs/tech_blog/blog11_GPT_OSS_Eagle3.md | 2 +- .../blog6_Llama4_maverick_eagle_guide.md | 2 +- docs/source/features/speculative-decoding.md | 10 ++-- .../torch_compile_and_piecewise_cuda_graph.md | 2 +- .../legacy/advanced/speculative-decoding.md | 2 + examples/llm-api/llm_speculative_decoding.py | 4 +- examples/llm-api/quickstart_advanced.py | 4 +- examples/models/core/qwen/README.md | 6 +- tensorrt_llm/_torch/models/modeling_auto.py | 16 +++-- .../_torch/models/modeling_speculative.py | 24 ++++++-- tensorrt_llm/llmapi/__init__.py | 10 ++-- tensorrt_llm/llmapi/llm_args.py | 40 ++++++++++++- tensorrt_llm/llmapi/llm_utils.py | 5 +- .../accuracy/references/cnn_dailymail.yaml | 7 +++ .../accuracy/references/gpqa_diamond.yaml | 8 +++ .../defs/accuracy/references/gsm8k.yaml | 26 ++++++++ .../accuracy/references/json_mode_eval.yaml | 8 +++ .../defs/accuracy/references/mmlu.yaml | 23 +++++++ .../defs/accuracy/test_llm_api_pytorch.py | 60 +++++++++---------- .../test_disaggregated_single_gpu.py | 4 +- .../examples/test_ad_speculative_decoding.py | 6 +- tests/integration/defs/test_e2e.py | 4 +- ...test_draft_token_prepare_for_generation.py | 4 +- .../test_draft_token_tree_sampling.py | 4 +- .../test_draft_token_tree_verification.py | 4 +- .../speculative/test_dynamic_spec_decode.py | 4 +- .../_torch/speculative/test_eagle3.py | 28 ++++----- .../_torch/speculative/test_kv_cache_reuse.py | 4 +- .../_torch/speculative/test_spec_gate.py | 4 +- tests/unittest/llmapi/test_llm_args.py | 42 +++++++++++++ 30 files changed, 271 insertions(+), 96 deletions(-) diff --git a/docs/source/blogs/tech_blog/blog11_GPT_OSS_Eagle3.md b/docs/source/blogs/tech_blog/blog11_GPT_OSS_Eagle3.md index 3b2ddfa782..7a7fc8ce9f 100644 --- a/docs/source/blogs/tech_blog/blog11_GPT_OSS_Eagle3.md +++ b/docs/source/blogs/tech_blog/blog11_GPT_OSS_Eagle3.md @@ -84,7 +84,7 @@ kv_cache_config: enable_block_reuse: false free_gpu_memory_fraction: 0.8 speculative_config: - decoding_type: Eagle + decoding_type: Eagle3 max_draft_len: 3 speculative_model_dir: /config/models/eagle/ cuda_graph_config: diff --git a/docs/source/blogs/tech_blog/blog6_Llama4_maverick_eagle_guide.md b/docs/source/blogs/tech_blog/blog6_Llama4_maverick_eagle_guide.md index 5ebb4e3cbb..9f49d7920b 100644 --- a/docs/source/blogs/tech_blog/blog6_Llama4_maverick_eagle_guide.md +++ b/docs/source/blogs/tech_blog/blog6_Llama4_maverick_eagle_guide.md @@ -68,7 +68,7 @@ docker run -d --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 \ -p 8000:8000 --gpus=all -e "TRTLLM_ENABLE_PDL=1" \ -v /path/to/maverick:/config/models/maverick -v /path/to/eagle:/config/models/eagle \ docker.io//tensorrt_llm:main sh \ - -c "echo -e 'enable_autotuner: false\nenable_attention_dp: false\nenable_min_latency: true\ncuda_graph_config:\n max_batch_size: 8\nspeculative_config:\n decoding_type: Eagle\n max_draft_len: 3\n speculative_model_dir: /config/models/eagle\n eagle3_one_model: true\nkv_cache_config:\n enable_block_reuse: false' > c.yaml && \ + -c "echo -e 'enable_autotuner: false\nenable_attention_dp: false\nenable_min_latency: true\ncuda_graph_config:\n max_batch_size: 8\nspeculative_config:\n decoding_type: Eagle3\n max_draft_len: 3\n speculative_model_dir: /config/models/eagle\n eagle3_one_model: true\nkv_cache_config:\n enable_block_reuse: false' > c.yaml && \ TRT_LLM_DISABLE_LOAD_WEIGHTS_IN_PARALLEL=True \ trtllm-serve /config/models/maverick \ --host 0.0.0.0 --port 8000 \ diff --git a/docs/source/features/speculative-decoding.md b/docs/source/features/speculative-decoding.md index 64e52558b2..d77ee797a5 100644 --- a/docs/source/features/speculative-decoding.md +++ b/docs/source/features/speculative-decoding.md @@ -59,14 +59,14 @@ The following draft model checkpoints can be used for EAGLE 3: * Other models, including `gpt-oss-120b` and `Qwen3`: check out the [Speculative Decoding Modules](https://huggingface.co/collections/nvidia/speculative-decoding-modules) collection from NVIDIA. ```python -from tensorrt_llm.llmapi import EagleDecodingConfig +from tensorrt_llm.llmapi import Eagle3DecodingConfig # Enable to use the faster one-model implementation for Llama 4. eagle3_one_model = False model = "meta-llama/Llama-3.1-8B-Instruct" speculative_model = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" -speculative_config = EagleDecodingConfig( +speculative_config = Eagle3DecodingConfig( max_draft_len=3, speculative_model=speculative_model, eagle3_one_model=eagle3_one_model) @@ -141,10 +141,12 @@ llm = LLM("/path/to/target_model", speculative_config=speculative_config) Speculative decoding options must be specified via `--config config.yaml` for both `trtllm-bench` and `trtllm-serve`. All speculative decoding options can be specified in this YAML file. An additional `decoding_type` option is used to specify the type of speculation to use. The available options are: * `MTP` -* `Eagle` (for EAGLE 3) +* `Eagle3` * `NGram` * `DraftTarget` +> Note: The PyTorch backend supports only `Eagle3`. `decoding_type: Eagle` is accepted as a backward-compatible alias for `Eagle3`, but EAGLE (v1/v2) draft checkpoints are incompatible. + The rest of the argument names/valid values are the same as in their corresponding configuration class described in the Quick Start section. For example, a YAML configuration could look like this: ```yaml @@ -160,7 +162,7 @@ speculative_config: # Or using a local path disable_overlap_scheduler: true speculative_config: - decoding_type: Eagle + decoding_type: Eagle3 max_draft_len: 4 speculative_model: /path/to/draft/model ``` diff --git a/docs/source/features/torch_compile_and_piecewise_cuda_graph.md b/docs/source/features/torch_compile_and_piecewise_cuda_graph.md index 5fab5e09d0..a125157deb 100644 --- a/docs/source/features/torch_compile_and_piecewise_cuda_graph.md +++ b/docs/source/features/torch_compile_and_piecewise_cuda_graph.md @@ -96,7 +96,7 @@ speculative_config: mtp_eagle_one_model: False # Not supported speculative_config: - decoding_type: "Eagle" + decoding_type: "Eagle3" eagle3_one_model: False # Not supported ``` diff --git a/docs/source/legacy/advanced/speculative-decoding.md b/docs/source/legacy/advanced/speculative-decoding.md index b393ae8949..2faf885f8d 100644 --- a/docs/source/legacy/advanced/speculative-decoding.md +++ b/docs/source/legacy/advanced/speculative-decoding.md @@ -171,6 +171,8 @@ The EAGLE approach enhances the single-model Medusa method by predicting and ver Similarly to ReDrafter, TensorRT-LLM implements the EAGLE model such that logits prediction, draft tokens acceptance and draft token generation are performed inside of the TensorRT engine(EAGLE-1 and EAGLE-2 are both supported). Please, visit the [EAGLE README](https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/eagle/README.md) for information about building and running the model. +> **EAGLE3 note.** If the EAGLE3 draft head config omits `draft_vocab_size`, TensorRT-LLM assumes it matches `vocab_size` and emits a warning. Set `draft_vocab_size` explicitly if the draft head uses a different vocabulary. + ### Disaggregated Serving [Disaggregated Serving](https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/features/disaggregated-service.md) with EAGLE3 using the two model approach is supported in the Pytorch backend. Please refer to the following [Dynamo example](https://github.com/ai-dynamo/dynamo/blob/main/examples/tensorrt_llm/llama4_plus_eagle.md) on how to run EAGLE3 with Disaggregated Serving for Llama 4 Maverick. diff --git a/examples/llm-api/llm_speculative_decoding.py b/examples/llm-api/llm_speculative_decoding.py index de33278a09..4f19dd6df3 100644 --- a/examples/llm-api/llm_speculative_decoding.py +++ b/examples/llm-api/llm_speculative_decoding.py @@ -6,7 +6,7 @@ from typing import Optional import click from tensorrt_llm import LLM, SamplingParams -from tensorrt_llm.llmapi import (EagleDecodingConfig, KvCacheConfig, +from tensorrt_llm.llmapi import (Eagle3DecodingConfig, KvCacheConfig, MTPDecodingConfig, NGramDecodingConfig) prompts = [ @@ -33,7 +33,7 @@ def run_MTP(model: Optional[str] = None): def run_Eagle3(): - spec_config = EagleDecodingConfig( + spec_config = Eagle3DecodingConfig( max_draft_len=3, speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", eagle3_one_model=True) diff --git a/examples/llm-api/quickstart_advanced.py b/examples/llm-api/quickstart_advanced.py index c0bb4e31be..1ea69df77c 100644 --- a/examples/llm-api/quickstart_advanced.py +++ b/examples/llm-api/quickstart_advanced.py @@ -5,7 +5,7 @@ import time from tensorrt_llm import LLM, SamplingParams from tensorrt_llm.llmapi import (AttentionDpConfig, AutoDecodingConfig, CudaGraphConfig, DraftTargetDecodingConfig, - EagleDecodingConfig, KvCacheConfig, MoeConfig, + Eagle3DecodingConfig, KvCacheConfig, MoeConfig, MTPDecodingConfig, NGramDecodingConfig, TorchCompileConfig) @@ -222,7 +222,7 @@ def setup_llm(args, **kwargs): mtp_eagle_one_model=args.use_one_model, speculative_model=args.model_dir) elif spec_decode_algo == "EAGLE3": - spec_config = EagleDecodingConfig( + spec_config = Eagle3DecodingConfig( max_draft_len=args.spec_decode_max_draft_len, speculative_model=args.draft_model_dir, eagle3_one_model=args.use_one_model, diff --git a/examples/models/core/qwen/README.md b/examples/models/core/qwen/README.md index 566d4eab1b..f6c9273927 100644 --- a/examples/models/core/qwen/README.md +++ b/examples/models/core/qwen/README.md @@ -837,8 +837,8 @@ settings for your specific use case. Qwen3 now supports Eagle3 (Speculative Decoding with Eagle3). To enable Eagle3 on Qwen3, you need to set the following arguments when running `trtllm-bench` or `trtllm-serve`: -- `speculative_config.decoding_type: Eagle` - Set the decoding type to "Eagle" to enable Eagle3 speculative decoding. +- `speculative_config.decoding_type: Eagle3` + Set the decoding type to `Eagle3` to enable Eagle3 speculative decoding. - `speculative_config.max_draft_len: 3` Set the maximum number of draft tokens generated per step (this value can be adjusted as needed). - `speculative_config.speculative_model: ` @@ -855,7 +855,7 @@ Example `config.yml` snippet for Eagle3: echo " enable_attention_dp: false speculative_config: - decoding_type: Eagle + decoding_type: Eagle3 max_draft_len: 3 speculative_model: kv_cache_config: diff --git a/tensorrt_llm/_torch/models/modeling_auto.py b/tensorrt_llm/_torch/models/modeling_auto.py index 738110d287..86c87009b4 100644 --- a/tensorrt_llm/_torch/models/modeling_auto.py +++ b/tensorrt_llm/_torch/models/modeling_auto.py @@ -24,12 +24,18 @@ class AutoModelForCausalLM(Generic[TModel, TConfig]): vision_encoder_cls, vlm_base_model = vision_encoder_info return vision_encoder_cls(config, vlm_base_model) - # Hack to detect eagle3 checkpoints. TODO: should we provide - # our own checkpoints with the correct arch? It would let us - # avoid nasty stuff like this. - model_arch = model_arch.replace("Eagle3", - "") # Strip the appended EAGLE3 + # Hack to detect eagle3 checkpoints. + # Why it exists: + # - Eagle3 checkpoints have draft_vocab_size in config.json (even if None) + # - Some community checkpoints append "Eagle3" to architecture names ("LlamaForCausalLMEagle3") + # - Some checkpoints don't include "Eagle3" in arch name at all ("LlamaForCausalLM") + # - TensorRT-LLM's MODEL_CLASS_MAPPING expects prefixed names like EAGLE3LlamaForCausalLM + # - Hence: LlamaForCausalLMEagle3 -> EAGLE3LlamaForCausalLM + # LlamaForCausalLM (with draft_vocab_size) -> EAGLE3LlamaForCausalLM + # TODO: should we provide our own checkpoints with the correct arch? It would let us avoid nasty stuff like this. if hasattr(config.pretrained_config, "draft_vocab_size"): + # It's an Eagle3 checkpoint - strip "Eagle3" suffix if present, then add prefix + model_arch = model_arch.replace("Eagle3", "") model_arch = "EAGLE3" + model_arch if model_arch in ( "DeepseekV3ForCausalLM", "Glm4MoeForCausalLM", diff --git a/tensorrt_llm/_torch/models/modeling_speculative.py b/tensorrt_llm/_torch/models/modeling_speculative.py index f59d63b6c8..6792d06393 100755 --- a/tensorrt_llm/_torch/models/modeling_speculative.py +++ b/tensorrt_llm/_torch/models/modeling_speculative.py @@ -4,6 +4,8 @@ import torch from torch import nn from transformers import LlamaConfig, PretrainedConfig +from tensorrt_llm.logger import logger + from ...functional import PositionEmbeddingType from ..attention_backend import AttentionMetadata from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams @@ -24,6 +26,18 @@ from .modeling_utils import (DecoderModel, DecoderModelForCausalLM, TModel, register_auto_model) +def _ensure_draft_vocab_size(config: PretrainedConfig) -> None: + if hasattr(config, + "draft_vocab_size") and config.draft_vocab_size is not None: + return + + logger.warning( + "Missing 'draft_vocab_size' in pretrained config; defaulting to 'vocab_size'. " + "Set 'draft_vocab_size' explicitly if the draft head uses a different vocabulary." + ) + config.draft_vocab_size = config.vocab_size + + class Eagle3Attention(Attention): def __init__( @@ -417,9 +431,8 @@ class Eagle3ForCausalLM(DecoderModelForCausalLM[Eagle3DraftModel, model_config: ModelConfig[PretrainedConfig], start_layer_idx: int = 0, ): - draft_vocab_size = model_config.pretrained_config.vocab_size - if model_config.pretrained_config.draft_vocab_size is not None: - draft_vocab_size = model_config.pretrained_config.draft_vocab_size + config = model_config.pretrained_config + _ensure_draft_vocab_size(config) # Determine if we should use MLA attention based on config # MLA is used for DeepSeekV3-style models that have kv_lora_rank @@ -435,8 +448,8 @@ class Eagle3ForCausalLM(DecoderModelForCausalLM[Eagle3DraftModel, super().__init__( draft_model, config=model_config, - hidden_size=model_config.pretrained_config.hidden_size, - vocab_size=draft_vocab_size, + hidden_size=config.hidden_size, + vocab_size=config.draft_vocab_size, ) self.load_lm_head_from_target = True @@ -598,6 +611,7 @@ class MistralLarge3DraftModel(DecoderModel): # We use MistralLarge3 as the base architecture for EAGLE3 draft layers +# NOTE: Class name says "Eagle" not "Eagle3" to match checkpoint naming (e.g., "Mistral-Large-3-675B-Instruct-2512-Eagle") @register_auto_model("MistralLarge3EagleForCausalLM") class MistralLarge3EagleForCausalLM(DecoderModelForCausalLM): diff --git a/tensorrt_llm/llmapi/__init__.py b/tensorrt_llm/llmapi/__init__.py index b16a6a8acd..9271f51f55 100644 --- a/tensorrt_llm/llmapi/__init__.py +++ b/tensorrt_llm/llmapi/__init__.py @@ -10,10 +10,11 @@ from .llm_args import (AttentionDpConfig, AutoDecodingConfig, BatchingType, CapacitySchedulerPolicy, ContextChunkingPolicy, CudaGraphConfig, DeepSeekSparseAttentionConfig, DraftTargetDecodingConfig, DynamicBatchConfig, - EagleDecodingConfig, ExtendedRuntimePerfKnobConfig, - KvCacheConfig, LlmArgs, LookaheadDecodingConfig, - MedusaDecodingConfig, MoeConfig, MTPDecodingConfig, - NGramDecodingConfig, RocketSparseAttentionConfig, + Eagle3DecodingConfig, EagleDecodingConfig, + ExtendedRuntimePerfKnobConfig, KvCacheConfig, LlmArgs, + LookaheadDecodingConfig, MedusaDecodingConfig, MoeConfig, + MTPDecodingConfig, NGramDecodingConfig, + RocketSparseAttentionConfig, SaveHiddenStatesDecodingConfig, SchedulerConfig, SkipSoftmaxAttentionConfig, TorchCompileConfig, TorchLlmArgs, TrtLlmArgs, UserProvidedDecodingConfig) @@ -38,6 +39,7 @@ __all__ = [ 'LookaheadDecodingConfig', 'MedusaDecodingConfig', 'EagleDecodingConfig', + 'Eagle3DecodingConfig', 'MTPDecodingConfig', 'SchedulerConfig', 'CapacitySchedulerPolicy', diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 1b5d3fad76..124da909a6 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -715,6 +715,8 @@ class DecodingBaseConfig(StrictBaseModel): _allow_chain_drafter: bool = PrivateAttr(True) # If set, drafting uses greedy sampling, irrespective of sampling parameters. _allow_greedy_draft_tokens: bool = PrivateAttr(True) + # Internal: record decoding_type alias used during parsing (for warnings). + _decoding_type_alias: Optional[str] = PrivateAttr(default=None) @field_validator('draft_len_schedule') @classmethod @@ -762,13 +764,14 @@ class DecodingBaseConfig(StrictBaseModel): return v @classmethod - def from_dict(cls, data: dict): + def from_dict(cls, data: dict, backend: Optional[str] = None): # dispatch to the correct decoding config decoding_type = data.get("decoding_type") config_classes = { "MTP": MTPDecodingConfig, "Medusa": MedusaDecodingConfig, "Eagle": EagleDecodingConfig, + "Eagle3": Eagle3DecodingConfig, "Lookahead": LookaheadDecodingConfig, "NGram": NGramDecodingConfig, "DraftTarget": DraftTargetDecodingConfig, @@ -777,6 +780,14 @@ class DecodingBaseConfig(StrictBaseModel): "AUTO": AutoDecodingConfig, } + backend = backend.lower() if isinstance(backend, str) else backend + if decoding_type == "Eagle" and backend in ("pytorch", "_autodeploy"): + data = dict(data) + data.pop("decoding_type") + spec_cfg = Eagle3DecodingConfig(**data) + spec_cfg._decoding_type_alias = "Eagle" + return spec_cfg + config_class = config_classes.get(decoding_type) if config_class is None: raise ValueError(f"Invalid decoding type: {decoding_type}") @@ -990,6 +1001,10 @@ class EagleDecodingConfig(DecodingBaseConfig): return False +class Eagle3DecodingConfig(EagleDecodingConfig): + decoding_type: ClassVar[str] = "Eagle3" + + class SaveHiddenStatesDecodingConfig(DecodingBaseConfig): output_directory: str write_interval: int = 20 @@ -2530,6 +2545,11 @@ class TrtLlmArgs(BaseLlmArgs): decoding_mode=DecodingMode.Medusa(), medusa_choices=self.speculative_config.medusa_choices) + elif isinstance(self.speculative_config, Eagle3DecodingConfig): + raise ValueError( + "speculative_config.decoding_type 'Eagle3' is only supported on the PyTorch backend. " + "Use decoding_type 'Eagle' for the TensorRT backend.") + elif isinstance(self.speculative_config, EagleDecodingConfig): assert self.speculative_config.max_draft_len > 0 assert self.speculative_config.speculative_model is not None, "EAGLE3 draft model must be specified." @@ -3046,6 +3066,14 @@ class TorchLlmArgs(BaseLlmArgs): f"support backend {self.backend}") if isinstance(self.speculative_config, EagleDecodingConfig): + if (getattr(self.speculative_config, "_decoding_type_alias", + None) == "Eagle" or type(self.speculative_config) + is EagleDecodingConfig): + logger.warning( + "speculative_config.decoding_type 'Eagle' is not supported on the PyTorch backend; only 'Eagle3' is supported. " + "'Eagle' is treated as 'Eagle3' for backward compatibility. " + "EAGLE (v1/v2) draft checkpoints are incompatible with Eagle3—use an Eagle3 draft model." + ) assert self.speculative_config.max_draft_len > 0 assert self.speculative_config.speculative_model is not None, "EAGLE3 draft model must be specified." elif isinstance(self.speculative_config, NGramDecodingConfig): @@ -3337,8 +3365,14 @@ def update_llm_args_with_extra_dict( if field_name in llm_args_dict: # Some fields need to be converted manually. if field_name in ["speculative_config", "sparse_attention_config"]: - llm_args_dict[field_name] = field_type.from_dict( - llm_args_dict[field_name]) + if field_name == "speculative_config": + backend = llm_args_dict.get("backend") or llm_args.get( + "backend") + llm_args_dict[field_name] = field_type.from_dict( + llm_args_dict[field_name], backend=backend) + else: + llm_args_dict[field_name] = field_type.from_dict( + llm_args_dict[field_name]) else: llm_args_dict[field_name] = field_type( **llm_args_dict[field_name]) diff --git a/tensorrt_llm/llmapi/llm_utils.py b/tensorrt_llm/llmapi/llm_utils.py index 9b003991a2..83e26ce93e 100644 --- a/tensorrt_llm/llmapi/llm_utils.py +++ b/tensorrt_llm/llmapi/llm_utils.py @@ -30,8 +30,8 @@ from ..module import Module from .build_cache import (BuildCache, BuildCacheConfig, CachedStage, get_build_cache_config_from_env) from .llm_args import (CalibConfig, CudaGraphConfig, DraftTargetDecodingConfig, - EagleDecodingConfig, KvCacheConfig, LlmArgs, - LookaheadDecodingConfig, MedusaDecodingConfig, + Eagle3DecodingConfig, EagleDecodingConfig, KvCacheConfig, + LlmArgs, LookaheadDecodingConfig, MedusaDecodingConfig, MTPDecodingConfig, NGramDecodingConfig, UserProvidedDecodingConfig, _ModelFormatKind, _ModelWrapper, _ParallelConfig, @@ -973,6 +973,7 @@ __all__ = [ 'KvCacheConfig', 'CachedModelLoader', 'EagleDecodingConfig', + 'Eagle3DecodingConfig', 'update_llm_args_with_extra_dict', 'update_llm_args_with_extra_options', ] diff --git a/tests/integration/defs/accuracy/references/cnn_dailymail.yaml b/tests/integration/defs/accuracy/references/cnn_dailymail.yaml index 09c2a7f898..4f5f067742 100644 --- a/tests/integration/defs/accuracy/references/cnn_dailymail.yaml +++ b/tests/integration/defs/accuracy/references/cnn_dailymail.yaml @@ -68,6 +68,8 @@ lmsys/vicuna-7b-v1.3: accuracy: 33.419 - spec_dec_algo: Eagle accuracy: 27.832 + - spec_dec_algo: Eagle3 + accuracy: 27.832 llama-7b-hf: - accuracy: 30.457 - quant_algo: W4A16_GPTQ @@ -137,6 +139,8 @@ meta-llama/Llama-3.1-8B-Instruct: - accuracy: 33.640 - spec_dec_algo: Eagle accuracy: 33.640 + - spec_dec_algo: Eagle3 + accuracy: 33.640 - extra_acc_spec: logprobs=2 accuracy: 30.522 - quant_algo: FP8 @@ -189,6 +193,9 @@ meta-llama/Llama-3.3-70B-Instruct: - quant_algo: FP8 spec_dec_algo: Eagle accuracy: 33.244 + - quant_algo: FP8 + spec_dec_algo: Eagle3 + accuracy: 33.244 - quant_algo: NVFP4 kv_cache_quant_algo: FP8 accuracy: 34.383 diff --git a/tests/integration/defs/accuracy/references/gpqa_diamond.yaml b/tests/integration/defs/accuracy/references/gpqa_diamond.yaml index 404dcbb6b7..a0820dd2c3 100644 --- a/tests/integration/defs/accuracy/references/gpqa_diamond.yaml +++ b/tests/integration/defs/accuracy/references/gpqa_diamond.yaml @@ -63,11 +63,16 @@ GPT-OSS/120B-MXFP4: - accuracy: 65.0 - spec_dec_algo: Eagle accuracy: 65.0 + - spec_dec_algo: Eagle3 + accuracy: 65.0 - quant_algo: W4A8_MXFP4_MXFP8 accuracy: 65.0 - quant_algo: W4A8_MXFP4_MXFP8 spec_dec_algo: Eagle accuracy: 65.0 + - quant_algo: W4A8_MXFP4_MXFP8 + spec_dec_algo: Eagle3 + accuracy: 65.0 - quant_algo: W4A8_MXFP4_MXFP8 kv_cache_quant_algo: FP8 accuracy: 65.0 @@ -76,6 +81,9 @@ GPT-OSS/120B-MXFP4: - quant_algo: W4A16_MXFP4 spec_dec_algo: Eagle accuracy: 65.0 + - quant_algo: W4A16_MXFP4 + spec_dec_algo: Eagle3 + accuracy: 65.0 - quant_algo: W4A16_MXFP4 kv_cache_quant_algo: FP8 accuracy: 65.0 diff --git a/tests/integration/defs/accuracy/references/gsm8k.yaml b/tests/integration/defs/accuracy/references/gsm8k.yaml index 94a64dea1a..70c3f18757 100644 --- a/tests/integration/defs/accuracy/references/gsm8k.yaml +++ b/tests/integration/defs/accuracy/references/gsm8k.yaml @@ -4,6 +4,8 @@ meta-llama/Llama-3.1-8B-Instruct: accuracy: 74.20 - spec_dec_algo: Eagle accuracy: 74.20 + - spec_dec_algo: Eagle3 + accuracy: 74.20 - quant_algo: FP8 accuracy: 74.30 - quant_algo: FP8 @@ -31,6 +33,10 @@ meta-llama/Llama-4-Maverick-17B-128E-Instruct: kv_cache_quant_algo: FP8 spec_dec_algo: Eagle accuracy: 92.20 + - quant_algo: FP8 + kv_cache_quant_algo: FP8 + spec_dec_algo: Eagle3 + accuracy: 92.20 meta-llama/Llama-4-Scout-17B-16E-Instruct: - accuracy: 89.70 - quant_algo: NVFP4 @@ -109,10 +115,14 @@ deepseek-ai/DeepSeek-V3.2-Exp: Qwen3/Qwen3-4B: - spec_dec_algo: Eagle accuracy: 85.823 + - spec_dec_algo: Eagle3 + accuracy: 85.823 Qwen3/Qwen3-8B: - accuracy: 87.1114 - spec_dec_algo: Eagle accuracy: 87.1114 + - spec_dec_algo: Eagle3 + accuracy: 87.1114 - quant_algo: FP8 kv_cache_quant_algo: FP8 accuracy: 87.1114 @@ -128,6 +138,8 @@ Qwen3/Qwen3-30B-A3B: accuracy: 83.43 - spec_dec_algo: Eagle accuracy: 83.43 + - spec_dec_algo: Eagle3 + accuracy: 83.43 Qwen3/Qwen3-235B-A22B: - quant_algo: FP8 kv_cache_quant_algo: FP8 @@ -139,6 +151,10 @@ Qwen3/Qwen3-235B-A22B: quant_algo: NVFP4 kv_cache_quant_algo: FP8 accuracy: 85.78 + - spec_dec_algo: Eagle3 + quant_algo: NVFP4 + kv_cache_quant_algo: FP8 + accuracy: 85.78 Qwen3/Qwen3-Next-80B-A3B-Thinking: - accuracy: 81.577 Qwen3/Qwen3-Next-80B-A3B-Instruct: @@ -245,11 +261,16 @@ GPT-OSS/120B-MXFP4: - accuracy: 90.3 - spec_dec_algo: Eagle accuracy: 90.3 + - spec_dec_algo: Eagle3 + accuracy: 90.3 - quant_algo: W4A8_MXFP4_MXFP8 accuracy: 90.3 - quant_algo: W4A8_MXFP4_MXFP8 spec_dec_algo: Eagle accuracy: 90.3 + - quant_algo: W4A8_MXFP4_MXFP8 + spec_dec_algo: Eagle3 + accuracy: 90.3 - quant_algo: W4A8_MXFP4_MXFP8 kv_cache_quant_algo: FP8 accuracy: 90.3 @@ -263,6 +284,9 @@ GPT-OSS/120B-MXFP4: - quant_algo: W4A16_MXFP4 spec_dec_algo: Eagle accuracy: 90.3 + - quant_algo: W4A16_MXFP4 + spec_dec_algo: Eagle3 + accuracy: 90.3 - quant_algo: W4A16_MXFP4 kv_cache_quant_algo: FP8 accuracy: 90.3 @@ -314,6 +338,8 @@ mistral/Mistral-Large-3-675B: - accuracy: 86.1 - spec_dec_algo: Eagle accuracy: 86.1 + - spec_dec_algo: Eagle3 + accuracy: 86.1 nvidia/Nemotron-Super-V3: - accuracy: 83.74 - quant_algo: FP8 diff --git a/tests/integration/defs/accuracy/references/json_mode_eval.yaml b/tests/integration/defs/accuracy/references/json_mode_eval.yaml index 0d36ea6d26..6b43fa4e39 100644 --- a/tests/integration/defs/accuracy/references/json_mode_eval.yaml +++ b/tests/integration/defs/accuracy/references/json_mode_eval.yaml @@ -2,6 +2,8 @@ meta-llama/Llama-3.1-8B-Instruct: - accuracy: 74.00 - spec_dec_algo: Eagle accuracy: 74.00 + - spec_dec_algo: Eagle3 + accuracy: 74.00 - spec_dec_algo: NGram accuracy: 74.00 deepseek-ai/DeepSeek-V3-Lite: @@ -16,6 +18,12 @@ GPT-OSS/120B-MXFP4: - quant_algo: W4A16_MXFP4 spec_dec_algo: Eagle accuracy: 62.00 + - quant_algo: W4A16_MXFP4 + spec_dec_algo: Eagle3 + accuracy: 62.00 - quant_algo: W4A8_MXFP4_MXFP8 spec_dec_algo: Eagle accuracy: 62.00 + - quant_algo: W4A8_MXFP4_MXFP8 + spec_dec_algo: Eagle3 + accuracy: 62.00 diff --git a/tests/integration/defs/accuracy/references/mmlu.yaml b/tests/integration/defs/accuracy/references/mmlu.yaml index 3646793942..479ecca029 100644 --- a/tests/integration/defs/accuracy/references/mmlu.yaml +++ b/tests/integration/defs/accuracy/references/mmlu.yaml @@ -22,6 +22,8 @@ meta-llama/Llama-3.1-8B-Instruct: - accuracy: 68.17 - spec_dec_algo: Eagle accuracy: 68.20 + - spec_dec_algo: Eagle3 + accuracy: 68.20 - spec_dec_algo: NGram accuracy: 68.17 - quant_algo: FP8 @@ -57,14 +59,21 @@ meta-llama/Llama-3.2-1B: accuracy: 32.82 meta-llama/Llama-3.2-3B: - accuracy: 57.92 + - spec_dec_algo: Eagle3 + accuracy: 57.92 - quant_algo: FP8 kv_cache_quant_algo: FP8 accuracy: 60.60 meta-llama/Llama-3.3-70B-Instruct: - accuracy: 81.31 + - spec_dec_algo: Eagle3 + accuracy: 81.31 - quant_algo: FP8 spec_dec_algo: Eagle accuracy: 81.31 + - quant_algo: FP8 + spec_dec_algo: Eagle3 + accuracy: 81.31 - quant_algo: NVFP4 kv_cache_quant_algo: FP8 accuracy: 78.78 @@ -82,6 +91,10 @@ meta-llama/Llama-4-Maverick-17B-128E-Instruct: kv_cache_quant_algo: FP8 spec_dec_algo: Eagle accuracy: 86.40 + - quant_algo: FP8 + kv_cache_quant_algo: FP8 + spec_dec_algo: Eagle3 + accuracy: 86.40 - quant_algo: FP8 kv_cache_quant_algo: FP8 accuracy: 86.40 @@ -213,6 +226,8 @@ Qwen3/Qwen3-8B: - accuracy: 76.0 # WAR for https://nvbugs/5575902 - spec_dec_algo: Eagle accuracy: 76.12 + - spec_dec_algo: Eagle3 + accuracy: 76.12 Qwen3/Qwen3-30B-A3B: - accuracy: 79.53 - quant_algo: FP8_BLOCK_SCALES @@ -240,6 +255,10 @@ Qwen3/Qwen3-235B-A22B: quant_algo: NVFP4 kv_cache_quant_algo: FP8 accuracy: 85.5 + - spec_dec_algo: Eagle3 + quant_algo: NVFP4 + kv_cache_quant_algo: FP8 + accuracy: 85.5 Qwen3/Qwen3-Next-80B-A3B-Thinking: - accuracy: 86 Qwen3/Qwen3-Next-80B-A3B-Instruct: @@ -335,6 +354,8 @@ microsoft/phi-4: accuracy: 79.36 LGAI-EXAONE/EXAONE-4.0-32B: - accuracy: 78.52 + - spec_dec_algo: Eagle3 + accuracy: 78.52 GPT-OSS/BF16: - accuracy: 77.50 GPT-OSS/MXFP4: @@ -351,6 +372,8 @@ mistral/Mistral-Large-3-675B: - accuracy: 85.30 - spec_dec_algo: Eagle accuracy: 85.30 + - spec_dec_algo: Eagle3 + accuracy: 85.30 nvidia/Nemotron-Super-V3: - accuracy: 80.00 - quant_algo: FP8 diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index ab4315f3cc..76b1f7bb02 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -53,7 +53,7 @@ from tensorrt_llm._torch.modules.fused_moe.fused_moe_triton import \ IS_TRITON_KERNELS_AVAILABLE from tensorrt_llm.llmapi import (AutoDecodingConfig, CudaGraphConfig, DeepSeekSparseAttentionConfig, - EagleDecodingConfig, KvCacheConfig, MoeConfig, + Eagle3DecodingConfig, KvCacheConfig, MoeConfig, MTPDecodingConfig, NGramDecodingConfig, RocketSparseAttentionConfig, SamplingParams, SkipSoftmaxAttentionConfig, TorchCompileConfig) @@ -290,9 +290,9 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness): target_model_dir = f"{llm_models_root()}/llama-3.1-model/Llama-3.1-8B-Instruct" draft_len = 4 - spec_config = EagleDecodingConfig(max_draft_len=draft_len, - speculative_model=eagle_model_dir, - eagle3_one_model=eagle3_one_model) + spec_config = Eagle3DecodingConfig(max_draft_len=draft_len, + speculative_model=eagle_model_dir, + eagle3_one_model=eagle3_one_model) with LLM(model=target_model_dir, **pytorch_config, @@ -382,7 +382,7 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness): mocker.patch.dict(os.environ, {"TRTLLM_XGUIDANCE_LENIENT": "1"}) kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.8) cuda_graph_config = CudaGraphConfig(enable_padding=True) - spec_config = EagleDecodingConfig( + spec_config = Eagle3DecodingConfig( max_draft_len=3, speculative_model=f"{llm_models_root()}/EAGLE3-LLaMA3.1-Instruct-8B", eagle3_one_model=eagle3_one_model) @@ -634,9 +634,9 @@ class TestLlama3_3_70BInstruct(LlmapiAccuracyTestHarness): model_path = f"{llm_models_root()}/modelopt-hf-model-hub/Llama-3.3-70B-Instruct-fp8" eagle_model_dir = f"{llm_models_root()}/EAGLE3-LLaMA3.3-Instruct-70B" kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.6) - spec_config = EagleDecodingConfig(max_draft_len=3, - speculative_model=eagle_model_dir, - eagle3_one_model=eagle3_one_model) + spec_config = Eagle3DecodingConfig(max_draft_len=3, + speculative_model=eagle_model_dir, + eagle3_one_model=eagle3_one_model) torch_compile_config = _get_default_torch_compile_config(torch_compile) pytorch_config = dict( disable_overlap_scheduler=not eagle3_one_model, @@ -3567,9 +3567,9 @@ class TestQwen3_8B(LlmapiAccuracyTestHarness): target_model_dir = f"{llm_models_root()}/Qwen3/Qwen3-8B" draft_len = 4 - spec_config = EagleDecodingConfig(max_draft_len=draft_len, - speculative_model=eagle_model_dir, - eagle3_one_model=eagle3_one_model) + spec_config = Eagle3DecodingConfig(max_draft_len=draft_len, + speculative_model=eagle_model_dir, + eagle3_one_model=eagle3_one_model) llm = LLM(model=target_model_dir, **pytorch_config, @@ -3937,7 +3937,7 @@ class TestQwen3_235B_A22B(LlmapiAccuracyTestHarness): enable_block_reuse=not eagle3) spec_config = None if eagle3: - spec_config = EagleDecodingConfig( + spec_config = Eagle3DecodingConfig( max_draft_len=2, speculative_model= f"{llm_models_root()}/Qwen3/qwen3-235B-eagle3/", @@ -3985,7 +3985,7 @@ class TestQwen3_235B_A22B(LlmapiAccuracyTestHarness): enable_block_reuse=not eagle3) spec_config = None if eagle3: - spec_config = EagleDecodingConfig( + spec_config = Eagle3DecodingConfig( max_draft_len=2, speculative_model= f"{llm_models_root()}/Qwen3/qwen3-235B-eagle3/", @@ -4637,10 +4637,10 @@ class TestGPTOSS(LlmapiAccuracyTestHarness): eagle_model_dir = f"{llm_models_root()}/gpt_oss/gpt-oss-120b-Eagle3" draft_len = 3 - spec_config = EagleDecodingConfig(max_draft_len=draft_len, - speculative_model=eagle_model_dir, - eagle3_one_model=one_model, - allow_advanced_sampling=True) + spec_config = Eagle3DecodingConfig(max_draft_len=draft_len, + speculative_model=eagle_model_dir, + eagle3_one_model=one_model, + allow_advanced_sampling=True) max_seq_len = MAX_INPUT_LEN + MAX_OUTPUT_LEN llm = LLM(self.MODEL_PATH, @@ -4703,10 +4703,10 @@ class TestGPTOSS(LlmapiAccuracyTestHarness): eagle_model_dir = f"{llm_models_root()}/gpt_oss/gpt-oss-120b-Eagle3" draft_len = 3 - spec_config = EagleDecodingConfig(max_draft_len=draft_len, - speculative_model=eagle_model_dir, - eagle3_one_model=one_model, - allow_advanced_sampling=True) + spec_config = Eagle3DecodingConfig(max_draft_len=draft_len, + speculative_model=eagle_model_dir, + eagle3_one_model=one_model, + allow_advanced_sampling=True) max_seq_len = MAX_INPUT_LEN + MAX_OUTPUT_LEN llm = LLM(self.MODEL_PATH, @@ -4767,10 +4767,10 @@ class TestGPTOSS(LlmapiAccuracyTestHarness): eagle_model_dir = f"{llm_models_root()}/gpt_oss/gpt-oss-120b-Eagle3" draft_len = 3 - spec_config = EagleDecodingConfig(max_draft_len=draft_len, - speculative_model=eagle_model_dir, - eagle3_one_model=one_model, - allow_advanced_sampling=True) + spec_config = Eagle3DecodingConfig(max_draft_len=draft_len, + speculative_model=eagle_model_dir, + eagle3_one_model=one_model, + allow_advanced_sampling=True) max_seq_len = MAX_INPUT_LEN + MAX_OUTPUT_LEN llm = LLM(self.MODEL_PATH, @@ -4826,9 +4826,9 @@ class TestGPTOSS(LlmapiAccuracyTestHarness): eagle_model_dir = f"{llm_models_root()}/gpt_oss/gpt-oss-120b-Eagle3" draft_len = 3 - spec_config = EagleDecodingConfig(max_draft_len=draft_len, - speculative_model=eagle_model_dir, - eagle3_one_model=one_model) + spec_config = Eagle3DecodingConfig(max_draft_len=draft_len, + speculative_model=eagle_model_dir, + eagle3_one_model=one_model) max_seq_len = MAX_INPUT_LEN + MAX_OUTPUT_LEN llm = LLM(self.MODEL_PATH, @@ -5307,7 +5307,7 @@ class TestMistralLarge3_675B(LlmapiAccuracyTestHarness): enable_block_reuse=not eagle3) spec_config = None if eagle3: - spec_config = EagleDecodingConfig( + spec_config = Eagle3DecodingConfig( max_draft_len=2, speculative_model= f"{llm_models_root()}/Mistral-Large-3-675B/Mistral-Large-3-675B-Instruct-2512-Eagle/", @@ -5358,7 +5358,7 @@ class TestMistralLarge3_675B(LlmapiAccuracyTestHarness): enable_block_reuse=not eagle3) spec_config = None if eagle3: - spec_config = EagleDecodingConfig( + spec_config = Eagle3DecodingConfig( max_draft_len=2, speculative_model= f"{llm_models_root()}/Mistral-Large-3-675B/Mistral-Large-3-675B-Instruct-2512-Eagle/", diff --git a/tests/integration/defs/disaggregated/test_disaggregated_single_gpu.py b/tests/integration/defs/disaggregated/test_disaggregated_single_gpu.py index 1d10ead5f8..f72c072aaa 100644 --- a/tests/integration/defs/disaggregated/test_disaggregated_single_gpu.py +++ b/tests/integration/defs/disaggregated/test_disaggregated_single_gpu.py @@ -13,7 +13,7 @@ from tensorrt_llm import LLM, DisaggregatedParams, SamplingParams from tensorrt_llm._utils import set_mpi_comm from tensorrt_llm.llmapi import (CacheTransceiverConfig, CudaGraphConfig, KvCacheConfig, MpiCommSession) -from tensorrt_llm.llmapi.llm_args import EagleDecodingConfig +from tensorrt_llm.llmapi.llm_args import Eagle3DecodingConfig cloudpickle.register_pickle_by_value(sys.modules[__name__]) MPI.pickle.__init__( @@ -407,7 +407,7 @@ def test_disaggregated_spec_dec_batch_slot_limit(model, spec_dec_model_path, eagle3_one_model): # Test whether the batch slots are properly released when using speculative decoding # with disaggregated serving. - spec_dec_config = EagleDecodingConfig( + spec_dec_config = Eagle3DecodingConfig( speculative_model=model_path(spec_dec_model_path), eagle3_one_model=eagle3_one_model, max_draft_len=3) diff --git a/tests/integration/defs/examples/test_ad_speculative_decoding.py b/tests/integration/defs/examples/test_ad_speculative_decoding.py index e1492a0153..e77709aa7c 100644 --- a/tests/integration/defs/examples/test_ad_speculative_decoding.py +++ b/tests/integration/defs/examples/test_ad_speculative_decoding.py @@ -21,7 +21,7 @@ from defs.conftest import llm_models_root from tensorrt_llm import SamplingParams from tensorrt_llm._torch.auto_deploy.llm import LLM -from tensorrt_llm.llmapi import DraftTargetDecodingConfig, EagleDecodingConfig, KvCacheConfig +from tensorrt_llm.llmapi import DraftTargetDecodingConfig, Eagle3DecodingConfig, KvCacheConfig prompts = [ "What is the capital of France?", @@ -57,7 +57,7 @@ def make_draft_target_config(spec_model_path: str): def make_eagle3_config(spec_model_path: str): - return EagleDecodingConfig( + return Eagle3DecodingConfig( max_draft_len=EAGLE_MAX_DRAFT_LEN, speculative_model=spec_model_path, eagle3_one_model=False, @@ -214,7 +214,7 @@ def test_autodeploy_eagle3_acceptance_rate(): max_draft_len = EAGLE_MAX_DRAFT_LEN # Configure Eagle3 speculative decoding - speculative_config = EagleDecodingConfig( + speculative_config = Eagle3DecodingConfig( max_draft_len=max_draft_len, speculative_model=eagle_model, eagle3_one_model=False, diff --git a/tests/integration/defs/test_e2e.py b/tests/integration/defs/test_e2e.py index d37df0bed2..778b1452a9 100644 --- a/tests/integration/defs/test_e2e.py +++ b/tests/integration/defs/test_e2e.py @@ -3482,7 +3482,7 @@ def test_eagle3_output_consistency_4gpus(model_dir: str, draft_model_dir: str): RCCA: https://nvbugspro.nvidia.com/bug/5575211 """ from tensorrt_llm import LLM, SamplingParams - from tensorrt_llm.llmapi import (CudaGraphConfig, EagleDecodingConfig, + from tensorrt_llm.llmapi import (CudaGraphConfig, Eagle3DecodingConfig, KvCacheConfig) models_path = llm_models_root() @@ -3519,7 +3519,7 @@ def test_eagle3_output_consistency_4gpus(model_dir: str, draft_model_dir: str): sampling_params = SamplingParams(max_tokens=1024, temperature=0) # Run with Eagle3 - spec_config = EagleDecodingConfig( + spec_config = Eagle3DecodingConfig( max_draft_len=3, speculative_model=eagle_model_dir, eagle3_one_model=True, diff --git a/tests/unittest/_torch/speculative/test_draft_token_prepare_for_generation.py b/tests/unittest/_torch/speculative/test_draft_token_prepare_for_generation.py index 352d6b743f..7293cf9158 100644 --- a/tests/unittest/_torch/speculative/test_draft_token_prepare_for_generation.py +++ b/tests/unittest/_torch/speculative/test_draft_token_prepare_for_generation.py @@ -10,7 +10,7 @@ from tensorrt_llm._torch.attention_backend.trtllm import TrtllmAttentionMetadata from tensorrt_llm._torch.speculative.drafting_loops import TreeDraftingLoopWrapper from tensorrt_llm._torch.speculative.eagle3 import Eagle3ResourceManager, Eagle3SpecMetadata from tensorrt_llm._torch.speculative.spec_tree_manager import SpecTreeManager -from tensorrt_llm.llmapi import EagleDecodingConfig +from tensorrt_llm.llmapi import Eagle3DecodingConfig sys.path.append(os.path.join(os.path.dirname(__file__), "..")) @@ -84,7 +84,7 @@ def test_draft_token_static_tree_prepare_for_generation(): ) # 2) Create spec metadata - spec_config = EagleDecodingConfig( + spec_config = Eagle3DecodingConfig( max_draft_len=max_draft_len, max_total_draft_tokens=max_total_draft_tokens, speculative_model=eagle_model_dir, diff --git a/tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py b/tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py index 689da99cf2..f623167587 100644 --- a/tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py +++ b/tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py @@ -8,7 +8,7 @@ from utils.llm_data import llm_models_root from tensorrt_llm._torch.speculative.drafting_loops import \ TreeDraftingLoopWrapper from tensorrt_llm._torch.speculative.spec_tree_manager import SpecTreeManager -from tensorrt_llm.llmapi import EagleDecodingConfig +from tensorrt_llm.llmapi import Eagle3DecodingConfig sys.path.append(os.path.join(os.path.dirname(__file__), '..')) @@ -35,7 +35,7 @@ def test_draft_token_static_tree_sampling(): def run_test(max_batch_size, draft_layer_id, max_total_draft_tokens, max_draft_len, eagle_choices, logits, use_cuda_graph, ref_new_tokens): - spec_config = EagleDecodingConfig( + spec_config = Eagle3DecodingConfig( max_draft_len=max_draft_len, max_total_draft_tokens=max_total_draft_tokens, speculative_model=eagle_model_dir, diff --git a/tests/unittest/_torch/speculative/test_draft_token_tree_verification.py b/tests/unittest/_torch/speculative/test_draft_token_tree_verification.py index 29a19a04cc..a5d18f8fed 100644 --- a/tests/unittest/_torch/speculative/test_draft_token_tree_verification.py +++ b/tests/unittest/_torch/speculative/test_draft_token_tree_verification.py @@ -11,7 +11,7 @@ from tensorrt_llm._torch.pyexecutor.llm_request import (LlmRequest, from tensorrt_llm._torch.pyexecutor.sampler import TorchSampler from tensorrt_llm._torch.speculative.spec_tree_manager import SpecTreeManager from tensorrt_llm.bindings.executor import FinishReason -from tensorrt_llm.llmapi import EagleDecodingConfig +from tensorrt_llm.llmapi import Eagle3DecodingConfig sys.path.append(os.path.join(os.path.dirname(__file__), '..')) @@ -20,7 +20,7 @@ def run_test(eagle_model_dir, max_seq_len, beam_width, use_dynamic_tree, max_new_tokens, max_batch_size, input_request, input_new_tokens, draft_layer_id, max_total_draft_tokens, max_draft_len, eagle_choices, ref_num_accepted_draft_tokens, ref_mtokens): - spec_config = EagleDecodingConfig( + spec_config = Eagle3DecodingConfig( max_draft_len=max_draft_len, max_total_draft_tokens=max_total_draft_tokens, speculative_model=eagle_model_dir, diff --git a/tests/unittest/_torch/speculative/test_dynamic_spec_decode.py b/tests/unittest/_torch/speculative/test_dynamic_spec_decode.py index eaa215c81e..1140f53c62 100644 --- a/tests/unittest/_torch/speculative/test_dynamic_spec_decode.py +++ b/tests/unittest/_torch/speculative/test_dynamic_spec_decode.py @@ -9,7 +9,7 @@ from utils.llm_data import llm_models_root from tensorrt_llm import LLM, SamplingParams from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequestState -from tensorrt_llm.llmapi import (CudaGraphConfig, EagleDecodingConfig, +from tensorrt_llm.llmapi import (CudaGraphConfig, Eagle3DecodingConfig, KvCacheConfig) sys.path.append(os.path.join(os.path.dirname(__file__), '..')) @@ -54,7 +54,7 @@ def test_dynamic_spec_decode(enforce_single_worker, max_seq_len=8192, ) - spec_config = EagleDecodingConfig( + spec_config = Eagle3DecodingConfig( max_draft_len=max_draft_len, speculative_model=eagle_model_dir, # Llama 3 does not support one model eagle. diff --git a/tests/unittest/_torch/speculative/test_eagle3.py b/tests/unittest/_torch/speculative/test_eagle3.py index 578ff41cb3..71befcfef8 100644 --- a/tests/unittest/_torch/speculative/test_eagle3.py +++ b/tests/unittest/_torch/speculative/test_eagle3.py @@ -14,7 +14,7 @@ from utils.llm_data import llm_models_root from tensorrt_llm import LLM, SamplingParams from tensorrt_llm._torch.attention_backend.trtllm import TrtllmAttentionMetadata from tensorrt_llm._torch.metadata import KVCacheParams -from tensorrt_llm.llmapi import (CudaGraphConfig, EagleDecodingConfig, +from tensorrt_llm.llmapi import (CudaGraphConfig, Eagle3DecodingConfig, KvCacheConfig) sys.path.append(os.path.join(os.path.dirname(__file__), '..')) @@ -199,7 +199,7 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str, # Use a small max_num_tokens so that the chunked prefill path gets exercised. llm_common_config['max_num_tokens'] = 64 - spec_config = EagleDecodingConfig( + spec_config = Eagle3DecodingConfig( max_draft_len=max_draft_len, speculative_model=eagle_model, # Llama 3 does not support one model eagle. @@ -275,7 +275,7 @@ def test_eagle3_spec_decoding_stats(eagle3_one_model): kv_cache_config = KvCacheConfig(enable_block_reuse=False, free_gpu_memory_fraction=0.6) - spec_config = EagleDecodingConfig( + spec_config = Eagle3DecodingConfig( max_draft_len=3, speculative_model=eagle_model_dir, eagle3_one_model=eagle3_one_model, @@ -355,7 +355,7 @@ def test_llama_eagle3_long_prompt(use_cuda_graph): eagle_model_dir = f"{models_path}/EAGLE3-LLaMA3.1-Instruct-8B" target_model_dir = f"{models_path}/llama-3.1-model/Llama-3.1-8B-Instruct" - spec_config = EagleDecodingConfig( + spec_config = Eagle3DecodingConfig( max_draft_len=3, speculative_model=eagle_model_dir, eagle3_one_model=False, @@ -449,7 +449,7 @@ def test_deepseek_eagle3(): 'transformers_version': '4.52.4', 'use_cache': True, 'vocab_size': 129280, - 'draft_vocab_size': 129280 + 'draft_vocab_size': 129280, } with tempfile.TemporaryDirectory() as temp_dir: eagle_model_dir = Path(temp_dir) @@ -479,7 +479,7 @@ def test_deepseek_eagle3(): enable_chunked_prefill=enable_chunked_prefill, ) - spec_config = EagleDecodingConfig( + spec_config = Eagle3DecodingConfig( max_draft_len=max_draft_len, speculative_model=eagle_model_dir, # Llama 3 does not support one model eagle. @@ -590,10 +590,10 @@ def test_deepseek_mla_eagle3(): load_format="dummy", ) - spec_config = EagleDecodingConfig(max_draft_len=max_draft_len, - speculative_model=eagle_model_dir, - eagle3_one_model=use_one_model, - load_format="dummy") + spec_config = Eagle3DecodingConfig(max_draft_len=max_draft_len, + speculative_model=eagle_model_dir, + eagle3_one_model=use_one_model, + load_format="dummy") llm_spec = LLM(**llm_common_config, speculative_config=spec_config) @@ -659,7 +659,7 @@ def test_multi_eagle3(use_one_model: bool): 'transformers_version': '4.52.4', 'use_cache': True, 'vocab_size': 128256, - 'draft_vocab_size': 128256 + 'draft_vocab_size': 128256, } with tempfile.TemporaryDirectory() as temp_dir: eagle_model_dir = Path(temp_dir) @@ -688,7 +688,7 @@ def test_multi_eagle3(use_one_model: bool): load_format="dummy", ) - spec_config = EagleDecodingConfig( + spec_config = Eagle3DecodingConfig( max_draft_len=max_draft_len, speculative_model=eagle_model_dir, # Llama 3 does not support one model eagle. @@ -747,7 +747,7 @@ def test_eagle3_cuda_graph_padding(disable_overlap_scheduler: bool): enable_chunked_prefill=enable_chunked_prefill, ) - spec_config = EagleDecodingConfig( + spec_config = Eagle3DecodingConfig( max_draft_len=max_draft_len, speculative_model=eagle_model_dir, eagle3_one_model=use_one_model, @@ -800,7 +800,7 @@ def test_eagle3_cdl_sampling(disable_overlap_scheduler: bool): enable_chunked_prefill=enable_chunked_prefill, ) - spec_config = EagleDecodingConfig( + spec_config = Eagle3DecodingConfig( max_draft_len=max_draft_len, speculative_model=eagle_model_dir, eagle3_one_model=use_one_model, diff --git a/tests/unittest/_torch/speculative/test_kv_cache_reuse.py b/tests/unittest/_torch/speculative/test_kv_cache_reuse.py index eb5a720db1..73078701b2 100644 --- a/tests/unittest/_torch/speculative/test_kv_cache_reuse.py +++ b/tests/unittest/_torch/speculative/test_kv_cache_reuse.py @@ -7,7 +7,7 @@ import torch from utils.llm_data import llm_models_root from tensorrt_llm import LLM, SamplingParams -from tensorrt_llm.llmapi import (CudaGraphConfig, EagleDecodingConfig, +from tensorrt_llm.llmapi import (CudaGraphConfig, Eagle3DecodingConfig, KvCacheConfig) sys.path.append(os.path.join(os.path.dirname(__file__), '..')) @@ -50,7 +50,7 @@ def test_kv_cache_reuse(use_cuda_graph: bool, attn_backend: str): max_seq_len=8192, ) - spec_config = EagleDecodingConfig( + spec_config = Eagle3DecodingConfig( max_draft_len=max_draft_len, speculative_model=eagle_model_dir, eagle3_one_model=False, diff --git a/tests/unittest/_torch/speculative/test_spec_gate.py b/tests/unittest/_torch/speculative/test_spec_gate.py index a99654a9c6..2a3dd9b99e 100644 --- a/tests/unittest/_torch/speculative/test_spec_gate.py +++ b/tests/unittest/_torch/speculative/test_spec_gate.py @@ -9,7 +9,7 @@ from utils.util import similar, skip_blackwell from tensorrt_llm import LLM, SamplingParams from tensorrt_llm._torch.speculative.speculation_gate import SpeculationGate -from tensorrt_llm.llmapi import (CudaGraphConfig, EagleDecodingConfig, +from tensorrt_llm.llmapi import (CudaGraphConfig, Eagle3DecodingConfig, KvCacheConfig) sys.path.append(os.path.join(os.path.dirname(__file__), '..')) @@ -45,7 +45,7 @@ def test_spec_gate_e2e(): max_seq_len=4096, ) - spec_config = EagleDecodingConfig( + spec_config = Eagle3DecodingConfig( max_draft_len=max_draft_len, speculative_model=eagle_model_dir, # Llama 3 does not support one model eagle. diff --git a/tests/unittest/llmapi/test_llm_args.py b/tests/unittest/llmapi/test_llm_args.py index 33bf857be4..a2b086feaa 100644 --- a/tests/unittest/llmapi/test_llm_args.py +++ b/tests/unittest/llmapi/test_llm_args.py @@ -152,6 +152,48 @@ model_kwargs: assert llm_args.model_kwargs['num_hidden_layers'] == 2 +def test_decoding_type_eagle3_parses_to_eagle3_decoding_config(): + spec_cfg = DecodingBaseConfig.from_dict( + dict(decoding_type="Eagle3", + max_draft_len=3, + speculative_model_dir="/path/to/draft/model")) + assert isinstance(spec_cfg, Eagle3DecodingConfig) + + +def test_decoding_type_eagle_warns_on_pytorch_backend(monkeypatch): + import tensorrt_llm.llmapi.llm_args as llm_args_mod + + warnings_seen: list[str] = [] + + def _capture_warning(msg, *args, **kwargs): + warnings_seen.append(str(msg)) + + monkeypatch.setattr(llm_args_mod.logger, "warning", _capture_warning) + + spec_cfg = DecodingBaseConfig.from_dict(dict( + decoding_type="Eagle", + max_draft_len=3, + speculative_model_dir="/path/to/draft/model"), + backend="pytorch") + assert isinstance(spec_cfg, Eagle3DecodingConfig) + + TorchLlmArgs(model=llama_model_path, speculative_config=spec_cfg) + + assert any( + "EAGLE (v1/v2) draft checkpoints are incompatible with Eagle3" in m + for m in warnings_seen) + + +def test_decoding_type_eagle3_errors_on_tensorrt_backend(): + spec_cfg = DecodingBaseConfig.from_dict( + dict(decoding_type="Eagle3", + max_draft_len=3, + speculative_model_dir="/path/to/draft/model")) + with pytest.raises(ValueError, + match="only supported on the PyTorch backend"): + TrtLlmArgs(model=llama_model_path, speculative_config=spec_cfg) + + def check_defaults(py_config_cls, pybind_config_cls): py_config = py_config_cls() pybind_config = pybind_config_cls()