From 5a97374f3c93e10cc8c3eb049829ae2ac4674b22 Mon Sep 17 00:00:00 2001 From: Karthik Date: Fri, 30 Jan 2026 16:05:53 -0500 Subject: [PATCH] [#9525][feat] add L2 norm pattern matcher and fusion transform (#10767) Signed-off-by: Karthik Vetrivel --- .../_torch/auto_deploy/config/default.yaml | 5 + .../models/custom/modeling_nemotron_flash.py | 4 +- .../auto_deploy/transform/library/l2_norm.py | 200 ++++++++++++++++++ .../library/test_fuse_l2norm.py | 101 +++++++++ 4 files changed, 308 insertions(+), 2 deletions(-) create mode 100644 tensorrt_llm/_torch/auto_deploy/transform/library/l2_norm.py create mode 100644 tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_l2norm.py diff --git a/tensorrt_llm/_torch/auto_deploy/config/default.yaml b/tensorrt_llm/_torch/auto_deploy/config/default.yaml index 5fc629fd8f..ff5a32eac8 100644 --- a/tensorrt_llm/_torch/auto_deploy/config/default.yaml +++ b/tensorrt_llm/_torch/auto_deploy/config/default.yaml @@ -50,6 +50,8 @@ transforms: expected_layout: bsnd match_rmsnorm_pattern: stage: pattern_matcher + match_l2norm_pattern: + stage: pattern_matcher ############################################################################################ # RUN TRANSFORMATIONS ON STANDARDIZED GRAPH REPRESENTATION ############################################################################################ @@ -140,6 +142,9 @@ transforms: rmsnorm_backend: flashinfer gated_rmsnorm_backend: triton requires_shape_prop: true + fuse_l2norm: + stage: post_load_fusion + backend: fla fuse_add_rms_norm: stage: post_load_fusion enabled: true diff --git a/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_nemotron_flash.py b/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_nemotron_flash.py index 588eb82c33..55c1090f53 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_nemotron_flash.py +++ b/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_nemotron_flash.py @@ -284,8 +284,8 @@ class DeltaNet(nn.Module): beta = beta * 2.0 if self.qk_norm == "l2": - q = torch.ops.auto_deploy.fla_l2norm(q) - k = torch.ops.auto_deploy.fla_l2norm(k) + q = torch.ops.auto_deploy.torch_l2norm(q) + k = torch.ops.auto_deploy.torch_l2norm(k) elif self.qk_norm == "sum": q = sum_norm(q).to(q) k = sum_norm(k).to(k) diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/l2_norm.py b/tensorrt_llm/_torch/auto_deploy/transform/library/l2_norm.py new file mode 100644 index 0000000000..e83245bf70 --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/l2_norm.py @@ -0,0 +1,200 @@ +"""Graph transform to optimize L2Norm execution using FLA Triton kernels.""" + +from typing import Literal, Tuple, Type + +import torch +from pydantic import Field +from torch.fx import GraphModule, Node + +from ...models.factory import ModelFactory +from ...shim.interface import CachedSequenceInterface + +# It is important to import ADPatternMatcherPass from pattern_matcher.py, not from torch._inductor.pattern_matcher +from ...utils.node_utils import is_op +from ...utils.pattern_matcher import ADPatternMatcherPass, register_ad_pattern +from ..interface import ( + BaseTransform, + SharedConfig, + TransformConfig, + TransformInfo, + TransformRegistry, +) + +_BACKEND_OPS = { + "fla": torch.ops.auto_deploy.fla_l2norm.default, + "torch": torch.ops.auto_deploy.torch_l2norm.default, +} + + +def _l2_norm_pattern(data: torch.Tensor, eps: float) -> torch.Tensor: + """Implements the L2Norm pattern for pattern matching. + + L2 normalization: x / sqrt(sum(x^2) + eps) + + Args: + data: Input tensor to normalize. + eps: Small constant for numerical stability. + + Returns: + L2 normalized tensor. + """ + input_dtype = data.dtype + data = data.to(torch.float32) + sum_sq = (data * data).sum(dim=-1, keepdim=True) + data = data * torch.rsqrt(sum_sq + eps) + return data.to(input_dtype) + + +def _l2_norm_pattern_no_dtype_cast(data: torch.Tensor, eps: float) -> torch.Tensor: + """Implements the L2Norm pattern without dtype casting for pattern matching. + + Some models may already operate in float32 and skip the dtype cast. + + Args: + data: Input tensor to normalize. + eps: Small constant for numerical stability. + + Returns: + L2 normalized tensor. + """ + sum_sq = (data * data).sum(dim=-1, keepdim=True) + return data * torch.rsqrt(sum_sq + eps) + + +def _l2_norm_to_torch_l2norm(data: torch.Tensor, eps: float) -> torch.Tensor: + """Replace L2Norm pattern with torch_l2norm op (standardized representation). + + Args: + data: Input tensor to normalize. + eps: Small constant for numerical stability. + + Returns: + L2 normalized tensor using torch_l2norm. + """ + return torch.ops.auto_deploy.torch_l2norm(data, eps) + + +@TransformRegistry.register("match_l2norm_pattern") +class MatchL2NormPattern(BaseTransform): + """Matches L2Norm patterns in the graph and replaces them with torch_l2norm op. + + This transform runs in the pattern_matcher stage and standardizes L2Norm patterns + to use torch_l2norm op, which can later be fused to a specific backend in the + post_load_fusion stage. + + Args: + gm: Input graph module to transform. + + Returns: + Transformed graph module with standardized torch_l2norm operations. + """ + + def _apply( + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, + ) -> Tuple[GraphModule, TransformInfo]: + graph = gm.graph + patterns = ADPatternMatcherPass() + + bs = 2 + hidden_size = 512 + + def dummy_args(input_dtype: torch.dtype, eps: float = 1e-6): + return [ + torch.randn(bs, hidden_size, device="cuda", dtype=input_dtype), + eps, + ] + + configs = [ + torch.bfloat16, + torch.float16, + torch.float32, + ] + + search_fns = [ + _l2_norm_pattern, + _l2_norm_pattern_no_dtype_cast, + ] + for search_fn in search_fns: + for input_dtype in configs: + register_ad_pattern( + search_fn=search_fn, + replace_fn=_l2_norm_to_torch_l2norm, + patterns=patterns, + dummy_args=dummy_args(input_dtype), + op_ignore_types={}, + scalar_workaround={"eps": 1e-6}, + skip_duplicates=True, + ) + + cnt = patterns.apply(graph) + + info = TransformInfo( + skipped=False, num_matches=cnt, is_clean=cnt == 0, has_valid_shapes=cnt == 0 + ) + + return gm, info + + +class FuseL2NormConfig(TransformConfig): + """Configuration for the L2Norm fusion transform.""" + + backend: Literal["torch", "fla"] = Field( + default="fla", + description="Backend to use for L2Norm computation ('fla' or 'torch').", + ) + + +@TransformRegistry.register("fuse_l2norm") +class FuseL2Norm(BaseTransform): + """Fuses torch_l2norm ops with the selected backend implementation. + + This transform runs in the post_load_fusion stage and replaces torch_l2norm ops + with the specified backend implementation (fla or torch). + + Args: + gm: Input graph module to transform. + backend: Backend to use for L2Norm computation ("fla" or "torch"). + + Returns: + Transformed graph module with backend-specific L2Norm operations. + """ + + config: FuseL2NormConfig + + @classmethod + def get_config_class(cls) -> Type[TransformConfig]: + return FuseL2NormConfig + + def _apply( + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, + ) -> Tuple[GraphModule, TransformInfo]: + graph = gm.graph + target_op = _BACKEND_OPS[self.config.backend] + cnt = 0 + + for node in list(graph.nodes): + if is_op(node, torch.ops.auto_deploy.torch_l2norm): + with graph.inserting_after(node): + new_node: Node = graph.call_function( + target_op, + args=node.args, + kwargs=node.kwargs, + ) + new_node.meta = node.meta.copy() + node.replace_all_uses_with(new_node) + graph.erase_node(node) + cnt += 1 + + info = TransformInfo( + skipped=False, num_matches=cnt, is_clean=cnt == 0, has_valid_shapes=cnt == 0 + ) + + return gm, info diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_l2norm.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_l2norm.py new file mode 100644 index 0000000000..e52b8e4392 --- /dev/null +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_l2norm.py @@ -0,0 +1,101 @@ +import pytest +import torch +from _graph_test_helpers import run_test_transformed_gm +from torch.export import Dim + +from tensorrt_llm._torch.auto_deploy.custom_ops.l2norm import * # noqa +from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm +from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer +from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op + + +class L2Norm(torch.nn.Module): + """L2 normalization module that normalizes along the last dimension.""" + + def __init__(self, eps=1e-6): + super().__init__() + self.eps = eps + + def forward(self, x): + input_dtype = x.dtype + x = x.to(torch.float32) + sum_sq = (x * x).sum(dim=-1, keepdim=True) + x = x * torch.rsqrt(sum_sq + self.eps) + return x.to(input_dtype) + + +class L2NormNoCast(torch.nn.Module): + """L2 normalization module without dtype casting (for float32 inputs).""" + + def __init__(self, eps=1e-6): + super().__init__() + self.eps = eps + + def forward(self, x): + sum_sq = (x * x).sum(dim=-1, keepdim=True) + return x * torch.rsqrt(sum_sq + self.eps) + + +class TestModel(torch.nn.Module): + def __init__(self, eps: float = 1e-6, use_no_cast: bool = False): + super().__init__() + self.linear1 = torch.nn.Linear(1024, 1024, device="cuda", dtype=torch.float16) + if use_no_cast: + self.l2_norm = L2NormNoCast(eps) + else: + self.l2_norm = L2Norm(eps) + self.linear2 = torch.nn.Linear(1024, 1024, device="cuda", dtype=torch.float16) + + def forward(self, x): + x = self.linear1(x) + x = self.l2_norm(x) + x = self.linear2(x) + return x + + +def _run_test(model, op, variant): + def checker(gm): + return any(is_op(n, op) for n in gm.graph.nodes) + + x = torch.randn(2, 1024, device="cuda", dtype=torch.float16) + dynamic_shapes = {0: Dim.DYNAMIC} + gm = torch_export_to_gm(model, args=(x,), dynamic_shapes=(dynamic_shapes,), clone=True) + gm_transformed = InferenceOptimizer( + None, + { + "match_l2norm_pattern": { + "stage": "pattern_matcher", + }, + "fuse_l2norm": { + "stage": "post_load_fusion", + "backend": variant, + }, + }, + )(None, gm) + + run_test_transformed_gm( + model, + x, + gm_transformed, + checker, + lambda num_p_og: num_p_og, + dynamic_shapes=dynamic_shapes, + ) + + new_input = torch.randn(4, 1024, device="cuda", dtype=torch.float16) + y_transformed = gm_transformed(new_input) + y_model = model(new_input) + torch.testing.assert_close(y_transformed, y_model, atol=1e-3, rtol=1e-3) + + +@pytest.mark.parametrize("eps", [1e-2, 1e-6]) +@pytest.mark.parametrize( + "variant, op", + [ + ("fla", torch.ops.auto_deploy.fla_l2norm.default), + ("torch", torch.ops.auto_deploy.torch_l2norm.default), + ], +) +def test_l2norm_fusion(eps, variant, op): + model = TestModel(eps) + _run_test(model, op, variant)