mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[Kernel][ROCm] Native W4A16 kernel for AMD RDNA3 (gfx1100) — fp16 + bf16 (#41394)
Signed-off-by: JartX <sagformas@epdcenter.es>
This commit is contained in:
@@ -1284,6 +1284,14 @@ if(VLLM_GPU_LANG STREQUAL "HIP")
|
||||
"csrc/rocm/skinny_gemms.cu"
|
||||
"csrc/rocm/attention.cu")
|
||||
|
||||
set(VLLM_ROCM_HAS_GFX1100 OFF)
|
||||
if(VLLM_GPU_ARCHES MATCHES "gfx1100")
|
||||
set(VLLM_ROCM_HAS_GFX1100 ON)
|
||||
list(APPEND VLLM_ROCM_EXT_SRC
|
||||
"csrc/rocm/q_gemm_rdna3.cu"
|
||||
"csrc/rocm/q_gemm_rdna3_wmma.cu")
|
||||
endif()
|
||||
|
||||
define_extension_target(
|
||||
_rocm_C
|
||||
DESTINATION vllm
|
||||
@@ -1293,6 +1301,10 @@ if(VLLM_GPU_LANG STREQUAL "HIP")
|
||||
ARCHITECTURES ${VLLM_GPU_ARCHES}
|
||||
USE_SABI 3
|
||||
WITH_SOABI)
|
||||
|
||||
if(VLLM_ROCM_HAS_GFX1100)
|
||||
target_compile_definitions(_rocm_C PRIVATE VLLM_ROCM_GFX1100)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# Must run after the last HIP `define_extension_target` so every extension
|
||||
|
||||
@@ -18,6 +18,15 @@ void wvSplitKQ(const at::Tensor& in_a, const at::Tensor& in_b,
|
||||
const at::Tensor& scale_a, const at::Tensor& scale_b,
|
||||
const int64_t CuCount);
|
||||
|
||||
torch::Tensor gptq_gemm_rdna3(torch::Tensor a, torch::Tensor b_q_weight,
|
||||
torch::Tensor b_qzeros, torch::Tensor b_scales,
|
||||
torch::Tensor b_g_idx, bool use_v2_format);
|
||||
|
||||
torch::Tensor gptq_gemm_rdna3_wmma(torch::Tensor a, torch::Tensor b_q_weight,
|
||||
torch::Tensor b_qzeros,
|
||||
torch::Tensor b_scales,
|
||||
torch::Tensor b_g_idx, bool use_v2_format);
|
||||
|
||||
void paged_attention(
|
||||
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
|
||||
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||
|
||||
@@ -0,0 +1,780 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
//
|
||||
// W4A16 GPTQ kernel for RDNA3 (gfx1100 / RX 7900 XTX class), templated on the
|
||||
// activation dtype (half or __hip_bfloat16). Adapted from exllamav2's 4-bit
|
||||
// kernel (csrc/quantization/gptq/q_gemm.cu) with the following changes:
|
||||
//
|
||||
// 1. Direct write to the T-typed output via packed CAS-loop on a 64-bit
|
||||
// word (atomic_add_pk4_{f16,bf16}). gfx11 has no native
|
||||
// v_global_atomic_pk_add_{f16,bf16}, so the kernel emulates one with
|
||||
// global_atomic_cmpswap_b64. This avoids the M*N*4-byte FP32 scratch
|
||||
// buffer + memset + cast-pass that an fp32-accumulator design would
|
||||
// need; the caller passes a zero-initialised T-typed output tensor
|
||||
// and every block atomically adds its partial sum into it.
|
||||
//
|
||||
// 2. The bf16 path uses a dedicated bit-trick that avoids the fp16-only
|
||||
// "upper nibble * 16" trick, which would overflow the 7-bit bf16
|
||||
// mantissa. See qdq_4_rdna3.cuh for details.
|
||||
//
|
||||
// 3. Wave32 geometry sized for high CU saturation: THREADS_X=256
|
||||
// (8 waves per block) and BLOCK_KN_SIZE=256, with each thread
|
||||
// computing 4 N output columns. gridDim.z = K / BLOCK_KN_SIZE
|
||||
// splits K and the output is atomically accumulated. fp16 uses
|
||||
// v_dot2_f32_f16 (__builtin_amdgcn_fdot2) for the inner dot;
|
||||
// bf16 widens to fp32 (no v_pk_fma_bf16 on gfx11) and accumulates
|
||||
// with v_fma_f32. M_COUNT ∈ {1,2,4,8} is selected at launch
|
||||
// based on size_m.
|
||||
//
|
||||
// 4. The bf16 dispatch with M >= 16 forwards to the WMMA kernel in
|
||||
// q_gemm_rdna3_wmma.cu (separate translation unit) where
|
||||
// v_wmma_f32_16x16x16_bf16_w32 wins. The fp16 path always stays
|
||||
// scalar (the bit-trick dequant beats WMMA below M=64).
|
||||
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
|
||||
#include <torch/all.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <hip/hip_bf16.h>
|
||||
#include <hip/hip_fp16.h>
|
||||
|
||||
#include "qdq_4_rdna3.cuh"
|
||||
|
||||
#if defined(__HIPCC__) && defined(__gfx1100__)
|
||||
#define __HIP__RDNA3__
|
||||
#endif
|
||||
|
||||
namespace vllm {
|
||||
namespace gptq_rdna3 {
|
||||
|
||||
// BLOCK_KN_SIZE = 256 (was 128 in exllama). Each block covers 256 K
|
||||
// elements and THREADS_X*4 = 1024 N columns. For Qwen-class K=4096 this
|
||||
// halves gridDim.z (32 → 16) and therefore halves the atomic count per
|
||||
// output position vs the exllama default. THREADS_X=256 = 8 waves on RDNA3
|
||||
// wave32; with ~32 wave slots per CU we still fit 4 blocks per CU at peak.
|
||||
//
|
||||
// We tried BLOCK_KN_SIZE=512 (microbench on Qwen3.6-27B): bf16 improved
|
||||
// 5-10% at large M (atomic CAS halved), but fp16 decode regressed up to
|
||||
// +40% on qkv-square (32 → 45 μs at M=1). Cause: 16 waves/block × 16
|
||||
// total blocks for [M=1, K=N=4096] only saturates ~8 of the 96 CUs,
|
||||
// breaking memory-latency hiding for the fp16 path which is already
|
||||
// memory-bound. Reverted to 256; bf16 keeps most of its gains from the
|
||||
// fp32 dequant rewrite alone.
|
||||
#define BLOCK_KN_SIZE 256
|
||||
#define THREADS_X 256
|
||||
|
||||
// Device code below is RDNA3-only; non-RDNA3 device passes fall through to
|
||||
// the empty __global__ stub at the #else below for symbol parity.
|
||||
#if defined(__HIP__RDNA3__) || !defined(__HIP_DEVICE_COMPILE__)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Per-dtype helpers. We avoid heavy template metaprogramming and just provide
|
||||
// overloaded inline functions; the kernel below selects via `if constexpr`.
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// Type-generic zero — both half and bf16_t in HIP/ROCm have a converting
|
||||
// constructor from float, but going through __float2half_rn / __float2bfloat16
|
||||
// is the unambiguously correct path on every ROCm version.
|
||||
template <typename T>
|
||||
__forceinline__ __device__ T tzero();
|
||||
|
||||
template <>
|
||||
__forceinline__ __device__ half tzero<half>() {
|
||||
return __float2half_rn(0.0f);
|
||||
}
|
||||
|
||||
template <>
|
||||
__forceinline__ __device__ bf16_t tzero<bf16_t>() {
|
||||
return __float2bfloat16(0.0f);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ float dot22_8_f(half2 (&dq)[4], const half* a_ptr) {
|
||||
// RDNA3 has v_dot2_f32_f16 (`__builtin_amdgcn_fdot2`) which computes
|
||||
// fp32 += a.x*b.x + a.y*b.y in a single instruction with the accumulator
|
||||
// staying in fp32 throughout. hipcc 7.2 does NOT peephole the obvious
|
||||
// `__hfma2 + cast + add` pattern into v_dot2 (verified by ISA
|
||||
// disassembly: 0 v_dot2_f32_f16 vs 256 v_cvt_f32_f16 + 218 v_add_f32 in
|
||||
// the M_COUNT=8 kernel before this change), so we issue the builtin
|
||||
// explicitly. Saves the trailing 2× v_cvt_f32_f16 + v_add_f32 (3 ops)
|
||||
// per dot22_8_f call vs the half2-accumulator form. With 128 calls per
|
||||
// K=32 step that's ~384 ops/K-step less issue pressure on the VALU.
|
||||
//
|
||||
// Numerical bonus: accumulator stays fp32 throughout the dot. The old
|
||||
// form accumulated 8 muladds in fp16 (10-bit mantissa) before casting,
|
||||
// which could lose ~3 bits of precision on borderline magnitudes.
|
||||
float result = 0.0f;
|
||||
const half2* a2_ptr = (const half2*)a_ptr;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
result = __builtin_amdgcn_fdot2(dq[i], *a2_ptr++, result, /*clamp=*/false);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ float dot22_8_f(bf162_t (&dq)[4],
|
||||
const bf16_t* a_ptr) {
|
||||
// RDNA3 (gfx1100) lacks a packed bf16 FMA: there is no v_pk_fma_bf16 in
|
||||
// the gfx11 ISA (it only landed on CDNA3+ / gfx94x and later). hipcc
|
||||
// therefore lowers __hfma2(bf162_t, bf162_t, bf162_t) to a serialised
|
||||
// fallback (single-element FMAs or fp32 round-trips), which empirically
|
||||
// runs ~2× the cycle count of v_pk_fma_f16 on the same VALU. The bf16
|
||||
// decode path was paying that tax in full, scaling linearly with M (the
|
||||
// fp16 path scales sub-linearly because its v_pk_fma_f16 is full rate
|
||||
// and the kernel becomes memory-bound).
|
||||
//
|
||||
// Fix: widen bf16 → fp32 explicitly (a left-shift by 16, free in VGPRs)
|
||||
// and accumulate with v_fma_f32, which IS full rate on RDNA3. Same FMA
|
||||
// count, but each FMA is fast. Bonus: the accumulator is now fp32
|
||||
// throughout instead of bf16, which is also numerically more accurate
|
||||
// (no compounding bf16-rounding inside the dot loop).
|
||||
float result = 0.0f;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
uint32_t aw, dw;
|
||||
__builtin_memcpy(&aw, a_ptr + 2 * i, sizeof(uint32_t));
|
||||
__builtin_memcpy(&dw, &dq[i], sizeof(uint32_t));
|
||||
// bf16 in low 16 bits → fp32 by left-shifting into the upper half.
|
||||
// bf16 in high 16 bits → already aligned with fp32's upper half.
|
||||
float a_x = __uint_as_float((aw & 0xFFFFu) << 16);
|
||||
float a_y = __uint_as_float(aw & 0xFFFF0000u);
|
||||
float d_x = __uint_as_float((dw & 0xFFFFu) << 16);
|
||||
float d_y = __uint_as_float(dw & 0xFFFF0000u);
|
||||
result = __fmaf_rn(d_x, a_x, result);
|
||||
result = __fmaf_rn(d_y, a_y, result);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// fp32-input dot product: paired with dequant_4bit_8_bf16_f32 which already
|
||||
// produces fp32 dq[8]. Saves the bf16→fp32 widening that the bf162_t
|
||||
// overload above does for dq (still need to widen A from bf16). Wins more
|
||||
// at high N: the bf162_t version's per-call widening cost scales with the
|
||||
// number of dequants × M_COUNT × 4 dot calls; the fp32 version pays only
|
||||
// for A widening (M_COUNT × 4 × 4 widens, half as many).
|
||||
__forceinline__ __device__ float dot22_8_f(float (&dq)[8],
|
||||
const bf16_t* a_ptr) {
|
||||
float result = 0.0f;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
uint32_t aw;
|
||||
__builtin_memcpy(&aw, a_ptr + 2 * i, sizeof(uint32_t));
|
||||
float a_x = __uint_as_float((aw & 0xFFFFu) << 16);
|
||||
float a_y = __uint_as_float(aw & 0xFFFF0000u);
|
||||
result = __fmaf_rn(dq[2 * i + 0], a_x, result);
|
||||
result = __fmaf_rn(dq[2 * i + 1], a_y, result);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Packed atomic-add via CAS-loop on a 64-bit word (4 fp16/bf16 lanes per CAS).
|
||||
// RDNA3 (gfx11) does NOT have native v_global_atomic_pk_add_f16 / _bf16 (those
|
||||
// landed on gfx940 / gfx1250 respectively), so this lowers to
|
||||
// global_atomic_cmpswap_b64 plus retry. We use this in the kernel epilogue to
|
||||
// write 4 output columns per row in a single atomic operation — half the
|
||||
// atomic instruction count and half the contention vs two 32-bit CAS calls.
|
||||
//
|
||||
// Writing directly to fp16/bf16 (instead of through an FP32 scratch buffer +
|
||||
// cast pass) saves M*N*4 bytes of allocation, the memset, and the epilogue
|
||||
// cast pass that an fp32-accumulator design would need.
|
||||
//
|
||||
// 64-bit alignment: the kernel writes at `out + n` where n = offset_n + t*4
|
||||
// (always multiple of 4), and partition_weight_shape[1] is required to be a
|
||||
// multiple of 8 by can_implement(), so every (m, n) write target is 8-byte
|
||||
// aligned. Required by global_atomic_cmpswap_b64.
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
__forceinline__ __device__ void atomic_add_pk4_f16(half* addr, half2 v01,
|
||||
half2 v23) {
|
||||
unsigned long long* addr_u = reinterpret_cast<unsigned long long*>(addr);
|
||||
unsigned long long old = *addr_u;
|
||||
while (true) {
|
||||
union {
|
||||
unsigned long long u;
|
||||
half2 h2[2];
|
||||
} cur, sum;
|
||||
cur.u = old;
|
||||
sum.h2[0] = __hadd2(cur.h2[0], v01);
|
||||
sum.h2[1] = __hadd2(cur.h2[1], v23);
|
||||
unsigned long long prev = atomicCAS(addr_u, old, sum.u);
|
||||
if (prev == old) break;
|
||||
old = prev;
|
||||
}
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void atomic_add_pk4_bf16(bf16_t* addr, bf162_t v01,
|
||||
bf162_t v23) {
|
||||
unsigned long long* addr_u = reinterpret_cast<unsigned long long*>(addr);
|
||||
unsigned long long old = *addr_u;
|
||||
while (true) {
|
||||
union {
|
||||
unsigned long long u;
|
||||
bf162_t b2[2];
|
||||
} cur, sum;
|
||||
cur.u = old;
|
||||
sum.b2[0] = __hadd2(cur.b2[0], v01);
|
||||
sum.b2[1] = __hadd2(cur.b2[1], v23);
|
||||
unsigned long long prev = atomicCAS(addr_u, old, sum.u);
|
||||
if (prev == old) break;
|
||||
old = prev;
|
||||
}
|
||||
}
|
||||
|
||||
// Load one row's worth of 4 packed zeros (column n..n+3) from a [groups, N/8]
|
||||
// uint32 tensor. n is a multiple of 4 by construction (n = offset_n + t*4 with
|
||||
// offset_n = blockIdx.x * 512), so the 4 nibbles always live within one or two
|
||||
// uint32 words; in practice within one because n & 7 is 0 or 4.
|
||||
__forceinline__ __device__ void load4_zeros(const uint32_t* qzeros_row, int n,
|
||||
int (&zeros)[4]) {
|
||||
int qcol = n / 8;
|
||||
int shift = (n & 0x07) * 4;
|
||||
uint32_t d = qzeros_row[qcol] >> shift;
|
||||
zeros[0] = (int)(d & 0xF);
|
||||
zeros[1] = (int)((d >> 4) & 0xF);
|
||||
zeros[2] = (int)((d >> 8) & 0xF);
|
||||
zeros[3] = (int)((d >> 12) & 0xF);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__forceinline__ __device__ void load4_scales(const T* scales_row, int n,
|
||||
T (&scales)[4]) {
|
||||
scales[0] = scales_row[n + 0];
|
||||
scales[1] = scales_row[n + 1];
|
||||
scales[2] = scales_row[n + 2];
|
||||
scales[3] = scales_row[n + 3];
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Main kernel.
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
template <typename T, int M_COUNT>
|
||||
__global__ void gemm_q4_kernel_rdna3(
|
||||
const T* __restrict__ a, const uint32_t* __restrict__ b_q_weight,
|
||||
const uint32_t* __restrict__ b_qzeros, const T* __restrict__ b_scales,
|
||||
T* __restrict__ c, const int size_m, const int size_n, const int size_k,
|
||||
const int groups, const int zero_offset, const int* __restrict__ b_q_perm) {
|
||||
const int t = threadIdx.x;
|
||||
const int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
|
||||
const int offset_m = blockIdx.y * M_COUNT;
|
||||
const int offset_k = blockIdx.z * BLOCK_KN_SIZE;
|
||||
const int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
|
||||
const int n = offset_n + t * 4;
|
||||
|
||||
// LDS layout: [M_COUNT][BLOCK_KN_SIZE + LDS_PAD]. The PAD=8 elements per M
|
||||
// row break the natural 256-element/512-byte alignment that would otherwise
|
||||
// collide on the same LDS bank when a thread reads block_a[0..M_COUNT-1][k]
|
||||
// (same k, different m). Row stride becomes 264 elements * 2B = 528B = 132
|
||||
// 4-byte banks, so m-stride hits banks (m*132)%32 = (m*4)%32 — distinct for
|
||||
// all M_COUNT ≤ 8. Cost: 16B LDS per block, irrelevant.
|
||||
constexpr int LDS_PAD = 8;
|
||||
__shared__ T block_a[M_COUNT][BLOCK_KN_SIZE + LDS_PAD];
|
||||
|
||||
// Stage A: each thread loads 1 K element per M row into LDS (with optional
|
||||
// act-order permutation). THREADS_X == BLOCK_KN_SIZE so this is a 1:1 map.
|
||||
// For M_COUNT > 1 with size_m not a multiple of M_COUNT, slots past size_m
|
||||
// are zero-padded so the dot product contribution is 0 (we then skip the
|
||||
// atomic write for those rows below).
|
||||
//
|
||||
// M=1 fast path: skip LDS staging + __syncthreads entirely. All 256 threads
|
||||
// read the SAME 8-element A window per inner step (a_off is uniform across
|
||||
// the block), so the cache-line broadcast through L1 makes global reads as
|
||||
// cheap as LDS reads. Measured: ~1% on 4B b=1, ~6% on 27B b=1 in=128.
|
||||
static_assert(BLOCK_KN_SIZE == THREADS_X,
|
||||
"BLOCK_KN_SIZE must equal THREADS_X (1 K element per thread)");
|
||||
// The M=1 fast path (skip LDS) only has a global-read code path for bf16
|
||||
// (the v_dot2_f32_bf16 branch). The fp16 inner loop still indexes
|
||||
// block_a[m][a_off] unconditionally, so for fp16 we MUST stage A through
|
||||
// LDS even at M=1 to avoid reading uninitialized shared memory.
|
||||
constexpr bool USE_LDS_A = (M_COUNT > 1) || std::is_same<T, half>::value;
|
||||
if constexpr (USE_LDS_A) {
|
||||
if (offset_k + t < end_k) {
|
||||
#pragma unroll
|
||||
for (int m = 0; m < M_COUNT; ++m) {
|
||||
T av;
|
||||
if (offset_m + m < size_m) {
|
||||
const T* a_row = a + (offset_m + m) * size_k;
|
||||
if (b_q_perm)
|
||||
av = a_row[b_q_perm[offset_k + t]];
|
||||
else
|
||||
av = a_row[offset_k + t];
|
||||
} else {
|
||||
av = tzero<T>(); // zero-pad invalid M rows
|
||||
}
|
||||
block_a[m][t] = av;
|
||||
}
|
||||
}
|
||||
|
||||
// Threads beyond the right edge of N have nothing to do. Note: we must NOT
|
||||
// return before __syncthreads() if any thread in the block participates in
|
||||
// the LDS load above — but here all THREADS_X (=256) threads always do,
|
||||
// regardless of whether their `n` is in bounds.
|
||||
__syncthreads();
|
||||
} else if (b_q_perm) {
|
||||
// bf16 M=1 fast path skips LDS, but its global read below is sequential
|
||||
// and cannot apply act-order. When a permutation is present, stage the
|
||||
// single A row through LDS (as fp16 / M>1 do) so the read picks it up.
|
||||
// b_q_perm is block-uniform, so the __syncthreads is non-divergent.
|
||||
if (offset_k + t < end_k)
|
||||
block_a[0][t] = a[offset_m * size_k + b_q_perm[offset_k + t]];
|
||||
__syncthreads();
|
||||
}
|
||||
if (n >= size_n) return;
|
||||
|
||||
// Group bookkeeping. We require size_k % groups == 0 (groupsize divides K).
|
||||
const int groupsize = size_k / groups;
|
||||
int group = offset_k / groupsize;
|
||||
int nextgroup = (group + 1) * groupsize;
|
||||
|
||||
// qweight stride: weights are [K/8, N] uint32 with K packed at dim 0.
|
||||
int qk = offset_k / 8;
|
||||
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
|
||||
|
||||
// Per-column dequant constants. We hold one set of (z, y) pairs per column.
|
||||
// fp16 uses the exllama (z1z16, y1y16) double-pair to enable the upper-
|
||||
// nibble-*16 trick. bf16 uses fp32 scalars (z, y) because the dequant
|
||||
// produces fp32 directly — see prep_zero_scale_bf16_f32 / the FMA
|
||||
// bypass for the missing v_pk_fma_bf16 on gfx11.
|
||||
half2 z1z16_h[4][2], y1y16_h[4][2];
|
||||
float z_b_f[4], y_b_f[4];
|
||||
|
||||
auto refresh_group = [&](int g) {
|
||||
const uint32_t* qz_row = b_qzeros + g * (size_n / 8);
|
||||
const T* sc_row = b_scales + g * size_n;
|
||||
int zeros[4];
|
||||
T scales[4];
|
||||
load4_zeros(qz_row, n, zeros);
|
||||
load4_scales<T>(sc_row, n, scales);
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
prep_zero_scale_fp16((uint32_t)(zeros[i] + zero_offset), scales[i],
|
||||
z1z16_h[i], y1y16_h[i]);
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
prep_zero_scale_bf16_f32((uint32_t)(zeros[i] + zero_offset), scales[i],
|
||||
z_b_f[i], y_b_f[i]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
refresh_group(group);
|
||||
|
||||
float block_c[M_COUNT][4];
|
||||
#pragma unroll
|
||||
for (int m = 0; m < M_COUNT; ++m) {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 4; ++j) block_c[m][j] = 0.0f;
|
||||
}
|
||||
|
||||
// Note on group-transition granularity: we check `k == nextgroup` at the
|
||||
// start of each outer iteration (which advances K by 32). This is correct
|
||||
// when group_size >= 32 OR group_size divides 32 evenly (groupsize is one
|
||||
// of {1,2,4,8,16,32,64,128,...}). For group_size in {16, 8, 4, ...} the
|
||||
// inner loop would cross a group boundary between j-iterations; we require
|
||||
// group_size >= 32 here, mirroring exllama's assumption.
|
||||
//
|
||||
// Software pipelining: we issue all 4 vectorized weight loads up front
|
||||
// before any dequant/FMA depends on them. This gives the AMDGPU backend
|
||||
// freedom to schedule the global_loads early and overlap their latency
|
||||
// with dequant + v_pk_fma_f16 of earlier iterations. Cost: 4×int4 = 16
|
||||
// VGPRs in flight per thread, plenty of headroom on RDNA3.
|
||||
int k = offset_k;
|
||||
while (k < end_k) {
|
||||
if (k == nextgroup) {
|
||||
group++;
|
||||
nextgroup += groupsize;
|
||||
refresh_group(group);
|
||||
}
|
||||
|
||||
// Prefetch all four j-iterations' weight words. The compiler emits 4
|
||||
// global_load_b128 instructions back-to-back; the dependent dequant +
|
||||
// FMA work below hides their latency.
|
||||
int4 b_w[4];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
b_w[j] = *(const int4*)(b_ptr + j * size_n);
|
||||
}
|
||||
b_ptr += 4 * size_n;
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
const int a_off = (k - offset_k) + 8 * j;
|
||||
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
half2 dq[4][4];
|
||||
dequant_4bit_8_fp16((uint32_t)b_w[j].x, dq[0], z1z16_h[0], y1y16_h[0]);
|
||||
dequant_4bit_8_fp16((uint32_t)b_w[j].y, dq[1], z1z16_h[1], y1y16_h[1]);
|
||||
dequant_4bit_8_fp16((uint32_t)b_w[j].z, dq[2], z1z16_h[2], y1y16_h[2]);
|
||||
dequant_4bit_8_fp16((uint32_t)b_w[j].w, dq[3], z1z16_h[3], y1y16_h[3]);
|
||||
|
||||
#pragma unroll
|
||||
for (int m = 0; m < M_COUNT; ++m) {
|
||||
const half* a_ptr = reinterpret_cast<const half*>(&block_a[m][a_off]);
|
||||
block_c[m][0] += dot22_8_f(dq[0], a_ptr);
|
||||
block_c[m][1] += dot22_8_f(dq[1], a_ptr);
|
||||
block_c[m][2] += dot22_8_f(dq[2], a_ptr);
|
||||
block_c[m][3] += dot22_8_f(dq[3], a_ptr);
|
||||
}
|
||||
} else if constexpr (M_COUNT == 1) {
|
||||
// bf16 decode (M=1), v_dot2_f32_bf16 path. Mirrors the data-flow of
|
||||
// Hybrid PR #40977's wvSplitK_int4 kernel exactly so clang's
|
||||
// InstCombine cannot fold the bf16→fp32 widening (LLVM #76000):
|
||||
// * activations and magic-value weights share a fp32-aliased
|
||||
// union (bytes written as uint32, read as bf16x2_t for the
|
||||
// dot — pointer-cast opacity defeats the fold)
|
||||
// * sum_a computed via a *second* v_dot2 with bf162(1,1) as the
|
||||
// second operand, avoiding any explicit bf16→fp32 widen of A
|
||||
// * bias correction y_b_f * partial + z_b_f * sum_a, identical
|
||||
// to the previous fp32-FMA-chain path
|
||||
//
|
||||
// Net: 20 v_dot2_f32_bf16 + 8 fp32 FMA per int32 weight vs the
|
||||
// previous 40 fp32 FMA. v_dot2 runs at full rate on gfx1100, so
|
||||
// the substitution is ~2× cheaper for the inner accumulator.
|
||||
typedef short __attribute__((ext_vector_type(2))) bf16x2_t;
|
||||
constexpr uint32_t BF16_MAGIC = 0x43004300u; // bf162(128, 128)
|
||||
constexpr uint32_t BF16_ONES = 0x3F803F80u; // bf162(1.0, 1.0)
|
||||
union pack4 {
|
||||
float f[4];
|
||||
uint32_t u[4];
|
||||
};
|
||||
|
||||
uint32_t w[4];
|
||||
__builtin_memcpy(w, &b_w[j], sizeof(int4));
|
||||
|
||||
// Load 8 bf16 activations as 4 uint32s (= 4 bf16x2 pairs) into a
|
||||
// fp32-aliased union. Storing as uint32 keeps the IR-level type
|
||||
// opaque so the inner v_dot2 cannot be folded to fp32 widening.
|
||||
//
|
||||
// A is read direct from global (no LDS staging — see USE_LDS_A above),
|
||||
// except under act-order, where it comes from the permuted LDS copy.
|
||||
pack4 a_pack;
|
||||
{
|
||||
const uint32_t* a_words =
|
||||
b_q_perm
|
||||
? reinterpret_cast<const uint32_t*>(&block_a[0][a_off])
|
||||
: reinterpret_cast<const uint32_t*>(a + offset_k + a_off);
|
||||
a_pack.u[0] = a_words[0];
|
||||
a_pack.u[1] = a_words[1];
|
||||
a_pack.u[2] = a_words[2];
|
||||
a_pack.u[3] = a_words[3];
|
||||
}
|
||||
|
||||
// sum_a = Σ a[i]. Computed via 4× v_dot2_f32_bf16 with bf162(1,1) as
|
||||
// the second operand — every bf16 pair contributes 1·a_lo + 1·a_hi.
|
||||
// No fp32 widening of activations: the bytes go straight from LDS
|
||||
// through v_dot2 into the fp32 accumulator.
|
||||
float sum_a = 0.0f;
|
||||
#pragma unroll
|
||||
for (int b = 0; b < 4; ++b) {
|
||||
sum_a = __builtin_amdgcn_fdot2_f32_bf16(
|
||||
*((bf16x2_t*)(&a_pack.f[b])), *((const bf16x2_t*)&BF16_ONES),
|
||||
sum_a, /*clamp=*/false);
|
||||
}
|
||||
|
||||
// unroll 1 keeps q_pack alive only one col at a time (8 fp32 VGPRs
|
||||
// recycled across cols), avoiding straight-line expansion that
|
||||
// would inflate live-range to 32 VGPRs.
|
||||
#pragma unroll 1
|
||||
for (int col = 0; col < 4; ++col) {
|
||||
// Build dequant magic values bf16(128 + nibble) directly into a
|
||||
// fp32-aliased union via uint32 stores. No fp32 in the data flow
|
||||
// until v_dot2 consumes the bytes.
|
||||
pack4 q_pack;
|
||||
const uint32_t qa = w[col];
|
||||
q_pack.u[0] = ((qa >> 0) & 0x000F000Fu) | BF16_MAGIC;
|
||||
q_pack.u[1] = ((qa >> 4) & 0x000F000Fu) | BF16_MAGIC;
|
||||
q_pack.u[2] = ((qa >> 8) & 0x000F000Fu) | BF16_MAGIC;
|
||||
q_pack.u[3] = ((qa >> 12) & 0x000F000Fu) | BF16_MAGIC;
|
||||
|
||||
// partial = Σ (128 + nibble[i]) · a[i], via 4× v_dot2_f32_bf16.
|
||||
float partial = 0.0f;
|
||||
#pragma unroll
|
||||
for (int b = 0; b < 4; ++b) {
|
||||
partial = __builtin_amdgcn_fdot2_f32_bf16(
|
||||
*((bf16x2_t*)(&a_pack.f[b])), *((bf16x2_t*)(&q_pack.f[b])),
|
||||
partial, /*clamp=*/false);
|
||||
}
|
||||
|
||||
// block_c += y_b_f * partial + z_b_f * sum_a
|
||||
// y_b_f = scale, z_b_f = -(128+zero)*scale
|
||||
// partial holds (128 + nibble) · a; subtracting (128+zero)·sum_a
|
||||
// and scaling yields scale · (nibble - zero) · a as required.
|
||||
block_c[0][col] =
|
||||
__fmaf_rn(y_b_f[col], partial,
|
||||
__fmaf_rn(z_b_f[col], sum_a, block_c[0][col]));
|
||||
}
|
||||
} else {
|
||||
// bf16 M_COUNT > 1 path with v_dot2_f32_bf16. Same opacity trick as
|
||||
// the M=1 branch: activations + magic-value weights stored in
|
||||
// fp32-aliased unions, dot via __builtin_amdgcn_fdot2_f32_bf16 with
|
||||
// pointer-cast to bf16x2_t. sum_a[m] computed via second v_dot2
|
||||
// with BF16_ONES; bias correction (y_b_f * partial + z_b_f * sum_a)
|
||||
// applied after the dot. Magic values built once per col and reused
|
||||
// across all M rows — amortizes dequant cost across M_COUNT.
|
||||
typedef short __attribute__((ext_vector_type(2))) bf16x2_t;
|
||||
constexpr uint32_t BF16_MAGIC = 0x43004300u; // bf162(128, 128)
|
||||
constexpr uint32_t BF16_ONES = 0x3F803F80u; // bf162(1.0, 1.0)
|
||||
union pack4 {
|
||||
float f[4];
|
||||
uint32_t u[4];
|
||||
};
|
||||
|
||||
uint32_t w[4];
|
||||
__builtin_memcpy(w, &b_w[j], sizeof(int4));
|
||||
|
||||
// Load M_COUNT × 8 bf16 activations as 4 uint32s each into pack4
|
||||
// unions. Stored as uint32 to keep IR-level types opaque (defeats
|
||||
// InstCombine fold). At M_COUNT=8 this is 32 fp32 VGPRs — within RDNA3
|
||||
// budget.
|
||||
pack4 a_pack[M_COUNT];
|
||||
#pragma unroll
|
||||
for (int m = 0; m < M_COUNT; ++m) {
|
||||
const uint32_t* a_words =
|
||||
reinterpret_cast<const uint32_t*>(&block_a[m][a_off]);
|
||||
a_pack[m].u[0] = a_words[0];
|
||||
a_pack[m].u[1] = a_words[1];
|
||||
a_pack[m].u[2] = a_words[2];
|
||||
a_pack[m].u[3] = a_words[3];
|
||||
}
|
||||
|
||||
// sum_a[m] = Σ a[m][i] via 4× v_dot2 with bf162(1,1) — no fp32 widen.
|
||||
float sum_a[M_COUNT];
|
||||
#pragma unroll
|
||||
for (int m = 0; m < M_COUNT; ++m) {
|
||||
float s = 0.0f;
|
||||
#pragma unroll
|
||||
for (int b = 0; b < 4; ++b) {
|
||||
s = __builtin_amdgcn_fdot2_f32_bf16(*((bf16x2_t*)(&a_pack[m].f[b])),
|
||||
*((const bf16x2_t*)&BF16_ONES),
|
||||
s, /*clamp=*/false);
|
||||
}
|
||||
sum_a[m] = s;
|
||||
}
|
||||
|
||||
// Per col: build magic-value pack, dot against all M activations.
|
||||
// unroll 1 keeps q_pack live one col at a time (8 fp32 VGPRs recycled)
|
||||
// — same register-pressure trick as the previous fp32 path.
|
||||
#pragma unroll 1
|
||||
for (int col = 0; col < 4; ++col) {
|
||||
pack4 q_pack;
|
||||
const uint32_t qa = w[col];
|
||||
q_pack.u[0] = ((qa >> 0) & 0x000F000Fu) | BF16_MAGIC;
|
||||
q_pack.u[1] = ((qa >> 4) & 0x000F000Fu) | BF16_MAGIC;
|
||||
q_pack.u[2] = ((qa >> 8) & 0x000F000Fu) | BF16_MAGIC;
|
||||
q_pack.u[3] = ((qa >> 12) & 0x000F000Fu) | BF16_MAGIC;
|
||||
|
||||
#pragma unroll
|
||||
for (int m = 0; m < M_COUNT; ++m) {
|
||||
float partial = 0.0f;
|
||||
#pragma unroll
|
||||
for (int b = 0; b < 4; ++b) {
|
||||
partial = __builtin_amdgcn_fdot2_f32_bf16(
|
||||
*((bf16x2_t*)(&a_pack[m].f[b])), *((bf16x2_t*)(&q_pack.f[b])),
|
||||
partial, /*clamp=*/false);
|
||||
}
|
||||
// block_c += y_b_f * partial + z_b_f * sum_a (same correction as
|
||||
// M=1)
|
||||
block_c[m][col] =
|
||||
__fmaf_rn(y_b_f[col], partial,
|
||||
__fmaf_rn(z_b_f[col], sum_a[m], block_c[m][col]));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
k += 32; // 4 weight words * 8 nibbles = 32 K elements
|
||||
}
|
||||
|
||||
// Pack the 4 FP32 partial sums into 2 packed pairs and atomically add all
|
||||
// four lanes in a single 64-bit CAS write directly to the T-typed output
|
||||
// (caller pre-zeros it). On gfx11 the packed atomic is a CAS-loop, but with
|
||||
// a single b64 op we halve the atomic instruction count vs two b32 CAS
|
||||
// calls, AND save the FP32 buffer + memset + cast pass entirely.
|
||||
#pragma unroll
|
||||
for (int m = 0; m < M_COUNT; ++m) {
|
||||
if (offset_m + m >= size_m) continue; // skip padding rows past size_m
|
||||
T* out = c + (offset_m + m) * size_n + n;
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
half2 r01 = __halves2half2(__float2half_rn(block_c[m][0]),
|
||||
__float2half_rn(block_c[m][1]));
|
||||
half2 r23 = __halves2half2(__float2half_rn(block_c[m][2]),
|
||||
__float2half_rn(block_c[m][3]));
|
||||
atomic_add_pk4_f16(out, r01, r23);
|
||||
} else {
|
||||
bf162_t r01;
|
||||
r01.x = __float2bfloat16(block_c[m][0]);
|
||||
r01.y = __float2bfloat16(block_c[m][1]);
|
||||
bf162_t r23;
|
||||
r23.x = __float2bfloat16(block_c[m][2]);
|
||||
r23.y = __float2bfloat16(block_c[m][3]);
|
||||
atomic_add_pk4_bf16(out, r01, r23);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#else // non-RDNA3 device pass: empty __global__ for symbol parity.
|
||||
|
||||
template <typename T, int M_COUNT>
|
||||
__global__ void gemm_q4_kernel_rdna3(const T*, const uint32_t*, const uint32_t*,
|
||||
const T*, T*, const int, const int,
|
||||
const int, const int, const int,
|
||||
const int*) {}
|
||||
|
||||
#endif // __HIP__RDNA3__ || !__HIP_DEVICE_COMPILE__
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Launcher.
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
template <typename T, int M_COUNT>
|
||||
void launch_gemm_q4_for_mcount(const T* a, const uint32_t* b_q_weight,
|
||||
const uint32_t* b_qzeros, const T* b_scales,
|
||||
const int* b_q_perm, T* c, int size_m,
|
||||
int size_n, int size_k, int groups,
|
||||
int zero_offset, cudaStream_t stream) {
|
||||
dim3 block(THREADS_X);
|
||||
dim3 grid((size_n + BLOCK_KN_SIZE * 4 - 1) / (BLOCK_KN_SIZE * 4),
|
||||
(size_m + M_COUNT - 1) / M_COUNT,
|
||||
(size_k + BLOCK_KN_SIZE - 1) / BLOCK_KN_SIZE);
|
||||
|
||||
gemm_q4_kernel_rdna3<T, M_COUNT><<<grid, block, 0, stream>>>(
|
||||
a, b_q_weight, b_qzeros, b_scales, c, size_m, size_n, size_k, groups,
|
||||
zero_offset, b_q_perm);
|
||||
}
|
||||
|
||||
// Dispatch to the largest M_COUNT template that doesn't waste more than
|
||||
// half a tile. Caps at 8: above that, the WMMA-prefill kernel (M >= 16) is
|
||||
// the right tool, not bigger M_COUNT in the scalar dot-product path.
|
||||
//
|
||||
// Tile-waste table:
|
||||
// M=1 -> M_COUNT=1 (no waste)
|
||||
// M=2,3 -> M_COUNT=2 (M=3 wastes 1/2 of last tile)
|
||||
// M=4-7 -> M_COUNT=4 (worst case M=5: wastes 3/4 of last tile)
|
||||
// M=8-15-> M_COUNT=8 (worst case M=9: wastes 7/8 of last tile)
|
||||
// "Wasted" rows are zero-padded in LDS and skip the atomic write, so they
|
||||
// only burn instructions on the last block, never affect correctness.
|
||||
template <typename T>
|
||||
void launch_gemm_q4(const T* a, const uint32_t* b_q_weight,
|
||||
const uint32_t* b_qzeros, const T* b_scales,
|
||||
const int* b_q_perm, T* c, int size_m, int size_n,
|
||||
int size_k, int groups, bool use_v2_format,
|
||||
cudaStream_t stream) {
|
||||
const int zero_offset = use_v2_format ? 0 : 1;
|
||||
|
||||
if (size_m == 1) {
|
||||
launch_gemm_q4_for_mcount<T, 1>(a, b_q_weight, b_qzeros, b_scales, b_q_perm,
|
||||
c, size_m, size_n, size_k, groups,
|
||||
zero_offset, stream);
|
||||
} else if (size_m <= 3) {
|
||||
launch_gemm_q4_for_mcount<T, 2>(a, b_q_weight, b_qzeros, b_scales, b_q_perm,
|
||||
c, size_m, size_n, size_k, groups,
|
||||
zero_offset, stream);
|
||||
} else if (size_m <= 7) {
|
||||
launch_gemm_q4_for_mcount<T, 4>(a, b_q_weight, b_qzeros, b_scales, b_q_perm,
|
||||
c, size_m, size_n, size_k, groups,
|
||||
zero_offset, stream);
|
||||
} else {
|
||||
// M_COUNT=8 covers M up to 15 here; M >= 16 should ideally take the
|
||||
// WMMA path, but if it falls through we still produce correct output —
|
||||
// just leaving 3-5× of throughput on the table for prefill workloads.
|
||||
launch_gemm_q4_for_mcount<T, 8>(a, b_q_weight, b_qzeros, b_scales, b_q_perm,
|
||||
c, size_m, size_n, size_k, groups,
|
||||
zero_offset, stream);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace gptq_rdna3
|
||||
} // namespace vllm
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Public entry point.
|
||||
// ---------------------------------------------------------------------------
|
||||
//
|
||||
// Inputs:
|
||||
// a [M, K] half or bfloat16
|
||||
// b_q_weight[K/8, N] uint32 (already shuffled via gptq_shuffle)
|
||||
// b_qzeros [groups, N/8] uint32 (packed 4-bit zeros)
|
||||
// b_scales [groups, N] half or bfloat16
|
||||
// b_g_idx [K] or empty int32 (act-order permutation; empty=identity)
|
||||
// use_v2_format bool (true = GPTQv2, no +1 zero offset)
|
||||
//
|
||||
// Output:
|
||||
// c [M, N] same dtype as a
|
||||
|
||||
torch::Tensor gptq_gemm_rdna3_wmma(torch::Tensor a, torch::Tensor b_q_weight,
|
||||
torch::Tensor b_qzeros,
|
||||
torch::Tensor b_scales,
|
||||
torch::Tensor b_g_idx, bool use_v2_format);
|
||||
|
||||
torch::Tensor gptq_gemm_rdna3(torch::Tensor a, torch::Tensor b_q_weight,
|
||||
torch::Tensor b_qzeros, torch::Tensor b_scales,
|
||||
torch::Tensor b_g_idx, bool use_v2_format) {
|
||||
if (a.dim() == 2 && b_q_weight.dim() == 2 && a.size(1) % 16 == 0 &&
|
||||
b_q_weight.size(1) % 16 == 0 &&
|
||||
((a.scalar_type() == torch::kBFloat16 && a.size(0) >= 16) ||
|
||||
(a.scalar_type() == torch::kHalf && a.size(0) >= 64))) {
|
||||
return gptq_gemm_rdna3_wmma(a, b_q_weight, b_qzeros, b_scales, b_g_idx,
|
||||
use_v2_format);
|
||||
}
|
||||
|
||||
TORCH_CHECK(a.is_cuda(), "a must be a CUDA/HIP tensor");
|
||||
TORCH_CHECK(b_q_weight.is_cuda(), "b_q_weight must be a CUDA/HIP tensor");
|
||||
TORCH_CHECK(b_qzeros.is_cuda(), "b_qzeros must be a CUDA/HIP tensor");
|
||||
TORCH_CHECK(b_scales.is_cuda(), "b_scales must be a CUDA/HIP tensor");
|
||||
TORCH_CHECK(a.dim() == 2, "a must be 2D [M, K]");
|
||||
TORCH_CHECK(b_q_weight.dim() == 2, "b_q_weight must be 2D [K/8, N]");
|
||||
TORCH_CHECK(
|
||||
a.scalar_type() == torch::kHalf || a.scalar_type() == torch::kBFloat16,
|
||||
"a must be half or bfloat16");
|
||||
TORCH_CHECK(a.scalar_type() == b_scales.scalar_type(),
|
||||
"b_scales dtype must match a");
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
int size_m = (int)a.size(0);
|
||||
int size_k = (int)a.size(1);
|
||||
int size_n = (int)b_q_weight.size(1);
|
||||
int groups = (int)b_qzeros.size(0);
|
||||
|
||||
TORCH_CHECK(b_q_weight.size(0) * 8 == size_k,
|
||||
"b_q_weight first dim must be K/8");
|
||||
TORCH_CHECK(b_scales.size(0) == groups,
|
||||
"b_scales must have same group count as qzeros");
|
||||
TORCH_CHECK(b_scales.size(1) == size_n, "b_scales last dim must be N");
|
||||
TORCH_CHECK(size_n % 8 == 0, "N must be a multiple of 8 (64-bit atomic CAS)");
|
||||
|
||||
auto opts = torch::TensorOptions().dtype(a.dtype()).device(a.device());
|
||||
at::Tensor c = torch::zeros({size_m, size_n}, opts);
|
||||
|
||||
const int* g_idx_ptr = nullptr;
|
||||
if (!b_g_idx.device().is_meta() && b_g_idx.numel() > 0) {
|
||||
TORCH_CHECK(b_g_idx.scalar_type() == torch::kInt32,
|
||||
"b_g_idx must be int32");
|
||||
g_idx_ptr = (const int*)b_g_idx.data_ptr();
|
||||
}
|
||||
|
||||
if (a.scalar_type() == torch::kHalf) {
|
||||
vllm::gptq_rdna3::launch_gemm_q4<half>(
|
||||
(const half*)a.data_ptr(), (const uint32_t*)b_q_weight.data_ptr(),
|
||||
(const uint32_t*)b_qzeros.data_ptr(), (const half*)b_scales.data_ptr(),
|
||||
g_idx_ptr, (half*)c.data_ptr(), size_m, size_n, size_k, groups,
|
||||
use_v2_format, stream);
|
||||
} else {
|
||||
vllm::gptq_rdna3::launch_gemm_q4<vllm::gptq_rdna3::bf16_t>(
|
||||
(const vllm::gptq_rdna3::bf16_t*)a.data_ptr(),
|
||||
(const uint32_t*)b_q_weight.data_ptr(),
|
||||
(const uint32_t*)b_qzeros.data_ptr(),
|
||||
(const vllm::gptq_rdna3::bf16_t*)b_scales.data_ptr(), g_idx_ptr,
|
||||
(vllm::gptq_rdna3::bf16_t*)c.data_ptr(), size_m, size_n, size_k, groups,
|
||||
use_v2_format, stream);
|
||||
}
|
||||
|
||||
return c;
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,239 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
//
|
||||
// W4A16 dequant primitives for RDNA3 (gfx1100/gfx1101/gfx1102), templated on
|
||||
// the activation/scale dtype (half or __hip_bfloat16). The fp16 path reuses
|
||||
// the classic exllamav2 bit-trick:
|
||||
//
|
||||
// (qa & 0x000F000F) | 0x64006400 -> half2(1024+q_lo, 1024+q_hi)
|
||||
// (qa & 0x00F000F0) | 0x64006400 -> half2(1024+q_lo*16, 1024+q_hi*16)
|
||||
//
|
||||
// The "*16 then divide by 16 in the FMA" trick for the upper-nibble pairs
|
||||
// works in fp16 because the mantissa (10 bits) is wide enough to hold a value
|
||||
// shifted by 4 bits. In bf16 the mantissa is only 7 bits, so shifting an upper
|
||||
// nibble into bits [7:4] would spill into the exponent. To avoid that, the
|
||||
// bf16 path shifts each pair of nibbles down to bits [3:0]/[19:16] with a
|
||||
// single right-shift before the OR with 0x43004300 (= bf162(128, 128)).
|
||||
|
||||
#ifndef _qdq_4_rdna3_cuh
|
||||
#define _qdq_4_rdna3_cuh
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#include <hip/hip_bf16.h>
|
||||
#include <hip/hip_fp16.h>
|
||||
|
||||
namespace vllm {
|
||||
namespace gptq_rdna3 {
|
||||
|
||||
using bf16_t = __hip_bfloat16;
|
||||
using bf162_t = __hip_bfloat162;
|
||||
|
||||
// Bit-shuffle for an int32 holding 8 sequential 4-bit weights q[0..7]:
|
||||
// in: q[7] q[6] q[5] q[4] q[3] q[2] q[1] q[0] (LSB first)
|
||||
// out: q[7] q[5] q[3] q[1] q[6] q[4] q[2] q[0] (even/odd interleaved)
|
||||
//
|
||||
// After shuffle, q[2k] sits at bits [4k : 4k+3] (lower 16)
|
||||
// q[2k+1] sits at bits [16+4k: 16+4k+3] (upper 16)
|
||||
// so a single mask 0x000F000F selects the matching even/odd pair, ready to
|
||||
// bitcast to half2 / bfloat162 after OR-ing with the magic constant.
|
||||
__forceinline__ __device__ void shuffle_4bit_8(uint32_t* q) {
|
||||
uint32_t qa = q[0];
|
||||
uint32_t qb = 0;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
uint32_t qa0 = qa & 0x0F;
|
||||
uint32_t qa1 = (qa & 0xF0) >> 4;
|
||||
qa >>= 8;
|
||||
qb |= (qa1 << (i * 4 + 16));
|
||||
qb |= (qa0 << (i * 4));
|
||||
}
|
||||
q[0] = qb;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// fp16 path
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// Precompute scale-baked constants for a single zero/scale pair.
|
||||
// z1z16[0] = scale * (-1024 - zero) (used for "low" pairs)
|
||||
// z1z16[1] = scale * (-64 - zero) (used for "high" pairs)
|
||||
// y1y16[0] = scale * 1 (low pairs are q + 1024)
|
||||
// y1y16[1] = scale * (1/16) (high pairs are q*16 + 1024)
|
||||
__forceinline__ __device__ void prep_zero_scale_fp16(uint32_t zero, half scale,
|
||||
half2 (&z1z16)[2],
|
||||
half2 (&y1y16)[2]) {
|
||||
// half(-1024 - zero) via the exllamav2 bit-trick:
|
||||
// half bits 0xE400 == -1024.0 ; ORing the zero into mantissa subtracts it.
|
||||
union {
|
||||
uint16_t u;
|
||||
half h;
|
||||
} z1u;
|
||||
z1u.u = (uint16_t)(0xE400 | zero);
|
||||
half z1 = z1u.h;
|
||||
half z16 = __hsub(__int2half_rn(-64), __int2half_rn((int)zero));
|
||||
|
||||
half2 scale2 = __half2half2(scale);
|
||||
z1z16[0] = __hmul2(scale2, __half2half2(z1));
|
||||
z1z16[1] = __hmul2(scale2, __half2half2(z16));
|
||||
|
||||
half y1 = __float2half_rn(1.0f);
|
||||
half y16 = __float2half_rn(1.0f / 16.0f);
|
||||
y1y16[0] = __hmul2(scale2, __half2half2(y1));
|
||||
y1y16[1] = __hmul2(scale2, __half2half2(y16));
|
||||
}
|
||||
|
||||
// Dequantize one int32 (8 shuffled 4-bit weights) into 4 half2 pairs:
|
||||
// dq[0] = (q[0], q[1]) * scale - zero*scale
|
||||
// dq[1] = (q[2], q[3]) * scale - zero*scale
|
||||
// dq[2] = (q[4], q[5]) * scale - zero*scale
|
||||
// dq[3] = (q[6], q[7]) * scale - zero*scale
|
||||
__forceinline__ __device__ void dequant_4bit_8_fp16(uint32_t qa, half2 (&dq)[4],
|
||||
half2 (&z1z16)[2],
|
||||
half2 (&y1y16)[2]) {
|
||||
const uint32_t c0 = 0x64006400;
|
||||
|
||||
union {
|
||||
uint32_t u;
|
||||
half2 h2;
|
||||
} q0, q1, q2, q3;
|
||||
q0.u = (qa & 0x000F000F) | c0; // half2(q[0]+1024, q[1]+1024)
|
||||
q1.u = (qa & 0x00F000F0) | c0; // half2(q[2]*16+1024, q[3]*16+1024)
|
||||
uint32_t qa_hi = qa >> 8;
|
||||
q2.u = (qa_hi & 0x000F000F) | c0; // half2(q[4]+1024, q[5]+1024)
|
||||
q3.u = (qa_hi & 0x00F000F0) | c0; // half2(q[6]*16+1024, q[7]*16+1024)
|
||||
|
||||
dq[0] = __hfma2(q0.h2, y1y16[0], z1z16[0]);
|
||||
dq[1] = __hfma2(q1.h2, y1y16[1], z1z16[1]);
|
||||
dq[2] = __hfma2(q2.h2, y1y16[0], z1z16[0]);
|
||||
dq[3] = __hfma2(q3.h2, y1y16[1], z1z16[1]);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// bf16 path
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// Bit-trick magic for bf16:
|
||||
// bf16(128) == 0x4300 (sign 0, exp 134, mantissa 0).
|
||||
// For nibble n in [0..15], bits [3:0] of mantissa hold n exactly because
|
||||
// bf16's ULP at 128 is 1 (mantissa step = 2^(7-7) = 1). So
|
||||
// ((qa & 0x000F000F) | 0x43004300) bitcasts to bfloat162(128+n_lo, 128+n_hi).
|
||||
//
|
||||
// Because bf16's mantissa is only 7 bits, we cannot use the fp16 "upper nibble
|
||||
// * 16" trick. Instead each pair of nibbles is shifted down to [3:0]/[19:16]
|
||||
// via a single 4/8/12-bit right-shift before the OR. That costs one extra
|
||||
// shift per pair vs fp16, but keeps the FMA structure identical.
|
||||
__forceinline__ __device__ void prep_zero_scale_bf16(uint32_t zero,
|
||||
bf16_t scale,
|
||||
bf162_t& z_prep,
|
||||
bf162_t& y_prep) {
|
||||
// z = scale * -(128 + zero); y = scale.
|
||||
float scale_f = __bfloat162float(scale);
|
||||
float zf = -(128.0f + (float)zero) * scale_f;
|
||||
bf16_t zb = __float2bfloat16(zf);
|
||||
z_prep = __bfloat162bfloat162(zb);
|
||||
y_prep = __bfloat162bfloat162(scale);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_4bit_8_bf16(uint32_t qa,
|
||||
bf162_t (&dq)[4],
|
||||
bf162_t z_prep,
|
||||
bf162_t y_prep) {
|
||||
const uint32_t c0 = 0x43004300;
|
||||
|
||||
union {
|
||||
uint32_t u;
|
||||
bf162_t b2;
|
||||
} q0, q1, q2, q3;
|
||||
q0.u = ((qa >> 0) & 0x000F000F) | c0; // bf162(128+q[0], 128+q[1])
|
||||
q1.u = ((qa >> 4) & 0x000F000F) | c0; // bf162(128+q[2], 128+q[3])
|
||||
q2.u = ((qa >> 8) & 0x000F000F) | c0; // bf162(128+q[4], 128+q[5])
|
||||
q3.u = ((qa >> 12) & 0x000F000F) | c0; // bf162(128+q[6], 128+q[7])
|
||||
|
||||
// dq = q_b * scale + (-(128+zero)*scale) = (q - zero) * scale
|
||||
dq[0] = __hfma2(q0.b2, y_prep, z_prep);
|
||||
dq[1] = __hfma2(q1.b2, y_prep, z_prep);
|
||||
dq[2] = __hfma2(q2.b2, y_prep, z_prep);
|
||||
dq[3] = __hfma2(q3.b2, y_prep, z_prep);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// bf16-input → fp32-output dequant (RDNA3 scalar path).
|
||||
//
|
||||
// RDNA3 (gfx1100) has no v_pk_fma_bf16; packed bf16 FMA lowers to a slow
|
||||
// fallback. Rather than computing dq in bf16 and widening at FMA time in
|
||||
// the dot product, we widen to fp32 here once (a free left-shift by 16) and
|
||||
// emit the (q - zero) * scale FMA directly in fp32. This:
|
||||
// * Replaces 4× slow bf16 packed FMA with 8× fast fp32 FMA per int32.
|
||||
// * Eliminates 4× bf16→fp32 widens that the dot product would do.
|
||||
// * Keeps the dot product accumulator in fp32 without a roundtrip.
|
||||
//
|
||||
// Output: fp32 dq[8], one element per K position (consumed by the
|
||||
// fp32-overload of dot22_8_f in q_gemm_rdna3.cu).
|
||||
__forceinline__ __device__ void prep_zero_scale_bf16_f32(uint32_t zero,
|
||||
bf16_t scale,
|
||||
float& z_prep,
|
||||
float& y_prep) {
|
||||
float scale_f = __bfloat162float(scale);
|
||||
z_prep = -(128.0f + (float)zero) * scale_f;
|
||||
y_prep = scale_f;
|
||||
}
|
||||
|
||||
// Pure-q dequant for the M_COUNT=1 factored path: outputs the unscaled fp32
|
||||
// values 128+nibble, without folding scale/zero. The caller folds scale/zb
|
||||
// into the accumulator outside the inner loop using a precomputed sum_a,
|
||||
// which saves ~27% of the FMA count vs the per-col-dequant approach above
|
||||
// (only beneficial at M_COUNT=1; break-even at M_COUNT=2).
|
||||
//
|
||||
// Cost: 0 FMAs (pure bit-trick + as_float reinterprets).
|
||||
__forceinline__ __device__ void dequant_4bit_8_bf16_q_only(uint32_t qa,
|
||||
float (&q_f32)[8]) {
|
||||
const uint32_t c0 = 0x43004300;
|
||||
const uint32_t q0 = ((qa >> 0) & 0x000F000F) | c0;
|
||||
const uint32_t q1 = ((qa >> 4) & 0x000F000F) | c0;
|
||||
const uint32_t q2 = ((qa >> 8) & 0x000F000F) | c0;
|
||||
const uint32_t q3 = ((qa >> 12) & 0x000F000F) | c0;
|
||||
q_f32[0] = __uint_as_float((q0 & 0xFFFFu) << 16);
|
||||
q_f32[1] = __uint_as_float(q0 & 0xFFFF0000u);
|
||||
q_f32[2] = __uint_as_float((q1 & 0xFFFFu) << 16);
|
||||
q_f32[3] = __uint_as_float(q1 & 0xFFFF0000u);
|
||||
q_f32[4] = __uint_as_float((q2 & 0xFFFFu) << 16);
|
||||
q_f32[5] = __uint_as_float(q2 & 0xFFFF0000u);
|
||||
q_f32[6] = __uint_as_float((q3 & 0xFFFFu) << 16);
|
||||
q_f32[7] = __uint_as_float(q3 & 0xFFFF0000u);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_4bit_8_bf16_f32(uint32_t qa,
|
||||
float (&dq)[8],
|
||||
float z_prep,
|
||||
float y_prep) {
|
||||
const uint32_t c0 = 0x43004300;
|
||||
const uint32_t q0 = ((qa >> 0) & 0x000F000F) | c0;
|
||||
const uint32_t q1 = ((qa >> 4) & 0x000F000F) | c0;
|
||||
const uint32_t q2 = ((qa >> 8) & 0x000F000F) | c0;
|
||||
const uint32_t q3 = ((qa >> 12) & 0x000F000F) | c0;
|
||||
// bf16(128+nibble) bits → fp32(128+nibble) bits via left-shift by 16
|
||||
// (just zero-extends the mantissa from 7 to 23 bits; exponent preserved).
|
||||
const float q0x = __uint_as_float((q0 & 0xFFFFu) << 16);
|
||||
const float q0y = __uint_as_float(q0 & 0xFFFF0000u);
|
||||
const float q1x = __uint_as_float((q1 & 0xFFFFu) << 16);
|
||||
const float q1y = __uint_as_float(q1 & 0xFFFF0000u);
|
||||
const float q2x = __uint_as_float((q2 & 0xFFFFu) << 16);
|
||||
const float q2y = __uint_as_float(q2 & 0xFFFF0000u);
|
||||
const float q3x = __uint_as_float((q3 & 0xFFFFu) << 16);
|
||||
const float q3y = __uint_as_float(q3 & 0xFFFF0000u);
|
||||
// dq[i] = q_f32 * scale + (-(128+zero)*scale) = (nibble - zero) * scale
|
||||
dq[0] = __fmaf_rn(q0x, y_prep, z_prep);
|
||||
dq[1] = __fmaf_rn(q0y, y_prep, z_prep);
|
||||
dq[2] = __fmaf_rn(q1x, y_prep, z_prep);
|
||||
dq[3] = __fmaf_rn(q1y, y_prep, z_prep);
|
||||
dq[4] = __fmaf_rn(q2x, y_prep, z_prep);
|
||||
dq[5] = __fmaf_rn(q2y, y_prep, z_prep);
|
||||
dq[6] = __fmaf_rn(q3x, y_prep, z_prep);
|
||||
dq[7] = __fmaf_rn(q3y, y_prep, z_prep);
|
||||
}
|
||||
|
||||
} // namespace gptq_rdna3
|
||||
} // namespace vllm
|
||||
|
||||
#endif // _qdq_4_rdna3_cuh
|
||||
@@ -39,6 +39,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) {
|
||||
" Tensor scale_b, int CuCount) -> ()");
|
||||
rocm_ops.impl("wvSplitKQ", torch::kCUDA, &wvSplitKQ);
|
||||
|
||||
#ifdef VLLM_ROCM_GFX1100
|
||||
// W4A16 GPTQ kernels for AMD RDNA3 (gfx1100).
|
||||
rocm_ops.def(
|
||||
"gptq_gemm_rdna3(Tensor a, Tensor b_q_weight, Tensor b_qzeros, "
|
||||
"Tensor b_scales, Tensor b_g_idx, bool use_v2_format) -> Tensor");
|
||||
rocm_ops.impl("gptq_gemm_rdna3", torch::kCUDA, &gptq_gemm_rdna3);
|
||||
|
||||
rocm_ops.def(
|
||||
"gptq_gemm_rdna3_wmma(Tensor a, Tensor b_q_weight, Tensor b_qzeros, "
|
||||
"Tensor b_scales, Tensor b_g_idx, bool use_v2_format) -> Tensor");
|
||||
rocm_ops.impl("gptq_gemm_rdna3_wmma", torch::kCUDA, &gptq_gemm_rdna3_wmma);
|
||||
#endif
|
||||
|
||||
// Custom attention op
|
||||
// Compute the attention between an input query and the cached
|
||||
// keys/values using PagedAttention.
|
||||
|
||||
@@ -0,0 +1,278 @@
|
||||
#!/usr/bin/env python3
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Correctness tests for the ROCm RDNA3 W4A16 GPTQ kernel (gfx1100).
|
||||
|
||||
Exercises ``RDNA3W4A16LinearKernel`` end-to-end: it builds a layer with
|
||||
GPTQ-format checkpoint parameters, runs ``process_weights_after_loading``
|
||||
(weight shuffle + zero-point synthesis), then ``apply_weights``, and compares
|
||||
the result against an fp32 reference dequant-and-matmul.
|
||||
|
||||
The kernel is exposed via ``torch.ops._rocm_C.gptq_gemm_rdna3`` and is only
|
||||
built for gfx11; tests are skipped elsewhere.
|
||||
|
||||
Run `pytest tests/kernels/quantization/test_rdna3_w4a16.py`.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if not current_platform.is_rocm():
|
||||
pytest.skip("RDNA3 W4A16 kernel is ROCm-only", allow_module_level=True)
|
||||
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.MPLinearKernel import ( # noqa: E402
|
||||
MPLinearLayerConfig,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.rdna3_w4a16 import ( # noqa: E402
|
||||
RDNA3W4A16LinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import ( # noqa: E402
|
||||
pack_quantized_values_into_int32,
|
||||
)
|
||||
from vllm.model_executor.parameter import ( # noqa: E402
|
||||
GroupQuantScaleParameter,
|
||||
PackedvLLMParameter,
|
||||
)
|
||||
from vllm.platforms.rocm import on_gfx1100 # noqa: E402
|
||||
from vllm.scalar_type import scalar_types # noqa: E402
|
||||
from vllm.utils.torch_utils import set_random_seed # noqa: E402
|
||||
|
||||
device = "cuda"
|
||||
|
||||
WEIGHT_TYPE = scalar_types.uint4b8 # symmetric int4, bias = 8
|
||||
PACK_FACTOR = 8 # 8 x 4-bit nibbles per int32
|
||||
|
||||
# Skip everything in this module unless we are on the only architecture the
|
||||
# kernel is built/registered for.
|
||||
gfx1100_only = pytest.mark.skipif(
|
||||
not (
|
||||
on_gfx1100()
|
||||
and hasattr(torch.ops, "_rocm_C")
|
||||
and hasattr(torch.ops._rocm_C, "gptq_gemm_rdna3")
|
||||
),
|
||||
reason="requires gfx1100 with the _rocm_C.gptq_gemm_rdna3 op built in",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Reference implementation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _reference(
|
||||
x_mk: torch.Tensor,
|
||||
q_int4_kn: torch.Tensor,
|
||||
scales_gn: torch.Tensor,
|
||||
zeros_gn: torch.Tensor | None,
|
||||
group_size: int,
|
||||
bias: torch.Tensor | None,
|
||||
) -> torch.Tensor:
|
||||
"""fp32 reference for the RDNA3 W4A16 op.
|
||||
|
||||
x_mk: [M, K] fp16/bf16 activations.
|
||||
q_int4_kn: [K, N] int32 raw stored nibbles in [0, 15].
|
||||
scales_gn: [K//G, N] per-group scales (act dtype).
|
||||
zeros_gn: [K//G, N] int32 raw stored zero points in [0, 15], or None
|
||||
for the symmetric path (kernel synthesizes stored zero = 7).
|
||||
group_size: G.
|
||||
|
||||
The kernel applies the GPTQv1 "+1" zero-point quirk, so the effective
|
||||
zero is ``stored_zero + 1`` (symmetric path: 7 + 1 == bias == 8).
|
||||
"""
|
||||
K, N = q_int4_kn.shape
|
||||
s_full = scales_gn.repeat_interleave(group_size, dim=0).to(torch.float32) # [K,N]
|
||||
if zeros_gn is None:
|
||||
z_full = torch.full(
|
||||
(K, N), float(WEIGHT_TYPE.bias), device=x_mk.device, dtype=torch.float32
|
||||
)
|
||||
else:
|
||||
z_full = (zeros_gn + 1).repeat_interleave(group_size, dim=0).to(torch.float32)
|
||||
w_fp = (q_int4_kn.to(torch.float32) - z_full) * s_full # [K, N]
|
||||
out = x_mk.to(torch.float32) @ w_fp # [M, N]
|
||||
if bias is not None:
|
||||
out = out + bias.to(torch.float32)
|
||||
return out.to(x_mk.dtype)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Layer construction (GPTQ checkpoint format)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _build_layer(
|
||||
q_int4_kn: torch.Tensor,
|
||||
scales_gn: torch.Tensor,
|
||||
zeros_gn: torch.Tensor | None,
|
||||
dtype: torch.dtype,
|
||||
) -> torch.nn.Module:
|
||||
"""Build a dummy layer carrying GPTQ-format params, as the loader would."""
|
||||
no_loader = lambda *args, **kwargs: None # noqa: E731
|
||||
|
||||
# qweight: int4 packed along K into int32 -> [K//8, N].
|
||||
qweight = pack_quantized_values_into_int32(q_int4_kn, WEIGHT_TYPE, packed_dim=0)
|
||||
|
||||
class DummyLayer(torch.nn.Module):
|
||||
pass
|
||||
|
||||
layer = DummyLayer()
|
||||
layer.register_parameter(
|
||||
"qweight",
|
||||
PackedvLLMParameter(
|
||||
data=qweight,
|
||||
weight_loader=no_loader,
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=0,
|
||||
packed_factor=PACK_FACTOR,
|
||||
),
|
||||
)
|
||||
layer.register_parameter(
|
||||
"scales",
|
||||
GroupQuantScaleParameter(
|
||||
data=scales_gn.to(dtype),
|
||||
weight_loader=no_loader,
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
),
|
||||
)
|
||||
if zeros_gn is not None:
|
||||
# qzeros: int4 packed along N into int32 -> [K//G, N//8].
|
||||
qzeros = pack_quantized_values_into_int32(zeros_gn, WEIGHT_TYPE, packed_dim=1)
|
||||
layer.register_parameter(
|
||||
"qzeros",
|
||||
PackedvLLMParameter(
|
||||
data=qzeros,
|
||||
weight_loader=no_loader,
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=PACK_FACTOR,
|
||||
),
|
||||
)
|
||||
return layer
|
||||
|
||||
|
||||
def _run_kernel(
|
||||
x_mk: torch.Tensor,
|
||||
q_int4_kn: torch.Tensor,
|
||||
scales_gn: torch.Tensor,
|
||||
zeros_gn: torch.Tensor | None,
|
||||
group_size: int,
|
||||
bias: torch.Tensor | None,
|
||||
dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
K, N = q_int4_kn.shape
|
||||
has_zp = zeros_gn is not None
|
||||
|
||||
config = MPLinearLayerConfig(
|
||||
full_weight_shape=(K, N),
|
||||
partition_weight_shape=(K, N),
|
||||
weight_type=WEIGHT_TYPE,
|
||||
act_type=dtype,
|
||||
group_size=group_size,
|
||||
zero_points=has_zp,
|
||||
has_g_idx=False,
|
||||
)
|
||||
ok, reason = RDNA3W4A16LinearKernel.can_implement(config)
|
||||
assert ok, f"can_implement rejected a supported config: {reason}"
|
||||
|
||||
layer = _build_layer(q_int4_kn, scales_gn, zeros_gn, dtype)
|
||||
kernel = RDNA3W4A16LinearKernel(
|
||||
config,
|
||||
w_q_param_name="qweight",
|
||||
w_s_param_name="scales",
|
||||
w_zp_param_name="qzeros" if has_zp else None,
|
||||
w_gidx_param_name=None,
|
||||
)
|
||||
kernel.process_weights_after_loading(layer)
|
||||
return kernel.apply_weights(layer, x_mk, bias=bias)
|
||||
|
||||
|
||||
# Relative-L2 tolerance per dtype. The bf16 path widens dequantized weights
|
||||
# to fp32 and accumulates in fp32, so it matches the reference almost exactly
|
||||
# (<0.4% incl. the WMMA prefill path). The fp16 path uses the exllamav2
|
||||
# "+1024" bit-trick (see qdq_4_rdna3.cuh): the dequantized weight is recovered
|
||||
# as the fp16 difference of two ~1024*scale magnitudes, which sheds low-order
|
||||
# mantissa bits and leaves ~2-3% relative noise that accumulates over K. We
|
||||
# compare on the relative Frobenius norm rather than elementwise, since the
|
||||
# bit-trick noise produces large *relative* errors on individual near-zero
|
||||
# outputs that carry negligible absolute weight.
|
||||
_REL_L2_TOL = {torch.float16: 5e-2, torch.bfloat16: 1e-2}
|
||||
|
||||
|
||||
def _assert_close(out: torch.Tensor, ref: torch.Tensor, dtype: torch.dtype):
|
||||
rel_l2 = (out.to(torch.float32) - ref.to(torch.float32)).norm() / ref.to(
|
||||
torch.float32
|
||||
).norm()
|
||||
tol = _REL_L2_TOL[dtype]
|
||||
assert rel_l2 < tol, f"relative L2 error {rel_l2:.4f} exceeds {tol} for {dtype}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Forward correctness
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
# (M, K, N, group_size). M spans the scalar decode path (small M) and the
|
||||
# WMMA prefill path (M >= 16 on the bf16 dispatch). K/N satisfy the kernel's
|
||||
# divisibility constraints (K % G == 0, K % 8 == 0, N % 8 == 0).
|
||||
MKNG_SHAPES = [
|
||||
(1, 128, 128, 128), # single group, decode
|
||||
(2, 256, 256, 128), # two groups
|
||||
(8, 256, 512, 64), # M=8 scalar, smaller group
|
||||
(16, 512, 256, 128), # M=16 -> WMMA path for bf16
|
||||
(32, 512, 512, 64), # larger prefill
|
||||
]
|
||||
|
||||
|
||||
@gfx1100_only
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("has_zp", [False, True], ids=["no_zp", "with_zp"])
|
||||
@pytest.mark.parametrize(
|
||||
"M,K,N,G", MKNG_SHAPES, ids=[f"m{m}_k{k}_n{n}_g{g}" for m, k, n, g in MKNG_SHAPES]
|
||||
)
|
||||
def test_rdna3_w4a16_matches_reference(dtype, has_zp, M, K, N, G, dist_init):
|
||||
set_random_seed(0)
|
||||
assert K % G == 0 and K % PACK_FACTOR == 0 and N % PACK_FACTOR == 0
|
||||
|
||||
groups = K // G
|
||||
x_mk = (0.25 * torch.randn((M, K), device=device, dtype=torch.float32)).to(dtype)
|
||||
q_int4_kn = torch.randint(0, 16, (K, N), device=device, dtype=torch.int32)
|
||||
scales_gn = (
|
||||
0.05 * torch.rand((groups, N), device=device, dtype=torch.float32) + 0.01
|
||||
).to(dtype)
|
||||
zeros_gn = (
|
||||
torch.randint(0, 16, (groups, N), device=device, dtype=torch.int32)
|
||||
if has_zp
|
||||
else None
|
||||
)
|
||||
|
||||
out = _run_kernel(x_mk, q_int4_kn, scales_gn, zeros_gn, G, None, dtype)
|
||||
ref = _reference(x_mk, q_int4_kn, scales_gn, zeros_gn, G, None)
|
||||
|
||||
assert out.shape == (M, N) and out.dtype == dtype
|
||||
_assert_close(out, ref, dtype)
|
||||
|
||||
|
||||
@gfx1100_only
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("M", [1, 32], ids=["decode", "prefill"])
|
||||
def test_rdna3_w4a16_bias(dtype, M, dist_init):
|
||||
"""Bias is added on both the scalar (M=1) and WMMA (M=32) paths."""
|
||||
set_random_seed(0)
|
||||
K, N, G = 512, 256, 128
|
||||
groups = K // G
|
||||
|
||||
x_mk = (0.25 * torch.randn((M, K), device=device, dtype=torch.float32)).to(dtype)
|
||||
q_int4_kn = torch.randint(0, 16, (K, N), device=device, dtype=torch.int32)
|
||||
scales_gn = (
|
||||
0.05 * torch.rand((groups, N), device=device, dtype=torch.float32) + 0.01
|
||||
).to(dtype)
|
||||
bias = (0.1 * torch.randn(N, device=device, dtype=torch.float32)).to(dtype)
|
||||
|
||||
out = _run_kernel(x_mk, q_int4_kn, scales_gn, None, G, bias, dtype)
|
||||
ref = _reference(x_mk, q_int4_kn, scales_gn, None, G, bias)
|
||||
|
||||
_assert_close(out, ref, dtype)
|
||||
@@ -0,0 +1,89 @@
|
||||
#!/usr/bin/env python3
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Kernel-selection / gating tests for the ROCm RDNA3 W4A16 GPTQ kernel.
|
||||
|
||||
Verifies that ``choose_mp_linear_kernel`` resolves a supported W4A16 GPTQ
|
||||
config to ``RDNA3W4A16LinearKernel`` on gfx1100 (it is registered ahead of
|
||||
``TritonW4A16LinearKernel`` in the ROCm priority list), and that
|
||||
``RDNA3W4A16LinearKernel.can_implement`` rejects the configs it does not
|
||||
support so selection falls through to the next kernel.
|
||||
|
||||
Run `pytest tests/kernels/quantization/test_rdna3_w4a16_selection.py`.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if not current_platform.is_rocm():
|
||||
pytest.skip("RDNA3 W4A16 kernel is ROCm-only", allow_module_level=True)
|
||||
|
||||
from vllm.model_executor.kernels.linear import ( # noqa: E402
|
||||
choose_mp_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.MPLinearKernel import ( # noqa: E402
|
||||
MPLinearLayerConfig,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.rdna3_w4a16 import ( # noqa: E402
|
||||
RDNA3W4A16LinearKernel,
|
||||
)
|
||||
from vllm.platforms.rocm import on_gfx1100 # noqa: E402
|
||||
from vllm.scalar_type import scalar_types # noqa: E402
|
||||
|
||||
WEIGHT_TYPE = scalar_types.uint4b8 # symmetric int4, bias = 8
|
||||
|
||||
# The kernel is only selectable when running on gfx1100 with the custom op
|
||||
# compiled in; otherwise can_implement rejects and selection falls through.
|
||||
gfx1100_only = pytest.mark.skipif(
|
||||
not (
|
||||
on_gfx1100()
|
||||
and hasattr(torch.ops, "_rocm_C")
|
||||
and hasattr(torch.ops._rocm_C, "gptq_gemm_rdna3")
|
||||
),
|
||||
reason="requires gfx1100 with the _rocm_C.gptq_gemm_rdna3 op built in",
|
||||
)
|
||||
|
||||
|
||||
@gfx1100_only
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
def test_selection_prefers_rdna3(dtype):
|
||||
"""A supported W4A16 GPTQ config resolves to the RDNA3 kernel on gfx1100."""
|
||||
config = MPLinearLayerConfig(
|
||||
full_weight_shape=(1024, 256),
|
||||
partition_weight_shape=(1024, 256),
|
||||
weight_type=WEIGHT_TYPE,
|
||||
act_type=dtype,
|
||||
group_size=128,
|
||||
zero_points=False,
|
||||
has_g_idx=False,
|
||||
)
|
||||
assert choose_mp_linear_kernel(config).__name__ == "RDNA3W4A16LinearKernel"
|
||||
|
||||
|
||||
@gfx1100_only
|
||||
@pytest.mark.parametrize(
|
||||
"weight_type,group_size,N,full_k,expected_ok",
|
||||
[
|
||||
(scalar_types.uint4b8, 128, 256, 1024, True), # nominal: supported
|
||||
(scalar_types.uint4b8, -1, 256, 1024, False), # channelwise unsupported
|
||||
(scalar_types.uint4b8, 128, 252, 1024, False), # N not a multiple of 8
|
||||
(scalar_types.uint4b8, 96, 256, 1024, False), # group does not divide K
|
||||
(scalar_types.uint8b128, 128, 256, 1024, False), # wrong quant type
|
||||
],
|
||||
ids=["ok", "channelwise", "bad_n", "group_ndiv_k", "wrong_qtype"],
|
||||
)
|
||||
def test_can_implement(weight_type, group_size, N, full_k, expected_ok):
|
||||
"""can_implement gates on quant type, group size, and N divisibility."""
|
||||
config = MPLinearLayerConfig(
|
||||
full_weight_shape=(full_k, N),
|
||||
partition_weight_shape=(full_k, N),
|
||||
weight_type=weight_type,
|
||||
act_type=torch.float16,
|
||||
group_size=group_size,
|
||||
zero_points=False,
|
||||
has_g_idx=False,
|
||||
)
|
||||
ok, reason = RDNA3W4A16LinearKernel.can_implement(config)
|
||||
assert ok is expected_ok, reason
|
||||
@@ -650,6 +650,51 @@ def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, bit: int) -> None
|
||||
torch.ops._C.gptq_shuffle(q_weight, q_perm, bit)
|
||||
|
||||
|
||||
def gptq_gemm_rdna3(
|
||||
a: torch.Tensor,
|
||||
b_q_weight: torch.Tensor,
|
||||
b_qzeros: torch.Tensor,
|
||||
b_scales: torch.Tensor,
|
||||
b_g_idx: torch.Tensor,
|
||||
use_v2_format: bool,
|
||||
) -> torch.Tensor:
|
||||
return torch.ops._rocm_C.gptq_gemm_rdna3(
|
||||
a, b_q_weight, b_qzeros, b_scales, b_g_idx, use_v2_format
|
||||
)
|
||||
|
||||
|
||||
if hasattr(torch.ops, "_rocm_C") and hasattr(torch.ops._rocm_C, "gptq_gemm_rdna3"):
|
||||
|
||||
@register_fake("_rocm_C::gptq_gemm_rdna3")
|
||||
def _gptq_gemm_rdna3_fake(
|
||||
a: torch.Tensor,
|
||||
b_q_weight: torch.Tensor,
|
||||
b_qzeros: torch.Tensor,
|
||||
b_scales: torch.Tensor,
|
||||
b_g_idx: torch.Tensor,
|
||||
use_v2_format: bool,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty(
|
||||
(a.size(0), b_q_weight.size(1)), dtype=a.dtype, device=a.device
|
||||
)
|
||||
|
||||
|
||||
if hasattr(torch.ops, "_rocm_C") and hasattr(torch.ops._rocm_C, "gptq_gemm_rdna3_wmma"):
|
||||
|
||||
@register_fake("_rocm_C::gptq_gemm_rdna3_wmma")
|
||||
def _gptq_gemm_rdna3_wmma_fake(
|
||||
a: torch.Tensor,
|
||||
b_q_weight: torch.Tensor,
|
||||
b_qzeros: torch.Tensor,
|
||||
b_scales: torch.Tensor,
|
||||
b_g_idx: torch.Tensor,
|
||||
use_v2_format: bool,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty(
|
||||
(a.size(0), b_q_weight.size(1)), dtype=a.dtype, device=a.device
|
||||
)
|
||||
|
||||
|
||||
if hasattr(torch.ops._C, "allspark_w8a16_gemm"):
|
||||
|
||||
@register_fake("_C::allspark_w8a16_gemm")
|
||||
|
||||
@@ -51,6 +51,9 @@ from vllm.model_executor.kernels.linear.mixed_precision.machete import (
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.marlin import (
|
||||
MarlinLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.rdna3_w4a16 import (
|
||||
RDNA3W4A16LinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.triton_w4a16 import (
|
||||
TritonW4A16LinearKernel,
|
||||
)
|
||||
@@ -339,6 +342,7 @@ _POSSIBLE_KERNELS: dict[PlatformEnum, list[type[MPLinearKernel]]] = {
|
||||
TritonW4A16LinearKernel,
|
||||
],
|
||||
PlatformEnum.ROCM: [
|
||||
RDNA3W4A16LinearKernel,
|
||||
TritonW4A16LinearKernel,
|
||||
ConchLinearKernel,
|
||||
ExllamaLinearKernel,
|
||||
|
||||
@@ -29,6 +29,9 @@ from vllm.model_executor.kernels.linear.mixed_precision.MPLinearKernel import (
|
||||
MPLinearKernel,
|
||||
MPLinearLayerConfig,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.rdna3_w4a16 import (
|
||||
RDNA3W4A16LinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.triton_w4a16 import (
|
||||
TritonW4A16LinearKernel,
|
||||
)
|
||||
@@ -48,6 +51,7 @@ __all__ = [
|
||||
"ExllamaLinearKernel",
|
||||
"MacheteLinearKernel",
|
||||
"MarlinLinearKernel",
|
||||
"RDNA3W4A16LinearKernel",
|
||||
"TritonW4A16LinearKernel",
|
||||
"XPUW4A8IntLinearKernel",
|
||||
"XPUwNa16LinearKernel",
|
||||
|
||||
@@ -0,0 +1,193 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""W4A16 GPTQ kernel for AMD RDNA3 (gfx1100) — fp16 + bf16.
|
||||
|
||||
Drop-in replacement for ExllamaLinearKernel on RDNA3 that adds native bf16
|
||||
support. The HIP kernel lives in ``csrc/rocm/q_gemm_rdna3.cu``
|
||||
and is exposed via ``torch.ops._rocm_C.gptq_gemm_rdna3``.
|
||||
|
||||
Registered ahead of TritonW4A16LinearKernel for the ROCm-RDNA3 path; falls
|
||||
through to the Triton kernel on non-RDNA3 ROCm devices (e.g. CDNA/MI300).
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
pack_quantized_values_into_int32,
|
||||
)
|
||||
from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
|
||||
|
||||
|
||||
class RDNA3W4A16LinearKernel(MPLinearKernel):
|
||||
SUPPORTED_QUANT_TYPES = [scalar_types.uint4b8]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# ROCm gates via on_gfx1100() in can_implement.
|
||||
return 60
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_rocm():
|
||||
return False, "RDNA3 W4A16 kernel is ROCm-only"
|
||||
|
||||
from vllm.platforms.rocm import on_gfx1100
|
||||
|
||||
if not on_gfx1100():
|
||||
return False, "RDNA3 W4A16 kernel requires gfx1100"
|
||||
|
||||
# The HIP op is registered by the C++ extension; if a user is running
|
||||
# against a vLLM build that doesn't include it (e.g. partial rebuild),
|
||||
# fall through gracefully to the next kernel in the registry.
|
||||
if not (
|
||||
hasattr(torch.ops, "_rocm_C")
|
||||
and hasattr(torch.ops._rocm_C, "gptq_gemm_rdna3")
|
||||
):
|
||||
return (
|
||||
False,
|
||||
"torch.ops._rocm_C.gptq_gemm_rdna3 missing — rebuild C++ extension",
|
||||
)
|
||||
|
||||
if c.act_type not in (torch.float16, torch.bfloat16):
|
||||
return False, "RDNA3 W4A16 kernel only supports fp16 and bf16"
|
||||
|
||||
if c.weight_type not in cls.SUPPORTED_QUANT_TYPES:
|
||||
return (
|
||||
False,
|
||||
f"Quant type ({c.weight_type}) not supported by "
|
||||
f"RDNA3 W4A16 kernel; supported: {cls.SUPPORTED_QUANT_TYPES}",
|
||||
)
|
||||
|
||||
if c.group_size <= 0:
|
||||
return (
|
||||
False,
|
||||
"RDNA3 W4A16 kernel does not support channelwise quantization",
|
||||
)
|
||||
|
||||
if c.full_weight_shape[0] % c.group_size != 0:
|
||||
return (
|
||||
False,
|
||||
f"Group size ({c.group_size}) does not evenly divide K "
|
||||
f"({c.full_weight_shape[0]})",
|
||||
)
|
||||
|
||||
# Output features must be a multiple of the pack factor (8 nibbles per
|
||||
# int32) and of 8 so that qzeros (packed 4-bit per col) align cleanly
|
||||
# against the BLOCK_KN_SIZE*4 = 512 N-stride and per-thread 4 columns.
|
||||
if c.partition_weight_shape[1] % 8 != 0:
|
||||
return (
|
||||
False,
|
||||
"Output features must be a multiple of 8 for the RDNA3 "
|
||||
"W4A16 kernel (qzeros packing)",
|
||||
)
|
||||
|
||||
if c.has_g_idx and c.partition_weight_shape[0] != c.full_weight_shape[0]:
|
||||
return (
|
||||
False,
|
||||
"Act-order with TP-partitioned input features is not "
|
||||
"supported by the RDNA3 W4A16 kernel",
|
||||
)
|
||||
|
||||
return True, None
|
||||
|
||||
# ----- Weight prep (identical layout/shuffle as ExllamaLinearKernel) -----
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module):
|
||||
c = self.config
|
||||
device = getattr(layer, self.w_q_name).device
|
||||
|
||||
# Synthesize zero points if the checkpoint doesn't carry them.
|
||||
if not c.zero_points:
|
||||
self.w_zp_name = "qzeros"
|
||||
groups = c.partition_weight_shape[0] // c.group_size
|
||||
out_features = c.partition_weight_shape[1]
|
||||
|
||||
if c.weight_type.has_bias():
|
||||
# GPTQv1 quirk: the kernel adds 1 to the stored zero, so we
|
||||
# encode (bias - 1) here. See exllama.py for the link to the
|
||||
# documentation of this checkpoint-format wart.
|
||||
zeros = torch.full(
|
||||
(groups, out_features),
|
||||
c.weight_type.bias - 1,
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"RDNA3 W4A16 kernel: zero-bias 4-bit quant requires "
|
||||
"explicit zero points (GPTQv1 +1 quirk)."
|
||||
)
|
||||
zeros = pack_quantized_values_into_int32(zeros, c.weight_type, packed_dim=1)
|
||||
setattr(
|
||||
layer, self.w_zp_name, torch.nn.Parameter(zeros, requires_grad=False)
|
||||
)
|
||||
|
||||
# Act-order: convert g_idx to the inverse permutation array exllama
|
||||
# expects (kernel reads a[perm[k]] instead of using groups indirected
|
||||
# by g_idx[k]).
|
||||
if c.has_g_idx:
|
||||
|
||||
def transform_w_g_idx(x):
|
||||
return torch.argsort(x).to(torch.int)
|
||||
|
||||
self._transform_param(layer, self.w_gidx_name, transform_w_g_idx) # type: ignore
|
||||
else:
|
||||
self.w_gidx_name = "g_idx"
|
||||
empty_g_idx = torch.nn.Parameter(
|
||||
torch.empty((0,), dtype=torch.int, device=device),
|
||||
requires_grad=False,
|
||||
)
|
||||
setattr(layer, self.w_gidx_name, empty_g_idx)
|
||||
|
||||
def transform_w_q(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
assert self.w_gidx_name is not None
|
||||
g_idx = getattr(layer, self.w_gidx_name)
|
||||
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
|
||||
x_cont = x.data.contiguous()
|
||||
# Same 4-bit shuffle as exllama. The RDNA3 kernel reads weights in
|
||||
# the same shuffled int32 layout and uses the (qa & 0x000F000F)
|
||||
# bit-trick on top.
|
||||
ops.gptq_shuffle(x_cont, g_idx, c.weight_type.size_bits)
|
||||
return x_cont
|
||||
|
||||
def transform_w_s(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1)
|
||||
x.data = x.data.contiguous()
|
||||
# Keep scales in the activation dtype (fp16 OR bf16) — the kernel
|
||||
# branches on dtype internally.
|
||||
return x.to(dtype=c.act_type)
|
||||
|
||||
self._transform_param(layer, self.w_q_name, transform_w_q)
|
||||
self._transform_param(layer, self.w_s_name, transform_w_s)
|
||||
|
||||
# ----- Forward --------------------------------------------------------
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
c = self.config
|
||||
|
||||
x_2d = x.reshape(-1, x.shape[-1])
|
||||
out_shape = x.shape[:-1] + (c.partition_weight_shape[1],)
|
||||
|
||||
w_q, w_s, w_zp, w_g_idx = self._get_weight_params(layer)
|
||||
|
||||
assert w_zp is not None, "Zero points are required by RDNA3 W4A16"
|
||||
assert w_g_idx is not None, "g_idx tensor (possibly empty) required"
|
||||
|
||||
output = ops.gptq_gemm_rdna3(x_2d, w_q, w_zp, w_s, w_g_idx, False)
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias)
|
||||
return output.reshape(out_shape)
|
||||
@@ -190,6 +190,9 @@ def _get_gcn_arch() -> str:
|
||||
_GCN_ARCH = _get_gcn_arch()
|
||||
|
||||
_ON_GFX1X = any(arch in _GCN_ARCH for arch in ["gfx11", "gfx12"])
|
||||
_ON_GFX11 = "gfx11" in _GCN_ARCH
|
||||
_ON_GFX1100 = "gfx1100" in _GCN_ARCH
|
||||
_ON_GFX1151 = "gfx1151" in _GCN_ARCH
|
||||
_ON_GFX12X = any(arch in _GCN_ARCH for arch in ["gfx12"])
|
||||
_ON_MI3XX = any(arch in _GCN_ARCH for arch in ["gfx942", "gfx950"])
|
||||
_ON_GFX9 = any(arch in _GCN_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
|
||||
@@ -273,6 +276,18 @@ def on_gfx1x() -> bool:
|
||||
return _ON_GFX1X
|
||||
|
||||
|
||||
def on_gfx11() -> bool:
|
||||
return _ON_GFX11
|
||||
|
||||
|
||||
def on_gfx1100() -> bool:
|
||||
return _ON_GFX1100
|
||||
|
||||
|
||||
def on_gfx1151() -> bool:
|
||||
return _ON_GFX1151
|
||||
|
||||
|
||||
def on_gfx12x() -> bool:
|
||||
return _ON_GFX12X
|
||||
|
||||
|
||||
Reference in New Issue
Block a user