TensorRT-LLMs/cpp/include/tensorrt_llm/runtime/gptModelConfig.h
Kaiyu Xie f044eb8d94
Update TensorRT-LLM (#302)
* Update TensorRT-LLM

---------

Co-authored-by: wangruohui <12756472+wangruohui@users.noreply.github.com>
2023-11-07 19:51:58 +08:00

270 lines
6.9 KiB
C++

/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/common/quantization.h"
#include "tensorrt_llm/runtime/common.h"
#include <NvInferRuntime.h>
namespace tensorrt_llm::runtime
{
class GptModelConfig
{
public:
enum class ModelVariant : std::int32_t
{
kGpt = 0,
kGlm = 1, // https://github.com/THUDM/GLM and https://github.com/THUDM/ChatGLM-6B
};
constexpr explicit GptModelConfig(
SizeType vocabSize, SizeType nbLayers, SizeType nbHeads, SizeType hiddenSize, nvinfer1::DataType dtype)
: mVocabSize(vocabSize)
, mNbLayers(nbLayers)
, mNbHeads(nbHeads)
, mNbKvHeads(nbHeads)
, mHiddenSize(hiddenSize)
, mDataType(dtype)
, mUseGptAttentionPlugin(false)
, mInputPacked{false}
, mPagedKvCache{false}
, mTokensPerBlock{64}
, mQuantMode{common::QuantMode::none()}
, mMaxBatchSize(0)
, mMaxInputLen(0)
, mMaxOutputLen(0)
, mMaxNumTokens(std::nullopt)
, mComputeContextLogits(false)
, mModelVariant(ModelVariant::kGpt)
, mUseCustomAllReduce(false)
, mMaxPromptEmbeddingTableSize(0)
{
}
[[nodiscard]] SizeType constexpr getVocabSize() const noexcept
{
return mVocabSize;
}
[[nodiscard]] SizeType constexpr getVocabSizePadded(SizeType worldSize) const noexcept
{
return (mVocabSize + worldSize - 1) / worldSize * worldSize;
}
[[nodiscard]] SizeType constexpr getNbLayers(SizeType pipelineParallelism = 1) const
{
TLLM_CHECK(mNbLayers % pipelineParallelism == 0);
return mNbLayers / pipelineParallelism;
}
[[nodiscard]] SizeType constexpr getNbHeads() const noexcept
{
return mNbHeads;
}
[[nodiscard]] SizeType constexpr getNbKvHeads() const noexcept
{
return mNbKvHeads;
}
void constexpr setNbKvHeads(SizeType nbKvHeads) noexcept
{
mNbKvHeads = nbKvHeads;
}
[[nodiscard]] SizeType constexpr getHiddenSize() const noexcept
{
return mHiddenSize;
}
[[nodiscard]] SizeType constexpr getSizePerHead() const noexcept
{
return mHiddenSize / mNbHeads;
}
[[nodiscard]] nvinfer1::DataType constexpr getDataType() const noexcept
{
return mDataType;
}
[[nodiscard]] bool constexpr useGptAttentionPlugin() const noexcept
{
return mUseGptAttentionPlugin;
}
void constexpr useGptAttentionPlugin(bool useGptAttentionPlugin) noexcept
{
mUseGptAttentionPlugin = useGptAttentionPlugin;
}
[[nodiscard]] bool constexpr usePackedInput() const noexcept
{
return mInputPacked;
}
void constexpr usePackedInput(bool inputPacked) noexcept
{
mInputPacked = inputPacked;
}
[[nodiscard]] bool constexpr usePagedKvCache() const noexcept
{
return mPagedKvCache;
}
void constexpr usePagedKvCache(bool pagedKvCache) noexcept
{
mPagedKvCache = pagedKvCache;
}
[[nodiscard]] SizeType constexpr getTokensPerBlock() const noexcept
{
return mTokensPerBlock;
}
void constexpr setTokensPerBlock(SizeType TokensPerBlock) noexcept
{
mTokensPerBlock = TokensPerBlock;
}
[[nodiscard]] common::QuantMode constexpr getQuantMode() const noexcept
{
return mQuantMode;
}
void constexpr setQuantMode(common::QuantMode QuantMode) noexcept
{
mQuantMode = QuantMode;
}
[[nodiscard]] bool constexpr supportsInflightBatching() const noexcept
{
return mUseGptAttentionPlugin && mInputPacked && mPagedKvCache;
}
[[nodiscard]] SizeType constexpr getMaxBatchSize() const noexcept
{
return mMaxBatchSize;
}
void constexpr setMaxBatchSize(SizeType maxBatchSize) noexcept
{
mMaxBatchSize = maxBatchSize;
}
[[nodiscard]] SizeType constexpr getMaxInputLen() const noexcept
{
return mMaxInputLen;
}
void constexpr setMaxInputLen(SizeType maxInputLen) noexcept
{
mMaxInputLen = maxInputLen;
}
[[nodiscard]] SizeType constexpr getMaxOutputLen() const noexcept
{
return mMaxOutputLen;
}
void constexpr setMaxOutputLen(SizeType maxOutputLen) noexcept
{
mMaxOutputLen = maxOutputLen;
}
[[nodiscard]] std::optional<SizeType> constexpr getMaxNumTokens() const noexcept
{
return mMaxNumTokens;
}
void constexpr setMaxNumTokens(std::optional<SizeType> maxNumTokens) noexcept
{
mMaxNumTokens = maxNumTokens;
}
[[nodiscard]] bool constexpr usePromptTuning() const noexcept
{
return mMaxPromptEmbeddingTableSize > 0;
}
[[nodiscard]] SizeType constexpr getMaxPromptEmbeddingTableSize() const noexcept
{
return mMaxPromptEmbeddingTableSize;
}
void constexpr setMaxPromptEmbeddingTableSize(SizeType maxPromptEmbeddingTableSize) noexcept
{
mMaxPromptEmbeddingTableSize = maxPromptEmbeddingTableSize;
}
[[nodiscard]] bool constexpr computeContextLogits() const noexcept
{
return mComputeContextLogits;
}
void constexpr computeContextLogits(bool computeContextLogits) noexcept
{
mComputeContextLogits = computeContextLogits;
}
[[nodiscard]] ModelVariant getModelVariant() const
{
return mModelVariant;
}
void setModelVariant(ModelVariant modelVariant)
{
mModelVariant = modelVariant;
}
[[nodiscard]] bool constexpr useCustomAllReduce() const noexcept
{
return mUseCustomAllReduce;
}
void constexpr useCustomAllReduce(bool customAllReduce) noexcept
{
mUseCustomAllReduce = customAllReduce;
}
private:
SizeType mVocabSize;
SizeType mNbLayers;
SizeType mNbHeads;
SizeType mNbKvHeads;
SizeType mHiddenSize;
nvinfer1::DataType mDataType;
bool mUseGptAttentionPlugin;
bool mInputPacked;
bool mPagedKvCache;
SizeType mTokensPerBlock;
common::QuantMode mQuantMode;
SizeType mMaxBatchSize;
SizeType mMaxInputLen;
SizeType mMaxOutputLen;
std::optional<SizeType> mMaxNumTokens;
bool mComputeContextLogits;
ModelVariant mModelVariant;
bool mUseCustomAllReduce;
SizeType mMaxPromptEmbeddingTableSize;
};
} // namespace tensorrt_llm::runtime