/* * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #include "cuda_hint.cuh" #include "utils.h" #ifndef GENERATE_CUBIN #include #else #include "mha_stdheaders.cuh" #endif #ifndef __CUDACC__ #include #endif #include "barriers.cuh" #include #include #include inline constexpr float log2e = 1.4426950408889634; // std::log2(M_E) // we used an optimization where exp(x-rowMax) is computed as: /* bias = rowMax * log2e // shared for the whole row exp(x-rowMax) = exp2f(x * log2e - bias) */ // But this optimization is not numerically stable when (x * log2e - bias) is computed with FMA and x is too large. For // this reason, don't set safeInitRowMax with a huge absolute value. inline constexpr float safeInitRowMax = -1e+5F; inline constexpr int32_t kBAD_PAGE_INDEX = -1; __constant__ constexpr float kE4M3_MAX = 448.F; #ifdef __CUDA_ARCH__ #if __CUDA_ARCH__ == 860 || __CUDA_ARCH__ == 890 || __CUDA_ARCH__ == 1200 constexpr uint32_t kMAX_SMEM_SIZE = (99u << 10); #elif __CUDA_ARCH__ == 800 || __CUDA_ARCH__ == 870 constexpr uint32_t kMAX_SMEM_SIZE = (163u << 10); #elif __CUDA_ARCH__ == 900 constexpr uint32_t kMAX_SMEM_SIZE = (227u << 10); #endif #endif __device__ inline void assertWarpConverged() { // assert(__activemask() == ~0U); } #define DEFINE_VEC_BINARY_FUNC(func) \ template \ __device__ __host__ inline Vec(), mha::declval())), size> func( \ Vec const& a, Vec const& b) \ { \ Vec(), mha::declval())), size> result; \ _Pragma("unroll") for (uint32_t i = 0; i < size; i++) \ { \ result[i] = func(a[i], b[i]); \ } \ return result; \ } DEFINE_VEC_BINARY_FUNC(max) DEFINE_VEC_BINARY_FUNC(fmaxf) DEFINE_VEC_BINARY_FUNC(__hadd2_rn) __device__ __host__ inline float2 addFloat2(float2 a, float2 b) { return float2{a.x + b.x, a.y + b.y}; } DEFINE_VEC_BINARY_FUNC(addFloat2) #undef DEFINE_VEC_BINARY_FUNC #define DEFINE_VEC_BINARY_OP(op) \ template \ __device__ __host__ inline Vec() op mha::declval()), size> operator op( \ Vec const& a, Vec const& b) \ { \ Vec() op mha::declval()), size> result; \ _Pragma("unroll") for (uint32_t i = 0; i < size; i++) \ { \ result[i] = a[i] op b[i]; \ } \ return result; \ } \ template \ __device__ __host__ inline Vec() op mha::declval()), size> operator op( \ Vec const& a, Scalar const& b) \ { \ Vec() op mha::declval()), size> result; \ _Pragma("unroll") for (uint32_t i = 0; i < size; i++) \ { \ result[i] = a[i] op b; \ } \ return result; \ } \ template \ __device__ __host__ inline Vec() op mha::declval()), size> operator op( \ Scalar const& a, Vec const& b) \ { \ Vec() op mha::declval()), size> result; \ _Pragma("unroll") for (uint32_t i = 0; i < size; i++) \ { \ result[i] = a op b[i]; \ } \ return result; \ } // Don't use DEFINE_VEC_BINARY_FUNC(operator+), as operator+(float, float) is undefined, // and float will be converted into half to perform the operation, which results in much // lower precision. It's a defect of C++ that operator+(1.F, 2.F) does not work! DEFINE_VEC_BINARY_OP(+) DEFINE_VEC_BINARY_OP(-) DEFINE_VEC_BINARY_OP(*) DEFINE_VEC_BINARY_OP(/) DEFINE_VEC_BINARY_OP(==) DEFINE_VEC_BINARY_OP(!=) DEFINE_VEC_BINARY_OP(>) DEFINE_VEC_BINARY_OP(<) DEFINE_VEC_BINARY_OP(>=) DEFINE_VEC_BINARY_OP(<=) #undef DEFINE_VEC_BINARY_OP template HOST_DEVICE_FUNC inline bool all(Vec const& src) { bool ret = true; #pragma unroll for (uint32_t i = 0; i < size; i++) { ret = ret && src[i]; } return ret; } template HOST_DEVICE_FUNC inline bool any(Vec const& src) { bool ret = false; #pragma unroll for (uint32_t i = 0; i < size; i++) { ret = ret || src[i]; } return ret; } #define DEFINE_VEC_UNARY_OP(op) \ template \ __device__ __host__ inline Vec())), size> op(Vec const& a) \ { \ Vec())), size> result; \ _Pragma("unroll") for (uint32_t i = 0; i < size; i++) \ { \ result[i] = op(a[i]); \ } \ return result; \ } DEFINE_VEC_UNARY_OP(expf) DEFINE_VEC_UNARY_OP(exp2f) DEFINE_VEC_UNARY_OP(__float2bfloat162_rn) DEFINE_VEC_UNARY_OP(__float2half2_rn) DEFINE_VEC_UNARY_OP(__float22half2_rn) DEFINE_VEC_UNARY_OP(__bfloat1622float2) DEFINE_VEC_UNARY_OP(__half22float2) DEFINE_VEC_UNARY_OP(__frcp_rn) #undef DEFINE_VEC_UNARY_OP template __device__ __host__ inline Vec convert(Vec const& src) { if constexpr (mha::is_same_v, mha::decay_t>) { return src; } Vec dst; if constexpr (mha::is_same_v && mha::is_same_v) { for (uint32_t i = 0; i < size - 1; i += 2) { reinterpret_cast(dst[i]) = __half22float2(reinterpret_cast(src[i])); } if constexpr (size % 2 != 0) { dst[size - 1] = Dst{src[size - 1]}; } } else if constexpr (mha::is_same_v && mha::is_same_v) { for (uint32_t i = 0; i < size - 1; i += 2) { reinterpret_cast(dst[i]) = __float22half2_rn(reinterpret_cast(src[i])); } if constexpr (size % 2 != 0) { dst[size - 1] = Dst{src[size - 1]}; } } if constexpr (mha::is_same_v && mha::is_same_v) { for (uint32_t i = 0; i < size - 1; i += 2) { reinterpret_cast(dst[i]) = __bfloat1622float2(reinterpret_cast<__nv_bfloat162 const&>(src[i])); } if constexpr (size % 2 != 0) { dst[size - 1] = Dst{src[size - 1]}; } } else if constexpr (mha::is_same_v && mha::is_same_v) { for (uint32_t i = 0; i < size - 1; i += 2) { reinterpret_cast<__nv_bfloat162&>(dst[i]) = __float22bfloat162_rn(reinterpret_cast(src[i])); } if constexpr (size % 2 != 0) { dst[size - 1] = Dst{src[size - 1]}; } } else if constexpr (mha::is_same_v && mha::is_same_v) { for (uint32_t i = 0; i < size - 1; i += 2) { reinterpret_cast(dst[i]) = float2(reinterpret_cast<__nv_fp8x2_e4m3 const&>(src[i])); } if constexpr (size % 2 != 0) { dst[size - 1] = Dst{src[size - 1]}; } } else if constexpr (mha::is_same_v && mha::is_same_v) { for (uint32_t i = 0; i < size - 1; i += 2) { reinterpret_cast<__nv_fp8x2_e4m3&>(dst[i]) = __nv_fp8x2_e4m3{float2{src[i], src[i + 1]}}; } if constexpr (size % 2 != 0) { dst[size - 1] = Dst{src[size - 1]}; } } else if constexpr (mha::is_same_v && mha::is_same_v) { for (uint32_t i = 0; i < size - 1; i += 2) { reinterpret_cast(dst[i]) = half2(reinterpret_cast<__nv_fp8x2_e4m3 const&>(src[i])); } if constexpr (size % 2 != 0) { dst[size - 1] = Dst{src[size - 1]}; } } else if constexpr (mha::is_same_v && mha::is_same_v) { for (uint32_t i = 0; i < size - 1; i += 2) { reinterpret_cast<__nv_fp8x2_e4m3&>(dst[i]) = __nv_fp8x2_e4m3{reinterpret_cast(src[i])}; } if constexpr (size % 2 != 0) { dst[size - 1] = Dst{src[size - 1]}; } } // else if constexpr (mha::is_same_v && mha::is_same_v) { // static_assert("not implemented"); // } else if constexpr (mha::is_same_v && mha::is_same_v) { for (uint32_t i = 0; i < size - 1; i += 2) { reinterpret_cast<__nv_fp8x2_e4m3&>(dst[i]) = __nv_fp8x2_e4m3{reinterpret_cast<__nv_bfloat162 const&>(src[i])}; } if constexpr (size % 2 != 0) { dst[size - 1] = Dst{src[size - 1]}; } } else { for (uint32_t i = 0; i < size; i++) { dst[i] = Dst{src[i]}; } } return dst; } __device__ inline uint32_t laneId() { uint32_t id; asm("mov.u32 %0, %%laneid;\n" : "=r"(id)); return id; } __device__ inline uint32_t dynamicSmemSize() { uint32_t size; asm("mov.u32 %0, %%dynamic_smem_size;\n" : "=r"(size)); return size; } __device__ inline void trap() { asm volatile("trap;\n"); } inline constexpr uint32_t warp_size = 32; struct Warp { }; __device__ inline Warp this_warp() { return {}; } // @fixme: check asm code to make sure UR is used and SHFL is not generated. template __device__ inline T makeWarpUniform(Warp const& warp, T const& val) { T const val0 = __shfl_sync(~0U, val, 0); assert(val == val0); return val0; } __device__ inline uint3 getWarpIdx(uint3 ctaShapeInWarps, Warp const& warp = this_warp()) { assert(ctaShapeInWarps.x % 128 == 0); return uint3{ctaShapeInWarps.x == 1 ? 0 : makeWarpUniform(warp, threadIdx.x / warp_size), ctaShapeInWarps.y == 1 ? 0 : makeWarpUniform(warp, threadIdx.y), ctaShapeInWarps.z == 1 ? 0 : makeWarpUniform(warp, threadIdx.z)}; } constexpr uint32_t cacheLineSize = 128; template __device__ __host__ inline void assertIsPowerOf2() { static_assert((x & (x - 1)) == 0); } template __device__ inline bool hasBankConflict(T* p) { static_assert(sizeof(T) % 4 == 0 && sizeof(T) <= 16 && alignof(T) == sizeof(T)); constexpr uint32_t grpSize = 128 / sizeof(T); const uint32_t grpMask = (((1U << grpSize) - 1U) << (laneId() / grpSize * grpSize)); uint32_t const x = reinterpret_cast(p) / sizeof(T) % grpSize; auto const match = __match_any_sync(grpMask, x); bool const conflict = __popc(match) > 1; if (grpSize <= 8 && conflict) { char str[grpSize * 2 + 1] = {}; for (uint32_t i = 0; i < grpSize; i++) { str[i * 2] = __shfl_sync(grpMask, x, i, grpSize) + '0'; str[i * 2 + 1] = ' '; } if (laneId() % grpSize == 0) { printf("bank conflict (%u): %s\n", match, str); } } return conflict; } __device__ inline float atomicMax(float* addr, float value) { float old; old = (value >= 0) ? __int_as_float(atomicMax(reinterpret_cast(addr), __float_as_int(value))) : __uint_as_float(atomicMin(reinterpret_cast(addr), __float_as_uint(value))); return old; } __device__ inline bool isInInt32Range(uint32_t x) { return x <= static_cast(mha::numeric_limits::max()); } // struct of arrays instead of array of structs for compact storage template struct CompactRangeList { mha::array pointerList; mha::array sizeList; struct Range { Pointer const& data; uint32_t const& size; }; __device__ inline Range operator[](uint32_t i) const { return Range{pointerList[i], sizeList[i]}; } }; // alignedForSwizzle is for case when you need to mix TMA+LDS/LDSM, or LDGSTS/STS/STSM+GMMA template struct alignas(mha::min(maxArrayAlign(rows_* cols_), cacheLineSize)) Array2D { using Elem = T; static constexpr uint32_t rows = rows_; static constexpr uint32_t cols = cols_; static constexpr uint32_t size = rows * cols; static constexpr uint32_t rowBytes = sizeof(T) * cols; template __device__ inline T const& at(uint32_t r, uint32_t c) const { assert(r < rows && c < cols); // two different swizzle styles #if 1 uint32_t const c_swizzled = [&] { if constexpr (swizzle) { static_assert(rowBytes % cacheLineSize == 0 || cacheLineSize % rowBytes == 0); static constexpr uint32_t rowsPerSliding = exactDiv(cacheLineSize, rowBytes % cacheLineSize == 0 ? cacheLineSize : rowBytes % cacheLineSize); constexpr uint32_t swizzleRowsRepeat = exactDiv(cacheLineSize, sizeof(Elem)); auto const runtimeBaseOffset = static_cast(__cvta_generic_to_shared(this->data)) / rowBytes % rows; uint32_t const baseOffset = alignedForSwizzle ? 0 : runtimeBaseOffset; // To match TMA when array is not aligned to pattern boundary uint32_t const xorMask = alignedForSwizzle ? BoundedVal{r} .template divBy() .template mod() .get() : (r + baseOffset) / rowsPerSliding % exactDiv(swizzleRowsRepeat, rowsPerSliding); return c ^ xorMask; } return c; }(); #else uint32_t const c_swizzled = swizzle ? (c + r / rowsPerSliding) % cols : c; #endif T const& ret = (&data[0][0])[r * cols + c_swizzled]; assert(&data[r][c_swizzled] == &ret); return ret; } template __device__ inline T& at(uint32_t r, uint32_t c) { return const_cast(static_cast(this)->at(r, c)); } __device__ inline T const& operator()(uint32_t r, uint32_t c) const { return at(r, c); } __device__ inline T& operator()(uint32_t r, uint32_t c) { return at(r, c); } template __device__ inline Array2D& as() { return reinterpret_cast&>(*this); } __device__ inline void fill(T val) { #pragma unroll for (uint32_t i = 0; i < rows * cols; i++) { (&data[0][0])[i] = val; } } __device__ inline static Array2D filled(T val) { Array2D ret; ret.fill(val); return ret; } T data[rows][cols]; }; #define DEFINE_ARRAY2D_BINARY_OP(op) \ template \ __device__ __host__ inline Array2D() op mha::declval()), rows, cols> operator op( \ Array2D const& a, Array2D const& b) \ { \ Array2D() op mha::declval()), rows, cols> result; \ _Pragma("unroll") for (uint32_t i = 0; i < rows; i++) \ { \ for (uint32_t j = 0; j < cols; j++) \ { \ result(i, j) = a(i, j) op b(i, j); \ } \ } \ return result; \ } \ template \ __device__ __host__ inline Array2D() op mha::declval()), rows, cols> operator op( \ Array2D const& a, Scalar const& b) \ { \ Array2D() op mha::declval()), rows, cols> result; \ _Pragma("unroll") for (uint32_t i = 0; i < rows; i++) \ { \ for (uint32_t j = 0; j < cols; j++) \ { \ result(i, j) = a(i, j) op b; \ } \ } \ return result; \ } \ template \ __device__ __host__ inline Array2D() op mha::declval()), rows, cols> operator op( \ Scalar const& a, Array2D const& b) \ { \ Array2D() op mha::declval()), rows, cols> result; \ _Pragma("unroll") for (uint32_t i = 0; i < rows; i++) \ { \ for (uint32_t j = 0; j < cols; j++) \ { \ result(i, j) = a op b(i, j); \ } \ } \ return result; \ } // Don't use DEFINE_VEC_BINARY_FUNC(operator+), as operator+(float, float) is undefined, // and float will be converted into half to perform the operation, which results in much // lower precision. It's a defect of C++ that operator+(1.F, 2.F) does not work! DEFINE_ARRAY2D_BINARY_OP(+) DEFINE_ARRAY2D_BINARY_OP(-) DEFINE_ARRAY2D_BINARY_OP(*) using LdGrain = Vec; constexpr uint32_t grainBytes = sizeof(LdGrain); // wrapper for PTX ldmatrix template __device__ inline Vec ldmatrix(LdGrain const* row) { assertWarpConverged(); uint32_t a, b, c, d; if constexpr (nbMat == 4) { if (transpose) { asm("ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];\n" : "=r"(a), "=r"(b), "=r"(c), "=r"(d) : "l"(__cvta_generic_to_shared(row)) : "memory"); } else { asm("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n" : "=r"(a), "=r"(b), "=r"(c), "=r"(d) : "l"(__cvta_generic_to_shared(row)) : "memory"); } #if 0 auto checkMat = [&](uint32_t val, uint32_t idxMat) -> Vec const& { auto const v = (Vec const&)val; uint32_t const lane = laneId(); auto getRow = [&](uint32_t r) { assert(r<8); auto const ret = __shfl_sync(~0U, reinterpret_cast(row), 8*idxMat+r); return *reinterpret_cast const*>(ret); }; auto checkEq = [](uint16_t x, uint16_t y) { if (!(x==y)) { printf("x=%u, y= %u\n", (unsigned)x, (unsigned)y); } }; if (transpose) { checkEq(v[0], getRow(lane % 4 * 2)[lane / 4]); checkEq(v[1], getRow(lane % 4 * 2 + 1)[lane / 4]); } else { checkEq(v[0], getRow(lane / 4)[lane % 4 * 2]); checkEq(v[1], getRow(lane / 4)[lane % 4 * 2 + 1]); } }; checkMat(a, 0); checkMat(b, 1); checkMat(c, 2); checkMat(d, 3); #endif return Vec{a, b, c, d}; } else if constexpr (nbMat == 2) { if (transpose) { asm("ldmatrix.sync.aligned.m8n8.x2.trans.shared.b16 {%0, %1}, [%2];\n" : "=r"(a), "=r"(b) : "l"(__cvta_generic_to_shared(row)) : "memory"); } else { asm("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0, %1}, [%2];\n" : "=r"(a), "=r"(b) : "l"(__cvta_generic_to_shared(row)) : "memory"); } return Vec{a, b}; } else if constexpr (nbMat == 1) { if (transpose) { asm("ldmatrix.sync.aligned.m8n8.x2.trans.shared.b16 %0, [%1];\n" : "=r"(a) : "l"(__cvta_generic_to_shared(row)) : "memory"); } else { asm("ldmatrix.sync.aligned.m8n8.x2.shared.b16 %0, [%1];\n" : "=r"(a) : "l"(__cvta_generic_to_shared(row)) : "memory"); } return Vec{a}; } else { static_assert(nbMat == 1 || nbMat == 2 || nbMat == 4); } } template __device__ inline Vec ldmatrix_4x(Warp const& warp, LdGrain const* row) { return ldmatrix(row); } template __device__ inline Vec ldmatrix_16x16_trans(LdGrain const* row) { uint32_t a, b, c, d; if constexpr (nbMat == 1) { asm("ldmatrix.sync.aligned.m16n16.x1.trans.shared::cta.b8 {%0, %1}, [%2];\n" : "=r"(a), "=r"(b) : "l"(__cvta_generic_to_shared(row)) : "memory"); return Vec{a, b}; } else if constexpr (nbMat == 2) { asm("ldmatrix.sync.aligned.m16n16.x2.trans.shared::cta.b8 {%0, %1, %2, %3}, [%4];\n" : "=r"(a), "=r"(b), "=r"(c), "=r"(d) : "l"(__cvta_generic_to_shared(row)) : "memory"); return Vec{a, b, c, d}; } else { static_assert(nbMat == 1 || nbMat == 2); } } template __device__ inline void stmatrix(LdGrain* row, Vec const& data) { #if __CUDA_ARCH__ >= 900 assertWarpConverged(); if constexpr (nbMat == 4) { if constexpr (transpose) { asm("stmatrix.sync.aligned.m8n8.x4.trans.shared.b16 [%0], {%1, %2, %3, %4};\n" ::"l"( __cvta_generic_to_shared(row)), "r"(data[0]), "r"(data[1]), "r"(data[2]), "r"(data[3]) : "memory"); } else { asm("stmatrix.sync.aligned.m8n8.x4.shared.b16 [%0], {%1, %2, %3, %4};\n" ::"l"( __cvta_generic_to_shared(row)), "r"(data[0]), "r"(data[1]), "r"(data[2]), "r"(data[3]) : "memory"); } } else if constexpr (nbMat == 2) { if constexpr (transpose) { asm("stmatrix.sync.aligned.m8n8.x2.trans.shared.b16 [%0], {%1, %2};\n" ::"l"(__cvta_generic_to_shared(row)), "r"(data[0]), "r"(data[1]) : "memory"); } else { asm("stmatrix.sync.aligned.m8n8.x2.shared.b16 [%0], {%1, %2};\n" ::"l"(__cvta_generic_to_shared(row)), "r"(data[0]), "r"(data[1]) : "memory"); } } else if constexpr (nbMat == 1) { if constexpr (transpose) { asm("stmatrix.sync.aligned.m8n8.x1.trans.shared.b16 [%0], {%1};\n" ::"l"(__cvta_generic_to_shared(row)), "r"(data[0]) : "memory"); } else { asm("stmatrix.sync.aligned.m8n8.x1.shared.b16 [%0], {%1};\n" ::"l"(__cvta_generic_to_shared(row)), "r"(data[0]) : "memory"); } } else { static_assert(nbMat == 1 || nbMat == 2 || nbMat == 4); } #else trap(); #endif } template __device__ inline void stmatrix_4x(Warp const& warp, LdGrain* row, Vec const& data) { stmatrix(row, data); } struct None { }; template using RealTypeOrNone = mha::conditional_t; template struct MBarrierPair { MBarrier produced; MBarrier consumed; __device__ inline void initialize(uint32_t producedCount, uint32_t consumedCount) { init(&produced, producedCount); init(&consumed, consumedCount); } }; using CtaBarrierPair = MBarrierPair; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 template __device__ inline auto arrive_tx(MBarrier& bar, uint32_t txCount, uint32_t arriveCount = 1) { #if USE_CUSTOM_BARRIER return bar.arrive_tx(txCount, arriveCount); #else return cuda::device::barrier_arrive_tx(bar, arriveCount, txCount); #endif } template __device__ inline void arrive_tx_and_wait(MBarrier& bar, uint32_t txCount, uint32_t arriveCount = 1) { bar.wait(arrive_tx(bar, txCount, arriveCount)); } #endif template __device__ inline mha::tuple carryLE(uint32_t i0, uint32_t iLast) { return mha::tuple{i0 % bound0, iLast + i0 / bound0}; } template __device__ inline mha::tuple carryLE( uint32_t i0, uint32_t i1, decltype(bounds)... i, uint32_t iLast) { return mha::tuple_cat(mha::tuple(i0 % bound0), carryLE(i1 + i0 / bound0, i..., iLast)); } __device__ __host__ inline void assertClose(float a, float b, float threshold = 0.01f) { assert(abs(a - b) < threshold); } __device__ __host__ inline void assertClose(half a, half b, float threshold = 0.01f) { assertClose(__half2float(a), __half2float(b), threshold); } template __device__ inline Vec convertKCacheWordToF16(uint32_t i8data) { static_assert(mha::is_same_v || mha::is_same_v, "not implemented"); static_assert(sizeof(CacheElem) == 1); Vec ret; #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 890) if constexpr (mha::is_same_v && mha::is_same_v) { uint16_t(&src)[2] = reinterpret_cast(i8data); uint32_t(&dst)[2] = reinterpret_cast(ret); asm("{\n" "cvt.rn.f16x2.e4m3x2 %0, %2;\n" "cvt.rn.f16x2.e4m3x2 %1, %3;\n" "}" : "=r"(dst[0]), "=r"(dst[1]) : "h"(src[0]), "h"(src[1])); return ret; } #endif CacheElem const(&src)[4] = reinterpret_cast(i8data); InputElem(&dst)[4] = reinterpret_cast(ret); #pragma unroll for (uint32_t i = 0; i < 4; i++) { dst[i] = InputElem(src[i]); } return ret; } template __device__ inline Vec convertVCacheWordToF16(uint32_t i8data) { static_assert(mha::is_same_v || mha::is_same_v, "not implemented"); static_assert(sizeof(CacheElem) == 1); Vec ret; #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 890) if constexpr (mha::is_same_v && mha::is_same_v) { uint32_t(&dst)[2] = reinterpret_cast(ret); asm("{\n" ".reg .b32 dst0;\n" ".reg .b32 dst1;\n" ".reg .b32 src;\n" ".reg .b16 src0;\n" ".reg .b16 src1;\n" "prmt.b32 src, %2, 0x0, 0x3120;\n" "mov.b32 {src0, src1}, src;\n" "cvt.rn.f16x2.e4m3x2 %0, src0;\n" "cvt.rn.f16x2.e4m3x2 %1, src1;\n" "}" : "=r"(dst[0]), "=r"(dst[1]) : "r"(i8data)); return ret; } #endif CacheElem const(&src)[2][2] = reinterpret_cast(i8data); InputElem(&dst)[2][2] = reinterpret_cast(ret); #pragma unroll for (uint32_t i = 0; i < 2; i++) { #pragma unroll for (uint32_t j = 0; j < 2; j++) { dst[i][j] = InputElem(src[j][i]); } } return ret; } struct PermuteOrder { uint16_t x0 : 4; uint16_t x1 : 4; uint16_t x2 : 4; uint16_t x3 : 4; }; static_assert(sizeof(PermuteOrder) == 2); __device__ inline uint32_t prmt(uint32_t a, uint32_t b, PermuteOrder order) { uint32_t d; uint32_t const c = reinterpret_cast(order); asm("prmt.b32 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c)); return d; } __device__ inline uint32_t movmatrix(uint32_t src) { uint32_t dst; asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;\n" : "=r"(dst) : "r"(src)); return dst; } __device__ inline bool warpElectSync() { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) uint32_t pred = 0; asm volatile( "{\n" " .reg .b32 d;\n" " .reg .pred p;\n" " elect.sync d|p, 0xFFFFFFFF;\n" " selp.b32 %0, 1, 0, p;\n" "}\n" : "=r"(pred)); return pred != 0; #else assert("not available"); return false; #endif } __device__ inline void preExit() { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) asm volatile("griddepcontrol.launch_dependents;\n"); #endif } __device__ inline void acqBulk() { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) asm volatile("griddepcontrol.wait;\n"); #endif } __device__ inline uint3 nbClusters() { uint3 id; asm("mov.v4.u32 {%0, %1, %2, _}, %%nclusterid;\n" : "=r"(id.x), "=r"(id.y), "=r"(id.z)); return id; } __device__ inline uint3 clusterId() { uint3 id; asm("mov.v4.u32 {%0, %1, %2, _}, %%clusterid;\n" : "=r"(id.x), "=r"(id.y), "=r"(id.z)); return id; } __device__ inline uint32_t clusterCtaRank() { uint32_t rank; asm("mov.u32 %0, %%cluster_ctarank;\n" : "=r"(rank)); return rank; } __device__ inline uint3 clusterCtaId() { uint3 id; asm("mov.v4.u32 {%0, %1, %2, _}, %%cluster_ctaid;\n" : "=r"(id.x), "=r"(id.y), "=r"(id.z)); return id; } // src and return are both generic address template __device__ inline T* mapa(T* src, uint32_t clusterCtaRank) { uint64_t dst; asm volatile("mapa.u64 %0, %1, %2;\n" : "=l"(dst) : "l"(reinterpret_cast(src)), "r"(clusterCtaRank)); return reinterpret_cast(dst); } template __device__ inline T& mapa(T& src, uint32_t clusterCtaRank) { return *mapa(&src, clusterCtaRank); } __device__ inline void clusterBarArrive() { asm volatile("barrier.cluster.arrive.release.aligned;\n"); } __device__ inline void clusterBarWait() { asm volatile("barrier.cluster.wait.acquire.aligned;\n"); } __device__ inline uint32_t clock32() { uint32_t ret; asm volatile("mov.u32 %0, %%clock;\n" : "=r"(ret)::"memory"); return ret; } template struct BarWaiter { MBarrierPair (*bars)[nbBufs]; uint32_t idx; uint32_t idxBuf; bool skipBarWait = false; __device__ inline BarWaiter(MBarrierPair (&bars)[nbBufs], uint32_t idx) : bars{&bars} , idx{idx} , idxBuf{idx % nbBufs} { } __device__ inline bool testWait() { bool const parity = toParity(idx); skipBarWait = bar().produced.test_wait_parity(parity); return skipBarWait; } __device__ inline BarWaiter next(uint32_t step = 1) { return BarWaiter{*bars, idx + step}; } __device__ inline void wait() { if (!skipBarWait) { bar().produced.wait_parity(toParity(idx)); } } __device__ inline MBarrierPair& bar() { return (*bars)[idxBuf]; } __device__ inline void consumed() { bar().consumed.arrive(); } }; class Timer { public: __device__ inline Timer() { reset(); } __device__ inline void print(char const* name = "unnamed", bool reset = false) { auto const toc = clock32(); printf("%s: %u (block={%u, %u, %u})\n", name, toc - mTic, blockIdx.x, blockIdx.y, blockIdx.z); if (reset) { this->reset(); } } __device__ inline void reset() { mTic = clock32(); } private: uint32_t mTic; }; // [beg, end) struct Range { uint32_t beg, end; }; constexpr bool overlap(Range a, Range b) { return a.beg < b.end && b.beg < a.end; }