mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Merge dc401b5ec9 into 6df2c8a074
This commit is contained in:
commit
55ff21ff81
@ -37,8 +37,13 @@ Draft/target is the simplest form of speculative decoding. In this approach, an
|
||||
```python
|
||||
from tensorrt_llm.llmapi import DraftTargetDecodingConfig
|
||||
|
||||
# Option 1: Use a HuggingFace Hub model ID (auto-downloaded)
|
||||
speculative_config = DraftTargetDecodingConfig(
|
||||
max_draft_len=3, speculative_model_dir="/path/to/draft_model")
|
||||
max_draft_len=3, speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B")
|
||||
|
||||
# Option 2: Use a local path
|
||||
# speculative_config = DraftTargetDecodingConfig(
|
||||
# max_draft_len=3, speculative_model="/path/to/draft_model")
|
||||
|
||||
llm = LLM("/path/to/target_model", speculative_config=speculative_config, disable_overlap_scheduler=True)
|
||||
```
|
||||
@ -51,18 +56,23 @@ TRT-LLM supports a modified version of the algorithm presented in the paper: tre
|
||||
The following draft model checkpoints can be used for EAGLE 3:
|
||||
* Llama 3 variants: [use the checkpoints from the authors of the original EAGLE 3 paper](https://huggingface.co/yuhuili).
|
||||
* Llama 4 Maverick: [use the checkpoint from the NVIDIA HuggingFace repository](https://huggingface.co/nvidia/Llama-4-Maverick-17B-128E-Eagle3).
|
||||
* Other models, including `gpt-oss-120b` and `Qwen3`: check out the [Speculative Decoding Modules](https://huggingface.co/collections/nvidia/speculative-decoding-modules) collection from NVIDIA.
|
||||
|
||||
```python
|
||||
from tensorrt_llm.llmapi import EagleDecodingConfig
|
||||
|
||||
# Enable to use the faster one-model implementation for Llama 4.
|
||||
eagle3_one_model = False
|
||||
model = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
speculative_model = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
|
||||
|
||||
speculative_config = EagleDecodingConfig(
|
||||
max_draft_len=3, speculative_model_dir="/path/to/draft_model", eagle3_one_model=eagle3_one_model)
|
||||
max_draft_len=3,
|
||||
speculative_model=speculative_model,
|
||||
eagle3_one_model=eagle3_one_model)
|
||||
|
||||
# Only need to disable overlap scheduler if eagle3_one_model is False.
|
||||
llm = LLM("/path/to/target_model", speculative_config=speculative_config, disable_overlap_scheduler=True)
|
||||
llm = LLM(model, speculative_config=speculative_config, disable_overlap_scheduler=True)
|
||||
```
|
||||
|
||||
### NGram
|
||||
@ -137,7 +147,17 @@ Speculative decoding options must be specified via `--config config.yaml` for bo
|
||||
|
||||
The rest of the argument names/valid values are the same as in their corresponding configuration class described in the Quick Start section. For example, a YAML configuration could look like this:
|
||||
|
||||
```yaml
|
||||
# Using a HuggingFace Hub model ID (auto-downloaded)
|
||||
disable_overlap_scheduler: true
|
||||
speculative_config:
|
||||
decoding_type: Eagle
|
||||
max_draft_len: 4
|
||||
speculative_model: yuhuili/EAGLE3-LLaMA3.1-Instruct-8B
|
||||
```
|
||||
|
||||
```yaml
|
||||
# Or using a local path
|
||||
disable_overlap_scheduler: true
|
||||
speculative_config:
|
||||
decoding_type: Eagle
|
||||
@ -145,6 +165,16 @@ speculative_config:
|
||||
speculative_model: /path/to/draft/model
|
||||
```
|
||||
|
||||
```{note}
|
||||
The field name `speculative_model_dir` can also be used as an alias for `speculative_config.speculative_model`. For example:
|
||||
|
||||
speculative_config:
|
||||
decoding_type: Eagle
|
||||
max_draft_len: 4
|
||||
speculative_model_dir: /path/to/draft/model
|
||||
```
|
||||
|
||||
|
||||
## Developer Guide
|
||||
|
||||
This section describes the components of a speculative decoding algorithm. All of the interfaces are defined in [`_torch/speculative/interface.py`](https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/_torch/speculative/interface.py).
|
||||
|
||||
@ -23,12 +23,12 @@ def main():
|
||||
model = "lmsys/vicuna-7b-v1.3"
|
||||
|
||||
# The end user can customize the eagle decoding configuration by specifying the
|
||||
# speculative_model_dir, max_draft_len, num_eagle_layers, max_non_leaves_per_layer, eagle_choices
|
||||
# speculative_model, max_draft_len, num_eagle_layers, max_non_leaves_per_layer, eagle_choices
|
||||
# greedy_sampling,posterior_threshold, use_dynamic_tree and dynamic_tree_max_topK
|
||||
# with the EagleDecodingConfig class
|
||||
|
||||
speculative_config = EagleDecodingConfig(
|
||||
speculative_model_dir="yuhuili/EAGLE-Vicuna-7B-v1.3",
|
||||
speculative_model="yuhuili/EAGLE-Vicuna-7B-v1.3",
|
||||
max_draft_len=63,
|
||||
num_eagle_layers=4,
|
||||
max_non_leaves_per_layer=10,
|
||||
|
||||
@ -23,12 +23,12 @@ def main():
|
||||
model = "lmsys/vicuna-7b-v1.3"
|
||||
|
||||
# The end user can customize the eagle decoding configuration by specifying the
|
||||
# speculative_model_dir, max_draft_len, num_eagle_layers, max_non_leaves_per_layer, eagle_choices
|
||||
# speculative_model, max_draft_len, num_eagle_layers, max_non_leaves_per_layer, eagle_choices
|
||||
# greedy_sampling,posterior_threshold, use_dynamic_tree and dynamic_tree_max_topK
|
||||
# with the EagleDecodingConfig class
|
||||
|
||||
speculative_config = EagleDecodingConfig(
|
||||
speculative_model_dir="yuhuili/EAGLE-Vicuna-7B-v1.3",
|
||||
speculative_model="yuhuili/EAGLE-Vicuna-7B-v1.3",
|
||||
max_draft_len=63,
|
||||
num_eagle_layers=4,
|
||||
max_non_leaves_per_layer=10,
|
||||
|
||||
@ -48,10 +48,10 @@ def run_medusa_decoding(use_modelopt_ckpt=False, model_dir=None):
|
||||
model = "lmsys/vicuna-7b-v1.3"
|
||||
|
||||
# The end user can customize the medusa decoding configuration by specifying the
|
||||
# speculative_model_dir, max_draft_len, medusa heads num and medusa choices
|
||||
# speculative_model, max_draft_len, medusa heads num and medusa choices
|
||||
# with the MedusaDecodingConfig class
|
||||
speculative_config = MedusaDecodingConfig(
|
||||
speculative_model_dir="FasterDecoding/medusa-vicuna-7b-v1.3",
|
||||
speculative_model="FasterDecoding/medusa-vicuna-7b-v1.3",
|
||||
max_draft_len=63,
|
||||
num_medusa_heads=4,
|
||||
medusa_choices=[[0], [0, 0], [1], [0, 1], [2], [0, 0, 0], [1, 0], [0, 2], [3], [0, 3], [4], [0, 4], [2, 0], \
|
||||
|
||||
@ -35,7 +35,7 @@ def run_MTP(model: Optional[str] = None):
|
||||
def run_Eagle3():
|
||||
spec_config = EagleDecodingConfig(
|
||||
max_draft_len=3,
|
||||
speculative_model_dir="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
|
||||
speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
|
||||
eagle3_one_model=True)
|
||||
|
||||
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.8)
|
||||
|
||||
@ -220,11 +220,11 @@ def setup_llm(args, **kwargs):
|
||||
relaxed_topk=args.relaxed_topk,
|
||||
relaxed_delta=args.relaxed_delta,
|
||||
mtp_eagle_one_model=args.use_one_model,
|
||||
speculative_model_dir=args.model_dir)
|
||||
speculative_model=args.model_dir)
|
||||
elif spec_decode_algo == "EAGLE3":
|
||||
spec_config = EagleDecodingConfig(
|
||||
max_draft_len=args.spec_decode_max_draft_len,
|
||||
speculative_model_dir=args.draft_model_dir,
|
||||
speculative_model=args.draft_model_dir,
|
||||
eagle3_one_model=args.use_one_model,
|
||||
eagle_choices=args.eagle_choices,
|
||||
use_dynamic_tree=args.use_dynamic_tree,
|
||||
@ -234,7 +234,7 @@ def setup_llm(args, **kwargs):
|
||||
elif spec_decode_algo == "DRAFT_TARGET":
|
||||
spec_config = DraftTargetDecodingConfig(
|
||||
max_draft_len=args.spec_decode_max_draft_len,
|
||||
speculative_model_dir=args.draft_model_dir)
|
||||
speculative_model=args.draft_model_dir)
|
||||
elif spec_decode_algo == "NGRAM":
|
||||
spec_config = NGramDecodingConfig(
|
||||
max_draft_len=args.spec_decode_max_draft_len,
|
||||
|
||||
@ -841,8 +841,8 @@ Qwen3 now supports Eagle3 (Speculative Decoding with Eagle3). To enable Eagle3 o
|
||||
Set the decoding type to "Eagle" to enable Eagle3 speculative decoding.
|
||||
- `speculative_config.max_draft_len: 3`
|
||||
Set the maximum number of draft tokens generated per step (this value can be adjusted as needed).
|
||||
- `speculative_config.speculative_model_dir: <EAGLE3_DRAFT_MODEL_PATH>`
|
||||
Specify the path to the Eagle3 draft model (ensure the corresponding draft model weights are prepared).
|
||||
- `speculative_config.speculative_model: <HUGGINGFACE ID / LOCAL PATH>`
|
||||
Specify the Eagle3 draft model either as a Huggingface model ID or a local path. You can find ready-to-use Eagle3 draft models at https://huggingface.co/collections/nvidia/speculative-decoding-modules.
|
||||
|
||||
Currently, there are some limitations when enabling Eagle3:
|
||||
|
||||
@ -857,7 +857,7 @@ enable_attention_dp: false
|
||||
speculative_config:
|
||||
decoding_type: Eagle
|
||||
max_draft_len: 3
|
||||
speculative_model_dir: <EAGLE3_DRAFT_MODEL_PATH>
|
||||
speculative_model: <HUGGINGFACE ID / LOCAL PATH>
|
||||
kv_cache_config:
|
||||
enable_block_reuse: false
|
||||
" >> ${path_config}
|
||||
|
||||
@ -921,7 +921,7 @@ def create_draft_model_engine_maybe(
|
||||
drafting_loop_wrapper = None
|
||||
|
||||
draft_model_engine = PyTorchModelEngine(
|
||||
model_path=draft_spec_config.speculative_model_dir,
|
||||
model_path=draft_spec_config.speculative_model,
|
||||
llm_args=draft_llm_args,
|
||||
mapping=dist_mapping,
|
||||
attn_runtime_features=attn_runtime_features,
|
||||
|
||||
@ -887,7 +887,7 @@ class SpecDecOneEngineForCausalLM(DecoderModelForCausalLM[TModel, TConfig],
|
||||
from tensorrt_llm._torch.models.checkpoints.mistral.config_loader import \
|
||||
MistralConfigLoader
|
||||
self.draft_config = MistralConfigLoader().load(
|
||||
spec_config.speculative_model_dir,
|
||||
spec_config.speculative_model,
|
||||
mapping=model_config.mapping,
|
||||
moe_backend=model_config.moe_backend,
|
||||
moe_max_num_tokens=model_config.moe_max_num_tokens,
|
||||
@ -898,7 +898,7 @@ class SpecDecOneEngineForCausalLM(DecoderModelForCausalLM[TModel, TConfig],
|
||||
self.draft_config.extra_attrs = model_config.extra_attrs
|
||||
elif spec_config.eagle3_model_arch == "llama3":
|
||||
self.draft_config = ModelConfig.from_pretrained(
|
||||
model_config.spec_config.speculative_model_dir,
|
||||
model_config.spec_config.speculative_model,
|
||||
trust_remote_code=True,
|
||||
attn_backend=model_config.attn_backend,
|
||||
moe_backend=model_config.moe_backend,
|
||||
|
||||
@ -278,7 +278,7 @@ class ModelLoader:
|
||||
if self.spec_config is not None and self.spec_config.spec_dec_mode.need_load_draft_weights(
|
||||
):
|
||||
weights = checkpoint_loader.load_weights(
|
||||
self.spec_config.speculative_model_dir,
|
||||
self.spec_config.speculative_model,
|
||||
mapping=self.mapping)
|
||||
|
||||
draft_model_arch = model.draft_config.pretrained_config.architectures[
|
||||
|
||||
@ -398,7 +398,7 @@ def create_py_executor(
|
||||
draft_llm_args.load_format = LoadFormat.DUMMY
|
||||
|
||||
draft_model_engine = PyTorchModelEngine(
|
||||
model_path=spec_config.speculative_model_dir,
|
||||
model_path=spec_config.speculative_model,
|
||||
llm_args=draft_llm_args,
|
||||
mapping=mapping,
|
||||
attn_runtime_features=attn_runtime_features,
|
||||
|
||||
@ -13,7 +13,7 @@ from typing import (Any, ClassVar, Dict, List, Literal, Optional, Set, Tuple,
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
from pydantic import BaseModel
|
||||
from pydantic import AliasChoices, BaseModel
|
||||
from pydantic import Field as PydanticField
|
||||
from pydantic import PrivateAttr, field_validator, model_validator
|
||||
from strenum import StrEnum
|
||||
@ -651,7 +651,14 @@ class DecodingBaseConfig(StrictBaseModel):
|
||||
# If it's a static or dynamic tree, each draft layer may generate more than one draft token.
|
||||
# In this case, max_total_draft_tokens >= max_draft_len.
|
||||
max_total_draft_tokens: Optional[int] = None
|
||||
speculative_model_dir: Optional[Union[str, Path]] = None
|
||||
# The speculative (draft) model. Accepts either:
|
||||
# - A HuggingFace Hub model ID (str), e.g., "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
|
||||
# which will be automatically downloaded.
|
||||
# - A local filesystem path to a downloaded model directory.
|
||||
speculative_model: Optional[Union[str, Path]] = Field(
|
||||
default=None,
|
||||
validation_alias=AliasChoices("speculative_model",
|
||||
"speculative_model_dir"))
|
||||
|
||||
# PyTorch only.
|
||||
# When specified, speculation will be disabled at batch sizes above
|
||||
@ -858,28 +865,29 @@ class EagleDecodingConfig(DecodingBaseConfig):
|
||||
# choices: llama3, mistral_large3
|
||||
eagle3_model_arch: str = "llama3"
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__()
|
||||
for attr_name, attr_value in kwargs.items():
|
||||
if attr_name == 'max_draft_len':
|
||||
self.num_eagle_layers = attr_value
|
||||
self.max_total_draft_tokens = attr_value # If using linear-tree, the max_total_draft_tokens is the same as max_draft_len
|
||||
# Convert the data type of Eagle choice from str to List[List[int]]
|
||||
if attr_name == 'eagle_choices' and attr_value is not None:
|
||||
logger.warning(
|
||||
"NOTE: The Draft token tree is still under development, PLEASE DO NOT USE IT !!!"
|
||||
)
|
||||
if not isinstance(attr_value, list):
|
||||
if isinstance(attr_value, str):
|
||||
attr_value = ast.literal_eval(
|
||||
attr_value.replace(" ", ""))
|
||||
else:
|
||||
raise ValueError(
|
||||
"Wrong eagle choices type. Eagle choices should be a List[List[int]] or a string like [[0], [1], [2], [0, 0], [0, 1]]."
|
||||
)
|
||||
setattr(self, attr_name, attr_value)
|
||||
@field_validator('eagle_choices', mode='before')
|
||||
@classmethod
|
||||
def validate_eagle_choices(cls, v):
|
||||
if v is not None:
|
||||
logger.warning(
|
||||
"NOTE: The Draft token tree is still under development, PLEASE DO NOT USE IT !!!"
|
||||
)
|
||||
if not isinstance(v, list):
|
||||
if isinstance(v, str):
|
||||
v = ast.literal_eval(v.replace(" ", ""))
|
||||
else:
|
||||
raise ValueError(
|
||||
"Wrong eagle choices type. Eagle choices should be a List[List[int]] or a string like [[0], [1], [2], [0, 0], [0, 1]]."
|
||||
)
|
||||
return v
|
||||
|
||||
@model_validator(mode='after')
|
||||
def validate_eagle_config(self) -> 'EagleDecodingConfig':
|
||||
if self.max_draft_len is None:
|
||||
raise ValueError("max_draft_len is required for Eagle")
|
||||
self.num_eagle_layers = self.max_draft_len
|
||||
self.max_total_draft_tokens = self.max_draft_len # If using linear-tree, the max_total_draft_tokens is the same as max_draft_len
|
||||
|
||||
assert self.max_draft_len is not None, "max_draft_len is required for Eagle"
|
||||
if self.eagle3_model_arch == "mistral_large3" and self.eagle3_layers_to_capture is None:
|
||||
# FIXME find a better way to setup it.
|
||||
self.eagle3_layers_to_capture = {-1}
|
||||
@ -889,7 +897,10 @@ class EagleDecodingConfig(DecodingBaseConfig):
|
||||
# and reset the max_draft_len and num_eagle_layers if necessary
|
||||
if self.eagle_choices is not None:
|
||||
# If eagle_choices is provided, use_dynamic_tree should not be used
|
||||
assert not self.use_dynamic_tree, "If eagle_choices is provided, use_dynamic_tree need to be False"
|
||||
if self.use_dynamic_tree:
|
||||
raise ValueError(
|
||||
"If eagle_choices is provided, use_dynamic_tree need to be False"
|
||||
)
|
||||
|
||||
# Get num_eagle_layers from eagle_choices
|
||||
num_eagle_layers_from_choices = self.check_eagle_choices()
|
||||
@ -906,10 +917,23 @@ class EagleDecodingConfig(DecodingBaseConfig):
|
||||
|
||||
# Dynamic tree logic
|
||||
if self.use_dynamic_tree:
|
||||
assert self.eagle_choices is None, "If use_dynamic_tree is True, eagle_choices should be None"
|
||||
assert self.max_draft_len is not None and self.max_draft_len > 0, "max_draft_len should be provided, which indicates the number of drafter layers"
|
||||
assert self.dynamic_tree_max_topK is not None and self.dynamic_tree_max_topK > 0, "dynamic_tree_max_topK should be provided, which indicates the number of nodes to expand each time"
|
||||
assert self.max_total_draft_tokens is not None and self.max_total_draft_tokens > 0, "max_total_draft_tokens should be provided, which indicates the total nodes of the final draft tree. (exclude the root node)"
|
||||
if self.eagle_choices is not None:
|
||||
raise ValueError(
|
||||
"If use_dynamic_tree is True, eagle_choices should be None")
|
||||
if self.max_draft_len is None or self.max_draft_len <= 0:
|
||||
raise ValueError(
|
||||
"max_draft_len should be provided, which indicates the number of drafter layers"
|
||||
)
|
||||
if self.dynamic_tree_max_topK is None or self.dynamic_tree_max_topK <= 0:
|
||||
raise ValueError(
|
||||
"dynamic_tree_max_topK should be provided, which indicates the number of nodes to expand each time"
|
||||
)
|
||||
if self.max_total_draft_tokens is None or self.max_total_draft_tokens <= 0:
|
||||
raise ValueError(
|
||||
"max_total_draft_tokens should be provided, which indicates the total nodes of the final draft tree. (exclude the root node)"
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict):
|
||||
@ -918,7 +942,7 @@ class EagleDecodingConfig(DecodingBaseConfig):
|
||||
decoding_type: ClassVar[str] = "Eagle"
|
||||
|
||||
def validate(self) -> None:
|
||||
if self.speculative_model_dir is None:
|
||||
if self.speculative_model is None:
|
||||
raise ValueError("Draft model must be provided for EAGLE")
|
||||
|
||||
def check_eagle_choices(self):
|
||||
@ -2119,9 +2143,6 @@ class BaseLlmArgs(StrictBaseModel):
|
||||
|
||||
_parallel_config: Optional[_ParallelConfig] = PrivateAttr(default=None)
|
||||
_model_format: Optional[_ModelFormatKind] = PrivateAttr(default=None)
|
||||
_speculative_model: Optional[str] = PrivateAttr(default=None)
|
||||
_speculative_model_format: Optional[_ModelFormatKind] = PrivateAttr(
|
||||
default=None)
|
||||
|
||||
@property
|
||||
def parallel_config(self) -> _ParallelConfig:
|
||||
@ -2132,12 +2153,8 @@ class BaseLlmArgs(StrictBaseModel):
|
||||
return self._model_format
|
||||
|
||||
@property
|
||||
def speculative_model_dir(self) -> Optional[_ModelFormatKind]:
|
||||
return self._speculative_model
|
||||
|
||||
@property
|
||||
def speculative_model_format(self) -> _ModelFormatKind:
|
||||
return self._speculative_model_format
|
||||
def speculative_model(self) -> Optional[Union[str, Path]]:
|
||||
return self.speculative_config.speculative_model if self.speculative_config is not None else None
|
||||
|
||||
@classmethod
|
||||
def from_kwargs(cls, **kwargs: Any) -> "BaseLlmArgs":
|
||||
@ -2508,7 +2525,7 @@ class TrtLlmArgs(BaseLlmArgs):
|
||||
|
||||
elif isinstance(self.speculative_config, EagleDecodingConfig):
|
||||
assert self.speculative_config.max_draft_len > 0
|
||||
assert self.speculative_config.speculative_model_dir is not None, "Path to EAGLE3 weights must be specified."
|
||||
assert self.speculative_config.speculative_model is not None, "EAGLE3 draft model must be specified."
|
||||
self.build_config.max_draft_len = self.speculative_config.max_draft_len
|
||||
self.build_config.speculative_decoding_mode = SpeculativeDecodingMode.EAGLE
|
||||
eagle_config = _EagleConfig(
|
||||
@ -2528,14 +2545,6 @@ class TrtLlmArgs(BaseLlmArgs):
|
||||
else:
|
||||
self.decoding_config = None
|
||||
|
||||
self._speculative_model = getattr(self.speculative_config,
|
||||
"speculative_model_dir", None)
|
||||
speculative_model_obj = _ModelWrapper(
|
||||
self._speculative_model
|
||||
) if self._speculative_model is not None else None
|
||||
if self._speculative_model and speculative_model_obj.is_local_model:
|
||||
self._speculative_model_format = _ModelFormatKind.HF
|
||||
|
||||
return self
|
||||
|
||||
def _load_config_from_engine(self, engine_dir: Path):
|
||||
@ -3025,12 +3034,12 @@ class TorchLlmArgs(BaseLlmArgs):
|
||||
|
||||
if isinstance(self.speculative_config, EagleDecodingConfig):
|
||||
assert self.speculative_config.max_draft_len > 0
|
||||
assert self.speculative_config.speculative_model_dir is not None, "Path to EAGLE3 weights must be specified."
|
||||
assert self.speculative_config.speculative_model is not None, "EAGLE3 draft model must be specified."
|
||||
elif isinstance(self.speculative_config, NGramDecodingConfig):
|
||||
assert self.speculative_config.max_draft_len > 0 and self.speculative_config.max_matching_ngram_size > 0
|
||||
elif isinstance(self.speculative_config, DraftTargetDecodingConfig):
|
||||
assert self.speculative_config.max_draft_len > 0
|
||||
assert self.speculative_config.speculative_model_dir is not None, "Path to draft model must be specified."
|
||||
assert self.speculative_config.speculative_model is not None, "Draft model must be specified."
|
||||
elif isinstance(self.speculative_config, MTPDecodingConfig):
|
||||
assert self.speculative_config.num_nextn_predict_layers > 0
|
||||
self.speculative_config.max_draft_len = self.speculative_config.num_nextn_predict_layers
|
||||
@ -3057,14 +3066,6 @@ class TorchLlmArgs(BaseLlmArgs):
|
||||
else:
|
||||
self.decoding_config = None
|
||||
|
||||
self._speculative_model = getattr(self.speculative_config,
|
||||
"speculative_model_dir", None)
|
||||
speculative_model_obj = _ModelWrapper(
|
||||
self._speculative_model
|
||||
) if self._speculative_model is not None else None
|
||||
if self._speculative_model and speculative_model_obj.is_local_model:
|
||||
self._speculative_model_format = _ModelFormatKind.HF
|
||||
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
|
||||
@ -109,8 +109,8 @@ class ModelLoader:
|
||||
|
||||
self.model_obj = _ModelWrapper(self.llm_args.model)
|
||||
self.speculative_model_obj = _ModelWrapper(
|
||||
self.llm_args.speculative_model_dir
|
||||
) if self.llm_args.speculative_model_dir is not None else None
|
||||
self.llm_args.speculative_model
|
||||
) if self.llm_args.speculative_model is not None else None
|
||||
|
||||
if isinstance(self.llm_args, TrtLlmArgs):
|
||||
self.convert_checkpoint_options = self.llm_args._convert_checkpoint_options
|
||||
@ -125,7 +125,7 @@ class ModelLoader:
|
||||
Path] = self.model_obj.model_dir if self.model_obj.is_local_model else None
|
||||
|
||||
self._speculative_model_dir: Optional[
|
||||
Path] = self.speculative_model_obj.model_dir if self.speculative_model_obj is not None and self.model_obj.is_local_model else None
|
||||
Path] = self.speculative_model_obj.model_dir if self.speculative_model_obj is not None and self.speculative_model_obj.is_local_model else None
|
||||
self._model_info: Optional[_ModelInfo] = None
|
||||
self._model_format = self.llm_args.model_format
|
||||
|
||||
@ -145,9 +145,7 @@ class ModelLoader:
|
||||
return
|
||||
|
||||
if (self.model_obj.is_hub_model
|
||||
and self._model_format is not _ModelFormatKind.TLLM_ENGINE) or (
|
||||
self.speculative_model_obj
|
||||
and self.speculative_model_obj.is_hub_model):
|
||||
and self._model_format is not _ModelFormatKind.TLLM_ENGINE):
|
||||
# Download HF model if necessary
|
||||
if self.model_obj.model_name is None:
|
||||
raise ValueError(
|
||||
@ -305,31 +303,18 @@ class ModelLoader:
|
||||
def _download_hf_model(self):
|
||||
''' Download HF model from third-party model hub like www.modelscope.cn or huggingface. '''
|
||||
model_dir = None
|
||||
speculative_model_dir = None
|
||||
# Only the rank0 are allowed to download model
|
||||
if mpi_rank() == 0:
|
||||
assert self._workspace is not None
|
||||
assert isinstance(self.model_obj.model_name, str)
|
||||
# this will download only once when multiple MPI processes are running
|
||||
|
||||
model_dir = download_hf_model(self.model_obj.model_name,
|
||||
revision=self.llm_args.revision)
|
||||
print_colored(f"Downloaded model to {model_dir}\n", 'grey')
|
||||
if self.speculative_model_obj:
|
||||
speculative_model_dir = download_hf_model(
|
||||
self.speculative_model_obj.model_name)
|
||||
print_colored(f"Downloaded model to {speculative_model_dir}\n",
|
||||
'grey')
|
||||
# Make all the processes got the same model_dir
|
||||
self._model_dir = mpi_broadcast(model_dir, root=0)
|
||||
self.model_obj.model_dir = self._model_dir # mark as a local model
|
||||
assert self.model_obj.is_local_model
|
||||
if self.speculative_model_obj:
|
||||
self._speculative_model_dir = mpi_broadcast(speculative_model_dir,
|
||||
root=0)
|
||||
self.speculative_model_obj.model_dir = self._speculative_model_dir
|
||||
|
||||
assert self.speculative_model_obj.is_local_model
|
||||
|
||||
def _update_from_hf_quant_config(self) -> bool:
|
||||
"""Update quant_config from the config file of pre-quantized HF checkpoint.
|
||||
@ -440,8 +425,8 @@ class ModelLoader:
|
||||
model_cls = AutoModelForCausalLM.get_trtllm_model_class(
|
||||
self._model_dir, self.llm_args.trust_remote_code,
|
||||
self.llm_args.decoding_config.decoding_mode
|
||||
if hasattr(self.llm_args, "speculative_model_dir")
|
||||
and self.llm_args.speculative_model_dir else None)
|
||||
if hasattr(self.llm_args, "speculative_model")
|
||||
and self.llm_args.speculative_model else None)
|
||||
|
||||
prequantized = self._update_from_hf_quant_config()
|
||||
|
||||
@ -638,18 +623,42 @@ class CachedModelLoader:
|
||||
else:
|
||||
return [task(*args, **kwargs)]
|
||||
|
||||
def _download_hf_model_if_needed(self,
|
||||
model_obj: _ModelWrapper,
|
||||
revision: Optional[str] = None) -> Path:
|
||||
"""Download a model from HF hub if needed.
|
||||
|
||||
Also updates the model_obj.model_dir with the local model dir on rank 0.
|
||||
"""
|
||||
if model_obj.is_hub_model:
|
||||
model_dirs = self._submit_to_all_workers(
|
||||
CachedModelLoader._node_download_hf_model,
|
||||
model=model_obj.model_name,
|
||||
revision=revision)
|
||||
model_dir = model_dirs[0]
|
||||
model_obj.model_dir = model_dir
|
||||
return model_dir
|
||||
return model_obj.model_dir
|
||||
|
||||
def __call__(self) -> Tuple[Path, Union[Path, None]]:
|
||||
|
||||
if self.llm_args.model_format is _ModelFormatKind.TLLM_ENGINE:
|
||||
return Path(self.llm_args.model), None
|
||||
|
||||
# Download speculative model from HuggingFace if needed (all backends)
|
||||
if (self.llm_args.speculative_config is not None and
|
||||
self.llm_args.speculative_config.speculative_model is not None):
|
||||
spec_model_obj = _ModelWrapper(
|
||||
self.llm_args.speculative_config.speculative_model)
|
||||
spec_model_dir = self._download_hf_model_if_needed(spec_model_obj)
|
||||
self.llm_args.speculative_config.speculative_model = spec_model_dir
|
||||
|
||||
# AutoDeploy doesn't use ModelLoader
|
||||
if self.llm_args.backend == "_autodeploy":
|
||||
return None, ""
|
||||
|
||||
self.engine_cache_stage: Optional[CachedStage] = None
|
||||
|
||||
self._hf_model_dir = None
|
||||
|
||||
self.model_loader = ModelLoader(self.llm_args)
|
||||
|
||||
if self.llm_args.backend is not None:
|
||||
@ -657,14 +666,8 @@ class CachedModelLoader:
|
||||
raise ValueError(
|
||||
f'backend {self.llm_args.backend} is not supported.')
|
||||
|
||||
if self.model_loader.model_obj.is_hub_model:
|
||||
hf_model_dirs = self._submit_to_all_workers(
|
||||
CachedModelLoader._node_download_hf_model,
|
||||
model=self.model_loader.model_obj.model_name,
|
||||
revision=self.llm_args.revision)
|
||||
self._hf_model_dir = hf_model_dirs[0]
|
||||
else:
|
||||
self._hf_model_dir = self.model_loader.model_obj.model_dir
|
||||
self._hf_model_dir = self._download_hf_model_if_needed(
|
||||
self.model_loader.model_obj, revision=self.llm_args.revision)
|
||||
|
||||
if self.llm_args.quant_config.quant_algo is not None:
|
||||
logger.warning(
|
||||
|
||||
@ -224,6 +224,7 @@ class DisabledTqdm(tqdm):
|
||||
|
||||
def download_hf_model(model: str, revision: Optional[str] = None) -> Path:
|
||||
ignore_patterns = ["original/**/*"]
|
||||
logger.info(f"Downloading model {model} from HuggingFace")
|
||||
with get_file_lock(model):
|
||||
hf_folder = snapshot_download(
|
||||
model,
|
||||
@ -231,6 +232,7 @@ def download_hf_model(model: str, revision: Optional[str] = None) -> Path:
|
||||
ignore_patterns=ignore_patterns,
|
||||
revision=revision,
|
||||
tqdm_class=DisabledTqdm)
|
||||
logger.info(f"Finished downloading model {model} from HuggingFace")
|
||||
return Path(hf_folder)
|
||||
|
||||
|
||||
|
||||
@ -576,7 +576,7 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
|
||||
speculative_decoding_config = {
|
||||
"decoding_type": "Eagle",
|
||||
"max_draft_len": 4,
|
||||
"speculative_model_dir":
|
||||
"speculative_model":
|
||||
f"{llm_models_root()}/EAGLE3-LLaMA3.1-Instruct-8B",
|
||||
"eagle3_one_model": eagle3_one_model
|
||||
}
|
||||
@ -675,7 +675,7 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
|
||||
speculative_decoding_config = {
|
||||
"decoding_type": "Eagle",
|
||||
"max_draft_len": 3,
|
||||
"speculative_model_dir":
|
||||
"speculative_model":
|
||||
f"{llm_models_root()}/EAGLE3-LLaMA3.1-Instruct-8B",
|
||||
"eagle3_one_model": eagle3_one_model
|
||||
}
|
||||
|
||||
@ -471,7 +471,7 @@ class TestEagleVicuna_7B_v1_3(LlmapiAccuracyTestHarness):
|
||||
|
||||
speculative_config = EagleDecodingConfig(
|
||||
max_draft_len=63,
|
||||
speculative_model_dir=f"{llm_models_root()}/EAGLE-Vicuna-7B-v1.3",
|
||||
speculative_model=f"{llm_models_root()}/EAGLE-Vicuna-7B-v1.3",
|
||||
num_eagle_layers=4,
|
||||
max_non_leaves_per_layer=10,
|
||||
eagle_choices=[[0], [0, 0], [1], [0, 1], [2], [0, 0, 0], [1, 0], [0, 2], [3], [0, 3], [4], [0, 4], [2, 0], \
|
||||
@ -497,7 +497,7 @@ class TestEagle2Vicuna_7B_v1_3(LlmapiAccuracyTestHarness):
|
||||
|
||||
speculative_config = EagleDecodingConfig(
|
||||
max_draft_len=63,
|
||||
speculative_model_dir=f"{llm_models_root()}/EAGLE-Vicuna-7B-v1.3",
|
||||
speculative_model=f"{llm_models_root()}/EAGLE-Vicuna-7B-v1.3",
|
||||
num_eagle_layers=4,
|
||||
max_non_leaves_per_layer=10,
|
||||
use_dynamic_tree=True,
|
||||
|
||||
@ -13,34 +13,19 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from test_common.llm_data import hf_model_dir_or_hub_id, llm_models_root
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy import LLM as AutoDeployLLM
|
||||
from tensorrt_llm.quantization import QuantAlgo
|
||||
from tensorrt_llm.sampling_params import SamplingParams
|
||||
|
||||
from ..conftest import llm_models_root
|
||||
from .accuracy_core import GSM8K, MMLU, CnnDailymail, LlmapiAccuracyTestHarness
|
||||
|
||||
|
||||
def _hf_model_dir_or_hub_id(
|
||||
hf_model_subdir: str,
|
||||
hf_hub_id: str,
|
||||
) -> str:
|
||||
llm_models_path = llm_models_root()
|
||||
if llm_models_path and os.path.isdir(
|
||||
(model_fullpath := os.path.join(llm_models_path, hf_model_subdir))):
|
||||
return str(model_fullpath)
|
||||
else:
|
||||
return hf_hub_id
|
||||
|
||||
|
||||
class TestLlama3_1_8B(LlmapiAccuracyTestHarness):
|
||||
MODEL_NAME = "meta-llama/Llama-3.1-8B"
|
||||
MODEL_PATH = _hf_model_dir_or_hub_id("llama-3.1-model/Meta-Llama-3.1-8B",
|
||||
MODEL_NAME)
|
||||
MODEL_PATH = hf_model_dir_or_hub_id(MODEL_NAME)
|
||||
|
||||
def get_default_kwargs(self, enable_chunked_prefill=False):
|
||||
config = {
|
||||
|
||||
@ -276,7 +276,7 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
|
||||
|
||||
draft_len = 4
|
||||
spec_config = EagleDecodingConfig(max_draft_len=draft_len,
|
||||
speculative_model_dir=eagle_model_dir,
|
||||
speculative_model=eagle_model_dir,
|
||||
eagle3_one_model=eagle3_one_model)
|
||||
|
||||
with LLM(model=target_model_dir,
|
||||
@ -369,8 +369,7 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
|
||||
cuda_graph_config = CudaGraphConfig(enable_padding=True)
|
||||
spec_config = EagleDecodingConfig(
|
||||
max_draft_len=3,
|
||||
speculative_model_dir=
|
||||
f"{llm_models_root()}/EAGLE3-LLaMA3.1-Instruct-8B",
|
||||
speculative_model=f"{llm_models_root()}/EAGLE3-LLaMA3.1-Instruct-8B",
|
||||
eagle3_one_model=eagle3_one_model)
|
||||
llm = LLM(
|
||||
self.MODEL_PATH,
|
||||
@ -621,7 +620,7 @@ class TestLlama3_3_70BInstruct(LlmapiAccuracyTestHarness):
|
||||
eagle_model_dir = f"{llm_models_root()}/EAGLE3-LLaMA3.3-Instruct-70B"
|
||||
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.6)
|
||||
spec_config = EagleDecodingConfig(max_draft_len=3,
|
||||
speculative_model_dir=eagle_model_dir,
|
||||
speculative_model=eagle_model_dir,
|
||||
eagle3_one_model=eagle3_one_model)
|
||||
torch_compile_config = _get_default_torch_compile_config(torch_compile)
|
||||
pytorch_config = dict(
|
||||
@ -1383,7 +1382,7 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
)
|
||||
mtp_config = MTPDecodingConfig(num_nextn_predict_layers=3,
|
||||
mtp_eagle_one_model=False,
|
||||
speculative_model_dir=self.MODEL_PATH)
|
||||
speculative_model=self.MODEL_PATH)
|
||||
with LLM(self.MODEL_PATH,
|
||||
kv_cache_config=kv_cache_config,
|
||||
enable_chunked_prefill=False,
|
||||
@ -2935,7 +2934,7 @@ class TestGLM4_6(LlmapiAccuracyTestHarness):
|
||||
|
||||
mtp_config = MTPDecodingConfig(num_nextn_predict_layers=3,
|
||||
mtp_eagle_one_model=False,
|
||||
speculative_model_dir=model_path)
|
||||
speculative_model=model_path)
|
||||
|
||||
with LLM(model_path,
|
||||
max_batch_size=max_batch_size,
|
||||
@ -3441,7 +3440,7 @@ class TestQwen3_8B(LlmapiAccuracyTestHarness):
|
||||
|
||||
draft_len = 4
|
||||
spec_config = EagleDecodingConfig(max_draft_len=draft_len,
|
||||
speculative_model_dir=eagle_model_dir,
|
||||
speculative_model=eagle_model_dir,
|
||||
eagle3_one_model=eagle3_one_model)
|
||||
|
||||
llm = LLM(model=target_model_dir,
|
||||
@ -3812,7 +3811,7 @@ class TestQwen3_235B_A22B(LlmapiAccuracyTestHarness):
|
||||
if eagle3:
|
||||
spec_config = EagleDecodingConfig(
|
||||
max_draft_len=2,
|
||||
speculative_model_dir=
|
||||
speculative_model=
|
||||
f"{llm_models_root()}/Qwen3/qwen3-235B-eagle3/",
|
||||
eagle3_one_model=True)
|
||||
with LLM(
|
||||
@ -3860,7 +3859,7 @@ class TestQwen3_235B_A22B(LlmapiAccuracyTestHarness):
|
||||
if eagle3:
|
||||
spec_config = EagleDecodingConfig(
|
||||
max_draft_len=2,
|
||||
speculative_model_dir=
|
||||
speculative_model=
|
||||
f"{llm_models_root()}/Qwen3/qwen3-235B-eagle3/",
|
||||
eagle3_one_model=True)
|
||||
with LLM(
|
||||
@ -4479,7 +4478,7 @@ class TestGPTOSS(LlmapiAccuracyTestHarness):
|
||||
eagle_model_dir = f"{llm_models_root()}/gpt_oss/gpt-oss-120b-Eagle3"
|
||||
draft_len = 3
|
||||
spec_config = EagleDecodingConfig(max_draft_len=draft_len,
|
||||
speculative_model_dir=eagle_model_dir,
|
||||
speculative_model=eagle_model_dir,
|
||||
eagle3_one_model=one_model,
|
||||
allow_advanced_sampling=True)
|
||||
|
||||
@ -4545,7 +4544,7 @@ class TestGPTOSS(LlmapiAccuracyTestHarness):
|
||||
eagle_model_dir = f"{llm_models_root()}/gpt_oss/gpt-oss-120b-Eagle3"
|
||||
draft_len = 3
|
||||
spec_config = EagleDecodingConfig(max_draft_len=draft_len,
|
||||
speculative_model_dir=eagle_model_dir,
|
||||
speculative_model=eagle_model_dir,
|
||||
eagle3_one_model=one_model,
|
||||
allow_advanced_sampling=True)
|
||||
|
||||
@ -4609,7 +4608,7 @@ class TestGPTOSS(LlmapiAccuracyTestHarness):
|
||||
eagle_model_dir = f"{llm_models_root()}/gpt_oss/gpt-oss-120b-Eagle3"
|
||||
draft_len = 3
|
||||
spec_config = EagleDecodingConfig(max_draft_len=draft_len,
|
||||
speculative_model_dir=eagle_model_dir,
|
||||
speculative_model=eagle_model_dir,
|
||||
eagle3_one_model=one_model,
|
||||
allow_advanced_sampling=True)
|
||||
|
||||
@ -4668,7 +4667,7 @@ class TestGPTOSS(LlmapiAccuracyTestHarness):
|
||||
eagle_model_dir = f"{llm_models_root()}/gpt_oss/gpt-oss-120b-Eagle3"
|
||||
draft_len = 3
|
||||
spec_config = EagleDecodingConfig(max_draft_len=draft_len,
|
||||
speculative_model_dir=eagle_model_dir,
|
||||
speculative_model=eagle_model_dir,
|
||||
eagle3_one_model=one_model)
|
||||
|
||||
max_seq_len = MAX_INPUT_LEN + MAX_OUTPUT_LEN
|
||||
@ -5150,7 +5149,7 @@ class TestMistralLarge3_675B(LlmapiAccuracyTestHarness):
|
||||
if eagle3:
|
||||
spec_config = EagleDecodingConfig(
|
||||
max_draft_len=2,
|
||||
speculative_model_dir=
|
||||
speculative_model=
|
||||
f"{llm_models_root()}/Mistral-Large-3-675B/Mistral-Large-3-675B-Instruct-2512-Eagle/",
|
||||
eagle3_one_model=True,
|
||||
eagle3_model_arch="mistral_large3")
|
||||
@ -5201,7 +5200,7 @@ class TestMistralLarge3_675B(LlmapiAccuracyTestHarness):
|
||||
if eagle3:
|
||||
spec_config = EagleDecodingConfig(
|
||||
max_draft_len=2,
|
||||
speculative_model_dir=
|
||||
speculative_model=
|
||||
f"{llm_models_root()}/Mistral-Large-3-675B/Mistral-Large-3-675B-Instruct-2512-Eagle/",
|
||||
eagle3_one_model=True,
|
||||
eagle3_model_arch="mistral_large3")
|
||||
|
||||
@ -400,7 +400,7 @@ def test_disaggregated_spec_dec_batch_slot_limit(model, spec_dec_model_path,
|
||||
# Test whether the batch slots are properly released when using speculative decoding
|
||||
# with disaggregated serving.
|
||||
spec_dec_config = EagleDecodingConfig(
|
||||
speculative_model_dir=model_path(spec_dec_model_path),
|
||||
speculative_model=model_path(spec_dec_model_path),
|
||||
eagle3_one_model=eagle3_one_model,
|
||||
max_draft_len=3)
|
||||
|
||||
|
||||
@ -93,7 +93,7 @@ def test_spec_decoding_metrics_eagle3_one_model():
|
||||
"speculative_config": {
|
||||
"decoding_type": "Eagle",
|
||||
"max_draft_len": 4,
|
||||
"speculative_model_dir": eagle3_path,
|
||||
"speculative_model": eagle3_path,
|
||||
"eagle3_one_model": True,
|
||||
},
|
||||
}
|
||||
@ -174,7 +174,7 @@ def test_spec_decoding_metrics_eagle3_two_model():
|
||||
"speculative_config": {
|
||||
"decoding_type": "Eagle",
|
||||
"max_draft_len": 4,
|
||||
"speculative_model_dir": eagle3_path,
|
||||
"speculative_model": eagle3_path,
|
||||
"eagle3_one_model": False, # Two-model mode
|
||||
},
|
||||
}
|
||||
|
||||
@ -52,14 +52,14 @@ def get_model_paths():
|
||||
|
||||
def make_draft_target_config(spec_model_path: str):
|
||||
return DraftTargetDecodingConfig(
|
||||
max_draft_len=DRAFT_TARGET_MAX_DRAFT_LEN, speculative_model_dir=spec_model_path
|
||||
max_draft_len=DRAFT_TARGET_MAX_DRAFT_LEN, speculative_model=spec_model_path
|
||||
)
|
||||
|
||||
|
||||
def make_eagle3_config(spec_model_path: str):
|
||||
return EagleDecodingConfig(
|
||||
max_draft_len=EAGLE_MAX_DRAFT_LEN,
|
||||
speculative_model_dir=spec_model_path,
|
||||
speculative_model=spec_model_path,
|
||||
eagle3_one_model=False,
|
||||
eagle3_layers_to_capture=None,
|
||||
)
|
||||
@ -216,7 +216,7 @@ def test_autodeploy_eagle3_acceptance_rate():
|
||||
# Configure Eagle3 speculative decoding
|
||||
speculative_config = EagleDecodingConfig(
|
||||
max_draft_len=max_draft_len,
|
||||
speculative_model_dir=eagle_model,
|
||||
speculative_model=eagle_model,
|
||||
eagle3_one_model=False,
|
||||
eagle3_layers_to_capture=None,
|
||||
)
|
||||
|
||||
@ -223,7 +223,7 @@ def get_model_yaml_config(model_label: str,
|
||||
'speculative_config': {
|
||||
'decoding_type': 'Eagle',
|
||||
'eagle3_one_model': True,
|
||||
'speculative_model_dir': 'Qwen3-4B_eagle3',
|
||||
'speculative_model': 'Qwen3-4B_eagle3',
|
||||
'max_draft_len': 3,
|
||||
},
|
||||
'kv_cache_config': {
|
||||
|
||||
@ -211,7 +211,7 @@ class ServerConfig:
|
||||
else:
|
||||
self.eagle3_layers_to_capture = []
|
||||
self.max_draft_len = speculative_config.get("max_draft_len", 0)
|
||||
self.speculative_model_dir = speculative_config.get("speculative_model_dir", "")
|
||||
self.speculative_model = speculative_config.get("speculative_model", "")
|
||||
|
||||
# match_mode: "config" (default) or "scenario"
|
||||
self.match_mode = server_config_data.get("match_mode", "config")
|
||||
@ -338,7 +338,7 @@ class ServerConfig:
|
||||
"l_num_nextn_predict_layers": self.num_nextn_predict_layers,
|
||||
"s_eagle3_layers_to_capture": ",".join(map(str, self.eagle3_layers_to_capture)),
|
||||
"l_max_draft_len": self.max_draft_len,
|
||||
"s_speculative_model_dir": self.speculative_model_dir,
|
||||
"s_speculative_model_dir": self.speculative_model,
|
||||
"s_server_log_link": "",
|
||||
"s_server_env_var": self.env_vars,
|
||||
}
|
||||
@ -348,15 +348,15 @@ class ServerConfig:
|
||||
"""Generate extra-llm-api-config.yml content."""
|
||||
config_data = dict(self.extra_llm_api_config_data)
|
||||
|
||||
# Handle speculative_model_dir path conversion
|
||||
# Handle speculative_model path conversion
|
||||
if (
|
||||
"speculative_config" in config_data
|
||||
and "speculative_model_dir" in config_data["speculative_config"]
|
||||
and "speculative_model" in config_data["speculative_config"]
|
||||
):
|
||||
spec_model_dir = config_data["speculative_config"]["speculative_model_dir"]
|
||||
if spec_model_dir:
|
||||
config_data["speculative_config"]["speculative_model_dir"] = os.path.join(
|
||||
llm_models_root(), spec_model_dir
|
||||
spec_model = config_data["speculative_config"]["speculative_model"]
|
||||
if spec_model:
|
||||
config_data["speculative_config"]["speculative_model"] = os.path.join(
|
||||
llm_models_root(), spec_model
|
||||
)
|
||||
|
||||
return yaml.dump(config_data, default_flow_style=False, sort_keys=False)
|
||||
|
||||
@ -3414,7 +3414,7 @@ def test_eagle3_output_consistency_4gpus(model_dir: str, draft_model_dir: str):
|
||||
# Run with Eagle3
|
||||
spec_config = EagleDecodingConfig(
|
||||
max_draft_len=3,
|
||||
speculative_model_dir=eagle_model_dir,
|
||||
speculative_model=eagle_model_dir,
|
||||
eagle3_one_model=True,
|
||||
)
|
||||
with LLM(**llm_common_config, speculative_config=spec_config) as llm_spec:
|
||||
|
||||
@ -146,7 +146,7 @@ server_configs:
|
||||
decoding_type: 'Eagle'
|
||||
eagle3_layers_to_capture: [-1]
|
||||
max_draft_len: 3
|
||||
speculative_model_dir: "gpt_oss/gpt-oss-120b-Eagle3"
|
||||
speculative_model: "gpt_oss/gpt-oss-120b-Eagle3"
|
||||
stream_interval: 20
|
||||
num_postprocess_workers: 4
|
||||
client_configs:
|
||||
|
||||
@ -218,7 +218,7 @@ SERVER_CONFIG_METRICS = {
|
||||
SPECULATIVE_CONFIG_METRICS = {
|
||||
"decoding_type": (True, str),
|
||||
"max_draft_len": (True, int),
|
||||
"speculative_model_dir": (True, str),
|
||||
"speculative_model": (True, str),
|
||||
"eagle3_one_model": (True, str_to_bool),
|
||||
}
|
||||
|
||||
@ -259,7 +259,7 @@ class ServerConfig:
|
||||
enable_padding: bool = True,
|
||||
decoding_type: str = "",
|
||||
max_draft_len: int = 0,
|
||||
speculative_model_dir: str = "",
|
||||
speculative_model: str = "",
|
||||
eagle3_one_model: bool = False,
|
||||
):
|
||||
self.name = name
|
||||
@ -285,7 +285,7 @@ class ServerConfig:
|
||||
self.enable_padding = enable_padding
|
||||
self.decoding_type = decoding_type
|
||||
self.max_draft_len = max_draft_len
|
||||
self.speculative_model_dir = speculative_model_dir
|
||||
self.speculative_model = speculative_model
|
||||
self.eagle3_one_model = eagle3_one_model
|
||||
|
||||
model_dir = get_model_dir(self.model_name)
|
||||
@ -345,9 +345,9 @@ class ServerConfig:
|
||||
config_lines.append(f" decoding_type: {self.decoding_type}")
|
||||
if self.max_draft_len > 0:
|
||||
config_lines.append(f" max_draft_len: {self.max_draft_len}")
|
||||
if self.speculative_model_dir:
|
||||
if self.speculative_model:
|
||||
config_lines.append(
|
||||
f" speculative_model_dir: {self.speculative_model_dir}")
|
||||
f" speculative_model: {self.speculative_model}")
|
||||
if self.eagle3_one_model:
|
||||
config_lines.append(
|
||||
f" eagle3_one_model: {str(self.eagle3_one_model).lower()}")
|
||||
@ -500,8 +500,8 @@ def parse_config_file(config_file_path: str, select_pattern: str = None):
|
||||
{}).get('decoding_type', ''),
|
||||
max_draft_len=server_config_data.get('speculative_config',
|
||||
{}).get('max_draft_len', 0),
|
||||
speculative_model_dir=server_config_data.get(
|
||||
'speculative_config', {}).get('speculative_model_dir', ''),
|
||||
speculative_model=server_config_data.get(
|
||||
'speculative_config', {}).get('speculative_model', ''),
|
||||
eagle3_one_model=server_config_data.get(
|
||||
'speculative_config', {}).get('eagle3_one_model', False))
|
||||
|
||||
|
||||
115
tests/test_common/llm_data.py
Normal file
115
tests/test_common/llm_data.py
Normal file
@ -0,0 +1,115 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Shared utilities for local LLM model paths and HuggingFace download mocking."""
|
||||
|
||||
import os
|
||||
from functools import wraps
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from unittest.mock import patch
|
||||
|
||||
# Mapping from HuggingFace Hub ID to local subdirectory under LLM_MODELS_ROOT.
|
||||
# NOTE: hf_id_to_llm_models_subdir below will fall back to checking if the model name exists
|
||||
# in LLM_MODELS_ROOT if not present here, so it's not required to exhaustively list all
|
||||
# models here.
|
||||
HF_ID_TO_LLM_MODELS_SUBDIR = {
|
||||
"meta-llama/Meta-Llama-3.1-8B-Instruct": "llama-3.1-model/Llama-3.1-8B-Instruct",
|
||||
"meta-llama/Llama-3.1-8B-Instruct": "llama-3.1-model/Llama-3.1-8B-Instruct",
|
||||
"meta-llama/Llama-3.1-8B": "llama-3.1-model/Meta-Llama-3.1-8B",
|
||||
"TinyLlama/TinyLlama-1.1B-Chat-v1.0": "llama-models-v2/TinyLlama-1.1B-Chat-v1.0",
|
||||
"meta-llama/Llama-4-Scout-17B-16E-Instruct": "llama4-models/Llama-4-Scout-17B-16E-Instruct",
|
||||
"mistralai/Mixtral-8x7B-Instruct-v0.1": "Mixtral-8x7B-Instruct-v0.1",
|
||||
"mistralai/Mistral-Small-3.1-24B-Instruct-2503": "Mistral-Small-3.1-24B-Instruct-2503",
|
||||
"Qwen/Qwen3-30B-A3B": "Qwen3/Qwen3-30B-A3B",
|
||||
"Qwen/Qwen2.5-3B-Instruct": "Qwen2.5-3B-Instruct",
|
||||
"microsoft/Phi-3-mini-4k-instruct": "Phi-3/Phi-3-mini-4k-instruct",
|
||||
"deepseek-ai/DeepSeek-V3": "DeepSeek-V3",
|
||||
"deepseek-ai/DeepSeek-R1": "DeepSeek-R1/DeepSeek-R1",
|
||||
"ibm-ai-platform/Bamba-9B-v2": "Bamba-9B-v2",
|
||||
"nvidia/NVIDIA-Nemotron-Nano-12B-v2": "NVIDIA-Nemotron-Nano-12B-v2",
|
||||
"nvidia/NVIDIA-Nemotron-Nano-31B-A3-v3": "NVIDIA-Nemotron-Nano-31B-A3-v3",
|
||||
"nvidia/Nemotron-Nano-3-30B-A3.5B-dev-1024": "Nemotron-Nano-3-30B-A3.5B-dev-1024",
|
||||
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B": "EAGLE3-LLaMA3.1-Instruct-8B",
|
||||
}
|
||||
|
||||
|
||||
def llm_models_root(check: bool = False) -> Optional[Path]:
|
||||
root = Path("/home/scratch.trt_llm_data/llm-models/")
|
||||
|
||||
if "LLM_MODELS_ROOT" in os.environ:
|
||||
root = Path(os.environ.get("LLM_MODELS_ROOT"))
|
||||
|
||||
if not root.exists():
|
||||
root = Path("/scratch.trt_llm_data/llm-models/")
|
||||
|
||||
if check:
|
||||
assert root.exists(), (
|
||||
"You must set LLM_MODELS_ROOT env or be able to access /home/scratch.trt_llm_data to run this test"
|
||||
)
|
||||
|
||||
return root if root.exists() else None
|
||||
|
||||
|
||||
def llm_datasets_root() -> str:
|
||||
return os.path.join(llm_models_root(check=True), "datasets")
|
||||
|
||||
|
||||
def hf_id_to_local_model_dir(hf_hub_id: str) -> str | None:
|
||||
"""Return the local model directory under LLM_MODELS_ROOT for a given HuggingFace Hub ID, or None if not found."""
|
||||
root = llm_models_root()
|
||||
if root is None:
|
||||
return None
|
||||
|
||||
if hf_hub_id in HF_ID_TO_LLM_MODELS_SUBDIR:
|
||||
return str(root / HF_ID_TO_LLM_MODELS_SUBDIR[hf_hub_id])
|
||||
|
||||
# Fall back to checking if the model name exists as a top-level directory in LLM_MODELS_ROOT
|
||||
model_name = hf_hub_id.split("/")[-1]
|
||||
if os.path.isdir(root / model_name):
|
||||
return str(root / model_name)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def hf_model_dir_or_hub_id(hf_hub_id: str) -> str:
|
||||
"""Resolve a HuggingFace Hub ID to local path if available, otherwise return the Hub ID."""
|
||||
return hf_id_to_local_model_dir(hf_hub_id) or hf_hub_id
|
||||
|
||||
|
||||
def mock_snapshot_download(repo_id: str, **kwargs) -> str:
|
||||
"""Mock huggingface_hub.snapshot_download that returns an existing local model directory.
|
||||
|
||||
NOTE: This function does not currently handle the revision / allow_patterns / ignore_patterns parameters.
|
||||
"""
|
||||
local_path = hf_id_to_local_model_dir(repo_id)
|
||||
if local_path is None:
|
||||
raise ValueError(f"Model '{repo_id}' not found in LLM_MODELS_ROOT")
|
||||
return local_path
|
||||
|
||||
|
||||
def with_mocked_hf_download(func):
|
||||
"""Decorator to mock huggingface_hub.snapshot_download for tests.
|
||||
|
||||
When applied, any calls to snapshot_download will be redirected to use
|
||||
local model paths from LLM_MODELS_ROOT instead of downloading from HuggingFace.
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
with patch("huggingface_hub.snapshot_download", side_effect=mock_snapshot_download):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
@ -1,12 +1,11 @@
|
||||
import copy
|
||||
import os
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from test_common.llm_data import hf_model_dir_or_hub_id
|
||||
from torch import nn
|
||||
from torch.export import Dim
|
||||
from utils.llm_data import llm_models_root
|
||||
|
||||
|
||||
def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
||||
@ -285,17 +284,6 @@ def generate_dynamic_shapes(max_batch_size, max_seq_len):
|
||||
return dynamic_shapes
|
||||
|
||||
|
||||
def _hf_model_dir_or_hub_id(
|
||||
hf_model_subdir: str,
|
||||
hf_hub_id: str,
|
||||
) -> str:
|
||||
llm_models_path = llm_models_root()
|
||||
if llm_models_path and os.path.isdir((model_fullpath := llm_models_path / hf_model_subdir)):
|
||||
return str(model_fullpath)
|
||||
else:
|
||||
return hf_hub_id
|
||||
|
||||
|
||||
def rotate_half(x: torch.Tensor) -> torch.Tensor:
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
@ -351,7 +339,6 @@ def apply_rotary_pos_emb_ds(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
||||
|
||||
_SMALL_MODEL_CONFIGS = {
|
||||
"meta-llama/Meta-Llama-3.1-8B-Instruct": {
|
||||
"llm_models_subdir": "llama-3.1-model/Llama-3.1-8B-Instruct",
|
||||
"model_kwargs": {
|
||||
"num_hidden_layers": 1,
|
||||
"hidden_size": 64,
|
||||
@ -361,7 +348,6 @@ _SMALL_MODEL_CONFIGS = {
|
||||
},
|
||||
},
|
||||
"mistralai/Mixtral-8x7B-Instruct-v0.1": {
|
||||
"llm_models_subdir": "Mixtral-8x7B-Instruct-v0.1",
|
||||
"model_kwargs": {
|
||||
"num_hidden_layers": 2,
|
||||
"intermediate_size": 256,
|
||||
@ -372,7 +358,6 @@ _SMALL_MODEL_CONFIGS = {
|
||||
},
|
||||
},
|
||||
"Qwen/Qwen3-30B-A3B": {
|
||||
"llm_models_subdir": "Qwen3/Qwen3-30B-A3B",
|
||||
"model_kwargs": {
|
||||
"num_hidden_layers": 2,
|
||||
"intermediate_size": 256,
|
||||
@ -383,7 +368,6 @@ _SMALL_MODEL_CONFIGS = {
|
||||
},
|
||||
},
|
||||
"microsoft/Phi-3-mini-4k-instruct": {
|
||||
"llm_models_subdir": "Phi-3/Phi-3-mini-4k-instruct",
|
||||
"model_kwargs": {
|
||||
"num_hidden_layers": 2,
|
||||
"hidden_size": 128,
|
||||
@ -393,7 +377,6 @@ _SMALL_MODEL_CONFIGS = {
|
||||
},
|
||||
},
|
||||
"meta-llama/Llama-4-Scout-17B-16E-Instruct": {
|
||||
"llm_models_subdir": "llama4-models/Llama-4-Scout-17B-16E-Instruct",
|
||||
"model_factory": "AutoModelForImageTextToText",
|
||||
"model_kwargs": {
|
||||
"text_config": {
|
||||
@ -412,7 +395,6 @@ _SMALL_MODEL_CONFIGS = {
|
||||
},
|
||||
},
|
||||
"deepseek-ai/DeepSeek-V3": {
|
||||
"llm_models_subdir": "DeepSeek-V3",
|
||||
"model_kwargs": {
|
||||
"first_k_dense_replace": 1,
|
||||
"num_hidden_layers": 2,
|
||||
@ -431,7 +413,6 @@ _SMALL_MODEL_CONFIGS = {
|
||||
},
|
||||
},
|
||||
"Qwen/Qwen2.5-3B-Instruct": {
|
||||
"llm_models_subdir": "Qwen2.5-3B-Instruct",
|
||||
"model_kwargs": {
|
||||
"num_hidden_layers": 2,
|
||||
"hidden_size": 64,
|
||||
@ -441,7 +422,6 @@ _SMALL_MODEL_CONFIGS = {
|
||||
},
|
||||
},
|
||||
"mistralai/Mistral-Small-3.1-24B-Instruct-2503": {
|
||||
"llm_models_subdir": "Mistral-Small-3.1-24B-Instruct-2503",
|
||||
"model_factory": "AutoModelForImageTextToText",
|
||||
"model_kwargs": {
|
||||
"text_config": {
|
||||
@ -463,7 +443,6 @@ _SMALL_MODEL_CONFIGS = {
|
||||
},
|
||||
},
|
||||
"ibm-ai-platform/Bamba-9B-v2": {
|
||||
"llm_models_subdir": "Bamba-9B-v2",
|
||||
"model_kwargs": {
|
||||
"dtype": "bfloat16",
|
||||
"hidden_size": 64,
|
||||
@ -482,7 +461,6 @@ _SMALL_MODEL_CONFIGS = {
|
||||
},
|
||||
},
|
||||
"nvidia/NVIDIA-Nemotron-Nano-12B-v2": {
|
||||
"llm_models_subdir": "NVIDIA-Nemotron-Nano-12B-v2",
|
||||
"model_kwargs": {
|
||||
"dtype": "bfloat16",
|
||||
"hidden_size": 32,
|
||||
@ -497,13 +475,11 @@ _SMALL_MODEL_CONFIGS = {
|
||||
},
|
||||
},
|
||||
"TinyLlama/TinyLlama-1.1B-Chat-v1.0": {
|
||||
"llm_models_subdir": "llama-models-v2/TinyLlama-1.1B-Chat-v1.0",
|
||||
"model_kwargs": {
|
||||
"num_hidden_layers": 2,
|
||||
},
|
||||
},
|
||||
"nvidia/Nemotron-Nano-3-30B-A3.5B-dev-1024": {
|
||||
"llm_models_subdir": "Nemotron-Nano-3-30B-A3.5B-dev-1024",
|
||||
"model_kwargs": {
|
||||
"num_hidden_layers": 8,
|
||||
},
|
||||
@ -531,7 +507,7 @@ def get_small_model_config(model_hub_id: str, **llm_args_kwargs) -> Dict[str, An
|
||||
llm_args = copy.deepcopy(_SMALL_MODEL_CONFIGS[model_hub_id])
|
||||
|
||||
# check if should use llm_models_root or hf_hub_id
|
||||
llm_args["model"] = _hf_model_dir_or_hub_id(llm_args.pop("llm_models_subdir"), model_hub_id)
|
||||
llm_args["model"] = hf_model_dir_or_hub_id(model_hub_id)
|
||||
|
||||
# add some defaults to llm_args
|
||||
llm_args["skip_loading_weights"] = True # No weight loading to speed up things
|
||||
|
||||
@ -4,7 +4,7 @@ import types
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from _model_test_utils import _hf_model_dir_or_hub_id
|
||||
from test_common.llm_data import hf_model_dir_or_hub_id
|
||||
from transformers import AutoConfig, AutoModelForCausalLM
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.models.patches.deepseek import (
|
||||
@ -77,7 +77,7 @@ def _generate_ds_attention_mask(b, s):
|
||||
"model_name, module_name, patch, inputs",
|
||||
[
|
||||
pytest.param(
|
||||
_hf_model_dir_or_hub_id("DeepSeek-R1/DeepSeek-R1", "deepseek-ai/DeepSeek-R1"),
|
||||
hf_model_dir_or_hub_id("deepseek-ai/DeepSeek-R1"),
|
||||
"model.layers.0.self_attn",
|
||||
deepseek_v3_attention,
|
||||
[
|
||||
@ -87,7 +87,7 @@ def _generate_ds_attention_mask(b, s):
|
||||
],
|
||||
), # attention requires inputs [hidden_states, attention_mask, position_ids]
|
||||
pytest.param(
|
||||
_hf_model_dir_or_hub_id("DeepSeek-R1/DeepSeek-R1", "deepseek-ai/DeepSeek-R1"),
|
||||
hf_model_dir_or_hub_id("deepseek-ai/DeepSeek-R1"),
|
||||
"model.layers.0.mlp",
|
||||
deepseek_v3_moe_exact,
|
||||
[torch.randn(2, 6, 8, dtype=torch.bfloat16)],
|
||||
|
||||
@ -13,13 +13,17 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pytest
|
||||
from _model_test_utils import get_small_model_config
|
||||
from build_and_run_ad import ExperimentConfig, main
|
||||
from test_common.llm_data import with_mocked_hf_download
|
||||
|
||||
from tensorrt_llm.llmapi import DraftTargetDecodingConfig, KvCacheConfig
|
||||
|
||||
|
||||
def test_ad_speculative_decoding_smoke():
|
||||
@pytest.mark.parametrize("use_hf_speculative_model", [False, True])
|
||||
@with_mocked_hf_download
|
||||
def test_ad_speculative_decoding_smoke(use_hf_speculative_model: bool):
|
||||
"""Test speculative decoding with AutoDeploy using the build_and_run_ad main()."""
|
||||
|
||||
# Use a simple test prompt
|
||||
@ -27,15 +31,15 @@ def test_ad_speculative_decoding_smoke():
|
||||
|
||||
# Get base model config
|
||||
experiment_config = get_small_model_config("meta-llama/Meta-Llama-3.1-8B-Instruct")
|
||||
speculative_model_dir = get_small_model_config("TinyLlama/TinyLlama-1.1B-Chat-v1.0")["args"][
|
||||
"model"
|
||||
]
|
||||
speculative_model_hf_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
||||
if use_hf_speculative_model:
|
||||
# NOTE: this will still mock out the actual HuggingFace download
|
||||
speculative_model = speculative_model_hf_id
|
||||
else:
|
||||
speculative_model = get_small_model_config(speculative_model_hf_id)["args"]["model"]
|
||||
|
||||
print(f"Speculative model path: {speculative_model_dir}")
|
||||
# Configure speculative decoding with a draft model
|
||||
spec_config = DraftTargetDecodingConfig(
|
||||
max_draft_len=3, speculative_model_dir=speculative_model_dir
|
||||
)
|
||||
spec_config = DraftTargetDecodingConfig(max_draft_len=3, speculative_model=speculative_model)
|
||||
|
||||
# Configure KV cache
|
||||
kv_cache_config = KvCacheConfig(
|
||||
|
||||
@ -77,7 +77,7 @@ def test_correctness_across_batch_sizes(drafter_type: str, schedule: dict):
|
||||
else:
|
||||
spec_config = DraftTargetDecodingConfig(
|
||||
max_draft_len=max_draft_len,
|
||||
speculative_model_dir=str(draft_model),
|
||||
speculative_model=str(draft_model),
|
||||
draft_len_schedule=schedule,
|
||||
)
|
||||
|
||||
@ -123,7 +123,7 @@ def test_correctness_across_batch_sizes(drafter_type: str, schedule: dict):
|
||||
else:
|
||||
spec_config_fixed = DraftTargetDecodingConfig(
|
||||
max_draft_len=max_draft_len,
|
||||
speculative_model_dir=str(draft_model),
|
||||
speculative_model=str(draft_model),
|
||||
draft_len_schedule=None, # No schedule - fixed draft length
|
||||
)
|
||||
llm_fixed = LLM(**llm_common_config, speculative_config=spec_config_fixed)
|
||||
@ -186,9 +186,7 @@ def test_draft_len_schedule_functionality(drafter_type: str, draft_schedule: dic
|
||||
else:
|
||||
spec_config = DraftTargetDecodingConfig(
|
||||
max_draft_len=5,
|
||||
speculative_model_dir=str(
|
||||
llm_models_root() / "llama-3.2-models" / "Llama-3.2-3B-Instruct"
|
||||
),
|
||||
speculative_model=str(llm_models_root() / "llama-3.2-models" / "Llama-3.2-3B-Instruct"),
|
||||
draft_len_schedule=draft_schedule,
|
||||
)
|
||||
prompts = ["The capital of France is" for i in range(7)]
|
||||
|
||||
@ -45,7 +45,7 @@ def test_llama_draft_target(use_cuda_graph: bool, attn_backend: str):
|
||||
|
||||
spec_config = DraftTargetDecodingConfig(
|
||||
max_draft_len=max_draft_len,
|
||||
speculative_model_dir=draft_model_dir,
|
||||
speculative_model=draft_model_dir,
|
||||
)
|
||||
|
||||
prompts = [
|
||||
|
||||
@ -87,7 +87,7 @@ def test_draft_token_static_tree_prepare_for_generation():
|
||||
spec_config = EagleDecodingConfig(
|
||||
max_draft_len=max_draft_len,
|
||||
max_total_draft_tokens=max_total_draft_tokens,
|
||||
speculative_model_dir=eagle_model_dir,
|
||||
speculative_model=eagle_model_dir,
|
||||
eagle3_one_model=False,
|
||||
eagle_choices=eagle_choices,
|
||||
use_dynamic_tree=use_dynamic_tree,
|
||||
|
||||
@ -38,7 +38,7 @@ def test_draft_token_static_tree_sampling():
|
||||
spec_config = EagleDecodingConfig(
|
||||
max_draft_len=max_draft_len,
|
||||
max_total_draft_tokens=max_total_draft_tokens,
|
||||
speculative_model_dir=eagle_model_dir,
|
||||
speculative_model=eagle_model_dir,
|
||||
eagle3_one_model=False,
|
||||
eagle_choices=eagle_choices,
|
||||
use_dynamic_tree=use_dynamic_tree,
|
||||
|
||||
@ -23,7 +23,7 @@ def run_test(eagle_model_dir, max_seq_len, beam_width, use_dynamic_tree,
|
||||
spec_config = EagleDecodingConfig(
|
||||
max_draft_len=max_draft_len,
|
||||
max_total_draft_tokens=max_total_draft_tokens,
|
||||
speculative_model_dir=eagle_model_dir,
|
||||
speculative_model=eagle_model_dir,
|
||||
eagle3_one_model=False,
|
||||
eagle_choices=eagle_choices,
|
||||
use_dynamic_tree=use_dynamic_tree,
|
||||
|
||||
@ -56,7 +56,7 @@ def test_dynamic_spec_decode(enforce_single_worker,
|
||||
|
||||
spec_config = EagleDecodingConfig(
|
||||
max_draft_len=max_draft_len,
|
||||
speculative_model_dir=eagle_model_dir,
|
||||
speculative_model=eagle_model_dir,
|
||||
# Llama 3 does not support one model eagle.
|
||||
eagle3_one_model=False,
|
||||
)
|
||||
|
||||
@ -8,6 +8,7 @@ from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from test_common.llm_data import with_mocked_hf_download
|
||||
from utils.llm_data import llm_models_root
|
||||
|
||||
from tensorrt_llm import LLM, SamplingParams
|
||||
@ -92,48 +93,83 @@ def test_kv_lens_runtime_with_eagle3_one_model():
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"use_cuda_graph,attn_backend,disable_overlap_scheduler,enable_block_reuse,use_one_model,enable_chunked_prefill,use_chain_drafter,multi_batch,attention_dp",
|
||||
"use_cuda_graph,attn_backend,disable_overlap_scheduler,enable_block_reuse,use_one_model,enable_chunked_prefill,use_chain_drafter,multi_batch,attention_dp,use_hf_speculative_model",
|
||||
[
|
||||
[True, "TRTLLM", True, False, False, False, True, False, False],
|
||||
[True, "TRTLLM", True, False, False, False, False, False, False],
|
||||
[False, "TRTLLM", True, False, False, False, True, False, False],
|
||||
[False, "TRTLLM", True, False, False, False, False, False, False],
|
||||
[True, "FLASHINFER", True, False, False, False, True, False, False],
|
||||
[False, "FLASHINFER", True, False, False, False, True, False, False],
|
||||
[False, "TRTLLM", False, True, True, False, True, False, False],
|
||||
[True, "TRTLLM", False, True, True, False, True, False, False],
|
||||
[True, "TRTLLM", True, False, True, True, True, False, False],
|
||||
[True, "TRTLLM", True, False, True, False, True, False, False],
|
||||
[True, "TRTLLM", True, False, False, True, True, False, False],
|
||||
[True, "TRTLLM", False, False, False, False, True, False, False],
|
||||
[False, "TRTLLM", False, False, False, False, True, False, False],
|
||||
[True, "TRTLLM", False, False, False, False, False, True, False],
|
||||
[True, "TRTLLM", False, False, False, False, False, True, True],
|
||||
[False, "TRTLLM", False, False, False, False, False, True, False],
|
||||
[True, "TRTLLM", False, False, False, False, True, True, False],
|
||||
[False, "TRTLLM", False, False, False, False, True, True, False],
|
||||
[True, "TRTLLM", False, False, False, False, False, False, False],
|
||||
[False, "TRTLLM", False, False, False, False, False, False, False],
|
||||
[True, "TRTLLM", False, False, False, True, True, False, False],
|
||||
[True, "TRTLLM", False, False, False, True, False, False, False],
|
||||
[True, "FLASHINFER", False, False, False, False, True, False, False],
|
||||
[False, "FLASHINFER", False, False, False, False, True, False, False],
|
||||
[True, "TRTLLM", True, False, False, False, True, False, False, False],
|
||||
[True, "TRTLLM", True, False, False, False, False, False, False, False],
|
||||
[False, "TRTLLM", True, False, False, False, True, False, False, False],
|
||||
[
|
||||
False, "TRTLLM", True, False, False, False, False, False, False,
|
||||
False
|
||||
],
|
||||
[
|
||||
True, "FLASHINFER", True, False, False, False, True, False, False,
|
||||
False
|
||||
],
|
||||
[
|
||||
False, "FLASHINFER", True, False, False, False, True, False, False,
|
||||
False
|
||||
],
|
||||
[False, "TRTLLM", False, True, True, False, True, False, False, False],
|
||||
[True, "TRTLLM", False, True, True, False, True, False, False, False],
|
||||
[True, "TRTLLM", True, False, True, True, True, False, False, False],
|
||||
[True, "TRTLLM", True, False, True, False, True, False, False, False],
|
||||
[True, "TRTLLM", True, False, False, True, True, False, False, False],
|
||||
[True, "TRTLLM", False, False, False, False, True, False, False, False],
|
||||
[
|
||||
False, "TRTLLM", False, False, False, False, True, False, False,
|
||||
False
|
||||
],
|
||||
[True, "TRTLLM", False, False, False, False, False, True, False, False],
|
||||
[True, "TRTLLM", False, False, False, False, False, True, True, False],
|
||||
[
|
||||
False, "TRTLLM", False, False, False, False, False, True, False,
|
||||
False
|
||||
],
|
||||
[True, "TRTLLM", False, False, False, False, True, True, False, False],
|
||||
[False, "TRTLLM", False, False, False, False, True, True, False, False],
|
||||
[
|
||||
True, "TRTLLM", False, False, False, False, False, False, False,
|
||||
False
|
||||
],
|
||||
[
|
||||
False, "TRTLLM", False, False, False, False, False, False, False,
|
||||
False
|
||||
],
|
||||
[True, "TRTLLM", False, False, False, True, True, False, False, False],
|
||||
[True, "TRTLLM", False, False, False, True, False, False, False, False],
|
||||
[
|
||||
True, "FLASHINFER", False, False, False, False, True, False, False,
|
||||
False
|
||||
],
|
||||
[
|
||||
False, "FLASHINFER", False, False, False, False, True, False, False,
|
||||
False
|
||||
],
|
||||
# Tests (mocked) speculative model auto-download from HuggingFace
|
||||
[False, "TRTLLM", True, False, False, False, True, False, False, True],
|
||||
])
|
||||
@pytest.mark.high_cuda_memory
|
||||
@with_mocked_hf_download
|
||||
def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str,
|
||||
disable_overlap_scheduler: bool, enable_block_reuse: bool,
|
||||
use_one_model: bool, enable_chunked_prefill: bool,
|
||||
use_chain_drafter: bool, multi_batch: bool,
|
||||
attention_dp: bool, request):
|
||||
attention_dp: bool, use_hf_speculative_model: bool,
|
||||
request):
|
||||
# Eagle3 one model works with overlap scheduler and block reuse.
|
||||
total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
|
||||
if total_mem_gb < 35:
|
||||
pytest.skip("Not enough memory to load target + draft model")
|
||||
|
||||
models_path = llm_models_root()
|
||||
eagle_model_dir = f"{models_path}/EAGLE3-LLaMA3.1-Instruct-8B"
|
||||
target_model_dir = f"{models_path}/llama-3.1-model/Llama-3.1-8B-Instruct"
|
||||
|
||||
if use_hf_speculative_model:
|
||||
eagle_model = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
|
||||
else:
|
||||
eagle_model = f"{models_path}/EAGLE3-LLaMA3.1-Instruct-8B"
|
||||
|
||||
# bs > 1 gives non-deterministic when doing IFB. There are slight chances
|
||||
# that ref and spec does not match 100%
|
||||
max_batch_size = 4 if multi_batch else 1
|
||||
@ -165,7 +201,7 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str,
|
||||
|
||||
spec_config = EagleDecodingConfig(
|
||||
max_draft_len=max_draft_len,
|
||||
speculative_model_dir=eagle_model_dir,
|
||||
speculative_model=eagle_model,
|
||||
# Llama 3 does not support one model eagle.
|
||||
eagle3_one_model=use_one_model,
|
||||
)
|
||||
@ -241,7 +277,7 @@ def test_eagle3_spec_decoding_stats(eagle3_one_model):
|
||||
free_gpu_memory_fraction=0.6)
|
||||
spec_config = EagleDecodingConfig(
|
||||
max_draft_len=3,
|
||||
speculative_model_dir=eagle_model_dir,
|
||||
speculative_model=eagle_model_dir,
|
||||
eagle3_one_model=eagle3_one_model,
|
||||
)
|
||||
|
||||
@ -321,7 +357,7 @@ def test_llama_eagle3_long_prompt(use_cuda_graph):
|
||||
|
||||
spec_config = EagleDecodingConfig(
|
||||
max_draft_len=3,
|
||||
speculative_model_dir=eagle_model_dir,
|
||||
speculative_model=eagle_model_dir,
|
||||
eagle3_one_model=False,
|
||||
)
|
||||
|
||||
@ -445,7 +481,7 @@ def test_deepseek_eagle3():
|
||||
|
||||
spec_config = EagleDecodingConfig(
|
||||
max_draft_len=max_draft_len,
|
||||
speculative_model_dir=eagle_model_dir,
|
||||
speculative_model=eagle_model_dir,
|
||||
# Llama 3 does not support one model eagle.
|
||||
eagle3_one_model=use_one_model,
|
||||
eagle3_layers_to_capture={29},
|
||||
@ -555,7 +591,7 @@ def test_deepseek_mla_eagle3():
|
||||
)
|
||||
|
||||
spec_config = EagleDecodingConfig(max_draft_len=max_draft_len,
|
||||
speculative_model_dir=eagle_model_dir,
|
||||
speculative_model=eagle_model_dir,
|
||||
eagle3_one_model=use_one_model,
|
||||
load_format="dummy")
|
||||
|
||||
@ -654,7 +690,7 @@ def test_multi_eagle3(use_one_model: bool):
|
||||
|
||||
spec_config = EagleDecodingConfig(
|
||||
max_draft_len=max_draft_len,
|
||||
speculative_model_dir=eagle_model_dir,
|
||||
speculative_model=eagle_model_dir,
|
||||
# Llama 3 does not support one model eagle.
|
||||
eagle3_one_model=use_one_model,
|
||||
num_eagle_layers=2,
|
||||
@ -713,7 +749,7 @@ def test_eagle3_cuda_graph_padding(disable_overlap_scheduler: bool):
|
||||
|
||||
spec_config = EagleDecodingConfig(
|
||||
max_draft_len=max_draft_len,
|
||||
speculative_model_dir=eagle_model_dir,
|
||||
speculative_model=eagle_model_dir,
|
||||
eagle3_one_model=use_one_model,
|
||||
)
|
||||
|
||||
@ -766,7 +802,7 @@ def test_eagle3_cdl_sampling(disable_overlap_scheduler: bool):
|
||||
|
||||
spec_config = EagleDecodingConfig(
|
||||
max_draft_len=max_draft_len,
|
||||
speculative_model_dir=eagle_model_dir,
|
||||
speculative_model=eagle_model_dir,
|
||||
eagle3_one_model=use_one_model,
|
||||
)
|
||||
|
||||
|
||||
@ -52,7 +52,7 @@ def test_kv_cache_reuse(use_cuda_graph: bool, attn_backend: str):
|
||||
|
||||
spec_config = EagleDecodingConfig(
|
||||
max_draft_len=max_draft_len,
|
||||
speculative_model_dir=eagle_model_dir,
|
||||
speculative_model=eagle_model_dir,
|
||||
eagle3_one_model=False,
|
||||
)
|
||||
|
||||
|
||||
@ -47,7 +47,7 @@ def test_spec_gate_e2e():
|
||||
|
||||
spec_config = EagleDecodingConfig(
|
||||
max_draft_len=max_draft_len,
|
||||
speculative_model_dir=eagle_model_dir,
|
||||
speculative_model=eagle_model_dir,
|
||||
# Llama 3 does not support one model eagle.
|
||||
eagle3_one_model=False,
|
||||
max_concurrency=10000,
|
||||
|
||||
@ -1218,7 +1218,7 @@ def test_llm_api_medusa():
|
||||
|
||||
speculative_config = MedusaDecodingConfig(num_medusa_heads=4,
|
||||
max_draft_len=63,
|
||||
speculative_model_dir=get_model_path("medusa-vicuna-7b-v1.3"),
|
||||
speculative_model=get_model_path("medusa-vicuna-7b-v1.3"),
|
||||
medusa_choices=[[0], [0, 0], [1], [0, 1], [2], [0, 0, 0], [1, 0], [0, 2], [3], [0, 3], [4], [0, 4], [2, 0], \
|
||||
[0, 5], [0, 0, 1], [5], [0, 6], [6], [0, 7], [0, 1, 0], [1, 1], [7], [0, 8], [0, 0, 2], [3, 0], \
|
||||
[0, 9], [8], [9], [1, 0, 0], [0, 2, 0], [1, 2], [0, 0, 3], [4, 0], [2, 1], [0, 0, 4], [0, 0, 5], \
|
||||
@ -1257,7 +1257,7 @@ def test_llm_api_medusa_tp2():
|
||||
|
||||
speculative_config = MedusaDecodingConfig(num_medusa_heads=4,
|
||||
max_draft_len=63,
|
||||
speculative_model_dir=get_model_path("medusa-vicuna-7b-v1.3"),
|
||||
speculative_model=get_model_path("medusa-vicuna-7b-v1.3"),
|
||||
medusa_choices=[[0], [0, 0], [1], [0, 1], [2], [0, 0, 0], [1, 0], [0, 2], [3], [0, 3], [4], [0, 4], [2, 0], \
|
||||
[0, 5], [0, 0, 1], [5], [0, 6], [6], [0, 7], [0, 1, 0], [1, 1], [7], [0, 8], [0, 0, 2], [3, 0], \
|
||||
[0, 9], [8], [9], [1, 0, 0], [0, 2, 0], [1, 2], [0, 0, 3], [4, 0], [2, 1], [0, 0, 4], [0, 0, 5], \
|
||||
@ -1295,7 +1295,7 @@ def test_llm_api_eagle(**llm_kwargs):
|
||||
|
||||
speculative_config = EagleDecodingConfig(
|
||||
max_draft_len=63,
|
||||
speculative_model_dir=get_model_path("EAGLE-Vicuna-7B-v1.3"),
|
||||
speculative_model=get_model_path("EAGLE-Vicuna-7B-v1.3"),
|
||||
num_eagle_layers=4,
|
||||
max_non_leaves_per_layer=10,
|
||||
eagle_choices=[[0], [0, 0], [1], [0, 1], [2], [0, 0, 0], [1, 0], [0, 2], [3], [0, 3], [4], [0, 4], [2, 0], \
|
||||
@ -1342,7 +1342,7 @@ def test_llm_api_eagle2(**llm_kwargs):
|
||||
|
||||
speculative_config = EagleDecodingConfig(
|
||||
max_draft_len=63,
|
||||
speculative_model_dir=get_model_path("EAGLE-Vicuna-7B-v1.3"),
|
||||
speculative_model=get_model_path("EAGLE-Vicuna-7B-v1.3"),
|
||||
num_eagle_layers=4,
|
||||
max_non_leaves_per_layer=10,
|
||||
use_dynamic_tree=True,
|
||||
|
||||
@ -445,6 +445,19 @@ class TestTorchLlmArgs:
|
||||
args = TorchLlmArgs(model=llama_model_path)
|
||||
args.invalid_arg = 1
|
||||
|
||||
def test_speculative_model_alias(self):
|
||||
"""Test that speculative_model_dir is accepted as an alias for speculative_model."""
|
||||
|
||||
spec_config = EagleDecodingConfig(
|
||||
max_draft_len=3,
|
||||
speculative_model_dir="/path/to/model",
|
||||
eagle3_one_model=False,
|
||||
)
|
||||
|
||||
args = TorchLlmArgs(model=llama_model_path,
|
||||
speculative_config=spec_config)
|
||||
assert args.speculative_model == "/path/to/model"
|
||||
|
||||
|
||||
class TestTrtLlmArgs:
|
||||
|
||||
|
||||
@ -1,23 +1,15 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
import sys
|
||||
|
||||
# Ensure tests/ directory is in path for test_common imports
|
||||
sys.path.insert(
|
||||
0,
|
||||
os.path.dirname(os.path.dirname(os.path.dirname(
|
||||
os.path.abspath(__file__)))))
|
||||
|
||||
def llm_models_root(check=False) -> Optional[Path]:
|
||||
root = Path("/home/scratch.trt_llm_data/llm-models/")
|
||||
from test_common.llm_data import llm_datasets_root, llm_models_root
|
||||
|
||||
if "LLM_MODELS_ROOT" in os.environ:
|
||||
root = Path(os.environ.get("LLM_MODELS_ROOT"))
|
||||
|
||||
if not root.exists():
|
||||
root = Path("/scratch.trt_llm_data/llm-models/")
|
||||
|
||||
if check:
|
||||
assert root.exists(), \
|
||||
"You shall set LLM_MODELS_ROOT env or be able to access /home/scratch.trt_llm_data to run this test"
|
||||
|
||||
return root if root.exists() else None
|
||||
|
||||
|
||||
def llm_datasets_root() -> str:
|
||||
return os.path.join(llm_models_root(check=True), "datasets")
|
||||
__all__ = [
|
||||
"llm_datasets_root",
|
||||
"llm_models_root",
|
||||
]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user