mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
2817 lines
98 KiB
C++
2817 lines
98 KiB
C++
/*
|
|
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
#pragma once
|
|
|
|
#include "tensorrt_llm/common/cudaTypeUtils.cuh"
|
|
#include "tensorrt_llm/common/memoryUtils.h"
|
|
#include "tensorrt_llm/kernels/decoderMaskedMultiheadAttention.h"
|
|
#include "tensorrt_llm/kernels/decoderMaskedMultiheadAttentionUtils.h"
|
|
#include "tensorrt_llm/kernels/gptKernels.h"
|
|
#include "tensorrt_llm/kernels/kvCacheUtils.h"
|
|
#include <assert.h>
|
|
#include <float.h>
|
|
#include <type_traits>
|
|
|
|
// Multi-block mmha kernel can only be selected when CUDA >= 11.7
|
|
#if (CUDART_VERSION >= 11070)
|
|
#define ENABLE_MULTI_BLOCK_OPTION
|
|
#endif
|
|
|
|
#ifdef ENABLE_MULTI_BLOCK_OPTION
|
|
#include <cub/block/block_reduce.cuh>
|
|
#include <cuda/atomic>
|
|
#include <cuda/std/bit>
|
|
#endif // ENABLE_MULTI_BLOCK_OPTION
|
|
|
|
namespace tensorrt_llm
|
|
{
|
|
namespace kernels
|
|
{
|
|
|
|
// Use HMMA to compute with FP16/BF16 inputs and FP32 accumulators.
|
|
// #define MMHA_USE_HMMA
|
|
|
|
// Pre-scale Q or P to reduce number of instructions for dequantizing KV cache.
|
|
// If you notice a decrease in accuracy when the fp8 kv cache is enabled,
|
|
// consider disabling the two flags.
|
|
#ifdef ENABLE_FP8
|
|
// Apply the FP8 scaling to Q instead of K.
|
|
#define MMHA_FP8_SCALE_Q_INSTEAD_OF_K
|
|
// Apply the FP8 scaling to P instead of V.
|
|
#define MMHA_FP8_SCALE_P_INSTEAD_OF_V
|
|
#endif // !defined ENABLE_FP8
|
|
|
|
// Below are knobs to extend FP32 accumulation for higher FP16 accuracy
|
|
|
|
// Does not seem to affect the accuracy that much
|
|
#define MMHA_USE_FP32_ACCUM_FOR_FMA
|
|
|
|
// Seems to slightly improve the accuracy
|
|
#define MMHA_USE_FP32_ACCUM_FOR_OUT
|
|
|
|
#if 0 && defined(MMHA_USE_FP32_ACCUM_FOR_OUT)
|
|
// Does not seem to improve the accuracy
|
|
//#define MMHA_USE_FP32_ACCUM_FOR_LOGITS
|
|
#endif
|
|
|
|
namespace mmha
|
|
{
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
//
|
|
// We use the following terminology to describe the different dimensions.
|
|
//
|
|
// B: Batch size (number of sequences),
|
|
// L: Sequence length,
|
|
// D: Hidden dimension,
|
|
// H: Number of heads,
|
|
// Dh: Hidden dimension per head - Dh = D / H.
|
|
//
|
|
// The different kernels assign a threadblock for B x H pair. The grid has size (1, B, H). We use
|
|
// 256 threads per block to maximum occupancy and performance.
|
|
//
|
|
// Each threadblock loads Dh values from Q and its associated bias. The kernels run a loop to
|
|
// compute Q * K^T where K is loaded from a cache buffer -- except for the current timestep. The
|
|
// cache buffer helps with memory accesses and contains keys with bias.
|
|
//
|
|
// The layout of the cache buffer for the keys/values is [B, H, L, Dh]
|
|
// where the fastest moving dimension (contiguous data) is the rightmost one.
|
|
// Contiguous threads will read one hidden_dimension per LDG unless we need more than 32 threads.
|
|
//
|
|
// The different kernels use 1 ~ 32 threads per key (THREADS_PER_KEY). The size of the LDGs
|
|
// is always 16bytes (8 bytes for 8bit cache). Each thread sums Dh / THREADS_PER_KEY elements. At
|
|
// the end of each iteration of the Q * K^T loop, we perform a reduction between lanes using an
|
|
// HMMA instruction (Tensor Core). Each Q * K^T value is stored in shared memory in FP32.
|
|
//
|
|
// After that loop, a parallel softmax is computed across the different Q * K^T values stored in
|
|
// shared memory.
|
|
//
|
|
// The kernel ends with a loop over the values in V. We use THREADS_PER_VALUE to control how many
|
|
// timesteps are computed by loop iteration. As with the keys, the values are read from a cache
|
|
// except for the current timestep. The layout of the cache buffer for the values is same as the key,
|
|
// which is [B, H, L, Dh].
|
|
//
|
|
// Note that we have remapped key layout to make sure it shares the same pattern as value [B, H, L, Dh].
|
|
// It helps coalescing memory access, and reducing register pressure.
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename T, int Dh_MAX>
|
|
struct Qk_vec_m_
|
|
{
|
|
};
|
|
|
|
template <>
|
|
struct Qk_vec_m_<float, 32>
|
|
{
|
|
using Type = float;
|
|
};
|
|
|
|
template <>
|
|
struct Qk_vec_m_<float, 64>
|
|
{
|
|
using Type = float2;
|
|
};
|
|
|
|
template <>
|
|
struct Qk_vec_m_<float, 128>
|
|
{
|
|
using Type = float4;
|
|
};
|
|
|
|
template <>
|
|
struct Qk_vec_m_<float, 256>
|
|
{
|
|
using Type = float4;
|
|
};
|
|
|
|
template <>
|
|
struct Qk_vec_m_<uint16_t, 32>
|
|
{
|
|
using Type = uint32_t;
|
|
};
|
|
|
|
template <>
|
|
struct Qk_vec_m_<uint16_t, 64>
|
|
{
|
|
using Type = uint32_t;
|
|
};
|
|
|
|
template <>
|
|
struct Qk_vec_m_<uint16_t, 128>
|
|
{
|
|
using Type = uint2;
|
|
};
|
|
|
|
template <>
|
|
struct Qk_vec_m_<uint16_t, 256>
|
|
{
|
|
using Type = uint4;
|
|
};
|
|
#ifdef ENABLE_BF16
|
|
template <>
|
|
struct Qk_vec_m_<__nv_bfloat16, 32>
|
|
{
|
|
using Type = __nv_bfloat162;
|
|
};
|
|
|
|
template <>
|
|
struct Qk_vec_m_<__nv_bfloat16, 64>
|
|
{
|
|
using Type = __nv_bfloat162;
|
|
};
|
|
|
|
template <>
|
|
struct Qk_vec_m_<__nv_bfloat16, 128>
|
|
{
|
|
using Type = bf16_4_t;
|
|
};
|
|
|
|
template <>
|
|
struct Qk_vec_m_<__nv_bfloat16, 256>
|
|
{
|
|
using Type = bf16_8_t;
|
|
};
|
|
#endif // ENABLE_BF16
|
|
|
|
#ifdef ENABLE_FP8
|
|
template <>
|
|
struct Qk_vec_m_<__nv_fp8_e4m3, 32>
|
|
{
|
|
using Type = fp8_4_t;
|
|
};
|
|
|
|
template <>
|
|
struct Qk_vec_m_<__nv_fp8_e4m3, 64>
|
|
{
|
|
using Type = fp8_4_t;
|
|
};
|
|
|
|
template <>
|
|
struct Qk_vec_m_<__nv_fp8_e4m3, 128>
|
|
{
|
|
using Type = fp8_4_t;
|
|
};
|
|
|
|
template <>
|
|
struct Qk_vec_m_<__nv_fp8_e4m3, 256>
|
|
{
|
|
using Type = fp8_4_t;
|
|
};
|
|
#endif // ENABLE_FP8
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename T, int Dh>
|
|
struct Qk_vec_k_
|
|
{
|
|
using Type = typename Qk_vec_m_<T, Dh>::Type;
|
|
};
|
|
#ifdef ENABLE_FP8
|
|
template <>
|
|
struct Qk_vec_k_<__nv_fp8_e4m3, 32>
|
|
{
|
|
using Type = float4;
|
|
};
|
|
|
|
template <>
|
|
struct Qk_vec_k_<__nv_fp8_e4m3, 64>
|
|
{
|
|
using Type = float4;
|
|
};
|
|
|
|
template <>
|
|
struct Qk_vec_k_<__nv_fp8_e4m3, 128>
|
|
{
|
|
using Type = float4;
|
|
};
|
|
|
|
template <>
|
|
struct Qk_vec_k_<__nv_fp8_e4m3, 256>
|
|
{
|
|
using Type = float4;
|
|
};
|
|
#endif // ENABLE_FP8
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename T, int V_VEC_SIZE>
|
|
struct V_vec_m_
|
|
{
|
|
};
|
|
|
|
template <>
|
|
struct V_vec_m_<float, 1>
|
|
{
|
|
using Type = float;
|
|
};
|
|
|
|
template <>
|
|
struct V_vec_m_<float, 2>
|
|
{
|
|
using Type = float2;
|
|
};
|
|
|
|
template <>
|
|
struct V_vec_m_<float, 4>
|
|
{
|
|
using Type = float4;
|
|
};
|
|
|
|
template <>
|
|
struct V_vec_m_<float, 8>
|
|
{
|
|
using Type = Float8_;
|
|
};
|
|
|
|
template <>
|
|
struct V_vec_m_<uint16_t, 2>
|
|
{
|
|
using Type = uint32_t;
|
|
};
|
|
|
|
template <>
|
|
struct V_vec_m_<uint16_t, 4>
|
|
{
|
|
using Type = uint2;
|
|
};
|
|
|
|
template <>
|
|
struct V_vec_m_<uint16_t, 8>
|
|
{
|
|
using Type = uint4;
|
|
};
|
|
#ifdef ENABLE_BF16
|
|
template <>
|
|
struct V_vec_m_<__nv_bfloat16, 2>
|
|
{
|
|
using Type = __nv_bfloat162;
|
|
};
|
|
|
|
template <>
|
|
struct V_vec_m_<__nv_bfloat16, 4>
|
|
{
|
|
using Type = bf16_4_t;
|
|
};
|
|
|
|
template <>
|
|
struct V_vec_m_<__nv_bfloat16, 8>
|
|
{
|
|
using Type = bf16_8_t;
|
|
};
|
|
#endif // ENABLE_BF16
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename T, int V_VEC_SIZE>
|
|
struct V_vec_k_
|
|
{
|
|
using Type = typename V_vec_m_<T, V_VEC_SIZE>::Type;
|
|
};
|
|
#ifdef ENABLE_FP8
|
|
template <>
|
|
struct V_vec_k_<__nv_fp8_e4m3, 4>
|
|
{
|
|
using Type = float4;
|
|
};
|
|
|
|
template <>
|
|
struct V_vec_k_<__nv_fp8_e4m3, 8>
|
|
{
|
|
using Type = float4;
|
|
};
|
|
|
|
template <>
|
|
struct V_vec_k_<__nv_fp8_e4m3, 16>
|
|
{
|
|
using Type = float4;
|
|
};
|
|
#endif
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// Reuse V_vec traits as key and value share the same layout.
|
|
template <typename T, int K_VEC_SIZE>
|
|
struct K_vec_m_
|
|
{
|
|
using Type = typename V_vec_m_<T, K_VEC_SIZE>::Type;
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename T, int K_VEC_SIZE>
|
|
struct K_vec_k_
|
|
{
|
|
using Type = typename K_vec_m_<T, K_VEC_SIZE>::Type;
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
#ifdef MMHA_USE_FP32_ACCUM_FOR_FMA
|
|
template <typename T>
|
|
struct Qk_vec_accum_fp32_
|
|
{
|
|
};
|
|
|
|
template <>
|
|
struct Qk_vec_accum_fp32_<float>
|
|
{
|
|
using Type = float;
|
|
};
|
|
|
|
template <>
|
|
struct Qk_vec_accum_fp32_<float2>
|
|
{
|
|
using Type = float2;
|
|
};
|
|
|
|
template <>
|
|
struct Qk_vec_accum_fp32_<float4>
|
|
{
|
|
using Type = float4;
|
|
};
|
|
|
|
// template<> struct Qk_vec_accum_fp32_<uint16_t> { using Type = float; };
|
|
template <>
|
|
struct Qk_vec_accum_fp32_<uint32_t>
|
|
{
|
|
using Type = float2;
|
|
};
|
|
|
|
template <>
|
|
struct Qk_vec_accum_fp32_<uint2>
|
|
{
|
|
using Type = Float4_;
|
|
};
|
|
|
|
template <>
|
|
struct Qk_vec_accum_fp32_<uint4>
|
|
{
|
|
using Type = Float8_;
|
|
};
|
|
|
|
template <>
|
|
struct Qk_vec_accum_fp32_<__nv_bfloat16>
|
|
{
|
|
using Type = float;
|
|
};
|
|
|
|
template <>
|
|
struct Qk_vec_accum_fp32_<__nv_bfloat162>
|
|
{
|
|
using Type = float2;
|
|
};
|
|
|
|
template <>
|
|
struct Qk_vec_accum_fp32_<bf16_4_t>
|
|
{
|
|
using Type = Float4_;
|
|
};
|
|
|
|
template <>
|
|
struct Qk_vec_accum_fp32_<bf16_8_t>
|
|
{
|
|
using Type = Float8_;
|
|
};
|
|
|
|
#ifdef ENABLE_FP8
|
|
// template<>
|
|
// struct Qk_vec_accum_fp32_<fp8_2_t> {
|
|
// using Type = float2;
|
|
// };
|
|
template <>
|
|
struct Qk_vec_accum_fp32_<fp8_4_t>
|
|
{
|
|
using Type = Float4_;
|
|
};
|
|
|
|
// template<>
|
|
// struct Qk_vec_accum_fp32_<fp8_8_t> {
|
|
// using Type = Float4_;
|
|
// };
|
|
#endif // ENABLE_FP8
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename T>
|
|
struct K_vec_accum_fp32_
|
|
{
|
|
};
|
|
|
|
template <>
|
|
struct K_vec_accum_fp32_<float>
|
|
{
|
|
using Type = float;
|
|
};
|
|
|
|
template <>
|
|
struct K_vec_accum_fp32_<float2>
|
|
{
|
|
using Type = float2;
|
|
};
|
|
|
|
template <>
|
|
struct K_vec_accum_fp32_<float4>
|
|
{
|
|
using Type = float4;
|
|
};
|
|
|
|
template <>
|
|
struct K_vec_accum_fp32_<Float8_>
|
|
{
|
|
using Type = Float8_;
|
|
};
|
|
|
|
template <>
|
|
struct K_vec_accum_fp32_<uint32_t>
|
|
{
|
|
using Type = float2;
|
|
};
|
|
|
|
template <>
|
|
struct K_vec_accum_fp32_<uint2>
|
|
{
|
|
using Type = Float4_;
|
|
};
|
|
|
|
template <>
|
|
struct K_vec_accum_fp32_<uint4>
|
|
{
|
|
using Type = Float8_;
|
|
};
|
|
|
|
template <>
|
|
struct K_vec_accum_fp32_<__nv_bfloat16>
|
|
{
|
|
using Type = float;
|
|
};
|
|
|
|
template <>
|
|
struct K_vec_accum_fp32_<__nv_bfloat162>
|
|
{
|
|
using Type = float2;
|
|
};
|
|
|
|
template <>
|
|
struct K_vec_accum_fp32_<bf16_4_t>
|
|
{
|
|
using Type = Float4_;
|
|
};
|
|
|
|
template <>
|
|
struct K_vec_accum_fp32_<bf16_8_t>
|
|
{
|
|
using Type = Float8_;
|
|
};
|
|
#ifdef ENABLE_FP8
|
|
template <>
|
|
struct K_vec_accum_fp32_<__nv_fp8_e4m3>
|
|
{
|
|
using Type = float;
|
|
};
|
|
|
|
template <>
|
|
struct K_vec_accum_fp32_<fp8_2_t>
|
|
{
|
|
using Type = float2;
|
|
};
|
|
|
|
template <>
|
|
struct K_vec_accum_fp32_<fp8_4_t>
|
|
{
|
|
using Type = Float4_;
|
|
};
|
|
|
|
template <>
|
|
struct K_vec_accum_fp32_<fp8_8_t>
|
|
{
|
|
using Type = Float8_;
|
|
};
|
|
#endif // ENABLE_FP8
|
|
|
|
template <>
|
|
struct K_vec_accum_fp32_<int8_t>
|
|
{
|
|
using Type = float;
|
|
};
|
|
|
|
template <>
|
|
struct K_vec_accum_fp32_<int16_t>
|
|
{
|
|
using Type = float2;
|
|
};
|
|
|
|
template <>
|
|
struct K_vec_accum_fp32_<int32_t>
|
|
{
|
|
using Type = Float4_;
|
|
};
|
|
|
|
template <>
|
|
struct K_vec_accum_fp32_<int64_t>
|
|
{
|
|
using Type = Float8_;
|
|
};
|
|
|
|
#endif // MMHA_USE_FP32_ACCUM_FOR_FMA
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
#ifdef MMHA_USE_FP32_ACCUM_FOR_OUT
|
|
template <typename T>
|
|
struct V_vec_accum_fp32_
|
|
{
|
|
};
|
|
|
|
template <>
|
|
struct V_vec_accum_fp32_<float>
|
|
{
|
|
using Type = float;
|
|
};
|
|
|
|
template <>
|
|
struct V_vec_accum_fp32_<float2>
|
|
{
|
|
using Type = float2;
|
|
};
|
|
|
|
template <>
|
|
struct V_vec_accum_fp32_<float4>
|
|
{
|
|
using Type = float4;
|
|
};
|
|
|
|
template <>
|
|
struct V_vec_accum_fp32_<uint32_t>
|
|
{
|
|
using Type = float2;
|
|
};
|
|
|
|
template <>
|
|
struct V_vec_accum_fp32_<uint2>
|
|
{
|
|
using Type = Float4_;
|
|
};
|
|
|
|
template <>
|
|
struct V_vec_accum_fp32_<uint4>
|
|
{
|
|
using Type = Float8_;
|
|
};
|
|
#ifdef ENABLE_BF16
|
|
template <>
|
|
struct V_vec_accum_fp32_<__nv_bfloat162>
|
|
{
|
|
using Type = float2;
|
|
};
|
|
|
|
template <>
|
|
struct V_vec_accum_fp32_<bf16_4_t>
|
|
{
|
|
using Type = Float4_;
|
|
};
|
|
|
|
template <>
|
|
struct V_vec_accum_fp32_<bf16_8_t>
|
|
{
|
|
using Type = Float8_;
|
|
};
|
|
#endif // ENABLE_BF16
|
|
#ifdef ENABLE_FP8
|
|
// template<>
|
|
// struct V_vec_accum_fp32_<fp8_2_t> {
|
|
// using Type = float2;
|
|
// };
|
|
template <>
|
|
struct V_vec_accum_fp32_<fp8_4_t>
|
|
{
|
|
using Type = Float4_;
|
|
};
|
|
|
|
// template<>
|
|
// struct V_vec_accum_fp32_<fp8_8_t> {
|
|
// using Type = Float4_;
|
|
// };
|
|
#endif // ENABLE_FP8
|
|
#endif
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename Tout, typename Tin>
|
|
__inline__ __device__ constexpr Tout vec_conversion(Tin const& x)
|
|
{
|
|
static_assert(std::is_same<Tout, Tin>::value, "Type mismatch");
|
|
return x;
|
|
}
|
|
|
|
template <>
|
|
__inline__ __device__ Float8_ vec_conversion<Float8_, uint4>(uint4 const& a)
|
|
{
|
|
Float8_ fc;
|
|
fc.x = half2_to_float2(a.x);
|
|
fc.y = half2_to_float2(a.y);
|
|
fc.z = half2_to_float2(a.z);
|
|
fc.w = half2_to_float2(a.w);
|
|
return fc;
|
|
}
|
|
|
|
#ifdef ENABLE_BF16
|
|
template <>
|
|
__inline__ __device__ Float8_ vec_conversion<Float8_, bf16_8_t>(bf16_8_t const& a)
|
|
{
|
|
Float8_ fc;
|
|
fc.x = bf1622float2(a.x);
|
|
fc.y = bf1622float2(a.y);
|
|
fc.z = bf1622float2(a.z);
|
|
fc.w = bf1622float2(a.w);
|
|
return fc;
|
|
}
|
|
#endif // ENABLE_BF16
|
|
|
|
#ifdef ENABLE_FP8
|
|
// fp8_t
|
|
template <>
|
|
__inline__ __device__ float vec_conversion<float, __nv_fp8_e4m3>(__nv_fp8_e4m3 const& a)
|
|
{
|
|
return float(a);
|
|
}
|
|
|
|
template <>
|
|
__inline__ __device__ __nv_fp8_e4m3 vec_conversion<__nv_fp8_e4m3, float>(float const& a)
|
|
{
|
|
return __nv_fp8_e4m3(a);
|
|
}
|
|
|
|
// fp8_2_t
|
|
template <>
|
|
__inline__ __device__ float2 vec_conversion<float2, fp8_2_t>(fp8_2_t const& a)
|
|
{
|
|
return float2(a);
|
|
}
|
|
|
|
template <>
|
|
__inline__ __device__ fp8_2_t vec_conversion<fp8_2_t, float2>(float2 const& a)
|
|
{
|
|
return fp8_2_t(a);
|
|
}
|
|
|
|
// fp8_4_t
|
|
template <>
|
|
__inline__ __device__ float4 vec_conversion<float4, fp8_4_t>(fp8_4_t const& a)
|
|
{
|
|
return float4(a);
|
|
}
|
|
|
|
template <>
|
|
__inline__ __device__ fp8_4_t vec_conversion<fp8_4_t, float4>(float4 const& a)
|
|
{
|
|
return fp8_4_t(a);
|
|
}
|
|
#endif // ENABLE_FP8
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <int THREADS_PER_KEY, typename Q_vec, typename K_vec, int N>
|
|
inline __device__ float qk_dot_(const Q_vec (&q)[N], const K_vec (&k)[N])
|
|
{
|
|
#ifdef MMHA_USE_FP32_ACCUM_FOR_FMA
|
|
using K_vec_accum = typename K_vec_accum_fp32_<K_vec>::Type;
|
|
#else
|
|
using K_vec_accum = K_vec;
|
|
#endif
|
|
// Compute the parallel products for Q*K^T (treat vector lanes separately).
|
|
K_vec_accum qk_vec = mul<K_vec_accum, Q_vec, K_vec>(q[0], k[0]);
|
|
#pragma unroll
|
|
for (int ii = 1; ii < N; ++ii)
|
|
{
|
|
qk_vec = fma(q[ii], k[ii], qk_vec);
|
|
}
|
|
|
|
// Finalize the reduction across lanes.
|
|
float qk = sum(qk_vec);
|
|
#pragma unroll
|
|
for (int mask = THREADS_PER_KEY / 2; mask >= 1; mask /= 2)
|
|
{
|
|
qk += __shfl_xor_sync(uint32_t(-1), qk, mask);
|
|
}
|
|
return qk;
|
|
}
|
|
|
|
template <int THREADS_PER_KEY, typename Q_vec, typename K_vec, int N>
|
|
inline __device__ float qk_scale_dot_(const Q_vec (&q)[N], const K_vec (&k)[N], float const k_scale)
|
|
{
|
|
#ifdef MMHA_USE_FP32_ACCUM_FOR_FMA
|
|
using K_vec_accum = typename K_vec_accum_fp32_<K_vec>::Type;
|
|
#else
|
|
using K_vec_accum = K_vec;
|
|
#endif
|
|
// Compute the parallel products for Q*K^T (treat vector lanes separately).
|
|
K_vec_accum k_vec = mul<K_vec_accum, float, K_vec>(k_scale, k[0]);
|
|
K_vec_accum qk_vec = mul<K_vec_accum, Q_vec, K_vec_accum>(q[0], k_vec);
|
|
#pragma unroll
|
|
for (int ii = 1; ii < N; ++ii)
|
|
{
|
|
K_vec_accum k_vec = mul<K_vec_accum, float, K_vec>(k_scale, k[ii]);
|
|
qk_vec = fma(q[ii], k_vec, qk_vec);
|
|
}
|
|
|
|
// Finalize the reduction across lanes.
|
|
float qk = sum(qk_vec);
|
|
#pragma unroll
|
|
for (int mask = THREADS_PER_KEY / 2; mask >= 1; mask /= 2)
|
|
{
|
|
qk += __shfl_xor_sync(uint32_t(-1), qk, mask);
|
|
}
|
|
return qk;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename T, int THREADS_PER_KEY>
|
|
struct Qk_dot
|
|
{
|
|
template <typename Q_vec, typename K_vec, int N>
|
|
static inline __device__ float dot(const Q_vec (&q)[N], const K_vec (&k)[N])
|
|
{
|
|
return qk_dot_<THREADS_PER_KEY>(q, k);
|
|
}
|
|
|
|
template <typename Q_vec, typename K_vec, int N>
|
|
static inline __device__ float scale_dot(const Q_vec (&q)[N], const K_vec (&k)[N], float const k_scale)
|
|
{
|
|
#ifdef MMHA_USE_HMMA
|
|
static_assert("HMMA doesn't support k scales");
|
|
#endif // MMHA_USE_HMMA
|
|
return qk_scale_dot_<THREADS_PER_KEY>(q, k, k_scale);
|
|
}
|
|
|
|
template <int WARP_SIZE = 32>
|
|
static inline __device__ bool is_leader(int const tidx)
|
|
{
|
|
return (tidx % THREADS_PER_KEY) == 0;
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename K_vec>
|
|
inline __device__ void hmma_fp32(float4& c, K_vec const& a, K_vec b)
|
|
{
|
|
// Not supported.
|
|
assert(false);
|
|
}
|
|
|
|
template <>
|
|
inline __device__ void hmma_fp32(float4& c, uint32_t const& a, uint32_t b)
|
|
{
|
|
asm volatile(
|
|
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 \n"
|
|
" {%0, %1, %2, %3}, \n"
|
|
" {%4, %5}, \n"
|
|
" {%6}, \n"
|
|
" {%0, %1, %2, %3}; \n"
|
|
: "+f"(c.x), "+f"(c.y), "+f"(c.z), "+f"(c.w)
|
|
: "r"(a), "r"(a), "r"(b));
|
|
}
|
|
|
|
template <>
|
|
inline __device__ void hmma_fp32(float4& c, uint2 const& a, uint2 b)
|
|
{
|
|
hmma_fp32(c, a.x, b.x);
|
|
hmma_fp32(c, a.y, b.y);
|
|
}
|
|
|
|
template <>
|
|
inline __device__ void hmma_fp32(float4& c, uint4 const& a, uint4 b)
|
|
{
|
|
hmma_fp32(c, a.x, b.x);
|
|
hmma_fp32(c, a.y, b.y);
|
|
hmma_fp32(c, a.z, b.z);
|
|
hmma_fp32(c, a.w, b.w);
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename K_vec, int THREADS_PER_KEY, int N>
|
|
inline __device__ float qk_hmma_dot_(const K_vec (&q)[N], const K_vec (&k)[N])
|
|
{
|
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
|
|
|
|
// Each quad computes its partial result.
|
|
float4 acc = make_float4(0.f, 0.f, 0.f, 0.f);
|
|
|
|
#pragma unroll
|
|
for (int ii = 0; ii < N; ++ii)
|
|
{
|
|
hmma_fp32(acc, q[ii], k[ii]);
|
|
}
|
|
|
|
// The position inside the warp.
|
|
int lane = threadIdx.x % 32;
|
|
|
|
// The position inside the HMMA instruction.
|
|
int row = lane / 4;
|
|
int col = lane % 4 * 2;
|
|
|
|
// The result. Only 1 thread in each quad owns a valid value.
|
|
//
|
|
// Row 0, it's lane 0 (col 0) in acc.x.
|
|
// Row 1, it's lane 4 (col 0) in acc.y.
|
|
// Row 2, it's lane 9 (col 2) in acc.x.
|
|
// Row 3, it's lane 13 (col 2) in acc.y.
|
|
// Row 4, it's lane 18 (col 4) in acc.x.
|
|
// Row 5, it's lane 22 (col 4) in acc.y.
|
|
// Row 6, it's lane 27 (col 6) in acc.x.
|
|
// Row 7, it's lane 31 (col 6) in acc.y.
|
|
//
|
|
float result = (row == col) ? acc.x : acc.y;
|
|
|
|
// Do the reduction inside the warp.
|
|
if (THREADS_PER_KEY > 4)
|
|
{
|
|
result += __shfl_xor_sync(unsigned(-1), result, 4);
|
|
}
|
|
if (THREADS_PER_KEY > 8)
|
|
{
|
|
result += __shfl_xor_sync(unsigned(-1), result, 9);
|
|
}
|
|
if (THREADS_PER_KEY > 16)
|
|
{
|
|
result += __shfl_xor_sync(unsigned(-1), result, 18);
|
|
}
|
|
|
|
// The warp leader has the correct value.
|
|
return result;
|
|
|
|
#else // !defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 750
|
|
return 0.f;
|
|
#endif
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <int THREADS_PER_KEY>
|
|
struct Qk_dot<uint16_t, THREADS_PER_KEY>
|
|
{
|
|
template <typename Q_vec, typename K_vec, int N>
|
|
static inline __device__ float dot(const Q_vec (&q)[N], const K_vec (&k)[N])
|
|
{
|
|
#if __CUDA_ARCH__ >= 750 && defined(MMHA_USE_HMMA)
|
|
return qk_hmma_dot_<K_vec, THREADS_PER_KEY, N>(q, k);
|
|
#else
|
|
return qk_dot_<THREADS_PER_KEY>(q, k);
|
|
#endif // defined MMHA_USE_HMMA
|
|
}
|
|
|
|
template <typename Q_vec, typename K_vec, int N>
|
|
static inline __device__ float scale_dot(const Q_vec (&q)[N], const K_vec (&k)[N], float const k_scale)
|
|
{
|
|
#ifdef MMHA_USE_HMMA
|
|
static_assert("HMMA doesn't support k scales");
|
|
#endif // MMHA_USE_HMMA
|
|
return qk_scale_dot_<THREADS_PER_KEY>(q, k, k_scale);
|
|
}
|
|
|
|
template <int WARP_SIZE = 32>
|
|
static inline __device__ bool is_leader(int const tidx)
|
|
{
|
|
// Use HMMA.FP32, leader threads are in the diagonal roughly (0, 4, 9, 13, 18, 22, 27, 31).
|
|
#if __CUDA_ARCH__ >= 750 && defined(MMHA_USE_HMMA)
|
|
int leader = 0;
|
|
// The thread position inside the warp.
|
|
int lane = tidx % WARP_SIZE;
|
|
if (THREADS_PER_KEY == 4)
|
|
{
|
|
leader = int(lane / 8);
|
|
}
|
|
else
|
|
{
|
|
leader = int(lane / THREADS_PER_KEY) * int(THREADS_PER_KEY / 8);
|
|
}
|
|
#else
|
|
bool const leader = 0;
|
|
#endif // defined MMHA_USE_HMMA
|
|
return (tidx % THREADS_PER_KEY) == leader;
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename Tk, typename V_vec_accum, typename V_vec_m, bool INT8_KV_CACHE, bool FP8_KV_CACHE>
|
|
inline __device__ void Logit_value_fma(
|
|
V_vec_accum& out, Tk const* logits_smem, V_vec_m const& v_vec, float const v_scale, bool const is_mask)
|
|
{
|
|
#if defined(MMHA_USE_FP32_ACCUM_FOR_LOGITS)
|
|
float logit = is_mask ? 0.f : reinterpret_cast<float*>(logits_smem)[0];
|
|
if constexpr (INT8_KV_CACHE)
|
|
{
|
|
V_vec_accum v_vec_ = mul<V_vec_accum, float, V_vec_m>(v_scale, v_vec);
|
|
out = fma(logit, cast_to_float(v_vec_), out);
|
|
}
|
|
else if constexpr (FP8_KV_CACHE)
|
|
{
|
|
#ifdef MMHA_FP8_SCALE_P_INSTEAD_OF_V
|
|
out = fma(logit, cast_to_float(v_vec), out);
|
|
#else
|
|
V_vec_accum v_vec_ = mul<V_vec_accum, float, V_vec_m>(v_scale, v_vec);
|
|
out = fma(logit, cast_to_float(v_vec_), out);
|
|
#endif // MMHA_FP8_SCALE_P_INSTEAD_OF_V
|
|
}
|
|
else
|
|
{
|
|
out = fma(logit, cast_to_float(v_vec), out);
|
|
}
|
|
#else // MMHA_USE_FP32_ACCUM_FOR_LOGITS
|
|
Tk logit = is_mask ? Tk(0.f) : logits_smem[0];
|
|
if constexpr (INT8_KV_CACHE)
|
|
{
|
|
V_vec_accum v_vec_ = mul<V_vec_accum, float, V_vec_m>(v_scale, v_vec);
|
|
out = fma(logit, v_vec_, out);
|
|
}
|
|
else if constexpr (FP8_KV_CACHE)
|
|
{
|
|
#ifdef MMHA_FP8_SCALE_P_INSTEAD_OF_V
|
|
out = fma(logit, v_vec, out);
|
|
#else
|
|
V_vec_accum v_vec_ = mul<V_vec_accum, float, V_vec_m>(v_scale, v_vec);
|
|
out = fma(logit, v_vec_, out);
|
|
#endif // MMHA_FP8_SCALE_P_INSTEAD_OF_V
|
|
}
|
|
else
|
|
{
|
|
out = fma(logit, v_vec, out);
|
|
}
|
|
#endif // MMHA_USE_FP32_ACCUM_FOR_LOGITS
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <int WARPS_PER_BLOCK, int WARP_SIZE = 32>
|
|
inline __device__ float block_sum(float* red_smem, float sum)
|
|
{
|
|
|
|
// Decompose the thread index into warp / lane.
|
|
int warp = threadIdx.x / WARP_SIZE;
|
|
int lane = threadIdx.x % WARP_SIZE;
|
|
|
|
// Compute the sum per warp.
|
|
#pragma unroll
|
|
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2)
|
|
{
|
|
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
|
|
}
|
|
|
|
// Warp leaders store the data to shared memory.
|
|
if (lane == 0)
|
|
{
|
|
red_smem[warp] = sum;
|
|
}
|
|
|
|
// Make sure the data is in shared memory.
|
|
__syncthreads();
|
|
|
|
// The warps compute the final sums.
|
|
if (lane < WARPS_PER_BLOCK)
|
|
{
|
|
sum = red_smem[lane];
|
|
}
|
|
|
|
// Parallel reduction inside the warp.
|
|
#pragma unroll
|
|
for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2)
|
|
{
|
|
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
|
|
}
|
|
|
|
// Broadcast to other threads.
|
|
return __shfl_sync(uint32_t(-1), sum, 0);
|
|
}
|
|
|
|
#if defined(MMHA_USE_FP32_ACCUM_FOR_LOGITS)
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline __device__ float cast_to_float(float u)
|
|
{
|
|
return u;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline __device__ float2 cast_to_float(float2 u)
|
|
{
|
|
return u;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline __device__ float4 cast_to_float(float4 u)
|
|
{
|
|
return u;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline __device__ Float4_ cast_to_float(Float4_ u)
|
|
{
|
|
return u;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline __device__ Float8_ cast_to_float(Float8_ u)
|
|
{
|
|
return u;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline __device__ float2 cast_to_float(uint32_t u)
|
|
{
|
|
return half2_to_float2(u);
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline __device__ Float4_ cast_to_float(uint2 u)
|
|
{
|
|
Float4_ tmp;
|
|
tmp.x = half2_to_float2(u.x);
|
|
tmp.y = half2_to_float2(u.y);
|
|
return tmp;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline __device__ Float8_ cast_to_float(uint4 u)
|
|
{
|
|
Float8_ tmp;
|
|
tmp.x = half2_to_float2(u.x);
|
|
tmp.y = half2_to_float2(u.y);
|
|
tmp.z = half2_to_float2(u.z);
|
|
tmp.w = half2_to_float2(u.w);
|
|
return tmp;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline __device__ float2 cast_to_float(__nv_bfloat162 u)
|
|
{
|
|
float2 tmp;
|
|
tmp = __bfloat1622float2(u);
|
|
return tmp;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline __device__ Float4_ cast_to_float(bf16_4_t u)
|
|
{
|
|
Float4_ tmp;
|
|
tmp.x = __bfloat1622float2(u.x);
|
|
tmp.y = __bfloat1622float2(u.y);
|
|
return tmp;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline __device__ Float8_ cast_to_float(bf16_8_t u)
|
|
{
|
|
Float8_ tmp;
|
|
tmp.x = __bfloat1622float2(u.x);
|
|
tmp.y = __bfloat1622float2(u.y);
|
|
tmp.z = __bfloat1622float2(u.z);
|
|
tmp.w = __bfloat1622float2(u.w);
|
|
return tmp;
|
|
}
|
|
|
|
#endif
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename T>
|
|
inline __device__ __host__ T divUp(T m, T n)
|
|
{
|
|
return (m + n - 1) / n;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename T>
|
|
inline __device__ __host__ T div(T m, T n)
|
|
{
|
|
return m / n;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename T>
|
|
struct kernel_type_t
|
|
{
|
|
using Type = T;
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// Compute the largest supported head size (dh_max). It must be the smallest power-of-2 that is not strictly smaller
|
|
// than the head size (dh).
|
|
inline __device__ __host__ constexpr unsigned dh_max(unsigned dh)
|
|
{
|
|
return next_power_of_two(mmha::const_max(dh, 32u));
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename T>
|
|
inline __device__ __host__ constexpr unsigned threads_per_value(unsigned dh_max)
|
|
{
|
|
return dh_max * sizeof(T) / 16;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename T, unsigned Dh_MAX>
|
|
inline __device__ __host__ constexpr unsigned threads_per_key()
|
|
{
|
|
// Since we want to perform the reduction entirely within a warp, the number of threads per key
|
|
// is capped at 32.
|
|
constexpr unsigned threads = (unsigned) (Dh_MAX * sizeof(T) / 16u);
|
|
if ((threads & (threads - 1)) != 0)
|
|
{
|
|
assert(false); // Not a power of two.
|
|
}
|
|
return std::min(32u, threads);
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// Specialized launch bounds for certain cases, which helps increase occupancy.
|
|
// Keep other cases untouched as there might be register spilling.
|
|
|
|
template <typename T, typename Tcache, unsigned THREADS_PER_BLOCK, unsigned Dh_MAX, bool DO_CROSS_ATTENTION,
|
|
bool HAS_BEAMS, bool POS_SHIFT>
|
|
struct Launch_bounds_config
|
|
{
|
|
// By default, we will not use launch bounds.
|
|
static constexpr int MAX_THREADS_PER_BLOCK = 0;
|
|
static constexpr int MIN_BLOCKS_PER_SM = 0;
|
|
};
|
|
|
|
template <>
|
|
struct Launch_bounds_config<uint16_t, __nv_fp8_e4m3, 256u, 64u, false, false, false>
|
|
{
|
|
static constexpr int MAX_THREADS_PER_BLOCK = 256u;
|
|
static constexpr int MIN_BLOCKS_PER_SM = 4u;
|
|
};
|
|
|
|
// Llama with FP8 KV Cache.
|
|
template <>
|
|
struct Launch_bounds_config<uint16_t, __nv_fp8_e4m3, 256u, 128u, false, false, false>
|
|
{
|
|
static constexpr int MAX_THREADS_PER_BLOCK = 256u;
|
|
static constexpr int MIN_BLOCKS_PER_SM = 4u;
|
|
};
|
|
|
|
// GPTJ With Beam Searching and FP8 KV Cache.
|
|
template <>
|
|
struct Launch_bounds_config<uint16_t, __nv_fp8_e4m3, 256u, 256u, false, true, false>
|
|
{
|
|
static constexpr int MAX_THREADS_PER_BLOCK = 256u;
|
|
static constexpr int MIN_BLOCKS_PER_SM = 3u;
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline __device__ constexpr uint32_t shfl_mask(int threads)
|
|
{
|
|
assert(threads <= 32);
|
|
return threads == 32 ? -1u : (1u << threads) - 1u;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename T, typename T_VEC, unsigned VECS_PER_CHUNK>
|
|
__device__ inline constexpr uint2 chunk_index(unsigned tidx)
|
|
{
|
|
// The chunk associated with the thread.
|
|
auto const idx_chunk = tidx / VECS_PER_CHUNK;
|
|
|
|
// The position of the T_VEC vector in that chunk associated with the thread.
|
|
static_assert(sizeof(T_VEC) % sizeof(T) == 0);
|
|
unsigned constexpr kVecSize{sizeof(T_VEC) / sizeof(T)};
|
|
auto const idx_vec = (tidx % VECS_PER_CHUNK) * kVecSize;
|
|
|
|
return uint2{idx_chunk, idx_vec};
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <
|
|
// The type of the inputs. Supported types: float, uint16_t, nv_bfloat16.
|
|
typename T,
|
|
// The type of the cache.
|
|
typename Tcache,
|
|
// The type of the shift key cache.
|
|
typename TKcache,
|
|
// Type of struct containing KV cache
|
|
typename KVCacheBuffer,
|
|
// Type of struct containing K cache to read past keys
|
|
typename KCacheBuffer,
|
|
// The hidden dimension per head.
|
|
unsigned Dh,
|
|
// The number of threads in a threadblock.
|
|
unsigned THREADS_PER_BLOCK,
|
|
// Whether cross attention is enabled
|
|
bool DO_CROSS_ATTENTION,
|
|
// Whether has beams.
|
|
bool HAS_BEAMS,
|
|
// Whether enable multi-block mode for long-sequence-length.
|
|
bool DO_MULTI_BLOCK = false,
|
|
// Whether enable position shift for streamingllm
|
|
bool POS_SHIFT = false,
|
|
// Whether compute implicit relative attention bias on the fly.
|
|
bool IMPLICIT_REL_ATTN_BIAS = false,
|
|
// The number of threads per key.
|
|
unsigned THREADS_PER_KEY = threads_per_key<T, dh_max(Dh)>(),
|
|
// The number of threads per value.
|
|
unsigned THREADS_PER_VALUE = threads_per_value<T>(dh_max(Dh)),
|
|
// The unroll factor for loading from K cache.
|
|
// Set it default to 4 for higher occupancy (by reducing registers usage).
|
|
unsigned K_LOOP_UNROLL = 4,
|
|
// The unroll factor for loading from V cache.
|
|
unsigned V_LOOP_UNROLL = 8,
|
|
// Launch bounds
|
|
unsigned MAX_THEADS_PER_BLOCK
|
|
= Launch_bounds_config<T, Tcache, THREADS_PER_BLOCK, dh_max(Dh), DO_CROSS_ATTENTION, HAS_BEAMS, POS_SHIFT>()
|
|
.MAX_THREADS_PER_BLOCK,
|
|
unsigned MIN_BLOCKS_PER_SM
|
|
= Launch_bounds_config<T, Tcache, THREADS_PER_BLOCK, dh_max(Dh), DO_CROSS_ATTENTION, HAS_BEAMS, POS_SHIFT>()
|
|
.MIN_BLOCKS_PER_SM>
|
|
__global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) masked_multihead_attention_kernel(
|
|
Multihead_attention_params<T, DO_CROSS_ATTENTION> params, KVCacheBuffer kvCacheBuffer, KCacheBuffer pastKCache)
|
|
{
|
|
|
|
using Tk = typename kernel_type_t<T>::Type;
|
|
// Use 8bit cache.
|
|
static constexpr bool ENABLE_8BITS_K_CACHE = sizeof(TKcache) == 1;
|
|
static constexpr bool ENABLE_8BITS_KV_CACHE = sizeof(Tcache) == 1;
|
|
// FP8 KV Cache.
|
|
static constexpr bool FP8_K_CACHE = std::is_same<TKcache, __nv_fp8_e4m3>::value;
|
|
static constexpr bool FP8_KV_CACHE = std::is_same<Tcache, __nv_fp8_e4m3>::value;
|
|
// INT8 KV Cache.
|
|
static constexpr bool INT8_KV_CACHE = std::is_same<Tcache, int8_t>::value;
|
|
|
|
// The size of a warp.
|
|
constexpr unsigned WARP_SIZE{32};
|
|
// The number of warps in a threadblock.
|
|
constexpr unsigned WARPS_PER_BLOCK{THREADS_PER_BLOCK / WARP_SIZE};
|
|
|
|
// The maximum hidden size per head.
|
|
constexpr auto Dh_MAX = dh_max(Dh);
|
|
constexpr bool IS_Dh_MAX = Dh == Dh_MAX;
|
|
static_assert(Dh_MAX >= WARP_SIZE);
|
|
static_assert(Dh_MAX >= Dh);
|
|
// Only instantiate few head sizes for implicit relative attention bias in order to save compilation time.
|
|
static_assert(!IMPLICIT_REL_ATTN_BIAS || Dh == 32 || Dh == 64 || Dh == 128);
|
|
|
|
// The maximum sequence length in the cyclic kv_cache, i.e., an upper bound on L.
|
|
// Note that the maximum sequence length supported by the model might be greater than this.
|
|
// Note max_attention_window_size is maximum of cyclic_attention_window_size among all layers.
|
|
// By default, you can assume that they are the same.
|
|
auto const cyclic_kv_cache_len = static_cast<unsigned>(params.cyclic_attention_window_size);
|
|
// The number of sink tokens in kv cache to support streamingllm
|
|
auto const sink_token_len = static_cast<unsigned>(params.sink_token_length);
|
|
// The current timestep (including paddings).
|
|
// It is only used to calculate the smem stride.
|
|
auto const timestep = static_cast<unsigned>(DO_MULTI_BLOCK ? params.timesteps_per_block : params.timestep);
|
|
|
|
#ifdef ENABLE_MULTI_BLOCK_OPTION
|
|
constexpr bool MULTI_BLOCK_FLAG = DO_MULTI_BLOCK;
|
|
#else
|
|
constexpr bool MULTI_BLOCK_FLAG = false;
|
|
#endif
|
|
|
|
// Use smem_size_in_bytes (above) to determine the amount of shared memory.
|
|
extern __shared__ char smem_[];
|
|
|
|
// The shared memory for the Q*K^T values and partial logits in softmax.
|
|
auto qk_smem = reinterpret_cast<float*>(smem_);
|
|
|
|
__shared__ float qk_current_smem[1];
|
|
|
|
// The shared memory for the logits. For FP32, that's the same buffer as qk_smem.
|
|
char* logits_smem_ = smem_;
|
|
#ifndef MMHA_USE_FP32_ACCUM_FOR_LOGITS
|
|
if (sizeof(Tk) != 4)
|
|
{
|
|
auto const max_timesteps = DO_CROSS_ATTENTION ? cyclic_kv_cache_len : min(timestep, cyclic_kv_cache_len);
|
|
logits_smem_ += divUp(max_timesteps + 1, 4u) * 16;
|
|
}
|
|
Tk* logits_smem = reinterpret_cast<Tk*>(logits_smem_);
|
|
#else
|
|
float* logits_smem = reinterpret_cast<float*>(logits_smem_);
|
|
#endif
|
|
|
|
__shared__ Tk logits_current_smem[1];
|
|
|
|
// The shared memory to do the final reduction for the output values. Reuse qk_smem.
|
|
Tk* out_smem = reinterpret_cast<Tk*>(smem_);
|
|
|
|
// The shared memory buffers for the block-wide reductions. One for max, one for sum.
|
|
__shared__ float red_smem[WARPS_PER_BLOCK * 2];
|
|
|
|
// A vector of Q or K elements for the current timestep.
|
|
using Qk_vec_m = typename Qk_vec_m_<T, Dh_MAX>::Type; // with memory-used precision
|
|
using Qk_vec_k = typename Qk_vec_k_<T, Dh_MAX>::Type; // with kernel-used precision
|
|
#ifdef MMHA_USE_FP32_ACCUM_FOR_FMA
|
|
using Qk_vec_accum = typename Qk_vec_accum_fp32_<Qk_vec_k>::Type;
|
|
#else
|
|
using Qk_vec_accum = Qk_vec_k;
|
|
#endif
|
|
|
|
// Make sure the hidden dimension per head is a multiple of the number of threads per key.
|
|
static_assert(Dh_MAX % THREADS_PER_KEY == 0); // trivially satisfied since THREADS_PER_KEY in {1, 2, 4}
|
|
|
|
// The number of elements per vector.
|
|
// Each thread will handle 16 bytes.
|
|
constexpr int K_VEC_SIZE = 16u / sizeof(T);
|
|
// Make sure the hidden size per head is a multiple of the vector size.
|
|
static_assert(Dh_MAX % K_VEC_SIZE == 0);
|
|
// The type of queries and keys for the math in the Q*K^T product.
|
|
using K_vec_k = typename K_vec_k_<T, K_VEC_SIZE>::Type;
|
|
// Only used when key cache is quantized to 8 bits.
|
|
using K_vec_m = typename packed_type<TKcache, num_elems<K_vec_k>::value>::type;
|
|
#ifdef MMHA_USE_FP32_ACCUM_FOR_FMA
|
|
using K_vec_accum = typename Qk_vec_accum_fp32_<K_vec_k>::Type;
|
|
#else
|
|
using K_vec_accum = K_vec_k;
|
|
#endif
|
|
|
|
// Use alignment for safely casting the shared buffers as Qk_vec_k and K_vec_k.
|
|
// Shared memory to store Q inputs.
|
|
__shared__ __align__(mmha::const_max(sizeof(Qk_vec_k), sizeof(K_vec_k))) Tk q_smem[Dh_MAX];
|
|
__shared__ __align__(mmha::const_max(sizeof(Qk_vec_k), sizeof(K_vec_k))) Tk k_smem[Dh_MAX];
|
|
|
|
// Make sure the hidden dimension per head is a multiple of the number of threads per value.
|
|
static_assert(Dh_MAX % THREADS_PER_VALUE == 0); // trivially satisfied since THREADS_PER_VALUE == Dh_MAX / p
|
|
|
|
// The number of elements per vector.
|
|
constexpr int V_VEC_SIZE = Dh_MAX / THREADS_PER_VALUE;
|
|
// A vector of V elements for the current timestep.
|
|
using V_vec_k = typename V_vec_k_<T, V_VEC_SIZE>::Type;
|
|
// Only used when value cache is quantized to 8 bits.
|
|
using V_vec_m = typename packed_type<Tcache, num_elems<V_vec_k>::value>::type;
|
|
static_assert(V_VEC_SIZE == sizeof(V_vec_k) / sizeof(T));
|
|
|
|
// This could be one of the reasons to have a separate kernel for cross attention
|
|
constexpr auto bias_smem_size = DO_CROSS_ATTENTION ? Dh_MAX : 1u;
|
|
__shared__ __align__(mmha::const_max(mmha::const_max(sizeof(Qk_vec_k), sizeof(K_vec_k)), sizeof(V_vec_k)))
|
|
Tk bias_smem[bias_smem_size];
|
|
|
|
// The number of elements per vector.
|
|
constexpr unsigned QK_VEC_SIZE{sizeof(Qk_vec_m) / sizeof(T)};
|
|
// Make sure the hidden size per head is a multiple of the vector size.
|
|
static_assert(Dh_MAX % QK_VEC_SIZE == 0);
|
|
// We will use block wide reduction if needed
|
|
// The number of vectors per Dh_MAX.
|
|
constexpr unsigned QK_VECS_PER_Dh_MAX{Dh_MAX / QK_VEC_SIZE};
|
|
static_assert(THREADS_PER_BLOCK >= QK_VECS_PER_Dh_MAX);
|
|
|
|
// The batch/beam idx
|
|
auto const batch_beam_idx = blockIdx.y;
|
|
if (params.finished != nullptr && params.finished[batch_beam_idx])
|
|
{
|
|
return;
|
|
}
|
|
|
|
// The head.
|
|
unsigned const hi{blockIdx.x};
|
|
// The head index of keys and values adjusted for MQA/GQA.
|
|
int const qhead_per_kv{params.num_heads / params.num_kv_heads};
|
|
unsigned const hi_kv{hi / qhead_per_kv};
|
|
// The number of heads.
|
|
auto const num_heads = static_cast<unsigned>(params.num_heads);
|
|
// The number of heads for keys and values adjusted for MQA/GQA.
|
|
auto const num_heads_kv = static_cast<unsigned>(params.num_kv_heads);
|
|
|
|
// The thread in the block.
|
|
unsigned const tidx{threadIdx.x};
|
|
|
|
// The column tile along L dimension on K^T -- noted as T_c in flash-attention paper
|
|
unsigned const c_tile{MULTI_BLOCK_FLAG ? blockIdx.z : 0};
|
|
|
|
// Indicate if we need to compute the K/V cache element (add KV bias, IA3, RoPE, etc.) and update the cache.
|
|
// For Self-Attention, it's always required.
|
|
// For Cross-Attention, as everything is pre-computed,
|
|
// in the context phase of the encoder, it's not needed in that kernel.
|
|
// Therefore, HANDLE_KV is !DO_CROSS_ATTENTION and irrelevant of timestep.
|
|
static constexpr bool HANDLE_KV{!DO_CROSS_ATTENTION};
|
|
|
|
// While doing the product Q*K^T for the different keys we track the max.
|
|
float qk_max = -FLT_MAX;
|
|
|
|
float qk = 0.0F;
|
|
|
|
// Do we have a relative attention bias?
|
|
bool has_relative_attention_bias = params.relative_attention_bias != nullptr;
|
|
// IMPLICIT_REL_ATTN_BIAS:
|
|
// Compute relative attention bias on the fly, with relative attention table [head_num/TP, num_buckets] passed in.
|
|
// num_buckets passed as relative_attention_bias_stride, max_distance passed as params.max_distance
|
|
// this is a common optimization for both self attention and cross attention
|
|
int relative_attention_bias_stride
|
|
= params.relative_attention_bias_stride; // num_buckets might be modified below, save it beforehand
|
|
int max_distance = params.max_distance;
|
|
|
|
// The actual sequence length excluding the paddings.
|
|
// minus 1 because it includes the current timestep while tlength denotes the kv cache length.
|
|
int const tlength = DO_CROSS_ATTENTION
|
|
? params.memory_length_per_sample[batch_beam_idx] - 1
|
|
: (params.length_per_sample ? (params.length_per_sample[batch_beam_idx] - 1) : static_cast<int>(timestep));
|
|
// We will use cyclic kv cache when it exceeds the limit.
|
|
// The length position for storing new key and value.
|
|
int const cyclic_tlength = kvCacheBuffer.getKVTokenIdx(tlength);
|
|
// When enable cyclic kv cache and one more block mode, we need to shift the index to the actual index in the
|
|
// sequence. Otherwise, if the token is not the sink token, we need to add the bubblen length to the index.
|
|
bool const enable_use_seq_idx_kv = kvCacheBuffer.mEnableOneMoreBlock && tlength > cyclic_kv_cache_len;
|
|
int const shift_for_cyclic_kv = (enable_use_seq_idx_kv) ? tlength - cyclic_kv_cache_len : kvCacheBuffer.mBubbleLen;
|
|
int const shift_for_cyclic_k = (enable_use_seq_idx_kv) ? tlength - cyclic_kv_cache_len : pastKCache.mBubbleLen;
|
|
// The actual kv cache length.
|
|
// tlength is the past length actually.
|
|
int const kv_loop_length = min(tlength, cyclic_kv_cache_len);
|
|
// The context length for beam searching optimization (all points to beam 0).
|
|
// TODO: with cyclic kv cache, we set it 0 for now (will optimize in the future)
|
|
// as context kv cache might be overwritten by the new kv cache
|
|
int const beam0_context_length
|
|
= HAS_BEAMS && tlength > cyclic_kv_cache_len ? 0 : params.input_lengths[batch_beam_idx];
|
|
// The position of the current timestep, and it is used to apply the position embedding
|
|
int const current_pos_idx = (!POS_SHIFT || DO_CROSS_ATTENTION) ? tlength : kv_loop_length;
|
|
|
|
// The offset in the Q and K buffer also accounts for the batch.
|
|
auto const qk_vec_idx = tidx * QK_VEC_SIZE;
|
|
auto const is_valid_qk_vec = qk_vec_idx < Dh;
|
|
|
|
bool const load_qkv_quant = params.qkv_scale_quant_orig != nullptr;
|
|
bool const write_attention_quant = params.attention_out_scale_orig_quant != nullptr;
|
|
|
|
// Quant/Dequant scales for 8bits kv cache.
|
|
using T_scale = typename kv_cache_scale_type_t<T, Tcache>::Type;
|
|
T_scale kv_scale_orig_quant, k_scale_quant_orig;
|
|
float const k_scale_quant_orig_f = (ENABLE_8BITS_K_CACHE ? params.kv_scale_quant_orig[0] : 1.0f);
|
|
float const kv_scale_quant_orig_f = (ENABLE_8BITS_KV_CACHE ? params.kv_scale_quant_orig[0] : 1.0f);
|
|
convert_from_float(&k_scale_quant_orig, k_scale_quant_orig_f);
|
|
convert_from_float(&kv_scale_orig_quant, (ENABLE_8BITS_KV_CACHE ? params.kv_scale_orig_quant[0] : 1.0f));
|
|
|
|
// Up to QK_VECS_PER_Dh_MAX threads load Q and K + the bias values for the current timestep.
|
|
// Trigger the loads from the Q and K buffers.
|
|
Qk_vec_k q, k, q_bias, k_bias;
|
|
// key without position embedding
|
|
Qk_vec_k k_wo_pos;
|
|
zero(q);
|
|
zero(k);
|
|
zero(q_bias);
|
|
zero(k_bias);
|
|
zero(k_wo_pos);
|
|
float rotary_embedding_base = params.rotary_embedding_base;
|
|
float rotary_embedding_scale = params.rotary_embedding_scale;
|
|
if (is_valid_qk_vec)
|
|
{
|
|
mmha::update_rotary_base_n_scale(rotary_embedding_base, rotary_embedding_scale,
|
|
params.rotary_embedding_scale_type, params.rotary_embedding_dim, params.rotary_embedding_max_positions,
|
|
current_pos_idx);
|
|
// Query
|
|
// The stride between tokens. We may be able to always use params.stride.
|
|
uint32_t q_stride = params.stride ? static_cast<uint32_t>(params.stride) : (num_heads * Dh);
|
|
// The offset.
|
|
auto const q_offset = tensorrt_llm::common::flat_index_strided3(batch_beam_idx, hi, qk_vec_idx, q_stride, Dh);
|
|
|
|
if (load_qkv_quant)
|
|
{
|
|
using Packed_Int8_t = typename packed_type<int8_t, num_elems<Qk_vec_m>::value>::type;
|
|
using Packed_Float_t = typename packed_type<float, num_elems<Qk_vec_m>::value>::type;
|
|
auto const q_scaling = params.qkv_scale_quant_orig[0];
|
|
auto const q_quant
|
|
= *reinterpret_cast<Packed_Int8_t const*>(&reinterpret_cast<int8_t const*>(params.q)[q_offset]);
|
|
convert_from_float(&q, mul<Packed_Float_t, float>(q_scaling, float_from_int8(q_quant)));
|
|
}
|
|
else
|
|
{
|
|
q = vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<Qk_vec_m const*>(¶ms.q[q_offset]));
|
|
}
|
|
|
|
if constexpr (DO_CROSS_ATTENTION)
|
|
{
|
|
auto const k_idx = QK_VEC_SIZE * tidx;
|
|
int const inBlockIdx = kvCacheBuffer.getKVLocalIdx(cyclic_tlength, hi, Dh, k_idx);
|
|
Tcache* k_cache = reinterpret_cast<Tcache*>(kvCacheBuffer.getKBlockPtr(batch_beam_idx, cyclic_tlength));
|
|
|
|
k = vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<Qk_vec_m const*>(&k_cache[inBlockIdx]));
|
|
}
|
|
else
|
|
{
|
|
// Key
|
|
// The stride between tokens. We may be able to always use params.stride.
|
|
uint32_t k_stride = params.stride ? static_cast<uint32_t>(params.stride) : (num_heads_kv * Dh);
|
|
// The offset.
|
|
auto const k_offset
|
|
= tensorrt_llm::common::flat_index_strided3(batch_beam_idx, hi_kv, qk_vec_idx, k_stride, Dh);
|
|
|
|
if (load_qkv_quant)
|
|
{
|
|
using Packed_Int8_t = typename packed_type<int8_t, num_elems<Qk_vec_m>::value>::type;
|
|
using Packed_Float_t = typename packed_type<float, num_elems<Qk_vec_m>::value>::type;
|
|
auto const k_scaling = params.qkv_scale_quant_orig[1];
|
|
auto const k_quant
|
|
= *reinterpret_cast<Packed_Int8_t const*>(&reinterpret_cast<int8_t const*>(params.k)[k_offset]);
|
|
|
|
convert_from_float(&k, mul<Packed_Float_t, float>(k_scaling, float_from_int8(k_quant)));
|
|
}
|
|
else
|
|
{
|
|
k = vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<Qk_vec_m const*>(¶ms.k[k_offset]));
|
|
}
|
|
}
|
|
|
|
if (params.q_bias != nullptr)
|
|
{
|
|
auto const q_bias_offset = tensorrt_llm::common::flat_index2(hi, qk_vec_idx, Dh);
|
|
q_bias
|
|
= vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<Qk_vec_m const*>(¶ms.q_bias[q_bias_offset]));
|
|
}
|
|
if (HANDLE_KV && params.k_bias != nullptr)
|
|
{
|
|
auto const k_bias_offset = tensorrt_llm::common::flat_index2(hi_kv, qk_vec_idx, Dh);
|
|
k_bias
|
|
= vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<Qk_vec_m const*>(¶ms.k_bias[k_bias_offset]));
|
|
}
|
|
}
|
|
|
|
// Computes the Q/K values with bias.
|
|
q = add(q, q_bias);
|
|
if (HANDLE_KV)
|
|
{
|
|
k = add(k, k_bias);
|
|
}
|
|
|
|
// The width of the beam.
|
|
auto const beam_width = static_cast<unsigned>(params.beam_width);
|
|
// The batch idx.
|
|
int const batch_idx = batch_beam_idx / beam_width;
|
|
// Do we apply IA3?
|
|
bool const do_ia3 = HANDLE_KV && params.ia3_tasks != nullptr;
|
|
// Compute the IA3 task. One per batch index.
|
|
auto const ia3_ti_hi = do_ia3
|
|
? tensorrt_llm::common::flat_index2(static_cast<unsigned>(params.ia3_tasks[batch_idx]), hi, num_heads)
|
|
: 0;
|
|
|
|
if (do_ia3 && is_valid_qk_vec)
|
|
{
|
|
k = mul<Qk_vec_k, Qk_vec_k, Qk_vec_k>(k,
|
|
vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<Qk_vec_m const*>(
|
|
¶ms.ia3_key_weights[tensorrt_llm::common::flat_index2(ia3_ti_hi, qk_vec_idx, Dh)])));
|
|
}
|
|
k_wo_pos = k;
|
|
|
|
// Note we have no paddings in KV cache now.
|
|
switch (params.position_embedding_type)
|
|
{
|
|
case PositionEmbeddingType::kLEARNED_ABSOLUTE:
|
|
case PositionEmbeddingType::kRELATIVE:
|
|
case PositionEmbeddingType::kALIBI:
|
|
case PositionEmbeddingType::kALIBI_WITH_SCALE:
|
|
{
|
|
break;
|
|
}
|
|
case PositionEmbeddingType::kROPE_GPTJ:
|
|
{
|
|
if (HANDLE_KV)
|
|
{
|
|
apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, rotary_embedding_base,
|
|
rotary_embedding_scale, 0, nullptr, current_pos_idx);
|
|
}
|
|
else
|
|
{
|
|
apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, rotary_embedding_base, rotary_embedding_scale,
|
|
0, nullptr, current_pos_idx);
|
|
}
|
|
break;
|
|
}
|
|
case PositionEmbeddingType::kLONG_ROPE:
|
|
case PositionEmbeddingType::kROPE_GPT_NEOX:
|
|
{
|
|
bool const do_rotary = is_valid_qk_vec && QK_VEC_SIZE * tidx < params.rotary_embedding_dim;
|
|
|
|
T* q_smem_ = reinterpret_cast<T*>(smem_);
|
|
T* k_smem_ = q_smem_ + params.rotary_embedding_dim;
|
|
|
|
int const half_rotary_dim = params.rotary_embedding_dim / 2;
|
|
int const half_idx = qk_vec_idx / half_rotary_dim;
|
|
int const intra_half_idx = qk_vec_idx % half_rotary_dim;
|
|
int const smem_pitch = half_rotary_dim; // TODO: adjust for bank conflicts
|
|
|
|
assert(half_rotary_dim % QK_VEC_SIZE == 0);
|
|
|
|
if (do_rotary)
|
|
{
|
|
*reinterpret_cast<Qk_vec_k*>(q_smem_ + half_idx * smem_pitch + intra_half_idx) = q;
|
|
if (HANDLE_KV)
|
|
{
|
|
*reinterpret_cast<Qk_vec_k*>(k_smem_ + half_idx * smem_pitch + intra_half_idx) = k;
|
|
}
|
|
}
|
|
|
|
__syncthreads();
|
|
|
|
int const transpose_idx = half_idx * (half_rotary_dim / 2) + intra_half_idx / 2;
|
|
constexpr int tidx_factor = (QK_VEC_SIZE > 1) ? QK_VEC_SIZE / 2 : 1;
|
|
if (do_rotary)
|
|
{
|
|
mmha::vec_from_smem_transpose(q, q_smem_, transpose_idx, smem_pitch);
|
|
if (HANDLE_KV)
|
|
{
|
|
mmha::vec_from_smem_transpose(k, k_smem_, transpose_idx, smem_pitch);
|
|
|
|
mmha::apply_rotary_embedding(q, k, transpose_idx / tidx_factor, params.rotary_embedding_dim,
|
|
rotary_embedding_base, rotary_embedding_scale, params.rotary_embedding_m_scale,
|
|
params.rotary_embedding_scaling_factors, current_pos_idx, params.rotary_cogvlm_vision_start,
|
|
params.rotary_cogvlm_vision_length);
|
|
|
|
mmha::write_smem_transpose(k, k_smem_, transpose_idx, smem_pitch);
|
|
}
|
|
else
|
|
{
|
|
mmha::apply_rotary_embedding(q, transpose_idx / tidx_factor, params.rotary_embedding_dim,
|
|
rotary_embedding_base, rotary_embedding_scale, params.rotary_embedding_m_scale,
|
|
params.rotary_embedding_scaling_factors, current_pos_idx, params.rotary_cogvlm_vision_start,
|
|
params.rotary_cogvlm_vision_length);
|
|
}
|
|
mmha::write_smem_transpose(q, q_smem_, transpose_idx, smem_pitch);
|
|
}
|
|
|
|
__syncthreads();
|
|
|
|
if (do_rotary)
|
|
{
|
|
q = *reinterpret_cast<Qk_vec_k*>(q_smem_ + half_idx * smem_pitch + intra_half_idx);
|
|
if (HANDLE_KV)
|
|
{
|
|
k = *reinterpret_cast<Qk_vec_k*>(k_smem_ + half_idx * smem_pitch + intra_half_idx);
|
|
}
|
|
}
|
|
|
|
__syncthreads();
|
|
break;
|
|
}
|
|
}
|
|
|
|
// For the same reason as HANDLE_KV, no compute needed in Cross-Attention's 1st step
|
|
// Store Q K vectors to shared memory, and calculate QK.
|
|
if (qk_vec_idx < Dh_MAX)
|
|
{
|
|
|
|
// Store the Q values to shared memory.
|
|
#ifdef MMHA_FP8_SCALE_Q_INSTEAD_OF_K
|
|
if constexpr (FP8_K_CACHE)
|
|
{
|
|
// There are many more elements from K than elements from Q so we pre-scale Q instead
|
|
// of scaling all the elements from K. It helps reduce the number of ops.
|
|
Qk_vec_k scaled_q;
|
|
zero(scaled_q);
|
|
if (is_valid_qk_vec)
|
|
{
|
|
scaled_q = mul<Qk_vec_k, Tk, Qk_vec_k>(k_scale_quant_orig, q);
|
|
}
|
|
reinterpret_cast<Qk_vec_k*>(&q_smem[qk_vec_idx])[0] = scaled_q;
|
|
}
|
|
else
|
|
#endif
|
|
{
|
|
// Set padded Dh to 0 for the correctness of QK (when Dh != Dh_Max).
|
|
Qk_vec_k zero_q;
|
|
zero(zero_q);
|
|
reinterpret_cast<Qk_vec_k*>(&q_smem[qk_vec_idx])[0] = is_valid_qk_vec ? q : zero_q;
|
|
}
|
|
|
|
// Store the K values to shared memory.
|
|
// We store K values from shared memory to global memory
|
|
// when the target position of K cache in global memory has been accessed (in the case of cyclic kv cache)
|
|
if (POS_SHIFT && !DO_CROSS_ATTENTION)
|
|
{
|
|
reinterpret_cast<Qk_vec_k*>(&k_smem[qk_vec_idx])[0] = k_wo_pos;
|
|
}
|
|
else
|
|
{
|
|
reinterpret_cast<Qk_vec_k*>(&k_smem[qk_vec_idx])[0] = k;
|
|
}
|
|
|
|
// Compute \sum_i Q[i] * K^T[i] for the current timestep.
|
|
qk = dot<Qk_vec_accum, Qk_vec_k>(q, k);
|
|
if (QK_VECS_PER_Dh_MAX <= WARP_SIZE)
|
|
{
|
|
#pragma unroll
|
|
for (int mask = QK_VECS_PER_Dh_MAX / 2; mask >= 1; mask /= 2)
|
|
{
|
|
qk += __shfl_xor_sync(shfl_mask(QK_VECS_PER_Dh_MAX), qk, mask);
|
|
}
|
|
}
|
|
}
|
|
|
|
if (QK_VECS_PER_Dh_MAX > WARP_SIZE)
|
|
{
|
|
constexpr int WARPS_PER_RED = (QK_VECS_PER_Dh_MAX + WARP_SIZE - 1) / WARP_SIZE;
|
|
qk = block_sum<WARPS_PER_RED>(&red_smem[WARPS_PER_RED], qk);
|
|
}
|
|
|
|
// Pre-compute the pointer for the relative attention bias.
|
|
T const* relative_attention_bias_ptr = nullptr;
|
|
T const* relative_attention_bias_ptr_fixed = nullptr; // record the base for offset
|
|
if (has_relative_attention_bias)
|
|
{
|
|
// "hi" is unsigned, subtracting int from unsigned int causes underflow. Cast to int
|
|
int64_t offset = IMPLICIT_REL_ATTN_BIAS
|
|
? ((int64_t) hi * relative_attention_bias_stride - tlength)
|
|
: ((int64_t) hi * relative_attention_bias_stride + tlength) * relative_attention_bias_stride;
|
|
relative_attention_bias_ptr = ¶ms.relative_attention_bias[offset];
|
|
relative_attention_bias_ptr_fixed = ¶ms.relative_attention_bias[offset];
|
|
}
|
|
|
|
// Load the value.
|
|
float relative_attention_bias = 0.f;
|
|
if (has_relative_attention_bias && tidx == 0)
|
|
{
|
|
// TODO: Use a better way to convert from T to float.
|
|
relative_attention_bias = add(relative_attention_bias, relative_attention_bias_ptr[tlength]);
|
|
}
|
|
|
|
// Store that value in shared memory. Keep the Q*K^T value in register for softmax.
|
|
if (tidx == 0)
|
|
{
|
|
// Normalize qk.
|
|
qk = qk * params.inv_sqrt_dh + relative_attention_bias;
|
|
|
|
// We don't need to apply the linear position bias here since qi - ki = 0 yields the position bias 0.
|
|
qk_max = qk;
|
|
|
|
// Store Q*K^T to shared memory.
|
|
if (MULTI_BLOCK_FLAG)
|
|
{
|
|
qk_current_smem[0] = qk;
|
|
}
|
|
else
|
|
{
|
|
// We need to store the qk result to the end of the qk_smem for cyclic kv cache (+ 1 for smem memory
|
|
// allocation) because the previous cache will still write to the new_cache_pos of qk_smem.
|
|
qk_smem[kv_loop_length] = qk;
|
|
}
|
|
}
|
|
|
|
// Make sure the data is in shared memory.
|
|
__syncthreads();
|
|
|
|
constexpr unsigned K_ELTS_PER_CHUNK{THREADS_PER_KEY * K_VEC_SIZE};
|
|
|
|
// The positions of the cache buffer (for this B * H) and the vector within that chunk associated with this
|
|
// thread.
|
|
auto const k_idx = chunk_index<T, K_vec_k, THREADS_PER_KEY>(tidx);
|
|
|
|
// The number of vectors per thread.
|
|
constexpr unsigned K_VECS_PER_THREAD{Dh_MAX / K_ELTS_PER_CHUNK};
|
|
static_assert(Dh_MAX == K_ELTS_PER_CHUNK * K_VECS_PER_THREAD);
|
|
|
|
// Load the Q values from shared memory. The values are reused during the loop on K.
|
|
K_vec_accum q_vec[K_VECS_PER_THREAD];
|
|
#pragma unroll
|
|
for (unsigned ii = 0; ii < K_VECS_PER_THREAD; ++ii)
|
|
{
|
|
q_vec[ii] = vec_conversion<K_vec_accum, K_vec_k>(*reinterpret_cast<K_vec_k const*>(
|
|
&q_smem[tensorrt_llm::common::flat_index2(ii, k_idx.y, K_ELTS_PER_CHUNK)]));
|
|
}
|
|
|
|
// The number of timesteps loaded per iteration, i.e., (THREADS_PER_BLOCK * THREADS_PER_BLOCK) / 256 <= 256
|
|
constexpr unsigned K_PER_ITER{THREADS_PER_BLOCK / THREADS_PER_KEY};
|
|
// The number of keys per warp.
|
|
constexpr unsigned K_PER_WARP{WARP_SIZE / THREADS_PER_KEY};
|
|
// The number of unrolled keys per warp.
|
|
constexpr unsigned UNROLLED_K_PER_WARP = K_PER_WARP * K_LOOP_UNROLL;
|
|
// The number of unrolled keys per ieration.
|
|
constexpr unsigned UNROLLED_K_PER_ITER = K_PER_ITER * K_LOOP_UNROLL;
|
|
|
|
auto const timesteps_per_block = static_cast<unsigned>(params.timesteps_per_block);
|
|
|
|
// Pick a number of keys to make sure all the threads of a warp enter (due to shfl_sync).
|
|
// Take all previous cache as context when we have no beam searching in order to batch as many LDGs as possible.
|
|
int const context_length
|
|
= DO_CROSS_ATTENTION ? kv_loop_length : (HAS_BEAMS ? beam0_context_length : kv_loop_length);
|
|
// Clarifications:
|
|
// - in self attn, input_length is input text length, tlength is current timestep
|
|
// - in cross attn, input_length is *decoder* input length (usually 1), tlength is *encoder* input context length
|
|
// - in beam search, since the cache during generation is organized differently, the following KV compute needs
|
|
// split into context cache compute and generation cache compute
|
|
// - for self attn, no-beam search: entire cache can be treated as context cache --> context_length = tlength
|
|
// - for self attn, beam search: cache of input text length is context cache, other are generation cache -->
|
|
// context_length = input_length
|
|
// - for cross attn, no-beam/beam search: cache length is fixed, not differ context/generation cache -->
|
|
// context_length = tlength Suggestion: we could have a flag HANDLE_GEN_CACHE
|
|
|
|
auto const context_ti_end = MULTI_BLOCK_FLAG
|
|
? divUp(timesteps_per_block, UNROLLED_K_PER_WARP) * UNROLLED_K_PER_WARP
|
|
: divUp(static_cast<unsigned>(context_length), UNROLLED_K_PER_WARP) * UNROLLED_K_PER_WARP;
|
|
|
|
// The generation ti_end.
|
|
auto const generation_ti_end = MULTI_BLOCK_FLAG
|
|
? divUp(timesteps_per_block, K_PER_WARP) * K_PER_WARP
|
|
: divUp(static_cast<unsigned>(kv_loop_length), K_PER_WARP) * K_PER_WARP;
|
|
|
|
// Iterate over the keys/timesteps to compute the various (Q*K^T)_{ti} values.
|
|
// Note max_attention_window_size is maximum of cyclic_attention_window_size among all layers.
|
|
// By default, you can assume that they are the same.
|
|
auto const bi_seq_len_offset = static_cast<std::size_t>(batch_beam_idx) * params.max_attention_window_size;
|
|
// Beam indices are based on the max_attention_window_size while each layer may have different
|
|
// cyclic_attention_window_size So we need to rebuild the beam_indices if max_attention_window_size is not equal to
|
|
// cyclic_attention_window_size.
|
|
int const* beam_indices = HAS_BEAMS ? ¶ms.cache_indir[bi_seq_len_offset] : nullptr;
|
|
|
|
auto const c_tile_times_timesteps_per_block = c_tile * timesteps_per_block; // 0 if !MULTI_BLOCK_FLAG
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////
|
|
// Key cache loops for dot(Q, K).
|
|
|
|
// Is it the leader?
|
|
bool const is_leader = Qk_dot<T, THREADS_PER_KEY>::is_leader(tidx);
|
|
|
|
// The slope for ALiBi.
|
|
float linear_bias_slope = 0.f;
|
|
if (params.linear_bias_slopes != nullptr)
|
|
{
|
|
// TODO: Use a cleaner code to convert from T to float.
|
|
linear_bias_slope = mul<float>(params.linear_bias_slopes[hi], 1.f);
|
|
}
|
|
|
|
// Handle only context key cache with beam searching.
|
|
// Handle both context and generation key cache without beam searching.
|
|
// Explicit batching of LDGs (by K_LOOP_UNROLL) as it doesn't depend on indirection tables.
|
|
for (int ti = k_idx.x; ti < context_ti_end; ti += UNROLLED_K_PER_ITER)
|
|
{
|
|
int const time_now = MULTI_BLOCK_FLAG ? ti + c_tile_times_timesteps_per_block : ti;
|
|
|
|
// The keys loaded from the key cache.
|
|
K_vec_m k_vec_cache[K_LOOP_UNROLL][K_VECS_PER_THREAD];
|
|
|
|
#pragma unroll
|
|
for (int k_loop = 0; k_loop < K_LOOP_UNROLL; ++k_loop)
|
|
{
|
|
#pragma unroll
|
|
for (int k_vec_i = 0; k_vec_i < K_VECS_PER_THREAD; ++k_vec_i)
|
|
{
|
|
// Make sure we read data within the bound.
|
|
// Dh OOB values will be handled by zero_q.
|
|
// Seq OOB values will be masked out when storing back to smem.
|
|
auto const jj = min(k_idx.y + k_vec_i * K_ELTS_PER_CHUNK, Dh - K_VEC_SIZE);
|
|
int valid_time_now = min(time_now + k_loop * K_PER_ITER, context_length - 1);
|
|
if (POS_SHIFT && valid_time_now >= sink_token_len)
|
|
{
|
|
// If one more block mode is enabled, we use the index in sequence as tokenIdx.
|
|
// Otherwise, we need to add the bubble length to the index
|
|
valid_time_now += shift_for_cyclic_k;
|
|
if (enable_use_seq_idx_kv)
|
|
{
|
|
// Convert the token index in sequence to token index in K cache.
|
|
valid_time_now = pastKCache.getKVTokenIdx(valid_time_now);
|
|
}
|
|
}
|
|
int const seqIdx = batch_idx * beam_width;
|
|
|
|
// Base pointer to k cache block for beam's batch
|
|
TKcache* k_cache_batch = reinterpret_cast<TKcache*>(pastKCache.getKBlockPtr(seqIdx, valid_time_now));
|
|
|
|
int inBlockIdx = pastKCache.getKVLocalIdx(valid_time_now, hi_kv, Dh, jj);
|
|
k_vec_cache[k_loop][k_vec_i] = *reinterpret_cast<K_vec_m const*>(&k_cache_batch[inBlockIdx]);
|
|
}
|
|
}
|
|
|
|
#pragma unroll
|
|
for (int k_loop = 0; k_loop < K_LOOP_UNROLL; ++k_loop)
|
|
{
|
|
int const local_time_now = time_now + k_loop * K_PER_ITER;
|
|
int const local_ti = ti + k_loop * K_PER_ITER;
|
|
|
|
// Perform the dot product and normalize qk.
|
|
//
|
|
// WARNING: ALL THE THREADS OF A WARP MUST ENTER!!!
|
|
K_vec_m k_vec[K_VECS_PER_THREAD];
|
|
#pragma unroll
|
|
for (int k_vec_i = 0; k_vec_i < K_VECS_PER_THREAD; ++k_vec_i)
|
|
{
|
|
k_vec[k_vec_i] = *reinterpret_cast<K_vec_m*>(&k_vec_cache[k_loop][k_vec_i]);
|
|
}
|
|
|
|
// Is it active?
|
|
bool const is_active = local_time_now < context_length;
|
|
|
|
if constexpr (IMPLICIT_REL_ATTN_BIAS)
|
|
{
|
|
// Compute bias value on the fly (See bert_preprocess_kernels.cu::buildRelativeAttentionBias)
|
|
int relative_buckets = 0;
|
|
int relative_position = local_time_now - tlength;
|
|
int num_buckets = relative_attention_bias_stride;
|
|
// Special logic in T5 relative attention, both encoder & decoder use this, because
|
|
// relative_attention_bias is pre-computed once and passed around.
|
|
// T5 decoder attention now only uses bidirectional=False relative position logic
|
|
// (ref: tensorrt_llm/layers/attention.py compute_relative_bias())
|
|
relative_position = relative_position >= 0 ? 0 : -relative_position;
|
|
|
|
int max_exact = num_buckets / 2;
|
|
bool is_small = relative_position < max_exact;
|
|
int relative_position_if_large = max_exact
|
|
+ (int) (logf(relative_position * 1.0f / max_exact) / logf((float) max_distance / max_exact)
|
|
* (num_buckets - max_exact));
|
|
relative_position_if_large = min(relative_position_if_large, num_buckets - 1);
|
|
relative_buckets += is_small ? relative_position : relative_position_if_large;
|
|
relative_attention_bias_ptr
|
|
= relative_attention_bias_ptr_fixed + (tlength - local_time_now) + relative_buckets;
|
|
}
|
|
|
|
// Prefetch the relative attention bias.
|
|
float relative_attention_bias = 0.f;
|
|
if (is_active && has_relative_attention_bias)
|
|
{
|
|
// TODO: Use a better way to convert from T to float.
|
|
relative_attention_bias = add(relative_attention_bias, relative_attention_bias_ptr[local_time_now]);
|
|
}
|
|
|
|
// Compute the dot product between Q and K.
|
|
// Note that dot will convert 8bit vec to the accumulation data type (float by default).
|
|
float qk_ = 0.f;
|
|
#ifdef MMHA_FP8_SCALE_Q_INSTEAD_OF_K
|
|
if constexpr (FP8_K_CACHE)
|
|
{
|
|
qk_ = Qk_dot<T, THREADS_PER_KEY>::dot(q_vec, k_vec) * params.inv_sqrt_dh;
|
|
}
|
|
else
|
|
#endif // MMHA_FP8_SCALE_Q_INSTEAD_OF_K
|
|
{
|
|
if constexpr (ENABLE_8BITS_K_CACHE)
|
|
{
|
|
qk_ = Qk_dot<T, THREADS_PER_KEY>::scale_dot(q_vec, k_vec, k_scale_quant_orig_f)
|
|
* params.inv_sqrt_dh;
|
|
}
|
|
else
|
|
{
|
|
qk_ = Qk_dot<T, THREADS_PER_KEY>::dot(q_vec, k_vec) * params.inv_sqrt_dh;
|
|
}
|
|
}
|
|
|
|
// For multi-block mode, we need to make sure it will not be OOB.
|
|
if (MULTI_BLOCK_FLAG && local_ti >= timesteps_per_block)
|
|
{
|
|
continue;
|
|
}
|
|
|
|
// Add the ALiBi bias. (ki - qi) * slope[hi].
|
|
//
|
|
// The padding tokens are located between the input context and the generated tokens.
|
|
// We need to remove the correct number of padding tokens in the distance computation.
|
|
//
|
|
// ti : 0 1 2 3 4 5 6 7 8 9(tlength)
|
|
// token: i i i i p p p o o o where i=input, p=pad, o=output.
|
|
// e.g. ti = 2, dist = (9 - 3) - 2 = 4.
|
|
//
|
|
// All the threads do the work even if it's not relevant to avoid divergence.
|
|
qk_ += linear_bias_slope * (local_time_now - tlength) + relative_attention_bias;
|
|
|
|
// There's one qk value per timestep.
|
|
// Make sure only leader threads stores qk value within the bound.
|
|
if (is_active && is_leader)
|
|
{
|
|
// Calculate the max for softmax.
|
|
qk_max = fmaxf(qk_max, qk_);
|
|
// Store the product to shared memory.
|
|
qk_smem[local_ti] = qk_;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Handle generation key cache with beam searching.
|
|
// Note that it may be overlapped with the context key loop, but it won't impact the corretness.
|
|
// Can skip in cross attention mode.
|
|
if (HAS_BEAMS && !DO_CROSS_ATTENTION
|
|
&& (!MULTI_BLOCK_FLAG || (c_tile + 1) * timesteps_per_block > beam0_context_length))
|
|
{
|
|
// The input length;
|
|
int const input_length_ = MULTI_BLOCK_FLAG ? beam0_context_length % timesteps_per_block : beam0_context_length;
|
|
// The beginning of the generation.
|
|
int const generation_start_ti = k_idx.x + input_length_ / K_PER_WARP * K_PER_WARP;
|
|
|
|
// Iterate over the output tokens.
|
|
for (int ti = generation_start_ti; ti < generation_ti_end; ti += K_PER_ITER)
|
|
{
|
|
int const time_now = MULTI_BLOCK_FLAG ? ti + c_tile_times_timesteps_per_block : ti;
|
|
|
|
// The keys loaded from the key cache.
|
|
K_vec_m k_vec[K_VECS_PER_THREAD];
|
|
|
|
#pragma unroll
|
|
for (int k_vec_i = 0; k_vec_i < K_VECS_PER_THREAD; ++k_vec_i)
|
|
{
|
|
int const jj = min(k_idx.y + k_vec_i * K_ELTS_PER_CHUNK, Dh - K_VEC_SIZE);
|
|
int valid_time_now = min(time_now, kv_loop_length - 1);
|
|
int beam_offset = beam_indices[valid_time_now];
|
|
if (POS_SHIFT && valid_time_now >= sink_token_len)
|
|
{
|
|
// If one more block mode is enabled, we use the index in sequence as tokenIdx.
|
|
// Otherwise, we need to add the bubble length to the index
|
|
valid_time_now += shift_for_cyclic_k;
|
|
if (enable_use_seq_idx_kv)
|
|
{
|
|
// Convert the token index in sequence to token index in K cache.
|
|
valid_time_now = pastKCache.getKVTokenIdx(valid_time_now);
|
|
}
|
|
}
|
|
int const seqIdx = batch_idx * beam_width + beam_offset;
|
|
// Base pointer to k cache block for beam's batch, before offsetting with indirection buffer
|
|
TKcache* k_cache_batch = reinterpret_cast<TKcache*>(pastKCache.getKBlockPtr(seqIdx, valid_time_now));
|
|
|
|
int inBlockIdx = pastKCache.getKVLocalIdx(valid_time_now, hi_kv, Dh, jj);
|
|
k_vec[k_vec_i] = (*reinterpret_cast<K_vec_m const*>(&k_cache_batch[inBlockIdx]));
|
|
}
|
|
|
|
// Is it active?
|
|
bool const is_active = time_now >= context_length && time_now < kv_loop_length;
|
|
|
|
if constexpr (IMPLICIT_REL_ATTN_BIAS)
|
|
{
|
|
// Compute bias value on the fly (See bert_preprocess_kernels.cu::buildRelativeAttentionBias)
|
|
int relative_buckets = 0;
|
|
int relative_position = time_now - tlength;
|
|
int num_buckets = relative_attention_bias_stride;
|
|
// Special logic in T5 relative attention, both encoder & decoder use this, because
|
|
// relative_attention_bias is pre-computed once and passed around.
|
|
// T5 decoder attention now only uses bidirectional=False relative position logic
|
|
// (ref: tensorrt_llm/layers/attention.py compute_relative_bias())
|
|
relative_position = relative_position >= 0 ? 0 : -relative_position;
|
|
|
|
int max_exact = num_buckets / 2;
|
|
bool is_small = relative_position < max_exact;
|
|
int relative_position_if_large = max_exact
|
|
+ (int) (logf(relative_position * 1.0f / max_exact) / logf((float) max_distance / max_exact)
|
|
* (num_buckets - max_exact));
|
|
relative_position_if_large = min(relative_position_if_large, num_buckets - 1);
|
|
relative_buckets += is_small ? relative_position : relative_position_if_large;
|
|
relative_attention_bias_ptr
|
|
= relative_attention_bias_ptr_fixed + (tlength - time_now) + relative_buckets;
|
|
}
|
|
|
|
// Prefetch the relative attention bias.
|
|
float relative_attention_bias = 0.f;
|
|
if (is_active && has_relative_attention_bias)
|
|
{
|
|
// TODO: Use a better way to convert from T to float.
|
|
relative_attention_bias = add(relative_attention_bias, relative_attention_bias_ptr[time_now]);
|
|
}
|
|
|
|
// Perform the dot product and normalize qk.
|
|
//
|
|
// WARNING: ALL THE THREADS OF A WARP MUST ENTER!!!
|
|
// Note that dot will convert 8bit vec to the accumulation data type (float by default).
|
|
float qk_ = 0.f;
|
|
#ifdef MMHA_FP8_SCALE_Q_INSTEAD_OF_K
|
|
if constexpr (FP8_K_CACHE)
|
|
{
|
|
qk_ = Qk_dot<T, THREADS_PER_KEY>::dot(q_vec, k_vec) * params.inv_sqrt_dh;
|
|
}
|
|
else
|
|
#endif // MMHA_FP8_SCALE_Q_INSTEAD_OF_K
|
|
{
|
|
if constexpr (ENABLE_8BITS_K_CACHE)
|
|
{
|
|
qk_ = Qk_dot<T, THREADS_PER_KEY>::scale_dot(q_vec, k_vec, k_scale_quant_orig_f)
|
|
* params.inv_sqrt_dh;
|
|
}
|
|
else
|
|
{
|
|
qk_ = Qk_dot<T, THREADS_PER_KEY>::dot(q_vec, k_vec) * params.inv_sqrt_dh;
|
|
}
|
|
}
|
|
// Add the ALiBi bias. (ki - qi) * slope[hi].
|
|
//
|
|
// The padding tokens are located between the input context and the generated tokens.
|
|
// We need to remove the correct number of padding tokens in the distance computation.
|
|
//
|
|
// ti : 0 1 2 3 4 5 6 7 8 9(tlength)
|
|
// token: i i i i p p p o o o where i=input, p=pad, o=output.
|
|
// e.g. ti = 2, dist = (9 - 3) - 2 = 4.
|
|
//
|
|
// All the threads perform that step to avoid divergence.
|
|
qk_ += linear_bias_slope * (time_now - tlength) + relative_attention_bias;
|
|
|
|
// There's one qk value per timestep.
|
|
// Make sure only leader threads stores qk value within the bound.
|
|
if (is_active && is_leader)
|
|
{
|
|
// Calculate the max for softmax.
|
|
qk_max = fmaxf(qk_max, qk_);
|
|
// Store the product to shared memory.
|
|
qk_smem[ti] = qk_;
|
|
}
|
|
}
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////
|
|
// Softmax.
|
|
|
|
// Perform the final reduction to compute the max inside each warp.
|
|
//
|
|
// NOTE: In a group of THREADS_PER_KEY threads, the leader already has the max value for the
|
|
// group so it's not needed to run the reduction inside the group (again).
|
|
|
|
#if __CUDA_ARCH__ >= 750 && defined(MMHA_USE_HMMA)
|
|
// Leader threads will be in the dignonal when using HMMA.
|
|
if (THREADS_PER_KEY <= 4)
|
|
{
|
|
qk_max = fmaxf(qk_max, __shfl_xor_sync(unsigned(-1), qk_max, 4));
|
|
}
|
|
if (THREADS_PER_KEY <= 8)
|
|
{
|
|
qk_max = fmaxf(qk_max, __shfl_xor_sync(unsigned(-1), qk_max, 9));
|
|
}
|
|
if (THREADS_PER_KEY <= 16)
|
|
{
|
|
qk_max = fmaxf(qk_max, __shfl_xor_sync(unsigned(-1), qk_max, 18));
|
|
}
|
|
#else
|
|
#pragma unroll
|
|
for (int mask = WARP_SIZE / 2; mask >= THREADS_PER_KEY; mask /= 2)
|
|
{
|
|
qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
|
|
}
|
|
#endif // defined MMHA_USE_HMMA
|
|
|
|
// Decompose the thread index into warp and lane.
|
|
auto const warp = tidx / WARP_SIZE;
|
|
auto const lane = tidx % WARP_SIZE;
|
|
|
|
// The warp leader writes the max to shared memory.
|
|
if (lane == 0)
|
|
{
|
|
red_smem[warp] = qk_max;
|
|
}
|
|
|
|
// Make sure the products are in shared memory.
|
|
__syncthreads();
|
|
|
|
// After the syncthreads, the target k position (cyclic kv cache) should also have been used by the k loop.
|
|
// Write the K values to the global memory cache.
|
|
//
|
|
// NOTE: The stores are uncoalesced as we have multiple chunks of 16B spread across the memory
|
|
// system. We designed it this way as it allows much better memory loads (and there are many
|
|
// more loads) + the stores are really "write and forget" since we won't need the ack before
|
|
// the end of the kernel. There's plenty of time for the transactions to complete.
|
|
|
|
// For MQA/GQA mode, write only with the first Q head of each group per KV head.
|
|
if (HANDLE_KV && hi == (hi_kv * qhead_per_kv) && qk_vec_idx < Dh)
|
|
{
|
|
// Trigger the stores to global memory.
|
|
Qk_vec_k k_vec = *reinterpret_cast<Qk_vec_k*>(&k_smem[qk_vec_idx]);
|
|
auto const k_idx = QK_VEC_SIZE * tidx;
|
|
int const inBlockIdx = kvCacheBuffer.getKVLocalIdx(cyclic_tlength, hi_kv, Dh, k_idx);
|
|
// The base pointer for the value in the cache buffer.
|
|
Tcache* k_cache = reinterpret_cast<Tcache*>(kvCacheBuffer.getKBlockPtr(batch_beam_idx, cyclic_tlength));
|
|
|
|
if constexpr (ENABLE_8BITS_KV_CACHE)
|
|
{
|
|
store_8bits_kv_cache_vec(reinterpret_cast<Tcache*>(k_cache), k_vec, inBlockIdx, kv_scale_orig_quant);
|
|
}
|
|
else
|
|
{
|
|
*reinterpret_cast<Qk_vec_m*>(&k_cache[inBlockIdx]) = vec_conversion<Qk_vec_m, Qk_vec_k>(k_vec);
|
|
}
|
|
}
|
|
|
|
// The warps finalize the reduction.
|
|
qk_max = lane < WARPS_PER_BLOCK ? red_smem[lane] : -FLT_MAX;
|
|
#pragma unroll
|
|
for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2)
|
|
{
|
|
qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
|
|
}
|
|
|
|
// Broadcast to all the threads in the warp.
|
|
qk_max = __shfl_sync(uint32_t(-1), qk_max, 0);
|
|
|
|
// Compute the logits and start the sum.
|
|
float sum = 0.f;
|
|
|
|
// Each thread will handle one float (either qk_smem/logit).
|
|
int const logit_loop_end = MULTI_BLOCK_FLAG ? timesteps_per_block : kv_loop_length;
|
|
for (int ti = tidx; ti <= logit_loop_end; ti += THREADS_PER_BLOCK)
|
|
{
|
|
|
|
int const time_now = MULTI_BLOCK_FLAG ? ti + c_tile_times_timesteps_per_block : ti;
|
|
|
|
// For single-block mode, we don't need the mask since it has been skipped.
|
|
if (!MULTI_BLOCK_FLAG)
|
|
{
|
|
float logit = __expf(qk_smem[time_now] - qk_max);
|
|
sum += logit;
|
|
qk_smem[time_now] = logit;
|
|
}
|
|
else
|
|
{
|
|
// Not supported yet: multi-block mode with FP8_MHA
|
|
if (time_now < kv_loop_length && ti != timesteps_per_block)
|
|
{
|
|
float logit = __expf(qk_smem[ti] - qk_max);
|
|
sum += logit;
|
|
qk_smem[ti] = logit;
|
|
}
|
|
else if (time_now == kv_loop_length)
|
|
{
|
|
float logit = __expf(qk_current_smem[0] - qk_max);
|
|
sum += logit;
|
|
qk_current_smem[0] = logit;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Compute the sum.
|
|
sum = block_sum<WARPS_PER_BLOCK>(&red_smem[WARPS_PER_BLOCK], sum);
|
|
|
|
// Normalize the logits.
|
|
#ifdef MMHA_FP8_SCALE_P_INSTEAD_OF_V
|
|
float logit_scale = (FP8_KV_CACHE ? kv_scale_quant_orig_f : 1.0f);
|
|
#else
|
|
float logit_scale = 1.f;
|
|
#endif // MMHA_FP8_SCALE_P_INSTEAD_OF_V
|
|
float inv_sum = __fdividef(logit_scale, sum + 1.e-6f);
|
|
|
|
int const normlization_loop_end = MULTI_BLOCK_FLAG ? timesteps_per_block : kv_loop_length;
|
|
for (int ti = tidx; ti <= normlization_loop_end; ti += THREADS_PER_BLOCK)
|
|
{
|
|
int const time_now = MULTI_BLOCK_FLAG ? ti + c_tile_times_timesteps_per_block : ti;
|
|
|
|
if (!MULTI_BLOCK_FLAG)
|
|
{
|
|
convert_from_float(&logits_smem[ti], qk_smem[ti] * inv_sum);
|
|
}
|
|
else
|
|
{
|
|
// no scaling factor inv_sum applied here, will apply the scaling factor after all blocks finished
|
|
if (time_now < kv_loop_length && ti != timesteps_per_block)
|
|
{
|
|
convert_from_float(&logits_smem[ti], qk_smem[ti]);
|
|
}
|
|
else if (time_now == kv_loop_length)
|
|
{
|
|
convert_from_float(&logits_current_smem[0], qk_current_smem[0]);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Put Values part below so we leverage __syncthreads
|
|
// from the previous step
|
|
|
|
auto const v_idx = chunk_index<T, V_vec_k, THREADS_PER_VALUE>(tidx);
|
|
// The value computed by this thread.
|
|
auto const vo = v_idx.x;
|
|
// The hidden dimensions computed by this particular thread.
|
|
auto const vi = v_idx.y;
|
|
|
|
// The number of values processed per iteration of the loop.
|
|
constexpr unsigned V_PER_ITER{THREADS_PER_BLOCK / THREADS_PER_VALUE};
|
|
// The number of unrolled keys per ieration.
|
|
constexpr unsigned UNROLLED_V_PER_ITER = V_PER_ITER * V_LOOP_UNROLL;
|
|
|
|
bool const is_valid_vi = IS_Dh_MAX || vi < Dh;
|
|
|
|
// One group of threads computes the product(s) for the current timestep.
|
|
V_vec_k v_bias;
|
|
zero(v_bias);
|
|
// if( vo == params.timestep % V_PER_ITER ) {
|
|
if (is_valid_vi && HANDLE_KV && vo == kv_loop_length % V_PER_ITER)
|
|
{
|
|
// Trigger the loads from the V bias buffer.
|
|
if (params.v_bias != nullptr)
|
|
{
|
|
auto const v_bias_offset = tensorrt_llm::common::flat_index2(hi_kv, vi, Dh);
|
|
v_bias = *reinterpret_cast<V_vec_k const*>(¶ms.v_bias[v_bias_offset]);
|
|
}
|
|
|
|
if (DO_CROSS_ATTENTION)
|
|
{
|
|
*reinterpret_cast<V_vec_k*>(&bias_smem[vi]) = v_bias;
|
|
}
|
|
}
|
|
|
|
// From previous, before values, step
|
|
// Also make sure the logits are in shared memory.
|
|
__syncthreads();
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////
|
|
// Value cache loops.
|
|
|
|
#ifdef MMHA_USE_FP32_ACCUM_FOR_OUT
|
|
using V_vec_accum = typename V_vec_accum_fp32_<V_vec_k>::Type;
|
|
#else
|
|
using V_vec_accum = V_vec_k;
|
|
#endif
|
|
// The partial outputs computed by each thread.
|
|
V_vec_accum out;
|
|
zero(out);
|
|
|
|
// Loop over the timesteps to compute the partial outputs.
|
|
if (is_valid_vi)
|
|
{
|
|
// Handle only context value cache with beam searching.
|
|
// Handle both context and generation value cache without beam searching.
|
|
// Explicit batching of LDGs (by V_LOOP_UNROLL) as it doesn't depend on indirection tables.
|
|
// Take all previous cache as context when we have no beam searching in order to batch as many LDGs as possible.
|
|
int const context_length
|
|
= DO_CROSS_ATTENTION ? kv_loop_length : (HAS_BEAMS ? beam0_context_length : kv_loop_length);
|
|
int context_v_loop_end = MULTI_BLOCK_FLAG ? timesteps_per_block : context_length;
|
|
int generation_v_loop_end = MULTI_BLOCK_FLAG ? timesteps_per_block : kv_loop_length;
|
|
for (int ti = vo; ti < context_v_loop_end; ti += UNROLLED_V_PER_ITER)
|
|
{
|
|
V_vec_m v_vec_cache[V_LOOP_UNROLL];
|
|
#pragma unroll
|
|
for (int v_loop = 0; v_loop < V_LOOP_UNROLL; v_loop++)
|
|
{
|
|
// Fetch offset based on cache_indir when beam sampling
|
|
int time_idx = ti + v_loop * V_PER_ITER + (MULTI_BLOCK_FLAG ? c_tile_times_timesteps_per_block : 0);
|
|
time_idx = min(time_idx, kv_loop_length - 1);
|
|
if (POS_SHIFT && time_idx >= sink_token_len)
|
|
{
|
|
// If one more block mode is enabled, we use the index in sequence as tokenIdx.
|
|
// Otherwise, we need to add the bubble length to the index
|
|
time_idx += shift_for_cyclic_kv;
|
|
if (enable_use_seq_idx_kv)
|
|
{
|
|
// Convert the token index in sequence to token index in V cache.
|
|
time_idx = kvCacheBuffer.getKVTokenIdx(time_idx);
|
|
}
|
|
}
|
|
int rowIdx = batch_idx * beam_width;
|
|
|
|
int const inBlockIdx = kvCacheBuffer.getKVLocalIdx(time_idx, hi_kv, Dh, vi);
|
|
// The base pointer for the value in the cache buffer.
|
|
Tcache* v_cache_batch = reinterpret_cast<Tcache*>(kvCacheBuffer.getVBlockPtr(rowIdx, time_idx));
|
|
|
|
v_vec_cache[v_loop] = *reinterpret_cast<V_vec_m const*>(&v_cache_batch[inBlockIdx]);
|
|
}
|
|
|
|
#pragma unroll
|
|
for (int v_loop = 0; v_loop < V_LOOP_UNROLL; v_loop++)
|
|
{
|
|
V_vec_m v_vec = reinterpret_cast<V_vec_m*>(&v_vec_cache[v_loop])[0];
|
|
|
|
int local_time_idx = ti + v_loop * V_PER_ITER;
|
|
int time_idx = local_time_idx + (MULTI_BLOCK_FLAG ? c_tile_times_timesteps_per_block : 0);
|
|
|
|
bool const is_mask
|
|
= (MULTI_BLOCK_FLAG && local_time_idx >= timesteps_per_block) || (time_idx >= context_length);
|
|
|
|
// Load the logits from shared memory.
|
|
// Note that fma will convert 8bit vec to the accumulation data type (float by default).
|
|
Logit_value_fma<Tk, V_vec_accum, V_vec_m, INT8_KV_CACHE, FP8_KV_CACHE>(
|
|
out, reinterpret_cast<Tk*>(logits_smem + local_time_idx), v_vec, kv_scale_quant_orig_f, is_mask);
|
|
}
|
|
}
|
|
|
|
// Handle generation value cache with beam searching.
|
|
if (HAS_BEAMS && !DO_CROSS_ATTENTION)
|
|
{
|
|
auto const generation_start_ti
|
|
= MULTI_BLOCK_FLAG ? vo : (vo + (beam0_context_length / V_PER_ITER) * V_PER_ITER);
|
|
// Only the last few blocks need to handle the generation value cache.
|
|
if (!MULTI_BLOCK_FLAG || (c_tile + 1) * timesteps_per_block > beam0_context_length)
|
|
{
|
|
for (int ti = generation_start_ti; ti < generation_v_loop_end; ti += V_PER_ITER)
|
|
{
|
|
// Fetch offset based on cache_indir when beam sampling
|
|
int time_idx = ti + (MULTI_BLOCK_FLAG ? c_tile_times_timesteps_per_block : 0);
|
|
int local_time_idx = ti;
|
|
if (time_idx < beam0_context_length || (MULTI_BLOCK_FLAG && time_idx >= kv_loop_length))
|
|
{
|
|
continue;
|
|
}
|
|
int rowIdx = batch_idx * beam_width + beam_indices[time_idx];
|
|
|
|
if (POS_SHIFT && time_idx >= sink_token_len)
|
|
{
|
|
// If one more block mode is enabled, we use the index in sequence as tokenIdx.
|
|
// Otherwise, we need to add the bubble length to the index
|
|
time_idx += shift_for_cyclic_kv;
|
|
if (enable_use_seq_idx_kv)
|
|
{
|
|
// Convert the token index in sequence to token index in V cache.
|
|
time_idx = kvCacheBuffer.getKVTokenIdx(time_idx);
|
|
}
|
|
}
|
|
|
|
int const inBlockIdx = kvCacheBuffer.getKVLocalIdx(time_idx, hi_kv, Dh, vi);
|
|
// The base pointer for the value in the cache buffer.
|
|
Tcache* v_cache_batch = reinterpret_cast<Tcache*>(kvCacheBuffer.getVBlockPtr(rowIdx, time_idx));
|
|
V_vec_m v_vec = reinterpret_cast<V_vec_m const*>(&v_cache_batch[inBlockIdx])[0];
|
|
|
|
// Load the logits from shared memory.
|
|
// Note that fma will convert 8bit vec to the accumulation data type (float by default).
|
|
Logit_value_fma<Tk, V_vec_accum, V_vec_m, INT8_KV_CACHE, FP8_KV_CACHE>(
|
|
out, reinterpret_cast<Tk*>(logits_smem + local_time_idx), v_vec, kv_scale_quant_orig_f, false);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Make sure we can overwrite the v cache if using cyclic kv cache.
|
|
__syncthreads();
|
|
|
|
// Get the c_tile_id that handles the current timestep.
|
|
int const ctile_idx = tlength / timesteps_per_block;
|
|
|
|
// One group of threads computes the product(s) for the current timestep.
|
|
if (vo == kv_loop_length % V_PER_ITER && is_valid_vi && (!MULTI_BLOCK_FLAG || (c_tile == ctile_idx)))
|
|
{
|
|
int const inBlockIdx = kvCacheBuffer.getKVLocalIdx(cyclic_tlength, hi_kv, Dh, vi);
|
|
// The base pointer for the value in the cache buffer.
|
|
Tcache* v_cache_base = reinterpret_cast<Tcache*>(kvCacheBuffer.getVBlockPtr(batch_beam_idx, cyclic_tlength));
|
|
|
|
V_vec_k v;
|
|
if (DO_CROSS_ATTENTION)
|
|
{
|
|
v = vec_conversion<V_vec_k, V_vec_k>(*reinterpret_cast<V_vec_k const*>(&v_cache_base[inBlockIdx]));
|
|
}
|
|
else
|
|
{
|
|
// Trigger the loads from the V buffer.
|
|
// The stride between tokens. We may be able to always use params.stride.
|
|
uint32_t v_stride = params.stride ? static_cast<uint32_t>(params.stride) : (num_heads_kv * Dh);
|
|
// The offset.
|
|
auto const v_offset = tensorrt_llm::common::flat_index_strided3(batch_beam_idx, hi_kv, vi, v_stride, Dh);
|
|
|
|
if (load_qkv_quant)
|
|
{
|
|
using Packed_Int8_t = typename packed_type<int8_t, num_elems<V_vec_k>::value>::type;
|
|
using Packed_Float_t = typename packed_type<float, num_elems<V_vec_k>::value>::type;
|
|
auto const v_scaling = params.qkv_scale_quant_orig[2];
|
|
auto const v_quant
|
|
= *reinterpret_cast<Packed_Int8_t const*>(&reinterpret_cast<int8_t const*>(params.v)[v_offset]);
|
|
|
|
convert_from_float(&v, mul<Packed_Float_t, float>(v_scaling, float_from_int8(v_quant)));
|
|
}
|
|
else
|
|
{
|
|
v = *reinterpret_cast<V_vec_k const*>(¶ms.v[v_offset]);
|
|
}
|
|
}
|
|
|
|
if (HANDLE_KV)
|
|
{
|
|
// Compute the V values with bias.
|
|
v = add(v, v_bias);
|
|
|
|
if (do_ia3)
|
|
{
|
|
v = mul<V_vec_k, V_vec_k, V_vec_k>(v,
|
|
*reinterpret_cast<V_vec_k const*>(
|
|
¶ms.ia3_value_weights[tensorrt_llm::common::flat_index2(ia3_ti_hi, vi, Dh)]));
|
|
}
|
|
}
|
|
|
|
// Store the values with bias back to global memory in the cache for V.
|
|
//*reinterpret_cast<V_vec_k*>(&v_cache[params.timestep*Dh]) = v;
|
|
// For MQA/GQA mode, write only with the first Q head of each group per KV head.
|
|
if (hi == (hi_kv * qhead_per_kv))
|
|
{
|
|
if (ENABLE_8BITS_KV_CACHE)
|
|
{
|
|
store_8bits_kv_cache_vec(v_cache_base, v, inBlockIdx, kv_scale_orig_quant);
|
|
}
|
|
else
|
|
{
|
|
*reinterpret_cast<V_vec_k*>(&v_cache_base[inBlockIdx]) = v;
|
|
}
|
|
}
|
|
|
|
// Initialize the output value with the current timestep.
|
|
#if defined(MMHA_USE_FP32_ACCUM_FOR_LOGITS)
|
|
// out = fma(logits_smem[params.timestep], cast_to_float(v), out);
|
|
if (!MULTI_BLOCK_FLAG)
|
|
{
|
|
out = fma(logits_smem[kv_loop_length], cast_to_float(v), out);
|
|
}
|
|
else
|
|
{
|
|
out = fma(logits_current_smem[0], cast_to_float(v), out);
|
|
}
|
|
#else // MMHA_USE_FP32_ACCUM_FOR_LOGITS
|
|
// out = fma(logits_smem[params.timestep], v, out);
|
|
if (!MULTI_BLOCK_FLAG)
|
|
{
|
|
out = fma(logits_smem[kv_loop_length], v, out);
|
|
}
|
|
else
|
|
{ // MULTI_BLOCK_FLAG // Not supported yet: multi-block mode with FP8_MHA
|
|
out = fma(logits_current_smem[0], v, out);
|
|
}
|
|
#endif // MMHA_USE_FP32_ACCUM_FOR_LOGITS
|
|
}
|
|
// Make sure we can start writing to shared memory.
|
|
__syncthreads();
|
|
|
|
// Run the final reduction amongst the different groups computing different partial outputs.
|
|
#pragma unroll
|
|
for (int active_groups = V_PER_ITER; active_groups >= 2; active_groups /= 2)
|
|
{
|
|
|
|
// The midpoint in the number of active groups.
|
|
int midpoint = active_groups / 2;
|
|
|
|
// The upper part of active threads store to shared memory.
|
|
if (vo >= midpoint && vo < active_groups && (Dh == Dh_MAX || vi < Dh))
|
|
{
|
|
#ifdef MMHA_USE_FP32_ACCUM_FOR_OUT
|
|
convert_from_float(reinterpret_cast<V_vec_k*>(&out_smem[(vo - midpoint) * Dh + vi]), out);
|
|
#else
|
|
*reinterpret_cast<V_vec_k*>(&out_smem[(vo - midpoint) * Dh + vi]) = out;
|
|
#endif
|
|
}
|
|
__syncthreads();
|
|
|
|
// The bottom warps update their values.
|
|
if (vo < midpoint && (Dh == Dh_MAX || vi < Dh))
|
|
{
|
|
out = add(*reinterpret_cast<V_vec_k const*>(&out_smem[vo * Dh + vi]), out);
|
|
}
|
|
__syncthreads();
|
|
}
|
|
|
|
// Quantized output only supports fp8 currently, which should be used together with FP8 Context FMHA.
|
|
using Quantized_t = __nv_fp8_e4m3;
|
|
using Quantized_vec = typename packed_type<__nv_fp8_e4m3, num_elems<V_vec_accum>::value>::type;
|
|
auto const bhi = tensorrt_llm::common::flat_index2(batch_beam_idx, hi, num_heads);
|
|
auto const bhi_seq_len_tile = bhi * params.seq_len_tile;
|
|
// Output the final values.
|
|
if (vo == 0 && (Dh == Dh_MAX || vi < Dh))
|
|
{
|
|
auto const bhvi = tensorrt_llm::common::flat_index2(bhi, vi, Dh);
|
|
#ifdef MMHA_USE_FP32_ACCUM_FOR_OUT
|
|
if (!MULTI_BLOCK_FLAG)
|
|
{
|
|
if (write_attention_quant)
|
|
{
|
|
out = mul<V_vec_accum, float>(*params.attention_out_scale_orig_quant, out);
|
|
Quantized_vec final_out;
|
|
convert_to_fp8(&final_out, out);
|
|
*reinterpret_cast<Quantized_vec*>(reinterpret_cast<Quantized_t*>(params.out) + bhvi) = final_out;
|
|
}
|
|
else
|
|
{
|
|
// This makes sure we have coalesced memory access.
|
|
V_vec_k final_out;
|
|
convert_from_float(&final_out, out);
|
|
*reinterpret_cast<V_vec_k*>(¶ms.out[bhvi]) = final_out;
|
|
}
|
|
}
|
|
else
|
|
{
|
|
// for write partial output to partial_out
|
|
int partial_out_offset = c_tile * params.batch_size * num_heads * params.hidden_size_per_head;
|
|
// for write partial statistics to partial_max and partial_sum
|
|
int partial_stats_offset = bhi_seq_len_tile + c_tile;
|
|
|
|
// This makes sure we have coalesced memory access.
|
|
V_vec_k partial_out;
|
|
convert_from_float(&partial_out, out);
|
|
*reinterpret_cast<V_vec_k*>(¶ms.partial_out[partial_out_offset + bhvi]) = partial_out;
|
|
convert_from_float(reinterpret_cast<float*>(¶ms.partial_max[partial_stats_offset]), qk_max);
|
|
convert_from_float(reinterpret_cast<float*>(¶ms.partial_sum[partial_stats_offset]), sum);
|
|
}
|
|
#else // MMHA_USE_FP32_ACCUM_FOR_OUT
|
|
*reinterpret_cast<V_vec_accum*>(¶ms.out[bhvi]) = out;
|
|
#endif // MMHA_USE_FP32_ACCUM_FOR_OUT
|
|
}
|
|
|
|
#ifdef ENABLE_MULTI_BLOCK_OPTION
|
|
if (MULTI_BLOCK_FLAG)
|
|
{
|
|
|
|
cuda::atomic_ref<int, cuda::thread_scope_device> count_ref{params.block_counter[bhi]};
|
|
bool last_block{false};
|
|
if (tidx == 0)
|
|
{
|
|
if (count_ref.fetch_add(1, cuda::memory_order_acq_rel) == (gridDim.z - 1))
|
|
{
|
|
last_block = true;
|
|
}
|
|
}
|
|
|
|
////////////////////
|
|
////////////////////
|
|
// Make sure every threadblock finishes the previous computation, and enter the last threadblock in the
|
|
// following (for each B and H) Do the final computation in the last threadblock Final reduction computation
|
|
// by combining all the partial max/sum and outputs
|
|
////////////////////
|
|
////////////////////
|
|
if (__syncthreads_or(last_block))
|
|
{
|
|
|
|
////////////////////
|
|
// Find the global max from all partial max -> use CUB BlockReduce
|
|
////////////////////
|
|
|
|
float final_max = -FLT_MAX;
|
|
float thread_partial_max = -FLT_MAX;
|
|
thread_partial_max = params.partial_max[bhi_seq_len_tile + min(tidx, gridDim.z - 1)];
|
|
|
|
// Make sure we can start writing to shared memory.
|
|
__syncthreads();
|
|
|
|
// Specialize BlockReduce for a 1D block of THREADS_PER_BLOCK threads of type int
|
|
typedef cub::BlockReduce<float, THREADS_PER_BLOCK> BlockReduce;
|
|
// Allocate shared memory for BlockReduce
|
|
__shared__ typename BlockReduce::TempStorage temp_storage;
|
|
// Obtain a segment of consecutive items that are blocked across threads (final_max from above)
|
|
// Compute the block-wide max for thread0
|
|
final_max = BlockReduce(temp_storage).Reduce(thread_partial_max, cub::Max(), gridDim.z);
|
|
|
|
__shared__ float final_max_smem;
|
|
if (tidx == 0)
|
|
{
|
|
final_max_smem = final_max;
|
|
}
|
|
__syncthreads();
|
|
|
|
// Finish the final_max computation
|
|
final_max = final_max_smem;
|
|
|
|
////////////////////
|
|
// Reduction for global sum over all partial sum (scaled by the exponential term from global max) -> use
|
|
// gridDim.z threads
|
|
////////////////////
|
|
|
|
float final_sum = 0.f;
|
|
if (tidx < gridDim.z)
|
|
{
|
|
thread_partial_max = params.partial_max[bhi_seq_len_tile + tidx];
|
|
auto const thread_partial_sum = params.partial_sum[bhi_seq_len_tile + tidx];
|
|
final_sum += __expf(thread_partial_max - final_max) * thread_partial_sum;
|
|
}
|
|
|
|
// Compute the final_sum.
|
|
final_sum = block_sum<WARPS_PER_BLOCK>(&red_smem[WARPS_PER_BLOCK], final_sum);
|
|
|
|
////////////////////
|
|
// Reduction for final output (scaled by the exponential term from global max) -> use THREADS_PER_VALUE
|
|
// * gridDim.z threads
|
|
////////////////////
|
|
|
|
// Shared memory to store partial outputs for each oi. -> size: gridDim.z * Dh * 4 Bytes. Reuse qk_smem.
|
|
T* out_oi_smem = reinterpret_cast<T*>(smem_);
|
|
|
|
auto const o_idx = chunk_index<T, V_vec_k, THREADS_PER_VALUE>(tidx);
|
|
|
|
// Init partial out for accumulation.
|
|
V_vec_k zero_k;
|
|
zero(zero_k);
|
|
V_vec_k thread_accumulated_out = zero_k;
|
|
|
|
// The hidden dimensions computed by this particular thread. (refer to vi)
|
|
auto const oi = o_idx.y;
|
|
|
|
// The partial output region this thread takes care of
|
|
auto const oo = o_idx.x;
|
|
|
|
// Each thread may handle more than one partial output.
|
|
for (int tile_idx = o_idx.x; tile_idx < gridDim.z; tile_idx += V_PER_ITER)
|
|
{
|
|
// Load partial output
|
|
int thread_partial_out_offset = tile_idx * params.batch_size * num_heads * params.hidden_size_per_head;
|
|
// Load partial max (different to thread_partial_max since the threadIdx rule changes here)
|
|
float thread_partial_max_for_out = params.partial_max[bhi_seq_len_tile + tile_idx];
|
|
// Load the partial outputs.
|
|
V_vec_k thread_partial_out
|
|
= *reinterpret_cast<V_vec_k const*>(¶ms.partial_out[thread_partial_out_offset + bhi * Dh + oi]);
|
|
// Apply the correction factor.
|
|
Tk factor_compute;
|
|
convert_from_float(&factor_compute, __expf(thread_partial_max_for_out - final_max));
|
|
thread_partial_out = mul<V_vec_k, Tk, V_vec_k>(factor_compute, thread_partial_out);
|
|
thread_accumulated_out = add(thread_partial_out, thread_accumulated_out);
|
|
}
|
|
|
|
// Run the final reduction amongst the different groups computing different partial outputs.
|
|
#pragma unroll
|
|
for (int active_groups = V_PER_ITER; active_groups >= 2; active_groups /= 2)
|
|
{
|
|
|
|
// The midpoint in the number of active groups.
|
|
int midpoint = active_groups / 2;
|
|
|
|
// The upper part of active threads store to shared memory.
|
|
if (oo >= midpoint && oo < active_groups && (Dh == Dh_MAX || oi < Dh))
|
|
{
|
|
*reinterpret_cast<V_vec_k*>(&out_oi_smem[(oo - midpoint) * Dh + oi]) = thread_accumulated_out;
|
|
}
|
|
__syncthreads();
|
|
|
|
// The bottom warps update their values.
|
|
if (oo < midpoint && (Dh == Dh_MAX || oi < Dh))
|
|
{
|
|
thread_accumulated_out
|
|
= add(thread_accumulated_out, *reinterpret_cast<V_vec_k const*>(&out_oi_smem[oo * Dh + oi]));
|
|
}
|
|
__syncthreads();
|
|
}
|
|
|
|
////////////////////
|
|
// Final output O * inv_sum
|
|
////////////////////
|
|
|
|
if (oo == 0 && (Dh == Dh_MAX || oi < Dh))
|
|
{
|
|
auto const inv_sum = __fdividef(
|
|
write_attention_quant ? *params.attention_out_scale_orig_quant : 1.f, final_sum + 1.e-6f);
|
|
|
|
Tk inv_sum_compute;
|
|
convert_from_float(&inv_sum_compute, inv_sum);
|
|
|
|
thread_accumulated_out = mul<V_vec_k, Tk, V_vec_k>(inv_sum_compute, thread_accumulated_out);
|
|
|
|
if (write_attention_quant)
|
|
{
|
|
Quantized_vec final_out;
|
|
convert_to_fp8(&final_out, thread_accumulated_out);
|
|
*reinterpret_cast<Quantized_vec*>(reinterpret_cast<Quantized_t*>(params.out) + bhi * Dh + oi)
|
|
= final_out;
|
|
}
|
|
else
|
|
{
|
|
*reinterpret_cast<V_vec_k*>(¶ms.out[bhi * Dh + oi]) = thread_accumulated_out;
|
|
}
|
|
}
|
|
|
|
// Reset qk_current_smem and block_counter for the next timestep
|
|
if (tidx == 0)
|
|
{
|
|
params.block_counter[bhi] = 0;
|
|
}
|
|
}
|
|
}
|
|
#endif // ENABLE_MULTI_BLOCK_OPTION
|
|
}
|
|
|
|
} // namespace mmha
|
|
|
|
} // namespace kernels
|
|
} // namespace tensorrt_llm
|