/* * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #include #ifdef __CUDACC__ // for CUDA #define FT_DEV_CEXPR __device__ __host__ inline constexpr #else #define FT_DEV_CEXPR inline constexpr #endif //---------------------------------------------------------------------------- // Cn: constant integer //---------------------------------------------------------------------------- template struct Cn : public std::integral_constant { }; template constexpr auto cn = Cn(); //---------------------------------------------------------------------------- // Operators for Cn //---------------------------------------------------------------------------- template FT_DEV_CEXPR auto operator+(Cn) { return cn<+value_>; } template FT_DEV_CEXPR auto operator-(Cn) { return cn<-value_>; } template FT_DEV_CEXPR auto operator!(Cn) { return cn; } template FT_DEV_CEXPR auto operator~(Cn) { return cn<~value_>; } template FT_DEV_CEXPR auto operator+(Cn, Cn) { return cn; } template FT_DEV_CEXPR auto operator-(Cn, Cn) { return cn; } template FT_DEV_CEXPR auto operator*(Cn, Cn) { return cn; } template FT_DEV_CEXPR auto operator/(Cn, Cn) { return cn; } template FT_DEV_CEXPR auto operator%(Cn, Cn) { return cn; } template FT_DEV_CEXPR auto operator<<(Cn, Cn) { return cn<(a_ << b_)>; } template FT_DEV_CEXPR auto operator>>(Cn, Cn) { return cn<(a_ >> b_)>; } template FT_DEV_CEXPR auto operator<(Cn, Cn) { return cn<(a_ < b_)>; } template FT_DEV_CEXPR auto operator<=(Cn, Cn) { return cn<(a_ <= b_)>; } template FT_DEV_CEXPR auto operator>(Cn, Cn) { return cn<(a_ > b_)>; } template FT_DEV_CEXPR auto operator>=(Cn, Cn) { return cn<(a_ >= b_)>; } template FT_DEV_CEXPR auto operator==(Cn, Cn) { return cn<(a_ == b_)>; } template FT_DEV_CEXPR auto operator!=(Cn, Cn) { return cn<(a_ != b_)>; } template FT_DEV_CEXPR auto operator^(Cn, Cn) { return cn; } template FT_DEV_CEXPR auto operator&(Cn, Cn) { return cn; } template FT_DEV_CEXPR auto operator&&(Cn, Cn) { return cn < a_ && b_ > ; } template FT_DEV_CEXPR auto operator|(Cn, Cn) { return cn; } template FT_DEV_CEXPR auto operator||(Cn, Cn) { return cn < a_ || b_ > ; } template FT_DEV_CEXPR std::enable_if_t> operator*(Cn, B_) { return cn; } template FT_DEV_CEXPR std::enable_if_t> operator/(Cn, B_) { return cn; } template FT_DEV_CEXPR std::enable_if_t> operator%(Cn, B_) { return cn; } template FT_DEV_CEXPR std::enable_if_t> operator<<(Cn, B_) { return cn; } template FT_DEV_CEXPR std::enable_if_t> operator>>(Cn, B_) { return cn; } template FT_DEV_CEXPR std::enable_if_t> operator&(Cn, B_) { return cn; } template FT_DEV_CEXPR std::enable_if_t> operator&&(Cn, B_) { return cn; } template FT_DEV_CEXPR std::enable_if_t> operator*(A_, Cn) { return cn; } template FT_DEV_CEXPR std::enable_if_t> operator%(A_, Cn) { return cn; } template FT_DEV_CEXPR std::enable_if_t> operator%(A_, Cn) { return cn; } template FT_DEV_CEXPR std::enable_if_t> operator&(A_, Cn) { return cn; } template FT_DEV_CEXPR std::enable_if_t> operator&&(A_, Cn) { return cn; } //---------------------------------------------------------------------------- // div_up & round_up //---------------------------------------------------------------------------- template FT_DEV_CEXPR auto cexpr_abs(T_ a_) // abs is not constexpr until C++20 { return a_ >= cn<0> ? +a_ : -a_; } template FT_DEV_CEXPR auto div_up(T_ a_, U_ b_) { auto tmp = a_ >= cn<0> ? a_ + (cexpr_abs(b_) - cn<1>) : a_ - (cexpr_abs(b_) - cn<1>); return tmp / b_; } template FT_DEV_CEXPR auto round_up(T_ a_, U_ b_) { auto tmp = a_ >= cn<0> ? a_ + (cexpr_abs(b_) - cn<1>) : a_ - (cexpr_abs(b_) - cn<1>); return tmp - tmp % b_; } template FT_DEV_CEXPR auto div_up(Cn, Cn) { return cn; } template FT_DEV_CEXPR auto round_up(Cn, Cn) { return cn; } template FT_DEV_CEXPR std::enable_if_t> div_up(Cn, B_) { return cn; } template FT_DEV_CEXPR std::enable_if_t> round_up(Cn, B_) { return cn; } //---------------------------------------------------------------------------- // IsTuple: std::tuple, but not std::pair, std::array, etc. //---------------------------------------------------------------------------- template struct IsTuple : public std::false_type { }; template struct IsTuple> : public std::true_type { }; template struct IsTuple : public IsTuple { }; template struct IsTuple : public IsTuple { }; template struct IsTuple : public IsTuple { }; template constexpr bool IsTuple_v = IsTuple::value; // vim: ts=2 sw=2 sts=2 et sta