/* * Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "tensorrt_llm/thop/dynamicDecodeOp.h" #include "tensorrt_llm/common/tensorConversion.h" #include "tensorrt_llm/runtime/bufferManager.h" #include "tensorrt_llm/thop/thUtils.h" #include "tensorrt_llm/thop/torchAllocator.h" namespace th = torch; namespace tr = tensorrt_llm::runtime; namespace tcc = tensorrt_llm::common::conversion; namespace torch_ext { template FtDynamicDecode::FtDynamicDecode( const size_t vocab_size, const size_t vocab_size_padded, const int tensor_para_size, const int pipeline_para_size) : vocab_size_(vocab_size) , vocab_size_padded_(vocab_size_padded) { TLLM_CHECK_WITH_INFO(vocab_size_padded_ % tensor_para_size == 0, tensorrt_llm::common::fmtstr( "vocab_size (%ld) is not multiple of tensor_para_size (%d).", vocab_size_padded_, tensor_para_size)); allocator_ = new tensorrt_llm::thop::TorchAllocator(); auto stream = at::cuda::getCurrentCUDAStream().stream(); cudaDeviceProp prop; tensorrt_llm::common::check_cuda_error(cudaGetDeviceProperties(&prop, 0)); dynamic_decode_layer_ = new tensorrt_llm::layers::DynamicDecodeLayer( vocab_size_, vocab_size_padded_, stream, nullptr, allocator_, false, &prop_); } template FtDynamicDecode::~FtDynamicDecode() { delete dynamic_decode_layer_; delete allocator_; } namespace { template void safeInsert(th::optional& tensor, std::optional>& arg) { using value_type = T; if (tensor.has_value()) { auto ptr = get_ptr(tensor.value()); auto shape = convert_shape(tensor.value()); size_t const size = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<>()); arg = std::vector(ptr, ptr + size); } } template void safeUpdate(th::optional& tensor, std::optional& arg) { if (tensor.has_value()) { arg = convert_tensor(tensor.value()); } } template void safeUpdateScalar(th::optional& tensor, std::optional& arg, std::string const& name) { if (tensor.has_value()) { auto accessor = tensor->accessor(); TLLM_CHECK_WITH_INFO(accessor.size(0) == 1, name + " must be a scalar"); arg = accessor[0]; } } template void safeUpdatePtr(th::optional& tensor, T*& ptr) { if (tensor.has_value()) { ptr = get_ptr(tensor.value()); } } } // namespace template void FtDynamicDecode::setup(size_t batch_size, size_t beam_width, th::optional runtime_top_k_opt, th::optional runtime_top_p_opt, th::optional temperature_opt, th::optional repetition_penalty_opt, th::optional presence_penalty_opt, th::optional min_length_opt, th::optional length_penalty_opt, th::optional beam_search_diversity_rate_opt, th::optional random_seed_opt, th::optional top_p_decay_opt, th::optional top_p_min_opt, th::optional top_p_reset_ids_opt) { // unused: length_penalty_opt, beam_search_diversity_rate_opt auto stream = at::cuda::getCurrentCUDAStream().stream(); dynamic_decode_layer_->setStream(stream); SetupParams setupParams; safeInsert(runtime_top_k_opt, setupParams.runtime_top_k); safeInsert(runtime_top_p_opt, setupParams.runtime_top_p); safeInsert(temperature_opt, setupParams.temperature); safeInsert(repetition_penalty_opt, setupParams.repetition_penalty); safeInsert(presence_penalty_opt, setupParams.presence_penalty); safeInsert(min_length_opt, setupParams.min_length); safeInsert(random_seed_opt, setupParams.random_seed); safeInsert(top_p_decay_opt, setupParams.top_p_decay); safeInsert(top_p_min_opt, setupParams.top_p_min); safeInsert(top_p_reset_ids_opt, setupParams.top_p_reset_ids); safeUpdateScalar( beam_search_diversity_rate_opt, setupParams.beam_search_diversity_rate, "beam_search_diversity_rate"); safeUpdateScalar(length_penalty_opt, setupParams.length_penalty, "length_penalty"); dynamic_decode_layer_->setup(batch_size, beam_width, setupParams); } template void FtDynamicDecode::forward(th::Tensor& logits, // (batch_size, beam_width, hidden_size) int step, int max_input_length, uint ite, int local_batch_size, th::Tensor end_id, th::optional embedding_bias_opt, th::optional input_lengths_opt, th::optional sequence_limit_length_opt, th::optional stop_words_list_opt, th::optional bad_words_list_opt, th::optional no_repeat_ngram_size_opt, th::optional src_cache_indirection_opt, // Outputs th::Tensor& output_token_ids, th::Tensor& newTokens, th::Tensor& should_stop, th::optional finished_opt, th::optional sequence_lengths_opt, th::optional cum_log_probs_opt, th::optional output_log_probs_opt, th::optional parent_ids_opt, th::optional tgt_cache_indirection_opt, th::optional beam_hyps_output_ids_tgt_opt, th::optional beam_hyps_sequence_lengths_tgt_opt, th::optional beam_hyps_cum_log_probs_opt, th::optional beam_hyps_normed_scores_opt, th::optional beam_hyps_log_probs_opt, th::optional beam_hyps_min_normed_scores_opt, th::optional beam_hyps_num_beams_opt, th::optional beam_hyps_is_done_opt, bool use_beam_hyps) { auto const& logits_converted = convert_tensor(logits); auto const& end_ids_converted = convert_tensor(end_id); typename tensorrt_llm::layers::DynamicDecodeLayer::ForwardParams forwardParams{ step, static_cast(ite), max_input_length, local_batch_size, logits_converted, end_ids_converted}; safeUpdate(src_cache_indirection_opt, forwardParams.src_cache_indirection); safeUpdate(sequence_limit_length_opt, forwardParams.sequence_limit_length); safeUpdate(embedding_bias_opt, forwardParams.embedding_bias); safeUpdate(input_lengths_opt, forwardParams.input_lengths); safeUpdate(bad_words_list_opt, forwardParams.bad_words_list); safeUpdate(stop_words_list_opt, forwardParams.stop_words_list); safeUpdate(no_repeat_ngram_size_opt, forwardParams.no_repeat_ngram_size); auto const& output_ids_converted = convert_tensor(output_token_ids); typename tensorrt_llm::layers::DynamicDecodeLayer::OutputParams outputParams{output_ids_converted}; outputParams.newTokens = std::move(convert_tensor(newTokens)); safeUpdate(finished_opt, outputParams.finished); tr::BufferManager::ITensorPtr finished_sum; std::int32_t* finished_sum_host = nullptr; if (forwardParams.sequence_limit_length && outputParams.finished.has_value()) { finished_sum = tr::BufferManager::pinned(tr::ITensor::makeShape({1}), nvinfer1::DataType::kINT32); outputParams.finished_sum = tcc::toTllmTensor(*finished_sum); finished_sum_host = tr::bufferCast(*finished_sum); *finished_sum_host = 0; } safeUpdate(sequence_lengths_opt, outputParams.sequence_length); safeUpdate(parent_ids_opt, outputParams.parent_ids); safeUpdate(cum_log_probs_opt, outputParams.cum_log_probs); safeUpdate(output_log_probs_opt, outputParams.output_log_probs); safeUpdate(tgt_cache_indirection_opt, outputParams.tgt_cache_indirection); if (use_beam_hyps) { outputParams.beamHypotheses = std::make_shared(); safeUpdatePtr(beam_hyps_output_ids_tgt_opt, outputParams.beamHypotheses->output_ids_tgt); safeUpdatePtr(beam_hyps_sequence_lengths_tgt_opt, outputParams.beamHypotheses->sequence_lengths_tgt); safeUpdatePtr(beam_hyps_cum_log_probs_opt, outputParams.beamHypotheses->cum_log_probs); safeUpdatePtr(beam_hyps_normed_scores_opt, outputParams.beamHypotheses->normed_scores); safeUpdatePtr(beam_hyps_log_probs_opt, outputParams.beamHypotheses->log_probs); safeUpdatePtr(beam_hyps_min_normed_scores_opt, outputParams.beamHypotheses->min_normed_scores); safeUpdatePtr(beam_hyps_num_beams_opt, outputParams.beamHypotheses->num_beams); safeUpdatePtr(beam_hyps_is_done_opt, outputParams.beamHypotheses->is_done); safeUpdatePtr(input_lengths_opt, outputParams.beamHypotheses->input_lengths); } dynamic_decode_layer_->forward(outputParams, forwardParams); if (finished_sum_host) { auto const numToFinish = outputParams.finished->size(); TLLM_CUDA_CHECK(::cudaStreamSynchronize(dynamic_decode_layer_->getStream())); auto should_stop_accessor = should_stop.accessor(); should_stop_accessor[0] = numToFinish == *finished_sum_host; } } DynamicDecodeOp::DynamicDecodeOp(const int64_t vocab_size, const int64_t vocab_size_padded, const int64_t tensor_para_size, const int64_t pipeline_para_size, at::ScalarType scalar_type) : vocab_size_(static_cast(vocab_size)) , vocab_size_padded_(static_cast(vocab_size_padded)) , tensor_para_size_(static_cast(tensor_para_size)) , pipeline_para_size_(static_cast(pipeline_para_size)) , scalar_type_(scalar_type) { TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); createInstance(); } DynamicDecodeOp::~DynamicDecodeOp() { // Do nothing. } void DynamicDecodeOp::createInstance() { dynamic_decode_.reset(); switch (scalar_type_) { case at::ScalarType::Float: dynamic_decode_ = std::make_unique>( vocab_size_, vocab_size_padded_, tensor_para_size_, pipeline_para_size_); break; case at::ScalarType::Half: dynamic_decode_ = std::make_unique>( vocab_size_, vocab_size_padded_, tensor_para_size_, pipeline_para_size_); break; default: throw std::runtime_error("Wrong tensor type."); } } void DynamicDecodeOp::setup(int64_t batch_size, int64_t beam_width, th::optional runtime_top_k_opt, th::optional runtime_top_p_opt, th::optional temperature_opt, th::optional repetition_penalty_opt, th::optional presence_penalty_opt, th::optional min_length_opt, th::optional length_penalty_opt, th::optional beam_search_diversity_rate_opt, th::optional random_seed_opt, th::optional top_p_decay_opt, th::optional top_p_min_opt, th::optional top_p_reset_ids_opt) { // TODO: Revise DynamicDecodeLayer and make the decode arguments consistent. CHECK_OPTIONAL_CPU_INPUT(runtime_top_k_opt, torch::kInt32); CHECK_OPTIONAL_CPU_INPUT(runtime_top_p_opt, torch::kFloat); CHECK_OPTIONAL_CPU_INPUT(temperature_opt, torch::kFloat); CHECK_OPTIONAL_CPU_INPUT(repetition_penalty_opt, torch::kFloat); CHECK_OPTIONAL_CPU_INPUT(presence_penalty_opt, torch::kFloat); CHECK_OPTIONAL_CPU_INPUT(min_length_opt, torch::kInt32); CHECK_OPTIONAL_CPU_INPUT(length_penalty_opt, torch::kFloat); CHECK_OPTIONAL_CPU_INPUT(beam_search_diversity_rate_opt, torch::kFloat); CHECK_OPTIONAL_CPU_INPUT(random_seed_opt, torch::kInt64); CHECK_OPTIONAL_INPUT(top_p_decay_opt, torch::kFloat); CHECK_OPTIONAL_INPUT(top_p_min_opt, torch::kFloat); CHECK_OPTIONAL_INPUT(top_p_reset_ids_opt, torch::kInt32); dynamic_decode_->setup(static_cast(batch_size), static_cast(beam_width), runtime_top_k_opt, runtime_top_p_opt, temperature_opt, repetition_penalty_opt, presence_penalty_opt, min_length_opt, length_penalty_opt, beam_search_diversity_rate_opt, random_seed_opt, top_p_decay_opt, top_p_min_opt, top_p_reset_ids_opt); } th::Tensor DynamicDecodeOp::forward(th::Tensor logits, int64_t step, int64_t max_input_length, int64_t ite, int64_t local_batch_size, th::Tensor end_id, th::optional embedding_bias_opt, th::optional input_lengths_opt, // length of input contexts. th::optional sequence_limit_length_opt, th::optional stop_words_list_opt, th::optional bad_words_list_opt, th::optional no_repeat_ngram_size_opt, th::optional src_cache_indirection_opt, // output buffers. th::Tensor output_token_ids, th::Tensor newTokens, th::optional finished_opt, th::optional seuqence_lengths_opt, // length of the current sequences. th::optional cum_log_probs_opt, th::optional output_log_probs_opt, th::optional parent_ids_opt, th::optional tgt_cache_indirection_opt, th::optional beam_hyps_output_ids_tgt_opt, th::optional beam_hyps_sequence_lengths_tgt_opt, th::optional beam_hyps_cum_log_probs_opt, th::optional beam_hyps_normed_scores_opt, th::optional beam_hyps_log_probs_opt, th::optional beam_hyps_min_normed_scores_opt, th::optional beam_hyps_num_beams_opt, th::optional beam_hyps_is_done_opt, bool use_beam_hyps) { // Input Arguments: // logits: [batch_size, beam_width, vocab_size_padded], T // end_id: [batch_size], int, optional // embedding_bias: [vocab_size_padded], T, optional // input_lengths: [batch_size * beam_width], int, optional // sequence_limit_length: [batch_size], int, optional // stop_words_list: [batch_size, 2, stop_words_length], int, optional // bad_words_list: [2, stop_words_length], int, optional // src_cache_indirection: [local_batch_size, beam_width, memory_length], // int, optional output_token_ids: [max_seq_length, batch_size, // beam_width], int finished: [batch_size * beam_width], bool, optional // sequence_lengths: [batch_size * beam_width], int, optional // cum_log_probs: [batch_size * beam_width], float, optional // output_log_probs: [gen_length, batch_size, beam_width], float, optional // parent_ids: [gen_length, batch_size, beam_width], float, optional // tgt_cache_indirection: [local_batch_size, beam_width, memory_length], // float, optional CHECK_INPUT(logits, scalar_type_); TLLM_CHECK_WITH_INFO(logits.dim() == 3, tensorrt_llm::common::fmtstr("logits is of shape (batch_size, beam_width, vocab_size_padded), " "but got dim=%d shape=%s", (int) logits.dim(), tensorrt_llm::common::vec2str(convert_shape(logits)).c_str()) .c_str()); TLLM_CHECK_WITH_INFO(static_cast(logits.size(2)) == vocab_size_padded_, tensorrt_llm::common::fmtstr("logits is of shape (batch_size, beam_width, " "vocab_size(%ld)), but got the last dim=%ld.", vocab_size_padded_, static_cast(logits.size(2)))); CHECK_INPUT(end_id, torch::kInt32); CHECK_OPTIONAL_INPUT(input_lengths_opt, torch::kInt32); CHECK_OPTIONAL_INPUT(sequence_limit_length_opt, torch::kInt32); CHECK_OPTIONAL_INPUT(stop_words_list_opt, torch::kInt32); CHECK_OPTIONAL_INPUT(bad_words_list_opt, torch::kInt32); CHECK_OPTIONAL_INPUT(no_repeat_ngram_size_opt, torch::kInt32); CHECK_OPTIONAL_INPUT(src_cache_indirection_opt, torch::kInt32); CHECK_INPUT(output_token_ids, torch::kInt32); CHECK_OPTIONAL_INPUT(finished_opt, torch::kBool); CHECK_OPTIONAL_INPUT(seuqence_lengths_opt, torch::kInt32); CHECK_OPTIONAL_INPUT(cum_log_probs_opt, torch::kFloat32); CHECK_OPTIONAL_INPUT(output_log_probs_opt, torch::kFloat32); CHECK_OPTIONAL_INPUT(parent_ids_opt, torch::kInt32); CHECK_OPTIONAL_INPUT(tgt_cache_indirection_opt, torch::kInt32); th::Tensor should_stop = torch::zeros({1}, torch::dtype(torch::kBool).requires_grad(false)); dynamic_decode_->forward( // Inputs logits, static_cast(step), static_cast(max_input_length), static_cast(ite), static_cast(local_batch_size), end_id, embedding_bias_opt, input_lengths_opt, sequence_limit_length_opt, stop_words_list_opt, bad_words_list_opt, no_repeat_ngram_size_opt, src_cache_indirection_opt, // Outputs output_token_ids, newTokens, should_stop, finished_opt, seuqence_lengths_opt, cum_log_probs_opt, output_log_probs_opt, parent_ids_opt, tgt_cache_indirection_opt, beam_hyps_output_ids_tgt_opt, beam_hyps_sequence_lengths_tgt_opt, beam_hyps_cum_log_probs_opt, beam_hyps_normed_scores_opt, beam_hyps_log_probs_opt, beam_hyps_min_normed_scores_opt, beam_hyps_num_beams_opt, beam_hyps_is_done_opt, use_beam_hyps); return should_stop; } } // namespace torch_ext static auto fasterTransformerGptContextDecoderTHS = torch::jit::class_("FasterTransformer", "DynamicDecodeOp") .def(torch::jit::init()) .def("setup", &torch_ext::DynamicDecodeOp::setup) .def("forward", &torch_ext::DynamicDecodeOp::forward);