[None][feat] Support ignored prompt length for penalties via new sampling config parameter (#8127)

Signed-off-by: Xuanyu Chen <xuanyuc@nvidia.com>
This commit is contained in:
nvxuanyuc 2025-10-27 10:12:31 -07:00 committed by GitHub
parent b9b2802599
commit d1398c05e6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
32 changed files with 538 additions and 196 deletions

View File

@ -71,6 +71,7 @@ public:
std::optional<FloatType> const& repetitionPenalty = std::nullopt,
std::optional<FloatType> const& presencePenalty = std::nullopt,
std::optional<FloatType> const& frequencyPenalty = std::nullopt,
std::optional<SizeType32> const& promptIgnoreLength = std::nullopt,
std::optional<FloatType> const& lengthPenalty = std::nullopt,
std::optional<SizeType32> const& earlyStopping = std::nullopt,
std::optional<SizeType32> const& noRepeatNgramSize = std::nullopt,
@ -94,6 +95,7 @@ public:
[[nodiscard]] std::optional<FloatType> getRepetitionPenalty() const;
[[nodiscard]] std::optional<FloatType> getPresencePenalty() const;
[[nodiscard]] std::optional<FloatType> getFrequencyPenalty() const;
[[nodiscard]] std::optional<SizeType32> getPromptIgnoreLength() const;
[[nodiscard]] std::optional<FloatType> getLengthPenalty() const;
[[nodiscard]] std::optional<SizeType32> getEarlyStopping() const;
[[nodiscard]] std::optional<SizeType32> getNoRepeatNgramSize() const;
@ -114,6 +116,7 @@ public:
void setRepetitionPenalty(std::optional<FloatType> const& repetitionPenalty);
void setPresencePenalty(std::optional<FloatType> const& presencePenalty);
void setFrequencyPenalty(std::optional<FloatType> const& frequencyPenalty);
void setPromptIgnoreLength(std::optional<SizeType32> const& promptIgnoreLength);
void setLengthPenalty(std::optional<FloatType> const& lengthPenalty);
void setEarlyStopping(std::optional<SizeType32> const& earlyStopping);
void setNoRepeatNgramSize(std::optional<SizeType32> const& noRepeatNgramSize);
@ -133,6 +136,8 @@ private:
static std::optional<FloatType> const& checkBeamSearchDiversityRate(
std::optional<FloatType> const& beamSearchDiversityRate);
static std::optional<FloatType> const& checkRepetitionPenalty(std::optional<FloatType> const& repetitionpenalty);
static std::optional<SizeType32> const& checkPromptIgnoreLength(
std::optional<SizeType32> const& promptIgnoreLength);
static std::optional<FloatType> const& checkLengthPenalty(std::optional<FloatType> const& lengthPenalty);
static std::optional<SizeType32> const& checkEarlyStopping(std::optional<SizeType32> const& earlyStopping);
static std::optional<SizeType32> const& checkNoRepeatNgramSize(std::optional<SizeType32> const& noRepeatNgramSize);
@ -174,6 +179,9 @@ private:
/// @brief Used to penalize tokens already present in the sequence (dependent on the number of appearances). It can
/// have any values. Values < 0.f encourage repetition, values > 0.f discourage it. Default is 0.f
std::optional<FloatType> mFrequencyPenalty;
/// @brief Controls how many tokens to ignore from the prompt for presence and frequency penalties. Values <= 0 have
/// no effect. Values > input (prompt) length will be clamped. Default is 0.
std::optional<SizeType32> mPromptIgnoreLength;
/// @brief Controls how to penalize longer sequences in beam search. Default is 0.f
std::optional<FloatType> mLengthPenalty;
/// @brief Controls whether the generation process finishes once beamWidth sentences are generated (ends with

View File

@ -56,6 +56,11 @@ public:
return 1;
}
[[nodiscard]] __host__ __device__ static constexpr runtime::SizeType32 getPromptIgnoreLength()
{
return 0;
}
[[nodiscard]] __host__ __device__ static constexpr uint64_t getSeed()
{
return 0;

View File

@ -133,6 +133,9 @@ public:
frequencyPenalty = fuseValues<FloatType>(
configs, [&configs](size_t ci) { return configs[ci].frequencyPenalty; },
layers::DefaultDecodingParams::getFrequencyPenalty());
promptIgnoreLength = fuseValues<SizeType32>(
configs, [&configs](size_t ci) { return configs[ci].promptIgnoreLength; },
layers::DefaultDecodingParams::getPromptIgnoreLength());
noRepeatNgramSize = fuseValues<SizeType32>(
configs, [&configs](size_t ci) { return configs[ci].noRepeatNgramSize; },
layers::DefaultDecodingParams::getNoRepeatNgramSize());
@ -224,6 +227,7 @@ public:
SET_FROM_OPTIONAL(repetitionPenalty, RepetitionPenalty, FloatType)
SET_FROM_OPTIONAL(presencePenalty, PresencePenalty, FloatType)
SET_FROM_OPTIONAL(frequencyPenalty, FrequencyPenalty, FloatType)
SET_FROM_OPTIONAL(promptIgnoreLength, PromptIgnoreLength, SizeType32)
SET_FROM_OPTIONAL(lengthPenalty, LengthPenalty, FloatType)
SET_FROM_OPTIONAL(earlyStopping, EarlyStopping, SizeType32)
SET_FROM_OPTIONAL(noRepeatNgramSize, NoRepeatNgramSize, SizeType32)
@ -342,6 +346,7 @@ public:
OptVec<FloatType> repetitionPenalty; // [1] or [batchSize]
OptVec<FloatType> presencePenalty; // [1] or [batchSize]
OptVec<FloatType> frequencyPenalty; // [1] or [batchSize]
OptVec<SizeType32> promptIgnoreLength; // [1] or [batchSize]
OptVec<SizeType32> noRepeatNgramSize; // [1] or [batchSize]
// probs
@ -377,13 +382,14 @@ public:
&& temperature == other.temperature && originalTemperature == other.originalTemperature
&& minLength == other.minLength && repetitionPenalty == other.repetitionPenalty
&& presencePenalty == other.presencePenalty && frequencyPenalty == other.frequencyPenalty
&& noRepeatNgramSize == other.noRepeatNgramSize && topK == other.topK && topP == other.topP
&& randomSeed == other.randomSeed && topPDecay == other.topPDecay && topPMin == other.topPMin
&& topPResetIds == other.topPResetIds && beamSearchDiversityRate == other.beamSearchDiversityRate
&& lengthPenalty == other.lengthPenalty && earlyStopping == other.earlyStopping
&& draftAcceptanceThreshold == other.draftAcceptanceThreshold && topKMedusaHeads == other.topKMedusaHeads
&& normalizeLogProbs == other.normalizeLogProbs && outputLogProbs == other.outputLogProbs
&& cumLogProbs == other.cumLogProbs && minP == other.minP && beamWidthArray == other.beamWidthArray;
&& promptIgnoreLength == other.promptIgnoreLength && noRepeatNgramSize == other.noRepeatNgramSize
&& topK == other.topK && topP == other.topP && randomSeed == other.randomSeed
&& topPDecay == other.topPDecay && topPMin == other.topPMin && topPResetIds == other.topPResetIds
&& beamSearchDiversityRate == other.beamSearchDiversityRate && lengthPenalty == other.lengthPenalty
&& earlyStopping == other.earlyStopping && draftAcceptanceThreshold == other.draftAcceptanceThreshold
&& topKMedusaHeads == other.topKMedusaHeads && normalizeLogProbs == other.normalizeLogProbs
&& outputLogProbs == other.outputLogProbs && cumLogProbs == other.cumLogProbs && minP == other.minP
&& beamWidthArray == other.beamWidthArray;
}
SizeType32 getNumReturnBeams() const

View File

@ -34,9 +34,9 @@ SamplingConfig::SamplingConfig(SizeType32 beamWidth, OptSize32 const& topK, OptF
OptFloat const& topPMin, std::optional<TokenIdType> const& topPResetIds, OptFloat const& topPDecay,
std::optional<RandomSeedType> const& seed, OptFloat const& temperature, OptSize32 const& minTokens,
OptFloat const& beamSearchDiversityRate, OptFloat const& repetitionPenalty, OptFloat const& presencePenalty,
OptFloat const& frequencyPenalty, OptFloat const& lengthPenalty, OptSize32 const& earlyStopping,
OptSize32 const& noRepeatNgramSize, OptSize32 const& numReturnSequences, OptFloat const& minP,
OptVec<SizeType32> const& beamWidthArray)
OptFloat const& frequencyPenalty, OptSize32 const& promptIgnoreLength, OptFloat const& lengthPenalty,
OptSize32 const& earlyStopping, OptSize32 const& noRepeatNgramSize, OptSize32 const& numReturnSequences,
OptFloat const& minP, OptVec<SizeType32> const& beamWidthArray)
: mBeamWidth(checkBeamWidth(beamWidth))
, mTopK(checkTopK(topK))
, mTopP(checkTopP(topP))
@ -50,6 +50,7 @@ SamplingConfig::SamplingConfig(SizeType32 beamWidth, OptSize32 const& topK, OptF
, mRepetitionPenalty(checkRepetitionPenalty(repetitionPenalty))
, mPresencePenalty(presencePenalty)
, mFrequencyPenalty(frequencyPenalty)
, mPromptIgnoreLength(checkPromptIgnoreLength(promptIgnoreLength))
, mLengthPenalty(checkLengthPenalty(lengthPenalty))
, mEarlyStopping(checkEarlyStopping(earlyStopping))
, mNoRepeatNgramSize(checkNoRepeatNgramSize(noRepeatNgramSize))
@ -67,9 +68,10 @@ bool SamplingConfig::operator==(SamplingConfig const& other) const
&& mTemperature == other.mTemperature && mMinTokens == other.mMinTokens
&& mBeamSearchDiversityRate == other.mBeamSearchDiversityRate && mRepetitionPenalty == other.mRepetitionPenalty
&& mPresencePenalty == other.mPresencePenalty && mFrequencyPenalty == other.mFrequencyPenalty
&& mLengthPenalty == other.mLengthPenalty && mEarlyStopping == other.mEarlyStopping
&& mNoRepeatNgramSize == other.mNoRepeatNgramSize && mNumReturnSequences == other.mNumReturnSequences
&& mMinP == other.mMinP && mBeamWidthArray == other.mBeamWidthArray;
&& mPromptIgnoreLength == other.mPromptIgnoreLength && mLengthPenalty == other.mLengthPenalty
&& mEarlyStopping == other.mEarlyStopping && mNoRepeatNgramSize == other.mNoRepeatNgramSize
&& mNumReturnSequences == other.mNumReturnSequences && mMinP == other.mMinP
&& mBeamWidthArray == other.mBeamWidthArray;
}
// Getters
@ -143,6 +145,11 @@ OptFloat SamplingConfig::getFrequencyPenalty() const
return mFrequencyPenalty;
}
OptSize32 SamplingConfig::getPromptIgnoreLength() const
{
return mPromptIgnoreLength;
}
OptFloat SamplingConfig::getLengthPenalty() const
{
return mLengthPenalty;
@ -240,6 +247,11 @@ void SamplingConfig::setFrequencyPenalty(OptFloat const& frequencyPenalty)
mFrequencyPenalty = frequencyPenalty;
}
void SamplingConfig::setPromptIgnoreLength(OptSize32 const& promptIgnoreLength)
{
mPromptIgnoreLength = checkPromptIgnoreLength(promptIgnoreLength);
}
void SamplingConfig::setLengthPenalty(OptFloat const& lengthPenalty)
{
mLengthPenalty = lengthPenalty; // TODO: re-enable `checkLengthPenalty` later
@ -362,6 +374,15 @@ OptFloat const& SamplingConfig::checkRepetitionPenalty(OptFloat const& repetitio
return repetitionpenalty;
}
OptSize32 const& SamplingConfig::checkPromptIgnoreLength(OptSize32 const& promptIgnoreLength)
{
if (promptIgnoreLength.has_value())
{
TLLM_CHECK(promptIgnoreLength.value() >= 0);
}
return promptIgnoreLength;
}
OptFloat const& SamplingConfig::checkLengthPenalty(OptFloat const& lengthPenalty)
{
if (lengthPenalty.has_value())

View File

@ -159,6 +159,7 @@ SamplingConfig Serialization::deserializeSamplingConfig(std::istream& is)
auto repetitionPenalty = su::deserialize<std::optional<FloatType>>(is);
auto presencePenalty = su::deserialize<std::optional<FloatType>>(is);
auto frequencyPenalty = su::deserialize<std::optional<FloatType>>(is);
auto promptIgnoreLength = su::deserialize<std::optional<SizeType32>>(is);
auto lengthPenalty = su::deserialize<std::optional<FloatType>>(is);
auto earlyStopping = su::deserialize<std::optional<SizeType32>>(is);
auto noRepeatNgramSize = su::deserialize<std::optional<SizeType32>>(is);
@ -167,8 +168,8 @@ SamplingConfig Serialization::deserializeSamplingConfig(std::istream& is)
auto beamWidthArray = su::deserialize<std::optional<std::vector<SizeType32>>>(is);
return SamplingConfig{beamWidth, topK, topP, topPMin, topPResetIds, topPDecay, randomSeed, temperature, minLength,
beamSearchDiversityRate, repetitionPenalty, presencePenalty, frequencyPenalty, lengthPenalty, earlyStopping,
noRepeatNgramSize, numReturnSequences, minP, beamWidthArray};
beamSearchDiversityRate, repetitionPenalty, presencePenalty, frequencyPenalty, promptIgnoreLength,
lengthPenalty, earlyStopping, noRepeatNgramSize, numReturnSequences, minP, beamWidthArray};
}
void Serialization::serialize(SamplingConfig const& config, std::ostream& os)
@ -186,6 +187,7 @@ void Serialization::serialize(SamplingConfig const& config, std::ostream& os)
su::serialize(config.mRepetitionPenalty, os);
su::serialize(config.mPresencePenalty, os);
su::serialize(config.mFrequencyPenalty, os);
su::serialize(config.mPromptIgnoreLength, os);
su::serialize(config.mLengthPenalty, os);
su::serialize(config.mEarlyStopping, os);
su::serialize(config.mNoRepeatNgramSize, os);
@ -210,6 +212,7 @@ size_t Serialization::serializedSize(SamplingConfig const& config)
totalSize += su::serializedSize(config.mRepetitionPenalty);
totalSize += su::serializedSize(config.mPresencePenalty);
totalSize += su::serializedSize(config.mFrequencyPenalty);
totalSize += su::serializedSize(config.mPromptIgnoreLength);
totalSize += su::serializedSize(config.mLengthPenalty);
totalSize += su::serializedSize(config.mEarlyStopping);
totalSize += su::serializedSize(config.mNoRepeatNgramSize);

