[AutoDeploy] merge feat/ad-2025-06-29 (#5737)

Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>
Co-authored-by: Neta Zmora <nzmora@nvidia.com>
Co-authored-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
This commit is contained in:
Lucas Liebenwein 2025-07-03 21:21:18 -04:00 committed by GitHub
parent aa72d39b72
commit 24ac9b5f69
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 353 additions and 55 deletions

View File

@ -8,15 +8,15 @@
"program": "build_and_run_ad.py",
"args": [
"--model=meta-llama/Meta-Llama-3.1-8B-Instruct",
"--args.world_size=2",
"--args.world-size=2",
"--args.runtime=demollm",
"--args.compile_backend=torch-simple",
"--args.attn_page_size=16",
"--args.attn_backend=flashinfer",
"--args.model_factory=AutoModelForCausalLM",
"--args.compile-backend=torch-simple",
"--args.attn-page-size=16",
"--args.attn-backend=flashinfer",
"--args.model-factory=AutoModelForCausalLM",
"--benchmark.enabled=false",
"--prompt.batch_size=2",
"--args.model_kwargs",
"--prompt.batch-size=2",
"--args.model-kwargs",
"num_hidden_layers=3,num_attention_heads=32",
],
"console": "integratedTerminal",

View File

@ -25,6 +25,8 @@
"python.testing.pytestArgs": [
"./tests/unittest/_torch/auto_deploy",
"--no-cov",
"-n=auto",
"--dist=worksteal",
],
"files.exclude": {
"build": true,

View File

@ -142,19 +142,19 @@ Below is a non-exhaustive list of common config options:
| Configuration Key | Description |
|-------------------|-------------|
| `--model` | The HF model card or path to a HF checkpoint folder |
| `--args.model_factory` | Choose model factory implementation (`"AutoModelForCausalLM"`, ...) |
| `--args.skip_loading_weights` | Only load the architecture, not the weights |
| `--args.model_kwargs` | Extra kwargs that are being passed to the model initializer in the model factory |
| `--args.tokenizer_kwargs` | Extra kwargs that are being passed to the tokenizer initializer in the model factory |
| `--args.world_size` | The number of GPUs for Tensor Parallel |
| `--args.model-factory` | Choose model factory implementation (`"AutoModelForCausalLM"`, ...) |
| `--args.skip-loading-weights` | Only load the architecture, not the weights |
| `--args.model-kwargs` | Extra kwargs that are being passed to the model initializer in the model factory |
| `--args.tokenizer-kwargs` | Extra kwargs that are being passed to the tokenizer initializer in the model factory |
| `--args.world-size` | The number of GPUs for Tensor Parallel |
| `--args.runtime` | Specifies which type of Engine to use during runtime (`"demollm"` or `"trtllm"`) |
| `--args.compile_backend` | Specifies how to compile the graph at the end |
| `--args.attn_backend` | Specifies kernel implementation for attention |
| `--args.mla_backend` | Specifies implementation for multi-head latent attention |
| `--args.max_seq_len` | Maximum sequence length for inference/cache |
| `--args.max_batch_size` | Maximum dimension for statically allocated KV cache |
| `--args.attn_page_size` | Page size for attention |
| `--prompt.batch_size` | Number of queries to generate |
| `--args.compile-backend` | Specifies how to compile the graph at the end |
| `--args.attn-backend` | Specifies kernel implementation for attention |
| `--args.mla-backend` | Specifies implementation for multi-head latent attention |
| `--args.max-seq-len` | Maximum sequence length for inference/cache |
| `--args.max-batch-size` | Maximum dimension for statically allocated KV cache |
| `--args.attn-page-size` | Page size for attention |
| `--prompt.batch-size` | Number of queries to generate |
| `--benchmark.enabled` | Whether to run the built-in benchmark (true/false) |
For default values and additional configuration options, refer to the `ExperimentConfig` class in [build_and_run_ad.py](./build_and_run_ad.py) file.
@ -165,10 +165,10 @@ Here is a more complete example of using the script:
cd examples/auto_deploy
python build_and_run_ad.py \
--model "TinyLlama/TinyLlama-1.1B-Chat-v1.0" \
--args.world_size 2 \
--args.world-size 2 \
--args.runtime "demollm" \
--args.compile_backend "torch-compile" \
--args.attn_backend "flashinfer" \
--args.compile-backend "torch-compile" \
--args.attn-backend "flashinfer" \
--benchmark.enabled True
```
@ -214,7 +214,7 @@ Refer to [NVIDIA TensorRT Model Optimizer](https://github.com/NVIDIA/TensorRT-Mo
```bash
cd examples/auto_deploy
python build_and_run_ad.py --model "<MODELOPT_CKPT_PATH>" --args.world_size 1
python build_and_run_ad.py --model "<MODELOPT_CKPT_PATH>" --args.world-size 1
```
### Incorporating `auto_deploy` into your own workflow

View File

@ -96,6 +96,12 @@ class LlmArgs(BaseLlmArgs):
device: str = Field(default="cuda", description="The device to use for the model.", frozen=True)
kv_cache_dtype: str = Field(
default="auto",
description="Data type for KV cache. This is a temporary field until kv_cache_dtype is "
"supported in AutoDeploy.",
)
# INFERENCE OPTIMIZER CONFIG ###################################################################
attn_backend: Literal["flashinfer", "triton"] = Field(
default="flashinfer", description="Attention backend to use."

View File

@ -2,5 +2,6 @@ from . import hf
from .decilm import *
from .deepseek import *
from .factory import *
from .mixtral import *
from .phi import *
from .qwen3 import *

View File

@ -0,0 +1,50 @@
"""A patch for Mixtral MoE to make it compatible with torch.export."""
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
def _forward_moe(self: MixtralSparseMoeBlock, hidden_states: torch.Tensor):
# check if we can apply the patch
use_original_forward = False
if not all(isinstance(expert.act_fn, nn.SiLU) for expert in self.experts):
use_original_forward = True
if any(getattr(mod, "bias", None) is not None for mod in self.experts.modules()):
use_original_forward = True
# rely on original forward instead
if use_original_forward:
return self._original_forward(hidden_states)
batch_size, sequence_length, hidden_dim = hidden_states.shape
if self.training and self.jitter_noise > 0:
hidden_states *= torch.empty_like(hidden_states).uniform_(
1.0 - self.jitter_noise, 1.0 + self.jitter_noise
)
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
# we cast back to the input dtype
routing_weights = routing_weights.to(hidden_states.dtype)
final_hidden_states = torch.ops.auto_deploy.torch_moe(
hidden_states,
selected_experts,
routing_weights,
w1_weight=[expert.w1.weight for expert in self.experts], # gate projection
w2_weight=[expert.w2.weight for expert in self.experts], # down projection
w3_weight=[expert.w3.weight for expert in self.experts], # up projection
)
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states, router_logits
MixtralSparseMoeBlock._original_forward = MixtralSparseMoeBlock.forward
MixtralSparseMoeBlock.forward = _forward_moe

View File

@ -1,10 +1,24 @@
"""A patch for Qwen3 MoE to make it compatible with torch.export and reduce export time."""
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock
# patch for MoE to reduce torch.export time
def _forward_moe(self: Qwen3MoeSparseMoeBlock, hidden_states: torch.Tensor):
# check if we can apply the patch
use_original_forward = False
if not all(isinstance(expert.act_fn, nn.SiLU) for expert in self.experts):
use_original_forward = True
if any(getattr(mod, "bias", None) is not None for mod in self.experts.modules()):
use_original_forward = True
# rely on original forward instead
if use_original_forward:
return self._original_forward(hidden_states)
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)

View File

@ -11,6 +11,7 @@ from ....executor import GenerationExecutor
from ....executor.request import GenerationRequest
from ....executor.result import CompletionOutput, GenerationResult
from ....sampling_params import SamplingParams
from ...pyexecutor.sampler import greedy_search_sampling_batch, top_k_sampling_batch
from ..distributed import common as dist_ad
from ..utils.logger import ad_logger
from .ad_executor import ADEngine
@ -201,11 +202,12 @@ class DemoEngine(ADEngine):
def _sample(
cls, logits: torch.Tensor, sampling_params: SamplingParams
) -> Tuple[torch.Tensor, torch.Tensor]:
from tensorrt_llm._torch.pyexecutor.sampler import top_k_sampling_batch
logits_shape = logits.shape
logits = logits.view(-1, logits_shape[-1]) # top_k_sampling_batch expects 2D logits
idx_next, probs = top_k_sampling_batch(logits, sampling_params.top_k)
logits = logits.view(-1, logits_shape[-1]) # sampling_batch expects 2D logits
if isinstance(sampling_params.top_k, int):
idx_next, probs = top_k_sampling_batch(logits, sampling_params.top_k)
else:
idx_next, probs = greedy_search_sampling_batch(logits)
idx_next = idx_next.view(logits_shape[:-1])
return idx_next, probs
@ -213,10 +215,6 @@ class DemoEngine(ADEngine):
self, logits_last: torch.Tensor, sampling_params: SamplingParams
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Returns a sampled token per input sequence and associating probability."""
if sampling_params.top_k == 1: # greedy decoding
# idx_next is the index of the max logit for each sequence
idx_next = logits_last.argmax(dim=-1, keepdim=False)
return idx_next, logits_last.squeeze(-1)
# run sampling
return self._sample(logits_last, sampling_params)

View File

@ -1,6 +1,7 @@
import importlib.metadata
import math
from collections import defaultdict
from contextlib import nullcontext
from contextlib import contextmanager, nullcontext
from functools import partial
from typing import Any, Dict, List, Optional, Tuple
@ -8,6 +9,7 @@ import torch
import torch.export as te
import torch.nn as nn
import torch.nn.functional as F
from packaging import version
from torch import fx
from torch.utils._sympy.value_ranges import ValueRanges
@ -230,6 +232,63 @@ def _torch_modulelist_getitem_patch(self: nn.ModuleList, idx):
_torch_modulelist_getitem_patch.getitem_original = nn.ModuleList.__getitem__
def _torch_tensor_patch(data, **kwargs):
"""Patch torch.tensor to handle 0.0 on meta device.
``torch.tensor(0.0, device="meta")`` does not work and hence we are patching it to use
``torch.zeros((), device="meta")`` instead, which is equivalent.
"""
device = kwargs.get("device", None)
if data == 0.0 and device is not None and torch.device(device) == torch.device("meta"):
return torch.zeros((), **kwargs)
return _torch_tensor_patch.tensor_original(data, **kwargs)
_torch_tensor_patch.tensor_original = torch.tensor
def _transformers_version() -> str:
"""Get the version of transformers."""
return version.parse(importlib.metadata.version("transformers")).base_version
# TODO (@lucaslie): https://github.com/NVIDIA/TensorRT-LLM/issues/5728
# not great that this patch is here but it's the least invasisve change until we make headway on the
# above issue.
@contextmanager
def _transformers_sdpa_mask_patch():
"""Patch transformers.masking_utils.sdpa_mask to be export-compatible."""
# this patch is only needed+compatible for transformers >= 4.53.0
if version.parse(_transformers_version()) < version.parse("4.53.0"):
yield # Just yield without doing anything (like nullcontext)
return
# imports only after version check
from transformers import masking_utils
from transformers.integrations.executorch import sdpa_mask_without_vmap
# recall original implementation
sdpa_mask_original = masking_utils.sdpa_mask
# patch function and mask attention interface
masking_utils.sdpa_mask = sdpa_mask_without_vmap
if "sdpa" in masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._local_mapping:
sdpa_local_original = masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._local_mapping["sdpa"]
else:
sdpa_local_original = None
masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["sdpa"] = sdpa_mask_without_vmap
try:
yield
finally:
# revert patches
masking_utils.sdpa_mask = sdpa_mask_original
if sdpa_local_original is None:
del masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["sdpa"]
else:
masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["sdpa"] = sdpa_local_original
def add_missing_load_hooks(gm: fx.GraphModule, model: nn.Module) -> fx.GraphModule:
"""Adds back the state dict load hooks stripped away during export."""
hooks = {
@ -356,24 +415,29 @@ def torch_export_to_gm(
torch.autocast = lambda *args, **kwargs: nullcontext()
torch.nn.attention.sdpa_kernel = lambda *args, **kwargs: nullcontext()
with lift_to_meta(model) as state_dict:
# clean up args, kwargs and move to correct device
args, kwargs = tree_to((args, kwargs or {}), device="meta")
# patch torch.tensor to handle 0.0 on meta device
torch.tensor = _torch_tensor_patch
# NOTE: we always export in non-strict mode for now as it relaxes some
# assumptions around tracing. Strict mode uses torchdynamo (symbolic bytecode analysis),
# which can be brittle since it relies on the exact bytecode representation of the model.
# see here as well: https://pytorch.org/docs/stable/export.html#non-strict-export
export_kwargs["strict"] = False
# run export with sdpa masking patch and lifted to meta
with _transformers_sdpa_mask_patch():
with lift_to_meta(model) as state_dict:
# clean up args, kwargs and move to correct device
args, kwargs = tree_to((args, kwargs or {}), device="meta")
# run export and extract graph module
egm: fx.GraphModule = torch_export(model, args, kwargs, **export_kwargs).module()
# NOTE: we always export in non-strict mode for now as it relaxes some
# assumptions around tracing. Strict mode uses torchdynamo (symbolic bytecode analysis),
# which can be brittle since it relies on the exact bytecode representation of the model
# see here as well: https://pytorch.org/docs/stable/export.html#non-strict-export
export_kwargs["strict"] = False
# load state_dict into egm
# NOTE: export might have removed unused params/buffers (hence we allow unexpected keys)
load_buffers_and_params(
egm, state_dict, strict_missing=True, strict_unexpected=False, clone=clone
)
# run export and extract graph module
egm: fx.GraphModule = torch_export(model, args, kwargs, **export_kwargs).module()
# load state_dict into egm
# NOTE: export might have removed unused params/buffers (hence we allow unexpected keys)
load_buffers_and_params(
egm, state_dict, strict_missing=True, strict_unexpected=False, clone=clone
)
# revert sdpa back to original
F.scaled_dot_product_attention = sdpa_original
@ -391,6 +455,9 @@ def torch_export_to_gm(
torch.autocast = autocast_original
torch.nn.attention.sdpa_kernel = sdpa_kernel_original
# revert torch.tensor patch
torch.tensor = _torch_tensor_patch.tensor_original
# Export strips away all methods not traced during forward. The model could have
# load hooks that contain logic for correct state_dict loading. We need to add those
# hooks back to the exported graph module.

View File

@ -19,6 +19,7 @@ from tensorrt_llm.bench.benchmark.utils.general import (
# isort: on
from tensorrt_llm import LLM as PyTorchLLM
from tensorrt_llm._tensorrt_engine import LLM
from tensorrt_llm._torch.auto_deploy import LLM as AutoDeployLLM
from tensorrt_llm.bench.benchmark.utils.general import generate_warmup_dataset
from tensorrt_llm.bench.dataclasses.configuration import RuntimeConfig
from tensorrt_llm.bench.dataclasses.general import BenchmarkEnvironment
@ -370,6 +371,12 @@ def throughput_command(
"Ignore extended_runtime_perf_knob_config for pytorch backend."
)
llm = PyTorchLLM(**kwargs)
elif runtime_config.backend == "_autodeploy":
if kwargs.pop("extended_runtime_perf_knob_config", None):
logger.warning(
"Ignore extended_runtime_perf_knob_config for _autodeploy backend."
)
llm = AutoDeployLLM(**kwargs)
else:
llm = LLM(**kwargs)

View File

@ -9,13 +9,11 @@ from pydantic import (BaseModel, Field, PositiveFloat, field_validator,
model_validator)
import tensorrt_llm.bindings.executor as trtllm
from tensorrt_llm._torch.auto_deploy import LlmArgs as _AutoDeployLlmArgs
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
from tensorrt_llm.llmapi import (BatchingType, CapacitySchedulerPolicy,
ContextChunkingPolicy, DynamicBatchConfig,
ExtendedRuntimePerfKnobConfig, KvCacheConfig,
SchedulerConfig)
from tensorrt_llm.llmapi.llm_args import TorchCompileConfig
from tensorrt_llm.llmapi.llm_utils import update_llm_args_with_extra_options
from tensorrt_llm.models.modeling_utils import SpeculativeDecodingMode
@ -111,11 +109,11 @@ class PerformanceOptions:
def get_pytorch_perf_config(self) -> PyTorchConfig:
return self.pytorch_config
def get_autodeploy_perf_config(self) -> _AutoDeployLlmArgs:
ad_config = _AutoDeployLlmArgs(**self.pytorch_config)
ad_config.attn_backend = "flashinfer"
ad_config.torch_compile_config = TorchCompileConfig()
ad_config.skip_loading_weights = True
def get_autodeploy_perf_config(self) -> Dict:
AutoDeployPerfConfig = dict
ad_config = AutoDeployPerfConfig()
ad_config["kv_cache_dtype"] = "auto"
ad_config["attn_backend"] = "flashinfer"
return ad_config

View File

@ -357,6 +357,21 @@ _SMALL_MODEL_CONFIGS = {
"model_kwargs": {
"num_hidden_layers": 2,
"intermediate_size": 256,
"hidden_size": 64,
"num_attention_heads": 4,
"num_key_value_heads": 2,
"num_local_experts": 2,
},
},
"Qwen/Qwen3-30B-A3B": {
"model": _hf_model_dir_or_hub_id(
f"{llm_models_root()}/Qwen3/Qwen3-30B-A3B",
"Qwen/Qwen3-30B-A3B",
),
"model_kwargs": {
"num_hidden_layers": 2,
"intermediate_size": 256,
"hidden_size": 64,
"num_attention_heads": 4,
"num_key_value_heads": 2,
"num_local_experts": 2,

View File

@ -4,6 +4,7 @@ import pytest
import torch
import torch.nn as nn
from tensorrt_llm import SamplingParams
from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import SequenceInfo
from tensorrt_llm._torch.auto_deploy.shim.ad_executor import ADEngine
from tensorrt_llm._torch.auto_deploy.shim.demollm import DemoEngine
@ -76,3 +77,55 @@ def test_engine(engine_cls: Type[ADEngine], attn_backend: str, attn_page_size: i
mock_input = None
original_logits = get_inference_model(mock_input)(input_ids[0].unsqueeze(0))[0]
assert torch.allclose(logits, original_logits, atol=1e-5), "Generated Token ID mismatch"
@pytest.mark.parametrize("attn_page_size", [0, 2])
def test_demo_engine_sampling(attn_page_size: int):
"""Test sampling logic specific to DemoEngine."""
seed = 0
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
device = torch.device("cuda")
max_seq_len = 64
max_batch_size = 8
sequence_info = SequenceInfo(
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
page_size=attn_page_size,
)
sequence_info.to(device)
engine = DemoEngine(get_inference_model, sequence_info, device)
with torch.inference_mode():
input_ids = [torch.tensor([1, 2, 3, 4], device=device)]
sequence_info.reset()
sequence_info.nest_sequences(input_ids)
engine.cache_seq_interface.info.sync(sequence_info)
logits = engine._compute_logits()
logits = torch.stack(logits)
vocab_size = logits.size(-1)
sampling_params = SamplingParams(top_k=5, temperature=1.0)
token_ids, _ = engine._sample(logits, sampling_params)
expected_shape = logits.shape[:-1]
assert token_ids.shape == expected_shape, (
f"Unexpected shape for sampled token IDs, expected {expected_shape}, but got {token_ids.shape}"
)
assert torch.all((token_ids >= 0) & (token_ids < vocab_size)), (
"Sampled indices out of range"
)
# Test that top_k=1 (greedy) matches top_k=None (argmax fallback)
sampling_params_greedy = SamplingParams(top_k=1)
sampling_params_none = SamplingParams(top_k=None)
token_ids_1, _ = engine._sample(logits, sampling_params_greedy)
token_ids_2, _ = engine._sample(logits, sampling_params_none)
torch.testing.assert_close(token_ids_1, token_ids_2)

View File

@ -51,6 +51,11 @@ def _check_ad_config(experiment_config: ExperimentConfig, ad_config: LlmArgs):
attn_backend="triton",
compile_backend="torch-simple",
),
get_small_model_config(
"Qwen/Qwen3-30B-A3B",
attn_backend="triton",
compile_backend="torch-simple",
),
get_small_model_config(
"microsoft/Phi-3-mini-4k-instruct",
attn_backend="triton",

View File

@ -0,0 +1,82 @@
import subprocess
import tempfile
from pathlib import Path
import yaml
from _model_test_utils import _hf_model_dir_or_hub_id
from click.testing import CliRunner
from utils.cpp_paths import llm_root # noqa: F401
from utils.llm_data import llm_models_root
from tensorrt_llm.commands.bench import main
def prepare_dataset(root_dir: str, temp_dir: str, model_name: str):
_DATASET_NAME = "synthetic_128_128.txt"
dataset_path = Path(temp_dir, _DATASET_NAME)
dataset_tool = Path(root_dir, "benchmarks", "cpp", "prepare_dataset.py")
# Generate a small dataset to run a test.
command = [
"python3",
f"{dataset_tool}",
"--stdout",
"--tokenizer",
model_name,
"token-norm-dist",
"--input-mean",
"128",
"--output-mean",
"128",
"--input-stdev",
"0",
"--output-stdev",
"0",
"--num-requests",
"10",
]
print(f"Running command: {' '.join(command)}")
result = subprocess.run(command, capture_output=True, text=True)
if result.returncode != 0:
raise RuntimeError(f"Failed to prepare dataset: {result.stderr}")
# Grab the stdout and write it to a dataset file for passing to suite.
with open(dataset_path, "w") as dataset:
dataset.write(result.stdout)
return dataset_path
def run_benchmark(model_name: str, dataset_path: str, temp_dir: str):
runner = CliRunner()
args = [
"--model",
model_name,
"throughput",
"--backend",
"_autodeploy",
"--dataset",
dataset_path,
"--extra_llm_api_options",
f"{temp_dir}/model_kwargs.yaml",
]
runner.invoke(main, args, catch_exceptions=False)
def test_trtllm_bench(llm_root): # noqa: F811
model_name = _hf_model_dir_or_hub_id(
f"{llm_models_root()}/TinyLlama-1.1B-Chat-v1.0", "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
)
with tempfile.TemporaryDirectory() as temp_dir:
with open(f"{temp_dir}/model_kwargs.yaml", "w") as f:
yaml.dump(
{
"model_kwargs": {"num_hidden_layers": 2},
"cuda_graph_batch_sizes": [1, 2],
"max_batch_size": 128,
},
f,
)
dataset_path = prepare_dataset(llm_root, temp_dir, model_name)
run_benchmark(model_name, dataset_path, temp_dir)