/* * Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2021, NAVER Corp. Authored by CLOVA. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "topKSamplingLayer.h" #include "tensorrt_llm/common/logger.h" #include "tensorrt_llm/kernels/decodingCommon.h" #include "tensorrt_llm/kernels/samplingTopKKernels.h" #include "tensorrt_llm/layers/defaultDecodingParams.h" #include "tensorrt_llm/layers/layerUtils.h" #include #include using namespace tensorrt_llm::common; using namespace tensorrt_llm::kernels; using namespace tensorrt_llm::runtime; namespace tensorrt_llm::layers { template TopKSamplingLayer::TopKSamplingLayer( DecoderDomain const& decoderDomain, std::shared_ptr bufferManager) : BaseLayer(decoderDomain, bufferManager) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); allocateBuffer(mDecoderDomain.getBatchSize()); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void TopKSamplingLayer::allocateBuffer(SizeType32 const batchSize) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); mWorkspaceSize = getTopKWorkspaceSize(batchSize, 1, TOP_K_MAX, mDecoderDomain.getVocabSizePadded()); auto const batchSizeShape = ITensor::makeShape({batchSize}); mRuntimeTopKDevice = mBufferManager->gpu(batchSizeShape, TRTDataType::value); mRuntimeTopPDevice = mBufferManager->gpu(batchSizeShape, TRTDataType::value); mSkipDecodeDevice = mBufferManager->gpu(batchSizeShape, TRTDataType::value); mRuntimeTopKHost = mBufferManager->cpu(batchSizeShape, TRTDataType::value); mSkipDecodeHost = mBufferManager->cpu(batchSizeShape, TRTDataType::value); mSetupWorkspaceSize = batchSize * sizeof(SizeType32); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void TopKSamplingLayer::setup(SizeType32 batchSize, SizeType32 beamWidth, TensorConstPtr batchSlots, std::shared_ptr const& baseSetupParams, std::shared_ptr const& workspace) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto setupParams = std::dynamic_pointer_cast(baseSetupParams); mNormalizeLogProbs = setupParams->normalizeLogProbs.value_or(false); auto runtimeTopK = setupParams->runtimeTopK.value_or(std::vector{DefaultDecodingParams::getTopK()}); auto runtimeTopP = setupParams->runtimeTopP.value_or(std::vector{DefaultDecodingParams::getTopP()}); auto const paramsSize = expandMatchElements(batchSize, runtimeTopK, runtimeTopP); TLLM_CHECK_WITH_INFO(paramsSize != 0, fmtstr("TopKSamplingLayer got parameter with unexpected size, want 1 or batchSize(%d), got" "runtimeTopK.size() = %zu, runtimeTopP.size() = %zu", batchSize, runtimeTopK.size(), runtimeTopP.size())); for (size_t i = 0; i < paramsSize; ++i) { auto& topK = runtimeTopK[i]; auto& topP = runtimeTopP[i]; clampTopK(topK); clampTopP(topP); regularizeTopKTopP(topK, topP); } // Update parameters on both device and host, so we can // - determine whether we can skip launch kernel by examine mSkipDecodeHost // - select best kernel by examine mRuntimeTopKHost // without consulting device memory, or we'll have to do an expensive synchronization. SizeType32* topKsPtr = nullptr; float* topPsPtr = nullptr; if (paramsSize > 1) { auto initWorkspaceSizes = getTopKInitWorkspaceSizes(batchSize); auto workspacePtr = workspace->getRawWorkspaceDevicePtr(); calcAlignedPointers(workspacePtr, initWorkspaceSizes)(topKsPtr, topPsPtr); DecodingLayerWorkspace::copyToWorkspace( *mBufferManager, runtimeTopK, IBuffer::wrap(topKsPtr, initWorkspaceSizes[0] / sizeof(*topKsPtr))); DecodingLayerWorkspace::copyToWorkspace( *mBufferManager, runtimeTopP, IBuffer::wrap(topPsPtr, initWorkspaceSizes[1] / sizeof(*topPsPtr))); } auto const* batchSlotsDevicePtr = workspace->getDeviceBatchSlotsPtr(); auto* skipDecodeDevicePtr = bufferCastOrNull(mSkipDecodeDevice); invokeSetupTopKRuntimeArgs(batchSize, // {topKsPtr, runtimeTopK.front(), bufferCast(*mRuntimeTopKDevice)}, // {topPsPtr, runtimeTopP.front(), bufferCast(*mRuntimeTopPDevice)}, // skipDecodeDevicePtr, batchSlotsDevicePtr, true, getStream()); auto const* batchSlotsHostPtr = bufferCast(*batchSlots); auto* skipDecodeHostPtr = bufferCastOrNull(mSkipDecodeHost); topKsPtr = paramsSize > 1 ? runtimeTopK.data() : nullptr; invokeSetupTopKRuntimeArgs(batchSize, // {topKsPtr, runtimeTopK.front(), bufferCast(*mRuntimeTopKHost)}, {}, // skipDecodeHostPtr, batchSlotsHostPtr, false); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void TopKSamplingLayer::forwardAsync(std::shared_ptr const& outputs, std::shared_ptr const& baseInputs, std::shared_ptr const& workspace) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto inputs = std::dynamic_pointer_cast(baseInputs); auto const batchSize = inputs->logits.value()->getDimension<0>(); auto const* batchSlotsHost = bufferCast(*inputs->batchSlots); auto* skipDecodeHostPtr = bufferCastOrNull(mSkipDecodeHost); auto const skip = allOfBatchSlots(batchSlotsHost, skipDecodeHostPtr, batchSize, true); if (skip) { return; } auto logits = bufferCastOrNull(inputs->logits); auto const* endIds = bufferCastOrNull(inputs->endIds); auto const probsComputed = inputs->probsComputed; FinishedState const* finishedInput = (inputs->finished) ? reinterpret_cast(bufferCastOrNull(inputs->finished)) : nullptr; FinishedState* finishedOutput = (outputs->finished) ? reinterpret_cast(bufferCastOrNull(outputs->finished)) : nullptr; auto* runtimeTopKHostPtr = bufferCast(*mRuntimeTopKHost); TopKSamplingKernelParams params; params.logProbs = logits; params.outputIdsPtrs = bufferCastOrNull(outputs->outputIdsPtr); params.workspace = workspace->getRawWorkspaceDevicePtr(); params.maxTopP = 1.0f; params.topPs = bufferCastOrNull(mRuntimeTopPDevice); params.maxTopK = maxOfBatchSlots(batchSlotsHost, runtimeTopKHostPtr, batchSize); params.topKs = bufferCastOrNull(mRuntimeTopKDevice); params.sequenceLengths = bufferCastOrNull(outputs->sequenceLength); params.endIds = endIds; params.batchSlots = workspace->getDeviceBatchSlotsPtr(); params.finishedInput = finishedInput; params.finishedOutput = finishedOutput; params.skipDecode = bufferCastOrNull(mSkipDecodeDevice); params.cumLogProbs = bufferCastOrNull(outputs->cumLogProbs); params.outputLogProbs = bufferCastOrNull(outputs->outputLogProbsTiled); params.curandState = inputs->curandStates; params.batchSize = batchSize; params.maxBatchSize = mDecoderDomain.getBatchSize(); params.maxTokensPerStep = 1; params.vocabSizePadded = mDecoderDomain.getVocabSizePadded(); params.normalizeLogProbs = mNormalizeLogProbs; params.logitsHasProbs = probsComputed; invokeBatchTopKSampling(params, getStream()); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template size_t TopKSamplingLayer::getWorkspaceSize() const noexcept { return std::max(mWorkspaceSize, mSetupWorkspaceSize); } template class TopKSamplingLayer; template class TopKSamplingLayer; } // namespace tensorrt_llm::layers