mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
* fp8 kv + bf16 ctx MLA + fp8 gen MLA
Use BF16 for context MLA.
mFP8GenerationMLA and mFP8ContextFMHA shouldn't be enabled together.
Allow mSM==90 for mFP8GenerationMLA==true.
For FMHA, dataTypeKv should be FP8.
For FP8 MLA generation, the output is still in BF16.
Refine debug info for FMHA kernel metadata.
Use inputType, outputType, SM together to hash kernel list.
Add FP8 MLA generation FMHA kernel.
Special WAR of NUM_COMPUTE_GROUPS for MLA generation kernel.
Separate the implementation of fused_multihead_attention_v2.h to CPP and print some debug info if checkIfKernelExist fails.
Refine debug info in fused_multihead_attention_v2.cpp
Correct FP8 MLA metadata.
New kernel provided by Yuxin, which outputs BF16.
smem size is not set correctly, which will lead to illegal mem access.
Yuxin fixed the error in FMHA MLA kernel: previously the BF16 isn't correctly written: some parts are repeatedly written, while some others are untouched.
There are two bmm1 scales that should be set correctly.
New kernel generated by Yuxin.
Modificatiosn to common/attentionOp for FP8 MLA on Hopper using FMHA.
Not necessary. If mFP8GenerationMLA, is_fp8_out is false, so mFP8ContextFMHA is false.
Skip a check in fmhaDispatcher.
Modifications in fmhaRunner:
- Debug dump.
- if (!isFP8GenerationMLA) skips a lot of flag setting.
- TMA descriptor modification for qo (by Yuxin).
Cleanup debug output.
Clean up o tma descriptor modifications.
Signed-off-by: Bo Li <bobboli0202@gmail.com>
* Resolve conflicts.
Signed-off-by: Bo Li <bobboli0202@gmail.com>
* Apply the patch of FP8 FlashMLA and resolve conflicts.
Signed-off-by: Bo Li <bobboli0202@gmail.com>
* Fix compilation error.
Signed-off-by: Bo Li <bobboli0202@gmail.com>
* Fix compile error.
Signed-off-by: Bo Li <bobboli0202@gmail.com>
* pick blackwell support
Signed-off-by: Dylan Chen <191843203+DylanChen-NV@users.noreply.github.com>
* Add copyright notice to fused_multihead_attention_v2.cpp.
Signed-off-by: Bo Li <bobboli0202@gmail.com>
* Add license.
Signed-off-by: Bo Li <bobboli0202@gmail.com>
* Add missing license.
Signed-off-by: Bo Li <bobboli0202@gmail.com>
* Exclude building flashMLA kernels under sm90.
Signed-off-by: Bo Li <bobboli0202@gmail.com>
* Revert "Exclude building flashMLA kernels under sm90."
This reverts commit f0c859d459.
Signed-off-by: Bo Li <bobboli0202@gmail.com>
* Use macro to skip compiling FlashMLA for non sm90 targets.
Signed-off-by: Bo Li <bobboli0202@gmail.com>
---------
Signed-off-by: Bo Li <bobboli0202@gmail.com>
Signed-off-by: Dylan Chen <191843203+DylanChen-NV@users.noreply.github.com>
Co-authored-by: Dylan Chen <ziqingc@nvidia.com>
Co-authored-by: Dylan Chen <191843203+DylanChen-NV@users.noreply.github.com>
Co-authored-by: QI JUN <22017000+QiJune@users.noreply.github.com>
121 lines
3.5 KiB
C++
121 lines
3.5 KiB
C++
/*
|
|
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#pragma once
|
|
|
|
#include "tensorrt_llm/common/assert.h"
|
|
#include <limits.h>
|
|
#include <stdint.h>
|
|
|
|
namespace tensorrt_llm
|
|
{
|
|
namespace kernels
|
|
{
|
|
enum Data_type
|
|
{
|
|
DATA_TYPE_BOOL,
|
|
DATA_TYPE_FP16,
|
|
DATA_TYPE_FP32,
|
|
DATA_TYPE_INT4,
|
|
DATA_TYPE_INT8,
|
|
DATA_TYPE_INT32,
|
|
DATA_TYPE_BF16,
|
|
DATA_TYPE_E2M1,
|
|
DATA_TYPE_E4M3,
|
|
DATA_TYPE_E5M2
|
|
};
|
|
|
|
static inline std::string data_type_to_string(Data_type dtype)
|
|
{
|
|
switch (dtype)
|
|
{
|
|
case DATA_TYPE_BOOL: return "bool";
|
|
case DATA_TYPE_FP16: return "fp16";
|
|
case DATA_TYPE_FP32: return "fp32";
|
|
case DATA_TYPE_INT4: return "int4";
|
|
case DATA_TYPE_INT8: return "int8";
|
|
case DATA_TYPE_INT32: return "int32";
|
|
case DATA_TYPE_BF16: return "bf16";
|
|
case DATA_TYPE_E2M1: return "e2m1";
|
|
case DATA_TYPE_E4M3: return "e4m3";
|
|
case DATA_TYPE_E5M2: return "e5m2";
|
|
default: return std::to_string(static_cast<int>(dtype)) + " (unknown)";
|
|
}
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static inline size_t get_size_in_bits(Data_type dtype)
|
|
{
|
|
switch (dtype)
|
|
{
|
|
case DATA_TYPE_FP32: return 32;
|
|
case DATA_TYPE_FP16: return 16;
|
|
case DATA_TYPE_INT32: return 32;
|
|
case DATA_TYPE_INT8: return 8;
|
|
case DATA_TYPE_BF16: return 16;
|
|
case DATA_TYPE_E2M1: return 4;
|
|
case DATA_TYPE_E4M3: return 8;
|
|
case DATA_TYPE_E5M2: return 8;
|
|
default: TLLM_CHECK_WITH_INFO(false, "FMHA Data Type is not supported."); return 0;
|
|
}
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static inline size_t get_size_in_bytes(size_t n, Data_type dtype)
|
|
{
|
|
switch (dtype)
|
|
{
|
|
case DATA_TYPE_FP32: return n * 4;
|
|
case DATA_TYPE_FP16: return n * 2;
|
|
case DATA_TYPE_INT32: return n * 4;
|
|
case DATA_TYPE_INT8: return n;
|
|
case DATA_TYPE_BF16: return n * 2;
|
|
case DATA_TYPE_E2M1: TLLM_CHECK_WITH_INFO(n % 2 == 0, "Not supported."); return n / 2;
|
|
case DATA_TYPE_E4M3: return n;
|
|
case DATA_TYPE_E5M2: return n;
|
|
default: TLLM_CHECK_WITH_INFO(false, "FMHA Data Type is not supported."); return 0;
|
|
}
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static inline size_t get_size_in_bytes(Data_type dtype)
|
|
{
|
|
return get_size_in_bytes(1, dtype);
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
constexpr int32_t kSM_70 = 70;
|
|
constexpr int32_t kSM_72 = 72;
|
|
constexpr int32_t kSM_75 = 75;
|
|
constexpr int32_t kSM_80 = 80;
|
|
constexpr int32_t kSM_86 = 86;
|
|
constexpr int32_t kSM_89 = 89;
|
|
constexpr int32_t kSM_90 = 90;
|
|
constexpr int32_t kSM_100 = 100;
|
|
constexpr int32_t kSM_120 = 120;
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static constexpr int kIdxScaleSoftmaxPtr = 0;
|
|
static constexpr int kIdxScaleSoftmaxLog2Ptr = 1;
|
|
|
|
} // namespace kernels
|
|
} // namespace tensorrt_llm
|