mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-09 12:41:52 +08:00
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> Co-authored-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
This commit is contained in:
parent
344bc4575d
commit
e18dacc931
@ -19,6 +19,11 @@ transforms:
|
||||
stage: post_export
|
||||
cleanup_input_constraints:
|
||||
stage: post_export
|
||||
############################################################################################
|
||||
# RUN PATTERN MATCHER TRANSFORMATIONS TO STANDARDIZE GRAPH REPRESENTATION
|
||||
############################################################################################
|
||||
match_moe_pattern:
|
||||
stage: pattern_matcher
|
||||
match_repeat_kv:
|
||||
stage: pattern_matcher
|
||||
match_eager_attention:
|
||||
@ -27,12 +32,13 @@ transforms:
|
||||
stage: pattern_matcher
|
||||
match_attention_layout:
|
||||
stage: pattern_matcher
|
||||
match_moe_pattern:
|
||||
stage: pattern_matcher
|
||||
match_rope_pattern:
|
||||
stage: pattern_matcher
|
||||
match_rope_layout:
|
||||
stage: pattern_matcher
|
||||
############################################################################################
|
||||
# RUN TRANSFORMATIONS ON STANDARDIZED GRAPH REPRESENTATION
|
||||
############################################################################################
|
||||
eliminate_redundant_transposes:
|
||||
stage: pattern_matcher
|
||||
# TODO (lucaslie): let's move this to perf optimization once TP sharding is improved
|
||||
@ -57,5 +63,44 @@ transforms:
|
||||
sharding_transform_executor:
|
||||
stage: sharding
|
||||
run_shape_prop: true
|
||||
############################################################################################
|
||||
# MOVE MODEL AND LOAD WEIGHTS
|
||||
############################################################################################
|
||||
load_weights:
|
||||
stage: weight_load
|
||||
############################################################################################
|
||||
# RUN POST-LOAD FUSION AND OPTIMIZATIONS
|
||||
############################################################################################
|
||||
# TODO: https://github.com/NVIDIA/TensorRT-LLM/issues/4674 this is causing OOMs
|
||||
# fuse_moe:
|
||||
# stage: post_load_fusion
|
||||
# fuse_gemms:
|
||||
# stage: post_load_fusion
|
||||
fuse_allreduce_residual_rmsnorm:
|
||||
stage: post_load_fusion
|
||||
fuse_collectives:
|
||||
stage: post_load_fusion
|
||||
# TODO (lucaslie): add backend selection as part of configurable inference optimizers
|
||||
# check if we can fuse rmsnorm
|
||||
fuse_rmsnorm:
|
||||
stage: post_load_fusion
|
||||
backend: flashinfer
|
||||
############################################################################################
|
||||
# SWITCH TO CACHED+FLATTENED ATTENTION + INITIALIZE CACHES
|
||||
############################################################################################
|
||||
update_in_out_nodes:
|
||||
stage: cache_init
|
||||
insert_cached_attention:
|
||||
stage: cache_init
|
||||
insert_cached_mla_attention:
|
||||
stage: cache_init
|
||||
attn_backend: MultiHeadLatentAttention
|
||||
initialize_cache:
|
||||
stage: cache_init
|
||||
resize_kv_cache:
|
||||
stage: cache_init
|
||||
############################################################################################
|
||||
# COMPILE MODEL
|
||||
############################################################################################
|
||||
compile_model:
|
||||
stage: compile
|
||||
|
||||
@ -54,6 +54,7 @@ class SharedConfig(BaseModel):
|
||||
sharding_config: ShardingConfig = Field(default_factory=ShardingConfig)
|
||||
local_rank: int = Field(default=0)
|
||||
world_size: int = Field(default=1)
|
||||
attn_backend: str = Field(default="flashinfer", description="The attention backend to use.")
|
||||
|
||||
|
||||
class TransformConfig(BaseModel):
|
||||
@ -285,7 +286,10 @@ class BaseTransform(ABC):
|
||||
# update + store new meta data
|
||||
history[t_name] = info
|
||||
autodeploy_meta[self._history_key] = history
|
||||
self._set_autodeploy_meta(gm, autodeploy_meta)
|
||||
|
||||
if isinstance(gm, GraphModule):
|
||||
# After compilation, gm becomes type CapturedGraph with no meta data.
|
||||
self._set_autodeploy_meta(gm, autodeploy_meta)
|
||||
|
||||
# return the graph module
|
||||
return gm
|
||||
|
||||
204
tensorrt_llm/_torch/auto_deploy/transform/library/collectives.py
Normal file
204
tensorrt_llm/_torch/auto_deploy/transform/library/collectives.py
Normal file
@ -0,0 +1,204 @@
|
||||
import operator
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from ...distributed.trtllm import is_trtllm_op_available
|
||||
from ...models.factory import ModelFactory
|
||||
from ...shim.interface import CachedSequenceInterface
|
||||
from ...utils.node_utils import get_op_overload_packet, get_user_if_pattern_match, is_op
|
||||
from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry
|
||||
|
||||
# TODO: This is an overly simplified model that works well for vanilla Llama models.
|
||||
# However, we eventually want to consider more sophisticated patterns such as
|
||||
# * all_reduce(lin1(x) + lin2(x))
|
||||
# * version above with fused GEMMs (i.e. with a split node)
|
||||
# * all_reduce(pointwise_op(linear(x)))
|
||||
# * ...
|
||||
|
||||
|
||||
@TransformRegistry.register("fuse_collectives")
|
||||
class FuseCollectives(BaseTransform):
|
||||
"""
|
||||
Fuses all_reduce ops with preceding (quantized) linear ops into a single fused node for improved performance.
|
||||
"""
|
||||
|
||||
def _apply(
|
||||
self,
|
||||
gm: GraphModule,
|
||||
cm: CachedSequenceInterface,
|
||||
factory: ModelFactory,
|
||||
shared_config: SharedConfig,
|
||||
) -> Tuple[GraphModule, TransformInfo]:
|
||||
num_gemm_collective_fusions = 0
|
||||
|
||||
# lookup for fused ops
|
||||
# TODO: avoid this hardcoded lookup, e.g., by generating fused ops on the fly.
|
||||
lookup = {
|
||||
torch.ops.auto_deploy.torch_linear_simple: torch.ops.auto_deploy.trtllm_dist_fused_linear_all_reduce,
|
||||
torch.ops.aten.linear: torch.ops.auto_deploy.trtllm_dist_fused_linear_all_reduce,
|
||||
torch.ops.auto_deploy.torch_quant_fp8_linear: torch.ops.auto_deploy.torch_quant_fused_fp8_linear_all_reduce,
|
||||
}
|
||||
|
||||
# go through all nodes and find all_reduce nodes
|
||||
for node in gm.graph.nodes:
|
||||
if not is_op(node, torch.ops.auto_deploy.torch_dist_all_reduce):
|
||||
continue
|
||||
|
||||
# check if args are as expected
|
||||
assert len(node.args) == 1 and not len(node.kwargs), (
|
||||
"Unexpected args/kwargs for all_reduce"
|
||||
)
|
||||
|
||||
# retrieve parent and check a few conditions on the parent node
|
||||
parent_node = node.args[0]
|
||||
if not is_op(parent_node, lookup.keys()):
|
||||
continue
|
||||
if len(parent_node.users) > 1:
|
||||
continue
|
||||
|
||||
with gm.graph.inserting_before(node):
|
||||
# insert fused node
|
||||
fused_linear_collective_node = gm.graph.call_function(
|
||||
lookup[get_op_overload_packet(parent_node.target)],
|
||||
args=parent_node.args,
|
||||
kwargs=parent_node.kwargs,
|
||||
)
|
||||
node.replace_all_uses_with(fused_linear_collective_node)
|
||||
gm.graph.erase_node(node)
|
||||
gm.graph.erase_node(parent_node)
|
||||
num_gemm_collective_fusions += 1
|
||||
|
||||
info = TransformInfo(
|
||||
skipped=False,
|
||||
num_matches=num_gemm_collective_fusions,
|
||||
is_clean=False,
|
||||
has_valid_shapes=False,
|
||||
)
|
||||
|
||||
return gm, info
|
||||
|
||||
|
||||
@TransformRegistry.register("fuse_allreduce_residual_rmsnorm")
|
||||
class FuseAllreduceResidualRMSNorm(BaseTransform):
|
||||
"""Essentially, this transformation fuses the following operators into one allreduce trtllm implementation.
|
||||
|
||||
* target pattern:
|
||||
x = all_reduce(x)
|
||||
y = x + residual
|
||||
return rmsnorm(y), y
|
||||
* replacement:
|
||||
fused_allreduce_residual_rmsnorm(x, residual, rmsnorm_weight, rmsnorm_eps)
|
||||
|
||||
"""
|
||||
|
||||
def _apply(
|
||||
self,
|
||||
gm: GraphModule,
|
||||
cm: CachedSequenceInterface,
|
||||
factory: ModelFactory,
|
||||
shared_config: SharedConfig,
|
||||
) -> Tuple[GraphModule, TransformInfo]:
|
||||
if not is_trtllm_op_available():
|
||||
return gm, TransformInfo(
|
||||
skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True
|
||||
)
|
||||
|
||||
num_ar_r_rms_fusions = 0
|
||||
|
||||
def trace_and_fuse(allreduce_node, graph):
|
||||
# Check if all_reduce is followed by addition
|
||||
users = list(allreduce_node.users.keys())
|
||||
if len(users) != 1:
|
||||
return # Skip if all_reduce has more than one consumer
|
||||
add_node = users[0]
|
||||
|
||||
# Traverse nodes for RMSNorm pattern which is composed of to_copy, pow, mean, add, refer
|
||||
# the Huggingface LlamaRMSNorm implementation as example for more details
|
||||
to_copy_1 = get_user_if_pattern_match(add_node, [torch.ops.aten.add, operator.add], 2)
|
||||
# operand of pow and mul
|
||||
pow_node = get_user_if_pattern_match(
|
||||
to_copy_1, [torch.ops.aten._to_copy, torch.ops.aten.to], 2
|
||||
)
|
||||
mean_node = get_user_if_pattern_match(pow_node, torch.ops.aten.pow, 1)
|
||||
add_eps_node = get_user_if_pattern_match(mean_node, torch.ops.aten.mean, 1)
|
||||
rsqrt_node = get_user_if_pattern_match(
|
||||
add_eps_node, [torch.ops.aten.add, operator.add], 1
|
||||
)
|
||||
mul_node_1 = get_user_if_pattern_match(rsqrt_node, torch.ops.aten.rsqrt, 1)
|
||||
to_copy_2 = get_user_if_pattern_match(mul_node_1, torch.ops.aten.mul, 1)
|
||||
mul_node_2 = get_user_if_pattern_match(
|
||||
to_copy_2, [torch.ops.aten._to_copy, torch.ops.aten.to], 1
|
||||
)
|
||||
# check args of ops: pow(2) and mean(-1)
|
||||
ARGS_MATCH = pow_node is not None and pow_node.args[1] == 2 # exponent
|
||||
ARGS_MATCH &= mean_node is not None and mean_node.args[1] == [-1] # dimensions
|
||||
|
||||
# Match found: Replace with fused operation
|
||||
if (
|
||||
to_copy_1
|
||||
and pow_node
|
||||
and mean_node
|
||||
and add_eps_node
|
||||
and rsqrt_node
|
||||
and mul_node_1
|
||||
and to_copy_2
|
||||
and mul_node_2
|
||||
and ARGS_MATCH
|
||||
):
|
||||
# Gather the inputs for the custom operation
|
||||
tensor = allreduce_node.args[0]
|
||||
# Identify the residual argument in the add operation
|
||||
# One of the args in add_node.args is the output of all_reduce
|
||||
# The same idea also applies to norm_weight
|
||||
residual = (
|
||||
add_node.args[0] if add_node.args[1] is allreduce_node else add_node.args[1]
|
||||
)
|
||||
norm_weight = (
|
||||
mul_node_2.args[0] if mul_node_2.args[1] is to_copy_2 else mul_node_2.args[1]
|
||||
)
|
||||
eps = add_eps_node.args[1]
|
||||
|
||||
# Insert nodes
|
||||
with graph.inserting_before(allreduce_node):
|
||||
fused_node = graph.call_function(
|
||||
torch.ops.dist.fused_allreduce_residual_rmsnorm,
|
||||
args=(
|
||||
tensor,
|
||||
residual,
|
||||
norm_weight,
|
||||
eps,
|
||||
),
|
||||
)
|
||||
# Extract outputs from the tuple returned by `fused_node`
|
||||
final_output_node = gm.graph.create_node(
|
||||
"call_function",
|
||||
target=operator.getitem,
|
||||
args=(fused_node, 0),
|
||||
)
|
||||
add_output_node = gm.graph.create_node(
|
||||
"call_function",
|
||||
target=operator.getitem,
|
||||
args=(fused_node, 1),
|
||||
)
|
||||
|
||||
# Replace all uses of rmsnorm_node with final_output_node
|
||||
mul_node_2.replace_all_uses_with(final_output_node)
|
||||
|
||||
# Replace all uses of add_node with add_output_node
|
||||
add_node.replace_all_uses_with(add_output_node)
|
||||
|
||||
nonlocal num_ar_r_rms_fusions
|
||||
num_ar_r_rms_fusions += 1
|
||||
|
||||
# Traverse all nodes
|
||||
for node in gm.graph.nodes:
|
||||
if is_op(node, torch.ops.auto_deploy.torch_dist_all_reduce):
|
||||
trace_and_fuse(allreduce_node=node, graph=gm.graph)
|
||||
|
||||
info = TransformInfo(
|
||||
skipped=False, num_matches=num_ar_r_rms_fusions, is_clean=False, has_valid_shapes=False
|
||||
)
|
||||
|
||||
return gm, info
|
||||
@ -0,0 +1,65 @@
|
||||
from typing import List, Literal, Optional, Tuple, Type
|
||||
|
||||
from pydantic import Field
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from ...compile import compile_and_capture
|
||||
from ...models.factory import ModelFactory
|
||||
from ...shim.interface import CachedSequenceInterface
|
||||
from ..interface import (
|
||||
BaseTransform,
|
||||
SharedConfig,
|
||||
TransformConfig,
|
||||
TransformInfo,
|
||||
TransformRegistry,
|
||||
)
|
||||
|
||||
|
||||
class CompileModelConfig(TransformConfig):
|
||||
"""Configuration for the compile model transform."""
|
||||
|
||||
cuda_graph_batch_sizes: Optional[List[int]] = Field(
|
||||
default=None, description="The batch sizes to use for CUDA graphs."
|
||||
)
|
||||
num_batched_inputs: int = Field(
|
||||
default=2, description="The number of batched inputs to use for CUDA graphs."
|
||||
)
|
||||
compile_backend: Literal["torch-simple", "torch-compile", "torch-cudagraph", "torch-opt"] = (
|
||||
Field(description="The backend to use for compiling the model.")
|
||||
)
|
||||
|
||||
|
||||
@TransformRegistry.register("compile_model")
|
||||
class CompileModel(BaseTransform):
|
||||
"""A transform to compile the model."""
|
||||
|
||||
config: CompileModelConfig
|
||||
|
||||
@classmethod
|
||||
def get_config_class(cls) -> Type[TransformConfig]:
|
||||
return CompileModelConfig
|
||||
|
||||
def _apply(
|
||||
self,
|
||||
gm: GraphModule,
|
||||
cm: CachedSequenceInterface,
|
||||
factory: ModelFactory,
|
||||
shared_config: SharedConfig,
|
||||
) -> Tuple[GraphModule, TransformInfo]:
|
||||
cm.info.set_generate_only_batch()
|
||||
egm_compiled = compile_and_capture(
|
||||
gm,
|
||||
self.config.compile_backend,
|
||||
args=cm.args,
|
||||
dynamic_shapes=cm.dynamic_shapes,
|
||||
compiler_kwargs={
|
||||
"cuda_graph_batch_sizes": self.config.cuda_graph_batch_sizes,
|
||||
"num_batched_inputs": self.config.num_batched_inputs,
|
||||
},
|
||||
)
|
||||
cm.info.reset()
|
||||
|
||||
# store info object about the transform
|
||||
info = TransformInfo(skipped=False, num_matches=1, is_clean=True, has_valid_shapes=True)
|
||||
|
||||
return egm_compiled, info
|
||||
@ -6,6 +6,8 @@ import torch
|
||||
import torch.nn as nn
|
||||
from torch.fx import GraphModule, Node
|
||||
|
||||
from ...models.factory import ModelFactory
|
||||
from ...shim.interface import CachedSequenceInterface
|
||||
from ...utils.cuda_mem_tracker import cuda_memory_tracker
|
||||
from ...utils.logger import ad_logger
|
||||
from ...utils.node_utils import (
|
||||
@ -14,7 +16,7 @@ from ...utils.node_utils import (
|
||||
is_linear_op,
|
||||
)
|
||||
from ...utils.quantization_utils import QuantizationImpl
|
||||
from .._graph import canonicalize_graph
|
||||
from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry
|
||||
|
||||
|
||||
def _insert_fused_gemm(gm: GraphModule, idx: int, parent_node: Node, linear_nodes: List[Node]):
|
||||
@ -116,30 +118,36 @@ def _insert_fused_gemm(gm: GraphModule, idx: int, parent_node: Node, linear_node
|
||||
gm.delete_all_unused_submodules()
|
||||
|
||||
|
||||
def fuse_gemms(gm: GraphModule) -> None:
|
||||
ad_logger.info("GEMM fusion")
|
||||
ad_logger.debug("Before GEMM fusion: " + str(gm))
|
||||
# sort linear nodes by parent node
|
||||
linear_nodes = defaultdict(list)
|
||||
for node in gm.graph.nodes:
|
||||
# TODO: we don't handle bias for now...
|
||||
if is_linear_op(node, include_quantization=True) and node.args[2] is None:
|
||||
linear_nodes[node.args[0]].append(node)
|
||||
@TransformRegistry.register("fuse_gemms")
|
||||
class FuseGemms(BaseTransform):
|
||||
def _apply(
|
||||
self,
|
||||
gm: GraphModule,
|
||||
cm: CachedSequenceInterface,
|
||||
factory: ModelFactory,
|
||||
shared_config: SharedConfig,
|
||||
) -> Tuple[GraphModule, TransformInfo]:
|
||||
# sort linear nodes by parent node
|
||||
linear_nodes = defaultdict(list)
|
||||
for node in gm.graph.nodes:
|
||||
# TODO: we don't handle bias for now...
|
||||
if is_linear_op(node, include_quantization=True) and node.args[2] is None:
|
||||
linear_nodes[node.args[0]].append(node)
|
||||
|
||||
# fuse linear nodes
|
||||
idx = -1
|
||||
with cuda_memory_tracker():
|
||||
for parent_node, lin_children in linear_nodes.items():
|
||||
if len(lin_children) < 2:
|
||||
continue
|
||||
# linear nodes to fuse
|
||||
ad_logger.debug(
|
||||
f"Found linear nodes to fuse: {lin_children} with parent node: {parent_node}"
|
||||
)
|
||||
_insert_fused_gemm(gm, idx := idx + 1, parent_node, lin_children)
|
||||
# fuse linear nodes
|
||||
idx = -1
|
||||
num_matches = 0
|
||||
with cuda_memory_tracker():
|
||||
for parent_node, lin_children in linear_nodes.items():
|
||||
if len(lin_children) < 2:
|
||||
continue
|
||||
# linear nodes to fuse
|
||||
_insert_fused_gemm(gm, idx := idx + 1, parent_node, lin_children)
|
||||
num_matches += 1
|
||||
|
||||
# clean up and return
|
||||
canonicalize_graph(gm)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
ad_logger.debug("After GEMM fusion: " + str(gm))
|
||||
torch.cuda.empty_cache()
|
||||
info = TransformInfo(
|
||||
skipped=False, num_matches=num_matches, is_clean=False, has_valid_shapes=False
|
||||
)
|
||||
return gm, info
|
||||
299
tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py
Normal file
299
tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py
Normal file
@ -0,0 +1,299 @@
|
||||
"""Graph transformation to automatically add kv cache into fused MHA op."""
|
||||
|
||||
import operator
|
||||
from typing import Dict, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
from pydantic import Field
|
||||
from torch.fx import Graph, GraphModule, Node
|
||||
|
||||
from ...custom_ops.attention_interface import AttentionRegistry
|
||||
from ...distributed.common import all_gather_object, get_world_size
|
||||
from ...models.factory import ModelFactory
|
||||
from ...shim.interface import CachedSequenceInterface
|
||||
from ...transformations._graph import add_graph_input
|
||||
from ...utils.logger import ad_logger
|
||||
from ...utils.node_utils import get_all_input_output_nodes, is_op
|
||||
from ..interface import (
|
||||
BaseTransform,
|
||||
SharedConfig,
|
||||
TransformConfig,
|
||||
TransformInfo,
|
||||
TransformRegistry,
|
||||
)
|
||||
|
||||
|
||||
@TransformRegistry.register("update_in_out_nodes")
|
||||
class UpdateInOutNodes(BaseTransform):
|
||||
"""Modify the graph module by adding new input nodes.
|
||||
|
||||
The new input nodes correspond to the extra arguments needed for cached and flattened attention.
|
||||
|
||||
Args:
|
||||
egm: The graph module to analyze and modify.
|
||||
cm: Cached sequence interface containing extra argument information.
|
||||
"""
|
||||
|
||||
def _apply(
|
||||
self,
|
||||
gm: GraphModule,
|
||||
cm: CachedSequenceInterface,
|
||||
factory: ModelFactory,
|
||||
shared_config: SharedConfig,
|
||||
) -> Tuple[GraphModule, TransformInfo]:
|
||||
# loop through nodes to get input, output, and get_attr nodes
|
||||
input_nodes, output_nodes = get_all_input_output_nodes(gm.graph)
|
||||
|
||||
# we only expect one input node
|
||||
assert len(input_nodes) == 2, "Expected exactly two input nodes (input_ids, position_ids)."
|
||||
|
||||
# NOTE: for now, we wanna make sure we *only* return the final output and no hidden states.
|
||||
# Later on, we can revisit how to support returning hidden states.
|
||||
assert len(output_nodes) == 1, "Expected exactly one output node!"
|
||||
assert len(output_nodes[0].all_input_nodes) == 1, (
|
||||
"Expected to only return final tensor output!"
|
||||
)
|
||||
|
||||
# Activate and add extra argument nodes
|
||||
new_args = cm.info.switch_to_cached_attn_inputs()
|
||||
for name in new_args:
|
||||
input_nodes.append(add_graph_input(gm, name))
|
||||
|
||||
info = TransformInfo(skipped=False, num_matches=1, is_clean=False, has_valid_shapes=False)
|
||||
|
||||
return gm, info
|
||||
|
||||
|
||||
class InsertCachedAttentionConfig(TransformConfig):
|
||||
"""Configuration for the insert cached attention transform."""
|
||||
|
||||
attn_backend: Optional[str] = Field(default=None, description="The attention backend to use.")
|
||||
|
||||
|
||||
@TransformRegistry.register("insert_cached_attention")
|
||||
class InsertCachedAttention(BaseTransform):
|
||||
"""
|
||||
A transform to insert cached attention into the graph module.
|
||||
|
||||
If attn_backend is not provided in transform config, will find from shared config.
|
||||
"""
|
||||
|
||||
config: InsertCachedAttentionConfig
|
||||
|
||||
@classmethod
|
||||
def get_config_class(cls) -> Type[TransformConfig]:
|
||||
return InsertCachedAttentionConfig
|
||||
|
||||
def _apply(
|
||||
self,
|
||||
gm: GraphModule,
|
||||
cm: CachedSequenceInterface,
|
||||
factory: ModelFactory,
|
||||
shared_config: SharedConfig,
|
||||
) -> Tuple[GraphModule, TransformInfo]:
|
||||
"""Replace uncached source attention node with corresponding cached attn node."""
|
||||
attn_descriptor = AttentionRegistry.get(self.config.attn_backend)
|
||||
|
||||
cache_config = factory.get_cache_config()
|
||||
|
||||
# Get all attention nodes and their info objects
|
||||
source_op = attn_descriptor.get_source_attention_op()
|
||||
|
||||
# pick up graph
|
||||
graph: Graph = gm.graph
|
||||
|
||||
# look for relevant source attention nodes
|
||||
source_attn_nodes = [n for n in graph.nodes if is_op(n, source_op)]
|
||||
|
||||
if not source_attn_nodes:
|
||||
# If there are no nodes for kv cache insertion found, return current graph
|
||||
return gm, TransformInfo(
|
||||
skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True
|
||||
)
|
||||
|
||||
# Sanity check
|
||||
if cm.info.is_paged:
|
||||
assert attn_descriptor.is_paged(), "Paged sequence info requires paged attention op."
|
||||
|
||||
# retrieve input nodes
|
||||
input_nodes, _ = get_all_input_output_nodes(gm.graph)
|
||||
|
||||
# insert metadata computation and extract each argument as a node
|
||||
get_metadata, num_metadata = attn_descriptor.get_prepare_metadata_op()
|
||||
with graph.inserting_before(input_nodes[-1].next):
|
||||
ret_node = graph.call_function(
|
||||
get_metadata,
|
||||
args=(
|
||||
*input_nodes,
|
||||
cm.info.page_size,
|
||||
),
|
||||
)
|
||||
metadata_nodes = [
|
||||
graph.call_function(operator.getitem, args=(ret_node, idx))
|
||||
for idx in range(num_metadata)
|
||||
]
|
||||
|
||||
buffer_in_lookup: Dict[str, Node] = {}
|
||||
|
||||
# replace fused attention node with attention node that has kv cache
|
||||
num_cached_attn_replacements = 0
|
||||
for idx, attn_node in enumerate(source_attn_nodes):
|
||||
# pick out GEMMs
|
||||
qkv = attn_node.args[: attn_descriptor.get_num_qkv_args()]
|
||||
|
||||
# setup + store cache initializers and caches as input nodes
|
||||
cache_in_nodes = []
|
||||
for k, get_cache in attn_descriptor.get_cache_initializers(
|
||||
attn_node, cache_config
|
||||
).items():
|
||||
k_indexed = f"{k}_{idx}"
|
||||
cm.add_cache(k_indexed, get_cache)
|
||||
cache_in_nodes.append(add_graph_input(gm, k_indexed))
|
||||
|
||||
# setup + store global buffer initializers and buffers as input nodes
|
||||
# NOTE: we have to check against existing keys to make sure nothing is registered twice...
|
||||
buffer_in_nodes = []
|
||||
for k, get_buffer in attn_descriptor.get_global_buffer_initializers(attn_node).items():
|
||||
if k not in buffer_in_lookup:
|
||||
cm.add_cache(k, get_buffer)
|
||||
buffer_in_lookup[k] = add_graph_input(gm, k)
|
||||
buffer_in_nodes.append(buffer_in_lookup[k]) # store buffer nodes for this op
|
||||
|
||||
# retrieve constants for attention_op
|
||||
constants = attn_descriptor.get_constants(attn_node)
|
||||
|
||||
# insert cached attention replacement op
|
||||
with graph.inserting_before(attn_node):
|
||||
cached_attn_node = graph.call_function(
|
||||
attn_descriptor.get_cached_attention_op(),
|
||||
args=(*qkv, *metadata_nodes, *cache_in_nodes, *buffer_in_nodes, *constants),
|
||||
)
|
||||
attn_node.replace_all_uses_with(cached_attn_node)
|
||||
graph.erase_node(attn_node)
|
||||
num_cached_attn_replacements += 1
|
||||
|
||||
info = TransformInfo(
|
||||
skipped=False,
|
||||
num_matches=num_cached_attn_replacements,
|
||||
is_clean=False,
|
||||
has_valid_shapes=False,
|
||||
)
|
||||
|
||||
return gm, info
|
||||
|
||||
|
||||
@TransformRegistry.register("insert_cached_mla_attention")
|
||||
class InsertCachedMLAAttention(InsertCachedAttention):
|
||||
"""
|
||||
A transform to insert cached MLA attention into the graph module.
|
||||
|
||||
This class is identical to InsertCachedAttention and inherits all its behavior.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ResizeKVCacheConfig(TransformConfig):
|
||||
"""Configuration for the resize kv cache transform."""
|
||||
|
||||
free_mem_ratio: float = Field(
|
||||
description="The fraction of available memory to occupy.", default=0.8
|
||||
)
|
||||
|
||||
|
||||
@TransformRegistry.register("resize_kv_cache")
|
||||
class ResizeKVCache(BaseTransform):
|
||||
"""Inflate the kv cache to occupy the available GPU memory.
|
||||
|
||||
free_mem_ratio specifies the fraction of available memory to occupy.
|
||||
"""
|
||||
|
||||
config: ResizeKVCacheConfig
|
||||
|
||||
@classmethod
|
||||
def get_config_class(cls) -> Type[TransformConfig]:
|
||||
return ResizeKVCacheConfig
|
||||
|
||||
def _apply(
|
||||
self,
|
||||
gm: GraphModule,
|
||||
cm: CachedSequenceInterface,
|
||||
factory: ModelFactory,
|
||||
shared_config: SharedConfig,
|
||||
) -> Tuple[GraphModule, TransformInfo]:
|
||||
free_mem_ratio = self.config.free_mem_ratio
|
||||
|
||||
def _get_mem_info_in_mb():
|
||||
free_mem, total_mem = torch.cuda.mem_get_info()
|
||||
return free_mem // 1024**2, total_mem // 1024**2
|
||||
|
||||
free_mem, total_mem = _get_mem_info_in_mb()
|
||||
ad_logger.info(f"Free memory (MB): {free_mem}, Total memory (MB): {total_mem}")
|
||||
current_cache_size = cm.current_cache_size_bytes()
|
||||
current_num_pages = cm.info.num_pages
|
||||
ad_logger.info(
|
||||
f"Current cache size: {current_cache_size}, Current num pages: {current_num_pages}"
|
||||
)
|
||||
|
||||
if free_mem_ratio == 0.0:
|
||||
ad_logger.info(f"Skipping cache resize for {free_mem_ratio=}")
|
||||
return gm, TransformInfo(
|
||||
skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True
|
||||
)
|
||||
|
||||
try:
|
||||
# Let's run a forward pass to get the memory usage
|
||||
cm.info._set_max_num_tokens_sample()
|
||||
free_mem_pre, _ = _get_mem_info_in_mb()
|
||||
ad_logger.info(f"Free memory before forward pass (MB): {free_mem_pre}")
|
||||
|
||||
gm(*cm.args)
|
||||
|
||||
free_mem_post, _ = _get_mem_info_in_mb()
|
||||
ad_logger.info(f"Free memory after forward pass (MB): {free_mem_post}")
|
||||
|
||||
memory_for_forward_pass = free_mem_pre - free_mem_post
|
||||
ad_logger.info(f"Memory for forward pass (MB): {memory_for_forward_pass}")
|
||||
|
||||
new_cache_size = free_mem_post * 1024 * 1024 * free_mem_ratio + current_cache_size
|
||||
new_num_pages = int(new_cache_size // (current_cache_size // current_num_pages))
|
||||
|
||||
# Need to sync all the GPUs
|
||||
gathered_num_pages = [None] * get_world_size()
|
||||
all_gather_object(gathered_num_pages, new_num_pages)
|
||||
new_num_pages = min(gathered_num_pages)
|
||||
ad_logger.info(f"After all_gather - new_num_pages: {new_num_pages}")
|
||||
|
||||
cm.resize_cache(new_num_pages)
|
||||
except Exception as e:
|
||||
ad_logger.warning(
|
||||
f"Error encountered while resizing kv cache: {e}.\nSkipping cache resize."
|
||||
)
|
||||
|
||||
# Free memory
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
info = TransformInfo(
|
||||
skipped=False,
|
||||
num_matches=0,
|
||||
is_clean=True,
|
||||
has_valid_shapes=True,
|
||||
)
|
||||
|
||||
return gm, info
|
||||
|
||||
|
||||
@TransformRegistry.register("initialize_cache")
|
||||
class InitializeCache(BaseTransform):
|
||||
def _apply(
|
||||
self,
|
||||
gm: GraphModule,
|
||||
cm: CachedSequenceInterface,
|
||||
factory: ModelFactory,
|
||||
shared_config: SharedConfig,
|
||||
) -> Tuple[GraphModule, TransformInfo]:
|
||||
cm.initialize_caches()
|
||||
|
||||
info = TransformInfo(skipped=False, num_matches=1, is_clean=True, has_valid_shapes=True)
|
||||
|
||||
return gm, info
|
||||
148
tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py
Normal file
148
tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py
Normal file
@ -0,0 +1,148 @@
|
||||
"""Graph transform to optimize RMSNorm execution using FlashInfer."""
|
||||
|
||||
from functools import partial
|
||||
from typing import Tuple, Type
|
||||
|
||||
import torch
|
||||
from pydantic import Field
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from ...models.factory import ModelFactory
|
||||
from ...shim.interface import CachedSequenceInterface
|
||||
|
||||
# It is important to import ADPatternMatcherPass from pattern_matcher.py, not from torch._inductor.pattern_matcher
|
||||
from ...utils.pattern_matcher import ADPatternMatcherPass, register_ad_pattern
|
||||
from ..interface import (
|
||||
BaseTransform,
|
||||
SharedConfig,
|
||||
TransformConfig,
|
||||
TransformInfo,
|
||||
TransformRegistry,
|
||||
)
|
||||
|
||||
_BACKEND_OPS = {
|
||||
"flashinfer": torch.ops.auto_deploy.flashinfer_rms_norm,
|
||||
"triton": torch.ops.auto_deploy.triton_rms_norm,
|
||||
"torch": torch.ops.auto_deploy.torch_rmsnorm,
|
||||
}
|
||||
|
||||
|
||||
def _rms_norm_pattern(data: torch.Tensor, weight: torch.Tensor, eps: float) -> torch.Tensor:
|
||||
"""Implements the RMSNorm pattern for pattern matching.
|
||||
|
||||
Args:
|
||||
data: Input tensor to normalize.
|
||||
weight: Scaling weights for the normalized output.
|
||||
eps: Small constant for numerical stability.
|
||||
|
||||
Returns:
|
||||
Normalized and scaled tensor.
|
||||
"""
|
||||
input_dtype = data.dtype
|
||||
data = data.to(torch.float32)
|
||||
variance = data.pow(2).mean(-1, keepdim=True)
|
||||
data = data * torch.rsqrt(variance + eps)
|
||||
return weight * data.to(input_dtype)
|
||||
|
||||
|
||||
def _rms_norm_replacement(
|
||||
data: torch.Tensor, weight: torch.Tensor, eps: float, backend: str
|
||||
) -> torch.Tensor:
|
||||
"""Backend-specific rms_norm implementation.
|
||||
|
||||
Args:
|
||||
data: Input tensor to normalize.
|
||||
weight: Scaling weights for the normalized output.
|
||||
eps: Small constant for numerical stability.
|
||||
backend: Backend to use for RMSNorm computation ("flashinfer" or "triton").
|
||||
|
||||
Returns:
|
||||
Normalized and scaled tensor using the specified backend implementation.
|
||||
"""
|
||||
|
||||
assert backend.lower() in _BACKEND_OPS, (
|
||||
f"Invalid {backend=}; must be one of {list(_BACKEND_OPS)}"
|
||||
)
|
||||
return _BACKEND_OPS[backend.lower()](data, weight, eps)
|
||||
|
||||
|
||||
class FuseRMSNormConfig(TransformConfig):
|
||||
"""Configuration for the RMSNorm fusion transform."""
|
||||
|
||||
backend: str = Field(
|
||||
default="flashinfer",
|
||||
description="Backend to use for RMSNorm computation ('flashinfer' or 'triton').",
|
||||
)
|
||||
|
||||
|
||||
@TransformRegistry.register("fuse_rmsnorm")
|
||||
class FuseRMSNorm(BaseTransform):
|
||||
"""Matches and replaces RMSNorm patterns in the graph with FlashInfer or Triton implementation.
|
||||
|
||||
This function sets up pattern matching to identify RMSNorm operations in the graph
|
||||
and replaces them with optimized implementations. It uses dummy tensors to register
|
||||
the pattern matching rules.
|
||||
|
||||
Args:
|
||||
gm: Input graph module to transform.
|
||||
backend: Backend to use for RMSNorm computation ("flashinfer" or "triton").
|
||||
|
||||
Returns:
|
||||
Transformed graph module with optimized RMSNorm operations.
|
||||
"""
|
||||
|
||||
config: FuseRMSNormConfig
|
||||
|
||||
@classmethod
|
||||
def get_config_class(cls) -> Type[TransformConfig]:
|
||||
return FuseRMSNormConfig
|
||||
|
||||
def _apply(
|
||||
self,
|
||||
gm: GraphModule,
|
||||
cm: CachedSequenceInterface,
|
||||
factory: ModelFactory,
|
||||
shared_config: SharedConfig,
|
||||
) -> Tuple[GraphModule, TransformInfo]:
|
||||
if self.config.backend.lower() not in _BACKEND_OPS:
|
||||
raise ValueError(
|
||||
f"Invalid backend, must be one of {list(_BACKEND_OPS)}, got {self.config.backend}"
|
||||
)
|
||||
|
||||
graph = gm.graph
|
||||
patterns = ADPatternMatcherPass()
|
||||
|
||||
# Create dummy tensors for pattern matching
|
||||
bs = 2
|
||||
hidden_size = 512
|
||||
|
||||
def dummy_args(input_dtype: torch.dtype, weight_dtype: torch.dtype, eps: float = 1e-6):
|
||||
return [
|
||||
torch.randn(bs, hidden_size, device="cuda", dtype=input_dtype),
|
||||
torch.randn(hidden_size, device="cuda", dtype=weight_dtype),
|
||||
eps,
|
||||
]
|
||||
|
||||
# Define configurations for different data types
|
||||
configs = [
|
||||
(torch.bfloat16, torch.bfloat16),
|
||||
(torch.float16, torch.float16),
|
||||
(torch.float32, torch.float32),
|
||||
]
|
||||
|
||||
# Register patterns for each configuration
|
||||
for input_dtype, weight_dtype in configs:
|
||||
register_ad_pattern(
|
||||
search_fn=_rms_norm_pattern,
|
||||
replace_fn=partial(_rms_norm_replacement, backend=self.config.backend),
|
||||
patterns=patterns,
|
||||
dummy_args=dummy_args(input_dtype, weight_dtype),
|
||||
op_ignore_types={},
|
||||
scalar_workaround={"eps": 1e-6},
|
||||
)
|
||||
|
||||
cnt = patterns.apply(graph)
|
||||
|
||||
info = TransformInfo(skipped=False, num_matches=cnt, is_clean=False, has_valid_shapes=False)
|
||||
|
||||
return gm, info
|
||||
@ -1,11 +1,5 @@
|
||||
"""A library of transformation passes."""
|
||||
|
||||
from .collectives import *
|
||||
from .fused_moe import *
|
||||
from .fusion import *
|
||||
from .kvcache import *
|
||||
from .rms_norm import *
|
||||
|
||||
try:
|
||||
from .visualization import visualize_namespace
|
||||
except ImportError:
|
||||
|
||||
@ -1,167 +0,0 @@
|
||||
import operator
|
||||
|
||||
import torch
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from ...distributed.trtllm import is_trtllm_op_available
|
||||
from ...utils.logger import ad_logger
|
||||
from ...utils.node_utils import get_op_overload_packet, get_user_if_pattern_match, is_op
|
||||
from .._graph import canonicalize_graph
|
||||
|
||||
|
||||
# TODO: This is an overly simplified model that works well for vanilla Llama models.
|
||||
# However, we eventually want to consider more sophisticated patterns such as
|
||||
# * all_reduce(lin1(x) + lin2(x))
|
||||
# * version above with fused GEMMs (i.e. with a split node)
|
||||
# * all_reduce(pointwise_op(linear(x)))
|
||||
# * ...
|
||||
def fuse_collectives(gm: GraphModule) -> None:
|
||||
num_gemm_collective_fusions = 0
|
||||
ad_logger.debug("Before GEMM+Collective fusion: " + str(gm))
|
||||
|
||||
# lookup for fused ops
|
||||
# TODO: avoid this hardcoded lookup, e.g., by generating fused ops on the fly.
|
||||
lookup = {
|
||||
torch.ops.auto_deploy.torch_linear_simple: torch.ops.auto_deploy.trtllm_dist_fused_linear_all_reduce,
|
||||
torch.ops.aten.linear: torch.ops.auto_deploy.trtllm_dist_fused_linear_all_reduce,
|
||||
torch.ops.auto_deploy.torch_quant_fp8_linear: torch.ops.auto_deploy.torch_quant_fused_fp8_linear_all_reduce,
|
||||
}
|
||||
|
||||
# go through all nodes and find all_reduce nodes
|
||||
for node in gm.graph.nodes:
|
||||
if not is_op(node, torch.ops.auto_deploy.torch_dist_all_reduce):
|
||||
continue
|
||||
|
||||
# check if args are as expected
|
||||
assert len(node.args) == 1 and not len(node.kwargs), "Unexpected args/kwargs for all_reduce"
|
||||
|
||||
# retrieve parent and check a few conditions on the parent node
|
||||
parent_node = node.args[0]
|
||||
if not is_op(parent_node, lookup.keys()):
|
||||
continue
|
||||
if len(parent_node.users) > 1:
|
||||
continue
|
||||
|
||||
with gm.graph.inserting_before(node):
|
||||
# insert fused node
|
||||
fused_linear_collective_node = gm.graph.call_function(
|
||||
lookup[get_op_overload_packet(parent_node.target)],
|
||||
args=parent_node.args,
|
||||
kwargs=parent_node.kwargs,
|
||||
)
|
||||
node.replace_all_uses_with(fused_linear_collective_node)
|
||||
gm.graph.erase_node(node)
|
||||
gm.graph.erase_node(parent_node)
|
||||
num_gemm_collective_fusions += 1
|
||||
|
||||
canonicalize_graph(gm)
|
||||
ad_logger.info(f"Found {num_gemm_collective_fusions} GEMM+Collective fusions")
|
||||
ad_logger.debug("After GEMM+Collective fusion: " + str(gm))
|
||||
|
||||
|
||||
def fuse_allreduce_residual_rmsnorm(gm: GraphModule) -> None:
|
||||
"""Essentially, this function fuses the following operators into one allreduce trtllm implementation.
|
||||
|
||||
* target pattern:
|
||||
x = all_reduce(x)
|
||||
y = x + residual
|
||||
return rmsnorm(y), y
|
||||
* replacement:
|
||||
fused_allreduce_residual_rmsnorm(x, residual, rmsnorm_weight, rmsnorm_eps)
|
||||
|
||||
"""
|
||||
if not is_trtllm_op_available():
|
||||
return
|
||||
|
||||
num_ar_r_rms_fusions = 0
|
||||
ad_logger.debug("Before allreduce+residual+rmsnorm fusion: " + str(gm))
|
||||
|
||||
def trace_and_fuse(allreduce_node, graph):
|
||||
# Check if all_reduce is followed by addition
|
||||
users = list(allreduce_node.users.keys())
|
||||
if len(users) != 1:
|
||||
return # Skip if all_reduce has more than one consumer
|
||||
add_node = users[0]
|
||||
|
||||
# Traverse nodes for RMSNorm pattern which is composed of to_copy, pow, mean, add, refer
|
||||
# the Huggingface LlamaRMSNorm implementation as example for more details
|
||||
to_copy_1 = get_user_if_pattern_match(add_node, [torch.ops.aten.add, operator.add], 2)
|
||||
# operand of pow and mul
|
||||
pow_node = get_user_if_pattern_match(
|
||||
to_copy_1, [torch.ops.aten._to_copy, torch.ops.aten.to], 2
|
||||
)
|
||||
mean_node = get_user_if_pattern_match(pow_node, torch.ops.aten.pow, 1)
|
||||
add_eps_node = get_user_if_pattern_match(mean_node, torch.ops.aten.mean, 1)
|
||||
rsqrt_node = get_user_if_pattern_match(add_eps_node, [torch.ops.aten.add, operator.add], 1)
|
||||
mul_node_1 = get_user_if_pattern_match(rsqrt_node, torch.ops.aten.rsqrt, 1)
|
||||
to_copy_2 = get_user_if_pattern_match(mul_node_1, torch.ops.aten.mul, 1)
|
||||
mul_node_2 = get_user_if_pattern_match(
|
||||
to_copy_2, [torch.ops.aten._to_copy, torch.ops.aten.to], 1
|
||||
)
|
||||
# check args of ops: pow(2) and mean(-1)
|
||||
ARGS_MATCH = pow_node is not None and pow_node.args[1] == 2 # exponent
|
||||
ARGS_MATCH &= mean_node is not None and mean_node.args[1] == [-1] # dimensions
|
||||
|
||||
# Match found: Replace with fused operation
|
||||
if (
|
||||
to_copy_1
|
||||
and pow_node
|
||||
and mean_node
|
||||
and add_eps_node
|
||||
and rsqrt_node
|
||||
and mul_node_1
|
||||
and to_copy_2
|
||||
and mul_node_2
|
||||
and ARGS_MATCH
|
||||
):
|
||||
# Gather the inputs for the custom operation
|
||||
tensor = allreduce_node.args[0]
|
||||
# Identify the residual argument in the add operation
|
||||
# One of the args in add_node.args is the output of all_reduce
|
||||
# The same idea also applies to norm_weight
|
||||
residual = add_node.args[0] if add_node.args[1] is allreduce_node else add_node.args[1]
|
||||
norm_weight = (
|
||||
mul_node_2.args[0] if mul_node_2.args[1] is to_copy_2 else mul_node_2.args[1]
|
||||
)
|
||||
eps = add_eps_node.args[1]
|
||||
|
||||
# Insert nodes
|
||||
with graph.inserting_before(allreduce_node):
|
||||
fused_node = graph.call_function(
|
||||
torch.ops.dist.fused_allreduce_residual_rmsnorm,
|
||||
args=(
|
||||
tensor,
|
||||
residual,
|
||||
norm_weight,
|
||||
eps,
|
||||
),
|
||||
)
|
||||
# Extract outputs from the tuple returned by `fused_node`
|
||||
final_output_node = gm.graph.create_node(
|
||||
"call_function",
|
||||
target=operator.getitem,
|
||||
args=(fused_node, 0),
|
||||
)
|
||||
add_output_node = gm.graph.create_node(
|
||||
"call_function",
|
||||
target=operator.getitem,
|
||||
args=(fused_node, 1),
|
||||
)
|
||||
|
||||
# Replace all uses of rmsnorm_node with final_output_node
|
||||
mul_node_2.replace_all_uses_with(final_output_node)
|
||||
|
||||
# Replace all uses of add_node with add_output_node
|
||||
add_node.replace_all_uses_with(add_output_node)
|
||||
|
||||
nonlocal num_ar_r_rms_fusions
|
||||
num_ar_r_rms_fusions += 1
|
||||
|
||||
# Traverse all nodes
|
||||
for node in gm.graph.nodes:
|
||||
if is_op(node, torch.ops.auto_deploy.torch_dist_all_reduce):
|
||||
trace_and_fuse(allreduce_node=node, graph=gm.graph)
|
||||
|
||||
canonicalize_graph(gm)
|
||||
ad_logger.info(f"Found {num_ar_r_rms_fusions} allreduce+residual+rmsnorm fusions")
|
||||
ad_logger.debug("After allreduce+residual+rmsnorm fusion: " + str(gm))
|
||||
@ -1,511 +0,0 @@
|
||||
from collections import defaultdict
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch.fx import GraphModule, Node
|
||||
|
||||
from ...utils.cuda_mem_tracker import cuda_memory_tracker
|
||||
from ...utils.logger import ad_logger
|
||||
from ...utils.node_utils import bfs, identify_regions_between_residuals, is_linear_op, is_op
|
||||
from ...utils.quantization_utils import get_scales_and_type_from_node
|
||||
from .._graph import canonicalize_graph
|
||||
|
||||
|
||||
def match_moe_pattern(gm: GraphModule) -> None:
|
||||
graph = gm.graph
|
||||
|
||||
ad_logger.debug("Before MoE Pattern Matching: " + str(gm))
|
||||
# Preprocessing: Identify boundary nodes (e.g. residual connections) in the graph.
|
||||
boundary_nodes = identify_regions_between_residuals(gm)
|
||||
|
||||
num_moe_patterns = 0
|
||||
|
||||
for start_boundary, end_boundary in zip(boundary_nodes[:-1], boundary_nodes[1:]):
|
||||
# Step 1: Identify Expert Compute pattern
|
||||
(pattern_input_nodes, pattern_output_nodes, expert_weights, expert_scales, weight_type) = (
|
||||
_match_expert_compute_pattern(start_boundary, end_boundary)
|
||||
)
|
||||
if not expert_weights:
|
||||
continue
|
||||
# TODO: naming convention to verify the order of the weight nodes
|
||||
|
||||
# Step 2: Trace upwards to locate normalize_routing_weight and selected_experts:
|
||||
arg1_list, arg2_list = _extract_index_branches_from_expert_outputs(pattern_output_nodes)
|
||||
normalized_routing_weights = _find_lowest_common_ancessor(arg1_list)
|
||||
if not normalized_routing_weights:
|
||||
continue
|
||||
|
||||
common_ancessor2 = _find_lowest_common_ancessor(arg2_list)
|
||||
if not common_ancessor2:
|
||||
continue
|
||||
selected_experts = bfs(
|
||||
common_ancessor2,
|
||||
lambda node: is_op(node, torch.ops.aten.one_hot),
|
||||
attr_next="all_input_nodes",
|
||||
boundary=start_boundary,
|
||||
).args[0]
|
||||
if not selected_experts:
|
||||
continue
|
||||
|
||||
# Step 3: Trace upwards to find input node:
|
||||
hidden_states = _find_lowest_common_ancessor(pattern_input_nodes)
|
||||
if not hidden_states:
|
||||
continue
|
||||
|
||||
# Step 4: Find output node with the combine pattern
|
||||
final_hidden_state_node = _find_final_hidden_state_node(pattern_output_nodes, end_boundary)
|
||||
if final_hidden_state_node is None:
|
||||
continue
|
||||
|
||||
# Step 5: Insert the MoE op into the graph.
|
||||
ad_logger.debug(
|
||||
f"Found MoE Pattern: between boundary {start_boundary} and {end_boundary}.\n"
|
||||
f"Input hidden states node: {hidden_states}, "
|
||||
f"selected_experts node: {selected_experts}, "
|
||||
f"routing_weights node: {normalized_routing_weights}, "
|
||||
f"expert weights: {expert_weights}, weight type: {weight_type}"
|
||||
)
|
||||
with graph.inserting_before(final_hidden_state_node):
|
||||
w1_list = expert_weights["w1"]
|
||||
w2_list = expert_weights["w2"]
|
||||
w3_list = expert_weights["w3"]
|
||||
|
||||
if weight_type == "fp8":
|
||||
fused_moe_node = graph.call_function(
|
||||
torch.ops.auto_deploy.torch_quant_fp8_moe,
|
||||
args=(
|
||||
hidden_states,
|
||||
selected_experts,
|
||||
normalized_routing_weights,
|
||||
w1_list,
|
||||
w2_list,
|
||||
w3_list,
|
||||
expert_scales["w1_input_scale"],
|
||||
expert_scales["w2_input_scale"],
|
||||
expert_scales["w3_input_scale"],
|
||||
expert_scales["w1_weight_scale"],
|
||||
expert_scales["w2_weight_scale"],
|
||||
expert_scales["w3_weight_scale"],
|
||||
),
|
||||
)
|
||||
elif weight_type == "fp4":
|
||||
fused_moe_node = graph.call_function(
|
||||
torch.ops.auto_deploy.torch_quant_fp4_moe,
|
||||
args=(
|
||||
hidden_states,
|
||||
selected_experts,
|
||||
normalized_routing_weights,
|
||||
w1_list,
|
||||
w2_list,
|
||||
w3_list,
|
||||
expert_scales["w1_input_scale"],
|
||||
expert_scales["w2_input_scale"],
|
||||
expert_scales["w3_input_scale"],
|
||||
expert_scales["w1_weight_scale"],
|
||||
expert_scales["w2_weight_scale"],
|
||||
expert_scales["w3_weight_scale"],
|
||||
expert_scales["w1_alpha"],
|
||||
expert_scales["w2_alpha"],
|
||||
expert_scales["w3_alpha"],
|
||||
),
|
||||
)
|
||||
else:
|
||||
fused_moe_node = graph.call_function(
|
||||
torch.ops.auto_deploy.torch_moe,
|
||||
args=(
|
||||
hidden_states,
|
||||
selected_experts,
|
||||
normalized_routing_weights,
|
||||
w1_list,
|
||||
w2_list,
|
||||
w3_list,
|
||||
),
|
||||
)
|
||||
|
||||
final_hidden_state_node.replace_all_uses_with(fused_moe_node)
|
||||
graph.erase_node(final_hidden_state_node)
|
||||
|
||||
while _remove_dead_inplace_nodes_in_region(gm.graph, start_boundary, end_boundary):
|
||||
gm.graph.eliminate_dead_code()
|
||||
|
||||
num_moe_patterns += 1
|
||||
|
||||
canonicalize_graph(gm)
|
||||
|
||||
ad_logger.info(f"Found {num_moe_patterns} MoE Patterns")
|
||||
ad_logger.debug("After MoE Pattern Matching: " + str(gm))
|
||||
|
||||
|
||||
def fuse_moe(gm: torch.fx.GraphModule) -> None:
|
||||
"""
|
||||
Scan the FX graph and replace all calls to torch.ops.auto_deploy.torch_moe with
|
||||
torch.ops.auto_deploy.trtllm_moe_fused.
|
||||
"""
|
||||
ad_logger.debug("Before MoE fusion: " + str(gm))
|
||||
|
||||
with cuda_memory_tracker():
|
||||
fused_key_counter = _insert_fused_moe_ops(gm)
|
||||
if fused_key_counter:
|
||||
canonicalize_graph(gm)
|
||||
|
||||
ad_logger.info(f"Found {fused_key_counter} MoE fusions")
|
||||
ad_logger.debug("After MoE fusion: " + str(gm))
|
||||
|
||||
|
||||
def _insert_fused_moe_ops(gm: GraphModule) -> int:
|
||||
fused_key_counter = 0
|
||||
graph = gm.graph
|
||||
|
||||
for node in list(graph.nodes):
|
||||
if not is_op(node, torch.ops.auto_deploy.torch_moe):
|
||||
continue
|
||||
|
||||
ad_logger.debug(f"Found MoE op to fuse: {node} with args: {node.args}")
|
||||
hidden_states, selected_experts, routing_weights, w1_list, w2_list, w3_list = node.args
|
||||
|
||||
fused_w3_w1_experts = torch.stack(
|
||||
[
|
||||
torch.cat(
|
||||
[gm.get_parameter(w3_node.target), gm.get_parameter(w1_node.target)], dim=-2
|
||||
)
|
||||
for w1_node, w3_node in zip(w1_list, w3_list)
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
fused_w2_experts = torch.stack([gm.get_parameter(n.target) for n in w2_list], dim=0)
|
||||
|
||||
new_key_w3_w1 = f"fused_moe_w3_w1_stacked_{fused_key_counter}"
|
||||
new_key_w2 = f"fused_moe_w2_stacked_{fused_key_counter}"
|
||||
fused_key_counter += 1
|
||||
param_w3_w1 = torch.nn.Parameter(fused_w3_w1_experts)
|
||||
param_w2 = torch.nn.Parameter(fused_w2_experts)
|
||||
gm.register_parameter(new_key_w3_w1, param_w3_w1)
|
||||
gm.register_parameter(new_key_w2, param_w2)
|
||||
|
||||
with graph.inserting_before(node):
|
||||
new_node = graph.call_function(
|
||||
# TODO(Fridah-nv): torch.ops.auto_deploy.trtllm_moe_fused for quantized models
|
||||
torch.ops.auto_deploy.trtllm_moe_fused,
|
||||
args=(
|
||||
hidden_states,
|
||||
selected_experts,
|
||||
routing_weights,
|
||||
graph.get_attr(new_key_w3_w1),
|
||||
graph.get_attr(new_key_w2),
|
||||
),
|
||||
)
|
||||
|
||||
node.replace_all_uses_with(new_node)
|
||||
graph.erase_node(node)
|
||||
|
||||
return fused_key_counter
|
||||
|
||||
|
||||
def _find_lowest_common_ancessor(nodes: list[Node]) -> Optional[Node]:
|
||||
"""
|
||||
Find the lowest common ancestor for a list of nodes in a torch.fx Graph by following
|
||||
each node's primary branch (recursively following the first Node argument).
|
||||
|
||||
It first finds the LCA of the first two nodes and then
|
||||
iteratively computes the LCA of the result with the next node, and so on.
|
||||
|
||||
Returns:
|
||||
The common ancestor Node if found, otherwise None.
|
||||
"""
|
||||
if not nodes:
|
||||
return None
|
||||
|
||||
def get_parent(node: Node) -> Optional[Node]:
|
||||
"""Return the first Node-valued argument for a given node, or None if not found."""
|
||||
for arg in node.args:
|
||||
if isinstance(arg, Node):
|
||||
return arg
|
||||
return None
|
||||
|
||||
def get_depth(node: Node) -> int:
|
||||
"""
|
||||
Recursively compute the depth of the node by following its primary branch.
|
||||
Depth is defined as the number of steps to reach a node with no parent.
|
||||
"""
|
||||
parent = get_parent(node)
|
||||
if parent is None:
|
||||
return 0
|
||||
return 1 + get_depth(parent)
|
||||
|
||||
def lca_two(a: Node, b: Node) -> Optional[Node]:
|
||||
"""
|
||||
Find the lowest common ancestor of two nodes by first equalizing their depth
|
||||
and then moving upward until a common node is found.
|
||||
"""
|
||||
depth_a = get_depth(a)
|
||||
depth_b = get_depth(b)
|
||||
|
||||
# Equalize depths
|
||||
while depth_a > depth_b:
|
||||
a = get_parent(a)
|
||||
depth_a -= 1
|
||||
while depth_b > depth_a:
|
||||
b = get_parent(b)
|
||||
depth_b -= 1
|
||||
|
||||
# Walk upward in lockstep
|
||||
while a is not None and b is not None:
|
||||
if a is b:
|
||||
return a
|
||||
a = get_parent(a)
|
||||
b = get_parent(b)
|
||||
return None
|
||||
|
||||
# Iteratively compute the LCA across all nodes.
|
||||
common = nodes[0]
|
||||
for node in nodes[1:]:
|
||||
common = lca_two(common, node)
|
||||
if common is None:
|
||||
return None
|
||||
|
||||
return common
|
||||
|
||||
|
||||
def _extract_linear_parameters(linear_node: Node) -> tuple[Node, torch.Tensor, Optional[dict], str]:
|
||||
"""
|
||||
Given a linear op node, extract the input tensor node, weight tensor,
|
||||
any quantization scales (if the op is quantized), and return a weight type.
|
||||
|
||||
For a torch.ops.auto_deploy.torch_linear_simple.default op:
|
||||
- Returns (input_node, weight, None, "simple")
|
||||
|
||||
For a torch.ops.auto_deploy.torch_quant_fp8_linear op:
|
||||
- Returns (input_node, weight, {"input_scale": input_scale, "weight_scale": weight_scale}, "fp8")
|
||||
For a torch.ops.auto_deploy.torch_quant_fp4_linear op:
|
||||
- Returns (input_node, weight, {"input_scale": input_scale, "weight_scale": weight_scale, "alpha": alpha}, "fp4")
|
||||
"""
|
||||
input_node = linear_node.args[0]
|
||||
if is_op(linear_node, torch.ops.auto_deploy.torch_linear_simple):
|
||||
weight = linear_node.args[1]
|
||||
return input_node, weight, None, ""
|
||||
elif {
|
||||
is_op(linear_node, torch.ops.auto_deploy.torch_quant_fp4_linear),
|
||||
is_op(linear_node, torch.ops.auto_deploy.torch_quant_fp8_linear),
|
||||
}:
|
||||
weight = linear_node.args[1]
|
||||
scales, quant_type = get_scales_and_type_from_node(linear_node)
|
||||
return input_node, weight, scales, quant_type
|
||||
|
||||
|
||||
def _match_expert_compute_pattern(start_boundary: Node, end_boundary: Node):
|
||||
"""
|
||||
Match the expert compute pattern between the given boundaries.
|
||||
|
||||
The expert compute pattern corresponds to:
|
||||
|
||||
(F.silu(x @ w1.t()) * (x @ w3.t())) @ w2.t()
|
||||
|
||||
For each expert, the function extracts the input node from the w1 branch and
|
||||
collects the weight parameters from three linear ops (w1, w3, and w2 branches).
|
||||
|
||||
This function supports both:
|
||||
- torch.ops.auto_deploy.torch_linear_simple.default ops, and
|
||||
- torch.ops.auto_deploy.torch_quant_fp8_linear ops (also extracts quantization scales).
|
||||
- torch.ops.auto_deploy.torch_quant_fp4_linear ops (also extracts quantization scales).
|
||||
|
||||
Returns:
|
||||
A tuple:
|
||||
(pattern_input_nodes, pattern_output_nodes, expert_weights, expert_scales, weight_type)
|
||||
|
||||
- pattern_input_nodes: List of input nodes (x) used for the expert compute.
|
||||
- pattern_output_nodes: List of final expert output nodes (the linear op with weight w2).
|
||||
- expert_weights: Dict with keys "w1", "w2", "w3" mapping to lists of weight tensors.
|
||||
- expert_scales: Dict with keys "w1_input_scale", "w1_weight_scale", etc., containing scale tensors
|
||||
(empty if weight_type is "simple").
|
||||
- weight_type: "fp8" if FP8 ops were used, "simple" otherwise.
|
||||
"""
|
||||
pattern_input_nodes, pattern_output_nodes = [], []
|
||||
expert_weights = defaultdict(list)
|
||||
expert_scales = defaultdict(list)
|
||||
weight_type = "simple" # default
|
||||
|
||||
nodes = list(start_boundary.graph.nodes)
|
||||
region_nodes = nodes[nodes.index(start_boundary) + 1 : nodes.index(end_boundary)]
|
||||
|
||||
for node in region_nodes:
|
||||
# Accept both simple and quantized linear ops.
|
||||
if not is_linear_op(node, include_quantization=True):
|
||||
continue
|
||||
|
||||
final_linear = node
|
||||
if not final_linear.args or not isinstance(final_linear.args[0], Node):
|
||||
continue
|
||||
|
||||
mul_node = final_linear.args[0]
|
||||
if not is_op(mul_node, torch.ops.aten.mul) or len(mul_node.args) < 2:
|
||||
continue
|
||||
|
||||
arg_a, arg_b = mul_node.args[:2]
|
||||
silu_node = (
|
||||
arg_a
|
||||
if is_op(arg_a, torch.ops.aten.silu)
|
||||
else arg_b
|
||||
if is_op(arg_b, torch.ops.aten.silu)
|
||||
else None
|
||||
)
|
||||
if silu_node is None:
|
||||
continue
|
||||
|
||||
if not (silu_node.args and is_linear_op(silu_node.args[0], include_quantization=True)):
|
||||
continue
|
||||
linear_w1_node = silu_node.args[0]
|
||||
|
||||
# The other branch should be a linear op (w3 branch).
|
||||
linear_w3_node = arg_b if arg_a is silu_node else arg_a
|
||||
if not is_linear_op(linear_w3_node, include_quantization=True):
|
||||
continue
|
||||
if not (linear_w1_node.args and linear_w3_node.args):
|
||||
continue
|
||||
|
||||
# Extract parameters from each linear op.
|
||||
input_node_w1, weight_w1, quant_params_w1, wt_type_w1 = _extract_linear_parameters(
|
||||
linear_w1_node
|
||||
)
|
||||
_, weight_w3, quant_params_w3, wt_type_w3 = _extract_linear_parameters(linear_w3_node)
|
||||
_, weight_w2, quant_params_w2, wt_type_w2 = _extract_linear_parameters(final_linear)
|
||||
|
||||
if None in (weight_w1, weight_w3, weight_w2):
|
||||
continue
|
||||
|
||||
# Ensure the weight type is consistent across branches.
|
||||
if wt_type_w1 != wt_type_w3 or wt_type_w1 != wt_type_w2:
|
||||
continue
|
||||
weight_type = wt_type_w1
|
||||
|
||||
pattern_input_nodes.append(input_node_w1)
|
||||
pattern_output_nodes.append(final_linear)
|
||||
expert_weights["w1"].append(weight_w1)
|
||||
expert_weights["w3"].append(weight_w3)
|
||||
expert_weights["w2"].append(weight_w2)
|
||||
|
||||
# TODO: sanity check that all experts have same weight type
|
||||
if weight_type == "fp8":
|
||||
expert_scales["w1_input_scale"].append(quant_params_w1["input_scale"])
|
||||
expert_scales["w1_weight_scale"].append(quant_params_w1["weight_scale"])
|
||||
expert_scales["w3_input_scale"].append(quant_params_w3["input_scale"])
|
||||
expert_scales["w3_weight_scale"].append(quant_params_w3["weight_scale"])
|
||||
expert_scales["w2_input_scale"].append(quant_params_w2["input_scale"])
|
||||
expert_scales["w2_weight_scale"].append(quant_params_w2["weight_scale"])
|
||||
elif weight_type == "fp4":
|
||||
expert_scales["w1_input_scale"].append(quant_params_w1["input_scale"])
|
||||
expert_scales["w1_weight_scale"].append(quant_params_w1["weight_scale"])
|
||||
expert_scales["w1_alpha"].append(quant_params_w1["alpha"])
|
||||
expert_scales["w3_input_scale"].append(quant_params_w3["input_scale"])
|
||||
expert_scales["w3_weight_scale"].append(quant_params_w3["weight_scale"])
|
||||
expert_scales["w3_alpha"].append(quant_params_w3["alpha"])
|
||||
expert_scales["w2_input_scale"].append(quant_params_w2["input_scale"])
|
||||
expert_scales["w2_weight_scale"].append(quant_params_w2["weight_scale"])
|
||||
expert_scales["w2_alpha"].append(quant_params_w2["alpha"])
|
||||
|
||||
return pattern_input_nodes, pattern_output_nodes, expert_weights, expert_scales, weight_type
|
||||
|
||||
|
||||
def _find_final_hidden_state_node(
|
||||
pattern_output_nodes: list[Node], end_boundary: Node
|
||||
) -> Optional[Node]:
|
||||
"""
|
||||
Identify the final hidden state node corresponding to the combine pattern:
|
||||
|
||||
(expert_output * routing_weight) → index_add_
|
||||
|
||||
For each expert output node (from the expert compute pattern), this function:
|
||||
1. Retrieves a multiplication node from its users.
|
||||
2. Extracts the second argument from the multiplication node (assumed to be the index node).
|
||||
3. Uses a BFS to locate the subsequent index_add_ node (guarded by the end_boundary).
|
||||
|
||||
After collecting all such index_add_ nodes, the final hidden state node is determined
|
||||
as the one that is not used by any of the other index_add_ nodes.
|
||||
|
||||
If any required attribute (users or args) is missing during the process or if no valid
|
||||
final node is found, the function returns None.
|
||||
"""
|
||||
|
||||
if not pattern_output_nodes:
|
||||
return None
|
||||
|
||||
index_add_nodes = []
|
||||
for node in pattern_output_nodes:
|
||||
if not node.users:
|
||||
return None
|
||||
mul_node = next(iter(node.users))
|
||||
if not (hasattr(mul_node, "args") and len(mul_node.args) >= 2):
|
||||
return None
|
||||
index_node = mul_node.args[1]
|
||||
index_add_node = bfs(
|
||||
index_node, lambda n: is_op(n, torch.ops.aten.index_add_), boundary=end_boundary
|
||||
)
|
||||
if not index_add_node:
|
||||
return None
|
||||
index_add_nodes.append(index_add_node)
|
||||
|
||||
# The final node is defined as the index_add_node that is not used by any other index_add_nodes
|
||||
return next(
|
||||
(
|
||||
candidate
|
||||
for candidate in index_add_nodes
|
||||
if not any(
|
||||
candidate in other.args for other in index_add_nodes if candidate is not other
|
||||
)
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
def _extract_index_branches_from_expert_outputs(
|
||||
pattern_output_nodes: list[Node],
|
||||
) -> tuple[list[Node], list[Node]]:
|
||||
"""
|
||||
Extract routing and experts branches from expert outputs.
|
||||
|
||||
For each expert output, find its multiplication user. From the
|
||||
multiplication node's second argument (an index node),
|
||||
extract:
|
||||
- The first argument as the routing branch.
|
||||
- The second argument (flattened if a list/tuple) as the experts branch.
|
||||
|
||||
Returns:
|
||||
A tuple (routing_branches, experts_branches).
|
||||
"""
|
||||
routing_branches, experts_branches = [], []
|
||||
for out in pattern_output_nodes:
|
||||
mul = next((u for u in out.users if is_op(u, torch.ops.aten.mul)), None)
|
||||
if not mul or len(mul.args) < 2:
|
||||
continue
|
||||
idx_node = mul.args[1]
|
||||
if not is_op(idx_node, torch.ops.aten.index):
|
||||
continue
|
||||
routing_branches.append(idx_node.args[0])
|
||||
experts = idx_node.args[1]
|
||||
experts_branches.extend(experts) if isinstance(
|
||||
experts, (list, tuple)
|
||||
) else experts_branches.append(experts)
|
||||
return routing_branches, experts_branches
|
||||
|
||||
|
||||
def _remove_dead_inplace_nodes_in_region(
|
||||
graph: torch.fx.Graph,
|
||||
start_boundary: torch.fx.Node,
|
||||
end_boundary: torch.fx.Node,
|
||||
) -> bool:
|
||||
"""
|
||||
Searches (via BFS) for a dead in-place node (index_add_) in the region
|
||||
between start_boundary and end_boundary. If one is found, it is removed from the graph.
|
||||
Returns True if a node was removed, False otherwise.
|
||||
"""
|
||||
|
||||
def target(n: torch.fx.Node) -> bool:
|
||||
return is_op(n, {torch.ops.aten.index_add_}) and len(n.users) == 0
|
||||
|
||||
try:
|
||||
node_to_remove = bfs(start_boundary, target, attr_next="users", boundary=end_boundary)
|
||||
ad_logger.debug(f"Removing In-place Dead Node: {node_to_remove}")
|
||||
graph.erase_node(node_to_remove)
|
||||
return True
|
||||
except RuntimeError:
|
||||
return False
|
||||
@ -1,193 +0,0 @@
|
||||
"""Graph transformation to automatically add kv cache into fused MHA op."""
|
||||
|
||||
import operator
|
||||
from typing import Dict, Type
|
||||
|
||||
import torch
|
||||
from torch.fx import Graph, GraphModule, Node
|
||||
|
||||
from ...custom_ops.attention_interface import AttentionDescriptor, CacheConfig
|
||||
from ...distributed.common import all_gather_object, get_world_size
|
||||
from ...shim.interface import CachedSequenceInterface
|
||||
from ...utils.logger import ad_logger
|
||||
from ...utils.node_utils import get_all_input_output_nodes, is_op
|
||||
from .._graph import add_graph_input, canonicalize_graph
|
||||
|
||||
|
||||
def update_in_out_nodes(egm: GraphModule, cm: CachedSequenceInterface) -> None:
|
||||
"""Modify the graph module by adding new input nodes and canonicalizing the graph.
|
||||
|
||||
The new input nodes correspond to the extra arguments needed for cached and flattened attention.
|
||||
|
||||
Args:
|
||||
egm: The graph module to analyze and modify.
|
||||
cm: Cached sequence interface containing extra argument information.
|
||||
"""
|
||||
# loop through nodes to get input, output, and get_attr nodes
|
||||
input_nodes, output_nodes = get_all_input_output_nodes(egm.graph)
|
||||
|
||||
# we only expect one input node
|
||||
assert len(input_nodes) == 2, "Expected exactly two input nodes (input_ids, position_ids)."
|
||||
|
||||
# NOTE: for now, we wanna make sure we *only* return the final output and no hidden states.
|
||||
# Later on, we can revisit how to support returning hidden states.
|
||||
assert len(output_nodes) == 1, "Expected exactly one output node!"
|
||||
assert len(output_nodes[0].all_input_nodes) == 1, "Expected to only return final tensor output!"
|
||||
|
||||
ad_logger.info(f"Found {len(input_nodes)} input nodes and {len(output_nodes)} output nodes")
|
||||
|
||||
# Activate and add extra argument nodes
|
||||
new_args = cm.info.switch_to_cached_attn_inputs()
|
||||
for name in new_args:
|
||||
input_nodes.append(add_graph_input(egm, name))
|
||||
ad_logger.info(f"Added {len(new_args)} new input nodes for cached attention metadata")
|
||||
|
||||
canonicalize_graph(egm)
|
||||
|
||||
|
||||
def insert_cached_attention(
|
||||
egm: GraphModule,
|
||||
cm: CachedSequenceInterface,
|
||||
attn_descriptor: Type[AttentionDescriptor],
|
||||
cache_config: CacheConfig,
|
||||
) -> None:
|
||||
"""Replace uncached source attention node with corresponding cached attn node."""
|
||||
# Get all attention nodes and their info objects
|
||||
source_op = attn_descriptor.get_source_attention_op()
|
||||
|
||||
# pick up graph
|
||||
graph: Graph = egm.graph
|
||||
|
||||
# look for relevant source attention nodes
|
||||
source_attn_nodes = [n for n in graph.nodes if is_op(n, source_op)]
|
||||
|
||||
if not source_attn_nodes:
|
||||
# If there are no nodes for kv cache insertion found, return current graph
|
||||
return
|
||||
|
||||
# Sanity check
|
||||
if cm.info.is_paged:
|
||||
assert attn_descriptor.is_paged(), "Paged sequence info requires paged attention op."
|
||||
|
||||
ad_logger.debug(f"Before inserting {attn_descriptor=} with cache: {egm}")
|
||||
|
||||
# retrieve input nodes
|
||||
input_nodes, _ = get_all_input_output_nodes(egm.graph)
|
||||
|
||||
# insert metadata computation and extract each argument as a node
|
||||
get_metadata, num_metadata = attn_descriptor.get_prepare_metadata_op()
|
||||
with graph.inserting_before(input_nodes[-1].next):
|
||||
ret_node = graph.call_function(
|
||||
get_metadata,
|
||||
args=(
|
||||
*input_nodes,
|
||||
cm.info.page_size,
|
||||
),
|
||||
)
|
||||
metadata_nodes = [
|
||||
graph.call_function(operator.getitem, args=(ret_node, idx))
|
||||
for idx in range(num_metadata)
|
||||
]
|
||||
|
||||
buffer_in_lookup: Dict[str, Node] = {}
|
||||
|
||||
# replace fused attention node with attention node that has kv cache
|
||||
num_cached_attn_replacements = 0
|
||||
for idx, attn_node in enumerate(source_attn_nodes):
|
||||
# pick out GEMMs
|
||||
qkv = attn_node.args[: attn_descriptor.get_num_qkv_args()]
|
||||
|
||||
# setup + store cache initializers and caches as input nodes
|
||||
cache_in_nodes = []
|
||||
for k, get_cache in attn_descriptor.get_cache_initializers(attn_node, cache_config).items():
|
||||
k_indexed = f"{k}_{idx}"
|
||||
cm.add_cache(k_indexed, get_cache)
|
||||
cache_in_nodes.append(add_graph_input(egm, k_indexed))
|
||||
|
||||
# setup + store global buffer initializers and buffers as input nodes
|
||||
# NOTE: we have to check against existing keys to make sure nothing is registered twice...
|
||||
buffer_in_nodes = []
|
||||
for k, get_buffer in attn_descriptor.get_global_buffer_initializers(attn_node).items():
|
||||
if k not in buffer_in_lookup:
|
||||
cm.add_cache(k, get_buffer)
|
||||
buffer_in_lookup[k] = add_graph_input(egm, k)
|
||||
buffer_in_nodes.append(buffer_in_lookup[k]) # store buffer nodes for this op
|
||||
|
||||
# retrieve constants for attention_op
|
||||
constants = attn_descriptor.get_constants(attn_node)
|
||||
|
||||
# insert cached attention replacement op
|
||||
with graph.inserting_before(attn_node):
|
||||
cached_attn_node = graph.call_function(
|
||||
attn_descriptor.get_cached_attention_op(),
|
||||
args=(*qkv, *metadata_nodes, *cache_in_nodes, *buffer_in_nodes, *constants),
|
||||
)
|
||||
attn_node.replace_all_uses_with(cached_attn_node)
|
||||
graph.erase_node(attn_node)
|
||||
num_cached_attn_replacements += 1
|
||||
|
||||
canonicalize_graph(egm)
|
||||
ad_logger.info(
|
||||
f"Replaced {num_cached_attn_replacements} {source_op} ops "
|
||||
f"with {attn_descriptor.get_cached_attention_op()}"
|
||||
)
|
||||
ad_logger.debug(f"After inserting {attn_descriptor=} with cache: {egm}")
|
||||
|
||||
|
||||
def resize_kv_cache(
|
||||
egm: GraphModule,
|
||||
cm: CachedSequenceInterface,
|
||||
free_mem_ratio: float = 0.8,
|
||||
) -> None:
|
||||
"""Inflate the kv cache to occupy the available GPU memory.
|
||||
|
||||
free_mem_ratio specifies the fraction of available memory to occupy.
|
||||
"""
|
||||
|
||||
def _get_mem_info_in_mb():
|
||||
free_mem, total_mem = torch.cuda.mem_get_info()
|
||||
return free_mem // 1024**2, total_mem // 1024**2
|
||||
|
||||
free_mem, total_mem = _get_mem_info_in_mb()
|
||||
ad_logger.info(f"Free memory (MB): {free_mem}, Total memory (MB): {total_mem}")
|
||||
current_cache_size = cm.current_cache_size_bytes()
|
||||
current_num_pages = cm.info.num_pages
|
||||
ad_logger.info(
|
||||
f"Current cache size: {current_cache_size}, Current num pages: {current_num_pages}"
|
||||
)
|
||||
|
||||
if free_mem_ratio == 0.0:
|
||||
ad_logger.info(f"Skipping cache resize for {free_mem_ratio=}")
|
||||
return
|
||||
|
||||
try:
|
||||
# Let's run a forward pass to get the memory usage
|
||||
cm.info._set_max_num_tokens_sample()
|
||||
free_mem_pre, _ = _get_mem_info_in_mb()
|
||||
ad_logger.info(f"Free memory before forward pass (MB): {free_mem_pre}")
|
||||
|
||||
egm(*cm.args)
|
||||
|
||||
free_mem_post, _ = _get_mem_info_in_mb()
|
||||
ad_logger.info(f"Free memory after forward pass (MB): {free_mem_post}")
|
||||
|
||||
memory_for_forward_pass = free_mem_pre - free_mem_post
|
||||
ad_logger.info(f"Memory for forward pass (MB): {memory_for_forward_pass}")
|
||||
|
||||
new_cache_size = free_mem_post * 1024 * 1024 * free_mem_ratio + current_cache_size
|
||||
new_num_pages = int(new_cache_size // (current_cache_size // current_num_pages))
|
||||
|
||||
# Need to sync all the GPUs
|
||||
gathered_num_pages = [None] * get_world_size()
|
||||
all_gather_object(gathered_num_pages, new_num_pages)
|
||||
new_num_pages = min(gathered_num_pages)
|
||||
ad_logger.info(f"After all_gather - new_num_pages: {new_num_pages}")
|
||||
|
||||
cm.resize_cache(new_num_pages)
|
||||
except Exception as e:
|
||||
ad_logger.warning(
|
||||
f"Error encountered while resizing kv cache: {e}.\nSkipping cache resize."
|
||||
)
|
||||
|
||||
# Free memory
|
||||
torch.cuda.empty_cache()
|
||||
@ -1,113 +0,0 @@
|
||||
"""Graph transform to optimize RMSNorm execution using FlashInfer."""
|
||||
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from ...utils.logger import ad_logger
|
||||
|
||||
# It is important to import ADPatternMatcherPass from pattern_matcher.py, not from torch._inductor.pattern_matcher
|
||||
from ...utils.pattern_matcher import ADPatternMatcherPass, register_ad_pattern
|
||||
from .._graph import canonicalize_graph
|
||||
|
||||
_BACKEND_OPS = {
|
||||
"flashinfer": torch.ops.auto_deploy.flashinfer_rms_norm,
|
||||
"triton": torch.ops.auto_deploy.triton_rms_norm,
|
||||
"torch": torch.ops.auto_deploy.torch_rmsnorm,
|
||||
}
|
||||
|
||||
|
||||
def _rms_norm_pattern(data: torch.Tensor, weight: torch.Tensor, eps: float) -> torch.Tensor:
|
||||
"""Implements the RMSNorm pattern for pattern matching.
|
||||
|
||||
Args:
|
||||
data: Input tensor to normalize.
|
||||
weight: Scaling weights for the normalized output.
|
||||
eps: Small constant for numerical stability.
|
||||
|
||||
Returns:
|
||||
Normalized and scaled tensor.
|
||||
"""
|
||||
input_dtype = data.dtype
|
||||
data = data.to(torch.float32)
|
||||
variance = data.pow(2).mean(-1, keepdim=True)
|
||||
data = data * torch.rsqrt(variance + eps)
|
||||
return weight * data.to(input_dtype)
|
||||
|
||||
|
||||
def _rms_norm_replacement(
|
||||
data: torch.Tensor, weight: torch.Tensor, eps: float, backend: str
|
||||
) -> torch.Tensor:
|
||||
"""Backend-specific rms_norm implementation.
|
||||
|
||||
Args:
|
||||
data: Input tensor to normalize.
|
||||
weight: Scaling weights for the normalized output.
|
||||
eps: Small constant for numerical stability.
|
||||
backend: Backend to use for RMSNorm computation ("flashinfer" or "triton").
|
||||
|
||||
Returns:
|
||||
Normalized and scaled tensor using the specified backend implementation.
|
||||
"""
|
||||
|
||||
assert backend.lower() in _BACKEND_OPS, (
|
||||
f"Invalid {backend=}; must be one of {list(_BACKEND_OPS)}"
|
||||
)
|
||||
return _BACKEND_OPS[backend.lower()](data, weight, eps)
|
||||
|
||||
|
||||
def fuse_rmsnorm(gm: GraphModule, backend: str = "triton") -> None:
|
||||
"""Matches and replaces RMSNorm patterns in the graph with FlashInfer or Triton implementation.
|
||||
|
||||
This function sets up pattern matching to identify RMSNorm operations in the graph
|
||||
and replaces them with optimized implementations. It uses dummy tensors to register
|
||||
the pattern matching rules.
|
||||
|
||||
Args:
|
||||
gm: Input graph module to transform.
|
||||
backend: Backend to use for RMSNorm computation ("flashinfer" or "triton").
|
||||
|
||||
Returns:
|
||||
Transformed graph module with optimized RMSNorm operations.
|
||||
"""
|
||||
if backend.lower() not in _BACKEND_OPS:
|
||||
raise ValueError(f"Invalid backend, must be one of {list(_BACKEND_OPS)}, got {backend}")
|
||||
ad_logger.info(f"Starting RMSNorm pattern matching with backend: {backend}")
|
||||
|
||||
graph = gm.graph
|
||||
patterns = ADPatternMatcherPass()
|
||||
|
||||
# Create dummy tensors for pattern matching
|
||||
bs = 2
|
||||
hidden_size = 512
|
||||
|
||||
def dummy_args(input_dtype: torch.dtype, weight_dtype: torch.dtype, eps: float = 1e-6):
|
||||
return [
|
||||
torch.randn(bs, hidden_size, device="cuda", dtype=input_dtype),
|
||||
torch.randn(hidden_size, device="cuda", dtype=weight_dtype),
|
||||
eps,
|
||||
]
|
||||
|
||||
# Define configurations for different data types
|
||||
configs = [
|
||||
(torch.bfloat16, torch.bfloat16),
|
||||
(torch.float16, torch.float16),
|
||||
(torch.float32, torch.float32),
|
||||
]
|
||||
|
||||
# Register patterns for each configuration
|
||||
for input_dtype, weight_dtype in configs:
|
||||
register_ad_pattern(
|
||||
search_fn=_rms_norm_pattern,
|
||||
replace_fn=partial(_rms_norm_replacement, backend=backend),
|
||||
patterns=patterns,
|
||||
dummy_args=dummy_args(input_dtype, weight_dtype),
|
||||
op_ignore_types={},
|
||||
scalar_workaround={"eps": 1e-6},
|
||||
)
|
||||
|
||||
cnt = patterns.apply(graph)
|
||||
ad_logger.info(f"RMSNorm pattern count: {cnt}")
|
||||
canonicalize_graph(gm)
|
||||
ad_logger.debug("RMSNorm pattern matching completed.")
|
||||
@ -5,21 +5,11 @@ import gc
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..compile import compile_and_capture
|
||||
from ..custom_ops.attention_interface import AttentionRegistry
|
||||
from ..llm_args import AutoDeployConfig
|
||||
from ..models.factory import ModelFactory
|
||||
from ..shim.interface import CachedSequenceInterface
|
||||
from ..transform.optimizer import InferenceOptimizer as ModularInferenceOptimizer
|
||||
from ..utils.logger import ad_logger
|
||||
from .library import (
|
||||
fuse_allreduce_residual_rmsnorm,
|
||||
fuse_collectives,
|
||||
fuse_rmsnorm,
|
||||
insert_cached_attention,
|
||||
resize_kv_cache,
|
||||
update_in_out_nodes,
|
||||
)
|
||||
|
||||
|
||||
class InferenceOptimizer:
|
||||
@ -55,88 +45,60 @@ class InferenceOptimizer:
|
||||
self.ad_config.attn_backend
|
||||
).get_attention_layout()
|
||||
|
||||
new_optimizer = ModularInferenceOptimizer(self.factory, self.ad_config.transforms)
|
||||
|
||||
# TODO (hg): similar to above.
|
||||
if "load_weights" in new_optimizer.config:
|
||||
new_optimizer.config[
|
||||
if "load_weights" in self.ad_config.transforms:
|
||||
self.ad_config.transforms[
|
||||
"load_weights"
|
||||
].checkpoint_device = self.ad_config.checkpoint_device
|
||||
new_optimizer.config["load_weights"].device = cm.device
|
||||
self.ad_config.transforms["load_weights"].device = cm.device
|
||||
|
||||
if "resize_kv_cache" in self.ad_config.transforms:
|
||||
self.ad_config.transforms[
|
||||
"resize_kv_cache"
|
||||
].free_mem_ratio = self.ad_config.free_mem_ratio
|
||||
if "insert_cached_attention" in self.ad_config.transforms:
|
||||
self.ad_config.transforms[
|
||||
"insert_cached_attention"
|
||||
].attn_backend = self.ad_config.attn_backend
|
||||
if "insert_cached_mla_attention" in self.ad_config.transforms:
|
||||
self.ad_config.transforms[
|
||||
"insert_cached_mla_attention"
|
||||
].attn_backend = self.ad_config.mla_backend
|
||||
|
||||
# TODO: (hg)Missing MLA here. Figure out how to add MLA since duplicate transforms are not allowed.
|
||||
# Old code:
|
||||
# detect attention op and replace with cache-aware op
|
||||
# for a_backend in [self.ad_config.attn_backend, self.ad_config.mla_backend]:
|
||||
# attn_descriptor = AttentionRegistry.get(a_backend)
|
||||
# insert_cached_attention(egm, cm, attn_descriptor, self.factory.get_cache_config())
|
||||
|
||||
if "compile_model" in self.ad_config.transforms:
|
||||
self.ad_config.transforms[
|
||||
"compile_model"
|
||||
].cuda_graph_batch_sizes = self.ad_config.cuda_graph_batch_sizes
|
||||
self.ad_config.transforms[
|
||||
"compile_model"
|
||||
].compile_backend = self.ad_config.compile_backend
|
||||
|
||||
new_optimizer = ModularInferenceOptimizer(self.factory, self.ad_config.transforms)
|
||||
# TODO: (hg) move this. let match_rope_layout and match_atten_layout use this shared config
|
||||
new_optimizer.shared_config.attn_backend = self.ad_config.attn_backend
|
||||
|
||||
egm = new_optimizer(cm)
|
||||
|
||||
# TODO (lucaslie): continue moving legacy transforms to the new optimizer
|
||||
############################################################################################
|
||||
# RUN POST-LOAD FUSION AND OPTIMIZATIONS
|
||||
############################################################################################
|
||||
# NOTE: (hg)Disabled visualization since compiled gm is a CapturedGraph instead of GraphModule.
|
||||
# We can add a new stage in the optimizer to visualize the intermediate gm.
|
||||
# if self.ad_config.visualize:
|
||||
# try:
|
||||
# from .library import visualize_namespace
|
||||
|
||||
# run MoE fusion
|
||||
# TODO: https://github.com/NVIDIA/TensorRT-LLM/issues/4674 this is causing OOMs
|
||||
# fuse_moe(egm)
|
||||
|
||||
# run GEMM fusion
|
||||
# TODO: https://github.com/NVIDIA/TensorRT-LLM/issues/4674 this is causing OOMs
|
||||
# fuse_gemms(egm)
|
||||
|
||||
# check if we can fuse allreduce, residual and rmsnorm
|
||||
fuse_allreduce_residual_rmsnorm(egm)
|
||||
|
||||
# check if we can fuse collectives
|
||||
fuse_collectives(egm)
|
||||
|
||||
# TODO (lucaslie): add backend selection as part of configurable inference optimizers
|
||||
# check if we can fuse rmsnorm
|
||||
fuse_rmsnorm(egm, "flashinfer")
|
||||
|
||||
# visualize the final graph
|
||||
if self.ad_config.visualize:
|
||||
try:
|
||||
from .library import visualize_namespace
|
||||
|
||||
visualize_namespace(egm, args=cm.args, dynamic_shapes=cm.dynamic_shapes)
|
||||
ad_logger.warning(
|
||||
"Please run `pip install -r examples/auto_deploy/requirements.txt` to visualize"
|
||||
" the graph."
|
||||
)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
############################################################################################
|
||||
# SWITCH TO CACHED+FLATTENED ATTENTION + INITIALIZE CACHES
|
||||
############################################################################################
|
||||
|
||||
update_in_out_nodes(egm, cm)
|
||||
|
||||
# detect attention op and replace with cache-aware op
|
||||
for a_backend in [self.ad_config.attn_backend, self.ad_config.mla_backend]:
|
||||
attn_descriptor = AttentionRegistry.get(a_backend)
|
||||
insert_cached_attention(egm, cm, attn_descriptor, self.factory.get_cache_config())
|
||||
|
||||
# initialize cache on correct device
|
||||
cm.initialize_caches()
|
||||
|
||||
# resize kv cache to occupy the available GPU memory up to free_mem_ratio
|
||||
resize_kv_cache(egm, cm, free_mem_ratio=self.ad_config.free_mem_ratio)
|
||||
|
||||
############################################################################################
|
||||
# COMPILE MODEL
|
||||
############################################################################################
|
||||
|
||||
cm.info.set_generate_only_batch()
|
||||
compiler_kwargs = {
|
||||
"cuda_graph_batch_sizes": self.ad_config.cuda_graph_batch_sizes,
|
||||
"num_batched_inputs": 2, # TODO (lucaslie): improve once we have a config system...
|
||||
}
|
||||
egm_compiled = compile_and_capture(
|
||||
egm,
|
||||
self.ad_config.compile_backend,
|
||||
args=cm.args,
|
||||
dynamic_shapes=cm.dynamic_shapes,
|
||||
compiler_kwargs=compiler_kwargs,
|
||||
)
|
||||
cm.info.reset()
|
||||
# visualize_namespace(egm, args=cm.args, dynamic_shapes=cm.dynamic_shapes)
|
||||
# ad_logger.warning(
|
||||
# "Please run `pip install -r examples/auto_deploy/requirements.txt` to visualize"
|
||||
# " the graph."
|
||||
# )
|
||||
# except ImportError:
|
||||
# pass
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
return egm_compiled
|
||||
return egm
|
||||
|
||||
@ -8,9 +8,7 @@ from torch.export import export
|
||||
from tensorrt_llm._torch.auto_deploy.distributed import common as dist
|
||||
from tensorrt_llm._torch.auto_deploy.distributed.trtllm import is_trtllm_op_available
|
||||
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
|
||||
from tensorrt_llm._torch.auto_deploy.transformations.library.collectives import (
|
||||
fuse_allreduce_residual_rmsnorm,
|
||||
)
|
||||
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
|
||||
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op
|
||||
from tensorrt_llm.llmapi.mpi_session import MpiPoolSession
|
||||
|
||||
@ -65,14 +63,21 @@ def _test_allreduce_fusion(port: int):
|
||||
original_outputs, residual_original = gm(x, residual)
|
||||
|
||||
# Fuse ops
|
||||
fuse_allreduce_residual_rmsnorm(gm)
|
||||
gm_transformed = InferenceOptimizer(
|
||||
None,
|
||||
{
|
||||
"fuse_allreduce_residual_rmsnorm": {
|
||||
"stage": "post_load_fusion",
|
||||
},
|
||||
},
|
||||
)(None, gm)
|
||||
|
||||
# Run the fused graph
|
||||
fused_outputs, residual_fused = gm(x, residual)
|
||||
fused_outputs, residual_fused = gm_transformed(x, residual)
|
||||
|
||||
# Check if fused node in the graph
|
||||
has_fused_node = False
|
||||
for node in gm.graph.nodes:
|
||||
for node in gm_transformed.graph.nodes:
|
||||
if is_op(node, torch.ops.dist.fused_allreduce_residual_rmsnorm):
|
||||
has_fused_node = True
|
||||
assert has_fused_node, "Fused node not found."
|
||||
@ -86,8 +91,8 @@ def _test_allreduce_fusion(port: int):
|
||||
)
|
||||
|
||||
# check if we can still export the model as expected
|
||||
export(gm, args=args)
|
||||
torch_export_to_gm(gm, args=args)
|
||||
export(gm_transformed, args=args)
|
||||
torch_export_to_gm(gm_transformed, args=args)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device_count", get_device_counts())
|
||||
|
||||
@ -8,12 +8,13 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from _dist_test_utils import get_device_counts
|
||||
from _graph_test_helpers import run_test
|
||||
from _graph_test_helpers import run_test_transformed_gm
|
||||
from _torch_test_utils import fp8_compatible
|
||||
|
||||
import tensorrt_llm._torch.auto_deploy.distributed.common as dist_common
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.quant import FP8Linear
|
||||
from tensorrt_llm._torch.auto_deploy.transformations.library import fuse_collectives
|
||||
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
|
||||
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
|
||||
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op
|
||||
|
||||
|
||||
@ -61,11 +62,21 @@ def _run_job(
|
||||
is_op(n, torch.ops.auto_deploy.torch_dist_all_reduce) for n in gm.graph.nodes
|
||||
)
|
||||
|
||||
gm = torch_export_to_gm(model, args=(x,), clone=True)
|
||||
gm_transformed = InferenceOptimizer(
|
||||
None,
|
||||
{
|
||||
"fuse_collectives": {
|
||||
"stage": "post_load_fusion",
|
||||
},
|
||||
},
|
||||
)(None, gm)
|
||||
|
||||
# now run the test
|
||||
run_test(
|
||||
run_test_transformed_gm(
|
||||
model,
|
||||
x,
|
||||
transform=fuse_collectives,
|
||||
gm_transformed,
|
||||
check_transformed_graph=check_transformed_graph,
|
||||
_get_expected_num_params=_get_expected_num_params,
|
||||
test_load_hook=False,
|
||||
|
||||
@ -1,12 +1,11 @@
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from _graph_test_helpers import run_test
|
||||
from _graph_test_helpers import run_test_transformed_gm
|
||||
from torch.export import Dim
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.rms_norm import * # noqa
|
||||
from tensorrt_llm._torch.auto_deploy.transformations.library.rms_norm import fuse_rmsnorm
|
||||
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
|
||||
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
|
||||
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op
|
||||
|
||||
|
||||
@ -52,15 +51,28 @@ def test_rmsnorm_fusion(eps, variant, op):
|
||||
return any(is_op(n, op) for n in gm.graph.nodes)
|
||||
|
||||
model = TestModel(eps)
|
||||
gm_transformed = run_test(
|
||||
x = torch.randn(2, 1024, device="cuda", dtype=torch.float16)
|
||||
dynamic_shapes = {0: Dim("batch_size", max=8)}
|
||||
gm = torch_export_to_gm(model, args=(x,), dynamic_shapes=(dynamic_shapes,), clone=True)
|
||||
gm_transformed = InferenceOptimizer(
|
||||
None,
|
||||
{
|
||||
"fuse_rmsnorm": {
|
||||
"stage": "post_load_fusion",
|
||||
"backend": variant,
|
||||
},
|
||||
},
|
||||
)(None, gm)
|
||||
|
||||
run_test_transformed_gm(
|
||||
model,
|
||||
torch.randn(2, 1024, device="cuda", dtype=torch.float16),
|
||||
partial(fuse_rmsnorm, backend=variant),
|
||||
x,
|
||||
gm_transformed,
|
||||
checker,
|
||||
lambda num_p_og: num_p_og,
|
||||
dynamic_shapes={0: Dim("batch_size", max=8)},
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
)
|
||||
print(gm_transformed.graph)
|
||||
|
||||
new_input = torch.randn(4, 1024, device="cuda", dtype=torch.float16)
|
||||
y_transformed = gm_transformed(new_input)
|
||||
y_model = model(new_input)
|
||||
|
||||
@ -7,11 +7,12 @@ import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from _graph_test_helpers import count_buffers, run_test
|
||||
from _graph_test_helpers import count_buffers, run_test_transformed_gm
|
||||
from _torch_test_utils import all_close, fp8_compatible, reset_parameters
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.quant import FP8Linear
|
||||
from tensorrt_llm._torch.auto_deploy.transformations.library import fuse_gemms
|
||||
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
|
||||
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
|
||||
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_linear_op
|
||||
|
||||
torch.manual_seed(0)
|
||||
@ -254,10 +255,20 @@ def test_fusion(get_model: Callable[[], TestModel], dtype: str):
|
||||
|
||||
buffer_size_before = count_buffers(model)
|
||||
|
||||
gm_transformed = run_test(
|
||||
gm = torch_export_to_gm(model, args=(x,), clone=True)
|
||||
gm_transformed = InferenceOptimizer(
|
||||
None,
|
||||
{
|
||||
"fuse_gemms": {
|
||||
"stage": "post_load_fusion",
|
||||
},
|
||||
},
|
||||
)(None, gm)
|
||||
|
||||
run_test_transformed_gm(
|
||||
model,
|
||||
x,
|
||||
fuse_gemms,
|
||||
gm_transformed,
|
||||
lambda gm: sum(is_linear_op(n, include_quantization=True) for n in gm.graph.nodes)
|
||||
== model.num_gemms_after_fusion,
|
||||
lambda num_p_og: num_p_og, # unchanged since fusing doesn't change param count
|
||||
|
||||
@ -2,19 +2,35 @@ from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from _graph_test_helpers import FakeFactory
|
||||
from _graph_test_helpers import SequenceEmbeddingInfo
|
||||
from _model_test_utils import GQA
|
||||
from _torch_test_utils import all_close
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import CacheConfig, SequenceInfo
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.flashinfer_attention import FlashInferAttention
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.triton_attention import TritonAttention
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import CacheConfig
|
||||
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
|
||||
from tensorrt_llm._torch.auto_deploy.models.factory import ModelFactory
|
||||
from tensorrt_llm._torch.auto_deploy.shim.interface import CachedSequenceInterface
|
||||
from tensorrt_llm._torch.auto_deploy.transform.interface import InferenceOptimizerConfig
|
||||
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
|
||||
from tensorrt_llm._torch.auto_deploy.transformations.library import update_in_out_nodes
|
||||
from tensorrt_llm._torch.auto_deploy.transformations.library.kvcache import insert_cached_attention
|
||||
|
||||
|
||||
class DummyFactory(ModelFactory):
|
||||
"""Dummy factory to pass cache_config for testing."""
|
||||
|
||||
def __init__(self, model, cache_config):
|
||||
self._model = model
|
||||
self.cache_config = cache_config
|
||||
|
||||
def build_model(self, device: str):
|
||||
return self._model.to(device=device)
|
||||
|
||||
def _build_model(self, device: str):
|
||||
return
|
||||
|
||||
def _load_checkpoint(self, model, device):
|
||||
return
|
||||
|
||||
def get_cache_config(self):
|
||||
return self.cache_config
|
||||
|
||||
|
||||
# Class that uses SDPA directly instead of the regular attention mechanism
|
||||
@ -68,42 +84,6 @@ class GQAWithSdpa(GQA):
|
||||
return self.o_proj(attn_output)
|
||||
|
||||
|
||||
def _get_optimizer_config() -> InferenceOptimizerConfig:
|
||||
return {
|
||||
"build_model": {
|
||||
"stage": "factory",
|
||||
"device": "cuda",
|
||||
"run_graph_cleanup": False,
|
||||
"requires_clean_graph": False,
|
||||
},
|
||||
"export_to_gm": {
|
||||
"stage": "export",
|
||||
"strict": False,
|
||||
"clone_state_dict": True,
|
||||
"run_graph_cleanup": False,
|
||||
"requires_clean_graph": False,
|
||||
},
|
||||
"cleanup_input_constraints": {
|
||||
"stage": "post_export",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class SequenceEmbeddingInfo(SequenceInfo):
|
||||
hidden_size: int
|
||||
dtype: torch.dtype
|
||||
|
||||
def set_example_sequence(self) -> None:
|
||||
super().set_example_sequence()
|
||||
# set input ids to a 3D tensor (actually input embeddings)
|
||||
self.input_ids = torch.rand(
|
||||
*self.input_ids.shape,
|
||||
self.hidden_size,
|
||||
device=self.input_ids.device,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
|
||||
# TODO (lucaslie): consider rewriting this test with a custom InferenceOptimizer config
|
||||
@pytest.mark.parametrize(
|
||||
"dtype",
|
||||
@ -111,8 +91,8 @@ class SequenceEmbeddingInfo(SequenceInfo):
|
||||
ids=["float16", "float32"],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"attn_descriptor",
|
||||
[TritonAttention, FlashInferAttention],
|
||||
"attn_backend",
|
||||
["triton", "flashinfer"],
|
||||
ids=["triton", "flashinfer"],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
@ -125,10 +105,10 @@ class SequenceEmbeddingInfo(SequenceInfo):
|
||||
ids=["regular", "gqa", "mqa"],
|
||||
)
|
||||
@torch.inference_mode()
|
||||
def test_sdpa_with_kv_cache(dtype, attn_descriptor, gqa_config):
|
||||
def test_sdpa_with_kv_cache(dtype, attn_backend, gqa_config):
|
||||
"""Test the SDPA transformation with KV cache."""
|
||||
# flashinfer doesn't support float32 data type
|
||||
if attn_descriptor == FlashInferAttention and dtype == torch.float32:
|
||||
if attn_backend == "flashinfer" and dtype == torch.float32:
|
||||
pytest.skip("flashinfer doesn't support float32 data type")
|
||||
|
||||
# Unpack the GQA configuration
|
||||
@ -157,7 +137,6 @@ def test_sdpa_with_kv_cache(dtype, attn_descriptor, gqa_config):
|
||||
hidden_size,
|
||||
num_key_value_heads,
|
||||
).to(dtype=dtype, device="cuda")
|
||||
factory = FakeFactory(model)
|
||||
|
||||
# Create input tensor and position_ids
|
||||
x = torch.rand(batch_size, seq_len, hidden_size).to(device="cuda", dtype=dtype)
|
||||
@ -166,21 +145,37 @@ def test_sdpa_with_kv_cache(dtype, attn_descriptor, gqa_config):
|
||||
# Get the model's regular output
|
||||
y_model = model(x, position_ids) # b, s, d
|
||||
|
||||
# run modular inference optimizer up to post_export
|
||||
optimizer = InferenceOptimizer(factory, _get_optimizer_config()) # type: ignore
|
||||
# Apply the transformation
|
||||
optimizer = InferenceOptimizer(
|
||||
DummyFactory(model, CacheConfig()),
|
||||
{
|
||||
"build_model": {
|
||||
"stage": "factory",
|
||||
"device": "cuda",
|
||||
"run_graph_cleanup": False,
|
||||
"requires_clean_graph": False,
|
||||
},
|
||||
"export_to_gm": {
|
||||
"stage": "export",
|
||||
"strict": False,
|
||||
"clone_state_dict": True,
|
||||
"run_graph_cleanup": False,
|
||||
"requires_clean_graph": False,
|
||||
},
|
||||
"cleanup_input_constraints": {
|
||||
"stage": "post_export",
|
||||
},
|
||||
"update_in_out_nodes": {
|
||||
"stage": "cache_init",
|
||||
},
|
||||
"insert_cached_attention": {
|
||||
"stage": "cache_init",
|
||||
"attn_backend": attn_backend,
|
||||
},
|
||||
},
|
||||
) # type: ignore
|
||||
gm = optimizer(cm)
|
||||
|
||||
y_gm = gm(x, position_ids)
|
||||
assert all_close(y_model, y_gm, atol=atol, rtol=rtol)
|
||||
|
||||
# Set up cache configuration
|
||||
cache_config = CacheConfig()
|
||||
|
||||
# Get input node(s)
|
||||
update_in_out_nodes(gm, cm)
|
||||
|
||||
# Apply the transformation
|
||||
insert_cached_attention(gm, cm, attn_descriptor=attn_descriptor, cache_config=cache_config)
|
||||
gm.to("cuda")
|
||||
cm.initialize_caches()
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user