TensorRT-LLMs/cpp/include/tensorrt_llm/batch_manager/llmRequest.h
Kaiyu Xie 587d063e6d
Update TensorRT-LLM (#506)
* Update TensorRT-LLM

---------

Co-authored-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
2023-11-30 16:46:22 +08:00

352 lines
12 KiB
C++

/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/runtime/bufferManager.h"
#include "tensorrt_llm/runtime/iTensor.h"
#include "tensorrt_llm/runtime/samplingConfig.h"
#include <assert.h>
#include <cstdint>
#include <memory>
#include <vector>
namespace tensorrt_llm::batch_manager
{
enum LlmRequestState_t
{
REQUEST_STATE_UNKNOWN = 0,
REQUEST_STATE_CONTEXT_INIT = 1,
REQUEST_STATE_GENERATION_IN_PROGRESS = 2,
REQUEST_STATE_GENERATION_COMPLETE = 3
};
class LlmRequest
{
public:
using SizeType = runtime::SizeType;
using TokenIdType = runtime::TokenIdType;
using RequestIdType = std::uint64_t;
using VecTokens = std::vector<TokenIdType>;
using VecLogProbs = std::vector<float>;
using BeamTokens = std::vector<VecTokens>;
using TensorPtr = runtime::ITensor::SharedPtr;
LlmRequest(RequestIdType requestId, SizeType maxNewTokens, std::shared_ptr<std::vector<TokenIdType>> inputTokens,
runtime::SamplingConfig samplingConfig, bool isStreaming, std::optional<SizeType> endId = std::nullopt,
std::optional<SizeType> padId = std::nullopt, std::optional<TensorPtr> embeddingBias = std::nullopt,
std::optional<TensorPtr> badWordsList = std::nullopt, std::optional<TensorPtr> stopWordsList = std::nullopt,
std::optional<TensorPtr> promptEmbeddingTable = std::nullopt,
std::optional<SizeType> promptVocabSize = std::nullopt, bool returnLogProbs = false,
std::optional<std::shared_ptr<VecTokens>> draftTokens = std::nullopt)
: mRequestId(requestId)
, mPromptLen(inputTokens->size())
, mMaxNewTokens(maxNewTokens)
, mSamplingConfig(samplingConfig)
, mState(REQUEST_STATE_CONTEXT_INIT)
, mIsStreaming(isStreaming)
, mEndId(endId)
, mPadId(padId)
, mBatchSlot(-1)
, mOrigPromptLen(inputTokens->size())
, mEmbeddingBias(embeddingBias)
, mBadWordsList(badWordsList)
, mStopWordsList(stopWordsList)
, mPromptEmbeddingTable(promptEmbeddingTable)
, mPromptVocabSize(promptVocabSize)
, mReturnLogProbs(returnLogProbs)
, mLogProbs(samplingConfig.beamWidth)
, mCumLogProbs(samplingConfig.beamWidth)
, mDraftTokens(draftTokens.value_or(std::make_shared<VecTokens>()))
{
mMaxSentTokenPos = mPromptLen - 1;
// Scatter the input tokens to other beam
mTokens = BeamTokens(mSamplingConfig.beamWidth, *inputTokens);
if ((mPromptEmbeddingTable.has_value() && !mPromptVocabSize.has_value())
|| (!mPromptEmbeddingTable.has_value() && mPromptVocabSize.has_value()))
{
std::string errStr
= "Prompt embedding table and prompt vocab size tensors must both be provided for requests with prompt "
"tuning enabled.";
TLLM_LOG_ERROR(errStr);
throw std::runtime_error(errStr);
}
}
/// @brief Get total number of tokens for this req (prompt + generated)
/// @param beam The beam index
/// @return The number of tokens
SizeType getNumTokens(SizeType beam) const
{
return mTokens.at(beam).size();
}
/// @brief Get max number of tokens across all beams
/// @return The number of tokens
SizeType getMaxBeamNumTokens() const
{
SizeType maxTokens = 0;
for (SizeType beam = 0; beam < mSamplingConfig.beamWidth; ++beam)
{
maxTokens = std::max(maxTokens, static_cast<SizeType>(mTokens.at(beam).size()));
}
return maxTokens;
}
/// @brief Get a token at a given position and beam index
/// @param beam The beam index
/// @param pos The position of the token relative to beginning of the prompt
/// @return The token index
TokenIdType getToken(SizeType beam, SizeType pos) const
{
return mTokens.at(beam).at(pos);
}
/// @brief Get the tokens at a given beam index
/// @param beam The beam index
/// @return A vector of tokens for this beam index, includes the prompt
std::vector<TokenIdType> const& getTokens(SizeType beam) const
{
return mTokens.at(beam);
}
/// @brief Get the draft tokens
/// @return shared_ptr to vector of draft tokens
std::shared_ptr<std::vector<TokenIdType>> const& getDraftTokens() const
{
return mDraftTokens;
}
/// @brief Returns true if request has draft tokens
/// @return flag
bool hasDraftTokens() const
{
return mDraftTokens && mDraftTokens->size() > 0;
}
/// @brief Get the maximum number of generated tokens among all rays in beam
/// @return The number of generated tokens (doesn't include the prompt tokens)
SizeType getMaxNumGeneratedTokens() const
{
return getMaxBeamNumTokens() - mPromptLen;
}
/// @brief Add new generated tokens to the vector of tokens
/// @param token The token to add
/// @param beam The beam to which to add the new token
void addNewToken(TokenIdType token, SizeType beam)
{
mTokens.at(beam).push_back(token);
}
/// @brief Add new generated tokens to the vector of tokens
/// @param beamTokens A vector containing the tokens to add for each beam index
/// beamTokens is expected to be of size beamWidth
void addNewTokens(const std::vector<TokenIdType>& beamTokens)
{
assert(mSamplingConfig.beamWidth == beamTokens.size());
for (std::size_t beam = 0; beam < beamTokens.size(); ++beam)
{
const auto outputId = beamTokens[beam];
mTokens.at(beam).push_back(outputId);
}
}
/// @brief Sets the generated tokens for all beams. Erases all previous generated tokens.
/// @param generatedBeamTokens The generated tokens for all beams (vector of vector of tokens)
void setGeneratedTokens(const BeamTokens& generatedBeamTokens)
{
assert(generatedBeamTokens.size() == mSamplingConfig.beamWidth);
for (std::size_t beam = 0; beam < generatedBeamTokens.size(); ++beam)
{
auto& beamTokens = mTokens[beam];
beamTokens.resize(mPromptLen);
beamTokens.insert(beamTokens.end(), generatedBeamTokens[beam].begin(), generatedBeamTokens[beam].end());
}
}
/// @brief Pause a request by moving the generated tokens to the prompt
/// @param maxInputLen The maximum prompt len.
void pause(SizeType maxInputLen)
{
// TODO: For beamWidth > 1, we would need to support swapping to avoid
// recomputing from the start
// As a temporary solution, we currently reset the tokens to the prompt
if (mSamplingConfig.beamWidth > 1)
{
for (std::size_t beam = 0; beam < mTokens.size(); ++beam)
{
auto& beamTokens = mTokens.at(beam);
beamTokens.resize(mPromptLen);
if (mReturnLogProbs)
{
mLogProbs.at(beam).clear();
}
}
}
else
{
SizeType newPromptLen = std::min(maxInputLen, mPromptLen + getMaxNumGeneratedTokens());
for (std::size_t beam = 0; beam < mTokens.size(); ++beam)
{
auto& beamTokens = mTokens.at(beam);
beamTokens.resize(newPromptLen);
if (mReturnLogProbs)
{
auto& logProb = mLogProbs.at(beam);
logProb.resize(newPromptLen - mPromptLen);
}
}
mMaxNewTokens -= (newPromptLen - mPromptLen);
mPromptLen = newPromptLen;
}
mState = REQUEST_STATE_CONTEXT_INIT;
mBatchSlot = -1;
}
/// @brief Get the maximum position of the tokens returned to the client. Use to ensure we don't return to
/// client duplicated token positions.
/// @return The maximum position of the tokens sent to the client
SizeType getMaxSentTokenPos() const
{
return mMaxSentTokenPos;
}
/// @brief Sets the maximum position of the tokens returned to the client. Use to ensure we don't return to
/// client duplicated token positions.
/// @param pos The maximum position
void setMaxSentTokenPos(SizeType pos)
{
mMaxSentTokenPos = pos;
}
std::optional<TensorPtr> getPromptEmbeddingTable() const
{
return mPromptEmbeddingTable;
}
void movePromptEmbeddingTableToGpu(runtime::BufferManager const& manager)
{
if (!mPromptEmbeddingTable.has_value()
|| mPromptEmbeddingTable.value()->getMemoryType() == runtime::MemoryType::kGPU)
{
return;
}
else
{
TensorPtr gpuPromptEmbeddingTable
= manager.copyFrom(*mPromptEmbeddingTable.value(), runtime::MemoryType::kGPU);
mPromptEmbeddingTable = gpuPromptEmbeddingTable;
}
}
std::optional<SizeType> getPromptVocabSize() const
{
return mPromptVocabSize;
}
std::optional<TensorPtr> getEmbeddingBias() const
{
return mEmbeddingBias;
}
std::optional<TensorPtr> getBadWordsList() const
{
return mBadWordsList;
}
std::optional<TensorPtr> getStopWordsList() const
{
return mStopWordsList;
}
bool returnLogProbs() const
{
return mReturnLogProbs;
}
std::vector<VecLogProbs> const& getLogProbs() const
{
return mLogProbs;
}
VecLogProbs const& getLogProbs(SizeType beam) const
{
return mLogProbs.at(beam);
}
void setLogProbs(VecLogProbs const& logProbs, SizeType beam)
{
mLogProbs.at(beam).resize(mPromptLen - mOrigPromptLen);
mLogProbs.at(beam).insert(mLogProbs.at(beam).end(), logProbs.begin(), logProbs.end());
}
VecLogProbs const& getCumLogProbs() const
{
return mCumLogProbs;
}
void setCumLogProb(float cumLogProb, SizeType beam)
{
mCumLogProbs.at(beam) = cumLogProb;
}
SizeType getOrigPromptLen() const
{
return mOrigPromptLen;
}
void setDraftTokens(const std::shared_ptr<VecTokens>& draftTokens)
{
mDraftTokens = draftTokens;
}
RequestIdType mRequestId;
SizeType mPromptLen;
SizeType mMaxNewTokens;
// Tokens [beam_size, mPromptLen + getMaxNumGeneratedTokens()]
runtime::SamplingConfig mSamplingConfig;
LlmRequestState_t mState;
bool mIsStreaming;
std::optional<SizeType> mEndId;
std::optional<SizeType> mPadId;
SizeType mBatchSlot;
private:
SizeType mOrigPromptLen;
BeamTokens mTokens;
SizeType mMaxSentTokenPos;
std::optional<TensorPtr> mEmbeddingBias;
std::optional<TensorPtr> mBadWordsList;
std::optional<TensorPtr> mStopWordsList;
std::optional<TensorPtr> mPromptEmbeddingTable;
std::optional<SizeType> mPromptVocabSize;
bool mReturnLogProbs;
std::vector<VecLogProbs> mLogProbs; // [beamSize, seqLen]
VecLogProbs mCumLogProbs; // [beamSize]
std::shared_ptr<VecTokens> mDraftTokens;
};
} // namespace tensorrt_llm::batch_manager