Merge remote-tracking branch 'origin/main' into feat/b300_cu13

Signed-off-by: Xiwen Yu <13230610+VALLIS-NERIA@users.noreply.github.com>
This commit is contained in:
Xiwen Yu 2025-09-07 10:28:24 +08:00
commit 291290851a
36 changed files with 688 additions and 92 deletions

View File

@ -542,7 +542,8 @@ texec::Request makeExecutorContextRequest(Sample const& sample, SizeType32 const
std::nullopt, // kvCacheRetentionConfig
std::nullopt, // logitsPostProcessorName
std::nullopt, // logitsPostProcessor
encoderInputTokenIds.has_value() ? encoderInputTokenIds : std::nullopt);
encoderInputTokenIds.has_value() ? encoderInputTokenIds : std::nullopt,
std::nullopt); // cacheSaltID
request.setRequestType(tensorrt_llm::executor::RequestType::REQUEST_TYPE_CONTEXT_ONLY);
return request;
}

View File

@ -837,7 +837,8 @@ texec::Request makeExecutorRequest(Sample const& sample, SizeType32 const& beamW
std::nullopt, // kvCacheRetentionConfig
std::nullopt, // logitsPostProcessorName
std::nullopt, // logitsPostProcessor
encoderInputTokenIds.has_value() ? encoderInputTokenIds : std::nullopt);
encoderInputTokenIds.has_value() ? encoderInputTokenIds : std::nullopt,
std::nullopt); // cacheSaltID
}
void benchmarkExecutor(std::optional<std::filesystem::path> const& decoderEngineDir,

View File

@ -69,6 +69,7 @@ using UniqueToken = tensorrt_llm::runtime::UniqueToken;
using VecUniqueTokens = tensorrt_llm::runtime::VecUniqueTokens;
using LoraTaskIdType = tensorrt_llm::runtime::LoraTaskIdType;
using BlocksPerWindow = std::map<SizeType32, std::tuple<SizeType32, SizeType32>>;
using CacheSaltIDType = tensorrt_llm::runtime::CacheSaltIDType;
// Type alias for multimodal hash key (hash array + start offset)
using MmKey = std::pair<std::array<uint8_t, 32>, SizeType32>;
@ -115,6 +116,7 @@ struct BlockKey
// Extra keys for multimodal data (similar to VLLM's approach)
// Each extra key is a pair of (mm_hash, start_offset_in_block)
std::vector<MmKey> extraKeys;
std::optional<CacheSaltIDType> cacheSaltID = std::nullopt;
BlockKey() = default;
@ -129,24 +131,25 @@ struct BlockKey
}
explicit BlockKey(bool usesExtraIds, std::optional<LoraTaskIdType> loraTaskId, VecUniqueTokens uniqueTokens,
std::vector<MmKey> extraKeys = {})
std::vector<MmKey> extraKeys = {}, std::optional<CacheSaltIDType> cacheSaltID = std::nullopt)
: usesExtraIds{usesExtraIds}
, loraTaskId{loraTaskId}
, uniqueTokens{std::move(uniqueTokens)}
, extraKeys{std::move(extraKeys)}
, cacheSaltID{cacheSaltID}
{
}
bool operator==(BlockKey const& other) const noexcept
{
return (usesExtraIds == other.usesExtraIds && loraTaskId == other.loraTaskId
&& uniqueTokens == other.uniqueTokens && extraKeys == other.extraKeys);
&& uniqueTokens == other.uniqueTokens && extraKeys == other.extraKeys && cacheSaltID == other.cacheSaltID);
}
int partialMatch(BlockKey const& other) const noexcept
{
SizeType32 numMatched{0};
if (loraTaskId == other.loraTaskId && extraKeys == other.extraKeys)
if (loraTaskId == other.loraTaskId && extraKeys == other.extraKeys && cacheSaltID == other.cacheSaltID)
{
auto [matchEnd, otherMatchEnd] = std::mismatch(
uniqueTokens.begin(), uniqueTokens.end(), other.uniqueTokens.begin(), other.uniqueTokens.end());

View File

@ -100,8 +100,8 @@ public:
RequestIdType, TensorPtr&, BeamTokens const&, TStream const&, std::optional<RequestIdType>)>;
using RequestPtr = std::shared_ptr<GenericLlmRequest>;
using MillisecondsType = std::chrono::milliseconds;
using CacheSaltIDType = runtime::CacheSaltIDType;
// 49 parameters, 56 items in initialization list
GenericLlmRequest(RequestIdType requestId, SizeType32 maxNewTokens, std::shared_ptr<VecTokens> const& inputTokens,
runtime::SamplingConfig const& samplingConfig, bool isStreaming, std::optional<SizeType32> endId = std::nullopt,
std::optional<SizeType32> padId = std::nullopt, std::optional<TensorPtr> embeddingBias = std::nullopt,
@ -137,7 +137,8 @@ public:
std::optional<executor::GuidedDecodingParams> guidedDecodingParams = std::nullopt,
std::optional<SizeType32> languageAdapterUid = std::nullopt,
std::optional<MillisecondsType> allottedTimeMs = std::nullopt,
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt)
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt,
std::optional<CacheSaltIDType> cacheSaltID = std::nullopt)
: mRequestId(requestId)
, mPromptLen(inputTokens->size())
, mMaxNewTokens(maxNewTokens)
@ -194,6 +195,7 @@ public:
, mGuidedDecodingParams(std::move(guidedDecodingParams))
, mLanguageAdapterUid(languageAdapterUid)
, mAllottedTimeMs(allottedTimeMs)
, mCacheSaltID(cacheSaltID)
{
if (mEncoderTokens.has_value() || encoderInputFeatures.has_value())
{
@ -203,7 +205,6 @@ public:
initialize(*inputTokens, returnLogProbs);
}
// 32 parameters, 39 items in initialization list
GenericLlmRequest(RequestIdType requestId, SizeType32 maxNewTokens, VecTokens const& inputTokens,
runtime::SamplingConfig const& samplingConfig, bool isStreaming, std::optional<SizeType32> endId = std::nullopt,
std::optional<SizeType32> padId = std::nullopt, std::optional<TensorPtr> embeddingBias = std::nullopt,
@ -221,7 +222,8 @@ public:
bool returnEncoderOutput = false, std::optional<RequestIdType> clientId = std::nullopt,
executor::PriorityType priority = executor::Request::kDefaultPriority, SizeType32 numReturnSequences = 1,
std::optional<SizeType32> languageAdapterUid = std::nullopt,
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt)
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt,
std::optional<CacheSaltIDType> cacheSaltID = std::nullopt)
: mRequestId(requestId)
, mPromptLen(inputTokens.size())
, mMaxNewTokens(maxNewTokens)
@ -261,6 +263,7 @@ public:
, mContextPhaseParams(contextPhaseParams)
, mNumReturnSequences(numReturnSequences)
, mLanguageAdapterUid(languageAdapterUid)
, mCacheSaltID(cacheSaltID)
{
if (mEncoderTokens.has_value())
{
@ -269,7 +272,6 @@ public:
initialize(inputTokens, returnLogProbs);
}
// 29 items in initialization list
GenericLlmRequest(RequestIdType requestId, executor::Request const& req)
: mRequestId(requestId)
, mPromptLen(req.getInputTokenIds().size())
@ -300,6 +302,7 @@ public:
, mGuidedDecodingParams(req.getGuidedDecodingParams())
, mLanguageAdapterUid(req.getLanguageAdapterUid())
, mAllottedTimeMs(req.getAllottedTimeMs())
, mCacheSaltID(req.getCacheSaltID())
{
if (req.getRequestType() == executor::RequestType::REQUEST_TYPE_GENERATION_ONLY)
{
@ -1764,6 +1767,11 @@ public:
return mLanguageAdapterUid;
}
[[nodiscard]] std::optional<CacheSaltIDType> getCacheSaltID() const
{
return mCacheSaltID;
}
std::vector<SizeType32> getLanguageAdapterRouting(
SizeType32 const reqNumLanguages, SizeType32 const inputLength) const
{
@ -2042,6 +2050,9 @@ protected:
bool mUseDraftModel{false};
// Cache salt id for each request.
std::optional<CacheSaltIDType> mCacheSaltID{std::nullopt};
private:
void initialize(VecTokens const& inputTokens, bool outputLogProbs)
{
@ -2222,7 +2233,8 @@ public:
std::optional<executor::GuidedDecodingParams> guidedDecodingParams = std::nullopt,
std::optional<SizeType32> languageAdapterUid = std::nullopt,
std::optional<MillisecondsType> allottedTimeMs = std::nullopt,
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt)
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt,
std::optional<CacheSaltIDType> cacheSaltID = std::nullopt)
: Base(requestId, maxNewTokens, std::move(inputTokens), samplingConfig, isStreaming, endId, padId,
std::move(embeddingBias), std::move(badWordsList), std::move(stopWordsList), std::move(positionIds),
std::move(promptEmbeddingTable), promptVocabSize, std::move(multimodalHashes),
@ -2234,7 +2246,8 @@ public:
std::move(encoderInputTokens), returnEncoderOutput, clientId, priority, std::move(encoderInputFeatures),
std::move(encoderOutputLength), std::move(crossAttentionMask), llmRequestType,
std::move(inputTokenExtraIds), numReturnSequences, std::move(eagleConfig), std::move(skipCrossAttnBlocks),
returnPerfMetrics, std::move(guidedDecodingParams), languageAdapterUid, allottedTimeMs, contextPhaseParams)
returnPerfMetrics, std::move(guidedDecodingParams), languageAdapterUid, allottedTimeMs, contextPhaseParams,
cacheSaltID)
{
}
@ -2272,7 +2285,8 @@ public:
std::optional<executor::GuidedDecodingParams> guidedDecodingParams = std::nullopt,
std::optional<SizeType32> languageAdapterUid = std::nullopt,
std::optional<MillisecondsType> allottedTimeMs = std::nullopt,
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt)
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt,
std::optional<CacheSaltIDType> cacheSaltID = std::nullopt)
: Base(requestId, maxNewTokens, std::make_shared<std::vector<TokenIdType>>(std::move(inputTokens)),
samplingConfig, isStreaming, endId, padId, std::move(embeddingBias), std::move(badWordsList),
std::move(stopWordsList),
@ -2302,7 +2316,7 @@ public:
inputTokenExtraIds ? std::make_optional(std::make_shared<VecTokenExtraIds>(std::move(*inputTokenExtraIds)))
: std::optional<std::shared_ptr<VecTokenExtraIds>>(std::nullopt),
numReturnSequences, std::move(eagleConfig), skipCrossAttnBlocks, returnPerfMetrics,
std::move(guidedDecodingParams), languageAdapterUid, allottedTimeMs, contextPhaseParams)
std::move(guidedDecodingParams), languageAdapterUid, allottedTimeMs, contextPhaseParams, cacheSaltID)
{
}
@ -2324,14 +2338,15 @@ public:
bool returnEncoderOutput = false, std::optional<RequestIdType> clientId = std::nullopt,
executor::PriorityType priority = executor::Request::kDefaultPriority, SizeType32 numReturnSequences = 1,
std::optional<SizeType32> languageAdapterUid = std::nullopt,
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt)
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt,
std::optional<CacheSaltIDType> cacheSaltID = std::nullopt)
: Base(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming, endId, padId,
std::move(embeddingBias), std::move(badWordsList), std::move(stopWordsList), std::move(positionIds),
std::move(promptEmbeddingTable), promptVocabSize, loraTaskId, std::move(loraWeights), std::move(loraConfig),
lookaheadConfig, returnLogProbs, returnContextLogits, returnGenerationLogits, std::move(draftTokens),
std::move(draftLogits), excludeInputFromOutput, std::move(logitsPostProcessor),
applyLogitsPostProcessorBatched, std::move(encoderInputTokens), returnEncoderOutput, clientId, priority,
numReturnSequences, languageAdapterUid, contextPhaseParams)
numReturnSequences, languageAdapterUid, contextPhaseParams, cacheSaltID)
{
}

