mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[None][feat] Enable rms norm fusion for Nemotron MOE (#8563)
Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com> Signed-off-by: junq <22017000+QiJune@users.noreply.github.com> Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com> Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com> Co-authored-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com> Co-authored-by: QI JUN <22017000+QiJune@users.noreply.github.com> Co-authored-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
This commit is contained in:
parent
a7c2c8c212
commit
2956978da3
@ -124,7 +124,7 @@ transforms:
|
||||
# TODO (lucaslie): add backend selection as part of configurable inference optimizers
|
||||
# check if we can fuse rmsnorm
|
||||
stage: post_load_fusion
|
||||
backend: flashinfer
|
||||
backend: triton
|
||||
requires_shape_prop: true
|
||||
fuse_gated_rmsnorm:
|
||||
stage: post_load_fusion
|
||||
|
||||
@ -250,6 +250,12 @@ def _invoke_kernel(
|
||||
EM = sorted_token_ids.numel()
|
||||
if EM == 0:
|
||||
return
|
||||
if A.size(0) < config["BLOCK_SIZE_M"]:
|
||||
# optimize for small batch_size.
|
||||
# We assume that top_ids of each token is unique,
|
||||
# so num_valid_experts <= batch_size <= BLOCK_SIZE_M,
|
||||
# and we can skip some invalid blocks.
|
||||
EM = min(sorted_token_ids.size(0), A.size(0) * top_k * config["BLOCK_SIZE_M"])
|
||||
|
||||
def _grid(META):
|
||||
return (
|
||||
|
||||
@ -14,7 +14,9 @@ def rms_norm_kernel(
|
||||
N_COLS: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
"""Rms norm kernel."""
|
||||
"""Rms norm kernel.
|
||||
Forces weights to be in float32 for the kernel.
|
||||
"""
|
||||
prog_id = tl.program_id(0)
|
||||
offsets = tl.arange(0, BLOCK_N)
|
||||
|
||||
@ -26,7 +28,7 @@ def rms_norm_kernel(
|
||||
|
||||
var = tl.sum(xf * xf, 0) * float(1.0 / N_COLS)
|
||||
out = xf / tl.sqrt(var + eps)
|
||||
out = (w * out).to(x.dtype)
|
||||
out = (w.to(tl.float32) * out).to(x.dtype)
|
||||
|
||||
out_ptr = output + prog_id * input_row_stride
|
||||
tl.store(out_ptr + offsets, out, mask=offsets < N_COLS)
|
||||
|
||||
@ -46,6 +46,26 @@ def _rms_norm_pattern(data: torch.Tensor, weight: torch.Tensor, eps: float) -> t
|
||||
return weight * data.to(input_dtype)
|
||||
|
||||
|
||||
def _rms_norm_pattern_float32_weights(
|
||||
data: torch.Tensor, weight: torch.Tensor, eps: float
|
||||
) -> torch.Tensor:
|
||||
"""Implements the RMSNorm pattern for pattern matching.
|
||||
|
||||
Args:
|
||||
data: Input tensor to normalize.
|
||||
weight: Scaling weights for the normalized output.
|
||||
eps: Small constant for numerical stability.
|
||||
|
||||
Returns:
|
||||
Normalized and scaled tensor.
|
||||
"""
|
||||
input_dtype = data.dtype
|
||||
data = data.to(torch.float32)
|
||||
variance = data.pow(2).mean(-1, keepdim=True)
|
||||
data = data * torch.rsqrt(variance + eps)
|
||||
return (weight.to(torch.float32) * data).to(input_dtype)
|
||||
|
||||
|
||||
def _rms_norm_replacement(
|
||||
data: torch.Tensor, weight: torch.Tensor, eps: float, backend: str
|
||||
) -> torch.Tensor:
|
||||
@ -132,15 +152,20 @@ class FuseRMSNorm(BaseTransform):
|
||||
]
|
||||
|
||||
# Register patterns for each configuration
|
||||
for input_dtype, weight_dtype in configs:
|
||||
register_ad_pattern(
|
||||
search_fn=_rms_norm_pattern,
|
||||
replace_fn=partial(_rms_norm_replacement, backend=self.config.backend),
|
||||
patterns=patterns,
|
||||
dummy_args=dummy_args(input_dtype, weight_dtype),
|
||||
op_ignore_types={},
|
||||
scalar_workaround={"eps": 1e-6},
|
||||
)
|
||||
search_fns = [
|
||||
_rms_norm_pattern,
|
||||
_rms_norm_pattern_float32_weights,
|
||||
]
|
||||
for search_fn in search_fns:
|
||||
for input_dtype, weight_dtype in configs:
|
||||
register_ad_pattern(
|
||||
search_fn=search_fn,
|
||||
replace_fn=partial(_rms_norm_replacement, backend=self.config.backend),
|
||||
patterns=patterns,
|
||||
dummy_args=dummy_args(input_dtype, weight_dtype),
|
||||
op_ignore_types={},
|
||||
scalar_workaround={"eps": 1e-6},
|
||||
)
|
||||
|
||||
cnt = patterns.apply(graph)
|
||||
|
||||
|
||||
@ -23,11 +23,28 @@ class RMSNorm(torch.nn.Module):
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
|
||||
class NemotronH_RMSNorm(torch.nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
super().__init__()
|
||||
self.weight = torch.nn.Parameter(torch.ones(hidden_size, device="cuda"))
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
||||
return (self.weight.to(torch.float32) * hidden_states).to(input_dtype)
|
||||
|
||||
|
||||
class TestModel(torch.nn.Module):
|
||||
def __init__(self, eps: float = 1e-6):
|
||||
def __init__(self, eps: float = 1e-6, use_nemotron_h: bool = False):
|
||||
super().__init__()
|
||||
self.linear1 = torch.nn.Linear(1024, 1024, device="cuda", dtype=torch.float16)
|
||||
self.rms_norm = RMSNorm(1024, eps).to(torch.float16)
|
||||
if use_nemotron_h:
|
||||
self.rms_norm = NemotronH_RMSNorm(1024, eps).to(torch.float16)
|
||||
else:
|
||||
self.rms_norm = RMSNorm(1024, eps).to(torch.float16)
|
||||
self.linear2 = torch.nn.Linear(1024, 1024, device="cuda", dtype=torch.float16)
|
||||
|
||||
def forward(self, x):
|
||||
@ -37,20 +54,10 @@ class TestModel(torch.nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
@pytest.mark.parametrize("eps", [1e-2, 1e-6])
|
||||
@pytest.mark.parametrize(
|
||||
"variant, op",
|
||||
[
|
||||
("flashinfer", torch.ops.auto_deploy.flashinfer_rms_norm),
|
||||
("triton", torch.ops.auto_deploy.triton_rms_norm),
|
||||
("torch", torch.ops.auto_deploy.torch_rmsnorm),
|
||||
],
|
||||
)
|
||||
def test_rmsnorm_fusion(eps, variant, op):
|
||||
def _run_test(model, op, variant):
|
||||
def checker(gm):
|
||||
return any(is_op(n, op) for n in gm.graph.nodes)
|
||||
|
||||
model = TestModel(eps)
|
||||
x = torch.randn(2, 1024, device="cuda", dtype=torch.float16)
|
||||
dynamic_shapes = {0: Dim("batch_size", max=8)}
|
||||
gm = torch_export_to_gm(model, args=(x,), dynamic_shapes=(dynamic_shapes,), clone=True)
|
||||
@ -77,3 +84,22 @@ def test_rmsnorm_fusion(eps, variant, op):
|
||||
y_transformed = gm_transformed(new_input)
|
||||
y_model = model(new_input)
|
||||
torch.testing.assert_close(y_transformed, y_model, atol=1e-3, rtol=1e-3)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("eps", [1e-2, 1e-6])
|
||||
@pytest.mark.parametrize(
|
||||
"variant, op",
|
||||
[
|
||||
("flashinfer", torch.ops.auto_deploy.flashinfer_rms_norm),
|
||||
("torch", torch.ops.auto_deploy.torch_rmsnorm),
|
||||
],
|
||||
)
|
||||
def test_rmsnorm_fusion(eps, variant, op):
|
||||
model = TestModel(eps)
|
||||
_run_test(model, op, variant)
|
||||
|
||||
|
||||
def test_rmsnorm_fusion_nemotron_h():
|
||||
# Only the triton backend supports the nemotron h rmsnorm
|
||||
model = TestModel(eps=1e-6, use_nemotron_h=True)
|
||||
_run_test(model, torch.ops.auto_deploy.triton_rms_norm, "triton")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user