/* * SPDX-FileCopyrightText: Copyright (c) 2011-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #include #include #include #include #if defined(__CLANGD__) #include <__clang_cuda_builtin_vars.h> #include <__clang_cuda_math.h> #endif #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #include #endif // include warpgroup related instructions, used by SM90. #include // include gmma related instructions, used by SM90. #include // include tma related instructions, used by SM90. #include #include "fmha/numeric_types.h" #define FP32_I2F_MAGIC_NUMBER 12582912.f #define FP32_I2F_MAGIC_NUMBER_HEX 0x4b400000 extern "C" __device__ uint32_t __nvvm_get_smem_pointer(void* ptr); //////////////////////////////////////////////////////////////////////////////////////////////////// namespace introspection { template struct Unpack; template struct Unpack { // if we simply static_assert(false) then compiler will not emit template params upon failure static_assert(N < INT_MIN, ""); using Type = std::integral_constant; }; template struct Unpack { using Type = Unpack; using Unpack_first = typename Unpack::Type; using Unpack_remaining = typename Unpack::Type; }; } // namespace introspection // Example usage: // // Inspect_ns<(int)USE_LDGSTS_, PRED_REGS, (int)IS_HOPPER> foo; // // or // // Inspect_ns<(int)USE_LDGSTS_, PRED_REGS, (int)IS_HOPPER>{}.foo(); // // Output by nvcc: // // ./src/fmha/gmem_tile_qkv_packed.h(70): error: static assertion failed with "" // detected during: // instantiation of class "fmha::v2::Unpack [with N=1]" // (77): here // instantiation of class "fmha::v2::Unpack [with N=1, Ns=<2, 0>]" // (84): here // instantiation of class "fmha::v2::Inspect_ns [with Ns=<1, 2, 0>]" // (143): here template struct Inspect_ns { using Type = typename introspection::Unpack::Type; }; // Can be used alongside with static_assert() to figure out the conditions when assertion failed // Example: // // Cond_inspect_ns< (int)ROWS >= (int)ROWS_PER_LDG, ROWS, ROWS_PER_LDG> foo; // // Output by nvcc (when condition is not met): // // ./src/fmha/utils.h(163): error: static assertion failed with "" // detected during: // instantiation of class "Cond_inspect_ns [with COND=false, Ns=<32, 64>]" template struct Cond_inspect_ns { static_assert(COND, ""); }; // Example: // // Inspect_type{}.foo(); // // or // // Inspect_type foo; // // Output by nvcc: // // ./src/fmha/utils.h(189): error: class "fmha::Ampere_hmma_tile, 16>" has no member "Dummy" // detected during: // instantiation of class "Inspect_type [with T=fmha::Ampere_hmma_tile, 16>]" template struct Inspect_type { // Purposefully trigger error by referencing non-existent T::Dummy using Dummy = typename T::Dummy; }; //////////////////////////////////////////////////////////////////////////////////////////////////// namespace fmha { //////////////////////////////////////////////////////////////////////////////////////////////////// struct Row { static constexpr bool COL = false; static constexpr bool ROW = true; }; struct Col { static constexpr bool COL = true; static constexpr bool ROW = false; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Round_up { enum { VALUE = (M + N - 1) / N * N }; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Tile_nhw { enum { N = N_, H = H_, W = W_ }; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Next_power_of_two { }; template struct Next_power_of_two { enum { VALUE = M }; }; template <> struct Next_power_of_two<3, false> { enum { VALUE = 4 }; }; template <> struct Next_power_of_two<5, false> { enum { VALUE = 8 }; }; template <> struct Next_power_of_two<6, false> { enum { VALUE = 8 }; }; template <> struct Next_power_of_two<7, false> { enum { VALUE = 8 }; }; template <> struct Next_power_of_two<9, false> { enum { VALUE = 16 }; }; template <> struct Next_power_of_two<10, false> { enum { VALUE = 16 }; }; template <> struct Next_power_of_two<11, false> { enum { VALUE = 16 }; }; template <> struct Next_power_of_two<12, false> { enum { VALUE = 16 }; }; template <> struct Next_power_of_two<13, false> { enum { VALUE = 16 }; }; template <> struct Next_power_of_two<14, false> { enum { VALUE = 16 }; }; template <> struct Next_power_of_two<15, false> { enum { VALUE = 16 }; }; template <> struct Next_power_of_two<24, false> { enum { VALUE = 32 }; }; template <> struct Next_power_of_two<40, false> { enum { VALUE = 64 }; }; template <> struct Next_power_of_two<48, false> { enum { VALUE = 64 }; }; template <> struct Next_power_of_two<72, false> { enum { VALUE = 128 }; }; template <> struct Next_power_of_two<80, false> { enum { VALUE = 128 }; }; template <> struct Next_power_of_two<96, false> { enum { VALUE = 128 }; }; template <> struct Next_power_of_two<104, false> { enum { VALUE = 128 }; }; template <> struct Next_power_of_two<112, false> { enum { VALUE = 128 }; }; template <> struct Next_power_of_two<144, false> { enum { VALUE = 256 }; }; template <> struct Next_power_of_two<160, false> { enum { VALUE = 256 }; }; template <> struct Next_power_of_two<192, false> { enum { VALUE = 256 }; }; template <> struct Next_power_of_two<576, false> { enum { VALUE = 1024 }; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Prev_power_of_two { }; template struct Prev_power_of_two { enum { VALUE = N }; }; template <> struct Prev_power_of_two<3, false> { enum { VALUE = 2 }; }; template <> struct Prev_power_of_two<5, false> { enum { VALUE = 4 }; }; template <> struct Prev_power_of_two<6, false> { enum { VALUE = 4 }; }; template <> struct Prev_power_of_two<7, false> { enum { VALUE = 4 }; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Compute_skew { // The size of a transaction. enum { BYTES_PER_TRX = 128 }; // The remainder of the row without skew. enum { REMAINDER = BYTES_PER_ROW % BYTES_PER_TRX }; // The value. enum { VALUE = REMAINDER <= SKEW ? SKEW - REMAINDER : BYTES_PER_TRX + SKEW - REMAINDER }; // Make sure the math works ;) static_assert((BYTES_PER_ROW + VALUE) % BYTES_PER_TRX == SKEW, ""); }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Compute_skew { // No skew! enum { VALUE = 0 }; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Div_up { enum { VALUE = (M + N - 1) / N }; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Max { enum { VALUE = A >= B ? A : B }; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Max_3 { enum { VALUE = Max::VALUE, C>::VALUE }; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Min { enum { VALUE = A <= B ? A : B }; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Uint_from_size_in_bytes { }; //////////////////////////////////////////////////////////////////////////////////////////////////// template <> struct Uint_from_size_in_bytes<1> { using Type = uint8_t; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template <> struct Uint_from_size_in_bytes<2> { using Type = uint16_t; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template <> struct Uint_from_size_in_bytes<4> { using Type = uint32_t; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template <> struct Uint_from_size_in_bytes<8> { using Type = uint2; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template <> struct Uint_from_size_in_bytes<16> { using Type = uint4; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Warp_masks { }; //////////////////////////////////////////////////////////////////////////////////////////////////// template <> struct Warp_masks<8, 1, 1> { enum { M = 0xe0, N = 0x00, K = 0x00 }; }; template <> struct Warp_masks<4, 2, 1> { enum { M = 0x60, N = 0x80, K = 0x00 }; }; template <> struct Warp_masks<4, 1, 2> { enum { M = 0x60, N = 0x00, K = 0x80 }; }; template <> struct Warp_masks<4, 1, 1> { enum { M = 0x60, N = 0x00, K = 0x00 }; }; template <> struct Warp_masks<2, 4, 1> { enum { M = 0x20, N = 0xc0, K = 0x00 }; }; template <> struct Warp_masks<2, 2, 2> { enum { M = 0x20, N = 0x40, K = 0x80 }; }; template <> struct Warp_masks<2, 2, 1> { enum { M = 0x20, N = 0x40, K = 0x00 }; }; template <> struct Warp_masks<2, 1, 2> { enum { M = 0x20, N = 0x00, K = 0x40 }; }; template <> struct Warp_masks<2, 1, 1> { enum { M = 0x20, N = 0x00, K = 0x00 }; }; template <> struct Warp_masks<1, 8, 1> { enum { M = 0x00, N = 0xe0, K = 0x00 }; }; template <> struct Warp_masks<1, 4, 2> { enum { M = 0x00, N = 0x60, K = 0x80 }; }; template <> struct Warp_masks<1, 4, 1> { enum { M = 0x00, N = 0x60, K = 0x00 }; }; template <> struct Warp_masks<1, 2, 2> { enum { M = 0x00, N = 0x20, K = 0x40 }; }; template <> struct Warp_masks<1, 2, 1> { enum { M = 0x00, N = 0x20, K = 0x00 }; }; template <> struct Warp_masks<1, 1, 4> { enum { M = 0x00, N = 0x00, K = 0x60 }; }; template <> struct Warp_masks<1, 1, 2> { enum { M = 0x00, N = 0x00, K = 0x20 }; }; template <> struct Warp_masks<1, 1, 1> { enum { M = 0x00, N = 0x00, K = 0x00 }; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ __host__ T div_up(T m, T n) { return (m + n - 1) / n; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline int clz(int x) { for (int i = 31; i >= 0; --i) { if ((1 << i) & x) { return 31 - i; } } return 32; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline int find_log_2(int x, bool round_up = false) { int a = 31 - clz(x); if (round_up) { a += (x & (x - 1)) ? 1 : 0; } return a; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline void find_divisor(uint32_t& mul, uint32_t& shr, int x) { assert(x != 0); if (x == 1) { // If dividing by 1, reduced math doesn't work because mul_coeff would need to be 2^32, // which doesn't fit into unsigned int. the div() routine handles this special case // separately. mul = 0; shr = 0; } else { // To express the division N/D in terms of a multiplication, what we first // imagine is simply N*(1/D). However, 1/D will always evaluate to 0 (for D>1), // so we need another way. There's nothing that says we have to use exactly // the fraction 1/D; instead it could be any X/Y that reduces to 1/D (i.e., // Y=X*D), or at least to "close enough" to it. If we pick Y that is a power // of two, then the N*(X/Y) can be N*X followed by a right-shift by some amount. // The power of two we should pick should be at least 2^32, because in the // div() routine we'll use umulhi(), which returns only the upper 32 bits -- // this being equivalent to a right-shift by 32. But we might want a higher // power of two for better accuracy depending on the magnitude of the denominator. // Once we've picked Y, then X [our mul_coeff value] is simply Y/D, rounding up, // and we save shift_coeff as whatever further shift we have to do beyond // what the umulhi() implies. uint32_t p = 31 + find_log_2(x, true); uint32_t m = (uint32_t) (((1ull << p) + (uint32_t) x - 1) / (uint32_t) x); mul = m; shr = p - 32; } } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void fast_divmod(int& div, int& mod, int x, int y, uint32_t mul, uint32_t shr) { if (y == 1) { div = x; mod = 0; } else { div = __umulhi((uint32_t) x, mul) >> shr; mod = x - div * y; } } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint32_t hadd2(uint32_t a, uint32_t b) { uint32_t c; asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); return c; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint32_t bfadd2(uint32_t a, uint32_t b) { uint32_t c; uint32_t one = 0x3f803f80; ; asm volatile("fma.rn.bf16x2 %0, %1, %3, %2;\n" : "=r"(c) : "r"(a), "r"(b), "r"(one)); return c; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint32_t hmax2(uint32_t a, uint32_t b) { uint32_t c; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 asm volatile("max.f16x2 %0, %1, %2;" : "=r"(c) : "r"(a), "r"(b)); #else asm volatile( "{\n" "\t .reg .f16x2 sela, selb;\n" "\n" "\t set.ge.f16x2.f16x2 sela, %1, %2;\n" "\t set.gt.f16x2.f16x2 selb, %2, %1;\n" "\n" "\t mul.f16x2 %0, sela, %1;\n" "\t fma.rn.f16x2 %0, selb, %2, %0;\n" "}\n" : "=r"(c) : "r"(a), "r"(b)); #endif return c; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint2 hmax4(uint2 a, uint2 b) { uint2 c; c.x = hmax2(a.x, b.x); c.y = hmax2(a.y, b.y); return c; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint4 hmax8(uint4 a, uint4 b) { uint4 c; c.x = hmax2(a.x, b.x); c.y = hmax2(a.y, b.y); c.z = hmax2(a.z, b.z); c.w = hmax2(a.w, b.w); return c; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint32_t hmin2(uint32_t a, uint32_t b) { uint32_t c; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 asm volatile("min.f16x2 %0, %1, %2;" : "=r"(c) : "r"(a), "r"(b)); #else asm volatile( "{\n" "\t .reg .f16x2 sela, selb;\n" "\n" "\t set.le.f16x2.f16x2 sela, %1, %2;\n" "\t set.lt.f16x2.f16x2 selb, %2, %1;\n" "\n" "\t mul.f16x2 %0, sela, %1;\n" "\t fma.rn.f16x2 %0, selb, %2, %0;\n" "}\n" : "=r"(c) : "r"(a), "r"(b)); #endif return c; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint32_t hmul2(uint32_t a, uint32_t b) { uint32_t c; asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); return c; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint32_t bfmul2(uint32_t a, uint32_t b) { uint32_t c; asm("{.reg .b32 c;\n" " mov.b32 c, 0x80008000U;\n" " fma.rn.bf16x2 %0,%1,%2,c;}\n" : "=r"(c) : "r"(a), "r"(b)); return c; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint2 hmul4(uint2 a, uint2 b) { uint2 c; c.x = hmul2(a.x, b.x); c.y = hmul2(a.y, b.y); return c; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint4 hmul8(uint4 a, uint4 b) { uint4 c; c.x = hmul2(a.x, b.x); c.y = hmul2(a.y, b.y); c.z = hmul2(a.z, b.z); c.w = hmul2(a.w, b.w); return c; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint4 hmul8(uint32_t a, uint4 b) { uint4 c; c.x = hmul2(a, b.x); c.y = hmul2(a, b.y); c.z = hmul2(a, b.z); c.w = hmul2(a, b.w); return c; } //////////////////////////////////////////////////////////////////////////////////////////////////// // Template function to support both half and bfloat16 template inline __device__ uint32_t mul2(uint32_t a, uint32_t b) { return hmul2(a, b); } template <> inline __device__ uint32_t mul2(uint32_t a, uint32_t b) { return bfmul2(a, b); } //////////////////////////////////////////////////////////////////////////////////////////////////// // Template function to support both half and bfloat16 template inline __device__ uint4 mul8(uint32_t a, uint4 b) { uint4 c; c.x = hmul2(a, b.x); c.y = hmul2(a, b.y); c.z = hmul2(a, b.z); c.w = hmul2(a, b.w); return c; } template <> inline __device__ uint4 mul8(uint32_t a, uint4 b) { uint4 c; c.x = bfmul2(a, b.x); c.y = bfmul2(a, b.y); c.z = bfmul2(a, b.z); c.w = bfmul2(a, b.w); return c; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint32_t hrelu2(uint32_t x) { uint32_t res; uint32_t const zero = 0u; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 asm volatile("max.f16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero)); #else asm volatile( "{\n" "\t .reg .f16x2 sela;\n" "\t set.gtu.u32.f16x2 sela, %1, %2;\n" "\t and.b32 %0, sela, %1;\n" "}\n" : "=r"(res) : "r"(x), "r"(zero)); #endif return res; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint32_t bfrelu2(uint32_t x) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 uint32_t res; uint32_t const zero = 0u; asm volatile("max.bf16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero)); return res; #endif // not implemented yet return x; } //////////////////////////////////////////////////////////////////////////////////////////////////// // Template function to support both half and bfloat16 template inline __device__ uint32_t relu2(uint32_t x) { return hrelu2(x); } template <> inline __device__ uint32_t relu2(uint32_t x) { return bfrelu2(x); } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint32_t habs2(uint32_t x) { uint32_t res; asm volatile("abs.f16x2 %0, %1;\n" : "=r"(res) : "r"(x)); return res; } //////////////////////////////////////////////////////////////////////////////////////////////////// // static inline __device__ uint32_t add_bias(uint32_t a, uint32_t bias, bool relu) { // uint32_t c; // #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 // if( relu ) { // uint32_t one = 0x3c003c00u; // asm volatile("fma.rn.f16x2.relu %0, %1, %2, %3;" : "=r"(c) : "r"(a), "r"(one), "r"(bias)); // } else { // c = hadd2(a, bias); // } // #else // c = hadd2(a, bias); // if( relu ) { // c = hrelu2(c); // } // #endif // return c; // } //////////////////////////////////////////////////////////////////////////////////////////////////// // static inline __device__ uint2 add_bias(uint2 a, uint2 bias, bool relu) { // uint2 dst; // dst.x = add_bias(a.x, bias.x, relu); // dst.y = add_bias(a.y, bias.y, relu); // return dst; // } //////////////////////////////////////////////////////////////////////////////////////////////////// // static inline __device__ uint4 add_bias(uint4 a, uint4 bias, bool relu) { // uint4 dst; // dst.x = add_bias(a.x, bias.x, relu); // dst.y = add_bias(a.y, bias.y, relu); // dst.z = add_bias(a.z, bias.z, relu); // dst.w = add_bias(a.w, bias.w, relu); // return dst; // } //////////////////////////////////////////////////////////////////////////////////////////////////// // clamp float +inf/-inf static inline __device__ float satfinite(float x) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 860 // bit representation of maximum value of float uint32_t clamp_value = 0x7f7fffffu; asm volatile("min.xorsign.abs.f32 %0, %0, %1;" : "+f"(x) : "r"(clamp_value)); return x; #else // bit representation of maximum and minimum value of float uint32_t umax = 0x7f7fffffu; uint32_t umin = 0xff7fffffu; float out; asm volatile("min.f32 %0, %1, %2;" : "=f"(out) : "f"(x), "r"(umax)); asm volatile("max.f32 %0, %0, %1;" : "+f"(out) : "r"(umin)); return out; #endif } //////////////////////////////////////////////////////////////////////////////////////////////////// // clamp half2 +inf/-inf static inline __device__ uint32_t satfinite_h2(uint32_t h2) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 860 uint32_t out, clamp_value; clamp_value = 0x7bff7bffu; asm volatile("min.xorsign.abs.f16x2 %0, %1, %2;" : "=r"(out) : "r"(h2), "r"(clamp_value)); return out; #elif defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 800 // bit representation of maximum and minimum value of half2 uint32_t umax = 0x7bff7bffu; uint32_t umin = 0xfbfffbffu; uint32_t out; asm volatile("min.f16x2 %0, %1, %2;" : "=r"(out) : "r"(h2), "r"(umax)); asm volatile("max.f16x2 %0, %0, %1;" : "+r"(out) : "r"(umin)); return out; #else // Take the absolute value of h2. It should map to |Rx| in SASS. uint32_t p2; asm volatile("abs.f16x2 %0, %1;" : "=r"(p2) : "r"(h2)); // Compute a mask for each fp16: 0xffff if +INF and 0x0000 otherwise. uint32_t inf2 = 0x7c007c00u; uint32_t mask; asm volatile("set.eq.u32.f16x2 %0, %1, %2;" : "=r"(mask) : "r"(p2), "r"(inf2)); // Recreate the new value. 0x7bff is the max value for FP16. p2 = (~mask & p2) | (mask & 0x7bff7bff); // Simply re-add the sign and we're done. return p2 | (h2 & 0x80008000); #endif } //////////////////////////////////////////////////////////////////////////////////////////////////// template static inline __device__ T clamp(T x, T lb, T ub) { return x < lb ? lb : (x > ub ? ub : x); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ float custom_exp2f(float x, float scale, float scaled_max) { float d1, d2; asm("fma.rz.ftz.f32 %0, %1, %2, %3;" : "=f"(d1) : "f"(x), "f"(scale), "f"(-scaled_max)); asm("ex2.approx.ftz.f32 %0, %1;" : "=f"(d2) : "f"(d1)); return d2; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint16_t clamp_to_zero(uint16_t x) { uint16_t mask; asm volatile("set.gtu %0, %1, 0;" : "=h"(mask) : "h"(x)); return mask & x; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint16_t float_to_half(float f) { uint16_t h; asm volatile("cvt.rn.f16.f32 %0, %1;" : "=h"(h) : "f"(f)); return h; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ bf16_t float_to_bf16(float f) { return __float2bfloat16(f); } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint32_t float2_to_half2(float a, float b) { uint32_t c; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(c) : "f"(b), "f"(a)); #else uint16_t lo = float_to_half(a); uint16_t hi = float_to_half(b); asm volatile("mov.b32 %0, {%1, %2};\n" : "=r"(c) : "h"(lo), "h"(hi)); #endif return c; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint32_t float2_to_bf16_x2(float a, float b) { uint32_t c; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n" : "=r"(c) : "f"(b), "f"(a)); #else uint16_t* px = reinterpret_cast(&a); uint16_t* py = reinterpret_cast(&b); uint16_t value = px[1]; uint16_t value2 = py[1]; if (px[0] == 0x8000) { if ((value & 0x1) == 1) value++; } else if (px[0] > 0x8000) { value++; } if (py[0] == 0x8000) { if ((value2 & 0x1) == 1) value2++; } else if (py[0] > 0x8000) { value2++; } uint32_t high = reinterpret_cast(value2); c = (high << 16) | value; #endif return c; } //////////////////////////////////////////////////////////////////////////////////////////////////// // Template function to support both half and bfloat16 template inline __device__ uint32_t float2_to_16bit_2(float a, float b) { return float2_to_half2(a, b); } template <> inline __device__ uint32_t float2_to_16bit_2(float a, float b) { return float2_to_bf16_x2(a, b); } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint32_t float_to_half2(float a) { return float2_to_half2(a, a); } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint32_t float2_to_half2(float2 const& f) { return float2_to_half2(f.x, f.y); } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint32_t float_to_bf16_2(float a) { return float2_to_bf16_x2(a, a); } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint2 float4_to_half4(float x, float y, float z, float w) { uint2 d; d.x = float2_to_half2(x, y); d.y = float2_to_half2(z, w); return d; } //////////////////////////////////////////////////////////////////////////////////////////////////// // Template function to support both half and bfloat16 template inline __device__ uint2 float4_to_16bit_x4(float x, float y, float z, float w) { uint2 d; d.x = float2_to_half2(x, y); d.y = float2_to_half2(z, w); return d; } template <> inline __device__ uint2 float4_to_16bit_x4(float x, float y, float z, float w) { uint2 d; d.x = float2_to_bf16_x2(x, y); d.y = float2_to_bf16_x2(z, w); return d; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint32_t hfma2(uint32_t a, uint32_t b, uint32_t c) { uint32_t d; asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c)); return d; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint32_t hfma2_relu(uint32_t a, uint32_t b, uint32_t c) { uint32_t d; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 asm volatile("fma.rn.f16x2.relu %0, %1, %2, %3;" : "=r"(d) : "r"(a), "r"(b), "r"(c)); #else d = hrelu2(hfma2(a, b, c)); #endif return d; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint32_t h0_h0(uint32_t x) { uint32_t y; asm volatile("{.reg .f16 lo, hi; mov.b32 {lo, hi}, %1; mov.b32 %0, {lo, lo};}\n" : "=r"(y) : "r"(x)); return y; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ float h0_to_float(uint32_t h2) { float f; asm volatile( "{\n" ".reg .f16 lo, hi;\n" "mov.b32 {lo, hi}, %1;\n" "cvt.f32.f16 %0, lo;\n" "}\n" : "=f"(f) : "r"(h2)); return f; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint32_t h1_h1(uint32_t x) { uint32_t y; asm volatile("{.reg .f16 lo, hi; mov.b32 {lo, hi}, %1; mov.b32 %0, {hi, hi};}\n" : "=r"(y) : "r"(x)); return y; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint16_t hadd(uint16_t a, uint16_t b) { uint16_t d; asm volatile("add.f16 %0, %1, %2;" : "=h"(d) : "h"(a), "h"(b)); return d; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint32_t hadd(uint32_t a, uint32_t b) { return hadd2(a, b); } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint2 hadd4(uint2 a, uint2 b) { uint2 c; c.x = hadd2(a.x, b.x); c.y = hadd2(a.y, b.y); return c; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint2 hadd(uint2 a, uint2 b) { return hadd4(a, b); } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint4 hadd8(uint4 a, uint4 b) { uint4 c; c.x = hadd2(a.x, b.x); c.y = hadd2(a.y, b.y); c.z = hadd2(a.z, b.z); c.w = hadd2(a.w, b.w); return c; } //////////////////////////////////////////////////////////////////////////////////////////////////// // Template function to support both half and bfloat16 template inline __device__ uint4 add8(uint4 a, uint4 b) { return hadd8(a, b); } template <> inline __device__ uint4 add8(uint4 a, uint4 b) { uint4 c; c.x = bfadd2(a.x, b.x); c.y = bfadd2(a.y, b.y); c.z = bfadd2(a.z, b.z); c.w = bfadd2(a.w, b.w); return c; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint4 fadd4(uint4 a, uint4 b) { float4 c; c.x = reinterpret_cast(a.x) + reinterpret_cast(b.x); c.y = reinterpret_cast(a.y) + reinterpret_cast(b.y); c.z = reinterpret_cast(a.z) + reinterpret_cast(b.z); c.w = reinterpret_cast(a.w) + reinterpret_cast(b.w); return reinterpret_cast(c); } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint4 hadd(uint4 a, uint4 b) { return hadd8(a, b); } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ float half_to_float(uint16_t h) { float f; asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h)); return f; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ float bf16_to_float(uint16_t h) { float f; asm volatile("mov.b32 %0, {0, %1};\n" : "=f"(f) : "h"(h)); return f; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ float2 half2_to_float2(uint32_t x) { uint16_t lo, hi; asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(x)); return make_float2(half_to_float(lo), half_to_float(hi)); } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ float2 bf16_2_to_float2(uint32_t x) { float2 res; asm volatile( "{\n" " .reg .b16 lo, hi;\n" " mov.b32 {lo, hi}, %2;\n" " mov.b32 %0, {0, lo};\n" " mov.b32 %1, {0, hi};\n" "}\n" : "=f"(res.x), "=f"(res.y) : "r"(x)); return res; } //////////////////////////////////////////////////////////////////////////////////////////////////// // Template function to support both half and bfloat16 template inline __device__ float2 convert_from_16bit_2(uint32_t x) { return half2_to_float2(x); } template <> inline __device__ float2 convert_from_16bit_2(uint32_t x) { return bf16_2_to_float2(x); } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ void half2_to_float2(float& x, float& y, uint32_t h) { float2 tmp = half2_to_float2(h); x = tmp.x; y = tmp.y; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint16_t hfma(uint16_t a, uint16_t b, uint16_t c) { uint16_t d; asm volatile("fma.rn.f16 %0, %1, %2, %3;" : "=h"(d) : "h"(a), "h"(b), "h"(c)); return d; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint16_t hmul(uint16_t a, uint16_t b) { uint16_t d; asm volatile("mul.f16 %0, %1, %2;" : "=h"(d) : "h"(a), "h"(b)); return d; } //////////////////////////////////////////////////////////////////////////////////////////////////// // Converted two half2's or bf162's into float, then take their dot product. template inline __device__ float fma2_in_float(uint32_t const a, uint32_t const b) { float2 af = fmha::convert_from_16bit_2(a); float2 bf = fmha::convert_from_16bit_2(b); return af.x * bf.x + af.y * bf.y; } //////////////////////////////////////////////////////////////////////////////////////////////////// // Converted two vectors of 8 half's or bf16's into float, then take their dot product. template inline __device__ float fma8_in_float(uint4 const a, uint4 const b) { float sum; sum = fmha::fma2_in_float(a.x, b.x); sum += fmha::fma2_in_float(a.y, b.y); sum += fmha::fma2_in_float(a.z, b.z); sum += fmha::fma2_in_float(a.w, b.w); return sum; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ float sigmoid(float x) { return 1.f / (1.f + expf(-x)); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void clear(uint16_t& dst) { dst = uint16_t(0); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void clear(uint32_t& dst) { dst = 0u; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void clear(uint2& dst) { dst = make_uint2(0u, 0u); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void clear(uint4& dst) { dst = make_uint4(0u, 0u, 0u, 0u); } //////////////////////////////////////////////////////////////////////////////////////////////////// // // P R E D I C A T E P A C K I N G // //////////////////////////////////////////////////////////////////////////////////////////////////// enum { BYTES_PER_REG = 4, PREDS_PER_BYTE = 4, PREDS_PER_REG = BYTES_PER_REG * PREDS_PER_BYTE }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Compute_number_of_pred_regs { enum { VALUE = Div_up::VALUE }; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ void pack_predicates(uint32_t (&preds)[M], uint32_t const (&p)[N]) { // Make sure the values match. static_assert(Compute_number_of_pred_regs::VALUE == M, ""); // The number of complete steps (where we use all the predicates in a byte). enum { COMPLETE_BYTES = N / PREDS_PER_BYTE }; // Make sure we allocated enough predicate registers. static_assert(Div_up::VALUE <= M, ""); // The remainder. enum { REMAINDER = N - COMPLETE_BYTES * PREDS_PER_BYTE }; // Make sure we got the math right and the remainder is between 0 and 3. static_assert(REMAINDER >= 0 && REMAINDER <= 3, ""); // The mask to extract the predicates. enum { COMPLETE_MASK = (1 << PREDS_PER_BYTE) - 1 }; // Run complete steps. #pragma unroll for (int ii = 0; ii < M; ++ii) { // The number of complete bytes for that register. Be careful it can be > than 4 ;) int const COMPLETE = (N - ii * PREDS_PER_REG) / PREDS_PER_BYTE; // Pack the predicates in a register. uint32_t reg = 0u; #pragma unroll for (int jj = 0; jj < 4; ++jj) { // Early exit. if (jj >= COMPLETE) { break; } // Prepare the array of predicates. bool tmp[PREDS_PER_BYTE]; #pragma unroll for (int kk = 0; kk < PREDS_PER_BYTE; ++kk) { tmp[kk] = p[ii * PREDS_PER_REG + jj * PREDS_PER_BYTE + kk] != 0; } // Store the predicates. #pragma unroll for (int kk = 0; kk < PREDS_PER_BYTE; ++kk) { if (tmp[kk]) { reg |= 1u << (jj * 8 + kk); } } } // Skip the rest of the code if we do not have a remainder. if (COMPLETE < 4 && REMAINDER > 0) { // The mask to extract the predicates. enum { REMAINDER_MASK = (1 << REMAINDER) - 1 }; // Prepare the array of predicates. bool tmp[PREDS_PER_BYTE]; #pragma unroll for (int jj = 0; jj < REMAINDER; ++jj) { tmp[jj] = p[COMPLETE_BYTES * PREDS_PER_BYTE + jj] != 0; } // Store the predicates. #pragma unroll for (int jj = 0; jj < REMAINDER; ++jj) { if (tmp[jj]) { reg |= 1u << (COMPLETE * 8 + jj); } } } // Store the predicate register. preds[ii] = reg; } } //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ uint32_t pack_predicates(uint32_t const (&p)[N]) { uint32_t tmp[1]; pack_predicates(tmp, p); return tmp[0]; } //////////////////////////////////////////////////////////////////////////////////////////////////// // // G E N E R I C P R E D I C A T E D L D G S T S // //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ void ldgsts_(Functor& fct, uint32_t const (&preds)[M]) { // The number of complete bytes (where we use all the predicates in a byte). enum { COMPLETE = N / PREDS_PER_BYTE }; // Make sure we did allocate enough predicates. static_assert(Div_up::VALUE <= M, ""); // The remainder. enum { REMAINDER = N - COMPLETE * PREDS_PER_BYTE }; // Make sure we got the math right and the remainder is between 0 and 3. static_assert(REMAINDER >= 0 && REMAINDER <= 3, ""); // The mask to extract the predicates. enum { COMPLETE_MASK = (1 << PREDS_PER_BYTE) - 1 }; // Clear the fetch registers. #pragma unroll for (int ii = 0; ii < N; ++ii) { fct.clear(ii); } // Run complete steps. bool p[PREDS_PER_BYTE]; #pragma unroll for (int ii = 0; ii < COMPLETE; ++ii) { // The predicate. uint32_t reg = preds[ii / BYTES_PER_REG]; // Extract the predicates. #pragma unroll for (int jj = 0; jj < PREDS_PER_BYTE; ++jj) { uint32_t mask = 1u << (ii % BYTES_PER_REG * 8 + jj); p[jj] = (reg & mask) != 0u; } // Issue the loads. #pragma unroll for (int jj = 0; jj < PREDS_PER_BYTE; ++jj) { fct.ldgsts(ii * PREDS_PER_BYTE + jj, p[jj]); } } // Skip the rest of the code if we do not have a remainder. if (REMAINDER > 0) { // The mask to extract the predicates. enum { REMAINDER_MASK = (1 << REMAINDER) - 1 }; // The predicate register. uint32_t reg = preds[COMPLETE / BYTES_PER_REG]; // Extract the predicates. #pragma unroll for (int jj = 0; jj < PREDS_PER_BYTE; ++jj) { uint32_t mask = 1u << (COMPLETE % BYTES_PER_REG * 8 + jj); p[jj] = (reg & mask) != 0u; } // Issue the loads. #pragma unroll for (int ii = 0; ii < REMAINDER; ++ii) { fct.ldgsts(COMPLETE * PREDS_PER_BYTE + ii, p[ii]); } } } //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ void ldgsts_(Functor& fct, uint32_t preds) { uint32_t tmp[1] = {preds}; ldgsts_(fct, tmp); } //////////////////////////////////////////////////////////////////////////////////////////////////// // // L D G // //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void ldg(uint8_t& dst, void const* ptr) { dst = *reinterpret_cast(ptr); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void ldg(uint16_t& dst, void const* ptr) { dst = *reinterpret_cast(ptr); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void ldg(uint32_t& dst, void const* ptr) { dst = *reinterpret_cast(ptr); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void ldg(uint2& dst, void const* ptr) { dst = *reinterpret_cast(ptr); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void ldg(uint4& dst, void const* ptr) { dst = *reinterpret_cast(ptr); } //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Ldg_functor { // Ctor. inline __device__ Ldg_functor(Data_type (&fetch)[N], void const* (&ptrs)[N]) : fetch_(fetch) , ptrs_(ptrs) { } // Clear the element. inline __device__ void clear(int ii) { fmha::clear(fetch_[ii]); } // Trigger the loads. inline __device__ void ldgsts(int ii, bool p) { if (p) { ldg(fetch_[ii], ptrs_[ii]); } } // The fetch registers. Data_type (&fetch_)[N]; // The pointers. void const* (&ptrs_)[N]; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ void ldg_(Data_type (&fetch)[N], void const* (&ptrs)[N], uint32_t (&preds)[M]) { Ldg_functor fct(fetch, ptrs); ldgsts_(fct, preds); } //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ void ldg(uint8_t (&fetch)[N], void const* (&ptrs)[N], uint32_t (&preds)[M]) { ldg_(fetch, ptrs, preds); } //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ void ldg(uint16_t (&fetch)[N], void const* (&ptrs)[N], uint32_t (&preds)[M]) { ldg_(fetch, ptrs, preds); } //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ void ldg(uint32_t (&fetch)[N], void const* (&ptrs)[N], uint32_t (&preds)[M]) { ldg_(fetch, ptrs, preds); } //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ void ldg(uint2 (&fetch)[N], void const* (&ptrs)[N], uint32_t (&preds)[M]) { ldg_(fetch, ptrs, preds); } //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ void ldg(uint4 (&fetch)[N], void const* (&ptrs)[N], uint32_t (&preds)[M]) { ldg_(fetch, ptrs, preds); } //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ void ldgdepbar() { if (USE_LDGSTS) { asm volatile("cp.async.commit_group;\n" ::); } } //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ void depbar_() { if (USE_LDGSTS) { asm volatile("cp.async.wait_group %0;\n" ::"n"(COUNT)); } } //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ void depbar() { if (USE_LDGSTS) { int const VALUE = Max::VALUE; asm volatile("cp.async.wait_group %0;\n" ::"n"(VALUE)); } } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void ldgsts128(uint32_t dst, void const* src, bool p = true) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 uint32_t m = p ? 16u : 0u; asm volatile("cp.async.cg.shared.global [%0], [%1], 16, %2;\n" ::"r"(dst), "l"(src), "r"(m)); #endif } //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Ldgsts_functor { // Ctor. inline __device__ Ldgsts_functor(uint32_t (&smem_ptrs)[N], void const* (&gmem_ptrs)[N]) : smem_ptrs_(smem_ptrs) , gmem_ptrs_(gmem_ptrs) { } // Does nothing. inline __device__ void clear(int ii) {} // Trigger the load-store instruction. inline __device__ void ldgsts(int ii, bool p) { ldgsts128(smem_ptrs_[ii], gmem_ptrs_[ii], p); } // The shared memory pointers. uint32_t (&smem_ptrs_)[N]; // The global memory pointers. void const* (&gmem_ptrs_)[N]; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ void ldgsts(uint32_t (&dst)[N], void const* (&src)[N], uint32_t (&preds)[M]) { Ldgsts_functor fct(dst, src); ldgsts_(fct, preds); } //////////////////////////////////////////////////////////////////////////////////////////////////// // // L D S // //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void lds(uint16_t& dst, uint32_t ptr) { asm volatile("ld.shared.b16 %0, [%1];\n" : "=h"(dst) : "r"(ptr)); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void lds(uint32_t& dst, uint32_t ptr) { asm volatile("ld.shared.b32 %0, [%1];\n" : "=r"(dst) : "r"(ptr)); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void lds(uint2& dst, uint32_t ptr) { asm volatile("ld.shared.v2.b32 {%0, %1}, [%2];\n" : "=r"(dst.x), "=r"(dst.y) : "r"(ptr)); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void lds(uint4& dst, uint32_t ptr) { asm volatile("ld.shared.v4.b32 {%0, %1, %2, %3}, [%4];\n" : "=r"(dst.x), "=r"(dst.y), "=r"(dst.z), "=r"(dst.w) : "r"(ptr)); } //////////////////////////////////////////////////////////////////////////////////////////////////// // // L D S M // //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void ldsm(uint32_t& dst, uint32_t ptr) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730 asm volatile("ldmatrix.sync.aligned.m8n8.x1.shared.b16 {%0}, [%1];\n" : "=r"(dst) : "r"(ptr)); #endif } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void ldsmt(uint32_t& dst, uint32_t ptr) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730 asm volatile("ldmatrix.sync.aligned.m8n8.x1.trans.shared.b16 {%0}, [%1];\n" : "=r"(dst) : "r"(ptr)); #endif } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void ldsm(uint2& dst, uint32_t ptr) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730 asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0, %1}, [%2];\n" : "=r"(dst.x), "=r"(dst.y) : "r"(ptr)); #endif } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void ldsmt(uint2& dst, uint32_t ptr) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730 asm volatile("ldmatrix.sync.aligned.m8n8.x2.trans.shared.b16 {%0, %1}, [%2];\n" : "=r"(dst.x), "=r"(dst.y) : "r"(ptr)); #endif } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void ldsm(uint4& dst, uint32_t ptr) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730 asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n" : "=r"(dst.x), "=r"(dst.y), "=r"(dst.z), "=r"(dst.w) : "r"(ptr)); #endif } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void ldsmt(uint4& dst, uint32_t ptr) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730 asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];\n" : "=r"(dst.x), "=r"(dst.y), "=r"(dst.z), "=r"(dst.w) : "r"(ptr)); #endif } //////////////////////////////////////////////////////////////////////////////////////////////////// // // S T S M // //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void stsm(uint32_t ptr, uint32_t const& src) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 asm volatile("stmatrix.sync.aligned.m8n8.x1.shared.b16 [%0], {%1};\n" ::"r"(ptr), "r"(src)); #endif } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void stsmt(uint32_t ptr, uint32_t const& src) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 asm volatile("stmatrix.sync.aligned.m8n8.x1.trans.shared.b16 [%0], {%1};\n" ::"r"(ptr), "r"(src)); #endif } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void stsm(uint32_t ptr, uint2 const& src) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 asm volatile("stmatrix.sync.aligned.m8n8.x2.shared.b16 [%0], {%1, %2};\n" ::"r"(ptr), "r"(src.x), "r"(src.y)); #endif } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void stsmt(uint32_t ptr, uint2 const& src) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 asm volatile("stmatrix.sync.aligned.m8n8.x2.trans.shared.b16 [%0], {%1, %2};\n" ::"r"(ptr), "r"(src.x), "r"(src.y)); #endif } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void stsm(uint32_t ptr, uint4 const& src) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 asm volatile("stmatrix.sync.aligned.m8n8.x4.shared.b16 [%0], {%1, %2, %3, %4};\n" ::"r"(ptr), "r"(src.x), "r"(src.y), "r"(src.z), "r"(src.w)); #endif } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void stsmt(uint32_t ptr, uint4 const& src) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 asm volatile("stmatrix.sync.aligned.m8n8.x4.trans.shared.b16 [%0], {%1, %2, %3, %4};\n" ::"r"(ptr), "r"(src.x), "r"(src.y), "r"(src.z), "r"(src.w)); #endif } //////////////////////////////////////////////////////////////////////////////////////////////////// // // S T G // //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void stg(void* ptr, float val) { *reinterpret_cast(ptr) = val; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void stg(void* ptr, uint8_t val) { *reinterpret_cast(ptr) = val; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void stg(void* ptr, uint16_t val) { *reinterpret_cast(ptr) = val; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void stg(void* ptr, uint32_t val) { *reinterpret_cast(ptr) = val; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void stg(void* ptr, uint2 val) { *reinterpret_cast(ptr) = val; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void stg(void* ptr, uint4 val) { *reinterpret_cast(ptr) = val; } //////////////////////////////////////////////////////////////////////////////////////////////////// // // S T S // //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void sts(uint32_t ptr, uint16_t val) { asm volatile("st.shared.b16 [%0], %1;\n" : : "r"(ptr), "h"(val)); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void sts(uint32_t ptr, uint32_t val) { asm volatile("st.shared.b32 [%0], %1;\n" : : "r"(ptr), "r"(val)); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void sts(uint32_t ptr, uint2 val) { asm volatile("st.shared.v2.b32 [%0], {%1, %2};\n" : : "r"(ptr), "r"(val.x), "r"(val.y)); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void sts(uint32_t ptr, uint4 val) { asm volatile("st.shared.v4.b32 [%0], {%1, %2, %3, %4};\n" : : "r"(ptr), "r"(val.x), "r"(val.y), "r"(val.z), "r"(val.w)); } //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ void sts_(uint32_t (&ptrs)[N], Data_type const (&data)[N]) { #pragma unroll for (int ii = 0; ii < N; ++ii) { sts(ptrs[ii], data[ii]); } } //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ void sts(uint32_t (&ptrs)[N], uint16_t const (&data)[N]) { sts_(ptrs, data); } //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ void sts(uint32_t (&ptrs)[N], uint32_t const (&data)[N]) { sts_(ptrs, data); } //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ void sts(uint32_t (&ptrs)[N], uint2 const (&data)[N]) { sts_(ptrs, data); } //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ void sts(uint32_t (&ptrs)[N], uint4 const (&data)[N]) { sts_(ptrs, data); } //////////////////////////////////////////////////////////////////////////////////////////////////// #define __HALF2_TO_UI(var) *(reinterpret_cast(&(var))) #define __HALF2_TO_CUI(var) *(reinterpret_cast(&(var))) static __device__ __inline__ void atomicAdd_half2(half2* const address, const half2 val) { asm volatile("{ red.global.add.noftz.f16x2 [%0],%1; }\n" ::"l"(address), "r"(__HALF2_TO_CUI(val)) : "memory"); } //////////////////////////////////////////////////////////////////////////////////////////////////// template static inline __device__ uint32_t float4_to_char4(float x, float y, float z, float w) { #if defined(USE_F2I_EMULATION_TRICK) // Make sure the float is in the proper range. float cx, cy, cz, cw; if (CAN_BE_NEGATIVE) { cx = fmha::clamp(x, -128.f, 127.f); cy = fmha::clamp(y, -128.f, 127.f); cz = fmha::clamp(z, -128.f, 127.f); cw = fmha::clamp(w, -128.f, 127.f); } else { cx = fminf(x, 127.f); cy = fminf(y, 127.f); cz = fminf(z, 127.f); cw = fminf(w, 127.f); } // Re-add the magic number. cx += FP32_I2F_MAGIC_NUMBER; cy += FP32_I2F_MAGIC_NUMBER; cz += FP32_I2F_MAGIC_NUMBER; cw += FP32_I2F_MAGIC_NUMBER; // We need unsigned ints... uint32_t a = reinterpret_cast(cx); uint32_t b = reinterpret_cast(cy); uint32_t c = reinterpret_cast(cz); uint32_t d = reinterpret_cast(cw); // Pack the numbers. uint32_t dst; asm volatile("prmt.b32 %0, %1, %2, 0x0040;\n" : "=r"(dst) : "r"(a), "r"(b)); asm volatile("prmt.b32 %0, %0, %1, 0x0410;\n" : "+r"(dst) : "r"(c)); asm volatile("prmt.b32 %0, %0, %1, 0x4210;\n" : "+r"(dst) : "r"(d)); return dst; #else uint32_t a; asm volatile("cvt.rni.sat.s32.f32 %0, %1;\n" : "=r"(a) : "f"(x)); uint32_t b; asm volatile("cvt.rni.sat.s32.f32 %0, %1;\n" : "=r"(b) : "f"(y)); uint32_t c; asm volatile("cvt.rni.sat.s32.f32 %0, %1;\n" : "=r"(c) : "f"(z)); uint32_t d; asm volatile("cvt.rni.sat.s32.f32 %0, %1;\n" : "=r"(d) : "f"(w)); uint32_t dst; asm volatile("cvt.pack.sat.s8.s32.b32 %0, %1, %2, 0;\n" : "=r"(dst) : "r"(d), "r"(c)); asm volatile("cvt.pack.sat.s8.s32.b32 %0, %1, %2, %0;\n" : "+r"(dst) : "r"(b), "r"(a)); return dst; #endif // defined(USE_F2I_EMULATION_TRICK) } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ void swizzle_rows(uint32_t& a, uint32_t& b, uint32_t c, uint32_t d) { asm volatile("prmt.b32 %0, %1, %2, 0x6420;\n" : "=r"(a) : "r"(c), "r"(d)); asm volatile("prmt.b32 %0, %1, %2, 0x7531;\n" : "=r"(b) : "r"(c), "r"(d)); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void ldsm_with_lds(uint2& data, uint32_t smem) { int lane = threadIdx.x % 32; data = {0, 0}; uint4 v = {0, 0, 0, 0}; uint32_t* a = reinterpret_cast(&v); if (lane < 16) { fmha::lds(v, smem); } int src_row = lane / 4; int src_col = lane % 4; for (int it = 0; it < 4; it++) { uint32_t val = a[it]; uint32_t x = __shfl_sync(uint32_t(-1), val, src_row); __syncwarp(); uint32_t y = __shfl_sync(uint32_t(-1), val, src_row + 8); __syncwarp(); if (it == src_col) { data.x = x; data.y = y; } } } inline __device__ void ldsmt_with_lds(uint2& data, uint32_t smem) { int lane = threadIdx.x % 32; uint4 tmp16{0, 0, 0, 0}; // 16B if (lane < 16) { fmha::lds(tmp16, smem); } uint16_t* tmp16c = reinterpret_cast(&tmp16); // 8x2B: we move pairs uint16_t* t = reinterpret_cast(&data); // 4x2B int const src_col = lane / 4; // 0 - 7 int const src_row = (lane % 4) * 2; // we have to shuffle the values to distribute them in the warp #pragma unroll for (int it = 0; it < 8; it++) { uint16_t val, x, y; val = tmp16c[it]; x = __shfl_sync(uint32_t(-1), val, src_row + 0); __syncwarp(); y = __shfl_sync(uint32_t(-1), val, src_row + 1); __syncwarp(); if (src_col == it) { t[0] = x; t[1] = y; } val = tmp16c[it]; x = __shfl_sync(uint32_t(-1), val, src_row + 8); __syncwarp(); y = __shfl_sync(uint32_t(-1), val, src_row + 9); __syncwarp(); if (src_col == it) { t[2] = x; t[3] = y; } } } //////////////////////////////////////////////////////////////////////////////////////////////////// template struct MaxOp { __device__ inline T operator()(T const& x, T const& y) { return x > y ? x : y; } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct SumOp { __device__ inline T operator()(T const& x, T const& y) { return x + y; } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Allreduce { static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); template static __device__ inline T run(T x, Operator& op) { constexpr int OFFSET = THREADS / 2; x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); return Allreduce::run(x, op); } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template <> struct Allreduce<2> { template static __device__ inline T run(T x, Operator& op) { x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); return x; } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template __device__ inline void quad_reduce(float (&dst)[M], float (&src)[M], Operator& op) { #pragma unroll for (int mi = 0; mi < M; mi++) { dst[mi] = src[mi]; dst[mi] = op(dst[mi], __shfl_down_sync(uint32_t(-1), dst[mi], 2)); dst[mi] = op(dst[mi], __shfl_down_sync(uint32_t(-1), dst[mi], 1)); } } //////////////////////////////////////////////////////////////////////////////////////////////////// template __device__ inline void quad_reduce(float (&dst)[M], float2 (&src)[M], Operator& op) { float tmp[M]; #pragma unroll for (int mi = 0; mi < M; mi++) { tmp[mi] = op(src[mi].x, src[mi].y); } quad_reduce(dst, tmp, op); } //////////////////////////////////////////////////////////////////////////////////////////////////// template __device__ inline void quad_allreduce(float (&dst)[M], float (&src)[M], Operator& op) { #pragma unroll for (int mi = 0; mi < M; mi++) { dst[mi] = src[mi]; dst[mi] = Allreduce<4>::run(dst[mi], op); } } //////////////////////////////////////////////////////////////////////////////////////////////////// template __device__ inline void quad_allreduce(float (&dst)[M], float2 (&src)[M], Operator& op) { float tmp[M]; #pragma unroll for (int mi = 0; mi < M; mi++) { tmp[mi] = op(src[mi].x, src[mi].y); } quad_allreduce(dst, tmp, op); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ uint32_t elect_one_sync() { uint32_t pred = 0; #if __CUDA_ARCH__ >= 900 #if !defined(__CUDACC_RTC__) uint32_t laneid = 0; asm volatile( "\n\ {\n\ .reg .b32 %rx;\n\ .reg .pred %px;\n\ elect.one.sync %rx|%px, %2;\n\ @%px mov.s32 %1, 1;\n\ mov.s32 %0, %rx;\n\ }\n" : "+r"(laneid), "+r"(pred) : "r"(0xFFFFFFFF)); #else pred = threadIdx.x == 0; #endif #endif return pred; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ uint16_t float2_to_e4m3x2(float x, float y) { #if defined(__CUDA_ARCH__) && ((__CUDA_ARCH__ == 890 && defined(FMHA_ENABLE_SM89_QMMA)) || (__CUDA_ARCH__ >= 900)) uint16_t res; asm volatile("cvt.rn.e4m3x2.f32.satfinite %0, %2, %1;" : "=h"(res) : "f"(x), "f"(y)); return res; #else assert(false); return 0; #endif } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ uint32_t float4_to_e4m3x4(float x, float y, float z, float w) { #if defined(__CUDA_ARCH__) && ((__CUDA_ARCH__ == 890 && defined(FMHA_ENABLE_SM89_QMMA)) || (__CUDA_ARCH__ >= 900)) uint32_t res; asm volatile( "{\n" ".reg .b16 lo;\n" ".reg .b16 hi;\n" "cvt.rn.e4m3x2.f32.satfinite lo, %2, %1;\n" "cvt.rn.e4m3x2.f32.satfinite hi, %4, %3;\n" "mov.b32 %0, {lo, hi};\n" "}" : "=r"(res) : "f"(x), "f"(y), "f"(z), "f"(w)); return res; #else assert(false); return 0; #endif } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ uint32_t float4_to_e5m2x4(float x, float y, float z, float w) { #if defined(__CUDA_ARCH__) && ((__CUDA_ARCH__ == 890 && defined(FMHA_ENABLE_SM89_QMMA)) || (__CUDA_ARCH__ >= 900)) uint32_t res; asm volatile( "{\n" ".reg .b16 lo;\n" ".reg .b16 hi;\n" "cvt.rn.e5m2x2.f32.satfinite lo, %2, %1;\n" "cvt.rn.e5m2x2.f32.satfinite hi, %4, %3;\n" "mov.b32 %0, {lo, hi};\n" "}" : "=r"(res) : "f"(x), "f"(y), "f"(z), "f"(w)); return res; #else assert(false); return 0; #endif } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ uint32_t half4_to_e4m3x4(uint32_t const h2_0, uint32_t const h2_1) { #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890)) uint32_t res; asm volatile( "{\n" ".reg .b16 lo, hi;\n" "cvt.satfinite.rn.e4m3x2.f16x2 lo, %1;\n" "cvt.satfinite.rn.e4m3x2.f16x2 hi, %2;\n" "mov.b32 %0, {lo, hi};\n" "}\n" : "=r"(res) : "r"(h2_0), "r"(h2_1)); return res; #else assert(false); return 0; #endif } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ uint32_t half4_to_e5m2x4(uint32_t const h2_0, uint32_t const h2_1) { #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890)) uint32_t res; asm volatile( "{\n" ".reg .b16 lo, hi;\n" "cvt.satfinite.rn.e5m2x2.f16x2 lo, %1;\n" "cvt.satfinite.rn.e5m2x2.f16x2 hi, %2;\n" "mov.b32 %0, {lo, hi};\n" "}\n" : "=r"(res) : "r"(h2_0), "r"(h2_1)); return res; #else assert(false); return 0; #endif } //////////////////////////////////////////////////////////////////////////////////////////////////// // Helpers to pack float4 into a destination register with 4 8bit values template inline __device__ uint32_t float4_to_8bitx4(float const x, float const y, float const z, float const w) { return float4_to_char4(x, y, z, w); }; template <> inline __device__ uint32_t float4_to_8bitx4(float const x, float const y, float const z, float const w) { return float4_to_e4m3x4(x, y, z, w); }; template <> inline __device__ uint32_t float4_to_8bitx4(float const x, float const y, float const z, float const w) { return float4_to_e5m2x4(x, y, z, w); }; //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ uint32_t half4_to_fp8x4(uint32_t const h2_0, uint32_t const h2_1); template <> inline __device__ uint32_t half4_to_fp8x4(uint32_t const h2_0, uint32_t const h2_1) { return half4_to_e4m3x4(h2_0, h2_1); } template <> inline __device__ uint32_t half4_to_fp8x4(uint32_t const h2_0, uint32_t const h2_1) { return half4_to_e5m2x4(h2_0, h2_1); } //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ uint32_t float4_to_fp8x4(float const, float const, float const, float const); template <> inline __device__ uint32_t float4_to_fp8x4(float const x, float const y, float const z, float const w) { return float4_to_e4m3x4(x, y, z, w); } template <> inline __device__ uint32_t float4_to_fp8x4(float const x, float const y, float const z, float const w) { return float4_to_e5m2x4(x, y, z, w); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void fence_view_async_shared() { // Issue a shared memory fence for async operations (FENCE.VIEW.ASYNC.S) // only compiles on sm90+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 asm volatile("fence.proxy.async.shared::cta;\n"); #else assert(false); #endif } inline __device__ void fence_view_async_global() { // Issue a global memory fence for async operations (FENCE.VIEW.ASYNC.G) // only compiles on sm90+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 asm volatile("fence.proxy.async.global::cta;\n"); #else assert(false); #endif } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ char* align_1024(char* ptr) { uint64_t address_bit = reinterpret_cast(ptr); uint64_t offset = address_bit % 1024; if (offset == 0) { return ptr; } else { return ptr + (1024 - offset); } } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ float atomicMaxFloat(float* addr, float value) { float old; old = (value >= 0) ? __int_as_float(atomicMax((int*) addr, __float_as_int(value))) : __uint_as_float(atomicMin((unsigned int*) addr, __float_as_uint(value))); return old; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ float atomicMaxFloatPos_(float* addr, float value) { // VALUE MUST BE POSITIVE! USED ONLY FOR INTERNAL AMAX REDUCTION. float old = __int_as_float(atomicMax((int*) addr, __float_as_int(value))); return old; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ float max3Pos_(float const a, float const b, float const c) { // VALUE MUST BE POSITIVE! USED ONLY FOR INTERNAL AMAX REDUCTION. float res; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 int32_t a_ = reinterpret_cast(a); int32_t b_ = reinterpret_cast(b); int32_t c_ = reinterpret_cast(c); int32_t tmp; asm volatile("max.s16x2 %0, %1, %2;\n" : "=r"(tmp) : "r"(a_), "r"(b_)); asm volatile("max.s16x2 %0, %0, %1;\n" : "+r"(tmp) : "r"(tmp), "r"(c_)); res = reinterpret_cast(tmp); #else res = fmaxf(a, fmaxf(b, c)); #endif return res; } //////////////////////////////////////////////////////////////////////////////////////////////////// // Fast approximate tanh. static inline __device__ float __tanhf(float x) { #if (__CUDA_ARCH__ >= 750) float r = x; asm("tanh.approx.f32 %0, %0;" : "+f"(r)); return r; #else return tanhf(x); #endif } //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace fmha