mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
277 lines
11 KiB
C++
277 lines
11 KiB
C++
/*
|
|
* MIT License
|
|
*
|
|
* Copyright (c) 2025 DeepSeek
|
|
*
|
|
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
* of this software and associated documentation files (the "Software"), to deal
|
|
* in the Software without restriction, including without limitation the rights
|
|
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
* copies of the Software, and to permit persons to whom the Software is
|
|
* furnished to do so, subject to the following conditions:
|
|
*
|
|
* The above copyright notice and this permission notice shall be included in all
|
|
* copies or substantial portions of the Software.
|
|
*
|
|
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
* SOFTWARE.
|
|
*
|
|
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*
|
|
* reference: https://github.com/deepseek-ai/FlashMLA
|
|
*/
|
|
|
|
// Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/src/softmax.h
|
|
|
|
#pragma once
|
|
|
|
#include <cmath>
|
|
|
|
#include <cute/tensor.hpp>
|
|
#include <cutlass/numeric_types.h>
|
|
|
|
#include "utils.h"
|
|
|
|
namespace flash
|
|
{
|
|
|
|
using namespace cute;
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <bool zero_init = true, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
|
|
typename Operator>
|
|
__device__ __forceinline__ void thread_reduce_(
|
|
Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1>& summary, Operator& op)
|
|
{
|
|
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
|
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
|
|
CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor));
|
|
#pragma unroll
|
|
for (int mi = 0; mi < size<0>(tensor); mi++)
|
|
{
|
|
summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0));
|
|
#pragma unroll
|
|
for (int ni = 1; ni < size<1>(tensor); ni++)
|
|
{
|
|
summary(mi) = op(summary(mi), tensor(mi, ni));
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
|
|
__device__ __forceinline__ void quad_allreduce_(
|
|
Tensor<Engine0, Layout0>& dst, Tensor<Engine1, Layout1>& src, Operator& op)
|
|
{
|
|
CUTE_STATIC_ASSERT_V(size(dst) == size(src));
|
|
#pragma unroll
|
|
for (int i = 0; i < size(dst); i++)
|
|
{
|
|
dst(i) = Allreduce<4>::run(src(i), op);
|
|
}
|
|
}
|
|
|
|
template <bool zero_init = true, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
|
|
typename Operator>
|
|
__device__ __forceinline__ void reduce_(
|
|
Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1>& summary, Operator& op)
|
|
{
|
|
thread_reduce_<zero_init>(tensor, summary, op);
|
|
quad_allreduce_(summary, summary, op);
|
|
}
|
|
|
|
template <bool zero_init = true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
|
__device__ __forceinline__ void reduce_max(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1>& max)
|
|
{
|
|
MaxOp<float> max_op;
|
|
reduce_<zero_init>(tensor, max, max_op);
|
|
}
|
|
|
|
template <bool zero_init = true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
|
__device__ __forceinline__ void reduce_sum(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1>& sum)
|
|
{
|
|
SumOp<float> sum_op;
|
|
thread_reduce_<zero_init>(tensor, sum, sum_op);
|
|
}
|
|
|
|
// Apply the exp to all the elements.
|
|
template <bool Scale_max = true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
|
__forceinline__ __device__ auto scale_apply_exp2(
|
|
Tensor<Engine0, Layout0>& tensor, Tensor<Engine1, Layout1> const& max, float const scale)
|
|
{
|
|
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
|
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
|
|
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
|
|
#pragma unroll
|
|
for (int mi = 0; mi < size<0>(tensor); ++mi)
|
|
{
|
|
// If max is -inf, then all elements must have been -inf (possibly due to masking).
|
|
// We don't want (-inf - (-inf)) since that would give NaN.
|
|
// If we don't have float around M_LOG2E the multiplication is done in fp64.
|
|
float const max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E));
|
|
#pragma unroll
|
|
for (int ni = 0; ni < size<1>(tensor); ++ni)
|
|
{
|
|
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
|
|
// max * log_2(e)) This allows the compiler to use the ffma
|
|
// instruction instead of fadd and fmul separately.
|
|
// The following macro will disable the use of fma.
|
|
// See: https://github.com/pytorch/pytorch/issues/121558 for more details
|
|
// This macro is set in PyTorch and not FlashAttention
|
|
#ifdef UNFUSE_FMA
|
|
tensor(mi, ni) = exp2f(__fmul_rn(tensor(mi, ni), scale) - max_scaled);
|
|
#else
|
|
tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
|
|
#endif
|
|
}
|
|
}
|
|
return tensor;
|
|
}
|
|
|
|
// Apply the exp to all the elements.
|
|
template <bool zero_init = true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
|
__forceinline__ __device__ void max_scale_exp2_sum(
|
|
Tensor<Engine0, Layout0>& tensor, Tensor<Engine1, Layout1>& max, Tensor<Engine1, Layout1>& sum, float const scale)
|
|
{
|
|
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
|
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
|
|
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
|
|
#pragma unroll
|
|
for (int mi = 0; mi < size<0>(tensor); ++mi)
|
|
{
|
|
MaxOp<float> max_op;
|
|
max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0));
|
|
#pragma unroll
|
|
for (int ni = 1; ni < size<1>(tensor); ni++)
|
|
{
|
|
max(mi) = max_op(max(mi), tensor(mi, ni));
|
|
}
|
|
max(mi) = Allreduce<4>::run(max(mi), max_op);
|
|
// If max is -inf, then all elements must have been -inf (possibly due to masking).
|
|
// We don't want (-inf - (-inf)) since that would give NaN.
|
|
float const max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale;
|
|
sum(mi) = 0;
|
|
#pragma unroll
|
|
for (int ni = 0; ni < size<1>(tensor); ++ni)
|
|
{
|
|
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
|
|
// max * log_2(e)) This allows the compiler to use the ffma
|
|
// instruction instead of fadd and fmul separately.
|
|
tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
|
|
sum(mi) += tensor(mi, ni);
|
|
}
|
|
SumOp<float> sum_op;
|
|
sum(mi) = Allreduce<4>::run(sum(mi), sum_op);
|
|
}
|
|
}
|
|
|
|
template <typename Tensor0, typename Tensor1>
|
|
__forceinline__ __device__ void rescale_o(Tensor0& acc_o, Tensor1& scale_o)
|
|
{
|
|
// Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
|
|
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
|
|
#pragma unroll
|
|
for (int mi = 0; mi < size(scale_o); ++mi)
|
|
{
|
|
#pragma unroll
|
|
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni)
|
|
{
|
|
acc_o_rowcol(mi, ni) *= scale_o(mi);
|
|
}
|
|
}
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <int kNRows>
|
|
struct Softmax
|
|
{
|
|
|
|
using TensorT = decltype(make_tensor<float>(Shape<Int<kNRows>>{}));
|
|
TensorT row_max, row_sum;
|
|
|
|
__forceinline__ __device__ Softmax(){};
|
|
|
|
template <bool Is_first, bool Check_inf = false, typename Tensor0>
|
|
__forceinline__ __device__ TensorT softmax(Tensor0& acc_s, float softmax_scale_log2)
|
|
{
|
|
// Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
|
|
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
|
|
static_assert(decltype(size<0>(scores))::value == kNRows);
|
|
TensorT scale_o;
|
|
clear(scale_o);
|
|
if (Is_first)
|
|
{
|
|
flash::template reduce_max</*zero_init=*/true>(scores, row_max);
|
|
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
|
|
flash::reduce_sum</*zero_init=*/true>(scores, row_sum);
|
|
}
|
|
else
|
|
{
|
|
Tensor scores_max_prev = make_fragment_like(row_max);
|
|
cute::copy(row_max, scores_max_prev);
|
|
flash::template reduce_max</*zero_init=*/false>(scores, row_max);
|
|
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
|
|
#pragma unroll
|
|
for (int mi = 0; mi < size(row_max); ++mi)
|
|
{
|
|
float scores_max_cur = !Check_inf ? row_max(mi) : (row_max(mi) == -INFINITY ? 0.0f : row_max(mi));
|
|
float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
|
|
scale_o(mi) = scores_scale;
|
|
row_sum(mi) *= scores_scale;
|
|
}
|
|
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
|
|
// We don't do the reduce across threads here since we don't need to use the row_sum.
|
|
// We do that reduce at the end when we need to normalize the softmax.
|
|
flash::reduce_sum</*zero_init=*/false>(scores, row_sum);
|
|
}
|
|
return scale_o;
|
|
};
|
|
|
|
template <bool Is_dropout = false, bool Split = false, typename Tensor0>
|
|
__forceinline__ __device__ TensorT normalize_softmax_lse(
|
|
Tensor0& acc_o, float softmax_scale, float descale_v, float rp_dropout = 1.0)
|
|
{
|
|
SumOp<float> sum_op;
|
|
quad_allreduce_(row_sum, row_sum, sum_op);
|
|
TensorT lse = make_fragment_like(row_sum);
|
|
// Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
|
|
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
|
|
static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
|
|
#pragma unroll
|
|
for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi)
|
|
{
|
|
float sum = row_sum(mi);
|
|
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : descale_v / sum;
|
|
lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY)
|
|
: row_max(mi) * softmax_scale + __logf(sum);
|
|
float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout;
|
|
#pragma unroll
|
|
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni)
|
|
{
|
|
acc_o_rowcol(mi, ni) *= scale;
|
|
}
|
|
}
|
|
return lse;
|
|
};
|
|
};
|
|
|
|
} // namespace flash
|