mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
Enable perf_token_group_quant/_C_stable_libtorch for ROCm (#42758)
Signed-off-by: charlifu <charlifu@amd.com>
This commit is contained in:
@@ -6,6 +6,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils import fp8_utils, int8_utils
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import get_fp8_min_max
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@@ -16,7 +17,9 @@ from vllm.platforms import current_platform
|
||||
@pytest.mark.parametrize("tma_aligned", [False, True])
|
||||
@pytest.mark.parametrize("scale_ue8m0", [False, True])
|
||||
@pytest.mark.parametrize("group_size", [64, 128])
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda_alike(), reason="Only test on CUDA/ROCm."
|
||||
)
|
||||
def test_per_token_group_quant_fp8(
|
||||
shape, column_major: bool, tma_aligned: bool, scale_ue8m0: bool, group_size: int
|
||||
):
|
||||
@@ -37,7 +40,7 @@ def test_per_token_group_quant_fp8(
|
||||
)
|
||||
|
||||
# triton ref
|
||||
with patch("vllm.platforms.current_platform.is_cuda", return_value=False):
|
||||
with patch("vllm.platforms.current_platform.is_cuda_alike", return_value=False):
|
||||
ref_q, ref_s = fp8_utils.per_token_group_quant_fp8(
|
||||
x,
|
||||
group_size,
|
||||
@@ -77,7 +80,8 @@ def test_per_token_group_quant_fp8(
|
||||
)
|
||||
@pytest.mark.parametrize("poisoned_scales", [False, True])
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda(), reason="DeepGEMM not available on this platform"
|
||||
not current_platform.is_cuda_alike(),
|
||||
reason="DeepGEMM not available on this platform",
|
||||
)
|
||||
def test_per_token_group_quant_fp8_packed(
|
||||
num_tokens, hidden_dim, group_size, poisoned_scales
|
||||
@@ -99,8 +103,8 @@ def test_per_token_group_quant_fp8_packed(
|
||||
if poisoned_scales:
|
||||
# Call the kernel with poisoned scale buffer to
|
||||
# ensure padded indices are correctly zeroed.
|
||||
fp8_dtype = torch.float8_e4m3fn
|
||||
finfo = torch.finfo(fp8_dtype)
|
||||
fp8_dtype = current_platform.fp8_dtype()
|
||||
fp8_min, fp8_max = get_fp8_min_max()
|
||||
out_q = torch.empty_like(x, dtype=fp8_dtype)
|
||||
out_s_packed = torch.empty_strided(
|
||||
(mn, k_num_packed),
|
||||
@@ -115,8 +119,8 @@ def test_per_token_group_quant_fp8_packed(
|
||||
out_s_packed,
|
||||
group_size,
|
||||
1e-10,
|
||||
finfo.min,
|
||||
finfo.max,
|
||||
fp8_min,
|
||||
fp8_max,
|
||||
)
|
||||
else:
|
||||
out_q, out_s_packed = fp8_utils.per_token_group_quant_fp8_packed_for_deepgemm(
|
||||
@@ -126,7 +130,7 @@ def test_per_token_group_quant_fp8_packed(
|
||||
)
|
||||
|
||||
# Triton reference (row-major float32 scales, UE8M0)
|
||||
with patch("vllm.platforms.current_platform.is_cuda", return_value=False):
|
||||
with patch("vllm.platforms.current_platform.is_cuda_alike", return_value=False):
|
||||
ref_q, ref_s = fp8_utils.per_token_group_quant_fp8(
|
||||
x,
|
||||
group_size,
|
||||
@@ -157,7 +161,8 @@ def test_per_token_group_quant_fp8_packed(
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda(), reason="DeepGEMM not available on this platform"
|
||||
not current_platform.is_cuda_alike(),
|
||||
reason="DeepGEMM not available on this platform",
|
||||
)
|
||||
def test_per_token_group_quant_fp8_packed_all_zero():
|
||||
"""All-zero input must produce well-defined UE8M0 scale bytes via the eps
|
||||
@@ -216,7 +221,8 @@ def test_per_token_group_quant_fp8_packed_all_zero():
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda(), reason="DeepGEMM not available on this platform"
|
||||
not current_platform.is_cuda_alike(),
|
||||
reason="DeepGEMM not available on this platform",
|
||||
)
|
||||
def test_per_token_group_quant_fp8_packed_mantissa_rounds_up():
|
||||
"""Inputs whose absmax/max_8bit produces a non-power-of-2 force the
|
||||
@@ -244,7 +250,7 @@ def test_per_token_group_quant_fp8_packed_mantissa_rounds_up():
|
||||
use_ue8m0=True,
|
||||
)
|
||||
|
||||
with patch("vllm.platforms.current_platform.is_cuda", return_value=False):
|
||||
with patch("vllm.platforms.current_platform.is_cuda_alike", return_value=False):
|
||||
ref_q, ref_s = fp8_utils.per_token_group_quant_fp8(
|
||||
x,
|
||||
group_size,
|
||||
@@ -286,7 +292,8 @@ def test_per_token_group_quant_fp8_packed_mantissa_rounds_up():
|
||||
],
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda(), reason="DeepGEMM not available on this platform"
|
||||
not current_platform.is_cuda_alike(),
|
||||
reason="DeepGEMM not available on this platform",
|
||||
)
|
||||
def test_per_token_group_quant_fp8_packed_zero_fills_padded_output_q(
|
||||
num_tokens, hidden_dim
|
||||
@@ -305,8 +312,8 @@ def test_per_token_group_quant_fp8_packed_zero_fills_padded_output_q(
|
||||
k_num_packed = (groups_per_row + 3) // 4
|
||||
tma_aligned_mn = ((mn + 3) // 4) * 4
|
||||
|
||||
fp8_dtype = torch.float8_e4m3fn
|
||||
finfo = torch.finfo(fp8_dtype)
|
||||
fp8_dtype = current_platform.fp8_dtype()
|
||||
fp8_min, fp8_max = get_fp8_min_max()
|
||||
# Allocate output_q with the padded mn extent and pre-fill with 0xFF
|
||||
# so the kernel cannot rely on a clean buffer.
|
||||
out_q = torch.empty((tma_aligned_mn, hidden_dim), device=device, dtype=fp8_dtype)
|
||||
@@ -320,11 +327,11 @@ def test_per_token_group_quant_fp8_packed_zero_fills_padded_output_q(
|
||||
)
|
||||
|
||||
torch.ops._C.per_token_group_fp8_quant_packed(
|
||||
x, out_q, out_s_packed, group_size, 1e-10, finfo.min, finfo.max
|
||||
x, out_q, out_s_packed, group_size, 1e-10, fp8_min, fp8_max
|
||||
)
|
||||
|
||||
# Live rows must match the Triton reference.
|
||||
with patch("vllm.platforms.current_platform.is_cuda", return_value=False):
|
||||
with patch("vllm.platforms.current_platform.is_cuda_alike", return_value=False):
|
||||
ref_q, _ = fp8_utils.per_token_group_quant_fp8(x, group_size, use_ue8m0=True)
|
||||
assert torch.equal(out_q[:mn], ref_q), "Live region mismatch"
|
||||
|
||||
@@ -356,7 +363,7 @@ def test_per_token_group_quant_int8(shape, group_size: int):
|
||||
)
|
||||
|
||||
# triton ref
|
||||
with patch("vllm.platforms.current_platform.is_cuda", return_value=False):
|
||||
with patch("vllm.platforms.current_platform.is_cuda_alike", return_value=False):
|
||||
ref_q, ref_s = int8_utils.per_token_group_quant_int8(
|
||||
x,
|
||||
group_size,
|
||||
|
||||
Reference in New Issue
Block a user