/* * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #include "tensorrt_llm/common/logger.h" #include "tensorrt_llm/executor/executor.h" #include "tensorrt_llm/layers/defaultDecodingParams.h" #include "tensorrt_llm/runtime/common.h" #include #include #include #include namespace tensorrt_llm::runtime { class SamplingConfig { private: using FloatType = float; template using OptVec = std::optional>; template static OptVec fuseValues( std::vector const& configs, std::function(size_t ci)> accessor, T defaultValue) { std::vector values; bool atLeastOneHasValue{false}; for (size_t ci = 0; ci < configs.size(); ++ci) { auto const& configValue = accessor(ci); if (configValue.has_value()) { atLeastOneHasValue = true; break; } } if (atLeastOneHasValue) { for (size_t ci = 0; ci < configs.size(); ++ci) { auto value = defaultValue; auto const& configValue = accessor(ci); if (configValue.has_value()) { TLLM_CHECK(configValue.value().size() == 1); value = configValue.value().front(); } values.push_back(value); } return std::make_optional>(values); } else { return std::nullopt; } } template bool validateVec(std::string name, OptVec const& vec, T min, std::optional max = std::nullopt) { bool valid{true}; if (vec) { valid = std::all_of(vec->begin(), vec->end(), [min, max](T elem) { return min < elem && ((max.has_value() && elem <= max.value()) || (!max.has_value())); }); if (!valid) { std::stringstream ss; ss << "Incorrect sampling param. " << name << " is out of range ("; ss << min << ", "; if (max.has_value()) { ss << max.value(); } else { ss << "inf"; } ss << "]"; TLLM_LOG_WARNING(valid, ss.str()); } } return valid; } public: explicit SamplingConfig(SizeType32 beamWidth = 1) : beamWidth{beamWidth} { } explicit SamplingConfig(std::vector const& configs) { TLLM_CHECK(configs.size() > 0); beamWidth = configs.front().beamWidth; numReturnSequences = configs.front().numReturnSequences; normalizeLogProbs = configs.front().normalizeLogProbs; temperature = fuseValues( configs, [&configs](size_t ci) { return configs[ci].temperature; }, layers::DefaultDecodingParams::getTemperature()); originalTemperature = fuseValues( configs, [&configs](size_t ci) { return configs[ci].originalTemperature; }, layers::DefaultDecodingParams::getTemperature()); minLength = fuseValues( configs, [&configs](size_t ci) { return configs[ci].minLength; }, layers::DefaultDecodingParams::getMinLength()); repetitionPenalty = fuseValues( configs, [&configs](size_t ci) { return configs[ci].repetitionPenalty; }, layers::DefaultDecodingParams::getRepetitionPenalty()); presencePenalty = fuseValues( configs, [&configs](size_t ci) { return configs[ci].presencePenalty; }, layers::DefaultDecodingParams::getPresencePenalty()); frequencyPenalty = fuseValues( configs, [&configs](size_t ci) { return configs[ci].frequencyPenalty; }, layers::DefaultDecodingParams::getFrequencyPenalty()); promptIgnoreLength = fuseValues( configs, [&configs](size_t ci) { return configs[ci].promptIgnoreLength; }, layers::DefaultDecodingParams::getPromptIgnoreLength()); noRepeatNgramSize = fuseValues( configs, [&configs](size_t ci) { return configs[ci].noRepeatNgramSize; }, layers::DefaultDecodingParams::getNoRepeatNgramSize()); topK = fuseValues( configs, [&configs](size_t ci) { return configs[ci].topK; }, layers::DefaultDecodingParams::getTopK()); topP = fuseValues( configs, [&configs](size_t ci) { return configs[ci].topP; }, layers::DefaultDecodingParams::getTopP()); // Generate a random seed for each samplingConfig with randomSeed == std::nullopt randomSeed = std::vector(configs.size()); for (size_t ci = 0; ci < configs.size(); ++ci) { auto const& configValue = configs[ci].randomSeed; if (configValue) { TLLM_CHECK(configValue->size() == 1); randomSeed->at(ci) = configValue->front(); } else { randomSeed->at(ci) = layers::DefaultDecodingParams::generateRandomSeed(); } } topPDecay = fuseValues( configs, [&configs](size_t ci) { return configs[ci].topPDecay; }, layers::DefaultDecodingParams::getTopPDecay()); topPMin = fuseValues( configs, [&configs](size_t ci) { return configs[ci].topPMin; }, layers::DefaultDecodingParams::getTopPMin()); topPResetIds = fuseValues( configs, [&configs](size_t ci) { return configs[ci].topPResetIds; }, layers::DefaultDecodingParams::getTopPResetId()); beamSearchDiversityRate = fuseValues( configs, [&configs](size_t ci) { return configs[ci].beamSearchDiversityRate; }, layers::DefaultDecodingParams::getBeamSearchDiversity()); lengthPenalty = fuseValues( configs, [&configs](size_t ci) { return configs[ci].lengthPenalty; }, layers::DefaultDecodingParams::getLengthPenalty()); earlyStopping = fuseValues( configs, [&configs](size_t ci) { return configs[ci].earlyStopping; }, layers::DefaultDecodingParams::getEarlyStopping()); topKMedusaHeads = fuseValues>( configs, [&configs](size_t ci) { return configs[ci].topKMedusaHeads; }, layers::DefaultDecodingParams::getTopKMedusaHeads()); outputLogProbs = fuseValues( configs, [&configs](size_t ci) { return configs[ci].outputLogProbs; }, false); cumLogProbs = fuseValues( configs, [&configs](size_t ci) { return configs[ci].cumLogProbs; }, false); beamWidthArray = fuseValues>( configs, [&configs](size_t ci) { return configs[ci].beamWidthArray; }, layers::DefaultDecodingParams::getBeamWidthArray()); // Only used for tests. draftAcceptanceThreshold = fuseValues( configs, [&configs](size_t ci) { return configs[ci].draftAcceptanceThreshold; }, 0); minP = fuseValues( configs, [&configs](size_t ci) { return configs[ci].minP; }, layers::DefaultDecodingParams::getMinP()); } explicit SamplingConfig(executor::SamplingConfig const& samplingConfig, std::optional const& externalDraftTokensConfig = std::nullopt) : beamWidth{samplingConfig.getBeamWidth()} , numReturnSequences(samplingConfig.getNumReturnSequences()) { if (externalDraftTokensConfig && externalDraftTokensConfig.value().getAcceptanceThreshold()) { draftAcceptanceThreshold = std::vector{externalDraftTokensConfig.value().getAcceptanceThreshold().value()}; } #define SET_FROM_OPTIONAL(varName, VarName, VarType) \ \ if (samplingConfig.get##VarName()) \ { \ varName = std::vector{samplingConfig.get##VarName().value()}; \ } SET_FROM_OPTIONAL(topK, TopK, SizeType32) SET_FROM_OPTIONAL(topP, TopP, FloatType) SET_FROM_OPTIONAL(topPMin, TopPMin, FloatType) SET_FROM_OPTIONAL(topPResetIds, TopPResetIds, TokenIdType) SET_FROM_OPTIONAL(topPDecay, TopPDecay, FloatType) SET_FROM_OPTIONAL(randomSeed, Seed, uint64_t) SET_FROM_OPTIONAL(temperature, Temperature, FloatType) SET_FROM_OPTIONAL(originalTemperature, Temperature, FloatType) SET_FROM_OPTIONAL(minLength, MinTokens, SizeType32) SET_FROM_OPTIONAL(beamSearchDiversityRate, BeamSearchDiversityRate, FloatType) SET_FROM_OPTIONAL(repetitionPenalty, RepetitionPenalty, FloatType) SET_FROM_OPTIONAL(presencePenalty, PresencePenalty, FloatType) SET_FROM_OPTIONAL(frequencyPenalty, FrequencyPenalty, FloatType) SET_FROM_OPTIONAL(promptIgnoreLength, PromptIgnoreLength, SizeType32) SET_FROM_OPTIONAL(lengthPenalty, LengthPenalty, FloatType) SET_FROM_OPTIONAL(earlyStopping, EarlyStopping, SizeType32) SET_FROM_OPTIONAL(noRepeatNgramSize, NoRepeatNgramSize, SizeType32) SET_FROM_OPTIONAL(minP, MinP, FloatType) SET_FROM_OPTIONAL(beamWidthArray, BeamWidthArray, std::vector) #undef SET_FROM_OPTIONAL } bool validate() { auto constexpr fltEpsilon = std::numeric_limits::epsilon(); bool valid{true}; valid &= (beamWidth > 0); if (!valid) { TLLM_LOG_WARNING( "Requested beam width %d is incorrect. Must be > 0. To de-activate beam searching set beamWidth to 1.", beamWidth); } if (numReturnSequences) { valid &= (numReturnSequences.value() > 0); if (!valid) { TLLM_LOG_WARNING( "Requested numReturnSequences %d is incorrect. Must be > 0.", numReturnSequences.value()); } valid &= (beamWidth == 1 || numReturnSequences.value() <= beamWidth); if (!valid) { TLLM_LOG_WARNING( "Requested numReturnSequences %d is incorrect. In beam search, numReturnSequences should not " "exceed the beam width %d.", numReturnSequences.value(), beamWidth); } } valid &= validateVec("topK", topK, -1); valid &= validateVec("topP", topP, -fltEpsilon, {1.f}); valid &= validateVec("topPMin", topPMin, 0.f, {1.f}); valid &= validateVec("topPResetIds", topPResetIds, -1); valid &= validateVec("topPDecay", topPDecay, 0.f, {1.f}); valid &= validateVec("temperature", temperature, -fltEpsilon); valid &= validateVec("minLength", minLength, -1); valid &= validateVec("beamSearchDiversityRate", beamSearchDiversityRate, -fltEpsilon); valid &= validateVec("repetitionPenalty", repetitionPenalty, 0.f); // TODO: checking `lengthPenalty`leads to a failure in // `test_openai_chat_example`, debug and re-enable it later. // valid &= validateVec("lengthPenalty", lengthPenalty, 0.f); valid &= validateVec("noRepeatNgramSize", noRepeatNgramSize, 0); valid &= validateVec("minP", minP, -fltEpsilon, {1.f}); // TODO: check `beamWidthArray` // Detect greedy sampling and overwrite params. if (temperature) { // Keep original temperature for Eagle. bool saveOriginalTemperature{false}; if (!originalTemperature) { saveOriginalTemperature = true; originalTemperature = std::vector(temperature->size()); } for (size_t ti = 0; ti < temperature->size(); ++ti) { if (temperature->at(ti) == 0.f) { if (saveOriginalTemperature) { originalTemperature->at(ti) = 0.f; } temperature->at(ti) = 1.0f; if (topK) { topK->at(ti) = 1; } if (topP) { topP->at(ti) = 1.f; } } else if (saveOriginalTemperature) { originalTemperature->at(ti) = temperature->at(ti); } } } return valid; } template bool useDefaultValues(OptVec const& vec, T defaultValue) { bool useDefault{true}; if (vec) { useDefault = std::all_of(vec->begin(), vec->end(), [defaultValue](T elem) { return elem == defaultValue; }); } return useDefault; } public: SizeType32 beamWidth; std::optional numReturnSequences; // penalties, [1] for one request, [batchSize] for one batch, the same for other parameters below OptVec temperature; // [1] or [batchSize] OptVec originalTemperature; // [1] or [batchSize] OptVec minLength; // [1] or [batchSize] OptVec repetitionPenalty; // [1] or [batchSize] OptVec presencePenalty; // [1] or [batchSize] OptVec frequencyPenalty; // [1] or [batchSize] OptVec promptIgnoreLength; // [1] or [batchSize] OptVec noRepeatNgramSize; // [1] or [batchSize] // probs OptVec outputLogProbs; OptVec cumLogProbs; // sampling layers OptVec topK; // [1] or [batchSize] OptVec topP; // [1] or [batchSize] OptVec randomSeed; // [1] or [batchSize] OptVec topPDecay; // [1] or [batchSize], between [0, 1] OptVec topPMin; // [1] or [batchSize], between [0, 1] OptVec topPResetIds; // [1] or [batchSize] OptVec minP; // [1] or [batchSize] // beam search layer OptVec beamSearchDiversityRate; // [1] or [batchSize] OptVec lengthPenalty; // [1] or [batchSize] OptVec earlyStopping; // [1] or [batchSize] OptVec> beamWidthArray; // [maxBeamWidthArrayLength] or [batchSize, maxBeamWidthArrayLength] // speculative decoding, only the first value is used (in gptDecoderBatched.cpp) OptVec draftAcceptanceThreshold; // [1] or [batchSize] // medusa params OptVec> topKMedusaHeads; // [batchSize, maxMedusaHeads] std::optional normalizeLogProbs; bool operator==(SamplingConfig const& other) const { return beamWidth == other.beamWidth && numReturnSequences == other.numReturnSequences && temperature == other.temperature && originalTemperature == other.originalTemperature && minLength == other.minLength && repetitionPenalty == other.repetitionPenalty && presencePenalty == other.presencePenalty && frequencyPenalty == other.frequencyPenalty && promptIgnoreLength == other.promptIgnoreLength && noRepeatNgramSize == other.noRepeatNgramSize && topK == other.topK && topP == other.topP && randomSeed == other.randomSeed && topPDecay == other.topPDecay && topPMin == other.topPMin && topPResetIds == other.topPResetIds && beamSearchDiversityRate == other.beamSearchDiversityRate && lengthPenalty == other.lengthPenalty && earlyStopping == other.earlyStopping && draftAcceptanceThreshold == other.draftAcceptanceThreshold && topKMedusaHeads == other.topKMedusaHeads && normalizeLogProbs == other.normalizeLogProbs && outputLogProbs == other.outputLogProbs && cumLogProbs == other.cumLogProbs && minP == other.minP && beamWidthArray == other.beamWidthArray; } SizeType32 getNumReturnBeams() const { if (numReturnSequences && beamWidth > 1) { return std::min(numReturnSequences.value(), beamWidth); } return beamWidth; } // Get the maximum beam width of a whole SamplingConfig SizeType32 getMaxBeamWidth() const noexcept { SizeType32 maxBeamWidth = this->beamWidth; // For non-Variable-Beam-Width-Search auto const& beamWidthArray = this->beamWidthArray; if (beamWidthArray.has_value()) { for (size_t indexSC = 0; indexSC < beamWidthArray->size(); ++indexSC) { auto const& array = beamWidthArray.value()[indexSC]; auto arrayMax = *std::max_element(array.begin(), array.end()); maxBeamWidth = std::max(maxBeamWidth, arrayMax); } } return maxBeamWidth; } }; } // namespace tensorrt_llm::runtime