/* * Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "tensorrt_llm/kernels/beamSearchTopkKernels.h" #include "tensorrt_llm/kernels/decodingKernels.h" #include "tensorrt_llm/thop/thUtils.h" namespace th = torch; namespace tl = tensorrt_llm; namespace torch_ext { // this should be similar to gatherTree in cpp/tensorrt_llm/runtime/gptSession.cpp th::Tensor gatherTree(th::Tensor& sequence_lengths, th::Tensor& output_ids, th::Tensor& parent_ids, th::Tensor& end_ids, th::Tensor& tiled_input_lengths, th::optional cum_log_probs_opt, th::optional beam_hyps_output_ids_tgt, th::optional beam_hyps_sequence_lengths_tgt, th::optional beam_hyps_cum_log_probs, th::optional beam_hyps_normed_scores, th::optional beam_hyps_log_probs, th::optional beam_hyps_min_normed_scores, th::optional beam_hyps_num_beams, th::optional beam_hyps_is_done, th::optional finished, th::Tensor& length_penalty, int64_t batch_size, int64_t beam_width, int64_t max_seq_len, bool use_beam_hyps) { auto stream = at::cuda::getCurrentCUDAStream().stream(); th::Tensor final_output_ids = torch::zeros( {batch_size, beam_width, max_seq_len}, torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false)); if (use_beam_hyps && beam_width > 1) { tl::kernels::invokeInitializeOutput(get_ptr(final_output_ids), get_ptr(end_ids), batch_size * beam_width, max_seq_len, stream); tl::kernels::BeamHypotheses beamHypotheses; beamHypotheses.sequence_lengths_src = get_ptr(sequence_lengths); beamHypotheses.parent_ids_src = get_ptr(parent_ids); beamHypotheses.output_ids_src = get_ptr(output_ids); beamHypotheses.log_probs_src = nullptr; beamHypotheses.max_seq_len = max_seq_len; beamHypotheses.length_penalty = get_val(length_penalty, 0); beamHypotheses.output_ids_tgt = get_ptr(beam_hyps_output_ids_tgt.value()); beamHypotheses.sequence_lengths_tgt = get_ptr(beam_hyps_sequence_lengths_tgt.value()); beamHypotheses.cum_log_probs = get_ptr(beam_hyps_cum_log_probs.value()); beamHypotheses.normed_scores = get_ptr(beam_hyps_normed_scores.value()); beamHypotheses.log_probs = get_ptr(beam_hyps_log_probs.value()); beamHypotheses.min_normed_scores = get_ptr(beam_hyps_min_normed_scores.value()); beamHypotheses.num_beams = get_ptr(beam_hyps_num_beams.value()); beamHypotheses.is_done = get_ptr(beam_hyps_is_done.value()); beamHypotheses.input_lengths = get_ptr(tiled_input_lengths); tl::kernels::invokeInsertUnfinishedPath(beamHypotheses, get_ptr(finished.value()), get_ptr(cum_log_probs_opt.value()), batch_size, beam_width, stream); sync_check_cuda_error(); tl::kernels::invokeFinalize(get_ptr(final_output_ids), get_ptr(sequence_lengths), cum_log_probs_opt.has_value() ? get_ptr(cum_log_probs_opt.value()) : nullptr, nullptr, // output_logs beamHypotheses.output_ids_tgt, beamHypotheses.sequence_lengths_tgt, beamHypotheses.normed_scores, beamHypotheses.cum_log_probs, beamHypotheses.log_probs, beamHypotheses.num_beams, get_ptr(tiled_input_lengths), beam_width, max_seq_len, batch_size, stream); sync_check_cuda_error(); } else if (!use_beam_hyps && beam_width > 1) { th::Tensor workspace = torch::zeros(batch_size * beam_width * max_seq_len * sizeof(int32_t), torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false)); // For sampling, it is equivalent to all parent ids are 0. tl::kernels::gatherTreeParam param; param.beams = get_ptr(workspace); // Remove prompt length if possible param.sequence_lengths = get_ptr(sequence_lengths); // add sequence_length 1 here because the sequence_length of time step t is t - 1 param.max_sequence_length_final_step = 1; // response input lengths (used to slice the ids during postprocessing), used in interactive generation // This feature is not supported yet, setting it to nullptr temporarily. param.response_input_lengths = nullptr; param.max_seq_len = max_seq_len; param.batch_size = batch_size; param.beam_width = beam_width; param.step_ids = get_ptr(output_ids); param.parent_ids = beam_width == 1 ? nullptr : get_ptr(parent_ids); param.end_tokens = get_ptr(end_ids); param.input_lengths = get_ptr(tiled_input_lengths); param.stream = stream; param.output_ids = get_ptr(final_output_ids); param.cum_log_probs = cum_log_probs_opt.has_value() ? get_ptr(cum_log_probs_opt.value()) : nullptr; param.length_penalty = get_val(length_penalty, 0); // NOTE: need to remove all prompt virtual tokens tl::kernels::invokeGatherTree(param); sync_check_cuda_error(); } else { cudaMemcpyAsync(get_ptr(final_output_ids), get_ptr(output_ids), sizeof(int) * batch_size * beam_width * max_seq_len, cudaMemcpyDeviceToDevice, stream); sync_check_cuda_error(); } return final_output_ids; } } // namespace torch_ext static auto gather_tree = torch::RegisterOperators("tensorrt_llm::gather_tree", &torch_ext::gatherTree);