/* * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "tensorrt_llm/common/cudaUtils.h" #include "tensorrt_llm/common/memoryUtils.h" #include "tensorrt_llm/kernels/beamSearchPenaltyKernels.h" #include "tensorrt_llm/layers/baseBeamSearchLayer.h" #include "tensorrt_llm/layers/fillBuffers.h" #include using namespace tensorrt_llm::common; using namespace tensorrt_llm::kernels; namespace tensorrt_llm { namespace layers { __global__ void update_indir_cache_kernel(int* tgt_indir_cache, const int* src_indir_cache, const int** parent_ids, const FinishedState* finished, const int* sequence_lengths, const int* input_lengths, int batch_dim, int local_batch_size, int beam_width, int max_attention_window, int sink_token_length, int max_seq_len) { int time_step = threadIdx.x + blockIdx.x * blockDim.x; int bb_id = threadIdx.y + blockIdx.y * blockDim.y; // should be just blockIdx.y? const int current_step{sequence_lengths[bb_id] - 1}; // the sequence_lengths is updated, need to minus 1 const int input_length{input_lengths == nullptr ? 0 : input_lengths[bb_id]}; const int batch_id = bb_id / beam_width; const int beam_id = bb_id % beam_width; // Exit when the batch_beam or timestep is out of the bound. // Assume that KV Cache is shared and fixed for context part, // so we don't need to update the indices for context part. if (bb_id >= beam_width * local_batch_size || time_step >= max_seq_len || time_step < input_length || time_step < (max_seq_len - max_attention_window) || finished[bb_id].isFinished()) { return; } int time_step_circ = time_step; if (time_step_circ >= sink_token_length) { time_step_circ = sink_token_length + (time_step - sink_token_length) % (max_attention_window - sink_token_length); } // for the parent_ids, we will still keep it for all past tokens (i.e. max_seq_len) const int src_beam = parent_ids[batch_id][beam_id * max_seq_len + current_step]; // for the indir tables, we have the cyclic kv cache. const uint32_t tgt_offset = batch_id * beam_width * max_attention_window + beam_id * max_attention_window + time_step_circ; const uint32_t src_offset = batch_id * beam_width * max_attention_window + src_beam * max_attention_window + time_step_circ; tgt_indir_cache[tgt_offset] = (time_step == current_step) ? beam_id : src_indir_cache[src_offset]; } void update_indir_cache_kernelLauncher(int* tgt_indir_cache, const int* src_indir_cache, const int** parent_ids, const FinishedState* finished, const int* sequence_lengths, const int* input_lengths, int batch_dim, int local_batch_size, int beam_width, int max_seq_len, int max_attention_window, int sink_token_length, cudaStream_t stream) { const dim3 block(32); // Update indirections steps [input_length[bb_id], sequence_lengths[bb_id]], included const dim3 grid((max_seq_len + block.x - 1) / block.x, local_batch_size * beam_width); update_indir_cache_kernel<<>>(tgt_indir_cache, src_indir_cache, parent_ids, finished, sequence_lengths, input_lengths, batch_dim, local_batch_size, beam_width, max_attention_window, sink_token_length, max_seq_len); } template BaseBeamSearchLayer::BaseBeamSearchLayer(size_t vocab_size, size_t vocab_size_padded, cudaStream_t stream, std::shared_ptr allocator, bool is_free_buffer_after_forward) : BaseLayer(stream, std::move(allocator), is_free_buffer_after_forward, nullptr) , vocab_size_(vocab_size) , vocab_size_padded_(vocab_size_padded) { } template BaseBeamSearchLayer::BaseBeamSearchLayer(BaseBeamSearchLayer const& beam_search_layer) : BaseLayer(beam_search_layer) , vocab_size_(beam_search_layer.vocab_size_) , vocab_size_padded_(beam_search_layer.vocab_size_padded_) , topk_softmax_workspace_size_(beam_search_layer.topk_softmax_workspace_size_) { } template BaseBeamSearchLayer::~BaseBeamSearchLayer() { TLLM_LOG_TRACE(__PRETTY_FUNCTION__); freeBuffer(); } template void BaseBeamSearchLayer::freeBuffer() { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); if (is_allocate_buffer_) { allocator_->free((void**) (&temperature_buf_)); allocator_->free((void**) (&min_lengths_buf_)); allocator_->free((void**) (&repetition_penalty_buf_)); allocator_->free((void**) (&presence_penalty_buf_)); allocator_->free((void**) (&frequency_penalty_buf_)); is_allocate_buffer_ = false; } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void BaseBeamSearchLayer::allocateBuffer(size_t batch_size) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); temperature_buf_ = allocator_->reMalloc(temperature_buf_, sizeof(float) * batch_size, false); min_lengths_buf_ = allocator_->reMalloc(min_lengths_buf_, sizeof(int) * batch_size, false); repetition_penalty_buf_ = allocator_->reMalloc(repetition_penalty_buf_, sizeof(float) * batch_size, false); presence_penalty_buf_ = allocator_->reMalloc(presence_penalty_buf_, sizeof(float) * batch_size, false); frequency_penalty_buf_ = allocator_->reMalloc(frequency_penalty_buf_, sizeof(float) * batch_size, false); is_allocate_buffer_ = true; TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void BaseBeamSearchLayer::setupBase(size_t batch_size, SetupParams const& setupParams) { allocateBuffer(batch_size); TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); // Setup penalties. FillBuffers const fillBuffers{batch_size, stream_}; fillBuffers(setupParams.temperature, 1.0f, mTemperature, temperature_buf_); fillBuffers(setupParams.min_length, 1, mMinLength, min_lengths_buf_); use_repetition_penalty_ = static_cast(setupParams.repetition_penalty); use_presence_penalty_ = static_cast(setupParams.presence_penalty); use_frequency_penalty_ = static_cast(setupParams.frequency_penalty); if (use_repetition_penalty_) { fillBuffers(setupParams.repetition_penalty, getDefaultPenaltyValue(RepetitionPenaltyType::Repetition), mRepetitionPenalty, repetition_penalty_buf_); } if (use_presence_penalty_) { fillBuffers(setupParams.presence_penalty, getDefaultPenaltyValue(RepetitionPenaltyType::Presence), mPresencePenalty, presence_penalty_buf_); } if (use_frequency_penalty_) { fillBuffers(setupParams.frequency_penalty, getDefaultPenaltyValue(RepetitionPenaltyType::Frequency), mFrequencyPenalty, frequency_penalty_buf_); } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void BaseBeamSearchLayer::forward(BeamSearchOutputParams& outputs, ForwardParams const& params) { TLLM_LOG_TRACE("%s", __PRETTY_FUNCTION__); Tensor& output_ids_ptr = outputs.output_ids_ptr; const auto batch_size = static_cast(output_ids_ptr.shape[0]); const auto beam_width = static_cast(output_ids_ptr.shape[1]); const auto max_seq_len = static_cast(output_ids_ptr.shape[2]); TLLM_CHECK_WITH_INFO(params.ite == 0, "Pipeline Parallelism is not supported yet !"); const int ite{params.ite}; Tensor const& logits = params.logits; const auto local_batch_size = logits.shape[0]; const T* embedding_bias = params.embedding_bias ? params.embedding_bias->template getPtr() : nullptr; auto* end_ids = params.end_ids.template getPtr(); auto* const input_lengths = params.input_lengths ? params.input_lengths->template getPtr() : nullptr; int* sequence_length = (outputs.sequence_length) ? outputs.sequence_length->template getPtr() : nullptr; invokeAddBiasApplyPenalties(logits.getPtr(), output_ids_ptr.template getPtr(), outputs.parent_ids_ptr.template getPtr(), input_lengths, sequence_length, embedding_bias, ite, local_batch_size, batch_size, beam_width, vocab_size_, vocab_size_padded_, end_ids, temperature_buf_, mTemperature, repetition_penalty_buf_, presence_penalty_buf_, frequency_penalty_buf_, mRepetitionPenalty, mPresencePenalty, mFrequencyPenalty, use_repetition_penalty_, use_presence_penalty_, use_frequency_penalty_, min_lengths_buf_, max_seq_len, stream_); sync_check_cuda_error(); invokeSoftMax(outputs, params); if (beam_width > 1) { update_indir_cache_kernelLauncher(outputs.tgt_cache_indirection.template getPtr(), params.src_cache_indirection.template getPtr(), outputs.parent_ids_ptr.template getPtr(), reinterpret_cast( outputs.finished->template getPtr()), sequence_length, input_lengths, batch_size, local_batch_size, beam_width, max_seq_len, params.max_attention_window, params.sink_token_length, stream_); sync_check_cuda_error(); } sync_check_cuda_error(); if (is_free_buffer_after_forward_) { freeBuffer(); } sync_check_cuda_error(); } template class BaseBeamSearchLayer; template class BaseBeamSearchLayer; } // namespace layers } // namespace tensorrt_llm