[None][feat] Enable rms norm fusion for Nemotron MOE (#8563)

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
Signed-off-by: junq <22017000+QiJune@users.noreply.github.com>
Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
Co-authored-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
Co-authored-by: QI JUN <22017000+QiJune@users.noreply.github.com>
Co-authored-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
This commit is contained in:
Suyog Gupta 2025-10-22 21:09:42 -07:00 committed by GitHub
parent a7c2c8c212
commit 2956978da3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 84 additions and 25 deletions

View File

@ -124,7 +124,7 @@ transforms:
# TODO (lucaslie): add backend selection as part of configurable inference optimizers
# check if we can fuse rmsnorm
stage: post_load_fusion
backend: flashinfer
backend: triton
requires_shape_prop: true
fuse_gated_rmsnorm:
stage: post_load_fusion

View File

@ -250,6 +250,12 @@ def _invoke_kernel(
EM = sorted_token_ids.numel()
if EM == 0:
return
if A.size(0) < config["BLOCK_SIZE_M"]:
# optimize for small batch_size.
# We assume that top_ids of each token is unique,
# so num_valid_experts <= batch_size <= BLOCK_SIZE_M,
# and we can skip some invalid blocks.
EM = min(sorted_token_ids.size(0), A.size(0) * top_k * config["BLOCK_SIZE_M"])
def _grid(META):
return (

View File

@ -14,7 +14,9 @@ def rms_norm_kernel(
N_COLS: tl.constexpr,
BLOCK_N: tl.constexpr,
):
"""Rms norm kernel."""
"""Rms norm kernel.
Forces weights to be in float32 for the kernel.
"""
prog_id = tl.program_id(0)
offsets = tl.arange(0, BLOCK_N)
@ -26,7 +28,7 @@ def rms_norm_kernel(
var = tl.sum(xf * xf, 0) * float(1.0 / N_COLS)
out = xf / tl.sqrt(var + eps)
out = (w * out).to(x.dtype)
out = (w.to(tl.float32) * out).to(x.dtype)
out_ptr = output + prog_id * input_row_stride
tl.store(out_ptr + offsets, out, mask=offsets < N_COLS)

View File

@ -46,6 +46,26 @@ def _rms_norm_pattern(data: torch.Tensor, weight: torch.Tensor, eps: float) -> t
return weight * data.to(input_dtype)
def _rms_norm_pattern_float32_weights(
data: torch.Tensor, weight: torch.Tensor, eps: float
) -> torch.Tensor:
"""Implements the RMSNorm pattern for pattern matching.
Args:
data: Input tensor to normalize.
weight: Scaling weights for the normalized output.
eps: Small constant for numerical stability.
Returns:
Normalized and scaled tensor.
"""
input_dtype = data.dtype
data = data.to(torch.float32)
variance = data.pow(2).mean(-1, keepdim=True)
data = data * torch.rsqrt(variance + eps)
return (weight.to(torch.float32) * data).to(input_dtype)
def _rms_norm_replacement(
data: torch.Tensor, weight: torch.Tensor, eps: float, backend: str
) -> torch.Tensor:
@ -132,15 +152,20 @@ class FuseRMSNorm(BaseTransform):
]
# Register patterns for each configuration
for input_dtype, weight_dtype in configs:
register_ad_pattern(
search_fn=_rms_norm_pattern,
replace_fn=partial(_rms_norm_replacement, backend=self.config.backend),
patterns=patterns,
dummy_args=dummy_args(input_dtype, weight_dtype),
op_ignore_types={},
scalar_workaround={"eps": 1e-6},
)
search_fns = [
_rms_norm_pattern,
_rms_norm_pattern_float32_weights,
]
for search_fn in search_fns:
for input_dtype, weight_dtype in configs:
register_ad_pattern(
search_fn=search_fn,
replace_fn=partial(_rms_norm_replacement, backend=self.config.backend),
patterns=patterns,
dummy_args=dummy_args(input_dtype, weight_dtype),
op_ignore_types={},
scalar_workaround={"eps": 1e-6},
)
cnt = patterns.apply(graph)

View File

@ -23,11 +23,28 @@ class RMSNorm(torch.nn.Module):
return self.weight * hidden_states.to(input_dtype)
class NemotronH_RMSNorm(torch.nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = torch.nn.Parameter(torch.ones(hidden_size, device="cuda"))
self.eps = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
return (self.weight.to(torch.float32) * hidden_states).to(input_dtype)
class TestModel(torch.nn.Module):
def __init__(self, eps: float = 1e-6):
def __init__(self, eps: float = 1e-6, use_nemotron_h: bool = False):
super().__init__()
self.linear1 = torch.nn.Linear(1024, 1024, device="cuda", dtype=torch.float16)
self.rms_norm = RMSNorm(1024, eps).to(torch.float16)
if use_nemotron_h:
self.rms_norm = NemotronH_RMSNorm(1024, eps).to(torch.float16)
else:
self.rms_norm = RMSNorm(1024, eps).to(torch.float16)
self.linear2 = torch.nn.Linear(1024, 1024, device="cuda", dtype=torch.float16)
def forward(self, x):
@ -37,20 +54,10 @@ class TestModel(torch.nn.Module):
return x
@pytest.mark.parametrize("eps", [1e-2, 1e-6])
@pytest.mark.parametrize(
"variant, op",
[
("flashinfer", torch.ops.auto_deploy.flashinfer_rms_norm),
("triton", torch.ops.auto_deploy.triton_rms_norm),
("torch", torch.ops.auto_deploy.torch_rmsnorm),
],
)
def test_rmsnorm_fusion(eps, variant, op):
def _run_test(model, op, variant):
def checker(gm):
return any(is_op(n, op) for n in gm.graph.nodes)
model = TestModel(eps)
x = torch.randn(2, 1024, device="cuda", dtype=torch.float16)
dynamic_shapes = {0: Dim("batch_size", max=8)}
gm = torch_export_to_gm(model, args=(x,), dynamic_shapes=(dynamic_shapes,), clone=True)
@ -77,3 +84,22 @@ def test_rmsnorm_fusion(eps, variant, op):
y_transformed = gm_transformed(new_input)
y_model = model(new_input)
torch.testing.assert_close(y_transformed, y_model, atol=1e-3, rtol=1e-3)
@pytest.mark.parametrize("eps", [1e-2, 1e-6])
@pytest.mark.parametrize(
"variant, op",
[
("flashinfer", torch.ops.auto_deploy.flashinfer_rms_norm),
("torch", torch.ops.auto_deploy.torch_rmsnorm),
],
)
def test_rmsnorm_fusion(eps, variant, op):
model = TestModel(eps)
_run_test(model, op, variant)
def test_rmsnorm_fusion_nemotron_h():
# Only the triton backend supports the nemotron h rmsnorm
model = TestModel(eps=1e-6, use_nemotron_h=True)
_run_test(model, torch.ops.auto_deploy.triton_rms_norm, "triton")