/* * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, * Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a * copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "tensorrt_llm/runtime/samplingConfig.h" #include "tensorrt_llm/executor/executor.h" #include "tensorrt_llm/executor/types.h" #include #include using ::testing::_; using ::testing::Invoke; namespace tr = tensorrt_llm::runtime; namespace te = tensorrt_llm::executor; using namespace tensorrt_llm::common; using te::SizeType32; using te::FloatType; using te::TokenIdType; using te::RandomSeedType; static std::nullopt_t constexpr no = std::nullopt; void test(bool const useExternalDraftTokensConfig, SizeType32 beamWidth = 1, std::optional topK = no, std::optional topP = no, std::optional topPMin = no, std::optional topPResetIds = no, std::optional topPDecay = no, std::optional randomSeed = no, std::optional temperature = no, std::optional minLength = no, std::optional beamSearchDiversityRate = no, std::optional repetitionPenalty = no, std::optional presencePenalty = no, std::optional frequencyPenalty = no, std::optional promptIgnoreLength = no, std::optional lengthPenalty = no, std::optional earlyStopping = no, std::optional noRepeatNgramSize = no, std::optional numReturnSequences = no, std::optional minP = no, std::optional> beamWidthArray = no) { // 20 parameters for SamplingConfig, from `beamWidth` to `beamWidthArray` try { te::SamplingConfig execSamplingCfg(beamWidth, topK, topP, topPMin, topPResetIds, topPDecay, randomSeed, temperature, minLength, beamSearchDiversityRate, repetitionPenalty, presencePenalty, frequencyPenalty, promptIgnoreLength, lengthPenalty, earlyStopping, noRepeatNgramSize, numReturnSequences, minP, beamWidthArray); std::optional specCfg = std::nullopt; if (useExternalDraftTokensConfig) { specCfg = te::ExternalDraftTokensConfig({1}, no, 0.5f); } tr::SamplingConfig samplingCfg(execSamplingCfg, specCfg); EXPECT_EQ(samplingCfg.beamWidth, execSamplingCfg.getBeamWidth()); EXPECT_EQ(samplingCfg.numReturnSequences, execSamplingCfg.getNumReturnSequences()); if (useExternalDraftTokensConfig) { EXPECT_TRUE(samplingCfg.draftAcceptanceThreshold.has_value()); EXPECT_THAT(samplingCfg.draftAcceptanceThreshold.value(), testing::ElementsAre(0.5f)); } else { EXPECT_EQ(samplingCfg.draftAcceptanceThreshold, no); } } catch (TllmException& e) { // Come here if `sc` is invalid and the exception is caught FAIL() << "Expected TllmException"; } catch (std::exception const& e) { // Come here if `sc` is invalid but the exception is not caught FAIL() << "Expected TllmException"; } } TEST(samplingConfigTest, validInputs) { // Auto test(false, 1); // Use ExternalDraftTokensConfig test(true, 1); // TopK test(false, 1, 2); // TopP test(false, 1, no, 0.5f); // TopPMin test(false, 1, no, no, 0.5f); // TopP reset ids test(false, 1, no, no, no, 0); // TopP decay test(false, 1, no, no, no, no, 0.5f); // Seed test(false, 1, no, no, no, no, no, 65536); // Temperature test(false, 1, no, no, no, no, no, no, 0.5f); // Min token test(false, 1, no, no, no, no, no, no, no, 64); // Beam divirsity rate test(false, 2, no, no, no, no, no, no, no, no, 0.5f); // Repetition penalty test(false, 1, no, no, no, no, no, no, no, no, no, 1.f); // Presence penalty test(false, 1, no, no, no, no, no, no, no, no, no, no, 1.f); // Frequency penalty test(false, 1, no, no, no, no, no, no, no, no, no, no, no, 1.f); // Prompt ignore length test(false, 1, no, no, no, no, no, no, no, no, no, no, no, no, 1); // Length penalty test(false, 1, no, no, no, no, no, no, no, no, no, no, no, no, no, 1.f); // Early stopping test(false, 1, no, no, no, no, no, no, no, no, no, no, no, no, no, no, 1.f); // No repeat ngram size test(false, 1, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, 2); // NumReturnSequences test(false, 4, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, 2); // MinP, 19 arguments test(false, 1, no, 0.9, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, 0.5f); // BeamWidthArray test(false, 5, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, std::vector{2, 3, 4, 5}); // All parameters { te::SizeType32 beamWidth{5}; te::SizeType32 topK{1}; te::FloatType topP{0.5f}; te::FloatType topPMin{0.5f}; te::SizeType32 topPResetIds{1}; te::FloatType topPDecay{0.5f}; te::RandomSeedType randomSeed{65536}; te::FloatType temperature{0.5f}; te::SizeType32 minLength{64}; te::FloatType beamSearchDiversityRate{0.5f}; te::FloatType repetitionPenalty{0.5f}; te::FloatType presencePenalty{0.5f}; te::FloatType frequencyPenalty{0.5f}; te::SizeType32 promptIgnoreLength{1}; te::FloatType lengthPenalty{0.5f}; te::SizeType32 earlyStopping{1}; te::SizeType32 noRepeatNgramSize{5}; te::SizeType32 numReturnSequences{1}; te::FloatType minP{0.5f}; std::vector beamWidthArray{2, 3, 4, 5}; te::SamplingConfig execSamplingCfg(beamWidth, topK, topP, topPMin, topPResetIds, topPDecay, randomSeed, temperature, minLength, beamSearchDiversityRate, repetitionPenalty, presencePenalty, frequencyPenalty, promptIgnoreLength, lengthPenalty, earlyStopping, noRepeatNgramSize, numReturnSequences, minP, beamWidthArray); te::ExternalDraftTokensConfig specCfg({1}, no, 0.5f); tr::SamplingConfig samplingCfg(execSamplingCfg, specCfg); EXPECT_EQ(samplingCfg.beamWidth, execSamplingCfg.getBeamWidth()); EXPECT_EQ(samplingCfg.numReturnSequences, execSamplingCfg.getNumReturnSequences()); EXPECT_THAT(samplingCfg.draftAcceptanceThreshold.value(), testing::ElementsAre(0.5f)); EXPECT_THAT(samplingCfg.topK.value(), testing::ElementsAre(topK)); EXPECT_THAT(samplingCfg.topP.value(), testing::ElementsAre(topP)); EXPECT_THAT(samplingCfg.topPMin.value(), testing::ElementsAre(topPMin)); EXPECT_THAT(samplingCfg.topPResetIds.value(), testing::ElementsAre(topPResetIds)); EXPECT_THAT(samplingCfg.topPDecay.value(), testing::ElementsAre(topPDecay)); EXPECT_THAT(samplingCfg.randomSeed.value(), testing::ElementsAre(randomSeed)); EXPECT_THAT(samplingCfg.temperature.value(), testing::ElementsAre(temperature)); EXPECT_THAT(samplingCfg.minLength.value(), testing::ElementsAre(minLength)); EXPECT_THAT(samplingCfg.beamSearchDiversityRate.value(), testing::ElementsAre(beamSearchDiversityRate)); EXPECT_THAT(samplingCfg.repetitionPenalty.value(), testing::ElementsAre(repetitionPenalty)); EXPECT_THAT(samplingCfg.presencePenalty.value(), testing::ElementsAre(presencePenalty)); EXPECT_THAT(samplingCfg.frequencyPenalty.value(), testing::ElementsAre(frequencyPenalty)); EXPECT_THAT(samplingCfg.promptIgnoreLength.value(), testing::ElementsAre(promptIgnoreLength)); EXPECT_THAT(samplingCfg.lengthPenalty.value(), testing::ElementsAre(lengthPenalty)); EXPECT_THAT(samplingCfg.earlyStopping.value(), testing::ElementsAre(earlyStopping)); EXPECT_THAT(samplingCfg.noRepeatNgramSize.value(), testing::ElementsAre(noRepeatNgramSize)); EXPECT_THAT(samplingCfg.minP.value(), testing::ElementsAre(minP)); auto const beamWidthArrayReturn = samplingCfg.beamWidthArray.value()[0]; EXPECT_EQ(beamWidthArrayReturn.size(), beamWidthArray.size()); for (int i = 0; i < (int) beamWidthArrayReturn.size(); ++i) { EXPECT_EQ(beamWidthArrayReturn[i], beamWidthArray[i]); } } }