[#9525][feat] add L2 norm pattern matcher and fusion transform (#10767)

Signed-off-by: Karthik Vetrivel <kvetrivel@nvidia.com>
This commit is contained in:
Karthik 2026-01-30 16:05:53 -05:00 committed by GitHub
parent 4af47208d8
commit 5a97374f3c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 308 additions and 2 deletions

View File

@ -50,6 +50,8 @@ transforms:
expected_layout: bsnd
match_rmsnorm_pattern:
stage: pattern_matcher
match_l2norm_pattern:
stage: pattern_matcher
############################################################################################
# RUN TRANSFORMATIONS ON STANDARDIZED GRAPH REPRESENTATION
############################################################################################
@ -140,6 +142,9 @@ transforms:
rmsnorm_backend: flashinfer
gated_rmsnorm_backend: triton
requires_shape_prop: true
fuse_l2norm:
stage: post_load_fusion
backend: fla
fuse_add_rms_norm:
stage: post_load_fusion
enabled: true

View File

@ -284,8 +284,8 @@ class DeltaNet(nn.Module):
beta = beta * 2.0
if self.qk_norm == "l2":
q = torch.ops.auto_deploy.fla_l2norm(q)
k = torch.ops.auto_deploy.fla_l2norm(k)
q = torch.ops.auto_deploy.torch_l2norm(q)
k = torch.ops.auto_deploy.torch_l2norm(k)
elif self.qk_norm == "sum":
q = sum_norm(q).to(q)
k = sum_norm(k).to(k)

View File

@ -0,0 +1,200 @@
"""Graph transform to optimize L2Norm execution using FLA Triton kernels."""
from typing import Literal, Tuple, Type
import torch
from pydantic import Field
from torch.fx import GraphModule, Node
from ...models.factory import ModelFactory
from ...shim.interface import CachedSequenceInterface
# It is important to import ADPatternMatcherPass from pattern_matcher.py, not from torch._inductor.pattern_matcher
from ...utils.node_utils import is_op
from ...utils.pattern_matcher import ADPatternMatcherPass, register_ad_pattern
from ..interface import (
BaseTransform,
SharedConfig,
TransformConfig,
TransformInfo,
TransformRegistry,
)
_BACKEND_OPS = {
"fla": torch.ops.auto_deploy.fla_l2norm.default,
"torch": torch.ops.auto_deploy.torch_l2norm.default,
}
def _l2_norm_pattern(data: torch.Tensor, eps: float) -> torch.Tensor:
"""Implements the L2Norm pattern for pattern matching.
L2 normalization: x / sqrt(sum(x^2) + eps)
Args:
data: Input tensor to normalize.
eps: Small constant for numerical stability.
Returns:
L2 normalized tensor.
"""
input_dtype = data.dtype
data = data.to(torch.float32)
sum_sq = (data * data).sum(dim=-1, keepdim=True)
data = data * torch.rsqrt(sum_sq + eps)
return data.to(input_dtype)
def _l2_norm_pattern_no_dtype_cast(data: torch.Tensor, eps: float) -> torch.Tensor:
"""Implements the L2Norm pattern without dtype casting for pattern matching.
Some models may already operate in float32 and skip the dtype cast.
Args:
data: Input tensor to normalize.
eps: Small constant for numerical stability.
Returns:
L2 normalized tensor.
"""
sum_sq = (data * data).sum(dim=-1, keepdim=True)
return data * torch.rsqrt(sum_sq + eps)
def _l2_norm_to_torch_l2norm(data: torch.Tensor, eps: float) -> torch.Tensor:
"""Replace L2Norm pattern with torch_l2norm op (standardized representation).
Args:
data: Input tensor to normalize.
eps: Small constant for numerical stability.
Returns:
L2 normalized tensor using torch_l2norm.
"""
return torch.ops.auto_deploy.torch_l2norm(data, eps)
@TransformRegistry.register("match_l2norm_pattern")
class MatchL2NormPattern(BaseTransform):
"""Matches L2Norm patterns in the graph and replaces them with torch_l2norm op.
This transform runs in the pattern_matcher stage and standardizes L2Norm patterns
to use torch_l2norm op, which can later be fused to a specific backend in the
post_load_fusion stage.
Args:
gm: Input graph module to transform.
Returns:
Transformed graph module with standardized torch_l2norm operations.
"""
def _apply(
self,
gm: GraphModule,
cm: CachedSequenceInterface,
factory: ModelFactory,
shared_config: SharedConfig,
) -> Tuple[GraphModule, TransformInfo]:
graph = gm.graph
patterns = ADPatternMatcherPass()
bs = 2
hidden_size = 512
def dummy_args(input_dtype: torch.dtype, eps: float = 1e-6):
return [
torch.randn(bs, hidden_size, device="cuda", dtype=input_dtype),
eps,
]
configs = [
torch.bfloat16,
torch.float16,
torch.float32,
]
search_fns = [
_l2_norm_pattern,
_l2_norm_pattern_no_dtype_cast,
]
for search_fn in search_fns:
for input_dtype in configs:
register_ad_pattern(
search_fn=search_fn,
replace_fn=_l2_norm_to_torch_l2norm,
patterns=patterns,
dummy_args=dummy_args(input_dtype),
op_ignore_types={},
scalar_workaround={"eps": 1e-6},
skip_duplicates=True,
)
cnt = patterns.apply(graph)
info = TransformInfo(
skipped=False, num_matches=cnt, is_clean=cnt == 0, has_valid_shapes=cnt == 0
)
return gm, info
class FuseL2NormConfig(TransformConfig):
"""Configuration for the L2Norm fusion transform."""
backend: Literal["torch", "fla"] = Field(
default="fla",
description="Backend to use for L2Norm computation ('fla' or 'torch').",
)
@TransformRegistry.register("fuse_l2norm")
class FuseL2Norm(BaseTransform):
"""Fuses torch_l2norm ops with the selected backend implementation.
This transform runs in the post_load_fusion stage and replaces torch_l2norm ops
with the specified backend implementation (fla or torch).
Args:
gm: Input graph module to transform.
backend: Backend to use for L2Norm computation ("fla" or "torch").
Returns:
Transformed graph module with backend-specific L2Norm operations.
"""
config: FuseL2NormConfig
@classmethod
def get_config_class(cls) -> Type[TransformConfig]:
return FuseL2NormConfig
def _apply(
self,
gm: GraphModule,
cm: CachedSequenceInterface,
factory: ModelFactory,
shared_config: SharedConfig,
) -> Tuple[GraphModule, TransformInfo]:
graph = gm.graph
target_op = _BACKEND_OPS[self.config.backend]
cnt = 0
for node in list(graph.nodes):
if is_op(node, torch.ops.auto_deploy.torch_l2norm):
with graph.inserting_after(node):
new_node: Node = graph.call_function(
target_op,
args=node.args,
kwargs=node.kwargs,
)
new_node.meta = node.meta.copy()
node.replace_all_uses_with(new_node)
graph.erase_node(node)
cnt += 1
info = TransformInfo(
skipped=False, num_matches=cnt, is_clean=cnt == 0, has_valid_shapes=cnt == 0
)
return gm, info

View File

@ -0,0 +1,101 @@
import pytest
import torch
from _graph_test_helpers import run_test_transformed_gm
from torch.export import Dim
from tensorrt_llm._torch.auto_deploy.custom_ops.l2norm import * # noqa
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op
class L2Norm(torch.nn.Module):
"""L2 normalization module that normalizes along the last dimension."""
def __init__(self, eps=1e-6):
super().__init__()
self.eps = eps
def forward(self, x):
input_dtype = x.dtype
x = x.to(torch.float32)
sum_sq = (x * x).sum(dim=-1, keepdim=True)
x = x * torch.rsqrt(sum_sq + self.eps)
return x.to(input_dtype)
class L2NormNoCast(torch.nn.Module):
"""L2 normalization module without dtype casting (for float32 inputs)."""
def __init__(self, eps=1e-6):
super().__init__()
self.eps = eps
def forward(self, x):
sum_sq = (x * x).sum(dim=-1, keepdim=True)
return x * torch.rsqrt(sum_sq + self.eps)
class TestModel(torch.nn.Module):
def __init__(self, eps: float = 1e-6, use_no_cast: bool = False):
super().__init__()
self.linear1 = torch.nn.Linear(1024, 1024, device="cuda", dtype=torch.float16)
if use_no_cast:
self.l2_norm = L2NormNoCast(eps)
else:
self.l2_norm = L2Norm(eps)
self.linear2 = torch.nn.Linear(1024, 1024, device="cuda", dtype=torch.float16)
def forward(self, x):
x = self.linear1(x)
x = self.l2_norm(x)
x = self.linear2(x)
return x
def _run_test(model, op, variant):
def checker(gm):
return any(is_op(n, op) for n in gm.graph.nodes)
x = torch.randn(2, 1024, device="cuda", dtype=torch.float16)
dynamic_shapes = {0: Dim.DYNAMIC}
gm = torch_export_to_gm(model, args=(x,), dynamic_shapes=(dynamic_shapes,), clone=True)
gm_transformed = InferenceOptimizer(
None,
{
"match_l2norm_pattern": {
"stage": "pattern_matcher",
},
"fuse_l2norm": {
"stage": "post_load_fusion",
"backend": variant,
},
},
)(None, gm)
run_test_transformed_gm(
model,
x,
gm_transformed,
checker,
lambda num_p_og: num_p_og,
dynamic_shapes=dynamic_shapes,
)
new_input = torch.randn(4, 1024, device="cuda", dtype=torch.float16)
y_transformed = gm_transformed(new_input)
y_model = model(new_input)
torch.testing.assert_close(y_transformed, y_model, atol=1e-3, rtol=1e-3)
@pytest.mark.parametrize("eps", [1e-2, 1e-6])
@pytest.mark.parametrize(
"variant, op",
[
("fla", torch.ops.auto_deploy.fla_l2norm.default),
("torch", torch.ops.auto_deploy.torch_l2norm.default),
],
)
def test_l2norm_fusion(eps, variant, op):
model = TestModel(eps)
_run_test(model, op, variant)