[None][chore] AutoDeploy: cleanup old inference optimizer configs (#8039)

Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
Co-authored-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
This commit is contained in:
h-guo18 2025-10-17 12:55:57 -07:00 committed by GitHub
parent bb7fdcebf4
commit 55fed1873c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
46 changed files with 555 additions and 572 deletions

View File

@ -40,29 +40,31 @@ trtllm-bench \
#### Basic Performance Configuration (`autodeploy_config.yaml`)
```yaml
# Compilation backend
compile_backend: torch-opt
# Runtime engine
# runtime engine
runtime: trtllm
# Model loading
# model loading
skip_loading_weights: false
# Fraction of free memory to use for kv-caches
free_mem_ratio: 0.8
# CUDA Graph optimization
cuda_graph_batch_sizes: [1, 2, 4, 8, 16, 32, 64, 128, 256]
# Attention backend
attn_backend: flashinfer
# Sequence configuration
max_batch_size: 256
# transform options
transforms:
insert_cached_attention:
# attention backend
backend: flashinfer
resize_kv_cache:
# fraction of free memory to use for kv-caches
free_mem_ratio: 0.8
compile_model:
# compilation backend
backend: torch-opt
# CUDA Graph optimization
cuda_graph_batch_sizes: [1, 2, 4, 8, 16, 32, 64, 128, 256]
```
Enable multi-GPU execution by specifying `--tp n`, where `n` is the number of GPUs
Enable multi-GPU execution by specifying `--tp n`, where `n` is the number of GPUs.
## Configuration Options Reference

View File

@ -63,15 +63,15 @@ args:
num_hidden_layers: 12
hidden_size: 1024
world_size: 4
compile_backend: torch-compile
attn_backend: triton
max_seq_len: 2048
max_batch_size: 16
transforms:
sharding:
strategy: auto
quantization:
enabled: false
detect_sharding:
support_partial_config: true
insert_cached_attention:
backend: triton
compile_model:
backend: torch-compile
prompt:
batch_size: 8
@ -79,13 +79,6 @@ prompt:
max_tokens: 150
temperature: 0.8
top_k: 50
benchmark:
enabled: true
num: 20
bs: 4
isl: 1024
osl: 256
```
Create an additional override file (e.g., `production.yaml`):
@ -94,11 +87,10 @@ Create an additional override file (e.g., `production.yaml`):
# production.yaml
args:
world_size: 8
compile_backend: torch-opt
max_batch_size: 32
benchmark:
enabled: false
transforms:
compile_model:
backend: torch-opt
```
Then use these configurations:
@ -107,18 +99,18 @@ Then use these configurations:
# Using single YAML config
python build_and_run_ad.py \
--model "meta-llama/Meta-Llama-3.1-8B-Instruct" \
--yaml-configs my_config.yaml
--yaml-extra my_config.yaml
# Using multiple YAML configs (deep merged in order, later files have higher priority)
python build_and_run_ad.py \
--model "meta-llama/Meta-Llama-3.1-8B-Instruct" \
--yaml-configs my_config.yaml production.yaml
--yaml-extra my_config.yaml production.yaml
# Targeting nested AutoDeployConfig with separate YAML
python build_and_run_ad.py \
--model "meta-llama/Meta-Llama-3.1-8B-Instruct" \
--yaml-configs my_config.yaml \
--args.yaml-configs autodeploy_overrides.yaml
--yaml-extra my_config.yaml \
--args.yaml-extra autodeploy_overrides.yaml
```
## Configuration Precedence and Deep Merging
@ -126,7 +118,7 @@ python build_and_run_ad.py \
The configuration system follows a precedence order in which higher priority sources override lower priority ones:
1. **CLI Arguments** (highest priority) - Direct command line arguments
1. **YAML Configs** - Files specified via `--yaml-configs` and `--args.yaml-configs`
1. **YAML Configs** - Files specified via `--yaml-extra` and `--args.yaml-extra`
1. **Default Settings** (lowest priority) - Built-in defaults from the config classes
**Deep Merging**: Unlike simple overwriting, deep merging recursively combines nested dictionaries. For example:
@ -152,12 +144,12 @@ args:
**Nested Config Behavior**: When using nested configurations, outer YAML configuration files become initialization settings for inner objects, giving them higher precedence:
```bash
# The outer yaml-configs affects the entire ExperimentConfig
# The inner args.yaml-configs affects only the AutoDeployConfig
# The outer yaml-extra affects the entire ExperimentConfig
# The inner args.yaml-extra affects only the AutoDeployConfig
python build_and_run_ad.py \
--model "meta-llama/Meta-Llama-3.1-8B-Instruct" \
--yaml-configs experiment_config.yaml \
--args.yaml-configs autodeploy_config.yaml \
--yaml-extra experiment_config.yaml \
--args.yaml-extra autodeploy_config.yaml \
--args.world-size=8 # CLI override beats both YAML configs
```

View File

@ -18,9 +18,7 @@ llm = LLM(
attn_page_size=64, # page size for attention (tokens_per_block, should be == max_seq_len for triton)
skip_loading_weights=False,
model_factory="AutoModelForCausalLM", # choose appropriate model factory
mla_backend="MultiHeadLatentAttention", # for models that support MLA
free_mem_ratio=0.8, # fraction of available memory for cache
simple_shard_only=False, # tensor parallelism sharding strategy
max_seq_len=<MAX_SEQ_LEN>,
max_batch_size=<MAX_BATCH_SIZE>,
)

View File

@ -113,6 +113,7 @@ Optimize attention operations with different attention kernel implementations:
| `"attn_backend"` | Description |
|----------------------|-------------|
| `torch` | Custom fused multi-head attention (MHA) with KV Cache reference implementation in pure PyTorch (slow!) |
| `triton` | Custom fused multi-head attention (MHA) with KV Cache kernels for efficient attention processing. |
| `flashinfer` | Uses optimized attention kernels with KV Cache from the [`flashinfer`](https://github.com/flashinfer-ai/flashinfer.git) library. |

View File

@ -40,29 +40,31 @@ trtllm-bench \
#### Basic Performance Configuration (`autodeploy_config.yaml`)
```yaml
# Compilation backend
compile_backend: torch-opt
# Runtime engine
# runtime engine
runtime: trtllm
# Model loading
# model loading
skip_loading_weights: false
# Fraction of free memory to use for kv-caches
free_mem_ratio: 0.8
# CUDA Graph optimization
cuda_graph_batch_sizes: [1, 2, 4, 8, 16, 32, 64, 128, 256]
# Attention backend
attn_backend: flashinfer
# Sequence configuration
max_batch_size: 256
# transform options
transforms:
insert_cached_attention:
# attention backend
backend: flashinfer
resize_kv_cache:
# fraction of free memory to use for kv-caches
free_mem_ratio: 0.8
compile_model:
# compilation backend
backend: torch-opt
# CUDA Graph optimization
cuda_graph_batch_sizes: [1, 2, 4, 8, 16, 32, 64, 128, 256]
```
Enable multi-GPU execution by specifying `--tp n`, where `n` is the number of GPUs
Enable multi-GPU execution by specifying `--tp n`, where `n` is the number of GPUs.
## Configuration Options Reference

View File

@ -63,15 +63,15 @@ args:
num_hidden_layers: 12
hidden_size: 1024
world_size: 4
compile_backend: torch-compile
attn_backend: triton
max_seq_len: 2048
max_batch_size: 16
transforms:
sharding:
strategy: auto
quantization:
enabled: false
detect_sharding:
support_partial_config: true
insert_cached_attention:
backend: triton
compile_model:
backend: torch-compile
prompt:
batch_size: 8
@ -79,13 +79,6 @@ prompt:
max_tokens: 150
temperature: 0.8
top_k: 50
benchmark:
enabled: true
num: 20
bs: 4
isl: 1024
osl: 256
```
Create an additional override file (e.g., `production.yaml`):
@ -94,11 +87,10 @@ Create an additional override file (e.g., `production.yaml`):
# production.yaml
args:
world_size: 8
compile_backend: torch-opt
max_batch_size: 32
benchmark:
enabled: false
transforms:
compile_model:
backend: torch-opt
```
Then use these configurations:
@ -107,18 +99,18 @@ Then use these configurations:
# Using single YAML config
python build_and_run_ad.py \
--model "meta-llama/Meta-Llama-3.1-8B-Instruct" \
--yaml-configs my_config.yaml
--yaml-extra my_config.yaml
# Using multiple YAML configs (deep merged in order, later files have higher priority)
python build_and_run_ad.py \
--model "meta-llama/Meta-Llama-3.1-8B-Instruct" \
--yaml-configs my_config.yaml production.yaml
--yaml-extra my_config.yaml production.yaml
# Targeting nested AutoDeployConfig with separate YAML
python build_and_run_ad.py \
--model "meta-llama/Meta-Llama-3.1-8B-Instruct" \
--yaml-configs my_config.yaml \
--args.yaml-configs autodeploy_overrides.yaml
--yaml-extra my_config.yaml \
--args.yaml-extra autodeploy_overrides.yaml
```
## Configuration Precedence and Deep Merging
@ -126,7 +118,7 @@ python build_and_run_ad.py \
The configuration system follows a precedence order in which higher priority sources override lower priority ones:
1. **CLI Arguments** (highest priority) - Direct command line arguments
1. **YAML Configs** - Files specified via `--yaml-configs` and `--args.yaml-configs`
1. **YAML Configs** - Files specified via `--yaml-extra` and `--args.yaml-extra`
1. **Default Settings** (lowest priority) - Built-in defaults from the config classes
**Deep Merging**: Unlike simple overwriting, deep merging recursively combines nested dictionaries. For example:
@ -152,12 +144,12 @@ args:
**Nested Config Behavior**: When using nested configurations, outer YAML configuration files become initialization settings for inner objects, giving them higher precedence:
```bash
# The outer yaml-configs affects the entire ExperimentConfig
# The inner args.yaml-configs affects only the AutoDeployConfig
# The outer yaml-extra affects the entire ExperimentConfig
# The inner args.yaml-extra affects only the AutoDeployConfig
python build_and_run_ad.py \
--model "meta-llama/Meta-Llama-3.1-8B-Instruct" \
--yaml-configs experiment_config.yaml \
--args.yaml-configs autodeploy_config.yaml \
--yaml-extra experiment_config.yaml \
--args.yaml-extra autodeploy_config.yaml \
--args.world-size=8 # CLI override beats both YAML configs
```

View File

@ -42,23 +42,31 @@ trtllm-serve \
Example `autodeploy_config.yaml`:
```yaml
# Compilation backend for AutoDeploy
compile_backend: torch-opt # options: torch-simple, torch-compile, torch-cudagraph, torch-opt
# runtime engine
runtime: trtllm
# Runtime engine
runtime: trtllm # options: trtllm, demollm
# model loading
skip_loading_weights: false
# Model loading
skip_loading_weights: false # set true for architecture-only perf runs
# Sequence configuration
max_batch_size: 256
# KV cache memory
free_mem_ratio: 0.8 # fraction of free GPU mem for KV cache
# multi-gpu execution
world_size: 1
# CUDA graph optimization
cuda_graph_batch_sizes: [1, 2, 4, 8, 16, 32, 64]
# Attention backend
attn_backend: flashinfer # recommended for best performance
# transform options
transforms:
insert_cached_attention:
# attention backend
backend: flashinfer
resize_kv_cache:
# fraction of free memory to use for kv-caches
free_mem_ratio: 0.8
compile_model:
# compilation backend
backend: torch-opt
# CUDA Graph optimization
cuda_graph_batch_sizes: [1, 2, 4, 8, 16, 32, 64, 128, 256]
```
## Limitations and tips

View File

@ -12,15 +12,18 @@ from tensorrt_llm._torch.auto_deploy import LLM
llm = LLM(
model=<HF_MODEL_CARD_OR_DIR>,
world_size=<DESIRED_WORLD_SIZE>,
compile_backend="torch-compile",
model_kwargs={"num_hidden_layers": 2}, # test with smaller model configuration
attn_backend="flashinfer", # choose between "triton" and "flashinfer"
attn_page_size=64, # page size for attention (tokens_per_block, should be == max_seq_len for triton)
skip_loading_weights=False,
model_factory="AutoModelForCausalLM", # choose appropriate model factory
mla_backend="MultiHeadLatentAttention", # for models that support MLA
free_mem_ratio=0.8, # fraction of available memory for cache
simple_shard_only=False, # tensor parallelism sharding strategy
model_kwargs={"num_hidden_layers": 2}, # test with smaller model configuration
transforms={
"insert_cached_attention": {"backend": "flashinfer"}, # or "triton"
"insert_cached_mla_attention": {"backend": "MultiHeadLatentAttention"},
"resize_kv_cache": {"free_mem_ratio": 0.8},
"compile_model": {"backend": "torch-compile"},
"detect_sharding": {"simple_shard_only": False},
},
attn_page_size=64, # page size for attention
skip_loading_weights=False,
max_seq_len=<MAX_SEQ_LEN>,
max_batch_size=<MAX_BATCH_SIZE>,
)

View File

@ -10,9 +10,9 @@
"--model=meta-llama/Meta-Llama-3.1-8B-Instruct",
"--args.world-size=2",
"--args.runtime=demollm",
"--args.compile-backend=torch-simple",
"--args.transforms.compile-model.backend=torch-simple",
"--args.attn-page-size=16",
"--args.attn-backend=flashinfer",
"--args.transforms.insert-cached-attention.backend=flashinfer",
"--args.model-factory=AutoModelForCausalLM",
"--benchmark.enabled=false",
"--prompt.batch-size=2",

View File

@ -128,9 +128,7 @@ llm = LLM(
attn_page_size=64, # page size for attention (tokens_per_block, should be == max_seq_len for triton)
skip_loading_weights=False,
model_factory="AutoModelForCausalLM", # choose appropriate model factory
mla_backend="MultiHeadLatentAttention", # for models that support MLA
free_mem_ratio=0.8, # fraction of available memory for cache
simple_shard_only=False, # tensor parallelism sharding strategy
max_seq_len=<MAX_SEQ_LEN>,
max_batch_size=<MAX_BATCH_SIZE>,
)
@ -218,15 +216,15 @@ args:
num_hidden_layers: 12
hidden_size: 1024
world_size: 4
compile_backend: torch-compile
attn_backend: triton
max_seq_len: 2048
max_batch_size: 16
transforms:
sharding:
strategy: auto
quantization:
enabled: false
detect_sharding:
support_partial_config: true
insert_cached_attention:
backend: triton
compile_model:
backend: torch-compile
prompt:
batch_size: 8
@ -234,13 +232,6 @@ prompt:
max_tokens: 150
temperature: 0.8
top_k: 50
benchmark:
enabled: true
num: 20
bs: 4
isl: 1024
osl: 256
```
Create an additional override file (e.g., `production.yaml`):
@ -249,11 +240,10 @@ Create an additional override file (e.g., `production.yaml`):
# production.yaml
args:
world_size: 8
compile_backend: torch-opt
max_batch_size: 32
benchmark:
enabled: false
transforms:
compile_model:
backend: torch-opt
```
Then use these configurations:

View File

@ -280,7 +280,7 @@ def main(config: Optional[ExperimentConfig] = None):
# run a benchmark for the model with batch_size == config.benchmark_bs
if config.benchmark.enabled and config.args.runtime != "trtllm":
ad_logger.info("Running benchmark...")
keys_from_args = ["compile_backend", "attn_backend", "mla_backend"]
keys_from_args = []
fields_to_show = [f"benchmark={config.benchmark}"]
fields_to_show.extend([f"{k}={getattr(config.args, k)}" for k in keys_from_args])
results["benchmark_results"] = benchmark(

View File

@ -38,10 +38,12 @@ transforms:
stage: pattern_matcher
match_attention_layout:
stage: pattern_matcher
attn_layout: bsnd
match_rope_pattern:
stage: pattern_matcher
match_rope_layout:
stage: pattern_matcher
expected_layout: bsnd
############################################################################################
# RUN TRANSFORMATIONS ON STANDARDIZED GRAPH REPRESENTATION
############################################################################################
@ -87,6 +89,7 @@ transforms:
load_weights:
stage: weight_load
run_per_gm: false
checkpoint_device: null
move_inputs_to_device:
stage: weight_load
run_per_gm: false
@ -122,30 +125,40 @@ transforms:
backend: flashinfer
requires_shape_prop: true
############################################################################################
# VISUALIZE GRAPH
############################################################################################
visualize_namespace:
stage: visualize
enabled: false # TODO: https://github.com/NVIDIA/TensorRT-LLM/issues/8460
############################################################################################
# SWITCH TO CACHED+FLATTENED ATTENTION + INITIALIZE CACHES
############################################################################################
update_in_out_nodes:
stage: cache_init
insert_cached_attention:
stage: cache_init
backend: flashinfer
insert_cached_mla_attention:
stage: cache_init
attn_backend: MultiHeadLatentAttention
backend: MultiHeadLatentAttention
insert_cached_ssm_attention:
stage: cache_init
attn_backend: triton_ssm
backend: triton_ssm
insert_cached_causal_conv:
stage: cache_init
attn_backend: cuda_causal_conv
backend: cuda_causal_conv
initialize_cache:
stage: cache_init
run_per_gm: false
resize_kv_cache:
stage: cache_init
run_per_gm: false
free_mem_ratio: 0.0
############################################################################################
# COMPILE MODEL
############################################################################################
compile_model:
stage: compile
run_per_gm: false
cuda_graph_batch_sizes: null
backend: torch-compile

View File

@ -21,7 +21,7 @@ transforms:
run_per_gm: false
transformers_replace_cached_attn:
stage: cache_init
attn_backend: flashinfer
backend: flashinfer
run_per_gm: false
initialize_cache:
stage: cache_init
@ -29,6 +29,7 @@ transforms:
resize_kv_cache:
stage: cache_init
run_per_gm: false
free_mem_ratio: 0.0
############################################################################################
# COMPILE MODEL
############################################################################################

View File

@ -10,12 +10,7 @@ import torch.export as te
import torch.nn as nn
from torch import fx
from ..transformations._graph import (
canonicalize_graph,
lift_to_meta,
load_buffers_and_params,
tree_to,
)
from ..utils._graph import canonicalize_graph, lift_to_meta, load_buffers_and_params, tree_to
from ..utils.logger import ad_logger
from ..utils.node_utils import is_op
from .interface import apply_export_patches

View File

@ -38,6 +38,19 @@ def _check_for_default_value_only(
return value
_TRANSFORMS_SHORTCUT_LOOKUP = {
"attn_backend": ("insert_cached_attention.backend", "transformers_replace_cached_attn.backend"),
"free_mem_ratio": ("resize_kv_cache.free_mem_ratio",),
"compile_backend": ("compile_model.backend",),
"cuda_graph_batch_sizes": ("compile_model.cuda_graph_batch_sizes",),
}
def _shortcut_description(description: str, shortcut: str) -> str:
long_names_str = ", ".join([f"transforms.{k}" for k in _TRANSFORMS_SHORTCUT_LOOKUP[shortcut]])
return f"{description} Alias for: {long_names_str}."
class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings):
"""An argument class stripped down to AutoDeploy-specific configurations.
@ -75,12 +88,6 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings):
"If True, only the model architecture is loaded.",
)
checkpoint_device: Optional[str] = Field(
default=None,
description="Device on which to load the model checkpoint. "
"Defaults to the same device as the rest of the pipeline.",
)
tokenizer: Optional[PathLike] = Field(
description="The tokenizer",
default=None,
@ -141,53 +148,6 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings):
)
### INFERENCE OPTIMIZER CONFIG #################################################################
attn_backend: Literal["flashinfer", "triton", "torch"] = Field(
default="flashinfer", description="Attention backend to use."
)
mla_backend: Literal["MultiHeadLatentAttention"] = Field(
default="MultiHeadLatentAttention",
description="The Multi-Head Latent Attention backend to use.",
)
free_mem_ratio: float = Field(
default=0.0,
ge=0.0,
le=1.0,
description="The fraction of available memory to allocate for cache.",
)
simple_shard_only: bool = Field(
default=False,
description="If True, force simple sharding (all_gather) in tensor parallelism. "
"If False, auto-detect and use column+row (all_reduce) sharding when possible.",
)
use_sharding_from_factory: bool = Field(
default=False,
description="If True, use sharding from the model factory. If False, use sharding from the "
"AutoDeployConfig.",
)
sharding_dims: List[str] = Field(
default=["tp", "ep", "dp"],
description="The sharding methods to apply by the heuristic sharding stage.",
)
compile_backend: Literal["torch-simple", "torch-compile", "torch-cudagraph", "torch-opt"] = (
Field(
default="torch-compile",
description="The backend to use for compiling the model.",
)
)
cuda_graph_batch_sizes: Optional[List[int]] = Field(
default=None, description="List of batch sizes to create CUDA graphs for."
)
visualize: bool = Field(default=False, description="Whether to visualize the model graph.")
### NEW INFERENCE OPTIMIZER CONFIG #############################################################
mode: Literal["graph", "transformers"] = Field(
default="graph",
description="The mode to use for the inference optimizer. Currently, we "
@ -195,12 +155,38 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings):
"or transformers-only cached attention optimization.",
)
transforms: Dict[str, Any] = Field(
transforms: Dict[str, Dict[str, Any]] = Field(
default_factory=dict,
description="A dictionary of transform configurations. The key is the transform name and "
"the value is the transform configuration.",
)
### SHORTCUTS FOR COMMON INFERENCE OPTIMIZER CONFIGS ###########################################
attn_backend: str = Field(
default="flashinfer",
description=_shortcut_description("Attention backend to use.", "attn_backend"),
)
free_mem_ratio: float = Field(
default=0.0,
description=_shortcut_description(
"The fraction of available memory to allocate for cache.", "free_mem_ratio"
),
)
compile_backend: str = Field(
default="torch-compile",
description=_shortcut_description(
"The backend to use for compiling the model.", "compile_backend"
),
)
cuda_graph_batch_sizes: Optional[List[int]] = Field(
default=None,
description=_shortcut_description(
"List of batch sizes for CUDA graph creation. If not provided, a heuristic will"
" be used to determine the batch sizes.",
"cuda_graph_batch_sizes",
),
)
### SEQUENCE INTERFACE CONFIG ##################################################################
max_input_len: int = Field(default=1024, description="The maximum input length.")
max_num_tokens: Optional[int] = Field(default=None, description="The maximum number of tokens.")
@ -219,8 +205,16 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings):
# TODO: discuss what to do with this once we fully transition to the new inference optimizer
def update_attn_page_size(self):
# NOTE force attn_page_size to equal max_seq_len for triton backend
# TODO: maybe don't do this and rely on slot_idx instead??
if self.attn_backend == "triton" or self.attn_backend == "torch":
if self.transforms.get("insert_cached_attention", {}).get("backend") in [
"triton",
"torch",
]:
self.attn_page_size = self.max_seq_len
# NOTE: (hg) For transformers mode. This is ugly.
if self.transforms.get("transformers_replace_cached_attn", {}).get("backend") in [
"triton",
"torch",
]:
self.attn_page_size = self.max_seq_len
return self
@ -235,6 +229,27 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings):
return value
@model_validator(mode="after")
def update_transforms_with_shortcuts(self) -> Dict[str, Any]:
"""Synchronize the transforms config with the values from the defined shortcuts.
NOTE: shortcut values always take precedence over the values in the transforms config.
"""
for shortcut_key, transforms_keys in _TRANSFORMS_SHORTCUT_LOOKUP.items():
for transform_key in transforms_keys:
t_key, config_key = transform_key.split(".")
if t_key not in self.transforms:
continue
# first update the transforms config with the shortcut value
if shortcut_key in self.model_fields_set:
self.transforms[t_key][config_key] = getattr(self, shortcut_key)
# then update the shortcut field with the value from the transforms config to make
# sure both fields are in sync
setattr(self, shortcut_key, self.transforms[t_key][config_key])
return self
### UTILITY METHODS ############################################################################
def create_factory(self) -> ModelFactory:
"""Create a model factory from the arguments."""

View File

@ -25,7 +25,7 @@ from ...pyexecutor.scheduler import (
from ..custom_ops.attention_interface import SequenceInfo
from ..distributed import common as dist
from ..llm_args import AutoDeployConfig, LlmArgs
from ..transformations.transform import InferenceOptimizer
from ..transform.optimizer import InferenceOptimizer
from ..utils.logger import ad_logger
from .interface import CachedSequenceInterface, GetInferenceModel
@ -110,7 +110,7 @@ class ADEngine(ModelEngine):
# ADEngine.__init__, and ADEngine.build_from_config. Seems a bit unnatural atm.
# construct inference optimizer
build_and_optimize = InferenceOptimizer(factory=factory, ad_config=ad_config)
build_and_optimize = InferenceOptimizer(factory=factory, config=ad_config.transforms)
# construct engine
return cls(build_and_optimize, seq_info, device, max_beam_width)

View File

@ -14,6 +14,11 @@ class CachedSequenceInterface:
def __init__(
self, sequence_info: SequenceInfo, device: Optional[DeviceLikeType] = None
) -> None:
# TODO (lucaslie): this is somewhat circular/confusing. Here `device` denotes the desired
# device and not the actual device unlike, e.g., in SequenceInfo. We rely on the attribute
# here to read the desired device across the inference optimizer pipeline. We should ideally
# think about a better way to handle this,
# see https://github.com/NVIDIA/TensorRT-LLM/issues/8371
self.device = device or "cuda"
self.info = sequence_info
self._cache_initializers: Dict[str, GetCacheCallable] = {}

View File

@ -16,7 +16,7 @@ from torch.fx import GraphModule
from ..models.factory import ModelFactory
from ..shim.interface import CachedSequenceInterface
from ..transformations._graph import (
from ..utils._graph import (
canonicalize_graph,
lift_to_meta,
named_graphmodules,
@ -48,6 +48,7 @@ class Stages(Enum):
WEIGHT_LOAD = "weight_load" # loading of the model weights
POST_LOAD_FUSION = "post_load_fusion" # post-loading fusion and perf optimizations of the graph
CACHE_INIT = "cache_init" # initialization of cached attention + (KV) cache initialization
VISUALIZE = "visualize" # visualization of the graph
COMPILE = "compile" # graph compilation stage using low-level compilers like torch.compile
def __lt__(self, other):
@ -63,7 +64,6 @@ class SharedConfig(BaseModel):
sharding_config: ShardingConfig = Field(default_factory=ShardingConfig)
local_rank: int = Field(default=0)
world_size: int = Field(default=1)
attn_backend: str = Field(default="flashinfer", description="The attention backend to use.")
class TransformConfig(BaseModel):

View File

@ -2,14 +2,13 @@
from inspect import Parameter, Signature
from itertools import product
from typing import Any, Callable, Dict, List, Tuple, Type
from typing import Any, Callable, Dict, List, Literal, Tuple, Type
import torch
import torch.nn.functional as F
from pydantic import Field
from torch.fx import GraphModule
from ...custom_ops.attention_interface import AttentionDescriptor
from ...models.factory import ModelFactory
from ...shim.interface import CachedSequenceInterface
from ...utils.logger import ad_logger
@ -696,9 +695,11 @@ def register_match_attn_layout(patterns: ADPatternMatcherPass):
class MatchAttentionLayoutConfig(TransformConfig):
"""Configuration for the insert cached attention transform."""
"""Configuration for the match attention layout transform."""
attention_op: Type[AttentionDescriptor] = Field(description="The attention descriptor to use.")
attn_layout: Literal["bsnd", "bnsd"] = Field(
description="Layout expected by the attention backend."
)
@TransformRegistry.register("match_attention_layout")
@ -721,13 +722,8 @@ class MatchAttentionLayout(BaseTransform):
factory: ModelFactory,
shared_config: SharedConfig,
) -> Tuple[GraphModule, TransformInfo]:
attention_layout = self.config.attention_op.get_attention_layout()
if attention_layout not in ("bnsd", "bsnd"):
raise ValueError(f"Unsupported attention layout: {attention_layout}")
# If backend expects bnsd, nothing to do.
if attention_layout == "bnsd":
if self.config.attn_layout == "bnsd":
return gm, TransformInfo(
skipped=False, num_matches=0, is_clean=False, has_valid_shapes=False
)

View File

@ -76,7 +76,7 @@ class BuildAndLoadFactoryModel(BuildModel):
assert isinstance(factory, hf.AutoModelFactory), "Only HF models are supported."
# build and load the model
model = factory.build_and_load_model(self.config.device)
model = factory.build_and_load_model(cm.device)
# we set the standard example sequence WITHOUT extra_args to set them to None so that
# only the text portion of the model gets called.

View File

@ -24,8 +24,8 @@ class CompileModelConfig(TransformConfig):
num_batched_inputs: int = Field(
default=2, description="The number of batched inputs to use for CUDA graphs."
)
compile_backend: Literal["torch-simple", "torch-compile", "torch-cudagraph", "torch-opt"] = (
Field(description="The backend to use for compiling the model.")
backend: Literal["torch-simple", "torch-compile", "torch-cudagraph", "torch-opt"] = Field(
description="The backend to use for compiling the model."
)
@ -48,7 +48,7 @@ class CompileModel(BaseTransform):
) -> Tuple[nn.Module, TransformInfo]:
cm.info.set_generate_only_batch()
compiler_cls = CompileBackendRegistry.get(self.config.compile_backend)
compiler_cls = CompileBackendRegistry.get(self.config.backend)
mod_compiled = compiler_cls(
mod,
args=(),

View File

@ -12,7 +12,7 @@ from ...custom_ops.attention_interface import AttentionDescriptor, AttentionRegi
from ...distributed.common import all_gather_object, get_world_size
from ...models.factory import ModelFactory
from ...shim.interface import CachedSequenceInterface
from ...transformations._graph import add_graph_input
from ...utils._graph import add_graph_input
from ...utils.node_utils import get_all_input_output_nodes, is_op
from ..interface import (
BaseTransform,
@ -64,16 +64,13 @@ class UpdateInOutNodes(BaseTransform):
class InsertCachedAttentionConfig(TransformConfig):
"""Configuration for the insert cached attention transform."""
attn_backend: Optional[str] = Field(default=None, description="The attention backend to use.")
backend: Optional[str] = Field(default=None, description="The attention backend to use.")
@TransformRegistry.register("insert_cached_attention")
class InsertCachedAttention(BaseTransform):
"""
A transform to insert cached attention into the graph module.
If attn_backend is not provided in transform config, will find from shared config.
"""
A transform to insert cached attention into the graph module."""
config: InsertCachedAttentionConfig
@ -83,7 +80,7 @@ class InsertCachedAttention(BaseTransform):
@property
def attn_descriptor(self) -> Type[AttentionDescriptor]:
return AttentionRegistry.get(self.config.attn_backend)
return AttentionRegistry.get(self.config.backend)
def _process_get_metadata(
self, gm: GraphModule, m_args: List[str], const_args: List[Constant]
@ -222,7 +219,7 @@ class ResizeKVCacheConfig(TransformConfig):
"""Configuration for the resize kv cache transform."""
free_mem_ratio: float = Field(
description="The fraction of available memory to occupy.", default=0.8
default=0.0, ge=0.0, le=1.0, description="The fraction of available memory to occupy."
)

View File

@ -7,7 +7,7 @@ from pydantic import Field
from ...models.factory import ModelFactory
from ...shim.interface import CachedSequenceInterface
from ...transformations._graph import move_to_device
from ...utils._graph import move_to_device
from ..interface import (
BaseTransform,
SharedConfig,
@ -20,9 +20,9 @@ from ..interface import (
class MoveDeviceConfig(TransformConfig):
"""Configuration for the moving inputs/arguments to the device transform."""
device: str = Field(default="meta", description="The device to load the weights on.")
adconfig_checkpoint_device: Optional[str] = Field(
default=None, description="Optional checkpoint device argument from adconfig."
checkpoint_device: Optional[str] = Field(
default=None,
description="Optional device to init checkpoint before move to shared_config.local_device.",
)
@ -45,9 +45,9 @@ class LoadWeightsToDevice(BaseTransform):
) -> Tuple[nn.Module, TransformInfo]:
factory.load_or_random_init(
mod,
device=self.config.adconfig_checkpoint_device or self.config.device,
device=self.config.checkpoint_device or cm.device,
)
move_to_device(mod, self.config.device)
move_to_device(mod, cm.device)
info = TransformInfo(skipped=False, num_matches=0, is_clean=True, has_valid_shapes=True)
@ -71,7 +71,9 @@ class LoadFactoryModelWeights(BaseTransform):
factory: ModelFactory,
shared_config: SharedConfig,
) -> Tuple[nn.Module, TransformInfo]:
cm.to(self.config.device)
# TODO (hg) This is weird but equivalent to previous code.
# We does not seems to need this transform.
cm.to(cm.device)
info = TransformInfo(skipped=False, num_matches=0, is_clean=True, has_valid_shapes=True)

View File

@ -3,12 +3,24 @@
import json
from typing import Tuple
import model_explorer
import torch
import torch.export as te
from model_explorer.graph_builder import GraphNode, KeyValue, MetadataItem
from model_explorer.pytorch_exported_program_adater_impl import PytorchExportedProgramAdapterImpl
from torch import fx
from torch.fx import GraphModule
from ...models.factory import ModelFactory
from ...shim.interface import CachedSequenceInterface
from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry
try:
import model_explorer
from model_explorer.graph_builder import GraphNode, KeyValue, MetadataItem
from model_explorer.pytorch_exported_program_adater_impl import (
PytorchExportedProgramAdapterImpl,
)
except ImportError:
model_explorer = None
GraphNode = KeyValue = MetadataItem = PytorchExportedProgramAdapterImpl = None
# Optionally, you can log a warning or handle this gracefully elsewhere
def print_tensor(self, tensor: torch.Tensor, size_limit: int = 16):
@ -62,9 +74,6 @@ def add_outputs_metadata(self, fx_node: torch.fx.node.Node, node: GraphNode):
raise ValueError(f"Unsupported output type: {type(out_vals)}")
PytorchExportedProgramAdapterImpl.print_tensor = print_tensor
PytorchExportedProgramAdapterImpl.add_outputs_metadata = add_outputs_metadata
# TODO(yudong): make custom_ops configurable
CUSTOM_OPS = (
torch.ops.auto_deploy.torch_dist_all_reduce.default,
@ -76,13 +85,26 @@ CUSTOM_OPS = (
)
# TODO(yudong): make viz as non-block call.
def visualize_namespace(gm: fx.GraphModule, args: Tuple[torch.Tensor, ...], dynamic_shapes):
ep = te.export(gm, args=args, dynamic_shapes=dynamic_shapes)
graph = ep.graph
# Ensure the ops land up in the right module for better viz
for n in graph.nodes:
if n.target in CUSTOM_OPS:
n.meta["nn_module_stack"] = n.args[0].meta["nn_module_stack"]
@TransformRegistry.register("visualize_namespace")
class VisualizeNamespace(BaseTransform):
def _apply(
self,
gm: GraphModule,
cm: CachedSequenceInterface,
factory: ModelFactory,
shared_config: SharedConfig,
) -> Tuple[GraphModule, TransformInfo]:
PytorchExportedProgramAdapterImpl.print_tensor = print_tensor
PytorchExportedProgramAdapterImpl.add_outputs_metadata = add_outputs_metadata
model_explorer.visualize_pytorch("model-viz", ep)
# TODO(yudong): make viz as non-block call.
ep = te.export(gm, args=cm.args, dynamic_shapes=cm.dynamic_shapes)
graph = ep.graph
# Ensure the ops land up in the right module for better viz
for n in graph.nodes:
if n.target in CUSTOM_OPS:
n.meta["nn_module_stack"] = n.args[0].meta["nn_module_stack"]
model_explorer.visualize_pytorch("model-viz", ep)
return gm, TransformInfo(skipped=False, num_matches=1, is_clean=True, has_valid_shapes=True)

View File

@ -1,7 +1,9 @@
"""High-level entrypoint to transform a model into an efficient inference model."""
import gc
from typing import Optional
import torch
import torch.distributed as dist
import torch.nn as nn
@ -71,4 +73,6 @@ class InferenceOptimizer:
############################################################################################
# RETURN OPTIMIZED MODEL
############################################################################################
torch.cuda.empty_cache()
gc.collect()
return mod

View File

@ -1 +0,0 @@
"""V1 Graph Transformations Module --> will be deprecated and replaced by auto_deploy.transform."""

View File

@ -1,6 +0,0 @@
"""A library of transformation passes."""
try:
from .visualization import visualize_namespace
except ImportError:
pass

View File

@ -1,117 +0,0 @@
"""High-level entrypoint to transform a model into an efficient inference model."""
import gc
import torch
import torch.nn as nn
from ..custom_ops.attention_interface import AttentionRegistry
from ..llm_args import AutoDeployConfig
from ..models.factory import ModelFactory
from ..shim.interface import CachedSequenceInterface
from ..transform.optimizer import InferenceOptimizer as ModularInferenceOptimizer
class InferenceOptimizer:
def __init__(self, factory: ModelFactory, ad_config: AutoDeployConfig):
self.factory = factory
self.ad_config = ad_config
def __call__(self, cm: CachedSequenceInterface) -> nn.Module:
"""Transform a model into an optimized inference model.
Args:
model: The model to transform.
cp: The cache pool to use for caching.
args: Example inputs to the model.
dynamic_shapes: Dynamic shapes to use. Defaults to None.
poe_config: The config for positional encoding. Defaults to None.
quantization: The quantization method to use. Defaults to None.
Returns:
A nn.Module representing the optimized inference model.
"""
############################################################################################
# RUN MODULAR INFERENCE OPTIMIZER FOR ALREADY-MIGRATED TRANSFORMS
############################################################################################
# TODO (hg): default values that are not representable in YAML.
# move to the optimizer
if "match_attention_layout" in self.ad_config.transforms:
self.ad_config.transforms["match_attention_layout"]["attention_op"] = (
AttentionRegistry.get(self.ad_config.attn_backend)
)
if "match_rope_layout" in self.ad_config.transforms:
self.ad_config.transforms["match_rope_layout"]["expected_layout"] = (
AttentionRegistry.get(self.ad_config.attn_backend).get_attention_layout()
)
if "load_weights" in self.ad_config.transforms:
self.ad_config.transforms["load_weights"]["checkpoint_device"] = (
self.ad_config.checkpoint_device
)
self.ad_config.transforms["load_weights"]["device"] = cm.device
if "build_and_load_factory_model" in self.ad_config.transforms:
self.ad_config.transforms["build_and_load_factory_model"]["device"] = cm.device
if "move_inputs_to_device" in self.ad_config.transforms:
self.ad_config.transforms["move_inputs_to_device"]["checkpoint_device"] = (
self.ad_config.checkpoint_device
)
self.ad_config.transforms["move_inputs_to_device"]["device"] = cm.device
if "resize_kv_cache" in self.ad_config.transforms:
self.ad_config.transforms["resize_kv_cache"]["free_mem_ratio"] = (
self.ad_config.free_mem_ratio
)
if "insert_cached_attention" in self.ad_config.transforms:
self.ad_config.transforms["insert_cached_attention"]["attn_backend"] = (
self.ad_config.attn_backend
)
if "insert_cached_mla_attention" in self.ad_config.transforms:
self.ad_config.transforms["insert_cached_mla_attention"]["attn_backend"] = (
self.ad_config.mla_backend
)
if "transformers_replace_cached_attn" in self.ad_config.transforms:
self.ad_config.transforms["transformers_replace_cached_attn"]["attn_backend"] = (
self.ad_config.attn_backend
)
# TODO: (hg)Missing MLA here. Figure out how to add MLA since duplicate transforms are not allowed.
# Old code:
# detect attention op and replace with cache-aware op
# for a_backend in [self.ad_config.attn_backend, self.ad_config.mla_backend]:
# attn_descriptor = AttentionRegistry.get(a_backend)
# insert_cached_attention(egm, cm, attn_descriptor, self.factory.get_cache_config())
if "compile_model" in self.ad_config.transforms:
self.ad_config.transforms["compile_model"]["cuda_graph_batch_sizes"] = (
self.ad_config.cuda_graph_batch_sizes
)
self.ad_config.transforms["compile_model"]["compile_backend"] = (
self.ad_config.compile_backend
)
new_optimizer = ModularInferenceOptimizer(self.factory, self.ad_config.transforms)
# TODO: (hg) move this. let match_rope_layout and match_atten_layout use this shared config
new_optimizer.shared_config.attn_backend = self.ad_config.attn_backend
egm = new_optimizer(cm)
# NOTE: (hg)Disabled visualization since compiled gm is a CapturedGraph instead of GraphModule.
# We can add a new stage in the optimizer to visualize the intermediate gm.
# if self.ad_config.visualize:
# try:
# from .library import visualize_namespace
# visualize_namespace(egm, args=cm.args, dynamic_shapes=cm.dynamic_shapes)
# ad_logger.warning(
# "Please run `pip install -r examples/auto_deploy/requirements.txt` to visualize"
# " the graph."
# )
# except ImportError:
# pass
torch.cuda.empty_cache()
gc.collect()
return egm

View File

@ -5,7 +5,7 @@ from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from omegaconf import DictConfig, OmegaConf
from pydantic import Field, field_validator, model_validator
from pydantic import Field, model_validator
from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, YamlConfigSettingsSource
from pydantic_settings.sources.types import DEFAULT_PATH, PathType
@ -161,23 +161,6 @@ class DynamicYamlMixInForSettings:
'with higher priority. "Later" files have higher priority. Should be used with care!',
)
# TODO: remove this field in a future version
yaml_configs: List[PathType] = Field(
default_factory=list,
description="DEPRECATED: Please use yaml_extra instead.",
)
@field_validator("yaml_configs")
@classmethod
def validate_yaml_configs_deprecated(cls, v):
"""Throw error that yaml_configs is deprecated in favor of yaml_extra."""
if v: # Only raise error if the field is actually being used (not empty)
raise ValueError(
"The 'yaml_configs' field is deprecated and no longer supported. "
"Please use 'yaml_extra' instead."
)
return v
@model_validator(mode="after")
def validate_mode_and_yaml_default_not_both_provided(self):
"""Validate that both mode and yaml_default are not provided simultaneously.

View File

@ -15,8 +15,8 @@ from torch.fx.passes.fake_tensor_prop import FakeTensorProp
from torch.fx.passes.shape_prop import _extract_tensor_metadata
from torch.utils._pytree import _LEAF_SPEC
from ..utils.logger import ad_logger
from ..utils.node_utils import is_op
from .logger import ad_logger
from .node_utils import is_op
_NoValType = type("_NoValType", (), {})
_NO_VAL = _NoValType()

View File

@ -131,7 +131,6 @@ class PerformanceOptions:
def get_autodeploy_perf_config(self) -> Dict:
AutoDeployPerfConfig = dict
ad_config = AutoDeployPerfConfig()
ad_config["attn_backend"] = "flashinfer"
return ad_config

View File

@ -52,9 +52,17 @@ class TestLlama3_1_8B(LlmapiAccuracyTestHarness):
# Set it explicitly here to 8192 which is the default in build_config.
"max_num_tokens": 8192,
"skip_loading_weights": False,
"compile_backend": "torch-opt",
"free_mem_ratio": 0.7,
"cuda_graph_batch_sizes": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512],
"transforms": {
"resize_kv_cache": {
"free_mem_ratio": 0.7
},
"compile_model": {
"backend":
"torch-opt",
"cuda_graph_batch_sizes":
[1, 2, 4, 8, 16, 32, 64, 128, 256, 512],
},
},
}
def get_default_sampling_params(self):
@ -99,9 +107,15 @@ class TestNemotronH(LlmapiAccuracyTestHarness):
# Set explicitly to match default build_config behavior
"max_num_tokens": 8192,
"skip_loading_weights": False,
"compile_backend": "torch-opt",
"free_mem_ratio": 0.7,
"cuda_graph_batch_sizes": [1, 2, 4, 8, 16, 32, 64, 128],
"transforms": {
"resize_kv_cache": {
"free_mem_ratio": 0.7
},
"compile_model": {
"backend": "torch-opt",
"cuda_graph_batch_sizes": [1, 2, 4, 8, 16, 32, 64, 128],
},
},
}
def get_default_sampling_params(self):

View File

@ -1411,8 +1411,14 @@ class MultiMetricPerfTest(AbstractPerfScriptTestClass):
# Create _autodeploy specific configuration
autodeploy_config = {
'compile_backend': self._config.ad_compile_backend,
'free_mem_ratio': self._config.free_mem_ratio,
'transforms': {
'compile_model': {
'backend': self._config.ad_compile_backend
},
'resize_kv_cache': {
'free_mem_ratio': self._config.free_mem_ratio
},
},
'runtime': self._config.extra_runtime,
'skip_loading_weights': self._config.skip_loading_weights
}

View File

@ -2,7 +2,6 @@ import copy
import os
from typing import Any, Dict, Optional
import pytest
import torch
import torch.nn.functional as F
from torch import nn
@ -444,7 +443,6 @@ _SMALL_MODEL_CONFIGS = {
"mistralai/Mistral-Small-3.1-24B-Instruct-2503": {
"llm_models_subdir": "Mistral-Small-3.1-24B-Instruct-2503",
"model_factory": "AutoModelForImageTextToText",
"compile_backend": "torch-simple",
"model_kwargs": {
"text_config": {
"num_hidden_layers": 2,
@ -531,10 +529,8 @@ def get_small_model_config(model_hub_id: str, **llm_args_kwargs) -> Dict[str, An
# add some defaults to llm_args
llm_args["skip_loading_weights"] = True # No weight loading to speed up things
llm_args["free_mem_ratio"] = 0.00 # we don't need the cache and it may cause OOM issues
llm_args["attn_page_size"] = 4 # Make sure paging is activated despite small max_tokens
llm_args["max_batch_size"] = 2 # Minimum batching to speed up things
# update with custom llm_args kwargs
llm_args.update(llm_args_kwargs)
@ -549,13 +545,3 @@ def get_small_model_config(model_hub_id: str, **llm_args_kwargs) -> Dict[str, An
}
return experiment_config
def get_small_model_config_pytest_param(
model_hub_id: str, pytest_param_kwargs=None, **llm_args_kwargs
):
return pytest.param(
get_small_model_config(model_hub_id, **llm_args_kwargs),
id=model_hub_id,
**(pytest_param_kwargs or {}),
)

View File

@ -1,28 +1,40 @@
"""Testing build_and_run_ad end2end."""
from typing import Dict
import pytest
from _model_test_utils import get_small_model_config_pytest_param
from _model_test_utils import get_small_model_config
from build_and_run_ad import ExperimentConfig, main
@pytest.mark.parametrize("world_size", [1, 2])
@pytest.mark.parametrize("mode", ["graph", "transformers"])
@pytest.mark.parametrize(
"experiment_config",
"model_hub_id, llm_extra_args",
[
get_small_model_config_pytest_param(
(
"meta-llama/Meta-Llama-3.1-8B-Instruct",
attn_backend="flashinfer",
compile_backend="torch-opt",
{
"transforms": {
"insert_cached_attention": {"backend": "flashinfer"},
"compile_model": {"backend": "torch-opt"},
},
},
),
(
"meta-llama/Meta-Llama-3.1-8B-Instruct",
{
"transforms": {
"transformers_replace_cached_attn": {"backend": "flashinfer"},
},
"mode": "transformers",
},
),
],
)
def test_build_ad(world_size: int, experiment_config: Dict, mode: str):
def test_build_ad(world_size: int, model_hub_id: str, llm_extra_args: dict):
experiment_config = get_small_model_config(model_hub_id, **llm_extra_args)
experiment_config["args"]["world_size"] = world_size
experiment_config["args"]["runtime"] = "trtllm" # Default runtime set to trtllm
experiment_config["args"]["mode"] = mode
experiment_config = ExperimentConfig(**experiment_config)
print(f"Experiment Config: {experiment_config}")
main(experiment_config)

View File

@ -1,12 +1,12 @@
import torch # noqa
import torch.export as te
from torch.export import Dim # noqa
import pytest
import torch
import torch.export as te
from _model_test_utils import get_small_model_config
from torch.export import Dim
from tensorrt_llm._torch.auto_deploy.export import apply_export_patches, torch_export_to_gm
from tensorrt_llm._torch.auto_deploy.llm_args import AutoDeployConfig
from tensorrt_llm._torch.auto_deploy.transformations._graph import move_to_device # noqa
from _model_test_utils import get_small_model_config
from tensorrt_llm._torch.auto_deploy.utils._graph import move_to_device
# NOTE: find example inputs with the same tokenization length to avoid seq concat.
EXAMPLE_INPUT = "Mamba is a snake with the following properties:"
@ -31,10 +31,12 @@ def test_bamba_patches(model_dir: str, run_verify_generation: bool):
common_kwargs = {
"world_size": 0,
"runtime": "demollm",
"compile_backend": "torch-simple",
"attn_backend": "flashinfer",
"model_factory": "AutoModelForCausalLM",
"max_seq_len": 512,
"transforms": {
"insert_cached_attention": {"backend": "flashinfer"},
"compile_model": {"backend": "torch-simple"},
},
}
if use_small_config:

View File

@ -5,7 +5,7 @@ from PIL import Image
from tensorrt_llm._torch.auto_deploy import LlmArgs
from tensorrt_llm._torch.auto_deploy.export import apply_export_patches, torch_export_to_gm
from tensorrt_llm._torch.auto_deploy.transformations._graph import move_to_device
from tensorrt_llm._torch.auto_deploy.utils._graph import move_to_device
def test_build_run_llama4_vlm():

View File

@ -4,7 +4,7 @@ from build_and_run_ad import ExperimentConfig
from tensorrt_llm._torch.auto_deploy import LlmArgs
from tensorrt_llm._torch.auto_deploy.export import apply_export_patches, torch_export_to_gm
from tensorrt_llm._torch.auto_deploy.transformations._graph import move_to_device
from tensorrt_llm._torch.auto_deploy.utils._graph import move_to_device
def test_build_run_mistral3_vlm():

View File

@ -41,10 +41,8 @@ def get_inference_model(cache_seq_interface):
@pytest.mark.parametrize("engine_cls", [ADEngine, DemoEngine])
@pytest.mark.parametrize(
"attn_backend, attn_page_size", [("triton", 0), ("flashinfer", 2), ("torch", 0)]
)
def test_engine(engine_cls: Type[ADEngine], attn_backend: str, attn_page_size: int):
@pytest.mark.parametrize("attn_page_size", [0, 2, 0])
def test_engine(engine_cls: Type[ADEngine], attn_page_size: int):
"""Test the SimpleEngine functionality."""
seed = 42 # Set random seed for model param init

View File

@ -4,6 +4,7 @@ import pydantic
import pytest
from tensorrt_llm._torch.auto_deploy import LLM, DemoLLM, LlmArgs
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
def test_custom_values():
@ -12,13 +13,23 @@ def test_custom_values():
"model": "test-model",
"model_factory": "AutoModelForImageTextToText",
"model_kwargs": {"custom_param": True},
"mla_backend": "MultiHeadLatentAttention",
"skip_loading_weights": True,
"free_mem_ratio": 0.9,
"simple_shard_only": True,
"attn_page_size": 128,
"attn_backend": "flashinfer",
"max_seq_len": 2048,
"transforms": {
"detect_sharding": {
"stage": "sharding",
"simple_shard_only": True,
},
"insert_cached_attention": {
"stage": "cache_init",
"backend": "flashinfer",
},
"resize_kv_cache": {
"stage": "cache_init",
"free_mem_ratio": 0.9,
},
},
}
args = LlmArgs(**custom_kwargs)
@ -28,26 +39,30 @@ def test_custom_values():
"custom_param": True,
}
assert args.skip_loading_weights
assert args.free_mem_ratio == 0.9
assert args.simple_shard_only
assert args.transforms["resize_kv_cache"]["free_mem_ratio"] == 0.9
assert args.transforms["detect_sharding"]["simple_shard_only"]
assert args.attn_page_size == 128
assert args.max_seq_len == 2048
# attn_backend should be overridden if it was 'TRTLLM'
assert args.attn_backend == "flashinfer"
# backend should be overridden if it was 'TRTLLM'
assert args.transforms["insert_cached_attention"]["backend"] == "flashinfer"
def test_free_mem_ratio_validation():
"""Test free_mem_ratio validation."""
def get_transform_config(free_mem_ratio):
return {"resize_kv_cache": {"stage": "cache_init", "free_mem_ratio": free_mem_ratio}}
# Valid values
LlmArgs(model="test-model", free_mem_ratio=0.0)
LlmArgs(model="test-model", free_mem_ratio=1.0)
LlmArgs(model="test-model", free_mem_ratio=0.5)
InferenceOptimizer(None, get_transform_config(0.0))
InferenceOptimizer(None, get_transform_config(1.0))
InferenceOptimizer(None, get_transform_config(0.5))
# Invalid values
with pytest.raises(ValueError):
LlmArgs(model="test-model", free_mem_ratio=-0.1)
InferenceOptimizer(None, get_transform_config(-0.1))
with pytest.raises(ValueError):
LlmArgs(model="test-model", free_mem_ratio=1.1)
InferenceOptimizer(None, get_transform_config(1.1))
def test_get_pytorch_backend_config():
@ -67,14 +82,25 @@ def test_config_params():
return {
"model": "test-model",
"model_factory": "AutoModelForImageTextToText",
"free_mem_ratio": 0.7,
"simple_shard_only": True,
"skip_loading_weights": True,
"attn_page_size": 17,
"attn_backend": "flashinfer",
"max_seq_len": 19,
"max_batch_size": 5,
"world_size": 3,
"transforms": {
"detect_sharding": {
"stage": "sharding",
"simple_shard_only": True,
},
"insert_cached_attention": {
"stage": "cache_init",
"backend": "flashinfer",
},
"resize_kv_cache": {
"stage": "cache_init",
"free_mem_ratio": 0.7,
},
},
}
@ -131,8 +157,14 @@ def test_config_flow(
# Common assertions for both APIs
assert instance.args.model_factory == test_config_params["model_factory"]
assert instance.args.free_mem_ratio == test_config_params["free_mem_ratio"]
assert instance.args.simple_shard_only == test_config_params["simple_shard_only"]
assert (
instance.args.transforms["resize_kv_cache"]["free_mem_ratio"]
== test_config_params["transforms"]["resize_kv_cache"]["free_mem_ratio"]
)
assert (
instance.args.transforms["detect_sharding"]["simple_shard_only"]
== test_config_params["transforms"]["detect_sharding"]["simple_shard_only"]
)
assert instance.args.skip_loading_weights == test_config_params["skip_loading_weights"]
assert instance.args.attn_page_size == test_config_params["attn_page_size"]
assert instance.args.max_seq_len == test_config_params["max_seq_len"]
@ -190,13 +222,17 @@ def test_parallel_config_validation(parallel_field, invalid_value):
@pytest.mark.parametrize(
"attn_backend,expected_attn_page_size",
"backend,expected_attn_page_size",
[
("flashinfer", 64), # Default attn_page_size
("triton", 1024), # Should equal max_seq_len
],
)
def test_attention_backend_page_size_logic(attn_backend, expected_attn_page_size):
def test_attention_backend_page_size_logic(backend, expected_attn_page_size):
"""Test attn_page_size logic for different attention backends."""
args = LlmArgs(model="test-model", attn_backend=attn_backend, max_seq_len=1024)
args = LlmArgs(
model="test-model",
max_seq_len=1024,
transforms={"insert_cached_attention": {"stage": "cache_init", "backend": backend}},
)
assert args.attn_page_size == expected_attn_page_size

View File

@ -1,13 +1,11 @@
"""Testing build_and_run_ad end2end."""
from typing import Dict
import pytest
from _model_test_utils import get_small_model_config_pytest_param
from _model_test_utils import get_small_model_config
from build_and_run_ad import ExperimentConfig, main
from tensorrt_llm._torch.auto_deploy.llm_args import AutoDeployConfig, LlmArgs, _ParallelConfig
from tensorrt_llm._torch.auto_deploy.transformations.transform import InferenceOptimizer
from tensorrt_llm._torch.auto_deploy.shim.ad_executor import ADEngine
def _check_ad_config(experiment_config: ExperimentConfig, llm_args: LlmArgs):
@ -41,87 +39,175 @@ def _check_ad_config(experiment_config: ExperimentConfig, llm_args: LlmArgs):
)
@pytest.mark.parametrize("mode", ["graph", "transformers"])
@pytest.mark.parametrize(
"experiment_config",
"model_hub_id, llm_extra_args",
[
get_small_model_config_pytest_param(
(
"meta-llama/Meta-Llama-3.1-8B-Instruct",
attn_backend="flashinfer",
compile_backend="torch-opt",
free_mem_ratio=0.0001,
{
"transforms": {
"resize_kv_cache": {"free_mem_ratio": 0.0001},
"insert_cached_attention": {"backend": "flashinfer"},
"compile_model": {"backend": "torch-opt"},
},
},
),
get_small_model_config_pytest_param(
(
"meta-llama/Meta-Llama-3.1-8B-Instruct",
{
"transforms": {
"transformers_replace_cached_attn": {"backend": "flashinfer"},
},
"mode": "transformers",
},
),
(
"mistralai/Mixtral-8x7B-Instruct-v0.1",
attn_backend="triton",
compile_backend="torch-simple",
{
"transforms": {
"insert_cached_attention": {"backend": "triton"},
"compile_model": {"backend": "torch-simple"},
},
},
),
get_small_model_config_pytest_param(
(
"mistralai/Mixtral-8x7B-Instruct-v0.1",
{
"transforms": {
"transformers_replace_cached_attn": {"backend": "triton"},
},
"mode": "transformers",
},
),
(
"Qwen/Qwen3-30B-A3B",
attn_backend="triton",
compile_backend="torch-simple",
{
"transforms": {
"insert_cached_attention": {"backend": "triton"},
"compile_model": {"backend": "torch-simple"},
},
},
),
get_small_model_config_pytest_param(
(
"Qwen/Qwen3-30B-A3B",
{
"transforms": {
"transformers_replace_cached_attn": {"backend": "triton"},
},
"mode": "transformers",
},
),
(
"microsoft/Phi-3-mini-4k-instruct",
attn_backend="triton",
compile_backend="torch-simple",
{
"transforms": {
"insert_cached_attention": {"backend": "triton"},
"compile_model": {"backend": "torch-simple"},
},
},
),
get_small_model_config_pytest_param(
(
"microsoft/Phi-3-mini-4k-instruct",
attn_backend="torch",
compile_backend="torch-simple",
{
"transforms": {
"insert_cached_attention": {"backend": "torch"},
"compile_model": {"backend": "torch-simple"},
},
},
),
get_small_model_config_pytest_param(
(
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
attn_backend="flashinfer",
compile_backend="torch-opt",
{
"transforms": {
"insert_cached_attention": {"backend": "flashinfer"},
"compile_model": {"backend": "torch-opt"},
},
},
),
get_small_model_config_pytest_param(
(
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
{
"transforms": {
"transformers_replace_cached_attn": {"backend": "flashinfer"},
},
"mode": "transformers",
},
),
(
"deepseek-ai/DeepSeek-V3",
attn_backend="triton",
compile_backend="torch-simple",
{
"transforms": {
"insert_cached_attention": {"backend": "triton"},
"compile_model": {"backend": "torch-simple"},
},
},
),
get_small_model_config_pytest_param(
(
"Qwen/Qwen2.5-3B-Instruct",
attn_backend="triton",
compile_backend="torch-compile",
{
"transforms": {
"insert_cached_attention": {"backend": "triton"},
"compile_model": {"backend": "torch-compile"},
},
},
),
get_small_model_config_pytest_param(
(
"Qwen/Qwen2.5-3B-Instruct",
{
"transforms": {
"transformers_replace_cached_attn": {"backend": "triton"},
},
"mode": "transformers",
},
),
(
"mistralai/Mistral-Small-3.1-24B-Instruct-2503",
attn_backend="flashinfer",
compile_backend="torch-cudagraph",
{
"transforms": {
"insert_cached_attention": {"backend": "flashinfer"},
"compile_model": {"backend": "torch-cudagraph"},
},
},
),
get_small_model_config_pytest_param(
(
"mistralai/Mistral-Small-3.1-24B-Instruct-2503",
{
"transforms": {
"transformers_replace_cached_attn": {"backend": "flashinfer"},
},
"mode": "transformers",
},
),
(
"nvidia/NVIDIA-Nemotron-Nano-12B-v2",
attn_backend="flashinfer",
compile_backend="torch-simple",
{
"transforms": {
"insert_cached_attention": {"backend": "flashinfer"},
"compile_model": {"backend": "torch-simple"},
},
},
),
],
)
def test_build_ad(experiment_config: Dict, mode: str):
if (
"DeepSeek-V3" in experiment_config["args"]["model"]
or "Phi-3-mini-4k-instruct" in experiment_config["args"]["model"]
or "NVIDIA-Nemotron-Nano-12B-v2" in experiment_config["args"]["model"]
and mode == "transformers"
):
pytest.skip(f"{experiment_config['args']['model']} is not supported in transformers mode")
def test_build_ad(model_hub_id: str, llm_extra_args: dict):
experiment_config = get_small_model_config(model_hub_id, **llm_extra_args)
experiment_config["args"]["runtime"] = "demollm" # Default runtime set to demollm
experiment_config["args"]["world_size"] = 0 # Default world_size set to 0
experiment_config["args"]["mode"] = mode
print(f"Experiment Config: {experiment_config}")
experiment_config = ExperimentConfig(**experiment_config)
original_init = InferenceOptimizer.__init__
original_build_from_config = ADEngine.build_from_config
def check_and_original_init(self, factory, ad_config):
@classmethod
def check_and_original_build(cls, ad_config):
_check_ad_config(experiment_config, ad_config)
return original_init(self, factory, ad_config=ad_config)
return original_build_from_config.__func__(cls, ad_config)
# Temporarily replace the __init__ method
InferenceOptimizer.__init__ = check_and_original_init
# Temporarily replace the build_from_config classmethod
ADEngine.build_from_config = check_and_original_build
try:
main(experiment_config)
finally:
# Restore original __init__
InferenceOptimizer.__init__ = original_init
# Restore original build_from_config
ADEngine.build_from_config = original_build_from_config

View File

@ -75,8 +75,14 @@ def test_trtllm_bench(llm_root, compile_backend, model_name): # noqa: F811
with open(extra_llm_api_options_path, "w") as f:
yaml.dump(
{
"compile_backend": compile_backend,
**config["args"],
"transforms": {
"compile_model": {
"stage": "compile",
"cuda_graph_batch_sizes": [1, 2, 4, 8, 16, 32, 64, 128],
"backend": compile_backend,
}
},
},
f,
)

View File

@ -7,7 +7,6 @@ from torch.export import Dim
from torch.fx import GraphModule
from transformers.integrations.sdpa_attention import repeat_kv as hf_repeat_kv
from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import AttentionDescriptor
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op
@ -1014,21 +1013,6 @@ class Llama3CausalAttentionModel(torch.nn.Module):
return {0: Dim("batch_size", max=8), 1: Dim("seq_len", min=4, max=16)}
class MockAttentionDescriptor(AttentionDescriptor):
"""A mock class that mimics the AttentionDescriptor interface for testing."""
layout: str = "bnsd"
source_attention_op: Callable = torch.ops.auto_deploy.torch_attention_sdpa
@classmethod
def get_attention_layout(cls) -> str:
return cls.layout
@classmethod
def get_source_attention_op(cls) -> Callable:
return cls.source_attention_op
class AttentionLayoutModel(torch.nn.Module):
"""Model that uses SDPA for testing the layout transformation."""
@ -1169,17 +1153,6 @@ def test_match_attention_layout(layout, model_config, has_mask):
hidden_size = 512
num_heads = 8
# Set up the mock attention descriptor class with the specified layout
MockAttentionDescriptor.layout = layout
if layout == "bnsd":
if model_config.get("use_grouped_sdpa"):
source_op = torch.ops.auto_deploy.torch_attention
else:
source_op = torch.ops.auto_deploy.torch_attention_sdpa
else:
source_op = torch.ops.auto_deploy.torch_attention
MockAttentionDescriptor.source_attention_op = source_op
# Create appropriate model based on model_config
if model_config["type"] == "standard":
model = AttentionLayoutModel(
@ -1329,7 +1302,7 @@ def test_match_attention_layout(layout, model_config, has_mask):
{
"match_attention_layout": {
"stage": "pattern_matcher",
"attention_op": MockAttentionDescriptor,
"attn_layout": layout,
},
},
)(None, gm),

View File

@ -1,7 +1,7 @@
"""Test that the attention matcher works with HF's SDPA backends."""
import copy
from typing import Any, Callable, Dict
from typing import Any, Dict
import pytest
import torch
@ -12,29 +12,14 @@ from torch.fx import GraphModule
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaModel
from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import AttentionDescriptor
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
from tensorrt_llm._torch.auto_deploy.transformations._graph import move_to_device
from tensorrt_llm._torch.auto_deploy.utils._graph import move_to_device
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op
torch.manual_seed(0)
class MockAttentionDescriptor(AttentionDescriptor):
"""A mock class that mimics the AttentionDescriptor interface for testing."""
layout: str = "bsnd"
@classmethod
def get_attention_layout(cls) -> str:
return cls.layout
@classmethod
def get_source_attention_op(cls) -> Callable:
return torch.ops.auto_deploy.torch_attention
class HFWrapper(nn.Module):
def __init__(self, model: nn.Module):
super().__init__()
@ -61,7 +46,7 @@ def _joint_transform(gm: GraphModule) -> None:
},
"match_attention_layout": {
"stage": "pattern_matcher",
"attention_op": MockAttentionDescriptor,
"attn_layout": "bsnd",
},
},
)(None, gm)

View File

@ -192,7 +192,7 @@ def test_sdpa_with_kv_cache(dtype, attn_backend, gqa_config):
},
"insert_cached_attention": {
"stage": "cache_init",
"attn_backend": attn_backend,
"backend": attn_backend,
},
},
) # type: ignore

View File

@ -414,23 +414,6 @@ def test_partial_env_override(basic_yaml_files):
assert settings.option.option == "on" # from env
# Error handling tests
def test_deprecated_yaml_configs_field_error(basic_yaml_files):
"""Test that using deprecated yaml_configs field raises ValueError."""
with pytest.raises(
ValueError, match=r"The 'yaml_configs' field is deprecated.*Please use 'yaml_extra' instead"
):
BasicSettings(yaml_configs=[basic_yaml_files["config1"]])
def test_empty_yaml_configs_allowed():
"""Test that empty yaml_configs list doesn't raise error."""
# Empty yaml_configs should not raise error (but validation will still fail for missing fields)
with pytest.raises(ValidationError):
# Should fail validation for missing required fields, not for yaml_configs deprecation
BasicSettings(yaml_configs=[])
def test_missing_yaml_file(temp_dir):
"""Test handling of missing yaml file."""
missing_file = temp_dir / "missing.yaml"