View File

@ -670,7 +670,7 @@ public:
/// @param allottedTimeMs The allotted time in milliseconds after which the request is cancelled with a timedOut
/// finish reason. The request may exceed this time slightly, but at most by 1 forward pass (in pipeline parallelism
/// that may involve multiple micro-batches). A request can be timed-out before ever being scheduled.
// 34 parameters
/// @param cacheSaltID Salt ID for KV cache blocks to limit the kv cache reuse to the requests with the same string.
Request(VecTokens inputTokenIds, SizeType32 maxTokens, bool streaming = false,
SamplingConfig const& samplingConfig = SamplingConfig(), OutputConfig const& outputConfig = OutputConfig(),
std::optional<SizeType32> const& endId = std::nullopt, std::optional<SizeType32> const& padId = std::nullopt,
@ -697,7 +697,8 @@ public:
std::optional<EagleConfig> eagleConfig = std::nullopt, std::optional<Tensor> skipCrossAttnBlocks = std::nullopt,
std::optional<GuidedDecodingParams> guidedDecodingParams = std::nullopt,
std::optional<SizeType32> languageAdapterUid = std::nullopt,
std::optional<MillisecondsType> allottedTimeMs = std::nullopt);
std::optional<MillisecondsType> allottedTimeMs = std::nullopt,
std::optional<CacheSaltIDType> cacheSaltID = std::nullopt);
/// @brief This logits postprocessor name will dispatch to the batched logits postprocessor
static auto constexpr kBatchedPostProcessorName = "batched";
@ -745,6 +746,7 @@ public:
[[nodiscard]] std::optional<GuidedDecodingParams> getGuidedDecodingParams() const;
[[nodiscard]] std::optional<SizeType32> getLanguageAdapterUid() const;
[[nodiscard]] std::optional<MillisecondsType> getAllottedTimeMs() const;
[[nodiscard]] std::optional<CacheSaltIDType> getCacheSaltID() const;
[[nodiscard]] std::optional<std::vector<std::string>> getAdditionalOutputNames() const;
void setStreaming(bool streaming);
@ -780,6 +782,7 @@ public:
void setGuidedDecodingParams(GuidedDecodingParams const& guidedDecodingParams);
void setLanguageAdapterUid(SizeType32 languageAdapterUid);
void setAllottedTimeMs(MillisecondsType allottedTimeMs);
void setCacheSaltID(CacheSaltIDType cacheSaltID);
private:
friend class Serialization;

View File

@ -58,6 +58,7 @@ using RandomSeedType = std::uint64_t;
using VecLogProbs = std::vector<FloatType>;
using StreamPtr = std::shared_ptr<tensorrt_llm::runtime::CudaStream>;
using MillisecondsType = std::chrono::milliseconds;
using CacheSaltIDType = std::uint64_t;
using LogitsPostProcessor
= std::function<void(IdType, Tensor&, BeamTokens const&, StreamPtr const&, std::optional<IdType>)>;
using LogitsPostProcessorMap = std::unordered_map<std::string, LogitsPostProcessor>;

View File

@ -44,6 +44,7 @@ using TokenIdType = std::int32_t;
using LoraTaskIdType = std::uint64_t;
using TokenExtraIdType = std::uint64_t;
using VecTokenExtraIds = std::vector<TokenExtraIdType>;
using CacheSaltIDType = std::uint64_t;
struct UniqueToken
{

View File

@ -131,7 +131,7 @@ std::vector<MmKey> generateBlockHashExtraKeys(
// Check if this multimodal content overlaps with the current block
if (endTokenIdx > startPos && startTokenIdx < startPos + length)
{
SizeType32 mmStartInBlock = (startPos >= startTokenIdx) ? 0 : startTokenIdx - startPos;
uint64_t mmStartInBlock = (startPos >= startTokenIdx) ? 0 : static_cast<uint64_t>(startTokenIdx - startPos);
extraKeys.emplace_back(mmHashArray, mmStartInBlock);
}
}
@ -151,7 +151,7 @@ std::vector<BlockKey> buildBlockKeys(
currentTokenIdx += uniqueTokens.size();
blockKeys.emplace_back(llmRequest.getInputTokensExtraIds().has_value(), llmRequest.getLoraTaskId(),
std::move(uniqueTokens), std::move(extraKeys));
std::move(uniqueTokens), std::move(extraKeys), llmRequest.getCacheSaltID());
}
return blockKeys;
}
@ -167,6 +167,16 @@ size_t BlockKeyHasher::hash(BlockKey const& blockKey, std::size_t parentHash) no
// Constants provide very good distribution - each input bit affects each output bit with ~50% probability.
size_t seed = blockKey.uniqueTokens.size() ^ parentHash * UINT64_C(0xbf58476d1ce4e5b9);
if (parentHash == 0 && blockKey.cacheSaltID)
{
// Only hashing the cache salt ID for the first block in the sequence
uint64_t c = blockKey.cacheSaltID.value();
c = (c ^ (c >> 30)) * UINT64_C(0xbf58476d1ce4e5b9);
c = (c ^ (c >> 27)) * UINT64_C(0x94d049bb133111eb);
c = c ^ (c >> 31);
seed ^= c + 0x9e3779b9 + (seed << 6) + (seed >> 2);
}
for (auto const& uniqueToken : blockKey.uniqueTokens)
{
uint32_t a = static_cast<uint32_t>(uniqueToken.tokenId);

View File

@ -25,7 +25,6 @@
namespace tensorrt_llm::executor
{
// 36 parameters
Request::Request(VecTokens inputTokenIds, SizeType32 maxTokens, bool streaming, SamplingConfig const& samplingConfig,
OutputConfig const& outputConfig, std::optional<SizeType32> const& endId, std::optional<SizeType32> const& padId,
std::optional<std::vector<SizeType32>> positionIds, std::optional<std::list<VecTokens>> badWords,
@ -41,7 +40,7 @@ Request::Request(VecTokens inputTokenIds, SizeType32 maxTokens, bool streaming,
std::optional<SizeType32> encoderOutputLength, std::optional<Tensor> crossAttentionMask,
SizeType32 numReturnSequences, std::optional<EagleConfig> eagleConfig, std::optional<Tensor> skipCrossAttnBlocks,
std::optional<GuidedDecodingParams> guidedDecodingParams, std::optional<SizeType32> languageAdapterUid,
std::optional<MillisecondsType> allottedTimeMs)
std::optional<MillisecondsType> allottedTimeMs, std::optional<CacheSaltIDType> cacheSaltID)
: mImpl(std::make_unique<Impl>(std::move(inputTokenIds), maxTokens, streaming, samplingConfig, outputConfig, endId,
padId, std::move(positionIds), std::move(badWords), std::move(stopWords), std::move(embeddingBias),
std::move(externalDraftTokensConfig), std::move(pTuningConfig), std::move(multimodalInput),
@ -50,7 +49,7 @@ Request::Request(VecTokens inputTokenIds, SizeType32 maxTokens, bool streaming,
std::move(encoderInputTokenIds), clientId, returnAllGeneratedTokens, priority, type,
std::move(contextPhaseParams), std::move(encoderInputFeatures), encoderOutputLength, crossAttentionMask,
numReturnSequences, eagleConfig, skipCrossAttnBlocks, std::move(guidedDecodingParams), languageAdapterUid,
allottedTimeMs))
allottedTimeMs, cacheSaltID))
{
}
@ -249,6 +248,11 @@ std::optional<SizeType32> Request::getLanguageAdapterUid() const
return mImpl->getLanguageAdapterUid();
}
std::optional<CacheSaltIDType> Request::getCacheSaltID() const
{
return mImpl->getCacheSaltID();
}
void Request::setStreaming(bool streaming)
{
mImpl->setStreaming(streaming);
@ -413,4 +417,9 @@ void Request::setLanguageAdapterUid(SizeType32 languageAdapterUid)
{
return mImpl->setLanguageAdapterUid(languageAdapterUid);
}
void Request::setCacheSaltID(CacheSaltIDType cacheSaltID)
{
return mImpl->setCacheSaltID(cacheSaltID);
}
} // namespace tensorrt_llm::executor

