This commit is contained in:
Anish Shanbhag 2026-01-13 21:25:08 +08:00 committed by GitHub
commit 55ff21ff81
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
42 changed files with 430 additions and 276 deletions

View File

@ -37,8 +37,13 @@ Draft/target is the simplest form of speculative decoding. In this approach, an
```python
from tensorrt_llm.llmapi import DraftTargetDecodingConfig
# Option 1: Use a HuggingFace Hub model ID (auto-downloaded)
speculative_config = DraftTargetDecodingConfig(
max_draft_len=3, speculative_model_dir="/path/to/draft_model")
max_draft_len=3, speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B")
# Option 2: Use a local path
# speculative_config = DraftTargetDecodingConfig(
# max_draft_len=3, speculative_model="/path/to/draft_model")
llm = LLM("/path/to/target_model", speculative_config=speculative_config, disable_overlap_scheduler=True)
```
@ -51,18 +56,23 @@ TRT-LLM supports a modified version of the algorithm presented in the paper: tre
The following draft model checkpoints can be used for EAGLE 3:
* Llama 3 variants: [use the checkpoints from the authors of the original EAGLE 3 paper](https://huggingface.co/yuhuili).
* Llama 4 Maverick: [use the checkpoint from the NVIDIA HuggingFace repository](https://huggingface.co/nvidia/Llama-4-Maverick-17B-128E-Eagle3).
* Other models, including `gpt-oss-120b` and `Qwen3`: check out the [Speculative Decoding Modules](https://huggingface.co/collections/nvidia/speculative-decoding-modules) collection from NVIDIA.
```python
from tensorrt_llm.llmapi import EagleDecodingConfig
# Enable to use the faster one-model implementation for Llama 4.
eagle3_one_model = False
model = "meta-llama/Llama-3.1-8B-Instruct"
speculative_model = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
speculative_config = EagleDecodingConfig(
max_draft_len=3, speculative_model_dir="/path/to/draft_model", eagle3_one_model=eagle3_one_model)
max_draft_len=3,
speculative_model=speculative_model,
eagle3_one_model=eagle3_one_model)
# Only need to disable overlap scheduler if eagle3_one_model is False.
llm = LLM("/path/to/target_model", speculative_config=speculative_config, disable_overlap_scheduler=True)
llm = LLM(model, speculative_config=speculative_config, disable_overlap_scheduler=True)
```
### NGram
@ -137,7 +147,17 @@ Speculative decoding options must be specified via `--config config.yaml` for bo
The rest of the argument names/valid values are the same as in their corresponding configuration class described in the Quick Start section. For example, a YAML configuration could look like this:
```yaml
# Using a HuggingFace Hub model ID (auto-downloaded)
disable_overlap_scheduler: true
speculative_config:
decoding_type: Eagle
max_draft_len: 4
speculative_model: yuhuili/EAGLE3-LLaMA3.1-Instruct-8B
```
```yaml
# Or using a local path
disable_overlap_scheduler: true
speculative_config:
decoding_type: Eagle
@ -145,6 +165,16 @@ speculative_config:
speculative_model: /path/to/draft/model
```
```{note}
The field name `speculative_model_dir` can also be used as an alias for `speculative_config.speculative_model`. For example:
speculative_config:
decoding_type: Eagle
max_draft_len: 4
speculative_model_dir: /path/to/draft/model
```
## Developer Guide
This section describes the components of a speculative decoding algorithm. All of the interfaces are defined in [`_torch/speculative/interface.py`](https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/_torch/speculative/interface.py).

View File

@ -23,12 +23,12 @@ def main():
model = "lmsys/vicuna-7b-v1.3"
# The end user can customize the eagle decoding configuration by specifying the
# speculative_model_dir, max_draft_len, num_eagle_layers, max_non_leaves_per_layer, eagle_choices
# speculative_model, max_draft_len, num_eagle_layers, max_non_leaves_per_layer, eagle_choices
# greedy_sampling,posterior_threshold, use_dynamic_tree and dynamic_tree_max_topK
# with the EagleDecodingConfig class
speculative_config = EagleDecodingConfig(
speculative_model_dir="yuhuili/EAGLE-Vicuna-7B-v1.3",
speculative_model="yuhuili/EAGLE-Vicuna-7B-v1.3",
max_draft_len=63,
num_eagle_layers=4,
max_non_leaves_per_layer=10,

View File

@ -23,12 +23,12 @@ def main():
model = "lmsys/vicuna-7b-v1.3"
# The end user can customize the eagle decoding configuration by specifying the
# speculative_model_dir, max_draft_len, num_eagle_layers, max_non_leaves_per_layer, eagle_choices
# speculative_model, max_draft_len, num_eagle_layers, max_non_leaves_per_layer, eagle_choices
# greedy_sampling,posterior_threshold, use_dynamic_tree and dynamic_tree_max_topK
# with the EagleDecodingConfig class
speculative_config = EagleDecodingConfig(
speculative_model_dir="yuhuili/EAGLE-Vicuna-7B-v1.3",
speculative_model="yuhuili/EAGLE-Vicuna-7B-v1.3",
max_draft_len=63,
num_eagle_layers=4,
max_non_leaves_per_layer=10,

View File

@ -48,10 +48,10 @@ def run_medusa_decoding(use_modelopt_ckpt=False, model_dir=None):
model = "lmsys/vicuna-7b-v1.3"
# The end user can customize the medusa decoding configuration by specifying the
# speculative_model_dir, max_draft_len, medusa heads num and medusa choices
# speculative_model, max_draft_len, medusa heads num and medusa choices
# with the MedusaDecodingConfig class
speculative_config = MedusaDecodingConfig(
speculative_model_dir="FasterDecoding/medusa-vicuna-7b-v1.3",
speculative_model="FasterDecoding/medusa-vicuna-7b-v1.3",
max_draft_len=63,
num_medusa_heads=4,
medusa_choices=[[0], [0, 0], [1], [0, 1], [2], [0, 0, 0], [1, 0], [0, 2], [3], [0, 3], [4], [0, 4], [2, 0], \

View File

@ -35,7 +35,7 @@ def run_MTP(model: Optional[str] = None):
def run_Eagle3():
spec_config = EagleDecodingConfig(
max_draft_len=3,
speculative_model_dir="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
eagle3_one_model=True)
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.8)

View File

@ -220,11 +220,11 @@ def setup_llm(args, **kwargs):
relaxed_topk=args.relaxed_topk,
relaxed_delta=args.relaxed_delta,
mtp_eagle_one_model=args.use_one_model,
speculative_model_dir=args.model_dir)
speculative_model=args.model_dir)
elif spec_decode_algo == "EAGLE3":
spec_config = EagleDecodingConfig(
max_draft_len=args.spec_decode_max_draft_len,
speculative_model_dir=args.draft_model_dir,
speculative_model=args.draft_model_dir,
eagle3_one_model=args.use_one_model,
eagle_choices=args.eagle_choices,
use_dynamic_tree=args.use_dynamic_tree,
@ -234,7 +234,7 @@ def setup_llm(args, **kwargs):
elif spec_decode_algo == "DRAFT_TARGET":
spec_config = DraftTargetDecodingConfig(
max_draft_len=args.spec_decode_max_draft_len,
speculative_model_dir=args.draft_model_dir)
speculative_model=args.draft_model_dir)
elif spec_decode_algo == "NGRAM":
spec_config = NGramDecodingConfig(
max_draft_len=args.spec_decode_max_draft_len,

View File

@ -841,8 +841,8 @@ Qwen3 now supports Eagle3 (Speculative Decoding with Eagle3). To enable Eagle3 o
Set the decoding type to "Eagle" to enable Eagle3 speculative decoding.
- `speculative_config.max_draft_len: 3`
Set the maximum number of draft tokens generated per step (this value can be adjusted as needed).
- `speculative_config.speculative_model_dir: <EAGLE3_DRAFT_MODEL_PATH>`
Specify the path to the Eagle3 draft model (ensure the corresponding draft model weights are prepared).
- `speculative_config.speculative_model: <HUGGINGFACE ID / LOCAL PATH>`
Specify the Eagle3 draft model either as a Huggingface model ID or a local path. You can find ready-to-use Eagle3 draft models at https://huggingface.co/collections/nvidia/speculative-decoding-modules.
Currently, there are some limitations when enabling Eagle3:
@ -857,7 +857,7 @@ enable_attention_dp: false
speculative_config:
decoding_type: Eagle
max_draft_len: 3
speculative_model_dir: <EAGLE3_DRAFT_MODEL_PATH>
speculative_model: <HUGGINGFACE ID / LOCAL PATH>
kv_cache_config:
enable_block_reuse: false
" >> ${path_config}

View File

@ -921,7 +921,7 @@ def create_draft_model_engine_maybe(
drafting_loop_wrapper = None
draft_model_engine = PyTorchModelEngine(
model_path=draft_spec_config.speculative_model_dir,
model_path=draft_spec_config.speculative_model,
llm_args=draft_llm_args,
mapping=dist_mapping,
attn_runtime_features=attn_runtime_features,

View File

@ -887,7 +887,7 @@ class SpecDecOneEngineForCausalLM(DecoderModelForCausalLM[TModel, TConfig],
from tensorrt_llm._torch.models.checkpoints.mistral.config_loader import \
MistralConfigLoader
self.draft_config = MistralConfigLoader().load(
spec_config.speculative_model_dir,
spec_config.speculative_model,
mapping=model_config.mapping,
moe_backend=model_config.moe_backend,
moe_max_num_tokens=model_config.moe_max_num_tokens,
@ -898,7 +898,7 @@ class SpecDecOneEngineForCausalLM(DecoderModelForCausalLM[TModel, TConfig],
self.draft_config.extra_attrs = model_config.extra_attrs
elif spec_config.eagle3_model_arch == "llama3":
self.draft_config = ModelConfig.from_pretrained(
model_config.spec_config.speculative_model_dir,
model_config.spec_config.speculative_model,
trust_remote_code=True,
attn_backend=model_config.attn_backend,
moe_backend=model_config.moe_backend,

View File

@ -278,7 +278,7 @@ class ModelLoader:
if self.spec_config is not None and self.spec_config.spec_dec_mode.need_load_draft_weights(
):
weights = checkpoint_loader.load_weights(
self.spec_config.speculative_model_dir,
self.spec_config.speculative_model,
mapping=self.mapping)
draft_model_arch = model.draft_config.pretrained_config.architectures[

View File

@ -398,7 +398,7 @@ def create_py_executor(
draft_llm_args.load_format = LoadFormat.DUMMY
draft_model_engine = PyTorchModelEngine(
model_path=spec_config.speculative_model_dir,
model_path=spec_config.speculative_model,
llm_args=draft_llm_args,
mapping=mapping,
attn_runtime_features=attn_runtime_features,

View File

@ -13,7 +13,7 @@ from typing import (Any, ClassVar, Dict, List, Literal, Optional, Set, Tuple,
import torch
import yaml
from pydantic import BaseModel
from pydantic import AliasChoices, BaseModel
from pydantic import Field as PydanticField
from pydantic import PrivateAttr, field_validator, model_validator
from strenum import StrEnum
@ -651,7 +651,14 @@ class DecodingBaseConfig(StrictBaseModel):
# If it's a static or dynamic tree, each draft layer may generate more than one draft token.
# In this case, max_total_draft_tokens >= max_draft_len.
max_total_draft_tokens: Optional[int] = None
speculative_model_dir: Optional[Union[str, Path]] = None
# The speculative (draft) model. Accepts either:
# - A HuggingFace Hub model ID (str), e.g., "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
# which will be automatically downloaded.
# - A local filesystem path to a downloaded model directory.
speculative_model: Optional[Union[str, Path]] = Field(
default=None,
validation_alias=AliasChoices("speculative_model",
"speculative_model_dir"))
# PyTorch only.
# When specified, speculation will be disabled at batch sizes above
@ -858,28 +865,29 @@ class EagleDecodingConfig(DecodingBaseConfig):
# choices: llama3, mistral_large3
eagle3_model_arch: str = "llama3"
def __init__(self, **kwargs):
super().__init__()
for attr_name, attr_value in kwargs.items():
if attr_name == 'max_draft_len':
self.num_eagle_layers = attr_value
self.max_total_draft_tokens = attr_value # If using linear-tree, the max_total_draft_tokens is the same as max_draft_len
# Convert the data type of Eagle choice from str to List[List[int]]
if attr_name == 'eagle_choices' and attr_value is not None:
logger.warning(
"NOTE: The Draft token tree is still under development, PLEASE DO NOT USE IT !!!"
)
if not isinstance(attr_value, list):
if isinstance(attr_value, str):
attr_value = ast.literal_eval(
attr_value.replace(" ", ""))
else:
raise ValueError(
"Wrong eagle choices type. Eagle choices should be a List[List[int]] or a string like [[0], [1], [2], [0, 0], [0, 1]]."
)
setattr(self, attr_name, attr_value)
@field_validator('eagle_choices', mode='before')
@classmethod
def validate_eagle_choices(cls, v):
if v is not None:
logger.warning(
"NOTE: The Draft token tree is still under development, PLEASE DO NOT USE IT !!!"
)
if not isinstance(v, list):
if isinstance(v, str):
v = ast.literal_eval(v.replace(" ", ""))
else:
raise ValueError(
"Wrong eagle choices type. Eagle choices should be a List[List[int]] or a string like [[0], [1], [2], [0, 0], [0, 1]]."
)
return v
@model_validator(mode='after')
def validate_eagle_config(self) -> 'EagleDecodingConfig':
if self.max_draft_len is None:
raise ValueError("max_draft_len is required for Eagle")
self.num_eagle_layers = self.max_draft_len
self.max_total_draft_tokens = self.max_draft_len # If using linear-tree, the max_total_draft_tokens is the same as max_draft_len
assert self.max_draft_len is not None, "max_draft_len is required for Eagle"
if self.eagle3_model_arch == "mistral_large3" and self.eagle3_layers_to_capture is None:
# FIXME find a better way to setup it.
self.eagle3_layers_to_capture = {-1}
@ -889,7 +897,10 @@ class EagleDecodingConfig(DecodingBaseConfig):
# and reset the max_draft_len and num_eagle_layers if necessary
if self.eagle_choices is not None:
# If eagle_choices is provided, use_dynamic_tree should not be used
assert not self.use_dynamic_tree, "If eagle_choices is provided, use_dynamic_tree need to be False"
if self.use_dynamic_tree:
raise ValueError(
"If eagle_choices is provided, use_dynamic_tree need to be False"
)
# Get num_eagle_layers from eagle_choices
num_eagle_layers_from_choices = self.check_eagle_choices()
@ -906,10 +917,23 @@ class EagleDecodingConfig(DecodingBaseConfig):
# Dynamic tree logic
if self.use_dynamic_tree:
assert self.eagle_choices is None, "If use_dynamic_tree is True, eagle_choices should be None"
assert self.max_draft_len is not None and self.max_draft_len > 0, "max_draft_len should be provided, which indicates the number of drafter layers"
assert self.dynamic_tree_max_topK is not None and self.dynamic_tree_max_topK > 0, "dynamic_tree_max_topK should be provided, which indicates the number of nodes to expand each time"
assert self.max_total_draft_tokens is not None and self.max_total_draft_tokens > 0, "max_total_draft_tokens should be provided, which indicates the total nodes of the final draft tree. (exclude the root node)"
if self.eagle_choices is not None:
raise ValueError(
"If use_dynamic_tree is True, eagle_choices should be None")
if self.max_draft_len is None or self.max_draft_len <= 0:
raise ValueError(
"max_draft_len should be provided, which indicates the number of drafter layers"
)
if self.dynamic_tree_max_topK is None or self.dynamic_tree_max_topK <= 0:
raise ValueError(
"dynamic_tree_max_topK should be provided, which indicates the number of nodes to expand each time"
)
if self.max_total_draft_tokens is None or self.max_total_draft_tokens <= 0:
raise ValueError(
"max_total_draft_tokens should be provided, which indicates the total nodes of the final draft tree. (exclude the root node)"
)
return self
@classmethod
def from_dict(cls, data: dict):
@ -918,7 +942,7 @@ class EagleDecodingConfig(DecodingBaseConfig):
decoding_type: ClassVar[str] = "Eagle"
def validate(self) -> None:
if self.speculative_model_dir is None:
if self.speculative_model is None:
raise ValueError("Draft model must be provided for EAGLE")
def check_eagle_choices(self):
@ -2119,9 +2143,6 @@ class BaseLlmArgs(StrictBaseModel):
_parallel_config: Optional[_ParallelConfig] = PrivateAttr(default=None)
_model_format: Optional[_ModelFormatKind] = PrivateAttr(default=None)
_speculative_model: Optional[str] = PrivateAttr(default=None)
_speculative_model_format: Optional[_ModelFormatKind] = PrivateAttr(
default=None)
@property
def parallel_config(self) -> _ParallelConfig:
@ -2132,12 +2153,8 @@ class BaseLlmArgs(StrictBaseModel):
return self._model_format
@property
def speculative_model_dir(self) -> Optional[_ModelFormatKind]:
return self._speculative_model
@property
def speculative_model_format(self) -> _ModelFormatKind:
return self._speculative_model_format
def speculative_model(self) -> Optional[Union[str, Path]]:
return self.speculative_config.speculative_model if self.speculative_config is not None else None
@classmethod
def from_kwargs(cls, **kwargs: Any) -> "BaseLlmArgs":
@ -2508,7 +2525,7 @@ class TrtLlmArgs(BaseLlmArgs):
elif isinstance(self.speculative_config, EagleDecodingConfig):
assert self.speculative_config.max_draft_len > 0
assert self.speculative_config.speculative_model_dir is not None, "Path to EAGLE3 weights must be specified."
assert self.speculative_config.speculative_model is not None, "EAGLE3 draft model must be specified."
self.build_config.max_draft_len = self.speculative_config.max_draft_len
self.build_config.speculative_decoding_mode = SpeculativeDecodingMode.EAGLE
eagle_config = _EagleConfig(
@ -2528,14 +2545,6 @@ class TrtLlmArgs(BaseLlmArgs):
else:
self.decoding_config = None
self._speculative_model = getattr(self.speculative_config,
"speculative_model_dir", None)
speculative_model_obj = _ModelWrapper(
self._speculative_model
) if self._speculative_model is not None else None
if self._speculative_model and speculative_model_obj.is_local_model:
self._speculative_model_format = _ModelFormatKind.HF
return self
def _load_config_from_engine(self, engine_dir: Path):
@ -3025,12 +3034,12 @@ class TorchLlmArgs(BaseLlmArgs):
if isinstance(self.speculative_config, EagleDecodingConfig):
assert self.speculative_config.max_draft_len > 0
assert self.speculative_config.speculative_model_dir is not None, "Path to EAGLE3 weights must be specified."
assert self.speculative_config.speculative_model is not None, "EAGLE3 draft model must be specified."
elif isinstance(self.speculative_config, NGramDecodingConfig):
assert self.speculative_config.max_draft_len > 0 and self.speculative_config.max_matching_ngram_size > 0
elif isinstance(self.speculative_config, DraftTargetDecodingConfig):
assert self.speculative_config.max_draft_len > 0
assert self.speculative_config.speculative_model_dir is not None, "Path to draft model must be specified."
assert self.speculative_config.speculative_model is not None, "Draft model must be specified."
elif isinstance(self.speculative_config, MTPDecodingConfig):
assert self.speculative_config.num_nextn_predict_layers > 0
self.speculative_config.max_draft_len = self.speculative_config.num_nextn_predict_layers
@ -3057,14 +3066,6 @@ class TorchLlmArgs(BaseLlmArgs):
else:
self.decoding_config = None
self._speculative_model = getattr(self.speculative_config,
"speculative_model_dir", None)
speculative_model_obj = _ModelWrapper(
self._speculative_model
) if self._speculative_model is not None else None
if self._speculative_model and speculative_model_obj.is_local_model:
self._speculative_model_format = _ModelFormatKind.HF
return self
@model_validator(mode="after")

View File

@ -109,8 +109,8 @@ class ModelLoader:
self.model_obj = _ModelWrapper(self.llm_args.model)
self.speculative_model_obj = _ModelWrapper(
self.llm_args.speculative_model_dir
) if self.llm_args.speculative_model_dir is not None else None
self.llm_args.speculative_model
) if self.llm_args.speculative_model is not None else None
if isinstance(self.llm_args, TrtLlmArgs):
self.convert_checkpoint_options = self.llm_args._convert_checkpoint_options
@ -125,7 +125,7 @@ class ModelLoader:
Path] = self.model_obj.model_dir if self.model_obj.is_local_model else None
self._speculative_model_dir: Optional[
Path] = self.speculative_model_obj.model_dir if self.speculative_model_obj is not None and self.model_obj.is_local_model else None
Path] = self.speculative_model_obj.model_dir if self.speculative_model_obj is not None and self.speculative_model_obj.is_local_model else None
self._model_info: Optional[_ModelInfo] = None
self._model_format = self.llm_args.model_format
@ -145,9 +145,7 @@ class ModelLoader:
return
if (self.model_obj.is_hub_model
and self._model_format is not _ModelFormatKind.TLLM_ENGINE) or (
self.speculative_model_obj
and self.speculative_model_obj.is_hub_model):
and self._model_format is not _ModelFormatKind.TLLM_ENGINE):
# Download HF model if necessary
if self.model_obj.model_name is None:
raise ValueError(
@ -305,31 +303,18 @@ class ModelLoader:
def _download_hf_model(self):
''' Download HF model from third-party model hub like www.modelscope.cn or huggingface. '''
model_dir = None
speculative_model_dir = None
# Only the rank0 are allowed to download model
if mpi_rank() == 0:
assert self._workspace is not None
assert isinstance(self.model_obj.model_name, str)
# this will download only once when multiple MPI processes are running
model_dir = download_hf_model(self.model_obj.model_name,
revision=self.llm_args.revision)
print_colored(f"Downloaded model to {model_dir}\n", 'grey')
if self.speculative_model_obj:
speculative_model_dir = download_hf_model(
self.speculative_model_obj.model_name)
print_colored(f"Downloaded model to {speculative_model_dir}\n",
'grey')
# Make all the processes got the same model_dir
self._model_dir = mpi_broadcast(model_dir, root=0)
self.model_obj.model_dir = self._model_dir # mark as a local model
assert self.model_obj.is_local_model
if self.speculative_model_obj:
self._speculative_model_dir = mpi_broadcast(speculative_model_dir,
root=0)
self.speculative_model_obj.model_dir = self._speculative_model_dir
assert self.speculative_model_obj.is_local_model
def _update_from_hf_quant_config(self) -> bool:
"""Update quant_config from the config file of pre-quantized HF checkpoint.
@ -440,8 +425,8 @@ class ModelLoader:
model_cls = AutoModelForCausalLM.get_trtllm_model_class(
self._model_dir, self.llm_args.trust_remote_code,
self.llm_args.decoding_config.decoding_mode
if hasattr(self.llm_args, "speculative_model_dir")
and self.llm_args.speculative_model_dir else None)
if hasattr(self.llm_args, "speculative_model")
and self.llm_args.speculative_model else None)
prequantized = self._update_from_hf_quant_config()
@ -638,18 +623,42 @@ class CachedModelLoader:
else:
return [task(*args, **kwargs)]
def _download_hf_model_if_needed(self,
model_obj: _ModelWrapper,
revision: Optional[str] = None) -> Path:
"""Download a model from HF hub if needed.
Also updates the model_obj.model_dir with the local model dir on rank 0.
"""
if model_obj.is_hub_model:
model_dirs = self._submit_to_all_workers(
CachedModelLoader._node_download_hf_model,
model=model_obj.model_name,
revision=revision)
model_dir = model_dirs[0]
model_obj.model_dir = model_dir
return model_dir
return model_obj.model_dir
def __call__(self) -> Tuple[Path, Union[Path, None]]:
if self.llm_args.model_format is _ModelFormatKind.TLLM_ENGINE:
return Path(self.llm_args.model), None
# Download speculative model from HuggingFace if needed (all backends)
if (self.llm_args.speculative_config is not None and
self.llm_args.speculative_config.speculative_model is not None):
spec_model_obj = _ModelWrapper(
self.llm_args.speculative_config.speculative_model)
spec_model_dir = self._download_hf_model_if_needed(spec_model_obj)
self.llm_args.speculative_config.speculative_model = spec_model_dir
# AutoDeploy doesn't use ModelLoader
if self.llm_args.backend == "_autodeploy":
return None, ""
self.engine_cache_stage: Optional[CachedStage] = None
self._hf_model_dir = None
self.model_loader = ModelLoader(self.llm_args)
if self.llm_args.backend is not None:
@ -657,14 +666,8 @@ class CachedModelLoader:
raise ValueError(
f'backend {self.llm_args.backend} is not supported.')
if self.model_loader.model_obj.is_hub_model:
hf_model_dirs = self._submit_to_all_workers(
CachedModelLoader._node_download_hf_model,
model=self.model_loader.model_obj.model_name,
revision=self.llm_args.revision)
self._hf_model_dir = hf_model_dirs[0]
else:
self._hf_model_dir = self.model_loader.model_obj.model_dir
self._hf_model_dir = self._download_hf_model_if_needed(
self.model_loader.model_obj, revision=self.llm_args.revision)
if self.llm_args.quant_config.quant_algo is not None:
logger.warning(

View File

@ -224,6 +224,7 @@ class DisabledTqdm(tqdm):
def download_hf_model(model: str, revision: Optional[str] = None) -> Path:
ignore_patterns = ["original/**/*"]
logger.info(f"Downloading model {model} from HuggingFace")
with get_file_lock(model):
hf_folder = snapshot_download(
model,
@ -231,6 +232,7 @@ def download_hf_model(model: str, revision: Optional[str] = None) -> Path:
ignore_patterns=ignore_patterns,
revision=revision,
tqdm_class=DisabledTqdm)
logger.info(f"Finished downloading model {model} from HuggingFace")
return Path(hf_folder)

View File

@ -576,7 +576,7 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
speculative_decoding_config = {
"decoding_type": "Eagle",
"max_draft_len": 4,
"speculative_model_dir":
"speculative_model":
f"{llm_models_root()}/EAGLE3-LLaMA3.1-Instruct-8B",
"eagle3_one_model": eagle3_one_model
}
@ -675,7 +675,7 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
speculative_decoding_config = {
"decoding_type": "Eagle",
"max_draft_len": 3,
"speculative_model_dir":
"speculative_model":
f"{llm_models_root()}/EAGLE3-LLaMA3.1-Instruct-8B",
"eagle3_one_model": eagle3_one_model
}

View File

@ -471,7 +471,7 @@ class TestEagleVicuna_7B_v1_3(LlmapiAccuracyTestHarness):
speculative_config = EagleDecodingConfig(
max_draft_len=63,
speculative_model_dir=f"{llm_models_root()}/EAGLE-Vicuna-7B-v1.3",
speculative_model=f"{llm_models_root()}/EAGLE-Vicuna-7B-v1.3",
num_eagle_layers=4,
max_non_leaves_per_layer=10,
eagle_choices=[[0], [0, 0], [1], [0, 1], [2], [0, 0, 0], [1, 0], [0, 2], [3], [0, 3], [4], [0, 4], [2, 0], \
@ -497,7 +497,7 @@ class TestEagle2Vicuna_7B_v1_3(LlmapiAccuracyTestHarness):
speculative_config = EagleDecodingConfig(
max_draft_len=63,
speculative_model_dir=f"{llm_models_root()}/EAGLE-Vicuna-7B-v1.3",
speculative_model=f"{llm_models_root()}/EAGLE-Vicuna-7B-v1.3",
num_eagle_layers=4,
max_non_leaves_per_layer=10,
use_dynamic_tree=True,

View File

@ -13,34 +13,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pytest
from test_common.llm_data import hf_model_dir_or_hub_id, llm_models_root
from tensorrt_llm._torch.auto_deploy import LLM as AutoDeployLLM
from tensorrt_llm.quantization import QuantAlgo
from tensorrt_llm.sampling_params import SamplingParams
from ..conftest import llm_models_root
from .accuracy_core import GSM8K, MMLU, CnnDailymail, LlmapiAccuracyTestHarness
def _hf_model_dir_or_hub_id(
hf_model_subdir: str,
hf_hub_id: str,
) -> str:
llm_models_path = llm_models_root()
if llm_models_path and os.path.isdir(
(model_fullpath := os.path.join(llm_models_path, hf_model_subdir))):
return str(model_fullpath)
else:
return hf_hub_id
class TestLlama3_1_8B(LlmapiAccuracyTestHarness):
MODEL_NAME = "meta-llama/Llama-3.1-8B"
MODEL_PATH = _hf_model_dir_or_hub_id("llama-3.1-model/Meta-Llama-3.1-8B",
MODEL_NAME)
MODEL_PATH = hf_model_dir_or_hub_id(MODEL_NAME)
def get_default_kwargs(self, enable_chunked_prefill=False):
config = {

View File

@ -276,7 +276,7 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
draft_len = 4
spec_config = EagleDecodingConfig(max_draft_len=draft_len,
speculative_model_dir=eagle_model_dir,
speculative_model=eagle_model_dir,
eagle3_one_model=eagle3_one_model)
with LLM(model=target_model_dir,
@ -369,8 +369,7 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
cuda_graph_config = CudaGraphConfig(enable_padding=True)
spec_config = EagleDecodingConfig(
max_draft_len=3,
speculative_model_dir=
f"{llm_models_root()}/EAGLE3-LLaMA3.1-Instruct-8B",
speculative_model=f"{llm_models_root()}/EAGLE3-LLaMA3.1-Instruct-8B",
eagle3_one_model=eagle3_one_model)
llm = LLM(
self.MODEL_PATH,
@ -621,7 +620,7 @@ class TestLlama3_3_70BInstruct(LlmapiAccuracyTestHarness):
eagle_model_dir = f"{llm_models_root()}/EAGLE3-LLaMA3.3-Instruct-70B"
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.6)
spec_config = EagleDecodingConfig(max_draft_len=3,
speculative_model_dir=eagle_model_dir,
speculative_model=eagle_model_dir,
eagle3_one_model=eagle3_one_model)
torch_compile_config = _get_default_torch_compile_config(torch_compile)
pytorch_config = dict(
@ -1383,7 +1382,7 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
)
mtp_config = MTPDecodingConfig(num_nextn_predict_layers=3,
mtp_eagle_one_model=False,
speculative_model_dir=self.MODEL_PATH)
speculative_model=self.MODEL_PATH)
with LLM(self.MODEL_PATH,
kv_cache_config=kv_cache_config,
enable_chunked_prefill=False,
@ -2935,7 +2934,7 @@ class TestGLM4_6(LlmapiAccuracyTestHarness):
mtp_config = MTPDecodingConfig(num_nextn_predict_layers=3,
mtp_eagle_one_model=False,
speculative_model_dir=model_path)
speculative_model=model_path)
with LLM(model_path,
max_batch_size=max_batch_size,
@ -3441,7 +3440,7 @@ class TestQwen3_8B(LlmapiAccuracyTestHarness):
draft_len = 4
spec_config = EagleDecodingConfig(max_draft_len=draft_len,
speculative_model_dir=eagle_model_dir,
speculative_model=eagle_model_dir,
eagle3_one_model=eagle3_one_model)
llm = LLM(model=target_model_dir,
@ -3812,7 +3811,7 @@ class TestQwen3_235B_A22B(LlmapiAccuracyTestHarness):
if eagle3:
spec_config = EagleDecodingConfig(
max_draft_len=2,
speculative_model_dir=
speculative_model=
f"{llm_models_root()}/Qwen3/qwen3-235B-eagle3/",
eagle3_one_model=True)
with LLM(
@ -3860,7 +3859,7 @@ class TestQwen3_235B_A22B(LlmapiAccuracyTestHarness):
if eagle3:
spec_config = EagleDecodingConfig(
max_draft_len=2,
speculative_model_dir=
speculative_model=
f"{llm_models_root()}/Qwen3/qwen3-235B-eagle3/",
eagle3_one_model=True)
with LLM(
@ -4479,7 +4478,7 @@ class TestGPTOSS(LlmapiAccuracyTestHarness):
eagle_model_dir = f"{llm_models_root()}/gpt_oss/gpt-oss-120b-Eagle3"
draft_len = 3
spec_config = EagleDecodingConfig(max_draft_len=draft_len,
speculative_model_dir=eagle_model_dir,
speculative_model=eagle_model_dir,
eagle3_one_model=one_model,
allow_advanced_sampling=True)
@ -4545,7 +4544,7 @@ class TestGPTOSS(LlmapiAccuracyTestHarness):
eagle_model_dir = f"{llm_models_root()}/gpt_oss/gpt-oss-120b-Eagle3"
draft_len = 3
spec_config = EagleDecodingConfig(max_draft_len=draft_len,
speculative_model_dir=eagle_model_dir,
speculative_model=eagle_model_dir,
eagle3_one_model=one_model,
allow_advanced_sampling=True)
@ -4609,7 +4608,7 @@ class TestGPTOSS(LlmapiAccuracyTestHarness):
eagle_model_dir = f"{llm_models_root()}/gpt_oss/gpt-oss-120b-Eagle3"
draft_len = 3
spec_config = EagleDecodingConfig(max_draft_len=draft_len,
speculative_model_dir=eagle_model_dir,
speculative_model=eagle_model_dir,
eagle3_one_model=one_model,
allow_advanced_sampling=True)
@ -4668,7 +4667,7 @@ class TestGPTOSS(LlmapiAccuracyTestHarness):
eagle_model_dir = f"{llm_models_root()}/gpt_oss/gpt-oss-120b-Eagle3"
draft_len = 3
spec_config = EagleDecodingConfig(max_draft_len=draft_len,
speculative_model_dir=eagle_model_dir,
speculative_model=eagle_model_dir,
eagle3_one_model=one_model)
max_seq_len = MAX_INPUT_LEN + MAX_OUTPUT_LEN
@ -5150,7 +5149,7 @@ class TestMistralLarge3_675B(LlmapiAccuracyTestHarness):
if eagle3:
spec_config = EagleDecodingConfig(
max_draft_len=2,
speculative_model_dir=
speculative_model=
f"{llm_models_root()}/Mistral-Large-3-675B/Mistral-Large-3-675B-Instruct-2512-Eagle/",
eagle3_one_model=True,
eagle3_model_arch="mistral_large3")
@ -5201,7 +5200,7 @@ class TestMistralLarge3_675B(LlmapiAccuracyTestHarness):
if eagle3:
spec_config = EagleDecodingConfig(
max_draft_len=2,
speculative_model_dir=
speculative_model=
f"{llm_models_root()}/Mistral-Large-3-675B/Mistral-Large-3-675B-Instruct-2512-Eagle/",
eagle3_one_model=True,
eagle3_model_arch="mistral_large3")

View File

@ -400,7 +400,7 @@ def test_disaggregated_spec_dec_batch_slot_limit(model, spec_dec_model_path,
# Test whether the batch slots are properly released when using speculative decoding
# with disaggregated serving.
spec_dec_config = EagleDecodingConfig(
speculative_model_dir=model_path(spec_dec_model_path),
speculative_model=model_path(spec_dec_model_path),
eagle3_one_model=eagle3_one_model,
max_draft_len=3)

View File

@ -93,7 +93,7 @@ def test_spec_decoding_metrics_eagle3_one_model():
"speculative_config": {
"decoding_type": "Eagle",
"max_draft_len": 4,
"speculative_model_dir": eagle3_path,
"speculative_model": eagle3_path,
"eagle3_one_model": True,
},
}
@ -174,7 +174,7 @@ def test_spec_decoding_metrics_eagle3_two_model():
"speculative_config": {
"decoding_type": "Eagle",
"max_draft_len": 4,
"speculative_model_dir": eagle3_path,
"speculative_model": eagle3_path,
"eagle3_one_model": False, # Two-model mode
},
}

View File

@ -52,14 +52,14 @@ def get_model_paths():
def make_draft_target_config(spec_model_path: str):
return DraftTargetDecodingConfig(
max_draft_len=DRAFT_TARGET_MAX_DRAFT_LEN, speculative_model_dir=spec_model_path
max_draft_len=DRAFT_TARGET_MAX_DRAFT_LEN, speculative_model=spec_model_path
)
def make_eagle3_config(spec_model_path: str):
return EagleDecodingConfig(
max_draft_len=EAGLE_MAX_DRAFT_LEN,
speculative_model_dir=spec_model_path,
speculative_model=spec_model_path,
eagle3_one_model=False,
eagle3_layers_to_capture=None,
)
@ -216,7 +216,7 @@ def test_autodeploy_eagle3_acceptance_rate():
# Configure Eagle3 speculative decoding
speculative_config = EagleDecodingConfig(
max_draft_len=max_draft_len,
speculative_model_dir=eagle_model,
speculative_model=eagle_model,
eagle3_one_model=False,
eagle3_layers_to_capture=None,
)

View File

@ -223,7 +223,7 @@ def get_model_yaml_config(model_label: str,
'speculative_config': {
'decoding_type': 'Eagle',
'eagle3_one_model': True,
'speculative_model_dir': 'Qwen3-4B_eagle3',
'speculative_model': 'Qwen3-4B_eagle3',
'max_draft_len': 3,
},
'kv_cache_config': {

View File

@ -211,7 +211,7 @@ class ServerConfig:
else:
self.eagle3_layers_to_capture = []
self.max_draft_len = speculative_config.get("max_draft_len", 0)
self.speculative_model_dir = speculative_config.get("speculative_model_dir", "")
self.speculative_model = speculative_config.get("speculative_model", "")
# match_mode: "config" (default) or "scenario"
self.match_mode = server_config_data.get("match_mode", "config")
@ -338,7 +338,7 @@ class ServerConfig:
"l_num_nextn_predict_layers": self.num_nextn_predict_layers,
"s_eagle3_layers_to_capture": ",".join(map(str, self.eagle3_layers_to_capture)),
"l_max_draft_len": self.max_draft_len,
"s_speculative_model_dir": self.speculative_model_dir,
"s_speculative_model_dir": self.speculative_model,
"s_server_log_link": "",
"s_server_env_var": self.env_vars,
}
@ -348,15 +348,15 @@ class ServerConfig:
"""Generate extra-llm-api-config.yml content."""
config_data = dict(self.extra_llm_api_config_data)
# Handle speculative_model_dir path conversion
# Handle speculative_model path conversion
if (
"speculative_config" in config_data
and "speculative_model_dir" in config_data["speculative_config"]
and "speculative_model" in config_data["speculative_config"]
):
spec_model_dir = config_data["speculative_config"]["speculative_model_dir"]
if spec_model_dir:
config_data["speculative_config"]["speculative_model_dir"] = os.path.join(
llm_models_root(), spec_model_dir
spec_model = config_data["speculative_config"]["speculative_model"]
if spec_model:
config_data["speculative_config"]["speculative_model"] = os.path.join(
llm_models_root(), spec_model
)
return yaml.dump(config_data, default_flow_style=False, sort_keys=False)

View File

@ -3414,7 +3414,7 @@ def test_eagle3_output_consistency_4gpus(model_dir: str, draft_model_dir: str):
# Run with Eagle3
spec_config = EagleDecodingConfig(
max_draft_len=3,
speculative_model_dir=eagle_model_dir,
speculative_model=eagle_model_dir,
eagle3_one_model=True,
)
with LLM(**llm_common_config, speculative_config=spec_config) as llm_spec:

View File

@ -146,7 +146,7 @@ server_configs:
decoding_type: 'Eagle'
eagle3_layers_to_capture: [-1]
max_draft_len: 3
speculative_model_dir: "gpt_oss/gpt-oss-120b-Eagle3"
speculative_model: "gpt_oss/gpt-oss-120b-Eagle3"
stream_interval: 20
num_postprocess_workers: 4
client_configs:

View File

@ -218,7 +218,7 @@ SERVER_CONFIG_METRICS = {
SPECULATIVE_CONFIG_METRICS = {
"decoding_type": (True, str),
"max_draft_len": (True, int),
"speculative_model_dir": (True, str),
"speculative_model": (True, str),
"eagle3_one_model": (True, str_to_bool),
}
@ -259,7 +259,7 @@ class ServerConfig:
enable_padding: bool = True,
decoding_type: str = "",
max_draft_len: int = 0,
speculative_model_dir: str = "",
speculative_model: str = "",
eagle3_one_model: bool = False,
):
self.name = name
@ -285,7 +285,7 @@ class ServerConfig:
self.enable_padding = enable_padding
self.decoding_type = decoding_type
self.max_draft_len = max_draft_len
self.speculative_model_dir = speculative_model_dir
self.speculative_model = speculative_model
self.eagle3_one_model = eagle3_one_model
model_dir = get_model_dir(self.model_name)
@ -345,9 +345,9 @@ class ServerConfig:
config_lines.append(f" decoding_type: {self.decoding_type}")
if self.max_draft_len > 0:
config_lines.append(f" max_draft_len: {self.max_draft_len}")
if self.speculative_model_dir:
if self.speculative_model:
config_lines.append(
f" speculative_model_dir: {self.speculative_model_dir}")
f" speculative_model: {self.speculative_model}")
if self.eagle3_one_model:
config_lines.append(
f" eagle3_one_model: {str(self.eagle3_one_model).lower()}")
@ -500,8 +500,8 @@ def parse_config_file(config_file_path: str, select_pattern: str = None):
{}).get('decoding_type', ''),
max_draft_len=server_config_data.get('speculative_config',
{}).get('max_draft_len', 0),
speculative_model_dir=server_config_data.get(
'speculative_config', {}).get('speculative_model_dir', ''),
speculative_model=server_config_data.get(
'speculative_config', {}).get('speculative_model', ''),
eagle3_one_model=server_config_data.get(
'speculative_config', {}).get('eagle3_one_model', False))

View File

@ -0,0 +1,115 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shared utilities for local LLM model paths and HuggingFace download mocking."""
import os
from functools import wraps
from pathlib import Path
from typing import Optional
from unittest.mock import patch
# Mapping from HuggingFace Hub ID to local subdirectory under LLM_MODELS_ROOT.
# NOTE: hf_id_to_llm_models_subdir below will fall back to checking if the model name exists
# in LLM_MODELS_ROOT if not present here, so it's not required to exhaustively list all
# models here.
HF_ID_TO_LLM_MODELS_SUBDIR = {
"meta-llama/Meta-Llama-3.1-8B-Instruct": "llama-3.1-model/Llama-3.1-8B-Instruct",
"meta-llama/Llama-3.1-8B-Instruct": "llama-3.1-model/Llama-3.1-8B-Instruct",
"meta-llama/Llama-3.1-8B": "llama-3.1-model/Meta-Llama-3.1-8B",
"TinyLlama/TinyLlama-1.1B-Chat-v1.0": "llama-models-v2/TinyLlama-1.1B-Chat-v1.0",
"meta-llama/Llama-4-Scout-17B-16E-Instruct": "llama4-models/Llama-4-Scout-17B-16E-Instruct",
"mistralai/Mixtral-8x7B-Instruct-v0.1": "Mixtral-8x7B-Instruct-v0.1",
"mistralai/Mistral-Small-3.1-24B-Instruct-2503": "Mistral-Small-3.1-24B-Instruct-2503",
"Qwen/Qwen3-30B-A3B": "Qwen3/Qwen3-30B-A3B",
"Qwen/Qwen2.5-3B-Instruct": "Qwen2.5-3B-Instruct",
"microsoft/Phi-3-mini-4k-instruct": "Phi-3/Phi-3-mini-4k-instruct",
"deepseek-ai/DeepSeek-V3": "DeepSeek-V3",
"deepseek-ai/DeepSeek-R1": "DeepSeek-R1/DeepSeek-R1",
"ibm-ai-platform/Bamba-9B-v2": "Bamba-9B-v2",
"nvidia/NVIDIA-Nemotron-Nano-12B-v2": "NVIDIA-Nemotron-Nano-12B-v2",
"nvidia/NVIDIA-Nemotron-Nano-31B-A3-v3": "NVIDIA-Nemotron-Nano-31B-A3-v3",
"nvidia/Nemotron-Nano-3-30B-A3.5B-dev-1024": "Nemotron-Nano-3-30B-A3.5B-dev-1024",
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B": "EAGLE3-LLaMA3.1-Instruct-8B",
}
def llm_models_root(check: bool = False) -> Optional[Path]:
root = Path("/home/scratch.trt_llm_data/llm-models/")
if "LLM_MODELS_ROOT" in os.environ:
root = Path(os.environ.get("LLM_MODELS_ROOT"))
if not root.exists():
root = Path("/scratch.trt_llm_data/llm-models/")
if check:
assert root.exists(), (
"You must set LLM_MODELS_ROOT env or be able to access /home/scratch.trt_llm_data to run this test"
)
return root if root.exists() else None
def llm_datasets_root() -> str:
return os.path.join(llm_models_root(check=True), "datasets")
def hf_id_to_local_model_dir(hf_hub_id: str) -> str | None:
"""Return the local model directory under LLM_MODELS_ROOT for a given HuggingFace Hub ID, or None if not found."""
root = llm_models_root()
if root is None:
return None
if hf_hub_id in HF_ID_TO_LLM_MODELS_SUBDIR:
return str(root / HF_ID_TO_LLM_MODELS_SUBDIR[hf_hub_id])
# Fall back to checking if the model name exists as a top-level directory in LLM_MODELS_ROOT
model_name = hf_hub_id.split("/")[-1]
if os.path.isdir(root / model_name):
return str(root / model_name)
return None
def hf_model_dir_or_hub_id(hf_hub_id: str) -> str:
"""Resolve a HuggingFace Hub ID to local path if available, otherwise return the Hub ID."""
return hf_id_to_local_model_dir(hf_hub_id) or hf_hub_id
def mock_snapshot_download(repo_id: str, **kwargs) -> str:
"""Mock huggingface_hub.snapshot_download that returns an existing local model directory.
NOTE: This function does not currently handle the revision / allow_patterns / ignore_patterns parameters.
"""
local_path = hf_id_to_local_model_dir(repo_id)
if local_path is None:
raise ValueError(f"Model '{repo_id}' not found in LLM_MODELS_ROOT")
return local_path
def with_mocked_hf_download(func):
"""Decorator to mock huggingface_hub.snapshot_download for tests.
When applied, any calls to snapshot_download will be redirected to use
local model paths from LLM_MODELS_ROOT instead of downloading from HuggingFace.
"""
@wraps(func)
def wrapper(*args, **kwargs):
with patch("huggingface_hub.snapshot_download", side_effect=mock_snapshot_download):
return func(*args, **kwargs)
return wrapper

View File

@ -1,12 +1,11 @@
import copy
import os
from typing import Any, Dict, Optional
import torch
import torch.nn.functional as F
from test_common.llm_data import hf_model_dir_or_hub_id
from torch import nn
from torch.export import Dim
from utils.llm_data import llm_models_root
def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
@ -285,17 +284,6 @@ def generate_dynamic_shapes(max_batch_size, max_seq_len):
return dynamic_shapes
def _hf_model_dir_or_hub_id(
hf_model_subdir: str,
hf_hub_id: str,
) -> str:
llm_models_path = llm_models_root()
if llm_models_path and os.path.isdir((model_fullpath := llm_models_path / hf_model_subdir)):
return str(model_fullpath)
else:
return hf_hub_id
def rotate_half(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
@ -351,7 +339,6 @@ def apply_rotary_pos_emb_ds(q, k, cos, sin, position_ids, unsqueeze_dim=1):
_SMALL_MODEL_CONFIGS = {
"meta-llama/Meta-Llama-3.1-8B-Instruct": {
"llm_models_subdir": "llama-3.1-model/Llama-3.1-8B-Instruct",
"model_kwargs": {
"num_hidden_layers": 1,
"hidden_size": 64,
@ -361,7 +348,6 @@ _SMALL_MODEL_CONFIGS = {
},
},
"mistralai/Mixtral-8x7B-Instruct-v0.1": {
"llm_models_subdir": "Mixtral-8x7B-Instruct-v0.1",
"model_kwargs": {
"num_hidden_layers": 2,
"intermediate_size": 256,
@ -372,7 +358,6 @@ _SMALL_MODEL_CONFIGS = {
},
},
"Qwen/Qwen3-30B-A3B": {
"llm_models_subdir": "Qwen3/Qwen3-30B-A3B",
"model_kwargs": {
"num_hidden_layers": 2,
"intermediate_size": 256,
@ -383,7 +368,6 @@ _SMALL_MODEL_CONFIGS = {
},
},
"microsoft/Phi-3-mini-4k-instruct": {
"llm_models_subdir": "Phi-3/Phi-3-mini-4k-instruct",
"model_kwargs": {
"num_hidden_layers": 2,
"hidden_size": 128,
@ -393,7 +377,6 @@ _SMALL_MODEL_CONFIGS = {
},
},
"meta-llama/Llama-4-Scout-17B-16E-Instruct": {
"llm_models_subdir": "llama4-models/Llama-4-Scout-17B-16E-Instruct",
"model_factory": "AutoModelForImageTextToText",
"model_kwargs": {
"text_config": {
@ -412,7 +395,6 @@ _SMALL_MODEL_CONFIGS = {
},
},
"deepseek-ai/DeepSeek-V3": {
"llm_models_subdir": "DeepSeek-V3",
"model_kwargs": {
"first_k_dense_replace": 1,
"num_hidden_layers": 2,
@ -431,7 +413,6 @@ _SMALL_MODEL_CONFIGS = {
},
},
"Qwen/Qwen2.5-3B-Instruct": {
"llm_models_subdir": "Qwen2.5-3B-Instruct",
"model_kwargs": {
"num_hidden_layers": 2,
"hidden_size": 64,
@ -441,7 +422,6 @@ _SMALL_MODEL_CONFIGS = {
},
},
"mistralai/Mistral-Small-3.1-24B-Instruct-2503": {
"llm_models_subdir": "Mistral-Small-3.1-24B-Instruct-2503",
"model_factory": "AutoModelForImageTextToText",
"model_kwargs": {
"text_config": {
@ -463,7 +443,6 @@ _SMALL_MODEL_CONFIGS = {
},
},
"ibm-ai-platform/Bamba-9B-v2": {
"llm_models_subdir": "Bamba-9B-v2",
"model_kwargs": {
"dtype": "bfloat16",
"hidden_size": 64,
@ -482,7 +461,6 @@ _SMALL_MODEL_CONFIGS = {
},
},
"nvidia/NVIDIA-Nemotron-Nano-12B-v2": {
"llm_models_subdir": "NVIDIA-Nemotron-Nano-12B-v2",
"model_kwargs": {
"dtype": "bfloat16",
"hidden_size": 32,
@ -497,13 +475,11 @@ _SMALL_MODEL_CONFIGS = {
},
},
"TinyLlama/TinyLlama-1.1B-Chat-v1.0": {
"llm_models_subdir": "llama-models-v2/TinyLlama-1.1B-Chat-v1.0",
"model_kwargs": {
"num_hidden_layers": 2,
},
},
"nvidia/Nemotron-Nano-3-30B-A3.5B-dev-1024": {
"llm_models_subdir": "Nemotron-Nano-3-30B-A3.5B-dev-1024",
"model_kwargs": {
"num_hidden_layers": 8,
},
@ -531,7 +507,7 @@ def get_small_model_config(model_hub_id: str, **llm_args_kwargs) -> Dict[str, An
llm_args = copy.deepcopy(_SMALL_MODEL_CONFIGS[model_hub_id])
# check if should use llm_models_root or hf_hub_id
llm_args["model"] = _hf_model_dir_or_hub_id(llm_args.pop("llm_models_subdir"), model_hub_id)
llm_args["model"] = hf_model_dir_or_hub_id(model_hub_id)
# add some defaults to llm_args
llm_args["skip_loading_weights"] = True # No weight loading to speed up things

View File

@ -4,7 +4,7 @@ import types
import pytest
import torch
from _model_test_utils import _hf_model_dir_or_hub_id
from test_common.llm_data import hf_model_dir_or_hub_id
from transformers import AutoConfig, AutoModelForCausalLM
from tensorrt_llm._torch.auto_deploy.models.patches.deepseek import (
@ -77,7 +77,7 @@ def _generate_ds_attention_mask(b, s):
"model_name, module_name, patch, inputs",
[
pytest.param(
_hf_model_dir_or_hub_id("DeepSeek-R1/DeepSeek-R1", "deepseek-ai/DeepSeek-R1"),
hf_model_dir_or_hub_id("deepseek-ai/DeepSeek-R1"),
"model.layers.0.self_attn",
deepseek_v3_attention,
[
@ -87,7 +87,7 @@ def _generate_ds_attention_mask(b, s):
],
), # attention requires inputs [hidden_states, attention_mask, position_ids]
pytest.param(
_hf_model_dir_or_hub_id("DeepSeek-R1/DeepSeek-R1", "deepseek-ai/DeepSeek-R1"),
hf_model_dir_or_hub_id("deepseek-ai/DeepSeek-R1"),
"model.layers.0.mlp",
deepseek_v3_moe_exact,
[torch.randn(2, 6, 8, dtype=torch.bfloat16)],

View File

@ -13,13 +13,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from _model_test_utils import get_small_model_config
from build_and_run_ad import ExperimentConfig, main
from test_common.llm_data import with_mocked_hf_download
from tensorrt_llm.llmapi import DraftTargetDecodingConfig, KvCacheConfig
def test_ad_speculative_decoding_smoke():
@pytest.mark.parametrize("use_hf_speculative_model", [False, True])
@with_mocked_hf_download
def test_ad_speculative_decoding_smoke(use_hf_speculative_model: bool):
"""Test speculative decoding with AutoDeploy using the build_and_run_ad main()."""
# Use a simple test prompt
@ -27,15 +31,15 @@ def test_ad_speculative_decoding_smoke():
# Get base model config
experiment_config = get_small_model_config("meta-llama/Meta-Llama-3.1-8B-Instruct")
speculative_model_dir = get_small_model_config("TinyLlama/TinyLlama-1.1B-Chat-v1.0")["args"][
"model"
]
speculative_model_hf_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
if use_hf_speculative_model:
# NOTE: this will still mock out the actual HuggingFace download
speculative_model = speculative_model_hf_id
else:
speculative_model = get_small_model_config(speculative_model_hf_id)["args"]["model"]
print(f"Speculative model path: {speculative_model_dir}")
# Configure speculative decoding with a draft model
spec_config = DraftTargetDecodingConfig(
max_draft_len=3, speculative_model_dir=speculative_model_dir
)
spec_config = DraftTargetDecodingConfig(max_draft_len=3, speculative_model=speculative_model)
# Configure KV cache
kv_cache_config = KvCacheConfig(

View File

@ -77,7 +77,7 @@ def test_correctness_across_batch_sizes(drafter_type: str, schedule: dict):
else:
spec_config = DraftTargetDecodingConfig(
max_draft_len=max_draft_len,
speculative_model_dir=str(draft_model),
speculative_model=str(draft_model),
draft_len_schedule=schedule,
)
@ -123,7 +123,7 @@ def test_correctness_across_batch_sizes(drafter_type: str, schedule: dict):
else:
spec_config_fixed = DraftTargetDecodingConfig(
max_draft_len=max_draft_len,
speculative_model_dir=str(draft_model),
speculative_model=str(draft_model),
draft_len_schedule=None, # No schedule - fixed draft length
)
llm_fixed = LLM(**llm_common_config, speculative_config=spec_config_fixed)
@ -186,9 +186,7 @@ def test_draft_len_schedule_functionality(drafter_type: str, draft_schedule: dic
else:
spec_config = DraftTargetDecodingConfig(
max_draft_len=5,
speculative_model_dir=str(
llm_models_root() / "llama-3.2-models" / "Llama-3.2-3B-Instruct"
),
speculative_model=str(llm_models_root() / "llama-3.2-models" / "Llama-3.2-3B-Instruct"),
draft_len_schedule=draft_schedule,
)
prompts = ["The capital of France is" for i in range(7)]

View File

@ -45,7 +45,7 @@ def test_llama_draft_target(use_cuda_graph: bool, attn_backend: str):
spec_config = DraftTargetDecodingConfig(
max_draft_len=max_draft_len,
speculative_model_dir=draft_model_dir,
speculative_model=draft_model_dir,
)
prompts = [

View File

@ -87,7 +87,7 @@ def test_draft_token_static_tree_prepare_for_generation():
spec_config = EagleDecodingConfig(
max_draft_len=max_draft_len,
max_total_draft_tokens=max_total_draft_tokens,
speculative_model_dir=eagle_model_dir,
speculative_model=eagle_model_dir,
eagle3_one_model=False,
eagle_choices=eagle_choices,
use_dynamic_tree=use_dynamic_tree,

View File

@ -38,7 +38,7 @@ def test_draft_token_static_tree_sampling():
spec_config = EagleDecodingConfig(
max_draft_len=max_draft_len,
max_total_draft_tokens=max_total_draft_tokens,
speculative_model_dir=eagle_model_dir,
speculative_model=eagle_model_dir,
eagle3_one_model=False,
eagle_choices=eagle_choices,
use_dynamic_tree=use_dynamic_tree,

View File

@ -23,7 +23,7 @@ def run_test(eagle_model_dir, max_seq_len, beam_width, use_dynamic_tree,
spec_config = EagleDecodingConfig(
max_draft_len=max_draft_len,
max_total_draft_tokens=max_total_draft_tokens,
speculative_model_dir=eagle_model_dir,
speculative_model=eagle_model_dir,
eagle3_one_model=False,
eagle_choices=eagle_choices,
use_dynamic_tree=use_dynamic_tree,

View File

@ -56,7 +56,7 @@ def test_dynamic_spec_decode(enforce_single_worker,
spec_config = EagleDecodingConfig(
max_draft_len=max_draft_len,
speculative_model_dir=eagle_model_dir,
speculative_model=eagle_model_dir,
# Llama 3 does not support one model eagle.
eagle3_one_model=False,
)

View File

@ -8,6 +8,7 @@ from unittest.mock import MagicMock
import pytest
import torch
from test_common.llm_data import with_mocked_hf_download
from utils.llm_data import llm_models_root
from tensorrt_llm import LLM, SamplingParams
@ -92,48 +93,83 @@ def test_kv_lens_runtime_with_eagle3_one_model():
@pytest.mark.parametrize(
"use_cuda_graph,attn_backend,disable_overlap_scheduler,enable_block_reuse,use_one_model,enable_chunked_prefill,use_chain_drafter,multi_batch,attention_dp",
"use_cuda_graph,attn_backend,disable_overlap_scheduler,enable_block_reuse,use_one_model,enable_chunked_prefill,use_chain_drafter,multi_batch,attention_dp,use_hf_speculative_model",
[
[True, "TRTLLM", True, False, False, False, True, False, False],
[True, "TRTLLM", True, False, False, False, False, False, False],
[False, "TRTLLM", True, False, False, False, True, False, False],
[False, "TRTLLM", True, False, False, False, False, False, False],
[True, "FLASHINFER", True, False, False, False, True, False, False],
[False, "FLASHINFER", True, False, False, False, True, False, False],
[False, "TRTLLM", False, True, True, False, True, False, False],
[True, "TRTLLM", False, True, True, False, True, False, False],
[True, "TRTLLM", True, False, True, True, True, False, False],
[True, "TRTLLM", True, False, True, False, True, False, False],
[True, "TRTLLM", True, False, False, True, True, False, False],
[True, "TRTLLM", False, False, False, False, True, False, False],
[False, "TRTLLM", False, False, False, False, True, False, False],
[True, "TRTLLM", False, False, False, False, False, True, False],
[True, "TRTLLM", False, False, False, False, False, True, True],
[False, "TRTLLM", False, False, False, False, False, True, False],
[True, "TRTLLM", False, False, False, False, True, True, False],
[False, "TRTLLM", False, False, False, False, True, True, False],
[True, "TRTLLM", False, False, False, False, False, False, False],
[False, "TRTLLM", False, False, False, False, False, False, False],
[True, "TRTLLM", False, False, False, True, True, False, False],
[True, "TRTLLM", False, False, False, True, False, False, False],
[True, "FLASHINFER", False, False, False, False, True, False, False],
[False, "FLASHINFER", False, False, False, False, True, False, False],
[True, "TRTLLM", True, False, False, False, True, False, False, False],
[True, "TRTLLM", True, False, False, False, False, False, False, False],
[False, "TRTLLM", True, False, False, False, True, False, False, False],
[
False, "TRTLLM", True, False, False, False, False, False, False,
False
],
[
True, "FLASHINFER", True, False, False, False, True, False, False,
False
],
[
False, "FLASHINFER", True, False, False, False, True, False, False,
False
],
[False, "TRTLLM", False, True, True, False, True, False, False, False],
[True, "TRTLLM", False, True, True, False, True, False, False, False],
[True, "TRTLLM", True, False, True, True, True, False, False, False],
[True, "TRTLLM", True, False, True, False, True, False, False, False],
[True, "TRTLLM", True, False, False, True, True, False, False, False],
[True, "TRTLLM", False, False, False, False, True, False, False, False],
[
False, "TRTLLM", False, False, False, False, True, False, False,
False
],
[True, "TRTLLM", False, False, False, False, False, True, False, False],
[True, "TRTLLM", False, False, False, False, False, True, True, False],
[
False, "TRTLLM", False, False, False, False, False, True, False,
False
],
[True, "TRTLLM", False, False, False, False, True, True, False, False],
[False, "TRTLLM", False, False, False, False, True, True, False, False],
[
True, "TRTLLM", False, False, False, False, False, False, False,
False
],
[
False, "TRTLLM", False, False, False, False, False, False, False,
False
],
[True, "TRTLLM", False, False, False, True, True, False, False, False],
[True, "TRTLLM", False, False, False, True, False, False, False, False],
[
True, "FLASHINFER", False, False, False, False, True, False, False,
False
],
[
False, "FLASHINFER", False, False, False, False, True, False, False,
False
],
# Tests (mocked) speculative model auto-download from HuggingFace
[False, "TRTLLM", True, False, False, False, True, False, False, True],
])
@pytest.mark.high_cuda_memory
@with_mocked_hf_download
def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str,
disable_overlap_scheduler: bool, enable_block_reuse: bool,
use_one_model: bool, enable_chunked_prefill: bool,
use_chain_drafter: bool, multi_batch: bool,
attention_dp: bool, request):
attention_dp: bool, use_hf_speculative_model: bool,
request):
# Eagle3 one model works with overlap scheduler and block reuse.
total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
if total_mem_gb < 35:
pytest.skip("Not enough memory to load target + draft model")
models_path = llm_models_root()
eagle_model_dir = f"{models_path}/EAGLE3-LLaMA3.1-Instruct-8B"
target_model_dir = f"{models_path}/llama-3.1-model/Llama-3.1-8B-Instruct"
if use_hf_speculative_model:
eagle_model = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
else:
eagle_model = f"{models_path}/EAGLE3-LLaMA3.1-Instruct-8B"
# bs > 1 gives non-deterministic when doing IFB. There are slight chances
# that ref and spec does not match 100%
max_batch_size = 4 if multi_batch else 1
@ -165,7 +201,7 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str,
spec_config = EagleDecodingConfig(
max_draft_len=max_draft_len,
speculative_model_dir=eagle_model_dir,
speculative_model=eagle_model,
# Llama 3 does not support one model eagle.
eagle3_one_model=use_one_model,
)
@ -241,7 +277,7 @@ def test_eagle3_spec_decoding_stats(eagle3_one_model):
free_gpu_memory_fraction=0.6)
spec_config = EagleDecodingConfig(
max_draft_len=3,
speculative_model_dir=eagle_model_dir,
speculative_model=eagle_model_dir,
eagle3_one_model=eagle3_one_model,
)
@ -321,7 +357,7 @@ def test_llama_eagle3_long_prompt(use_cuda_graph):
spec_config = EagleDecodingConfig(
max_draft_len=3,
speculative_model_dir=eagle_model_dir,
speculative_model=eagle_model_dir,
eagle3_one_model=False,
)
@ -445,7 +481,7 @@ def test_deepseek_eagle3():
spec_config = EagleDecodingConfig(
max_draft_len=max_draft_len,
speculative_model_dir=eagle_model_dir,
speculative_model=eagle_model_dir,
# Llama 3 does not support one model eagle.
eagle3_one_model=use_one_model,
eagle3_layers_to_capture={29},
@ -555,7 +591,7 @@ def test_deepseek_mla_eagle3():
)
spec_config = EagleDecodingConfig(max_draft_len=max_draft_len,
speculative_model_dir=eagle_model_dir,
speculative_model=eagle_model_dir,
eagle3_one_model=use_one_model,
load_format="dummy")
@ -654,7 +690,7 @@ def test_multi_eagle3(use_one_model: bool):
spec_config = EagleDecodingConfig(
max_draft_len=max_draft_len,
speculative_model_dir=eagle_model_dir,
speculative_model=eagle_model_dir,
# Llama 3 does not support one model eagle.
eagle3_one_model=use_one_model,
num_eagle_layers=2,
@ -713,7 +749,7 @@ def test_eagle3_cuda_graph_padding(disable_overlap_scheduler: bool):
spec_config = EagleDecodingConfig(
max_draft_len=max_draft_len,
speculative_model_dir=eagle_model_dir,
speculative_model=eagle_model_dir,
eagle3_one_model=use_one_model,
)
@ -766,7 +802,7 @@ def test_eagle3_cdl_sampling(disable_overlap_scheduler: bool):
spec_config = EagleDecodingConfig(
max_draft_len=max_draft_len,
speculative_model_dir=eagle_model_dir,
speculative_model=eagle_model_dir,
eagle3_one_model=use_one_model,
)

View File

@ -52,7 +52,7 @@ def test_kv_cache_reuse(use_cuda_graph: bool, attn_backend: str):
spec_config = EagleDecodingConfig(
max_draft_len=max_draft_len,
speculative_model_dir=eagle_model_dir,
speculative_model=eagle_model_dir,
eagle3_one_model=False,
)

View File

@ -47,7 +47,7 @@ def test_spec_gate_e2e():
spec_config = EagleDecodingConfig(
max_draft_len=max_draft_len,
speculative_model_dir=eagle_model_dir,
speculative_model=eagle_model_dir,
# Llama 3 does not support one model eagle.
eagle3_one_model=False,
max_concurrency=10000,

View File

@ -1218,7 +1218,7 @@ def test_llm_api_medusa():
speculative_config = MedusaDecodingConfig(num_medusa_heads=4,
max_draft_len=63,
speculative_model_dir=get_model_path("medusa-vicuna-7b-v1.3"),
speculative_model=get_model_path("medusa-vicuna-7b-v1.3"),
medusa_choices=[[0], [0, 0], [1], [0, 1], [2], [0, 0, 0], [1, 0], [0, 2], [3], [0, 3], [4], [0, 4], [2, 0], \
[0, 5], [0, 0, 1], [5], [0, 6], [6], [0, 7], [0, 1, 0], [1, 1], [7], [0, 8], [0, 0, 2], [3, 0], \
[0, 9], [8], [9], [1, 0, 0], [0, 2, 0], [1, 2], [0, 0, 3], [4, 0], [2, 1], [0, 0, 4], [0, 0, 5], \
@ -1257,7 +1257,7 @@ def test_llm_api_medusa_tp2():
speculative_config = MedusaDecodingConfig(num_medusa_heads=4,
max_draft_len=63,
speculative_model_dir=get_model_path("medusa-vicuna-7b-v1.3"),
speculative_model=get_model_path("medusa-vicuna-7b-v1.3"),
medusa_choices=[[0], [0, 0], [1], [0, 1], [2], [0, 0, 0], [1, 0], [0, 2], [3], [0, 3], [4], [0, 4], [2, 0], \
[0, 5], [0, 0, 1], [5], [0, 6], [6], [0, 7], [0, 1, 0], [1, 1], [7], [0, 8], [0, 0, 2], [3, 0], \
[0, 9], [8], [9], [1, 0, 0], [0, 2, 0], [1, 2], [0, 0, 3], [4, 0], [2, 1], [0, 0, 4], [0, 0, 5], \
@ -1295,7 +1295,7 @@ def test_llm_api_eagle(**llm_kwargs):
speculative_config = EagleDecodingConfig(
max_draft_len=63,
speculative_model_dir=get_model_path("EAGLE-Vicuna-7B-v1.3"),
speculative_model=get_model_path("EAGLE-Vicuna-7B-v1.3"),
num_eagle_layers=4,
max_non_leaves_per_layer=10,
eagle_choices=[[0], [0, 0], [1], [0, 1], [2], [0, 0, 0], [1, 0], [0, 2], [3], [0, 3], [4], [0, 4], [2, 0], \
@ -1342,7 +1342,7 @@ def test_llm_api_eagle2(**llm_kwargs):
speculative_config = EagleDecodingConfig(
max_draft_len=63,
speculative_model_dir=get_model_path("EAGLE-Vicuna-7B-v1.3"),
speculative_model=get_model_path("EAGLE-Vicuna-7B-v1.3"),
num_eagle_layers=4,
max_non_leaves_per_layer=10,
use_dynamic_tree=True,

View File

@ -445,6 +445,19 @@ class TestTorchLlmArgs:
args = TorchLlmArgs(model=llama_model_path)
args.invalid_arg = 1
def test_speculative_model_alias(self):
"""Test that speculative_model_dir is accepted as an alias for speculative_model."""
spec_config = EagleDecodingConfig(
max_draft_len=3,
speculative_model_dir="/path/to/model",
eagle3_one_model=False,
)
args = TorchLlmArgs(model=llama_model_path,
speculative_config=spec_config)
assert args.speculative_model == "/path/to/model"
class TestTrtLlmArgs:

View File

@ -1,23 +1,15 @@
import os
from pathlib import Path
from typing import Optional
import sys
# Ensure tests/ directory is in path for test_common imports
sys.path.insert(
0,
os.path.dirname(os.path.dirname(os.path.dirname(
os.path.abspath(__file__)))))
def llm_models_root(check=False) -> Optional[Path]:
root = Path("/home/scratch.trt_llm_data/llm-models/")
from test_common.llm_data import llm_datasets_root, llm_models_root
if "LLM_MODELS_ROOT" in os.environ:
root = Path(os.environ.get("LLM_MODELS_ROOT"))
if not root.exists():
root = Path("/scratch.trt_llm_data/llm-models/")
if check:
assert root.exists(), \
"You shall set LLM_MODELS_ROOT env or be able to access /home/scratch.trt_llm_data to run this test"
return root if root.exists() else None
def llm_datasets_root() -> str:
return os.path.join(llm_models_root(check=True), "datasets")
__all__ = [
"llm_datasets_root",
"llm_models_root",
]