mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Merge remote-tracking branch 'origin/main' into feat/b300_cu13
Signed-off-by: Xiwen Yu <13230610+VALLIS-NERIA@users.noreply.github.com>
This commit is contained in:
commit
291290851a
@ -542,7 +542,8 @@ texec::Request makeExecutorContextRequest(Sample const& sample, SizeType32 const
|
||||
std::nullopt, // kvCacheRetentionConfig
|
||||
std::nullopt, // logitsPostProcessorName
|
||||
std::nullopt, // logitsPostProcessor
|
||||
encoderInputTokenIds.has_value() ? encoderInputTokenIds : std::nullopt);
|
||||
encoderInputTokenIds.has_value() ? encoderInputTokenIds : std::nullopt,
|
||||
std::nullopt); // cacheSaltID
|
||||
request.setRequestType(tensorrt_llm::executor::RequestType::REQUEST_TYPE_CONTEXT_ONLY);
|
||||
return request;
|
||||
}
|
||||
|
||||
@ -837,7 +837,8 @@ texec::Request makeExecutorRequest(Sample const& sample, SizeType32 const& beamW
|
||||
std::nullopt, // kvCacheRetentionConfig
|
||||
std::nullopt, // logitsPostProcessorName
|
||||
std::nullopt, // logitsPostProcessor
|
||||
encoderInputTokenIds.has_value() ? encoderInputTokenIds : std::nullopt);
|
||||
encoderInputTokenIds.has_value() ? encoderInputTokenIds : std::nullopt,
|
||||
std::nullopt); // cacheSaltID
|
||||
}
|
||||
|
||||
void benchmarkExecutor(std::optional<std::filesystem::path> const& decoderEngineDir,
|
||||
|
||||
@ -69,6 +69,7 @@ using UniqueToken = tensorrt_llm::runtime::UniqueToken;
|
||||
using VecUniqueTokens = tensorrt_llm::runtime::VecUniqueTokens;
|
||||
using LoraTaskIdType = tensorrt_llm::runtime::LoraTaskIdType;
|
||||
using BlocksPerWindow = std::map<SizeType32, std::tuple<SizeType32, SizeType32>>;
|
||||
using CacheSaltIDType = tensorrt_llm::runtime::CacheSaltIDType;
|
||||
|
||||
// Type alias for multimodal hash key (hash array + start offset)
|
||||
using MmKey = std::pair<std::array<uint8_t, 32>, SizeType32>;
|
||||
@ -115,6 +116,7 @@ struct BlockKey
|
||||
// Extra keys for multimodal data (similar to VLLM's approach)
|
||||
// Each extra key is a pair of (mm_hash, start_offset_in_block)
|
||||
std::vector<MmKey> extraKeys;
|
||||
std::optional<CacheSaltIDType> cacheSaltID = std::nullopt;
|
||||
|
||||
BlockKey() = default;
|
||||
|
||||
@ -129,24 +131,25 @@ struct BlockKey
|
||||
}
|
||||
|
||||
explicit BlockKey(bool usesExtraIds, std::optional<LoraTaskIdType> loraTaskId, VecUniqueTokens uniqueTokens,
|
||||
std::vector<MmKey> extraKeys = {})
|
||||
std::vector<MmKey> extraKeys = {}, std::optional<CacheSaltIDType> cacheSaltID = std::nullopt)
|
||||
: usesExtraIds{usesExtraIds}
|
||||
, loraTaskId{loraTaskId}
|
||||
, uniqueTokens{std::move(uniqueTokens)}
|
||||
, extraKeys{std::move(extraKeys)}
|
||||
, cacheSaltID{cacheSaltID}
|
||||
{
|
||||
}
|
||||
|
||||
bool operator==(BlockKey const& other) const noexcept
|
||||
{
|
||||
return (usesExtraIds == other.usesExtraIds && loraTaskId == other.loraTaskId
|
||||
&& uniqueTokens == other.uniqueTokens && extraKeys == other.extraKeys);
|
||||
&& uniqueTokens == other.uniqueTokens && extraKeys == other.extraKeys && cacheSaltID == other.cacheSaltID);
|
||||
}
|
||||
|
||||
int partialMatch(BlockKey const& other) const noexcept
|
||||
{
|
||||
SizeType32 numMatched{0};
|
||||
if (loraTaskId == other.loraTaskId && extraKeys == other.extraKeys)
|
||||
if (loraTaskId == other.loraTaskId && extraKeys == other.extraKeys && cacheSaltID == other.cacheSaltID)
|
||||
{
|
||||
auto [matchEnd, otherMatchEnd] = std::mismatch(
|
||||
uniqueTokens.begin(), uniqueTokens.end(), other.uniqueTokens.begin(), other.uniqueTokens.end());
|
||||
|
||||
@ -100,8 +100,8 @@ public:
|
||||
RequestIdType, TensorPtr&, BeamTokens const&, TStream const&, std::optional<RequestIdType>)>;
|
||||
using RequestPtr = std::shared_ptr<GenericLlmRequest>;
|
||||
using MillisecondsType = std::chrono::milliseconds;
|
||||
using CacheSaltIDType = runtime::CacheSaltIDType;
|
||||
|
||||
// 49 parameters, 56 items in initialization list
|
||||
GenericLlmRequest(RequestIdType requestId, SizeType32 maxNewTokens, std::shared_ptr<VecTokens> const& inputTokens,
|
||||
runtime::SamplingConfig const& samplingConfig, bool isStreaming, std::optional<SizeType32> endId = std::nullopt,
|
||||
std::optional<SizeType32> padId = std::nullopt, std::optional<TensorPtr> embeddingBias = std::nullopt,
|
||||
@ -137,7 +137,8 @@ public:
|
||||
std::optional<executor::GuidedDecodingParams> guidedDecodingParams = std::nullopt,
|
||||
std::optional<SizeType32> languageAdapterUid = std::nullopt,
|
||||
std::optional<MillisecondsType> allottedTimeMs = std::nullopt,
|
||||
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt)
|
||||
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt,
|
||||
std::optional<CacheSaltIDType> cacheSaltID = std::nullopt)
|
||||
: mRequestId(requestId)
|
||||
, mPromptLen(inputTokens->size())
|
||||
, mMaxNewTokens(maxNewTokens)
|
||||
@ -194,6 +195,7 @@ public:
|
||||
, mGuidedDecodingParams(std::move(guidedDecodingParams))
|
||||
, mLanguageAdapterUid(languageAdapterUid)
|
||||
, mAllottedTimeMs(allottedTimeMs)
|
||||
, mCacheSaltID(cacheSaltID)
|
||||
{
|
||||
if (mEncoderTokens.has_value() || encoderInputFeatures.has_value())
|
||||
{
|
||||
@ -203,7 +205,6 @@ public:
|
||||
initialize(*inputTokens, returnLogProbs);
|
||||
}
|
||||
|
||||
// 32 parameters, 39 items in initialization list
|
||||
GenericLlmRequest(RequestIdType requestId, SizeType32 maxNewTokens, VecTokens const& inputTokens,
|
||||
runtime::SamplingConfig const& samplingConfig, bool isStreaming, std::optional<SizeType32> endId = std::nullopt,
|
||||
std::optional<SizeType32> padId = std::nullopt, std::optional<TensorPtr> embeddingBias = std::nullopt,
|
||||
@ -221,7 +222,8 @@ public:
|
||||
bool returnEncoderOutput = false, std::optional<RequestIdType> clientId = std::nullopt,
|
||||
executor::PriorityType priority = executor::Request::kDefaultPriority, SizeType32 numReturnSequences = 1,
|
||||
std::optional<SizeType32> languageAdapterUid = std::nullopt,
|
||||
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt)
|
||||
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt,
|
||||
std::optional<CacheSaltIDType> cacheSaltID = std::nullopt)
|
||||
: mRequestId(requestId)
|
||||
, mPromptLen(inputTokens.size())
|
||||
, mMaxNewTokens(maxNewTokens)
|
||||
@ -261,6 +263,7 @@ public:
|
||||
, mContextPhaseParams(contextPhaseParams)
|
||||
, mNumReturnSequences(numReturnSequences)
|
||||
, mLanguageAdapterUid(languageAdapterUid)
|
||||
, mCacheSaltID(cacheSaltID)
|
||||
{
|
||||
if (mEncoderTokens.has_value())
|
||||
{
|
||||
@ -269,7 +272,6 @@ public:
|
||||
initialize(inputTokens, returnLogProbs);
|
||||
}
|
||||
|
||||
// 29 items in initialization list
|
||||
GenericLlmRequest(RequestIdType requestId, executor::Request const& req)
|
||||
: mRequestId(requestId)
|
||||
, mPromptLen(req.getInputTokenIds().size())
|
||||
@ -300,6 +302,7 @@ public:
|
||||
, mGuidedDecodingParams(req.getGuidedDecodingParams())
|
||||
, mLanguageAdapterUid(req.getLanguageAdapterUid())
|
||||
, mAllottedTimeMs(req.getAllottedTimeMs())
|
||||
, mCacheSaltID(req.getCacheSaltID())
|
||||
{
|
||||
if (req.getRequestType() == executor::RequestType::REQUEST_TYPE_GENERATION_ONLY)
|
||||
{
|
||||
@ -1764,6 +1767,11 @@ public:
|
||||
return mLanguageAdapterUid;
|
||||
}
|
||||
|
||||
[[nodiscard]] std::optional<CacheSaltIDType> getCacheSaltID() const
|
||||
{
|
||||
return mCacheSaltID;
|
||||
}
|
||||
|
||||
std::vector<SizeType32> getLanguageAdapterRouting(
|
||||
SizeType32 const reqNumLanguages, SizeType32 const inputLength) const
|
||||
{
|
||||
@ -2042,6 +2050,9 @@ protected:
|
||||
|
||||
bool mUseDraftModel{false};
|
||||
|
||||
// Cache salt id for each request.
|
||||
std::optional<CacheSaltIDType> mCacheSaltID{std::nullopt};
|
||||
|
||||
private:
|
||||
void initialize(VecTokens const& inputTokens, bool outputLogProbs)
|
||||
{
|
||||
@ -2222,7 +2233,8 @@ public:
|
||||
std::optional<executor::GuidedDecodingParams> guidedDecodingParams = std::nullopt,
|
||||
std::optional<SizeType32> languageAdapterUid = std::nullopt,
|
||||
std::optional<MillisecondsType> allottedTimeMs = std::nullopt,
|
||||
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt)
|
||||
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt,
|
||||
std::optional<CacheSaltIDType> cacheSaltID = std::nullopt)
|
||||
: Base(requestId, maxNewTokens, std::move(inputTokens), samplingConfig, isStreaming, endId, padId,
|
||||
std::move(embeddingBias), std::move(badWordsList), std::move(stopWordsList), std::move(positionIds),
|
||||
std::move(promptEmbeddingTable), promptVocabSize, std::move(multimodalHashes),
|
||||
@ -2234,7 +2246,8 @@ public:
|
||||
std::move(encoderInputTokens), returnEncoderOutput, clientId, priority, std::move(encoderInputFeatures),
|
||||
std::move(encoderOutputLength), std::move(crossAttentionMask), llmRequestType,
|
||||
std::move(inputTokenExtraIds), numReturnSequences, std::move(eagleConfig), std::move(skipCrossAttnBlocks),
|
||||
returnPerfMetrics, std::move(guidedDecodingParams), languageAdapterUid, allottedTimeMs, contextPhaseParams)
|
||||
returnPerfMetrics, std::move(guidedDecodingParams), languageAdapterUid, allottedTimeMs, contextPhaseParams,
|
||||
cacheSaltID)
|
||||
{
|
||||
}
|
||||
|
||||
@ -2272,7 +2285,8 @@ public:
|
||||
std::optional<executor::GuidedDecodingParams> guidedDecodingParams = std::nullopt,
|
||||
std::optional<SizeType32> languageAdapterUid = std::nullopt,
|
||||
std::optional<MillisecondsType> allottedTimeMs = std::nullopt,
|
||||
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt)
|
||||
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt,
|
||||
std::optional<CacheSaltIDType> cacheSaltID = std::nullopt)
|
||||
: Base(requestId, maxNewTokens, std::make_shared<std::vector<TokenIdType>>(std::move(inputTokens)),
|
||||
samplingConfig, isStreaming, endId, padId, std::move(embeddingBias), std::move(badWordsList),
|
||||
std::move(stopWordsList),
|
||||
@ -2302,7 +2316,7 @@ public:
|
||||
inputTokenExtraIds ? std::make_optional(std::make_shared<VecTokenExtraIds>(std::move(*inputTokenExtraIds)))
|
||||
: std::optional<std::shared_ptr<VecTokenExtraIds>>(std::nullopt),
|
||||
numReturnSequences, std::move(eagleConfig), skipCrossAttnBlocks, returnPerfMetrics,
|
||||
std::move(guidedDecodingParams), languageAdapterUid, allottedTimeMs, contextPhaseParams)
|
||||
std::move(guidedDecodingParams), languageAdapterUid, allottedTimeMs, contextPhaseParams, cacheSaltID)
|
||||
{
|
||||
}
|
||||
|
||||
@ -2324,14 +2338,15 @@ public:
|
||||
bool returnEncoderOutput = false, std::optional<RequestIdType> clientId = std::nullopt,
|
||||
executor::PriorityType priority = executor::Request::kDefaultPriority, SizeType32 numReturnSequences = 1,
|
||||
std::optional<SizeType32> languageAdapterUid = std::nullopt,
|
||||
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt)
|
||||
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt,
|
||||
std::optional<CacheSaltIDType> cacheSaltID = std::nullopt)
|
||||
: Base(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming, endId, padId,
|
||||
std::move(embeddingBias), std::move(badWordsList), std::move(stopWordsList), std::move(positionIds),
|
||||
std::move(promptEmbeddingTable), promptVocabSize, loraTaskId, std::move(loraWeights), std::move(loraConfig),
|
||||
lookaheadConfig, returnLogProbs, returnContextLogits, returnGenerationLogits, std::move(draftTokens),
|
||||
std::move(draftLogits), excludeInputFromOutput, std::move(logitsPostProcessor),
|
||||
applyLogitsPostProcessorBatched, std::move(encoderInputTokens), returnEncoderOutput, clientId, priority,
|
||||
numReturnSequences, languageAdapterUid, contextPhaseParams)
|
||||
numReturnSequences, languageAdapterUid, contextPhaseParams, cacheSaltID)
|
||||
{
|
||||
}
|
||||
|
||||
|
||||
@ -670,7 +670,7 @@ public:
|
||||
/// @param allottedTimeMs The allotted time in milliseconds after which the request is cancelled with a timedOut
|
||||
/// finish reason. The request may exceed this time slightly, but at most by 1 forward pass (in pipeline parallelism
|
||||
/// that may involve multiple micro-batches). A request can be timed-out before ever being scheduled.
|
||||
// 34 parameters
|
||||
/// @param cacheSaltID Salt ID for KV cache blocks to limit the kv cache reuse to the requests with the same string.
|
||||
Request(VecTokens inputTokenIds, SizeType32 maxTokens, bool streaming = false,
|
||||
SamplingConfig const& samplingConfig = SamplingConfig(), OutputConfig const& outputConfig = OutputConfig(),
|
||||
std::optional<SizeType32> const& endId = std::nullopt, std::optional<SizeType32> const& padId = std::nullopt,
|
||||
@ -697,7 +697,8 @@ public:
|
||||
std::optional<EagleConfig> eagleConfig = std::nullopt, std::optional<Tensor> skipCrossAttnBlocks = std::nullopt,
|
||||
std::optional<GuidedDecodingParams> guidedDecodingParams = std::nullopt,
|
||||
std::optional<SizeType32> languageAdapterUid = std::nullopt,
|
||||
std::optional<MillisecondsType> allottedTimeMs = std::nullopt);
|
||||
std::optional<MillisecondsType> allottedTimeMs = std::nullopt,
|
||||
std::optional<CacheSaltIDType> cacheSaltID = std::nullopt);
|
||||
|
||||
/// @brief This logits postprocessor name will dispatch to the batched logits postprocessor
|
||||
static auto constexpr kBatchedPostProcessorName = "batched";
|
||||
@ -745,6 +746,7 @@ public:
|
||||
[[nodiscard]] std::optional<GuidedDecodingParams> getGuidedDecodingParams() const;
|
||||
[[nodiscard]] std::optional<SizeType32> getLanguageAdapterUid() const;
|
||||
[[nodiscard]] std::optional<MillisecondsType> getAllottedTimeMs() const;
|
||||
[[nodiscard]] std::optional<CacheSaltIDType> getCacheSaltID() const;
|
||||
[[nodiscard]] std::optional<std::vector<std::string>> getAdditionalOutputNames() const;
|
||||
|
||||
void setStreaming(bool streaming);
|
||||
@ -780,6 +782,7 @@ public:
|
||||
void setGuidedDecodingParams(GuidedDecodingParams const& guidedDecodingParams);
|
||||
void setLanguageAdapterUid(SizeType32 languageAdapterUid);
|
||||
void setAllottedTimeMs(MillisecondsType allottedTimeMs);
|
||||
void setCacheSaltID(CacheSaltIDType cacheSaltID);
|
||||
|
||||
private:
|
||||
friend class Serialization;
|
||||
|
||||
@ -58,6 +58,7 @@ using RandomSeedType = std::uint64_t;
|
||||
using VecLogProbs = std::vector<FloatType>;
|
||||
using StreamPtr = std::shared_ptr<tensorrt_llm::runtime::CudaStream>;
|
||||
using MillisecondsType = std::chrono::milliseconds;
|
||||
using CacheSaltIDType = std::uint64_t;
|
||||
using LogitsPostProcessor
|
||||
= std::function<void(IdType, Tensor&, BeamTokens const&, StreamPtr const&, std::optional<IdType>)>;
|
||||
using LogitsPostProcessorMap = std::unordered_map<std::string, LogitsPostProcessor>;
|
||||
|
||||
@ -44,6 +44,7 @@ using TokenIdType = std::int32_t;
|
||||
using LoraTaskIdType = std::uint64_t;
|
||||
using TokenExtraIdType = std::uint64_t;
|
||||
using VecTokenExtraIds = std::vector<TokenExtraIdType>;
|
||||
using CacheSaltIDType = std::uint64_t;
|
||||
|
||||
struct UniqueToken
|
||||
{
|
||||
|
||||
@ -131,7 +131,7 @@ std::vector<MmKey> generateBlockHashExtraKeys(
|
||||
// Check if this multimodal content overlaps with the current block
|
||||
if (endTokenIdx > startPos && startTokenIdx < startPos + length)
|
||||
{
|
||||
SizeType32 mmStartInBlock = (startPos >= startTokenIdx) ? 0 : startTokenIdx - startPos;
|
||||
uint64_t mmStartInBlock = (startPos >= startTokenIdx) ? 0 : static_cast<uint64_t>(startTokenIdx - startPos);
|
||||
extraKeys.emplace_back(mmHashArray, mmStartInBlock);
|
||||
}
|
||||
}
|
||||
@ -151,7 +151,7 @@ std::vector<BlockKey> buildBlockKeys(
|
||||
currentTokenIdx += uniqueTokens.size();
|
||||
|
||||
blockKeys.emplace_back(llmRequest.getInputTokensExtraIds().has_value(), llmRequest.getLoraTaskId(),
|
||||
std::move(uniqueTokens), std::move(extraKeys));
|
||||
std::move(uniqueTokens), std::move(extraKeys), llmRequest.getCacheSaltID());
|
||||
}
|
||||
return blockKeys;
|
||||
}
|
||||
@ -167,6 +167,16 @@ size_t BlockKeyHasher::hash(BlockKey const& blockKey, std::size_t parentHash) no
|
||||
// Constants provide very good distribution - each input bit affects each output bit with ~50% probability.
|
||||
size_t seed = blockKey.uniqueTokens.size() ^ parentHash * UINT64_C(0xbf58476d1ce4e5b9);
|
||||
|
||||
if (parentHash == 0 && blockKey.cacheSaltID)
|
||||
{
|
||||
// Only hashing the cache salt ID for the first block in the sequence
|
||||
uint64_t c = blockKey.cacheSaltID.value();
|
||||
c = (c ^ (c >> 30)) * UINT64_C(0xbf58476d1ce4e5b9);
|
||||
c = (c ^ (c >> 27)) * UINT64_C(0x94d049bb133111eb);
|
||||
c = c ^ (c >> 31);
|
||||
seed ^= c + 0x9e3779b9 + (seed << 6) + (seed >> 2);
|
||||
}
|
||||
|
||||
for (auto const& uniqueToken : blockKey.uniqueTokens)
|
||||
{
|
||||
uint32_t a = static_cast<uint32_t>(uniqueToken.tokenId);
|
||||
|
||||
@ -25,7 +25,6 @@
|
||||
|
||||
namespace tensorrt_llm::executor
|
||||
{
|
||||
// 36 parameters
|
||||
Request::Request(VecTokens inputTokenIds, SizeType32 maxTokens, bool streaming, SamplingConfig const& samplingConfig,
|
||||
OutputConfig const& outputConfig, std::optional<SizeType32> const& endId, std::optional<SizeType32> const& padId,
|
||||
std::optional<std::vector<SizeType32>> positionIds, std::optional<std::list<VecTokens>> badWords,
|
||||
@ -41,7 +40,7 @@ Request::Request(VecTokens inputTokenIds, SizeType32 maxTokens, bool streaming,
|
||||
std::optional<SizeType32> encoderOutputLength, std::optional<Tensor> crossAttentionMask,
|
||||
SizeType32 numReturnSequences, std::optional<EagleConfig> eagleConfig, std::optional<Tensor> skipCrossAttnBlocks,
|
||||
std::optional<GuidedDecodingParams> guidedDecodingParams, std::optional<SizeType32> languageAdapterUid,
|
||||
std::optional<MillisecondsType> allottedTimeMs)
|
||||
std::optional<MillisecondsType> allottedTimeMs, std::optional<CacheSaltIDType> cacheSaltID)
|
||||
: mImpl(std::make_unique<Impl>(std::move(inputTokenIds), maxTokens, streaming, samplingConfig, outputConfig, endId,
|
||||
padId, std::move(positionIds), std::move(badWords), std::move(stopWords), std::move(embeddingBias),
|
||||
std::move(externalDraftTokensConfig), std::move(pTuningConfig), std::move(multimodalInput),
|
||||
@ -50,7 +49,7 @@ Request::Request(VecTokens inputTokenIds, SizeType32 maxTokens, bool streaming,
|
||||
std::move(encoderInputTokenIds), clientId, returnAllGeneratedTokens, priority, type,
|
||||
std::move(contextPhaseParams), std::move(encoderInputFeatures), encoderOutputLength, crossAttentionMask,
|
||||
numReturnSequences, eagleConfig, skipCrossAttnBlocks, std::move(guidedDecodingParams), languageAdapterUid,
|
||||
allottedTimeMs))
|
||||
allottedTimeMs, cacheSaltID))
|
||||
{
|
||||
}
|
||||
|
||||
@ -249,6 +248,11 @@ std::optional<SizeType32> Request::getLanguageAdapterUid() const
|
||||
return mImpl->getLanguageAdapterUid();
|
||||
}
|
||||
|
||||
std::optional<CacheSaltIDType> Request::getCacheSaltID() const
|
||||
{
|
||||
return mImpl->getCacheSaltID();
|
||||
}
|
||||
|
||||
void Request::setStreaming(bool streaming)
|
||||
{
|
||||
mImpl->setStreaming(streaming);
|
||||
@ -413,4 +417,9 @@ void Request::setLanguageAdapterUid(SizeType32 languageAdapterUid)
|
||||
{
|
||||
return mImpl->setLanguageAdapterUid(languageAdapterUid);
|
||||
}
|
||||
|
||||
void Request::setCacheSaltID(CacheSaltIDType cacheSaltID)
|
||||
{
|
||||
return mImpl->setCacheSaltID(cacheSaltID);
|
||||
}
|
||||
} // namespace tensorrt_llm::executor
|
||||
|
||||
@ -32,7 +32,6 @@ class Request::Impl
|
||||
{
|
||||
|
||||
public:
|
||||
// 36 parameters, 36 items in initialization list
|
||||
Impl(VecTokens inputTokenIds, SizeType32 maxNewTokens, bool streaming, SamplingConfig const& samplingConfig,
|
||||
OutputConfig outputConfig, std::optional<TokenIdType> const& endId, std::optional<TokenIdType> const& padId,
|
||||
std::optional<std::vector<SizeType32>> positionIds, std::optional<std::list<VecTokens>> badWords,
|
||||
@ -48,7 +47,8 @@ public:
|
||||
std::optional<Tensor> encoderInputFeatures, std::optional<SizeType32> encoderOutputLength,
|
||||
std::optional<Tensor> crossAttentionMask, SizeType32 numReturnSequences, std::optional<EagleConfig> eagleConfig,
|
||||
std::optional<Tensor> skipCrossAttnBlocks, std::optional<GuidedDecodingParams> guidedDecodingParams,
|
||||
std::optional<SizeType32> languageAdapterUid, std::optional<MillisecondsType> allottedTimeMs)
|
||||
std::optional<SizeType32> languageAdapterUid, std::optional<MillisecondsType> allottedTimeMs,
|
||||
std::optional<CacheSaltIDType> cacheSaltID)
|
||||
: mInputTokenIds(std::move(inputTokenIds))
|
||||
, mMaxNewTokens(maxNewTokens)
|
||||
, mStreaming(streaming)
|
||||
@ -85,6 +85,7 @@ public:
|
||||
, mGuidedDecodingParams(std::move(guidedDecodingParams))
|
||||
, mLanguageAdapterUid(languageAdapterUid)
|
||||
, mAllottedTimeMs(allottedTimeMs)
|
||||
, mCacheSaltID(cacheSaltID)
|
||||
{
|
||||
validate();
|
||||
}
|
||||
@ -296,6 +297,11 @@ public:
|
||||
return mLanguageAdapterUid;
|
||||
}
|
||||
|
||||
[[nodiscard]] std::optional<CacheSaltIDType> getCacheSaltID() const
|
||||
{
|
||||
return mCacheSaltID;
|
||||
}
|
||||
|
||||
void setStreaming(bool streaming)
|
||||
{
|
||||
mStreaming = streaming;
|
||||
@ -470,6 +476,11 @@ public:
|
||||
mLanguageAdapterUid = languageAdapterUid;
|
||||
}
|
||||
|
||||
void setCacheSaltID(CacheSaltIDType cacheSaltID)
|
||||
{
|
||||
mCacheSaltID = cacheSaltID;
|
||||
}
|
||||
|
||||
private:
|
||||
void validate()
|
||||
{
|
||||
@ -543,6 +554,7 @@ private:
|
||||
lambda(mGuidedDecodingParams);
|
||||
lambda(mLanguageAdapterUid);
|
||||
lambda(mAllottedTimeMs ? std::make_optional(mAllottedTimeMs->count()) : std::nullopt);
|
||||
lambda(mCacheSaltID);
|
||||
}
|
||||
|
||||
VecTokens mInputTokenIds;
|
||||
@ -581,6 +593,7 @@ private:
|
||||
std::optional<GuidedDecodingParams> mGuidedDecodingParams;
|
||||
std::optional<SizeType32> mLanguageAdapterUid;
|
||||
std::optional<MillisecondsType> mAllottedTimeMs;
|
||||
std::optional<CacheSaltIDType> mCacheSaltID;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::executor
|
||||
|
||||
@ -711,8 +711,8 @@ Request Serialization::deserializeRequest(std::istream& is)
|
||||
auto allottedTimeMs = allottedTimeInt
|
||||
? std::optional<std::chrono::milliseconds>(std::chrono::milliseconds(*allottedTimeInt))
|
||||
: std::nullopt;
|
||||
auto cacheSaltID = su::deserialize<std::optional<CacheSaltIDType>>(is);
|
||||
|
||||
// 35 parameters
|
||||
return Request(std::move(inputTokenIds), maxNewTokens, streaming, samplingConfig, outputConfig, endId, padId,
|
||||
std::move(positionIds), std::move(badWords), std::move(stopWords), std::move(embeddingBias),
|
||||
std::move(externalDraftTokensConfig), std::move(pTuningConfig), std::move(multimodalInput),
|
||||
@ -721,7 +721,7 @@ Request Serialization::deserializeRequest(std::istream& is)
|
||||
std::move(encoderInputTokenIds), clientId, returnAllGeneratedTokens, priority, requestType,
|
||||
std::move(contextPhaseParams), std::move(encoderInputFeatures), encoderOutputLength,
|
||||
std::move(crossAttentionMask), numReturnSequences, std::move(eagleConfig), std::move(skipCrossAttnBlocks),
|
||||
std::move(guidedDecodingParams), languageAdapterUid, allottedTimeMs);
|
||||
std::move(guidedDecodingParams), languageAdapterUid, allottedTimeMs, cacheSaltID);
|
||||
}
|
||||
|
||||
void Serialization::serialize(Request const& request, std::ostream& os)
|
||||
|
||||
@ -189,6 +189,7 @@ void initBindings(nb::module_& m)
|
||||
.def_prop_ro("llm_request_type", &GenLlmReq::getLlmRequestType)
|
||||
.def_prop_ro("parent_request_id", &GenLlmReq::getParentRequestId)
|
||||
.def_prop_ro("is_child", &GenLlmReq::isChild)
|
||||
.def_prop_ro("cache_salt_id", &GenLlmReq::getCacheSaltID)
|
||||
.def_prop_ro("multimodal_hashes",
|
||||
[](GenLlmReq& self)
|
||||
{
|
||||
@ -287,7 +288,8 @@ void initBindings(nb::module_& m)
|
||||
std::optional<executor::GuidedDecodingParams> guided_decoding_params,
|
||||
std::optional<tb::LlmRequest::SizeType32> language_adapter_uid,
|
||||
std::optional<tb::LlmRequest::MillisecondsType> allotted_time_ms,
|
||||
std::optional<executor::ContextPhaseParams> context_phase_params)
|
||||
std::optional<executor::ContextPhaseParams> context_phase_params,
|
||||
std::optional<tb::LlmRequest::CacheSaltIDType> cache_salt_id)
|
||||
{
|
||||
auto makeOptionalTensor = [](std::optional<at::Tensor> const& atTensor, bool unsqueeze = false)
|
||||
{
|
||||
@ -316,7 +318,6 @@ void initBindings(nb::module_& m)
|
||||
auto cross_attention_mask_tensor_ptr = makeOptionalTensor(cross_attention_mask);
|
||||
auto skip_cross_attn_blocks_tensor_ptr = makeOptionalTensor(skip_cross_attn_blocks);
|
||||
|
||||
// 49 parameters
|
||||
new (self) tb::LlmRequest{request_id, max_new_tokens, input_tokens, sampling_config, is_streaming,
|
||||
end_id, pad_id, embedding_bias_tensor_ptr, bad_words_list_tensor_ptr, stop_words_list_tensor_ptr,
|
||||
position_ids, prompt_embedding_table_tensor_ptr, prompt_vocab_size, multimodal_hashes,
|
||||
@ -328,7 +329,8 @@ void initBindings(nb::module_& m)
|
||||
encoder_input_tokens, return_encoder_output, client_id, priority, encoder_input_features_tensor_ptr,
|
||||
encoder_output_length, cross_attention_mask_tensor_ptr, llm_request_type, input_token_extra_ids,
|
||||
num_return_sequences, eagle_config, skip_cross_attn_blocks_tensor_ptr, return_perf_metrics,
|
||||
guided_decoding_params, language_adapter_uid, allotted_time_ms, context_phase_params};
|
||||
guided_decoding_params, language_adapter_uid, allotted_time_ms, context_phase_params,
|
||||
cache_salt_id};
|
||||
},
|
||||
nb::arg("request_id"), nb::arg("max_new_tokens"), nb::arg("input_tokens"), nb::arg("sampling_config"),
|
||||
nb::arg("is_streaming"), nb::arg("end_id") = std::nullopt, nb::arg("pad_id") = std::nullopt,
|
||||
@ -353,7 +355,7 @@ void initBindings(nb::module_& m)
|
||||
nb::arg("eagle_config") = std::nullopt, nb::arg("skip_cross_attn_blocks") = std::nullopt,
|
||||
nb::arg("return_perf_metrics") = false, nb::arg("guided_decoding_params") = std::nullopt,
|
||||
nb::arg("language_adapter_uid") = std::nullopt, nb::arg("allotted_time_ms") = std::nullopt,
|
||||
nb::arg("context_phase_params") = std::nullopt)
|
||||
nb::arg("context_phase_params") = std::nullopt, nb::arg("cache_salt_id") = std::nullopt)
|
||||
.def("check_token_id_range", &tb::LlmRequest::checkTokenIdRange, nb::arg("vocab_size"))
|
||||
.def(nb::init<tb::LlmRequest const&>())
|
||||
.def("validate", &tb::LlmRequest::validate, nb::arg("max_input_len"), nb::arg("max_seq_len"),
|
||||
|
||||
@ -76,7 +76,6 @@ std::shared_ptr<tb::LlmRequest> LlmRequest::toTrtLlm() const
|
||||
? std::make_shared<std::vector<TokenIdType>>(*mEncoderTokens.value().get())
|
||||
: nullptr;
|
||||
auto const optEncoderInputTokens = std::optional<std::shared_ptr<std::vector<TokenIdType>>>(encoderInputTokens);
|
||||
// 49 parameters
|
||||
return std::make_shared<tb::LlmRequest>( //
|
||||
mRequestId, //
|
||||
mMaxNewTokens, //
|
||||
@ -126,6 +125,7 @@ std::shared_ptr<tb::LlmRequest> LlmRequest::toTrtLlm() const
|
||||
mGuidedDecodingParams, //
|
||||
mLanguageAdapterUid, //
|
||||
mAllottedTimeMs, //
|
||||
mContextPhaseParams //
|
||||
mContextPhaseParams, //
|
||||
mCacheSaltID //
|
||||
);
|
||||
}
|
||||
|
||||
@ -51,7 +51,6 @@ public:
|
||||
using VecTokenExtraIds = Base::VecTokenExtraIds;
|
||||
using LogitsPostProcessor = Base::LogitsPostProcessor;
|
||||
|
||||
// 49 parameters
|
||||
LlmRequest(RequestIdType requestId, SizeType32 maxNewTokens, std::vector<TokenIdType> inputTokens,
|
||||
runtime::SamplingConfig samplingConfig, bool isStreaming, std::optional<SizeType32> endId = std::nullopt,
|
||||
std::optional<SizeType32> padId = std::nullopt, std::optional<TensorPtr> embeddingBias = std::nullopt,
|
||||
@ -85,7 +84,8 @@ public:
|
||||
std::optional<executor::GuidedDecodingParams> guidedDecodingParams = std::nullopt,
|
||||
std::optional<SizeType32> languageAdapterUid = std::nullopt,
|
||||
std::optional<MillisecondsType> allottedTimeMs = std::nullopt,
|
||||
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt)
|
||||
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt,
|
||||
std::optional<CacheSaltIDType> cacheSaltID = std::nullopt)
|
||||
: Base(requestId, //
|
||||
maxNewTokens, //
|
||||
std::make_shared<std::vector<TokenIdType>>(std::move(inputTokens)), //
|
||||
@ -146,7 +146,8 @@ public:
|
||||
guidedDecodingParams, //
|
||||
languageAdapterUid, //
|
||||
allottedTimeMs, //
|
||||
contextPhaseParams //
|
||||
contextPhaseParams, //
|
||||
cacheSaltID //
|
||||
)
|
||||
{
|
||||
}
|
||||
|
||||
@ -573,11 +573,11 @@ void initRequestBindings(nb::module_& m)
|
||||
self.getClientId(), self.getReturnAllGeneratedTokens(), self.getPriority(), self.getRequestType(),
|
||||
self.getContextPhaseParams(), self.getEncoderInputFeatures(), self.getEncoderOutputLength(),
|
||||
self.getCrossAttentionMask(), self.getEagleConfig(), self.getSkipCrossAttnBlocks(),
|
||||
self.getGuidedDecodingParams());
|
||||
self.getGuidedDecodingParams(), self.getCacheSaltID());
|
||||
};
|
||||
auto requestSetstate = [](tle::Request& self, nb::tuple const& state)
|
||||
{
|
||||
if (state.size() != 33)
|
||||
if (state.size() != 34)
|
||||
{
|
||||
throw std::runtime_error("Invalid Request state!");
|
||||
}
|
||||
@ -601,7 +601,8 @@ void initRequestBindings(nb::module_& m)
|
||||
nb::cast<std::optional<tle::Tensor>>(state[27]), nb::cast<std::optional<SizeType32>>(state[28]),
|
||||
nb::cast<std::optional<tle::Tensor>>(state[29]), 1, nb::cast<std::optional<tle::EagleConfig>>(state[30]),
|
||||
nb::cast<std::optional<tle::Tensor>>(state[31]),
|
||||
nb::cast<std::optional<tle::GuidedDecodingParams>>(state[32]));
|
||||
nb::cast<std::optional<tle::GuidedDecodingParams>>(state[32]),
|
||||
nb::cast<std::optional<tle::CacheSaltIDType>>(state[33]));
|
||||
};
|
||||
|
||||
nb::class_<tle::Request> request(m, "Request", nb::dynamic_attr());
|
||||
@ -641,7 +642,8 @@ void initRequestBindings(nb::module_& m)
|
||||
std::optional<tle::Tensor>, // skipCrossAttnBlocks
|
||||
std::optional<tle::GuidedDecodingParams>, // guidedDecodingParams
|
||||
std::optional<tle::SizeType32>, // languageAdapterUid
|
||||
std::optional<tle::MillisecondsType> // allottedTimeMs
|
||||
std::optional<tle::MillisecondsType>, // allottedTimeMs
|
||||
std::optional<tle::CacheSaltIDType> // cacheSaltID
|
||||
>(),
|
||||
// clang-format off
|
||||
nb::arg("input_token_ids"),
|
||||
@ -680,8 +682,9 @@ void initRequestBindings(nb::module_& m)
|
||||
nb::arg("skip_cross_attn_blocks") = nb::none(),
|
||||
nb::arg("guided_decoding_params") = nb::none(),
|
||||
nb::arg("language_adapter_uid") = nb::none(),
|
||||
nb::arg("allotted_time_ms") = nb::none()
|
||||
) // clang-format on
|
||||
nb::arg("allotted_time_ms") = nb::none(),
|
||||
nb::arg("cache_salt_id") = nb::none()
|
||||
) // clang-format on
|
||||
.def_prop_ro("input_token_ids", &tle::Request::getInputTokenIds)
|
||||
.def_prop_ro("max_tokens", &tle::Request::getMaxTokens)
|
||||
.def_prop_rw("streaming", &tle::Request::getStreaming, &tle::Request::setStreaming)
|
||||
@ -723,6 +726,7 @@ void initRequestBindings(nb::module_& m)
|
||||
.def_prop_rw(
|
||||
"guided_decoding_params", &tle::Request::getGuidedDecodingParams, &tle::Request::setGuidedDecodingParams)
|
||||
.def_prop_rw("allotted_time_ms", &tle::Request::getAllottedTimeMs, &tle::Request::setAllottedTimeMs)
|
||||
.def_prop_rw("cache_salt_id", &tle::Request::getCacheSaltID, &tle::Request::setCacheSaltID)
|
||||
.def_prop_rw("context_phase_params", &tle::Request::getContextPhaseParams, &tle::Request::setContextPhaseParams)
|
||||
.def("__getstate__", requestGetstate)
|
||||
.def("__setstate__", requestSetstate);
|
||||
|
||||
@ -196,6 +196,7 @@ void initBindings(pybind11::module_& m)
|
||||
.def_property_readonly("llm_request_type", &GenLlmReq::getLlmRequestType)
|
||||
.def_property_readonly("parent_request_id", &GenLlmReq::getParentRequestId)
|
||||
.def_property_readonly("is_child", &GenLlmReq::isChild)
|
||||
.def_property_readonly("cache_salt_id", &GenLlmReq::getCacheSaltID)
|
||||
.def_property_readonly("multimodal_hashes",
|
||||
[](GenLlmReq& self)
|
||||
{
|
||||
@ -293,7 +294,8 @@ void initBindings(pybind11::module_& m)
|
||||
std::optional<executor::GuidedDecodingParams> guided_decoding_params,
|
||||
std::optional<tb::LlmRequest::SizeType32> language_adapter_uid,
|
||||
std::optional<tb::LlmRequest::MillisecondsType> allotted_time_ms,
|
||||
std::optional<executor::ContextPhaseParams> context_phase_params)
|
||||
std::optional<executor::ContextPhaseParams> context_phase_params,
|
||||
std::optional<tb::LlmRequest::CacheSaltIDType> cache_salt_id)
|
||||
{
|
||||
auto makeOptionalTensor = [](std::optional<at::Tensor> const& atTensor, bool unsqueeze = false)
|
||||
{
|
||||
@ -322,7 +324,6 @@ void initBindings(pybind11::module_& m)
|
||||
auto cross_attention_mask_tensor_ptr = makeOptionalTensor(cross_attention_mask);
|
||||
auto skip_cross_attn_blocks_tensor_ptr = makeOptionalTensor(skip_cross_attn_blocks);
|
||||
|
||||
// 49 parameters
|
||||
return tb::LlmRequest{request_id, max_new_tokens, input_tokens, sampling_config, is_streaming,
|
||||
end_id, pad_id, embedding_bias_tensor_ptr, bad_words_list_tensor_ptr,
|
||||
stop_words_list_tensor_ptr, position_ids, prompt_embedding_table_tensor_ptr, prompt_vocab_size,
|
||||
@ -335,7 +336,7 @@ void initBindings(pybind11::module_& m)
|
||||
encoder_input_features_tensor_ptr, encoder_output_length, cross_attention_mask_tensor_ptr,
|
||||
llm_request_type, input_token_extra_ids, num_return_sequences, eagle_config,
|
||||
skip_cross_attn_blocks_tensor_ptr, return_perf_metrics, guided_decoding_params,
|
||||
language_adapter_uid, allotted_time_ms, context_phase_params};
|
||||
language_adapter_uid, allotted_time_ms, context_phase_params, cache_salt_id};
|
||||
}),
|
||||
py::arg("request_id"), py::arg("max_new_tokens"), py::arg("input_tokens"), py::arg("sampling_config"),
|
||||
py::arg("is_streaming"), py::arg("end_id") = std::nullopt, py::arg("pad_id") = std::nullopt,
|
||||
@ -361,7 +362,7 @@ void initBindings(pybind11::module_& m)
|
||||
py::arg("eagle_config") = std::nullopt, py::arg("skip_cross_attn_blocks") = std::nullopt,
|
||||
py::arg("return_perf_metrics") = false, py::arg("guided_decoding_params") = std::nullopt,
|
||||
py::arg("language_adapter_uid") = std::nullopt, py::arg("allotted_time_ms") = std::nullopt,
|
||||
py::arg("context_phase_params") = std::nullopt)
|
||||
py::arg("context_phase_params") = std::nullopt, py::arg("cache_salt_id") = std::nullopt)
|
||||
.def("check_token_id_range", &tb::LlmRequest::checkTokenIdRange, py::arg("vocab_size"))
|
||||
.def(py::init<tb::LlmRequest const&>())
|
||||
.def("validate", &tb::LlmRequest::validate, py::arg("max_input_len"), py::arg("max_seq_len"),
|
||||
|
||||
@ -75,7 +75,6 @@ std::shared_ptr<tb::LlmRequest> LlmRequest::toTrtLlm() const
|
||||
? std::make_shared<std::vector<TokenIdType>>(*mEncoderTokens.value().get())
|
||||
: nullptr;
|
||||
auto const optEncoderInputTokens = std::optional<std::shared_ptr<std::vector<TokenIdType>>>(encoderInputTokens);
|
||||
// 49 parameters
|
||||
return std::make_shared<tb::LlmRequest>( //
|
||||
mRequestId, //
|
||||
mMaxNewTokens, //
|
||||
@ -125,6 +124,7 @@ std::shared_ptr<tb::LlmRequest> LlmRequest::toTrtLlm() const
|
||||
mGuidedDecodingParams, //
|
||||
mLanguageAdapterUid, //
|
||||
mAllottedTimeMs, //
|
||||
mContextPhaseParams //
|
||||
mContextPhaseParams, //
|
||||
mCacheSaltID //
|
||||
);
|
||||
}
|
||||
|
||||
@ -49,8 +49,8 @@ public:
|
||||
using VecTokens = Base::VecTokens;
|
||||
using VecTokenExtraIds = Base::VecTokenExtraIds;
|
||||
using LogitsPostProcessor = Base::LogitsPostProcessor;
|
||||
using CacheSaltIDType = Base::CacheSaltIDType;
|
||||
|
||||
// 49 parameters
|
||||
LlmRequest(RequestIdType requestId, SizeType32 maxNewTokens, std::vector<TokenIdType> inputTokens,
|
||||
runtime::SamplingConfig samplingConfig, bool isStreaming, std::optional<SizeType32> endId = std::nullopt,
|
||||
std::optional<SizeType32> padId = std::nullopt, std::optional<TensorPtr> embeddingBias = std::nullopt,
|
||||
@ -84,7 +84,8 @@ public:
|
||||
std::optional<executor::GuidedDecodingParams> guidedDecodingParams = std::nullopt,
|
||||
std::optional<SizeType32> languageAdapterUid = std::nullopt,
|
||||
std::optional<MillisecondsType> allottedTimeMs = std::nullopt,
|
||||
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt)
|
||||
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt,
|
||||
std::optional<CacheSaltIDType> cacheSaltID = std::nullopt)
|
||||
: Base(requestId, //
|
||||
maxNewTokens, //
|
||||
std::make_shared<std::vector<TokenIdType>>(std::move(inputTokens)), //
|
||||
@ -145,7 +146,8 @@ public:
|
||||
guidedDecodingParams, //
|
||||
languageAdapterUid, //
|
||||
allottedTimeMs, //
|
||||
contextPhaseParams //
|
||||
contextPhaseParams, //
|
||||
cacheSaltID //
|
||||
)
|
||||
{
|
||||
}
|
||||
|
||||
@ -526,11 +526,11 @@ void initRequestBindings(pybind11::module_& m)
|
||||
self.getClientId(), self.getReturnAllGeneratedTokens(), self.getPriority(), self.getRequestType(),
|
||||
self.getContextPhaseParams(), self.getEncoderInputFeatures(), self.getEncoderOutputLength(),
|
||||
self.getCrossAttentionMask(), self.getEagleConfig(), self.getSkipCrossAttnBlocks(),
|
||||
self.getGuidedDecodingParams());
|
||||
self.getGuidedDecodingParams(), self.getCacheSaltID());
|
||||
};
|
||||
auto requestSetstate = [](py::tuple const& state)
|
||||
{
|
||||
if (state.size() != 33)
|
||||
if (state.size() != 34)
|
||||
{
|
||||
throw std::runtime_error("Invalid Request state!");
|
||||
}
|
||||
@ -550,7 +550,8 @@ void initRequestBindings(pybind11::module_& m)
|
||||
state[25].cast<tle::RequestType>(), state[26].cast<std::optional<tle::ContextPhaseParams>>(),
|
||||
state[27].cast<std::optional<tle::Tensor>>(), state[28].cast<std::optional<SizeType32>>(),
|
||||
state[29].cast<std::optional<tle::Tensor>>(), 1, state[30].cast<std::optional<tle::EagleConfig>>(),
|
||||
state[31].cast<std::optional<tle::Tensor>>(), state[32].cast<std::optional<tle::GuidedDecodingParams>>());
|
||||
state[31].cast<std::optional<tle::Tensor>>(), state[32].cast<std::optional<tle::GuidedDecodingParams>>(),
|
||||
state[33].cast<std::optional<tle::CacheSaltIDType>>());
|
||||
};
|
||||
|
||||
py::class_<tle::Request> request(m, "Request", pybind11::dynamic_attr());
|
||||
@ -590,7 +591,8 @@ void initRequestBindings(pybind11::module_& m)
|
||||
std::optional<tle::Tensor>, // skipCrossAttnBlocks
|
||||
std::optional<tle::GuidedDecodingParams>, // guidedDecodingParams
|
||||
std::optional<tle::SizeType32>, // languageAdapterUid
|
||||
std::optional<tle::MillisecondsType> // allottedTimeMs
|
||||
std::optional<tle::MillisecondsType>, // allottedTimeMs
|
||||
std::optional<tle::CacheSaltIDType> // cacheSaltID
|
||||
>(),
|
||||
// clang-format off
|
||||
py::arg("input_token_ids"),
|
||||
@ -630,8 +632,9 @@ void initRequestBindings(pybind11::module_& m)
|
||||
py::arg("skip_cross_attn_blocks") = py::none(),
|
||||
py::arg("guided_decoding_params") = py::none(),
|
||||
py::arg("language_adapter_uid") = py::none(),
|
||||
py::arg("allotted_time_ms") = py::none()
|
||||
) // clang-format on
|
||||
py::arg("allotted_time_ms") = py::none(),
|
||||
py::arg("cache_salt_id") = py::none()
|
||||
) // clang-format on
|
||||
.def_property_readonly("input_token_ids", &tle::Request::getInputTokenIds)
|
||||
.def_property_readonly("max_tokens", &tle::Request::getMaxTokens)
|
||||
.def_property("streaming", &tle::Request::getStreaming, &tle::Request::setStreaming)
|
||||
@ -675,6 +678,7 @@ void initRequestBindings(pybind11::module_& m)
|
||||
.def_property(
|
||||
"guided_decoding_params", &tle::Request::getGuidedDecodingParams, &tle::Request::setGuidedDecodingParams)
|
||||
.def_property("allotted_time_ms", &tle::Request::getAllottedTimeMs, &tle::Request::setAllottedTimeMs)
|
||||
.def_property("cache_salt_id", &tle::Request::getCacheSaltID, &tle::Request::setCacheSaltID)
|
||||
.def_property(
|
||||
"context_phase_params", &tle::Request::getContextPhaseParams, &tle::Request::setContextPhaseParams)
|
||||
.def(py::pickle(requestGetstate, requestSetstate));
|
||||
|
||||
@ -1686,6 +1686,207 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdAndLoraTaskIdTest)
|
||||
EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool);
|
||||
}
|
||||
|
||||
TEST_F(KVCacheManagerTest, BlockManagerReuseWithCacheSaltIdTest)
|
||||
{
|
||||
// Test that cache_salt_id prevents KV cache reuse between requests with same tokens
|
||||
// but different cache_salt_id values.
|
||||
using VecTokenExtraIds = LlmRequest::VecTokenExtraIds;
|
||||
using CacheSaltIDType = LlmRequest::CacheSaltIDType;
|
||||
|
||||
auto constexpr numLayers = 12;
|
||||
auto constexpr numKvHeads = 6;
|
||||
auto constexpr sizePerHead = 16;
|
||||
auto constexpr tokensPerBlock = 4;
|
||||
auto constexpr maxBlocksPerSeq = 4;
|
||||
auto constexpr blocksInPrimaryPool = 16;
|
||||
auto constexpr blocksInSecondaryPool = 0;
|
||||
auto constexpr maxNumSequences = 8;
|
||||
auto const stream = std::make_shared<tr::CudaStream>();
|
||||
auto constexpr onboardBlocks = true;
|
||||
auto constexpr numReturnSequences = 1;
|
||||
auto constexpr maxAttentionWindow = tokensPerBlock * maxBlocksPerSeq;
|
||||
auto constexpr beamWidth = 1;
|
||||
|
||||
auto const blocksPerWindow = BlocksPerWindow{{maxAttentionWindow, {blocksInPrimaryPool, blocksInSecondaryPool}}};
|
||||
|
||||
BlockManager blockManager(std::vector(numLayers, numKvHeads), sizePerHead, tokensPerBlock, blocksPerWindow,
|
||||
maxNumSequences, stream, maxAttentionWindow, beamWidth,
|
||||
std::vector<BlockManager::SizeType32>{maxAttentionWindow}, std::nullopt, nvinfer1::DataType::kHALF, 0,
|
||||
onboardBlocks);
|
||||
blockManager.allocatePools(false);
|
||||
|
||||
EXPECT_EQ(blockManager.getTokensPerBlock(), tokensPerBlock);
|
||||
EXPECT_EQ(blockManager.getMaxNumBlocks(), blocksInPrimaryPool);
|
||||
EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool);
|
||||
|
||||
SizeType32 constexpr maxNewTokens{0};
|
||||
tr::SamplingConfig const samplingConfig{beamWidth};
|
||||
bool constexpr isStreaming{false};
|
||||
|
||||
// Create shared input tokens
|
||||
auto inputTokens = std::make_shared<VecTokens>(VecTokens{100, 101, 102, 103, 104, 105, 106, 107, 108});
|
||||
auto const inputLength = static_cast<SizeType32>(inputTokens->size());
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////
|
||||
// Test Case 1: Request without cache_salt_id
|
||||
LlmRequest::RequestIdType requestId{0};
|
||||
auto llmRequest0 = std::make_shared<LlmRequest>(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming,
|
||||
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
|
||||
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
|
||||
std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, std::nullopt,
|
||||
false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt,
|
||||
LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::nullopt, numReturnSequences, std::nullopt,
|
||||
std::nullopt, false, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
|
||||
std::nullopt); // No cache_salt_id
|
||||
|
||||
GenerationRequest seq0{requestId, inputLength, beamWidth, blockManager.getWindowSizesMetadata()};
|
||||
|
||||
// Add first request and get blocks 0, 1, 2
|
||||
auto constexpr beamIdx = 0;
|
||||
auto promptLen0 = llmRequest0->getNumTokens(beamIdx);
|
||||
auto numContextBlocks0 = tc::ceilDiv(promptLen0, blockManager.getTokensPerBlock());
|
||||
blockManager.addSequence(seq0, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow);
|
||||
EXPECT_EQ(llmRequest0->getContextCurrentPosition(), 0);
|
||||
EXPECT_THAT(seq0.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 2}));
|
||||
|
||||
// Add generated tokens
|
||||
llmRequest0->addNewToken(3, beamIdx);
|
||||
llmRequest0->addNewToken(4, beamIdx);
|
||||
auto numTokens = llmRequest0->getNumTokens(beamIdx);
|
||||
auto numBlocks = tc::ceilDiv(numTokens, tokensPerBlock);
|
||||
EXPECT_EQ(numBlocks, 3);
|
||||
EXPECT_EQ(blockManager.getNumAllocatedBlocks(), numBlocks);
|
||||
EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool - numBlocks);
|
||||
|
||||
// Release blocks to make them available for reuse
|
||||
blockManager.releaseBlocks(seq0, llmRequest0);
|
||||
EXPECT_EQ(blockManager.getNumAllocatedBlocks(), 0);
|
||||
EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////
|
||||
// Test Case 2: Request with same tokens but with cache_salt_id = 12345
|
||||
requestId = 1;
|
||||
CacheSaltIDType cacheSaltId1{12345};
|
||||
auto llmRequest1 = std::make_shared<LlmRequest>(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming,
|
||||
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
|
||||
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
|
||||
std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, std::nullopt,
|
||||
false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt,
|
||||
LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::nullopt, numReturnSequences, std::nullopt,
|
||||
std::nullopt, false, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
|
||||
cacheSaltId1); // With cache_salt_id = 12345
|
||||
|
||||
GenerationRequest seq1{requestId, inputLength, beamWidth, blockManager.getWindowSizesMetadata()};
|
||||
|
||||
// Should NOT reuse blocks despite same tokens, because cache_salt_id is different
|
||||
auto promptLen1 = llmRequest1->getNumTokens(beamIdx);
|
||||
auto numContextBlocks1 = tc::ceilDiv(promptLen1, blockManager.getTokensPerBlock());
|
||||
blockManager.addSequence(seq1, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow);
|
||||
EXPECT_EQ(llmRequest1->getContextCurrentPosition(), 0); // No reuse, starts from scratch
|
||||
EXPECT_THAT(seq1.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({3, 4, 5}));
|
||||
|
||||
llmRequest1->addNewToken(3, beamIdx);
|
||||
llmRequest1->addNewToken(4, beamIdx);
|
||||
EXPECT_EQ(blockManager.getNumAllocatedBlocks(), numBlocks);
|
||||
EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool - numBlocks);
|
||||
|
||||
// Release blocks
|
||||
blockManager.releaseBlocks(seq1, llmRequest1);
|
||||
EXPECT_EQ(blockManager.getNumAllocatedBlocks(), 0);
|
||||
EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////
|
||||
// Test Case 3: Request with same tokens and same cache_salt_id = 12345
|
||||
requestId = 2;
|
||||
auto llmRequest2 = std::make_shared<LlmRequest>(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming,
|
||||
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
|
||||
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
|
||||
std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, std::nullopt,
|
||||
false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt,
|
||||
LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::nullopt, numReturnSequences, std::nullopt,
|
||||
std::nullopt, false, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
|
||||
cacheSaltId1); // Same cache_salt_id = 12345
|
||||
|
||||
GenerationRequest seq2{requestId, inputLength, beamWidth, blockManager.getWindowSizesMetadata()};
|
||||
|
||||
// SHOULD reuse blocks because both tokens and cache_salt_id match
|
||||
auto promptLen2 = llmRequest2->getNumTokens(beamIdx);
|
||||
auto numContextBlocks2 = tc::ceilDiv(promptLen2, blockManager.getTokensPerBlock());
|
||||
blockManager.addSequence(seq2, promptLen2, numContextBlocks2, *llmRequest2, maxAttentionWindow);
|
||||
EXPECT_EQ(llmRequest2->getContextCurrentPosition(), 2 * tokensPerBlock); // Reuse blocks 3,4
|
||||
EXPECT_THAT(seq2.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({3, 4, 6}));
|
||||
|
||||
llmRequest2->addNewToken(3, beamIdx);
|
||||
llmRequest2->addNewToken(4, beamIdx);
|
||||
EXPECT_EQ(blockManager.getNumAllocatedBlocks(), numBlocks);
|
||||
EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool - numBlocks);
|
||||
|
||||
// Release blocks
|
||||
blockManager.releaseBlocks(seq2, llmRequest2);
|
||||
EXPECT_EQ(blockManager.getNumAllocatedBlocks(), 0);
|
||||
EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////
|
||||
// Test Case 4: Request with same tokens but different cache_salt_id = 67890
|
||||
requestId = 3;
|
||||
CacheSaltIDType cacheSaltId2{67890};
|
||||
auto llmRequest3 = std::make_shared<LlmRequest>(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming,
|
||||
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
|
||||
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
|
||||
std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, std::nullopt,
|
||||
false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt,
|
||||
LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::nullopt, numReturnSequences, std::nullopt,
|
||||
std::nullopt, false, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
|
||||
cacheSaltId2); // Different cache_salt_id = 67890
|
||||
|
||||
GenerationRequest seq3{requestId, inputLength, beamWidth, blockManager.getWindowSizesMetadata()};
|
||||
|
||||
// Should NOT reuse blocks from any previous request because cache_salt_id is different
|
||||
auto promptLen3 = llmRequest3->getNumTokens(beamIdx);
|
||||
auto numContextBlocks3 = tc::ceilDiv(promptLen3, blockManager.getTokensPerBlock());
|
||||
blockManager.addSequence(seq3, promptLen3, numContextBlocks3, *llmRequest3, maxAttentionWindow);
|
||||
EXPECT_EQ(llmRequest3->getContextCurrentPosition(), 0); // No reuse
|
||||
EXPECT_THAT(seq3.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({7, 8, 9}));
|
||||
|
||||
llmRequest3->addNewToken(5, beamIdx);
|
||||
llmRequest3->addNewToken(6, beamIdx);
|
||||
EXPECT_EQ(blockManager.getNumAllocatedBlocks(), numBlocks);
|
||||
EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool - numBlocks);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////
|
||||
// Test Case 5: Request without cache_salt_id again
|
||||
requestId = 4;
|
||||
auto llmRequest4 = std::make_shared<LlmRequest>(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming,
|
||||
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
|
||||
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
|
||||
std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, std::nullopt,
|
||||
false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt,
|
||||
LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::nullopt, numReturnSequences, std::nullopt,
|
||||
std::nullopt, false, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
|
||||
std::nullopt); // No cache_salt_id
|
||||
|
||||
GenerationRequest seq4{requestId, inputLength, beamWidth, blockManager.getWindowSizesMetadata()};
|
||||
|
||||
// Should reuse blocks from request0 (blocks 0,1) because both have no cache_salt_id
|
||||
auto promptLen4 = llmRequest4->getNumTokens(beamIdx);
|
||||
auto numContextBlocks4 = tc::ceilDiv(promptLen4, blockManager.getTokensPerBlock());
|
||||
blockManager.addSequence(seq4, promptLen4, numContextBlocks4, *llmRequest4, maxAttentionWindow);
|
||||
EXPECT_EQ(llmRequest4->getContextCurrentPosition(), 2 * tokensPerBlock); // Reuse blocks 0,1
|
||||
EXPECT_THAT(seq4.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 10}));
|
||||
|
||||
llmRequest4->addNewToken(7, beamIdx);
|
||||
numTokens = llmRequest4->getNumTokens(beamIdx);
|
||||
numBlocks = tc::ceilDiv(numTokens, tokensPerBlock);
|
||||
EXPECT_EQ(blockManager.getNumAllocatedBlocks(), numBlocks * 2);
|
||||
EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool - numBlocks * 2);
|
||||
|
||||
// Clean up
|
||||
blockManager.releaseBlocks(seq3, llmRequest3);
|
||||
blockManager.releaseBlocks(seq4, llmRequest4);
|
||||
EXPECT_EQ(blockManager.getNumAllocatedBlocks(), 0);
|
||||
EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool);
|
||||
}
|
||||
|
||||
TEST_F(KVCacheManagerTest, KVCacheManagerPerRequestStatsTest)
|
||||
{
|
||||
auto constexpr numLayers = 12;
|
||||
|
||||
@ -45,6 +45,7 @@ from tensorrt_llm.functional import PositionEmbeddingType
|
||||
from tensorrt_llm.llmapi.utils import enable_llm_debug
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
from tensorrt_llm.models.modeling_utils import QuantConfig
|
||||
from tensorrt_llm.quantization.mode import QuantAlgo
|
||||
from tensorrt_llm.quantization.utils.fp8_utils import (
|
||||
resmooth_to_fp8_e8m0, transform_sf_into_required_layout)
|
||||
|
||||
@ -468,10 +469,13 @@ class Deepseekv3MoE(nn.Module):
|
||||
layer_idx=layer_idx,
|
||||
# DS-R1 W4A8 is only supported through custom quantization script from
|
||||
# examples/quantization/quantize_mixed_precision_moe.py
|
||||
weight_loading_mode=(MoEWeightLoadingMode.W4A8_CUSTOM
|
||||
if model_config.quant_config.quant_mode.
|
||||
is_int4_weight_only_per_group() else
|
||||
MoEWeightLoadingMode.VANILLA))
|
||||
weight_loading_mode=(
|
||||
MoEWeightLoadingMode.W4A8_CUSTOM
|
||||
if self._get_experts_quant_config(
|
||||
model_config,
|
||||
layer_idx).layer_quant_mode.is_int4_weight_only_per_group()
|
||||
else MoEWeightLoadingMode.VANILLA),
|
||||
)
|
||||
|
||||
self.mapping = model_config.mapping
|
||||
|
||||
@ -536,6 +540,13 @@ class Deepseekv3MoE(nn.Module):
|
||||
|
||||
return shared_tp_size, shared_output_scale
|
||||
|
||||
@staticmethod
|
||||
def _get_experts_quant_config(model_config, layer_idx: int) -> QuantConfig:
|
||||
if getattr(model_config, "quant_config_dict", None) is None:
|
||||
return model_config.quant_config
|
||||
return model_config.quant_config_dict.get(
|
||||
f"model.layers.{layer_idx}.mlp.experts", model_config.quant_config)
|
||||
|
||||
def compute_routed_output(self, hidden_states, hidden_states_fp4,
|
||||
all_rank_num_tokens, all_rank_max_num_tokens,
|
||||
do_finalize):
|
||||
@ -657,6 +668,9 @@ class DeepseekV3DecoderLayer(DecoderLayer):
|
||||
quant_config = self._get_decoder_layer_quant_config(
|
||||
model_config, layer_idx)
|
||||
self.is_nvfp4 = quant_config.layer_quant_mode.has_nvfp4()
|
||||
assert (
|
||||
quant_config.quant_algo
|
||||
is not QuantAlgo.MIXED_PRECISION), "MIXED_PRECISION is ambiguous"
|
||||
|
||||
has_tp = mapping.has_tp()
|
||||
self.allreduce = AllReduce(mapping=model_config.mapping,
|
||||
|
||||
@ -40,14 +40,42 @@ class CUDAGraphRunner:
|
||||
self.max_beam_width = engine.max_beam_width
|
||||
self.spec_config = engine.spec_config
|
||||
|
||||
self.max_possible_draft_len = (self.spec_config.max_draft_len
|
||||
if self.enable_spec_decode else 0)
|
||||
|
||||
self.graphs: Dict[Tuple[int, int], torch.cuda.CUDAGraph] = {}
|
||||
self.static_inputs: Dict[Tuple[int, int], Dict[str, torch.Tensor]] = {}
|
||||
self.graph_outputs: Dict[Tuple[int, int],
|
||||
Callable[[], Optional[torch.Tensor]]] = {}
|
||||
self.graph_metadata: Dict[Tuple[int, int], Dict[str, Any]] = {}
|
||||
self.memory_pool = engine._cuda_graph_mem_pool
|
||||
self.padding_dummy_request: Optional["Request"] = None
|
||||
|
||||
self.shared_static_tensors: Dict[str, torch.Tensor] = {}
|
||||
if self.enabled:
|
||||
self._create_shared_static_tensors()
|
||||
|
||||
def _create_shared_static_tensors(self):
|
||||
"""Allocates static tensors sized for the largest possible batch."""
|
||||
engine = self._get_engine()
|
||||
|
||||
token_per_request = self.max_possible_draft_len + 1
|
||||
max_total_tokens = (self.max_supported_batch_size *
|
||||
self.max_beam_width * token_per_request)
|
||||
max_total_tokens = min(max_total_tokens, engine.max_num_tokens)
|
||||
|
||||
self.shared_static_tensors = {
|
||||
"input_ids":
|
||||
torch.ones((max_total_tokens, ), device="cuda", dtype=torch.int32),
|
||||
"position_ids":
|
||||
torch.zeros((1, max_total_tokens), device="cuda",
|
||||
dtype=torch.int32),
|
||||
}
|
||||
if engine.use_mrope:
|
||||
self.shared_static_tensors["mrope_position_deltas"] = torch.zeros(
|
||||
(self.max_supported_batch_size, 1),
|
||||
device="cuda",
|
||||
dtype=torch.int32)
|
||||
|
||||
@property
|
||||
def enable_spec_decode(self):
|
||||
return self._get_engine().is_spec_decode
|
||||
@ -139,38 +167,32 @@ class CUDAGraphRunner:
|
||||
def capture(self, batch_size: int, forward_fn: Callable,
|
||||
initial_inputs: Dict[str, Any]):
|
||||
"""Captures the forward pass for a given batch size."""
|
||||
engine = self._get_engine()
|
||||
key = (batch_size, self.draft_len)
|
||||
spec_metadata = initial_inputs.get("spec_metadata", None)
|
||||
# [CUDA graph spec decode padding]
|
||||
# We pad input IDs/position IDs to the maximum draft length (token per request).
|
||||
# We're forced to do this because we cannot reallocate inputs over many graph runs.
|
||||
token_per_request = spec_metadata.max_draft_len + 1 if spec_metadata is not None else 1
|
||||
token_per_request = self.max_possible_draft_len + 1
|
||||
num_tokens_for_capture = (batch_size * self.max_beam_width *
|
||||
token_per_request)
|
||||
|
||||
static_tensors = {
|
||||
sliced_static_tensors = {
|
||||
"input_ids":
|
||||
torch.ones((batch_size * self.max_beam_width * token_per_request, ),
|
||||
device="cuda",
|
||||
dtype=torch.int32),
|
||||
self.shared_static_tensors["input_ids"][:num_tokens_for_capture],
|
||||
"position_ids":
|
||||
torch.zeros((
|
||||
1,
|
||||
batch_size * self.max_beam_width * token_per_request,
|
||||
),
|
||||
device="cuda",
|
||||
dtype=torch.int32),
|
||||
self.shared_static_tensors["position_ids"]
|
||||
[:, :num_tokens_for_capture],
|
||||
}
|
||||
if engine.use_mrope:
|
||||
static_tensors["mrope_position_deltas"] = torch.zeros(
|
||||
(batch_size, 1), device="cuda", dtype=torch.int32)
|
||||
self.static_inputs[key] = static_tensors
|
||||
if "mrope_position_deltas" in self.shared_static_tensors:
|
||||
sliced_static_tensors["mrope_position_deltas"] = \
|
||||
self.shared_static_tensors["mrope_position_deltas"][:batch_size]
|
||||
|
||||
# Use the sliced tensors for capture
|
||||
capture_inputs = initial_inputs.copy()
|
||||
capture_inputs.update(static_tensors)
|
||||
capture_inputs.update(sliced_static_tensors)
|
||||
|
||||
self.graph_metadata[key] = {
|
||||
"attn_metadata": initial_inputs["attn_metadata"],
|
||||
"spec_metadata": spec_metadata,
|
||||
"spec_metadata": initial_inputs.get("spec_metadata", None),
|
||||
}
|
||||
|
||||
# We have to do warm up runs to initialize PyTorch's
|
||||
@ -198,7 +220,7 @@ class CUDAGraphRunner:
|
||||
assert current_inputs.get(
|
||||
"spec_metadata") is stored_meta["spec_metadata"]
|
||||
|
||||
static_tensors = self.static_inputs[key]
|
||||
static_tensors = self.shared_static_tensors
|
||||
|
||||
input_ids = current_inputs["input_ids"]
|
||||
seqlen = input_ids.shape[0]
|
||||
@ -301,7 +323,6 @@ class CUDAGraphRunner:
|
||||
for graph in self.graphs.values():
|
||||
graph.reset()
|
||||
self.graphs.clear()
|
||||
self.static_inputs.clear()
|
||||
self.graph_outputs.clear()
|
||||
self.graph_metadata.clear()
|
||||
self.padding_dummy_request = None
|
||||
|
||||
@ -546,6 +546,7 @@ def executor_request_to_llm_request(
|
||||
priority=0.5,
|
||||
llm_request_type=llm_request_type,
|
||||
context_phase_params=executor_request.context_phase_params,
|
||||
cache_salt_id=executor_request.cache_salt_id,
|
||||
py_multimodal_data=getattr(executor_request, "py_multimodal_data",
|
||||
None))
|
||||
if child_req_ids:
|
||||
|
||||
@ -426,7 +426,6 @@ class PyTorchModelEngine(ModelEngine):
|
||||
# the model engine.
|
||||
self.attn_metadata = None
|
||||
self.iter_states = {}
|
||||
self._cuda_graphs = {}
|
||||
self._cuda_graph_mem_pool = self._torch_compile_backend._graph_pool_handle if self._torch_compile_enabled else None
|
||||
|
||||
self._cuda_graph_padding_enabled = pytorch_backend_config.cuda_graph_padding_enabled
|
||||
|
||||
@ -124,6 +124,7 @@ class GenerationExecutor(ABC):
|
||||
postproc_params: Optional[PostprocParams] = None,
|
||||
multimodal_params: Optional[MultimodalParams] = None,
|
||||
scheduling_params: Optional[SchedulingParams] = None,
|
||||
cache_salt_id: Optional[int] = None,
|
||||
) -> GenerationResult:
|
||||
"""Generate output for the given prompt token ids in the asynchronous mode.
|
||||
Asynchronous generation accepts single prompt only.
|
||||
@ -147,7 +148,8 @@ class GenerationExecutor(ABC):
|
||||
kv_cache_retention_config=kv_cache_retention_config,
|
||||
disaggregated_params=disaggregated_params,
|
||||
multimodal_params=multimodal_params,
|
||||
scheduling_params=scheduling_params)
|
||||
scheduling_params=scheduling_params,
|
||||
cache_salt_id=cache_salt_id)
|
||||
result = self.submit(request)
|
||||
# release memory in time
|
||||
if hasattr(request, "multimodal_params"):
|
||||
|
||||
@ -97,6 +97,7 @@ class GenerationRequest:
|
||||
postproc_params: Optional[PostprocParams] = None,
|
||||
multimodal_params: Optional[MultimodalParams] = None,
|
||||
scheduling_params: Optional[SchedulingParams] = None,
|
||||
cache_salt_id: Optional[int] = None,
|
||||
):
|
||||
if isinstance(prompt_token_ids, list):
|
||||
self.prompt_token_ids = prompt_token_ids
|
||||
@ -122,6 +123,7 @@ class GenerationRequest:
|
||||
self.id: Optional[int] = None
|
||||
self.disaggregated_params = disaggregated_params
|
||||
self.scheduling_params = scheduling_params
|
||||
self.cache_salt_id = cache_salt_id
|
||||
|
||||
def set_id(self, id):
|
||||
assert self.id is None, f"Request ID is already set: {self.id}"
|
||||
|
||||
@ -572,7 +572,8 @@ class GenerationExecutorWorker(GenerationExecutor):
|
||||
request.sampling_params.logits_processor,
|
||||
kv_cache_retention_config=request.kv_cache_retention_config,
|
||||
context_phase_params=context_phase_params,
|
||||
type=request_type)
|
||||
type=request_type,
|
||||
cache_salt_id=request.cache_salt_id)
|
||||
executor_request.py_lora_path = py_lora_path
|
||||
|
||||
if self._is_pytorch_backend and request.multimodal_params is not None:
|
||||
|
||||
@ -11,7 +11,8 @@ from .utils import (ALL_SUPPORTED_AUDIO_MODELS, ALL_SUPPORTED_IMAGE_MODELS,
|
||||
add_multimodal_placeholders, apply_chat_template,
|
||||
async_load_audio, async_load_image, async_load_video,
|
||||
convert_image_mode, default_multimodal_input_loader,
|
||||
encode_base64_content_from_url, load_image, load_video)
|
||||
encode_base64_content_from_url, get_cache_salt_id,
|
||||
load_image, load_video)
|
||||
|
||||
__all__ = [
|
||||
"ALL_SUPPORTED_MULTIMODAL_MODELS",
|
||||
@ -44,4 +45,5 @@ __all__ = [
|
||||
"encode_base64_content_from_url",
|
||||
"load_image",
|
||||
"load_video",
|
||||
"get_cache_salt_id",
|
||||
]
|
||||
|
||||
@ -17,6 +17,7 @@ from torchvision.transforms import ToTensor
|
||||
from transformers import AutoProcessor, ProcessorMixin
|
||||
from transformers.utils import logging
|
||||
|
||||
from tensorrt_llm.inputs.multimodal import default_hasher
|
||||
from tensorrt_llm.inputs.registry import (MULTIMODAL_PLACEHOLDER_REGISTRY,
|
||||
MultimodalPlaceholderPlacement)
|
||||
from tensorrt_llm.llmapi.llm_utils import ModelLoader
|
||||
@ -610,3 +611,14 @@ def default_multimodal_input_loader(
|
||||
inputs.append(input)
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
def get_cache_salt_id(cache_salt: str) -> int:
|
||||
b = cache_salt.encode("utf-8")
|
||||
h = default_hasher(b).digest(length=8)
|
||||
cache_salt_id = int.from_bytes(h, "little", signed=False)
|
||||
if cache_salt_id < 0 or cache_salt_id >= (1 << 64):
|
||||
raise ValueError(
|
||||
f"cache_salt_id must be in [0, 2**64 - 1], got {cache_salt_id}.")
|
||||
|
||||
return cache_salt_id
|
||||
|
||||
@ -27,7 +27,8 @@ from ..executor.postproc_worker import PostprocParams
|
||||
from ..executor.utils import (create_mpi_comm_session,
|
||||
get_spawn_proxy_process_env)
|
||||
from ..inputs import (PromptInputs, create_input_processor,
|
||||
create_input_processor_with_hash, prompt_inputs)
|
||||
create_input_processor_with_hash, get_cache_salt_id,
|
||||
prompt_inputs)
|
||||
from ..logger import logger
|
||||
from ..sampling_params import SamplingParams
|
||||
from ..scheduling_params import SchedulingParams
|
||||
@ -325,6 +326,7 @@ class BaseLLM:
|
||||
disaggregated_params: Optional[DisaggregatedParams] = None,
|
||||
_postproc_params: Optional[PostprocParams] = None,
|
||||
scheduling_params: Optional[SchedulingParams] = None,
|
||||
cache_salt: Optional[str] = None,
|
||||
) -> RequestOutput:
|
||||
"""Generate output for the given prompt in the asynchronous mode.
|
||||
Asynchronous generation accepts single prompt only.
|
||||
@ -339,7 +341,7 @@ class BaseLLM:
|
||||
kv_cache_retention_config (tensorrt_llm.bindings.executor.KvCacheRetentionConfig, optional): Configuration for the request's retention in the KV Cache. Defaults to None.
|
||||
disaggregated_params (tensorrt_llm.disaggregated_params.DisaggregatedParams, optional): Disaggregated parameters. Defaults to None.
|
||||
scheduling_params (tensorrt_llm.scheduling_params.SchedulingParams, optional): Scheduling parameters. Defaults to None.
|
||||
|
||||
cache_salt (str, optional): If specified, KV cache will be salted with the provided string to limit the kv cache reuse to the requests with the same string. Defaults to None.
|
||||
Returns:
|
||||
tensorrt_llm.llmapi.RequestOutput: The output data of the completion request to the LLM.
|
||||
"""
|
||||
@ -349,7 +351,8 @@ class BaseLLM:
|
||||
raise RuntimeError("LLM is shutting down")
|
||||
|
||||
sampling_params = self._prepare_sampling_params(sampling_params)
|
||||
|
||||
cache_salt_id = get_cache_salt_id(
|
||||
cache_salt) if cache_salt is not None else None
|
||||
# With pytorch backend, py_executor has logic to handle max_tokens of 1,
|
||||
# so set to 1 to avoid allocating unnecessary KV cache blocks for single request
|
||||
# TODO: Also support for trt backend
|
||||
@ -444,6 +447,7 @@ class BaseLLM:
|
||||
postproc_params=_postproc_params,
|
||||
multimodal_params=multimodal_params,
|
||||
scheduling_params=scheduling_params,
|
||||
cache_salt_id=cache_salt_id,
|
||||
)
|
||||
|
||||
return RequestOutput._from_generation_result(result, prompt,
|
||||
|
||||
@ -19,7 +19,8 @@ from openai.types.responses.response import ToolChoice
|
||||
from openai.types.responses.tool import Tool
|
||||
from openai.types.shared import Metadata, Reasoning
|
||||
from openai_harmony import ReasoningEffort
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
from pydantic import (BaseModel, ConfigDict, Field, field_validator,
|
||||
model_validator)
|
||||
from typing_extensions import Annotated, Required, TypeAlias, TypedDict
|
||||
|
||||
from tensorrt_llm.executor.request import LoRARequest
|
||||
@ -592,6 +593,13 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
description=("Parameters for disaggregated serving"),
|
||||
)
|
||||
|
||||
cache_salt: Optional[str] = Field(
|
||||
default=None,
|
||||
description=
|
||||
("If specified, KV cache will be salted with the provided string "
|
||||
"to limit the kv cache reuse on with the requests having the same string."
|
||||
))
|
||||
|
||||
# doc: end-chat-completion-extra-params
|
||||
|
||||
def to_sampling_params(self, vocab_size: int = 32000) -> SamplingParams:
|
||||
@ -671,6 +679,16 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
raise ValueError("suffix is not supported")
|
||||
return data
|
||||
|
||||
@field_validator("cache_salt")
|
||||
@classmethod
|
||||
def check_cache_salt_support(cls, v):
|
||||
if v is not None:
|
||||
if not isinstance(v, str) or not v.strip():
|
||||
raise ValueError(
|
||||
"Parameter 'cache_salt' must be a non-empty string if provided."
|
||||
)
|
||||
return v
|
||||
|
||||
|
||||
ResponseInputOutputItem: TypeAlias = Union[ResponseInputItemParam,
|
||||
ResponseReasoningItem,
|
||||
|
||||
@ -462,7 +462,8 @@ class OpenAIServer:
|
||||
_postproc_params=postproc_params if self.postproc_worker_enabled else None,
|
||||
streaming=request.stream,
|
||||
lora_request=request.lora_request,
|
||||
disaggregated_params=disaggregated_params
|
||||
disaggregated_params=disaggregated_params,
|
||||
cache_salt=request.cache_salt,
|
||||
)
|
||||
asyncio.create_task(self.await_disconnected(raw_request, promise))
|
||||
if not self.postproc_worker_enabled:
|
||||
|
||||
@ -1557,6 +1557,18 @@ def test_openai_misc_example(llm_root, llm_venv, backend: str):
|
||||
])
|
||||
|
||||
|
||||
def test_openai_cache_salt(llm_root, llm_venv):
|
||||
example_root = Path(os.path.join(llm_root, "examples", "serve"))
|
||||
test_root = unittest_path() / "llmapi" / "apps"
|
||||
llm_venv.run_cmd([
|
||||
"-m", "pip", "install", "-r",
|
||||
os.path.join(example_root, "requirements.txt")
|
||||
])
|
||||
llm_venv.run_cmd(
|
||||
["-m", "pytest",
|
||||
str(test_root / "_test_openai_cache_salt.py")])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("backend", ["pytorch", "trt"])
|
||||
def test_openai_completions_example(llm_root, llm_venv, backend: str):
|
||||
test_root = unittest_path() / "llmapi" / "apps"
|
||||
|
||||
@ -186,6 +186,7 @@ def create_mock_engine(batch_size: int):
|
||||
_cuda_graph_batch_sizes=[batch_size],
|
||||
_max_cuda_graph_batch_size=batch_size,
|
||||
max_beam_width=1,
|
||||
max_num_tokens=8192,
|
||||
is_spec_decode=False,
|
||||
spec_config=None,
|
||||
_cuda_graph_mem_pool=None,
|
||||
|
||||
@ -196,6 +196,9 @@ methods:
|
||||
annotation: Optional[tensorrt_llm.scheduling_params.SchedulingParams]
|
||||
default: null
|
||||
status: prototype
|
||||
cache_salt:
|
||||
annotation: Optional[str]
|
||||
default: null
|
||||
return_annotation: tensorrt_llm.llmapi.llm.RequestOutput
|
||||
get_kv_cache_events:
|
||||
parameters:
|
||||
|
||||
231
tests/unittest/llmapi/apps/_test_openai_cache_salt.py
Normal file
231
tests/unittest/llmapi/apps/_test_openai_cache_salt.py
Normal file
@ -0,0 +1,231 @@
|
||||
"""Test cache_salt functionality in OpenAI API to ensure it prevents cache reuse"""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from ..test_llm import get_model_path
|
||||
from .openai_server import RemoteOpenAIServer
|
||||
|
||||
pytestmark = pytest.mark.threadleak(enabled=False)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", ids=["TinyLlama-1.1B-Chat"])
|
||||
def model_name() -> str:
|
||||
return "llama-models-v2/TinyLlama-1.1B-Chat-v1.0"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def temp_extra_llm_api_options_file():
|
||||
"""Create temporary config file with KV cache enabled for testing"""
|
||||
temp_dir = tempfile.gettempdir()
|
||||
temp_file_path = os.path.join(temp_dir, "cache_salt_test_options.yaml")
|
||||
try:
|
||||
extra_llm_api_options_dict = {
|
||||
# Enable KV cache reuse
|
||||
"kv_cache_config": {
|
||||
"enable_block_reuse": True,
|
||||
},
|
||||
# Enable performance metrics to get cache hit rate
|
||||
"return_perf_metrics": True,
|
||||
"enable_iter_perf_stats": True,
|
||||
"enable_iter_req_stats": True,
|
||||
# Disable CUDA graph for compatibility
|
||||
"cuda_graph_config": None,
|
||||
}
|
||||
|
||||
with open(temp_file_path, 'w') as f:
|
||||
yaml.dump(extra_llm_api_options_dict, f)
|
||||
|
||||
yield temp_file_path
|
||||
finally:
|
||||
if os.path.exists(temp_file_path):
|
||||
os.remove(temp_file_path)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server(model_name: str,
|
||||
temp_extra_llm_api_options_file: str) -> RemoteOpenAIServer:
|
||||
model_path = get_model_path(model_name)
|
||||
args = []
|
||||
args.extend(["--backend", "pytorch"])
|
||||
args.extend(["--extra_llm_api_options", temp_extra_llm_api_options_file])
|
||||
with RemoteOpenAIServer(model_path, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def client(server: RemoteOpenAIServer) -> openai.OpenAI:
|
||||
return server.get_client()
|
||||
|
||||
|
||||
def get_cache_hit_rate(client: openai.OpenAI) -> float:
|
||||
"""Get cache hit rate from the metrics endpoint"""
|
||||
import httpx
|
||||
|
||||
# Get the base URL from the OpenAI client (it includes /v1)
|
||||
# We need to go up one level to access /metrics
|
||||
base_url = str(client.base_url).rstrip('/')
|
||||
if base_url.endswith('/v1'):
|
||||
base_url = base_url[:-3] # Remove /v1
|
||||
|
||||
# Make a direct HTTP request to the metrics endpoint
|
||||
with httpx.Client() as http_client:
|
||||
response = http_client.get(f"{base_url}/metrics", timeout=5.0)
|
||||
|
||||
# Check if metrics endpoint is available
|
||||
if response.status_code != 200:
|
||||
raise RuntimeError(
|
||||
f"Metrics endpoint returned status {response.status_code}")
|
||||
|
||||
metrics = response.json()
|
||||
|
||||
# Validate that we have metrics data
|
||||
if not isinstance(metrics, list) or len(metrics) == 0:
|
||||
raise ValueError("No metrics data available")
|
||||
|
||||
# Get the most recent stats
|
||||
latest_stats = metrics[-1]
|
||||
|
||||
# Extract KV cache statistics
|
||||
kv_cache_stats = latest_stats.get("kvCacheStats", {})
|
||||
if not kv_cache_stats:
|
||||
raise ValueError("No KV cache statistics available in metrics")
|
||||
|
||||
try:
|
||||
print(f"kv_cache_stats reused: {kv_cache_stats['reusedBlocks']}")
|
||||
print(f"kv_cache_stats missed: {kv_cache_stats['missedBlocks']}")
|
||||
print(f"kv_cache_stats hit rate: {kv_cache_stats['cacheHitRate']}")
|
||||
return kv_cache_stats["cacheHitRate"]
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not get cache metrics: {e}")
|
||||
return 0.0
|
||||
|
||||
|
||||
def test_cache_salt_prevents_reuse_chat(client: openai.OpenAI, model_name: str):
|
||||
"""Test that different cache_salt values prevent KV cache reuse in chat completions"""
|
||||
|
||||
# Common messages that will be used across all requests
|
||||
messages = [{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant. Keep responses brief."
|
||||
}, {
|
||||
"role":
|
||||
"user",
|
||||
"content":
|
||||
"What is the capital of France? Answer in one sentence."
|
||||
}]
|
||||
|
||||
# Test configuration
|
||||
max_tokens = 30
|
||||
temperature = 0.0 # Deterministic for testing
|
||||
|
||||
# Track responses for comparison
|
||||
responses = []
|
||||
|
||||
# Test Case 1: First request without cache_salt (baseline)
|
||||
print("\n=== Test Case 1: First request without cache_salt ===")
|
||||
response1 = client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
)
|
||||
responses.append(response1.choices[0].message.content)
|
||||
print(f"Response 1: {response1.choices[0].message.content[:100]}...")
|
||||
|
||||
# Display initial cache metrics
|
||||
initial_hit_rate = get_cache_hit_rate(client)
|
||||
print(f"Initial cache hit rate: {initial_hit_rate:.2%}")
|
||||
|
||||
# Test Case 2: Same messages without cache_salt (should reuse cache)
|
||||
print(
|
||||
"\n=== Test Case 2: Same messages without cache_salt (should reuse) ==="
|
||||
)
|
||||
response2 = client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
)
|
||||
responses.append(response2.choices[0].message.content)
|
||||
print(f"Response 2: {response2.choices[0].message.content[:100]}...")
|
||||
|
||||
# Check if metrics are available
|
||||
hit_rate_after_reuse = get_cache_hit_rate(client)
|
||||
print(f"Cache hit rate after reuse: {hit_rate_after_reuse:.2%}")
|
||||
assert hit_rate_after_reuse >= initial_hit_rate, \
|
||||
"Cache hit rate should increase when reusing cache without salt"
|
||||
|
||||
# Test Case 3: Same messages with cache_salt="user_123" (should NOT reuse)
|
||||
print(
|
||||
"\n=== Test Case 3: Same messages with cache_salt='user_123' (no reuse) ==="
|
||||
)
|
||||
response3 = client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
extra_body={"cache_salt": "user_123"})
|
||||
responses.append(response3.choices[0].message.content)
|
||||
print(f"Response 3: {response3.choices[0].message.content[:100]}...")
|
||||
|
||||
# Record metrics after request with different salt
|
||||
hit_rate_after_salt1 = get_cache_hit_rate(client)
|
||||
print(f"Cache hit rate after salt 'user_123': {hit_rate_after_salt1:.2%}")
|
||||
assert hit_rate_after_salt1 < hit_rate_after_reuse, \
|
||||
"Cache hit rate should decrease when using a different salt"
|
||||
|
||||
# Test Case 4: Same messages with same cache_salt="user_123" (should reuse)
|
||||
print(
|
||||
"\n=== Test Case 4: Same messages with same cache_salt='user_123' (should reuse) ==="
|
||||
)
|
||||
response4 = client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
extra_body={"cache_salt": "user_123"} # Same salt should enable reuse
|
||||
)
|
||||
responses.append(response4.choices[0].message.content)
|
||||
print(f"Response 4: {response4.choices[0].message.content[:100]}...")
|
||||
|
||||
# Cache hit rate should increase again when using same salt
|
||||
hit_rate_after_salt1_reuse = get_cache_hit_rate(client)
|
||||
print(
|
||||
f"Cache hit rate after reusing salt 'user_123': {hit_rate_after_salt1_reuse:.2%}"
|
||||
)
|
||||
assert hit_rate_after_salt1_reuse >= hit_rate_after_salt1, \
|
||||
"Cache hit rate should increase when reusing same salt"
|
||||
|
||||
# Test Case 5: Same messages with different cache_salt="user_456" (should NOT reuse)
|
||||
print(
|
||||
"\n=== Test Case 5: Same messages with cache_salt='user_456' (no reuse) ==="
|
||||
)
|
||||
response5 = client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
extra_body={"cache_salt": "user_456"})
|
||||
responses.append(response5.choices[0].message.content)
|
||||
print(f"Response 5: {response5.choices[0].message.content[:100]}...")
|
||||
|
||||
# Cache hit rate should decrease when using a different salt
|
||||
hit_rate_after_salt2 = get_cache_hit_rate(client)
|
||||
print(f"Cache hit rate after salt 'user_456': {hit_rate_after_salt2:.2%}")
|
||||
assert hit_rate_after_salt2 < hit_rate_after_salt1_reuse, \
|
||||
"Cache hit rate should decrease when using a different salt"
|
||||
|
||||
# Test empty string (should be rejected)
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
extra_body={"cache_salt": ""} # Empty string should be rejected
|
||||
)
|
||||
print(f"Empty string rejected as expected: {exc_info.value}")
|
||||
Loading…
Reference in New Issue
Block a user