mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[perf] Add gemma RMS AR fusion (#42646)
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> Co-authored-by: Jiangyun Zhu <riverclouds.zhu@qq.com>
This commit is contained in:
committed by
GitHub
parent
1bdc60ed53
commit
d0975a4b50
@@ -14,6 +14,7 @@ from vllm.compilation.passes.fusion.allreduce_rms_fusion import (
|
||||
AllReduceFusionPass,
|
||||
RocmAiterAllReduceFusionPass,
|
||||
)
|
||||
from vllm.compilation.passes.fx_utils import find_op_nodes
|
||||
from vllm.compilation.passes.utility.fix_functionalization import (
|
||||
FixFunctionalizationPass,
|
||||
)
|
||||
@@ -33,7 +34,7 @@ from vllm.distributed.parallel_state import (
|
||||
init_distributed_environment,
|
||||
initialize_model_parallel,
|
||||
)
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
kFp8StaticTensorSym,
|
||||
)
|
||||
@@ -91,6 +92,49 @@ class TestAllReduceRMSNormModel(torch.nn.Module):
|
||||
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
|
||||
|
||||
|
||||
class TestAllReduceGemmaRMSNormModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size=16,
|
||||
token_num=16,
|
||||
eps=1e-6,
|
||||
dtype: torch.dtype = torch.float16,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.eps = eps
|
||||
self.norm = [GemmaRMSNorm(hidden_size, eps) for _ in range(4)]
|
||||
# Non-trivial weight (~Gemma range) so (1 + w) exercises the scale path.
|
||||
for n in self.norm:
|
||||
n.weight.data.normal_(mean=0.0, std=0.1)
|
||||
self.w = [torch.rand(hidden_size, hidden_size) for _ in range(3)]
|
||||
|
||||
def forward(self, x):
|
||||
# avoid having graph input be an arg to a pattern directly
|
||||
z = torch.relu(x)
|
||||
x = resid = tensor_model_parallel_all_reduce(z)
|
||||
y = self.norm[0](x)
|
||||
|
||||
z2 = torch.mm(y, self.w[0])
|
||||
x2 = tensor_model_parallel_all_reduce(z2)
|
||||
y2, resid = self.norm[1](x2, resid)
|
||||
|
||||
z3 = torch.mm(y2, self.w[1])
|
||||
x3 = tensor_model_parallel_all_reduce(z3)
|
||||
y3, resid = self.norm[2](x3, resid)
|
||||
|
||||
z4 = torch.mm(y3, self.w[2])
|
||||
x4 = tensor_model_parallel_all_reduce(z4)
|
||||
y4, resid = self.norm[3](x4, resid)
|
||||
return y4
|
||||
|
||||
def ops_in_model_before(self):
|
||||
return [torch.ops.vllm.all_reduce.default]
|
||||
|
||||
def ops_in_model_after(self):
|
||||
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
|
||||
|
||||
|
||||
class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
|
||||
quant_key = kFp8StaticTensorSym
|
||||
|
||||
@@ -209,6 +253,15 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
|
||||
"test_model, enable_quant_fp8_custom_op, use_aiter",
|
||||
[
|
||||
(TestAllReduceRMSNormModel, False, IS_AITER_FOUND),
|
||||
pytest.param(
|
||||
TestAllReduceGemmaRMSNormModel,
|
||||
False,
|
||||
False,
|
||||
marks=pytest.mark.skipif(
|
||||
current_platform.is_rocm(),
|
||||
reason="Not supported on ROCm platform",
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
TestAllReduceRMSNormStaticQuantFP8Model,
|
||||
True,
|
||||
@@ -404,4 +457,9 @@ def all_reduce_fusion_pass_on_test_model(
|
||||
)
|
||||
backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False)
|
||||
backend.check_after_ops(model.ops_in_model_after())
|
||||
if test_model_cls is TestAllReduceGemmaRMSNormModel:
|
||||
fused_op = torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default
|
||||
fused_nodes = list(find_op_nodes(fused_op, backend.graph_post_pass))
|
||||
assert fused_nodes
|
||||
assert all(n.kwargs.get("weight_bias") == 1.0 for n in fused_nodes)
|
||||
del all_reduce_fusion_pass
|
||||
|
||||
Reference in New Issue
Block a user