TensorRT-LLMs/cpp/include/tensorrt_llm/runtime/loraCachePageManagerConfig.h
Kaiyu Xie bf0a5afc92
Update TensorRT-LLM (#1598)
* Update TensorRT-LLM
2024-05-14 16:43:41 +08:00

171 lines
4.8 KiB
C++

/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/iBuffer.h"
#include <NvInferRuntime.h>
#include <ostream>
#include <sstream>
#include <type_traits>
namespace tensorrt_llm::runtime
{
/**
* Configuration for LoraCachePageManager
*
* See LoraCache docs for description of pages, slots, and page blocks.
*/
class LoraCachePageManagerConfig
{
public:
explicit constexpr LoraCachePageManagerConfig(runtime::MemoryType memType, nvinfer1::DataType dType,
SizeType32 totalNumPages, SizeType32 maxPagesPerBlock, SizeType32 slotsPerPage, SizeType32 pageWidth,
SizeType32 numCopyStreams)
: mMemoryType(memType)
, mDataType(dType)
, mTotalNumPages(totalNumPages)
, mMaxPagesPerBlock(maxPagesPerBlock)
, mSlotsPerPage(slotsPerPage)
, mPageWidth(pageWidth)
, mInitToZero(false)
{
}
[[nodiscard]] runtime::MemoryType constexpr getMemoryType() const noexcept
{
return mMemoryType;
}
void constexpr setMemoryType(runtime::MemoryType const& memoryType) noexcept
{
mMemoryType = memoryType;
}
[[nodiscard]] nvinfer1::DataType constexpr getDataType() const noexcept
{
return mDataType;
}
void constexpr setDataType(nvinfer1::DataType const& dtype) noexcept
{
mDataType = dtype;
}
[[nodiscard]] SizeType32 constexpr getTotalNumPages() const noexcept
{
return mTotalNumPages;
}
void constexpr setTotalNumPage(SizeType32 const& totalNumPages) noexcept
{
mTotalNumPages = totalNumPages;
}
[[nodiscard]] SizeType32 constexpr getMaxPagesPerBlock() const noexcept
{
return mMaxPagesPerBlock;
}
void constexpr setMaxPagesPerBlock(SizeType32 const& maxPagesPerBlock) noexcept
{
mMaxPagesPerBlock = maxPagesPerBlock;
}
[[nodiscard]] SizeType32 constexpr getSlotsPerPage() const noexcept
{
return mSlotsPerPage;
}
void constexpr setSlotsPerPage(SizeType32 const& slotsPerPage) noexcept
{
mSlotsPerPage = slotsPerPage;
}
[[nodiscard]] SizeType32 constexpr getPageWidth() const noexcept
{
return mPageWidth;
}
void constexpr setPageWidth(SizeType32 const& pageWidth) noexcept
{
mPageWidth = pageWidth;
}
[[nodiscard]] bool constexpr getInitToZero() const noexcept
{
return mInitToZero;
}
void constexpr setInitToZero(bool initToZero) noexcept
{
mInitToZero = initToZero;
}
[[nodiscard]] SizeType32 constexpr getNumCopyStreams() const noexcept
{
return mNumCopyStreams;
}
void constexpr setNumCopyStreams(SizeType32 numCopyStreams) noexcept
{
mNumCopyStreams = numCopyStreams;
}
private:
runtime::MemoryType mMemoryType;
nvinfer1::DataType mDataType;
/*
* Number cache pages in the cache.
* Generally corresponds to the number of opt sized LoRAs that can be stored in the cache
*/
SizeType32 mTotalNumPages;
// number of pages to allocate in one block
SizeType32 mMaxPagesPerBlock;
// number of slots per page, where a slot corresponds to a adapterSize=1, 1-layer, 1-module set or weights
SizeType32 mSlotsPerPage;
SizeType32 mPageWidth;
// number of streams used to copy pages to device cache
SizeType32 mNumCopyStreams = 1;
bool mInitToZero; // for testing
};
inline std::ostream& operator<<(std::ostream& os, LoraCachePageManagerConfig const& c)
{
os << "{"
<< "memoryType=" << static_cast<typename std::underlying_type<runtime::MemoryType>::type>(c.getMemoryType())
<< " dataType=" << static_cast<typename std::underlying_type<nvinfer1::DataType>::type>(c.getDataType())
<< " totalNumPages=" << c.getTotalNumPages() << " maxPagesPerBlock=" << c.getMaxPagesPerBlock()
<< " slotsPerPage=" << c.getSlotsPerPage() << " pageWidth=" << c.getPageWidth()
<< " initToZero=" << c.getInitToZero() << "}";
return os;
}
inline std::string to_string(LoraCachePageManagerConfig const& c)
{
std::stringstream sstream;
sstream << c;
return sstream.str();
}
} // namespace tensorrt_llm::runtime