/* * Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "tests/kernels/sampling/samplingTest.h" namespace tensorrt_llm::tests::kernels::sampling { using namespace tensorrt_llm::runtime; namespace tc = tensorrt_llm::common; namespace tk = tensorrt_llm::kernels; namespace trk = tensorrt_llm::runtime::kernels; template void SamplingKernelTest::SetUp() { mStream = std::make_shared(); mBufferManager = std::make_shared(mStream); int device; cudaGetDevice(&device); cudaGetDeviceProperties(&mDeviceProp, device); } template void SamplingKernelTest::TearDown() { } template void SamplingKernelTest::allocateBuffers(int32_t batchSize, int32_t vocabSize, int32_t maxSeqLen, int32_t outputLen) { // Allocate GPU data mSeqLengthsHost = mBufferManager->pinned(ITensor::makeShape({batchSize}), nvinfer1::DataType::kINT32); mSeqLengthsDevice = mBufferManager->gpu(ITensor::makeShape({batchSize}), nvinfer1::DataType::kINT32); mFinishedHost = mBufferManager->pinned(ITensor::makeShape({batchSize}), nvinfer1::DataType::kBOOL); mFinishedDevice = mBufferManager->gpu(ITensor::makeShape({batchSize}), nvinfer1::DataType::kBOOL); mOutputIdsHost = mBufferManager->pinned(ITensor::makeShape({batchSize, maxSeqLen}), nvinfer1::DataType::kINT32); mOutputIdsDevice = mBufferManager->gpu(ITensor::makeShape({batchSize, maxSeqLen}), nvinfer1::DataType::kINT32); mProbsHost = mBufferManager->pinned(ITensor::makeShape({batchSize, vocabSize}), std::is_same_v ? nvinfer1::DataType::kFLOAT : nvinfer1::DataType::kHALF); mProbsDevice = mBufferManager->gpu(ITensor::makeShape({batchSize, vocabSize}), std::is_same_v ? nvinfer1::DataType::kFLOAT : nvinfer1::DataType::kHALF); mCumLogProbsDevice = mBufferManager->gpu(ITensor::makeShape({batchSize}), nvinfer1::DataType::kFLOAT); mOutputLogProbsDevice = mBufferManager->gpu(ITensor::makeShape({batchSize, outputLen}), nvinfer1::DataType::kFLOAT); mZeroParentIdsDevice = mBufferManager->gpu(ITensor::makeShape({batchSize, maxSeqLen}), nvinfer1::DataType::kINT32); mTopPIdValsDevice = mBufferManager->gpu(ITensor::makeShape({batchSize, vocabSize}), nvinfer1::DataType::kINT32); mBeginOffsetsDevice = mBufferManager->gpu(ITensor::makeShape({batchSize + 1}), nvinfer1::DataType::kINT32); mEndOffsetsDevice = mBufferManager->gpu(ITensor::makeShape({batchSize + 1}), nvinfer1::DataType::kINT32); mLogitsHost = mBufferManager->pinned(ITensor::makeShape({batchSize, vocabSize}), std::is_same_v ? nvinfer1::DataType::kFLOAT : nvinfer1::DataType::kHALF); mLogProbsHost = mBufferManager->pinned(ITensor::makeShape({batchSize, vocabSize}), std::is_same_v ? nvinfer1::DataType::kFLOAT : nvinfer1::DataType::kHALF); mIdsPtrHost = mBufferManager->pinned(ITensor::makeShape({2 * batchSize}), nvinfer1::DataType::kINT64); mEndIdsHost = mBufferManager->pinned(ITensor::makeShape({batchSize}), nvinfer1::DataType::kINT32); mEndIdsDevice = mBufferManager->gpu(ITensor::makeShape({batchSize}), nvinfer1::DataType::kINT32); mTopPsHost = mBufferManager->pinned(ITensor::makeShape({batchSize}), nvinfer1::DataType::kFLOAT); mTopPsDevice = mBufferManager->gpu(ITensor::makeShape({batchSize}), nvinfer1::DataType::kFLOAT); mTopKsHost = mBufferManager->pinned(ITensor::makeShape({batchSize}), nvinfer1::DataType::kINT32); mTopKsDevice = mBufferManager->gpu(ITensor::makeShape({batchSize}), nvinfer1::DataType::kINT32); mSkipDecodeHost = mBufferManager->pinned(ITensor::makeShape({batchSize}), nvinfer1::DataType::kBOOL); mSkipDecodeDevice = mBufferManager->gpu(ITensor::makeShape({batchSize}), nvinfer1::DataType::kBOOL); mExpectedCumLogProbsHost = mBufferManager->pinned(ITensor::makeShape({batchSize}), nvinfer1::DataType::kFLOAT); } template void SamplingKernelTest::setupBuffers(int32_t batchSize, int32_t vocabSize, int32_t maxSeqLen, int32_t outputLen, int32_t topK, float topP, bool useSkipDecode, bool hasDiffRuntimeArgs, std::mt19937& gen, std::uniform_int_distribution<>& endIdsDistr) { // Allocate and init curand states cudaMalloc(&mCurandStatesDevice, sizeof(curandState_t) * batchSize); tk::invokeCurandInitialize(mCurandStatesDevice, batchSize, mSeed, mStream->get()); std::uniform_real_distribution<> skipDecodeDist(0, 1); // uniform distribution between 0 and 1 std::uniform_real_distribution<> topPDist(0, 1); // uniform distribution between 0 and 1 std::uniform_int_distribution<> topKDist(1, std::min(1024, vocabSize)); // Init by zero. trk::invokeFill(*mSeqLengthsDevice, int32_t{0}, *mStream); trk::invokeFill(*mFinishedDevice, false, *mStream); trk::invokeFill(*mCumLogProbsDevice, float{0.0f}, *mStream); trk::invokeFill(*mOutputLogProbsDevice, float{0.0f}, *mStream); trk::invokeFill(*mZeroParentIdsDevice, int32_t{0}, *mStream); trk::invokeFill(*mOutputIdsDevice, int32_t{0}, *mStream); std::fill_n(bufferCast(*mExpectedCumLogProbsHost), batchSize, 0); // Init topK, topP and endIds for each request in batch auto skipDecodeHostPtr = bufferCast(*mSkipDecodeHost); auto topPsHostPtr = bufferCast(*mTopPsHost); auto topKsHostPtr = bufferCast(*mTopKsHost); auto endIdsHostPtr = bufferCast(*mEndIdsHost); for (SizeType bi = 0; bi < batchSize; ++bi) { endIdsHostPtr[bi] = endIdsDistr(gen); skipDecodeHostPtr[bi] = useSkipDecode ? skipDecodeDist(gen) > 0.8 : false; topKsHostPtr[bi] = hasDiffRuntimeArgs ? topKDist(gen) : topK; topPsHostPtr[bi] = hasDiffRuntimeArgs ? topPDist(gen) : topP; } mMaxTopK = *std::max_element(topKsHostPtr, topKsHostPtr + batchSize); mMaxTopP = *std::max_element(topPsHostPtr, topPsHostPtr + batchSize); // Setup pointers to output ids for each request in batch auto idsPtrHostPtr = reinterpret_cast(bufferCast(*mIdsPtrHost)); auto outputIdsDevicePtr = bufferCast(*mOutputIdsDevice); auto zeroParentIdsDevicePtr = bufferCast(*mZeroParentIdsDevice); for (SizeType bi = 0; bi < batchSize; bi++) { idsPtrHostPtr[bi] = outputIdsDevicePtr + bi * maxSeqLen; } for (SizeType bi = 0; bi < batchSize; bi++) { idsPtrHostPtr[batchSize + bi] = zeroParentIdsDevicePtr + bi * maxSeqLen; } mBufferManager->copy(*mEndIdsHost, *mEndIdsDevice); mBufferManager->copy(*mSkipDecodeHost, *mSkipDecodeDevice); mBufferManager->copy(*mTopPsHost, *mTopPsDevice); mBufferManager->copy(*mTopKsHost, *mTopKsDevice); } template void SamplingKernelTest::verifyCurrentStep(int32_t batchSize, int32_t vocabSize, int32_t maxSeqLen, int32_t step, bool greedySearch, bool useSkipDecode, bool hasDiffRuntimeArgs, std::vector& refFinished, std::vector& refSeqLength, const std::vector& finishedCurrentStep) { const auto outputIdsHostPtr = bufferCast(*mOutputIdsHost); const auto seqLengthsHostPtr = bufferCast(*mSeqLengthsHost); const auto finishedHostPtr = bufferCast(*mFinishedHost); const auto logProbsHostPtr = bufferCast(*mLogProbsHost); const auto endIdsHostPtr = bufferCast(*mEndIdsHost); const auto skipDecodeHostPtr = bufferCast(*mSkipDecodeHost); auto expectedCumLogProbsHostPtr = bufferCast(*mExpectedCumLogProbsHost); for (SizeType bi = 0; bi < batchSize; ++bi) { // Set reference finished state to true if we finished before or at current step bool finishedThisStep = finishedCurrentStep[bi] || outputIdsHostPtr[bi * maxSeqLen + step] == endIdsHostPtr[bi]; refFinished[bi] = refFinished[bi] || finishedThisStep; if (refFinished[bi] == false) { // Increase reference seq len excluding the EOS token refSeqLength[bi]++; } // If decoding for this batch is skipped ignore cumLog calculation if (!skipDecodeHostPtr[bi]) { // Check seq len correctness EXPECT_EQ(seqLengthsHostPtr[bi], refSeqLength[bi]); // Only in greedy search we can guarantee the selected token and stop by condition if (greedySearch) { EXPECT_EQ(finishedHostPtr[bi], refFinished[bi]); } int idx = bi * vocabSize + outputIdsHostPtr[bi * maxSeqLen + step]; // Compute reference cumLogProb by summing all logProbs up to the stop token expectedCumLogProbsHostPtr[bi] += step < refSeqLength[bi] || finishedThisStep ? (float) logProbsHostPtr[idx] : 0.0f; // If sequence has just finished at this step if (finishedHostPtr[bi] && step < seqLengthsHostPtr[bi]) { // Check that finished tokens is endId EXPECT_EQ(outputIdsHostPtr[bi * maxSeqLen + step], endIdsHostPtr[bi]) << "step: " << step << " b: " << bi << " hasDiffRuntimeArgs: " << hasDiffRuntimeArgs << " useSkipDecode: " << useSkipDecode; } // TODO(nkorobov): check correctness with K>1 } } } template void SamplingKernelTest::runTest(const SamplingKernelTestParam& param, bool hasDiffRuntimeArgs, bool useSkipDecode) { const auto batchSize = param.batchSize; const auto vocabSize = param.vocabSize; const auto outputLen = param.outputLen; const auto maxSeqLen = outputLen; const auto topK = param.topK; const auto topP = param.topP; const bool greedySearch = topK == 1 && hasDiffRuntimeArgs == false && useSkipDecode == false; std::mt19937 gen(42); std::uniform_real_distribution<> finishedDist(0, 1); // uniform distribution between 0 and 1 std::uniform_int_distribution<> endIdsDistr( 0, vocabSize - 1); // -1 because uniform_int_distribution generates closed interval // Allocate buffers allocateBuffers(batchSize, vocabSize, maxSeqLen, outputLen); // Setup buffers setupBuffers( batchSize, vocabSize, maxSeqLen, outputLen, topK, topP, useSkipDecode, hasDiffRuntimeArgs, gen, endIdsDistr); // Allocate internal state holders for reference std::vector refSeqLength(batchSize); std::vector refFinished(batchSize); // retrieve the workspace size of the sampling kernel. const auto workspaceSize = getWorkspaceSize(param); TensorPtr workspaceDevice = mBufferManager->gpu(ITensor::makeShape({static_cast(workspaceSize)}), nvinfer1::DataType::kINT8); for (size_t step = 0; step < outputLen; ++step) { // Init logits randomly auto logitsHostPtr = bufferCast(*mLogitsHost); auto endIdsHostPtr = bufferCast(*mEndIdsHost); initRandom(logitsHostPtr, batchSize * vocabSize, -3.0f, 3.0f); std::vector finishedCurrentStep(batchSize); // Only in greedy search we can guarantee the selected token and stop by condition if (greedySearch) { for (SizeType bi = 0; bi < batchSize; ++bi) { // Randomly decide if the sequence finishes at current step finishedCurrentStep[bi] = refFinished[bi] == false ? finishedDist(gen) < 0.1 : false; if (finishedCurrentStep[bi]) { // Set logit of the endId for the finished request to the value above others // NOTE that we can guarantee finish only in greedy search logitsHostPtr[bi * vocabSize + endIdsHostPtr[bi]] = 4.0f; } } } // Compute probobilities for each token computeProb(bufferCast(*mProbsHost), bufferCast(*mLogitsHost), batchSize, vocabSize); mBufferManager->copy(*mProbsHost, *mProbsDevice); // Call tested function sampling callTestedFunction(param, hasDiffRuntimeArgs, workspaceSize, workspaceDevice); mBufferManager->copy(*mOutputIdsDevice, *mOutputIdsHost); mBufferManager->copy(*mSeqLengthsDevice, *mSeqLengthsHost); mBufferManager->copy(*mFinishedDevice, *mFinishedHost); // Synchronize to get valid data on Host mStream->synchronize(); // Compute reference. computeLogProb(bufferCast(*mLogProbsHost), bufferCast(*mLogitsHost), batchSize, vocabSize); verifyCurrentStep(batchSize, vocabSize, maxSeqLen, step, greedySearch, useSkipDecode, hasDiffRuntimeArgs, refFinished, refSeqLength, finishedCurrentStep); } const auto cumLogProbsHost = mBufferManager->copyFrom(*mCumLogProbsDevice, MemoryType::kCPU); // Check if cumulative prob is correct const bool passed = checkResult( param.toString(), bufferCast(*cumLogProbsHost), bufferCast(*mExpectedCumLogProbsHost), batchSize); EXPECT_TRUE(passed); cudaFree(mCurandStatesDevice); } template void SamplingKernelTest::runTest(const SamplingKernelTestParam& param) { runTest(param, false, false); // Single params, do not skip decoders runTest(param, true, false); // Different params, do not skip decoders runTest(param, false, true); // Single params, skip some decoders runTest(param, true, true); // Different params, skip some decoders } template class SamplingKernelTest; template class SamplingKernelTest; } // namespace tensorrt_llm::tests::kernels::sampling