TensorRT-LLMs/cpp/include/tensorrt_llm/batch_manager/inferenceRequest.h
Kaiyu Xie 4bb65f216f
Update TensorRT-LLM (#1274)
* Update TensorRT-LLM

---------

Co-authored-by: meghagarwal <16129366+megha95@users.noreply.github.com>
Co-authored-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
2024-03-12 18:15:52 +08:00

321 lines
16 KiB
C++

/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/batch_manager/llmRequest.h"
#include "tensorrt_llm/batch_manager/namedTensor.h"
#include "tensorrt_llm/runtime/iTensor.h"
#include <algorithm>
#include <array>
#include <cstdint>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
namespace tensorrt_llm::batch_manager
{
namespace inference_request
{
// Input tensors
auto constexpr kInputIdsTensorName = "input_ids";
auto constexpr kDraftInputIdsTensorName = "draft_input_ids";
auto constexpr kDraftLogitsTensorName = "draft_logits";
auto constexpr kMaxNewTokensTensorName = "request_output_len";
auto constexpr kBeamWidthTensorName = "beam_width";
auto constexpr kEndIdTensorName = "end_id";
auto constexpr kPadIdTensorName = "pad_id";
auto constexpr kBadWordsListTensorName = "bad_words_list";
auto constexpr kStopWordsListTensorName = "stop_words_list";
auto constexpr kEmbeddingBiasTensorName = "embedding_bias";
auto constexpr kTemperatureTensorName = "temperature";
auto constexpr kRuntimeTopKTensorName = "runtime_top_k";
auto constexpr kRuntimeTopPTensorName = "runtime_top_p";
auto constexpr kLengthPenaltyTensorName = "len_penalty";
auto constexpr kEarlyStoppingTensorName = "early_stopping";
auto constexpr kRepetitionPenaltyTensorName = "repetition_penalty";
auto constexpr kMinLengthTensorName = "min_length";
auto constexpr kPresencePenaltyTensorName = "presence_penalty";
auto constexpr kFrequencyPenaltyTensorName = "frequency_penalty";
auto constexpr kRandomSeedTensorName = "random_seed";
auto constexpr kReturnLogProbsTensorName = "return_log_probs";
auto constexpr kReturnContextLogitsTensorName = "return_context_logits";
auto constexpr kReturnGenerationLogitsTensorName = "return_generation_logits";
auto constexpr kPromptEmbeddingTableName = "prompt_embedding_table";
auto constexpr kPromptVocabSizeName = "prompt_vocab_size";
// weights for a lora adapter shape [ num_lora_modules_layers, D x Hi + Ho x D ]
// where the last dimension holds the in / out adapter weights for the associated module (e.g. attn_qkv) and model layer
// each of the in / out tensors are first flattened and then concatenated together in the format above.
// D=adapter_size (R value), Hi=hidden_size_in, Ho=hidden_size_out.
auto constexpr kLoraWeights = "lora_weights";
// module identifier (same size a first dimension of lora_weights)
// See LoraModule::ModuleType for model id mapping
//
// "attn_qkv": 0 # compbined qkv adapter
// "attn_q": 1 # q adapter
// "attn_k": 2 # k adapter
// "attn_v": 3 # v adapter
// "attn_dense": 4 # adapter for the dense layer in attention
// "mlp_h_to_4h": 5 # for llama2 adapter for gated mlp layer after attention / RMSNorm: up projection
// "mlp_4h_to_h": 6 # for llama2 adapter for gated mlp layer after attention / RMSNorm: down projection
// "mlp_gate": 7 # for llama2 adapter for gated mlp later after attention / RMSNorm: gate
// "cross_attn_qkv": 8 # for enc-dec adapter for cross attention in decoder
// "cross_attn_q": 9 # for enc-dec adapter for cross attention in decoder
// "cross_attn_k": 10 # for enc-dec adapter for cross attention in decoder
// "cross_attn_v": 11 # for enc-dec adapter for cross attention in decoder
// "cross_attn_dense": 12 # for enc-dec adapter for cross attention in decoder
//
// last dim holds [ module_id, layer_idx, adapter_size (D / R value) ]
auto constexpr kLoraConfig = "lora_config"; // [num_lora_modules_layers, 3]
// Obsolete names for backward compatibility
auto constexpr kInputLengthsTensorName = "input_lengths";
// Output tensors
auto constexpr kOutputIdsTensorName = "output_ids";
auto constexpr kSequenceLengthTensorName = "sequence_length";
auto constexpr kLogProbsTensorName = "output_log_probs";
auto constexpr kCumLogProbsTensorName = "cum_log_probs";
auto constexpr kContextLogitsName = "context_logits";
auto constexpr kGenerationLogitsName = "generation_logits";
} // namespace inference_request
template <typename TTensor, typename TNamedTensor, typename TStream = runtime::BufferManager::CudaStreamPtr>
class GenericInferenceRequest
{
public:
using TensorPtr = TTensor;
using NamedTensorType = TNamedTensor;
using TensorMap = std::unordered_map<std::string, TTensor>;
using LogitsPostProcessor = typename GenericLlmRequest<TensorPtr, TStream>::LogitsPostProcessor;
explicit GenericInferenceRequest(
uint64_t requestId, std::optional<LogitsPostProcessor> logitsPostProcessor = std::nullopt)
: mRequestId{requestId}
, mIsStreaming{false}
, mlogitsPostProcessor(logitsPostProcessor)
{
}
GenericInferenceRequest(uint64_t requestId, TensorMap&& tensorMap,
std::optional<LogitsPostProcessor> logitsPostProcessor = std::nullopt)
: mRequestId{requestId}
, mIsStreaming{false}
, mInputTensors{std::move(tensorMap)}
, mlogitsPostProcessor(logitsPostProcessor)
{
for (auto const& [name, tensor] : mInputTensors)
{
validateTensorName(name);
}
}
GenericInferenceRequest(uint64_t requestId, TensorMap const& tensorMap,
std::optional<LogitsPostProcessor> logitsPostProcessor = std::nullopt)
: GenericInferenceRequest(requestId, TensorMap{tensorMap}, logitsPostProcessor)
{
}
void setIsStreaming(bool isStreaming)
{
mIsStreaming = isStreaming;
}
[[nodiscard]] bool isStreaming() const
{
return mIsStreaming;
}
[[nodiscard]] uint64_t getRequestId() const
{
return mRequestId;
}
TensorMap const& getInputTensors() const
{
return mInputTensors;
}
void setLogitsPostProcessor(std::optional<LogitsPostProcessor> cb)
{
mlogitsPostProcessor = cb;
}
std::optional<LogitsPostProcessor> getLogitsPostProcessor()
{
return mlogitsPostProcessor;
}
static std::array constexpr kTensorNames = {
inference_request::kInputIdsTensorName,
inference_request::kDraftInputIdsTensorName,
inference_request::kDraftLogitsTensorName,
inference_request::kMaxNewTokensTensorName,
inference_request::kBeamWidthTensorName,
inference_request::kEndIdTensorName,
inference_request::kPadIdTensorName,
inference_request::kBadWordsListTensorName,
inference_request::kStopWordsListTensorName,
inference_request::kEmbeddingBiasTensorName,
inference_request::kTemperatureTensorName,
inference_request::kRuntimeTopKTensorName,
inference_request::kRuntimeTopPTensorName,
inference_request::kLengthPenaltyTensorName,
inference_request::kEarlyStoppingTensorName,
inference_request::kRepetitionPenaltyTensorName,
inference_request::kMinLengthTensorName,
inference_request::kPresencePenaltyTensorName,
inference_request::kFrequencyPenaltyTensorName,
inference_request::kRandomSeedTensorName,
inference_request::kReturnLogProbsTensorName,
inference_request::kReturnContextLogitsTensorName,
inference_request::kReturnGenerationLogitsTensorName,
inference_request::kPromptEmbeddingTableName,
inference_request::kPromptVocabSizeName,
// obsolete names for backward compatibility
inference_request::kInputLengthsTensorName,
inference_request::kLoraWeights,
inference_request::kLoraConfig,
};
#define TENSOR_GETTER_SETTER(funcName, tensorName) \
\
[[nodiscard]] bool has##funcName() const \
{ \
return mInputTensors.find(tensorName) != mInputTensors.end(); \
} \
\
[[nodiscard]] TensorPtr const& get##funcName() const \
{ \
auto it = mInputTensors.find(tensorName); \
TLLM_CHECK_WITH_INFO(it != mInputTensors.end(), "Undefined tensor: %s", tensorName); \
return it->second; \
} \
\
[[nodiscard]] TensorPtr get##funcName##Unchecked() const \
{ \
auto it = mInputTensors.find(tensorName); \
return it != mInputTensors.end() ? it->second : TensorPtr{}; \
} \
\
[[nodiscard]] NamedTensorType get##funcName##Named() const \
{ \
auto it = mInputTensors.find(tensorName); \
return it != mInputTensors.end() ? NamedTensorType{it->second, tensorName} : NamedTensor{tensorName}; \
} \
\
void set##funcName(TensorPtr const& tensor) \
{ \
if constexpr (std::is_same_v<TensorPtr, tensorrt_llm::runtime::ITensor::SharedPtr>) \
{ \
TLLM_CHECK_WITH_INFO(tensor, "Cannot set nullptr when calling %s", __FUNCTION__); \
} \
mInputTensors[tensorName] = tensor; \
}
TENSOR_GETTER_SETTER(InputIds, inference_request::kInputIdsTensorName)
TENSOR_GETTER_SETTER(DraftInputIds, inference_request::kDraftInputIdsTensorName)
TENSOR_GETTER_SETTER(DraftLogits, inference_request::kDraftLogitsTensorName)
TENSOR_GETTER_SETTER(MaxNewTokens, inference_request::kMaxNewTokensTensorName)
TENSOR_GETTER_SETTER(BeamWidth, inference_request::kBeamWidthTensorName)
TENSOR_GETTER_SETTER(EndId, inference_request::kEndIdTensorName)
TENSOR_GETTER_SETTER(PadId, inference_request::kPadIdTensorName)
TENSOR_GETTER_SETTER(BadWordsList, inference_request::kBadWordsListTensorName)
TENSOR_GETTER_SETTER(StopWordsList, inference_request::kStopWordsListTensorName)
TENSOR_GETTER_SETTER(EmbeddingBias, inference_request::kEmbeddingBiasTensorName)
TENSOR_GETTER_SETTER(Temperature, inference_request::kTemperatureTensorName)
TENSOR_GETTER_SETTER(RuntimeTopK, inference_request::kRuntimeTopKTensorName)
TENSOR_GETTER_SETTER(RuntimeTopP, inference_request::kRuntimeTopPTensorName)
TENSOR_GETTER_SETTER(LengthPenalty, inference_request::kLengthPenaltyTensorName)
TENSOR_GETTER_SETTER(EarlyStopping, inference_request::kEarlyStoppingTensorName)
TENSOR_GETTER_SETTER(RepetitionPenalty, inference_request::kRepetitionPenaltyTensorName)
TENSOR_GETTER_SETTER(MinLength, inference_request::kMinLengthTensorName)
TENSOR_GETTER_SETTER(PresencePenalty, inference_request::kPresencePenaltyTensorName)
TENSOR_GETTER_SETTER(FrequencyPenalty, inference_request::kFrequencyPenaltyTensorName)
TENSOR_GETTER_SETTER(RandomSeed, inference_request::kRandomSeedTensorName)
TENSOR_GETTER_SETTER(ReturnLogProbs, inference_request::kReturnLogProbsTensorName)
TENSOR_GETTER_SETTER(ReturnContextLogits, inference_request::kReturnContextLogitsTensorName)
TENSOR_GETTER_SETTER(ReturnGenerationLogits, inference_request::kReturnGenerationLogitsTensorName)
TENSOR_GETTER_SETTER(PromptEmbeddingTable, inference_request::kPromptEmbeddingTableName)
TENSOR_GETTER_SETTER(PromptVocabSize, inference_request::kPromptVocabSizeName)
TENSOR_GETTER_SETTER(LoraWeights, inference_request::kLoraWeights)
TENSOR_GETTER_SETTER(LoraConfig, inference_request::kLoraConfig)
#undef TENSOR_GETTER_SETTER
protected:
static void validateTensorName(std::string const& tensorName)
{
TLLM_CHECK_WITH_INFO(std::find(kTensorNames.begin(), kTensorNames.end(), tensorName) != kTensorNames.end(),
"Invalid tensor name: %s", tensorName.c_str());
}
uint64_t mRequestId;
bool mIsStreaming;
TensorMap mInputTensors;
std::optional<LogitsPostProcessor> mlogitsPostProcessor;
};
class InferenceRequest : public GenericInferenceRequest<tensorrt_llm::runtime::ITensor::SharedPtr, NamedTensor>
{
public:
using Base = GenericInferenceRequest<tensorrt_llm::runtime::ITensor::SharedPtr, NamedTensor>;
using TensorPtr = Base::TensorPtr;
using TensorMap = Base::TensorMap;
explicit InferenceRequest(uint64_t requestId)
: Base(requestId)
{
}
InferenceRequest(TensorMap const& inputTensors, uint64_t requestId)
: Base(requestId, inputTensors)
{
}
InferenceRequest(TensorMap&& inputTensors, uint64_t requestId)
: Base(requestId, std::move(inputTensors))
{
}
[[deprecated("Use direct tensor access instead")]] [[nodiscard]] TensorPtr const& getInputTensor(
std::string const& inputTensorName) const
{
auto it = Base::mInputTensors.find(inputTensorName);
TLLM_CHECK_WITH_INFO(it != Base::mInputTensors.end(), "Invalid input tensor name: %s", inputTensorName.c_str());
return it->second;
}
[[deprecated("Use direct tensor access instead")]] void emplaceInputTensor(
std::string const& inputTensorName, TensorPtr inputTensor)
{
validateTensorName(inputTensorName);
Base::mInputTensors[inputTensorName] = std::move(inputTensor);
}
[[nodiscard]] std::vector<int64_t> serialize() const;
static std::shared_ptr<InferenceRequest> deserialize(std::vector<int64_t> const& packed);
static std::shared_ptr<InferenceRequest> deserialize(int64_t const* packed_ptr);
};
} // namespace tensorrt_llm::batch_manager