mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[ROCm] Add gfx1102/gfx1103 support (#40037)
Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
This commit is contained in:
+1
-1
@@ -37,7 +37,7 @@ install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS)
|
||||
set(PYTHON_SUPPORTED_VERSIONS "3.10" "3.11" "3.12" "3.13")
|
||||
|
||||
# Supported AMD GPU architectures.
|
||||
set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1101;gfx1150;gfx1151;gfx1152;gfx1153;gfx1200;gfx1201")
|
||||
set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1101;gfx1102;gfx1103;gfx1150;gfx1151;gfx1152;gfx1153;gfx1200;gfx1201")
|
||||
|
||||
# ROCm installation prefix. Default to /opt/rocm but allow override via
|
||||
# -DROCM_PATH=/your/rocm/path when invoking cmake.
|
||||
|
||||
+2
-11
@@ -40,15 +40,6 @@ using __hip_fp8_e5m2 = __hip_fp8_e5m2_fnuz;
|
||||
#define __HIP__FP8MFMA__
|
||||
#endif
|
||||
|
||||
#if defined(__HIPCC__) && (defined(__gfx1100__) || defined(__gfx1101__) || \
|
||||
defined(__gfx1150__) || defined(__gfx1151__))
|
||||
#define __HIP__GFX11__
|
||||
#endif
|
||||
|
||||
#if defined(__HIPCC__) && (defined(__gfx1200__) || defined(__gfx1201__))
|
||||
#define __HIP__GFX12__
|
||||
#endif
|
||||
|
||||
#if defined(NDEBUG)
|
||||
#undef NDEBUG
|
||||
#include <assert.h>
|
||||
@@ -1629,7 +1620,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
#elif defined(__HIP__GFX11__)
|
||||
#elif defined(__GFX11__)
|
||||
|
||||
using floatx8 = __attribute__((__vector_size__(8 * sizeof(float)))) float;
|
||||
|
||||
@@ -2388,7 +2379,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
|
||||
out_ptr[threadIdx.x] = from_float<scalar_t>(acc);
|
||||
}
|
||||
|
||||
#elif defined(__HIP__GFX12__)
|
||||
#elif defined(__GFX12__)
|
||||
|
||||
using floatx8 = __attribute__((__vector_size__(8 * sizeof(float)))) float;
|
||||
|
||||
|
||||
+18
-23
@@ -26,16 +26,11 @@
|
||||
#define __HIP__GFX9__
|
||||
#endif
|
||||
|
||||
#if defined(__HIPCC__) && \
|
||||
(defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1150__) || \
|
||||
defined(__gfx1151__) || defined(__gfx1200__) || defined(__gfx1201__))
|
||||
// Combined RDNA macro (gfx11 + gfx12) - both use 32-wide wavefronts
|
||||
#if defined(__GFX11__) || defined(__GFX12__)
|
||||
#define __HIP__GFX1X__
|
||||
#endif
|
||||
|
||||
#if defined(__HIPCC__) && (defined(__gfx1200__) || defined(__gfx1201__))
|
||||
#define __HIP__GFX12__
|
||||
#endif
|
||||
|
||||
#if defined(__HIPCC__) && (defined(__gfx942__) || defined(__gfx950__))
|
||||
#define __HIP__MI3XX__
|
||||
#endif
|
||||
@@ -1845,7 +1840,7 @@ torch::Tensor wvSplitKrc(const at::Tensor& in_a, const at::Tensor& in_b,
|
||||
return out_c;
|
||||
}
|
||||
|
||||
#if defined(__HIP__MI3XX__) || defined(__HIP__GFX12__)
|
||||
#if defined(__HIP__MI3XX__) || defined(__GFX12__)
|
||||
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
|
||||
int A_CHUNK, int UNRL, int N>
|
||||
__global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
@@ -1893,7 +1888,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
float sB = *s_B;
|
||||
|
||||
while (m < M) {
|
||||
#ifdef __HIP__GFX12__
|
||||
#ifdef __GFX12__
|
||||
// gfx12: per-lane scalar accumulation via v_dot4_f32_fp8_fp8
|
||||
float sum[N][YTILE] = {};
|
||||
#else
|
||||
@@ -1931,7 +1926,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
#pragma unroll
|
||||
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
|
||||
for (uint32_t n = 0; n < N; n++) {
|
||||
#ifdef __HIP__GFX12__
|
||||
#ifdef __GFX12__
|
||||
// gfx12: 4 x dot4 per A_CHUNK=16 bytes (4 FP8 per dot4)
|
||||
for (int y = 0; y < YTILE; ++y) {
|
||||
#pragma unroll
|
||||
@@ -1955,7 +1950,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
}
|
||||
|
||||
// Final reduction
|
||||
#ifdef __HIP__GFX12__
|
||||
#ifdef __GFX12__
|
||||
// gfx12 wave32: DPP row_shr within 16-lane rows + cross-row shuffle
|
||||
for (int n = 0; n < N; n++) {
|
||||
for (int y = 0; y < YTILE; y++) {
|
||||
@@ -1993,7 +1988,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
#endif
|
||||
|
||||
const bool writeback_lane =
|
||||
#ifdef __HIP__GFX12__
|
||||
#ifdef __GFX12__
|
||||
threadIdx.x == (THRDS - 1);
|
||||
#else
|
||||
threadIdx.x == 0;
|
||||
@@ -2009,7 +2004,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
for (int n = 0; n < N; n++) {
|
||||
for (int y = 0; y < YTILE; y++) {
|
||||
if (y + m >= M) break; // To avoid mem access fault.
|
||||
#ifdef __HIP__GFX12__
|
||||
#ifdef __GFX12__
|
||||
float result = sum[n][y] * sA * sB;
|
||||
#else
|
||||
float result = sum[n][y][0] * sA * sB;
|
||||
@@ -2027,7 +2022,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
m += CuCount * _WvPrGrp * YTILE;
|
||||
}
|
||||
}
|
||||
#else // !defined(__HIP__MI3XX__) && !defined(__HIP__GFX12__)
|
||||
#else // !defined(__HIP__MI3XX__) && !defined(__GFX12__)
|
||||
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
|
||||
int A_CHUNK, int UNRL, int N>
|
||||
__global__ void wvSplitKQ_hf_sml_(const int K, const int Kap, const int Kbp,
|
||||
@@ -2039,9 +2034,9 @@ __global__ void wvSplitKQ_hf_sml_(const int K, const int Kap, const int Kbp,
|
||||
const int _WvPrGrp, const int CuCount) {
|
||||
UNREACHABLE_CODE
|
||||
}
|
||||
#endif // defined(__HIP__MI3XX__) || defined(__HIP__GFX12__)
|
||||
#endif // defined(__HIP__MI3XX__) || defined(__GFX12__)
|
||||
|
||||
#if defined(__HIP__MI3XX__) || defined(__HIP__GFX12__)
|
||||
#if defined(__HIP__MI3XX__) || defined(__GFX12__)
|
||||
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
|
||||
int A_CHUNK, int UNRL, int N>
|
||||
__global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
@@ -2088,7 +2083,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
float sB = *s_B;
|
||||
|
||||
while (m < M) {
|
||||
#ifdef __HIP__GFX12__
|
||||
#ifdef __GFX12__
|
||||
// gfx12: per-lane scalar accumulation via v_dot4_f32_fp8_fp8
|
||||
float sum[N][YTILE] = {};
|
||||
#else
|
||||
@@ -2128,7 +2123,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
#pragma unroll
|
||||
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
|
||||
for (uint32_t n = 0; n < N; n++) {
|
||||
#ifdef __HIP__GFX12__
|
||||
#ifdef __GFX12__
|
||||
// gfx12: 4 x dot4 per A_CHUNK=16 bytes (4 FP8 per dot4)
|
||||
for (int y = 0; y < YTILE; ++y) {
|
||||
#pragma unroll
|
||||
@@ -2152,7 +2147,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
}
|
||||
|
||||
// Final reduction
|
||||
#ifdef __HIP__GFX12__
|
||||
#ifdef __GFX12__
|
||||
// gfx12 wave32: DPP row_shr within 16-lane rows + cross-row shuffle
|
||||
for (int n = 0; n < N; n++) {
|
||||
for (int y = 0; y < YTILE; y++) {
|
||||
@@ -2190,7 +2185,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
#endif
|
||||
|
||||
const bool writeback_lane =
|
||||
#ifdef __HIP__GFX12__
|
||||
#ifdef __GFX12__
|
||||
threadIdx.x == (THRDS - 1);
|
||||
#else
|
||||
threadIdx.x == 0;
|
||||
@@ -2206,7 +2201,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
for (int n = 0; n < N; n++) {
|
||||
for (int y = 0; y < YTILE; y++) {
|
||||
if (y + m >= M) break; // To avoid mem access fault.
|
||||
#ifdef __HIP__GFX12__
|
||||
#ifdef __GFX12__
|
||||
float result = sum[n][y] * sA * sB;
|
||||
#else
|
||||
float result = sum[n][y][0] * sA * sB;
|
||||
@@ -2224,7 +2219,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
m += CuCount * _WvPrGrp * YTILE;
|
||||
}
|
||||
}
|
||||
#else // !defined(__HIP__MI3XX__) && !defined(__HIP__GFX12__)
|
||||
#else // !defined(__HIP__MI3XX__) && !defined(__GFX12__)
|
||||
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
|
||||
int A_CHUNK, int UNRL, int N>
|
||||
__global__ void wvSplitKQ_hf_(const int K, const int Kap, const int Kbp,
|
||||
@@ -2236,7 +2231,7 @@ __global__ void wvSplitKQ_hf_(const int K, const int Kap, const int Kbp,
|
||||
const int CuCount) {
|
||||
UNREACHABLE_CODE
|
||||
}
|
||||
#endif // defined(__HIP__MI3XX__) || defined(__HIP__GFX12__)
|
||||
#endif // defined(__HIP__MI3XX__) || defined(__GFX12__)
|
||||
|
||||
void wvSplitKQ(const at::Tensor& in_b, const at::Tensor& in_a,
|
||||
const std::optional<at::Tensor>& in_bias, at::Tensor& out_c,
|
||||
|
||||
Reference in New Issue
Block a user