mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[None] [feat] Add model gpt-oss (#6645)
Signed-off-by: Hao Lu <14827759+hlu1@users.noreply.github.com>
This commit is contained in:
parent
6c1f7d8b91
commit
8207d5fd39
@ -122,6 +122,16 @@ public:
|
||||
return QuantMode(BaseType(1u) << 14);
|
||||
}
|
||||
|
||||
static constexpr QuantMode w4a8Mxfp4Mxfp8() noexcept
|
||||
{
|
||||
return QuantMode(BaseType(1u) << 15);
|
||||
}
|
||||
|
||||
static constexpr QuantMode w4a16Mxfp4() noexcept
|
||||
{
|
||||
return QuantMode(BaseType(1u) << 16);
|
||||
}
|
||||
|
||||
constexpr BaseType value() const noexcept
|
||||
{
|
||||
return mValue;
|
||||
@ -202,6 +212,16 @@ public:
|
||||
return isSet(w4a8Mxfp4Fp8());
|
||||
}
|
||||
|
||||
constexpr bool hasW4a8Mxfp4Mxfp8() const noexcept
|
||||
{
|
||||
return isSet(w4a8Mxfp4Mxfp8());
|
||||
}
|
||||
|
||||
constexpr bool hasW4a16Mxfp4() const noexcept
|
||||
{
|
||||
return isSet(w4a16Mxfp4());
|
||||
}
|
||||
|
||||
constexpr bool hasKvCacheQuant() const noexcept
|
||||
{
|
||||
return hasInt8KvCache() || hasFp8KvCache() || hasFp4KvCache();
|
||||
@ -209,7 +229,8 @@ public:
|
||||
|
||||
static constexpr QuantMode fromDescription(bool quantizeWeights, bool quantizeActivations, bool perToken,
|
||||
bool perChannel, bool perGroup, bool useInt4Weights, bool useInt8KvCache, bool useFp8KvCache, bool useFp8Qdq,
|
||||
bool useFp8RowWise, bool useW4a8QServe, bool useFp4Quant, bool useFp8BlockScales, bool useW4a8Mxfp4Fp8)
|
||||
bool useFp8RowWise, bool useW4a8QServe, bool useFp4Quant, bool useFp8BlockScales, bool useW4a8Mxfp4Fp8,
|
||||
bool useW4a8Mxfp4Mxfp8, bool useW4a16Mxfp4)
|
||||
{
|
||||
QuantMode quantMode{};
|
||||
if (quantizeWeights)
|
||||
@ -278,25 +299,35 @@ public:
|
||||
quantMode += w4a8Mxfp4Fp8();
|
||||
}
|
||||
|
||||
if (useW4a8Mxfp4Mxfp8)
|
||||
{
|
||||
quantMode += w4a8Mxfp4Mxfp8();
|
||||
}
|
||||
|
||||
if (useW4a16Mxfp4)
|
||||
{
|
||||
quantMode += w4a16Mxfp4();
|
||||
}
|
||||
|
||||
return quantMode;
|
||||
}
|
||||
|
||||
static constexpr QuantMode useSmoothQuant(bool perToken = false, bool perChannel = false)
|
||||
{
|
||||
return fromDescription(
|
||||
true, true, perToken, perChannel, false, false, false, false, false, false, false, false, false, false);
|
||||
return fromDescription(true, true, perToken, perChannel, false, false, false, false, false, false, false, false,
|
||||
false, false, false, false);
|
||||
}
|
||||
|
||||
static constexpr QuantMode useQServe(bool perGroup)
|
||||
{
|
||||
return fromDescription(
|
||||
true, true, false, false, perGroup, true, false, false, false, false, true, false, false, false);
|
||||
return fromDescription(true, true, false, false, perGroup, true, false, false, false, false, true, false, false,
|
||||
false, false, false);
|
||||
}
|
||||
|
||||
static constexpr QuantMode useWeightOnly(bool useInt4Weights = false, bool perGroup = false)
|
||||
{
|
||||
return fromDescription(true, false, false, false, perGroup, useInt4Weights, false, false, false, false, false,
|
||||
false, false, false);
|
||||
false, false, false, false, false);
|
||||
}
|
||||
|
||||
static QuantMode const fromQuantAlgo(
|
||||
@ -353,28 +384,38 @@ public:
|
||||
}
|
||||
else if (quantAlgo == "FP8")
|
||||
{
|
||||
quantMode = fromDescription(
|
||||
false, false, false, false, false, false, false, false, true, false, false, false, false, false);
|
||||
quantMode = fromDescription(false, false, false, false, false, false, false, false, true, false, false,
|
||||
false, false, false, false, false);
|
||||
}
|
||||
else if (quantAlgo == "FP8_ROWWISE")
|
||||
{
|
||||
quantMode = fromDescription(
|
||||
false, false, true, true, false, false, false, false, false, true, false, false, false, false);
|
||||
quantMode = fromDescription(false, false, true, true, false, false, false, false, false, true, false, false,
|
||||
false, false, false, false);
|
||||
}
|
||||
else if (quantAlgo == "FP4")
|
||||
{
|
||||
quantMode = fromDescription(
|
||||
false, false, false, false, false, false, false, false, false, false, false, true, false, false);
|
||||
quantMode = fromDescription(false, false, false, false, false, false, false, false, false, false, false,
|
||||
true, false, false, false, false);
|
||||
}
|
||||
else if (quantAlgo == "FP8_BLOCK_SCALES")
|
||||
{
|
||||
quantMode = fromDescription(
|
||||
false, false, false, false, false, false, false, false, false, false, false, false, true, false);
|
||||
quantMode = fromDescription(false, false, false, false, false, false, false, false, false, false, false,
|
||||
false, true, false, false, false);
|
||||
}
|
||||
else if (quantAlgo == "W4A8_MXFP4_FP8")
|
||||
{
|
||||
quantMode = fromDescription(
|
||||
false, false, false, false, false, false, false, false, false, false, false, false, false, true);
|
||||
quantMode = fromDescription(false, false, false, false, false, false, false, false, false, false, false,
|
||||
false, false, true, false, false);
|
||||
}
|
||||
else if (quantAlgo == "W4A8_MXFP4_MXFP8")
|
||||
{
|
||||
quantMode = fromDescription(false, false, false, false, false, false, false, false, false, false, false,
|
||||
false, false, false, true, false);
|
||||
}
|
||||
else if (quantAlgo == "W4A16_MXFP4")
|
||||
{
|
||||
quantMode = fromDescription(false, false, false, false, false, false, false, false, false, false, false,
|
||||
false, false, false, false, true);
|
||||
}
|
||||
|
||||
if (kvCacheQuantAlgo == "INT8")
|
||||
|
||||
@ -50,7 +50,7 @@ def getSMVersion():
|
||||
ids=["fp16", "bf16", "fp16-fp32", "e4m3"])
|
||||
@pytest.mark.parametrize('flag', [
|
||||
"-s-q 128 -paged-kv", "-s-q 63 -paged-kv", "-paged-kv",
|
||||
"-softcapping-scale-bmm1 30", "-contiguous-q-kv"
|
||||
"-softcapping-scale-bmm1 30", "-contiguous-q-kv", "-use-attention-sinks"
|
||||
])
|
||||
@pytest.mark.parametrize('tiled_kernel', ["", "-force-non-tiled"])
|
||||
def test_trtllm_flash_attention_fmha(d, s, dtype, flag, tiled_kernel):
|
||||
@ -117,8 +117,8 @@ def test_trtllm_flash_attention_fmha(d, s, dtype, flag, tiled_kernel):
|
||||
f"bin/fmha.exe -d {d} -h 16 -b 8 -s {s} -min-s 128 -custom-mask -gqa 2 -v {verbose} {dtype} {epsilon} {flag} {tiled_kernel}",
|
||||
shell=True,
|
||||
check=True)
|
||||
# alibi and softcapping-scale-bmm1 are mutually exclusive.
|
||||
if '-softcapping-scale-bmm1' not in flag:
|
||||
# alibi doesn't work with softcapping-scale-bmm1/use-attention-sinks.
|
||||
if '-softcapping-scale-bmm1' not in flag and '-use-attention-sinks' not in flag:
|
||||
subprocess.run(
|
||||
f"bin/fmha.exe -d {d} -h 16 -b 8 -s {s} -min-s 128 -causal-mask -alibi -v {verbose} {dtype} {epsilon} {flag} {tiled_kernel}",
|
||||
shell=True,
|
||||
|
||||
@ -326,9 +326,6 @@ struct Compute
|
||||
uint32_t smem_v = __cvta_generic_to_shared(&shared->smem_v[0]);
|
||||
Compute_tile_o ctile_o(0, smem_v);
|
||||
|
||||
// BMM2 epilogue
|
||||
Tile_o_epilogue tile_o_epilogue(params);
|
||||
|
||||
// Mutex between two compute groups.
|
||||
OrderedMutexAccessor mutex_accessor(shared->compute_mutex, warpgroup_id, SYNC_BARRIER);
|
||||
// Notify warpgroup 0 to execute HGMMA first (overlap HGMMA and Softmax Math Instructions).
|
||||
@ -368,6 +365,9 @@ struct Compute
|
||||
sage_scale_row = head_info.bidb * params.h + head_info.bidh;
|
||||
}
|
||||
|
||||
// BMM2 epilogue
|
||||
Tile_o_epilogue tile_o_epilogue(params, head_info);
|
||||
|
||||
int q_step_idx = warpgroup_id;
|
||||
|
||||
// Compute work.
|
||||
@ -490,7 +490,7 @@ struct Compute
|
||||
if (valid_run)
|
||||
{
|
||||
// Final step's update.
|
||||
tile_o_epilogue.scale(ctile_o, p_sum);
|
||||
tile_o_epilogue.scale(ctile_o, p_max, p_sum);
|
||||
// Store o_tile to gmem.
|
||||
gmem_o.store(ctile_o.acc_);
|
||||
}
|
||||
|
||||
@ -454,7 +454,7 @@ struct Softmax_base
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < Mma_tile_o::CORES_M; mi++)
|
||||
{
|
||||
uint32_t const scale = float_to_half2(correction_[mi]);
|
||||
const uint32_t scale = float_to_half2(correction_[mi]);
|
||||
|
||||
// Assume only N has multiple MMAs (MMAS_M = 1).
|
||||
// MMAS_N > 1 when N dimension is split.
|
||||
@ -477,9 +477,15 @@ struct Softmax_base
|
||||
}
|
||||
|
||||
// BMM1 scale.
|
||||
uint32_t const scale_bmm1_;
|
||||
const uint32_t scale_bmm1_;
|
||||
// BMM1 softcapping scale.
|
||||
float const softcapping_scale_bmm1_;
|
||||
|
||||
// The sliding window size.
|
||||
int const sliding_window_size_;
|
||||
// The log2 attention chunk size.
|
||||
int const log2_chunked_attention_size_;
|
||||
|
||||
// The thread idx in the warp group.
|
||||
int tidx_;
|
||||
// The col index for the mma thread layout.
|
||||
@ -487,15 +493,10 @@ struct Softmax_base
|
||||
// The row index for the mma thread layout.
|
||||
int quad_row_;
|
||||
|
||||
// The sliding window size.
|
||||
int const sliding_window_size_;
|
||||
// The log2 attention chunk size.
|
||||
int const log2_chunked_attention_size_;
|
||||
|
||||
// The packed mask ptr.
|
||||
uint32_t const* packed_mask_ptr_;
|
||||
// The packed mask k-dim stride in bytes;
|
||||
int64_t const params_packed_mask_stride_in_bytes_;
|
||||
const int64_t params_packed_mask_stride_in_bytes_;
|
||||
|
||||
// Unpacked BMM1 output buffer.
|
||||
float elt_[Mma_tile_p::CORES_M][Mma_tile_p::CORES_N * 2];
|
||||
@ -1072,20 +1073,53 @@ struct Tile_o_epilogue_base
|
||||
// The MMA tile for the BMM2.
|
||||
using Mma_tile_o = typename Kernel_traits::Mma_tile_o;
|
||||
|
||||
template <typename Params>
|
||||
inline __device__ Tile_o_epilogue_base(Params const& params)
|
||||
// Apply the exp2f optimization (fuse bmm1_scale and -max into FMAs).
|
||||
enum
|
||||
{
|
||||
; // nothing to construct.
|
||||
EXP2F_OPTIMIZATION = Kernel_traits::EXP2F_OPTIMIZATION
|
||||
};
|
||||
|
||||
template <typename Params, typename Block_info>
|
||||
inline __device__ Tile_o_epilogue_base(Params const& params, Block_info& block_info)
|
||||
{
|
||||
has_attention_sink_ = params.attention_sinks != nullptr;
|
||||
head_idx_ = block_info.bidh;
|
||||
attention_sink_ = has_attention_sink_ ? params.attention_sinks[block_info.bidh] : 0.f;
|
||||
// It is only need when the exp2f optimization is enabled, so params.scale_bmm1 is always float.
|
||||
scale_bmm1_f_ = reinterpret_cast<float const&>(params.scale_bmm1_d ? *params.scale_bmm1_d : params.scale_bmm1);
|
||||
};
|
||||
|
||||
// The attention sinks.
|
||||
inline __device__ void add_attention_sink(float& sum, float max)
|
||||
{
|
||||
if (has_attention_sink_)
|
||||
{
|
||||
// The global max needs to be scaled by the bmm1 scale if exp2f optimization is enabled.
|
||||
if constexpr (EXP2F_OPTIMIZATION)
|
||||
{
|
||||
sum += exp2f(attention_sink_ * M_LOG2E - max * scale_bmm1_f_);
|
||||
}
|
||||
else
|
||||
{
|
||||
sum += expf(attention_sink_ - max);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Scale ctile_o output by 1/sum
|
||||
inline __device__ void scale(Compute_tile_o& ctile_o, float (&global_sum)[Mma_tile_o::CORES_M])
|
||||
inline __device__ void scale(
|
||||
Compute_tile_o& ctile_o, float (&global_max)[Mma_tile_o::CORES_M], float (&global_sum)[Mma_tile_o::CORES_M])
|
||||
{
|
||||
// Final step's update.
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < Mma_tile_o::CORES_M; mi++)
|
||||
{
|
||||
global_sum[mi] = global_sum[mi] == 0.f ? 1.f : 1.0f / global_sum[mi];
|
||||
// The global sum.
|
||||
float global_sum_mi = global_sum[mi];
|
||||
// Add the attention sink to the global sum.
|
||||
add_attention_sink(global_sum_mi, global_max[mi]);
|
||||
// The scale.
|
||||
float scale = global_sum_mi == 0.f ? 1.f : 1.0f / global_sum_mi;
|
||||
|
||||
// Assume only N has multiple MMAs (MMAS_M = 1).
|
||||
#pragma unroll
|
||||
@ -1096,12 +1130,21 @@ struct Tile_o_epilogue_base
|
||||
{
|
||||
float& reg0 = ctile_o.acc_[0][mma_ni].elt(2 * ni * Mma_tile_o::CORES_M + 2 * mi);
|
||||
float& reg1 = ctile_o.acc_[0][mma_ni].elt(2 * ni * Mma_tile_o::CORES_M + 2 * mi + 1);
|
||||
reg0 *= global_sum[mi];
|
||||
reg1 *= global_sum[mi];
|
||||
reg0 *= scale;
|
||||
reg1 *= scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Whether the attention sink is enabled.
|
||||
bool has_attention_sink_ = false;
|
||||
// The attention sink value.
|
||||
float attention_sink_ = 0.f;
|
||||
// The float scale of bmm1 outputs.
|
||||
float scale_bmm1_f_ = 1.f;
|
||||
// The head idx.
|
||||
int head_idx_ = 0;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1138,14 +1181,21 @@ struct Tile_o_epilogue<Hopper_hgmma_fp16_traits, Kernel_traits>
|
||||
using Base::Tile_o_epilogue_base;
|
||||
|
||||
// Scale ctile_o output by 1/sum
|
||||
inline __device__ void scale(Compute_tile_o& ctile_o, float (&global_sum)[Mma_tile_o::CORES_M])
|
||||
inline __device__ void scale(
|
||||
Compute_tile_o& ctile_o, float (&global_max)[Mma_tile_o::CORES_M], float (&global_sum)[Mma_tile_o::CORES_M])
|
||||
{
|
||||
// Final step's update.
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < Mma_tile_o::CORES_M; mi++)
|
||||
{
|
||||
global_sum[mi] = global_sum[mi] == 0.f ? 1.f : 1.0f / global_sum[mi];
|
||||
uint32_t const scale = float_to_half2(global_sum[mi]);
|
||||
// The global sum.
|
||||
float global_sum_mi = global_sum[mi];
|
||||
// Add the attention sink to the global sum.
|
||||
this->add_attention_sink(global_sum_mi, global_max[mi]);
|
||||
// The scale.
|
||||
float scale = global_sum_mi == 0.f ? 1.f : 1.0f / global_sum_mi;
|
||||
// The scale.
|
||||
const uint32_t scale_h = float_to_half2(scale);
|
||||
|
||||
// Assume only N has multiple MMAs (MMAS_M = 1).
|
||||
#pragma unroll
|
||||
@ -1155,7 +1205,7 @@ struct Tile_o_epilogue<Hopper_hgmma_fp16_traits, Kernel_traits>
|
||||
for (int ni = 0; ni < Mma_tile_o::CORES_N; ni++)
|
||||
{
|
||||
uint32_t& reg = ctile_o.acc_[0][mma_ni].reg(ni * Mma_tile_o::CORES_M + mi);
|
||||
reg = hmul2(reg, scale);
|
||||
reg = hmul2(reg, scale_h);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1215,27 +1265,58 @@ struct Tile_o_epilogue<Hopper_qgmma_e4m3_fp32_traits, Kernel_traits>
|
||||
// The MMA tile for the BMM2.
|
||||
using Mma_tile_o = typename Base::Mma_tile_o;
|
||||
|
||||
// Apply the exp2f optimization (fuse bmm1_scale and -max into FMAs).
|
||||
enum
|
||||
{
|
||||
EXP2F_OPTIMIZATION = Base::EXP2F_OPTIMIZATION
|
||||
};
|
||||
|
||||
// Ctor.
|
||||
template <typename Params>
|
||||
inline __device__ Tile_o_epilogue(Params const& params)
|
||||
: Base(params)
|
||||
template <typename Params, typename Block_info>
|
||||
inline __device__ Tile_o_epilogue(Params const& params, Block_info& block_info)
|
||||
: Base(params, block_info)
|
||||
, scale_bmm2_(*params.scale_bmm2_d)
|
||||
{
|
||||
}
|
||||
|
||||
// Add the attention sink to the global sum.
|
||||
inline __device__ void add_attention_sink(float& sum, float max)
|
||||
{
|
||||
if (this->has_attention_sink_)
|
||||
{
|
||||
// The global max needs to be scaled by the bmm1 scale if exp2f optimization is enabled.
|
||||
// Take the log2f(Traits_o::SOFTMAX_FP_QUANT_SCALE) into account as the same scale has been applied to sum.
|
||||
float quant_scale_in_log2 = log2f(Traits_o::SOFTMAX_FP_QUANT_SCALE);
|
||||
if constexpr (EXP2F_OPTIMIZATION)
|
||||
{
|
||||
sum += exp2f(this->attention_sink_ * M_LOG2E - max * this->scale_bmm1_f_ + quant_scale_in_log2);
|
||||
}
|
||||
else
|
||||
{
|
||||
sum += expf(this->attention_sink_ - max + quant_scale_in_log2);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Scale ctile_o output by 1/sum
|
||||
inline __device__ void scale(Compute_tile_o& ctile_o, float (&global_sum)[Mma_tile_o::CORES_M])
|
||||
inline __device__ void scale(
|
||||
Compute_tile_o& ctile_o, float (&global_max)[Mma_tile_o::CORES_M], float (&global_sum)[Mma_tile_o::CORES_M])
|
||||
{
|
||||
// Final step's update.
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < Mma_tile_o::CORES_M; mi++)
|
||||
{
|
||||
// The global sum.
|
||||
float global_sum_mi = global_sum[mi];
|
||||
// Add the attention sink to the global sum.
|
||||
add_attention_sink(global_sum_mi, global_max[mi]);
|
||||
#ifdef UNIFIED_EPILOGUE_SCALE
|
||||
// Descaling factor
|
||||
float const scale_bmm2_f_ = reinterpret_cast<float&>(scale_bmm2_);
|
||||
global_sum[mi] = global_sum[mi] == 0.f ? scale_bmm2_f_ : scale_bmm2_f_ / global_sum[mi];
|
||||
// The scale.
|
||||
float scale = global_sum_mi == 0.f ? scale_bmm2_f_ : scale_bmm2_f_ / global_sum_mi;
|
||||
#else
|
||||
global_sum[mi] = global_sum[mi] == 0.f ? 1.0f : 1.0f / global_sum[mi];
|
||||
float scale = global_sum_mi == 0.f ? 1.0f : 1.0f / global_sum_mi;
|
||||
#endif
|
||||
// Assume only N has multiple MMAs (MMAS_M = 1).
|
||||
#pragma unroll
|
||||
@ -1246,8 +1327,8 @@ struct Tile_o_epilogue<Hopper_qgmma_e4m3_fp32_traits, Kernel_traits>
|
||||
{
|
||||
float& reg0 = ctile_o.acc_[0][mma_ni].elt(2 * ni * Mma_tile_o::CORES_M + 2 * mi);
|
||||
float& reg1 = ctile_o.acc_[0][mma_ni].elt(2 * ni * Mma_tile_o::CORES_M + 2 * mi + 1);
|
||||
reg0 *= global_sum[mi];
|
||||
reg1 *= global_sum[mi];
|
||||
reg0 *= scale;
|
||||
reg1 *= scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -29,33 +29,36 @@ using Kv_block_array = fmha::Kv_block_array;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void run_softmax_fp32(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_q_seqlens_d,
|
||||
int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n, bool has_alibi);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void run_softmax_e4m3(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_q_seqlens_d,
|
||||
int s_inner, int s_outer, int b, int h, float scale_softmax, float softcapping_scale_bmm1, int warps_n,
|
||||
void run_softmax_fp32(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d,
|
||||
void* cu_q_seqlens_d, int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n,
|
||||
bool has_alibi);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void run_softmax_fp16(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_q_seqlens_d,
|
||||
int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n, bool has_alibi);
|
||||
void run_softmax_e4m3(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d,
|
||||
void* cu_q_seqlens_d, int s_inner, int s_outer, int b, int h, float scale_softmax, float softcapping_scale_bmm1,
|
||||
int warps_n, bool has_alibi);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void run_softmax_bf16(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_q_seqlens_d,
|
||||
int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n, bool has_alibi);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void run_softmax_int8(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_q_seqlens_d,
|
||||
int s_inner, int s_outer, int b, int h, float scale_i2f, float scale_f2i, float softcapping_scale_bmm1, int warps_n,
|
||||
void run_softmax_fp16(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d,
|
||||
void* cu_q_seqlens_d, int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n,
|
||||
bool has_alibi);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void run_softmax_bf16(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d,
|
||||
void* cu_q_seqlens_d, int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n,
|
||||
bool has_alibi);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void run_softmax_int8(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d,
|
||||
void* cu_q_seqlens_d, int s_inner, int s_outer, int b, int h, float scale_i2f, float scale_f2i,
|
||||
float softcapping_scale_bmm1, int warps_n, bool has_alibi);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void run_conversion_int32_to_int8(void* dst, void const* src, int s, int b, int h, int d, float scale);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -81,11 +84,11 @@ void run_sage_quant(unsigned int batch_size, unsigned int head_num, unsigned int
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void ground_truth(RefBMM& bmm1, RefBMM& bmm2, Data_type const data_type, Data_type const acc_type,
|
||||
void ground_truth(RefBMM& bmm1, RefBMM& bmm2, const Data_type data_type, const Data_type acc_type,
|
||||
float const scale_bmm1, float const scale_softmax, float const scale_bmm2, float const softcapping_scale_bmm1,
|
||||
void* qkv_d, void* vt_d, void* mask_d, void* p_d, void* s_d, void* tmp_d, void* o_d, void* softmax_sum_d,
|
||||
void* cu_q_seqlens_d, size_t const b, size_t const s, size_t const h, size_t const d, size_t const dv,
|
||||
int const runs, int const warps_m, int const warps_n, bool const has_alibi)
|
||||
void* qkv_d, void* vt_d, void* mask_d, void* attention_sinks_d, void* p_d, void* s_d, void* tmp_d, void* o_d,
|
||||
void* softmax_sum_d, void* cu_q_seqlens_d, const size_t b, const size_t s, const size_t h, const size_t d,
|
||||
const size_t dv, int const runs, int const warps_m, int const warps_n, bool const has_alibi)
|
||||
{
|
||||
|
||||
cudaStream_t stream = 0;
|
||||
@ -106,28 +109,28 @@ void ground_truth(RefBMM& bmm1, RefBMM& bmm2, Data_type const data_type, Data_ty
|
||||
// Softmax.
|
||||
if (data_type == DATA_TYPE_FP16 && acc_type == DATA_TYPE_FP16)
|
||||
{
|
||||
run_softmax_fp16(s_d, p_d, mask_d, softmax_sum_d, cu_q_seqlens_d, s, s, b, h, softcapping_scale_bmm1,
|
||||
warps_n, has_alibi);
|
||||
run_softmax_fp16(s_d, p_d, mask_d, attention_sinks_d, softmax_sum_d, cu_q_seqlens_d, s, s, b, h,
|
||||
softcapping_scale_bmm1, warps_n, has_alibi);
|
||||
}
|
||||
else if (data_type == DATA_TYPE_BF16 && acc_type == DATA_TYPE_FP32)
|
||||
{
|
||||
run_softmax_bf16(s_d, p_d, mask_d, softmax_sum_d, cu_q_seqlens_d, s, s, b, h, softcapping_scale_bmm1,
|
||||
warps_n, has_alibi);
|
||||
run_softmax_bf16(s_d, p_d, mask_d, attention_sinks_d, softmax_sum_d, cu_q_seqlens_d, s, s, b, h,
|
||||
softcapping_scale_bmm1, warps_n, has_alibi);
|
||||
}
|
||||
else if (data_type == DATA_TYPE_FP16 && acc_type == DATA_TYPE_FP32)
|
||||
{
|
||||
run_softmax_fp32(s_d, p_d, mask_d, softmax_sum_d, cu_q_seqlens_d, s, s, b, h, softcapping_scale_bmm1,
|
||||
warps_n, has_alibi);
|
||||
run_softmax_fp32(s_d, p_d, mask_d, attention_sinks_d, softmax_sum_d, cu_q_seqlens_d, s, s, b, h,
|
||||
softcapping_scale_bmm1, warps_n, has_alibi);
|
||||
}
|
||||
else if (data_type == DATA_TYPE_E4M3 && acc_type == DATA_TYPE_FP32)
|
||||
{
|
||||
run_softmax_e4m3(s_d, p_d, mask_d, softmax_sum_d, cu_q_seqlens_d, s, s, b, h, scale_softmax,
|
||||
softcapping_scale_bmm1, warps_n, has_alibi);
|
||||
run_softmax_e4m3(s_d, p_d, mask_d, attention_sinks_d, softmax_sum_d, cu_q_seqlens_d, s, s, b, h,
|
||||
scale_softmax, softcapping_scale_bmm1, warps_n, has_alibi);
|
||||
}
|
||||
else if (data_type == DATA_TYPE_INT8 && acc_type == DATA_TYPE_INT32)
|
||||
{
|
||||
run_softmax_int8(s_d, p_d, mask_d, softmax_sum_d, cu_q_seqlens_d, s, s, b, h, scale_bmm1, scale_softmax,
|
||||
softcapping_scale_bmm1, warps_n, has_alibi);
|
||||
run_softmax_int8(s_d, p_d, mask_d, attention_sinks_d, softmax_sum_d, cu_q_seqlens_d, s, s, b, h, scale_bmm1,
|
||||
scale_softmax, softcapping_scale_bmm1, warps_n, has_alibi);
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -179,7 +182,7 @@ static inline void set_params(bert::Fused_multihead_attention_params_v1& params,
|
||||
// types
|
||||
Data_type data_type, Data_type acc_type,
|
||||
// sizes
|
||||
size_t const b, size_t const s, size_t const h, size_t const d, size_t const packed_mask_stride,
|
||||
const size_t b, const size_t s, const size_t h, const size_t d, const size_t packed_mask_stride,
|
||||
// device pointers
|
||||
void* qkv_d, void* packed_mask_d, void* o_d, void* p_d, void* s_d,
|
||||
// scale factors
|
||||
@ -235,17 +238,17 @@ static inline void set_params(bert::Fused_multihead_attention_params_v1& params,
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
static inline void set_params(bert::Fused_multihead_attention_params_v2& params, Launch_params const launch_params,
|
||||
static inline void set_params(bert::Fused_multihead_attention_params_v2& params, const Launch_params launch_params,
|
||||
// types
|
||||
Data_type data_type, Data_type acc_type, Data_type output_dtype,
|
||||
// attention input layout
|
||||
Attention_input_layout input_layout,
|
||||
// sizes
|
||||
size_t const b, size_t const s_q, size_t const s_kv, size_t const h, size_t const h_kv, size_t const d,
|
||||
size_t const dv, size_t const total, const size_t num_grouped_heads, const size_t sliding_window_size,
|
||||
const size_t b, const size_t s_q, const size_t s_kv, const size_t h, const size_t h_kv, const size_t d,
|
||||
const size_t dv, const size_t total, const size_t num_grouped_heads, const size_t sliding_window_size,
|
||||
const size_t chunked_attention_size,
|
||||
// paged kv cache block size.
|
||||
size_t const tokens_per_block,
|
||||
const size_t tokens_per_block,
|
||||
// device pointers
|
||||
void* qkv_packed_d,
|
||||
// contiguous q.
|
||||
@ -261,8 +264,10 @@ static inline void set_params(bert::Fused_multihead_attention_params_v2& params,
|
||||
// offsets for different blocks in terms of the start address.
|
||||
int32_t* paged_block_offsets,
|
||||
// mask input.
|
||||
void* packed_mask_d, void* cu_mask_rows_d, void* cu_kv_seqlens_d, void* cu_q_seqlens_d, void* o_packed_d, void* p_d,
|
||||
void* s_d, void* softmax_stats_d, void* scale_bmm2_d,
|
||||
void* packed_mask_d, void* cu_mask_rows_d,
|
||||
// attention sinks.
|
||||
void* attention_sinks_d, void* cu_kv_seqlens_d, void* cu_q_seqlens_d, void* o_packed_d, void* p_d, void* s_d,
|
||||
void* softmax_stats_d, void* scale_bmm2_d,
|
||||
// scale factors
|
||||
float const scale_bmm1, float const scale_softmax, float const scale_bmm2, float const softcapping_scale_bmm1,
|
||||
// flags
|
||||
@ -329,6 +334,9 @@ static inline void set_params(bert::Fused_multihead_attention_params_v2& params,
|
||||
// The N dimension has to be aligned.
|
||||
params.packed_mask_stride_in_bytes = (align_to(int64_t(s_kv), int64_t(fmha::FLASH_ATTEN_MASK_N_ALIGNMENT))) / 8;
|
||||
|
||||
// Attention sinks.
|
||||
params.attention_sinks = reinterpret_cast<float*>(attention_sinks_d);
|
||||
|
||||
#if defined(STORE_P)
|
||||
params.p_ptr = p_d;
|
||||
params.p_stride_in_bytes = get_size_in_bytes(b * h * s_kv, acc_type);
|
||||
@ -412,13 +420,13 @@ static inline void set_params(bert::Fused_multihead_attention_params_v2& params,
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
static inline void determine_launch_params(Launch_params& launch_params, Data_type data_type, int sm, size_t const s,
|
||||
size_t const d, Attention_mask_type const attention_mask_type, Attention_input_layout const input_layout,
|
||||
static inline void determine_launch_params(Launch_params& launch_params, Data_type data_type, int sm, const size_t s,
|
||||
const size_t d, const Attention_mask_type attention_mask_type, const Attention_input_layout input_layout,
|
||||
bool const interleaved, bool const ignore_b1opt, bool const force_unroll, bool const use_tma,
|
||||
bool const force_non_flash_attention, bool const force_non_warp_specialization,
|
||||
bool const force_non_granular_tiling, bool const force_fp32_acc,
|
||||
// device props
|
||||
cudaDeviceProp const props)
|
||||
const cudaDeviceProp props)
|
||||
{
|
||||
|
||||
// Set launch params to choose kernels
|
||||
@ -573,6 +581,9 @@ int main(int argc, char** argv)
|
||||
// SageAttention block sizes
|
||||
int sage_block_size_q = 0, sage_block_size_k = 0, sage_block_size_v = 0;
|
||||
|
||||
// Use attention sinks (added to the denominator of softmax)
|
||||
bool use_attention_sinks = false;
|
||||
|
||||
// Read the parameters from the command-line.
|
||||
for (int ii = 1; ii < argc; ++ii)
|
||||
{
|
||||
@ -865,13 +876,16 @@ int main(int argc, char** argv)
|
||||
{
|
||||
sage_block_size_v = strtol(argv[ii], nullptr, 10);
|
||||
}
|
||||
else if (!strcmp(argv[ii], "-use-attention-sinks"))
|
||||
{
|
||||
use_attention_sinks = true;
|
||||
}
|
||||
else
|
||||
{
|
||||
fprintf(stderr, "Unrecognized option: %s. Aborting!\n", argv[ii]);
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
if (save_softmax == true)
|
||||
{
|
||||
if (input_layout != Attention_input_layout::CONTIGUOUS_Q_KV)
|
||||
@ -1043,11 +1057,11 @@ int main(int argc, char** argv)
|
||||
force_non_granular_tiling, force_fp32_acc, props);
|
||||
|
||||
// The Q, K and V matrices are packed into one big matrix of size S x B x H x 3 x D.
|
||||
size_t const qkv_size = s * b * h * (2 * d + dv);
|
||||
const size_t qkv_size = s * b * h * (2 * d + dv);
|
||||
// Allocate on the host.
|
||||
float* qkv_h = (float*) malloc(qkv_size * sizeof(float));
|
||||
// The size in bytes.
|
||||
size_t const qkv_size_in_bytes = get_size_in_bytes(qkv_size, data_type);
|
||||
const size_t qkv_size_in_bytes = get_size_in_bytes(qkv_size, data_type);
|
||||
// Allocate on the device.
|
||||
void *qkv_sbh3d_d = nullptr, *qkv_bsh3d_d = nullptr;
|
||||
FMHA_CHECK_CUDA(cudaMalloc(&qkv_sbh3d_d, qkv_size_in_bytes));
|
||||
@ -1057,7 +1071,7 @@ int main(int argc, char** argv)
|
||||
// The shape is [B, 2, S, H, D].
|
||||
const size_t kv_size = b * s * h_kv * (d + dv);
|
||||
// The size in bytes.
|
||||
size_t const kv_size_in_bytes = get_size_in_bytes(kv_size, data_type);
|
||||
const size_t kv_size_in_bytes = get_size_in_bytes(kv_size, data_type);
|
||||
// Allocate on the host.
|
||||
void* contiguous_kv_h = malloc(kv_size_in_bytes);
|
||||
// Memset the buffer.
|
||||
@ -1071,13 +1085,13 @@ int main(int argc, char** argv)
|
||||
void** kv_cache_ptrs_h = nullptr;
|
||||
void* kv_cache_pool_ptr = nullptr;
|
||||
int32_t *kv_cache_block_offsets_h, *kv_cache_block_offsets_d = nullptr;
|
||||
size_t const max_blocks_per_seq = (s + tokens_per_block - 1) / tokens_per_block;
|
||||
size_t const num_total_blocks = b * 2 * max_blocks_per_seq;
|
||||
const size_t max_blocks_per_seq = (s + tokens_per_block - 1) / tokens_per_block;
|
||||
const size_t num_total_blocks = b * 2 * max_blocks_per_seq;
|
||||
kv_cache_ptrs_h = (void**) malloc(num_total_blocks * sizeof(void*));
|
||||
kv_cache_block_offsets_h = (int32_t*) malloc(num_total_blocks * sizeof(int32_t));
|
||||
size_t const paged_kv_block_size_in_bytes = get_size_in_bytes(tokens_per_block * h_kv * std::gcd(d, dv), data_type);
|
||||
const size_t paged_kv_block_size_in_bytes = get_size_in_bytes(tokens_per_block * h_kv * std::gcd(d, dv), data_type);
|
||||
FMHA_CHECK_CUDA(cudaMalloc((void**) (&kv_cache_block_offsets_d), num_total_blocks * sizeof(int32_t)));
|
||||
size_t const kv_cache_pool_sz
|
||||
const size_t kv_cache_pool_sz
|
||||
= get_size_in_bytes(num_total_blocks * tokens_per_block * h_kv * (d + dv) / 2, data_type);
|
||||
FMHA_CHECK_CUDA(cudaMalloc((void**) (&kv_cache_pool_ptr), kv_cache_pool_sz));
|
||||
size_t ptr_index = 0;
|
||||
@ -1104,7 +1118,7 @@ int main(int argc, char** argv)
|
||||
|
||||
// Q will always be [B, S, H, Dh] with paged kv cache.
|
||||
void* q_d;
|
||||
size_t const q_size = s * b * h * d;
|
||||
const size_t q_size = s * b * h * d;
|
||||
FMHA_CHECK_CUDA(cudaMalloc(&q_d, get_size_in_bytes(q_size, data_type)));
|
||||
|
||||
// K has [B, S, H_kv, D] with separate kv cache.
|
||||
@ -1122,11 +1136,11 @@ int main(int argc, char** argv)
|
||||
FMHA_CHECK_CUDA(cudaMalloc(&scale_bmm2_d, sizeof(uint32_t)));
|
||||
|
||||
// The mask for dropout or any mask patterns.
|
||||
size_t const mask_size = s * b * s;
|
||||
const size_t mask_size = s * b * s;
|
||||
// Allocate on the host.
|
||||
float* mask_h = (float*) malloc(mask_size * sizeof(float));
|
||||
// The size in bytes.
|
||||
size_t const mask_size_in_bytes = get_size_in_bytes(mask_size, DATA_TYPE_INT8);
|
||||
const size_t mask_size_in_bytes = get_size_in_bytes(mask_size, DATA_TYPE_INT8);
|
||||
// Allocate on the device.
|
||||
void* mask_d = nullptr;
|
||||
if (!skip_checks)
|
||||
@ -1158,7 +1172,7 @@ int main(int argc, char** argv)
|
||||
v1 ? 1 : 2);
|
||||
|
||||
// The number of threads per CTA.
|
||||
size_t const threads_per_cta = warps_m * warps_n * warps_k * 32;
|
||||
const size_t threads_per_cta = warps_m * warps_n * warps_k * 32;
|
||||
// The number of mmas in the M dimension. We use one uint32_t per MMA in the M dimension.
|
||||
size_t mmas_m = (s + 16 * warps_m - 1) / (16 * warps_m);
|
||||
// The number of mmas in the N dimension.
|
||||
@ -1182,7 +1196,7 @@ int main(int argc, char** argv)
|
||||
packed_mask_size = b * mmas_m * mmas_n * threads_per_cta;
|
||||
}
|
||||
// The size in bytes.
|
||||
size_t const packed_mask_size_in_bytes = packed_mask_size * sizeof(uint32_t);
|
||||
const size_t packed_mask_size_in_bytes = packed_mask_size * sizeof(uint32_t);
|
||||
// Allocate on the host.
|
||||
uint32_t* packed_mask_h = (uint32_t*) malloc(packed_mask_size_in_bytes);
|
||||
// Set it to 0 (indicates that all elements are valid).
|
||||
@ -1190,12 +1204,30 @@ int main(int argc, char** argv)
|
||||
// Allocate on the device.
|
||||
void* packed_mask_d = nullptr;
|
||||
|
||||
// The size of the attention sinks.
|
||||
const size_t attention_sinks_size_in_bytes = h * sizeof(float);
|
||||
|
||||
// The attention sinks.
|
||||
void* attention_sinks_d = nullptr;
|
||||
if (use_attention_sinks)
|
||||
{
|
||||
// Allocate on the host.
|
||||
float* attention_sinks_h = (float*) malloc(attention_sinks_size_in_bytes);
|
||||
// Randomly initialize the attention sinks.
|
||||
random_init("attention_sinks", attention_sinks_h, 1, h, 1, false, 5.f, 1.f, verbose);
|
||||
// Allocate on the device.
|
||||
FMHA_CHECK_CUDA(cudaMalloc(&attention_sinks_d, attention_sinks_size_in_bytes));
|
||||
// Copy from the host to the device.
|
||||
FMHA_CHECK_CUDA(
|
||||
cudaMemcpy(attention_sinks_d, attention_sinks_h, attention_sinks_size_in_bytes, cudaMemcpyDefault));
|
||||
}
|
||||
|
||||
// The O matrix is packed as S * B * H * D.
|
||||
size_t const o_size = s * b * h * dv;
|
||||
const size_t o_size = s * b * h * dv;
|
||||
// Allocate on the host.
|
||||
float* o_h = (float*) malloc(o_size * sizeof(float));
|
||||
// The size in bytes.
|
||||
size_t const o_size_in_bytes = get_size_in_bytes(o_size, data_type);
|
||||
const size_t o_size_in_bytes = get_size_in_bytes(o_size, data_type);
|
||||
// Allocate on the device.
|
||||
void* o_d = nullptr;
|
||||
FMHA_CHECK_CUDA(cudaMalloc(&o_d, o_size_in_bytes));
|
||||
@ -1206,7 +1238,7 @@ int main(int argc, char** argv)
|
||||
FMHA_CHECK_CUDA(cudaMemset(softmax_stats_d, 0x00, 2 * sizeof(float) * b * s * h));
|
||||
|
||||
// The size in bytes.
|
||||
size_t const tmp_size_in_bytes = get_size_in_bytes(o_size, acc_type);
|
||||
const size_t tmp_size_in_bytes = get_size_in_bytes(o_size, acc_type);
|
||||
// Allocate on the device.
|
||||
void* tmp_d = nullptr;
|
||||
if (data_type != acc_type)
|
||||
@ -1220,9 +1252,9 @@ int main(int argc, char** argv)
|
||||
float* softmax_sum_h = (float*) malloc(b * s * h * sizeof(float));
|
||||
|
||||
// The P matrix is stored as one big matrix of size S x B x H x S.
|
||||
size_t const p_size = s * b * h * s;
|
||||
const size_t p_size = s * b * h * s;
|
||||
// The size in bytes.
|
||||
size_t const p_size_in_bytes = get_size_in_bytes(p_size, acc_type);
|
||||
const size_t p_size_in_bytes = get_size_in_bytes(p_size, acc_type);
|
||||
// Allocate on the device.
|
||||
void* p_d = nullptr;
|
||||
if (!skip_checks)
|
||||
@ -1238,7 +1270,7 @@ int main(int argc, char** argv)
|
||||
#endif // defined(STORE_P)
|
||||
|
||||
// The size in bytes of the S matrix (the data type may be different from P for int8).
|
||||
size_t const s_size_in_bytes = get_size_in_bytes(p_size, data_type);
|
||||
const size_t s_size_in_bytes = get_size_in_bytes(p_size, data_type);
|
||||
// Allocate on the device.
|
||||
void* s_d = nullptr;
|
||||
if (!skip_checks)
|
||||
@ -1327,7 +1359,7 @@ int main(int argc, char** argv)
|
||||
|
||||
std::vector<uint32_t> seqlens(b, 0); // randomly draw a batch of sequence lengths >= min_s
|
||||
std::transform(seqlens.begin(), seqlens.end(), seqlens.begin(),
|
||||
[=](uint32_t const)
|
||||
[=](const uint32_t)
|
||||
{
|
||||
if (fix_s)
|
||||
{
|
||||
@ -1415,7 +1447,7 @@ int main(int argc, char** argv)
|
||||
FMHA_CHECK_CUDA(cudaMalloc(&mqa_qkv_packed_d, mqa_qkv_packed_size_in_bytes));
|
||||
FMHA_CHECK_CUDA(cudaMalloc(&mqa_qkv_d, mqa_qkv_size_in_bytes));
|
||||
|
||||
size_t const o_packed_size = cu_seqlens.back() * h * dv;
|
||||
const size_t o_packed_size = cu_seqlens.back() * h * dv;
|
||||
// Allocate on the host.
|
||||
float* o_packed_h = (float*) malloc(o_packed_size * sizeof(float));
|
||||
void* o_packed_d = nullptr;
|
||||
@ -1676,9 +1708,9 @@ int main(int argc, char** argv)
|
||||
total, num_grouped_heads, sliding_window_size, chunked_attention_size,
|
||||
// Paged kv cache.
|
||||
tokens_per_block, qkv_d_view, q_d, k_d, v_d, contiguous_kv_d, kv_cache_pool_ptr, kv_cache_block_offsets_d,
|
||||
packed_mask_d, cu_mask_rows_d, cu_seqlens_d, cu_q_seqlens_d, o_d_view, p_d, s_d, softmax_stats_ptr,
|
||||
scale_bmm2_d, scale_bmm1, scale_softmax, scale_bmm2, softcapping_scale_bmm1, use_int8_scale_max, interleaved,
|
||||
is_s_padded, has_alibi);
|
||||
packed_mask_d, cu_mask_rows_d, attention_sinks_d, cu_seqlens_d, cu_q_seqlens_d, o_d_view, p_d, s_d,
|
||||
softmax_stats_ptr, scale_bmm2_d, scale_bmm1, scale_softmax, scale_bmm2, softcapping_scale_bmm1,
|
||||
use_int8_scale_max, interleaved, is_s_padded, has_alibi);
|
||||
|
||||
// total number of tokens is needed to set TMA desc on the host.
|
||||
launch_params.total_q_seqlen = q_seqlens[b];
|
||||
@ -1894,8 +1926,8 @@ int main(int argc, char** argv)
|
||||
ground_truth(bmm1, bmm2, data_type, acc_type, scale_bmm1, scale_softmax, scale_bmm2, softcapping_scale_bmm1,
|
||||
qkv_sbh3d_d,
|
||||
vt_d, // WAR pass in V'
|
||||
mask_d, p_d, s_d, tmp_d, o_d, softmax_stats_d, cu_seqlens_d, b, s, h, d, dv, runs, warps_m, warps_n,
|
||||
has_alibi);
|
||||
mask_d, attention_sinks_d, p_d, s_d, tmp_d, o_d, softmax_stats_d, cu_seqlens_d, b, s, h, d, dv, runs,
|
||||
warps_m, warps_n, has_alibi);
|
||||
timer.stop();
|
||||
FMHA_CHECK_CUDA(cudaPeekAtLastError());
|
||||
FMHA_CHECK_CUDA(cudaDeviceSynchronize());
|
||||
@ -2009,7 +2041,6 @@ int main(int argc, char** argv)
|
||||
// Extract the last s_q tokens from the output.
|
||||
extract_and_transpose_output<float>(
|
||||
o_ref_trans_h.data(), o_ref_h, seqlens, q_seqlens, s, s_q, b, h, dv, is_s_padded);
|
||||
|
||||
if (verbose)
|
||||
{
|
||||
printf("\nChecking .....: O = V * S\n");
|
||||
|
||||
@ -197,6 +197,9 @@ struct Fused_multihead_attention_params_v2 : Fused_multihead_attention_params_ba
|
||||
// The stride between rows of softmax_stats_ptr
|
||||
int64_t softmax_stats_stride_in_bytes;
|
||||
|
||||
// The attention sinks (per head).
|
||||
float* attention_sinks;
|
||||
|
||||
// array of length b+1 holding prefix sum of actual q sequence lengths.
|
||||
int* cu_q_seqlens;
|
||||
// array of length b+1 holding prefix sum of actual kv sequence lengths.
|
||||
|
||||
@ -87,6 +87,8 @@ struct Fused_multihead_attention_params_v2
|
||||
fmha::Kv_block_array paged_kv_cache;
|
||||
// The mask to implement drop-out.
|
||||
void* packed_mask_ptr;
|
||||
// The attention sinks (per head).
|
||||
float* attention_sinks;
|
||||
// The O matrix (output).
|
||||
void* o_ptr;
|
||||
// The Softmax stats vector of layout [2, B, S, H], including softmax_sum and softmax_max
|
||||
|
||||
@ -23,28 +23,30 @@ using Launch_params = bert::Fused_multihead_attention_launch_params;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void run_softmax_fp32(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_seqlens_q_d,
|
||||
int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n, bool has_alibi);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void run_softmax_e4m3(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_seqlens_q_d,
|
||||
int s_inner, int s_outer, int b, int h, float scale_softmax, float softcapping_scale_bmm1, int warps_n,
|
||||
void run_softmax_fp32(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d,
|
||||
void* cu_seqlens_q_d, int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n,
|
||||
bool has_alibi);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void run_softmax_fp16(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_seqlens_q_d,
|
||||
int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n, bool has_alibi);
|
||||
void run_softmax_e4m3(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d,
|
||||
void* cu_seqlens_q_d, int s_inner, int s_outer, int b, int h, float scale_softmax, float softcapping_scale_bmm1,
|
||||
int warps_n, bool has_alibi);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void run_softmax_int8(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_seqlens_q_d,
|
||||
int s_inner, int s_outer, int b, int h, float scale_i2f, float scale_f2i, float softcapping_scale_bmm1, int warps_n,
|
||||
void run_softmax_fp16(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d,
|
||||
void* cu_seqlens_q_d, int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n,
|
||||
bool has_alibi);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void run_softmax_int8(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d,
|
||||
void* cu_seqlens_q_d, int s_inner, int s_outer, int b, int h, float scale_i2f, float scale_f2i,
|
||||
float softcapping_scale_bmm1, int warps_n, bool has_alibi);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void run_conversion_int32_to_int8(void* dst, void const* src, int s, int b, int h, int d, float scale);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -57,10 +59,10 @@ void run_conversion_fp32_to_e4m3(void* dst, void const* src, int s, int b, int h
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void ground_truth(RefBMM& bmm1, RefBMM& bmm2, Data_type const data_type, Data_type const acc_type,
|
||||
void ground_truth(RefBMM& bmm1, RefBMM& bmm2, const Data_type data_type, const Data_type acc_type,
|
||||
float const scale_bmm1, float const scale_softmax, float const scale_bmm2, void* q_d, void* kv_d, void* vt_d,
|
||||
void* mask_d, void* p_d, void* s_d, void* tmp_d, void* o_d, void* softmax_sum_d, void* cu_seqlens_q_d,
|
||||
size_t const b, size_t const s_q, size_t const s_kv, size_t const h, size_t const d, int const runs,
|
||||
const size_t b, const size_t s_q, const size_t s_kv, const size_t h, const size_t d, int const runs,
|
||||
int const warps_m, int const warps_n, bool has_alibi)
|
||||
{
|
||||
|
||||
@ -84,20 +86,22 @@ void ground_truth(RefBMM& bmm1, RefBMM& bmm2, Data_type const data_type, Data_ty
|
||||
// Softmax.
|
||||
if (data_type == DATA_TYPE_FP16 && acc_type == DATA_TYPE_FP16)
|
||||
{
|
||||
run_softmax_fp16(s_d, p_d, mask_d, softmax_sum_d, cu_seqlens_q_d, s_kv, s_q, b, h, 0.f, warps_n, has_alibi);
|
||||
run_softmax_fp16(
|
||||
s_d, p_d, mask_d, nullptr, softmax_sum_d, cu_seqlens_q_d, s_kv, s_q, b, h, 0.f, warps_n, has_alibi);
|
||||
}
|
||||
else if (data_type == DATA_TYPE_FP16 && acc_type == DATA_TYPE_FP32)
|
||||
{
|
||||
run_softmax_fp32(s_d, p_d, mask_d, softmax_sum_d, cu_seqlens_q_d, s_kv, s_q, b, h, 0.f, warps_n, has_alibi);
|
||||
run_softmax_fp32(
|
||||
s_d, p_d, mask_d, nullptr, softmax_sum_d, cu_seqlens_q_d, s_kv, s_q, b, h, 0.f, warps_n, has_alibi);
|
||||
}
|
||||
else if (data_type == DATA_TYPE_E4M3 && acc_type == DATA_TYPE_FP32)
|
||||
{
|
||||
run_softmax_e4m3(s_d, p_d, mask_d, softmax_sum_d, cu_seqlens_q_d, s_kv, s_q, b, h, scale_softmax, 0.f,
|
||||
warps_n, has_alibi);
|
||||
run_softmax_e4m3(s_d, p_d, mask_d, nullptr, softmax_sum_d, cu_seqlens_q_d, s_kv, s_q, b, h, scale_softmax,
|
||||
0.f, warps_n, has_alibi);
|
||||
}
|
||||
else if (data_type == DATA_TYPE_INT8 && acc_type == DATA_TYPE_INT32)
|
||||
{
|
||||
run_softmax_int8(s_d, p_d, mask_d, softmax_sum_d, cu_seqlens_q_d, s_kv, s_q, b, h, scale_bmm1,
|
||||
run_softmax_int8(s_d, p_d, mask_d, nullptr, softmax_sum_d, cu_seqlens_q_d, s_kv, s_q, b, h, scale_bmm1,
|
||||
scale_softmax, 0.f, warps_n, has_alibi);
|
||||
}
|
||||
else
|
||||
@ -148,8 +152,8 @@ static inline void set_params(bert::Fused_multihead_attention_params_mhca& param
|
||||
// types
|
||||
Data_type data_type, Data_type acc_type,
|
||||
// sizes
|
||||
size_t const b, size_t const s_q, size_t const s_kv, size_t const h, size_t const d, size_t const d_padded,
|
||||
size_t const total,
|
||||
const size_t b, const size_t s_q, const size_t s_kv, const size_t h, const size_t d, const size_t d_padded,
|
||||
const size_t total,
|
||||
// device pointers
|
||||
void* q_packed_d, void* kv_packed_d, void* cu_seqlens_q_d, void* cu_seqlens_kv_d, void* o_packed_d, void* p_d,
|
||||
void* s_d,
|
||||
@ -515,17 +519,17 @@ int main(int argc, char** argv)
|
||||
launch_params.use_tma = use_tma;
|
||||
|
||||
// The Q matrix of size S_Q x B x H x D.
|
||||
size_t const q_size = s_q * b * h * d;
|
||||
const size_t q_size = s_q * b * h * d;
|
||||
// The K and V matrices are packed into one big matrix of size S_KV x B x H x 2 x D.
|
||||
size_t const kv_size = s_kv_padded * b * h * 2 * d;
|
||||
const size_t kv_size = s_kv_padded * b * h * 2 * d;
|
||||
// Allocate on the host.
|
||||
float* q_h = (float*) malloc(q_size * sizeof(float));
|
||||
// Allocate on the host.
|
||||
float* kv_h = (float*) malloc(kv_size * sizeof(float));
|
||||
// The size in bytes.
|
||||
size_t const q_size_in_bytes = get_size_in_bytes(q_size, data_type);
|
||||
const size_t q_size_in_bytes = get_size_in_bytes(q_size, data_type);
|
||||
// The size in bytes.
|
||||
size_t const kv_size_in_bytes = get_size_in_bytes(kv_size, data_type);
|
||||
const size_t kv_size_in_bytes = get_size_in_bytes(kv_size, data_type);
|
||||
// Allocate on the device.
|
||||
void* q_d = nullptr;
|
||||
FMHA_CHECK_CUDA(cudaMalloc(&q_d, q_size_in_bytes));
|
||||
@ -534,11 +538,11 @@ int main(int argc, char** argv)
|
||||
FMHA_CHECK_CUDA(cudaMalloc(&kv_d, kv_size_in_bytes));
|
||||
|
||||
// The mask for dropout.
|
||||
size_t const mask_size = s_q * b * s_kv_padded;
|
||||
const size_t mask_size = s_q * b * s_kv_padded;
|
||||
// Allocate on the host.
|
||||
float* mask_h = (float*) malloc(mask_size * sizeof(float));
|
||||
// The size in bytes.
|
||||
size_t const mask_size_in_bytes = get_size_in_bytes(mask_size, DATA_TYPE_INT8);
|
||||
const size_t mask_size_in_bytes = get_size_in_bytes(mask_size, DATA_TYPE_INT8);
|
||||
// Allocate on the device.
|
||||
void* mask_d = nullptr;
|
||||
FMHA_CHECK_CUDA(cudaMalloc(&mask_d, mask_size_in_bytes));
|
||||
@ -554,28 +558,28 @@ int main(int argc, char** argv)
|
||||
v1 ? 1 : 2);
|
||||
|
||||
// The number of threads per CTA.
|
||||
size_t const threads_per_cta = warps_m * warps_n * warps_k * 32;
|
||||
const size_t threads_per_cta = warps_m * warps_n * warps_k * 32;
|
||||
// The number of mmas in the M dimension. We use one uint32_t per MMA in the M dimension.
|
||||
size_t const mmas_m = (s_q + 16 * warps_m - 1) / (16 * warps_m);
|
||||
const size_t mmas_m = (s_q + 16 * warps_m - 1) / (16 * warps_m);
|
||||
// The number of mmas in the N dimension.
|
||||
size_t const mmas_n = (s_kv_padded + 16 * warps_n - 1) / (16 * warps_n);
|
||||
const size_t mmas_n = (s_kv_padded + 16 * warps_n - 1) / (16 * warps_n);
|
||||
// We do not support more than 4 MMAS in the N dimension (as each MMA needs 8 bits in the mask).
|
||||
assert(!v1 || mmas_n <= 4);
|
||||
// The packed mask for dropout (in the fused kernel). Layout is B * MMAS_M * THREADS_PER_CTA.
|
||||
size_t const packed_mask_size = b * mmas_m * threads_per_cta;
|
||||
const size_t packed_mask_size = b * mmas_m * threads_per_cta;
|
||||
// The size in bytes.
|
||||
size_t const packed_mask_size_in_bytes = packed_mask_size * sizeof(uint32_t);
|
||||
const size_t packed_mask_size_in_bytes = packed_mask_size * sizeof(uint32_t);
|
||||
// Allocate on the host.
|
||||
uint32_t* packed_mask_h = (uint32_t*) malloc(packed_mask_size_in_bytes);
|
||||
// Allocate on the device.
|
||||
void* packed_mask_d = nullptr;
|
||||
|
||||
// The O matrix is packed as S_Q * B * H * D.
|
||||
size_t const o_size = s_q * b * h * d;
|
||||
const size_t o_size = s_q * b * h * d;
|
||||
// Allocate on the host.
|
||||
float* o_h = (float*) malloc(o_size * sizeof(float));
|
||||
// The size in bytes.
|
||||
size_t const o_size_in_bytes = get_size_in_bytes(o_size, data_type);
|
||||
const size_t o_size_in_bytes = get_size_in_bytes(o_size, data_type);
|
||||
// Allocate on the device.
|
||||
void* o_d = nullptr;
|
||||
FMHA_CHECK_CUDA(cudaMalloc(&o_d, o_size_in_bytes));
|
||||
@ -587,7 +591,7 @@ int main(int argc, char** argv)
|
||||
FMHA_CHECK_CUDA(cudaMemset(softmax_max_d, 0x00, sizeof(float) * b * s_q * h));
|
||||
|
||||
// The size in bytes.
|
||||
size_t const tmp_size_in_bytes = get_size_in_bytes(o_size, acc_type);
|
||||
const size_t tmp_size_in_bytes = get_size_in_bytes(o_size, acc_type);
|
||||
// Allocate on the device.
|
||||
void* tmp_d = nullptr;
|
||||
if (data_type != acc_type)
|
||||
@ -599,9 +603,9 @@ int main(int argc, char** argv)
|
||||
float* o_ref_h = (float*) malloc(o_size * sizeof(float));
|
||||
|
||||
// The P matrix is stored as one big matrix of size S_Q x B x H x S_KV.
|
||||
size_t const p_size = s_q * b * h * s_kv_padded;
|
||||
const size_t p_size = s_q * b * h * s_kv_padded;
|
||||
// The size in bytes.
|
||||
size_t const p_size_in_bytes = get_size_in_bytes(p_size, acc_type);
|
||||
const size_t p_size_in_bytes = get_size_in_bytes(p_size, acc_type);
|
||||
// Allocate on the device.
|
||||
void* p_d = nullptr;
|
||||
FMHA_CHECK_CUDA(cudaMalloc(&p_d, p_size_in_bytes));
|
||||
@ -614,7 +618,7 @@ int main(int argc, char** argv)
|
||||
#endif // defined(STORE_P)
|
||||
|
||||
// The size in bytes of the S matrix (the data type may be different from P for int8).
|
||||
size_t const s_size_in_bytes = get_size_in_bytes(p_size, data_type);
|
||||
const size_t s_size_in_bytes = get_size_in_bytes(p_size, data_type);
|
||||
// Allocate on the device.
|
||||
void* s_d = nullptr;
|
||||
FMHA_CHECK_CUDA(cudaMalloc(&s_d, s_size_in_bytes));
|
||||
@ -634,9 +638,9 @@ int main(int argc, char** argv)
|
||||
|
||||
// WAR fOR MISSING CUBLAS FP8 NN SUPPORT.
|
||||
// Transpose V, so that we can do a TN BMM2, i.e. O = S x V' instead of O = S x V.
|
||||
size_t const v_size = s_kv_padded * b * h * d;
|
||||
const size_t v_size = s_kv_padded * b * h * d;
|
||||
// The size in bytes.
|
||||
size_t const v_size_in_bytes = get_size_in_bytes(v_size, data_type);
|
||||
const size_t v_size_in_bytes = get_size_in_bytes(v_size, data_type);
|
||||
float* vt_h = (float*) malloc(v_size * sizeof(float));
|
||||
void* vt_d = nullptr;
|
||||
FMHA_CHECK_CUDA(cudaMalloc(&vt_d, v_size_in_bytes));
|
||||
@ -676,7 +680,7 @@ int main(int argc, char** argv)
|
||||
= [min_s, fix_s, b](int s, std::vector<uint32_t>& seqlens, std::vector<int>& cu_seqlens, void** cu_seqlens_d)
|
||||
{
|
||||
std::transform(seqlens.begin(), seqlens.end(), seqlens.begin(),
|
||||
[=](uint32_t const)
|
||||
[=](const uint32_t)
|
||||
{
|
||||
if (fix_s)
|
||||
{
|
||||
@ -728,7 +732,7 @@ int main(int argc, char** argv)
|
||||
void* kv_packed_d = nullptr;
|
||||
FMHA_CHECK_CUDA(cudaMalloc(&kv_packed_d, kv_packed_size_in_bytes));
|
||||
|
||||
size_t const o_packed_size = cu_seqlens_q.back() * h * d;
|
||||
const size_t o_packed_size = cu_seqlens_q.back() * h * d;
|
||||
// Allocate on the host.
|
||||
float* o_packed_h = (float*) malloc(o_packed_size * sizeof(float));
|
||||
float* o_ref_packed_h = (float*) malloc(o_packed_size * sizeof(float));
|
||||
|
||||
@ -12,9 +12,10 @@
|
||||
|
||||
#include "softmax_impl.h"
|
||||
|
||||
void run_softmax_bf16(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_q_seqlens_d,
|
||||
int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n, bool has_alibi)
|
||||
void run_softmax_bf16(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d,
|
||||
void* cu_q_seqlens_d, int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n,
|
||||
bool has_alibi)
|
||||
{
|
||||
run_softmax<fmha::bf16_t, float>(dst, src, mask, softmax_sum_d, cu_q_seqlens_d, s_inner, s_outer, b, h, 0.f, 0.f,
|
||||
softcapping_scale_bmm1, warps_n, has_alibi);
|
||||
run_softmax<fmha::bf16_t, float>(dst, src, mask, attention_sinks, softmax_sum_d, cu_q_seqlens_d, s_inner, s_outer,
|
||||
b, h, 0.f, 0.f, softcapping_scale_bmm1, warps_n, has_alibi);
|
||||
}
|
||||
|
||||
@ -12,9 +12,10 @@
|
||||
|
||||
#include "softmax_impl.h"
|
||||
|
||||
void run_softmax_fp16(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_q_seqlens_d,
|
||||
int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n, bool has_alibi)
|
||||
void run_softmax_fp16(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d,
|
||||
void* cu_q_seqlens_d, int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n,
|
||||
bool has_alibi)
|
||||
{
|
||||
run_softmax<uint16_t, uint16_t>(dst, src, mask, softmax_sum_d, cu_q_seqlens_d, s_inner, s_outer, b, h, 0.f, 0.f,
|
||||
softcapping_scale_bmm1, warps_n, has_alibi);
|
||||
run_softmax<uint16_t, uint16_t>(dst, src, mask, attention_sinks, softmax_sum_d, cu_q_seqlens_d, s_inner, s_outer, b,
|
||||
h, 0.f, 0.f, softcapping_scale_bmm1, warps_n, has_alibi);
|
||||
}
|
||||
|
||||
@ -12,9 +12,10 @@
|
||||
|
||||
#include "softmax_impl.h"
|
||||
|
||||
void run_softmax_fp32(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_q_seqlens_d,
|
||||
int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n, bool has_alibi)
|
||||
void run_softmax_fp32(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d,
|
||||
void* cu_q_seqlens_d, int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n,
|
||||
bool has_alibi)
|
||||
{
|
||||
run_softmax<fmha::fp16_t, float>(dst, src, mask, softmax_sum_d, cu_q_seqlens_d, s_inner, s_outer, b, h, 0.f, 0.f,
|
||||
softcapping_scale_bmm1, warps_n, has_alibi);
|
||||
run_softmax<fmha::fp16_t, float>(dst, src, mask, attention_sinks, softmax_sum_d, cu_q_seqlens_d, s_inner, s_outer,
|
||||
b, h, 0.f, 0.f, softcapping_scale_bmm1, warps_n, has_alibi);
|
||||
}
|
||||
|
||||
@ -12,10 +12,10 @@
|
||||
|
||||
#include "softmax_impl.h"
|
||||
|
||||
void run_softmax_e4m3(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_q_seqlens_d,
|
||||
int s_inner, int s_outer, int b, int h, float scale_softmax, float softcapping_scale_bmm1, int warps_n,
|
||||
bool has_alibi)
|
||||
void run_softmax_e4m3(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d,
|
||||
void* cu_q_seqlens_d, int s_inner, int s_outer, int b, int h, float scale_softmax, float softcapping_scale_bmm1,
|
||||
int warps_n, bool has_alibi)
|
||||
{
|
||||
run_softmax<fmha::e4m3_t, float>(dst, src, mask, softmax_sum_d, cu_q_seqlens_d, s_inner, s_outer, b, h, 0.f,
|
||||
scale_softmax, softcapping_scale_bmm1, warps_n, has_alibi);
|
||||
run_softmax<fmha::e4m3_t, float>(dst, src, mask, attention_sinks, softmax_sum_d, cu_q_seqlens_d, s_inner, s_outer,
|
||||
b, h, 0.f, scale_softmax, softcapping_scale_bmm1, warps_n, has_alibi);
|
||||
}
|
||||
|
||||
@ -10,6 +10,7 @@
|
||||
* its affiliates is strictly prohibited.
|
||||
*/
|
||||
|
||||
#include <cfloat>
|
||||
#include <cstdio>
|
||||
#include <fmha/numeric_types.h>
|
||||
#include <fmha/utils.h>
|
||||
@ -33,6 +34,8 @@ struct Softmax_params
|
||||
Src_type const* src;
|
||||
// Masks.
|
||||
int8_t const* mask;
|
||||
// Attention sinks (per head).
|
||||
float const* attention_sinks;
|
||||
// Softmax sum pointer.
|
||||
float* softmax_sum;
|
||||
// ALiBi
|
||||
@ -148,7 +151,8 @@ static inline __device__ float apply_exp_(float x, float max)
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int N>
|
||||
static inline __device__ void reduce(float (&data_fp32)[N][1], int8_t const (&mask)[N][1], int warps_n, float& sum_fp32)
|
||||
static inline __device__ void reduce(
|
||||
float (&data_fp32)[N][1], const int8_t (&mask)[N][1], int warps_n, float& sum_fp32, float const attention_sink)
|
||||
{
|
||||
|
||||
// Apply the masks.
|
||||
@ -233,7 +237,7 @@ static inline __device__ void reduce(float (&data_fp32)[N][1], int8_t const (&ma
|
||||
}
|
||||
|
||||
// Normalize.
|
||||
float inv_sum_fp32 = 1.f / sum_fp32;
|
||||
float inv_sum_fp32 = 1.f / (sum_fp32 + expf(attention_sink - max_fp32));
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < N; ++ii)
|
||||
{
|
||||
@ -244,7 +248,8 @@ static inline __device__ void reduce(float (&data_fp32)[N][1], int8_t const (&ma
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int N>
|
||||
static inline __device__ void reduce(float (&data_fp32)[N][2], int8_t const (&mask)[N][2], int warps_n, float& sum_fp32)
|
||||
static inline __device__ void reduce(
|
||||
float (&data_fp32)[N][2], const int8_t (&mask)[N][2], int warps_n, float& sum_fp32, float const attention_sink)
|
||||
{
|
||||
// Apply the masks.
|
||||
#pragma unroll
|
||||
@ -401,7 +406,7 @@ static inline __device__ void reduce(float (&data_fp32)[N][2], int8_t const (&ma
|
||||
}
|
||||
|
||||
// Normalize.
|
||||
float inv_sum_fp32 = 1.f / sum_fp32;
|
||||
float inv_sum_fp32 = 1.f / (sum_fp32 + expf(attention_sink - max_fp32));
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < N; ++ii)
|
||||
{
|
||||
@ -413,7 +418,8 @@ static inline __device__ void reduce(float (&data_fp32)[N][2], int8_t const (&ma
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int N>
|
||||
static inline __device__ void reduce(float (&data_fp32)[N][4], int8_t const (&mask)[N][4], int warps_n, float& sum_fp32)
|
||||
static inline __device__ void reduce(
|
||||
float (&data_fp32)[N][4], const int8_t (&mask)[N][4], int warps_n, float& sum_fp32, float const attention_sink)
|
||||
{
|
||||
|
||||
// Apply the masks.
|
||||
@ -824,7 +830,7 @@ static inline __device__ void reduce(float (&data_fp32)[N][4], int8_t const (&ma
|
||||
}
|
||||
|
||||
// Normalize.
|
||||
float inv_sum_fp32 = 1.f / sum_fp32;
|
||||
float inv_sum_fp32 = 1.f / (sum_fp32 + expf(attention_sink - max_fp32));
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < N; ++ii)
|
||||
{
|
||||
@ -994,9 +1000,16 @@ static __global__ void softmax_kernel(Softmax_params<Dst_type, Src_type> params)
|
||||
}
|
||||
}
|
||||
|
||||
// The attention sink value.
|
||||
float attention_sink = -FLT_MAX;
|
||||
if (params.attention_sinks != nullptr)
|
||||
{
|
||||
attention_sink = params.attention_sinks[hi];
|
||||
}
|
||||
|
||||
// Do the reduction.
|
||||
float sum_fp32 = 0.f;
|
||||
reduce(data_fp32, mask_, params.warps_n, sum_fp32);
|
||||
reduce(data_fp32, mask_, params.warps_n, sum_fp32, attention_sink);
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
int sum_s = params.cu_q_seqlens[bi];
|
||||
@ -1025,9 +1038,9 @@ static __global__ void softmax_kernel(Softmax_params<Dst_type, Src_type> params)
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Dst_type, typename Src_type>
|
||||
void run_softmax(void* dst, void const* src, void const* mask, void* softmax_sum, void* cu_q_seqlens, int s_inner,
|
||||
int s_outer, int b, int h, float scale_bmm1, float scale_softmax, float softcapping_scale_bmm1, int warps_n,
|
||||
bool has_alibi)
|
||||
void run_softmax(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum,
|
||||
void* cu_q_seqlens, int s_inner, int s_outer, int b, int h, float scale_bmm1, float scale_softmax,
|
||||
float softcapping_scale_bmm1, int warps_n, bool has_alibi)
|
||||
{
|
||||
|
||||
Softmax_params<Dst_type, Src_type> params;
|
||||
@ -1039,6 +1052,7 @@ void run_softmax(void* dst, void const* src, void const* mask, void* softmax_sum
|
||||
params.softmax_sum = reinterpret_cast<float*>(softmax_sum);
|
||||
params.cu_q_seqlens = reinterpret_cast<int*>(cu_q_seqlens);
|
||||
params.mask = reinterpret_cast<int8_t const*>(mask);
|
||||
params.attention_sinks = reinterpret_cast<float const*>(attention_sinks);
|
||||
params.has_alibi = has_alibi;
|
||||
|
||||
// The dimensions and precomputed values.
|
||||
|
||||
@ -12,10 +12,10 @@
|
||||
|
||||
#include "softmax_impl.h"
|
||||
|
||||
void run_softmax_int8(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_q_seqlens_d,
|
||||
int s_inner, int s_outer, int b, int h, float scale_bmm1, float scale_softmax, float softcapping_scale_bmm1,
|
||||
int warps_n, bool has_alibi)
|
||||
void run_softmax_int8(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d,
|
||||
void* cu_q_seqlens_d, int s_inner, int s_outer, int b, int h, float scale_bmm1, float scale_softmax,
|
||||
float softcapping_scale_bmm1, int warps_n, bool has_alibi)
|
||||
{
|
||||
run_softmax<int8_t, int32_t>(dst, src, mask, softmax_sum_d, cu_q_seqlens_d, s_inner, s_outer, b, h, scale_bmm1,
|
||||
scale_softmax, softcapping_scale_bmm1, warps_n, has_alibi);
|
||||
run_softmax<int8_t, int32_t>(dst, src, mask, attention_sinks, softmax_sum_d, cu_q_seqlens_d, s_inner, s_outer, b, h,
|
||||
scale_bmm1, scale_softmax, softcapping_scale_bmm1, warps_n, has_alibi);
|
||||
}
|
||||
|
||||
@ -1379,6 +1379,19 @@ __device__ inline ThrdRegRowMax mergeRowMax(
|
||||
return mergedRowMax;
|
||||
}
|
||||
|
||||
__device__ inline void addAttentionSinks(
|
||||
ThrdRegRowMax& globalRowSum, ThrdRegRowMax const globalRowMax, float const* attentionSinks)
|
||||
{
|
||||
for (uint32_t i = 0; i < globalRowSum.size; i++)
|
||||
{
|
||||
uint32_t srcOffset = warp_size * i + laneId();
|
||||
if (srcOffset < headGrpSize)
|
||||
{
|
||||
globalRowSum[i] += expf(attentionSinks[srcOffset] - globalRowMax[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef NDEBUG
|
||||
__device__ __forceinline__
|
||||
#else
|
||||
@ -1405,6 +1418,7 @@ CUBIN_EXPORT __global__
|
||||
#if SPEC_DEC
|
||||
MaskType const* __restrict__ mask, // [qSeqLen, divUp(qSeqLen, 32)].
|
||||
#endif
|
||||
float const* attentionSinks, // [headGrpSize]
|
||||
#ifdef NDEBUG
|
||||
KVCacheList<usePagedKVCache> const& cacheList,
|
||||
#if BEAM_WIDTH > 1
|
||||
@ -2371,6 +2385,12 @@ CUBIN_EXPORT __global__
|
||||
float voScale = (isKVCacheQuantized ? kvCacheScale[0] : 1.F);
|
||||
if (seqIterInit < nbSeqIters)
|
||||
{ // otherwise rcpRowSum will be NAN.
|
||||
// The attention sinks are moved to the multi-block reduction part if the multi-block is enabled.
|
||||
if (!isMultiBlock && attentionSinks != nullptr)
|
||||
{
|
||||
// Attention sinks are per head.
|
||||
addAttentionSinks(globalRowSum, globalRowMax, attentionSinks + headGrpSize * idxHeadGrp);
|
||||
}
|
||||
ThrdRegRowMax const rcpRowSum = __frcp_rn(globalRowSum);
|
||||
#if LOW_PREC_OUTPUT
|
||||
voScale *= rcpOutScale[0];
|
||||
@ -2559,6 +2579,11 @@ CUBIN_EXPORT __global__
|
||||
assert(std::isfinite(mergedRowSum[0]));
|
||||
}
|
||||
}
|
||||
if (attentionSinks != nullptr)
|
||||
{
|
||||
// Attention sinks are per head.
|
||||
addAttentionSinks(mergedRowSum, mergedRowMax, attentionSinks + headGrpSize * idxHeadGrp);
|
||||
}
|
||||
__syncthreads();
|
||||
rescaleAcc(warp, sumAcc, fullRescaleMask, __frcp_rn(mergedRowSum));
|
||||
GemmOutRegTile const mergedOutTile = toFp16(sumAcc);
|
||||
@ -2615,6 +2640,7 @@ CUBIN_EXPORT __global__ __launch_bounds__(256, nbCtaPerSM) void kernel_mha(
|
||||
MaskType const* __restrict__ mask, // [qSeqLen, divUp(qSeqLen, 32))] uint2 (each bit represents mask for one col
|
||||
// position).
|
||||
#endif
|
||||
float const* attentionSinks, // [headGrpSize]
|
||||
KVCacheList<usePagedKVCache> const cacheList,
|
||||
#if BEAM_WIDTH > 1
|
||||
BeamSearchParams const beamSearchParams,
|
||||
@ -2640,7 +2666,7 @@ CUBIN_EXPORT __global__ __launch_bounds__(256, nbCtaPerSM) void kernel_mha(
|
||||
#if SPEC_DEC
|
||||
mask,
|
||||
#endif
|
||||
cacheList,
|
||||
attentionSinks, cacheList,
|
||||
#if BEAM_WIDTH > 1
|
||||
beamSearchParams,
|
||||
#endif
|
||||
@ -2667,6 +2693,7 @@ void launchMHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
|
||||
#else
|
||||
InputHead const* q,
|
||||
#endif
|
||||
float const* attentionSinks, // [headGrpSize]
|
||||
#if USE_PAGED_KV_CACHE
|
||||
#if PAGED_KV_CACHE_LAYOUT == 1
|
||||
GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM,
|
||||
@ -2760,7 +2787,7 @@ void launchMHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
|
||||
#if SPEC_DEC
|
||||
mask,
|
||||
#endif
|
||||
cacheList,
|
||||
attentionSinks, cacheList,
|
||||
#if BEAM_WIDTH > 1
|
||||
beamSearchParams,
|
||||
#endif
|
||||
@ -2788,7 +2815,7 @@ void launchMHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
|
||||
#if SPEC_DEC
|
||||
mask,
|
||||
#endif
|
||||
cacheList,
|
||||
attentionSinks, cacheList,
|
||||
#if BEAM_WIDTH > 1
|
||||
beamSearchParams,
|
||||
#endif
|
||||
|
||||
@ -101,6 +101,7 @@ void launchMHA(cudaDeviceProp const& prop, uint32_t const nbKHeads,
|
||||
#else
|
||||
InputHead const* q,
|
||||
#endif
|
||||
float const* attentionSinks, // [headGrpSize]
|
||||
#if USE_PAGED_KV_CACHE
|
||||
#if PAGED_KV_CACHE_LAYOUT == 1
|
||||
GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM,
|
||||
@ -140,6 +141,7 @@ void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
|
||||
#else
|
||||
InputHead const* q,
|
||||
#endif
|
||||
float const* attentionSinks, // [headGrpSize]
|
||||
#if USE_PAGED_KV_CACHE
|
||||
#if PAGED_KV_CACHE_LAYOUT == 1
|
||||
GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM,
|
||||
|
||||
@ -428,6 +428,7 @@ __device__ RegColWiseVec computeWarpColSum(Gemm0Acc& src);
|
||||
__device__ void storeGemm0AccToShm(
|
||||
uint32_t warpRank, uint32_t lane, SharedMem::XBuffer& smemX, CtaBarrier& barConsumed, Gemm0Acc const& acc);
|
||||
__device__ RegColWiseVec loadShmColWiseVecWithDup(ShmQWiseVec const& smemVec);
|
||||
__device__ RegColWiseVec loadGmemColWiseVecWithDup(ShmQWiseVec const& gmemVec, uint32_t bound);
|
||||
#else
|
||||
__device__ RegRowWiseVec computeWarpGrpRowMax_sync(uint32_t warpRank, ShmQWiseVec& smemColMax, Gemm0Acc const& src);
|
||||
__device__ void warpGrpApplyMask(Gemm0Acc& acc, uint32_t validColBeg, uint32_t validColEnd);
|
||||
@ -453,7 +454,8 @@ __device__ void rescaleGemm1AccForNewColMax_sync(uint32_t warpRank, ShmQWiseVec
|
||||
template <bool dstIsStrided = false, typename DstHead>
|
||||
__device__ void finalizeAndWriteOut_sync(uint32_t threadRank, uint32_t warpRank, DstHead* dst,
|
||||
SharedMem::OutSwizzleBuf& swizzleBuf, Gemm1Acc& acc, float xvoScale, CtaBarrier& warpGrpBar,
|
||||
ShmQWiseVec const& accColSum, uint32_t nbKHeads = 0 /* only for final result in spec dec. */);
|
||||
ShmQWiseVec const& accColSum, ShmQWiseVec const& accColMax, ShmQWiseVec const* attentionSinksVec,
|
||||
uint32_t nbKHeads = 0 /* only for final result in spec dec. */);
|
||||
#else
|
||||
__device__ void transposeVTile(
|
||||
uint32_t warpRank, uint32_t lane, SharedMem::VTBuffer& dst, SharedMem::VBuffer const& src);
|
||||
@ -651,6 +653,7 @@ CUBIN_EXPORT __global__
|
||||
#else
|
||||
IOHead const* __restrict__ const q, // [nbReq][beamWidth][nbQHeads],
|
||||
#endif
|
||||
float const* attentionSinks, // [headGrpSize]
|
||||
KVCacheList<usePagedKVCache> const cacheList,
|
||||
#if USE_BEAM_SEARCH
|
||||
BeamSearchParams const beamSearchParams,
|
||||
@ -1252,7 +1255,7 @@ CUBIN_EXPORT __global__
|
||||
IOHead* const dst = (scratchMem.tokens() + idxChunk).template cast<IOHead>();
|
||||
#if SWAP_AB
|
||||
finalizeAndWriteOut_sync(threadIdx.x, warpRank, dst, smem.outSwizzleBuf(idxXBuf), acc, xvoScale,
|
||||
smem.gemm1WarpGrpBar, smem.gemm1AccColSum);
|
||||
smem.gemm1WarpGrpBar, smem.gemm1AccColSum, smem.gemm1AccColMax, nullptr);
|
||||
#else
|
||||
finalizeAndWriteOut_sync(warpRank, dst, smem.outSwizzleBuf(idxXBuf), acc, xvoScale,
|
||||
smem.gemm1AccColSum, 1, ctaNbValidTokens);
|
||||
@ -1262,9 +1265,16 @@ CUBIN_EXPORT __global__
|
||||
{
|
||||
uint32_t const outOffset = headGrpSize * (nbKHeads * (beamWidth * ctaInputTokBeg) + idxHeadGrp);
|
||||
OutputHead* const dst = &output[outOffset];
|
||||
ShmQWiseVec const* attentionSinksVec = nullptr;
|
||||
if (attentionSinks != nullptr)
|
||||
{
|
||||
attentionSinksVec
|
||||
= reinterpret_cast<ShmQWiseVec const*>(attentionSinks + headGrpSize * idxHeadGrp);
|
||||
}
|
||||
#if SWAP_AB
|
||||
finalizeAndWriteOut_sync<SPEC_DEC>(threadIdx.x, warpRank, dst, smem.outSwizzleBuf(idxXBuf), acc,
|
||||
xvoScale, smem.gemm1WarpGrpBar, smem.gemm1AccColSum, nbKHeads);
|
||||
xvoScale, smem.gemm1WarpGrpBar, smem.gemm1AccColSum, smem.gemm1AccColMax, attentionSinksVec,
|
||||
nbKHeads);
|
||||
#else
|
||||
finalizeAndWriteOut_sync(warpRank, dst, smem.outSwizzleBuf(idxXBuf), acc, xvoScale,
|
||||
smem.gemm1AccColSum, nbKHeads, ctaNbValidTokens);
|
||||
@ -1585,6 +1595,17 @@ CUBIN_EXPORT __global__
|
||||
}
|
||||
unused(bar.consumed.arrive());
|
||||
}
|
||||
// Add the attention sinks.
|
||||
if (attentionSinks != nullptr)
|
||||
{
|
||||
for (uint32_t i = 0; i < headsPerWarp; i++)
|
||||
{
|
||||
uint32_t const idxHead = wid + nbMathWarps * i;
|
||||
float sink = expf(
|
||||
attentionSinks[mha::min(idxHead, headGrpSize - 1) + idxHeadGrp * headGrpSize] - states[i].max);
|
||||
states[i].sum += sink;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
uint32_t const outOffset = headGrpSize * (nbKHeads * (beamWidth * ctaInputTokBeg) + idxHeadGrp);
|
||||
auto const dst = &output[outOffset];
|
||||
@ -2029,6 +2050,22 @@ __device__ inline RegColWiseVec loadShmColWiseVecWithDup(ShmQWiseVec const& smem
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ inline RegColWiseVec loadGmemColWiseVecWithDup(ShmQWiseVec const& gmemVec, uint32_t bound)
|
||||
{
|
||||
RegColWiseVec ret;
|
||||
constexpr uint32_t nbThrdsPerInstNBase = exactDiv(gmma::instNBase, GmmaAccCoreMat::cols);
|
||||
auto const idx = laneId() % nbThrdsPerInstNBase;
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < exactDiv(ShmQWiseVec::size, gmma::instNBase); i++)
|
||||
{
|
||||
static_assert(nbThrdsPerInstNBase * RegColWiseVec::size == exactDiv(ShmQWiseVec::size, GmmaAccCoreMat::cols));
|
||||
ret[i] = reinterpret_cast<
|
||||
Vec<Vec<float, GmmaAccCoreMat::cols>, exactDiv(ShmQWiseVec::size, GmmaAccCoreMat::cols)> const&>(
|
||||
gmemVec)[mha::min(i * nbThrdsPerInstNBase + idx, bound)];
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ inline void warpGrpApplyMask(uint32_t warpRank, Gemm0Acc& acc, uint32_t validRowBeg, uint32_t validRowEnd)
|
||||
{
|
||||
uint32_t const idxInQuad = laneId() % 4;
|
||||
@ -2878,12 +2915,19 @@ __device__ inline void saveTransposedOutput(uint32_t threadRank, uint32_t warpRa
|
||||
template <bool dstIsStrided, typename DstHead>
|
||||
__device__ inline void finalizeAndWriteOut_sync(uint32_t threadRank, uint32_t warpRank, DstHead* dst,
|
||||
SharedMem::OutSwizzleBuf& swizzleBuf, Gemm1Acc& acc, float xvoScale, CtaBarrier& warpGrpBar,
|
||||
ShmQWiseVec const& accColSum, uint32_t nbKHeads)
|
||||
ShmQWiseVec const& accColSum, ShmQWiseVec const& accColMax, ShmQWiseVec const* attentionSinksVec, uint32_t nbKHeads)
|
||||
{
|
||||
// @fixme: if ctaNbQHeads is large, use loadShmColWiseVecNoDup + rcp + shfl to avoid 8x waste of mufu.rcp
|
||||
// static_assert(ctaNbQHeads <= 8, "Warning: consider using loadShmColWiseVecNoDup + rcp + shfl to avoid 8x waste of
|
||||
// mufu.rcp");
|
||||
auto const regColSum = loadShmColWiseVecWithDup(accColSum);
|
||||
auto regColSum = loadShmColWiseVecWithDup(accColSum);
|
||||
if (attentionSinksVec != nullptr)
|
||||
{
|
||||
auto const regAccColMax = loadShmColWiseVecWithDup(accColMax);
|
||||
auto const regAttentionSinks = loadGmemColWiseVecWithDup(attentionSinksVec[0], headGrpSize - 1);
|
||||
auto regColSinks = expf(regAttentionSinks - regAccColMax);
|
||||
regColSum = regColSum + regColSinks;
|
||||
}
|
||||
auto const regOutScale = __frcp_rn(regColSum) * xvoScale;
|
||||
rescaleAcc(acc, regOutScale);
|
||||
|
||||
@ -3175,6 +3219,7 @@ void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
|
||||
#else
|
||||
InputHead const* q,
|
||||
#endif
|
||||
float const* attentionSinks, // [headGrpSize]
|
||||
#if USE_PAGED_KV_CACHE
|
||||
#if PAGED_KV_CACHE_LAYOUT == 1
|
||||
GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM,
|
||||
@ -3286,7 +3331,7 @@ void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
|
||||
#else
|
||||
q,
|
||||
#endif
|
||||
cacheList,
|
||||
attentionSinks, cacheList,
|
||||
#if USE_BEAM_SEARCH
|
||||
beamSearchParams,
|
||||
#endif
|
||||
@ -3322,7 +3367,7 @@ void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
|
||||
#else
|
||||
q,
|
||||
#endif
|
||||
cacheList,
|
||||
attentionSinks, cacheList,
|
||||
#if USE_BEAM_SEARCH
|
||||
beamSearchParams,
|
||||
#endif
|
||||
|
||||
@ -1859,12 +1859,13 @@ CUtensorMap makeTensorMapForQ(
|
||||
#endif // IS_MLA
|
||||
|
||||
void launchMLA(cudaDeviceProp const& prop,
|
||||
uint32_t inputSeqLen, // uniform for all requests and causal mask is assumed
|
||||
uint32_t inputSeqLen, // uniform for all requests and causal mask is assumed
|
||||
float qScale, OutputHead* output, InputHead const* q,
|
||||
float* attentionSinks, // [headGrpSize], not supported.
|
||||
#if USE_PAGED_KV_CACHE
|
||||
GMemCacheHead* pool, // global pool of pages
|
||||
GMemCacheHead* pool, // global pool of pages
|
||||
KVCachePageIndex const*
|
||||
kvCachePageList, // device pointer. shape: KVCachePage[batchSize][beamWidth][2][maxNbPagesPerSeq]
|
||||
kvCachePageList, // device pointer. shape: KVCachePage[batchSize][beamWidth][2][maxNbPagesPerSeq]
|
||||
#else
|
||||
GMemKVCacheHead* kvCacheData,
|
||||
#endif
|
||||
|
||||
@ -45,7 +45,7 @@ using Vector = Matrix<Type, Size, 1>;
|
||||
template <typename MathElem, uint32_t tileSize, bool isPaged, bool useBeamSearch>
|
||||
Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refFlashAttention(IOHead const* q,
|
||||
CacheSeq<isPaged, useBeamSearch> const& k, CacheSeq<isPaged, useBeamSearch> const& v, uint32_t seqLen, float qScale,
|
||||
float kvScale, float xScale, uint32_t slidingWinSize)
|
||||
float kvScale, float xScale, uint32_t slidingWinSize, float* attentionSinks)
|
||||
{
|
||||
uint32_t const nbTiles = divUp(seqLen, tileSize);
|
||||
auto gemm1Acc = Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor>::Zero().eval();
|
||||
@ -113,6 +113,16 @@ Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refFlashAt
|
||||
}
|
||||
rowSum += tileRowSum;
|
||||
}
|
||||
|
||||
// Add the attention sinks.
|
||||
if (attentionSinks != nullptr)
|
||||
{
|
||||
for (uint32_t i = 0; i < headGrpSize; i++)
|
||||
{
|
||||
rowSum[i] += expf(attentionSinks[i] - rowMax[i]);
|
||||
}
|
||||
}
|
||||
|
||||
Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> out
|
||||
= gemm1Acc.array().colwise() * (xScale * kvScale / rowSum.array());
|
||||
std::for_each(out.data(), out.data() + out.size(), [](float& e) { e = float(OutputElem(e)); });
|
||||
@ -123,7 +133,7 @@ Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refFlashAt
|
||||
template Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> \
|
||||
refFlashAttention<prec, tileSize, isPaged, useBeamSearch>(IOHead const* q, \
|
||||
CacheSeq<isPaged, useBeamSearch> const& k, CacheSeq<isPaged, useBeamSearch> const& v, uint32_t seqLen, \
|
||||
float qScale, float kvScale, float xScale, uint32_t slidingWinSize)
|
||||
float qScale, float kvScale, float xScale, uint32_t slidingWinSize, float* attentionSinks)
|
||||
|
||||
INSTANTIATE_refFlashAttention(CacheElem, 64, false, false);
|
||||
INSTANTIATE_refFlashAttention(CacheElem, 64, false, true);
|
||||
@ -143,7 +153,7 @@ Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refAttenti
|
||||
#else
|
||||
Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refAttention(IOHead const* q,
|
||||
CacheSeq<isPaged, useBeamSearch> const& k, CacheSeq<isPaged, useBeamSearch> const& v, uint32_t seqLen, float qScale,
|
||||
float kvScale, float xScale, uint32_t slidingWinSize)
|
||||
float kvScale, float xScale, uint32_t slidingWinSize, float* attentionSinks)
|
||||
{
|
||||
#endif
|
||||
float const rcpXScale = 1.f / xScale;
|
||||
@ -184,7 +194,7 @@ Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refAttenti
|
||||
|
||||
Eigen::Matrix<float, headGrpSize, Eigen::Dynamic, Eigen::RowMajor> x
|
||||
= (gemm0Acc.colwise() - rowMax).array().exp().eval();
|
||||
Eigen::Vector<float, headGrpSize> const rowSum = x.rowwise().sum().eval();
|
||||
Eigen::Vector<float, headGrpSize> rowSum = x.rowwise().sum().eval();
|
||||
|
||||
std::for_each(x.data(), x.data() + x.size(), [&](float& e) { e = float(MathElem(e * rcpXScale)); });
|
||||
|
||||
@ -200,6 +210,18 @@ Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refAttenti
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add the attention sinks.
|
||||
#if !SPEC_DEC
|
||||
if (attentionSinks != nullptr)
|
||||
{
|
||||
for (uint32_t i = 0; i < headGrpSize; i++)
|
||||
{
|
||||
rowSum[i] += expf(attentionSinks[i] - rowMax[i]);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> out
|
||||
= gemm1Acc.array().colwise() * (xScale * kvScale / rowSum.array());
|
||||
std::for_each(out.data(), out.data() + out.size(), [](float& e) { e = float(OutputElem(e)); });
|
||||
@ -217,7 +239,7 @@ Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refAttenti
|
||||
template Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> \
|
||||
refAttention<prec, isPaged, useBeamSearch>(IOHead const* q, CacheSeq<isPaged, useBeamSearch> const& k, \
|
||||
CacheSeq<isPaged, useBeamSearch> const& v, uint32_t seqLen, float qScale, float kvScale, float xScale, \
|
||||
uint32_t slidingWinSize)
|
||||
uint32_t slidingWinSize, float* attentionSinks)
|
||||
#endif
|
||||
INSTANTIATE_refAttention(InputElem, false, false);
|
||||
INSTANTIATE_refAttention(InputElem, false, true);
|
||||
|
||||
@ -83,7 +83,7 @@ struct CacheSeq<true, true>
|
||||
template <typename MathElem, uint32_t tileSize, bool isPaged, bool useBeamSearch>
|
||||
Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refFlashAttention(IOHead const* q,
|
||||
CacheSeq<isPaged, useBeamSearch> const& k, CacheSeq<isPaged, useBeamSearch> const& v, uint32_t seqLen, float qScale,
|
||||
float kvScale, float xScale, uint32_t slidingWinSize);
|
||||
float kvScale, float xScale, uint32_t slidingWinSize, float* attentionSinks);
|
||||
|
||||
template <typename MathElem, bool isPaged, bool useBeamSearch>
|
||||
#if SPEC_DEC
|
||||
@ -93,7 +93,7 @@ Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refAttenti
|
||||
#else
|
||||
Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refAttention(IOHead const* q,
|
||||
CacheSeq<isPaged, useBeamSearch> const& k, CacheSeq<isPaged, useBeamSearch> const& v, uint32_t seqLen, float qScale,
|
||||
float kvScale, float xScale, uint32_t slidingWinSize);
|
||||
float kvScale, float xScale, uint32_t slidingWinSize, float* attentionSinks);
|
||||
#endif
|
||||
|
||||
template <uint32_t ropeStyle>
|
||||
|
||||
@ -130,7 +130,7 @@ template <uint32_t nbKHeads>
|
||||
#endif
|
||||
#endif
|
||||
void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck, bool verbose = false,
|
||||
bool saveData = false, uint32_t ctxLen = ~0U, uint32_t slidingWinSize = 1U << 30)
|
||||
bool saveData = false, bool hasAttentionSinks = false, uint32_t ctxLen = ~0U, uint32_t slidingWinSize = 1U << 30)
|
||||
{
|
||||
#if IS_MLA
|
||||
if (nbKHeads != 1)
|
||||
@ -613,6 +613,17 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
|
||||
}
|
||||
}
|
||||
|
||||
// Allocate the attention sinks (per head)
|
||||
auto attentionSinks = ManagedMemBuf<float>(nbQHeads);
|
||||
// The attention sinks ptr.
|
||||
float* attentionSinksPtr = hasAttentionSinks ? reinterpret_cast<float*>(attentionSinks.get()) : nullptr;
|
||||
// Initialize the attention sinks (use large values to detect the potential bugs).
|
||||
for (uint32_t i = 0; i < nbQHeads; i++)
|
||||
{
|
||||
// Range: [2, 5]
|
||||
attentionSinks.get()[i] = 2.f + float(i % 4);
|
||||
}
|
||||
|
||||
if (verbose)
|
||||
{
|
||||
printf("migrating data to gpu\n");
|
||||
@ -640,6 +651,7 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
|
||||
#if BEAM_WIDTH > 1
|
||||
cacheIndir.prefetch(dev, stream);
|
||||
#endif
|
||||
attentionSinks.prefetch(dev, stream);
|
||||
};
|
||||
prefetchToDevice(device);
|
||||
checkCuda(cudaMemsetAsync(semaphores.get(), 0, 4 * nbSemaphores, stream));
|
||||
@ -720,6 +732,7 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
|
||||
&qHeads[0][0][0],
|
||||
#endif
|
||||
#endif
|
||||
attentionSinksPtr,
|
||||
#if PAGED_KV_CACHE_LAYOUT == 1 && USE_PAGED_KV_CACHE
|
||||
cacheKHeads.get(), cacheVHeads.get(),
|
||||
#else
|
||||
@ -1028,10 +1041,13 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
|
||||
hostMask, qSeqLen, q_len);
|
||||
#else
|
||||
Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refOutput;
|
||||
auto const refAttentionSinks
|
||||
= hasAttentionSinks ? attentionSinksPtr + headGrpSize * idxKHead : nullptr;
|
||||
if (useQGMMA)
|
||||
{
|
||||
refOutput = refFlashAttention<CacheElem, 64>(&qHeads[req][b][headGrpSize * idxKHead], kCacheSeq,
|
||||
vCacheSeq, seqLen, qScaleForRef, kvCacheScale[0], xScale, slidingWinSize);
|
||||
vCacheSeq, seqLen, qScaleForRef, kvCacheScale[0], xScale, slidingWinSize,
|
||||
refAttentionSinks);
|
||||
// refOutput = refAttention<CacheElem>(&qHeads[req][b][headGrpSize * idxKHead], kCacheSeq,
|
||||
// vCacheSeq, seqLen, qScaleForRef, kvCacheScale[0], xScale, slidingWinSize);
|
||||
}
|
||||
@ -1039,8 +1055,9 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
|
||||
{
|
||||
// refOutput = refFlashAttention<InputElem, 64>(&qHeads[req][b][headGrpSize * idxKHead],
|
||||
// kCacheSeq, vCacheSeq, seqLen, qScaleForRef, kvCacheScale[0], xScale);
|
||||
refOutput = refAttention<InputElem>(&qHeads[req][b][headGrpSize * idxKHead], kCacheSeq,
|
||||
vCacheSeq, seqLen, qScaleForRef, kvCacheScale[0], xScale, slidingWinSize);
|
||||
refOutput
|
||||
= refAttention<InputElem>(&qHeads[req][b][headGrpSize * idxKHead], kCacheSeq, vCacheSeq,
|
||||
seqLen, qScaleForRef, kvCacheScale[0], xScale, slidingWinSize, refAttentionSinks);
|
||||
}
|
||||
#endif
|
||||
if (lowPrecOutput)
|
||||
@ -1196,11 +1213,23 @@ TEST(RefCheck, llama_V2_70b)
|
||||
runTest<2>(2, 514, false, true);
|
||||
runTest<1>(1, 4096, false, true);
|
||||
#if SLIDING_WINDOW
|
||||
runTest<2>(2, 4096, false, true, false, false, ~0, 256);
|
||||
runTest<2>(2, 400, false, true, false, false, ~0U, 256);
|
||||
runTest<2>(2, 4096, false, true, false, false, false, ~0, 256);
|
||||
runTest<2>(2, 400, false, true, false, false, false, ~0U, 256);
|
||||
#endif
|
||||
runTest<8>(120, 367, false, true);
|
||||
// runTest<8>(1792, 2048, false, true);
|
||||
runTest<8>(1792, 2048, false, true);
|
||||
}
|
||||
|
||||
TEST(RefCheck, attention_sinks)
|
||||
{
|
||||
auto runAttentionSinksTest = [](uint32_t batchSize, uint32_t seqLen)
|
||||
{ runTest<8>(batchSize, seqLen, false, true, false, false, /*hasAttentionSinks*/ true); };
|
||||
|
||||
runAttentionSinksTest(2, 2);
|
||||
runAttentionSinksTest(2, 15);
|
||||
runAttentionSinksTest(2, 256);
|
||||
runAttentionSinksTest(2, 514);
|
||||
runAttentionSinksTest(1, 4096);
|
||||
}
|
||||
|
||||
TEST(Perf, tracing_long)
|
||||
@ -1264,7 +1293,7 @@ TEST(Perf, mlperf_gptj)
|
||||
#ifndef NDEBUG
|
||||
GTEST_SKIP() << "Skipping perf tests for debug build";
|
||||
#endif
|
||||
runTest<32>(396, 800 + 224, true, false, false, false, 800);
|
||||
runTest<32>(396, 800 + 224, true, false, false, false, false, 800);
|
||||
}
|
||||
|
||||
TEST(Perf, mlperf_llama)
|
||||
|
||||
@ -53,6 +53,7 @@ using namespace CUTLASS_MOE_GEMM_KERNELS_NAMESPACE;
|
||||
using CUTLASS_MOE_GEMM_NAMESPACE::TmaWarpSpecializedGroupedGemmInput;
|
||||
using CUTLASS_MOE_GEMM_KERNELS_NAMESPACE::CutlassMoeFCRunner;
|
||||
using CUTLASS_MOE_GEMM_NAMESPACE::ActivationType;
|
||||
using CUTLASS_MOE_GEMM_KERNELS_NAMESPACE::ActivationParams;
|
||||
using CUTLASS_MOE_GEMM_NAMESPACE::isGatedActivation;
|
||||
|
||||
static BufferManager::CudaStreamPtr streamPtr;
|
||||
@ -984,7 +985,7 @@ public:
|
||||
mSelectedExperts + mSelectedExpertsSize * mBufferIndex,
|
||||
mUseFinalScale ? mScaleProbs + mScaleProbsSize * mBufferIndex : nullptr,
|
||||
mExpertWeight1 + mExpertWeight1Size * mBufferIndex, mExpertBias1 + mExpertBias1Size * mBufferIndex,
|
||||
mActType, mExpertWeight2 + mExpertWeight2Size * mBufferIndex,
|
||||
ActivationParams(mActType), mExpertWeight2 + mExpertWeight2Size * mBufferIndex,
|
||||
mExpertBias2 + mExpertBias2Size * mBufferIndex, mQuantParams[mBufferIndex], mTotalTokens, mHiddenSize,
|
||||
mInterSize, mNumExperts, mK, mWorkspace + mWorkspaceSize * mBufferIndex,
|
||||
mFinalOutput + mFinalOutputSize * mBufferIndex,
|
||||
@ -996,7 +997,7 @@ public:
|
||||
mSelectedExperts + mSelectedExpertsSize * mBufferIndex,
|
||||
mUseFinalScale ? mScaleProbs + mScaleProbsSize * mBufferIndex : nullptr,
|
||||
mExpertWeight1 + mExpertWeight1Size * mBufferIndex, mExpertBias1 + mExpertBias1Size * mBufferIndex,
|
||||
mActType, mExpertWeight2 + mExpertWeight2Size * mBufferIndex,
|
||||
ActivationParams(mActType), mExpertWeight2 + mExpertWeight2Size * mBufferIndex,
|
||||
mExpertBias2 + mExpertBias2Size * mBufferIndex, mQuantParams[mBufferIndex], mTotalTokens, mHiddenSize,
|
||||
mInterSize, mNumExperts, mK, mWorkspace + mWorkspaceSize * mBufferIndex,
|
||||
mFinalOutput + mFinalOutputSize * mBufferIndex,
|
||||
|
||||
@ -55,6 +55,7 @@ struct FusedQKVMaskedAttentionDispatchParams
|
||||
T const* qkv_bias;
|
||||
T const* relative_attention_bias;
|
||||
bool const* attention_mask;
|
||||
float const* attention_sinks;
|
||||
float const* logn_scaling_ptr;
|
||||
int const* cache_indir;
|
||||
void* context_buf;
|
||||
@ -71,6 +72,7 @@ struct FusedQKVMaskedAttentionDispatchParams
|
||||
RotaryScalingType rotary_embedding_scale_type;
|
||||
float rotary_embedding_scale;
|
||||
float const* rotary_embedding_inv_freq_cache;
|
||||
float2 const* rotary_embedding_cos_sin_cache;
|
||||
float rotary_embedding_short_m_scale;
|
||||
float rotary_embedding_long_m_scale;
|
||||
int rotary_embedding_max_positions;
|
||||
@ -225,6 +227,7 @@ bool AttentionOp::convertMMHAParamsToXQAParams(tensorrt_llm::kernels::XQAParams&
|
||||
xqaParams.output = generationsParams.context_buf;
|
||||
xqaParams.qkv = generationsParams.attention_input;
|
||||
xqaParams.cache_indir = generationsParams.cache_indir;
|
||||
xqaParams.attention_sinks = generationsParams.attention_sinks;
|
||||
xqaParams.kv_scale_orig_quant = generationsParams.kv_scale_orig_quant;
|
||||
xqaParams.kv_scale_quant_orig = generationsParams.kv_scale_quant_orig;
|
||||
xqaParams.host_past_key_value_lengths = generationsParams.host_past_key_value_lengths;
|
||||
@ -596,6 +599,7 @@ void fusedQKV_masked_attention_dispatch(Multihead_attention_params<T_MMHA, CROSS
|
||||
params.rotary_embedding_scale_type = input_params.rotary_embedding_scale_type;
|
||||
params.rotary_embedding_scale = input_params.rotary_embedding_scale;
|
||||
params.rotary_embedding_inv_freq_cache = input_params.rotary_embedding_inv_freq_cache;
|
||||
params.rotary_embedding_cos_sin_cache = input_params.rotary_embedding_cos_sin_cache;
|
||||
params.rotary_embedding_short_m_scale = input_params.rotary_embedding_short_m_scale;
|
||||
params.rotary_embedding_long_m_scale = input_params.rotary_embedding_long_m_scale;
|
||||
params.rotary_embedding_max_positions = input_params.rotary_embedding_max_positions;
|
||||
@ -620,6 +624,9 @@ void fusedQKV_masked_attention_dispatch(Multihead_attention_params<T_MMHA, CROSS
|
||||
params.attention_mask = input_params.attention_mask;
|
||||
params.attention_mask_stride = input_params.attention_mask_stride;
|
||||
|
||||
// Attention sinks.
|
||||
params.attention_sinks = input_params.attention_sinks;
|
||||
|
||||
// The slope of linear position bias per head, e.g., ALiBi.
|
||||
if (input_params.linear_bias_slopes != nullptr)
|
||||
{
|
||||
@ -1691,6 +1698,7 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
|
||||
fmhaParams.outputPtr
|
||||
= mCpSize > 1 ? gatherOutBuffer : params.context_buf; // only use [totalLength, h / cpSize, Dh]
|
||||
fmhaParams.outputSfPtr = params.context_buf_sf;
|
||||
fmhaParams.attentionSinksPtr = params.attention_sinks;
|
||||
fmhaParams.packedMaskPtr = params.attention_packed_mask;
|
||||
if constexpr (std::is_same_v<KVCacheBuffer, KVBlockArray>)
|
||||
{
|
||||
@ -2220,6 +2228,7 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams<T> const& params, cud
|
||||
dispatch_params.relative_attention_bias_stride = relative_attention_bias_stride;
|
||||
dispatch_params.attention_mask = params.attention_mask;
|
||||
dispatch_params.attention_mask_stride = params.attention_mask_stride;
|
||||
dispatch_params.attention_sinks = params.attention_sinks;
|
||||
dispatch_params.max_distance = max_distance;
|
||||
dispatch_params.cache_indir = params.cache_indir;
|
||||
dispatch_params.context_buf = mCpSize > 1 ? mhaOutput : params.context_buf; //
|
||||
@ -2267,6 +2276,7 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams<T> const& params, cud
|
||||
dispatch_params.rotary_embedding_scale_type = mRotaryEmbeddingScaleType;
|
||||
dispatch_params.rotary_embedding_scale = mRotaryEmbeddingScale;
|
||||
dispatch_params.rotary_embedding_inv_freq_cache = params.rotary_inv_freq;
|
||||
dispatch_params.rotary_embedding_cos_sin_cache = params.rotary_cos_sin;
|
||||
dispatch_params.rotary_embedding_short_m_scale = mRotaryEmbeddingShortMscale;
|
||||
dispatch_params.rotary_embedding_long_m_scale = mRotaryEmbeddingLongMscale;
|
||||
dispatch_params.rotary_embedding_max_positions = mRotaryEmbeddingMaxPositions;
|
||||
|
||||
@ -65,6 +65,8 @@ public:
|
||||
T const* qkv_bias = nullptr;
|
||||
// Attention mask input, which has shape of [batch_size, attention_mask_stride].
|
||||
bool const* attention_mask = nullptr;
|
||||
// Attention sinks with shape of [num_heads_q] float.
|
||||
float const* attention_sinks = nullptr;
|
||||
// Rotary inv_freq cache buffer to avoid re-computing.
|
||||
float const* rotary_inv_freq = nullptr;
|
||||
// Rotary cos sin cache buffer to avoid re-computing.
|
||||
|
||||
@ -27,6 +27,78 @@
|
||||
namespace cutlass::gemm::collective::detail
|
||||
{
|
||||
|
||||
using namespace cute;
|
||||
|
||||
typedef uint32_t __nv_fp4x8_storage_t;
|
||||
typedef uint32_t __nv_bf16x2_storage_t;
|
||||
typedef cutlass::uint128_t __nv_bf16x8_storage_t;
|
||||
|
||||
constexpr int int4_group_size = 128;
|
||||
constexpr int mxfp4_group_size = 32;
|
||||
|
||||
inline __device__ unsigned prmt(unsigned hi, unsigned lo, unsigned select_code)
|
||||
{
|
||||
unsigned res = 0;
|
||||
|
||||
asm volatile(
|
||||
"{\n"
|
||||
"prmt.b32 %0, %1, %2, %3;\n"
|
||||
"}\n"
|
||||
: "=r"(res)
|
||||
: "r"(lo), "r"(hi), "r"(select_code));
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
__device__ __inline__ __nv_fp8x4_storage_t cvt_lut_bf16(unsigned const index)
|
||||
{
|
||||
const __nv_fp8x4_storage_t h4b_lut = 0x03020100U; // 7654
|
||||
const __nv_fp8x4_storage_t l4b_lut = 0xFFFEFC00U; // 3210
|
||||
|
||||
__nv_fp8x4_storage_t lut_res = prmt(h4b_lut, l4b_lut, index);
|
||||
|
||||
return lut_res;
|
||||
}
|
||||
|
||||
__device__ __inline__ __nv_bf16x8_storage_t psx_cvt_lut_prmt_fp4x8_to_bf16x8(const __nv_fp4x8_storage_t fp4x8)
|
||||
{
|
||||
__nv_bf16x8_storage_t bf16x8_raw = {0, 0};
|
||||
__nv_bf16x2_storage_t* bf16x2_raw = reinterpret_cast<__nv_bf16x2_storage_t*>(&bf16x8_raw);
|
||||
|
||||
unsigned zero_padding = 0x00000000U;
|
||||
|
||||
unsigned h4b_em_fp4x4 = (fp4x8 & 0x77770000U) >> 16U;
|
||||
unsigned l4b_em_fp4x4 = (fp4x8 & 0x00007777U);
|
||||
|
||||
__nv_fp8x4_storage_t h4b_2to9_bits = cvt_lut_bf16(h4b_em_fp4x4); // 7654
|
||||
__nv_fp8x4_storage_t l4b_2to9_bits = cvt_lut_bf16(l4b_em_fp4x4); // 3210
|
||||
|
||||
bf16x2_raw[0] = prmt(zero_padding, l4b_2to9_bits, 0x1707U) >> 2U; // 1 0
|
||||
bf16x2_raw[1] = prmt(zero_padding, l4b_2to9_bits, 0x3727U) >> 2U; // 3 2
|
||||
bf16x2_raw[2] = prmt(h4b_2to9_bits, zero_padding, 0x5040U) >> 2U; // 5 4
|
||||
bf16x2_raw[3] = prmt(h4b_2to9_bits, zero_padding, 0x7060U) >> 2U; // 7 6
|
||||
|
||||
__nv_bf16x2_storage_t bf16x2_0to1_bits;
|
||||
|
||||
__nv_fp8x4_storage_t h_fp8x2_0to1_bits = (fp4x8 & 0x0000C0C0U); // 3 1
|
||||
__nv_fp8x4_storage_t l_fp8x2_0to1_bits = (fp4x8 & 0x00000C0CU) << 4U; // 2 0
|
||||
|
||||
bf16x2_0to1_bits = prmt(h_fp8x2_0to1_bits, l_fp8x2_0to1_bits, 0x4707U); // 1 0
|
||||
bf16x2_raw[0] = bf16x2_raw[0] | bf16x2_0to1_bits;
|
||||
bf16x2_0to1_bits = prmt(h_fp8x2_0to1_bits, l_fp8x2_0to1_bits, 0x5717U); // 3 2
|
||||
bf16x2_raw[1] = bf16x2_raw[1] | bf16x2_0to1_bits;
|
||||
|
||||
h_fp8x2_0to1_bits = (fp4x8 & 0xC0C00000U); // 7 5
|
||||
l_fp8x2_0to1_bits = (fp4x8 & 0x0C0C0000U) << 4U; // 6 4
|
||||
|
||||
bf16x2_0to1_bits = prmt(h_fp8x2_0to1_bits, l_fp8x2_0to1_bits, 0x6020U); // 5 4
|
||||
bf16x2_raw[2] = bf16x2_raw[2] | bf16x2_0to1_bits;
|
||||
bf16x2_0to1_bits = prmt(h_fp8x2_0to1_bits, l_fp8x2_0to1_bits, 0x7030U); // 7 6
|
||||
bf16x2_raw[3] = bf16x2_raw[3] | bf16x2_0to1_bits;
|
||||
|
||||
return bf16x8_raw;
|
||||
}
|
||||
|
||||
template <class Collective>
|
||||
struct MixedGroupedGemmInputUtils
|
||||
{
|
||||
@ -46,6 +118,7 @@ private:
|
||||
static constexpr auto KernelConversionMode = Collective::KernelConversionMode;
|
||||
static constexpr auto ModeHasScales = Collective::ModeHasScales;
|
||||
static constexpr auto UseScaleLookupTable = Collective::UseScaleLookupTable;
|
||||
static constexpr auto UseFP4ToBF16LookupTable = Collective::UseFP4ToBF16LookupTable;
|
||||
|
||||
public:
|
||||
static constexpr auto elements_per_smem_scale()
|
||||
@ -239,6 +312,27 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
// The core converter uses a lookup table to converts i4 -> 8 bit value.
|
||||
template <class EngineIn, class LayoutIn, class EngineOut,
|
||||
class LayoutOut>
|
||||
CUTLASS_DEVICE static void fp4tobf16_lookup_table_convert( // Accept mutable temporaries
|
||||
Tensor<EngineIn, LayoutIn> const& src, Tensor<EngineOut, LayoutOut>&& dst)
|
||||
{
|
||||
fp4tobf16_lookup_table_convert(src, dst);
|
||||
}
|
||||
|
||||
template <class EngineIn, class LayoutIn, class EngineOut, class LayoutOut>
|
||||
CUTLASS_DEVICE static void fp4tobf16_lookup_table_convert(
|
||||
Tensor<EngineIn, LayoutIn> const& src, Tensor<EngineOut, LayoutOut>& dst)
|
||||
{
|
||||
|
||||
// View the input as reg
|
||||
auto&& src_ = cute::recast<__nv_fp4x8_storage_t>(src)(0);
|
||||
auto&& dst_ = cute::recast<__nv_bf16x8_storage_t>(dst)(0);
|
||||
|
||||
dst_ = psx_cvt_lut_prmt_fp4x8_to_bf16x8(src_);
|
||||
}
|
||||
|
||||
/// Utilities to dequantize A.
|
||||
template <class Layout>
|
||||
CUTLASS_DEVICE static void static_check_scale(Layout const& tensor)
|
||||
@ -253,7 +347,6 @@ public:
|
||||
static_check_scale(flatten(Layout{}));
|
||||
}
|
||||
|
||||
// dequantize_A_kblock is here!!!
|
||||
template <class EngineIn, class EngineOut, class LayoutIn, class LayoutOut, class... Ts>
|
||||
CUTLASS_DEVICE static void dequantize_A_kblock(Tensor<EngineIn, LayoutIn> const& tCrA_load,
|
||||
Tensor<EngineOut, LayoutOut>& tCrA_mma, cute::tuple<Ts...>& partitioned_extra_info, int const k_block)
|
||||
@ -288,8 +381,6 @@ public:
|
||||
}
|
||||
else if constexpr (UseScaleLookupTable)
|
||||
{
|
||||
// this path
|
||||
|
||||
constexpr int num_elements = decltype(size(src))::value;
|
||||
static_assert(is_same_v<RealSwappedElementA, cutlass::int4b_t>,
|
||||
"Lookup table only supports int4 being the quant type now.");
|
||||
@ -424,7 +515,6 @@ public:
|
||||
static_assert(size_v<LayoutIn> == cosize_v<LayoutIn>);
|
||||
static_assert(size_v<LayoutOut> == cosize_v<LayoutOut>);
|
||||
using SrcType = typename EngineIn::value_type;
|
||||
using DstType = typename EngineOut::value_type;
|
||||
|
||||
Tensor src = tCrA_load(_, _, k_block);
|
||||
Tensor dst = tCrA_mma(_, _, k_block);
|
||||
@ -441,7 +531,14 @@ public:
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size<1>(dst_vm); ++i)
|
||||
{
|
||||
LayoutAwareConvert(src_vm(_, i), dst_vm(_, i));
|
||||
if constexpr (UseFP4ToBF16LookupTable)
|
||||
{
|
||||
fp4tobf16_lookup_table_convert(src_vm(_, i), dst_vm(_, i));
|
||||
}
|
||||
else
|
||||
{
|
||||
LayoutAwareConvert(src_vm(_, i), dst_vm(_, i));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -30,37 +30,12 @@
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cute/numeric/arithmetic_tuple.hpp"
|
||||
|
||||
#define GROUP_SIZE 128
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::collective
|
||||
{
|
||||
using namespace cute;
|
||||
|
||||
template <int N>
|
||||
CUTE_HOST_DEVICE void warpgroup_wait_()
|
||||
{
|
||||
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
||||
cutlass::arch::synclog_emit_warpgroup_wait(__LINE__, N);
|
||||
asm volatile("wgmma.wait_group.sync.aligned %0;\n" ::"n"(N) : "memory");
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Attempting to use wgmma.wait_group<N> without CUTE_ARCH_MMA_SM90A_ENABLED");
|
||||
#endif
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE void warpgroup_wait_dispatch(int onthefly_count)
|
||||
{
|
||||
switch (onthefly_count)
|
||||
{
|
||||
case 0: warpgroup_wait_<0>(); break;
|
||||
case 4: warpgroup_wait_<4>(); break;
|
||||
case 8: warpgroup_wait_<8>(); break;
|
||||
case 12: warpgroup_wait_<12>(); break;
|
||||
default: assert(false && "Invalid onthefly_count value");
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// WarpSpecialized Mainloop
|
||||
@ -91,7 +66,7 @@ public:
|
||||
private:
|
||||
template <class T>
|
||||
friend struct detail::MixedGroupedGemmInputUtils;
|
||||
using CollectiveType = CollectiveMma<DispatchPolicy, TileShape_, ElementAOptionalTuple, StrideA_,
|
||||
using CollectiveType = CollectiveMmaArrayMixedInput<DispatchPolicy, TileShape_, ElementAOptionalTuple, StrideA_,
|
||||
ElementBOptionalTuple, StrideB_, TiledMma_, GmemTiledCopyA_, SmemLayoutAtomA_, SmemCopyAtomA_, TransformA_,
|
||||
GmemTiledCopyB_, SmemLayoutAtomB_, SmemCopyAtomB_, TransformB_>;
|
||||
using Utils = detail::MixedGroupedGemmInputUtils<CollectiveType>;
|
||||
@ -146,6 +121,11 @@ public:
|
||||
static_assert(cutlass::gemm::detail::is_mn_major<NonVoidStrideScale>(),
|
||||
"Scale must be MN major [Col Major if A is scaled, Row Major if B is scaled].");
|
||||
|
||||
static constexpr bool IsMXFP4 = cute::is_same_v<ElementA, cutlass::float_e2m1_t>;
|
||||
// Group size 128 for int4 weights
|
||||
// Group size 32 for mxfp4 weights
|
||||
static constexpr int ScalingGroupSize = IsMXFP4 ? detail::mxfp4_group_size : detail::int4_group_size;
|
||||
|
||||
using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{}));
|
||||
using TiledMma = TiledMma_;
|
||||
using ElementAccumulator = typename TiledMma::ValTypeC;
|
||||
@ -268,6 +248,8 @@ public:
|
||||
|| KernelConversionMode == ConversionMode::ConvertAndScaleWithZero;
|
||||
static constexpr bool UseScaleLookupTable
|
||||
= KernelConversionMode == ConversionMode::ConvertAndScale && cutlass::detail::is_Array_v<ElementScale>;
|
||||
static constexpr bool UseFP4ToBF16LookupTable = KernelConversionMode == ConversionMode::ConvertAndScale
|
||||
&& cute::is_same_v<ElementA, cutlass::float_e2m1_t> && cute::is_same_v<ElementB, cutlass::bfloat16_t>;
|
||||
static constexpr size_t SmemAlignmentA = cutlass::detail::alignment_for_swizzle(SmemLayoutA{});
|
||||
static constexpr size_t SmemAlignmentB = cutlass::detail::alignment_for_swizzle(SmemLayoutB{});
|
||||
static constexpr size_t SmemAlignmentScale = cute::max(SmemAlignmentA, SmemAlignmentB);
|
||||
@ -705,7 +687,7 @@ public:
|
||||
{
|
||||
// The real scale_k that actually works
|
||||
// auto scale_k = K / mainloop_params.chunk_size;
|
||||
auto scale_k = K / GROUP_SIZE;
|
||||
auto scale_k = K / ScalingGroupSize;
|
||||
|
||||
Tensor mS_mkl = mainloop_params.tma_load_scale.get_tma_tensor(make_shape(M, scale_k, L)); // (m,scale_k,l)
|
||||
Tensor gS_mkl = local_tile(mS_mkl, ScaleTileShape{}, make_coord(_, _)); // (BLK_M,BLK_Scale_K,m,scale_k,l)
|
||||
@ -872,7 +854,6 @@ public:
|
||||
}
|
||||
else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero)
|
||||
{
|
||||
// zero copy
|
||||
auto tZgZ = get<2>(extra_input_partitions);
|
||||
auto tZsZ = get<3>(extra_input_partitions);
|
||||
if (cute::elect_one_sync())
|
||||
@ -979,7 +960,8 @@ public:
|
||||
return make_tensor_like<RealSwappedElementA>(tCsA(_, _, _, Int<0>{}));
|
||||
}
|
||||
}();
|
||||
Tensor tCsB = mma_warpgroup_slice.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
|
||||
Tensor tCsB = mma_warpgroup_slice.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
|
||||
// tCrB is just a view of the tensor tCsB
|
||||
Tensor tCrB = mma_warpgroup_slice.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE)
|
||||
|
||||
//
|
||||
@ -1013,8 +995,8 @@ public:
|
||||
|
||||
multiply_add<ElementAccumulator> fma;
|
||||
|
||||
constexpr int NumMMAsPerChunk = GROUP_SIZE / cute::get<0, 1>(tCsB.shape())();
|
||||
constexpr int NumChunksPerTileK = cute::size<1>(sA.shape())() / GROUP_SIZE;
|
||||
constexpr int NumMMAsPerChunk = ScalingGroupSize / cute::get<0, 1>(tCsB.shape())();
|
||||
constexpr int NumChunksPerTileK = cute::size<1>(sA.shape())() / ScalingGroupSize;
|
||||
cute::array<decltype(make_fragment_like(accum)), NumChunksPerTileK> intermediate_array;
|
||||
|
||||
constexpr int K_BLOCK_MAX = size<2>(tCrA_load);
|
||||
@ -1045,8 +1027,6 @@ public:
|
||||
// src: tCrA_load, dst: tCrA_mma
|
||||
Utils::convert_A_kblock(tCrA_load, tCrA_mma, 0);
|
||||
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
||||
|
||||
// Unroll the K mode manually to set scale D to 1
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int chunk_id = 0; chunk_id < NumChunksPerTileK; ++chunk_id)
|
||||
@ -1079,10 +1059,11 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
warpgroup_wait<0>();
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int chunk_id_ = 0; chunk_id_ < NumChunksPerTileK; ++chunk_id_)
|
||||
{
|
||||
warpgroup_wait_dispatch((NumChunksPerTileK - chunk_id_ - 1) * NumMMAsPerChunk);
|
||||
warpgroup_fence_operand(intermediate_array[chunk_id_]);
|
||||
|
||||
// Apply the group-wise scaling
|
||||
@ -1129,7 +1110,6 @@ public:
|
||||
Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, partitioned_extra_info,
|
||||
copy_partitions_extra_info, 1, smem_pipe_read.index());
|
||||
|
||||
warpgroup_wait<K_WAIT_MAX>();
|
||||
Utils::convert_A_kblock(tCrA_load, tCrA_mma, 0);
|
||||
}
|
||||
}
|
||||
@ -1169,8 +1149,6 @@ public:
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
||||
warpgroup_commit_batch();
|
||||
|
||||
warpgroup_wait<K_WAIT_MAX>(); // We have K_BLOCK_MAX - 1 GMMA instructions pending for this stage,
|
||||
// so we can release prior barrier
|
||||
if (k_block == K_BLOCK_MAX - 1)
|
||||
{
|
||||
pipeline.consumer_release(
|
||||
@ -1187,10 +1165,11 @@ public:
|
||||
{
|
||||
// The last k_block
|
||||
|
||||
warpgroup_wait<0>();
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int chunk_id_ = 0; chunk_id_ < NumChunksPerTileK; ++chunk_id_)
|
||||
{
|
||||
warpgroup_wait_dispatch((NumChunksPerTileK - chunk_id_ - 1) * NumMMAsPerChunk);
|
||||
warpgroup_fence_operand(intermediate_array[chunk_id_]);
|
||||
|
||||
// Apply the group-wise scaling
|
||||
@ -1257,7 +1236,6 @@ public:
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
||||
warpgroup_commit_batch();
|
||||
|
||||
warpgroup_wait<K_WAIT_MAX>();
|
||||
if (k_block == K_BLOCK_MAX - 1)
|
||||
{
|
||||
// release prior barrier
|
||||
@ -1318,7 +1296,7 @@ public:
|
||||
smem_pipe_release.advance(k_tile_count);
|
||||
|
||||
// Wait on all GMMAs to complete
|
||||
warpgroup_wait<0>();
|
||||
// warpgroup_wait<0>();
|
||||
|
||||
for (int count = 0; count < prologue_mma_count; ++count)
|
||||
{
|
||||
@ -1462,7 +1440,7 @@ public:
|
||||
{
|
||||
NonVoidElementScale const* ptr_S = nullptr;
|
||||
// auto scale_k = K / mainloop_params.chunk_size;
|
||||
auto scale_k = K / GROUP_SIZE;
|
||||
auto scale_k = K / ScalingGroupSize;
|
||||
Tensor tensor_scale = make_tensor(
|
||||
detail::get_logical_ptr(ptr_S), make_shape(M, scale_k, Int<1>{}), mainloop_params.dS[next_group]);
|
||||
cute::detail::fill_tma_gmem_shape_stride(
|
||||
@ -1472,7 +1450,7 @@ public:
|
||||
{
|
||||
ElementZero const* ptr_Z = nullptr;
|
||||
// auto scale_k = K / mainloop_params.chunk_size;
|
||||
auto scale_k = K / GROUP_SIZE;
|
||||
auto scale_k = K / ScalingGroupSize;
|
||||
Tensor tensor_zero = make_tensor(
|
||||
detail::get_logical_ptr(ptr_Z), make_shape(M, scale_k, Int<1>{}), mainloop_params.dS[next_group]);
|
||||
cute::detail::fill_tma_gmem_shape_stride(
|
||||
|
||||
@ -256,9 +256,9 @@ public:
|
||||
constexpr int SF_VEC_SIZE = 16;
|
||||
using PackedVec = PackedVec<DType>;
|
||||
PackedVec pack_val = *reinterpret_cast<PackedVec const*>(&val);
|
||||
auto sf_out = cvt_quant_to_fp4_get_sf_out_offset<uint32_t, 2, SF_VEC_SIZE>(std::nullopt, token_id,
|
||||
m_access_id_in_token, std::nullopt, m_params.hidden_dim,
|
||||
reinterpret_cast<uint32_t*>(m_params.scale_out), m_params.layout);
|
||||
auto sf_out = cvt_quant_get_sf_out_offset<uint32_t, 2>(std::nullopt, token_id, m_access_id_in_token,
|
||||
std::nullopt, m_params.hidden_dim / SF_VEC_SIZE, reinterpret_cast<uint32_t*>(m_params.scale_out),
|
||||
m_params.layout);
|
||||
reinterpret_cast<uint32_t*>(m_params.quant_out)[m_access_id]
|
||||
= cvt_warp_fp16_to_fp4<DType, SF_VEC_SIZE, false>(pack_val, m_scale_factor, sf_out);
|
||||
}
|
||||
|
||||
@ -132,7 +132,7 @@ struct AllReduceFusionParams
|
||||
float rms_eps;
|
||||
float* scale_factor;
|
||||
bool use_oneshot;
|
||||
FP4QuantizationSFLayout layout = FP4QuantizationSFLayout::SWIZZLED;
|
||||
QuantizationSFLayout layout = QuantizationSFLayout::SWIZZLED;
|
||||
cudaStream_t stream;
|
||||
AllReduceFusionPattern pattern;
|
||||
bool trigger_completion_at_end = true;
|
||||
|
||||
@ -99,15 +99,15 @@ __device__ struct __attribute__((aligned(32))) LamportFlags
|
||||
uint32_t* offset_access_ptr;
|
||||
uint32_t* buffer_flags;
|
||||
|
||||
__device__ explicit LamportFlags(uint32_t* buffer_flags)
|
||||
__device__ explicit LamportFlags(uint32_t* buffer_flags, uint32_t buffer_size)
|
||||
: offset_access_ptr(&buffer_flags[4])
|
||||
, buffer_flags(buffer_flags)
|
||||
, buffer_size(buffer_size)
|
||||
{
|
||||
uint4 flag = reinterpret_cast<uint4*>(buffer_flags)[0];
|
||||
buffer_size = flag.z;
|
||||
input_offset = flag.x * (buffer_size << 1U);
|
||||
clear_offset = flag.y * (buffer_size << 1U);
|
||||
num_tokens_prev = flag.w;
|
||||
num_tokens_prev = flag.z;
|
||||
}
|
||||
|
||||
__device__ void cta_arrive()
|
||||
@ -135,7 +135,7 @@ __device__ struct __attribute__((aligned(32))) LamportFlags
|
||||
uint4 flag = reinterpret_cast<uint4*>(buffer_flags)[0];
|
||||
buffer_flags[0] = (flag.x + 1) % 3;
|
||||
buffer_flags[1] = (flag.y + 1) % 3;
|
||||
buffer_flags[3] = num_tokens;
|
||||
buffer_flags[2] = num_tokens;
|
||||
*(offset_access_ptr) = 0;
|
||||
}
|
||||
}
|
||||
@ -144,7 +144,7 @@ __device__ struct __attribute__((aligned(32))) LamportFlags
|
||||
|
||||
template <int WORLD_SIZE, typename T>
|
||||
__global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_ptrs, T* mcast_ptr, int num_tokens,
|
||||
int buffer_M, int token_dim, int rank, uint32_t* buffer_flags, bool wait_for_results)
|
||||
int buffer_M, int token_dim, int rank, uint32_t buffer_size, uint32_t* buffer_flags, bool wait_for_results)
|
||||
{
|
||||
int elt = blockIdx.y * blockDim.x + threadIdx.x;
|
||||
if (elt >= token_dim)
|
||||
@ -155,7 +155,7 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_
|
||||
cudaGridDependencySynchronize();
|
||||
#endif
|
||||
|
||||
LamportFlags flags(buffer_flags);
|
||||
LamportFlags flags(buffer_flags, buffer_size);
|
||||
|
||||
// Capture the number of tokens in previous iteration so that we can properly clear the buffer
|
||||
// The scatter stage will use the buffer in WORLD_SIZE granularity, thus we need to round up
|
||||
@ -217,15 +217,17 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
cudaTriggerProgrammaticLaunchCompletion();
|
||||
#endif
|
||||
|
||||
// Similarly clear broadcast buffer here
|
||||
for (int clr_tok = 0; clr_tok < clr_toks_cta; clr_tok++)
|
||||
if (elt < token_dim)
|
||||
{
|
||||
uint32_t clr_token_idx = token + clr_tok * gridDim.x;
|
||||
if (clr_token_idx < buffer_M)
|
||||
// Similarly clear broadcast buffer here
|
||||
for (int clr_tok = 0; clr_tok < clr_toks_cta; clr_tok++)
|
||||
{
|
||||
input_ptrs[rank][flags.clear_offset + buffer_M * token_dim + clr_token_idx * token_dim + elt]
|
||||
= fromFloat<T>(-0.f);
|
||||
uint32_t clr_token_idx = token + clr_tok * gridDim.x;
|
||||
if (clr_token_idx < buffer_M)
|
||||
{
|
||||
input_ptrs[rank][flags.clear_offset + buffer_M * token_dim + clr_token_idx * token_dim + elt]
|
||||
= fromFloat<T>(-0.f);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -240,20 +242,24 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_
|
||||
// blockDim.x / ELTS_PER_LOAD should be at least the size of a warp (32)
|
||||
if (threadIdx.x < (blockDim.x / ELTS_PER_LOAD))
|
||||
{
|
||||
uint64_t current_pos = blockIdx.x * token_dim + blockIdx.y * blockDim.x + threadIdx.x * ELTS_PER_LOAD;
|
||||
uint64_t elt_load_offset = blockIdx.y * blockDim.x + threadIdx.x * ELTS_PER_LOAD;
|
||||
if (elt_load_offset < token_dim)
|
||||
{
|
||||
uint64_t current_pos = blockIdx.x * token_dim + elt_load_offset;
|
||||
|
||||
void* lamport_ptr = (void*) &input_ptrs[rank][flags.input_offset + buffer_M * token_dim + current_pos];
|
||||
// We have 2 assumptions here:
|
||||
// 1. The write is atomic in 8B granularity -> Each buffer in the buffer group should be aligned to 8B
|
||||
// 2. The num_token * token_dim is divisible by ELTS_PER_LOAD (4 for BF16 and 2 for FP32)
|
||||
float2 val = loadfloat2(lamport_ptr);
|
||||
while (isNegZero(*(T*) &val))
|
||||
{
|
||||
val = loadfloat2(lamport_ptr);
|
||||
}
|
||||
if (output_ptr)
|
||||
{
|
||||
*((float2*) &output_ptr[current_pos]) = val;
|
||||
void* lamport_ptr = (void*) &input_ptrs[rank][flags.input_offset + buffer_M * token_dim + current_pos];
|
||||
// We have 2 assumptions here:
|
||||
// 1. The write is atomic in 8B granularity -> Each buffer in the buffer group should be aligned to 8B
|
||||
// 2. The num_token * token_dim is divisible by ELTS_PER_LOAD (4 for BF16 and 2 for FP32)
|
||||
float2 val = loadfloat2(lamport_ptr);
|
||||
while (isNegZero(*(T*) &val))
|
||||
{
|
||||
val = loadfloat2(lamport_ptr);
|
||||
}
|
||||
if (output_ptr)
|
||||
{
|
||||
*((float2*) &output_ptr[current_pos]) = val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -263,10 +269,11 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_
|
||||
}
|
||||
|
||||
#define LAUNCH_ALL_REDUCE_KERNEL(WORLD_SIZE, T) \
|
||||
TLLM_CUDA_CHECK(cudaLaunchKernelEx(&config, &twoshot_allreduce_kernel<WORLD_SIZE, T>, \
|
||||
reinterpret_cast<T*>(params.output), reinterpret_cast<T*>(params.input), \
|
||||
reinterpret_cast<T**>(params.buffer_ptrs_dev), (T*) params.multicast_ptr, params.num_tokens, params.buffer_M, \
|
||||
params.token_dim, params.rank, reinterpret_cast<uint32_t*>(params.buffer_flags), params.wait_for_results));
|
||||
TLLM_CUDA_CHECK( \
|
||||
cudaLaunchKernelEx(&config, &twoshot_allreduce_kernel<WORLD_SIZE, T>, reinterpret_cast<T*>(params.output), \
|
||||
reinterpret_cast<T*>(params.input), reinterpret_cast<T**>(params.buffer_ptrs_dev), \
|
||||
(T*) params.multicast_ptr, params.num_tokens, params.buffer_M, params.token_dim, params.rank, \
|
||||
params.buffer_size, reinterpret_cast<uint32_t*>(params.buffer_flags), params.wait_for_results));
|
||||
|
||||
void twoshot_allreduce_op(AllReduceParams const& params)
|
||||
{
|
||||
@ -369,20 +376,33 @@ inline __device__ T add(T a, T b)
|
||||
}
|
||||
|
||||
#define FINAL_MASK 0xffffffff
|
||||
#define WARP_SIZE 32
|
||||
|
||||
template <typename T>
|
||||
__inline__ __device__ T warpReduceSum(T val)
|
||||
{
|
||||
// Get the actual number of active threads in this warp
|
||||
int active_warp_size = min(WARP_SIZE, blockDim.x - (threadIdx.x & ~(WARP_SIZE - 1)));
|
||||
unsigned int mask = (1U << active_warp_size) - 1;
|
||||
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1)
|
||||
val = add<T>(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32)); //__shfl_sync bf16 return float when sm < 80
|
||||
for (int offset = 16; offset > 0; offset >>= 1)
|
||||
{
|
||||
if (offset < active_warp_size)
|
||||
{
|
||||
val = add<T>(val, __shfl_xor_sync(mask, val, offset, WARP_SIZE));
|
||||
}
|
||||
}
|
||||
return val;
|
||||
}
|
||||
|
||||
inline __device__ float block_reduce_sum(float val)
|
||||
{
|
||||
__shared__ float smem[32];
|
||||
int lane_id = threadIdx.x % 32, warp_id = threadIdx.x / 32, warp_num = blockDim.x / 32;
|
||||
__shared__ float smem[WARP_SIZE];
|
||||
int lane_id = threadIdx.x % WARP_SIZE;
|
||||
int warp_id = threadIdx.x / WARP_SIZE;
|
||||
int warp_num = (blockDim.x + WARP_SIZE - 1) / WARP_SIZE; // Ceiling division to include partial warps
|
||||
|
||||
val = warpReduceSum(val);
|
||||
if (lane_id == 0)
|
||||
{
|
||||
@ -391,6 +411,7 @@ inline __device__ float block_reduce_sum(float val)
|
||||
__syncthreads();
|
||||
val = lane_id < warp_num ? smem[lane_id] : 0.f;
|
||||
val = warpReduceSum(val);
|
||||
|
||||
return val;
|
||||
}
|
||||
|
||||
@ -410,7 +431,7 @@ __device__ float4 loadfloat4(void const* ptr)
|
||||
template <int DIM, int NUM_THREADS, int NUM_INPUTS, typename T_OUT, typename T_IN>
|
||||
__global__ void __launch_bounds__(128, 1)
|
||||
RMSNorm(T_IN* input_plus_residual, T_OUT* output_norm, T_IN const* buffer_input, T_IN const* gamma, float epsilon,
|
||||
T_IN const* residual, int batch_size, uint32_t* buffer_flags)
|
||||
T_IN const* residual, int batch_size, uint32_t buffer_size, uint32_t* buffer_flags)
|
||||
{
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
static bool const LAMPORT = true;
|
||||
@ -433,7 +454,7 @@ __global__ void __launch_bounds__(128, 1)
|
||||
|
||||
int offsets[NUM_INPUTS][DIM / (1 * ELTS_PER_THREAD * NUM_THREADS)];
|
||||
|
||||
LamportFlags flags(buffer_flags);
|
||||
LamportFlags flags(buffer_flags, buffer_size);
|
||||
T_IN const* input = &buffer_input[flags.input_offset + flags.buffer_size];
|
||||
|
||||
cudaTriggerProgrammaticLaunchCompletion();
|
||||
@ -598,16 +619,15 @@ __global__ void __launch_bounds__(128, 1)
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, int H_DIM>
|
||||
template <typename T, int H_DIM, int NUM_THREADS>
|
||||
void twoshot_rmsnorm(T* prenorm_output, T* normed_output, T const* input, T const* gamma, double epsilon,
|
||||
T const* residual, uint32_t* buffer_flags, int batch, cudaStream_t stream)
|
||||
T const* residual, uint32_t buffer_size, uint32_t* buffer_flags, int batch, cudaStream_t stream)
|
||||
{
|
||||
|
||||
// input to rmsnorm is the buffer in the twoshot ar
|
||||
// We should use prenorm output to determine the actual used size
|
||||
float _epsilon{static_cast<float>(epsilon)};
|
||||
|
||||
static constexpr int NUM_THREADS = 128;
|
||||
static constexpr int CGA_THREADS = NUM_THREADS;
|
||||
constexpr int iters = H_DIM / CGA_THREADS;
|
||||
|
||||
@ -628,28 +648,34 @@ void twoshot_rmsnorm(T* prenorm_output, T* normed_output, T const* input, T cons
|
||||
&RMSNorm<H_DIM, NUM_THREADS, 1, T, T>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size);
|
||||
config.dynamicSmemBytes = shmem_size;
|
||||
TLLM_CUDA_CHECK(cudaLaunchKernelEx(&config, &RMSNorm<H_DIM, NUM_THREADS, 1, T, T>, prenorm_output, normed_output,
|
||||
input, gamma, _epsilon, residual, batch, buffer_flags));
|
||||
input, gamma, _epsilon, residual, batch, buffer_size, buffer_flags));
|
||||
}
|
||||
|
||||
#define LAUNCH_RMSNORM_KERNEL(T, H_DIM) \
|
||||
twoshot_rmsnorm<T, H_DIM>(static_cast<T*>(params.residual_output), static_cast<T*>(params.output), \
|
||||
#define LAUNCH_RMSNORM_KERNEL(T, H_DIM, NUM_THREADS) \
|
||||
twoshot_rmsnorm<T, H_DIM, NUM_THREADS>(static_cast<T*>(params.residual_output), static_cast<T*>(params.output), \
|
||||
static_cast<T const*>(params.input), static_cast<T const*>(params.gamma), params.epsilon, \
|
||||
static_cast<T const*>(params.residual), params.buffer_flags, params.batch, params.stream)
|
||||
static_cast<T const*>(params.residual), params.buffer_size, params.buffer_flags, params.batch, params.stream)
|
||||
|
||||
void twoshot_rmsnorm_op(RMSNormParams const& params)
|
||||
{
|
||||
auto dtype = params.dtype;
|
||||
|
||||
#define CASE_DISPATCH_RMSNORM(T, H_DIM, NUM_THREADS) \
|
||||
case H_DIM: LAUNCH_RMSNORM_KERNEL(T, H_DIM, NUM_THREADS); break;
|
||||
|
||||
#define TYPE_DISPATCH_RMSNORM(T) \
|
||||
CASE_DISPATCH_RMSNORM(T, 2048, 128) \
|
||||
CASE_DISPATCH_RMSNORM(T, 2880, 120) \
|
||||
CASE_DISPATCH_RMSNORM(T, 4096, 128) \
|
||||
CASE_DISPATCH_RMSNORM(T, 5120, 128) \
|
||||
CASE_DISPATCH_RMSNORM(T, 7168, 128) \
|
||||
CASE_DISPATCH_RMSNORM(T, 8192, 128)
|
||||
|
||||
if (dtype == nvinfer1::DataType::kFLOAT)
|
||||
{
|
||||
switch (params.hidden_dim)
|
||||
{
|
||||
case 2048: LAUNCH_RMSNORM_KERNEL(float, 2048); break;
|
||||
case 4096: LAUNCH_RMSNORM_KERNEL(float, 4096); break;
|
||||
// Llama-4 Hidden Dimension
|
||||
case 5120: LAUNCH_RMSNORM_KERNEL(float, 5120); break;
|
||||
// DeepSeek Hidden Dimension
|
||||
case 7168: LAUNCH_RMSNORM_KERNEL(float, 7168); break;
|
||||
case 8192: LAUNCH_RMSNORM_KERNEL(float, 8192); break;
|
||||
TYPE_DISPATCH_RMSNORM(float);
|
||||
default: TLLM_CHECK_WITH_INFO(false, "[MNNVL TwoShot RMSNorm]: unsupported hidden_dim.");
|
||||
}
|
||||
}
|
||||
@ -657,13 +683,7 @@ void twoshot_rmsnorm_op(RMSNormParams const& params)
|
||||
{
|
||||
switch (params.hidden_dim)
|
||||
{
|
||||
case 2048: LAUNCH_RMSNORM_KERNEL(__nv_bfloat16, 2048); break;
|
||||
case 4096: LAUNCH_RMSNORM_KERNEL(__nv_bfloat16, 4096); break;
|
||||
// Llama-4 Hidden Dimension
|
||||
case 5120: LAUNCH_RMSNORM_KERNEL(__nv_bfloat16, 5120); break;
|
||||
// DeepSeek Hidden Dimension
|
||||
case 7168: LAUNCH_RMSNORM_KERNEL(__nv_bfloat16, 7168); break;
|
||||
case 8192: LAUNCH_RMSNORM_KERNEL(__nv_bfloat16, 8192); break;
|
||||
TYPE_DISPATCH_RMSNORM(__nv_bfloat16);
|
||||
default: TLLM_CHECK_WITH_INFO(false, "[MNNVL TwoShot RMSNorm]: unsupported hidden_dim.");
|
||||
}
|
||||
}
|
||||
@ -671,13 +691,7 @@ void twoshot_rmsnorm_op(RMSNormParams const& params)
|
||||
{
|
||||
switch (params.hidden_dim)
|
||||
{
|
||||
case 2048: LAUNCH_RMSNORM_KERNEL(__nv_half, 2048); break;
|
||||
case 4096: LAUNCH_RMSNORM_KERNEL(__nv_half, 4096); break;
|
||||
// Llama-4 Hidden Dimension
|
||||
case 5120: LAUNCH_RMSNORM_KERNEL(__nv_half, 5120); break;
|
||||
// DeepSeek Hidden Dimension
|
||||
case 7168: LAUNCH_RMSNORM_KERNEL(__nv_half, 7168); break;
|
||||
case 8192: LAUNCH_RMSNORM_KERNEL(__nv_half, 8192); break;
|
||||
TYPE_DISPATCH_RMSNORM(__nv_half);
|
||||
default: TLLM_CHECK_WITH_INFO(false, "[MNNVL TwoShot RMSNorm]: unsupported hidden_dim.");
|
||||
}
|
||||
}
|
||||
@ -685,6 +699,8 @@ void twoshot_rmsnorm_op(RMSNormParams const& params)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(false, "[MNNVL TwoShot RMSNorm]: unsupported dtype.");
|
||||
}
|
||||
#undef TYPE_DISPATCH_RMSNORM
|
||||
#undef CASE_DISPATCH_RMSNORM
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::kernels::mnnvl
|
||||
|
||||
@ -30,6 +30,7 @@ struct AllReduceParams
|
||||
int buffer_M;
|
||||
int num_tokens;
|
||||
int token_dim;
|
||||
uint32_t buffer_size;
|
||||
void** buffer_ptrs_dev;
|
||||
void* multicast_ptr;
|
||||
void* buffer_flags;
|
||||
@ -50,6 +51,7 @@ struct RMSNormParams
|
||||
void const* gamma;
|
||||
double epsilon;
|
||||
void* residual;
|
||||
uint32_t buffer_size;
|
||||
uint32_t* buffer_flags;
|
||||
int batch;
|
||||
int hidden_dim;
|
||||
|
||||
@ -150,8 +150,8 @@ __device__ __forceinline__ void fused_op(
|
||||
constexpr int SF_VEC_SIZE = 16;
|
||||
using PackedVec = PackedVec<DType>;
|
||||
PackedVec pack_val = *reinterpret_cast<PackedVec const*>(&norm_val);
|
||||
auto sf_out = cvt_quant_to_fp4_get_sf_out_offset<uint32_t, 2, SF_VEC_SIZE>(std::nullopt /* batchIdx */,
|
||||
token_id, access_id_in_token, std::nullopt /* numRows */, params.hidden_dim,
|
||||
auto sf_out = cvt_quant_get_sf_out_offset<uint32_t, 2>(std::nullopt /* batchIdx */, token_id,
|
||||
access_id_in_token, std::nullopt /* numRows */, params.hidden_dim / SF_VEC_SIZE,
|
||||
reinterpret_cast<uint32_t*>(params.scale_out), params.layout);
|
||||
reinterpret_cast<uint32_t*>(params.quant_out)[access_id]
|
||||
= cvt_warp_fp16_to_fp4<DType, SF_VEC_SIZE, false>(pack_val, *params.scale_factor, sf_out);
|
||||
|
||||
@ -55,7 +55,7 @@ struct AllReduceFusionParams
|
||||
void* rms_gamma;
|
||||
float rms_eps;
|
||||
float* scale_factor;
|
||||
FP4QuantizationSFLayout layout = FP4QuantizationSFLayout::SWIZZLED;
|
||||
QuantizationSFLayout layout = QuantizationSFLayout::SWIZZLED;
|
||||
cudaStream_t stream;
|
||||
};
|
||||
|
||||
|
||||
@ -1,3 +0,0 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:d5bb139b12206a563daec9fa473dda422319bde5ae5f965d37cf5ca67d325c49
|
||||
size 1005546
|
||||
@ -1,3 +0,0 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:c4357a935656d47414a459939720b66311c67213f450168715e1cb0238653768
|
||||
size 1066324
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:0a0671e7cbbed9f51dc0c47e4b970e2f72067d629ff6562c9d65f9cd55c68578
|
||||
size 361861
|
||||
oid sha256:c709dce149c0f4500539e495c90d1da2d86cec28c4187ee9494b015642e158cf
|
||||
size 363441
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:5ec9817bebb07483ce29d8d91c45d35c2c05f0101bfa70146fba5a6576a6b825
|
||||
size 1091614
|
||||
oid sha256:b9170581da010aca67f4bafd9f6f59aaaf5fd1958a1fdd336aa208146599ac06
|
||||
size 1094770
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:0540cdb398818ec54a60c34b462c158e169347db73d244d633669d74211696ba
|
||||
size 1467312
|
||||
oid sha256:2147a246067f7ea74ca382fbc8c02a26332479e5205ecfbe08fb84161a3a87ec
|
||||
size 1483888
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:69bdfba64f1faff30ed8389a28b7b9ef37c0d180b1df643722b280011c8f74e8
|
||||
size 692990
|
||||
oid sha256:279bd48b8ac53690bb4e37dffbe9060428db80c1417ff29c6f4d4a10ab35a7c9
|
||||
size 700094
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:c8173308813999ab64ba8236016b23fbfd3f3f1501f61290bf71ea027ead2920
|
||||
size 642456
|
||||
oid sha256:db5d186ce70d7a94cae2b6619b3449ca557903944beba1ee738d2ee425792d74
|
||||
size 652718
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:f41ae066b01b2a9c3b5165535f743461a9a1d559f6fcd0a00a04c554f8a50962
|
||||
size 414757
|
||||
oid sha256:089a98cf8ab0bbd7530e69821c42220ea02578b740bff62a3e6e33de45209114
|
||||
size 416335
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:ab0be8e667d459e13135f96469613f1c095e47187b24e5d40c7c57583351a076
|
||||
size 1194236
|
||||
oid sha256:1f0cc486ec5e9c1720f495a2a5e7c26d42e737694d307d4746a08b6ead5cc225
|
||||
size 1197394
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:03d86280f76994e2e01d43747cb5c811496b8340d031ebb0c3bdd46437422994
|
||||
size 1654394
|
||||
oid sha256:398965e34c1a4c747b42d8836c04934daaa43903b7931586ed12120e17a61f76
|
||||
size 1672548
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:35c5715bcb1a16c343f3a28be105fb6fee1bbca24cf832f71a7d0f20cf9a0b3e
|
||||
size 365015
|
||||
oid sha256:77cbd7d45164d24be73e021bc0a8745b4f021e4369a254e216ee00b36d3c7263
|
||||
size 366593
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:a3335a8d4b2c0ca63f006c3f957d57aa3f808ef06d4adda322c311a333286d84
|
||||
oid sha256:3a3f74fbe72ef54b9c028d957353c1ecbff1d20bcc9619ff17ee37471934a2ab
|
||||
size 1126352
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:fdc0bf099862d352b3b765e117437240a82e4749d3efd104881647dd4ea14562
|
||||
oid sha256:b3af082c6742f385d0d2c96489ff1de314458eb992d6d5a251c737f8ec912e79
|
||||
size 644092
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:ccd938df8f78af4eae306c6e9e669599c2baf6f095f956318470063c560fbd3c
|
||||
size 1091610
|
||||
oid sha256:8e26f3b8cc173301b3cf07ba1ca7893b6f140432410b0b298361ecff597604c2
|
||||
size 1095556
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:ce4d35ab4c7b65476f0dcec635db1791fcb718afd6b3531338712f5b2bc9aa84
|
||||
size 1460204
|
||||
oid sha256:32220d11bc3542e9edcc36d51b4866bf40044213114d7e237e003afc1fc7c464
|
||||
size 1478358
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:d088ce37b21d335ba1f92034cf97f78fc968d7fecaa0c4f9ec83a0d5165f1d99
|
||||
oid sha256:3ee5ae75df4866d848e90616562345d3740b17b68c90f06329dc074dba5217a9
|
||||
size 482709
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:40653ec672098e2cb1f94c473fa67852efcf6b49a6e8109e4fcf39422281acb4
|
||||
oid sha256:817ae5c1eb8a8c6f22a76ab0b88075fd3391d06abb7dd6d9ab51206b809cd69d
|
||||
size 657930
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:96348957990518db6f51af7c681a71e625dede568cc8f8303dd2de8ad09bfc28
|
||||
oid sha256:680734da0abb1c3029dce32e892687f649c4219f66574acb15ab88471f508263
|
||||
size 677218
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:4687df80ac2fa9454b0564b0a80d78cfaedc2c7796c8f3a1010dd7ebbf722c83
|
||||
oid sha256:c27e871dd680022920081c30c5e239613e53b42129680fdb1d17668b5c5ddd9a
|
||||
size 369401
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:d8b9985065f5f2c62b74c05f8eed02b1909c96656b26fbd7779cc57a2146b037
|
||||
size 947140
|
||||
oid sha256:3e1ecaa635067924b692b665241d86e1d8c1d60a19290de7adde1ff2ca7dbeb0
|
||||
size 956612
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:23599e63b07ad966df921daf3cb97a9ed5cde27eeda0fd96ba5abd835b48f89a
|
||||
size 590779
|
||||
oid sha256:d3018c622303f89c6f22f037ec99eaeaeea9cfe8911e22463b48a22c13116805
|
||||
size 592357
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:cd1c452565583b20913d835de9b14c2f19c0cc431bc926ea6c92295362a85bca
|
||||
size 1813864
|
||||
oid sha256:a7a381f2855236f418a40124a5254401c95001d5e15c074a704e22cc7ed89aa2
|
||||
size 1818600
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:b20de2c6bb3081564ddfbf7ece80fb2c17e66f4e7ff0e0969da4e4655e90d1ec
|
||||
size 2407418
|
||||
oid sha256:9bb49ace4dedc4faa3de2b9c22e09db0f3990129ce7ab4afb6419c38a5d48a16
|
||||
size 2427152
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:33a0e8bb2391128e688e5c6356f09a5ed189ce5c1bcdeef4efc0ce0415dc2849
|
||||
size 555245
|
||||
oid sha256:9769d7cb9754718798be515c84c45ff48e43322573f3f12e31c2e42e99d8dbd4
|
||||
size 557613
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:4b014f41b1cfdf6ed2729778841213a36440191eb3c087346a02c21510bd3f0e
|
||||
size 665794
|
||||
oid sha256:134f4a73e0e6b02b717319ec49e3b3ea0a585cad385a1f300e6c5761f12de9d7
|
||||
size 671320
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:bd77afeb7dcd1ff8d6be80788b20e92e4fbc8c3026ba12d1d522c99316754a7c
|
||||
size 1740442
|
||||
oid sha256:7935b0f053a79a7e620c0efe274fa5b4c840fc9c6e439a381c4d380446e1cb68
|
||||
size 1744388
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:b674707d02aac297b66d523de8b11618ca1598c49eeaf7ce9b1c9d516ce95c4b
|
||||
size 2247958
|
||||
oid sha256:74ecbbaa19b2efe97a3b12c488f0e03c2102f16c460239df4bfc19976fc4365e
|
||||
size 2266902
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:7556f88488e05ee669e763b839afa1b7690060cfa9d8482d419c0ca336df9352
|
||||
oid sha256:813265d25709bd2d39982efbaf092c9163b124bd990fccab505b3c22134522aa
|
||||
size 595585
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:ac9d879aa0c70967bb3a79cd7034998baf43a544c0dd4444ebddeb76e78df5ae
|
||||
oid sha256:dd36195c01bf7c2a2013d5f31d2e74c2579c471385d7b45be7e35ea2f0652608
|
||||
size 908162
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:4e781c0278fc46142f578ae51bfeb38767e89d9c25b92023215948f99dd1d3ed
|
||||
oid sha256:31d4d6dca68c4632d1f435e9179582cfe2ad7a75ee0f7625ee67b0044c914f10
|
||||
size 1371512
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:d608e9e3ec460d2a38f43067a7d7a2dd408e068db690806bbafb11007e175336
|
||||
oid sha256:6570d3ee7b651dec797e82b31eb21fd3261c6e2639fb7c9b157f251bf98bb3bf
|
||||
size 1419662
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:9c1e1d300866c6425c2495e550230051debdca0a7eb85874ae33c0c2de8a81cb
|
||||
oid sha256:88b972677c5436b90fe85870278e3b23d6f709608f99295bddf0be3861d95d1a
|
||||
size 1419662
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:132d83639e34af1b431abdcb3f09542d0389030b85752e18a3ae221ead7d24a3
|
||||
oid sha256:d975f605d62c3070d6cf72f6114d98642c520e66989ed2d2845c3213e921ebf7
|
||||
size 1965880
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:4a96710f6c691580c2363c187a75fd436f5e6be732810a1a45182ce72dc52d1e
|
||||
oid sha256:ef5a2728cbd3241f45f3d8285c91a818e11b2a9fedf322f343a9461d31a6ad30
|
||||
size 1380182
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:a6339f008f451d030aa36a6b3fac7179e7534f7f2474d641fa0ebfbf487074e7
|
||||
oid sha256:16b5f3d3f8760dabc0849217cf11edf18d19896dda475a5fc233bbfd444faf33
|
||||
size 1401494
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:57ebcae2b70fc28881f2b3969868d64c203ef4a9cbc9588a9e28051c5f5b6849
|
||||
oid sha256:cbacb235f39adaeabd68e2fc46c51aac6ca26cdf96293a6a7eb60b5be40640ef
|
||||
size 1401494
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:5e2a4ce1b944feb2b3ed535943089a2d5968bf523b149885df78f7fa4bd7e835
|
||||
oid sha256:e6f3e068435339a64d47673f8018b66c202f6259d68e0a97a4a30acb7505a7fd
|
||||
size 1935872
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:f5d456b30f89ad05ba5b852fabcffb3f8269913d83ef8c0e4e319f2243dee54d
|
||||
oid sha256:7c2d7ab0692de5405b26d19a0c57d720285366ac12a8550bbabca1613cce7f0c
|
||||
size 305897
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:85593d3c2fecb6842a72952c6dcbde19a70e6b26245829d279ca50bb391eb636
|
||||
oid sha256:91a26adfddc0bcaf8b42249f59f1a0b9f74be0f82c7378fe4b56f3a2fa3d4bf1
|
||||
size 290109
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:69cd61bd8334d2109067ef0460a91b8dba4c2cb07392eb636d72d025ccb15bf9
|
||||
oid sha256:6ef79c9e2e2d8bba55d7803dc8dc147b5d8babc29e906a43407a8722bbd8d939
|
||||
size 498507
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:0427b7729ce3cfa652a4595d04f936a947febec8f2c96ce33eed7cbaaa05613e
|
||||
oid sha256:0eef025f8e8581868b02bcea37ff225afebcbb2966450fb29fb0e32ac54eccd4
|
||||
size 668214
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:321bcd81b8965c8dfc08682f775508ae18e3ff711490ee8dff5fe56c20f74843
|
||||
oid sha256:abb2857ffb85cc36aae90ebb674635dffee2b2c5f7ad1ea81bb8002b65d5a0f8
|
||||
size 711628
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:aa77d3789c0ca314689125ec303a8af76554120a708a4b63395c69b7aad07f04
|
||||
oid sha256:49a3661535314b139e2794fe16f6f3e0a8d45742b68ea59ba99a9113068adf2c
|
||||
size 752698
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:aa35aa70d0fa304c776c076a1a189d32a054d3f696dac5d99018085d1108c73b
|
||||
oid sha256:d76fb6c4f8bb2de687bc5f9f275389356934119c1f0db9983dcf0ec7b68c6197
|
||||
size 748726
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:d1a702d456b5acf279487dd810e3e33efdd1c7bd82530ceb5a32ad30ec30396c
|
||||
oid sha256:be8ee89f4489c430d0ff6e9c6cf4e07379ac05abf468d47e34e084ad594b2037
|
||||
size 946060
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:558aa7d42de329c49361c94c4baef16738304b21b6adbe675d77c7819ef37660
|
||||
oid sha256:aa4be8ca2dd52e56c9a6af76b90ac353d217fad5fa931b21129ac5a811b5283a
|
||||
size 489823
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:7b5baa6048e6c33e74c6d343eb7c76252ff2e534fe467b3189af12b5d64af37c
|
||||
oid sha256:cb0482b768a40bc7f8a86fa23a84bab62fb82c205f3237ff60becda50cbafc90
|
||||
size 489823
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:e17cb191ad092e6db255ea503e49ea883ed56322fc58ed8d68710f6687376c1f
|
||||
oid sha256:95b1796f4e7c905eca82ed3691427025f68e765797440b962b0114a5ab32b1d7
|
||||
size 500083
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:bfca5660a931e08941347f7a0aefa82c214940e8eaa6b6d89cfded621f34a490
|
||||
oid sha256:2d9f13977fc865e716f1f35dfdb222a38000b224ff7394134230ed5c88119947
|
||||
size 496125
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:fffd2cd799953808034d7e7b89a57d4fede24db124bfb0d3938188177acbdfeb
|
||||
oid sha256:007e32a06fcac853159dc5786940447281c57ba70406d38beb6f089fd037053d
|
||||
size 182023
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:19ada3a5d449542f103077db8d193bc2293a8f48ccee201e366473964287314c
|
||||
oid sha256:26241ea5909395116e1b1a0f19cadc448886f6a6ab2b3ba76c092b67cd0148f0
|
||||
size 182023
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:b9c32124cd708aab7da30637d85437da0af9bf2157d163c19c6fe14498698cda
|
||||
oid sha256:86e4ca60a459117c5e701631fbd3c67ca66e81d177c394c1fc9ad3b66396e69a
|
||||
size 661096
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:7f248fd42759509c61d20f912ae74dc3a85448a9c8386370ea92492ed9031e80
|
||||
oid sha256:770db1f4ec1c2d3c25767593b60cb095e49f7a6eb7abe054bbdec6e72db97f8d
|
||||
size 672936
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:190fd946ddc7e1b5e9ca2172ec1de39c6288829773d9ce29fe98374256eff566
|
||||
oid sha256:0b6428cae2d0c8c813925be9589c94771098cfe5a6d0ff2036104d3e36384b81
|
||||
size 721900
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:b7cd5976c836bcd75c0cadfe968050ac60bf89b93df021ad6c1681e159c497c5
|
||||
oid sha256:36c6932301fe3dc29631c28fcb8cb6b08652103bc7a36fd74a03a8189a1c77e4
|
||||
size 717928
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:7c536d725e1d9ebd2cb836dfe3993edcc81101534db6b7f1943c8a9443838bf4
|
||||
oid sha256:d858f6dcaf3f49fb3fa18b1c8c20ee1b933e2c8ddd1a429c8d3b5b4d269fb875
|
||||
size 927892
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:b5907da5a2f68c010d44bbbd0d780e097f9625be15b2f85e8dd1f00dd4c31ff9
|
||||
oid sha256:7dc92ab65ed0fc5f9d821f52a396a6d55ea9ae37e080eac7ff9e9c14eae741e7
|
||||
size 631890
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:9cf14c71134a89ed6ffc83c0b7db06ed10e22b55294dc15ddf7f016427f01033
|
||||
oid sha256:d66606a37cfe8eb78ccc3f548a231f770df9f46e70f6d3ba22fb8abe6216480e
|
||||
size 159919
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:f2b83c70dbc8ab0b3695dab3f4d2069b7ee7119e9140d7860b8c19f59a498589
|
||||
oid sha256:b723b296cff04602f64a5da9928e6f9b6a03c5cc608ba9ef7d8055f23f1f4ea2
|
||||
size 159919
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:fc8369f5701dceea91d429a713ddcbb4ecb0ad08d3c9042688557ead5f00e9da
|
||||
oid sha256:d40578a5684262cd8136705367e2c98493ea9b9fcfc123c7efa3ead14017b5b8
|
||||
size 483493
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:4e9fffff2d13d49613e5f9334a010ca9bcde43b3bb55a792fd97fe2c867760dc
|
||||
oid sha256:60cc82b9d11c53392de91a7c4c097263c20a56f9b346278c7c9af12ef2bb5fbf
|
||||
size 496123
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:dd3041ba5a52263f7f02d64f1911c50e346151bf529e865c1abf22583abd3e21
|
||||
oid sha256:8f685b6b2a0a573953f31fad89fa37e949361db245de69c0c06ce0bbb14eacef
|
||||
size 443285
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:12482099b086249163085e6e3421a61f6e304f865aaf56dd15382614be5e48e7
|
||||
oid sha256:834f0f3601c589893a21b957be2864df594f96b34b2cfd6018ada8319986aa21
|
||||
size 441683
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:bfea1ea1627eaef7b614db08bad00bda8b611c8e466c858e050c0ce2aee2eafb
|
||||
oid sha256:3d81a070e7ed49f1e1a322d38a757a3505186cf5cbded99814e950e07229a46a
|
||||
size 298049
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:f828600699faa3a0474085cbbe88d2e0ac7c8e056c976b81a882c3a72682e527
|
||||
oid sha256:b9de5bc49d888699da1880d24ccf6a9cb6c0049d7a244d1ae9ab64b7365ecd5a
|
||||
size 296445
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:2d4b297922065ecb79b4a1278d048b253b57601d011fc5833a32f9fc1b78e58e
|
||||
oid sha256:e30ed0df4b0d0b1da1ace5831dc0a7a526e04001b25860f862345c78acff5a43
|
||||
size 427485
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:3fd5305445c9856fbd5d9dfaffdd7f87b9014638f33fb63fb2cb4fce9893b20b
|
||||
oid sha256:030015dc1811e3dc2ae36ed770f51063a3f46deae42ead5e1523c977b438a133
|
||||
size 425883
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:2b7fee97097f799830df2bcb1c782c7ea9018243cbd5cd0e0f47ec299b49db79
|
||||
oid sha256:6921a204892e1336cef2a308be38855f3c888e56bd6a16752d2806aa9e93c431
|
||||
size 1524634
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user