mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
2190 lines
73 KiB
C++
2190 lines
73 KiB
C++
/*
|
|
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
#pragma once
|
|
|
|
#include "tensorrt_llm/common/cudaTypeUtils.cuh"
|
|
#include "tensorrt_llm/common/memoryUtils.h"
|
|
#include "tensorrt_llm/kernels/decoderMaskedMultiheadAttention.h"
|
|
#include "tensorrt_llm/kernels/decoderMaskedMultiheadAttentionUtils.h"
|
|
#include "tensorrt_llm/kernels/gptKernels.h"
|
|
#include "tensorrt_llm/kernels/kvCacheUtils.h"
|
|
#include <assert.h>
|
|
#include <float.h>
|
|
#include <type_traits>
|
|
|
|
// Multi-block mmha kernel can only be selected when CUDA >= 11.7
|
|
#if (CUDART_VERSION >= 11070)
|
|
#define ENABLE_MULTI_BLOCK_OPTION
|
|
#endif
|
|
|
|
#ifdef ENABLE_MULTI_BLOCK_OPTION
|
|
#include <cub/block/block_reduce.cuh>
|
|
#include <cuda/atomic>
|
|
#include <cuda/std/bit>
|
|
#endif // ENABLE_MULTI_BLOCK_OPTION
|
|
|
|
namespace tensorrt_llm
|
|
{
|
|
namespace kernels
|
|
{
|
|
|
|
// #define MMHA_USE_HMMA_FOR_REDUCTION
|
|
|
|
// Below are knobs to extend FP32 accumulation for higher FP16 accuracy
|
|
|
|
// Does not seem to affect the accuracy that much
|
|
#define MMHA_USE_FP32_ACUM_FOR_FMA
|
|
|
|
// Seems to slightly improve the accuracy
|
|
#define MMHA_USE_FP32_ACUM_FOR_OUT
|
|
|
|
#if 0 && defined(MMHA_USE_FP32_ACUM_FOR_OUT)
|
|
// Does not seem to improve the accuracy
|
|
//#define MMHA_USE_FP32_ACUM_FOR_LOGITS
|
|
#endif
|
|
|
|
namespace mmha
|
|
{
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
//
|
|
// We use the following terminology to describe the different dimensions.
|
|
//
|
|
// B: Batch size (number of sequences),
|
|
// L: Sequence length,
|
|
// D: Hidden dimension,
|
|
// H: Number of heads,
|
|
// Dh: Hidden dimension per head - Dh = D / H.
|
|
//
|
|
// The different kernels assign a threadblock for B x H pair. The grid has size (1, B, H). We use
|
|
// 256 threads per block to maximum occupancy and performance.
|
|
//
|
|
// Each threadblock loads Dh values from Q and its associated bias. The kernels run a loop to
|
|
// compute Q * K^T where K is loaded from a cache buffer -- except for the current timestep. The
|
|
// cache buffer helps with memory accesses and contains keys with bias.
|
|
//
|
|
// The layout of the cache buffer for the keys/values is [B, H, L, Dh]
|
|
// where the fastest moving dimension (contiguous data) is the rightmost one.
|
|
// Contiguous threads will read one hidden_dimension per LDG unless we need more than 32 threads.
|
|
//
|
|
// The different kernels use 1 ~ 32 threads per key (THREADS_PER_KEY). The size of the LDGs
|
|
// is always 16bytes (8 bytes for 8bit cache). Each thread sums Dh / THREADS_PER_KEY elements. At
|
|
// the end of each iteration of the Q * K^T loop, we perform a reduction between lanes using an
|
|
// HMMA instruction (Tensor Core). Each Q * K^T value is stored in shared memory in FP32.
|
|
//
|
|
// After that loop, a parallel softmax is computed across the different Q * K^T values stored in
|
|
// shared memory.
|
|
//
|
|
// The kernel ends with a loop over the values in V. We use THREADS_PER_VALUE to control how many
|
|
// timesteps are computed by loop iteration. As with the keys, the values are read from a cache
|
|
// except for the current timestep. The layout of the cache buffer for the values is same as the key,
|
|
// which is [B, H, L, Dh].
|
|
//
|
|
// Note that we have remapped key layout to make sure it shares the same pattern as value [B, H, L, Dh].
|
|
// It helps coalescing memory access, and reducing register pressure.
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename T, int Dh_MAX>
|
|
struct Qk_vec_m_
|
|
{
|
|
};
|
|
|
|
template <>
|
|
struct Qk_vec_m_<float, 32>
|
|
{
|
|
using Type = float;
|
|
};
|
|
|
|
template <>
|
|
struct Qk_vec_m_<float, 64>
|
|
{
|
|
using Type = float2;
|
|
};
|
|
|
|
template <>
|
|
struct Qk_vec_m_<float, 128>
|
|
{
|
|
using Type = float4;
|
|
};
|
|
|
|
template <>
|
|
struct Qk_vec_m_<float, 256>
|
|
{
|
|
using Type = float4;
|
|
};
|
|
|
|
template <>
|
|
struct Qk_vec_m_<uint16_t, 32>
|
|
{
|
|
using Type = uint32_t;
|
|
};
|
|
|
|
template <>
|
|
struct Qk_vec_m_<uint16_t, 64>
|
|
{
|
|
using Type = uint32_t;
|
|
};
|
|
|
|
template <>
|
|
struct Qk_vec_m_<uint16_t, 128>
|
|
{
|
|
using Type = uint2;
|
|
};
|
|
|
|
template <>
|
|
struct Qk_vec_m_<uint16_t, 256>
|
|
{
|
|
using Type = uint4;
|
|
};
|
|
#ifdef ENABLE_BF16
|
|
template <>
|
|
struct Qk_vec_m_<__nv_bfloat16, 32>
|
|
{
|
|
using Type = __nv_bfloat162;
|
|
};
|
|
|
|
template <>
|
|
struct Qk_vec_m_<__nv_bfloat16, 64>
|
|
{
|
|
using Type = __nv_bfloat162;
|
|
};
|
|
|
|
template <>
|
|
struct Qk_vec_m_<__nv_bfloat16, 128>
|
|
{
|
|
using Type = bf16_4_t;
|
|
};
|
|
|
|
template <>
|
|
struct Qk_vec_m_<__nv_bfloat16, 256>
|
|
{
|
|
using Type = bf16_8_t;
|
|
};
|
|
#endif // ENABLE_BF16
|
|
|
|
#ifdef ENABLE_FP8
|
|
template <>
|
|
struct Qk_vec_m_<__nv_fp8_e4m3, 32>
|
|
{
|
|
using Type = fp8_4_t;
|
|
};
|
|
|
|
template <>
|
|
struct Qk_vec_m_<__nv_fp8_e4m3, 64>
|
|
{
|
|
using Type = fp8_4_t;
|
|
};
|
|
|
|
template <>
|
|
struct Qk_vec_m_<__nv_fp8_e4m3, 128>
|
|
{
|
|
using Type = fp8_4_t;
|
|
};
|
|
|
|
template <>
|
|
struct Qk_vec_m_<__nv_fp8_e4m3, 256>
|
|
{
|
|
using Type = fp8_4_t;
|
|
};
|
|
#endif // ENABLE_FP8
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename T, int Dh>
|
|
struct Qk_vec_k_
|
|
{
|
|
using Type = typename Qk_vec_m_<T, Dh>::Type;
|
|
};
|
|
#ifdef ENABLE_FP8
|
|
template <>
|
|
struct Qk_vec_k_<__nv_fp8_e4m3, 32>
|
|
{
|
|
using Type = float4;
|
|
};
|
|
|
|
template <>
|
|
struct Qk_vec_k_<__nv_fp8_e4m3, 64>
|
|
{
|
|
using Type = float4;
|
|
};
|
|
|
|
template <>
|
|
struct Qk_vec_k_<__nv_fp8_e4m3, 128>
|
|
{
|
|
using Type = float4;
|
|
};
|
|
|
|
template <>
|
|
struct Qk_vec_k_<__nv_fp8_e4m3, 256>
|
|
{
|
|
using Type = float4;
|
|
};
|
|
#endif // ENABLE_FP8
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename T, int V_VEC_SIZE>
|
|
struct V_vec_m_
|
|
{
|
|
};
|
|
|
|
template <>
|
|
struct V_vec_m_<float, 1>
|
|
{
|
|
using Type = float;
|
|
};
|
|
|
|
template <>
|
|
struct V_vec_m_<float, 2>
|
|
{
|
|
using Type = float2;
|
|
};
|
|
|
|
template <>
|
|
struct V_vec_m_<float, 4>
|
|
{
|
|
using Type = float4;
|
|
};
|
|
|
|
template <>
|
|
struct V_vec_m_<float, 8>
|
|
{
|
|
using Type = Float8_;
|
|
};
|
|
|
|
template <>
|
|
struct V_vec_m_<uint16_t, 2>
|
|
{
|
|
using Type = uint32_t;
|
|
};
|
|
|
|
template <>
|
|
struct V_vec_m_<uint16_t, 4>
|
|
{
|
|
using Type = uint2;
|
|
};
|
|
|
|
template <>
|
|
struct V_vec_m_<uint16_t, 8>
|
|
{
|
|
using Type = uint4;
|
|
};
|
|
#ifdef ENABLE_BF16
|
|
template <>
|
|
struct V_vec_m_<__nv_bfloat16, 2>
|
|
{
|
|
using Type = __nv_bfloat162;
|
|
};
|
|
|
|
template <>
|
|
struct V_vec_m_<__nv_bfloat16, 4>
|
|
{
|
|
using Type = bf16_4_t;
|
|
};
|
|
|
|
template <>
|
|
struct V_vec_m_<__nv_bfloat16, 8>
|
|
{
|
|
using Type = bf16_8_t;
|
|
};
|
|
#endif // ENABLE_BF16
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename T, int V_VEC_SIZE>
|
|
struct V_vec_k_
|
|
{
|
|
using Type = typename V_vec_m_<T, V_VEC_SIZE>::Type;
|
|
};
|
|
#ifdef ENABLE_FP8
|
|
template <>
|
|
struct V_vec_k_<__nv_fp8_e4m3, 4>
|
|
{
|
|
using Type = float4;
|
|
};
|
|
|
|
template <>
|
|
struct V_vec_k_<__nv_fp8_e4m3, 8>
|
|
{
|
|
using Type = float4;
|
|
};
|
|
|
|
template <>
|
|
struct V_vec_k_<__nv_fp8_e4m3, 16>
|
|
{
|
|
using Type = float4;
|
|
};
|
|
#endif
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// Reuse V_vec traits as key and value share the same layout.
|
|
template <typename T, int K_VEC_SIZE>
|
|
struct K_vec_m_
|
|
{
|
|
using Type = typename V_vec_m_<T, K_VEC_SIZE>::Type;
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename T, int K_VEC_SIZE>
|
|
struct K_vec_k_
|
|
{
|
|
using Type = typename K_vec_m_<T, K_VEC_SIZE>::Type;
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
#ifdef MMHA_USE_FP32_ACUM_FOR_FMA
|
|
template <typename T>
|
|
struct Qk_vec_acum_fp32_
|
|
{
|
|
};
|
|
|
|
template <>
|
|
struct Qk_vec_acum_fp32_<float>
|
|
{
|
|
using Type = float;
|
|
};
|
|
|
|
template <>
|
|
struct Qk_vec_acum_fp32_<float2>
|
|
{
|
|
using Type = float2;
|
|
};
|
|
|
|
template <>
|
|
struct Qk_vec_acum_fp32_<float4>
|
|
{
|
|
using Type = float4;
|
|
};
|
|
|
|
// template<> struct Qk_vec_acum_fp32_<uint16_t> { using Type = float; };
|
|
template <>
|
|
struct Qk_vec_acum_fp32_<uint32_t>
|
|
{
|
|
using Type = float2;
|
|
};
|
|
|
|
template <>
|
|
struct Qk_vec_acum_fp32_<uint2>
|
|
{
|
|
using Type = Float4_;
|
|
};
|
|
|
|
template <>
|
|
struct Qk_vec_acum_fp32_<uint4>
|
|
{
|
|
using Type = Float8_;
|
|
};
|
|
|
|
template <>
|
|
struct Qk_vec_acum_fp32_<__nv_bfloat16>
|
|
{
|
|
using Type = float;
|
|
};
|
|
|
|
template <>
|
|
struct Qk_vec_acum_fp32_<__nv_bfloat162>
|
|
{
|
|
using Type = float2;
|
|
};
|
|
|
|
template <>
|
|
struct Qk_vec_acum_fp32_<bf16_4_t>
|
|
{
|
|
using Type = Float4_;
|
|
};
|
|
|
|
template <>
|
|
struct Qk_vec_acum_fp32_<bf16_8_t>
|
|
{
|
|
using Type = Float8_;
|
|
};
|
|
|
|
#ifdef ENABLE_FP8
|
|
// template<>
|
|
// struct Qk_vec_acum_fp32_<fp8_2_t> {
|
|
// using Type = float2;
|
|
// };
|
|
template <>
|
|
struct Qk_vec_acum_fp32_<fp8_4_t>
|
|
{
|
|
using Type = Float4_;
|
|
};
|
|
|
|
// template<>
|
|
// struct Qk_vec_acum_fp32_<fp8_8_t> {
|
|
// using Type = Float4_;
|
|
// };
|
|
#endif // ENABLE_FP8
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename T>
|
|
struct K_vec_acum_fp32_
|
|
{
|
|
};
|
|
|
|
template <>
|
|
struct K_vec_acum_fp32_<float>
|
|
{
|
|
using Type = float;
|
|
};
|
|
|
|
template <>
|
|
struct K_vec_acum_fp32_<float2>
|
|
{
|
|
using Type = float2;
|
|
};
|
|
|
|
template <>
|
|
struct K_vec_acum_fp32_<float4>
|
|
{
|
|
using Type = float4;
|
|
};
|
|
|
|
template <>
|
|
struct K_vec_acum_fp32_<Float8_>
|
|
{
|
|
using Type = Float8_;
|
|
};
|
|
|
|
template <>
|
|
struct K_vec_acum_fp32_<uint32_t>
|
|
{
|
|
using Type = float2;
|
|
};
|
|
|
|
template <>
|
|
struct K_vec_acum_fp32_<uint2>
|
|
{
|
|
using Type = Float4_;
|
|
};
|
|
|
|
template <>
|
|
struct K_vec_acum_fp32_<uint4>
|
|
{
|
|
using Type = Float8_;
|
|
};
|
|
|
|
template <>
|
|
struct K_vec_acum_fp32_<__nv_bfloat16>
|
|
{
|
|
using Type = float;
|
|
};
|
|
|
|
template <>
|
|
struct K_vec_acum_fp32_<__nv_bfloat162>
|
|
{
|
|
using Type = float2;
|
|
};
|
|
|
|
template <>
|
|
struct K_vec_acum_fp32_<bf16_4_t>
|
|
{
|
|
using Type = Float4_;
|
|
};
|
|
|
|
template <>
|
|
struct K_vec_acum_fp32_<bf16_8_t>
|
|
{
|
|
using Type = Float8_;
|
|
};
|
|
#ifdef ENABLE_FP8
|
|
// template<>
|
|
// struct K_vec_acum_fp32_<fp8_2_t> {
|
|
// using Type = float2;
|
|
// };
|
|
template <>
|
|
struct K_vec_acum_fp32_<fp8_4_t>
|
|
{
|
|
using Type = Float4_;
|
|
};
|
|
|
|
// template<>
|
|
// struct K_vec_acum_fp32_<fp8_8_t> {
|
|
// using Type = Float4_;
|
|
// };
|
|
#endif // ENABLE_FP8
|
|
#endif // MMHA_USE_FP32_ACUM_FOR_FMA
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
|
|
template <typename T>
|
|
struct V_vec_acum_fp32_
|
|
{
|
|
};
|
|
|
|
template <>
|
|
struct V_vec_acum_fp32_<float>
|
|
{
|
|
using Type = float;
|
|
};
|
|
|
|
template <>
|
|
struct V_vec_acum_fp32_<float2>
|
|
{
|
|
using Type = float2;
|
|
};
|
|
|
|
template <>
|
|
struct V_vec_acum_fp32_<float4>
|
|
{
|
|
using Type = float4;
|
|
};
|
|
|
|
template <>
|
|
struct V_vec_acum_fp32_<uint32_t>
|
|
{
|
|
using Type = float2;
|
|
};
|
|
|
|
template <>
|
|
struct V_vec_acum_fp32_<uint2>
|
|
{
|
|
using Type = Float4_;
|
|
};
|
|
|
|
template <>
|
|
struct V_vec_acum_fp32_<uint4>
|
|
{
|
|
using Type = Float8_;
|
|
};
|
|
#ifdef ENABLE_BF16
|
|
template <>
|
|
struct V_vec_acum_fp32_<__nv_bfloat162>
|
|
{
|
|
using Type = float2;
|
|
};
|
|
|
|
template <>
|
|
struct V_vec_acum_fp32_<bf16_4_t>
|
|
{
|
|
using Type = Float4_;
|
|
};
|
|
|
|
template <>
|
|
struct V_vec_acum_fp32_<bf16_8_t>
|
|
{
|
|
using Type = Float8_;
|
|
};
|
|
#endif // ENABLE_BF16
|
|
#ifdef ENABLE_FP8
|
|
// template<>
|
|
// struct V_vec_acum_fp32_<fp8_2_t> {
|
|
// using Type = float2;
|
|
// };
|
|
template <>
|
|
struct V_vec_acum_fp32_<fp8_4_t>
|
|
{
|
|
using Type = Float4_;
|
|
};
|
|
|
|
// template<>
|
|
// struct V_vec_acum_fp32_<fp8_8_t> {
|
|
// using Type = Float4_;
|
|
// };
|
|
#endif // ENABLE_FP8
|
|
#endif
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename Tout, typename Tin>
|
|
__inline__ __device__ constexpr Tout vec_conversion(const Tin& x)
|
|
{
|
|
static_assert(std::is_same<Tout, Tin>::value, "Type mismatch");
|
|
return x;
|
|
}
|
|
#ifdef ENABLE_FP8
|
|
// fp8_t
|
|
template <>
|
|
__inline__ __device__ float vec_conversion<float, __nv_fp8_e4m3>(const __nv_fp8_e4m3& a)
|
|
{
|
|
return float(a);
|
|
}
|
|
|
|
template <>
|
|
__inline__ __device__ __nv_fp8_e4m3 vec_conversion<__nv_fp8_e4m3, float>(const float& a)
|
|
{
|
|
return __nv_fp8_e4m3(a);
|
|
}
|
|
|
|
// fp8_2_t
|
|
template <>
|
|
__inline__ __device__ float2 vec_conversion<float2, fp8_2_t>(const fp8_2_t& a)
|
|
{
|
|
return float2(a);
|
|
}
|
|
|
|
template <>
|
|
__inline__ __device__ fp8_2_t vec_conversion<fp8_2_t, float2>(const float2& a)
|
|
{
|
|
return fp8_2_t(a);
|
|
}
|
|
|
|
// fp8_4_t
|
|
template <>
|
|
__inline__ __device__ float4 vec_conversion<float4, fp8_4_t>(const fp8_4_t& a)
|
|
{
|
|
return float4(a);
|
|
}
|
|
|
|
template <>
|
|
__inline__ __device__ fp8_4_t vec_conversion<fp8_4_t, float4>(const float4& a)
|
|
{
|
|
return fp8_4_t(a);
|
|
}
|
|
#endif // ENABLE_FP8
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <int THREADS_PER_KEY, typename K_vec, int N>
|
|
inline __device__ float qk_dot_(const K_vec (&q)[N], const K_vec (&k)[N])
|
|
{
|
|
#ifdef MMHA_USE_FP32_ACUM_FOR_FMA
|
|
using K_vec_acum = typename K_vec_acum_fp32_<K_vec>::Type;
|
|
#else
|
|
using K_vec_acum = K_vec;
|
|
#endif
|
|
// Compute the parallel products for Q*K^T (treat vector lanes separately).
|
|
K_vec_acum qk_vec = mul<K_vec_acum, K_vec, K_vec>(q[0], k[0]);
|
|
#pragma unroll
|
|
for (int ii = 1; ii < N; ++ii)
|
|
{
|
|
qk_vec = fma(q[ii], k[ii], qk_vec);
|
|
}
|
|
|
|
// Finalize the reduction across lanes.
|
|
float qk = sum(qk_vec);
|
|
#pragma unroll
|
|
for (int mask = THREADS_PER_KEY / 2; mask >= 1; mask /= 2)
|
|
{
|
|
qk += __shfl_xor_sync(uint32_t(-1), qk, mask);
|
|
}
|
|
return qk;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename T, int THREADS_PER_KEY>
|
|
struct Qk_dot
|
|
{
|
|
template <typename K_vec, int N>
|
|
static inline __device__ float dot(const K_vec (&q)[N], const K_vec (&k)[N])
|
|
{
|
|
return qk_dot_<THREADS_PER_KEY>(q, k);
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline __device__ float4 hmma_fp32(const uint2& a, uint32_t b)
|
|
{
|
|
float4 c;
|
|
float zero = 0.f;
|
|
asm volatile(
|
|
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 \n"
|
|
" {%0, %1, %2, %3}, \n"
|
|
" {%4, %5}, \n"
|
|
" {%6}, \n"
|
|
" {%7, %7, %7, %7}; \n"
|
|
|
|
: "=f"(c.x), "=f"(c.y), "=f"(c.z), "=f"(c.w)
|
|
: "r"(a.x) "r"(a.y), "r"(b), "f"(zero));
|
|
return c;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <int N>
|
|
inline __device__ float qk_hmma_dot_(const uint32_t (&q)[N], const uint32_t (&k)[N])
|
|
{
|
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
|
|
#ifdef MMHA_USE_FP32_ACUM_FOR_FMA
|
|
using K_vec_acum = typename K_vec_acum_fp32_<uint32_t>::Type;
|
|
#else
|
|
using K_vec_acum = uint32_t;
|
|
#endif
|
|
K_vec_acum qk_vec = mul<K_vec_acum, uint32_t, uint32_t>(q[0], k[0]);
|
|
#pragma unroll
|
|
for (int ii = 1; ii < N; ++ii)
|
|
{
|
|
qk_vec = fma(q[ii], k[ii], qk_vec);
|
|
}
|
|
#ifdef MMHA_USE_FP32_ACUM_FOR_FMA
|
|
uint32_t qk_vec_ = float2_to_half2(qk_vec);
|
|
return hmma_fp32(make_uint2(qk_vec_, 0u), 0x3c003c00u).x;
|
|
#else
|
|
return hmma_fp32(make_uint2(qk_vec, 0u), 0x3c003c00u).x;
|
|
#endif
|
|
#else
|
|
return 0.f;
|
|
#endif
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <>
|
|
struct Qk_dot<uint16_t, 4>
|
|
{
|
|
template <typename K_vec, int N>
|
|
static inline __device__ float dot(const K_vec (&q)[N], const K_vec (&k)[N])
|
|
{
|
|
return qk_dot_<4>(q, k);
|
|
}
|
|
|
|
template <int N>
|
|
static inline __device__ float dot(const uint32_t (&q)[N], const uint32_t (&k)[N])
|
|
{
|
|
#if __CUDA_ARCH__ >= 750 && defined(MMHA_USE_HMMA_FOR_REDUCTION)
|
|
return qk_hmma_dot_(q, k);
|
|
#else
|
|
return qk_dot_<4>(q, k);
|
|
#endif // defined MMHA_USE_HMMA_FOR_REDUCTION
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <int WARPS_PER_BLOCK, int WARP_SIZE = 32>
|
|
inline __device__ float block_sum(float* red_smem, float sum)
|
|
{
|
|
|
|
// Decompose the thread index into warp / lane.
|
|
int warp = threadIdx.x / WARP_SIZE;
|
|
int lane = threadIdx.x % WARP_SIZE;
|
|
|
|
// Compute the sum per warp.
|
|
#pragma unroll
|
|
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2)
|
|
{
|
|
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
|
|
}
|
|
|
|
// Warp leaders store the data to shared memory.
|
|
if (lane == 0)
|
|
{
|
|
red_smem[warp] = sum;
|
|
}
|
|
|
|
// Make sure the data is in shared memory.
|
|
__syncthreads();
|
|
|
|
// The warps compute the final sums.
|
|
if (lane < WARPS_PER_BLOCK)
|
|
{
|
|
sum = red_smem[lane];
|
|
}
|
|
|
|
// Parallel reduction inside the warp.
|
|
#pragma unroll
|
|
for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2)
|
|
{
|
|
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
|
|
}
|
|
|
|
// Broadcast to other threads.
|
|
return __shfl_sync(uint32_t(-1), sum, 0);
|
|
}
|
|
|
|
#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline __device__ float cast_to_float(float u)
|
|
{
|
|
return u;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline __device__ float2 cast_to_float(float2 u)
|
|
{
|
|
return u;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline __device__ float4 cast_to_float(float4 u)
|
|
{
|
|
return u;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline __device__ Float4_ cast_to_float(Float4_ u)
|
|
{
|
|
return u;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline __device__ Float8_ cast_to_float(Float8_ u)
|
|
{
|
|
return u;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline __device__ float2 cast_to_float(uint32_t u)
|
|
{
|
|
return half2_to_float2(u);
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline __device__ Float4_ cast_to_float(uint2 u)
|
|
{
|
|
Float4_ tmp;
|
|
tmp.x = half2_to_float2(u.x);
|
|
tmp.y = half2_to_float2(u.y);
|
|
return tmp;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline __device__ Float8_ cast_to_float(uint4 u)
|
|
{
|
|
Float8_ tmp;
|
|
tmp.x = half2_to_float2(u.x);
|
|
tmp.y = half2_to_float2(u.y);
|
|
tmp.z = half2_to_float2(u.z);
|
|
tmp.w = half2_to_float2(u.w);
|
|
return tmp;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline __device__ float2 cast_to_float(__nv_bfloat162 u)
|
|
{
|
|
float2 tmp;
|
|
tmp = __bfloat1622float2(u);
|
|
return tmp;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline __device__ Float4_ cast_to_float(bf16_4_t u)
|
|
{
|
|
Float4_ tmp;
|
|
tmp.x = __bfloat1622float2(u.x);
|
|
tmp.y = __bfloat1622float2(u.y);
|
|
return tmp;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline __device__ Float8_ cast_to_float(bf16_8_t u)
|
|
{
|
|
Float8_ tmp;
|
|
tmp.x = __bfloat1622float2(u.x);
|
|
tmp.y = __bfloat1622float2(u.y);
|
|
tmp.z = __bfloat1622float2(u.z);
|
|
tmp.w = __bfloat1622float2(u.w);
|
|
return tmp;
|
|
}
|
|
|
|
#endif
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename T>
|
|
inline __device__ __host__ T divUp(T m, T n)
|
|
{
|
|
return (m + n - 1) / n;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename T>
|
|
inline __device__ __host__ T div(T m, T n)
|
|
{
|
|
return m / n;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename T>
|
|
struct kernel_type_t
|
|
{
|
|
using Type = T;
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// Compute the largest supported head size (dh_max). It must be the smallest power-of-2 that is not strictly smaller
|
|
// than the head size (dh).
|
|
inline __device__ __host__ constexpr unsigned dh_max(unsigned dh)
|
|
{
|
|
return next_power_of_two(mmha::const_max(dh, 32u));
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename T>
|
|
inline __device__ __host__ constexpr unsigned threads_per_value(unsigned dh_max)
|
|
{
|
|
return dh_max * sizeof(T) / 16;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename T, unsigned Dh_MAX>
|
|
inline __device__ __host__ constexpr unsigned threads_per_key()
|
|
{
|
|
// Since we want to perform the reduction entirely within a warp, the number of threads per key
|
|
// is capped at 32.
|
|
constexpr unsigned threads = (unsigned) (Dh_MAX * sizeof(T) / 16u);
|
|
if ((threads & (threads - 1)) != 0)
|
|
{
|
|
assert(false); // Not a power of two.
|
|
}
|
|
return std::min(32u, threads);
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline __device__ constexpr uint32_t shfl_mask(int threads)
|
|
{
|
|
assert(threads <= 32);
|
|
return threads == 32 ? -1u : (1u << threads) - 1u;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename T, typename T_VEC, unsigned VECS_PER_CHUNK>
|
|
__device__ inline constexpr uint2 chunk_index(unsigned tidx)
|
|
{
|
|
// The chunk associated with the thread.
|
|
auto const idx_chunk = tidx / VECS_PER_CHUNK;
|
|
|
|
// The position of the T_VEC vector in that chunk associated with the thread.
|
|
static_assert(sizeof(T_VEC) % sizeof(T) == 0);
|
|
unsigned constexpr kVecSize{sizeof(T_VEC) / sizeof(T)};
|
|
auto const idx_vec = (tidx % VECS_PER_CHUNK) * kVecSize;
|
|
|
|
return uint2{idx_chunk, idx_vec};
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <
|
|
// The type of the inputs. Supported types: float, uint16_t, nv_bfloat16.
|
|
typename T,
|
|
// The type of the cache.
|
|
typename Tcache,
|
|
// Type of struct containing KV cache
|
|
typename KVCacheBuffer,
|
|
// The hidden dimension per head.
|
|
unsigned Dh,
|
|
// The number of threads in a threadblock.
|
|
unsigned THREADS_PER_BLOCK,
|
|
// Whether has beams.
|
|
bool HAS_BEAMS,
|
|
// Whether enable multi-block mode for long-sequence-length.
|
|
bool DO_MULTI_BLOCK = false,
|
|
// The number of threads per key.
|
|
unsigned THREADS_PER_KEY = threads_per_key<T, dh_max(Dh)>(),
|
|
// The number of threads per value.
|
|
unsigned THREADS_PER_VALUE = threads_per_value<T>(dh_max(Dh)),
|
|
// The unroll factor for loading from K cache.
|
|
unsigned K_LOOP_UNROLL = 8,
|
|
// The unroll factor for loading from V cache.
|
|
// Set it default to 4 for higher occupancy (by reducing registers usage).
|
|
unsigned V_LOOP_UNROLL = 4>
|
|
__global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> params, KVCacheBuffer kvCacheBuffer)
|
|
{
|
|
|
|
using Tk = typename kernel_type_t<T>::Type;
|
|
// Use 8bit cache.
|
|
static constexpr bool ENABLE_8BITS_CACHE = sizeof(Tcache) == 1;
|
|
|
|
// The size of a warp.
|
|
constexpr unsigned WARP_SIZE{32};
|
|
// The number of warps in a threadblock.
|
|
constexpr unsigned WARPS_PER_BLOCK{THREADS_PER_BLOCK / WARP_SIZE};
|
|
|
|
// The maximum hidden size per head.
|
|
constexpr auto Dh_MAX = dh_max(Dh);
|
|
constexpr bool IS_Dh_MAX = Dh == Dh_MAX;
|
|
static_assert(Dh_MAX >= WARP_SIZE);
|
|
static_assert(Dh_MAX >= Dh);
|
|
|
|
// The maximum sequence length in the kv_cache, i.e., an upper bound on L.
|
|
// Note that the maximum sequence length supported by the model might be greater than this.
|
|
const auto max_seq_len = static_cast<unsigned>(params.memory_max_len);
|
|
assert(max_seq_len > 0);
|
|
// The current timestep (including paddings).
|
|
// It is only used to calculate the smem stride.
|
|
const auto timestep = static_cast<unsigned>(DO_MULTI_BLOCK ? params.timesteps_per_block : params.timestep);
|
|
|
|
#ifdef ENABLE_MULTI_BLOCK_OPTION
|
|
constexpr bool MULTI_BLOCK_FLAG = DO_MULTI_BLOCK;
|
|
#else
|
|
constexpr bool MULTI_BLOCK_FLAG = false;
|
|
#endif
|
|
|
|
// Use smem_size_in_bytes (above) to determine the amount of shared memory.
|
|
extern __shared__ char smem_[];
|
|
|
|
// The shared memory for the Q*K^T values and partial logits in softmax.
|
|
auto qk_smem = reinterpret_cast<float*>(smem_);
|
|
|
|
__shared__ float qk_current_smem[1];
|
|
|
|
// The shared memory for the logits. For FP32, that's the same buffer as qk_smem.
|
|
char* logits_smem_ = smem_;
|
|
#ifndef MMHA_USE_FP32_ACUM_FOR_LOGITS
|
|
if (sizeof(Tk) != 4)
|
|
{
|
|
// TODO - change to tlength
|
|
const auto max_timesteps = min(timestep, max_seq_len);
|
|
logits_smem_ += divUp(max_timesteps + 1, 4u) * 16;
|
|
}
|
|
Tk* logits_smem = reinterpret_cast<Tk*>(logits_smem_);
|
|
#else
|
|
float* logits_smem = reinterpret_cast<float*>(logits_smem_);
|
|
#endif
|
|
|
|
__shared__ Tk logits_current_smem[1];
|
|
|
|
// The shared memory to do the final reduction for the output values. Reuse qk_smem.
|
|
Tk* out_smem = reinterpret_cast<Tk*>(smem_);
|
|
|
|
// The shared memory buffers for the block-wide reductions. One for max, one for sum.
|
|
__shared__ float red_smem[WARPS_PER_BLOCK * 2];
|
|
|
|
// A vector of Q or K elements for the current timestep.
|
|
using Qk_vec_m = typename Qk_vec_m_<T, Dh_MAX>::Type; // with memory-used precision
|
|
using Qk_vec_k = typename Qk_vec_k_<T, Dh_MAX>::Type; // with kernel-used precision
|
|
|
|
// Make sure the hidden dimension per head is a multiple of the number of threads per key.
|
|
static_assert(Dh_MAX % THREADS_PER_KEY == 0); // trivially satisfied since THREADS_PER_KEY in {1, 2, 4}
|
|
|
|
// The number of elements per vector.
|
|
// Each thread will handle 16 bytes.
|
|
constexpr int K_VEC_SIZE = 16u / sizeof(T);
|
|
// Make sure the hidden size per head is a multiple of the vector size.
|
|
static_assert(Dh_MAX % K_VEC_SIZE == 0);
|
|
// The type of queries and keys for the math in the Q*K^T product.
|
|
using K_vec_k = typename K_vec_k_<T, K_VEC_SIZE>::Type;
|
|
// Only used when key cache is quantized to 8 bits.
|
|
using K_vec_m = typename packed_type<Tcache, num_elems<K_vec_k>::value>::type;
|
|
|
|
// Use alignment for safely casting the shared buffers as Qk_vec_k and K_vec_k.
|
|
// Shared memory to store Q inputs.
|
|
__shared__ __align__(mmha::const_max(sizeof(Qk_vec_k), sizeof(K_vec_k))) Tk q_smem[Dh_MAX];
|
|
|
|
// Make sure the hidden dimension per head is a multiple of the number of threads per value.
|
|
static_assert(Dh_MAX % THREADS_PER_VALUE == 0); // trivially satisfied since THREADS_PER_VALUE == Dh_MAX / p
|
|
|
|
// The number of elements per vector.
|
|
constexpr int V_VEC_SIZE = Dh_MAX / THREADS_PER_VALUE;
|
|
// A vector of V elements for the current timestep.
|
|
using V_vec_k = typename V_vec_k_<T, V_VEC_SIZE>::Type;
|
|
// Only used when value cache is quantized to 8 bits.
|
|
using V_vec_m = typename packed_type<Tcache, num_elems<V_vec_k>::value>::type;
|
|
static_assert(V_VEC_SIZE == sizeof(V_vec_k) / sizeof(T));
|
|
|
|
// The number of elements per vector.
|
|
constexpr unsigned QK_VEC_SIZE{sizeof(Qk_vec_m) / sizeof(T)};
|
|
// Make sure the hidden size per head is a multiple of the vector size.
|
|
static_assert(Dh_MAX % QK_VEC_SIZE == 0);
|
|
// We will use block wide reduction if needed
|
|
// The number of vectors per Dh_MAX.
|
|
constexpr unsigned QK_VECS_PER_Dh_MAX{Dh_MAX / QK_VEC_SIZE};
|
|
static_assert(THREADS_PER_BLOCK >= QK_VECS_PER_Dh_MAX);
|
|
|
|
// The batch/beam idx
|
|
const auto bi = blockIdx.y;
|
|
if (params.finished != nullptr && params.finished[bi])
|
|
{
|
|
return;
|
|
}
|
|
// The head.
|
|
const unsigned hi{blockIdx.x};
|
|
// The head index of keys and values adjusted for MQA/GQA.
|
|
const int qhead_per_kv{params.num_heads / params.num_kv_heads};
|
|
const unsigned hi_kv{hi / qhead_per_kv};
|
|
// The number of heads.
|
|
const auto num_heads = static_cast<unsigned>(params.num_heads);
|
|
// The number of heads for keys and values adjusted for MQA/GQA.
|
|
const auto num_heads_kv = static_cast<unsigned>(params.num_kv_heads);
|
|
|
|
// The thread in the block.
|
|
const unsigned tidx{threadIdx.x};
|
|
|
|
// The column tile along L dimension on K^T -- noted as T_c in flash-attention paper
|
|
const unsigned c_tile{MULTI_BLOCK_FLAG ? blockIdx.z : 0};
|
|
|
|
// While doing the product Q*K^T for the different keys we track the max.
|
|
float qk_max = -FLT_MAX;
|
|
|
|
float qk = 0.0F;
|
|
|
|
// The actual sequence length excluding the paddings.
|
|
// minus 1 because it includes the current timestep while tlength denotes the kv cache length.
|
|
const int tlength = params.length_per_sample ? (params.length_per_sample[bi] - 1) : static_cast<int>(timestep);
|
|
// The context length for beam searching optimization (all points to beam 0).
|
|
const int input_length = params.input_lengths[bi];
|
|
|
|
// The offset in the Q and K buffer also accounts for the batch.
|
|
const auto qk_vec_idx = tidx * QK_VEC_SIZE;
|
|
const auto is_valid_qk_vec = qk_vec_idx < Dh;
|
|
|
|
const bool load_qkv_quant = params.qkv_scale_quant_orig != nullptr;
|
|
const bool write_attention_quant = params.attention_out_scale_orig_quant != nullptr;
|
|
|
|
// Quant/Dequant scales for 8bits kv cache.
|
|
using T_scale = typename kv_cache_scale_type_t<T, Tcache>::Type;
|
|
T_scale kv_scale_quant_orig, kv_scale_orig_quant;
|
|
convert_from_float(&kv_scale_quant_orig, (ENABLE_8BITS_CACHE ? params.kv_scale_quant_orig[0] : 1.0f));
|
|
convert_from_float(&kv_scale_orig_quant, (ENABLE_8BITS_CACHE ? params.kv_scale_orig_quant[0] : 1.0f));
|
|
|
|
// Up to QK_VECS_PER_Dh_MAX threads load Q and K + the bias values for the current timestep.
|
|
// Trigger the loads from the Q and K buffers.
|
|
Qk_vec_k q, k, q_bias, k_bias;
|
|
zero(q);
|
|
zero(k);
|
|
zero(q_bias);
|
|
zero(k_bias);
|
|
if (is_valid_qk_vec)
|
|
{
|
|
// Query
|
|
// The stride between tokens. We may be able to always use params.stride.
|
|
uint32_t q_stride = params.stride ? static_cast<uint32_t>(params.stride) : (num_heads * Dh);
|
|
// The offset.
|
|
const auto q_offset = tensorrt_llm::common::flat_index_strided3(bi, hi, qk_vec_idx, q_stride, Dh);
|
|
|
|
if (load_qkv_quant)
|
|
{
|
|
using Packed_Int8_t = typename packed_type<int8_t, num_elems<Qk_vec_m>::value>::type;
|
|
using Packed_Float_t = typename packed_type<float, num_elems<Qk_vec_m>::value>::type;
|
|
const auto q_scaling = params.qkv_scale_quant_orig[0];
|
|
const auto q_quant
|
|
= *reinterpret_cast<const Packed_Int8_t*>(&reinterpret_cast<const int8_t*>(params.q)[q_offset]);
|
|
convert_from_float(&q, mul<Packed_Float_t, float>(q_scaling, float_from_int8(q_quant)));
|
|
}
|
|
else
|
|
{
|
|
q = vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(¶ms.q[q_offset]));
|
|
}
|
|
|
|
// Key
|
|
// The stride between tokens. We may be able to always use params.stride.
|
|
uint32_t k_stride = params.stride ? static_cast<uint32_t>(params.stride) : (num_heads_kv * Dh);
|
|
// The offset.
|
|
const auto k_offset = tensorrt_llm::common::flat_index_strided3(bi, hi_kv, qk_vec_idx, k_stride, Dh);
|
|
|
|
if (load_qkv_quant)
|
|
{
|
|
using Packed_Int8_t = typename packed_type<int8_t, num_elems<Qk_vec_m>::value>::type;
|
|
using Packed_Float_t = typename packed_type<float, num_elems<Qk_vec_m>::value>::type;
|
|
const auto k_scaling = params.qkv_scale_quant_orig[1];
|
|
const auto k_quant
|
|
= *reinterpret_cast<const Packed_Int8_t*>(&reinterpret_cast<const int8_t*>(params.k)[k_offset]);
|
|
|
|
convert_from_float(&k, mul<Packed_Float_t, float>(k_scaling, float_from_int8(k_quant)));
|
|
}
|
|
else
|
|
{
|
|
k = vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(¶ms.k[k_offset]));
|
|
}
|
|
|
|
if (params.q_bias != nullptr)
|
|
{
|
|
const auto q_bias_offset = tensorrt_llm::common::flat_index2(hi, qk_vec_idx, Dh);
|
|
q_bias
|
|
= vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(¶ms.q_bias[q_bias_offset]));
|
|
}
|
|
if (params.k_bias != nullptr)
|
|
{
|
|
const auto k_bias_offset = tensorrt_llm::common::flat_index2(hi_kv, qk_vec_idx, Dh);
|
|
k_bias
|
|
= vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(¶ms.k_bias[k_bias_offset]));
|
|
}
|
|
}
|
|
|
|
// Computes the Q/K values with bias.
|
|
q = add(q, q_bias);
|
|
k = add(k, k_bias);
|
|
|
|
const bool do_ia3 = params.ia3_tasks != nullptr;
|
|
const auto beam_width = static_cast<unsigned>(params.beam_width);
|
|
const auto ia3_ti_hi = do_ia3
|
|
? tensorrt_llm::common::flat_index2(static_cast<unsigned>(params.ia3_tasks[bi / beam_width]), hi, num_heads)
|
|
: 0;
|
|
|
|
if (do_ia3 && is_valid_qk_vec)
|
|
{
|
|
k = mul<Qk_vec_k, Qk_vec_k, Qk_vec_k>(k,
|
|
vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(
|
|
¶ms.ia3_key_weights[tensorrt_llm::common::flat_index2(ia3_ti_hi, qk_vec_idx, Dh)])));
|
|
}
|
|
|
|
// Note we have no paddings in KV cache now.
|
|
switch (params.position_embedding_type)
|
|
{
|
|
case PositionEmbeddingType::kLEARNED_ABSOLUTE:
|
|
case PositionEmbeddingType::kALIBI: break;
|
|
case PositionEmbeddingType::kROPE_GPTJ:
|
|
{
|
|
apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, tlength);
|
|
break;
|
|
}
|
|
case PositionEmbeddingType::kROPE_GPT_NEOX:
|
|
{
|
|
const bool do_rotary = is_valid_qk_vec && QK_VEC_SIZE * tidx < params.rotary_embedding_dim;
|
|
|
|
T* q_smem_ = reinterpret_cast<T*>(smem_);
|
|
T* k_smem = q_smem_ + params.rotary_embedding_dim;
|
|
|
|
const int half_rotary_dim = params.rotary_embedding_dim / 2;
|
|
const int half_idx = qk_vec_idx / half_rotary_dim;
|
|
const int intra_half_idx = qk_vec_idx % half_rotary_dim;
|
|
const int smem_pitch = half_rotary_dim; // TODO: adjust for bank conflicts
|
|
|
|
assert(half_rotary_dim % QK_VEC_SIZE == 0);
|
|
|
|
if (do_rotary)
|
|
{
|
|
*reinterpret_cast<Qk_vec_k*>(q_smem_ + half_idx * smem_pitch + intra_half_idx) = q;
|
|
*reinterpret_cast<Qk_vec_k*>(k_smem + half_idx * smem_pitch + intra_half_idx) = k;
|
|
}
|
|
|
|
__syncthreads();
|
|
|
|
const int transpose_idx = half_idx * (half_rotary_dim / 2) + intra_half_idx / 2;
|
|
constexpr int tidx_factor = (QK_VEC_SIZE > 1) ? QK_VEC_SIZE / 2 : 1;
|
|
if (do_rotary)
|
|
{
|
|
mmha::vec_from_smem_transpose(q, q_smem_, transpose_idx, smem_pitch);
|
|
mmha::vec_from_smem_transpose(k, k_smem, transpose_idx, smem_pitch);
|
|
|
|
mmha::apply_rotary_embedding(q, k, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength);
|
|
|
|
mmha::write_smem_transpose(k, k_smem, transpose_idx, smem_pitch);
|
|
mmha::write_smem_transpose(q, q_smem_, transpose_idx, smem_pitch);
|
|
}
|
|
|
|
__syncthreads();
|
|
|
|
if (do_rotary)
|
|
{
|
|
q = *reinterpret_cast<Qk_vec_k*>(q_smem_ + half_idx * smem_pitch + intra_half_idx);
|
|
k = *reinterpret_cast<Qk_vec_k*>(k_smem + half_idx * smem_pitch + intra_half_idx);
|
|
}
|
|
|
|
__syncthreads();
|
|
break;
|
|
}
|
|
}
|
|
|
|
if (qk_vec_idx < Dh_MAX)
|
|
{
|
|
|
|
// Store the Q values to shared memory.
|
|
// Set padded Dh to 0 for the correctness of QK (when Dh != Dh_Max).
|
|
Qk_vec_k zero_q;
|
|
zero(zero_q);
|
|
|
|
*reinterpret_cast<Qk_vec_k*>(&q_smem[qk_vec_idx]) = is_valid_qk_vec ? q : zero_q;
|
|
// Write the K values to the global memory cache.
|
|
//
|
|
// NOTE: The stores are uncoalesced as we have multiple chunks of 16B spread across the memory
|
|
// system. We designed it this way as it allows much better memory loads (and there are many
|
|
// more loads) + the stores are really "write and forget" since we won't need the ack before
|
|
// the end of the kernel. There's plenty of time for the transactions to complete.
|
|
|
|
// For MQA/GQA mode, write only with the first Q head of each group per KV head.
|
|
if (hi == (hi_kv * qhead_per_kv) && (IS_Dh_MAX || is_valid_qk_vec))
|
|
{
|
|
// Trigger the stores to global memory.
|
|
const auto k_idx = QK_VEC_SIZE * tidx;
|
|
const int inBlockIdx = kvCacheBuffer.getKVLocalIdx(tlength, hi_kv, Dh, k_idx);
|
|
// The base pointer for the value in the cache buffer.
|
|
Tcache* k_cache = reinterpret_cast<Tcache*>(kvCacheBuffer.getKBlockPtr(bi, tlength));
|
|
|
|
if constexpr (ENABLE_8BITS_CACHE)
|
|
{
|
|
store_8bits_kv_cache_vec(reinterpret_cast<Tcache*>(k_cache), k, inBlockIdx, kv_scale_orig_quant);
|
|
}
|
|
else
|
|
{
|
|
*reinterpret_cast<Qk_vec_m*>(&k_cache[inBlockIdx]) = vec_conversion<Qk_vec_m, Qk_vec_k>(k);
|
|
}
|
|
}
|
|
|
|
// Compute \sum_i Q[i] * K^T[i] for the current timestep.
|
|
#ifdef MMHA_USE_FP32_ACUM_FOR_FMA
|
|
using Qk_vec_acum = typename Qk_vec_acum_fp32_<Qk_vec_k>::Type;
|
|
#else
|
|
using Qk_vec_acum = Qk_vec_k;
|
|
#endif
|
|
qk = dot<Qk_vec_acum, Qk_vec_k>(q, k);
|
|
if (QK_VECS_PER_Dh_MAX <= WARP_SIZE)
|
|
{
|
|
#pragma unroll
|
|
for (int mask = QK_VECS_PER_Dh_MAX / 2; mask >= 1; mask /= 2)
|
|
{
|
|
qk += __shfl_xor_sync(shfl_mask(QK_VECS_PER_Dh_MAX), qk, mask);
|
|
}
|
|
}
|
|
}
|
|
|
|
if (QK_VECS_PER_Dh_MAX > WARP_SIZE)
|
|
{
|
|
constexpr int WARPS_PER_RED = (QK_VECS_PER_Dh_MAX + WARP_SIZE - 1) / WARP_SIZE;
|
|
qk = block_sum<WARPS_PER_RED>(&red_smem[WARPS_PER_RED], qk);
|
|
}
|
|
|
|
// Store that value in shared memory. Keep the Q*K^T value in register for softmax.
|
|
if (tidx == 0)
|
|
{
|
|
// Normalize qk.
|
|
qk *= params.inv_sqrt_dh;
|
|
if (params.relative_attention_bias != nullptr)
|
|
{
|
|
qk = add(qk,
|
|
params.relative_attention_bias[hi * params.relative_attention_bias_stride
|
|
* params.relative_attention_bias_stride
|
|
+ tlength * params.relative_attention_bias_stride + tlength]);
|
|
}
|
|
// We don't need to apply the linear position bias here since qi - ki = 0 yields the position bias 0.
|
|
|
|
qk_max = qk;
|
|
// qk_smem[params.timestep] = qk;
|
|
if (MULTI_BLOCK_FLAG)
|
|
{
|
|
qk_current_smem[0] = qk;
|
|
}
|
|
else
|
|
{
|
|
qk_smem[tlength] = qk;
|
|
}
|
|
}
|
|
|
|
// Make sure the data is in shared memory.
|
|
__syncthreads();
|
|
|
|
constexpr unsigned K_ELTS_PER_CHUNK{THREADS_PER_KEY * K_VEC_SIZE};
|
|
|
|
// The positions of the cache buffer (for this B * H) and the vector within that chunk associated with this
|
|
// thread.
|
|
const auto k_idx = chunk_index<T, K_vec_k, THREADS_PER_KEY>(tidx);
|
|
|
|
// The number of vectors per thread.
|
|
constexpr unsigned K_VECS_PER_THREAD{Dh_MAX / K_ELTS_PER_CHUNK};
|
|
static_assert(Dh_MAX == K_ELTS_PER_CHUNK * K_VECS_PER_THREAD);
|
|
|
|
// Load the Q values from shared memory. The values are reused during the loop on K.
|
|
K_vec_k q_vec[K_VECS_PER_THREAD];
|
|
#pragma unroll
|
|
for (unsigned ii = 0; ii < K_VECS_PER_THREAD; ++ii)
|
|
{
|
|
q_vec[ii] = *reinterpret_cast<const K_vec_k*>(
|
|
&q_smem[tensorrt_llm::common::flat_index2(ii, k_idx.y, K_ELTS_PER_CHUNK)]);
|
|
}
|
|
|
|
// The number of timesteps loaded per iteration, i.e., (THREADS_PER_BLOCK * THREADS_PER_BLOCK) / 256 <= 256
|
|
constexpr unsigned K_PER_ITER{THREADS_PER_BLOCK / THREADS_PER_KEY};
|
|
// The number of keys per warp.
|
|
constexpr unsigned K_PER_WARP{WARP_SIZE / THREADS_PER_KEY};
|
|
// The number of unrolled keys per warp.
|
|
constexpr unsigned UNROLLED_K_PER_WARP = K_PER_WARP * K_LOOP_UNROLL;
|
|
// The number of unrolled keys per ieration.
|
|
constexpr unsigned UNROLLED_K_PER_ITER = K_PER_ITER * K_LOOP_UNROLL;
|
|
|
|
// Base pointer for the row of pointers to k cache blocks
|
|
void** k_cache_base_row_ptr = reinterpret_cast<void**>(kvCacheBuffer.getRowPtr(KVIdxType::K_IDX, bi));
|
|
|
|
const auto timesteps_per_block = static_cast<unsigned>(params.timesteps_per_block);
|
|
|
|
// Pick a number of keys to make sure all the threads of a warp enter (due to shfl_sync).
|
|
// Take all previous cache as context when we have no beam searching in order to batch as many LDGs as possible.
|
|
const int context_length = HAS_BEAMS ? input_length : tlength;
|
|
const auto context_ti_end = MULTI_BLOCK_FLAG
|
|
|
|
? divUp(timesteps_per_block, UNROLLED_K_PER_WARP) * UNROLLED_K_PER_WARP
|
|
: divUp(static_cast<unsigned>(context_length), UNROLLED_K_PER_WARP) * UNROLLED_K_PER_WARP;
|
|
|
|
// The generation ti_end.
|
|
const auto generation_ti_end = MULTI_BLOCK_FLAG ? divUp(timesteps_per_block, K_PER_WARP) * K_PER_WARP
|
|
: divUp(static_cast<unsigned>(tlength), K_PER_WARP) * K_PER_WARP;
|
|
|
|
// Iterate over the keys/timesteps to compute the various (Q*K^T)_{ti} values.
|
|
const auto bi_seq_len_offset = static_cast<std::size_t>(bi) * max_seq_len;
|
|
const int* beam_indices = HAS_BEAMS ? ¶ms.cache_indir[bi_seq_len_offset] : nullptr;
|
|
|
|
const auto c_tile_times_timesteps_per_block = c_tile * timesteps_per_block; // 0 if !MULTI_BLOCK_FLAG
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////
|
|
// Key cache loops for dot(Q, K).
|
|
|
|
// Handle only context key cache with beam searching.
|
|
// Handle both context and generation key cache without beam searching.
|
|
// Explict batching of LDGs (by K_LOOP_UNROLL) as it doesn't depend on indirection tables.
|
|
for (int ti = k_idx.x; ti < context_ti_end; ti += UNROLLED_K_PER_ITER)
|
|
{
|
|
const int time_now = MULTI_BLOCK_FLAG ? ti + c_tile_times_timesteps_per_block : ti;
|
|
|
|
// The keys loaded from the key cache.
|
|
K_vec_m k_vec_cache[K_LOOP_UNROLL][K_VECS_PER_THREAD];
|
|
|
|
#pragma unroll
|
|
for (int k_loop = 0; k_loop < K_LOOP_UNROLL; ++k_loop)
|
|
{
|
|
#pragma unroll
|
|
for (int k_vec_i = 0; k_vec_i < K_VECS_PER_THREAD; ++k_vec_i)
|
|
{
|
|
// Make sure we read data within the bound.
|
|
// Dh OOB values will be handled by zero_q.
|
|
// Seq OOB values will be masked out when storing back to smem.
|
|
auto const jj = min(k_idx.y + k_vec_i * K_ELTS_PER_CHUNK, Dh - K_VEC_SIZE);
|
|
const int valid_time_now = min(time_now + k_loop * K_PER_ITER, context_length - 1);
|
|
const int seqIdx = bi / beam_width * beam_width;
|
|
// Base pointer to k cache block for beam's batch
|
|
Tcache* k_cache_batch = reinterpret_cast<Tcache*>(kvCacheBuffer.getKBlockPtr(seqIdx, valid_time_now));
|
|
|
|
int inBlockIdx = kvCacheBuffer.getKVLocalIdx(valid_time_now, hi_kv, Dh, jj);
|
|
k_vec_cache[k_loop][k_vec_i] = *reinterpret_cast<const K_vec_m*>(&k_cache_batch[inBlockIdx]);
|
|
}
|
|
}
|
|
|
|
#pragma unroll
|
|
for (int k_loop = 0; k_loop < K_LOOP_UNROLL; ++k_loop)
|
|
{
|
|
const int local_time_now = time_now + k_loop * K_PER_ITER;
|
|
const int local_ti = ti + k_loop * K_PER_ITER;
|
|
// Perform the dot product and normalize qk.
|
|
//
|
|
// WARNING: ALL THE THREADS OF A WARP MUST ENTER!!!
|
|
K_vec_k k_vec[K_VECS_PER_THREAD];
|
|
#pragma unroll
|
|
for (int k_vec_i = 0; k_vec_i < K_VECS_PER_THREAD; ++k_vec_i)
|
|
{
|
|
// we move quantization to here for better batching of inflight LDGs.
|
|
if constexpr (ENABLE_8BITS_CACHE)
|
|
{
|
|
convert_from_8bit_kv_cache<K_vec_m, K_vec_k, Tcache, T_scale>(
|
|
&k_vec[k_vec_i], k_vec_cache[k_loop][k_vec_i], kv_scale_quant_orig);
|
|
}
|
|
else
|
|
{
|
|
// K_vek is same as K_vec_cache in this case.
|
|
k_vec[k_vec_i] = *reinterpret_cast<K_vec_k*>(&k_vec_cache[k_loop][k_vec_i]);
|
|
}
|
|
}
|
|
|
|
float qk_{Qk_dot<T, THREADS_PER_KEY>::dot(q_vec, k_vec) * params.inv_sqrt_dh};
|
|
|
|
// For multi-block mode, we still need to make sure it will not be OOB.
|
|
if (MULTI_BLOCK_FLAG && local_ti >= timesteps_per_block)
|
|
{
|
|
continue;
|
|
}
|
|
|
|
// Store the product to shared memory. There's one qk value per timestep. Update the max.
|
|
if (local_time_now < context_length && tidx % THREADS_PER_KEY == 0)
|
|
{
|
|
if (params.relative_attention_bias != nullptr)
|
|
{
|
|
qk_ = add(qk_,
|
|
params.relative_attention_bias[hi * params.relative_attention_bias_stride
|
|
* params.relative_attention_bias_stride
|
|
+ tlength * params.relative_attention_bias_stride + local_time_now]);
|
|
}
|
|
if (params.linear_bias_slopes != nullptr)
|
|
{
|
|
// Apply the linear position bias: (ki - qi) * slope[hi].
|
|
// The padding token locates between the input context and the generated tokens.
|
|
// We need to remove the number of padding tokens in the distance computation.
|
|
// ti : 0 1 2 3 4 5 6 7 8 9(tlength)
|
|
// token: i i i i p p p o o o where i=input, p=pad, o=output.
|
|
// e.g. ti = 2, dist = (9 - 3) - 2 = 4.
|
|
float dist = local_time_now - tlength;
|
|
|
|
qk_ += mul<float, T, float>(params.linear_bias_slopes[hi], dist);
|
|
}
|
|
|
|
// Calculate the max for softmax, and store qk back to smem.
|
|
// Don't need mask here as we remove paddings in kv cache.
|
|
qk_max = fmaxf(qk_max, qk_);
|
|
qk_smem[local_ti] = qk_;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Handle generation key cache with beam searching.
|
|
// Note that it may be overlapped with the context key loop, but it won't impact the corretness.
|
|
if (HAS_BEAMS)
|
|
{
|
|
// For multi-block mode, the last few blocks will handle the generation key cache.
|
|
if (!MULTI_BLOCK_FLAG || (c_tile + 1) * timesteps_per_block > input_length)
|
|
{
|
|
const int generation_start_ti = k_idx.x
|
|
+ ((MULTI_BLOCK_FLAG ? input_length % timesteps_per_block : input_length) / K_PER_WARP) * K_PER_WARP;
|
|
for (int ti = generation_start_ti; ti < generation_ti_end; ti += K_PER_ITER)
|
|
{
|
|
const int time_now = MULTI_BLOCK_FLAG ? ti + c_tile_times_timesteps_per_block : ti;
|
|
|
|
// The keys loaded from the key cache.
|
|
K_vec_k k_vec[K_VECS_PER_THREAD];
|
|
|
|
#pragma unroll
|
|
for (int k_vec_i = 0; k_vec_i < K_VECS_PER_THREAD; ++k_vec_i)
|
|
{
|
|
const int jj = min(k_idx.y + k_vec_i * K_ELTS_PER_CHUNK, Dh - K_VEC_SIZE);
|
|
const int valid_time_now = min(time_now, tlength - 1);
|
|
int beam_offset = beam_indices[valid_time_now];
|
|
const int seqIdx = bi / beam_width * beam_width + beam_offset;
|
|
// Base pointer to k cache block for beam's batch, before offsetting with indirection buffer
|
|
Tcache* k_cache_batch
|
|
= reinterpret_cast<Tcache*>(kvCacheBuffer.getKBlockPtr(seqIdx, valid_time_now));
|
|
|
|
int inBlockIdx = kvCacheBuffer.getKVLocalIdx(valid_time_now, hi_kv, Dh, jj);
|
|
if constexpr (ENABLE_8BITS_CACHE)
|
|
{
|
|
load_8bits_kv_cache_vec(&k_vec[k_vec_i], k_cache_batch, inBlockIdx, kv_scale_quant_orig);
|
|
}
|
|
else
|
|
{
|
|
k_vec[k_vec_i] = (*reinterpret_cast<const K_vec_k*>(&k_cache_batch[inBlockIdx]));
|
|
}
|
|
}
|
|
|
|
// Perform the dot product and normalize qk.
|
|
//
|
|
// WARNING: ALL THE THREADS OF A WARP MUST ENTER!!!
|
|
float qk_{Qk_dot<T, THREADS_PER_KEY>::dot(q_vec, k_vec) * params.inv_sqrt_dh};
|
|
|
|
// Store the product to shared memory. There's one qk value per timestep. Update the max.
|
|
if (time_now >= input_length && time_now < tlength && tidx % THREADS_PER_KEY == 0)
|
|
{
|
|
if (params.relative_attention_bias != nullptr)
|
|
{
|
|
qk_ = add(qk_,
|
|
params.relative_attention_bias[hi * params.relative_attention_bias_stride
|
|
* params.relative_attention_bias_stride
|
|
+ tlength * params.relative_attention_bias_stride + time_now]);
|
|
}
|
|
if (params.linear_bias_slopes != nullptr)
|
|
{
|
|
// Apply the linear position bias: (ki - qi) * slope[hi].
|
|
// The padding token locates between the input context and the generated tokens.
|
|
// We need to remove the number of padding tokens in the distance computation.
|
|
// ti : 0 1 2 3 4 5 6 7 8 9(tlength)
|
|
// token: i i i i p p p o o o where i=input, p=pad, o=output.
|
|
// e.g. ti = 2, dist = (9 - 3) - 2 = 4.
|
|
float dist = time_now - tlength;
|
|
|
|
qk_ += mul<float, T, float>(params.linear_bias_slopes[hi], dist);
|
|
}
|
|
|
|
// Calculate the max for softmax, and store qk back to smem.
|
|
qk_max = fmaxf(qk_max, qk_);
|
|
qk_smem[ti] = qk_;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////
|
|
// Softmax.
|
|
|
|
// Perform the final reduction to compute the max inside each warp.
|
|
//
|
|
// NOTE: In a group of THREADS_PER_KEY threads, the leader already has the max value for the
|
|
// group so it's not needed to run the reduction inside the group (again).
|
|
#pragma unroll
|
|
for (int mask = WARP_SIZE / 2; mask >= THREADS_PER_KEY; mask /= 2)
|
|
{
|
|
qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
|
|
}
|
|
|
|
// Decompose the thread index into warp and lane.
|
|
const auto warp = tidx / WARP_SIZE;
|
|
const auto lane = tidx % WARP_SIZE;
|
|
|
|
// The warp leader writes the max to shared memory.
|
|
if (lane == 0)
|
|
{
|
|
red_smem[warp] = qk_max;
|
|
}
|
|
|
|
// Make sure the products are in shared memory.
|
|
__syncthreads();
|
|
|
|
// The warps finalize the reduction.
|
|
qk_max = lane < WARPS_PER_BLOCK ? red_smem[lane] : -FLT_MAX;
|
|
#pragma unroll
|
|
for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2)
|
|
{
|
|
qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
|
|
}
|
|
|
|
// Broadcast to all the threads in the warp.
|
|
qk_max = __shfl_sync(uint32_t(-1), qk_max, 0);
|
|
|
|
// Compute the logits and start the sum.
|
|
float sum = 0.f;
|
|
|
|
// Each thread will handle one float (either qk_smem/logit).
|
|
const int logit_loop_end = MULTI_BLOCK_FLAG ? timesteps_per_block : tlength;
|
|
for (int ti = tidx; ti <= logit_loop_end; ti += THREADS_PER_BLOCK)
|
|
{
|
|
|
|
const int time_now = MULTI_BLOCK_FLAG ? ti + c_tile_times_timesteps_per_block : ti;
|
|
|
|
// For single-block mode, we don't need the mask since it has been skipped.
|
|
if (!MULTI_BLOCK_FLAG)
|
|
{
|
|
float logit = __expf(qk_smem[time_now] - qk_max);
|
|
sum += logit;
|
|
qk_smem[time_now] = logit;
|
|
}
|
|
else
|
|
{
|
|
// Not supported yet: multi-block mode with FP8_MHA
|
|
if (time_now < tlength && ti != timesteps_per_block)
|
|
{
|
|
float logit = __expf(qk_smem[ti] - qk_max);
|
|
sum += logit;
|
|
qk_smem[ti] = logit;
|
|
}
|
|
else if (time_now == tlength)
|
|
{
|
|
float logit = __expf(qk_current_smem[0] - qk_max);
|
|
sum += logit;
|
|
qk_current_smem[0] = logit;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Compute the sum.
|
|
sum = block_sum<WARPS_PER_BLOCK>(&red_smem[WARPS_PER_BLOCK], sum);
|
|
|
|
// Normalize the logits.
|
|
float inv_sum = __fdividef(1.f, sum + 1.e-6f);
|
|
|
|
const int normlization_loop_end = MULTI_BLOCK_FLAG ? timesteps_per_block : tlength;
|
|
for (int ti = tidx; ti <= normlization_loop_end; ti += THREADS_PER_BLOCK)
|
|
{
|
|
|
|
const int time_now = MULTI_BLOCK_FLAG ? ti + c_tile_times_timesteps_per_block : ti;
|
|
|
|
if (!MULTI_BLOCK_FLAG)
|
|
{
|
|
convert_from_float(&logits_smem[ti], qk_smem[ti] * inv_sum);
|
|
}
|
|
else
|
|
{
|
|
// no scaling factor inv_sum applied here, will apply the scaling factor after all blocks finished
|
|
if (time_now < tlength && ti != timesteps_per_block)
|
|
{
|
|
convert_from_float(&logits_smem[ti], qk_smem[ti]);
|
|
}
|
|
else if (time_now == tlength)
|
|
{
|
|
convert_from_float(&logits_current_smem[0], qk_current_smem[0]);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Put Values part below so we leverage __syncthreads
|
|
// from the previous step
|
|
|
|
const auto v_idx = chunk_index<T, V_vec_k, THREADS_PER_VALUE>(tidx);
|
|
// The value computed by this thread.
|
|
const auto vo = v_idx.x;
|
|
// The hidden dimensions computed by this particular thread.
|
|
const auto vi = v_idx.y;
|
|
// Base pointer for the row of pointers to v cache blocks
|
|
void** v_cache_base_row_ptr = reinterpret_cast<void**>(kvCacheBuffer.getRowPtr(KVIdxType::V_IDX, bi));
|
|
// Base pointer for the row of pointers to v cache blocks for beam's batch, before offsetting with indirection
|
|
// buffer
|
|
void** v_cache_batch_row_ptr
|
|
= reinterpret_cast<void**>(kvCacheBuffer.getRowPtr(KVIdxType::V_IDX, bi / beam_width * beam_width));
|
|
|
|
// The number of values processed per iteration of the loop.
|
|
constexpr unsigned V_PER_ITER{THREADS_PER_BLOCK / THREADS_PER_VALUE};
|
|
// The number of unrolled keys per ieration.
|
|
constexpr unsigned UNROLLED_V_PER_ITER = V_PER_ITER * V_LOOP_UNROLL;
|
|
|
|
bool const is_valid_vi = IS_Dh_MAX || vi < Dh;
|
|
|
|
// One group of threads computes the product(s) for the current timestep.
|
|
V_vec_k v_bias;
|
|
zero(v_bias);
|
|
// if( vo == params.timestep % V_PER_ITER ) {
|
|
if (is_valid_vi && vo == tlength % V_PER_ITER)
|
|
{
|
|
// Trigger the loads from the V bias buffer.
|
|
if (params.v_bias != nullptr)
|
|
{
|
|
const auto v_bias_offset = tensorrt_llm::common::flat_index2(hi_kv, vi, Dh);
|
|
v_bias = *reinterpret_cast<const V_vec_k*>(¶ms.v_bias[v_bias_offset]);
|
|
}
|
|
}
|
|
|
|
// From previous, before values, step
|
|
// Also make sure the logits are in shared memory.
|
|
__syncthreads();
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////
|
|
// Value cache loops.
|
|
|
|
#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
|
|
using V_vec_acum = typename V_vec_acum_fp32_<V_vec_k>::Type;
|
|
#else
|
|
using V_vec_acum = V_vec_k;
|
|
#endif
|
|
// The partial outputs computed by each thread.
|
|
V_vec_acum out;
|
|
zero(out);
|
|
|
|
// Loop over the timesteps to compute the partial outputs.
|
|
if (is_valid_vi)
|
|
{
|
|
// Handle only context value cache with beam searching.
|
|
// Handle both context and generation value cache without beam searching.
|
|
// Explict batching of LDGs (by V_LOOP_UNROLL) as it doesn't depend on indirection tables.
|
|
// Take all previous cache as context when we have no beam searching in order to batch as many LDGs as possible.
|
|
const int context_length = HAS_BEAMS ? input_length : tlength;
|
|
int context_v_loop_end = MULTI_BLOCK_FLAG ? timesteps_per_block : context_length;
|
|
int generation_v_loop_end = MULTI_BLOCK_FLAG ? timesteps_per_block : tlength;
|
|
for (int ti = vo; ti < context_v_loop_end; ti += UNROLLED_V_PER_ITER)
|
|
{
|
|
V_vec_m v_vec_cache[V_LOOP_UNROLL];
|
|
#pragma unroll
|
|
for (int v_loop = 0; v_loop < V_LOOP_UNROLL; v_loop++)
|
|
{
|
|
// Fetch offset based on cache_indir when beam sampling
|
|
int time_idx = ti + v_loop * V_PER_ITER + (MULTI_BLOCK_FLAG ? c_tile_times_timesteps_per_block : 0);
|
|
time_idx = min(time_idx, tlength - 1);
|
|
int rowIdx = bi / beam_width * beam_width;
|
|
|
|
const int inBlockIdx = kvCacheBuffer.getKVLocalIdx(time_idx, hi_kv, Dh, vi);
|
|
// The base pointer for the value in the cache buffer.
|
|
Tcache* v_cache_batch = reinterpret_cast<Tcache*>(kvCacheBuffer.getVBlockPtr(rowIdx, time_idx));
|
|
|
|
v_vec_cache[v_loop] = *reinterpret_cast<const V_vec_m*>(&v_cache_batch[inBlockIdx]);
|
|
}
|
|
|
|
#pragma unroll
|
|
for (int v_loop = 0; v_loop < V_LOOP_UNROLL; v_loop++)
|
|
{
|
|
V_vec_k v_vec;
|
|
// we move quantization to here for better batching of inflight LDGs.
|
|
if constexpr (ENABLE_8BITS_CACHE)
|
|
{
|
|
convert_from_8bit_kv_cache<V_vec_m, V_vec_k, Tcache, T_scale>(
|
|
&v_vec, v_vec_cache[v_loop], kv_scale_quant_orig);
|
|
}
|
|
else
|
|
{
|
|
// V_vek is same as V_vec_cache in this case.
|
|
v_vec = *reinterpret_cast<V_vec_k*>(&v_vec_cache[v_loop]);
|
|
}
|
|
|
|
int local_time_idx = ti + v_loop * V_PER_ITER;
|
|
int time_idx = local_time_idx + (MULTI_BLOCK_FLAG ? c_tile_times_timesteps_per_block : 0);
|
|
const bool is_mask
|
|
= (MULTI_BLOCK_FLAG && local_time_idx >= timesteps_per_block) || (time_idx >= context_length);
|
|
// Load the logits from shared memory.
|
|
if (!is_mask)
|
|
{
|
|
#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
|
|
float logit = logits_smem[local_time_idx];
|
|
out = fma(logit, cast_to_float(v_vec), out);
|
|
#else // MMHA_USE_FP32_ACUM_FOR_LOGITS
|
|
Tk logit = logits_smem[local_time_idx];
|
|
out = fma(logit, v_vec, out);
|
|
#endif // MMHA_USE_FP32_ACUM_FOR_LOGITS
|
|
}
|
|
}
|
|
}
|
|
|
|
// Handle generation value cache with beam searching.
|
|
if (HAS_BEAMS)
|
|
{
|
|
const auto generation_start_ti = MULTI_BLOCK_FLAG ? vo : (vo + (input_length / V_PER_ITER) * V_PER_ITER);
|
|
// Only the last few blocks need to handle the generation value cache.
|
|
if (!MULTI_BLOCK_FLAG || (c_tile + 1) * timesteps_per_block > input_length)
|
|
{
|
|
for (int ti = generation_start_ti; ti < generation_v_loop_end; ti += V_PER_ITER)
|
|
{
|
|
// Fetch offset based on cache_indir when beam sampling
|
|
int time_idx = ti + (MULTI_BLOCK_FLAG ? c_tile_times_timesteps_per_block : 0);
|
|
int local_time_idx = ti;
|
|
if (time_idx < input_length || (MULTI_BLOCK_FLAG && time_idx >= tlength))
|
|
{
|
|
continue;
|
|
}
|
|
int rowIdx = bi / beam_width * beam_width + beam_indices[time_idx];
|
|
|
|
V_vec_k v;
|
|
const int inBlockIdx = kvCacheBuffer.getKVLocalIdx(time_idx, hi_kv, Dh, vi);
|
|
// The base pointer for the value in the cache buffer.
|
|
Tcache* v_cache_batch = reinterpret_cast<Tcache*>(kvCacheBuffer.getVBlockPtr(rowIdx, time_idx));
|
|
|
|
if (ENABLE_8BITS_CACHE)
|
|
{
|
|
load_8bits_kv_cache_vec(&v, v_cache_batch, inBlockIdx, kv_scale_quant_orig);
|
|
}
|
|
else
|
|
{
|
|
v = *reinterpret_cast<const V_vec_k*>(&v_cache_batch[inBlockIdx]);
|
|
}
|
|
|
|
// Load the logits from shared memory.
|
|
#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
|
|
float logit = logits_smem[local_time_idx];
|
|
out = fma(logit, cast_to_float(v), out);
|
|
#else // MMHA_USE_FP32_ACUM_FOR_LOGITS
|
|
Tk logit = logits_smem[local_time_idx];
|
|
out = fma(logit, v, out);
|
|
#endif // MMHA_USE_FP32_ACUM_FOR_LOGITS
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// One group of threads computes the product(s) for the current timestep.
|
|
if (vo == tlength % V_PER_ITER && is_valid_vi && (!MULTI_BLOCK_FLAG || (c_tile == gridDim.z - 1)))
|
|
{
|
|
const int tokenIdx = tlength;
|
|
const int inBlockIdx = kvCacheBuffer.getKVLocalIdx(tokenIdx, hi_kv, Dh, vi);
|
|
// The base pointer for the value in the cache buffer.
|
|
Tcache* v_cache_base = reinterpret_cast<Tcache*>(kvCacheBuffer.getBlockPtr(v_cache_base_row_ptr, tokenIdx));
|
|
|
|
V_vec_k v;
|
|
// Trigger the loads from the V buffer.
|
|
// The stride between tokens. We may be able to always use params.stride.
|
|
uint32_t v_stride = params.stride ? static_cast<uint32_t>(params.stride) : (num_heads_kv * Dh);
|
|
// The offset.
|
|
const auto v_offset = tensorrt_llm::common::flat_index_strided3(bi, hi_kv, vi, v_stride, Dh);
|
|
|
|
if (load_qkv_quant)
|
|
{
|
|
using Packed_Int8_t = typename packed_type<int8_t, num_elems<V_vec_k>::value>::type;
|
|
using Packed_Float_t = typename packed_type<float, num_elems<V_vec_k>::value>::type;
|
|
const auto v_scaling = params.qkv_scale_quant_orig[2];
|
|
const auto v_quant
|
|
= *reinterpret_cast<const Packed_Int8_t*>(&reinterpret_cast<const int8_t*>(params.v)[v_offset]);
|
|
|
|
convert_from_float(&v, mul<Packed_Float_t, float>(v_scaling, float_from_int8(v_quant)));
|
|
}
|
|
else
|
|
{
|
|
v = *reinterpret_cast<const V_vec_k*>(¶ms.v[v_offset]);
|
|
}
|
|
|
|
// Compute the V values with bias.
|
|
v = add(v, v_bias);
|
|
|
|
if (do_ia3)
|
|
{
|
|
v = mul<V_vec_k, V_vec_k, V_vec_k>(v,
|
|
*reinterpret_cast<const V_vec_k*>(
|
|
¶ms.ia3_value_weights[tensorrt_llm::common::flat_index2(ia3_ti_hi, vi, Dh)]));
|
|
}
|
|
|
|
// Store the values with bias back to global memory in the cache for V.
|
|
//*reinterpret_cast<V_vec_k*>(&v_cache[params.timestep*Dh]) = v;
|
|
// For MQA/GQA mode, write only with the first Q head of each group per KV head.
|
|
if (hi == (hi_kv * qhead_per_kv))
|
|
{
|
|
if (ENABLE_8BITS_CACHE)
|
|
{
|
|
store_8bits_kv_cache_vec(v_cache_base, v, inBlockIdx, kv_scale_orig_quant);
|
|
}
|
|
else
|
|
{
|
|
*reinterpret_cast<V_vec_k*>(&v_cache_base[inBlockIdx]) = v;
|
|
}
|
|
}
|
|
|
|
// Initialize the output value with the current timestep.
|
|
#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
|
|
// out = fma(logits_smem[params.timestep], cast_to_float(v), out);
|
|
if (!MULTI_BLOCK_FLAG)
|
|
{
|
|
out = fma(logits_smem[tlength], cast_to_float(v), out);
|
|
}
|
|
else
|
|
{
|
|
out = fma(logits_current_smem[0], cast_to_float(v), out);
|
|
}
|
|
#else // MMHA_USE_FP32_ACUM_FOR_LOGITS
|
|
// out = fma(logits_smem[params.timestep], v, out);
|
|
if (!MULTI_BLOCK_FLAG)
|
|
{
|
|
out = fma(logits_smem[tlength], v, out);
|
|
}
|
|
else
|
|
{ // MULTI_BLOCK_FLAG // Not supported yet: multi-block mode with FP8_MHA
|
|
out = fma(logits_current_smem[0], v, out);
|
|
}
|
|
#endif // MMHA_USE_FP32_ACUM_FOR_LOGITS
|
|
}
|
|
|
|
// Make sure we can start writing to shared memory.
|
|
__syncthreads();
|
|
|
|
// Run the final reduction amongst the different groups computing different partial outputs.
|
|
#pragma unroll
|
|
for (int active_groups = V_PER_ITER; active_groups >= 2; active_groups /= 2)
|
|
{
|
|
|
|
// The midpoint in the number of active groups.
|
|
int midpoint = active_groups / 2;
|
|
|
|
// The upper part of active threads store to shared memory.
|
|
if (vo >= midpoint && vo < active_groups && (Dh == Dh_MAX || vi < Dh))
|
|
{
|
|
#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
|
|
convert_from_float(reinterpret_cast<V_vec_k*>(&out_smem[(vo - midpoint) * Dh + vi]), out);
|
|
#else
|
|
*reinterpret_cast<V_vec_k*>(&out_smem[(vo - midpoint) * Dh + vi]) = out;
|
|
#endif
|
|
}
|
|
__syncthreads();
|
|
|
|
// The bottom warps update their values.
|
|
if (vo < midpoint && (Dh == Dh_MAX || vi < Dh))
|
|
{
|
|
out = add(*reinterpret_cast<const V_vec_k*>(&out_smem[vo * Dh + vi]), out);
|
|
}
|
|
__syncthreads();
|
|
}
|
|
|
|
const auto bhi = tensorrt_llm::common::flat_index2(bi, hi, num_heads);
|
|
const auto bhi_seq_len_tile = bhi * params.max_seq_len_tile;
|
|
// Output the final values.
|
|
if (vo == 0 && (Dh == Dh_MAX || vi < Dh))
|
|
{
|
|
const auto bhvi = tensorrt_llm::common::flat_index2(bhi, vi, Dh);
|
|
#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
|
|
if (write_attention_quant)
|
|
{
|
|
using Packed_Int8_t = typename packed_type<int8_t, num_elems<V_vec_acum>::value>::type;
|
|
out = mul<V_vec_acum, float>(*params.attention_out_scale_orig_quant, out);
|
|
*reinterpret_cast<Packed_Int8_t*>(&(reinterpret_cast<int8_t*>(params.out)[bhvi])) = cast_to_int8(out);
|
|
}
|
|
else
|
|
{
|
|
if (!MULTI_BLOCK_FLAG)
|
|
{
|
|
// This makes sure we have coalesced memory access.
|
|
V_vec_k final_out;
|
|
convert_from_float(&final_out, out);
|
|
*reinterpret_cast<V_vec_k*>(¶ms.out[bhvi]) = final_out;
|
|
}
|
|
else
|
|
{
|
|
// for write partial output to partial_out
|
|
int partial_out_offset = c_tile * params.batch_size * num_heads * params.hidden_size_per_head;
|
|
// for write partial statistics to partial_max and partial_sum
|
|
int partial_stats_offset = bhi_seq_len_tile + c_tile;
|
|
|
|
// This makes sure we have coalesced memory access.
|
|
V_vec_k partial_out;
|
|
convert_from_float(&partial_out, out);
|
|
*reinterpret_cast<V_vec_k*>(¶ms.partial_out[partial_out_offset + bhvi]) = partial_out;
|
|
convert_from_float(reinterpret_cast<float*>(¶ms.partial_max[partial_stats_offset]), qk_max);
|
|
convert_from_float(reinterpret_cast<float*>(¶ms.partial_sum[partial_stats_offset]), sum);
|
|
}
|
|
}
|
|
#else // MMHA_USE_FP32_ACUM_FOR_OUT
|
|
*reinterpret_cast<V_vec_acum*>(¶ms.out[bhvi]) = out;
|
|
#endif // MMHA_USE_FP32_ACUM_FOR_OUT
|
|
}
|
|
|
|
#ifdef ENABLE_MULTI_BLOCK_OPTION
|
|
if (MULTI_BLOCK_FLAG)
|
|
{
|
|
|
|
cuda::atomic_ref<int, cuda::thread_scope_device> count_ref{params.block_counter[bhi]};
|
|
bool last_block{false};
|
|
if (tidx == 0)
|
|
{
|
|
if (count_ref.fetch_add(1, cuda::memory_order_acq_rel) == (gridDim.z - 1))
|
|
{
|
|
last_block = true;
|
|
}
|
|
}
|
|
|
|
////////////////////
|
|
////////////////////
|
|
// Make sure every threadblock finishes the previous computation, and enter the last threadblock in the
|
|
// following (for each B and H) Do the final computation in the last threadblock Final reduction computation
|
|
// by combining all the partial max/sum and outputs
|
|
////////////////////
|
|
////////////////////
|
|
if (__syncthreads_or(last_block))
|
|
{
|
|
|
|
////////////////////
|
|
// Find the global max from all partial max -> use CUB BlockReduce
|
|
////////////////////
|
|
|
|
float final_max = -FLT_MAX;
|
|
float thread_partial_max = -FLT_MAX;
|
|
if (tidx < gridDim.z)
|
|
thread_partial_max = params.partial_max[bhi_seq_len_tile + tidx];
|
|
// final_max = fmaxf(final_max, thread_partial_max);
|
|
|
|
// Make sure we can start writing to shared memory.
|
|
__syncthreads();
|
|
|
|
// Specialize BlockReduce for a 1D block of THREADS_PER_BLOCK threads of type int
|
|
typedef cub::BlockReduce<float, THREADS_PER_BLOCK> BlockReduce;
|
|
// Allocate shared memory for BlockReduce
|
|
__shared__ typename BlockReduce::TempStorage temp_storage;
|
|
// Obtain a segment of consecutive items that are blocked across threads (final_max from above)
|
|
// Compute the block-wide max for thread0
|
|
final_max = BlockReduce(temp_storage).Reduce(thread_partial_max, cub::Max(), gridDim.z);
|
|
|
|
__shared__ float final_max_smem;
|
|
if (tidx == 0)
|
|
{
|
|
final_max_smem = final_max;
|
|
}
|
|
__syncthreads();
|
|
|
|
// Finish the final_max computation
|
|
final_max = final_max_smem;
|
|
|
|
////////////////////
|
|
// Reduction for global sum over all partial sum (scaled by the exponential term from global max) -> use
|
|
// gridDim.z threads
|
|
////////////////////
|
|
|
|
float final_sum = 0.f;
|
|
if (tidx < gridDim.z)
|
|
{
|
|
thread_partial_max = params.partial_max[bhi_seq_len_tile + tidx];
|
|
const auto thread_partial_sum = params.partial_sum[bhi_seq_len_tile + tidx];
|
|
final_sum += __expf(thread_partial_max - final_max) * thread_partial_sum;
|
|
}
|
|
|
|
// Compute the final_sum.
|
|
final_sum = block_sum<WARPS_PER_BLOCK>(&red_smem[WARPS_PER_BLOCK], final_sum);
|
|
|
|
////////////////////
|
|
// Reduction for final output (scaled by the exponential term from global max) -> use THREADS_PER_VALUE
|
|
// * gridDim.z threads
|
|
////////////////////
|
|
|
|
// Shared memory to store partial outputs for each oi. -> size: gridDim.z * Dh * 4 Bytes. Reuse qk_smem.
|
|
T* out_oi_smem = reinterpret_cast<T*>(smem_);
|
|
|
|
// Number of threads to utilize: THREADS_PER_VALUE * gridDim.z (THREADS_PER_VALUE for vectorized output
|
|
// and gridDim.z for all the partial outputs)
|
|
int threads_boundary = THREADS_PER_VALUE * gridDim.z; // should be smaller than THREADS_PER_BLOCK
|
|
assert(threads_boundary <= THREADS_PER_BLOCK);
|
|
|
|
const auto o_idx = chunk_index<T, V_vec_k, THREADS_PER_VALUE>(tidx);
|
|
// The partial output region this thread takes care of
|
|
const auto oo = o_idx.x;
|
|
// The hidden dimensions computed by this particular thread. (refer to vi)
|
|
const auto oi = o_idx.y;
|
|
|
|
// Load partial output
|
|
int thread_partial_out_offset = oo * params.batch_size * num_heads * params.hidden_size_per_head;
|
|
// Load partial max (different to thread_partial_max since the threadIdx rule changes here)
|
|
float thread_partial_max_for_out = params.partial_max[bhi_seq_len_tile + oo];
|
|
|
|
// Load the partial outputs.
|
|
V_vec_k thread_partial_out
|
|
= *reinterpret_cast<const V_vec_k*>(¶ms.partial_out[thread_partial_out_offset + bhi * Dh + oi]);
|
|
|
|
if (tidx >= threads_boundary)
|
|
{
|
|
zero(thread_partial_out);
|
|
}
|
|
|
|
Tk factor_compute;
|
|
convert_from_float(&factor_compute, __expf(thread_partial_max_for_out - final_max));
|
|
|
|
thread_partial_out = mul<V_vec_k, Tk, V_vec_k>(factor_compute, thread_partial_out);
|
|
|
|
// Make sure we can start writing to shared memory.
|
|
__syncthreads();
|
|
|
|
// The reduction iteration should start with a number which is a power of 2
|
|
const auto reduction_iteration = static_cast<int>(cuda::std::bit_ceil(gridDim.z));
|
|
|
|
// Run the final reduction amongst the different groups computing different partial outputs.
|
|
#pragma unroll
|
|
for (int active_groups = reduction_iteration; active_groups >= 2; active_groups /= 2)
|
|
{
|
|
|
|
// The midpoint in the number of active groups.
|
|
int midpoint = active_groups / 2;
|
|
|
|
// The upper part of active threads store to shared memory.
|
|
if (oo >= midpoint && oo < active_groups && (Dh == Dh_MAX || oi < Dh))
|
|
{
|
|
*reinterpret_cast<V_vec_k*>(&out_oi_smem[(oo - midpoint) * Dh + oi]) = thread_partial_out;
|
|
}
|
|
__syncthreads();
|
|
|
|
// The bottom warps update their values.
|
|
if (oo < midpoint && (Dh == Dh_MAX || oi < Dh))
|
|
{
|
|
thread_partial_out
|
|
= add(thread_partial_out, *reinterpret_cast<const V_vec_k*>(&out_oi_smem[oo * Dh + oi]));
|
|
}
|
|
__syncthreads();
|
|
}
|
|
|
|
////////////////////
|
|
// Final output O * inv_sum
|
|
////////////////////
|
|
|
|
if (oo == 0 && (Dh == Dh_MAX || oi < Dh))
|
|
{
|
|
const auto inv_sum = __fdividef(1.f, final_sum + 1.e-6f);
|
|
|
|
Tk inv_sum_compute;
|
|
convert_from_float(&inv_sum_compute, inv_sum);
|
|
|
|
thread_partial_out = mul<V_vec_k, Tk, V_vec_k>(inv_sum_compute, thread_partial_out);
|
|
|
|
*reinterpret_cast<V_vec_k*>(¶ms.out[bhi * Dh + oi]) = thread_partial_out;
|
|
}
|
|
|
|
// Reset qk_current_smem and block_counter for the next timestep
|
|
if (tidx == 0)
|
|
{
|
|
params.block_counter[bhi] = 0;
|
|
}
|
|
}
|
|
}
|
|
#endif // ENABLE_MULTI_BLOCK_OPTION
|
|
}
|
|
|
|
} // namespace mmha
|
|
|
|
} // namespace kernels
|
|
} // namespace tensorrt_llm
|