/* * Copyright (c) 2020-2024, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "tensorrt_llm/common/config.h" #include "tensorrt_llm/common/logger.h" #include "tensorrt_llm/common/reduceKernelUtils.cuh" #include "tensorrt_llm/kernels/decodingCommon.h" #include "tensorrt_llm/kernels/penaltyKernels.h" #include "tensorrt_llm/layers/defaultDecodingParams.h" #include #include using namespace tensorrt_llm::common; using namespace tensorrt_llm::runtime; TRTLLM_NAMESPACE_BEGIN namespace kernels { __device__ bool almostEqual(float a, float b, float epsilon) { return fabs(a - b) < epsilon; } template __global__ void batchApplyPenalty(T const* const* inputLogits, T* outputLogits, T const* biases, TokenIdType* penaltyWorkspace, TokenIdType const* penaltyWorkspacePrev, float const* temperatures, float const* repetitionPenalties, float const* presencePenalties, float const* frequencyPenalties, SizeType32 const* promptIgnoreLengths, SizeType32 maxSeqLen, SizeType32 vocabSize, SizeType32 vocabSizePadded, TokenIdType const** outputIdsPtr, SizeType32 const** parentIdsPtr, SizeType32 const* inputLengths, SizeType32 const* sequenceLengths, SizeType32 const* minLengths, TokenIdType const* endIds, SizeType32 const* batchSlots, SizeType32 const* tokensPerStep, FinishedState const* finished) { auto const beamWidth = static_cast(gridDim.y); auto const maxTokensPerStep = static_cast(gridDim.z); auto const batchIdx = static_cast(blockIdx.x); auto const beamIdx = static_cast(blockIdx.y); auto const stepIdx = static_cast(blockIdx.z); auto const batchSlot = batchSlots[batchIdx]; FinishedState const finishState = finished != nullptr ? finished[batchSlot] : FinishedState::empty(); if (finishState.isSkipDecoding()) { return; } auto const batchBeamStepIdx = (batchIdx * beamWidth + beamIdx) * maxTokensPerStep + stepIdx; auto const batchSlotBeamIdx = batchSlot * beamWidth + beamIdx; auto const inputLen = inputLengths == nullptr ? SizeType32{0} : inputLengths[batchSlotBeamIdx]; auto const currentStep = sequenceLengths == nullptr ? SizeType32{0} : sequenceLengths[batchSlotBeamIdx]; T const* biasBase = biases + batchSlot * vocabSizePadded; if (tokensPerStep != nullptr && stepIdx >= tokensPerStep[batchSlot]) { return; } float invTemperature{layers::DefaultDecodingParams::getTemperature()}; float repetitionPenalty{layers::DefaultDecodingParams::getRepetitionPenalty()}; float presencePenalty{layers::DefaultDecodingParams::getPresencePenalty()}; float frequencyPenalty{layers::DefaultDecodingParams::getFrequencyPenalty()}; SizeType32 minLength{layers::DefaultDecodingParams::getMinLength()}; SizeType32 promptIgnoreLength{layers::DefaultDecodingParams::getPromptIgnoreLength()}; bool accumulateVocab{false}; bool hasTemperature{false}; bool hasMinLength{false}; if (temperatures != nullptr) { float temperature = temperatures[batchSlot]; invTemperature = 1.0f / (temperature + 1e-6f); hasTemperature |= (!almostEqual(temperature, layers::DefaultDecodingParams::getTemperature(), 1e-9)); } if (repetitionPenalties != nullptr) { repetitionPenalty = repetitionPenalties[batchSlot]; accumulateVocab |= (!almostEqual(repetitionPenalty, layers::DefaultDecodingParams::getRepetitionPenalty(), 1e-9)); } if (presencePenalties != nullptr) { presencePenalty = presencePenalties[batchSlot]; accumulateVocab |= (!almostEqual(presencePenalty, layers::DefaultDecodingParams::getPresencePenalty(), 1e-9)); } if (frequencyPenalties != nullptr) { frequencyPenalty = frequencyPenalties[batchSlot]; accumulateVocab |= (!almostEqual(frequencyPenalty, layers::DefaultDecodingParams::getFrequencyPenalty(), 1e-9)); } if (minLengths != nullptr) { minLength = minLengths[batchSlot]; hasMinLength |= (minLength > 0); } if (promptIgnoreLengths != nullptr) { promptIgnoreLength = min(promptIgnoreLengths[batchSlot], inputLen); } // Initialize or update the number of occurrences of tokens if (accumulateVocab) { penaltyWorkspace += batchBeamStepIdx * 2 * vocabSize; if (currentStep <= inputLen) { // Context phase for (auto index = static_cast(threadIdx.x); index < 2 * vocabSize; index += static_cast(blockDim.x)) { penaltyWorkspace[index] = 0; } __syncthreads(); for (auto step = static_cast(threadIdx.x); step < promptIgnoreLength; step += static_cast(blockDim.x)) { // All beams in the context phase are identical auto penaltyIndex = outputIdsPtr[batchSlot][beamIdx * maxSeqLen + step]; if (penaltyIndex < vocabSize) { penaltyWorkspace[penaltyIndex] = 1; } } for (auto step = promptIgnoreLength + static_cast(threadIdx.x); step < inputLen; step += static_cast(blockDim.x)) { // All beams in the context phase are identical auto penaltyIndex = outputIdsPtr[batchSlot][beamIdx * maxSeqLen + step]; if (penaltyIndex < vocabSize) { atomicAdd(&penaltyWorkspace[penaltyIndex + vocabSize], 1); } } } else { // Generation phase if (beamWidth > 1) { auto parentBeam = parentIdsPtr[batchSlot][beamIdx * maxSeqLen + currentStep - 1]; penaltyWorkspacePrev += ((batchIdx * beamWidth + parentBeam) * maxTokensPerStep + stepIdx) * (2 * vocabSize); for (auto index = static_cast(threadIdx.x); index < 2 * vocabSize; index += static_cast(blockDim.x)) { penaltyWorkspace[index] = penaltyWorkspacePrev[index]; } __syncthreads(); } if (threadIdx.x == 0) { auto penaltyIndex = outputIdsPtr[batchSlot][beamIdx * maxSeqLen + currentStep - 1]; if (penaltyIndex < vocabSize) { penaltyWorkspace[penaltyIndex + vocabSize] += 1; } } } __syncthreads(); } // Apply bias and penalties auto const inLogitsPtr = inputLogits[batchIdx] + (beamIdx * maxTokensPerStep + stepIdx) * vocabSizePadded; auto outLogitsPtr = outputLogits + batchBeamStepIdx * vocabSizePadded; T const MASK_VAL = (std::is_same::value) ? -HALF_FLT_MAX : -FLT_MAX; for (auto index = static_cast(threadIdx.x); index < vocabSizePadded; index += static_cast(blockDim.x)) { if (index < vocabSize) { auto logit = static_cast(inLogitsPtr[index]); // Bias if (biases != nullptr) { logit += static_cast(biasBase[index]); } // Temperature if (hasTemperature) { logit *= invTemperature; } if (accumulateVocab) { SizeType32 numOccurences = penaltyWorkspace[index + vocabSize]; SizeType32 ifPresenceInFullSeq = numOccurences | penaltyWorkspace[index]; if (ifPresenceInFullSeq > 0) { // Repetition if (repetitionPenalties != nullptr) { logit = logit < 0.0f ? logit * repetitionPenalty : logit / repetitionPenalty; } } if (numOccurences > 0) { // Presence if (presencePenalties != nullptr) { logit -= presencePenalty; } // Frequency if (frequencyPenalties != nullptr) { logit -= frequencyPenalty * numOccurences; } } } // do clamp to prevent overflow if (logit > static_cast(-MASK_VAL)) { logit = static_cast(-MASK_VAL); } else if (logit < static_cast(MASK_VAL)) { logit = static_cast(MASK_VAL); } outLogitsPtr[index] = logit; } else { outLogitsPtr[index] = MASK_VAL; } } if (hasMinLength) { __syncthreads(); // If current generation length is too short, make sure EOS doesn't have high probability. // This check is not needed when endId is already -1 as generation won't stop on EOS anyway. if ((threadIdx.x == 0) && (currentStep - inputLen < minLength) && endIds[batchSlot] > -1) { outLogitsPtr[endIds[batchSlot]] = MASK_VAL; } } } template void invokeBatchApplyPenalty(InvokeBatchApplyPenaltyParams const& params) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); dim3 block(512); dim3 grid(params.batchSize, params.beamWidth, params.maxTokensPerStep); batchApplyPenalty<<>>(params.inputLogits, params.outputLogits, params.biases, params.penaltyWorkspace, params.penaltyWorkspacePrev, params.temperatures, params.repetitionPenalties, params.presencePenalties, params.frequencyPenalties, params.promptIgnoreLengths, params.maxSeqLen, params.vocabSize, params.vocabSizePadded, params.outputIdsPtr, params.parentIdsPtr, params.inputLengths, params.sequenceLengths, params.minLengths, params.endIds, params.batchSlots, params.tokensPerStep, params.finished); } template void invokeBatchApplyPenalty(InvokeBatchApplyPenaltyParams const& params); template void invokeBatchApplyPenalty(InvokeBatchApplyPenaltyParams const& params); } // namespace kernels TRTLLM_NAMESPACE_END