mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[None][chore] AutoDeploy: cleanup old inference optimizer configs (#8039)
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com> Co-authored-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
This commit is contained in:
parent
bb7fdcebf4
commit
55fed1873c
@ -40,29 +40,31 @@ trtllm-bench \
|
||||
#### Basic Performance Configuration (`autodeploy_config.yaml`)
|
||||
|
||||
```yaml
|
||||
# Compilation backend
|
||||
compile_backend: torch-opt
|
||||
|
||||
# Runtime engine
|
||||
# runtime engine
|
||||
runtime: trtllm
|
||||
|
||||
# Model loading
|
||||
# model loading
|
||||
skip_loading_weights: false
|
||||
|
||||
# Fraction of free memory to use for kv-caches
|
||||
free_mem_ratio: 0.8
|
||||
|
||||
# CUDA Graph optimization
|
||||
cuda_graph_batch_sizes: [1, 2, 4, 8, 16, 32, 64, 128, 256]
|
||||
|
||||
# Attention backend
|
||||
attn_backend: flashinfer
|
||||
|
||||
# Sequence configuration
|
||||
max_batch_size: 256
|
||||
|
||||
# transform options
|
||||
transforms:
|
||||
insert_cached_attention:
|
||||
# attention backend
|
||||
backend: flashinfer
|
||||
resize_kv_cache:
|
||||
# fraction of free memory to use for kv-caches
|
||||
free_mem_ratio: 0.8
|
||||
compile_model:
|
||||
# compilation backend
|
||||
backend: torch-opt
|
||||
# CUDA Graph optimization
|
||||
cuda_graph_batch_sizes: [1, 2, 4, 8, 16, 32, 64, 128, 256]
|
||||
```
|
||||
|
||||
Enable multi-GPU execution by specifying `--tp n`, where `n` is the number of GPUs
|
||||
Enable multi-GPU execution by specifying `--tp n`, where `n` is the number of GPUs.
|
||||
|
||||
## Configuration Options Reference
|
||||
|
||||
|
||||
@ -63,15 +63,15 @@ args:
|
||||
num_hidden_layers: 12
|
||||
hidden_size: 1024
|
||||
world_size: 4
|
||||
compile_backend: torch-compile
|
||||
attn_backend: triton
|
||||
max_seq_len: 2048
|
||||
max_batch_size: 16
|
||||
transforms:
|
||||
sharding:
|
||||
strategy: auto
|
||||
quantization:
|
||||
enabled: false
|
||||
detect_sharding:
|
||||
support_partial_config: true
|
||||
insert_cached_attention:
|
||||
backend: triton
|
||||
compile_model:
|
||||
backend: torch-compile
|
||||
|
||||
prompt:
|
||||
batch_size: 8
|
||||
@ -79,13 +79,6 @@ prompt:
|
||||
max_tokens: 150
|
||||
temperature: 0.8
|
||||
top_k: 50
|
||||
|
||||
benchmark:
|
||||
enabled: true
|
||||
num: 20
|
||||
bs: 4
|
||||
isl: 1024
|
||||
osl: 256
|
||||
```
|
||||
|
||||
Create an additional override file (e.g., `production.yaml`):
|
||||
@ -94,11 +87,10 @@ Create an additional override file (e.g., `production.yaml`):
|
||||
# production.yaml
|
||||
args:
|
||||
world_size: 8
|
||||
compile_backend: torch-opt
|
||||
max_batch_size: 32
|
||||
|
||||
benchmark:
|
||||
enabled: false
|
||||
transforms:
|
||||
compile_model:
|
||||
backend: torch-opt
|
||||
```
|
||||
|
||||
Then use these configurations:
|
||||
@ -107,18 +99,18 @@ Then use these configurations:
|
||||
# Using single YAML config
|
||||
python build_and_run_ad.py \
|
||||
--model "meta-llama/Meta-Llama-3.1-8B-Instruct" \
|
||||
--yaml-configs my_config.yaml
|
||||
--yaml-extra my_config.yaml
|
||||
|
||||
# Using multiple YAML configs (deep merged in order, later files have higher priority)
|
||||
python build_and_run_ad.py \
|
||||
--model "meta-llama/Meta-Llama-3.1-8B-Instruct" \
|
||||
--yaml-configs my_config.yaml production.yaml
|
||||
--yaml-extra my_config.yaml production.yaml
|
||||
|
||||
# Targeting nested AutoDeployConfig with separate YAML
|
||||
python build_and_run_ad.py \
|
||||
--model "meta-llama/Meta-Llama-3.1-8B-Instruct" \
|
||||
--yaml-configs my_config.yaml \
|
||||
--args.yaml-configs autodeploy_overrides.yaml
|
||||
--yaml-extra my_config.yaml \
|
||||
--args.yaml-extra autodeploy_overrides.yaml
|
||||
```
|
||||
|
||||
## Configuration Precedence and Deep Merging
|
||||
@ -126,7 +118,7 @@ python build_and_run_ad.py \
|
||||
The configuration system follows a precedence order in which higher priority sources override lower priority ones:
|
||||
|
||||
1. **CLI Arguments** (highest priority) - Direct command line arguments
|
||||
1. **YAML Configs** - Files specified via `--yaml-configs` and `--args.yaml-configs`
|
||||
1. **YAML Configs** - Files specified via `--yaml-extra` and `--args.yaml-extra`
|
||||
1. **Default Settings** (lowest priority) - Built-in defaults from the config classes
|
||||
|
||||
**Deep Merging**: Unlike simple overwriting, deep merging recursively combines nested dictionaries. For example:
|
||||
@ -152,12 +144,12 @@ args:
|
||||
**Nested Config Behavior**: When using nested configurations, outer YAML configuration files become initialization settings for inner objects, giving them higher precedence:
|
||||
|
||||
```bash
|
||||
# The outer yaml-configs affects the entire ExperimentConfig
|
||||
# The inner args.yaml-configs affects only the AutoDeployConfig
|
||||
# The outer yaml-extra affects the entire ExperimentConfig
|
||||
# The inner args.yaml-extra affects only the AutoDeployConfig
|
||||
python build_and_run_ad.py \
|
||||
--model "meta-llama/Meta-Llama-3.1-8B-Instruct" \
|
||||
--yaml-configs experiment_config.yaml \
|
||||
--args.yaml-configs autodeploy_config.yaml \
|
||||
--yaml-extra experiment_config.yaml \
|
||||
--args.yaml-extra autodeploy_config.yaml \
|
||||
--args.world-size=8 # CLI override beats both YAML configs
|
||||
```
|
||||
|
||||
|
||||
@ -18,9 +18,7 @@ llm = LLM(
|
||||
attn_page_size=64, # page size for attention (tokens_per_block, should be == max_seq_len for triton)
|
||||
skip_loading_weights=False,
|
||||
model_factory="AutoModelForCausalLM", # choose appropriate model factory
|
||||
mla_backend="MultiHeadLatentAttention", # for models that support MLA
|
||||
free_mem_ratio=0.8, # fraction of available memory for cache
|
||||
simple_shard_only=False, # tensor parallelism sharding strategy
|
||||
max_seq_len=<MAX_SEQ_LEN>,
|
||||
max_batch_size=<MAX_BATCH_SIZE>,
|
||||
)
|
||||
|
||||
@ -113,6 +113,7 @@ Optimize attention operations with different attention kernel implementations:
|
||||
|
||||
| `"attn_backend"` | Description |
|
||||
|----------------------|-------------|
|
||||
| `torch` | Custom fused multi-head attention (MHA) with KV Cache reference implementation in pure PyTorch (slow!) |
|
||||
| `triton` | Custom fused multi-head attention (MHA) with KV Cache kernels for efficient attention processing. |
|
||||
| `flashinfer` | Uses optimized attention kernels with KV Cache from the [`flashinfer`](https://github.com/flashinfer-ai/flashinfer.git) library. |
|
||||
|
||||
|
||||
@ -40,29 +40,31 @@ trtllm-bench \
|
||||
#### Basic Performance Configuration (`autodeploy_config.yaml`)
|
||||
|
||||
```yaml
|
||||
# Compilation backend
|
||||
compile_backend: torch-opt
|
||||
|
||||
# Runtime engine
|
||||
# runtime engine
|
||||
runtime: trtllm
|
||||
|
||||
# Model loading
|
||||
# model loading
|
||||
skip_loading_weights: false
|
||||
|
||||
# Fraction of free memory to use for kv-caches
|
||||
free_mem_ratio: 0.8
|
||||
|
||||
# CUDA Graph optimization
|
||||
cuda_graph_batch_sizes: [1, 2, 4, 8, 16, 32, 64, 128, 256]
|
||||
|
||||
# Attention backend
|
||||
attn_backend: flashinfer
|
||||
|
||||
# Sequence configuration
|
||||
max_batch_size: 256
|
||||
|
||||
# transform options
|
||||
transforms:
|
||||
insert_cached_attention:
|
||||
# attention backend
|
||||
backend: flashinfer
|
||||
resize_kv_cache:
|
||||
# fraction of free memory to use for kv-caches
|
||||
free_mem_ratio: 0.8
|
||||
compile_model:
|
||||
# compilation backend
|
||||
backend: torch-opt
|
||||
# CUDA Graph optimization
|
||||
cuda_graph_batch_sizes: [1, 2, 4, 8, 16, 32, 64, 128, 256]
|
||||
```
|
||||
|
||||
Enable multi-GPU execution by specifying `--tp n`, where `n` is the number of GPUs
|
||||
Enable multi-GPU execution by specifying `--tp n`, where `n` is the number of GPUs.
|
||||
|
||||
## Configuration Options Reference
|
||||
|
||||
|
||||
@ -63,15 +63,15 @@ args:
|
||||
num_hidden_layers: 12
|
||||
hidden_size: 1024
|
||||
world_size: 4
|
||||
compile_backend: torch-compile
|
||||
attn_backend: triton
|
||||
max_seq_len: 2048
|
||||
max_batch_size: 16
|
||||
transforms:
|
||||
sharding:
|
||||
strategy: auto
|
||||
quantization:
|
||||
enabled: false
|
||||
detect_sharding:
|
||||
support_partial_config: true
|
||||
insert_cached_attention:
|
||||
backend: triton
|
||||
compile_model:
|
||||
backend: torch-compile
|
||||
|
||||
prompt:
|
||||
batch_size: 8
|
||||
@ -79,13 +79,6 @@ prompt:
|
||||
max_tokens: 150
|
||||
temperature: 0.8
|
||||
top_k: 50
|
||||
|
||||
benchmark:
|
||||
enabled: true
|
||||
num: 20
|
||||
bs: 4
|
||||
isl: 1024
|
||||
osl: 256
|
||||
```
|
||||
|
||||
Create an additional override file (e.g., `production.yaml`):
|
||||
@ -94,11 +87,10 @@ Create an additional override file (e.g., `production.yaml`):
|
||||
# production.yaml
|
||||
args:
|
||||
world_size: 8
|
||||
compile_backend: torch-opt
|
||||
max_batch_size: 32
|
||||
|
||||
benchmark:
|
||||
enabled: false
|
||||
transforms:
|
||||
compile_model:
|
||||
backend: torch-opt
|
||||
```
|
||||
|
||||
Then use these configurations:
|
||||
@ -107,18 +99,18 @@ Then use these configurations:
|
||||
# Using single YAML config
|
||||
python build_and_run_ad.py \
|
||||
--model "meta-llama/Meta-Llama-3.1-8B-Instruct" \
|
||||
--yaml-configs my_config.yaml
|
||||
--yaml-extra my_config.yaml
|
||||
|
||||
# Using multiple YAML configs (deep merged in order, later files have higher priority)
|
||||
python build_and_run_ad.py \
|
||||
--model "meta-llama/Meta-Llama-3.1-8B-Instruct" \
|
||||
--yaml-configs my_config.yaml production.yaml
|
||||
--yaml-extra my_config.yaml production.yaml
|
||||
|
||||
# Targeting nested AutoDeployConfig with separate YAML
|
||||
python build_and_run_ad.py \
|
||||
--model "meta-llama/Meta-Llama-3.1-8B-Instruct" \
|
||||
--yaml-configs my_config.yaml \
|
||||
--args.yaml-configs autodeploy_overrides.yaml
|
||||
--yaml-extra my_config.yaml \
|
||||
--args.yaml-extra autodeploy_overrides.yaml
|
||||
```
|
||||
|
||||
## Configuration Precedence and Deep Merging
|
||||
@ -126,7 +118,7 @@ python build_and_run_ad.py \
|
||||
The configuration system follows a precedence order in which higher priority sources override lower priority ones:
|
||||
|
||||
1. **CLI Arguments** (highest priority) - Direct command line arguments
|
||||
1. **YAML Configs** - Files specified via `--yaml-configs` and `--args.yaml-configs`
|
||||
1. **YAML Configs** - Files specified via `--yaml-extra` and `--args.yaml-extra`
|
||||
1. **Default Settings** (lowest priority) - Built-in defaults from the config classes
|
||||
|
||||
**Deep Merging**: Unlike simple overwriting, deep merging recursively combines nested dictionaries. For example:
|
||||
@ -152,12 +144,12 @@ args:
|
||||
**Nested Config Behavior**: When using nested configurations, outer YAML configuration files become initialization settings for inner objects, giving them higher precedence:
|
||||
|
||||
```bash
|
||||
# The outer yaml-configs affects the entire ExperimentConfig
|
||||
# The inner args.yaml-configs affects only the AutoDeployConfig
|
||||
# The outer yaml-extra affects the entire ExperimentConfig
|
||||
# The inner args.yaml-extra affects only the AutoDeployConfig
|
||||
python build_and_run_ad.py \
|
||||
--model "meta-llama/Meta-Llama-3.1-8B-Instruct" \
|
||||
--yaml-configs experiment_config.yaml \
|
||||
--args.yaml-configs autodeploy_config.yaml \
|
||||
--yaml-extra experiment_config.yaml \
|
||||
--args.yaml-extra autodeploy_config.yaml \
|
||||
--args.world-size=8 # CLI override beats both YAML configs
|
||||
```
|
||||
|
||||
|
||||
@ -42,23 +42,31 @@ trtllm-serve \
|
||||
Example `autodeploy_config.yaml`:
|
||||
|
||||
```yaml
|
||||
# Compilation backend for AutoDeploy
|
||||
compile_backend: torch-opt # options: torch-simple, torch-compile, torch-cudagraph, torch-opt
|
||||
# runtime engine
|
||||
runtime: trtllm
|
||||
|
||||
# Runtime engine
|
||||
runtime: trtllm # options: trtllm, demollm
|
||||
# model loading
|
||||
skip_loading_weights: false
|
||||
|
||||
# Model loading
|
||||
skip_loading_weights: false # set true for architecture-only perf runs
|
||||
# Sequence configuration
|
||||
max_batch_size: 256
|
||||
|
||||
# KV cache memory
|
||||
free_mem_ratio: 0.8 # fraction of free GPU mem for KV cache
|
||||
# multi-gpu execution
|
||||
world_size: 1
|
||||
|
||||
# CUDA graph optimization
|
||||
cuda_graph_batch_sizes: [1, 2, 4, 8, 16, 32, 64]
|
||||
|
||||
# Attention backend
|
||||
attn_backend: flashinfer # recommended for best performance
|
||||
# transform options
|
||||
transforms:
|
||||
insert_cached_attention:
|
||||
# attention backend
|
||||
backend: flashinfer
|
||||
resize_kv_cache:
|
||||
# fraction of free memory to use for kv-caches
|
||||
free_mem_ratio: 0.8
|
||||
compile_model:
|
||||
# compilation backend
|
||||
backend: torch-opt
|
||||
# CUDA Graph optimization
|
||||
cuda_graph_batch_sizes: [1, 2, 4, 8, 16, 32, 64, 128, 256]
|
||||
```
|
||||
|
||||
## Limitations and tips
|
||||
|
||||
@ -12,15 +12,18 @@ from tensorrt_llm._torch.auto_deploy import LLM
|
||||
llm = LLM(
|
||||
model=<HF_MODEL_CARD_OR_DIR>,
|
||||
world_size=<DESIRED_WORLD_SIZE>,
|
||||
compile_backend="torch-compile",
|
||||
model_kwargs={"num_hidden_layers": 2}, # test with smaller model configuration
|
||||
attn_backend="flashinfer", # choose between "triton" and "flashinfer"
|
||||
attn_page_size=64, # page size for attention (tokens_per_block, should be == max_seq_len for triton)
|
||||
skip_loading_weights=False,
|
||||
model_factory="AutoModelForCausalLM", # choose appropriate model factory
|
||||
mla_backend="MultiHeadLatentAttention", # for models that support MLA
|
||||
free_mem_ratio=0.8, # fraction of available memory for cache
|
||||
simple_shard_only=False, # tensor parallelism sharding strategy
|
||||
model_kwargs={"num_hidden_layers": 2}, # test with smaller model configuration
|
||||
transforms={
|
||||
"insert_cached_attention": {"backend": "flashinfer"}, # or "triton"
|
||||
"insert_cached_mla_attention": {"backend": "MultiHeadLatentAttention"},
|
||||
"resize_kv_cache": {"free_mem_ratio": 0.8},
|
||||
"compile_model": {"backend": "torch-compile"},
|
||||
"detect_sharding": {"simple_shard_only": False},
|
||||
|
||||
},
|
||||
attn_page_size=64, # page size for attention
|
||||
skip_loading_weights=False,
|
||||
max_seq_len=<MAX_SEQ_LEN>,
|
||||
max_batch_size=<MAX_BATCH_SIZE>,
|
||||
)
|
||||
|
||||
4
examples/auto_deploy/.vscode/launch.json
vendored
4
examples/auto_deploy/.vscode/launch.json
vendored
@ -10,9 +10,9 @@
|
||||
"--model=meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"--args.world-size=2",
|
||||
"--args.runtime=demollm",
|
||||
"--args.compile-backend=torch-simple",
|
||||
"--args.transforms.compile-model.backend=torch-simple",
|
||||
"--args.attn-page-size=16",
|
||||
"--args.attn-backend=flashinfer",
|
||||
"--args.transforms.insert-cached-attention.backend=flashinfer",
|
||||
"--args.model-factory=AutoModelForCausalLM",
|
||||
"--benchmark.enabled=false",
|
||||
"--prompt.batch-size=2",
|
||||
|
||||
@ -128,9 +128,7 @@ llm = LLM(
|
||||
attn_page_size=64, # page size for attention (tokens_per_block, should be == max_seq_len for triton)
|
||||
skip_loading_weights=False,
|
||||
model_factory="AutoModelForCausalLM", # choose appropriate model factory
|
||||
mla_backend="MultiHeadLatentAttention", # for models that support MLA
|
||||
free_mem_ratio=0.8, # fraction of available memory for cache
|
||||
simple_shard_only=False, # tensor parallelism sharding strategy
|
||||
max_seq_len=<MAX_SEQ_LEN>,
|
||||
max_batch_size=<MAX_BATCH_SIZE>,
|
||||
)
|
||||
@ -218,15 +216,15 @@ args:
|
||||
num_hidden_layers: 12
|
||||
hidden_size: 1024
|
||||
world_size: 4
|
||||
compile_backend: torch-compile
|
||||
attn_backend: triton
|
||||
max_seq_len: 2048
|
||||
max_batch_size: 16
|
||||
transforms:
|
||||
sharding:
|
||||
strategy: auto
|
||||
quantization:
|
||||
enabled: false
|
||||
detect_sharding:
|
||||
support_partial_config: true
|
||||
insert_cached_attention:
|
||||
backend: triton
|
||||
compile_model:
|
||||
backend: torch-compile
|
||||
|
||||
prompt:
|
||||
batch_size: 8
|
||||
@ -234,13 +232,6 @@ prompt:
|
||||
max_tokens: 150
|
||||
temperature: 0.8
|
||||
top_k: 50
|
||||
|
||||
benchmark:
|
||||
enabled: true
|
||||
num: 20
|
||||
bs: 4
|
||||
isl: 1024
|
||||
osl: 256
|
||||
```
|
||||
|
||||
Create an additional override file (e.g., `production.yaml`):
|
||||
@ -249,11 +240,10 @@ Create an additional override file (e.g., `production.yaml`):
|
||||
# production.yaml
|
||||
args:
|
||||
world_size: 8
|
||||
compile_backend: torch-opt
|
||||
max_batch_size: 32
|
||||
|
||||
benchmark:
|
||||
enabled: false
|
||||
transforms:
|
||||
compile_model:
|
||||
backend: torch-opt
|
||||
```
|
||||
|
||||
Then use these configurations:
|
||||
|
||||
@ -280,7 +280,7 @@ def main(config: Optional[ExperimentConfig] = None):
|
||||
# run a benchmark for the model with batch_size == config.benchmark_bs
|
||||
if config.benchmark.enabled and config.args.runtime != "trtllm":
|
||||
ad_logger.info("Running benchmark...")
|
||||
keys_from_args = ["compile_backend", "attn_backend", "mla_backend"]
|
||||
keys_from_args = []
|
||||
fields_to_show = [f"benchmark={config.benchmark}"]
|
||||
fields_to_show.extend([f"{k}={getattr(config.args, k)}" for k in keys_from_args])
|
||||
results["benchmark_results"] = benchmark(
|
||||
|
||||
@ -38,10 +38,12 @@ transforms:
|
||||
stage: pattern_matcher
|
||||
match_attention_layout:
|
||||
stage: pattern_matcher
|
||||
attn_layout: bsnd
|
||||
match_rope_pattern:
|
||||
stage: pattern_matcher
|
||||
match_rope_layout:
|
||||
stage: pattern_matcher
|
||||
expected_layout: bsnd
|
||||
############################################################################################
|
||||
# RUN TRANSFORMATIONS ON STANDARDIZED GRAPH REPRESENTATION
|
||||
############################################################################################
|
||||
@ -87,6 +89,7 @@ transforms:
|
||||
load_weights:
|
||||
stage: weight_load
|
||||
run_per_gm: false
|
||||
checkpoint_device: null
|
||||
move_inputs_to_device:
|
||||
stage: weight_load
|
||||
run_per_gm: false
|
||||
@ -122,30 +125,40 @@ transforms:
|
||||
backend: flashinfer
|
||||
requires_shape_prop: true
|
||||
############################################################################################
|
||||
# VISUALIZE GRAPH
|
||||
############################################################################################
|
||||
visualize_namespace:
|
||||
stage: visualize
|
||||
enabled: false # TODO: https://github.com/NVIDIA/TensorRT-LLM/issues/8460
|
||||
############################################################################################
|
||||
# SWITCH TO CACHED+FLATTENED ATTENTION + INITIALIZE CACHES
|
||||
############################################################################################
|
||||
update_in_out_nodes:
|
||||
stage: cache_init
|
||||
insert_cached_attention:
|
||||
stage: cache_init
|
||||
backend: flashinfer
|
||||
insert_cached_mla_attention:
|
||||
stage: cache_init
|
||||
attn_backend: MultiHeadLatentAttention
|
||||
backend: MultiHeadLatentAttention
|
||||
insert_cached_ssm_attention:
|
||||
stage: cache_init
|
||||
attn_backend: triton_ssm
|
||||
backend: triton_ssm
|
||||
insert_cached_causal_conv:
|
||||
stage: cache_init
|
||||
attn_backend: cuda_causal_conv
|
||||
backend: cuda_causal_conv
|
||||
initialize_cache:
|
||||
stage: cache_init
|
||||
run_per_gm: false
|
||||
resize_kv_cache:
|
||||
stage: cache_init
|
||||
run_per_gm: false
|
||||
free_mem_ratio: 0.0
|
||||
############################################################################################
|
||||
# COMPILE MODEL
|
||||
############################################################################################
|
||||
compile_model:
|
||||
stage: compile
|
||||
run_per_gm: false
|
||||
cuda_graph_batch_sizes: null
|
||||
backend: torch-compile
|
||||
|
||||
@ -21,7 +21,7 @@ transforms:
|
||||
run_per_gm: false
|
||||
transformers_replace_cached_attn:
|
||||
stage: cache_init
|
||||
attn_backend: flashinfer
|
||||
backend: flashinfer
|
||||
run_per_gm: false
|
||||
initialize_cache:
|
||||
stage: cache_init
|
||||
@ -29,6 +29,7 @@ transforms:
|
||||
resize_kv_cache:
|
||||
stage: cache_init
|
||||
run_per_gm: false
|
||||
free_mem_ratio: 0.0
|
||||
############################################################################################
|
||||
# COMPILE MODEL
|
||||
############################################################################################
|
||||
|
||||
@ -10,12 +10,7 @@ import torch.export as te
|
||||
import torch.nn as nn
|
||||
from torch import fx
|
||||
|
||||
from ..transformations._graph import (
|
||||
canonicalize_graph,
|
||||
lift_to_meta,
|
||||
load_buffers_and_params,
|
||||
tree_to,
|
||||
)
|
||||
from ..utils._graph import canonicalize_graph, lift_to_meta, load_buffers_and_params, tree_to
|
||||
from ..utils.logger import ad_logger
|
||||
from ..utils.node_utils import is_op
|
||||
from .interface import apply_export_patches
|
||||
|
||||
@ -38,6 +38,19 @@ def _check_for_default_value_only(
|
||||
return value
|
||||
|
||||
|
||||
_TRANSFORMS_SHORTCUT_LOOKUP = {
|
||||
"attn_backend": ("insert_cached_attention.backend", "transformers_replace_cached_attn.backend"),
|
||||
"free_mem_ratio": ("resize_kv_cache.free_mem_ratio",),
|
||||
"compile_backend": ("compile_model.backend",),
|
||||
"cuda_graph_batch_sizes": ("compile_model.cuda_graph_batch_sizes",),
|
||||
}
|
||||
|
||||
|
||||
def _shortcut_description(description: str, shortcut: str) -> str:
|
||||
long_names_str = ", ".join([f"transforms.{k}" for k in _TRANSFORMS_SHORTCUT_LOOKUP[shortcut]])
|
||||
return f"{description} Alias for: {long_names_str}."
|
||||
|
||||
|
||||
class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings):
|
||||
"""An argument class stripped down to AutoDeploy-specific configurations.
|
||||
|
||||
@ -75,12 +88,6 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings):
|
||||
"If True, only the model architecture is loaded.",
|
||||
)
|
||||
|
||||
checkpoint_device: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Device on which to load the model checkpoint. "
|
||||
"Defaults to the same device as the rest of the pipeline.",
|
||||
)
|
||||
|
||||
tokenizer: Optional[PathLike] = Field(
|
||||
description="The tokenizer",
|
||||
default=None,
|
||||
@ -141,53 +148,6 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings):
|
||||
)
|
||||
|
||||
### INFERENCE OPTIMIZER CONFIG #################################################################
|
||||
attn_backend: Literal["flashinfer", "triton", "torch"] = Field(
|
||||
default="flashinfer", description="Attention backend to use."
|
||||
)
|
||||
|
||||
mla_backend: Literal["MultiHeadLatentAttention"] = Field(
|
||||
default="MultiHeadLatentAttention",
|
||||
description="The Multi-Head Latent Attention backend to use.",
|
||||
)
|
||||
|
||||
free_mem_ratio: float = Field(
|
||||
default=0.0,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="The fraction of available memory to allocate for cache.",
|
||||
)
|
||||
|
||||
simple_shard_only: bool = Field(
|
||||
default=False,
|
||||
description="If True, force simple sharding (all_gather) in tensor parallelism. "
|
||||
"If False, auto-detect and use column+row (all_reduce) sharding when possible.",
|
||||
)
|
||||
|
||||
use_sharding_from_factory: bool = Field(
|
||||
default=False,
|
||||
description="If True, use sharding from the model factory. If False, use sharding from the "
|
||||
"AutoDeployConfig.",
|
||||
)
|
||||
|
||||
sharding_dims: List[str] = Field(
|
||||
default=["tp", "ep", "dp"],
|
||||
description="The sharding methods to apply by the heuristic sharding stage.",
|
||||
)
|
||||
|
||||
compile_backend: Literal["torch-simple", "torch-compile", "torch-cudagraph", "torch-opt"] = (
|
||||
Field(
|
||||
default="torch-compile",
|
||||
description="The backend to use for compiling the model.",
|
||||
)
|
||||
)
|
||||
|
||||
cuda_graph_batch_sizes: Optional[List[int]] = Field(
|
||||
default=None, description="List of batch sizes to create CUDA graphs for."
|
||||
)
|
||||
|
||||
visualize: bool = Field(default=False, description="Whether to visualize the model graph.")
|
||||
|
||||
### NEW INFERENCE OPTIMIZER CONFIG #############################################################
|
||||
mode: Literal["graph", "transformers"] = Field(
|
||||
default="graph",
|
||||
description="The mode to use for the inference optimizer. Currently, we "
|
||||
@ -195,12 +155,38 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings):
|
||||
"or transformers-only cached attention optimization.",
|
||||
)
|
||||
|
||||
transforms: Dict[str, Any] = Field(
|
||||
transforms: Dict[str, Dict[str, Any]] = Field(
|
||||
default_factory=dict,
|
||||
description="A dictionary of transform configurations. The key is the transform name and "
|
||||
"the value is the transform configuration.",
|
||||
)
|
||||
|
||||
### SHORTCUTS FOR COMMON INFERENCE OPTIMIZER CONFIGS ###########################################
|
||||
attn_backend: str = Field(
|
||||
default="flashinfer",
|
||||
description=_shortcut_description("Attention backend to use.", "attn_backend"),
|
||||
)
|
||||
free_mem_ratio: float = Field(
|
||||
default=0.0,
|
||||
description=_shortcut_description(
|
||||
"The fraction of available memory to allocate for cache.", "free_mem_ratio"
|
||||
),
|
||||
)
|
||||
compile_backend: str = Field(
|
||||
default="torch-compile",
|
||||
description=_shortcut_description(
|
||||
"The backend to use for compiling the model.", "compile_backend"
|
||||
),
|
||||
)
|
||||
cuda_graph_batch_sizes: Optional[List[int]] = Field(
|
||||
default=None,
|
||||
description=_shortcut_description(
|
||||
"List of batch sizes for CUDA graph creation. If not provided, a heuristic will"
|
||||
" be used to determine the batch sizes.",
|
||||
"cuda_graph_batch_sizes",
|
||||
),
|
||||
)
|
||||
|
||||
### SEQUENCE INTERFACE CONFIG ##################################################################
|
||||
max_input_len: int = Field(default=1024, description="The maximum input length.")
|
||||
max_num_tokens: Optional[int] = Field(default=None, description="The maximum number of tokens.")
|
||||
@ -219,8 +205,16 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings):
|
||||
# TODO: discuss what to do with this once we fully transition to the new inference optimizer
|
||||
def update_attn_page_size(self):
|
||||
# NOTE force attn_page_size to equal max_seq_len for triton backend
|
||||
# TODO: maybe don't do this and rely on slot_idx instead??
|
||||
if self.attn_backend == "triton" or self.attn_backend == "torch":
|
||||
if self.transforms.get("insert_cached_attention", {}).get("backend") in [
|
||||
"triton",
|
||||
"torch",
|
||||
]:
|
||||
self.attn_page_size = self.max_seq_len
|
||||
# NOTE: (hg) For transformers mode. This is ugly.
|
||||
if self.transforms.get("transformers_replace_cached_attn", {}).get("backend") in [
|
||||
"triton",
|
||||
"torch",
|
||||
]:
|
||||
self.attn_page_size = self.max_seq_len
|
||||
return self
|
||||
|
||||
@ -235,6 +229,27 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings):
|
||||
|
||||
return value
|
||||
|
||||
@model_validator(mode="after")
|
||||
def update_transforms_with_shortcuts(self) -> Dict[str, Any]:
|
||||
"""Synchronize the transforms config with the values from the defined shortcuts.
|
||||
|
||||
NOTE: shortcut values always take precedence over the values in the transforms config.
|
||||
"""
|
||||
for shortcut_key, transforms_keys in _TRANSFORMS_SHORTCUT_LOOKUP.items():
|
||||
for transform_key in transforms_keys:
|
||||
t_key, config_key = transform_key.split(".")
|
||||
if t_key not in self.transforms:
|
||||
continue
|
||||
|
||||
# first update the transforms config with the shortcut value
|
||||
if shortcut_key in self.model_fields_set:
|
||||
self.transforms[t_key][config_key] = getattr(self, shortcut_key)
|
||||
# then update the shortcut field with the value from the transforms config to make
|
||||
# sure both fields are in sync
|
||||
setattr(self, shortcut_key, self.transforms[t_key][config_key])
|
||||
|
||||
return self
|
||||
|
||||
### UTILITY METHODS ############################################################################
|
||||
def create_factory(self) -> ModelFactory:
|
||||
"""Create a model factory from the arguments."""
|
||||
|
||||
@ -25,7 +25,7 @@ from ...pyexecutor.scheduler import (
|
||||
from ..custom_ops.attention_interface import SequenceInfo
|
||||
from ..distributed import common as dist
|
||||
from ..llm_args import AutoDeployConfig, LlmArgs
|
||||
from ..transformations.transform import InferenceOptimizer
|
||||
from ..transform.optimizer import InferenceOptimizer
|
||||
from ..utils.logger import ad_logger
|
||||
from .interface import CachedSequenceInterface, GetInferenceModel
|
||||
|
||||
@ -110,7 +110,7 @@ class ADEngine(ModelEngine):
|
||||
# ADEngine.__init__, and ADEngine.build_from_config. Seems a bit unnatural atm.
|
||||
|
||||
# construct inference optimizer
|
||||
build_and_optimize = InferenceOptimizer(factory=factory, ad_config=ad_config)
|
||||
build_and_optimize = InferenceOptimizer(factory=factory, config=ad_config.transforms)
|
||||
|
||||
# construct engine
|
||||
return cls(build_and_optimize, seq_info, device, max_beam_width)
|
||||
|
||||
@ -14,6 +14,11 @@ class CachedSequenceInterface:
|
||||
def __init__(
|
||||
self, sequence_info: SequenceInfo, device: Optional[DeviceLikeType] = None
|
||||
) -> None:
|
||||
# TODO (lucaslie): this is somewhat circular/confusing. Here `device` denotes the desired
|
||||
# device and not the actual device unlike, e.g., in SequenceInfo. We rely on the attribute
|
||||
# here to read the desired device across the inference optimizer pipeline. We should ideally
|
||||
# think about a better way to handle this,
|
||||
# see https://github.com/NVIDIA/TensorRT-LLM/issues/8371
|
||||
self.device = device or "cuda"
|
||||
self.info = sequence_info
|
||||
self._cache_initializers: Dict[str, GetCacheCallable] = {}
|
||||
|
||||
@ -16,7 +16,7 @@ from torch.fx import GraphModule
|
||||
|
||||
from ..models.factory import ModelFactory
|
||||
from ..shim.interface import CachedSequenceInterface
|
||||
from ..transformations._graph import (
|
||||
from ..utils._graph import (
|
||||
canonicalize_graph,
|
||||
lift_to_meta,
|
||||
named_graphmodules,
|
||||
@ -48,6 +48,7 @@ class Stages(Enum):
|
||||
WEIGHT_LOAD = "weight_load" # loading of the model weights
|
||||
POST_LOAD_FUSION = "post_load_fusion" # post-loading fusion and perf optimizations of the graph
|
||||
CACHE_INIT = "cache_init" # initialization of cached attention + (KV) cache initialization
|
||||
VISUALIZE = "visualize" # visualization of the graph
|
||||
COMPILE = "compile" # graph compilation stage using low-level compilers like torch.compile
|
||||
|
||||
def __lt__(self, other):
|
||||
@ -63,7 +64,6 @@ class SharedConfig(BaseModel):
|
||||
sharding_config: ShardingConfig = Field(default_factory=ShardingConfig)
|
||||
local_rank: int = Field(default=0)
|
||||
world_size: int = Field(default=1)
|
||||
attn_backend: str = Field(default="flashinfer", description="The attention backend to use.")
|
||||
|
||||
|
||||
class TransformConfig(BaseModel):
|
||||
|
||||
@ -2,14 +2,13 @@
|
||||
|
||||
from inspect import Parameter, Signature
|
||||
from itertools import product
|
||||
from typing import Any, Callable, Dict, List, Tuple, Type
|
||||
from typing import Any, Callable, Dict, List, Literal, Tuple, Type
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from pydantic import Field
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from ...custom_ops.attention_interface import AttentionDescriptor
|
||||
from ...models.factory import ModelFactory
|
||||
from ...shim.interface import CachedSequenceInterface
|
||||
from ...utils.logger import ad_logger
|
||||
@ -696,9 +695,11 @@ def register_match_attn_layout(patterns: ADPatternMatcherPass):
|
||||
|
||||
|
||||
class MatchAttentionLayoutConfig(TransformConfig):
|
||||
"""Configuration for the insert cached attention transform."""
|
||||
"""Configuration for the match attention layout transform."""
|
||||
|
||||
attention_op: Type[AttentionDescriptor] = Field(description="The attention descriptor to use.")
|
||||
attn_layout: Literal["bsnd", "bnsd"] = Field(
|
||||
description="Layout expected by the attention backend."
|
||||
)
|
||||
|
||||
|
||||
@TransformRegistry.register("match_attention_layout")
|
||||
@ -721,13 +722,8 @@ class MatchAttentionLayout(BaseTransform):
|
||||
factory: ModelFactory,
|
||||
shared_config: SharedConfig,
|
||||
) -> Tuple[GraphModule, TransformInfo]:
|
||||
attention_layout = self.config.attention_op.get_attention_layout()
|
||||
|
||||
if attention_layout not in ("bnsd", "bsnd"):
|
||||
raise ValueError(f"Unsupported attention layout: {attention_layout}")
|
||||
|
||||
# If backend expects bnsd, nothing to do.
|
||||
if attention_layout == "bnsd":
|
||||
if self.config.attn_layout == "bnsd":
|
||||
return gm, TransformInfo(
|
||||
skipped=False, num_matches=0, is_clean=False, has_valid_shapes=False
|
||||
)
|
||||
|
||||
@ -76,7 +76,7 @@ class BuildAndLoadFactoryModel(BuildModel):
|
||||
assert isinstance(factory, hf.AutoModelFactory), "Only HF models are supported."
|
||||
|
||||
# build and load the model
|
||||
model = factory.build_and_load_model(self.config.device)
|
||||
model = factory.build_and_load_model(cm.device)
|
||||
|
||||
# we set the standard example sequence WITHOUT extra_args to set them to None so that
|
||||
# only the text portion of the model gets called.
|
||||
|
||||
@ -24,8 +24,8 @@ class CompileModelConfig(TransformConfig):
|
||||
num_batched_inputs: int = Field(
|
||||
default=2, description="The number of batched inputs to use for CUDA graphs."
|
||||
)
|
||||
compile_backend: Literal["torch-simple", "torch-compile", "torch-cudagraph", "torch-opt"] = (
|
||||
Field(description="The backend to use for compiling the model.")
|
||||
backend: Literal["torch-simple", "torch-compile", "torch-cudagraph", "torch-opt"] = Field(
|
||||
description="The backend to use for compiling the model."
|
||||
)
|
||||
|
||||
|
||||
@ -48,7 +48,7 @@ class CompileModel(BaseTransform):
|
||||
) -> Tuple[nn.Module, TransformInfo]:
|
||||
cm.info.set_generate_only_batch()
|
||||
|
||||
compiler_cls = CompileBackendRegistry.get(self.config.compile_backend)
|
||||
compiler_cls = CompileBackendRegistry.get(self.config.backend)
|
||||
mod_compiled = compiler_cls(
|
||||
mod,
|
||||
args=(),
|
||||
|
||||
@ -12,7 +12,7 @@ from ...custom_ops.attention_interface import AttentionDescriptor, AttentionRegi
|
||||
from ...distributed.common import all_gather_object, get_world_size
|
||||
from ...models.factory import ModelFactory
|
||||
from ...shim.interface import CachedSequenceInterface
|
||||
from ...transformations._graph import add_graph_input
|
||||
from ...utils._graph import add_graph_input
|
||||
from ...utils.node_utils import get_all_input_output_nodes, is_op
|
||||
from ..interface import (
|
||||
BaseTransform,
|
||||
@ -64,16 +64,13 @@ class UpdateInOutNodes(BaseTransform):
|
||||
class InsertCachedAttentionConfig(TransformConfig):
|
||||
"""Configuration for the insert cached attention transform."""
|
||||
|
||||
attn_backend: Optional[str] = Field(default=None, description="The attention backend to use.")
|
||||
backend: Optional[str] = Field(default=None, description="The attention backend to use.")
|
||||
|
||||
|
||||
@TransformRegistry.register("insert_cached_attention")
|
||||
class InsertCachedAttention(BaseTransform):
|
||||
"""
|
||||
A transform to insert cached attention into the graph module.
|
||||
|
||||
If attn_backend is not provided in transform config, will find from shared config.
|
||||
"""
|
||||
A transform to insert cached attention into the graph module."""
|
||||
|
||||
config: InsertCachedAttentionConfig
|
||||
|
||||
@ -83,7 +80,7 @@ class InsertCachedAttention(BaseTransform):
|
||||
|
||||
@property
|
||||
def attn_descriptor(self) -> Type[AttentionDescriptor]:
|
||||
return AttentionRegistry.get(self.config.attn_backend)
|
||||
return AttentionRegistry.get(self.config.backend)
|
||||
|
||||
def _process_get_metadata(
|
||||
self, gm: GraphModule, m_args: List[str], const_args: List[Constant]
|
||||
@ -222,7 +219,7 @@ class ResizeKVCacheConfig(TransformConfig):
|
||||
"""Configuration for the resize kv cache transform."""
|
||||
|
||||
free_mem_ratio: float = Field(
|
||||
description="The fraction of available memory to occupy.", default=0.8
|
||||
default=0.0, ge=0.0, le=1.0, description="The fraction of available memory to occupy."
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -7,7 +7,7 @@ from pydantic import Field
|
||||
|
||||
from ...models.factory import ModelFactory
|
||||
from ...shim.interface import CachedSequenceInterface
|
||||
from ...transformations._graph import move_to_device
|
||||
from ...utils._graph import move_to_device
|
||||
from ..interface import (
|
||||
BaseTransform,
|
||||
SharedConfig,
|
||||
@ -20,9 +20,9 @@ from ..interface import (
|
||||
class MoveDeviceConfig(TransformConfig):
|
||||
"""Configuration for the moving inputs/arguments to the device transform."""
|
||||
|
||||
device: str = Field(default="meta", description="The device to load the weights on.")
|
||||
adconfig_checkpoint_device: Optional[str] = Field(
|
||||
default=None, description="Optional checkpoint device argument from adconfig."
|
||||
checkpoint_device: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Optional device to init checkpoint before move to shared_config.local_device.",
|
||||
)
|
||||
|
||||
|
||||
@ -45,9 +45,9 @@ class LoadWeightsToDevice(BaseTransform):
|
||||
) -> Tuple[nn.Module, TransformInfo]:
|
||||
factory.load_or_random_init(
|
||||
mod,
|
||||
device=self.config.adconfig_checkpoint_device or self.config.device,
|
||||
device=self.config.checkpoint_device or cm.device,
|
||||
)
|
||||
move_to_device(mod, self.config.device)
|
||||
move_to_device(mod, cm.device)
|
||||
|
||||
info = TransformInfo(skipped=False, num_matches=0, is_clean=True, has_valid_shapes=True)
|
||||
|
||||
@ -71,7 +71,9 @@ class LoadFactoryModelWeights(BaseTransform):
|
||||
factory: ModelFactory,
|
||||
shared_config: SharedConfig,
|
||||
) -> Tuple[nn.Module, TransformInfo]:
|
||||
cm.to(self.config.device)
|
||||
# TODO (hg) This is weird but equivalent to previous code.
|
||||
# We does not seems to need this transform.
|
||||
cm.to(cm.device)
|
||||
|
||||
info = TransformInfo(skipped=False, num_matches=0, is_clean=True, has_valid_shapes=True)
|
||||
|
||||
|
||||
@ -3,12 +3,24 @@
|
||||
import json
|
||||
from typing import Tuple
|
||||
|
||||
import model_explorer
|
||||
import torch
|
||||
import torch.export as te
|
||||
from model_explorer.graph_builder import GraphNode, KeyValue, MetadataItem
|
||||
from model_explorer.pytorch_exported_program_adater_impl import PytorchExportedProgramAdapterImpl
|
||||
from torch import fx
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from ...models.factory import ModelFactory
|
||||
from ...shim.interface import CachedSequenceInterface
|
||||
from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry
|
||||
|
||||
try:
|
||||
import model_explorer
|
||||
from model_explorer.graph_builder import GraphNode, KeyValue, MetadataItem
|
||||
from model_explorer.pytorch_exported_program_adater_impl import (
|
||||
PytorchExportedProgramAdapterImpl,
|
||||
)
|
||||
except ImportError:
|
||||
model_explorer = None
|
||||
GraphNode = KeyValue = MetadataItem = PytorchExportedProgramAdapterImpl = None
|
||||
# Optionally, you can log a warning or handle this gracefully elsewhere
|
||||
|
||||
|
||||
def print_tensor(self, tensor: torch.Tensor, size_limit: int = 16):
|
||||
@ -62,9 +74,6 @@ def add_outputs_metadata(self, fx_node: torch.fx.node.Node, node: GraphNode):
|
||||
raise ValueError(f"Unsupported output type: {type(out_vals)}")
|
||||
|
||||
|
||||
PytorchExportedProgramAdapterImpl.print_tensor = print_tensor
|
||||
PytorchExportedProgramAdapterImpl.add_outputs_metadata = add_outputs_metadata
|
||||
|
||||
# TODO(yudong): make custom_ops configurable
|
||||
CUSTOM_OPS = (
|
||||
torch.ops.auto_deploy.torch_dist_all_reduce.default,
|
||||
@ -76,13 +85,26 @@ CUSTOM_OPS = (
|
||||
)
|
||||
|
||||
|
||||
# TODO(yudong): make viz as non-block call.
|
||||
def visualize_namespace(gm: fx.GraphModule, args: Tuple[torch.Tensor, ...], dynamic_shapes):
|
||||
ep = te.export(gm, args=args, dynamic_shapes=dynamic_shapes)
|
||||
graph = ep.graph
|
||||
# Ensure the ops land up in the right module for better viz
|
||||
for n in graph.nodes:
|
||||
if n.target in CUSTOM_OPS:
|
||||
n.meta["nn_module_stack"] = n.args[0].meta["nn_module_stack"]
|
||||
@TransformRegistry.register("visualize_namespace")
|
||||
class VisualizeNamespace(BaseTransform):
|
||||
def _apply(
|
||||
self,
|
||||
gm: GraphModule,
|
||||
cm: CachedSequenceInterface,
|
||||
factory: ModelFactory,
|
||||
shared_config: SharedConfig,
|
||||
) -> Tuple[GraphModule, TransformInfo]:
|
||||
PytorchExportedProgramAdapterImpl.print_tensor = print_tensor
|
||||
PytorchExportedProgramAdapterImpl.add_outputs_metadata = add_outputs_metadata
|
||||
|
||||
model_explorer.visualize_pytorch("model-viz", ep)
|
||||
# TODO(yudong): make viz as non-block call.
|
||||
ep = te.export(gm, args=cm.args, dynamic_shapes=cm.dynamic_shapes)
|
||||
graph = ep.graph
|
||||
# Ensure the ops land up in the right module for better viz
|
||||
for n in graph.nodes:
|
||||
if n.target in CUSTOM_OPS:
|
||||
n.meta["nn_module_stack"] = n.args[0].meta["nn_module_stack"]
|
||||
|
||||
model_explorer.visualize_pytorch("model-viz", ep)
|
||||
|
||||
return gm, TransformInfo(skipped=False, num_matches=1, is_clean=True, has_valid_shapes=True)
|
||||
@ -1,7 +1,9 @@
|
||||
"""High-level entrypoint to transform a model into an efficient inference model."""
|
||||
|
||||
import gc
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
|
||||
@ -71,4 +73,6 @@ class InferenceOptimizer:
|
||||
############################################################################################
|
||||
# RETURN OPTIMIZED MODEL
|
||||
############################################################################################
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
return mod
|
||||
|
||||
@ -1 +0,0 @@
|
||||
"""V1 Graph Transformations Module --> will be deprecated and replaced by auto_deploy.transform."""
|
||||
@ -1,6 +0,0 @@
|
||||
"""A library of transformation passes."""
|
||||
|
||||
try:
|
||||
from .visualization import visualize_namespace
|
||||
except ImportError:
|
||||
pass
|
||||
@ -1,117 +0,0 @@
|
||||
"""High-level entrypoint to transform a model into an efficient inference model."""
|
||||
|
||||
import gc
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..custom_ops.attention_interface import AttentionRegistry
|
||||
from ..llm_args import AutoDeployConfig
|
||||
from ..models.factory import ModelFactory
|
||||
from ..shim.interface import CachedSequenceInterface
|
||||
from ..transform.optimizer import InferenceOptimizer as ModularInferenceOptimizer
|
||||
|
||||
|
||||
class InferenceOptimizer:
|
||||
def __init__(self, factory: ModelFactory, ad_config: AutoDeployConfig):
|
||||
self.factory = factory
|
||||
self.ad_config = ad_config
|
||||
|
||||
def __call__(self, cm: CachedSequenceInterface) -> nn.Module:
|
||||
"""Transform a model into an optimized inference model.
|
||||
|
||||
Args:
|
||||
model: The model to transform.
|
||||
cp: The cache pool to use for caching.
|
||||
args: Example inputs to the model.
|
||||
dynamic_shapes: Dynamic shapes to use. Defaults to None.
|
||||
poe_config: The config for positional encoding. Defaults to None.
|
||||
quantization: The quantization method to use. Defaults to None.
|
||||
|
||||
Returns:
|
||||
A nn.Module representing the optimized inference model.
|
||||
"""
|
||||
############################################################################################
|
||||
# RUN MODULAR INFERENCE OPTIMIZER FOR ALREADY-MIGRATED TRANSFORMS
|
||||
############################################################################################
|
||||
# TODO (hg): default values that are not representable in YAML.
|
||||
# move to the optimizer
|
||||
if "match_attention_layout" in self.ad_config.transforms:
|
||||
self.ad_config.transforms["match_attention_layout"]["attention_op"] = (
|
||||
AttentionRegistry.get(self.ad_config.attn_backend)
|
||||
)
|
||||
if "match_rope_layout" in self.ad_config.transforms:
|
||||
self.ad_config.transforms["match_rope_layout"]["expected_layout"] = (
|
||||
AttentionRegistry.get(self.ad_config.attn_backend).get_attention_layout()
|
||||
)
|
||||
|
||||
if "load_weights" in self.ad_config.transforms:
|
||||
self.ad_config.transforms["load_weights"]["checkpoint_device"] = (
|
||||
self.ad_config.checkpoint_device
|
||||
)
|
||||
self.ad_config.transforms["load_weights"]["device"] = cm.device
|
||||
|
||||
if "build_and_load_factory_model" in self.ad_config.transforms:
|
||||
self.ad_config.transforms["build_and_load_factory_model"]["device"] = cm.device
|
||||
|
||||
if "move_inputs_to_device" in self.ad_config.transforms:
|
||||
self.ad_config.transforms["move_inputs_to_device"]["checkpoint_device"] = (
|
||||
self.ad_config.checkpoint_device
|
||||
)
|
||||
self.ad_config.transforms["move_inputs_to_device"]["device"] = cm.device
|
||||
|
||||
if "resize_kv_cache" in self.ad_config.transforms:
|
||||
self.ad_config.transforms["resize_kv_cache"]["free_mem_ratio"] = (
|
||||
self.ad_config.free_mem_ratio
|
||||
)
|
||||
if "insert_cached_attention" in self.ad_config.transforms:
|
||||
self.ad_config.transforms["insert_cached_attention"]["attn_backend"] = (
|
||||
self.ad_config.attn_backend
|
||||
)
|
||||
if "insert_cached_mla_attention" in self.ad_config.transforms:
|
||||
self.ad_config.transforms["insert_cached_mla_attention"]["attn_backend"] = (
|
||||
self.ad_config.mla_backend
|
||||
)
|
||||
if "transformers_replace_cached_attn" in self.ad_config.transforms:
|
||||
self.ad_config.transforms["transformers_replace_cached_attn"]["attn_backend"] = (
|
||||
self.ad_config.attn_backend
|
||||
)
|
||||
|
||||
# TODO: (hg)Missing MLA here. Figure out how to add MLA since duplicate transforms are not allowed.
|
||||
# Old code:
|
||||
# detect attention op and replace with cache-aware op
|
||||
# for a_backend in [self.ad_config.attn_backend, self.ad_config.mla_backend]:
|
||||
# attn_descriptor = AttentionRegistry.get(a_backend)
|
||||
# insert_cached_attention(egm, cm, attn_descriptor, self.factory.get_cache_config())
|
||||
|
||||
if "compile_model" in self.ad_config.transforms:
|
||||
self.ad_config.transforms["compile_model"]["cuda_graph_batch_sizes"] = (
|
||||
self.ad_config.cuda_graph_batch_sizes
|
||||
)
|
||||
self.ad_config.transforms["compile_model"]["compile_backend"] = (
|
||||
self.ad_config.compile_backend
|
||||
)
|
||||
|
||||
new_optimizer = ModularInferenceOptimizer(self.factory, self.ad_config.transforms)
|
||||
# TODO: (hg) move this. let match_rope_layout and match_atten_layout use this shared config
|
||||
new_optimizer.shared_config.attn_backend = self.ad_config.attn_backend
|
||||
|
||||
egm = new_optimizer(cm)
|
||||
|
||||
# NOTE: (hg)Disabled visualization since compiled gm is a CapturedGraph instead of GraphModule.
|
||||
# We can add a new stage in the optimizer to visualize the intermediate gm.
|
||||
# if self.ad_config.visualize:
|
||||
# try:
|
||||
# from .library import visualize_namespace
|
||||
|
||||
# visualize_namespace(egm, args=cm.args, dynamic_shapes=cm.dynamic_shapes)
|
||||
# ad_logger.warning(
|
||||
# "Please run `pip install -r examples/auto_deploy/requirements.txt` to visualize"
|
||||
# " the graph."
|
||||
# )
|
||||
# except ImportError:
|
||||
# pass
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
return egm
|
||||
@ -5,7 +5,7 @@ from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from pydantic import Field, field_validator, model_validator
|
||||
from pydantic import Field, model_validator
|
||||
from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, YamlConfigSettingsSource
|
||||
from pydantic_settings.sources.types import DEFAULT_PATH, PathType
|
||||
|
||||
@ -161,23 +161,6 @@ class DynamicYamlMixInForSettings:
|
||||
'with higher priority. "Later" files have higher priority. Should be used with care!',
|
||||
)
|
||||
|
||||
# TODO: remove this field in a future version
|
||||
yaml_configs: List[PathType] = Field(
|
||||
default_factory=list,
|
||||
description="DEPRECATED: Please use yaml_extra instead.",
|
||||
)
|
||||
|
||||
@field_validator("yaml_configs")
|
||||
@classmethod
|
||||
def validate_yaml_configs_deprecated(cls, v):
|
||||
"""Throw error that yaml_configs is deprecated in favor of yaml_extra."""
|
||||
if v: # Only raise error if the field is actually being used (not empty)
|
||||
raise ValueError(
|
||||
"The 'yaml_configs' field is deprecated and no longer supported. "
|
||||
"Please use 'yaml_extra' instead."
|
||||
)
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_mode_and_yaml_default_not_both_provided(self):
|
||||
"""Validate that both mode and yaml_default are not provided simultaneously.
|
||||
|
||||
@ -15,8 +15,8 @@ from torch.fx.passes.fake_tensor_prop import FakeTensorProp
|
||||
from torch.fx.passes.shape_prop import _extract_tensor_metadata
|
||||
from torch.utils._pytree import _LEAF_SPEC
|
||||
|
||||
from ..utils.logger import ad_logger
|
||||
from ..utils.node_utils import is_op
|
||||
from .logger import ad_logger
|
||||
from .node_utils import is_op
|
||||
|
||||
_NoValType = type("_NoValType", (), {})
|
||||
_NO_VAL = _NoValType()
|
||||
@ -131,7 +131,6 @@ class PerformanceOptions:
|
||||
def get_autodeploy_perf_config(self) -> Dict:
|
||||
AutoDeployPerfConfig = dict
|
||||
ad_config = AutoDeployPerfConfig()
|
||||
ad_config["attn_backend"] = "flashinfer"
|
||||
return ad_config
|
||||
|
||||
|
||||
|
||||
@ -52,9 +52,17 @@ class TestLlama3_1_8B(LlmapiAccuracyTestHarness):
|
||||
# Set it explicitly here to 8192 which is the default in build_config.
|
||||
"max_num_tokens": 8192,
|
||||
"skip_loading_weights": False,
|
||||
"compile_backend": "torch-opt",
|
||||
"free_mem_ratio": 0.7,
|
||||
"cuda_graph_batch_sizes": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512],
|
||||
"transforms": {
|
||||
"resize_kv_cache": {
|
||||
"free_mem_ratio": 0.7
|
||||
},
|
||||
"compile_model": {
|
||||
"backend":
|
||||
"torch-opt",
|
||||
"cuda_graph_batch_sizes":
|
||||
[1, 2, 4, 8, 16, 32, 64, 128, 256, 512],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def get_default_sampling_params(self):
|
||||
@ -99,9 +107,15 @@ class TestNemotronH(LlmapiAccuracyTestHarness):
|
||||
# Set explicitly to match default build_config behavior
|
||||
"max_num_tokens": 8192,
|
||||
"skip_loading_weights": False,
|
||||
"compile_backend": "torch-opt",
|
||||
"free_mem_ratio": 0.7,
|
||||
"cuda_graph_batch_sizes": [1, 2, 4, 8, 16, 32, 64, 128],
|
||||
"transforms": {
|
||||
"resize_kv_cache": {
|
||||
"free_mem_ratio": 0.7
|
||||
},
|
||||
"compile_model": {
|
||||
"backend": "torch-opt",
|
||||
"cuda_graph_batch_sizes": [1, 2, 4, 8, 16, 32, 64, 128],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def get_default_sampling_params(self):
|
||||
|
||||
@ -1411,8 +1411,14 @@ class MultiMetricPerfTest(AbstractPerfScriptTestClass):
|
||||
|
||||
# Create _autodeploy specific configuration
|
||||
autodeploy_config = {
|
||||
'compile_backend': self._config.ad_compile_backend,
|
||||
'free_mem_ratio': self._config.free_mem_ratio,
|
||||
'transforms': {
|
||||
'compile_model': {
|
||||
'backend': self._config.ad_compile_backend
|
||||
},
|
||||
'resize_kv_cache': {
|
||||
'free_mem_ratio': self._config.free_mem_ratio
|
||||
},
|
||||
},
|
||||
'runtime': self._config.extra_runtime,
|
||||
'skip_loading_weights': self._config.skip_loading_weights
|
||||
}
|
||||
|
||||
@ -2,7 +2,6 @@ import copy
|
||||
import os
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
@ -444,7 +443,6 @@ _SMALL_MODEL_CONFIGS = {
|
||||
"mistralai/Mistral-Small-3.1-24B-Instruct-2503": {
|
||||
"llm_models_subdir": "Mistral-Small-3.1-24B-Instruct-2503",
|
||||
"model_factory": "AutoModelForImageTextToText",
|
||||
"compile_backend": "torch-simple",
|
||||
"model_kwargs": {
|
||||
"text_config": {
|
||||
"num_hidden_layers": 2,
|
||||
@ -531,10 +529,8 @@ def get_small_model_config(model_hub_id: str, **llm_args_kwargs) -> Dict[str, An
|
||||
|
||||
# add some defaults to llm_args
|
||||
llm_args["skip_loading_weights"] = True # No weight loading to speed up things
|
||||
llm_args["free_mem_ratio"] = 0.00 # we don't need the cache and it may cause OOM issues
|
||||
llm_args["attn_page_size"] = 4 # Make sure paging is activated despite small max_tokens
|
||||
llm_args["max_batch_size"] = 2 # Minimum batching to speed up things
|
||||
|
||||
# update with custom llm_args kwargs
|
||||
llm_args.update(llm_args_kwargs)
|
||||
|
||||
@ -549,13 +545,3 @@ def get_small_model_config(model_hub_id: str, **llm_args_kwargs) -> Dict[str, An
|
||||
}
|
||||
|
||||
return experiment_config
|
||||
|
||||
|
||||
def get_small_model_config_pytest_param(
|
||||
model_hub_id: str, pytest_param_kwargs=None, **llm_args_kwargs
|
||||
):
|
||||
return pytest.param(
|
||||
get_small_model_config(model_hub_id, **llm_args_kwargs),
|
||||
id=model_hub_id,
|
||||
**(pytest_param_kwargs or {}),
|
||||
)
|
||||
|
||||
@ -1,28 +1,40 @@
|
||||
"""Testing build_and_run_ad end2end."""
|
||||
|
||||
from typing import Dict
|
||||
|
||||
import pytest
|
||||
from _model_test_utils import get_small_model_config_pytest_param
|
||||
from _model_test_utils import get_small_model_config
|
||||
from build_and_run_ad import ExperimentConfig, main
|
||||
|
||||
|
||||
@pytest.mark.parametrize("world_size", [1, 2])
|
||||
@pytest.mark.parametrize("mode", ["graph", "transformers"])
|
||||
@pytest.mark.parametrize(
|
||||
"experiment_config",
|
||||
"model_hub_id, llm_extra_args",
|
||||
[
|
||||
get_small_model_config_pytest_param(
|
||||
(
|
||||
"meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
attn_backend="flashinfer",
|
||||
compile_backend="torch-opt",
|
||||
{
|
||||
"transforms": {
|
||||
"insert_cached_attention": {"backend": "flashinfer"},
|
||||
"compile_model": {"backend": "torch-opt"},
|
||||
},
|
||||
},
|
||||
),
|
||||
(
|
||||
"meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
{
|
||||
"transforms": {
|
||||
"transformers_replace_cached_attn": {"backend": "flashinfer"},
|
||||
},
|
||||
"mode": "transformers",
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_build_ad(world_size: int, experiment_config: Dict, mode: str):
|
||||
def test_build_ad(world_size: int, model_hub_id: str, llm_extra_args: dict):
|
||||
experiment_config = get_small_model_config(model_hub_id, **llm_extra_args)
|
||||
|
||||
experiment_config["args"]["world_size"] = world_size
|
||||
experiment_config["args"]["runtime"] = "trtllm" # Default runtime set to trtllm
|
||||
experiment_config["args"]["mode"] = mode
|
||||
|
||||
experiment_config = ExperimentConfig(**experiment_config)
|
||||
print(f"Experiment Config: {experiment_config}")
|
||||
main(experiment_config)
|
||||
|
||||
@ -1,12 +1,12 @@
|
||||
import torch # noqa
|
||||
import torch.export as te
|
||||
from torch.export import Dim # noqa
|
||||
import pytest
|
||||
import torch
|
||||
import torch.export as te
|
||||
from _model_test_utils import get_small_model_config
|
||||
from torch.export import Dim
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.export import apply_export_patches, torch_export_to_gm
|
||||
from tensorrt_llm._torch.auto_deploy.llm_args import AutoDeployConfig
|
||||
from tensorrt_llm._torch.auto_deploy.transformations._graph import move_to_device # noqa
|
||||
from _model_test_utils import get_small_model_config
|
||||
from tensorrt_llm._torch.auto_deploy.utils._graph import move_to_device
|
||||
|
||||
# NOTE: find example inputs with the same tokenization length to avoid seq concat.
|
||||
EXAMPLE_INPUT = "Mamba is a snake with the following properties:"
|
||||
@ -31,10 +31,12 @@ def test_bamba_patches(model_dir: str, run_verify_generation: bool):
|
||||
common_kwargs = {
|
||||
"world_size": 0,
|
||||
"runtime": "demollm",
|
||||
"compile_backend": "torch-simple",
|
||||
"attn_backend": "flashinfer",
|
||||
"model_factory": "AutoModelForCausalLM",
|
||||
"max_seq_len": 512,
|
||||
"transforms": {
|
||||
"insert_cached_attention": {"backend": "flashinfer"},
|
||||
"compile_model": {"backend": "torch-simple"},
|
||||
},
|
||||
}
|
||||
|
||||
if use_small_config:
|
||||
|
||||
@ -5,7 +5,7 @@ from PIL import Image
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy import LlmArgs
|
||||
from tensorrt_llm._torch.auto_deploy.export import apply_export_patches, torch_export_to_gm
|
||||
from tensorrt_llm._torch.auto_deploy.transformations._graph import move_to_device
|
||||
from tensorrt_llm._torch.auto_deploy.utils._graph import move_to_device
|
||||
|
||||
|
||||
def test_build_run_llama4_vlm():
|
||||
|
||||
@ -4,7 +4,7 @@ from build_and_run_ad import ExperimentConfig
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy import LlmArgs
|
||||
from tensorrt_llm._torch.auto_deploy.export import apply_export_patches, torch_export_to_gm
|
||||
from tensorrt_llm._torch.auto_deploy.transformations._graph import move_to_device
|
||||
from tensorrt_llm._torch.auto_deploy.utils._graph import move_to_device
|
||||
|
||||
|
||||
def test_build_run_mistral3_vlm():
|
||||
|
||||
@ -41,10 +41,8 @@ def get_inference_model(cache_seq_interface):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("engine_cls", [ADEngine, DemoEngine])
|
||||
@pytest.mark.parametrize(
|
||||
"attn_backend, attn_page_size", [("triton", 0), ("flashinfer", 2), ("torch", 0)]
|
||||
)
|
||||
def test_engine(engine_cls: Type[ADEngine], attn_backend: str, attn_page_size: int):
|
||||
@pytest.mark.parametrize("attn_page_size", [0, 2, 0])
|
||||
def test_engine(engine_cls: Type[ADEngine], attn_page_size: int):
|
||||
"""Test the SimpleEngine functionality."""
|
||||
|
||||
seed = 42 # Set random seed for model param init
|
||||
|
||||
@ -4,6 +4,7 @@ import pydantic
|
||||
import pytest
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy import LLM, DemoLLM, LlmArgs
|
||||
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
|
||||
|
||||
|
||||
def test_custom_values():
|
||||
@ -12,13 +13,23 @@ def test_custom_values():
|
||||
"model": "test-model",
|
||||
"model_factory": "AutoModelForImageTextToText",
|
||||
"model_kwargs": {"custom_param": True},
|
||||
"mla_backend": "MultiHeadLatentAttention",
|
||||
"skip_loading_weights": True,
|
||||
"free_mem_ratio": 0.9,
|
||||
"simple_shard_only": True,
|
||||
"attn_page_size": 128,
|
||||
"attn_backend": "flashinfer",
|
||||
"max_seq_len": 2048,
|
||||
"transforms": {
|
||||
"detect_sharding": {
|
||||
"stage": "sharding",
|
||||
"simple_shard_only": True,
|
||||
},
|
||||
"insert_cached_attention": {
|
||||
"stage": "cache_init",
|
||||
"backend": "flashinfer",
|
||||
},
|
||||
"resize_kv_cache": {
|
||||
"stage": "cache_init",
|
||||
"free_mem_ratio": 0.9,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
args = LlmArgs(**custom_kwargs)
|
||||
@ -28,26 +39,30 @@ def test_custom_values():
|
||||
"custom_param": True,
|
||||
}
|
||||
assert args.skip_loading_weights
|
||||
assert args.free_mem_ratio == 0.9
|
||||
assert args.simple_shard_only
|
||||
assert args.transforms["resize_kv_cache"]["free_mem_ratio"] == 0.9
|
||||
assert args.transforms["detect_sharding"]["simple_shard_only"]
|
||||
assert args.attn_page_size == 128
|
||||
assert args.max_seq_len == 2048
|
||||
# attn_backend should be overridden if it was 'TRTLLM'
|
||||
assert args.attn_backend == "flashinfer"
|
||||
# backend should be overridden if it was 'TRTLLM'
|
||||
assert args.transforms["insert_cached_attention"]["backend"] == "flashinfer"
|
||||
|
||||
|
||||
def test_free_mem_ratio_validation():
|
||||
"""Test free_mem_ratio validation."""
|
||||
|
||||
def get_transform_config(free_mem_ratio):
|
||||
return {"resize_kv_cache": {"stage": "cache_init", "free_mem_ratio": free_mem_ratio}}
|
||||
|
||||
# Valid values
|
||||
LlmArgs(model="test-model", free_mem_ratio=0.0)
|
||||
LlmArgs(model="test-model", free_mem_ratio=1.0)
|
||||
LlmArgs(model="test-model", free_mem_ratio=0.5)
|
||||
InferenceOptimizer(None, get_transform_config(0.0))
|
||||
InferenceOptimizer(None, get_transform_config(1.0))
|
||||
InferenceOptimizer(None, get_transform_config(0.5))
|
||||
|
||||
# Invalid values
|
||||
with pytest.raises(ValueError):
|
||||
LlmArgs(model="test-model", free_mem_ratio=-0.1)
|
||||
InferenceOptimizer(None, get_transform_config(-0.1))
|
||||
with pytest.raises(ValueError):
|
||||
LlmArgs(model="test-model", free_mem_ratio=1.1)
|
||||
InferenceOptimizer(None, get_transform_config(1.1))
|
||||
|
||||
|
||||
def test_get_pytorch_backend_config():
|
||||
@ -67,14 +82,25 @@ def test_config_params():
|
||||
return {
|
||||
"model": "test-model",
|
||||
"model_factory": "AutoModelForImageTextToText",
|
||||
"free_mem_ratio": 0.7,
|
||||
"simple_shard_only": True,
|
||||
"skip_loading_weights": True,
|
||||
"attn_page_size": 17,
|
||||
"attn_backend": "flashinfer",
|
||||
"max_seq_len": 19,
|
||||
"max_batch_size": 5,
|
||||
"world_size": 3,
|
||||
"transforms": {
|
||||
"detect_sharding": {
|
||||
"stage": "sharding",
|
||||
"simple_shard_only": True,
|
||||
},
|
||||
"insert_cached_attention": {
|
||||
"stage": "cache_init",
|
||||
"backend": "flashinfer",
|
||||
},
|
||||
"resize_kv_cache": {
|
||||
"stage": "cache_init",
|
||||
"free_mem_ratio": 0.7,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@ -131,8 +157,14 @@ def test_config_flow(
|
||||
|
||||
# Common assertions for both APIs
|
||||
assert instance.args.model_factory == test_config_params["model_factory"]
|
||||
assert instance.args.free_mem_ratio == test_config_params["free_mem_ratio"]
|
||||
assert instance.args.simple_shard_only == test_config_params["simple_shard_only"]
|
||||
assert (
|
||||
instance.args.transforms["resize_kv_cache"]["free_mem_ratio"]
|
||||
== test_config_params["transforms"]["resize_kv_cache"]["free_mem_ratio"]
|
||||
)
|
||||
assert (
|
||||
instance.args.transforms["detect_sharding"]["simple_shard_only"]
|
||||
== test_config_params["transforms"]["detect_sharding"]["simple_shard_only"]
|
||||
)
|
||||
assert instance.args.skip_loading_weights == test_config_params["skip_loading_weights"]
|
||||
assert instance.args.attn_page_size == test_config_params["attn_page_size"]
|
||||
assert instance.args.max_seq_len == test_config_params["max_seq_len"]
|
||||
@ -190,13 +222,17 @@ def test_parallel_config_validation(parallel_field, invalid_value):
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"attn_backend,expected_attn_page_size",
|
||||
"backend,expected_attn_page_size",
|
||||
[
|
||||
("flashinfer", 64), # Default attn_page_size
|
||||
("triton", 1024), # Should equal max_seq_len
|
||||
],
|
||||
)
|
||||
def test_attention_backend_page_size_logic(attn_backend, expected_attn_page_size):
|
||||
def test_attention_backend_page_size_logic(backend, expected_attn_page_size):
|
||||
"""Test attn_page_size logic for different attention backends."""
|
||||
args = LlmArgs(model="test-model", attn_backend=attn_backend, max_seq_len=1024)
|
||||
args = LlmArgs(
|
||||
model="test-model",
|
||||
max_seq_len=1024,
|
||||
transforms={"insert_cached_attention": {"stage": "cache_init", "backend": backend}},
|
||||
)
|
||||
assert args.attn_page_size == expected_attn_page_size
|
||||
|
||||
@ -1,13 +1,11 @@
|
||||
"""Testing build_and_run_ad end2end."""
|
||||
|
||||
from typing import Dict
|
||||
|
||||
import pytest
|
||||
from _model_test_utils import get_small_model_config_pytest_param
|
||||
from _model_test_utils import get_small_model_config
|
||||
from build_and_run_ad import ExperimentConfig, main
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.llm_args import AutoDeployConfig, LlmArgs, _ParallelConfig
|
||||
from tensorrt_llm._torch.auto_deploy.transformations.transform import InferenceOptimizer
|
||||
from tensorrt_llm._torch.auto_deploy.shim.ad_executor import ADEngine
|
||||
|
||||
|
||||
def _check_ad_config(experiment_config: ExperimentConfig, llm_args: LlmArgs):
|
||||
@ -41,87 +39,175 @@ def _check_ad_config(experiment_config: ExperimentConfig, llm_args: LlmArgs):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("mode", ["graph", "transformers"])
|
||||
@pytest.mark.parametrize(
|
||||
"experiment_config",
|
||||
"model_hub_id, llm_extra_args",
|
||||
[
|
||||
get_small_model_config_pytest_param(
|
||||
(
|
||||
"meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
attn_backend="flashinfer",
|
||||
compile_backend="torch-opt",
|
||||
free_mem_ratio=0.0001,
|
||||
{
|
||||
"transforms": {
|
||||
"resize_kv_cache": {"free_mem_ratio": 0.0001},
|
||||
"insert_cached_attention": {"backend": "flashinfer"},
|
||||
"compile_model": {"backend": "torch-opt"},
|
||||
},
|
||||
},
|
||||
),
|
||||
get_small_model_config_pytest_param(
|
||||
(
|
||||
"meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
{
|
||||
"transforms": {
|
||||
"transformers_replace_cached_attn": {"backend": "flashinfer"},
|
||||
},
|
||||
"mode": "transformers",
|
||||
},
|
||||
),
|
||||
(
|
||||
"mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||
attn_backend="triton",
|
||||
compile_backend="torch-simple",
|
||||
{
|
||||
"transforms": {
|
||||
"insert_cached_attention": {"backend": "triton"},
|
||||
"compile_model": {"backend": "torch-simple"},
|
||||
},
|
||||
},
|
||||
),
|
||||
get_small_model_config_pytest_param(
|
||||
(
|
||||
"mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||
{
|
||||
"transforms": {
|
||||
"transformers_replace_cached_attn": {"backend": "triton"},
|
||||
},
|
||||
"mode": "transformers",
|
||||
},
|
||||
),
|
||||
(
|
||||
"Qwen/Qwen3-30B-A3B",
|
||||
attn_backend="triton",
|
||||
compile_backend="torch-simple",
|
||||
{
|
||||
"transforms": {
|
||||
"insert_cached_attention": {"backend": "triton"},
|
||||
"compile_model": {"backend": "torch-simple"},
|
||||
},
|
||||
},
|
||||
),
|
||||
get_small_model_config_pytest_param(
|
||||
(
|
||||
"Qwen/Qwen3-30B-A3B",
|
||||
{
|
||||
"transforms": {
|
||||
"transformers_replace_cached_attn": {"backend": "triton"},
|
||||
},
|
||||
"mode": "transformers",
|
||||
},
|
||||
),
|
||||
(
|
||||
"microsoft/Phi-3-mini-4k-instruct",
|
||||
attn_backend="triton",
|
||||
compile_backend="torch-simple",
|
||||
{
|
||||
"transforms": {
|
||||
"insert_cached_attention": {"backend": "triton"},
|
||||
"compile_model": {"backend": "torch-simple"},
|
||||
},
|
||||
},
|
||||
),
|
||||
get_small_model_config_pytest_param(
|
||||
(
|
||||
"microsoft/Phi-3-mini-4k-instruct",
|
||||
attn_backend="torch",
|
||||
compile_backend="torch-simple",
|
||||
{
|
||||
"transforms": {
|
||||
"insert_cached_attention": {"backend": "torch"},
|
||||
"compile_model": {"backend": "torch-simple"},
|
||||
},
|
||||
},
|
||||
),
|
||||
get_small_model_config_pytest_param(
|
||||
(
|
||||
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
attn_backend="flashinfer",
|
||||
compile_backend="torch-opt",
|
||||
{
|
||||
"transforms": {
|
||||
"insert_cached_attention": {"backend": "flashinfer"},
|
||||
"compile_model": {"backend": "torch-opt"},
|
||||
},
|
||||
},
|
||||
),
|
||||
get_small_model_config_pytest_param(
|
||||
(
|
||||
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
{
|
||||
"transforms": {
|
||||
"transformers_replace_cached_attn": {"backend": "flashinfer"},
|
||||
},
|
||||
"mode": "transformers",
|
||||
},
|
||||
),
|
||||
(
|
||||
"deepseek-ai/DeepSeek-V3",
|
||||
attn_backend="triton",
|
||||
compile_backend="torch-simple",
|
||||
{
|
||||
"transforms": {
|
||||
"insert_cached_attention": {"backend": "triton"},
|
||||
"compile_model": {"backend": "torch-simple"},
|
||||
},
|
||||
},
|
||||
),
|
||||
get_small_model_config_pytest_param(
|
||||
(
|
||||
"Qwen/Qwen2.5-3B-Instruct",
|
||||
attn_backend="triton",
|
||||
compile_backend="torch-compile",
|
||||
{
|
||||
"transforms": {
|
||||
"insert_cached_attention": {"backend": "triton"},
|
||||
"compile_model": {"backend": "torch-compile"},
|
||||
},
|
||||
},
|
||||
),
|
||||
get_small_model_config_pytest_param(
|
||||
(
|
||||
"Qwen/Qwen2.5-3B-Instruct",
|
||||
{
|
||||
"transforms": {
|
||||
"transformers_replace_cached_attn": {"backend": "triton"},
|
||||
},
|
||||
"mode": "transformers",
|
||||
},
|
||||
),
|
||||
(
|
||||
"mistralai/Mistral-Small-3.1-24B-Instruct-2503",
|
||||
attn_backend="flashinfer",
|
||||
compile_backend="torch-cudagraph",
|
||||
{
|
||||
"transforms": {
|
||||
"insert_cached_attention": {"backend": "flashinfer"},
|
||||
"compile_model": {"backend": "torch-cudagraph"},
|
||||
},
|
||||
},
|
||||
),
|
||||
get_small_model_config_pytest_param(
|
||||
(
|
||||
"mistralai/Mistral-Small-3.1-24B-Instruct-2503",
|
||||
{
|
||||
"transforms": {
|
||||
"transformers_replace_cached_attn": {"backend": "flashinfer"},
|
||||
},
|
||||
"mode": "transformers",
|
||||
},
|
||||
),
|
||||
(
|
||||
"nvidia/NVIDIA-Nemotron-Nano-12B-v2",
|
||||
attn_backend="flashinfer",
|
||||
compile_backend="torch-simple",
|
||||
{
|
||||
"transforms": {
|
||||
"insert_cached_attention": {"backend": "flashinfer"},
|
||||
"compile_model": {"backend": "torch-simple"},
|
||||
},
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_build_ad(experiment_config: Dict, mode: str):
|
||||
if (
|
||||
"DeepSeek-V3" in experiment_config["args"]["model"]
|
||||
or "Phi-3-mini-4k-instruct" in experiment_config["args"]["model"]
|
||||
or "NVIDIA-Nemotron-Nano-12B-v2" in experiment_config["args"]["model"]
|
||||
and mode == "transformers"
|
||||
):
|
||||
pytest.skip(f"{experiment_config['args']['model']} is not supported in transformers mode")
|
||||
|
||||
def test_build_ad(model_hub_id: str, llm_extra_args: dict):
|
||||
experiment_config = get_small_model_config(model_hub_id, **llm_extra_args)
|
||||
experiment_config["args"]["runtime"] = "demollm" # Default runtime set to demollm
|
||||
experiment_config["args"]["world_size"] = 0 # Default world_size set to 0
|
||||
experiment_config["args"]["mode"] = mode
|
||||
|
||||
print(f"Experiment Config: {experiment_config}")
|
||||
experiment_config = ExperimentConfig(**experiment_config)
|
||||
original_init = InferenceOptimizer.__init__
|
||||
original_build_from_config = ADEngine.build_from_config
|
||||
|
||||
def check_and_original_init(self, factory, ad_config):
|
||||
@classmethod
|
||||
def check_and_original_build(cls, ad_config):
|
||||
_check_ad_config(experiment_config, ad_config)
|
||||
return original_init(self, factory, ad_config=ad_config)
|
||||
return original_build_from_config.__func__(cls, ad_config)
|
||||
|
||||
# Temporarily replace the __init__ method
|
||||
InferenceOptimizer.__init__ = check_and_original_init
|
||||
# Temporarily replace the build_from_config classmethod
|
||||
ADEngine.build_from_config = check_and_original_build
|
||||
|
||||
try:
|
||||
main(experiment_config)
|
||||
finally:
|
||||
# Restore original __init__
|
||||
InferenceOptimizer.__init__ = original_init
|
||||
# Restore original build_from_config
|
||||
ADEngine.build_from_config = original_build_from_config
|
||||
|
||||
@ -75,8 +75,14 @@ def test_trtllm_bench(llm_root, compile_backend, model_name): # noqa: F811
|
||||
with open(extra_llm_api_options_path, "w") as f:
|
||||
yaml.dump(
|
||||
{
|
||||
"compile_backend": compile_backend,
|
||||
**config["args"],
|
||||
"transforms": {
|
||||
"compile_model": {
|
||||
"stage": "compile",
|
||||
"cuda_graph_batch_sizes": [1, 2, 4, 8, 16, 32, 64, 128],
|
||||
"backend": compile_backend,
|
||||
}
|
||||
},
|
||||
},
|
||||
f,
|
||||
)
|
||||
|
||||
@ -7,7 +7,6 @@ from torch.export import Dim
|
||||
from torch.fx import GraphModule
|
||||
from transformers.integrations.sdpa_attention import repeat_kv as hf_repeat_kv
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import AttentionDescriptor
|
||||
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
|
||||
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op
|
||||
|
||||
@ -1014,21 +1013,6 @@ class Llama3CausalAttentionModel(torch.nn.Module):
|
||||
return {0: Dim("batch_size", max=8), 1: Dim("seq_len", min=4, max=16)}
|
||||
|
||||
|
||||
class MockAttentionDescriptor(AttentionDescriptor):
|
||||
"""A mock class that mimics the AttentionDescriptor interface for testing."""
|
||||
|
||||
layout: str = "bnsd"
|
||||
source_attention_op: Callable = torch.ops.auto_deploy.torch_attention_sdpa
|
||||
|
||||
@classmethod
|
||||
def get_attention_layout(cls) -> str:
|
||||
return cls.layout
|
||||
|
||||
@classmethod
|
||||
def get_source_attention_op(cls) -> Callable:
|
||||
return cls.source_attention_op
|
||||
|
||||
|
||||
class AttentionLayoutModel(torch.nn.Module):
|
||||
"""Model that uses SDPA for testing the layout transformation."""
|
||||
|
||||
@ -1169,17 +1153,6 @@ def test_match_attention_layout(layout, model_config, has_mask):
|
||||
hidden_size = 512
|
||||
num_heads = 8
|
||||
|
||||
# Set up the mock attention descriptor class with the specified layout
|
||||
MockAttentionDescriptor.layout = layout
|
||||
if layout == "bnsd":
|
||||
if model_config.get("use_grouped_sdpa"):
|
||||
source_op = torch.ops.auto_deploy.torch_attention
|
||||
else:
|
||||
source_op = torch.ops.auto_deploy.torch_attention_sdpa
|
||||
else:
|
||||
source_op = torch.ops.auto_deploy.torch_attention
|
||||
MockAttentionDescriptor.source_attention_op = source_op
|
||||
|
||||
# Create appropriate model based on model_config
|
||||
if model_config["type"] == "standard":
|
||||
model = AttentionLayoutModel(
|
||||
@ -1329,7 +1302,7 @@ def test_match_attention_layout(layout, model_config, has_mask):
|
||||
{
|
||||
"match_attention_layout": {
|
||||
"stage": "pattern_matcher",
|
||||
"attention_op": MockAttentionDescriptor,
|
||||
"attn_layout": layout,
|
||||
},
|
||||
},
|
||||
)(None, gm),
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
"""Test that the attention matcher works with HF's SDPA backends."""
|
||||
|
||||
import copy
|
||||
from typing import Any, Callable, Dict
|
||||
from typing import Any, Dict
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -12,29 +12,14 @@ from torch.fx import GraphModule
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
from transformers.models.llama.modeling_llama import LlamaModel
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import AttentionDescriptor
|
||||
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
|
||||
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
|
||||
from tensorrt_llm._torch.auto_deploy.transformations._graph import move_to_device
|
||||
from tensorrt_llm._torch.auto_deploy.utils._graph import move_to_device
|
||||
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
|
||||
class MockAttentionDescriptor(AttentionDescriptor):
|
||||
"""A mock class that mimics the AttentionDescriptor interface for testing."""
|
||||
|
||||
layout: str = "bsnd"
|
||||
|
||||
@classmethod
|
||||
def get_attention_layout(cls) -> str:
|
||||
return cls.layout
|
||||
|
||||
@classmethod
|
||||
def get_source_attention_op(cls) -> Callable:
|
||||
return torch.ops.auto_deploy.torch_attention
|
||||
|
||||
|
||||
class HFWrapper(nn.Module):
|
||||
def __init__(self, model: nn.Module):
|
||||
super().__init__()
|
||||
@ -61,7 +46,7 @@ def _joint_transform(gm: GraphModule) -> None:
|
||||
},
|
||||
"match_attention_layout": {
|
||||
"stage": "pattern_matcher",
|
||||
"attention_op": MockAttentionDescriptor,
|
||||
"attn_layout": "bsnd",
|
||||
},
|
||||
},
|
||||
)(None, gm)
|
||||
|
||||
@ -192,7 +192,7 @@ def test_sdpa_with_kv_cache(dtype, attn_backend, gqa_config):
|
||||
},
|
||||
"insert_cached_attention": {
|
||||
"stage": "cache_init",
|
||||
"attn_backend": attn_backend,
|
||||
"backend": attn_backend,
|
||||
},
|
||||
},
|
||||
) # type: ignore
|
||||
|
||||
@ -414,23 +414,6 @@ def test_partial_env_override(basic_yaml_files):
|
||||
assert settings.option.option == "on" # from env
|
||||
|
||||
|
||||
# Error handling tests
|
||||
def test_deprecated_yaml_configs_field_error(basic_yaml_files):
|
||||
"""Test that using deprecated yaml_configs field raises ValueError."""
|
||||
with pytest.raises(
|
||||
ValueError, match=r"The 'yaml_configs' field is deprecated.*Please use 'yaml_extra' instead"
|
||||
):
|
||||
BasicSettings(yaml_configs=[basic_yaml_files["config1"]])
|
||||
|
||||
|
||||
def test_empty_yaml_configs_allowed():
|
||||
"""Test that empty yaml_configs list doesn't raise error."""
|
||||
# Empty yaml_configs should not raise error (but validation will still fail for missing fields)
|
||||
with pytest.raises(ValidationError):
|
||||
# Should fail validation for missing required fields, not for yaml_configs deprecation
|
||||
BasicSettings(yaml_configs=[])
|
||||
|
||||
|
||||
def test_missing_yaml_file(temp_dir):
|
||||
"""Test handling of missing yaml file."""
|
||||
missing_file = temp_dir / "missing.yaml"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user