mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
Enable perf_token_group_quant/_C_stable_libtorch for ROCm (#42758)
Signed-off-by: charlifu <charlifu@amd.com>
This commit is contained in:
+2
-2
@@ -633,6 +633,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
|
||||
"csrc/libtorch_stable/activation_kernels.cu"
|
||||
"csrc/libtorch_stable/quantization/w8a8/int8/scaled_quant.cu"
|
||||
"csrc/libtorch_stable/quantization/w8a8/fp8/common.cu"
|
||||
"csrc/libtorch_stable/quantization/w8a8/fp8/per_token_group_quant.cu"
|
||||
"csrc/libtorch_stable/quantization/w8a8/int8/per_token_group_quant.cu"
|
||||
"csrc/libtorch_stable/quantization/gptq/q_gemm.cu"
|
||||
"csrc/libtorch_stable/quantization/gguf/gguf_kernel.cu"
|
||||
"csrc/libtorch_stable/pos_encoding_kernels.cu"
|
||||
@@ -657,8 +659,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
|
||||
"csrc/libtorch_stable/quantization/fp4/nvfp4_quant_entry.cu"
|
||||
"csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_entry.cu"
|
||||
"csrc/libtorch_stable/permute_cols.cu"
|
||||
"csrc/libtorch_stable/quantization/w8a8/fp8/per_token_group_quant.cu"
|
||||
"csrc/libtorch_stable/quantization/w8a8/int8/per_token_group_quant.cu"
|
||||
"csrc/libtorch_stable/quantization/awq/gemm_kernels.cu")
|
||||
|
||||
set_gencode_flags_for_srcs(
|
||||
|
||||
+31
-10
@@ -14,7 +14,18 @@ import argparse
|
||||
import os
|
||||
import shutil
|
||||
|
||||
from torch.utils.hipify.hipify_python import hipify
|
||||
from torch.utils.hipify.hipify_python import get_hip_file_path, hipify
|
||||
|
||||
|
||||
def _expected_hip_build_path(source_abs: str, output_directory: str) -> str:
|
||||
"""Match torch.utils.hipify.hipify_python.preprocessor fout_path naming."""
|
||||
rel = os.path.relpath(source_abs, output_directory)
|
||||
return os.path.abspath(
|
||||
os.path.join(
|
||||
output_directory, get_hip_file_path(rel, is_pytorch_extension=True)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
@@ -53,7 +64,11 @@ if __name__ == "__main__":
|
||||
hipify_result = hipify(
|
||||
project_directory=args.project_dir,
|
||||
output_directory=args.output_dir,
|
||||
header_include_dirs=[],
|
||||
# Hipify resolves quoted includes next to the including file first; vLLM
|
||||
# uses paths relative to csrc/ (e.g. "libtorch_stable/torch_utils.h"
|
||||
# from quantization/w8a8/fp8/*.cu). Without an include root here, those
|
||||
# headers are never found and are not hipified or rewritten in dependents.
|
||||
header_include_dirs=["."],
|
||||
includes=includes,
|
||||
extra_files=extra_files,
|
||||
show_detailed=True,
|
||||
@@ -64,14 +79,20 @@ if __name__ == "__main__":
|
||||
hipified_sources = []
|
||||
for source in args.sources:
|
||||
s_abs = os.path.abspath(source)
|
||||
hipified_s_abs = (
|
||||
hipify_result[s_abs].hipified_path
|
||||
if (
|
||||
s_abs in hipify_result
|
||||
and hipify_result[s_abs].hipified_path is not None
|
||||
)
|
||||
else s_abs
|
||||
)
|
||||
if s_abs in hipify_result and hipify_result[s_abs].hipified_path is not None:
|
||||
path = hipify_result[s_abs].hipified_path
|
||||
# PyTorch skips writing when is_pytorch_extension and text unchanged;
|
||||
# hipified_path then stays *.cu. CMake expects *.hip under output_dir.
|
||||
if s_abs.endswith(".cu") and path.endswith(".cu"):
|
||||
dest = _expected_hip_build_path(s_abs, args.output_dir)
|
||||
if os.path.normpath(path) != os.path.normpath(dest):
|
||||
os.makedirs(os.path.dirname(dest), exist_ok=True)
|
||||
shutil.copy2(path, dest)
|
||||
hipified_s_abs = dest
|
||||
else:
|
||||
hipified_s_abs = path
|
||||
else:
|
||||
hipified_s_abs = s_abs
|
||||
hipified_sources.append(hipified_s_abs)
|
||||
|
||||
assert len(hipified_sources) == len(args.sources)
|
||||
|
||||
@@ -3,10 +3,6 @@
|
||||
#include <torch/csrc/stable/library.h>
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
|
||||
#ifndef USE_ROCM
|
||||
torch::stable::Tensor permute_cols(torch::stable::Tensor const& A,
|
||||
torch::stable::Tensor const& perm);
|
||||
|
||||
void per_token_group_quant_fp8(const torch::stable::Tensor& input,
|
||||
torch::stable::Tensor& output_q,
|
||||
torch::stable::Tensor& output_s,
|
||||
@@ -28,6 +24,10 @@ void per_token_group_quant_int8(const torch::stable::Tensor& input,
|
||||
int64_t group_size, double eps, double int8_min,
|
||||
double int8_max);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
torch::stable::Tensor permute_cols(torch::stable::Tensor const& A,
|
||||
torch::stable::Tensor const& perm);
|
||||
|
||||
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
|
||||
bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability);
|
||||
bool cutlass_group_gemm_supported(int64_t cuda_device_capability);
|
||||
|
||||
@@ -7,7 +7,11 @@
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include <cuda_fp8.h>
|
||||
#ifdef USE_ROCM
|
||||
#include <hip/hip_fp8.h>
|
||||
#else
|
||||
#include <cuda_fp8.h>
|
||||
#endif
|
||||
|
||||
#include "libtorch_stable/quantization/vectorization.cuh"
|
||||
#include "libtorch_stable/quantization/vectorization_utils.cuh"
|
||||
@@ -15,12 +19,23 @@
|
||||
#include "libtorch_stable/torch_utils.h"
|
||||
|
||||
__device__ __forceinline__ float GroupReduceMax(float val) {
|
||||
#ifdef USE_ROCM
|
||||
// 16-thread logical groups may pack up to four per 64-lane wavefront; use a
|
||||
// 64-bit mask and explicit width so shuffles stay within each group.
|
||||
const int lane_in_wave = threadIdx.x % warpSize;
|
||||
const unsigned long long mask = 0xFFFFull << ((lane_in_wave / 16) * 16);
|
||||
val = fmaxf(val, __shfl_xor_sync(mask, val, 8, 16));
|
||||
val = fmaxf(val, __shfl_xor_sync(mask, val, 4, 16));
|
||||
val = fmaxf(val, __shfl_xor_sync(mask, val, 2, 16));
|
||||
val = fmaxf(val, __shfl_xor_sync(mask, val, 1, 16));
|
||||
#else
|
||||
unsigned mask = threadIdx.x % 32 >= 16 ? 0xffff0000 : 0x0000ffff;
|
||||
|
||||
val = fmaxf(val, __shfl_xor_sync(mask, val, 8));
|
||||
val = fmaxf(val, __shfl_xor_sync(mask, val, 4));
|
||||
val = fmaxf(val, __shfl_xor_sync(mask, val, 2));
|
||||
val = fmaxf(val, __shfl_xor_sync(mask, val, 1));
|
||||
#endif
|
||||
return val;
|
||||
}
|
||||
|
||||
@@ -237,10 +252,12 @@ void per_token_group_quant_8bit(const torch::stable::Tensor& input,
|
||||
|
||||
VLLM_STABLE_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "per_token_group_quant_8bit", ([&] {
|
||||
if (dst_type == torch::headeronly::ScalarType::Float8_e4m3fn) {
|
||||
LAUNCH_KERNEL(scalar_t, __nv_fp8_e4m3);
|
||||
} else if (dst_type == torch::headeronly::ScalarType::Char) {
|
||||
if (dst_type == torch::headeronly::ScalarType::Char) {
|
||||
LAUNCH_KERNEL(scalar_t, int8_t);
|
||||
} else {
|
||||
VLLM_STABLE_DISPATCH_FP8_TYPES(
|
||||
dst_type, "per_token_group_quant_8bit_fp8",
|
||||
([&] { LAUNCH_KERNEL(scalar_t, fp8_t); }));
|
||||
}
|
||||
}));
|
||||
|
||||
@@ -317,10 +334,18 @@ __global__ void per_token_group_quant_8bit_packed_register_kernel(
|
||||
|
||||
// 8-lane subgroup shuffle reduce (octet of the warp). The mask selects the
|
||||
// 8 lanes within the warp that share a group.
|
||||
#ifdef USE_ROCM
|
||||
const int lane_in_wave = threadIdx.x % warpSize;
|
||||
const unsigned long long mask = 0xFFull << (lane_in_wave & ~7);
|
||||
local_absmax = fmaxf(local_absmax, __shfl_xor_sync(mask, local_absmax, 4, 8));
|
||||
local_absmax = fmaxf(local_absmax, __shfl_xor_sync(mask, local_absmax, 2, 8));
|
||||
local_absmax = fmaxf(local_absmax, __shfl_xor_sync(mask, local_absmax, 1, 8));
|
||||
#else
|
||||
unsigned mask = 0xffu << (threadIdx.x & 24u);
|
||||
local_absmax = fmaxf(local_absmax, __shfl_xor_sync(mask, local_absmax, 4));
|
||||
local_absmax = fmaxf(local_absmax, __shfl_xor_sync(mask, local_absmax, 2));
|
||||
local_absmax = fmaxf(local_absmax, __shfl_xor_sync(mask, local_absmax, 1));
|
||||
#endif
|
||||
|
||||
float y_s = local_absmax / max_8bit;
|
||||
y_s = fmaxf(y_s, 1e-10f);
|
||||
@@ -503,15 +528,12 @@ void per_token_group_quant_8bit_packed(const torch::stable::Tensor& input,
|
||||
|
||||
VLLM_STABLE_DISPATCH_HALF_TYPES(
|
||||
input.scalar_type(), "per_token_group_quant_8bit_packed_register", ([&] {
|
||||
if (dst_type == torch::headeronly::ScalarType::Float8_e4m3fn) {
|
||||
LAUNCH_REG_KERNEL(scalar_t, __nv_fp8_e4m3);
|
||||
} else if (dst_type == torch::headeronly::ScalarType::Char) {
|
||||
if (dst_type == torch::headeronly::ScalarType::Char) {
|
||||
LAUNCH_REG_KERNEL(scalar_t, int8_t);
|
||||
} else {
|
||||
STD_TORCH_CHECK(
|
||||
false,
|
||||
"per_token_group_quant_8bit_packed only supports FP8/INT8 "
|
||||
"outputs.");
|
||||
VLLM_STABLE_DISPATCH_FP8_TYPES(
|
||||
dst_type, "per_token_group_quant_8bit_packed_fp8",
|
||||
([&] { LAUNCH_REG_KERNEL(scalar_t, fp8_t); }));
|
||||
}
|
||||
}));
|
||||
|
||||
|
||||
@@ -7,11 +7,6 @@
|
||||
// Note: We register under namespace "_C" so ops are accessible as
|
||||
// torch.ops._C.<op_name> for compatibility with existing code.
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) {
|
||||
#ifndef USE_ROCM
|
||||
ops.def("permute_cols(Tensor A, Tensor perm) -> Tensor");
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
// Compute per-token-group FP8 quantized tensor and scaling factor.
|
||||
// The dummy arguments are here so we can correctly fuse with RMSNorm.
|
||||
ops.def(
|
||||
@@ -32,6 +27,11 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) {
|
||||
"output_s, int group_size, float eps, float int8_min, float int8_max) -> "
|
||||
"()");
|
||||
|
||||
#ifndef USE_ROCM
|
||||
ops.def("permute_cols(Tensor A, Tensor perm) -> Tensor");
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
|
||||
// quantization, as well as bias
|
||||
ops.def(
|
||||
@@ -508,11 +508,6 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) {
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) {
|
||||
#ifndef USE_ROCM
|
||||
ops.impl("permute_cols", TORCH_BOX(&permute_cols));
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
// Per-token group quantization
|
||||
ops.impl("per_token_group_fp8_quant", TORCH_BOX(&per_token_group_quant_fp8));
|
||||
ops.impl("per_token_group_fp8_quant_packed",
|
||||
@@ -520,6 +515,11 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) {
|
||||
ops.impl("per_token_group_quant_int8",
|
||||
TORCH_BOX(&per_token_group_quant_int8));
|
||||
|
||||
#ifndef USE_ROCM
|
||||
ops.impl("permute_cols", TORCH_BOX(&permute_cols));
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
// CUTLASS scaled_mm ops
|
||||
ops.impl("cutlass_scaled_mm", TORCH_BOX(&cutlass_scaled_mm));
|
||||
ops.impl("cutlass_scaled_mm_azp", TORCH_BOX(&cutlass_scaled_mm_azp));
|
||||
|
||||
@@ -6,11 +6,7 @@
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include <torch/headeronly/util/shim_utils.h>
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include <cuda_runtime.h>
|
||||
#else
|
||||
#include <hip/hip_runtime.h>
|
||||
#endif
|
||||
#include <cuda_runtime.h>
|
||||
#include <cublas_v2.h>
|
||||
|
||||
#include <deque>
|
||||
|
||||
@@ -6,6 +6,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils import fp8_utils, int8_utils
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import get_fp8_min_max
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@@ -16,7 +17,9 @@ from vllm.platforms import current_platform
|
||||
@pytest.mark.parametrize("tma_aligned", [False, True])
|
||||
@pytest.mark.parametrize("scale_ue8m0", [False, True])
|
||||
@pytest.mark.parametrize("group_size", [64, 128])
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda_alike(), reason="Only test on CUDA/ROCm."
|
||||
)
|
||||
def test_per_token_group_quant_fp8(
|
||||
shape, column_major: bool, tma_aligned: bool, scale_ue8m0: bool, group_size: int
|
||||
):
|
||||
@@ -37,7 +40,7 @@ def test_per_token_group_quant_fp8(
|
||||
)
|
||||
|
||||
# triton ref
|
||||
with patch("vllm.platforms.current_platform.is_cuda", return_value=False):
|
||||
with patch("vllm.platforms.current_platform.is_cuda_alike", return_value=False):
|
||||
ref_q, ref_s = fp8_utils.per_token_group_quant_fp8(
|
||||
x,
|
||||
group_size,
|
||||
@@ -77,7 +80,8 @@ def test_per_token_group_quant_fp8(
|
||||
)
|
||||
@pytest.mark.parametrize("poisoned_scales", [False, True])
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda(), reason="DeepGEMM not available on this platform"
|
||||
not current_platform.is_cuda_alike(),
|
||||
reason="DeepGEMM not available on this platform",
|
||||
)
|
||||
def test_per_token_group_quant_fp8_packed(
|
||||
num_tokens, hidden_dim, group_size, poisoned_scales
|
||||
@@ -99,8 +103,8 @@ def test_per_token_group_quant_fp8_packed(
|
||||
if poisoned_scales:
|
||||
# Call the kernel with poisoned scale buffer to
|
||||
# ensure padded indices are correctly zeroed.
|
||||
fp8_dtype = torch.float8_e4m3fn
|
||||
finfo = torch.finfo(fp8_dtype)
|
||||
fp8_dtype = current_platform.fp8_dtype()
|
||||
fp8_min, fp8_max = get_fp8_min_max()
|
||||
out_q = torch.empty_like(x, dtype=fp8_dtype)
|
||||
out_s_packed = torch.empty_strided(
|
||||
(mn, k_num_packed),
|
||||
@@ -115,8 +119,8 @@ def test_per_token_group_quant_fp8_packed(
|
||||
out_s_packed,
|
||||
group_size,
|
||||
1e-10,
|
||||
finfo.min,
|
||||
finfo.max,
|
||||
fp8_min,
|
||||
fp8_max,
|
||||
)
|
||||
else:
|
||||
out_q, out_s_packed = fp8_utils.per_token_group_quant_fp8_packed_for_deepgemm(
|
||||
@@ -126,7 +130,7 @@ def test_per_token_group_quant_fp8_packed(
|
||||
)
|
||||
|
||||
# Triton reference (row-major float32 scales, UE8M0)
|
||||
with patch("vllm.platforms.current_platform.is_cuda", return_value=False):
|
||||
with patch("vllm.platforms.current_platform.is_cuda_alike", return_value=False):
|
||||
ref_q, ref_s = fp8_utils.per_token_group_quant_fp8(
|
||||
x,
|
||||
group_size,
|
||||
@@ -157,7 +161,8 @@ def test_per_token_group_quant_fp8_packed(
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda(), reason="DeepGEMM not available on this platform"
|
||||
not current_platform.is_cuda_alike(),
|
||||
reason="DeepGEMM not available on this platform",
|
||||
)
|
||||
def test_per_token_group_quant_fp8_packed_all_zero():
|
||||
"""All-zero input must produce well-defined UE8M0 scale bytes via the eps
|
||||
@@ -216,7 +221,8 @@ def test_per_token_group_quant_fp8_packed_all_zero():
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda(), reason="DeepGEMM not available on this platform"
|
||||
not current_platform.is_cuda_alike(),
|
||||
reason="DeepGEMM not available on this platform",
|
||||
)
|
||||
def test_per_token_group_quant_fp8_packed_mantissa_rounds_up():
|
||||
"""Inputs whose absmax/max_8bit produces a non-power-of-2 force the
|
||||
@@ -244,7 +250,7 @@ def test_per_token_group_quant_fp8_packed_mantissa_rounds_up():
|
||||
use_ue8m0=True,
|
||||
)
|
||||
|
||||
with patch("vllm.platforms.current_platform.is_cuda", return_value=False):
|
||||
with patch("vllm.platforms.current_platform.is_cuda_alike", return_value=False):
|
||||
ref_q, ref_s = fp8_utils.per_token_group_quant_fp8(
|
||||
x,
|
||||
group_size,
|
||||
@@ -286,7 +292,8 @@ def test_per_token_group_quant_fp8_packed_mantissa_rounds_up():
|
||||
],
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda(), reason="DeepGEMM not available on this platform"
|
||||
not current_platform.is_cuda_alike(),
|
||||
reason="DeepGEMM not available on this platform",
|
||||
)
|
||||
def test_per_token_group_quant_fp8_packed_zero_fills_padded_output_q(
|
||||
num_tokens, hidden_dim
|
||||
@@ -305,8 +312,8 @@ def test_per_token_group_quant_fp8_packed_zero_fills_padded_output_q(
|
||||
k_num_packed = (groups_per_row + 3) // 4
|
||||
tma_aligned_mn = ((mn + 3) // 4) * 4
|
||||
|
||||
fp8_dtype = torch.float8_e4m3fn
|
||||
finfo = torch.finfo(fp8_dtype)
|
||||
fp8_dtype = current_platform.fp8_dtype()
|
||||
fp8_min, fp8_max = get_fp8_min_max()
|
||||
# Allocate output_q with the padded mn extent and pre-fill with 0xFF
|
||||
# so the kernel cannot rely on a clean buffer.
|
||||
out_q = torch.empty((tma_aligned_mn, hidden_dim), device=device, dtype=fp8_dtype)
|
||||
@@ -320,11 +327,11 @@ def test_per_token_group_quant_fp8_packed_zero_fills_padded_output_q(
|
||||
)
|
||||
|
||||
torch.ops._C.per_token_group_fp8_quant_packed(
|
||||
x, out_q, out_s_packed, group_size, 1e-10, finfo.min, finfo.max
|
||||
x, out_q, out_s_packed, group_size, 1e-10, fp8_min, fp8_max
|
||||
)
|
||||
|
||||
# Live rows must match the Triton reference.
|
||||
with patch("vllm.platforms.current_platform.is_cuda", return_value=False):
|
||||
with patch("vllm.platforms.current_platform.is_cuda_alike", return_value=False):
|
||||
ref_q, _ = fp8_utils.per_token_group_quant_fp8(x, group_size, use_ue8m0=True)
|
||||
assert torch.equal(out_q[:mn], ref_q), "Live region mismatch"
|
||||
|
||||
@@ -356,7 +363,7 @@ def test_per_token_group_quant_int8(shape, group_size: int):
|
||||
)
|
||||
|
||||
# triton ref
|
||||
with patch("vllm.platforms.current_platform.is_cuda", return_value=False):
|
||||
with patch("vllm.platforms.current_platform.is_cuda_alike", return_value=False):
|
||||
ref_q, ref_s = int8_utils.per_token_group_quant_int8(
|
||||
x,
|
||||
group_size,
|
||||
|
||||
@@ -36,14 +36,13 @@ QUANT_OPS: dict[QuantKey, OpOverload] = {
|
||||
kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501
|
||||
kFp8DynamicTensorSym: torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501
|
||||
kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
|
||||
kFp8Dynamic128Sym: torch.ops._C.per_token_group_fp8_quant.default, # noqa: E501
|
||||
kFp8Dynamic64Sym: torch.ops._C.per_token_group_fp8_quant.default, # noqa: E501
|
||||
}
|
||||
|
||||
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
|
||||
QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.out # noqa: E501
|
||||
|
||||
if current_platform.is_cuda():
|
||||
QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
|
||||
QUANT_OPS[kFp8Dynamic64Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
|
||||
|
||||
SILU_MUL_OP = torch.ops._C.silu_and_mul.default
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
QuantKey,
|
||||
ScaleDesc,
|
||||
get_fp8_min_max,
|
||||
kFp8Dynamic64Sym,
|
||||
kFp8Dynamic128Sym,
|
||||
kFp8DynamicTensorSym,
|
||||
@@ -83,12 +84,11 @@ QUANT_OPS: dict[QuantKey, OpOverload] = {
|
||||
kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501
|
||||
kFp8DynamicTensorSym: torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501
|
||||
kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
|
||||
kFp8Dynamic128Sym: torch.ops._C.per_token_group_fp8_quant.default, # noqa: E501
|
||||
kFp8Dynamic64Sym: torch.ops._C.per_token_group_fp8_quant.default, # noqa: E501
|
||||
}
|
||||
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
|
||||
QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.out
|
||||
if current_platform.is_cuda():
|
||||
QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
|
||||
QUANT_OPS[kFp8Dynamic64Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
|
||||
|
||||
|
||||
class FusedRMSQuantKey(NamedTuple):
|
||||
@@ -327,9 +327,7 @@ class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern):
|
||||
dtype=self.quant_matcher.quant_key.dtype,
|
||||
)
|
||||
assert scale is not None
|
||||
finfo = torch.finfo(self.quant_matcher.quant_key.dtype)
|
||||
fp8_min = finfo.min
|
||||
fp8_max = finfo.max
|
||||
fp8_min, fp8_max = get_fp8_min_max()
|
||||
|
||||
_, result, scale = auto_functionalized(
|
||||
self.quant_matcher.QUANT_OP,
|
||||
@@ -430,9 +428,7 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern):
|
||||
dtype=self.quant_matcher.quant_key.dtype,
|
||||
)
|
||||
assert scale is not None
|
||||
finfo = torch.finfo(self.quant_matcher.quant_key.dtype)
|
||||
fp8_min = finfo.min
|
||||
fp8_max = finfo.max
|
||||
fp8_min, fp8_max = get_fp8_min_max()
|
||||
|
||||
_, result, scale = auto_functionalized(
|
||||
self.quant_matcher.QUANT_OP,
|
||||
@@ -645,31 +641,30 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass):
|
||||
# Fuse rms_norm + dynamic per-token fp8 quant
|
||||
RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
|
||||
|
||||
# Only register group quant patterns on CUDA where the C++ op exists
|
||||
if current_platform.is_cuda():
|
||||
for group_shape in [GroupShape(1, 128), GroupShape(1, 64)]:
|
||||
for has_col_major_scales in [True, False]:
|
||||
for is_e8m0 in [True, False]:
|
||||
for is_tma_aligned in [False, True]:
|
||||
# Fuse fused_add_rms_norm + fp8 group quant
|
||||
FusedAddRMSNormGroupQuantPattern(
|
||||
epsilon,
|
||||
FP8_DTYPE,
|
||||
group_shape=group_shape,
|
||||
is_e8m0=is_e8m0,
|
||||
has_col_major_scales=has_col_major_scales,
|
||||
is_tma_aligned=is_tma_aligned,
|
||||
).register(self.patterns)
|
||||
# Only register group quant patterns on CUDA/ROCm where the C++ op exists
|
||||
for group_shape in [GroupShape(1, 128), GroupShape(1, 64)]:
|
||||
for has_col_major_scales in [True, False]:
|
||||
for is_e8m0 in [True, False]:
|
||||
for is_tma_aligned in [False, True]:
|
||||
# Fuse fused_add_rms_norm + fp8 group quant
|
||||
FusedAddRMSNormGroupQuantPattern(
|
||||
epsilon,
|
||||
FP8_DTYPE,
|
||||
group_shape=group_shape,
|
||||
is_e8m0=is_e8m0,
|
||||
has_col_major_scales=has_col_major_scales,
|
||||
is_tma_aligned=is_tma_aligned,
|
||||
).register(self.patterns)
|
||||
|
||||
# Fuse rms_norm + fp8 group quant
|
||||
RMSNormGroupQuantPattern(
|
||||
epsilon,
|
||||
FP8_DTYPE,
|
||||
group_shape=group_shape,
|
||||
is_e8m0=is_e8m0,
|
||||
has_col_major_scales=has_col_major_scales,
|
||||
is_tma_aligned=is_tma_aligned,
|
||||
).register(self.patterns)
|
||||
# Fuse rms_norm + fp8 group quant
|
||||
RMSNormGroupQuantPattern(
|
||||
epsilon,
|
||||
FP8_DTYPE,
|
||||
group_shape=group_shape,
|
||||
is_e8m0=is_e8m0,
|
||||
has_col_major_scales=has_col_major_scales,
|
||||
is_tma_aligned=is_tma_aligned,
|
||||
).register(self.patterns)
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
|
||||
@@ -158,11 +158,6 @@ class QuantFP8(CustomOp):
|
||||
if use_aiter_per_token_quant:
|
||||
return rocm_aiter_ops.per_token_quant(x, _FP8_DTYPE, scale)
|
||||
|
||||
# Fallback to native implementation for group quantization.
|
||||
if self.is_group_quant:
|
||||
assert scale is None, "Dynamic group quantization does not use scale"
|
||||
return self._quantize_group_native(x)
|
||||
|
||||
# Fallback to CUDA implementation
|
||||
return self.forward_cuda(x, scale, scale_ub)
|
||||
|
||||
|
||||
@@ -575,7 +575,9 @@ def per_token_group_quant_fp8(
|
||||
|
||||
# prefer CUDA/XPU kernel if available
|
||||
# TODO(bnell): this causes some fp8 moe test to fail.
|
||||
if (current_platform.is_cuda() or current_platform.is_xpu()) and x.is_contiguous():
|
||||
if (
|
||||
current_platform.is_cuda_alike() or current_platform.is_xpu()
|
||||
) and x.is_contiguous():
|
||||
torch.ops._C.per_token_group_fp8_quant(
|
||||
x,
|
||||
x_q,
|
||||
@@ -664,8 +666,7 @@ def per_token_group_quant_fp8_packed_for_deepgemm(
|
||||
)
|
||||
assert x.stride(-1) == 1, "`x` groups must be contiguous"
|
||||
|
||||
finfo = torch.finfo(dtype)
|
||||
fp8_min, fp8_max = finfo.min, finfo.max
|
||||
fp8_min, fp8_max = get_fp8_min_max()
|
||||
|
||||
# compute DeepGEMM-style packed scale tensor shape.
|
||||
hidden_dim = x.shape[-1]
|
||||
@@ -681,10 +682,10 @@ def per_token_group_quant_fp8_packed_for_deepgemm(
|
||||
dtype=torch.int32,
|
||||
)
|
||||
|
||||
# CUDA kernel path only (DeepGEMM + E8M0 is CUDA-specific).
|
||||
assert current_platform.is_cuda(), (
|
||||
"per_token_group_quant_fp8_packed_for_deepgemm is only valid on CUDA "
|
||||
"platforms using DeepGEMM."
|
||||
# Native kernel (libtorch stable); used with DeepGEMM on CUDA and
|
||||
# available on ROCm for the same packed UE8M0 scale layout.
|
||||
assert current_platform.is_cuda_alike(), (
|
||||
"per_token_group_quant_fp8_packed_for_deepgemm requires a CUDA or ROCm GPU."
|
||||
)
|
||||
|
||||
x_contiguous = x.contiguous()
|
||||
|
||||
@@ -235,8 +235,8 @@ def per_token_group_quant_int8(
|
||||
device=x.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
# prefer CUDA kernel if available
|
||||
if current_platform.is_cuda():
|
||||
# Prefer native stable kernel on CUDA/ROCm when available.
|
||||
if current_platform.is_cuda_alike():
|
||||
torch.ops._C.per_token_group_quant_int8(
|
||||
x, x_q, x_s, group_size, eps, float(int8_min), float(int8_max)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user