/* * Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #include "tensorrt_llm/runtime/samplingConfig.h" #include #include #include #include namespace tensorrt_llm::batch_manager { enum LlmRequestState_t { REQUEST_STATE_UNKNOWN = 0, REQUEST_STATE_CONTEXT_INIT = 1, REQUEST_STATE_GENERATION_IN_PROGRESS = 2, REQUEST_STATE_GENERATION_COMPLETE = 3 }; class LlmRequest { public: using SizeType = runtime::SizeType; using TokenIdType = runtime::TokenIdType; using RequestIdType = std::uint64_t; using BeamTokens = std::vector>; LlmRequest(RequestIdType requestId, SizeType maxNewTokens, std::shared_ptr> input_tokens, runtime::SamplingConfig samplingConfig, bool isStreaming, std::optional endId = std::nullopt, std::optional padId = std::nullopt) : mRequestId(requestId) , mPromptLen(input_tokens->size()) , mMaxNewTokens(maxNewTokens) , mSamplingConfig(samplingConfig) , mState(REQUEST_STATE_CONTEXT_INIT) , mIsStreaming(isStreaming) , mEndId(endId) , mPadId(padId) , mBatchSlot(-1) { mMaxSentTokenPos = mPromptLen - 1; // Scatter the input tokens to other beam mTokens = std::make_shared(mSamplingConfig.beamWidth, *input_tokens); } /// @brief Get total number of tokens for this req (prompt + generated) /// @param beam The beam index /// @return The number of tokens SizeType getNumTokens(SizeType beam) const { return mTokens->at(beam).size(); } /// @brief Get max number of tokens across all beams /// @return The number of tokens SizeType getMaxBeamNumTokens() const { SizeType maxTokens = 0; for (SizeType beam = 0; beam < mSamplingConfig.beamWidth; ++beam) { maxTokens = std::max(maxTokens, static_cast(mTokens->at(beam).size())); } return maxTokens; } /// @brief Get a token at a given position and beam index /// @param beam The beam index /// @param pos The position of the token relative to beginning of the prompt /// @return The token index TokenIdType getToken(SizeType beam, SizeType pos) const { return mTokens->at(beam).at(pos); } /// @brief Get the tokens at a given beam index /// @param beam The beam index /// @return A vector of tokens for this beam index, includes the prompt std::vector getTokens(SizeType beam) const { return mTokens->at(beam); } /// @brief Get the maximum number of generated tokens among all rays in beam /// @return The number of generated tokens (doesn't include the prompt tokens) SizeType getMaxNumGeneratedTokens() const { return getMaxBeamNumTokens() - mPromptLen; } /// @brief Add new generated tokens to the vector of tokens /// @param beamTokens A vector containing the tokens to add for each beam index /// beamTokens is expected to be of size beamWidth void addNewTokens(const std::vector& beamTokens) { assert(mSamplingConfig.beamWidth == beamTokens.size()); for (std::size_t beam = 0; beam < beamTokens.size(); ++beam) { const auto outputId = beamTokens[beam]; mTokens->at(beam).push_back(outputId); } } /// @brief Sets the generated tokens for all beams. Erases all previous generated tokens. /// @param generatedBeamTokens The generated tokens for all beams (vector of vector of tokens) void setGeneratedTokens(const BeamTokens& generatedBeamTokens) { assert(generatedBeamTokens.size() == mSamplingConfig.beamWidth); for (std::size_t beam = 0; beam < generatedBeamTokens.size(); ++beam) { auto& beamTokens = (*mTokens)[beam]; beamTokens.resize(mPromptLen); beamTokens.insert(beamTokens.end(), generatedBeamTokens[beam].begin(), generatedBeamTokens[beam].end()); } } /// @brief Pause a request by moving the generated tokens to the prompt /// @param maxInputLen The maximum prompt len. void pause(SizeType maxInputLen) { // TODO: For beamWidth > 1, we would need to support swapping to avoid // recomputing from the start // As a temporary solution, we currently reset the tokens to the prompt if (mSamplingConfig.beamWidth > 1) { for (auto& beamTokens : *mTokens) { beamTokens.resize(mPromptLen); } } else { SizeType newPromptLen = std::min(maxInputLen, mPromptLen + getMaxNumGeneratedTokens()); for (auto& beamTokens : *mTokens) { beamTokens.resize(newPromptLen); } mMaxNewTokens -= (newPromptLen - mPromptLen); mPromptLen = newPromptLen; } mState = REQUEST_STATE_CONTEXT_INIT; mBatchSlot = -1; } /// @brief Get the maximum position of the tokens returned to the client. Use to ensure we don't return to client /// duplicated token positions. /// @return The maximum position of the tokens sent to the client SizeType getMaxSentTokenPos() const { return mMaxSentTokenPos; } /// @brief Sets the maximum position of the tokens returned to the client. Use to ensure we don't return to client /// duplicated token positions. /// @param pos The maximum position void setMaxSentTokenPos(SizeType pos) { mMaxSentTokenPos = pos; } RequestIdType mRequestId; SizeType mPromptLen; SizeType mMaxNewTokens; // Tokens [beam_size, mPromptLen + getMaxNumGeneratedTokens()] runtime::SamplingConfig mSamplingConfig; LlmRequestState_t mState; bool mIsStreaming; std::optional mEndId; std::optional mPadId; SizeType mBatchSlot; private: std::shared_ptr mTokens; SizeType mMaxSentTokenPos; }; } // namespace tensorrt_llm::batch_manager