[#9283][feat] AutoDeploy: separate rms pattern detection from fusion (#9969)

Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
This commit is contained in:
Frida Hou 2026-01-13 11:57:27 -08:00 committed by GitHub
parent 7b7f1e2ba1
commit bf16fbd86c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 247 additions and 84 deletions

View File

@ -48,6 +48,8 @@ transforms:
match_rope_layout:
stage: pattern_matcher
expected_layout: bsnd
match_rmsnorm_pattern:
stage: pattern_matcher
############################################################################################
# RUN TRANSFORMATIONS ON STANDARDIZED GRAPH REPRESENTATION
############################################################################################

View File

@ -88,6 +88,65 @@ def _(input: torch.Tensor, weight: torch.Tensor, eps: float) -> torch.Tensor:
return torch.empty_like(input)
@torch.library.custom_op("auto_deploy::torch_rmsnorm_gated", mutates_args=())
def torch_rmsnorm_gated(
x: torch.Tensor,
weight: torch.Tensor,
gate: torch.Tensor | None,
eps: float,
group_size: int,
norm_before_gate: bool = False,
) -> torch.Tensor:
"""Custom operator for Torch gated RMSNorm implementation.
Group RMSNorm with optional SiLU gating, using pure PyTorch operations.
Args:
x: Input tensor of shape [..., H].
weight: Scaling weights of shape [H].
gate: Optional gate tensor with same shape as x, or None.
eps: Small constant for numerical stability.
group_size: Size of groups for grouped normalization. H must be divisible by group_size.
norm_before_gate: If True, apply gating after normalization. If False, apply before.
Returns:
Normalized and optionally gated tensor of shape like x.
"""
dtype = x.dtype
weight = weight.float()
x = x.float()
z = gate.float() if gate is not None else gate
if z is not None and not norm_before_gate:
x = x * F.silu(z)
if group_size is None:
rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
out = x * rstd * weight
else:
x_group = rearrange(x, "... (g d) -> ... g d", d=group_size)
rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) + eps)
out = rearrange(x_group * rstd, "... g d -> ... (g d)") * weight
if z is not None and norm_before_gate:
out *= F.silu(z)
return out.to(dtype)
@torch_rmsnorm_gated.register_fake
def _(
x: torch.Tensor,
weight: torch.Tensor,
gate: torch.Tensor | None,
eps: float,
group_size: int,
norm_before_gate: bool = False,
) -> torch.Tensor:
"""Fake implementation for the custom operator during tracing."""
return x.new_empty(x.shape, dtype=x.dtype)
@torch.library.custom_op("auto_deploy::triton_rmsnorm_gated", mutates_args=())
def triton_rmsnorm_gated(
x: torch.Tensor,

View File

@ -32,7 +32,10 @@ from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformReg
def _make_allreduce_residual_rmsnorm_pattern(
add_order: str = "residual_first", strategy: str = "AUTO"
):
"""Factory function to create pattern functions for allreduce+residual+rmsnorm fusion.
"""Factory function to create pattern functions for allreduce+residual+torch_rmsnorm fusion.
This pattern matches the graph after match_rmsnorm_pattern has replaced
RMSNorm patterns with torch_rmsnorm ops.
Args:
add_order: Either "residual_first" (residual + x) or "x_first" (x + residual)
@ -45,15 +48,14 @@ def _make_allreduce_residual_rmsnorm_pattern(
def pattern_fn(
x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 0.1253
):
"""Pattern: trtllm_dist_all_reduce(x) -> add residual -> RMSNorm
"""Pattern: trtllm_dist_all_reduce(x) -> add residual -> torch_rmsnorm
Reference PyTorch composition:
y = trtllm_dist_all_reduce(x)
z = residual + y (or y + residual)
normed = RMSNorm(z, weight, eps)
normed = torch_rmsnorm(z, weight, eps)
Returns (normed, z)
"""
input_dtype = x.dtype
hidden_states = torch.ops.auto_deploy.trtllm_dist_all_reduce(x, strategy)
# Handle addition order
@ -62,11 +64,8 @@ def _make_allreduce_residual_rmsnorm_pattern(
else: # x_first
add = hidden_states + residual
hidden_states = add.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + eps)
normed = weight * hidden_states.to(input_dtype)
# Use torch_rmsnorm op (already replaced by match_rmsnorm_pattern)
normed = torch.ops.auto_deploy.torch_rmsnorm(add, weight, eps)
return normed, add
@ -94,6 +93,9 @@ class FuseAllreduceResidualRMSNorm(BaseTransform):
This transform only applies when TRT-LLM ops are used (MPI mode), as it provides
optimized fused kernels. The torch backend (demollm mode) does not benefit from
this fusion and uses unfused operations.
Note: This transform expects torch_rmsnorm ops in the graph, which are created
by the match_rmsnorm_pattern transform that runs earlier in the pipeline.
"""
def _apply(
@ -114,7 +116,6 @@ class FuseAllreduceResidualRMSNorm(BaseTransform):
0.1253, # eps
]
op_ignore_types = {torch.ops.aten.to.dtype: (torch.dtype,)}
scalar_workaround = {"eps": 0.1253}
# ============================================================================
@ -139,7 +140,6 @@ class FuseAllreduceResidualRMSNorm(BaseTransform):
replace_fn=partial(_allreduce_residual_rmsnorm_replacement, strategy=strategy),
patterns=patterns,
dummy_args=dummy_args,
op_ignore_types=op_ignore_types,
scalar_workaround=scalar_workaround,
)
@ -149,7 +149,6 @@ class FuseAllreduceResidualRMSNorm(BaseTransform):
replace_fn=partial(_allreduce_residual_rmsnorm_replacement, strategy=strategy),
patterns=patterns,
dummy_args=dummy_args,
op_ignore_types=op_ignore_types,
scalar_workaround=scalar_workaround,
)