View File

@ -39,10 +39,10 @@ template <typename T>
__global__ void batchApplyPenalty(T const* const* inputLogits, T* outputLogits, T const* biases,
TokenIdType* penaltyWorkspace, TokenIdType const* penaltyWorkspacePrev, float const* temperatures,
float const* repetitionPenalties, float const* presencePenalties, float const* frequencyPenalties,
SizeType32 maxSeqLen, SizeType32 vocabSize, SizeType32 vocabSizePadded, TokenIdType const** outputIdsPtr,
SizeType32 const** parentIdsPtr, SizeType32 const* inputLengths, SizeType32 const* sequenceLengths,
SizeType32 const* minLengths, TokenIdType const* endIds, SizeType32 const* batchSlots,
SizeType32 const* tokensPerStep, FinishedState const* finished)
SizeType32 const* promptIgnoreLengths, SizeType32 maxSeqLen, SizeType32 vocabSize, SizeType32 vocabSizePadded,
TokenIdType const** outputIdsPtr, SizeType32 const** parentIdsPtr, SizeType32 const* inputLengths,
SizeType32 const* sequenceLengths, SizeType32 const* minLengths, TokenIdType const* endIds,
SizeType32 const* batchSlots, SizeType32 const* tokensPerStep, FinishedState const* finished)
{
auto const beamWidth = static_cast<SizeType32>(gridDim.y);
auto const maxTokensPerStep = static_cast<SizeType32>(gridDim.z);
@ -73,6 +73,7 @@ __global__ void batchApplyPenalty(T const* const* inputLogits, T* outputLogits,
float presencePenalty{layers::DefaultDecodingParams::getPresencePenalty()};
float frequencyPenalty{layers::DefaultDecodingParams::getFrequencyPenalty()};
SizeType32 minLength{layers::DefaultDecodingParams::getMinLength()};
SizeType32 promptIgnoreLength{layers::DefaultDecodingParams::getPromptIgnoreLength()};
bool accumulateVocab{false};
bool hasTemperature{false};
bool hasMinLength{false};
@ -103,27 +104,42 @@ __global__ void batchApplyPenalty(T const* const* inputLogits, T* outputLogits,
minLength = minLengths[batchSlot];
hasMinLength |= (minLength > 0);
}
if (promptIgnoreLengths != nullptr)
{
promptIgnoreLength = min(promptIgnoreLengths[batchSlot], inputLen);
}
// Initialize or update the number of occurrences of tokens
if (accumulateVocab)
{
penaltyWorkspace += batchBeamStepIdx * vocabSize;
penaltyWorkspace += batchBeamStepIdx * 2 * vocabSize;
if (currentStep <= inputLen)
{ // Context phase
for (auto index = static_cast<SizeType32>(threadIdx.x); index < vocabSize;
for (auto index = static_cast<SizeType32>(threadIdx.x); index < 2 * vocabSize;
index += static_cast<SizeType32>(blockDim.x))
{
penaltyWorkspace[index] = 0;
}
__syncthreads();
for (auto step = static_cast<SizeType32>(threadIdx.x); step < inputLen;
for (auto step = static_cast<SizeType32>(threadIdx.x); step < promptIgnoreLength;
step += static_cast<SizeType32>(blockDim.x))
{
// All beams in the context phase are identical
auto penaltyIndex = outputIdsPtr[batchSlot][beamIdx * maxSeqLen + step];
if (penaltyIndex < vocabSize)
{
atomicAdd(&penaltyWorkspace[penaltyIndex], 1);
penaltyWorkspace[penaltyIndex] = 1;
}
}
for (auto step = promptIgnoreLength + static_cast<SizeType32>(threadIdx.x); step < inputLen;
step += static_cast<SizeType32>(blockDim.x))
{
// All beams in the context phase are identical
auto penaltyIndex = outputIdsPtr[batchSlot][beamIdx * maxSeqLen + step];
if (penaltyIndex < vocabSize)
{
atomicAdd(&penaltyWorkspace[penaltyIndex + vocabSize], 1);
}
}
}
@ -132,8 +148,9 @@ __global__ void batchApplyPenalty(T const* const* inputLogits, T* outputLogits,
if (beamWidth > 1)
{
auto parentBeam = parentIdsPtr[batchSlot][beamIdx * maxSeqLen + currentStep - 1];
penaltyWorkspacePrev += ((batchIdx * beamWidth + parentBeam) * maxTokensPerStep + stepIdx) * vocabSize;
for (auto index = static_cast<SizeType32>(threadIdx.x); index < vocabSize;
penaltyWorkspacePrev
+= ((batchIdx * beamWidth + parentBeam) * maxTokensPerStep + stepIdx) * (2 * vocabSize);
for (auto index = static_cast<SizeType32>(threadIdx.x); index < 2 * vocabSize;
index += static_cast<SizeType32>(blockDim.x))
{
penaltyWorkspace[index] = penaltyWorkspacePrev[index];
@ -145,7 +162,7 @@ __global__ void batchApplyPenalty(T const* const* inputLogits, T* outputLogits,
auto penaltyIndex = outputIdsPtr[batchSlot][beamIdx * maxSeqLen + currentStep - 1];
if (penaltyIndex < vocabSize)
{
penaltyWorkspace[penaltyIndex] += 1;
penaltyWorkspace[penaltyIndex + vocabSize] += 1;
}
}
}
@ -174,14 +191,19 @@ __global__ void batchApplyPenalty(T const* const* inputLogits, T* outputLogits,
}
if (accumulateVocab)
{
SizeType32 numOccurences = penaltyWorkspace[index];
if (numOccurences > 0)
SizeType32 numOccurences = penaltyWorkspace[index + vocabSize];
SizeType32 ifPresenceInFullSeq = numOccurences | penaltyWorkspace[index];
if (ifPresenceInFullSeq > 0)
{
// Repetition
if (repetitionPenalties != nullptr)
{
logit = logit < 0.0f ? logit * repetitionPenalty : logit / repetitionPenalty;
}
}
if (numOccurences > 0)
{
// Presence
if (presencePenalties != nullptr)
{
@ -230,9 +252,10 @@ void invokeBatchApplyPenalty(InvokeBatchApplyPenaltyParams<T> const& params)
dim3 grid(params.batchSize, params.beamWidth, params.maxTokensPerStep);
batchApplyPenalty<T><<<grid, block, 0, params.stream>>>(params.inputLogits, params.outputLogits, params.biases,
params.penaltyWorkspace, params.penaltyWorkspacePrev, params.temperatures, params.repetitionPenalties,
params.presencePenalties, params.frequencyPenalties, params.maxSeqLen, params.vocabSize, params.vocabSizePadded,
params.outputIdsPtr, params.parentIdsPtr, params.inputLengths, params.sequenceLengths, params.minLengths,
params.endIds, params.batchSlots, params.tokensPerStep, params.finished);
params.presencePenalties, params.frequencyPenalties, params.promptIgnoreLengths, params.maxSeqLen,
params.vocabSize, params.vocabSizePadded, params.outputIdsPtr, params.parentIdsPtr, params.inputLengths,
params.sequenceLengths, params.minLengths, params.endIds, params.batchSlots, params.tokensPerStep,
params.finished);
}
template void invokeBatchApplyPenalty(InvokeBatchApplyPenaltyParams<float> const& params);

View File

@ -35,6 +35,7 @@ struct InvokeBatchApplyPenaltyParams
float const* repetitionPenalties;
float const* presencePenalties;
float const* frequencyPenalties;
runtime::SizeType32 const* promptIgnoreLengths;
runtime::SizeType32 batchSize;
runtime::SizeType32 beamWidth;
runtime::SizeType32 maxSeqLen;

View File

@ -29,11 +29,12 @@ namespace kernels
enum class DecodingPenaltyType
{
Temperature, // the temperature penalty
Repetition, // the repetition penalty
Presence, // the presence penalty
Frequency, // the frequency penalty
MinLength, // the min length penalty
Temperature, // the temperature penalty
Repetition, // the repetition penalty
Presence, // the presence penalty
Frequency, // the frequency penalty
MinLength, // the min length penalty
PromptIgnoreLength, // the prompt ignore length for presence/frequency penalty
};
inline std::pair<float, float> getLimitsPenalty(DecodingPenaltyType penaltyType)
@ -49,6 +50,7 @@ inline std::pair<float, float> getLimitsPenalty(DecodingPenaltyType penaltyType)
case DecodingPenaltyType::Presence: return std::make_pair(fltMin, fltMax);
case DecodingPenaltyType::Frequency: return std::make_pair(fltMin, fltMax);
case DecodingPenaltyType::MinLength: return std::make_pair(-fltEpsilon, fltMax);
case DecodingPenaltyType::PromptIgnoreLength: return std::make_pair(-fltEpsilon, fltMax);
}
TLLM_CHECK_WITH_INFO(false, "Unknown penalty type %d", static_cast<int32_t>(penaltyType));
return std::make_pair(fltMin, fltMax);

View File

@ -128,11 +128,12 @@ public:
class PenaltySetupParams : public BaseSetupParams
{
public:
OptVec<float> temperature; // [1] or [setupBatchSize]
OptVec<runtime::SizeType32> minLength; // [1] or [setupBatchSize]
OptVec<float> repetitionPenalty; // [1] or [setupBatchSize]
OptVec<float> presencePenalty; // [1] or [setupBatchSize]
OptVec<float> frequencyPenalty; // [1] or [setupBatchSize]
OptVec<float> temperature; // [1] or [setupBatchSize]
OptVec<runtime::SizeType32> minLength; // [1] or [setupBatchSize]
OptVec<float> repetitionPenalty; // [1] or [setupBatchSize]
OptVec<float> presencePenalty; // [1] or [setupBatchSize]
OptVec<float> frequencyPenalty; // [1] or [setupBatchSize]
OptVec<runtime::SizeType32> promptIgnoreLength; // [1] or [setupBatchSize]
};
// Ban words layer

View File

@ -83,7 +83,7 @@ void PenaltyLayer<T>::allocateWorkspace()
{
auto const workspaceSize = mDecoderDomain.getBatchSize() * mDecoderDomain.getMaxDecodingTokens()
* mConfiguredBeamWidth * mDecoderDomain.getVocabSize();
* mConfiguredBeamWidth * mDecoderDomain.getVocabSize() * 2;
mPenaltyWorkspaceDevice = mBufferManager->gpu(workspaceSize, nvinfer1::DataType::kINT32);
if (mDecodingMode.isBeamSearch())
@ -107,6 +107,7 @@ void PenaltyLayer<T>::allocateBuffer()
mPresencePenalty = mBufferManager->pinnedPool(batchSizeShape, TRTDataType<float>::value);
mFrequencyPenalty = mBufferManager->pinnedPool(batchSizeShape, TRTDataType<float>::value);
mMinLength = mBufferManager->pinnedPool(batchSizeShape, TRTDataType<SizeType32>::value);
mPromptIgnoreLength = mBufferManager->pinnedPool(batchSizeShape, TRTDataType<SizeType32>::value);
if (mDecodingMode.isUseTemperature())
{
@ -128,6 +129,10 @@ void PenaltyLayer<T>::allocateBuffer()
{
mMinLengthDevice = mBufferManager->gpu(batchSizeShape, nvinfer1::DataType::kINT32);
}
if (mDecodingMode.isUseOccurrencePenalty())
{
mPromptIgnoreLengthDevice = mBufferManager->gpu(batchSizeShape, nvinfer1::DataType::kINT32);
}
auto const logitsPtrDeviceDesc = std::make_pair(batchSizeShape, TRTDataType<T*>::value);
mWorkspaceSize = DecodingLayerWorkspace::calculateRequiredWorkspaceSize(logitsPtrDeviceDesc);
@ -169,6 +174,8 @@ void PenaltyLayer<T>::setup(SizeType32 batchSize, SizeType32 beamWidth, TensorCo
bool const useFrequencyPenalty
= mDecodingMode.isUseFrequencyPenalty() && penaltyParams->frequencyPenalty.has_value();
bool const useMinLength = mDecodingMode.isUseMinLength() && penaltyParams->minLength.has_value();
bool const usePromptIgnoreLength
= mDecodingMode.isUseOccurrencePenalty() && penaltyParams->promptIgnoreLength.has_value();
// FIXME: once one of the requests has some penalty, we will always have to compute it.
// To avoid that we need to scan through all active requests at each iteration.
mUseTemperature |= useTemperature;
@ -176,6 +183,7 @@ void PenaltyLayer<T>::setup(SizeType32 batchSize, SizeType32 beamWidth, TensorCo
mUsePresencePenalty |= usePresencePenalty;
mUseFrequencyPenalty |= useFrequencyPenalty;
mUseMinLength |= useMinLength;
mUsePromptIgnoreLength |= usePromptIgnoreLength;
if (mUseTemperature)
{
@ -203,10 +211,16 @@ void PenaltyLayer<T>::setup(SizeType32 batchSize, SizeType32 beamWidth, TensorCo
fillBuffers(penaltyParams->minLength, DefaultDecodingParams::getMinLength(), mMinLength, mMinLengthDevice,
batchSlots, getLimitsPenalty(DecodingPenaltyType::MinLength), "min length");
}
if (mUsePromptIgnoreLength)
{
fillBuffers(penaltyParams->promptIgnoreLength, DefaultDecodingParams::getPromptIgnoreLength(),
mPromptIgnoreLength, mPromptIgnoreLengthDevice, batchSlots,
getLimitsPenalty(DecodingPenaltyType::PromptIgnoreLength), "prompt ignore length");
}
// Reset penalty workspace
auto const workspaceSizePerBatch
= mDecoderDomain.getMaxDecodingTokens() * mConfiguredBeamWidth * mDecoderDomain.getVocabSize();
= mDecoderDomain.getMaxDecodingTokens() * mConfiguredBeamWidth * mDecoderDomain.getVocabSize() * 2;
for (SizeType32 bi = 0; bi < batchSize; ++bi)
{
auto batchSlot = runtime::bufferCast<runtime::SizeType32>(*batchSlots)[bi];
@ -287,6 +301,7 @@ void PenaltyLayer<T>::forwardAsync(std::shared_ptr<BaseDecodingOutputs> const& b
auto presencePenalties = GET_PENALTIES(PresencePenalty, float);
auto frequencyPenalties = GET_PENALTIES(FrequencyPenalty, float);
auto minLengths = GET_PENALTIES(MinLength, SizeType32);
auto promptIgnoreLengths = GET_PENALTIES(PromptIgnoreLength, SizeType32);
#undef GET_PENALTIES
@ -316,6 +331,7 @@ void PenaltyLayer<T>::forwardAsync(std::shared_ptr<BaseDecodingOutputs> const& b
penaltyParams.inputLengths = inputLengths;
penaltyParams.sequenceLengths = bufferCast<SizeType32>(*outputs->sequenceLength.value());
penaltyParams.minLengths = bufferCastOrNull<SizeType32>(minLengths);
penaltyParams.promptIgnoreLengths = bufferCastOrNull<SizeType32>(promptIgnoreLengths);
penaltyParams.endIds = bufferCast<TokenIdType>(*params->endIds);
penaltyParams.batchSlots = workspace->getDeviceBatchSlotsPtr();
penaltyParams.maxTokensPerStep = mDecoderDomain.getMaxDecodingTokens();

View File

@ -67,18 +67,21 @@ private:
TensorPtr mPresencePenaltyDevice;
TensorPtr mFrequencyPenaltyDevice;
TensorPtr mMinLengthDevice;
TensorPtr mPromptIgnoreLengthDevice;
TensorPtr mTemperature;
TensorPtr mRepetitionPenalty;
TensorPtr mPresencePenalty;
TensorPtr mFrequencyPenalty;
TensorPtr mMinLength;
TensorPtr mPromptIgnoreLength;
bool mUseTemperature{false};
bool mUseRepetitionPenalty{false};
bool mUsePresencePenalty{false};
bool mUseFrequencyPenalty{false};
bool mUseMinLength{false};
bool mUsePromptIgnoreLength{false};
runtime::SizeType32 mCyclicStep{0};
runtime::SizeType32 mRuntimeMaxSeqLen{0};

View File

@ -370,14 +370,14 @@ NB_MODULE(TRTLLM_NB_MODULE, m)
auto SamplingConfigGetState = [](tr::SamplingConfig const& config) -> nb::tuple
{
return nb::make_tuple(config.beamWidth, config.temperature, config.minLength, config.repetitionPenalty,
config.presencePenalty, config.frequencyPenalty, config.topK, config.topP, config.randomSeed,
config.topPDecay, config.topPMin, config.topPResetIds, config.beamSearchDiversityRate, config.lengthPenalty,
config.earlyStopping, config.noRepeatNgramSize, config.numReturnSequences, config.minP,
config.beamWidthArray);
config.presencePenalty, config.frequencyPenalty, config.promptIgnoreLength, config.topK, config.topP,
config.randomSeed, config.topPDecay, config.topPMin, config.topPResetIds, config.beamSearchDiversityRate,
config.lengthPenalty, config.earlyStopping, config.noRepeatNgramSize, config.numReturnSequences,
config.minP, config.beamWidthArray);
};
auto SamplingConfigSetState = [](tr::SamplingConfig& self, nb::tuple t)
{
if (t.size() != 19)
if (t.size() != 20)
{
throw std::runtime_error("Invalid SamplingConfig state!");
}
@ -389,19 +389,20 @@ NB_MODULE(TRTLLM_NB_MODULE, m)
config.repetitionPenalty = nb::cast<OptVec<float>>(t[3]);
config.presencePenalty = nb::cast<OptVec<float>>(t[4]);
config.frequencyPenalty = nb::cast<OptVec<float>>(t[5]);
config.topK = nb::cast<OptVec<SizeType32>>(t[6]);
config.topP = nb::cast<OptVec<float>>(t[7]);
config.randomSeed = nb::cast<OptVec<uint64_t>>(t[8]);
config.topPDecay = nb::cast<OptVec<float>>(t[9]);
config.topPMin = nb::cast<OptVec<float>>(t[10]);
config.topPResetIds = nb::cast<OptVec<TokenIdType>>(t[11]);
config.beamSearchDiversityRate = nb::cast<OptVec<float>>(t[12]);
config.lengthPenalty = nb::cast<OptVec<float>>(t[13]);
config.earlyStopping = nb::cast<OptVec<SizeType32>>(t[14]);
config.noRepeatNgramSize = nb::cast<OptVec<SizeType32>>(t[15]);
config.numReturnSequences = nb::cast<SizeType32>(t[16]);
config.minP = nb::cast<OptVec<float>>(t[17]);
config.beamWidthArray = nb::cast<OptVec<std::vector<SizeType32>>>(t[18]);
config.promptIgnoreLength = nb::cast<OptVec<SizeType32>>(t[6]);
config.topK = nb::cast<OptVec<SizeType32>>(t[7]);
config.topP = nb::cast<OptVec<float>>(t[8]);
config.randomSeed = nb::cast<OptVec<uint64_t>>(t[9]);
config.topPDecay = nb::cast<OptVec<float>>(t[10]);
config.topPMin = nb::cast<OptVec<float>>(t[11]);
config.topPResetIds = nb::cast<OptVec<TokenIdType>>(t[12]);
config.beamSearchDiversityRate = nb::cast<OptVec<float>>(t[13]);
config.lengthPenalty = nb::cast<OptVec<float>>(t[14]);
config.earlyStopping = nb::cast<OptVec<SizeType32>>(t[15]);
config.noRepeatNgramSize = nb::cast<OptVec<SizeType32>>(t[16]);
config.numReturnSequences = nb::cast<SizeType32>(t[17]);
config.minP = nb::cast<OptVec<float>>(t[18]);
config.beamWidthArray = nb::cast<OptVec<std::vector<SizeType32>>>(t[19]);
new (&self) tr::SamplingConfig(config);
};
@ -416,6 +417,7 @@ NB_MODULE(TRTLLM_NB_MODULE, m)
.def_rw("repetition_penalty", &tr::SamplingConfig::repetitionPenalty)
.def_rw("presence_penalty", &tr::SamplingConfig::presencePenalty)
.def_rw("frequency_penalty", &tr::SamplingConfig::frequencyPenalty)
.def_rw("prompt_ignore_length", &tr::SamplingConfig::promptIgnoreLength)
.def_rw("top_k", &tr::SamplingConfig::topK)
.def_rw("top_p", &tr::SamplingConfig::topP)
.def_rw("random_seed", &tr::SamplingConfig::randomSeed)

View File

@ -76,12 +76,12 @@ void initRequestBindings(nb::module_& m)
return nb::make_tuple(self.getBeamWidth(), self.getTopK(), self.getTopP(), self.getTopPMin(),
self.getTopPResetIds(), self.getTopPDecay(), self.getSeed(), self.getTemperature(), self.getMinTokens(),
self.getBeamSearchDiversityRate(), self.getRepetitionPenalty(), self.getPresencePenalty(),
self.getFrequencyPenalty(), self.getLengthPenalty(), self.getEarlyStopping(), self.getNoRepeatNgramSize(),
self.getNumReturnSequences(), self.getMinP(), self.getBeamWidthArray());
self.getFrequencyPenalty(), self.getPromptIgnoreLength(), self.getLengthPenalty(), self.getEarlyStopping(),
self.getNoRepeatNgramSize(), self.getNumReturnSequences(), self.getMinP(), self.getBeamWidthArray());
};
auto samplingConfigSetstate = [](tle::SamplingConfig& samplingConfig, nb::tuple const& state)
{
if (state.size() != 19)
if (state.size() != 20)
{
throw std::runtime_error("Invalid SamplingConfig state!");
}
@ -98,12 +98,13 @@ void initRequestBindings(nb::module_& m)
nb::cast<std::optional<FloatType>>(state[10]), // RepetitionPenalty
nb::cast<std::optional<FloatType>>(state[11]), // PresencePenalty
nb::cast<std::optional<FloatType>>(state[12]), // FrequencyPenalty
nb::cast<std::optional<FloatType>>(state[13]), // LengthPenalty
nb::cast<std::optional<SizeType32>>(state[14]), // EarlyStopping
nb::cast<std::optional<SizeType32>>(state[15]), // NoRepeatNgramSize
nb::cast<std::optional<SizeType32>>(state[16]), // NumReturnSequences
nb::cast<std::optional<FloatType>>(state[17]), // MinP
nb::cast<std::optional<std::vector<SizeType32>>>(state[18]) // BeamWidthArray
nb::cast<std::optional<SizeType32>>(state[13]), // PromptIgnoreLength
nb::cast<std::optional<FloatType>>(state[14]), // LengthPenalty
nb::cast<std::optional<SizeType32>>(state[15]), // EarlyStopping
nb::cast<std::optional<SizeType32>>(state[16]), // NoRepeatNgramSize
nb::cast<std::optional<SizeType32>>(state[17]), // NumReturnSequences
nb::cast<std::optional<FloatType>>(state[18]), // MinP
nb::cast<std::optional<std::vector<SizeType32>>>(state[19]) // BeamWidthArray
);
};
nb::class_<tle::SamplingConfig>(m, "SamplingConfig")
@ -120,6 +121,7 @@ void initRequestBindings(nb::module_& m)
std::optional<tle::FloatType> const&, // repetitionPenalty
std::optional<tle::FloatType> const&, // presencePenalty
std::optional<tle::FloatType> const&, // frequencyPenalty
std::optional<tle::SizeType32> const&, // promptIgnoreLength
std::optional<tle::FloatType> const&, // lengthPenalty
std::optional<tle::SizeType32> const&, // earlyStopping
std::optional<tle::SizeType32> const&, // noRepeatNgramSize
@ -142,6 +144,7 @@ void initRequestBindings(nb::module_& m)
nb::arg("repetition_penalty") = nb::none(),
nb::arg("presence_penalty") = nb::none(),
nb::arg("frequency_penalty") = nb::none(),
nb::arg("prompt_ignore_length") = nb::none(),
nb::arg("length_penalty") = nb::none(),
nb::arg("early_stopping") = nb::none(),
nb::arg("no_repeat_ngram_size") = nb::none(),
@ -165,6 +168,8 @@ void initRequestBindings(nb::module_& m)
[](tle::SamplingConfig& self, std::optional<FloatType> v) { self.setPresencePenalty(v); })
.def_prop_rw(
"frequency_penalty", &tle::SamplingConfig::getFrequencyPenalty, &tle::SamplingConfig::setFrequencyPenalty)
.def_prop_rw("prompt_ignore_length", &tle::SamplingConfig::getPromptIgnoreLength,
&tle::SamplingConfig::setPromptIgnoreLength)
.def_prop_rw("length_penalty", &tle::SamplingConfig::getLengthPenalty, &tle::SamplingConfig::setLengthPenalty)
.def_prop_rw("early_stopping", &tle::SamplingConfig::getEarlyStopping, &tle::SamplingConfig::setEarlyStopping)
.def_prop_rw("no_repeat_ngram_size", &tle::SamplingConfig::getNoRepeatNgramSize,

View File

@ -361,14 +361,14 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
auto SamplingConfigGetState = [](tr::SamplingConfig const& config) -> py::tuple
{
return py::make_tuple(config.beamWidth, config.temperature, config.minLength, config.repetitionPenalty,
config.presencePenalty, config.frequencyPenalty, config.topK, config.topP, config.randomSeed,
config.topPDecay, config.topPMin, config.topPResetIds, config.beamSearchDiversityRate, config.lengthPenalty,
config.earlyStopping, config.noRepeatNgramSize, config.numReturnSequences, config.minP,
config.beamWidthArray);
config.presencePenalty, config.frequencyPenalty, config.promptIgnoreLength, config.topK, config.topP,
config.randomSeed, config.topPDecay, config.topPMin, config.topPResetIds, config.beamSearchDiversityRate,
config.lengthPenalty, config.earlyStopping, config.noRepeatNgramSize, config.numReturnSequences,
config.minP, config.beamWidthArray);
};
auto SamplingConfigSetState = [](py::tuple t) -> tr::SamplingConfig
{
if (t.size() != 19)
if (t.size() != 20)
{
throw std::runtime_error("Invalid SamplingConfig state!");
}
@ -380,19 +380,20 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
config.repetitionPenalty = t[3].cast<OptVec<float>>();
config.presencePenalty = t[4].cast<OptVec<float>>();
config.frequencyPenalty = t[5].cast<OptVec<float>>();
config.topK = t[6].cast<OptVec<SizeType32>>();
config.topP = t[7].cast<OptVec<float>>();
config.randomSeed = t[8].cast<OptVec<uint64_t>>();
config.topPDecay = t[9].cast<OptVec<float>>();
config.topPMin = t[10].cast<OptVec<float>>();
config.topPResetIds = t[11].cast<OptVec<TokenIdType>>();
config.beamSearchDiversityRate = t[12].cast<OptVec<float>>();
config.lengthPenalty = t[13].cast<OptVec<float>>();
config.earlyStopping = t[14].cast<OptVec<SizeType32>>();
config.noRepeatNgramSize = t[15].cast<OptVec<SizeType32>>();
config.numReturnSequences = t[16].cast<SizeType32>();
config.minP = t[17].cast<OptVec<float>>();
config.beamWidthArray = t[18].cast<OptVec<std::vector<SizeType32>>>();
config.promptIgnoreLength = t[6].cast<OptVec<SizeType32>>();
config.topK = t[7].cast<OptVec<SizeType32>>();
config.topP = t[8].cast<OptVec<float>>();
config.randomSeed = t[9].cast<OptVec<uint64_t>>();
config.topPDecay = t[10].cast<OptVec<float>>();
config.topPMin = t[11].cast<OptVec<float>>();
config.topPResetIds = t[12].cast<OptVec<TokenIdType>>();
config.beamSearchDiversityRate = t[13].cast<OptVec<float>>();
config.lengthPenalty = t[14].cast<OptVec<float>>();
config.earlyStopping = t[15].cast<OptVec<SizeType32>>();
config.noRepeatNgramSize = t[16].cast<OptVec<SizeType32>>();
config.numReturnSequences = t[17].cast<SizeType32>();
config.minP = t[18].cast<OptVec<float>>();
config.beamWidthArray = t[19].cast<OptVec<std::vector<SizeType32>>>();
return config;
};
@ -407,6 +408,7 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
.def_readwrite("repetition_penalty", &tr::SamplingConfig::repetitionPenalty)
.def_readwrite("presence_penalty", &tr::SamplingConfig::presencePenalty)
.def_readwrite("frequency_penalty", &tr::SamplingConfig::frequencyPenalty)
.def_readwrite("prompt_ignore_length", &tr::SamplingConfig::promptIgnoreLength)
.def_readwrite("top_k", &tr::SamplingConfig::topK)
.def_readwrite("top_p", &tr::SamplingConfig::topP)
.def_readwrite("random_seed", &tr::SamplingConfig::randomSeed)

View File

@ -72,12 +72,12 @@ void initRequestBindings(pybind11::module_& m)
return py::make_tuple(self.getBeamWidth(), self.getTopK(), self.getTopP(), self.getTopPMin(),
self.getTopPResetIds(), self.getTopPDecay(), self.getSeed(), self.getTemperature(), self.getMinTokens(),
self.getBeamSearchDiversityRate(), self.getRepetitionPenalty(), self.getPresencePenalty(),
self.getFrequencyPenalty(), self.getLengthPenalty(), self.getEarlyStopping(), self.getNoRepeatNgramSize(),
self.getNumReturnSequences(), self.getMinP(), self.getBeamWidthArray());
self.getFrequencyPenalty(), self.getPromptIgnoreLength(), self.getLengthPenalty(), self.getEarlyStopping(),
self.getNoRepeatNgramSize(), self.getNumReturnSequences(), self.getMinP(), self.getBeamWidthArray());
};
auto samplingConfigSetstate = [](py::tuple const& state)
{
if (state.size() != 19)
if (state.size() != 20)
{
throw std::runtime_error("Invalid SamplingConfig state!");
}
@ -94,12 +94,13 @@ void initRequestBindings(pybind11::module_& m)
state[10].cast<std::optional<FloatType>>(), // RepetitionPenalty
state[11].cast<std::optional<FloatType>>(), // PresencePenalty
state[12].cast<std::optional<FloatType>>(), // FrequencyPenalty
state[13].cast<std::optional<FloatType>>(), // LengthPenalty
state[14].cast<std::optional<SizeType32>>(), // EarlyStopping
state[15].cast<std::optional<SizeType32>>(), // NoRepeatNgramSize
state[16].cast<std::optional<SizeType32>>(), // NumReturnSequences
state[17].cast<std::optional<FloatType>>(), // MinP
state[18].cast<std::optional<std::vector<SizeType32>>>() // BeamWidthArray
state[13].cast<std::optional<SizeType32>>(), // PromptIgnoreLength
state[14].cast<std::optional<FloatType>>(), // LengthPenalty
state[15].cast<std::optional<SizeType32>>(), // EarlyStopping
state[16].cast<std::optional<SizeType32>>(), // NoRepeatNgramSize
state[17].cast<std::optional<SizeType32>>(), // NumReturnSequences
state[18].cast<std::optional<FloatType>>(), // MinP
state[19].cast<std::optional<std::vector<SizeType32>>>() // BeamWidthArray
);
};
py::class_<tle::SamplingConfig>(m, "SamplingConfig")
@ -116,6 +117,7 @@ void initRequestBindings(pybind11::module_& m)
std::optional<tle::FloatType> const&, // repetitionPenalty
std::optional<tle::FloatType> const&, // presencePenalty
std::optional<tle::FloatType> const&, // frequencyPenalty
std::optional<tle::SizeType32> const&, // promptIgnoreLength
std::optional<tle::FloatType> const&, // lengthPenalty
std::optional<tle::SizeType32> const&, // earlyStopping
std::optional<tle::SizeType32> const&, // noRepeatNgramSize
@ -138,6 +140,7 @@ void initRequestBindings(pybind11::module_& m)
py::arg("repetition_penalty") = py::none(),
py::arg("presence_penalty") = py::none(),
py::arg("frequency_penalty") = py::none(),
py::arg("prompt_ignore_length") = py::none(),
py::arg("length_penalty") = py::none(),
py::arg("early_stopping") = py::none(),
py::arg("no_repeat_ngram_size") = py::none(),
@ -161,6 +164,8 @@ void initRequestBindings(pybind11::module_& m)
[](tle::SamplingConfig& self, std::optional<FloatType> v) { self.setPresencePenalty(v); })
.def_property(
"frequency_penalty", &tle::SamplingConfig::getFrequencyPenalty, &tle::SamplingConfig::setFrequencyPenalty)
.def_property("prompt_ignore_length", &tle::SamplingConfig::getPromptIgnoreLength,
&tle::SamplingConfig::setPromptIgnoreLength)
.def_property("length_penalty", &tle::SamplingConfig::getLengthPenalty, &tle::SamplingConfig::setLengthPenalty)
.def_property("early_stopping", &tle::SamplingConfig::getEarlyStopping, &tle::SamplingConfig::setEarlyStopping)
.def_property("no_repeat_ngram_size", &tle::SamplingConfig::getNoRepeatNgramSize,

View File

@ -84,6 +84,7 @@ void GptDecoder<T>::disableLookahead(
penaltyParams->repetitionPenalty = mSamplingConfig.repetitionPenalty;
penaltyParams->presencePenalty = mSamplingConfig.presencePenalty;
penaltyParams->frequencyPenalty = mSamplingConfig.frequencyPenalty;
penaltyParams->promptIgnoreLength = mSamplingConfig.promptIgnoreLength;
penaltyParams->temperature = mSamplingConfig.temperature;
penaltyParams->minLength = mSamplingConfig.minLength;
@ -136,6 +137,7 @@ void GptDecoder<T>::setup(SamplingConfig const& samplingConfig, size_t batchSize
penaltyParams->repetitionPenalty = mSamplingConfig.repetitionPenalty;
penaltyParams->presencePenalty = mSamplingConfig.presencePenalty;
penaltyParams->frequencyPenalty = mSamplingConfig.frequencyPenalty;
penaltyParams->promptIgnoreLength = mSamplingConfig.promptIgnoreLength;
penaltyParams->temperature = mSamplingConfig.temperature;
penaltyParams->minLength = mSamplingConfig.minLength;

View File

@ -120,12 +120,12 @@ void FtDynamicDecode<T>::setup(size_t const batch_size, size_t const beam_width,
th::optional<th::Tensor> runtime_top_k_opt, th::optional<th::Tensor> runtime_top_p_opt,
th::optional<th::Tensor> temperature_opt, th::optional<th::Tensor> repetition_penalty_opt,
th::optional<th::Tensor> presence_penalty_opt, th::optional<th::Tensor> frequency_penalty_opt,
th::optional<th::Tensor> min_length_opt, th::optional<th::Tensor> length_penalty_opt,
th::optional<th::Tensor> early_stopping_opt, th::optional<th::Tensor> beam_search_diversity_rate_opt,
th::optional<th::Tensor> random_seed_opt, th::optional<th::Tensor> top_p_decay_opt,
th::optional<th::Tensor> top_p_min_opt, th::optional<th::Tensor> top_p_reset_ids_opt,
th::optional<th::Tensor> no_repeat_ngram_size_opt, th::optional<th::Tensor> min_p_opt, bool output_log_probs,
bool cum_log_probs)
th::optional<th::Tensor> prompt_ignore_length_opt, th::optional<th::Tensor> min_length_opt,
th::optional<th::Tensor> length_penalty_opt, th::optional<th::Tensor> early_stopping_opt,
th::optional<th::Tensor> beam_search_diversity_rate_opt, th::optional<th::Tensor> random_seed_opt,
th::optional<th::Tensor> top_p_decay_opt, th::optional<th::Tensor> top_p_min_opt,
th::optional<th::Tensor> top_p_reset_ids_opt, th::optional<th::Tensor> no_repeat_ngram_size_opt,
th::optional<th::Tensor> min_p_opt, bool output_log_probs, bool cum_log_probs)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
mBeamWidth = beam_width;
@ -137,6 +137,7 @@ void FtDynamicDecode<T>::setup(size_t const batch_size, size_t const beam_width,
safeInsert(repetition_penalty_opt, penaltyParams->repetitionPenalty);
safeInsert(presence_penalty_opt, penaltyParams->presencePenalty);
safeInsert(frequency_penalty_opt, penaltyParams->frequencyPenalty);
safeInsert(prompt_ignore_length_opt, penaltyParams->promptIgnoreLength);
safeInsert(min_length_opt, penaltyParams->minLength);
safeInsert(no_repeat_ngram_size_opt, banWordsParams->noRepeatNgramSize);
if (beam_width == 1)
@ -328,10 +329,10 @@ void DynamicDecodeOp::createInstance()
void DynamicDecodeOp::setup(int64_t const batchSize, int64_t const beamWidth, th::optional<th::Tensor> runtimeTopKOpt,
th::optional<th::Tensor> runtimeTopPOpt, th::optional<th::Tensor> temperatureOpt,
th::optional<th::Tensor> repetitionPenaltyOpt, th::optional<th::Tensor> presencePenaltyOpt,
th::optional<th::Tensor> frequencyPenaltyOpt, th::optional<th::Tensor> minLengthOpt,
th::optional<th::Tensor> lengthPenaltyOpt, th::optional<th::Tensor> earlyStoppingOpt,
th::optional<th::Tensor> beamSearchDiversityRateOpt, th::optional<th::Tensor> randomSeedOpt,
th::optional<th::Tensor> topPDecayOpt, th::optional<th::Tensor> topPMinOpt,
th::optional<th::Tensor> frequencyPenaltyOpt, th::optional<th::Tensor> promptIgnoreLengthOpt,
th::optional<th::Tensor> minLengthOpt, th::optional<th::Tensor> lengthPenaltyOpt,
th::optional<th::Tensor> earlyStoppingOpt, th::optional<th::Tensor> beamSearchDiversityRateOpt,
th::optional<th::Tensor> randomSeedOpt, th::optional<th::Tensor> topPDecayOpt, th::optional<th::Tensor> topPMinOpt,
th::optional<th::Tensor> topPResetIdsOpt, th::optional<th::Tensor> noRepeatNgramSizeOpt,
th::optional<th::Tensor> minPOpt, bool outputLogProbs, bool cumLogProbs)
{
@ -343,6 +344,7 @@ void DynamicDecodeOp::setup(int64_t const batchSize, int64_t const beamWidth, th
CHECK_OPTIONAL_CPU_INPUT(repetitionPenaltyOpt, torch::kFloat);
CHECK_OPTIONAL_CPU_INPUT(presencePenaltyOpt, torch::kFloat);
CHECK_OPTIONAL_CPU_INPUT(frequencyPenaltyOpt, torch::kFloat);
CHECK_OPTIONAL_CPU_INPUT(promptIgnoreLengthOpt, torch::kInt32);
CHECK_OPTIONAL_CPU_INPUT(minLengthOpt, torch::kInt32);
CHECK_OPTIONAL_CPU_INPUT(lengthPenaltyOpt, torch::kFloat);
CHECK_OPTIONAL_CPU_INPUT(earlyStoppingOpt, torch::kInt32);
@ -356,8 +358,9 @@ void DynamicDecodeOp::setup(int64_t const batchSize, int64_t const beamWidth, th
dynamicDecode_->setup(static_cast<tr::SizeType32>(batchSize), static_cast<tr::SizeType32>(beamWidth),
runtimeTopKOpt, runtimeTopPOpt, temperatureOpt, repetitionPenaltyOpt, presencePenaltyOpt, frequencyPenaltyOpt,
minLengthOpt, lengthPenaltyOpt, earlyStoppingOpt, beamSearchDiversityRateOpt, randomSeedOpt, topPDecayOpt,
topPMinOpt, topPResetIdsOpt, noRepeatNgramSizeOpt, minPOpt, outputLogProbs, cumLogProbs);
promptIgnoreLengthOpt, minLengthOpt, lengthPenaltyOpt, earlyStoppingOpt, beamSearchDiversityRateOpt,
randomSeedOpt, topPDecayOpt, topPMinOpt, topPResetIdsOpt, noRepeatNgramSizeOpt, minPOpt, outputLogProbs,
cumLogProbs);
}
th::Tensor DynamicDecodeOp::forward(

View File

@ -32,12 +32,13 @@ public:
virtual void setup(size_t const batch_size, size_t const beam_width, th::optional<th::Tensor> runtime_top_k_opt,
th::optional<th::Tensor> runtime_top_p_opt, th::optional<th::Tensor> temperature_opt,
th::optional<th::Tensor> repetition_penalty_opt, th::optional<th::Tensor> presence_penalty_opt,
th::optional<th::Tensor> frequency_penalty_opt, th::optional<th::Tensor> min_length_opt,
th::optional<th::Tensor> length_penalty_opt, th::optional<th::Tensor> early_stopping_opt,
th::optional<th::Tensor> beam_search_diversity_rate_opt, th::optional<th::Tensor> random_seed_opt,
th::optional<th::Tensor> top_p_decay_opt, th::optional<th::Tensor> top_p_min_opt,
th::optional<th::Tensor> top_p_reset_ids_opt, th::optional<th::Tensor> no_repeat_ngram_size_opt,
th::optional<th::Tensor> min_p_opt, bool output_log_probs, bool cum_log_probs)
th::optional<th::Tensor> frequency_penalty_opt, th::optional<th::Tensor> prompt_ignore_length_opt,
th::optional<th::Tensor> min_length_opt, th::optional<th::Tensor> length_penalty_opt,
th::optional<th::Tensor> early_stopping_opt, th::optional<th::Tensor> beam_search_diversity_rate_opt,
th::optional<th::Tensor> random_seed_opt, th::optional<th::Tensor> top_p_decay_opt,
th::optional<th::Tensor> top_p_min_opt, th::optional<th::Tensor> top_p_reset_ids_opt,
th::optional<th::Tensor> no_repeat_ngram_size_opt, th::optional<th::Tensor> min_p_opt, bool output_log_probs,
bool cum_log_probs)
= 0;
virtual void forward(th::Tensor const& logits, int const step, int const max_input_length,
@ -72,12 +73,13 @@ public:
void setup(size_t const batch_size, size_t const beam_width, th::optional<th::Tensor> runtime_top_k_opt,
th::optional<th::Tensor> runtime_top_p_opt, th::optional<th::Tensor> temperature_opt,
th::optional<th::Tensor> repetition_penalty_opt, th::optional<th::Tensor> presence_penalty_opt,
th::optional<th::Tensor> frequency_penalty_opt, th::optional<th::Tensor> min_length_opt,
th::optional<th::Tensor> length_penalty_opt, th::optional<th::Tensor> early_stopping_opt,
th::optional<th::Tensor> beam_search_diversity_rate_opt, th::optional<th::Tensor> random_seed_opt,
th::optional<th::Tensor> top_p_decay_opt, th::optional<th::Tensor> top_p_min_opt,
th::optional<th::Tensor> top_p_reset_ids_opt, th::optional<th::Tensor> no_repeat_ngram_size_opt,
th::optional<th::Tensor> min_p_opt, bool output_log_probs, bool cum_log_probs) override;
th::optional<th::Tensor> frequency_penalty_opt, th::optional<th::Tensor> prompt_ignore_length_opt,
th::optional<th::Tensor> min_length_opt, th::optional<th::Tensor> length_penalty_opt,
th::optional<th::Tensor> early_stopping_opt, th::optional<th::Tensor> beam_search_diversity_rate_opt,
th::optional<th::Tensor> random_seed_opt, th::optional<th::Tensor> top_p_decay_opt,
th::optional<th::Tensor> top_p_min_opt, th::optional<th::Tensor> top_p_reset_ids_opt,
th::optional<th::Tensor> no_repeat_ngram_size_opt, th::optional<th::Tensor> min_p_opt, bool output_log_probs,
bool cum_log_probs) override;
void forward(th::Tensor const& logits, int const step, int const max_input_length, int const max_attention_window,
int const sink_token_length, uint64_t const ite, int const local_batch_size, th::Tensor end_id,
@ -115,12 +117,13 @@ public:
void setup(int64_t const batch_size, int64_t const beam_width, th::optional<th::Tensor> runtime_top_k_opt,
th::optional<th::Tensor> runtime_top_p_opt, th::optional<th::Tensor> temperature_opt,
th::optional<th::Tensor> repetition_penalty_opt, th::optional<th::Tensor> presence_penalty_opt,
th::optional<th::Tensor> frequency_penalty_opt, th::optional<th::Tensor> min_length_opt,
th::optional<th::Tensor> length_penalty_opt, th::optional<th::Tensor> early_stopping_opt,
th::optional<th::Tensor> beam_search_diversity_rate_opt, th::optional<th::Tensor> random_seed_opt,
th::optional<th::Tensor> top_p_decay_opt, th::optional<th::Tensor> top_p_min_opt,
th::optional<th::Tensor> top_p_reset_ids_opt, th::optional<th::Tensor> no_repeat_ngram_size_opt,
th::optional<th::Tensor> min_p_opt, bool output_log_probs, bool cum_log_probs);
th::optional<th::Tensor> frequency_penalty_opt, th::optional<th::Tensor> prompt_ignore_length_opt,
th::optional<th::Tensor> min_length_opt, th::optional<th::Tensor> length_penalty_opt,
th::optional<th::Tensor> early_stopping_opt, th::optional<th::Tensor> beam_search_diversity_rate_opt,
th::optional<th::Tensor> random_seed_opt, th::optional<th::Tensor> top_p_decay_opt,
th::optional<th::Tensor> top_p_min_opt, th::optional<th::Tensor> top_p_reset_ids_opt,
th::optional<th::Tensor> no_repeat_ngram_size_opt, th::optional<th::Tensor> min_p_opt, bool output_log_probs,
bool cum_log_probs);
th::Tensor forward(th::Tensor const& logits, int64_t const step, int64_t const max_input_length,
int64_t const max_attention_window, int64_t const sink_token_length, int64_t const ite,

View File

@ -34,17 +34,18 @@ void test(bool const isTestValid, SizeType32 beamWidth = 1, std::optional<SizeTy
std::optional<RandomSeedType> randomSeed = no, std::optional<FloatType> temperature = no,
std::optional<SizeType32> minLength = no, std::optional<FloatType> beamSearchDiversityRate = no,
std::optional<FloatType> repetitionPenalty = no, std::optional<FloatType> presencePenalty = no,
std::optional<FloatType> frequencyPenalty = no, std::optional<FloatType> lengthPenalty = no,
std::optional<SizeType32> earlyStopping = no, std::optional<SizeType32> noRepeatNgramSize = no,
std::optional<SizeType32> numReturnSequences = no, std::optional<FloatType> minP = no,
std::optional<std::vector<SizeType32>> beamWidthArray = no)
std::optional<FloatType> frequencyPenalty = no, std::optional<SizeType32> promptIgnoreLength = no,
std::optional<FloatType> lengthPenalty = no, std::optional<SizeType32> earlyStopping = no,
std::optional<SizeType32> noRepeatNgramSize = no, std::optional<SizeType32> numReturnSequences = no,
std::optional<FloatType> minP = no, std::optional<std::vector<SizeType32>> beamWidthArray = no)
{
// 19 parameters for SamplingConfig, from `beamWidth` to `beamWidthArray`
// 20 parameters for SamplingConfig, from `beamWidth` to `beamWidthArray`
try
{
auto sc = SamplingConfig(beamWidth, topK, topP, topPMin, topPResetIds, topPDecay, randomSeed, temperature,
minLength, beamSearchDiversityRate, repetitionPenalty, presencePenalty, frequencyPenalty, lengthPenalty,
earlyStopping, noRepeatNgramSize, numReturnSequences, minP, beamWidthArray);
minLength, beamSearchDiversityRate, repetitionPenalty, presencePenalty, frequencyPenalty,
promptIgnoreLength, lengthPenalty, earlyStopping, noRepeatNgramSize, numReturnSequences, minP,
beamWidthArray);
// Come here if `sc` is valid
if (!isTestValid)
@ -102,18 +103,20 @@ TEST(SamplingConfigTest, validInputs)
test(true, 1, no, no, no, no, no, no, no, no, no, no, 1.f);
// Frequency penalty
test(true, 1, no, no, no, no, no, no, no, no, no, no, no, 1.f);
// Prompt ignore length
test(true, 1, no, no, no, no, no, no, no, no, no, no, no, no, 1);
// Length penalty
test(true, 1, no, no, no, no, no, no, no, no, no, no, no, no, 1.f);
// Early stopping
test(true, 1, no, no, no, no, no, no, no, no, no, no, no, no, no, 1.f);
// Early stopping
test(true, 1, no, no, no, no, no, no, no, no, no, no, no, no, no, no, 1.f);
// No repeat ngram size
test(true, 1, no, no, no, no, no, no, no, no, no, no, no, no, no, no, 2);
test(true, 1, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, 2);
// NumReturnSequences
test(true, 4, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, 2);
test(true, 4, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, 2);
// MinP
test(true, 1, no, 0.9, no, no, no, no, no, no, no, no, no, no, no, no, no, no, 0.5f);
test(true, 1, no, 0.9, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, 0.5f);
// BeamWidthArray
test(true, 5, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no,
test(true, 5, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no,
std::vector<SizeType32>{2, 3, 4, 5});
}
@ -156,32 +159,35 @@ TEST(SamplingConfigTest, invalidInputs)
// Skip presence penalty, frequency penalty, no test
// Neg length penalty
// Neg prompt ignore length
test(false, 1, no, no, no, no, no, no, no, no, no, no, no, no, -1);
// Neg early stopping
// Neg length penalty
test(false, 1, no, no, no, no, no, no, no, no, no, no, no, no, no, -1);
// Neg no repeat ngram size
// Neg early stopping
test(false, 1, no, no, no, no, no, no, no, no, no, no, no, no, no, no, -1);
// Neg no repeat ngram size
test(false, 1, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, -1);
// Neg or zero numReturnSequences
test(false, 1, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, 0);
test(false, 1, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, 0);
// numReturnSequences > beamWidth
test(false, 2, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, 4);
test(false, 2, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, 4);
// Neg minP
test(false, 1, no, 0.9, no, no, no, no, no, no, no, no, no, no, no, no, no, no, -1.f);
test(false, 1, no, 0.9, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, -1.f);
// Neg / Large minP
test(false, 1, no, 0.9, no, no, no, no, no, no, no, no, no, no, no, no, no, no, -1.f);
test(false, 1, no, 0.9, no, no, no, no, no, no, no, no, no, no, no, no, no, no, +2.f);
test(false, 1, no, 0.9, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, -1.f);
test(false, 1, no, 0.9, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, +2.f);
// BeamWidthArray with neg / large beamWidth
test(false, 4, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no,
test(false, 4, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no,
std::vector<SizeType32>{2, 3, 4, -1});
test(false, 4, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no,
test(false, 4, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no,
std::vector<SizeType32>{2, 3, 4, 65536});
}
@ -265,6 +271,12 @@ TEST(SamplingConfigTest, getterSetter)
sc.setFrequencyPenalty(0.5f);
EXPECT_EQ(sc.getFrequencyPenalty(), 0.5f);
}
// Prompt ignore length
{
auto sc = SamplingConfig();
sc.setPromptIgnoreLength(1);
EXPECT_EQ(sc.getPromptIgnoreLength(), 1);
}
// Length penalty
{
auto sc = SamplingConfig();

View File

@ -161,7 +161,7 @@ protected:
mLogitsPtrs = BufferManager::pinned(ITensor::makeShape({mBatchSize}), ptrType);
mPenaltyWorkspaceDevice = mBufferManager->gpu(
ITensor::makeShape({mMaxBatchSize, mMaxTokensPerStep, mVocabSizePadded}), nvinfer1::DataType::kINT32);
ITensor::makeShape({mMaxBatchSize, mMaxTokensPerStep, mVocabSize * 2}), nvinfer1::DataType::kINT32);
mTokensPerStep = BufferManager::pinned(ITensor::makeShape({mMaxBatchSize}), nvinfer1::DataType::kINT32);
@ -259,8 +259,8 @@ public:
InvokeBatchApplyPenaltyParams<T> penaltyParams{reinterpret_cast<T**>(bufferCast<int64_t>(*mLogitsPtrs)),
bufferCast<T>(*mOutLogitsDevice), bufferCast<T>(*mBiasDevice),
bufferCast<int32_t>(*mPenaltyWorkspaceDevice), nullptr, bufferCast<float>(*mTemperaturesDevice), nullptr,
nullptr, nullptr, mBatchSize, 1, 1, mVocabSize, mVocabSizePadded, nullptr, nullptr, nullptr, nullptr,
nullptr, nullptr, bufferCast<int32_t>(*mBatchSlots), mMaxTokensPerStep,
nullptr, nullptr, nullptr, mBatchSize, 1, 1, mVocabSize, mVocabSizePadded, nullptr, nullptr, nullptr,
nullptr, nullptr, nullptr, bufferCast<int32_t>(*mBatchSlots), mMaxTokensPerStep,
bufferCast<int32_t>(*mTokensPerStep), nullptr, mStream->get()};
tk::invokeBatchApplyPenalty(penaltyParams);
auto logitsOutHost = mBufferManager->copyFrom(*mOutLogitsDevice, MemoryType::kCPU);
@ -382,9 +382,11 @@ struct RepetitionPenaltyTestCase
TensorPtr repetitionPenalties;
TensorPtr presencePenalties;
TensorPtr frequencyPenalties;
TensorPtr promptIgnoreLengths;
int32_t repetitionPenaltiesSize;
int32_t presencePenaltiesSize;
int32_t frequencyPenaltiesSize;
int32_t promptIgnoreLengthsSize;
int32_t maxTokensPerStep{1};
RepetitionPenaltyTestCase& setBatchSize(int32_t bs)
@ -423,6 +425,12 @@ struct RepetitionPenaltyTestCase
return *this;
}
RepetitionPenaltyTestCase& setPromptIgnoreLengths(TensorPtr pil)
{
promptIgnoreLengths = pil;
return *this;
}
RepetitionPenaltyTestCase& setRepetitionPenaltiesSize(int32_t rps)
{
repetitionPenaltiesSize = rps;
@ -441,6 +449,12 @@ struct RepetitionPenaltyTestCase
return *this;
}
RepetitionPenaltyTestCase& setPromptIgnoreLengthsSize(int32_t pils)
{
promptIgnoreLengthsSize = pils;
return *this;
}
RepetitionPenaltyTestCase& setMaxTokensPerStep(int32_t ts)
{
maxTokensPerStep = ts;
@ -451,11 +465,12 @@ struct RepetitionPenaltyTestCase
{
return tc::fmtstr(
"RepetitionPenaltyTestCase[batch=%d, vocab=%d, maxInputLength=%d, "
"repetitionPenalties=%s, presencePenalties=%s, frequencyPenalties=%s]",
"repetitionPenalties=%s, presencePenalties=%s, frequencyPenalties=%s, promptIgnoreLengths=%s]",
batchSize, vocabSize, maxInputLength,
tc::arr2str(bufferCast<float>(*repetitionPenalties), repetitionPenaltiesSize).c_str(),
tc::arr2str(bufferCast<float>(*presencePenalties), presencePenaltiesSize).c_str(),
tc::arr2str(bufferCast<float>(*frequencyPenalties), frequencyPenaltiesSize).c_str());
tc::arr2str(bufferCast<float>(*frequencyPenalties), frequencyPenaltiesSize).c_str(),
tc::arr2str(bufferCast<int32_t>(*promptIgnoreLengths), promptIgnoreLengthsSize).c_str());
}
};
@ -499,6 +514,7 @@ protected:
TensorPtr mRepetitionPenaltiesDevice;
TensorPtr mPresencePenaltiesDevice;
TensorPtr mFrequencyPenaltiesDevice;
TensorPtr mPromptIgnoreLengthsDevice;
TensorPtr mBatchSlots;
void subsetup(RepetitionPenaltyTestCase param)
@ -525,7 +541,7 @@ protected:
mLogitsPtrs = BufferManager::pinned(ITensor::makeShape({mBatchSize}), ptrType);
mPenaltyWorkspaceDevice = mBufferManager->gpu(
ITensor::makeShape({mBatchSize, mMaxTokensPerStep, mVocabSize}), nvinfer1::DataType::kINT32);
ITensor::makeShape({mBatchSize, mMaxTokensPerStep, mVocabSize * 2}), nvinfer1::DataType::kINT32);
mTokensPerStep = BufferManager::pinned(ITensor::makeShape({mMaxBatchSize}), nvinfer1::DataType::kINT32);
@ -588,24 +604,29 @@ protected:
ASSERT_EQ(param.repetitionPenaltiesSize, mMaxBatchSize) << "Invalid test configuration.";
ASSERT_EQ(param.presencePenaltiesSize, mMaxBatchSize) << "Invalid test configuration.";
ASSERT_EQ(param.frequencyPenaltiesSize, mMaxBatchSize) << "Invalid test configuration.";
ASSERT_EQ(param.promptIgnoreLengthsSize, mMaxBatchSize) << "Invalid test configuration.";
mRepetitionPenaltiesDevice
= mBufferManager->gpu(ITensor::makeShape({param.repetitionPenaltiesSize}), nvinfer1::DataType::kFLOAT);
mPresencePenaltiesDevice
= mBufferManager->gpu(ITensor::makeShape({param.presencePenaltiesSize}), nvinfer1::DataType::kFLOAT);
mFrequencyPenaltiesDevice
= mBufferManager->gpu(ITensor::makeShape({param.frequencyPenaltiesSize}), nvinfer1::DataType::kFLOAT);
mPromptIgnoreLengthsDevice
= mBufferManager->gpu(ITensor::makeShape({param.promptIgnoreLengthsSize}), nvinfer1::DataType::kINT32);
mBufferManager->copy(*param.repetitionPenalties, *mRepetitionPenaltiesDevice);
mBufferManager->copy(*param.presencePenalties, *mPresencePenaltiesDevice);
mBufferManager->copy(*param.frequencyPenalties, *mFrequencyPenaltiesDevice);
mBufferManager->copy(*param.promptIgnoreLengths, *mPromptIgnoreLengthsDevice);
}
void computeReference(T const* const inLogits, T* const outLogits, int32_t const* const outputIds,
int32_t const* const sequenceLengths, float const* const repetitionPenalties,
float const* const presencePenalties, float const* const frequencyPenalties,
int32_t const repetitionPenaltiesSize, int32_t const presencePenaltiesSize,
int32_t const frequencyPenaltiesSize)
int32_t const* const promptIgnoreLengths, int32_t const repetitionPenaltiesSize,
int32_t const presencePenaltiesSize, int32_t const frequencyPenaltiesSize,
int32_t const promptIgnoreLengthsSize)
{
std::vector<bool> penalized(mVocabSize);
std::vector<bool> repetitionPenalized(mVocabSize), presencePenalized(mVocabSize);
auto const batchSlotsPtr = bufferCast<int32_t>(*mBatchSlots);
auto const tokensPerStepPtr = bufferCast<int32_t>(*mTokensPerStep);
@ -633,21 +654,47 @@ protected:
float presencePenalty = presencePenaltiesSize > 1 ? presencePenalties[batchSlot] : presencePenalties[0];
float frequencyPenalty
= frequencyPenaltiesSize > 1 ? frequencyPenalties[batchSlot] : frequencyPenalties[0];
int32_t promptIgnoreLength
= promptIgnoreLengthsSize > 1 ? promptIgnoreLengths[batchSlot] : promptIgnoreLengths[0];
std::fill(penalized.begin(), penalized.end(), false);
std::fill(repetitionPenalized.begin(), repetitionPenalized.end(), false);
std::fill(presencePenalized.begin(), presencePenalized.end(), false);
size_t offset = (bi * mMaxTokensPerStep + ti) * mVocabSizePadded;
auto const step = sequenceLengths[batchSlot];
// clamping to the inputLength (set to same as sequenceLength)
promptIgnoreLength = std::min(promptIgnoreLength, step);
std::vector<int32_t> numOccurences(mVocabSize, 0);
for (int32_t t = 0; t < step; ++t)
{
auto tokenId = outputIds[batchSlot * mSequenceLength + t];
if (!penalized[tokenId])
if (!repetitionPenalized[tokenId])
{
auto logit = static_cast<float>(outLogits[offset + tokenId]);
outLogits[offset + tokenId] = static_cast<T>(
(logit < 0.0f ? logit * repetitionPenalty : logit / repetitionPenalty) - presencePenalty);
penalized[tokenId] = true;
outLogits[offset + tokenId]
= static_cast<T>((logit < 0.0f ? logit * repetitionPenalty : logit / repetitionPenalty));
repetitionPenalized[tokenId] = true;
}
if (!(t < promptIgnoreLength))
{
presencePenalized[tokenId] = true;
numOccurences[tokenId] += 1;
}
}
for (int32_t vi = 0; vi < mVocabSize; ++vi)
{
if (presencePenalized[vi])
{
outLogits[offset + vi] -= presencePenalty;
}
if (numOccurences[vi] > 0)
{
outLogits[offset + vi] -= numOccurences[vi] * frequencyPenalty;
}
outLogits[offset + tokenId] -= frequencyPenalty;
}
}
}
@ -661,7 +708,8 @@ public:
InvokeBatchApplyPenaltyParams<T> penaltyParams{reinterpret_cast<T**>(bufferCast<int64_t>(*mLogitsPtrs)),
bufferCast<T>(*mOutLogitsDevice), nullptr, bufferCast<int32_t>(*mPenaltyWorkspaceDevice), nullptr, nullptr,
bufferCast<float>(*mRepetitionPenaltiesDevice), bufferCast<float>(*mPresencePenaltiesDevice),
bufferCast<float>(*mFrequencyPenaltiesDevice), mBatchSize, 1, mSequenceLength, mVocabSize, mVocabSizePadded,
bufferCast<float>(*mFrequencyPenaltiesDevice), bufferCast<int32_t>(*mPromptIgnoreLengthsDevice), mBatchSize,
1, mSequenceLength, mVocabSize, mVocabSizePadded,
reinterpret_cast<int32_t const**>(bufferCast<int64_t>(*mIdsPtrDevice)), nullptr,
bufferCast<int32_t>(*mContextLengthDevice), bufferCast<int32_t>(*mSeqLengthDevice), nullptr, nullptr,
bufferCast<int32_t>(*mBatchSlots), mMaxTokensPerStep, bufferCast<int32_t>(*mTokensPerStep), nullptr,
@ -673,8 +721,9 @@ public:
computeReference(bufferCast<T>(*mLogitsHost), bufferCast<T>(*mLogitsRefHost),
bufferCast<int32_t>(*mOutputIdsHost), bufferCast<int32_t>(*mSeqLengthHost),
bufferCast<float>(*param.repetitionPenalties), bufferCast<float>(*param.presencePenalties),
bufferCast<float>(*param.frequencyPenalties), param.repetitionPenaltiesSize, param.presencePenaltiesSize,
param.frequencyPenaltiesSize);
bufferCast<float>(*param.frequencyPenalties), bufferCast<int32_t>(*param.promptIgnoreLengths),
param.repetitionPenaltiesSize, param.presencePenaltiesSize, param.frequencyPenaltiesSize,
param.promptIgnoreLengthsSize);
mStream->synchronize();
@ -696,11 +745,14 @@ TYPED_TEST(RepetitionPenaltyTest, BatchNoPenalty)
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT);
TensorPtr frequencyPenaltyHost
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT);
TensorPtr promptIgnoreLengthsHost
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32);
for (int32_t i = 0; i < maxBatchSize; ++i)
{
bufferCast<float>(*repetitionPenaltyHost)[i] = 1.0f;
bufferCast<float>(*presencePenaltyHost)[i] = 0.0f;
bufferCast<float>(*frequencyPenaltyHost)[i] = 0.0f;
bufferCast<int32_t>(*promptIgnoreLengthsHost)[i] = 0;
}
this->runTest(RepetitionPenaltyTestCase()
.setBatchSize(batchSize)
@ -709,9 +761,11 @@ TYPED_TEST(RepetitionPenaltyTest, BatchNoPenalty)
.setRepetitionPenalties(repetitionPenaltyHost)
.setPresencePenalties(presencePenaltyHost)
.setFrequencyPenalties(frequencyPenaltyHost)
.setPromptIgnoreLengths(promptIgnoreLengthsHost)
.setRepetitionPenaltiesSize(maxBatchSize)
.setPresencePenaltiesSize(maxBatchSize)
.setFrequencyPenaltiesSize(maxBatchSize));
.setFrequencyPenaltiesSize(maxBatchSize)
.setPromptIgnoreLengthsSize(maxBatchSize));
}
TYPED_TEST(RepetitionPenaltyTest, BatchRepetitionLessThanOne)
@ -724,11 +778,14 @@ TYPED_TEST(RepetitionPenaltyTest, BatchRepetitionLessThanOne)
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT);
TensorPtr frequencyPenaltyHost
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT);
TensorPtr promptIgnoreLengthsHost
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32);
for (int32_t i = 0; i < maxBatchSize; ++i)
{
bufferCast<float>(*repetitionPenaltyHost)[i] = 0.53f;
bufferCast<float>(*presencePenaltyHost)[i] = 0.0f;
bufferCast<float>(*frequencyPenaltyHost)[i] = 0.0f;
bufferCast<int32_t>(*promptIgnoreLengthsHost)[i] = 0;
}
this->runTest(RepetitionPenaltyTestCase()
.setBatchSize(batchSize)
@ -737,9 +794,11 @@ TYPED_TEST(RepetitionPenaltyTest, BatchRepetitionLessThanOne)
.setRepetitionPenalties(repetitionPenaltyHost)
.setPresencePenalties(presencePenaltyHost)
.setFrequencyPenalties(frequencyPenaltyHost)
.setPromptIgnoreLengths(promptIgnoreLengthsHost)
.setRepetitionPenaltiesSize(maxBatchSize)
.setPresencePenaltiesSize(maxBatchSize)
.setFrequencyPenaltiesSize(maxBatchSize));
.setFrequencyPenaltiesSize(maxBatchSize)
.setPromptIgnoreLengthsSize(maxBatchSize));
}
TYPED_TEST(RepetitionPenaltyTest, BatchRepetitionGreaterThaneOne)
@ -752,11 +811,14 @@ TYPED_TEST(RepetitionPenaltyTest, BatchRepetitionGreaterThaneOne)
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT);
TensorPtr frequencyPenaltyHost
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT);
TensorPtr promptIgnoreLengthsHost
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32);
for (int32_t i = 0; i < maxBatchSize; ++i)
{
bufferCast<float>(*repetitionPenaltyHost)[i] = 2.01f;
bufferCast<float>(*presencePenaltyHost)[i] = 0.0f;
bufferCast<float>(*frequencyPenaltyHost)[i] = 0.0f;
bufferCast<int32_t>(*promptIgnoreLengthsHost)[i] = 0;
}
this->runTest(RepetitionPenaltyTestCase()
.setBatchSize(batchSize)
@ -765,9 +827,11 @@ TYPED_TEST(RepetitionPenaltyTest, BatchRepetitionGreaterThaneOne)
.setRepetitionPenalties(repetitionPenaltyHost)
.setPresencePenalties(presencePenaltyHost)
.setFrequencyPenalties(frequencyPenaltyHost)
.setPromptIgnoreLengths(promptIgnoreLengthsHost)
.setRepetitionPenaltiesSize(maxBatchSize)
.setPresencePenaltiesSize(maxBatchSize)
.setFrequencyPenaltiesSize(maxBatchSize));
.setFrequencyPenaltiesSize(maxBatchSize)
.setPromptIgnoreLengthsSize(maxBatchSize));
}
TYPED_TEST(RepetitionPenaltyTest, BatchRepetitionMixed)
@ -780,11 +844,14 @@ TYPED_TEST(RepetitionPenaltyTest, BatchRepetitionMixed)
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT);
TensorPtr frequencyPenaltyHost
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT);
TensorPtr promptIgnoreLengthsHost
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32);
for (int32_t i = 0; i < maxBatchSize; ++i)
{
bufferCast<float>(*repetitionPenaltyHost)[i] = 0.53 + i * 0.2f;
bufferCast<float>(*presencePenaltyHost)[i] = 0.0f;
bufferCast<float>(*frequencyPenaltyHost)[i] = 0.0f;
bufferCast<int32_t>(*promptIgnoreLengthsHost)[i] = 0;
}
this->runTest(RepetitionPenaltyTestCase()
.setBatchSize(batchSize)
@ -793,9 +860,11 @@ TYPED_TEST(RepetitionPenaltyTest, BatchRepetitionMixed)
.setRepetitionPenalties(repetitionPenaltyHost)
.setPresencePenalties(presencePenaltyHost)
.setFrequencyPenalties(frequencyPenaltyHost)
.setPromptIgnoreLengths(promptIgnoreLengthsHost)
.setRepetitionPenaltiesSize(maxBatchSize)
.setPresencePenaltiesSize(maxBatchSize)
.setFrequencyPenaltiesSize(maxBatchSize));
.setFrequencyPenaltiesSize(maxBatchSize)
.setPromptIgnoreLengthsSize(maxBatchSize));
}
TYPED_TEST(RepetitionPenaltyTest, BatchPresenceMixed)
@ -808,11 +877,14 @@ TYPED_TEST(RepetitionPenaltyTest, BatchPresenceMixed)
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT);
TensorPtr frequencyPenaltyHost
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT);
TensorPtr promptIgnoreLengthsHost
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32);
for (int32_t i = 0; i < maxBatchSize; ++i)
{
bufferCast<float>(*repetitionPenaltyHost)[i] = 1.0f;
bufferCast<float>(*presencePenaltyHost)[i] = 0.53 + i * 0.2f;
bufferCast<float>(*frequencyPenaltyHost)[i] = 0.0f;
bufferCast<int32_t>(*promptIgnoreLengthsHost)[i] = 0;
}
this->runTest(RepetitionPenaltyTestCase()
.setBatchSize(batchSize)
@ -821,9 +893,11 @@ TYPED_TEST(RepetitionPenaltyTest, BatchPresenceMixed)
.setRepetitionPenalties(repetitionPenaltyHost)
.setPresencePenalties(presencePenaltyHost)
.setFrequencyPenalties(frequencyPenaltyHost)
.setPromptIgnoreLengths(promptIgnoreLengthsHost)
.setRepetitionPenaltiesSize(maxBatchSize)
.setPresencePenaltiesSize(maxBatchSize)
.setFrequencyPenaltiesSize(maxBatchSize));
.setFrequencyPenaltiesSize(maxBatchSize)
.setPromptIgnoreLengthsSize(maxBatchSize));
}
TYPED_TEST(RepetitionPenaltyTest, BatchPresenceHasDefaultValueZero2)
@ -836,11 +910,14 @@ TYPED_TEST(RepetitionPenaltyTest, BatchPresenceHasDefaultValueZero2)
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT);
TensorPtr frequencyPenaltyHost
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT);
TensorPtr promptIgnoreLengthsHost
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32);
for (int32_t i = 0; i < maxBatchSize; ++i)
{
bufferCast<float>(*repetitionPenaltyHost)[i] = 1.0f;
bufferCast<float>(*presencePenaltyHost)[i] = i % 2 == 0 ? 1.0f : 0.0f;
bufferCast<float>(*frequencyPenaltyHost)[i] = 0.0f;
bufferCast<int32_t>(*promptIgnoreLengthsHost)[i] = 0;
}
this->runTest(RepetitionPenaltyTestCase()
.setBatchSize(batchSize)
@ -849,9 +926,11 @@ TYPED_TEST(RepetitionPenaltyTest, BatchPresenceHasDefaultValueZero2)
.setRepetitionPenalties(repetitionPenaltyHost)
.setPresencePenalties(presencePenaltyHost)
.setFrequencyPenalties(frequencyPenaltyHost)
.setPromptIgnoreLengths(promptIgnoreLengthsHost)
.setRepetitionPenaltiesSize(maxBatchSize)
.setPresencePenaltiesSize(maxBatchSize)
.setFrequencyPenaltiesSize(maxBatchSize));
.setFrequencyPenaltiesSize(maxBatchSize)
.setPromptIgnoreLengthsSize(maxBatchSize));
}
TYPED_TEST(RepetitionPenaltyTest, BatchFrequencyMixed)
@ -864,11 +943,14 @@ TYPED_TEST(RepetitionPenaltyTest, BatchFrequencyMixed)
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT);
TensorPtr frequencyPenaltyHost
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT);
TensorPtr promptIgnoreLengthsHost
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32);
for (int32_t i = 0; i < maxBatchSize; ++i)
{
bufferCast<float>(*repetitionPenaltyHost)[i] = 1.0f;
bufferCast<float>(*presencePenaltyHost)[i] = 0.0f;
bufferCast<float>(*frequencyPenaltyHost)[i] = 0.53 + i * 0.2f;
bufferCast<int32_t>(*promptIgnoreLengthsHost)[i] = 0;
}
this->runTest(RepetitionPenaltyTestCase()
.setBatchSize(batchSize)
@ -877,9 +959,11 @@ TYPED_TEST(RepetitionPenaltyTest, BatchFrequencyMixed)
.setRepetitionPenalties(repetitionPenaltyHost)
.setPresencePenalties(presencePenaltyHost)
.setFrequencyPenalties(frequencyPenaltyHost)
.setPromptIgnoreLengths(promptIgnoreLengthsHost)
.setRepetitionPenaltiesSize(maxBatchSize)
.setPresencePenaltiesSize(maxBatchSize)
.setFrequencyPenaltiesSize(maxBatchSize));
.setFrequencyPenaltiesSize(maxBatchSize)
.setPromptIgnoreLengthsSize(maxBatchSize));
}
TYPED_TEST(RepetitionPenaltyTest, BatchFrequencyHasDefaultValueZero2)
@ -892,11 +976,14 @@ TYPED_TEST(RepetitionPenaltyTest, BatchFrequencyHasDefaultValueZero2)
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT);
TensorPtr frequencyPenaltyHost
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT);
TensorPtr promptIgnoreLengthsHost
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32);
for (int32_t i = 0; i < maxBatchSize; ++i)
{
bufferCast<float>(*repetitionPenaltyHost)[i] = 1.0f;
bufferCast<float>(*presencePenaltyHost)[i] = 0.0f;
bufferCast<float>(*frequencyPenaltyHost)[i] = i % 2 == 0 ? 1.0f : 0.0f;
bufferCast<int32_t>(*promptIgnoreLengthsHost)[i] = 0;
}
this->runTest(RepetitionPenaltyTestCase()
.setBatchSize(batchSize)
@ -905,9 +992,11 @@ TYPED_TEST(RepetitionPenaltyTest, BatchFrequencyHasDefaultValueZero2)
.setRepetitionPenalties(repetitionPenaltyHost)
.setPresencePenalties(presencePenaltyHost)
.setFrequencyPenalties(frequencyPenaltyHost)
.setPromptIgnoreLengths(promptIgnoreLengthsHost)
.setRepetitionPenaltiesSize(maxBatchSize)
.setPresencePenaltiesSize(maxBatchSize)
.setFrequencyPenaltiesSize(maxBatchSize));
.setFrequencyPenaltiesSize(maxBatchSize)
.setPromptIgnoreLengthsSize(maxBatchSize));
}
TYPED_TEST(RepetitionPenaltyTest, PenaltyTypeRepetitionPresence)
@ -920,11 +1009,14 @@ TYPED_TEST(RepetitionPenaltyTest, PenaltyTypeRepetitionPresence)
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT);
TensorPtr frequencyPenaltyHost
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT);
TensorPtr promptIgnoreLengthsHost
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32);
for (int32_t i = 0; i < maxBatchSize; ++i)
{
bufferCast<float>(*repetitionPenaltyHost)[i] = 0.53 + i * 0.2f;
bufferCast<float>(*presencePenaltyHost)[i] = 0.53 + i * 0.2f;
bufferCast<float>(*frequencyPenaltyHost)[i] = 0.0f;
bufferCast<int32_t>(*promptIgnoreLengthsHost)[i] = 0;
}
this->runTest(RepetitionPenaltyTestCase()
.setBatchSize(batchSize)
@ -933,9 +1025,11 @@ TYPED_TEST(RepetitionPenaltyTest, PenaltyTypeRepetitionPresence)
.setRepetitionPenalties(repetitionPenaltyHost)
.setPresencePenalties(presencePenaltyHost)
.setFrequencyPenalties(frequencyPenaltyHost)
.setPromptIgnoreLengths(promptIgnoreLengthsHost)
.setRepetitionPenaltiesSize(maxBatchSize)
.setPresencePenaltiesSize(maxBatchSize)
.setFrequencyPenaltiesSize(maxBatchSize));
.setFrequencyPenaltiesSize(maxBatchSize)
.setPromptIgnoreLengthsSize(maxBatchSize));
}
TYPED_TEST(RepetitionPenaltyTest, PenaltyTypeRepetitionFrequency)
@ -948,11 +1042,14 @@ TYPED_TEST(RepetitionPenaltyTest, PenaltyTypeRepetitionFrequency)
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT);
TensorPtr frequencyPenaltyHost
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT);
TensorPtr promptIgnoreLengthsHost
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32);
for (int32_t i = 0; i < maxBatchSize; ++i)
{
bufferCast<float>(*repetitionPenaltyHost)[i] = 0.53 + i * 0.2f;
bufferCast<float>(*presencePenaltyHost)[i] = 0.0f;
bufferCast<float>(*frequencyPenaltyHost)[i] = 0.53 + i * 0.2f;
bufferCast<int32_t>(*promptIgnoreLengthsHost)[i] = 0;
}
this->runTest(RepetitionPenaltyTestCase()
.setBatchSize(batchSize)
@ -961,9 +1058,11 @@ TYPED_TEST(RepetitionPenaltyTest, PenaltyTypeRepetitionFrequency)
.setRepetitionPenalties(repetitionPenaltyHost)
.setPresencePenalties(presencePenaltyHost)
.setFrequencyPenalties(frequencyPenaltyHost)
.setPromptIgnoreLengths(promptIgnoreLengthsHost)
.setRepetitionPenaltiesSize(maxBatchSize)
.setPresencePenaltiesSize(maxBatchSize)
.setFrequencyPenaltiesSize(maxBatchSize));
.setFrequencyPenaltiesSize(maxBatchSize)
.setPromptIgnoreLengthsSize(maxBatchSize));
}
TYPED_TEST(RepetitionPenaltyTest, PenaltyTypePresenceFrequency)
@ -976,11 +1075,14 @@ TYPED_TEST(RepetitionPenaltyTest, PenaltyTypePresenceFrequency)
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT);
TensorPtr frequencyPenaltyHost
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT);
TensorPtr promptIgnoreLengthsHost
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32);
for (int32_t i = 0; i < maxBatchSize; ++i)
{
bufferCast<float>(*repetitionPenaltyHost)[i] = 1.0f;
bufferCast<float>(*presencePenaltyHost)[i] = 0.53 + i * 0.2f;
bufferCast<float>(*frequencyPenaltyHost)[i] = 0.53 + i * 0.2f;
bufferCast<int32_t>(*promptIgnoreLengthsHost)[i] = 0;
}
this->runTest(RepetitionPenaltyTestCase()
.setBatchSize(batchSize)
@ -989,9 +1091,11 @@ TYPED_TEST(RepetitionPenaltyTest, PenaltyTypePresenceFrequency)
.setRepetitionPenalties(repetitionPenaltyHost)
.setPresencePenalties(presencePenaltyHost)
.setFrequencyPenalties(frequencyPenaltyHost)
.setPromptIgnoreLengths(promptIgnoreLengthsHost)
.setRepetitionPenaltiesSize(maxBatchSize)
.setPresencePenaltiesSize(maxBatchSize)
.setFrequencyPenaltiesSize(maxBatchSize));
.setFrequencyPenaltiesSize(maxBatchSize)
.setPromptIgnoreLengthsSize(maxBatchSize));
}
TYPED_TEST(RepetitionPenaltyTest, PenaltyTypeFull)
@ -1004,11 +1108,14 @@ TYPED_TEST(RepetitionPenaltyTest, PenaltyTypeFull)
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT);
TensorPtr frequencyPenaltyHost
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT);
TensorPtr promptIgnoreLengthsHost
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32);
for (int32_t i = 0; i < maxBatchSize; ++i)
{
bufferCast<float>(*repetitionPenaltyHost)[i] = 0.53 + i * 0.2f;
bufferCast<float>(*presencePenaltyHost)[i] = 0.53 + i * 0.2f;
bufferCast<float>(*frequencyPenaltyHost)[i] = 0.53 + i * 0.2f;
bufferCast<int32_t>(*promptIgnoreLengthsHost)[i] = 0;
}
this->runTest(RepetitionPenaltyTestCase()
.setBatchSize(batchSize)
@ -1017,9 +1124,11 @@ TYPED_TEST(RepetitionPenaltyTest, PenaltyTypeFull)
.setRepetitionPenalties(repetitionPenaltyHost)
.setPresencePenalties(presencePenaltyHost)
.setFrequencyPenalties(frequencyPenaltyHost)
.setPromptIgnoreLengths(promptIgnoreLengthsHost)
.setRepetitionPenaltiesSize(maxBatchSize)
.setPresencePenaltiesSize(maxBatchSize)
.setFrequencyPenaltiesSize(maxBatchSize));
.setFrequencyPenaltiesSize(maxBatchSize)
.setPromptIgnoreLengthsSize(maxBatchSize));
}
TYPED_TEST(RepetitionPenaltyTest, PenaltyTypeFullTokensPerStep)
@ -1032,11 +1141,14 @@ TYPED_TEST(RepetitionPenaltyTest, PenaltyTypeFullTokensPerStep)
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT);
TensorPtr frequencyPenaltyHost
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT);
TensorPtr promptIgnoreLengthsHost
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32);
for (int32_t i = 0; i < maxBatchSize; ++i)
{
bufferCast<float>(*repetitionPenaltyHost)[i] = 0.53 + i * 0.2f;
bufferCast<float>(*presencePenaltyHost)[i] = 0.53 + i * 0.2f;
bufferCast<float>(*frequencyPenaltyHost)[i] = 0.53 + i * 0.2f;
bufferCast<int32_t>(*promptIgnoreLengthsHost)[i] = 0;
}
this->runTest(RepetitionPenaltyTestCase()
.setBatchSize(batchSize)
@ -1045,9 +1157,78 @@ TYPED_TEST(RepetitionPenaltyTest, PenaltyTypeFullTokensPerStep)
.setRepetitionPenalties(repetitionPenaltyHost)
.setPresencePenalties(presencePenaltyHost)
.setFrequencyPenalties(frequencyPenaltyHost)
.setPromptIgnoreLengths(promptIgnoreLengthsHost)
.setRepetitionPenaltiesSize(maxBatchSize)
.setPresencePenaltiesSize(maxBatchSize)
.setFrequencyPenaltiesSize(maxBatchSize)
.setPromptIgnoreLengthsSize(maxBatchSize)
.setMaxTokensPerStep(4));
}
TYPED_TEST(RepetitionPenaltyTest, PenaltyTypeFullWithPartialPromptIgnore)
{
int32_t batchSize = 6;
int32_t maxBatchSize = 2 * batchSize;
TensorPtr repetitionPenaltyHost
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT);
TensorPtr presencePenaltyHost
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT);
TensorPtr frequencyPenaltyHost
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT);
TensorPtr promptIgnoreLengthsHost
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32);
for (int32_t i = 0; i < maxBatchSize; ++i)
{
bufferCast<float>(*repetitionPenaltyHost)[i] = 0.53 + i * 0.2f;
bufferCast<float>(*presencePenaltyHost)[i] = 0.53 + i * 0.2f;
bufferCast<float>(*frequencyPenaltyHost)[i] = 0.53 + i * 0.2f;
bufferCast<int32_t>(*promptIgnoreLengthsHost)[i] = 1; // set to 1 to ignore first prompt token
}
this->runTest(RepetitionPenaltyTestCase()
.setBatchSize(batchSize)
.setVocabSize(4)
.setMaxInputLength(5)
.setRepetitionPenalties(repetitionPenaltyHost)
.setPresencePenalties(presencePenaltyHost)
.setFrequencyPenalties(frequencyPenaltyHost)
.setPromptIgnoreLengths(promptIgnoreLengthsHost)
.setRepetitionPenaltiesSize(maxBatchSize)
.setPresencePenaltiesSize(maxBatchSize)
.setFrequencyPenaltiesSize(maxBatchSize)
.setPromptIgnoreLengthsSize(maxBatchSize));
}
TYPED_TEST(RepetitionPenaltyTest, PenaltyTypeFullTokensPerStepWithFullPromptIgnore)
{
int32_t batchSize = 6;
int32_t maxBatchSize = 2 * batchSize;
TensorPtr repetitionPenaltyHost
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT);
TensorPtr presencePenaltyHost
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT);
TensorPtr frequencyPenaltyHost
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT);
TensorPtr promptIgnoreLengthsHost
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32);
for (int32_t i = 0; i < maxBatchSize; ++i)
{
bufferCast<float>(*repetitionPenaltyHost)[i] = 0.53 + i * 0.2f;
bufferCast<float>(*presencePenaltyHost)[i] = 0.53 + i * 0.2f;
bufferCast<float>(*frequencyPenaltyHost)[i] = 0.53 + i * 0.2f;
bufferCast<int32_t>(*promptIgnoreLengthsHost)[i] = 5; // set to max input length to ignore full prompt
}
this->runTest(RepetitionPenaltyTestCase()
.setBatchSize(batchSize)
.setVocabSize(4)
.setMaxInputLength(5)
.setRepetitionPenalties(repetitionPenaltyHost)
.setPresencePenalties(presencePenaltyHost)
.setFrequencyPenalties(frequencyPenaltyHost)
.setPromptIgnoreLengths(promptIgnoreLengthsHost)
.setRepetitionPenaltiesSize(maxBatchSize)
.setPresencePenaltiesSize(maxBatchSize)
.setFrequencyPenaltiesSize(maxBatchSize)
.setPromptIgnoreLengthsSize(maxBatchSize)
.setMaxTokensPerStep(4));
}
@ -1257,8 +1438,8 @@ public:
InvokeBatchApplyPenaltyParams<T> penaltyParams{reinterpret_cast<T**>(bufferCast<int64_t>(*mLogitsPtrs)),
bufferCast<T>(*mOutLogitsDevice), nullptr, bufferCast<int32_t>(*mPenaltyWorkspaceDevice), nullptr, nullptr,
nullptr, nullptr, nullptr, mBatchSize, 1, mSequenceLength, mVocabSize, mVocabSizePadded, nullptr, nullptr,
bufferCast<int32_t>(*mContextLengthDevice), bufferCast<int32_t>(*mSeqLengthDevice),
nullptr, nullptr, nullptr, nullptr, mBatchSize, 1, mSequenceLength, mVocabSize, mVocabSizePadded, nullptr,
nullptr, bufferCast<int32_t>(*mContextLengthDevice), bufferCast<int32_t>(*mSeqLengthDevice),
bufferCast<int32_t>(*mMinLengthDevice), bufferCast<int32_t>(*mEndIdsDevice),
bufferCast<int32_t>(*mBatchSlots), mMaxTokensPerStep, bufferCast<int32_t>(*mTokensPerStep), nullptr,
mStream->get()};
@ -1415,6 +1596,7 @@ public:
/*repetitionPenalties=*/nullptr,
/*presencePenalties=*/nullptr,
/*frequencyPenalties=*/nullptr,
/*promptIgnoreLengths=*/nullptr,
/*batchSize=*/mBatchSize,
/*beamWidth=*/1,
/*maxSeqLen=*/mSequenceLength,

