mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
This commit is contained in:
parent
7b7f1e2ba1
commit
bf16fbd86c
@ -48,6 +48,8 @@ transforms:
|
||||
match_rope_layout:
|
||||
stage: pattern_matcher
|
||||
expected_layout: bsnd
|
||||
match_rmsnorm_pattern:
|
||||
stage: pattern_matcher
|
||||
############################################################################################
|
||||
# RUN TRANSFORMATIONS ON STANDARDIZED GRAPH REPRESENTATION
|
||||
############################################################################################
|
||||
|
||||
@ -88,6 +88,65 @@ def _(input: torch.Tensor, weight: torch.Tensor, eps: float) -> torch.Tensor:
|
||||
return torch.empty_like(input)
|
||||
|
||||
|
||||
@torch.library.custom_op("auto_deploy::torch_rmsnorm_gated", mutates_args=())
|
||||
def torch_rmsnorm_gated(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
gate: torch.Tensor | None,
|
||||
eps: float,
|
||||
group_size: int,
|
||||
norm_before_gate: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""Custom operator for Torch gated RMSNorm implementation.
|
||||
|
||||
Group RMSNorm with optional SiLU gating, using pure PyTorch operations.
|
||||
|
||||
Args:
|
||||
x: Input tensor of shape [..., H].
|
||||
weight: Scaling weights of shape [H].
|
||||
gate: Optional gate tensor with same shape as x, or None.
|
||||
eps: Small constant for numerical stability.
|
||||
group_size: Size of groups for grouped normalization. H must be divisible by group_size.
|
||||
norm_before_gate: If True, apply gating after normalization. If False, apply before.
|
||||
|
||||
Returns:
|
||||
Normalized and optionally gated tensor of shape like x.
|
||||
"""
|
||||
dtype = x.dtype
|
||||
weight = weight.float()
|
||||
x = x.float()
|
||||
z = gate.float() if gate is not None else gate
|
||||
|
||||
if z is not None and not norm_before_gate:
|
||||
x = x * F.silu(z)
|
||||
|
||||
if group_size is None:
|
||||
rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
|
||||
out = x * rstd * weight
|
||||
else:
|
||||
x_group = rearrange(x, "... (g d) -> ... g d", d=group_size)
|
||||
rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) + eps)
|
||||
out = rearrange(x_group * rstd, "... g d -> ... (g d)") * weight
|
||||
|
||||
if z is not None and norm_before_gate:
|
||||
out *= F.silu(z)
|
||||
|
||||
return out.to(dtype)
|
||||
|
||||
|
||||
@torch_rmsnorm_gated.register_fake
|
||||
def _(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
gate: torch.Tensor | None,
|
||||
eps: float,
|
||||
group_size: int,
|
||||
norm_before_gate: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""Fake implementation for the custom operator during tracing."""
|
||||
return x.new_empty(x.shape, dtype=x.dtype)
|
||||
|
||||
|
||||
@torch.library.custom_op("auto_deploy::triton_rmsnorm_gated", mutates_args=())
|
||||
def triton_rmsnorm_gated(
|
||||
x: torch.Tensor,
|
||||
|
||||
@ -32,7 +32,10 @@ from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformReg
|
||||
def _make_allreduce_residual_rmsnorm_pattern(
|
||||
add_order: str = "residual_first", strategy: str = "AUTO"
|
||||
):
|
||||
"""Factory function to create pattern functions for allreduce+residual+rmsnorm fusion.
|
||||
"""Factory function to create pattern functions for allreduce+residual+torch_rmsnorm fusion.
|
||||
|
||||
This pattern matches the graph after match_rmsnorm_pattern has replaced
|
||||
RMSNorm patterns with torch_rmsnorm ops.
|
||||
|
||||
Args:
|
||||
add_order: Either "residual_first" (residual + x) or "x_first" (x + residual)
|
||||
@ -45,15 +48,14 @@ def _make_allreduce_residual_rmsnorm_pattern(
|
||||
def pattern_fn(
|
||||
x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 0.1253
|
||||
):
|
||||
"""Pattern: trtllm_dist_all_reduce(x) -> add residual -> RMSNorm
|
||||
"""Pattern: trtllm_dist_all_reduce(x) -> add residual -> torch_rmsnorm
|
||||
|
||||
Reference PyTorch composition:
|
||||
y = trtllm_dist_all_reduce(x)
|
||||
z = residual + y (or y + residual)
|
||||
normed = RMSNorm(z, weight, eps)
|
||||
normed = torch_rmsnorm(z, weight, eps)
|
||||
Returns (normed, z)
|
||||
"""
|
||||
input_dtype = x.dtype
|
||||
hidden_states = torch.ops.auto_deploy.trtllm_dist_all_reduce(x, strategy)
|
||||
|
||||
# Handle addition order
|
||||
@ -62,11 +64,8 @@ def _make_allreduce_residual_rmsnorm_pattern(
|
||||
else: # x_first
|
||||
add = hidden_states + residual
|
||||
|
||||
hidden_states = add.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + eps)
|
||||
|
||||
normed = weight * hidden_states.to(input_dtype)
|
||||
# Use torch_rmsnorm op (already replaced by match_rmsnorm_pattern)
|
||||
normed = torch.ops.auto_deploy.torch_rmsnorm(add, weight, eps)
|
||||
|
||||
return normed, add
|
||||
|
||||
@ -94,6 +93,9 @@ class FuseAllreduceResidualRMSNorm(BaseTransform):
|
||||
This transform only applies when TRT-LLM ops are used (MPI mode), as it provides
|
||||
optimized fused kernels. The torch backend (demollm mode) does not benefit from
|
||||
this fusion and uses unfused operations.
|
||||
|
||||
Note: This transform expects torch_rmsnorm ops in the graph, which are created
|
||||
by the match_rmsnorm_pattern transform that runs earlier in the pipeline.
|
||||
"""
|
||||
|
||||
def _apply(
|
||||
@ -114,7 +116,6 @@ class FuseAllreduceResidualRMSNorm(BaseTransform):
|
||||
0.1253, # eps
|
||||
]
|
||||
|
||||
op_ignore_types = {torch.ops.aten.to.dtype: (torch.dtype,)}
|
||||
scalar_workaround = {"eps": 0.1253}
|
||||
|
||||
# ============================================================================
|
||||
@ -139,7 +140,6 @@ class FuseAllreduceResidualRMSNorm(BaseTransform):
|
||||
replace_fn=partial(_allreduce_residual_rmsnorm_replacement, strategy=strategy),
|
||||
patterns=patterns,
|
||||
dummy_args=dummy_args,
|
||||
op_ignore_types=op_ignore_types,
|
||||
scalar_workaround=scalar_workaround,
|
||||
)
|
||||
|
||||
@ -149,7 +149,6 @@ class FuseAllreduceResidualRMSNorm(BaseTransform):
|
||||
replace_fn=partial(_allreduce_residual_rmsnorm_replacement, strategy=strategy),
|
||||
patterns=patterns,
|
||||
dummy_args=dummy_args,
|
||||
op_ignore_types=op_ignore_types,
|
||||
scalar_workaround=scalar_workaround,
|
||||
)
|
||||
|
||||
|
||||
@ -1,17 +1,17 @@
|
||||
"""Graph transform to optimize RMSNorm execution using FlashInfer."""
|
||||
|
||||
from functools import partial
|
||||
from typing import Tuple, Type
|
||||
|
||||
import torch
|
||||
from pydantic import Field
|
||||
from torch.fx import GraphModule
|
||||
from torch.fx import GraphModule, Node
|
||||
|
||||
from ...custom_ops.rms_norm import gated_rms_norm_ref
|
||||
from ...models.factory import ModelFactory
|
||||
from ...shim.interface import CachedSequenceInterface
|
||||
|
||||
# It is important to import ADPatternMatcherPass from pattern_matcher.py, not from torch._inductor.pattern_matcher
|
||||
from ...utils.node_utils import is_op
|
||||
from ...utils.pattern_matcher import ADPatternMatcherPass, register_ad_pattern
|
||||
from ..interface import (
|
||||
BaseTransform,
|
||||
@ -66,6 +66,22 @@ def _rms_norm_pattern_float32_weights(
|
||||
return (weight.to(torch.float32) * data).to(input_dtype)
|
||||
|
||||
|
||||
def _rms_norm_to_torch_rmsnorm(
|
||||
data: torch.Tensor, weight: torch.Tensor, eps: float
|
||||
) -> torch.Tensor:
|
||||
"""Replace RMSNorm pattern with torch_rmsnorm op (standardized representation).
|
||||
|
||||
Args:
|
||||
data: Input tensor to normalize.
|
||||
weight: Scaling weights for the normalized output.
|
||||
eps: Small constant for numerical stability.
|
||||
|
||||
Returns:
|
||||
Normalized and scaled tensor using torch_rmsnorm.
|
||||
"""
|
||||
return torch.ops.auto_deploy.torch_rmsnorm(data, weight, eps)
|
||||
|
||||
|
||||
def _rms_norm_replacement(
|
||||
data: torch.Tensor, weight: torch.Tensor, eps: float, backend: str
|
||||
) -> torch.Tensor:
|
||||
@ -87,6 +103,109 @@ def _rms_norm_replacement(
|
||||
return _BACKEND_OPS[backend.lower()](data, weight, eps)
|
||||
|
||||
|
||||
@TransformRegistry.register("match_rmsnorm_pattern")
|
||||
class MatchRMSNormPattern(BaseTransform):
|
||||
"""Matches RMSNorm patterns in the graph and replaces them with torch_rmsnorm op.
|
||||
|
||||
This transform runs in the pattern_matcher stage and standardizes RMSNorm patterns
|
||||
to use torch_rmsnorm op, which can later be fused to a specific backend in the
|
||||
post_load_fusion stage.
|
||||
|
||||
Args:
|
||||
gm: Input graph module to transform.
|
||||
|
||||
Returns:
|
||||
Transformed graph module with standardized torch_rmsnorm operations.
|
||||
"""
|
||||
|
||||
config: TransformConfig
|
||||
|
||||
@classmethod
|
||||
def get_config_class(cls) -> Type[TransformConfig]:
|
||||
return TransformConfig
|
||||
|
||||
def _apply(
|
||||
self,
|
||||
gm: GraphModule,
|
||||
cm: CachedSequenceInterface,
|
||||
factory: ModelFactory,
|
||||
shared_config: SharedConfig,
|
||||
) -> Tuple[GraphModule, TransformInfo]:
|
||||
graph = gm.graph
|
||||
patterns = ADPatternMatcherPass()
|
||||
|
||||
# Pattern matching for regular RMSNorm
|
||||
bs = 2
|
||||
hidden_size = 512
|
||||
|
||||
def dummy_args(input_dtype: torch.dtype, weight_dtype: torch.dtype, eps: float = 1e-6):
|
||||
return [
|
||||
torch.randn(bs, hidden_size, device="cuda", dtype=input_dtype),
|
||||
torch.randn(hidden_size, device="cuda", dtype=weight_dtype),
|
||||
eps,
|
||||
]
|
||||
|
||||
# Define configurations for different data types
|
||||
configs = [
|
||||
(torch.bfloat16, torch.bfloat16),
|
||||
(torch.float16, torch.float16),
|
||||
(torch.float32, torch.float32),
|
||||
]
|
||||
|
||||
# Register patterns for each configuration - replace with torch_rmsnorm
|
||||
search_fns = [
|
||||
_rms_norm_pattern,
|
||||
_rms_norm_pattern_float32_weights,
|
||||
]
|
||||
for search_fn in search_fns:
|
||||
for input_dtype, weight_dtype in configs:
|
||||
register_ad_pattern(
|
||||
search_fn=search_fn,
|
||||
replace_fn=_rms_norm_to_torch_rmsnorm,
|
||||
patterns=patterns,
|
||||
dummy_args=dummy_args(input_dtype, weight_dtype),
|
||||
op_ignore_types={},
|
||||
scalar_workaround={"eps": 1e-6},
|
||||
)
|
||||
|
||||
# Pattern matching for gated RMSNorm
|
||||
B, S, H = 2, 3, 4096
|
||||
group_size = 512
|
||||
eps = 1e-5
|
||||
|
||||
def make_dummy_args_gated(group_size: int, eps: float) -> list:
|
||||
x = torch.randn(B, S, H, dtype=torch.float32)
|
||||
w = torch.randn(H, dtype=torch.float32)
|
||||
g = torch.randn(B, S, H, dtype=torch.float32)
|
||||
return [x, w, g, eps, group_size]
|
||||
|
||||
op_ignore_types = {
|
||||
torch.ops.aten.reshape.default: (int, list, tuple),
|
||||
torch.ops.aten.view.default: (int, list, tuple),
|
||||
torch.ops.aten.mean.dim: (list, tuple),
|
||||
torch.ops.aten.to.dtype: (torch.dtype,),
|
||||
}
|
||||
|
||||
# Register pattern for gated RMSNorm - replace with torch_rmsnorm_gated
|
||||
register_ad_pattern(
|
||||
search_fn=_gated_rmsnorm_pattern_ref,
|
||||
replace_fn=_gated_rmsnorm_to_torch_rmsnorm_gated,
|
||||
patterns=patterns,
|
||||
dummy_args=make_dummy_args_gated(group_size, eps),
|
||||
op_ignore_types=op_ignore_types,
|
||||
scalar_workaround={"eps": eps, "group_size": group_size},
|
||||
skip_duplicates=True,
|
||||
)
|
||||
|
||||
cnt = patterns.apply(graph)
|
||||
|
||||
info = TransformInfo(
|
||||
skipped=False, num_matches=cnt, is_clean=cnt == 0, has_valid_shapes=cnt == 0
|
||||
)
|
||||
|
||||
return gm, info
|
||||
|
||||
|
||||
class FuseRMSNormConfig(TransformConfig):
|
||||
"""Configuration for the RMSNorm fusion transform."""
|
||||
|
||||
@ -102,11 +221,10 @@ class FuseRMSNormConfig(TransformConfig):
|
||||
|
||||
@TransformRegistry.register("fuse_rmsnorm")
|
||||
class FuseRMSNorm(BaseTransform):
|
||||
"""Matches and replaces RMSNorm patterns (regular and gated) in the graph with optimized implementations.
|
||||
"""Fuses torch_rmsnorm ops with the selected backend implementation.
|
||||
|
||||
This function sets up pattern matching to identify both regular and gated RMSNorm operations in the graph
|
||||
and replaces them with optimized implementations. It uses dummy tensors to register
|
||||
the pattern matching rules.
|
||||
This transform runs in the post_load_fusion stage and replaces torch_rmsnorm ops
|
||||
with the specified backend implementation (flashinfer, triton, or torch).
|
||||
|
||||
Args:
|
||||
gm: Input graph module to transform.
|
||||
@ -114,7 +232,7 @@ class FuseRMSNorm(BaseTransform):
|
||||
gated_rmsnorm_backend: Backend to use for gated RMSNorm computation (currently only "triton").
|
||||
|
||||
Returns:
|
||||
Transformed graph module with optimized RMSNorm operations.
|
||||
Transformed graph module with backend-specific RMSNorm operations.
|
||||
"""
|
||||
|
||||
config: FuseRMSNormConfig
|
||||
@ -144,72 +262,39 @@ class FuseRMSNorm(BaseTransform):
|
||||
)
|
||||
|
||||
graph = gm.graph
|
||||
patterns = ADPatternMatcherPass()
|
||||
backend = self.config.rmsnorm_backend.lower()
|
||||
target_op = _BACKEND_OPS[backend]
|
||||
cnt = 0
|
||||
|
||||
# Pattern matching for regular RMSNorm
|
||||
bs = 2
|
||||
hidden_size = 512
|
||||
# Replace torch_rmsnorm ops with the selected backend
|
||||
for node in list(graph.nodes):
|
||||
if is_op(node, torch.ops.auto_deploy.torch_rmsnorm):
|
||||
# Replace with the selected backend op
|
||||
with graph.inserting_after(node):
|
||||
new_node: Node = graph.call_function(
|
||||
target_op,
|
||||
args=node.args,
|
||||
kwargs=node.kwargs,
|
||||
)
|
||||
node.replace_all_uses_with(new_node)
|
||||
graph.erase_node(node)
|
||||
cnt += 1
|
||||
|
||||
def dummy_args(input_dtype: torch.dtype, weight_dtype: torch.dtype, eps: float = 1e-6):
|
||||
return [
|
||||
torch.randn(bs, hidden_size, device="cuda", dtype=input_dtype),
|
||||
torch.randn(hidden_size, device="cuda", dtype=weight_dtype),
|
||||
eps,
|
||||
]
|
||||
# Replace torch_rmsnorm_gated ops with triton_rmsnorm_gated
|
||||
for node in list(graph.nodes):
|
||||
if is_op(node, torch.ops.auto_deploy.torch_rmsnorm_gated):
|
||||
# Replace with triton_rmsnorm_gated op
|
||||
with graph.inserting_after(node):
|
||||
new_node: Node = graph.call_function(
|
||||
torch.ops.auto_deploy.triton_rmsnorm_gated,
|
||||
args=node.args,
|
||||
kwargs=node.kwargs,
|
||||
)
|
||||
node.replace_all_uses_with(new_node)
|
||||
graph.erase_node(node)
|
||||
cnt += 1
|
||||
|
||||
# Define configurations for different data types
|
||||
configs = [
|
||||
(torch.bfloat16, torch.bfloat16),
|
||||
(torch.float16, torch.float16),
|
||||
(torch.float32, torch.float32),
|
||||
]
|
||||
|
||||
# Register patterns for each configuration
|
||||
search_fns = [
|
||||
_rms_norm_pattern,
|
||||
_rms_norm_pattern_float32_weights,
|
||||
]
|
||||
for search_fn in search_fns:
|
||||
for input_dtype, weight_dtype in configs:
|
||||
register_ad_pattern(
|
||||
search_fn=search_fn,
|
||||
replace_fn=partial(_rms_norm_replacement, backend=self.config.rmsnorm_backend),
|
||||
patterns=patterns,
|
||||
dummy_args=dummy_args(input_dtype, weight_dtype),
|
||||
op_ignore_types={},
|
||||
scalar_workaround={"eps": 1e-6},
|
||||
)
|
||||
|
||||
# Pattern matching for gated RMSNorm
|
||||
B, S, H = 2, 3, 4096
|
||||
group_size = 512
|
||||
eps = 1e-5
|
||||
|
||||
def make_dummy_args_gated(group_size: int, eps: float) -> list:
|
||||
x = torch.randn(B, S, H, dtype=torch.float32)
|
||||
w = torch.randn(H, dtype=torch.float32)
|
||||
g = torch.randn(B, S, H, dtype=torch.float32)
|
||||
return [x, w, g, eps, group_size]
|
||||
|
||||
op_ignore_types = {
|
||||
torch.ops.aten.reshape.default: (int, list, tuple),
|
||||
torch.ops.aten.view.default: (int, list, tuple),
|
||||
torch.ops.aten.mean.dim: (list, tuple),
|
||||
torch.ops.aten.to.dtype: (torch.dtype,),
|
||||
}
|
||||
|
||||
# Register pattern for gated RMSNorm
|
||||
register_ad_pattern(
|
||||
search_fn=_gated_rmsnorm_pattern_ref,
|
||||
replace_fn=_gated_rmsnorm_replacement,
|
||||
patterns=patterns,
|
||||
dummy_args=make_dummy_args_gated(group_size, eps),
|
||||
op_ignore_types=op_ignore_types,
|
||||
scalar_workaround={"eps": eps, "group_size": group_size},
|
||||
skip_duplicates=True,
|
||||
)
|
||||
|
||||
cnt = patterns.apply(graph)
|
||||
gm.recompile()
|
||||
|
||||
info = TransformInfo(
|
||||
skipped=False, num_matches=cnt, is_clean=cnt == 0, has_valid_shapes=cnt == 0
|
||||
@ -239,13 +324,25 @@ def _gated_rmsnorm_pattern_ref(
|
||||
return y
|
||||
|
||||
|
||||
def _gated_rmsnorm_replacement(
|
||||
def _gated_rmsnorm_to_torch_rmsnorm_gated(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
gate: torch.Tensor,
|
||||
eps: float,
|
||||
group_size: int,
|
||||
) -> torch.Tensor:
|
||||
return torch.ops.auto_deploy.triton_rmsnorm_gated(
|
||||
"""Replace gated RMSNorm pattern with torch_rmsnorm_gated op (standardized representation).
|
||||
|
||||
Args:
|
||||
x: Input tensor to normalize.
|
||||
weight: Scaling weights for the normalized output.
|
||||
gate: Gate tensor for gated normalization.
|
||||
eps: Small constant for numerical stability.
|
||||
group_size: Size of groups for grouped normalization.
|
||||
|
||||
Returns:
|
||||
Normalized and gated tensor using torch_rmsnorm_gated.
|
||||
"""
|
||||
return torch.ops.auto_deploy.torch_rmsnorm_gated(
|
||||
x, weight, gate, float(eps), int(group_size), False
|
||||
)
|
||||
|
||||
@ -87,6 +87,9 @@ def _test_allreduce_fusion(port: int, ModuleCls, strategy: str):
|
||||
gm_transformed = InferenceOptimizer(
|
||||
None,
|
||||
{
|
||||
"match_rmsnorm_pattern": {
|
||||
"stage": "pattern_matcher",
|
||||
},
|
||||
"detect_sharding": {
|
||||
"stage": "post_export",
|
||||
"allreduce_strategy": strategy,
|
||||
|
||||
@ -64,6 +64,9 @@ def _run_test(model, op, variant):
|
||||
gm_transformed = InferenceOptimizer(
|
||||
None,
|
||||
{
|
||||
"match_rmsnorm_pattern": {
|
||||
"stage": "pattern_matcher",
|
||||
},
|
||||
"fuse_rmsnorm": {
|
||||
"stage": "post_load_fusion",
|
||||
"gated_rmsnorm_backend": "triton",
|
||||
|
||||
Loading…
Reference in New Issue
Block a user