View File

@ -32,7 +32,6 @@ class Request::Impl
{
public:
// 36 parameters, 36 items in initialization list
Impl(VecTokens inputTokenIds, SizeType32 maxNewTokens, bool streaming, SamplingConfig const& samplingConfig,
OutputConfig outputConfig, std::optional<TokenIdType> const& endId, std::optional<TokenIdType> const& padId,
std::optional<std::vector<SizeType32>> positionIds, std::optional<std::list<VecTokens>> badWords,
@ -48,7 +47,8 @@ public:
std::optional<Tensor> encoderInputFeatures, std::optional<SizeType32> encoderOutputLength,
std::optional<Tensor> crossAttentionMask, SizeType32 numReturnSequences, std::optional<EagleConfig> eagleConfig,
std::optional<Tensor> skipCrossAttnBlocks, std::optional<GuidedDecodingParams> guidedDecodingParams,
std::optional<SizeType32> languageAdapterUid, std::optional<MillisecondsType> allottedTimeMs)
std::optional<SizeType32> languageAdapterUid, std::optional<MillisecondsType> allottedTimeMs,
std::optional<CacheSaltIDType> cacheSaltID)
: mInputTokenIds(std::move(inputTokenIds))
, mMaxNewTokens(maxNewTokens)
, mStreaming(streaming)
@ -85,6 +85,7 @@ public:
, mGuidedDecodingParams(std::move(guidedDecodingParams))
, mLanguageAdapterUid(languageAdapterUid)
, mAllottedTimeMs(allottedTimeMs)
, mCacheSaltID(cacheSaltID)
{
validate();
}
@ -296,6 +297,11 @@ public:
return mLanguageAdapterUid;
}
[[nodiscard]] std::optional<CacheSaltIDType> getCacheSaltID() const
{
return mCacheSaltID;
}
void setStreaming(bool streaming)
{
mStreaming = streaming;
@ -470,6 +476,11 @@ public:
mLanguageAdapterUid = languageAdapterUid;
}
void setCacheSaltID(CacheSaltIDType cacheSaltID)
{
mCacheSaltID = cacheSaltID;
}
private:
void validate()
{
@ -543,6 +554,7 @@ private:
lambda(mGuidedDecodingParams);
lambda(mLanguageAdapterUid);
lambda(mAllottedTimeMs ? std::make_optional(mAllottedTimeMs->count()) : std::nullopt);
lambda(mCacheSaltID);
}
VecTokens mInputTokenIds;
@ -581,6 +593,7 @@ private:
std::optional<GuidedDecodingParams> mGuidedDecodingParams;
std::optional<SizeType32> mLanguageAdapterUid;
std::optional<MillisecondsType> mAllottedTimeMs;
std::optional<CacheSaltIDType> mCacheSaltID;
};
} // namespace tensorrt_llm::executor

View File

@ -711,8 +711,8 @@ Request Serialization::deserializeRequest(std::istream& is)
auto allottedTimeMs = allottedTimeInt
? std::optional<std::chrono::milliseconds>(std::chrono::milliseconds(*allottedTimeInt))
: std::nullopt;
auto cacheSaltID = su::deserialize<std::optional<CacheSaltIDType>>(is);
// 35 parameters
return Request(std::move(inputTokenIds), maxNewTokens, streaming, samplingConfig, outputConfig, endId, padId,
std::move(positionIds), std::move(badWords), std::move(stopWords), std::move(embeddingBias),
std::move(externalDraftTokensConfig), std::move(pTuningConfig), std::move(multimodalInput),
@ -721,7 +721,7 @@ Request Serialization::deserializeRequest(std::istream& is)
std::move(encoderInputTokenIds), clientId, returnAllGeneratedTokens, priority, requestType,
std::move(contextPhaseParams), std::move(encoderInputFeatures), encoderOutputLength,
std::move(crossAttentionMask), numReturnSequences, std::move(eagleConfig), std::move(skipCrossAttnBlocks),
std::move(guidedDecodingParams), languageAdapterUid, allottedTimeMs);
std::move(guidedDecodingParams), languageAdapterUid, allottedTimeMs, cacheSaltID);
}
void Serialization::serialize(Request const& request, std::ostream& os)

View File

@ -189,6 +189,7 @@ void initBindings(nb::module_& m)
.def_prop_ro("llm_request_type", &GenLlmReq::getLlmRequestType)
.def_prop_ro("parent_request_id", &GenLlmReq::getParentRequestId)
.def_prop_ro("is_child", &GenLlmReq::isChild)
.def_prop_ro("cache_salt_id", &GenLlmReq::getCacheSaltID)
.def_prop_ro("multimodal_hashes",
[](GenLlmReq& self)
{
@ -287,7 +288,8 @@ void initBindings(nb::module_& m)
std::optional<executor::GuidedDecodingParams> guided_decoding_params,
std::optional<tb::LlmRequest::SizeType32> language_adapter_uid,
std::optional<tb::LlmRequest::MillisecondsType> allotted_time_ms,
std::optional<executor::ContextPhaseParams> context_phase_params)
std::optional<executor::ContextPhaseParams> context_phase_params,
std::optional<tb::LlmRequest::CacheSaltIDType> cache_salt_id)
{
auto makeOptionalTensor = [](std::optional<at::Tensor> const& atTensor, bool unsqueeze = false)
{
@ -316,7 +318,6 @@ void initBindings(nb::module_& m)
auto cross_attention_mask_tensor_ptr = makeOptionalTensor(cross_attention_mask);
auto skip_cross_attn_blocks_tensor_ptr = makeOptionalTensor(skip_cross_attn_blocks);
// 49 parameters
new (self) tb::LlmRequest{request_id, max_new_tokens, input_tokens, sampling_config, is_streaming,
end_id, pad_id, embedding_bias_tensor_ptr, bad_words_list_tensor_ptr, stop_words_list_tensor_ptr,
position_ids, prompt_embedding_table_tensor_ptr, prompt_vocab_size, multimodal_hashes,
@ -328,7 +329,8 @@ void initBindings(nb::module_& m)
encoder_input_tokens, return_encoder_output, client_id, priority, encoder_input_features_tensor_ptr,
encoder_output_length, cross_attention_mask_tensor_ptr, llm_request_type, input_token_extra_ids,
num_return_sequences, eagle_config, skip_cross_attn_blocks_tensor_ptr, return_perf_metrics,
guided_decoding_params, language_adapter_uid, allotted_time_ms, context_phase_params};
guided_decoding_params, language_adapter_uid, allotted_time_ms, context_phase_params,
cache_salt_id};
},
nb::arg("request_id"), nb::arg("max_new_tokens"), nb::arg("input_tokens"), nb::arg("sampling_config"),
nb::arg("is_streaming"), nb::arg("end_id") = std::nullopt, nb::arg("pad_id") = std::nullopt,
@ -353,7 +355,7 @@ void initBindings(nb::module_& m)
nb::arg("eagle_config") = std::nullopt, nb::arg("skip_cross_attn_blocks") = std::nullopt,
nb::arg("return_perf_metrics") = false, nb::arg("guided_decoding_params") = std::nullopt,
nb::arg("language_adapter_uid") = std::nullopt, nb::arg("allotted_time_ms") = std::nullopt,
nb::arg("context_phase_params") = std::nullopt)
nb::arg("context_phase_params") = std::nullopt, nb::arg("cache_salt_id") = std::nullopt)
.def("check_token_id_range", &tb::LlmRequest::checkTokenIdRange, nb::arg("vocab_size"))
.def(nb::init<tb::LlmRequest const&>())
.def("validate", &tb::LlmRequest::validate, nb::arg("max_input_len"), nb::arg("max_seq_len"),

View File

@ -76,7 +76,6 @@ std::shared_ptr<tb::LlmRequest> LlmRequest::toTrtLlm() const
? std::make_shared<std::vector<TokenIdType>>(*mEncoderTokens.value().get())
: nullptr;
auto const optEncoderInputTokens = std::optional<std::shared_ptr<std::vector<TokenIdType>>>(encoderInputTokens);
// 49 parameters
return std::make_shared<tb::LlmRequest>( //
mRequestId, //
mMaxNewTokens, //
@ -126,6 +125,7 @@ std::shared_ptr<tb::LlmRequest> LlmRequest::toTrtLlm() const
mGuidedDecodingParams, //
mLanguageAdapterUid, //
mAllottedTimeMs, //
mContextPhaseParams //
mContextPhaseParams, //
mCacheSaltID //
);
}

View File

