diff --git a/benchmarks/__init__.py b/benchmarks/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/benchmarks/kernels/__init__.py b/benchmarks/kernels/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/benchmarks/kernels/ir/__init__.py b/benchmarks/kernels/ir/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/benchmarks/kernels/ir/bench_ir_ops.py b/benchmarks/kernels/ir/bench_ir_ops.py new file mode 100644 index 00000000000..b23c4e8ae32 --- /dev/null +++ b/benchmarks/kernels/ir/bench_ir_ops.py @@ -0,0 +1,378 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Generic benchmark harness for vLLM IR ops. + +Usage: + python benchmarks/kernels/ir/bench_ir_ops.py + python benchmarks/kernels/ir/bench_ir_ops.py --ops rms_norm + python benchmarks/kernels/ir/bench_ir_ops.py --ops rms_norm,silu_mul + python benchmarks/kernels/ir/bench_ir_ops.py --no-cuda-graph + python benchmarks/kernels/ir/bench_ir_ops.py --ops rms_norm --save-path ./results/ +""" + +import argparse +import contextlib +import csv +import dataclasses +import datetime +import math +import os +import subprocess +import sys +import tempfile + +# Ensure repo root is on sys.path so `benchmarks` is importable as a package. +_REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")) +if _REPO_ROOT not in sys.path: + sys.path.insert(0, _REPO_ROOT) + +# Suppress noisy C++ warnings from vllm kernel registration (written to fd 2 +# directly by the dynamic linker, so Python-level sys.stderr redirect won't +# catch them). +_saved_fd = os.dup(2) +try: + with open(os.devnull, "w") as _devnull: + os.dup2(_devnull.fileno(), 2) + import torch + + import vllm.kernels # noqa: E402, F401 +finally: + os.dup2(_saved_fd, 2) + os.close(_saved_fd) + +from tqdm import tqdm # noqa: E402 + +from benchmarks.kernels.ir.shapes import SHAPE_CONFIGS # noqa: E402 # isort: skip +from vllm.ir.op import IrOp # noqa: E402 +from vllm.platforms import current_platform # noqa: E402 +from vllm.triton_utils import triton # noqa: E402 + + +@dataclasses.dataclass(frozen=True) +class BenchConfig: + use_cuda_graph: bool = True + warmup: int = 25 + rep: int = 100 + + +def _pkg_version(name: str) -> str: + from importlib.metadata import PackageNotFoundError, version + + with contextlib.suppress(PackageNotFoundError): + return version(name) + return "not installed" + + +_METADATA_LABELS = { + "timestamp": "Timestamp", + "git_commit": "Git commit", + "vllm": "vLLM", + "pytorch": "PyTorch", + "cuda_runtime": "CUDA runtime", + "triton": "Triton", + "cutlass": "CUTLASS", + "helion": "Helion", + "device": "Device", + "bench_mode": "Bench mode", + "warmup": "Warmup", + "rep": "Repetitions", +} + + +def collect_env_metadata(cfg: BenchConfig) -> dict[str, str]: + from vllm.collect_env import get_env_info + + env = get_env_info() + + git_sha = "unknown" + with contextlib.suppress(subprocess.CalledProcessError, FileNotFoundError): + git_sha = ( + subprocess.check_output( + ["git", "rev-parse", "--short", "HEAD"], stderr=subprocess.DEVNULL + ) + .decode() + .strip() + ) + + device_name = current_platform.get_device_name() + + warmup_note = " ms" if not cfg.use_cuda_graph else " ms (ignored)" + rep_note = " replays" if cfg.use_cuda_graph else " ms" + + return { + "timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), + "git_commit": git_sha, + "vllm": str(env.vllm_version), + "pytorch": str(env.torch_version), + "cuda_runtime": str(env.cuda_runtime_version), + "triton": triton.__version__, + "cutlass": _pkg_version("nvidia-cutlass-dsl"), + "helion": _pkg_version("helion"), + "device": device_name, + "bench_mode": "cuda_graph" if cfg.use_cuda_graph else "eager", + "warmup": f"{cfg.warmup}{warmup_note}", + "rep": f"{cfg.rep}{rep_note}", + } + + +def print_metadata(metadata: dict[str, str]): + print("=" * 60) + for key, val in metadata.items(): + print(f"{_METADATA_LABELS.get(key, key) + ':':<16}{val}") + print("=" * 60) + + +def _clone_args(args: tuple) -> tuple: + return tuple(a.clone() if isinstance(a, torch.Tensor) else a for a in args) + + +# TODO(gmagogsfm): When the `maybe_inplace` PR lands, ops marked as +# inplace=True will mutate bench_args across iterations. Both CUDA graph +# and eager modes will accumulate drift from repeated in-place mutation. +# We need to re-clone inputs per iteration for inplace ops. +def _bench_one(fn, args, cfg: BenchConfig) -> float: + bench_args = _clone_args(args) + bench_fn = lambda: fn(*bench_args) + + if cfg.use_cuda_graph: + ms = triton.testing.do_bench_cudagraph(bench_fn, rep=cfg.rep, quantiles=[0.5]) + else: + ms = triton.testing.do_bench( + bench_fn, warmup=cfg.warmup, rep=cfg.rep, quantiles=[0.5] + ) + return ms * 1000 + + +# TODO(gmagogsfm): Once compiled native implementation lands (#38775), +# the benchmark baseline should be the compiled native (what vLLM runs by +# default) rather than the uncompiled native implementation. +def collect_timings( + op: IrOp, shape_configs: list[dict], cfg: BenchConfig +) -> tuple[list[str], list[str], dict[str, dict[str, float]]]: + def fmt(v) -> str: + return str(v).split(".")[-1] if isinstance(v, torch.dtype) else str(v) + + case_names = [ + "_".join(f"{k}={fmt(v)}" for k, v in kwargs.items()) for kwargs in shape_configs + ] + providers = [n for n, impl in op.impls.items() if impl.supported] + + results: dict[str, dict[str, float]] = {c: {} for c in case_names} + for provider in providers: + impl = op.impls[provider] + desc = f"{op.name} / {provider}" + for case_name, kwargs in tqdm( + zip(case_names, shape_configs), + desc=desc, + total=len(case_names), + unit=" cases", + ): + args = op.generate_inputs(**kwargs) + if impl.supports_args(*args): + results[case_name][provider] = _bench_one(impl.impl_fn, args, cfg) + else: + results[case_name][provider] = float("nan") + + return case_names, providers, results + + +def analyze_results( + op_name: str, + case_names: list[str], + providers: list[str], + results: dict[str, dict[str, float]], +) -> tuple[list[dict[str, str]], list[dict[str, str]], list[str]]: + native_col = "native" + non_native = [p for p in providers if p != native_col] + + header_cols = ["case"] + for p in providers: + header_cols.append(f"{p} (us)") + for p in non_native: + header_cols.append(f"{p} speedup") + + detail_rows: list[dict[str, str]] = [] + speedup_data: dict[str, list[tuple[float, str]]] = {p: [] for p in non_native} + + for case_name in case_names: + timings = results[case_name] + row: dict[str, str] = {"case": case_name} + + for p in providers: + val = timings.get(p, float("nan")) + row[f"{p} (us)"] = f"{val:.2f}" if not math.isnan(val) else "n/a" + + native_us = timings.get(native_col, float("nan")) + for p in non_native: + p_us = timings.get(p, float("nan")) + if not math.isnan(native_us) and not math.isnan(p_us) and p_us > 0: + speedup = native_us / p_us + row[f"{p} speedup"] = f"{speedup:.2f}x" + speedup_data[p].append((speedup, case_name)) + else: + row[f"{p} speedup"] = "n/a" + + detail_rows.append(row) + + summary_rows: list[dict[str, str]] = [] + for p in non_native: + entries = speedup_data[p] + if not entries: + continue + speedups = [s for s, _ in entries] + geomean = math.exp(sum(math.log(s) for s in speedups) / len(speedups)) + best_val, best_case = max(entries) + worst_val, worst_case = min(entries) + wins = sum(1 for s in speedups if s > 1.0) + losses = sum(1 for s in speedups if s < 1.0) + total = len(speedups) + + print(f"\n{p} vs native ({wins}/{total} faster, {losses}/{total} slower):") + print(f" geomean speedup: {geomean:.2f}x") + print(f" best: {best_val:.2f}x ({best_case})") + print(f" worst: {worst_val:.2f}x ({worst_case})") + + summary_rows.append( + { + "op": op_name, + "provider": p, + "geomean_speedup": f"{geomean:.2f}", + "best_speedup": f"{best_val:.2f}", + "best_case": best_case, + "worst_speedup": f"{worst_val:.2f}", + "worst_case": worst_case, + "wins": str(wins), + "losses": str(losses), + "total": str(total), + } + ) + + return detail_rows, summary_rows, header_cols + + +def write_csv(path: str, rows: list[dict[str, str]], fieldnames: list[str]): + with open(path, "w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=fieldnames) + writer.writeheader() + writer.writerows(rows) + + +def save_results( + save_dir: str, + op_name: str, + detail_rows: list[dict[str, str]], + header_cols: list[str], + all_summary_rows: list[dict[str, str]], + metadata: dict[str, str], +): + write_csv( + os.path.join(save_dir, f"{op_name}_detail.csv"), + detail_rows, + header_cols, + ) + if all_summary_rows: + write_csv( + os.path.join(save_dir, "summary.csv"), + all_summary_rows, + list(all_summary_rows[0].keys()), + ) + write_csv( + os.path.join(save_dir, "metadata.csv"), + [metadata], + list(metadata.keys()), + ) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Benchmark vLLM IR ops") + parser.add_argument( + "--ops", + type=str, + default=None, + help="Comma-separated list of op names to benchmark (substring match)", + ) + parser.add_argument( + "--no-cuda-graph", + action="store_true", + help="Disable CUDA graph; use do_bench with L2 cache flushing instead", + ) + parser.add_argument( + "--warmup", + type=int, + default=25, + help="Warmup time in ms (do_bench) or ignored with CUDA graph (default: 25)", + ) + parser.add_argument( + "--rep", + type=int, + default=100, + help="Repetition time in ms (do_bench) or number of graph replays " + "(do_bench_cudagraph) (default: 100)", + ) + parser.add_argument( + "--save-path", + type=str, + default=None, + help="Directory to save results (default: auto-created temp dir)", + ) + return parser.parse_args() + + +def main(): + args = parse_args() + cfg = BenchConfig( + use_cuda_graph=not args.no_cuda_graph, + warmup=args.warmup, + rep=args.rep, + ) + + torch.set_default_device(current_platform.device_type) + + metadata = collect_env_metadata(cfg) + print_metadata(metadata) + + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + save_dir = args.save_path or os.path.join( + tempfile.gettempdir(), f"vllm_ir_bench_{timestamp}" + ) + os.makedirs(save_dir, exist_ok=True) + + op_filters = [f.strip() for f in args.ops.split(",")] if args.ops else None + all_summary_rows: list[dict[str, str]] = [] + + for op in IrOp.registry.values(): + if op_filters and not any(f in op.name for f in op_filters): + continue + if not op.has_input_generator: + print(f"Skipping op '{op.name}': no input generator registered") + continue + if op.name not in SHAPE_CONFIGS: + raise RuntimeError( + f"No benchmark shape config for op '{op.name}'. " + f"Add it to benchmarks/kernels/ir/shapes.py" + ) + + case_names, providers, results = collect_timings( + op, SHAPE_CONFIGS[op.name], cfg + ) + detail_rows, summary_rows, header_cols = analyze_results( + op.name, case_names, providers, results + ) + all_summary_rows.extend(summary_rows) + + save_results( + save_dir, + op.name, + detail_rows, + header_cols, + all_summary_rows, + metadata, + ) + + print(f"\nResults saved to: {save_dir}") + + +if __name__ == "__main__": + main() diff --git a/benchmarks/kernels/ir/shapes.py b/benchmarks/kernels/ir/shapes.py new file mode 100644 index 00000000000..6cc44cf6cec --- /dev/null +++ b/benchmarks/kernels/ir/shapes.py @@ -0,0 +1,29 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Shape configurations for IR op benchmarks. +""" + +import torch + +NUM_TOKENS = [1, 2, 4, 16, 64, 256, 1024, 4096, 16384] +COMMON_HIDDEN_SIZES = [ + 2048, # Llama 3.2 1B, Qwen 3 MoE 30B-A3B, Gemma 3n + 3072, # Gemma 7B/9B + 4096, # Llama 3 8B, Qwen 3 8B, Mistral 7B + 5120, # Llama 4 Scout 17B-16E + 7168, # DeepSeek V3 + 8192, # Llama 3 70B + 16384, # Llama 3 405B +] + +# Each entry maps an op name to a list of kwarg dicts that will be passed +# to that op's registered input generator via op.generate_inputs(**kwargs). +SHAPE_CONFIGS: dict[str, list[dict]] = { + "rms_norm": [ + {"num_tokens": n, "hidden_size": d, "dtype": dtype} + for dtype in [torch.float16, torch.bfloat16, torch.float32] + for d in COMMON_HIDDEN_SIZES + for n in NUM_TOKENS + ], +} diff --git a/tests/ir/ir_test_utils.py b/tests/ir/ir_test_utils.py new file mode 100644 index 00000000000..a82206b6f72 --- /dev/null +++ b/tests/ir/ir_test_utils.py @@ -0,0 +1,46 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Shared test utilities for vLLM IR op correctness tests. +""" + +import torch + +from vllm.ir.op import IrOp + +NUM_TOKENS = [1, 8, 17, 32, 512, 2048] +COMMON_HIDDEN_SIZES = [ + 2048, # Llama 3.2 1B, Qwen 3 MoE 30B-A3B, Gemma 3n + 4096, # Llama 3 8B, Qwen 3 8B + 5120, # Llama 4 Scout 17B-16E + 7168, # DeepSeek V3 + 8192, # Llama 3 70B +] + + +def clone_args(args: tuple) -> tuple: + return tuple(a.clone() if isinstance(a, torch.Tensor) else a for a in args) + + +def supported_providers(op: IrOp) -> list[str]: + return [ + name for name, impl in op.impls.items() if name != "native" and impl.supported + ] + + +def assert_close(op: IrOp, actual, expected): + if isinstance(actual, torch.Tensor): + tol = op.get_tolerance(actual.dtype) + try: + torch.testing.assert_close(actual, expected, **tol) + except AssertionError as e: + raise AssertionError( + f"{e}\n\nTo adjust tolerance, use:\n" + f" ir.ops.{op.name}.override_tolerance(" + f"{actual.dtype}, atol=..., rtol=...)" + ) from None + elif isinstance(actual, (tuple, list)): + for a, ex in zip(actual, expected): + assert_close(op, a, ex) + else: + assert actual == expected diff --git a/tests/ir/test_op.py b/tests/ir/test_op.py index 8d4245a04a9..524497916b6 100644 --- a/tests/ir/test_op.py +++ b/tests/ir/test_op.py @@ -495,3 +495,68 @@ def test_uuid_and_oot(tmp_path: Path): assert uuid2 == uuid assert uuid2 != uuid1 del _custom_mm.impls["impl_mm_oot"] + + +def _test_native(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return x + y + + +def _make_op_with_generator(name: str = "_ig_test"): + op = IrOp(name, _test_native) + + @op.register_input_generator + def _gen(n: int = 4): + x = torch.randn(n, 3) + y = torch.randn(n, 3) + return x, y + + return op + + +def _test_native_single(x: torch.Tensor) -> torch.Tensor: + return x + + +class TestInputGenerator: + def test_no_input_generator_by_default(self): + op = IrOp("_ig_test_no_gen", _test_native_single) + assert not op.has_input_generator + + def test_register_input_generator(self): + op = _make_op_with_generator("_ig_test_reg") + assert op.has_input_generator + + def test_generate_inputs_returns_tuple(self): + op = _make_op_with_generator("_ig_test_tuple") + result = op.generate_inputs(n=2) + assert isinstance(result, tuple) + assert len(result) == 2 + assert result[0].shape == (2, 3) + assert result[1].shape == (2, 3) + + def test_generate_inputs_default_kwargs(self): + op = _make_op_with_generator("_ig_test_default") + result = op.generate_inputs() + assert result[0].shape == (4, 3) + + def test_generate_inputs_without_registration_raises(self): + op = IrOp("_ig_test_no_gen_raises", _test_native_single) + with pytest.raises(RuntimeError, match="No input generator"): + op.generate_inputs() + + +class TestTolerance: + def test_override_and_get_tolerance(self): + op = IrOp("_tol_test", _test_native) + + tol = op.get_tolerance(torch.float32) + assert tol == {"atol": 1e-5, "rtol": 1.3e-6} + + op.override_tolerance(torch.float32, atol=0.1, rtol=0.2) + assert op.get_tolerance(torch.float32) == {"atol": 0.1, "rtol": 0.2} + assert op.get_tolerance(torch.float16) == {"atol": 1e-3, "rtol": 1e-3} + + def test_get_tolerance_raises_for_unknown_dtype(self): + op = IrOp("_tol_test_unknown", _test_native) + with pytest.raises(ValueError, match="No tolerance defined"): + op.get_tolerance(torch.complex64) diff --git a/tests/kernels/ir/test_ir_ops.py b/tests/kernels/ir/test_ir_ops.py new file mode 100644 index 00000000000..1ee36b8f4c9 --- /dev/null +++ b/tests/kernels/ir/test_ir_ops.py @@ -0,0 +1,20 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Meta-tests for vLLM IR op infrastructure. + +Ensures all registered ops have input generators defined. +Per-op correctness tests live alongside their op definitions +(e.g. tests/kernels/ir/test_layernorm.py). +""" + +import vllm.kernels # noqa: F401 — registers provider implementations +from vllm.ir.op import IrOp + + +def test_all_ops_have_input_generator(): + missing = [name for name, op in IrOp.registry.items() if not op.has_input_generator] + assert not missing, ( + f"IR ops without input generators: {missing}. " + f"Register one with @ir.ops..register_input_generator" + ) diff --git a/tests/kernels/ir/test_layernorm.py b/tests/kernels/ir/test_layernorm.py index 3d21169098d..7510ae5010f 100644 --- a/tests/kernels/ir/test_layernorm.py +++ b/tests/kernels/ir/test_layernorm.py @@ -5,17 +5,17 @@ import torch # This registers op implementations import vllm.kernels # noqa: F401 +from tests.ir.ir_test_utils import ( + COMMON_HIDDEN_SIZES, + NUM_TOKENS, + assert_close, + clone_args, + supported_providers, +) from tests.kernels.allclose_default import get_default_rtol from vllm import ir from vllm.platforms import current_platform - -def rms_norm_inputs(n_tokens: int, hidden_size: int, dtype: torch.dtype): - x = torch.randn(n_tokens, hidden_size, dtype=dtype) - weight = torch.rand(hidden_size, dtype=dtype) - return x, weight - - rms_norm_native = ir.ops.rms_norm.impls["native"].impl_fn @@ -40,8 +40,8 @@ def test_rms_norm_registration(): @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) -@pytest.mark.parametrize("n_tokens", [1, 8, 17]) -@pytest.mark.parametrize("hidden_size", [16, 4096, 8192]) +@pytest.mark.parametrize("n_tokens", NUM_TOKENS) +@pytest.mark.parametrize("hidden_size", COMMON_HIDDEN_SIZES) @pytest.mark.parametrize("epsilon", [1e-6, 1e-5]) @pytest.mark.skipif( not current_platform.is_cuda_alike() and not current_platform.is_xpu(), @@ -53,7 +53,9 @@ class TestRMSNorm: torch.set_default_device(current_platform.device_type) def test_native_semantics(self, dtype, n_tokens, hidden_size, epsilon): - x, weight = rms_norm_inputs(4, 8, dtype) + x, weight, epsilon = ir.ops.rms_norm.generate_inputs( + num_tokens=4, hidden_size=8, dtype=dtype, epsilon=epsilon + ) out = rms_norm_native(x, weight, epsilon=epsilon) # Check shape, dtype, device @@ -71,59 +73,59 @@ class TestRMSNorm: out4 = rms_norm_native(x, None, epsilon=epsilon) torch.testing.assert_close(out3, out4) - @pytest.mark.parametrize("provider", ["vllm_c", "aiter", "xpu_kernels"]) + @pytest.mark.parametrize("provider", supported_providers(ir.ops.rms_norm)) def test_impls(self, dtype, n_tokens, hidden_size, epsilon, provider): impl = ir.ops.rms_norm.impls[provider] - if not impl.supported: - pytest.skip(f"{provider} impl not supported on this platform") - - x, weight = rms_norm_inputs(n_tokens, hidden_size, dtype) - args = (x, weight, epsilon, None) - - assert impl.supported - - if provider == "aiter" and dtype not in [torch.float16, torch.bfloat16]: - assert not impl.supports_args(*args) - return - - assert impl.supports_args(*args) - - out_impl = impl.impl_fn(*args) - out_native = rms_norm_native(*args) - - torch.testing.assert_close( - out_impl, out_native, rtol=get_default_rtol(out_impl), atol=1e-3 + x, weight, eps = ir.ops.rms_norm.generate_inputs( + num_tokens=n_tokens, hidden_size=hidden_size, dtype=dtype, epsilon=epsilon ) + args = (x, weight, eps) + + if not impl.supports_args(*args): + pytest.skip(f"{provider} does not support args") + + ref_output = rms_norm_native(*clone_args(args)) + output = impl.impl_fn(*clone_args(args)) + assert_close(ir.ops.rms_norm, output, ref_output) # check that dispatched call matches direct call with ir.ops.rms_norm.set_priority([provider, "native"]): - out_impl2 = ir.ops.rms_norm(*args) - - # exact match - torch.testing.assert_close(out_impl2, out_impl, rtol=0.0, atol=0.0) + out_dispatched = ir.ops.rms_norm(*args) + out_direct = impl.impl_fn(*args) + torch.testing.assert_close(out_dispatched, out_direct, rtol=0.0, atol=0.0) # none of these support variance_size override - assert not impl.supports_args(x, weight, epsilon, 4) - assert not impl.supports_args(x, weight, epsilon, variance_size=4) + assert not impl.supports_args(x, weight, eps, 4) + assert not impl.supports_args(x, weight, eps, variance_size=4) # test weight=None behavior - out_impl_no_weight = impl.impl_fn(x, None, epsilon) - out_impl_unit_weight = impl.impl_fn(x, torch.ones_like(weight), epsilon) - torch.testing.assert_close( - out_impl_no_weight, - out_impl_unit_weight, - rtol=get_default_rtol(out_impl_no_weight), - atol=2e-4, - ) + out_no_weight = impl.impl_fn(x, None, eps) + out_unit_weight = impl.impl_fn(x, torch.ones_like(weight), eps) + assert_close(ir.ops.rms_norm, out_no_weight, out_unit_weight) @pytest.mark.parametrize("provider", ["vllm_c", "aiter", "xpu_kernels", "native"]) def test_torch_opcheck(self, dtype, n_tokens, hidden_size, epsilon, provider): if not ir.ops.rms_norm.impls[provider].supported: pytest.skip(f"{provider} impl not supported on this platform") - x, weight = rms_norm_inputs(n_tokens, hidden_size, dtype) - args = (x, weight, epsilon, None) + args = ir.ops.rms_norm.generate_inputs( + num_tokens=n_tokens, hidden_size=hidden_size, dtype=dtype, epsilon=epsilon + ) # When checking the torch op, we have to set priority and use dispatch with ir.ops.rms_norm.set_priority([provider, "native"]): torch.library.opcheck(torch.ops.vllm_ir.rms_norm, args) + + +@pytest.mark.skipif( + not current_platform.is_rocm(), + reason="aiter is only supported on ROCm", +) +def test_aiter_rejects_unsupported_dtypes(): + torch.set_default_device(current_platform.device_type) + impl = ir.ops.rms_norm.impls["aiter"] + for dtype in [torch.float32, torch.float64]: + args = ir.ops.rms_norm.generate_inputs( + num_tokens=8, hidden_size=4096, dtype=dtype, epsilon=1e-5 + ) + assert not impl.supports_args(*args), f"aiter should reject dtype={dtype}" diff --git a/vllm/ir/op.py b/vllm/ir/op.py index 1cbd78e28f9..5d7c01be1bb 100644 --- a/vllm/ir/op.py +++ b/vllm/ir/op.py @@ -9,10 +9,13 @@ from typing import Any, ClassVar, overload import torch from torch.library import Library, infer_schema +from vllm.ir.tolerances import DEFAULT_TOLERANCES, ToleranceSpec from vllm.ir.util import hash_source, weak_cache from vllm.logger import init_logger from vllm.logging_utils import lazy, tensors_str_no_data +InputGenerator = Callable[..., tuple[Any, ...]] + vllm_ir_lib = Library("vllm_ir", "FRAGMENT") logger = init_logger(__name__) @@ -113,6 +116,8 @@ class IrOp: self.impls: dict[str, IrOpImpl] = {} self._priority_impls: list[IrOpImpl] = [] self._schema_str = infer_schema(native_impl, mutates_args=[]) + self._input_generator: InputGenerator | None = None + self._tolerance_overrides: ToleranceSpec = {} # native implementation self.impls["native"] = IrOpImpl( @@ -316,6 +321,38 @@ class IrOp: def supported_providers(self) -> list[str]: return [p.provider for p in self.impls.values() if p.supported] + @property + def has_input_generator(self) -> bool: + return self._input_generator is not None + + def register_input_generator(self, fn: InputGenerator) -> InputGenerator: + self._input_generator = fn + return fn + + def generate_inputs(self, **kwargs: Any) -> tuple[Any, ...]: + if self._input_generator is None: + raise RuntimeError( + f"No input generator registered for op '{self.name}'. " + f"Use @ir.ops.{self.name}.register_input_generator" + ) + return self._input_generator(**kwargs) + + def override_tolerance( + self, dtype: torch.dtype, *, atol: float, rtol: float + ) -> None: + self._tolerance_overrides[dtype] = {"atol": atol, "rtol": rtol} + + def get_tolerance(self, dtype: torch.dtype) -> dict[str, float]: + if dtype in self._tolerance_overrides: + return self._tolerance_overrides[dtype] + if dtype in DEFAULT_TOLERANCES: + return DEFAULT_TOLERANCES[dtype] + raise ValueError( + f"No tolerance defined for dtype {dtype} in op '{self.name}'. " + f"Use op.override_tolerance({dtype}, atol=..., rtol=...) " + f"or add {dtype} to DEFAULT_TOLERANCES." + ) + class IrOpImpl: def __init__( diff --git a/vllm/ir/ops/layernorm.py b/vllm/ir/ops/layernorm.py index ac0c38a9e4d..981d5e3bd83 100644 --- a/vllm/ir/ops/layernorm.py +++ b/vllm/ir/ops/layernorm.py @@ -19,3 +19,18 @@ def rms_norm( if weight is not None: x = x.to(weight.dtype) * weight return x.to(orig_dtype) + + +@rms_norm.register_input_generator +def _rms_norm_input_generator( + num_tokens: int, hidden_size: int, dtype: torch.dtype, epsilon: float = 1e-5 +) -> tuple: + x = torch.randn(num_tokens, hidden_size, dtype=dtype) + weight = torch.randn(hidden_size, dtype=dtype) + return (x, weight, epsilon) + + +# Reductions in rms_norm accumulate rounding error at large shapes +# (e.g. 32768x16384), causing a few elements out of millions to exceed +# the default float16 tolerance. +rms_norm.override_tolerance(torch.float16, atol=1e-2, rtol=2e-3) diff --git a/vllm/ir/tolerances.py b/vllm/ir/tolerances.py new file mode 100644 index 00000000000..a794a939765 --- /dev/null +++ b/vllm/ir/tolerances.py @@ -0,0 +1,36 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + +ToleranceSpec = dict[torch.dtype, dict[str, float]] + +# Default tolerances for comparing IR op implementations against native. +# These are intentionally conservative (permissive) to avoid false failures +# across different hardware and kernel implementations. Ops that need tighter +# or looser bounds should use override_tolerance. +DEFAULT_TOLERANCES: ToleranceSpec = { + # 52-bit mantissa; machine epsilon ~1.1e-16 + torch.float64: {"atol": 1e-8, "rtol": 1e-8}, + # 23-bit mantissa; machine epsilon ~1.2e-7. + # Values from PyTorch test_transformers.py reference defaults. + torch.float32: {"atol": 1e-5, "rtol": 1.3e-6}, + # 10-bit mantissa; machine epsilon ~9.8e-4. + # Standard tolerance used across vLLM kernel tests. + torch.float16: {"atol": 1e-3, "rtol": 1e-3}, + # 7-bit mantissa; machine epsilon ~7.8e-3. + # Wider rtol than float16 to account for the coarser mantissa. + torch.bfloat16: {"atol": 1e-3, "rtol": 1.6e-2}, + # 3-bit mantissa; machine epsilon ~6.25e-2. + # Derived from vLLM fp8 kernel tests (merge_attn_states, silu_mul_fp8). + torch.float8_e4m3fn: {"atol": 1e-1, "rtol": 1e-1}, + # 2-bit mantissa; machine epsilon ~1.25e-1. + # Wider than e4m3fn due to the smaller mantissa. + torch.float8_e5m2: {"atol": 2e-1, "rtol": 2e-1}, + # 1-bit mantissa; machine epsilon ~2.5e-1. Packed pair format (x2). + # Derived from vLLM fp4 tests (test_silu_mul_nvfp4_quant: atol=3e-1). + torch.float4_e2m1fn_x2: {"atol": 3e-1, "rtol": 3e-1}, + # Integer quantized; off-by-one from rounding is expected. + # rtol=0 because relative error is meaningless for small integers. + torch.int8: {"atol": 1, "rtol": 0}, +}