/* * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #ifndef GENERATE_CUBIN #include #include #include #include #include #include #endif #include "mha_stdheaders.cuh" template HOST_DEVICE_FUNC constexpr inline void unused(T&& x) { static_cast(x); } #ifndef GENERATE_CUBIN inline void checkCuda(cudaError_t err) { if (err != cudaSuccess) { printf("%s\n", cudaGetErrorName(err)); throw std::runtime_error(cudaGetErrorName(err)); } } inline void checkCu(CUresult err) { if (err != CUDA_SUCCESS) { char const* str = nullptr; if (cuGetErrorName(err, &str) != CUDA_SUCCESS) { str = "A cuda driver API error happened, but we failed to query the error name\n"; } printf("%s\n", str); throw std::runtime_error(str); } } #endif HOST_DEVICE_FUNC constexpr inline uint32_t greatestPowerOf2Divisor(uint32_t x) { return x & ~(x - 1); } template HOST_DEVICE_FUNC constexpr uint32_t maxArrayAlign(uint32_t size) { return sizeof(T) * greatestPowerOf2Divisor(size); } HOST_DEVICE_FUNC constexpr inline uint32_t exactDiv(uint32_t a, uint32_t b) { assert(a % b == 0); return a / b; } template HOST_DEVICE_FUNC constexpr inline T divUp(T a, T b) { return (a + b - 1) / b; } template HOST_DEVICE_FUNC constexpr inline T roundUp(T a, T b) { return divUp(a, b) * b; } // upperBound is exclusive, i.e. range is [0, upperBound) template struct BoundedVal { template HOST_DEVICE_FUNC inline BoundedVal divBy() const { assert(value < upperBound); return {upperBound <= divisor ? 0 : value / divisor}; } template HOST_DEVICE_FUNC inline BoundedVal mod() const { assert(value < upperBound); return {upperBound <= divisor ? value : value % divisor}; } HOST_DEVICE_FUNC inline bool operator<=(uint32_t rhs) const { assert(value < upperBound); return upperBound <= rhs || value <= rhs; } HOST_DEVICE_FUNC inline uint32_t get() const { assert(value < upperBound); return upperBound == 1 ? 0 : value; } uint32_t value; }; template struct alignas(mha::max(alignof(T), mha::min(maxArrayAlign(size_), 16))) Vec { using Elem = T; static constexpr uint32_t size = size_; Elem data[size]; HOST_DEVICE_FUNC inline void fill(T val) { #pragma unroll for (uint32_t i = 0; i < size; i++) { data[i] = val; } } static HOST_DEVICE_FUNC inline Vec filled(T val) { Vec ret; ret.fill(val); return ret; } HOST_DEVICE_FUNC inline Elem const& operator[](uint32_t i) const { assert(i < size); return data[BoundedVal{i}.get()]; } HOST_DEVICE_FUNC inline Elem& operator[](uint32_t i) { assert(i < size); return data[BoundedVal{i}.get()]; } }; template struct CircIdx { public: static constexpr uint32_t nbBuffers = nbBuffers_; static_assert(nbBuffers >= 1); __device__ inline CircIdx(uint32_t init) : mIndex{init % nbBuffers} { } __device__ inline operator uint32_t() const { return mIndex; } __device__ inline CircIdx operator+(uint32_t i) const { return CircIdx{(mIndex + i) % nbBuffers}; } __device__ inline CircIdx operator-(uint32_t i) const { return CircIdx{(mIndex + (nbBuffers - 1) * i) % nbBuffers}; } __device__ inline CircIdx next() const { return *this + 1u; } __device__ inline CircIdx& operator++() { mIndex = next(); return *this; } __device__ inline CircIdx operator++(int) { CircIdx old = *this; operator++(); return old; } __device__ inline CircIdx prev() const { return *this - 1u; } __device__ inline CircIdx& operator--() { mIndex = prev(); return *this; } __device__ inline CircIdx operator--(int) { CircIdx old = *this; operator--(); return old; } private: uint32_t mIndex; }; // base is usually in constant memory, so usually only require 1 register to store the offset. template struct TinyPtr { T* base; // typically in constant memory or uniform registers uint32_t offset; // may be non-uniform template __device__ __host__ inline TinyPtr cast() const { D* const p = reinterpret_cast(base); assert(reinterpret_cast(p) % alignof(D) == 0); if constexpr (mha::is_void_v) { assert(offset == 0); return TinyPtr{p, 0}; } else if constexpr (sizeof(T) < sizeof(D)) { return TinyPtr{p, exactDiv(offset, exactDiv(sizeof(D), sizeof(T)))}; } else { return TinyPtr{p, offset * exactDiv(sizeof(T), sizeof(D))}; } } __device__ __host__ inline T& operator*() const { return base[offset]; } __device__ __host__ inline TinyPtr operator+(uint32_t i) const { return TinyPtr{base, offset + i}; } __device__ __host__ inline T& operator[](uint32_t i) const { return *(*this + i); } __device__ __host__ inline operator T*() const { return base + offset; } }; template class Segmenter { public: HOST_DEVICE_FUNC Segmenter(uint32_t offset = 0) : mNextOffset{offset} { } // offset is in bytes template HOST_DEVICE_FUNC OffsetInt newSeg(uint32_t count = 1, uint32_t alignment = alignof(T)) { mMaxAlignment = mha::max(mMaxAlignment, alignment); OffsetInt const offset = roundUp(mNextOffset, alignment); mNextOffset = offset + sizeof(T) * count; return offset; } HOST_DEVICE_FUNC OffsetInt getEndOffset() const { return mNextOffset; } HOST_DEVICE_FUNC uint32_t getMaxAlignment() const { return mMaxAlignment; } private: OffsetInt mNextOffset; uint32_t mMaxAlignment = 1; }; template using AddConst = mha::conditional_t; template class MemSegmenter { public: HOST_DEVICE_FUNC MemSegmenter(AddConst* base, uint32_t offset = 0) : mBase{static_cast*>(base)} , mSegmenter{offset} { } // to use TinyPtr, alignment must be sizeof(T) template HOST_DEVICE_FUNC TinyPtr> newSeg(uint32_t count = 1, uint32_t alignment = sizeof(T)) { assert(reinterpret_cast(mBase) % alignment == 0); OffsetInt const offset = mSegmenter.template newSeg(count, alignment); return TinyPtr>{mBase, offset}.template cast>(); } HOST_DEVICE_FUNC OffsetInt getEndOffset() const { return mSegmenter.getEndOffset(); } HOST_DEVICE_FUNC uint32_t getMaxAlignment() const { return mSegmenter.getMaxAlignment(); } private: AddConst* mBase; Segmenter mSegmenter; }; // dims in little endian template struct DimsLE { static constexpr uint32_t nbDims = nbDims_; __device__ __host__ inline uint32_t& operator[](uint32_t i) { return d[i]; } __device__ __host__ inline uint32_t const& operator[](uint32_t i) const { return d[i]; } uint32_t d[nbDims]; }; // check if val is in range [lb, ub) template constexpr bool inRange(T val, T lb, T ub) { return val >= lb && val < ub; } // val is an optimized / pre-computed value, ref is the original value template HOST_DEVICE_FUNC constexpr inline T checkedVal(T val, T ref) { assert(val == ref); return val; }