TensorRT-LLMs/cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/kernelUtils.h
Perkz Zheng da6cb541a2
[None][feat] Optimize MLA kernels with separate reduction kernels (#7597)
Signed-off-by: Perkz Zheng <67892460+PerkzZheng@users.noreply.github.com>
2025-09-09 16:58:44 +08:00

174 lines
5.8 KiB
C++

/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_fp8.h>
namespace tensorrt_llm
{
namespace kernels
{
///////////////////////////////////////////////////////////////////////////////////////////////////
// Helper functions to convert float2 to half, bfloat16, and e4m3.
inline __device__ uint32_t convert_float2_to_half(float a, float b)
{
uint32_t output;
reinterpret_cast<__half2&>(output) = __float22half2_rn(make_float2(a, b));
return output;
}
inline __device__ uint32_t convert_float2_to_bfloat16(float a, float b)
{
uint32_t output;
reinterpret_cast<__nv_bfloat162&>(output) = __float22bfloat162_rn(make_float2(a, b));
return output;
}
inline __device__ uint32_t convert_float4_to_e4m3(float a, float b, float c, float d)
{
uint32_t output;
reinterpret_cast<__nv_fp8x4_e4m3&>(output) = __nv_fp8x4_e4m3(make_float4(a, b, c, d));
return output;
}
///////////////////////////////////////////////////////////////////////////////////////////////////
// Helper functions for float2 mul and fma.
inline __device__ void mul(float2& c, float2 const& a, float2 const& b)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000
asm volatile("mul.f32x2 %0, %1, %2;\n"
: "=l"(reinterpret_cast<uint64_t&>(c))
: "l"(reinterpret_cast<uint64_t const&>(a)), "l"(reinterpret_cast<uint64_t const&>(b)));
#else
c.x = a.x * b.x;
c.y = a.y * b.y;
#endif
}
inline __device__ void fma(float2& d, float2 const& a, float2 const& b, float2 const& c)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000
asm volatile("fma.rn.f32x2 %0, %1, %2, %3;\n"
: "=l"(reinterpret_cast<uint64_t&>(d))
: "l"(reinterpret_cast<uint64_t const&>(a)), "l"(reinterpret_cast<uint64_t const&>(b)),
"l"(reinterpret_cast<uint64_t const&>(c)));
#else
d.x = a.x * b.x + c.x;
d.y = a.y * b.y + c.y;
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Dtype, int32_t NumElts>
inline __device__ void convertAndStoreToGmem(char* gmemPtr, float (&input)[NumElts])
{
static_assert(sizeof(Dtype) == 0, "Not implemented.");
}
template <typename Dtype, int32_t NumElts>
inline __device__ void convertAndStoreToGmem(
char* gmemPtr, char* oSfPtr, float (&input)[NumElts], float sfScale, bool isValidRow)
{
static_assert(sizeof(Dtype) == 0, "Not implemented.");
}
template <>
inline __device__ void convertAndStoreToGmem<__half, 8>(char* gmemPtr, float (&input)[8])
{
uint4 output;
output.x = convert_float2_to_half(input[0], input[1]);
output.y = convert_float2_to_half(input[2], input[3]);
output.z = convert_float2_to_half(input[4], input[5]);
output.w = convert_float2_to_half(input[6], input[7]);
*reinterpret_cast<uint4*>(gmemPtr) = output;
}
template <>
inline __device__ void convertAndStoreToGmem<__nv_bfloat16, 8>(char* gmemPtr, float (&input)[8])
{
uint4 output;
output.x = convert_float2_to_bfloat16(input[0], input[1]);
output.y = convert_float2_to_bfloat16(input[2], input[3]);
output.z = convert_float2_to_bfloat16(input[4], input[5]);
output.w = convert_float2_to_bfloat16(input[6], input[7]);
*reinterpret_cast<uint4*>(gmemPtr) = output;
}
template <>
inline __device__ void convertAndStoreToGmem<__nv_fp8_e4m3, 8>(char* gmemPtr, float (&input)[8])
{
uint2 output;
output.x = convert_float4_to_e4m3(input[0], input[1], input[2], input[3]);
output.y = convert_float4_to_e4m3(input[4], input[5], input[6], input[7]);
*reinterpret_cast<uint2*>(gmemPtr) = output;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Dtype, int32_t NumElts>
inline __device__ void convertToFloatAndAccumulate(float (&output)[NumElts], uint4 input, float scale0, float scale1)
{
static_assert(sizeof(Dtype) == 0, "Not implemented.");
}
template <>
inline __device__ void convertToFloatAndAccumulate<__half, 8>(
float (&output)[8], uint4 input, float scale0, float scale1)
{
float2 scales0 = make_float2(scale0, scale0);
float2 scales1 = make_float2(scale1, scale1);
#pragma unroll
for (int32_t ii = 0; ii < 4; ++ii)
{
float2 a = __half22float2(reinterpret_cast<__half2*>(&input)[ii]);
float2& c = reinterpret_cast<float2(&)[4]>(output)[ii];
// FFMA2: output = input * scale1 + output * scale0.
mul(c, c, scales0);
fma(c, a, scales1, c);
}
}
template <>
inline __device__ void convertToFloatAndAccumulate<__nv_bfloat16, 8>(
float (&output)[8], uint4 input, float scale0, float scale1)
{
float2 scales0 = make_float2(scale0, scale0);
float2 scales1 = make_float2(scale1, scale1);
#pragma unroll
for (int32_t ii = 0; ii < 4; ++ii)
{
float2 a = __bfloat1622float2(reinterpret_cast<__nv_bfloat162*>(&input)[ii]);
float2& c = reinterpret_cast<float2(&)[4]>(output)[ii];
// FFMA2: output = input * scale1 + output * scale0.
mul(c, c, scales0);
fma(c, a, scales1, c);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernels
} // namespace tensorrt_llm