TensorRT-LLMs/cpp/tensorrt_llm/kernels/mlaKernels.h
Bo Li 515dd0d78f
feat: Add support for FP8 MLA on Hopper and Blackwell. (#3190)
* fp8 kv + bf16 ctx MLA + fp8 gen MLA

Use BF16 for context MLA.
mFP8GenerationMLA and mFP8ContextFMHA shouldn't be enabled together.

Allow mSM==90 for mFP8GenerationMLA==true.
For FMHA, dataTypeKv should be FP8.

For FP8 MLA generation, the output is still in BF16.

Refine debug info for FMHA kernel metadata.

Use inputType, outputType, SM together to hash kernel list.

Add FP8 MLA generation FMHA kernel.

Special WAR of NUM_COMPUTE_GROUPS for MLA generation kernel.

Separate the implementation of fused_multihead_attention_v2.h to CPP and print some debug info if checkIfKernelExist fails.

Refine debug info in fused_multihead_attention_v2.cpp

Correct FP8 MLA metadata.

New kernel provided by Yuxin, which outputs BF16.

smem size is not set correctly, which will lead to illegal mem access.

Yuxin fixed the error in FMHA MLA kernel: previously the BF16 isn't correctly written: some parts are repeatedly written, while some others are untouched.

There are two bmm1 scales that should be set correctly.

New kernel generated by Yuxin.

Modificatiosn to common/attentionOp for FP8 MLA on Hopper using FMHA.

Not necessary. If mFP8GenerationMLA, is_fp8_out is false, so mFP8ContextFMHA is false.

Skip a check in fmhaDispatcher.

Modifications in fmhaRunner:
- Debug dump.
- if (!isFP8GenerationMLA) skips a lot of flag setting.
- TMA descriptor modification for qo (by Yuxin).

Cleanup debug output.

Clean up o tma descriptor modifications.

Signed-off-by: Bo Li <bobboli0202@gmail.com>

* Resolve conflicts.

Signed-off-by: Bo Li <bobboli0202@gmail.com>

* Apply the patch of FP8 FlashMLA and resolve conflicts.

Signed-off-by: Bo Li <bobboli0202@gmail.com>

* Fix compilation error.

Signed-off-by: Bo Li <bobboli0202@gmail.com>

* Fix compile error.

Signed-off-by: Bo Li <bobboli0202@gmail.com>

* pick blackwell support

Signed-off-by: Dylan Chen <191843203+DylanChen-NV@users.noreply.github.com>

* Add copyright notice to fused_multihead_attention_v2.cpp.

Signed-off-by: Bo Li <bobboli0202@gmail.com>

* Add license.

Signed-off-by: Bo Li <bobboli0202@gmail.com>

* Add missing license.

Signed-off-by: Bo Li <bobboli0202@gmail.com>

* Exclude building flashMLA kernels under sm90.

Signed-off-by: Bo Li <bobboli0202@gmail.com>

* Revert "Exclude building flashMLA kernels under sm90."

    This reverts commit f0c859d459.

Signed-off-by: Bo Li <bobboli0202@gmail.com>

* Use macro to skip compiling FlashMLA for non sm90 targets.

Signed-off-by: Bo Li <bobboli0202@gmail.com>

---------

Signed-off-by: Bo Li <bobboli0202@gmail.com>
Signed-off-by: Dylan Chen <191843203+DylanChen-NV@users.noreply.github.com>
Co-authored-by: Dylan Chen <ziqingc@nvidia.com>
Co-authored-by: Dylan Chen <191843203+DylanChen-NV@users.noreply.github.com>
Co-authored-by: QI JUN <22017000+QiJune@users.noreply.github.com>
2025-04-07 15:14:13 +08:00

94 lines
2.6 KiB
C++

/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/kernels/unfusedAttentionKernels.h"
#include <assert.h>
#include <cstdint>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
namespace tensorrt_llm
{
namespace kernels
{
enum class KvCacheDataType;
struct MlaMetaParams
{
int32_t q_lora_rank = 0;
int32_t kv_lora_rank = 0;
int32_t qk_nope_head_dim = 0;
int32_t qk_rope_head_dim = 0;
int32_t v_head_dim = 0;
int32_t predicted_tokens_per_seq = 1;
int32_t num_layers = 0;
auto data() const
{
return std::make_tuple(q_lora_rank, kv_lora_rank, qk_nope_head_dim, qk_rope_head_dim, v_head_dim,
predicted_tokens_per_seq, num_layers);
}
};
template <typename T>
struct MlaParams
{
T const* latent_cache; // cKV + k_pe
T* attention_input_buf; // [b, s, 3, h, d_h + r]
void* quant_attention_input_buf;
T* context_buf;
T* q_pe; // [b, h, d_r], strided
float2 const* cos_sin_cache; // [s, rope]
int32_t batch_size;
int32_t acc_q_len;
int32_t head_num; // h
void* workspace;
int32_t const* cache_seq_lens;
int* seqQOffset;
uint32_t* fmha_tile_counter;
int32_t max_input_seq_len;
int* cu_q_seqlens;
int* cu_kv_seqlens;
int32_t q_pe_ld;
int32_t q_pe_stride;
MlaMetaParams meta;
int const* block_ids_per_seq;
KvCacheDataType cache_type;
// Scales for mla quantization
float* bmm1_scale;
float* bmm2_scale;
float const* quant_scale_o;
float const* quant_scale_q;
float const* quant_scale_kv;
float const* dequant_scale_q;
float const* dequant_scale_kv;
float host_bmm1_scale;
};
template <typename T, typename KVCacheBuffer>
void invokeMLARopeContext(MlaParams<T>& params, KVCacheBuffer kv_cache_buffer, cudaStream_t stream);
template <typename T, typename KVCacheBuffer>
void invokeMLARopeGeneration(MlaParams<T>& params, KVCacheBuffer kv_cache_buffer, cudaStream_t stream);
} // namespace kernels
} // namespace tensorrt_llm