Enable perf_token_group_quant/_C_stable_libtorch for ROCm (#42758)

Signed-off-by: charlifu <charlifu@amd.com>
This commit is contained in:
Charlie Fu
2026-06-03 01:23:28 -05:00
committed by GitHub
parent e0081ef8cf
commit 71df063c49
12 changed files with 145 additions and 109 deletions
@@ -6,6 +6,7 @@ import pytest
import torch
from vllm.model_executor.layers.quantization.utils import fp8_utils, int8_utils
from vllm.model_executor.layers.quantization.utils.quant_utils import get_fp8_min_max
from vllm.platforms import current_platform
@@ -16,7 +17,9 @@ from vllm.platforms import current_platform
@pytest.mark.parametrize("tma_aligned", [False, True])
@pytest.mark.parametrize("scale_ue8m0", [False, True])
@pytest.mark.parametrize("group_size", [64, 128])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(
not current_platform.is_cuda_alike(), reason="Only test on CUDA/ROCm."
)
def test_per_token_group_quant_fp8(
shape, column_major: bool, tma_aligned: bool, scale_ue8m0: bool, group_size: int
):
@@ -37,7 +40,7 @@ def test_per_token_group_quant_fp8(
)
# triton ref
with patch("vllm.platforms.current_platform.is_cuda", return_value=False):
with patch("vllm.platforms.current_platform.is_cuda_alike", return_value=False):
ref_q, ref_s = fp8_utils.per_token_group_quant_fp8(
x,
group_size,
@@ -77,7 +80,8 @@ def test_per_token_group_quant_fp8(
)
@pytest.mark.parametrize("poisoned_scales", [False, True])
@pytest.mark.skipif(
not current_platform.is_cuda(), reason="DeepGEMM not available on this platform"
not current_platform.is_cuda_alike(),
reason="DeepGEMM not available on this platform",
)
def test_per_token_group_quant_fp8_packed(
num_tokens, hidden_dim, group_size, poisoned_scales
@@ -99,8 +103,8 @@ def test_per_token_group_quant_fp8_packed(
if poisoned_scales:
# Call the kernel with poisoned scale buffer to
# ensure padded indices are correctly zeroed.
fp8_dtype = torch.float8_e4m3fn
finfo = torch.finfo(fp8_dtype)
fp8_dtype = current_platform.fp8_dtype()
fp8_min, fp8_max = get_fp8_min_max()
out_q = torch.empty_like(x, dtype=fp8_dtype)
out_s_packed = torch.empty_strided(
(mn, k_num_packed),
@@ -115,8 +119,8 @@ def test_per_token_group_quant_fp8_packed(
out_s_packed,
group_size,
1e-10,
finfo.min,
finfo.max,
fp8_min,
fp8_max,
)
else:
out_q, out_s_packed = fp8_utils.per_token_group_quant_fp8_packed_for_deepgemm(
@@ -126,7 +130,7 @@ def test_per_token_group_quant_fp8_packed(
)
# Triton reference (row-major float32 scales, UE8M0)
with patch("vllm.platforms.current_platform.is_cuda", return_value=False):
with patch("vllm.platforms.current_platform.is_cuda_alike", return_value=False):
ref_q, ref_s = fp8_utils.per_token_group_quant_fp8(
x,
group_size,
@@ -157,7 +161,8 @@ def test_per_token_group_quant_fp8_packed(
@pytest.mark.skipif(
not current_platform.is_cuda(), reason="DeepGEMM not available on this platform"
not current_platform.is_cuda_alike(),
reason="DeepGEMM not available on this platform",
)
def test_per_token_group_quant_fp8_packed_all_zero():
"""All-zero input must produce well-defined UE8M0 scale bytes via the eps
@@ -216,7 +221,8 @@ def test_per_token_group_quant_fp8_packed_all_zero():
@pytest.mark.skipif(
not current_platform.is_cuda(), reason="DeepGEMM not available on this platform"
not current_platform.is_cuda_alike(),
reason="DeepGEMM not available on this platform",
)
def test_per_token_group_quant_fp8_packed_mantissa_rounds_up():
"""Inputs whose absmax/max_8bit produces a non-power-of-2 force the
@@ -244,7 +250,7 @@ def test_per_token_group_quant_fp8_packed_mantissa_rounds_up():
use_ue8m0=True,
)
with patch("vllm.platforms.current_platform.is_cuda", return_value=False):
with patch("vllm.platforms.current_platform.is_cuda_alike", return_value=False):
ref_q, ref_s = fp8_utils.per_token_group_quant_fp8(
x,
group_size,
@@ -286,7 +292,8 @@ def test_per_token_group_quant_fp8_packed_mantissa_rounds_up():
],
)
@pytest.mark.skipif(
not current_platform.is_cuda(), reason="DeepGEMM not available on this platform"
not current_platform.is_cuda_alike(),
reason="DeepGEMM not available on this platform",
)
def test_per_token_group_quant_fp8_packed_zero_fills_padded_output_q(
num_tokens, hidden_dim
@@ -305,8 +312,8 @@ def test_per_token_group_quant_fp8_packed_zero_fills_padded_output_q(
k_num_packed = (groups_per_row + 3) // 4
tma_aligned_mn = ((mn + 3) // 4) * 4
fp8_dtype = torch.float8_e4m3fn
finfo = torch.finfo(fp8_dtype)
fp8_dtype = current_platform.fp8_dtype()
fp8_min, fp8_max = get_fp8_min_max()
# Allocate output_q with the padded mn extent and pre-fill with 0xFF
# so the kernel cannot rely on a clean buffer.
out_q = torch.empty((tma_aligned_mn, hidden_dim), device=device, dtype=fp8_dtype)
@@ -320,11 +327,11 @@ def test_per_token_group_quant_fp8_packed_zero_fills_padded_output_q(
)
torch.ops._C.per_token_group_fp8_quant_packed(
x, out_q, out_s_packed, group_size, 1e-10, finfo.min, finfo.max
x, out_q, out_s_packed, group_size, 1e-10, fp8_min, fp8_max
)
# Live rows must match the Triton reference.
with patch("vllm.platforms.current_platform.is_cuda", return_value=False):
with patch("vllm.platforms.current_platform.is_cuda_alike", return_value=False):
ref_q, _ = fp8_utils.per_token_group_quant_fp8(x, group_size, use_ue8m0=True)
assert torch.equal(out_q[:mn], ref_q), "Live region mismatch"
@@ -356,7 +363,7 @@ def test_per_token_group_quant_int8(shape, group_size: int):
)
# triton ref
with patch("vllm.platforms.current_platform.is_cuda", return_value=False):
with patch("vllm.platforms.current_platform.is_cuda_alike", return_value=False):
ref_q, ref_s = int8_utils.per_token_group_quant_int8(
x,
group_size,