From d0975a4b50140a9d953f00955a1cbb2a4945edef Mon Sep 17 00:00:00 2001 From: "Jiahan Chang (Cyrus)" <173873397+jiahanc@users.noreply.github.com> Date: Thu, 4 Jun 2026 16:33:59 +0800 Subject: [PATCH] [perf] Add gemma RMS AR fusion (#42646) Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> Co-authored-by: Jiangyun Zhu --- .../distributed/test_fusion_all_reduce.py | 60 ++++++- .../passes/fusion/allreduce_rms_fusion.py | 163 +++++++++++++++++- vllm/model_executor/layers/layernorm.py | 18 +- 3 files changed, 225 insertions(+), 16 deletions(-) diff --git a/tests/compile/passes/distributed/test_fusion_all_reduce.py b/tests/compile/passes/distributed/test_fusion_all_reduce.py index 1a175b8dd33..4805863057d 100644 --- a/tests/compile/passes/distributed/test_fusion_all_reduce.py +++ b/tests/compile/passes/distributed/test_fusion_all_reduce.py @@ -14,6 +14,7 @@ from vllm.compilation.passes.fusion.allreduce_rms_fusion import ( AllReduceFusionPass, RocmAiterAllReduceFusionPass, ) +from vllm.compilation.passes.fx_utils import find_op_nodes from vllm.compilation.passes.utility.fix_functionalization import ( FixFunctionalizationPass, ) @@ -33,7 +34,7 @@ from vllm.distributed.parallel_state import ( init_distributed_environment, initialize_model_parallel, ) -from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm from vllm.model_executor.layers.quantization.utils.quant_utils import ( kFp8StaticTensorSym, ) @@ -91,6 +92,49 @@ class TestAllReduceRMSNormModel(torch.nn.Module): return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default] +class TestAllReduceGemmaRMSNormModel(torch.nn.Module): + def __init__( + self, + hidden_size=16, + token_num=16, + eps=1e-6, + dtype: torch.dtype = torch.float16, + ): + super().__init__() + self.hidden_size = hidden_size + self.eps = eps + self.norm = [GemmaRMSNorm(hidden_size, eps) for _ in range(4)] + # Non-trivial weight (~Gemma range) so (1 + w) exercises the scale path. + for n in self.norm: + n.weight.data.normal_(mean=0.0, std=0.1) + self.w = [torch.rand(hidden_size, hidden_size) for _ in range(3)] + + def forward(self, x): + # avoid having graph input be an arg to a pattern directly + z = torch.relu(x) + x = resid = tensor_model_parallel_all_reduce(z) + y = self.norm[0](x) + + z2 = torch.mm(y, self.w[0]) + x2 = tensor_model_parallel_all_reduce(z2) + y2, resid = self.norm[1](x2, resid) + + z3 = torch.mm(y2, self.w[1]) + x3 = tensor_model_parallel_all_reduce(z3) + y3, resid = self.norm[2](x3, resid) + + z4 = torch.mm(y3, self.w[2]) + x4 = tensor_model_parallel_all_reduce(z4) + y4, resid = self.norm[3](x4, resid) + return y4 + + def ops_in_model_before(self): + return [torch.ops.vllm.all_reduce.default] + + def ops_in_model_after(self): + return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default] + + class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module): quant_key = kFp8StaticTensorSym @@ -209,6 +253,15 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module): "test_model, enable_quant_fp8_custom_op, use_aiter", [ (TestAllReduceRMSNormModel, False, IS_AITER_FOUND), + pytest.param( + TestAllReduceGemmaRMSNormModel, + False, + False, + marks=pytest.mark.skipif( + current_platform.is_rocm(), + reason="Not supported on ROCm platform", + ), + ), pytest.param( TestAllReduceRMSNormStaticQuantFP8Model, True, @@ -404,4 +457,9 @@ def all_reduce_fusion_pass_on_test_model( ) backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False) backend.check_after_ops(model.ops_in_model_after()) + if test_model_cls is TestAllReduceGemmaRMSNormModel: + fused_op = torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default + fused_nodes = list(find_op_nodes(fused_op, backend.graph_post_pass)) + assert fused_nodes + assert all(n.kwargs.get("weight_bias") == 1.0 for n in fused_nodes) del all_reduce_fusion_pass diff --git a/vllm/compilation/passes/fusion/allreduce_rms_fusion.py b/vllm/compilation/passes/fusion/allreduce_rms_fusion.py index 569fac667eb..324b0266b4d 100644 --- a/vllm/compilation/passes/fusion/allreduce_rms_fusion.py +++ b/vllm/compilation/passes/fusion/allreduce_rms_fusion.py @@ -44,6 +44,24 @@ from .matcher_utils import MatcherQuantFP8 FP8_DTYPE = current_platform.fp8_dtype() +_IR_RMS_NORM_OP = torch.ops.vllm_ir.rms_norm.default +_IR_FUSED_ADD_RMS_NORM_OP = torch.ops.vllm_ir.fused_add_rms_norm.default + + +def _norm_input_weight_dtype_match(match: pm.Match) -> bool: + """Prevent fusion when the norm input and weight dtypes differ (e.g. a Gemma + fp32 weight.float()+1 gamma), covering rms_norm and fused_add_rms_norm.""" + for node in match.nodes: + if node.target == _IR_RMS_NORM_OP: + x, weight = node.args[0], node.args[1] + elif node.target == _IR_FUSED_ADD_RMS_NORM_OP: + x, weight = node.args[0], node.args[2] + else: + continue + if isinstance(x, fx.Node) and isinstance(weight, fx.Node): + return x.meta["val"].dtype == weight.meta["val"].dtype + return True + # The empirical value for small batch PDL_ADVANCE_LAUNCH_TOKENS = 16 @@ -132,6 +150,7 @@ if flashinfer_comm is not None: quant_out: torch.Tensor | None = None, scale_out: torch.Tensor | None = None, scale_factor: torch.Tensor | None = None, + weight_bias: float = 0.0, ) -> None: num_tokens, hidden_size = allreduce_in.shape element_size = allreduce_in.element_size() @@ -208,6 +227,7 @@ if flashinfer_comm is not None: layout_code=layout_code, use_oneshot=use_oneshot, fp32_acc=fp32_acc, + weight_bias=weight_bias, trigger_completion_at_end=num_tokens > PDL_ADVANCE_LAUNCH_TOKENS, ) @@ -225,6 +245,7 @@ if flashinfer_comm is not None: quant_out: torch.Tensor | None = None, scale_out: torch.Tensor | None = None, scale_factor: torch.Tensor | None = None, + weight_bias: float = 0.0, ) -> None: pass @@ -399,14 +420,142 @@ class AllReduceFusedAddRMSNormPattern(BasePattern): # allreduce_in, residual return allreduce[1], allreduce[2] + # extra_check routes a Gemma fp32 gamma to AllReduceFusedAddGemmaRMSNormPattern. pm.register_replacement( - pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass + pattern, + replacement, + self.get_inputs(), + pm.fwd_only, + pm_pass, + extra_check=_norm_input_weight_dtype_match, ) # Same pattern, but only return the output and not residual # (helpful for end of graph where residual is not used again) first_return_only = lambda fn: lambda a, b, c: fn(a, b, c)[0] + pm.register_replacement( + first_return_only(pattern), # type: ignore[no-untyped-call] + first_return_only(replacement), # type: ignore[no-untyped-call] + self.get_inputs(), + pm.fwd_only, + pm_pass, + extra_check=_norm_input_weight_dtype_match, + ) + + +class AllReduceGemmaRMSNormPattern(BasePattern): + """Gemma-style variant of AllReduceRMSNormPattern (no residual).""" + + def __init__( + self, + epsilon: float, + dtype: torch.dtype, + device: str | None, + allreduce_params: FlashInferFusedAllReduceParams, + ) -> None: + super().__init__(dtype, device) + self.epsilon = epsilon + self.allreduce_params = allreduce_params + + def get_inputs(self) -> list[torch.Tensor]: + return [self.empty(5, 16), self.empty(16)] + + def register(self, pm_pass: PatternMatcherPass) -> None: + def pattern( + input: torch.Tensor, weight: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + allreduce_output = tensor_model_parallel_all_reduce(input) + rms = vllm.ir.ops.rms_norm( + allreduce_output, weight.float() + 1.0, self.epsilon + ) + return rms, allreduce_output + + def replacement( + input: torch.Tensor, weight: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + residual = torch.zeros_like(input) + rms_result = torch.empty_like(input) + assert flashinfer_comm is not None, "FlashInfer must be enabled" + allreduce = auto_functionalized( + flashinfer_trtllm_fused_allreduce_norm, + allreduce_in=input, + residual=residual, + norm_out=rms_result, + quant_out=None, + scale_out=None, + rms_gamma=weight, + rms_eps=self.epsilon, + pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm, + weight_bias=1.0, + **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(), + ) + return allreduce[3], allreduce[1] + + pm.register_replacement( + pattern, + replacement, + self.get_inputs(), + pm.fwd_only, + pm_pass, + ) + + +class AllReduceFusedAddGemmaRMSNormPattern(BasePattern): + """Gemma-style variant of AllReduceFusedAddRMSNormPattern (with residual).""" + + def __init__( + self, + epsilon: float, + dtype: torch.dtype, + device: str | None, + allreduce_params: FlashInferFusedAllReduceParams, + ) -> None: + super().__init__(dtype, device) + self.epsilon = epsilon + self.allreduce_params = allreduce_params + + def get_inputs(self) -> list[torch.Tensor]: + input = self.empty(5, 16) + residual = self.empty(5, 16) + weight = self.empty(16) + return [residual, input.to(self.dtype), weight] + + def register(self, pm_pass: PatternMatcherPass) -> None: + def pattern( + residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + allreduce_output = tensor_model_parallel_all_reduce(input) + rms, residual = vllm.ir.ops.fused_add_rms_norm( + allreduce_output, residual, weight.float() + 1.0, self.epsilon + ) + return rms, residual + + def replacement( + residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + assert flashinfer_comm is not None, "FlashInfer must be enabled" + allreduce = auto_functionalized( + flashinfer_trtllm_fused_allreduce_norm, + allreduce_in=input, + residual=residual, + norm_out=None, + quant_out=None, + scale_out=None, + rms_gamma=weight, + rms_eps=self.epsilon, + pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm, + weight_bias=1.0, + **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(), + ) + return allreduce[1], allreduce[2] + + pm.register_replacement( + pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass + ) + + first_return_only = lambda fn: lambda a, b, c: fn(a, b, c)[0] + pm.register_replacement( first_return_only(pattern), # type: ignore[no-untyped-call] first_return_only(replacement), # type: ignore[no-untyped-call] @@ -881,6 +1030,18 @@ class AllReduceFusionPass(VllmPatternMatcherPass): self.device, self.allreduce_params, ).register(self.patterns) + AllReduceGemmaRMSNormPattern( + epsilon, + self.model_dtype, + self.device, + self.allreduce_params, + ).register(self.patterns) + AllReduceFusedAddGemmaRMSNormPattern( + epsilon, + self.model_dtype, + self.device, + self.allreduce_params, + ).register(self.patterns) # WARNING: This is a hack to clear the pattern matcher cache # and allow multiple values of epsilon. diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 23027c821d5..13b0ae78131 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -159,20 +159,10 @@ class GemmaRMSNorm(CustomOp): residual: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: """PyTorch-native implementation equivalent to forward().""" - orig_dtype = x.dtype - weight = self.weight.data.float() + 1.0 - if residual is not None: - x = ( - x.float() + residual.float() - if orig_dtype == torch.float16 - else x + residual - ) - residual = x - # ir.ops.rms_norm handles fp32 upcast internally - out = ir.ops.rms_norm(x, weight, self.variance_epsilon) - return ( - out.to(orig_dtype) if residual is None else (out.to(orig_dtype), residual) - ) + weight = self.weight.float() + 1.0 + if residual is None: + return ir.ops.rms_norm(x, weight, self.variance_epsilon) + return ir.ops.fused_add_rms_norm(x, residual, weight, self.variance_epsilon) def forward_cuda( self,