[ROCm] Add gfx1102/gfx1103 support (#40037)

Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
This commit is contained in:
Matthias Gehre
2026-04-23 10:32:04 +02:00
committed by GitHub
parent 4a79262e0f
commit 4b7869d6bc
3 changed files with 21 additions and 35 deletions
+1 -1
View File
@@ -37,7 +37,7 @@ install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS)
set(PYTHON_SUPPORTED_VERSIONS "3.10" "3.11" "3.12" "3.13")
# Supported AMD GPU architectures.
set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1101;gfx1150;gfx1151;gfx1152;gfx1153;gfx1200;gfx1201")
set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1101;gfx1102;gfx1103;gfx1150;gfx1151;gfx1152;gfx1153;gfx1200;gfx1201")
# ROCm installation prefix. Default to /opt/rocm but allow override via
# -DROCM_PATH=/your/rocm/path when invoking cmake.
+2 -11
View File
@@ -40,15 +40,6 @@ using __hip_fp8_e5m2 = __hip_fp8_e5m2_fnuz;
#define __HIP__FP8MFMA__
#endif
#if defined(__HIPCC__) && (defined(__gfx1100__) || defined(__gfx1101__) || \
defined(__gfx1150__) || defined(__gfx1151__))
#define __HIP__GFX11__
#endif
#if defined(__HIPCC__) && (defined(__gfx1200__) || defined(__gfx1201__))
#define __HIP__GFX12__
#endif
#if defined(NDEBUG)
#undef NDEBUG
#include <assert.h>
@@ -1629,7 +1620,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
}
}
#elif defined(__HIP__GFX11__)
#elif defined(__GFX11__)
using floatx8 = __attribute__((__vector_size__(8 * sizeof(float)))) float;
@@ -2388,7 +2379,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
out_ptr[threadIdx.x] = from_float<scalar_t>(acc);
}
#elif defined(__HIP__GFX12__)
#elif defined(__GFX12__)
using floatx8 = __attribute__((__vector_size__(8 * sizeof(float)))) float;
+18 -23
View File
@@ -26,16 +26,11 @@
#define __HIP__GFX9__
#endif
#if defined(__HIPCC__) && \
(defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1150__) || \
defined(__gfx1151__) || defined(__gfx1200__) || defined(__gfx1201__))
// Combined RDNA macro (gfx11 + gfx12) - both use 32-wide wavefronts
#if defined(__GFX11__) || defined(__GFX12__)
#define __HIP__GFX1X__
#endif
#if defined(__HIPCC__) && (defined(__gfx1200__) || defined(__gfx1201__))
#define __HIP__GFX12__
#endif
#if defined(__HIPCC__) && (defined(__gfx942__) || defined(__gfx950__))
#define __HIP__MI3XX__
#endif
@@ -1845,7 +1840,7 @@ torch::Tensor wvSplitKrc(const at::Tensor& in_a, const at::Tensor& in_b,
return out_c;
}
#if defined(__HIP__MI3XX__) || defined(__HIP__GFX12__)
#if defined(__HIP__MI3XX__) || defined(__GFX12__)
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
int A_CHUNK, int UNRL, int N>
__global__ void __launch_bounds__(WvPrGrp* THRDS)
@@ -1893,7 +1888,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
float sB = *s_B;
while (m < M) {
#ifdef __HIP__GFX12__
#ifdef __GFX12__
// gfx12: per-lane scalar accumulation via v_dot4_f32_fp8_fp8
float sum[N][YTILE] = {};
#else
@@ -1931,7 +1926,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
#pragma unroll
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
for (uint32_t n = 0; n < N; n++) {
#ifdef __HIP__GFX12__
#ifdef __GFX12__
// gfx12: 4 x dot4 per A_CHUNK=16 bytes (4 FP8 per dot4)
for (int y = 0; y < YTILE; ++y) {
#pragma unroll
@@ -1955,7 +1950,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
}
// Final reduction
#ifdef __HIP__GFX12__
#ifdef __GFX12__
// gfx12 wave32: DPP row_shr within 16-lane rows + cross-row shuffle
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
@@ -1993,7 +1988,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
#endif
const bool writeback_lane =
#ifdef __HIP__GFX12__
#ifdef __GFX12__
threadIdx.x == (THRDS - 1);
#else
threadIdx.x == 0;
@@ -2009,7 +2004,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
if (y + m >= M) break; // To avoid mem access fault.
#ifdef __HIP__GFX12__
#ifdef __GFX12__
float result = sum[n][y] * sA * sB;
#else
float result = sum[n][y][0] * sA * sB;
@@ -2027,7 +2022,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
m += CuCount * _WvPrGrp * YTILE;
}
}
#else // !defined(__HIP__MI3XX__) && !defined(__HIP__GFX12__)
#else // !defined(__HIP__MI3XX__) && !defined(__GFX12__)
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
int A_CHUNK, int UNRL, int N>
__global__ void wvSplitKQ_hf_sml_(const int K, const int Kap, const int Kbp,
@@ -2039,9 +2034,9 @@ __global__ void wvSplitKQ_hf_sml_(const int K, const int Kap, const int Kbp,
const int _WvPrGrp, const int CuCount) {
UNREACHABLE_CODE
}
#endif // defined(__HIP__MI3XX__) || defined(__HIP__GFX12__)
#endif // defined(__HIP__MI3XX__) || defined(__GFX12__)
#if defined(__HIP__MI3XX__) || defined(__HIP__GFX12__)
#if defined(__HIP__MI3XX__) || defined(__GFX12__)
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
int A_CHUNK, int UNRL, int N>
__global__ void __launch_bounds__(WvPrGrp* THRDS)
@@ -2088,7 +2083,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
float sB = *s_B;
while (m < M) {
#ifdef __HIP__GFX12__
#ifdef __GFX12__
// gfx12: per-lane scalar accumulation via v_dot4_f32_fp8_fp8
float sum[N][YTILE] = {};
#else
@@ -2128,7 +2123,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
#pragma unroll
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
for (uint32_t n = 0; n < N; n++) {
#ifdef __HIP__GFX12__
#ifdef __GFX12__
// gfx12: 4 x dot4 per A_CHUNK=16 bytes (4 FP8 per dot4)
for (int y = 0; y < YTILE; ++y) {
#pragma unroll
@@ -2152,7 +2147,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
}
// Final reduction
#ifdef __HIP__GFX12__
#ifdef __GFX12__
// gfx12 wave32: DPP row_shr within 16-lane rows + cross-row shuffle
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
@@ -2190,7 +2185,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
#endif
const bool writeback_lane =
#ifdef __HIP__GFX12__
#ifdef __GFX12__
threadIdx.x == (THRDS - 1);
#else
threadIdx.x == 0;
@@ -2206,7 +2201,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
if (y + m >= M) break; // To avoid mem access fault.
#ifdef __HIP__GFX12__
#ifdef __GFX12__
float result = sum[n][y] * sA * sB;
#else
float result = sum[n][y][0] * sA * sB;
@@ -2224,7 +2219,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
m += CuCount * _WvPrGrp * YTILE;
}
}
#else // !defined(__HIP__MI3XX__) && !defined(__HIP__GFX12__)
#else // !defined(__HIP__MI3XX__) && !defined(__GFX12__)
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
int A_CHUNK, int UNRL, int N>
__global__ void wvSplitKQ_hf_(const int K, const int Kap, const int Kbp,
@@ -2236,7 +2231,7 @@ __global__ void wvSplitKQ_hf_(const int K, const int Kap, const int Kbp,
const int CuCount) {
UNREACHABLE_CODE
}
#endif // defined(__HIP__MI3XX__) || defined(__HIP__GFX12__)
#endif // defined(__HIP__MI3XX__) || defined(__GFX12__)
void wvSplitKQ(const at::Tensor& in_b, const at::Tensor& in_a,
const std::optional<at::Tensor>& in_bias, at::Tensor& out_c,