perf: [AutoDeploy] Enable AutoDeploy as a backend in trtllm-bench (#3041)

* Enable AutoDeploy as a backend in trtllm-bench

Signed-off-by: Suyog Gupta <suyogg@nvidia.com>

* update how caches are resized

Signed-off-by: Suyog Gupta <suyogg@nvidia.com>

* fix: files permission from 100755 to 100644

Signed-off-by: Suyog Gupta <suyogg@nvidia.com>

* some comments

Signed-off-by: Suyog Gupta <suyogg@nvidia.com>

* lint

Signed-off-by: Suyog Gupta <suyogg@nvidia.com>

* lint

Signed-off-by: Suyog Gupta <suyogg@nvidia.com>

* lint

Signed-off-by: Suyog Gupta <suyogg@nvidia.com>

* lint

Signed-off-by: Suyog Gupta <suyogg@nvidia.com>

* Fix function name

Signed-off-by: Suyog Gupta <suyogg@nvidia.com>

* refactor

Signed-off-by: Suyog Gupta <suyogg@nvidia.com>

* Remove spurious change

Signed-off-by: Suyog Gupta <suyogg@nvidia.com>

* Add cursor generated doc strings

Signed-off-by: Suyog Gupta <suyogg@nvidia.com>

* re-enable ad test

Signed-off-by: Suyog Gupta <suyogg@nvidia.com>

* some perf cleanup

Signed-off-by: Suyog Gupta <suyogg@nvidia.com>

* debug ci

Signed-off-by: Suyog Gupta <suyogg@nvidia.com>

* ensure that overlap scheduler is enabled

Signed-off-by: Suyog Gupta <suyogg@nvidia.com>

* Reorder the tests

Signed-off-by: Suyog Gupta <suyogg@nvidia.com>

---------

Signed-off-by: Suyog Gupta <suyogg@nvidia.com>
Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>
This commit is contained in:
Suyog Gupta 2025-03-26 14:33:14 -07:00 committed by GitHub
parent 3e035f2219
commit 047f2b234d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 182 additions and 53 deletions

View File

@ -44,6 +44,7 @@ def build_llm_from_config(config: SimpleConfig) -> LLM:
model_kwargs=config.model_kwargs,
attn_backend=config.attn_backend,
skip_loading_weights=config.skip_loading_weights,
cuda_graph_max_batch_size=config.max_batch_size,
)
ad_logger.info(f"AutoDeploy Config: {ad_config}")

View File

@ -14,7 +14,9 @@ from ..compiler import BackendCompiler, BackendRegistry, _flatten_args
class CompiledGraph(nn.Module):
def __init__(self, model: GraphModule, max_batch_size: int):
def __init__(
self, model: GraphModule, max_batch_size: int, cuda_graph_batch_sizes: List[int] = None
):
super().__init__()
self._in_spec: TreeSpec = model._in_spec
self._out_spec: TreeSpec = model._out_spec
@ -24,6 +26,11 @@ class CompiledGraph(nn.Module):
self._input_buffer: torch.Tensor = torch.empty(0, 1)
self._out_buffer_flat: List[torch.Tensor] = None
self._args_hash: Optional[Tuple[int, ...]] = None
self.cuda_graph_batch_sizes = (
cuda_graph_batch_sizes
if cuda_graph_batch_sizes is not None
else self._get_graph_batch_sizes(self.max_batch_size)
)
def _get_hash(self, flat_args: List[Any]) -> Tuple[int, ...]:
return tuple(hash(a) for a in flat_args)
@ -90,7 +97,7 @@ class CompiledGraph(nn.Module):
assert out_spec == self._out_spec, "Output spec mismatch."
# capture graph now for a range of batch sizes
for bs in self._get_graph_batch_sizes(self.max_batch_size):
for bs in self.cuda_graph_batch_sizes:
ad_logger.info(f"Capturing graph for batch size: {bs}")
# setup args, kwargs
@ -131,7 +138,12 @@ class CompiledGraph(nn.Module):
class TorchOptCompiler(BackendCompiler):
@torch.inference_mode()
def compile(self) -> CompiledGraph:
compiled_gm = CompiledGraph(self.gm, max_batch_size=self.max_batch_size)
cuda_graph_batch_sizes = self.compiler_kwargs.get("cuda_graph_batch_sizes", None)
compiled_gm = CompiledGraph(
self.gm,
max_batch_size=self.max_batch_size,
cuda_graph_batch_sizes=cuda_graph_batch_sizes,
)
# try capturing cudagraph
if self.args is not None or self.kwargs is not None:

