/* * Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2021, NAVER Corp. Authored by CLOVA. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "penaltyLayer.h" #include "tensorrt_llm/common/cudaUtils.h" #include "tensorrt_llm/kernels/penaltyKernels.h" #include "tensorrt_llm/layers/defaultDecodingParams.h" #include "tensorrt_llm/layers/layerUtils.h" #include "tensorrt_llm/runtime/bufferManager.h" #include using namespace tensorrt_llm::common; using namespace tensorrt_llm::kernels; using namespace tensorrt_llm::runtime; namespace tensorrt_llm::layers { template PenaltyLayer::PenaltyLayer(executor::DecodingMode const& mode, DecoderDomain const& decoderDomain, std::shared_ptr bufferManager) : BaseLayer(decoderDomain, bufferManager) , mDecodingMode(mode) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); initialize(); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template size_t PenaltyLayer::getWorkspaceSize() const noexcept { return mPenaltyWorkspaceDevice->getSizeInBytes(); } template void PenaltyLayer::initialize() { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); allocateBuffer(); mCyclicStep = 0; mRuntimeMaxSeqLen = 0; mConfiguredBeamWidth = -1; if (!mDecodingMode.isAuto()) { mConfiguredBeamWidth = mDecoderDomain.getBeamWidth(); allocateWorkspace(); } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void PenaltyLayer::allocateWorkspace() { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); if (mDecodingMode.isUseOccurrencePenalty()) { auto const workspaceSize = mDecoderDomain.getBatchSize() * mDecoderDomain.getMaxDecodingTokens() * mConfiguredBeamWidth * mDecoderDomain.getVocabSize(); mPenaltyWorkspaceDevice = mBufferManager->gpu(workspaceSize, nvinfer1::DataType::kINT32); if (mDecodingMode.isBeamSearch()) { mPenaltyWorkspacePrevDevice = mBufferManager->gpu(workspaceSize, nvinfer1::DataType::kINT32); } } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void PenaltyLayer::allocateBuffer() { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); mLogitsPtrsHost = mBufferManager->pinnedPool(ITensor::makeShape({}), TRTDataType::value); mLogitsPtrsDevice = mBufferManager->gpu(ITensor::makeShape({mDecoderDomain.getBatchSize()}), TRTDataType::value); auto const batchSizeShape = ITensor::makeShape({mDecoderDomain.getBatchSize()}); mTemperature = mBufferManager->pinnedPool(batchSizeShape, TRTDataType::value); mRepetitionPenalty = mBufferManager->pinnedPool(batchSizeShape, TRTDataType::value); mPresencePenalty = mBufferManager->pinnedPool(batchSizeShape, TRTDataType::value); mFrequencyPenalty = mBufferManager->pinnedPool(batchSizeShape, TRTDataType::value); mMinLength = mBufferManager->pinnedPool(batchSizeShape, TRTDataType::value); if (mDecodingMode.isUseTemperature()) { mTemperatureDevice = mBufferManager->gpu(batchSizeShape, nvinfer1::DataType::kFLOAT); } if (mDecodingMode.isUseRepetitionPenalty()) { mRepetitionPenaltyDevice = mBufferManager->gpu(batchSizeShape, nvinfer1::DataType::kFLOAT); } if (mDecodingMode.isUsePresencePenalty()) { mPresencePenaltyDevice = mBufferManager->gpu(batchSizeShape, nvinfer1::DataType::kFLOAT); } if (mDecodingMode.isUseFrequencyPenalty()) { mFrequencyPenaltyDevice = mBufferManager->gpu(batchSizeShape, nvinfer1::DataType::kFLOAT); } if (mDecodingMode.isUseMinLength()) { mMinLengthDevice = mBufferManager->gpu(batchSizeShape, nvinfer1::DataType::kINT32); } auto const runtimeLogitsDeviceSize = mDecoderDomain.getBatchSize() * mDecoderDomain.getMaxDecodingTokens() * mDecoderDomain.getBeamWidth() * mDecoderDomain.getVocabSizePadded(); mRuntimeLogitsDevice = mBufferManager->gpu(ITensor::makeShape({runtimeLogitsDeviceSize}), TRTDataType::value); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void PenaltyLayer::setup(SizeType32 batchSize, SizeType32 beamWidth, BufferConstPtr batchSlots, std::shared_ptr const& baseSetupParams) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto setupParams = std::dynamic_pointer_cast(baseSetupParams); if (mConfiguredBeamWidth == -1) { // This code is left only for Python runtime // In C++ runtime given maxBeamWidth should always be equal to the runtime beamWidth TLLM_CHECK(mDecodingMode.isAuto()); mConfiguredBeamWidth = beamWidth; mDecodingMode = mConfiguredBeamWidth == 1 ? executor::DecodingMode::TopKTopP() : executor::DecodingMode::BeamSearch(); allocateWorkspace(); } // Setup penalties. FillBuffers const fillBuffers{batchSize, mDecoderDomain.getBatchSize(), mBufferManager}; auto const& penaltyParams = setupParams->penaltyParams; TLLM_CHECK_WITH_INFO(penaltyParams, "penaltyParams for setup is not set"); bool const useTemperature = mDecodingMode.isUseTemperature() && penaltyParams->temperature.has_value(); bool const useRepetitionPenalty = mDecodingMode.isUseRepetitionPenalty() && penaltyParams->repetitionPenalty.has_value(); bool const usePresencePenalty = mDecodingMode.isUsePresencePenalty() && penaltyParams->presencePenalty.has_value(); bool const useFrequencyPenalty = mDecodingMode.isUseFrequencyPenalty() && penaltyParams->frequencyPenalty.has_value(); bool const useMinLength = mDecodingMode.isUseMinLength() && penaltyParams->minLength.has_value(); // FIXME(nkorobov): once one of the requests has some penalty, we will always have to compute it. // To avoid that we need to scan through all active requests at each iteration. mUseTemperature |= useTemperature; mUseRepetitionPenalty |= useRepetitionPenalty; mUsePresencePenalty |= usePresencePenalty; mUseFrequencyPenalty |= useFrequencyPenalty; mUseMinLength |= useMinLength; if (mUseTemperature) { fillBuffers(penaltyParams->temperature, DefaultDecodingParams::getTemperature(), mTemperature, mTemperatureDevice, batchSlots, getLimitsPenalty(DecodingPenaltyType::Temperature), "temperature penalty"); } if (mUseRepetitionPenalty) { fillBuffers(penaltyParams->repetitionPenalty, DefaultDecodingParams::getRepetitionPenalty(), mRepetitionPenalty, mRepetitionPenaltyDevice, batchSlots, getLimitsPenalty(DecodingPenaltyType::Repetition), "repetition penalty"); } if (mUsePresencePenalty) { fillBuffers(penaltyParams->presencePenalty, DefaultDecodingParams::getPresencePenalty(), mPresencePenalty, mPresencePenaltyDevice, batchSlots, getLimitsPenalty(DecodingPenaltyType::Presence), "presence penalty"); } if (mUseFrequencyPenalty) { fillBuffers(penaltyParams->frequencyPenalty, DefaultDecodingParams::getFrequencyPenalty(), mFrequencyPenalty, mFrequencyPenaltyDevice, batchSlots, getLimitsPenalty(DecodingPenaltyType::Frequency), "frequency penalty"); } if (mUseMinLength) { fillBuffers(penaltyParams->minLength, DefaultDecodingParams::getMinLength(), mMinLength, mMinLengthDevice, batchSlots, getLimitsPenalty(DecodingPenaltyType::MinLength), "min length"); } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void PenaltyLayer::forwardAsync( std::shared_ptr const& baseOutputs, std::shared_ptr const& baseInputs) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto outputs = std::dynamic_pointer_cast(baseOutputs); auto params = std::dynamic_pointer_cast(baseInputs); auto const localDecoderDomain = getLocalDecoderDomain(params, mDecoderDomain); auto const maxSeqLen = outputs->outputIds->getDimension<-1>(); auto batchSlots = bufferCast(*params->batchSlots); if (!mLogitsPtrsHost->data()) { mLogitsPtrsHost = mBufferManager->pinnedPool( ITensor::makeShape({static_cast(maxSeqLen), static_cast(mDecoderDomain.getBatchSize())}), TRTDataType::value); mRuntimeMaxSeqLen = maxSeqLen; } mCyclicStep = mCyclicStep % mRuntimeMaxSeqLen; TensorPtr logitsPtrsHost = ITensor::slice(mLogitsPtrsHost, mCyclicStep, 1); logitsPtrsHost->squeeze(0); auto logitsPtrsHostData = bufferCast(*logitsPtrsHost); for (SizeType32 bi = 0; bi < localDecoderDomain.getBatchSize(); bi++) { if (params->logitsVec) { TLLM_CHECK_WITH_INFO(params->logitsVec->size() == localDecoderDomain.getBatchSize(), "Logits vector size (%lu) is not equal to the batchSize (%d)", params->logitsVec->size(), localDecoderDomain.getBatchSize()); logitsPtrsHostData[bi] = bufferCastOrNull(params->logitsVec.value()[bi]); } else { TensorPtr logitsForBatchIndex = ITensor::slice(params->logits.value(), ITensor::makeShape({bi})); auto const ptrToLogitsForBatchIndex = bufferCastOrNull(logitsForBatchIndex); logitsPtrsHostData[bi] = ptrToLogitsForBatchIndex; } } auto inputLengths = bufferCastOrNull(params->inputLengths); auto embeddingBias = bufferCastOrNull(params->embeddingBias); auto batchSlotsHostPtr = bufferCast(*params->batchSlots); #define GET_PENALTIES(capital_name, type) \ (mUse##capital_name \ && !allOfBatchSlots(batchSlotsHostPtr, bufferCast(*m##capital_name), localDecoderDomain.getBatchSize(), \ DefaultDecodingParams::get##capital_name())) \ ? m##capital_name##Device \ : nullptr; auto temperatures = GET_PENALTIES(Temperature, float); auto repetitionPenalties = GET_PENALTIES(RepetitionPenalty, float); auto presencePenalties = GET_PENALTIES(PresencePenalty, float); auto frequencyPenalties = GET_PENALTIES(FrequencyPenalty, float); auto minLengths = GET_PENALTIES(MinLength, SizeType32); #undef GET_PENALTIES auto const tokensPerStep = bufferCastOrNull(params->curTokensPerStep); InvokeBatchApplyPenaltyParams penaltyParams; { // Moving the logits ptrs to device for faster access during kernel execution. TensorPtr logitsPtrsDeviceSlice = ITensor::slice(mLogitsPtrsDevice, 0, localDecoderDomain.getBatchSize()); TensorPtr logitsPtrsHostSlice = ITensor::slice(logitsPtrsHost, 0, localDecoderDomain.getBatchSize()); mBufferManager->copy(*logitsPtrsHostSlice, *logitsPtrsDeviceSlice); penaltyParams.inputLogits = reinterpret_cast(bufferCast(*logitsPtrsDeviceSlice)); } penaltyParams.outputLogits = bufferCast(*mRuntimeLogitsDevice); penaltyParams.biases = embeddingBias; penaltyParams.penaltyWorkspace = bufferCastOrNull(mPenaltyWorkspaceDevice); penaltyParams.penaltyWorkspacePrev = bufferCastOrNull(mPenaltyWorkspacePrevDevice); penaltyParams.temperatures = bufferCastOrNull(temperatures); penaltyParams.repetitionPenalties = bufferCastOrNull(repetitionPenalties); penaltyParams.presencePenalties = bufferCastOrNull(presencePenalties); penaltyParams.frequencyPenalties = bufferCastOrNull(frequencyPenalties); penaltyParams.batchSize = localDecoderDomain.getBatchSize(); penaltyParams.beamWidth = localDecoderDomain.getBeamWidth(); penaltyParams.maxSeqLen = maxSeqLen; penaltyParams.vocabSize = mDecoderDomain.getVocabSize(); penaltyParams.vocabSizePadded = mDecoderDomain.getVocabSizePadded(); penaltyParams.outputIdsPtr = bufferCast(*outputs->outputIdsPtr); penaltyParams.parentIdsPtr = bufferCast(*outputs->parentIdsPtr); penaltyParams.inputLengths = inputLengths; penaltyParams.sequenceLengths = bufferCast(*outputs->sequenceLength.value()); penaltyParams.minLengths = bufferCastOrNull(minLengths); penaltyParams.endIds = bufferCast(*params->endIds); penaltyParams.batchSlots = batchSlots; penaltyParams.maxTokensPerStep = mDecoderDomain.getMaxDecodingTokens(); penaltyParams.tokensPerStep = tokensPerStep; penaltyParams.stream = getStream(); invokeBatchApplyPenalty(penaltyParams); sync_check_cuda_error(); mCyclicStep += 1; auto const logitsShape = ITensor::makeShape({localDecoderDomain.getBatchSize(), mDecoderDomain.getMaxDecodingTokens(), localDecoderDomain.getBeamWidth(), mDecoderDomain.getVocabSizePadded()}); params->logits = ITensor::view(mRuntimeLogitsDevice, logitsShape); if (mDecodingMode.isBeamSearch()) { std::swap(mPenaltyWorkspaceDevice, mPenaltyWorkspacePrevDevice); } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template class PenaltyLayer; template class PenaltyLayer; } // namespace tensorrt_llm::layers