mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
* fp8 kv + bf16 ctx MLA + fp8 gen MLA
Use BF16 for context MLA.
mFP8GenerationMLA and mFP8ContextFMHA shouldn't be enabled together.
Allow mSM==90 for mFP8GenerationMLA==true.
For FMHA, dataTypeKv should be FP8.
For FP8 MLA generation, the output is still in BF16.
Refine debug info for FMHA kernel metadata.
Use inputType, outputType, SM together to hash kernel list.
Add FP8 MLA generation FMHA kernel.
Special WAR of NUM_COMPUTE_GROUPS for MLA generation kernel.
Separate the implementation of fused_multihead_attention_v2.h to CPP and print some debug info if checkIfKernelExist fails.
Refine debug info in fused_multihead_attention_v2.cpp
Correct FP8 MLA metadata.
New kernel provided by Yuxin, which outputs BF16.
smem size is not set correctly, which will lead to illegal mem access.
Yuxin fixed the error in FMHA MLA kernel: previously the BF16 isn't correctly written: some parts are repeatedly written, while some others are untouched.
There are two bmm1 scales that should be set correctly.
New kernel generated by Yuxin.
Modificatiosn to common/attentionOp for FP8 MLA on Hopper using FMHA.
Not necessary. If mFP8GenerationMLA, is_fp8_out is false, so mFP8ContextFMHA is false.
Skip a check in fmhaDispatcher.
Modifications in fmhaRunner:
- Debug dump.
- if (!isFP8GenerationMLA) skips a lot of flag setting.
- TMA descriptor modification for qo (by Yuxin).
Cleanup debug output.
Clean up o tma descriptor modifications.
Signed-off-by: Bo Li <bobboli0202@gmail.com>
* Resolve conflicts.
Signed-off-by: Bo Li <bobboli0202@gmail.com>
* Apply the patch of FP8 FlashMLA and resolve conflicts.
Signed-off-by: Bo Li <bobboli0202@gmail.com>
* Fix compilation error.
Signed-off-by: Bo Li <bobboli0202@gmail.com>
* Fix compile error.
Signed-off-by: Bo Li <bobboli0202@gmail.com>
* pick blackwell support
Signed-off-by: Dylan Chen <191843203+DylanChen-NV@users.noreply.github.com>
* Add copyright notice to fused_multihead_attention_v2.cpp.
Signed-off-by: Bo Li <bobboli0202@gmail.com>
* Add license.
Signed-off-by: Bo Li <bobboli0202@gmail.com>
* Add missing license.
Signed-off-by: Bo Li <bobboli0202@gmail.com>
* Exclude building flashMLA kernels under sm90.
Signed-off-by: Bo Li <bobboli0202@gmail.com>
* Revert "Exclude building flashMLA kernels under sm90."
This reverts commit f0c859d459.
Signed-off-by: Bo Li <bobboli0202@gmail.com>
* Use macro to skip compiling FlashMLA for non sm90 targets.
Signed-off-by: Bo Li <bobboli0202@gmail.com>
---------
Signed-off-by: Bo Li <bobboli0202@gmail.com>
Signed-off-by: Dylan Chen <191843203+DylanChen-NV@users.noreply.github.com>
Co-authored-by: Dylan Chen <ziqingc@nvidia.com>
Co-authored-by: Dylan Chen <191843203+DylanChen-NV@users.noreply.github.com>
Co-authored-by: QI JUN <22017000+QiJune@users.noreply.github.com>
277 lines
11 KiB
C++
277 lines
11 KiB
C++
/*
|
|
* MIT License
|
|
*
|
|
* Copyright (c) 2025 DeepSeek
|
|
*
|
|
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
* of this software and associated documentation files (the "Software"), to deal
|
|
* in the Software without restriction, including without limitation the rights
|
|
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
* copies of the Software, and to permit persons to whom the Software is
|
|
* furnished to do so, subject to the following conditions:
|
|
*
|
|
* The above copyright notice and this permission notice shall be included in all
|
|
* copies or substantial portions of the Software.
|
|
*
|
|
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
* SOFTWARE.
|
|
*
|
|
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*
|
|
* reference: https://github.com/deepseek-ai/FlashMLA
|
|
*/
|
|
|
|
// Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/src/softmax.h
|
|
|
|
#pragma once
|
|
|
|
#include <cmath>
|
|
|
|
#include <cute/tensor.hpp>
|
|
#include <cutlass/numeric_types.h>
|
|
|
|
#include "utils.h"
|
|
|
|
namespace flash
|
|
{
|
|
|
|
using namespace cute;
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <bool zero_init = true, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
|
|
typename Operator>
|
|
__device__ __forceinline__ void thread_reduce_(
|
|
Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1>& summary, Operator& op)
|
|
{
|
|
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
|
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
|
|
CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor));
|
|
#pragma unroll
|
|
for (int mi = 0; mi < size<0>(tensor); mi++)
|
|
{
|
|
summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0));
|
|
#pragma unroll
|
|
for (int ni = 1; ni < size<1>(tensor); ni++)
|
|
{
|
|
summary(mi) = op(summary(mi), tensor(mi, ni));
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
|
|
__device__ __forceinline__ void quad_allreduce_(
|
|
Tensor<Engine0, Layout0>& dst, Tensor<Engine1, Layout1>& src, Operator& op)
|
|
{
|
|
CUTE_STATIC_ASSERT_V(size(dst) == size(src));
|
|
#pragma unroll
|
|
for (int i = 0; i < size(dst); i++)
|
|
{
|
|
dst(i) = Allreduce<4>::run(src(i), op);
|
|
}
|
|
}
|
|
|
|
template <bool zero_init = true, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
|
|
typename Operator>
|
|
__device__ __forceinline__ void reduce_(
|
|
Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1>& summary, Operator& op)
|
|
{
|
|
thread_reduce_<zero_init>(tensor, summary, op);
|
|
quad_allreduce_(summary, summary, op);
|
|
}
|
|
|
|
template <bool zero_init = true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
|
__device__ __forceinline__ void reduce_max(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1>& max)
|
|
{
|
|
MaxOp<float> max_op;
|
|
reduce_<zero_init>(tensor, max, max_op);
|
|
}
|
|
|
|
template <bool zero_init = true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
|
__device__ __forceinline__ void reduce_sum(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1>& sum)
|
|
{
|
|
SumOp<float> sum_op;
|
|
thread_reduce_<zero_init>(tensor, sum, sum_op);
|
|
}
|
|
|
|
// Apply the exp to all the elements.
|
|
template <bool Scale_max = true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
|
__forceinline__ __device__ auto scale_apply_exp2(
|
|
Tensor<Engine0, Layout0>& tensor, Tensor<Engine1, Layout1> const& max, float const scale)
|
|
{
|
|
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
|
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
|
|
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
|
|
#pragma unroll
|
|
for (int mi = 0; mi < size<0>(tensor); ++mi)
|
|
{
|
|
// If max is -inf, then all elements must have been -inf (possibly due to masking).
|
|
// We don't want (-inf - (-inf)) since that would give NaN.
|
|
// If we don't have float around M_LOG2E the multiplication is done in fp64.
|
|
float const max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E));
|
|
#pragma unroll
|
|
for (int ni = 0; ni < size<1>(tensor); ++ni)
|
|
{
|
|
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
|
|
// max * log_2(e)) This allows the compiler to use the ffma
|
|
// instruction instead of fadd and fmul separately.
|
|
// The following macro will disable the use of fma.
|
|
// See: https://github.com/pytorch/pytorch/issues/121558 for more details
|
|
// This macro is set in PyTorch and not FlashAttention
|
|
#ifdef UNFUSE_FMA
|
|
tensor(mi, ni) = exp2f(__fmul_rn(tensor(mi, ni), scale) - max_scaled);
|
|
#else
|
|
tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
|
|
#endif
|
|
}
|
|
}
|
|
return tensor;
|
|
}
|
|
|
|
// Apply the exp to all the elements.
|
|
template <bool zero_init = true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
|
__forceinline__ __device__ void max_scale_exp2_sum(
|
|
Tensor<Engine0, Layout0>& tensor, Tensor<Engine1, Layout1>& max, Tensor<Engine1, Layout1>& sum, float const scale)
|
|
{
|
|
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
|
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
|
|
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
|
|
#pragma unroll
|
|
for (int mi = 0; mi < size<0>(tensor); ++mi)
|
|
{
|
|
MaxOp<float> max_op;
|
|
max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0));
|
|
#pragma unroll
|
|
for (int ni = 1; ni < size<1>(tensor); ni++)
|
|
{
|
|
max(mi) = max_op(max(mi), tensor(mi, ni));
|
|
}
|
|
max(mi) = Allreduce<4>::run(max(mi), max_op);
|
|
// If max is -inf, then all elements must have been -inf (possibly due to masking).
|
|
// We don't want (-inf - (-inf)) since that would give NaN.
|
|
float const max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale;
|
|
sum(mi) = 0;
|
|
#pragma unroll
|
|
for (int ni = 0; ni < size<1>(tensor); ++ni)
|
|
{
|
|
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
|
|
// max * log_2(e)) This allows the compiler to use the ffma
|
|
// instruction instead of fadd and fmul separately.
|
|
tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
|
|
sum(mi) += tensor(mi, ni);
|
|
}
|
|
SumOp<float> sum_op;
|
|
sum(mi) = Allreduce<4>::run(sum(mi), sum_op);
|
|
}
|
|
}
|
|
|
|
template <typename Tensor0, typename Tensor1>
|
|
__forceinline__ __device__ void rescale_o(Tensor0& acc_o, Tensor1& scale_o)
|
|
{
|
|
// Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
|
|
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
|
|
#pragma unroll
|
|
for (int mi = 0; mi < size(scale_o); ++mi)
|
|
{
|
|
#pragma unroll
|
|
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni)
|
|
{
|
|
acc_o_rowcol(mi, ni) *= scale_o(mi);
|
|
}
|
|
}
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <int kNRows>
|
|
struct Softmax
|
|
{
|
|
|
|
using TensorT = decltype(make_tensor<float>(Shape<Int<kNRows>>{}));
|
|
TensorT row_max, row_sum;
|
|
|
|
__forceinline__ __device__ Softmax(){};
|
|
|
|
template <bool Is_first, bool Check_inf = false, typename Tensor0>
|
|
__forceinline__ __device__ TensorT softmax(Tensor0& acc_s, float softmax_scale_log2)
|
|
{
|
|
// Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
|
|
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
|
|
static_assert(decltype(size<0>(scores))::value == kNRows);
|
|
TensorT scale_o;
|
|
clear(scale_o);
|
|
if (Is_first)
|
|
{
|
|
flash::template reduce_max</*zero_init=*/true>(scores, row_max);
|
|
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
|
|
flash::reduce_sum</*zero_init=*/true>(scores, row_sum);
|
|
}
|
|
else
|
|
{
|
|
Tensor scores_max_prev = make_fragment_like(row_max);
|
|
cute::copy(row_max, scores_max_prev);
|
|
flash::template reduce_max</*zero_init=*/false>(scores, row_max);
|
|
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
|
|
#pragma unroll
|
|
for (int mi = 0; mi < size(row_max); ++mi)
|
|
{
|
|
float scores_max_cur = !Check_inf ? row_max(mi) : (row_max(mi) == -INFINITY ? 0.0f : row_max(mi));
|
|
float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
|
|
scale_o(mi) = scores_scale;
|
|
row_sum(mi) *= scores_scale;
|
|
}
|
|
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
|
|
// We don't do the reduce across threads here since we don't need to use the row_sum.
|
|
// We do that reduce at the end when we need to normalize the softmax.
|
|
flash::reduce_sum</*zero_init=*/false>(scores, row_sum);
|
|
}
|
|
return scale_o;
|
|
};
|
|
|
|
template <bool Is_dropout = false, bool Split = false, typename Tensor0>
|
|
__forceinline__ __device__ TensorT normalize_softmax_lse(
|
|
Tensor0& acc_o, float softmax_scale, float descale_v, float rp_dropout = 1.0)
|
|
{
|
|
SumOp<float> sum_op;
|
|
quad_allreduce_(row_sum, row_sum, sum_op);
|
|
TensorT lse = make_fragment_like(row_sum);
|
|
// Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
|
|
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
|
|
static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
|
|
#pragma unroll
|
|
for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi)
|
|
{
|
|
float sum = row_sum(mi);
|
|
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : descale_v / sum;
|
|
lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY)
|
|
: row_max(mi) * softmax_scale + __logf(sum);
|
|
float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout;
|
|
#pragma unroll
|
|
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni)
|
|
{
|
|
acc_o_rowcol(mi, ni) *= scale;
|
|
}
|
|
}
|
|
return lse;
|
|
};
|
|
};
|
|
|
|
} // namespace flash
|