@ -51,7 +51,6 @@ public:
using VecTokenExtraIds = Base::VecTokenExtraIds;
using LogitsPostProcessor = Base::LogitsPostProcessor;
// 49 parameters
LlmRequest(RequestIdType requestId, SizeType32 maxNewTokens, std::vector<TokenIdType> inputTokens,
runtime::SamplingConfig samplingConfig, bool isStreaming, std::optional<SizeType32> endId = std::nullopt,
std::optional<SizeType32> padId = std::nullopt, std::optional<TensorPtr> embeddingBias = std::nullopt,
@ -85,7 +84,8 @@ public:
std::optional<executor::GuidedDecodingParams> guidedDecodingParams = std::nullopt,
std::optional<SizeType32> languageAdapterUid = std::nullopt,
std::optional<MillisecondsType> allottedTimeMs = std::nullopt,
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt)
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt,
std::optional<CacheSaltIDType> cacheSaltID = std::nullopt)
: Base(requestId, //
maxNewTokens, //
std::make_shared<std::vector<TokenIdType>>(std::move(inputTokens)), //
@ -146,7 +146,8 @@ public:
guidedDecodingParams, //
languageAdapterUid, //
allottedTimeMs, //
contextPhaseParams //
contextPhaseParams, //
cacheSaltID //
)
{
}

View File

@ -573,11 +573,11 @@ void initRequestBindings(nb::module_& m)
self.getClientId(), self.getReturnAllGeneratedTokens(), self.getPriority(), self.getRequestType(),
self.getContextPhaseParams(), self.getEncoderInputFeatures(), self.getEncoderOutputLength(),
self.getCrossAttentionMask(), self.getEagleConfig(), self.getSkipCrossAttnBlocks(),
self.getGuidedDecodingParams());
self.getGuidedDecodingParams(), self.getCacheSaltID());
};
auto requestSetstate = [](tle::Request& self, nb::tuple const& state)
{
if (state.size() != 33)
if (state.size() != 34)
{
throw std::runtime_error("Invalid Request state!");
}
@ -601,7 +601,8 @@ void initRequestBindings(nb::module_& m)
nb::cast<std::optional<tle::Tensor>>(state[27]), nb::cast<std::optional<SizeType32>>(state[28]),
nb::cast<std::optional<tle::Tensor>>(state[29]), 1, nb::cast<std::optional<tle::EagleConfig>>(state[30]),
nb::cast<std::optional<tle::Tensor>>(state[31]),
nb::cast<std::optional<tle::GuidedDecodingParams>>(state[32]));
nb::cast<std::optional<tle::GuidedDecodingParams>>(state[32]),
nb::cast<std::optional<tle::CacheSaltIDType>>(state[33]));
};
nb::class_<tle::Request> request(m, "Request", nb::dynamic_attr());
@ -641,7 +642,8 @@ void initRequestBindings(nb::module_& m)
std::optional<tle::Tensor>, // skipCrossAttnBlocks
std::optional<tle::GuidedDecodingParams>, // guidedDecodingParams
std::optional<tle::SizeType32>, // languageAdapterUid
std::optional<tle::MillisecondsType> // allottedTimeMs
std::optional<tle::MillisecondsType>, // allottedTimeMs
std::optional<tle::CacheSaltIDType> // cacheSaltID
>(),
// clang-format off
nb::arg("input_token_ids"),
@ -680,8 +682,9 @@ void initRequestBindings(nb::module_& m)
nb::arg("skip_cross_attn_blocks") = nb::none(),
nb::arg("guided_decoding_params") = nb::none(),
nb::arg("language_adapter_uid") = nb::none(),
nb::arg("allotted_time_ms") = nb::none()
) // clang-format on
nb::arg("allotted_time_ms") = nb::none(),
nb::arg("cache_salt_id") = nb::none()
) // clang-format on
.def_prop_ro("input_token_ids", &tle::Request::getInputTokenIds)
.def_prop_ro("max_tokens", &tle::Request::getMaxTokens)
.def_prop_rw("streaming", &tle::Request::getStreaming, &tle::Request::setStreaming)
@ -723,6 +726,7 @@ void initRequestBindings(nb::module_& m)
.def_prop_rw(
"guided_decoding_params", &tle::Request::getGuidedDecodingParams, &tle::Request::setGuidedDecodingParams)
.def_prop_rw("allotted_time_ms", &tle::Request::getAllottedTimeMs, &tle::Request::setAllottedTimeMs)
.def_prop_rw("cache_salt_id", &tle::Request::getCacheSaltID, &tle::Request::setCacheSaltID)
.def_prop_rw("context_phase_params", &tle::Request::getContextPhaseParams, &tle::Request::setContextPhaseParams)
.def("__getstate__", requestGetstate)
.def("__setstate__", requestSetstate);

View File

@ -196,6 +196,7 @@ void initBindings(pybind11::module_& m)
.def_property_readonly("llm_request_type", &GenLlmReq::getLlmRequestType)
.def_property_readonly("parent_request_id", &GenLlmReq::getParentRequestId)
.def_property_readonly("is_child", &GenLlmReq::isChild)
.def_property_readonly("cache_salt_id", &GenLlmReq::getCacheSaltID)
.def_property_readonly("multimodal_hashes",
[](GenLlmReq& self)
{
@ -293,7 +294,8 @@ void initBindings(pybind11::module_& m)
std::optional<executor::GuidedDecodingParams> guided_decoding_params,
std::optional<tb::LlmRequest::SizeType32> language_adapter_uid,
std::optional<tb::LlmRequest::MillisecondsType> allotted_time_ms,
std::optional<executor::ContextPhaseParams> context_phase_params)
std::optional<executor::ContextPhaseParams> context_phase_params,
std::optional<tb::LlmRequest::CacheSaltIDType> cache_salt_id)
{
auto makeOptionalTensor = [](std::optional<at::Tensor> const& atTensor, bool unsqueeze = false)
{
@ -322,7 +324,6 @@ void initBindings(pybind11::module_& m)
auto cross_attention_mask_tensor_ptr = makeOptionalTensor(cross_attention_mask);
auto skip_cross_attn_blocks_tensor_ptr = makeOptionalTensor(skip_cross_attn_blocks);
// 49 parameters
return tb::LlmRequest{request_id, max_new_tokens, input_tokens, sampling_config, is_streaming,
end_id, pad_id, embedding_bias_tensor_ptr, bad_words_list_tensor_ptr,
stop_words_list_tensor_ptr, position_ids, prompt_embedding_table_tensor_ptr, prompt_vocab_size,
@ -335,7 +336,7 @@ void initBindings(pybind11::module_& m)
encoder_input_features_tensor_ptr, encoder_output_length, cross_attention_mask_tensor_ptr,
llm_request_type, input_token_extra_ids, num_return_sequences, eagle_config,
skip_cross_attn_blocks_tensor_ptr, return_perf_metrics, guided_decoding_params,
language_adapter_uid, allotted_time_ms, context_phase_params};
language_adapter_uid, allotted_time_ms, context_phase_params, cache_salt_id};
}),
py::arg("request_id"), py::arg("max_new_tokens"), py::arg("input_tokens"), py::arg("sampling_config"),
py::arg("is_streaming"), py::arg("end_id") = std::nullopt, py::arg("pad_id") = std::nullopt,
@ -361,7 +362,7 @@ void initBindings(pybind11::module_& m)
py::arg("eagle_config") = std::nullopt, py::arg("skip_cross_attn_blocks") = std::nullopt,
py::arg("return_perf_metrics") = false, py::arg("guided_decoding_params") = std::nullopt,
py::arg("language_adapter_uid") = std::nullopt, py::arg("allotted_time_ms") = std::nullopt,
py::arg("context_phase_params") = std::nullopt)
py::arg("context_phase_params") = std::nullopt, py::arg("cache_salt_id") = std::nullopt)
.def("check_token_id_range", &tb::LlmRequest::checkTokenIdRange, py::arg("vocab_size"))
.def(py::init<tb::LlmRequest const&>())
.def("validate", &tb::LlmRequest::validate, py::arg("max_input_len"), py::arg("max_seq_len"),

View File

@ -75,7 +75,6 @@ std::shared_ptr<tb::LlmRequest> LlmRequest::toTrtLlm() const
? std::make_shared<std::vector<TokenIdType>>(*mEncoderTokens.value().get())
: nullptr;
auto const optEncoderInputTokens = std::optional<std::shared_ptr<std::vector<TokenIdType>>>(encoderInputTokens);
// 49 parameters
return std::make_shared<tb::LlmRequest>( //
mRequestId, //
mMaxNewTokens, //
@ -125,6 +124,7 @@ std::shared_ptr<tb::LlmRequest> LlmRequest::toTrtLlm() const
mGuidedDecodingParams, //
mLanguageAdapterUid, //
mAllottedTimeMs, //
mContextPhaseParams //
mContextPhaseParams, //
mCacheSaltID //
);
}

View File

