From 4b7869d6bc64f5b124e2403891b4c2e29713bbf5 Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Thu, 23 Apr 2026 10:32:04 +0200 Subject: [PATCH] [ROCm] Add gfx1102/gfx1103 support (#40037) Signed-off-by: Matthias Gehre --- CMakeLists.txt | 2 +- csrc/rocm/attention.cu | 13 ++----------- csrc/rocm/skinny_gemms.cu | 41 +++++++++++++++++---------------------- 3 files changed, 21 insertions(+), 35 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 8f859c9cc40..e79c5b9f912 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -37,7 +37,7 @@ install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS) set(PYTHON_SUPPORTED_VERSIONS "3.10" "3.11" "3.12" "3.13") # Supported AMD GPU architectures. -set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1101;gfx1150;gfx1151;gfx1152;gfx1153;gfx1200;gfx1201") +set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1101;gfx1102;gfx1103;gfx1150;gfx1151;gfx1152;gfx1153;gfx1200;gfx1201") # ROCm installation prefix. Default to /opt/rocm but allow override via # -DROCM_PATH=/your/rocm/path when invoking cmake. diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu index a339c5641bb..9e6c0726d19 100644 --- a/csrc/rocm/attention.cu +++ b/csrc/rocm/attention.cu @@ -40,15 +40,6 @@ using __hip_fp8_e5m2 = __hip_fp8_e5m2_fnuz; #define __HIP__FP8MFMA__ #endif -#if defined(__HIPCC__) && (defined(__gfx1100__) || defined(__gfx1101__) || \ - defined(__gfx1150__) || defined(__gfx1151__)) - #define __HIP__GFX11__ -#endif - -#if defined(__HIPCC__) && (defined(__gfx1200__) || defined(__gfx1201__)) - #define __HIP__GFX12__ -#endif - #if defined(NDEBUG) #undef NDEBUG #include @@ -1629,7 +1620,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( } } -#elif defined(__HIP__GFX11__) +#elif defined(__GFX11__) using floatx8 = __attribute__((__vector_size__(8 * sizeof(float)))) float; @@ -2388,7 +2379,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( out_ptr[threadIdx.x] = from_float(acc); } -#elif defined(__HIP__GFX12__) +#elif defined(__GFX12__) using floatx8 = __attribute__((__vector_size__(8 * sizeof(float)))) float; diff --git a/csrc/rocm/skinny_gemms.cu b/csrc/rocm/skinny_gemms.cu index 60e10e53391..3342db37be9 100644 --- a/csrc/rocm/skinny_gemms.cu +++ b/csrc/rocm/skinny_gemms.cu @@ -26,16 +26,11 @@ #define __HIP__GFX9__ #endif -#if defined(__HIPCC__) && \ - (defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1150__) || \ - defined(__gfx1151__) || defined(__gfx1200__) || defined(__gfx1201__)) +// Combined RDNA macro (gfx11 + gfx12) - both use 32-wide wavefronts +#if defined(__GFX11__) || defined(__GFX12__) #define __HIP__GFX1X__ #endif -#if defined(__HIPCC__) && (defined(__gfx1200__) || defined(__gfx1201__)) - #define __HIP__GFX12__ -#endif - #if defined(__HIPCC__) && (defined(__gfx942__) || defined(__gfx950__)) #define __HIP__MI3XX__ #endif @@ -1845,7 +1840,7 @@ torch::Tensor wvSplitKrc(const at::Tensor& in_a, const at::Tensor& in_b, return out_c; } -#if defined(__HIP__MI3XX__) || defined(__HIP__GFX12__) +#if defined(__HIP__MI3XX__) || defined(__GFX12__) template __global__ void __launch_bounds__(WvPrGrp* THRDS) @@ -1893,7 +1888,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) float sB = *s_B; while (m < M) { - #ifdef __HIP__GFX12__ + #ifdef __GFX12__ // gfx12: per-lane scalar accumulation via v_dot4_f32_fp8_fp8 float sum[N][YTILE] = {}; #else @@ -1931,7 +1926,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) #pragma unroll for (uint32_t k2 = 0; k2 < UNRL; k2++) { for (uint32_t n = 0; n < N; n++) { - #ifdef __HIP__GFX12__ + #ifdef __GFX12__ // gfx12: 4 x dot4 per A_CHUNK=16 bytes (4 FP8 per dot4) for (int y = 0; y < YTILE; ++y) { #pragma unroll @@ -1955,7 +1950,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) } // Final reduction - #ifdef __HIP__GFX12__ + #ifdef __GFX12__ // gfx12 wave32: DPP row_shr within 16-lane rows + cross-row shuffle for (int n = 0; n < N; n++) { for (int y = 0; y < YTILE; y++) { @@ -1993,7 +1988,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) #endif const bool writeback_lane = - #ifdef __HIP__GFX12__ + #ifdef __GFX12__ threadIdx.x == (THRDS - 1); #else threadIdx.x == 0; @@ -2009,7 +2004,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) for (int n = 0; n < N; n++) { for (int y = 0; y < YTILE; y++) { if (y + m >= M) break; // To avoid mem access fault. - #ifdef __HIP__GFX12__ + #ifdef __GFX12__ float result = sum[n][y] * sA * sB; #else float result = sum[n][y][0] * sA * sB; @@ -2027,7 +2022,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) m += CuCount * _WvPrGrp * YTILE; } } -#else // !defined(__HIP__MI3XX__) && !defined(__HIP__GFX12__) +#else // !defined(__HIP__MI3XX__) && !defined(__GFX12__) template __global__ void wvSplitKQ_hf_sml_(const int K, const int Kap, const int Kbp, @@ -2039,9 +2034,9 @@ __global__ void wvSplitKQ_hf_sml_(const int K, const int Kap, const int Kbp, const int _WvPrGrp, const int CuCount) { UNREACHABLE_CODE } -#endif // defined(__HIP__MI3XX__) || defined(__HIP__GFX12__) +#endif // defined(__HIP__MI3XX__) || defined(__GFX12__) -#if defined(__HIP__MI3XX__) || defined(__HIP__GFX12__) +#if defined(__HIP__MI3XX__) || defined(__GFX12__) template __global__ void __launch_bounds__(WvPrGrp* THRDS) @@ -2088,7 +2083,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) float sB = *s_B; while (m < M) { - #ifdef __HIP__GFX12__ + #ifdef __GFX12__ // gfx12: per-lane scalar accumulation via v_dot4_f32_fp8_fp8 float sum[N][YTILE] = {}; #else @@ -2128,7 +2123,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) #pragma unroll for (uint32_t k2 = 0; k2 < UNRL; k2++) { for (uint32_t n = 0; n < N; n++) { - #ifdef __HIP__GFX12__ + #ifdef __GFX12__ // gfx12: 4 x dot4 per A_CHUNK=16 bytes (4 FP8 per dot4) for (int y = 0; y < YTILE; ++y) { #pragma unroll @@ -2152,7 +2147,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) } // Final reduction - #ifdef __HIP__GFX12__ + #ifdef __GFX12__ // gfx12 wave32: DPP row_shr within 16-lane rows + cross-row shuffle for (int n = 0; n < N; n++) { for (int y = 0; y < YTILE; y++) { @@ -2190,7 +2185,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) #endif const bool writeback_lane = - #ifdef __HIP__GFX12__ + #ifdef __GFX12__ threadIdx.x == (THRDS - 1); #else threadIdx.x == 0; @@ -2206,7 +2201,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) for (int n = 0; n < N; n++) { for (int y = 0; y < YTILE; y++) { if (y + m >= M) break; // To avoid mem access fault. - #ifdef __HIP__GFX12__ + #ifdef __GFX12__ float result = sum[n][y] * sA * sB; #else float result = sum[n][y][0] * sA * sB; @@ -2224,7 +2219,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) m += CuCount * _WvPrGrp * YTILE; } } -#else // !defined(__HIP__MI3XX__) && !defined(__HIP__GFX12__) +#else // !defined(__HIP__MI3XX__) && !defined(__GFX12__) template __global__ void wvSplitKQ_hf_(const int K, const int Kap, const int Kbp, @@ -2236,7 +2231,7 @@ __global__ void wvSplitKQ_hf_(const int K, const int Kap, const int Kbp, const int CuCount) { UNREACHABLE_CODE } -#endif // defined(__HIP__MI3XX__) || defined(__HIP__GFX12__) +#endif // defined(__HIP__MI3XX__) || defined(__GFX12__) void wvSplitKQ(const at::Tensor& in_b, const at::Tensor& in_a, const std::optional& in_bias, at::Tensor& out_c,