/* * Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2021, NAVER Corp. Authored by CLOVA. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "penaltyLayer.h" #include "tensorrt_llm/common/cudaUtils.h" #include "tensorrt_llm/common/nvtxUtils.h" #include "tensorrt_llm/kernels/penaltyKernels.h" #include "tensorrt_llm/kernels/penaltyTypes.h" #include "tensorrt_llm/layers/defaultDecodingParams.h" #include "tensorrt_llm/layers/layerUtils.h" #include "tensorrt_llm/runtime/bufferManager.h" #include "tensorrt_llm/runtime/common.h" #include using namespace tensorrt_llm::common; using namespace tensorrt_llm::kernels; using namespace tensorrt_llm::runtime; namespace tensorrt_llm::layers { template size_t PenaltyLayer::getWorkspaceSize() const noexcept { return mWorkspaceSize; } template PenaltyLayer::PenaltyLayer(executor::DecodingMode const& mode, DecoderDomain const& decoderDomain, std::shared_ptr bufferManager) : BaseLayer(decoderDomain, bufferManager) , mDecodingMode(mode) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); initialize(); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void PenaltyLayer::initialize() { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); allocateBuffer(); mCyclicStep = 0; mRuntimeMaxSeqLen = 0; mConfiguredBeamWidth = -1; if (!mDecodingMode.isAuto()) { mConfiguredBeamWidth = mDecoderDomain.getBeamWidth(); allocateWorkspace(); } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void PenaltyLayer::allocateWorkspace() { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); if (mDecodingMode.isUseOccurrencePenalty()) { auto const workspaceSize = mDecoderDomain.getBatchSize() * mDecoderDomain.getMaxDecodingTokens() * mConfiguredBeamWidth * mDecoderDomain.getVocabSize() * 2; mPenaltyWorkspaceDevice = mBufferManager->gpu(workspaceSize, nvinfer1::DataType::kINT32); if (mDecodingMode.isBeamSearch()) { mPenaltyWorkspacePrevDevice = mBufferManager->gpu(workspaceSize, nvinfer1::DataType::kINT32); } } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void PenaltyLayer::allocateBuffer() { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); mLogitsPtrsHost = mBufferManager->pinnedPool(ITensor::makeShape({}), TRTDataType::value); auto const batchSizeShape = ITensor::makeShape({mDecoderDomain.getBatchSize()}); mTemperature = mBufferManager->pinnedPool(batchSizeShape, TRTDataType::value); mRepetitionPenalty = mBufferManager->pinnedPool(batchSizeShape, TRTDataType::value); mPresencePenalty = mBufferManager->pinnedPool(batchSizeShape, TRTDataType::value); mFrequencyPenalty = mBufferManager->pinnedPool(batchSizeShape, TRTDataType::value); mMinLength = mBufferManager->pinnedPool(batchSizeShape, TRTDataType::value); mPromptIgnoreLength = mBufferManager->pinnedPool(batchSizeShape, TRTDataType::value); if (mDecodingMode.isUseTemperature()) { mTemperatureDevice = mBufferManager->gpu(batchSizeShape, nvinfer1::DataType::kFLOAT); } if (mDecodingMode.isUseRepetitionPenalty()) { mRepetitionPenaltyDevice = mBufferManager->gpu(batchSizeShape, nvinfer1::DataType::kFLOAT); } if (mDecodingMode.isUsePresencePenalty()) { mPresencePenaltyDevice = mBufferManager->gpu(batchSizeShape, nvinfer1::DataType::kFLOAT); } if (mDecodingMode.isUseFrequencyPenalty()) { mFrequencyPenaltyDevice = mBufferManager->gpu(batchSizeShape, nvinfer1::DataType::kFLOAT); } if (mDecodingMode.isUseMinLength()) { mMinLengthDevice = mBufferManager->gpu(batchSizeShape, nvinfer1::DataType::kINT32); } if (mDecodingMode.isUseOccurrencePenalty()) { mPromptIgnoreLengthDevice = mBufferManager->gpu(batchSizeShape, nvinfer1::DataType::kINT32); } auto const logitsPtrDeviceDesc = std::make_pair(batchSizeShape, TRTDataType::value); mWorkspaceSize = DecodingLayerWorkspace::calculateRequiredWorkspaceSize(logitsPtrDeviceDesc); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void PenaltyLayer::setup(SizeType32 batchSize, SizeType32 beamWidth, TensorConstPtr batchSlots, std::shared_ptr const& baseSetupParams, std::shared_ptr const& workspace) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); NVTX3_SCOPED_RANGE(PenaltyLayer_setup); auto setupParams = std::dynamic_pointer_cast(baseSetupParams); if (mConfiguredBeamWidth == -1) { // This code is left only for Python runtime // In C++ runtime given maxBeamWidth should always be equal to the runtime beamWidth TLLM_CHECK(mDecodingMode.isAuto()); mConfiguredBeamWidth = beamWidth; mDecodingMode = mConfiguredBeamWidth == 1 ? executor::DecodingMode::TopKTopP() : executor::DecodingMode::BeamSearch(); allocateWorkspace(); } // Setup penalties. FillBuffers const fillBuffers{batchSize, mDecoderDomain.getBatchSize(), mBufferManager}; auto const& penaltyParams = setupParams->penaltyParams; TLLM_CHECK_WITH_INFO(penaltyParams, "penaltyParams for setup is not set"); bool const useTemperature = mDecodingMode.isUseTemperature() && penaltyParams->temperature.has_value(); bool const useRepetitionPenalty = mDecodingMode.isUseRepetitionPenalty() && penaltyParams->repetitionPenalty.has_value(); bool const usePresencePenalty = mDecodingMode.isUsePresencePenalty() && penaltyParams->presencePenalty.has_value(); bool const useFrequencyPenalty = mDecodingMode.isUseFrequencyPenalty() && penaltyParams->frequencyPenalty.has_value(); bool const useMinLength = mDecodingMode.isUseMinLength() && penaltyParams->minLength.has_value(); bool const usePromptIgnoreLength = mDecodingMode.isUseOccurrencePenalty() && penaltyParams->promptIgnoreLength.has_value(); // FIXME: once one of the requests has some penalty, we will always have to compute it. // To avoid that we need to scan through all active requests at each iteration. mUseTemperature |= useTemperature; mUseRepetitionPenalty |= useRepetitionPenalty; mUsePresencePenalty |= usePresencePenalty; mUseFrequencyPenalty |= useFrequencyPenalty; mUseMinLength |= useMinLength; mUsePromptIgnoreLength |= usePromptIgnoreLength; if (mUseTemperature) { fillBuffers(penaltyParams->temperature, DefaultDecodingParams::getTemperature(), mTemperature, mTemperatureDevice, batchSlots, getLimitsPenalty(DecodingPenaltyType::Temperature), "temperature penalty"); } if (mUseRepetitionPenalty) { fillBuffers(penaltyParams->repetitionPenalty, DefaultDecodingParams::getRepetitionPenalty(), mRepetitionPenalty, mRepetitionPenaltyDevice, batchSlots, getLimitsPenalty(DecodingPenaltyType::Repetition), "repetition penalty"); } if (mUsePresencePenalty) { fillBuffers(penaltyParams->presencePenalty, DefaultDecodingParams::getPresencePenalty(), mPresencePenalty, mPresencePenaltyDevice, batchSlots, getLimitsPenalty(DecodingPenaltyType::Presence), "presence penalty"); } if (mUseFrequencyPenalty) { fillBuffers(penaltyParams->frequencyPenalty, DefaultDecodingParams::getFrequencyPenalty(), mFrequencyPenalty, mFrequencyPenaltyDevice, batchSlots, getLimitsPenalty(DecodingPenaltyType::Frequency), "frequency penalty"); } if (mUseMinLength) { fillBuffers(penaltyParams->minLength, DefaultDecodingParams::getMinLength(), mMinLength, mMinLengthDevice, batchSlots, getLimitsPenalty(DecodingPenaltyType::MinLength), "min length"); } if (mUsePromptIgnoreLength) { fillBuffers(penaltyParams->promptIgnoreLength, DefaultDecodingParams::getPromptIgnoreLength(), mPromptIgnoreLength, mPromptIgnoreLengthDevice, batchSlots, getLimitsPenalty(DecodingPenaltyType::PromptIgnoreLength), "prompt ignore length"); } // Reset penalty workspace auto const workspaceSizePerBatch = mDecoderDomain.getMaxDecodingTokens() * mConfiguredBeamWidth * mDecoderDomain.getVocabSize() * 2; for (SizeType32 bi = 0; bi < batchSize; ++bi) { auto batchSlot = runtime::bufferCast(*batchSlots)[bi]; if (mPenaltyWorkspaceDevice) { auto deviceSlice = runtime::IBuffer::slice( mPenaltyWorkspaceDevice, batchSlot * workspaceSizePerBatch, workspaceSizePerBatch); mBufferManager->setZero(*deviceSlice); } if (mPenaltyWorkspacePrevDevice) { auto deviceSlice = runtime::IBuffer::slice( mPenaltyWorkspacePrevDevice, batchSlot * workspaceSizePerBatch, workspaceSizePerBatch); mBufferManager->setZero(*deviceSlice); } } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void PenaltyLayer::forwardAsync(std::shared_ptr const& baseOutputs, std::shared_ptr const& baseInputs, std::shared_ptr const& workspace) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); NVTX3_SCOPED_RANGE(PenaltyLayer_forwardAsync); auto outputs = std::dynamic_pointer_cast(baseOutputs); auto params = std::dynamic_pointer_cast(baseInputs); auto const localDecoderDomain = getLocalDecoderDomain(params, mDecoderDomain); auto const maxSeqLen = outputs->outputIds->getDimension<-1>(); if (mLogitsPtrsHost->data() == nullptr) { mLogitsPtrsHost->reshape( ITensor::makeShape({static_cast(maxSeqLen), static_cast(mDecoderDomain.getBatchSize())})); mRuntimeMaxSeqLen = maxSeqLen; } mCyclicStep = mCyclicStep % mRuntimeMaxSeqLen; TensorPtr logitsPtrsHost = ITensor::slice(mLogitsPtrsHost, mCyclicStep, 1); logitsPtrsHost->squeeze(0); auto logitsPtrsHostData = bufferCast(*logitsPtrsHost); for (SizeType32 bi = 0; bi < localDecoderDomain.getBatchSize(); bi++) { if (params->logitsVec) { TLLM_CHECK_WITH_INFO(params->logitsVec->size() == static_cast(localDecoderDomain.getBatchSize()), "Logits vector size (%lu) is not equal to the batchSize (%d)", params->logitsVec->size(), localDecoderDomain.getBatchSize()); logitsPtrsHostData[bi] = bufferCastOrNull(params->logitsVec.value()[bi]); } else { TensorConstPtr logitsForBatchIndex = ITensor::slice(params->logits.value(), ITensor::makeShape({bi})); auto const ptrToLogitsForBatchIndex = bufferCastOrNull(logitsForBatchIndex); logitsPtrsHostData[bi] = ptrToLogitsForBatchIndex; } } auto const* inputLengths = bufferCastOrNull(params->inputLengths); auto embeddingBias = bufferCastOrNull(params->embeddingBias); auto const* batchSlotsHostPtr = bufferCast(*params->batchSlots); #define GET_PENALTIES(capital_name, type) \ (mUse##capital_name \ && !allOfBatchSlots(batchSlotsHostPtr, bufferCast(*m##capital_name), localDecoderDomain.getBatchSize(), \ DefaultDecodingParams::get##capital_name())) \ ? m##capital_name##Device \ : nullptr; auto temperatures = GET_PENALTIES(Temperature, float); auto repetitionPenalties = GET_PENALTIES(RepetitionPenalty, float); auto presencePenalties = GET_PENALTIES(PresencePenalty, float); auto frequencyPenalties = GET_PENALTIES(FrequencyPenalty, float); auto minLengths = GET_PENALTIES(MinLength, SizeType32); auto promptIgnoreLengths = GET_PENALTIES(PromptIgnoreLength, SizeType32); #undef GET_PENALTIES auto* const tokensPerStep = bufferCastOrNull(params->curTokensPerStep); InvokeBatchApplyPenaltyParams penaltyParams{}; TensorPtr logitsPtrsHostSlice = ITensor::slice(logitsPtrsHost, 0, localDecoderDomain.getBatchSize()); auto [logitsPtrsDeviceSlice] = workspace->mirrorInWorkspace(logitsPtrsHostSlice); auto runtimeLogits = workspace->getDeviceRuntimeLogits(); penaltyParams.inputLogits = reinterpret_cast(bufferCast(*logitsPtrsDeviceSlice)); penaltyParams.outputLogits = bufferCast(*runtimeLogits); penaltyParams.biases = embeddingBias; penaltyParams.penaltyWorkspace = bufferCastOrNull(mPenaltyWorkspaceDevice); penaltyParams.penaltyWorkspacePrev = bufferCastOrNull(mPenaltyWorkspacePrevDevice); penaltyParams.temperatures = bufferCastOrNull(temperatures); penaltyParams.repetitionPenalties = bufferCastOrNull(repetitionPenalties); penaltyParams.presencePenalties = bufferCastOrNull(presencePenalties); penaltyParams.frequencyPenalties = bufferCastOrNull(frequencyPenalties); penaltyParams.batchSize = localDecoderDomain.getBatchSize(); penaltyParams.beamWidth = localDecoderDomain.getBeamWidth(); penaltyParams.maxSeqLen = maxSeqLen; penaltyParams.vocabSize = mDecoderDomain.getVocabSize(); penaltyParams.vocabSizePadded = mDecoderDomain.getVocabSizePadded(); penaltyParams.outputIdsPtr = bufferCast(*outputs->outputIdsPtr); penaltyParams.parentIdsPtr = bufferCast(*outputs->parentIdsPtr); penaltyParams.inputLengths = inputLengths; penaltyParams.sequenceLengths = bufferCast(*outputs->sequenceLength.value()); penaltyParams.minLengths = bufferCastOrNull(minLengths); penaltyParams.promptIgnoreLengths = bufferCastOrNull(promptIgnoreLengths); penaltyParams.endIds = bufferCast(*params->endIds); penaltyParams.batchSlots = workspace->getDeviceBatchSlotsPtr(); penaltyParams.maxTokensPerStep = mDecoderDomain.getMaxDecodingTokens(); penaltyParams.tokensPerStep = tokensPerStep; penaltyParams.finished = (params->finished) ? reinterpret_cast(bufferCast(*params->finished.value())) : nullptr; penaltyParams.stream = getStream(); if (penaltyParams.beamWidth > 1) { // Convert logits into logProbs before penalties, only necessary in Beam-Search. BiasSoftmaxParams biasSoftmaxParams; biasSoftmaxParams.logitsPtrs = const_cast(penaltyParams.inputLogits); biasSoftmaxParams.bias = penaltyParams.biases; biasSoftmaxParams.endIds = penaltyParams.endIds; biasSoftmaxParams.batchSlots = penaltyParams.batchSlots; biasSoftmaxParams.batchSize = penaltyParams.batchSize; biasSoftmaxParams.maxBatchSize = mDecoderDomain.getBatchSize(); biasSoftmaxParams.maxBeamWidth = penaltyParams.beamWidth; biasSoftmaxParams.vocabSize = penaltyParams.vocabSize; biasSoftmaxParams.vocabSizePadded = penaltyParams.vocabSizePadded; biasSoftmaxParams.skipSoftMax = false; biasSoftmaxParams.batchSlotsLogits = penaltyParams.batchSlots != nullptr; biasSoftmaxParams.checkParams(); invokeAddBiasSoftMax(biasSoftmaxParams, penaltyParams.stream); } invokeBatchApplyPenalty(penaltyParams); sync_check_cuda_error(penaltyParams.stream); mCyclicStep += 1; auto const logitsShape = ITensor::makeShape({localDecoderDomain.getBatchSize(), mDecoderDomain.getMaxDecodingTokens(), localDecoderDomain.getBeamWidth(), mDecoderDomain.getVocabSizePadded()}); params->logits = ITensor::view(runtimeLogits, logitsShape); if (mDecodingMode.isBeamSearch()) { std::swap(mPenaltyWorkspaceDevice, mPenaltyWorkspacePrevDevice); } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template class PenaltyLayer; template class PenaltyLayer; } // namespace tensorrt_llm::layers