From e18dacc931daa303720c1bd539314711faa103e8 Mon Sep 17 00:00:00 2001 From: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> Date: Thu, 21 Aug 2025 10:30:36 -0700 Subject: [PATCH] [#4403][refactor] Move fusion, kvcache, and compile to modular inference optimizer (#7057) Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> Co-authored-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- .../_torch/auto_deploy/config/default.yaml | 49 +- .../_torch/auto_deploy/transform/interface.py | 6 +- .../transform/library/collectives.py | 204 +++++++ .../transform/library/compile_model.py | 65 +++ .../library/fusion.py | 58 +- .../auto_deploy/transform/library/kvcache.py | 299 ++++++++++ .../auto_deploy/transform/library/rms_norm.py | 148 +++++ .../transformations/library/__init__.py | 6 - .../transformations/library/collectives.py | 167 ------ .../transformations/library/fused_moe.py | 511 ------------------ .../transformations/library/kvcache.py | 193 ------- .../transformations/library/rms_norm.py | 113 ---- .../auto_deploy/transformations/transform.py | 134 ++--- .../test_allreduce_residual_rmsnorm_fusion.py | 21 +- .../library/test_collective_fusion.py | 19 +- .../library/test_fuse_rmsnorm.py | 30 +- .../library/test_gemm_fusion.py | 19 +- .../transformations/library/test_kv_cache.py | 117 ++-- 18 files changed, 969 insertions(+), 1190 deletions(-) create mode 100644 tensorrt_llm/_torch/auto_deploy/transform/library/collectives.py create mode 100644 tensorrt_llm/_torch/auto_deploy/transform/library/compile_model.py rename tensorrt_llm/_torch/auto_deploy/{transformations => transform}/library/fusion.py (76%) create mode 100644 tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py create mode 100644 tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py delete mode 100644 tensorrt_llm/_torch/auto_deploy/transformations/library/collectives.py delete mode 100644 tensorrt_llm/_torch/auto_deploy/transformations/library/fused_moe.py delete mode 100644 tensorrt_llm/_torch/auto_deploy/transformations/library/kvcache.py delete mode 100644 tensorrt_llm/_torch/auto_deploy/transformations/library/rms_norm.py diff --git a/tensorrt_llm/_torch/auto_deploy/config/default.yaml b/tensorrt_llm/_torch/auto_deploy/config/default.yaml index f7ad7934a9..041d51e73d 100644 --- a/tensorrt_llm/_torch/auto_deploy/config/default.yaml +++ b/tensorrt_llm/_torch/auto_deploy/config/default.yaml @@ -19,6 +19,11 @@ transforms: stage: post_export cleanup_input_constraints: stage: post_export + ############################################################################################ + # RUN PATTERN MATCHER TRANSFORMATIONS TO STANDARDIZE GRAPH REPRESENTATION + ############################################################################################ + match_moe_pattern: + stage: pattern_matcher match_repeat_kv: stage: pattern_matcher match_eager_attention: @@ -27,12 +32,13 @@ transforms: stage: pattern_matcher match_attention_layout: stage: pattern_matcher - match_moe_pattern: - stage: pattern_matcher match_rope_pattern: stage: pattern_matcher match_rope_layout: stage: pattern_matcher + ############################################################################################ + # RUN TRANSFORMATIONS ON STANDARDIZED GRAPH REPRESENTATION + ############################################################################################ eliminate_redundant_transposes: stage: pattern_matcher # TODO (lucaslie): let's move this to perf optimization once TP sharding is improved @@ -57,5 +63,44 @@ transforms: sharding_transform_executor: stage: sharding run_shape_prop: true + ############################################################################################ + # MOVE MODEL AND LOAD WEIGHTS + ############################################################################################ load_weights: stage: weight_load + ############################################################################################ + # RUN POST-LOAD FUSION AND OPTIMIZATIONS + ############################################################################################ + # TODO: https://github.com/NVIDIA/TensorRT-LLM/issues/4674 this is causing OOMs + # fuse_moe: + # stage: post_load_fusion + # fuse_gemms: + # stage: post_load_fusion + fuse_allreduce_residual_rmsnorm: + stage: post_load_fusion + fuse_collectives: + stage: post_load_fusion + # TODO (lucaslie): add backend selection as part of configurable inference optimizers + # check if we can fuse rmsnorm + fuse_rmsnorm: + stage: post_load_fusion + backend: flashinfer + ############################################################################################ + # SWITCH TO CACHED+FLATTENED ATTENTION + INITIALIZE CACHES + ############################################################################################ + update_in_out_nodes: + stage: cache_init + insert_cached_attention: + stage: cache_init + insert_cached_mla_attention: + stage: cache_init + attn_backend: MultiHeadLatentAttention + initialize_cache: + stage: cache_init + resize_kv_cache: + stage: cache_init + ############################################################################################ + # COMPILE MODEL + ############################################################################################ + compile_model: + stage: compile diff --git a/tensorrt_llm/_torch/auto_deploy/transform/interface.py b/tensorrt_llm/_torch/auto_deploy/transform/interface.py index 1087714177..cddc56b872 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/interface.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/interface.py @@ -54,6 +54,7 @@ class SharedConfig(BaseModel): sharding_config: ShardingConfig = Field(default_factory=ShardingConfig) local_rank: int = Field(default=0) world_size: int = Field(default=1) + attn_backend: str = Field(default="flashinfer", description="The attention backend to use.") class TransformConfig(BaseModel): @@ -285,7 +286,10 @@ class BaseTransform(ABC): # update + store new meta data history[t_name] = info autodeploy_meta[self._history_key] = history - self._set_autodeploy_meta(gm, autodeploy_meta) + + if isinstance(gm, GraphModule): + # After compilation, gm becomes type CapturedGraph with no meta data. + self._set_autodeploy_meta(gm, autodeploy_meta) # return the graph module return gm diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/collectives.py b/tensorrt_llm/_torch/auto_deploy/transform/library/collectives.py new file mode 100644 index 0000000000..6c5b1fe2b9 --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/collectives.py @@ -0,0 +1,204 @@ +import operator +from typing import Tuple + +import torch +from torch.fx import GraphModule + +from ...distributed.trtllm import is_trtllm_op_available +from ...models.factory import ModelFactory +from ...shim.interface import CachedSequenceInterface +from ...utils.node_utils import get_op_overload_packet, get_user_if_pattern_match, is_op +from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry + +# TODO: This is an overly simplified model that works well for vanilla Llama models. +# However, we eventually want to consider more sophisticated patterns such as +# * all_reduce(lin1(x) + lin2(x)) +# * version above with fused GEMMs (i.e. with a split node) +# * all_reduce(pointwise_op(linear(x))) +# * ... + + +@TransformRegistry.register("fuse_collectives") +class FuseCollectives(BaseTransform): + """ + Fuses all_reduce ops with preceding (quantized) linear ops into a single fused node for improved performance. + """ + + def _apply( + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, + ) -> Tuple[GraphModule, TransformInfo]: + num_gemm_collective_fusions = 0 + + # lookup for fused ops + # TODO: avoid this hardcoded lookup, e.g., by generating fused ops on the fly. + lookup = { + torch.ops.auto_deploy.torch_linear_simple: torch.ops.auto_deploy.trtllm_dist_fused_linear_all_reduce, + torch.ops.aten.linear: torch.ops.auto_deploy.trtllm_dist_fused_linear_all_reduce, + torch.ops.auto_deploy.torch_quant_fp8_linear: torch.ops.auto_deploy.torch_quant_fused_fp8_linear_all_reduce, + } + + # go through all nodes and find all_reduce nodes + for node in gm.graph.nodes: + if not is_op(node, torch.ops.auto_deploy.torch_dist_all_reduce): + continue + + # check if args are as expected + assert len(node.args) == 1 and not len(node.kwargs), ( + "Unexpected args/kwargs for all_reduce" + ) + + # retrieve parent and check a few conditions on the parent node + parent_node = node.args[0] + if not is_op(parent_node, lookup.keys()): + continue + if len(parent_node.users) > 1: + continue + + with gm.graph.inserting_before(node): + # insert fused node + fused_linear_collective_node = gm.graph.call_function( + lookup[get_op_overload_packet(parent_node.target)], + args=parent_node.args, + kwargs=parent_node.kwargs, + ) + node.replace_all_uses_with(fused_linear_collective_node) + gm.graph.erase_node(node) + gm.graph.erase_node(parent_node) + num_gemm_collective_fusions += 1 + + info = TransformInfo( + skipped=False, + num_matches=num_gemm_collective_fusions, + is_clean=False, + has_valid_shapes=False, + ) + + return gm, info + + +@TransformRegistry.register("fuse_allreduce_residual_rmsnorm") +class FuseAllreduceResidualRMSNorm(BaseTransform): + """Essentially, this transformation fuses the following operators into one allreduce trtllm implementation. + + * target pattern: + x = all_reduce(x) + y = x + residual + return rmsnorm(y), y + * replacement: + fused_allreduce_residual_rmsnorm(x, residual, rmsnorm_weight, rmsnorm_eps) + + """ + + def _apply( + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, + ) -> Tuple[GraphModule, TransformInfo]: + if not is_trtllm_op_available(): + return gm, TransformInfo( + skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True + ) + + num_ar_r_rms_fusions = 0 + + def trace_and_fuse(allreduce_node, graph): + # Check if all_reduce is followed by addition + users = list(allreduce_node.users.keys()) + if len(users) != 1: + return # Skip if all_reduce has more than one consumer + add_node = users[0] + + # Traverse nodes for RMSNorm pattern which is composed of to_copy, pow, mean, add, refer + # the Huggingface LlamaRMSNorm implementation as example for more details + to_copy_1 = get_user_if_pattern_match(add_node, [torch.ops.aten.add, operator.add], 2) + # operand of pow and mul + pow_node = get_user_if_pattern_match( + to_copy_1, [torch.ops.aten._to_copy, torch.ops.aten.to], 2 + ) + mean_node = get_user_if_pattern_match(pow_node, torch.ops.aten.pow, 1) + add_eps_node = get_user_if_pattern_match(mean_node, torch.ops.aten.mean, 1) + rsqrt_node = get_user_if_pattern_match( + add_eps_node, [torch.ops.aten.add, operator.add], 1 + ) + mul_node_1 = get_user_if_pattern_match(rsqrt_node, torch.ops.aten.rsqrt, 1) + to_copy_2 = get_user_if_pattern_match(mul_node_1, torch.ops.aten.mul, 1) + mul_node_2 = get_user_if_pattern_match( + to_copy_2, [torch.ops.aten._to_copy, torch.ops.aten.to], 1 + ) + # check args of ops: pow(2) and mean(-1) + ARGS_MATCH = pow_node is not None and pow_node.args[1] == 2 # exponent + ARGS_MATCH &= mean_node is not None and mean_node.args[1] == [-1] # dimensions + + # Match found: Replace with fused operation + if ( + to_copy_1 + and pow_node + and mean_node + and add_eps_node + and rsqrt_node + and mul_node_1 + and to_copy_2 + and mul_node_2 + and ARGS_MATCH + ): + # Gather the inputs for the custom operation + tensor = allreduce_node.args[0] + # Identify the residual argument in the add operation + # One of the args in add_node.args is the output of all_reduce + # The same idea also applies to norm_weight + residual = ( + add_node.args[0] if add_node.args[1] is allreduce_node else add_node.args[1] + ) + norm_weight = ( + mul_node_2.args[0] if mul_node_2.args[1] is to_copy_2 else mul_node_2.args[1] + ) + eps = add_eps_node.args[1] + + # Insert nodes + with graph.inserting_before(allreduce_node): + fused_node = graph.call_function( + torch.ops.dist.fused_allreduce_residual_rmsnorm, + args=( + tensor, + residual, + norm_weight, + eps, + ), + ) + # Extract outputs from the tuple returned by `fused_node` + final_output_node = gm.graph.create_node( + "call_function", + target=operator.getitem, + args=(fused_node, 0), + ) + add_output_node = gm.graph.create_node( + "call_function", + target=operator.getitem, + args=(fused_node, 1), + ) + + # Replace all uses of rmsnorm_node with final_output_node + mul_node_2.replace_all_uses_with(final_output_node) + + # Replace all uses of add_node with add_output_node + add_node.replace_all_uses_with(add_output_node) + + nonlocal num_ar_r_rms_fusions + num_ar_r_rms_fusions += 1 + + # Traverse all nodes + for node in gm.graph.nodes: + if is_op(node, torch.ops.auto_deploy.torch_dist_all_reduce): + trace_and_fuse(allreduce_node=node, graph=gm.graph) + + info = TransformInfo( + skipped=False, num_matches=num_ar_r_rms_fusions, is_clean=False, has_valid_shapes=False + ) + + return gm, info diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/compile_model.py b/tensorrt_llm/_torch/auto_deploy/transform/library/compile_model.py new file mode 100644 index 0000000000..00601303b6 --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/compile_model.py @@ -0,0 +1,65 @@ +from typing import List, Literal, Optional, Tuple, Type + +from pydantic import Field +from torch.fx import GraphModule + +from ...compile import compile_and_capture +from ...models.factory import ModelFactory +from ...shim.interface import CachedSequenceInterface +from ..interface import ( + BaseTransform, + SharedConfig, + TransformConfig, + TransformInfo, + TransformRegistry, +) + + +class CompileModelConfig(TransformConfig): + """Configuration for the compile model transform.""" + + cuda_graph_batch_sizes: Optional[List[int]] = Field( + default=None, description="The batch sizes to use for CUDA graphs." + ) + num_batched_inputs: int = Field( + default=2, description="The number of batched inputs to use for CUDA graphs." + ) + compile_backend: Literal["torch-simple", "torch-compile", "torch-cudagraph", "torch-opt"] = ( + Field(description="The backend to use for compiling the model.") + ) + + +@TransformRegistry.register("compile_model") +class CompileModel(BaseTransform): + """A transform to compile the model.""" + + config: CompileModelConfig + + @classmethod + def get_config_class(cls) -> Type[TransformConfig]: + return CompileModelConfig + + def _apply( + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, + ) -> Tuple[GraphModule, TransformInfo]: + cm.info.set_generate_only_batch() + egm_compiled = compile_and_capture( + gm, + self.config.compile_backend, + args=cm.args, + dynamic_shapes=cm.dynamic_shapes, + compiler_kwargs={ + "cuda_graph_batch_sizes": self.config.cuda_graph_batch_sizes, + "num_batched_inputs": self.config.num_batched_inputs, + }, + ) + cm.info.reset() + + # store info object about the transform + info = TransformInfo(skipped=False, num_matches=1, is_clean=True, has_valid_shapes=True) + + return egm_compiled, info diff --git a/tensorrt_llm/_torch/auto_deploy/transformations/library/fusion.py b/tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py similarity index 76% rename from tensorrt_llm/_torch/auto_deploy/transformations/library/fusion.py rename to tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py index e66ced8ae6..2d422c42d6 100644 --- a/tensorrt_llm/_torch/auto_deploy/transformations/library/fusion.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py @@ -6,6 +6,8 @@ import torch import torch.nn as nn from torch.fx import GraphModule, Node +from ...models.factory import ModelFactory +from ...shim.interface import CachedSequenceInterface from ...utils.cuda_mem_tracker import cuda_memory_tracker from ...utils.logger import ad_logger from ...utils.node_utils import ( @@ -14,7 +16,7 @@ from ...utils.node_utils import ( is_linear_op, ) from ...utils.quantization_utils import QuantizationImpl -from .._graph import canonicalize_graph +from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry def _insert_fused_gemm(gm: GraphModule, idx: int, parent_node: Node, linear_nodes: List[Node]): @@ -116,30 +118,36 @@ def _insert_fused_gemm(gm: GraphModule, idx: int, parent_node: Node, linear_node gm.delete_all_unused_submodules() -def fuse_gemms(gm: GraphModule) -> None: - ad_logger.info("GEMM fusion") - ad_logger.debug("Before GEMM fusion: " + str(gm)) - # sort linear nodes by parent node - linear_nodes = defaultdict(list) - for node in gm.graph.nodes: - # TODO: we don't handle bias for now... - if is_linear_op(node, include_quantization=True) and node.args[2] is None: - linear_nodes[node.args[0]].append(node) +@TransformRegistry.register("fuse_gemms") +class FuseGemms(BaseTransform): + def _apply( + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, + ) -> Tuple[GraphModule, TransformInfo]: + # sort linear nodes by parent node + linear_nodes = defaultdict(list) + for node in gm.graph.nodes: + # TODO: we don't handle bias for now... + if is_linear_op(node, include_quantization=True) and node.args[2] is None: + linear_nodes[node.args[0]].append(node) - # fuse linear nodes - idx = -1 - with cuda_memory_tracker(): - for parent_node, lin_children in linear_nodes.items(): - if len(lin_children) < 2: - continue - # linear nodes to fuse - ad_logger.debug( - f"Found linear nodes to fuse: {lin_children} with parent node: {parent_node}" - ) - _insert_fused_gemm(gm, idx := idx + 1, parent_node, lin_children) + # fuse linear nodes + idx = -1 + num_matches = 0 + with cuda_memory_tracker(): + for parent_node, lin_children in linear_nodes.items(): + if len(lin_children) < 2: + continue + # linear nodes to fuse + _insert_fused_gemm(gm, idx := idx + 1, parent_node, lin_children) + num_matches += 1 - # clean up and return - canonicalize_graph(gm) + torch.cuda.empty_cache() - ad_logger.debug("After GEMM fusion: " + str(gm)) - torch.cuda.empty_cache() + info = TransformInfo( + skipped=False, num_matches=num_matches, is_clean=False, has_valid_shapes=False + ) + return gm, info diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py b/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py new file mode 100644 index 0000000000..80f9d440c1 --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py @@ -0,0 +1,299 @@ +"""Graph transformation to automatically add kv cache into fused MHA op.""" + +import operator +from typing import Dict, Optional, Tuple, Type + +import torch +from pydantic import Field +from torch.fx import Graph, GraphModule, Node + +from ...custom_ops.attention_interface import AttentionRegistry +from ...distributed.common import all_gather_object, get_world_size +from ...models.factory import ModelFactory +from ...shim.interface import CachedSequenceInterface +from ...transformations._graph import add_graph_input +from ...utils.logger import ad_logger +from ...utils.node_utils import get_all_input_output_nodes, is_op +from ..interface import ( + BaseTransform, + SharedConfig, + TransformConfig, + TransformInfo, + TransformRegistry, +) + + +@TransformRegistry.register("update_in_out_nodes") +class UpdateInOutNodes(BaseTransform): + """Modify the graph module by adding new input nodes. + + The new input nodes correspond to the extra arguments needed for cached and flattened attention. + + Args: + egm: The graph module to analyze and modify. + cm: Cached sequence interface containing extra argument information. + """ + + def _apply( + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, + ) -> Tuple[GraphModule, TransformInfo]: + # loop through nodes to get input, output, and get_attr nodes + input_nodes, output_nodes = get_all_input_output_nodes(gm.graph) + + # we only expect one input node + assert len(input_nodes) == 2, "Expected exactly two input nodes (input_ids, position_ids)." + + # NOTE: for now, we wanna make sure we *only* return the final output and no hidden states. + # Later on, we can revisit how to support returning hidden states. + assert len(output_nodes) == 1, "Expected exactly one output node!" + assert len(output_nodes[0].all_input_nodes) == 1, ( + "Expected to only return final tensor output!" + ) + + # Activate and add extra argument nodes + new_args = cm.info.switch_to_cached_attn_inputs() + for name in new_args: + input_nodes.append(add_graph_input(gm, name)) + + info = TransformInfo(skipped=False, num_matches=1, is_clean=False, has_valid_shapes=False) + + return gm, info + + +class InsertCachedAttentionConfig(TransformConfig): + """Configuration for the insert cached attention transform.""" + + attn_backend: Optional[str] = Field(default=None, description="The attention backend to use.") + + +@TransformRegistry.register("insert_cached_attention") +class InsertCachedAttention(BaseTransform): + """ + A transform to insert cached attention into the graph module. + + If attn_backend is not provided in transform config, will find from shared config. + """ + + config: InsertCachedAttentionConfig + + @classmethod + def get_config_class(cls) -> Type[TransformConfig]: + return InsertCachedAttentionConfig + + def _apply( + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, + ) -> Tuple[GraphModule, TransformInfo]: + """Replace uncached source attention node with corresponding cached attn node.""" + attn_descriptor = AttentionRegistry.get(self.config.attn_backend) + + cache_config = factory.get_cache_config() + + # Get all attention nodes and their info objects + source_op = attn_descriptor.get_source_attention_op() + + # pick up graph + graph: Graph = gm.graph + + # look for relevant source attention nodes + source_attn_nodes = [n for n in graph.nodes if is_op(n, source_op)] + + if not source_attn_nodes: + # If there are no nodes for kv cache insertion found, return current graph + return gm, TransformInfo( + skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True + ) + + # Sanity check + if cm.info.is_paged: + assert attn_descriptor.is_paged(), "Paged sequence info requires paged attention op." + + # retrieve input nodes + input_nodes, _ = get_all_input_output_nodes(gm.graph) + + # insert metadata computation and extract each argument as a node + get_metadata, num_metadata = attn_descriptor.get_prepare_metadata_op() + with graph.inserting_before(input_nodes[-1].next): + ret_node = graph.call_function( + get_metadata, + args=( + *input_nodes, + cm.info.page_size, + ), + ) + metadata_nodes = [ + graph.call_function(operator.getitem, args=(ret_node, idx)) + for idx in range(num_metadata) + ] + + buffer_in_lookup: Dict[str, Node] = {} + + # replace fused attention node with attention node that has kv cache + num_cached_attn_replacements = 0 + for idx, attn_node in enumerate(source_attn_nodes): + # pick out GEMMs + qkv = attn_node.args[: attn_descriptor.get_num_qkv_args()] + + # setup + store cache initializers and caches as input nodes + cache_in_nodes = [] + for k, get_cache in attn_descriptor.get_cache_initializers( + attn_node, cache_config + ).items(): + k_indexed = f"{k}_{idx}" + cm.add_cache(k_indexed, get_cache) + cache_in_nodes.append(add_graph_input(gm, k_indexed)) + + # setup + store global buffer initializers and buffers as input nodes + # NOTE: we have to check against existing keys to make sure nothing is registered twice... + buffer_in_nodes = [] + for k, get_buffer in attn_descriptor.get_global_buffer_initializers(attn_node).items(): + if k not in buffer_in_lookup: + cm.add_cache(k, get_buffer) + buffer_in_lookup[k] = add_graph_input(gm, k) + buffer_in_nodes.append(buffer_in_lookup[k]) # store buffer nodes for this op + + # retrieve constants for attention_op + constants = attn_descriptor.get_constants(attn_node) + + # insert cached attention replacement op + with graph.inserting_before(attn_node): + cached_attn_node = graph.call_function( + attn_descriptor.get_cached_attention_op(), + args=(*qkv, *metadata_nodes, *cache_in_nodes, *buffer_in_nodes, *constants), + ) + attn_node.replace_all_uses_with(cached_attn_node) + graph.erase_node(attn_node) + num_cached_attn_replacements += 1 + + info = TransformInfo( + skipped=False, + num_matches=num_cached_attn_replacements, + is_clean=False, + has_valid_shapes=False, + ) + + return gm, info + + +@TransformRegistry.register("insert_cached_mla_attention") +class InsertCachedMLAAttention(InsertCachedAttention): + """ + A transform to insert cached MLA attention into the graph module. + + This class is identical to InsertCachedAttention and inherits all its behavior. + """ + + pass + + +class ResizeKVCacheConfig(TransformConfig): + """Configuration for the resize kv cache transform.""" + + free_mem_ratio: float = Field( + description="The fraction of available memory to occupy.", default=0.8 + ) + + +@TransformRegistry.register("resize_kv_cache") +class ResizeKVCache(BaseTransform): + """Inflate the kv cache to occupy the available GPU memory. + + free_mem_ratio specifies the fraction of available memory to occupy. + """ + + config: ResizeKVCacheConfig + + @classmethod + def get_config_class(cls) -> Type[TransformConfig]: + return ResizeKVCacheConfig + + def _apply( + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, + ) -> Tuple[GraphModule, TransformInfo]: + free_mem_ratio = self.config.free_mem_ratio + + def _get_mem_info_in_mb(): + free_mem, total_mem = torch.cuda.mem_get_info() + return free_mem // 1024**2, total_mem // 1024**2 + + free_mem, total_mem = _get_mem_info_in_mb() + ad_logger.info(f"Free memory (MB): {free_mem}, Total memory (MB): {total_mem}") + current_cache_size = cm.current_cache_size_bytes() + current_num_pages = cm.info.num_pages + ad_logger.info( + f"Current cache size: {current_cache_size}, Current num pages: {current_num_pages}" + ) + + if free_mem_ratio == 0.0: + ad_logger.info(f"Skipping cache resize for {free_mem_ratio=}") + return gm, TransformInfo( + skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True + ) + + try: + # Let's run a forward pass to get the memory usage + cm.info._set_max_num_tokens_sample() + free_mem_pre, _ = _get_mem_info_in_mb() + ad_logger.info(f"Free memory before forward pass (MB): {free_mem_pre}") + + gm(*cm.args) + + free_mem_post, _ = _get_mem_info_in_mb() + ad_logger.info(f"Free memory after forward pass (MB): {free_mem_post}") + + memory_for_forward_pass = free_mem_pre - free_mem_post + ad_logger.info(f"Memory for forward pass (MB): {memory_for_forward_pass}") + + new_cache_size = free_mem_post * 1024 * 1024 * free_mem_ratio + current_cache_size + new_num_pages = int(new_cache_size // (current_cache_size // current_num_pages)) + + # Need to sync all the GPUs + gathered_num_pages = [None] * get_world_size() + all_gather_object(gathered_num_pages, new_num_pages) + new_num_pages = min(gathered_num_pages) + ad_logger.info(f"After all_gather - new_num_pages: {new_num_pages}") + + cm.resize_cache(new_num_pages) + except Exception as e: + ad_logger.warning( + f"Error encountered while resizing kv cache: {e}.\nSkipping cache resize." + ) + + # Free memory + torch.cuda.empty_cache() + + info = TransformInfo( + skipped=False, + num_matches=0, + is_clean=True, + has_valid_shapes=True, + ) + + return gm, info + + +@TransformRegistry.register("initialize_cache") +class InitializeCache(BaseTransform): + def _apply( + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, + ) -> Tuple[GraphModule, TransformInfo]: + cm.initialize_caches() + + info = TransformInfo(skipped=False, num_matches=1, is_clean=True, has_valid_shapes=True) + + return gm, info diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py b/tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py new file mode 100644 index 0000000000..1772037d93 --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py @@ -0,0 +1,148 @@ +"""Graph transform to optimize RMSNorm execution using FlashInfer.""" + +from functools import partial +from typing import Tuple, Type + +import torch +from pydantic import Field +from torch.fx import GraphModule + +from ...models.factory import ModelFactory +from ...shim.interface import CachedSequenceInterface + +# It is important to import ADPatternMatcherPass from pattern_matcher.py, not from torch._inductor.pattern_matcher +from ...utils.pattern_matcher import ADPatternMatcherPass, register_ad_pattern +from ..interface import ( + BaseTransform, + SharedConfig, + TransformConfig, + TransformInfo, + TransformRegistry, +) + +_BACKEND_OPS = { + "flashinfer": torch.ops.auto_deploy.flashinfer_rms_norm, + "triton": torch.ops.auto_deploy.triton_rms_norm, + "torch": torch.ops.auto_deploy.torch_rmsnorm, +} + + +def _rms_norm_pattern(data: torch.Tensor, weight: torch.Tensor, eps: float) -> torch.Tensor: + """Implements the RMSNorm pattern for pattern matching. + + Args: + data: Input tensor to normalize. + weight: Scaling weights for the normalized output. + eps: Small constant for numerical stability. + + Returns: + Normalized and scaled tensor. + """ + input_dtype = data.dtype + data = data.to(torch.float32) + variance = data.pow(2).mean(-1, keepdim=True) + data = data * torch.rsqrt(variance + eps) + return weight * data.to(input_dtype) + + +def _rms_norm_replacement( + data: torch.Tensor, weight: torch.Tensor, eps: float, backend: str +) -> torch.Tensor: + """Backend-specific rms_norm implementation. + + Args: + data: Input tensor to normalize. + weight: Scaling weights for the normalized output. + eps: Small constant for numerical stability. + backend: Backend to use for RMSNorm computation ("flashinfer" or "triton"). + + Returns: + Normalized and scaled tensor using the specified backend implementation. + """ + + assert backend.lower() in _BACKEND_OPS, ( + f"Invalid {backend=}; must be one of {list(_BACKEND_OPS)}" + ) + return _BACKEND_OPS[backend.lower()](data, weight, eps) + + +class FuseRMSNormConfig(TransformConfig): + """Configuration for the RMSNorm fusion transform.""" + + backend: str = Field( + default="flashinfer", + description="Backend to use for RMSNorm computation ('flashinfer' or 'triton').", + ) + + +@TransformRegistry.register("fuse_rmsnorm") +class FuseRMSNorm(BaseTransform): + """Matches and replaces RMSNorm patterns in the graph with FlashInfer or Triton implementation. + + This function sets up pattern matching to identify RMSNorm operations in the graph + and replaces them with optimized implementations. It uses dummy tensors to register + the pattern matching rules. + + Args: + gm: Input graph module to transform. + backend: Backend to use for RMSNorm computation ("flashinfer" or "triton"). + + Returns: + Transformed graph module with optimized RMSNorm operations. + """ + + config: FuseRMSNormConfig + + @classmethod + def get_config_class(cls) -> Type[TransformConfig]: + return FuseRMSNormConfig + + def _apply( + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, + ) -> Tuple[GraphModule, TransformInfo]: + if self.config.backend.lower() not in _BACKEND_OPS: + raise ValueError( + f"Invalid backend, must be one of {list(_BACKEND_OPS)}, got {self.config.backend}" + ) + + graph = gm.graph + patterns = ADPatternMatcherPass() + + # Create dummy tensors for pattern matching + bs = 2 + hidden_size = 512 + + def dummy_args(input_dtype: torch.dtype, weight_dtype: torch.dtype, eps: float = 1e-6): + return [ + torch.randn(bs, hidden_size, device="cuda", dtype=input_dtype), + torch.randn(hidden_size, device="cuda", dtype=weight_dtype), + eps, + ] + + # Define configurations for different data types + configs = [ + (torch.bfloat16, torch.bfloat16), + (torch.float16, torch.float16), + (torch.float32, torch.float32), + ] + + # Register patterns for each configuration + for input_dtype, weight_dtype in configs: + register_ad_pattern( + search_fn=_rms_norm_pattern, + replace_fn=partial(_rms_norm_replacement, backend=self.config.backend), + patterns=patterns, + dummy_args=dummy_args(input_dtype, weight_dtype), + op_ignore_types={}, + scalar_workaround={"eps": 1e-6}, + ) + + cnt = patterns.apply(graph) + + info = TransformInfo(skipped=False, num_matches=cnt, is_clean=False, has_valid_shapes=False) + + return gm, info diff --git a/tensorrt_llm/_torch/auto_deploy/transformations/library/__init__.py b/tensorrt_llm/_torch/auto_deploy/transformations/library/__init__.py index 0d4c388ebc..e5260ada48 100644 --- a/tensorrt_llm/_torch/auto_deploy/transformations/library/__init__.py +++ b/tensorrt_llm/_torch/auto_deploy/transformations/library/__init__.py @@ -1,11 +1,5 @@ """A library of transformation passes.""" -from .collectives import * -from .fused_moe import * -from .fusion import * -from .kvcache import * -from .rms_norm import * - try: from .visualization import visualize_namespace except ImportError: diff --git a/tensorrt_llm/_torch/auto_deploy/transformations/library/collectives.py b/tensorrt_llm/_torch/auto_deploy/transformations/library/collectives.py deleted file mode 100644 index 8cec047561..0000000000 --- a/tensorrt_llm/_torch/auto_deploy/transformations/library/collectives.py +++ /dev/null @@ -1,167 +0,0 @@ -import operator - -import torch -from torch.fx import GraphModule - -from ...distributed.trtllm import is_trtllm_op_available -from ...utils.logger import ad_logger -from ...utils.node_utils import get_op_overload_packet, get_user_if_pattern_match, is_op -from .._graph import canonicalize_graph - - -# TODO: This is an overly simplified model that works well for vanilla Llama models. -# However, we eventually want to consider more sophisticated patterns such as -# * all_reduce(lin1(x) + lin2(x)) -# * version above with fused GEMMs (i.e. with a split node) -# * all_reduce(pointwise_op(linear(x))) -# * ... -def fuse_collectives(gm: GraphModule) -> None: - num_gemm_collective_fusions = 0 - ad_logger.debug("Before GEMM+Collective fusion: " + str(gm)) - - # lookup for fused ops - # TODO: avoid this hardcoded lookup, e.g., by generating fused ops on the fly. - lookup = { - torch.ops.auto_deploy.torch_linear_simple: torch.ops.auto_deploy.trtllm_dist_fused_linear_all_reduce, - torch.ops.aten.linear: torch.ops.auto_deploy.trtllm_dist_fused_linear_all_reduce, - torch.ops.auto_deploy.torch_quant_fp8_linear: torch.ops.auto_deploy.torch_quant_fused_fp8_linear_all_reduce, - } - - # go through all nodes and find all_reduce nodes - for node in gm.graph.nodes: - if not is_op(node, torch.ops.auto_deploy.torch_dist_all_reduce): - continue - - # check if args are as expected - assert len(node.args) == 1 and not len(node.kwargs), "Unexpected args/kwargs for all_reduce" - - # retrieve parent and check a few conditions on the parent node - parent_node = node.args[0] - if not is_op(parent_node, lookup.keys()): - continue - if len(parent_node.users) > 1: - continue - - with gm.graph.inserting_before(node): - # insert fused node - fused_linear_collective_node = gm.graph.call_function( - lookup[get_op_overload_packet(parent_node.target)], - args=parent_node.args, - kwargs=parent_node.kwargs, - ) - node.replace_all_uses_with(fused_linear_collective_node) - gm.graph.erase_node(node) - gm.graph.erase_node(parent_node) - num_gemm_collective_fusions += 1 - - canonicalize_graph(gm) - ad_logger.info(f"Found {num_gemm_collective_fusions} GEMM+Collective fusions") - ad_logger.debug("After GEMM+Collective fusion: " + str(gm)) - - -def fuse_allreduce_residual_rmsnorm(gm: GraphModule) -> None: - """Essentially, this function fuses the following operators into one allreduce trtllm implementation. - - * target pattern: - x = all_reduce(x) - y = x + residual - return rmsnorm(y), y - * replacement: - fused_allreduce_residual_rmsnorm(x, residual, rmsnorm_weight, rmsnorm_eps) - - """ - if not is_trtllm_op_available(): - return - - num_ar_r_rms_fusions = 0 - ad_logger.debug("Before allreduce+residual+rmsnorm fusion: " + str(gm)) - - def trace_and_fuse(allreduce_node, graph): - # Check if all_reduce is followed by addition - users = list(allreduce_node.users.keys()) - if len(users) != 1: - return # Skip if all_reduce has more than one consumer - add_node = users[0] - - # Traverse nodes for RMSNorm pattern which is composed of to_copy, pow, mean, add, refer - # the Huggingface LlamaRMSNorm implementation as example for more details - to_copy_1 = get_user_if_pattern_match(add_node, [torch.ops.aten.add, operator.add], 2) - # operand of pow and mul - pow_node = get_user_if_pattern_match( - to_copy_1, [torch.ops.aten._to_copy, torch.ops.aten.to], 2 - ) - mean_node = get_user_if_pattern_match(pow_node, torch.ops.aten.pow, 1) - add_eps_node = get_user_if_pattern_match(mean_node, torch.ops.aten.mean, 1) - rsqrt_node = get_user_if_pattern_match(add_eps_node, [torch.ops.aten.add, operator.add], 1) - mul_node_1 = get_user_if_pattern_match(rsqrt_node, torch.ops.aten.rsqrt, 1) - to_copy_2 = get_user_if_pattern_match(mul_node_1, torch.ops.aten.mul, 1) - mul_node_2 = get_user_if_pattern_match( - to_copy_2, [torch.ops.aten._to_copy, torch.ops.aten.to], 1 - ) - # check args of ops: pow(2) and mean(-1) - ARGS_MATCH = pow_node is not None and pow_node.args[1] == 2 # exponent - ARGS_MATCH &= mean_node is not None and mean_node.args[1] == [-1] # dimensions - - # Match found: Replace with fused operation - if ( - to_copy_1 - and pow_node - and mean_node - and add_eps_node - and rsqrt_node - and mul_node_1 - and to_copy_2 - and mul_node_2 - and ARGS_MATCH - ): - # Gather the inputs for the custom operation - tensor = allreduce_node.args[0] - # Identify the residual argument in the add operation - # One of the args in add_node.args is the output of all_reduce - # The same idea also applies to norm_weight - residual = add_node.args[0] if add_node.args[1] is allreduce_node else add_node.args[1] - norm_weight = ( - mul_node_2.args[0] if mul_node_2.args[1] is to_copy_2 else mul_node_2.args[1] - ) - eps = add_eps_node.args[1] - - # Insert nodes - with graph.inserting_before(allreduce_node): - fused_node = graph.call_function( - torch.ops.dist.fused_allreduce_residual_rmsnorm, - args=( - tensor, - residual, - norm_weight, - eps, - ), - ) - # Extract outputs from the tuple returned by `fused_node` - final_output_node = gm.graph.create_node( - "call_function", - target=operator.getitem, - args=(fused_node, 0), - ) - add_output_node = gm.graph.create_node( - "call_function", - target=operator.getitem, - args=(fused_node, 1), - ) - - # Replace all uses of rmsnorm_node with final_output_node - mul_node_2.replace_all_uses_with(final_output_node) - - # Replace all uses of add_node with add_output_node - add_node.replace_all_uses_with(add_output_node) - - nonlocal num_ar_r_rms_fusions - num_ar_r_rms_fusions += 1 - - # Traverse all nodes - for node in gm.graph.nodes: - if is_op(node, torch.ops.auto_deploy.torch_dist_all_reduce): - trace_and_fuse(allreduce_node=node, graph=gm.graph) - - canonicalize_graph(gm) - ad_logger.info(f"Found {num_ar_r_rms_fusions} allreduce+residual+rmsnorm fusions") - ad_logger.debug("After allreduce+residual+rmsnorm fusion: " + str(gm)) diff --git a/tensorrt_llm/_torch/auto_deploy/transformations/library/fused_moe.py b/tensorrt_llm/_torch/auto_deploy/transformations/library/fused_moe.py deleted file mode 100644 index e049970862..0000000000 --- a/tensorrt_llm/_torch/auto_deploy/transformations/library/fused_moe.py +++ /dev/null @@ -1,511 +0,0 @@ -from collections import defaultdict -from typing import Optional - -import torch -from torch.fx import GraphModule, Node - -from ...utils.cuda_mem_tracker import cuda_memory_tracker -from ...utils.logger import ad_logger -from ...utils.node_utils import bfs, identify_regions_between_residuals, is_linear_op, is_op -from ...utils.quantization_utils import get_scales_and_type_from_node -from .._graph import canonicalize_graph - - -def match_moe_pattern(gm: GraphModule) -> None: - graph = gm.graph - - ad_logger.debug("Before MoE Pattern Matching: " + str(gm)) - # Preprocessing: Identify boundary nodes (e.g. residual connections) in the graph. - boundary_nodes = identify_regions_between_residuals(gm) - - num_moe_patterns = 0 - - for start_boundary, end_boundary in zip(boundary_nodes[:-1], boundary_nodes[1:]): - # Step 1: Identify Expert Compute pattern - (pattern_input_nodes, pattern_output_nodes, expert_weights, expert_scales, weight_type) = ( - _match_expert_compute_pattern(start_boundary, end_boundary) - ) - if not expert_weights: - continue - # TODO: naming convention to verify the order of the weight nodes - - # Step 2: Trace upwards to locate normalize_routing_weight and selected_experts: - arg1_list, arg2_list = _extract_index_branches_from_expert_outputs(pattern_output_nodes) - normalized_routing_weights = _find_lowest_common_ancessor(arg1_list) - if not normalized_routing_weights: - continue - - common_ancessor2 = _find_lowest_common_ancessor(arg2_list) - if not common_ancessor2: - continue - selected_experts = bfs( - common_ancessor2, - lambda node: is_op(node, torch.ops.aten.one_hot), - attr_next="all_input_nodes", - boundary=start_boundary, - ).args[0] - if not selected_experts: - continue - - # Step 3: Trace upwards to find input node: - hidden_states = _find_lowest_common_ancessor(pattern_input_nodes) - if not hidden_states: - continue - - # Step 4: Find output node with the combine pattern - final_hidden_state_node = _find_final_hidden_state_node(pattern_output_nodes, end_boundary) - if final_hidden_state_node is None: - continue - - # Step 5: Insert the MoE op into the graph. - ad_logger.debug( - f"Found MoE Pattern: between boundary {start_boundary} and {end_boundary}.\n" - f"Input hidden states node: {hidden_states}, " - f"selected_experts node: {selected_experts}, " - f"routing_weights node: {normalized_routing_weights}, " - f"expert weights: {expert_weights}, weight type: {weight_type}" - ) - with graph.inserting_before(final_hidden_state_node): - w1_list = expert_weights["w1"] - w2_list = expert_weights["w2"] - w3_list = expert_weights["w3"] - - if weight_type == "fp8": - fused_moe_node = graph.call_function( - torch.ops.auto_deploy.torch_quant_fp8_moe, - args=( - hidden_states, - selected_experts, - normalized_routing_weights, - w1_list, - w2_list, - w3_list, - expert_scales["w1_input_scale"], - expert_scales["w2_input_scale"], - expert_scales["w3_input_scale"], - expert_scales["w1_weight_scale"], - expert_scales["w2_weight_scale"], - expert_scales["w3_weight_scale"], - ), - ) - elif weight_type == "fp4": - fused_moe_node = graph.call_function( - torch.ops.auto_deploy.torch_quant_fp4_moe, - args=( - hidden_states, - selected_experts, - normalized_routing_weights, - w1_list, - w2_list, - w3_list, - expert_scales["w1_input_scale"], - expert_scales["w2_input_scale"], - expert_scales["w3_input_scale"], - expert_scales["w1_weight_scale"], - expert_scales["w2_weight_scale"], - expert_scales["w3_weight_scale"], - expert_scales["w1_alpha"], - expert_scales["w2_alpha"], - expert_scales["w3_alpha"], - ), - ) - else: - fused_moe_node = graph.call_function( - torch.ops.auto_deploy.torch_moe, - args=( - hidden_states, - selected_experts, - normalized_routing_weights, - w1_list, - w2_list, - w3_list, - ), - ) - - final_hidden_state_node.replace_all_uses_with(fused_moe_node) - graph.erase_node(final_hidden_state_node) - - while _remove_dead_inplace_nodes_in_region(gm.graph, start_boundary, end_boundary): - gm.graph.eliminate_dead_code() - - num_moe_patterns += 1 - - canonicalize_graph(gm) - - ad_logger.info(f"Found {num_moe_patterns} MoE Patterns") - ad_logger.debug("After MoE Pattern Matching: " + str(gm)) - - -def fuse_moe(gm: torch.fx.GraphModule) -> None: - """ - Scan the FX graph and replace all calls to torch.ops.auto_deploy.torch_moe with - torch.ops.auto_deploy.trtllm_moe_fused. - """ - ad_logger.debug("Before MoE fusion: " + str(gm)) - - with cuda_memory_tracker(): - fused_key_counter = _insert_fused_moe_ops(gm) - if fused_key_counter: - canonicalize_graph(gm) - - ad_logger.info(f"Found {fused_key_counter} MoE fusions") - ad_logger.debug("After MoE fusion: " + str(gm)) - - -def _insert_fused_moe_ops(gm: GraphModule) -> int: - fused_key_counter = 0 - graph = gm.graph - - for node in list(graph.nodes): - if not is_op(node, torch.ops.auto_deploy.torch_moe): - continue - - ad_logger.debug(f"Found MoE op to fuse: {node} with args: {node.args}") - hidden_states, selected_experts, routing_weights, w1_list, w2_list, w3_list = node.args - - fused_w3_w1_experts = torch.stack( - [ - torch.cat( - [gm.get_parameter(w3_node.target), gm.get_parameter(w1_node.target)], dim=-2 - ) - for w1_node, w3_node in zip(w1_list, w3_list) - ], - dim=0, - ) - - fused_w2_experts = torch.stack([gm.get_parameter(n.target) for n in w2_list], dim=0) - - new_key_w3_w1 = f"fused_moe_w3_w1_stacked_{fused_key_counter}" - new_key_w2 = f"fused_moe_w2_stacked_{fused_key_counter}" - fused_key_counter += 1 - param_w3_w1 = torch.nn.Parameter(fused_w3_w1_experts) - param_w2 = torch.nn.Parameter(fused_w2_experts) - gm.register_parameter(new_key_w3_w1, param_w3_w1) - gm.register_parameter(new_key_w2, param_w2) - - with graph.inserting_before(node): - new_node = graph.call_function( - # TODO(Fridah-nv): torch.ops.auto_deploy.trtllm_moe_fused for quantized models - torch.ops.auto_deploy.trtllm_moe_fused, - args=( - hidden_states, - selected_experts, - routing_weights, - graph.get_attr(new_key_w3_w1), - graph.get_attr(new_key_w2), - ), - ) - - node.replace_all_uses_with(new_node) - graph.erase_node(node) - - return fused_key_counter - - -def _find_lowest_common_ancessor(nodes: list[Node]) -> Optional[Node]: - """ - Find the lowest common ancestor for a list of nodes in a torch.fx Graph by following - each node's primary branch (recursively following the first Node argument). - - It first finds the LCA of the first two nodes and then - iteratively computes the LCA of the result with the next node, and so on. - - Returns: - The common ancestor Node if found, otherwise None. - """ - if not nodes: - return None - - def get_parent(node: Node) -> Optional[Node]: - """Return the first Node-valued argument for a given node, or None if not found.""" - for arg in node.args: - if isinstance(arg, Node): - return arg - return None - - def get_depth(node: Node) -> int: - """ - Recursively compute the depth of the node by following its primary branch. - Depth is defined as the number of steps to reach a node with no parent. - """ - parent = get_parent(node) - if parent is None: - return 0 - return 1 + get_depth(parent) - - def lca_two(a: Node, b: Node) -> Optional[Node]: - """ - Find the lowest common ancestor of two nodes by first equalizing their depth - and then moving upward until a common node is found. - """ - depth_a = get_depth(a) - depth_b = get_depth(b) - - # Equalize depths - while depth_a > depth_b: - a = get_parent(a) - depth_a -= 1 - while depth_b > depth_a: - b = get_parent(b) - depth_b -= 1 - - # Walk upward in lockstep - while a is not None and b is not None: - if a is b: - return a - a = get_parent(a) - b = get_parent(b) - return None - - # Iteratively compute the LCA across all nodes. - common = nodes[0] - for node in nodes[1:]: - common = lca_two(common, node) - if common is None: - return None - - return common - - -def _extract_linear_parameters(linear_node: Node) -> tuple[Node, torch.Tensor, Optional[dict], str]: - """ - Given a linear op node, extract the input tensor node, weight tensor, - any quantization scales (if the op is quantized), and return a weight type. - - For a torch.ops.auto_deploy.torch_linear_simple.default op: - - Returns (input_node, weight, None, "simple") - - For a torch.ops.auto_deploy.torch_quant_fp8_linear op: - - Returns (input_node, weight, {"input_scale": input_scale, "weight_scale": weight_scale}, "fp8") - For a torch.ops.auto_deploy.torch_quant_fp4_linear op: - - Returns (input_node, weight, {"input_scale": input_scale, "weight_scale": weight_scale, "alpha": alpha}, "fp4") - """ - input_node = linear_node.args[0] - if is_op(linear_node, torch.ops.auto_deploy.torch_linear_simple): - weight = linear_node.args[1] - return input_node, weight, None, "" - elif { - is_op(linear_node, torch.ops.auto_deploy.torch_quant_fp4_linear), - is_op(linear_node, torch.ops.auto_deploy.torch_quant_fp8_linear), - }: - weight = linear_node.args[1] - scales, quant_type = get_scales_and_type_from_node(linear_node) - return input_node, weight, scales, quant_type - - -def _match_expert_compute_pattern(start_boundary: Node, end_boundary: Node): - """ - Match the expert compute pattern between the given boundaries. - - The expert compute pattern corresponds to: - - (F.silu(x @ w1.t()) * (x @ w3.t())) @ w2.t() - - For each expert, the function extracts the input node from the w1 branch and - collects the weight parameters from three linear ops (w1, w3, and w2 branches). - - This function supports both: - - torch.ops.auto_deploy.torch_linear_simple.default ops, and - - torch.ops.auto_deploy.torch_quant_fp8_linear ops (also extracts quantization scales). - - torch.ops.auto_deploy.torch_quant_fp4_linear ops (also extracts quantization scales). - - Returns: - A tuple: - (pattern_input_nodes, pattern_output_nodes, expert_weights, expert_scales, weight_type) - - - pattern_input_nodes: List of input nodes (x) used for the expert compute. - - pattern_output_nodes: List of final expert output nodes (the linear op with weight w2). - - expert_weights: Dict with keys "w1", "w2", "w3" mapping to lists of weight tensors. - - expert_scales: Dict with keys "w1_input_scale", "w1_weight_scale", etc., containing scale tensors - (empty if weight_type is "simple"). - - weight_type: "fp8" if FP8 ops were used, "simple" otherwise. - """ - pattern_input_nodes, pattern_output_nodes = [], [] - expert_weights = defaultdict(list) - expert_scales = defaultdict(list) - weight_type = "simple" # default - - nodes = list(start_boundary.graph.nodes) - region_nodes = nodes[nodes.index(start_boundary) + 1 : nodes.index(end_boundary)] - - for node in region_nodes: - # Accept both simple and quantized linear ops. - if not is_linear_op(node, include_quantization=True): - continue - - final_linear = node - if not final_linear.args or not isinstance(final_linear.args[0], Node): - continue - - mul_node = final_linear.args[0] - if not is_op(mul_node, torch.ops.aten.mul) or len(mul_node.args) < 2: - continue - - arg_a, arg_b = mul_node.args[:2] - silu_node = ( - arg_a - if is_op(arg_a, torch.ops.aten.silu) - else arg_b - if is_op(arg_b, torch.ops.aten.silu) - else None - ) - if silu_node is None: - continue - - if not (silu_node.args and is_linear_op(silu_node.args[0], include_quantization=True)): - continue - linear_w1_node = silu_node.args[0] - - # The other branch should be a linear op (w3 branch). - linear_w3_node = arg_b if arg_a is silu_node else arg_a - if not is_linear_op(linear_w3_node, include_quantization=True): - continue - if not (linear_w1_node.args and linear_w3_node.args): - continue - - # Extract parameters from each linear op. - input_node_w1, weight_w1, quant_params_w1, wt_type_w1 = _extract_linear_parameters( - linear_w1_node - ) - _, weight_w3, quant_params_w3, wt_type_w3 = _extract_linear_parameters(linear_w3_node) - _, weight_w2, quant_params_w2, wt_type_w2 = _extract_linear_parameters(final_linear) - - if None in (weight_w1, weight_w3, weight_w2): - continue - - # Ensure the weight type is consistent across branches. - if wt_type_w1 != wt_type_w3 or wt_type_w1 != wt_type_w2: - continue - weight_type = wt_type_w1 - - pattern_input_nodes.append(input_node_w1) - pattern_output_nodes.append(final_linear) - expert_weights["w1"].append(weight_w1) - expert_weights["w3"].append(weight_w3) - expert_weights["w2"].append(weight_w2) - - # TODO: sanity check that all experts have same weight type - if weight_type == "fp8": - expert_scales["w1_input_scale"].append(quant_params_w1["input_scale"]) - expert_scales["w1_weight_scale"].append(quant_params_w1["weight_scale"]) - expert_scales["w3_input_scale"].append(quant_params_w3["input_scale"]) - expert_scales["w3_weight_scale"].append(quant_params_w3["weight_scale"]) - expert_scales["w2_input_scale"].append(quant_params_w2["input_scale"]) - expert_scales["w2_weight_scale"].append(quant_params_w2["weight_scale"]) - elif weight_type == "fp4": - expert_scales["w1_input_scale"].append(quant_params_w1["input_scale"]) - expert_scales["w1_weight_scale"].append(quant_params_w1["weight_scale"]) - expert_scales["w1_alpha"].append(quant_params_w1["alpha"]) - expert_scales["w3_input_scale"].append(quant_params_w3["input_scale"]) - expert_scales["w3_weight_scale"].append(quant_params_w3["weight_scale"]) - expert_scales["w3_alpha"].append(quant_params_w3["alpha"]) - expert_scales["w2_input_scale"].append(quant_params_w2["input_scale"]) - expert_scales["w2_weight_scale"].append(quant_params_w2["weight_scale"]) - expert_scales["w2_alpha"].append(quant_params_w2["alpha"]) - - return pattern_input_nodes, pattern_output_nodes, expert_weights, expert_scales, weight_type - - -def _find_final_hidden_state_node( - pattern_output_nodes: list[Node], end_boundary: Node -) -> Optional[Node]: - """ - Identify the final hidden state node corresponding to the combine pattern: - - (expert_output * routing_weight) → index_add_ - - For each expert output node (from the expert compute pattern), this function: - 1. Retrieves a multiplication node from its users. - 2. Extracts the second argument from the multiplication node (assumed to be the index node). - 3. Uses a BFS to locate the subsequent index_add_ node (guarded by the end_boundary). - - After collecting all such index_add_ nodes, the final hidden state node is determined - as the one that is not used by any of the other index_add_ nodes. - - If any required attribute (users or args) is missing during the process or if no valid - final node is found, the function returns None. - """ - - if not pattern_output_nodes: - return None - - index_add_nodes = [] - for node in pattern_output_nodes: - if not node.users: - return None - mul_node = next(iter(node.users)) - if not (hasattr(mul_node, "args") and len(mul_node.args) >= 2): - return None - index_node = mul_node.args[1] - index_add_node = bfs( - index_node, lambda n: is_op(n, torch.ops.aten.index_add_), boundary=end_boundary - ) - if not index_add_node: - return None - index_add_nodes.append(index_add_node) - - # The final node is defined as the index_add_node that is not used by any other index_add_nodes - return next( - ( - candidate - for candidate in index_add_nodes - if not any( - candidate in other.args for other in index_add_nodes if candidate is not other - ) - ), - None, - ) - - -def _extract_index_branches_from_expert_outputs( - pattern_output_nodes: list[Node], -) -> tuple[list[Node], list[Node]]: - """ - Extract routing and experts branches from expert outputs. - - For each expert output, find its multiplication user. From the - multiplication node's second argument (an index node), - extract: - - The first argument as the routing branch. - - The second argument (flattened if a list/tuple) as the experts branch. - - Returns: - A tuple (routing_branches, experts_branches). - """ - routing_branches, experts_branches = [], [] - for out in pattern_output_nodes: - mul = next((u for u in out.users if is_op(u, torch.ops.aten.mul)), None) - if not mul or len(mul.args) < 2: - continue - idx_node = mul.args[1] - if not is_op(idx_node, torch.ops.aten.index): - continue - routing_branches.append(idx_node.args[0]) - experts = idx_node.args[1] - experts_branches.extend(experts) if isinstance( - experts, (list, tuple) - ) else experts_branches.append(experts) - return routing_branches, experts_branches - - -def _remove_dead_inplace_nodes_in_region( - graph: torch.fx.Graph, - start_boundary: torch.fx.Node, - end_boundary: torch.fx.Node, -) -> bool: - """ - Searches (via BFS) for a dead in-place node (index_add_) in the region - between start_boundary and end_boundary. If one is found, it is removed from the graph. - Returns True if a node was removed, False otherwise. - """ - - def target(n: torch.fx.Node) -> bool: - return is_op(n, {torch.ops.aten.index_add_}) and len(n.users) == 0 - - try: - node_to_remove = bfs(start_boundary, target, attr_next="users", boundary=end_boundary) - ad_logger.debug(f"Removing In-place Dead Node: {node_to_remove}") - graph.erase_node(node_to_remove) - return True - except RuntimeError: - return False diff --git a/tensorrt_llm/_torch/auto_deploy/transformations/library/kvcache.py b/tensorrt_llm/_torch/auto_deploy/transformations/library/kvcache.py deleted file mode 100644 index 618c8108f8..0000000000 --- a/tensorrt_llm/_torch/auto_deploy/transformations/library/kvcache.py +++ /dev/null @@ -1,193 +0,0 @@ -"""Graph transformation to automatically add kv cache into fused MHA op.""" - -import operator -from typing import Dict, Type - -import torch -from torch.fx import Graph, GraphModule, Node - -from ...custom_ops.attention_interface import AttentionDescriptor, CacheConfig -from ...distributed.common import all_gather_object, get_world_size -from ...shim.interface import CachedSequenceInterface -from ...utils.logger import ad_logger -from ...utils.node_utils import get_all_input_output_nodes, is_op -from .._graph import add_graph_input, canonicalize_graph - - -def update_in_out_nodes(egm: GraphModule, cm: CachedSequenceInterface) -> None: - """Modify the graph module by adding new input nodes and canonicalizing the graph. - - The new input nodes correspond to the extra arguments needed for cached and flattened attention. - - Args: - egm: The graph module to analyze and modify. - cm: Cached sequence interface containing extra argument information. - """ - # loop through nodes to get input, output, and get_attr nodes - input_nodes, output_nodes = get_all_input_output_nodes(egm.graph) - - # we only expect one input node - assert len(input_nodes) == 2, "Expected exactly two input nodes (input_ids, position_ids)." - - # NOTE: for now, we wanna make sure we *only* return the final output and no hidden states. - # Later on, we can revisit how to support returning hidden states. - assert len(output_nodes) == 1, "Expected exactly one output node!" - assert len(output_nodes[0].all_input_nodes) == 1, "Expected to only return final tensor output!" - - ad_logger.info(f"Found {len(input_nodes)} input nodes and {len(output_nodes)} output nodes") - - # Activate and add extra argument nodes - new_args = cm.info.switch_to_cached_attn_inputs() - for name in new_args: - input_nodes.append(add_graph_input(egm, name)) - ad_logger.info(f"Added {len(new_args)} new input nodes for cached attention metadata") - - canonicalize_graph(egm) - - -def insert_cached_attention( - egm: GraphModule, - cm: CachedSequenceInterface, - attn_descriptor: Type[AttentionDescriptor], - cache_config: CacheConfig, -) -> None: - """Replace uncached source attention node with corresponding cached attn node.""" - # Get all attention nodes and their info objects - source_op = attn_descriptor.get_source_attention_op() - - # pick up graph - graph: Graph = egm.graph - - # look for relevant source attention nodes - source_attn_nodes = [n for n in graph.nodes if is_op(n, source_op)] - - if not source_attn_nodes: - # If there are no nodes for kv cache insertion found, return current graph - return - - # Sanity check - if cm.info.is_paged: - assert attn_descriptor.is_paged(), "Paged sequence info requires paged attention op." - - ad_logger.debug(f"Before inserting {attn_descriptor=} with cache: {egm}") - - # retrieve input nodes - input_nodes, _ = get_all_input_output_nodes(egm.graph) - - # insert metadata computation and extract each argument as a node - get_metadata, num_metadata = attn_descriptor.get_prepare_metadata_op() - with graph.inserting_before(input_nodes[-1].next): - ret_node = graph.call_function( - get_metadata, - args=( - *input_nodes, - cm.info.page_size, - ), - ) - metadata_nodes = [ - graph.call_function(operator.getitem, args=(ret_node, idx)) - for idx in range(num_metadata) - ] - - buffer_in_lookup: Dict[str, Node] = {} - - # replace fused attention node with attention node that has kv cache - num_cached_attn_replacements = 0 - for idx, attn_node in enumerate(source_attn_nodes): - # pick out GEMMs - qkv = attn_node.args[: attn_descriptor.get_num_qkv_args()] - - # setup + store cache initializers and caches as input nodes - cache_in_nodes = [] - for k, get_cache in attn_descriptor.get_cache_initializers(attn_node, cache_config).items(): - k_indexed = f"{k}_{idx}" - cm.add_cache(k_indexed, get_cache) - cache_in_nodes.append(add_graph_input(egm, k_indexed)) - - # setup + store global buffer initializers and buffers as input nodes - # NOTE: we have to check against existing keys to make sure nothing is registered twice... - buffer_in_nodes = [] - for k, get_buffer in attn_descriptor.get_global_buffer_initializers(attn_node).items(): - if k not in buffer_in_lookup: - cm.add_cache(k, get_buffer) - buffer_in_lookup[k] = add_graph_input(egm, k) - buffer_in_nodes.append(buffer_in_lookup[k]) # store buffer nodes for this op - - # retrieve constants for attention_op - constants = attn_descriptor.get_constants(attn_node) - - # insert cached attention replacement op - with graph.inserting_before(attn_node): - cached_attn_node = graph.call_function( - attn_descriptor.get_cached_attention_op(), - args=(*qkv, *metadata_nodes, *cache_in_nodes, *buffer_in_nodes, *constants), - ) - attn_node.replace_all_uses_with(cached_attn_node) - graph.erase_node(attn_node) - num_cached_attn_replacements += 1 - - canonicalize_graph(egm) - ad_logger.info( - f"Replaced {num_cached_attn_replacements} {source_op} ops " - f"with {attn_descriptor.get_cached_attention_op()}" - ) - ad_logger.debug(f"After inserting {attn_descriptor=} with cache: {egm}") - - -def resize_kv_cache( - egm: GraphModule, - cm: CachedSequenceInterface, - free_mem_ratio: float = 0.8, -) -> None: - """Inflate the kv cache to occupy the available GPU memory. - - free_mem_ratio specifies the fraction of available memory to occupy. - """ - - def _get_mem_info_in_mb(): - free_mem, total_mem = torch.cuda.mem_get_info() - return free_mem // 1024**2, total_mem // 1024**2 - - free_mem, total_mem = _get_mem_info_in_mb() - ad_logger.info(f"Free memory (MB): {free_mem}, Total memory (MB): {total_mem}") - current_cache_size = cm.current_cache_size_bytes() - current_num_pages = cm.info.num_pages - ad_logger.info( - f"Current cache size: {current_cache_size}, Current num pages: {current_num_pages}" - ) - - if free_mem_ratio == 0.0: - ad_logger.info(f"Skipping cache resize for {free_mem_ratio=}") - return - - try: - # Let's run a forward pass to get the memory usage - cm.info._set_max_num_tokens_sample() - free_mem_pre, _ = _get_mem_info_in_mb() - ad_logger.info(f"Free memory before forward pass (MB): {free_mem_pre}") - - egm(*cm.args) - - free_mem_post, _ = _get_mem_info_in_mb() - ad_logger.info(f"Free memory after forward pass (MB): {free_mem_post}") - - memory_for_forward_pass = free_mem_pre - free_mem_post - ad_logger.info(f"Memory for forward pass (MB): {memory_for_forward_pass}") - - new_cache_size = free_mem_post * 1024 * 1024 * free_mem_ratio + current_cache_size - new_num_pages = int(new_cache_size // (current_cache_size // current_num_pages)) - - # Need to sync all the GPUs - gathered_num_pages = [None] * get_world_size() - all_gather_object(gathered_num_pages, new_num_pages) - new_num_pages = min(gathered_num_pages) - ad_logger.info(f"After all_gather - new_num_pages: {new_num_pages}") - - cm.resize_cache(new_num_pages) - except Exception as e: - ad_logger.warning( - f"Error encountered while resizing kv cache: {e}.\nSkipping cache resize." - ) - - # Free memory - torch.cuda.empty_cache() diff --git a/tensorrt_llm/_torch/auto_deploy/transformations/library/rms_norm.py b/tensorrt_llm/_torch/auto_deploy/transformations/library/rms_norm.py deleted file mode 100644 index a94758b181..0000000000 --- a/tensorrt_llm/_torch/auto_deploy/transformations/library/rms_norm.py +++ /dev/null @@ -1,113 +0,0 @@ -"""Graph transform to optimize RMSNorm execution using FlashInfer.""" - -from functools import partial - -import torch -from torch.fx import GraphModule - -from ...utils.logger import ad_logger - -# It is important to import ADPatternMatcherPass from pattern_matcher.py, not from torch._inductor.pattern_matcher -from ...utils.pattern_matcher import ADPatternMatcherPass, register_ad_pattern -from .._graph import canonicalize_graph - -_BACKEND_OPS = { - "flashinfer": torch.ops.auto_deploy.flashinfer_rms_norm, - "triton": torch.ops.auto_deploy.triton_rms_norm, - "torch": torch.ops.auto_deploy.torch_rmsnorm, -} - - -def _rms_norm_pattern(data: torch.Tensor, weight: torch.Tensor, eps: float) -> torch.Tensor: - """Implements the RMSNorm pattern for pattern matching. - - Args: - data: Input tensor to normalize. - weight: Scaling weights for the normalized output. - eps: Small constant for numerical stability. - - Returns: - Normalized and scaled tensor. - """ - input_dtype = data.dtype - data = data.to(torch.float32) - variance = data.pow(2).mean(-1, keepdim=True) - data = data * torch.rsqrt(variance + eps) - return weight * data.to(input_dtype) - - -def _rms_norm_replacement( - data: torch.Tensor, weight: torch.Tensor, eps: float, backend: str -) -> torch.Tensor: - """Backend-specific rms_norm implementation. - - Args: - data: Input tensor to normalize. - weight: Scaling weights for the normalized output. - eps: Small constant for numerical stability. - backend: Backend to use for RMSNorm computation ("flashinfer" or "triton"). - - Returns: - Normalized and scaled tensor using the specified backend implementation. - """ - - assert backend.lower() in _BACKEND_OPS, ( - f"Invalid {backend=}; must be one of {list(_BACKEND_OPS)}" - ) - return _BACKEND_OPS[backend.lower()](data, weight, eps) - - -def fuse_rmsnorm(gm: GraphModule, backend: str = "triton") -> None: - """Matches and replaces RMSNorm patterns in the graph with FlashInfer or Triton implementation. - - This function sets up pattern matching to identify RMSNorm operations in the graph - and replaces them with optimized implementations. It uses dummy tensors to register - the pattern matching rules. - - Args: - gm: Input graph module to transform. - backend: Backend to use for RMSNorm computation ("flashinfer" or "triton"). - - Returns: - Transformed graph module with optimized RMSNorm operations. - """ - if backend.lower() not in _BACKEND_OPS: - raise ValueError(f"Invalid backend, must be one of {list(_BACKEND_OPS)}, got {backend}") - ad_logger.info(f"Starting RMSNorm pattern matching with backend: {backend}") - - graph = gm.graph - patterns = ADPatternMatcherPass() - - # Create dummy tensors for pattern matching - bs = 2 - hidden_size = 512 - - def dummy_args(input_dtype: torch.dtype, weight_dtype: torch.dtype, eps: float = 1e-6): - return [ - torch.randn(bs, hidden_size, device="cuda", dtype=input_dtype), - torch.randn(hidden_size, device="cuda", dtype=weight_dtype), - eps, - ] - - # Define configurations for different data types - configs = [ - (torch.bfloat16, torch.bfloat16), - (torch.float16, torch.float16), - (torch.float32, torch.float32), - ] - - # Register patterns for each configuration - for input_dtype, weight_dtype in configs: - register_ad_pattern( - search_fn=_rms_norm_pattern, - replace_fn=partial(_rms_norm_replacement, backend=backend), - patterns=patterns, - dummy_args=dummy_args(input_dtype, weight_dtype), - op_ignore_types={}, - scalar_workaround={"eps": 1e-6}, - ) - - cnt = patterns.apply(graph) - ad_logger.info(f"RMSNorm pattern count: {cnt}") - canonicalize_graph(gm) - ad_logger.debug("RMSNorm pattern matching completed.") diff --git a/tensorrt_llm/_torch/auto_deploy/transformations/transform.py b/tensorrt_llm/_torch/auto_deploy/transformations/transform.py index c841b4601f..931c8ec955 100644 --- a/tensorrt_llm/_torch/auto_deploy/transformations/transform.py +++ b/tensorrt_llm/_torch/auto_deploy/transformations/transform.py @@ -5,21 +5,11 @@ import gc import torch import torch.nn as nn -from ..compile import compile_and_capture from ..custom_ops.attention_interface import AttentionRegistry from ..llm_args import AutoDeployConfig from ..models.factory import ModelFactory from ..shim.interface import CachedSequenceInterface from ..transform.optimizer import InferenceOptimizer as ModularInferenceOptimizer -from ..utils.logger import ad_logger -from .library import ( - fuse_allreduce_residual_rmsnorm, - fuse_collectives, - fuse_rmsnorm, - insert_cached_attention, - resize_kv_cache, - update_in_out_nodes, -) class InferenceOptimizer: @@ -55,88 +45,60 @@ class InferenceOptimizer: self.ad_config.attn_backend ).get_attention_layout() - new_optimizer = ModularInferenceOptimizer(self.factory, self.ad_config.transforms) - - # TODO (hg): similar to above. - if "load_weights" in new_optimizer.config: - new_optimizer.config[ + if "load_weights" in self.ad_config.transforms: + self.ad_config.transforms[ "load_weights" ].checkpoint_device = self.ad_config.checkpoint_device - new_optimizer.config["load_weights"].device = cm.device + self.ad_config.transforms["load_weights"].device = cm.device + + if "resize_kv_cache" in self.ad_config.transforms: + self.ad_config.transforms[ + "resize_kv_cache" + ].free_mem_ratio = self.ad_config.free_mem_ratio + if "insert_cached_attention" in self.ad_config.transforms: + self.ad_config.transforms[ + "insert_cached_attention" + ].attn_backend = self.ad_config.attn_backend + if "insert_cached_mla_attention" in self.ad_config.transforms: + self.ad_config.transforms[ + "insert_cached_mla_attention" + ].attn_backend = self.ad_config.mla_backend + + # TODO: (hg)Missing MLA here. Figure out how to add MLA since duplicate transforms are not allowed. + # Old code: + # detect attention op and replace with cache-aware op + # for a_backend in [self.ad_config.attn_backend, self.ad_config.mla_backend]: + # attn_descriptor = AttentionRegistry.get(a_backend) + # insert_cached_attention(egm, cm, attn_descriptor, self.factory.get_cache_config()) + + if "compile_model" in self.ad_config.transforms: + self.ad_config.transforms[ + "compile_model" + ].cuda_graph_batch_sizes = self.ad_config.cuda_graph_batch_sizes + self.ad_config.transforms[ + "compile_model" + ].compile_backend = self.ad_config.compile_backend + + new_optimizer = ModularInferenceOptimizer(self.factory, self.ad_config.transforms) + # TODO: (hg) move this. let match_rope_layout and match_atten_layout use this shared config + new_optimizer.shared_config.attn_backend = self.ad_config.attn_backend egm = new_optimizer(cm) - # TODO (lucaslie): continue moving legacy transforms to the new optimizer - ############################################################################################ - # RUN POST-LOAD FUSION AND OPTIMIZATIONS - ############################################################################################ + # NOTE: (hg)Disabled visualization since compiled gm is a CapturedGraph instead of GraphModule. + # We can add a new stage in the optimizer to visualize the intermediate gm. + # if self.ad_config.visualize: + # try: + # from .library import visualize_namespace - # run MoE fusion - # TODO: https://github.com/NVIDIA/TensorRT-LLM/issues/4674 this is causing OOMs - # fuse_moe(egm) - - # run GEMM fusion - # TODO: https://github.com/NVIDIA/TensorRT-LLM/issues/4674 this is causing OOMs - # fuse_gemms(egm) - - # check if we can fuse allreduce, residual and rmsnorm - fuse_allreduce_residual_rmsnorm(egm) - - # check if we can fuse collectives - fuse_collectives(egm) - - # TODO (lucaslie): add backend selection as part of configurable inference optimizers - # check if we can fuse rmsnorm - fuse_rmsnorm(egm, "flashinfer") - - # visualize the final graph - if self.ad_config.visualize: - try: - from .library import visualize_namespace - - visualize_namespace(egm, args=cm.args, dynamic_shapes=cm.dynamic_shapes) - ad_logger.warning( - "Please run `pip install -r examples/auto_deploy/requirements.txt` to visualize" - " the graph." - ) - except ImportError: - pass - - ############################################################################################ - # SWITCH TO CACHED+FLATTENED ATTENTION + INITIALIZE CACHES - ############################################################################################ - - update_in_out_nodes(egm, cm) - - # detect attention op and replace with cache-aware op - for a_backend in [self.ad_config.attn_backend, self.ad_config.mla_backend]: - attn_descriptor = AttentionRegistry.get(a_backend) - insert_cached_attention(egm, cm, attn_descriptor, self.factory.get_cache_config()) - - # initialize cache on correct device - cm.initialize_caches() - - # resize kv cache to occupy the available GPU memory up to free_mem_ratio - resize_kv_cache(egm, cm, free_mem_ratio=self.ad_config.free_mem_ratio) - - ############################################################################################ - # COMPILE MODEL - ############################################################################################ - - cm.info.set_generate_only_batch() - compiler_kwargs = { - "cuda_graph_batch_sizes": self.ad_config.cuda_graph_batch_sizes, - "num_batched_inputs": 2, # TODO (lucaslie): improve once we have a config system... - } - egm_compiled = compile_and_capture( - egm, - self.ad_config.compile_backend, - args=cm.args, - dynamic_shapes=cm.dynamic_shapes, - compiler_kwargs=compiler_kwargs, - ) - cm.info.reset() + # visualize_namespace(egm, args=cm.args, dynamic_shapes=cm.dynamic_shapes) + # ad_logger.warning( + # "Please run `pip install -r examples/auto_deploy/requirements.txt` to visualize" + # " the graph." + # ) + # except ImportError: + # pass torch.cuda.empty_cache() gc.collect() - return egm_compiled + return egm diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_allreduce_residual_rmsnorm_fusion.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_allreduce_residual_rmsnorm_fusion.py index c81ca0ae1c..58d22302f2 100644 --- a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_allreduce_residual_rmsnorm_fusion.py +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_allreduce_residual_rmsnorm_fusion.py @@ -8,9 +8,7 @@ from torch.export import export from tensorrt_llm._torch.auto_deploy.distributed import common as dist from tensorrt_llm._torch.auto_deploy.distributed.trtllm import is_trtllm_op_available from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm -from tensorrt_llm._torch.auto_deploy.transformations.library.collectives import ( - fuse_allreduce_residual_rmsnorm, -) +from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op from tensorrt_llm.llmapi.mpi_session import MpiPoolSession @@ -65,14 +63,21 @@ def _test_allreduce_fusion(port: int): original_outputs, residual_original = gm(x, residual) # Fuse ops - fuse_allreduce_residual_rmsnorm(gm) + gm_transformed = InferenceOptimizer( + None, + { + "fuse_allreduce_residual_rmsnorm": { + "stage": "post_load_fusion", + }, + }, + )(None, gm) # Run the fused graph - fused_outputs, residual_fused = gm(x, residual) + fused_outputs, residual_fused = gm_transformed(x, residual) # Check if fused node in the graph has_fused_node = False - for node in gm.graph.nodes: + for node in gm_transformed.graph.nodes: if is_op(node, torch.ops.dist.fused_allreduce_residual_rmsnorm): has_fused_node = True assert has_fused_node, "Fused node not found." @@ -86,8 +91,8 @@ def _test_allreduce_fusion(port: int): ) # check if we can still export the model as expected - export(gm, args=args) - torch_export_to_gm(gm, args=args) + export(gm_transformed, args=args) + torch_export_to_gm(gm_transformed, args=args) @pytest.mark.parametrize("device_count", get_device_counts()) diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_collective_fusion.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_collective_fusion.py index 4aa1a875c4..ed3b98f281 100644 --- a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_collective_fusion.py +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_collective_fusion.py @@ -8,12 +8,13 @@ import torch import torch.nn as nn import torch.nn.functional as F from _dist_test_utils import get_device_counts -from _graph_test_helpers import run_test +from _graph_test_helpers import run_test_transformed_gm from _torch_test_utils import fp8_compatible import tensorrt_llm._torch.auto_deploy.distributed.common as dist_common from tensorrt_llm._torch.auto_deploy.custom_ops.quant import FP8Linear -from tensorrt_llm._torch.auto_deploy.transformations.library import fuse_collectives +from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm +from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op @@ -61,11 +62,21 @@ def _run_job( is_op(n, torch.ops.auto_deploy.torch_dist_all_reduce) for n in gm.graph.nodes ) + gm = torch_export_to_gm(model, args=(x,), clone=True) + gm_transformed = InferenceOptimizer( + None, + { + "fuse_collectives": { + "stage": "post_load_fusion", + }, + }, + )(None, gm) + # now run the test - run_test( + run_test_transformed_gm( model, x, - transform=fuse_collectives, + gm_transformed, check_transformed_graph=check_transformed_graph, _get_expected_num_params=_get_expected_num_params, test_load_hook=False, diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rmsnorm.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rmsnorm.py index be2f9d52af..691aad78c5 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rmsnorm.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rmsnorm.py @@ -1,12 +1,11 @@ -from functools import partial - import pytest import torch -from _graph_test_helpers import run_test +from _graph_test_helpers import run_test_transformed_gm from torch.export import Dim from tensorrt_llm._torch.auto_deploy.custom_ops.rms_norm import * # noqa -from tensorrt_llm._torch.auto_deploy.transformations.library.rms_norm import fuse_rmsnorm +from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm +from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op @@ -52,15 +51,28 @@ def test_rmsnorm_fusion(eps, variant, op): return any(is_op(n, op) for n in gm.graph.nodes) model = TestModel(eps) - gm_transformed = run_test( + x = torch.randn(2, 1024, device="cuda", dtype=torch.float16) + dynamic_shapes = {0: Dim("batch_size", max=8)} + gm = torch_export_to_gm(model, args=(x,), dynamic_shapes=(dynamic_shapes,), clone=True) + gm_transformed = InferenceOptimizer( + None, + { + "fuse_rmsnorm": { + "stage": "post_load_fusion", + "backend": variant, + }, + }, + )(None, gm) + + run_test_transformed_gm( model, - torch.randn(2, 1024, device="cuda", dtype=torch.float16), - partial(fuse_rmsnorm, backend=variant), + x, + gm_transformed, checker, lambda num_p_og: num_p_og, - dynamic_shapes={0: Dim("batch_size", max=8)}, + dynamic_shapes=dynamic_shapes, ) - print(gm_transformed.graph) + new_input = torch.randn(4, 1024, device="cuda", dtype=torch.float16) y_transformed = gm_transformed(new_input) y_model = model(new_input) diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gemm_fusion.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gemm_fusion.py index 82a5104503..b99862fdc1 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gemm_fusion.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gemm_fusion.py @@ -7,11 +7,12 @@ import pytest import torch import torch.nn as nn import torch.nn.functional as F -from _graph_test_helpers import count_buffers, run_test +from _graph_test_helpers import count_buffers, run_test_transformed_gm from _torch_test_utils import all_close, fp8_compatible, reset_parameters from tensorrt_llm._torch.auto_deploy.custom_ops.quant import FP8Linear -from tensorrt_llm._torch.auto_deploy.transformations.library import fuse_gemms +from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm +from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_linear_op torch.manual_seed(0) @@ -254,10 +255,20 @@ def test_fusion(get_model: Callable[[], TestModel], dtype: str): buffer_size_before = count_buffers(model) - gm_transformed = run_test( + gm = torch_export_to_gm(model, args=(x,), clone=True) + gm_transformed = InferenceOptimizer( + None, + { + "fuse_gemms": { + "stage": "post_load_fusion", + }, + }, + )(None, gm) + + run_test_transformed_gm( model, x, - fuse_gemms, + gm_transformed, lambda gm: sum(is_linear_op(n, include_quantization=True) for n in gm.graph.nodes) == model.num_gemms_after_fusion, lambda num_p_og: num_p_og, # unchanged since fusing doesn't change param count diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py index f2fd32ea3e..9266027e11 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py @@ -2,19 +2,35 @@ from typing import Optional import pytest import torch -from _graph_test_helpers import FakeFactory +from _graph_test_helpers import SequenceEmbeddingInfo from _model_test_utils import GQA from _torch_test_utils import all_close -from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import CacheConfig, SequenceInfo -from tensorrt_llm._torch.auto_deploy.custom_ops.flashinfer_attention import FlashInferAttention -from tensorrt_llm._torch.auto_deploy.custom_ops.triton_attention import TritonAttention +from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import CacheConfig from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm +from tensorrt_llm._torch.auto_deploy.models.factory import ModelFactory from tensorrt_llm._torch.auto_deploy.shim.interface import CachedSequenceInterface -from tensorrt_llm._torch.auto_deploy.transform.interface import InferenceOptimizerConfig from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer -from tensorrt_llm._torch.auto_deploy.transformations.library import update_in_out_nodes -from tensorrt_llm._torch.auto_deploy.transformations.library.kvcache import insert_cached_attention + + +class DummyFactory(ModelFactory): + """Dummy factory to pass cache_config for testing.""" + + def __init__(self, model, cache_config): + self._model = model + self.cache_config = cache_config + + def build_model(self, device: str): + return self._model.to(device=device) + + def _build_model(self, device: str): + return + + def _load_checkpoint(self, model, device): + return + + def get_cache_config(self): + return self.cache_config # Class that uses SDPA directly instead of the regular attention mechanism @@ -68,42 +84,6 @@ class GQAWithSdpa(GQA): return self.o_proj(attn_output) -def _get_optimizer_config() -> InferenceOptimizerConfig: - return { - "build_model": { - "stage": "factory", - "device": "cuda", - "run_graph_cleanup": False, - "requires_clean_graph": False, - }, - "export_to_gm": { - "stage": "export", - "strict": False, - "clone_state_dict": True, - "run_graph_cleanup": False, - "requires_clean_graph": False, - }, - "cleanup_input_constraints": { - "stage": "post_export", - }, - } - - -class SequenceEmbeddingInfo(SequenceInfo): - hidden_size: int - dtype: torch.dtype - - def set_example_sequence(self) -> None: - super().set_example_sequence() - # set input ids to a 3D tensor (actually input embeddings) - self.input_ids = torch.rand( - *self.input_ids.shape, - self.hidden_size, - device=self.input_ids.device, - dtype=self.dtype, - ) - - # TODO (lucaslie): consider rewriting this test with a custom InferenceOptimizer config @pytest.mark.parametrize( "dtype", @@ -111,8 +91,8 @@ class SequenceEmbeddingInfo(SequenceInfo): ids=["float16", "float32"], ) @pytest.mark.parametrize( - "attn_descriptor", - [TritonAttention, FlashInferAttention], + "attn_backend", + ["triton", "flashinfer"], ids=["triton", "flashinfer"], ) @pytest.mark.parametrize( @@ -125,10 +105,10 @@ class SequenceEmbeddingInfo(SequenceInfo): ids=["regular", "gqa", "mqa"], ) @torch.inference_mode() -def test_sdpa_with_kv_cache(dtype, attn_descriptor, gqa_config): +def test_sdpa_with_kv_cache(dtype, attn_backend, gqa_config): """Test the SDPA transformation with KV cache.""" # flashinfer doesn't support float32 data type - if attn_descriptor == FlashInferAttention and dtype == torch.float32: + if attn_backend == "flashinfer" and dtype == torch.float32: pytest.skip("flashinfer doesn't support float32 data type") # Unpack the GQA configuration @@ -157,7 +137,6 @@ def test_sdpa_with_kv_cache(dtype, attn_descriptor, gqa_config): hidden_size, num_key_value_heads, ).to(dtype=dtype, device="cuda") - factory = FakeFactory(model) # Create input tensor and position_ids x = torch.rand(batch_size, seq_len, hidden_size).to(device="cuda", dtype=dtype) @@ -166,21 +145,37 @@ def test_sdpa_with_kv_cache(dtype, attn_descriptor, gqa_config): # Get the model's regular output y_model = model(x, position_ids) # b, s, d - # run modular inference optimizer up to post_export - optimizer = InferenceOptimizer(factory, _get_optimizer_config()) # type: ignore + # Apply the transformation + optimizer = InferenceOptimizer( + DummyFactory(model, CacheConfig()), + { + "build_model": { + "stage": "factory", + "device": "cuda", + "run_graph_cleanup": False, + "requires_clean_graph": False, + }, + "export_to_gm": { + "stage": "export", + "strict": False, + "clone_state_dict": True, + "run_graph_cleanup": False, + "requires_clean_graph": False, + }, + "cleanup_input_constraints": { + "stage": "post_export", + }, + "update_in_out_nodes": { + "stage": "cache_init", + }, + "insert_cached_attention": { + "stage": "cache_init", + "attn_backend": attn_backend, + }, + }, + ) # type: ignore gm = optimizer(cm) - y_gm = gm(x, position_ids) - assert all_close(y_model, y_gm, atol=atol, rtol=rtol) - - # Set up cache configuration - cache_config = CacheConfig() - - # Get input node(s) - update_in_out_nodes(gm, cm) - - # Apply the transformation - insert_cached_attention(gm, cm, attn_descriptor=attn_descriptor, cache_config=cache_config) gm.to("cuda") cm.initialize_caches()