mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Merge dda588ec23 into 6df2c8a074
This commit is contained in:
commit
7c60356daf
@ -84,7 +84,7 @@ kv_cache_config:
|
||||
enable_block_reuse: false
|
||||
free_gpu_memory_fraction: 0.8
|
||||
speculative_config:
|
||||
decoding_type: Eagle
|
||||
decoding_type: Eagle3
|
||||
max_draft_len: 3
|
||||
speculative_model_dir: /config/models/eagle/
|
||||
cuda_graph_config:
|
||||
|
||||
@ -68,7 +68,7 @@ docker run -d --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 \
|
||||
-p 8000:8000 --gpus=all -e "TRTLLM_ENABLE_PDL=1" \
|
||||
-v /path/to/maverick:/config/models/maverick -v /path/to/eagle:/config/models/eagle \
|
||||
docker.io/<username>/tensorrt_llm:main sh \
|
||||
-c "echo -e 'enable_autotuner: false\nenable_attention_dp: false\nenable_min_latency: true\ncuda_graph_config:\n max_batch_size: 8\nspeculative_config:\n decoding_type: Eagle\n max_draft_len: 3\n speculative_model_dir: /config/models/eagle\n eagle3_one_model: true\nkv_cache_config:\n enable_block_reuse: false' > c.yaml && \
|
||||
-c "echo -e 'enable_autotuner: false\nenable_attention_dp: false\nenable_min_latency: true\ncuda_graph_config:\n max_batch_size: 8\nspeculative_config:\n decoding_type: Eagle3\n max_draft_len: 3\n speculative_model_dir: /config/models/eagle\n eagle3_one_model: true\nkv_cache_config:\n enable_block_reuse: false' > c.yaml && \
|
||||
TRT_LLM_DISABLE_LOAD_WEIGHTS_IN_PARALLEL=True \
|
||||
trtllm-serve /config/models/maverick \
|
||||
--host 0.0.0.0 --port 8000 \
|
||||
|
||||
@ -53,12 +53,12 @@ The following draft model checkpoints can be used for EAGLE 3:
|
||||
* Llama 4 Maverick: [use the checkpoint from the NVIDIA HuggingFace repository](https://huggingface.co/nvidia/Llama-4-Maverick-17B-128E-Eagle3).
|
||||
|
||||
```python
|
||||
from tensorrt_llm.llmapi import EagleDecodingConfig
|
||||
from tensorrt_llm.llmapi import Eagle3DecodingConfig
|
||||
|
||||
# Enable to use the faster one-model implementation for Llama 4.
|
||||
eagle3_one_model = False
|
||||
|
||||
speculative_config = EagleDecodingConfig(
|
||||
speculative_config = Eagle3DecodingConfig(
|
||||
max_draft_len=3, speculative_model_dir="/path/to/draft_model", eagle3_one_model=eagle3_one_model)
|
||||
|
||||
# Only need to disable overlap scheduler if eagle3_one_model is False.
|
||||
@ -131,16 +131,18 @@ llm = LLM("/path/to/target_model", speculative_config=speculative_config)
|
||||
Speculative decoding options must be specified via `--config config.yaml` for both `trtllm-bench` and `trtllm-serve`. All speculative decoding options can be specified in this YAML file. An additional `decoding_type` option is used to specify the type of speculation to use. The available options are:
|
||||
|
||||
* `MTP`
|
||||
* `Eagle` (for EAGLE 3)
|
||||
* `Eagle3`
|
||||
* `NGram`
|
||||
* `DraftTarget`
|
||||
|
||||
> Note: The PyTorch backend supports only `Eagle3`. `decoding_type: Eagle` is accepted as a backward-compatible alias for `Eagle3`, but EAGLE (v1/v2) draft checkpoints are incompatible.
|
||||
|
||||
The rest of the argument names/valid values are the same as in their corresponding configuration class described in the Quick Start section. For example, a YAML configuration could look like this:
|
||||
|
||||
```
|
||||
disable_overlap_scheduler: true
|
||||
speculative_config:
|
||||
decoding_type: Eagle
|
||||
decoding_type: Eagle3
|
||||
max_draft_len: 4
|
||||
speculative_model: /path/to/draft/model
|
||||
```
|
||||
|
||||
@ -96,7 +96,7 @@ speculative_config:
|
||||
mtp_eagle_one_model: False # Not supported
|
||||
|
||||
speculative_config:
|
||||
decoding_type: "Eagle"
|
||||
decoding_type: "Eagle3"
|
||||
eagle3_one_model: False # Not supported
|
||||
```
|
||||
|
||||
|
||||
@ -171,6 +171,8 @@ The EAGLE approach enhances the single-model Medusa method by predicting and ver
|
||||
|
||||
Similarly to ReDrafter, TensorRT-LLM implements the EAGLE model such that logits prediction, draft tokens acceptance and draft token generation are performed inside of the TensorRT engine(EAGLE-1 and EAGLE-2 are both supported). Please, visit the [EAGLE README](https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/eagle/README.md) for information about building and running the model.
|
||||
|
||||
> **EAGLE3 note.** If the EAGLE3 draft head config omits `draft_vocab_size`, TensorRT-LLM assumes it matches `vocab_size` and emits a warning. Set `draft_vocab_size` explicitly if the draft head uses a different vocabulary.
|
||||
|
||||
### Disaggregated Serving
|
||||
|
||||
[Disaggregated Serving](https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/features/disaggregated-service.md) with EAGLE3 using the two model approach is supported in the Pytorch backend. Please refer to the following [Dynamo example](https://github.com/ai-dynamo/dynamo/blob/main/examples/tensorrt_llm/llama4_plus_eagle.md) on how to run EAGLE3 with Disaggregated Serving for Llama 4 Maverick.
|
||||
|
||||
@ -6,7 +6,7 @@ from typing import Optional
|
||||
import click
|
||||
|
||||
from tensorrt_llm import LLM, SamplingParams
|
||||
from tensorrt_llm.llmapi import (EagleDecodingConfig, KvCacheConfig,
|
||||
from tensorrt_llm.llmapi import (Eagle3DecodingConfig, KvCacheConfig,
|
||||
MTPDecodingConfig, NGramDecodingConfig)
|
||||
|
||||
prompts = [
|
||||
@ -33,7 +33,7 @@ def run_MTP(model: Optional[str] = None):
|
||||
|
||||
|
||||
def run_Eagle3():
|
||||
spec_config = EagleDecodingConfig(
|
||||
spec_config = Eagle3DecodingConfig(
|
||||
max_draft_len=3,
|
||||
speculative_model_dir="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
|
||||
eagle3_one_model=True)
|
||||
|
||||
@ -5,7 +5,7 @@ import time
|
||||
from tensorrt_llm import LLM, SamplingParams
|
||||
from tensorrt_llm.llmapi import (AttentionDpConfig, AutoDecodingConfig,
|
||||
CudaGraphConfig, DraftTargetDecodingConfig,
|
||||
EagleDecodingConfig, KvCacheConfig, MoeConfig,
|
||||
Eagle3DecodingConfig, KvCacheConfig, MoeConfig,
|
||||
MTPDecodingConfig, NGramDecodingConfig,
|
||||
TorchCompileConfig)
|
||||
|
||||
@ -222,7 +222,7 @@ def setup_llm(args, **kwargs):
|
||||
mtp_eagle_one_model=args.use_one_model,
|
||||
speculative_model_dir=args.model_dir)
|
||||
elif spec_decode_algo == "EAGLE3":
|
||||
spec_config = EagleDecodingConfig(
|
||||
spec_config = Eagle3DecodingConfig(
|
||||
max_draft_len=args.spec_decode_max_draft_len,
|
||||
speculative_model_dir=args.draft_model_dir,
|
||||
eagle3_one_model=args.use_one_model,
|
||||
|
||||
@ -837,8 +837,8 @@ settings for your specific use case.
|
||||
|
||||
Qwen3 now supports Eagle3 (Speculative Decoding with Eagle3). To enable Eagle3 on Qwen3, you need to set the following arguments when running `trtllm-bench` or `trtllm-serve`:
|
||||
|
||||
- `speculative_config.decoding_type: Eagle`
|
||||
Set the decoding type to "Eagle" to enable Eagle3 speculative decoding.
|
||||
- `speculative_config.decoding_type: Eagle3`
|
||||
Set the decoding type to `Eagle3` to enable Eagle3 speculative decoding.
|
||||
- `speculative_config.max_draft_len: 3`
|
||||
Set the maximum number of draft tokens generated per step (this value can be adjusted as needed).
|
||||
- `speculative_config.speculative_model_dir: <EAGLE3_DRAFT_MODEL_PATH>`
|
||||
@ -855,7 +855,7 @@ Example `config.yml` snippet for Eagle3:
|
||||
echo "
|
||||
enable_attention_dp: false
|
||||
speculative_config:
|
||||
decoding_type: Eagle
|
||||
decoding_type: Eagle3
|
||||
max_draft_len: 3
|
||||
speculative_model_dir: <EAGLE3_DRAFT_MODEL_PATH>
|
||||
kv_cache_config:
|
||||
|
||||
@ -24,12 +24,18 @@ class AutoModelForCausalLM(Generic[TModel, TConfig]):
|
||||
vision_encoder_cls, vlm_base_model = vision_encoder_info
|
||||
return vision_encoder_cls(config, vlm_base_model)
|
||||
|
||||
# Hack to detect eagle3 checkpoints. TODO: should we provide
|
||||
# our own checkpoints with the correct arch? It would let us
|
||||
# avoid nasty stuff like this.
|
||||
model_arch = model_arch.replace("Eagle3",
|
||||
"") # Strip the appended EAGLE3
|
||||
# Hack to detect eagle3 checkpoints.
|
||||
# Why it exists:
|
||||
# - Eagle3 checkpoints have draft_vocab_size in config.json (even if None)
|
||||
# - Some community checkpoints append "Eagle3" to architecture names ("LlamaForCausalLMEagle3")
|
||||
# - Some checkpoints don't include "Eagle3" in arch name at all ("LlamaForCausalLM")
|
||||
# - TensorRT-LLM's MODEL_CLASS_MAPPING expects prefixed names like EAGLE3LlamaForCausalLM
|
||||
# - Hence: LlamaForCausalLMEagle3 -> EAGLE3LlamaForCausalLM
|
||||
# LlamaForCausalLM (with draft_vocab_size) -> EAGLE3LlamaForCausalLM
|
||||
# TODO: should we provide our own checkpoints with the correct arch? It would let us avoid nasty stuff like this.
|
||||
if hasattr(config.pretrained_config, "draft_vocab_size"):
|
||||
# It's an Eagle3 checkpoint - strip "Eagle3" suffix if present, then add prefix
|
||||
model_arch = model_arch.replace("Eagle3", "")
|
||||
model_arch = "EAGLE3" + model_arch
|
||||
if model_arch in (
|
||||
"DeepseekV3ForCausalLM", "Glm4MoeForCausalLM"
|
||||
|
||||
@ -4,6 +4,8 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import LlamaConfig, PretrainedConfig
|
||||
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
from ...functional import PositionEmbeddingType
|
||||
from ..attention_backend import AttentionMetadata
|
||||
from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams
|
||||
@ -24,6 +26,18 @@ from .modeling_utils import (DecoderModel, DecoderModelForCausalLM, TModel,
|
||||
register_auto_model)
|
||||
|
||||
|
||||
def _ensure_draft_vocab_size(config: PretrainedConfig) -> None:
|
||||
if hasattr(config,
|
||||
"draft_vocab_size") and config.draft_vocab_size is not None:
|
||||
return
|
||||
|
||||
logger.warning(
|
||||
"Missing 'draft_vocab_size' in pretrained config; defaulting to 'vocab_size'. "
|
||||
"Set 'draft_vocab_size' explicitly if the draft head uses a different vocabulary."
|
||||
)
|
||||
config.draft_vocab_size = config.vocab_size
|
||||
|
||||
|
||||
class Eagle3Attention(Attention):
|
||||
|
||||
def __init__(
|
||||
@ -417,9 +431,8 @@ class Eagle3ForCausalLM(DecoderModelForCausalLM[Eagle3DraftModel,
|
||||
model_config: ModelConfig[PretrainedConfig],
|
||||
start_layer_idx: int = 0,
|
||||
):
|
||||
draft_vocab_size = model_config.pretrained_config.vocab_size
|
||||
if model_config.pretrained_config.draft_vocab_size is not None:
|
||||
draft_vocab_size = model_config.pretrained_config.draft_vocab_size
|
||||
config = model_config.pretrained_config
|
||||
_ensure_draft_vocab_size(config)
|
||||
|
||||
# Determine if we should use MLA attention based on config
|
||||
# MLA is used for DeepSeekV3-style models that have kv_lora_rank
|
||||
@ -435,8 +448,8 @@ class Eagle3ForCausalLM(DecoderModelForCausalLM[Eagle3DraftModel,
|
||||
super().__init__(
|
||||
draft_model,
|
||||
config=model_config,
|
||||
hidden_size=model_config.pretrained_config.hidden_size,
|
||||
vocab_size=draft_vocab_size,
|
||||
hidden_size=config.hidden_size,
|
||||
vocab_size=config.draft_vocab_size,
|
||||
)
|
||||
self.load_lm_head_from_target = True
|
||||
|
||||
@ -598,6 +611,7 @@ class MistralLarge3DraftModel(DecoderModel):
|
||||
|
||||
|
||||
# We use MistralLarge3 as the base architecture for EAGLE3 draft layers
|
||||
# NOTE: Class name says "Eagle" not "Eagle3" to match checkpoint naming (e.g., "Mistral-Large-3-675B-Instruct-2512-Eagle")
|
||||
@register_auto_model("MistralLarge3EagleForCausalLM")
|
||||
class MistralLarge3EagleForCausalLM(DecoderModelForCausalLM):
|
||||
|
||||
|
||||
@ -10,10 +10,11 @@ from .llm_args import (AttentionDpConfig, AutoDecodingConfig, BatchingType,
|
||||
CapacitySchedulerPolicy, ContextChunkingPolicy,
|
||||
CudaGraphConfig, DeepSeekSparseAttentionConfig,
|
||||
DraftTargetDecodingConfig, DynamicBatchConfig,
|
||||
EagleDecodingConfig, ExtendedRuntimePerfKnobConfig,
|
||||
KvCacheConfig, LlmArgs, LookaheadDecodingConfig,
|
||||
MedusaDecodingConfig, MoeConfig, MTPDecodingConfig,
|
||||
NGramDecodingConfig, RocketSparseAttentionConfig,
|
||||
Eagle3DecodingConfig, EagleDecodingConfig,
|
||||
ExtendedRuntimePerfKnobConfig, KvCacheConfig, LlmArgs,
|
||||
LookaheadDecodingConfig, MedusaDecodingConfig, MoeConfig,
|
||||
MTPDecodingConfig, NGramDecodingConfig,
|
||||
RocketSparseAttentionConfig,
|
||||
SaveHiddenStatesDecodingConfig, SchedulerConfig,
|
||||
SkipSoftmaxAttentionConfig, TorchCompileConfig,
|
||||
TorchLlmArgs, TrtLlmArgs, UserProvidedDecodingConfig)
|
||||
@ -38,6 +39,7 @@ __all__ = [
|
||||
'LookaheadDecodingConfig',
|
||||
'MedusaDecodingConfig',
|
||||
'EagleDecodingConfig',
|
||||
'Eagle3DecodingConfig',
|
||||
'MTPDecodingConfig',
|
||||
'SchedulerConfig',
|
||||
'CapacitySchedulerPolicy',
|
||||
|
||||
@ -708,6 +708,8 @@ class DecodingBaseConfig(StrictBaseModel):
|
||||
_allow_chain_drafter: bool = PrivateAttr(True)
|
||||
# If set, drafting uses greedy sampling, irrespective of sampling parameters.
|
||||
_allow_greedy_draft_tokens: bool = PrivateAttr(True)
|
||||
# Internal: record decoding_type alias used during parsing (for warnings).
|
||||
_decoding_type_alias: Optional[str] = PrivateAttr(default=None)
|
||||
|
||||
@field_validator('draft_len_schedule')
|
||||
@classmethod
|
||||
@ -755,13 +757,14 @@ class DecodingBaseConfig(StrictBaseModel):
|
||||
return v
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict):
|
||||
def from_dict(cls, data: dict, backend: Optional[str] = None):
|
||||
# dispatch to the correct decoding config
|
||||
decoding_type = data.get("decoding_type")
|
||||
config_classes = {
|
||||
"MTP": MTPDecodingConfig,
|
||||
"Medusa": MedusaDecodingConfig,
|
||||
"Eagle": EagleDecodingConfig,
|
||||
"Eagle3": Eagle3DecodingConfig,
|
||||
"Lookahead": LookaheadDecodingConfig,
|
||||
"NGram": NGramDecodingConfig,
|
||||
"DraftTarget": DraftTargetDecodingConfig,
|
||||
@ -770,6 +773,14 @@ class DecodingBaseConfig(StrictBaseModel):
|
||||
"AUTO": AutoDecodingConfig,
|
||||
}
|
||||
|
||||
backend = backend.lower() if isinstance(backend, str) else backend
|
||||
if decoding_type == "Eagle" and backend in ("pytorch", "_autodeploy"):
|
||||
data = dict(data)
|
||||
data.pop("decoding_type")
|
||||
spec_cfg = Eagle3DecodingConfig(**data)
|
||||
spec_cfg._decoding_type_alias = "Eagle"
|
||||
return spec_cfg
|
||||
|
||||
config_class = config_classes.get(decoding_type)
|
||||
if config_class is None:
|
||||
raise ValueError(f"Invalid decoding type: {decoding_type}")
|
||||
@ -966,6 +977,10 @@ class EagleDecodingConfig(DecodingBaseConfig):
|
||||
return False
|
||||
|
||||
|
||||
class Eagle3DecodingConfig(EagleDecodingConfig):
|
||||
decoding_type: ClassVar[str] = "Eagle3"
|
||||
|
||||
|
||||
class SaveHiddenStatesDecodingConfig(DecodingBaseConfig):
|
||||
output_directory: str
|
||||
write_interval: int = 20
|
||||
@ -2506,9 +2521,14 @@ class TrtLlmArgs(BaseLlmArgs):
|
||||
decoding_mode=DecodingMode.Medusa(),
|
||||
medusa_choices=self.speculative_config.medusa_choices)
|
||||
|
||||
elif isinstance(self.speculative_config, Eagle3DecodingConfig):
|
||||
raise ValueError(
|
||||
"speculative_config.decoding_type 'Eagle3' is only supported on the PyTorch backend. "
|
||||
"Use decoding_type 'Eagle' for the TensorRT backend.")
|
||||
|
||||
elif isinstance(self.speculative_config, EagleDecodingConfig):
|
||||
assert self.speculative_config.max_draft_len > 0
|
||||
assert self.speculative_config.speculative_model_dir is not None, "Path to EAGLE3 weights must be specified."
|
||||
assert self.speculative_config.speculative_model_dir is not None, "Path to EAGLE weights must be specified."
|
||||
self.build_config.max_draft_len = self.speculative_config.max_draft_len
|
||||
self.build_config.speculative_decoding_mode = SpeculativeDecodingMode.EAGLE
|
||||
eagle_config = _EagleConfig(
|
||||
@ -3024,6 +3044,14 @@ class TorchLlmArgs(BaseLlmArgs):
|
||||
f"support backend {self.backend}")
|
||||
|
||||
if isinstance(self.speculative_config, EagleDecodingConfig):
|
||||
if (getattr(self.speculative_config, "_decoding_type_alias",
|
||||
None) == "Eagle" or type(self.speculative_config)
|
||||
is EagleDecodingConfig):
|
||||
logger.warning(
|
||||
"speculative_config.decoding_type 'Eagle' is not supported on the PyTorch backend; only 'Eagle3' is supported. "
|
||||
"'Eagle' is treated as 'Eagle3' for backward compatibility. "
|
||||
"EAGLE (v1/v2) draft checkpoints are incompatible with Eagle3—use an Eagle3 draft model."
|
||||
)
|
||||
assert self.speculative_config.max_draft_len > 0
|
||||
assert self.speculative_config.speculative_model_dir is not None, "Path to EAGLE3 weights must be specified."
|
||||
elif isinstance(self.speculative_config, NGramDecodingConfig):
|
||||
@ -3323,8 +3351,14 @@ def update_llm_args_with_extra_dict(
|
||||
if field_name in llm_args_dict:
|
||||
# Some fields need to be converted manually.
|
||||
if field_name in ["speculative_config", "sparse_attention_config"]:
|
||||
llm_args_dict[field_name] = field_type.from_dict(
|
||||
llm_args_dict[field_name])
|
||||
if field_name == "speculative_config":
|
||||
backend = llm_args_dict.get("backend") or llm_args.get(
|
||||
"backend")
|
||||
llm_args_dict[field_name] = field_type.from_dict(
|
||||
llm_args_dict[field_name], backend=backend)
|
||||
else:
|
||||
llm_args_dict[field_name] = field_type.from_dict(
|
||||
llm_args_dict[field_name])
|
||||
else:
|
||||
llm_args_dict[field_name] = field_type(
|
||||
**llm_args_dict[field_name])
|
||||
|
||||
@ -30,8 +30,8 @@ from ..module import Module
|
||||
from .build_cache import (BuildCache, BuildCacheConfig, CachedStage,
|
||||
get_build_cache_config_from_env)
|
||||
from .llm_args import (CalibConfig, CudaGraphConfig, DraftTargetDecodingConfig,
|
||||
EagleDecodingConfig, KvCacheConfig, LlmArgs,
|
||||
LookaheadDecodingConfig, MedusaDecodingConfig,
|
||||
Eagle3DecodingConfig, EagleDecodingConfig, KvCacheConfig,
|
||||
LlmArgs, LookaheadDecodingConfig, MedusaDecodingConfig,
|
||||
MTPDecodingConfig, NGramDecodingConfig,
|
||||
UserProvidedDecodingConfig, _ModelFormatKind,
|
||||
_ModelWrapper, _ParallelConfig,
|
||||
@ -923,6 +923,7 @@ __all__ = [
|
||||
'KvCacheConfig',
|
||||
'CachedModelLoader',
|
||||
'EagleDecodingConfig',
|
||||
'Eagle3DecodingConfig',
|
||||
'update_llm_args_with_extra_dict',
|
||||
'update_llm_args_with_extra_options',
|
||||
]
|
||||
|
||||
@ -52,7 +52,7 @@ from tensorrt_llm._torch.modules.fused_moe.fused_moe_triton import \
|
||||
IS_TRITON_KERNELS_AVAILABLE
|
||||
from tensorrt_llm.llmapi import (AutoDecodingConfig, CudaGraphConfig,
|
||||
DeepSeekSparseAttentionConfig,
|
||||
EagleDecodingConfig, KvCacheConfig, MoeConfig,
|
||||
Eagle3DecodingConfig, KvCacheConfig, MoeConfig,
|
||||
MTPDecodingConfig, NGramDecodingConfig,
|
||||
RocketSparseAttentionConfig, SamplingParams,
|
||||
SkipSoftmaxAttentionConfig, TorchCompileConfig)
|
||||
@ -275,9 +275,10 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
|
||||
target_model_dir = f"{llm_models_root()}/llama-3.1-model/Llama-3.1-8B-Instruct"
|
||||
|
||||
draft_len = 4
|
||||
spec_config = EagleDecodingConfig(max_draft_len=draft_len,
|
||||
speculative_model_dir=eagle_model_dir,
|
||||
eagle3_one_model=eagle3_one_model)
|
||||
spec_config = Eagle3DecodingConfig(
|
||||
max_draft_len=draft_len,
|
||||
speculative_model_dir=eagle_model_dir,
|
||||
eagle3_one_model=eagle3_one_model)
|
||||
|
||||
with LLM(model=target_model_dir,
|
||||
**pytorch_config,
|
||||
@ -367,7 +368,7 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
|
||||
mocker.patch.dict(os.environ, {"TRTLLM_XGUIDANCE_LENIENT": "1"})
|
||||
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.8)
|
||||
cuda_graph_config = CudaGraphConfig(enable_padding=True)
|
||||
spec_config = EagleDecodingConfig(
|
||||
spec_config = Eagle3DecodingConfig(
|
||||
max_draft_len=3,
|
||||
speculative_model_dir=
|
||||
f"{llm_models_root()}/EAGLE3-LLaMA3.1-Instruct-8B",
|
||||
@ -620,9 +621,10 @@ class TestLlama3_3_70BInstruct(LlmapiAccuracyTestHarness):
|
||||
model_path = f"{llm_models_root()}/modelopt-hf-model-hub/Llama-3.3-70B-Instruct-fp8"
|
||||
eagle_model_dir = f"{llm_models_root()}/EAGLE3-LLaMA3.3-Instruct-70B"
|
||||
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.6)
|
||||
spec_config = EagleDecodingConfig(max_draft_len=3,
|
||||
speculative_model_dir=eagle_model_dir,
|
||||
eagle3_one_model=eagle3_one_model)
|
||||
spec_config = Eagle3DecodingConfig(
|
||||
max_draft_len=3,
|
||||
speculative_model_dir=eagle_model_dir,
|
||||
eagle3_one_model=eagle3_one_model)
|
||||
torch_compile_config = _get_default_torch_compile_config(torch_compile)
|
||||
pytorch_config = dict(
|
||||
disable_overlap_scheduler=not eagle3_one_model,
|
||||
@ -3440,9 +3442,10 @@ class TestQwen3_8B(LlmapiAccuracyTestHarness):
|
||||
target_model_dir = f"{llm_models_root()}/Qwen3/Qwen3-8B"
|
||||
|
||||
draft_len = 4
|
||||
spec_config = EagleDecodingConfig(max_draft_len=draft_len,
|
||||
speculative_model_dir=eagle_model_dir,
|
||||
eagle3_one_model=eagle3_one_model)
|
||||
spec_config = Eagle3DecodingConfig(
|
||||
max_draft_len=draft_len,
|
||||
speculative_model_dir=eagle_model_dir,
|
||||
eagle3_one_model=eagle3_one_model)
|
||||
|
||||
llm = LLM(model=target_model_dir,
|
||||
**pytorch_config,
|
||||
@ -3810,7 +3813,7 @@ class TestQwen3_235B_A22B(LlmapiAccuracyTestHarness):
|
||||
enable_block_reuse=not eagle3)
|
||||
spec_config = None
|
||||
if eagle3:
|
||||
spec_config = EagleDecodingConfig(
|
||||
spec_config = Eagle3DecodingConfig(
|
||||
max_draft_len=2,
|
||||
speculative_model_dir=
|
||||
f"{llm_models_root()}/Qwen3/qwen3-235B-eagle3/",
|
||||
@ -3858,7 +3861,7 @@ class TestQwen3_235B_A22B(LlmapiAccuracyTestHarness):
|
||||
enable_block_reuse=not eagle3)
|
||||
spec_config = None
|
||||
if eagle3:
|
||||
spec_config = EagleDecodingConfig(
|
||||
spec_config = Eagle3DecodingConfig(
|
||||
max_draft_len=2,
|
||||
speculative_model_dir=
|
||||
f"{llm_models_root()}/Qwen3/qwen3-235B-eagle3/",
|
||||
@ -4478,10 +4481,11 @@ class TestGPTOSS(LlmapiAccuracyTestHarness):
|
||||
|
||||
eagle_model_dir = f"{llm_models_root()}/gpt_oss/gpt-oss-120b-Eagle3"
|
||||
draft_len = 3
|
||||
spec_config = EagleDecodingConfig(max_draft_len=draft_len,
|
||||
speculative_model_dir=eagle_model_dir,
|
||||
eagle3_one_model=one_model,
|
||||
allow_advanced_sampling=True)
|
||||
spec_config = Eagle3DecodingConfig(
|
||||
max_draft_len=draft_len,
|
||||
speculative_model_dir=eagle_model_dir,
|
||||
eagle3_one_model=one_model,
|
||||
allow_advanced_sampling=True)
|
||||
|
||||
max_seq_len = MAX_INPUT_LEN + MAX_OUTPUT_LEN
|
||||
llm = LLM(self.MODEL_PATH,
|
||||
@ -4544,10 +4548,11 @@ class TestGPTOSS(LlmapiAccuracyTestHarness):
|
||||
|
||||
eagle_model_dir = f"{llm_models_root()}/gpt_oss/gpt-oss-120b-Eagle3"
|
||||
draft_len = 3
|
||||
spec_config = EagleDecodingConfig(max_draft_len=draft_len,
|
||||
speculative_model_dir=eagle_model_dir,
|
||||
eagle3_one_model=one_model,
|
||||
allow_advanced_sampling=True)
|
||||
spec_config = Eagle3DecodingConfig(
|
||||
max_draft_len=draft_len,
|
||||
speculative_model_dir=eagle_model_dir,
|
||||
eagle3_one_model=one_model,
|
||||
allow_advanced_sampling=True)
|
||||
|
||||
max_seq_len = MAX_INPUT_LEN + MAX_OUTPUT_LEN
|
||||
llm = LLM(self.MODEL_PATH,
|
||||
@ -4608,10 +4613,11 @@ class TestGPTOSS(LlmapiAccuracyTestHarness):
|
||||
|
||||
eagle_model_dir = f"{llm_models_root()}/gpt_oss/gpt-oss-120b-Eagle3"
|
||||
draft_len = 3
|
||||
spec_config = EagleDecodingConfig(max_draft_len=draft_len,
|
||||
speculative_model_dir=eagle_model_dir,
|
||||
eagle3_one_model=one_model,
|
||||
allow_advanced_sampling=True)
|
||||
spec_config = Eagle3DecodingConfig(
|
||||
max_draft_len=draft_len,
|
||||
speculative_model_dir=eagle_model_dir,
|
||||
eagle3_one_model=one_model,
|
||||
allow_advanced_sampling=True)
|
||||
|
||||
max_seq_len = MAX_INPUT_LEN + MAX_OUTPUT_LEN
|
||||
llm = LLM(self.MODEL_PATH,
|
||||
@ -4667,9 +4673,10 @@ class TestGPTOSS(LlmapiAccuracyTestHarness):
|
||||
|
||||
eagle_model_dir = f"{llm_models_root()}/gpt_oss/gpt-oss-120b-Eagle3"
|
||||
draft_len = 3
|
||||
spec_config = EagleDecodingConfig(max_draft_len=draft_len,
|
||||
speculative_model_dir=eagle_model_dir,
|
||||
eagle3_one_model=one_model)
|
||||
spec_config = Eagle3DecodingConfig(
|
||||
max_draft_len=draft_len,
|
||||
speculative_model_dir=eagle_model_dir,
|
||||
eagle3_one_model=one_model)
|
||||
|
||||
max_seq_len = MAX_INPUT_LEN + MAX_OUTPUT_LEN
|
||||
llm = LLM(self.MODEL_PATH,
|
||||
@ -5148,7 +5155,7 @@ class TestMistralLarge3_675B(LlmapiAccuracyTestHarness):
|
||||
enable_block_reuse=not eagle3)
|
||||
spec_config = None
|
||||
if eagle3:
|
||||
spec_config = EagleDecodingConfig(
|
||||
spec_config = Eagle3DecodingConfig(
|
||||
max_draft_len=2,
|
||||
speculative_model_dir=
|
||||
f"{llm_models_root()}/Mistral-Large-3-675B/Mistral-Large-3-675B-Instruct-2512-Eagle/",
|
||||
@ -5199,7 +5206,7 @@ class TestMistralLarge3_675B(LlmapiAccuracyTestHarness):
|
||||
enable_block_reuse=not eagle3)
|
||||
spec_config = None
|
||||
if eagle3:
|
||||
spec_config = EagleDecodingConfig(
|
||||
spec_config = Eagle3DecodingConfig(
|
||||
max_draft_len=2,
|
||||
speculative_model_dir=
|
||||
f"{llm_models_root()}/Mistral-Large-3-675B/Mistral-Large-3-675B-Instruct-2512-Eagle/",
|
||||
|
||||
@ -13,7 +13,7 @@ from tensorrt_llm import LLM, DisaggregatedParams, SamplingParams
|
||||
from tensorrt_llm._utils import set_mpi_comm
|
||||
from tensorrt_llm.llmapi import (CacheTransceiverConfig, CudaGraphConfig,
|
||||
KvCacheConfig, MpiCommSession)
|
||||
from tensorrt_llm.llmapi.llm_args import EagleDecodingConfig
|
||||
from tensorrt_llm.llmapi.llm_args import Eagle3DecodingConfig
|
||||
|
||||
cloudpickle.register_pickle_by_value(sys.modules[__name__])
|
||||
MPI.pickle.__init__(
|
||||
@ -399,7 +399,7 @@ def test_disaggregated_spec_dec_batch_slot_limit(model, spec_dec_model_path,
|
||||
eagle3_one_model):
|
||||
# Test whether the batch slots are properly released when using speculative decoding
|
||||
# with disaggregated serving.
|
||||
spec_dec_config = EagleDecodingConfig(
|
||||
spec_dec_config = Eagle3DecodingConfig(
|
||||
speculative_model_dir=model_path(spec_dec_model_path),
|
||||
eagle3_one_model=eagle3_one_model,
|
||||
max_draft_len=3)
|
||||
|
||||
@ -21,7 +21,7 @@ from defs.conftest import llm_models_root
|
||||
|
||||
from tensorrt_llm import SamplingParams
|
||||
from tensorrt_llm._torch.auto_deploy.llm import LLM
|
||||
from tensorrt_llm.llmapi import DraftTargetDecodingConfig, EagleDecodingConfig, KvCacheConfig
|
||||
from tensorrt_llm.llmapi import DraftTargetDecodingConfig, Eagle3DecodingConfig, KvCacheConfig
|
||||
|
||||
prompts = [
|
||||
"What is the capital of France?",
|
||||
@ -57,7 +57,7 @@ def make_draft_target_config(spec_model_path: str):
|
||||
|
||||
|
||||
def make_eagle3_config(spec_model_path: str):
|
||||
return EagleDecodingConfig(
|
||||
return Eagle3DecodingConfig(
|
||||
max_draft_len=EAGLE_MAX_DRAFT_LEN,
|
||||
speculative_model_dir=spec_model_path,
|
||||
eagle3_one_model=False,
|
||||
@ -214,7 +214,7 @@ def test_autodeploy_eagle3_acceptance_rate():
|
||||
max_draft_len = EAGLE_MAX_DRAFT_LEN
|
||||
|
||||
# Configure Eagle3 speculative decoding
|
||||
speculative_config = EagleDecodingConfig(
|
||||
speculative_config = Eagle3DecodingConfig(
|
||||
max_draft_len=max_draft_len,
|
||||
speculative_model_dir=eagle_model,
|
||||
eagle3_one_model=False,
|
||||
|
||||
@ -3375,7 +3375,7 @@ def test_eagle3_output_consistency_4gpus(model_dir: str, draft_model_dir: str):
|
||||
RCCA: https://nvbugspro.nvidia.com/bug/5575211
|
||||
"""
|
||||
from tensorrt_llm import LLM, SamplingParams
|
||||
from tensorrt_llm.llmapi import (CudaGraphConfig, EagleDecodingConfig,
|
||||
from tensorrt_llm.llmapi import (CudaGraphConfig, Eagle3DecodingConfig,
|
||||
KvCacheConfig)
|
||||
|
||||
models_path = llm_models_root()
|
||||
@ -3412,7 +3412,7 @@ def test_eagle3_output_consistency_4gpus(model_dir: str, draft_model_dir: str):
|
||||
sampling_params = SamplingParams(max_tokens=1024, temperature=0)
|
||||
|
||||
# Run with Eagle3
|
||||
spec_config = EagleDecodingConfig(
|
||||
spec_config = Eagle3DecodingConfig(
|
||||
max_draft_len=3,
|
||||
speculative_model_dir=eagle_model_dir,
|
||||
eagle3_one_model=True,
|
||||
|
||||
@ -10,7 +10,7 @@ from tensorrt_llm._torch.attention_backend.trtllm import TrtllmAttentionMetadata
|
||||
from tensorrt_llm._torch.speculative.drafting_loops import TreeDraftingLoopWrapper
|
||||
from tensorrt_llm._torch.speculative.eagle3 import Eagle3ResourceManager, Eagle3SpecMetadata
|
||||
from tensorrt_llm._torch.speculative.spec_tree_manager import SpecTreeManager
|
||||
from tensorrt_llm.llmapi import EagleDecodingConfig
|
||||
from tensorrt_llm.llmapi import Eagle3DecodingConfig
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
|
||||
|
||||
@ -84,7 +84,7 @@ def test_draft_token_static_tree_prepare_for_generation():
|
||||
)
|
||||
|
||||
# 2) Create spec metadata
|
||||
spec_config = EagleDecodingConfig(
|
||||
spec_config = Eagle3DecodingConfig(
|
||||
max_draft_len=max_draft_len,
|
||||
max_total_draft_tokens=max_total_draft_tokens,
|
||||
speculative_model_dir=eagle_model_dir,
|
||||
|
||||
@ -8,7 +8,7 @@ from utils.llm_data import llm_models_root
|
||||
from tensorrt_llm._torch.speculative.drafting_loops import \
|
||||
TreeDraftingLoopWrapper
|
||||
from tensorrt_llm._torch.speculative.spec_tree_manager import SpecTreeManager
|
||||
from tensorrt_llm.llmapi import EagleDecodingConfig
|
||||
from tensorrt_llm.llmapi import Eagle3DecodingConfig
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
||||
|
||||
@ -35,7 +35,7 @@ def test_draft_token_static_tree_sampling():
|
||||
def run_test(max_batch_size, draft_layer_id, max_total_draft_tokens,
|
||||
max_draft_len, eagle_choices, logits, use_cuda_graph,
|
||||
ref_new_tokens):
|
||||
spec_config = EagleDecodingConfig(
|
||||
spec_config = Eagle3DecodingConfig(
|
||||
max_draft_len=max_draft_len,
|
||||
max_total_draft_tokens=max_total_draft_tokens,
|
||||
speculative_model_dir=eagle_model_dir,
|
||||
|
||||
@ -11,7 +11,7 @@ from tensorrt_llm._torch.pyexecutor.llm_request import (LlmRequest,
|
||||
from tensorrt_llm._torch.pyexecutor.sampler import TorchSampler
|
||||
from tensorrt_llm._torch.speculative.spec_tree_manager import SpecTreeManager
|
||||
from tensorrt_llm.bindings.executor import FinishReason
|
||||
from tensorrt_llm.llmapi import EagleDecodingConfig
|
||||
from tensorrt_llm.llmapi import Eagle3DecodingConfig
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
||||
|
||||
@ -20,7 +20,7 @@ def run_test(eagle_model_dir, max_seq_len, beam_width, use_dynamic_tree,
|
||||
max_new_tokens, max_batch_size, input_request, input_new_tokens,
|
||||
draft_layer_id, max_total_draft_tokens, max_draft_len,
|
||||
eagle_choices, ref_num_accepted_draft_tokens, ref_mtokens):
|
||||
spec_config = EagleDecodingConfig(
|
||||
spec_config = Eagle3DecodingConfig(
|
||||
max_draft_len=max_draft_len,
|
||||
max_total_draft_tokens=max_total_draft_tokens,
|
||||
speculative_model_dir=eagle_model_dir,
|
||||
|
||||
@ -9,7 +9,7 @@ from utils.llm_data import llm_models_root
|
||||
|
||||
from tensorrt_llm import LLM, SamplingParams
|
||||
from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequestState
|
||||
from tensorrt_llm.llmapi import (CudaGraphConfig, EagleDecodingConfig,
|
||||
from tensorrt_llm.llmapi import (CudaGraphConfig, Eagle3DecodingConfig,
|
||||
KvCacheConfig)
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
||||
@ -54,7 +54,7 @@ def test_dynamic_spec_decode(enforce_single_worker,
|
||||
max_seq_len=8192,
|
||||
)
|
||||
|
||||
spec_config = EagleDecodingConfig(
|
||||
spec_config = Eagle3DecodingConfig(
|
||||
max_draft_len=max_draft_len,
|
||||
speculative_model_dir=eagle_model_dir,
|
||||
# Llama 3 does not support one model eagle.
|
||||
|
||||
@ -13,7 +13,7 @@ from utils.llm_data import llm_models_root
|
||||
from tensorrt_llm import LLM, SamplingParams
|
||||
from tensorrt_llm._torch.attention_backend.trtllm import TrtllmAttentionMetadata
|
||||
from tensorrt_llm._torch.metadata import KVCacheParams
|
||||
from tensorrt_llm.llmapi import (CudaGraphConfig, EagleDecodingConfig,
|
||||
from tensorrt_llm.llmapi import (CudaGraphConfig, Eagle3DecodingConfig,
|
||||
KvCacheConfig)
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
||||
@ -163,7 +163,7 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str,
|
||||
# Use a small max_num_tokens so that the chunked prefill path gets exercised.
|
||||
llm_common_config['max_num_tokens'] = 64
|
||||
|
||||
spec_config = EagleDecodingConfig(
|
||||
spec_config = Eagle3DecodingConfig(
|
||||
max_draft_len=max_draft_len,
|
||||
speculative_model_dir=eagle_model_dir,
|
||||
# Llama 3 does not support one model eagle.
|
||||
@ -239,7 +239,7 @@ def test_eagle3_spec_decoding_stats(eagle3_one_model):
|
||||
|
||||
kv_cache_config = KvCacheConfig(enable_block_reuse=False,
|
||||
free_gpu_memory_fraction=0.6)
|
||||
spec_config = EagleDecodingConfig(
|
||||
spec_config = Eagle3DecodingConfig(
|
||||
max_draft_len=3,
|
||||
speculative_model_dir=eagle_model_dir,
|
||||
eagle3_one_model=eagle3_one_model,
|
||||
@ -319,7 +319,7 @@ def test_llama_eagle3_long_prompt(use_cuda_graph):
|
||||
eagle_model_dir = f"{models_path}/EAGLE3-LLaMA3.1-Instruct-8B"
|
||||
target_model_dir = f"{models_path}/llama-3.1-model/Llama-3.1-8B-Instruct"
|
||||
|
||||
spec_config = EagleDecodingConfig(
|
||||
spec_config = Eagle3DecodingConfig(
|
||||
max_draft_len=3,
|
||||
speculative_model_dir=eagle_model_dir,
|
||||
eagle3_one_model=False,
|
||||
@ -413,7 +413,7 @@ def test_deepseek_eagle3():
|
||||
'transformers_version': '4.52.4',
|
||||
'use_cache': True,
|
||||
'vocab_size': 129280,
|
||||
'draft_vocab_size': 129280
|
||||
'draft_vocab_size': 129280,
|
||||
}
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
eagle_model_dir = Path(temp_dir)
|
||||
@ -443,7 +443,7 @@ def test_deepseek_eagle3():
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
)
|
||||
|
||||
spec_config = EagleDecodingConfig(
|
||||
spec_config = Eagle3DecodingConfig(
|
||||
max_draft_len=max_draft_len,
|
||||
speculative_model_dir=eagle_model_dir,
|
||||
# Llama 3 does not support one model eagle.
|
||||
@ -554,10 +554,11 @@ def test_deepseek_mla_eagle3():
|
||||
load_format="dummy",
|
||||
)
|
||||
|
||||
spec_config = EagleDecodingConfig(max_draft_len=max_draft_len,
|
||||
speculative_model_dir=eagle_model_dir,
|
||||
eagle3_one_model=use_one_model,
|
||||
load_format="dummy")
|
||||
spec_config = Eagle3DecodingConfig(
|
||||
max_draft_len=max_draft_len,
|
||||
speculative_model_dir=eagle_model_dir,
|
||||
eagle3_one_model=use_one_model,
|
||||
load_format="dummy")
|
||||
|
||||
llm_spec = LLM(**llm_common_config, speculative_config=spec_config)
|
||||
|
||||
@ -623,7 +624,7 @@ def test_multi_eagle3(use_one_model: bool):
|
||||
'transformers_version': '4.52.4',
|
||||
'use_cache': True,
|
||||
'vocab_size': 128256,
|
||||
'draft_vocab_size': 128256
|
||||
'draft_vocab_size': 128256,
|
||||
}
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
eagle_model_dir = Path(temp_dir)
|
||||
@ -652,7 +653,7 @@ def test_multi_eagle3(use_one_model: bool):
|
||||
load_format="dummy",
|
||||
)
|
||||
|
||||
spec_config = EagleDecodingConfig(
|
||||
spec_config = Eagle3DecodingConfig(
|
||||
max_draft_len=max_draft_len,
|
||||
speculative_model_dir=eagle_model_dir,
|
||||
# Llama 3 does not support one model eagle.
|
||||
@ -711,7 +712,7 @@ def test_eagle3_cuda_graph_padding(disable_overlap_scheduler: bool):
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
)
|
||||
|
||||
spec_config = EagleDecodingConfig(
|
||||
spec_config = Eagle3DecodingConfig(
|
||||
max_draft_len=max_draft_len,
|
||||
speculative_model_dir=eagle_model_dir,
|
||||
eagle3_one_model=use_one_model,
|
||||
@ -764,7 +765,7 @@ def test_eagle3_cdl_sampling(disable_overlap_scheduler: bool):
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
)
|
||||
|
||||
spec_config = EagleDecodingConfig(
|
||||
spec_config = Eagle3DecodingConfig(
|
||||
max_draft_len=max_draft_len,
|
||||
speculative_model_dir=eagle_model_dir,
|
||||
eagle3_one_model=use_one_model,
|
||||
|
||||
@ -7,7 +7,7 @@ import torch
|
||||
from utils.llm_data import llm_models_root
|
||||
|
||||
from tensorrt_llm import LLM, SamplingParams
|
||||
from tensorrt_llm.llmapi import (CudaGraphConfig, EagleDecodingConfig,
|
||||
from tensorrt_llm.llmapi import (CudaGraphConfig, Eagle3DecodingConfig,
|
||||
KvCacheConfig)
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
||||
@ -50,7 +50,7 @@ def test_kv_cache_reuse(use_cuda_graph: bool, attn_backend: str):
|
||||
max_seq_len=8192,
|
||||
)
|
||||
|
||||
spec_config = EagleDecodingConfig(
|
||||
spec_config = Eagle3DecodingConfig(
|
||||
max_draft_len=max_draft_len,
|
||||
speculative_model_dir=eagle_model_dir,
|
||||
eagle3_one_model=False,
|
||||
|
||||
@ -9,7 +9,7 @@ from utils.util import similar, skip_blackwell
|
||||
|
||||
from tensorrt_llm import LLM, SamplingParams
|
||||
from tensorrt_llm._torch.speculative.speculation_gate import SpeculationGate
|
||||
from tensorrt_llm.llmapi import (CudaGraphConfig, EagleDecodingConfig,
|
||||
from tensorrt_llm.llmapi import (CudaGraphConfig, Eagle3DecodingConfig,
|
||||
KvCacheConfig)
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
||||
@ -45,7 +45,7 @@ def test_spec_gate_e2e():
|
||||
max_seq_len=4096,
|
||||
)
|
||||
|
||||
spec_config = EagleDecodingConfig(
|
||||
spec_config = Eagle3DecodingConfig(
|
||||
max_draft_len=max_draft_len,
|
||||
speculative_model_dir=eagle_model_dir,
|
||||
# Llama 3 does not support one model eagle.
|
||||
|
||||
@ -139,6 +139,48 @@ max_seq_len: 128
|
||||
assert llm_args.max_seq_len == 128
|
||||
|
||||
|
||||
def test_decoding_type_eagle3_parses_to_eagle3_decoding_config():
|
||||
spec_cfg = DecodingBaseConfig.from_dict(
|
||||
dict(decoding_type="Eagle3",
|
||||
max_draft_len=3,
|
||||
speculative_model_dir="/path/to/draft/model"))
|
||||
assert isinstance(spec_cfg, Eagle3DecodingConfig)
|
||||
|
||||
|
||||
def test_decoding_type_eagle_warns_on_pytorch_backend(monkeypatch):
|
||||
import tensorrt_llm.llmapi.llm_args as llm_args_mod
|
||||
|
||||
warnings_seen: list[str] = []
|
||||
|
||||
def _capture_warning(msg, *args, **kwargs):
|
||||
warnings_seen.append(str(msg))
|
||||
|
||||
monkeypatch.setattr(llm_args_mod.logger, "warning", _capture_warning)
|
||||
|
||||
spec_cfg = DecodingBaseConfig.from_dict(dict(
|
||||
decoding_type="Eagle",
|
||||
max_draft_len=3,
|
||||
speculative_model_dir="/path/to/draft/model"),
|
||||
backend="pytorch")
|
||||
assert isinstance(spec_cfg, Eagle3DecodingConfig)
|
||||
|
||||
TorchLlmArgs(model=llama_model_path, speculative_config=spec_cfg)
|
||||
|
||||
assert any(
|
||||
"EAGLE (v1/v2) draft checkpoints are incompatible with Eagle3" in m
|
||||
for m in warnings_seen)
|
||||
|
||||
|
||||
def test_decoding_type_eagle3_errors_on_tensorrt_backend():
|
||||
spec_cfg = DecodingBaseConfig.from_dict(
|
||||
dict(decoding_type="Eagle3",
|
||||
max_draft_len=3,
|
||||
speculative_model_dir="/path/to/draft/model"))
|
||||
with pytest.raises(ValueError,
|
||||
match="only supported on the PyTorch backend"):
|
||||
TrtLlmArgs(model=llama_model_path, speculative_config=spec_cfg)
|
||||
|
||||
|
||||
def check_defaults(py_config_cls, pybind_config_cls):
|
||||
py_config = py_config_cls()
|
||||
pybind_config = pybind_config_cls()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user