[None] [feat] Add model gpt-oss (#6645)

Signed-off-by: Hao Lu <14827759+hlu1@users.noreply.github.com>
This commit is contained in:
hlu1 2025-08-07 00:04:18 -07:00 committed by GitHub
parent 6c1f7d8b91
commit 8207d5fd39
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2102 changed files with 33998 additions and 8186 deletions

View File

@ -122,6 +122,16 @@ public:
return QuantMode(BaseType(1u) << 14);
}
static constexpr QuantMode w4a8Mxfp4Mxfp8() noexcept
{
return QuantMode(BaseType(1u) << 15);
}
static constexpr QuantMode w4a16Mxfp4() noexcept
{
return QuantMode(BaseType(1u) << 16);
}
constexpr BaseType value() const noexcept
{
return mValue;
@ -202,6 +212,16 @@ public:
return isSet(w4a8Mxfp4Fp8());
}
constexpr bool hasW4a8Mxfp4Mxfp8() const noexcept
{
return isSet(w4a8Mxfp4Mxfp8());
}
constexpr bool hasW4a16Mxfp4() const noexcept
{
return isSet(w4a16Mxfp4());
}
constexpr bool hasKvCacheQuant() const noexcept
{
return hasInt8KvCache() || hasFp8KvCache() || hasFp4KvCache();
@ -209,7 +229,8 @@ public:
static constexpr QuantMode fromDescription(bool quantizeWeights, bool quantizeActivations, bool perToken,
bool perChannel, bool perGroup, bool useInt4Weights, bool useInt8KvCache, bool useFp8KvCache, bool useFp8Qdq,
bool useFp8RowWise, bool useW4a8QServe, bool useFp4Quant, bool useFp8BlockScales, bool useW4a8Mxfp4Fp8)
bool useFp8RowWise, bool useW4a8QServe, bool useFp4Quant, bool useFp8BlockScales, bool useW4a8Mxfp4Fp8,
bool useW4a8Mxfp4Mxfp8, bool useW4a16Mxfp4)
{
QuantMode quantMode{};
if (quantizeWeights)
@ -278,25 +299,35 @@ public:
quantMode += w4a8Mxfp4Fp8();
}
if (useW4a8Mxfp4Mxfp8)
{
quantMode += w4a8Mxfp4Mxfp8();
}
if (useW4a16Mxfp4)
{
quantMode += w4a16Mxfp4();
}
return quantMode;
}
static constexpr QuantMode useSmoothQuant(bool perToken = false, bool perChannel = false)
{
return fromDescription(
true, true, perToken, perChannel, false, false, false, false, false, false, false, false, false, false);
return fromDescription(true, true, perToken, perChannel, false, false, false, false, false, false, false, false,
false, false, false, false);
}
static constexpr QuantMode useQServe(bool perGroup)
{
return fromDescription(
true, true, false, false, perGroup, true, false, false, false, false, true, false, false, false);
return fromDescription(true, true, false, false, perGroup, true, false, false, false, false, true, false, false,
false, false, false);
}
static constexpr QuantMode useWeightOnly(bool useInt4Weights = false, bool perGroup = false)
{
return fromDescription(true, false, false, false, perGroup, useInt4Weights, false, false, false, false, false,
false, false, false);
false, false, false, false, false);
}
static QuantMode const fromQuantAlgo(
@ -353,28 +384,38 @@ public:
}
else if (quantAlgo == "FP8")
{
quantMode = fromDescription(
false, false, false, false, false, false, false, false, true, false, false, false, false, false);
quantMode = fromDescription(false, false, false, false, false, false, false, false, true, false, false,
false, false, false, false, false);
}
else if (quantAlgo == "FP8_ROWWISE")
{
quantMode = fromDescription(
false, false, true, true, false, false, false, false, false, true, false, false, false, false);
quantMode = fromDescription(false, false, true, true, false, false, false, false, false, true, false, false,
false, false, false, false);
}
else if (quantAlgo == "FP4")
{
quantMode = fromDescription(
false, false, false, false, false, false, false, false, false, false, false, true, false, false);
quantMode = fromDescription(false, false, false, false, false, false, false, false, false, false, false,
true, false, false, false, false);
}
else if (quantAlgo == "FP8_BLOCK_SCALES")
{
quantMode = fromDescription(
false, false, false, false, false, false, false, false, false, false, false, false, true, false);
quantMode = fromDescription(false, false, false, false, false, false, false, false, false, false, false,
false, true, false, false, false);
}
else if (quantAlgo == "W4A8_MXFP4_FP8")
{
quantMode = fromDescription(
false, false, false, false, false, false, false, false, false, false, false, false, false, true);
quantMode = fromDescription(false, false, false, false, false, false, false, false, false, false, false,
false, false, true, false, false);
}
else if (quantAlgo == "W4A8_MXFP4_MXFP8")
{
quantMode = fromDescription(false, false, false, false, false, false, false, false, false, false, false,
false, false, false, true, false);
}
else if (quantAlgo == "W4A16_MXFP4")
{
quantMode = fromDescription(false, false, false, false, false, false, false, false, false, false, false,
false, false, false, false, true);
}
if (kvCacheQuantAlgo == "INT8")

View File

@ -50,7 +50,7 @@ def getSMVersion():
ids=["fp16", "bf16", "fp16-fp32", "e4m3"])
@pytest.mark.parametrize('flag', [
"-s-q 128 -paged-kv", "-s-q 63 -paged-kv", "-paged-kv",
"-softcapping-scale-bmm1 30", "-contiguous-q-kv"
"-softcapping-scale-bmm1 30", "-contiguous-q-kv", "-use-attention-sinks"
])
@pytest.mark.parametrize('tiled_kernel', ["", "-force-non-tiled"])
def test_trtllm_flash_attention_fmha(d, s, dtype, flag, tiled_kernel):
@ -117,8 +117,8 @@ def test_trtllm_flash_attention_fmha(d, s, dtype, flag, tiled_kernel):
f"bin/fmha.exe -d {d} -h 16 -b 8 -s {s} -min-s 128 -custom-mask -gqa 2 -v {verbose} {dtype} {epsilon} {flag} {tiled_kernel}",
shell=True,
check=True)
# alibi and softcapping-scale-bmm1 are mutually exclusive.
if '-softcapping-scale-bmm1' not in flag:
# alibi doesn't work with softcapping-scale-bmm1/use-attention-sinks.
if '-softcapping-scale-bmm1' not in flag and '-use-attention-sinks' not in flag:
subprocess.run(
f"bin/fmha.exe -d {d} -h 16 -b 8 -s {s} -min-s 128 -causal-mask -alibi -v {verbose} {dtype} {epsilon} {flag} {tiled_kernel}",
shell=True,

View File

@ -326,9 +326,6 @@ struct Compute
uint32_t smem_v = __cvta_generic_to_shared(&shared->smem_v[0]);
Compute_tile_o ctile_o(0, smem_v);
// BMM2 epilogue
Tile_o_epilogue tile_o_epilogue(params);
// Mutex between two compute groups.
OrderedMutexAccessor mutex_accessor(shared->compute_mutex, warpgroup_id, SYNC_BARRIER);
// Notify warpgroup 0 to execute HGMMA first (overlap HGMMA and Softmax Math Instructions).
@ -368,6 +365,9 @@ struct Compute
sage_scale_row = head_info.bidb * params.h + head_info.bidh;
}
// BMM2 epilogue
Tile_o_epilogue tile_o_epilogue(params, head_info);
int q_step_idx = warpgroup_id;
// Compute work.
@ -490,7 +490,7 @@ struct Compute
if (valid_run)
{
// Final step's update.
tile_o_epilogue.scale(ctile_o, p_sum);
tile_o_epilogue.scale(ctile_o, p_max, p_sum);
// Store o_tile to gmem.
gmem_o.store(ctile_o.acc_);
}

View File

@ -454,7 +454,7 @@ struct Softmax_base
#pragma unroll
for (int mi = 0; mi < Mma_tile_o::CORES_M; mi++)
{
uint32_t const scale = float_to_half2(correction_[mi]);
const uint32_t scale = float_to_half2(correction_[mi]);
// Assume only N has multiple MMAs (MMAS_M = 1).
// MMAS_N > 1 when N dimension is split.
@ -477,9 +477,15 @@ struct Softmax_base
}
// BMM1 scale.
uint32_t const scale_bmm1_;
const uint32_t scale_bmm1_;
// BMM1 softcapping scale.
float const softcapping_scale_bmm1_;
// The sliding window size.
int const sliding_window_size_;
// The log2 attention chunk size.
int const log2_chunked_attention_size_;
// The thread idx in the warp group.
int tidx_;
// The col index for the mma thread layout.
@ -487,15 +493,10 @@ struct Softmax_base
// The row index for the mma thread layout.
int quad_row_;
// The sliding window size.
int const sliding_window_size_;
// The log2 attention chunk size.
int const log2_chunked_attention_size_;
// The packed mask ptr.
uint32_t const* packed_mask_ptr_;
// The packed mask k-dim stride in bytes;
int64_t const params_packed_mask_stride_in_bytes_;
const int64_t params_packed_mask_stride_in_bytes_;
// Unpacked BMM1 output buffer.
float elt_[Mma_tile_p::CORES_M][Mma_tile_p::CORES_N * 2];
@ -1072,20 +1073,53 @@ struct Tile_o_epilogue_base
// The MMA tile for the BMM2.
using Mma_tile_o = typename Kernel_traits::Mma_tile_o;
template <typename Params>
inline __device__ Tile_o_epilogue_base(Params const& params)
// Apply the exp2f optimization (fuse bmm1_scale and -max into FMAs).
enum
{
; // nothing to construct.
EXP2F_OPTIMIZATION = Kernel_traits::EXP2F_OPTIMIZATION
};
template <typename Params, typename Block_info>
inline __device__ Tile_o_epilogue_base(Params const& params, Block_info& block_info)
{
has_attention_sink_ = params.attention_sinks != nullptr;
head_idx_ = block_info.bidh;
attention_sink_ = has_attention_sink_ ? params.attention_sinks[block_info.bidh] : 0.f;
// It is only need when the exp2f optimization is enabled, so params.scale_bmm1 is always float.
scale_bmm1_f_ = reinterpret_cast<float const&>(params.scale_bmm1_d ? *params.scale_bmm1_d : params.scale_bmm1);
};
// The attention sinks.
inline __device__ void add_attention_sink(float& sum, float max)
{
if (has_attention_sink_)
{
// The global max needs to be scaled by the bmm1 scale if exp2f optimization is enabled.
if constexpr (EXP2F_OPTIMIZATION)
{
sum += exp2f(attention_sink_ * M_LOG2E - max * scale_bmm1_f_);
}
else
{
sum += expf(attention_sink_ - max);
}
}
}
// Scale ctile_o output by 1/sum
inline __device__ void scale(Compute_tile_o& ctile_o, float (&global_sum)[Mma_tile_o::CORES_M])
inline __device__ void scale(
Compute_tile_o& ctile_o, float (&global_max)[Mma_tile_o::CORES_M], float (&global_sum)[Mma_tile_o::CORES_M])
{
// Final step's update.
#pragma unroll
for (int mi = 0; mi < Mma_tile_o::CORES_M; mi++)
{
global_sum[mi] = global_sum[mi] == 0.f ? 1.f : 1.0f / global_sum[mi];
// The global sum.
float global_sum_mi = global_sum[mi];
// Add the attention sink to the global sum.
add_attention_sink(global_sum_mi, global_max[mi]);
// The scale.
float scale = global_sum_mi == 0.f ? 1.f : 1.0f / global_sum_mi;
// Assume only N has multiple MMAs (MMAS_M = 1).
#pragma unroll
@ -1096,12 +1130,21 @@ struct Tile_o_epilogue_base
{
float& reg0 = ctile_o.acc_[0][mma_ni].elt(2 * ni * Mma_tile_o::CORES_M + 2 * mi);
float& reg1 = ctile_o.acc_[0][mma_ni].elt(2 * ni * Mma_tile_o::CORES_M + 2 * mi + 1);
reg0 *= global_sum[mi];
reg1 *= global_sum[mi];
reg0 *= scale;
reg1 *= scale;
}
}
}
}
// Whether the attention sink is enabled.
bool has_attention_sink_ = false;
// The attention sink value.
float attention_sink_ = 0.f;
// The float scale of bmm1 outputs.
float scale_bmm1_f_ = 1.f;
// The head idx.
int head_idx_ = 0;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
@ -1138,14 +1181,21 @@ struct Tile_o_epilogue<Hopper_hgmma_fp16_traits, Kernel_traits>
using Base::Tile_o_epilogue_base;
// Scale ctile_o output by 1/sum
inline __device__ void scale(Compute_tile_o& ctile_o, float (&global_sum)[Mma_tile_o::CORES_M])
inline __device__ void scale(
Compute_tile_o& ctile_o, float (&global_max)[Mma_tile_o::CORES_M], float (&global_sum)[Mma_tile_o::CORES_M])
{
// Final step's update.
#pragma unroll
for (int mi = 0; mi < Mma_tile_o::CORES_M; mi++)
{
global_sum[mi] = global_sum[mi] == 0.f ? 1.f : 1.0f / global_sum[mi];
uint32_t const scale = float_to_half2(global_sum[mi]);
// The global sum.
float global_sum_mi = global_sum[mi];
// Add the attention sink to the global sum.
this->add_attention_sink(global_sum_mi, global_max[mi]);
// The scale.
float scale = global_sum_mi == 0.f ? 1.f : 1.0f / global_sum_mi;
// The scale.
const uint32_t scale_h = float_to_half2(scale);
// Assume only N has multiple MMAs (MMAS_M = 1).
#pragma unroll
@ -1155,7 +1205,7 @@ struct Tile_o_epilogue<Hopper_hgmma_fp16_traits, Kernel_traits>
for (int ni = 0; ni < Mma_tile_o::CORES_N; ni++)
{
uint32_t& reg = ctile_o.acc_[0][mma_ni].reg(ni * Mma_tile_o::CORES_M + mi);
reg = hmul2(reg, scale);
reg = hmul2(reg, scale_h);
}
}
}
@ -1215,27 +1265,58 @@ struct Tile_o_epilogue<Hopper_qgmma_e4m3_fp32_traits, Kernel_traits>
// The MMA tile for the BMM2.
using Mma_tile_o = typename Base::Mma_tile_o;
// Apply the exp2f optimization (fuse bmm1_scale and -max into FMAs).
enum
{
EXP2F_OPTIMIZATION = Base::EXP2F_OPTIMIZATION
};
// Ctor.
template <typename Params>
inline __device__ Tile_o_epilogue(Params const& params)
: Base(params)
template <typename Params, typename Block_info>
inline __device__ Tile_o_epilogue(Params const& params, Block_info& block_info)
: Base(params, block_info)
, scale_bmm2_(*params.scale_bmm2_d)
{
}
// Add the attention sink to the global sum.
inline __device__ void add_attention_sink(float& sum, float max)
{
if (this->has_attention_sink_)
{
// The global max needs to be scaled by the bmm1 scale if exp2f optimization is enabled.
// Take the log2f(Traits_o::SOFTMAX_FP_QUANT_SCALE) into account as the same scale has been applied to sum.
float quant_scale_in_log2 = log2f(Traits_o::SOFTMAX_FP_QUANT_SCALE);
if constexpr (EXP2F_OPTIMIZATION)
{
sum += exp2f(this->attention_sink_ * M_LOG2E - max * this->scale_bmm1_f_ + quant_scale_in_log2);
}
else
{
sum += expf(this->attention_sink_ - max + quant_scale_in_log2);
}
}
}
// Scale ctile_o output by 1/sum
inline __device__ void scale(Compute_tile_o& ctile_o, float (&global_sum)[Mma_tile_o::CORES_M])
inline __device__ void scale(
Compute_tile_o& ctile_o, float (&global_max)[Mma_tile_o::CORES_M], float (&global_sum)[Mma_tile_o::CORES_M])
{
// Final step's update.
#pragma unroll
for (int mi = 0; mi < Mma_tile_o::CORES_M; mi++)
{
// The global sum.
float global_sum_mi = global_sum[mi];
// Add the attention sink to the global sum.
add_attention_sink(global_sum_mi, global_max[mi]);
#ifdef UNIFIED_EPILOGUE_SCALE
// Descaling factor
float const scale_bmm2_f_ = reinterpret_cast<float&>(scale_bmm2_);
global_sum[mi] = global_sum[mi] == 0.f ? scale_bmm2_f_ : scale_bmm2_f_ / global_sum[mi];
// The scale.
float scale = global_sum_mi == 0.f ? scale_bmm2_f_ : scale_bmm2_f_ / global_sum_mi;
#else
global_sum[mi] = global_sum[mi] == 0.f ? 1.0f : 1.0f / global_sum[mi];
float scale = global_sum_mi == 0.f ? 1.0f : 1.0f / global_sum_mi;
#endif
// Assume only N has multiple MMAs (MMAS_M = 1).
#pragma unroll
@ -1246,8 +1327,8 @@ struct Tile_o_epilogue<Hopper_qgmma_e4m3_fp32_traits, Kernel_traits>
{
float& reg0 = ctile_o.acc_[0][mma_ni].elt(2 * ni * Mma_tile_o::CORES_M + 2 * mi);
float& reg1 = ctile_o.acc_[0][mma_ni].elt(2 * ni * Mma_tile_o::CORES_M + 2 * mi + 1);
reg0 *= global_sum[mi];
reg1 *= global_sum[mi];
reg0 *= scale;
reg1 *= scale;
}
}
}

View File