View File

@ -1,17 +1,17 @@
"""Graph transform to optimize RMSNorm execution using FlashInfer."""
from functools import partial
from typing import Tuple, Type
import torch
from pydantic import Field
from torch.fx import GraphModule
from torch.fx import GraphModule, Node
from ...custom_ops.rms_norm import gated_rms_norm_ref
from ...models.factory import ModelFactory
from ...shim.interface import CachedSequenceInterface
# It is important to import ADPatternMatcherPass from pattern_matcher.py, not from torch._inductor.pattern_matcher
from ...utils.node_utils import is_op
from ...utils.pattern_matcher import ADPatternMatcherPass, register_ad_pattern
from ..interface import (
BaseTransform,
@ -66,6 +66,22 @@ def _rms_norm_pattern_float32_weights(
return (weight.to(torch.float32) * data).to(input_dtype)
def _rms_norm_to_torch_rmsnorm(
data: torch.Tensor, weight: torch.Tensor, eps: float
) -> torch.Tensor:
"""Replace RMSNorm pattern with torch_rmsnorm op (standardized representation).
Args:
data: Input tensor to normalize.
weight: Scaling weights for the normalized output.
eps: Small constant for numerical stability.
Returns:
Normalized and scaled tensor using torch_rmsnorm.
"""
return torch.ops.auto_deploy.torch_rmsnorm(data, weight, eps)
def _rms_norm_replacement(
data: torch.Tensor, weight: torch.Tensor, eps: float, backend: str
) -> torch.Tensor:
@ -87,6 +103,109 @@ def _rms_norm_replacement(
return _BACKEND_OPS[backend.lower()](data, weight, eps)
@TransformRegistry.register("match_rmsnorm_pattern")
class MatchRMSNormPattern(BaseTransform):
"""Matches RMSNorm patterns in the graph and replaces them with torch_rmsnorm op.
This transform runs in the pattern_matcher stage and standardizes RMSNorm patterns
to use torch_rmsnorm op, which can later be fused to a specific backend in the
post_load_fusion stage.
Args:
gm: Input graph module to transform.
Returns:
Transformed graph module with standardized torch_rmsnorm operations.
"""
config: TransformConfig
@classmethod
def get_config_class(cls) -> Type[TransformConfig]:
return TransformConfig
def _apply(
self,
gm: GraphModule,
cm: CachedSequenceInterface,
factory: ModelFactory,
shared_config: SharedConfig,
) -> Tuple[GraphModule, TransformInfo]:
graph = gm.graph
patterns = ADPatternMatcherPass()
# Pattern matching for regular RMSNorm
bs = 2
hidden_size = 512
def dummy_args(input_dtype: torch.dtype, weight_dtype: torch.dtype, eps: float = 1e-6):
return [
torch.randn(bs, hidden_size, device="cuda", dtype=input_dtype),
torch.randn(hidden_size, device="cuda", dtype=weight_dtype),
eps,
]
# Define configurations for different data types
configs = [
(torch.bfloat16, torch.bfloat16),
(torch.float16, torch.float16),
(torch.float32, torch.float32),
]
# Register patterns for each configuration - replace with torch_rmsnorm
search_fns = [
_rms_norm_pattern,
_rms_norm_pattern_float32_weights,
]
for search_fn in search_fns:
for input_dtype, weight_dtype in configs:
register_ad_pattern(
search_fn=search_fn,
replace_fn=_rms_norm_to_torch_rmsnorm,
patterns=patterns,
dummy_args=dummy_args(input_dtype, weight_dtype),
op_ignore_types={},
scalar_workaround={"eps": 1e-6},
)
# Pattern matching for gated RMSNorm
B, S, H = 2, 3, 4096
group_size = 512
eps = 1e-5
def make_dummy_args_gated(group_size: int, eps: float) -> list:
x = torch.randn(B, S, H, dtype=torch.float32)
w = torch.randn(H, dtype=torch.float32)
g = torch.randn(B, S, H, dtype=torch.float32)
return [x, w, g, eps, group_size]
op_ignore_types = {
torch.ops.aten.reshape.default: (int, list, tuple),
torch.ops.aten.view.default: (int, list, tuple),
torch.ops.aten.mean.dim: (list, tuple),
torch.ops.aten.to.dtype: (torch.dtype,),
}
# Register pattern for gated RMSNorm - replace with torch_rmsnorm_gated
register_ad_pattern(
search_fn=_gated_rmsnorm_pattern_ref,
replace_fn=_gated_rmsnorm_to_torch_rmsnorm_gated,
patterns=patterns,
dummy_args=make_dummy_args_gated(group_size, eps),
op_ignore_types=op_ignore_types,
scalar_workaround={"eps": eps, "group_size": group_size},
skip_duplicates=True,
)
cnt = patterns.apply(graph)
info = TransformInfo(
skipped=False, num_matches=cnt, is_clean=cnt == 0, has_valid_shapes=cnt == 0
)
return gm, info
class FuseRMSNormConfig(TransformConfig):
"""Configuration for the RMSNorm fusion transform."""
@ -102,11 +221,10 @@ class FuseRMSNormConfig(TransformConfig):
@TransformRegistry.register("fuse_rmsnorm")
class FuseRMSNorm(BaseTransform):
"""Matches and replaces RMSNorm patterns (regular and gated) in the graph with optimized implementations.
"""Fuses torch_rmsnorm ops with the selected backend implementation.
This function sets up pattern matching to identify both regular and gated RMSNorm operations in the graph
and replaces them with optimized implementations. It uses dummy tensors to register
the pattern matching rules.
This transform runs in the post_load_fusion stage and replaces torch_rmsnorm ops
with the specified backend implementation (flashinfer, triton, or torch).
Args:
gm: Input graph module to transform.
@ -114,7 +232,7 @@ class FuseRMSNorm(BaseTransform):
gated_rmsnorm_backend: Backend to use for gated RMSNorm computation (currently only "triton").
Returns:
Transformed graph module with optimized RMSNorm operations.
Transformed graph module with backend-specific RMSNorm operations.
"""
config: FuseRMSNormConfig
@ -144,72 +262,39 @@ class FuseRMSNorm(BaseTransform):
)
graph = gm.graph
patterns = ADPatternMatcherPass()
backend = self.config.rmsnorm_backend.lower()
target_op = _BACKEND_OPS[backend]
cnt = 0
# Pattern matching for regular RMSNorm
bs = 2
hidden_size = 512
# Replace torch_rmsnorm ops with the selected backend
for node in list(graph.nodes):
if is_op(node, torch.ops.auto_deploy.torch_rmsnorm):
# Replace with the selected backend op
with graph.inserting_after(node):
new_node: Node = graph.call_function(
target_op,
args=node.args,
kwargs=node.kwargs,
)
node.replace_all_uses_with(new_node)
graph.erase_node(node)
cnt += 1
def dummy_args(input_dtype: torch.dtype, weight_dtype: torch.dtype, eps: float = 1e-6):
return [
torch.randn(bs, hidden_size, device="cuda", dtype=input_dtype),
torch.randn(hidden_size, device="cuda", dtype=weight_dtype),
eps,
]
# Replace torch_rmsnorm_gated ops with triton_rmsnorm_gated
for node in list(graph.nodes):
if is_op(node, torch.ops.auto_deploy.torch_rmsnorm_gated):
# Replace with triton_rmsnorm_gated op
with graph.inserting_after(node):
new_node: Node = graph.call_function(
torch.ops.auto_deploy.triton_rmsnorm_gated,
args=node.args,
kwargs=node.kwargs,
)
node.replace_all_uses_with(new_node)
graph.erase_node(node)
cnt += 1
# Define configurations for different data types
configs = [
(torch.bfloat16, torch.bfloat16),
(torch.float16, torch.float16),
(torch.float32, torch.float32),
]
# Register patterns for each configuration
search_fns = [
_rms_norm_pattern,
_rms_norm_pattern_float32_weights,
]
for search_fn in search_fns:
for input_dtype, weight_dtype in configs:
register_ad_pattern(
search_fn=search_fn,
replace_fn=partial(_rms_norm_replacement, backend=self.config.rmsnorm_backend),
patterns=patterns,
dummy_args=dummy_args(input_dtype, weight_dtype),
op_ignore_types={},
scalar_workaround={"eps": 1e-6},
)
# Pattern matching for gated RMSNorm
B, S, H = 2, 3, 4096
group_size = 512
eps = 1e-5
def make_dummy_args_gated(group_size: int, eps: float) -> list:
x = torch.randn(B, S, H, dtype=torch.float32)
w = torch.randn(H, dtype=torch.float32)
g = torch.randn(B, S, H, dtype=torch.float32)
return [x, w, g, eps, group_size]
op_ignore_types = {
torch.ops.aten.reshape.default: (int, list, tuple),
torch.ops.aten.view.default: (int, list, tuple),
torch.ops.aten.mean.dim: (list, tuple),
torch.ops.aten.to.dtype: (torch.dtype,),
}
# Register pattern for gated RMSNorm
register_ad_pattern(
search_fn=_gated_rmsnorm_pattern_ref,
replace_fn=_gated_rmsnorm_replacement,
patterns=patterns,
dummy_args=make_dummy_args_gated(group_size, eps),
op_ignore_types=op_ignore_types,
scalar_workaround={"eps": eps, "group_size": group_size},
skip_duplicates=True,
)
cnt = patterns.apply(graph)
gm.recompile()
info = TransformInfo(
skipped=False, num_matches=cnt, is_clean=cnt == 0, has_valid_shapes=cnt == 0
@ -239,13 +324,25 @@ def _gated_rmsnorm_pattern_ref(
return y
def _gated_rmsnorm_replacement(
def _gated_rmsnorm_to_torch_rmsnorm_gated(
x: torch.Tensor,
weight: torch.Tensor,
gate: torch.Tensor,
eps: float,
group_size: int,
) -> torch.Tensor:
return torch.ops.auto_deploy.triton_rmsnorm_gated(
"""Replace gated RMSNorm pattern with torch_rmsnorm_gated op (standardized representation).
Args:
x: Input tensor to normalize.
weight: Scaling weights for the normalized output.
gate: Gate tensor for gated normalization.
eps: Small constant for numerical stability.
group_size: Size of groups for grouped normalization.
Returns:
Normalized and gated tensor using torch_rmsnorm_gated.
"""
return torch.ops.auto_deploy.torch_rmsnorm_gated(
x, weight, gate, float(eps), int(group_size), False
)

View File

@ -87,6 +87,9 @@ def _test_allreduce_fusion(port: int, ModuleCls, strategy: str):
gm_transformed = InferenceOptimizer(
None,
{
"match_rmsnorm_pattern": {
"stage": "pattern_matcher",
},
"detect_sharding": {
"stage": "post_export",
"allreduce_strategy": strategy,

View File

@ -64,6 +64,9 @@ def _run_test(model, op, variant):
gm_transformed = InferenceOptimizer(
None,
{
"match_rmsnorm_pattern": {
"stage": "pattern_matcher",
},
"fuse_rmsnorm": {
"stage": "post_load_fusion",
"gated_rmsnorm_backend": "triton",