mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
174 lines
5.8 KiB
C++
174 lines
5.8 KiB
C++
/*
|
|
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#pragma once
|
|
|
|
#include <cuda_bf16.h>
|
|
#include <cuda_fp16.h>
|
|
#include <cuda_fp8.h>
|
|
|
|
namespace tensorrt_llm
|
|
{
|
|
namespace kernels
|
|
{
|
|
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// Helper functions to convert float2 to half, bfloat16, and e4m3.
|
|
|
|
inline __device__ uint32_t convert_float2_to_half(float a, float b)
|
|
{
|
|
uint32_t output;
|
|
reinterpret_cast<__half2&>(output) = __float22half2_rn(make_float2(a, b));
|
|
return output;
|
|
}
|
|
|
|
inline __device__ uint32_t convert_float2_to_bfloat16(float a, float b)
|
|
{
|
|
uint32_t output;
|
|
reinterpret_cast<__nv_bfloat162&>(output) = __float22bfloat162_rn(make_float2(a, b));
|
|
return output;
|
|
}
|
|
|
|
inline __device__ uint32_t convert_float4_to_e4m3(float a, float b, float c, float d)
|
|
{
|
|
uint32_t output;
|
|
reinterpret_cast<__nv_fp8x4_e4m3&>(output) = __nv_fp8x4_e4m3(make_float4(a, b, c, d));
|
|
return output;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// Helper functions for float2 mul and fma.
|
|
|
|
inline __device__ void mul(float2& c, float2 const& a, float2 const& b)
|
|
{
|
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000
|
|
asm volatile("mul.f32x2 %0, %1, %2;\n"
|
|
: "=l"(reinterpret_cast<uint64_t&>(c))
|
|
: "l"(reinterpret_cast<uint64_t const&>(a)), "l"(reinterpret_cast<uint64_t const&>(b)));
|
|
#else
|
|
c.x = a.x * b.x;
|
|
c.y = a.y * b.y;
|
|
#endif
|
|
}
|
|
|
|
inline __device__ void fma(float2& d, float2 const& a, float2 const& b, float2 const& c)
|
|
{
|
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000
|
|
asm volatile("fma.rn.f32x2 %0, %1, %2, %3;\n"
|
|
: "=l"(reinterpret_cast<uint64_t&>(d))
|
|
: "l"(reinterpret_cast<uint64_t const&>(a)), "l"(reinterpret_cast<uint64_t const&>(b)),
|
|
"l"(reinterpret_cast<uint64_t const&>(c)));
|
|
#else
|
|
d.x = a.x * b.x + c.x;
|
|
d.y = a.y * b.y + c.y;
|
|
#endif
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename Dtype, int32_t NumElts>
|
|
inline __device__ void convertAndStoreToGmem(char* gmemPtr, float (&input)[NumElts])
|
|
{
|
|
static_assert(sizeof(Dtype) == 0, "Not implemented.");
|
|
}
|
|
|
|
template <typename Dtype, int32_t NumElts>
|
|
inline __device__ void convertAndStoreToGmem(
|
|
char* gmemPtr, char* oSfPtr, float (&input)[NumElts], float sfScale, bool isValidRow)
|
|
{
|
|
static_assert(sizeof(Dtype) == 0, "Not implemented.");
|
|
}
|
|
|
|
template <>
|
|
inline __device__ void convertAndStoreToGmem<__half, 8>(char* gmemPtr, float (&input)[8])
|
|
{
|
|
uint4 output;
|
|
output.x = convert_float2_to_half(input[0], input[1]);
|
|
output.y = convert_float2_to_half(input[2], input[3]);
|
|
output.z = convert_float2_to_half(input[4], input[5]);
|
|
output.w = convert_float2_to_half(input[6], input[7]);
|
|
*reinterpret_cast<uint4*>(gmemPtr) = output;
|
|
}
|
|
|
|
template <>
|
|
inline __device__ void convertAndStoreToGmem<__nv_bfloat16, 8>(char* gmemPtr, float (&input)[8])
|
|
{
|
|
uint4 output;
|
|
output.x = convert_float2_to_bfloat16(input[0], input[1]);
|
|
output.y = convert_float2_to_bfloat16(input[2], input[3]);
|
|
output.z = convert_float2_to_bfloat16(input[4], input[5]);
|
|
output.w = convert_float2_to_bfloat16(input[6], input[7]);
|
|
*reinterpret_cast<uint4*>(gmemPtr) = output;
|
|
}
|
|
|
|
template <>
|
|
inline __device__ void convertAndStoreToGmem<__nv_fp8_e4m3, 8>(char* gmemPtr, float (&input)[8])
|
|
{
|
|
uint2 output;
|
|
output.x = convert_float4_to_e4m3(input[0], input[1], input[2], input[3]);
|
|
output.y = convert_float4_to_e4m3(input[4], input[5], input[6], input[7]);
|
|
*reinterpret_cast<uint2*>(gmemPtr) = output;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename Dtype, int32_t NumElts>
|
|
inline __device__ void convertToFloatAndAccumulate(float (&output)[NumElts], uint4 input, float scale0, float scale1)
|
|
{
|
|
static_assert(sizeof(Dtype) == 0, "Not implemented.");
|
|
}
|
|
|
|
template <>
|
|
inline __device__ void convertToFloatAndAccumulate<__half, 8>(
|
|
float (&output)[8], uint4 input, float scale0, float scale1)
|
|
{
|
|
float2 scales0 = make_float2(scale0, scale0);
|
|
float2 scales1 = make_float2(scale1, scale1);
|
|
#pragma unroll
|
|
for (int32_t ii = 0; ii < 4; ++ii)
|
|
{
|
|
float2 a = __half22float2(reinterpret_cast<__half2*>(&input)[ii]);
|
|
float2& c = reinterpret_cast<float2(&)[4]>(output)[ii];
|
|
// FFMA2: output = input * scale1 + output * scale0.
|
|
mul(c, c, scales0);
|
|
fma(c, a, scales1, c);
|
|
}
|
|
}
|
|
|
|
template <>
|
|
inline __device__ void convertToFloatAndAccumulate<__nv_bfloat16, 8>(
|
|
float (&output)[8], uint4 input, float scale0, float scale1)
|
|
{
|
|
float2 scales0 = make_float2(scale0, scale0);
|
|
float2 scales1 = make_float2(scale1, scale1);
|
|
#pragma unroll
|
|
for (int32_t ii = 0; ii < 4; ++ii)
|
|
{
|
|
float2 a = __bfloat1622float2(reinterpret_cast<__nv_bfloat162*>(&input)[ii]);
|
|
float2& c = reinterpret_cast<float2(&)[4]>(output)[ii];
|
|
// FFMA2: output = input * scale1 + output * scale0.
|
|
mul(c, c, scales0);
|
|
fma(c, a, scales1, c);
|
|
}
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
} // namespace kernels
|
|
} // namespace tensorrt_llm
|