mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[vLLM IR] Add IR op testing and benchmarking infrastructure (#40167)
Signed-off-by: Yanan Cao <gmagogsfm@gmail.com> Co-authored-by: Theresa Shan <Theresa.Shan@amd.com> Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,378 @@
|
||||
#!/usr/bin/env python3
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Generic benchmark harness for vLLM IR ops.
|
||||
|
||||
Usage:
|
||||
python benchmarks/kernels/ir/bench_ir_ops.py
|
||||
python benchmarks/kernels/ir/bench_ir_ops.py --ops rms_norm
|
||||
python benchmarks/kernels/ir/bench_ir_ops.py --ops rms_norm,silu_mul
|
||||
python benchmarks/kernels/ir/bench_ir_ops.py --no-cuda-graph
|
||||
python benchmarks/kernels/ir/bench_ir_ops.py --ops rms_norm --save-path ./results/
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import contextlib
|
||||
import csv
|
||||
import dataclasses
|
||||
import datetime
|
||||
import math
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
# Ensure repo root is on sys.path so `benchmarks` is importable as a package.
|
||||
_REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
|
||||
if _REPO_ROOT not in sys.path:
|
||||
sys.path.insert(0, _REPO_ROOT)
|
||||
|
||||
# Suppress noisy C++ warnings from vllm kernel registration (written to fd 2
|
||||
# directly by the dynamic linker, so Python-level sys.stderr redirect won't
|
||||
# catch them).
|
||||
_saved_fd = os.dup(2)
|
||||
try:
|
||||
with open(os.devnull, "w") as _devnull:
|
||||
os.dup2(_devnull.fileno(), 2)
|
||||
import torch
|
||||
|
||||
import vllm.kernels # noqa: E402, F401
|
||||
finally:
|
||||
os.dup2(_saved_fd, 2)
|
||||
os.close(_saved_fd)
|
||||
|
||||
from tqdm import tqdm # noqa: E402
|
||||
|
||||
from benchmarks.kernels.ir.shapes import SHAPE_CONFIGS # noqa: E402 # isort: skip
|
||||
from vllm.ir.op import IrOp # noqa: E402
|
||||
from vllm.platforms import current_platform # noqa: E402
|
||||
from vllm.triton_utils import triton # noqa: E402
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class BenchConfig:
|
||||
use_cuda_graph: bool = True
|
||||
warmup: int = 25
|
||||
rep: int = 100
|
||||
|
||||
|
||||
def _pkg_version(name: str) -> str:
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
|
||||
with contextlib.suppress(PackageNotFoundError):
|
||||
return version(name)
|
||||
return "not installed"
|
||||
|
||||
|
||||
_METADATA_LABELS = {
|
||||
"timestamp": "Timestamp",
|
||||
"git_commit": "Git commit",
|
||||
"vllm": "vLLM",
|
||||
"pytorch": "PyTorch",
|
||||
"cuda_runtime": "CUDA runtime",
|
||||
"triton": "Triton",
|
||||
"cutlass": "CUTLASS",
|
||||
"helion": "Helion",
|
||||
"device": "Device",
|
||||
"bench_mode": "Bench mode",
|
||||
"warmup": "Warmup",
|
||||
"rep": "Repetitions",
|
||||
}
|
||||
|
||||
|
||||
def collect_env_metadata(cfg: BenchConfig) -> dict[str, str]:
|
||||
from vllm.collect_env import get_env_info
|
||||
|
||||
env = get_env_info()
|
||||
|
||||
git_sha = "unknown"
|
||||
with contextlib.suppress(subprocess.CalledProcessError, FileNotFoundError):
|
||||
git_sha = (
|
||||
subprocess.check_output(
|
||||
["git", "rev-parse", "--short", "HEAD"], stderr=subprocess.DEVNULL
|
||||
)
|
||||
.decode()
|
||||
.strip()
|
||||
)
|
||||
|
||||
device_name = current_platform.get_device_name()
|
||||
|
||||
warmup_note = " ms" if not cfg.use_cuda_graph else " ms (ignored)"
|
||||
rep_note = " replays" if cfg.use_cuda_graph else " ms"
|
||||
|
||||
return {
|
||||
"timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"git_commit": git_sha,
|
||||
"vllm": str(env.vllm_version),
|
||||
"pytorch": str(env.torch_version),
|
||||
"cuda_runtime": str(env.cuda_runtime_version),
|
||||
"triton": triton.__version__,
|
||||
"cutlass": _pkg_version("nvidia-cutlass-dsl"),
|
||||
"helion": _pkg_version("helion"),
|
||||
"device": device_name,
|
||||
"bench_mode": "cuda_graph" if cfg.use_cuda_graph else "eager",
|
||||
"warmup": f"{cfg.warmup}{warmup_note}",
|
||||
"rep": f"{cfg.rep}{rep_note}",
|
||||
}
|
||||
|
||||
|
||||
def print_metadata(metadata: dict[str, str]):
|
||||
print("=" * 60)
|
||||
for key, val in metadata.items():
|
||||
print(f"{_METADATA_LABELS.get(key, key) + ':':<16}{val}")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
def _clone_args(args: tuple) -> tuple:
|
||||
return tuple(a.clone() if isinstance(a, torch.Tensor) else a for a in args)
|
||||
|
||||
|
||||
# TODO(gmagogsfm): When the `maybe_inplace` PR lands, ops marked as
|
||||
# inplace=True will mutate bench_args across iterations. Both CUDA graph
|
||||
# and eager modes will accumulate drift from repeated in-place mutation.
|
||||
# We need to re-clone inputs per iteration for inplace ops.
|
||||
def _bench_one(fn, args, cfg: BenchConfig) -> float:
|
||||
bench_args = _clone_args(args)
|
||||
bench_fn = lambda: fn(*bench_args)
|
||||
|
||||
if cfg.use_cuda_graph:
|
||||
ms = triton.testing.do_bench_cudagraph(bench_fn, rep=cfg.rep, quantiles=[0.5])
|
||||
else:
|
||||
ms = triton.testing.do_bench(
|
||||
bench_fn, warmup=cfg.warmup, rep=cfg.rep, quantiles=[0.5]
|
||||
)
|
||||
return ms * 1000
|
||||
|
||||
|
||||
# TODO(gmagogsfm): Once compiled native implementation lands (#38775),
|
||||
# the benchmark baseline should be the compiled native (what vLLM runs by
|
||||
# default) rather than the uncompiled native implementation.
|
||||
def collect_timings(
|
||||
op: IrOp, shape_configs: list[dict], cfg: BenchConfig
|
||||
) -> tuple[list[str], list[str], dict[str, dict[str, float]]]:
|
||||
def fmt(v) -> str:
|
||||
return str(v).split(".")[-1] if isinstance(v, torch.dtype) else str(v)
|
||||
|
||||
case_names = [
|
||||
"_".join(f"{k}={fmt(v)}" for k, v in kwargs.items()) for kwargs in shape_configs
|
||||
]
|
||||
providers = [n for n, impl in op.impls.items() if impl.supported]
|
||||
|
||||
results: dict[str, dict[str, float]] = {c: {} for c in case_names}
|
||||
for provider in providers:
|
||||
impl = op.impls[provider]
|
||||
desc = f"{op.name} / {provider}"
|
||||
for case_name, kwargs in tqdm(
|
||||
zip(case_names, shape_configs),
|
||||
desc=desc,
|
||||
total=len(case_names),
|
||||
unit=" cases",
|
||||
):
|
||||
args = op.generate_inputs(**kwargs)
|
||||
if impl.supports_args(*args):
|
||||
results[case_name][provider] = _bench_one(impl.impl_fn, args, cfg)
|
||||
else:
|
||||
results[case_name][provider] = float("nan")
|
||||
|
||||
return case_names, providers, results
|
||||
|
||||
|
||||
def analyze_results(
|
||||
op_name: str,
|
||||
case_names: list[str],
|
||||
providers: list[str],
|
||||
results: dict[str, dict[str, float]],
|
||||
) -> tuple[list[dict[str, str]], list[dict[str, str]], list[str]]:
|
||||
native_col = "native"
|
||||
non_native = [p for p in providers if p != native_col]
|
||||
|
||||
header_cols = ["case"]
|
||||
for p in providers:
|
||||
header_cols.append(f"{p} (us)")
|
||||
for p in non_native:
|
||||
header_cols.append(f"{p} speedup")
|
||||
|
||||
detail_rows: list[dict[str, str]] = []
|
||||
speedup_data: dict[str, list[tuple[float, str]]] = {p: [] for p in non_native}
|
||||
|
||||
for case_name in case_names:
|
||||
timings = results[case_name]
|
||||
row: dict[str, str] = {"case": case_name}
|
||||
|
||||
for p in providers:
|
||||
val = timings.get(p, float("nan"))
|
||||
row[f"{p} (us)"] = f"{val:.2f}" if not math.isnan(val) else "n/a"
|
||||
|
||||
native_us = timings.get(native_col, float("nan"))
|
||||
for p in non_native:
|
||||
p_us = timings.get(p, float("nan"))
|
||||
if not math.isnan(native_us) and not math.isnan(p_us) and p_us > 0:
|
||||
speedup = native_us / p_us
|
||||
row[f"{p} speedup"] = f"{speedup:.2f}x"
|
||||
speedup_data[p].append((speedup, case_name))
|
||||
else:
|
||||
row[f"{p} speedup"] = "n/a"
|
||||
|
||||
detail_rows.append(row)
|
||||
|
||||
summary_rows: list[dict[str, str]] = []
|
||||
for p in non_native:
|
||||
entries = speedup_data[p]
|
||||
if not entries:
|
||||
continue
|
||||
speedups = [s for s, _ in entries]
|
||||
geomean = math.exp(sum(math.log(s) for s in speedups) / len(speedups))
|
||||
best_val, best_case = max(entries)
|
||||
worst_val, worst_case = min(entries)
|
||||
wins = sum(1 for s in speedups if s > 1.0)
|
||||
losses = sum(1 for s in speedups if s < 1.0)
|
||||
total = len(speedups)
|
||||
|
||||
print(f"\n{p} vs native ({wins}/{total} faster, {losses}/{total} slower):")
|
||||
print(f" geomean speedup: {geomean:.2f}x")
|
||||
print(f" best: {best_val:.2f}x ({best_case})")
|
||||
print(f" worst: {worst_val:.2f}x ({worst_case})")
|
||||
|
||||
summary_rows.append(
|
||||
{
|
||||
"op": op_name,
|
||||
"provider": p,
|
||||
"geomean_speedup": f"{geomean:.2f}",
|
||||
"best_speedup": f"{best_val:.2f}",
|
||||
"best_case": best_case,
|
||||
"worst_speedup": f"{worst_val:.2f}",
|
||||
"worst_case": worst_case,
|
||||
"wins": str(wins),
|
||||
"losses": str(losses),
|
||||
"total": str(total),
|
||||
}
|
||||
)
|
||||
|
||||
return detail_rows, summary_rows, header_cols
|
||||
|
||||
|
||||
def write_csv(path: str, rows: list[dict[str, str]], fieldnames: list[str]):
|
||||
with open(path, "w", newline="") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
||||
writer.writeheader()
|
||||
writer.writerows(rows)
|
||||
|
||||
|
||||
def save_results(
|
||||
save_dir: str,
|
||||
op_name: str,
|
||||
detail_rows: list[dict[str, str]],
|
||||
header_cols: list[str],
|
||||
all_summary_rows: list[dict[str, str]],
|
||||
metadata: dict[str, str],
|
||||
):
|
||||
write_csv(
|
||||
os.path.join(save_dir, f"{op_name}_detail.csv"),
|
||||
detail_rows,
|
||||
header_cols,
|
||||
)
|
||||
if all_summary_rows:
|
||||
write_csv(
|
||||
os.path.join(save_dir, "summary.csv"),
|
||||
all_summary_rows,
|
||||
list(all_summary_rows[0].keys()),
|
||||
)
|
||||
write_csv(
|
||||
os.path.join(save_dir, "metadata.csv"),
|
||||
[metadata],
|
||||
list(metadata.keys()),
|
||||
)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Benchmark vLLM IR ops")
|
||||
parser.add_argument(
|
||||
"--ops",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Comma-separated list of op names to benchmark (substring match)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-cuda-graph",
|
||||
action="store_true",
|
||||
help="Disable CUDA graph; use do_bench with L2 cache flushing instead",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--warmup",
|
||||
type=int,
|
||||
default=25,
|
||||
help="Warmup time in ms (do_bench) or ignored with CUDA graph (default: 25)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rep",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Repetition time in ms (do_bench) or number of graph replays "
|
||||
"(do_bench_cudagraph) (default: 100)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Directory to save results (default: auto-created temp dir)",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
cfg = BenchConfig(
|
||||
use_cuda_graph=not args.no_cuda_graph,
|
||||
warmup=args.warmup,
|
||||
rep=args.rep,
|
||||
)
|
||||
|
||||
torch.set_default_device(current_platform.device_type)
|
||||
|
||||
metadata = collect_env_metadata(cfg)
|
||||
print_metadata(metadata)
|
||||
|
||||
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
save_dir = args.save_path or os.path.join(
|
||||
tempfile.gettempdir(), f"vllm_ir_bench_{timestamp}"
|
||||
)
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
op_filters = [f.strip() for f in args.ops.split(",")] if args.ops else None
|
||||
all_summary_rows: list[dict[str, str]] = []
|
||||
|
||||
for op in IrOp.registry.values():
|
||||
if op_filters and not any(f in op.name for f in op_filters):
|
||||
continue
|
||||
if not op.has_input_generator:
|
||||
print(f"Skipping op '{op.name}': no input generator registered")
|
||||
continue
|
||||
if op.name not in SHAPE_CONFIGS:
|
||||
raise RuntimeError(
|
||||
f"No benchmark shape config for op '{op.name}'. "
|
||||
f"Add it to benchmarks/kernels/ir/shapes.py"
|
||||
)
|
||||
|
||||
case_names, providers, results = collect_timings(
|
||||
op, SHAPE_CONFIGS[op.name], cfg
|
||||
)
|
||||
detail_rows, summary_rows, header_cols = analyze_results(
|
||||
op.name, case_names, providers, results
|
||||
)
|
||||
all_summary_rows.extend(summary_rows)
|
||||
|
||||
save_results(
|
||||
save_dir,
|
||||
op.name,
|
||||
detail_rows,
|
||||
header_cols,
|
||||
all_summary_rows,
|
||||
metadata,
|
||||
)
|
||||
|
||||
print(f"\nResults saved to: {save_dir}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,29 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Shape configurations for IR op benchmarks.
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
NUM_TOKENS = [1, 2, 4, 16, 64, 256, 1024, 4096, 16384]
|
||||
COMMON_HIDDEN_SIZES = [
|
||||
2048, # Llama 3.2 1B, Qwen 3 MoE 30B-A3B, Gemma 3n
|
||||
3072, # Gemma 7B/9B
|
||||
4096, # Llama 3 8B, Qwen 3 8B, Mistral 7B
|
||||
5120, # Llama 4 Scout 17B-16E
|
||||
7168, # DeepSeek V3
|
||||
8192, # Llama 3 70B
|
||||
16384, # Llama 3 405B
|
||||
]
|
||||
|
||||
# Each entry maps an op name to a list of kwarg dicts that will be passed
|
||||
# to that op's registered input generator via op.generate_inputs(**kwargs).
|
||||
SHAPE_CONFIGS: dict[str, list[dict]] = {
|
||||
"rms_norm": [
|
||||
{"num_tokens": n, "hidden_size": d, "dtype": dtype}
|
||||
for dtype in [torch.float16, torch.bfloat16, torch.float32]
|
||||
for d in COMMON_HIDDEN_SIZES
|
||||
for n in NUM_TOKENS
|
||||
],
|
||||
}
|
||||
@@ -0,0 +1,46 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Shared test utilities for vLLM IR op correctness tests.
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.ir.op import IrOp
|
||||
|
||||
NUM_TOKENS = [1, 8, 17, 32, 512, 2048]
|
||||
COMMON_HIDDEN_SIZES = [
|
||||
2048, # Llama 3.2 1B, Qwen 3 MoE 30B-A3B, Gemma 3n
|
||||
4096, # Llama 3 8B, Qwen 3 8B
|
||||
5120, # Llama 4 Scout 17B-16E
|
||||
7168, # DeepSeek V3
|
||||
8192, # Llama 3 70B
|
||||
]
|
||||
|
||||
|
||||
def clone_args(args: tuple) -> tuple:
|
||||
return tuple(a.clone() if isinstance(a, torch.Tensor) else a for a in args)
|
||||
|
||||
|
||||
def supported_providers(op: IrOp) -> list[str]:
|
||||
return [
|
||||
name for name, impl in op.impls.items() if name != "native" and impl.supported
|
||||
]
|
||||
|
||||
|
||||
def assert_close(op: IrOp, actual, expected):
|
||||
if isinstance(actual, torch.Tensor):
|
||||
tol = op.get_tolerance(actual.dtype)
|
||||
try:
|
||||
torch.testing.assert_close(actual, expected, **tol)
|
||||
except AssertionError as e:
|
||||
raise AssertionError(
|
||||
f"{e}\n\nTo adjust tolerance, use:\n"
|
||||
f" ir.ops.{op.name}.override_tolerance("
|
||||
f"{actual.dtype}, atol=..., rtol=...)"
|
||||
) from None
|
||||
elif isinstance(actual, (tuple, list)):
|
||||
for a, ex in zip(actual, expected):
|
||||
assert_close(op, a, ex)
|
||||
else:
|
||||
assert actual == expected
|
||||
@@ -495,3 +495,68 @@ def test_uuid_and_oot(tmp_path: Path):
|
||||
assert uuid2 == uuid
|
||||
assert uuid2 != uuid1
|
||||
del _custom_mm.impls["impl_mm_oot"]
|
||||
|
||||
|
||||
def _test_native(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
return x + y
|
||||
|
||||
|
||||
def _make_op_with_generator(name: str = "_ig_test"):
|
||||
op = IrOp(name, _test_native)
|
||||
|
||||
@op.register_input_generator
|
||||
def _gen(n: int = 4):
|
||||
x = torch.randn(n, 3)
|
||||
y = torch.randn(n, 3)
|
||||
return x, y
|
||||
|
||||
return op
|
||||
|
||||
|
||||
def _test_native_single(x: torch.Tensor) -> torch.Tensor:
|
||||
return x
|
||||
|
||||
|
||||
class TestInputGenerator:
|
||||
def test_no_input_generator_by_default(self):
|
||||
op = IrOp("_ig_test_no_gen", _test_native_single)
|
||||
assert not op.has_input_generator
|
||||
|
||||
def test_register_input_generator(self):
|
||||
op = _make_op_with_generator("_ig_test_reg")
|
||||
assert op.has_input_generator
|
||||
|
||||
def test_generate_inputs_returns_tuple(self):
|
||||
op = _make_op_with_generator("_ig_test_tuple")
|
||||
result = op.generate_inputs(n=2)
|
||||
assert isinstance(result, tuple)
|
||||
assert len(result) == 2
|
||||
assert result[0].shape == (2, 3)
|
||||
assert result[1].shape == (2, 3)
|
||||
|
||||
def test_generate_inputs_default_kwargs(self):
|
||||
op = _make_op_with_generator("_ig_test_default")
|
||||
result = op.generate_inputs()
|
||||
assert result[0].shape == (4, 3)
|
||||
|
||||
def test_generate_inputs_without_registration_raises(self):
|
||||
op = IrOp("_ig_test_no_gen_raises", _test_native_single)
|
||||
with pytest.raises(RuntimeError, match="No input generator"):
|
||||
op.generate_inputs()
|
||||
|
||||
|
||||
class TestTolerance:
|
||||
def test_override_and_get_tolerance(self):
|
||||
op = IrOp("_tol_test", _test_native)
|
||||
|
||||
tol = op.get_tolerance(torch.float32)
|
||||
assert tol == {"atol": 1e-5, "rtol": 1.3e-6}
|
||||
|
||||
op.override_tolerance(torch.float32, atol=0.1, rtol=0.2)
|
||||
assert op.get_tolerance(torch.float32) == {"atol": 0.1, "rtol": 0.2}
|
||||
assert op.get_tolerance(torch.float16) == {"atol": 1e-3, "rtol": 1e-3}
|
||||
|
||||
def test_get_tolerance_raises_for_unknown_dtype(self):
|
||||
op = IrOp("_tol_test_unknown", _test_native)
|
||||
with pytest.raises(ValueError, match="No tolerance defined"):
|
||||
op.get_tolerance(torch.complex64)
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Meta-tests for vLLM IR op infrastructure.
|
||||
|
||||
Ensures all registered ops have input generators defined.
|
||||
Per-op correctness tests live alongside their op definitions
|
||||
(e.g. tests/kernels/ir/test_layernorm.py).
|
||||
"""
|
||||
|
||||
import vllm.kernels # noqa: F401 — registers provider implementations
|
||||
from vllm.ir.op import IrOp
|
||||
|
||||
|
||||
def test_all_ops_have_input_generator():
|
||||
missing = [name for name, op in IrOp.registry.items() if not op.has_input_generator]
|
||||
assert not missing, (
|
||||
f"IR ops without input generators: {missing}. "
|
||||
f"Register one with @ir.ops.<name>.register_input_generator"
|
||||
)
|
||||
@@ -5,17 +5,17 @@ import torch
|
||||
|
||||
# This registers op implementations
|
||||
import vllm.kernels # noqa: F401
|
||||
from tests.ir.ir_test_utils import (
|
||||
COMMON_HIDDEN_SIZES,
|
||||
NUM_TOKENS,
|
||||
assert_close,
|
||||
clone_args,
|
||||
supported_providers,
|
||||
)
|
||||
from tests.kernels.allclose_default import get_default_rtol
|
||||
from vllm import ir
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
def rms_norm_inputs(n_tokens: int, hidden_size: int, dtype: torch.dtype):
|
||||
x = torch.randn(n_tokens, hidden_size, dtype=dtype)
|
||||
weight = torch.rand(hidden_size, dtype=dtype)
|
||||
return x, weight
|
||||
|
||||
|
||||
rms_norm_native = ir.ops.rms_norm.impls["native"].impl_fn
|
||||
|
||||
|
||||
@@ -40,8 +40,8 @@ def test_rms_norm_registration():
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
|
||||
@pytest.mark.parametrize("n_tokens", [1, 8, 17])
|
||||
@pytest.mark.parametrize("hidden_size", [16, 4096, 8192])
|
||||
@pytest.mark.parametrize("n_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("hidden_size", COMMON_HIDDEN_SIZES)
|
||||
@pytest.mark.parametrize("epsilon", [1e-6, 1e-5])
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda_alike() and not current_platform.is_xpu(),
|
||||
@@ -53,7 +53,9 @@ class TestRMSNorm:
|
||||
torch.set_default_device(current_platform.device_type)
|
||||
|
||||
def test_native_semantics(self, dtype, n_tokens, hidden_size, epsilon):
|
||||
x, weight = rms_norm_inputs(4, 8, dtype)
|
||||
x, weight, epsilon = ir.ops.rms_norm.generate_inputs(
|
||||
num_tokens=4, hidden_size=8, dtype=dtype, epsilon=epsilon
|
||||
)
|
||||
out = rms_norm_native(x, weight, epsilon=epsilon)
|
||||
|
||||
# Check shape, dtype, device
|
||||
@@ -71,59 +73,59 @@ class TestRMSNorm:
|
||||
out4 = rms_norm_native(x, None, epsilon=epsilon)
|
||||
torch.testing.assert_close(out3, out4)
|
||||
|
||||
@pytest.mark.parametrize("provider", ["vllm_c", "aiter", "xpu_kernels"])
|
||||
@pytest.mark.parametrize("provider", supported_providers(ir.ops.rms_norm))
|
||||
def test_impls(self, dtype, n_tokens, hidden_size, epsilon, provider):
|
||||
impl = ir.ops.rms_norm.impls[provider]
|
||||
if not impl.supported:
|
||||
pytest.skip(f"{provider} impl not supported on this platform")
|
||||
|
||||
x, weight = rms_norm_inputs(n_tokens, hidden_size, dtype)
|
||||
args = (x, weight, epsilon, None)
|
||||
|
||||
assert impl.supported
|
||||
|
||||
if provider == "aiter" and dtype not in [torch.float16, torch.bfloat16]:
|
||||
assert not impl.supports_args(*args)
|
||||
return
|
||||
|
||||
assert impl.supports_args(*args)
|
||||
|
||||
out_impl = impl.impl_fn(*args)
|
||||
out_native = rms_norm_native(*args)
|
||||
|
||||
torch.testing.assert_close(
|
||||
out_impl, out_native, rtol=get_default_rtol(out_impl), atol=1e-3
|
||||
x, weight, eps = ir.ops.rms_norm.generate_inputs(
|
||||
num_tokens=n_tokens, hidden_size=hidden_size, dtype=dtype, epsilon=epsilon
|
||||
)
|
||||
args = (x, weight, eps)
|
||||
|
||||
if not impl.supports_args(*args):
|
||||
pytest.skip(f"{provider} does not support args")
|
||||
|
||||
ref_output = rms_norm_native(*clone_args(args))
|
||||
output = impl.impl_fn(*clone_args(args))
|
||||
assert_close(ir.ops.rms_norm, output, ref_output)
|
||||
|
||||
# check that dispatched call matches direct call
|
||||
with ir.ops.rms_norm.set_priority([provider, "native"]):
|
||||
out_impl2 = ir.ops.rms_norm(*args)
|
||||
|
||||
# exact match
|
||||
torch.testing.assert_close(out_impl2, out_impl, rtol=0.0, atol=0.0)
|
||||
out_dispatched = ir.ops.rms_norm(*args)
|
||||
out_direct = impl.impl_fn(*args)
|
||||
torch.testing.assert_close(out_dispatched, out_direct, rtol=0.0, atol=0.0)
|
||||
|
||||
# none of these support variance_size override
|
||||
assert not impl.supports_args(x, weight, epsilon, 4)
|
||||
assert not impl.supports_args(x, weight, epsilon, variance_size=4)
|
||||
assert not impl.supports_args(x, weight, eps, 4)
|
||||
assert not impl.supports_args(x, weight, eps, variance_size=4)
|
||||
|
||||
# test weight=None behavior
|
||||
out_impl_no_weight = impl.impl_fn(x, None, epsilon)
|
||||
out_impl_unit_weight = impl.impl_fn(x, torch.ones_like(weight), epsilon)
|
||||
torch.testing.assert_close(
|
||||
out_impl_no_weight,
|
||||
out_impl_unit_weight,
|
||||
rtol=get_default_rtol(out_impl_no_weight),
|
||||
atol=2e-4,
|
||||
)
|
||||
out_no_weight = impl.impl_fn(x, None, eps)
|
||||
out_unit_weight = impl.impl_fn(x, torch.ones_like(weight), eps)
|
||||
assert_close(ir.ops.rms_norm, out_no_weight, out_unit_weight)
|
||||
|
||||
@pytest.mark.parametrize("provider", ["vllm_c", "aiter", "xpu_kernels", "native"])
|
||||
def test_torch_opcheck(self, dtype, n_tokens, hidden_size, epsilon, provider):
|
||||
if not ir.ops.rms_norm.impls[provider].supported:
|
||||
pytest.skip(f"{provider} impl not supported on this platform")
|
||||
|
||||
x, weight = rms_norm_inputs(n_tokens, hidden_size, dtype)
|
||||
args = (x, weight, epsilon, None)
|
||||
args = ir.ops.rms_norm.generate_inputs(
|
||||
num_tokens=n_tokens, hidden_size=hidden_size, dtype=dtype, epsilon=epsilon
|
||||
)
|
||||
|
||||
# When checking the torch op, we have to set priority and use dispatch
|
||||
with ir.ops.rms_norm.set_priority([provider, "native"]):
|
||||
torch.library.opcheck(torch.ops.vllm_ir.rms_norm, args)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_rocm(),
|
||||
reason="aiter is only supported on ROCm",
|
||||
)
|
||||
def test_aiter_rejects_unsupported_dtypes():
|
||||
torch.set_default_device(current_platform.device_type)
|
||||
impl = ir.ops.rms_norm.impls["aiter"]
|
||||
for dtype in [torch.float32, torch.float64]:
|
||||
args = ir.ops.rms_norm.generate_inputs(
|
||||
num_tokens=8, hidden_size=4096, dtype=dtype, epsilon=1e-5
|
||||
)
|
||||
assert not impl.supports_args(*args), f"aiter should reject dtype={dtype}"
|
||||
|
||||
@@ -9,10 +9,13 @@ from typing import Any, ClassVar, overload
|
||||
import torch
|
||||
from torch.library import Library, infer_schema
|
||||
|
||||
from vllm.ir.tolerances import DEFAULT_TOLERANCES, ToleranceSpec
|
||||
from vllm.ir.util import hash_source, weak_cache
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logging_utils import lazy, tensors_str_no_data
|
||||
|
||||
InputGenerator = Callable[..., tuple[Any, ...]]
|
||||
|
||||
vllm_ir_lib = Library("vllm_ir", "FRAGMENT")
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -113,6 +116,8 @@ class IrOp:
|
||||
self.impls: dict[str, IrOpImpl] = {}
|
||||
self._priority_impls: list[IrOpImpl] = []
|
||||
self._schema_str = infer_schema(native_impl, mutates_args=[])
|
||||
self._input_generator: InputGenerator | None = None
|
||||
self._tolerance_overrides: ToleranceSpec = {}
|
||||
|
||||
# native implementation
|
||||
self.impls["native"] = IrOpImpl(
|
||||
@@ -316,6 +321,38 @@ class IrOp:
|
||||
def supported_providers(self) -> list[str]:
|
||||
return [p.provider for p in self.impls.values() if p.supported]
|
||||
|
||||
@property
|
||||
def has_input_generator(self) -> bool:
|
||||
return self._input_generator is not None
|
||||
|
||||
def register_input_generator(self, fn: InputGenerator) -> InputGenerator:
|
||||
self._input_generator = fn
|
||||
return fn
|
||||
|
||||
def generate_inputs(self, **kwargs: Any) -> tuple[Any, ...]:
|
||||
if self._input_generator is None:
|
||||
raise RuntimeError(
|
||||
f"No input generator registered for op '{self.name}'. "
|
||||
f"Use @ir.ops.{self.name}.register_input_generator"
|
||||
)
|
||||
return self._input_generator(**kwargs)
|
||||
|
||||
def override_tolerance(
|
||||
self, dtype: torch.dtype, *, atol: float, rtol: float
|
||||
) -> None:
|
||||
self._tolerance_overrides[dtype] = {"atol": atol, "rtol": rtol}
|
||||
|
||||
def get_tolerance(self, dtype: torch.dtype) -> dict[str, float]:
|
||||
if dtype in self._tolerance_overrides:
|
||||
return self._tolerance_overrides[dtype]
|
||||
if dtype in DEFAULT_TOLERANCES:
|
||||
return DEFAULT_TOLERANCES[dtype]
|
||||
raise ValueError(
|
||||
f"No tolerance defined for dtype {dtype} in op '{self.name}'. "
|
||||
f"Use op.override_tolerance({dtype}, atol=..., rtol=...) "
|
||||
f"or add {dtype} to DEFAULT_TOLERANCES."
|
||||
)
|
||||
|
||||
|
||||
class IrOpImpl:
|
||||
def __init__(
|
||||
|
||||
@@ -19,3 +19,18 @@ def rms_norm(
|
||||
if weight is not None:
|
||||
x = x.to(weight.dtype) * weight
|
||||
return x.to(orig_dtype)
|
||||
|
||||
|
||||
@rms_norm.register_input_generator
|
||||
def _rms_norm_input_generator(
|
||||
num_tokens: int, hidden_size: int, dtype: torch.dtype, epsilon: float = 1e-5
|
||||
) -> tuple:
|
||||
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||
weight = torch.randn(hidden_size, dtype=dtype)
|
||||
return (x, weight, epsilon)
|
||||
|
||||
|
||||
# Reductions in rms_norm accumulate rounding error at large shapes
|
||||
# (e.g. 32768x16384), causing a few elements out of millions to exceed
|
||||
# the default float16 tolerance.
|
||||
rms_norm.override_tolerance(torch.float16, atol=1e-2, rtol=2e-3)
|
||||
|
||||
@@ -0,0 +1,36 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
ToleranceSpec = dict[torch.dtype, dict[str, float]]
|
||||
|
||||
# Default tolerances for comparing IR op implementations against native.
|
||||
# These are intentionally conservative (permissive) to avoid false failures
|
||||
# across different hardware and kernel implementations. Ops that need tighter
|
||||
# or looser bounds should use override_tolerance.
|
||||
DEFAULT_TOLERANCES: ToleranceSpec = {
|
||||
# 52-bit mantissa; machine epsilon ~1.1e-16
|
||||
torch.float64: {"atol": 1e-8, "rtol": 1e-8},
|
||||
# 23-bit mantissa; machine epsilon ~1.2e-7.
|
||||
# Values from PyTorch test_transformers.py reference defaults.
|
||||
torch.float32: {"atol": 1e-5, "rtol": 1.3e-6},
|
||||
# 10-bit mantissa; machine epsilon ~9.8e-4.
|
||||
# Standard tolerance used across vLLM kernel tests.
|
||||
torch.float16: {"atol": 1e-3, "rtol": 1e-3},
|
||||
# 7-bit mantissa; machine epsilon ~7.8e-3.
|
||||
# Wider rtol than float16 to account for the coarser mantissa.
|
||||
torch.bfloat16: {"atol": 1e-3, "rtol": 1.6e-2},
|
||||
# 3-bit mantissa; machine epsilon ~6.25e-2.
|
||||
# Derived from vLLM fp8 kernel tests (merge_attn_states, silu_mul_fp8).
|
||||
torch.float8_e4m3fn: {"atol": 1e-1, "rtol": 1e-1},
|
||||
# 2-bit mantissa; machine epsilon ~1.25e-1.
|
||||
# Wider than e4m3fn due to the smaller mantissa.
|
||||
torch.float8_e5m2: {"atol": 2e-1, "rtol": 2e-1},
|
||||
# 1-bit mantissa; machine epsilon ~2.5e-1. Packed pair format (x2).
|
||||
# Derived from vLLM fp4 tests (test_silu_mul_nvfp4_quant: atol=3e-1).
|
||||
torch.float4_e2m1fn_x2: {"atol": 3e-1, "rtol": 3e-1},
|
||||
# Integer quantized; off-by-one from rounding is expected.
|
||||
# rtol=0 because relative error is meaningless for small integers.
|
||||
torch.int8: {"atol": 1, "rtol": 0},
|
||||
}
|
||||
Reference in New Issue
Block a user