mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-04 18:21:52 +08:00
Signed-off-by: Karthik Vetrivel <kvetrivel@nvidia.com>
This commit is contained in:
parent
4af47208d8
commit
5a97374f3c
@ -50,6 +50,8 @@ transforms:
|
||||
expected_layout: bsnd
|
||||
match_rmsnorm_pattern:
|
||||
stage: pattern_matcher
|
||||
match_l2norm_pattern:
|
||||
stage: pattern_matcher
|
||||
############################################################################################
|
||||
# RUN TRANSFORMATIONS ON STANDARDIZED GRAPH REPRESENTATION
|
||||
############################################################################################
|
||||
@ -140,6 +142,9 @@ transforms:
|
||||
rmsnorm_backend: flashinfer
|
||||
gated_rmsnorm_backend: triton
|
||||
requires_shape_prop: true
|
||||
fuse_l2norm:
|
||||
stage: post_load_fusion
|
||||
backend: fla
|
||||
fuse_add_rms_norm:
|
||||
stage: post_load_fusion
|
||||
enabled: true
|
||||
|
||||
@ -284,8 +284,8 @@ class DeltaNet(nn.Module):
|
||||
beta = beta * 2.0
|
||||
|
||||
if self.qk_norm == "l2":
|
||||
q = torch.ops.auto_deploy.fla_l2norm(q)
|
||||
k = torch.ops.auto_deploy.fla_l2norm(k)
|
||||
q = torch.ops.auto_deploy.torch_l2norm(q)
|
||||
k = torch.ops.auto_deploy.torch_l2norm(k)
|
||||
elif self.qk_norm == "sum":
|
||||
q = sum_norm(q).to(q)
|
||||
k = sum_norm(k).to(k)
|
||||
|
||||
200
tensorrt_llm/_torch/auto_deploy/transform/library/l2_norm.py
Normal file
200
tensorrt_llm/_torch/auto_deploy/transform/library/l2_norm.py
Normal file
@ -0,0 +1,200 @@
|
||||
"""Graph transform to optimize L2Norm execution using FLA Triton kernels."""
|
||||
|
||||
from typing import Literal, Tuple, Type
|
||||
|
||||
import torch
|
||||
from pydantic import Field
|
||||
from torch.fx import GraphModule, Node
|
||||
|
||||
from ...models.factory import ModelFactory
|
||||
from ...shim.interface import CachedSequenceInterface
|
||||
|
||||
# It is important to import ADPatternMatcherPass from pattern_matcher.py, not from torch._inductor.pattern_matcher
|
||||
from ...utils.node_utils import is_op
|
||||
from ...utils.pattern_matcher import ADPatternMatcherPass, register_ad_pattern
|
||||
from ..interface import (
|
||||
BaseTransform,
|
||||
SharedConfig,
|
||||
TransformConfig,
|
||||
TransformInfo,
|
||||
TransformRegistry,
|
||||
)
|
||||
|
||||
_BACKEND_OPS = {
|
||||
"fla": torch.ops.auto_deploy.fla_l2norm.default,
|
||||
"torch": torch.ops.auto_deploy.torch_l2norm.default,
|
||||
}
|
||||
|
||||
|
||||
def _l2_norm_pattern(data: torch.Tensor, eps: float) -> torch.Tensor:
|
||||
"""Implements the L2Norm pattern for pattern matching.
|
||||
|
||||
L2 normalization: x / sqrt(sum(x^2) + eps)
|
||||
|
||||
Args:
|
||||
data: Input tensor to normalize.
|
||||
eps: Small constant for numerical stability.
|
||||
|
||||
Returns:
|
||||
L2 normalized tensor.
|
||||
"""
|
||||
input_dtype = data.dtype
|
||||
data = data.to(torch.float32)
|
||||
sum_sq = (data * data).sum(dim=-1, keepdim=True)
|
||||
data = data * torch.rsqrt(sum_sq + eps)
|
||||
return data.to(input_dtype)
|
||||
|
||||
|
||||
def _l2_norm_pattern_no_dtype_cast(data: torch.Tensor, eps: float) -> torch.Tensor:
|
||||
"""Implements the L2Norm pattern without dtype casting for pattern matching.
|
||||
|
||||
Some models may already operate in float32 and skip the dtype cast.
|
||||
|
||||
Args:
|
||||
data: Input tensor to normalize.
|
||||
eps: Small constant for numerical stability.
|
||||
|
||||
Returns:
|
||||
L2 normalized tensor.
|
||||
"""
|
||||
sum_sq = (data * data).sum(dim=-1, keepdim=True)
|
||||
return data * torch.rsqrt(sum_sq + eps)
|
||||
|
||||
|
||||
def _l2_norm_to_torch_l2norm(data: torch.Tensor, eps: float) -> torch.Tensor:
|
||||
"""Replace L2Norm pattern with torch_l2norm op (standardized representation).
|
||||
|
||||
Args:
|
||||
data: Input tensor to normalize.
|
||||
eps: Small constant for numerical stability.
|
||||
|
||||
Returns:
|
||||
L2 normalized tensor using torch_l2norm.
|
||||
"""
|
||||
return torch.ops.auto_deploy.torch_l2norm(data, eps)
|
||||
|
||||
|
||||
@TransformRegistry.register("match_l2norm_pattern")
|
||||
class MatchL2NormPattern(BaseTransform):
|
||||
"""Matches L2Norm patterns in the graph and replaces them with torch_l2norm op.
|
||||
|
||||
This transform runs in the pattern_matcher stage and standardizes L2Norm patterns
|
||||
to use torch_l2norm op, which can later be fused to a specific backend in the
|
||||
post_load_fusion stage.
|
||||
|
||||
Args:
|
||||
gm: Input graph module to transform.
|
||||
|
||||
Returns:
|
||||
Transformed graph module with standardized torch_l2norm operations.
|
||||
"""
|
||||
|
||||
def _apply(
|
||||
self,
|
||||
gm: GraphModule,
|
||||
cm: CachedSequenceInterface,
|
||||
factory: ModelFactory,
|
||||
shared_config: SharedConfig,
|
||||
) -> Tuple[GraphModule, TransformInfo]:
|
||||
graph = gm.graph
|
||||
patterns = ADPatternMatcherPass()
|
||||
|
||||
bs = 2
|
||||
hidden_size = 512
|
||||
|
||||
def dummy_args(input_dtype: torch.dtype, eps: float = 1e-6):
|
||||
return [
|
||||
torch.randn(bs, hidden_size, device="cuda", dtype=input_dtype),
|
||||
eps,
|
||||
]
|
||||
|
||||
configs = [
|
||||
torch.bfloat16,
|
||||
torch.float16,
|
||||
torch.float32,
|
||||
]
|
||||
|
||||
search_fns = [
|
||||
_l2_norm_pattern,
|
||||
_l2_norm_pattern_no_dtype_cast,
|
||||
]
|
||||
for search_fn in search_fns:
|
||||
for input_dtype in configs:
|
||||
register_ad_pattern(
|
||||
search_fn=search_fn,
|
||||
replace_fn=_l2_norm_to_torch_l2norm,
|
||||
patterns=patterns,
|
||||
dummy_args=dummy_args(input_dtype),
|
||||
op_ignore_types={},
|
||||
scalar_workaround={"eps": 1e-6},
|
||||
skip_duplicates=True,
|
||||
)
|
||||
|
||||
cnt = patterns.apply(graph)
|
||||
|
||||
info = TransformInfo(
|
||||
skipped=False, num_matches=cnt, is_clean=cnt == 0, has_valid_shapes=cnt == 0
|
||||
)
|
||||
|
||||
return gm, info
|
||||
|
||||
|
||||
class FuseL2NormConfig(TransformConfig):
|
||||
"""Configuration for the L2Norm fusion transform."""
|
||||
|
||||
backend: Literal["torch", "fla"] = Field(
|
||||
default="fla",
|
||||
description="Backend to use for L2Norm computation ('fla' or 'torch').",
|
||||
)
|
||||
|
||||
|
||||
@TransformRegistry.register("fuse_l2norm")
|
||||
class FuseL2Norm(BaseTransform):
|
||||
"""Fuses torch_l2norm ops with the selected backend implementation.
|
||||
|
||||
This transform runs in the post_load_fusion stage and replaces torch_l2norm ops
|
||||
with the specified backend implementation (fla or torch).
|
||||
|
||||
Args:
|
||||
gm: Input graph module to transform.
|
||||
backend: Backend to use for L2Norm computation ("fla" or "torch").
|
||||
|
||||
Returns:
|
||||
Transformed graph module with backend-specific L2Norm operations.
|
||||
"""
|
||||
|
||||
config: FuseL2NormConfig
|
||||
|
||||
@classmethod
|
||||
def get_config_class(cls) -> Type[TransformConfig]:
|
||||
return FuseL2NormConfig
|
||||
|
||||
def _apply(
|
||||
self,
|
||||
gm: GraphModule,
|
||||
cm: CachedSequenceInterface,
|
||||
factory: ModelFactory,
|
||||
shared_config: SharedConfig,
|
||||
) -> Tuple[GraphModule, TransformInfo]:
|
||||
graph = gm.graph
|
||||
target_op = _BACKEND_OPS[self.config.backend]
|
||||
cnt = 0
|
||||
|
||||
for node in list(graph.nodes):
|
||||
if is_op(node, torch.ops.auto_deploy.torch_l2norm):
|
||||
with graph.inserting_after(node):
|
||||
new_node: Node = graph.call_function(
|
||||
target_op,
|
||||
args=node.args,
|
||||
kwargs=node.kwargs,
|
||||
)
|
||||
new_node.meta = node.meta.copy()
|
||||
node.replace_all_uses_with(new_node)
|
||||
graph.erase_node(node)
|
||||
cnt += 1
|
||||
|
||||
info = TransformInfo(
|
||||
skipped=False, num_matches=cnt, is_clean=cnt == 0, has_valid_shapes=cnt == 0
|
||||
)
|
||||
|
||||
return gm, info
|
||||
@ -0,0 +1,101 @@
|
||||
import pytest
|
||||
import torch
|
||||
from _graph_test_helpers import run_test_transformed_gm
|
||||
from torch.export import Dim
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.l2norm import * # noqa
|
||||
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
|
||||
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
|
||||
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op
|
||||
|
||||
|
||||
class L2Norm(torch.nn.Module):
|
||||
"""L2 normalization module that normalizes along the last dimension."""
|
||||
|
||||
def __init__(self, eps=1e-6):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x):
|
||||
input_dtype = x.dtype
|
||||
x = x.to(torch.float32)
|
||||
sum_sq = (x * x).sum(dim=-1, keepdim=True)
|
||||
x = x * torch.rsqrt(sum_sq + self.eps)
|
||||
return x.to(input_dtype)
|
||||
|
||||
|
||||
class L2NormNoCast(torch.nn.Module):
|
||||
"""L2 normalization module without dtype casting (for float32 inputs)."""
|
||||
|
||||
def __init__(self, eps=1e-6):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x):
|
||||
sum_sq = (x * x).sum(dim=-1, keepdim=True)
|
||||
return x * torch.rsqrt(sum_sq + self.eps)
|
||||
|
||||
|
||||
class TestModel(torch.nn.Module):
|
||||
def __init__(self, eps: float = 1e-6, use_no_cast: bool = False):
|
||||
super().__init__()
|
||||
self.linear1 = torch.nn.Linear(1024, 1024, device="cuda", dtype=torch.float16)
|
||||
if use_no_cast:
|
||||
self.l2_norm = L2NormNoCast(eps)
|
||||
else:
|
||||
self.l2_norm = L2Norm(eps)
|
||||
self.linear2 = torch.nn.Linear(1024, 1024, device="cuda", dtype=torch.float16)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear1(x)
|
||||
x = self.l2_norm(x)
|
||||
x = self.linear2(x)
|
||||
return x
|
||||
|
||||
|
||||
def _run_test(model, op, variant):
|
||||
def checker(gm):
|
||||
return any(is_op(n, op) for n in gm.graph.nodes)
|
||||
|
||||
x = torch.randn(2, 1024, device="cuda", dtype=torch.float16)
|
||||
dynamic_shapes = {0: Dim.DYNAMIC}
|
||||
gm = torch_export_to_gm(model, args=(x,), dynamic_shapes=(dynamic_shapes,), clone=True)
|
||||
gm_transformed = InferenceOptimizer(
|
||||
None,
|
||||
{
|
||||
"match_l2norm_pattern": {
|
||||
"stage": "pattern_matcher",
|
||||
},
|
||||
"fuse_l2norm": {
|
||||
"stage": "post_load_fusion",
|
||||
"backend": variant,
|
||||
},
|
||||
},
|
||||
)(None, gm)
|
||||
|
||||
run_test_transformed_gm(
|
||||
model,
|
||||
x,
|
||||
gm_transformed,
|
||||
checker,
|
||||
lambda num_p_og: num_p_og,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
)
|
||||
|
||||
new_input = torch.randn(4, 1024, device="cuda", dtype=torch.float16)
|
||||
y_transformed = gm_transformed(new_input)
|
||||
y_model = model(new_input)
|
||||
torch.testing.assert_close(y_transformed, y_model, atol=1e-3, rtol=1e-3)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("eps", [1e-2, 1e-6])
|
||||
@pytest.mark.parametrize(
|
||||
"variant, op",
|
||||
[
|
||||
("fla", torch.ops.auto_deploy.fla_l2norm.default),
|
||||
("torch", torch.ops.auto_deploy.torch_l2norm.default),
|
||||
],
|
||||
)
|
||||
def test_l2norm_fusion(eps, variant, op):
|
||||
model = TestModel(eps)
|
||||
_run_test(model, op, variant)
|
||||
Loading…
Reference in New Issue
Block a user