Enable perf_token_group_quant/_C_stable_libtorch for ROCm (#42758)

Signed-off-by: charlifu <charlifu@amd.com>
This commit is contained in:
Charlie Fu
2026-06-03 01:23:28 -05:00
committed by GitHub
parent e0081ef8cf
commit 71df063c49
12 changed files with 145 additions and 109 deletions
+2 -2
View File
@@ -633,6 +633,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
"csrc/libtorch_stable/activation_kernels.cu"
"csrc/libtorch_stable/quantization/w8a8/int8/scaled_quant.cu"
"csrc/libtorch_stable/quantization/w8a8/fp8/common.cu"
"csrc/libtorch_stable/quantization/w8a8/fp8/per_token_group_quant.cu"
"csrc/libtorch_stable/quantization/w8a8/int8/per_token_group_quant.cu"
"csrc/libtorch_stable/quantization/gptq/q_gemm.cu"
"csrc/libtorch_stable/quantization/gguf/gguf_kernel.cu"
"csrc/libtorch_stable/pos_encoding_kernels.cu"
@@ -657,8 +659,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
"csrc/libtorch_stable/quantization/fp4/nvfp4_quant_entry.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_entry.cu"
"csrc/libtorch_stable/permute_cols.cu"
"csrc/libtorch_stable/quantization/w8a8/fp8/per_token_group_quant.cu"
"csrc/libtorch_stable/quantization/w8a8/int8/per_token_group_quant.cu"
"csrc/libtorch_stable/quantization/awq/gemm_kernels.cu")
set_gencode_flags_for_srcs(
+31 -10
View File
@@ -14,7 +14,18 @@ import argparse
import os
import shutil
from torch.utils.hipify.hipify_python import hipify
from torch.utils.hipify.hipify_python import get_hip_file_path, hipify
def _expected_hip_build_path(source_abs: str, output_directory: str) -> str:
"""Match torch.utils.hipify.hipify_python.preprocessor fout_path naming."""
rel = os.path.relpath(source_abs, output_directory)
return os.path.abspath(
os.path.join(
output_directory, get_hip_file_path(rel, is_pytorch_extension=True)
)
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
@@ -53,7 +64,11 @@ if __name__ == "__main__":
hipify_result = hipify(
project_directory=args.project_dir,
output_directory=args.output_dir,
header_include_dirs=[],
# Hipify resolves quoted includes next to the including file first; vLLM
# uses paths relative to csrc/ (e.g. "libtorch_stable/torch_utils.h"
# from quantization/w8a8/fp8/*.cu). Without an include root here, those
# headers are never found and are not hipified or rewritten in dependents.
header_include_dirs=["."],
includes=includes,
extra_files=extra_files,
show_detailed=True,
@@ -64,14 +79,20 @@ if __name__ == "__main__":
hipified_sources = []
for source in args.sources:
s_abs = os.path.abspath(source)
hipified_s_abs = (
hipify_result[s_abs].hipified_path
if (
s_abs in hipify_result
and hipify_result[s_abs].hipified_path is not None
)
else s_abs
)
if s_abs in hipify_result and hipify_result[s_abs].hipified_path is not None:
path = hipify_result[s_abs].hipified_path
# PyTorch skips writing when is_pytorch_extension and text unchanged;
# hipified_path then stays *.cu. CMake expects *.hip under output_dir.
if s_abs.endswith(".cu") and path.endswith(".cu"):
dest = _expected_hip_build_path(s_abs, args.output_dir)
if os.path.normpath(path) != os.path.normpath(dest):
os.makedirs(os.path.dirname(dest), exist_ok=True)
shutil.copy2(path, dest)
hipified_s_abs = dest
else:
hipified_s_abs = path
else:
hipified_s_abs = s_abs
hipified_sources.append(hipified_s_abs)
assert len(hipified_sources) == len(args.sources)
+4 -4
View File
@@ -3,10 +3,6 @@
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/tensor.h>
#ifndef USE_ROCM
torch::stable::Tensor permute_cols(torch::stable::Tensor const& A,
torch::stable::Tensor const& perm);
void per_token_group_quant_fp8(const torch::stable::Tensor& input,
torch::stable::Tensor& output_q,
torch::stable::Tensor& output_s,
@@ -28,6 +24,10 @@ void per_token_group_quant_int8(const torch::stable::Tensor& input,
int64_t group_size, double eps, double int8_min,
double int8_max);
#ifndef USE_ROCM
torch::stable::Tensor permute_cols(torch::stable::Tensor const& A,
torch::stable::Tensor const& perm);
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability);
bool cutlass_group_gemm_supported(int64_t cuda_device_capability);
@@ -7,7 +7,11 @@
#include <cmath>
#include <cuda_fp8.h>
#ifdef USE_ROCM
#include <hip/hip_fp8.h>
#else
#include <cuda_fp8.h>
#endif
#include "libtorch_stable/quantization/vectorization.cuh"
#include "libtorch_stable/quantization/vectorization_utils.cuh"
@@ -15,12 +19,23 @@
#include "libtorch_stable/torch_utils.h"
__device__ __forceinline__ float GroupReduceMax(float val) {
#ifdef USE_ROCM
// 16-thread logical groups may pack up to four per 64-lane wavefront; use a
// 64-bit mask and explicit width so shuffles stay within each group.
const int lane_in_wave = threadIdx.x % warpSize;
const unsigned long long mask = 0xFFFFull << ((lane_in_wave / 16) * 16);
val = fmaxf(val, __shfl_xor_sync(mask, val, 8, 16));
val = fmaxf(val, __shfl_xor_sync(mask, val, 4, 16));
val = fmaxf(val, __shfl_xor_sync(mask, val, 2, 16));
val = fmaxf(val, __shfl_xor_sync(mask, val, 1, 16));
#else
unsigned mask = threadIdx.x % 32 >= 16 ? 0xffff0000 : 0x0000ffff;
val = fmaxf(val, __shfl_xor_sync(mask, val, 8));
val = fmaxf(val, __shfl_xor_sync(mask, val, 4));
val = fmaxf(val, __shfl_xor_sync(mask, val, 2));
val = fmaxf(val, __shfl_xor_sync(mask, val, 1));
#endif
return val;
}
@@ -237,10 +252,12 @@ void per_token_group_quant_8bit(const torch::stable::Tensor& input,
VLLM_STABLE_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "per_token_group_quant_8bit", ([&] {
if (dst_type == torch::headeronly::ScalarType::Float8_e4m3fn) {
LAUNCH_KERNEL(scalar_t, __nv_fp8_e4m3);
} else if (dst_type == torch::headeronly::ScalarType::Char) {
if (dst_type == torch::headeronly::ScalarType::Char) {
LAUNCH_KERNEL(scalar_t, int8_t);
} else {
VLLM_STABLE_DISPATCH_FP8_TYPES(
dst_type, "per_token_group_quant_8bit_fp8",
([&] { LAUNCH_KERNEL(scalar_t, fp8_t); }));
}
}));
@@ -317,10 +334,18 @@ __global__ void per_token_group_quant_8bit_packed_register_kernel(
// 8-lane subgroup shuffle reduce (octet of the warp). The mask selects the
// 8 lanes within the warp that share a group.
#ifdef USE_ROCM
const int lane_in_wave = threadIdx.x % warpSize;
const unsigned long long mask = 0xFFull << (lane_in_wave & ~7);
local_absmax = fmaxf(local_absmax, __shfl_xor_sync(mask, local_absmax, 4, 8));
local_absmax = fmaxf(local_absmax, __shfl_xor_sync(mask, local_absmax, 2, 8));
local_absmax = fmaxf(local_absmax, __shfl_xor_sync(mask, local_absmax, 1, 8));
#else
unsigned mask = 0xffu << (threadIdx.x & 24u);
local_absmax = fmaxf(local_absmax, __shfl_xor_sync(mask, local_absmax, 4));
local_absmax = fmaxf(local_absmax, __shfl_xor_sync(mask, local_absmax, 2));
local_absmax = fmaxf(local_absmax, __shfl_xor_sync(mask, local_absmax, 1));
#endif
float y_s = local_absmax / max_8bit;
y_s = fmaxf(y_s, 1e-10f);
@@ -503,15 +528,12 @@ void per_token_group_quant_8bit_packed(const torch::stable::Tensor& input,
VLLM_STABLE_DISPATCH_HALF_TYPES(
input.scalar_type(), "per_token_group_quant_8bit_packed_register", ([&] {
if (dst_type == torch::headeronly::ScalarType::Float8_e4m3fn) {
LAUNCH_REG_KERNEL(scalar_t, __nv_fp8_e4m3);
} else if (dst_type == torch::headeronly::ScalarType::Char) {
if (dst_type == torch::headeronly::ScalarType::Char) {
LAUNCH_REG_KERNEL(scalar_t, int8_t);
} else {
STD_TORCH_CHECK(
false,
"per_token_group_quant_8bit_packed only supports FP8/INT8 "
"outputs.");
VLLM_STABLE_DISPATCH_FP8_TYPES(
dst_type, "per_token_group_quant_8bit_packed_fp8",
([&] { LAUNCH_REG_KERNEL(scalar_t, fp8_t); }));
}
}));
+10 -10
View File
@@ -7,11 +7,6 @@
// Note: We register under namespace "_C" so ops are accessible as
// torch.ops._C.<op_name> for compatibility with existing code.
STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) {
#ifndef USE_ROCM
ops.def("permute_cols(Tensor A, Tensor perm) -> Tensor");
#endif
#ifndef USE_ROCM
// Compute per-token-group FP8 quantized tensor and scaling factor.
// The dummy arguments are here so we can correctly fuse with RMSNorm.
ops.def(
@@ -32,6 +27,11 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) {
"output_s, int group_size, float eps, float int8_min, float int8_max) -> "
"()");
#ifndef USE_ROCM
ops.def("permute_cols(Tensor A, Tensor perm) -> Tensor");
#endif
#ifndef USE_ROCM
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
// quantization, as well as bias
ops.def(
@@ -508,11 +508,6 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) {
}
STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) {
#ifndef USE_ROCM
ops.impl("permute_cols", TORCH_BOX(&permute_cols));
#endif
#ifndef USE_ROCM
// Per-token group quantization
ops.impl("per_token_group_fp8_quant", TORCH_BOX(&per_token_group_quant_fp8));
ops.impl("per_token_group_fp8_quant_packed",
@@ -520,6 +515,11 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) {
ops.impl("per_token_group_quant_int8",
TORCH_BOX(&per_token_group_quant_int8));
#ifndef USE_ROCM
ops.impl("permute_cols", TORCH_BOX(&permute_cols));
#endif
#ifndef USE_ROCM
// CUTLASS scaled_mm ops
ops.impl("cutlass_scaled_mm", TORCH_BOX(&cutlass_scaled_mm));
ops.impl("cutlass_scaled_mm_azp", TORCH_BOX(&cutlass_scaled_mm_azp));
+1 -5
View File
@@ -6,11 +6,7 @@
#include <torch/csrc/stable/tensor.h>
#include <torch/headeronly/util/shim_utils.h>
#ifndef USE_ROCM
#include <cuda_runtime.h>
#else
#include <hip/hip_runtime.h>
#endif
#include <cuda_runtime.h>
#include <cublas_v2.h>
#include <deque>
@@ -6,6 +6,7 @@ import pytest
import torch
from vllm.model_executor.layers.quantization.utils import fp8_utils, int8_utils
from vllm.model_executor.layers.quantization.utils.quant_utils import get_fp8_min_max
from vllm.platforms import current_platform
@@ -16,7 +17,9 @@ from vllm.platforms import current_platform
@pytest.mark.parametrize("tma_aligned", [False, True])
@pytest.mark.parametrize("scale_ue8m0", [False, True])
@pytest.mark.parametrize("group_size", [64, 128])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(
not current_platform.is_cuda_alike(), reason="Only test on CUDA/ROCm."
)
def test_per_token_group_quant_fp8(
shape, column_major: bool, tma_aligned: bool, scale_ue8m0: bool, group_size: int
):
@@ -37,7 +40,7 @@ def test_per_token_group_quant_fp8(
)
# triton ref
with patch("vllm.platforms.current_platform.is_cuda", return_value=False):
with patch("vllm.platforms.current_platform.is_cuda_alike", return_value=False):
ref_q, ref_s = fp8_utils.per_token_group_quant_fp8(
x,
group_size,
@@ -77,7 +80,8 @@ def test_per_token_group_quant_fp8(
)
@pytest.mark.parametrize("poisoned_scales", [False, True])
@pytest.mark.skipif(
not current_platform.is_cuda(), reason="DeepGEMM not available on this platform"
not current_platform.is_cuda_alike(),
reason="DeepGEMM not available on this platform",
)
def test_per_token_group_quant_fp8_packed(
num_tokens, hidden_dim, group_size, poisoned_scales
@@ -99,8 +103,8 @@ def test_per_token_group_quant_fp8_packed(
if poisoned_scales:
# Call the kernel with poisoned scale buffer to
# ensure padded indices are correctly zeroed.
fp8_dtype = torch.float8_e4m3fn
finfo = torch.finfo(fp8_dtype)
fp8_dtype = current_platform.fp8_dtype()
fp8_min, fp8_max = get_fp8_min_max()
out_q = torch.empty_like(x, dtype=fp8_dtype)
out_s_packed = torch.empty_strided(
(mn, k_num_packed),
@@ -115,8 +119,8 @@ def test_per_token_group_quant_fp8_packed(
out_s_packed,
group_size,
1e-10,
finfo.min,
finfo.max,
fp8_min,
fp8_max,
)
else:
out_q, out_s_packed = fp8_utils.per_token_group_quant_fp8_packed_for_deepgemm(
@@ -126,7 +130,7 @@ def test_per_token_group_quant_fp8_packed(
)
# Triton reference (row-major float32 scales, UE8M0)
with patch("vllm.platforms.current_platform.is_cuda", return_value=False):
with patch("vllm.platforms.current_platform.is_cuda_alike", return_value=False):
ref_q, ref_s = fp8_utils.per_token_group_quant_fp8(
x,
group_size,
@@ -157,7 +161,8 @@ def test_per_token_group_quant_fp8_packed(
@pytest.mark.skipif(
not current_platform.is_cuda(), reason="DeepGEMM not available on this platform"
not current_platform.is_cuda_alike(),
reason="DeepGEMM not available on this platform",
)
def test_per_token_group_quant_fp8_packed_all_zero():
"""All-zero input must produce well-defined UE8M0 scale bytes via the eps
@@ -216,7 +221,8 @@ def test_per_token_group_quant_fp8_packed_all_zero():
@pytest.mark.skipif(
not current_platform.is_cuda(), reason="DeepGEMM not available on this platform"
not current_platform.is_cuda_alike(),
reason="DeepGEMM not available on this platform",
)
def test_per_token_group_quant_fp8_packed_mantissa_rounds_up():
"""Inputs whose absmax/max_8bit produces a non-power-of-2 force the
@@ -244,7 +250,7 @@ def test_per_token_group_quant_fp8_packed_mantissa_rounds_up():
use_ue8m0=True,
)
with patch("vllm.platforms.current_platform.is_cuda", return_value=False):
with patch("vllm.platforms.current_platform.is_cuda_alike", return_value=False):
ref_q, ref_s = fp8_utils.per_token_group_quant_fp8(
x,
group_size,
@@ -286,7 +292,8 @@ def test_per_token_group_quant_fp8_packed_mantissa_rounds_up():
],
)
@pytest.mark.skipif(
not current_platform.is_cuda(), reason="DeepGEMM not available on this platform"
not current_platform.is_cuda_alike(),
reason="DeepGEMM not available on this platform",
)
def test_per_token_group_quant_fp8_packed_zero_fills_padded_output_q(
num_tokens, hidden_dim
@@ -305,8 +312,8 @@ def test_per_token_group_quant_fp8_packed_zero_fills_padded_output_q(
k_num_packed = (groups_per_row + 3) // 4
tma_aligned_mn = ((mn + 3) // 4) * 4
fp8_dtype = torch.float8_e4m3fn
finfo = torch.finfo(fp8_dtype)
fp8_dtype = current_platform.fp8_dtype()
fp8_min, fp8_max = get_fp8_min_max()
# Allocate output_q with the padded mn extent and pre-fill with 0xFF
# so the kernel cannot rely on a clean buffer.
out_q = torch.empty((tma_aligned_mn, hidden_dim), device=device, dtype=fp8_dtype)
@@ -320,11 +327,11 @@ def test_per_token_group_quant_fp8_packed_zero_fills_padded_output_q(
)
torch.ops._C.per_token_group_fp8_quant_packed(
x, out_q, out_s_packed, group_size, 1e-10, finfo.min, finfo.max
x, out_q, out_s_packed, group_size, 1e-10, fp8_min, fp8_max
)
# Live rows must match the Triton reference.
with patch("vllm.platforms.current_platform.is_cuda", return_value=False):
with patch("vllm.platforms.current_platform.is_cuda_alike", return_value=False):
ref_q, _ = fp8_utils.per_token_group_quant_fp8(x, group_size, use_ue8m0=True)
assert torch.equal(out_q[:mn], ref_q), "Live region mismatch"
@@ -356,7 +363,7 @@ def test_per_token_group_quant_int8(shape, group_size: int):
)
# triton ref
with patch("vllm.platforms.current_platform.is_cuda", return_value=False):
with patch("vllm.platforms.current_platform.is_cuda_alike", return_value=False):
ref_q, ref_s = int8_utils.per_token_group_quant_int8(
x,
group_size,
@@ -36,14 +36,13 @@ QUANT_OPS: dict[QuantKey, OpOverload] = {
kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501
kFp8DynamicTensorSym: torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501
kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
kFp8Dynamic128Sym: torch.ops._C.per_token_group_fp8_quant.default, # noqa: E501
kFp8Dynamic64Sym: torch.ops._C.per_token_group_fp8_quant.default, # noqa: E501
}
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.out # noqa: E501
if current_platform.is_cuda():
QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
QUANT_OPS[kFp8Dynamic64Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
SILU_MUL_OP = torch.ops._C.silu_and_mul.default
@@ -16,6 +16,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
QuantKey,
ScaleDesc,
get_fp8_min_max,
kFp8Dynamic64Sym,
kFp8Dynamic128Sym,
kFp8DynamicTensorSym,
@@ -83,12 +84,11 @@ QUANT_OPS: dict[QuantKey, OpOverload] = {
kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501
kFp8DynamicTensorSym: torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501
kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
kFp8Dynamic128Sym: torch.ops._C.per_token_group_fp8_quant.default, # noqa: E501
kFp8Dynamic64Sym: torch.ops._C.per_token_group_fp8_quant.default, # noqa: E501
}
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.out
if current_platform.is_cuda():
QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
QUANT_OPS[kFp8Dynamic64Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
class FusedRMSQuantKey(NamedTuple):
@@ -327,9 +327,7 @@ class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern):
dtype=self.quant_matcher.quant_key.dtype,
)
assert scale is not None
finfo = torch.finfo(self.quant_matcher.quant_key.dtype)
fp8_min = finfo.min
fp8_max = finfo.max
fp8_min, fp8_max = get_fp8_min_max()
_, result, scale = auto_functionalized(
self.quant_matcher.QUANT_OP,
@@ -430,9 +428,7 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern):
dtype=self.quant_matcher.quant_key.dtype,
)
assert scale is not None
finfo = torch.finfo(self.quant_matcher.quant_key.dtype)
fp8_min = finfo.min
fp8_max = finfo.max
fp8_min, fp8_max = get_fp8_min_max()
_, result, scale = auto_functionalized(
self.quant_matcher.QUANT_OP,
@@ -645,31 +641,30 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass):
# Fuse rms_norm + dynamic per-token fp8 quant
RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
# Only register group quant patterns on CUDA where the C++ op exists
if current_platform.is_cuda():
for group_shape in [GroupShape(1, 128), GroupShape(1, 64)]:
for has_col_major_scales in [True, False]:
for is_e8m0 in [True, False]:
for is_tma_aligned in [False, True]:
# Fuse fused_add_rms_norm + fp8 group quant
FusedAddRMSNormGroupQuantPattern(
epsilon,
FP8_DTYPE,
group_shape=group_shape,
is_e8m0=is_e8m0,
has_col_major_scales=has_col_major_scales,
is_tma_aligned=is_tma_aligned,
).register(self.patterns)
# Only register group quant patterns on CUDA/ROCm where the C++ op exists
for group_shape in [GroupShape(1, 128), GroupShape(1, 64)]:
for has_col_major_scales in [True, False]:
for is_e8m0 in [True, False]:
for is_tma_aligned in [False, True]:
# Fuse fused_add_rms_norm + fp8 group quant
FusedAddRMSNormGroupQuantPattern(
epsilon,
FP8_DTYPE,
group_shape=group_shape,
is_e8m0=is_e8m0,
has_col_major_scales=has_col_major_scales,
is_tma_aligned=is_tma_aligned,
).register(self.patterns)
# Fuse rms_norm + fp8 group quant
RMSNormGroupQuantPattern(
epsilon,
FP8_DTYPE,
group_shape=group_shape,
is_e8m0=is_e8m0,
has_col_major_scales=has_col_major_scales,
is_tma_aligned=is_tma_aligned,
).register(self.patterns)
# Fuse rms_norm + fp8 group quant
RMSNormGroupQuantPattern(
epsilon,
FP8_DTYPE,
group_shape=group_shape,
is_e8m0=is_e8m0,
has_col_major_scales=has_col_major_scales,
is_tma_aligned=is_tma_aligned,
).register(self.patterns)
self.dump_patterns(config, self.patterns)
@@ -158,11 +158,6 @@ class QuantFP8(CustomOp):
if use_aiter_per_token_quant:
return rocm_aiter_ops.per_token_quant(x, _FP8_DTYPE, scale)
# Fallback to native implementation for group quantization.
if self.is_group_quant:
assert scale is None, "Dynamic group quantization does not use scale"
return self._quantize_group_native(x)
# Fallback to CUDA implementation
return self.forward_cuda(x, scale, scale_ub)
@@ -575,7 +575,9 @@ def per_token_group_quant_fp8(
# prefer CUDA/XPU kernel if available
# TODO(bnell): this causes some fp8 moe test to fail.
if (current_platform.is_cuda() or current_platform.is_xpu()) and x.is_contiguous():
if (
current_platform.is_cuda_alike() or current_platform.is_xpu()
) and x.is_contiguous():
torch.ops._C.per_token_group_fp8_quant(
x,
x_q,
@@ -664,8 +666,7 @@ def per_token_group_quant_fp8_packed_for_deepgemm(
)
assert x.stride(-1) == 1, "`x` groups must be contiguous"
finfo = torch.finfo(dtype)
fp8_min, fp8_max = finfo.min, finfo.max
fp8_min, fp8_max = get_fp8_min_max()
# compute DeepGEMM-style packed scale tensor shape.
hidden_dim = x.shape[-1]
@@ -681,10 +682,10 @@ def per_token_group_quant_fp8_packed_for_deepgemm(
dtype=torch.int32,
)
# CUDA kernel path only (DeepGEMM + E8M0 is CUDA-specific).
assert current_platform.is_cuda(), (
"per_token_group_quant_fp8_packed_for_deepgemm is only valid on CUDA "
"platforms using DeepGEMM."
# Native kernel (libtorch stable); used with DeepGEMM on CUDA and
# available on ROCm for the same packed UE8M0 scale layout.
assert current_platform.is_cuda_alike(), (
"per_token_group_quant_fp8_packed_for_deepgemm requires a CUDA or ROCm GPU."
)
x_contiguous = x.contiguous()
@@ -235,8 +235,8 @@ def per_token_group_quant_int8(
device=x.device,
dtype=torch.float32,
)
# prefer CUDA kernel if available
if current_platform.is_cuda():
# Prefer native stable kernel on CUDA/ROCm when available.
if current_platform.is_cuda_alike():
torch.ops._C.per_token_group_quant_int8(
x, x_q, x_s, group_size, eps, float(int8_min), float(int8_max)
)