mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
1152 lines
37 KiB
C++
1152 lines
37 KiB
C++
/*
|
|
* SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
* SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement
|
|
*
|
|
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
|
* property and proprietary rights in and to this material, related
|
|
* documentation and any modifications thereto. Any use, reproduction,
|
|
* disclosure or distribution of this material and related documentation
|
|
* without an express license agreement from NVIDIA CORPORATION or
|
|
* its affiliates is strictly prohibited.
|
|
*/
|
|
|
|
#include <cfloat>
|
|
#include <cstdio>
|
|
#include <fmha/numeric_types.h>
|
|
#include <fmha/utils.h>
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// The number of threads per warp.
|
|
enum
|
|
{
|
|
THREADS_PER_WARP = 32
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename Dst_type, typename Src_type>
|
|
struct Softmax_params
|
|
{
|
|
// Output pointer.
|
|
Dst_type* dst;
|
|
// Source pointer.
|
|
Src_type const* src;
|
|
// Masks.
|
|
int8_t const* mask;
|
|
// Attention sinks (per head).
|
|
float const* attention_sinks;
|
|
// Softmax sum pointer.
|
|
float* softmax_sum;
|
|
// ALiBi
|
|
bool has_alibi;
|
|
// Dimensions of the problem.
|
|
size_t b, h;
|
|
// Precomputed constants.
|
|
size_t bhs, hs, bs;
|
|
// The scaling factors to apply when we convert to/from float.
|
|
float scale_bmm1, softcapping_scale_bmm1, scale_softmax;
|
|
// The number of reduction warps used by the fused kernel.
|
|
int warps_n;
|
|
int* cu_q_seqlens;
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static inline __device__ float to_float(uint16_t const& src, float)
|
|
{
|
|
return fmha::half_to_float(src);
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// Disable warning #177-D because this function has not been used elsewhere
|
|
#pragma nv_diag_suppress 177
|
|
|
|
static inline __device__ float to_float(fmha::bf16_t const& src, float)
|
|
{
|
|
return __bfloat162float(src);
|
|
}
|
|
|
|
#pragma nv_diag_default 177
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// Disable warning #177-D because this function has not been used elsewhere
|
|
#pragma nv_diag_suppress 177
|
|
|
|
static inline __device__ float to_float(fmha::e4m3_t const& src, float)
|
|
{
|
|
return float(src);
|
|
}
|
|
|
|
#pragma nv_diag_default 177
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static inline __device__ float to_float(float const& src, float)
|
|
{
|
|
return src;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static inline __device__ float to_float(int const& src, float scale)
|
|
{
|
|
float dst;
|
|
|
|
// Convert from int to float.
|
|
dst = static_cast<float>(src);
|
|
|
|
// Scale.
|
|
dst *= scale;
|
|
|
|
return dst;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static inline __device__ void from_float(uint16_t& dst, float const& src, float)
|
|
{
|
|
dst = fmha::float_to_half(src);
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static inline __device__ void from_float(fmha::bf16_t& dst, float const& src, float)
|
|
{
|
|
dst = fmha::float_to_bf16(src);
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static inline __device__ int8_t float_to_int8_rn(float x)
|
|
{
|
|
uint32_t dst;
|
|
asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x));
|
|
return reinterpret_cast<int8_t const&>(dst);
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static inline __device__ void from_float(int8_t& dst, float const& src, float scale)
|
|
{
|
|
dst = float_to_int8_rn(src * scale);
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static inline __device__ void from_float(fmha::e4m3_t& dst, float const& src, float scale)
|
|
{
|
|
dst = fmha::e4m3_t(src * scale);
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static inline __device__ float apply_exp_(float x, float max)
|
|
{
|
|
return isinf(x) ? 0.f : __expf(x - max);
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <int N>
|
|
static inline __device__ void reduce(float (&data_fp32)[N][1], int8_t const (&mask)[N][1], int warps_n, float& sum_fp32,
|
|
float& max_fp32, float const attention_sink)
|
|
{
|
|
|
|
// Apply the masks.
|
|
#pragma unroll
|
|
for (int ii = 0; ii < N; ++ii)
|
|
{
|
|
data_fp32[ii][0] = mask[ii][0] ? data_fp32[ii][0] : -HUGE_VALF;
|
|
}
|
|
|
|
// Compute the max inside the thread.
|
|
#pragma unroll
|
|
for (int ii = 0; ii < N; ++ii)
|
|
{
|
|
max_fp32 = fmaxf(max_fp32, data_fp32[ii][0]);
|
|
}
|
|
|
|
// Compute inside the warp.
|
|
#pragma unroll
|
|
for (int xor_mask = THREADS_PER_WARP / 2; xor_mask > 0; xor_mask /= 2)
|
|
{
|
|
max_fp32 = fmaxf(max_fp32, __shfl_xor_sync(uint32_t(-1), max_fp32, xor_mask));
|
|
}
|
|
|
|
// Transform the elements.
|
|
#pragma unroll
|
|
for (int ii = 0; ii < N; ++ii)
|
|
{
|
|
data_fp32[ii][0] = apply_exp_(data_fp32[ii][0], max_fp32);
|
|
}
|
|
|
|
// Compute the max inside the thread.
|
|
#if defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE)
|
|
|
|
#pragma unroll
|
|
for (int ii = 0; ii < N; ii++)
|
|
{
|
|
sum_fp32 += data_fp32[ii][0]; //+0 +64 +128
|
|
}
|
|
|
|
// Emulate tmp[0] + tmp[1]
|
|
sum_fp32 += __shfl_xor_sync(uint32_t(-1), sum_fp32, 4);
|
|
__syncwarp();
|
|
|
|
// Emulate dst[mi] += __shfl_xor_sync(uint32_t(-1), dst[mi], 1); __syncwarp();
|
|
sum_fp32 += __shfl_xor_sync(uint32_t(-1), sum_fp32, 1);
|
|
__syncwarp();
|
|
// Emulate dst[mi] += __shfl_xor_sync(uint32_t(-1), dst[mi], 2); __syncwarp();
|
|
sum_fp32 += __shfl_xor_sync(uint32_t(-1), sum_fp32, 2);
|
|
__syncwarp();
|
|
|
|
// Emulate final reduction
|
|
sum_fp32 += __shfl_xor_sync(uint32_t(-1), sum_fp32, 8);
|
|
__syncwarp();
|
|
sum_fp32 += __shfl_xor_sync(uint32_t(-1), sum_fp32, 16);
|
|
__syncwarp();
|
|
|
|
#else
|
|
#pragma unroll
|
|
for (int ii = 0; ii < N; ++ii)
|
|
{
|
|
sum_fp32 += data_fp32[ii][0];
|
|
}
|
|
|
|
// Compute inside the warp.
|
|
#pragma unroll
|
|
for (int xor_mask = THREADS_PER_WARP / 2; xor_mask > 0; xor_mask /= 2)
|
|
{
|
|
sum_fp32 += __shfl_xor_sync(uint32_t(-1), sum_fp32, xor_mask);
|
|
}
|
|
#endif
|
|
|
|
// // DEBUG.
|
|
// if( blockIdx.z == 1 && threadIdx.y == 0 && threadIdx.x == 5 ) {
|
|
// printf("elt=%12.8f sum_fp32=%12.8f\n", data_fp32[0].x, sum_fp32);
|
|
// }
|
|
|
|
// Fix the sum if needed.
|
|
if (sum_fp32 == 0.f || sum_fp32 != sum_fp32)
|
|
{
|
|
sum_fp32 = 1.f;
|
|
}
|
|
|
|
// Normalize.
|
|
float inv_sum_fp32 = 1.f / (sum_fp32 + expf(attention_sink - max_fp32));
|
|
#pragma unroll
|
|
for (int ii = 0; ii < N; ++ii)
|
|
{
|
|
data_fp32[ii][0] *= inv_sum_fp32;
|
|
}
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <int N>
|
|
static inline __device__ void reduce(float (&data_fp32)[N][2], int8_t const (&mask)[N][2], int warps_n, float& sum_fp32,
|
|
float& max_fp32, float const attention_sink)
|
|
{
|
|
// Apply the masks.
|
|
#pragma unroll
|
|
for (int ii = 0; ii < N; ++ii)
|
|
{
|
|
data_fp32[ii][0] = mask[ii][0] ? data_fp32[ii][0] : -HUGE_VALF;
|
|
data_fp32[ii][1] = mask[ii][1] ? data_fp32[ii][1] : -HUGE_VALF;
|
|
}
|
|
|
|
// Compute the max inside the thread.
|
|
#pragma unroll
|
|
for (int ii = 0; ii < N; ++ii)
|
|
{
|
|
max_fp32 = fmaxf(max_fp32, data_fp32[ii][0]);
|
|
max_fp32 = fmaxf(max_fp32, data_fp32[ii][1]);
|
|
}
|
|
|
|
// Compute inside the warp.
|
|
#pragma unroll
|
|
for (int xor_mask = THREADS_PER_WARP / 2; xor_mask > 0; xor_mask /= 2)
|
|
{
|
|
max_fp32 = fmaxf(max_fp32, __shfl_xor_sync(uint32_t(-1), max_fp32, xor_mask));
|
|
}
|
|
|
|
// // DEBUG.
|
|
// if( blockIdx.z == 1 && threadIdx.y == 0 && threadIdx.x == 5 ) {
|
|
// printf("elt=%12.8f max_fp32=%12.8f\n", data_fp32[0][0], max_fp32);
|
|
// }
|
|
// // END OF DEBUG.
|
|
|
|
// Transform the elements.
|
|
#pragma unroll
|
|
for (int ii = 0; ii < N; ++ii)
|
|
{
|
|
data_fp32[ii][0] = apply_exp_(data_fp32[ii][0], max_fp32);
|
|
data_fp32[ii][1] = apply_exp_(data_fp32[ii][1], max_fp32);
|
|
}
|
|
|
|
// Compute the max inside the thread.
|
|
// float sum_fp32 = 0.f;
|
|
#if defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE)
|
|
if (warps_n == 1)
|
|
{
|
|
// TODO not sure if we can improve this on the gmma side without using additional regs.
|
|
|
|
// this is intentionally o(n) instead of o(log n)
|
|
// lanes 0 and 1 here represent the first quad.
|
|
|
|
// need to account for offset of l0 when addressing absolute lanes.
|
|
int const ti = threadIdx.x % 4;
|
|
float tmp = 0.f;
|
|
|
|
for (int ni = 0; ni < N; ni++)
|
|
{
|
|
float x = data_fp32[ni][0] + data_fp32[ni][1];
|
|
tmp += x;
|
|
|
|
for (int it = 1; it < 8; it++)
|
|
{
|
|
tmp += __shfl_sync(uint32_t(-1), x, 4 * it + ti);
|
|
__syncwarp();
|
|
}
|
|
}
|
|
|
|
// emulate dst[mi] += __shfl_xor_sync(uint32_t(-1), dst[mi], 1); __syncwarp();
|
|
tmp += __shfl_xor_sync(uint32_t(-1), tmp, 1);
|
|
__syncwarp();
|
|
// emulate dst[mi] += __shfl_xor_sync(uint32_t(-1), dst[mi], 2); __syncwarp();
|
|
tmp += __shfl_xor_sync(uint32_t(-1), tmp, 2);
|
|
__syncwarp();
|
|
sum_fp32 = __shfl_sync(uint32_t(-1), tmp, 0);
|
|
}
|
|
else if (warps_n == 8)
|
|
{
|
|
// Accumulate warp 0 and warp 4
|
|
float tmp[2] = {0.f, 0.f};
|
|
#pragma unroll
|
|
for (int ii = 0; ii < N; ii += 2)
|
|
{
|
|
tmp[0] += data_fp32[ii + 0][0];
|
|
tmp[0] += data_fp32[ii + 0][1];
|
|
tmp[1] += data_fp32[ii + 1][0];
|
|
tmp[1] += data_fp32[ii + 1][1];
|
|
}
|
|
|
|
// Emulate tmp[0] + tmp[1]
|
|
tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 4);
|
|
tmp[1] += __shfl_xor_sync(uint32_t(-1), tmp[1], 4);
|
|
|
|
// Emulate dst[mi] += __shfl_xor_sync(uint32_t(-1), dst[mi], 1); __syncwarp();
|
|
tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 1);
|
|
tmp[1] += __shfl_xor_sync(uint32_t(-1), tmp[1], 1);
|
|
// Emulate dst[mi] += __shfl_xor_sync(uint32_t(-1), dst[mi], 2); __syncwarp();
|
|
tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 2);
|
|
tmp[1] += __shfl_xor_sync(uint32_t(-1), tmp[1], 2);
|
|
// Emulate final reduction
|
|
tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 8);
|
|
tmp[1] += __shfl_xor_sync(uint32_t(-1), tmp[1], 8);
|
|
|
|
tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 16);
|
|
tmp[1] += __shfl_xor_sync(uint32_t(-1), tmp[1], 16);
|
|
sum_fp32 = tmp[0] + tmp[1];
|
|
|
|
sum_fp32 = __shfl_sync(uint32_t(-1), sum_fp32, 0);
|
|
}
|
|
else
|
|
{
|
|
|
|
#pragma unroll
|
|
for (int ii = 0; ii < N; ii++)
|
|
{
|
|
sum_fp32 += data_fp32[ii][0] + data_fp32[ii][1];
|
|
}
|
|
|
|
// Emulate tmp[0] + tmp[1]
|
|
sum_fp32 += __shfl_xor_sync(uint32_t(-1), sum_fp32, 4);
|
|
// Emulate dst[mi] += __shfl_xor_sync(uint32_t(-1), dst[mi], 1); __syncwarp();
|
|
sum_fp32 += __shfl_xor_sync(uint32_t(-1), sum_fp32, 1);
|
|
// Emulate dst[mi] += __shfl_xor_sync(uint32_t(-1), dst[mi], 2); __syncwarp();
|
|
sum_fp32 += __shfl_xor_sync(uint32_t(-1), sum_fp32, 2);
|
|
// Emulate final reduction
|
|
sum_fp32 += __shfl_xor_sync(uint32_t(-1), sum_fp32, 8);
|
|
sum_fp32 += __shfl_xor_sync(uint32_t(-1), sum_fp32, 16);
|
|
|
|
sum_fp32 = __shfl_sync(uint32_t(-1), sum_fp32, 0);
|
|
}
|
|
|
|
#else
|
|
#pragma unroll
|
|
for (int ii = 0; ii < N; ++ii)
|
|
{
|
|
sum_fp32 += data_fp32[ii][0];
|
|
sum_fp32 += data_fp32[ii][1];
|
|
}
|
|
|
|
// Compute inside the warp.
|
|
#pragma unroll
|
|
for (int xor_mask = THREADS_PER_WARP / 2; xor_mask > 0; xor_mask /= 2)
|
|
{
|
|
sum_fp32 += __shfl_xor_sync(uint32_t(-1), sum_fp32, xor_mask);
|
|
}
|
|
#endif
|
|
|
|
// // DEBUG.
|
|
// if( blockIdx.z == 1 && threadIdx.y == 0 && threadIdx.x == 5 ) {
|
|
// printf("elt=%12.8f sum_fp32=%12.8f\n", data_fp32[0][0], sum_fp32);
|
|
// }
|
|
|
|
// Fix the sum if needed.
|
|
if (sum_fp32 == 0.f || sum_fp32 != sum_fp32)
|
|
{
|
|
sum_fp32 = 1.f;
|
|
}
|
|
|
|
// Normalize.
|
|
float inv_sum_fp32 = 1.f / (sum_fp32 + expf(attention_sink - max_fp32));
|
|
#pragma unroll
|
|
for (int ii = 0; ii < N; ++ii)
|
|
{
|
|
data_fp32[ii][0] *= inv_sum_fp32;
|
|
data_fp32[ii][1] *= inv_sum_fp32;
|
|
}
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <int N>
|
|
static inline __device__ void reduce(float (&data_fp32)[N][4], int8_t const (&mask)[N][4], int warps_n, float& sum_fp32,
|
|
float& max_fp32, float const attention_sink)
|
|
{
|
|
|
|
// Apply the masks.
|
|
#pragma unroll
|
|
for (int ii = 0; ii < N; ++ii)
|
|
{
|
|
data_fp32[ii][0] = mask[ii][0] ? data_fp32[ii][0] : -HUGE_VALF;
|
|
data_fp32[ii][1] = mask[ii][1] ? data_fp32[ii][1] : -HUGE_VALF;
|
|
data_fp32[ii][2] = mask[ii][2] ? data_fp32[ii][2] : -HUGE_VALF;
|
|
data_fp32[ii][3] = mask[ii][3] ? data_fp32[ii][3] : -HUGE_VALF;
|
|
}
|
|
|
|
// Compute the max inside the thread.
|
|
#pragma unroll
|
|
for (int ii = 0; ii < N; ++ii)
|
|
{
|
|
max_fp32 = fmaxf(max_fp32, data_fp32[ii][0]);
|
|
max_fp32 = fmaxf(max_fp32, data_fp32[ii][1]);
|
|
max_fp32 = fmaxf(max_fp32, data_fp32[ii][2]);
|
|
max_fp32 = fmaxf(max_fp32, data_fp32[ii][3]);
|
|
}
|
|
|
|
// Compute inside the warp.
|
|
#pragma unroll
|
|
for (int xor_mask = THREADS_PER_WARP / 2; xor_mask > 0; xor_mask /= 2)
|
|
{
|
|
max_fp32 = fmaxf(max_fp32, __shfl_xor_sync(uint32_t(-1), max_fp32, xor_mask));
|
|
}
|
|
|
|
// // DEBUG.
|
|
// if( blockIdx.z == 1 && threadIdx.y == 0 && threadIdx.x == 5 ) {
|
|
// printf("elt=%12.8f max_fp32=%12.8f\n", data_fp32[0][0], max_fp32);
|
|
// }
|
|
|
|
// Transform the elements.
|
|
#pragma unroll
|
|
for (int ii = 0; ii < N; ++ii)
|
|
{
|
|
data_fp32[ii][0] = apply_exp_(data_fp32[ii][0], max_fp32);
|
|
data_fp32[ii][1] = apply_exp_(data_fp32[ii][1], max_fp32);
|
|
data_fp32[ii][2] = apply_exp_(data_fp32[ii][2], max_fp32);
|
|
data_fp32[ii][3] = apply_exp_(data_fp32[ii][3], max_fp32);
|
|
}
|
|
|
|
// Compute the max inside the thread.
|
|
// float sum_fp32 = 0.f;
|
|
|
|
// TODO needs refactoring...
|
|
|
|
#if defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE)
|
|
// Within a thread it should correspond to the operation done in the tmp[0]/[1] loop.
|
|
|
|
if (warps_n == 1)
|
|
{ // E.g. 4x1: 4 threads iterate over all cores.
|
|
// TODO not sure if we can improve this on the gmma side without using additional regs.
|
|
|
|
// this is intentionally o(n) instead of o(log n)
|
|
// lanes 0 and 1 here represent the first quad.
|
|
|
|
// need to account for offset of l0 when addressing absolute lanes.
|
|
int const ti = threadIdx.x % 2;
|
|
float tmp[2] = {0.f, 0.f};
|
|
|
|
for (int ni = 0; ni < N; ni++)
|
|
{
|
|
// +1
|
|
float x = data_fp32[ni][0] + data_fp32[ni][1];
|
|
float y = data_fp32[ni][2] + data_fp32[ni][3];
|
|
tmp[0] += x;
|
|
tmp[1] += y;
|
|
|
|
for (int it = 1; it < 16; it++)
|
|
{
|
|
tmp[0] += __shfl_sync(uint32_t(-1), x, 2 * it + ti);
|
|
__syncwarp();
|
|
tmp[1] += __shfl_sync(uint32_t(-1), y, 2 * it + ti);
|
|
__syncwarp();
|
|
}
|
|
}
|
|
|
|
// emulate dst[mi] += __shfl_xor_sync(uint32_t(-1), dst[mi], 1); __syncwarp();
|
|
tmp[0] += tmp[1];
|
|
// emulate dst[mi] += __shfl_xor_sync(uint32_t(-1), dst[mi], 2); __syncwarp();
|
|
tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 1);
|
|
__syncwarp();
|
|
sum_fp32 = __shfl_sync(uint32_t(-1), tmp[0], 0);
|
|
}
|
|
else
|
|
{
|
|
|
|
// SEQLEN == 128.
|
|
if (N == 1)
|
|
{
|
|
|
|
float tmp[2] = {0.f, 0.f};
|
|
|
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 700 // GV100
|
|
// The thread local reduction.
|
|
tmp[0] += data_fp32[0][0];
|
|
tmp[0] += data_fp32[0][1];
|
|
tmp[0] += data_fp32[0][2];
|
|
tmp[0] += data_fp32[0][3];
|
|
|
|
// Add threads 0 and 2. Inside a thread in the impl.
|
|
tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 2);
|
|
__syncwarp();
|
|
// Add threads 0 and 8. Inside the thread.
|
|
tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 8);
|
|
__syncwarp();
|
|
// Add threads 0 and 16. Inside the thread.
|
|
tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 16);
|
|
__syncwarp();
|
|
|
|
// Add threads 0 and 1.
|
|
tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 1);
|
|
__syncwarp();
|
|
|
|
// Add threads 0 and 4. Inter-warp in the code.
|
|
tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 4);
|
|
__syncwarp();
|
|
#else
|
|
if (warps_n == 2)
|
|
{ // 2x2
|
|
tmp[0] += data_fp32[0][0] + data_fp32[0][1];
|
|
tmp[1] += data_fp32[0][2] + data_fp32[0][3];
|
|
|
|
// Emulate a_01 += a_23...
|
|
tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 2);
|
|
__syncwarp();
|
|
tmp[1] += __shfl_xor_sync(uint32_t(-1), tmp[1], 2);
|
|
__syncwarp();
|
|
|
|
// Emulate a_01 += a_45...
|
|
tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 8);
|
|
__syncwarp();
|
|
tmp[1] += __shfl_xor_sync(uint32_t(-1), tmp[1], 8);
|
|
__syncwarp();
|
|
|
|
// Emulate a_01 += a_89...
|
|
tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 16);
|
|
__syncwarp();
|
|
tmp[1] += __shfl_xor_sync(uint32_t(-1), tmp[1], 16);
|
|
__syncwarp();
|
|
|
|
// Emulate dst[mi] += __shfl_xor_sync(uint32_t(-1), dst[mi], 1); __syncwarp();
|
|
tmp[0] += tmp[1];
|
|
|
|
// Emulate dst[mi] += __shfl_xor_sync(uint32_t(-1), dst[mi], 2); __syncwarp();
|
|
tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 1);
|
|
__syncwarp();
|
|
|
|
// Emulate the final reduction in smem.
|
|
tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 4);
|
|
__syncwarp();
|
|
}
|
|
else
|
|
{ // 1x4
|
|
tmp[0] += data_fp32[0][0] + data_fp32[0][1];
|
|
tmp[1] += data_fp32[0][2] + data_fp32[0][3];
|
|
|
|
// Add +64.
|
|
tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 16);
|
|
__syncwarp();
|
|
tmp[1] += __shfl_xor_sync(uint32_t(-1), tmp[1], 16);
|
|
__syncwarp();
|
|
|
|
// T0: Emulate dst[mi] = tmp[mi][0] + tmp[mi][1];
|
|
tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 2);
|
|
__syncwarp();
|
|
// T1: Emulate dst[mi] = tmp[mi][0] + tmp[mi][1];
|
|
tmp[1] += __shfl_xor_sync(uint32_t(-1), tmp[1], 2);
|
|
__syncwarp();
|
|
|
|
// Emulate dst[mi] += __shfl_xor_sync(uint32_t(-1), dst[mi], 1); __syncwarp();
|
|
tmp[0] += tmp[1];
|
|
|
|
// Emulate dst[mi] += __shfl_xor_sync(uint32_t(-1), dst[mi], 2); __syncwarp();
|
|
tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 1);
|
|
__syncwarp();
|
|
// Emulate dst[mi] += __shfl_xor_sync(uint32_t(-1), dst[mi], 4); __syncwarp();
|
|
tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 4);
|
|
__syncwarp();
|
|
// Emulate dst[mi] += __shfl_xor_sync(uint32_t(-1), dst[mi], 8); __syncwarp();
|
|
tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 8);
|
|
__syncwarp();
|
|
}
|
|
|
|
#endif // ! GV100
|
|
|
|
// Don't forget to put the value in sum_fp32 :)
|
|
// sum_fp32 = tmp[0];
|
|
sum_fp32 = __shfl_sync(uint32_t(-1), tmp[0], 0);
|
|
|
|
// SEQLEN == 256 - compare with 1x4.
|
|
}
|
|
else if (N == 2 || N == 8)
|
|
{
|
|
|
|
#pragma unroll
|
|
for (int step = 0; step < N; step += 2)
|
|
{
|
|
|
|
float tmp[2] = {0.f, 0.f};
|
|
|
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 700 // GV100
|
|
|
|
// The thread local reduction.
|
|
tmp[0] += data_fp32[step + 0][0];
|
|
tmp[0] += data_fp32[step + 0][1];
|
|
tmp[0] += data_fp32[step + 0][2];
|
|
tmp[0] += data_fp32[step + 0][3];
|
|
|
|
tmp[1] += data_fp32[step + 1][0];
|
|
tmp[1] += data_fp32[step + 1][1];
|
|
tmp[1] += data_fp32[step + 1][2];
|
|
tmp[1] += data_fp32[step + 1][3];
|
|
|
|
// Sum offset 0 and 128 (and so on).
|
|
tmp[0] += tmp[1];
|
|
|
|
// Add threads 0 and 2. Inside a thread in the impl.
|
|
tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 2);
|
|
__syncwarp();
|
|
// Add threads 0 and 16. Inside the thread.
|
|
tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 16);
|
|
__syncwarp();
|
|
|
|
// Add threads 0 and 1.
|
|
tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 1);
|
|
__syncwarp();
|
|
|
|
// Add threads 0 and 4. Inter-warp in the code.
|
|
tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 4);
|
|
__syncwarp();
|
|
// Add threads 0 and 8. Inter-warp in the code.
|
|
tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 8);
|
|
__syncwarp();
|
|
#else
|
|
// 0.
|
|
tmp[0] += data_fp32[step + 0][0] + data_fp32[step + 0][1];
|
|
tmp[1] += data_fp32[step + 0][2] + data_fp32[step + 0][3];
|
|
|
|
// Add +64.
|
|
tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 16);
|
|
__syncwarp();
|
|
tmp[1] += __shfl_xor_sync(uint32_t(-1), tmp[1], 16);
|
|
__syncwarp();
|
|
|
|
// Add +128 but use temp storage due to the next round of shfl.
|
|
float xy = data_fp32[step + 1][0] + data_fp32[step + 1][1];
|
|
float zw = data_fp32[step + 1][2] + data_fp32[step + 1][3];
|
|
|
|
// Add +128.
|
|
tmp[0] += xy;
|
|
tmp[1] += zw;
|
|
|
|
// Add +192.
|
|
tmp[0] += __shfl_xor_sync(uint32_t(-1), xy, 16);
|
|
__syncwarp();
|
|
tmp[1] += __shfl_xor_sync(uint32_t(-1), zw, 16);
|
|
__syncwarp();
|
|
|
|
// T0: Emulate dst[mi] = tmp[mi][0] + tmp[mi][1];
|
|
tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 2);
|
|
__syncwarp();
|
|
// T1: Emulate dst[mi] = tmp[mi][0] + tmp[mi][1];
|
|
tmp[1] += __shfl_xor_sync(uint32_t(-1), tmp[1], 2);
|
|
__syncwarp();
|
|
|
|
// Emulate dst[mi] += __shfl_xor_sync(uint32_t(-1), dst[mi], 1); __syncwarp();
|
|
tmp[0] += tmp[1];
|
|
|
|
// Emulate dst[mi] += __shfl_xor_sync(uint32_t(-1), dst[mi], 2); __syncwarp();
|
|
tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 1);
|
|
__syncwarp();
|
|
// Emulate dst[mi] += __shfl_xor_sync(uint32_t(-1), dst[mi], 4); __syncwarp();
|
|
tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 4);
|
|
__syncwarp();
|
|
// Emulate dst[mi] += __shfl_xor_sync(uint32_t(-1), dst[mi], 8); __syncwarp();
|
|
tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 8);
|
|
__syncwarp();
|
|
#endif // ! GV100
|
|
|
|
// Don't forget to put the value in sum_fp32 :)
|
|
sum_fp32 += tmp[0];
|
|
}
|
|
// Emulate taking warp results from position 0, 16, 32, 48, etc.
|
|
sum_fp32 = __shfl_sync(uint32_t(-1), sum_fp32, 0);
|
|
|
|
// SEQLEN == 384.
|
|
}
|
|
else if (N == 3)
|
|
{
|
|
|
|
float tmp[2] = {0.f, 0.f};
|
|
|
|
// The reduction inside the thread.
|
|
#pragma unroll
|
|
for (int ii = 0; ii < N; ++ii)
|
|
{
|
|
tmp[0] += data_fp32[ii][0];
|
|
tmp[0] += data_fp32[ii][1];
|
|
tmp[1] += data_fp32[ii][2];
|
|
tmp[1] += data_fp32[ii][3];
|
|
}
|
|
|
|
// Emulate dst[mi] = tmp[mi][0] + tmp[mi][1];
|
|
tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 2);
|
|
__syncwarp();
|
|
tmp[1] += __shfl_xor_sync(uint32_t(-1), tmp[1], 2);
|
|
__syncwarp();
|
|
|
|
// Emulate dst[mi] += __shfl_xor_sync(uint32_t(-1), dst[mi], 1); __syncwarp();
|
|
tmp[0] += tmp[1];
|
|
|
|
// Emulate dst[mi] += __shfl_xor_sync(uint32_t(-1), dst[mi], 2); __syncwarp();
|
|
tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 1);
|
|
__syncwarp();
|
|
|
|
// Emulate the final summation.
|
|
tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 4);
|
|
__syncwarp();
|
|
tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 8);
|
|
__syncwarp();
|
|
tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 16);
|
|
__syncwarp();
|
|
|
|
// Don't forget to put the value in sum_fp32 :)
|
|
sum_fp32 += tmp[0];
|
|
// SEQLEN == 512 - compare with 1x8.
|
|
}
|
|
else if (N >= 4)
|
|
{
|
|
// Emulate thread local
|
|
float tmp[2] = {0.f, 0.f}; // T0, T1
|
|
#pragma unroll
|
|
for (int step = 0; step < N; step++)
|
|
{
|
|
tmp[0] += data_fp32[step][0]; // + 0
|
|
tmp[0] += data_fp32[step][1]; // + 1
|
|
tmp[1] += data_fp32[step][2]; // + 2
|
|
tmp[1] += data_fp32[step][3]; // + 3
|
|
}
|
|
|
|
// T0: Emulate dst[mi] = tmp[mi][0] + tmp[mi][1];
|
|
tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 2);
|
|
__syncwarp();
|
|
// T1: Emulate dst[mi] = tmp[mi][0] + tmp[mi][1];
|
|
tmp[1] += __shfl_xor_sync(uint32_t(-1), tmp[1], 2);
|
|
__syncwarp();
|
|
|
|
// Emulate intra-thread
|
|
// Emulate dst[mi] += __shfl_xor_sync(uint32_t(-1), dst[mi], 1); __syncwarp();
|
|
tmp[0] += tmp[1];
|
|
// Emulate dst[mi] += __shfl_xor_sync(uint32_t(-1), dst[mi], 2); __syncwarp();
|
|
tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 1);
|
|
__syncwarp();
|
|
|
|
// Emulate inter-thread
|
|
tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 4);
|
|
__syncwarp();
|
|
tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 8);
|
|
__syncwarp();
|
|
tmp[0] += __shfl_xor_sync(uint32_t(-1), tmp[0], 16);
|
|
__syncwarp();
|
|
|
|
// Don't forget to put the value in sum_fp32 :)
|
|
// sum_fp32 = tmp[0];
|
|
|
|
// Emulate taking warp results from position 0, 16, 32, 48, etc.
|
|
sum_fp32 = __shfl_sync(uint32_t(-1), tmp[0], 0);
|
|
// Not supported.
|
|
}
|
|
else
|
|
{
|
|
assert(false);
|
|
}
|
|
} // warps_n == 1
|
|
#else
|
|
#pragma unroll
|
|
for (int ii = 0; ii < N; ++ii)
|
|
{
|
|
sum_fp32 += data_fp32[ii][0];
|
|
sum_fp32 += data_fp32[ii][1];
|
|
sum_fp32 += data_fp32[ii][2];
|
|
sum_fp32 += data_fp32[ii][3];
|
|
}
|
|
|
|
// Compute inside the warp.
|
|
#pragma unroll
|
|
for (int xor_mask = THREADS_PER_WARP / 2; xor_mask > 0; xor_mask /= 2)
|
|
{
|
|
sum_fp32 += __shfl_xor_sync(uint32_t(-1), sum_fp32, xor_mask);
|
|
}
|
|
#endif
|
|
|
|
// // DEBUG.
|
|
// if( blockIdx.x == 0 && threadIdx.y == 0 && threadIdx.x == 0 ) {
|
|
// printf("elt=%12.8f sum_fp32=%12.8f\n", data_fp32[0][0], sum_fp32);
|
|
// }
|
|
// // END OF DEBUG.
|
|
|
|
// Fix the sum if needed.
|
|
if (sum_fp32 == 0.f || sum_fp32 != sum_fp32)
|
|
{
|
|
sum_fp32 = 1.f;
|
|
}
|
|
|
|
// Normalize.
|
|
float inv_sum_fp32 = 1.f / (sum_fp32 + expf(attention_sink - max_fp32));
|
|
#pragma unroll
|
|
for (int ii = 0; ii < N; ++ii)
|
|
{
|
|
data_fp32[ii][0] *= inv_sum_fp32;
|
|
data_fp32[ii][1] *= inv_sum_fp32;
|
|
data_fp32[ii][2] *= inv_sum_fp32;
|
|
data_fp32[ii][3] *= inv_sum_fp32;
|
|
}
|
|
}
|
|
|
|
template <typename Data_type, int X>
|
|
struct VecX
|
|
{
|
|
|
|
using Type = typename fmha::Uint_from_size_in_bytes<X * sizeof(Data_type)>::Type;
|
|
static_assert(sizeof(Type) == X * sizeof(Data_type));
|
|
|
|
union Alias
|
|
{
|
|
Type raw;
|
|
Data_type elt[X];
|
|
};
|
|
|
|
static __device__ inline void to_floatX(
|
|
float (&dst)[X], Type const& src, float const scale, float const attn_logit_softcapping_scale)
|
|
{
|
|
Alias tmp;
|
|
tmp.raw = src;
|
|
#pragma unroll
|
|
for (int it = 0; it < X; it++)
|
|
{
|
|
dst[it] = to_float(tmp.elt[it], scale);
|
|
if (attn_logit_softcapping_scale != 0.f)
|
|
{
|
|
dst[it] = attn_logit_softcapping_scale * fmha::__tanhf(dst[it] / attn_logit_softcapping_scale);
|
|
}
|
|
}
|
|
}
|
|
|
|
static __device__ inline void from_floatX(Type& dst, float const (&src)[X], float const scale)
|
|
{
|
|
Alias tmp;
|
|
#pragma unroll
|
|
for (int it = 0; it < X; it++)
|
|
{
|
|
from_float(tmp.elt[it], src[it], scale);
|
|
}
|
|
dst = tmp.raw;
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline __device__ float get_alibi_head_scaling_factor(int const head_id, int const num_heads)
|
|
{
|
|
// Round down to power of 2
|
|
int const num_heads_pow2 = (1u << (31 - __clz(num_heads)));
|
|
if (head_id < num_heads_pow2)
|
|
{
|
|
return exp2f((head_id + 1) * -8.0f / num_heads_pow2);
|
|
}
|
|
else
|
|
{
|
|
float const adjusted_head_id = 2 * (head_id - num_heads_pow2) + 1;
|
|
return exp2f(adjusted_head_id * -4.0f / num_heads_pow2);
|
|
}
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename Dst_type, typename Src_type, int SEQLEN, int WARPS_PER_CTA, int X = 4>
|
|
static __global__ void softmax_kernel(Softmax_params<Dst_type, Src_type> params)
|
|
{
|
|
|
|
// By default, use LDG.64 for the loads and STG.64 for the stores.
|
|
enum
|
|
{
|
|
ELEMENTS_PER_LDG = X,
|
|
ELEMENTS_PER_STG = X
|
|
};
|
|
|
|
// The number of Vec_type per thread.
|
|
enum
|
|
{
|
|
VECs_PER_THREAD = SEQLEN / THREADS_PER_WARP / ELEMENTS_PER_LDG
|
|
};
|
|
|
|
// DEBUG.
|
|
static_assert(VECs_PER_THREAD * THREADS_PER_WARP * ELEMENTS_PER_LDG == SEQLEN, "");
|
|
// END OF DEBUG.
|
|
|
|
using VecO = VecX<Dst_type, X>;
|
|
using VecI = VecX<Src_type, X>;
|
|
using VecM = VecX<int8_t, X>;
|
|
// The vector types.
|
|
using DstX_type = typename VecO::Type;
|
|
using SrcX_type = typename VecI::Type;
|
|
|
|
// Make sure the sizes match our expectations.
|
|
static_assert(sizeof(DstX_type) == X * sizeof(Dst_type));
|
|
static_assert(sizeof(SrcX_type) == X * sizeof(Src_type));
|
|
|
|
// The type of the mask.
|
|
using MaskX_type = typename VecM::Type;
|
|
|
|
// One warp per sequence.
|
|
size_t hi = blockIdx.y * WARPS_PER_CTA + threadIdx.y;
|
|
size_t bi = blockIdx.z;
|
|
size_t si = blockIdx.x;
|
|
|
|
// The data offset. Layout is S * B * H * S.
|
|
size_t src_offset = si * params.bhs + bi * params.hs + hi * SEQLEN + threadIdx.x * ELEMENTS_PER_LDG;
|
|
|
|
// Load the input elements.
|
|
SrcX_type const* src_ptr = reinterpret_cast<SrcX_type const*>(¶ms.src[src_offset]);
|
|
SrcX_type data_src[VECs_PER_THREAD];
|
|
#pragma unroll
|
|
for (int ii = 0; ii < VECs_PER_THREAD; ++ii)
|
|
{
|
|
if (hi < params.h)
|
|
{
|
|
data_src[ii] = src_ptr[ii * THREADS_PER_WARP];
|
|
}
|
|
}
|
|
|
|
// The mask offset. Layout is S * B * S.
|
|
size_t mask_offset = si * params.bs + bi * SEQLEN + threadIdx.x * ELEMENTS_PER_LDG;
|
|
|
|
// Load the masks.
|
|
MaskX_type const* mask_ptr = reinterpret_cast<MaskX_type const*>(¶ms.mask[mask_offset]);
|
|
MaskX_type mask[VECs_PER_THREAD];
|
|
#pragma unroll
|
|
for (int ii = 0; ii < VECs_PER_THREAD; ++ii)
|
|
{
|
|
mask[ii] = mask_ptr[ii * THREADS_PER_WARP];
|
|
}
|
|
|
|
// Convert the data to float.
|
|
float data_fp32[VECs_PER_THREAD][X];
|
|
int8_t mask_[VECs_PER_THREAD][X];
|
|
#pragma unroll
|
|
for (int ii = 0; ii < VECs_PER_THREAD; ++ii)
|
|
{
|
|
VecI::to_floatX(data_fp32[ii], data_src[ii], params.scale_bmm1, params.softcapping_scale_bmm1);
|
|
|
|
typename VecM::Alias tmp;
|
|
tmp.raw = mask[ii];
|
|
#pragma unroll
|
|
for (int it = 0; it < X; it++)
|
|
{
|
|
mask_[ii][it] = tmp.elt[it];
|
|
}
|
|
}
|
|
|
|
if (params.has_alibi)
|
|
{
|
|
float const alibi_factor = get_alibi_head_scaling_factor(hi, params.h);
|
|
#pragma unroll
|
|
for (int ii = 0; ii < VECs_PER_THREAD; ii++)
|
|
{
|
|
#pragma unroll
|
|
for (int jj = 0; jj < X; jj++)
|
|
{
|
|
int col = ii * THREADS_PER_WARP * X + threadIdx.x * X + jj;
|
|
data_fp32[ii][jj] += alibi_factor * col;
|
|
}
|
|
}
|
|
}
|
|
|
|
// The attention sink value.
|
|
float attention_sink = -FLT_MAX;
|
|
if (params.attention_sinks != nullptr)
|
|
{
|
|
attention_sink = params.attention_sinks[hi];
|
|
}
|
|
|
|
// Do the reduction.
|
|
float sum_fp32 = 0.f;
|
|
float max_fp32 = -HUGE_VALF;
|
|
reduce(data_fp32, mask_, params.warps_n, sum_fp32, max_fp32, attention_sink);
|
|
if (threadIdx.x == 0)
|
|
{
|
|
int sum_s = params.cu_q_seqlens[bi];
|
|
// [B, S, H, 2] {max, sum} float
|
|
if (hi < params.h)
|
|
{
|
|
params.softmax_sum[(sum_s + si) * params.h * 2 + hi * 2] = max_fp32;
|
|
params.softmax_sum[(sum_s + si) * params.h * 2 + hi * 2 + 1] = sum_fp32;
|
|
}
|
|
}
|
|
// Reconvert to half.
|
|
DstX_type data_dst[VECs_PER_THREAD];
|
|
#pragma unroll
|
|
for (int ii = 0; ii < VECs_PER_THREAD; ++ii)
|
|
{
|
|
VecO::from_floatX(data_dst[ii], data_fp32[ii], params.scale_softmax);
|
|
}
|
|
|
|
// Store the output elements.
|
|
DstX_type* dst_ptr = reinterpret_cast<DstX_type*>(¶ms.dst[src_offset]);
|
|
#pragma unroll
|
|
for (int ii = 0; ii < VECs_PER_THREAD; ++ii)
|
|
{
|
|
if (hi < params.h)
|
|
{
|
|
dst_ptr[ii * THREADS_PER_WARP] = data_dst[ii];
|
|
}
|
|
}
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename Dst_type, typename Src_type>
|
|
void run_softmax(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum,
|
|
void* cu_q_seqlens, int s_inner, int s_outer, int b, int h, float scale_bmm1, float scale_softmax,
|
|
float softcapping_scale_bmm1, int warps_n, bool has_alibi)
|
|
{
|
|
|
|
Softmax_params<Dst_type, Src_type> params;
|
|
memset(¶ms, 0, sizeof(params));
|
|
|
|
// The different pointers.
|
|
params.dst = reinterpret_cast<Dst_type*>(dst);
|
|
params.src = reinterpret_cast<Src_type const*>(src);
|
|
params.softmax_sum = reinterpret_cast<float*>(softmax_sum);
|
|
params.cu_q_seqlens = reinterpret_cast<int*>(cu_q_seqlens);
|
|
params.mask = reinterpret_cast<int8_t const*>(mask);
|
|
params.attention_sinks = reinterpret_cast<float const*>(attention_sinks);
|
|
params.has_alibi = has_alibi;
|
|
|
|
// The dimensions and precomputed values.
|
|
params.b = b;
|
|
params.h = h;
|
|
params.bhs = b * h * s_inner;
|
|
params.hs = h * s_inner;
|
|
params.bs = b * s_inner;
|
|
|
|
// The scaling factors for the int8 version to convert to/from float.
|
|
params.scale_bmm1 = scale_bmm1;
|
|
params.softcapping_scale_bmm1 = softcapping_scale_bmm1;
|
|
params.scale_softmax = scale_softmax;
|
|
// The number of warps_n used to identify the reduction strategy.
|
|
params.warps_n = warps_n;
|
|
|
|
// Compute the grid size.
|
|
enum
|
|
{
|
|
WARPS_PER_CTA = 4
|
|
};
|
|
|
|
dim3 grid(s_outer, (h + WARPS_PER_CTA - 1) / WARPS_PER_CTA, b);
|
|
dim3 threads_per_cta(THREADS_PER_WARP, WARPS_PER_CTA);
|
|
|
|
// Launch the kernel.
|
|
if (s_inner == 32)
|
|
{
|
|
softmax_kernel<Dst_type, Src_type, 32, WARPS_PER_CTA, 1><<<grid, threads_per_cta>>>(params);
|
|
}
|
|
else if (s_inner == 64)
|
|
{
|
|
softmax_kernel<Dst_type, Src_type, 64, WARPS_PER_CTA, 2><<<grid, threads_per_cta>>>(params);
|
|
}
|
|
else if (s_inner == 96)
|
|
{
|
|
softmax_kernel<Dst_type, Src_type, 96, WARPS_PER_CTA, 1><<<grid, threads_per_cta>>>(params);
|
|
}
|
|
else if (s_inner == 128)
|
|
{
|
|
softmax_kernel<Dst_type, Src_type, 128, WARPS_PER_CTA><<<grid, threads_per_cta>>>(params);
|
|
}
|
|
else if (s_inner == 192)
|
|
{
|
|
softmax_kernel<Dst_type, Src_type, 192, WARPS_PER_CTA, 2><<<grid, threads_per_cta>>>(params);
|
|
}
|
|
else if (s_inner == 256)
|
|
{
|
|
softmax_kernel<Dst_type, Src_type, 256, WARPS_PER_CTA><<<grid, threads_per_cta>>>(params);
|
|
}
|
|
else if (s_inner == 384)
|
|
{
|
|
softmax_kernel<Dst_type, Src_type, 384, WARPS_PER_CTA, 2><<<grid, threads_per_cta>>>(params);
|
|
}
|
|
else if (s_inner == 512)
|
|
{
|
|
softmax_kernel<Dst_type, Src_type, 512, WARPS_PER_CTA><<<grid, threads_per_cta>>>(params);
|
|
}
|
|
else if (s_inner == 1024)
|
|
{
|
|
softmax_kernel<Dst_type, Src_type, 1024, WARPS_PER_CTA><<<grid, threads_per_cta>>>(params);
|
|
}
|
|
else if (s_inner == 2048)
|
|
{
|
|
softmax_kernel<Dst_type, Src_type, 2048, WARPS_PER_CTA><<<grid, threads_per_cta>>>(params);
|
|
}
|
|
else if (s_inner == 4096)
|
|
{
|
|
softmax_kernel<Dst_type, Src_type, 4096, WARPS_PER_CTA><<<grid, threads_per_cta>>>(params);
|
|
}
|
|
else if (s_inner == 8192)
|
|
{
|
|
softmax_kernel<Dst_type, Src_type, 8192, WARPS_PER_CTA><<<grid, threads_per_cta>>>(params);
|
|
}
|
|
else if (s_inner == 16384)
|
|
{
|
|
softmax_kernel<Dst_type, Src_type, 16384, WARPS_PER_CTA><<<grid, threads_per_cta>>>(params);
|
|
}
|
|
else if (s_inner == 32768)
|
|
{
|
|
softmax_kernel<Dst_type, Src_type, 32768, WARPS_PER_CTA><<<grid, threads_per_cta>>>(params);
|
|
}
|
|
else if (s_inner == 65536)
|
|
{
|
|
softmax_kernel<Dst_type, Src_type, 65536, WARPS_PER_CTA><<<grid, threads_per_cta>>>(params);
|
|
}
|
|
else
|
|
{
|
|
assert(false);
|
|
}
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|