mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-27 22:23:25 +08:00
2929 lines
74 KiB
C++
2929 lines
74 KiB
C++
/*
|
|
* SPDX-FileCopyrightText: Copyright (c) 2011-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#pragma once
|
|
|
|
#include <assert.h>
|
|
#include <cuda_fp16.h>
|
|
#include <stdint.h>
|
|
#include <stdlib.h>
|
|
|
|
#if defined(__CLANGD__)
|
|
#include <__clang_cuda_builtin_vars.h>
|
|
#include <__clang_cuda_math.h>
|
|
#endif
|
|
|
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
|
#include <cuda_bf16.h>
|
|
#endif
|
|
|
|
// include warpgroup related instructions, used by SM90.
|
|
#include <fmha/hopper/utils_warpgroup.h>
|
|
// include gmma related instructions, used by SM90.
|
|
#include <fmha/hopper/utils_gmma.h>
|
|
// include tma related instructions, used by SM90.
|
|
#include <fmha/hopper/utils_tma.h>
|
|
|
|
#include "fmha/numeric_types.h"
|
|
|
|
#define FP32_I2F_MAGIC_NUMBER 12582912.f
|
|
#define FP32_I2F_MAGIC_NUMBER_HEX 0x4b400000
|
|
|
|
extern "C" __device__ uint32_t __nvvm_get_smem_pointer(void* ptr);
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
namespace introspection
|
|
{
|
|
|
|
template <int... Ns>
|
|
struct Unpack;
|
|
|
|
template <int N>
|
|
struct Unpack<N>
|
|
{
|
|
// if we simply static_assert(false) then compiler will not emit template params upon failure
|
|
static_assert(N < INT_MIN, "");
|
|
using Type = std::integral_constant<int, N>;
|
|
};
|
|
|
|
template <int N, int... Ns>
|
|
struct Unpack<N, Ns...>
|
|
{
|
|
using Type = Unpack<N, Ns...>;
|
|
using Unpack_first = typename Unpack<N>::Type;
|
|
using Unpack_remaining = typename Unpack<Ns...>::Type;
|
|
};
|
|
|
|
} // namespace introspection
|
|
|
|
// Example usage:
|
|
//
|
|
// Inspect_ns<(int)USE_LDGSTS_, PRED_REGS, (int)IS_HOPPER> foo;
|
|
//
|
|
// or
|
|
//
|
|
// Inspect_ns<(int)USE_LDGSTS_, PRED_REGS, (int)IS_HOPPER>{}.foo();
|
|
//
|
|
// Output by nvcc:
|
|
//
|
|
// ./src/fmha/gmem_tile_qkv_packed.h(70): error: static assertion failed with ""
|
|
// detected during:
|
|
// instantiation of class "fmha::v2::Unpack<N> [with N=1]"
|
|
// (77): here
|
|
// instantiation of class "fmha::v2::Unpack<N, Ns...> [with N=1, Ns=<2, 0>]"
|
|
// (84): here
|
|
// instantiation of class "fmha::v2::Inspect_ns<Ns...> [with Ns=<1, 2, 0>]"
|
|
// (143): here
|
|
template <int... Ns>
|
|
struct Inspect_ns
|
|
{
|
|
using Type = typename introspection::Unpack<Ns...>::Type;
|
|
};
|
|
|
|
// Can be used alongside with static_assert() to figure out the conditions when assertion failed
|
|
// Example:
|
|
//
|
|
// Cond_inspect_ns< (int)ROWS >= (int)ROWS_PER_LDG, ROWS, ROWS_PER_LDG> foo;
|
|
//
|
|
// Output by nvcc (when condition is not met):
|
|
//
|
|
// ./src/fmha/utils.h(163): error: static assertion failed with ""
|
|
// detected during:
|
|
// instantiation of class "Cond_inspect_ns<COND, Ns...> [with COND=false, Ns=<32, 64>]"
|
|
template <bool COND, int... Ns>
|
|
struct Cond_inspect_ns
|
|
{
|
|
static_assert(COND, "");
|
|
};
|
|
|
|
// Example:
|
|
//
|
|
// Inspect_type<Mma_tile_p>{}.foo();
|
|
//
|
|
// or
|
|
//
|
|
// Inspect_type<Mma_tile_p> foo;
|
|
//
|
|
// Output by nvcc:
|
|
//
|
|
// ./src/fmha/utils.h(189): error: class "fmha::Ampere_hmma_tile<fmha::Cta_tile_<fmha::Ampere, 64, 128, 64, 128, 256,
|
|
// 4, 1, 1>, 16>" has no member "Dummy"
|
|
// detected during:
|
|
// instantiation of class "Inspect_type<T> [with T=fmha::Ampere_hmma_tile<fmha::Cta_tile_<fmha::Ampere,
|
|
// 64, 128, 64, 128, 256, 4, 1, 1>, 16>]"
|
|
template <typename T>
|
|
struct Inspect_type
|
|
{
|
|
// Purposefully trigger error by referencing non-existent T::Dummy
|
|
using Dummy = typename T::Dummy;
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
namespace fmha
|
|
{
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
struct Row
|
|
{
|
|
static constexpr bool COL = false;
|
|
static constexpr bool ROW = true;
|
|
};
|
|
|
|
struct Col
|
|
{
|
|
static constexpr bool COL = true;
|
|
static constexpr bool ROW = false;
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <int M, int N>
|
|
struct Round_up
|
|
{
|
|
enum
|
|
{
|
|
VALUE = (M + N - 1) / N * N
|
|
};
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <int N_, int H_, int W_>
|
|
struct Tile_nhw
|
|
{
|
|
enum
|
|
{
|
|
N = N_,
|
|
H = H_,
|
|
W = W_
|
|
};
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <int M, bool = (M & (M - 1)) == 0>
|
|
struct Next_power_of_two
|
|
{
|
|
};
|
|
|
|
template <int M>
|
|
struct Next_power_of_two<M, true>
|
|
{
|
|
enum
|
|
{
|
|
VALUE = M
|
|
};
|
|
};
|
|
|
|
template <>
|
|
struct Next_power_of_two<3, false>
|
|
{
|
|
enum
|
|
{
|
|
VALUE = 4
|
|
};
|
|
};
|
|
|
|
template <>
|
|
struct Next_power_of_two<5, false>
|
|
{
|
|
enum
|
|
{
|
|
VALUE = 8
|
|
};
|
|
};
|
|
|
|
template <>
|
|
struct Next_power_of_two<6, false>
|
|
{
|
|
enum
|
|
{
|
|
VALUE = 8
|
|
};
|
|
};
|
|
|
|
template <>
|
|
struct Next_power_of_two<7, false>
|
|
{
|
|
enum
|
|
{
|
|
VALUE = 8
|
|
};
|
|
};
|
|
|
|
template <>
|
|
struct Next_power_of_two<9, false>
|
|
{
|
|
enum
|
|
{
|
|
VALUE = 16
|
|
};
|
|
};
|
|
|
|
template <>
|
|
struct Next_power_of_two<10, false>
|
|
{
|
|
enum
|
|
{
|
|
VALUE = 16
|
|
};
|
|
};
|
|
|
|
template <>
|
|
struct Next_power_of_two<11, false>
|
|
{
|
|
enum
|
|
{
|
|
VALUE = 16
|
|
};
|
|
};
|
|
|
|
template <>
|
|
struct Next_power_of_two<12, false>
|
|
{
|
|
enum
|
|
{
|
|
VALUE = 16
|
|
};
|
|
};
|
|
|
|
template <>
|
|
struct Next_power_of_two<13, false>
|
|
{
|
|
enum
|
|
{
|
|
VALUE = 16
|
|
};
|
|
};
|
|
|
|
template <>
|
|
struct Next_power_of_two<14, false>
|
|
{
|
|
enum
|
|
{
|
|
VALUE = 16
|
|
};
|
|
};
|
|
|
|
template <>
|
|
struct Next_power_of_two<15, false>
|
|
{
|
|
enum
|
|
{
|
|
VALUE = 16
|
|
};
|
|
};
|
|
|
|
template <>
|
|
struct Next_power_of_two<24, false>
|
|
{
|
|
enum
|
|
{
|
|
VALUE = 32
|
|
};
|
|
};
|
|
|
|
template <>
|
|
struct Next_power_of_two<40, false>
|
|
{
|
|
enum
|
|
{
|
|
VALUE = 64
|
|
};
|
|
};
|
|
|
|
template <>
|
|
struct Next_power_of_two<48, false>
|
|
{
|
|
enum
|
|
{
|
|
VALUE = 64
|
|
};
|
|
};
|
|
|
|
template <>
|
|
struct Next_power_of_two<72, false>
|
|
{
|
|
enum
|
|
{
|
|
VALUE = 128
|
|
};
|
|
};
|
|
|
|
template <>
|
|
struct Next_power_of_two<80, false>
|
|
{
|
|
enum
|
|
{
|
|
VALUE = 128
|
|
};
|
|
};
|
|
|
|
template <>
|
|
struct Next_power_of_two<96, false>
|
|
{
|
|
enum
|
|
{
|
|
VALUE = 128
|
|
};
|
|
};
|
|
|
|
template <>
|
|
struct Next_power_of_two<104, false>
|
|
{
|
|
enum
|
|
{
|
|
VALUE = 128
|
|
};
|
|
};
|
|
|
|
template <>
|
|
struct Next_power_of_two<112, false>
|
|
{
|
|
enum
|
|
{
|
|
VALUE = 128
|
|
};
|
|
};
|
|
|
|
template <>
|
|
struct Next_power_of_two<144, false>
|
|
{
|
|
enum
|
|
{
|
|
VALUE = 256
|
|
};
|
|
};
|
|
|
|
template <>
|
|
struct Next_power_of_two<160, false>
|
|
{
|
|
enum
|
|
{
|
|
VALUE = 256
|
|
};
|
|
};
|
|
|
|
template <>
|
|
struct Next_power_of_two<192, false>
|
|
{
|
|
enum
|
|
{
|
|
VALUE = 256
|
|
};
|
|
};
|
|
|
|
template <>
|
|
struct Next_power_of_two<576, false>
|
|
{
|
|
enum
|
|
{
|
|
VALUE = 1024
|
|
};
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <int N, bool = (N & (N - 1)) == 0>
|
|
struct Prev_power_of_two
|
|
{
|
|
};
|
|
|
|
template <int N>
|
|
struct Prev_power_of_two<N, true>
|
|
{
|
|
enum
|
|
{
|
|
VALUE = N
|
|
};
|
|
};
|
|
|
|
template <>
|
|
struct Prev_power_of_two<3, false>
|
|
{
|
|
enum
|
|
{
|
|
VALUE = 2
|
|
};
|
|
};
|
|
|
|
template <>
|
|
struct Prev_power_of_two<5, false>
|
|
{
|
|
enum
|
|
{
|
|
VALUE = 4
|
|
};
|
|
};
|
|
|
|
template <>
|
|
struct Prev_power_of_two<6, false>
|
|
{
|
|
enum
|
|
{
|
|
VALUE = 4
|
|
};
|
|
};
|
|
|
|
template <>
|
|
struct Prev_power_of_two<7, false>
|
|
{
|
|
enum
|
|
{
|
|
VALUE = 4
|
|
};
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <int BYTES_PER_ROW, int SKEW>
|
|
struct Compute_skew
|
|
{
|
|
// The size of a transaction.
|
|
enum
|
|
{
|
|
BYTES_PER_TRX = 128
|
|
};
|
|
|
|
// The remainder of the row without skew.
|
|
enum
|
|
{
|
|
REMAINDER = BYTES_PER_ROW % BYTES_PER_TRX
|
|
};
|
|
|
|
// The value.
|
|
enum
|
|
{
|
|
VALUE = REMAINDER <= SKEW ? SKEW - REMAINDER : BYTES_PER_TRX + SKEW - REMAINDER
|
|
};
|
|
|
|
// Make sure the math works ;)
|
|
static_assert((BYTES_PER_ROW + VALUE) % BYTES_PER_TRX == SKEW, "");
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <int BYTES_PER_ROW>
|
|
struct Compute_skew<BYTES_PER_ROW, 128>
|
|
{
|
|
// No skew!
|
|
enum
|
|
{
|
|
VALUE = 0
|
|
};
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <int M, int N>
|
|
struct Div_up
|
|
{
|
|
enum
|
|
{
|
|
VALUE = (M + N - 1) / N
|
|
};
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <int A, int B>
|
|
struct Max
|
|
{
|
|
enum
|
|
{
|
|
VALUE = A >= B ? A : B
|
|
};
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <int A, int B, int C>
|
|
struct Max_3
|
|
{
|
|
enum
|
|
{
|
|
VALUE = Max<Max<A, B>::VALUE, C>::VALUE
|
|
};
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <int A, int B>
|
|
struct Min
|
|
{
|
|
enum
|
|
{
|
|
VALUE = A <= B ? A : B
|
|
};
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <int SIZE_IN_BYTES>
|
|
struct Uint_from_size_in_bytes
|
|
{
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <>
|
|
struct Uint_from_size_in_bytes<1>
|
|
{
|
|
using Type = uint8_t;
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <>
|
|
struct Uint_from_size_in_bytes<2>
|
|
{
|
|
using Type = uint16_t;
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <>
|
|
struct Uint_from_size_in_bytes<4>
|
|
{
|
|
using Type = uint32_t;
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <>
|
|
struct Uint_from_size_in_bytes<8>
|
|
{
|
|
using Type = uint2;
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <>
|
|
struct Uint_from_size_in_bytes<16>
|
|
{
|
|
using Type = uint4;
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <int WARPS_M, int WARPS_N, int WARPS_K>
|
|
struct Warp_masks
|
|
{
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <>
|
|
struct Warp_masks<8, 1, 1>
|
|
{
|
|
enum
|
|
{
|
|
M = 0xe0,
|
|
N = 0x00,
|
|
K = 0x00
|
|
};
|
|
};
|
|
|
|
template <>
|
|
struct Warp_masks<4, 2, 1>
|
|
{
|
|
enum
|
|
{
|
|
M = 0x60,
|
|
N = 0x80,
|
|
K = 0x00
|
|
};
|
|
};
|
|
|
|
template <>
|
|
struct Warp_masks<4, 1, 2>
|
|
{
|
|
enum
|
|
{
|
|
M = 0x60,
|
|
N = 0x00,
|
|
K = 0x80
|
|
};
|
|
};
|
|
|
|
template <>
|
|
struct Warp_masks<4, 1, 1>
|
|
{
|
|
enum
|
|
{
|
|
M = 0x60,
|
|
N = 0x00,
|
|
K = 0x00
|
|
};
|
|
};
|
|
|
|
template <>
|
|
struct Warp_masks<2, 4, 1>
|
|
{
|
|
enum
|
|
{
|
|
M = 0x20,
|
|
N = 0xc0,
|
|
K = 0x00
|
|
};
|
|
};
|
|
|
|
template <>
|
|
struct Warp_masks<2, 2, 2>
|
|
{
|
|
enum
|
|
{
|
|
M = 0x20,
|
|
N = 0x40,
|
|
K = 0x80
|
|
};
|
|
};
|
|
|
|
template <>
|
|
struct Warp_masks<2, 2, 1>
|
|
{
|
|
enum
|
|
{
|
|
M = 0x20,
|
|
N = 0x40,
|
|
K = 0x00
|
|
};
|
|
};
|
|
|
|
template <>
|
|
struct Warp_masks<2, 1, 2>
|
|
{
|
|
enum
|
|
{
|
|
M = 0x20,
|
|
N = 0x00,
|
|
K = 0x40
|
|
};
|
|
};
|
|
|
|
template <>
|
|
struct Warp_masks<2, 1, 1>
|
|
{
|
|
enum
|
|
{
|
|
M = 0x20,
|
|
N = 0x00,
|
|
K = 0x00
|
|
};
|
|
};
|
|
|
|
template <>
|
|
struct Warp_masks<1, 8, 1>
|
|
{
|
|
enum
|
|
{
|
|
M = 0x00,
|
|
N = 0xe0,
|
|
K = 0x00
|
|
};
|
|
};
|
|
|
|
template <>
|
|
struct Warp_masks<1, 4, 2>
|
|
{
|
|
enum
|
|
{
|
|
M = 0x00,
|
|
N = 0x60,
|
|
K = 0x80
|
|
};
|
|
};
|
|
|
|
template <>
|
|
struct Warp_masks<1, 4, 1>
|
|
{
|
|
enum
|
|
{
|
|
M = 0x00,
|
|
N = 0x60,
|
|
K = 0x00
|
|
};
|
|
};
|
|
|
|
template <>
|
|
struct Warp_masks<1, 2, 2>
|
|
{
|
|
enum
|
|
{
|
|
M = 0x00,
|
|
N = 0x20,
|
|
K = 0x40
|
|
};
|
|
};
|
|
|
|
template <>
|
|
struct Warp_masks<1, 2, 1>
|
|
{
|
|
enum
|
|
{
|
|
M = 0x00,
|
|
N = 0x20,
|
|
K = 0x00
|
|
};
|
|
};
|
|
|
|
template <>
|
|
struct Warp_masks<1, 1, 4>
|
|
{
|
|
enum
|
|
{
|
|
M = 0x00,
|
|
N = 0x00,
|
|
K = 0x60
|
|
};
|
|
};
|
|
|
|
template <>
|
|
struct Warp_masks<1, 1, 2>
|
|
{
|
|
enum
|
|
{
|
|
M = 0x00,
|
|
N = 0x00,
|
|
K = 0x20
|
|
};
|
|
};
|
|
|
|
template <>
|
|
struct Warp_masks<1, 1, 1>
|
|
{
|
|
enum
|
|
{
|
|
M = 0x00,
|
|
N = 0x00,
|
|
K = 0x00
|
|
};
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename T>
|
|
inline __device__ __host__ T div_up(T m, T n)
|
|
{
|
|
return (m + n - 1) / n;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline int clz(int x)
|
|
{
|
|
for (int i = 31; i >= 0; --i)
|
|
{
|
|
if ((1 << i) & x)
|
|
{
|
|
return 31 - i;
|
|
}
|
|
}
|
|
return 32;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline int find_log_2(int x, bool round_up = false)
|
|
{
|
|
int a = 31 - clz(x);
|
|
if (round_up)
|
|
{
|
|
a += (x & (x - 1)) ? 1 : 0;
|
|
}
|
|
return a;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline void find_divisor(uint32_t& mul, uint32_t& shr, int x)
|
|
{
|
|
assert(x != 0);
|
|
if (x == 1)
|
|
{
|
|
// If dividing by 1, reduced math doesn't work because mul_coeff would need to be 2^32,
|
|
// which doesn't fit into unsigned int. the div() routine handles this special case
|
|
// separately.
|
|
mul = 0;
|
|
shr = 0;
|
|
}
|
|
else
|
|
{
|
|
// To express the division N/D in terms of a multiplication, what we first
|
|
// imagine is simply N*(1/D). However, 1/D will always evaluate to 0 (for D>1),
|
|
// so we need another way. There's nothing that says we have to use exactly
|
|
// the fraction 1/D; instead it could be any X/Y that reduces to 1/D (i.e.,
|
|
// Y=X*D), or at least to "close enough" to it. If we pick Y that is a power
|
|
// of two, then the N*(X/Y) can be N*X followed by a right-shift by some amount.
|
|
// The power of two we should pick should be at least 2^32, because in the
|
|
// div() routine we'll use umulhi(), which returns only the upper 32 bits --
|
|
// this being equivalent to a right-shift by 32. But we might want a higher
|
|
// power of two for better accuracy depending on the magnitude of the denominator.
|
|
// Once we've picked Y, then X [our mul_coeff value] is simply Y/D, rounding up,
|
|
// and we save shift_coeff as whatever further shift we have to do beyond
|
|
// what the umulhi() implies.
|
|
uint32_t p = 31 + find_log_2(x, true);
|
|
uint32_t m = (uint32_t) (((1ull << p) + (uint32_t) x - 1) / (uint32_t) x);
|
|
|
|
mul = m;
|
|
shr = p - 32;
|
|
}
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline __device__ void fast_divmod(int& div, int& mod, int x, int y, uint32_t mul, uint32_t shr)
|
|
{
|
|
if (y == 1)
|
|
{
|
|
div = x;
|
|
mod = 0;
|
|
}
|
|
else
|
|
{
|
|
div = __umulhi((uint32_t) x, mul) >> shr;
|
|
mod = x - div * y;
|
|
}
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static inline __device__ uint32_t hadd2(uint32_t a, uint32_t b)
|
|
{
|
|
uint32_t c;
|
|
asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
|
|
return c;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static inline __device__ uint32_t bfadd2(uint32_t a, uint32_t b)
|
|
{
|
|
uint32_t c;
|
|
uint32_t one = 0x3f803f80;
|
|
;
|
|
asm volatile("fma.rn.bf16x2 %0, %1, %3, %2;\n" : "=r"(c) : "r"(a), "r"(b), "r"(one));
|
|
return c;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static inline __device__ uint32_t hmax2(uint32_t a, uint32_t b)
|
|
{
|
|
uint32_t c;
|
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
|
asm volatile("max.f16x2 %0, %1, %2;" : "=r"(c) : "r"(a), "r"(b));
|
|
#else
|
|
asm volatile(
|
|
"{\n"
|
|
"\t .reg .f16x2 sela, selb;\n"
|
|
"\n"
|
|
"\t set.ge.f16x2.f16x2 sela, %1, %2;\n"
|
|
"\t set.gt.f16x2.f16x2 selb, %2, %1;\n"
|
|
"\n"
|
|
"\t mul.f16x2 %0, sela, %1;\n"
|
|
"\t fma.rn.f16x2 %0, selb, %2, %0;\n"
|
|
"}\n"
|
|
: "=r"(c)
|
|
: "r"(a), "r"(b));
|
|
#endif
|
|
return c;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static inline __device__ uint2 hmax4(uint2 a, uint2 b)
|
|
{
|
|
uint2 c;
|
|
c.x = hmax2(a.x, b.x);
|
|
c.y = hmax2(a.y, b.y);
|
|
return c;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static inline __device__ uint4 hmax8(uint4 a, uint4 b)
|
|
{
|
|
uint4 c;
|
|
c.x = hmax2(a.x, b.x);
|
|
c.y = hmax2(a.y, b.y);
|
|
c.z = hmax2(a.z, b.z);
|
|
c.w = hmax2(a.w, b.w);
|
|
return c;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static inline __device__ uint32_t hmin2(uint32_t a, uint32_t b)
|
|
{
|
|
uint32_t c;
|
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
|
asm volatile("min.f16x2 %0, %1, %2;" : "=r"(c) : "r"(a), "r"(b));
|
|
#else
|
|
asm volatile(
|
|
"{\n"
|
|
"\t .reg .f16x2 sela, selb;\n"
|
|
"\n"
|
|
"\t set.le.f16x2.f16x2 sela, %1, %2;\n"
|
|
"\t set.lt.f16x2.f16x2 selb, %2, %1;\n"
|
|
"\n"
|
|
"\t mul.f16x2 %0, sela, %1;\n"
|
|
"\t fma.rn.f16x2 %0, selb, %2, %0;\n"
|
|
"}\n"
|
|
: "=r"(c)
|
|
: "r"(a), "r"(b));
|
|
#endif
|
|
return c;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static inline __device__ uint32_t hmul2(uint32_t a, uint32_t b)
|
|
{
|
|
uint32_t c;
|
|
asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
|
|
return c;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static inline __device__ uint32_t bfmul2(uint32_t a, uint32_t b)
|
|
{
|
|
uint32_t c;
|
|
asm("{.reg .b32 c;\n"
|
|
" mov.b32 c, 0x80008000U;\n"
|
|
" fma.rn.bf16x2 %0,%1,%2,c;}\n"
|
|
: "=r"(c)
|
|
: "r"(a), "r"(b));
|
|
return c;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static inline __device__ uint2 hmul4(uint2 a, uint2 b)
|
|
{
|
|
uint2 c;
|
|
c.x = hmul2(a.x, b.x);
|
|
c.y = hmul2(a.y, b.y);
|
|
return c;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static inline __device__ uint4 hmul8(uint4 a, uint4 b)
|
|
{
|
|
uint4 c;
|
|
c.x = hmul2(a.x, b.x);
|
|
c.y = hmul2(a.y, b.y);
|
|
c.z = hmul2(a.z, b.z);
|
|
c.w = hmul2(a.w, b.w);
|
|
return c;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static inline __device__ uint4 hmul8(uint32_t a, uint4 b)
|
|
{
|
|
uint4 c;
|
|
c.x = hmul2(a, b.x);
|
|
c.y = hmul2(a, b.y);
|
|
c.z = hmul2(a, b.z);
|
|
c.w = hmul2(a, b.w);
|
|
return c;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// Template function to support both half and bfloat16
|
|
template <typename Data_type>
|
|
inline __device__ uint32_t mul2(uint32_t a, uint32_t b)
|
|
{
|
|
return hmul2(a, b);
|
|
}
|
|
|
|
template <>
|
|
inline __device__ uint32_t mul2<bf16_t>(uint32_t a, uint32_t b)
|
|
{
|
|
return bfmul2(a, b);
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// Template function to support both half and bfloat16
|
|
template <typename Data_type>
|
|
inline __device__ uint4 mul8(uint32_t a, uint4 b)
|
|
{
|
|
uint4 c;
|
|
c.x = hmul2(a, b.x);
|
|
c.y = hmul2(a, b.y);
|
|
c.z = hmul2(a, b.z);
|
|
c.w = hmul2(a, b.w);
|
|
return c;
|
|
}
|
|
|
|
template <>
|
|
inline __device__ uint4 mul8<bf16_t>(uint32_t a, uint4 b)
|
|
{
|
|
uint4 c;
|
|
c.x = bfmul2(a, b.x);
|
|
c.y = bfmul2(a, b.y);
|
|
c.z = bfmul2(a, b.z);
|
|
c.w = bfmul2(a, b.w);
|
|
return c;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static inline __device__ uint32_t hrelu2(uint32_t x)
|
|
{
|
|
uint32_t res;
|
|
uint32_t const zero = 0u;
|
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
|
asm volatile("max.f16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero));
|
|
#else
|
|
asm volatile(
|
|
"{\n"
|
|
"\t .reg .f16x2 sela;\n"
|
|
"\t set.gtu.u32.f16x2 sela, %1, %2;\n"
|
|
"\t and.b32 %0, sela, %1;\n"
|
|
"}\n"
|
|
: "=r"(res)
|
|
: "r"(x), "r"(zero));
|
|
#endif
|
|
return res;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static inline __device__ uint32_t bfrelu2(uint32_t x)
|
|
{
|
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
|
uint32_t res;
|
|
uint32_t const zero = 0u;
|
|
asm volatile("max.bf16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero));
|
|
return res;
|
|
#endif
|
|
// not implemented yet
|
|
return x;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// Template function to support both half and bfloat16
|
|
template <typename Data_type>
|
|
inline __device__ uint32_t relu2(uint32_t x)
|
|
{
|
|
return hrelu2(x);
|
|
}
|
|
|
|
template <>
|
|
inline __device__ uint32_t relu2<bf16_t>(uint32_t x)
|
|
{
|
|
return bfrelu2(x);
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static inline __device__ uint32_t habs2(uint32_t x)
|
|
{
|
|
uint32_t res;
|
|
asm volatile("abs.f16x2 %0, %1;\n" : "=r"(res) : "r"(x));
|
|
return res;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// static inline __device__ uint32_t add_bias(uint32_t a, uint32_t bias, bool relu) {
|
|
// uint32_t c;
|
|
// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
|
// if( relu ) {
|
|
// uint32_t one = 0x3c003c00u;
|
|
// asm volatile("fma.rn.f16x2.relu %0, %1, %2, %3;" : "=r"(c) : "r"(a), "r"(one), "r"(bias));
|
|
// } else {
|
|
// c = hadd2(a, bias);
|
|
// }
|
|
// #else
|
|
// c = hadd2(a, bias);
|
|
// if( relu ) {
|
|
// c = hrelu2(c);
|
|
// }
|
|
// #endif
|
|
// return c;
|
|
// }
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// static inline __device__ uint2 add_bias(uint2 a, uint2 bias, bool relu) {
|
|
// uint2 dst;
|
|
// dst.x = add_bias(a.x, bias.x, relu);
|
|
// dst.y = add_bias(a.y, bias.y, relu);
|
|
// return dst;
|
|
// }
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// static inline __device__ uint4 add_bias(uint4 a, uint4 bias, bool relu) {
|
|
// uint4 dst;
|
|
// dst.x = add_bias(a.x, bias.x, relu);
|
|
// dst.y = add_bias(a.y, bias.y, relu);
|
|
// dst.z = add_bias(a.z, bias.z, relu);
|
|
// dst.w = add_bias(a.w, bias.w, relu);
|
|
// return dst;
|
|
// }
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// clamp float +inf/-inf
|
|
static inline __device__ float satfinite(float x)
|
|
{
|
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 860
|
|
// bit representation of maximum value of float
|
|
uint32_t clamp_value = 0x7f7fffffu;
|
|
asm volatile("min.xorsign.abs.f32 %0, %0, %1;" : "+f"(x) : "r"(clamp_value));
|
|
return x;
|
|
#else
|
|
// bit representation of maximum and minimum value of float
|
|
uint32_t umax = 0x7f7fffffu;
|
|
uint32_t umin = 0xff7fffffu;
|
|
float out;
|
|
asm volatile("min.f32 %0, %1, %2;" : "=f"(out) : "f"(x), "r"(umax));
|
|
asm volatile("max.f32 %0, %0, %1;" : "+f"(out) : "r"(umin));
|
|
return out;
|
|
#endif
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// clamp half2 +inf/-inf
|
|
static inline __device__ uint32_t satfinite_h2(uint32_t h2)
|
|
{
|
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 860
|
|
uint32_t out, clamp_value;
|
|
clamp_value = 0x7bff7bffu;
|
|
asm volatile("min.xorsign.abs.f16x2 %0, %1, %2;" : "=r"(out) : "r"(h2), "r"(clamp_value));
|
|
return out;
|
|
#elif defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 800
|
|
// bit representation of maximum and minimum value of half2
|
|
uint32_t umax = 0x7bff7bffu;
|
|
uint32_t umin = 0xfbfffbffu;
|
|
uint32_t out;
|
|
asm volatile("min.f16x2 %0, %1, %2;" : "=r"(out) : "r"(h2), "r"(umax));
|
|
asm volatile("max.f16x2 %0, %0, %1;" : "+r"(out) : "r"(umin));
|
|
return out;
|
|
#else
|
|
// Take the absolute value of h2. It should map to |Rx| in SASS.
|
|
uint32_t p2;
|
|
asm volatile("abs.f16x2 %0, %1;" : "=r"(p2) : "r"(h2));
|
|
|
|
// Compute a mask for each fp16: 0xffff if +INF and 0x0000 otherwise.
|
|
uint32_t inf2 = 0x7c007c00u;
|
|
uint32_t mask;
|
|
asm volatile("set.eq.u32.f16x2 %0, %1, %2;" : "=r"(mask) : "r"(p2), "r"(inf2));
|
|
|
|
// Recreate the new value. 0x7bff is the max value for FP16.
|
|
p2 = (~mask & p2) | (mask & 0x7bff7bff);
|
|
|
|
// Simply re-add the sign and we're done.
|
|
return p2 | (h2 & 0x80008000);
|
|
#endif
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename T>
|
|
static inline __device__ T clamp(T x, T lb, T ub)
|
|
{
|
|
return x < lb ? lb : (x > ub ? ub : x);
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline __device__ float custom_exp2f(float x, float scale, float scaled_max)
|
|
{
|
|
float d1, d2;
|
|
asm("fma.rz.ftz.f32 %0, %1, %2, %3;" : "=f"(d1) : "f"(x), "f"(scale), "f"(-scaled_max));
|
|
asm("ex2.approx.ftz.f32 %0, %1;" : "=f"(d2) : "f"(d1));
|
|
return d2;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static inline __device__ uint16_t clamp_to_zero(uint16_t x)
|
|
{
|
|
uint16_t mask;
|
|
asm volatile("set.gtu %0, %1, 0;" : "=h"(mask) : "h"(x));
|
|
return mask & x;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static inline __device__ uint16_t float_to_half(float f)
|
|
{
|
|
uint16_t h;
|
|
asm volatile("cvt.rn.f16.f32 %0, %1;" : "=h"(h) : "f"(f));
|
|
return h;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static inline __device__ bf16_t float_to_bf16(float f)
|
|
{
|
|
return __float2bfloat16(f);
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static inline __device__ uint32_t float2_to_half2(float a, float b)
|
|
{
|
|
uint32_t c;
|
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
|
asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(c) : "f"(b), "f"(a));
|
|
#else
|
|
uint16_t lo = float_to_half(a);
|
|
uint16_t hi = float_to_half(b);
|
|
asm volatile("mov.b32 %0, {%1, %2};\n" : "=r"(c) : "h"(lo), "h"(hi));
|
|
#endif
|
|
return c;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static inline __device__ uint32_t float2_to_bf16_x2(float a, float b)
|
|
{
|
|
uint32_t c;
|
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
|
asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n" : "=r"(c) : "f"(b), "f"(a));
|
|
#else
|
|
uint16_t* px = reinterpret_cast<uint16_t*>(&a);
|
|
uint16_t* py = reinterpret_cast<uint16_t*>(&b);
|
|
uint16_t value = px[1];
|
|
uint16_t value2 = py[1];
|
|
|
|
if (px[0] == 0x8000)
|
|
{
|
|
if ((value & 0x1) == 1)
|
|
value++;
|
|
}
|
|
else if (px[0] > 0x8000)
|
|
{
|
|
value++;
|
|
}
|
|
|
|
if (py[0] == 0x8000)
|
|
{
|
|
if ((value2 & 0x1) == 1)
|
|
value2++;
|
|
}
|
|
else if (py[0] > 0x8000)
|
|
{
|
|
value2++;
|
|
}
|
|
|
|
uint32_t high = reinterpret_cast<uint32_t&>(value2);
|
|
c = (high << 16) | value;
|
|
#endif
|
|
return c;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// Template function to support both half and bfloat16
|
|
template <typename Data_type>
|
|
inline __device__ uint32_t float2_to_16bit_2(float a, float b)
|
|
{
|
|
return float2_to_half2(a, b);
|
|
}
|
|
|
|
template <>
|
|
inline __device__ uint32_t float2_to_16bit_2<bf16_t>(float a, float b)
|
|
{
|
|
return float2_to_bf16_x2(a, b);
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static inline __device__ uint32_t float_to_half2(float a)
|
|
{
|
|
return float2_to_half2(a, a);
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static inline __device__ uint32_t float2_to_half2(float2 const& f)
|
|
{
|
|
return float2_to_half2(f.x, f.y);
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static inline __device__ uint32_t float_to_bf16_2(float a)
|
|
{
|
|
return float2_to_bf16_x2(a, a);
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static inline __device__ uint2 float4_to_half4(float x, float y, float z, float w)
|
|
{
|
|
uint2 d;
|
|
d.x = float2_to_half2(x, y);
|
|
d.y = float2_to_half2(z, w);
|
|
return d;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// Template function to support both half and bfloat16
|
|
template <typename Data_type>
|
|
inline __device__ uint2 float4_to_16bit_x4(float x, float y, float z, float w)
|
|
{
|
|
uint2 d;
|
|
d.x = float2_to_half2(x, y);
|
|
d.y = float2_to_half2(z, w);
|
|
return d;
|
|
}
|
|
|
|
template <>
|
|
inline __device__ uint2 float4_to_16bit_x4<bf16_t>(float x, float y, float z, float w)
|
|
{
|
|
uint2 d;
|
|
d.x = float2_to_bf16_x2(x, y);
|
|
d.y = float2_to_bf16_x2(z, w);
|
|
return d;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static inline __device__ uint32_t hfma2(uint32_t a, uint32_t b, uint32_t c)
|
|
{
|
|
uint32_t d;
|
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c));
|
|
return d;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static inline __device__ uint32_t hfma2_relu(uint32_t a, uint32_t b, uint32_t c)
|
|
{
|
|
uint32_t d;
|
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
|
asm volatile("fma.rn.f16x2.relu %0, %1, %2, %3;" : "=r"(d) : "r"(a), "r"(b), "r"(c));
|
|
#else
|
|
d = hrelu2(hfma2(a, b, c));
|
|
#endif
|
|
return d;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static inline __device__ uint32_t h0_h0(uint32_t x)
|
|
{
|
|
uint32_t y;
|
|
asm volatile("{.reg .f16 lo, hi; mov.b32 {lo, hi}, %1; mov.b32 %0, {lo, lo};}\n" : "=r"(y) : "r"(x));
|
|
return y;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static inline __device__ float h0_to_float(uint32_t h2)
|
|
{
|
|
float f;
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .f16 lo, hi;\n"
|
|
"mov.b32 {lo, hi}, %1;\n"
|
|
"cvt.f32.f16 %0, lo;\n"
|
|
"}\n"
|
|
: "=f"(f)
|
|
: "r"(h2));
|
|
return f;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static inline __device__ uint32_t h1_h1(uint32_t x)
|
|
{
|
|
uint32_t y;
|
|
asm volatile("{.reg .f16 lo, hi; mov.b32 {lo, hi}, %1; mov.b32 %0, {hi, hi};}\n" : "=r"(y) : "r"(x));
|
|
return y;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static inline __device__ uint16_t hadd(uint16_t a, uint16_t b)
|
|
{
|
|
uint16_t d;
|
|
asm volatile("add.f16 %0, %1, %2;" : "=h"(d) : "h"(a), "h"(b));
|
|
return d;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static inline __device__ uint32_t hadd(uint32_t a, uint32_t b)
|
|
{
|
|
return hadd2(a, b);
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static inline __device__ uint2 hadd4(uint2 a, uint2 b)
|
|
{
|
|
uint2 c;
|
|
c.x = hadd2(a.x, b.x);
|
|
c.y = hadd2(a.y, b.y);
|
|
return c;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static inline __device__ uint2 hadd(uint2 a, uint2 b)
|
|
{
|
|
return hadd4(a, b);
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static inline __device__ uint4 hadd8(uint4 a, uint4 b)
|
|
{
|
|
uint4 c;
|
|
c.x = hadd2(a.x, b.x);
|
|
c.y = hadd2(a.y, b.y);
|
|
c.z = hadd2(a.z, b.z);
|
|
c.w = hadd2(a.w, b.w);
|
|
return c;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// Template function to support both half and bfloat16
|
|
template <typename Data_type>
|
|
inline __device__ uint4 add8(uint4 a, uint4 b)
|
|
{
|
|
return hadd8(a, b);
|
|
}
|
|
|
|
template <>
|
|
inline __device__ uint4 add8<bf16_t>(uint4 a, uint4 b)
|
|
{
|
|
uint4 c;
|
|
c.x = bfadd2(a.x, b.x);
|
|
c.y = bfadd2(a.y, b.y);
|
|
c.z = bfadd2(a.z, b.z);
|
|
c.w = bfadd2(a.w, b.w);
|
|
return c;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static inline __device__ uint4 fadd4(uint4 a, uint4 b)
|
|
{
|
|
float4 c;
|
|
c.x = reinterpret_cast<float const&>(a.x) + reinterpret_cast<float const&>(b.x);
|
|
c.y = reinterpret_cast<float const&>(a.y) + reinterpret_cast<float const&>(b.y);
|
|
c.z = reinterpret_cast<float const&>(a.z) + reinterpret_cast<float const&>(b.z);
|
|
c.w = reinterpret_cast<float const&>(a.w) + reinterpret_cast<float const&>(b.w);
|
|
return reinterpret_cast<uint4 const&>(c);
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static inline __device__ uint4 hadd(uint4 a, uint4 b)
|
|
{
|
|
return hadd8(a, b);
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static inline __device__ float half_to_float(uint16_t h)
|
|
{
|
|
float f;
|
|
asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h));
|
|
return f;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static inline __device__ float bf16_to_float(uint16_t h)
|
|
{
|
|
float f;
|
|
asm volatile("mov.b32 %0, {0, %1};\n" : "=f"(f) : "h"(h));
|
|
return f;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static inline __device__ float2 half2_to_float2(uint32_t x)
|
|
{
|
|
uint16_t lo, hi;
|
|
asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(x));
|
|
return make_float2(half_to_float(lo), half_to_float(hi));
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static inline __device__ float2 bf16_2_to_float2(uint32_t x)
|
|
{
|
|
float2 res;
|
|
asm volatile(
|
|
"{\n"
|
|
" .reg .b16 lo, hi;\n"
|
|
" mov.b32 {lo, hi}, %2;\n"
|
|
" mov.b32 %0, {0, lo};\n"
|
|
" mov.b32 %1, {0, hi};\n"
|
|
"}\n"
|
|
: "=f"(res.x), "=f"(res.y)
|
|
: "r"(x));
|
|
return res;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// Template function to support both half and bfloat16
|
|
template <typename Data_type>
|
|
inline __device__ float2 convert_from_16bit_2(uint32_t x)
|
|
{
|
|
return half2_to_float2(x);
|
|
}
|
|
|
|
template <>
|
|
inline __device__ float2 convert_from_16bit_2<bf16_t>(uint32_t x)
|
|
{
|
|
return bf16_2_to_float2(x);
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static inline __device__ void half2_to_float2(float& x, float& y, uint32_t h)
|
|
{
|
|
float2 tmp = half2_to_float2(h);
|
|
x = tmp.x;
|
|
y = tmp.y;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static inline __device__ uint16_t hfma(uint16_t a, uint16_t b, uint16_t c)
|
|
{
|
|
uint16_t d;
|
|
asm volatile("fma.rn.f16 %0, %1, %2, %3;" : "=h"(d) : "h"(a), "h"(b), "h"(c));
|
|
return d;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static inline __device__ uint16_t hmul(uint16_t a, uint16_t b)
|
|
{
|
|
uint16_t d;
|
|
asm volatile("mul.f16 %0, %1, %2;" : "=h"(d) : "h"(a), "h"(b));
|
|
return d;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// Converted two half2's or bf162's into float, then take their dot product.
|
|
template <typename Data_type>
|
|
inline __device__ float fma2_in_float(uint32_t const a, uint32_t const b)
|
|
{
|
|
float2 af = fmha::convert_from_16bit_2<Data_type>(a);
|
|
float2 bf = fmha::convert_from_16bit_2<Data_type>(b);
|
|
return af.x * bf.x + af.y * bf.y;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// Converted two vectors of 8 half's or bf16's into float, then take their dot product.
|
|
template <typename Data_type>
|
|
inline __device__ float fma8_in_float(uint4 const a, uint4 const b)
|
|
{
|
|
float sum;
|
|
sum = fmha::fma2_in_float<Data_type>(a.x, b.x);
|
|
sum += fmha::fma2_in_float<Data_type>(a.y, b.y);
|
|
sum += fmha::fma2_in_float<Data_type>(a.z, b.z);
|
|
sum += fmha::fma2_in_float<Data_type>(a.w, b.w);
|
|
return sum;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static inline __device__ float sigmoid(float x)
|
|
{
|
|
return 1.f / (1.f + expf(-x));
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline __device__ void clear(uint16_t& dst)
|
|
{
|
|
dst = uint16_t(0);
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline __device__ void clear(uint32_t& dst)
|
|
{
|
|
dst = 0u;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline __device__ void clear(uint2& dst)
|
|
{
|
|
dst = make_uint2(0u, 0u);
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline __device__ void clear(uint4& dst)
|
|
{
|
|
dst = make_uint4(0u, 0u, 0u, 0u);
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
//
|
|
// P R E D I C A T E P A C K I N G
|
|
//
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
enum
|
|
{
|
|
BYTES_PER_REG = 4,
|
|
PREDS_PER_BYTE = 4,
|
|
PREDS_PER_REG = BYTES_PER_REG * PREDS_PER_BYTE
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <int LDGS>
|
|
struct Compute_number_of_pred_regs
|
|
{
|
|
enum
|
|
{
|
|
VALUE = Div_up<LDGS, PREDS_PER_REG>::VALUE
|
|
};
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <int M, int N>
|
|
inline __device__ void pack_predicates(uint32_t (&preds)[M], uint32_t const (&p)[N])
|
|
{
|
|
|
|
// Make sure the values match.
|
|
static_assert(Compute_number_of_pred_regs<N>::VALUE == M, "");
|
|
|
|
// The number of complete steps (where we use all the predicates in a byte).
|
|
enum
|
|
{
|
|
COMPLETE_BYTES = N / PREDS_PER_BYTE
|
|
};
|
|
|
|
// Make sure we allocated enough predicate registers.
|
|
static_assert(Div_up<COMPLETE_BYTES, BYTES_PER_REG>::VALUE <= M, "");
|
|
|
|
// The remainder.
|
|
enum
|
|
{
|
|
REMAINDER = N - COMPLETE_BYTES * PREDS_PER_BYTE
|
|
};
|
|
|
|
// Make sure we got the math right and the remainder is between 0 and 3.
|
|
static_assert(REMAINDER >= 0 && REMAINDER <= 3, "");
|
|
|
|
// The mask to extract the predicates.
|
|
enum
|
|
{
|
|
COMPLETE_MASK = (1 << PREDS_PER_BYTE) - 1
|
|
};
|
|
|
|
// Run complete steps.
|
|
#pragma unroll
|
|
for (int ii = 0; ii < M; ++ii)
|
|
{
|
|
|
|
// The number of complete bytes for that register. Be careful it can be > than 4 ;)
|
|
int const COMPLETE = (N - ii * PREDS_PER_REG) / PREDS_PER_BYTE;
|
|
|
|
// Pack the predicates in a register.
|
|
uint32_t reg = 0u;
|
|
#pragma unroll
|
|
for (int jj = 0; jj < 4; ++jj)
|
|
{
|
|
|
|
// Early exit.
|
|
if (jj >= COMPLETE)
|
|
{
|
|
break;
|
|
}
|
|
|
|
// Prepare the array of predicates.
|
|
bool tmp[PREDS_PER_BYTE];
|
|
#pragma unroll
|
|
for (int kk = 0; kk < PREDS_PER_BYTE; ++kk)
|
|
{
|
|
tmp[kk] = p[ii * PREDS_PER_REG + jj * PREDS_PER_BYTE + kk] != 0;
|
|
}
|
|
|
|
// Store the predicates.
|
|
#pragma unroll
|
|
for (int kk = 0; kk < PREDS_PER_BYTE; ++kk)
|
|
{
|
|
if (tmp[kk])
|
|
{
|
|
reg |= 1u << (jj * 8 + kk);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Skip the rest of the code if we do not have a remainder.
|
|
if (COMPLETE < 4 && REMAINDER > 0)
|
|
{
|
|
|
|
// The mask to extract the predicates.
|
|
enum
|
|
{
|
|
REMAINDER_MASK = (1 << REMAINDER) - 1
|
|
};
|
|
|
|
// Prepare the array of predicates.
|
|
bool tmp[PREDS_PER_BYTE];
|
|
#pragma unroll
|
|
for (int jj = 0; jj < REMAINDER; ++jj)
|
|
{
|
|
tmp[jj] = p[COMPLETE_BYTES * PREDS_PER_BYTE + jj] != 0;
|
|
}
|
|
|
|
// Store the predicates.
|
|
#pragma unroll
|
|
for (int jj = 0; jj < REMAINDER; ++jj)
|
|
{
|
|
if (tmp[jj])
|
|
{
|
|
reg |= 1u << (COMPLETE * 8 + jj);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Store the predicate register.
|
|
preds[ii] = reg;
|
|
}
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <int N>
|
|
inline __device__ uint32_t pack_predicates(uint32_t const (&p)[N])
|
|
{
|
|
uint32_t tmp[1];
|
|
pack_predicates(tmp, p);
|
|
return tmp[0];
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
//
|
|
// G E N E R I C P R E D I C A T E D L D G S T S
|
|
//
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <int N, int M, typename Functor>
|
|
inline __device__ void ldgsts_(Functor& fct, uint32_t const (&preds)[M])
|
|
{
|
|
|
|
// The number of complete bytes (where we use all the predicates in a byte).
|
|
enum
|
|
{
|
|
COMPLETE = N / PREDS_PER_BYTE
|
|
};
|
|
|
|
// Make sure we did allocate enough predicates.
|
|
static_assert(Div_up<COMPLETE, BYTES_PER_REG>::VALUE <= M, "");
|
|
|
|
// The remainder.
|
|
enum
|
|
{
|
|
REMAINDER = N - COMPLETE * PREDS_PER_BYTE
|
|
};
|
|
|
|
// Make sure we got the math right and the remainder is between 0 and 3.
|
|
static_assert(REMAINDER >= 0 && REMAINDER <= 3, "");
|
|
|
|
// The mask to extract the predicates.
|
|
enum
|
|
{
|
|
COMPLETE_MASK = (1 << PREDS_PER_BYTE) - 1
|
|
};
|
|
|
|
// Clear the fetch registers.
|
|
#pragma unroll
|
|
for (int ii = 0; ii < N; ++ii)
|
|
{
|
|
fct.clear(ii);
|
|
}
|
|
|
|
// Run complete steps.
|
|
bool p[PREDS_PER_BYTE];
|
|
#pragma unroll
|
|
for (int ii = 0; ii < COMPLETE; ++ii)
|
|
{
|
|
|
|
// The predicate.
|
|
uint32_t reg = preds[ii / BYTES_PER_REG];
|
|
|
|
// Extract the predicates.
|
|
#pragma unroll
|
|
for (int jj = 0; jj < PREDS_PER_BYTE; ++jj)
|
|
{
|
|
uint32_t mask = 1u << (ii % BYTES_PER_REG * 8 + jj);
|
|
p[jj] = (reg & mask) != 0u;
|
|
}
|
|
|
|
// Issue the loads.
|
|
#pragma unroll
|
|
for (int jj = 0; jj < PREDS_PER_BYTE; ++jj)
|
|
{
|
|
fct.ldgsts(ii * PREDS_PER_BYTE + jj, p[jj]);
|
|
}
|
|
}
|
|
|
|
// Skip the rest of the code if we do not have a remainder.
|
|
if (REMAINDER > 0)
|
|
{
|
|
|
|
// The mask to extract the predicates.
|
|
enum
|
|
{
|
|
REMAINDER_MASK = (1 << REMAINDER) - 1
|
|
};
|
|
|
|
// The predicate register.
|
|
uint32_t reg = preds[COMPLETE / BYTES_PER_REG];
|
|
|
|
// Extract the predicates.
|
|
#pragma unroll
|
|
for (int jj = 0; jj < PREDS_PER_BYTE; ++jj)
|
|
{
|
|
uint32_t mask = 1u << (COMPLETE % BYTES_PER_REG * 8 + jj);
|
|
p[jj] = (reg & mask) != 0u;
|
|
}
|
|
|
|
// Issue the loads.
|
|
#pragma unroll
|
|
for (int ii = 0; ii < REMAINDER; ++ii)
|
|
{
|
|
fct.ldgsts(COMPLETE * PREDS_PER_BYTE + ii, p[ii]);
|
|
}
|
|
}
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <int M, typename Functor>
|
|
inline __device__ void ldgsts_(Functor& fct, uint32_t preds)
|
|
{
|
|
uint32_t tmp[1] = {preds};
|
|
ldgsts_<M>(fct, tmp);
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
//
|
|
// L D G
|
|
//
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline __device__ void ldg(uint8_t& dst, void const* ptr)
|
|
{
|
|
dst = *reinterpret_cast<uint8_t const*>(ptr);
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline __device__ void ldg(uint16_t& dst, void const* ptr)
|
|
{
|
|
dst = *reinterpret_cast<uint16_t const*>(ptr);
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline __device__ void ldg(uint32_t& dst, void const* ptr)
|
|
{
|
|
dst = *reinterpret_cast<uint32_t const*>(ptr);
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline __device__ void ldg(uint2& dst, void const* ptr)
|
|
{
|
|
dst = *reinterpret_cast<uint2 const*>(ptr);
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline __device__ void ldg(uint4& dst, void const* ptr)
|
|
{
|
|
dst = *reinterpret_cast<uint4 const*>(ptr);
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename Data_type, int N>
|
|
struct Ldg_functor
|
|
{
|
|
// Ctor.
|
|
inline __device__ Ldg_functor(Data_type (&fetch)[N], void const* (&ptrs)[N])
|
|
: fetch_(fetch)
|
|
, ptrs_(ptrs)
|
|
{
|
|
}
|
|
|
|
// Clear the element.
|
|
inline __device__ void clear(int ii)
|
|
{
|
|
fmha::clear(fetch_[ii]);
|
|
}
|
|
|
|
// Trigger the loads.
|
|
inline __device__ void ldgsts(int ii, bool p)
|
|
{
|
|
if (p)
|
|
{
|
|
ldg(fetch_[ii], ptrs_[ii]);
|
|
}
|
|
}
|
|
|
|
// The fetch registers.
|
|
Data_type (&fetch_)[N];
|
|
// The pointers.
|
|
void const* (&ptrs_)[N];
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename Data_type, int N, int M>
|
|
inline __device__ void ldg_(Data_type (&fetch)[N], void const* (&ptrs)[N], uint32_t (&preds)[M])
|
|
{
|
|
Ldg_functor<Data_type, N> fct(fetch, ptrs);
|
|
ldgsts_<N>(fct, preds);
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <int N, int M>
|
|
inline __device__ void ldg(uint8_t (&fetch)[N], void const* (&ptrs)[N], uint32_t (&preds)[M])
|
|
{
|
|
ldg_<uint8_t, N>(fetch, ptrs, preds);
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <int N, int M>
|
|
inline __device__ void ldg(uint16_t (&fetch)[N], void const* (&ptrs)[N], uint32_t (&preds)[M])
|
|
{
|
|
ldg_<uint16_t, N>(fetch, ptrs, preds);
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <int N, int M>
|
|
inline __device__ void ldg(uint32_t (&fetch)[N], void const* (&ptrs)[N], uint32_t (&preds)[M])
|
|
{
|
|
ldg_<uint32_t, N>(fetch, ptrs, preds);
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <int N, int M>
|
|
inline __device__ void ldg(uint2 (&fetch)[N], void const* (&ptrs)[N], uint32_t (&preds)[M])
|
|
{
|
|
ldg_<uint2, N>(fetch, ptrs, preds);
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <int N, int M>
|
|
inline __device__ void ldg(uint4 (&fetch)[N], void const* (&ptrs)[N], uint32_t (&preds)[M])
|
|
{
|
|
ldg_<uint4, N>(fetch, ptrs, preds);
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <bool USE_LDGSTS>
|
|
inline __device__ void ldgdepbar()
|
|
{
|
|
if (USE_LDGSTS)
|
|
{
|
|
asm volatile("cp.async.commit_group;\n" ::);
|
|
}
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <bool USE_LDGSTS, int COUNT = 0>
|
|
inline __device__ void depbar_()
|
|
{
|
|
if (USE_LDGSTS)
|
|
{
|
|
asm volatile("cp.async.wait_group %0;\n" ::"n"(COUNT));
|
|
}
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <bool USE_LDGSTS, int STAGES>
|
|
inline __device__ void depbar()
|
|
{
|
|
if (USE_LDGSTS)
|
|
{
|
|
int const VALUE = Max<STAGES - 2, 0>::VALUE;
|
|
asm volatile("cp.async.wait_group %0;\n" ::"n"(VALUE));
|
|
}
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline __device__ void ldgsts128(uint32_t dst, void const* src, bool p = true)
|
|
{
|
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
|
uint32_t m = p ? 16u : 0u;
|
|
asm volatile("cp.async.cg.shared.global [%0], [%1], 16, %2;\n" ::"r"(dst), "l"(src), "r"(m));
|
|
#endif
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <int N>
|
|
struct Ldgsts_functor
|
|
{
|
|
// Ctor.
|
|
inline __device__ Ldgsts_functor(uint32_t (&smem_ptrs)[N], void const* (&gmem_ptrs)[N])
|
|
: smem_ptrs_(smem_ptrs)
|
|
, gmem_ptrs_(gmem_ptrs)
|
|
{
|
|
}
|
|
|
|
// Does nothing.
|
|
inline __device__ void clear(int ii) {}
|
|
|
|
// Trigger the load-store instruction.
|
|
inline __device__ void ldgsts(int ii, bool p)
|
|
{
|
|
ldgsts128(smem_ptrs_[ii], gmem_ptrs_[ii], p);
|
|
}
|
|
|
|
// The shared memory pointers.
|
|
uint32_t (&smem_ptrs_)[N];
|
|
// The global memory pointers.
|
|
void const* (&gmem_ptrs_)[N];
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <int N, int M>
|
|
inline __device__ void ldgsts(uint32_t (&dst)[N], void const* (&src)[N], uint32_t (&preds)[M])
|
|
{
|
|
Ldgsts_functor<N> fct(dst, src);
|
|
ldgsts_<N>(fct, preds);
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
//
|
|
// L D S
|
|
//
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline __device__ void lds(uint16_t& dst, uint32_t ptr)
|
|
{
|
|
asm volatile("ld.shared.b16 %0, [%1];\n" : "=h"(dst) : "r"(ptr));
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline __device__ void lds(uint32_t& dst, uint32_t ptr)
|
|
{
|
|
asm volatile("ld.shared.b32 %0, [%1];\n" : "=r"(dst) : "r"(ptr));
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline __device__ void lds(uint2& dst, uint32_t ptr)
|
|
{
|
|
asm volatile("ld.shared.v2.b32 {%0, %1}, [%2];\n" : "=r"(dst.x), "=r"(dst.y) : "r"(ptr));
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline __device__ void lds(uint4& dst, uint32_t ptr)
|
|
{
|
|
asm volatile("ld.shared.v4.b32 {%0, %1, %2, %3}, [%4];\n"
|
|
: "=r"(dst.x), "=r"(dst.y), "=r"(dst.z), "=r"(dst.w)
|
|
: "r"(ptr));
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
//
|
|
// L D S M
|
|
//
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline __device__ void ldsm(uint32_t& dst, uint32_t ptr)
|
|
{
|
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730
|
|
asm volatile("ldmatrix.sync.aligned.m8n8.x1.shared.b16 {%0}, [%1];\n" : "=r"(dst) : "r"(ptr));
|
|
#endif
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline __device__ void ldsmt(uint32_t& dst, uint32_t ptr)
|
|
{
|
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730
|
|
asm volatile("ldmatrix.sync.aligned.m8n8.x1.trans.shared.b16 {%0}, [%1];\n" : "=r"(dst) : "r"(ptr));
|
|
#endif
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline __device__ void ldsm(uint2& dst, uint32_t ptr)
|
|
{
|
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730
|
|
asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0, %1}, [%2];\n" : "=r"(dst.x), "=r"(dst.y) : "r"(ptr));
|
|
#endif
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline __device__ void ldsmt(uint2& dst, uint32_t ptr)
|
|
{
|
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730
|
|
asm volatile("ldmatrix.sync.aligned.m8n8.x2.trans.shared.b16 {%0, %1}, [%2];\n"
|
|
: "=r"(dst.x), "=r"(dst.y)
|
|
: "r"(ptr));
|
|
#endif
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline __device__ void ldsm(uint4& dst, uint32_t ptr)
|
|
{
|
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730
|
|
asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n"
|
|
: "=r"(dst.x), "=r"(dst.y), "=r"(dst.z), "=r"(dst.w)
|
|
: "r"(ptr));
|
|
#endif
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline __device__ void ldsmt(uint4& dst, uint32_t ptr)
|
|
{
|
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730
|
|
asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];\n"
|
|
: "=r"(dst.x), "=r"(dst.y), "=r"(dst.z), "=r"(dst.w)
|
|
: "r"(ptr));
|
|
#endif
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
//
|
|
// S T S M
|
|
//
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline __device__ void stsm(uint32_t ptr, uint32_t const& src)
|
|
{
|
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
|
|
asm volatile("stmatrix.sync.aligned.m8n8.x1.shared.b16 [%0], {%1};\n" ::"r"(ptr), "r"(src));
|
|
#endif
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline __device__ void stsmt(uint32_t ptr, uint32_t const& src)
|
|
{
|
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
|
|
asm volatile("stmatrix.sync.aligned.m8n8.x1.trans.shared.b16 [%0], {%1};\n" ::"r"(ptr), "r"(src));
|
|
#endif
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline __device__ void stsm(uint32_t ptr, uint2 const& src)
|
|
{
|
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
|
|
asm volatile("stmatrix.sync.aligned.m8n8.x2.shared.b16 [%0], {%1, %2};\n" ::"r"(ptr), "r"(src.x), "r"(src.y));
|
|
#endif
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline __device__ void stsmt(uint32_t ptr, uint2 const& src)
|
|
{
|
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
|
|
asm volatile("stmatrix.sync.aligned.m8n8.x2.trans.shared.b16 [%0], {%1, %2};\n" ::"r"(ptr), "r"(src.x), "r"(src.y));
|
|
#endif
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline __device__ void stsm(uint32_t ptr, uint4 const& src)
|
|
{
|
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
|
|
asm volatile("stmatrix.sync.aligned.m8n8.x4.shared.b16 [%0], {%1, %2, %3, %4};\n" ::"r"(ptr), "r"(src.x),
|
|
"r"(src.y), "r"(src.z), "r"(src.w));
|
|
#endif
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline __device__ void stsmt(uint32_t ptr, uint4 const& src)
|
|
{
|
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
|
|
asm volatile("stmatrix.sync.aligned.m8n8.x4.trans.shared.b16 [%0], {%1, %2, %3, %4};\n" ::"r"(ptr), "r"(src.x),
|
|
"r"(src.y), "r"(src.z), "r"(src.w));
|
|
#endif
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
//
|
|
// S T G
|
|
//
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline __device__ void stg(void* ptr, float val)
|
|
{
|
|
*reinterpret_cast<float*>(ptr) = val;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline __device__ void stg(void* ptr, uint8_t val)
|
|
{
|
|
*reinterpret_cast<uint8_t*>(ptr) = val;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline __device__ void stg(void* ptr, uint16_t val)
|
|
{
|
|
*reinterpret_cast<uint16_t*>(ptr) = val;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline __device__ void stg(void* ptr, uint32_t val)
|
|
{
|
|
*reinterpret_cast<uint32_t*>(ptr) = val;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline __device__ void stg(void* ptr, uint2 val)
|
|
{
|
|
*reinterpret_cast<uint2*>(ptr) = val;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline __device__ void stg(void* ptr, uint4 val)
|
|
{
|
|
*reinterpret_cast<uint4*>(ptr) = val;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
//
|
|
// S T S
|
|
//
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline __device__ void sts(uint32_t ptr, uint16_t val)
|
|
{
|
|
asm volatile("st.shared.b16 [%0], %1;\n" : : "r"(ptr), "h"(val));
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline __device__ void sts(uint32_t ptr, uint32_t val)
|
|
{
|
|
asm volatile("st.shared.b32 [%0], %1;\n" : : "r"(ptr), "r"(val));
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline __device__ void sts(uint32_t ptr, uint2 val)
|
|
{
|
|
asm volatile("st.shared.v2.b32 [%0], {%1, %2};\n" : : "r"(ptr), "r"(val.x), "r"(val.y));
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline __device__ void sts(uint32_t ptr, uint4 val)
|
|
{
|
|
asm volatile("st.shared.v4.b32 [%0], {%1, %2, %3, %4};\n"
|
|
:
|
|
: "r"(ptr), "r"(val.x), "r"(val.y), "r"(val.z), "r"(val.w));
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename Data_type, int N>
|
|
inline __device__ void sts_(uint32_t (&ptrs)[N], Data_type const (&data)[N])
|
|
{
|
|
#pragma unroll
|
|
for (int ii = 0; ii < N; ++ii)
|
|
{
|
|
sts(ptrs[ii], data[ii]);
|
|
}
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <int N>
|
|
inline __device__ void sts(uint32_t (&ptrs)[N], uint16_t const (&data)[N])
|
|
{
|
|
sts_<uint16_t, N>(ptrs, data);
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <int N>
|
|
inline __device__ void sts(uint32_t (&ptrs)[N], uint32_t const (&data)[N])
|
|
{
|
|
sts_<uint32_t, N>(ptrs, data);
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <int N>
|
|
inline __device__ void sts(uint32_t (&ptrs)[N], uint2 const (&data)[N])
|
|
{
|
|
sts_<uint2, N>(ptrs, data);
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <int N>
|
|
inline __device__ void sts(uint32_t (&ptrs)[N], uint4 const (&data)[N])
|
|
{
|
|
sts_<uint4, N>(ptrs, data);
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
#define __HALF2_TO_UI(var) *(reinterpret_cast<unsigned int*>(&(var)))
|
|
#define __HALF2_TO_CUI(var) *(reinterpret_cast<const unsigned int*>(&(var)))
|
|
|
|
static __device__ __inline__ void atomicAdd_half2(half2* const address, const half2 val)
|
|
{
|
|
asm volatile("{ red.global.add.noftz.f16x2 [%0],%1; }\n" ::"l"(address), "r"(__HALF2_TO_CUI(val)) : "memory");
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <bool CAN_BE_NEGATIVE>
|
|
static inline __device__ uint32_t float4_to_char4(float x, float y, float z, float w)
|
|
{
|
|
#if defined(USE_F2I_EMULATION_TRICK)
|
|
// Make sure the float is in the proper range.
|
|
float cx, cy, cz, cw;
|
|
if (CAN_BE_NEGATIVE)
|
|
{
|
|
cx = fmha::clamp(x, -128.f, 127.f);
|
|
cy = fmha::clamp(y, -128.f, 127.f);
|
|
cz = fmha::clamp(z, -128.f, 127.f);
|
|
cw = fmha::clamp(w, -128.f, 127.f);
|
|
}
|
|
else
|
|
{
|
|
cx = fminf(x, 127.f);
|
|
cy = fminf(y, 127.f);
|
|
cz = fminf(z, 127.f);
|
|
cw = fminf(w, 127.f);
|
|
}
|
|
|
|
// Re-add the magic number.
|
|
cx += FP32_I2F_MAGIC_NUMBER;
|
|
cy += FP32_I2F_MAGIC_NUMBER;
|
|
cz += FP32_I2F_MAGIC_NUMBER;
|
|
cw += FP32_I2F_MAGIC_NUMBER;
|
|
|
|
// We need unsigned ints...
|
|
uint32_t a = reinterpret_cast<uint32_t const&>(cx);
|
|
uint32_t b = reinterpret_cast<uint32_t const&>(cy);
|
|
uint32_t c = reinterpret_cast<uint32_t const&>(cz);
|
|
uint32_t d = reinterpret_cast<uint32_t const&>(cw);
|
|
|
|
// Pack the numbers.
|
|
uint32_t dst;
|
|
asm volatile("prmt.b32 %0, %1, %2, 0x0040;\n" : "=r"(dst) : "r"(a), "r"(b));
|
|
asm volatile("prmt.b32 %0, %0, %1, 0x0410;\n" : "+r"(dst) : "r"(c));
|
|
asm volatile("prmt.b32 %0, %0, %1, 0x4210;\n" : "+r"(dst) : "r"(d));
|
|
return dst;
|
|
#else
|
|
uint32_t a;
|
|
asm volatile("cvt.rni.sat.s32.f32 %0, %1;\n" : "=r"(a) : "f"(x));
|
|
uint32_t b;
|
|
asm volatile("cvt.rni.sat.s32.f32 %0, %1;\n" : "=r"(b) : "f"(y));
|
|
uint32_t c;
|
|
asm volatile("cvt.rni.sat.s32.f32 %0, %1;\n" : "=r"(c) : "f"(z));
|
|
uint32_t d;
|
|
asm volatile("cvt.rni.sat.s32.f32 %0, %1;\n" : "=r"(d) : "f"(w));
|
|
|
|
uint32_t dst;
|
|
asm volatile("cvt.pack.sat.s8.s32.b32 %0, %1, %2, 0;\n" : "=r"(dst) : "r"(d), "r"(c));
|
|
asm volatile("cvt.pack.sat.s8.s32.b32 %0, %1, %2, %0;\n" : "+r"(dst) : "r"(b), "r"(a));
|
|
return dst;
|
|
#endif // defined(USE_F2I_EMULATION_TRICK)
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static inline __device__ void swizzle_rows(uint32_t& a, uint32_t& b, uint32_t c, uint32_t d)
|
|
{
|
|
asm volatile("prmt.b32 %0, %1, %2, 0x6420;\n" : "=r"(a) : "r"(c), "r"(d));
|
|
asm volatile("prmt.b32 %0, %1, %2, 0x7531;\n" : "=r"(b) : "r"(c), "r"(d));
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline __device__ void ldsm_with_lds(uint2& data, uint32_t smem)
|
|
{
|
|
int lane = threadIdx.x % 32;
|
|
data = {0, 0};
|
|
uint4 v = {0, 0, 0, 0};
|
|
uint32_t* a = reinterpret_cast<uint32_t*>(&v);
|
|
if (lane < 16)
|
|
{
|
|
fmha::lds(v, smem);
|
|
}
|
|
int src_row = lane / 4;
|
|
int src_col = lane % 4;
|
|
for (int it = 0; it < 4; it++)
|
|
{
|
|
uint32_t val = a[it];
|
|
uint32_t x = __shfl_sync(uint32_t(-1), val, src_row);
|
|
__syncwarp();
|
|
uint32_t y = __shfl_sync(uint32_t(-1), val, src_row + 8);
|
|
__syncwarp();
|
|
if (it == src_col)
|
|
{
|
|
data.x = x;
|
|
data.y = y;
|
|
}
|
|
}
|
|
}
|
|
|
|
inline __device__ void ldsmt_with_lds(uint2& data, uint32_t smem)
|
|
{
|
|
int lane = threadIdx.x % 32;
|
|
|
|
uint4 tmp16{0, 0, 0, 0}; // 16B
|
|
|
|
if (lane < 16)
|
|
{
|
|
fmha::lds(tmp16, smem);
|
|
}
|
|
|
|
uint16_t* tmp16c = reinterpret_cast<uint16_t*>(&tmp16); // 8x2B: we move pairs
|
|
|
|
uint16_t* t = reinterpret_cast<uint16_t*>(&data); // 4x2B
|
|
|
|
int const src_col = lane / 4; // 0 - 7
|
|
int const src_row = (lane % 4) * 2;
|
|
|
|
// we have to shuffle the values to distribute them in the warp
|
|
#pragma unroll
|
|
for (int it = 0; it < 8; it++)
|
|
{
|
|
uint16_t val, x, y;
|
|
val = tmp16c[it];
|
|
x = __shfl_sync(uint32_t(-1), val, src_row + 0);
|
|
__syncwarp();
|
|
y = __shfl_sync(uint32_t(-1), val, src_row + 1);
|
|
__syncwarp();
|
|
|
|
if (src_col == it)
|
|
{
|
|
t[0] = x;
|
|
t[1] = y;
|
|
}
|
|
val = tmp16c[it];
|
|
x = __shfl_sync(uint32_t(-1), val, src_row + 8);
|
|
__syncwarp();
|
|
y = __shfl_sync(uint32_t(-1), val, src_row + 9);
|
|
__syncwarp();
|
|
|
|
if (src_col == it)
|
|
{
|
|
t[2] = x;
|
|
t[3] = y;
|
|
}
|
|
}
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename T>
|
|
struct MaxOp
|
|
{
|
|
__device__ inline T operator()(T const& x, T const& y)
|
|
{
|
|
return x > y ? x : y;
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename T>
|
|
struct SumOp
|
|
{
|
|
__device__ inline T operator()(T const& x, T const& y)
|
|
{
|
|
return x + y;
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <int THREADS>
|
|
struct Allreduce
|
|
{
|
|
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
|
|
|
|
template <typename T, typename Operator>
|
|
static __device__ inline T run(T x, Operator& op)
|
|
{
|
|
constexpr int OFFSET = THREADS / 2;
|
|
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
|
|
return Allreduce<OFFSET>::run(x, op);
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <>
|
|
struct Allreduce<2>
|
|
{
|
|
template <typename T, typename Operator>
|
|
static __device__ inline T run(T x, Operator& op)
|
|
{
|
|
x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
|
|
return x;
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename Operator, int M>
|
|
__device__ inline void quad_reduce(float (&dst)[M], float (&src)[M], Operator& op)
|
|
{
|
|
#pragma unroll
|
|
for (int mi = 0; mi < M; mi++)
|
|
{
|
|
dst[mi] = src[mi];
|
|
dst[mi] = op(dst[mi], __shfl_down_sync(uint32_t(-1), dst[mi], 2));
|
|
dst[mi] = op(dst[mi], __shfl_down_sync(uint32_t(-1), dst[mi], 1));
|
|
}
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename Operator, int M>
|
|
__device__ inline void quad_reduce(float (&dst)[M], float2 (&src)[M], Operator& op)
|
|
{
|
|
float tmp[M];
|
|
#pragma unroll
|
|
for (int mi = 0; mi < M; mi++)
|
|
{
|
|
tmp[mi] = op(src[mi].x, src[mi].y);
|
|
}
|
|
quad_reduce(dst, tmp, op);
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename Operator, int M>
|
|
__device__ inline void quad_allreduce(float (&dst)[M], float (&src)[M], Operator& op)
|
|
{
|
|
#pragma unroll
|
|
for (int mi = 0; mi < M; mi++)
|
|
{
|
|
dst[mi] = src[mi];
|
|
dst[mi] = Allreduce<4>::run(dst[mi], op);
|
|
}
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename Operator, int M>
|
|
__device__ inline void quad_allreduce(float (&dst)[M], float2 (&src)[M], Operator& op)
|
|
{
|
|
float tmp[M];
|
|
#pragma unroll
|
|
for (int mi = 0; mi < M; mi++)
|
|
{
|
|
tmp[mi] = op(src[mi].x, src[mi].y);
|
|
}
|
|
quad_allreduce(dst, tmp, op);
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline __device__ uint32_t elect_one_sync()
|
|
{
|
|
uint32_t pred = 0;
|
|
#if __CUDA_ARCH__ >= 900
|
|
#if !defined(__CUDACC_RTC__)
|
|
uint32_t laneid = 0;
|
|
asm volatile(
|
|
"\n\
|
|
{\n\
|
|
.reg .b32 %rx;\n\
|
|
.reg .pred %px;\n\
|
|
elect.one.sync %rx|%px, %2;\n\
|
|
@%px mov.s32 %1, 1;\n\
|
|
mov.s32 %0, %rx;\n\
|
|
}\n"
|
|
: "+r"(laneid), "+r"(pred)
|
|
: "r"(0xFFFFFFFF));
|
|
#else
|
|
pred = threadIdx.x == 0;
|
|
#endif
|
|
#endif
|
|
return pred;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline __device__ uint16_t float2_to_e4m3x2(float x, float y)
|
|
{
|
|
#if defined(__CUDA_ARCH__) && ((__CUDA_ARCH__ == 890 && defined(FMHA_ENABLE_SM89_QMMA)) || (__CUDA_ARCH__ >= 900))
|
|
uint16_t res;
|
|
asm volatile("cvt.rn.e4m3x2.f32.satfinite %0, %2, %1;" : "=h"(res) : "f"(x), "f"(y));
|
|
return res;
|
|
#else
|
|
assert(false);
|
|
return 0;
|
|
#endif
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline __device__ uint32_t float4_to_e4m3x4(float x, float y, float z, float w)
|
|
{
|
|
#if defined(__CUDA_ARCH__) && ((__CUDA_ARCH__ == 890 && defined(FMHA_ENABLE_SM89_QMMA)) || (__CUDA_ARCH__ >= 900))
|
|
uint32_t res;
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .b16 lo;\n"
|
|
".reg .b16 hi;\n"
|
|
"cvt.rn.e4m3x2.f32.satfinite lo, %2, %1;\n"
|
|
"cvt.rn.e4m3x2.f32.satfinite hi, %4, %3;\n"
|
|
"mov.b32 %0, {lo, hi};\n"
|
|
"}"
|
|
: "=r"(res)
|
|
: "f"(x), "f"(y), "f"(z), "f"(w));
|
|
return res;
|
|
#else
|
|
assert(false);
|
|
return 0;
|
|
#endif
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline __device__ uint32_t float4_to_e5m2x4(float x, float y, float z, float w)
|
|
{
|
|
#if defined(__CUDA_ARCH__) && ((__CUDA_ARCH__ == 890 && defined(FMHA_ENABLE_SM89_QMMA)) || (__CUDA_ARCH__ >= 900))
|
|
uint32_t res;
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .b16 lo;\n"
|
|
".reg .b16 hi;\n"
|
|
"cvt.rn.e5m2x2.f32.satfinite lo, %2, %1;\n"
|
|
"cvt.rn.e5m2x2.f32.satfinite hi, %4, %3;\n"
|
|
"mov.b32 %0, {lo, hi};\n"
|
|
"}"
|
|
: "=r"(res)
|
|
: "f"(x), "f"(y), "f"(z), "f"(w));
|
|
return res;
|
|
#else
|
|
assert(false);
|
|
return 0;
|
|
#endif
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline __device__ uint32_t half4_to_e4m3x4(uint32_t const h2_0, uint32_t const h2_1)
|
|
{
|
|
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890))
|
|
uint32_t res;
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .b16 lo, hi;\n"
|
|
"cvt.satfinite.rn.e4m3x2.f16x2 lo, %1;\n"
|
|
"cvt.satfinite.rn.e4m3x2.f16x2 hi, %2;\n"
|
|
"mov.b32 %0, {lo, hi};\n"
|
|
"}\n"
|
|
: "=r"(res)
|
|
: "r"(h2_0), "r"(h2_1));
|
|
return res;
|
|
#else
|
|
assert(false);
|
|
return 0;
|
|
#endif
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline __device__ uint32_t half4_to_e5m2x4(uint32_t const h2_0, uint32_t const h2_1)
|
|
{
|
|
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890))
|
|
uint32_t res;
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .b16 lo, hi;\n"
|
|
"cvt.satfinite.rn.e5m2x2.f16x2 lo, %1;\n"
|
|
"cvt.satfinite.rn.e5m2x2.f16x2 hi, %2;\n"
|
|
"mov.b32 %0, {lo, hi};\n"
|
|
"}\n"
|
|
: "=r"(res)
|
|
: "r"(h2_0), "r"(h2_1));
|
|
return res;
|
|
#else
|
|
assert(false);
|
|
return 0;
|
|
#endif
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// Helpers to pack float4 into a destination register with 4 8bit values
|
|
template <typename Dst_type>
|
|
inline __device__ uint32_t float4_to_8bitx4(float const x, float const y, float const z, float const w)
|
|
{
|
|
return float4_to_char4<false>(x, y, z, w);
|
|
};
|
|
|
|
template <>
|
|
inline __device__ uint32_t float4_to_8bitx4<e4m3_t>(float const x, float const y, float const z, float const w)
|
|
{
|
|
return float4_to_e4m3x4(x, y, z, w);
|
|
};
|
|
|
|
template <>
|
|
inline __device__ uint32_t float4_to_8bitx4<e5m2_t>(float const x, float const y, float const z, float const w)
|
|
{
|
|
return float4_to_e5m2x4(x, y, z, w);
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename T>
|
|
inline __device__ uint32_t half4_to_fp8x4(uint32_t const h2_0, uint32_t const h2_1);
|
|
|
|
template <>
|
|
inline __device__ uint32_t half4_to_fp8x4<fmha::e4m3_t>(uint32_t const h2_0, uint32_t const h2_1)
|
|
{
|
|
return half4_to_e4m3x4(h2_0, h2_1);
|
|
}
|
|
|
|
template <>
|
|
inline __device__ uint32_t half4_to_fp8x4<fmha::e5m2_t>(uint32_t const h2_0, uint32_t const h2_1)
|
|
{
|
|
return half4_to_e5m2x4(h2_0, h2_1);
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename T>
|
|
inline __device__ uint32_t float4_to_fp8x4(float const, float const, float const, float const);
|
|
|
|
template <>
|
|
inline __device__ uint32_t float4_to_fp8x4<fmha::e4m3_t>(float const x, float const y, float const z, float const w)
|
|
{
|
|
return float4_to_e4m3x4(x, y, z, w);
|
|
}
|
|
|
|
template <>
|
|
inline __device__ uint32_t float4_to_fp8x4<fmha::e5m2_t>(float const x, float const y, float const z, float const w)
|
|
{
|
|
return float4_to_e5m2x4(x, y, z, w);
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
inline __device__ void fence_view_async_shared()
|
|
{
|
|
|
|
// Issue a shared memory fence for async operations (FENCE.VIEW.ASYNC.S)
|
|
// only compiles on sm90+
|
|
|
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
|
|
asm volatile("fence.proxy.async.shared::cta;\n");
|
|
#else
|
|
assert(false);
|
|
#endif
|
|
}
|
|
|
|
inline __device__ void fence_view_async_global()
|
|
{
|
|
|
|
// Issue a global memory fence for async operations (FENCE.VIEW.ASYNC.G)
|
|
// only compiles on sm90+
|
|
|
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
|
|
asm volatile("fence.proxy.async.global::cta;\n");
|
|
#else
|
|
assert(false);
|
|
#endif
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static inline __device__ char* align_1024(char* ptr)
|
|
{
|
|
uint64_t address_bit = reinterpret_cast<uint64_t>(ptr);
|
|
uint64_t offset = address_bit % 1024;
|
|
if (offset == 0)
|
|
{
|
|
return ptr;
|
|
}
|
|
else
|
|
{
|
|
return ptr + (1024 - offset);
|
|
}
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static inline __device__ float atomicMaxFloat(float* addr, float value)
|
|
{
|
|
float old;
|
|
old = (value >= 0) ? __int_as_float(atomicMax((int*) addr, __float_as_int(value)))
|
|
: __uint_as_float(atomicMin((unsigned int*) addr, __float_as_uint(value)));
|
|
return old;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static inline __device__ float atomicMaxFloatPos_(float* addr, float value)
|
|
{
|
|
// VALUE MUST BE POSITIVE! USED ONLY FOR INTERNAL AMAX REDUCTION.
|
|
float old = __int_as_float(atomicMax((int*) addr, __float_as_int(value)));
|
|
return old;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static inline __device__ float max3Pos_(float const a, float const b, float const c)
|
|
{
|
|
// VALUE MUST BE POSITIVE! USED ONLY FOR INTERNAL AMAX REDUCTION.
|
|
float res;
|
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
|
|
int32_t a_ = reinterpret_cast<int32_t const&>(a);
|
|
int32_t b_ = reinterpret_cast<int32_t const&>(b);
|
|
int32_t c_ = reinterpret_cast<int32_t const&>(c);
|
|
int32_t tmp;
|
|
asm volatile("max.s16x2 %0, %1, %2;\n" : "=r"(tmp) : "r"(a_), "r"(b_));
|
|
asm volatile("max.s16x2 %0, %0, %1;\n" : "+r"(tmp) : "r"(tmp), "r"(c_));
|
|
res = reinterpret_cast<float const&>(tmp);
|
|
#else
|
|
res = fmaxf(a, fmaxf(b, c));
|
|
#endif
|
|
return res;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// Fast approximate tanh.
|
|
static inline __device__ float __tanhf(float x)
|
|
{
|
|
#if (__CUDA_ARCH__ >= 750)
|
|
float r = x;
|
|
asm("tanh.approx.f32 %0, %0;" : "+f"(r));
|
|
return r;
|
|
#else
|
|
return tanhf(x);
|
|
#endif
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
} // namespace fmha
|