/* * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "tensorrt_llm/thop/dynamicDecodeOp.h" #include "tensorrt_llm/common/tensorConversion.h" #include "tensorrt_llm/runtime/bufferManager.h" #include "tensorrt_llm/thop/thUtils.h" #include "tensorrt_llm/thop/torchAllocator.h" namespace th = torch; namespace tr = tensorrt_llm::runtime; namespace tcc = tensorrt_llm::common::conversion; namespace torch_ext { template FtDynamicDecode::FtDynamicDecode(const size_t max_batch_size, const size_t max_beam_width, const size_t vocab_size, const size_t vocab_size_padded, int const tensor_para_size, int const pipeline_para_size) : vocab_size_(vocab_size) , vocab_size_padded_(vocab_size_padded) , finished_sum_(tr::BufferManager::pinned( tr::ITensor::makeShape({static_cast(max_batch_size)}), nvinfer1::DataType::kINT32)) { TLLM_CHECK_WITH_INFO(vocab_size_padded_ % tensor_para_size == 0, tensorrt_llm::common::fmtstr( "vocab_size (%ld) is not multiple of tensor_para_size (%d).", vocab_size_padded_, tensor_para_size)); auto stream = at::cuda::getCurrentCUDAStream().stream(); auto allocator = std::make_shared(stream); int deviceId; tensorrt_llm::common::check_cuda_error(cudaGetDevice(&deviceId)); // Get the correct device id tensorrt_llm::common::check_cuda_error(cudaGetDeviceProperties(&prop_, deviceId)); dynamic_decode_layer_ = std::make_shared>(tr::DecodingMode::None(), max_batch_size, max_beam_width, vocab_size_, vocab_size_padded_, stream, std::move(allocator), &prop_); } namespace { template void safeInsert(th::optional& tensor, std::optional>& arg) { using value_type = T; if (tensor.has_value()) { auto ptr = get_ptr(tensor.value()); auto shape = convert_shape(tensor.value()); size_t const size = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<>()); arg = std::vector(ptr, ptr + size); } } template void safeUpdate(th::optional& tensor, std::optional& arg) { if (tensor.has_value()) { arg = convert_tensor(tensor.value()); } } template void safeUpdateScalar(th::optional& tensor, std::optional& arg, std::string const& name) { if (tensor.has_value()) { auto accessor = tensor->accessor(); TLLM_CHECK_WITH_INFO(accessor.size(0) == 1, name + " must be a scalar"); arg = accessor[0]; } } template void safeUpdatePtr(th::optional& tensor, T*& ptr) { if (tensor.has_value()) { ptr = get_ptr(tensor.value()); } } } // namespace template void FtDynamicDecode::setup(size_t batch_size, size_t beam_width, th::optional runtime_top_k_opt, th::optional runtime_top_p_opt, th::optional temperature_opt, th::optional repetition_penalty_opt, th::optional presence_penalty_opt, th::optional frequency_penalty_opt, th::optional min_length_opt, th::optional length_penalty_opt, th::optional early_stopping_opt, th::optional beam_search_diversity_rate_opt, th::optional random_seed_opt, th::optional top_p_decay_opt, th::optional top_p_min_opt, th::optional top_p_reset_ids_opt) { auto stream = at::cuda::getCurrentCUDAStream().stream(); dynamic_decode_layer_->setStream(stream); SetupParams setupParams; safeInsert(temperature_opt, setupParams.temperature); safeInsert(repetition_penalty_opt, setupParams.repetition_penalty); safeInsert(presence_penalty_opt, setupParams.presence_penalty); safeInsert(frequency_penalty_opt, setupParams.frequency_penalty); safeInsert(min_length_opt, setupParams.min_length); safeInsert(runtime_top_k_opt, setupParams.runtime_top_k); safeInsert(runtime_top_p_opt, setupParams.runtime_top_p); safeInsert(random_seed_opt, setupParams.randomSeed); safeInsert(top_p_decay_opt, setupParams.top_p_decay); safeInsert(top_p_min_opt, setupParams.top_p_min); safeInsert(top_p_reset_ids_opt, setupParams.top_p_reset_ids); safeInsert(beam_search_diversity_rate_opt, setupParams.beam_search_diversity_rate); safeInsert(length_penalty_opt, setupParams.length_penalty); safeInsert(early_stopping_opt, setupParams.early_stopping); dynamic_decode_layer_->setup(batch_size, beam_width, nullptr, setupParams); } template void FtDynamicDecode::forward(th::Tensor& logits, int step, int max_input_length, int max_attention_window, int sink_token_length, uint64_t ite, int local_batch_size, th::Tensor end_id, th::optional embedding_bias_opt, th::optional input_lengths_opt, th::optional sequence_limit_length_opt, th::optional stop_words_list_ptrs_opt, th::optional stop_words_lens_opt, int32_t max_stop_words_len, th::optional bad_words_list_ptrs_opt, th::optional bad_words_lens_opt, int32_t max_bad_words_len, th::optional no_repeat_ngram_size_opt, th::optional src_cache_indirection_opt, th::Tensor& output_token_ids, th::Tensor& newTokens, th::Tensor& should_stop, th::optional finished_input, th::optional finished_output, th::optional sequence_lengths_opt, th::optional cum_log_probs_opt, th::optional output_log_probs_opt, th::optional output_log_probs_tiled_opt, th::optional parent_ids_opt, th::optional tgt_cache_indirection_opt, th::optional beam_hyps_output_ids_tgt_opt, th::optional beam_hyps_sequence_lengths_tgt_opt, th::optional beam_hyps_cum_log_probs_opt, th::optional beam_hyps_normed_scores_opt, th::optional beam_hyps_log_probs_opt, th::optional beam_hyps_min_normed_scores_opt, th::optional beam_hyps_num_beams_opt, th::optional beam_hyps_is_done_opt, bool use_beam_hyps) { typename tensorrt_llm::layers::DynamicDecodeLayer::ForwardParams forwardParams{step, static_cast(ite), max_input_length, max_attention_window, sink_token_length, local_batch_size, convert_tensor(end_id)}; forwardParams.logits = convert_tensor(logits); safeUpdate(embedding_bias_opt, forwardParams.embedding_bias); safeUpdate(input_lengths_opt, forwardParams.input_lengths); safeUpdate(sequence_limit_length_opt, forwardParams.sequence_limit_length); safeUpdate(stop_words_list_ptrs_opt, forwardParams.stop_words_ptr); safeUpdate(stop_words_lens_opt, forwardParams.stop_words_lengths); forwardParams.max_stop_words_len = max_stop_words_len; safeUpdate(bad_words_list_ptrs_opt, forwardParams.bad_words_ptr); safeUpdate(bad_words_lens_opt, forwardParams.bad_words_lengths); forwardParams.max_bad_words_len = max_bad_words_len; safeUpdate(no_repeat_ngram_size_opt, forwardParams.no_repeat_ngram_size); safeUpdate(src_cache_indirection_opt, forwardParams.src_cache_indirection); auto const& output_ids_converted = convert_tensor(output_token_ids); typename tensorrt_llm::layers::DynamicDecodeLayer::OutputParams outputParams{output_ids_converted}; outputParams.newTokens = std::move(convert_tensor(newTokens)); safeUpdate(finished_input, forwardParams.finished); safeUpdate(finished_output, outputParams.finished); safeUpdate(sequence_lengths_opt, outputParams.sequence_length); safeUpdate(cum_log_probs_opt, outputParams.cum_log_probs); safeUpdate(output_log_probs_opt, outputParams.output_log_probs); safeUpdate(output_log_probs_tiled_opt, outputParams.output_log_probs_tiled); safeUpdate(parent_ids_opt, outputParams.parent_ids); safeUpdate(tgt_cache_indirection_opt, outputParams.tgt_cache_indirection); std::int32_t* finished_sum_host = nullptr; if (forwardParams.sequence_limit_length && outputParams.finished.has_value()) { outputParams.finished_sum = tcc::toTllmTensor(*finished_sum_); finished_sum_host = tr::bufferCast(*finished_sum_); for (int32_t bi = 0; bi < local_batch_size; ++bi) { finished_sum_host[bi] = 0; } } if (use_beam_hyps) { outputParams.beamHypotheses = std::make_shared(); safeUpdatePtr(beam_hyps_is_done_opt, outputParams.beamHypotheses->is_done); safeUpdatePtr(beam_hyps_cum_log_probs_opt, outputParams.beamHypotheses->cum_log_probs); safeUpdatePtr(beam_hyps_log_probs_opt, outputParams.beamHypotheses->log_probs); safeUpdatePtr(beam_hyps_min_normed_scores_opt, outputParams.beamHypotheses->min_normed_scores); safeUpdatePtr(beam_hyps_normed_scores_opt, outputParams.beamHypotheses->normed_scores); safeUpdatePtr(beam_hyps_num_beams_opt, outputParams.beamHypotheses->num_beams); safeUpdatePtr(beam_hyps_output_ids_tgt_opt, outputParams.beamHypotheses->output_ids_tgt); safeUpdatePtr(beam_hyps_sequence_lengths_tgt_opt, outputParams.beamHypotheses->sequence_lengths_tgt); // TODO: move the assignment below into onlineBeamSearchLayer.cu safeUpdatePtr(input_lengths_opt, outputParams.beamHypotheses->input_lengths); } dynamic_decode_layer_->forward(outputParams, forwardParams); if (finished_sum_host) { TLLM_CUDA_CHECK(::cudaStreamSynchronize(dynamic_decode_layer_->getStream())); int32_t finished_sum = 0; for (int32_t bi = 0; bi < local_batch_size; ++bi) { finished_sum += finished_sum_host[bi]; } auto const numToFinish = outputParams.finished->size(); auto should_stop_accessor = should_stop.accessor(); should_stop_accessor[0] = numToFinish == finished_sum; } } DynamicDecodeOp::DynamicDecodeOp(const int64_t max_batch_size, const int64_t max_beam_width, const int64_t vocab_size, const int64_t vocab_size_padded, const int64_t tensor_para_size, const int64_t pipeline_para_size, at::ScalarType scalar_type) : max_batch_size_(static_cast(max_batch_size)) , max_beam_width_(static_cast(max_beam_width)) , vocab_size_(static_cast(vocab_size)) , vocab_size_padded_(static_cast(vocab_size_padded)) , tensor_para_size_(static_cast(tensor_para_size)) , pipeline_para_size_(static_cast(pipeline_para_size)) , scalar_type_(scalar_type) { TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); createInstance(); } void DynamicDecodeOp::createInstance() { dynamic_decode_.reset(); switch (scalar_type_) { case at::ScalarType::Float: dynamic_decode_ = std::make_unique>( max_batch_size_, max_beam_width_, vocab_size_, vocab_size_padded_, tensor_para_size_, pipeline_para_size_); break; case at::ScalarType::Half: dynamic_decode_ = std::make_unique>( max_batch_size_, max_beam_width_, vocab_size_, vocab_size_padded_, tensor_para_size_, pipeline_para_size_); break; default: throw std::runtime_error("Wrong tensor type."); } } void DynamicDecodeOp::setup(int64_t batch_size, int64_t beam_width, th::optional runtime_top_k_opt, th::optional runtime_top_p_opt, th::optional temperature_opt, th::optional repetition_penalty_opt, th::optional presence_penalty_opt, th::optional frequency_penalty_opt, th::optional min_length_opt, th::optional length_penalty_opt, th::optional early_stopping_opt, th::optional beam_search_diversity_rate_opt, th::optional random_seed_opt, th::optional top_p_decay_opt, th::optional top_p_min_opt, th::optional top_p_reset_ids_opt) { // TODO: Revise DynamicDecodeLayer and make the decode arguments consistent. CHECK_OPTIONAL_CPU_INPUT(runtime_top_k_opt, torch::kInt32); CHECK_OPTIONAL_CPU_INPUT(runtime_top_p_opt, torch::kFloat); CHECK_OPTIONAL_CPU_INPUT(temperature_opt, torch::kFloat); CHECK_OPTIONAL_CPU_INPUT(repetition_penalty_opt, torch::kFloat); CHECK_OPTIONAL_CPU_INPUT(presence_penalty_opt, torch::kFloat); CHECK_OPTIONAL_CPU_INPUT(frequency_penalty_opt, torch::kFloat); CHECK_OPTIONAL_CPU_INPUT(min_length_opt, torch::kInt32); CHECK_OPTIONAL_CPU_INPUT(length_penalty_opt, torch::kFloat); CHECK_OPTIONAL_CPU_INPUT(early_stopping_opt, torch::kInt32); CHECK_OPTIONAL_CPU_INPUT(beam_search_diversity_rate_opt, torch::kFloat); CHECK_OPTIONAL_CPU_INPUT(random_seed_opt, torch::kInt64); CHECK_OPTIONAL_INPUT(top_p_decay_opt, torch::kFloat); CHECK_OPTIONAL_INPUT(top_p_min_opt, torch::kFloat); CHECK_OPTIONAL_INPUT(top_p_reset_ids_opt, torch::kInt32); // TODO: add a parameter "return_normed_score" to return normed_cum_log_probs / cum_log_probs dynamic_decode_->setup(static_cast(batch_size), static_cast(beam_width), runtime_top_k_opt, runtime_top_p_opt, temperature_opt, repetition_penalty_opt, presence_penalty_opt, frequency_penalty_opt, min_length_opt, length_penalty_opt, early_stopping_opt, beam_search_diversity_rate_opt, random_seed_opt, top_p_decay_opt, top_p_min_opt, top_p_reset_ids_opt); } th::Tensor DynamicDecodeOp::forward( // BS: batch_size, BM: beam_width, mSL: max_seq_length th::Tensor logits, // [BS, BM, vocab_size_padded], T int64_t step, // int64_t max_input_length, // int64_t max_attention_window, // int64_t sink_token_length, // int64_t ite, // int64_t local_batch_size, // th::Tensor end_id, // [BS*BM], int th::optional embedding_bias_opt, // [vocab_size_padded], T th::optional input_lengths_opt, // [BS*BM], int, length of input contexts th::optional sequence_limit_length_opt, // [BS, 1], int th::optional stop_words_list_ptrs_opt, // [BS][2, stop_words_length], int64 th::optional stop_words_lens_opt, // [BS], int int64_t max_stop_words_len, // th::optional bad_words_list_ptrs_opt, // [BS][2, bad_words_length], int64 th::optional bad_words_lens_opt, // [BS], int int64_t max_bad_words_len, // th::optional no_repeat_ngram_size_opt, // [BS], int th::optional src_cache_indirection_opt, // [local_BS, BM, mSL], int th::Tensor output_token_ids, // [BS, BM, mSL], int ? [mSL, BS, BM] th::Tensor newTokens, // [BS, BM, 1], int th::optional finished_input, // [BS, BM], uint8 th::optional finished_output, // [BS, BM], uint8 th::optional seuqence_lengths_opt, // [BS*BM], int, length of the current sequences th::optional cum_log_probs_opt, // [BS, BM], float th::optional output_log_probs_opt, // [BS, BM, mSL], float ? [mSL, BS, BM] th::optional output_log_probs_tiled_opt, // [mSL, BS, BM], float ? [BS, BM, mSL] th::optional parent_ids_opt, // [BS, BM, mSL], int ? [mSL, BS, BM] th::optional tgt_cache_indirection_opt, // [local_BS, BM, memory_length], int th::optional beam_hyps_output_ids_tgt_opt, // [BS, BM*2, mSL], int th::optional beam_hyps_sequence_lengths_tgt_opt, // [BS, BM*2], int th::optional beam_hyps_cum_log_probs_opt, // [BS, BM*2], float th::optional beam_hyps_normed_scores_opt, // [BS, BM*2], float th::optional beam_hyps_log_probs_opt, // [BS, BM*2, mSL], float th::optional beam_hyps_min_normed_scores_opt, // [BS], float th::optional beam_hyps_num_beams_opt, // [BS], int th::optional beam_hyps_is_done_opt, // [BS], bool bool use_beam_hyps // ) { CHECK_INPUT(logits, scalar_type_); TLLM_CHECK_WITH_INFO(logits.dim() == 3, "logits is of shape (batch_size, beam_width, vocab_size_padded), but got dim=%d shape=%s", (int) logits.dim(), tensorrt_llm::common::vec2str(convert_shape(logits)).c_str()); TLLM_CHECK_WITH_INFO(static_cast(logits.size(2)) == vocab_size_padded_, "logits is of shape (batch_size, beam_width, vocab_size(%ld)), but got the last dim=%ld.", vocab_size_padded_, static_cast(logits.size(2))); CHECK_INPUT(end_id, torch::kInt32); CHECK_OPTIONAL_INPUT(embedding_bias_opt, scalar_type_); CHECK_OPTIONAL_INPUT(input_lengths_opt, torch::kInt32); CHECK_OPTIONAL_INPUT(sequence_limit_length_opt, torch::kInt32); CHECK_OPTIONAL_INPUT(stop_words_list_ptrs_opt, torch::kInt64); CHECK_OPTIONAL_INPUT(stop_words_lens_opt, torch::kInt32); CHECK_OPTIONAL_INPUT(bad_words_list_ptrs_opt, torch::kInt64); CHECK_OPTIONAL_INPUT(bad_words_lens_opt, torch::kInt32); CHECK_OPTIONAL_INPUT(no_repeat_ngram_size_opt, torch::kInt32); CHECK_OPTIONAL_INPUT(src_cache_indirection_opt, torch::kInt32); CHECK_INPUT(output_token_ids, torch::kInt32); CHECK_INPUT(newTokens, torch::kInt32); CHECK_OPTIONAL_INPUT(finished_input, torch::kUInt8); CHECK_OPTIONAL_INPUT(finished_output, torch::kUInt8); CHECK_OPTIONAL_INPUT(seuqence_lengths_opt, torch::kInt32); CHECK_OPTIONAL_INPUT(cum_log_probs_opt, torch::kFloat32); CHECK_OPTIONAL_INPUT(output_log_probs_opt, torch::kFloat32); CHECK_OPTIONAL_INPUT(output_log_probs_tiled_opt, torch::kFloat32); CHECK_OPTIONAL_INPUT(parent_ids_opt, torch::kInt32); CHECK_OPTIONAL_INPUT(tgt_cache_indirection_opt, torch::kInt32); th::Tensor should_stop = torch::zeros({1}, torch::dtype(torch::kBool).requires_grad(false)); dynamic_decode_->forward( // Inputs logits, static_cast(step), static_cast(max_input_length), static_cast(max_attention_window), static_cast(sink_token_length), static_cast(ite), static_cast(local_batch_size), end_id, embedding_bias_opt, input_lengths_opt, sequence_limit_length_opt, stop_words_list_ptrs_opt, stop_words_lens_opt, static_cast(max_stop_words_len), bad_words_list_ptrs_opt, bad_words_lens_opt, static_cast(max_bad_words_len), no_repeat_ngram_size_opt, src_cache_indirection_opt, // Outputs output_token_ids, newTokens, should_stop, finished_input, finished_output, seuqence_lengths_opt, cum_log_probs_opt, output_log_probs_opt, output_log_probs_tiled_opt, parent_ids_opt, tgt_cache_indirection_opt, beam_hyps_output_ids_tgt_opt, beam_hyps_sequence_lengths_tgt_opt, beam_hyps_cum_log_probs_opt, beam_hyps_normed_scores_opt, beam_hyps_log_probs_opt, beam_hyps_min_normed_scores_opt, beam_hyps_num_beams_opt, beam_hyps_is_done_opt, use_beam_hyps); return should_stop; } } // namespace torch_ext static auto trtllmGptContextDecoderTHS = torch::jit::class_("trtllm", "DynamicDecodeOp") .def(torch::jit::init()) .def("setup", &torch_ext::DynamicDecodeOp::setup) .def("forward", &torch_ext::DynamicDecodeOp::forward);