/* * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "tests/unit_tests/kernels/sampling/samplingTest.h" namespace tk = tensorrt_llm::kernels; using namespace tensorrt_llm::runtime; using namespace tensorrt_llm::tests::kernels::sampling; namespace { template class AirTopPSamplingKernelTest : public SamplingKernelTest { protected: const int32_t endId = 0; using SamplingKernelTest::mSeed; using SamplingKernelTest::mStream; using SamplingKernelTest::mBufferManager; private: size_t getWorkspaceSize(SamplingKernelTestParam const& params) override { return tensorrt_llm::kernels::getAirTopPWorkspaceSize( params.batchSize, params.vocabSize, params.isDeterministicTopP); } void callTestedFunction( SamplingKernelTestParam const& params, tensorrt_llm::runtime::ITensor::SharedPtr& workspaceDevice) override { // Calculate the number of blocks based on the number of multiprocessors, batchSize and vocabSize. int dev; int smCnt; TLLM_CUDA_CHECK(cudaGetDevice(&dev)); TLLM_CUDA_CHECK(cudaDeviceGetAttribute(&smCnt, cudaDevAttrMultiProcessorCount, dev)); auto const maxBatchSize = 2 * params.batchSize; int blockNum = tk::calcAirTopPBlockNum(params.batchSize, params.vocabSize, smCnt, params.isDeterministicTopP); tk::TopPSamplingKernelParams kernelParams; kernelParams.probs = bufferCast(*this->mProbsDevice); kernelParams.outputIdsPtrs = bufferCast(*this->mIdsPtrHost); kernelParams.workspace = workspaceDevice->data(); kernelParams.topPs = bufferCast(*this->mTopPsDevice); kernelParams.sequenceLength = bufferCast(*this->mSeqLengthsDevice); kernelParams.endIds = bufferCast(*this->mEndIdsDevice); kernelParams.batchSlots = bufferCast(*this->mBatchSlots); kernelParams.finishedInput = reinterpret_cast( bufferCast(*this->mFinishedDevice)); kernelParams.finishedOutput = reinterpret_cast( bufferCast(*this->mFinishedDevice)); kernelParams.skipDecode = bufferCast(*this->mSkipDecodeDevice); kernelParams.cumLogProbs = bufferCast(*this->mCumLogProbsDevice); kernelParams.outputLogProbs = bufferCast(*this->mOutputLogProbsDevice); kernelParams.curandState = reinterpret_cast(bufferCast(*this->mCurandStatesDevice)); kernelParams.batchSize = params.batchSize; kernelParams.maxBatchSize = maxBatchSize; kernelParams.vocabSizePadded = params.vocabSize; kernelParams.blockNum = blockNum; kernelParams.isDeterministic = params.isDeterministicTopP; // Perform batched TopP sampling tk::invokeBatchAirTopPSampling(kernelParams, this->mStream->get()); } }; TYPED_TEST_SUITE(AirTopPSamplingKernelTest, FloatAndHalfTypes); TYPED_TEST(AirTopPSamplingKernelTest, NondeterministicCorrectnessSmallP) { this->runTest(SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopK(0).setTopP(0.2f)); }; TYPED_TEST(AirTopPSamplingKernelTest, NondeterministicCorrectnessLargeP) { this->runTest(SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopK(0).setTopP(0.9f)); }; TYPED_TEST(AirTopPSamplingKernelTest, NondeterministicCorrectnessAncestral) { this->runTest(SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopK(0).setTopP(1.0f)); }; TYPED_TEST(AirTopPSamplingKernelTest, NondeterministicCorrectnessLargeVocabSmallP) { this->runTest(SamplingKernelTestParam().setBatchSize(32).setVocabSize(51200).setTopK(0).setTopP(0.2f)); }; TYPED_TEST(AirTopPSamplingKernelTest, NondeterministicCorrectnessLargeVocabLargeP) { this->runTest(SamplingKernelTestParam().setBatchSize(32).setVocabSize(51200).setTopK(0).setTopP(0.9f)); }; TYPED_TEST(AirTopPSamplingKernelTest, DeterministicCorrectnessSmallP) { this->runTest( SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopK(0).setTopP(0.2f).setDeterministicTopP(true)); }; TYPED_TEST(AirTopPSamplingKernelTest, DeterministicCorrectnessLargeP) { this->runTest( SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopK(0).setTopP(0.9f).setDeterministicTopP(true)); }; TYPED_TEST(AirTopPSamplingKernelTest, DeterministicCorrectnessAncestral) { this->runTest( SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopK(0).setTopP(1.0f).setDeterministicTopP(true)); }; TYPED_TEST(AirTopPSamplingKernelTest, DeterministicCorrectnessLargeVocabSmallP) { this->runTest( SamplingKernelTestParam().setBatchSize(32).setVocabSize(51200).setTopK(0).setTopP(0.2f).setDeterministicTopP( true)); }; TYPED_TEST(AirTopPSamplingKernelTest, DeterministicCorrectnessLargeVocabLargeP) { this->runTest( SamplingKernelTestParam().setBatchSize(32).setVocabSize(51200).setTopK(0).setTopP(0.9f).setDeterministicTopP( true)); }; class AirTopPSamplingKernelUtilsTest : public SamplingKernelTest { }; } // end of namespace