View File

@ -55,12 +55,13 @@ class BackendCompiler(ABC):
args: Tuple[Any, ...],
kwargs: Optional[Dict[str, Any]] = None,
dynamic_shapes=None,
compiler_kwargs: Optional[Dict[str, Any]] = None,
):
self.gm = gm
self.args = args
self.kwargs = kwargs or {}
self.dynamic_shapes = dynamic_shapes
self.compiler_kwargs = compiler_kwargs or {}
# identify max_batch_size
if self.dynamic_shapes is not None and 0 in self.dynamic_shapes[0]:
self.max_batch_size = self.dynamic_shapes[0][0].max
@ -79,13 +80,16 @@ def compile_and_capture(
args: Tuple[Any, ...],
kwargs: Optional[Dict[str, Any]] = None,
dynamic_shapes=None,
compiler_kwargs: Optional[Dict[str, Any]] = None,
) -> nn.Module:
"""Compile or capture graph for single-token generation."""
elapsed_time = -time.time()
ad_logger.info("Fusion before compiling...")
ad_logger.info(f"Compiling for {backend} backend...")
compiler_cls = BackendRegistry.get(backend)
compiled_module = compiler_cls(gm, args, kwargs, dynamic_shapes).compile()
compiled_module = compiler_cls(gm, args, kwargs, dynamic_shapes, compiler_kwargs).compile()
elapsed_time += time.time()
ad_logger.info(f"Compile time with backend {backend}: {elapsed_time:.6f} seconds")

View File

