[perf] Add gemma RMS AR fusion (#42646)

Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Co-authored-by: Jiangyun Zhu <riverclouds.zhu@qq.com>
This commit is contained in:
Jiahan Chang (Cyrus)
2026-06-04 16:33:59 +08:00
committed by GitHub
parent 1bdc60ed53
commit d0975a4b50
3 changed files with 225 additions and 16 deletions
@@ -14,6 +14,7 @@ from vllm.compilation.passes.fusion.allreduce_rms_fusion import (
AllReduceFusionPass,
RocmAiterAllReduceFusionPass,
)
from vllm.compilation.passes.fx_utils import find_op_nodes
from vllm.compilation.passes.utility.fix_functionalization import (
FixFunctionalizationPass,
)
@@ -33,7 +34,7 @@ from vllm.distributed.parallel_state import (
init_distributed_environment,
initialize_model_parallel,
)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8StaticTensorSym,
)
@@ -91,6 +92,49 @@ class TestAllReduceRMSNormModel(torch.nn.Module):
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
class TestAllReduceGemmaRMSNormModel(torch.nn.Module):
def __init__(
self,
hidden_size=16,
token_num=16,
eps=1e-6,
dtype: torch.dtype = torch.float16,
):
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
self.norm = [GemmaRMSNorm(hidden_size, eps) for _ in range(4)]
# Non-trivial weight (~Gemma range) so (1 + w) exercises the scale path.
for n in self.norm:
n.weight.data.normal_(mean=0.0, std=0.1)
self.w = [torch.rand(hidden_size, hidden_size) for _ in range(3)]
def forward(self, x):
# avoid having graph input be an arg to a pattern directly
z = torch.relu(x)
x = resid = tensor_model_parallel_all_reduce(z)
y = self.norm[0](x)
z2 = torch.mm(y, self.w[0])
x2 = tensor_model_parallel_all_reduce(z2)
y2, resid = self.norm[1](x2, resid)
z3 = torch.mm(y2, self.w[1])
x3 = tensor_model_parallel_all_reduce(z3)
y3, resid = self.norm[2](x3, resid)
z4 = torch.mm(y3, self.w[2])
x4 = tensor_model_parallel_all_reduce(z4)
y4, resid = self.norm[3](x4, resid)
return y4
def ops_in_model_before(self):
return [torch.ops.vllm.all_reduce.default]
def ops_in_model_after(self):
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
quant_key = kFp8StaticTensorSym
@@ -209,6 +253,15 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
"test_model, enable_quant_fp8_custom_op, use_aiter",
[
(TestAllReduceRMSNormModel, False, IS_AITER_FOUND),
pytest.param(
TestAllReduceGemmaRMSNormModel,
False,
False,
marks=pytest.mark.skipif(
current_platform.is_rocm(),
reason="Not supported on ROCm platform",
),
),
pytest.param(
TestAllReduceRMSNormStaticQuantFP8Model,
True,
@@ -404,4 +457,9 @@ def all_reduce_fusion_pass_on_test_model(
)
backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False)
backend.check_after_ops(model.ops_in_model_after())
if test_model_cls is TestAllReduceGemmaRMSNormModel:
fused_op = torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default
fused_nodes = list(find_op_nodes(fused_op, backend.graph_post_pass))
assert fused_nodes
assert all(n.kwargs.get("weight_bias") == 1.0 for n in fused_nodes)
del all_reduce_fusion_pass