diff --git a/tests/compile/conftest.py b/tests/compile/conftest.py index 6aafac7bcad..1263cce04c6 100644 --- a/tests/compile/conftest.py +++ b/tests/compile/conftest.py @@ -24,10 +24,24 @@ def mock_cuda_platform(): def _mock_platform(is_cuda: bool = True, capability: tuple[int, int] | None = None): mock_platform = MagicMock() mock_platform.is_cuda.return_value = is_cuda - if capability is not None: - mock_platform.get_device_capability.return_value = DeviceCapability( - *capability + device_capability = ( + DeviceCapability(*capability) if capability is not None else None + ) + mock_platform.get_device_capability.return_value = device_capability + + def is_device_capability_family( + requested_capability: int, device_id: int = 0 + ) -> bool: + current_capability = mock_platform.get_device_capability( + device_id=device_id ) + if current_capability is None: + return False + return current_capability.major == (requested_capability // 10) + + mock_platform.is_device_capability_family.side_effect = ( + is_device_capability_family + ) with patch("vllm.platforms.current_platform", mock_platform): yield mock_platform diff --git a/tests/compile/fusions_e2e/conftest.py b/tests/compile/fusions_e2e/conftest.py index b017f88881c..b4922a3fdc5 100644 --- a/tests/compile/fusions_e2e/conftest.py +++ b/tests/compile/fusions_e2e/conftest.py @@ -97,6 +97,12 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn): f"attention backend '{attn_backend.backend.name}'" ) + if attn_backend.backend.name == "FLASHINFER": + from vllm.utils.flashinfer import supports_trtllm_attention + + if not supports_trtllm_attention(): + matches = matches._replace(attn_quant_fusion=0) + # TODO: remove this after finishing migration from envs to model kwargs if model_name == "openai/gpt-oss-20b": from .common import is_blackwell diff --git a/tests/compile/fusions_e2e/test_tp2_async_tp.py b/tests/compile/fusions_e2e/test_tp2_async_tp.py index 609377e6895..baa7bdef0a7 100644 --- a/tests/compile/fusions_e2e/test_tp2_async_tp.py +++ b/tests/compile/fusions_e2e/test_tp2_async_tp.py @@ -13,7 +13,6 @@ from .common import ( AttentionBackendCase, Matches, custom_ops_combos, - is_blackwell, ) from .models import ( FLASHINFER_ATTN, @@ -46,14 +45,9 @@ def test_tp2_async_tp_fp8_fusions( custom_ops: str, inductor_graph_partition: bool, run_e2e_fusion_test, - monkeypatch, ): matches = matches_fn(n_layers) - if is_blackwell(): - # Disable FlashInfer scaled_mm FP8 as it's not supported in async tp patterns - monkeypatch.setenv("VLLM_DISABLED_KERNELS", "FlashInferFP8ScaledMMLinearKernel") - # Reduce size of model and skip weight loading time model_kwargs["hf_overrides"] = hf_overrides(n_layers) model_kwargs["load_format"] = "dummy" @@ -173,14 +167,9 @@ def test_tp2_sp_ar_rms_fp8_fusions( custom_ops: str, inductor_graph_partition: bool, run_e2e_fusion_test, - monkeypatch, ): matches = matches_fn(n_layers) - if is_blackwell(): - # Disable FlashInfer scaled_mm FP8 as it's not supported in async tp patterns - monkeypatch.setenv("VLLM_DISABLED_KERNELS", "FlashInferFP8ScaledMMLinearKernel") - # Reduce size of model and skip weight loading time model_kwargs["hf_overrides"] = hf_overrides(n_layers) model_kwargs["load_format"] = "dummy" diff --git a/vllm/compilation/passes/fusion/collective_fusion.py b/vllm/compilation/passes/fusion/collective_fusion.py index 7c14931f497..2b74eae8dd3 100644 --- a/vllm/compilation/passes/fusion/collective_fusion.py +++ b/vllm/compilation/passes/fusion/collective_fusion.py @@ -1,8 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Callable +from contextlib import suppress + import torch import torch._inductor.pattern_matcher as pm +import torch.distributed.distributed_c10d as c10d import torch.fx as fx from torch._inductor.pattern_matcher import PatternMatcherPass from torch.distributed._symmetric_memory import enable_symm_mem_for_group @@ -15,15 +19,197 @@ from vllm.distributed.parallel_state import ( ) from vllm.logger import init_logger from vllm.platforms import current_platform +from vllm.utils.torch_utils import direct_register_custom_op from ..inductor_pass import enable_fake_mode -from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass +from ..vllm_inductor_pass import ( + VllmFusionPatternMatcherPass, + VllmInductorPass, + VllmPatternMatcherPass, + VllmPatternReplacement, +) FP8_DTYPE = current_platform.fp8_dtype() logger = init_logger(__name__) +def _flashinfer_scaled_mm_out( + A: torch.Tensor, + B: torch.Tensor, + *, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out: torch.Tensor, + bias: torch.Tensor | None = None, + scale_result: torch.Tensor | None = None, + out_dtype: torch.dtype | None = None, + use_fast_accum: bool = False, +) -> None: + # Import lazily to avoid a circular import during module initialization + # when docs or other tooling import the pass without FlashInfer. + from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm_out + + assert bias is None, "FlashInfer symm_mem adapter does not support bias" + assert scale_result is None, ( + "FlashInfer symm_mem adapter does not support result scaling" + ) + assert not use_fast_accum, ( + "FlashInfer symm_mem adapter does not support use_fast_accum" + ) + assert A.ndim == 2 and B.ndim == 2 and out.ndim == 2, ( + "FlashInfer symm_mem adapter expects 2D inputs and output" + ) + assert scale_a.numel() == 1 and scale_b.numel() == 1, ( + "FlashInfer symm_mem adapter only supports tensor-wise FP8 scales" + ) + + flashinfer_scaled_fp8_mm_out( + A, + B, + scale_a, + scale_b, + out=out, + out_dtype=out_dtype or out.dtype, + ) + + +def fused_flashinfer_scaled_matmul_reduce_scatter_fake( + A: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + reduce_op: str, + orig_scatter_dim: int, + scatter_dim_after_maybe_reshape: int, + group_name: str, + output_shape: list[int], + out_dtype: torch.dtype | None = None, +) -> torch.Tensor: + world_size = c10d._resolve_process_group(group_name).size() + result_shape = list(output_shape) + result_shape[orig_scatter_dim] //= world_size + return torch.empty( + result_shape, + dtype=out_dtype or torch.bfloat16, + device=A.device, + ) + + +def fused_flashinfer_scaled_matmul_reduce_scatter( + A: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + reduce_op: str, + orig_scatter_dim: int, + scatter_dim_after_maybe_reshape: int, + group_name: str, + output_shape: list[int], + out_dtype: torch.dtype | None = None, +) -> torch.Tensor: + assert orig_scatter_dim == 0 and scatter_dim_after_maybe_reshape == 0, ( + "FlashInfer symm_mem adapter currently only supports scatter_dim=0" + ) + world_size = c10d._resolve_process_group(group_name).size() + assert A.ndim == 2 and B.ndim == 2, "FlashInfer symm_mem adapter expects 2D inputs" + assert A.is_contiguous(), "FlashInfer symm_mem adapter expects contiguous A" + assert A_scale.numel() == 1 and B_scale.numel() == 1, ( + "FlashInfer symm_mem adapter only supports tensor-wise FP8 scales" + ) + assert A.shape[0] % world_size == 0, ( + "FlashInfer symm_mem adapter expects M divisible by world size" + ) + + kwargs = { + "scale_b": B_scale, + "bias": None, + "scale_result": None, + "out_dtype": out_dtype, + "use_fast_accum": False, + } + return torch.distributed._symmetric_memory._fused_scaled_matmul_reduce_scatter_impl( + mm_out_op=_flashinfer_scaled_mm_out, + A=A, + B=B, + A_scale=A_scale, + kwargs=kwargs, + out_dtype=out_dtype, + reduce_op=reduce_op, + orig_scatter_dim=orig_scatter_dim, + scatter_dim_after_maybe_reshape=scatter_dim_after_maybe_reshape, + group_name=group_name, + output_shape=output_shape, + ) + + +def fused_all_gather_flashinfer_scaled_matmul_fake( + A_shard: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + gather_dim: int, + group_name: str, + out_dtype: torch.dtype | None = None, +) -> torch.Tensor: + world_size = c10d._resolve_process_group(group_name).size() + output_shape = list(A_shard.shape) + output_shape[gather_dim] *= world_size + output_shape[-1] = B.shape[1] + return torch.empty( + output_shape, + dtype=out_dtype or torch.bfloat16, + device=A_shard.device, + ) + + +def fused_all_gather_flashinfer_scaled_matmul( + A_shard: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + gather_dim: int, + group_name: str, + out_dtype: torch.dtype | None = None, +) -> torch.Tensor: + assert gather_dim == 0, ( + "FlashInfer symm_mem adapter currently only supports gather_dim=0" + ) + _, outputs = torch.distributed._symmetric_memory._fused_all_gather_matmul_impl( + mm_out_op=_flashinfer_scaled_mm_out, + A_shard=A_shard, + Bs=[B], + A_scale=A_scale, + kwargs_list=[ + { + "scale_b": B_scale, + "bias": None, + "scale_result": None, + "out_dtype": out_dtype, + "use_fast_accum": False, + } + ], + out_dtypes=[out_dtype], + gather_dim=gather_dim, + group_name=group_name, + return_A=False, + ) + return outputs[0] + + +direct_register_custom_op( + op_name="fused_flashinfer_scaled_matmul_reduce_scatter", + op_func=fused_flashinfer_scaled_matmul_reduce_scatter, + fake_impl=fused_flashinfer_scaled_matmul_reduce_scatter_fake, +) + +direct_register_custom_op( + op_name="fused_all_gather_flashinfer_scaled_matmul", + op_func=fused_all_gather_flashinfer_scaled_matmul, + fake_impl=fused_all_gather_flashinfer_scaled_matmul_fake, +) + + class BasePattern: def __init__(self, dtype: torch.dtype, device: str | None) -> None: self.dtype = dtype @@ -371,39 +557,169 @@ class AllGatherCutlassScaledMMPattern(BasePattern): ) -class AsyncTPPass(VllmPatternMatcherPass): +class FlashInferBMMFP8ReduceScatterPattern( + BasePattern, VllmPatternReplacement[..., torch.Tensor] +): + def get_inputs(self) -> list[torch.Tensor]: + a_2d = torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE) + b_2d = ( + torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE) + .contiguous() + .transpose(0, 1) + ) + a_scale = torch.empty([1], device=self.device, dtype=torch.float32) + b_scale = torch.empty([1], device=self.device, dtype=torch.float32) + return [a_2d, b_2d, a_scale, b_scale] + + @property + def pattern(self) -> Callable[..., torch.Tensor]: + def _pattern( + a_2d: torch.Tensor, + b_2d: torch.Tensor, + a_scale: torch.Tensor, + b_scale: torch.Tensor, + ) -> torch.Tensor: + bmm = torch.ops.vllm.bmm_fp8.default( + torch.ops.aten.unsqueeze.default(a_2d, 0), + torch.ops.aten.unsqueeze.default(b_2d, 0), + a_scale, + b_scale, + self.dtype, + "auto", + ) + output = torch.ops.aten.reshape.default(bmm, list(bmm.shape[1:])) + return torch.ops.vllm.reduce_scatter.default( + output, + dim=0, + world_size=self.tp_size, + group_name=self.tp.unique_name, + ) + + return _pattern + + @property + def replacement(self) -> Callable[..., torch.Tensor]: + def _replacement( + a_2d: torch.Tensor, + b_2d: torch.Tensor, + a_scale: torch.Tensor, + b_scale: torch.Tensor, + ) -> torch.Tensor: + return torch.ops.vllm.fused_flashinfer_scaled_matmul_reduce_scatter.default( + a_2d, + b_2d, + a_scale, + b_scale, + "sum", + 0, + 0, + self.tp.device_group.group_name, + [a_2d.shape[0], b_2d.shape[1]], + self.dtype, + ) + + return _replacement + + +class FlashInferAllGatherBMMFP8Pattern( + BasePattern, VllmPatternReplacement[..., torch.Tensor] +): + def get_inputs(self) -> list[torch.Tensor]: + a_shard_2d = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE) + b_2d = ( + torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE) + .contiguous() + .transpose(0, 1) + ) + a_scale = torch.empty([1], device=self.device, dtype=torch.float32) + b_scale = torch.empty([1], device=self.device, dtype=torch.float32) + return [a_shard_2d, b_2d, a_scale, b_scale] + + @property + def pattern(self) -> Callable[..., torch.Tensor]: + def _pattern( + a_shard_2d: torch.Tensor, + b_2d: torch.Tensor, + a_scale: torch.Tensor, + b_scale: torch.Tensor, + ) -> torch.Tensor: + all_gather = torch.ops.vllm.all_gather.default( + a_shard_2d, + dim=0, + world_size=self.tp_size, + group_name=self.tp.unique_name, + ) + return torch.ops.vllm.bmm_fp8.default( + torch.ops.aten.unsqueeze.default(all_gather, 0), + torch.ops.aten.unsqueeze.default(b_2d, 0), + a_scale, + b_scale, + self.dtype, + "auto", + ) + + return _pattern + + @property + def replacement(self) -> Callable[..., torch.Tensor]: + def _replacement( + a_shard_2d: torch.Tensor, + b_2d: torch.Tensor, + a_scale: torch.Tensor, + b_scale: torch.Tensor, + ) -> torch.Tensor: + fused = torch.ops.vllm.fused_all_gather_flashinfer_scaled_matmul.default( + a_shard_2d, + b_2d, + a_scale, + b_scale, + 0, + self.tp.device_group.group_name, + self.dtype, + ) + return torch.ops.aten.unsqueeze.default(fused, 0) + + return _replacement + + +class AsyncTPPass(VllmFusionPatternMatcherPass): @enable_fake_mode def __init__(self, config: VllmConfig) -> None: - super().__init__(config) + super().__init__(config, pass_name="async_tp_pass") - # Enable symmetric memory for the TP process group enable_symm_mem_for_group(get_tp_group().device_group.group_name) - self.patterns: PatternMatcherPass = PatternMatcherPass( - pass_name="async_tp_pass" - ) - GEMMReduceScatterPattern(self.model_dtype, self.device).register(self.patterns) + GEMMReduceScatterPattern(self.model_dtype, self.device).register(self.pm_pass) - AllGatherGEMMPattern(self.model_dtype, self.device).register(self.patterns) + AllGatherGEMMPattern(self.model_dtype, self.device).register(self.pm_pass) # These fusions are enabled only for bfloat16 models because # `scaled_mm` or `cutlass_scaled_mm` with per-token (row-wise) scaling # only supports bfloat16 as the output dtype. if self.model_dtype == torch.bfloat16: ScaledMMReduceScatterPattern(self.model_dtype, self.device).register( - self.patterns + self.pm_pass ) AllGatherScaledMMPattern(self.model_dtype, self.device).register( - self.patterns + self.pm_pass ) CutlassScaledMMReduceScatterPattern(self.model_dtype, self.device).register( - self.patterns + self.pm_pass ) AllGatherCutlassScaledMMPattern(self.model_dtype, self.device).register( - self.patterns + self.pm_pass ) + with suppress(ImportError): + import vllm.utils.flashinfer # noqa: F401 + if hasattr(torch.ops.vllm, "bmm_fp8"): + self.register( + FlashInferAllGatherBMMFP8Pattern(self.model_dtype, self.device) + ) + self.register( + FlashInferBMMFP8ReduceScatterPattern(self.model_dtype, self.device) + ) - self.dump_patterns(config, self.patterns) + self.dump_patterns(config, self.pm_pass) def is_applicable_for_range(self, compile_range: Range) -> bool: # This pass is applied on top of the sequence parallelism pass, @@ -416,5 +732,6 @@ class AsyncTPPass(VllmPatternMatcherPass): @VllmInductorPass.time_and_log def __call__(self, graph: fx.Graph) -> None: - self.matched_count = self.patterns.apply(graph) + self.matched_count = self.pm_pass.apply(graph) + VllmPatternMatcherPass.match_table[self.pass_name] += self.matched_count logger.debug("Replaced %s patterns", self.matched_count) diff --git a/vllm/compilation/passes/fusion/sequence_parallelism.py b/vllm/compilation/passes/fusion/sequence_parallelism.py index 1eae92ecb6a..35885eeb0b8 100644 --- a/vllm/compilation/passes/fusion/sequence_parallelism.py +++ b/vllm/compilation/passes/fusion/sequence_parallelism.py @@ -31,6 +31,7 @@ logger = init_logger(__name__) # Only apply sequence parallelism for models with hidden_size >= threshold SP_MIN_HIDDEN_SIZE: dict[int, int] = { 90: 8192, # H100: only for models with hidden_size >= 8192 + 100: 8192, # Blackwell family: only for models with hidden_size >= 8192 } # Min size per GPU per device capability for sequence parallelism @@ -38,6 +39,8 @@ SP_MIN_HIDDEN_SIZE: dict[int, int] = { # This ensures the threshold scales appropriately with tensor parallelism SP_MIN_PER_GPU_SIZE_MB: dict[int, float] = { 90: 8, # 8MB per GPU for H100 + # Use a more conservative threshold on Blackwell so TP8 starts later. + 100: 32, } @@ -67,7 +70,12 @@ def get_sequence_parallelism_threshold( capability = current_platform.get_device_capability() if capability is None: return None - device_capability = capability.to_int() + + # Collapse Blackwell variants (sm100/sm103/...) into one policy bucket. + if current_platform.is_device_capability_family(100): + device_capability = 100 + else: + device_capability = capability.to_int() # Check if device has configured thresholds min_hidden_size = SP_MIN_HIDDEN_SIZE.get(device_capability) diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 5672aef301e..828ff08a067 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -715,6 +715,38 @@ def flashinfer_scaled_fp8_mm( return output +def flashinfer_scaled_fp8_mm_out( + a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out: torch.Tensor, + out_dtype: torch.dtype | None = None, +) -> torch.Tensor: + assert a.ndim == 2 and b.ndim == 2 and out.ndim == 2 + assert a.shape[1] == b.shape[0] + assert out.shape == (a.shape[0], b.shape[1]) + assert scale_a.numel() == 1 and scale_b.numel() == 1 + assert a.dtype == torch.float8_e4m3fn and b.dtype == torch.float8_e4m3fn + assert out.device.type == "cuda" + assert a.is_contiguous() + + from flashinfer import bmm_fp8 as bmm_fp8_ + + bmm_fp8_( + a.unsqueeze(0), + # FlashInfer expects the weight in the same column-major view layout + # consumed by flashinfer_scaled_fp8_mm, so keep the transposed view. + b.unsqueeze(0), + scale_a, + scale_b, + out_dtype or out.dtype, + out.unsqueeze(0), + "auto", + ) + return out + + def flashinfer_quant_nvfp4_8x4_sf_layout( a: torch.Tensor, a_global_sf: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: @@ -833,6 +865,7 @@ __all__ = [ "use_trtllm_attention", "flashinfer_scaled_fp4_mm", "flashinfer_scaled_fp8_mm", + "flashinfer_scaled_fp8_mm_out", "flashinfer_quant_nvfp4_8x4_sf_layout", "flashinfer_fp8_blockscale_gemm", "should_use_flashinfer_for_blockscale_fp8_gemm",