[ROCM] [FEAT] Integrate Aiter hipBLASLt GEMM online tuning (#40426)

Signed-off-by: hanlin12 <hanlin12@amd.com>
Signed-off-by: Han Lin <hanlin12@amd.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: TJian <tunjian.tan@embeddedllm.com>
This commit is contained in:
Han Lin
2026-06-05 14:45:36 +08:00
committed by GitHub
parent c505cd93ef
commit 165b7864d0
6 changed files with 591 additions and 0 deletions
@@ -0,0 +1,406 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import csv
import importlib
import importlib.util
import os
import pytest
import torch
from tests.utils import TestFP8Layer
from vllm._aiter_ops import rocm_aiter_ops
from vllm.model_executor.kernels.linear.scaled_mm.aiter import (
AiterHipbMMPerTokenFp8ScaledMMLinearKernel,
)
from vllm.model_executor.kernels.linear.scaled_mm.ScaledMMLinearKernel import (
FP8ScaledMMLinearLayerConfig,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8DynamicTokenSym,
kFp8StaticChannelSym,
kFp8StaticTensorSym,
)
from vllm.platforms import current_platform
aiter_available = importlib.util.find_spec("aiter") is not None
pytestmark = [
pytest.mark.skipif(
not (
current_platform.is_rocm()
and current_platform.supports_fp8()
and aiter_available
),
reason="Requires ROCm + FP8 support + aiter",
),
pytest.mark.usefixtures("default_vllm_config"),
]
@pytest.fixture
def enable_hipb_mm_kernel(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
monkeypatch.setenv("VLLM_ROCM_USE_AITER_LINEAR", "1")
monkeypatch.setenv("VLLM_ROCM_USE_AITER_LINEAR_HIPBMM", "1")
rocm_aiter_ops.refresh_env_variables()
yield
rocm_aiter_ops.refresh_env_variables()
def _make_config(
*,
weight_quant_key=kFp8StaticChannelSym,
out_dtype: torch.dtype = torch.bfloat16,
weight_shape: tuple[int, int] = (512, 4096),
) -> FP8ScaledMMLinearLayerConfig:
return FP8ScaledMMLinearLayerConfig(
weight_quant_key=weight_quant_key,
activation_quant_key=kFp8DynamicTokenSym,
weight_shape=weight_shape,
input_dtype=torch.bfloat16,
out_dtype=out_dtype,
)
def _find_csv_row(path: str, m: int, n: int, k: int) -> dict | None:
if not os.path.exists(path):
return None
with open(path, newline="") as f:
reader = csv.DictReader(f, skipinitialspace=True)
for row in reader:
try:
if (
int(row.get("m", -1)) == m
and int(row.get("n", -1)) == n
and int(row.get("k", -1)) == k
):
return dict(row)
except (TypeError, ValueError):
continue
return None
def _skip_if_no_hipb_mm_solution(exc: RuntimeError) -> None:
if "hipblasLtMatmulAlgoGetHeuristic found 0 valid solutions" in str(exc):
pytest.skip(
"hipb_mm bpreshuffle path has no valid hipBLASLt solution on "
"this ROCm stack."
)
def _check_bpreshuffle_runtime_support(weight_shape: tuple[int, int], num_tokens: int):
import aiter
from aiter.ops.shuffle import shuffle_weight
x = torch.randn(num_tokens, weight_shape[1], dtype=torch.bfloat16, device="cuda")
w = torch.randn(weight_shape, dtype=torch.bfloat16, device="cuda")
aiter.hipb_create_extension()
x_q, x_scale = aiter.pertoken_quant(x, quant_dtype=current_platform.fp8_dtype())
w_q, w_scale = aiter.pertoken_quant(w, quant_dtype=current_platform.fp8_dtype())
try:
aiter.hipb_mm(
x_q,
shuffle_weight(w_q, layout=(16, 16)).t(),
solution_index=-1,
out_dtype=torch.bfloat16,
scaleA=x_scale,
scaleB=w_scale.t().contiguous(),
scaleOut=None,
bpreshuffle=True,
)
except RuntimeError as exc:
_skip_if_no_hipb_mm_solution(exc)
raise
def test_hipb_mm_kernel_requires_hipbmm_flag(monkeypatch: pytest.MonkeyPatch):
# The kernel rejects when `is_hip_fp8bmm_enabled()` is False. That helper
# requires AITER + AITER_LINEAR + MI3xx, so dropping AITER_LINEAR exercises
# the rejection branch.
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
monkeypatch.setenv("VLLM_ROCM_USE_AITER_LINEAR", "0")
monkeypatch.delenv("VLLM_ROCM_USE_AITER_LINEAR_HIPBMM", raising=False)
rocm_aiter_ops.refresh_env_variables()
is_supported, reason = AiterHipbMMPerTokenFp8ScaledMMLinearKernel.is_supported()
assert not is_supported
assert reason == (
"requires setting `VLLM_ROCM_USE_AITER=1`, "
"`VLLM_ROCM_USE_AITER_LINEAR=1`, "
"and `VLLM_ROCM_USE_AITER_LINEAR_HIPBMM=1`."
)
def test_hipb_mm_flag_enables_hip_online_tuning(
monkeypatch: pytest.MonkeyPatch,
):
import vllm.envs as envs_mod
import vllm.platforms.rocm as rocm_mod
# The rocm.py gate requires all three AITER flags (and MI3xx) to auto-set
# HIP_ONLINE_TUNING.
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
monkeypatch.setenv("VLLM_ROCM_USE_AITER_LINEAR", "1")
monkeypatch.setenv("VLLM_ROCM_USE_AITER_LINEAR_HIPBMM", "1")
try:
importlib.reload(envs_mod)
importlib.reload(rocm_mod)
assert envs_mod.VLLM_ROCM_USE_AITER
assert envs_mod.VLLM_ROCM_USE_AITER_LINEAR
assert envs_mod.VLLM_ROCM_USE_AITER_LINEAR_HIPBMM
assert os.environ.get("HIP_ONLINE_TUNING") == "1"
finally:
monkeypatch.undo()
os.environ.pop("HIP_ONLINE_TUNING", None)
importlib.reload(envs_mod)
importlib.reload(rocm_mod)
rocm_aiter_ops.refresh_env_variables()
def test_hipb_mm_kernel_can_implement_success(enable_hipb_mm_kernel):
can_implement, reason = AiterHipbMMPerTokenFp8ScaledMMLinearKernel.can_implement(
_make_config()
)
assert can_implement
assert reason is None
@pytest.mark.parametrize(
("config", "expected_reason"),
[
(
_make_config(weight_quant_key=kFp8StaticTensorSym),
"requires per token activation scales and per channel weight scales.",
),
(
_make_config(out_dtype=torch.float16),
"requires bfloat16 output dtype.",
),
(
_make_config(weight_shape=(8, 4090)),
"requires N >= 16 and both N and K divisible by 16, "
"received N=8 and K=4090.",
),
],
)
def test_hipb_mm_kernel_can_implement_rejects_unsupported_configs(
enable_hipb_mm_kernel,
config: FP8ScaledMMLinearLayerConfig,
expected_reason: str,
):
can_implement, reason = AiterHipbMMPerTokenFp8ScaledMMLinearKernel.can_implement(
config
)
assert not can_implement
assert reason == expected_reason
def test_hipb_mm_kernel_process_weights_after_loading_shuffles_weights(
enable_hipb_mm_kernel,
):
weight_shape = (512, 4096)
kernel = AiterHipbMMPerTokenFp8ScaledMMLinearKernel(
_make_config(weight_shape=weight_shape),
layer_param_names=("weight", "weight_scale", "input_scale", "input_scale_ub"),
)
layer = torch.nn.Module()
layer.weight = torch.nn.Parameter(
torch.rand(weight_shape, device="cuda").to(current_platform.fp8_dtype()).t(),
requires_grad=False,
)
layer.weight_scale = torch.nn.Parameter(
torch.rand((weight_shape[0], 1), dtype=torch.float32, device="cuda"),
requires_grad=False,
)
layer.input_scale = None
layer.input_scale_ub = None
original_weight = layer.weight.detach().clone()
original_weight_scale = layer.weight_scale.detach().clone()
kernel.process_weights_after_loading(layer)
# process_weights_after_loading now pre-applies the transposes that used
# to live in _rocm_aiter_hipb_mm_fp8_impl, so the stored weight is the
# shuffled tensor with a trailing `.t()` view, and the stored weight scale
# is its transposed-contiguous form.
expected_weight = rocm_aiter_ops.shuffle_weight(
original_weight.t().contiguous()
).t()
torch.testing.assert_close(layer.weight, expected_weight)
expected_weight_scale = original_weight_scale.t().contiguous()
torch.testing.assert_close(layer.weight_scale, expected_weight_scale)
def test_hipb_mm_kernel_forward_matches_raw_aiter_hipb_mm(enable_hipb_mm_kernel):
import aiter
weight_shape = (512, 4096)
_check_bpreshuffle_runtime_support(weight_shape, num_tokens=32)
layer = TestFP8Layer(
weight_shape=weight_shape,
activation_quant_key=kFp8DynamicTokenSym,
weight_quant_key=kFp8StaticChannelSym,
input_dtype=torch.bfloat16,
out_dtype=torch.bfloat16,
device=torch.device("cuda"),
force_kernel=AiterHipbMMPerTokenFp8ScaledMMLinearKernel,
)
# hipb_mm uses a transposed-result GEMM internally, so the flattened token
# count becomes the effective N dimension passed into hipBLASLt. Keep it
# aligned to avoid heuristic failures for tiny N.
x = torch.randn(2, 16, weight_shape[1], dtype=torch.bfloat16, device="cuda")
bias = torch.randn(weight_shape[0], dtype=torch.bfloat16, device="cuda")
try:
out = layer(x, bias)
except RuntimeError as exc:
_skip_if_no_hipb_mm_solution(exc)
raise
x_2d = x.view(-1, x.shape[-1])
x_q, x_scale = layer.kernel.quant_fp8(
x_2d,
layer.input_scale,
layer.input_scale_ub,
)
try:
# process_weights_after_loading already applies the trailing `.t()` on
# the shuffled weight and the `.t().contiguous()` on the weight scale,
# so the raw aiter call uses them directly.
expected = aiter.hipb_mm(
x_q,
layer.weight,
solution_index=-1,
bias=bias,
out_dtype=torch.bfloat16,
scaleA=x_scale,
scaleB=layer.weight_scale,
scaleOut=None,
bpreshuffle=True,
).view(*out.shape)
except RuntimeError as exc:
_skip_if_no_hipb_mm_solution(exc)
raise
assert isinstance(layer.kernel, AiterHipbMMPerTokenFp8ScaledMMLinearKernel)
assert out.shape == (2, 16, weight_shape[0])
torch.testing.assert_close(out, expected)
def test_hipb_mm_kernel_forward_accuracy(enable_hipb_mm_kernel):
"""Kernel output should match a dequantized fp32 reference within
fp8 per-token / per-channel quantization noise."""
weight_shape = (512, 4096) # (N, K)
num_tokens = 32
_check_bpreshuffle_runtime_support(weight_shape, num_tokens=num_tokens)
fp8_dtype = current_platform.fp8_dtype()
fp8_max = torch.finfo(fp8_dtype).max
device = torch.device("cuda")
# Build a bf16 weight and quantize per output channel (one scale per row).
w_bf16 = torch.randn(weight_shape, dtype=torch.bfloat16, device=device)
w_amax = w_bf16.abs().amax(dim=1, keepdim=True).to(torch.float32)
w_scale = (w_amax / fp8_max).clamp(min=1e-12)
w_fp8 = (w_bf16.to(torch.float32) / w_scale).clamp(-fp8_max, fp8_max).to(fp8_dtype)
w_dequant = w_fp8.to(torch.float32) * w_scale
bias = torch.randn(weight_shape[0], dtype=torch.bfloat16, device=device)
layer = torch.nn.Module()
# Pre-`process_weights_after_loading` convention: weight stored as the
# `[K, N]` view of the fp8 tensor.
layer.weight = torch.nn.Parameter(w_fp8.t(), requires_grad=False)
layer.weight_scale = torch.nn.Parameter(w_scale, requires_grad=False)
layer.input_scale = None
layer.input_scale_ub = None
kernel = AiterHipbMMPerTokenFp8ScaledMMLinearKernel(
_make_config(weight_shape=weight_shape),
layer_param_names=("weight", "weight_scale", "input_scale", "input_scale_ub"),
)
kernel.process_weights_after_loading(layer)
x = torch.randn(num_tokens, weight_shape[1], dtype=torch.bfloat16, device=device)
try:
out = kernel.apply_weights(layer, x, bias)
except RuntimeError as exc:
_skip_if_no_hipb_mm_solution(exc)
raise
# Reference: quantize x per-token the same way the kernel does, then run
# the matmul in fp32 against the dequantized weight. This isolates plumbing
# / reduction bugs from inherent fp8 quantization noise.
x_amax = x.abs().amax(dim=1, keepdim=True).to(torch.float32)
x_scale_ref = (x_amax / fp8_max).clamp(min=1e-12)
x_q = (x.to(torch.float32) / x_scale_ref).clamp(-fp8_max, fp8_max).to(fp8_dtype)
x_dequant = x_q.to(torch.float32) * x_scale_ref
expected = (x_dequant @ w_dequant.t() + bias.to(torch.float32)).to(torch.bfloat16)
assert out.shape == (num_tokens, weight_shape[0])
# K=4096 fp8 reduction leaves room for accumulation order drift and
# catastrophic cancellation on near-zero outputs; tolerances are loose
# enough to absorb that but tight enough to catch wrong layouts, missing
# bias, swapped scales, etc.
torch.testing.assert_close(out, expected, atol=5.0, rtol=0.1)
def test_hipb_mm_kernel_online_tuning_writes_csv(
enable_hipb_mm_kernel,
monkeypatch: pytest.MonkeyPatch,
tmp_path,
):
weight_shape = (256, 4096)
cache_file = tmp_path / "hip_online_tuning_res.csv"
_check_bpreshuffle_runtime_support(weight_shape, num_tokens=16)
monkeypatch.setenv("HIP_ONLINE_TUNING", "1")
monkeypatch.chdir(tmp_path)
layer = TestFP8Layer(
weight_shape=weight_shape,
activation_quant_key=kFp8DynamicTokenSym,
weight_quant_key=kFp8StaticChannelSym,
input_dtype=torch.bfloat16,
out_dtype=torch.bfloat16,
device=torch.device("cuda"),
force_kernel=AiterHipbMMPerTokenFp8ScaledMMLinearKernel,
)
# The effective heuristic N dimension is the flattened token count.
x = torch.randn(16, weight_shape[1], dtype=torch.bfloat16, device="cuda")
try:
out = layer(x)
except RuntimeError as exc:
_skip_if_no_hipb_mm_solution(exc)
raise
torch.accelerator.synchronize()
assert out.shape == (16, weight_shape[0])
assert cache_file.exists()
# hipb_mm records the internal GEMM dimensions used by hipBLASLt after its
# transposed-result transformation.
row = _find_csv_row(
str(cache_file),
m=weight_shape[0],
n=x.shape[0],
k=weight_shape[1],
)
assert row is not None
+73
View File
@@ -27,6 +27,16 @@ except ImportError:
# on ROCm the fp8_dtype always calls is_fp8_fnuz
# which is a host op, so we cache it once here.
FP8_DTYPE = current_platform.fp8_dtype()
_HIPB_MM_INITIALIZED_DEVICES: set[int] = set()
def _ensure_hipb_mm_extension_initialized() -> None:
import aiter
device = torch.accelerator.current_device_index()
if device not in _HIPB_MM_INITIALIZED_DEVICES:
aiter.hipb_create_extension()
_HIPB_MM_INITIALIZED_DEVICES.add(device)
def is_aiter_found() -> bool:
@@ -625,6 +635,43 @@ def _rocm_aiter_preshuffled_per_token_w8a8_gemm_fake(
return torch.empty(m, n, dtype=output_dtype, device=A.device)
def _rocm_aiter_hipb_mm_fp8_impl(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor | None = None,
output_dtype: torch.dtype = torch.bfloat16,
) -> torch.Tensor:
from aiter import hipb_mm
_ensure_hipb_mm_extension_initialized()
return hipb_mm(
A,
B,
solution_index=-1,
bias=bias,
out_dtype=output_dtype,
scaleA=As,
scaleB=Bs,
scaleOut=None,
bpreshuffle=True,
)
def _rocm_aiter_hipb_mm_fp8_fake(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor | None = None,
output_dtype: torch.dtype = torch.bfloat16,
) -> torch.Tensor:
m = A.shape[0]
n = B.shape[1]
return torch.empty(m, n, dtype=output_dtype, device=A.device)
def _rocm_aiter_triton_gemm_a8w8_blockscale_impl(
A: torch.Tensor,
B: torch.Tensor,
@@ -1308,6 +1355,7 @@ class rocm_aiter_ops:
# TODO: Consolidate under _LINEAR_ENABLED
_FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM
_FP4BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP4BMM
_LINEAR_HIPBMM_ENABLED = envs.VLLM_ROCM_USE_AITER_LINEAR_HIPBMM
# TODO: Consolidate under _LINEAR_ENABLED
_FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM
# TODO: Consolidate under VLLM_ROCM_USE_AITER_ROPE
@@ -1340,6 +1388,7 @@ class rocm_aiter_ops:
cls._TRITON_UNIFIED_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION
cls._FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM
cls._FP4BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP4BMM
cls._LINEAR_HIPBMM_ENABLED = envs.VLLM_ROCM_USE_AITER_LINEAR_HIPBMM
cls._FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM
cls._TRITON_ROTARY_EMBED = envs.VLLM_ROCM_USE_AITER_TRITON_ROPE
cls._MOE_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS
@@ -1512,6 +1561,13 @@ class rocm_aiter_ops:
return cls._AITER_ENABLED and cls._FP4BMM_ENABLED and on_gfx950()
@classmethod
@if_aiter_supported
def is_linear_hipbmm_enabled(cls) -> bool:
from vllm.platforms.rocm import on_mi3xx
return cls.is_linear_enabled() and on_mi3xx() and cls._LINEAR_HIPBMM_ENABLED
@classmethod
@if_aiter_supported
def is_asm_fp4_gemm_dynamic_quant_enabled(cls) -> bool:
@@ -1668,6 +1724,12 @@ class rocm_aiter_ops:
fake_impl=_rocm_aiter_preshuffled_per_token_w8a8_gemm_fake,
)
direct_register_custom_op(
op_name="rocm_aiter_hipb_mm_fp8",
op_func=_rocm_aiter_hipb_mm_fp8_impl,
fake_impl=_rocm_aiter_hipb_mm_fp8_fake,
)
direct_register_custom_op(
op_name="rocm_aiter_triton_gemm_a8w8_blockscale",
op_func=_rocm_aiter_triton_gemm_a8w8_blockscale_impl,
@@ -1858,6 +1920,17 @@ class rocm_aiter_ops:
A, B, As, Bs, bias, output_dtype
)
@staticmethod
def hipb_mm_fp8(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor | None = None,
output_dtype: torch.dtype = torch.bfloat16,
) -> torch.Tensor:
return torch.ops.vllm.rocm_aiter_hipb_mm_fp8(A, B, As, Bs, bias, output_dtype)
@staticmethod
def triton_gemm_a8w8_blockscale(
A: torch.Tensor,
+4
View File
@@ -115,6 +115,7 @@ if TYPE_CHECKING:
VLLM_ROCM_USE_AITER: bool = False
VLLM_ROCM_USE_AITER_PAGED_ATTN: bool = False
VLLM_ROCM_USE_AITER_LINEAR: bool = True
VLLM_ROCM_USE_AITER_LINEAR_HIPBMM: bool = False
VLLM_ROCM_USE_AITER_MOE: bool = True
VLLM_ROCM_AITER_MOE_DISPATCH_POLICY: int = 0
VLLM_ROCM_USE_AITER_RMSNORM: bool = True
@@ -1117,6 +1118,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_ROCM_USE_AITER_LINEAR": lambda: (
os.getenv("VLLM_ROCM_USE_AITER_LINEAR", "True").lower() in ("true", "1")
),
"VLLM_ROCM_USE_AITER_LINEAR_HIPBMM": lambda: (
os.getenv("VLLM_ROCM_USE_AITER_LINEAR_HIPBMM", "False").lower() in ("true", "1")
),
# Whether to use aiter moe ops.
# By default is enabled.
"VLLM_ROCM_USE_AITER_MOE": lambda: (
@@ -128,6 +128,7 @@ from vllm.model_executor.kernels.linear.scaled_mm import (
)
from vllm.model_executor.kernels.linear.scaled_mm.aiter import (
AiterFp8BlockScaledMMKernel,
AiterHipbMMPerTokenFp8ScaledMMLinearKernel,
AiterInt8ScaledMMLinearKernel,
AiterPerTokenFp8ScaledMMLinearKernel,
AiterPreshuffledPerTokenFp8ScaledMMLinearKernel,
@@ -285,6 +286,7 @@ _POSSIBLE_FP8_KERNELS: dict[PlatformEnum, list[type[FP8ScaledMMLinearKernel]]] =
ChannelWiseTorchFP8ScaledMMLinearKernel,
],
PlatformEnum.ROCM: [
AiterHipbMMPerTokenFp8ScaledMMLinearKernel,
AiterPreshuffledPerTokenFp8ScaledMMLinearKernel,
AiterPerTokenFp8ScaledMMLinearKernel,
ROCmFP8ScaledMMLinearKernel,
@@ -1024,6 +1026,7 @@ __all__ = [
"FP8ScaledMMLinearLayerConfig",
"Int8ScaledMMLinearLayerConfig",
"ScaledMMLinearLayerConfig",
"AiterHipbMMPerTokenFp8ScaledMMLinearKernel",
"AiterPreshuffledPerTokenFp8ScaledMMLinearKernel",
"AiterPerTokenFp8ScaledMMLinearKernel",
"NvFp4LinearKernel",
@@ -212,6 +212,99 @@ class AiterPreshuffledPerTokenFp8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
)
class AiterHipbMMPerTokenFp8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
if not current_platform.is_rocm():
return False, "requires ROCm."
if not rocm_aiter_ops.is_linear_hipbmm_enabled():
return (
False,
"requires setting `VLLM_ROCM_USE_AITER=1`, "
"`VLLM_ROCM_USE_AITER_LINEAR=1`, "
"and `VLLM_ROCM_USE_AITER_LINEAR_HIPBMM=1`.",
)
try:
import aiter # noqa: F401
except Exception:
return False, "requires aiter library to be installed."
if not hasattr(aiter, "hipb_mm"):
return False, "requires aiter hipb_mm support."
return True, None
@classmethod
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
is_ptpc = (
c.activation_quant_key.scale.group_shape.is_per_token()
and c.weight_quant_key.scale.group_shape.is_per_channel()
)
if c.weight_shape is None:
return False, "weight_shape is required for Aiter kernels"
N, K = c.weight_shape
if c.out_dtype is not torch.bfloat16:
return False, "requires bfloat16 output dtype."
if not is_ptpc:
return (
False,
"requires per token activation scales and per channel weight scales.",
)
if not (N >= 16 and N % 16 == 0 and K % 16 == 0):
return (
False,
"requires N >= 16 and both N and K divisible by 16, "
f"received N={N} and K={K}.",
)
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
w_name, w_s_name, *_ = self.layer_param_names
w, w_s, *_ = self._get_layer_params(layer)
# Pre-apply the transposes that used to live in
# _rocm_aiter_hipb_mm_fp8_impl so the kernel can consume B/Bs directly.
# The `.t()` on the shuffled weight is kept as a non-contiguous view —
# materializing it with `.contiguous()` would re-arrange the bytes and
# break the `bpreshuffle` layout.
shuffled_w = rocm_aiter_ops.shuffle_weight(w.t().contiguous())
replace_parameter(
layer,
w_name,
torch.nn.Parameter(shuffled_w.t(), requires_grad=False),
)
if w_s.ndim > 1:
replace_parameter(
layer,
w_s_name,
torch.nn.Parameter(w_s.t().contiguous(), requires_grad=False),
)
def apply_scaled_mm(
self,
*,
A: torch.Tensor,
B: torch.Tensor,
out_dtype: torch.dtype,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor | None,
output_shape: list,
) -> torch.Tensor:
output_shape[-1] = B.shape[1]
return rocm_aiter_ops.hipb_mm_fp8(A, B, As, Bs, bias, out_dtype).view(
*output_shape
)
class AiterPerTokenFp8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
@classmethod
def is_supported(
+12
View File
@@ -134,6 +134,7 @@ def _sync_hip_cuda_env_vars():
# Sync at import time - catches misconfigurations from process start.
_sync_hip_cuda_env_vars()
# AMDSMI utils
# Note that NVML is not affected by `{CUDA/HIP}_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
@@ -312,6 +313,17 @@ def on_gfx950() -> bool:
return _ON_GFX950
# Enable HIP online tuning early, before hipBLASLt initializes.
# Turn on hipBLASLt online tuning if use AITER hipBLASLt GEMM.
if (
envs.VLLM_ROCM_USE_AITER
and envs.VLLM_ROCM_USE_AITER_LINEAR
and envs.VLLM_ROCM_USE_AITER_LINEAR_HIPBMM
and on_mi3xx()
):
os.environ["HIP_ONLINE_TUNING"] = "1"
@cache
def use_rocm_custom_paged_attention(
qtype: torch.dtype,