mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[None][feat] Support ignored prompt length for penalties via new sampling config parameter (#8127)
Signed-off-by: Xuanyu Chen <xuanyuc@nvidia.com>
This commit is contained in:
parent
b9b2802599
commit
d1398c05e6
@ -71,6 +71,7 @@ public:
|
||||
std::optional<FloatType> const& repetitionPenalty = std::nullopt,
|
||||
std::optional<FloatType> const& presencePenalty = std::nullopt,
|
||||
std::optional<FloatType> const& frequencyPenalty = std::nullopt,
|
||||
std::optional<SizeType32> const& promptIgnoreLength = std::nullopt,
|
||||
std::optional<FloatType> const& lengthPenalty = std::nullopt,
|
||||
std::optional<SizeType32> const& earlyStopping = std::nullopt,
|
||||
std::optional<SizeType32> const& noRepeatNgramSize = std::nullopt,
|
||||
@ -94,6 +95,7 @@ public:
|
||||
[[nodiscard]] std::optional<FloatType> getRepetitionPenalty() const;
|
||||
[[nodiscard]] std::optional<FloatType> getPresencePenalty() const;
|
||||
[[nodiscard]] std::optional<FloatType> getFrequencyPenalty() const;
|
||||
[[nodiscard]] std::optional<SizeType32> getPromptIgnoreLength() const;
|
||||
[[nodiscard]] std::optional<FloatType> getLengthPenalty() const;
|
||||
[[nodiscard]] std::optional<SizeType32> getEarlyStopping() const;
|
||||
[[nodiscard]] std::optional<SizeType32> getNoRepeatNgramSize() const;
|
||||
@ -114,6 +116,7 @@ public:
|
||||
void setRepetitionPenalty(std::optional<FloatType> const& repetitionPenalty);
|
||||
void setPresencePenalty(std::optional<FloatType> const& presencePenalty);
|
||||
void setFrequencyPenalty(std::optional<FloatType> const& frequencyPenalty);
|
||||
void setPromptIgnoreLength(std::optional<SizeType32> const& promptIgnoreLength);
|
||||
void setLengthPenalty(std::optional<FloatType> const& lengthPenalty);
|
||||
void setEarlyStopping(std::optional<SizeType32> const& earlyStopping);
|
||||
void setNoRepeatNgramSize(std::optional<SizeType32> const& noRepeatNgramSize);
|
||||
@ -133,6 +136,8 @@ private:
|
||||
static std::optional<FloatType> const& checkBeamSearchDiversityRate(
|
||||
std::optional<FloatType> const& beamSearchDiversityRate);
|
||||
static std::optional<FloatType> const& checkRepetitionPenalty(std::optional<FloatType> const& repetitionpenalty);
|
||||
static std::optional<SizeType32> const& checkPromptIgnoreLength(
|
||||
std::optional<SizeType32> const& promptIgnoreLength);
|
||||
static std::optional<FloatType> const& checkLengthPenalty(std::optional<FloatType> const& lengthPenalty);
|
||||
static std::optional<SizeType32> const& checkEarlyStopping(std::optional<SizeType32> const& earlyStopping);
|
||||
static std::optional<SizeType32> const& checkNoRepeatNgramSize(std::optional<SizeType32> const& noRepeatNgramSize);
|
||||
@ -174,6 +179,9 @@ private:
|
||||
/// @brief Used to penalize tokens already present in the sequence (dependent on the number of appearances). It can
|
||||
/// have any values. Values < 0.f encourage repetition, values > 0.f discourage it. Default is 0.f
|
||||
std::optional<FloatType> mFrequencyPenalty;
|
||||
/// @brief Controls how many tokens to ignore from the prompt for presence and frequency penalties. Values <= 0 have
|
||||
/// no effect. Values > input (prompt) length will be clamped. Default is 0.
|
||||
std::optional<SizeType32> mPromptIgnoreLength;
|
||||
/// @brief Controls how to penalize longer sequences in beam search. Default is 0.f
|
||||
std::optional<FloatType> mLengthPenalty;
|
||||
/// @brief Controls whether the generation process finishes once beamWidth sentences are generated (ends with
|
||||
|
||||
@ -56,6 +56,11 @@ public:
|
||||
return 1;
|
||||
}
|
||||
|
||||
[[nodiscard]] __host__ __device__ static constexpr runtime::SizeType32 getPromptIgnoreLength()
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
[[nodiscard]] __host__ __device__ static constexpr uint64_t getSeed()
|
||||
{
|
||||
return 0;
|
||||
|
||||
@ -133,6 +133,9 @@ public:
|
||||
frequencyPenalty = fuseValues<FloatType>(
|
||||
configs, [&configs](size_t ci) { return configs[ci].frequencyPenalty; },
|
||||
layers::DefaultDecodingParams::getFrequencyPenalty());
|
||||
promptIgnoreLength = fuseValues<SizeType32>(
|
||||
configs, [&configs](size_t ci) { return configs[ci].promptIgnoreLength; },
|
||||
layers::DefaultDecodingParams::getPromptIgnoreLength());
|
||||
noRepeatNgramSize = fuseValues<SizeType32>(
|
||||
configs, [&configs](size_t ci) { return configs[ci].noRepeatNgramSize; },
|
||||
layers::DefaultDecodingParams::getNoRepeatNgramSize());
|
||||
@ -224,6 +227,7 @@ public:
|
||||
SET_FROM_OPTIONAL(repetitionPenalty, RepetitionPenalty, FloatType)
|
||||
SET_FROM_OPTIONAL(presencePenalty, PresencePenalty, FloatType)
|
||||
SET_FROM_OPTIONAL(frequencyPenalty, FrequencyPenalty, FloatType)
|
||||
SET_FROM_OPTIONAL(promptIgnoreLength, PromptIgnoreLength, SizeType32)
|
||||
SET_FROM_OPTIONAL(lengthPenalty, LengthPenalty, FloatType)
|
||||
SET_FROM_OPTIONAL(earlyStopping, EarlyStopping, SizeType32)
|
||||
SET_FROM_OPTIONAL(noRepeatNgramSize, NoRepeatNgramSize, SizeType32)
|
||||
@ -342,6 +346,7 @@ public:
|
||||
OptVec<FloatType> repetitionPenalty; // [1] or [batchSize]
|
||||
OptVec<FloatType> presencePenalty; // [1] or [batchSize]
|
||||
OptVec<FloatType> frequencyPenalty; // [1] or [batchSize]
|
||||
OptVec<SizeType32> promptIgnoreLength; // [1] or [batchSize]
|
||||
OptVec<SizeType32> noRepeatNgramSize; // [1] or [batchSize]
|
||||
|
||||
// probs
|
||||
@ -377,13 +382,14 @@ public:
|
||||
&& temperature == other.temperature && originalTemperature == other.originalTemperature
|
||||
&& minLength == other.minLength && repetitionPenalty == other.repetitionPenalty
|
||||
&& presencePenalty == other.presencePenalty && frequencyPenalty == other.frequencyPenalty
|
||||
&& noRepeatNgramSize == other.noRepeatNgramSize && topK == other.topK && topP == other.topP
|
||||
&& randomSeed == other.randomSeed && topPDecay == other.topPDecay && topPMin == other.topPMin
|
||||
&& topPResetIds == other.topPResetIds && beamSearchDiversityRate == other.beamSearchDiversityRate
|
||||
&& lengthPenalty == other.lengthPenalty && earlyStopping == other.earlyStopping
|
||||
&& draftAcceptanceThreshold == other.draftAcceptanceThreshold && topKMedusaHeads == other.topKMedusaHeads
|
||||
&& normalizeLogProbs == other.normalizeLogProbs && outputLogProbs == other.outputLogProbs
|
||||
&& cumLogProbs == other.cumLogProbs && minP == other.minP && beamWidthArray == other.beamWidthArray;
|
||||
&& promptIgnoreLength == other.promptIgnoreLength && noRepeatNgramSize == other.noRepeatNgramSize
|
||||
&& topK == other.topK && topP == other.topP && randomSeed == other.randomSeed
|
||||
&& topPDecay == other.topPDecay && topPMin == other.topPMin && topPResetIds == other.topPResetIds
|
||||
&& beamSearchDiversityRate == other.beamSearchDiversityRate && lengthPenalty == other.lengthPenalty
|
||||
&& earlyStopping == other.earlyStopping && draftAcceptanceThreshold == other.draftAcceptanceThreshold
|
||||
&& topKMedusaHeads == other.topKMedusaHeads && normalizeLogProbs == other.normalizeLogProbs
|
||||
&& outputLogProbs == other.outputLogProbs && cumLogProbs == other.cumLogProbs && minP == other.minP
|
||||
&& beamWidthArray == other.beamWidthArray;
|
||||
}
|
||||
|
||||
SizeType32 getNumReturnBeams() const
|
||||
|
||||
@ -34,9 +34,9 @@ SamplingConfig::SamplingConfig(SizeType32 beamWidth, OptSize32 const& topK, OptF
|
||||
OptFloat const& topPMin, std::optional<TokenIdType> const& topPResetIds, OptFloat const& topPDecay,
|
||||
std::optional<RandomSeedType> const& seed, OptFloat const& temperature, OptSize32 const& minTokens,
|
||||
OptFloat const& beamSearchDiversityRate, OptFloat const& repetitionPenalty, OptFloat const& presencePenalty,
|
||||
OptFloat const& frequencyPenalty, OptFloat const& lengthPenalty, OptSize32 const& earlyStopping,
|
||||
OptSize32 const& noRepeatNgramSize, OptSize32 const& numReturnSequences, OptFloat const& minP,
|
||||
OptVec<SizeType32> const& beamWidthArray)
|
||||
OptFloat const& frequencyPenalty, OptSize32 const& promptIgnoreLength, OptFloat const& lengthPenalty,
|
||||
OptSize32 const& earlyStopping, OptSize32 const& noRepeatNgramSize, OptSize32 const& numReturnSequences,
|
||||
OptFloat const& minP, OptVec<SizeType32> const& beamWidthArray)
|
||||
: mBeamWidth(checkBeamWidth(beamWidth))
|
||||
, mTopK(checkTopK(topK))
|
||||
, mTopP(checkTopP(topP))
|
||||
@ -50,6 +50,7 @@ SamplingConfig::SamplingConfig(SizeType32 beamWidth, OptSize32 const& topK, OptF
|
||||
, mRepetitionPenalty(checkRepetitionPenalty(repetitionPenalty))
|
||||
, mPresencePenalty(presencePenalty)
|
||||
, mFrequencyPenalty(frequencyPenalty)
|
||||
, mPromptIgnoreLength(checkPromptIgnoreLength(promptIgnoreLength))
|
||||
, mLengthPenalty(checkLengthPenalty(lengthPenalty))
|
||||
, mEarlyStopping(checkEarlyStopping(earlyStopping))
|
||||
, mNoRepeatNgramSize(checkNoRepeatNgramSize(noRepeatNgramSize))
|
||||
@ -67,9 +68,10 @@ bool SamplingConfig::operator==(SamplingConfig const& other) const
|
||||
&& mTemperature == other.mTemperature && mMinTokens == other.mMinTokens
|
||||
&& mBeamSearchDiversityRate == other.mBeamSearchDiversityRate && mRepetitionPenalty == other.mRepetitionPenalty
|
||||
&& mPresencePenalty == other.mPresencePenalty && mFrequencyPenalty == other.mFrequencyPenalty
|
||||
&& mLengthPenalty == other.mLengthPenalty && mEarlyStopping == other.mEarlyStopping
|
||||
&& mNoRepeatNgramSize == other.mNoRepeatNgramSize && mNumReturnSequences == other.mNumReturnSequences
|
||||
&& mMinP == other.mMinP && mBeamWidthArray == other.mBeamWidthArray;
|
||||
&& mPromptIgnoreLength == other.mPromptIgnoreLength && mLengthPenalty == other.mLengthPenalty
|
||||
&& mEarlyStopping == other.mEarlyStopping && mNoRepeatNgramSize == other.mNoRepeatNgramSize
|
||||
&& mNumReturnSequences == other.mNumReturnSequences && mMinP == other.mMinP
|
||||
&& mBeamWidthArray == other.mBeamWidthArray;
|
||||
}
|
||||
|
||||
// Getters
|
||||
@ -143,6 +145,11 @@ OptFloat SamplingConfig::getFrequencyPenalty() const
|
||||
return mFrequencyPenalty;
|
||||
}
|
||||
|
||||
OptSize32 SamplingConfig::getPromptIgnoreLength() const
|
||||
{
|
||||
return mPromptIgnoreLength;
|
||||
}
|
||||
|
||||
OptFloat SamplingConfig::getLengthPenalty() const
|
||||
{
|
||||
return mLengthPenalty;
|
||||
@ -240,6 +247,11 @@ void SamplingConfig::setFrequencyPenalty(OptFloat const& frequencyPenalty)
|
||||
mFrequencyPenalty = frequencyPenalty;
|
||||
}
|
||||
|
||||
void SamplingConfig::setPromptIgnoreLength(OptSize32 const& promptIgnoreLength)
|
||||
{
|
||||
mPromptIgnoreLength = checkPromptIgnoreLength(promptIgnoreLength);
|
||||
}
|
||||
|
||||
void SamplingConfig::setLengthPenalty(OptFloat const& lengthPenalty)
|
||||
{
|
||||
mLengthPenalty = lengthPenalty; // TODO: re-enable `checkLengthPenalty` later
|
||||
@ -362,6 +374,15 @@ OptFloat const& SamplingConfig::checkRepetitionPenalty(OptFloat const& repetitio
|
||||
return repetitionpenalty;
|
||||
}
|
||||
|
||||
OptSize32 const& SamplingConfig::checkPromptIgnoreLength(OptSize32 const& promptIgnoreLength)
|
||||
{
|
||||
if (promptIgnoreLength.has_value())
|
||||
{
|
||||
TLLM_CHECK(promptIgnoreLength.value() >= 0);
|
||||
}
|
||||
return promptIgnoreLength;
|
||||
}
|
||||
|
||||
OptFloat const& SamplingConfig::checkLengthPenalty(OptFloat const& lengthPenalty)
|
||||
{
|
||||
if (lengthPenalty.has_value())
|
||||
|
||||
@ -159,6 +159,7 @@ SamplingConfig Serialization::deserializeSamplingConfig(std::istream& is)
|
||||
auto repetitionPenalty = su::deserialize<std::optional<FloatType>>(is);
|
||||
auto presencePenalty = su::deserialize<std::optional<FloatType>>(is);
|
||||
auto frequencyPenalty = su::deserialize<std::optional<FloatType>>(is);
|
||||
auto promptIgnoreLength = su::deserialize<std::optional<SizeType32>>(is);
|
||||
auto lengthPenalty = su::deserialize<std::optional<FloatType>>(is);
|
||||
auto earlyStopping = su::deserialize<std::optional<SizeType32>>(is);
|
||||
auto noRepeatNgramSize = su::deserialize<std::optional<SizeType32>>(is);
|
||||
@ -167,8 +168,8 @@ SamplingConfig Serialization::deserializeSamplingConfig(std::istream& is)
|
||||
auto beamWidthArray = su::deserialize<std::optional<std::vector<SizeType32>>>(is);
|
||||
|
||||
return SamplingConfig{beamWidth, topK, topP, topPMin, topPResetIds, topPDecay, randomSeed, temperature, minLength,
|
||||
beamSearchDiversityRate, repetitionPenalty, presencePenalty, frequencyPenalty, lengthPenalty, earlyStopping,
|
||||
noRepeatNgramSize, numReturnSequences, minP, beamWidthArray};
|
||||
beamSearchDiversityRate, repetitionPenalty, presencePenalty, frequencyPenalty, promptIgnoreLength,
|
||||
lengthPenalty, earlyStopping, noRepeatNgramSize, numReturnSequences, minP, beamWidthArray};
|
||||
}
|
||||
|
||||
void Serialization::serialize(SamplingConfig const& config, std::ostream& os)
|
||||
@ -186,6 +187,7 @@ void Serialization::serialize(SamplingConfig const& config, std::ostream& os)
|
||||
su::serialize(config.mRepetitionPenalty, os);
|
||||
su::serialize(config.mPresencePenalty, os);
|
||||
su::serialize(config.mFrequencyPenalty, os);
|
||||
su::serialize(config.mPromptIgnoreLength, os);
|
||||
su::serialize(config.mLengthPenalty, os);
|
||||
su::serialize(config.mEarlyStopping, os);
|
||||
su::serialize(config.mNoRepeatNgramSize, os);
|
||||
@ -210,6 +212,7 @@ size_t Serialization::serializedSize(SamplingConfig const& config)
|
||||
totalSize += su::serializedSize(config.mRepetitionPenalty);
|
||||
totalSize += su::serializedSize(config.mPresencePenalty);
|
||||
totalSize += su::serializedSize(config.mFrequencyPenalty);
|
||||
totalSize += su::serializedSize(config.mPromptIgnoreLength);
|
||||
totalSize += su::serializedSize(config.mLengthPenalty);
|
||||
totalSize += su::serializedSize(config.mEarlyStopping);
|
||||
totalSize += su::serializedSize(config.mNoRepeatNgramSize);
|
||||
|
||||
@ -39,10 +39,10 @@ template <typename T>
|
||||
__global__ void batchApplyPenalty(T const* const* inputLogits, T* outputLogits, T const* biases,
|
||||
TokenIdType* penaltyWorkspace, TokenIdType const* penaltyWorkspacePrev, float const* temperatures,
|
||||
float const* repetitionPenalties, float const* presencePenalties, float const* frequencyPenalties,
|
||||
SizeType32 maxSeqLen, SizeType32 vocabSize, SizeType32 vocabSizePadded, TokenIdType const** outputIdsPtr,
|
||||
SizeType32 const** parentIdsPtr, SizeType32 const* inputLengths, SizeType32 const* sequenceLengths,
|
||||
SizeType32 const* minLengths, TokenIdType const* endIds, SizeType32 const* batchSlots,
|
||||
SizeType32 const* tokensPerStep, FinishedState const* finished)
|
||||
SizeType32 const* promptIgnoreLengths, SizeType32 maxSeqLen, SizeType32 vocabSize, SizeType32 vocabSizePadded,
|
||||
TokenIdType const** outputIdsPtr, SizeType32 const** parentIdsPtr, SizeType32 const* inputLengths,
|
||||
SizeType32 const* sequenceLengths, SizeType32 const* minLengths, TokenIdType const* endIds,
|
||||
SizeType32 const* batchSlots, SizeType32 const* tokensPerStep, FinishedState const* finished)
|
||||
{
|
||||
auto const beamWidth = static_cast<SizeType32>(gridDim.y);
|
||||
auto const maxTokensPerStep = static_cast<SizeType32>(gridDim.z);
|
||||
@ -73,6 +73,7 @@ __global__ void batchApplyPenalty(T const* const* inputLogits, T* outputLogits,
|
||||
float presencePenalty{layers::DefaultDecodingParams::getPresencePenalty()};
|
||||
float frequencyPenalty{layers::DefaultDecodingParams::getFrequencyPenalty()};
|
||||
SizeType32 minLength{layers::DefaultDecodingParams::getMinLength()};
|
||||
SizeType32 promptIgnoreLength{layers::DefaultDecodingParams::getPromptIgnoreLength()};
|
||||
bool accumulateVocab{false};
|
||||
bool hasTemperature{false};
|
||||
bool hasMinLength{false};
|
||||
@ -103,27 +104,42 @@ __global__ void batchApplyPenalty(T const* const* inputLogits, T* outputLogits,
|
||||
minLength = minLengths[batchSlot];
|
||||
hasMinLength |= (minLength > 0);
|
||||
}
|
||||
if (promptIgnoreLengths != nullptr)
|
||||
{
|
||||
promptIgnoreLength = min(promptIgnoreLengths[batchSlot], inputLen);
|
||||
}
|
||||
|
||||
// Initialize or update the number of occurrences of tokens
|
||||
if (accumulateVocab)
|
||||
{
|
||||
penaltyWorkspace += batchBeamStepIdx * vocabSize;
|
||||
penaltyWorkspace += batchBeamStepIdx * 2 * vocabSize;
|
||||
if (currentStep <= inputLen)
|
||||
{ // Context phase
|
||||
for (auto index = static_cast<SizeType32>(threadIdx.x); index < vocabSize;
|
||||
for (auto index = static_cast<SizeType32>(threadIdx.x); index < 2 * vocabSize;
|
||||
index += static_cast<SizeType32>(blockDim.x))
|
||||
{
|
||||
penaltyWorkspace[index] = 0;
|
||||
}
|
||||
__syncthreads();
|
||||
for (auto step = static_cast<SizeType32>(threadIdx.x); step < inputLen;
|
||||
for (auto step = static_cast<SizeType32>(threadIdx.x); step < promptIgnoreLength;
|
||||
step += static_cast<SizeType32>(blockDim.x))
|
||||
{
|
||||
// All beams in the context phase are identical
|
||||
auto penaltyIndex = outputIdsPtr[batchSlot][beamIdx * maxSeqLen + step];
|
||||
if (penaltyIndex < vocabSize)
|
||||
{
|
||||
atomicAdd(&penaltyWorkspace[penaltyIndex], 1);
|
||||
penaltyWorkspace[penaltyIndex] = 1;
|
||||
}
|
||||
}
|
||||
|
||||
for (auto step = promptIgnoreLength + static_cast<SizeType32>(threadIdx.x); step < inputLen;
|
||||
step += static_cast<SizeType32>(blockDim.x))
|
||||
{
|
||||
// All beams in the context phase are identical
|
||||
auto penaltyIndex = outputIdsPtr[batchSlot][beamIdx * maxSeqLen + step];
|
||||
if (penaltyIndex < vocabSize)
|
||||
{
|
||||
atomicAdd(&penaltyWorkspace[penaltyIndex + vocabSize], 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -132,8 +148,9 @@ __global__ void batchApplyPenalty(T const* const* inputLogits, T* outputLogits,
|
||||
if (beamWidth > 1)
|
||||
{
|
||||
auto parentBeam = parentIdsPtr[batchSlot][beamIdx * maxSeqLen + currentStep - 1];
|
||||
penaltyWorkspacePrev += ((batchIdx * beamWidth + parentBeam) * maxTokensPerStep + stepIdx) * vocabSize;
|
||||
for (auto index = static_cast<SizeType32>(threadIdx.x); index < vocabSize;
|
||||
penaltyWorkspacePrev
|
||||
+= ((batchIdx * beamWidth + parentBeam) * maxTokensPerStep + stepIdx) * (2 * vocabSize);
|
||||
for (auto index = static_cast<SizeType32>(threadIdx.x); index < 2 * vocabSize;
|
||||
index += static_cast<SizeType32>(blockDim.x))
|
||||
{
|
||||
penaltyWorkspace[index] = penaltyWorkspacePrev[index];
|
||||
@ -145,7 +162,7 @@ __global__ void batchApplyPenalty(T const* const* inputLogits, T* outputLogits,
|
||||
auto penaltyIndex = outputIdsPtr[batchSlot][beamIdx * maxSeqLen + currentStep - 1];
|
||||
if (penaltyIndex < vocabSize)
|
||||
{
|
||||
penaltyWorkspace[penaltyIndex] += 1;
|
||||
penaltyWorkspace[penaltyIndex + vocabSize] += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -174,14 +191,19 @@ __global__ void batchApplyPenalty(T const* const* inputLogits, T* outputLogits,
|
||||
}
|
||||
if (accumulateVocab)
|
||||
{
|
||||
SizeType32 numOccurences = penaltyWorkspace[index];
|
||||
if (numOccurences > 0)
|
||||
SizeType32 numOccurences = penaltyWorkspace[index + vocabSize];
|
||||
SizeType32 ifPresenceInFullSeq = numOccurences | penaltyWorkspace[index];
|
||||
if (ifPresenceInFullSeq > 0)
|
||||
{
|
||||
// Repetition
|
||||
if (repetitionPenalties != nullptr)
|
||||
{
|
||||
logit = logit < 0.0f ? logit * repetitionPenalty : logit / repetitionPenalty;
|
||||
}
|
||||
}
|
||||
|
||||
if (numOccurences > 0)
|
||||
{
|
||||
// Presence
|
||||
if (presencePenalties != nullptr)
|
||||
{
|
||||
@ -230,9 +252,10 @@ void invokeBatchApplyPenalty(InvokeBatchApplyPenaltyParams<T> const& params)
|
||||
dim3 grid(params.batchSize, params.beamWidth, params.maxTokensPerStep);
|
||||
batchApplyPenalty<T><<<grid, block, 0, params.stream>>>(params.inputLogits, params.outputLogits, params.biases,
|
||||
params.penaltyWorkspace, params.penaltyWorkspacePrev, params.temperatures, params.repetitionPenalties,
|
||||
params.presencePenalties, params.frequencyPenalties, params.maxSeqLen, params.vocabSize, params.vocabSizePadded,
|
||||
params.outputIdsPtr, params.parentIdsPtr, params.inputLengths, params.sequenceLengths, params.minLengths,
|
||||
params.endIds, params.batchSlots, params.tokensPerStep, params.finished);
|
||||
params.presencePenalties, params.frequencyPenalties, params.promptIgnoreLengths, params.maxSeqLen,
|
||||
params.vocabSize, params.vocabSizePadded, params.outputIdsPtr, params.parentIdsPtr, params.inputLengths,
|
||||
params.sequenceLengths, params.minLengths, params.endIds, params.batchSlots, params.tokensPerStep,
|
||||
params.finished);
|
||||
}
|
||||
|
||||
template void invokeBatchApplyPenalty(InvokeBatchApplyPenaltyParams<float> const& params);
|
||||
|
||||
@ -35,6 +35,7 @@ struct InvokeBatchApplyPenaltyParams
|
||||
float const* repetitionPenalties;
|
||||
float const* presencePenalties;
|
||||
float const* frequencyPenalties;
|
||||
runtime::SizeType32 const* promptIgnoreLengths;
|
||||
runtime::SizeType32 batchSize;
|
||||
runtime::SizeType32 beamWidth;
|
||||
runtime::SizeType32 maxSeqLen;
|
||||
|
||||
@ -29,11 +29,12 @@ namespace kernels
|
||||
|
||||
enum class DecodingPenaltyType
|
||||
{
|
||||
Temperature, // the temperature penalty
|
||||
Repetition, // the repetition penalty
|
||||
Presence, // the presence penalty
|
||||
Frequency, // the frequency penalty
|
||||
MinLength, // the min length penalty
|
||||
Temperature, // the temperature penalty
|
||||
Repetition, // the repetition penalty
|
||||
Presence, // the presence penalty
|
||||
Frequency, // the frequency penalty
|
||||
MinLength, // the min length penalty
|
||||
PromptIgnoreLength, // the prompt ignore length for presence/frequency penalty
|
||||
};
|
||||
|
||||
inline std::pair<float, float> getLimitsPenalty(DecodingPenaltyType penaltyType)
|
||||
@ -49,6 +50,7 @@ inline std::pair<float, float> getLimitsPenalty(DecodingPenaltyType penaltyType)
|
||||
case DecodingPenaltyType::Presence: return std::make_pair(fltMin, fltMax);
|
||||
case DecodingPenaltyType::Frequency: return std::make_pair(fltMin, fltMax);
|
||||
case DecodingPenaltyType::MinLength: return std::make_pair(-fltEpsilon, fltMax);
|
||||
case DecodingPenaltyType::PromptIgnoreLength: return std::make_pair(-fltEpsilon, fltMax);
|
||||
}
|
||||
TLLM_CHECK_WITH_INFO(false, "Unknown penalty type %d", static_cast<int32_t>(penaltyType));
|
||||
return std::make_pair(fltMin, fltMax);
|
||||
|
||||
@ -128,11 +128,12 @@ public:
|
||||
class PenaltySetupParams : public BaseSetupParams
|
||||
{
|
||||
public:
|
||||
OptVec<float> temperature; // [1] or [setupBatchSize]
|
||||
OptVec<runtime::SizeType32> minLength; // [1] or [setupBatchSize]
|
||||
OptVec<float> repetitionPenalty; // [1] or [setupBatchSize]
|
||||
OptVec<float> presencePenalty; // [1] or [setupBatchSize]
|
||||
OptVec<float> frequencyPenalty; // [1] or [setupBatchSize]
|
||||
OptVec<float> temperature; // [1] or [setupBatchSize]
|
||||
OptVec<runtime::SizeType32> minLength; // [1] or [setupBatchSize]
|
||||
OptVec<float> repetitionPenalty; // [1] or [setupBatchSize]
|
||||
OptVec<float> presencePenalty; // [1] or [setupBatchSize]
|
||||
OptVec<float> frequencyPenalty; // [1] or [setupBatchSize]
|
||||
OptVec<runtime::SizeType32> promptIgnoreLength; // [1] or [setupBatchSize]
|
||||
};
|
||||
|
||||
// Ban words layer
|
||||
|
||||
@ -83,7 +83,7 @@ void PenaltyLayer<T>::allocateWorkspace()
|
||||
{
|
||||
|
||||
auto const workspaceSize = mDecoderDomain.getBatchSize() * mDecoderDomain.getMaxDecodingTokens()
|
||||
* mConfiguredBeamWidth * mDecoderDomain.getVocabSize();
|
||||
* mConfiguredBeamWidth * mDecoderDomain.getVocabSize() * 2;
|
||||
mPenaltyWorkspaceDevice = mBufferManager->gpu(workspaceSize, nvinfer1::DataType::kINT32);
|
||||
|
||||
if (mDecodingMode.isBeamSearch())
|
||||
@ -107,6 +107,7 @@ void PenaltyLayer<T>::allocateBuffer()
|
||||
mPresencePenalty = mBufferManager->pinnedPool(batchSizeShape, TRTDataType<float>::value);
|
||||
mFrequencyPenalty = mBufferManager->pinnedPool(batchSizeShape, TRTDataType<float>::value);
|
||||
mMinLength = mBufferManager->pinnedPool(batchSizeShape, TRTDataType<SizeType32>::value);
|
||||
mPromptIgnoreLength = mBufferManager->pinnedPool(batchSizeShape, TRTDataType<SizeType32>::value);
|
||||
|
||||
if (mDecodingMode.isUseTemperature())
|
||||
{
|
||||
@ -128,6 +129,10 @@ void PenaltyLayer<T>::allocateBuffer()
|
||||
{
|
||||
mMinLengthDevice = mBufferManager->gpu(batchSizeShape, nvinfer1::DataType::kINT32);
|
||||
}
|
||||
if (mDecodingMode.isUseOccurrencePenalty())
|
||||
{
|
||||
mPromptIgnoreLengthDevice = mBufferManager->gpu(batchSizeShape, nvinfer1::DataType::kINT32);
|
||||
}
|
||||
|
||||
auto const logitsPtrDeviceDesc = std::make_pair(batchSizeShape, TRTDataType<T*>::value);
|
||||
mWorkspaceSize = DecodingLayerWorkspace::calculateRequiredWorkspaceSize(logitsPtrDeviceDesc);
|
||||
@ -169,6 +174,8 @@ void PenaltyLayer<T>::setup(SizeType32 batchSize, SizeType32 beamWidth, TensorCo
|
||||
bool const useFrequencyPenalty
|
||||
= mDecodingMode.isUseFrequencyPenalty() && penaltyParams->frequencyPenalty.has_value();
|
||||
bool const useMinLength = mDecodingMode.isUseMinLength() && penaltyParams->minLength.has_value();
|
||||
bool const usePromptIgnoreLength
|
||||
= mDecodingMode.isUseOccurrencePenalty() && penaltyParams->promptIgnoreLength.has_value();
|
||||
// FIXME: once one of the requests has some penalty, we will always have to compute it.
|
||||
// To avoid that we need to scan through all active requests at each iteration.
|
||||
mUseTemperature |= useTemperature;
|
||||
@ -176,6 +183,7 @@ void PenaltyLayer<T>::setup(SizeType32 batchSize, SizeType32 beamWidth, TensorCo
|
||||
mUsePresencePenalty |= usePresencePenalty;
|
||||
mUseFrequencyPenalty |= useFrequencyPenalty;
|
||||
mUseMinLength |= useMinLength;
|
||||
mUsePromptIgnoreLength |= usePromptIgnoreLength;
|
||||
|
||||
if (mUseTemperature)
|
||||
{
|
||||
@ -203,10 +211,16 @@ void PenaltyLayer<T>::setup(SizeType32 batchSize, SizeType32 beamWidth, TensorCo
|
||||
fillBuffers(penaltyParams->minLength, DefaultDecodingParams::getMinLength(), mMinLength, mMinLengthDevice,
|
||||
batchSlots, getLimitsPenalty(DecodingPenaltyType::MinLength), "min length");
|
||||
}
|
||||
if (mUsePromptIgnoreLength)
|
||||
{
|
||||
fillBuffers(penaltyParams->promptIgnoreLength, DefaultDecodingParams::getPromptIgnoreLength(),
|
||||
mPromptIgnoreLength, mPromptIgnoreLengthDevice, batchSlots,
|
||||
getLimitsPenalty(DecodingPenaltyType::PromptIgnoreLength), "prompt ignore length");
|
||||
}
|
||||
|
||||
// Reset penalty workspace
|
||||
auto const workspaceSizePerBatch
|
||||
= mDecoderDomain.getMaxDecodingTokens() * mConfiguredBeamWidth * mDecoderDomain.getVocabSize();
|
||||
= mDecoderDomain.getMaxDecodingTokens() * mConfiguredBeamWidth * mDecoderDomain.getVocabSize() * 2;
|
||||
for (SizeType32 bi = 0; bi < batchSize; ++bi)
|
||||
{
|
||||
auto batchSlot = runtime::bufferCast<runtime::SizeType32>(*batchSlots)[bi];
|
||||
@ -287,6 +301,7 @@ void PenaltyLayer<T>::forwardAsync(std::shared_ptr<BaseDecodingOutputs> const& b
|
||||
auto presencePenalties = GET_PENALTIES(PresencePenalty, float);
|
||||
auto frequencyPenalties = GET_PENALTIES(FrequencyPenalty, float);
|
||||
auto minLengths = GET_PENALTIES(MinLength, SizeType32);
|
||||
auto promptIgnoreLengths = GET_PENALTIES(PromptIgnoreLength, SizeType32);
|
||||
|
||||
#undef GET_PENALTIES
|
||||
|
||||
@ -316,6 +331,7 @@ void PenaltyLayer<T>::forwardAsync(std::shared_ptr<BaseDecodingOutputs> const& b
|
||||
penaltyParams.inputLengths = inputLengths;
|
||||
penaltyParams.sequenceLengths = bufferCast<SizeType32>(*outputs->sequenceLength.value());
|
||||
penaltyParams.minLengths = bufferCastOrNull<SizeType32>(minLengths);
|
||||
penaltyParams.promptIgnoreLengths = bufferCastOrNull<SizeType32>(promptIgnoreLengths);
|
||||
penaltyParams.endIds = bufferCast<TokenIdType>(*params->endIds);
|
||||
penaltyParams.batchSlots = workspace->getDeviceBatchSlotsPtr();
|
||||
penaltyParams.maxTokensPerStep = mDecoderDomain.getMaxDecodingTokens();
|
||||
|
||||
@ -67,18 +67,21 @@ private:
|
||||
TensorPtr mPresencePenaltyDevice;
|
||||
TensorPtr mFrequencyPenaltyDevice;
|
||||
TensorPtr mMinLengthDevice;
|
||||
TensorPtr mPromptIgnoreLengthDevice;
|
||||
|
||||
TensorPtr mTemperature;
|
||||
TensorPtr mRepetitionPenalty;
|
||||
TensorPtr mPresencePenalty;
|
||||
TensorPtr mFrequencyPenalty;
|
||||
TensorPtr mMinLength;
|
||||
TensorPtr mPromptIgnoreLength;
|
||||
|
||||
bool mUseTemperature{false};
|
||||
bool mUseRepetitionPenalty{false};
|
||||
bool mUsePresencePenalty{false};
|
||||
bool mUseFrequencyPenalty{false};
|
||||
bool mUseMinLength{false};
|
||||
bool mUsePromptIgnoreLength{false};
|
||||
|
||||
runtime::SizeType32 mCyclicStep{0};
|
||||
runtime::SizeType32 mRuntimeMaxSeqLen{0};
|
||||
|
||||
@ -370,14 +370,14 @@ NB_MODULE(TRTLLM_NB_MODULE, m)
|
||||
auto SamplingConfigGetState = [](tr::SamplingConfig const& config) -> nb::tuple
|
||||
{
|
||||
return nb::make_tuple(config.beamWidth, config.temperature, config.minLength, config.repetitionPenalty,
|
||||
config.presencePenalty, config.frequencyPenalty, config.topK, config.topP, config.randomSeed,
|
||||
config.topPDecay, config.topPMin, config.topPResetIds, config.beamSearchDiversityRate, config.lengthPenalty,
|
||||
config.earlyStopping, config.noRepeatNgramSize, config.numReturnSequences, config.minP,
|
||||
config.beamWidthArray);
|
||||
config.presencePenalty, config.frequencyPenalty, config.promptIgnoreLength, config.topK, config.topP,
|
||||
config.randomSeed, config.topPDecay, config.topPMin, config.topPResetIds, config.beamSearchDiversityRate,
|
||||
config.lengthPenalty, config.earlyStopping, config.noRepeatNgramSize, config.numReturnSequences,
|
||||
config.minP, config.beamWidthArray);
|
||||
};
|
||||
auto SamplingConfigSetState = [](tr::SamplingConfig& self, nb::tuple t)
|
||||
{
|
||||
if (t.size() != 19)
|
||||
if (t.size() != 20)
|
||||
{
|
||||
throw std::runtime_error("Invalid SamplingConfig state!");
|
||||
}
|
||||
@ -389,19 +389,20 @@ NB_MODULE(TRTLLM_NB_MODULE, m)
|
||||
config.repetitionPenalty = nb::cast<OptVec<float>>(t[3]);
|
||||
config.presencePenalty = nb::cast<OptVec<float>>(t[4]);
|
||||
config.frequencyPenalty = nb::cast<OptVec<float>>(t[5]);
|
||||
config.topK = nb::cast<OptVec<SizeType32>>(t[6]);
|
||||
config.topP = nb::cast<OptVec<float>>(t[7]);
|
||||
config.randomSeed = nb::cast<OptVec<uint64_t>>(t[8]);
|
||||
config.topPDecay = nb::cast<OptVec<float>>(t[9]);
|
||||
config.topPMin = nb::cast<OptVec<float>>(t[10]);
|
||||
config.topPResetIds = nb::cast<OptVec<TokenIdType>>(t[11]);
|
||||
config.beamSearchDiversityRate = nb::cast<OptVec<float>>(t[12]);
|
||||
config.lengthPenalty = nb::cast<OptVec<float>>(t[13]);
|
||||
config.earlyStopping = nb::cast<OptVec<SizeType32>>(t[14]);
|
||||
config.noRepeatNgramSize = nb::cast<OptVec<SizeType32>>(t[15]);
|
||||
config.numReturnSequences = nb::cast<SizeType32>(t[16]);
|
||||
config.minP = nb::cast<OptVec<float>>(t[17]);
|
||||
config.beamWidthArray = nb::cast<OptVec<std::vector<SizeType32>>>(t[18]);
|
||||
config.promptIgnoreLength = nb::cast<OptVec<SizeType32>>(t[6]);
|
||||
config.topK = nb::cast<OptVec<SizeType32>>(t[7]);
|
||||
config.topP = nb::cast<OptVec<float>>(t[8]);
|
||||
config.randomSeed = nb::cast<OptVec<uint64_t>>(t[9]);
|
||||
config.topPDecay = nb::cast<OptVec<float>>(t[10]);
|
||||
config.topPMin = nb::cast<OptVec<float>>(t[11]);
|
||||
config.topPResetIds = nb::cast<OptVec<TokenIdType>>(t[12]);
|
||||
config.beamSearchDiversityRate = nb::cast<OptVec<float>>(t[13]);
|
||||
config.lengthPenalty = nb::cast<OptVec<float>>(t[14]);
|
||||
config.earlyStopping = nb::cast<OptVec<SizeType32>>(t[15]);
|
||||
config.noRepeatNgramSize = nb::cast<OptVec<SizeType32>>(t[16]);
|
||||
config.numReturnSequences = nb::cast<SizeType32>(t[17]);
|
||||
config.minP = nb::cast<OptVec<float>>(t[18]);
|
||||
config.beamWidthArray = nb::cast<OptVec<std::vector<SizeType32>>>(t[19]);
|
||||
|
||||
new (&self) tr::SamplingConfig(config);
|
||||
};
|
||||
@ -416,6 +417,7 @@ NB_MODULE(TRTLLM_NB_MODULE, m)
|
||||
.def_rw("repetition_penalty", &tr::SamplingConfig::repetitionPenalty)
|
||||
.def_rw("presence_penalty", &tr::SamplingConfig::presencePenalty)
|
||||
.def_rw("frequency_penalty", &tr::SamplingConfig::frequencyPenalty)
|
||||
.def_rw("prompt_ignore_length", &tr::SamplingConfig::promptIgnoreLength)
|
||||
.def_rw("top_k", &tr::SamplingConfig::topK)
|
||||
.def_rw("top_p", &tr::SamplingConfig::topP)
|
||||
.def_rw("random_seed", &tr::SamplingConfig::randomSeed)
|
||||
|
||||
@ -76,12 +76,12 @@ void initRequestBindings(nb::module_& m)
|
||||
return nb::make_tuple(self.getBeamWidth(), self.getTopK(), self.getTopP(), self.getTopPMin(),
|
||||
self.getTopPResetIds(), self.getTopPDecay(), self.getSeed(), self.getTemperature(), self.getMinTokens(),
|
||||
self.getBeamSearchDiversityRate(), self.getRepetitionPenalty(), self.getPresencePenalty(),
|
||||
self.getFrequencyPenalty(), self.getLengthPenalty(), self.getEarlyStopping(), self.getNoRepeatNgramSize(),
|
||||
self.getNumReturnSequences(), self.getMinP(), self.getBeamWidthArray());
|
||||
self.getFrequencyPenalty(), self.getPromptIgnoreLength(), self.getLengthPenalty(), self.getEarlyStopping(),
|
||||
self.getNoRepeatNgramSize(), self.getNumReturnSequences(), self.getMinP(), self.getBeamWidthArray());
|
||||
};
|
||||
auto samplingConfigSetstate = [](tle::SamplingConfig& samplingConfig, nb::tuple const& state)
|
||||
{
|
||||
if (state.size() != 19)
|
||||
if (state.size() != 20)
|
||||
{
|
||||
throw std::runtime_error("Invalid SamplingConfig state!");
|
||||
}
|
||||
@ -98,12 +98,13 @@ void initRequestBindings(nb::module_& m)
|
||||
nb::cast<std::optional<FloatType>>(state[10]), // RepetitionPenalty
|
||||
nb::cast<std::optional<FloatType>>(state[11]), // PresencePenalty
|
||||
nb::cast<std::optional<FloatType>>(state[12]), // FrequencyPenalty
|
||||
nb::cast<std::optional<FloatType>>(state[13]), // LengthPenalty
|
||||
nb::cast<std::optional<SizeType32>>(state[14]), // EarlyStopping
|
||||
nb::cast<std::optional<SizeType32>>(state[15]), // NoRepeatNgramSize
|
||||
nb::cast<std::optional<SizeType32>>(state[16]), // NumReturnSequences
|
||||
nb::cast<std::optional<FloatType>>(state[17]), // MinP
|
||||
nb::cast<std::optional<std::vector<SizeType32>>>(state[18]) // BeamWidthArray
|
||||
nb::cast<std::optional<SizeType32>>(state[13]), // PromptIgnoreLength
|
||||
nb::cast<std::optional<FloatType>>(state[14]), // LengthPenalty
|
||||
nb::cast<std::optional<SizeType32>>(state[15]), // EarlyStopping
|
||||
nb::cast<std::optional<SizeType32>>(state[16]), // NoRepeatNgramSize
|
||||
nb::cast<std::optional<SizeType32>>(state[17]), // NumReturnSequences
|
||||
nb::cast<std::optional<FloatType>>(state[18]), // MinP
|
||||
nb::cast<std::optional<std::vector<SizeType32>>>(state[19]) // BeamWidthArray
|
||||
);
|
||||
};
|
||||
nb::class_<tle::SamplingConfig>(m, "SamplingConfig")
|
||||
@ -120,6 +121,7 @@ void initRequestBindings(nb::module_& m)
|
||||
std::optional<tle::FloatType> const&, // repetitionPenalty
|
||||
std::optional<tle::FloatType> const&, // presencePenalty
|
||||
std::optional<tle::FloatType> const&, // frequencyPenalty
|
||||
std::optional<tle::SizeType32> const&, // promptIgnoreLength
|
||||
std::optional<tle::FloatType> const&, // lengthPenalty
|
||||
std::optional<tle::SizeType32> const&, // earlyStopping
|
||||
std::optional<tle::SizeType32> const&, // noRepeatNgramSize
|
||||
@ -142,6 +144,7 @@ void initRequestBindings(nb::module_& m)
|
||||
nb::arg("repetition_penalty") = nb::none(),
|
||||
nb::arg("presence_penalty") = nb::none(),
|
||||
nb::arg("frequency_penalty") = nb::none(),
|
||||
nb::arg("prompt_ignore_length") = nb::none(),
|
||||
nb::arg("length_penalty") = nb::none(),
|
||||
nb::arg("early_stopping") = nb::none(),
|
||||
nb::arg("no_repeat_ngram_size") = nb::none(),
|
||||
@ -165,6 +168,8 @@ void initRequestBindings(nb::module_& m)
|
||||
[](tle::SamplingConfig& self, std::optional<FloatType> v) { self.setPresencePenalty(v); })
|
||||
.def_prop_rw(
|
||||
"frequency_penalty", &tle::SamplingConfig::getFrequencyPenalty, &tle::SamplingConfig::setFrequencyPenalty)
|
||||
.def_prop_rw("prompt_ignore_length", &tle::SamplingConfig::getPromptIgnoreLength,
|
||||
&tle::SamplingConfig::setPromptIgnoreLength)
|
||||
.def_prop_rw("length_penalty", &tle::SamplingConfig::getLengthPenalty, &tle::SamplingConfig::setLengthPenalty)
|
||||
.def_prop_rw("early_stopping", &tle::SamplingConfig::getEarlyStopping, &tle::SamplingConfig::setEarlyStopping)
|
||||
.def_prop_rw("no_repeat_ngram_size", &tle::SamplingConfig::getNoRepeatNgramSize,
|
||||
|
||||
@ -361,14 +361,14 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
|
||||
auto SamplingConfigGetState = [](tr::SamplingConfig const& config) -> py::tuple
|
||||
{
|
||||
return py::make_tuple(config.beamWidth, config.temperature, config.minLength, config.repetitionPenalty,
|
||||
config.presencePenalty, config.frequencyPenalty, config.topK, config.topP, config.randomSeed,
|
||||
config.topPDecay, config.topPMin, config.topPResetIds, config.beamSearchDiversityRate, config.lengthPenalty,
|
||||
config.earlyStopping, config.noRepeatNgramSize, config.numReturnSequences, config.minP,
|
||||
config.beamWidthArray);
|
||||
config.presencePenalty, config.frequencyPenalty, config.promptIgnoreLength, config.topK, config.topP,
|
||||
config.randomSeed, config.topPDecay, config.topPMin, config.topPResetIds, config.beamSearchDiversityRate,
|
||||
config.lengthPenalty, config.earlyStopping, config.noRepeatNgramSize, config.numReturnSequences,
|
||||
config.minP, config.beamWidthArray);
|
||||
};
|
||||
auto SamplingConfigSetState = [](py::tuple t) -> tr::SamplingConfig
|
||||
{
|
||||
if (t.size() != 19)
|
||||
if (t.size() != 20)
|
||||
{
|
||||
throw std::runtime_error("Invalid SamplingConfig state!");
|
||||
}
|
||||
@ -380,19 +380,20 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
|
||||
config.repetitionPenalty = t[3].cast<OptVec<float>>();
|
||||
config.presencePenalty = t[4].cast<OptVec<float>>();
|
||||
config.frequencyPenalty = t[5].cast<OptVec<float>>();
|
||||
config.topK = t[6].cast<OptVec<SizeType32>>();
|
||||
config.topP = t[7].cast<OptVec<float>>();
|
||||
config.randomSeed = t[8].cast<OptVec<uint64_t>>();
|
||||
config.topPDecay = t[9].cast<OptVec<float>>();
|
||||
config.topPMin = t[10].cast<OptVec<float>>();
|
||||
config.topPResetIds = t[11].cast<OptVec<TokenIdType>>();
|
||||
config.beamSearchDiversityRate = t[12].cast<OptVec<float>>();
|
||||
config.lengthPenalty = t[13].cast<OptVec<float>>();
|
||||
config.earlyStopping = t[14].cast<OptVec<SizeType32>>();
|
||||
config.noRepeatNgramSize = t[15].cast<OptVec<SizeType32>>();
|
||||
config.numReturnSequences = t[16].cast<SizeType32>();
|
||||
config.minP = t[17].cast<OptVec<float>>();
|
||||
config.beamWidthArray = t[18].cast<OptVec<std::vector<SizeType32>>>();
|
||||
config.promptIgnoreLength = t[6].cast<OptVec<SizeType32>>();
|
||||
config.topK = t[7].cast<OptVec<SizeType32>>();
|
||||
config.topP = t[8].cast<OptVec<float>>();
|
||||
config.randomSeed = t[9].cast<OptVec<uint64_t>>();
|
||||
config.topPDecay = t[10].cast<OptVec<float>>();
|
||||
config.topPMin = t[11].cast<OptVec<float>>();
|
||||
config.topPResetIds = t[12].cast<OptVec<TokenIdType>>();
|
||||
config.beamSearchDiversityRate = t[13].cast<OptVec<float>>();
|
||||
config.lengthPenalty = t[14].cast<OptVec<float>>();
|
||||
config.earlyStopping = t[15].cast<OptVec<SizeType32>>();
|
||||
config.noRepeatNgramSize = t[16].cast<OptVec<SizeType32>>();
|
||||
config.numReturnSequences = t[17].cast<SizeType32>();
|
||||
config.minP = t[18].cast<OptVec<float>>();
|
||||
config.beamWidthArray = t[19].cast<OptVec<std::vector<SizeType32>>>();
|
||||
|
||||
return config;
|
||||
};
|
||||
@ -407,6 +408,7 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
|
||||
.def_readwrite("repetition_penalty", &tr::SamplingConfig::repetitionPenalty)
|
||||
.def_readwrite("presence_penalty", &tr::SamplingConfig::presencePenalty)
|
||||
.def_readwrite("frequency_penalty", &tr::SamplingConfig::frequencyPenalty)
|
||||
.def_readwrite("prompt_ignore_length", &tr::SamplingConfig::promptIgnoreLength)
|
||||
.def_readwrite("top_k", &tr::SamplingConfig::topK)
|
||||
.def_readwrite("top_p", &tr::SamplingConfig::topP)
|
||||
.def_readwrite("random_seed", &tr::SamplingConfig::randomSeed)
|
||||
|
||||
@ -72,12 +72,12 @@ void initRequestBindings(pybind11::module_& m)
|
||||
return py::make_tuple(self.getBeamWidth(), self.getTopK(), self.getTopP(), self.getTopPMin(),
|
||||
self.getTopPResetIds(), self.getTopPDecay(), self.getSeed(), self.getTemperature(), self.getMinTokens(),
|
||||
self.getBeamSearchDiversityRate(), self.getRepetitionPenalty(), self.getPresencePenalty(),
|
||||
self.getFrequencyPenalty(), self.getLengthPenalty(), self.getEarlyStopping(), self.getNoRepeatNgramSize(),
|
||||
self.getNumReturnSequences(), self.getMinP(), self.getBeamWidthArray());
|
||||
self.getFrequencyPenalty(), self.getPromptIgnoreLength(), self.getLengthPenalty(), self.getEarlyStopping(),
|
||||
self.getNoRepeatNgramSize(), self.getNumReturnSequences(), self.getMinP(), self.getBeamWidthArray());
|
||||
};
|
||||
auto samplingConfigSetstate = [](py::tuple const& state)
|
||||
{
|
||||
if (state.size() != 19)
|
||||
if (state.size() != 20)
|
||||
{
|
||||
throw std::runtime_error("Invalid SamplingConfig state!");
|
||||
}
|
||||
@ -94,12 +94,13 @@ void initRequestBindings(pybind11::module_& m)
|
||||
state[10].cast<std::optional<FloatType>>(), // RepetitionPenalty
|
||||
state[11].cast<std::optional<FloatType>>(), // PresencePenalty
|
||||
state[12].cast<std::optional<FloatType>>(), // FrequencyPenalty
|
||||
state[13].cast<std::optional<FloatType>>(), // LengthPenalty
|
||||
state[14].cast<std::optional<SizeType32>>(), // EarlyStopping
|
||||
state[15].cast<std::optional<SizeType32>>(), // NoRepeatNgramSize
|
||||
state[16].cast<std::optional<SizeType32>>(), // NumReturnSequences
|
||||
state[17].cast<std::optional<FloatType>>(), // MinP
|
||||
state[18].cast<std::optional<std::vector<SizeType32>>>() // BeamWidthArray
|
||||
state[13].cast<std::optional<SizeType32>>(), // PromptIgnoreLength
|
||||
state[14].cast<std::optional<FloatType>>(), // LengthPenalty
|
||||
state[15].cast<std::optional<SizeType32>>(), // EarlyStopping
|
||||
state[16].cast<std::optional<SizeType32>>(), // NoRepeatNgramSize
|
||||
state[17].cast<std::optional<SizeType32>>(), // NumReturnSequences
|
||||
state[18].cast<std::optional<FloatType>>(), // MinP
|
||||
state[19].cast<std::optional<std::vector<SizeType32>>>() // BeamWidthArray
|
||||
);
|
||||
};
|
||||
py::class_<tle::SamplingConfig>(m, "SamplingConfig")
|
||||
@ -116,6 +117,7 @@ void initRequestBindings(pybind11::module_& m)
|
||||
std::optional<tle::FloatType> const&, // repetitionPenalty
|
||||
std::optional<tle::FloatType> const&, // presencePenalty
|
||||
std::optional<tle::FloatType> const&, // frequencyPenalty
|
||||
std::optional<tle::SizeType32> const&, // promptIgnoreLength
|
||||
std::optional<tle::FloatType> const&, // lengthPenalty
|
||||
std::optional<tle::SizeType32> const&, // earlyStopping
|
||||
std::optional<tle::SizeType32> const&, // noRepeatNgramSize
|
||||
@ -138,6 +140,7 @@ void initRequestBindings(pybind11::module_& m)
|
||||
py::arg("repetition_penalty") = py::none(),
|
||||
py::arg("presence_penalty") = py::none(),
|
||||
py::arg("frequency_penalty") = py::none(),
|
||||
py::arg("prompt_ignore_length") = py::none(),
|
||||
py::arg("length_penalty") = py::none(),
|
||||
py::arg("early_stopping") = py::none(),
|
||||
py::arg("no_repeat_ngram_size") = py::none(),
|
||||
@ -161,6 +164,8 @@ void initRequestBindings(pybind11::module_& m)
|
||||
[](tle::SamplingConfig& self, std::optional<FloatType> v) { self.setPresencePenalty(v); })
|
||||
.def_property(
|
||||
"frequency_penalty", &tle::SamplingConfig::getFrequencyPenalty, &tle::SamplingConfig::setFrequencyPenalty)
|
||||
.def_property("prompt_ignore_length", &tle::SamplingConfig::getPromptIgnoreLength,
|
||||
&tle::SamplingConfig::setPromptIgnoreLength)
|
||||
.def_property("length_penalty", &tle::SamplingConfig::getLengthPenalty, &tle::SamplingConfig::setLengthPenalty)
|
||||
.def_property("early_stopping", &tle::SamplingConfig::getEarlyStopping, &tle::SamplingConfig::setEarlyStopping)
|
||||
.def_property("no_repeat_ngram_size", &tle::SamplingConfig::getNoRepeatNgramSize,
|
||||
|
||||
@ -84,6 +84,7 @@ void GptDecoder<T>::disableLookahead(
|
||||
penaltyParams->repetitionPenalty = mSamplingConfig.repetitionPenalty;
|
||||
penaltyParams->presencePenalty = mSamplingConfig.presencePenalty;
|
||||
penaltyParams->frequencyPenalty = mSamplingConfig.frequencyPenalty;
|
||||
penaltyParams->promptIgnoreLength = mSamplingConfig.promptIgnoreLength;
|
||||
penaltyParams->temperature = mSamplingConfig.temperature;
|
||||
penaltyParams->minLength = mSamplingConfig.minLength;
|
||||
|
||||
@ -136,6 +137,7 @@ void GptDecoder<T>::setup(SamplingConfig const& samplingConfig, size_t batchSize
|
||||
penaltyParams->repetitionPenalty = mSamplingConfig.repetitionPenalty;
|
||||
penaltyParams->presencePenalty = mSamplingConfig.presencePenalty;
|
||||
penaltyParams->frequencyPenalty = mSamplingConfig.frequencyPenalty;
|
||||
penaltyParams->promptIgnoreLength = mSamplingConfig.promptIgnoreLength;
|
||||
penaltyParams->temperature = mSamplingConfig.temperature;
|
||||
penaltyParams->minLength = mSamplingConfig.minLength;
|
||||
|
||||
|
||||
@ -120,12 +120,12 @@ void FtDynamicDecode<T>::setup(size_t const batch_size, size_t const beam_width,
|
||||
th::optional<th::Tensor> runtime_top_k_opt, th::optional<th::Tensor> runtime_top_p_opt,
|
||||
th::optional<th::Tensor> temperature_opt, th::optional<th::Tensor> repetition_penalty_opt,
|
||||
th::optional<th::Tensor> presence_penalty_opt, th::optional<th::Tensor> frequency_penalty_opt,
|
||||
th::optional<th::Tensor> min_length_opt, th::optional<th::Tensor> length_penalty_opt,
|
||||
th::optional<th::Tensor> early_stopping_opt, th::optional<th::Tensor> beam_search_diversity_rate_opt,
|
||||
th::optional<th::Tensor> random_seed_opt, th::optional<th::Tensor> top_p_decay_opt,
|
||||
th::optional<th::Tensor> top_p_min_opt, th::optional<th::Tensor> top_p_reset_ids_opt,
|
||||
th::optional<th::Tensor> no_repeat_ngram_size_opt, th::optional<th::Tensor> min_p_opt, bool output_log_probs,
|
||||
bool cum_log_probs)
|
||||
th::optional<th::Tensor> prompt_ignore_length_opt, th::optional<th::Tensor> min_length_opt,
|
||||
th::optional<th::Tensor> length_penalty_opt, th::optional<th::Tensor> early_stopping_opt,
|
||||
th::optional<th::Tensor> beam_search_diversity_rate_opt, th::optional<th::Tensor> random_seed_opt,
|
||||
th::optional<th::Tensor> top_p_decay_opt, th::optional<th::Tensor> top_p_min_opt,
|
||||
th::optional<th::Tensor> top_p_reset_ids_opt, th::optional<th::Tensor> no_repeat_ngram_size_opt,
|
||||
th::optional<th::Tensor> min_p_opt, bool output_log_probs, bool cum_log_probs)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
mBeamWidth = beam_width;
|
||||
@ -137,6 +137,7 @@ void FtDynamicDecode<T>::setup(size_t const batch_size, size_t const beam_width,
|
||||
safeInsert(repetition_penalty_opt, penaltyParams->repetitionPenalty);
|
||||
safeInsert(presence_penalty_opt, penaltyParams->presencePenalty);
|
||||
safeInsert(frequency_penalty_opt, penaltyParams->frequencyPenalty);
|
||||
safeInsert(prompt_ignore_length_opt, penaltyParams->promptIgnoreLength);
|
||||
safeInsert(min_length_opt, penaltyParams->minLength);
|
||||
safeInsert(no_repeat_ngram_size_opt, banWordsParams->noRepeatNgramSize);
|
||||
if (beam_width == 1)
|
||||
@ -328,10 +329,10 @@ void DynamicDecodeOp::createInstance()
|
||||
void DynamicDecodeOp::setup(int64_t const batchSize, int64_t const beamWidth, th::optional<th::Tensor> runtimeTopKOpt,
|
||||
th::optional<th::Tensor> runtimeTopPOpt, th::optional<th::Tensor> temperatureOpt,
|
||||
th::optional<th::Tensor> repetitionPenaltyOpt, th::optional<th::Tensor> presencePenaltyOpt,
|
||||
th::optional<th::Tensor> frequencyPenaltyOpt, th::optional<th::Tensor> minLengthOpt,
|
||||
th::optional<th::Tensor> lengthPenaltyOpt, th::optional<th::Tensor> earlyStoppingOpt,
|
||||
th::optional<th::Tensor> beamSearchDiversityRateOpt, th::optional<th::Tensor> randomSeedOpt,
|
||||
th::optional<th::Tensor> topPDecayOpt, th::optional<th::Tensor> topPMinOpt,
|
||||
th::optional<th::Tensor> frequencyPenaltyOpt, th::optional<th::Tensor> promptIgnoreLengthOpt,
|
||||
th::optional<th::Tensor> minLengthOpt, th::optional<th::Tensor> lengthPenaltyOpt,
|
||||
th::optional<th::Tensor> earlyStoppingOpt, th::optional<th::Tensor> beamSearchDiversityRateOpt,
|
||||
th::optional<th::Tensor> randomSeedOpt, th::optional<th::Tensor> topPDecayOpt, th::optional<th::Tensor> topPMinOpt,
|
||||
th::optional<th::Tensor> topPResetIdsOpt, th::optional<th::Tensor> noRepeatNgramSizeOpt,
|
||||
th::optional<th::Tensor> minPOpt, bool outputLogProbs, bool cumLogProbs)
|
||||
{
|
||||
@ -343,6 +344,7 @@ void DynamicDecodeOp::setup(int64_t const batchSize, int64_t const beamWidth, th
|
||||
CHECK_OPTIONAL_CPU_INPUT(repetitionPenaltyOpt, torch::kFloat);
|
||||
CHECK_OPTIONAL_CPU_INPUT(presencePenaltyOpt, torch::kFloat);
|
||||
CHECK_OPTIONAL_CPU_INPUT(frequencyPenaltyOpt, torch::kFloat);
|
||||
CHECK_OPTIONAL_CPU_INPUT(promptIgnoreLengthOpt, torch::kInt32);
|
||||
CHECK_OPTIONAL_CPU_INPUT(minLengthOpt, torch::kInt32);
|
||||
CHECK_OPTIONAL_CPU_INPUT(lengthPenaltyOpt, torch::kFloat);
|
||||
CHECK_OPTIONAL_CPU_INPUT(earlyStoppingOpt, torch::kInt32);
|
||||
@ -356,8 +358,9 @@ void DynamicDecodeOp::setup(int64_t const batchSize, int64_t const beamWidth, th
|
||||
|
||||
dynamicDecode_->setup(static_cast<tr::SizeType32>(batchSize), static_cast<tr::SizeType32>(beamWidth),
|
||||
runtimeTopKOpt, runtimeTopPOpt, temperatureOpt, repetitionPenaltyOpt, presencePenaltyOpt, frequencyPenaltyOpt,
|
||||
minLengthOpt, lengthPenaltyOpt, earlyStoppingOpt, beamSearchDiversityRateOpt, randomSeedOpt, topPDecayOpt,
|
||||
topPMinOpt, topPResetIdsOpt, noRepeatNgramSizeOpt, minPOpt, outputLogProbs, cumLogProbs);
|
||||
promptIgnoreLengthOpt, minLengthOpt, lengthPenaltyOpt, earlyStoppingOpt, beamSearchDiversityRateOpt,
|
||||
randomSeedOpt, topPDecayOpt, topPMinOpt, topPResetIdsOpt, noRepeatNgramSizeOpt, minPOpt, outputLogProbs,
|
||||
cumLogProbs);
|
||||
}
|
||||
|
||||
th::Tensor DynamicDecodeOp::forward(
|
||||
|
||||
@ -32,12 +32,13 @@ public:
|
||||
virtual void setup(size_t const batch_size, size_t const beam_width, th::optional<th::Tensor> runtime_top_k_opt,
|
||||
th::optional<th::Tensor> runtime_top_p_opt, th::optional<th::Tensor> temperature_opt,
|
||||
th::optional<th::Tensor> repetition_penalty_opt, th::optional<th::Tensor> presence_penalty_opt,
|
||||
th::optional<th::Tensor> frequency_penalty_opt, th::optional<th::Tensor> min_length_opt,
|
||||
th::optional<th::Tensor> length_penalty_opt, th::optional<th::Tensor> early_stopping_opt,
|
||||
th::optional<th::Tensor> beam_search_diversity_rate_opt, th::optional<th::Tensor> random_seed_opt,
|
||||
th::optional<th::Tensor> top_p_decay_opt, th::optional<th::Tensor> top_p_min_opt,
|
||||
th::optional<th::Tensor> top_p_reset_ids_opt, th::optional<th::Tensor> no_repeat_ngram_size_opt,
|
||||
th::optional<th::Tensor> min_p_opt, bool output_log_probs, bool cum_log_probs)
|
||||
th::optional<th::Tensor> frequency_penalty_opt, th::optional<th::Tensor> prompt_ignore_length_opt,
|
||||
th::optional<th::Tensor> min_length_opt, th::optional<th::Tensor> length_penalty_opt,
|
||||
th::optional<th::Tensor> early_stopping_opt, th::optional<th::Tensor> beam_search_diversity_rate_opt,
|
||||
th::optional<th::Tensor> random_seed_opt, th::optional<th::Tensor> top_p_decay_opt,
|
||||
th::optional<th::Tensor> top_p_min_opt, th::optional<th::Tensor> top_p_reset_ids_opt,
|
||||
th::optional<th::Tensor> no_repeat_ngram_size_opt, th::optional<th::Tensor> min_p_opt, bool output_log_probs,
|
||||
bool cum_log_probs)
|
||||
= 0;
|
||||
|
||||
virtual void forward(th::Tensor const& logits, int const step, int const max_input_length,
|
||||
@ -72,12 +73,13 @@ public:
|
||||
void setup(size_t const batch_size, size_t const beam_width, th::optional<th::Tensor> runtime_top_k_opt,
|
||||
th::optional<th::Tensor> runtime_top_p_opt, th::optional<th::Tensor> temperature_opt,
|
||||
th::optional<th::Tensor> repetition_penalty_opt, th::optional<th::Tensor> presence_penalty_opt,
|
||||
th::optional<th::Tensor> frequency_penalty_opt, th::optional<th::Tensor> min_length_opt,
|
||||
th::optional<th::Tensor> length_penalty_opt, th::optional<th::Tensor> early_stopping_opt,
|
||||
th::optional<th::Tensor> beam_search_diversity_rate_opt, th::optional<th::Tensor> random_seed_opt,
|
||||
th::optional<th::Tensor> top_p_decay_opt, th::optional<th::Tensor> top_p_min_opt,
|
||||
th::optional<th::Tensor> top_p_reset_ids_opt, th::optional<th::Tensor> no_repeat_ngram_size_opt,
|
||||
th::optional<th::Tensor> min_p_opt, bool output_log_probs, bool cum_log_probs) override;
|
||||
th::optional<th::Tensor> frequency_penalty_opt, th::optional<th::Tensor> prompt_ignore_length_opt,
|
||||
th::optional<th::Tensor> min_length_opt, th::optional<th::Tensor> length_penalty_opt,
|
||||
th::optional<th::Tensor> early_stopping_opt, th::optional<th::Tensor> beam_search_diversity_rate_opt,
|
||||
th::optional<th::Tensor> random_seed_opt, th::optional<th::Tensor> top_p_decay_opt,
|
||||
th::optional<th::Tensor> top_p_min_opt, th::optional<th::Tensor> top_p_reset_ids_opt,
|
||||
th::optional<th::Tensor> no_repeat_ngram_size_opt, th::optional<th::Tensor> min_p_opt, bool output_log_probs,
|
||||
bool cum_log_probs) override;
|
||||
|
||||
void forward(th::Tensor const& logits, int const step, int const max_input_length, int const max_attention_window,
|
||||
int const sink_token_length, uint64_t const ite, int const local_batch_size, th::Tensor end_id,
|
||||
@ -115,12 +117,13 @@ public:
|
||||
void setup(int64_t const batch_size, int64_t const beam_width, th::optional<th::Tensor> runtime_top_k_opt,
|
||||
th::optional<th::Tensor> runtime_top_p_opt, th::optional<th::Tensor> temperature_opt,
|
||||
th::optional<th::Tensor> repetition_penalty_opt, th::optional<th::Tensor> presence_penalty_opt,
|
||||
th::optional<th::Tensor> frequency_penalty_opt, th::optional<th::Tensor> min_length_opt,
|
||||
th::optional<th::Tensor> length_penalty_opt, th::optional<th::Tensor> early_stopping_opt,
|
||||
th::optional<th::Tensor> beam_search_diversity_rate_opt, th::optional<th::Tensor> random_seed_opt,
|
||||
th::optional<th::Tensor> top_p_decay_opt, th::optional<th::Tensor> top_p_min_opt,
|
||||
th::optional<th::Tensor> top_p_reset_ids_opt, th::optional<th::Tensor> no_repeat_ngram_size_opt,
|
||||
th::optional<th::Tensor> min_p_opt, bool output_log_probs, bool cum_log_probs);
|
||||
th::optional<th::Tensor> frequency_penalty_opt, th::optional<th::Tensor> prompt_ignore_length_opt,
|
||||
th::optional<th::Tensor> min_length_opt, th::optional<th::Tensor> length_penalty_opt,
|
||||
th::optional<th::Tensor> early_stopping_opt, th::optional<th::Tensor> beam_search_diversity_rate_opt,
|
||||
th::optional<th::Tensor> random_seed_opt, th::optional<th::Tensor> top_p_decay_opt,
|
||||
th::optional<th::Tensor> top_p_min_opt, th::optional<th::Tensor> top_p_reset_ids_opt,
|
||||
th::optional<th::Tensor> no_repeat_ngram_size_opt, th::optional<th::Tensor> min_p_opt, bool output_log_probs,
|
||||
bool cum_log_probs);
|
||||
|
||||
th::Tensor forward(th::Tensor const& logits, int64_t const step, int64_t const max_input_length,
|
||||
int64_t const max_attention_window, int64_t const sink_token_length, int64_t const ite,
|
||||
|
||||
@ -34,17 +34,18 @@ void test(bool const isTestValid, SizeType32 beamWidth = 1, std::optional<SizeTy
|
||||
std::optional<RandomSeedType> randomSeed = no, std::optional<FloatType> temperature = no,
|
||||
std::optional<SizeType32> minLength = no, std::optional<FloatType> beamSearchDiversityRate = no,
|
||||
std::optional<FloatType> repetitionPenalty = no, std::optional<FloatType> presencePenalty = no,
|
||||
std::optional<FloatType> frequencyPenalty = no, std::optional<FloatType> lengthPenalty = no,
|
||||
std::optional<SizeType32> earlyStopping = no, std::optional<SizeType32> noRepeatNgramSize = no,
|
||||
std::optional<SizeType32> numReturnSequences = no, std::optional<FloatType> minP = no,
|
||||
std::optional<std::vector<SizeType32>> beamWidthArray = no)
|
||||
std::optional<FloatType> frequencyPenalty = no, std::optional<SizeType32> promptIgnoreLength = no,
|
||||
std::optional<FloatType> lengthPenalty = no, std::optional<SizeType32> earlyStopping = no,
|
||||
std::optional<SizeType32> noRepeatNgramSize = no, std::optional<SizeType32> numReturnSequences = no,
|
||||
std::optional<FloatType> minP = no, std::optional<std::vector<SizeType32>> beamWidthArray = no)
|
||||
{
|
||||
// 19 parameters for SamplingConfig, from `beamWidth` to `beamWidthArray`
|
||||
// 20 parameters for SamplingConfig, from `beamWidth` to `beamWidthArray`
|
||||
try
|
||||
{
|
||||
auto sc = SamplingConfig(beamWidth, topK, topP, topPMin, topPResetIds, topPDecay, randomSeed, temperature,
|
||||
minLength, beamSearchDiversityRate, repetitionPenalty, presencePenalty, frequencyPenalty, lengthPenalty,
|
||||
earlyStopping, noRepeatNgramSize, numReturnSequences, minP, beamWidthArray);
|
||||
minLength, beamSearchDiversityRate, repetitionPenalty, presencePenalty, frequencyPenalty,
|
||||
promptIgnoreLength, lengthPenalty, earlyStopping, noRepeatNgramSize, numReturnSequences, minP,
|
||||
beamWidthArray);
|
||||
|
||||
// Come here if `sc` is valid
|
||||
if (!isTestValid)
|
||||
@ -102,18 +103,20 @@ TEST(SamplingConfigTest, validInputs)
|
||||
test(true, 1, no, no, no, no, no, no, no, no, no, no, 1.f);
|
||||
// Frequency penalty
|
||||
test(true, 1, no, no, no, no, no, no, no, no, no, no, no, 1.f);
|
||||
// Prompt ignore length
|
||||
test(true, 1, no, no, no, no, no, no, no, no, no, no, no, no, 1);
|
||||
// Length penalty
|
||||
test(true, 1, no, no, no, no, no, no, no, no, no, no, no, no, 1.f);
|
||||
// Early stopping
|
||||
test(true, 1, no, no, no, no, no, no, no, no, no, no, no, no, no, 1.f);
|
||||
// Early stopping
|
||||
test(true, 1, no, no, no, no, no, no, no, no, no, no, no, no, no, no, 1.f);
|
||||
// No repeat ngram size
|
||||
test(true, 1, no, no, no, no, no, no, no, no, no, no, no, no, no, no, 2);
|
||||
test(true, 1, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, 2);
|
||||
// NumReturnSequences
|
||||
test(true, 4, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, 2);
|
||||
test(true, 4, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, 2);
|
||||
// MinP
|
||||
test(true, 1, no, 0.9, no, no, no, no, no, no, no, no, no, no, no, no, no, no, 0.5f);
|
||||
test(true, 1, no, 0.9, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, 0.5f);
|
||||
// BeamWidthArray
|
||||
test(true, 5, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no,
|
||||
test(true, 5, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no,
|
||||
std::vector<SizeType32>{2, 3, 4, 5});
|
||||
}
|
||||
|
||||
@ -156,32 +159,35 @@ TEST(SamplingConfigTest, invalidInputs)
|
||||
|
||||
// Skip presence penalty, frequency penalty, no test
|
||||
|
||||
// Neg length penalty
|
||||
// Neg prompt ignore length
|
||||
test(false, 1, no, no, no, no, no, no, no, no, no, no, no, no, -1);
|
||||
|
||||
// Neg early stopping
|
||||
// Neg length penalty
|
||||
test(false, 1, no, no, no, no, no, no, no, no, no, no, no, no, no, -1);
|
||||
|
||||
// Neg no repeat ngram size
|
||||
// Neg early stopping
|
||||
test(false, 1, no, no, no, no, no, no, no, no, no, no, no, no, no, no, -1);
|
||||
|
||||
// Neg no repeat ngram size
|
||||
test(false, 1, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, -1);
|
||||
|
||||
// Neg or zero numReturnSequences
|
||||
test(false, 1, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, 0);
|
||||
test(false, 1, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, 0);
|
||||
|
||||
// numReturnSequences > beamWidth
|
||||
test(false, 2, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, 4);
|
||||
test(false, 2, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, 4);
|
||||
|
||||
// Neg minP
|
||||
test(false, 1, no, 0.9, no, no, no, no, no, no, no, no, no, no, no, no, no, no, -1.f);
|
||||
test(false, 1, no, 0.9, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, -1.f);
|
||||
|
||||
// Neg / Large minP
|
||||
test(false, 1, no, 0.9, no, no, no, no, no, no, no, no, no, no, no, no, no, no, -1.f);
|
||||
test(false, 1, no, 0.9, no, no, no, no, no, no, no, no, no, no, no, no, no, no, +2.f);
|
||||
test(false, 1, no, 0.9, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, -1.f);
|
||||
test(false, 1, no, 0.9, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, +2.f);
|
||||
|
||||
// BeamWidthArray with neg / large beamWidth
|
||||
test(false, 4, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no,
|
||||
test(false, 4, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no,
|
||||
std::vector<SizeType32>{2, 3, 4, -1});
|
||||
test(false, 4, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no,
|
||||
test(false, 4, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no,
|
||||
std::vector<SizeType32>{2, 3, 4, 65536});
|
||||
}
|
||||
|
||||
@ -265,6 +271,12 @@ TEST(SamplingConfigTest, getterSetter)
|
||||
sc.setFrequencyPenalty(0.5f);
|
||||
EXPECT_EQ(sc.getFrequencyPenalty(), 0.5f);
|
||||
}
|
||||
// Prompt ignore length
|
||||
{
|
||||
auto sc = SamplingConfig();
|
||||
sc.setPromptIgnoreLength(1);
|
||||
EXPECT_EQ(sc.getPromptIgnoreLength(), 1);
|
||||
}
|
||||
// Length penalty
|
||||
{
|
||||
auto sc = SamplingConfig();
|
||||
|
||||
@ -161,7 +161,7 @@ protected:
|
||||
mLogitsPtrs = BufferManager::pinned(ITensor::makeShape({mBatchSize}), ptrType);
|
||||
|
||||
mPenaltyWorkspaceDevice = mBufferManager->gpu(
|
||||
ITensor::makeShape({mMaxBatchSize, mMaxTokensPerStep, mVocabSizePadded}), nvinfer1::DataType::kINT32);
|
||||
ITensor::makeShape({mMaxBatchSize, mMaxTokensPerStep, mVocabSize * 2}), nvinfer1::DataType::kINT32);
|
||||
|
||||
mTokensPerStep = BufferManager::pinned(ITensor::makeShape({mMaxBatchSize}), nvinfer1::DataType::kINT32);
|
||||
|
||||
@ -259,8 +259,8 @@ public:
|
||||
InvokeBatchApplyPenaltyParams<T> penaltyParams{reinterpret_cast<T**>(bufferCast<int64_t>(*mLogitsPtrs)),
|
||||
bufferCast<T>(*mOutLogitsDevice), bufferCast<T>(*mBiasDevice),
|
||||
bufferCast<int32_t>(*mPenaltyWorkspaceDevice), nullptr, bufferCast<float>(*mTemperaturesDevice), nullptr,
|
||||
nullptr, nullptr, mBatchSize, 1, 1, mVocabSize, mVocabSizePadded, nullptr, nullptr, nullptr, nullptr,
|
||||
nullptr, nullptr, bufferCast<int32_t>(*mBatchSlots), mMaxTokensPerStep,
|
||||
nullptr, nullptr, nullptr, mBatchSize, 1, 1, mVocabSize, mVocabSizePadded, nullptr, nullptr, nullptr,
|
||||
nullptr, nullptr, nullptr, bufferCast<int32_t>(*mBatchSlots), mMaxTokensPerStep,
|
||||
bufferCast<int32_t>(*mTokensPerStep), nullptr, mStream->get()};
|
||||
tk::invokeBatchApplyPenalty(penaltyParams);
|
||||
auto logitsOutHost = mBufferManager->copyFrom(*mOutLogitsDevice, MemoryType::kCPU);
|
||||
@ -382,9 +382,11 @@ struct RepetitionPenaltyTestCase
|
||||
TensorPtr repetitionPenalties;
|
||||
TensorPtr presencePenalties;
|
||||
TensorPtr frequencyPenalties;
|
||||
TensorPtr promptIgnoreLengths;
|
||||
int32_t repetitionPenaltiesSize;
|
||||
int32_t presencePenaltiesSize;
|
||||
int32_t frequencyPenaltiesSize;
|
||||
int32_t promptIgnoreLengthsSize;
|
||||
int32_t maxTokensPerStep{1};
|
||||
|
||||
RepetitionPenaltyTestCase& setBatchSize(int32_t bs)
|
||||
@ -423,6 +425,12 @@ struct RepetitionPenaltyTestCase
|
||||
return *this;
|
||||
}
|
||||
|
||||
RepetitionPenaltyTestCase& setPromptIgnoreLengths(TensorPtr pil)
|
||||
{
|
||||
promptIgnoreLengths = pil;
|
||||
return *this;
|
||||
}
|
||||
|
||||
RepetitionPenaltyTestCase& setRepetitionPenaltiesSize(int32_t rps)
|
||||
{
|
||||
repetitionPenaltiesSize = rps;
|
||||
@ -441,6 +449,12 @@ struct RepetitionPenaltyTestCase
|
||||
return *this;
|
||||
}
|
||||
|
||||
RepetitionPenaltyTestCase& setPromptIgnoreLengthsSize(int32_t pils)
|
||||
{
|
||||
promptIgnoreLengthsSize = pils;
|
||||
return *this;
|
||||
}
|
||||
|
||||
RepetitionPenaltyTestCase& setMaxTokensPerStep(int32_t ts)
|
||||
{
|
||||
maxTokensPerStep = ts;
|
||||
@ -451,11 +465,12 @@ struct RepetitionPenaltyTestCase
|
||||
{
|
||||
return tc::fmtstr(
|
||||
"RepetitionPenaltyTestCase[batch=%d, vocab=%d, maxInputLength=%d, "
|
||||
"repetitionPenalties=%s, presencePenalties=%s, frequencyPenalties=%s]",
|
||||
"repetitionPenalties=%s, presencePenalties=%s, frequencyPenalties=%s, promptIgnoreLengths=%s]",
|
||||
batchSize, vocabSize, maxInputLength,
|
||||
tc::arr2str(bufferCast<float>(*repetitionPenalties), repetitionPenaltiesSize).c_str(),
|
||||
tc::arr2str(bufferCast<float>(*presencePenalties), presencePenaltiesSize).c_str(),
|
||||
tc::arr2str(bufferCast<float>(*frequencyPenalties), frequencyPenaltiesSize).c_str());
|
||||
tc::arr2str(bufferCast<float>(*frequencyPenalties), frequencyPenaltiesSize).c_str(),
|
||||
tc::arr2str(bufferCast<int32_t>(*promptIgnoreLengths), promptIgnoreLengthsSize).c_str());
|
||||
}
|
||||
};
|
||||
|
||||
@ -499,6 +514,7 @@ protected:
|
||||
TensorPtr mRepetitionPenaltiesDevice;
|
||||
TensorPtr mPresencePenaltiesDevice;
|
||||
TensorPtr mFrequencyPenaltiesDevice;
|
||||
TensorPtr mPromptIgnoreLengthsDevice;
|
||||
TensorPtr mBatchSlots;
|
||||
|
||||
void subsetup(RepetitionPenaltyTestCase param)
|
||||
@ -525,7 +541,7 @@ protected:
|
||||
mLogitsPtrs = BufferManager::pinned(ITensor::makeShape({mBatchSize}), ptrType);
|
||||
|
||||
mPenaltyWorkspaceDevice = mBufferManager->gpu(
|
||||
ITensor::makeShape({mBatchSize, mMaxTokensPerStep, mVocabSize}), nvinfer1::DataType::kINT32);
|
||||
ITensor::makeShape({mBatchSize, mMaxTokensPerStep, mVocabSize * 2}), nvinfer1::DataType::kINT32);
|
||||
|
||||
mTokensPerStep = BufferManager::pinned(ITensor::makeShape({mMaxBatchSize}), nvinfer1::DataType::kINT32);
|
||||
|
||||
@ -588,24 +604,29 @@ protected:
|
||||
ASSERT_EQ(param.repetitionPenaltiesSize, mMaxBatchSize) << "Invalid test configuration.";
|
||||
ASSERT_EQ(param.presencePenaltiesSize, mMaxBatchSize) << "Invalid test configuration.";
|
||||
ASSERT_EQ(param.frequencyPenaltiesSize, mMaxBatchSize) << "Invalid test configuration.";
|
||||
ASSERT_EQ(param.promptIgnoreLengthsSize, mMaxBatchSize) << "Invalid test configuration.";
|
||||
mRepetitionPenaltiesDevice
|
||||
= mBufferManager->gpu(ITensor::makeShape({param.repetitionPenaltiesSize}), nvinfer1::DataType::kFLOAT);
|
||||
mPresencePenaltiesDevice
|
||||
= mBufferManager->gpu(ITensor::makeShape({param.presencePenaltiesSize}), nvinfer1::DataType::kFLOAT);
|
||||
mFrequencyPenaltiesDevice
|
||||
= mBufferManager->gpu(ITensor::makeShape({param.frequencyPenaltiesSize}), nvinfer1::DataType::kFLOAT);
|
||||
mPromptIgnoreLengthsDevice
|
||||
= mBufferManager->gpu(ITensor::makeShape({param.promptIgnoreLengthsSize}), nvinfer1::DataType::kINT32);
|
||||
mBufferManager->copy(*param.repetitionPenalties, *mRepetitionPenaltiesDevice);
|
||||
mBufferManager->copy(*param.presencePenalties, *mPresencePenaltiesDevice);
|
||||
mBufferManager->copy(*param.frequencyPenalties, *mFrequencyPenaltiesDevice);
|
||||
mBufferManager->copy(*param.promptIgnoreLengths, *mPromptIgnoreLengthsDevice);
|
||||
}
|
||||
|
||||
void computeReference(T const* const inLogits, T* const outLogits, int32_t const* const outputIds,
|
||||
int32_t const* const sequenceLengths, float const* const repetitionPenalties,
|
||||
float const* const presencePenalties, float const* const frequencyPenalties,
|
||||
int32_t const repetitionPenaltiesSize, int32_t const presencePenaltiesSize,
|
||||
int32_t const frequencyPenaltiesSize)
|
||||
int32_t const* const promptIgnoreLengths, int32_t const repetitionPenaltiesSize,
|
||||
int32_t const presencePenaltiesSize, int32_t const frequencyPenaltiesSize,
|
||||
int32_t const promptIgnoreLengthsSize)
|
||||
{
|
||||
std::vector<bool> penalized(mVocabSize);
|
||||
std::vector<bool> repetitionPenalized(mVocabSize), presencePenalized(mVocabSize);
|
||||
auto const batchSlotsPtr = bufferCast<int32_t>(*mBatchSlots);
|
||||
auto const tokensPerStepPtr = bufferCast<int32_t>(*mTokensPerStep);
|
||||
|
||||
@ -633,21 +654,47 @@ protected:
|
||||
float presencePenalty = presencePenaltiesSize > 1 ? presencePenalties[batchSlot] : presencePenalties[0];
|
||||
float frequencyPenalty
|
||||
= frequencyPenaltiesSize > 1 ? frequencyPenalties[batchSlot] : frequencyPenalties[0];
|
||||
int32_t promptIgnoreLength
|
||||
= promptIgnoreLengthsSize > 1 ? promptIgnoreLengths[batchSlot] : promptIgnoreLengths[0];
|
||||
|
||||
std::fill(penalized.begin(), penalized.end(), false);
|
||||
std::fill(repetitionPenalized.begin(), repetitionPenalized.end(), false);
|
||||
std::fill(presencePenalized.begin(), presencePenalized.end(), false);
|
||||
size_t offset = (bi * mMaxTokensPerStep + ti) * mVocabSizePadded;
|
||||
auto const step = sequenceLengths[batchSlot];
|
||||
|
||||
// clamping to the inputLength (set to same as sequenceLength)
|
||||
promptIgnoreLength = std::min(promptIgnoreLength, step);
|
||||
|
||||
std::vector<int32_t> numOccurences(mVocabSize, 0);
|
||||
for (int32_t t = 0; t < step; ++t)
|
||||
{
|
||||
auto tokenId = outputIds[batchSlot * mSequenceLength + t];
|
||||
if (!penalized[tokenId])
|
||||
|
||||
if (!repetitionPenalized[tokenId])
|
||||
{
|
||||
auto logit = static_cast<float>(outLogits[offset + tokenId]);
|
||||
outLogits[offset + tokenId] = static_cast<T>(
|
||||
(logit < 0.0f ? logit * repetitionPenalty : logit / repetitionPenalty) - presencePenalty);
|
||||
penalized[tokenId] = true;
|
||||
outLogits[offset + tokenId]
|
||||
= static_cast<T>((logit < 0.0f ? logit * repetitionPenalty : logit / repetitionPenalty));
|
||||
repetitionPenalized[tokenId] = true;
|
||||
}
|
||||
|
||||
if (!(t < promptIgnoreLength))
|
||||
{
|
||||
presencePenalized[tokenId] = true;
|
||||
numOccurences[tokenId] += 1;
|
||||
}
|
||||
}
|
||||
|
||||
for (int32_t vi = 0; vi < mVocabSize; ++vi)
|
||||
{
|
||||
if (presencePenalized[vi])
|
||||
{
|
||||
outLogits[offset + vi] -= presencePenalty;
|
||||
}
|
||||
if (numOccurences[vi] > 0)
|
||||
{
|
||||
outLogits[offset + vi] -= numOccurences[vi] * frequencyPenalty;
|
||||
}
|
||||
outLogits[offset + tokenId] -= frequencyPenalty;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -661,7 +708,8 @@ public:
|
||||
InvokeBatchApplyPenaltyParams<T> penaltyParams{reinterpret_cast<T**>(bufferCast<int64_t>(*mLogitsPtrs)),
|
||||
bufferCast<T>(*mOutLogitsDevice), nullptr, bufferCast<int32_t>(*mPenaltyWorkspaceDevice), nullptr, nullptr,
|
||||
bufferCast<float>(*mRepetitionPenaltiesDevice), bufferCast<float>(*mPresencePenaltiesDevice),
|
||||
bufferCast<float>(*mFrequencyPenaltiesDevice), mBatchSize, 1, mSequenceLength, mVocabSize, mVocabSizePadded,
|
||||
bufferCast<float>(*mFrequencyPenaltiesDevice), bufferCast<int32_t>(*mPromptIgnoreLengthsDevice), mBatchSize,
|
||||
1, mSequenceLength, mVocabSize, mVocabSizePadded,
|
||||
reinterpret_cast<int32_t const**>(bufferCast<int64_t>(*mIdsPtrDevice)), nullptr,
|
||||
bufferCast<int32_t>(*mContextLengthDevice), bufferCast<int32_t>(*mSeqLengthDevice), nullptr, nullptr,
|
||||
bufferCast<int32_t>(*mBatchSlots), mMaxTokensPerStep, bufferCast<int32_t>(*mTokensPerStep), nullptr,
|
||||
@ -673,8 +721,9 @@ public:
|
||||
computeReference(bufferCast<T>(*mLogitsHost), bufferCast<T>(*mLogitsRefHost),
|
||||
bufferCast<int32_t>(*mOutputIdsHost), bufferCast<int32_t>(*mSeqLengthHost),
|
||||
bufferCast<float>(*param.repetitionPenalties), bufferCast<float>(*param.presencePenalties),
|
||||
bufferCast<float>(*param.frequencyPenalties), param.repetitionPenaltiesSize, param.presencePenaltiesSize,
|
||||
param.frequencyPenaltiesSize);
|
||||
bufferCast<float>(*param.frequencyPenalties), bufferCast<int32_t>(*param.promptIgnoreLengths),
|
||||
param.repetitionPenaltiesSize, param.presencePenaltiesSize, param.frequencyPenaltiesSize,
|
||||
param.promptIgnoreLengthsSize);
|
||||
|
||||
mStream->synchronize();
|
||||
|
||||
@ -696,11 +745,14 @@ TYPED_TEST(RepetitionPenaltyTest, BatchNoPenalty)
|
||||
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT);
|
||||
TensorPtr frequencyPenaltyHost
|
||||
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT);
|
||||
TensorPtr promptIgnoreLengthsHost
|
||||
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32);
|
||||
for (int32_t i = 0; i < maxBatchSize; ++i)
|
||||
{
|
||||
bufferCast<float>(*repetitionPenaltyHost)[i] = 1.0f;
|
||||
bufferCast<float>(*presencePenaltyHost)[i] = 0.0f;
|
||||
bufferCast<float>(*frequencyPenaltyHost)[i] = 0.0f;
|
||||
bufferCast<int32_t>(*promptIgnoreLengthsHost)[i] = 0;
|
||||
}
|
||||
this->runTest(RepetitionPenaltyTestCase()
|
||||
.setBatchSize(batchSize)
|
||||
@ -709,9 +761,11 @@ TYPED_TEST(RepetitionPenaltyTest, BatchNoPenalty)
|
||||
.setRepetitionPenalties(repetitionPenaltyHost)
|
||||
.setPresencePenalties(presencePenaltyHost)
|
||||
.setFrequencyPenalties(frequencyPenaltyHost)
|
||||
.setPromptIgnoreLengths(promptIgnoreLengthsHost)
|
||||
.setRepetitionPenaltiesSize(maxBatchSize)
|
||||
.setPresencePenaltiesSize(maxBatchSize)
|
||||
.setFrequencyPenaltiesSize(maxBatchSize));
|
||||
.setFrequencyPenaltiesSize(maxBatchSize)
|
||||
.setPromptIgnoreLengthsSize(maxBatchSize));
|
||||
}
|
||||
|
||||
TYPED_TEST(RepetitionPenaltyTest, BatchRepetitionLessThanOne)
|
||||
@ -724,11 +778,14 @@ TYPED_TEST(RepetitionPenaltyTest, BatchRepetitionLessThanOne)
|
||||
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT);
|
||||
TensorPtr frequencyPenaltyHost
|
||||
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT);
|
||||
TensorPtr promptIgnoreLengthsHost
|
||||
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32);
|
||||
for (int32_t i = 0; i < maxBatchSize; ++i)
|
||||
{
|
||||
bufferCast<float>(*repetitionPenaltyHost)[i] = 0.53f;
|
||||
bufferCast<float>(*presencePenaltyHost)[i] = 0.0f;
|
||||
bufferCast<float>(*frequencyPenaltyHost)[i] = 0.0f;
|
||||
bufferCast<int32_t>(*promptIgnoreLengthsHost)[i] = 0;
|
||||
}
|
||||
this->runTest(RepetitionPenaltyTestCase()
|
||||
.setBatchSize(batchSize)
|
||||
@ -737,9 +794,11 @@ TYPED_TEST(RepetitionPenaltyTest, BatchRepetitionLessThanOne)
|
||||
.setRepetitionPenalties(repetitionPenaltyHost)
|
||||
.setPresencePenalties(presencePenaltyHost)
|
||||
.setFrequencyPenalties(frequencyPenaltyHost)
|
||||
.setPromptIgnoreLengths(promptIgnoreLengthsHost)
|
||||
.setRepetitionPenaltiesSize(maxBatchSize)
|
||||
.setPresencePenaltiesSize(maxBatchSize)
|
||||
.setFrequencyPenaltiesSize(maxBatchSize));
|
||||
.setFrequencyPenaltiesSize(maxBatchSize)
|
||||
.setPromptIgnoreLengthsSize(maxBatchSize));
|
||||
}
|
||||
|
||||
TYPED_TEST(RepetitionPenaltyTest, BatchRepetitionGreaterThaneOne)
|
||||
@ -752,11 +811,14 @@ TYPED_TEST(RepetitionPenaltyTest, BatchRepetitionGreaterThaneOne)
|
||||
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT);
|
||||
TensorPtr frequencyPenaltyHost
|
||||
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT);
|
||||
TensorPtr promptIgnoreLengthsHost
|
||||
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32);
|
||||
for (int32_t i = 0; i < maxBatchSize; ++i)
|
||||
{
|
||||
bufferCast<float>(*repetitionPenaltyHost)[i] = 2.01f;
|
||||
bufferCast<float>(*presencePenaltyHost)[i] = 0.0f;
|
||||
bufferCast<float>(*frequencyPenaltyHost)[i] = 0.0f;
|
||||
bufferCast<int32_t>(*promptIgnoreLengthsHost)[i] = 0;
|
||||
}
|
||||
this->runTest(RepetitionPenaltyTestCase()
|
||||
.setBatchSize(batchSize)
|
||||
@ -765,9 +827,11 @@ TYPED_TEST(RepetitionPenaltyTest, BatchRepetitionGreaterThaneOne)
|
||||
.setRepetitionPenalties(repetitionPenaltyHost)
|
||||
.setPresencePenalties(presencePenaltyHost)
|
||||
.setFrequencyPenalties(frequencyPenaltyHost)
|
||||
.setPromptIgnoreLengths(promptIgnoreLengthsHost)
|
||||
.setRepetitionPenaltiesSize(maxBatchSize)
|
||||
.setPresencePenaltiesSize(maxBatchSize)
|
||||
.setFrequencyPenaltiesSize(maxBatchSize));
|
||||
.setFrequencyPenaltiesSize(maxBatchSize)
|
||||
.setPromptIgnoreLengthsSize(maxBatchSize));
|
||||
}
|
||||
|
||||
TYPED_TEST(RepetitionPenaltyTest, BatchRepetitionMixed)
|
||||
@ -780,11 +844,14 @@ TYPED_TEST(RepetitionPenaltyTest, BatchRepetitionMixed)
|
||||
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT);
|
||||
TensorPtr frequencyPenaltyHost
|
||||
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT);
|
||||
TensorPtr promptIgnoreLengthsHost
|
||||
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32);
|
||||
for (int32_t i = 0; i < maxBatchSize; ++i)
|
||||
{
|
||||
bufferCast<float>(*repetitionPenaltyHost)[i] = 0.53 + i * 0.2f;
|
||||
bufferCast<float>(*presencePenaltyHost)[i] = 0.0f;
|
||||
bufferCast<float>(*frequencyPenaltyHost)[i] = 0.0f;
|
||||
bufferCast<int32_t>(*promptIgnoreLengthsHost)[i] = 0;
|
||||
}
|
||||
this->runTest(RepetitionPenaltyTestCase()
|
||||
.setBatchSize(batchSize)
|
||||
@ -793,9 +860,11 @@ TYPED_TEST(RepetitionPenaltyTest, BatchRepetitionMixed)
|
||||
.setRepetitionPenalties(repetitionPenaltyHost)
|
||||
.setPresencePenalties(presencePenaltyHost)
|
||||
.setFrequencyPenalties(frequencyPenaltyHost)
|
||||
.setPromptIgnoreLengths(promptIgnoreLengthsHost)
|
||||
.setRepetitionPenaltiesSize(maxBatchSize)
|
||||
.setPresencePenaltiesSize(maxBatchSize)
|
||||
.setFrequencyPenaltiesSize(maxBatchSize));
|
||||
.setFrequencyPenaltiesSize(maxBatchSize)
|
||||
.setPromptIgnoreLengthsSize(maxBatchSize));
|
||||
}
|
||||
|
||||
TYPED_TEST(RepetitionPenaltyTest, BatchPresenceMixed)
|
||||
@ -808,11 +877,14 @@ TYPED_TEST(RepetitionPenaltyTest, BatchPresenceMixed)
|
||||
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT);
|
||||
TensorPtr frequencyPenaltyHost
|
||||
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT);
|
||||
TensorPtr promptIgnoreLengthsHost
|
||||
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32);
|
||||
for (int32_t i = 0; i < maxBatchSize; ++i)
|
||||
{
|
||||
bufferCast<float>(*repetitionPenaltyHost)[i] = 1.0f;
|
||||
bufferCast<float>(*presencePenaltyHost)[i] = 0.53 + i * 0.2f;
|
||||
bufferCast<float>(*frequencyPenaltyHost)[i] = 0.0f;
|
||||
bufferCast<int32_t>(*promptIgnoreLengthsHost)[i] = 0;
|
||||
}
|
||||
this->runTest(RepetitionPenaltyTestCase()
|
||||
.setBatchSize(batchSize)
|
||||
@ -821,9 +893,11 @@ TYPED_TEST(RepetitionPenaltyTest, BatchPresenceMixed)
|
||||
.setRepetitionPenalties(repetitionPenaltyHost)
|
||||
.setPresencePenalties(presencePenaltyHost)
|
||||
.setFrequencyPenalties(frequencyPenaltyHost)
|
||||
.setPromptIgnoreLengths(promptIgnoreLengthsHost)
|
||||
.setRepetitionPenaltiesSize(maxBatchSize)
|
||||
.setPresencePenaltiesSize(maxBatchSize)
|
||||
.setFrequencyPenaltiesSize(maxBatchSize));
|
||||
.setFrequencyPenaltiesSize(maxBatchSize)
|
||||
.setPromptIgnoreLengthsSize(maxBatchSize));
|
||||
}
|
||||
|
||||
TYPED_TEST(RepetitionPenaltyTest, BatchPresenceHasDefaultValueZero2)
|
||||
@ -836,11 +910,14 @@ TYPED_TEST(RepetitionPenaltyTest, BatchPresenceHasDefaultValueZero2)
|
||||
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT);
|
||||
TensorPtr frequencyPenaltyHost
|
||||
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT);
|
||||
TensorPtr promptIgnoreLengthsHost
|
||||
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32);
|
||||
for (int32_t i = 0; i < maxBatchSize; ++i)
|
||||
{
|
||||
bufferCast<float>(*repetitionPenaltyHost)[i] = 1.0f;
|
||||
bufferCast<float>(*presencePenaltyHost)[i] = i % 2 == 0 ? 1.0f : 0.0f;
|
||||
bufferCast<float>(*frequencyPenaltyHost)[i] = 0.0f;
|
||||
bufferCast<int32_t>(*promptIgnoreLengthsHost)[i] = 0;
|
||||
}
|
||||
this->runTest(RepetitionPenaltyTestCase()
|
||||
.setBatchSize(batchSize)
|
||||
@ -849,9 +926,11 @@ TYPED_TEST(RepetitionPenaltyTest, BatchPresenceHasDefaultValueZero2)
|
||||
.setRepetitionPenalties(repetitionPenaltyHost)
|
||||
.setPresencePenalties(presencePenaltyHost)
|
||||
.setFrequencyPenalties(frequencyPenaltyHost)
|
||||
.setPromptIgnoreLengths(promptIgnoreLengthsHost)
|
||||
.setRepetitionPenaltiesSize(maxBatchSize)
|
||||
.setPresencePenaltiesSize(maxBatchSize)
|
||||
.setFrequencyPenaltiesSize(maxBatchSize));
|
||||
.setFrequencyPenaltiesSize(maxBatchSize)
|
||||
.setPromptIgnoreLengthsSize(maxBatchSize));
|
||||
}
|
||||
|
||||
TYPED_TEST(RepetitionPenaltyTest, BatchFrequencyMixed)
|
||||
@ -864,11 +943,14 @@ TYPED_TEST(RepetitionPenaltyTest, BatchFrequencyMixed)
|
||||
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT);
|
||||
TensorPtr frequencyPenaltyHost
|
||||
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT);
|
||||
TensorPtr promptIgnoreLengthsHost
|
||||
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32);
|
||||
for (int32_t i = 0; i < maxBatchSize; ++i)
|
||||
{
|
||||
bufferCast<float>(*repetitionPenaltyHost)[i] = 1.0f;
|
||||
bufferCast<float>(*presencePenaltyHost)[i] = 0.0f;
|
||||
bufferCast<float>(*frequencyPenaltyHost)[i] = 0.53 + i * 0.2f;
|
||||
bufferCast<int32_t>(*promptIgnoreLengthsHost)[i] = 0;
|
||||
}
|
||||
this->runTest(RepetitionPenaltyTestCase()
|
||||
.setBatchSize(batchSize)
|
||||
@ -877,9 +959,11 @@ TYPED_TEST(RepetitionPenaltyTest, BatchFrequencyMixed)
|
||||
.setRepetitionPenalties(repetitionPenaltyHost)
|
||||
.setPresencePenalties(presencePenaltyHost)
|
||||
.setFrequencyPenalties(frequencyPenaltyHost)
|
||||
.setPromptIgnoreLengths(promptIgnoreLengthsHost)
|
||||
.setRepetitionPenaltiesSize(maxBatchSize)
|
||||
.setPresencePenaltiesSize(maxBatchSize)
|
||||
.setFrequencyPenaltiesSize(maxBatchSize));
|
||||
.setFrequencyPenaltiesSize(maxBatchSize)
|
||||
.setPromptIgnoreLengthsSize(maxBatchSize));
|
||||
}
|
||||
|
||||
TYPED_TEST(RepetitionPenaltyTest, BatchFrequencyHasDefaultValueZero2)
|
||||
@ -892,11 +976,14 @@ TYPED_TEST(RepetitionPenaltyTest, BatchFrequencyHasDefaultValueZero2)
|
||||
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT);
|
||||
TensorPtr frequencyPenaltyHost
|
||||
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT);
|
||||
TensorPtr promptIgnoreLengthsHost
|
||||
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32);
|
||||
for (int32_t i = 0; i < maxBatchSize; ++i)
|
||||
{
|
||||
bufferCast<float>(*repetitionPenaltyHost)[i] = 1.0f;
|
||||
bufferCast<float>(*presencePenaltyHost)[i] = 0.0f;
|
||||
bufferCast<float>(*frequencyPenaltyHost)[i] = i % 2 == 0 ? 1.0f : 0.0f;
|
||||
bufferCast<int32_t>(*promptIgnoreLengthsHost)[i] = 0;
|
||||
}
|
||||
this->runTest(RepetitionPenaltyTestCase()
|
||||
.setBatchSize(batchSize)
|
||||
@ -905,9 +992,11 @@ TYPED_TEST(RepetitionPenaltyTest, BatchFrequencyHasDefaultValueZero2)
|
||||
.setRepetitionPenalties(repetitionPenaltyHost)
|
||||
.setPresencePenalties(presencePenaltyHost)
|
||||
.setFrequencyPenalties(frequencyPenaltyHost)
|
||||
.setPromptIgnoreLengths(promptIgnoreLengthsHost)
|
||||
.setRepetitionPenaltiesSize(maxBatchSize)
|
||||
.setPresencePenaltiesSize(maxBatchSize)
|
||||
.setFrequencyPenaltiesSize(maxBatchSize));
|
||||
.setFrequencyPenaltiesSize(maxBatchSize)
|
||||
.setPromptIgnoreLengthsSize(maxBatchSize));
|
||||
}
|
||||
|
||||
TYPED_TEST(RepetitionPenaltyTest, PenaltyTypeRepetitionPresence)
|
||||
@ -920,11 +1009,14 @@ TYPED_TEST(RepetitionPenaltyTest, PenaltyTypeRepetitionPresence)
|
||||
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT);
|
||||
TensorPtr frequencyPenaltyHost
|
||||
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT);
|
||||
TensorPtr promptIgnoreLengthsHost
|
||||
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32);
|
||||
for (int32_t i = 0; i < maxBatchSize; ++i)
|
||||
{
|
||||
bufferCast<float>(*repetitionPenaltyHost)[i] = 0.53 + i * 0.2f;
|
||||
bufferCast<float>(*presencePenaltyHost)[i] = 0.53 + i * 0.2f;
|
||||
bufferCast<float>(*frequencyPenaltyHost)[i] = 0.0f;
|
||||
bufferCast<int32_t>(*promptIgnoreLengthsHost)[i] = 0;
|
||||
}
|
||||
this->runTest(RepetitionPenaltyTestCase()
|
||||
.setBatchSize(batchSize)
|
||||
@ -933,9 +1025,11 @@ TYPED_TEST(RepetitionPenaltyTest, PenaltyTypeRepetitionPresence)
|
||||
.setRepetitionPenalties(repetitionPenaltyHost)
|
||||
.setPresencePenalties(presencePenaltyHost)
|
||||
.setFrequencyPenalties(frequencyPenaltyHost)
|
||||
.setPromptIgnoreLengths(promptIgnoreLengthsHost)
|
||||
.setRepetitionPenaltiesSize(maxBatchSize)
|
||||
.setPresencePenaltiesSize(maxBatchSize)
|
||||
.setFrequencyPenaltiesSize(maxBatchSize));
|
||||
.setFrequencyPenaltiesSize(maxBatchSize)
|
||||
.setPromptIgnoreLengthsSize(maxBatchSize));
|
||||
}
|
||||
|
||||
TYPED_TEST(RepetitionPenaltyTest, PenaltyTypeRepetitionFrequency)
|
||||
@ -948,11 +1042,14 @@ TYPED_TEST(RepetitionPenaltyTest, PenaltyTypeRepetitionFrequency)
|
||||
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT);
|
||||
TensorPtr frequencyPenaltyHost
|
||||
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT);
|
||||
TensorPtr promptIgnoreLengthsHost
|
||||
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32);
|
||||
for (int32_t i = 0; i < maxBatchSize; ++i)
|
||||
{
|
||||
bufferCast<float>(*repetitionPenaltyHost)[i] = 0.53 + i * 0.2f;
|
||||
bufferCast<float>(*presencePenaltyHost)[i] = 0.0f;
|
||||
bufferCast<float>(*frequencyPenaltyHost)[i] = 0.53 + i * 0.2f;
|
||||
bufferCast<int32_t>(*promptIgnoreLengthsHost)[i] = 0;
|
||||
}
|
||||
this->runTest(RepetitionPenaltyTestCase()
|
||||
.setBatchSize(batchSize)
|
||||
@ -961,9 +1058,11 @@ TYPED_TEST(RepetitionPenaltyTest, PenaltyTypeRepetitionFrequency)
|
||||
.setRepetitionPenalties(repetitionPenaltyHost)
|
||||
.setPresencePenalties(presencePenaltyHost)
|
||||
.setFrequencyPenalties(frequencyPenaltyHost)
|
||||
.setPromptIgnoreLengths(promptIgnoreLengthsHost)
|
||||
.setRepetitionPenaltiesSize(maxBatchSize)
|
||||
.setPresencePenaltiesSize(maxBatchSize)
|
||||
.setFrequencyPenaltiesSize(maxBatchSize));
|
||||
.setFrequencyPenaltiesSize(maxBatchSize)
|
||||
.setPromptIgnoreLengthsSize(maxBatchSize));
|
||||
}
|
||||
|
||||
TYPED_TEST(RepetitionPenaltyTest, PenaltyTypePresenceFrequency)
|
||||
@ -976,11 +1075,14 @@ TYPED_TEST(RepetitionPenaltyTest, PenaltyTypePresenceFrequency)
|
||||
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT);
|
||||
TensorPtr frequencyPenaltyHost
|
||||
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT);
|
||||
TensorPtr promptIgnoreLengthsHost
|
||||
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32);
|
||||
for (int32_t i = 0; i < maxBatchSize; ++i)
|
||||
{
|
||||
bufferCast<float>(*repetitionPenaltyHost)[i] = 1.0f;
|
||||
bufferCast<float>(*presencePenaltyHost)[i] = 0.53 + i * 0.2f;
|
||||
bufferCast<float>(*frequencyPenaltyHost)[i] = 0.53 + i * 0.2f;
|
||||
bufferCast<int32_t>(*promptIgnoreLengthsHost)[i] = 0;
|
||||
}
|
||||
this->runTest(RepetitionPenaltyTestCase()
|
||||
.setBatchSize(batchSize)
|
||||
@ -989,9 +1091,11 @@ TYPED_TEST(RepetitionPenaltyTest, PenaltyTypePresenceFrequency)
|
||||
.setRepetitionPenalties(repetitionPenaltyHost)
|
||||
.setPresencePenalties(presencePenaltyHost)
|
||||
.setFrequencyPenalties(frequencyPenaltyHost)
|
||||
.setPromptIgnoreLengths(promptIgnoreLengthsHost)
|
||||
.setRepetitionPenaltiesSize(maxBatchSize)
|
||||
.setPresencePenaltiesSize(maxBatchSize)
|
||||
.setFrequencyPenaltiesSize(maxBatchSize));
|
||||
.setFrequencyPenaltiesSize(maxBatchSize)
|
||||
.setPromptIgnoreLengthsSize(maxBatchSize));
|
||||
}
|
||||
|
||||
TYPED_TEST(RepetitionPenaltyTest, PenaltyTypeFull)
|
||||
@ -1004,11 +1108,14 @@ TYPED_TEST(RepetitionPenaltyTest, PenaltyTypeFull)
|
||||
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT);
|
||||
TensorPtr frequencyPenaltyHost
|
||||
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT);
|
||||
TensorPtr promptIgnoreLengthsHost
|
||||
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32);
|
||||
for (int32_t i = 0; i < maxBatchSize; ++i)
|
||||
{
|
||||
bufferCast<float>(*repetitionPenaltyHost)[i] = 0.53 + i * 0.2f;
|
||||
bufferCast<float>(*presencePenaltyHost)[i] = 0.53 + i * 0.2f;
|
||||
bufferCast<float>(*frequencyPenaltyHost)[i] = 0.53 + i * 0.2f;
|
||||
bufferCast<int32_t>(*promptIgnoreLengthsHost)[i] = 0;
|
||||
}
|
||||
this->runTest(RepetitionPenaltyTestCase()
|
||||
.setBatchSize(batchSize)
|
||||
@ -1017,9 +1124,11 @@ TYPED_TEST(RepetitionPenaltyTest, PenaltyTypeFull)
|
||||
.setRepetitionPenalties(repetitionPenaltyHost)
|
||||
.setPresencePenalties(presencePenaltyHost)
|
||||
.setFrequencyPenalties(frequencyPenaltyHost)
|
||||
.setPromptIgnoreLengths(promptIgnoreLengthsHost)
|
||||
.setRepetitionPenaltiesSize(maxBatchSize)
|
||||
.setPresencePenaltiesSize(maxBatchSize)
|
||||
.setFrequencyPenaltiesSize(maxBatchSize));
|
||||
.setFrequencyPenaltiesSize(maxBatchSize)
|
||||
.setPromptIgnoreLengthsSize(maxBatchSize));
|
||||
}
|
||||
|
||||
TYPED_TEST(RepetitionPenaltyTest, PenaltyTypeFullTokensPerStep)
|
||||
@ -1032,11 +1141,14 @@ TYPED_TEST(RepetitionPenaltyTest, PenaltyTypeFullTokensPerStep)
|
||||
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT);
|
||||
TensorPtr frequencyPenaltyHost
|
||||
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT);
|
||||
TensorPtr promptIgnoreLengthsHost
|
||||
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32);
|
||||
for (int32_t i = 0; i < maxBatchSize; ++i)
|
||||
{
|
||||
bufferCast<float>(*repetitionPenaltyHost)[i] = 0.53 + i * 0.2f;
|
||||
bufferCast<float>(*presencePenaltyHost)[i] = 0.53 + i * 0.2f;
|
||||
bufferCast<float>(*frequencyPenaltyHost)[i] = 0.53 + i * 0.2f;
|
||||
bufferCast<int32_t>(*promptIgnoreLengthsHost)[i] = 0;
|
||||
}
|
||||
this->runTest(RepetitionPenaltyTestCase()
|
||||
.setBatchSize(batchSize)
|
||||
@ -1045,9 +1157,78 @@ TYPED_TEST(RepetitionPenaltyTest, PenaltyTypeFullTokensPerStep)
|
||||
.setRepetitionPenalties(repetitionPenaltyHost)
|
||||
.setPresencePenalties(presencePenaltyHost)
|
||||
.setFrequencyPenalties(frequencyPenaltyHost)
|
||||
.setPromptIgnoreLengths(promptIgnoreLengthsHost)
|
||||
.setRepetitionPenaltiesSize(maxBatchSize)
|
||||
.setPresencePenaltiesSize(maxBatchSize)
|
||||
.setFrequencyPenaltiesSize(maxBatchSize)
|
||||
.setPromptIgnoreLengthsSize(maxBatchSize)
|
||||
.setMaxTokensPerStep(4));
|
||||
}
|
||||
|
||||
TYPED_TEST(RepetitionPenaltyTest, PenaltyTypeFullWithPartialPromptIgnore)
|
||||
{
|
||||
int32_t batchSize = 6;
|
||||
int32_t maxBatchSize = 2 * batchSize;
|
||||
TensorPtr repetitionPenaltyHost
|
||||
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT);
|
||||
TensorPtr presencePenaltyHost
|
||||
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT);
|
||||
TensorPtr frequencyPenaltyHost
|
||||
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT);
|
||||
TensorPtr promptIgnoreLengthsHost
|
||||
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32);
|
||||
for (int32_t i = 0; i < maxBatchSize; ++i)
|
||||
{
|
||||
bufferCast<float>(*repetitionPenaltyHost)[i] = 0.53 + i * 0.2f;
|
||||
bufferCast<float>(*presencePenaltyHost)[i] = 0.53 + i * 0.2f;
|
||||
bufferCast<float>(*frequencyPenaltyHost)[i] = 0.53 + i * 0.2f;
|
||||
bufferCast<int32_t>(*promptIgnoreLengthsHost)[i] = 1; // set to 1 to ignore first prompt token
|
||||
}
|
||||
this->runTest(RepetitionPenaltyTestCase()
|
||||
.setBatchSize(batchSize)
|
||||
.setVocabSize(4)
|
||||
.setMaxInputLength(5)
|
||||
.setRepetitionPenalties(repetitionPenaltyHost)
|
||||
.setPresencePenalties(presencePenaltyHost)
|
||||
.setFrequencyPenalties(frequencyPenaltyHost)
|
||||
.setPromptIgnoreLengths(promptIgnoreLengthsHost)
|
||||
.setRepetitionPenaltiesSize(maxBatchSize)
|
||||
.setPresencePenaltiesSize(maxBatchSize)
|
||||
.setFrequencyPenaltiesSize(maxBatchSize)
|
||||
.setPromptIgnoreLengthsSize(maxBatchSize));
|
||||
}
|
||||
|
||||
TYPED_TEST(RepetitionPenaltyTest, PenaltyTypeFullTokensPerStepWithFullPromptIgnore)
|
||||
{
|
||||
int32_t batchSize = 6;
|
||||
int32_t maxBatchSize = 2 * batchSize;
|
||||
TensorPtr repetitionPenaltyHost
|
||||
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT);
|
||||
TensorPtr presencePenaltyHost
|
||||
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT);
|
||||
TensorPtr frequencyPenaltyHost
|
||||
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT);
|
||||
TensorPtr promptIgnoreLengthsHost
|
||||
= BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32);
|
||||
for (int32_t i = 0; i < maxBatchSize; ++i)
|
||||
{
|
||||
bufferCast<float>(*repetitionPenaltyHost)[i] = 0.53 + i * 0.2f;
|
||||
bufferCast<float>(*presencePenaltyHost)[i] = 0.53 + i * 0.2f;
|
||||
bufferCast<float>(*frequencyPenaltyHost)[i] = 0.53 + i * 0.2f;
|
||||
bufferCast<int32_t>(*promptIgnoreLengthsHost)[i] = 5; // set to max input length to ignore full prompt
|
||||
}
|
||||
this->runTest(RepetitionPenaltyTestCase()
|
||||
.setBatchSize(batchSize)
|
||||
.setVocabSize(4)
|
||||
.setMaxInputLength(5)
|
||||
.setRepetitionPenalties(repetitionPenaltyHost)
|
||||
.setPresencePenalties(presencePenaltyHost)
|
||||
.setFrequencyPenalties(frequencyPenaltyHost)
|
||||
.setPromptIgnoreLengths(promptIgnoreLengthsHost)
|
||||
.setRepetitionPenaltiesSize(maxBatchSize)
|
||||
.setPresencePenaltiesSize(maxBatchSize)
|
||||
.setFrequencyPenaltiesSize(maxBatchSize)
|
||||
.setPromptIgnoreLengthsSize(maxBatchSize)
|
||||
.setMaxTokensPerStep(4));
|
||||
}
|
||||
|
||||
@ -1257,8 +1438,8 @@ public:
|
||||
|
||||
InvokeBatchApplyPenaltyParams<T> penaltyParams{reinterpret_cast<T**>(bufferCast<int64_t>(*mLogitsPtrs)),
|
||||
bufferCast<T>(*mOutLogitsDevice), nullptr, bufferCast<int32_t>(*mPenaltyWorkspaceDevice), nullptr, nullptr,
|
||||
nullptr, nullptr, nullptr, mBatchSize, 1, mSequenceLength, mVocabSize, mVocabSizePadded, nullptr, nullptr,
|
||||
bufferCast<int32_t>(*mContextLengthDevice), bufferCast<int32_t>(*mSeqLengthDevice),
|
||||
nullptr, nullptr, nullptr, nullptr, mBatchSize, 1, mSequenceLength, mVocabSize, mVocabSizePadded, nullptr,
|
||||
nullptr, bufferCast<int32_t>(*mContextLengthDevice), bufferCast<int32_t>(*mSeqLengthDevice),
|
||||
bufferCast<int32_t>(*mMinLengthDevice), bufferCast<int32_t>(*mEndIdsDevice),
|
||||
bufferCast<int32_t>(*mBatchSlots), mMaxTokensPerStep, bufferCast<int32_t>(*mTokensPerStep), nullptr,
|
||||
mStream->get()};
|
||||
@ -1415,6 +1596,7 @@ public:
|
||||
/*repetitionPenalties=*/nullptr,
|
||||
/*presencePenalties=*/nullptr,
|
||||
/*frequencyPenalties=*/nullptr,
|
||||
/*promptIgnoreLengths=*/nullptr,
|
||||
/*batchSize=*/mBatchSize,
|
||||
/*beamWidth=*/1,
|
||||
/*maxSeqLen=*/mSequenceLength,
|
||||
|
||||
@ -40,6 +40,7 @@ struct TestSamplingParams
|
||||
std::vector<float> repetitionPenalties;
|
||||
std::vector<float> presencePenalties;
|
||||
std::vector<float> frequencyPenalties;
|
||||
std::vector<runtime::SizeType32> promptIgnoreLengths;
|
||||
std::vector<runtime::SizeType32> minLengths;
|
||||
std::vector<float> decay;
|
||||
std::vector<float> minTopP;
|
||||
|
||||
@ -37,17 +37,18 @@ void test(bool const useExternalDraftTokensConfig, SizeType32 beamWidth = 1, std
|
||||
std::optional<RandomSeedType> randomSeed = no, std::optional<FloatType> temperature = no,
|
||||
std::optional<SizeType32> minLength = no, std::optional<FloatType> beamSearchDiversityRate = no,
|
||||
std::optional<FloatType> repetitionPenalty = no, std::optional<FloatType> presencePenalty = no,
|
||||
std::optional<FloatType> frequencyPenalty = no, std::optional<FloatType> lengthPenalty = no,
|
||||
std::optional<SizeType32> earlyStopping = no, std::optional<SizeType32> noRepeatNgramSize = no,
|
||||
std::optional<SizeType32> numReturnSequences = no, std::optional<FloatType> minP = no,
|
||||
std::optional<std::vector<SizeType32>> beamWidthArray = no)
|
||||
std::optional<FloatType> frequencyPenalty = no, std::optional<SizeType32> promptIgnoreLength = no,
|
||||
std::optional<FloatType> lengthPenalty = no, std::optional<SizeType32> earlyStopping = no,
|
||||
std::optional<SizeType32> noRepeatNgramSize = no, std::optional<SizeType32> numReturnSequences = no,
|
||||
std::optional<FloatType> minP = no, std::optional<std::vector<SizeType32>> beamWidthArray = no)
|
||||
{
|
||||
// 19 parameters for SamplingConfig, from `beamWidth` to `beamWidthArray`
|
||||
// 20 parameters for SamplingConfig, from `beamWidth` to `beamWidthArray`
|
||||
try
|
||||
{
|
||||
te::SamplingConfig execSamplingCfg(beamWidth, topK, topP, topPMin, topPResetIds, topPDecay, randomSeed,
|
||||
temperature, minLength, beamSearchDiversityRate, repetitionPenalty, presencePenalty, frequencyPenalty,
|
||||
lengthPenalty, earlyStopping, noRepeatNgramSize, numReturnSequences, minP, beamWidthArray);
|
||||
promptIgnoreLength, lengthPenalty, earlyStopping, noRepeatNgramSize, numReturnSequences, minP,
|
||||
beamWidthArray);
|
||||
std::optional<te::ExternalDraftTokensConfig> specCfg = std::nullopt;
|
||||
if (useExternalDraftTokensConfig)
|
||||
{
|
||||
@ -110,18 +111,20 @@ TEST(samplingConfigTest, validInputs)
|
||||
test(false, 1, no, no, no, no, no, no, no, no, no, no, 1.f);
|
||||
// Frequency penalty
|
||||
test(false, 1, no, no, no, no, no, no, no, no, no, no, no, 1.f);
|
||||
// Prompt ignore length
|
||||
test(false, 1, no, no, no, no, no, no, no, no, no, no, no, no, 1);
|
||||
// Length penalty
|
||||
test(false, 1, no, no, no, no, no, no, no, no, no, no, no, no, 1.f);
|
||||
// Early stopping
|
||||
test(false, 1, no, no, no, no, no, no, no, no, no, no, no, no, no, 1.f);
|
||||
// Early stopping
|
||||
test(false, 1, no, no, no, no, no, no, no, no, no, no, no, no, no, no, 1.f);
|
||||
// No repeat ngram size
|
||||
test(false, 1, no, no, no, no, no, no, no, no, no, no, no, no, no, no, 2);
|
||||
test(false, 1, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, 2);
|
||||
// NumReturnSequences
|
||||
test(false, 4, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, 2);
|
||||
// MinP, 18 arguments
|
||||
test(false, 1, no, 0.9, no, no, no, no, no, no, no, no, no, no, no, no, no, no, 0.5f);
|
||||
test(false, 4, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, 2);
|
||||
// MinP, 19 arguments
|
||||
test(false, 1, no, 0.9, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, 0.5f);
|
||||
// BeamWidthArray
|
||||
test(false, 5, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no,
|
||||
test(false, 5, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no,
|
||||
std::vector<SizeType32>{2, 3, 4, 5});
|
||||
|
||||
// All parameters
|
||||
@ -139,6 +142,7 @@ TEST(samplingConfigTest, validInputs)
|
||||
te::FloatType repetitionPenalty{0.5f};
|
||||
te::FloatType presencePenalty{0.5f};
|
||||
te::FloatType frequencyPenalty{0.5f};
|
||||
te::SizeType32 promptIgnoreLength{1};
|
||||
te::FloatType lengthPenalty{0.5f};
|
||||
te::SizeType32 earlyStopping{1};
|
||||
te::SizeType32 noRepeatNgramSize{5};
|
||||
@ -148,7 +152,8 @@ TEST(samplingConfigTest, validInputs)
|
||||
|
||||
te::SamplingConfig execSamplingCfg(beamWidth, topK, topP, topPMin, topPResetIds, topPDecay, randomSeed,
|
||||
temperature, minLength, beamSearchDiversityRate, repetitionPenalty, presencePenalty, frequencyPenalty,
|
||||
lengthPenalty, earlyStopping, noRepeatNgramSize, numReturnSequences, minP, beamWidthArray);
|
||||
promptIgnoreLength, lengthPenalty, earlyStopping, noRepeatNgramSize, numReturnSequences, minP,
|
||||
beamWidthArray);
|
||||
te::ExternalDraftTokensConfig specCfg({1}, no, 0.5f);
|
||||
tr::SamplingConfig samplingCfg(execSamplingCfg, specCfg);
|
||||
EXPECT_EQ(samplingCfg.beamWidth, execSamplingCfg.getBeamWidth());
|
||||
@ -166,6 +171,7 @@ TEST(samplingConfigTest, validInputs)
|
||||
EXPECT_THAT(samplingCfg.repetitionPenalty.value(), testing::ElementsAre(repetitionPenalty));
|
||||
EXPECT_THAT(samplingCfg.presencePenalty.value(), testing::ElementsAre(presencePenalty));
|
||||
EXPECT_THAT(samplingCfg.frequencyPenalty.value(), testing::ElementsAre(frequencyPenalty));
|
||||
EXPECT_THAT(samplingCfg.promptIgnoreLength.value(), testing::ElementsAre(promptIgnoreLength));
|
||||
EXPECT_THAT(samplingCfg.lengthPenalty.value(), testing::ElementsAre(lengthPenalty));
|
||||
EXPECT_THAT(samplingCfg.earlyStopping.value(), testing::ElementsAre(earlyStopping));
|
||||
EXPECT_THAT(samplingCfg.noRepeatNgramSize.value(), testing::ElementsAre(noRepeatNgramSize));
|
||||
|
||||
@ -281,6 +281,7 @@ def main(args):
|
||||
repetition_penalty=args.repetition_penalty,
|
||||
presence_penalty=args.presence_penalty,
|
||||
frequency_penalty=args.frequency_penalty,
|
||||
prompt_ignore_length=args.prompt_ignore_length,
|
||||
# stop_words_list=stop_words_list,
|
||||
# bad_words_list=bad_words_list,
|
||||
output_cum_log_probs=(args.output_cum_log_probs_npy != None),
|
||||
|
||||
@ -540,6 +540,7 @@ def main(args):
|
||||
repetition_penalty=args.repetition_penalty,
|
||||
presence_penalty=args.presence_penalty,
|
||||
frequency_penalty=args.frequency_penalty,
|
||||
prompt_ignore_length=args.prompt_ignore_length,
|
||||
min_p=args.min_p,
|
||||
stop_words_list=stop_words_list,
|
||||
bad_words_list=bad_words_list,
|
||||
@ -639,6 +640,7 @@ def main(args):
|
||||
repetition_penalty=args.repetition_penalty,
|
||||
presence_penalty=args.presence_penalty,
|
||||
frequency_penalty=args.frequency_penalty,
|
||||
prompt_ignore_length=args.prompt_ignore_length,
|
||||
min_p=args.min_p,
|
||||
stop_words_list=stop_words_list,
|
||||
bad_words_list=bad_words_list,
|
||||
@ -677,6 +679,7 @@ def main(args):
|
||||
repetition_penalty=args.repetition_penalty,
|
||||
presence_penalty=args.presence_penalty,
|
||||
frequency_penalty=args.frequency_penalty,
|
||||
prompt_ignore_length=args.prompt_ignore_length,
|
||||
stop_words_list=stop_words_list,
|
||||
bad_words_list=bad_words_list,
|
||||
output_cum_log_probs=(args.output_cum_log_probs_npy
|
||||
|
||||
@ -211,6 +211,7 @@ def main(args):
|
||||
repetition_penalty = args.repetition_penalty
|
||||
presence_penalty = args.presence_penalty
|
||||
frequency_penalty = args.frequency_penalty
|
||||
prompt_ignore_length = args.prompt_ignore_length
|
||||
random_seed = args.random_seed
|
||||
torch.manual_seed(random_seed)
|
||||
|
||||
@ -353,6 +354,7 @@ def main(args):
|
||||
repetition_penalty=repetition_penalty,
|
||||
presence_penalty=presence_penalty,
|
||||
frequency_penalty=frequency_penalty,
|
||||
prompt_ignore_length=prompt_ignore_length,
|
||||
lora_uids=args.lora_task_uids,
|
||||
lookahead_config=args.lookahead_config,
|
||||
output_sequence_lengths=True,
|
||||
|
||||
@ -304,6 +304,7 @@ def add_common_args(parser):
|
||||
parser.add_argument('--repetition_penalty', type=float, default=1.0)
|
||||
parser.add_argument('--presence_penalty', type=float, default=0.0)
|
||||
parser.add_argument('--frequency_penalty', type=float, default=0.0)
|
||||
parser.add_argument('--prompt_ignore_length', type=int, default=0)
|
||||
parser.add_argument('--min_p', type=float, default=0.0)
|
||||
parser.add_argument('--beam_search_diversity_rate', type=float, default=0.0)
|
||||
parser.add_argument('--random_seed', type=int, default=0)
|
||||
|
||||
@ -721,6 +721,7 @@ class SamplingConfig:
|
||||
min_length: Union[int, torch.Tensor] = field(default=1)
|
||||
presence_penalty: Union[float, torch.Tensor] = field(default=0.0)
|
||||
frequency_penalty: Union[float, torch.Tensor] = field(default=0.0)
|
||||
prompt_ignore_length: Union[int, torch.Tensor] = field(default=0)
|
||||
use_beam_hyps: bool = field(default=True)
|
||||
|
||||
# None here means user didn't set it, and dynamicDecodeOp.cpp take optional value
|
||||
@ -1474,6 +1475,16 @@ class GenerationSession(object):
|
||||
scfg.frequency_penalty,
|
||||
dtype=torch.float32)
|
||||
|
||||
if isinstance(scfg.prompt_ignore_length, torch.Tensor):
|
||||
assert scfg.prompt_ignore_length.dtype == torch.int32, f"scfg.prompt_ignore_length.dtype ({scfg.prompt_ignore_length.dtype}) must be torch.int32"
|
||||
assert scfg.prompt_ignore_length.shape[
|
||||
0] == batch_size, f"scfg.prompt_ignore_length.shape[0] ({scfg.prompt_ignore_length.shape[0]}) must equal to batch_size ({batch_size})"
|
||||
self.prompt_ignore_length = scfg.prompt_ignore_length
|
||||
else:
|
||||
self.prompt_ignore_length = torch.full([batch_size],
|
||||
scfg.prompt_ignore_length,
|
||||
dtype=torch.int32)
|
||||
|
||||
if isinstance(scfg.min_length, torch.Tensor):
|
||||
assert scfg.min_length.dtype == torch.int32, f"scfg.min_length.dtype ({scfg.min_length.dtype}) must be torch.int32"
|
||||
assert scfg.min_length.shape[
|
||||
@ -1543,6 +1554,7 @@ class GenerationSession(object):
|
||||
self.repetition_penalty,
|
||||
self.presence_penalty,
|
||||
self.frequency_penalty,
|
||||
self.prompt_ignore_length,
|
||||
self.min_length,
|
||||
self.host_length_penalty,
|
||||
self.host_early_stopping,
|
||||
|
||||
@ -648,6 +648,7 @@ class ModelRunnerCpp(ModelRunnerMixin):
|
||||
"repetition_penalty",
|
||||
"presence_penalty",
|
||||
"frequency_penalty",
|
||||
"prompt_ignore_length",
|
||||
"length_penalty",
|
||||
"early_stopping",
|
||||
"no_repeat_ngram_size",
|
||||
|
||||
@ -165,6 +165,7 @@ class SamplingParams:
|
||||
repetition_penalty (float, optional): Used to penalize tokens based on how often they appear in the sequence. It can have any value > 0.f. Values < 1.f encourages repetition, values > 1.f discourages it. None means using C++ runtime default 1.f. Defaults to None.
|
||||
presence_penalty (float, optional): Used to penalize tokens already present in the sequence (irrespective of the number of appearances). It can have any values. Values < 0.f encourage repetition, values > 0.f discourage it. None means using C++ runtime default 0.f. Defaults to None.
|
||||
frequency_penalty (float, optional): Used to penalize tokens already present in the sequence (dependent on the number of appearances). It can have any values. Values < 0.f encourage repetition, values > 0.f discourage it. None means using C++ runtime default 0.f. Defaults to None.
|
||||
prompt_ignore_length (int, optional): Controls how many tokens to ignore from the prompt for presence and frequency penalties. Values <= 0 have no effect. Values > input (prompt) length will be clamped. None means using C++ runtime default 0. Defaults to None.
|
||||
length_penalty (float, optional): Controls how to penalize longer sequences in beam search. None means using C++ runtime default 0.f. Defaults to None.
|
||||
early_stopping (int, optional): Controls whether the generation process finishes once beamWidth sentences are generated (ends with end_token). None means using C++ runtime default 1. Defaults to None.
|
||||
no_repeat_ngram_size (int, optional): Controls how many repeat ngram size are acceptable. None means using C++ runtime default 1 << 30. Defaults to None.
|
||||
@ -232,6 +233,7 @@ class SamplingParams:
|
||||
repetition_penalty: Optional[float] = None
|
||||
presence_penalty: Optional[float] = None
|
||||
frequency_penalty: Optional[float] = None
|
||||
prompt_ignore_length: Optional[int] = None
|
||||
length_penalty: Optional[float] = None
|
||||
early_stopping: Optional[int] = None
|
||||
no_repeat_ngram_size: Optional[int] = None
|
||||
|
||||
@ -229,6 +229,7 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
top_p: Optional[float] = None
|
||||
user: Optional[str] = None
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
prompt_ignore_length: Optional[int] = 0
|
||||
|
||||
# doc: begin-completion-sampling-params
|
||||
use_beam_search: bool = False
|
||||
@ -283,6 +284,7 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
temperature=(self.temperature
|
||||
if self.temperature is not None else 1.0),
|
||||
top_p=(self.top_p if self.top_p is not None else 1.0),
|
||||
prompt_ignore_length=self.prompt_ignore_length,
|
||||
|
||||
# completion-sampling-params
|
||||
use_beam_search=self.use_beam_search,
|
||||
@ -530,6 +532,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
"reasoning is shown in the model's response. Options: "
|
||||
"'low', 'medium', 'high'."),
|
||||
)
|
||||
prompt_ignore_length: Optional[int] = 0
|
||||
|
||||
# doc: begin-chat-completion-sampling-params
|
||||
best_of: Optional[int] = None
|
||||
@ -622,6 +625,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
stop=self.stop,
|
||||
temperature=(self.temperature
|
||||
if self.temperature is not None else 1.0),
|
||||
prompt_ignore_length=self.prompt_ignore_length,
|
||||
|
||||
# chat-completion-sampling-params
|
||||
best_of=self.best_of,
|
||||
|
||||
@ -12,5 +12,8 @@ methods:
|
||||
beam_width_array:
|
||||
annotation: Optional[List[int]]
|
||||
default: null
|
||||
prompt_ignore_length:
|
||||
annotation: Optional[int]
|
||||
default: null
|
||||
return_annotation: None
|
||||
properties: {}
|
||||
|
||||
@ -35,6 +35,7 @@ class TestSamplingParams(ApiStabilityTestHarness):
|
||||
"repetition_penalty",
|
||||
"presence_penalty",
|
||||
"frequency_penalty",
|
||||
"prompt_ignore_length",
|
||||
"length_penalty",
|
||||
"early_stopping",
|
||||
"no_repeat_ngram_size",
|
||||
|
||||
Loading…
Reference in New Issue
Block a user