/* * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement * * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual * property and proprietary rights in and to this material, related * documentation and any modifications thereto. Any use, reproduction, * disclosure or distribution of this material and related documentation * without an express license agreement from NVIDIA CORPORATION or * its affiliates is strictly prohibited. */ #pragma once #ifndef GENERATE_CUBIN #include #include #include #include #include #include #include #endif #ifndef __CUDACC__ #include #endif #define HOST_DEVICE_FUNC __host__ __device__ #define DEVICE_FUNC __device__ namespace mha { #ifndef GENERATE_CUBIN template using numeric_limits = std::numeric_limits; using std::max; using std::min; #else using uint8_t = unsigned char; using int8_t = signed char; using uint16_t = unsigned short; using uint32_t = unsigned int; using int32_t = int; using uint64_t = unsigned long long; using uintptr_t = uint64_t; static_assert(sizeof(uint8_t) == 1); static_assert(sizeof(int8_t) == 1); static_assert(sizeof(uint16_t) == 2); static_assert(sizeof(uint32_t) == 4); static_assert(sizeof(int32_t) == 4); static_assert(sizeof(uint64_t) == 8); template class numeric_limits; template <> class numeric_limits { public: static constexpr int32_t max() noexcept { return 0x7FFFFFFF; } }; template <> class numeric_limits { public: static constexpr float lowest() noexcept { return -3.40282347E+38F; } }; template DEVICE_FUNC constexpr T const& max(T const& a, T const& b) { return a > b ? a : b; } template DEVICE_FUNC constexpr T const& min(T const& a, T const& b) { return a < b ? a : b; } #endif #ifndef GENERATE_CUBIN template using conditional_t = std::conditional_t; template using enable_if_t = typename std::enable_if::type; #else // https://en.cppreference.com/w/cpp/types/conditional template struct conditional { using type = T; }; template struct conditional { using type = F; }; template using conditional_t = typename conditional::type; template struct enable_if { }; template struct enable_if { typedef T type; }; template using enable_if_t = typename enable_if::type; #endif #ifndef GENERATE_CUBIN using byte = std::byte; #else // https://en.cppreference.com/w/cpp/types/byte enum class byte : unsigned char { }; #endif #ifndef GENERATE_CUBIN using std::declval; #else // https://en.cppreference.com/w/cpp/types/add_reference namespace detail { template struct type_identity { using type = T; }; // or use std::type_identity (since C++20) template // Note that `cv void&` is a substitution failure DEVICE_FUNC auto try_add_lvalue_reference(int) -> type_identity; template // Handle T = cv void case DEVICE_FUNC auto try_add_lvalue_reference(...) -> type_identity; template DEVICE_FUNC auto try_add_rvalue_reference(int) -> type_identity; template DEVICE_FUNC auto try_add_rvalue_reference(...) -> type_identity; } // namespace detail template struct add_lvalue_reference : decltype(detail::try_add_lvalue_reference(0)) { }; template struct add_rvalue_reference : decltype(detail::try_add_rvalue_reference(0)) { }; // https://en.cppreference.com/w/cpp/utility/declval template DEVICE_FUNC typename add_rvalue_reference::type declval() noexcept { static_assert(false, "declval not allowed in an evaluated context"); } #endif #ifndef GENERATE_CUBIN template using array = std::array; #else // https://en.cppreference.com/w/cpp/container/array template struct array; #endif #ifndef GENERATE_CUBIN template using is_same = std::is_same; using std::is_same_v; #else // https://en.cppreference.com/w/cpp/types/integral_constant template struct integral_constant { static constexpr T value = v; using value_type = T; using type = integral_constant; // using injected-class-name DEVICE_FUNC constexpr operator value_type() const noexcept { return value; } DEVICE_FUNC constexpr value_type operator()() const noexcept { return value; } // since c++14 }; using false_type = integral_constant; using true_type = integral_constant; // https://en.cppreference.com/w/cpp/types/is_same template struct is_same : false_type { }; template struct is_same : true_type { }; template inline constexpr bool is_same_v = is_same::value; #endif #ifndef GENERATE_CUBIN using std::move; using std::forward; using std::is_empty; #else // /usr/include/c++/11/type_traits template struct is_empty : public integral_constant { }; template struct remove_reference { typedef T type; }; template struct remove_reference { typedef T type; }; template struct remove_reference { typedef T type; }; template constexpr typename remove_reference::type&& move(T&& arg) { return static_cast::type&&>(arg); } template constexpr T&& forward(typename remove_reference::type& param) { return static_cast(param); } #endif // https://gcc.gnu.org/onlinedocs/libstdc++/libstdc++-api-4.5/a01066_source.html namespace libstdcpp { // Adds a const reference to a non-reference type. template struct __add_c_ref { typedef _Tp const& type; }; template struct __add_c_ref<_Tp&> { typedef _Tp& type; }; // Adds a reference to a non-reference type. template struct __add_ref { typedef _Tp& type; }; template struct __add_ref<_Tp&> { typedef _Tp& type; }; template struct _Head_base; template struct _Head_base<_Idx, _Head, true> : public _Head { DEVICE_FUNC _Head_base() : _Head() { } DEVICE_FUNC _Head_base(_Head const& __h) : _Head(__h) { } template DEVICE_FUNC _Head_base(_UHead&& __h) : _Head(forward<_UHead>(__h)) { } DEVICE_FUNC _Head& _M_head() { return *this; } DEVICE_FUNC _Head const& _M_head() const { return *this; } DEVICE_FUNC void _M_swap_impl(_Head&) { /* no-op */ } }; template struct _Head_base<_Idx, _Head, false> { DEVICE_FUNC _Head_base() : _M_head_impl() { } DEVICE_FUNC _Head_base(_Head const& __h) : _M_head_impl(__h) { } template DEVICE_FUNC _Head_base(_UHead&& __h) : _M_head_impl(forward<_UHead>(__h)) { } DEVICE_FUNC _Head& _M_head() { return _M_head_impl; } DEVICE_FUNC _Head const& _M_head() const { return _M_head_impl; } DEVICE_FUNC void _M_swap_impl(_Head& __h) { using std::swap; swap(__h, _M_head_impl); } _Head _M_head_impl; }; /** * Contains the actual implementation of the @c tuple template, stored * as a recursive inheritance hierarchy from the first element (most * derived class) to the last (least derived class). The @c Idx * parameter gives the 0-based index of the element stored at this * point in the hierarchy; we use it to implement a constant-time * get() operation. */ template struct _Tuple_impl; /** * Zero-element tuple implementation. This is the basis case for the * inheritance recursion. */ template struct _Tuple_impl<_Idx> { protected: DEVICE_FUNC void _M_swap_impl(_Tuple_impl&) { /* no-op */ } }; /** * Recursive tuple implementation. Here we store the @c Head element * and derive from a @c Tuple_impl containing the remaining elements * (which contains the @c Tail). */ template struct _Tuple_impl<_Idx, _Head, _Tail...> : public _Tuple_impl<_Idx + 1, _Tail...>, private _Head_base<_Idx, _Head, is_empty<_Head>::value> { typedef _Tuple_impl<_Idx + 1, _Tail...> _Inherited; typedef _Head_base<_Idx, _Head, is_empty<_Head>::value> _Base; DEVICE_FUNC _Head& _M_head() { return _Base::_M_head(); } DEVICE_FUNC _Head const& _M_head() const { return _Base::_M_head(); } DEVICE_FUNC _Inherited& _M_tail() { return *this; } DEVICE_FUNC _Inherited const& _M_tail() const { return *this; } DEVICE_FUNC _Tuple_impl() : _Inherited() , _Base() { } explicit DEVICE_FUNC _Tuple_impl(_Head const& __head, _Tail const&... __tail) : _Inherited(__tail...) , _Base(__head) { } template explicit DEVICE_FUNC _Tuple_impl(_UHead&& __head, _UTail&&... __tail) : _Inherited(forward<_UTail>(__tail)...) , _Base(forward<_UHead>(__head)) { } DEVICE_FUNC _Tuple_impl(_Tuple_impl const& __in) : _Inherited(__in._M_tail()) , _Base(__in._M_head()) { } DEVICE_FUNC _Tuple_impl(_Tuple_impl&& __in) : _Inherited(move(__in._M_tail())) , _Base(forward<_Head>(__in._M_head())) { } template DEVICE_FUNC _Tuple_impl(_Tuple_impl<_Idx, _UElements...> const& __in) : _Inherited(__in._M_tail()) , _Base(__in._M_head()) { } template DEVICE_FUNC _Tuple_impl(_Tuple_impl<_Idx, _UElements...>&& __in) : _Inherited(move(__in._M_tail())) , _Base(move(__in._M_head())) { } DEVICE_FUNC _Tuple_impl& operator=(_Tuple_impl const& __in) { _M_head() = __in._M_head(); _M_tail() = __in._M_tail(); return *this; } DEVICE_FUNC _Tuple_impl& operator=(_Tuple_impl&& __in) { _M_head() = move(__in._M_head()); _M_tail() = move(__in._M_tail()); return *this; } template DEVICE_FUNC _Tuple_impl& operator=(_Tuple_impl<_Idx, _UElements...> const& __in) { _M_head() = __in._M_head(); _M_tail() = __in._M_tail(); return *this; } template DEVICE_FUNC _Tuple_impl& operator=(_Tuple_impl<_Idx, _UElements...>&& __in) { _M_head() = move(__in._M_head()); _M_tail() = move(__in._M_tail()); return *this; } protected: DEVICE_FUNC void _M_swap_impl(_Tuple_impl& __in) { _Base::_M_swap_impl(__in._M_head()); _Inherited::_M_swap_impl(__in._M_tail()); } }; /// tuple template class tuple : public _Tuple_impl<0, _Elements...> { typedef _Tuple_impl<0, _Elements...> _Inherited; public: DEVICE_FUNC tuple() : _Inherited() { } explicit DEVICE_FUNC tuple(_Elements const&... __elements) : _Inherited(__elements...) { } template explicit DEVICE_FUNC tuple(_UElements&&... __elements) : _Inherited(forward<_UElements>(__elements)...) { } DEVICE_FUNC tuple(tuple const& __in) : _Inherited(static_cast<_Inherited const&>(__in)) { } DEVICE_FUNC tuple(tuple&& __in) : _Inherited(static_cast<_Inherited&&>(__in)) { } template DEVICE_FUNC tuple(tuple<_UElements...> const& __in) : _Inherited(static_cast<_Tuple_impl<0, _UElements...> const&>(__in)) { } template DEVICE_FUNC tuple(tuple<_UElements...>&& __in) : _Inherited(static_cast<_Tuple_impl<0, _UElements...>&&>(__in)) { } // XXX http://gcc.gnu.org/ml/libstdc++/2008-02/msg00047.html template DEVICE_FUNC tuple(tuple<_UElements...>& __in) : _Inherited(static_cast<_Tuple_impl<0, _UElements...> const&>(__in)) { } DEVICE_FUNC tuple& operator=(tuple const& __in) { static_cast<_Inherited&>(*this) = __in; return *this; } DEVICE_FUNC tuple& operator=(tuple&& __in) { static_cast<_Inherited&>(*this) = move(__in); return *this; } template DEVICE_FUNC tuple& operator=(tuple<_UElements...> const& __in) { static_cast<_Inherited&>(*this) = __in; return *this; } template DEVICE_FUNC tuple& operator=(tuple<_UElements...>&& __in) { static_cast<_Inherited&>(*this) = move(__in); return *this; } void DEVICE_FUNC swap(tuple& __in) { _Inherited::_M_swap_impl(__in); } }; template <> class tuple<> { public: DEVICE_FUNC void swap(tuple&) { /* no-op */ } }; /// Gives the type of the ith element of a given tuple type. template struct tuple_element; /** * Recursive case for tuple_element: strip off the first element in * the tuple and retrieve the (i-1)th element of the remaining tuple. */ template struct tuple_element<__i, tuple<_Head, _Tail...>> : tuple_element<__i - 1, tuple<_Tail...>> { }; /** * Basis case for tuple_element: The first element is the one we're seeking. */ template struct tuple_element<0, tuple<_Head, _Tail...>> { typedef _Head type; }; /// Finds the size of a given tuple type. template struct tuple_size; /// class tuple_size template struct tuple_size> { static const size_t value = sizeof...(_Elements); }; template const size_t tuple_size>::value; template DEVICE_FUNC inline typename __add_ref<_Head>::type __get_helper(_Tuple_impl<__i, _Head, _Tail...>& __t) { return __t._M_head(); } template DEVICE_FUNC inline typename __add_c_ref<_Head>::type __get_helper(_Tuple_impl<__i, _Head, _Tail...> const& __t) { return __t._M_head(); } // Return a reference (const reference) to the ith element of a tuple. // Any const or non-const ref elements are returned with their original type. template DEVICE_FUNC inline typename __add_ref>::type>::type get( tuple<_Elements...>& __t) { return __get_helper<__i>(__t); } template DEVICE_FUNC inline typename __add_c_ref>::type>::type get( tuple<_Elements...> const& __t) { return __get_helper<__i>(__t); } // This class helps construct the various comparison operations on tuples template struct __tuple_compare; template struct __tuple_compare<0, __i, __j, _Tp, _Up> { DEVICE_FUNC static bool __eq(_Tp const& __t, _Up const& __u) { return (get<__i>(__t) == get<__i>(__u) && __tuple_compare<0, __i + 1, __j, _Tp, _Up>::__eq(__t, __u)); } DEVICE_FUNC static bool __less(_Tp const& __t, _Up const& __u) { return ((get<__i>(__t) < get<__i>(__u)) || !(get<__i>(__u) < get<__i>(__t)) && __tuple_compare<0, __i + 1, __j, _Tp, _Up>::__less(__t, __u)); } }; template struct __tuple_compare<0, __i, __i, _Tp, _Up> { static bool __eq(_Tp const&, _Up const&) { return true; } static bool __less(_Tp const&, _Up const&) { return false; } }; template DEVICE_FUNC bool operator==(tuple<_TElements...> const& __t, tuple<_UElements...> const& __u) { typedef tuple<_TElements...> _Tp; typedef tuple<_UElements...> _Up; return (__tuple_compare::value - tuple_size<_Up>::value, 0, tuple_size<_Tp>::value, _Tp, _Up>::__eq( __t, __u)); } template DEVICE_FUNC bool operator<(tuple<_TElements...> const& __t, tuple<_UElements...> const& __u) { typedef tuple<_TElements...> _Tp; typedef tuple<_UElements...> _Up; return ( __tuple_compare::value - tuple_size<_Up>::value, 0, tuple_size<_Tp>::value, _Tp, _Up>::__less( __t, __u)); } template DEVICE_FUNC inline bool operator!=(tuple<_TElements...> const& __t, tuple<_UElements...> const& __u) { return !(__t == __u); } template DEVICE_FUNC inline bool operator>(tuple<_TElements...> const& __t, tuple<_UElements...> const& __u) { return __u < __t; } template DEVICE_FUNC inline bool operator<=(tuple<_TElements...> const& __t, tuple<_UElements...> const& __u) { return !(__u < __t); } template DEVICE_FUNC inline bool operator>=(tuple<_TElements...> const& __t, tuple<_UElements...> const& __u) { return !(__t < __u); } template struct __index_holder { }; template struct __index_holder_impl; template struct __index_holder_impl<__i, __index_holder<_Indexes...>, _IdxHolder, _Elements...> { typedef typename __index_holder_impl<__i + 1, __index_holder<_Indexes..., __i>, _Elements...>::type type; }; template struct __index_holder_impl<__i, __index_holder<_Indexes...>> { typedef __index_holder<_Indexes...> type; }; template struct __make_index_holder : __index_holder_impl<0, __index_holder<>, _Elements...> { }; template DEVICE_FUNC inline tuple<_TElements..., _UElements...> __tuple_cat_helper(tuple<_TElements...> const& __t, __index_holder<_TIdx...> const&, tuple<_UElements...> const& __u, __index_holder<_UIdx...> const&) { return tuple<_TElements..., _UElements...>(get<_TIdx>(__t)..., get<_UIdx>(__u)...); } template DEVICE_FUNC inline tuple<_TElements..., _UElements...> __tuple_cat_helper(tuple<_TElements...>&& __t, __index_holder<_TIdx...> const&, tuple<_UElements...> const& __u, __index_holder<_UIdx...> const&) { return tuple<_TElements..., _UElements...>(move(get<_TIdx>(__t))..., get<_UIdx>(__u)...); } template DEVICE_FUNC inline tuple<_TElements..., _UElements...> __tuple_cat_helper(tuple<_TElements...> const& __t, __index_holder<_TIdx...> const&, tuple<_UElements...>&& __u, __index_holder<_UIdx...> const&) { return tuple<_TElements..., _UElements...>(get<_TIdx>(__t)..., move(get<_UIdx>(__u))...); } template DEVICE_FUNC inline tuple<_TElements..., _UElements...> __tuple_cat_helper(tuple<_TElements...>&& __t, __index_holder<_TIdx...> const&, tuple<_UElements...>&& __u, __index_holder<_UIdx...> const&) { return tuple<_TElements..., _UElements...>(move(get<_TIdx>(__t))..., move(get<_UIdx>(__u))...); } template DEVICE_FUNC inline tuple<_TElements..., _UElements...> tuple_cat( tuple<_TElements...> const& __t, tuple<_UElements...> const& __u) { return __tuple_cat_helper(__t, typename __make_index_holder<_TElements...>::type(), __u, typename __make_index_holder<_UElements...>::type()); } template DEVICE_FUNC inline tuple<_TElements..., _UElements...> tuple_cat( tuple<_TElements...>&& __t, tuple<_UElements...> const& __u) { return __tuple_cat_helper(move(__t), typename __make_index_holder<_TElements...>::type(), __u, typename __make_index_holder<_UElements...>::type()); } template DEVICE_FUNC inline tuple<_TElements..., _UElements...> tuple_cat( tuple<_TElements...> const& __t, tuple<_UElements...>&& __u) { return __tuple_cat_helper(__t, typename __make_index_holder<_TElements...>::type(), move(__u), typename __make_index_holder<_UElements...>::type()); } template DEVICE_FUNC inline tuple<_TElements..., _UElements...> tuple_cat(tuple<_TElements...>&& __t, tuple<_UElements...>&& __u) { return __tuple_cat_helper(move(__t), typename __make_index_holder<_TElements...>::type(), move(__u), typename __make_index_holder<_UElements...>::type()); } template DEVICE_FUNC inline tuple<_Elements&...> tie(_Elements&... __args) { return tuple<_Elements&...>(__args...); } template DEVICE_FUNC inline void swap(tuple<_Elements...>& __x, tuple<_Elements...>& __y) { __x.swap(__y); } // A class (and instance) which can be used in 'tie' when an element // of a tuple is not required struct _Swallow_assign { template DEVICE_FUNC _Swallow_assign& operator=(_Tp const&) { return *this; } }; // TODO: Put this in some kind of shared file. namespace { _Swallow_assign ignore; }; // anonymous namespace } // namespace libstdcpp template using tuple = libstdcpp::tuple; using libstdcpp::tie; using libstdcpp::tuple_cat; #ifndef GENERATE_CUBIN template using remove_cv = std::remove_cv; template using remove_cv_t = typename std::remove_cv::type; template using decay = std::decay; template using decay_t = std::decay_t; #else // https://en.cppreference.com/w/cpp/types/is_array template struct is_array : false_type { }; template struct is_array : true_type { }; template struct is_array : true_type { }; // https://en.cppreference.com/w/cpp/types/remove_extent template struct remove_extent { using type = T; }; template struct remove_extent { using type = T; }; template struct remove_extent { using type = T; }; // https://en.cppreference.com/w/cpp/types/is_function template struct is_function : false_type { }; // specialization for regular functions template struct is_function : true_type { }; // specialization for variadic functions such as printf template struct is_function : true_type { }; // specialization for function types that have cv-qualifiers template struct is_function : true_type { }; template struct is_function : true_type { }; template struct is_function : true_type { }; template struct is_function : true_type { }; template struct is_function : true_type { }; template struct is_function : true_type { }; // specialization for function types that have ref-qualifiers template struct is_function : true_type { }; template struct is_function : true_type { }; template struct is_function : true_type { }; template struct is_function : true_type { }; template struct is_function : true_type { }; template struct is_function : true_type { }; template struct is_function : true_type { }; template struct is_function : true_type { }; template struct is_function : true_type { }; template struct is_function : true_type { }; template struct is_function : true_type { }; template struct is_function : true_type { }; template struct is_function : true_type { }; template struct is_function : true_type { }; template struct is_function : true_type { }; template struct is_function : true_type { }; // specializations for noexcept versions of all the above (C++17 and later) template struct is_function : true_type { }; template struct is_function : true_type { }; template struct is_function : true_type { }; template struct is_function : true_type { }; template struct is_function : true_type { }; template struct is_function : true_type { }; template struct is_function : true_type { }; template struct is_function : true_type { }; template struct is_function : true_type { }; template struct is_function : true_type { }; template struct is_function : true_type { }; template struct is_function : true_type { }; template struct is_function : true_type { }; template struct is_function : true_type { }; template struct is_function : true_type { }; template struct is_function : true_type { }; template struct is_function : true_type { }; template struct is_function : true_type { }; template struct is_function : true_type { }; template struct is_function : true_type { }; template struct is_function : true_type { }; template struct is_function : true_type { }; template struct is_function : true_type { }; template struct is_function : true_type { }; // https://en.cppreference.com/w/cpp/types/remove_cv template struct remove_cv { typedef T type; }; template struct remove_cv { typedef T type; }; template struct remove_cv { typedef T type; }; template struct remove_cv { typedef T type; }; template struct remove_const { typedef T type; }; template struct remove_const { typedef T type; }; template struct remove_volatile { typedef T type; }; template struct remove_volatile { typedef T type; }; template using remove_cv_t = typename remove_cv::type; // https://en.cppreference.com/w/cpp/types/add_pointer namespace detail { template auto try_add_pointer(int) -> type_identity::type*>; // usual case template auto try_add_pointer(...) -> type_identity; // unusual case (cannot form std::remove_reference::type*) } // namespace detail template struct add_pointer : decltype(detail::try_add_pointer(0)) { }; // https://en.cppreference.com/w/cpp/types/decay template struct decay { private: typedef typename remove_reference::type U; public: typedef typename conditional::value, typename add_pointer::type>::type, typename conditional::value, typename add_pointer::type, typename remove_cv::type>::type>::type type; }; template using decay_t = typename decay::type; #endif #ifndef GENERATE_CUBIN template using is_void = std::is_void; template inline constexpr bool is_void_v = std::is_void_v; #else template using is_void = is_same, void>; template inline constexpr bool is_void_v = is_void::value; #endif } // namespace mha #if GENERATE_CUBIN using uint8_t = mha::uint8_t; using int8_t = mha::int8_t; using uint16_t = mha::uint16_t; using int32_t = mha::int32_t; using uint32_t = mha::uint32_t; using uint64_t = mha::uint64_t; using uintptr_t = mha::uintptr_t; #endif