TensorRT-LLMs/cpp/include/tensorrt_llm/batch_manager/llmRequest.h
Kaiyu Xie f044eb8d94
Update TensorRT-LLM (#302)
* Update TensorRT-LLM

---------

Co-authored-by: wangruohui <12756472+wangruohui@users.noreply.github.com>
2023-11-07 19:51:58 +08:00

270 lines
9.2 KiB
C++

/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/runtime/bufferManager.h"
#include "tensorrt_llm/runtime/iTensor.h"
#include "tensorrt_llm/runtime/samplingConfig.h"
#include <assert.h>
#include <cstdint>
#include <memory>
#include <vector>
namespace tensorrt_llm::batch_manager
{
enum LlmRequestState_t
{
REQUEST_STATE_UNKNOWN = 0,
REQUEST_STATE_CONTEXT_INIT = 1,
REQUEST_STATE_GENERATION_IN_PROGRESS = 2,
REQUEST_STATE_GENERATION_COMPLETE = 3
};
class LlmRequest
{
public:
using SizeType = runtime::SizeType;
using TokenIdType = runtime::TokenIdType;
using RequestIdType = std::uint64_t;
using BeamTokens = std::vector<std::vector<TokenIdType>>;
using TensorPtr = runtime::ITensor::SharedPtr;
LlmRequest(RequestIdType requestId, SizeType maxNewTokens, std::shared_ptr<std::vector<TokenIdType>> input_tokens,
runtime::SamplingConfig samplingConfig, bool isStreaming, std::optional<SizeType> endId = std::nullopt,
std::optional<SizeType> padId = std::nullopt, std::optional<TensorPtr> embeddingBias = std::nullopt,
std::optional<TensorPtr> badWordsList = std::nullopt, std::optional<TensorPtr> stopWordsList = std::nullopt,
std::optional<TensorPtr> promptEmbeddingTable = std::nullopt,
std::optional<SizeType> promptVocabSize = std::nullopt)
: mRequestId(requestId)
, mPromptLen(input_tokens->size())
, mMaxNewTokens(maxNewTokens)
, mSamplingConfig(samplingConfig)
, mState(REQUEST_STATE_CONTEXT_INIT)
, mIsStreaming(isStreaming)
, mEndId(endId)
, mPadId(padId)
, mBatchSlot(-1)
, mEmbeddingBias(embeddingBias)
, mBadWordsList(badWordsList)
, mStopWordsList(stopWordsList)
, mPromptEmbeddingTable(promptEmbeddingTable)
, mPromptVocabSize(promptVocabSize)
{
mMaxSentTokenPos = mPromptLen - 1;
// Scatter the input tokens to other beam
mTokens = std::make_shared<BeamTokens>(mSamplingConfig.beamWidth, *input_tokens);
if ((mPromptEmbeddingTable.has_value() && !mPromptVocabSize.has_value())
|| (!mPromptEmbeddingTable.has_value() && mPromptVocabSize.has_value()))
{
std::string errStr
= "Prompt embedding table and prompt vocab size tensors must both be provided for requests with prompt "
"tuning enabled.";
TLLM_LOG_ERROR(errStr);
throw std::runtime_error(errStr);
}
}
/// @brief Get total number of tokens for this req (prompt + generated)
/// @param beam The beam index
/// @return The number of tokens
SizeType getNumTokens(SizeType beam) const
{
return mTokens->at(beam).size();
}
/// @brief Get max number of tokens across all beams
/// @return The number of tokens
SizeType getMaxBeamNumTokens() const
{
SizeType maxTokens = 0;
for (SizeType beam = 0; beam < mSamplingConfig.beamWidth; ++beam)
{
maxTokens = std::max(maxTokens, static_cast<SizeType>(mTokens->at(beam).size()));
}
return maxTokens;
}
/// @brief Get a token at a given position and beam index
/// @param beam The beam index
/// @param pos The position of the token relative to beginning of the prompt
/// @return The token index
TokenIdType getToken(SizeType beam, SizeType pos) const
{
return mTokens->at(beam).at(pos);
}
/// @brief Get the tokens at a given beam index
/// @param beam The beam index
/// @return A vector of tokens for this beam index, includes the prompt
std::vector<TokenIdType> getTokens(SizeType beam) const
{
return mTokens->at(beam);
}
/// @brief Get the maximum number of generated tokens among all rays in beam
/// @return The number of generated tokens (doesn't include the prompt tokens)
SizeType getMaxNumGeneratedTokens() const
{
return getMaxBeamNumTokens() - mPromptLen;
}
/// @brief Add new generated tokens to the vector of tokens
/// @param token The token to add
/// @param beam The beam to which to add the new token
void addNewToken(TokenIdType token, SizeType beam)
{
mTokens->at(beam).push_back(token);
}
/// @brief Add new generated tokens to the vector of tokens
/// @param beamTokens A vector containing the tokens to add for each beam index
/// beamTokens is expected to be of size beamWidth
void addNewTokens(const std::vector<TokenIdType>& beamTokens)
{
assert(mSamplingConfig.beamWidth == beamTokens.size());
for (std::size_t beam = 0; beam < beamTokens.size(); ++beam)
{
const auto outputId = beamTokens[beam];
mTokens->at(beam).push_back(outputId);
}
}
/// @brief Sets the generated tokens for all beams. Erases all previous generated tokens.
/// @param generatedBeamTokens The generated tokens for all beams (vector of vector of tokens)
void setGeneratedTokens(const BeamTokens& generatedBeamTokens)
{
assert(generatedBeamTokens.size() == mSamplingConfig.beamWidth);
for (std::size_t beam = 0; beam < generatedBeamTokens.size(); ++beam)
{
auto& beamTokens = (*mTokens)[beam];
beamTokens.resize(mPromptLen);
beamTokens.insert(beamTokens.end(), generatedBeamTokens[beam].begin(), generatedBeamTokens[beam].end());
}
}
/// @brief Pause a request by moving the generated tokens to the prompt
/// @param maxInputLen The maximum prompt len.
void pause(SizeType maxInputLen)
{
// TODO: For beamWidth > 1, we would need to support swapping to avoid
// recomputing from the start
// As a temporary solution, we currently reset the tokens to the prompt
if (mSamplingConfig.beamWidth > 1)
{
for (auto& beamTokens : *mTokens)
{
beamTokens.resize(mPromptLen);
}
}
else
{
SizeType newPromptLen = std::min(maxInputLen, mPromptLen + getMaxNumGeneratedTokens());
for (auto& beamTokens : *mTokens)
{
beamTokens.resize(newPromptLen);
}
mMaxNewTokens -= (newPromptLen - mPromptLen);
mPromptLen = newPromptLen;
}
mState = REQUEST_STATE_CONTEXT_INIT;
mBatchSlot = -1;
}
/// @brief Get the maximum position of the tokens returned to the client. Use to ensure we don't return to client
/// duplicated token positions.
/// @return The maximum position of the tokens sent to the client
SizeType getMaxSentTokenPos() const
{
return mMaxSentTokenPos;
}
/// @brief Sets the maximum position of the tokens returned to the client. Use to ensure we don't return to client
/// duplicated token positions.
/// @param pos The maximum position
void setMaxSentTokenPos(SizeType pos)
{
mMaxSentTokenPos = pos;
}
std::optional<TensorPtr> getPromptEmbeddingTable() const
{
return mPromptEmbeddingTable;
}
void movePromptEmbeddingTableToGpu(runtime::BufferManager const& manager)
{
if (!mPromptEmbeddingTable.has_value()
|| mPromptEmbeddingTable.value()->getMemoryType() == runtime::MemoryType::kGPU)
{
return;
}
else
{
TensorPtr gpuPromptEmbeddingTable
= manager.copyFrom(*mPromptEmbeddingTable.value(), runtime::MemoryType::kGPU);
mPromptEmbeddingTable = gpuPromptEmbeddingTable;
}
}
std::optional<SizeType> getPromptVocabSize() const
{
return mPromptVocabSize;
}
std::optional<TensorPtr> getEmbeddingBias() const
{
return mEmbeddingBias;
}
std::optional<TensorPtr> getBadWordsList() const
{
return mBadWordsList;
}
std::optional<TensorPtr> getStopWordsList() const
{
return mStopWordsList;
}
RequestIdType mRequestId;
SizeType mPromptLen;
SizeType mMaxNewTokens;
// Tokens [beam_size, mPromptLen + getMaxNumGeneratedTokens()]
runtime::SamplingConfig mSamplingConfig;
LlmRequestState_t mState;
bool mIsStreaming;
std::optional<SizeType> mEndId;
std::optional<SizeType> mPadId;
SizeType mBatchSlot;
private:
std::shared_ptr<BeamTokens> mTokens;
SizeType mMaxSentTokenPos;
std::optional<TensorPtr> mEmbeddingBias;
std::optional<TensorPtr> mBadWordsList;
std::optional<TensorPtr> mStopWordsList;
std::optional<TensorPtr> mPromptEmbeddingTable;
std::optional<SizeType> mPromptVocabSize;
};
} // namespace tensorrt_llm::batch_manager