/* * MIT License * * Copyright (c) 2025 DeepSeek * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in all * copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. * * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * reference: https://github.com/deepseek-ai/FlashMLA */ // Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/src/softmax.h #pragma once #include #include #include #include "utils.h" namespace flash { using namespace cute; //////////////////////////////////////////////////////////////////////////////////////////////////// template __device__ __forceinline__ void thread_reduce_( Tensor const& tensor, Tensor& summary, Operator& op) { static_assert(Layout0::rank == 2, "Only support 2D Tensor"); static_assert(Layout1::rank == 1, "Only support 1D Tensor"); CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor)); #pragma unroll for (int mi = 0; mi < size<0>(tensor); mi++) { summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0)); #pragma unroll for (int ni = 1; ni < size<1>(tensor); ni++) { summary(mi) = op(summary(mi), tensor(mi, ni)); } } } template __device__ __forceinline__ void quad_allreduce_( Tensor& dst, Tensor& src, Operator& op) { CUTE_STATIC_ASSERT_V(size(dst) == size(src)); #pragma unroll for (int i = 0; i < size(dst); i++) { dst(i) = Allreduce<4>::run(src(i), op); } } template __device__ __forceinline__ void reduce_( Tensor const& tensor, Tensor& summary, Operator& op) { thread_reduce_(tensor, summary, op); quad_allreduce_(summary, summary, op); } template __device__ __forceinline__ void reduce_max(Tensor const& tensor, Tensor& max) { MaxOp max_op; reduce_(tensor, max, max_op); } template __device__ __forceinline__ void reduce_sum(Tensor const& tensor, Tensor& sum) { SumOp sum_op; thread_reduce_(tensor, sum, sum_op); } // Apply the exp to all the elements. template __forceinline__ __device__ auto scale_apply_exp2( Tensor& tensor, Tensor const& max, float const scale) { static_assert(Layout0::rank == 2, "Only support 2D Tensor"); static_assert(Layout1::rank == 1, "Only support 1D Tensor"); CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); #pragma unroll for (int mi = 0; mi < size<0>(tensor); ++mi) { // If max is -inf, then all elements must have been -inf (possibly due to masking). // We don't want (-inf - (-inf)) since that would give NaN. // If we don't have float around M_LOG2E the multiplication is done in fp64. float const max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E)); #pragma unroll for (int ni = 0; ni < size<1>(tensor); ++ni) { // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - // max * log_2(e)) This allows the compiler to use the ffma // instruction instead of fadd and fmul separately. // The following macro will disable the use of fma. // See: https://github.com/pytorch/pytorch/issues/121558 for more details // This macro is set in PyTorch and not FlashAttention #ifdef UNFUSE_FMA tensor(mi, ni) = exp2f(__fmul_rn(tensor(mi, ni), scale) - max_scaled); #else tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); #endif } } return tensor; } // Apply the exp to all the elements. template __forceinline__ __device__ void max_scale_exp2_sum( Tensor& tensor, Tensor& max, Tensor& sum, float const scale) { static_assert(Layout0::rank == 2, "Only support 2D Tensor"); static_assert(Layout1::rank == 1, "Only support 1D Tensor"); CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); #pragma unroll for (int mi = 0; mi < size<0>(tensor); ++mi) { MaxOp max_op; max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0)); #pragma unroll for (int ni = 1; ni < size<1>(tensor); ni++) { max(mi) = max_op(max(mi), tensor(mi, ni)); } max(mi) = Allreduce<4>::run(max(mi), max_op); // If max is -inf, then all elements must have been -inf (possibly due to masking). // We don't want (-inf - (-inf)) since that would give NaN. float const max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale; sum(mi) = 0; #pragma unroll for (int ni = 0; ni < size<1>(tensor); ++ni) { // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - // max * log_2(e)) This allows the compiler to use the ffma // instruction instead of fadd and fmul separately. tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); sum(mi) += tensor(mi, ni); } SumOp sum_op; sum(mi) = Allreduce<4>::run(sum(mi), sum_op); } } template __forceinline__ __device__ void rescale_o(Tensor0& acc_o, Tensor1& scale_o) { // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); #pragma unroll for (int mi = 0; mi < size(scale_o); ++mi) { #pragma unroll for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale_o(mi); } } } //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Softmax { using TensorT = decltype(make_tensor(Shape>{})); TensorT row_max, row_sum; __forceinline__ __device__ Softmax(){}; template __forceinline__ __device__ TensorT softmax(Tensor0& acc_s, float softmax_scale_log2) { // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); static_assert(decltype(size<0>(scores))::value == kNRows); TensorT scale_o; clear(scale_o); if (Is_first) { flash::template reduce_max(scores, row_max); flash::scale_apply_exp2(scores, row_max, softmax_scale_log2); flash::reduce_sum(scores, row_sum); } else { Tensor scores_max_prev = make_fragment_like(row_max); cute::copy(row_max, scores_max_prev); flash::template reduce_max(scores, row_max); // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) #pragma unroll for (int mi = 0; mi < size(row_max); ++mi) { float scores_max_cur = !Check_inf ? row_max(mi) : (row_max(mi) == -INFINITY ? 0.0f : row_max(mi)); float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); scale_o(mi) = scores_scale; row_sum(mi) *= scores_scale; } flash::scale_apply_exp2(scores, row_max, softmax_scale_log2); // We don't do the reduce across threads here since we don't need to use the row_sum. // We do that reduce at the end when we need to normalize the softmax. flash::reduce_sum(scores, row_sum); } return scale_o; }; template __forceinline__ __device__ TensorT normalize_softmax_lse( Tensor0& acc_o, float softmax_scale, float descale_v, float rp_dropout = 1.0) { SumOp sum_op; quad_allreduce_(row_sum, row_sum, sum_op); TensorT lse = make_fragment_like(row_sum); // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); #pragma unroll for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { float sum = row_sum(mi); float inv_sum = (sum == 0.f || sum != sum) ? 1.f : descale_v / sum; lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum); float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout; #pragma unroll for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; } } return lse; }; }; } // namespace flash