mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-12 22:14:03 +08:00
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
This commit is contained in:
parent
7b7f1e2ba1
commit
bf16fbd86c
@ -48,6 +48,8 @@ transforms:
|
|||||||
match_rope_layout:
|
match_rope_layout:
|
||||||
stage: pattern_matcher
|
stage: pattern_matcher
|
||||||
expected_layout: bsnd
|
expected_layout: bsnd
|
||||||
|
match_rmsnorm_pattern:
|
||||||
|
stage: pattern_matcher
|
||||||
############################################################################################
|
############################################################################################
|
||||||
# RUN TRANSFORMATIONS ON STANDARDIZED GRAPH REPRESENTATION
|
# RUN TRANSFORMATIONS ON STANDARDIZED GRAPH REPRESENTATION
|
||||||
############################################################################################
|
############################################################################################
|
||||||
|
|||||||
@ -88,6 +88,65 @@ def _(input: torch.Tensor, weight: torch.Tensor, eps: float) -> torch.Tensor:
|
|||||||
return torch.empty_like(input)
|
return torch.empty_like(input)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.library.custom_op("auto_deploy::torch_rmsnorm_gated", mutates_args=())
|
||||||
|
def torch_rmsnorm_gated(
|
||||||
|
x: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
gate: torch.Tensor | None,
|
||||||
|
eps: float,
|
||||||
|
group_size: int,
|
||||||
|
norm_before_gate: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Custom operator for Torch gated RMSNorm implementation.
|
||||||
|
|
||||||
|
Group RMSNorm with optional SiLU gating, using pure PyTorch operations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input tensor of shape [..., H].
|
||||||
|
weight: Scaling weights of shape [H].
|
||||||
|
gate: Optional gate tensor with same shape as x, or None.
|
||||||
|
eps: Small constant for numerical stability.
|
||||||
|
group_size: Size of groups for grouped normalization. H must be divisible by group_size.
|
||||||
|
norm_before_gate: If True, apply gating after normalization. If False, apply before.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Normalized and optionally gated tensor of shape like x.
|
||||||
|
"""
|
||||||
|
dtype = x.dtype
|
||||||
|
weight = weight.float()
|
||||||
|
x = x.float()
|
||||||
|
z = gate.float() if gate is not None else gate
|
||||||
|
|
||||||
|
if z is not None and not norm_before_gate:
|
||||||
|
x = x * F.silu(z)
|
||||||
|
|
||||||
|
if group_size is None:
|
||||||
|
rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
|
||||||
|
out = x * rstd * weight
|
||||||
|
else:
|
||||||
|
x_group = rearrange(x, "... (g d) -> ... g d", d=group_size)
|
||||||
|
rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) + eps)
|
||||||
|
out = rearrange(x_group * rstd, "... g d -> ... (g d)") * weight
|
||||||
|
|
||||||
|
if z is not None and norm_before_gate:
|
||||||
|
out *= F.silu(z)
|
||||||
|
|
||||||
|
return out.to(dtype)
|
||||||
|
|
||||||
|
|
||||||
|
@torch_rmsnorm_gated.register_fake
|
||||||
|
def _(
|
||||||
|
x: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
gate: torch.Tensor | None,
|
||||||
|
eps: float,
|
||||||
|
group_size: int,
|
||||||
|
norm_before_gate: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Fake implementation for the custom operator during tracing."""
|
||||||
|
return x.new_empty(x.shape, dtype=x.dtype)
|
||||||
|
|
||||||
|
|
||||||
@torch.library.custom_op("auto_deploy::triton_rmsnorm_gated", mutates_args=())
|
@torch.library.custom_op("auto_deploy::triton_rmsnorm_gated", mutates_args=())
|
||||||
def triton_rmsnorm_gated(
|
def triton_rmsnorm_gated(
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
|
|||||||
@ -32,7 +32,10 @@ from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformReg
|
|||||||
def _make_allreduce_residual_rmsnorm_pattern(
|
def _make_allreduce_residual_rmsnorm_pattern(
|
||||||
add_order: str = "residual_first", strategy: str = "AUTO"
|
add_order: str = "residual_first", strategy: str = "AUTO"
|
||||||
):
|
):
|
||||||
"""Factory function to create pattern functions for allreduce+residual+rmsnorm fusion.
|
"""Factory function to create pattern functions for allreduce+residual+torch_rmsnorm fusion.
|
||||||
|
|
||||||
|
This pattern matches the graph after match_rmsnorm_pattern has replaced
|
||||||
|
RMSNorm patterns with torch_rmsnorm ops.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
add_order: Either "residual_first" (residual + x) or "x_first" (x + residual)
|
add_order: Either "residual_first" (residual + x) or "x_first" (x + residual)
|
||||||
@ -45,15 +48,14 @@ def _make_allreduce_residual_rmsnorm_pattern(
|
|||||||
def pattern_fn(
|
def pattern_fn(
|
||||||
x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 0.1253
|
x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 0.1253
|
||||||
):
|
):
|
||||||
"""Pattern: trtllm_dist_all_reduce(x) -> add residual -> RMSNorm
|
"""Pattern: trtllm_dist_all_reduce(x) -> add residual -> torch_rmsnorm
|
||||||
|
|
||||||
Reference PyTorch composition:
|
Reference PyTorch composition:
|
||||||
y = trtllm_dist_all_reduce(x)
|
y = trtllm_dist_all_reduce(x)
|
||||||
z = residual + y (or y + residual)
|
z = residual + y (or y + residual)
|
||||||
normed = RMSNorm(z, weight, eps)
|
normed = torch_rmsnorm(z, weight, eps)
|
||||||
Returns (normed, z)
|
Returns (normed, z)
|
||||||
"""
|
"""
|
||||||
input_dtype = x.dtype
|
|
||||||
hidden_states = torch.ops.auto_deploy.trtllm_dist_all_reduce(x, strategy)
|
hidden_states = torch.ops.auto_deploy.trtllm_dist_all_reduce(x, strategy)
|
||||||
|
|
||||||
# Handle addition order
|
# Handle addition order
|
||||||
@ -62,11 +64,8 @@ def _make_allreduce_residual_rmsnorm_pattern(
|
|||||||
else: # x_first
|
else: # x_first
|
||||||
add = hidden_states + residual
|
add = hidden_states + residual
|
||||||
|
|
||||||
hidden_states = add.to(torch.float32)
|
# Use torch_rmsnorm op (already replaced by match_rmsnorm_pattern)
|
||||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
normed = torch.ops.auto_deploy.torch_rmsnorm(add, weight, eps)
|
||||||
hidden_states = hidden_states * torch.rsqrt(variance + eps)
|
|
||||||
|
|
||||||
normed = weight * hidden_states.to(input_dtype)
|
|
||||||
|
|
||||||
return normed, add
|
return normed, add
|
||||||
|
|
||||||
@ -94,6 +93,9 @@ class FuseAllreduceResidualRMSNorm(BaseTransform):
|
|||||||
This transform only applies when TRT-LLM ops are used (MPI mode), as it provides
|
This transform only applies when TRT-LLM ops are used (MPI mode), as it provides
|
||||||
optimized fused kernels. The torch backend (demollm mode) does not benefit from
|
optimized fused kernels. The torch backend (demollm mode) does not benefit from
|
||||||
this fusion and uses unfused operations.
|
this fusion and uses unfused operations.
|
||||||
|
|
||||||
|
Note: This transform expects torch_rmsnorm ops in the graph, which are created
|
||||||
|
by the match_rmsnorm_pattern transform that runs earlier in the pipeline.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _apply(
|
def _apply(
|
||||||
@ -114,7 +116,6 @@ class FuseAllreduceResidualRMSNorm(BaseTransform):
|
|||||||
0.1253, # eps
|
0.1253, # eps
|
||||||
]
|
]
|
||||||
|
|
||||||
op_ignore_types = {torch.ops.aten.to.dtype: (torch.dtype,)}
|
|
||||||
scalar_workaround = {"eps": 0.1253}
|
scalar_workaround = {"eps": 0.1253}
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
@ -139,7 +140,6 @@ class FuseAllreduceResidualRMSNorm(BaseTransform):
|
|||||||
replace_fn=partial(_allreduce_residual_rmsnorm_replacement, strategy=strategy),
|
replace_fn=partial(_allreduce_residual_rmsnorm_replacement, strategy=strategy),
|
||||||
patterns=patterns,
|
patterns=patterns,
|
||||||
dummy_args=dummy_args,
|
dummy_args=dummy_args,
|
||||||
op_ignore_types=op_ignore_types,
|
|
||||||
scalar_workaround=scalar_workaround,
|
scalar_workaround=scalar_workaround,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -149,7 +149,6 @@ class FuseAllreduceResidualRMSNorm(BaseTransform):
|
|||||||
replace_fn=partial(_allreduce_residual_rmsnorm_replacement, strategy=strategy),
|
replace_fn=partial(_allreduce_residual_rmsnorm_replacement, strategy=strategy),
|
||||||
patterns=patterns,
|
patterns=patterns,
|
||||||
dummy_args=dummy_args,
|
dummy_args=dummy_args,
|
||||||
op_ignore_types=op_ignore_types,
|
|
||||||
scalar_workaround=scalar_workaround,
|
scalar_workaround=scalar_workaround,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -1,17 +1,17 @@
|
|||||||
"""Graph transform to optimize RMSNorm execution using FlashInfer."""
|
"""Graph transform to optimize RMSNorm execution using FlashInfer."""
|
||||||
|
|
||||||
from functools import partial
|
|
||||||
from typing import Tuple, Type
|
from typing import Tuple, Type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from torch.fx import GraphModule
|
from torch.fx import GraphModule, Node
|
||||||
|
|
||||||
from ...custom_ops.rms_norm import gated_rms_norm_ref
|
from ...custom_ops.rms_norm import gated_rms_norm_ref
|
||||||
from ...models.factory import ModelFactory
|
from ...models.factory import ModelFactory
|
||||||
from ...shim.interface import CachedSequenceInterface
|
from ...shim.interface import CachedSequenceInterface
|
||||||
|
|
||||||
# It is important to import ADPatternMatcherPass from pattern_matcher.py, not from torch._inductor.pattern_matcher
|
# It is important to import ADPatternMatcherPass from pattern_matcher.py, not from torch._inductor.pattern_matcher
|
||||||
|
from ...utils.node_utils import is_op
|
||||||
from ...utils.pattern_matcher import ADPatternMatcherPass, register_ad_pattern
|
from ...utils.pattern_matcher import ADPatternMatcherPass, register_ad_pattern
|
||||||
from ..interface import (
|
from ..interface import (
|
||||||
BaseTransform,
|
BaseTransform,
|
||||||
@ -66,6 +66,22 @@ def _rms_norm_pattern_float32_weights(
|
|||||||
return (weight.to(torch.float32) * data).to(input_dtype)
|
return (weight.to(torch.float32) * data).to(input_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def _rms_norm_to_torch_rmsnorm(
|
||||||
|
data: torch.Tensor, weight: torch.Tensor, eps: float
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Replace RMSNorm pattern with torch_rmsnorm op (standardized representation).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Input tensor to normalize.
|
||||||
|
weight: Scaling weights for the normalized output.
|
||||||
|
eps: Small constant for numerical stability.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Normalized and scaled tensor using torch_rmsnorm.
|
||||||
|
"""
|
||||||
|
return torch.ops.auto_deploy.torch_rmsnorm(data, weight, eps)
|
||||||
|
|
||||||
|
|
||||||
def _rms_norm_replacement(
|
def _rms_norm_replacement(
|
||||||
data: torch.Tensor, weight: torch.Tensor, eps: float, backend: str
|
data: torch.Tensor, weight: torch.Tensor, eps: float, backend: str
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
@ -87,6 +103,109 @@ def _rms_norm_replacement(
|
|||||||
return _BACKEND_OPS[backend.lower()](data, weight, eps)
|
return _BACKEND_OPS[backend.lower()](data, weight, eps)
|
||||||
|
|
||||||
|
|
||||||
|
@TransformRegistry.register("match_rmsnorm_pattern")
|
||||||
|
class MatchRMSNormPattern(BaseTransform):
|
||||||
|
"""Matches RMSNorm patterns in the graph and replaces them with torch_rmsnorm op.
|
||||||
|
|
||||||
|
This transform runs in the pattern_matcher stage and standardizes RMSNorm patterns
|
||||||
|
to use torch_rmsnorm op, which can later be fused to a specific backend in the
|
||||||
|
post_load_fusion stage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
gm: Input graph module to transform.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Transformed graph module with standardized torch_rmsnorm operations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
config: TransformConfig
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config_class(cls) -> Type[TransformConfig]:
|
||||||
|
return TransformConfig
|
||||||
|
|
||||||
|
def _apply(
|
||||||
|
self,
|
||||||
|
gm: GraphModule,
|
||||||
|
cm: CachedSequenceInterface,
|
||||||
|
factory: ModelFactory,
|
||||||
|
shared_config: SharedConfig,
|
||||||
|
) -> Tuple[GraphModule, TransformInfo]:
|
||||||
|
graph = gm.graph
|
||||||
|
patterns = ADPatternMatcherPass()
|
||||||
|
|
||||||
|
# Pattern matching for regular RMSNorm
|
||||||
|
bs = 2
|
||||||
|
hidden_size = 512
|
||||||
|
|
||||||
|
def dummy_args(input_dtype: torch.dtype, weight_dtype: torch.dtype, eps: float = 1e-6):
|
||||||
|
return [
|
||||||
|
torch.randn(bs, hidden_size, device="cuda", dtype=input_dtype),
|
||||||
|
torch.randn(hidden_size, device="cuda", dtype=weight_dtype),
|
||||||
|
eps,
|
||||||
|
]
|
||||||
|
|
||||||
|
# Define configurations for different data types
|
||||||
|
configs = [
|
||||||
|
(torch.bfloat16, torch.bfloat16),
|
||||||
|
(torch.float16, torch.float16),
|
||||||
|
(torch.float32, torch.float32),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Register patterns for each configuration - replace with torch_rmsnorm
|
||||||
|
search_fns = [
|
||||||
|
_rms_norm_pattern,
|
||||||
|
_rms_norm_pattern_float32_weights,
|
||||||
|
]
|
||||||
|
for search_fn in search_fns:
|
||||||
|
for input_dtype, weight_dtype in configs:
|
||||||
|
register_ad_pattern(
|
||||||
|
search_fn=search_fn,
|
||||||
|
replace_fn=_rms_norm_to_torch_rmsnorm,
|
||||||
|
patterns=patterns,
|
||||||
|
dummy_args=dummy_args(input_dtype, weight_dtype),
|
||||||
|
op_ignore_types={},
|
||||||
|
scalar_workaround={"eps": 1e-6},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Pattern matching for gated RMSNorm
|
||||||
|
B, S, H = 2, 3, 4096
|
||||||
|
group_size = 512
|
||||||
|
eps = 1e-5
|
||||||
|
|
||||||
|
def make_dummy_args_gated(group_size: int, eps: float) -> list:
|
||||||
|
x = torch.randn(B, S, H, dtype=torch.float32)
|
||||||
|
w = torch.randn(H, dtype=torch.float32)
|
||||||
|
g = torch.randn(B, S, H, dtype=torch.float32)
|
||||||
|
return [x, w, g, eps, group_size]
|
||||||
|
|
||||||
|
op_ignore_types = {
|
||||||
|
torch.ops.aten.reshape.default: (int, list, tuple),
|
||||||
|
torch.ops.aten.view.default: (int, list, tuple),
|
||||||
|
torch.ops.aten.mean.dim: (list, tuple),
|
||||||
|
torch.ops.aten.to.dtype: (torch.dtype,),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Register pattern for gated RMSNorm - replace with torch_rmsnorm_gated
|
||||||
|
register_ad_pattern(
|
||||||
|
search_fn=_gated_rmsnorm_pattern_ref,
|
||||||
|
replace_fn=_gated_rmsnorm_to_torch_rmsnorm_gated,
|
||||||
|
patterns=patterns,
|
||||||
|
dummy_args=make_dummy_args_gated(group_size, eps),
|
||||||
|
op_ignore_types=op_ignore_types,
|
||||||
|
scalar_workaround={"eps": eps, "group_size": group_size},
|
||||||
|
skip_duplicates=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
cnt = patterns.apply(graph)
|
||||||
|
|
||||||
|
info = TransformInfo(
|
||||||
|
skipped=False, num_matches=cnt, is_clean=cnt == 0, has_valid_shapes=cnt == 0
|
||||||
|
)
|
||||||
|
|
||||||
|
return gm, info
|
||||||
|
|
||||||
|
|
||||||
class FuseRMSNormConfig(TransformConfig):
|
class FuseRMSNormConfig(TransformConfig):
|
||||||
"""Configuration for the RMSNorm fusion transform."""
|
"""Configuration for the RMSNorm fusion transform."""
|
||||||
|
|
||||||
@ -102,11 +221,10 @@ class FuseRMSNormConfig(TransformConfig):
|
|||||||
|
|
||||||
@TransformRegistry.register("fuse_rmsnorm")
|
@TransformRegistry.register("fuse_rmsnorm")
|
||||||
class FuseRMSNorm(BaseTransform):
|
class FuseRMSNorm(BaseTransform):
|
||||||
"""Matches and replaces RMSNorm patterns (regular and gated) in the graph with optimized implementations.
|
"""Fuses torch_rmsnorm ops with the selected backend implementation.
|
||||||
|
|
||||||
This function sets up pattern matching to identify both regular and gated RMSNorm operations in the graph
|
This transform runs in the post_load_fusion stage and replaces torch_rmsnorm ops
|
||||||
and replaces them with optimized implementations. It uses dummy tensors to register
|
with the specified backend implementation (flashinfer, triton, or torch).
|
||||||
the pattern matching rules.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
gm: Input graph module to transform.
|
gm: Input graph module to transform.
|
||||||
@ -114,7 +232,7 @@ class FuseRMSNorm(BaseTransform):
|
|||||||
gated_rmsnorm_backend: Backend to use for gated RMSNorm computation (currently only "triton").
|
gated_rmsnorm_backend: Backend to use for gated RMSNorm computation (currently only "triton").
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Transformed graph module with optimized RMSNorm operations.
|
Transformed graph module with backend-specific RMSNorm operations.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
config: FuseRMSNormConfig
|
config: FuseRMSNormConfig
|
||||||
@ -144,72 +262,39 @@ class FuseRMSNorm(BaseTransform):
|
|||||||
)
|
)
|
||||||
|
|
||||||
graph = gm.graph
|
graph = gm.graph
|
||||||
patterns = ADPatternMatcherPass()
|
backend = self.config.rmsnorm_backend.lower()
|
||||||
|
target_op = _BACKEND_OPS[backend]
|
||||||
|
cnt = 0
|
||||||
|
|
||||||
# Pattern matching for regular RMSNorm
|
# Replace torch_rmsnorm ops with the selected backend
|
||||||
bs = 2
|
for node in list(graph.nodes):
|
||||||
hidden_size = 512
|
if is_op(node, torch.ops.auto_deploy.torch_rmsnorm):
|
||||||
|
# Replace with the selected backend op
|
||||||
|
with graph.inserting_after(node):
|
||||||
|
new_node: Node = graph.call_function(
|
||||||
|
target_op,
|
||||||
|
args=node.args,
|
||||||
|
kwargs=node.kwargs,
|
||||||
|
)
|
||||||
|
node.replace_all_uses_with(new_node)
|
||||||
|
graph.erase_node(node)
|
||||||
|
cnt += 1
|
||||||
|
|
||||||
def dummy_args(input_dtype: torch.dtype, weight_dtype: torch.dtype, eps: float = 1e-6):
|
# Replace torch_rmsnorm_gated ops with triton_rmsnorm_gated
|
||||||
return [
|
for node in list(graph.nodes):
|
||||||
torch.randn(bs, hidden_size, device="cuda", dtype=input_dtype),
|
if is_op(node, torch.ops.auto_deploy.torch_rmsnorm_gated):
|
||||||
torch.randn(hidden_size, device="cuda", dtype=weight_dtype),
|
# Replace with triton_rmsnorm_gated op
|
||||||
eps,
|
with graph.inserting_after(node):
|
||||||
]
|
new_node: Node = graph.call_function(
|
||||||
|
torch.ops.auto_deploy.triton_rmsnorm_gated,
|
||||||
|
args=node.args,
|
||||||
|
kwargs=node.kwargs,
|
||||||
|
)
|
||||||
|
node.replace_all_uses_with(new_node)
|
||||||
|
graph.erase_node(node)
|
||||||
|
cnt += 1
|
||||||
|
|
||||||
# Define configurations for different data types
|
gm.recompile()
|
||||||
configs = [
|
|
||||||
(torch.bfloat16, torch.bfloat16),
|
|
||||||
(torch.float16, torch.float16),
|
|
||||||
(torch.float32, torch.float32),
|
|
||||||
]
|
|
||||||
|
|
||||||
# Register patterns for each configuration
|
|
||||||
search_fns = [
|
|
||||||
_rms_norm_pattern,
|
|
||||||
_rms_norm_pattern_float32_weights,
|
|
||||||
]
|
|
||||||
for search_fn in search_fns:
|
|
||||||
for input_dtype, weight_dtype in configs:
|
|
||||||
register_ad_pattern(
|
|
||||||
search_fn=search_fn,
|
|
||||||
replace_fn=partial(_rms_norm_replacement, backend=self.config.rmsnorm_backend),
|
|
||||||
patterns=patterns,
|
|
||||||
dummy_args=dummy_args(input_dtype, weight_dtype),
|
|
||||||
op_ignore_types={},
|
|
||||||
scalar_workaround={"eps": 1e-6},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Pattern matching for gated RMSNorm
|
|
||||||
B, S, H = 2, 3, 4096
|
|
||||||
group_size = 512
|
|
||||||
eps = 1e-5
|
|
||||||
|
|
||||||
def make_dummy_args_gated(group_size: int, eps: float) -> list:
|
|
||||||
x = torch.randn(B, S, H, dtype=torch.float32)
|
|
||||||
w = torch.randn(H, dtype=torch.float32)
|
|
||||||
g = torch.randn(B, S, H, dtype=torch.float32)
|
|
||||||
return [x, w, g, eps, group_size]
|
|
||||||
|
|
||||||
op_ignore_types = {
|
|
||||||
torch.ops.aten.reshape.default: (int, list, tuple),
|
|
||||||
torch.ops.aten.view.default: (int, list, tuple),
|
|
||||||
torch.ops.aten.mean.dim: (list, tuple),
|
|
||||||
torch.ops.aten.to.dtype: (torch.dtype,),
|
|
||||||
}
|
|
||||||
|
|
||||||
# Register pattern for gated RMSNorm
|
|
||||||
register_ad_pattern(
|
|
||||||
search_fn=_gated_rmsnorm_pattern_ref,
|
|
||||||
replace_fn=_gated_rmsnorm_replacement,
|
|
||||||
patterns=patterns,
|
|
||||||
dummy_args=make_dummy_args_gated(group_size, eps),
|
|
||||||
op_ignore_types=op_ignore_types,
|
|
||||||
scalar_workaround={"eps": eps, "group_size": group_size},
|
|
||||||
skip_duplicates=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
cnt = patterns.apply(graph)
|
|
||||||
|
|
||||||
info = TransformInfo(
|
info = TransformInfo(
|
||||||
skipped=False, num_matches=cnt, is_clean=cnt == 0, has_valid_shapes=cnt == 0
|
skipped=False, num_matches=cnt, is_clean=cnt == 0, has_valid_shapes=cnt == 0
|
||||||
@ -239,13 +324,25 @@ def _gated_rmsnorm_pattern_ref(
|
|||||||
return y
|
return y
|
||||||
|
|
||||||
|
|
||||||
def _gated_rmsnorm_replacement(
|
def _gated_rmsnorm_to_torch_rmsnorm_gated(
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
weight: torch.Tensor,
|
weight: torch.Tensor,
|
||||||
gate: torch.Tensor,
|
gate: torch.Tensor,
|
||||||
eps: float,
|
eps: float,
|
||||||
group_size: int,
|
group_size: int,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return torch.ops.auto_deploy.triton_rmsnorm_gated(
|
"""Replace gated RMSNorm pattern with torch_rmsnorm_gated op (standardized representation).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input tensor to normalize.
|
||||||
|
weight: Scaling weights for the normalized output.
|
||||||
|
gate: Gate tensor for gated normalization.
|
||||||
|
eps: Small constant for numerical stability.
|
||||||
|
group_size: Size of groups for grouped normalization.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Normalized and gated tensor using torch_rmsnorm_gated.
|
||||||
|
"""
|
||||||
|
return torch.ops.auto_deploy.torch_rmsnorm_gated(
|
||||||
x, weight, gate, float(eps), int(group_size), False
|
x, weight, gate, float(eps), int(group_size), False
|
||||||
)
|
)
|
||||||
|
|||||||
@ -87,6 +87,9 @@ def _test_allreduce_fusion(port: int, ModuleCls, strategy: str):
|
|||||||
gm_transformed = InferenceOptimizer(
|
gm_transformed = InferenceOptimizer(
|
||||||
None,
|
None,
|
||||||
{
|
{
|
||||||
|
"match_rmsnorm_pattern": {
|
||||||
|
"stage": "pattern_matcher",
|
||||||
|
},
|
||||||
"detect_sharding": {
|
"detect_sharding": {
|
||||||
"stage": "post_export",
|
"stage": "post_export",
|
||||||
"allreduce_strategy": strategy,
|
"allreduce_strategy": strategy,
|
||||||
|
|||||||
@ -64,6 +64,9 @@ def _run_test(model, op, variant):
|
|||||||
gm_transformed = InferenceOptimizer(
|
gm_transformed = InferenceOptimizer(
|
||||||
None,
|
None,
|
||||||
{
|
{
|
||||||
|
"match_rmsnorm_pattern": {
|
||||||
|
"stage": "pattern_matcher",
|
||||||
|
},
|
||||||
"fuse_rmsnorm": {
|
"fuse_rmsnorm": {
|
||||||
"stage": "post_load_fusion",
|
"stage": "post_load_fusion",
|
||||||
"gated_rmsnorm_backend": "triton",
|
"gated_rmsnorm_backend": "triton",
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user