mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com> Signed-off-by: Tian Zheng <29906817+Tom-Zheng@users.noreply.github.com> Co-authored-by: Tian Zheng <29906817+Tom-Zheng@users.noreply.github.com>
2204 lines
89 KiB
C++
2204 lines
89 KiB
C++
/*
|
|
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include <algorithm>
|
|
#include <float.h>
|
|
#include <fmha/hopper/tma_types.h>
|
|
#include <fmha/paged_kv_cache.h>
|
|
#include <fstream>
|
|
#include <fused_multihead_attention_api.h>
|
|
#include <iostream>
|
|
#include <math.h>
|
|
#include <numeric>
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
using Launch_params = bert::Fused_multihead_attention_launch_params;
|
|
using Attention_mask_type = fmha::Attention_mask_type;
|
|
using Attention_input_layout = fmha::Attention_input_layout;
|
|
using Kv_block_array = fmha::Kv_block_array;
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
void run_softmax_fp32(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d,
|
|
void* cu_q_seqlens_d, int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n,
|
|
bool has_alibi);
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
void run_softmax_e4m3(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d,
|
|
void* cu_q_seqlens_d, int s_inner, int s_outer, int b, int h, float scale_softmax, float softcapping_scale_bmm1,
|
|
int warps_n, bool has_alibi);
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
void run_softmax_fp16(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d,
|
|
void* cu_q_seqlens_d, int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n,
|
|
bool has_alibi);
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
void run_softmax_bf16(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d,
|
|
void* cu_q_seqlens_d, int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n,
|
|
bool has_alibi);
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
void run_softmax_int8(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d,
|
|
void* cu_q_seqlens_d, int s_inner, int s_outer, int b, int h, float scale_i2f, float scale_f2i,
|
|
float softcapping_scale_bmm1, int warps_n, bool has_alibi);
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
void run_conversion_int32_to_int8(void* dst, void const* src, int s, int b, int h, int d, float scale);
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
void run_conversion_fp32_to_fp16(void* dst, void const* src, int s, int b, int h, int d);
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
void run_conversion_fp32_to_bf16(void* dst, void const* src, int s, int b, int h, int d);
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
void run_conversion_fp32_to_e4m3(void* dst, void const* src, int s, int b, int h, int d, float scale_o);
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
void run_sage_quant(unsigned int batch_size, unsigned int head_num, unsigned int head_size, unsigned int max_seq_len,
|
|
// device var
|
|
void const* q, void const* k, void const* v, int stride_q, int stride_k, int stride_v, int const* cu_seqlens_q,
|
|
int const* cu_seqlens_kv, int block_size_q, int block_size_k, int block_size_v,
|
|
// output
|
|
void* quant_q, void* quant_k, void* quant_v, float* scales_q, float* scales_k, float* scales_v);
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
void ground_truth(RefBMM& bmm1, RefBMM& bmm2, const Data_type data_type, const Data_type acc_type,
|
|
float const scale_bmm1, float const scale_softmax, float const scale_bmm2, float const softcapping_scale_bmm1,
|
|
void* qkv_d, void* vt_d, void* mask_d, void* attention_sinks_d, void* p_d, void* s_d, void* tmp_d, void* o_d,
|
|
void* softmax_sum_d, void* cu_q_seqlens_d, const size_t b, const size_t s, const size_t h, const size_t d,
|
|
const size_t dv, int const runs, int const warps_m, int const warps_n, bool const has_alibi)
|
|
{
|
|
|
|
cudaStream_t stream = 0;
|
|
// The stride between rows of the QKV matrix.
|
|
size_t qkv_stride = get_size_in_bytes(d, data_type);
|
|
|
|
// 1st GEMMd.
|
|
uint32_t alpha, beta = 0u;
|
|
|
|
for (int ii = 0; ii < runs; ++ii)
|
|
{
|
|
|
|
// If we run the INT8 kernel, defer the scaling of P to softmax.
|
|
set_alpha(alpha, data_type == DATA_TYPE_INT8 ? 1.f : scale_bmm1, acc_type);
|
|
// P = Q x K'
|
|
bmm1(static_cast<char*>(qkv_d) + 0 * qkv_stride, static_cast<char*>(qkv_d) + 1 * qkv_stride, p_d, &alpha, &beta,
|
|
stream);
|
|
// Softmax.
|
|
if (data_type == DATA_TYPE_FP16 && acc_type == DATA_TYPE_FP16)
|
|
{
|
|
run_softmax_fp16(s_d, p_d, mask_d, attention_sinks_d, softmax_sum_d, cu_q_seqlens_d, s, s, b, h,
|
|
softcapping_scale_bmm1, warps_n, has_alibi);
|
|
}
|
|
else if (data_type == DATA_TYPE_BF16 && acc_type == DATA_TYPE_FP32)
|
|
{
|
|
run_softmax_bf16(s_d, p_d, mask_d, attention_sinks_d, softmax_sum_d, cu_q_seqlens_d, s, s, b, h,
|
|
softcapping_scale_bmm1, warps_n, has_alibi);
|
|
}
|
|
else if (data_type == DATA_TYPE_FP16 && acc_type == DATA_TYPE_FP32)
|
|
{
|
|
run_softmax_fp32(s_d, p_d, mask_d, attention_sinks_d, softmax_sum_d, cu_q_seqlens_d, s, s, b, h,
|
|
softcapping_scale_bmm1, warps_n, has_alibi);
|
|
}
|
|
else if (data_type == DATA_TYPE_E4M3 && acc_type == DATA_TYPE_FP32)
|
|
{
|
|
run_softmax_e4m3(s_d, p_d, mask_d, attention_sinks_d, softmax_sum_d, cu_q_seqlens_d, s, s, b, h,
|
|
scale_softmax, softcapping_scale_bmm1, warps_n, has_alibi);
|
|
}
|
|
else if (data_type == DATA_TYPE_INT8 && acc_type == DATA_TYPE_INT32)
|
|
{
|
|
run_softmax_int8(s_d, p_d, mask_d, attention_sinks_d, softmax_sum_d, cu_q_seqlens_d, s, s, b, h, scale_bmm1,
|
|
scale_softmax, softcapping_scale_bmm1, warps_n, has_alibi);
|
|
}
|
|
else
|
|
{
|
|
assert(false && "Reference Softmax: Unsupported type config");
|
|
}
|
|
|
|
// 2nd GEMM.
|
|
set_alpha(alpha, 1.f, acc_type);
|
|
|
|
void* out_d = o_d;
|
|
|
|
// We may have to do a final conversion.
|
|
if (data_type != acc_type)
|
|
{
|
|
out_d = tmp_d;
|
|
}
|
|
// O = S x V
|
|
bmm2(static_cast<char*>(s_d),
|
|
static_cast<char*>(vt_d), // static_cast<char *>(qkv_d) + 2 * qkv_stride,
|
|
out_d, &alpha, &beta, stream);
|
|
// Conversion to output type.
|
|
if (data_type == DATA_TYPE_FP16 && acc_type == DATA_TYPE_FP16)
|
|
{
|
|
// Noop.
|
|
}
|
|
else if (data_type == DATA_TYPE_FP16 && acc_type == DATA_TYPE_FP32)
|
|
{
|
|
run_conversion_fp32_to_fp16(o_d, out_d, s, b, h, dv);
|
|
}
|
|
else if (data_type == DATA_TYPE_BF16 && acc_type == DATA_TYPE_FP32)
|
|
{
|
|
run_conversion_fp32_to_bf16(o_d, out_d, s, b, h, dv);
|
|
}
|
|
else if (data_type == DATA_TYPE_E4M3 && acc_type == DATA_TYPE_FP32)
|
|
{
|
|
run_conversion_fp32_to_e4m3(o_d, out_d, s, b, h, dv, scale_bmm2);
|
|
}
|
|
else if (data_type == DATA_TYPE_INT8 && acc_type == DATA_TYPE_INT32)
|
|
{
|
|
// quantize output in second step
|
|
run_conversion_int32_to_int8(o_d, out_d, s, b, h, dv, scale_bmm2);
|
|
}
|
|
}
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static inline void set_params(bert::Fused_multihead_attention_params_v1& params,
|
|
// types
|
|
Data_type data_type, Data_type acc_type,
|
|
// sizes
|
|
const size_t b, const size_t s, const size_t h, const size_t d, const size_t packed_mask_stride,
|
|
// device pointers
|
|
void* qkv_d, void* packed_mask_d, void* o_d, void* p_d, void* s_d,
|
|
// scale factors
|
|
float const scale_bmm1, float const scale_softmax, float const scale_bmm2,
|
|
// flags
|
|
bool const has_alibi)
|
|
{
|
|
memset(¶ms, 0, sizeof(params));
|
|
|
|
// Set the pointers.
|
|
params.qkv_ptr = qkv_d;
|
|
params.qkv_stride_in_bytes = get_size_in_bytes(b * h * 3 * d, data_type);
|
|
// params.qkv_stride_in_bytes = get_size_in_bytes(h * 3 * d, data_type);
|
|
params.packed_mask_ptr = packed_mask_d;
|
|
// params.packed_mask_stride_in_bytes = mmas_m * threads_per_cta * sizeof(uint32_t);
|
|
params.packed_mask_stride_in_bytes = packed_mask_stride * sizeof(uint32_t);
|
|
params.o_ptr = o_d;
|
|
params.o_stride_in_bytes = get_size_in_bytes(b * h * d, data_type);
|
|
params.has_alibi = has_alibi;
|
|
params.alibi_params = fmha::AlibiParams(h);
|
|
|
|
#if defined(STORE_P)
|
|
params.p_ptr = p_d;
|
|
params.p_stride_in_bytes = get_size_in_bytes(b * h * s, acc_type);
|
|
#endif // defined(STORE_P)
|
|
|
|
#if defined(STORE_S)
|
|
params.s_ptr = s_d;
|
|
params.s_stride_in_bytes = get_size_in_bytes(b * h * s, data_type);
|
|
#endif // defined(STORE_S)
|
|
|
|
// Set the dimensions.
|
|
params.b = b;
|
|
params.h = h;
|
|
params.s = s;
|
|
params.d = d;
|
|
|
|
// Set the different scale values.
|
|
Data_type scale_type1 = (data_type == DATA_TYPE_FP16) || (data_type == DATA_TYPE_BF16) ? acc_type : DATA_TYPE_FP32;
|
|
Data_type scale_type2 = (data_type == DATA_TYPE_FP16) || (data_type == DATA_TYPE_BF16) ? data_type : DATA_TYPE_FP32;
|
|
|
|
set_alpha(params.scale_bmm1, scale_bmm1, scale_type1);
|
|
set_alpha(params.scale_softmax, scale_softmax, scale_type1);
|
|
set_alpha(params.scale_bmm2, scale_bmm2, scale_type2);
|
|
|
|
// Do we enable the trick to replace I2F with FP math in the 2nd GEMM?
|
|
if (data_type == DATA_TYPE_INT8)
|
|
{
|
|
params.enable_i2f_trick
|
|
= -double(1 << 22) * double(scale_bmm2) <= -128.f && double(1 << 22) * double(scale_bmm2) >= 127.f;
|
|
}
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static inline void set_params(bert::Fused_multihead_attention_params_v2& params, const Launch_params launch_params,
|
|
// types
|
|
Data_type data_type, Data_type acc_type, Data_type output_dtype,
|
|
// attention input layout
|
|
Attention_input_layout input_layout,
|
|
// sizes
|
|
const size_t b, const size_t s_q, const size_t s_kv, const size_t h, const size_t h_kv, const size_t d,
|
|
const size_t dv, const size_t total, const size_t num_grouped_heads, const size_t sliding_window_size,
|
|
const size_t chunked_attention_size,
|
|
// paged kv cache block size.
|
|
const size_t tokens_per_block,
|
|
// device pointers
|
|
void* qkv_packed_d,
|
|
// contiguous q.
|
|
void* q_d,
|
|
// separate k.
|
|
void* k_d,
|
|
// separate v.
|
|
void* v_d,
|
|
// contiguous kv.
|
|
void* kv_d,
|
|
// start address of the paged kv pool.
|
|
void* paged_kv_pool_ptr,
|
|
// offsets for different blocks in terms of the start address.
|
|
int32_t* paged_block_offsets,
|
|
// mask input.
|
|
void* packed_mask_d, void* cu_mask_rows_d,
|
|
// attention sinks.
|
|
void* attention_sinks_d, void* cu_kv_seqlens_d, void* cu_q_seqlens_d, void* o_packed_d, void* p_d, void* s_d,
|
|
void* softmax_stats_d, void* scale_bmm2_d,
|
|
// scale factors
|
|
float const scale_bmm1, float const scale_softmax, float const scale_bmm2, float const softcapping_scale_bmm1,
|
|
// flags
|
|
bool const use_int8_scale_max, bool const interleaved, bool const is_s_padded, bool const has_alibi,
|
|
float const skip_softmax_threshold_scale_factor)
|
|
{
|
|
|
|
memset(¶ms, 0, sizeof(params));
|
|
|
|
params.o_ptr = o_packed_d;
|
|
params.o_stride_in_bytes = get_size_in_bytes(h * dv, output_dtype);
|
|
|
|
if (interleaved)
|
|
{
|
|
params.q_stride_in_bytes = total;
|
|
params.o_stride_in_bytes = total;
|
|
}
|
|
|
|
if (input_layout == Attention_input_layout::PACKED_QKV)
|
|
{
|
|
// For grouped- or multi-query attention (h denotes num_q_heads; h' denotes h_kv):
|
|
// qkv_layout = [b, s, [q_hd, k_h'd, v_h'd]]
|
|
// qkv_stride = (h+2*h')d * bytes_per_elt
|
|
// Otherwise:
|
|
// qkv_layout = [b, s, 3, h, d] or [b, s, h, 3, d]
|
|
// qkv_stride = 3hd * bytes_per_elt
|
|
params.qkv_ptr = qkv_packed_d;
|
|
params.q_stride_in_bytes = params.k_stride_in_bytes = params.v_stride_in_bytes
|
|
= get_size_in_bytes(h * d + h_kv * d + h_kv * dv, data_type);
|
|
}
|
|
else
|
|
{
|
|
// Layout [B, S, H, D].
|
|
params.q_ptr = q_d;
|
|
params.q_stride_in_bytes = get_size_in_bytes(h * d, data_type);
|
|
|
|
if (input_layout == Attention_input_layout::CONTIGUOUS_Q_KV)
|
|
{
|
|
// Layout [B, S, 2, H, D].
|
|
params.kv_ptr = kv_d;
|
|
params.k_stride_in_bytes = params.v_stride_in_bytes = get_size_in_bytes(h_kv * (d + dv), data_type);
|
|
}
|
|
else if (input_layout == Attention_input_layout::Q_PAGED_KV)
|
|
{
|
|
int max_blocks_per_sequence = (s_kv + tokens_per_block - 1) / tokens_per_block;
|
|
params.paged_kv_cache = Kv_block_array(b, max_blocks_per_sequence, tokens_per_block,
|
|
get_size_in_bytes(tokens_per_block * h_kv * std::gcd(d, dv), data_type), paged_kv_pool_ptr);
|
|
params.paged_kv_cache.mBlockOffsets = paged_block_offsets;
|
|
params.k_stride_in_bytes = get_size_in_bytes(tokens_per_block * d, data_type);
|
|
params.v_stride_in_bytes = get_size_in_bytes(tokens_per_block * dv, data_type);
|
|
}
|
|
else if (input_layout == Attention_input_layout::SEPARATE_Q_K_V)
|
|
{
|
|
// Layout [B, S, H_kv, D].
|
|
params.k_ptr = k_d;
|
|
// Layout [B, S, H_kv, Dv].
|
|
params.v_ptr = v_d;
|
|
params.k_stride_in_bytes = get_size_in_bytes(h_kv * d, data_type);
|
|
params.v_stride_in_bytes = get_size_in_bytes(h_kv * dv, data_type);
|
|
}
|
|
}
|
|
|
|
// Packed mask.
|
|
params.packed_mask_ptr = packed_mask_d;
|
|
// The N dimension has to be aligned.
|
|
params.packed_mask_stride_in_bytes = (align_to(int64_t(s_kv), int64_t(fmha::FLASH_ATTEN_MASK_N_ALIGNMENT))) / 8;
|
|
|
|
// Attention sinks.
|
|
params.attention_sinks = reinterpret_cast<float*>(attention_sinks_d);
|
|
|
|
#if defined(STORE_P)
|
|
params.p_ptr = p_d;
|
|
params.p_stride_in_bytes = get_size_in_bytes(b * h * s_kv, acc_type);
|
|
#endif // defined(STORE_P)
|
|
|
|
#if defined(STORE_S)
|
|
params.s_ptr = s_d;
|
|
params.s_stride_in_bytes = get_size_in_bytes(b * h * s_kv, data_type);
|
|
#endif // defined(STORE_S)
|
|
|
|
params.softmax_stats_ptr = softmax_stats_d;
|
|
params.softmax_stats_stride_in_bytes = get_size_in_bytes(h * 2, DATA_TYPE_FP32);
|
|
|
|
// Set the dimensions.
|
|
params.b = b;
|
|
params.h = h;
|
|
params.s = s_q;
|
|
params.d = d;
|
|
params.dv = dv;
|
|
params.num_grouped_heads = num_grouped_heads;
|
|
params.sliding_window_size = sliding_window_size;
|
|
assert((chunked_attention_size == 0 || (chunked_attention_size & (chunked_attention_size - 1)) == 0)
|
|
&& "chunked_attention_size has to be a power of 2");
|
|
params.log2_chunked_attention_size = chunked_attention_size > 0 ? std::log2(chunked_attention_size) : 0;
|
|
|
|
// cumulative q or kv sequence lengths.
|
|
params.cu_q_seqlens = static_cast<int*>(cu_q_seqlens_d);
|
|
params.cu_kv_seqlens = static_cast<int*>(cu_kv_seqlens_d);
|
|
// cumulative mask sequence lengths.
|
|
params.cu_mask_rows = static_cast<int*>(cu_mask_rows_d);
|
|
|
|
// Set the different scale values.
|
|
Data_type scale_type1 = (data_type == DATA_TYPE_FP16) || (data_type == DATA_TYPE_BF16) ? acc_type : DATA_TYPE_FP32;
|
|
Data_type scale_softmax_type = scale_type1;
|
|
Data_type scale_type2 = (data_type == DATA_TYPE_FP16) || (data_type == DATA_TYPE_BF16) ? data_type : DATA_TYPE_FP32;
|
|
if (data_type == DATA_TYPE_E4M3)
|
|
{
|
|
scale_type1 = acc_type;
|
|
scale_type2 = acc_type;
|
|
}
|
|
|
|
// Fuse 1.0f / softcapping_scale into scale_bmm1.
|
|
bool const enable_attn_logit_softcapping = softcapping_scale_bmm1 != 0.f;
|
|
float fused_scale_bmm1 = enable_attn_logit_softcapping ? scale_bmm1 / softcapping_scale_bmm1 : scale_bmm1;
|
|
|
|
// use specialized hopper kernels without alibi support.
|
|
// alibi or softcapping_scale cannot utilize the exp2f with fused_scale optimization.
|
|
if (launch_params.warp_specialization && !has_alibi && !enable_attn_logit_softcapping)
|
|
{
|
|
set_alpha(params.scale_bmm1, fused_scale_bmm1 * float(M_LOG2E), DATA_TYPE_FP32);
|
|
}
|
|
else
|
|
{
|
|
set_alpha(params.scale_bmm1, fused_scale_bmm1, scale_type1);
|
|
}
|
|
set_alpha(params.scale_softmax, scale_softmax, scale_softmax_type);
|
|
set_alpha(params.scale_bmm2, scale_bmm2, scale_type2);
|
|
params.scale_bmm2_d = reinterpret_cast<uint32_t*>(scale_bmm2_d);
|
|
params.softcapping_scale_bmm1 = softcapping_scale_bmm1;
|
|
|
|
FMHA_CHECK_CUDA(cudaMemcpy(params.scale_bmm2_d, ¶ms.scale_bmm2, sizeof(uint32_t), cudaMemcpyHostToDevice));
|
|
|
|
// attention type, h_kv < h if MQA or GQA
|
|
params.h_kv = h_kv;
|
|
assert(h % h_kv == 0 && "MQA/GQA needs h to be divisible by h_kv!");
|
|
params.h_q_per_kv = h / h_kv;
|
|
params.has_alibi = has_alibi;
|
|
params.alibi_params = fmha::AlibiParams(h);
|
|
|
|
// Set flags
|
|
params.is_s_padded = is_s_padded;
|
|
params.use_int8_scale_max = use_int8_scale_max;
|
|
|
|
// Do we enable the trick to replace I2F with FP math in the 2nd GEMM?
|
|
if (data_type == DATA_TYPE_INT8)
|
|
{
|
|
params.enable_i2f_trick
|
|
= -double(1 << 22) * double(scale_bmm2) <= -128.f && double(1 << 22) * double(scale_bmm2) >= 127.f;
|
|
}
|
|
|
|
// Skip-softmax attention
|
|
params.skip_softmax_threshold_scale_factor = skip_softmax_threshold_scale_factor;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static inline void determine_launch_params(Launch_params& launch_params, Data_type data_type, int sm, const size_t s,
|
|
const size_t d, const Attention_mask_type attention_mask_type, const Attention_input_layout input_layout,
|
|
bool const interleaved, bool const ignore_b1opt, bool const force_unroll, bool const use_tma,
|
|
bool const force_non_flash_attention, bool const force_non_warp_specialization,
|
|
bool const force_non_granular_tiling, bool const force_fp32_acc, float const skip_softmax_threshold_scale_factor,
|
|
// device props
|
|
const cudaDeviceProp props)
|
|
{
|
|
|
|
// Set launch params to choose kernels
|
|
launch_params.ignore_b1opt = ignore_b1opt;
|
|
launch_params.force_unroll = force_unroll;
|
|
launch_params.force_fp32_acc = force_fp32_acc;
|
|
launch_params.interleaved = interleaved;
|
|
launch_params.attention_mask_type = attention_mask_type;
|
|
launch_params.attention_input_layout = input_layout;
|
|
|
|
// Set SM count and L2 cache size (used to determine launch blocks/grids to maximum performance)
|
|
launch_params.multi_processor_count = props.multiProcessorCount;
|
|
launch_params.device_l2_cache_size = props.l2CacheSize;
|
|
|
|
// threshold for adopting flash attention or warp_specialized kernels.
|
|
launch_params.flash_attention
|
|
= (data_type == DATA_TYPE_FP16 || data_type == DATA_TYPE_BF16 || data_type == DATA_TYPE_E4M3)
|
|
&& (s >= 16 && d >= 16) && !force_non_flash_attention;
|
|
|
|
// enable warp_speialized kernels when s >= 512 on hopper
|
|
// note that warp_speialized kernels need flash attention + tma
|
|
launch_params.warp_specialization
|
|
= (data_type == DATA_TYPE_FP16 || data_type == DATA_TYPE_BF16 || data_type == DATA_TYPE_E4M3) && sm == 90
|
|
&& launch_params.flash_attention && !force_non_warp_specialization;
|
|
// warp specialization kernels on hopper need tma
|
|
launch_params.use_tma = use_tma || launch_params.warp_specialization;
|
|
|
|
// use granular tiling on Ampere-style flash attention
|
|
launch_params.use_granular_tiling
|
|
= !force_non_granular_tiling && launch_params.flash_attention && !launch_params.warp_specialization && sm >= 80;
|
|
|
|
if (launch_params.use_granular_tiling && (data_type == DATA_TYPE_E4M3 && sm == 80))
|
|
{
|
|
printf(
|
|
"Fallback to non-granular-tiling kernels as tiled e4m3 kernels"
|
|
"are not supported on Ada currently.\n");
|
|
launch_params.use_granular_tiling = false;
|
|
}
|
|
|
|
// Enable skip softmax attention or not.
|
|
launch_params.enable_skip_softmax = skip_softmax_threshold_scale_factor > 0.f;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
int main(int argc, char** argv)
|
|
{
|
|
|
|
// The device. Reset on destruction
|
|
CudaDevice device;
|
|
int sm = device.sm;
|
|
cudaDeviceProp props = device.props;
|
|
|
|
GpuTimer timer;
|
|
|
|
// The batch size.
|
|
size_t b = 128;
|
|
// The number of heads.
|
|
size_t h = 16;
|
|
// The dimension of the Q, K and V vectors.
|
|
size_t d = 64;
|
|
// The dimension of V if set to non-zero, otherwise dimension of V equals to that of Q
|
|
size_t dv = 0;
|
|
// The length of the sequence.
|
|
size_t s = 384;
|
|
// Number of grouped heads in the seqlen dimension.
|
|
size_t num_grouped_heads = 1;
|
|
// Sliding Window Attention
|
|
// Only pay attention to [max(0, query_idx - sliding_window_size), query_idx].
|
|
size_t sliding_window_size = size_t(INT_MAX);
|
|
// The chunked-attention size.
|
|
size_t chunked_attention_size = 0;
|
|
|
|
// The data type of the kernel.
|
|
Data_type data_type = DATA_TYPE_FP16;
|
|
// The type of the intermediate P matrix.
|
|
Data_type acc_type = DATA_TYPE_FP16;
|
|
// The type of the output.
|
|
Data_type output_dtype = DATA_TYPE_FP16;
|
|
// Is the output type set ?
|
|
bool is_output_dtype_set = false;
|
|
|
|
// The scaling factors.
|
|
float scale_bmm1 = 0.f, scale_softmax = 0.f, scale_bmm2 = 0.25f;
|
|
// The number of runs.
|
|
int runs = 1, warm_up_runs = 0;
|
|
// Do we use 1s for Q, K, V.
|
|
bool use_1s_q = false, use_1s_k = false, use_1s_v = false;
|
|
// The range of the different inputs.
|
|
int range_q = 5, range_k = 3, range_v = 5;
|
|
// The scale.
|
|
float scale_q = 0.f, scale_k = 0.f, scale_v = 0.f;
|
|
// The threshold for dropout. By default, drop 10%.
|
|
float dropout = 0.1f;
|
|
// Do we skip the checks.
|
|
bool skip_checks = false;
|
|
// The tolerance when checking results.
|
|
float epsilon = -1.f; // data_type == DATA_TYPE_FP16 ? 0.015f : 0.f;
|
|
// Use causal mask / padding_mask / sliding_or_chunked_causal mask / custom_mask input.
|
|
Attention_mask_type attention_mask_type = Attention_mask_type::PADDING;
|
|
// Use padded format for input QKV tensor & output O tensor.
|
|
// Instead of variable lengths [total, h, 3, d] where total = b1*s1 + b2*s2 + ... bn*sn,
|
|
// use padded length [b, max_s, h, 3, d] where max_s is the maximum expected seq len
|
|
bool is_s_padded = false;
|
|
|
|
// minimum sequence length for sampling variable seqlens
|
|
uint32_t min_s = -1;
|
|
|
|
// run interleaved kernels and transpose input and output accordingly
|
|
bool interleaved = false;
|
|
bool ignore_b1opt = false;
|
|
bool force_unroll = false;
|
|
// used by kernels that have different acc data types (like hmma, qmma)
|
|
bool force_fp32_acc = false;
|
|
bool force_non_flash_attention = false;
|
|
// enable warp specialization kernels on sm 90
|
|
bool force_non_warp_specialization = (sm != 90);
|
|
bool use_int8_scale_max = false;
|
|
bool verbose = true;
|
|
bool save_softmax = false;
|
|
|
|
// use granular tiling
|
|
// supported only by Ampere-based Flash Attention at this moment
|
|
bool force_non_granular_tiling = false;
|
|
|
|
// set all sequence lengths to min(s, min_s)
|
|
bool fix_s = false;
|
|
|
|
bool v1 = false;
|
|
|
|
// use TMA or not. ignored if not in SM90
|
|
bool use_tma = false;
|
|
|
|
// use alibi.
|
|
bool has_alibi = false;
|
|
|
|
// Use softcapping_scale_bmm1 (scale * __tanhf(x / scale)).
|
|
float softcapping_scale_bmm1 = 0.f;
|
|
|
|
// In multi-query or grouped-query attention (MQA/GQA), several Q heads are associated with one KV head
|
|
bool multi_query_attention = false;
|
|
size_t h_kv = 0;
|
|
|
|
// The attention input layout.
|
|
Attention_input_layout input_layout = Attention_input_layout::PACKED_QKV;
|
|
|
|
// TRTLLM uses 64 by default in paged kv cache.
|
|
size_t tokens_per_block = 64;
|
|
|
|
// Attention that has different q and kv lengths.
|
|
size_t s_q = 0;
|
|
// different q and kv sequence lengths.
|
|
bool different_q_kv_lengths = false;
|
|
|
|
// SageAttention block sizes
|
|
int sage_block_size_q = 0, sage_block_size_k = 0, sage_block_size_v = 0;
|
|
|
|
// Use attention sinks (added to the denominator of softmax)
|
|
bool use_attention_sinks = false;
|
|
|
|
// Skip-softmax attention
|
|
float skip_softmax_threshold_scale_factor = 0;
|
|
|
|
// Read the parameters from the command-line.
|
|
for (int ii = 1; ii < argc; ++ii)
|
|
{
|
|
if (!strcmp(argv[ii], "-1s"))
|
|
{
|
|
use_1s_k = use_1s_q = use_1s_v = true;
|
|
}
|
|
else if (!strcmp(argv[ii], "-1s-k"))
|
|
{
|
|
use_1s_k = true;
|
|
}
|
|
else if (!strcmp(argv[ii], "-1s-q"))
|
|
{
|
|
use_1s_q = true;
|
|
}
|
|
else if (!strcmp(argv[ii], "-1s-v"))
|
|
{
|
|
use_1s_v = true;
|
|
}
|
|
else if (!strcmp(argv[ii], "-b") && ++ii < argc)
|
|
{
|
|
b = strtol(argv[ii], nullptr, 10);
|
|
}
|
|
else if (!strcmp(argv[ii], "-d") && ++ii < argc)
|
|
{
|
|
d = strtol(argv[ii], nullptr, 10);
|
|
}
|
|
else if (!strcmp(argv[ii], "-dv") && ++ii < argc)
|
|
{
|
|
dv = strtol(argv[ii], nullptr, 10);
|
|
}
|
|
else if (!strcmp(argv[ii], "-s-q") && ++ii < argc)
|
|
{
|
|
s_q = strtol(argv[ii], nullptr, 10);
|
|
different_q_kv_lengths = true;
|
|
}
|
|
else if (!strcmp(argv[ii], "-dropout") && ++ii < argc)
|
|
{
|
|
dropout = (float) strtod(argv[ii], nullptr);
|
|
}
|
|
else if (!strcmp(argv[ii], "-epsilon") && ++ii < argc)
|
|
{
|
|
epsilon = (float) strtod(argv[ii], nullptr);
|
|
}
|
|
else if (!strcmp(argv[ii], "-h") && ++ii < argc)
|
|
{
|
|
h = strtol(argv[ii], nullptr, 10);
|
|
}
|
|
else if (!strcmp(argv[ii], "-int8"))
|
|
{
|
|
data_type = DATA_TYPE_INT8;
|
|
acc_type = DATA_TYPE_INT32;
|
|
}
|
|
else if (!strcmp(argv[ii], "-fp16"))
|
|
{
|
|
data_type = DATA_TYPE_FP16;
|
|
acc_type = DATA_TYPE_FP16;
|
|
}
|
|
else if (!strcmp(argv[ii], "-fp16-fp32"))
|
|
{
|
|
data_type = DATA_TYPE_FP16;
|
|
acc_type = DATA_TYPE_FP32;
|
|
force_fp32_acc = true;
|
|
}
|
|
else if (!strcmp(argv[ii], "-bf16"))
|
|
{
|
|
data_type = DATA_TYPE_BF16;
|
|
acc_type = DATA_TYPE_FP32;
|
|
force_fp32_acc = true;
|
|
}
|
|
else if (!strcmp(argv[ii], "-e4m3"))
|
|
{
|
|
data_type = DATA_TYPE_E4M3;
|
|
// Technically not the acc type.
|
|
acc_type = DATA_TYPE_FP32;
|
|
force_fp32_acc = true;
|
|
}
|
|
else if (!strcmp(argv[ii], "-e4m3-fp16"))
|
|
{ // Ada QMMA only
|
|
data_type = DATA_TYPE_E4M3;
|
|
// Technically not the acc type.
|
|
acc_type = DATA_TYPE_FP16;
|
|
}
|
|
else if (!strcmp(argv[ii], "-e4m3-fp32"))
|
|
{
|
|
data_type = DATA_TYPE_E4M3;
|
|
// Technically not the acc type.
|
|
acc_type = DATA_TYPE_FP32;
|
|
force_fp32_acc = true;
|
|
}
|
|
else if (!strcmp(argv[ii], "-fp16-output"))
|
|
{
|
|
output_dtype = DATA_TYPE_FP16;
|
|
is_output_dtype_set = true;
|
|
}
|
|
else if (!strcmp(argv[ii], "-bf16-output"))
|
|
{
|
|
output_dtype = DATA_TYPE_BF16;
|
|
is_output_dtype_set = true;
|
|
}
|
|
else if (!strcmp(argv[ii], "-num-grouped-heads") && ++ii < argc)
|
|
{
|
|
num_grouped_heads = strtol(argv[ii], nullptr, 10);
|
|
}
|
|
else if (!strcmp(argv[ii], "-range-k") && ++ii < argc)
|
|
{
|
|
range_k = strtol(argv[ii], nullptr, 10);
|
|
}
|
|
else if (!strcmp(argv[ii], "-range-q") && ++ii < argc)
|
|
{
|
|
range_q = strtol(argv[ii], nullptr, 10);
|
|
}
|
|
else if (!strcmp(argv[ii], "-range-v") && ++ii < argc)
|
|
{
|
|
range_v = strtol(argv[ii], nullptr, 10);
|
|
}
|
|
else if (!strcmp(argv[ii], "-runs") && ++ii < argc)
|
|
{
|
|
runs = strtol(argv[ii], nullptr, 10);
|
|
}
|
|
else if (!strcmp(argv[ii], "-s") && ++ii < argc)
|
|
{
|
|
s = strtol(argv[ii], nullptr, 10);
|
|
}
|
|
else if (!strcmp(argv[ii], "-sliding-window-size") && ++ii < argc)
|
|
{
|
|
sliding_window_size = strtol(argv[ii], nullptr, 10);
|
|
}
|
|
else if (!strcmp(argv[ii], "-chunked-attention-size") && ++ii < argc)
|
|
{
|
|
chunked_attention_size = strtol(argv[ii], nullptr, 10);
|
|
}
|
|
else if (!strcmp(argv[ii], "-scale-bmm1") && ++ii < argc)
|
|
{
|
|
scale_bmm1 = (float) strtod(argv[ii], nullptr);
|
|
}
|
|
else if (!strcmp(argv[ii], "-scale-bmm2") && ++ii < argc)
|
|
{
|
|
scale_bmm2 = (float) strtod(argv[ii], nullptr);
|
|
}
|
|
else if (!strcmp(argv[ii], "-scale-k") && ++ii < argc)
|
|
{
|
|
scale_k = (float) strtod(argv[ii], nullptr);
|
|
}
|
|
else if (!strcmp(argv[ii], "-scale-softmax") && ++ii < argc)
|
|
{
|
|
scale_softmax = (float) strtod(argv[ii], nullptr);
|
|
}
|
|
else if (!strcmp(argv[ii], "-scale-q") && ++ii < argc)
|
|
{
|
|
scale_q = (float) strtod(argv[ii], nullptr);
|
|
}
|
|
else if (!strcmp(argv[ii], "-scale-v") && ++ii < argc)
|
|
{
|
|
scale_v = (float) strtod(argv[ii], nullptr);
|
|
}
|
|
else if (!strcmp(argv[ii], "-skip-checks"))
|
|
{
|
|
skip_checks = true;
|
|
}
|
|
else if (!strcmp(argv[ii], "-warm-up-runs") && ++ii < argc)
|
|
{
|
|
warm_up_runs = strtol(argv[ii], nullptr, 10);
|
|
}
|
|
else if (!strcmp(argv[ii], "-min-s") && ++ii < argc)
|
|
{
|
|
min_s = strtol(argv[ii], nullptr, 10);
|
|
}
|
|
else if (!strcmp(argv[ii], "-il"))
|
|
{
|
|
interleaved = true;
|
|
}
|
|
else if (!strcmp(argv[ii], "-causal-mask"))
|
|
{
|
|
attention_mask_type = Attention_mask_type::CAUSAL;
|
|
}
|
|
else if (!strcmp(argv[ii], "-sliding-or-chunked-causal-mask"))
|
|
{
|
|
attention_mask_type = Attention_mask_type::SLIDING_OR_CHUNKED_CAUSAL;
|
|
}
|
|
else if (!strcmp(argv[ii], "-custom-mask"))
|
|
{
|
|
attention_mask_type = Attention_mask_type::CUSTOM_MASK;
|
|
}
|
|
else if (!strcmp(argv[ii], "-multi-query-attention") || !strcmp(argv[ii], "-mqa"))
|
|
{
|
|
h_kv = 1;
|
|
multi_query_attention = true; // subset of GQA
|
|
}
|
|
else if ((!strcmp(argv[ii], "-grouped-query-attention") || !strcmp(argv[ii], "-gqa")) && ++ii < argc)
|
|
{
|
|
h_kv = strtol(argv[ii], nullptr, 10);
|
|
multi_query_attention = true;
|
|
}
|
|
else if (!strcmp(argv[ii], "-contiguous-q-kv"))
|
|
{
|
|
input_layout = Attention_input_layout::CONTIGUOUS_Q_KV;
|
|
}
|
|
else if (!strcmp(argv[ii], "-paged-kv"))
|
|
{
|
|
input_layout = Attention_input_layout::Q_PAGED_KV;
|
|
}
|
|
else if (!strcmp(argv[ii], "-separate-q-k-v"))
|
|
{
|
|
input_layout = Attention_input_layout::SEPARATE_Q_K_V;
|
|
}
|
|
else if (!strcmp(argv[ii], "-tokens-per-block") && ++ii < argc)
|
|
{
|
|
tokens_per_block = strtol(argv[ii], nullptr, 10);
|
|
}
|
|
else if (!strcmp(argv[ii], "-pad-s"))
|
|
{
|
|
is_s_padded = true;
|
|
}
|
|
else if (!strcmp(argv[ii], "-ignore-b1opt"))
|
|
{
|
|
ignore_b1opt = true;
|
|
}
|
|
else if (!strcmp(argv[ii], "-force-unroll"))
|
|
{
|
|
force_unroll = true;
|
|
}
|
|
else if (!strcmp(argv[ii], "-force-non-flash-attention"))
|
|
{
|
|
force_non_flash_attention = true;
|
|
force_non_warp_specialization = true;
|
|
}
|
|
else if (!strcmp(argv[ii], "-force-flash-attention"))
|
|
{
|
|
fprintf(stderr,
|
|
"Deprecation warning: -force-flash-attention is no longer valid; use "
|
|
"-force-non-flash-attention instead, as Flash Attention is enabled by default.\n");
|
|
}
|
|
else if (!strcmp(argv[ii], "-force-non-warp-specialization"))
|
|
{
|
|
force_non_warp_specialization = true;
|
|
}
|
|
else if (!strcmp(argv[ii], "-force-non-granular-tiling") || !strcmp(argv[ii], "-force-non-tiled"))
|
|
{
|
|
force_non_granular_tiling = true;
|
|
}
|
|
else if (!strcmp(argv[ii], "-fix-s"))
|
|
{
|
|
fix_s = true;
|
|
}
|
|
else if (!strcmp(argv[ii], "-scale-max"))
|
|
{
|
|
use_int8_scale_max = true;
|
|
}
|
|
else if (!strcmp(argv[ii], "-v") && ++ii < argc)
|
|
{
|
|
int v = strtol(argv[ii], nullptr, 10);
|
|
verbose = v != 0;
|
|
}
|
|
else if (!strcmp(argv[ii], "-v1"))
|
|
{
|
|
v1 = true;
|
|
}
|
|
else if (!strcmp(argv[ii], "-use-tma"))
|
|
{
|
|
use_tma = true;
|
|
// flash attention + tma + non_warp_specialized kernels are not supported
|
|
// use non_flash_attention + tma + non_warp_specialized instead
|
|
if (force_non_warp_specialization)
|
|
{
|
|
force_non_flash_attention = true;
|
|
}
|
|
}
|
|
else if (!strcmp(argv[ii], "-alibi"))
|
|
{
|
|
has_alibi = true;
|
|
}
|
|
else if (!strcmp(argv[ii], "-softcapping-scale-bmm1") && ++ii < argc)
|
|
{
|
|
softcapping_scale_bmm1 = (float) strtod(argv[ii], nullptr);
|
|
}
|
|
else if (!strcmp(argv[ii], "-save-softmax"))
|
|
{
|
|
save_softmax = true;
|
|
}
|
|
else if (!strcmp(argv[ii], "-sage-block-q") && ++ii < argc)
|
|
{
|
|
sage_block_size_q = strtol(argv[ii], nullptr, 10);
|
|
}
|
|
else if (!strcmp(argv[ii], "-sage-block-k") && ++ii < argc)
|
|
{
|
|
sage_block_size_k = strtol(argv[ii], nullptr, 10);
|
|
}
|
|
else if (!strcmp(argv[ii], "-sage-block-v") && ++ii < argc)
|
|
{
|
|
sage_block_size_v = strtol(argv[ii], nullptr, 10);
|
|
}
|
|
else if (!strcmp(argv[ii], "-use-attention-sinks"))
|
|
{
|
|
use_attention_sinks = true;
|
|
}
|
|
else if (!strcmp(argv[ii], "-skip-softmax-threshold-scale-factor") && ++ii < argc)
|
|
{
|
|
skip_softmax_threshold_scale_factor = strtof(argv[ii], nullptr);
|
|
}
|
|
else
|
|
{
|
|
fprintf(stderr, "Unrecognized option: %s. Aborting!\n", argv[ii]);
|
|
return -1;
|
|
}
|
|
}
|
|
if (save_softmax == true)
|
|
{
|
|
bool is_MLA = (d == 192 && dv == 128);
|
|
if (((!is_MLA) && input_layout != Attention_input_layout::CONTIGUOUS_Q_KV)
|
|
|| (is_MLA && input_layout != Attention_input_layout::SEPARATE_Q_K_V))
|
|
{
|
|
fprintf(stderr,
|
|
"For normal attention, Only '--contiguous-q-kv' layout supports "
|
|
"'-save-softmax'. For MLA only '-separate-q-k-v' layout supports "
|
|
"'-save-softmax'.\n");
|
|
exit(1);
|
|
}
|
|
}
|
|
// Sanitize
|
|
if (min_s == -1)
|
|
min_s = s;
|
|
min_s = std::min<uint32_t>(s, min_s);
|
|
h_kv = multi_query_attention ? h_kv : h;
|
|
|
|
// Check if the options are valid.
|
|
if (different_q_kv_lengths)
|
|
{
|
|
assert(input_layout != Attention_input_layout::PACKED_QKV
|
|
&& "Packed QKV input layout is not supported with different q and kv lengths.");
|
|
assert(s >= s_q && "q seqlen has to be smaller than or equal to the kv seqlen !");
|
|
}
|
|
else
|
|
{
|
|
s_q = s;
|
|
}
|
|
|
|
// Sliding window attention (only pay attention to sliding-window-size long previous tokens).
|
|
if (sliding_window_size < s)
|
|
{
|
|
assert(
|
|
chunked_attention_size == 0 && "chunked_attention_size should not be used when sliding_window_size is set");
|
|
attention_mask_type = Attention_mask_type::SLIDING_OR_CHUNKED_CAUSAL;
|
|
}
|
|
// Chunked attention.
|
|
if (chunked_attention_size > 0)
|
|
{
|
|
assert((chunked_attention_size & (chunked_attention_size - 1)) == 0
|
|
&& "chunked_attention_size has to be a power of 2");
|
|
attention_mask_type = Attention_mask_type::SLIDING_OR_CHUNKED_CAUSAL;
|
|
}
|
|
|
|
// Set the norm.
|
|
if (scale_bmm1 == 0.f)
|
|
{
|
|
scale_bmm1 = 1.f / sqrtf((float) d);
|
|
}
|
|
|
|
// Set the output type if not set by user.
|
|
if (!is_output_dtype_set)
|
|
{
|
|
output_dtype = data_type;
|
|
}
|
|
|
|
// Force the softmax scale to 1.f for the FP16 kernel.
|
|
if (data_type == DATA_TYPE_FP16)
|
|
{
|
|
scale_softmax = 1.f;
|
|
}
|
|
else if (data_type == DATA_TYPE_INT8 && scale_softmax == 0.f)
|
|
{
|
|
scale_softmax = std::max(512.f, (float) s);
|
|
}
|
|
else if (data_type == DATA_TYPE_E4M3 && scale_softmax == 0.f)
|
|
{
|
|
scale_softmax = 1.f; // For E4M3 this is hardcoded as the largest power-of-2 below E4M3_MAX
|
|
}
|
|
|
|
// Sage Attention uses the e4m3 data type
|
|
if (sage_block_size_q > 0 || sage_block_size_k > 0 || sage_block_size_v > 0)
|
|
{
|
|
scale_softmax = 1.f;
|
|
scale_bmm2 = 1.f;
|
|
force_fp32_acc = true;
|
|
acc_type = DATA_TYPE_FP32;
|
|
}
|
|
|
|
// Define the scaling factor for the different inputs.
|
|
if (scale_q == 0.f)
|
|
{
|
|
scale_q = 1.f;
|
|
}
|
|
if (scale_k == 0.f)
|
|
{
|
|
scale_k = 1.f;
|
|
}
|
|
if (scale_v == 0.f)
|
|
{
|
|
// BF16 here just for debug.
|
|
scale_v = (data_type == DATA_TYPE_FP16 || data_type == DATA_TYPE_BF16) ? 0.125f : 1.f;
|
|
}
|
|
if (has_alibi && attention_mask_type == Attention_mask_type::PADDING)
|
|
{
|
|
attention_mask_type = Attention_mask_type::CAUSAL;
|
|
}
|
|
|
|
// BF16 only support FP32 acc_type.
|
|
if (data_type == DATA_TYPE_BF16 && acc_type != DATA_TYPE_FP32)
|
|
{
|
|
fprintf(stderr, "Only FP32 accumulation is supported for BF16 I/O\n");
|
|
exit(1);
|
|
}
|
|
|
|
// Set the tolerance if not already set by the user.
|
|
if (epsilon < 0.f)
|
|
{
|
|
switch (data_type)
|
|
{
|
|
case DATA_TYPE_FP16: epsilon = 0.015f; break;
|
|
case DATA_TYPE_BF16: epsilon = 0.025f; break;
|
|
case DATA_TYPE_E4M3: epsilon = 0.15f; break;
|
|
default: epsilon = 0.f;
|
|
}
|
|
// the accuracy of SageAttention may be between fp8 and fp16/bf16 ?
|
|
if (sage_block_size_q > 0 || sage_block_size_k > 0 || sage_block_size_v > 0)
|
|
{
|
|
epsilon = 0.05f;
|
|
}
|
|
}
|
|
|
|
// let the dimension of V equal to that of Q if not set by user
|
|
if (dv == 0)
|
|
{
|
|
dv = d;
|
|
}
|
|
|
|
// Debug info -- only in verbose mode.
|
|
if (verbose)
|
|
{
|
|
// Running the following command.
|
|
printf("Command.......: %s", argv[0]);
|
|
for (int ii = 1; ii < argc; ++ii)
|
|
{
|
|
printf(" %s", argv[ii]);
|
|
}
|
|
printf("\n");
|
|
|
|
// Device info.
|
|
printf("Device........: %s\n", props.name);
|
|
printf("Arch.(sm).....: %d\n", sm);
|
|
printf("#.of.SMs......: %d\n", props.multiProcessorCount);
|
|
|
|
// Problem info.
|
|
printf("Batch ........: %lu\n", b);
|
|
printf("Heads ........: %lu\n", h);
|
|
printf("Dimension ....: %lu\n", d);
|
|
printf("Dimension of V ....: %lu\n", dv);
|
|
printf("Seq length ...: %lu\n", s);
|
|
printf("Warm-up runs .: %d\n", warm_up_runs);
|
|
printf("Runs..........: %d\n\n", runs);
|
|
|
|
// The scaling factors for the 3 operations.
|
|
printf("Scale bmm1 ...: %.6f\n", scale_bmm1);
|
|
printf("Scale softmax.: %.6f\n", scale_softmax);
|
|
printf("Scale bmm2 ...: %.6f\n", scale_bmm2);
|
|
printf("\n");
|
|
}
|
|
|
|
// determine the launch params to select kernels
|
|
Launch_params launch_params;
|
|
determine_launch_params(launch_params, data_type, sm, s, d, attention_mask_type, input_layout, interleaved,
|
|
ignore_b1opt, force_unroll, use_tma, force_non_flash_attention, force_non_warp_specialization,
|
|
force_non_granular_tiling, force_fp32_acc, skip_softmax_threshold_scale_factor, props);
|
|
|
|
// The Q, K and V matrices are packed into one big matrix of size S x B x H x 3 x D.
|
|
const size_t qkv_size = s * b * h * (2 * d + dv);
|
|
// Allocate on the host.
|
|
float* qkv_h = (float*) malloc(qkv_size * sizeof(float));
|
|
// The size in bytes.
|
|
const size_t qkv_size_in_bytes = get_size_in_bytes(qkv_size, data_type);
|
|
// Allocate on the device.
|
|
void *qkv_sbh3d_d = nullptr, *qkv_bsh3d_d = nullptr;
|
|
FMHA_CHECK_CUDA(cudaMalloc(&qkv_sbh3d_d, qkv_size_in_bytes));
|
|
FMHA_CHECK_CUDA(cudaMalloc(&qkv_bsh3d_d, qkv_size_in_bytes));
|
|
|
|
// Contiguous KV cache buffer.
|
|
// The shape is [B, 2, S, H, D].
|
|
const size_t kv_size = b * s * h_kv * (d + dv);
|
|
// The size in bytes.
|
|
const size_t kv_size_in_bytes = get_size_in_bytes(kv_size, data_type);
|
|
// Allocate on the host.
|
|
void* contiguous_kv_h = malloc(kv_size_in_bytes);
|
|
// Memset the buffer.
|
|
memset(contiguous_kv_h, 0, kv_size_in_bytes);
|
|
// Allocate on the device.
|
|
void* contiguous_kv_d;
|
|
FMHA_CHECK_CUDA(cudaMalloc(&contiguous_kv_d, kv_size_in_bytes));
|
|
|
|
// Paged KV Cache buffer.
|
|
// The shape is [B, 2, Blocks_per_sequence], and each block's buffer shape is [H, Tokens_per_block, Dh].
|
|
void** kv_cache_ptrs_h = nullptr;
|
|
void* kv_cache_pool_ptr = nullptr;
|
|
int32_t *kv_cache_block_offsets_h, *kv_cache_block_offsets_d = nullptr;
|
|
const size_t max_blocks_per_seq = (s + tokens_per_block - 1) / tokens_per_block;
|
|
const size_t num_total_blocks = b * 2 * max_blocks_per_seq;
|
|
kv_cache_ptrs_h = (void**) malloc(num_total_blocks * sizeof(void*));
|
|
kv_cache_block_offsets_h = (int32_t*) malloc(num_total_blocks * sizeof(int32_t));
|
|
const size_t paged_kv_block_size_in_bytes = get_size_in_bytes(tokens_per_block * h_kv * std::gcd(d, dv), data_type);
|
|
FMHA_CHECK_CUDA(cudaMalloc((void**) (&kv_cache_block_offsets_d), num_total_blocks * sizeof(int32_t)));
|
|
const size_t kv_cache_pool_sz
|
|
= get_size_in_bytes(num_total_blocks * tokens_per_block * h_kv * (d + dv) / 2, data_type);
|
|
FMHA_CHECK_CUDA(cudaMalloc((void**) (&kv_cache_pool_ptr), kv_cache_pool_sz));
|
|
size_t ptr_index = 0;
|
|
size_t abs_offset = 0;
|
|
for (size_t bi = 0; bi < b; bi++)
|
|
{
|
|
for (int kv_offset = 0; kv_offset < 2; kv_offset++)
|
|
{
|
|
size_t block_size = get_size_in_bytes(tokens_per_block * h_kv * (kv_offset == 0 ? d : dv), data_type);
|
|
for (size_t block_i = 0; block_i < max_blocks_per_seq; block_i++)
|
|
{
|
|
kv_cache_ptrs_h[ptr_index]
|
|
= reinterpret_cast<void*>(reinterpret_cast<char*>(kv_cache_pool_ptr) + abs_offset);
|
|
assert(abs_offset % paged_kv_block_size_in_bytes == 0);
|
|
kv_cache_block_offsets_h[ptr_index] = abs_offset / paged_kv_block_size_in_bytes;
|
|
ptr_index++;
|
|
abs_offset += block_size;
|
|
}
|
|
}
|
|
}
|
|
assert(ptr_index == num_total_blocks && abs_offset == kv_cache_pool_sz);
|
|
FMHA_CHECK_CUDA(cudaMemcpy(
|
|
kv_cache_block_offsets_d, kv_cache_block_offsets_h, num_total_blocks * sizeof(int32_t), cudaMemcpyDefault));
|
|
|
|
// Q will always be [B, S, H, Dh] with paged kv cache.
|
|
void* q_d;
|
|
const size_t q_size = s * b * h * d;
|
|
FMHA_CHECK_CUDA(cudaMalloc(&q_d, get_size_in_bytes(q_size, data_type)));
|
|
|
|
// K has [B, S, H_kv, D] with separate kv cache.
|
|
void* k_d;
|
|
const size_t k_size = s * b * h_kv * d;
|
|
FMHA_CHECK_CUDA(cudaMalloc(&k_d, get_size_in_bytes(k_size, data_type)));
|
|
|
|
// V has [B, S, H_kv, Dv] with separate kv cache.
|
|
void* v_d;
|
|
const size_t v_size = s * b * h_kv * dv;
|
|
FMHA_CHECK_CUDA(cudaMalloc(&v_d, get_size_in_bytes(v_size, data_type)));
|
|
|
|
// Scale bmm2 (per-tensor).
|
|
void* scale_bmm2_d;
|
|
FMHA_CHECK_CUDA(cudaMalloc(&scale_bmm2_d, sizeof(uint32_t)));
|
|
|
|
// The mask for dropout or any mask patterns.
|
|
const size_t mask_size = s * b * s;
|
|
// Allocate on the host.
|
|
float* mask_h = (float*) malloc(mask_size * sizeof(float));
|
|
// The size in bytes.
|
|
const size_t mask_size_in_bytes = get_size_in_bytes(mask_size, DATA_TYPE_INT8);
|
|
// Allocate on the device.
|
|
void* mask_d = nullptr;
|
|
if (!skip_checks)
|
|
{
|
|
FMHA_CHECK_CUDA(cudaMalloc(&mask_d, mask_size_in_bytes));
|
|
}
|
|
|
|
// The decomposition of threads and warps for BMM1.
|
|
size_t warps_m, warps_n, warps_k;
|
|
std::tie(warps_m, warps_n, warps_k) = get_warps(launch_params, sm, data_type, s, b, d, v1 ? 1 : 2);
|
|
|
|
// print launch configuration
|
|
printf(
|
|
"v1=%d il=%d s_q=%lu, s=%lu b=%lu h=%lu/%lu d=%lu/%lu dtype=%s, output_dtype=%s, "
|
|
"flash_attn=%s, "
|
|
"warp_spec=%s, mask=%s, "
|
|
"alibi=%s, attn=%s, qkv_layout=%s, wm=%lu wn=%lu\n",
|
|
v1, interleaved, s_q, s, b, h, h_kv, d, dv, data_type_to_name(data_type).c_str(),
|
|
data_type_to_name(output_dtype).c_str(),
|
|
launch_params.flash_attention ? (launch_params.use_granular_tiling ? "true_tiled" : "true") : "false",
|
|
launch_params.warp_specialization ? "true" : "false", mask_type_to_string(attention_mask_type).c_str(),
|
|
has_alibi ? "true" : "false", h_kv == 1 ? "mqa" : (h_kv == h ? "mha" : "gqa"),
|
|
attention_input_layout_to_string(input_layout).c_str(), warps_m, warps_n);
|
|
|
|
// For multi-CTA cases, determine the size of the CTA wave.
|
|
int heads_per_wave, ctas_per_head;
|
|
get_grid_size(heads_per_wave, ctas_per_head, sm, data_type, b, s, h, d,
|
|
false, // disable multi-cta kernels by default
|
|
v1 ? 1 : 2);
|
|
|
|
// The number of threads per CTA.
|
|
const size_t threads_per_cta = warps_m * warps_n * warps_k * 32;
|
|
// The number of mmas in the M dimension. We use one uint32_t per MMA in the M dimension.
|
|
size_t mmas_m = (s + 16 * warps_m - 1) / (16 * warps_m);
|
|
// The number of mmas in the N dimension.
|
|
size_t mmas_n = (s + 16 * warps_n - 1) / (16 * warps_n);
|
|
// We do not support more than 4 MMAS in the N dimension (as each MMA needs 8 bits in the mask).
|
|
assert(!v1 || mmas_n <= 4);
|
|
// The packed mask for dropout (in the fused kernel). Layout is B * MMAS_M * THREADS_PER_CTA.
|
|
size_t packed_mask_size = b * mmas_m * threads_per_cta;
|
|
// Flash attention on Ampere and Hopper, which supports multiple mmas_n
|
|
if (!v1 && !force_non_flash_attention && attention_mask_type == Attention_mask_type::CUSTOM_MASK)
|
|
{
|
|
|
|
// We need to align q and k sequence lengths.
|
|
size_t rounded_q_s = align_to(s, size_t(fmha::FLASH_ATTEN_MASK_M_ALIGNMENT));
|
|
size_t rounded_k_s = align_to(s, size_t(fmha::FLASH_ATTEN_MASK_N_ALIGNMENT));
|
|
// The number of mmas in the M dimension (MMA_M = 64).
|
|
mmas_m = rounded_q_s / fmha::FLASH_ATTEN_MASK_MMA_M;
|
|
// The number of mmas in the N dimension (MMA_N = 64).
|
|
mmas_n = rounded_k_s / fmha::FLASH_ATTEN_MASK_MMA_N;
|
|
// Each thread holds 32 bit (2 rows, 16 cols -> 8 core MMAs) in one MMA here.
|
|
packed_mask_size = b * mmas_m * mmas_n * threads_per_cta;
|
|
}
|
|
// The size in bytes.
|
|
const size_t packed_mask_size_in_bytes = packed_mask_size * sizeof(uint32_t);
|
|
// Allocate on the host.
|
|
uint32_t* packed_mask_h = (uint32_t*) malloc(packed_mask_size_in_bytes);
|
|
// Set it to 0 (indicates that all elements are valid).
|
|
memset(packed_mask_h, 0, packed_mask_size_in_bytes);
|
|
// Allocate on the device.
|
|
void* packed_mask_d = nullptr;
|
|
|
|
// The size of the attention sinks.
|
|
const size_t attention_sinks_size_in_bytes = h * sizeof(float);
|
|
|
|
// The attention sinks.
|
|
void* attention_sinks_d = nullptr;
|
|
if (use_attention_sinks)
|
|
{
|
|
// Allocate on the host.
|
|
float* attention_sinks_h = (float*) malloc(attention_sinks_size_in_bytes);
|
|
// Randomly initialize the attention sinks.
|
|
random_init("attention_sinks", attention_sinks_h, 1, h, 1, false, 5.f, 1.f, verbose);
|
|
// Allocate on the device.
|
|
FMHA_CHECK_CUDA(cudaMalloc(&attention_sinks_d, attention_sinks_size_in_bytes));
|
|
// Copy from the host to the device.
|
|
FMHA_CHECK_CUDA(
|
|
cudaMemcpy(attention_sinks_d, attention_sinks_h, attention_sinks_size_in_bytes, cudaMemcpyDefault));
|
|
}
|
|
|
|
// The O matrix is packed as S * B * H * D.
|
|
const size_t o_size = s * b * h * dv;
|
|
// Allocate on the host.
|
|
float* o_h = (float*) malloc(o_size * sizeof(float));
|
|
// The size in bytes.
|
|
const size_t o_size_in_bytes = get_size_in_bytes(o_size, data_type);
|
|
// Allocate on the device.
|
|
void* o_d = nullptr;
|
|
FMHA_CHECK_CUDA(cudaMalloc(&o_d, o_size_in_bytes));
|
|
|
|
// The softmax_stats_d vector is used to store the max/sum of the softmax per token
|
|
void* softmax_stats_d;
|
|
FMHA_CHECK_CUDA(cudaMalloc(&softmax_stats_d, 2 * sizeof(float) * b * s * h));
|
|
FMHA_CHECK_CUDA(cudaMemset(softmax_stats_d, 0x00, 2 * sizeof(float) * b * s * h));
|
|
|
|
// The size in bytes.
|
|
const size_t tmp_size_in_bytes = get_size_in_bytes(o_size, acc_type);
|
|
// Allocate on the device.
|
|
void* tmp_d = nullptr;
|
|
if (data_type != acc_type)
|
|
{
|
|
FMHA_CHECK_CUDA(cudaMalloc(&tmp_d, tmp_size_in_bytes));
|
|
}
|
|
|
|
// Allocate the reference on the host.
|
|
float* o_ref_h = (float*) malloc(o_size * sizeof(float));
|
|
float* softmax_stats_ref_h = (float*) malloc(2 * b * s * h * sizeof(float));
|
|
float* softmax_stats_h = (float*) malloc(2 * b * s * h * sizeof(float));
|
|
|
|
// The P matrix is stored as one big matrix of size S x B x H x S.
|
|
const size_t p_size = s * b * h * s;
|
|
// The size in bytes.
|
|
const size_t p_size_in_bytes = get_size_in_bytes(p_size, acc_type);
|
|
// Allocate on the device.
|
|
void* p_d = nullptr;
|
|
if (!skip_checks)
|
|
{
|
|
FMHA_CHECK_CUDA(cudaMalloc(&p_d, p_size_in_bytes));
|
|
}
|
|
|
|
// Allocate the reference on the host.
|
|
float* p_ref_h = (float*) malloc(p_size * sizeof(float));
|
|
#if defined(STORE_P)
|
|
// Allocate on the host.
|
|
float* p_h = (float*) malloc(p_size * sizeof(float));
|
|
#endif // defined(STORE_P)
|
|
|
|
// The size in bytes of the S matrix (the data type may be different from P for int8).
|
|
const size_t s_size_in_bytes = get_size_in_bytes(p_size, data_type);
|
|
// Allocate on the device.
|
|
void* s_d = nullptr;
|
|
if (!skip_checks)
|
|
{
|
|
FMHA_CHECK_CUDA(cudaMalloc(&s_d, s_size_in_bytes));
|
|
}
|
|
|
|
// Allocate the reference on the host.
|
|
float* s_ref_h = (float*) malloc(p_size * sizeof(float));
|
|
|
|
// Allocate on the host.
|
|
float* s_h = (float*) malloc(p_size * sizeof(float));
|
|
// Make sure we set the seed for reproducible results.
|
|
srand(1234UL);
|
|
|
|
// Set the Q, K and V matrices.
|
|
random_init("Q", qkv_h + 0 * d, d, s * b * h, 2 * d + dv, use_1s_q, range_q, scale_q, verbose);
|
|
random_init("K", qkv_h + 1 * d, d, s * b * h, 2 * d + dv, use_1s_k, range_k, scale_k, verbose);
|
|
random_init("V", qkv_h + 2 * d, dv, s * b * h, 2 * d + dv, use_1s_v, range_v, scale_v, verbose);
|
|
// iota_init("Q", qkv_h + 0 * d, d, s * b * h, 3 * d, use_1s_q, range_q, scale_q, verbose, true, 0);
|
|
// iota_init("K", qkv_h + 1 * d, d, s * b * h, 3 * d, use_1s_k, range_k, scale_k, verbose, true, 128);
|
|
// iota_init("V", qkv_h + 2 * d, d, s * b * h, 3 * d, use_1s_v, range_v, scale_v, verbose, true, 256);
|
|
|
|
// Multi-query or grouped-query attention for reference input
|
|
if (multi_query_attention)
|
|
{
|
|
for (size_t sbi = 0; sbi < s * b; sbi++)
|
|
{
|
|
for (size_t hi = 0; hi < h; hi++)
|
|
{
|
|
for (size_t di = 0; di < d; di++)
|
|
{
|
|
// E.g., h=8, h_kv=4
|
|
// hi: 0, 1, 2, 3, 4, 5, 6, 7
|
|
// hi_kv_scatter: 0, 0, 2, 2, 4, 4, 6, 6
|
|
int const h_per_group = h / h_kv;
|
|
int const hi_kv_scatter = (hi / h_per_group) * h_per_group;
|
|
size_t src_offset = sbi * h * 3 * d + hi_kv_scatter * 3 * d + di; // [sbi, hi_kv_scatter, 0, di]
|
|
size_t dst_offset = sbi * h * 3 * d + hi * 3 * d + di; // [sbi, hi, 0, di]
|
|
|
|
// make sure all heads of kv in a group share the same d
|
|
qkv_h[dst_offset + 1 * d]
|
|
= qkv_h[src_offset + 1 * d]; // qkv[sbi, hi, 1, di] = qkv[sbi, hi_kv_scatter, 1, di]
|
|
qkv_h[dst_offset + 2 * d]
|
|
= qkv_h[src_offset + 2 * d]; // qkv[sbi, hi, 2, di] = qkv[sbi, hi_kv_scatter, 2, di]
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// WAR fOR MISSING CUBLAS FP8 NN SUPPORT.
|
|
// Transpose V, so that we can do a TN BMM2, i.e. O = S x V' instead of O = S x V.
|
|
float* vt_h = (float*) malloc(o_size * sizeof(float));
|
|
void* vt_d = nullptr;
|
|
FMHA_CHECK_CUDA(cudaMalloc(&vt_d, o_size_in_bytes));
|
|
for (size_t it = 0; it < o_size; it++)
|
|
{
|
|
// vt is B x H x D x S
|
|
size_t si = it % s;
|
|
size_t di = (it / s) % dv;
|
|
size_t hi = ((it / s) / dv) % h;
|
|
size_t bi = (((it / s) / dv) / h) % b;
|
|
// qkv is S x B x H x 3 x D
|
|
size_t qkv_idx = si * b * h * (2 * d + dv) + bi * h * (2 * d + dv) + hi * (2 * d + dv) + 2 * d // index V here
|
|
+ di;
|
|
vt_h[it] = qkv_h[qkv_idx];
|
|
}
|
|
FMHA_CHECK_CUDA(cuda_memcpy_h2d(vt_d, vt_h, o_size, data_type));
|
|
|
|
// // DEBUG.
|
|
// float sum = 0.f;
|
|
// for( size_t si = 0; si < s; ++si ) {
|
|
// float v = qkv_h[si*b*h*3*d + 2*d];
|
|
// printf("V[%3d]=%8.3f\n", si, v);
|
|
// sum += v;
|
|
// }
|
|
// printf("Sum of V = %8.3f\n", sum);
|
|
// // END OF DEBUG.
|
|
|
|
// Copy from the host to the device.
|
|
FMHA_CHECK_CUDA(cuda_memcpy_h2d(qkv_sbh3d_d, qkv_h, qkv_size, data_type));
|
|
|
|
// Create the buffer of mask.
|
|
// if(verbose) {printf("Init .........: mask\n"); }
|
|
// random_init_with_zeroes_or_ones(mask_h, b*s, false, 1.f - dropout, verbose);
|
|
|
|
std::vector<uint32_t> seqlens(b, 0); // randomly draw a batch of sequence lengths >= min_s
|
|
std::transform(seqlens.begin(), seqlens.end(), seqlens.begin(),
|
|
[=](const uint32_t)
|
|
{
|
|
if (fix_s)
|
|
{
|
|
return std::min(uint32_t(s), min_s);
|
|
}
|
|
if (s == min_s)
|
|
{
|
|
return min_s;
|
|
}
|
|
uint32_t s_ = s - min_s + 1;
|
|
uint32_t ret = min_s + (rand() % s_);
|
|
assert(ret <= s);
|
|
return ret;
|
|
});
|
|
|
|
// Compute the prefix sum of the sequence lengths.
|
|
std::vector<int> cu_seqlens(b + 1, 0);
|
|
for (int it = 0; it < b; it++)
|
|
{
|
|
cu_seqlens[it + 1] = cu_seqlens[it] + seqlens[it];
|
|
}
|
|
int total = cu_seqlens.back();
|
|
seqlens.emplace_back(total);
|
|
|
|
// Different q and kv sequence lengths.
|
|
std::vector<uint32_t> q_seqlens = seqlens;
|
|
std::vector<int> cu_q_seqlens = cu_seqlens;
|
|
if (different_q_kv_lengths)
|
|
{
|
|
for (int it = 0; it < b; it++)
|
|
{
|
|
q_seqlens[it] = s_q;
|
|
cu_q_seqlens[it + 1] = cu_q_seqlens[it] + q_seqlens[it];
|
|
}
|
|
}
|
|
|
|
// Compute the prefix sum of the mask sequence lengths.
|
|
std::vector<int> cu_mask_rows(b + 1, 0);
|
|
// The mask_h row offset in each sequence to support s_q < s_kv.
|
|
// we only need the last s_q rows in the [s, s] mask_h.
|
|
std::vector<int> mask_h_row_offsets(b);
|
|
for (int it = 0; it < b; it++)
|
|
{
|
|
// The actual q sequence length.
|
|
int actual_q_seqlen = q_seqlens[it];
|
|
// The mask_h row offset.
|
|
mask_h_row_offsets[it] = seqlens[it] - q_seqlens[it];
|
|
// Round up the sequence length to multiple of 128.
|
|
int mask_seqlen = align_to(actual_q_seqlen, fmha::FLASH_ATTEN_MASK_M_ALIGNMENT);
|
|
cu_mask_rows[it + 1] = cu_mask_rows[it] + mask_seqlen;
|
|
}
|
|
|
|
// transfer to device
|
|
void *cu_seqlens_d, *cu_q_seqlens_d, *cu_mask_rows_d;
|
|
FMHA_CHECK_CUDA(cudaMalloc(&cu_seqlens_d, sizeof(int) * cu_seqlens.size()));
|
|
FMHA_CHECK_CUDA(cudaMalloc(&cu_q_seqlens_d, sizeof(int) * cu_q_seqlens.size()));
|
|
FMHA_CHECK_CUDA(cudaMalloc(&cu_mask_rows_d, sizeof(int) * cu_mask_rows.size()));
|
|
FMHA_CHECK_CUDA(
|
|
cudaMemcpy(cu_seqlens_d, cu_seqlens.data(), sizeof(int) * cu_seqlens.size(), cudaMemcpyHostToDevice));
|
|
FMHA_CHECK_CUDA(
|
|
cudaMemcpy(cu_q_seqlens_d, cu_q_seqlens.data(), sizeof(int) * cu_q_seqlens.size(), cudaMemcpyHostToDevice));
|
|
FMHA_CHECK_CUDA(
|
|
cudaMemcpy(cu_mask_rows_d, cu_mask_rows.data(), sizeof(int) * cu_mask_rows.size(), cudaMemcpyHostToDevice));
|
|
|
|
size_t qkv_packed_size = cu_seqlens.back() * h * (2 * d + dv);
|
|
size_t qkv_packed_size_in_bytes = get_size_in_bytes(qkv_packed_size, data_type);
|
|
void* qkv_packed_d = nullptr;
|
|
FMHA_CHECK_CUDA(cudaMalloc(&qkv_packed_d, qkv_packed_size_in_bytes));
|
|
|
|
// Specify device buffers for multi-query attention or grouped-query attention
|
|
// TODO: Use the same buffer for all cases, and allow to set name to aid tracing/debugging
|
|
// e.g.,
|
|
// Buffer<float> qkv_buf(size);
|
|
// if( packed ) { qkv_buf.set_name("QKV_packed[total, h, 3, d]"); }
|
|
// else { qkv_buf.set_name("QKV_padded[b, s, h, 3, d]"); }
|
|
// qkv_buf.copy_to_device();
|
|
// float *qkv_buf_d = qkv_buf.get_device_buf();
|
|
// Or, more aggressively, use torch::Tensor from PyTorch ATen
|
|
size_t mqa_qkv_packed_size = cu_seqlens.back() * (h + 2 * h_kv) * d;
|
|
size_t mqa_qkv_packed_size_in_bytes = get_size_in_bytes(mqa_qkv_packed_size, data_type);
|
|
size_t mqa_qkv_size = b * s * (h + 2 * h_kv) * d; // original padded tensor
|
|
size_t mqa_qkv_size_in_bytes = get_size_in_bytes(mqa_qkv_size, data_type);
|
|
void* mqa_qkv_packed_d = nullptr;
|
|
void* mqa_qkv_d = nullptr;
|
|
FMHA_CHECK_CUDA(cudaMalloc(&mqa_qkv_packed_d, mqa_qkv_packed_size_in_bytes));
|
|
FMHA_CHECK_CUDA(cudaMalloc(&mqa_qkv_d, mqa_qkv_size_in_bytes));
|
|
|
|
const size_t o_packed_size = cu_seqlens.back() * h * dv;
|
|
// Allocate on the host.
|
|
float* o_packed_h = (float*) malloc(o_packed_size * sizeof(float));
|
|
void* o_packed_d = nullptr;
|
|
|
|
size_t o_packed_size_in_bytes = get_size_in_bytes(o_packed_size, output_dtype);
|
|
FMHA_CHECK_CUDA(cudaMalloc(&o_packed_d, o_packed_size_in_bytes));
|
|
|
|
// qkv_packed_h is TotalH3D
|
|
std::vector<float> qkv_packed_h(qkv_packed_size);
|
|
extract_and_transpose_input<float>(qkv_packed_h.data(), qkv_h, seqlens, s, b, h, d, dv, 3, false);
|
|
if (interleaved)
|
|
{
|
|
x_vec32(true, qkv_packed_h.data(), h, total, 3);
|
|
}
|
|
|
|
// qkv_h is SBH3D
|
|
// qkv_bsh3d_h is BSH3D
|
|
std::vector<float> qkv_bsh3d_h(qkv_size);
|
|
extract_and_transpose_input<float>(qkv_bsh3d_h.data(), qkv_h, seqlens, s, b, h, d, dv, 3, is_s_padded);
|
|
if (interleaved)
|
|
{
|
|
x_vec32(true, qkv_bsh3d_h.data(), h, b * h, 3);
|
|
}
|
|
|
|
std::vector<float> mqa_qkv_packed_h(mqa_qkv_packed_size);
|
|
std::vector<float> mqa_qkv_h(mqa_qkv_size);
|
|
// for now MLA doesn't use MQA, may enable it in the future
|
|
if (d == dv)
|
|
{
|
|
// from qkv[s, h, 3, d] to mqa_qkv[s, h + 2*h_kv, d]
|
|
// where
|
|
// Q is qkv[s, h, 0, d],
|
|
// K is qkv[s, h, 1, d],
|
|
// V is qkv[s, h, 2, d]
|
|
// and
|
|
// MQA_Q is mqa_qkv[s, h, [ 0 : h - 1], d],
|
|
// MQA_K is mqa_qkv[s, h, [ h : h + h_kv - 1], d],
|
|
// MQA_V is mqa_qkv[s, h, [h + h_kv : h + 2*h_kv - 1], d]
|
|
for (size_t si = 0; si < cu_seqlens.back(); si++)
|
|
{
|
|
for (size_t hi = 0; hi < h; hi++)
|
|
{
|
|
for (size_t di = 0; di < d; di++)
|
|
{
|
|
// Q: [si, hi, di] <- [si, hi, 0, di]
|
|
mqa_qkv_packed_h[si * (h + 2 * h_kv) * d + hi * d + di]
|
|
= qkv_packed_h[si * h * 3 * d + hi * 3 * d + 0 * d + di];
|
|
if (hi < h_kv)
|
|
{
|
|
// E.g., h=8, h_kv=4
|
|
// src kv id: 0, 0, 1, 1, 2, 2, 3, 3
|
|
// hi: 0, 1, 2, 3, 4, 5, 6, 7
|
|
// hi_kv_scatter: 0, 2, 4, 6, x, x, x, x
|
|
int const h_per_group = h / h_kv;
|
|
int const hi_kv_scatter = hi * h_per_group;
|
|
// K: [si, h + hi, di] <- [si, hi_kv_scatter, 1, di]
|
|
mqa_qkv_packed_h[si * (h + 2 * h_kv) * d + (h + hi) * d + di]
|
|
= qkv_packed_h[si * 3 * h * d + hi_kv_scatter * 3 * d + 1 * d + di];
|
|
// V: [si, h + h_kv + hi, di] <- [si, hi_kv_scatter, 2, di]
|
|
mqa_qkv_packed_h[si * (h + 2 * h_kv) * d + (h + h_kv + hi) * d + di]
|
|
= qkv_packed_h[si * 3 * h * d + hi_kv_scatter * 3 * d + 2 * d + di];
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// from qkv_bsh3d_h[b, s, h, 3, d] to mqa_qkv[b, s, h + 2*h_kv, d]
|
|
for (size_t bi = 0; bi < b; bi++)
|
|
{
|
|
int actual_s = seqlens[bi];
|
|
for (size_t si = 0; si < actual_s; si++)
|
|
{
|
|
for (size_t hi = 0; hi < h; hi++)
|
|
{
|
|
for (size_t di = 0; di < d; di++)
|
|
{
|
|
mqa_qkv_h[bi * s * (h + 2 * h_kv) * d + si * (h + 2 * h_kv) * d + hi * d + di]
|
|
= qkv_bsh3d_h[bi * s * h * 3 * d + si * h * 3 * d + hi * 3 * d + 0 * d + di];
|
|
if (hi < h_kv)
|
|
{
|
|
// E.g., h=8, h_kv=4
|
|
// src kv id: 0, 0, 1, 1, 2, 2, 3, 3
|
|
// hi: 0, 1, 2, 3, 4, 5, 6, 7
|
|
// hi_kv_scatter: 0, 2, 4, 6, x, x, x, x
|
|
int const h_per_group = h / h_kv;
|
|
int const hi_kv_scatter = hi * h_per_group;
|
|
mqa_qkv_h[bi * s * (h + 2 * h_kv) * d + si * (h + 2 * h_kv) * d + (h + hi) * d + di]
|
|
= qkv_bsh3d_h[bi * s * h * 3 * d + si * h * 3 * d + hi_kv_scatter * 3 * d + 1 * d + di];
|
|
mqa_qkv_h[bi * s * (h + 2 * h_kv) * d + si * (h + 2 * h_kv) * d + (h + h_kv + hi) * d + di]
|
|
= qkv_bsh3d_h[bi * s * h * 3 * d + si * h * 3 * d + hi_kv_scatter * 3 * d + 2 * d + di];
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
// if( verbose ) {
|
|
// print_tensor(qkv_packed_h.data() + 0 * d, d, total * h, 3 * d, "Packed Q[bs, h, d]");
|
|
// print_tensor(qkv_packed_h.data() + 1 * d, d, total * h, 3 * d, "Packed K[bs, h, d]");
|
|
// print_tensor(qkv_packed_h.data() + 2 * d, d, total * h, 3 * d, "Packed V[bs, h, d]");
|
|
|
|
// print_tensor(mqa_qkv_packed_h.data() + 0 * d, h * d, total, (h + 2 * h_kv) * d, "Packed MQA
|
|
// Q[bs, h*d]"); print_tensor(mqa_qkv_packed_h.data() + h * d, h_kv * d, total, (h + 2 * h_kv) * d,
|
|
// "Packed MQA K[bs, h_kv*d]"); print_tensor(mqa_qkv_packed_h.data() + h * d + h_kv * d, h_kv * d, total, (h + 2
|
|
// * h_kv) * d, "Packed MQA V[bs, h_kv*d]");
|
|
|
|
// print_tensor(qkv_bsh3d_h.data() + 0 * d, d, b * h * s, 3 * d, "Padded Q[b, s, h, d]");
|
|
// print_tensor(qkv_bsh3d_h.data() + 1 * d, d, b * h * s, 3 * d, "Padded K[b, s, h, d]");
|
|
// print_tensor(qkv_bsh3d_h.data() + 2 * d, d, b * h * s, 3 * d, "Padded V[b, s, h, d]");
|
|
|
|
// print_tensor(mqa_qkv_h.data() + 0 * d, h * d, b * s, (h + 2 * h_kv) * d, "Padded MQA Q[b, s,
|
|
// h*d]"); print_tensor(mqa_qkv_h.data() + h * d, h_kv * d, b * s, (h + 2 * h_kv) * d, "Padded MQA
|
|
// K[b, s, h_kv*d]"); print_tensor(mqa_qkv_h.data() + h * d + h_kv * d, h_kv * d, b * s, (h + 2 * h_kv) * d,
|
|
// "Padded MQA V[b, s, h_kv*d]");
|
|
// }
|
|
|
|
// Contiguous KV Cache and Separate KV Cache.
|
|
store_q_and_contiguous_kv_cache(q_d, k_d, v_d, contiguous_kv_h, contiguous_kv_d,
|
|
reinterpret_cast<float const*>(qkv_packed_h.data()), reinterpret_cast<int const*>(cu_seqlens.data()),
|
|
reinterpret_cast<int const*>(cu_q_seqlens.data()), b, s, h, h_kv, d, dv, data_type);
|
|
|
|
// Paged KV Cache.
|
|
store_paged_kv_cache(kv_cache_ptrs_h, reinterpret_cast<float const*>(qkv_packed_h.data()),
|
|
reinterpret_cast<int const*>(cu_seqlens.data()), max_blocks_per_seq, tokens_per_block, b, h, h_kv, d, dv,
|
|
data_type);
|
|
|
|
// Copy packed, padded, mqa packed, mqa padded data buffers
|
|
// TODO: use the same buffer for all cases
|
|
FMHA_CHECK_CUDA(cuda_memcpy_h2d(qkv_packed_d, qkv_packed_h.data(), qkv_packed_size, data_type));
|
|
FMHA_CHECK_CUDA(cuda_memcpy_h2d(mqa_qkv_packed_d, mqa_qkv_packed_h.data(), mqa_qkv_packed_size, data_type));
|
|
FMHA_CHECK_CUDA(cuda_memcpy_h2d(mqa_qkv_d, mqa_qkv_h.data(), mqa_qkv_size, data_type));
|
|
FMHA_CHECK_CUDA(cuda_memcpy_h2d(qkv_bsh3d_d, qkv_bsh3d_h.data(), qkv_size, data_type));
|
|
|
|
// Is MTP used?
|
|
bool is_mtp = (d == 576 && dv == 512);
|
|
|
|
for (size_t so = 0; so < s; ++so)
|
|
{ // s_q
|
|
for (size_t bi = 0; bi < b; ++bi)
|
|
{
|
|
int actual_seqlen = seqlens[bi];
|
|
for (size_t si = 0; si < s; ++si)
|
|
{ // s_kv
|
|
// Are both the query and the key inside the sequence?
|
|
bool valid = (si < actual_seqlen) && (so < actual_seqlen);
|
|
// FIXME: add random mask generator.
|
|
// attention_mask_type == Attention_mask_type::CUSTOM_MASK
|
|
if (attention_mask_type == Attention_mask_type::CUSTOM_MASK
|
|
|| attention_mask_type == Attention_mask_type::CAUSAL
|
|
|| attention_mask_type == Attention_mask_type::SLIDING_OR_CHUNKED_CAUSAL)
|
|
{
|
|
valid = valid && (so >= si);
|
|
}
|
|
if (attention_mask_type == Attention_mask_type::SLIDING_OR_CHUNKED_CAUSAL)
|
|
{
|
|
if (chunked_attention_size > 0)
|
|
{
|
|
int chunk_idx = so / chunked_attention_size;
|
|
valid = valid && (si >= (chunk_idx * chunked_attention_size));
|
|
}
|
|
else
|
|
{
|
|
valid = valid && (si >= std::max(int(so + 1 - sliding_window_size), 0));
|
|
}
|
|
}
|
|
if (is_mtp)
|
|
{
|
|
// Only the last s_q tokens are used for verifying the results.
|
|
size_t idx = so - (actual_seqlen - s_q);
|
|
size_t num_mtp_tokens = s_q / num_grouped_heads;
|
|
size_t mtp_token_idx = idx / num_grouped_heads;
|
|
valid
|
|
= idx >= 0 && si < (actual_seqlen - num_mtp_tokens + 1 + mtp_token_idx) && (so < actual_seqlen);
|
|
}
|
|
if (!skip_checks)
|
|
{
|
|
// The mask is stored as floats.
|
|
mask_h[so * b * s + bi * s + si] = valid ? 1.f : 0.f; // mask dims [s_q, b, s_kv]
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if (verbose)
|
|
{
|
|
printf("Sequence lengths (first 10 batches): ");
|
|
for (int bi = 0; bi < seqlens.size() && bi < 10; bi++)
|
|
{
|
|
printf("%d, ", seqlens[bi]);
|
|
}
|
|
printf("\n");
|
|
}
|
|
|
|
if (v1)
|
|
{
|
|
assert(!interleaved && "Interleaved not supported in v1");
|
|
assert(mmas_n <= 4 && "Not supported");
|
|
|
|
FMHA_CHECK_CUDA(cudaMalloc(&packed_mask_d, packed_mask_size_in_bytes));
|
|
if (sm == 70)
|
|
{
|
|
pack_mask_sm70(packed_mask_h, mask_h, s, b, mmas_m, mmas_n, warps_m, warps_n, threads_per_cta);
|
|
}
|
|
else
|
|
{
|
|
pack_mask(packed_mask_h, mask_h, s, b, mmas_m, mmas_n, warps_m, warps_n, threads_per_cta);
|
|
}
|
|
|
|
// Copy the packed mask to the device.
|
|
if (!skip_checks)
|
|
{
|
|
FMHA_CHECK_CUDA(
|
|
cudaMemcpy(packed_mask_d, packed_mask_h, packed_mask_size_in_bytes, cudaMemcpyHostToDevice));
|
|
}
|
|
}
|
|
else if (attention_mask_type == Attention_mask_type::CUSTOM_MASK)
|
|
{
|
|
FMHA_CHECK_CUDA(cudaMalloc(&packed_mask_d, packed_mask_size_in_bytes));
|
|
assert(fmha::FLASH_ATTEN_MASK_MMA_M == warps_m * 16 && "Not supported");
|
|
assert(fmha::FLASH_ATTEN_MASK_MMA_N / 8 == 8 && "Not supported");
|
|
pack_flash_attention_mask(packed_mask_h, mask_h, b, s, warps_m, warps_n, threads_per_cta, mmas_n,
|
|
fmha::FLASH_ATTEN_MASK_MMA_N / 8, mask_h_row_offsets.data(), cu_mask_rows.data());
|
|
|
|
// Copy the packed mask to the device.
|
|
FMHA_CHECK_CUDA(cudaMemcpy(packed_mask_d, packed_mask_h, packed_mask_size_in_bytes, cudaMemcpyHostToDevice));
|
|
}
|
|
|
|
// Copy the mask to the device.
|
|
if (!skip_checks)
|
|
{
|
|
FMHA_CHECK_CUDA(cuda_memcpy_h2d(mask_d, mask_h, mask_size, DATA_TYPE_INT8));
|
|
}
|
|
|
|
// non-owning pointer to the IO buffer
|
|
void* qkv_d_view = nullptr;
|
|
void* o_d_view = nullptr;
|
|
int o_view_size = 0;
|
|
if (is_s_padded)
|
|
{
|
|
qkv_d_view = multi_query_attention ? mqa_qkv_d : qkv_bsh3d_d;
|
|
o_d_view = o_d;
|
|
o_view_size = o_size;
|
|
}
|
|
else
|
|
{
|
|
qkv_d_view = multi_query_attention ? mqa_qkv_packed_d : qkv_packed_d;
|
|
o_d_view = o_packed_d;
|
|
o_view_size = o_packed_size;
|
|
}
|
|
void* softmax_stats_ptr = save_softmax ? softmax_stats_d : nullptr;
|
|
// Set the params.
|
|
bert::Fused_multihead_attention_params_v1 params_v1;
|
|
set_params(params_v1, data_type, acc_type, b, s, h, d, mmas_m * threads_per_cta, qkv_sbh3d_d, packed_mask_d, o_d,
|
|
p_d, s_d, scale_bmm1, scale_softmax, scale_bmm2, has_alibi);
|
|
|
|
bert::Fused_multihead_attention_params_v2 params_v2;
|
|
set_params(params_v2, launch_params, data_type, acc_type, output_dtype, input_layout, b, s_q, s, h, h_kv, d, dv,
|
|
total, num_grouped_heads, sliding_window_size, chunked_attention_size,
|
|
// Paged kv cache.
|
|
tokens_per_block, qkv_d_view, q_d, k_d, v_d, contiguous_kv_d, kv_cache_pool_ptr, kv_cache_block_offsets_d,
|
|
packed_mask_d, cu_mask_rows_d, attention_sinks_d, cu_seqlens_d, cu_q_seqlens_d, o_d_view, p_d, s_d,
|
|
softmax_stats_ptr, scale_bmm2_d, scale_bmm1, scale_softmax, scale_bmm2, softcapping_scale_bmm1,
|
|
use_int8_scale_max, interleaved, is_s_padded, has_alibi, skip_softmax_threshold_scale_factor);
|
|
#ifdef SKIP_SOFTMAX_STAT
|
|
FMHA_CHECK_CUDA(cudaMalloc(¶ms_v2.skip_softmax_total_blocks, sizeof(uint32_t)));
|
|
FMHA_CHECK_CUDA(cudaMalloc(¶ms_v2.skip_softmax_skipped_blocks, sizeof(uint32_t)));
|
|
FMHA_CHECK_CUDA(cudaMemset(params_v2.skip_softmax_total_blocks, 0, sizeof(uint32_t)));
|
|
FMHA_CHECK_CUDA(cudaMemset(params_v2.skip_softmax_skipped_blocks, 0, sizeof(uint32_t)));
|
|
#endif
|
|
|
|
// total number of tokens is needed to set TMA desc on the host.
|
|
launch_params.total_q_seqlen = q_seqlens[b];
|
|
launch_params.total_kv_seqlen = seqlens[b];
|
|
// set enable_attn_logit_softcapping to select the right kernel.
|
|
launch_params.enable_attn_logit_softcapping = softcapping_scale_bmm1 != 0.f;
|
|
|
|
// Allocate barriers and locks.
|
|
void* counters_d = nullptr;
|
|
if (ctas_per_head > 1)
|
|
{
|
|
size_t sz = heads_per_wave * sizeof(int);
|
|
FMHA_CHECK_CUDA(cudaMalloc((void**) &counters_d, 3 * sz));
|
|
}
|
|
|
|
// Allocate scratch storage for softmax.
|
|
void *max_scratch_d = nullptr, *sum_scratch_d = nullptr;
|
|
if (ctas_per_head > 1)
|
|
{
|
|
size_t sz = heads_per_wave * ctas_per_head * threads_per_cta * sizeof(float);
|
|
FMHA_CHECK_CUDA(cudaMalloc((void**) &max_scratch_d, sz));
|
|
FMHA_CHECK_CUDA(cudaMalloc((void**) &sum_scratch_d, sz));
|
|
}
|
|
|
|
// Allocate temporary storage for the parallel reduction.
|
|
void* o_scratch_d = nullptr;
|
|
if (ctas_per_head > 1 && data_type != DATA_TYPE_FP16)
|
|
{
|
|
size_t sz = heads_per_wave * threads_per_cta * MAX_STGS_PER_LOOP * sizeof(uint4);
|
|
FMHA_CHECK_CUDA(cudaMalloc((void**) &o_scratch_d, sz));
|
|
}
|
|
|
|
// Allocate tile id for dynamic scheduling
|
|
void* tile_id_counter_d = nullptr;
|
|
FMHA_CHECK_CUDA(cudaMalloc((void**) &tile_id_counter_d, sizeof(uint32_t)));
|
|
|
|
// The number of heads computed per wave.
|
|
params_v1.heads_per_wave = heads_per_wave;
|
|
params_v2.heads_per_wave = heads_per_wave;
|
|
|
|
// Barriers for the global sync in the multi-CTA kernel(s).
|
|
params_v1.counters = (int*) counters_d + 0 * heads_per_wave;
|
|
params_v2.counters = (int*) counters_d + 0 * heads_per_wave;
|
|
params_v1.max_barriers = (int*) counters_d + 0 * heads_per_wave;
|
|
params_v2.max_barriers = (int*) counters_d + 0 * heads_per_wave;
|
|
params_v1.sum_barriers = (int*) counters_d + 1 * heads_per_wave;
|
|
params_v2.sum_barriers = (int*) counters_d + 1 * heads_per_wave;
|
|
params_v1.locks = (int*) counters_d + 2 * heads_per_wave;
|
|
params_v2.locks = (int*) counters_d + 2 * heads_per_wave;
|
|
|
|
// Scratch storage for softmax.
|
|
params_v1.max_scratch_ptr = (float*) max_scratch_d;
|
|
params_v2.max_scratch_ptr = (float*) max_scratch_d;
|
|
params_v1.sum_scratch_ptr = (float*) sum_scratch_d;
|
|
params_v2.sum_scratch_ptr = (float*) sum_scratch_d;
|
|
|
|
// Scratch storage for output.
|
|
params_v1.o_scratch_ptr = (int*) o_scratch_d;
|
|
params_v2.o_scratch_ptr = (int*) o_scratch_d;
|
|
|
|
// Tile id counter for dynamic scheduling
|
|
params_v2.tile_id_counter_ptr = (uint32_t*) tile_id_counter_d;
|
|
// params_paged_v2.tile_id_counter_ptr = (uint32_t*) tile_id_counter_d;
|
|
|
|
if (sage_block_size_q > 0 || sage_block_size_k > 0 || sage_block_size_v > 0)
|
|
{
|
|
assert(input_layout == Attention_input_layout::PACKED_QKV && "for now this test only supports PACKED_QKV");
|
|
assert(d == dv && "for now SageAttention doesn't support different QKV dims");
|
|
assert(((sm == 90 && !force_non_warp_specialization) || (sm == 89))
|
|
&& "only hopper and ada kernels support SageAttention");
|
|
fmha::e4m3_t* quant_qkv;
|
|
FMHA_CHECK_CUDA(cudaMalloc((void**) &quant_qkv, qkv_packed_size));
|
|
params_v2.sage.q.block_size = sage_block_size_q;
|
|
params_v2.sage.q.max_nblock = (s + sage_block_size_q - 1) / sage_block_size_q;
|
|
FMHA_CHECK_CUDA(
|
|
cudaMalloc((void**) ¶ms_v2.sage.q.scales, params_v2.sage.q.max_nblock * h * b * sizeof(float)));
|
|
params_v2.sage.k.block_size = sage_block_size_k;
|
|
params_v2.sage.k.max_nblock = (s + sage_block_size_k - 1) / sage_block_size_k;
|
|
FMHA_CHECK_CUDA(
|
|
cudaMalloc((void**) ¶ms_v2.sage.k.scales, params_v2.sage.k.max_nblock * h * b * sizeof(float)));
|
|
params_v2.sage.v.block_size = sage_block_size_v;
|
|
params_v2.sage.v.max_nblock = (s + sage_block_size_v - 1) / sage_block_size_v;
|
|
FMHA_CHECK_CUDA(
|
|
cudaMalloc((void**) ¶ms_v2.sage.v.scales, params_v2.sage.v.max_nblock * h * b * sizeof(float)));
|
|
#if 1
|
|
{
|
|
// simple test, all scales are the same
|
|
constexpr float const_scale = 0.618f;
|
|
fmha::e4m3_t* quant_qkv_h = (fmha::e4m3_t*) malloc(qkv_packed_size);
|
|
for (size_t i = 0; i < qkv_packed_size; i++)
|
|
{
|
|
quant_qkv_h[i] = fmha::e4m3_t(qkv_packed_h[i] / const_scale);
|
|
}
|
|
FMHA_CHECK_CUDA(cudaMemcpy(quant_qkv, quant_qkv_h, qkv_packed_size, cudaMemcpyHostToDevice));
|
|
free(quant_qkv_h);
|
|
auto init_scales = [&](bert::Fused_multihead_attention_params_v2::SageAttention::Scales& x)
|
|
{
|
|
std::vector<float> scales(x.max_nblock * h * b, const_scale);
|
|
FMHA_CHECK_CUDA(
|
|
cudaMemcpy(x.scales, scales.data(), sizeof(float) * scales.size(), cudaMemcpyHostToDevice));
|
|
};
|
|
init_scales(params_v2.sage.q);
|
|
init_scales(params_v2.sage.k);
|
|
init_scales(params_v2.sage.v);
|
|
}
|
|
#else
|
|
{
|
|
// use external quant kernel
|
|
run_sage_quant(b, h, d, s, params_v2.qkv_ptr,
|
|
(char*) params_v2.qkv_ptr + get_size_in_bytes(h * d, data_type),
|
|
(char*) params_v2.qkv_ptr + get_size_in_bytes(2 * h * d, data_type,
|
|
params_v2.q_stride_in_bytes,
|
|
params_v2.k_stride_in_bytes,
|
|
params_v2.v_stride_in_bytes,
|
|
params_v2.cu_q_seqlens, params_v2.cu_kv_seqlens, sage_block_size_q, sage_block_size_k,
|
|
sage_block_size_v, quant_qkv, quant_qkv + h * d, quant_qkv + 2 * h * d, params_v2.sage.q.scales,
|
|
params_v2.sage.k.scales, params_v2.sage.v.scales);
|
|
}
|
|
#endif
|
|
// no need to free old params_v2.qkv_ptr, it will be released in the end
|
|
params_v2.qkv_ptr = quant_qkv;
|
|
params_v2.q_stride_in_bytes = params_v2.k_stride_in_bytes = params_v2.v_stride_in_bytes
|
|
= get_size_in_bytes((h + 2 * h_kv) * d, DATA_TYPE_E4M3);
|
|
}
|
|
|
|
#if defined(DEBUG_HAS_PRINT_BUFFER)
|
|
auto& params = params_v2;
|
|
constexpr size_t bytes = 32 * 1024;
|
|
void* print_ptr = nullptr;
|
|
FMHA_CHECK_CUDA(cudaMalloc(¶ms.print_ptr, bytes));
|
|
std::vector<float> print_buffer(bytes / sizeof(float));
|
|
#endif
|
|
// Run a few warm-up kernels.
|
|
for (int ii = 0; ii < warm_up_runs; ++ii)
|
|
{
|
|
if (v1)
|
|
{
|
|
run_fmha_v1(params_v1, launch_params, data_type, output_dtype, sm, 0);
|
|
}
|
|
else
|
|
{
|
|
run_fmha_v2(params_v2, launch_params, data_type, output_dtype, sm, 0);
|
|
}
|
|
}
|
|
FMHA_CHECK_CUDA(cudaPeekAtLastError());
|
|
|
|
float non_fused_elapsed = INFINITY;
|
|
if (!skip_checks)
|
|
{
|
|
// Run cuBLAS.
|
|
|
|
RefBMM bmm1(data_type_to_cuda(data_type), // a
|
|
data_type_to_cuda(data_type), // b
|
|
data_type_to_cuda(acc_type), // d
|
|
data_type_to_cublas(acc_type), // compute
|
|
data_type_to_cuda(acc_type), // scale
|
|
false, // Q
|
|
true, // K'
|
|
s, // m
|
|
s, // n
|
|
d, // k
|
|
b * h * (2 * d + dv), // ld Q
|
|
b * h * (2 * d + dv), // ld K
|
|
b * h * s, // ld P
|
|
(2 * d + dv), // stride Q
|
|
(2 * d + dv), // stride K
|
|
s, // stride P
|
|
b * h // batch count
|
|
);
|
|
|
|
/*
|
|
RefBMM bmm2(data_type_to_cuda(data_type), // a
|
|
data_type_to_cuda(data_type), // b
|
|
data_type_to_cuda(acc_type), // d
|
|
data_type_to_cublas(acc_type), //compute
|
|
data_type_to_cuda(acc_type), // scale
|
|
false, // S
|
|
false, // V
|
|
s, // m
|
|
d, // n
|
|
s, // k
|
|
b * h * s, // ld S
|
|
b * h * 3 * d, // ld V
|
|
b * h * d, // ld O
|
|
s, // stride S
|
|
3 * d, // stride V
|
|
d, // stride O
|
|
b * h // batch count
|
|
);
|
|
*/
|
|
|
|
// WAR fOR MISSING CUBLAS FP8 NN SUPPORT.
|
|
// Transpose V, so that we can do a TN BMM2, i.e. O = S x V' instead of O = S x V.
|
|
RefBMM bmm2(data_type_to_cuda(data_type), // a
|
|
data_type_to_cuda(data_type), // b
|
|
data_type_to_cuda(acc_type), // d
|
|
data_type_to_cublas(acc_type), // compute
|
|
data_type_to_cuda(acc_type), // scale
|
|
false, // S
|
|
true, // V'
|
|
s, // m
|
|
dv, // n
|
|
s, // k
|
|
b * h * s, // ld S
|
|
s, // ld V
|
|
b * h * dv, // ld O
|
|
s, // stride S
|
|
s * dv, // stride V
|
|
dv, // stride O
|
|
b * h // batch count
|
|
);
|
|
timer.start();
|
|
ground_truth(bmm1, bmm2, data_type, acc_type, scale_bmm1, scale_softmax, scale_bmm2, softcapping_scale_bmm1,
|
|
qkv_sbh3d_d,
|
|
vt_d, // WAR pass in V'
|
|
mask_d, attention_sinks_d, p_d, s_d, tmp_d, o_d, softmax_stats_d, cu_seqlens_d, b, s, h, d, dv, runs,
|
|
warps_m, warps_n, has_alibi);
|
|
timer.stop();
|
|
FMHA_CHECK_CUDA(cudaPeekAtLastError());
|
|
FMHA_CHECK_CUDA(cudaDeviceSynchronize());
|
|
non_fused_elapsed = timer.millis();
|
|
|
|
#if defined(STORE_P)
|
|
FMHA_CHECK_CUDA(cuda_memcpy_d2h(p_ref_h, p_d, p_size, acc_type));
|
|
#endif // defined(STORE_P)
|
|
|
|
#if defined(STORE_S)
|
|
FMHA_CHECK_CUDA(cuda_memcpy_d2h(s_ref_h, s_d, p_size, data_type));
|
|
#endif // defined(STORE_S)
|
|
|
|
// Read the results.
|
|
FMHA_CHECK_CUDA(cuda_memcpy_d2h(o_ref_h, o_d, o_size, data_type));
|
|
FMHA_CHECK_CUDA(cuda_memcpy_d2h(softmax_stats_ref_h, softmax_stats_d, 2 * b * s * h, DATA_TYPE_FP32));
|
|
}
|
|
|
|
// Fill-in p/s/o with garbage data.
|
|
// WAR: if sequence is padded, we zero-fill the output buffer as kernel will not write to the
|
|
// padded area, and the host expects to check the padded area
|
|
if (!skip_checks)
|
|
{
|
|
FMHA_CHECK_CUDA(cudaMemset(p_d, 0xdc, p_size_in_bytes));
|
|
FMHA_CHECK_CUDA(cudaMemset(s_d, 0xdc, s_size_in_bytes));
|
|
}
|
|
FMHA_CHECK_CUDA(cudaMemset(o_d, 0x00, o_size_in_bytes));
|
|
FMHA_CHECK_CUDA(cudaMemset(softmax_stats_d, 0x00, 2 * b * s * h * sizeof(float)));
|
|
|
|
// Run the kernel.
|
|
timer.start();
|
|
for (int ii = 0; ii < runs; ++ii)
|
|
{
|
|
if (v1)
|
|
{
|
|
run_fmha_v1(params_v1, launch_params, data_type, output_dtype, sm, 0);
|
|
}
|
|
else
|
|
{
|
|
run_fmha_v2(params_v2, launch_params, data_type, output_dtype, sm, 0);
|
|
}
|
|
}
|
|
timer.stop();
|
|
FMHA_CHECK_CUDA(cudaPeekAtLastError());
|
|
|
|
FMHA_CHECK_CUDA(cudaDeviceSynchronize());
|
|
float fused_elapsed = timer.millis();
|
|
|
|
#if defined(STORE_P)
|
|
FMHA_CHECK_CUDA(cuda_memcpy_d2h(p_h, p_d, p_size, acc_type));
|
|
printf("\nChecking .....: P = norm * K^T * Q\n");
|
|
|
|
// DEBUG.
|
|
printf("seqlens[0]=%d\n", seqlens[0]);
|
|
// END OF DEBUG.
|
|
|
|
// Clear the invalid region of P.
|
|
set_mat<float>(p_ref_h, seqlens, s, b, h, s, 0.f, true);
|
|
set_mat<float>(p_h, seqlens, s, b, h, s, 0.f, true);
|
|
|
|
// Do the check.
|
|
check_results(p_h, p_ref_h, s, s * b * h, s, 0.f, true, true);
|
|
#endif // defined(STORE_P)
|
|
|
|
#if defined(STORE_S)
|
|
FMHA_CHECK_CUDA(cuda_memcpy_d2h(s_h, s_d, p_size, data_type));
|
|
printf("\nChecking .....: S = softmax(P)\n");
|
|
#if defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE)
|
|
float softmax_epsilon = data_type == DATA_TYPE_FP16 ? 1e-3f : 0.f;
|
|
#else
|
|
float softmax_epsilon = 1.e-3f;
|
|
#endif // defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE)
|
|
|
|
// Clear the invalid region of S.
|
|
set_mat<float>(s_ref_h, seqlens, s, b, h, s, 0.f);
|
|
set_mat<float>(s_h, seqlens, s, b, h, s, 0.f);
|
|
|
|
// Do the check.
|
|
check_results(s_h, s_ref_h, s, s * b * h, s, softmax_epsilon, true, true);
|
|
#endif // defined(STORE_S)
|
|
|
|
// Check the final results.
|
|
int status = -1;
|
|
if (skip_checks)
|
|
{
|
|
status = 0;
|
|
printf("\n");
|
|
print_results(true, false);
|
|
}
|
|
else
|
|
{
|
|
if (v1)
|
|
{
|
|
FMHA_CHECK_CUDA(cuda_memcpy_d2h(o_h, o_d, o_size, output_dtype));
|
|
status = check_results(o_h, o_ref_h, d, s * b * h, d, epsilon, verbose, true);
|
|
}
|
|
else
|
|
{
|
|
std::vector<float> o_ref_trans_h(o_size);
|
|
|
|
FMHA_CHECK_CUDA(cuda_memcpy_d2h(o_h, o_d_view, o_view_size, output_dtype));
|
|
FMHA_CHECK_CUDA(cuda_memcpy_d2h(softmax_stats_h, softmax_stats_d, 2 * b * s * h, DATA_TYPE_FP32));
|
|
|
|
if (interleaved)
|
|
{
|
|
// revert batch-interleaved format: 3 x h/32 x total x d x 32 => total x
|
|
// h x 3 x d
|
|
x_vec32(false, o_h, h, is_s_padded ? b * h : total, 1);
|
|
}
|
|
|
|
// Extract the last s_q tokens from the output.
|
|
extract_and_transpose_output<float>(
|
|
o_ref_trans_h.data(), o_ref_h, seqlens, q_seqlens, s, s_q, b, h, dv, is_s_padded);
|
|
if (verbose)
|
|
{
|
|
printf("\nChecking .....: O = V * S\n");
|
|
}
|
|
status = check_results(o_h, o_ref_trans_h.data(), dv, is_s_padded ? s_q * b * h : cu_q_seqlens.back() * h,
|
|
dv, epsilon, verbose, true);
|
|
if (save_softmax)
|
|
{
|
|
auto errors = check_softmax_results(softmax_stats_h, softmax_stats_ref_h, b, s, h, seqlens, cu_seqlens);
|
|
status = status | ((errors.first + errors.second) > 0);
|
|
}
|
|
}
|
|
if (status != 0)
|
|
{ // if there was an error, print the config of the run
|
|
printf("v1=%d il=%d s=%lu b=%lu h=%lu dv=%lu dtype=%s\n", v1, interleaved, s, b, h, dv,
|
|
data_type_to_name(data_type).c_str());
|
|
}
|
|
if (!verbose)
|
|
{ // this just prints the SUCCESS/ERROR line
|
|
print_results(true, true, status == 0);
|
|
}
|
|
}
|
|
|
|
// accounts for tensor core flops only; excludes flops spent in softmax
|
|
size_t total_flops = 0;
|
|
// remove last seqlen(total_seqlen)
|
|
seqlens.pop_back();
|
|
for (auto& s_ : seqlens)
|
|
{
|
|
size_t s_size = size_t(s_);
|
|
total_flops += 2ull * h * (s_q * s_size * d + s_q * dv * s_size); // 1st BMM + 2nd BMM
|
|
}
|
|
total_flops = attention_mask_type == Attention_mask_type::CAUSAL ? total_flops / 2 : total_flops;
|
|
|
|
size_t total_bytes = o_packed_size_in_bytes + qkv_packed_size_in_bytes;
|
|
if (verbose)
|
|
{
|
|
// Runtimes.
|
|
printf("\n");
|
|
if (!skip_checks)
|
|
{
|
|
printf("Non-fused time: %.6f ms\n", non_fused_elapsed / float(runs));
|
|
}
|
|
printf("Fused time ...: %.6f us\n", fused_elapsed * 1000 / float(runs));
|
|
printf("Tensor core ..: %.2f Tflop/s\n", total_flops / (fused_elapsed / float(runs) / 1e-9));
|
|
printf("Bandwidth ....: %.2f GB/s\n", total_bytes / (fused_elapsed / float(runs) / 1e-6));
|
|
if (!skip_checks)
|
|
{
|
|
printf("Ratio ........: %.2fx\n", non_fused_elapsed / fused_elapsed);
|
|
}
|
|
}
|
|
else
|
|
{
|
|
printf("Elapsed ......: %.6f us (%.2fx), %.2f Tflop/s, %.2f GB/s\n", fused_elapsed * 1000 / float(runs),
|
|
non_fused_elapsed / fused_elapsed, total_flops / (fused_elapsed / float(runs) / 1e-9),
|
|
total_bytes / (fused_elapsed / float(runs) / 1e-6));
|
|
}
|
|
#ifdef SKIP_SOFTMAX_STAT
|
|
if (skip_softmax_threshold_scale_factor > 0)
|
|
{
|
|
uint32_t total_blocks, skipped_blocks;
|
|
FMHA_CHECK_CUDA(
|
|
cudaMemcpy(&total_blocks, params_v2.skip_softmax_total_blocks, sizeof(uint32_t), cudaMemcpyDeviceToHost));
|
|
FMHA_CHECK_CUDA(cudaMemcpy(
|
|
&skipped_blocks, params_v2.skip_softmax_skipped_blocks, sizeof(uint32_t), cudaMemcpyDeviceToHost));
|
|
printf("Skip-Softmax .: %u / %u = %.2f%%\n", skipped_blocks, total_blocks,
|
|
total_blocks ? 100.f * skipped_blocks / total_blocks : 0.f);
|
|
}
|
|
#endif
|
|
#if defined(DEBUG_HAS_PRINT_BUFFER)
|
|
FMHA_CHECK_CUDA(cuda_memcpy_d2h(print_buffer.data(), params.print_ptr, print_buffer.size(), DATA_TYPE_FP32));
|
|
|
|
printf("\n====================\n");
|
|
for (int it = 0; it < 16; it++)
|
|
{
|
|
printf("% .4f ", print_buffer[it]);
|
|
}
|
|
printf("\n====================\n");
|
|
|
|
FMHA_CHECK_CUDA(cudaFree(params.print_ptr));
|
|
|
|
#endif
|
|
// Release memory.
|
|
FMHA_CHECK_CUDA(cudaFree(qkv_sbh3d_d));
|
|
FMHA_CHECK_CUDA(cudaFree(qkv_packed_d));
|
|
FMHA_CHECK_CUDA(cudaFree(scale_bmm2_d));
|
|
FMHA_CHECK_CUDA(cudaFree(mqa_qkv_d));
|
|
FMHA_CHECK_CUDA(cudaFree(mqa_qkv_packed_d));
|
|
FMHA_CHECK_CUDA(cudaFree(qkv_bsh3d_d));
|
|
FMHA_CHECK_CUDA(cudaFree(mask_d));
|
|
FMHA_CHECK_CUDA(cudaFree(packed_mask_d));
|
|
FMHA_CHECK_CUDA(cudaFree(q_d));
|
|
FMHA_CHECK_CUDA(cudaFree(k_d));
|
|
FMHA_CHECK_CUDA(cudaFree(v_d));
|
|
FMHA_CHECK_CUDA(cudaFree(p_d));
|
|
FMHA_CHECK_CUDA(cudaFree(s_d));
|
|
FMHA_CHECK_CUDA(cudaFree(o_d));
|
|
FMHA_CHECK_CUDA(cudaFree(tmp_d));
|
|
FMHA_CHECK_CUDA(cudaFree(cu_seqlens_d));
|
|
FMHA_CHECK_CUDA(cudaFree(cu_mask_rows_d));
|
|
FMHA_CHECK_CUDA(cudaFree(max_scratch_d));
|
|
FMHA_CHECK_CUDA(cudaFree(sum_scratch_d));
|
|
FMHA_CHECK_CUDA(cudaFree(o_scratch_d));
|
|
FMHA_CHECK_CUDA(cudaFree(counters_d));
|
|
FMHA_CHECK_CUDA(cudaFree(tile_id_counter_d));
|
|
FMHA_CHECK_CUDA(cudaFree(kv_cache_pool_ptr));
|
|
FMHA_CHECK_CUDA(cudaFree(kv_cache_block_offsets_d));
|
|
FMHA_CHECK_CUDA(cudaFree(contiguous_kv_d));
|
|
FMHA_CHECK_CUDA(cudaFree(softmax_stats_d));
|
|
FMHA_CHECK_CUDA(cudaFree(attention_sinks_d));
|
|
#ifdef SKIP_SOFTMAX_STAT
|
|
FMHA_CHECK_CUDA(cudaFree(params_v2.skip_softmax_total_blocks));
|
|
FMHA_CHECK_CUDA(cudaFree(params_v2.skip_softmax_skipped_blocks));
|
|
#endif
|
|
|
|
free(qkv_h);
|
|
free(mask_h);
|
|
free(packed_mask_h);
|
|
free(s_h);
|
|
free(o_h);
|
|
free(o_ref_h);
|
|
free(softmax_stats_h);
|
|
free(softmax_stats_ref_h);
|
|
free(contiguous_kv_h);
|
|
free(kv_cache_ptrs_h);
|
|
free(kv_cache_block_offsets_h);
|
|
|
|
free(p_ref_h);
|
|
#if defined(STORE_P)
|
|
free(p_h);
|
|
#endif // defined(STORE_P)
|
|
free(s_ref_h);
|
|
|
|
return status;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|