@ -6,6 +6,7 @@ object-oriented interface to the high-level runtime via the SequenceInfo datacla
is also responsible for functionalizing information about the sequence and pass it on the the
various attention interface. The AttentionDescriptor is the main interface to the attention operator
and operates on a purely functional paradigm that is compatible with the torch custom op system.
"""
from abc import ABC, abstractmethod
@ -121,7 +122,9 @@ class SequenceInfo:
self.page_size = self.max_seq_len
if self.max_num_tokens < 1:
self.max_num_tokens = self.max_batch_size * self.max_seq_len
total_tokens = self.max_batch_size * self.page_size
# if the provided max_num_tokens is less than the max_batch_size * max_seq_len,
# we use the provided max_num_tokens to calculate the number of pages
total_tokens = min(self.max_num_tokens, self.max_batch_size * self.max_seq_len)
self._num_pages = (total_tokens) // self.page_size + (total_tokens % self.page_size > 0)
self.input_ids = torch.ones(self.max_batch_size, 1, dtype=torch.int)
self.seq_len = torch.empty(self.max_batch_size, dtype=torch.int)
@ -191,6 +194,12 @@ class SequenceInfo:
def num_pages(self) -> int:
return self._num_pages
@num_pages.setter
def num_pages(self, value):
self._num_pages = value
# update the cache_loc tensor
self.cache_loc.resize_(value)
@property
def is_paged(self) -> bool:
return self.page_size < self.max_seq_len
@ -306,6 +315,19 @@ class SequenceInfo:
self.nest_sequences(input_ids)
self.input_ids = input_ids
def _set_max_num_tokens_sample(self) -> None:
"""Set an example sequence with max_num_tokens."""
self.reset()
seq_len = self.max_num_tokens // self.max_batch_size
input_ids = torch.ones(
self.max_batch_size,
seq_len,
dtype=torch.int,
device=self.device,
)
self.pages_per_seq.fill_(seq_len // self.page_size)
self.nest_sequences(input_ids)
def _set_generate_only_batch(self) -> None:
"""Set an example sequence for generate-only batch."""
self.reset()
@ -319,16 +341,14 @@ class SequenceInfo:
# set new sequence lengths
seq_lens = [len(ids) for ids in input_ids]
self.seq_len.zero_()
self.seq_len[: len(seq_lens)] = torch.tensor(seq_lens, device=self.device)
self.seq_len[: len(seq_lens)].copy_(torch.tensor(seq_lens), non_blocking=True)
# set new input_ids as new tensor from flattened input_ids
ids_tnsr_list = [
lst.detach().to(self.device)
if isinstance(lst, torch.Tensor)
else torch.tensor(lst, dtype=torch.int, device=self.device)
lst.detach() if isinstance(lst, torch.Tensor) else torch.tensor(lst, dtype=torch.int)
for lst in input_ids
]
self.input_ids = torch.cat(ids_tnsr_list, dim=0)
self.input_ids = torch.cat(ids_tnsr_list, dim=0).to(self.device)
# set derivative properties
self._sequence_lengths = seq_lens
@ -362,10 +382,10 @@ class SequenceInfo:
cache_loc_flat = torch.tensor(
[p_idx for pages in page_assignments for p_idx in pages], dtype=torch.int
)
self.cache_loc[: len(cache_loc_flat)] = cache_loc_flat.to(self.device)
self.cache_loc[: len(cache_loc_flat)].copy_(cache_loc_flat, non_blocking=True)
pages_per_seq = torch.tensor([len(p) for p in page_assignments], dtype=torch.int)
self.pages_per_seq[: len(pages_per_seq)] = pages_per_seq.to(self.device)
self.pages_per_seq[: len(pages_per_seq)].copy_(pages_per_seq, non_blocking=True)
Constant = Union[int, float, str, None]

12
tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py Executable file → Normal file
View File

@ -62,6 +62,7 @@ class _CacheManagerWithFakePool(KVCacheManager):
# TODO (lliebenwein): this is VERY hacky... Ideally, we want to compute the number of blocks
# just like in the original implementation. However, let's wait for the layer-wise attention
# implementation before over-optimizing the function here
ad_logger.info("Using fake cache manager with head_dim=0 and num pages:", self.num_blocks)
return self.num_blocks, 0
@ -86,6 +87,7 @@ class ADEngine(ModelEngine):
device: DeviceLikeType,
):
"""Build the ADEngine using the AutoDeployConfig that gets passed through from the LLM."""
# construct model factory
model_kwargs = {"max_position_embeddings": seq_info.max_seq_len, **ad_config.model_kwargs}
factory = ModelFactoryRegistry.get("hf")(
@ -95,15 +97,7 @@ class ADEngine(ModelEngine):
)
# construct inference optimizer
# TODO (lliebenwein): let's split up the compile backend to separately handle cuda graph
# and torch compile so we can follow the PyTorchConfig here and enable it separately.
if ad_config.use_cuda_graph or ad_config.torch_compile_enabled:
compile_backend = "torch-opt"
else:
compile_backend = "torch-simple"
build_and_optimize = InferenceOptimizer(
factory=factory, attn_backend=ad_config.attn_backend, compile_backend=compile_backend
)
build_and_optimize = InferenceOptimizer(factory=factory, ad_config=ad_config)
# construct engine
engine = cls(build_and_optimize, seq_info, device)

View File

@ -45,6 +45,26 @@ class CachedSequenceInterface:
name: get_cache(self.info) for name, get_cache in self._cache_initializers.items()
}
def current_cache_size_bytes(self) -> int:
"""Calculate and return the total size of all caches in bytes."""
total_size = 0
for name, cache in self._caches.items():
# this hack is needed since _caches also contains global buffers such as freqs_cis.
if "cache" in name:
total_size += cache.element_size() * cache.numel()
return total_size
def resize_cache(self, new_num_pages: int):
"""Resize the cache to the new number of pages."""
# TODO: We should do some sanity check on the new number of pages.
self.info.num_pages = new_num_pages
for name, cache in self._caches.items():
# We assume cache is a tensor of shape (max_batch_size, page_size, n_heads, head_dim)
if "cache" in name:
current_shape = cache.shape
new_shape = (new_num_pages, *current_shape[1:])
cache.resize_(new_shape)
GetInferenceModel = Callable[[CachedSequenceInterface], nn.Module]

View File

@ -171,3 +171,43 @@ def insert_mha_with_kv_cache(
egm = canonicalize_graph(egm, shape_prop=False)
ad_logger.debug("After inserting MHA with KV cache: " + str(egm))
return egm
def resize_kv_cache(
egm: GraphModule, cm: CachedSequenceInterface, free_mem_ratio: float = 0.8
) -> None:
"""Inflate the kv cache to occupy the available GPU memory.
free_mem_ratio specifies the fraction of available memory to occupy.
"""
free_mem, total_mem = torch.cuda.mem_get_info()
ad_logger.info(f"Free memory: {free_mem}, Total memory: {total_mem}")
current_cache_size = cm.current_cache_size_bytes()
current_num_pages = cm.info.num_pages
ad_logger.info(
f"Current cache size: {current_cache_size}, Current num pages: {current_num_pages}"
)
try:
# Let's run a forward pass to get the memory usage
cm.info._set_max_num_tokens_sample()
free_mem_pre, _ = torch.cuda.mem_get_info()
ad_logger.info(f"Free memory before forward pass: {free_mem_pre}")
egm(*cm.args)
free_mem_post, _ = torch.cuda.mem_get_info()
ad_logger.info(f"Free memory after forward pass: {free_mem_post}")
memory_for_forward_pass = free_mem_pre - free_mem_post
ad_logger.info(f"Memory for forward pass: {memory_for_forward_pass}")
new_cache_size = free_mem_post * free_mem_ratio + current_cache_size
new_num_pages = int(new_cache_size // (current_cache_size // current_num_pages))
ad_logger.info(f"New cache size: {new_cache_size}, New num pages: {new_num_pages}")
cm.resize_cache(new_num_pages)
except Exception as e:
ad_logger.warning(
f"Error encountered while resizing kv cache: {e}.\nSkipping cache resize."
)
# Free memory
torch.cuda.empty_cache()

View File

@ -1,12 +1,15 @@
"""High-level entrypoint to transform a model into an efficient inference model."""
import gc
import torch
from torch.fx import GraphModule
from ..compile import compile_and_capture
from ..custom_ops.attention_interface import AttentionRegistry
from ..distributed import common as dist_ad
from ..models.factory import ModelFactory
from ..shim.interface import CachedSequenceInterface
from ..shim.interface import AutoDeployConfig, CachedSequenceInterface
from ..utils.logger import ad_logger
from ._graph import move_to_device
from .export import torch_export_to_gm
@ -21,6 +24,7 @@ from .library import (
insert_mha_with_kv_cache,
match_moe_pattern,
quantize,
resize_kv_cache,
)
@ -29,12 +33,18 @@ class InferenceOptimizer:
self,
factory: ModelFactory,
*, # TODO (lliebenwein): temporary until we have a better config system
attn_backend: str,
compile_backend: str,
ad_config: AutoDeployConfig,
visualize: bool = False,
):
self.factory = factory
self.attn_backend = attn_backend
self.attn_backend = ad_config.attn_backend
# TODO (lliebenwein): let's split up the compile backend to separately handle cuda graph
# and torch compile so we can follow the PyTorchConfig here and enable it separately.
self.ad_config = ad_config
if ad_config.use_cuda_graph or ad_config.torch_compile_enabled:
compile_backend = "torch-opt"
else:
compile_backend = "torch-simple"
self.compile_backend = compile_backend
self.visualize = visualize
@ -103,6 +113,7 @@ class InferenceOptimizer:
# initialize caches, load weights, and map to correct device
cm.initialize_caches()
# load weights
self.factory.load_or_random_init(egm, mmap=True, map_location=cm.device)
move_to_device(egm, cm.device)
@ -135,14 +146,27 @@ class InferenceOptimizer:
except ImportError:
pass
############################################################################################
# RESIZE CACHE
############################################################################################
# Free memory ratio is hardcoded to 0.8 for now to ensure we have enough memory for graph capture.
resize_kv_cache(egm, cm, free_mem_ratio=0.8)
############################################################################################
# COMPILE MODEL
############################################################################################
cm.info._set_generate_only_batch()
compiler_kwargs = {"cuda_graph_batch_sizes": self.ad_config.cuda_graph_batch_sizes}
egm_compiled = compile_and_capture(
egm, self.compile_backend, args=cm.args, dynamic_shapes=cm.dynamic_shapes
egm,
self.compile_backend,
args=cm.args,
dynamic_shapes=cm.dynamic_shapes,
compiler_kwargs=compiler_kwargs,
)
cm.info.reset()
torch.cuda.empty_cache()
gc.collect()
return egm_compiled

6
tensorrt_llm/bench/benchmark/throughput.py Normal file → Executable file
View File

@ -41,7 +41,7 @@ from tensorrt_llm.sampling_params import SamplingParams
help="Path to a serialized TRT-LLM engine.",
)
@optgroup.option("--backend",
type=click.Choice(["pytorch"]),
type=click.Choice(["pytorch", "autodeploy"]),
default=None,
help="Set to 'pytorch' for pytorch path. Default is cpp path.")
@optgroup.option(
@ -209,7 +209,7 @@ def throughput_command(
logger.info(metadata.get_summary_for_print())
# Engine configuration parsing
if backend and backend.lower() == "pytorch":
if backend and backend.lower() in ["pytorch", "autodeploy"]:
exec_settings = get_settings(params, metadata, bench_env.model,
bench_env.checkpoint_path)
kwargs_max_sql = max_seq_len or metadata.max_sequence_length
@ -262,6 +262,8 @@ def throughput_command(
try:
logger.info("Setting up throughput benchmark.")
kwargs = kwargs | runtime_config.get_llm_args()
kwargs['backend'] = backend
if runtime_config.backend == 'pytorch':
llm = PyTorchLLM(**kwargs)
else:

3
tensorrt_llm/bench/benchmark/utils/general.py Normal file → Executable file
View File

@ -144,6 +144,7 @@ def get_settings(params: dict, dataset_metadata: DatasetMetadata, model: str,
"enable_overlap_scheduler": True,
"kv_cache_dtype": kv_cache_dtype,
}
backend = params.get("backend", "pytorch")
return {
"sw_version": version("tensorrt_llm"),
@ -154,7 +155,7 @@ def get_settings(params: dict, dataset_metadata: DatasetMetadata, model: str,
"chunking": False,
},
"world_config": world_config,
"backend": "pytorch",
"backend": backend,
"decoding_config": {},
"performance_options": {
"cuda_graphs": True,

21
tensorrt_llm/bench/dataclasses/configuration.py Normal file → Executable file
View File

@ -9,6 +9,7 @@ from pydantic import (BaseModel, Field, PositiveFloat, field_validator,
model_validator)
import tensorrt_llm.bindings.executor as trtllm
from tensorrt_llm._torch.auto_deploy.shim import AutoDeployConfig
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
from tensorrt_llm.bench.dataclasses.enums import IFBSchedulingPolicy
from tensorrt_llm.llmapi.llm_utils import update_llm_args_with_extra_options
@ -29,7 +30,7 @@ class RuntimeConfig(BaseModel):
world_config: ExecutorWorldConfig
decoding_config: Optional[DecodingConfig] = None
performance_options: PerformanceOptions
backend: Literal["pytorch", None] = None
backend: Literal["pytorch", "autodeploy", None] = None
extra_llm_api_options: Optional[str] = None
def get_llm_args(self) -> Dict:
@ -68,9 +69,14 @@ class RuntimeConfig(BaseModel):
self.settings_config.max_num_tokens,
}
if self.backend == "pytorch":
llm_args["pytorch_backend_config"] = \
self.performance_options.get_pytorch_perf_config()
backend_config_map = {
"pytorch": self.performance_options.get_pytorch_perf_config,
"autodeploy": self.performance_options.get_autodeploy_perf_config
}
if self.backend in backend_config_map:
llm_args["pytorch_backend_config"] = backend_config_map[
self.backend]()
return update_llm_args_with_extra_options(llm_args,
self.extra_llm_api_options)
@ -99,6 +105,13 @@ class PerformanceOptions:
def get_pytorch_perf_config(self) -> PyTorchConfig:
return PyTorchConfig(**self.pytorch_config)
def get_autodeploy_perf_config(self) -> AutoDeployConfig:
ad_config = AutoDeployConfig(**self.pytorch_config)
ad_config.attn_backend = "FlashInfer"
ad_config.torch_compile_enabled = True
ad_config.skip_loading_weights = True
return ad_config
class DecodingConfig(BaseModel):
medusa_choices: Optional[List[List[int]]] = None

4
tensorrt_llm/bench/dataclasses/reporting.py Normal file → Executable file
View File

@ -222,7 +222,7 @@ class ReportUtility:
}
# Engine/Backend details
if self.rt_cfg.backend != 'pytorch':
if self.rt_cfg.backend not in ('pytorch', 'autodeploy'):
config_path = self.rt_cfg.engine_dir / "config.json"
with open(config_path, "r") as config:
engine_config = json.load(config)
@ -392,7 +392,7 @@ class ReportUtility:
decoding = stats_dict.get("decoding_stats", None)
backend_info = ""
if self.rt_cfg.backend != "pytorch":
if self.rt_cfg.backend not in ('pytorch', 'autodeploy'):
config_path = self.rt_cfg.engine_dir / "config.json"
with open(config_path, "r") as config:
engine_config = json.load(config)

View File

@ -5,7 +5,6 @@ from typing import Dict, Optional
import pytest
from _dist_test_utils import param_with_device_count
from _model_test_utils import _hf_model_dir_or_hub_id
from _torch_test_utils import fp8_compatible
from build_and_run_ad import main
from simple_config import SimpleConfig
from utils.llm_data import llm_models_root
@ -94,7 +93,7 @@ from utils.llm_data import llm_models_root
"attn_backend": "FlashInfer",
},
marks_extra=[
pytest.mark.skipif(not fp8_compatible(), reason="Requires fp8 support"),
pytest.mark.skip(reason="https://nvbugspro.nvidia.com/bug/5178508"),
],
),
# full NVSmall (Llama-3.1-Nemotron-51B) with torch-opt backend + simple runtime
@ -124,7 +123,6 @@ from utils.llm_data import llm_models_root
],
)
def test_build_ad(world_size: Optional[int], config: Dict):
pytest.skip("https://nvbugs/5178508")
simple_config = SimpleConfig(**config)
simple_config.world_size = world_size
main(simple_config)

View File

@ -12,20 +12,6 @@ from utils.llm_data import llm_models_root
@pytest.mark.parametrize(
"world_size, config",
[
# small llama3.1-8B model with world_size 0 (no processes are spawned)
(
0,
{
"model": _hf_model_dir_or_hub_id(
f"{llm_models_root()}/llama-3.1-model/Llama-3.1-8B-Instruct",
"meta-llama/Meta-Llama-3.1-8B-Instruct",
),
"runtime": "demollm",
"attn_backend": "TritonWithFlattenedInputs",
"compile_backend": "torch-simple",
"model_kwargs": {"num_hidden_layers": 2},
},
),
# small llama3.1-8B model with world_size 1 (processes are spawned)
(
1,
@ -68,6 +54,20 @@ from utils.llm_data import llm_models_root
"model_kwargs": {"num_hidden_layers": 2},
},
),
# small llama3.1-8B model with world_size 0 (no processes are spawned)
(
0,
{
"model": _hf_model_dir_or_hub_id(
f"{llm_models_root()}/llama-3.1-model/Llama-3.1-8B-Instruct",
"meta-llama/Meta-Llama-3.1-8B-Instruct",
),
"runtime": "demollm",
"attn_backend": "TritonWithFlattenedInputs",
"compile_backend": "torch-simple",
"model_kwargs": {"num_hidden_layers": 2},
},
),
],
)
def test_build_ad(world_size: Optional[int], config: Dict):