/* * Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2021, NAVER Corp. Authored by CLOVA. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "banWordsLayer.h" #include "tensorrt_llm/kernels/banBadWords.h" #include "tensorrt_llm/kernels/banRepeatNgram.h" #include "tensorrt_llm/layers/defaultDecodingParams.h" #include "tensorrt_llm/layers/layerUtils.h" using namespace tensorrt_llm::kernels; using namespace tensorrt_llm::runtime; namespace tensorrt_llm::layers { template BanWordsLayer::BanWordsLayer(executor::DecodingMode const& mode, DecoderDomain const& decoderDomain, std::shared_ptr bufferManager) : BaseLayer(decoderDomain, bufferManager) , mDecodingMode(mode) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); allocateBuffer(); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void BanWordsLayer::allocateBuffer() { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); if (mDecodingMode.isUseNoRepeatNgramSize()) { mNoRepeatNgramSizeDevice = mBufferManager->gpu(ITensor::makeShape({mDecoderDomain.getBatchSize()}), TRTDataType::value); } mNoRepeatNgramSize = mBufferManager->pinnedPool( ITensor::makeShape({mDecoderDomain.getBatchSize()}), TRTDataType::value); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void BanWordsLayer::setup(SizeType32 batchSize, SizeType32 beamWidth, TensorConstPtr batchSlots, std::shared_ptr const& baseSetupParams, std::shared_ptr const& workspace) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto setupParams = std::dynamic_pointer_cast(baseSetupParams); auto const& banWordsParams = setupParams->banWordsParams; TLLM_CHECK_WITH_INFO(banWordsParams, "banWordsParams for setup is not set"); bool const useNoRepeatNgramSize = mDecodingMode.isUseNoRepeatNgramSize() && banWordsParams->noRepeatNgramSize.has_value(); FillBuffers const fillBuffers{batchSize, mDecoderDomain.getBatchSize(), mBufferManager}; mUseNoRepeatNgramSize |= useNoRepeatNgramSize; if (mUseNoRepeatNgramSize) { fillBuffers(banWordsParams->noRepeatNgramSize, DefaultDecodingParams::getNoRepeatNgramSize(), mNoRepeatNgramSize, mNoRepeatNgramSizeDevice, batchSlots, std::make_pair(0.f, std::numeric_limits::max()), "no_repeat_ngram_size"); } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void BanWordsLayer::banRepeatNGrams(TensorPtr const& logits, std::shared_ptr const& outputs, std::shared_ptr const& inputs, BufferConstPtr const& batchSlots, BufferPtr noRepeatNgramSizeDevice, DecoderDomain const& decoderDomain, SizeType32 maxSeqLen, bool useNoRepeatNgramSize) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); if (useNoRepeatNgramSize) { // auto const maxStep = inputs->step; // TODO Should we use step? but current inputs->step is always 0. auto const maxStep = maxSeqLen; // Temporary variables to store dereferenced inputs auto logitsPtr = bufferCast(*logits); auto outputIdsPtr = bufferCast(*outputs->outputIdsPtr); auto finishedPtr = reinterpret_cast(bufferCastOrNull(inputs->finished)); auto parentIdsPtr = bufferCast(*outputs->parentIdsPtr); auto batchSlotsPtr = bufferCast(*batchSlots); auto sequenceLengthPtr = bufferCast(*outputs->sequenceLength.value()); auto noRepeatNgramSizeDevicePtr = bufferCastOrNull(noRepeatNgramSizeDevice); // Call to invokeBanRepeatNgram with dereferenced inputs invokeBanRepeatNgram(logitsPtr, outputIdsPtr, finishedPtr, parentIdsPtr, batchSlotsPtr, sequenceLengthPtr, decoderDomain.getBatchSize(), decoderDomain.getBeamWidth(), maxSeqLen, noRepeatNgramSizeDevicePtr, decoderDomain.getVocabSizePadded(), maxStep, getStream()); } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void BanWordsLayer::banBadWords(TensorPtr const& logits, std::shared_ptr const& outputs, std::shared_ptr const& inputs, BufferConstPtr const& batchSlots, DecoderDomain const& decoderDomain, SizeType32 maxSeqLen) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto const maxBadWordsLength = inputs->banWordsInputs->maxBadWordsLen; if (maxBadWordsLength != 0) { // Temporary variables to store dereferenced inputs auto badWordsPtr = bufferCast(*inputs->banWordsInputs->badWordsPtr.value()); auto badWordsLens = bufferCast(*inputs->banWordsInputs->badWordsLengths.value()); auto logitsPtr = bufferCast(*logits); auto outputIdsPtr = bufferCast(*outputs->outputIdsPtr); auto parentIdsPtr = decoderDomain.getBeamWidth() > 1 ? bufferCast(*outputs->parentIdsPtr) : nullptr; auto sequenceLengthPtr = bufferCast(*outputs->sequenceLength.value()); auto batchSlotsPtr = bufferCast(*batchSlots); // Call to invokeBanBadWords with dereferenced inputs invokeBanBadWords(logitsPtr, outputIdsPtr, parentIdsPtr, batchSlotsPtr, decoderDomain.getBatchSize(), decoderDomain.getBeamWidth(), badWordsPtr, badWordsLens, maxBadWordsLength, decoderDomain.getVocabSizePadded(), sequenceLengthPtr, maxSeqLen, getStream()); } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void BanWordsLayer::forwardAsync(std::shared_ptr const& baseOutputs, std::shared_ptr const& baseInputs, std::shared_ptr const& workspace) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto inputs = std::dynamic_pointer_cast(baseInputs); auto outputs = std::dynamic_pointer_cast(baseOutputs); TLLM_CHECK_WITH_INFO(inputs->banWordsInputs, "banWordsInputs for forward is not set"); auto const localDecoderDomain = getLocalDecoderDomain(inputs, mDecoderDomain); auto const maxSeqLen = outputs->outputIds->getDimension<-1>(); banRepeatNGrams(workspace->getDeviceRuntimeLogits(), outputs, inputs, workspace->getDeviceBatchSlots(), mNoRepeatNgramSizeDevice, localDecoderDomain, maxSeqLen, mUseNoRepeatNgramSize); banBadWords(workspace->getDeviceRuntimeLogits(), outputs, inputs, workspace->getDeviceBatchSlots(), localDecoderDomain, maxSeqLen); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template class BanWordsLayer; template class BanWordsLayer; } // namespace tensorrt_llm::layers