/* * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement * * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual * property and proprietary rights in and to this material, related * documentation and any modifications thereto. Any use, reproduction, * disclosure or distribution of this material and related documentation * without an express license agreement from NVIDIA CORPORATION or * its affiliates is strictly prohibited. */ #include #include #include #include //////////////////////////////////////////////////////////////////////////////////////////////////// // The number of threads per warp. enum { THREADS_PER_WARP = 32 }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Softmax_params { // Output pointer. Dst_type* dst; // Source pointer. Src_type const* src; // Masks. int8_t const* mask; // Attention sinks (per head). float const* attention_sinks; // Softmax sum pointer. float* softmax_sum; // ALiBi bool has_alibi; // Dimensions of the problem. size_t b, h; // Precomputed constants. size_t bhs, hs, bs; // The scaling factors to apply when we convert to/from float. float scale_bmm1, softcapping_scale_bmm1, scale_softmax; // The number of reduction warps used by the fused kernel. int warps_n; int* cu_q_seqlens; }; //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ float to_float(uint16_t const& src, float) { return fmha::half_to_float(src); } //////////////////////////////////////////////////////////////////////////////////////////////////// // Disable warning #177-D because this function has not been used elsewhere #pragma nv_diag_suppress 177 static inline __device__ float to_float(fmha::bf16_t const& src, float) { return __bfloat162float(src); } #pragma nv_diag_default 177 //////////////////////////////////////////////////////////////////////////////////////////////////// // Disable warning #177-D because this function has not been used elsewhere #pragma nv_diag_suppress 177 static inline __device__ float to_float(fmha::e4m3_t const& src, float) { return float(src); } #pragma nv_diag_default 177 //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ float to_float(float const& src, float) { return src; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ float to_float(int const& src, float scale) { float dst; // Convert from int to float. dst = static_cast(src); // Scale. dst *= scale; return dst; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ void from_float(uint16_t& dst, float const& src, float) { dst = fmha::float_to_half(src); } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ void from_float(fmha::bf16_t& dst, float const& src, float) { dst = fmha::float_to_bf16(src); } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ int8_t float_to_int8_rn(float x) { uint32_t dst; asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x)); return reinterpret_cast(dst); } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ void from_float(int8_t& dst, float const& src, float scale) { dst = float_to_int8_rn(src * scale); } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ void from_float(fmha::e4m3_t& dst, float const& src, float scale) { dst = fmha::e4m3_t(src * scale); } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ float apply_exp_(float x, float max) { return isinf(x) ? 0.f : __expf(x - max); } //////////////////////////////////////////////////////////////////////////////////////////////////// template static inline __device__ void reduce(float (&data_fp32)[N][1], int8_t const (&mask)[N][1], int warps_n, float& sum_fp32, float& max_fp32, float const attention_sink) { // Apply the masks. #pragma unroll for (int ii = 0; ii < N; ++ii) { data_fp32[ii][0] = mask[ii][0] ? data_fp32[ii][0] : -HUGE_VALF; } // Compute the max inside the thread. #pragma unroll for (int ii = 0; ii < N; ++ii) { max_fp32 = fmaxf(max_fp32, data_fp32[ii][0]); } // Compute inside the warp. #pragma unroll for (int xor_mask = THREADS_PER_WARP / 2; xor_mask > 0; xor_mask /= 2) { max_fp32 = fmaxf(max_fp32, __shfl_xor_sync(uint32_t(-1), max_fp32, xor_mask)); } // Transform the elements. #pragma unroll for (int ii = 0; ii < N; ++ii) { data_fp32[ii][0] = apply_exp_(data_fp32[ii][0], max_fp32); } // Compute the max inside the thread. #if defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE) #pragma unroll for (int ii = 0; ii < N; ii++) { sum_fp32 += data_fp32[ii][0]; //+0 +64 +128 } // Emulate tmp[0] + tmp[1] sum_fp32 += __shfl_xor_sync(uint32_t(-1), sum_fp32, 4); __syncwarp(); // Emulate dst[mi] += __shfl_xor_sync(uint32_t(-1), dst[mi], 1); __syncwarp(); sum_fp32 += __shfl_xor_sync(uint32_t(-1), sum_fp32, 1); __syncwarp(); // Emulate dst[mi] += __shfl_xor_sync(uint32_t(-1), dst[mi], 2); __syncwarp(); sum_fp32 += __shfl_xor_sync(uint32_t(-1), sum_fp32, 2); __syncwarp(); // Emulate final reduction sum_fp32 += __shfl_xor_sync(uint32_t(-1), sum_fp32, 8); __syncwarp(); sum_fp32 += __shfl_xor_sync(uint32_t(-1), sum_fp32, 16); __syncwarp(); #else #pragma unroll for (int ii = 0; ii < N; ++ii) { sum_fp32 += data_fp32[ii][0]; } // Compute inside the warp. #pragma unroll for (int xor_mask = THREADS_PER_WARP / 2; xor_mask > 0; xor_mask /= 2) { sum_fp32 += __shfl_xor_sync(uint32_t(-1), sum_fp32, xor_mask); } #endif // // DEBUG. // if( blockIdx.z == 1 && threadIdx.y == 0 && threadIdx.x == 5 ) { // printf("elt=%12.8f sum_fp32=%12.8f\n", data_fp32[0].x, sum_fp32); // } // Fix the sum if needed. if (sum_fp32 == 0.f || sum_fp32 != sum_fp32) { sum_fp32 = 1.f; } // Normalize. float inv_sum_fp32 = 1.f / (sum_fp32 + expf(attention_sink - max_fp32)); #pragma unroll for (int ii = 0; ii < N; ++ii) { data_fp32[ii][0] *= inv_sum_fp32; } } //////////////////////////////////////////////////////////////////////////////////////////////////// template static inline __device__ void reduce(float (&data_fp32)[N][2], int8_t const (&mask)[N][2], int warps_n, float& sum_fp32, float& max_fp32, float const attention_sink) { // Apply the masks. #pragma unroll for (int ii = 0; ii < N; ++ii) { data_fp32[ii][0] = mask[ii][0] ? data_fp32[ii][0] : -HUGE_VALF; data_fp32[ii][1] = mask[ii][1] ? data_fp32[ii][1] : -HUGE_VALF; } // Compute the max inside the thread. #pragma unroll for (int ii = 0; ii < N; ++ii) { max_fp32 = fmaxf(max_fp32, data_fp32[ii][0]); max_fp32 = fmaxf(max_fp32, data_fp32[ii][1]); } // Compute inside the warp. #pragma unroll for (int xor_mask = THREADS_PER_WARP / 2; xor_mask > 0; xor_mask /= 2) { max_fp32 = fmaxf(max_fp32, __shfl_xor_sync(uint32_t(-1), max_fp32, xor_mask)); } // // DEBUG. // if( blockIdx.z == 1 && threadIdx.y == 0 && threadIdx.x == 5 ) { // printf("elt=%12.8f max_fp32=%12.8f\n", data_fp32[0][0], max_fp32); // } // // END OF DEBUG. // Transform the elements. #pragma unroll for (int ii = 0; ii < N; ++ii) { data_fp32[ii][0] = apply_exp_(data_fp32[ii][0], max_fp32); data_fp32[ii][1] = apply_exp_(data_fp32[ii][1], max_fp32); } // Compute the max inside the thread. // float sum_fp32 = 0.f; #if defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE) if (warps_n == 1) { // TODO not sure if we can improve this on the gmma side without using additional regs. // this is intentionally o(n) instead of o(log n) // lanes 0 and 1 here represent the first quad. // need to account for offset of l0 when addressing absolute lanes. int const ti = threadIdx.x % 4; float tmp = 0.f; for (int ni = 0; ni < N; ni++) { float x = data_fp32[ni][0] + data_fp32[ni][1]; tmp += x; for (int it = 1; it < 8; it++) { tmp += __shfl_sync(uint32_t(-1), x, 4 * it + ti); __syncwarp(); } } // emulate dst[mi] += __shfl_xor_sync(uint32_t(-1), dst[mi], 1); __syncwarp(); tmp += __shfl_xor_sync(uint32_t(-1), tmp, 1); __syncwarp(); // emulate dst[mi] += __shfl_xor_sync(uint32_t(-1), dst[mi], 2); __syncwarp(); tmp += __shfl_xor_sync(uint32_t(-1), tmp, 2); __syncwarp(); sum_fp32 = __shfl_sync(uint32_t(-1), tmp, 0); } else if (warps_n == 8) { // Accumulate warp 0 and warp 4 float tmp[2] = {0.f, 0.f}; #pragma unroll for (int ii = 0; ii < N; ii += 2) { tmp[0] += data_fp32[ii + 0][0]; tmp[0] += data_fp32[ii + 0][1]; tmp[1] += data_fp32[ii + 1][0]; tmp[1] += data_fp32[ii + 1][1]; } // Emulate tmp[0] + tmp[1] tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 4); tmp[1] += __shfl_xor_sync(uint32_t(-1), tmp[1], 4); // Emulate dst[mi] += __shfl_xor_sync(uint32_t(-1), dst[mi], 1); __syncwarp(); tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 1); tmp[1] += __shfl_xor_sync(uint32_t(-1), tmp[1], 1); // Emulate dst[mi] += __shfl_xor_sync(uint32_t(-1), dst[mi], 2); __syncwarp(); tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 2); tmp[1] += __shfl_xor_sync(uint32_t(-1), tmp[1], 2); // Emulate final reduction tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 8); tmp[1] += __shfl_xor_sync(uint32_t(-1), tmp[1], 8); tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 16); tmp[1] += __shfl_xor_sync(uint32_t(-1), tmp[1], 16); sum_fp32 = tmp[0] + tmp[1]; sum_fp32 = __shfl_sync(uint32_t(-1), sum_fp32, 0); } else { #pragma unroll for (int ii = 0; ii < N; ii++) { sum_fp32 += data_fp32[ii][0] + data_fp32[ii][1]; } // Emulate tmp[0] + tmp[1] sum_fp32 += __shfl_xor_sync(uint32_t(-1), sum_fp32, 4); // Emulate dst[mi] += __shfl_xor_sync(uint32_t(-1), dst[mi], 1); __syncwarp(); sum_fp32 += __shfl_xor_sync(uint32_t(-1), sum_fp32, 1); // Emulate dst[mi] += __shfl_xor_sync(uint32_t(-1), dst[mi], 2); __syncwarp(); sum_fp32 += __shfl_xor_sync(uint32_t(-1), sum_fp32, 2); // Emulate final reduction sum_fp32 += __shfl_xor_sync(uint32_t(-1), sum_fp32, 8); sum_fp32 += __shfl_xor_sync(uint32_t(-1), sum_fp32, 16); sum_fp32 = __shfl_sync(uint32_t(-1), sum_fp32, 0); } #else #pragma unroll for (int ii = 0; ii < N; ++ii) { sum_fp32 += data_fp32[ii][0]; sum_fp32 += data_fp32[ii][1]; } // Compute inside the warp. #pragma unroll for (int xor_mask = THREADS_PER_WARP / 2; xor_mask > 0; xor_mask /= 2) { sum_fp32 += __shfl_xor_sync(uint32_t(-1), sum_fp32, xor_mask); } #endif // // DEBUG. // if( blockIdx.z == 1 && threadIdx.y == 0 && threadIdx.x == 5 ) { // printf("elt=%12.8f sum_fp32=%12.8f\n", data_fp32[0][0], sum_fp32); // } // Fix the sum if needed. if (sum_fp32 == 0.f || sum_fp32 != sum_fp32) { sum_fp32 = 1.f; } // Normalize. float inv_sum_fp32 = 1.f / (sum_fp32 + expf(attention_sink - max_fp32)); #pragma unroll for (int ii = 0; ii < N; ++ii) { data_fp32[ii][0] *= inv_sum_fp32; data_fp32[ii][1] *= inv_sum_fp32; } } //////////////////////////////////////////////////////////////////////////////////////////////////// template static inline __device__ void reduce(float (&data_fp32)[N][4], int8_t const (&mask)[N][4], int warps_n, float& sum_fp32, float& max_fp32, float const attention_sink) { // Apply the masks. #pragma unroll for (int ii = 0; ii < N; ++ii) { data_fp32[ii][0] = mask[ii][0] ? data_fp32[ii][0] : -HUGE_VALF; data_fp32[ii][1] = mask[ii][1] ? data_fp32[ii][1] : -HUGE_VALF; data_fp32[ii][2] = mask[ii][2] ? data_fp32[ii][2] : -HUGE_VALF; data_fp32[ii][3] = mask[ii][3] ? data_fp32[ii][3] : -HUGE_VALF; } // Compute the max inside the thread. #pragma unroll for (int ii = 0; ii < N; ++ii) { max_fp32 = fmaxf(max_fp32, data_fp32[ii][0]); max_fp32 = fmaxf(max_fp32, data_fp32[ii][1]); max_fp32 = fmaxf(max_fp32, data_fp32[ii][2]); max_fp32 = fmaxf(max_fp32, data_fp32[ii][3]); } // Compute inside the warp. #pragma unroll for (int xor_mask = THREADS_PER_WARP / 2; xor_mask > 0; xor_mask /= 2) { max_fp32 = fmaxf(max_fp32, __shfl_xor_sync(uint32_t(-1), max_fp32, xor_mask)); } // // DEBUG. // if( blockIdx.z == 1 && threadIdx.y == 0 && threadIdx.x == 5 ) { // printf("elt=%12.8f max_fp32=%12.8f\n", data_fp32[0][0], max_fp32); // } // Transform the elements. #pragma unroll for (int ii = 0; ii < N; ++ii) { data_fp32[ii][0] = apply_exp_(data_fp32[ii][0], max_fp32); data_fp32[ii][1] = apply_exp_(data_fp32[ii][1], max_fp32); data_fp32[ii][2] = apply_exp_(data_fp32[ii][2], max_fp32); data_fp32[ii][3] = apply_exp_(data_fp32[ii][3], max_fp32); } // Compute the max inside the thread. // float sum_fp32 = 0.f; // TODO needs refactoring... #if defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE) // Within a thread it should correspond to the operation done in the tmp[0]/[1] loop. if (warps_n == 1) { // E.g. 4x1: 4 threads iterate over all cores. // TODO not sure if we can improve this on the gmma side without using additional regs. // this is intentionally o(n) instead of o(log n) // lanes 0 and 1 here represent the first quad. // need to account for offset of l0 when addressing absolute lanes. int const ti = threadIdx.x % 2; float tmp[2] = {0.f, 0.f}; for (int ni = 0; ni < N; ni++) { // +1 float x = data_fp32[ni][0] + data_fp32[ni][1]; float y = data_fp32[ni][2] + data_fp32[ni][3]; tmp[0] += x; tmp[1] += y; for (int it = 1; it < 16; it++) { tmp[0] += __shfl_sync(uint32_t(-1), x, 2 * it + ti); __syncwarp(); tmp[1] += __shfl_sync(uint32_t(-1), y, 2 * it + ti); __syncwarp(); } } // emulate dst[mi] += __shfl_xor_sync(uint32_t(-1), dst[mi], 1); __syncwarp(); tmp[0] += tmp[1]; // emulate dst[mi] += __shfl_xor_sync(uint32_t(-1), dst[mi], 2); __syncwarp(); tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 1); __syncwarp(); sum_fp32 = __shfl_sync(uint32_t(-1), tmp[0], 0); } else { // SEQLEN == 128. if (N == 1) { float tmp[2] = {0.f, 0.f}; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 700 // GV100 // The thread local reduction. tmp[0] += data_fp32[0][0]; tmp[0] += data_fp32[0][1]; tmp[0] += data_fp32[0][2]; tmp[0] += data_fp32[0][3]; // Add threads 0 and 2. Inside a thread in the impl. tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 2); __syncwarp(); // Add threads 0 and 8. Inside the thread. tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 8); __syncwarp(); // Add threads 0 and 16. Inside the thread. tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 16); __syncwarp(); // Add threads 0 and 1. tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 1); __syncwarp(); // Add threads 0 and 4. Inter-warp in the code. tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 4); __syncwarp(); #else if (warps_n == 2) { // 2x2 tmp[0] += data_fp32[0][0] + data_fp32[0][1]; tmp[1] += data_fp32[0][2] + data_fp32[0][3]; // Emulate a_01 += a_23... tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 2); __syncwarp(); tmp[1] += __shfl_xor_sync(uint32_t(-1), tmp[1], 2); __syncwarp(); // Emulate a_01 += a_45... tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 8); __syncwarp(); tmp[1] += __shfl_xor_sync(uint32_t(-1), tmp[1], 8); __syncwarp(); // Emulate a_01 += a_89... tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 16); __syncwarp(); tmp[1] += __shfl_xor_sync(uint32_t(-1), tmp[1], 16); __syncwarp(); // Emulate dst[mi] += __shfl_xor_sync(uint32_t(-1), dst[mi], 1); __syncwarp(); tmp[0] += tmp[1]; // Emulate dst[mi] += __shfl_xor_sync(uint32_t(-1), dst[mi], 2); __syncwarp(); tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 1); __syncwarp(); // Emulate the final reduction in smem. tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 4); __syncwarp(); } else { // 1x4 tmp[0] += data_fp32[0][0] + data_fp32[0][1]; tmp[1] += data_fp32[0][2] + data_fp32[0][3]; // Add +64. tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 16); __syncwarp(); tmp[1] += __shfl_xor_sync(uint32_t(-1), tmp[1], 16); __syncwarp(); // T0: Emulate dst[mi] = tmp[mi][0] + tmp[mi][1]; tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 2); __syncwarp(); // T1: Emulate dst[mi] = tmp[mi][0] + tmp[mi][1]; tmp[1] += __shfl_xor_sync(uint32_t(-1), tmp[1], 2); __syncwarp(); // Emulate dst[mi] += __shfl_xor_sync(uint32_t(-1), dst[mi], 1); __syncwarp(); tmp[0] += tmp[1]; // Emulate dst[mi] += __shfl_xor_sync(uint32_t(-1), dst[mi], 2); __syncwarp(); tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 1); __syncwarp(); // Emulate dst[mi] += __shfl_xor_sync(uint32_t(-1), dst[mi], 4); __syncwarp(); tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 4); __syncwarp(); // Emulate dst[mi] += __shfl_xor_sync(uint32_t(-1), dst[mi], 8); __syncwarp(); tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 8); __syncwarp(); } #endif // ! GV100 // Don't forget to put the value in sum_fp32 :) // sum_fp32 = tmp[0]; sum_fp32 = __shfl_sync(uint32_t(-1), tmp[0], 0); // SEQLEN == 256 - compare with 1x4. } else if (N == 2 || N == 8) { #pragma unroll for (int step = 0; step < N; step += 2) { float tmp[2] = {0.f, 0.f}; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 700 // GV100 // The thread local reduction. tmp[0] += data_fp32[step + 0][0]; tmp[0] += data_fp32[step + 0][1]; tmp[0] += data_fp32[step + 0][2]; tmp[0] += data_fp32[step + 0][3]; tmp[1] += data_fp32[step + 1][0]; tmp[1] += data_fp32[step + 1][1]; tmp[1] += data_fp32[step + 1][2]; tmp[1] += data_fp32[step + 1][3]; // Sum offset 0 and 128 (and so on). tmp[0] += tmp[1]; // Add threads 0 and 2. Inside a thread in the impl. tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 2); __syncwarp(); // Add threads 0 and 16. Inside the thread. tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 16); __syncwarp(); // Add threads 0 and 1. tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 1); __syncwarp(); // Add threads 0 and 4. Inter-warp in the code. tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 4); __syncwarp(); // Add threads 0 and 8. Inter-warp in the code. tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 8); __syncwarp(); #else // 0. tmp[0] += data_fp32[step + 0][0] + data_fp32[step + 0][1]; tmp[1] += data_fp32[step + 0][2] + data_fp32[step + 0][3]; // Add +64. tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 16); __syncwarp(); tmp[1] += __shfl_xor_sync(uint32_t(-1), tmp[1], 16); __syncwarp(); // Add +128 but use temp storage due to the next round of shfl. float xy = data_fp32[step + 1][0] + data_fp32[step + 1][1]; float zw = data_fp32[step + 1][2] + data_fp32[step + 1][3]; // Add +128. tmp[0] += xy; tmp[1] += zw; // Add +192. tmp[0] += __shfl_xor_sync(uint32_t(-1), xy, 16); __syncwarp(); tmp[1] += __shfl_xor_sync(uint32_t(-1), zw, 16); __syncwarp(); // T0: Emulate dst[mi] = tmp[mi][0] + tmp[mi][1]; tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 2); __syncwarp(); // T1: Emulate dst[mi] = tmp[mi][0] + tmp[mi][1]; tmp[1] += __shfl_xor_sync(uint32_t(-1), tmp[1], 2); __syncwarp(); // Emulate dst[mi] += __shfl_xor_sync(uint32_t(-1), dst[mi], 1); __syncwarp(); tmp[0] += tmp[1]; // Emulate dst[mi] += __shfl_xor_sync(uint32_t(-1), dst[mi], 2); __syncwarp(); tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 1); __syncwarp(); // Emulate dst[mi] += __shfl_xor_sync(uint32_t(-1), dst[mi], 4); __syncwarp(); tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 4); __syncwarp(); // Emulate dst[mi] += __shfl_xor_sync(uint32_t(-1), dst[mi], 8); __syncwarp(); tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 8); __syncwarp(); #endif // ! GV100 // Don't forget to put the value in sum_fp32 :) sum_fp32 += tmp[0]; } // Emulate taking warp results from position 0, 16, 32, 48, etc. sum_fp32 = __shfl_sync(uint32_t(-1), sum_fp32, 0); // SEQLEN == 384. } else if (N == 3) { float tmp[2] = {0.f, 0.f}; // The reduction inside the thread. #pragma unroll for (int ii = 0; ii < N; ++ii) { tmp[0] += data_fp32[ii][0]; tmp[0] += data_fp32[ii][1]; tmp[1] += data_fp32[ii][2]; tmp[1] += data_fp32[ii][3]; } // Emulate dst[mi] = tmp[mi][0] + tmp[mi][1]; tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 2); __syncwarp(); tmp[1] += __shfl_xor_sync(uint32_t(-1), tmp[1], 2); __syncwarp(); // Emulate dst[mi] += __shfl_xor_sync(uint32_t(-1), dst[mi], 1); __syncwarp(); tmp[0] += tmp[1]; // Emulate dst[mi] += __shfl_xor_sync(uint32_t(-1), dst[mi], 2); __syncwarp(); tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 1); __syncwarp(); // Emulate the final summation. tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 4); __syncwarp(); tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 8); __syncwarp(); tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 16); __syncwarp(); // Don't forget to put the value in sum_fp32 :) sum_fp32 += tmp[0]; // SEQLEN == 512 - compare with 1x8. } else if (N >= 4) { // Emulate thread local float tmp[2] = {0.f, 0.f}; // T0, T1 #pragma unroll for (int step = 0; step < N; step++) { tmp[0] += data_fp32[step][0]; // + 0 tmp[0] += data_fp32[step][1]; // + 1 tmp[1] += data_fp32[step][2]; // + 2 tmp[1] += data_fp32[step][3]; // + 3 } // T0: Emulate dst[mi] = tmp[mi][0] + tmp[mi][1]; tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 2); __syncwarp(); // T1: Emulate dst[mi] = tmp[mi][0] + tmp[mi][1]; tmp[1] += __shfl_xor_sync(uint32_t(-1), tmp[1], 2); __syncwarp(); // Emulate intra-thread // Emulate dst[mi] += __shfl_xor_sync(uint32_t(-1), dst[mi], 1); __syncwarp(); tmp[0] += tmp[1]; // Emulate dst[mi] += __shfl_xor_sync(uint32_t(-1), dst[mi], 2); __syncwarp(); tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 1); __syncwarp(); // Emulate inter-thread tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 4); __syncwarp(); tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 8); __syncwarp(); tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 16); __syncwarp(); // Don't forget to put the value in sum_fp32 :) // sum_fp32 = tmp[0]; // Emulate taking warp results from position 0, 16, 32, 48, etc. sum_fp32 = __shfl_sync(uint32_t(-1), tmp[0], 0); // Not supported. } else { assert(false); } } // warps_n == 1 #else #pragma unroll for (int ii = 0; ii < N; ++ii) { sum_fp32 += data_fp32[ii][0]; sum_fp32 += data_fp32[ii][1]; sum_fp32 += data_fp32[ii][2]; sum_fp32 += data_fp32[ii][3]; } // Compute inside the warp. #pragma unroll for (int xor_mask = THREADS_PER_WARP / 2; xor_mask > 0; xor_mask /= 2) { sum_fp32 += __shfl_xor_sync(uint32_t(-1), sum_fp32, xor_mask); } #endif // // DEBUG. // if( blockIdx.x == 0 && threadIdx.y == 0 && threadIdx.x == 0 ) { // printf("elt=%12.8f sum_fp32=%12.8f\n", data_fp32[0][0], sum_fp32); // } // // END OF DEBUG. // Fix the sum if needed. if (sum_fp32 == 0.f || sum_fp32 != sum_fp32) { sum_fp32 = 1.f; } // Normalize. float inv_sum_fp32 = 1.f / (sum_fp32 + expf(attention_sink - max_fp32)); #pragma unroll for (int ii = 0; ii < N; ++ii) { data_fp32[ii][0] *= inv_sum_fp32; data_fp32[ii][1] *= inv_sum_fp32; data_fp32[ii][2] *= inv_sum_fp32; data_fp32[ii][3] *= inv_sum_fp32; } } template struct VecX { using Type = typename fmha::Uint_from_size_in_bytes::Type; static_assert(sizeof(Type) == X * sizeof(Data_type)); union Alias { Type raw; Data_type elt[X]; }; static __device__ inline void to_floatX( float (&dst)[X], Type const& src, float const scale, float const attn_logit_softcapping_scale) { Alias tmp; tmp.raw = src; #pragma unroll for (int it = 0; it < X; it++) { dst[it] = to_float(tmp.elt[it], scale); if (attn_logit_softcapping_scale != 0.f) { dst[it] = attn_logit_softcapping_scale * fmha::__tanhf(dst[it] / attn_logit_softcapping_scale); } } } static __device__ inline void from_floatX(Type& dst, float const (&src)[X], float const scale) { Alias tmp; #pragma unroll for (int it = 0; it < X; it++) { from_float(tmp.elt[it], src[it], scale); } dst = tmp.raw; } }; //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ float get_alibi_head_scaling_factor(int const head_id, int const num_heads) { // Round down to power of 2 int const num_heads_pow2 = (1u << (31 - __clz(num_heads))); if (head_id < num_heads_pow2) { return exp2f((head_id + 1) * -8.0f / num_heads_pow2); } else { float const adjusted_head_id = 2 * (head_id - num_heads_pow2) + 1; return exp2f(adjusted_head_id * -4.0f / num_heads_pow2); } } //////////////////////////////////////////////////////////////////////////////////////////////////// template static __global__ void softmax_kernel(Softmax_params params) { // By default, use LDG.64 for the loads and STG.64 for the stores. enum { ELEMENTS_PER_LDG = X, ELEMENTS_PER_STG = X }; // The number of Vec_type per thread. enum { VECs_PER_THREAD = SEQLEN / THREADS_PER_WARP / ELEMENTS_PER_LDG }; // DEBUG. static_assert(VECs_PER_THREAD * THREADS_PER_WARP * ELEMENTS_PER_LDG == SEQLEN, ""); // END OF DEBUG. using VecO = VecX; using VecI = VecX; using VecM = VecX; // The vector types. using DstX_type = typename VecO::Type; using SrcX_type = typename VecI::Type; // Make sure the sizes match our expectations. static_assert(sizeof(DstX_type) == X * sizeof(Dst_type)); static_assert(sizeof(SrcX_type) == X * sizeof(Src_type)); // The type of the mask. using MaskX_type = typename VecM::Type; // One warp per sequence. size_t hi = blockIdx.y * WARPS_PER_CTA + threadIdx.y; size_t bi = blockIdx.z; size_t si = blockIdx.x; // The data offset. Layout is S * B * H * S. size_t src_offset = si * params.bhs + bi * params.hs + hi * SEQLEN + threadIdx.x * ELEMENTS_PER_LDG; // Load the input elements. SrcX_type const* src_ptr = reinterpret_cast(¶ms.src[src_offset]); SrcX_type data_src[VECs_PER_THREAD]; #pragma unroll for (int ii = 0; ii < VECs_PER_THREAD; ++ii) { if (hi < params.h) { data_src[ii] = src_ptr[ii * THREADS_PER_WARP]; } } // The mask offset. Layout is S * B * S. size_t mask_offset = si * params.bs + bi * SEQLEN + threadIdx.x * ELEMENTS_PER_LDG; // Load the masks. MaskX_type const* mask_ptr = reinterpret_cast(¶ms.mask[mask_offset]); MaskX_type mask[VECs_PER_THREAD]; #pragma unroll for (int ii = 0; ii < VECs_PER_THREAD; ++ii) { mask[ii] = mask_ptr[ii * THREADS_PER_WARP]; } // Convert the data to float. float data_fp32[VECs_PER_THREAD][X]; int8_t mask_[VECs_PER_THREAD][X]; #pragma unroll for (int ii = 0; ii < VECs_PER_THREAD; ++ii) { VecI::to_floatX(data_fp32[ii], data_src[ii], params.scale_bmm1, params.softcapping_scale_bmm1); typename VecM::Alias tmp; tmp.raw = mask[ii]; #pragma unroll for (int it = 0; it < X; it++) { mask_[ii][it] = tmp.elt[it]; } } if (params.has_alibi) { float const alibi_factor = get_alibi_head_scaling_factor(hi, params.h); #pragma unroll for (int ii = 0; ii < VECs_PER_THREAD; ii++) { #pragma unroll for (int jj = 0; jj < X; jj++) { int col = ii * THREADS_PER_WARP * X + threadIdx.x * X + jj; data_fp32[ii][jj] += alibi_factor * col; } } } // The attention sink value. float attention_sink = -FLT_MAX; if (params.attention_sinks != nullptr) { attention_sink = params.attention_sinks[hi]; } // Do the reduction. float sum_fp32 = 0.f; float max_fp32 = -HUGE_VALF; reduce(data_fp32, mask_, params.warps_n, sum_fp32, max_fp32, attention_sink); if (threadIdx.x == 0) { int sum_s = params.cu_q_seqlens[bi]; // [B, S, H, 2] {max, sum} float if (hi < params.h) { params.softmax_sum[(sum_s + si) * params.h * 2 + hi * 2] = max_fp32; params.softmax_sum[(sum_s + si) * params.h * 2 + hi * 2 + 1] = sum_fp32; } } // Reconvert to half. DstX_type data_dst[VECs_PER_THREAD]; #pragma unroll for (int ii = 0; ii < VECs_PER_THREAD; ++ii) { VecO::from_floatX(data_dst[ii], data_fp32[ii], params.scale_softmax); } // Store the output elements. DstX_type* dst_ptr = reinterpret_cast(¶ms.dst[src_offset]); #pragma unroll for (int ii = 0; ii < VECs_PER_THREAD; ++ii) { if (hi < params.h) { dst_ptr[ii * THREADS_PER_WARP] = data_dst[ii]; } } } //////////////////////////////////////////////////////////////////////////////////////////////////// template void run_softmax(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum, void* cu_q_seqlens, int s_inner, int s_outer, int b, int h, float scale_bmm1, float scale_softmax, float softcapping_scale_bmm1, int warps_n, bool has_alibi) { Softmax_params params; memset(¶ms, 0, sizeof(params)); // The different pointers. params.dst = reinterpret_cast(dst); params.src = reinterpret_cast(src); params.softmax_sum = reinterpret_cast(softmax_sum); params.cu_q_seqlens = reinterpret_cast(cu_q_seqlens); params.mask = reinterpret_cast(mask); params.attention_sinks = reinterpret_cast(attention_sinks); params.has_alibi = has_alibi; // The dimensions and precomputed values. params.b = b; params.h = h; params.bhs = b * h * s_inner; params.hs = h * s_inner; params.bs = b * s_inner; // The scaling factors for the int8 version to convert to/from float. params.scale_bmm1 = scale_bmm1; params.softcapping_scale_bmm1 = softcapping_scale_bmm1; params.scale_softmax = scale_softmax; // The number of warps_n used to identify the reduction strategy. params.warps_n = warps_n; // Compute the grid size. enum { WARPS_PER_CTA = 4 }; dim3 grid(s_outer, (h + WARPS_PER_CTA - 1) / WARPS_PER_CTA, b); dim3 threads_per_cta(THREADS_PER_WARP, WARPS_PER_CTA); // Launch the kernel. if (s_inner == 32) { softmax_kernel<<>>(params); } else if (s_inner == 64) { softmax_kernel<<>>(params); } else if (s_inner == 96) { softmax_kernel<<>>(params); } else if (s_inner == 128) { softmax_kernel<<>>(params); } else if (s_inner == 192) { softmax_kernel<<>>(params); } else if (s_inner == 256) { softmax_kernel<<>>(params); } else if (s_inner == 384) { softmax_kernel<<>>(params); } else if (s_inner == 512) { softmax_kernel<<>>(params); } else if (s_inner == 1024) { softmax_kernel<<>>(params); } else if (s_inner == 2048) { softmax_kernel<<>>(params); } else if (s_inner == 4096) { softmax_kernel<<>>(params); } else if (s_inner == 8192) { softmax_kernel<<>>(params); } else if (s_inner == 16384) { softmax_kernel<<>>(params); } else if (s_inner == 32768) { softmax_kernel<<>>(params); } else if (s_inner == 65536) { softmax_kernel<<>>(params); } else { assert(false); } } ////////////////////////////////////////////////////////////////////////////////////////////////////