@ -29,33 +29,36 @@ using Kv_block_array = fmha::Kv_block_array;
////////////////////////////////////////////////////////////////////////////////////////////////////
void run_softmax_fp32(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_q_seqlens_d,
int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n, bool has_alibi);
////////////////////////////////////////////////////////////////////////////////////////////////////
void run_softmax_e4m3(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_q_seqlens_d,
int s_inner, int s_outer, int b, int h, float scale_softmax, float softcapping_scale_bmm1, int warps_n,
void run_softmax_fp32(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d,
void* cu_q_seqlens_d, int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n,
bool has_alibi);
////////////////////////////////////////////////////////////////////////////////////////////////////
void run_softmax_fp16(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_q_seqlens_d,
int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n, bool has_alibi);
void run_softmax_e4m3(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d,
void* cu_q_seqlens_d, int s_inner, int s_outer, int b, int h, float scale_softmax, float softcapping_scale_bmm1,
int warps_n, bool has_alibi);
////////////////////////////////////////////////////////////////////////////////////////////////////
void run_softmax_bf16(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_q_seqlens_d,
int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n, bool has_alibi);
////////////////////////////////////////////////////////////////////////////////////////////////////
void run_softmax_int8(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_q_seqlens_d,
int s_inner, int s_outer, int b, int h, float scale_i2f, float scale_f2i, float softcapping_scale_bmm1, int warps_n,
void run_softmax_fp16(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d,
void* cu_q_seqlens_d, int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n,
bool has_alibi);
////////////////////////////////////////////////////////////////////////////////////////////////////
void run_softmax_bf16(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d,
void* cu_q_seqlens_d, int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n,
bool has_alibi);
////////////////////////////////////////////////////////////////////////////////////////////////////
void run_softmax_int8(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d,
void* cu_q_seqlens_d, int s_inner, int s_outer, int b, int h, float scale_i2f, float scale_f2i,
float softcapping_scale_bmm1, int warps_n, bool has_alibi);
////////////////////////////////////////////////////////////////////////////////////////////////////
void run_conversion_int32_to_int8(void* dst, void const* src, int s, int b, int h, int d, float scale);
////////////////////////////////////////////////////////////////////////////////////////////////////
@ -81,11 +84,11 @@ void run_sage_quant(unsigned int batch_size, unsigned int head_num, unsigned int
////////////////////////////////////////////////////////////////////////////////////////////////////
void ground_truth(RefBMM& bmm1, RefBMM& bmm2, Data_type const data_type, Data_type const acc_type,
void ground_truth(RefBMM& bmm1, RefBMM& bmm2, const Data_type data_type, const Data_type acc_type,
float const scale_bmm1, float const scale_softmax, float const scale_bmm2, float const softcapping_scale_bmm1,
void* qkv_d, void* vt_d, void* mask_d, void* p_d, void* s_d, void* tmp_d, void* o_d, void* softmax_sum_d,
void* cu_q_seqlens_d, size_t const b, size_t const s, size_t const h, size_t const d, size_t const dv,
int const runs, int const warps_m, int const warps_n, bool const has_alibi)
void* qkv_d, void* vt_d, void* mask_d, void* attention_sinks_d, void* p_d, void* s_d, void* tmp_d, void* o_d,
void* softmax_sum_d, void* cu_q_seqlens_d, const size_t b, const size_t s, const size_t h, const size_t d,
const size_t dv, int const runs, int const warps_m, int const warps_n, bool const has_alibi)
{
cudaStream_t stream = 0;
@ -106,28 +109,28 @@ void ground_truth(RefBMM& bmm1, RefBMM& bmm2, Data_type const data_type, Data_ty
// Softmax.
if (data_type == DATA_TYPE_FP16 && acc_type == DATA_TYPE_FP16)
{
run_softmax_fp16(s_d, p_d, mask_d, softmax_sum_d, cu_q_seqlens_d, s, s, b, h, softcapping_scale_bmm1,
warps_n, has_alibi);
run_softmax_fp16(s_d, p_d, mask_d, attention_sinks_d, softmax_sum_d, cu_q_seqlens_d, s, s, b, h,
softcapping_scale_bmm1, warps_n, has_alibi);
}
else if (data_type == DATA_TYPE_BF16 && acc_type == DATA_TYPE_FP32)
{
run_softmax_bf16(s_d, p_d, mask_d, softmax_sum_d, cu_q_seqlens_d, s, s, b, h, softcapping_scale_bmm1,
warps_n, has_alibi);
run_softmax_bf16(s_d, p_d, mask_d, attention_sinks_d, softmax_sum_d, cu_q_seqlens_d, s, s, b, h,
softcapping_scale_bmm1, warps_n, has_alibi);
}
else if (data_type == DATA_TYPE_FP16 && acc_type == DATA_TYPE_FP32)
{
run_softmax_fp32(s_d, p_d, mask_d, softmax_sum_d, cu_q_seqlens_d, s, s, b, h, softcapping_scale_bmm1,
warps_n, has_alibi);
run_softmax_fp32(s_d, p_d, mask_d, attention_sinks_d, softmax_sum_d, cu_q_seqlens_d, s, s, b, h,
softcapping_scale_bmm1, warps_n, has_alibi);
}
else if (data_type == DATA_TYPE_E4M3 && acc_type == DATA_TYPE_FP32)
{
run_softmax_e4m3(s_d, p_d, mask_d, softmax_sum_d, cu_q_seqlens_d, s, s, b, h, scale_softmax,
softcapping_scale_bmm1, warps_n, has_alibi);
run_softmax_e4m3(s_d, p_d, mask_d, attention_sinks_d, softmax_sum_d, cu_q_seqlens_d, s, s, b, h,
scale_softmax, softcapping_scale_bmm1, warps_n, has_alibi);
}
else if (data_type == DATA_TYPE_INT8 && acc_type == DATA_TYPE_INT32)
{
run_softmax_int8(s_d, p_d, mask_d, softmax_sum_d, cu_q_seqlens_d, s, s, b, h, scale_bmm1, scale_softmax,
softcapping_scale_bmm1, warps_n, has_alibi);
run_softmax_int8(s_d, p_d, mask_d, attention_sinks_d, softmax_sum_d, cu_q_seqlens_d, s, s, b, h, scale_bmm1,
scale_softmax, softcapping_scale_bmm1, warps_n, has_alibi);
}
else
{
@ -179,7 +182,7 @@ static inline void set_params(bert::Fused_multihead_attention_params_v1& params,
// types
Data_type data_type, Data_type acc_type,
// sizes
size_t const b, size_t const s, size_t const h, size_t const d, size_t const packed_mask_stride,
const size_t b, const size_t s, const size_t h, const size_t d, const size_t packed_mask_stride,
// device pointers
void* qkv_d, void* packed_mask_d, void* o_d, void* p_d, void* s_d,
// scale factors
@ -235,17 +238,17 @@ static inline void set_params(bert::Fused_multihead_attention_params_v1& params,
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline void set_params(bert::Fused_multihead_attention_params_v2& params, Launch_params const launch_params,
static inline void set_params(bert::Fused_multihead_attention_params_v2& params, const Launch_params launch_params,
// types
Data_type data_type, Data_type acc_type, Data_type output_dtype,
// attention input layout
Attention_input_layout input_layout,
// sizes
size_t const b, size_t const s_q, size_t const s_kv, size_t const h, size_t const h_kv, size_t const d,
size_t const dv, size_t const total, const size_t num_grouped_heads, const size_t sliding_window_size,
const size_t b, const size_t s_q, const size_t s_kv, const size_t h, const size_t h_kv, const size_t d,
const size_t dv, const size_t total, const size_t num_grouped_heads, const size_t sliding_window_size,
const size_t chunked_attention_size,
// paged kv cache block size.
size_t const tokens_per_block,
const size_t tokens_per_block,
// device pointers
void* qkv_packed_d,
// contiguous q.
@ -261,8 +264,10 @@ static inline void set_params(bert::Fused_multihead_attention_params_v2& params,
// offsets for different blocks in terms of the start address.
int32_t* paged_block_offsets,
// mask input.
void* packed_mask_d, void* cu_mask_rows_d, void* cu_kv_seqlens_d, void* cu_q_seqlens_d, void* o_packed_d, void* p_d,
void* s_d, void* softmax_stats_d, void* scale_bmm2_d,
void* packed_mask_d, void* cu_mask_rows_d,
// attention sinks.
void* attention_sinks_d, void* cu_kv_seqlens_d, void* cu_q_seqlens_d, void* o_packed_d, void* p_d, void* s_d,
void* softmax_stats_d, void* scale_bmm2_d,
// scale factors
float const scale_bmm1, float const scale_softmax, float const scale_bmm2, float const softcapping_scale_bmm1,
// flags
@ -329,6 +334,9 @@ static inline void set_params(bert::Fused_multihead_attention_params_v2& params,
// The N dimension has to be aligned.
params.packed_mask_stride_in_bytes = (align_to(int64_t(s_kv), int64_t(fmha::FLASH_ATTEN_MASK_N_ALIGNMENT))) / 8;
// Attention sinks.
params.attention_sinks = reinterpret_cast<float*>(attention_sinks_d);
#if defined(STORE_P)
params.p_ptr = p_d;
params.p_stride_in_bytes = get_size_in_bytes(b * h * s_kv, acc_type);
@ -412,13 +420,13 @@ static inline void set_params(bert::Fused_multihead_attention_params_v2& params,
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline void determine_launch_params(Launch_params& launch_params, Data_type data_type, int sm, size_t const s,
size_t const d, Attention_mask_type const attention_mask_type, Attention_input_layout const input_layout,
static inline void determine_launch_params(Launch_params& launch_params, Data_type data_type, int sm, const size_t s,
const size_t d, const Attention_mask_type attention_mask_type, const Attention_input_layout input_layout,
bool const interleaved, bool const ignore_b1opt, bool const force_unroll, bool const use_tma,
bool const force_non_flash_attention, bool const force_non_warp_specialization,
bool const force_non_granular_tiling, bool const force_fp32_acc,
// device props
cudaDeviceProp const props)
const cudaDeviceProp props)
{
// Set launch params to choose kernels
@ -573,6 +581,9 @@ int main(int argc, char** argv)
// SageAttention block sizes
int sage_block_size_q = 0, sage_block_size_k = 0, sage_block_size_v = 0;
// Use attention sinks (added to the denominator of softmax)
bool use_attention_sinks = false;
// Read the parameters from the command-line.
for (int ii = 1; ii < argc; ++ii)
{
@ -865,13 +876,16 @@ int main(int argc, char** argv)
{
sage_block_size_v = strtol(argv[ii], nullptr, 10);
}
else if (!strcmp(argv[ii], "-use-attention-sinks"))
{
use_attention_sinks = true;
}
else
{
fprintf(stderr, "Unrecognized option: %s. Aborting!\n", argv[ii]);
return -1;
}
}
if (save_softmax == true)
{
if (input_layout != Attention_input_layout::CONTIGUOUS_Q_KV)
@ -1043,11 +1057,11 @@ int main(int argc, char** argv)
force_non_granular_tiling, force_fp32_acc, props);
// The Q, K and V matrices are packed into one big matrix of size S x B x H x 3 x D.
size_t const qkv_size = s * b * h * (2 * d + dv);
const size_t qkv_size = s * b * h * (2 * d + dv);
// Allocate on the host.
float* qkv_h = (float*) malloc(qkv_size * sizeof(float));
// The size in bytes.
size_t const qkv_size_in_bytes = get_size_in_bytes(qkv_size, data_type);
const size_t qkv_size_in_bytes = get_size_in_bytes(qkv_size, data_type);
// Allocate on the device.
void *qkv_sbh3d_d = nullptr, *qkv_bsh3d_d = nullptr;
FMHA_CHECK_CUDA(cudaMalloc(&qkv_sbh3d_d, qkv_size_in_bytes));
@ -1057,7 +1071,7 @@ int main(int argc, char** argv)
// The shape is [B, 2, S, H, D].
const size_t kv_size = b * s * h_kv * (d + dv);
// The size in bytes.
size_t const kv_size_in_bytes = get_size_in_bytes(kv_size, data_type);
const size_t kv_size_in_bytes = get_size_in_bytes(kv_size, data_type);
// Allocate on the host.
void* contiguous_kv_h = malloc(kv_size_in_bytes);
// Memset the buffer.
@ -1071,13 +1085,13 @@ int main(int argc, char** argv)
void** kv_cache_ptrs_h = nullptr;
void* kv_cache_pool_ptr = nullptr;
int32_t *kv_cache_block_offsets_h, *kv_cache_block_offsets_d = nullptr;
size_t const max_blocks_per_seq = (s + tokens_per_block - 1) / tokens_per_block;
size_t const num_total_blocks = b * 2 * max_blocks_per_seq;
const size_t max_blocks_per_seq = (s + tokens_per_block - 1) / tokens_per_block;
const size_t num_total_blocks = b * 2 * max_blocks_per_seq;
kv_cache_ptrs_h = (void**) malloc(num_total_blocks * sizeof(void*));
kv_cache_block_offsets_h = (int32_t*) malloc(num_total_blocks * sizeof(int32_t));
size_t const paged_kv_block_size_in_bytes = get_size_in_bytes(tokens_per_block * h_kv * std::gcd(d, dv), data_type);
const size_t paged_kv_block_size_in_bytes = get_size_in_bytes(tokens_per_block * h_kv * std::gcd(d, dv), data_type);
FMHA_CHECK_CUDA(cudaMalloc((void**) (&kv_cache_block_offsets_d), num_total_blocks * sizeof(int32_t)));
size_t const kv_cache_pool_sz
const size_t kv_cache_pool_sz
= get_size_in_bytes(num_total_blocks * tokens_per_block * h_kv * (d + dv) / 2, data_type);
FMHA_CHECK_CUDA(cudaMalloc((void**) (&kv_cache_pool_ptr), kv_cache_pool_sz));
size_t ptr_index = 0;
@ -1104,7 +1118,7 @@ int main(int argc, char** argv)
// Q will always be [B, S, H, Dh] with paged kv cache.
void* q_d;
size_t const q_size = s * b * h * d;
const size_t q_size = s * b * h * d;
FMHA_CHECK_CUDA(cudaMalloc(&q_d, get_size_in_bytes(q_size, data_type)));
// K has [B, S, H_kv, D] with separate kv cache.
@ -1122,11 +1136,11 @@ int main(int argc, char** argv)
FMHA_CHECK_CUDA(cudaMalloc(&scale_bmm2_d, sizeof(uint32_t)));
// The mask for dropout or any mask patterns.
size_t const mask_size = s * b * s;
const size_t mask_size = s * b * s;
// Allocate on the host.
float* mask_h = (float*) malloc(mask_size * sizeof(float));
// The size in bytes.
size_t const mask_size_in_bytes = get_size_in_bytes(mask_size, DATA_TYPE_INT8);
const size_t mask_size_in_bytes = get_size_in_bytes(mask_size, DATA_TYPE_INT8);
// Allocate on the device.
void* mask_d = nullptr;
if (!skip_checks)
@ -1158,7 +1172,7 @@ int main(int argc, char** argv)
v1 ? 1 : 2);
// The number of threads per CTA.
size_t const threads_per_cta = warps_m * warps_n * warps_k * 32;
const size_t threads_per_cta = warps_m * warps_n * warps_k * 32;
// The number of mmas in the M dimension. We use one uint32_t per MMA in the M dimension.
size_t mmas_m = (s + 16 * warps_m - 1) / (16 * warps_m);
// The number of mmas in the N dimension.
@ -1182,7 +1196,7 @@ int main(int argc, char** argv)
packed_mask_size = b * mmas_m * mmas_n * threads_per_cta;
}
// The size in bytes.
size_t const packed_mask_size_in_bytes = packed_mask_size * sizeof(uint32_t);
const size_t packed_mask_size_in_bytes = packed_mask_size * sizeof(uint32_t);
// Allocate on the host.
uint32_t* packed_mask_h = (uint32_t*) malloc(packed_mask_size_in_bytes);
// Set it to 0 (indicates that all elements are valid).
@ -1190,12 +1204,30 @@ int main(int argc, char** argv)
// Allocate on the device.
void* packed_mask_d = nullptr;
// The size of the attention sinks.
const size_t attention_sinks_size_in_bytes = h * sizeof(float);
// The attention sinks.
void* attention_sinks_d = nullptr;
if (use_attention_sinks)
{
// Allocate on the host.
float* attention_sinks_h = (float*) malloc(attention_sinks_size_in_bytes);
// Randomly initialize the attention sinks.
random_init("attention_sinks", attention_sinks_h, 1, h, 1, false, 5.f, 1.f, verbose);
// Allocate on the device.
FMHA_CHECK_CUDA(cudaMalloc(&attention_sinks_d, attention_sinks_size_in_bytes));
// Copy from the host to the device.
FMHA_CHECK_CUDA(
cudaMemcpy(attention_sinks_d, attention_sinks_h, attention_sinks_size_in_bytes, cudaMemcpyDefault));
}
// The O matrix is packed as S * B * H * D.
size_t const o_size = s * b * h * dv;
const size_t o_size = s * b * h * dv;
// Allocate on the host.
float* o_h = (float*) malloc(o_size * sizeof(float));
// The size in bytes.
size_t const o_size_in_bytes = get_size_in_bytes(o_size, data_type);
const size_t o_size_in_bytes = get_size_in_bytes(o_size, data_type);
// Allocate on the device.
void* o_d = nullptr;
FMHA_CHECK_CUDA(cudaMalloc(&o_d, o_size_in_bytes));
@ -1206,7 +1238,7 @@ int main(int argc, char** argv)
FMHA_CHECK_CUDA(cudaMemset(softmax_stats_d, 0x00, 2 * sizeof(float) * b * s * h));
// The size in bytes.
size_t const tmp_size_in_bytes = get_size_in_bytes(o_size, acc_type);
const size_t tmp_size_in_bytes = get_size_in_bytes(o_size, acc_type);
// Allocate on the device.
void* tmp_d = nullptr;
if (data_type != acc_type)
@ -1220,9 +1252,9 @@ int main(int argc, char** argv)
float* softmax_sum_h = (float*) malloc(b * s * h * sizeof(float));
// The P matrix is stored as one big matrix of size S x B x H x S.
size_t const p_size = s * b * h * s;
const size_t p_size = s * b * h * s;
// The size in bytes.
size_t const p_size_in_bytes = get_size_in_bytes(p_size, acc_type);
const size_t p_size_in_bytes = get_size_in_bytes(p_size, acc_type);
// Allocate on the device.
void* p_d = nullptr;
if (!skip_checks)
@ -1238,7 +1270,7 @@ int main(int argc, char** argv)
#endif // defined(STORE_P)
// The size in bytes of the S matrix (the data type may be different from P for int8).
size_t const s_size_in_bytes = get_size_in_bytes(p_size, data_type);
const size_t s_size_in_bytes = get_size_in_bytes(p_size, data_type);
// Allocate on the device.
void* s_d = nullptr;
if (!skip_checks)
@ -1327,7 +1359,7 @@ int main(int argc, char** argv)
std::vector<uint32_t> seqlens(b, 0); // randomly draw a batch of sequence lengths >= min_s
std::transform(seqlens.begin(), seqlens.end(), seqlens.begin(),
[=](uint32_t const)
[=](const uint32_t)
{
if (fix_s)
{
@ -1415,7 +1447,7 @@ int main(int argc, char** argv)
FMHA_CHECK_CUDA(cudaMalloc(&mqa_qkv_packed_d, mqa_qkv_packed_size_in_bytes));
FMHA_CHECK_CUDA(cudaMalloc(&mqa_qkv_d, mqa_qkv_size_in_bytes));
size_t const o_packed_size = cu_seqlens.back() * h * dv;
const size_t o_packed_size = cu_seqlens.back() * h * dv;
// Allocate on the host.
float* o_packed_h = (float*) malloc(o_packed_size * sizeof(float));
void* o_packed_d = nullptr;
@ -1676,9 +1708,9 @@ int main(int argc, char** argv)
total, num_grouped_heads, sliding_window_size, chunked_attention_size,
// Paged kv cache.
tokens_per_block, qkv_d_view, q_d, k_d, v_d, contiguous_kv_d, kv_cache_pool_ptr, kv_cache_block_offsets_d,
packed_mask_d, cu_mask_rows_d, cu_seqlens_d, cu_q_seqlens_d, o_d_view, p_d, s_d, softmax_stats_ptr,
scale_bmm2_d, scale_bmm1, scale_softmax, scale_bmm2, softcapping_scale_bmm1, use_int8_scale_max, interleaved,
is_s_padded, has_alibi);
packed_mask_d, cu_mask_rows_d, attention_sinks_d, cu_seqlens_d, cu_q_seqlens_d, o_d_view, p_d, s_d,
softmax_stats_ptr, scale_bmm2_d, scale_bmm1, scale_softmax, scale_bmm2, softcapping_scale_bmm1,
use_int8_scale_max, interleaved, is_s_padded, has_alibi);
// total number of tokens is needed to set TMA desc on the host.
launch_params.total_q_seqlen = q_seqlens[b];
@ -1894,8 +1926,8 @@ int main(int argc, char** argv)
ground_truth(bmm1, bmm2, data_type, acc_type, scale_bmm1, scale_softmax, scale_bmm2, softcapping_scale_bmm1,
qkv_sbh3d_d,
vt_d, // WAR pass in V'
mask_d, p_d, s_d, tmp_d, o_d, softmax_stats_d, cu_seqlens_d, b, s, h, d, dv, runs, warps_m, warps_n,
has_alibi);
mask_d, attention_sinks_d, p_d, s_d, tmp_d, o_d, softmax_stats_d, cu_seqlens_d, b, s, h, d, dv, runs,
warps_m, warps_n, has_alibi);
timer.stop();
FMHA_CHECK_CUDA(cudaPeekAtLastError());
FMHA_CHECK_CUDA(cudaDeviceSynchronize());
@ -2009,7 +2041,6 @@ int main(int argc, char** argv)
// Extract the last s_q tokens from the output.
extract_and_transpose_output<float>(
o_ref_trans_h.data(), o_ref_h, seqlens, q_seqlens, s, s_q, b, h, dv, is_s_padded);
if (verbose)
{
printf("\nChecking .....: O = V * S\n");

View File

@ -197,6 +197,9 @@ struct Fused_multihead_attention_params_v2 : Fused_multihead_attention_params_ba
// The stride between rows of softmax_stats_ptr
int64_t softmax_stats_stride_in_bytes;
// The attention sinks (per head).
float* attention_sinks;
// array of length b+1 holding prefix sum of actual q sequence lengths.
int* cu_q_seqlens;
// array of length b+1 holding prefix sum of actual kv sequence lengths.

View File

@ -87,6 +87,8 @@ struct Fused_multihead_attention_params_v2
fmha::Kv_block_array paged_kv_cache;
// The mask to implement drop-out.
void* packed_mask_ptr;
// The attention sinks (per head).
float* attention_sinks;
// The O matrix (output).
void* o_ptr;
// The Softmax stats vector of layout [2, B, S, H], including softmax_sum and softmax_max

View File

@ -23,28 +23,30 @@ using Launch_params = bert::Fused_multihead_attention_launch_params;
////////////////////////////////////////////////////////////////////////////////////////////////////
void run_softmax_fp32(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_seqlens_q_d,
int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n, bool has_alibi);
////////////////////////////////////////////////////////////////////////////////////////////////////
void run_softmax_e4m3(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_seqlens_q_d,
int s_inner, int s_outer, int b, int h, float scale_softmax, float softcapping_scale_bmm1, int warps_n,
void run_softmax_fp32(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d,
void* cu_seqlens_q_d, int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n,
bool has_alibi);
////////////////////////////////////////////////////////////////////////////////////////////////////
void run_softmax_fp16(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_seqlens_q_d,
int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n, bool has_alibi);
void run_softmax_e4m3(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d,
void* cu_seqlens_q_d, int s_inner, int s_outer, int b, int h, float scale_softmax, float softcapping_scale_bmm1,
int warps_n, bool has_alibi);
////////////////////////////////////////////////////////////////////////////////////////////////////
void run_softmax_int8(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_seqlens_q_d,
int s_inner, int s_outer, int b, int h, float scale_i2f, float scale_f2i, float softcapping_scale_bmm1, int warps_n,
void run_softmax_fp16(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d,
void* cu_seqlens_q_d, int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n,
bool has_alibi);
////////////////////////////////////////////////////////////////////////////////////////////////////
void run_softmax_int8(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d,
void* cu_seqlens_q_d, int s_inner, int s_outer, int b, int h, float scale_i2f, float scale_f2i,
float softcapping_scale_bmm1, int warps_n, bool has_alibi);
////////////////////////////////////////////////////////////////////////////////////////////////////
void run_conversion_int32_to_int8(void* dst, void const* src, int s, int b, int h, int d, float scale);
////////////////////////////////////////////////////////////////////////////////////////////////////
@ -57,10 +59,10 @@ void run_conversion_fp32_to_e4m3(void* dst, void const* src, int s, int b, int h
////////////////////////////////////////////////////////////////////////////////////////////////////
void ground_truth(RefBMM& bmm1, RefBMM& bmm2, Data_type const data_type, Data_type const acc_type,
void ground_truth(RefBMM& bmm1, RefBMM& bmm2, const Data_type data_type, const Data_type acc_type,
float const scale_bmm1, float const scale_softmax, float const scale_bmm2, void* q_d, void* kv_d, void* vt_d,
void* mask_d, void* p_d, void* s_d, void* tmp_d, void* o_d, void* softmax_sum_d, void* cu_seqlens_q_d,
size_t const b, size_t const s_q, size_t const s_kv, size_t const h, size_t const d, int const runs,
const size_t b, const size_t s_q, const size_t s_kv, const size_t h, const size_t d, int const runs,
int const warps_m, int const warps_n, bool has_alibi)
{
@ -84,20 +86,22 @@ void ground_truth(RefBMM& bmm1, RefBMM& bmm2, Data_type const data_type, Data_ty
// Softmax.
if (data_type == DATA_TYPE_FP16 && acc_type == DATA_TYPE_FP16)
{
run_softmax_fp16(s_d, p_d, mask_d, softmax_sum_d, cu_seqlens_q_d, s_kv, s_q, b, h, 0.f, warps_n, has_alibi);
run_softmax_fp16(
s_d, p_d, mask_d, nullptr, softmax_sum_d, cu_seqlens_q_d, s_kv, s_q, b, h, 0.f, warps_n, has_alibi);
}
else if (data_type == DATA_TYPE_FP16 && acc_type == DATA_TYPE_FP32)
{
run_softmax_fp32(s_d, p_d, mask_d, softmax_sum_d, cu_seqlens_q_d, s_kv, s_q, b, h, 0.f, warps_n, has_alibi);
run_softmax_fp32(
s_d, p_d, mask_d, nullptr, softmax_sum_d, cu_seqlens_q_d, s_kv, s_q, b, h, 0.f, warps_n, has_alibi);
}
else if (data_type == DATA_TYPE_E4M3 && acc_type == DATA_TYPE_FP32)
{
run_softmax_e4m3(s_d, p_d, mask_d, softmax_sum_d, cu_seqlens_q_d, s_kv, s_q, b, h, scale_softmax, 0.f,
warps_n, has_alibi);
run_softmax_e4m3(s_d, p_d, mask_d, nullptr, softmax_sum_d, cu_seqlens_q_d, s_kv, s_q, b, h, scale_softmax,
0.f, warps_n, has_alibi);
}
else if (data_type == DATA_TYPE_INT8 && acc_type == DATA_TYPE_INT32)
{
run_softmax_int8(s_d, p_d, mask_d, softmax_sum_d, cu_seqlens_q_d, s_kv, s_q, b, h, scale_bmm1,
run_softmax_int8(s_d, p_d, mask_d, nullptr, softmax_sum_d, cu_seqlens_q_d, s_kv, s_q, b, h, scale_bmm1,
scale_softmax, 0.f, warps_n, has_alibi);
}
else
@ -148,8 +152,8 @@ static inline void set_params(bert::Fused_multihead_attention_params_mhca& param
// types
Data_type data_type, Data_type acc_type,
// sizes
size_t const b, size_t const s_q, size_t const s_kv, size_t const h, size_t const d, size_t const d_padded,
size_t const total,
const size_t b, const size_t s_q, const size_t s_kv, const size_t h, const size_t d, const size_t d_padded,
const size_t total,
// device pointers
void* q_packed_d, void* kv_packed_d, void* cu_seqlens_q_d, void* cu_seqlens_kv_d, void* o_packed_d, void* p_d,
void* s_d,
@ -515,17 +519,17 @@ int main(int argc, char** argv)
launch_params.use_tma = use_tma;
// The Q matrix of size S_Q x B x H x D.
size_t const q_size = s_q * b * h * d;
const size_t q_size = s_q * b * h * d;
// The K and V matrices are packed into one big matrix of size S_KV x B x H x 2 x D.
size_t const kv_size = s_kv_padded * b * h * 2 * d;
const size_t kv_size = s_kv_padded * b * h * 2 * d;
// Allocate on the host.
float* q_h = (float*) malloc(q_size * sizeof(float));
// Allocate on the host.
float* kv_h = (float*) malloc(kv_size * sizeof(float));
// The size in bytes.
size_t const q_size_in_bytes = get_size_in_bytes(q_size, data_type);
const size_t q_size_in_bytes = get_size_in_bytes(q_size, data_type);
// The size in bytes.
size_t const kv_size_in_bytes = get_size_in_bytes(kv_size, data_type);
const size_t kv_size_in_bytes = get_size_in_bytes(kv_size, data_type);
// Allocate on the device.
void* q_d = nullptr;
FMHA_CHECK_CUDA(cudaMalloc(&q_d, q_size_in_bytes));
@ -534,11 +538,11 @@ int main(int argc, char** argv)
FMHA_CHECK_CUDA(cudaMalloc(&kv_d, kv_size_in_bytes));
// The mask for dropout.
size_t const mask_size = s_q * b * s_kv_padded;
const size_t mask_size = s_q * b * s_kv_padded;
// Allocate on the host.
float* mask_h = (float*) malloc(mask_size * sizeof(float));
// The size in bytes.
size_t const mask_size_in_bytes = get_size_in_bytes(mask_size, DATA_TYPE_INT8);
const size_t mask_size_in_bytes = get_size_in_bytes(mask_size, DATA_TYPE_INT8);
// Allocate on the device.
void* mask_d = nullptr;
FMHA_CHECK_CUDA(cudaMalloc(&mask_d, mask_size_in_bytes));
@ -554,28 +558,28 @@ int main(int argc, char** argv)
v1 ? 1 : 2);
// The number of threads per CTA.
size_t const threads_per_cta = warps_m * warps_n * warps_k * 32;
const size_t threads_per_cta = warps_m * warps_n * warps_k * 32;
// The number of mmas in the M dimension. We use one uint32_t per MMA in the M dimension.
size_t const mmas_m = (s_q + 16 * warps_m - 1) / (16 * warps_m);
const size_t mmas_m = (s_q + 16 * warps_m - 1) / (16 * warps_m);
// The number of mmas in the N dimension.
size_t const mmas_n = (s_kv_padded + 16 * warps_n - 1) / (16 * warps_n);
const size_t mmas_n = (s_kv_padded + 16 * warps_n - 1) / (16 * warps_n);
// We do not support more than 4 MMAS in the N dimension (as each MMA needs 8 bits in the mask).
assert(!v1 || mmas_n <= 4);
// The packed mask for dropout (in the fused kernel). Layout is B * MMAS_M * THREADS_PER_CTA.
size_t const packed_mask_size = b * mmas_m * threads_per_cta;
const size_t packed_mask_size = b * mmas_m * threads_per_cta;
// The size in bytes.
size_t const packed_mask_size_in_bytes = packed_mask_size * sizeof(uint32_t);
const size_t packed_mask_size_in_bytes = packed_mask_size * sizeof(uint32_t);
// Allocate on the host.
uint32_t* packed_mask_h = (uint32_t*) malloc(packed_mask_size_in_bytes);
// Allocate on the device.
void* packed_mask_d = nullptr;
// The O matrix is packed as S_Q * B * H * D.
size_t const o_size = s_q * b * h * d;
const size_t o_size = s_q * b * h * d;
// Allocate on the host.
float* o_h = (float*) malloc(o_size * sizeof(float));
// The size in bytes.
size_t const o_size_in_bytes = get_size_in_bytes(o_size, data_type);
const size_t o_size_in_bytes = get_size_in_bytes(o_size, data_type);
// Allocate on the device.
void* o_d = nullptr;
FMHA_CHECK_CUDA(cudaMalloc(&o_d, o_size_in_bytes));
@ -587,7 +591,7 @@ int main(int argc, char** argv)
FMHA_CHECK_CUDA(cudaMemset(softmax_max_d, 0x00, sizeof(float) * b * s_q * h));
// The size in bytes.
size_t const tmp_size_in_bytes = get_size_in_bytes(o_size, acc_type);
const size_t tmp_size_in_bytes = get_size_in_bytes(o_size, acc_type);
// Allocate on the device.
void* tmp_d = nullptr;
if (data_type != acc_type)
@ -599,9 +603,9 @@ int main(int argc, char** argv)
float* o_ref_h = (float*) malloc(o_size * sizeof(float));
// The P matrix is stored as one big matrix of size S_Q x B x H x S_KV.
size_t const p_size = s_q * b * h * s_kv_padded;
const size_t p_size = s_q * b * h * s_kv_padded;
// The size in bytes.
size_t const p_size_in_bytes = get_size_in_bytes(p_size, acc_type);
const size_t p_size_in_bytes = get_size_in_bytes(p_size, acc_type);
// Allocate on the device.
void* p_d = nullptr;
FMHA_CHECK_CUDA(cudaMalloc(&p_d, p_size_in_bytes));
@ -614,7 +618,7 @@ int main(int argc, char** argv)
#endif // defined(STORE_P)
// The size in bytes of the S matrix (the data type may be different from P for int8).
size_t const s_size_in_bytes = get_size_in_bytes(p_size, data_type);
const size_t s_size_in_bytes = get_size_in_bytes(p_size, data_type);
// Allocate on the device.
void* s_d = nullptr;
FMHA_CHECK_CUDA(cudaMalloc(&s_d, s_size_in_bytes));
@ -634,9 +638,9 @@ int main(int argc, char** argv)
// WAR fOR MISSING CUBLAS FP8 NN SUPPORT.
// Transpose V, so that we can do a TN BMM2, i.e. O = S x V' instead of O = S x V.
size_t const v_size = s_kv_padded * b * h * d;
const size_t v_size = s_kv_padded * b * h * d;
// The size in bytes.
size_t const v_size_in_bytes = get_size_in_bytes(v_size, data_type);
const size_t v_size_in_bytes = get_size_in_bytes(v_size, data_type);
float* vt_h = (float*) malloc(v_size * sizeof(float));
void* vt_d = nullptr;
FMHA_CHECK_CUDA(cudaMalloc(&vt_d, v_size_in_bytes));
@ -676,7 +680,7 @@ int main(int argc, char** argv)
= [min_s, fix_s, b](int s, std::vector<uint32_t>& seqlens, std::vector<int>& cu_seqlens, void** cu_seqlens_d)
{
std::transform(seqlens.begin(), seqlens.end(), seqlens.begin(),
[=](uint32_t const)
[=](const uint32_t)
{
if (fix_s)
{
@ -728,7 +732,7 @@ int main(int argc, char** argv)
void* kv_packed_d = nullptr;
FMHA_CHECK_CUDA(cudaMalloc(&kv_packed_d, kv_packed_size_in_bytes));
size_t const o_packed_size = cu_seqlens_q.back() * h * d;
const size_t o_packed_size = cu_seqlens_q.back() * h * d;
// Allocate on the host.
float* o_packed_h = (float*) malloc(o_packed_size * sizeof(float));
float* o_ref_packed_h = (float*) malloc(o_packed_size * sizeof(float));

View File

@ -12,9 +12,10 @@
#include "softmax_impl.h"
void run_softmax_bf16(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_q_seqlens_d,
int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n, bool has_alibi)
void run_softmax_bf16(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d,
void* cu_q_seqlens_d, int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n,
bool has_alibi)
{
run_softmax<fmha::bf16_t, float>(dst, src, mask, softmax_sum_d, cu_q_seqlens_d, s_inner, s_outer, b, h, 0.f, 0.f,
softcapping_scale_bmm1, warps_n, has_alibi);
run_softmax<fmha::bf16_t, float>(dst, src, mask, attention_sinks, softmax_sum_d, cu_q_seqlens_d, s_inner, s_outer,
b, h, 0.f, 0.f, softcapping_scale_bmm1, warps_n, has_alibi);
}

View File

@ -12,9 +12,10 @@
#include "softmax_impl.h"
void run_softmax_fp16(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_q_seqlens_d,
int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n, bool has_alibi)
void run_softmax_fp16(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d,
void* cu_q_seqlens_d, int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n,
bool has_alibi)
{
run_softmax<uint16_t, uint16_t>(dst, src, mask, softmax_sum_d, cu_q_seqlens_d, s_inner, s_outer, b, h, 0.f, 0.f,
softcapping_scale_bmm1, warps_n, has_alibi);
run_softmax<uint16_t, uint16_t>(dst, src, mask, attention_sinks, softmax_sum_d, cu_q_seqlens_d, s_inner, s_outer, b,
h, 0.f, 0.f, softcapping_scale_bmm1, warps_n, has_alibi);
}

View File

@ -12,9 +12,10 @@
#include "softmax_impl.h"
void run_softmax_fp32(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_q_seqlens_d,
int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n, bool has_alibi)
void run_softmax_fp32(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d,
void* cu_q_seqlens_d, int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n,
bool has_alibi)
{
run_softmax<fmha::fp16_t, float>(dst, src, mask, softmax_sum_d, cu_q_seqlens_d, s_inner, s_outer, b, h, 0.f, 0.f,
softcapping_scale_bmm1, warps_n, has_alibi);
run_softmax<fmha::fp16_t, float>(dst, src, mask, attention_sinks, softmax_sum_d, cu_q_seqlens_d, s_inner, s_outer,
b, h, 0.f, 0.f, softcapping_scale_bmm1, warps_n, has_alibi);
}

View File

@ -12,10 +12,10 @@
#include "softmax_impl.h"
void run_softmax_e4m3(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_q_seqlens_d,
int s_inner, int s_outer, int b, int h, float scale_softmax, float softcapping_scale_bmm1, int warps_n,
bool has_alibi)
void run_softmax_e4m3(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d,
void* cu_q_seqlens_d, int s_inner, int s_outer, int b, int h, float scale_softmax, float softcapping_scale_bmm1,
int warps_n, bool has_alibi)
{
run_softmax<fmha::e4m3_t, float>(dst, src, mask, softmax_sum_d, cu_q_seqlens_d, s_inner, s_outer, b, h, 0.f,
scale_softmax, softcapping_scale_bmm1, warps_n, has_alibi);
run_softmax<fmha::e4m3_t, float>(dst, src, mask, attention_sinks, softmax_sum_d, cu_q_seqlens_d, s_inner, s_outer,
b, h, 0.f, scale_softmax, softcapping_scale_bmm1, warps_n, has_alibi);
}

View File

@ -10,6 +10,7 @@
* its affiliates is strictly prohibited.
*/
#include <cfloat>
#include <cstdio>
#include <fmha/numeric_types.h>
#include <fmha/utils.h>
@ -33,6 +34,8 @@ struct Softmax_params
Src_type const* src;
// Masks.
int8_t const* mask;
// Attention sinks (per head).
float const* attention_sinks;
// Softmax sum pointer.
float* softmax_sum;
// ALiBi
@ -148,7 +151,8 @@ static inline __device__ float apply_exp_(float x, float max)
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int N>
static inline __device__ void reduce(float (&data_fp32)[N][1], int8_t const (&mask)[N][1], int warps_n, float& sum_fp32)
static inline __device__ void reduce(
float (&data_fp32)[N][1], const int8_t (&mask)[N][1], int warps_n, float& sum_fp32, float const attention_sink)
{
// Apply the masks.
@ -233,7 +237,7 @@ static inline __device__ void reduce(float (&data_fp32)[N][1], int8_t const (&ma
}
// Normalize.
float inv_sum_fp32 = 1.f / sum_fp32;
float inv_sum_fp32 = 1.f / (sum_fp32 + expf(attention_sink - max_fp32));
#pragma unroll
for (int ii = 0; ii < N; ++ii)
{
@ -244,7 +248,8 @@ static inline __device__ void reduce(float (&data_fp32)[N][1], int8_t const (&ma
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int N>
static inline __device__ void reduce(float (&data_fp32)[N][2], int8_t const (&mask)[N][2], int warps_n, float& sum_fp32)
static inline __device__ void reduce(
float (&data_fp32)[N][2], const int8_t (&mask)[N][2], int warps_n, float& sum_fp32, float const attention_sink)
{
// Apply the masks.
#pragma unroll
@ -401,7 +406,7 @@ static inline __device__ void reduce(float (&data_fp32)[N][2], int8_t const (&ma
}
// Normalize.
float inv_sum_fp32 = 1.f / sum_fp32;
float inv_sum_fp32 = 1.f / (sum_fp32 + expf(attention_sink - max_fp32));
#pragma unroll
for (int ii = 0; ii < N; ++ii)
{
@ -413,7 +418,8 @@ static inline __device__ void reduce(float (&data_fp32)[N][2], int8_t const (&ma
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int N>
static inline __device__ void reduce(float (&data_fp32)[N][4], int8_t const (&mask)[N][4], int warps_n, float& sum_fp32)
static inline __device__ void reduce(
float (&data_fp32)[N][4], const int8_t (&mask)[N][4], int warps_n, float& sum_fp32, float const attention_sink)
{
// Apply the masks.
@ -824,7 +830,7 @@ static inline __device__ void reduce(float (&data_fp32)[N][4], int8_t const (&ma
}
// Normalize.
float inv_sum_fp32 = 1.f / sum_fp32;
float inv_sum_fp32 = 1.f / (sum_fp32 + expf(attention_sink - max_fp32));
#pragma unroll
for (int ii = 0; ii < N; ++ii)
{
@ -994,9 +1000,16 @@ static __global__ void softmax_kernel(Softmax_params<Dst_type, Src_type> params)
}
}
// The attention sink value.
float attention_sink = -FLT_MAX;
if (params.attention_sinks != nullptr)
{
attention_sink = params.attention_sinks[hi];
}
// Do the reduction.
float sum_fp32 = 0.f;
reduce(data_fp32, mask_, params.warps_n, sum_fp32);
reduce(data_fp32, mask_, params.warps_n, sum_fp32, attention_sink);
if (threadIdx.x == 0)
{
int sum_s = params.cu_q_seqlens[bi];
@ -1025,9 +1038,9 @@ static __global__ void softmax_kernel(Softmax_params<Dst_type, Src_type> params)
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Dst_type, typename Src_type>
void run_softmax(void* dst, void const* src, void const* mask, void* softmax_sum, void* cu_q_seqlens, int s_inner,
int s_outer, int b, int h, float scale_bmm1, float scale_softmax, float softcapping_scale_bmm1, int warps_n,
bool has_alibi)
void run_softmax(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum,
void* cu_q_seqlens, int s_inner, int s_outer, int b, int h, float scale_bmm1, float scale_softmax,
float softcapping_scale_bmm1, int warps_n, bool has_alibi)
{
Softmax_params<Dst_type, Src_type> params;
@ -1039,6 +1052,7 @@ void run_softmax(void* dst, void const* src, void const* mask, void* softmax_sum
params.softmax_sum = reinterpret_cast<float*>(softmax_sum);
params.cu_q_seqlens = reinterpret_cast<int*>(cu_q_seqlens);
params.mask = reinterpret_cast<int8_t const*>(mask);
params.attention_sinks = reinterpret_cast<float const*>(attention_sinks);
params.has_alibi = has_alibi;
// The dimensions and precomputed values.

View File

@ -12,10 +12,10 @@
#include "softmax_impl.h"
void run_softmax_int8(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_q_seqlens_d,
int s_inner, int s_outer, int b, int h, float scale_bmm1, float scale_softmax, float softcapping_scale_bmm1,
int warps_n, bool has_alibi)
void run_softmax_int8(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d,
void* cu_q_seqlens_d, int s_inner, int s_outer, int b, int h, float scale_bmm1, float scale_softmax,
float softcapping_scale_bmm1, int warps_n, bool has_alibi)
{
run_softmax<int8_t, int32_t>(dst, src, mask, softmax_sum_d, cu_q_seqlens_d, s_inner, s_outer, b, h, scale_bmm1,
scale_softmax, softcapping_scale_bmm1, warps_n, has_alibi);
run_softmax<int8_t, int32_t>(dst, src, mask, attention_sinks, softmax_sum_d, cu_q_seqlens_d, s_inner, s_outer, b, h,
scale_bmm1, scale_softmax, softcapping_scale_bmm1, warps_n, has_alibi);
}

View File

@ -1379,6 +1379,19 @@ __device__ inline ThrdRegRowMax mergeRowMax(
return mergedRowMax;
}
__device__ inline void addAttentionSinks(
ThrdRegRowMax& globalRowSum, ThrdRegRowMax const globalRowMax, float const* attentionSinks)
{
for (uint32_t i = 0; i < globalRowSum.size; i++)
{
uint32_t srcOffset = warp_size * i + laneId();
if (srcOffset < headGrpSize)
{
globalRowSum[i] += expf(attentionSinks[srcOffset] - globalRowMax[i]);
}
}
}
#ifdef NDEBUG
__device__ __forceinline__
#else
@ -1405,6 +1418,7 @@ CUBIN_EXPORT __global__
#if SPEC_DEC
MaskType const* __restrict__ mask, // [qSeqLen, divUp(qSeqLen, 32)].
#endif
float const* attentionSinks, // [headGrpSize]
#ifdef NDEBUG
KVCacheList<usePagedKVCache> const& cacheList,
#if BEAM_WIDTH > 1
@ -2371,6 +2385,12 @@ CUBIN_EXPORT __global__
float voScale = (isKVCacheQuantized ? kvCacheScale[0] : 1.F);
if (seqIterInit < nbSeqIters)
{ // otherwise rcpRowSum will be NAN.
// The attention sinks are moved to the multi-block reduction part if the multi-block is enabled.
if (!isMultiBlock && attentionSinks != nullptr)
{
// Attention sinks are per head.
addAttentionSinks(globalRowSum, globalRowMax, attentionSinks + headGrpSize * idxHeadGrp);
}
ThrdRegRowMax const rcpRowSum = __frcp_rn(globalRowSum);
#if LOW_PREC_OUTPUT
voScale *= rcpOutScale[0];
@ -2559,6 +2579,11 @@ CUBIN_EXPORT __global__
assert(std::isfinite(mergedRowSum[0]));
}
}
if (attentionSinks != nullptr)
{
// Attention sinks are per head.
addAttentionSinks(mergedRowSum, mergedRowMax, attentionSinks + headGrpSize * idxHeadGrp);
}
__syncthreads();
rescaleAcc(warp, sumAcc, fullRescaleMask, __frcp_rn(mergedRowSum));
GemmOutRegTile const mergedOutTile = toFp16(sumAcc);
@ -2615,6 +2640,7 @@ CUBIN_EXPORT __global__ __launch_bounds__(256, nbCtaPerSM) void kernel_mha(
MaskType const* __restrict__ mask, // [qSeqLen, divUp(qSeqLen, 32))] uint2 (each bit represents mask for one col
// position).
#endif
float const* attentionSinks, // [headGrpSize]
KVCacheList<usePagedKVCache> const cacheList,
#if BEAM_WIDTH > 1
BeamSearchParams const beamSearchParams,
@ -2640,7 +2666,7 @@ CUBIN_EXPORT __global__ __launch_bounds__(256, nbCtaPerSM) void kernel_mha(
#if SPEC_DEC
mask,
#endif
cacheList,
attentionSinks, cacheList,
#if BEAM_WIDTH > 1
beamSearchParams,
#endif
@ -2667,6 +2693,7 @@ void launchMHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
#else
InputHead const* q,
#endif
float const* attentionSinks, // [headGrpSize]
#if USE_PAGED_KV_CACHE
#if PAGED_KV_CACHE_LAYOUT == 1
GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM,
@ -2760,7 +2787,7 @@ void launchMHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
#if SPEC_DEC
mask,
#endif
cacheList,
attentionSinks, cacheList,
#if BEAM_WIDTH > 1
beamSearchParams,
#endif
@ -2788,7 +2815,7 @@ void launchMHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
#if SPEC_DEC
mask,
#endif
cacheList,
attentionSinks, cacheList,
#if BEAM_WIDTH > 1
beamSearchParams,
#endif

View File

@ -101,6 +101,7 @@ void launchMHA(cudaDeviceProp const& prop, uint32_t const nbKHeads,
#else
InputHead const* q,
#endif
float const* attentionSinks, // [headGrpSize]
#if USE_PAGED_KV_CACHE
#if PAGED_KV_CACHE_LAYOUT == 1
GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM,
@ -140,6 +141,7 @@ void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
#else
InputHead const* q,
#endif
float const* attentionSinks, // [headGrpSize]
#if USE_PAGED_KV_CACHE
#if PAGED_KV_CACHE_LAYOUT == 1
GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM,

View File

@ -428,6 +428,7 @@ __device__ RegColWiseVec computeWarpColSum(Gemm0Acc& src);
__device__ void storeGemm0AccToShm(
uint32_t warpRank, uint32_t lane, SharedMem::XBuffer& smemX, CtaBarrier& barConsumed, Gemm0Acc const& acc);
__device__ RegColWiseVec loadShmColWiseVecWithDup(ShmQWiseVec const& smemVec);
__device__ RegColWiseVec loadGmemColWiseVecWithDup(ShmQWiseVec const& gmemVec, uint32_t bound);
#else
__device__ RegRowWiseVec computeWarpGrpRowMax_sync(uint32_t warpRank, ShmQWiseVec& smemColMax, Gemm0Acc const& src);
__device__ void warpGrpApplyMask(Gemm0Acc& acc, uint32_t validColBeg, uint32_t validColEnd);
@ -453,7 +454,8 @@ __device__ void rescaleGemm1AccForNewColMax_sync(uint32_t warpRank, ShmQWiseVec
template <bool dstIsStrided = false, typename DstHead>
__device__ void finalizeAndWriteOut_sync(uint32_t threadRank, uint32_t warpRank, DstHead* dst,
SharedMem::OutSwizzleBuf& swizzleBuf, Gemm1Acc& acc, float xvoScale, CtaBarrier& warpGrpBar,
ShmQWiseVec const& accColSum, uint32_t nbKHeads = 0 /* only for final result in spec dec. */);
ShmQWiseVec const& accColSum, ShmQWiseVec const& accColMax, ShmQWiseVec const* attentionSinksVec,
uint32_t nbKHeads = 0 /* only for final result in spec dec. */);
#else
__device__ void transposeVTile(
uint32_t warpRank, uint32_t lane, SharedMem::VTBuffer& dst, SharedMem::VBuffer const& src);
@ -651,6 +653,7 @@ CUBIN_EXPORT __global__
#else
IOHead const* __restrict__ const q, // [nbReq][beamWidth][nbQHeads],
#endif
float const* attentionSinks, // [headGrpSize]
KVCacheList<usePagedKVCache> const cacheList,
#if USE_BEAM_SEARCH
BeamSearchParams const beamSearchParams,
@ -1252,7 +1255,7 @@ CUBIN_EXPORT __global__
IOHead* const dst = (scratchMem.tokens() + idxChunk).template cast<IOHead>();
#if SWAP_AB
finalizeAndWriteOut_sync(threadIdx.x, warpRank, dst, smem.outSwizzleBuf(idxXBuf), acc, xvoScale,
smem.gemm1WarpGrpBar, smem.gemm1AccColSum);
smem.gemm1WarpGrpBar, smem.gemm1AccColSum, smem.gemm1AccColMax, nullptr);
#else
finalizeAndWriteOut_sync(warpRank, dst, smem.outSwizzleBuf(idxXBuf), acc, xvoScale,
smem.gemm1AccColSum, 1, ctaNbValidTokens);
@ -1262,9 +1265,16 @@ CUBIN_EXPORT __global__
{
uint32_t const outOffset = headGrpSize * (nbKHeads * (beamWidth * ctaInputTokBeg) + idxHeadGrp);
OutputHead* const dst = &output[outOffset];
ShmQWiseVec const* attentionSinksVec = nullptr;
if (attentionSinks != nullptr)
{
attentionSinksVec
= reinterpret_cast<ShmQWiseVec const*>(attentionSinks + headGrpSize * idxHeadGrp);
}
#if SWAP_AB
finalizeAndWriteOut_sync<SPEC_DEC>(threadIdx.x, warpRank, dst, smem.outSwizzleBuf(idxXBuf), acc,
xvoScale, smem.gemm1WarpGrpBar, smem.gemm1AccColSum, nbKHeads);
xvoScale, smem.gemm1WarpGrpBar, smem.gemm1AccColSum, smem.gemm1AccColMax, attentionSinksVec,
nbKHeads);
#else
finalizeAndWriteOut_sync(warpRank, dst, smem.outSwizzleBuf(idxXBuf), acc, xvoScale,
smem.gemm1AccColSum, nbKHeads, ctaNbValidTokens);
@ -1585,6 +1595,17 @@ CUBIN_EXPORT __global__
}
unused(bar.consumed.arrive());
}
// Add the attention sinks.
if (attentionSinks != nullptr)
{
for (uint32_t i = 0; i < headsPerWarp; i++)
{
uint32_t const idxHead = wid + nbMathWarps * i;
float sink = expf(
attentionSinks[mha::min(idxHead, headGrpSize - 1) + idxHeadGrp * headGrpSize] - states[i].max);
states[i].sum += sink;
}
}
__syncthreads();
uint32_t const outOffset = headGrpSize * (nbKHeads * (beamWidth * ctaInputTokBeg) + idxHeadGrp);
auto const dst = &output[outOffset];
@ -2029,6 +2050,22 @@ __device__ inline RegColWiseVec loadShmColWiseVecWithDup(ShmQWiseVec const& smem
return ret;
}
__device__ inline RegColWiseVec loadGmemColWiseVecWithDup(ShmQWiseVec const& gmemVec, uint32_t bound)
{
RegColWiseVec ret;
constexpr uint32_t nbThrdsPerInstNBase = exactDiv(gmma::instNBase, GmmaAccCoreMat::cols);
auto const idx = laneId() % nbThrdsPerInstNBase;
#pragma unroll
for (uint32_t i = 0; i < exactDiv(ShmQWiseVec::size, gmma::instNBase); i++)
{
static_assert(nbThrdsPerInstNBase * RegColWiseVec::size == exactDiv(ShmQWiseVec::size, GmmaAccCoreMat::cols));
ret[i] = reinterpret_cast<
Vec<Vec<float, GmmaAccCoreMat::cols>, exactDiv(ShmQWiseVec::size, GmmaAccCoreMat::cols)> const&>(
gmemVec)[mha::min(i * nbThrdsPerInstNBase + idx, bound)];
}
return ret;
}
__device__ inline void warpGrpApplyMask(uint32_t warpRank, Gemm0Acc& acc, uint32_t validRowBeg, uint32_t validRowEnd)
{
uint32_t const idxInQuad = laneId() % 4;
@ -2878,12 +2915,19 @@ __device__ inline void saveTransposedOutput(uint32_t threadRank, uint32_t warpRa
template <bool dstIsStrided, typename DstHead>
__device__ inline void finalizeAndWriteOut_sync(uint32_t threadRank, uint32_t warpRank, DstHead* dst,
SharedMem::OutSwizzleBuf& swizzleBuf, Gemm1Acc& acc, float xvoScale, CtaBarrier& warpGrpBar,
ShmQWiseVec const& accColSum, uint32_t nbKHeads)
ShmQWiseVec const& accColSum, ShmQWiseVec const& accColMax, ShmQWiseVec const* attentionSinksVec, uint32_t nbKHeads)
{
// @fixme: if ctaNbQHeads is large, use loadShmColWiseVecNoDup + rcp + shfl to avoid 8x waste of mufu.rcp
// static_assert(ctaNbQHeads <= 8, "Warning: consider using loadShmColWiseVecNoDup + rcp + shfl to avoid 8x waste of
// mufu.rcp");
auto const regColSum = loadShmColWiseVecWithDup(accColSum);
auto regColSum = loadShmColWiseVecWithDup(accColSum);
if (attentionSinksVec != nullptr)
{
auto const regAccColMax = loadShmColWiseVecWithDup(accColMax);
auto const regAttentionSinks = loadGmemColWiseVecWithDup(attentionSinksVec[0], headGrpSize - 1);
auto regColSinks = expf(regAttentionSinks - regAccColMax);
regColSum = regColSum + regColSinks;
}
auto const regOutScale = __frcp_rn(regColSum) * xvoScale;
rescaleAcc(acc, regOutScale);
@ -3175,6 +3219,7 @@ void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
#else
InputHead const* q,
#endif
float const* attentionSinks, // [headGrpSize]
#if USE_PAGED_KV_CACHE
#if PAGED_KV_CACHE_LAYOUT == 1
GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM,
@ -3286,7 +3331,7 @@ void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
#else
q,
#endif
cacheList,
attentionSinks, cacheList,
#if USE_BEAM_SEARCH
beamSearchParams,
#endif
@ -3322,7 +3367,7 @@ void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
#else
q,
#endif
cacheList,
attentionSinks, cacheList,
#if USE_BEAM_SEARCH
beamSearchParams,
#endif

View File

@ -1859,12 +1859,13 @@ CUtensorMap makeTensorMapForQ(
#endif // IS_MLA
void launchMLA(cudaDeviceProp const& prop,
uint32_t inputSeqLen, // uniform for all requests and causal mask is assumed
uint32_t inputSeqLen, // uniform for all requests and causal mask is assumed
float qScale, OutputHead* output, InputHead const* q,
float* attentionSinks, // [headGrpSize], not supported.
#if USE_PAGED_KV_CACHE
GMemCacheHead* pool, // global pool of pages
GMemCacheHead* pool, // global pool of pages
KVCachePageIndex const*
kvCachePageList, // device pointer. shape: KVCachePage[batchSize][beamWidth][2][maxNbPagesPerSeq]
kvCachePageList, // device pointer. shape: KVCachePage[batchSize][beamWidth][2][maxNbPagesPerSeq]
#else
GMemKVCacheHead* kvCacheData,
#endif

View File

@ -45,7 +45,7 @@ using Vector = Matrix<Type, Size, 1>;
template <typename MathElem, uint32_t tileSize, bool isPaged, bool useBeamSearch>
Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refFlashAttention(IOHead const* q,
CacheSeq<isPaged, useBeamSearch> const& k, CacheSeq<isPaged, useBeamSearch> const& v, uint32_t seqLen, float qScale,
float kvScale, float xScale, uint32_t slidingWinSize)
float kvScale, float xScale, uint32_t slidingWinSize, float* attentionSinks)
{
uint32_t const nbTiles = divUp(seqLen, tileSize);
auto gemm1Acc = Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor>::Zero().eval();
@ -113,6 +113,16 @@ Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refFlashAt
}
rowSum += tileRowSum;
}
// Add the attention sinks.
if (attentionSinks != nullptr)
{
for (uint32_t i = 0; i < headGrpSize; i++)
{
rowSum[i] += expf(attentionSinks[i] - rowMax[i]);
}
}
Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> out
= gemm1Acc.array().colwise() * (xScale * kvScale / rowSum.array());
std::for_each(out.data(), out.data() + out.size(), [](float& e) { e = float(OutputElem(e)); });
@ -123,7 +133,7 @@ Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refFlashAt
template Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> \
refFlashAttention<prec, tileSize, isPaged, useBeamSearch>(IOHead const* q, \
CacheSeq<isPaged, useBeamSearch> const& k, CacheSeq<isPaged, useBeamSearch> const& v, uint32_t seqLen, \
float qScale, float kvScale, float xScale, uint32_t slidingWinSize)
float qScale, float kvScale, float xScale, uint32_t slidingWinSize, float* attentionSinks)
INSTANTIATE_refFlashAttention(CacheElem, 64, false, false);
INSTANTIATE_refFlashAttention(CacheElem, 64, false, true);
@ -143,7 +153,7 @@ Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refAttenti
#else
Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refAttention(IOHead const* q,
CacheSeq<isPaged, useBeamSearch> const& k, CacheSeq<isPaged, useBeamSearch> const& v, uint32_t seqLen, float qScale,
float kvScale, float xScale, uint32_t slidingWinSize)
float kvScale, float xScale, uint32_t slidingWinSize, float* attentionSinks)
{
#endif
float const rcpXScale = 1.f / xScale;
@ -184,7 +194,7 @@ Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refAttenti
Eigen::Matrix<float, headGrpSize, Eigen::Dynamic, Eigen::RowMajor> x
= (gemm0Acc.colwise() - rowMax).array().exp().eval();
Eigen::Vector<float, headGrpSize> const rowSum = x.rowwise().sum().eval();
Eigen::Vector<float, headGrpSize> rowSum = x.rowwise().sum().eval();
std::for_each(x.data(), x.data() + x.size(), [&](float& e) { e = float(MathElem(e * rcpXScale)); });
@ -200,6 +210,18 @@ Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refAttenti
}
}
}
// Add the attention sinks.
#if !SPEC_DEC
if (attentionSinks != nullptr)
{
for (uint32_t i = 0; i < headGrpSize; i++)
{
rowSum[i] += expf(attentionSinks[i] - rowMax[i]);
}
}
#endif
Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> out
= gemm1Acc.array().colwise() * (xScale * kvScale / rowSum.array());
std::for_each(out.data(), out.data() + out.size(), [](float& e) { e = float(OutputElem(e)); });
@ -217,7 +239,7 @@ Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refAttenti
template Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> \
refAttention<prec, isPaged, useBeamSearch>(IOHead const* q, CacheSeq<isPaged, useBeamSearch> const& k, \
CacheSeq<isPaged, useBeamSearch> const& v, uint32_t seqLen, float qScale, float kvScale, float xScale, \
uint32_t slidingWinSize)
uint32_t slidingWinSize, float* attentionSinks)
#endif
INSTANTIATE_refAttention(InputElem, false, false);
INSTANTIATE_refAttention(InputElem, false, true);

View File

@ -83,7 +83,7 @@ struct CacheSeq<true, true>
template <typename MathElem, uint32_t tileSize, bool isPaged, bool useBeamSearch>
Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refFlashAttention(IOHead const* q,
CacheSeq<isPaged, useBeamSearch> const& k, CacheSeq<isPaged, useBeamSearch> const& v, uint32_t seqLen, float qScale,
float kvScale, float xScale, uint32_t slidingWinSize);
float kvScale, float xScale, uint32_t slidingWinSize, float* attentionSinks);
template <typename MathElem, bool isPaged, bool useBeamSearch>
#if SPEC_DEC
@ -93,7 +93,7 @@ Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refAttenti
#else
Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refAttention(IOHead const* q,
CacheSeq<isPaged, useBeamSearch> const& k, CacheSeq<isPaged, useBeamSearch> const& v, uint32_t seqLen, float qScale,
float kvScale, float xScale, uint32_t slidingWinSize);
float kvScale, float xScale, uint32_t slidingWinSize, float* attentionSinks);
#endif
template <uint32_t ropeStyle>

View File

@ -130,7 +130,7 @@ template <uint32_t nbKHeads>
#endif
#endif
void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck, bool verbose = false,
bool saveData = false, uint32_t ctxLen = ~0U, uint32_t slidingWinSize = 1U << 30)
bool saveData = false, bool hasAttentionSinks = false, uint32_t ctxLen = ~0U, uint32_t slidingWinSize = 1U << 30)
{
#if IS_MLA
if (nbKHeads != 1)
@ -613,6 +613,17 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
}
}
// Allocate the attention sinks (per head)
auto attentionSinks = ManagedMemBuf<float>(nbQHeads);
// The attention sinks ptr.
float* attentionSinksPtr = hasAttentionSinks ? reinterpret_cast<float*>(attentionSinks.get()) : nullptr;
// Initialize the attention sinks (use large values to detect the potential bugs).
for (uint32_t i = 0; i < nbQHeads; i++)
{
// Range: [2, 5]
attentionSinks.get()[i] = 2.f + float(i % 4);
}
if (verbose)
{
printf("migrating data to gpu\n");
@ -640,6 +651,7 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
#if BEAM_WIDTH > 1
cacheIndir.prefetch(dev, stream);
#endif
attentionSinks.prefetch(dev, stream);
};
prefetchToDevice(device);
checkCuda(cudaMemsetAsync(semaphores.get(), 0, 4 * nbSemaphores, stream));
@ -720,6 +732,7 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
&qHeads[0][0][0],
#endif
#endif
attentionSinksPtr,
#if PAGED_KV_CACHE_LAYOUT == 1 && USE_PAGED_KV_CACHE
cacheKHeads.get(), cacheVHeads.get(),
#else
@ -1028,10 +1041,13 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
hostMask, qSeqLen, q_len);
#else
Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refOutput;
auto const refAttentionSinks
= hasAttentionSinks ? attentionSinksPtr + headGrpSize * idxKHead : nullptr;
if (useQGMMA)
{
refOutput = refFlashAttention<CacheElem, 64>(&qHeads[req][b][headGrpSize * idxKHead], kCacheSeq,
vCacheSeq, seqLen, qScaleForRef, kvCacheScale[0], xScale, slidingWinSize);
vCacheSeq, seqLen, qScaleForRef, kvCacheScale[0], xScale, slidingWinSize,
refAttentionSinks);
// refOutput = refAttention<CacheElem>(&qHeads[req][b][headGrpSize * idxKHead], kCacheSeq,
// vCacheSeq, seqLen, qScaleForRef, kvCacheScale[0], xScale, slidingWinSize);
}
@ -1039,8 +1055,9 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
{
// refOutput = refFlashAttention<InputElem, 64>(&qHeads[req][b][headGrpSize * idxKHead],
// kCacheSeq, vCacheSeq, seqLen, qScaleForRef, kvCacheScale[0], xScale);
refOutput = refAttention<InputElem>(&qHeads[req][b][headGrpSize * idxKHead], kCacheSeq,
vCacheSeq, seqLen, qScaleForRef, kvCacheScale[0], xScale, slidingWinSize);
refOutput
= refAttention<InputElem>(&qHeads[req][b][headGrpSize * idxKHead], kCacheSeq, vCacheSeq,
seqLen, qScaleForRef, kvCacheScale[0], xScale, slidingWinSize, refAttentionSinks);
}
#endif
if (lowPrecOutput)
@ -1196,11 +1213,23 @@ TEST(RefCheck, llama_V2_70b)
runTest<2>(2, 514, false, true);
runTest<1>(1, 4096, false, true);
#if SLIDING_WINDOW
runTest<2>(2, 4096, false, true, false, false, ~0, 256);
runTest<2>(2, 400, false, true, false, false, ~0U, 256);
runTest<2>(2, 4096, false, true, false, false, false, ~0, 256);
runTest<2>(2, 400, false, true, false, false, false, ~0U, 256);
#endif
runTest<8>(120, 367, false, true);
// runTest<8>(1792, 2048, false, true);
runTest<8>(1792, 2048, false, true);
}
TEST(RefCheck, attention_sinks)
{
auto runAttentionSinksTest = [](uint32_t batchSize, uint32_t seqLen)
{ runTest<8>(batchSize, seqLen, false, true, false, false, /*hasAttentionSinks*/ true); };
runAttentionSinksTest(2, 2);
runAttentionSinksTest(2, 15);
runAttentionSinksTest(2, 256);
runAttentionSinksTest(2, 514);
runAttentionSinksTest(1, 4096);
}
TEST(Perf, tracing_long)
@ -1264,7 +1293,7 @@ TEST(Perf, mlperf_gptj)
#ifndef NDEBUG
GTEST_SKIP() << "Skipping perf tests for debug build";
#endif
runTest<32>(396, 800 + 224, true, false, false, false, 800);
runTest<32>(396, 800 + 224, true, false, false, false, false, 800);
}
TEST(Perf, mlperf_llama)

View File

@ -53,6 +53,7 @@ using namespace CUTLASS_MOE_GEMM_KERNELS_NAMESPACE;
using CUTLASS_MOE_GEMM_NAMESPACE::TmaWarpSpecializedGroupedGemmInput;
using CUTLASS_MOE_GEMM_KERNELS_NAMESPACE::CutlassMoeFCRunner;
using CUTLASS_MOE_GEMM_NAMESPACE::ActivationType;
using CUTLASS_MOE_GEMM_KERNELS_NAMESPACE::ActivationParams;
using CUTLASS_MOE_GEMM_NAMESPACE::isGatedActivation;
static BufferManager::CudaStreamPtr streamPtr;
@ -984,7 +985,7 @@ public:
mSelectedExperts + mSelectedExpertsSize * mBufferIndex,
mUseFinalScale ? mScaleProbs + mScaleProbsSize * mBufferIndex : nullptr,
mExpertWeight1 + mExpertWeight1Size * mBufferIndex, mExpertBias1 + mExpertBias1Size * mBufferIndex,
mActType, mExpertWeight2 + mExpertWeight2Size * mBufferIndex,
ActivationParams(mActType), mExpertWeight2 + mExpertWeight2Size * mBufferIndex,
mExpertBias2 + mExpertBias2Size * mBufferIndex, mQuantParams[mBufferIndex], mTotalTokens, mHiddenSize,
mInterSize, mNumExperts, mK, mWorkspace + mWorkspaceSize * mBufferIndex,
mFinalOutput + mFinalOutputSize * mBufferIndex,
@ -996,7 +997,7 @@ public:
mSelectedExperts + mSelectedExpertsSize * mBufferIndex,
mUseFinalScale ? mScaleProbs + mScaleProbsSize * mBufferIndex : nullptr,
mExpertWeight1 + mExpertWeight1Size * mBufferIndex, mExpertBias1 + mExpertBias1Size * mBufferIndex,
mActType, mExpertWeight2 + mExpertWeight2Size * mBufferIndex,
ActivationParams(mActType), mExpertWeight2 + mExpertWeight2Size * mBufferIndex,
mExpertBias2 + mExpertBias2Size * mBufferIndex, mQuantParams[mBufferIndex], mTotalTokens, mHiddenSize,
mInterSize, mNumExperts, mK, mWorkspace + mWorkspaceSize * mBufferIndex,
mFinalOutput + mFinalOutputSize * mBufferIndex,

View File

@ -55,6 +55,7 @@ struct FusedQKVMaskedAttentionDispatchParams
T const* qkv_bias;
T const* relative_attention_bias;
bool const* attention_mask;
float const* attention_sinks;
float const* logn_scaling_ptr;
int const* cache_indir;
void* context_buf;
@ -71,6 +72,7 @@ struct FusedQKVMaskedAttentionDispatchParams
RotaryScalingType rotary_embedding_scale_type;
float rotary_embedding_scale;
float const* rotary_embedding_inv_freq_cache;
float2 const* rotary_embedding_cos_sin_cache;
float rotary_embedding_short_m_scale;
float rotary_embedding_long_m_scale;
int rotary_embedding_max_positions;
@ -225,6 +227,7 @@ bool AttentionOp::convertMMHAParamsToXQAParams(tensorrt_llm::kernels::XQAParams&
xqaParams.output = generationsParams.context_buf;
xqaParams.qkv = generationsParams.attention_input;
xqaParams.cache_indir = generationsParams.cache_indir;
xqaParams.attention_sinks = generationsParams.attention_sinks;
xqaParams.kv_scale_orig_quant = generationsParams.kv_scale_orig_quant;
xqaParams.kv_scale_quant_orig = generationsParams.kv_scale_quant_orig;
xqaParams.host_past_key_value_lengths = generationsParams.host_past_key_value_lengths;
@ -596,6 +599,7 @@ void fusedQKV_masked_attention_dispatch(Multihead_attention_params<T_MMHA, CROSS
params.rotary_embedding_scale_type = input_params.rotary_embedding_scale_type;
params.rotary_embedding_scale = input_params.rotary_embedding_scale;
params.rotary_embedding_inv_freq_cache = input_params.rotary_embedding_inv_freq_cache;
params.rotary_embedding_cos_sin_cache = input_params.rotary_embedding_cos_sin_cache;
params.rotary_embedding_short_m_scale = input_params.rotary_embedding_short_m_scale;
params.rotary_embedding_long_m_scale = input_params.rotary_embedding_long_m_scale;
params.rotary_embedding_max_positions = input_params.rotary_embedding_max_positions;
@ -620,6 +624,9 @@ void fusedQKV_masked_attention_dispatch(Multihead_attention_params<T_MMHA, CROSS
params.attention_mask = input_params.attention_mask;
params.attention_mask_stride = input_params.attention_mask_stride;
// Attention sinks.
params.attention_sinks = input_params.attention_sinks;
// The slope of linear position bias per head, e.g., ALiBi.
if (input_params.linear_bias_slopes != nullptr)
{
@ -1691,6 +1698,7 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
fmhaParams.outputPtr
= mCpSize > 1 ? gatherOutBuffer : params.context_buf; // only use [totalLength, h / cpSize, Dh]
fmhaParams.outputSfPtr = params.context_buf_sf;
fmhaParams.attentionSinksPtr = params.attention_sinks;
fmhaParams.packedMaskPtr = params.attention_packed_mask;
if constexpr (std::is_same_v<KVCacheBuffer, KVBlockArray>)
{
@ -2220,6 +2228,7 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams<T> const& params, cud
dispatch_params.relative_attention_bias_stride = relative_attention_bias_stride;
dispatch_params.attention_mask = params.attention_mask;
dispatch_params.attention_mask_stride = params.attention_mask_stride;
dispatch_params.attention_sinks = params.attention_sinks;
dispatch_params.max_distance = max_distance;
dispatch_params.cache_indir = params.cache_indir;
dispatch_params.context_buf = mCpSize > 1 ? mhaOutput : params.context_buf; //
@ -2267,6 +2276,7 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams<T> const& params, cud
dispatch_params.rotary_embedding_scale_type = mRotaryEmbeddingScaleType;
dispatch_params.rotary_embedding_scale = mRotaryEmbeddingScale;
dispatch_params.rotary_embedding_inv_freq_cache = params.rotary_inv_freq;
dispatch_params.rotary_embedding_cos_sin_cache = params.rotary_cos_sin;
dispatch_params.rotary_embedding_short_m_scale = mRotaryEmbeddingShortMscale;
dispatch_params.rotary_embedding_long_m_scale = mRotaryEmbeddingLongMscale;
dispatch_params.rotary_embedding_max_positions = mRotaryEmbeddingMaxPositions;

View File

@ -65,6 +65,8 @@ public:
T const* qkv_bias = nullptr;
// Attention mask input, which has shape of [batch_size, attention_mask_stride].
bool const* attention_mask = nullptr;
// Attention sinks with shape of [num_heads_q] float.
float const* attention_sinks = nullptr;
// Rotary inv_freq cache buffer to avoid re-computing.
float const* rotary_inv_freq = nullptr;
// Rotary cos sin cache buffer to avoid re-computing.

View File

@ -27,6 +27,78 @@
namespace cutlass::gemm::collective::detail
{
using namespace cute;
typedef uint32_t __nv_fp4x8_storage_t;
typedef uint32_t __nv_bf16x2_storage_t;
typedef cutlass::uint128_t __nv_bf16x8_storage_t;
constexpr int int4_group_size = 128;
constexpr int mxfp4_group_size = 32;
inline __device__ unsigned prmt(unsigned hi, unsigned lo, unsigned select_code)
{
unsigned res = 0;
asm volatile(
"{\n"
"prmt.b32 %0, %1, %2, %3;\n"
"}\n"
: "=r"(res)
: "r"(lo), "r"(hi), "r"(select_code));
return res;
}
__device__ __inline__ __nv_fp8x4_storage_t cvt_lut_bf16(unsigned const index)
{
const __nv_fp8x4_storage_t h4b_lut = 0x03020100U; // 7654
const __nv_fp8x4_storage_t l4b_lut = 0xFFFEFC00U; // 3210
__nv_fp8x4_storage_t lut_res = prmt(h4b_lut, l4b_lut, index);
return lut_res;
}
__device__ __inline__ __nv_bf16x8_storage_t psx_cvt_lut_prmt_fp4x8_to_bf16x8(const __nv_fp4x8_storage_t fp4x8)
{
__nv_bf16x8_storage_t bf16x8_raw = {0, 0};
__nv_bf16x2_storage_t* bf16x2_raw = reinterpret_cast<__nv_bf16x2_storage_t*>(&bf16x8_raw);
unsigned zero_padding = 0x00000000U;
unsigned h4b_em_fp4x4 = (fp4x8 & 0x77770000U) >> 16U;
unsigned l4b_em_fp4x4 = (fp4x8 & 0x00007777U);
__nv_fp8x4_storage_t h4b_2to9_bits = cvt_lut_bf16(h4b_em_fp4x4); // 7654
__nv_fp8x4_storage_t l4b_2to9_bits = cvt_lut_bf16(l4b_em_fp4x4); // 3210
bf16x2_raw[0] = prmt(zero_padding, l4b_2to9_bits, 0x1707U) >> 2U; // 1 0
bf16x2_raw[1] = prmt(zero_padding, l4b_2to9_bits, 0x3727U) >> 2U; // 3 2
bf16x2_raw[2] = prmt(h4b_2to9_bits, zero_padding, 0x5040U) >> 2U; // 5 4
bf16x2_raw[3] = prmt(h4b_2to9_bits, zero_padding, 0x7060U) >> 2U; // 7 6
__nv_bf16x2_storage_t bf16x2_0to1_bits;
__nv_fp8x4_storage_t h_fp8x2_0to1_bits = (fp4x8 & 0x0000C0C0U); // 3 1
__nv_fp8x4_storage_t l_fp8x2_0to1_bits = (fp4x8 & 0x00000C0CU) << 4U; // 2 0
bf16x2_0to1_bits = prmt(h_fp8x2_0to1_bits, l_fp8x2_0to1_bits, 0x4707U); // 1 0
bf16x2_raw[0] = bf16x2_raw[0] | bf16x2_0to1_bits;
bf16x2_0to1_bits = prmt(h_fp8x2_0to1_bits, l_fp8x2_0to1_bits, 0x5717U); // 3 2
bf16x2_raw[1] = bf16x2_raw[1] | bf16x2_0to1_bits;
h_fp8x2_0to1_bits = (fp4x8 & 0xC0C00000U); // 7 5
l_fp8x2_0to1_bits = (fp4x8 & 0x0C0C0000U) << 4U; // 6 4
bf16x2_0to1_bits = prmt(h_fp8x2_0to1_bits, l_fp8x2_0to1_bits, 0x6020U); // 5 4
bf16x2_raw[2] = bf16x2_raw[2] | bf16x2_0to1_bits;
bf16x2_0to1_bits = prmt(h_fp8x2_0to1_bits, l_fp8x2_0to1_bits, 0x7030U); // 7 6
bf16x2_raw[3] = bf16x2_raw[3] | bf16x2_0to1_bits;
return bf16x8_raw;
}
template <class Collective>
struct MixedGroupedGemmInputUtils
{
@ -46,6 +118,7 @@ private:
static constexpr auto KernelConversionMode = Collective::KernelConversionMode;
static constexpr auto ModeHasScales = Collective::ModeHasScales;
static constexpr auto UseScaleLookupTable = Collective::UseScaleLookupTable;
static constexpr auto UseFP4ToBF16LookupTable = Collective::UseFP4ToBF16LookupTable;
public:
static constexpr auto elements_per_smem_scale()
@ -239,6 +312,27 @@ public:
}
}
// The core converter uses a lookup table to converts i4 -> 8 bit value.
template <class EngineIn, class LayoutIn, class EngineOut,
class LayoutOut>
CUTLASS_DEVICE static void fp4tobf16_lookup_table_convert( // Accept mutable temporaries
Tensor<EngineIn, LayoutIn> const& src, Tensor<EngineOut, LayoutOut>&& dst)
{
fp4tobf16_lookup_table_convert(src, dst);
}
template <class EngineIn, class LayoutIn, class EngineOut, class LayoutOut>
CUTLASS_DEVICE static void fp4tobf16_lookup_table_convert(
Tensor<EngineIn, LayoutIn> const& src, Tensor<EngineOut, LayoutOut>& dst)
{
// View the input as reg
auto&& src_ = cute::recast<__nv_fp4x8_storage_t>(src)(0);
auto&& dst_ = cute::recast<__nv_bf16x8_storage_t>(dst)(0);
dst_ = psx_cvt_lut_prmt_fp4x8_to_bf16x8(src_);
}
/// Utilities to dequantize A.
template <class Layout>
CUTLASS_DEVICE static void static_check_scale(Layout const& tensor)
@ -253,7 +347,6 @@ public:
static_check_scale(flatten(Layout{}));
}
// dequantize_A_kblock is here!!!
template <class EngineIn, class EngineOut, class LayoutIn, class LayoutOut, class... Ts>
CUTLASS_DEVICE static void dequantize_A_kblock(Tensor<EngineIn, LayoutIn> const& tCrA_load,
Tensor<EngineOut, LayoutOut>& tCrA_mma, cute::tuple<Ts...>& partitioned_extra_info, int const k_block)
@ -288,8 +381,6 @@ public:
}
else if constexpr (UseScaleLookupTable)
{
// this path
constexpr int num_elements = decltype(size(src))::value;
static_assert(is_same_v<RealSwappedElementA, cutlass::int4b_t>,
"Lookup table only supports int4 being the quant type now.");
@ -424,7 +515,6 @@ public:
static_assert(size_v<LayoutIn> == cosize_v<LayoutIn>);
static_assert(size_v<LayoutOut> == cosize_v<LayoutOut>);
using SrcType = typename EngineIn::value_type;
using DstType = typename EngineOut::value_type;
Tensor src = tCrA_load(_, _, k_block);
Tensor dst = tCrA_mma(_, _, k_block);
@ -441,7 +531,14 @@ public:
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size<1>(dst_vm); ++i)
{
LayoutAwareConvert(src_vm(_, i), dst_vm(_, i));
if constexpr (UseFP4ToBF16LookupTable)
{
fp4tobf16_lookup_table_convert(src_vm(_, i), dst_vm(_, i));
}
else
{
LayoutAwareConvert(src_vm(_, i), dst_vm(_, i));
}
}
}

View File

@ -30,37 +30,12 @@
#include "cute/atom/mma_atom.hpp"
#include "cute/numeric/arithmetic_tuple.hpp"
#define GROUP_SIZE 128
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::collective
{
using namespace cute;
template <int N>
CUTE_HOST_DEVICE void warpgroup_wait_()
{
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
cutlass::arch::synclog_emit_warpgroup_wait(__LINE__, N);
asm volatile("wgmma.wait_group.sync.aligned %0;\n" ::"n"(N) : "memory");
#else
CUTE_INVALID_CONTROL_PATH("Attempting to use wgmma.wait_group<N> without CUTE_ARCH_MMA_SM90A_ENABLED");
#endif
}
CUTLASS_DEVICE void warpgroup_wait_dispatch(int onthefly_count)
{
switch (onthefly_count)
{
case 0: warpgroup_wait_<0>(); break;
case 4: warpgroup_wait_<4>(); break;
case 8: warpgroup_wait_<8>(); break;
case 12: warpgroup_wait_<12>(); break;
default: assert(false && "Invalid onthefly_count value");
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////
// WarpSpecialized Mainloop
@ -91,7 +66,7 @@ public:
private:
template <class T>
friend struct detail::MixedGroupedGemmInputUtils;
using CollectiveType = CollectiveMma<DispatchPolicy, TileShape_, ElementAOptionalTuple, StrideA_,
using CollectiveType = CollectiveMmaArrayMixedInput<DispatchPolicy, TileShape_, ElementAOptionalTuple, StrideA_,
ElementBOptionalTuple, StrideB_, TiledMma_, GmemTiledCopyA_, SmemLayoutAtomA_, SmemCopyAtomA_, TransformA_,
GmemTiledCopyB_, SmemLayoutAtomB_, SmemCopyAtomB_, TransformB_>;
using Utils = detail::MixedGroupedGemmInputUtils<CollectiveType>;
@ -146,6 +121,11 @@ public:
static_assert(cutlass::gemm::detail::is_mn_major<NonVoidStrideScale>(),
"Scale must be MN major [Col Major if A is scaled, Row Major if B is scaled].");
static constexpr bool IsMXFP4 = cute::is_same_v<ElementA, cutlass::float_e2m1_t>;
// Group size 128 for int4 weights
// Group size 32 for mxfp4 weights
static constexpr int ScalingGroupSize = IsMXFP4 ? detail::mxfp4_group_size : detail::int4_group_size;
using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{}));
using TiledMma = TiledMma_;
using ElementAccumulator = typename TiledMma::ValTypeC;
@ -268,6 +248,8 @@ public:
|| KernelConversionMode == ConversionMode::ConvertAndScaleWithZero;
static constexpr bool UseScaleLookupTable
= KernelConversionMode == ConversionMode::ConvertAndScale && cutlass::detail::is_Array_v<ElementScale>;
static constexpr bool UseFP4ToBF16LookupTable = KernelConversionMode == ConversionMode::ConvertAndScale
&& cute::is_same_v<ElementA, cutlass::float_e2m1_t> && cute::is_same_v<ElementB, cutlass::bfloat16_t>;
static constexpr size_t SmemAlignmentA = cutlass::detail::alignment_for_swizzle(SmemLayoutA{});
static constexpr size_t SmemAlignmentB = cutlass::detail::alignment_for_swizzle(SmemLayoutB{});
static constexpr size_t SmemAlignmentScale = cute::max(SmemAlignmentA, SmemAlignmentB);
@ -705,7 +687,7 @@ public:
{
// The real scale_k that actually works
// auto scale_k = K / mainloop_params.chunk_size;
auto scale_k = K / GROUP_SIZE;
auto scale_k = K / ScalingGroupSize;
Tensor mS_mkl = mainloop_params.tma_load_scale.get_tma_tensor(make_shape(M, scale_k, L)); // (m,scale_k,l)
Tensor gS_mkl = local_tile(mS_mkl, ScaleTileShape{}, make_coord(_, _)); // (BLK_M,BLK_Scale_K,m,scale_k,l)
@ -872,7 +854,6 @@ public:
}
else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero)
{
// zero copy
auto tZgZ = get<2>(extra_input_partitions);
auto tZsZ = get<3>(extra_input_partitions);
if (cute::elect_one_sync())
@ -979,7 +960,8 @@ public:
return make_tensor_like<RealSwappedElementA>(tCsA(_, _, _, Int<0>{}));
}
}();
Tensor tCsB = mma_warpgroup_slice.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
Tensor tCsB = mma_warpgroup_slice.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
// tCrB is just a view of the tensor tCsB
Tensor tCrB = mma_warpgroup_slice.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE)
//
@ -1013,8 +995,8 @@ public:
multiply_add<ElementAccumulator> fma;
constexpr int NumMMAsPerChunk = GROUP_SIZE / cute::get<0, 1>(tCsB.shape())();
constexpr int NumChunksPerTileK = cute::size<1>(sA.shape())() / GROUP_SIZE;
constexpr int NumMMAsPerChunk = ScalingGroupSize / cute::get<0, 1>(tCsB.shape())();
constexpr int NumChunksPerTileK = cute::size<1>(sA.shape())() / ScalingGroupSize;
cute::array<decltype(make_fragment_like(accum)), NumChunksPerTileK> intermediate_array;
constexpr int K_BLOCK_MAX = size<2>(tCrA_load);
@ -1045,8 +1027,6 @@ public:
// src: tCrA_load, dst: tCrA_mma
Utils::convert_A_kblock(tCrA_load, tCrA_mma, 0);
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int chunk_id = 0; chunk_id < NumChunksPerTileK; ++chunk_id)
@ -1079,10 +1059,11 @@ public:
}
}
warpgroup_wait<0>();
CUTLASS_PRAGMA_UNROLL
for (int chunk_id_ = 0; chunk_id_ < NumChunksPerTileK; ++chunk_id_)
{
warpgroup_wait_dispatch((NumChunksPerTileK - chunk_id_ - 1) * NumMMAsPerChunk);
warpgroup_fence_operand(intermediate_array[chunk_id_]);
// Apply the group-wise scaling
@ -1129,7 +1110,6 @@ public:
Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, partitioned_extra_info,
copy_partitions_extra_info, 1, smem_pipe_read.index());
warpgroup_wait<K_WAIT_MAX>();
Utils::convert_A_kblock(tCrA_load, tCrA_mma, 0);
}
}
@ -1169,8 +1149,6 @@ public:
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
warpgroup_commit_batch();
warpgroup_wait<K_WAIT_MAX>(); // We have K_BLOCK_MAX - 1 GMMA instructions pending for this stage,
// so we can release prior barrier
if (k_block == K_BLOCK_MAX - 1)
{
pipeline.consumer_release(
@ -1187,10 +1165,11 @@ public:
{
// The last k_block
warpgroup_wait<0>();
CUTLASS_PRAGMA_UNROLL
for (int chunk_id_ = 0; chunk_id_ < NumChunksPerTileK; ++chunk_id_)
{
warpgroup_wait_dispatch((NumChunksPerTileK - chunk_id_ - 1) * NumMMAsPerChunk);
warpgroup_fence_operand(intermediate_array[chunk_id_]);
// Apply the group-wise scaling
@ -1257,7 +1236,6 @@ public:
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
warpgroup_commit_batch();
warpgroup_wait<K_WAIT_MAX>();
if (k_block == K_BLOCK_MAX - 1)
{
// release prior barrier
@ -1318,7 +1296,7 @@ public:
smem_pipe_release.advance(k_tile_count);
// Wait on all GMMAs to complete
warpgroup_wait<0>();
// warpgroup_wait<0>();
for (int count = 0; count < prologue_mma_count; ++count)
{
@ -1462,7 +1440,7 @@ public:
{
NonVoidElementScale const* ptr_S = nullptr;
// auto scale_k = K / mainloop_params.chunk_size;
auto scale_k = K / GROUP_SIZE;
auto scale_k = K / ScalingGroupSize;
Tensor tensor_scale = make_tensor(
detail::get_logical_ptr(ptr_S), make_shape(M, scale_k, Int<1>{}), mainloop_params.dS[next_group]);
cute::detail::fill_tma_gmem_shape_stride(
@ -1472,7 +1450,7 @@ public:
{
ElementZero const* ptr_Z = nullptr;
// auto scale_k = K / mainloop_params.chunk_size;
auto scale_k = K / GROUP_SIZE;
auto scale_k = K / ScalingGroupSize;
Tensor tensor_zero = make_tensor(
detail::get_logical_ptr(ptr_Z), make_shape(M, scale_k, Int<1>{}), mainloop_params.dS[next_group]);
cute::detail::fill_tma_gmem_shape_stride(

View File

@ -256,9 +256,9 @@ public:
constexpr int SF_VEC_SIZE = 16;
using PackedVec = PackedVec<DType>;
PackedVec pack_val = *reinterpret_cast<PackedVec const*>(&val);
auto sf_out = cvt_quant_to_fp4_get_sf_out_offset<uint32_t, 2, SF_VEC_SIZE>(std::nullopt, token_id,
m_access_id_in_token, std::nullopt, m_params.hidden_dim,
reinterpret_cast<uint32_t*>(m_params.scale_out), m_params.layout);
auto sf_out = cvt_quant_get_sf_out_offset<uint32_t, 2>(std::nullopt, token_id, m_access_id_in_token,
std::nullopt, m_params.hidden_dim / SF_VEC_SIZE, reinterpret_cast<uint32_t*>(m_params.scale_out),
m_params.layout);
reinterpret_cast<uint32_t*>(m_params.quant_out)[m_access_id]
= cvt_warp_fp16_to_fp4<DType, SF_VEC_SIZE, false>(pack_val, m_scale_factor, sf_out);
}

View File

@ -132,7 +132,7 @@ struct AllReduceFusionParams
float rms_eps;
float* scale_factor;
bool use_oneshot;
FP4QuantizationSFLayout layout = FP4QuantizationSFLayout::SWIZZLED;
QuantizationSFLayout layout = QuantizationSFLayout::SWIZZLED;
cudaStream_t stream;
AllReduceFusionPattern pattern;
bool trigger_completion_at_end = true;

View File

@ -99,15 +99,15 @@ __device__ struct __attribute__((aligned(32))) LamportFlags
uint32_t* offset_access_ptr;
uint32_t* buffer_flags;
__device__ explicit LamportFlags(uint32_t* buffer_flags)
__device__ explicit LamportFlags(uint32_t* buffer_flags, uint32_t buffer_size)
: offset_access_ptr(&buffer_flags[4])
, buffer_flags(buffer_flags)
, buffer_size(buffer_size)
{
uint4 flag = reinterpret_cast<uint4*>(buffer_flags)[0];
buffer_size = flag.z;
input_offset = flag.x * (buffer_size << 1U);
clear_offset = flag.y * (buffer_size << 1U);
num_tokens_prev = flag.w;
num_tokens_prev = flag.z;
}
__device__ void cta_arrive()
@ -135,7 +135,7 @@ __device__ struct __attribute__((aligned(32))) LamportFlags
uint4 flag = reinterpret_cast<uint4*>(buffer_flags)[0];
buffer_flags[0] = (flag.x + 1) % 3;
buffer_flags[1] = (flag.y + 1) % 3;
buffer_flags[3] = num_tokens;
buffer_flags[2] = num_tokens;
*(offset_access_ptr) = 0;
}
}
@ -144,7 +144,7 @@ __device__ struct __attribute__((aligned(32))) LamportFlags
template <int WORLD_SIZE, typename T>
__global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_ptrs, T* mcast_ptr, int num_tokens,
int buffer_M, int token_dim, int rank, uint32_t* buffer_flags, bool wait_for_results)
int buffer_M, int token_dim, int rank, uint32_t buffer_size, uint32_t* buffer_flags, bool wait_for_results)
{
int elt = blockIdx.y * blockDim.x + threadIdx.x;
if (elt >= token_dim)
@ -155,7 +155,7 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_
cudaGridDependencySynchronize();
#endif
LamportFlags flags(buffer_flags);
LamportFlags flags(buffer_flags, buffer_size);
// Capture the number of tokens in previous iteration so that we can properly clear the buffer
// The scatter stage will use the buffer in WORLD_SIZE granularity, thus we need to round up
@ -217,15 +217,17 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaTriggerProgrammaticLaunchCompletion();
#endif
// Similarly clear broadcast buffer here
for (int clr_tok = 0; clr_tok < clr_toks_cta; clr_tok++)
if (elt < token_dim)
{
uint32_t clr_token_idx = token + clr_tok * gridDim.x;
if (clr_token_idx < buffer_M)
// Similarly clear broadcast buffer here
for (int clr_tok = 0; clr_tok < clr_toks_cta; clr_tok++)
{
input_ptrs[rank][flags.clear_offset + buffer_M * token_dim + clr_token_idx * token_dim + elt]
= fromFloat<T>(-0.f);
uint32_t clr_token_idx = token + clr_tok * gridDim.x;
if (clr_token_idx < buffer_M)
{
input_ptrs[rank][flags.clear_offset + buffer_M * token_dim + clr_token_idx * token_dim + elt]
= fromFloat<T>(-0.f);
}
}
}
@ -240,20 +242,24 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_
// blockDim.x / ELTS_PER_LOAD should be at least the size of a warp (32)
if (threadIdx.x < (blockDim.x / ELTS_PER_LOAD))
{
uint64_t current_pos = blockIdx.x * token_dim + blockIdx.y * blockDim.x + threadIdx.x * ELTS_PER_LOAD;
uint64_t elt_load_offset = blockIdx.y * blockDim.x + threadIdx.x * ELTS_PER_LOAD;
if (elt_load_offset < token_dim)
{
uint64_t current_pos = blockIdx.x * token_dim + elt_load_offset;
void* lamport_ptr = (void*) &input_ptrs[rank][flags.input_offset + buffer_M * token_dim + current_pos];
// We have 2 assumptions here:
// 1. The write is atomic in 8B granularity -> Each buffer in the buffer group should be aligned to 8B
// 2. The num_token * token_dim is divisible by ELTS_PER_LOAD (4 for BF16 and 2 for FP32)
float2 val = loadfloat2(lamport_ptr);
while (isNegZero(*(T*) &val))
{
val = loadfloat2(lamport_ptr);
}
if (output_ptr)
{
*((float2*) &output_ptr[current_pos]) = val;
void* lamport_ptr = (void*) &input_ptrs[rank][flags.input_offset + buffer_M * token_dim + current_pos];
// We have 2 assumptions here:
// 1. The write is atomic in 8B granularity -> Each buffer in the buffer group should be aligned to 8B
// 2. The num_token * token_dim is divisible by ELTS_PER_LOAD (4 for BF16 and 2 for FP32)
float2 val = loadfloat2(lamport_ptr);
while (isNegZero(*(T*) &val))
{
val = loadfloat2(lamport_ptr);
}
if (output_ptr)
{
*((float2*) &output_ptr[current_pos]) = val;
}
}
}
@ -263,10 +269,11 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_
}
#define LAUNCH_ALL_REDUCE_KERNEL(WORLD_SIZE, T) \
TLLM_CUDA_CHECK(cudaLaunchKernelEx(&config, &twoshot_allreduce_kernel<WORLD_SIZE, T>, \
reinterpret_cast<T*>(params.output), reinterpret_cast<T*>(params.input), \
reinterpret_cast<T**>(params.buffer_ptrs_dev), (T*) params.multicast_ptr, params.num_tokens, params.buffer_M, \
params.token_dim, params.rank, reinterpret_cast<uint32_t*>(params.buffer_flags), params.wait_for_results));
TLLM_CUDA_CHECK( \
cudaLaunchKernelEx(&config, &twoshot_allreduce_kernel<WORLD_SIZE, T>, reinterpret_cast<T*>(params.output), \
reinterpret_cast<T*>(params.input), reinterpret_cast<T**>(params.buffer_ptrs_dev), \
(T*) params.multicast_ptr, params.num_tokens, params.buffer_M, params.token_dim, params.rank, \
params.buffer_size, reinterpret_cast<uint32_t*>(params.buffer_flags), params.wait_for_results));
void twoshot_allreduce_op(AllReduceParams const& params)
{
@ -369,20 +376,33 @@ inline __device__ T add(T a, T b)
}
#define FINAL_MASK 0xffffffff
#define WARP_SIZE 32
template <typename T>
__inline__ __device__ T warpReduceSum(T val)
{
// Get the actual number of active threads in this warp
int active_warp_size = min(WARP_SIZE, blockDim.x - (threadIdx.x & ~(WARP_SIZE - 1)));
unsigned int mask = (1U << active_warp_size) - 1;
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1)
val = add<T>(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32)); //__shfl_sync bf16 return float when sm < 80
for (int offset = 16; offset > 0; offset >>= 1)
{
if (offset < active_warp_size)
{
val = add<T>(val, __shfl_xor_sync(mask, val, offset, WARP_SIZE));
}
}
return val;
}
inline __device__ float block_reduce_sum(float val)
{
__shared__ float smem[32];
int lane_id = threadIdx.x % 32, warp_id = threadIdx.x / 32, warp_num = blockDim.x / 32;
__shared__ float smem[WARP_SIZE];
int lane_id = threadIdx.x % WARP_SIZE;
int warp_id = threadIdx.x / WARP_SIZE;
int warp_num = (blockDim.x + WARP_SIZE - 1) / WARP_SIZE; // Ceiling division to include partial warps
val = warpReduceSum(val);
if (lane_id == 0)
{
@ -391,6 +411,7 @@ inline __device__ float block_reduce_sum(float val)
__syncthreads();
val = lane_id < warp_num ? smem[lane_id] : 0.f;
val = warpReduceSum(val);
return val;
}
@ -410,7 +431,7 @@ __device__ float4 loadfloat4(void const* ptr)
template <int DIM, int NUM_THREADS, int NUM_INPUTS, typename T_OUT, typename T_IN>
__global__ void __launch_bounds__(128, 1)
RMSNorm(T_IN* input_plus_residual, T_OUT* output_norm, T_IN const* buffer_input, T_IN const* gamma, float epsilon,
T_IN const* residual, int batch_size, uint32_t* buffer_flags)
T_IN const* residual, int batch_size, uint32_t buffer_size, uint32_t* buffer_flags)
{
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
static bool const LAMPORT = true;
@ -433,7 +454,7 @@ __global__ void __launch_bounds__(128, 1)
int offsets[NUM_INPUTS][DIM / (1 * ELTS_PER_THREAD * NUM_THREADS)];
LamportFlags flags(buffer_flags);
LamportFlags flags(buffer_flags, buffer_size);
T_IN const* input = &buffer_input[flags.input_offset + flags.buffer_size];
cudaTriggerProgrammaticLaunchCompletion();
@ -598,16 +619,15 @@ __global__ void __launch_bounds__(128, 1)
#endif
}
template <typename T, int H_DIM>
template <typename T, int H_DIM, int NUM_THREADS>
void twoshot_rmsnorm(T* prenorm_output, T* normed_output, T const* input, T const* gamma, double epsilon,
T const* residual, uint32_t* buffer_flags, int batch, cudaStream_t stream)
T const* residual, uint32_t buffer_size, uint32_t* buffer_flags, int batch, cudaStream_t stream)
{
// input to rmsnorm is the buffer in the twoshot ar
// We should use prenorm output to determine the actual used size
float _epsilon{static_cast<float>(epsilon)};
static constexpr int NUM_THREADS = 128;
static constexpr int CGA_THREADS = NUM_THREADS;
constexpr int iters = H_DIM / CGA_THREADS;
@ -628,28 +648,34 @@ void twoshot_rmsnorm(T* prenorm_output, T* normed_output, T const* input, T cons
&RMSNorm<H_DIM, NUM_THREADS, 1, T, T>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size);
config.dynamicSmemBytes = shmem_size;
TLLM_CUDA_CHECK(cudaLaunchKernelEx(&config, &RMSNorm<H_DIM, NUM_THREADS, 1, T, T>, prenorm_output, normed_output,
input, gamma, _epsilon, residual, batch, buffer_flags));
input, gamma, _epsilon, residual, batch, buffer_size, buffer_flags));
}
#define LAUNCH_RMSNORM_KERNEL(T, H_DIM) \
twoshot_rmsnorm<T, H_DIM>(static_cast<T*>(params.residual_output), static_cast<T*>(params.output), \
#define LAUNCH_RMSNORM_KERNEL(T, H_DIM, NUM_THREADS) \
twoshot_rmsnorm<T, H_DIM, NUM_THREADS>(static_cast<T*>(params.residual_output), static_cast<T*>(params.output), \
static_cast<T const*>(params.input), static_cast<T const*>(params.gamma), params.epsilon, \
static_cast<T const*>(params.residual), params.buffer_flags, params.batch, params.stream)
static_cast<T const*>(params.residual), params.buffer_size, params.buffer_flags, params.batch, params.stream)
void twoshot_rmsnorm_op(RMSNormParams const& params)
{
auto dtype = params.dtype;
#define CASE_DISPATCH_RMSNORM(T, H_DIM, NUM_THREADS) \
case H_DIM: LAUNCH_RMSNORM_KERNEL(T, H_DIM, NUM_THREADS); break;
#define TYPE_DISPATCH_RMSNORM(T) \
CASE_DISPATCH_RMSNORM(T, 2048, 128) \
CASE_DISPATCH_RMSNORM(T, 2880, 120) \
CASE_DISPATCH_RMSNORM(T, 4096, 128) \
CASE_DISPATCH_RMSNORM(T, 5120, 128) \
CASE_DISPATCH_RMSNORM(T, 7168, 128) \
CASE_DISPATCH_RMSNORM(T, 8192, 128)
if (dtype == nvinfer1::DataType::kFLOAT)
{
switch (params.hidden_dim)
{
case 2048: LAUNCH_RMSNORM_KERNEL(float, 2048); break;
case 4096: LAUNCH_RMSNORM_KERNEL(float, 4096); break;
// Llama-4 Hidden Dimension
case 5120: LAUNCH_RMSNORM_KERNEL(float, 5120); break;
// DeepSeek Hidden Dimension
case 7168: LAUNCH_RMSNORM_KERNEL(float, 7168); break;
case 8192: LAUNCH_RMSNORM_KERNEL(float, 8192); break;
TYPE_DISPATCH_RMSNORM(float);
default: TLLM_CHECK_WITH_INFO(false, "[MNNVL TwoShot RMSNorm]: unsupported hidden_dim.");
}
}
@ -657,13 +683,7 @@ void twoshot_rmsnorm_op(RMSNormParams const& params)
{
switch (params.hidden_dim)
{
case 2048: LAUNCH_RMSNORM_KERNEL(__nv_bfloat16, 2048); break;
case 4096: LAUNCH_RMSNORM_KERNEL(__nv_bfloat16, 4096); break;
// Llama-4 Hidden Dimension
case 5120: LAUNCH_RMSNORM_KERNEL(__nv_bfloat16, 5120); break;
// DeepSeek Hidden Dimension
case 7168: LAUNCH_RMSNORM_KERNEL(__nv_bfloat16, 7168); break;
case 8192: LAUNCH_RMSNORM_KERNEL(__nv_bfloat16, 8192); break;
TYPE_DISPATCH_RMSNORM(__nv_bfloat16);
default: TLLM_CHECK_WITH_INFO(false, "[MNNVL TwoShot RMSNorm]: unsupported hidden_dim.");
}
}
@ -671,13 +691,7 @@ void twoshot_rmsnorm_op(RMSNormParams const& params)
{
switch (params.hidden_dim)
{
case 2048: LAUNCH_RMSNORM_KERNEL(__nv_half, 2048); break;
case 4096: LAUNCH_RMSNORM_KERNEL(__nv_half, 4096); break;
// Llama-4 Hidden Dimension
case 5120: LAUNCH_RMSNORM_KERNEL(__nv_half, 5120); break;
// DeepSeek Hidden Dimension
case 7168: LAUNCH_RMSNORM_KERNEL(__nv_half, 7168); break;
case 8192: LAUNCH_RMSNORM_KERNEL(__nv_half, 8192); break;
TYPE_DISPATCH_RMSNORM(__nv_half);
default: TLLM_CHECK_WITH_INFO(false, "[MNNVL TwoShot RMSNorm]: unsupported hidden_dim.");
}
}
@ -685,6 +699,8 @@ void twoshot_rmsnorm_op(RMSNormParams const& params)
{
TLLM_CHECK_WITH_INFO(false, "[MNNVL TwoShot RMSNorm]: unsupported dtype.");
}
#undef TYPE_DISPATCH_RMSNORM
#undef CASE_DISPATCH_RMSNORM
}
} // namespace tensorrt_llm::kernels::mnnvl

View File

@ -30,6 +30,7 @@ struct AllReduceParams
int buffer_M;
int num_tokens;
int token_dim;
uint32_t buffer_size;
void** buffer_ptrs_dev;
void* multicast_ptr;
void* buffer_flags;
@ -50,6 +51,7 @@ struct RMSNormParams
void const* gamma;
double epsilon;
void* residual;
uint32_t buffer_size;
uint32_t* buffer_flags;
int batch;
int hidden_dim;

View File

@ -150,8 +150,8 @@ __device__ __forceinline__ void fused_op(
constexpr int SF_VEC_SIZE = 16;
using PackedVec = PackedVec<DType>;
PackedVec pack_val = *reinterpret_cast<PackedVec const*>(&norm_val);
auto sf_out = cvt_quant_to_fp4_get_sf_out_offset<uint32_t, 2, SF_VEC_SIZE>(std::nullopt /* batchIdx */,
token_id, access_id_in_token, std::nullopt /* numRows */, params.hidden_dim,
auto sf_out = cvt_quant_get_sf_out_offset<uint32_t, 2>(std::nullopt /* batchIdx */, token_id,
access_id_in_token, std::nullopt /* numRows */, params.hidden_dim / SF_VEC_SIZE,
reinterpret_cast<uint32_t*>(params.scale_out), params.layout);
reinterpret_cast<uint32_t*>(params.quant_out)[access_id]
= cvt_warp_fp16_to_fp4<DType, SF_VEC_SIZE, false>(pack_val, *params.scale_factor, sf_out);

View File

@ -55,7 +55,7 @@ struct AllReduceFusionParams
void* rms_gamma;
float rms_eps;
float* scale_factor;
FP4QuantizationSFLayout layout = FP4QuantizationSFLayout::SWIZZLED;
QuantizationSFLayout layout = QuantizationSFLayout::SWIZZLED;
cudaStream_t stream;
};

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d5bb139b12206a563daec9fa473dda422319bde5ae5f965d37cf5ca67d325c49
size 1005546

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:c4357a935656d47414a459939720b66311c67213f450168715e1cb0238653768
size 1066324

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:0a0671e7cbbed9f51dc0c47e4b970e2f72067d629ff6562c9d65f9cd55c68578
size 361861
oid sha256:c709dce149c0f4500539e495c90d1da2d86cec28c4187ee9494b015642e158cf
size 363441

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:5ec9817bebb07483ce29d8d91c45d35c2c05f0101bfa70146fba5a6576a6b825
size 1091614
oid sha256:b9170581da010aca67f4bafd9f6f59aaaf5fd1958a1fdd336aa208146599ac06
size 1094770

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:0540cdb398818ec54a60c34b462c158e169347db73d244d633669d74211696ba
size 1467312
oid sha256:2147a246067f7ea74ca382fbc8c02a26332479e5205ecfbe08fb84161a3a87ec
size 1483888

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:69bdfba64f1faff30ed8389a28b7b9ef37c0d180b1df643722b280011c8f74e8
size 692990
oid sha256:279bd48b8ac53690bb4e37dffbe9060428db80c1417ff29c6f4d4a10ab35a7c9
size 700094

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:c8173308813999ab64ba8236016b23fbfd3f3f1501f61290bf71ea027ead2920
size 642456
oid sha256:db5d186ce70d7a94cae2b6619b3449ca557903944beba1ee738d2ee425792d74
size 652718

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:f41ae066b01b2a9c3b5165535f743461a9a1d559f6fcd0a00a04c554f8a50962
size 414757
oid sha256:089a98cf8ab0bbd7530e69821c42220ea02578b740bff62a3e6e33de45209114
size 416335

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:ab0be8e667d459e13135f96469613f1c095e47187b24e5d40c7c57583351a076
size 1194236
oid sha256:1f0cc486ec5e9c1720f495a2a5e7c26d42e737694d307d4746a08b6ead5cc225
size 1197394

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:03d86280f76994e2e01d43747cb5c811496b8340d031ebb0c3bdd46437422994
size 1654394
oid sha256:398965e34c1a4c747b42d8836c04934daaa43903b7931586ed12120e17a61f76
size 1672548

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:35c5715bcb1a16c343f3a28be105fb6fee1bbca24cf832f71a7d0f20cf9a0b3e
size 365015
oid sha256:77cbd7d45164d24be73e021bc0a8745b4f021e4369a254e216ee00b36d3c7263
size 366593

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a3335a8d4b2c0ca63f006c3f957d57aa3f808ef06d4adda322c311a333286d84
oid sha256:3a3f74fbe72ef54b9c028d957353c1ecbff1d20bcc9619ff17ee37471934a2ab
size 1126352

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:fdc0bf099862d352b3b765e117437240a82e4749d3efd104881647dd4ea14562
oid sha256:b3af082c6742f385d0d2c96489ff1de314458eb992d6d5a251c737f8ec912e79
size 644092

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:ccd938df8f78af4eae306c6e9e669599c2baf6f095f956318470063c560fbd3c
size 1091610
oid sha256:8e26f3b8cc173301b3cf07ba1ca7893b6f140432410b0b298361ecff597604c2
size 1095556

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:ce4d35ab4c7b65476f0dcec635db1791fcb718afd6b3531338712f5b2bc9aa84
size 1460204
oid sha256:32220d11bc3542e9edcc36d51b4866bf40044213114d7e237e003afc1fc7c464
size 1478358

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d088ce37b21d335ba1f92034cf97f78fc968d7fecaa0c4f9ec83a0d5165f1d99
oid sha256:3ee5ae75df4866d848e90616562345d3740b17b68c90f06329dc074dba5217a9
size 482709

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:40653ec672098e2cb1f94c473fa67852efcf6b49a6e8109e4fcf39422281acb4
oid sha256:817ae5c1eb8a8c6f22a76ab0b88075fd3391d06abb7dd6d9ab51206b809cd69d
size 657930

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:96348957990518db6f51af7c681a71e625dede568cc8f8303dd2de8ad09bfc28
oid sha256:680734da0abb1c3029dce32e892687f649c4219f66574acb15ab88471f508263
size 677218

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:4687df80ac2fa9454b0564b0a80d78cfaedc2c7796c8f3a1010dd7ebbf722c83
oid sha256:c27e871dd680022920081c30c5e239613e53b42129680fdb1d17668b5c5ddd9a
size 369401

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d8b9985065f5f2c62b74c05f8eed02b1909c96656b26fbd7779cc57a2146b037
size 947140
oid sha256:3e1ecaa635067924b692b665241d86e1d8c1d60a19290de7adde1ff2ca7dbeb0
size 956612

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:23599e63b07ad966df921daf3cb97a9ed5cde27eeda0fd96ba5abd835b48f89a
size 590779
oid sha256:d3018c622303f89c6f22f037ec99eaeaeea9cfe8911e22463b48a22c13116805
size 592357

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:cd1c452565583b20913d835de9b14c2f19c0cc431bc926ea6c92295362a85bca
size 1813864
oid sha256:a7a381f2855236f418a40124a5254401c95001d5e15c074a704e22cc7ed89aa2
size 1818600

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:b20de2c6bb3081564ddfbf7ece80fb2c17e66f4e7ff0e0969da4e4655e90d1ec
size 2407418
oid sha256:9bb49ace4dedc4faa3de2b9c22e09db0f3990129ce7ab4afb6419c38a5d48a16
size 2427152

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:33a0e8bb2391128e688e5c6356f09a5ed189ce5c1bcdeef4efc0ce0415dc2849
size 555245
oid sha256:9769d7cb9754718798be515c84c45ff48e43322573f3f12e31c2e42e99d8dbd4
size 557613

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:4b014f41b1cfdf6ed2729778841213a36440191eb3c087346a02c21510bd3f0e
size 665794
oid sha256:134f4a73e0e6b02b717319ec49e3b3ea0a585cad385a1f300e6c5761f12de9d7
size 671320

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:bd77afeb7dcd1ff8d6be80788b20e92e4fbc8c3026ba12d1d522c99316754a7c
size 1740442
oid sha256:7935b0f053a79a7e620c0efe274fa5b4c840fc9c6e439a381c4d380446e1cb68
size 1744388

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:b674707d02aac297b66d523de8b11618ca1598c49eeaf7ce9b1c9d516ce95c4b
size 2247958
oid sha256:74ecbbaa19b2efe97a3b12c488f0e03c2102f16c460239df4bfc19976fc4365e
size 2266902

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:7556f88488e05ee669e763b839afa1b7690060cfa9d8482d419c0ca336df9352
oid sha256:813265d25709bd2d39982efbaf092c9163b124bd990fccab505b3c22134522aa
size 595585

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:ac9d879aa0c70967bb3a79cd7034998baf43a544c0dd4444ebddeb76e78df5ae
oid sha256:dd36195c01bf7c2a2013d5f31d2e74c2579c471385d7b45be7e35ea2f0652608
size 908162

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:4e781c0278fc46142f578ae51bfeb38767e89d9c25b92023215948f99dd1d3ed
oid sha256:31d4d6dca68c4632d1f435e9179582cfe2ad7a75ee0f7625ee67b0044c914f10
size 1371512

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d608e9e3ec460d2a38f43067a7d7a2dd408e068db690806bbafb11007e175336
oid sha256:6570d3ee7b651dec797e82b31eb21fd3261c6e2639fb7c9b157f251bf98bb3bf
size 1419662

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:9c1e1d300866c6425c2495e550230051debdca0a7eb85874ae33c0c2de8a81cb
oid sha256:88b972677c5436b90fe85870278e3b23d6f709608f99295bddf0be3861d95d1a
size 1419662

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:132d83639e34af1b431abdcb3f09542d0389030b85752e18a3ae221ead7d24a3
oid sha256:d975f605d62c3070d6cf72f6114d98642c520e66989ed2d2845c3213e921ebf7
size 1965880

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:4a96710f6c691580c2363c187a75fd436f5e6be732810a1a45182ce72dc52d1e
oid sha256:ef5a2728cbd3241f45f3d8285c91a818e11b2a9fedf322f343a9461d31a6ad30
size 1380182

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a6339f008f451d030aa36a6b3fac7179e7534f7f2474d641fa0ebfbf487074e7
oid sha256:16b5f3d3f8760dabc0849217cf11edf18d19896dda475a5fc233bbfd444faf33
size 1401494

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:57ebcae2b70fc28881f2b3969868d64c203ef4a9cbc9588a9e28051c5f5b6849
oid sha256:cbacb235f39adaeabd68e2fc46c51aac6ca26cdf96293a6a7eb60b5be40640ef
size 1401494

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:5e2a4ce1b944feb2b3ed535943089a2d5968bf523b149885df78f7fa4bd7e835
oid sha256:e6f3e068435339a64d47673f8018b66c202f6259d68e0a97a4a30acb7505a7fd
size 1935872

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:f5d456b30f89ad05ba5b852fabcffb3f8269913d83ef8c0e4e319f2243dee54d
oid sha256:7c2d7ab0692de5405b26d19a0c57d720285366ac12a8550bbabca1613cce7f0c
size 305897

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:85593d3c2fecb6842a72952c6dcbde19a70e6b26245829d279ca50bb391eb636
oid sha256:91a26adfddc0bcaf8b42249f59f1a0b9f74be0f82c7378fe4b56f3a2fa3d4bf1
size 290109

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:69cd61bd8334d2109067ef0460a91b8dba4c2cb07392eb636d72d025ccb15bf9
oid sha256:6ef79c9e2e2d8bba55d7803dc8dc147b5d8babc29e906a43407a8722bbd8d939
size 498507

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:0427b7729ce3cfa652a4595d04f936a947febec8f2c96ce33eed7cbaaa05613e
oid sha256:0eef025f8e8581868b02bcea37ff225afebcbb2966450fb29fb0e32ac54eccd4
size 668214

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:321bcd81b8965c8dfc08682f775508ae18e3ff711490ee8dff5fe56c20f74843
oid sha256:abb2857ffb85cc36aae90ebb674635dffee2b2c5f7ad1ea81bb8002b65d5a0f8
size 711628

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:aa77d3789c0ca314689125ec303a8af76554120a708a4b63395c69b7aad07f04
oid sha256:49a3661535314b139e2794fe16f6f3e0a8d45742b68ea59ba99a9113068adf2c
size 752698

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:aa35aa70d0fa304c776c076a1a189d32a054d3f696dac5d99018085d1108c73b
oid sha256:d76fb6c4f8bb2de687bc5f9f275389356934119c1f0db9983dcf0ec7b68c6197
size 748726

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d1a702d456b5acf279487dd810e3e33efdd1c7bd82530ceb5a32ad30ec30396c
oid sha256:be8ee89f4489c430d0ff6e9c6cf4e07379ac05abf468d47e34e084ad594b2037
size 946060

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:558aa7d42de329c49361c94c4baef16738304b21b6adbe675d77c7819ef37660
oid sha256:aa4be8ca2dd52e56c9a6af76b90ac353d217fad5fa931b21129ac5a811b5283a
size 489823

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:7b5baa6048e6c33e74c6d343eb7c76252ff2e534fe467b3189af12b5d64af37c
oid sha256:cb0482b768a40bc7f8a86fa23a84bab62fb82c205f3237ff60becda50cbafc90
size 489823

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:e17cb191ad092e6db255ea503e49ea883ed56322fc58ed8d68710f6687376c1f
oid sha256:95b1796f4e7c905eca82ed3691427025f68e765797440b962b0114a5ab32b1d7
size 500083

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:bfca5660a931e08941347f7a0aefa82c214940e8eaa6b6d89cfded621f34a490
oid sha256:2d9f13977fc865e716f1f35dfdb222a38000b224ff7394134230ed5c88119947
size 496125

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:fffd2cd799953808034d7e7b89a57d4fede24db124bfb0d3938188177acbdfeb
oid sha256:007e32a06fcac853159dc5786940447281c57ba70406d38beb6f089fd037053d
size 182023

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:19ada3a5d449542f103077db8d193bc2293a8f48ccee201e366473964287314c
oid sha256:26241ea5909395116e1b1a0f19cadc448886f6a6ab2b3ba76c092b67cd0148f0
size 182023

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:b9c32124cd708aab7da30637d85437da0af9bf2157d163c19c6fe14498698cda
oid sha256:86e4ca60a459117c5e701631fbd3c67ca66e81d177c394c1fc9ad3b66396e69a
size 661096

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:7f248fd42759509c61d20f912ae74dc3a85448a9c8386370ea92492ed9031e80
oid sha256:770db1f4ec1c2d3c25767593b60cb095e49f7a6eb7abe054bbdec6e72db97f8d
size 672936

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:190fd946ddc7e1b5e9ca2172ec1de39c6288829773d9ce29fe98374256eff566
oid sha256:0b6428cae2d0c8c813925be9589c94771098cfe5a6d0ff2036104d3e36384b81
size 721900

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:b7cd5976c836bcd75c0cadfe968050ac60bf89b93df021ad6c1681e159c497c5
oid sha256:36c6932301fe3dc29631c28fcb8cb6b08652103bc7a36fd74a03a8189a1c77e4
size 717928

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:7c536d725e1d9ebd2cb836dfe3993edcc81101534db6b7f1943c8a9443838bf4
oid sha256:d858f6dcaf3f49fb3fa18b1c8c20ee1b933e2c8ddd1a429c8d3b5b4d269fb875
size 927892

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:b5907da5a2f68c010d44bbbd0d780e097f9625be15b2f85e8dd1f00dd4c31ff9
oid sha256:7dc92ab65ed0fc5f9d821f52a396a6d55ea9ae37e080eac7ff9e9c14eae741e7
size 631890

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:9cf14c71134a89ed6ffc83c0b7db06ed10e22b55294dc15ddf7f016427f01033
oid sha256:d66606a37cfe8eb78ccc3f548a231f770df9f46e70f6d3ba22fb8abe6216480e
size 159919

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:f2b83c70dbc8ab0b3695dab3f4d2069b7ee7119e9140d7860b8c19f59a498589
oid sha256:b723b296cff04602f64a5da9928e6f9b6a03c5cc608ba9ef7d8055f23f1f4ea2
size 159919

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:fc8369f5701dceea91d429a713ddcbb4ecb0ad08d3c9042688557ead5f00e9da
oid sha256:d40578a5684262cd8136705367e2c98493ea9b9fcfc123c7efa3ead14017b5b8
size 483493

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:4e9fffff2d13d49613e5f9334a010ca9bcde43b3bb55a792fd97fe2c867760dc
oid sha256:60cc82b9d11c53392de91a7c4c097263c20a56f9b346278c7c9af12ef2bb5fbf
size 496123

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:dd3041ba5a52263f7f02d64f1911c50e346151bf529e865c1abf22583abd3e21
oid sha256:8f685b6b2a0a573953f31fad89fa37e949361db245de69c0c06ce0bbb14eacef
size 443285

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:12482099b086249163085e6e3421a61f6e304f865aaf56dd15382614be5e48e7
oid sha256:834f0f3601c589893a21b957be2864df594f96b34b2cfd6018ada8319986aa21
size 441683

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:bfea1ea1627eaef7b614db08bad00bda8b611c8e466c858e050c0ce2aee2eafb
oid sha256:3d81a070e7ed49f1e1a322d38a757a3505186cf5cbded99814e950e07229a46a
size 298049

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:f828600699faa3a0474085cbbe88d2e0ac7c8e056c976b81a882c3a72682e527
oid sha256:b9de5bc49d888699da1880d24ccf6a9cb6c0049d7a244d1ae9ab64b7365ecd5a
size 296445

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:2d4b297922065ecb79b4a1278d048b253b57601d011fc5833a32f9fc1b78e58e
oid sha256:e30ed0df4b0d0b1da1ace5831dc0a7a526e04001b25860f862345c78acff5a43
size 427485

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:3fd5305445c9856fbd5d9dfaffdd7f87b9014638f33fb63fb2cb4fce9893b20b
oid sha256:030015dc1811e3dc2ae36ed770f51063a3f46deae42ead5e1523c977b438a133
size 425883

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:2b7fee97097f799830df2bcb1c782c7ea9018243cbd5cd0e0f47ec299b49db79
oid sha256:6921a204892e1336cef2a308be38855f3c888e56bd6a16752d2806aa9e93c431
size 1524634

Some files were not shown because too many files have changed in this diff Show More