[compile] Add FlashInfer FP8 async TP fusion and preserve allreduce fusion ordering #27893 (#39505)

Signed-off-by: baonudesifeizhai <baonudesifeizhai@gmail.com>
Signed-off-by: baonudesifeizhai <85092850+baonudesifeizhai@users.noreply.github.com>
Signed-off-by: roG0d <baonudesifeizhai@gmail.com>
This commit is contained in:
baonudesifeizhai
2026-05-01 01:08:34 -04:00
committed by GitHub
parent 947138b6c2
commit c3868bbbe4
6 changed files with 397 additions and 30 deletions
+17 -3
View File
@@ -24,10 +24,24 @@ def mock_cuda_platform():
def _mock_platform(is_cuda: bool = True, capability: tuple[int, int] | None = None):
mock_platform = MagicMock()
mock_platform.is_cuda.return_value = is_cuda
if capability is not None:
mock_platform.get_device_capability.return_value = DeviceCapability(
*capability
device_capability = (
DeviceCapability(*capability) if capability is not None else None
)
mock_platform.get_device_capability.return_value = device_capability
def is_device_capability_family(
requested_capability: int, device_id: int = 0
) -> bool:
current_capability = mock_platform.get_device_capability(
device_id=device_id
)
if current_capability is None:
return False
return current_capability.major == (requested_capability // 10)
mock_platform.is_device_capability_family.side_effect = (
is_device_capability_family
)
with patch("vllm.platforms.current_platform", mock_platform):
yield mock_platform
+6
View File
@@ -97,6 +97,12 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
f"attention backend '{attn_backend.backend.name}'"
)
if attn_backend.backend.name == "FLASHINFER":
from vllm.utils.flashinfer import supports_trtllm_attention
if not supports_trtllm_attention():
matches = matches._replace(attn_quant_fusion=0)
# TODO: remove this after finishing migration from envs to model kwargs
if model_name == "openai/gpt-oss-20b":
from .common import is_blackwell
@@ -13,7 +13,6 @@ from .common import (
AttentionBackendCase,
Matches,
custom_ops_combos,
is_blackwell,
)
from .models import (
FLASHINFER_ATTN,
@@ -46,14 +45,9 @@ def test_tp2_async_tp_fp8_fusions(
custom_ops: str,
inductor_graph_partition: bool,
run_e2e_fusion_test,
monkeypatch,
):
matches = matches_fn(n_layers)
if is_blackwell():
# Disable FlashInfer scaled_mm FP8 as it's not supported in async tp patterns
monkeypatch.setenv("VLLM_DISABLED_KERNELS", "FlashInferFP8ScaledMMLinearKernel")
# Reduce size of model and skip weight loading time
model_kwargs["hf_overrides"] = hf_overrides(n_layers)
model_kwargs["load_format"] = "dummy"
@@ -173,14 +167,9 @@ def test_tp2_sp_ar_rms_fp8_fusions(
custom_ops: str,
inductor_graph_partition: bool,
run_e2e_fusion_test,
monkeypatch,
):
matches = matches_fn(n_layers)
if is_blackwell():
# Disable FlashInfer scaled_mm FP8 as it's not supported in async tp patterns
monkeypatch.setenv("VLLM_DISABLED_KERNELS", "FlashInferFP8ScaledMMLinearKernel")
# Reduce size of model and skip weight loading time
model_kwargs["hf_overrides"] = hf_overrides(n_layers)
model_kwargs["load_format"] = "dummy"
@@ -1,8 +1,12 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from contextlib import suppress
import torch
import torch._inductor.pattern_matcher as pm
import torch.distributed.distributed_c10d as c10d
import torch.fx as fx
from torch._inductor.pattern_matcher import PatternMatcherPass
from torch.distributed._symmetric_memory import enable_symm_mem_for_group
@@ -15,15 +19,197 @@ from vllm.distributed.parallel_state import (
)
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
from ..inductor_pass import enable_fake_mode
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
from ..vllm_inductor_pass import (
VllmFusionPatternMatcherPass,
VllmInductorPass,
VllmPatternMatcherPass,
VllmPatternReplacement,
)
FP8_DTYPE = current_platform.fp8_dtype()
logger = init_logger(__name__)
def _flashinfer_scaled_mm_out(
A: torch.Tensor,
B: torch.Tensor,
*,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out: torch.Tensor,
bias: torch.Tensor | None = None,
scale_result: torch.Tensor | None = None,
out_dtype: torch.dtype | None = None,
use_fast_accum: bool = False,
) -> None:
# Import lazily to avoid a circular import during module initialization
# when docs or other tooling import the pass without FlashInfer.
from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm_out
assert bias is None, "FlashInfer symm_mem adapter does not support bias"
assert scale_result is None, (
"FlashInfer symm_mem adapter does not support result scaling"
)
assert not use_fast_accum, (
"FlashInfer symm_mem adapter does not support use_fast_accum"
)
assert A.ndim == 2 and B.ndim == 2 and out.ndim == 2, (
"FlashInfer symm_mem adapter expects 2D inputs and output"
)
assert scale_a.numel() == 1 and scale_b.numel() == 1, (
"FlashInfer symm_mem adapter only supports tensor-wise FP8 scales"
)
flashinfer_scaled_fp8_mm_out(
A,
B,
scale_a,
scale_b,
out=out,
out_dtype=out_dtype or out.dtype,
)
def fused_flashinfer_scaled_matmul_reduce_scatter_fake(
A: torch.Tensor,
B: torch.Tensor,
A_scale: torch.Tensor,
B_scale: torch.Tensor,
reduce_op: str,
orig_scatter_dim: int,
scatter_dim_after_maybe_reshape: int,
group_name: str,
output_shape: list[int],
out_dtype: torch.dtype | None = None,
) -> torch.Tensor:
world_size = c10d._resolve_process_group(group_name).size()
result_shape = list(output_shape)
result_shape[orig_scatter_dim] //= world_size
return torch.empty(
result_shape,
dtype=out_dtype or torch.bfloat16,
device=A.device,
)
def fused_flashinfer_scaled_matmul_reduce_scatter(
A: torch.Tensor,
B: torch.Tensor,
A_scale: torch.Tensor,
B_scale: torch.Tensor,
reduce_op: str,
orig_scatter_dim: int,
scatter_dim_after_maybe_reshape: int,
group_name: str,
output_shape: list[int],
out_dtype: torch.dtype | None = None,
) -> torch.Tensor:
assert orig_scatter_dim == 0 and scatter_dim_after_maybe_reshape == 0, (
"FlashInfer symm_mem adapter currently only supports scatter_dim=0"
)
world_size = c10d._resolve_process_group(group_name).size()
assert A.ndim == 2 and B.ndim == 2, "FlashInfer symm_mem adapter expects 2D inputs"
assert A.is_contiguous(), "FlashInfer symm_mem adapter expects contiguous A"
assert A_scale.numel() == 1 and B_scale.numel() == 1, (
"FlashInfer symm_mem adapter only supports tensor-wise FP8 scales"
)
assert A.shape[0] % world_size == 0, (
"FlashInfer symm_mem adapter expects M divisible by world size"
)
kwargs = {
"scale_b": B_scale,
"bias": None,
"scale_result": None,
"out_dtype": out_dtype,
"use_fast_accum": False,
}
return torch.distributed._symmetric_memory._fused_scaled_matmul_reduce_scatter_impl(
mm_out_op=_flashinfer_scaled_mm_out,
A=A,
B=B,
A_scale=A_scale,
kwargs=kwargs,
out_dtype=out_dtype,
reduce_op=reduce_op,
orig_scatter_dim=orig_scatter_dim,
scatter_dim_after_maybe_reshape=scatter_dim_after_maybe_reshape,
group_name=group_name,
output_shape=output_shape,
)
def fused_all_gather_flashinfer_scaled_matmul_fake(
A_shard: torch.Tensor,
B: torch.Tensor,
A_scale: torch.Tensor,
B_scale: torch.Tensor,
gather_dim: int,
group_name: str,
out_dtype: torch.dtype | None = None,
) -> torch.Tensor:
world_size = c10d._resolve_process_group(group_name).size()
output_shape = list(A_shard.shape)
output_shape[gather_dim] *= world_size
output_shape[-1] = B.shape[1]
return torch.empty(
output_shape,
dtype=out_dtype or torch.bfloat16,
device=A_shard.device,
)
def fused_all_gather_flashinfer_scaled_matmul(
A_shard: torch.Tensor,
B: torch.Tensor,
A_scale: torch.Tensor,
B_scale: torch.Tensor,
gather_dim: int,
group_name: str,
out_dtype: torch.dtype | None = None,
) -> torch.Tensor:
assert gather_dim == 0, (
"FlashInfer symm_mem adapter currently only supports gather_dim=0"
)
_, outputs = torch.distributed._symmetric_memory._fused_all_gather_matmul_impl(
mm_out_op=_flashinfer_scaled_mm_out,
A_shard=A_shard,
Bs=[B],
A_scale=A_scale,
kwargs_list=[
{
"scale_b": B_scale,
"bias": None,
"scale_result": None,
"out_dtype": out_dtype,
"use_fast_accum": False,
}
],
out_dtypes=[out_dtype],
gather_dim=gather_dim,
group_name=group_name,
return_A=False,
)
return outputs[0]
direct_register_custom_op(
op_name="fused_flashinfer_scaled_matmul_reduce_scatter",
op_func=fused_flashinfer_scaled_matmul_reduce_scatter,
fake_impl=fused_flashinfer_scaled_matmul_reduce_scatter_fake,
)
direct_register_custom_op(
op_name="fused_all_gather_flashinfer_scaled_matmul",
op_func=fused_all_gather_flashinfer_scaled_matmul,
fake_impl=fused_all_gather_flashinfer_scaled_matmul_fake,
)
class BasePattern:
def __init__(self, dtype: torch.dtype, device: str | None) -> None:
self.dtype = dtype
@@ -371,39 +557,169 @@ class AllGatherCutlassScaledMMPattern(BasePattern):
)
class AsyncTPPass(VllmPatternMatcherPass):
class FlashInferBMMFP8ReduceScatterPattern(
BasePattern, VllmPatternReplacement[..., torch.Tensor]
):
def get_inputs(self) -> list[torch.Tensor]:
a_2d = torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
b_2d = (
torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
.contiguous()
.transpose(0, 1)
)
a_scale = torch.empty([1], device=self.device, dtype=torch.float32)
b_scale = torch.empty([1], device=self.device, dtype=torch.float32)
return [a_2d, b_2d, a_scale, b_scale]
@property
def pattern(self) -> Callable[..., torch.Tensor]:
def _pattern(
a_2d: torch.Tensor,
b_2d: torch.Tensor,
a_scale: torch.Tensor,
b_scale: torch.Tensor,
) -> torch.Tensor:
bmm = torch.ops.vllm.bmm_fp8.default(
torch.ops.aten.unsqueeze.default(a_2d, 0),
torch.ops.aten.unsqueeze.default(b_2d, 0),
a_scale,
b_scale,
self.dtype,
"auto",
)
output = torch.ops.aten.reshape.default(bmm, list(bmm.shape[1:]))
return torch.ops.vllm.reduce_scatter.default(
output,
dim=0,
world_size=self.tp_size,
group_name=self.tp.unique_name,
)
return _pattern
@property
def replacement(self) -> Callable[..., torch.Tensor]:
def _replacement(
a_2d: torch.Tensor,
b_2d: torch.Tensor,
a_scale: torch.Tensor,
b_scale: torch.Tensor,
) -> torch.Tensor:
return torch.ops.vllm.fused_flashinfer_scaled_matmul_reduce_scatter.default(
a_2d,
b_2d,
a_scale,
b_scale,
"sum",
0,
0,
self.tp.device_group.group_name,
[a_2d.shape[0], b_2d.shape[1]],
self.dtype,
)
return _replacement
class FlashInferAllGatherBMMFP8Pattern(
BasePattern, VllmPatternReplacement[..., torch.Tensor]
):
def get_inputs(self) -> list[torch.Tensor]:
a_shard_2d = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE)
b_2d = (
torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
.contiguous()
.transpose(0, 1)
)
a_scale = torch.empty([1], device=self.device, dtype=torch.float32)
b_scale = torch.empty([1], device=self.device, dtype=torch.float32)
return [a_shard_2d, b_2d, a_scale, b_scale]
@property
def pattern(self) -> Callable[..., torch.Tensor]:
def _pattern(
a_shard_2d: torch.Tensor,
b_2d: torch.Tensor,
a_scale: torch.Tensor,
b_scale: torch.Tensor,
) -> torch.Tensor:
all_gather = torch.ops.vllm.all_gather.default(
a_shard_2d,
dim=0,
world_size=self.tp_size,
group_name=self.tp.unique_name,
)
return torch.ops.vllm.bmm_fp8.default(
torch.ops.aten.unsqueeze.default(all_gather, 0),
torch.ops.aten.unsqueeze.default(b_2d, 0),
a_scale,
b_scale,
self.dtype,
"auto",
)
return _pattern
@property
def replacement(self) -> Callable[..., torch.Tensor]:
def _replacement(
a_shard_2d: torch.Tensor,
b_2d: torch.Tensor,
a_scale: torch.Tensor,
b_scale: torch.Tensor,
) -> torch.Tensor:
fused = torch.ops.vllm.fused_all_gather_flashinfer_scaled_matmul.default(
a_shard_2d,
b_2d,
a_scale,
b_scale,
0,
self.tp.device_group.group_name,
self.dtype,
)
return torch.ops.aten.unsqueeze.default(fused, 0)
return _replacement
class AsyncTPPass(VllmFusionPatternMatcherPass):
@enable_fake_mode
def __init__(self, config: VllmConfig) -> None:
super().__init__(config)
super().__init__(config, pass_name="async_tp_pass")
# Enable symmetric memory for the TP process group
enable_symm_mem_for_group(get_tp_group().device_group.group_name)
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="async_tp_pass"
)
GEMMReduceScatterPattern(self.model_dtype, self.device).register(self.patterns)
GEMMReduceScatterPattern(self.model_dtype, self.device).register(self.pm_pass)
AllGatherGEMMPattern(self.model_dtype, self.device).register(self.patterns)
AllGatherGEMMPattern(self.model_dtype, self.device).register(self.pm_pass)
# These fusions are enabled only for bfloat16 models because
# `scaled_mm` or `cutlass_scaled_mm` with per-token (row-wise) scaling
# only supports bfloat16 as the output dtype.
if self.model_dtype == torch.bfloat16:
ScaledMMReduceScatterPattern(self.model_dtype, self.device).register(
self.patterns
self.pm_pass
)
AllGatherScaledMMPattern(self.model_dtype, self.device).register(
self.patterns
self.pm_pass
)
CutlassScaledMMReduceScatterPattern(self.model_dtype, self.device).register(
self.patterns
self.pm_pass
)
AllGatherCutlassScaledMMPattern(self.model_dtype, self.device).register(
self.patterns
self.pm_pass
)
with suppress(ImportError):
import vllm.utils.flashinfer # noqa: F401
if hasattr(torch.ops.vllm, "bmm_fp8"):
self.register(
FlashInferAllGatherBMMFP8Pattern(self.model_dtype, self.device)
)
self.register(
FlashInferBMMFP8ReduceScatterPattern(self.model_dtype, self.device)
)
self.dump_patterns(config, self.patterns)
self.dump_patterns(config, self.pm_pass)
def is_applicable_for_range(self, compile_range: Range) -> bool:
# This pass is applied on top of the sequence parallelism pass,
@@ -416,5 +732,6 @@ class AsyncTPPass(VllmPatternMatcherPass):
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph) -> None:
self.matched_count = self.patterns.apply(graph)
self.matched_count = self.pm_pass.apply(graph)
VllmPatternMatcherPass.match_table[self.pass_name] += self.matched_count
logger.debug("Replaced %s patterns", self.matched_count)
@@ -31,6 +31,7 @@ logger = init_logger(__name__)
# Only apply sequence parallelism for models with hidden_size >= threshold
SP_MIN_HIDDEN_SIZE: dict[int, int] = {
90: 8192, # H100: only for models with hidden_size >= 8192
100: 8192, # Blackwell family: only for models with hidden_size >= 8192
}
# Min size per GPU per device capability for sequence parallelism
@@ -38,6 +39,8 @@ SP_MIN_HIDDEN_SIZE: dict[int, int] = {
# This ensures the threshold scales appropriately with tensor parallelism
SP_MIN_PER_GPU_SIZE_MB: dict[int, float] = {
90: 8, # 8MB per GPU for H100
# Use a more conservative threshold on Blackwell so TP8 starts later.
100: 32,
}
@@ -67,7 +70,12 @@ def get_sequence_parallelism_threshold(
capability = current_platform.get_device_capability()
if capability is None:
return None
device_capability = capability.to_int()
# Collapse Blackwell variants (sm100/sm103/...) into one policy bucket.
if current_platform.is_device_capability_family(100):
device_capability = 100
else:
device_capability = capability.to_int()
# Check if device has configured thresholds
min_hidden_size = SP_MIN_HIDDEN_SIZE.get(device_capability)
+33
View File
@@ -715,6 +715,38 @@ def flashinfer_scaled_fp8_mm(
return output
def flashinfer_scaled_fp8_mm_out(
a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out: torch.Tensor,
out_dtype: torch.dtype | None = None,
) -> torch.Tensor:
assert a.ndim == 2 and b.ndim == 2 and out.ndim == 2
assert a.shape[1] == b.shape[0]
assert out.shape == (a.shape[0], b.shape[1])
assert scale_a.numel() == 1 and scale_b.numel() == 1
assert a.dtype == torch.float8_e4m3fn and b.dtype == torch.float8_e4m3fn
assert out.device.type == "cuda"
assert a.is_contiguous()
from flashinfer import bmm_fp8 as bmm_fp8_
bmm_fp8_(
a.unsqueeze(0),
# FlashInfer expects the weight in the same column-major view layout
# consumed by flashinfer_scaled_fp8_mm, so keep the transposed view.
b.unsqueeze(0),
scale_a,
scale_b,
out_dtype or out.dtype,
out.unsqueeze(0),
"auto",
)
return out
def flashinfer_quant_nvfp4_8x4_sf_layout(
a: torch.Tensor, a_global_sf: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
@@ -833,6 +865,7 @@ __all__ = [
"use_trtllm_attention",
"flashinfer_scaled_fp4_mm",
"flashinfer_scaled_fp8_mm",
"flashinfer_scaled_fp8_mm_out",
"flashinfer_quant_nvfp4_8x4_sf_layout",
"flashinfer_fp8_blockscale_gemm",
"should_use_flashinfer_for_blockscale_fp8_gemm",