mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
1125 lines
54 KiB
Plaintext
1125 lines
54 KiB
Plaintext
/*
|
|
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include "tensorrt_llm/common/cudaBf16Wrapper.h"
|
|
#include "tensorrt_llm/common/cudaTypeUtils.cuh"
|
|
#include "tensorrt_llm/common/cudaUtils.h"
|
|
#include "tensorrt_llm/common/envUtils.h"
|
|
#include "tensorrt_llm/common/mathUtils.h"
|
|
#include "tensorrt_llm/common/reduceKernelUtils.cuh"
|
|
#include "tensorrt_llm/kernels/decoderMaskedMultiheadAttentionUtils.h"
|
|
#include "tensorrt_llm/kernels/gptKernels.h"
|
|
#include "tensorrt_llm/kernels/mlaKernels.h"
|
|
#include <cstdint>
|
|
#include <cub/cub.cuh>
|
|
#include <cuda_fp16.h>
|
|
#include <cuda_fp8.h>
|
|
#include <cuda_runtime.h>
|
|
|
|
using namespace tensorrt_llm::common;
|
|
|
|
namespace tensorrt_llm
|
|
{
|
|
namespace kernels
|
|
{
|
|
|
|
// A stateful callback functor that maintains the running sum between consecutive scans.
|
|
struct BlockPrefixCallbackOp
|
|
{
|
|
// Running prefix
|
|
int mRunningTotal;
|
|
|
|
// Constructor
|
|
__device__ BlockPrefixCallbackOp(int runningTotal)
|
|
: mRunningTotal(runningTotal)
|
|
{
|
|
}
|
|
|
|
// Thread-0 is responsible for returning a value for seeding the block-wide scan.
|
|
__device__ int operator()(int blockAggregate)
|
|
{
|
|
int oldPrefix = mRunningTotal;
|
|
mRunningTotal += blockAggregate;
|
|
return oldPrefix;
|
|
}
|
|
};
|
|
|
|
template <typename T>
|
|
struct VecType
|
|
{
|
|
using Type = T;
|
|
using GPTJEltType = T;
|
|
};
|
|
|
|
template <>
|
|
struct VecType<float>
|
|
{
|
|
using Type = float4;
|
|
using GPTJEltType = float2;
|
|
};
|
|
|
|
template <>
|
|
struct VecType<half>
|
|
{
|
|
using Type = uint4;
|
|
using GPTJEltType = uint32_t;
|
|
};
|
|
|
|
template <>
|
|
struct VecType<__nv_bfloat16>
|
|
{
|
|
using Type = mmha::bf16_8_t;
|
|
using GPTJEltType = __nv_bfloat162;
|
|
};
|
|
|
|
struct __align__(16) fp8_16_t
|
|
{
|
|
__nv_fp8x4_e4m3 x;
|
|
__nv_fp8x4_e4m3 y;
|
|
__nv_fp8x4_e4m3 z;
|
|
__nv_fp8x4_e4m3 w;
|
|
};
|
|
|
|
template <>
|
|
struct VecType<__nv_fp8_e4m3>
|
|
{
|
|
using Type = fp8_16_t;
|
|
using GPTJEltType = __nv_fp8x2_e4m3;
|
|
};
|
|
|
|
template <typename T>
|
|
struct loadPagedKVKernelTraits
|
|
{
|
|
static constexpr int kLoraSize = 512;
|
|
static constexpr int kRopeSize = 64;
|
|
static constexpr int kHeadSize = kLoraSize + kRopeSize;
|
|
using VecT = typename VecType<T>::Type;
|
|
static constexpr int kBytesPerElem = sizeof(T);
|
|
static constexpr int kBytesPerLoad = 16;
|
|
static constexpr int kElemPerLoad = kBytesPerLoad / kBytesPerElem;
|
|
static_assert((kHeadSize * kBytesPerElem) % kBytesPerLoad == 0,
|
|
"kHeadSize * kBytesPerElem must be multiple of kBytesPerLoad (16Bytes)");
|
|
static constexpr int kVecPerHead = (kHeadSize * kBytesPerElem) / kBytesPerLoad;
|
|
static constexpr int kThreadPerHead = kVecPerHead; // for each head, we use kThreadPerHead threads to fetch all the
|
|
// kv cache data, each thread read kv cache only once.
|
|
static constexpr int kTokenPerBlock
|
|
= std::is_same_v<T, float> ? 4 : 8; // for each block, we fetch 4 tokens for fp32, 8 tokens for other types.
|
|
static constexpr int kBlockSize = kThreadPerHead * kTokenPerBlock;
|
|
static constexpr int kKVThreadPerHead = (kLoraSize * kBytesPerElem) / kBytesPerLoad;
|
|
};
|
|
|
|
template <typename T>
|
|
struct setPagedKVKernelTraits
|
|
{
|
|
static constexpr int kQKNopeSize = 128;
|
|
static constexpr int kVHeadSize = 128;
|
|
static_assert(kQKNopeSize == kVHeadSize);
|
|
static constexpr int kRopeSize = 64;
|
|
static constexpr int kHeadSize = kQKNopeSize + kRopeSize;
|
|
using VecT = typename VecType<T>::Type;
|
|
static constexpr int kBytesPerElem = sizeof(T);
|
|
static constexpr int kBytesPerLoad = 16;
|
|
static constexpr int kElemPerLoad = kBytesPerLoad / kBytesPerElem;
|
|
static_assert((kHeadSize * kBytesPerElem) % kBytesPerLoad == 0,
|
|
"kHeadSize * kBytesPerElem must be multiple of kBytesPerLoad (16Bytes)");
|
|
static constexpr int kNumHeads = 128;
|
|
static constexpr int kThreadPerHead = (kHeadSize * kBytesPerElem) / kBytesPerLoad;
|
|
static constexpr int kKVThreadPerHead = (kQKNopeSize * kBytesPerElem) / kBytesPerLoad;
|
|
static constexpr int kCpTokenPerBlock = 16;
|
|
static constexpr int kBlockSize = kThreadPerHead * kCpTokenPerBlock;
|
|
};
|
|
|
|
template <typename SrcType, int NUM>
|
|
inline __device__ void quantCopy(
|
|
__nv_fp8_e4m3* dst_global_ptr, SrcType const* src_fragment_ptr, float const scale_val = 1.f)
|
|
{
|
|
using DstVecType = typename std::conditional<sizeof(SrcType) == 2, float2, float>::type;
|
|
using SrcType2 =
|
|
typename std::conditional<sizeof(SrcType) == 2, typename TypeConverter<SrcType>::Type, float2>::type;
|
|
static constexpr int COPY_SIZE = sizeof(DstVecType);
|
|
static constexpr int TOTAL_COPY_SIZE = NUM * sizeof(__nv_fp8_e4m3);
|
|
static constexpr int LOOP_NUM = TOTAL_COPY_SIZE / COPY_SIZE;
|
|
static_assert(TOTAL_COPY_SIZE % COPY_SIZE == 0);
|
|
static constexpr int CVT_NUM = COPY_SIZE / sizeof(__nv_fp8_e4m3) / 2;
|
|
static_assert(COPY_SIZE % (sizeof(__nv_fp8_e4m3) * 2) == 0);
|
|
DstVecType fragment;
|
|
int offset = 0;
|
|
#pragma unroll
|
|
for (int i = 0; i < LOOP_NUM; ++i)
|
|
{
|
|
#pragma unroll
|
|
for (int j = 0; j < CVT_NUM; ++j)
|
|
{
|
|
float2 val2 = cuda_cast<float2>(reinterpret_cast<SrcType2 const*>(src_fragment_ptr)[j + offset]);
|
|
val2.x *= scale_val;
|
|
val2.y *= scale_val;
|
|
reinterpret_cast<__nv_fp8x2_e4m3*>(&fragment)[j] = __nv_fp8x2_e4m3(val2);
|
|
}
|
|
reinterpret_cast<DstVecType*>(dst_global_ptr)[i] = fragment;
|
|
offset += CVT_NUM;
|
|
}
|
|
}
|
|
|
|
template <typename DstType, int NUM>
|
|
inline __device__ void dequantCopy(
|
|
DstType* dst_global_ptr, __nv_fp8_e4m3 const* src_fragment_ptr, float const scale_val = 1.f)
|
|
{
|
|
using DstVecType = typename VecType<DstType>::Type;
|
|
using DstType2 =
|
|
typename std::conditional<sizeof(DstType) == 2, typename TypeConverter<DstType>::Type, float2>::type;
|
|
static constexpr int COPY_SIZE = sizeof(DstVecType);
|
|
static constexpr int TOTAL_COPY_SIZE = NUM * sizeof(DstType);
|
|
static constexpr int LOOP_NUM = TOTAL_COPY_SIZE / COPY_SIZE;
|
|
static_assert(TOTAL_COPY_SIZE % COPY_SIZE == 0);
|
|
static constexpr int CVT_NUM = COPY_SIZE / sizeof(DstType) / 2;
|
|
static_assert(COPY_SIZE % (sizeof(DstType) * 2) == 0);
|
|
DstVecType fragment;
|
|
int offset = 0;
|
|
#pragma unroll
|
|
for (int i = 0; i < LOOP_NUM; ++i)
|
|
{
|
|
#pragma unroll
|
|
for (int j = 0; j < CVT_NUM; ++j)
|
|
{
|
|
float2 val2 = cuda_cast<float2>(reinterpret_cast<__nv_fp8x2_e4m3 const*>(src_fragment_ptr)[j + offset]);
|
|
val2.x *= scale_val;
|
|
val2.y *= scale_val;
|
|
reinterpret_cast<DstType2*>(&fragment)[j] = cuda_cast<DstType2>(val2);
|
|
}
|
|
reinterpret_cast<DstVecType*>(dst_global_ptr)[i] = fragment;
|
|
offset += CVT_NUM;
|
|
}
|
|
}
|
|
|
|
template <typename T, int BLOCK_SIZE, int K_DIM, int ROPE_DIM, typename KVCacheBuffer>
|
|
__global__ void applyMLARopeAndAssignQKVKernelOptContext(T* qkv_output, T const* fuse_buf, KVCacheBuffer kv_cache,
|
|
float2 const* cos_sin_cache, size_t head_num, int head_size, int c_k, int* cu_q_seqlens,
|
|
int32_t const* kv_cache_lengths, uint32_t max_input_seq_len, KvCacheDataType cache_type, float* bmm1_scale,
|
|
float* bmm2_scale, float const* quant_scale_o, float const* quant_scale_kv, float const* dequant_scale_q,
|
|
float const* dequant_scale_kv, float host_bmm1_scale)
|
|
{
|
|
|
|
// Constants.
|
|
using VecT = typename VecType<T>::Type;
|
|
using GPTJEltT = typename VecType<T>::GPTJEltType;
|
|
constexpr auto HEAD_SIZE = ROPE_DIM;
|
|
constexpr auto K_HEAD_SIZE = K_DIM;
|
|
constexpr auto BYTES_PER_ELT = sizeof(T);
|
|
constexpr auto BYTES_PER_LOAD = 16;
|
|
constexpr auto ELTS_PER_VEC = BYTES_PER_LOAD / BYTES_PER_ELT;
|
|
static_assert((HEAD_SIZE * BYTES_PER_ELT) % BYTES_PER_LOAD == 0, "Head size needs to be multiple of 16 bytes.");
|
|
constexpr auto VECS_PER_HEAD = HEAD_SIZE * BYTES_PER_ELT / BYTES_PER_LOAD;
|
|
constexpr auto K_VECS_PER_HEAD = K_HEAD_SIZE * BYTES_PER_ELT / BYTES_PER_LOAD;
|
|
static_assert(BLOCK_SIZE % VECS_PER_HEAD == 0, "Kernel block should be able to handle entire heads.");
|
|
constexpr auto TOKENS_PER_BLOCK = BLOCK_SIZE / VECS_PER_HEAD;
|
|
constexpr auto K_TOKENS_PER_BLOCK = BLOCK_SIZE / K_VECS_PER_HEAD;
|
|
constexpr auto TOTAL_VECS_PER_HEAD = VECS_PER_HEAD + K_VECS_PER_HEAD;
|
|
|
|
// Block/Head idx.
|
|
size_t const batch_idx = blockIdx.y;
|
|
size_t const head_idx = blockIdx.z;
|
|
|
|
if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0)
|
|
{
|
|
|
|
// Calculate bmm scale for FP8 MLA
|
|
if (cache_type == KvCacheDataType::FP8)
|
|
{
|
|
float dequant_scale_q_val = dequant_scale_q ? dequant_scale_q[0] : 1.f;
|
|
float dequant_scale_kv_val = dequant_scale_kv ? dequant_scale_kv[0] : 1.f;
|
|
float quant_scale_o_val = quant_scale_o ? quant_scale_o[0] : 1.f;
|
|
if (bmm1_scale)
|
|
{
|
|
// The scale prepared for log2 optimization.
|
|
constexpr float kLog2e = 1.4426950408889634074f;
|
|
// The scale after fmha bmm1.
|
|
float bmm1_scale_val = dequant_scale_q_val * dequant_scale_kv_val * host_bmm1_scale;
|
|
bmm1_scale[0] = bmm1_scale_val;
|
|
bmm1_scale[1] = bmm1_scale_val * kLog2e;
|
|
}
|
|
if (bmm2_scale)
|
|
{
|
|
// The scale after fmha bmm2.
|
|
bmm2_scale[0] = quant_scale_o_val * dequant_scale_kv_val;
|
|
}
|
|
}
|
|
}
|
|
|
|
if (head_idx < head_num)
|
|
{
|
|
size_t const head_dim_vec_idx = (threadIdx.x % VECS_PER_HEAD);
|
|
size_t const head_dim_idx = head_dim_vec_idx * ELTS_PER_VEC;
|
|
|
|
size_t const seq_len_loop_end
|
|
= size_t((max_input_seq_len + TOKENS_PER_BLOCK - 1) / TOKENS_PER_BLOCK) * TOKENS_PER_BLOCK;
|
|
float quant_scale_kv_val = quant_scale_kv ? quant_scale_kv[0] : 1.f;
|
|
|
|
// Mainloop.
|
|
for (int local_token_idx = (threadIdx.x / VECS_PER_HEAD) + blockIdx.x * TOKENS_PER_BLOCK;
|
|
local_token_idx < seq_len_loop_end; local_token_idx += TOKENS_PER_BLOCK * gridDim.x)
|
|
{
|
|
|
|
int const global_token_offset = cu_q_seqlens[batch_idx];
|
|
int const cache_seq_len = kv_cache_lengths[batch_idx];
|
|
int token_idx_in_kv_cache = local_token_idx;
|
|
bool const valid_token = token_idx_in_kv_cache < cache_seq_len;
|
|
// Limit the token_idx to cache seq length (we need all threads in this block to be involved).
|
|
token_idx_in_kv_cache = std::min(token_idx_in_kv_cache, cache_seq_len - 1);
|
|
local_token_idx = std::min(local_token_idx, cache_seq_len - 1);
|
|
int const global_token_idx = local_token_idx + global_token_offset;
|
|
|
|
auto const position_id = local_token_idx;
|
|
float2 const* rotary_coef_cache_buffer
|
|
= cos_sin_cache + static_cast<size_t>(ROPE_DIM) * position_id + (head_dim_idx / 2);
|
|
|
|
VecT q, k;
|
|
auto const src_k_global_offset = static_cast<size_t>(global_token_idx) * (c_k + ROPE_DIM) + c_k;
|
|
auto const src_q_global_offset
|
|
= static_cast<size_t>(global_token_idx) * head_num * ((head_size + ROPE_DIM) * 2 + head_size)
|
|
+ (head_size + ROPE_DIM) * head_idx + head_size;
|
|
|
|
q = *reinterpret_cast<VecT const*>(&qkv_output[src_q_global_offset + head_dim_idx]);
|
|
k = *reinterpret_cast<VecT const*>(&fuse_buf[src_k_global_offset + head_dim_idx]);
|
|
|
|
// Pack two elements into one for gptj rotary embedding.
|
|
#pragma unroll
|
|
for (int elt_id = 0; elt_id < ELTS_PER_VEC / 2; elt_id++)
|
|
{
|
|
GPTJEltT& q_ = reinterpret_cast<GPTJEltT*>(&q)[elt_id];
|
|
GPTJEltT& k_ = reinterpret_cast<GPTJEltT*>(&k)[elt_id];
|
|
|
|
float2 rotary_coef_cache = rotary_coef_cache_buffer[elt_id];
|
|
mmha::apply_rotary_embedding_gptj(q_, k_, rotary_coef_cache);
|
|
}
|
|
// do sync
|
|
__syncwarp();
|
|
if (valid_token)
|
|
{
|
|
if (head_idx == 0)
|
|
{
|
|
auto kDst = reinterpret_cast<T*>(kv_cache.getKBlockPtr(batch_idx, token_idx_in_kv_cache));
|
|
auto inBlockIdx = kv_cache.getKVLocalIdx(
|
|
token_idx_in_kv_cache, 0, TOTAL_VECS_PER_HEAD, K_VECS_PER_HEAD + head_dim_vec_idx);
|
|
if (cache_type == KvCacheDataType::FP8)
|
|
{
|
|
|
|
quantCopy<T, ELTS_PER_VEC>(reinterpret_cast<__nv_fp8_e4m3*>(kDst) + inBlockIdx * ELTS_PER_VEC,
|
|
reinterpret_cast<T const*>(&k), quant_scale_kv_val);
|
|
}
|
|
else
|
|
reinterpret_cast<VecT*>(kDst)[inBlockIdx] = k;
|
|
}
|
|
auto const dst_q_idx
|
|
= static_cast<size_t>(global_token_idx) * head_num * ((head_size + ROPE_DIM) * 2 + head_size)
|
|
+ head_idx * (head_size + ROPE_DIM) + head_size + head_dim_idx;
|
|
auto const dst_k_idx
|
|
= static_cast<size_t>(global_token_idx) * head_num * ((head_size + ROPE_DIM) * 2 + head_size)
|
|
+ head_num * (head_size + ROPE_DIM) + head_idx * (head_size + ROPE_DIM) + head_size + head_dim_idx;
|
|
reinterpret_cast<VecT*>(qkv_output)[dst_q_idx / ELTS_PER_VEC] = q;
|
|
reinterpret_cast<VecT*>(qkv_output)[dst_k_idx / ELTS_PER_VEC] = k;
|
|
}
|
|
}
|
|
}
|
|
else
|
|
{
|
|
int block_dim = gridDim.z - head_num;
|
|
int block_id = head_idx - head_num;
|
|
size_t const head_dim_vec_idx = (threadIdx.x % K_VECS_PER_HEAD);
|
|
size_t const head_dim_idx = head_dim_vec_idx * ELTS_PER_VEC;
|
|
|
|
size_t const seq_len_loop_end
|
|
= size_t((max_input_seq_len + K_TOKENS_PER_BLOCK - 1) / K_TOKENS_PER_BLOCK) * K_TOKENS_PER_BLOCK;
|
|
float quant_scale_kv_val = quant_scale_kv ? quant_scale_kv[0] : 1.f;
|
|
|
|
// Mainloop.
|
|
for (int local_token_idx = (threadIdx.x / K_VECS_PER_HEAD) + gridDim.x * K_TOKENS_PER_BLOCK * block_id
|
|
+ blockIdx.x * K_TOKENS_PER_BLOCK;
|
|
local_token_idx < seq_len_loop_end; local_token_idx += block_dim * K_TOKENS_PER_BLOCK * gridDim.x)
|
|
{
|
|
|
|
int const global_token_offset = cu_q_seqlens[batch_idx];
|
|
int const cache_seq_len = kv_cache_lengths[batch_idx];
|
|
int token_idx_in_kv_cache = local_token_idx;
|
|
bool const valid_token = token_idx_in_kv_cache < cache_seq_len;
|
|
// Limit the token_idx to cache seq length (we need all threads in this block to be involved).
|
|
token_idx_in_kv_cache = std::min(token_idx_in_kv_cache, cache_seq_len - 1);
|
|
local_token_idx = std::min(local_token_idx, cache_seq_len - 1);
|
|
int const global_token_idx = local_token_idx + global_token_offset;
|
|
|
|
if (valid_token)
|
|
{
|
|
auto const src_k_global_offset = static_cast<size_t>(global_token_idx) * (c_k + ROPE_DIM);
|
|
|
|
auto kDst = reinterpret_cast<T*>(kv_cache.getKBlockPtr(batch_idx, token_idx_in_kv_cache));
|
|
auto inBlockIdx
|
|
= kv_cache.getKVLocalIdx(token_idx_in_kv_cache, 0, TOTAL_VECS_PER_HEAD, head_dim_vec_idx);
|
|
if (cache_type == KvCacheDataType::FP8)
|
|
{
|
|
|
|
quantCopy<T, ELTS_PER_VEC>(reinterpret_cast<__nv_fp8_e4m3*>(kDst) + inBlockIdx * ELTS_PER_VEC,
|
|
fuse_buf + src_k_global_offset + head_dim_idx, quant_scale_kv_val);
|
|
}
|
|
else
|
|
reinterpret_cast<VecT*>(kDst)[inBlockIdx]
|
|
= *reinterpret_cast<VecT const*>(&fuse_buf[src_k_global_offset + head_dim_idx]);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename T, int BLOCK_SIZE, int K_DIM, int ROPE_DIM, typename KVCacheBuffer>
|
|
__global__ void applyMLARopeAndAssignQKVKernelGeneration(T* qkv_output, T* q_pe, T const* fuse_buf, void* quant_q,
|
|
KVCacheBuffer kv_cache, float2 const* cos_sin_cache, size_t head_num, int c_k, int total_s_len, int seq_len,
|
|
int* seqQOffset, uint32_t* fmha_tile_counter, int32_t const* kv_cache_lengths, int* seqKVOffsets, int q_pe_ld,
|
|
int q_pe_stride, KvCacheDataType cache_type, float* bmm1_scale, float* bmm2_scale, float const* quant_scale_o,
|
|
float const* quant_scale_q, float const* quant_scale_kv, float const* dequant_scale_q,
|
|
float const* dequant_scale_kv, float host_bmm1_scale)
|
|
{
|
|
|
|
// Constants.
|
|
using VecT = typename VecType<T>::Type;
|
|
using GPTJEltT = typename VecType<T>::GPTJEltType;
|
|
constexpr auto HEAD_SIZE = ROPE_DIM;
|
|
constexpr auto K_HEAD_SIZE = K_DIM;
|
|
constexpr auto BYTES_PER_ELT = sizeof(T);
|
|
constexpr auto BYTES_PER_LOAD = 16;
|
|
constexpr auto ELTS_PER_VEC = BYTES_PER_LOAD / BYTES_PER_ELT;
|
|
static_assert((HEAD_SIZE * BYTES_PER_ELT) % BYTES_PER_LOAD == 0, "Head size needs to be multiple of 16 bytes.");
|
|
constexpr auto VECS_PER_HEAD = HEAD_SIZE * BYTES_PER_ELT / BYTES_PER_LOAD;
|
|
constexpr auto K_VECS_PER_HEAD = K_HEAD_SIZE * BYTES_PER_ELT / BYTES_PER_LOAD;
|
|
static_assert(BLOCK_SIZE % VECS_PER_HEAD == 0, "Kernel block should be able to handle entire heads.");
|
|
constexpr auto TOKENS_PER_BLOCK = BLOCK_SIZE / VECS_PER_HEAD;
|
|
constexpr auto K_TOKENS_PER_BLOCK = BLOCK_SIZE / K_VECS_PER_HEAD;
|
|
constexpr auto TOTAL_VEC_PER_HEAD = VECS_PER_HEAD + K_VECS_PER_HEAD;
|
|
|
|
// Block/Head idx.
|
|
size_t const head_idx = blockIdx.y;
|
|
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
|
asm volatile("griddepcontrol.wait;");
|
|
#endif
|
|
|
|
if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0)
|
|
{
|
|
fmha_tile_counter[0] = 0;
|
|
seqQOffset[0] = 0;
|
|
|
|
// Calculate bmm scale for FP8 MLA
|
|
if (cache_type == KvCacheDataType::FP8)
|
|
{
|
|
float dequant_scale_q_val = dequant_scale_q ? dequant_scale_q[0] : 1.f;
|
|
float dequant_scale_kv_val = dequant_scale_kv ? dequant_scale_kv[0] : 1.f;
|
|
float quant_scale_o_val = quant_scale_o ? quant_scale_o[0] : 1.f;
|
|
if (bmm1_scale)
|
|
{
|
|
// The scale prepared for log2 optimization.
|
|
constexpr float kLog2e = 1.4426950408889634074f;
|
|
// The scale after fmha bmm1.
|
|
float bmm1_scale_val = dequant_scale_q_val * dequant_scale_kv_val * host_bmm1_scale;
|
|
bmm1_scale[0] = bmm1_scale_val;
|
|
bmm1_scale[1] = bmm1_scale_val * kLog2e;
|
|
}
|
|
if (bmm2_scale)
|
|
{
|
|
// The scale after fmha bmm2.
|
|
bmm2_scale[0] = quant_scale_o_val * dequant_scale_kv_val;
|
|
}
|
|
}
|
|
}
|
|
|
|
if (head_idx <= head_num)
|
|
{
|
|
size_t const head_dim_vec_idx = (threadIdx.x % VECS_PER_HEAD);
|
|
size_t const head_dim_idx = head_dim_vec_idx * ELTS_PER_VEC;
|
|
|
|
int const seq_len_loop_end = size_t((total_s_len + TOKENS_PER_BLOCK - 1) / TOKENS_PER_BLOCK) * TOKENS_PER_BLOCK;
|
|
float const quant_scale_q_val = quant_scale_q ? quant_scale_q[0] : 1.0f;
|
|
float const quant_scale_kv_val = quant_scale_kv ? quant_scale_kv[0] : 1.0f;
|
|
|
|
// Mainloop.
|
|
for (int global_token_idx = (threadIdx.x / VECS_PER_HEAD) + blockIdx.x * TOKENS_PER_BLOCK;
|
|
global_token_idx < seq_len_loop_end; global_token_idx += TOKENS_PER_BLOCK * gridDim.x)
|
|
{
|
|
auto batch_idx = global_token_idx / seq_len;
|
|
auto local_token_idx = global_token_idx % seq_len;
|
|
bool const valid_token = global_token_idx < total_s_len;
|
|
VecT data;
|
|
|
|
if (valid_token)
|
|
{
|
|
|
|
auto const position_id = kv_cache_lengths[batch_idx] - seq_len + local_token_idx;
|
|
float2 const* rotary_coef_cache_buffer
|
|
= cos_sin_cache + static_cast<size_t>(ROPE_DIM) * position_id + (head_dim_idx / 2);
|
|
|
|
if (head_idx == head_num)
|
|
{
|
|
auto const src_k_global_offset = static_cast<size_t>(global_token_idx) * (c_k + ROPE_DIM) + c_k;
|
|
|
|
data = *reinterpret_cast<VecT const*>(&fuse_buf[src_k_global_offset + head_dim_idx]);
|
|
}
|
|
else
|
|
{
|
|
auto const src_q_global_offset
|
|
= static_cast<size_t>(global_token_idx) * q_pe_stride + q_pe_ld * head_idx;
|
|
|
|
data = *reinterpret_cast<VecT const*>(&q_pe[src_q_global_offset + head_dim_idx]);
|
|
}
|
|
|
|
// Pack two elements into one for gptj rotary embedding.
|
|
#pragma unroll
|
|
for (int elt_id = 0; elt_id < ELTS_PER_VEC / 2; elt_id++)
|
|
{
|
|
GPTJEltT& data_ = reinterpret_cast<GPTJEltT*>(&data)[elt_id];
|
|
|
|
float2 rotary_coef_cache = rotary_coef_cache_buffer[elt_id];
|
|
data_ = mmha::rotary_embedding_transform(data_, rotary_coef_cache);
|
|
}
|
|
}
|
|
|
|
__syncwarp();
|
|
|
|
if (valid_token)
|
|
{
|
|
if (head_idx == head_num)
|
|
{
|
|
auto const token_kv_idx = kv_cache_lengths[batch_idx] - seq_len + local_token_idx;
|
|
|
|
{
|
|
auto kDst = reinterpret_cast<T*>(kv_cache.getKBlockPtr(batch_idx, token_kv_idx));
|
|
auto inBlockIdx = kv_cache.getKVLocalIdx(
|
|
token_kv_idx, 0, TOTAL_VEC_PER_HEAD, K_VECS_PER_HEAD + head_dim_vec_idx);
|
|
if (cache_type == KvCacheDataType::FP8)
|
|
{
|
|
|
|
quantCopy<T, ELTS_PER_VEC>(
|
|
reinterpret_cast<__nv_fp8_e4m3*>(kDst) + inBlockIdx * ELTS_PER_VEC,
|
|
reinterpret_cast<T const*>(&data), quant_scale_kv_val);
|
|
}
|
|
else
|
|
reinterpret_cast<VecT*>(kDst)[inBlockIdx] = data;
|
|
}
|
|
}
|
|
else
|
|
{
|
|
auto const dst_q_idx = static_cast<size_t>(global_token_idx) * head_num * (c_k + ROPE_DIM)
|
|
+ head_idx * (c_k + ROPE_DIM) + c_k + head_dim_idx;
|
|
if (cache_type == KvCacheDataType::FP8)
|
|
{
|
|
quantCopy<T, ELTS_PER_VEC>(reinterpret_cast<__nv_fp8_e4m3*>(quant_q) + dst_q_idx,
|
|
reinterpret_cast<T const*>(&data), quant_scale_q_val);
|
|
}
|
|
else
|
|
reinterpret_cast<VecT*>(qkv_output)[dst_q_idx / ELTS_PER_VEC] = data;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
else if (head_idx <= head_num + 8)
|
|
{
|
|
int block_dim = gridDim.y - head_num - 1;
|
|
int block_id = head_idx - head_num - 1;
|
|
size_t const head_dim_vec_idx = (threadIdx.x % K_VECS_PER_HEAD);
|
|
size_t const head_dim_idx = head_dim_vec_idx * ELTS_PER_VEC;
|
|
|
|
size_t const seq_len_loop_end
|
|
= size_t((total_s_len + K_TOKENS_PER_BLOCK - 1) / K_TOKENS_PER_BLOCK) * K_TOKENS_PER_BLOCK;
|
|
float quant_scale_kv_val = quant_scale_kv ? quant_scale_kv[0] : 1.0f;
|
|
|
|
// Mainloop.
|
|
for (int global_token_idx = (threadIdx.x / K_VECS_PER_HEAD) + gridDim.x * K_TOKENS_PER_BLOCK * block_id
|
|
+ blockIdx.x * K_TOKENS_PER_BLOCK;
|
|
global_token_idx < seq_len_loop_end; global_token_idx += block_dim * K_TOKENS_PER_BLOCK * gridDim.x)
|
|
{
|
|
auto batch_idx = global_token_idx / seq_len;
|
|
auto local_token_idx = global_token_idx % seq_len;
|
|
bool valid_token = global_token_idx < total_s_len;
|
|
|
|
if (valid_token)
|
|
{
|
|
if (head_dim_vec_idx == 0)
|
|
{
|
|
seqQOffset[batch_idx + 1] = head_num * seq_len * (batch_idx + 1);
|
|
}
|
|
|
|
auto const token_kv_idx = kv_cache_lengths[batch_idx] - seq_len + local_token_idx;
|
|
auto const src_kv_global_offset = static_cast<size_t>(global_token_idx) * (c_k + ROPE_DIM);
|
|
|
|
{
|
|
auto kDst = reinterpret_cast<T*>(kv_cache.getKBlockPtr(batch_idx, token_kv_idx));
|
|
auto inBlockIdx = kv_cache.getKVLocalIdx(token_kv_idx, 0, TOTAL_VEC_PER_HEAD, head_dim_vec_idx);
|
|
|
|
if (cache_type == KvCacheDataType::FP8)
|
|
{
|
|
quantCopy<T, ELTS_PER_VEC>(reinterpret_cast<__nv_fp8_e4m3*>(kDst) + inBlockIdx * ELTS_PER_VEC,
|
|
fuse_buf + src_kv_global_offset + head_dim_idx, quant_scale_kv_val);
|
|
}
|
|
else
|
|
reinterpret_cast<VecT*>(kDst)[inBlockIdx]
|
|
= *reinterpret_cast<VecT const*>(&fuse_buf[src_kv_global_offset + head_dim_idx]);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
else
|
|
{
|
|
if (cache_type == KvCacheDataType::FP8)
|
|
{
|
|
int block_dim = gridDim.y - head_num - 1 - 8;
|
|
int block_id = head_idx - head_num - 1 - 8;
|
|
size_t const head_dim_vec_idx = (threadIdx.x % K_VECS_PER_HEAD);
|
|
size_t const head_dim_idx = head_dim_vec_idx * ELTS_PER_VEC;
|
|
size_t const head_num_idx = (block_id % head_num) * (K_HEAD_SIZE + HEAD_SIZE);
|
|
|
|
size_t const seq_len_loop_end
|
|
= size_t((total_s_len + K_TOKENS_PER_BLOCK - 1) / K_TOKENS_PER_BLOCK) * K_TOKENS_PER_BLOCK;
|
|
float quant_scale_q_val = quant_scale_q ? quant_scale_q[0] : 1.0f;
|
|
|
|
// Mainloop.
|
|
for (int global_token_idx = (threadIdx.x / K_VECS_PER_HEAD)
|
|
+ (block_id / head_num) * gridDim.x * K_TOKENS_PER_BLOCK + blockIdx.x * K_TOKENS_PER_BLOCK;
|
|
global_token_idx < seq_len_loop_end;
|
|
global_token_idx += (block_dim / head_num) * gridDim.x * K_TOKENS_PER_BLOCK)
|
|
{
|
|
if (global_token_idx < total_s_len)
|
|
{
|
|
size_t const load_idx
|
|
= global_token_idx * head_num * (K_HEAD_SIZE + HEAD_SIZE) + head_num_idx + head_dim_idx;
|
|
quantCopy<T, ELTS_PER_VEC>(
|
|
reinterpret_cast<__nv_fp8_e4m3*>(quant_q) + load_idx, qkv_output + load_idx, quant_scale_q_val);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
|
asm volatile("griddepcontrol.launch_dependents;");
|
|
#endif
|
|
|
|
// The implementation of the parallel scan in the thread block (see CUB for details).
|
|
using BlockScan = cub::BlockScan<int, BLOCK_SIZE>;
|
|
|
|
// Allocate storage in shared memory to do the scan.
|
|
__shared__ typename BlockScan::TempStorage tempKVStorage;
|
|
BlockPrefixCallbackOp prefixKVOp(0);
|
|
|
|
if (blockIdx.x == 0 && blockIdx.y == 0)
|
|
{
|
|
int const batchSizeBound = total_s_len / seq_len;
|
|
for (int batchOffset = 0; batchOffset <= batchSizeBound; batchOffset += BLOCK_SIZE)
|
|
{
|
|
// The index of the batch.
|
|
int batchIdx = batchOffset + threadIdx.x;
|
|
int seqKVLength = 0;
|
|
if (batchIdx < batchSizeBound)
|
|
{
|
|
seqKVLength = kv_cache_lengths[batchIdx];
|
|
}
|
|
int seqKVOffset;
|
|
BlockScan(tempKVStorage).ExclusiveSum(seqKVLength, seqKVOffset, prefixKVOp);
|
|
if (batchIdx <= batchSizeBound)
|
|
{
|
|
seqKVOffsets[batchIdx] = seqKVOffset;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename T, typename TCache>
|
|
__global__ void loadPagedKVCacheForMLAKernel(T* compressed_kv_ptr, T* k_pe_ptr,
|
|
tensorrt_llm::kernels::KVBlockArray const kv_cache, int64_t const* cu_ctx_cached_kv_lens, int max_input_seq_len,
|
|
float const* kv_scale_quant_orig_ptr)
|
|
{
|
|
static_assert(std::is_same_v<T, TCache> || std::is_same_v<TCache, __nv_fp8_e4m3>,
|
|
"TCache must be either the same type as T or __nv_fp8_e4m3");
|
|
using KT = typename tensorrt_llm::kernels::loadPagedKVKernelTraits<TCache>;
|
|
|
|
int const batch_idx = static_cast<int>(blockIdx.y);
|
|
float const kv_scale_quant_orig = kv_scale_quant_orig_ptr ? kv_scale_quant_orig_ptr[0] : 1.0f;
|
|
|
|
size_t const head_dim_vec_idx = (threadIdx.x % KT::kVecPerHead);
|
|
size_t const head_dim_idx = head_dim_vec_idx * KT::kElemPerLoad;
|
|
bool const is_valid_kv = head_dim_vec_idx < KT::kKVThreadPerHead;
|
|
|
|
size_t const seq_len_loop_end
|
|
= (max_input_seq_len + KT::kTokenPerBlock - 1) / KT::kTokenPerBlock * KT::kTokenPerBlock;
|
|
|
|
int64_t const global_token_offset = cu_ctx_cached_kv_lens[batch_idx];
|
|
int64_t const cache_kv_len = cu_ctx_cached_kv_lens[batch_idx + 1] - cu_ctx_cached_kv_lens[batch_idx];
|
|
|
|
for (int local_token_idx = (threadIdx.x / KT::kThreadPerHead) + blockIdx.x * KT::kTokenPerBlock;
|
|
local_token_idx < seq_len_loop_end; local_token_idx += KT::kTokenPerBlock * gridDim.x)
|
|
{
|
|
int token_idx_in_kv_cache = local_token_idx;
|
|
bool const valid_token = token_idx_in_kv_cache < cache_kv_len;
|
|
|
|
if (valid_token)
|
|
{
|
|
auto* kvSrc = reinterpret_cast<TCache*>(kv_cache.getKBlockPtr(batch_idx, token_idx_in_kv_cache));
|
|
// head_idx === 0
|
|
auto kvBlockIdx
|
|
= kv_cache.getKVLocalIdx(token_idx_in_kv_cache, 0, KT::kVecPerHead, static_cast<int>(head_dim_vec_idx));
|
|
|
|
auto src_data = reinterpret_cast<typename KT::VecT*>(kvSrc)[kvBlockIdx];
|
|
|
|
int const global_token_idx = local_token_idx + global_token_offset;
|
|
|
|
if (is_valid_kv)
|
|
{
|
|
// compressed_kv {total_token, lora_size}
|
|
int const dstIdx = global_token_idx * KT::kLoraSize + head_dim_idx;
|
|
|
|
// copy back to compressed_kv
|
|
if constexpr (std::is_same_v<TCache, T>)
|
|
{
|
|
*reinterpret_cast<typename KT::VecT*>(compressed_kv_ptr + dstIdx) = src_data;
|
|
}
|
|
else if constexpr (std::is_same_v<TCache, __nv_fp8_e4m3>)
|
|
{
|
|
dequantCopy<T, KT::kElemPerLoad>(compressed_kv_ptr + dstIdx,
|
|
reinterpret_cast<__nv_fp8_e4m3 const*>(&src_data), kv_scale_quant_orig);
|
|
}
|
|
}
|
|
else
|
|
{
|
|
// k_pe {total_token, rope_size}
|
|
int const dstIdx = global_token_idx * KT::kRopeSize + (head_dim_idx - KT::kLoraSize);
|
|
|
|
// copy back to k_pe
|
|
if constexpr (std::is_same_v<TCache, T>)
|
|
{
|
|
*reinterpret_cast<typename KT::VecT*>(k_pe_ptr + dstIdx) = src_data;
|
|
}
|
|
else if constexpr (std::is_same_v<TCache, __nv_fp8_e4m3>)
|
|
{
|
|
dequantCopy<T, KT::kElemPerLoad>(
|
|
k_pe_ptr + dstIdx, reinterpret_cast<__nv_fp8_e4m3 const*>(&src_data), kv_scale_quant_orig);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// k {total_token, h, d}, v {total_token, h, d}, k_pe {total_token, h=1, d_rope}
|
|
// output {b, 2, ceil(max_seq / kv_cache_tokens_per_block), h, kv_cache_tokens_per_block, d}
|
|
template <typename T>
|
|
__global__ void setPagedKVCacheForMLAKernel(T* output, T const* k_ptr, T const* v_ptr, T const* k_pe_ptr,
|
|
int64_t const* cu_seq_lens, int const max_input_seq_len, int num_heads, int kv_dim, int rope_dim,
|
|
int kv_cache_tokens_per_block, int64_t kv_token_stride)
|
|
{
|
|
using KT = typename tensorrt_llm::kernels::setPagedKVKernelTraits<T>;
|
|
int const batch_idx = static_cast<int>(blockIdx.y);
|
|
int const head_idx = static_cast<int>(blockIdx.z);
|
|
int const head_dim_vec_idx = (threadIdx.x % KT::kThreadPerHead);
|
|
int const head_dim_idx = head_dim_vec_idx * KT::kElemPerLoad;
|
|
bool const is_valid_v = head_dim_idx < KT::kVHeadSize;
|
|
|
|
size_t const seq_len_loop_end
|
|
= (max_input_seq_len + KT::kCpTokenPerBlock - 1) / KT::kCpTokenPerBlock * KT::kCpTokenPerBlock;
|
|
size_t const kv_cache_block_size = num_heads * kv_cache_tokens_per_block * (kv_dim + rope_dim);
|
|
size_t const kv_cache_block_num = (max_input_seq_len + kv_cache_tokens_per_block - 1) / kv_cache_tokens_per_block;
|
|
int64_t const global_token_offset = cu_seq_lens[batch_idx];
|
|
int64_t const cache_kv_len = cu_seq_lens[batch_idx + 1] - cu_seq_lens[batch_idx];
|
|
|
|
for (int local_token_idx = (threadIdx.x / KT::kThreadPerHead) + blockIdx.x * KT::kCpTokenPerBlock;
|
|
local_token_idx < seq_len_loop_end; local_token_idx += KT::kCpTokenPerBlock * gridDim.x)
|
|
{
|
|
int token_idx_in_kv_cache = local_token_idx;
|
|
bool const valid_token = token_idx_in_kv_cache < cache_kv_len;
|
|
if (valid_token)
|
|
{
|
|
// copy k and v
|
|
if (is_valid_v)
|
|
{
|
|
int ld_kv_global_offset = (global_token_offset + local_token_idx) * kv_token_stride + head_idx * kv_dim;
|
|
int ld_kv_local_offset = head_dim_vec_idx;
|
|
auto k_data
|
|
= (reinterpret_cast<typename KT::VecT const*>(k_ptr + ld_kv_global_offset))[ld_kv_local_offset];
|
|
auto v_data
|
|
= (reinterpret_cast<typename KT::VecT const*>(v_ptr + ld_kv_global_offset))[ld_kv_local_offset];
|
|
// {b, 0, token / kv_cache_tokens_per_block, h, token % kv_cache_tokens_per_block, ...}
|
|
int st_k_global_offset = batch_idx * 2 * kv_cache_block_num * kv_cache_block_size
|
|
+ local_token_idx / kv_cache_tokens_per_block * kv_cache_block_size
|
|
+ head_idx * kv_cache_tokens_per_block * (kv_dim + rope_dim)
|
|
+ (local_token_idx % kv_cache_tokens_per_block) * (kv_dim + rope_dim);
|
|
// {b, 1, token / kv_cache_tokens_per_block, h, token % kv_cache_tokens_per_block, ...}
|
|
int st_v_global_offset = st_k_global_offset + kv_cache_block_num * kv_cache_block_size;
|
|
int st_k_local_offset = head_dim_vec_idx;
|
|
int st_v_local_offset = head_dim_vec_idx;
|
|
(reinterpret_cast<typename KT::VecT*>(output + st_k_global_offset))[st_k_local_offset] = k_data;
|
|
(reinterpret_cast<typename KT::VecT*>(output + st_v_global_offset))[st_v_local_offset] = v_data;
|
|
}
|
|
// copy k_pe, only 1 head
|
|
else
|
|
{
|
|
int ld_rope_global_offset = (global_token_offset + local_token_idx) * rope_dim;
|
|
int ld_rope_local_offset = head_dim_vec_idx - KT::kKVThreadPerHead;
|
|
auto rope_data = (reinterpret_cast<typename KT::VecT const*>(
|
|
k_pe_ptr + ld_rope_global_offset))[ld_rope_local_offset];
|
|
// {b, 0, token / kv_cache_tokens_per_block, h, token % kv_cache_tokens_per_block, ...}
|
|
int st_rope_global_offset = batch_idx * 2 * kv_cache_block_num * kv_cache_block_size
|
|
+ local_token_idx / kv_cache_tokens_per_block * kv_cache_block_size
|
|
+ head_idx * kv_cache_tokens_per_block * (kv_dim + rope_dim)
|
|
+ (local_token_idx % kv_cache_tokens_per_block) * (kv_dim + rope_dim);
|
|
int st_rope_local_offset = head_dim_vec_idx;
|
|
(reinterpret_cast<typename KT::VecT*>(output + st_rope_global_offset))[st_rope_local_offset]
|
|
= rope_data;
|
|
}
|
|
}
|
|
else
|
|
{
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
// q {total_uncached_tokens, h, d_nope + d_rope}
|
|
// latent_cache {total_uncached_tokens, d_k + d_rope}
|
|
template <typename T, typename TCache, int BLOCK_SIZE, int K_DIM, int ROPE_DIM>
|
|
__global__ void applyMLARopeAppendPagedKVAssignQKernel(KVBlockArray kv_cache, T* q_ptr, T* latent_cache_ptr,
|
|
int64_t const* cu_ctx_cached_kv_lens, int64_t const* cu_seq_lens, int const max_input_uncached_seq_len,
|
|
float2 const* cos_sin_cache, size_t head_num, int nope_size, float const* kv_scale_orig_quant_ptr)
|
|
{
|
|
static_assert(std::is_same_v<T, TCache> || std::is_same_v<TCache, __nv_fp8_e4m3>,
|
|
"TCache must be either the same type as T or __nv_fp8_e4m3");
|
|
// Constants.
|
|
using VecT = typename VecType<T>::Type;
|
|
using GPTJEltT = typename VecType<T>::GPTJEltType;
|
|
constexpr auto HEAD_SIZE = ROPE_DIM;
|
|
constexpr auto K_HEAD_SIZE = K_DIM;
|
|
constexpr auto BYTES_PER_ELT = sizeof(T);
|
|
constexpr auto BYTES_PER_LOAD = 16;
|
|
constexpr auto ELTS_PER_VEC = BYTES_PER_LOAD / BYTES_PER_ELT;
|
|
static_assert((HEAD_SIZE * BYTES_PER_ELT) % BYTES_PER_LOAD == 0, "Head size needs to be multiple of 16 bytes.");
|
|
constexpr auto VECS_PER_HEAD = HEAD_SIZE * BYTES_PER_ELT / BYTES_PER_LOAD;
|
|
constexpr auto K_VECS_PER_HEAD = K_HEAD_SIZE * BYTES_PER_ELT / BYTES_PER_LOAD;
|
|
static_assert(BLOCK_SIZE % VECS_PER_HEAD == 0, "Kernel block should be able to handle entire heads.");
|
|
constexpr auto TOKENS_PER_BLOCK = BLOCK_SIZE / VECS_PER_HEAD;
|
|
constexpr auto K_TOKENS_PER_BLOCK = BLOCK_SIZE / K_VECS_PER_HEAD;
|
|
constexpr auto TOTAL_VECS_PER_HEAD = VECS_PER_HEAD + K_VECS_PER_HEAD;
|
|
|
|
// Block/Head idx.
|
|
size_t const batch_idx = blockIdx.y;
|
|
size_t const head_idx = blockIdx.z;
|
|
|
|
int64_t const global_token_offset = cu_seq_lens[batch_idx] - cu_ctx_cached_kv_lens[batch_idx];
|
|
int64_t const cached_kv_len = cu_ctx_cached_kv_lens[batch_idx + 1] - cu_ctx_cached_kv_lens[batch_idx];
|
|
int64_t const uncached_kv_len = cu_seq_lens[batch_idx + 1] - cu_seq_lens[batch_idx] - cached_kv_len;
|
|
|
|
if (head_idx <= head_num)
|
|
{
|
|
size_t const head_dim_vec_idx = (threadIdx.x % VECS_PER_HEAD);
|
|
size_t const head_dim_idx = head_dim_vec_idx * ELTS_PER_VEC;
|
|
|
|
size_t const seq_len_loop_end
|
|
= size_t((max_input_uncached_seq_len + TOKENS_PER_BLOCK - 1) / TOKENS_PER_BLOCK) * TOKENS_PER_BLOCK;
|
|
float quant_scale_kv_val = kv_scale_orig_quant_ptr ? kv_scale_orig_quant_ptr[0] : 1.f;
|
|
|
|
// Mainloop.
|
|
for (int local_token_idx = (threadIdx.x / VECS_PER_HEAD) + blockIdx.x * TOKENS_PER_BLOCK;
|
|
local_token_idx < seq_len_loop_end; local_token_idx += TOKENS_PER_BLOCK * gridDim.x)
|
|
{
|
|
|
|
int token_idx_in_kv_cache = local_token_idx + cached_kv_len;
|
|
bool valid_token = local_token_idx < uncached_kv_len;
|
|
int const global_token_idx = local_token_idx + global_token_offset;
|
|
VecT data;
|
|
|
|
if (valid_token)
|
|
{
|
|
auto const position_id = token_idx_in_kv_cache;
|
|
float2 const* rotary_coef_cache_buffer
|
|
= cos_sin_cache + static_cast<size_t>(ROPE_DIM) * position_id + (head_dim_idx / 2);
|
|
|
|
if (head_idx == head_num)
|
|
{
|
|
auto const src_k_global_offset = static_cast<size_t>(global_token_idx) * (K_DIM + ROPE_DIM) + K_DIM;
|
|
data = *reinterpret_cast<VecT const*>(&latent_cache_ptr[src_k_global_offset + head_dim_idx]);
|
|
}
|
|
else
|
|
{
|
|
auto const src_q_global_offset
|
|
= static_cast<size_t>(global_token_idx) * head_num * (nope_size + ROPE_DIM)
|
|
+ (nope_size + ROPE_DIM) * head_idx + nope_size;
|
|
data = *reinterpret_cast<VecT const*>(&q_ptr[src_q_global_offset + head_dim_idx]);
|
|
}
|
|
|
|
// Pack two elements into one for gptj rotary embedding.
|
|
#pragma unroll
|
|
for (int elt_id = 0; elt_id < ELTS_PER_VEC / 2; elt_id++)
|
|
{
|
|
GPTJEltT& data_ = reinterpret_cast<GPTJEltT*>(&data)[elt_id];
|
|
|
|
float2 rotary_coef_cache = rotary_coef_cache_buffer[elt_id];
|
|
data_ = mmha::rotary_embedding_transform(data_, rotary_coef_cache);
|
|
}
|
|
}
|
|
// do sync
|
|
__syncwarp();
|
|
if (valid_token)
|
|
{
|
|
if (head_idx == head_num)
|
|
{
|
|
auto kDst = reinterpret_cast<T*>(kv_cache.getKBlockPtr(batch_idx, token_idx_in_kv_cache));
|
|
auto inBlockIdx = kv_cache.getKVLocalIdx(
|
|
token_idx_in_kv_cache, 0, TOTAL_VECS_PER_HEAD, K_VECS_PER_HEAD + head_dim_vec_idx);
|
|
if constexpr (std::is_same_v<TCache, T>)
|
|
{
|
|
reinterpret_cast<VecT*>(kDst)[inBlockIdx] = data;
|
|
}
|
|
else if constexpr (std::is_same_v<TCache, __nv_fp8_e4m3>)
|
|
{
|
|
quantCopy<T, ELTS_PER_VEC>(reinterpret_cast<__nv_fp8_e4m3*>(kDst) + inBlockIdx * ELTS_PER_VEC,
|
|
reinterpret_cast<T const*>(&data), quant_scale_kv_val);
|
|
}
|
|
// copy to latent_cache (for chunked prefill, it will not load kv cache for uncached k_pe)
|
|
// we only need to copy original value.
|
|
auto const src_k_global_offset = static_cast<size_t>(global_token_idx) * (K_DIM + ROPE_DIM) + K_DIM;
|
|
*reinterpret_cast<VecT*>(&latent_cache_ptr[src_k_global_offset + head_dim_idx]) = data;
|
|
}
|
|
else
|
|
{
|
|
auto const dst_q_idx = static_cast<size_t>(global_token_idx) * head_num * (nope_size + ROPE_DIM)
|
|
+ head_idx * (nope_size + ROPE_DIM) + nope_size + head_dim_idx;
|
|
reinterpret_cast<VecT*>(q_ptr)[dst_q_idx / ELTS_PER_VEC] = data;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
else
|
|
{
|
|
int block_dim = gridDim.z - head_num - 1;
|
|
int block_id = head_idx - head_num - 1;
|
|
size_t const head_dim_vec_idx = (threadIdx.x % K_VECS_PER_HEAD);
|
|
size_t const head_dim_idx = head_dim_vec_idx * ELTS_PER_VEC;
|
|
|
|
size_t const seq_len_loop_end
|
|
= size_t((max_input_uncached_seq_len + K_TOKENS_PER_BLOCK - 1) / K_TOKENS_PER_BLOCK) * K_TOKENS_PER_BLOCK;
|
|
float quant_scale_kv_val = kv_scale_orig_quant_ptr ? kv_scale_orig_quant_ptr[0] : 1.f;
|
|
|
|
// Mainloop.
|
|
for (int local_token_idx = (threadIdx.x / K_VECS_PER_HEAD) + gridDim.x * K_TOKENS_PER_BLOCK * block_id
|
|
+ blockIdx.x * K_TOKENS_PER_BLOCK;
|
|
local_token_idx < seq_len_loop_end; local_token_idx += block_dim * K_TOKENS_PER_BLOCK * gridDim.x)
|
|
{
|
|
|
|
int token_idx_in_kv_cache = local_token_idx + cached_kv_len;
|
|
bool valid_token = local_token_idx < uncached_kv_len;
|
|
int const global_token_idx = local_token_idx + global_token_offset;
|
|
|
|
if (valid_token)
|
|
{
|
|
auto const src_k_global_offset = static_cast<size_t>(global_token_idx) * (K_DIM + ROPE_DIM);
|
|
|
|
auto kDst = reinterpret_cast<T*>(kv_cache.getKBlockPtr(batch_idx, token_idx_in_kv_cache));
|
|
auto inBlockIdx
|
|
= kv_cache.getKVLocalIdx(token_idx_in_kv_cache, 0, TOTAL_VECS_PER_HEAD, head_dim_vec_idx);
|
|
if constexpr (std::is_same_v<TCache, T>)
|
|
{
|
|
reinterpret_cast<VecT*>(kDst)[inBlockIdx]
|
|
= *reinterpret_cast<VecT const*>(&latent_cache_ptr[src_k_global_offset + head_dim_idx]);
|
|
}
|
|
else if constexpr (std::is_same_v<TCache, __nv_fp8_e4m3>)
|
|
{
|
|
quantCopy<T, ELTS_PER_VEC>(reinterpret_cast<__nv_fp8_e4m3*>(kDst) + inBlockIdx * ELTS_PER_VEC,
|
|
latent_cache_ptr + src_k_global_offset + head_dim_idx, quant_scale_kv_val);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename T, typename KVCacheBuffer>
|
|
void invokeMLARopeContext(MlaParams<T>& params, KVCacheBuffer kv_cache_buffer, cudaStream_t stream)
|
|
{
|
|
dim3 grid(int(tensorrt_llm::common::divUp(params.max_input_seq_len, 32)), params.batch_size, params.head_num + 8);
|
|
auto head_size = params.meta.qk_nope_head_dim;
|
|
applyMLARopeAndAssignQKVKernelOptContext<T, 256, 512, 64, KVCacheBuffer><<<grid, 256, 0, stream>>>(
|
|
params.attention_input_buf, params.latent_cache, kv_cache_buffer, params.cos_sin_cache, params.head_num,
|
|
head_size, params.meta.kv_lora_rank, params.cu_q_seqlens, params.cache_seq_lens, params.max_input_seq_len,
|
|
params.cache_type, params.bmm1_scale, params.bmm2_scale, params.quant_scale_o, params.quant_scale_kv,
|
|
params.dequant_scale_q, params.dequant_scale_kv, params.host_bmm1_scale);
|
|
if (params.attention_input_buf != nullptr && params.quant_attention_input_buf != nullptr
|
|
&& params.cache_type == KvCacheDataType::FP8)
|
|
{
|
|
TLLM_LOG_DEBUG("MLA RoPE Context: Quantizing attention_input_buf to FP8");
|
|
|
|
int const dim_q_per_head = (params.meta.qk_nope_head_dim + params.meta.qk_rope_head_dim);
|
|
int const dim_k_per_head = (params.meta.qk_nope_head_dim + params.meta.qk_rope_head_dim);
|
|
int const dim_v_per_head = (params.meta.v_head_dim);
|
|
|
|
// Total dimension per token across all heads for Q, K, and V components respectively
|
|
int const total_q_dim_all_heads = params.head_num * dim_q_per_head;
|
|
int const total_k_dim_all_heads
|
|
= params.head_num * dim_k_per_head; // Assuming effective num_kv_heads = head_num for layout
|
|
int const total_v_dim_all_heads
|
|
= params.head_num * dim_v_per_head; // Assuming effective num_kv_heads = head_num for layout
|
|
|
|
int const num_total_qkv_elements
|
|
= params.acc_q_len * (total_q_dim_all_heads + total_k_dim_all_heads + total_v_dim_all_heads);
|
|
size_t headDim = params.meta.kv_lora_rank + params.meta.qk_rope_head_dim;
|
|
float const* device_qkv_scale_ptr = params.quant_scale_qkv;
|
|
|
|
if (num_total_qkv_elements > 0)
|
|
{
|
|
int const threads_per_block = 256;
|
|
int const num_blocks = (num_total_qkv_elements + threads_per_block - 1) / threads_per_block;
|
|
|
|
TLLM_LOG_DEBUG(
|
|
"Launching QuantizeCopyInputToFp8Kernel with num_blocks: %d, threads_per_block: %d, elements: %d",
|
|
num_blocks, threads_per_block, num_total_qkv_elements);
|
|
|
|
tensorrt_llm::kernels::QuantizeCopyInputToFp8Kernel<T><<<num_blocks, threads_per_block, 0, stream>>>(
|
|
static_cast<T const*>(params.attention_input_buf), // Source
|
|
static_cast<__nv_fp8_e4m3*>(params.quant_attention_input_buf), // Destination
|
|
num_total_qkv_elements, device_qkv_scale_ptr);
|
|
sync_check_cuda_error(stream);
|
|
|
|
cudaStreamSynchronize(stream);
|
|
}
|
|
else
|
|
{
|
|
TLLM_LOG_WARNING("MLA RoPE Context: num_total_qkv_elements is 0, skipping quantization.");
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename T, typename KVCacheBuffer>
|
|
void invokeMLARopeGeneration(MlaParams<T>& params, KVCacheBuffer kv_cache_buffer, cudaStream_t stream)
|
|
{
|
|
dim3 grid(int(tensorrt_llm::common::divUp(params.acc_q_len, 32)), params.head_num + 1 + 8);
|
|
if (params.cache_type == KvCacheDataType::FP8)
|
|
grid.y += params.head_num * 8;
|
|
TLLM_CHECK_WITH_INFO(params.acc_q_len % params.batch_size == 0,
|
|
"MLA can only support input sequences with the same sequence length.");
|
|
auto seq_len = params.acc_q_len / params.batch_size;
|
|
|
|
auto* kernel_instance = &applyMLARopeAndAssignQKVKernelGeneration<T, 256, 512, 64, KVCacheBuffer>;
|
|
cudaLaunchConfig_t config;
|
|
config.gridDim = grid;
|
|
config.blockDim = 256;
|
|
config.dynamicSmemBytes = 0;
|
|
config.stream = stream;
|
|
cudaLaunchAttribute attrs[1];
|
|
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
|
attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL();
|
|
config.numAttrs = 1;
|
|
config.attrs = attrs;
|
|
cudaLaunchKernelEx(&config, kernel_instance, params.attention_input_buf, params.q_pe, params.latent_cache,
|
|
params.quant_attention_input_buf, kv_cache_buffer, params.cos_sin_cache, params.head_num,
|
|
params.meta.kv_lora_rank, params.acc_q_len, seq_len, params.seqQOffset, params.fmha_tile_counter,
|
|
params.cache_seq_lens, params.cu_kv_seqlens, params.q_pe_ld, params.q_pe_stride, params.cache_type,
|
|
params.bmm1_scale, params.bmm2_scale, params.quant_scale_o, params.quant_scale_q, params.quant_scale_kv,
|
|
params.dequant_scale_q, params.dequant_scale_kv, params.host_bmm1_scale);
|
|
}
|
|
|
|
template <typename T, typename TCache>
|
|
void invokeMLALoadPagedKV(T* compressed_kv_ptr, T* k_pe_ptr, KVBlockArray& kv_cache, int const num_contexts,
|
|
int64_t const* cu_ctx_cached_kv_lens, int const max_input_seq_len, int const lora_size, int const rope_size,
|
|
float const* kv_scale_quant_orig_ptr, cudaStream_t stream)
|
|
{
|
|
using KT = typename tensorrt_llm::kernels::loadPagedKVKernelTraits<TCache>;
|
|
// {seq_len / token_per_block, batch_size, head_num}
|
|
TLLM_CHECK_WITH_INFO(lora_size == KT::kLoraSize, "lora_size should be equal to %d", KT::kLoraSize);
|
|
TLLM_CHECK_WITH_INFO(rope_size == KT::kRopeSize, "rope_size should be equal to %d", KT::kRopeSize);
|
|
TLLM_CHECK_WITH_INFO(lora_size + rope_size == KT::kHeadSize, "head dim should be equal to %d", KT::kHeadSize);
|
|
dim3 grid(static_cast<int>(tensorrt_llm::common::divUp(max_input_seq_len, KT::kTokenPerBlock)), num_contexts, 1);
|
|
loadPagedKVCacheForMLAKernel<T, TCache><<<grid, KT::kBlockSize, 0, stream>>>(
|
|
compressed_kv_ptr, k_pe_ptr, kv_cache, cu_ctx_cached_kv_lens, max_input_seq_len, kv_scale_quant_orig_ptr);
|
|
}
|
|
|
|
template <typename T>
|
|
void invokeMLASetPagedKV(T* output, T const* k_ptr, T const* v_ptr, T const* k_pe_ptr, int const num_requests,
|
|
int64_t const* cu_seq_lens, int const max_input_seq_len, int num_heads, int kv_dim, int rope_dim,
|
|
int kv_cache_tokens_per_block, int64_t kv_token_stride, cudaStream_t stream)
|
|
{
|
|
using KT = typename tensorrt_llm::kernels::setPagedKVKernelTraits<T>;
|
|
TLLM_CHECK_WITH_INFO(kv_dim + rope_dim == KT::kHeadSize, "head dim should be equal to %d", KT::kHeadSize);
|
|
TLLM_CHECK_WITH_INFO(kv_cache_tokens_per_block % KT::kCpTokenPerBlock == 0,
|
|
"kv_cache_tokens_per_block should be multiple of %d", KT::kCpTokenPerBlock);
|
|
dim3 grid(tensorrt_llm::common::divUp(max_input_seq_len, KT::kCpTokenPerBlock), num_requests, num_heads);
|
|
setPagedKVCacheForMLAKernel<T><<<grid, KT::kBlockSize, 0, stream>>>(output, k_ptr, v_ptr, k_pe_ptr, cu_seq_lens,
|
|
max_input_seq_len, num_heads, kv_dim, rope_dim, kv_cache_tokens_per_block, kv_token_stride);
|
|
}
|
|
|
|
template <typename T, typename TCache>
|
|
void invokeMLARopeAppendPagedKVAssignQ(KVBlockArray& kv_cache, T* q_ptr, T* latent_cache_ptr, int const num_requests,
|
|
int64_t const* cu_ctx_cached_kv_lens, int64_t const* cu_seq_lens, int const max_input_uncached_seq_len,
|
|
float2 const* cos_sin_cache, size_t head_num, int nope_size, int rope_size, int lora_size,
|
|
float const* kv_scale_orig_quant_ptr, cudaStream_t stream)
|
|
{
|
|
dim3 grid(int(tensorrt_llm::common::divUp(max_input_uncached_seq_len, 32)), num_requests, head_num + 1 + 8);
|
|
TLLM_CHECK_WITH_INFO(lora_size == 512, "lora_size should be equal to %d", 512);
|
|
TLLM_CHECK_WITH_INFO(rope_size == 64, "rope_size should be equal to %d", 64);
|
|
applyMLARopeAppendPagedKVAssignQKernel<T, TCache, 256, 512, 64><<<grid, 256, 0, stream>>>(kv_cache, q_ptr,
|
|
latent_cache_ptr, cu_ctx_cached_kv_lens, cu_seq_lens, max_input_uncached_seq_len, cos_sin_cache, head_num,
|
|
nope_size, kv_scale_orig_quant_ptr);
|
|
}
|
|
|
|
#define INSTANTIATE_MLA_ROPE(T, KVCacheBuffer) \
|
|
template void invokeMLARopeContext(MlaParams<T>& params, KVCacheBuffer kv_cache_buffer, cudaStream_t stream); \
|
|
template void invokeMLARopeGeneration(MlaParams<T>& params, KVCacheBuffer kv_cache_buffer, cudaStream_t stream);
|
|
|
|
INSTANTIATE_MLA_ROPE(float, KVBlockArray);
|
|
INSTANTIATE_MLA_ROPE(half, KVBlockArray);
|
|
INSTANTIATE_MLA_ROPE(float, KVLinearBuffer);
|
|
INSTANTIATE_MLA_ROPE(half, KVLinearBuffer);
|
|
|
|
#ifdef ENABLE_BF16
|
|
INSTANTIATE_MLA_ROPE(__nv_bfloat16, KVBlockArray);
|
|
INSTANTIATE_MLA_ROPE(__nv_bfloat16, KVLinearBuffer);
|
|
#endif
|
|
|
|
#define INSTANTIATE_RW_KVCACHE_MLA(T, TCache) \
|
|
template void invokeMLALoadPagedKV<T, TCache>(T * compressed_kv_ptr, T * k_pe_ptr, KVBlockArray & kv_cache, \
|
|
int const num_contexts, int64_t const* cu_ctx_cached_kv_lens, int const max_input_seq_len, \
|
|
int const lora_size, int const rope_size, float const* kv_scale_quant_orig_ptr, cudaStream_t stream); \
|
|
template void invokeMLARopeAppendPagedKVAssignQ<T, TCache>(KVBlockArray & kv_cache, T * q_ptr, \
|
|
T * latent_cache_ptr, int const num_requests, int64_t const* cu_ctx_cached_kv_lens, \
|
|
int64_t const* cu_seq_lens, int const max_input_uncached_seq_len, float2 const* cos_sin_cache, \
|
|
size_t head_num, int nope_size, int rope_size, int lora_size, float const* kv_scale_orig_quant_ptr, \
|
|
cudaStream_t stream);
|
|
|
|
INSTANTIATE_RW_KVCACHE_MLA(float, float);
|
|
INSTANTIATE_RW_KVCACHE_MLA(float, __nv_fp8_e4m3);
|
|
INSTANTIATE_RW_KVCACHE_MLA(half, half);
|
|
INSTANTIATE_RW_KVCACHE_MLA(half, __nv_fp8_e4m3);
|
|
INSTANTIATE_RW_KVCACHE_MLA(__nv_bfloat16, __nv_bfloat16);
|
|
INSTANTIATE_RW_KVCACHE_MLA(__nv_bfloat16, __nv_fp8_e4m3);
|
|
|
|
#define INSTANTIATE_SET_KVCACHE_MLA(T) \
|
|
template void invokeMLASetPagedKV(T* output, T const* k_ptr, T const* v_ptr, T const* k_pe_ptr, \
|
|
int const num_requests, int64_t const* cu_seq_lens, int const max_input_seq_len, int num_heads, int kv_dim, \
|
|
int rope_dim, int kv_cache_tokens_per_block, int64_t kv_token_stride, cudaStream_t stream);
|
|
|
|
INSTANTIATE_SET_KVCACHE_MLA(float);
|
|
INSTANTIATE_SET_KVCACHE_MLA(half);
|
|
INSTANTIATE_SET_KVCACHE_MLA(__nv_bfloat16);
|
|
|
|
template <typename T_IN>
|
|
__global__ void QuantizeCopyInputToFp8Kernel(
|
|
T_IN const* input_buffer, __nv_fp8_e4m3* output_fp8_buffer, int num_total_elements, float const* device_scale_ptr)
|
|
{
|
|
uint element_idx = threadIdx.x + blockDim.x * blockIdx.x;
|
|
if (element_idx < num_total_elements)
|
|
{
|
|
float scale_factor = (device_scale_ptr != nullptr) ? *device_scale_ptr : 1.0f;
|
|
output_fp8_buffer[element_idx] = __nv_fp8_e4m3(static_cast<float>(input_buffer[element_idx]) * scale_factor);
|
|
}
|
|
}
|
|
} // namespace kernels
|
|
|
|
} // namespace tensorrt_llm
|