View File

@ -40,6 +40,7 @@ struct TestSamplingParams
std::vector<float> repetitionPenalties;
std::vector<float> presencePenalties;
std::vector<float> frequencyPenalties;
std::vector<runtime::SizeType32> promptIgnoreLengths;
std::vector<runtime::SizeType32> minLengths;
std::vector<float> decay;
std::vector<float> minTopP;

View File

@ -37,17 +37,18 @@ void test(bool const useExternalDraftTokensConfig, SizeType32 beamWidth = 1, std
std::optional<RandomSeedType> randomSeed = no, std::optional<FloatType> temperature = no,
std::optional<SizeType32> minLength = no, std::optional<FloatType> beamSearchDiversityRate = no,
std::optional<FloatType> repetitionPenalty = no, std::optional<FloatType> presencePenalty = no,
std::optional<FloatType> frequencyPenalty = no, std::optional<FloatType> lengthPenalty = no,
std::optional<SizeType32> earlyStopping = no, std::optional<SizeType32> noRepeatNgramSize = no,
std::optional<SizeType32> numReturnSequences = no, std::optional<FloatType> minP = no,
std::optional<std::vector<SizeType32>> beamWidthArray = no)
std::optional<FloatType> frequencyPenalty = no, std::optional<SizeType32> promptIgnoreLength = no,
std::optional<FloatType> lengthPenalty = no, std::optional<SizeType32> earlyStopping = no,
std::optional<SizeType32> noRepeatNgramSize = no, std::optional<SizeType32> numReturnSequences = no,
std::optional<FloatType> minP = no, std::optional<std::vector<SizeType32>> beamWidthArray = no)
{
// 19 parameters for SamplingConfig, from `beamWidth` to `beamWidthArray`
// 20 parameters for SamplingConfig, from `beamWidth` to `beamWidthArray`
try
{
te::SamplingConfig execSamplingCfg(beamWidth, topK, topP, topPMin, topPResetIds, topPDecay, randomSeed,
temperature, minLength, beamSearchDiversityRate, repetitionPenalty, presencePenalty, frequencyPenalty,
lengthPenalty, earlyStopping, noRepeatNgramSize, numReturnSequences, minP, beamWidthArray);
promptIgnoreLength, lengthPenalty, earlyStopping, noRepeatNgramSize, numReturnSequences, minP,
beamWidthArray);
std::optional<te::ExternalDraftTokensConfig> specCfg = std::nullopt;
if (useExternalDraftTokensConfig)
{
@ -110,18 +111,20 @@ TEST(samplingConfigTest, validInputs)
test(false, 1, no, no, no, no, no, no, no, no, no, no, 1.f);
// Frequency penalty
test(false, 1, no, no, no, no, no, no, no, no, no, no, no, 1.f);
// Prompt ignore length
test(false, 1, no, no, no, no, no, no, no, no, no, no, no, no, 1);
// Length penalty
test(false, 1, no, no, no, no, no, no, no, no, no, no, no, no, 1.f);
// Early stopping
test(false, 1, no, no, no, no, no, no, no, no, no, no, no, no, no, 1.f);
// Early stopping
test(false, 1, no, no, no, no, no, no, no, no, no, no, no, no, no, no, 1.f);
// No repeat ngram size
test(false, 1, no, no, no, no, no, no, no, no, no, no, no, no, no, no, 2);
test(false, 1, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, 2);
// NumReturnSequences
test(false, 4, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, 2);
// MinP, 18 arguments
test(false, 1, no, 0.9, no, no, no, no, no, no, no, no, no, no, no, no, no, no, 0.5f);
test(false, 4, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, 2);
// MinP, 19 arguments
test(false, 1, no, 0.9, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, 0.5f);
// BeamWidthArray
test(false, 5, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no,
test(false, 5, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no,
std::vector<SizeType32>{2, 3, 4, 5});
// All parameters
@ -139,6 +142,7 @@ TEST(samplingConfigTest, validInputs)
te::FloatType repetitionPenalty{0.5f};
te::FloatType presencePenalty{0.5f};
te::FloatType frequencyPenalty{0.5f};
te::SizeType32 promptIgnoreLength{1};
te::FloatType lengthPenalty{0.5f};
te::SizeType32 earlyStopping{1};
te::SizeType32 noRepeatNgramSize{5};
@ -148,7 +152,8 @@ TEST(samplingConfigTest, validInputs)
te::SamplingConfig execSamplingCfg(beamWidth, topK, topP, topPMin, topPResetIds, topPDecay, randomSeed,
temperature, minLength, beamSearchDiversityRate, repetitionPenalty, presencePenalty, frequencyPenalty,
lengthPenalty, earlyStopping, noRepeatNgramSize, numReturnSequences, minP, beamWidthArray);
promptIgnoreLength, lengthPenalty, earlyStopping, noRepeatNgramSize, numReturnSequences, minP,
beamWidthArray);
te::ExternalDraftTokensConfig specCfg({1}, no, 0.5f);
tr::SamplingConfig samplingCfg(execSamplingCfg, specCfg);
EXPECT_EQ(samplingCfg.beamWidth, execSamplingCfg.getBeamWidth());
@ -166,6 +171,7 @@ TEST(samplingConfigTest, validInputs)
EXPECT_THAT(samplingCfg.repetitionPenalty.value(), testing::ElementsAre(repetitionPenalty));
EXPECT_THAT(samplingCfg.presencePenalty.value(), testing::ElementsAre(presencePenalty));
EXPECT_THAT(samplingCfg.frequencyPenalty.value(), testing::ElementsAre(frequencyPenalty));
EXPECT_THAT(samplingCfg.promptIgnoreLength.value(), testing::ElementsAre(promptIgnoreLength));
EXPECT_THAT(samplingCfg.lengthPenalty.value(), testing::ElementsAre(lengthPenalty));
EXPECT_THAT(samplingCfg.earlyStopping.value(), testing::ElementsAre(earlyStopping));
EXPECT_THAT(samplingCfg.noRepeatNgramSize.value(), testing::ElementsAre(noRepeatNgramSize));

View File

@ -281,6 +281,7 @@ def main(args):
repetition_penalty=args.repetition_penalty,
presence_penalty=args.presence_penalty,
frequency_penalty=args.frequency_penalty,
prompt_ignore_length=args.prompt_ignore_length,
# stop_words_list=stop_words_list,
# bad_words_list=bad_words_list,
output_cum_log_probs=(args.output_cum_log_probs_npy != None),

View File

@ -540,6 +540,7 @@ def main(args):
repetition_penalty=args.repetition_penalty,
presence_penalty=args.presence_penalty,
frequency_penalty=args.frequency_penalty,
prompt_ignore_length=args.prompt_ignore_length,
min_p=args.min_p,
stop_words_list=stop_words_list,
bad_words_list=bad_words_list,
@ -639,6 +640,7 @@ def main(args):
repetition_penalty=args.repetition_penalty,
presence_penalty=args.presence_penalty,
frequency_penalty=args.frequency_penalty,
prompt_ignore_length=args.prompt_ignore_length,
min_p=args.min_p,
stop_words_list=stop_words_list,
bad_words_list=bad_words_list,
@ -677,6 +679,7 @@ def main(args):
repetition_penalty=args.repetition_penalty,
presence_penalty=args.presence_penalty,
frequency_penalty=args.frequency_penalty,
prompt_ignore_length=args.prompt_ignore_length,
stop_words_list=stop_words_list,
bad_words_list=bad_words_list,
output_cum_log_probs=(args.output_cum_log_probs_npy

View File

@ -211,6 +211,7 @@ def main(args):
repetition_penalty = args.repetition_penalty
presence_penalty = args.presence_penalty
frequency_penalty = args.frequency_penalty
prompt_ignore_length = args.prompt_ignore_length
random_seed = args.random_seed
torch.manual_seed(random_seed)
@ -353,6 +354,7 @@ def main(args):
repetition_penalty=repetition_penalty,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
prompt_ignore_length=prompt_ignore_length,
lora_uids=args.lora_task_uids,
lookahead_config=args.lookahead_config,
output_sequence_lengths=True,

View File

@ -304,6 +304,7 @@ def add_common_args(parser):
parser.add_argument('--repetition_penalty', type=float, default=1.0)
parser.add_argument('--presence_penalty', type=float, default=0.0)
parser.add_argument('--frequency_penalty', type=float, default=0.0)
parser.add_argument('--prompt_ignore_length', type=int, default=0)
parser.add_argument('--min_p', type=float, default=0.0)
parser.add_argument('--beam_search_diversity_rate', type=float, default=0.0)
parser.add_argument('--random_seed', type=int, default=0)

View File

@ -721,6 +721,7 @@ class SamplingConfig:
min_length: Union[int, torch.Tensor] = field(default=1)
presence_penalty: Union[float, torch.Tensor] = field(default=0.0)
frequency_penalty: Union[float, torch.Tensor] = field(default=0.0)
prompt_ignore_length: Union[int, torch.Tensor] = field(default=0)
use_beam_hyps: bool = field(default=True)
# None here means user didn't set it, and dynamicDecodeOp.cpp take optional value
@ -1474,6 +1475,16 @@ class GenerationSession(object):
scfg.frequency_penalty,
dtype=torch.float32)
if isinstance(scfg.prompt_ignore_length, torch.Tensor):
assert scfg.prompt_ignore_length.dtype == torch.int32, f"scfg.prompt_ignore_length.dtype ({scfg.prompt_ignore_length.dtype}) must be torch.int32"
assert scfg.prompt_ignore_length.shape[
0] == batch_size, f"scfg.prompt_ignore_length.shape[0] ({scfg.prompt_ignore_length.shape[0]}) must equal to batch_size ({batch_size})"
self.prompt_ignore_length = scfg.prompt_ignore_length
else:
self.prompt_ignore_length = torch.full([batch_size],
scfg.prompt_ignore_length,
dtype=torch.int32)
if isinstance(scfg.min_length, torch.Tensor):
assert scfg.min_length.dtype == torch.int32, f"scfg.min_length.dtype ({scfg.min_length.dtype}) must be torch.int32"
assert scfg.min_length.shape[
@ -1543,6 +1554,7 @@ class GenerationSession(object):
self.repetition_penalty,
self.presence_penalty,
self.frequency_penalty,
self.prompt_ignore_length,
self.min_length,
self.host_length_penalty,
self.host_early_stopping,

View File

@ -648,6 +648,7 @@ class ModelRunnerCpp(ModelRunnerMixin):
"repetition_penalty",
"presence_penalty",
"frequency_penalty",
"prompt_ignore_length",
"length_penalty",
"early_stopping",
"no_repeat_ngram_size",

View File

@ -165,6 +165,7 @@ class SamplingParams:
repetition_penalty (float, optional): Used to penalize tokens based on how often they appear in the sequence. It can have any value > 0.f. Values < 1.f encourages repetition, values > 1.f discourages it. None means using C++ runtime default 1.f. Defaults to None.
presence_penalty (float, optional): Used to penalize tokens already present in the sequence (irrespective of the number of appearances). It can have any values. Values < 0.f encourage repetition, values > 0.f discourage it. None means using C++ runtime default 0.f. Defaults to None.
frequency_penalty (float, optional): Used to penalize tokens already present in the sequence (dependent on the number of appearances). It can have any values. Values < 0.f encourage repetition, values > 0.f discourage it. None means using C++ runtime default 0.f. Defaults to None.
prompt_ignore_length (int, optional): Controls how many tokens to ignore from the prompt for presence and frequency penalties. Values <= 0 have no effect. Values > input (prompt) length will be clamped. None means using C++ runtime default 0. Defaults to None.
length_penalty (float, optional): Controls how to penalize longer sequences in beam search. None means using C++ runtime default 0.f. Defaults to None.
early_stopping (int, optional): Controls whether the generation process finishes once beamWidth sentences are generated (ends with end_token). None means using C++ runtime default 1. Defaults to None.
no_repeat_ngram_size (int, optional): Controls how many repeat ngram size are acceptable. None means using C++ runtime default 1 << 30. Defaults to None.
@ -232,6 +233,7 @@ class SamplingParams:
repetition_penalty: Optional[float] = None
presence_penalty: Optional[float] = None
frequency_penalty: Optional[float] = None
prompt_ignore_length: Optional[int] = None
length_penalty: Optional[float] = None
early_stopping: Optional[int] = None
no_repeat_ngram_size: Optional[int] = None

View File

@ -229,6 +229,7 @@ class CompletionRequest(OpenAIBaseModel):
top_p: Optional[float] = None
user: Optional[str] = None
lora_request: Optional[LoRARequest] = None
prompt_ignore_length: Optional[int] = 0
# doc: begin-completion-sampling-params
use_beam_search: bool = False
@ -283,6 +284,7 @@ class CompletionRequest(OpenAIBaseModel):
temperature=(self.temperature
if self.temperature is not None else 1.0),
top_p=(self.top_p if self.top_p is not None else 1.0),
prompt_ignore_length=self.prompt_ignore_length,
# completion-sampling-params
use_beam_search=self.use_beam_search,
@ -530,6 +532,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
"reasoning is shown in the model's response. Options: "
"'low', 'medium', 'high'."),
)
prompt_ignore_length: Optional[int] = 0
# doc: begin-chat-completion-sampling-params
best_of: Optional[int] = None
@ -622,6 +625,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
stop=self.stop,
temperature=(self.temperature
if self.temperature is not None else 1.0),
prompt_ignore_length=self.prompt_ignore_length,
# chat-completion-sampling-params
best_of=self.best_of,

View File

@ -12,5 +12,8 @@ methods:
beam_width_array:
annotation: Optional[List[int]]
default: null
prompt_ignore_length:
annotation: Optional[int]
default: null
return_annotation: None
properties: {}

View File

@ -35,6 +35,7 @@ class TestSamplingParams(ApiStabilityTestHarness):
"repetition_penalty",
"presence_penalty",
"frequency_penalty",
"prompt_ignore_length",
"length_penalty",
"early_stopping",
"no_repeat_ngram_size",