TensorRT-LLMs/cpp/kernels/fmha_v2/src/fmha/utils.h
Kanghwan 41e5870a70
[#8476][chore] Update license (#8807)
Signed-off-by: Kanghwan Jang <861393+karljang@users.noreply.github.com>
2025-11-19 15:05:25 -08:00

2929 lines
74 KiB
C++

/*
* SPDX-FileCopyrightText: Copyright (c) 2011-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <assert.h>
#include <cuda_fp16.h>
#include <stdint.h>
#include <stdlib.h>
#if defined(__CLANGD__)
#include <__clang_cuda_builtin_vars.h>
#include <__clang_cuda_math.h>
#endif
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#include <cuda_bf16.h>
#endif
// include warpgroup related instructions, used by SM90.
#include <fmha/hopper/utils_warpgroup.h>
// include gmma related instructions, used by SM90.
#include <fmha/hopper/utils_gmma.h>
// include tma related instructions, used by SM90.
#include <fmha/hopper/utils_tma.h>
#include "fmha/numeric_types.h"
#define FP32_I2F_MAGIC_NUMBER 12582912.f
#define FP32_I2F_MAGIC_NUMBER_HEX 0x4b400000
extern "C" __device__ uint32_t __nvvm_get_smem_pointer(void* ptr);
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace introspection
{
template <int... Ns>
struct Unpack;
template <int N>
struct Unpack<N>
{
// if we simply static_assert(false) then compiler will not emit template params upon failure
static_assert(N < INT_MIN, "");
using Type = std::integral_constant<int, N>;
};
template <int N, int... Ns>
struct Unpack<N, Ns...>
{
using Type = Unpack<N, Ns...>;
using Unpack_first = typename Unpack<N>::Type;
using Unpack_remaining = typename Unpack<Ns...>::Type;
};
} // namespace introspection
// Example usage:
//
// Inspect_ns<(int)USE_LDGSTS_, PRED_REGS, (int)IS_HOPPER> foo;
//
// or
//
// Inspect_ns<(int)USE_LDGSTS_, PRED_REGS, (int)IS_HOPPER>{}.foo();
//
// Output by nvcc:
//
// ./src/fmha/gmem_tile_qkv_packed.h(70): error: static assertion failed with ""
// detected during:
// instantiation of class "fmha::v2::Unpack<N> [with N=1]"
// (77): here
// instantiation of class "fmha::v2::Unpack<N, Ns...> [with N=1, Ns=<2, 0>]"
// (84): here
// instantiation of class "fmha::v2::Inspect_ns<Ns...> [with Ns=<1, 2, 0>]"
// (143): here
template <int... Ns>
struct Inspect_ns
{
using Type = typename introspection::Unpack<Ns...>::Type;
};
// Can be used alongside with static_assert() to figure out the conditions when assertion failed
// Example:
//
// Cond_inspect_ns< (int)ROWS >= (int)ROWS_PER_LDG, ROWS, ROWS_PER_LDG> foo;
//
// Output by nvcc (when condition is not met):
//
// ./src/fmha/utils.h(163): error: static assertion failed with ""
// detected during:
// instantiation of class "Cond_inspect_ns<COND, Ns...> [with COND=false, Ns=<32, 64>]"
template <bool COND, int... Ns>
struct Cond_inspect_ns
{
static_assert(COND, "");
};
// Example:
//
// Inspect_type<Mma_tile_p>{}.foo();
//
// or
//
// Inspect_type<Mma_tile_p> foo;
//
// Output by nvcc:
//
// ./src/fmha/utils.h(189): error: class "fmha::Ampere_hmma_tile<fmha::Cta_tile_<fmha::Ampere, 64, 128, 64, 128, 256,
// 4, 1, 1>, 16>" has no member "Dummy"
// detected during:
// instantiation of class "Inspect_type<T> [with T=fmha::Ampere_hmma_tile<fmha::Cta_tile_<fmha::Ampere,
// 64, 128, 64, 128, 256, 4, 1, 1>, 16>]"
template <typename T>
struct Inspect_type
{
// Purposefully trigger error by referencing non-existent T::Dummy
using Dummy = typename T::Dummy;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace fmha
{
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Row
{
static constexpr bool COL = false;
static constexpr bool ROW = true;
};
struct Col
{
static constexpr bool COL = true;
static constexpr bool ROW = false;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int M, int N>
struct Round_up
{
enum
{
VALUE = (M + N - 1) / N * N
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int N_, int H_, int W_>
struct Tile_nhw
{
enum
{
N = N_,
H = H_,
W = W_
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int M, bool = (M & (M - 1)) == 0>
struct Next_power_of_two
{
};
template <int M>
struct Next_power_of_two<M, true>
{
enum
{
VALUE = M
};
};
template <>
struct Next_power_of_two<3, false>
{
enum
{
VALUE = 4
};
};
template <>
struct Next_power_of_two<5, false>
{
enum
{
VALUE = 8
};
};
template <>
struct Next_power_of_two<6, false>
{
enum
{
VALUE = 8
};
};
template <>
struct Next_power_of_two<7, false>
{
enum
{
VALUE = 8
};
};
template <>
struct Next_power_of_two<9, false>
{
enum
{
VALUE = 16
};
};
template <>
struct Next_power_of_two<10, false>
{
enum
{
VALUE = 16
};
};
template <>
struct Next_power_of_two<11, false>
{
enum
{
VALUE = 16
};
};
template <>
struct Next_power_of_two<12, false>
{
enum
{
VALUE = 16
};
};
template <>
struct Next_power_of_two<13, false>
{
enum
{
VALUE = 16
};
};
template <>
struct Next_power_of_two<14, false>
{
enum
{
VALUE = 16
};
};
template <>
struct Next_power_of_two<15, false>
{
enum
{
VALUE = 16
};
};
template <>
struct Next_power_of_two<24, false>
{
enum
{
VALUE = 32
};
};
template <>
struct Next_power_of_two<40, false>
{
enum
{
VALUE = 64
};
};
template <>
struct Next_power_of_two<48, false>
{
enum
{
VALUE = 64
};
};
template <>
struct Next_power_of_two<72, false>
{
enum
{
VALUE = 128
};
};
template <>
struct Next_power_of_two<80, false>
{
enum
{
VALUE = 128
};
};
template <>
struct Next_power_of_two<96, false>
{
enum
{
VALUE = 128
};
};
template <>
struct Next_power_of_two<104, false>
{
enum
{
VALUE = 128
};
};
template <>
struct Next_power_of_two<112, false>
{
enum
{
VALUE = 128
};
};
template <>
struct Next_power_of_two<144, false>
{
enum
{
VALUE = 256
};
};
template <>
struct Next_power_of_two<160, false>
{
enum
{
VALUE = 256
};
};
template <>
struct Next_power_of_two<192, false>
{
enum
{
VALUE = 256
};
};
template <>
struct Next_power_of_two<576, false>
{
enum
{
VALUE = 1024
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int N, bool = (N & (N - 1)) == 0>
struct Prev_power_of_two
{
};
template <int N>
struct Prev_power_of_two<N, true>
{
enum
{
VALUE = N
};
};
template <>
struct Prev_power_of_two<3, false>
{
enum
{
VALUE = 2
};
};
template <>
struct Prev_power_of_two<5, false>
{
enum
{
VALUE = 4
};
};
template <>
struct Prev_power_of_two<6, false>
{
enum
{
VALUE = 4
};
};
template <>
struct Prev_power_of_two<7, false>
{
enum
{
VALUE = 4
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int BYTES_PER_ROW, int SKEW>
struct Compute_skew
{
// The size of a transaction.
enum
{
BYTES_PER_TRX = 128
};
// The remainder of the row without skew.
enum
{
REMAINDER = BYTES_PER_ROW % BYTES_PER_TRX
};
// The value.
enum
{
VALUE = REMAINDER <= SKEW ? SKEW - REMAINDER : BYTES_PER_TRX + SKEW - REMAINDER
};
// Make sure the math works ;)
static_assert((BYTES_PER_ROW + VALUE) % BYTES_PER_TRX == SKEW, "");
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int BYTES_PER_ROW>
struct Compute_skew<BYTES_PER_ROW, 128>
{
// No skew!
enum
{
VALUE = 0
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int M, int N>
struct Div_up
{
enum
{
VALUE = (M + N - 1) / N
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int A, int B>
struct Max
{
enum
{
VALUE = A >= B ? A : B
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int A, int B, int C>
struct Max_3
{
enum
{
VALUE = Max<Max<A, B>::VALUE, C>::VALUE
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int A, int B>
struct Min
{
enum
{
VALUE = A <= B ? A : B
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int SIZE_IN_BYTES>
struct Uint_from_size_in_bytes
{
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <>
struct Uint_from_size_in_bytes<1>
{
using Type = uint8_t;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <>
struct Uint_from_size_in_bytes<2>
{
using Type = uint16_t;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <>
struct Uint_from_size_in_bytes<4>
{
using Type = uint32_t;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <>
struct Uint_from_size_in_bytes<8>
{
using Type = uint2;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <>
struct Uint_from_size_in_bytes<16>
{
using Type = uint4;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int WARPS_M, int WARPS_N, int WARPS_K>
struct Warp_masks
{
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <>
struct Warp_masks<8, 1, 1>
{
enum
{
M = 0xe0,
N = 0x00,
K = 0x00
};
};
template <>
struct Warp_masks<4, 2, 1>
{
enum
{
M = 0x60,
N = 0x80,
K = 0x00
};
};
template <>
struct Warp_masks<4, 1, 2>
{
enum
{
M = 0x60,
N = 0x00,
K = 0x80
};
};
template <>
struct Warp_masks<4, 1, 1>
{
enum
{
M = 0x60,
N = 0x00,
K = 0x00
};
};
template <>
struct Warp_masks<2, 4, 1>
{
enum
{
M = 0x20,
N = 0xc0,
K = 0x00
};
};
template <>
struct Warp_masks<2, 2, 2>
{
enum
{
M = 0x20,
N = 0x40,
K = 0x80
};
};
template <>
struct Warp_masks<2, 2, 1>
{
enum
{
M = 0x20,
N = 0x40,
K = 0x00
};
};
template <>
struct Warp_masks<2, 1, 2>
{
enum
{
M = 0x20,
N = 0x00,
K = 0x40
};
};
template <>
struct Warp_masks<2, 1, 1>
{
enum
{
M = 0x20,
N = 0x00,
K = 0x00
};
};
template <>
struct Warp_masks<1, 8, 1>
{
enum
{
M = 0x00,
N = 0xe0,
K = 0x00
};
};
template <>
struct Warp_masks<1, 4, 2>
{
enum
{
M = 0x00,
N = 0x60,
K = 0x80
};
};
template <>
struct Warp_masks<1, 4, 1>
{
enum
{
M = 0x00,
N = 0x60,
K = 0x00
};
};
template <>
struct Warp_masks<1, 2, 2>
{
enum
{
M = 0x00,
N = 0x20,
K = 0x40
};
};
template <>
struct Warp_masks<1, 2, 1>
{
enum
{
M = 0x00,
N = 0x20,
K = 0x00
};
};
template <>
struct Warp_masks<1, 1, 4>
{
enum
{
M = 0x00,
N = 0x00,
K = 0x60
};
};
template <>
struct Warp_masks<1, 1, 2>
{
enum
{
M = 0x00,
N = 0x00,
K = 0x20
};
};
template <>
struct Warp_masks<1, 1, 1>
{
enum
{
M = 0x00,
N = 0x00,
K = 0x00
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T>
inline __device__ __host__ T div_up(T m, T n)
{
return (m + n - 1) / n;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline int clz(int x)
{
for (int i = 31; i >= 0; --i)
{
if ((1 << i) & x)
{
return 31 - i;
}
}
return 32;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline int find_log_2(int x, bool round_up = false)
{
int a = 31 - clz(x);
if (round_up)
{
a += (x & (x - 1)) ? 1 : 0;
}
return a;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline void find_divisor(uint32_t& mul, uint32_t& shr, int x)
{
assert(x != 0);
if (x == 1)
{
// If dividing by 1, reduced math doesn't work because mul_coeff would need to be 2^32,
// which doesn't fit into unsigned int. the div() routine handles this special case
// separately.
mul = 0;
shr = 0;
}
else
{
// To express the division N/D in terms of a multiplication, what we first
// imagine is simply N*(1/D). However, 1/D will always evaluate to 0 (for D>1),
// so we need another way. There's nothing that says we have to use exactly
// the fraction 1/D; instead it could be any X/Y that reduces to 1/D (i.e.,
// Y=X*D), or at least to "close enough" to it. If we pick Y that is a power
// of two, then the N*(X/Y) can be N*X followed by a right-shift by some amount.
// The power of two we should pick should be at least 2^32, because in the
// div() routine we'll use umulhi(), which returns only the upper 32 bits --
// this being equivalent to a right-shift by 32. But we might want a higher
// power of two for better accuracy depending on the magnitude of the denominator.
// Once we've picked Y, then X [our mul_coeff value] is simply Y/D, rounding up,
// and we save shift_coeff as whatever further shift we have to do beyond
// what the umulhi() implies.
uint32_t p = 31 + find_log_2(x, true);
uint32_t m = (uint32_t) (((1ull << p) + (uint32_t) x - 1) / (uint32_t) x);
mul = m;
shr = p - 32;
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void fast_divmod(int& div, int& mod, int x, int y, uint32_t mul, uint32_t shr)
{
if (y == 1)
{
div = x;
mod = 0;
}
else
{
div = __umulhi((uint32_t) x, mul) >> shr;
mod = x - div * y;
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t hadd2(uint32_t a, uint32_t b)
{
uint32_t c;
asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t bfadd2(uint32_t a, uint32_t b)
{
uint32_t c;
uint32_t one = 0x3f803f80;
;
asm volatile("fma.rn.bf16x2 %0, %1, %3, %2;\n" : "=r"(c) : "r"(a), "r"(b), "r"(one));
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t hmax2(uint32_t a, uint32_t b)
{
uint32_t c;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile("max.f16x2 %0, %1, %2;" : "=r"(c) : "r"(a), "r"(b));
#else
asm volatile(
"{\n"
"\t .reg .f16x2 sela, selb;\n"
"\n"
"\t set.ge.f16x2.f16x2 sela, %1, %2;\n"
"\t set.gt.f16x2.f16x2 selb, %2, %1;\n"
"\n"
"\t mul.f16x2 %0, sela, %1;\n"
"\t fma.rn.f16x2 %0, selb, %2, %0;\n"
"}\n"
: "=r"(c)
: "r"(a), "r"(b));
#endif
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint2 hmax4(uint2 a, uint2 b)
{
uint2 c;
c.x = hmax2(a.x, b.x);
c.y = hmax2(a.y, b.y);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint4 hmax8(uint4 a, uint4 b)
{
uint4 c;
c.x = hmax2(a.x, b.x);
c.y = hmax2(a.y, b.y);
c.z = hmax2(a.z, b.z);
c.w = hmax2(a.w, b.w);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t hmin2(uint32_t a, uint32_t b)
{
uint32_t c;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile("min.f16x2 %0, %1, %2;" : "=r"(c) : "r"(a), "r"(b));
#else
asm volatile(
"{\n"
"\t .reg .f16x2 sela, selb;\n"
"\n"
"\t set.le.f16x2.f16x2 sela, %1, %2;\n"
"\t set.lt.f16x2.f16x2 selb, %2, %1;\n"
"\n"
"\t mul.f16x2 %0, sela, %1;\n"
"\t fma.rn.f16x2 %0, selb, %2, %0;\n"
"}\n"
: "=r"(c)
: "r"(a), "r"(b));
#endif
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t hmul2(uint32_t a, uint32_t b)
{
uint32_t c;
asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t bfmul2(uint32_t a, uint32_t b)
{
uint32_t c;
asm("{.reg .b32 c;\n"
" mov.b32 c, 0x80008000U;\n"
" fma.rn.bf16x2 %0,%1,%2,c;}\n"
: "=r"(c)
: "r"(a), "r"(b));
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint2 hmul4(uint2 a, uint2 b)
{
uint2 c;
c.x = hmul2(a.x, b.x);
c.y = hmul2(a.y, b.y);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint4 hmul8(uint4 a, uint4 b)
{
uint4 c;
c.x = hmul2(a.x, b.x);
c.y = hmul2(a.y, b.y);
c.z = hmul2(a.z, b.z);
c.w = hmul2(a.w, b.w);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint4 hmul8(uint32_t a, uint4 b)
{
uint4 c;
c.x = hmul2(a, b.x);
c.y = hmul2(a, b.y);
c.z = hmul2(a, b.z);
c.w = hmul2(a, b.w);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// Template function to support both half and bfloat16
template <typename Data_type>
inline __device__ uint32_t mul2(uint32_t a, uint32_t b)
{
return hmul2(a, b);
}
template <>
inline __device__ uint32_t mul2<bf16_t>(uint32_t a, uint32_t b)
{
return bfmul2(a, b);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// Template function to support both half and bfloat16
template <typename Data_type>
inline __device__ uint4 mul8(uint32_t a, uint4 b)
{
uint4 c;
c.x = hmul2(a, b.x);
c.y = hmul2(a, b.y);
c.z = hmul2(a, b.z);
c.w = hmul2(a, b.w);
return c;
}
template <>
inline __device__ uint4 mul8<bf16_t>(uint32_t a, uint4 b)
{
uint4 c;
c.x = bfmul2(a, b.x);
c.y = bfmul2(a, b.y);
c.z = bfmul2(a, b.z);
c.w = bfmul2(a, b.w);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t hrelu2(uint32_t x)
{
uint32_t res;
uint32_t const zero = 0u;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile("max.f16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero));
#else
asm volatile(
"{\n"
"\t .reg .f16x2 sela;\n"
"\t set.gtu.u32.f16x2 sela, %1, %2;\n"
"\t and.b32 %0, sela, %1;\n"
"}\n"
: "=r"(res)
: "r"(x), "r"(zero));
#endif
return res;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t bfrelu2(uint32_t x)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
uint32_t res;
uint32_t const zero = 0u;
asm volatile("max.bf16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero));
return res;
#endif
// not implemented yet
return x;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// Template function to support both half and bfloat16
template <typename Data_type>
inline __device__ uint32_t relu2(uint32_t x)
{
return hrelu2(x);
}
template <>
inline __device__ uint32_t relu2<bf16_t>(uint32_t x)
{
return bfrelu2(x);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t habs2(uint32_t x)
{
uint32_t res;
asm volatile("abs.f16x2 %0, %1;\n" : "=r"(res) : "r"(x));
return res;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// static inline __device__ uint32_t add_bias(uint32_t a, uint32_t bias, bool relu) {
// uint32_t c;
// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
// if( relu ) {
// uint32_t one = 0x3c003c00u;
// asm volatile("fma.rn.f16x2.relu %0, %1, %2, %3;" : "=r"(c) : "r"(a), "r"(one), "r"(bias));
// } else {
// c = hadd2(a, bias);
// }
// #else
// c = hadd2(a, bias);
// if( relu ) {
// c = hrelu2(c);
// }
// #endif
// return c;
// }
////////////////////////////////////////////////////////////////////////////////////////////////////
// static inline __device__ uint2 add_bias(uint2 a, uint2 bias, bool relu) {
// uint2 dst;
// dst.x = add_bias(a.x, bias.x, relu);
// dst.y = add_bias(a.y, bias.y, relu);
// return dst;
// }
////////////////////////////////////////////////////////////////////////////////////////////////////
// static inline __device__ uint4 add_bias(uint4 a, uint4 bias, bool relu) {
// uint4 dst;
// dst.x = add_bias(a.x, bias.x, relu);
// dst.y = add_bias(a.y, bias.y, relu);
// dst.z = add_bias(a.z, bias.z, relu);
// dst.w = add_bias(a.w, bias.w, relu);
// return dst;
// }
////////////////////////////////////////////////////////////////////////////////////////////////////
// clamp float +inf/-inf
static inline __device__ float satfinite(float x)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 860
// bit representation of maximum value of float
uint32_t clamp_value = 0x7f7fffffu;
asm volatile("min.xorsign.abs.f32 %0, %0, %1;" : "+f"(x) : "r"(clamp_value));
return x;
#else
// bit representation of maximum and minimum value of float
uint32_t umax = 0x7f7fffffu;
uint32_t umin = 0xff7fffffu;
float out;
asm volatile("min.f32 %0, %1, %2;" : "=f"(out) : "f"(x), "r"(umax));
asm volatile("max.f32 %0, %0, %1;" : "+f"(out) : "r"(umin));
return out;
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// clamp half2 +inf/-inf
static inline __device__ uint32_t satfinite_h2(uint32_t h2)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 860
uint32_t out, clamp_value;
clamp_value = 0x7bff7bffu;
asm volatile("min.xorsign.abs.f16x2 %0, %1, %2;" : "=r"(out) : "r"(h2), "r"(clamp_value));
return out;
#elif defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 800
// bit representation of maximum and minimum value of half2
uint32_t umax = 0x7bff7bffu;
uint32_t umin = 0xfbfffbffu;
uint32_t out;
asm volatile("min.f16x2 %0, %1, %2;" : "=r"(out) : "r"(h2), "r"(umax));
asm volatile("max.f16x2 %0, %0, %1;" : "+r"(out) : "r"(umin));
return out;
#else
// Take the absolute value of h2. It should map to |Rx| in SASS.
uint32_t p2;
asm volatile("abs.f16x2 %0, %1;" : "=r"(p2) : "r"(h2));
// Compute a mask for each fp16: 0xffff if +INF and 0x0000 otherwise.
uint32_t inf2 = 0x7c007c00u;
uint32_t mask;
asm volatile("set.eq.u32.f16x2 %0, %1, %2;" : "=r"(mask) : "r"(p2), "r"(inf2));
// Recreate the new value. 0x7bff is the max value for FP16.
p2 = (~mask & p2) | (mask & 0x7bff7bff);
// Simply re-add the sign and we're done.
return p2 | (h2 & 0x80008000);
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T>
static inline __device__ T clamp(T x, T lb, T ub)
{
return x < lb ? lb : (x > ub ? ub : x);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float custom_exp2f(float x, float scale, float scaled_max)
{
float d1, d2;
asm("fma.rz.ftz.f32 %0, %1, %2, %3;" : "=f"(d1) : "f"(x), "f"(scale), "f"(-scaled_max));
asm("ex2.approx.ftz.f32 %0, %1;" : "=f"(d2) : "f"(d1));
return d2;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint16_t clamp_to_zero(uint16_t x)
{
uint16_t mask;
asm volatile("set.gtu %0, %1, 0;" : "=h"(mask) : "h"(x));
return mask & x;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint16_t float_to_half(float f)
{
uint16_t h;
asm volatile("cvt.rn.f16.f32 %0, %1;" : "=h"(h) : "f"(f));
return h;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ bf16_t float_to_bf16(float f)
{
return __float2bfloat16(f);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t float2_to_half2(float a, float b)
{
uint32_t c;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(c) : "f"(b), "f"(a));
#else
uint16_t lo = float_to_half(a);
uint16_t hi = float_to_half(b);
asm volatile("mov.b32 %0, {%1, %2};\n" : "=r"(c) : "h"(lo), "h"(hi));
#endif
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t float2_to_bf16_x2(float a, float b)
{
uint32_t c;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n" : "=r"(c) : "f"(b), "f"(a));
#else
uint16_t* px = reinterpret_cast<uint16_t*>(&a);
uint16_t* py = reinterpret_cast<uint16_t*>(&b);
uint16_t value = px[1];
uint16_t value2 = py[1];
if (px[0] == 0x8000)
{
if ((value & 0x1) == 1)
value++;
}
else if (px[0] > 0x8000)
{
value++;
}
if (py[0] == 0x8000)
{
if ((value2 & 0x1) == 1)
value2++;
}
else if (py[0] > 0x8000)
{
value2++;
}
uint32_t high = reinterpret_cast<uint32_t&>(value2);
c = (high << 16) | value;
#endif
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// Template function to support both half and bfloat16
template <typename Data_type>
inline __device__ uint32_t float2_to_16bit_2(float a, float b)
{
return float2_to_half2(a, b);
}
template <>
inline __device__ uint32_t float2_to_16bit_2<bf16_t>(float a, float b)
{
return float2_to_bf16_x2(a, b);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t float_to_half2(float a)
{
return float2_to_half2(a, a);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t float2_to_half2(float2 const& f)
{
return float2_to_half2(f.x, f.y);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t float_to_bf16_2(float a)
{
return float2_to_bf16_x2(a, a);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint2 float4_to_half4(float x, float y, float z, float w)
{
uint2 d;
d.x = float2_to_half2(x, y);
d.y = float2_to_half2(z, w);
return d;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// Template function to support both half and bfloat16
template <typename Data_type>
inline __device__ uint2 float4_to_16bit_x4(float x, float y, float z, float w)
{
uint2 d;
d.x = float2_to_half2(x, y);
d.y = float2_to_half2(z, w);
return d;
}
template <>
inline __device__ uint2 float4_to_16bit_x4<bf16_t>(float x, float y, float z, float w)
{
uint2 d;
d.x = float2_to_bf16_x2(x, y);
d.y = float2_to_bf16_x2(z, w);
return d;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t hfma2(uint32_t a, uint32_t b, uint32_t c)
{
uint32_t d;
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c));
return d;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t hfma2_relu(uint32_t a, uint32_t b, uint32_t c)
{
uint32_t d;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile("fma.rn.f16x2.relu %0, %1, %2, %3;" : "=r"(d) : "r"(a), "r"(b), "r"(c));
#else
d = hrelu2(hfma2(a, b, c));
#endif
return d;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t h0_h0(uint32_t x)
{
uint32_t y;
asm volatile("{.reg .f16 lo, hi; mov.b32 {lo, hi}, %1; mov.b32 %0, {lo, lo};}\n" : "=r"(y) : "r"(x));
return y;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ float h0_to_float(uint32_t h2)
{
float f;
asm volatile(
"{\n"
".reg .f16 lo, hi;\n"
"mov.b32 {lo, hi}, %1;\n"
"cvt.f32.f16 %0, lo;\n"
"}\n"
: "=f"(f)
: "r"(h2));
return f;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t h1_h1(uint32_t x)
{
uint32_t y;
asm volatile("{.reg .f16 lo, hi; mov.b32 {lo, hi}, %1; mov.b32 %0, {hi, hi};}\n" : "=r"(y) : "r"(x));
return y;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint16_t hadd(uint16_t a, uint16_t b)
{
uint16_t d;
asm volatile("add.f16 %0, %1, %2;" : "=h"(d) : "h"(a), "h"(b));
return d;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t hadd(uint32_t a, uint32_t b)
{
return hadd2(a, b);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint2 hadd4(uint2 a, uint2 b)
{
uint2 c;
c.x = hadd2(a.x, b.x);
c.y = hadd2(a.y, b.y);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint2 hadd(uint2 a, uint2 b)
{
return hadd4(a, b);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint4 hadd8(uint4 a, uint4 b)
{
uint4 c;
c.x = hadd2(a.x, b.x);
c.y = hadd2(a.y, b.y);
c.z = hadd2(a.z, b.z);
c.w = hadd2(a.w, b.w);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// Template function to support both half and bfloat16
template <typename Data_type>
inline __device__ uint4 add8(uint4 a, uint4 b)
{
return hadd8(a, b);
}
template <>
inline __device__ uint4 add8<bf16_t>(uint4 a, uint4 b)
{
uint4 c;
c.x = bfadd2(a.x, b.x);
c.y = bfadd2(a.y, b.y);
c.z = bfadd2(a.z, b.z);
c.w = bfadd2(a.w, b.w);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint4 fadd4(uint4 a, uint4 b)
{
float4 c;
c.x = reinterpret_cast<float const&>(a.x) + reinterpret_cast<float const&>(b.x);
c.y = reinterpret_cast<float const&>(a.y) + reinterpret_cast<float const&>(b.y);
c.z = reinterpret_cast<float const&>(a.z) + reinterpret_cast<float const&>(b.z);
c.w = reinterpret_cast<float const&>(a.w) + reinterpret_cast<float const&>(b.w);
return reinterpret_cast<uint4 const&>(c);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint4 hadd(uint4 a, uint4 b)
{
return hadd8(a, b);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ float half_to_float(uint16_t h)
{
float f;
asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h));
return f;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ float bf16_to_float(uint16_t h)
{
float f;
asm volatile("mov.b32 %0, {0, %1};\n" : "=f"(f) : "h"(h));
return f;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ float2 half2_to_float2(uint32_t x)
{
uint16_t lo, hi;
asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(x));
return make_float2(half_to_float(lo), half_to_float(hi));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ float2 bf16_2_to_float2(uint32_t x)
{
float2 res;
asm volatile(
"{\n"
" .reg .b16 lo, hi;\n"
" mov.b32 {lo, hi}, %2;\n"
" mov.b32 %0, {0, lo};\n"
" mov.b32 %1, {0, hi};\n"
"}\n"
: "=f"(res.x), "=f"(res.y)
: "r"(x));
return res;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// Template function to support both half and bfloat16
template <typename Data_type>
inline __device__ float2 convert_from_16bit_2(uint32_t x)
{
return half2_to_float2(x);
}
template <>
inline __device__ float2 convert_from_16bit_2<bf16_t>(uint32_t x)
{
return bf16_2_to_float2(x);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ void half2_to_float2(float& x, float& y, uint32_t h)
{
float2 tmp = half2_to_float2(h);
x = tmp.x;
y = tmp.y;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint16_t hfma(uint16_t a, uint16_t b, uint16_t c)
{
uint16_t d;
asm volatile("fma.rn.f16 %0, %1, %2, %3;" : "=h"(d) : "h"(a), "h"(b), "h"(c));
return d;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint16_t hmul(uint16_t a, uint16_t b)
{
uint16_t d;
asm volatile("mul.f16 %0, %1, %2;" : "=h"(d) : "h"(a), "h"(b));
return d;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// Converted two half2's or bf162's into float, then take their dot product.
template <typename Data_type>
inline __device__ float fma2_in_float(uint32_t const a, uint32_t const b)
{
float2 af = fmha::convert_from_16bit_2<Data_type>(a);
float2 bf = fmha::convert_from_16bit_2<Data_type>(b);
return af.x * bf.x + af.y * bf.y;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// Converted two vectors of 8 half's or bf16's into float, then take their dot product.
template <typename Data_type>
inline __device__ float fma8_in_float(uint4 const a, uint4 const b)
{
float sum;
sum = fmha::fma2_in_float<Data_type>(a.x, b.x);
sum += fmha::fma2_in_float<Data_type>(a.y, b.y);
sum += fmha::fma2_in_float<Data_type>(a.z, b.z);
sum += fmha::fma2_in_float<Data_type>(a.w, b.w);
return sum;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ float sigmoid(float x)
{
return 1.f / (1.f + expf(-x));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void clear(uint16_t& dst)
{
dst = uint16_t(0);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void clear(uint32_t& dst)
{
dst = 0u;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void clear(uint2& dst)
{
dst = make_uint2(0u, 0u);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void clear(uint4& dst)
{
dst = make_uint4(0u, 0u, 0u, 0u);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// P R E D I C A T E P A C K I N G
//
////////////////////////////////////////////////////////////////////////////////////////////////////
enum
{
BYTES_PER_REG = 4,
PREDS_PER_BYTE = 4,
PREDS_PER_REG = BYTES_PER_REG * PREDS_PER_BYTE
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int LDGS>
struct Compute_number_of_pred_regs
{
enum
{
VALUE = Div_up<LDGS, PREDS_PER_REG>::VALUE
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int M, int N>
inline __device__ void pack_predicates(uint32_t (&preds)[M], uint32_t const (&p)[N])
{
// Make sure the values match.
static_assert(Compute_number_of_pred_regs<N>::VALUE == M, "");
// The number of complete steps (where we use all the predicates in a byte).
enum
{
COMPLETE_BYTES = N / PREDS_PER_BYTE
};
// Make sure we allocated enough predicate registers.
static_assert(Div_up<COMPLETE_BYTES, BYTES_PER_REG>::VALUE <= M, "");
// The remainder.
enum
{
REMAINDER = N - COMPLETE_BYTES * PREDS_PER_BYTE
};
// Make sure we got the math right and the remainder is between 0 and 3.
static_assert(REMAINDER >= 0 && REMAINDER <= 3, "");
// The mask to extract the predicates.
enum
{
COMPLETE_MASK = (1 << PREDS_PER_BYTE) - 1
};
// Run complete steps.
#pragma unroll
for (int ii = 0; ii < M; ++ii)
{
// The number of complete bytes for that register. Be careful it can be > than 4 ;)
int const COMPLETE = (N - ii * PREDS_PER_REG) / PREDS_PER_BYTE;
// Pack the predicates in a register.
uint32_t reg = 0u;
#pragma unroll
for (int jj = 0; jj < 4; ++jj)
{
// Early exit.
if (jj >= COMPLETE)
{
break;
}
// Prepare the array of predicates.
bool tmp[PREDS_PER_BYTE];
#pragma unroll
for (int kk = 0; kk < PREDS_PER_BYTE; ++kk)
{
tmp[kk] = p[ii * PREDS_PER_REG + jj * PREDS_PER_BYTE + kk] != 0;
}
// Store the predicates.
#pragma unroll
for (int kk = 0; kk < PREDS_PER_BYTE; ++kk)
{
if (tmp[kk])
{
reg |= 1u << (jj * 8 + kk);
}
}
}
// Skip the rest of the code if we do not have a remainder.
if (COMPLETE < 4 && REMAINDER > 0)
{
// The mask to extract the predicates.
enum
{
REMAINDER_MASK = (1 << REMAINDER) - 1
};
// Prepare the array of predicates.
bool tmp[PREDS_PER_BYTE];
#pragma unroll
for (int jj = 0; jj < REMAINDER; ++jj)
{
tmp[jj] = p[COMPLETE_BYTES * PREDS_PER_BYTE + jj] != 0;
}
// Store the predicates.
#pragma unroll
for (int jj = 0; jj < REMAINDER; ++jj)
{
if (tmp[jj])
{
reg |= 1u << (COMPLETE * 8 + jj);
}
}
}
// Store the predicate register.
preds[ii] = reg;
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int N>
inline __device__ uint32_t pack_predicates(uint32_t const (&p)[N])
{
uint32_t tmp[1];
pack_predicates(tmp, p);
return tmp[0];
}
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// G E N E R I C P R E D I C A T E D L D G S T S
//
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int N, int M, typename Functor>
inline __device__ void ldgsts_(Functor& fct, uint32_t const (&preds)[M])
{
// The number of complete bytes (where we use all the predicates in a byte).
enum
{
COMPLETE = N / PREDS_PER_BYTE
};
// Make sure we did allocate enough predicates.
static_assert(Div_up<COMPLETE, BYTES_PER_REG>::VALUE <= M, "");
// The remainder.
enum
{
REMAINDER = N - COMPLETE * PREDS_PER_BYTE
};
// Make sure we got the math right and the remainder is between 0 and 3.
static_assert(REMAINDER >= 0 && REMAINDER <= 3, "");
// The mask to extract the predicates.
enum
{
COMPLETE_MASK = (1 << PREDS_PER_BYTE) - 1
};
// Clear the fetch registers.
#pragma unroll
for (int ii = 0; ii < N; ++ii)
{
fct.clear(ii);
}
// Run complete steps.
bool p[PREDS_PER_BYTE];
#pragma unroll
for (int ii = 0; ii < COMPLETE; ++ii)
{
// The predicate.
uint32_t reg = preds[ii / BYTES_PER_REG];
// Extract the predicates.
#pragma unroll
for (int jj = 0; jj < PREDS_PER_BYTE; ++jj)
{
uint32_t mask = 1u << (ii % BYTES_PER_REG * 8 + jj);
p[jj] = (reg & mask) != 0u;
}
// Issue the loads.
#pragma unroll
for (int jj = 0; jj < PREDS_PER_BYTE; ++jj)
{
fct.ldgsts(ii * PREDS_PER_BYTE + jj, p[jj]);
}
}
// Skip the rest of the code if we do not have a remainder.
if (REMAINDER > 0)
{
// The mask to extract the predicates.
enum
{
REMAINDER_MASK = (1 << REMAINDER) - 1
};
// The predicate register.
uint32_t reg = preds[COMPLETE / BYTES_PER_REG];
// Extract the predicates.
#pragma unroll
for (int jj = 0; jj < PREDS_PER_BYTE; ++jj)
{
uint32_t mask = 1u << (COMPLETE % BYTES_PER_REG * 8 + jj);
p[jj] = (reg & mask) != 0u;
}
// Issue the loads.
#pragma unroll
for (int ii = 0; ii < REMAINDER; ++ii)
{
fct.ldgsts(COMPLETE * PREDS_PER_BYTE + ii, p[ii]);
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int M, typename Functor>
inline __device__ void ldgsts_(Functor& fct, uint32_t preds)
{
uint32_t tmp[1] = {preds};
ldgsts_<M>(fct, tmp);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// L D G
//
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void ldg(uint8_t& dst, void const* ptr)
{
dst = *reinterpret_cast<uint8_t const*>(ptr);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void ldg(uint16_t& dst, void const* ptr)
{
dst = *reinterpret_cast<uint16_t const*>(ptr);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void ldg(uint32_t& dst, void const* ptr)
{
dst = *reinterpret_cast<uint32_t const*>(ptr);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void ldg(uint2& dst, void const* ptr)
{
dst = *reinterpret_cast<uint2 const*>(ptr);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void ldg(uint4& dst, void const* ptr)
{
dst = *reinterpret_cast<uint4 const*>(ptr);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Data_type, int N>
struct Ldg_functor
{
// Ctor.
inline __device__ Ldg_functor(Data_type (&fetch)[N], void const* (&ptrs)[N])
: fetch_(fetch)
, ptrs_(ptrs)
{
}
// Clear the element.
inline __device__ void clear(int ii)
{
fmha::clear(fetch_[ii]);
}
// Trigger the loads.
inline __device__ void ldgsts(int ii, bool p)
{
if (p)
{
ldg(fetch_[ii], ptrs_[ii]);
}
}
// The fetch registers.
Data_type (&fetch_)[N];
// The pointers.
void const* (&ptrs_)[N];
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Data_type, int N, int M>
inline __device__ void ldg_(Data_type (&fetch)[N], void const* (&ptrs)[N], uint32_t (&preds)[M])
{
Ldg_functor<Data_type, N> fct(fetch, ptrs);
ldgsts_<N>(fct, preds);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int N, int M>
inline __device__ void ldg(uint8_t (&fetch)[N], void const* (&ptrs)[N], uint32_t (&preds)[M])
{
ldg_<uint8_t, N>(fetch, ptrs, preds);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int N, int M>
inline __device__ void ldg(uint16_t (&fetch)[N], void const* (&ptrs)[N], uint32_t (&preds)[M])
{
ldg_<uint16_t, N>(fetch, ptrs, preds);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int N, int M>
inline __device__ void ldg(uint32_t (&fetch)[N], void const* (&ptrs)[N], uint32_t (&preds)[M])
{
ldg_<uint32_t, N>(fetch, ptrs, preds);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int N, int M>
inline __device__ void ldg(uint2 (&fetch)[N], void const* (&ptrs)[N], uint32_t (&preds)[M])
{
ldg_<uint2, N>(fetch, ptrs, preds);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int N, int M>
inline __device__ void ldg(uint4 (&fetch)[N], void const* (&ptrs)[N], uint32_t (&preds)[M])
{
ldg_<uint4, N>(fetch, ptrs, preds);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <bool USE_LDGSTS>
inline __device__ void ldgdepbar()
{
if (USE_LDGSTS)
{
asm volatile("cp.async.commit_group;\n" ::);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <bool USE_LDGSTS, int COUNT = 0>
inline __device__ void depbar_()
{
if (USE_LDGSTS)
{
asm volatile("cp.async.wait_group %0;\n" ::"n"(COUNT));
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <bool USE_LDGSTS, int STAGES>
inline __device__ void depbar()
{
if (USE_LDGSTS)
{
int const VALUE = Max<STAGES - 2, 0>::VALUE;
asm volatile("cp.async.wait_group %0;\n" ::"n"(VALUE));
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void ldgsts128(uint32_t dst, void const* src, bool p = true)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
uint32_t m = p ? 16u : 0u;
asm volatile("cp.async.cg.shared.global [%0], [%1], 16, %2;\n" ::"r"(dst), "l"(src), "r"(m));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int N>
struct Ldgsts_functor
{
// Ctor.
inline __device__ Ldgsts_functor(uint32_t (&smem_ptrs)[N], void const* (&gmem_ptrs)[N])
: smem_ptrs_(smem_ptrs)
, gmem_ptrs_(gmem_ptrs)
{
}
// Does nothing.
inline __device__ void clear(int ii) {}
// Trigger the load-store instruction.
inline __device__ void ldgsts(int ii, bool p)
{
ldgsts128(smem_ptrs_[ii], gmem_ptrs_[ii], p);
}
// The shared memory pointers.
uint32_t (&smem_ptrs_)[N];
// The global memory pointers.
void const* (&gmem_ptrs_)[N];
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int N, int M>
inline __device__ void ldgsts(uint32_t (&dst)[N], void const* (&src)[N], uint32_t (&preds)[M])
{
Ldgsts_functor<N> fct(dst, src);
ldgsts_<N>(fct, preds);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// L D S
//
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void lds(uint16_t& dst, uint32_t ptr)
{
asm volatile("ld.shared.b16 %0, [%1];\n" : "=h"(dst) : "r"(ptr));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void lds(uint32_t& dst, uint32_t ptr)
{
asm volatile("ld.shared.b32 %0, [%1];\n" : "=r"(dst) : "r"(ptr));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void lds(uint2& dst, uint32_t ptr)
{
asm volatile("ld.shared.v2.b32 {%0, %1}, [%2];\n" : "=r"(dst.x), "=r"(dst.y) : "r"(ptr));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void lds(uint4& dst, uint32_t ptr)
{
asm volatile("ld.shared.v4.b32 {%0, %1, %2, %3}, [%4];\n"
: "=r"(dst.x), "=r"(dst.y), "=r"(dst.z), "=r"(dst.w)
: "r"(ptr));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// L D S M
//
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void ldsm(uint32_t& dst, uint32_t ptr)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730
asm volatile("ldmatrix.sync.aligned.m8n8.x1.shared.b16 {%0}, [%1];\n" : "=r"(dst) : "r"(ptr));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void ldsmt(uint32_t& dst, uint32_t ptr)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730
asm volatile("ldmatrix.sync.aligned.m8n8.x1.trans.shared.b16 {%0}, [%1];\n" : "=r"(dst) : "r"(ptr));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void ldsm(uint2& dst, uint32_t ptr)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730
asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0, %1}, [%2];\n" : "=r"(dst.x), "=r"(dst.y) : "r"(ptr));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void ldsmt(uint2& dst, uint32_t ptr)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730
asm volatile("ldmatrix.sync.aligned.m8n8.x2.trans.shared.b16 {%0, %1}, [%2];\n"
: "=r"(dst.x), "=r"(dst.y)
: "r"(ptr));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void ldsm(uint4& dst, uint32_t ptr)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730
asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n"
: "=r"(dst.x), "=r"(dst.y), "=r"(dst.z), "=r"(dst.w)
: "r"(ptr));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void ldsmt(uint4& dst, uint32_t ptr)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730
asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];\n"
: "=r"(dst.x), "=r"(dst.y), "=r"(dst.z), "=r"(dst.w)
: "r"(ptr));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// S T S M
//
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void stsm(uint32_t ptr, uint32_t const& src)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
asm volatile("stmatrix.sync.aligned.m8n8.x1.shared.b16 [%0], {%1};\n" ::"r"(ptr), "r"(src));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void stsmt(uint32_t ptr, uint32_t const& src)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
asm volatile("stmatrix.sync.aligned.m8n8.x1.trans.shared.b16 [%0], {%1};\n" ::"r"(ptr), "r"(src));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void stsm(uint32_t ptr, uint2 const& src)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
asm volatile("stmatrix.sync.aligned.m8n8.x2.shared.b16 [%0], {%1, %2};\n" ::"r"(ptr), "r"(src.x), "r"(src.y));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void stsmt(uint32_t ptr, uint2 const& src)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
asm volatile("stmatrix.sync.aligned.m8n8.x2.trans.shared.b16 [%0], {%1, %2};\n" ::"r"(ptr), "r"(src.x), "r"(src.y));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void stsm(uint32_t ptr, uint4 const& src)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
asm volatile("stmatrix.sync.aligned.m8n8.x4.shared.b16 [%0], {%1, %2, %3, %4};\n" ::"r"(ptr), "r"(src.x),
"r"(src.y), "r"(src.z), "r"(src.w));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void stsmt(uint32_t ptr, uint4 const& src)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
asm volatile("stmatrix.sync.aligned.m8n8.x4.trans.shared.b16 [%0], {%1, %2, %3, %4};\n" ::"r"(ptr), "r"(src.x),
"r"(src.y), "r"(src.z), "r"(src.w));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// S T G
//
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void stg(void* ptr, float val)
{
*reinterpret_cast<float*>(ptr) = val;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void stg(void* ptr, uint8_t val)
{
*reinterpret_cast<uint8_t*>(ptr) = val;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void stg(void* ptr, uint16_t val)
{
*reinterpret_cast<uint16_t*>(ptr) = val;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void stg(void* ptr, uint32_t val)
{
*reinterpret_cast<uint32_t*>(ptr) = val;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void stg(void* ptr, uint2 val)
{
*reinterpret_cast<uint2*>(ptr) = val;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void stg(void* ptr, uint4 val)
{
*reinterpret_cast<uint4*>(ptr) = val;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// S T S
//
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void sts(uint32_t ptr, uint16_t val)
{
asm volatile("st.shared.b16 [%0], %1;\n" : : "r"(ptr), "h"(val));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void sts(uint32_t ptr, uint32_t val)
{
asm volatile("st.shared.b32 [%0], %1;\n" : : "r"(ptr), "r"(val));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void sts(uint32_t ptr, uint2 val)
{
asm volatile("st.shared.v2.b32 [%0], {%1, %2};\n" : : "r"(ptr), "r"(val.x), "r"(val.y));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void sts(uint32_t ptr, uint4 val)
{
asm volatile("st.shared.v4.b32 [%0], {%1, %2, %3, %4};\n"
:
: "r"(ptr), "r"(val.x), "r"(val.y), "r"(val.z), "r"(val.w));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Data_type, int N>
inline __device__ void sts_(uint32_t (&ptrs)[N], Data_type const (&data)[N])
{
#pragma unroll
for (int ii = 0; ii < N; ++ii)
{
sts(ptrs[ii], data[ii]);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int N>
inline __device__ void sts(uint32_t (&ptrs)[N], uint16_t const (&data)[N])
{
sts_<uint16_t, N>(ptrs, data);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int N>
inline __device__ void sts(uint32_t (&ptrs)[N], uint32_t const (&data)[N])
{
sts_<uint32_t, N>(ptrs, data);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int N>
inline __device__ void sts(uint32_t (&ptrs)[N], uint2 const (&data)[N])
{
sts_<uint2, N>(ptrs, data);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int N>
inline __device__ void sts(uint32_t (&ptrs)[N], uint4 const (&data)[N])
{
sts_<uint4, N>(ptrs, data);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#define __HALF2_TO_UI(var) *(reinterpret_cast<unsigned int*>(&(var)))
#define __HALF2_TO_CUI(var) *(reinterpret_cast<const unsigned int*>(&(var)))
static __device__ __inline__ void atomicAdd_half2(half2* const address, const half2 val)
{
asm volatile("{ red.global.add.noftz.f16x2 [%0],%1; }\n" ::"l"(address), "r"(__HALF2_TO_CUI(val)) : "memory");
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <bool CAN_BE_NEGATIVE>
static inline __device__ uint32_t float4_to_char4(float x, float y, float z, float w)
{
#if defined(USE_F2I_EMULATION_TRICK)
// Make sure the float is in the proper range.
float cx, cy, cz, cw;
if (CAN_BE_NEGATIVE)
{
cx = fmha::clamp(x, -128.f, 127.f);
cy = fmha::clamp(y, -128.f, 127.f);
cz = fmha::clamp(z, -128.f, 127.f);
cw = fmha::clamp(w, -128.f, 127.f);
}
else
{
cx = fminf(x, 127.f);
cy = fminf(y, 127.f);
cz = fminf(z, 127.f);
cw = fminf(w, 127.f);
}
// Re-add the magic number.
cx += FP32_I2F_MAGIC_NUMBER;
cy += FP32_I2F_MAGIC_NUMBER;
cz += FP32_I2F_MAGIC_NUMBER;
cw += FP32_I2F_MAGIC_NUMBER;
// We need unsigned ints...
uint32_t a = reinterpret_cast<uint32_t const&>(cx);
uint32_t b = reinterpret_cast<uint32_t const&>(cy);
uint32_t c = reinterpret_cast<uint32_t const&>(cz);
uint32_t d = reinterpret_cast<uint32_t const&>(cw);
// Pack the numbers.
uint32_t dst;
asm volatile("prmt.b32 %0, %1, %2, 0x0040;\n" : "=r"(dst) : "r"(a), "r"(b));
asm volatile("prmt.b32 %0, %0, %1, 0x0410;\n" : "+r"(dst) : "r"(c));
asm volatile("prmt.b32 %0, %0, %1, 0x4210;\n" : "+r"(dst) : "r"(d));
return dst;
#else
uint32_t a;
asm volatile("cvt.rni.sat.s32.f32 %0, %1;\n" : "=r"(a) : "f"(x));
uint32_t b;
asm volatile("cvt.rni.sat.s32.f32 %0, %1;\n" : "=r"(b) : "f"(y));
uint32_t c;
asm volatile("cvt.rni.sat.s32.f32 %0, %1;\n" : "=r"(c) : "f"(z));
uint32_t d;
asm volatile("cvt.rni.sat.s32.f32 %0, %1;\n" : "=r"(d) : "f"(w));
uint32_t dst;
asm volatile("cvt.pack.sat.s8.s32.b32 %0, %1, %2, 0;\n" : "=r"(dst) : "r"(d), "r"(c));
asm volatile("cvt.pack.sat.s8.s32.b32 %0, %1, %2, %0;\n" : "+r"(dst) : "r"(b), "r"(a));
return dst;
#endif // defined(USE_F2I_EMULATION_TRICK)
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ void swizzle_rows(uint32_t& a, uint32_t& b, uint32_t c, uint32_t d)
{
asm volatile("prmt.b32 %0, %1, %2, 0x6420;\n" : "=r"(a) : "r"(c), "r"(d));
asm volatile("prmt.b32 %0, %1, %2, 0x7531;\n" : "=r"(b) : "r"(c), "r"(d));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void ldsm_with_lds(uint2& data, uint32_t smem)
{
int lane = threadIdx.x % 32;
data = {0, 0};
uint4 v = {0, 0, 0, 0};
uint32_t* a = reinterpret_cast<uint32_t*>(&v);
if (lane < 16)
{
fmha::lds(v, smem);
}
int src_row = lane / 4;
int src_col = lane % 4;
for (int it = 0; it < 4; it++)
{
uint32_t val = a[it];
uint32_t x = __shfl_sync(uint32_t(-1), val, src_row);
__syncwarp();
uint32_t y = __shfl_sync(uint32_t(-1), val, src_row + 8);
__syncwarp();
if (it == src_col)
{
data.x = x;
data.y = y;
}
}
}
inline __device__ void ldsmt_with_lds(uint2& data, uint32_t smem)
{
int lane = threadIdx.x % 32;
uint4 tmp16{0, 0, 0, 0}; // 16B
if (lane < 16)
{
fmha::lds(tmp16, smem);
}
uint16_t* tmp16c = reinterpret_cast<uint16_t*>(&tmp16); // 8x2B: we move pairs
uint16_t* t = reinterpret_cast<uint16_t*>(&data); // 4x2B
int const src_col = lane / 4; // 0 - 7
int const src_row = (lane % 4) * 2;
// we have to shuffle the values to distribute them in the warp
#pragma unroll
for (int it = 0; it < 8; it++)
{
uint16_t val, x, y;
val = tmp16c[it];
x = __shfl_sync(uint32_t(-1), val, src_row + 0);
__syncwarp();
y = __shfl_sync(uint32_t(-1), val, src_row + 1);
__syncwarp();
if (src_col == it)
{
t[0] = x;
t[1] = y;
}
val = tmp16c[it];
x = __shfl_sync(uint32_t(-1), val, src_row + 8);
__syncwarp();
y = __shfl_sync(uint32_t(-1), val, src_row + 9);
__syncwarp();
if (src_col == it)
{
t[2] = x;
t[3] = y;
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T>
struct MaxOp
{
__device__ inline T operator()(T const& x, T const& y)
{
return x > y ? x : y;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T>
struct SumOp
{
__device__ inline T operator()(T const& x, T const& y)
{
return x + y;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int THREADS>
struct Allreduce
{
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
template <typename T, typename Operator>
static __device__ inline T run(T x, Operator& op)
{
constexpr int OFFSET = THREADS / 2;
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
return Allreduce<OFFSET>::run(x, op);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <>
struct Allreduce<2>
{
template <typename T, typename Operator>
static __device__ inline T run(T x, Operator& op)
{
x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
return x;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Operator, int M>
__device__ inline void quad_reduce(float (&dst)[M], float (&src)[M], Operator& op)
{
#pragma unroll
for (int mi = 0; mi < M; mi++)
{
dst[mi] = src[mi];
dst[mi] = op(dst[mi], __shfl_down_sync(uint32_t(-1), dst[mi], 2));
dst[mi] = op(dst[mi], __shfl_down_sync(uint32_t(-1), dst[mi], 1));
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Operator, int M>
__device__ inline void quad_reduce(float (&dst)[M], float2 (&src)[M], Operator& op)
{
float tmp[M];
#pragma unroll
for (int mi = 0; mi < M; mi++)
{
tmp[mi] = op(src[mi].x, src[mi].y);
}
quad_reduce(dst, tmp, op);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Operator, int M>
__device__ inline void quad_allreduce(float (&dst)[M], float (&src)[M], Operator& op)
{
#pragma unroll
for (int mi = 0; mi < M; mi++)
{
dst[mi] = src[mi];
dst[mi] = Allreduce<4>::run(dst[mi], op);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Operator, int M>
__device__ inline void quad_allreduce(float (&dst)[M], float2 (&src)[M], Operator& op)
{
float tmp[M];
#pragma unroll
for (int mi = 0; mi < M; mi++)
{
tmp[mi] = op(src[mi].x, src[mi].y);
}
quad_allreduce(dst, tmp, op);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ uint32_t elect_one_sync()
{
uint32_t pred = 0;
#if __CUDA_ARCH__ >= 900
#if !defined(__CUDACC_RTC__)
uint32_t laneid = 0;
asm volatile(
"\n\
{\n\
.reg .b32 %rx;\n\
.reg .pred %px;\n\
elect.one.sync %rx|%px, %2;\n\
@%px mov.s32 %1, 1;\n\
mov.s32 %0, %rx;\n\
}\n"
: "+r"(laneid), "+r"(pred)
: "r"(0xFFFFFFFF));
#else
pred = threadIdx.x == 0;
#endif
#endif
return pred;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ uint16_t float2_to_e4m3x2(float x, float y)
{
#if defined(__CUDA_ARCH__) && ((__CUDA_ARCH__ == 890 && defined(FMHA_ENABLE_SM89_QMMA)) || (__CUDA_ARCH__ >= 900))
uint16_t res;
asm volatile("cvt.rn.e4m3x2.f32.satfinite %0, %2, %1;" : "=h"(res) : "f"(x), "f"(y));
return res;
#else
assert(false);
return 0;
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ uint32_t float4_to_e4m3x4(float x, float y, float z, float w)
{
#if defined(__CUDA_ARCH__) && ((__CUDA_ARCH__ == 890 && defined(FMHA_ENABLE_SM89_QMMA)) || (__CUDA_ARCH__ >= 900))
uint32_t res;
asm volatile(
"{\n"
".reg .b16 lo;\n"
".reg .b16 hi;\n"
"cvt.rn.e4m3x2.f32.satfinite lo, %2, %1;\n"
"cvt.rn.e4m3x2.f32.satfinite hi, %4, %3;\n"
"mov.b32 %0, {lo, hi};\n"
"}"
: "=r"(res)
: "f"(x), "f"(y), "f"(z), "f"(w));
return res;
#else
assert(false);
return 0;
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ uint32_t float4_to_e5m2x4(float x, float y, float z, float w)
{
#if defined(__CUDA_ARCH__) && ((__CUDA_ARCH__ == 890 && defined(FMHA_ENABLE_SM89_QMMA)) || (__CUDA_ARCH__ >= 900))
uint32_t res;
asm volatile(
"{\n"
".reg .b16 lo;\n"
".reg .b16 hi;\n"
"cvt.rn.e5m2x2.f32.satfinite lo, %2, %1;\n"
"cvt.rn.e5m2x2.f32.satfinite hi, %4, %3;\n"
"mov.b32 %0, {lo, hi};\n"
"}"
: "=r"(res)
: "f"(x), "f"(y), "f"(z), "f"(w));
return res;
#else
assert(false);
return 0;
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ uint32_t half4_to_e4m3x4(uint32_t const h2_0, uint32_t const h2_1)
{
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890))
uint32_t res;
asm volatile(
"{\n"
".reg .b16 lo, hi;\n"
"cvt.satfinite.rn.e4m3x2.f16x2 lo, %1;\n"
"cvt.satfinite.rn.e4m3x2.f16x2 hi, %2;\n"
"mov.b32 %0, {lo, hi};\n"
"}\n"
: "=r"(res)
: "r"(h2_0), "r"(h2_1));
return res;
#else
assert(false);
return 0;
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ uint32_t half4_to_e5m2x4(uint32_t const h2_0, uint32_t const h2_1)
{
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890))
uint32_t res;
asm volatile(
"{\n"
".reg .b16 lo, hi;\n"
"cvt.satfinite.rn.e5m2x2.f16x2 lo, %1;\n"
"cvt.satfinite.rn.e5m2x2.f16x2 hi, %2;\n"
"mov.b32 %0, {lo, hi};\n"
"}\n"
: "=r"(res)
: "r"(h2_0), "r"(h2_1));
return res;
#else
assert(false);
return 0;
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// Helpers to pack float4 into a destination register with 4 8bit values
template <typename Dst_type>
inline __device__ uint32_t float4_to_8bitx4(float const x, float const y, float const z, float const w)
{
return float4_to_char4<false>(x, y, z, w);
};
template <>
inline __device__ uint32_t float4_to_8bitx4<e4m3_t>(float const x, float const y, float const z, float const w)
{
return float4_to_e4m3x4(x, y, z, w);
};
template <>
inline __device__ uint32_t float4_to_8bitx4<e5m2_t>(float const x, float const y, float const z, float const w)
{
return float4_to_e5m2x4(x, y, z, w);
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T>
inline __device__ uint32_t half4_to_fp8x4(uint32_t const h2_0, uint32_t const h2_1);
template <>
inline __device__ uint32_t half4_to_fp8x4<fmha::e4m3_t>(uint32_t const h2_0, uint32_t const h2_1)
{
return half4_to_e4m3x4(h2_0, h2_1);
}
template <>
inline __device__ uint32_t half4_to_fp8x4<fmha::e5m2_t>(uint32_t const h2_0, uint32_t const h2_1)
{
return half4_to_e5m2x4(h2_0, h2_1);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T>
inline __device__ uint32_t float4_to_fp8x4(float const, float const, float const, float const);
template <>
inline __device__ uint32_t float4_to_fp8x4<fmha::e4m3_t>(float const x, float const y, float const z, float const w)
{
return float4_to_e4m3x4(x, y, z, w);
}
template <>
inline __device__ uint32_t float4_to_fp8x4<fmha::e5m2_t>(float const x, float const y, float const z, float const w)
{
return float4_to_e5m2x4(x, y, z, w);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void fence_view_async_shared()
{
// Issue a shared memory fence for async operations (FENCE.VIEW.ASYNC.S)
// only compiles on sm90+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
asm volatile("fence.proxy.async.shared::cta;\n");
#else
assert(false);
#endif
}
inline __device__ void fence_view_async_global()
{
// Issue a global memory fence for async operations (FENCE.VIEW.ASYNC.G)
// only compiles on sm90+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
asm volatile("fence.proxy.async.global::cta;\n");
#else
assert(false);
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ char* align_1024(char* ptr)
{
uint64_t address_bit = reinterpret_cast<uint64_t>(ptr);
uint64_t offset = address_bit % 1024;
if (offset == 0)
{
return ptr;
}
else
{
return ptr + (1024 - offset);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ float atomicMaxFloat(float* addr, float value)
{
float old;
old = (value >= 0) ? __int_as_float(atomicMax((int*) addr, __float_as_int(value)))
: __uint_as_float(atomicMin((unsigned int*) addr, __float_as_uint(value)));
return old;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ float atomicMaxFloatPos_(float* addr, float value)
{
// VALUE MUST BE POSITIVE! USED ONLY FOR INTERNAL AMAX REDUCTION.
float old = __int_as_float(atomicMax((int*) addr, __float_as_int(value)));
return old;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ float max3Pos_(float const a, float const b, float const c)
{
// VALUE MUST BE POSITIVE! USED ONLY FOR INTERNAL AMAX REDUCTION.
float res;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
int32_t a_ = reinterpret_cast<int32_t const&>(a);
int32_t b_ = reinterpret_cast<int32_t const&>(b);
int32_t c_ = reinterpret_cast<int32_t const&>(c);
int32_t tmp;
asm volatile("max.s16x2 %0, %1, %2;\n" : "=r"(tmp) : "r"(a_), "r"(b_));
asm volatile("max.s16x2 %0, %0, %1;\n" : "+r"(tmp) : "r"(tmp), "r"(c_));
res = reinterpret_cast<float const&>(tmp);
#else
res = fmaxf(a, fmaxf(b, c));
#endif
return res;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// Fast approximate tanh.
static inline __device__ float __tanhf(float x)
{
#if (__CUDA_ARCH__ >= 750)
float r = x;
asm("tanh.approx.f32 %0, %0;" : "+f"(r));
return r;
#else
return tanhf(x);
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace fmha