/* * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "tests/unit_tests/layers/baseSamplingLayerTest.h" namespace tensorrt_llm::tests::layers::sampling { using namespace tensorrt_llm::runtime; using namespace tensorrt_llm::layers; using namespace tensorrt_llm::common; namespace tk = tensorrt_llm::kernels; namespace trk = tensorrt_llm::runtime::kernels; template void BaseSamplingLayerTest::setup(uint64_t seed, TestSamplingParams const& params) { auto const dataType = TRTDataType::value; auto const ptrType = TRTDataType::value; // clang-format off // logits = (-0.9163, -1.2040, -1.6094, -2.3026) -> prob = (0.4, 0.3, 0.2, 0.1) std::vector testLogits = { -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX, -0.9163, -1.2040, -1.6094, -2.3026, // step 0 -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX, // step 1 -FLT_MAX, -FLT_MAX, -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, -FLT_MAX, // step 2 -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX // step 3 }; // clang-format on if (params.beamWidth == 1) { mTestLogitsInit = testLogits; } else { for (int step = 0; step < mMaxSeqLen; ++step) { auto const& logitsBegin = testLogits.begin() + mVocabSize * step; auto const& logitsEnd = testLogits.begin() + mVocabSize * (step + 1); for (int bm = 0; bm < params.beamWidth; ++bm) { mTestLogitsInit.insert(mTestLogitsInit.end(), logitsBegin, logitsEnd); } } } if (mComputeProbs) { computeProb(mTestLogitsInit.data(), mTestLogitsInit.data(), BaseSamplingLayerTest::mMaxOutputLen * params.beamWidth, mVocabSize); } mSeqLengthsDevice = mBufferManager->gpu(ITensor::makeShape({maxBatchSize()}), nvinfer1::DataType::kINT32); mContextLengthDevice = mBufferManager->gpu(ITensor::makeShape({maxBatchSize()}), nvinfer1::DataType::kINT32); mFinishedDevice = params.isExternalDraftTokensLayerTest ? mBufferManager->gpu(ITensor::makeShape({mMaxTokensPerEngineStep, maxBatchSize()}), TRTDataType::value) : mBufferManager->gpu( ITensor::makeShape({maxBatchSize()}), TRTDataType::value); mOutputIdsDevice = mBufferManager->gpu(ITensor::makeShape({maxBatchSize(), mBeamWidth, mMaxSeqLen}), nvinfer1::DataType::kINT32); mEndIdsDevice = mBufferManager->gpu(ITensor::makeShape({maxBatchSize()}), nvinfer1::DataType::kINT32); mIdsPtrHost = mBufferManager->pinned(ITensor::makeShape({maxBatchSize()}), ptrType); mCumLogProbsDevice = mBufferManager->gpu(ITensor::makeShape({maxBatchSize()}), nvinfer1::DataType::kFLOAT); mOutputLogProbsDevice = mBufferManager->gpu(ITensor::makeShape({maxBatchSize(), mMaxSeqLen}), nvinfer1::DataType::kFLOAT); mBatchSlots = mBufferManager->pinned(ITensor::makeShape({mBatchSize + mBatchSizeBadPad}), nvinfer1::DataType::kINT32); mCurandStatesDevice = mBufferManager->gpu(ITensor::makeShape({maxBatchSize(), sizeof(curandState_t)}), nvinfer1::DataType::kINT8); auto const workspaceSize = mSamplingLayer->getWorkspaceSize(); trk::invokeFill(*mSeqLengthsDevice, int32_t{0}, *mStream); trk::invokeFill(*mContextLengthDevice, int32_t{0}, *mStream); trk::invokeFill(*mFinishedDevice, uint8_t{0}, *mStream); trk::invokeFill(*mOutputIdsDevice, int32_t{0}, *mStream); trk::invokeFill(*mCumLogProbsDevice, float{0.0f}, *mStream); trk::invokeFill(*mOutputLogProbsDevice, float{0.0f}, *mStream); trk::invokeFill(*mEndIdsDevice, int32_t{mEndId}, *mStream); tk::invokeCurandInitialize(reinterpret_cast(bufferCast(*mCurandStatesDevice)), nullptr, maxBatchSize(), seed, mStream->get()); auto batchSlotsPtr = bufferCast(*mBatchSlots); for (SizeType32 bi = 0; bi < mBatchSize; ++bi) { batchSlotsPtr[bi] = kDoubleBatchIdx * bi; } for (SizeType32 bi = 0; bi < mBatchSizeBadPad; ++bi) { batchSlotsPtr[mBatchSize + bi] = 0xbaadf00d; } auto idsPtrHostPtr = BufferRange(*mIdsPtrHost); auto outputIdsDevicePtr = bufferCast(*mOutputIdsDevice); for (SizeType32 bi = 0; bi < maxBatchSize(); bi++) { idsPtrHostPtr[bi] = outputIdsDevicePtr + bi * mMaxSeqLen; } std::shared_ptr setupParams; if (params.isExternalDraftTokensLayerTest) { auto externalDraftTokensSetupParams = std::make_shared(); externalDraftTokensSetupParams->randomSeed = std::make_optional>({seed}); externalDraftTokensSetupParams->runtimeTopK = params.topKs.size() ? std::make_optional>(params.topKs) : std::nullopt; externalDraftTokensSetupParams->runtimeTopP = params.topPs.size() ? std::make_optional>(params.topPs) : std::nullopt; setupParams = externalDraftTokensSetupParams; } else if (mBeamWidth == 1) { auto samplingSetupParams = std::make_shared(); samplingSetupParams->randomSeed = std::make_optional>({seed}); samplingSetupParams->runtimeTopK = params.topKs.size() ? std::make_optional>(params.topKs) : std::nullopt; samplingSetupParams->runtimeTopP = params.topPs.size() ? std::make_optional>(params.topPs) : std::nullopt; samplingSetupParams->topPDecay = params.decay.size() ? std::make_optional>(params.decay) : std::nullopt; samplingSetupParams->topPMin = params.minTopP.size() ? std::make_optional>(params.minTopP) : std::nullopt; samplingSetupParams->topPResetIds = params.topPResetIds.size() ? std::make_optional>(params.topPResetIds) : std::nullopt; setupParams = samplingSetupParams; } else // Beam Search { auto samplingSetupParams = std::make_shared(); setupParams = samplingSetupParams; mSrcCacheIndirection = mBufferManager->gpu( ITensor::makeShape({maxBatchSize(), mBeamWidth, mMaxSeqLen}), nvinfer1::DataType::kINT32); mTgtCacheIndirection = mBufferManager->gpu( ITensor::makeShape({maxBatchSize(), mBeamWidth, mMaxSeqLen}), nvinfer1::DataType::kINT32); mParentIds = mBufferManager->gpu( ITensor::makeShape({maxBatchSize(), mBeamWidth, mMaxSeqLen}), nvinfer1::DataType::kINT32); auto constexpr nvTokenIdType = TRTDataType::value; auto constexpr nvSizeType = TRTDataType::value; auto constexpr nvFloatType = TRTDataType::value; auto constexpr nvBoolType = TRTDataType::value; mOutputIdsCBA = mBufferManager->gpu(ITensor::makeShape({maxBatchSize(), 2 * mBeamWidth, mMaxSeqLen}), nvTokenIdType); mLogProbsCBA = mBufferManager->gpu(ITensor::makeShape({maxBatchSize(), 2 * mBeamWidth, mMaxSeqLen}), nvFloatType); mSequenceLengthsCBA = mBufferManager->gpu(ITensor::makeShape({maxBatchSize(), 2 * mBeamWidth}), nvSizeType); mCumLogProbsCBA = mBufferManager->gpu(ITensor::makeShape({maxBatchSize(), 2 * mBeamWidth}), nvFloatType); mNormedScoresCBA = mBufferManager->gpu(ITensor::makeShape({maxBatchSize(), 2 * mBeamWidth}), nvFloatType); mNumBeamsCBA = mBufferManager->gpu(ITensor::makeShape({maxBatchSize()}), nvSizeType); mMinNormedScoresCBA = mBufferManager->gpu(ITensor::makeShape({maxBatchSize()}), nvFloatType); mBatchDones = mBufferManager->gpu(ITensor::makeShape({maxBatchSize()}), nvBoolType); mOutputIdsPtr = mBufferManager->pinned(ITensor::makeShape({maxBatchSize()}), ptrType); mParentIdsPtr = mBufferManager->pinned(ITensor::makeShape({maxBatchSize()}), ptrType); trk::invokeFill(*mSrcCacheIndirection, int32_t{0}, *mStream); trk::invokeFill(*mTgtCacheIndirection, int32_t{0}, *mStream); trk::invokeFill(*mParentIds, int32_t{0}, *mStream); trk::invokeFill(*mOutputIdsCBA, int32_t{0}, *mStream); trk::invokeFill(*mLogProbsCBA, float{0}, *mStream); trk::invokeFill(*mSequenceLengthsCBA, int32_t{0}, *mStream); trk::invokeFill(*mCumLogProbsCBA, float{0}, *mStream); trk::invokeFill(*mNormedScoresCBA, float{0}, *mStream); trk::invokeFill(*mNumBeamsCBA, int32_t{0}, *mStream); trk::invokeFill(*mMinNormedScoresCBA, float{0}, *mStream); trk::invokeFill(*mBatchDones, bool{0}, *mStream); auto outputIdsPtr = bufferCast(*mOutputIdsPtr); auto parentIdsPtr = bufferCast(*mParentIdsPtr); for (SizeType32 bi = 0; bi < maxBatchSize(); bi++) { outputIdsPtr[bi] = outputIdsDevicePtr + bi * mMaxSeqLen; parentIdsPtr[bi] = outputIdsDevicePtr + bi * mMaxSeqLen; } } mDecodingWorkspace->setDeviceBatchSlots(mBatchSlots); mDecodingWorkspace->getDeviceRuntimeLogits()->reshape(ITensor::makeShape({mBatchSize, mBeamWidth, mVocabSize})); mSamplingLayer->setup(mBatchSize, mBeamWidth, mBatchSlots, setupParams, mDecodingWorkspace); mStream->synchronize(); } template std::shared_ptr BaseSamplingLayerTest::createInputTensors(int32_t step) { constexpr int32_t ite = 0; auto decodeInputTensors = (mBeamWidth > 1) ? std::make_shared(mEndIdsDevice, mBatchSlots, step, ite, mBatchSize) : std::make_shared(mEndIdsDevice, mBatchSlots, step, ite, mBatchSize); decodeInputTensors->logits = mDecodingWorkspace->getDeviceRuntimeLogits(); decodeInputTensors->inputLengths = mContextLengthDevice; decodeInputTensors->finished = mFinishedDevice; if (mBeamWidth > 1) { decodeInputTensors->srcCacheIndirection = mSrcCacheIndirection; } else { auto samplingInputTensors = std::dynamic_pointer_cast(decodeInputTensors); samplingInputTensors->probsComputed = mComputeProbs; samplingInputTensors->curandStates = reinterpret_cast(bufferCast(*mCurandStatesDevice)); } return decodeInputTensors; } template std::shared_ptr BaseSamplingLayerTest::createOutputTensors() { // TODO: check log probs and cum_log_probs auto decodeOutputs = (mBeamWidth > 1) ? std::make_shared(mOutputIdsDevice) : std::make_shared(mOutputIdsDevice); decodeOutputs->outputIdsPtr = mIdsPtrHost; decodeOutputs->outputIdsPtrHost = mIdsPtrHost; decodeOutputs->sequenceLength = mSeqLengthsDevice; decodeOutputs->finished = mFinishedDevice; decodeOutputs->outputLogProbs = mOutputLogProbsDevice; decodeOutputs->cumLogProbs = mCumLogProbsDevice; if (mBeamWidth > 1) { auto beamSearchOutputs = std::dynamic_pointer_cast(decodeOutputs); beamSearchOutputs->tgtCacheIndirection = mTgtCacheIndirection; beamSearchOutputs->parentIds = mParentIds; beamSearchOutputs->parentIdsPtr = mParentIdsPtr; beamSearchOutputs->beamHypotheses = std::make_unique(); beamSearchOutputs->beamHypotheses->outputIdsCBA = bufferCast(*mOutputIdsCBA); beamSearchOutputs->beamHypotheses->logProbsCBA = bufferCast(*mLogProbsCBA); beamSearchOutputs->beamHypotheses->sequenceLengthsCBA = bufferCast(*mSequenceLengthsCBA); beamSearchOutputs->beamHypotheses->cumLogProbsCBA = bufferCast(*mCumLogProbsCBA); beamSearchOutputs->beamHypotheses->normedScoresCBA = bufferCast(*mNormedScoresCBA); beamSearchOutputs->beamHypotheses->numBeamsCBA = bufferCast(*mNumBeamsCBA); beamSearchOutputs->beamHypotheses->minNormedScoresCBA = bufferCast(*mMinNormedScoresCBA); beamSearchOutputs->beamHypotheses->batchDones = bufferCast(*mBatchDones); } return decodeOutputs; } template void BaseSamplingLayerTest::batchCopy(int32_t step) { auto const logitsHost = ITensor::wrap(mTestLogitsInit.data() + step * mBeamWidth * mVocabSize, TRTDataType::value, ITensor::makeShape({mBeamWidth, mVocabSize})); for (int32_t bi = 0; bi < mBatchSize; ++bi) { auto logitsDeviceView = ITensor::slice(mDecodingWorkspace->getDeviceRuntimeLogits(), bi, 1); mBufferManager->copy(*logitsHost, *logitsDeviceView); } } template bool BaseSamplingLayerTest::checkResult(int32_t const* outputIds, std::vector> const& expectedIds) { assert(expectedIds.size() == mMaxSeqLen * batchBeam()); int failures = 0; auto* const batchSlotsPtr = bufferCast(*mBatchSlots); for (int32_t i = 0; i < mMaxSeqLen * mBatchSize; ++i) { int32_t s = i / mBatchSize; int32_t b = i % mBatchSize; auto const batchSlot = batchSlotsPtr[b]; std::set expts = expectedIds.at(i); auto const outputId = outputIds[batchSlot * mMaxSeqLen + s]; if (expts.count(outputId) == 0) { if (failures < 10) { std::stringstream ss; ss << " - Fail " << " (step=" << s << ", batch=" << b << ") " << "actual=" << outputId << ", expected"; for (auto const& expt : expts) { ss << " " << expt; } TLLM_LOG_DEBUG("%s", ss.str().c_str()); } ++failures; } } TLLM_LOG_DEBUG( "check...%6s : failures: %d / %d", failures == 0 ? "....OK" : "FAILED", failures, mMaxSeqLen * batchBeam()); return failures == 0; } template void BaseSamplingLayerTest::runTest( std::vector> const& expectedOutputIds, TestSamplingParams const& params, int32_t endId) { mBatchSize = params.batchSize; if (params.beamWidth > 1) { mBeamWidth = params.beamWidth; mMaxSeed = 1; mComputeProbs = true; } initLayer(params); auto const decoderDomain = tensorrt_llm::layers::DecoderDomain(maxBatchSize(), mBeamWidth, mVocabSize, mVocabSizePadded); mDecodingWorkspace = std::make_unique( mBufferManager, decoderDomain, TRTDataType::value, mSamplingLayer->getWorkspaceSize()); mEndId = endId; for (uint64_t seed = 0; seed < mMaxSeed; ++seed) { setup(seed, params); int32_t step = mMaxInputLen; auto inputTensors = createInputTensors(step); auto outputTensors = createOutputTensors(); for (step = mMaxInputLen; step < mMaxOutputLen; ++step) { // Reset by the test value since the sampling layer internally updates the logit buffer. batchCopy(step); if (params.isExternalDraftTokensLayerTest) { inputTensors = createInputTensors(step); } else { inputTensors->step = step; } mDecodingWorkspace->setDeviceBatchSlots(mBatchSlots); mSamplingLayer->forwardAsync(outputTensors, inputTensors, mDecodingWorkspace); mStream->synchronize(); } auto const outputIdsHost = mBufferManager->copyFrom(*mOutputIdsDevice, tensorrt_llm::runtime::MemoryType::kCPU); mStream->synchronize(); bool passed = checkResult(bufferCast(*outputIdsHost), expectedOutputIds); EXPECT_TRUE(passed) << "Output ids check failed at seed " << seed; if (!passed) { std::stringstream ss; ss << "Actual output ids:" << std::endl << *outputIdsHost; TLLM_LOG_DEBUG(ss.str()); } } } template class BaseSamplingLayerTest; template class BaseSamplingLayerTest; } // namespace tensorrt_llm::tests::layers::sampling