[#4403][refactor] Move fusion, kvcache, and compile to modular inference optimizer (#7057)

Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Co-authored-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
This commit is contained in:
Fridah-nv 2025-08-21 10:30:36 -07:00 committed by GitHub
parent 344bc4575d
commit e18dacc931
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 969 additions and 1190 deletions

View File

@ -19,6 +19,11 @@ transforms:
stage: post_export
cleanup_input_constraints:
stage: post_export
############################################################################################
# RUN PATTERN MATCHER TRANSFORMATIONS TO STANDARDIZE GRAPH REPRESENTATION
############################################################################################
match_moe_pattern:
stage: pattern_matcher
match_repeat_kv:
stage: pattern_matcher
match_eager_attention:
@ -27,12 +32,13 @@ transforms:
stage: pattern_matcher
match_attention_layout:
stage: pattern_matcher
match_moe_pattern:
stage: pattern_matcher
match_rope_pattern:
stage: pattern_matcher
match_rope_layout:
stage: pattern_matcher
############################################################################################
# RUN TRANSFORMATIONS ON STANDARDIZED GRAPH REPRESENTATION
############################################################################################
eliminate_redundant_transposes:
stage: pattern_matcher
# TODO (lucaslie): let's move this to perf optimization once TP sharding is improved
@ -57,5 +63,44 @@ transforms:
sharding_transform_executor:
stage: sharding
run_shape_prop: true
############################################################################################
# MOVE MODEL AND LOAD WEIGHTS
############################################################################################
load_weights:
stage: weight_load
############################################################################################
# RUN POST-LOAD FUSION AND OPTIMIZATIONS
############################################################################################
# TODO: https://github.com/NVIDIA/TensorRT-LLM/issues/4674 this is causing OOMs
# fuse_moe:
# stage: post_load_fusion
# fuse_gemms:
# stage: post_load_fusion
fuse_allreduce_residual_rmsnorm:
stage: post_load_fusion
fuse_collectives:
stage: post_load_fusion
# TODO (lucaslie): add backend selection as part of configurable inference optimizers
# check if we can fuse rmsnorm
fuse_rmsnorm:
stage: post_load_fusion
backend: flashinfer
############################################################################################
# SWITCH TO CACHED+FLATTENED ATTENTION + INITIALIZE CACHES
############################################################################################
update_in_out_nodes:
stage: cache_init
insert_cached_attention:
stage: cache_init
insert_cached_mla_attention:
stage: cache_init
attn_backend: MultiHeadLatentAttention
initialize_cache:
stage: cache_init
resize_kv_cache:
stage: cache_init
############################################################################################
# COMPILE MODEL
############################################################################################
compile_model:
stage: compile

View File

@ -54,6 +54,7 @@ class SharedConfig(BaseModel):
sharding_config: ShardingConfig = Field(default_factory=ShardingConfig)
local_rank: int = Field(default=0)
world_size: int = Field(default=1)
attn_backend: str = Field(default="flashinfer", description="The attention backend to use.")
class TransformConfig(BaseModel):
@ -285,7 +286,10 @@ class BaseTransform(ABC):
# update + store new meta data
history[t_name] = info
autodeploy_meta[self._history_key] = history
self._set_autodeploy_meta(gm, autodeploy_meta)
if isinstance(gm, GraphModule):
# After compilation, gm becomes type CapturedGraph with no meta data.
self._set_autodeploy_meta(gm, autodeploy_meta)
# return the graph module
return gm

View File

@ -0,0 +1,204 @@
import operator
from typing import Tuple
import torch
from torch.fx import GraphModule
from ...distributed.trtllm import is_trtllm_op_available
from ...models.factory import ModelFactory
from ...shim.interface import CachedSequenceInterface
from ...utils.node_utils import get_op_overload_packet, get_user_if_pattern_match, is_op
from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry
# TODO: This is an overly simplified model that works well for vanilla Llama models.
# However, we eventually want to consider more sophisticated patterns such as
# * all_reduce(lin1(x) + lin2(x))
# * version above with fused GEMMs (i.e. with a split node)
# * all_reduce(pointwise_op(linear(x)))
# * ...
@TransformRegistry.register("fuse_collectives")
class FuseCollectives(BaseTransform):
"""
Fuses all_reduce ops with preceding (quantized) linear ops into a single fused node for improved performance.
"""
def _apply(
self,
gm: GraphModule,
cm: CachedSequenceInterface,
factory: ModelFactory,
shared_config: SharedConfig,
) -> Tuple[GraphModule, TransformInfo]:
num_gemm_collective_fusions = 0
# lookup for fused ops
# TODO: avoid this hardcoded lookup, e.g., by generating fused ops on the fly.
lookup = {
torch.ops.auto_deploy.torch_linear_simple: torch.ops.auto_deploy.trtllm_dist_fused_linear_all_reduce,
torch.ops.aten.linear: torch.ops.auto_deploy.trtllm_dist_fused_linear_all_reduce,
torch.ops.auto_deploy.torch_quant_fp8_linear: torch.ops.auto_deploy.torch_quant_fused_fp8_linear_all_reduce,
}
# go through all nodes and find all_reduce nodes
for node in gm.graph.nodes:
if not is_op(node, torch.ops.auto_deploy.torch_dist_all_reduce):
continue
# check if args are as expected
assert len(node.args) == 1 and not len(node.kwargs), (
"Unexpected args/kwargs for all_reduce"
)
# retrieve parent and check a few conditions on the parent node
parent_node = node.args[0]
if not is_op(parent_node, lookup.keys()):
continue
if len(parent_node.users) > 1:
continue
with gm.graph.inserting_before(node):
# insert fused node
fused_linear_collective_node = gm.graph.call_function(
lookup[get_op_overload_packet(parent_node.target)],
args=parent_node.args,
kwargs=parent_node.kwargs,
)
node.replace_all_uses_with(fused_linear_collective_node)
gm.graph.erase_node(node)
gm.graph.erase_node(parent_node)
num_gemm_collective_fusions += 1
info = TransformInfo(
skipped=False,
num_matches=num_gemm_collective_fusions,
is_clean=False,
has_valid_shapes=False,
)
return gm, info
@TransformRegistry.register("fuse_allreduce_residual_rmsnorm")
class FuseAllreduceResidualRMSNorm(BaseTransform):
"""Essentially, this transformation fuses the following operators into one allreduce trtllm implementation.
* target pattern:
x = all_reduce(x)
y = x + residual
return rmsnorm(y), y
* replacement:
fused_allreduce_residual_rmsnorm(x, residual, rmsnorm_weight, rmsnorm_eps)
"""
def _apply(
self,
gm: GraphModule,
cm: CachedSequenceInterface,
factory: ModelFactory,
shared_config: SharedConfig,
) -> Tuple[GraphModule, TransformInfo]:
if not is_trtllm_op_available():
return gm, TransformInfo(
skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True
)
num_ar_r_rms_fusions = 0
def trace_and_fuse(allreduce_node, graph):
# Check if all_reduce is followed by addition
users = list(allreduce_node.users.keys())
if len(users) != 1:
return # Skip if all_reduce has more than one consumer
add_node = users[0]
# Traverse nodes for RMSNorm pattern which is composed of to_copy, pow, mean, add, refer
# the Huggingface LlamaRMSNorm implementation as example for more details
to_copy_1 = get_user_if_pattern_match(add_node, [torch.ops.aten.add, operator.add], 2)
# operand of pow and mul
pow_node = get_user_if_pattern_match(
to_copy_1, [torch.ops.aten._to_copy, torch.ops.aten.to], 2
)
mean_node = get_user_if_pattern_match(pow_node, torch.ops.aten.pow, 1)
add_eps_node = get_user_if_pattern_match(mean_node, torch.ops.aten.mean, 1)
rsqrt_node = get_user_if_pattern_match(
add_eps_node, [torch.ops.aten.add, operator.add], 1
)
mul_node_1 = get_user_if_pattern_match(rsqrt_node, torch.ops.aten.rsqrt, 1)
to_copy_2 = get_user_if_pattern_match(mul_node_1, torch.ops.aten.mul, 1)
mul_node_2 = get_user_if_pattern_match(
to_copy_2, [torch.ops.aten._to_copy, torch.ops.aten.to], 1
)
# check args of ops: pow(2) and mean(-1)
ARGS_MATCH = pow_node is not None and pow_node.args[1] == 2 # exponent
ARGS_MATCH &= mean_node is not None and mean_node.args[1] == [-1] # dimensions
# Match found: Replace with fused operation
if (
to_copy_1
and pow_node
and mean_node
and add_eps_node
and rsqrt_node
and mul_node_1
and to_copy_2
and mul_node_2
and ARGS_MATCH
):
# Gather the inputs for the custom operation
tensor = allreduce_node.args[0]
# Identify the residual argument in the add operation
# One of the args in add_node.args is the output of all_reduce
# The same idea also applies to norm_weight
residual = (
add_node.args[0] if add_node.args[1] is allreduce_node else add_node.args[1]
)
norm_weight = (
mul_node_2.args[0] if mul_node_2.args[1] is to_copy_2 else mul_node_2.args[1]
)
eps = add_eps_node.args[1]
# Insert nodes
with graph.inserting_before(allreduce_node):
fused_node = graph.call_function(
torch.ops.dist.fused_allreduce_residual_rmsnorm,
args=(
tensor,
residual,
norm_weight,
eps,
),
)
# Extract outputs from the tuple returned by `fused_node`
final_output_node = gm.graph.create_node(
"call_function",
target=operator.getitem,
args=(fused_node, 0),
)
add_output_node = gm.graph.create_node(
"call_function",
target=operator.getitem,
args=(fused_node, 1),
)
# Replace all uses of rmsnorm_node with final_output_node
mul_node_2.replace_all_uses_with(final_output_node)
# Replace all uses of add_node with add_output_node
add_node.replace_all_uses_with(add_output_node)
nonlocal num_ar_r_rms_fusions
num_ar_r_rms_fusions += 1
# Traverse all nodes
for node in gm.graph.nodes:
if is_op(node, torch.ops.auto_deploy.torch_dist_all_reduce):
trace_and_fuse(allreduce_node=node, graph=gm.graph)
info = TransformInfo(
skipped=False, num_matches=num_ar_r_rms_fusions, is_clean=False, has_valid_shapes=False
)
return gm, info

View File

@ -0,0 +1,65 @@
from typing import List, Literal, Optional, Tuple, Type
from pydantic import Field
from torch.fx import GraphModule
from ...compile import compile_and_capture
from ...models.factory import ModelFactory
from ...shim.interface import CachedSequenceInterface
from ..interface import (
BaseTransform,
SharedConfig,
TransformConfig,
TransformInfo,
TransformRegistry,
)
class CompileModelConfig(TransformConfig):
"""Configuration for the compile model transform."""
cuda_graph_batch_sizes: Optional[List[int]] = Field(
default=None, description="The batch sizes to use for CUDA graphs."
)
num_batched_inputs: int = Field(
default=2, description="The number of batched inputs to use for CUDA graphs."
)
compile_backend: Literal["torch-simple", "torch-compile", "torch-cudagraph", "torch-opt"] = (
Field(description="The backend to use for compiling the model.")
)
@TransformRegistry.register("compile_model")
class CompileModel(BaseTransform):
"""A transform to compile the model."""
config: CompileModelConfig
@classmethod
def get_config_class(cls) -> Type[TransformConfig]:
return CompileModelConfig
def _apply(
self,
gm: GraphModule,
cm: CachedSequenceInterface,
factory: ModelFactory,
shared_config: SharedConfig,
) -> Tuple[GraphModule, TransformInfo]:
cm.info.set_generate_only_batch()
egm_compiled = compile_and_capture(
gm,
self.config.compile_backend,
args=cm.args,
dynamic_shapes=cm.dynamic_shapes,
compiler_kwargs={
"cuda_graph_batch_sizes": self.config.cuda_graph_batch_sizes,
"num_batched_inputs": self.config.num_batched_inputs,
},
)
cm.info.reset()
# store info object about the transform
info = TransformInfo(skipped=False, num_matches=1, is_clean=True, has_valid_shapes=True)
return egm_compiled, info

View File

@ -6,6 +6,8 @@ import torch
import torch.nn as nn
from torch.fx import GraphModule, Node
from ...models.factory import ModelFactory
from ...shim.interface import CachedSequenceInterface
from ...utils.cuda_mem_tracker import cuda_memory_tracker
from ...utils.logger import ad_logger
from ...utils.node_utils import (
@ -14,7 +16,7 @@ from ...utils.node_utils import (
is_linear_op,
)
from ...utils.quantization_utils import QuantizationImpl
from .._graph import canonicalize_graph
from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry
def _insert_fused_gemm(gm: GraphModule, idx: int, parent_node: Node, linear_nodes: List[Node]):
@ -116,30 +118,36 @@ def _insert_fused_gemm(gm: GraphModule, idx: int, parent_node: Node, linear_node
gm.delete_all_unused_submodules()
def fuse_gemms(gm: GraphModule) -> None:
ad_logger.info("GEMM fusion")
ad_logger.debug("Before GEMM fusion: " + str(gm))
# sort linear nodes by parent node
linear_nodes = defaultdict(list)
for node in gm.graph.nodes:
# TODO: we don't handle bias for now...
if is_linear_op(node, include_quantization=True) and node.args[2] is None:
linear_nodes[node.args[0]].append(node)
@TransformRegistry.register("fuse_gemms")
class FuseGemms(BaseTransform):
def _apply(
self,
gm: GraphModule,
cm: CachedSequenceInterface,
factory: ModelFactory,
shared_config: SharedConfig,
) -> Tuple[GraphModule, TransformInfo]:
# sort linear nodes by parent node
linear_nodes = defaultdict(list)
for node in gm.graph.nodes:
# TODO: we don't handle bias for now...
if is_linear_op(node, include_quantization=True) and node.args[2] is None:
linear_nodes[node.args[0]].append(node)
# fuse linear nodes
idx = -1
with cuda_memory_tracker():
for parent_node, lin_children in linear_nodes.items():
if len(lin_children) < 2:
continue
# linear nodes to fuse
ad_logger.debug(
f"Found linear nodes to fuse: {lin_children} with parent node: {parent_node}"
)
_insert_fused_gemm(gm, idx := idx + 1, parent_node, lin_children)
# fuse linear nodes
idx = -1
num_matches = 0
with cuda_memory_tracker():
for parent_node, lin_children in linear_nodes.items():
if len(lin_children) < 2:
continue
# linear nodes to fuse
_insert_fused_gemm(gm, idx := idx + 1, parent_node, lin_children)
num_matches += 1
# clean up and return
canonicalize_graph(gm)
torch.cuda.empty_cache()
ad_logger.debug("After GEMM fusion: " + str(gm))
torch.cuda.empty_cache()
info = TransformInfo(
skipped=False, num_matches=num_matches, is_clean=False, has_valid_shapes=False
)
return gm, info

View File

@ -0,0 +1,299 @@
"""Graph transformation to automatically add kv cache into fused MHA op."""
import operator
from typing import Dict, Optional, Tuple, Type
import torch
from pydantic import Field
from torch.fx import Graph, GraphModule, Node
from ...custom_ops.attention_interface import AttentionRegistry
from ...distributed.common import all_gather_object, get_world_size
from ...models.factory import ModelFactory
from ...shim.interface import CachedSequenceInterface
from ...transformations._graph import add_graph_input
from ...utils.logger import ad_logger
from ...utils.node_utils import get_all_input_output_nodes, is_op
from ..interface import (
BaseTransform,
SharedConfig,
TransformConfig,
TransformInfo,
TransformRegistry,
)
@TransformRegistry.register("update_in_out_nodes")
class UpdateInOutNodes(BaseTransform):
"""Modify the graph module by adding new input nodes.
The new input nodes correspond to the extra arguments needed for cached and flattened attention.
Args:
egm: The graph module to analyze and modify.
cm: Cached sequence interface containing extra argument information.
"""
def _apply(
self,
gm: GraphModule,
cm: CachedSequenceInterface,
factory: ModelFactory,
shared_config: SharedConfig,
) -> Tuple[GraphModule, TransformInfo]:
# loop through nodes to get input, output, and get_attr nodes
input_nodes, output_nodes = get_all_input_output_nodes(gm.graph)
# we only expect one input node
assert len(input_nodes) == 2, "Expected exactly two input nodes (input_ids, position_ids)."
# NOTE: for now, we wanna make sure we *only* return the final output and no hidden states.
# Later on, we can revisit how to support returning hidden states.
assert len(output_nodes) == 1, "Expected exactly one output node!"
assert len(output_nodes[0].all_input_nodes) == 1, (
"Expected to only return final tensor output!"
)
# Activate and add extra argument nodes
new_args = cm.info.switch_to_cached_attn_inputs()
for name in new_args:
input_nodes.append(add_graph_input(gm, name))
info = TransformInfo(skipped=False, num_matches=1, is_clean=False, has_valid_shapes=False)
return gm, info
class InsertCachedAttentionConfig(TransformConfig):
"""Configuration for the insert cached attention transform."""
attn_backend: Optional[str] = Field(default=None, description="The attention backend to use.")
@TransformRegistry.register("insert_cached_attention")
class InsertCachedAttention(BaseTransform):
"""
A transform to insert cached attention into the graph module.
If attn_backend is not provided in transform config, will find from shared config.
"""
config: InsertCachedAttentionConfig
@classmethod
def get_config_class(cls) -> Type[TransformConfig]:
return InsertCachedAttentionConfig
def _apply(
self,
gm: GraphModule,
cm: CachedSequenceInterface,
factory: ModelFactory,
shared_config: SharedConfig,
) -> Tuple[GraphModule, TransformInfo]:
"""Replace uncached source attention node with corresponding cached attn node."""
attn_descriptor = AttentionRegistry.get(self.config.attn_backend)
cache_config = factory.get_cache_config()
# Get all attention nodes and their info objects
source_op = attn_descriptor.get_source_attention_op()
# pick up graph
graph: Graph = gm.graph
# look for relevant source attention nodes
source_attn_nodes = [n for n in graph.nodes if is_op(n, source_op)]
if not source_attn_nodes:
# If there are no nodes for kv cache insertion found, return current graph
return gm, TransformInfo(
skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True
)
# Sanity check
if cm.info.is_paged:
assert attn_descriptor.is_paged(), "Paged sequence info requires paged attention op."
# retrieve input nodes
input_nodes, _ = get_all_input_output_nodes(gm.graph)
# insert metadata computation and extract each argument as a node
get_metadata, num_metadata = attn_descriptor.get_prepare_metadata_op()
with graph.inserting_before(input_nodes[-1].next):
ret_node = graph.call_function(
get_metadata,
args=(
*input_nodes,
cm.info.page_size,
),
)
metadata_nodes = [
graph.call_function(operator.getitem, args=(ret_node, idx))
for idx in range(num_metadata)
]
buffer_in_lookup: Dict[str, Node] = {}
# replace fused attention node with attention node that has kv cache
num_cached_attn_replacements = 0
for idx, attn_node in enumerate(source_attn_nodes):
# pick out GEMMs
qkv = attn_node.args[: attn_descriptor.get_num_qkv_args()]
# setup + store cache initializers and caches as input nodes
cache_in_nodes = []
for k, get_cache in attn_descriptor.get_cache_initializers(
attn_node, cache_config
).items():
k_indexed = f"{k}_{idx}"
cm.add_cache(k_indexed, get_cache)
cache_in_nodes.append(add_graph_input(gm, k_indexed))
# setup + store global buffer initializers and buffers as input nodes
# NOTE: we have to check against existing keys to make sure nothing is registered twice...
buffer_in_nodes = []
for k, get_buffer in attn_descriptor.get_global_buffer_initializers(attn_node).items():
if k not in buffer_in_lookup:
cm.add_cache(k, get_buffer)
buffer_in_lookup[k] = add_graph_input(gm, k)
buffer_in_nodes.append(buffer_in_lookup[k]) # store buffer nodes for this op
# retrieve constants for attention_op
constants = attn_descriptor.get_constants(attn_node)
# insert cached attention replacement op
with graph.inserting_before(attn_node):
cached_attn_node = graph.call_function(
attn_descriptor.get_cached_attention_op(),
args=(*qkv, *metadata_nodes, *cache_in_nodes, *buffer_in_nodes, *constants),
)
attn_node.replace_all_uses_with(cached_attn_node)
graph.erase_node(attn_node)
num_cached_attn_replacements += 1
info = TransformInfo(
skipped=False,
num_matches=num_cached_attn_replacements,
is_clean=False,
has_valid_shapes=False,
)
return gm, info
@TransformRegistry.register("insert_cached_mla_attention")
class InsertCachedMLAAttention(InsertCachedAttention):
"""
A transform to insert cached MLA attention into the graph module.
This class is identical to InsertCachedAttention and inherits all its behavior.
"""
pass
class ResizeKVCacheConfig(TransformConfig):
"""Configuration for the resize kv cache transform."""
free_mem_ratio: float = Field(
description="The fraction of available memory to occupy.", default=0.8
)
@TransformRegistry.register("resize_kv_cache")
class ResizeKVCache(BaseTransform):
"""Inflate the kv cache to occupy the available GPU memory.
free_mem_ratio specifies the fraction of available memory to occupy.
"""
config: ResizeKVCacheConfig
@classmethod
def get_config_class(cls) -> Type[TransformConfig]:
return ResizeKVCacheConfig
def _apply(
self,
gm: GraphModule,
cm: CachedSequenceInterface,
factory: ModelFactory,
shared_config: SharedConfig,
) -> Tuple[GraphModule, TransformInfo]:
free_mem_ratio = self.config.free_mem_ratio
def _get_mem_info_in_mb():
free_mem, total_mem = torch.cuda.mem_get_info()
return free_mem // 1024**2, total_mem // 1024**2
free_mem, total_mem = _get_mem_info_in_mb()
ad_logger.info(f"Free memory (MB): {free_mem}, Total memory (MB): {total_mem}")
current_cache_size = cm.current_cache_size_bytes()
current_num_pages = cm.info.num_pages
ad_logger.info(
f"Current cache size: {current_cache_size}, Current num pages: {current_num_pages}"
)
if free_mem_ratio == 0.0:
ad_logger.info(f"Skipping cache resize for {free_mem_ratio=}")
return gm, TransformInfo(
skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True
)
try:
# Let's run a forward pass to get the memory usage
cm.info._set_max_num_tokens_sample()
free_mem_pre, _ = _get_mem_info_in_mb()
ad_logger.info(f"Free memory before forward pass (MB): {free_mem_pre}")
gm(*cm.args)
free_mem_post, _ = _get_mem_info_in_mb()
ad_logger.info(f"Free memory after forward pass (MB): {free_mem_post}")
memory_for_forward_pass = free_mem_pre - free_mem_post
ad_logger.info(f"Memory for forward pass (MB): {memory_for_forward_pass}")
new_cache_size = free_mem_post * 1024 * 1024 * free_mem_ratio + current_cache_size
new_num_pages = int(new_cache_size // (current_cache_size // current_num_pages))
# Need to sync all the GPUs
gathered_num_pages = [None] * get_world_size()
all_gather_object(gathered_num_pages, new_num_pages)
new_num_pages = min(gathered_num_pages)
ad_logger.info(f"After all_gather - new_num_pages: {new_num_pages}")
cm.resize_cache(new_num_pages)
except Exception as e:
ad_logger.warning(
f"Error encountered while resizing kv cache: {e}.\nSkipping cache resize."
)
# Free memory
torch.cuda.empty_cache()
info = TransformInfo(
skipped=False,
num_matches=0,
is_clean=True,
has_valid_shapes=True,
)
return gm, info
@TransformRegistry.register("initialize_cache")
class InitializeCache(BaseTransform):
def _apply(
self,
gm: GraphModule,
cm: CachedSequenceInterface,
factory: ModelFactory,
shared_config: SharedConfig,
) -> Tuple[GraphModule, TransformInfo]:
cm.initialize_caches()
info = TransformInfo(skipped=False, num_matches=1, is_clean=True, has_valid_shapes=True)
return gm, info

View File

@ -0,0 +1,148 @@
"""Graph transform to optimize RMSNorm execution using FlashInfer."""
from functools import partial
from typing import Tuple, Type
import torch
from pydantic import Field
from torch.fx import GraphModule
from ...models.factory import ModelFactory
from ...shim.interface import CachedSequenceInterface
# It is important to import ADPatternMatcherPass from pattern_matcher.py, not from torch._inductor.pattern_matcher
from ...utils.pattern_matcher import ADPatternMatcherPass, register_ad_pattern
from ..interface import (
BaseTransform,
SharedConfig,
TransformConfig,
TransformInfo,
TransformRegistry,
)
_BACKEND_OPS = {
"flashinfer": torch.ops.auto_deploy.flashinfer_rms_norm,
"triton": torch.ops.auto_deploy.triton_rms_norm,
"torch": torch.ops.auto_deploy.torch_rmsnorm,
}
def _rms_norm_pattern(data: torch.Tensor, weight: torch.Tensor, eps: float) -> torch.Tensor:
"""Implements the RMSNorm pattern for pattern matching.
Args:
data: Input tensor to normalize.
weight: Scaling weights for the normalized output.
eps: Small constant for numerical stability.
Returns:
Normalized and scaled tensor.
"""
input_dtype = data.dtype
data = data.to(torch.float32)
variance = data.pow(2).mean(-1, keepdim=True)
data = data * torch.rsqrt(variance + eps)
return weight * data.to(input_dtype)
def _rms_norm_replacement(
data: torch.Tensor, weight: torch.Tensor, eps: float, backend: str
) -> torch.Tensor:
"""Backend-specific rms_norm implementation.
Args:
data: Input tensor to normalize.
weight: Scaling weights for the normalized output.
eps: Small constant for numerical stability.
backend: Backend to use for RMSNorm computation ("flashinfer" or "triton").
Returns:
Normalized and scaled tensor using the specified backend implementation.
"""
assert backend.lower() in _BACKEND_OPS, (
f"Invalid {backend=}; must be one of {list(_BACKEND_OPS)}"
)
return _BACKEND_OPS[backend.lower()](data, weight, eps)
class FuseRMSNormConfig(TransformConfig):
"""Configuration for the RMSNorm fusion transform."""
backend: str = Field(
default="flashinfer",
description="Backend to use for RMSNorm computation ('flashinfer' or 'triton').",
)
@TransformRegistry.register("fuse_rmsnorm")
class FuseRMSNorm(BaseTransform):
"""Matches and replaces RMSNorm patterns in the graph with FlashInfer or Triton implementation.
This function sets up pattern matching to identify RMSNorm operations in the graph
and replaces them with optimized implementations. It uses dummy tensors to register
the pattern matching rules.
Args:
gm: Input graph module to transform.
backend: Backend to use for RMSNorm computation ("flashinfer" or "triton").
Returns:
Transformed graph module with optimized RMSNorm operations.
"""
config: FuseRMSNormConfig
@classmethod
def get_config_class(cls) -> Type[TransformConfig]:
return FuseRMSNormConfig
def _apply(
self,
gm: GraphModule,
cm: CachedSequenceInterface,
factory: ModelFactory,
shared_config: SharedConfig,
) -> Tuple[GraphModule, TransformInfo]:
if self.config.backend.lower() not in _BACKEND_OPS:
raise ValueError(
f"Invalid backend, must be one of {list(_BACKEND_OPS)}, got {self.config.backend}"
)
graph = gm.graph
patterns = ADPatternMatcherPass()
# Create dummy tensors for pattern matching
bs = 2
hidden_size = 512
def dummy_args(input_dtype: torch.dtype, weight_dtype: torch.dtype, eps: float = 1e-6):
return [
torch.randn(bs, hidden_size, device="cuda", dtype=input_dtype),
torch.randn(hidden_size, device="cuda", dtype=weight_dtype),
eps,
]
# Define configurations for different data types
configs = [
(torch.bfloat16, torch.bfloat16),
(torch.float16, torch.float16),
(torch.float32, torch.float32),
]
# Register patterns for each configuration
for input_dtype, weight_dtype in configs:
register_ad_pattern(
search_fn=_rms_norm_pattern,
replace_fn=partial(_rms_norm_replacement, backend=self.config.backend),
patterns=patterns,
dummy_args=dummy_args(input_dtype, weight_dtype),
op_ignore_types={},
scalar_workaround={"eps": 1e-6},
)
cnt = patterns.apply(graph)
info = TransformInfo(skipped=False, num_matches=cnt, is_clean=False, has_valid_shapes=False)
return gm, info

View File

@ -1,11 +1,5 @@
"""A library of transformation passes."""
from .collectives import *
from .fused_moe import *
from .fusion import *
from .kvcache import *
from .rms_norm import *
try:
from .visualization import visualize_namespace
except ImportError:

View File

@ -1,167 +0,0 @@
import operator
import torch
from torch.fx import GraphModule
from ...distributed.trtllm import is_trtllm_op_available
from ...utils.logger import ad_logger
from ...utils.node_utils import get_op_overload_packet, get_user_if_pattern_match, is_op
from .._graph import canonicalize_graph
# TODO: This is an overly simplified model that works well for vanilla Llama models.
# However, we eventually want to consider more sophisticated patterns such as
# * all_reduce(lin1(x) + lin2(x))
# * version above with fused GEMMs (i.e. with a split node)
# * all_reduce(pointwise_op(linear(x)))
# * ...
def fuse_collectives(gm: GraphModule) -> None:
num_gemm_collective_fusions = 0
ad_logger.debug("Before GEMM+Collective fusion: " + str(gm))
# lookup for fused ops
# TODO: avoid this hardcoded lookup, e.g., by generating fused ops on the fly.
lookup = {
torch.ops.auto_deploy.torch_linear_simple: torch.ops.auto_deploy.trtllm_dist_fused_linear_all_reduce,
torch.ops.aten.linear: torch.ops.auto_deploy.trtllm_dist_fused_linear_all_reduce,
torch.ops.auto_deploy.torch_quant_fp8_linear: torch.ops.auto_deploy.torch_quant_fused_fp8_linear_all_reduce,
}
# go through all nodes and find all_reduce nodes
for node in gm.graph.nodes:
if not is_op(node, torch.ops.auto_deploy.torch_dist_all_reduce):
continue
# check if args are as expected
assert len(node.args) == 1 and not len(node.kwargs), "Unexpected args/kwargs for all_reduce"
# retrieve parent and check a few conditions on the parent node
parent_node = node.args[0]
if not is_op(parent_node, lookup.keys()):
continue
if len(parent_node.users) > 1:
continue
with gm.graph.inserting_before(node):
# insert fused node
fused_linear_collective_node = gm.graph.call_function(
lookup[get_op_overload_packet(parent_node.target)],
args=parent_node.args,
kwargs=parent_node.kwargs,
)
node.replace_all_uses_with(fused_linear_collective_node)
gm.graph.erase_node(node)
gm.graph.erase_node(parent_node)
num_gemm_collective_fusions += 1
canonicalize_graph(gm)
ad_logger.info(f"Found {num_gemm_collective_fusions} GEMM+Collective fusions")
ad_logger.debug("After GEMM+Collective fusion: " + str(gm))
def fuse_allreduce_residual_rmsnorm(gm: GraphModule) -> None:
"""Essentially, this function fuses the following operators into one allreduce trtllm implementation.
* target pattern:
x = all_reduce(x)
y = x + residual
return rmsnorm(y), y
* replacement:
fused_allreduce_residual_rmsnorm(x, residual, rmsnorm_weight, rmsnorm_eps)
"""
if not is_trtllm_op_available():
return
num_ar_r_rms_fusions = 0
ad_logger.debug("Before allreduce+residual+rmsnorm fusion: " + str(gm))
def trace_and_fuse(allreduce_node, graph):
# Check if all_reduce is followed by addition
users = list(allreduce_node.users.keys())
if len(users) != 1:
return # Skip if all_reduce has more than one consumer
add_node = users[0]
# Traverse nodes for RMSNorm pattern which is composed of to_copy, pow, mean, add, refer
# the Huggingface LlamaRMSNorm implementation as example for more details
to_copy_1 = get_user_if_pattern_match(add_node, [torch.ops.aten.add, operator.add], 2)
# operand of pow and mul
pow_node = get_user_if_pattern_match(
to_copy_1, [torch.ops.aten._to_copy, torch.ops.aten.to], 2
)
mean_node = get_user_if_pattern_match(pow_node, torch.ops.aten.pow, 1)
add_eps_node = get_user_if_pattern_match(mean_node, torch.ops.aten.mean, 1)
rsqrt_node = get_user_if_pattern_match(add_eps_node, [torch.ops.aten.add, operator.add], 1)
mul_node_1 = get_user_if_pattern_match(rsqrt_node, torch.ops.aten.rsqrt, 1)
to_copy_2 = get_user_if_pattern_match(mul_node_1, torch.ops.aten.mul, 1)
mul_node_2 = get_user_if_pattern_match(
to_copy_2, [torch.ops.aten._to_copy, torch.ops.aten.to], 1
)
# check args of ops: pow(2) and mean(-1)
ARGS_MATCH = pow_node is not None and pow_node.args[1] == 2 # exponent
ARGS_MATCH &= mean_node is not None and mean_node.args[1] == [-1] # dimensions
# Match found: Replace with fused operation
if (
to_copy_1
and pow_node
and mean_node
and add_eps_node
and rsqrt_node
and mul_node_1
and to_copy_2
and mul_node_2
and ARGS_MATCH
):
# Gather the inputs for the custom operation
tensor = allreduce_node.args[0]
# Identify the residual argument in the add operation
# One of the args in add_node.args is the output of all_reduce
# The same idea also applies to norm_weight
residual = add_node.args[0] if add_node.args[1] is allreduce_node else add_node.args[1]
norm_weight = (
mul_node_2.args[0] if mul_node_2.args[1] is to_copy_2 else mul_node_2.args[1]
)
eps = add_eps_node.args[1]
# Insert nodes
with graph.inserting_before(allreduce_node):
fused_node = graph.call_function(
torch.ops.dist.fused_allreduce_residual_rmsnorm,
args=(
tensor,
residual,
norm_weight,
eps,
),
)
# Extract outputs from the tuple returned by `fused_node`
final_output_node = gm.graph.create_node(
"call_function",
target=operator.getitem,
args=(fused_node, 0),
)
add_output_node = gm.graph.create_node(
"call_function",
target=operator.getitem,
args=(fused_node, 1),
)
# Replace all uses of rmsnorm_node with final_output_node
mul_node_2.replace_all_uses_with(final_output_node)
# Replace all uses of add_node with add_output_node
add_node.replace_all_uses_with(add_output_node)
nonlocal num_ar_r_rms_fusions
num_ar_r_rms_fusions += 1
# Traverse all nodes
for node in gm.graph.nodes:
if is_op(node, torch.ops.auto_deploy.torch_dist_all_reduce):
trace_and_fuse(allreduce_node=node, graph=gm.graph)
canonicalize_graph(gm)
ad_logger.info(f"Found {num_ar_r_rms_fusions} allreduce+residual+rmsnorm fusions")
ad_logger.debug("After allreduce+residual+rmsnorm fusion: " + str(gm))

View File

@ -1,511 +0,0 @@
from collections import defaultdict
from typing import Optional
import torch
from torch.fx import GraphModule, Node
from ...utils.cuda_mem_tracker import cuda_memory_tracker
from ...utils.logger import ad_logger
from ...utils.node_utils import bfs, identify_regions_between_residuals, is_linear_op, is_op
from ...utils.quantization_utils import get_scales_and_type_from_node
from .._graph import canonicalize_graph
def match_moe_pattern(gm: GraphModule) -> None:
graph = gm.graph
ad_logger.debug("Before MoE Pattern Matching: " + str(gm))
# Preprocessing: Identify boundary nodes (e.g. residual connections) in the graph.
boundary_nodes = identify_regions_between_residuals(gm)
num_moe_patterns = 0
for start_boundary, end_boundary in zip(boundary_nodes[:-1], boundary_nodes[1:]):
# Step 1: Identify Expert Compute pattern
(pattern_input_nodes, pattern_output_nodes, expert_weights, expert_scales, weight_type) = (
_match_expert_compute_pattern(start_boundary, end_boundary)
)
if not expert_weights:
continue
# TODO: naming convention to verify the order of the weight nodes
# Step 2: Trace upwards to locate normalize_routing_weight and selected_experts:
arg1_list, arg2_list = _extract_index_branches_from_expert_outputs(pattern_output_nodes)
normalized_routing_weights = _find_lowest_common_ancessor(arg1_list)
if not normalized_routing_weights:
continue
common_ancessor2 = _find_lowest_common_ancessor(arg2_list)
if not common_ancessor2:
continue
selected_experts = bfs(
common_ancessor2,
lambda node: is_op(node, torch.ops.aten.one_hot),
attr_next="all_input_nodes",
boundary=start_boundary,
).args[0]
if not selected_experts:
continue
# Step 3: Trace upwards to find input node:
hidden_states = _find_lowest_common_ancessor(pattern_input_nodes)
if not hidden_states:
continue
# Step 4: Find output node with the combine pattern
final_hidden_state_node = _find_final_hidden_state_node(pattern_output_nodes, end_boundary)
if final_hidden_state_node is None:
continue
# Step 5: Insert the MoE op into the graph.
ad_logger.debug(
f"Found MoE Pattern: between boundary {start_boundary} and {end_boundary}.\n"
f"Input hidden states node: {hidden_states}, "
f"selected_experts node: {selected_experts}, "
f"routing_weights node: {normalized_routing_weights}, "
f"expert weights: {expert_weights}, weight type: {weight_type}"
)
with graph.inserting_before(final_hidden_state_node):
w1_list = expert_weights["w1"]
w2_list = expert_weights["w2"]
w3_list = expert_weights["w3"]
if weight_type == "fp8":
fused_moe_node = graph.call_function(
torch.ops.auto_deploy.torch_quant_fp8_moe,
args=(
hidden_states,
selected_experts,
normalized_routing_weights,
w1_list,
w2_list,
w3_list,
expert_scales["w1_input_scale"],
expert_scales["w2_input_scale"],
expert_scales["w3_input_scale"],
expert_scales["w1_weight_scale"],
expert_scales["w2_weight_scale"],
expert_scales["w3_weight_scale"],
),
)
elif weight_type == "fp4":
fused_moe_node = graph.call_function(
torch.ops.auto_deploy.torch_quant_fp4_moe,
args=(
hidden_states,
selected_experts,
normalized_routing_weights,
w1_list,
w2_list,
w3_list,
expert_scales["w1_input_scale"],
expert_scales["w2_input_scale"],
expert_scales["w3_input_scale"],
expert_scales["w1_weight_scale"],
expert_scales["w2_weight_scale"],
expert_scales["w3_weight_scale"],
expert_scales["w1_alpha"],
expert_scales["w2_alpha"],
expert_scales["w3_alpha"],
),
)
else:
fused_moe_node = graph.call_function(
torch.ops.auto_deploy.torch_moe,
args=(
hidden_states,
selected_experts,
normalized_routing_weights,
w1_list,
w2_list,
w3_list,
),
)
final_hidden_state_node.replace_all_uses_with(fused_moe_node)
graph.erase_node(final_hidden_state_node)
while _remove_dead_inplace_nodes_in_region(gm.graph, start_boundary, end_boundary):
gm.graph.eliminate_dead_code()
num_moe_patterns += 1
canonicalize_graph(gm)
ad_logger.info(f"Found {num_moe_patterns} MoE Patterns")
ad_logger.debug("After MoE Pattern Matching: " + str(gm))
def fuse_moe(gm: torch.fx.GraphModule) -> None:
"""
Scan the FX graph and replace all calls to torch.ops.auto_deploy.torch_moe with
torch.ops.auto_deploy.trtllm_moe_fused.
"""
ad_logger.debug("Before MoE fusion: " + str(gm))
with cuda_memory_tracker():
fused_key_counter = _insert_fused_moe_ops(gm)
if fused_key_counter:
canonicalize_graph(gm)
ad_logger.info(f"Found {fused_key_counter} MoE fusions")
ad_logger.debug("After MoE fusion: " + str(gm))
def _insert_fused_moe_ops(gm: GraphModule) -> int:
fused_key_counter = 0
graph = gm.graph
for node in list(graph.nodes):
if not is_op(node, torch.ops.auto_deploy.torch_moe):
continue
ad_logger.debug(f"Found MoE op to fuse: {node} with args: {node.args}")
hidden_states, selected_experts, routing_weights, w1_list, w2_list, w3_list = node.args
fused_w3_w1_experts = torch.stack(
[
torch.cat(
[gm.get_parameter(w3_node.target), gm.get_parameter(w1_node.target)], dim=-2
)
for w1_node, w3_node in zip(w1_list, w3_list)
],
dim=0,
)
fused_w2_experts = torch.stack([gm.get_parameter(n.target) for n in w2_list], dim=0)
new_key_w3_w1 = f"fused_moe_w3_w1_stacked_{fused_key_counter}"
new_key_w2 = f"fused_moe_w2_stacked_{fused_key_counter}"
fused_key_counter += 1
param_w3_w1 = torch.nn.Parameter(fused_w3_w1_experts)
param_w2 = torch.nn.Parameter(fused_w2_experts)
gm.register_parameter(new_key_w3_w1, param_w3_w1)
gm.register_parameter(new_key_w2, param_w2)
with graph.inserting_before(node):
new_node = graph.call_function(
# TODO(Fridah-nv): torch.ops.auto_deploy.trtllm_moe_fused for quantized models
torch.ops.auto_deploy.trtllm_moe_fused,
args=(
hidden_states,
selected_experts,
routing_weights,
graph.get_attr(new_key_w3_w1),
graph.get_attr(new_key_w2),
),
)
node.replace_all_uses_with(new_node)
graph.erase_node(node)
return fused_key_counter
def _find_lowest_common_ancessor(nodes: list[Node]) -> Optional[Node]:
"""
Find the lowest common ancestor for a list of nodes in a torch.fx Graph by following
each node's primary branch (recursively following the first Node argument).
It first finds the LCA of the first two nodes and then
iteratively computes the LCA of the result with the next node, and so on.
Returns:
The common ancestor Node if found, otherwise None.
"""
if not nodes:
return None
def get_parent(node: Node) -> Optional[Node]:
"""Return the first Node-valued argument for a given node, or None if not found."""
for arg in node.args:
if isinstance(arg, Node):
return arg
return None
def get_depth(node: Node) -> int:
"""
Recursively compute the depth of the node by following its primary branch.
Depth is defined as the number of steps to reach a node with no parent.
"""
parent = get_parent(node)
if parent is None:
return 0
return 1 + get_depth(parent)
def lca_two(a: Node, b: Node) -> Optional[Node]:
"""
Find the lowest common ancestor of two nodes by first equalizing their depth
and then moving upward until a common node is found.
"""
depth_a = get_depth(a)
depth_b = get_depth(b)
# Equalize depths
while depth_a > depth_b:
a = get_parent(a)
depth_a -= 1
while depth_b > depth_a:
b = get_parent(b)
depth_b -= 1
# Walk upward in lockstep
while a is not None and b is not None:
if a is b:
return a
a = get_parent(a)
b = get_parent(b)
return None
# Iteratively compute the LCA across all nodes.
common = nodes[0]
for node in nodes[1:]:
common = lca_two(common, node)
if common is None:
return None
return common
def _extract_linear_parameters(linear_node: Node) -> tuple[Node, torch.Tensor, Optional[dict], str]:
"""
Given a linear op node, extract the input tensor node, weight tensor,
any quantization scales (if the op is quantized), and return a weight type.
For a torch.ops.auto_deploy.torch_linear_simple.default op:
- Returns (input_node, weight, None, "simple")
For a torch.ops.auto_deploy.torch_quant_fp8_linear op:
- Returns (input_node, weight, {"input_scale": input_scale, "weight_scale": weight_scale}, "fp8")
For a torch.ops.auto_deploy.torch_quant_fp4_linear op:
- Returns (input_node, weight, {"input_scale": input_scale, "weight_scale": weight_scale, "alpha": alpha}, "fp4")
"""
input_node = linear_node.args[0]
if is_op(linear_node, torch.ops.auto_deploy.torch_linear_simple):
weight = linear_node.args[1]
return input_node, weight, None, ""
elif {
is_op(linear_node, torch.ops.auto_deploy.torch_quant_fp4_linear),
is_op(linear_node, torch.ops.auto_deploy.torch_quant_fp8_linear),
}:
weight = linear_node.args[1]
scales, quant_type = get_scales_and_type_from_node(linear_node)
return input_node, weight, scales, quant_type
def _match_expert_compute_pattern(start_boundary: Node, end_boundary: Node):
"""
Match the expert compute pattern between the given boundaries.
The expert compute pattern corresponds to:
(F.silu(x @ w1.t()) * (x @ w3.t())) @ w2.t()
For each expert, the function extracts the input node from the w1 branch and
collects the weight parameters from three linear ops (w1, w3, and w2 branches).
This function supports both:
- torch.ops.auto_deploy.torch_linear_simple.default ops, and
- torch.ops.auto_deploy.torch_quant_fp8_linear ops (also extracts quantization scales).
- torch.ops.auto_deploy.torch_quant_fp4_linear ops (also extracts quantization scales).
Returns:
A tuple:
(pattern_input_nodes, pattern_output_nodes, expert_weights, expert_scales, weight_type)
- pattern_input_nodes: List of input nodes (x) used for the expert compute.
- pattern_output_nodes: List of final expert output nodes (the linear op with weight w2).
- expert_weights: Dict with keys "w1", "w2", "w3" mapping to lists of weight tensors.
- expert_scales: Dict with keys "w1_input_scale", "w1_weight_scale", etc., containing scale tensors
(empty if weight_type is "simple").
- weight_type: "fp8" if FP8 ops were used, "simple" otherwise.
"""
pattern_input_nodes, pattern_output_nodes = [], []
expert_weights = defaultdict(list)
expert_scales = defaultdict(list)
weight_type = "simple" # default
nodes = list(start_boundary.graph.nodes)
region_nodes = nodes[nodes.index(start_boundary) + 1 : nodes.index(end_boundary)]
for node in region_nodes:
# Accept both simple and quantized linear ops.
if not is_linear_op(node, include_quantization=True):
continue
final_linear = node
if not final_linear.args or not isinstance(final_linear.args[0], Node):
continue
mul_node = final_linear.args[0]
if not is_op(mul_node, torch.ops.aten.mul) or len(mul_node.args) < 2:
continue
arg_a, arg_b = mul_node.args[:2]
silu_node = (
arg_a
if is_op(arg_a, torch.ops.aten.silu)
else arg_b
if is_op(arg_b, torch.ops.aten.silu)
else None
)
if silu_node is None:
continue
if not (silu_node.args and is_linear_op(silu_node.args[0], include_quantization=True)):
continue
linear_w1_node = silu_node.args[0]
# The other branch should be a linear op (w3 branch).
linear_w3_node = arg_b if arg_a is silu_node else arg_a
if not is_linear_op(linear_w3_node, include_quantization=True):
continue
if not (linear_w1_node.args and linear_w3_node.args):
continue
# Extract parameters from each linear op.
input_node_w1, weight_w1, quant_params_w1, wt_type_w1 = _extract_linear_parameters(
linear_w1_node
)
_, weight_w3, quant_params_w3, wt_type_w3 = _extract_linear_parameters(linear_w3_node)
_, weight_w2, quant_params_w2, wt_type_w2 = _extract_linear_parameters(final_linear)
if None in (weight_w1, weight_w3, weight_w2):
continue
# Ensure the weight type is consistent across branches.
if wt_type_w1 != wt_type_w3 or wt_type_w1 != wt_type_w2:
continue
weight_type = wt_type_w1
pattern_input_nodes.append(input_node_w1)
pattern_output_nodes.append(final_linear)
expert_weights["w1"].append(weight_w1)
expert_weights["w3"].append(weight_w3)
expert_weights["w2"].append(weight_w2)
# TODO: sanity check that all experts have same weight type
if weight_type == "fp8":
expert_scales["w1_input_scale"].append(quant_params_w1["input_scale"])
expert_scales["w1_weight_scale"].append(quant_params_w1["weight_scale"])
expert_scales["w3_input_scale"].append(quant_params_w3["input_scale"])
expert_scales["w3_weight_scale"].append(quant_params_w3["weight_scale"])
expert_scales["w2_input_scale"].append(quant_params_w2["input_scale"])
expert_scales["w2_weight_scale"].append(quant_params_w2["weight_scale"])
elif weight_type == "fp4":
expert_scales["w1_input_scale"].append(quant_params_w1["input_scale"])
expert_scales["w1_weight_scale"].append(quant_params_w1["weight_scale"])
expert_scales["w1_alpha"].append(quant_params_w1["alpha"])
expert_scales["w3_input_scale"].append(quant_params_w3["input_scale"])
expert_scales["w3_weight_scale"].append(quant_params_w3["weight_scale"])
expert_scales["w3_alpha"].append(quant_params_w3["alpha"])
expert_scales["w2_input_scale"].append(quant_params_w2["input_scale"])
expert_scales["w2_weight_scale"].append(quant_params_w2["weight_scale"])
expert_scales["w2_alpha"].append(quant_params_w2["alpha"])
return pattern_input_nodes, pattern_output_nodes, expert_weights, expert_scales, weight_type
def _find_final_hidden_state_node(
pattern_output_nodes: list[Node], end_boundary: Node
) -> Optional[Node]:
"""
Identify the final hidden state node corresponding to the combine pattern:
(expert_output * routing_weight) index_add_
For each expert output node (from the expert compute pattern), this function:
1. Retrieves a multiplication node from its users.
2. Extracts the second argument from the multiplication node (assumed to be the index node).
3. Uses a BFS to locate the subsequent index_add_ node (guarded by the end_boundary).
After collecting all such index_add_ nodes, the final hidden state node is determined
as the one that is not used by any of the other index_add_ nodes.
If any required attribute (users or args) is missing during the process or if no valid
final node is found, the function returns None.
"""
if not pattern_output_nodes:
return None
index_add_nodes = []
for node in pattern_output_nodes:
if not node.users:
return None
mul_node = next(iter(node.users))
if not (hasattr(mul_node, "args") and len(mul_node.args) >= 2):
return None
index_node = mul_node.args[1]
index_add_node = bfs(
index_node, lambda n: is_op(n, torch.ops.aten.index_add_), boundary=end_boundary
)
if not index_add_node:
return None
index_add_nodes.append(index_add_node)
# The final node is defined as the index_add_node that is not used by any other index_add_nodes
return next(
(
candidate
for candidate in index_add_nodes
if not any(
candidate in other.args for other in index_add_nodes if candidate is not other
)
),
None,
)
def _extract_index_branches_from_expert_outputs(
pattern_output_nodes: list[Node],
) -> tuple[list[Node], list[Node]]:
"""
Extract routing and experts branches from expert outputs.
For each expert output, find its multiplication user. From the
multiplication node's second argument (an index node),
extract:
- The first argument as the routing branch.
- The second argument (flattened if a list/tuple) as the experts branch.
Returns:
A tuple (routing_branches, experts_branches).
"""
routing_branches, experts_branches = [], []
for out in pattern_output_nodes:
mul = next((u for u in out.users if is_op(u, torch.ops.aten.mul)), None)
if not mul or len(mul.args) < 2:
continue
idx_node = mul.args[1]
if not is_op(idx_node, torch.ops.aten.index):
continue
routing_branches.append(idx_node.args[0])
experts = idx_node.args[1]
experts_branches.extend(experts) if isinstance(
experts, (list, tuple)
) else experts_branches.append(experts)
return routing_branches, experts_branches
def _remove_dead_inplace_nodes_in_region(
graph: torch.fx.Graph,
start_boundary: torch.fx.Node,
end_boundary: torch.fx.Node,
) -> bool:
"""
Searches (via BFS) for a dead in-place node (index_add_) in the region
between start_boundary and end_boundary. If one is found, it is removed from the graph.
Returns True if a node was removed, False otherwise.
"""
def target(n: torch.fx.Node) -> bool:
return is_op(n, {torch.ops.aten.index_add_}) and len(n.users) == 0
try:
node_to_remove = bfs(start_boundary, target, attr_next="users", boundary=end_boundary)
ad_logger.debug(f"Removing In-place Dead Node: {node_to_remove}")
graph.erase_node(node_to_remove)
return True
except RuntimeError:
return False

View File

@ -1,193 +0,0 @@
"""Graph transformation to automatically add kv cache into fused MHA op."""
import operator
from typing import Dict, Type
import torch
from torch.fx import Graph, GraphModule, Node
from ...custom_ops.attention_interface import AttentionDescriptor, CacheConfig
from ...distributed.common import all_gather_object, get_world_size
from ...shim.interface import CachedSequenceInterface
from ...utils.logger import ad_logger
from ...utils.node_utils import get_all_input_output_nodes, is_op
from .._graph import add_graph_input, canonicalize_graph
def update_in_out_nodes(egm: GraphModule, cm: CachedSequenceInterface) -> None:
"""Modify the graph module by adding new input nodes and canonicalizing the graph.
The new input nodes correspond to the extra arguments needed for cached and flattened attention.
Args:
egm: The graph module to analyze and modify.
cm: Cached sequence interface containing extra argument information.
"""
# loop through nodes to get input, output, and get_attr nodes
input_nodes, output_nodes = get_all_input_output_nodes(egm.graph)
# we only expect one input node
assert len(input_nodes) == 2, "Expected exactly two input nodes (input_ids, position_ids)."
# NOTE: for now, we wanna make sure we *only* return the final output and no hidden states.
# Later on, we can revisit how to support returning hidden states.
assert len(output_nodes) == 1, "Expected exactly one output node!"
assert len(output_nodes[0].all_input_nodes) == 1, "Expected to only return final tensor output!"
ad_logger.info(f"Found {len(input_nodes)} input nodes and {len(output_nodes)} output nodes")
# Activate and add extra argument nodes
new_args = cm.info.switch_to_cached_attn_inputs()
for name in new_args:
input_nodes.append(add_graph_input(egm, name))
ad_logger.info(f"Added {len(new_args)} new input nodes for cached attention metadata")
canonicalize_graph(egm)
def insert_cached_attention(
egm: GraphModule,
cm: CachedSequenceInterface,
attn_descriptor: Type[AttentionDescriptor],
cache_config: CacheConfig,
) -> None:
"""Replace uncached source attention node with corresponding cached attn node."""
# Get all attention nodes and their info objects
source_op = attn_descriptor.get_source_attention_op()
# pick up graph
graph: Graph = egm.graph
# look for relevant source attention nodes
source_attn_nodes = [n for n in graph.nodes if is_op(n, source_op)]
if not source_attn_nodes:
# If there are no nodes for kv cache insertion found, return current graph
return
# Sanity check
if cm.info.is_paged:
assert attn_descriptor.is_paged(), "Paged sequence info requires paged attention op."
ad_logger.debug(f"Before inserting {attn_descriptor=} with cache: {egm}")
# retrieve input nodes
input_nodes, _ = get_all_input_output_nodes(egm.graph)
# insert metadata computation and extract each argument as a node
get_metadata, num_metadata = attn_descriptor.get_prepare_metadata_op()
with graph.inserting_before(input_nodes[-1].next):
ret_node = graph.call_function(
get_metadata,
args=(
*input_nodes,
cm.info.page_size,
),
)
metadata_nodes = [
graph.call_function(operator.getitem, args=(ret_node, idx))
for idx in range(num_metadata)
]
buffer_in_lookup: Dict[str, Node] = {}
# replace fused attention node with attention node that has kv cache
num_cached_attn_replacements = 0
for idx, attn_node in enumerate(source_attn_nodes):
# pick out GEMMs
qkv = attn_node.args[: attn_descriptor.get_num_qkv_args()]
# setup + store cache initializers and caches as input nodes
cache_in_nodes = []
for k, get_cache in attn_descriptor.get_cache_initializers(attn_node, cache_config).items():
k_indexed = f"{k}_{idx}"
cm.add_cache(k_indexed, get_cache)
cache_in_nodes.append(add_graph_input(egm, k_indexed))
# setup + store global buffer initializers and buffers as input nodes
# NOTE: we have to check against existing keys to make sure nothing is registered twice...
buffer_in_nodes = []
for k, get_buffer in attn_descriptor.get_global_buffer_initializers(attn_node).items():
if k not in buffer_in_lookup:
cm.add_cache(k, get_buffer)
buffer_in_lookup[k] = add_graph_input(egm, k)
buffer_in_nodes.append(buffer_in_lookup[k]) # store buffer nodes for this op
# retrieve constants for attention_op
constants = attn_descriptor.get_constants(attn_node)
# insert cached attention replacement op
with graph.inserting_before(attn_node):
cached_attn_node = graph.call_function(
attn_descriptor.get_cached_attention_op(),
args=(*qkv, *metadata_nodes, *cache_in_nodes, *buffer_in_nodes, *constants),
)
attn_node.replace_all_uses_with(cached_attn_node)
graph.erase_node(attn_node)
num_cached_attn_replacements += 1
canonicalize_graph(egm)
ad_logger.info(
f"Replaced {num_cached_attn_replacements} {source_op} ops "
f"with {attn_descriptor.get_cached_attention_op()}"
)
ad_logger.debug(f"After inserting {attn_descriptor=} with cache: {egm}")
def resize_kv_cache(
egm: GraphModule,
cm: CachedSequenceInterface,
free_mem_ratio: float = 0.8,
) -> None:
"""Inflate the kv cache to occupy the available GPU memory.
free_mem_ratio specifies the fraction of available memory to occupy.
"""
def _get_mem_info_in_mb():
free_mem, total_mem = torch.cuda.mem_get_info()
return free_mem // 1024**2, total_mem // 1024**2
free_mem, total_mem = _get_mem_info_in_mb()
ad_logger.info(f"Free memory (MB): {free_mem}, Total memory (MB): {total_mem}")
current_cache_size = cm.current_cache_size_bytes()
current_num_pages = cm.info.num_pages
ad_logger.info(
f"Current cache size: {current_cache_size}, Current num pages: {current_num_pages}"
)
if free_mem_ratio == 0.0:
ad_logger.info(f"Skipping cache resize for {free_mem_ratio=}")
return
try:
# Let's run a forward pass to get the memory usage
cm.info._set_max_num_tokens_sample()
free_mem_pre, _ = _get_mem_info_in_mb()
ad_logger.info(f"Free memory before forward pass (MB): {free_mem_pre}")
egm(*cm.args)
free_mem_post, _ = _get_mem_info_in_mb()
ad_logger.info(f"Free memory after forward pass (MB): {free_mem_post}")
memory_for_forward_pass = free_mem_pre - free_mem_post
ad_logger.info(f"Memory for forward pass (MB): {memory_for_forward_pass}")
new_cache_size = free_mem_post * 1024 * 1024 * free_mem_ratio + current_cache_size
new_num_pages = int(new_cache_size // (current_cache_size // current_num_pages))
# Need to sync all the GPUs
gathered_num_pages = [None] * get_world_size()
all_gather_object(gathered_num_pages, new_num_pages)
new_num_pages = min(gathered_num_pages)
ad_logger.info(f"After all_gather - new_num_pages: {new_num_pages}")
cm.resize_cache(new_num_pages)
except Exception as e:
ad_logger.warning(
f"Error encountered while resizing kv cache: {e}.\nSkipping cache resize."
)
# Free memory
torch.cuda.empty_cache()

View File

@ -1,113 +0,0 @@
"""Graph transform to optimize RMSNorm execution using FlashInfer."""
from functools import partial
import torch
from torch.fx import GraphModule
from ...utils.logger import ad_logger
# It is important to import ADPatternMatcherPass from pattern_matcher.py, not from torch._inductor.pattern_matcher
from ...utils.pattern_matcher import ADPatternMatcherPass, register_ad_pattern
from .._graph import canonicalize_graph
_BACKEND_OPS = {
"flashinfer": torch.ops.auto_deploy.flashinfer_rms_norm,
"triton": torch.ops.auto_deploy.triton_rms_norm,
"torch": torch.ops.auto_deploy.torch_rmsnorm,
}
def _rms_norm_pattern(data: torch.Tensor, weight: torch.Tensor, eps: float) -> torch.Tensor:
"""Implements the RMSNorm pattern for pattern matching.
Args:
data: Input tensor to normalize.
weight: Scaling weights for the normalized output.
eps: Small constant for numerical stability.
Returns:
Normalized and scaled tensor.
"""
input_dtype = data.dtype
data = data.to(torch.float32)
variance = data.pow(2).mean(-1, keepdim=True)
data = data * torch.rsqrt(variance + eps)
return weight * data.to(input_dtype)
def _rms_norm_replacement(
data: torch.Tensor, weight: torch.Tensor, eps: float, backend: str
) -> torch.Tensor:
"""Backend-specific rms_norm implementation.
Args:
data: Input tensor to normalize.
weight: Scaling weights for the normalized output.
eps: Small constant for numerical stability.
backend: Backend to use for RMSNorm computation ("flashinfer" or "triton").
Returns:
Normalized and scaled tensor using the specified backend implementation.
"""
assert backend.lower() in _BACKEND_OPS, (
f"Invalid {backend=}; must be one of {list(_BACKEND_OPS)}"
)
return _BACKEND_OPS[backend.lower()](data, weight, eps)
def fuse_rmsnorm(gm: GraphModule, backend: str = "triton") -> None:
"""Matches and replaces RMSNorm patterns in the graph with FlashInfer or Triton implementation.
This function sets up pattern matching to identify RMSNorm operations in the graph
and replaces them with optimized implementations. It uses dummy tensors to register
the pattern matching rules.
Args:
gm: Input graph module to transform.
backend: Backend to use for RMSNorm computation ("flashinfer" or "triton").
Returns:
Transformed graph module with optimized RMSNorm operations.
"""
if backend.lower() not in _BACKEND_OPS:
raise ValueError(f"Invalid backend, must be one of {list(_BACKEND_OPS)}, got {backend}")
ad_logger.info(f"Starting RMSNorm pattern matching with backend: {backend}")
graph = gm.graph
patterns = ADPatternMatcherPass()
# Create dummy tensors for pattern matching
bs = 2
hidden_size = 512
def dummy_args(input_dtype: torch.dtype, weight_dtype: torch.dtype, eps: float = 1e-6):
return [
torch.randn(bs, hidden_size, device="cuda", dtype=input_dtype),
torch.randn(hidden_size, device="cuda", dtype=weight_dtype),
eps,
]
# Define configurations for different data types
configs = [
(torch.bfloat16, torch.bfloat16),
(torch.float16, torch.float16),
(torch.float32, torch.float32),
]
# Register patterns for each configuration
for input_dtype, weight_dtype in configs:
register_ad_pattern(
search_fn=_rms_norm_pattern,
replace_fn=partial(_rms_norm_replacement, backend=backend),
patterns=patterns,
dummy_args=dummy_args(input_dtype, weight_dtype),
op_ignore_types={},
scalar_workaround={"eps": 1e-6},
)
cnt = patterns.apply(graph)
ad_logger.info(f"RMSNorm pattern count: {cnt}")
canonicalize_graph(gm)
ad_logger.debug("RMSNorm pattern matching completed.")

View File

@ -5,21 +5,11 @@ import gc
import torch
import torch.nn as nn
from ..compile import compile_and_capture
from ..custom_ops.attention_interface import AttentionRegistry
from ..llm_args import AutoDeployConfig
from ..models.factory import ModelFactory
from ..shim.interface import CachedSequenceInterface
from ..transform.optimizer import InferenceOptimizer as ModularInferenceOptimizer
from ..utils.logger import ad_logger
from .library import (
fuse_allreduce_residual_rmsnorm,
fuse_collectives,
fuse_rmsnorm,
insert_cached_attention,
resize_kv_cache,
update_in_out_nodes,
)
class InferenceOptimizer:
@ -55,88 +45,60 @@ class InferenceOptimizer:
self.ad_config.attn_backend
).get_attention_layout()
new_optimizer = ModularInferenceOptimizer(self.factory, self.ad_config.transforms)
# TODO (hg): similar to above.
if "load_weights" in new_optimizer.config:
new_optimizer.config[
if "load_weights" in self.ad_config.transforms:
self.ad_config.transforms[
"load_weights"
].checkpoint_device = self.ad_config.checkpoint_device
new_optimizer.config["load_weights"].device = cm.device
self.ad_config.transforms["load_weights"].device = cm.device
if "resize_kv_cache" in self.ad_config.transforms:
self.ad_config.transforms[
"resize_kv_cache"
].free_mem_ratio = self.ad_config.free_mem_ratio
if "insert_cached_attention" in self.ad_config.transforms:
self.ad_config.transforms[
"insert_cached_attention"
].attn_backend = self.ad_config.attn_backend
if "insert_cached_mla_attention" in self.ad_config.transforms:
self.ad_config.transforms[
"insert_cached_mla_attention"
].attn_backend = self.ad_config.mla_backend
# TODO: (hg)Missing MLA here. Figure out how to add MLA since duplicate transforms are not allowed.
# Old code:
# detect attention op and replace with cache-aware op
# for a_backend in [self.ad_config.attn_backend, self.ad_config.mla_backend]:
# attn_descriptor = AttentionRegistry.get(a_backend)
# insert_cached_attention(egm, cm, attn_descriptor, self.factory.get_cache_config())
if "compile_model" in self.ad_config.transforms:
self.ad_config.transforms[
"compile_model"
].cuda_graph_batch_sizes = self.ad_config.cuda_graph_batch_sizes
self.ad_config.transforms[
"compile_model"
].compile_backend = self.ad_config.compile_backend
new_optimizer = ModularInferenceOptimizer(self.factory, self.ad_config.transforms)
# TODO: (hg) move this. let match_rope_layout and match_atten_layout use this shared config
new_optimizer.shared_config.attn_backend = self.ad_config.attn_backend
egm = new_optimizer(cm)
# TODO (lucaslie): continue moving legacy transforms to the new optimizer
############################################################################################
# RUN POST-LOAD FUSION AND OPTIMIZATIONS
############################################################################################
# NOTE: (hg)Disabled visualization since compiled gm is a CapturedGraph instead of GraphModule.
# We can add a new stage in the optimizer to visualize the intermediate gm.
# if self.ad_config.visualize:
# try:
# from .library import visualize_namespace
# run MoE fusion
# TODO: https://github.com/NVIDIA/TensorRT-LLM/issues/4674 this is causing OOMs
# fuse_moe(egm)
# run GEMM fusion
# TODO: https://github.com/NVIDIA/TensorRT-LLM/issues/4674 this is causing OOMs
# fuse_gemms(egm)
# check if we can fuse allreduce, residual and rmsnorm
fuse_allreduce_residual_rmsnorm(egm)
# check if we can fuse collectives
fuse_collectives(egm)
# TODO (lucaslie): add backend selection as part of configurable inference optimizers
# check if we can fuse rmsnorm
fuse_rmsnorm(egm, "flashinfer")
# visualize the final graph
if self.ad_config.visualize:
try:
from .library import visualize_namespace
visualize_namespace(egm, args=cm.args, dynamic_shapes=cm.dynamic_shapes)
ad_logger.warning(
"Please run `pip install -r examples/auto_deploy/requirements.txt` to visualize"
" the graph."
)
except ImportError:
pass
############################################################################################
# SWITCH TO CACHED+FLATTENED ATTENTION + INITIALIZE CACHES
############################################################################################
update_in_out_nodes(egm, cm)
# detect attention op and replace with cache-aware op
for a_backend in [self.ad_config.attn_backend, self.ad_config.mla_backend]:
attn_descriptor = AttentionRegistry.get(a_backend)
insert_cached_attention(egm, cm, attn_descriptor, self.factory.get_cache_config())
# initialize cache on correct device
cm.initialize_caches()
# resize kv cache to occupy the available GPU memory up to free_mem_ratio
resize_kv_cache(egm, cm, free_mem_ratio=self.ad_config.free_mem_ratio)
############################################################################################
# COMPILE MODEL
############################################################################################
cm.info.set_generate_only_batch()
compiler_kwargs = {
"cuda_graph_batch_sizes": self.ad_config.cuda_graph_batch_sizes,
"num_batched_inputs": 2, # TODO (lucaslie): improve once we have a config system...
}
egm_compiled = compile_and_capture(
egm,
self.ad_config.compile_backend,
args=cm.args,
dynamic_shapes=cm.dynamic_shapes,
compiler_kwargs=compiler_kwargs,
)
cm.info.reset()
# visualize_namespace(egm, args=cm.args, dynamic_shapes=cm.dynamic_shapes)
# ad_logger.warning(
# "Please run `pip install -r examples/auto_deploy/requirements.txt` to visualize"
# " the graph."
# )
# except ImportError:
# pass
torch.cuda.empty_cache()
gc.collect()
return egm_compiled
return egm

View File

@ -8,9 +8,7 @@ from torch.export import export
from tensorrt_llm._torch.auto_deploy.distributed import common as dist
from tensorrt_llm._torch.auto_deploy.distributed.trtllm import is_trtllm_op_available
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
from tensorrt_llm._torch.auto_deploy.transformations.library.collectives import (
fuse_allreduce_residual_rmsnorm,
)
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op
from tensorrt_llm.llmapi.mpi_session import MpiPoolSession
@ -65,14 +63,21 @@ def _test_allreduce_fusion(port: int):
original_outputs, residual_original = gm(x, residual)
# Fuse ops
fuse_allreduce_residual_rmsnorm(gm)
gm_transformed = InferenceOptimizer(
None,
{
"fuse_allreduce_residual_rmsnorm": {
"stage": "post_load_fusion",
},
},
)(None, gm)
# Run the fused graph
fused_outputs, residual_fused = gm(x, residual)
fused_outputs, residual_fused = gm_transformed(x, residual)
# Check if fused node in the graph
has_fused_node = False
for node in gm.graph.nodes:
for node in gm_transformed.graph.nodes:
if is_op(node, torch.ops.dist.fused_allreduce_residual_rmsnorm):
has_fused_node = True
assert has_fused_node, "Fused node not found."
@ -86,8 +91,8 @@ def _test_allreduce_fusion(port: int):
)
# check if we can still export the model as expected
export(gm, args=args)
torch_export_to_gm(gm, args=args)
export(gm_transformed, args=args)
torch_export_to_gm(gm_transformed, args=args)
@pytest.mark.parametrize("device_count", get_device_counts())

View File

@ -8,12 +8,13 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from _dist_test_utils import get_device_counts
from _graph_test_helpers import run_test
from _graph_test_helpers import run_test_transformed_gm
from _torch_test_utils import fp8_compatible
import tensorrt_llm._torch.auto_deploy.distributed.common as dist_common
from tensorrt_llm._torch.auto_deploy.custom_ops.quant import FP8Linear
from tensorrt_llm._torch.auto_deploy.transformations.library import fuse_collectives
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op
@ -61,11 +62,21 @@ def _run_job(
is_op(n, torch.ops.auto_deploy.torch_dist_all_reduce) for n in gm.graph.nodes
)
gm = torch_export_to_gm(model, args=(x,), clone=True)
gm_transformed = InferenceOptimizer(
None,
{
"fuse_collectives": {
"stage": "post_load_fusion",
},
},
)(None, gm)
# now run the test
run_test(
run_test_transformed_gm(
model,
x,
transform=fuse_collectives,
gm_transformed,
check_transformed_graph=check_transformed_graph,
_get_expected_num_params=_get_expected_num_params,
test_load_hook=False,

View File

@ -1,12 +1,11 @@
from functools import partial
import pytest
import torch
from _graph_test_helpers import run_test
from _graph_test_helpers import run_test_transformed_gm
from torch.export import Dim
from tensorrt_llm._torch.auto_deploy.custom_ops.rms_norm import * # noqa
from tensorrt_llm._torch.auto_deploy.transformations.library.rms_norm import fuse_rmsnorm
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op
@ -52,15 +51,28 @@ def test_rmsnorm_fusion(eps, variant, op):
return any(is_op(n, op) for n in gm.graph.nodes)
model = TestModel(eps)
gm_transformed = run_test(
x = torch.randn(2, 1024, device="cuda", dtype=torch.float16)
dynamic_shapes = {0: Dim("batch_size", max=8)}
gm = torch_export_to_gm(model, args=(x,), dynamic_shapes=(dynamic_shapes,), clone=True)
gm_transformed = InferenceOptimizer(
None,
{
"fuse_rmsnorm": {
"stage": "post_load_fusion",
"backend": variant,
},
},
)(None, gm)
run_test_transformed_gm(
model,
torch.randn(2, 1024, device="cuda", dtype=torch.float16),
partial(fuse_rmsnorm, backend=variant),
x,
gm_transformed,
checker,
lambda num_p_og: num_p_og,
dynamic_shapes={0: Dim("batch_size", max=8)},
dynamic_shapes=dynamic_shapes,
)
print(gm_transformed.graph)
new_input = torch.randn(4, 1024, device="cuda", dtype=torch.float16)
y_transformed = gm_transformed(new_input)
y_model = model(new_input)

View File

@ -7,11 +7,12 @@ import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F
from _graph_test_helpers import count_buffers, run_test
from _graph_test_helpers import count_buffers, run_test_transformed_gm
from _torch_test_utils import all_close, fp8_compatible, reset_parameters
from tensorrt_llm._torch.auto_deploy.custom_ops.quant import FP8Linear
from tensorrt_llm._torch.auto_deploy.transformations.library import fuse_gemms
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_linear_op
torch.manual_seed(0)
@ -254,10 +255,20 @@ def test_fusion(get_model: Callable[[], TestModel], dtype: str):
buffer_size_before = count_buffers(model)
gm_transformed = run_test(
gm = torch_export_to_gm(model, args=(x,), clone=True)
gm_transformed = InferenceOptimizer(
None,
{
"fuse_gemms": {
"stage": "post_load_fusion",
},
},
)(None, gm)
run_test_transformed_gm(
model,
x,
fuse_gemms,
gm_transformed,
lambda gm: sum(is_linear_op(n, include_quantization=True) for n in gm.graph.nodes)
== model.num_gemms_after_fusion,
lambda num_p_og: num_p_og, # unchanged since fusing doesn't change param count

View File

@ -2,19 +2,35 @@ from typing import Optional
import pytest
import torch
from _graph_test_helpers import FakeFactory
from _graph_test_helpers import SequenceEmbeddingInfo
from _model_test_utils import GQA
from _torch_test_utils import all_close
from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import CacheConfig, SequenceInfo
from tensorrt_llm._torch.auto_deploy.custom_ops.flashinfer_attention import FlashInferAttention
from tensorrt_llm._torch.auto_deploy.custom_ops.triton_attention import TritonAttention
from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import CacheConfig
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
from tensorrt_llm._torch.auto_deploy.models.factory import ModelFactory
from tensorrt_llm._torch.auto_deploy.shim.interface import CachedSequenceInterface
from tensorrt_llm._torch.auto_deploy.transform.interface import InferenceOptimizerConfig
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
from tensorrt_llm._torch.auto_deploy.transformations.library import update_in_out_nodes
from tensorrt_llm._torch.auto_deploy.transformations.library.kvcache import insert_cached_attention
class DummyFactory(ModelFactory):
"""Dummy factory to pass cache_config for testing."""
def __init__(self, model, cache_config):
self._model = model
self.cache_config = cache_config
def build_model(self, device: str):
return self._model.to(device=device)
def _build_model(self, device: str):
return
def _load_checkpoint(self, model, device):
return
def get_cache_config(self):
return self.cache_config
# Class that uses SDPA directly instead of the regular attention mechanism
@ -68,42 +84,6 @@ class GQAWithSdpa(GQA):
return self.o_proj(attn_output)
def _get_optimizer_config() -> InferenceOptimizerConfig:
return {
"build_model": {
"stage": "factory",
"device": "cuda",
"run_graph_cleanup": False,
"requires_clean_graph": False,
},
"export_to_gm": {
"stage": "export",
"strict": False,
"clone_state_dict": True,
"run_graph_cleanup": False,
"requires_clean_graph": False,
},
"cleanup_input_constraints": {
"stage": "post_export",
},
}
class SequenceEmbeddingInfo(SequenceInfo):
hidden_size: int
dtype: torch.dtype
def set_example_sequence(self) -> None:
super().set_example_sequence()
# set input ids to a 3D tensor (actually input embeddings)
self.input_ids = torch.rand(
*self.input_ids.shape,
self.hidden_size,
device=self.input_ids.device,
dtype=self.dtype,
)
# TODO (lucaslie): consider rewriting this test with a custom InferenceOptimizer config
@pytest.mark.parametrize(
"dtype",
@ -111,8 +91,8 @@ class SequenceEmbeddingInfo(SequenceInfo):
ids=["float16", "float32"],
)
@pytest.mark.parametrize(
"attn_descriptor",
[TritonAttention, FlashInferAttention],
"attn_backend",
["triton", "flashinfer"],
ids=["triton", "flashinfer"],
)
@pytest.mark.parametrize(
@ -125,10 +105,10 @@ class SequenceEmbeddingInfo(SequenceInfo):
ids=["regular", "gqa", "mqa"],
)
@torch.inference_mode()
def test_sdpa_with_kv_cache(dtype, attn_descriptor, gqa_config):
def test_sdpa_with_kv_cache(dtype, attn_backend, gqa_config):
"""Test the SDPA transformation with KV cache."""
# flashinfer doesn't support float32 data type
if attn_descriptor == FlashInferAttention and dtype == torch.float32:
if attn_backend == "flashinfer" and dtype == torch.float32:
pytest.skip("flashinfer doesn't support float32 data type")
# Unpack the GQA configuration
@ -157,7 +137,6 @@ def test_sdpa_with_kv_cache(dtype, attn_descriptor, gqa_config):
hidden_size,
num_key_value_heads,
).to(dtype=dtype, device="cuda")
factory = FakeFactory(model)
# Create input tensor and position_ids
x = torch.rand(batch_size, seq_len, hidden_size).to(device="cuda", dtype=dtype)
@ -166,21 +145,37 @@ def test_sdpa_with_kv_cache(dtype, attn_descriptor, gqa_config):
# Get the model's regular output
y_model = model(x, position_ids) # b, s, d
# run modular inference optimizer up to post_export
optimizer = InferenceOptimizer(factory, _get_optimizer_config()) # type: ignore
# Apply the transformation
optimizer = InferenceOptimizer(
DummyFactory(model, CacheConfig()),
{
"build_model": {
"stage": "factory",
"device": "cuda",
"run_graph_cleanup": False,
"requires_clean_graph": False,
},
"export_to_gm": {
"stage": "export",
"strict": False,
"clone_state_dict": True,
"run_graph_cleanup": False,
"requires_clean_graph": False,
},
"cleanup_input_constraints": {
"stage": "post_export",
},
"update_in_out_nodes": {
"stage": "cache_init",
},
"insert_cached_attention": {
"stage": "cache_init",
"attn_backend": attn_backend,
},
},
) # type: ignore
gm = optimizer(cm)
y_gm = gm(x, position_ids)
assert all_close(y_model, y_gm, atol=atol, rtol=rtol)
# Set up cache configuration
cache_config = CacheConfig()
# Get input node(s)
update_in_out_nodes(gm, cm)
# Apply the transformation
insert_cached_attention(gm, cm, attn_descriptor=attn_descriptor, cache_config=cache_config)
gm.to("cuda")
cm.initialize_caches()