@ -49,8 +49,8 @@ public:
using VecTokens = Base::VecTokens;
using VecTokenExtraIds = Base::VecTokenExtraIds;
using LogitsPostProcessor = Base::LogitsPostProcessor;
using CacheSaltIDType = Base::CacheSaltIDType;
// 49 parameters
LlmRequest(RequestIdType requestId, SizeType32 maxNewTokens, std::vector<TokenIdType> inputTokens,
runtime::SamplingConfig samplingConfig, bool isStreaming, std::optional<SizeType32> endId = std::nullopt,
std::optional<SizeType32> padId = std::nullopt, std::optional<TensorPtr> embeddingBias = std::nullopt,
@ -84,7 +84,8 @@ public:
std::optional<executor::GuidedDecodingParams> guidedDecodingParams = std::nullopt,
std::optional<SizeType32> languageAdapterUid = std::nullopt,
std::optional<MillisecondsType> allottedTimeMs = std::nullopt,
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt)
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt,
std::optional<CacheSaltIDType> cacheSaltID = std::nullopt)
: Base(requestId, //
maxNewTokens, //
std::make_shared<std::vector<TokenIdType>>(std::move(inputTokens)), //
@ -145,7 +146,8 @@ public:
guidedDecodingParams, //
languageAdapterUid, //
allottedTimeMs, //
contextPhaseParams //
contextPhaseParams, //
cacheSaltID //
)
{
}

View File

@ -526,11 +526,11 @@ void initRequestBindings(pybind11::module_& m)
self.getClientId(), self.getReturnAllGeneratedTokens(), self.getPriority(), self.getRequestType(),
self.getContextPhaseParams(), self.getEncoderInputFeatures(), self.getEncoderOutputLength(),
self.getCrossAttentionMask(), self.getEagleConfig(), self.getSkipCrossAttnBlocks(),
self.getGuidedDecodingParams());
self.getGuidedDecodingParams(), self.getCacheSaltID());
};
auto requestSetstate = [](py::tuple const& state)
{
if (state.size() != 33)
if (state.size() != 34)
{
throw std::runtime_error("Invalid Request state!");
}
@ -550,7 +550,8 @@ void initRequestBindings(pybind11::module_& m)
state[25].cast<tle::RequestType>(), state[26].cast<std::optional<tle::ContextPhaseParams>>(),
state[27].cast<std::optional<tle::Tensor>>(), state[28].cast<std::optional<SizeType32>>(),
state[29].cast<std::optional<tle::Tensor>>(), 1, state[30].cast<std::optional<tle::EagleConfig>>(),
state[31].cast<std::optional<tle::Tensor>>(), state[32].cast<std::optional<tle::GuidedDecodingParams>>());
state[31].cast<std::optional<tle::Tensor>>(), state[32].cast<std::optional<tle::GuidedDecodingParams>>(),
state[33].cast<std::optional<tle::CacheSaltIDType>>());
};
py::class_<tle::Request> request(m, "Request", pybind11::dynamic_attr());
@ -590,7 +591,8 @@ void initRequestBindings(pybind11::module_& m)
std::optional<tle::Tensor>, // skipCrossAttnBlocks
std::optional<tle::GuidedDecodingParams>, // guidedDecodingParams
std::optional<tle::SizeType32>, // languageAdapterUid
std::optional<tle::MillisecondsType> // allottedTimeMs
std::optional<tle::MillisecondsType>, // allottedTimeMs
std::optional<tle::CacheSaltIDType> // cacheSaltID
>(),
// clang-format off
py::arg("input_token_ids"),
@ -630,8 +632,9 @@ void initRequestBindings(pybind11::module_& m)
py::arg("skip_cross_attn_blocks") = py::none(),
py::arg("guided_decoding_params") = py::none(),
py::arg("language_adapter_uid") = py::none(),
py::arg("allotted_time_ms") = py::none()
) // clang-format on
py::arg("allotted_time_ms") = py::none(),
py::arg("cache_salt_id") = py::none()
) // clang-format on
.def_property_readonly("input_token_ids", &tle::Request::getInputTokenIds)
.def_property_readonly("max_tokens", &tle::Request::getMaxTokens)
.def_property("streaming", &tle::Request::getStreaming, &tle::Request::setStreaming)
@ -675,6 +678,7 @@ void initRequestBindings(pybind11::module_& m)
.def_property(
"guided_decoding_params", &tle::Request::getGuidedDecodingParams, &tle::Request::setGuidedDecodingParams)
.def_property("allotted_time_ms", &tle::Request::getAllottedTimeMs, &tle::Request::setAllottedTimeMs)
.def_property("cache_salt_id", &tle::Request::getCacheSaltID, &tle::Request::setCacheSaltID)
.def_property(
"context_phase_params", &tle::Request::getContextPhaseParams, &tle::Request::setContextPhaseParams)
.def(py::pickle(requestGetstate, requestSetstate));

View File

