/* * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #include "tensorrt_llm/common/arrayView.h" #include "tensorrt_llm/common/dataType.h" #include "tensorrt_llm/kernels/decodingCommon.h" #include "tensorrt_llm/kernels/kvCacheIndex.h" #include "tensorrt_llm/runtime/common.h" #include #include #ifdef ENABLE_FP8 #include #endif #ifdef ENABLE_BF16 #include #endif #include #include #include #include #include #include #include #include namespace tensorrt_llm::runtime { enum class MemoryType : std::int32_t { kGPU = 0, kCPU = 1, kPINNED = 2, kUVM = 3, kPINNEDPOOL = 4 }; template struct MemoryTypeString { }; template <> struct MemoryTypeString { static auto constexpr value = "GPU"; }; template <> struct MemoryTypeString { static auto constexpr value = "CPU"; }; template <> struct MemoryTypeString { static auto constexpr value = "PINNED"; }; template <> struct MemoryTypeString { static auto constexpr value = "UVM"; }; template <> struct MemoryTypeString { static auto constexpr value = "PINNEDPOOL"; }; //! \brief For converting a TensorRT data type to a C++ data type. template struct DataTypeTraits { }; template <> struct DataTypeTraits { using type = float; static char constexpr name[] = "float"; static auto constexpr size = sizeof(type); }; template <> struct DataTypeTraits { using type = half; static char constexpr name[] = "half"; static auto constexpr size = sizeof(type); }; template <> struct DataTypeTraits { using type = std::int8_t; static char constexpr name[] = "int8"; static auto constexpr size = sizeof(type); }; template <> struct DataTypeTraits { using type = std::int32_t; static char constexpr name[] = "int32"; static auto constexpr size = sizeof(type); }; template <> struct DataTypeTraits { using type = std::int64_t; static char constexpr name[] = "int64"; static auto constexpr size = sizeof(type); }; template <> struct DataTypeTraits { using type = std::uint32_t; static char constexpr name[] = "uint32"; static auto constexpr size = sizeof(type); }; template <> struct DataTypeTraits { using type = std::uint64_t; static char constexpr name[] = "uint64"; static auto constexpr size = sizeof(type); }; template struct DataTypeTraits { using type = bool; static char constexpr name[] = "bool"; static auto constexpr size = sizeof(type); }; template struct DataTypeTraits { using type = std::uint8_t; static char constexpr name[] = "uint8"; static auto constexpr size = sizeof(type); }; #ifdef ENABLE_BF16 template <> struct DataTypeTraits { using type = __nv_bfloat16; static char constexpr name[] = "bfloat16"; static auto constexpr size = sizeof(type); }; #endif #ifdef ENABLE_FP8 template <> struct DataTypeTraits { using type = __nv_fp8_e4m3; static char constexpr name[] = "fp8"; static auto constexpr size = sizeof(type); }; #endif template struct DataTypeTraits { using type = typename DataTypeTraits::type*; static char constexpr name[] = "*"; static auto constexpr size = sizeof(type); }; //! \brief A wrapper around `nvinfer1::DataType` that provides a support for pointer types. class BufferDataType { public: constexpr BufferDataType( // NOLINT(*-explicit-constructor) nvinfer1::DataType dataType, bool _unsigned = false, bool pointer = false) : mDataType{dataType} , mUnsigned{_unsigned} , mPointer{pointer} { } static auto constexpr kTrtPointerType = nvinfer1::DataType::kINT64; constexpr operator nvinfer1::DataType() const noexcept // NOLINT(*-explicit-constructor) { return mPointer ? kTrtPointerType : mDataType; } [[nodiscard]] constexpr nvinfer1::DataType getDataType() const noexcept { return mDataType; } [[nodiscard]] constexpr bool isPointer() const noexcept { return mPointer; } [[nodiscard]] constexpr bool isUnsigned() const { switch (mDataType) { case nvinfer1::DataType::kBOOL: [[fallthrough]]; case nvinfer1::DataType::kUINT8: return true; default: return mUnsigned; } } [[nodiscard]] constexpr std::size_t getSize() const noexcept { return tensorrt_llm::common::getDTypeSize(static_cast(*this)); } [[nodiscard]] constexpr std::size_t getSizeInBits() const noexcept { return tensorrt_llm::common::getDTypeSizeInBits(static_cast(*this)); } private: nvinfer1::DataType mDataType; bool mUnsigned; bool mPointer; }; //! \brief For converting a C++ data type to a TensorRT data type. template struct TRTDataType { }; template <> struct TRTDataType { static constexpr auto value = nvinfer1::DataType::kFLOAT; }; template <> struct TRTDataType { static constexpr auto value = nvinfer1::DataType::kHALF; }; template <> struct TRTDataType { static constexpr auto value = nvinfer1::DataType::kINT8; }; template <> struct TRTDataType { static constexpr auto value = nvinfer1::DataType::kINT32; }; template <> struct TRTDataType { static constexpr auto value = BufferDataType{nvinfer1::DataType::kINT32, true}; }; template <> struct TRTDataType { static constexpr auto value = nvinfer1::DataType::kINT64; }; template <> struct TRTDataType { static constexpr auto value = BufferDataType{nvinfer1::DataType::kINT64, true}; }; template <> struct TRTDataType { static constexpr auto value = nvinfer1::DataType::kBOOL; }; template <> struct TRTDataType { static constexpr auto value = nvinfer1::DataType::kUINT8; }; #ifdef ENABLE_BF16 template <> struct TRTDataType<__nv_bfloat16> { static constexpr auto value = nvinfer1::DataType::kBF16; }; #endif #ifdef ENABLE_FP8 template <> struct TRTDataType<__nv_fp8_e4m3> { static constexpr auto value = nvinfer1::DataType::kFP8; }; #endif template <> struct TRTDataType { static constexpr auto value = TRTDataType::value; }; template <> struct TRTDataType { static constexpr auto value = TRTDataType::value; }; template <> struct TRTDataType { static constexpr auto value = TRTDataType>::value; }; template <> struct TRTDataType { static constexpr auto value = BufferDataType::kTrtPointerType; }; template struct TRTDataType { private: static auto constexpr kUnderlyingType = BufferDataType{TRTDataType, false>::value}; public: static auto constexpr value = BufferDataType{kUnderlyingType.getDataType(), kUnderlyingType.isUnsigned(), true}; }; template using PointerElementType = typename std::remove_reference_t::element_type; template std::shared_ptr> constPointerCast(std::shared_ptr const& ptr) noexcept { return std::const_pointer_cast>(ptr); } template std::shared_ptr> constPointerCast(std::unique_ptr&& ptr) noexcept { return std::const_pointer_cast>(std::shared_ptr(std::move(ptr))); } class IBuffer { public: using UniquePtr = std::unique_ptr; using SharedPtr = std::shared_ptr; using UniqueConstPtr = std::unique_ptr; using SharedConstPtr = std::shared_ptr; using DataType = nvinfer1::DataType; //! //! \brief Returns a pointer to underlying array. //! [[nodiscard]] virtual void* data() = 0; //! //! \brief Returns a pointer to underlying array. //! [[nodiscard]] virtual void const* data() const = 0; //! //! \brief Returns a pointer to the underlying array at a given element index. //! [[nodiscard]] virtual void* data(std::size_t index) { auto* const dataPtr = this->data(); return dataPtr ? static_cast(dataPtr) + toBytes(index) : nullptr; } //! //! \brief Returns a pointer to the underlying array at a given element index. //! [[nodiscard]] virtual void const* data(std::size_t index) const { auto const* const dataPtr = this->data(); return dataPtr ? static_cast(dataPtr) + toBytes(index) : nullptr; } //! //! \brief Returns the size (in number of elements) of the buffer. //! [[nodiscard]] virtual std::size_t getSize() const = 0; //! //! \brief Returns the size (in bytes) of the buffer. //! [[nodiscard]] virtual std::size_t getSizeInBytes() const { return toBytes(getSize()); } //! //! \brief Returns the capacity of the buffer. //! [[nodiscard]] virtual std::size_t getCapacity() const = 0; //! //! \brief Returns the data type of the buffer. //! [[nodiscard]] virtual DataType getDataType() const = 0; [[nodiscard]] static char const* getDataTypeName(DataType dataType); [[nodiscard]] virtual char const* getDataTypeName() const; //! //! \brief Returns the memory type of the buffer. //! [[nodiscard]] virtual MemoryType getMemoryType() const = 0; [[nodiscard]] virtual char const* getMemoryTypeName() const; //! //! \brief Resizes the buffer. This is a no-op if the new size is smaller than or equal to the current capacity. //! virtual void resize(std::size_t newSize) = 0; //! //! \brief Releases the buffer. It will be reset to nullptr. //! virtual void release() = 0; virtual ~IBuffer() = default; //! //! \brief Not allowed to copy. //! IBuffer(IBuffer const&) = delete; //! //! \brief Not allowed to copy. //! IBuffer& operator=(IBuffer const&) = delete; //! //! \brief Creates a sliced view on the underlying `buffer`. The view will have the same data type as `buffer`. //! //! \param buffer The buffer to view. //! \param offset The offset of the view. //! \param size The size of the view. //! \return A view on the `buffer`. //! static UniquePtr slice(SharedPtr buffer, std::size_t offset, std::size_t size); template >, int> = 0> static UniqueConstPtr slice(TConstPtr&& tensor, std::size_t offset, std::size_t size) { return IBuffer::slice(constPointerCast(std::forward(tensor)), offset, size); } static UniquePtr slice(SharedPtr buffer, std::size_t offset) { auto const size = buffer->getSize() - offset; return IBuffer::slice(std::move(buffer), offset, size); } template >, int> = 0> static UniqueConstPtr slice(TConstPtr&& tensor, std::size_t offset) { return IBuffer::slice(constPointerCast(std::forward(tensor)), offset); } //! //! \brief Returns a view on the underlying `tensor` which can be independently resized. //! //! \param tensor The tensor to view. //! \return A view on the `tensor`. //! static UniquePtr view(SharedPtr tensor) { auto constexpr offset = 0; return IBuffer::slice(std::move(tensor), offset); } //! //! \brief Returns a view on the underlying `tensor` with a different size. //! //! \param tensor The tensor to view. //! \param size The size of the view. //! \return A view on the `tensor`. //! static UniquePtr view(SharedPtr tensor, std::size_t size) { auto v = IBuffer::view(std::move(tensor)); v->resize(size); return v; } template >, int> = 0> static UniqueConstPtr view(TConstPtr&& tensor, std::size_t size) { return IBuffer::view(constPointerCast(std::forward(tensor)), size); } //! //! \brief Wraps the given `data` in an `IBuffer`. The `IBuffer` will not own the underlying `data` and cannot //! be resized beyond `capacity`. //! //! \param data The data to wrap. //! \param type The data type of the `data`. //! \param size The size of the buffer. //! \param capacity The capacity of the buffer. //! \return An `IBuffer`. static UniquePtr wrap(void* data, DataType type, std::size_t size, std::size_t capacity); static UniquePtr wrap(void* data, DataType type, std::size_t size) { return wrap(data, type, size, size); } template static UniquePtr wrap(T* data, std::size_t size, std::size_t capacity) { return wrap(data, TRTDataType::value, size, capacity); } template static UniquePtr wrap(T* data, std::size_t size) { return wrap(data, size, size); } template static UniquePtr wrap(std::vector& v) { return wrap(v.data(), v.size(), v.capacity()); } //! //! \brief Determine the memory type of a pointer. //! static MemoryType memoryType(void const* data); protected: IBuffer() = default; //! //! \brief Returns an array index or size in bytes. //! [[nodiscard]] std::size_t toBytes(std::size_t size) const { return size * BufferDataType(getDataType()).getSizeInBits() / 8; } }; /// @brief Gets a typed pointer to the constant underlying data of the buffer. /// @tparam T The type of the underlying data. /// @param buffer The buffer to get a pointer to. /// @return A pointer to constant @p T. template T const* bufferCast(IBuffer const& buffer) { if (TRTDataType::type>::value != buffer.getDataType()) { throw std::bad_cast(); } return static_cast(buffer.data()); } /// @brief Gets a typed pointer to the underlying data of the buffer. /// @tparam T The type of the underlying data. /// @param buffer The buffer to get a pointer to. /// @return A pointer to @p T. template T* bufferCast(IBuffer& buffer) { if (TRTDataType::type>::value != buffer.getDataType()) { throw std::bad_cast(); } return static_cast(buffer.data()); } /// @brief Retrieves a T typed pointer to the underlying data of the buffer pointed to by the bufferPtr, or nullptr if /// the bufferPtr is null. /// @tparam T The type of the underlying data. /// @param bufferPtr A possibly null shared ptr. /// @return A pointer to T, possibly nullptr. template T* bufferCastOrNull(IBuffer::SharedPtr const& bufferPtr) { if (bufferPtr) { return bufferCast(*bufferPtr); } return static_cast(nullptr); } /// @brief Retrieves a T const typed pointer to the underlying data of the buffer pointed to by the bufferPtr, or /// nullptr if the bufferPtr is null. /// @tparam T The type of the underlying data. /// @param bufferPtr A possibly null shared ptr. /// @return A pointer to const T, possibly nullptr. template T const* bufferCastOrNull(IBuffer::SharedConstPtr const& bufferPtr) { if (bufferPtr) { return bufferCast(*bufferPtr); } return static_cast(nullptr); } /// @brief Retrieves a T typed pointer to the underlying data of the buffer pointed to by the buffer pointer /// contained in the optionalBufferPtr, or nullptr if the optional doesn't have a value. /// @tparam T The type of the underlying data. /// @param optionalBufferPtr A possibly empty optional. /// @return A pointer to T, possibly nullptr. template T* bufferCastOrNull(std::optional const& optionalBufferPtr) { if (optionalBufferPtr) { return bufferCast(*optionalBufferPtr.value()); } return static_cast(nullptr); } /// @brief Retrieves a T const typed pointer to the underlying data of the buffer pointed to by the buffer pointer /// contained in the optionalBufferPtr, or nullptr if the optional doesn't have a value. /// @tparam T The type of the underlying data. /// @param optionalBufferPtr A possibly empty optional. /// @return A pointer to const T, possibly nullptr. template T const* bufferCastOrNull(std::optional const& optionalBufferPtr) { if (optionalBufferPtr) { return bufferCast(*optionalBufferPtr.value()); } return static_cast(nullptr); } template class BufferRange : public tensorrt_llm::common::ArrayView { public: using Base = tensorrt_llm::common::ArrayView; using typename Base::size_type; BufferRange(T* data, size_type size) : Base{data, size} { } template , bool> = true> explicit BufferRange(IBuffer& buffer) : BufferRange(bufferCast(buffer), buffer.getSize()) { } template , bool> = true> explicit BufferRange(IBuffer const& buffer) : BufferRange(bufferCast(buffer), buffer.getSize()) { } }; //! \brief Utility function to print a buffer. std::ostream& operator<<(std::ostream& output, IBuffer const& buffer); } // namespace tensorrt_llm::runtime