TensorRT-LLMs/cpp/tensorrt_llm/thop/dynamicDecodeOp.cpp
2024-03-19 17:36:42 +08:00

385 lines
21 KiB
C++

/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/thop/dynamicDecodeOp.h"
#include "tensorrt_llm/common/tensorConversion.h"
#include "tensorrt_llm/runtime/bufferManager.h"
#include "tensorrt_llm/thop/thUtils.h"
#include "tensorrt_llm/thop/torchAllocator.h"
namespace th = torch;
namespace tr = tensorrt_llm::runtime;
namespace tcc = tensorrt_llm::common::conversion;
namespace torch_ext
{
template <typename T>
FtDynamicDecode<T>::FtDynamicDecode(const size_t max_batch_size, const size_t max_beam_width, const size_t vocab_size,
const size_t vocab_size_padded, int const tensor_para_size, int const pipeline_para_size)
: vocab_size_(vocab_size)
, vocab_size_padded_(vocab_size_padded)
, finished_sum_(tr::BufferManager::pinned(
tr::ITensor::makeShape({static_cast<int32_t>(max_batch_size)}), nvinfer1::DataType::kINT32))
{
TLLM_CHECK_WITH_INFO(vocab_size_padded_ % tensor_para_size == 0,
tensorrt_llm::common::fmtstr(
"vocab_size (%ld) is not multiple of tensor_para_size (%d).", vocab_size_padded_, tensor_para_size));
auto stream = at::cuda::getCurrentCUDAStream().stream();
auto allocator = std::make_shared<tensorrt_llm::thop::TorchAllocator>(stream);
int deviceId;
tensorrt_llm::common::check_cuda_error(cudaGetDevice(&deviceId)); // Get the correct device id
tensorrt_llm::common::check_cuda_error(cudaGetDeviceProperties(&prop_, deviceId));
dynamic_decode_layer_ = std::make_shared<tensorrt_llm::layers::DynamicDecodeLayer<T>>(tr::DecodingMode::None(),
max_batch_size, max_beam_width, vocab_size_, vocab_size_padded_, stream, std::move(allocator), &prop_);
}
namespace
{
template <typename T>
void safeInsert(th::optional<th::Tensor>& tensor, std::optional<std::vector<T>>& arg)
{
using value_type = T;
if (tensor.has_value())
{
auto ptr = get_ptr<value_type>(tensor.value());
auto shape = convert_shape(tensor.value());
size_t const size = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<>());
arg = std::vector<value_type>(ptr, ptr + size);
}
}
template <typename T>
void safeUpdate(th::optional<th::Tensor>& tensor, std::optional<tc::Tensor>& arg)
{
if (tensor.has_value())
{
arg = convert_tensor<T>(tensor.value());
}
}
template <typename T>
void safeUpdateScalar(th::optional<th::Tensor>& tensor, std::optional<T>& arg, std::string const& name)
{
if (tensor.has_value())
{
auto accessor = tensor->accessor<T, 1>();
TLLM_CHECK_WITH_INFO(accessor.size(0) == 1, name + " must be a scalar");
arg = accessor[0];
}
}
template <typename T>
void safeUpdatePtr(th::optional<th::Tensor>& tensor, T*& ptr)
{
if (tensor.has_value())
{
ptr = get_ptr<T>(tensor.value());
}
}
} // namespace
template <typename T>
void FtDynamicDecode<T>::setup(size_t batch_size, size_t beam_width, th::optional<th::Tensor> runtime_top_k_opt,
th::optional<th::Tensor> runtime_top_p_opt, th::optional<th::Tensor> temperature_opt,
th::optional<th::Tensor> repetition_penalty_opt, th::optional<th::Tensor> presence_penalty_opt,
th::optional<th::Tensor> frequency_penalty_opt, th::optional<th::Tensor> min_length_opt,
th::optional<th::Tensor> length_penalty_opt, th::optional<th::Tensor> early_stopping_opt,
th::optional<th::Tensor> beam_search_diversity_rate_opt, th::optional<th::Tensor> random_seed_opt,
th::optional<th::Tensor> top_p_decay_opt, th::optional<th::Tensor> top_p_min_opt,
th::optional<th::Tensor> top_p_reset_ids_opt)
{
auto stream = at::cuda::getCurrentCUDAStream().stream();
dynamic_decode_layer_->setStream(stream);
SetupParams setupParams;
safeInsert(temperature_opt, setupParams.temperature);
safeInsert(repetition_penalty_opt, setupParams.repetition_penalty);
safeInsert(presence_penalty_opt, setupParams.presence_penalty);
safeInsert(frequency_penalty_opt, setupParams.frequency_penalty);
safeInsert(min_length_opt, setupParams.min_length);
safeInsert(runtime_top_k_opt, setupParams.runtime_top_k);
safeInsert(runtime_top_p_opt, setupParams.runtime_top_p);
safeInsert(random_seed_opt, setupParams.randomSeed);
safeInsert(top_p_decay_opt, setupParams.top_p_decay);
safeInsert(top_p_min_opt, setupParams.top_p_min);
safeInsert(top_p_reset_ids_opt, setupParams.top_p_reset_ids);
safeInsert(beam_search_diversity_rate_opt, setupParams.beam_search_diversity_rate);
safeInsert(length_penalty_opt, setupParams.length_penalty);
safeInsert(early_stopping_opt, setupParams.early_stopping);
dynamic_decode_layer_->setup(batch_size, beam_width, nullptr, setupParams);
}
template <typename T>
void FtDynamicDecode<T>::forward(th::Tensor& logits, int step, int max_input_length, int max_attention_window,
int sink_token_length, uint64_t ite, int local_batch_size, th::Tensor end_id,
th::optional<th::Tensor> embedding_bias_opt, th::optional<th::Tensor> input_lengths_opt,
th::optional<th::Tensor> sequence_limit_length_opt, th::optional<th::Tensor> stop_words_list_ptrs_opt,
th::optional<th::Tensor> stop_words_lens_opt, int32_t max_stop_words_len,
th::optional<th::Tensor> bad_words_list_ptrs_opt, th::optional<th::Tensor> bad_words_lens_opt,
int32_t max_bad_words_len, th::optional<th::Tensor> no_repeat_ngram_size_opt,
th::optional<th::Tensor> src_cache_indirection_opt, th::Tensor& output_token_ids, th::Tensor& newTokens,
th::Tensor& should_stop, th::optional<th::Tensor> finished_input, th::optional<th::Tensor> finished_output,
th::optional<th::Tensor> sequence_lengths_opt, th::optional<th::Tensor> cum_log_probs_opt,
th::optional<th::Tensor> output_log_probs_opt, th::optional<th::Tensor> output_log_probs_tiled_opt,
th::optional<th::Tensor> parent_ids_opt, th::optional<th::Tensor> tgt_cache_indirection_opt,
th::optional<th::Tensor> beam_hyps_output_ids_tgt_opt, th::optional<th::Tensor> beam_hyps_sequence_lengths_tgt_opt,
th::optional<th::Tensor> beam_hyps_cum_log_probs_opt, th::optional<th::Tensor> beam_hyps_normed_scores_opt,
th::optional<th::Tensor> beam_hyps_log_probs_opt, th::optional<th::Tensor> beam_hyps_min_normed_scores_opt,
th::optional<th::Tensor> beam_hyps_num_beams_opt, th::optional<th::Tensor> beam_hyps_is_done_opt,
bool use_beam_hyps)
{
typename tensorrt_llm::layers::DynamicDecodeLayer<T>::ForwardParams forwardParams{step, static_cast<int>(ite),
max_input_length, max_attention_window, sink_token_length, local_batch_size, convert_tensor<int>(end_id)};
forwardParams.logits = convert_tensor<float>(logits);
safeUpdate<T>(embedding_bias_opt, forwardParams.embedding_bias);
safeUpdate<int>(input_lengths_opt, forwardParams.input_lengths);
safeUpdate<int>(sequence_limit_length_opt, forwardParams.sequence_limit_length);
safeUpdate<uint64_t>(stop_words_list_ptrs_opt, forwardParams.stop_words_ptr);
safeUpdate<int>(stop_words_lens_opt, forwardParams.stop_words_lengths);
forwardParams.max_stop_words_len = max_stop_words_len;
safeUpdate<uint64_t>(bad_words_list_ptrs_opt, forwardParams.bad_words_ptr);
safeUpdate<int>(bad_words_lens_opt, forwardParams.bad_words_lengths);
forwardParams.max_bad_words_len = max_bad_words_len;
safeUpdate<int>(no_repeat_ngram_size_opt, forwardParams.no_repeat_ngram_size);
safeUpdate<int>(src_cache_indirection_opt, forwardParams.src_cache_indirection);
auto const& output_ids_converted = convert_tensor<int>(output_token_ids);
typename tensorrt_llm::layers::DynamicDecodeLayer<T>::OutputParams outputParams{output_ids_converted};
outputParams.newTokens = std::move(convert_tensor<int>(newTokens));
safeUpdate<uint8_t>(finished_input, forwardParams.finished);
safeUpdate<uint8_t>(finished_output, outputParams.finished);
safeUpdate<int>(sequence_lengths_opt, outputParams.sequence_length);
safeUpdate<float>(cum_log_probs_opt, outputParams.cum_log_probs);
safeUpdate<float>(output_log_probs_opt, outputParams.output_log_probs);
safeUpdate<float>(output_log_probs_tiled_opt, outputParams.output_log_probs_tiled);
safeUpdate<int>(parent_ids_opt, outputParams.parent_ids);
safeUpdate<int>(tgt_cache_indirection_opt, outputParams.tgt_cache_indirection);
std::int32_t* finished_sum_host = nullptr;
if (forwardParams.sequence_limit_length && outputParams.finished.has_value())
{
outputParams.finished_sum = tcc::toTllmTensor(*finished_sum_);
finished_sum_host = tr::bufferCast<std::int32_t>(*finished_sum_);
for (int32_t bi = 0; bi < local_batch_size; ++bi)
{
finished_sum_host[bi] = 0;
}
}
if (use_beam_hyps)
{
outputParams.beamHypotheses = std::make_shared<tensorrt_llm::kernels::BeamHypotheses>();
safeUpdatePtr<bool>(beam_hyps_is_done_opt, outputParams.beamHypotheses->is_done);
safeUpdatePtr<float>(beam_hyps_cum_log_probs_opt, outputParams.beamHypotheses->cum_log_probs);
safeUpdatePtr<float>(beam_hyps_log_probs_opt, outputParams.beamHypotheses->log_probs);
safeUpdatePtr<float>(beam_hyps_min_normed_scores_opt, outputParams.beamHypotheses->min_normed_scores);
safeUpdatePtr<float>(beam_hyps_normed_scores_opt, outputParams.beamHypotheses->normed_scores);
safeUpdatePtr<int>(beam_hyps_num_beams_opt, outputParams.beamHypotheses->num_beams);
safeUpdatePtr<int>(beam_hyps_output_ids_tgt_opt, outputParams.beamHypotheses->output_ids_tgt);
safeUpdatePtr<int>(beam_hyps_sequence_lengths_tgt_opt, outputParams.beamHypotheses->sequence_lengths_tgt);
// TODO: move the assignment below into onlineBeamSearchLayer.cu
safeUpdatePtr<int32_t const>(input_lengths_opt, outputParams.beamHypotheses->input_lengths);
}
dynamic_decode_layer_->forward(outputParams, forwardParams);
if (finished_sum_host)
{
TLLM_CUDA_CHECK(::cudaStreamSynchronize(dynamic_decode_layer_->getStream()));
int32_t finished_sum = 0;
for (int32_t bi = 0; bi < local_batch_size; ++bi)
{
finished_sum += finished_sum_host[bi];
}
auto const numToFinish = outputParams.finished->size();
auto should_stop_accessor = should_stop.accessor<bool, 1>();
should_stop_accessor[0] = numToFinish == finished_sum;
}
}
DynamicDecodeOp::DynamicDecodeOp(const int64_t max_batch_size, const int64_t max_beam_width, const int64_t vocab_size,
const int64_t vocab_size_padded, const int64_t tensor_para_size, const int64_t pipeline_para_size,
at::ScalarType scalar_type)
: max_batch_size_(static_cast<size_t>(max_batch_size))
, max_beam_width_(static_cast<size_t>(max_beam_width))
, vocab_size_(static_cast<size_t>(vocab_size))
, vocab_size_padded_(static_cast<size_t>(vocab_size_padded))
, tensor_para_size_(static_cast<int>(tensor_para_size))
, pipeline_para_size_(static_cast<int>(pipeline_para_size))
, scalar_type_(scalar_type)
{
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
createInstance();
}
void DynamicDecodeOp::createInstance()
{
dynamic_decode_.reset();
switch (scalar_type_)
{
case at::ScalarType::Float:
dynamic_decode_ = std::make_unique<FtDynamicDecode<float>>(
max_batch_size_, max_beam_width_, vocab_size_, vocab_size_padded_, tensor_para_size_, pipeline_para_size_);
break;
case at::ScalarType::Half:
dynamic_decode_ = std::make_unique<FtDynamicDecode<half>>(
max_batch_size_, max_beam_width_, vocab_size_, vocab_size_padded_, tensor_para_size_, pipeline_para_size_);
break;
default: throw std::runtime_error("Wrong tensor type.");
}
}
void DynamicDecodeOp::setup(int64_t batch_size, int64_t beam_width, th::optional<th::Tensor> runtime_top_k_opt,
th::optional<th::Tensor> runtime_top_p_opt, th::optional<th::Tensor> temperature_opt,
th::optional<th::Tensor> repetition_penalty_opt, th::optional<th::Tensor> presence_penalty_opt,
th::optional<th::Tensor> frequency_penalty_opt, th::optional<th::Tensor> min_length_opt,
th::optional<th::Tensor> length_penalty_opt, th::optional<th::Tensor> early_stopping_opt,
th::optional<th::Tensor> beam_search_diversity_rate_opt, th::optional<th::Tensor> random_seed_opt,
th::optional<th::Tensor> top_p_decay_opt, th::optional<th::Tensor> top_p_min_opt,
th::optional<th::Tensor> top_p_reset_ids_opt)
{
// TODO: Revise DynamicDecodeLayer and make the decode arguments consistent.
CHECK_OPTIONAL_CPU_INPUT(runtime_top_k_opt, torch::kInt32);
CHECK_OPTIONAL_CPU_INPUT(runtime_top_p_opt, torch::kFloat);
CHECK_OPTIONAL_CPU_INPUT(temperature_opt, torch::kFloat);
CHECK_OPTIONAL_CPU_INPUT(repetition_penalty_opt, torch::kFloat);
CHECK_OPTIONAL_CPU_INPUT(presence_penalty_opt, torch::kFloat);
CHECK_OPTIONAL_CPU_INPUT(frequency_penalty_opt, torch::kFloat);
CHECK_OPTIONAL_CPU_INPUT(min_length_opt, torch::kInt32);
CHECK_OPTIONAL_CPU_INPUT(length_penalty_opt, torch::kFloat);
CHECK_OPTIONAL_CPU_INPUT(early_stopping_opt, torch::kInt32);
CHECK_OPTIONAL_CPU_INPUT(beam_search_diversity_rate_opt, torch::kFloat);
CHECK_OPTIONAL_CPU_INPUT(random_seed_opt, torch::kInt64);
CHECK_OPTIONAL_INPUT(top_p_decay_opt, torch::kFloat);
CHECK_OPTIONAL_INPUT(top_p_min_opt, torch::kFloat);
CHECK_OPTIONAL_INPUT(top_p_reset_ids_opt, torch::kInt32);
// TODO: add a parameter "return_normed_score" to return normed_cum_log_probs / cum_log_probs
dynamic_decode_->setup(static_cast<size_t>(batch_size), static_cast<size_t>(beam_width), runtime_top_k_opt,
runtime_top_p_opt, temperature_opt, repetition_penalty_opt, presence_penalty_opt, frequency_penalty_opt,
min_length_opt, length_penalty_opt, early_stopping_opt, beam_search_diversity_rate_opt, random_seed_opt,
top_p_decay_opt, top_p_min_opt, top_p_reset_ids_opt);
}
th::Tensor DynamicDecodeOp::forward( // BS: batch_size, BM: beam_width, mSL: max_seq_length
th::Tensor logits, // [BS, BM, vocab_size_padded], T
int64_t step, //
int64_t max_input_length, //
int64_t max_attention_window, //
int64_t sink_token_length, //
int64_t ite, //
int64_t local_batch_size, //
th::Tensor end_id, // [BS*BM], int
th::optional<th::Tensor> embedding_bias_opt, // [vocab_size_padded], T
th::optional<th::Tensor> input_lengths_opt, // [BS*BM], int, length of input contexts
th::optional<th::Tensor> sequence_limit_length_opt, // [BS, 1], int
th::optional<th::Tensor> stop_words_list_ptrs_opt, // [BS][2, stop_words_length], int64
th::optional<th::Tensor> stop_words_lens_opt, // [BS], int
int64_t max_stop_words_len, //
th::optional<th::Tensor> bad_words_list_ptrs_opt, // [BS][2, bad_words_length], int64
th::optional<th::Tensor> bad_words_lens_opt, // [BS], int
int64_t max_bad_words_len, //
th::optional<th::Tensor> no_repeat_ngram_size_opt, // [BS], int
th::optional<th::Tensor> src_cache_indirection_opt, // [local_BS, BM, mSL], int
th::Tensor output_token_ids, // [BS, BM, mSL], int ? [mSL, BS, BM]
th::Tensor newTokens, // [BS, BM, 1], int
th::optional<th::Tensor> finished_input, // [BS, BM], uint8
th::optional<th::Tensor> finished_output, // [BS, BM], uint8
th::optional<th::Tensor> seuqence_lengths_opt, // [BS*BM], int, length of the current sequences
th::optional<th::Tensor> cum_log_probs_opt, // [BS, BM], float
th::optional<th::Tensor> output_log_probs_opt, // [BS, BM, mSL], float ? [mSL, BS, BM]
th::optional<th::Tensor> output_log_probs_tiled_opt, // [mSL, BS, BM], float ? [BS, BM, mSL]
th::optional<th::Tensor> parent_ids_opt, // [BS, BM, mSL], int ? [mSL, BS, BM]
th::optional<th::Tensor> tgt_cache_indirection_opt, // [local_BS, BM, memory_length], int
th::optional<th::Tensor> beam_hyps_output_ids_tgt_opt, // [BS, BM*2, mSL], int
th::optional<th::Tensor> beam_hyps_sequence_lengths_tgt_opt, // [BS, BM*2], int
th::optional<th::Tensor> beam_hyps_cum_log_probs_opt, // [BS, BM*2], float
th::optional<th::Tensor> beam_hyps_normed_scores_opt, // [BS, BM*2], float
th::optional<th::Tensor> beam_hyps_log_probs_opt, // [BS, BM*2, mSL], float
th::optional<th::Tensor> beam_hyps_min_normed_scores_opt, // [BS], float
th::optional<th::Tensor> beam_hyps_num_beams_opt, // [BS], int
th::optional<th::Tensor> beam_hyps_is_done_opt, // [BS], bool
bool use_beam_hyps //
)
{
CHECK_INPUT(logits, scalar_type_);
TLLM_CHECK_WITH_INFO(logits.dim() == 3,
"logits is of shape (batch_size, beam_width, vocab_size_padded), but got dim=%d shape=%s", (int) logits.dim(),
tensorrt_llm::common::vec2str(convert_shape(logits)).c_str());
TLLM_CHECK_WITH_INFO(static_cast<size_t>(logits.size(2)) == vocab_size_padded_,
"logits is of shape (batch_size, beam_width, vocab_size(%ld)), but got the last dim=%ld.", vocab_size_padded_,
static_cast<size_t>(logits.size(2)));
CHECK_INPUT(end_id, torch::kInt32);
CHECK_OPTIONAL_INPUT(embedding_bias_opt, scalar_type_);
CHECK_OPTIONAL_INPUT(input_lengths_opt, torch::kInt32);
CHECK_OPTIONAL_INPUT(sequence_limit_length_opt, torch::kInt32);
CHECK_OPTIONAL_INPUT(stop_words_list_ptrs_opt, torch::kInt64);
CHECK_OPTIONAL_INPUT(stop_words_lens_opt, torch::kInt32);
CHECK_OPTIONAL_INPUT(bad_words_list_ptrs_opt, torch::kInt64);
CHECK_OPTIONAL_INPUT(bad_words_lens_opt, torch::kInt32);
CHECK_OPTIONAL_INPUT(no_repeat_ngram_size_opt, torch::kInt32);
CHECK_OPTIONAL_INPUT(src_cache_indirection_opt, torch::kInt32);
CHECK_INPUT(output_token_ids, torch::kInt32);
CHECK_INPUT(newTokens, torch::kInt32);
CHECK_OPTIONAL_INPUT(finished_input, torch::kUInt8);
CHECK_OPTIONAL_INPUT(finished_output, torch::kUInt8);
CHECK_OPTIONAL_INPUT(seuqence_lengths_opt, torch::kInt32);
CHECK_OPTIONAL_INPUT(cum_log_probs_opt, torch::kFloat32);
CHECK_OPTIONAL_INPUT(output_log_probs_opt, torch::kFloat32);
CHECK_OPTIONAL_INPUT(output_log_probs_tiled_opt, torch::kFloat32);
CHECK_OPTIONAL_INPUT(parent_ids_opt, torch::kInt32);
CHECK_OPTIONAL_INPUT(tgt_cache_indirection_opt, torch::kInt32);
th::Tensor should_stop = torch::zeros({1}, torch::dtype(torch::kBool).requires_grad(false));
dynamic_decode_->forward(
// Inputs
logits, static_cast<int>(step), static_cast<int>(max_input_length), static_cast<int>(max_attention_window),
static_cast<int>(sink_token_length), static_cast<uint32_t>(ite), static_cast<int>(local_batch_size), end_id,
embedding_bias_opt, input_lengths_opt, sequence_limit_length_opt, stop_words_list_ptrs_opt, stop_words_lens_opt,
static_cast<int32_t>(max_stop_words_len), bad_words_list_ptrs_opt, bad_words_lens_opt,
static_cast<int32_t>(max_bad_words_len), no_repeat_ngram_size_opt, src_cache_indirection_opt,
// Outputs
output_token_ids, newTokens, should_stop, finished_input, finished_output, seuqence_lengths_opt,
cum_log_probs_opt, output_log_probs_opt, output_log_probs_tiled_opt, parent_ids_opt, tgt_cache_indirection_opt,
beam_hyps_output_ids_tgt_opt, beam_hyps_sequence_lengths_tgt_opt, beam_hyps_cum_log_probs_opt,
beam_hyps_normed_scores_opt, beam_hyps_log_probs_opt, beam_hyps_min_normed_scores_opt, beam_hyps_num_beams_opt,
beam_hyps_is_done_opt, use_beam_hyps);
return should_stop;
}
} // namespace torch_ext
static auto trtllmGptContextDecoderTHS
= torch::jit::class_<torch_ext::DynamicDecodeOp>("trtllm", "DynamicDecodeOp")
.def(torch::jit::init<int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, at::ScalarType>())
.def("setup", &torch_ext::DynamicDecodeOp::setup)
.def("forward", &torch_ext::DynamicDecodeOp::forward);