@ -1686,6 +1686,207 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdAndLoraTaskIdTest)
EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool);
}
TEST_F(KVCacheManagerTest, BlockManagerReuseWithCacheSaltIdTest)
{
// Test that cache_salt_id prevents KV cache reuse between requests with same tokens
// but different cache_salt_id values.
using VecTokenExtraIds = LlmRequest::VecTokenExtraIds;
using CacheSaltIDType = LlmRequest::CacheSaltIDType;
auto constexpr numLayers = 12;
auto constexpr numKvHeads = 6;
auto constexpr sizePerHead = 16;
auto constexpr tokensPerBlock = 4;
auto constexpr maxBlocksPerSeq = 4;
auto constexpr blocksInPrimaryPool = 16;
auto constexpr blocksInSecondaryPool = 0;
auto constexpr maxNumSequences = 8;
auto const stream = std::make_shared<tr::CudaStream>();
auto constexpr onboardBlocks = true;
auto constexpr numReturnSequences = 1;
auto constexpr maxAttentionWindow = tokensPerBlock * maxBlocksPerSeq;
auto constexpr beamWidth = 1;
auto const blocksPerWindow = BlocksPerWindow{{maxAttentionWindow, {blocksInPrimaryPool, blocksInSecondaryPool}}};
BlockManager blockManager(std::vector(numLayers, numKvHeads), sizePerHead, tokensPerBlock, blocksPerWindow,
maxNumSequences, stream, maxAttentionWindow, beamWidth,
std::vector<BlockManager::SizeType32>{maxAttentionWindow}, std::nullopt, nvinfer1::DataType::kHALF, 0,
onboardBlocks);
blockManager.allocatePools(false);
EXPECT_EQ(blockManager.getTokensPerBlock(), tokensPerBlock);
EXPECT_EQ(blockManager.getMaxNumBlocks(), blocksInPrimaryPool);
EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool);
SizeType32 constexpr maxNewTokens{0};
tr::SamplingConfig const samplingConfig{beamWidth};
bool constexpr isStreaming{false};
// Create shared input tokens
auto inputTokens = std::make_shared<VecTokens>(VecTokens{100, 101, 102, 103, 104, 105, 106, 107, 108});
auto const inputLength = static_cast<SizeType32>(inputTokens->size());
///////////////////////////////////////////////////////////////////////////
// Test Case 1: Request without cache_salt_id
LlmRequest::RequestIdType requestId{0};
auto llmRequest0 = std::make_shared<LlmRequest>(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, std::nullopt,
false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt,
LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::nullopt, numReturnSequences, std::nullopt,
std::nullopt, false, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
std::nullopt); // No cache_salt_id
GenerationRequest seq0{requestId, inputLength, beamWidth, blockManager.getWindowSizesMetadata()};
// Add first request and get blocks 0, 1, 2
auto constexpr beamIdx = 0;
auto promptLen0 = llmRequest0->getNumTokens(beamIdx);
auto numContextBlocks0 = tc::ceilDiv(promptLen0, blockManager.getTokensPerBlock());
blockManager.addSequence(seq0, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow);
EXPECT_EQ(llmRequest0->getContextCurrentPosition(), 0);
EXPECT_THAT(seq0.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 2}));
// Add generated tokens
llmRequest0->addNewToken(3, beamIdx);
llmRequest0->addNewToken(4, beamIdx);
auto numTokens = llmRequest0->getNumTokens(beamIdx);
auto numBlocks = tc::ceilDiv(numTokens, tokensPerBlock);
EXPECT_EQ(numBlocks, 3);
EXPECT_EQ(blockManager.getNumAllocatedBlocks(), numBlocks);
EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool - numBlocks);
// Release blocks to make them available for reuse
blockManager.releaseBlocks(seq0, llmRequest0);
EXPECT_EQ(blockManager.getNumAllocatedBlocks(), 0);
EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool);
///////////////////////////////////////////////////////////////////////////
// Test Case 2: Request with same tokens but with cache_salt_id = 12345
requestId = 1;
CacheSaltIDType cacheSaltId1{12345};
auto llmRequest1 = std::make_shared<LlmRequest>(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, std::nullopt,
false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt,
LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::nullopt, numReturnSequences, std::nullopt,
std::nullopt, false, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
cacheSaltId1); // With cache_salt_id = 12345
GenerationRequest seq1{requestId, inputLength, beamWidth, blockManager.getWindowSizesMetadata()};
// Should NOT reuse blocks despite same tokens, because cache_salt_id is different
auto promptLen1 = llmRequest1->getNumTokens(beamIdx);
auto numContextBlocks1 = tc::ceilDiv(promptLen1, blockManager.getTokensPerBlock());
blockManager.addSequence(seq1, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow);
EXPECT_EQ(llmRequest1->getContextCurrentPosition(), 0); // No reuse, starts from scratch
EXPECT_THAT(seq1.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({3, 4, 5}));
llmRequest1->addNewToken(3, beamIdx);
llmRequest1->addNewToken(4, beamIdx);
EXPECT_EQ(blockManager.getNumAllocatedBlocks(), numBlocks);
EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool - numBlocks);
// Release blocks
blockManager.releaseBlocks(seq1, llmRequest1);
EXPECT_EQ(blockManager.getNumAllocatedBlocks(), 0);
EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool);
///////////////////////////////////////////////////////////////////////////
// Test Case 3: Request with same tokens and same cache_salt_id = 12345
requestId = 2;
auto llmRequest2 = std::make_shared<LlmRequest>(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, std::nullopt,
false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt,
LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::nullopt, numReturnSequences, std::nullopt,
std::nullopt, false, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
cacheSaltId1); // Same cache_salt_id = 12345
GenerationRequest seq2{requestId, inputLength, beamWidth, blockManager.getWindowSizesMetadata()};
// SHOULD reuse blocks because both tokens and cache_salt_id match
auto promptLen2 = llmRequest2->getNumTokens(beamIdx);
auto numContextBlocks2 = tc::ceilDiv(promptLen2, blockManager.getTokensPerBlock());
blockManager.addSequence(seq2, promptLen2, numContextBlocks2, *llmRequest2, maxAttentionWindow);
EXPECT_EQ(llmRequest2->getContextCurrentPosition(), 2 * tokensPerBlock); // Reuse blocks 3,4
EXPECT_THAT(seq2.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({3, 4, 6}));
llmRequest2->addNewToken(3, beamIdx);
llmRequest2->addNewToken(4, beamIdx);
EXPECT_EQ(blockManager.getNumAllocatedBlocks(), numBlocks);
EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool - numBlocks);
// Release blocks
blockManager.releaseBlocks(seq2, llmRequest2);
EXPECT_EQ(blockManager.getNumAllocatedBlocks(), 0);
EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool);
///////////////////////////////////////////////////////////////////////////
// Test Case 4: Request with same tokens but different cache_salt_id = 67890
requestId = 3;
CacheSaltIDType cacheSaltId2{67890};
auto llmRequest3 = std::make_shared<LlmRequest>(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, std::nullopt,
false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt,
LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::nullopt, numReturnSequences, std::nullopt,
std::nullopt, false, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
cacheSaltId2); // Different cache_salt_id = 67890
GenerationRequest seq3{requestId, inputLength, beamWidth, blockManager.getWindowSizesMetadata()};
// Should NOT reuse blocks from any previous request because cache_salt_id is different
auto promptLen3 = llmRequest3->getNumTokens(beamIdx);
auto numContextBlocks3 = tc::ceilDiv(promptLen3, blockManager.getTokensPerBlock());
blockManager.addSequence(seq3, promptLen3, numContextBlocks3, *llmRequest3, maxAttentionWindow);
EXPECT_EQ(llmRequest3->getContextCurrentPosition(), 0); // No reuse
EXPECT_THAT(seq3.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({7, 8, 9}));
llmRequest3->addNewToken(5, beamIdx);
llmRequest3->addNewToken(6, beamIdx);
EXPECT_EQ(blockManager.getNumAllocatedBlocks(), numBlocks);
EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool - numBlocks);
///////////////////////////////////////////////////////////////////////////
// Test Case 5: Request without cache_salt_id again
requestId = 4;
auto llmRequest4 = std::make_shared<LlmRequest>(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, std::nullopt,
false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt,
LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::nullopt, numReturnSequences, std::nullopt,
std::nullopt, false, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
std::nullopt); // No cache_salt_id
GenerationRequest seq4{requestId, inputLength, beamWidth, blockManager.getWindowSizesMetadata()};
// Should reuse blocks from request0 (blocks 0,1) because both have no cache_salt_id
auto promptLen4 = llmRequest4->getNumTokens(beamIdx);
auto numContextBlocks4 = tc::ceilDiv(promptLen4, blockManager.getTokensPerBlock());
blockManager.addSequence(seq4, promptLen4, numContextBlocks4, *llmRequest4, maxAttentionWindow);
EXPECT_EQ(llmRequest4->getContextCurrentPosition(), 2 * tokensPerBlock); // Reuse blocks 0,1
EXPECT_THAT(seq4.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 10}));
llmRequest4->addNewToken(7, beamIdx);
numTokens = llmRequest4->getNumTokens(beamIdx);
numBlocks = tc::ceilDiv(numTokens, tokensPerBlock);
EXPECT_EQ(blockManager.getNumAllocatedBlocks(), numBlocks * 2);
EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool - numBlocks * 2);
// Clean up
blockManager.releaseBlocks(seq3, llmRequest3);
blockManager.releaseBlocks(seq4, llmRequest4);
EXPECT_EQ(blockManager.getNumAllocatedBlocks(), 0);
EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool);
}
TEST_F(KVCacheManagerTest, KVCacheManagerPerRequestStatsTest)
{
auto constexpr numLayers = 12;

View File

@ -45,6 +45,7 @@ from tensorrt_llm.functional import PositionEmbeddingType
from tensorrt_llm.llmapi.utils import enable_llm_debug
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models.modeling_utils import QuantConfig
from tensorrt_llm.quantization.mode import QuantAlgo
from tensorrt_llm.quantization.utils.fp8_utils import (
resmooth_to_fp8_e8m0, transform_sf_into_required_layout)
@ -468,10 +469,13 @@ class Deepseekv3MoE(nn.Module):
layer_idx=layer_idx,
# DS-R1 W4A8 is only supported through custom quantization script from
# examples/quantization/quantize_mixed_precision_moe.py
weight_loading_mode=(MoEWeightLoadingMode.W4A8_CUSTOM
if model_config.quant_config.quant_mode.
is_int4_weight_only_per_group() else
MoEWeightLoadingMode.VANILLA))
weight_loading_mode=(
MoEWeightLoadingMode.W4A8_CUSTOM
if self._get_experts_quant_config(
model_config,
layer_idx).layer_quant_mode.is_int4_weight_only_per_group()
else MoEWeightLoadingMode.VANILLA),
)
self.mapping = model_config.mapping
@ -536,6 +540,13 @@ class Deepseekv3MoE(nn.Module):
return shared_tp_size, shared_output_scale
@staticmethod
def _get_experts_quant_config(model_config, layer_idx: int) -> QuantConfig:
if getattr(model_config, "quant_config_dict", None) is None:
return model_config.quant_config
return model_config.quant_config_dict.get(
f"model.layers.{layer_idx}.mlp.experts", model_config.quant_config)
def compute_routed_output(self, hidden_states, hidden_states_fp4,
all_rank_num_tokens, all_rank_max_num_tokens,
do_finalize):
@ -657,6 +668,9 @@ class DeepseekV3DecoderLayer(DecoderLayer):
quant_config = self._get_decoder_layer_quant_config(
model_config, layer_idx)
self.is_nvfp4 = quant_config.layer_quant_mode.has_nvfp4()
assert (
quant_config.quant_algo
is not QuantAlgo.MIXED_PRECISION), "MIXED_PRECISION is ambiguous"
has_tp = mapping.has_tp()
self.allreduce = AllReduce(mapping=model_config.mapping,

View File

@ -40,14 +40,42 @@ class CUDAGraphRunner:
self.max_beam_width = engine.max_beam_width
self.spec_config = engine.spec_config
self.max_possible_draft_len = (self.spec_config.max_draft_len
if self.enable_spec_decode else 0)
self.graphs: Dict[Tuple[int, int], torch.cuda.CUDAGraph] = {}
self.static_inputs: Dict[Tuple[int, int], Dict[str, torch.Tensor]] = {}
self.graph_outputs: Dict[Tuple[int, int],
Callable[[], Optional[torch.Tensor]]] = {}
self.graph_metadata: Dict[Tuple[int, int], Dict[str, Any]] = {}
self.memory_pool = engine._cuda_graph_mem_pool
self.padding_dummy_request: Optional["Request"] = None
self.shared_static_tensors: Dict[str, torch.Tensor] = {}
if self.enabled:
self._create_shared_static_tensors()
def _create_shared_static_tensors(self):
"""Allocates static tensors sized for the largest possible batch."""
engine = self._get_engine()
token_per_request = self.max_possible_draft_len + 1
max_total_tokens = (self.max_supported_batch_size *
self.max_beam_width * token_per_request)
max_total_tokens = min(max_total_tokens, engine.max_num_tokens)
self.shared_static_tensors = {
"input_ids":
torch.ones((max_total_tokens, ), device="cuda", dtype=torch.int32),
"position_ids":
torch.zeros((1, max_total_tokens), device="cuda",
dtype=torch.int32),
}
if engine.use_mrope:
self.shared_static_tensors["mrope_position_deltas"] = torch.zeros(
(self.max_supported_batch_size, 1),
device="cuda",
dtype=torch.int32)
@property
def enable_spec_decode(self):
return self._get_engine().is_spec_decode
@ -139,38 +167,32 @@ class CUDAGraphRunner:
def capture(self, batch_size: int, forward_fn: Callable,
initial_inputs: Dict[str, Any]):
"""Captures the forward pass for a given batch size."""
engine = self._get_engine()
key = (batch_size, self.draft_len)
spec_metadata = initial_inputs.get("spec_metadata", None)
# [CUDA graph spec decode padding]
# We pad input IDs/position IDs to the maximum draft length (token per request).
# We're forced to do this because we cannot reallocate inputs over many graph runs.
token_per_request = spec_metadata.max_draft_len + 1 if spec_metadata is not None else 1
token_per_request = self.max_possible_draft_len + 1
num_tokens_for_capture = (batch_size * self.max_beam_width *
token_per_request)
static_tensors = {
sliced_static_tensors = {
"input_ids":
torch.ones((batch_size * self.max_beam_width * token_per_request, ),
device="cuda",
dtype=torch.int32),
self.shared_static_tensors["input_ids"][:num_tokens_for_capture],
"position_ids":
torch.zeros((
1,
batch_size * self.max_beam_width * token_per_request,
),
device="cuda",
dtype=torch.int32),
self.shared_static_tensors["position_ids"]
[:, :num_tokens_for_capture],
}
if engine.use_mrope:
static_tensors["mrope_position_deltas"] = torch.zeros(
(batch_size, 1), device="cuda", dtype=torch.int32)
self.static_inputs[key] = static_tensors
if "mrope_position_deltas" in self.shared_static_tensors:
sliced_static_tensors["mrope_position_deltas"] = \
self.shared_static_tensors["mrope_position_deltas"][:batch_size]
# Use the sliced tensors for capture
capture_inputs = initial_inputs.copy()
capture_inputs.update(static_tensors)
capture_inputs.update(sliced_static_tensors)
self.graph_metadata[key] = {
"attn_metadata": initial_inputs["attn_metadata"],
"spec_metadata": spec_metadata,
"spec_metadata": initial_inputs.get("spec_metadata", None),
}
# We have to do warm up runs to initialize PyTorch's
@ -198,7 +220,7 @@ class CUDAGraphRunner:
assert current_inputs.get(
"spec_metadata") is stored_meta["spec_metadata"]
static_tensors = self.static_inputs[key]
static_tensors = self.shared_static_tensors
input_ids = current_inputs["input_ids"]
seqlen = input_ids.shape[0]
@ -301,7 +323,6 @@ class CUDAGraphRunner:
for graph in self.graphs.values():
graph.reset()
self.graphs.clear()
self.static_inputs.clear()
self.graph_outputs.clear()
self.graph_metadata.clear()
self.padding_dummy_request = None

View File

@ -546,6 +546,7 @@ def executor_request_to_llm_request(
priority=0.5,
llm_request_type=llm_request_type,
context_phase_params=executor_request.context_phase_params,
cache_salt_id=executor_request.cache_salt_id,
py_multimodal_data=getattr(executor_request, "py_multimodal_data",
None))
if child_req_ids:

View File

@ -426,7 +426,6 @@ class PyTorchModelEngine(ModelEngine):
# the model engine.
self.attn_metadata = None
self.iter_states = {}
self._cuda_graphs = {}
self._cuda_graph_mem_pool = self._torch_compile_backend._graph_pool_handle if self._torch_compile_enabled else None
self._cuda_graph_padding_enabled = pytorch_backend_config.cuda_graph_padding_enabled

View File

@ -124,6 +124,7 @@ class GenerationExecutor(ABC):
postproc_params: Optional[PostprocParams] = None,
multimodal_params: Optional[MultimodalParams] = None,
scheduling_params: Optional[SchedulingParams] = None,
cache_salt_id: Optional[int] = None,
) -> GenerationResult:
"""Generate output for the given prompt token ids in the asynchronous mode.
Asynchronous generation accepts single prompt only.
@ -147,7 +148,8 @@ class GenerationExecutor(ABC):
kv_cache_retention_config=kv_cache_retention_config,
disaggregated_params=disaggregated_params,
multimodal_params=multimodal_params,
scheduling_params=scheduling_params)
scheduling_params=scheduling_params,
cache_salt_id=cache_salt_id)
result = self.submit(request)
# release memory in time
if hasattr(request, "multimodal_params"):

View File

@ -97,6 +97,7 @@ class GenerationRequest:
postproc_params: Optional[PostprocParams] = None,
multimodal_params: Optional[MultimodalParams] = None,
scheduling_params: Optional[SchedulingParams] = None,
cache_salt_id: Optional[int] = None,
):
if isinstance(prompt_token_ids, list):
self.prompt_token_ids = prompt_token_ids
@ -122,6 +123,7 @@ class GenerationRequest:
self.id: Optional[int] = None
self.disaggregated_params = disaggregated_params
self.scheduling_params = scheduling_params
self.cache_salt_id = cache_salt_id
def set_id(self, id):
assert self.id is None, f"Request ID is already set: {self.id}"

View File

@ -572,7 +572,8 @@ class GenerationExecutorWorker(GenerationExecutor):
request.sampling_params.logits_processor,
kv_cache_retention_config=request.kv_cache_retention_config,
context_phase_params=context_phase_params,
type=request_type)
type=request_type,
cache_salt_id=request.cache_salt_id)
executor_request.py_lora_path = py_lora_path
if self._is_pytorch_backend and request.multimodal_params is not None:

View File

@ -11,7 +11,8 @@ from .utils import (ALL_SUPPORTED_AUDIO_MODELS, ALL_SUPPORTED_IMAGE_MODELS,
add_multimodal_placeholders, apply_chat_template,
async_load_audio, async_load_image, async_load_video,
convert_image_mode, default_multimodal_input_loader,
encode_base64_content_from_url, load_image, load_video)
encode_base64_content_from_url, get_cache_salt_id,
load_image, load_video)
__all__ = [
"ALL_SUPPORTED_MULTIMODAL_MODELS",
@ -44,4 +45,5 @@ __all__ = [
"encode_base64_content_from_url",
"load_image",
"load_video",
"get_cache_salt_id",
]

View File

@ -17,6 +17,7 @@ from torchvision.transforms import ToTensor
from transformers import AutoProcessor, ProcessorMixin
from transformers.utils import logging
from tensorrt_llm.inputs.multimodal import default_hasher
from tensorrt_llm.inputs.registry import (MULTIMODAL_PLACEHOLDER_REGISTRY,
MultimodalPlaceholderPlacement)
from tensorrt_llm.llmapi.llm_utils import ModelLoader
@ -610,3 +611,14 @@ def default_multimodal_input_loader(
inputs.append(input)
return inputs
def get_cache_salt_id(cache_salt: str) -> int:
b = cache_salt.encode("utf-8")
h = default_hasher(b).digest(length=8)
cache_salt_id = int.from_bytes(h, "little", signed=False)
if cache_salt_id < 0 or cache_salt_id >= (1 << 64):
raise ValueError(
f"cache_salt_id must be in [0, 2**64 - 1], got {cache_salt_id}.")
return cache_salt_id

View File

@ -27,7 +27,8 @@ from ..executor.postproc_worker import PostprocParams
from ..executor.utils import (create_mpi_comm_session,
get_spawn_proxy_process_env)
from ..inputs import (PromptInputs, create_input_processor,
create_input_processor_with_hash, prompt_inputs)
create_input_processor_with_hash, get_cache_salt_id,
prompt_inputs)
from ..logger import logger
from ..sampling_params import SamplingParams
from ..scheduling_params import SchedulingParams
@ -325,6 +326,7 @@ class BaseLLM:
disaggregated_params: Optional[DisaggregatedParams] = None,
_postproc_params: Optional[PostprocParams] = None,
scheduling_params: Optional[SchedulingParams] = None,
cache_salt: Optional[str] = None,
) -> RequestOutput:
"""Generate output for the given prompt in the asynchronous mode.
Asynchronous generation accepts single prompt only.
@ -339,7 +341,7 @@ class BaseLLM:
kv_cache_retention_config (tensorrt_llm.bindings.executor.KvCacheRetentionConfig, optional): Configuration for the request's retention in the KV Cache. Defaults to None.
disaggregated_params (tensorrt_llm.disaggregated_params.DisaggregatedParams, optional): Disaggregated parameters. Defaults to None.
scheduling_params (tensorrt_llm.scheduling_params.SchedulingParams, optional): Scheduling parameters. Defaults to None.
cache_salt (str, optional): If specified, KV cache will be salted with the provided string to limit the kv cache reuse to the requests with the same string. Defaults to None.
Returns:
tensorrt_llm.llmapi.RequestOutput: The output data of the completion request to the LLM.
"""
@ -349,7 +351,8 @@ class BaseLLM:
raise RuntimeError("LLM is shutting down")
sampling_params = self._prepare_sampling_params(sampling_params)
cache_salt_id = get_cache_salt_id(
cache_salt) if cache_salt is not None else None
# With pytorch backend, py_executor has logic to handle max_tokens of 1,
# so set to 1 to avoid allocating unnecessary KV cache blocks for single request
# TODO: Also support for trt backend
@ -444,6 +447,7 @@ class BaseLLM:
postproc_params=_postproc_params,
multimodal_params=multimodal_params,
scheduling_params=scheduling_params,
cache_salt_id=cache_salt_id,
)
return RequestOutput._from_generation_result(result, prompt,

View File

@ -19,7 +19,8 @@ from openai.types.responses.response import ToolChoice
from openai.types.responses.tool import Tool
from openai.types.shared import Metadata, Reasoning
from openai_harmony import ReasoningEffort
from pydantic import BaseModel, ConfigDict, Field, model_validator
from pydantic import (BaseModel, ConfigDict, Field, field_validator,
model_validator)
from typing_extensions import Annotated, Required, TypeAlias, TypedDict
from tensorrt_llm.executor.request import LoRARequest
@ -592,6 +593,13 @@ class ChatCompletionRequest(OpenAIBaseModel):
description=("Parameters for disaggregated serving"),
)
cache_salt: Optional[str] = Field(
default=None,
description=
("If specified, KV cache will be salted with the provided string "
"to limit the kv cache reuse on with the requests having the same string."
))
# doc: end-chat-completion-extra-params
def to_sampling_params(self, vocab_size: int = 32000) -> SamplingParams:
@ -671,6 +679,16 @@ class ChatCompletionRequest(OpenAIBaseModel):
raise ValueError("suffix is not supported")
return data
@field_validator("cache_salt")
@classmethod
def check_cache_salt_support(cls, v):
if v is not None:
if not isinstance(v, str) or not v.strip():
raise ValueError(
"Parameter 'cache_salt' must be a non-empty string if provided."
)
return v
ResponseInputOutputItem: TypeAlias = Union[ResponseInputItemParam,
ResponseReasoningItem,

View File

@ -462,7 +462,8 @@ class OpenAIServer:
_postproc_params=postproc_params if self.postproc_worker_enabled else None,
streaming=request.stream,
lora_request=request.lora_request,
disaggregated_params=disaggregated_params
disaggregated_params=disaggregated_params,
cache_salt=request.cache_salt,
)
asyncio.create_task(self.await_disconnected(raw_request, promise))
if not self.postproc_worker_enabled:

View File

@ -1557,6 +1557,18 @@ def test_openai_misc_example(llm_root, llm_venv, backend: str):
])
def test_openai_cache_salt(llm_root, llm_venv):
example_root = Path(os.path.join(llm_root, "examples", "serve"))
test_root = unittest_path() / "llmapi" / "apps"
llm_venv.run_cmd([
"-m", "pip", "install", "-r",
os.path.join(example_root, "requirements.txt")
])
llm_venv.run_cmd(
["-m", "pytest",
str(test_root / "_test_openai_cache_salt.py")])
@pytest.mark.parametrize("backend", ["pytorch", "trt"])
def test_openai_completions_example(llm_root, llm_venv, backend: str):
test_root = unittest_path() / "llmapi" / "apps"

View File

@ -186,6 +186,7 @@ def create_mock_engine(batch_size: int):
_cuda_graph_batch_sizes=[batch_size],
_max_cuda_graph_batch_size=batch_size,
max_beam_width=1,
max_num_tokens=8192,
is_spec_decode=False,
spec_config=None,
_cuda_graph_mem_pool=None,

View File

@ -196,6 +196,9 @@ methods:
annotation: Optional[tensorrt_llm.scheduling_params.SchedulingParams]
default: null
status: prototype
cache_salt:
annotation: Optional[str]
default: null
return_annotation: tensorrt_llm.llmapi.llm.RequestOutput
get_kv_cache_events:
parameters:

View File

@ -0,0 +1,231 @@
"""Test cache_salt functionality in OpenAI API to ensure it prevents cache reuse"""
import os
import tempfile
import openai
import pytest
import yaml
from ..test_llm import get_model_path
from .openai_server import RemoteOpenAIServer
pytestmark = pytest.mark.threadleak(enabled=False)
@pytest.fixture(scope="module", ids=["TinyLlama-1.1B-Chat"])
def model_name() -> str:
return "llama-models-v2/TinyLlama-1.1B-Chat-v1.0"
@pytest.fixture(scope="module")
def temp_extra_llm_api_options_file():
"""Create temporary config file with KV cache enabled for testing"""
temp_dir = tempfile.gettempdir()
temp_file_path = os.path.join(temp_dir, "cache_salt_test_options.yaml")
try:
extra_llm_api_options_dict = {
# Enable KV cache reuse
"kv_cache_config": {
"enable_block_reuse": True,
},
# Enable performance metrics to get cache hit rate
"return_perf_metrics": True,
"enable_iter_perf_stats": True,
"enable_iter_req_stats": True,
# Disable CUDA graph for compatibility
"cuda_graph_config": None,
}
with open(temp_file_path, 'w') as f:
yaml.dump(extra_llm_api_options_dict, f)
yield temp_file_path
finally:
if os.path.exists(temp_file_path):
os.remove(temp_file_path)
@pytest.fixture(scope="module")
def server(model_name: str,
temp_extra_llm_api_options_file: str) -> RemoteOpenAIServer:
model_path = get_model_path(model_name)
args = []
args.extend(["--backend", "pytorch"])
args.extend(["--extra_llm_api_options", temp_extra_llm_api_options_file])
with RemoteOpenAIServer(model_path, args) as remote_server:
yield remote_server
@pytest.fixture(scope="module")
def client(server: RemoteOpenAIServer) -> openai.OpenAI:
return server.get_client()
def get_cache_hit_rate(client: openai.OpenAI) -> float:
"""Get cache hit rate from the metrics endpoint"""
import httpx
# Get the base URL from the OpenAI client (it includes /v1)
# We need to go up one level to access /metrics
base_url = str(client.base_url).rstrip('/')
if base_url.endswith('/v1'):
base_url = base_url[:-3] # Remove /v1
# Make a direct HTTP request to the metrics endpoint
with httpx.Client() as http_client:
response = http_client.get(f"{base_url}/metrics", timeout=5.0)
# Check if metrics endpoint is available
if response.status_code != 200:
raise RuntimeError(
f"Metrics endpoint returned status {response.status_code}")
metrics = response.json()
# Validate that we have metrics data
if not isinstance(metrics, list) or len(metrics) == 0:
raise ValueError("No metrics data available")
# Get the most recent stats
latest_stats = metrics[-1]
# Extract KV cache statistics
kv_cache_stats = latest_stats.get("kvCacheStats", {})
if not kv_cache_stats:
raise ValueError("No KV cache statistics available in metrics")
try:
print(f"kv_cache_stats reused: {kv_cache_stats['reusedBlocks']}")
print(f"kv_cache_stats missed: {kv_cache_stats['missedBlocks']}")
print(f"kv_cache_stats hit rate: {kv_cache_stats['cacheHitRate']}")
return kv_cache_stats["cacheHitRate"]
except Exception as e:
print(f"Warning: Could not get cache metrics: {e}")
return 0.0
def test_cache_salt_prevents_reuse_chat(client: openai.OpenAI, model_name: str):
"""Test that different cache_salt values prevent KV cache reuse in chat completions"""
# Common messages that will be used across all requests
messages = [{
"role": "system",
"content": "You are a helpful assistant. Keep responses brief."
}, {
"role":
"user",
"content":
"What is the capital of France? Answer in one sentence."
}]
# Test configuration
max_tokens = 30
temperature = 0.0 # Deterministic for testing
# Track responses for comparison
responses = []
# Test Case 1: First request without cache_salt (baseline)
print("\n=== Test Case 1: First request without cache_salt ===")
response1 = client.chat.completions.create(
model=model_name,
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
)
responses.append(response1.choices[0].message.content)
print(f"Response 1: {response1.choices[0].message.content[:100]}...")
# Display initial cache metrics
initial_hit_rate = get_cache_hit_rate(client)
print(f"Initial cache hit rate: {initial_hit_rate:.2%}")
# Test Case 2: Same messages without cache_salt (should reuse cache)
print(
"\n=== Test Case 2: Same messages without cache_salt (should reuse) ==="
)
response2 = client.chat.completions.create(
model=model_name,
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
)
responses.append(response2.choices[0].message.content)
print(f"Response 2: {response2.choices[0].message.content[:100]}...")
# Check if metrics are available
hit_rate_after_reuse = get_cache_hit_rate(client)
print(f"Cache hit rate after reuse: {hit_rate_after_reuse:.2%}")
assert hit_rate_after_reuse >= initial_hit_rate, \
"Cache hit rate should increase when reusing cache without salt"
# Test Case 3: Same messages with cache_salt="user_123" (should NOT reuse)
print(
"\n=== Test Case 3: Same messages with cache_salt='user_123' (no reuse) ==="
)
response3 = client.chat.completions.create(
model=model_name,
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
extra_body={"cache_salt": "user_123"})
responses.append(response3.choices[0].message.content)
print(f"Response 3: {response3.choices[0].message.content[:100]}...")
# Record metrics after request with different salt
hit_rate_after_salt1 = get_cache_hit_rate(client)
print(f"Cache hit rate after salt 'user_123': {hit_rate_after_salt1:.2%}")
assert hit_rate_after_salt1 < hit_rate_after_reuse, \
"Cache hit rate should decrease when using a different salt"
# Test Case 4: Same messages with same cache_salt="user_123" (should reuse)
print(
"\n=== Test Case 4: Same messages with same cache_salt='user_123' (should reuse) ==="
)
response4 = client.chat.completions.create(
model=model_name,
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
extra_body={"cache_salt": "user_123"} # Same salt should enable reuse
)
responses.append(response4.choices[0].message.content)
print(f"Response 4: {response4.choices[0].message.content[:100]}...")
# Cache hit rate should increase again when using same salt
hit_rate_after_salt1_reuse = get_cache_hit_rate(client)
print(
f"Cache hit rate after reusing salt 'user_123': {hit_rate_after_salt1_reuse:.2%}"
)
assert hit_rate_after_salt1_reuse >= hit_rate_after_salt1, \
"Cache hit rate should increase when reusing same salt"
# Test Case 5: Same messages with different cache_salt="user_456" (should NOT reuse)
print(
"\n=== Test Case 5: Same messages with cache_salt='user_456' (no reuse) ==="
)
response5 = client.chat.completions.create(
model=model_name,
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
extra_body={"cache_salt": "user_456"})
responses.append(response5.choices[0].message.content)
print(f"Response 5: {response5.choices[0].message.content[:100]}...")
# Cache hit rate should decrease when using a different salt
hit_rate_after_salt2 = get_cache_hit_rate(client)
print(f"Cache hit rate after salt 'user_456': {hit_rate_after_salt2:.2%}")
assert hit_rate_after_salt2 < hit_rate_after_salt1_reuse, \
"Cache hit rate should decrease when using a different salt"
# Test empty string (should be rejected)
with pytest.raises(Exception) as exc_info:
client.chat.completions.create(
model=model_name,
messages=messages,
max_tokens=max_tokens,
extra_body={"cache_salt": ""} # Empty string should be rejected
)
print(f"Empty string rejected as expected: {exc_info.value}")