[TRTLLM-10487][feat] Add user-provided UUID support for multimodal KV cache identification. (#11075)

Signed-off-by: SimengLiu-nv <simengl@nvidia.com>
This commit is contained in:
Simeng Liu 2026-02-11 21:48:47 -08:00 committed by GitHub
parent 936220e746
commit 12085536df
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
28 changed files with 1297 additions and 153 deletions

View File

@ -119,6 +119,7 @@ public:
std::optional<std::shared_ptr<std::vector<std::vector<SizeType32>>>> multimodalHashes = std::nullopt,
std::optional<std::shared_ptr<std::vector<SizeType32>>> multimodalPositions = std::nullopt,
std::optional<std::shared_ptr<std::vector<SizeType32>>> multimodalLengths = std::nullopt,
std::optional<std::shared_ptr<std::vector<std::optional<std::string>>>> multimodalUuids = std::nullopt,
std::optional<TensorPtr> multimodalEmbedding = std::nullopt,
std::optional<TensorPtr> mropeRotaryCosSin = std::nullopt,
std::optional<SizeType32> mropePositionDeltas = std::nullopt,
@ -168,6 +169,7 @@ public:
, mMultimodalHashes(std::move(multimodalHashes))
, mMultimodalPositions(std::move(multimodalPositions))
, mMultimodalLengths(std::move(multimodalLengths))
, mMultimodalUuids(std::move(multimodalUuids))
, mMultimodalEmbedding(std::move(multimodalEmbedding))
, mMropeRotaryCosSin(std::move(mropeRotaryCosSin))
, mMropePositionDeltas(mropePositionDeltas)
@ -909,6 +911,11 @@ public:
return mMultimodalLengths;
}
[[nodiscard]] std::optional<std::shared_ptr<std::vector<std::optional<std::string>>>> getMultimodalUuids() const
{
return mMultimodalUuids;
}
[[nodiscard]] std::optional<TensorPtr> getMultimodalEmbedding() const
{
return mMultimodalEmbedding;
@ -1964,6 +1971,7 @@ protected:
std::optional<std::shared_ptr<std::vector<std::vector<SizeType32>>>> mMultimodalHashes{std::nullopt};
std::optional<std::shared_ptr<std::vector<SizeType32>>> mMultimodalPositions{std::nullopt};
std::optional<std::shared_ptr<std::vector<SizeType32>>> mMultimodalLengths{std::nullopt};
std::optional<std::shared_ptr<std::vector<std::optional<std::string>>>> mMultimodalUuids{std::nullopt};
std::optional<TensorPtr> mMultimodalEmbedding{std::nullopt};
std::optional<TensorPtr> mMropeRotaryCosSin{std::nullopt};
std::optional<SizeType32> mMropePositionDeltas{std::nullopt};
@ -2252,6 +2260,7 @@ public:
std::optional<std::vector<std::vector<SizeType32>>> multimodalHashes = std::nullopt,
std::optional<std::vector<SizeType32>> multimodalPositions = std::nullopt,
std::optional<std::vector<SizeType32>> multimodalLengths = std::nullopt,
std::optional<std::vector<std::optional<std::string>>> multimodalUuids = std::nullopt,
std::optional<TensorPtr> multimodalEmbedding = std::nullopt,
std::optional<TensorPtr> mropeRotaryCosSin = std::nullopt,
std::optional<SizeType32> mropePositionDeltas = std::nullopt,
@ -2292,6 +2301,9 @@ public:
multimodalLengths.has_value()
? std::make_shared<std::vector<SizeType32>>(std::move(multimodalLengths.value()))
: std::optional<std::shared_ptr<std::vector<SizeType32>>>(std::nullopt),
multimodalUuids.has_value()
? std::make_shared<std::vector<std::optional<std::string>>>(std::move(multimodalUuids.value()))
: std::optional<std::shared_ptr<std::vector<std::optional<std::string>>>>(std::nullopt),
std::move(multimodalEmbedding), std::move(mropeRotaryCosSin), mropePositionDeltas, loraTaskId,
std::move(loraWeights), std::move(loraConfig), lookaheadConfig, std::move(kvCacheRetentionConfig),
returnLogProbs, returnContextLogits, returnGenerationLogits,

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -48,10 +48,6 @@ namespace tensorrt_llm::executor
{
using SizeType32 = tensorrt_llm::runtime::SizeType32;
// Mmkey is used in KVCacheBlock when multimodal data presents in a block.
// Type alias for hash array + start offset at per-block granularity.
// This differs from the per-request level multimodal hash in MultimodalInput.
using MmKey = std::pair<std::array<uint8_t, 32>, SizeType32>;
/// @brief Version of TRT-LLM
char const* version() noexcept;
@ -301,11 +297,13 @@ class MultimodalInput
{
public:
explicit MultimodalInput(std::vector<std::vector<SizeType32>> multimodalHashes,
std::vector<SizeType32> multimodalPositions, std::vector<SizeType32> multimodalLengths);
std::vector<SizeType32> multimodalPositions, std::vector<SizeType32> multimodalLengths,
std::optional<std::vector<std::optional<std::string>>> multimodalUuids = std::nullopt);
[[nodiscard]] std::vector<std::vector<SizeType32>> getMultimodalHashes() const;
[[nodiscard]] std::vector<SizeType32> getMultimodalPositions() const;
[[nodiscard]] std::vector<SizeType32> getMultimodalLengths() const;
[[nodiscard]] std::optional<std::vector<std::optional<std::string>>> const& getMultimodalUuids() const;
private:
friend class Serialization;
@ -315,6 +313,9 @@ private:
std::vector<SizeType32> mMultimodalPositions;
/// @brief The multimodal lengths
std::vector<SizeType32> mMultimodalLengths;
/// @brief Optional user-provided UUIDs for multimodal items.
/// When provided, these are returned in KV cache events instead of content hashes.
std::optional<std::vector<std::optional<std::string>>> mMultimodalUuids;
};
/// @brief Configuration for mrope

View File

@ -349,6 +349,11 @@ public:
static void serialize(KVCacheUpdatedData const& data, std::ostream& os);
[[nodiscard]] static KVCacheUpdatedData deserializeKVCacheUpdatedData(std::istream& is);
// MmKey
[[nodiscard]] static size_t serializedSize(MmKey const& key);
static void serialize(MmKey const& key, std::ostream& os);
[[nodiscard]] static MmKey deserializeMmKey(std::istream& is);
// UniqueToken
[[nodiscard]] static size_t serializedSize(tensorrt_llm::runtime::UniqueToken const& token);
static void serialize(tensorrt_llm::runtime::UniqueToken const& token, std::ostream& os);

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,6 +16,7 @@
#pragma once
#include <array>
#include <chrono>
#include <cstdint>
#include <functional>
@ -70,6 +71,29 @@ using EagleChoices = std::vector<std::vector<SizeType32>>;
using PriorityType = float;
using BufferView = std::basic_string_view<uint8_t>;
//! MmKey is used in KVCacheBlock when multimodal data presents in a block.
//! Hash is a 32-byte array; startOffset is the per-block token offset; uuid is optional.
struct MmKey
{
std::array<uint8_t, 32> hash;
SizeType32 startOffset{};
std::optional<std::string> uuid{std::nullopt};
MmKey() = default;
MmKey(std::array<uint8_t, 32> hash, SizeType32 startOffset, std::optional<std::string> uuid = std::nullopt)
: hash(std::move(hash))
, startOffset(startOffset)
, uuid(std::move(uuid))
{
}
bool operator==(MmKey const& other) const noexcept
{
return hash == other.hash && startOffset == other.startOffset && uuid == other.uuid;
}
};
enum class DataType
{
kBOOL,

View File

@ -100,6 +100,7 @@ std::vector<MmKey> generateBlockHashExtraKeys(
auto const multimodalHashes = llmRequest.getMultimodalHashes();
auto const multimodalPositions = llmRequest.getMultimodalPositions();
auto const multimodalLengths = llmRequest.getMultimodalLengths();
auto const multimodalUuids = llmRequest.getMultimodalUuids();
if (!multimodalHashes || !multimodalPositions || !multimodalLengths || !(*multimodalHashes)
|| (*multimodalHashes)->empty() || !(*multimodalPositions) || (*multimodalPositions)->empty()
@ -115,7 +116,7 @@ std::vector<MmKey> generateBlockHashExtraKeys(
return {};
}
std::vector<MmKey> extraKeys; // MmKey = std::pair<std::array<uint8_t, 32>, SizeType32>
std::vector<MmKey> extraKeys;
extraKeys.reserve((*multimodalPositions)->size());
std::array<uint8_t, 32> mmHashArray;
@ -145,7 +146,15 @@ std::vector<MmKey> generateBlockHashExtraKeys(
if (endTokenIdx > startPos && startTokenIdx < startPos + length)
{
uint64_t mmStartInBlock = (startPos >= startTokenIdx) ? 0 : static_cast<uint64_t>(startTokenIdx - startPos);
extraKeys.emplace_back(mmHashArray, mmStartInBlock);
// Get UUID if available
std::optional<std::string> uuid = std::nullopt;
if (multimodalUuids && *multimodalUuids && i < (*multimodalUuids)->size())
{
uuid = (*(*multimodalUuids))[i];
}
extraKeys.emplace_back(mmHashArray, mmStartInBlock, std::move(uuid));
}
}
@ -222,8 +231,10 @@ size_t BlockKeyHasher::hash(BlockKey const& blockKey, std::size_t parentHash) no
// block
if (!blockKey.extraKeys.empty())
{
for (auto const& [mmHash, startOffset] : blockKey.extraKeys)
for (auto const& mmKey : blockKey.extraKeys)
{
auto const& mmHash = mmKey.hash;
auto const& startOffset = mmKey.startOffset;
// Hash the multimodal hash array in 32-bit chunks (more efficient)
for (size_t i = 0; i < 32; i += 4)
{

View File

@ -21,10 +21,12 @@
namespace tensorrt_llm::executor
{
MultimodalInput::MultimodalInput(std::vector<std::vector<SizeType32>> multimodalHashes,
std::vector<SizeType32> multimodalPositions, std::vector<SizeType32> multimodalLengths)
std::vector<SizeType32> multimodalPositions, std::vector<SizeType32> multimodalLengths,
std::optional<std::vector<std::optional<std::string>>> multimodalUuids)
: mMultimodalHashes(std::move(multimodalHashes))
, mMultimodalPositions(std::move(multimodalPositions))
, mMultimodalLengths(std::move(multimodalLengths))
, mMultimodalUuids(std::move(multimodalUuids))
{
}
@ -43,4 +45,9 @@ std::vector<SizeType32> MultimodalInput::getMultimodalLengths() const
return mMultimodalLengths;
}
std::optional<std::vector<std::optional<std::string>>> const& MultimodalInput::getMultimodalUuids() const
{
return mMultimodalUuids;
}
} // namespace tensorrt_llm::executor

View File

@ -339,7 +339,9 @@ MultimodalInput Serialization::deserializeMultimodalInput(std::istream& is)
auto multimodalHashes = su::deserialize<std::vector<std::vector<SizeType32>>>(is);
auto multimodalPositions = su::deserialize<std::vector<SizeType32>>(is);
auto multimodalLengths = su::deserialize<std::vector<SizeType32>>(is);
return MultimodalInput{std::move(multimodalHashes), std::move(multimodalPositions), std::move(multimodalLengths)};
auto multimodalUuids = su::deserialize<std::optional<std::vector<std::optional<std::string>>>>(is);
return MultimodalInput{std::move(multimodalHashes), std::move(multimodalPositions), std::move(multimodalLengths),
std::move(multimodalUuids)};
}
void Serialization::serialize(MultimodalInput const& multimodalInput, std::ostream& os)
@ -347,6 +349,7 @@ void Serialization::serialize(MultimodalInput const& multimodalInput, std::ostre
su::serialize(multimodalInput.mMultimodalHashes, os);
su::serialize(multimodalInput.mMultimodalPositions, os);
su::serialize(multimodalInput.mMultimodalLengths, os);
su::serialize(multimodalInput.mMultimodalUuids, os);
}
size_t Serialization::serializedSize(MultimodalInput const& multimodalInput)
@ -355,6 +358,7 @@ size_t Serialization::serializedSize(MultimodalInput const& multimodalInput)
totalSize += su::serializedSize(multimodalInput.mMultimodalHashes);
totalSize += su::serializedSize(multimodalInput.mMultimodalPositions);
totalSize += su::serializedSize(multimodalInput.mMultimodalLengths);
totalSize += su::serializedSize(multimodalInput.mMultimodalUuids);
return totalSize;
}
@ -2441,6 +2445,31 @@ KVCacheUpdatedData Serialization::deserializeKVCacheUpdatedData(std::istream& is
return KVCacheUpdatedData{blockHash, cacheLevel, priority};
}
// MmKey
size_t Serialization::serializedSize(MmKey const& key)
{
size_t totalSize = 0;
totalSize += su::serializedSize(key.hash);
totalSize += su::serializedSize(key.startOffset);
totalSize += su::serializedSize(key.uuid);
return totalSize;
}
void Serialization::serialize(MmKey const& key, std::ostream& os)
{
su::serialize(key.hash, os);
su::serialize(key.startOffset, os);
su::serialize(key.uuid, os);
}
MmKey Serialization::deserializeMmKey(std::istream& is)
{
auto hash = su::deserialize<std::array<uint8_t, 32>>(is);
auto startOffset = su::deserialize<SizeType32>(is);
auto uuid = su::deserialize<std::optional<std::string>>(is);
return MmKey{std::move(hash), startOffset, std::move(uuid)};
}
// UniqueToken
size_t Serialization::serializedSize(tensorrt_llm::runtime::UniqueToken const& token)
{

View File

@ -170,6 +170,7 @@ static_assert(hasSerializedSize<KVCacheRemovedData>(size_t()));
static_assert(hasSerializedSize<KVCacheEventDiff<SizeType32>>(size_t()));
static_assert(hasSerializedSize<KVCacheUpdatedData>(size_t()));
static_assert(hasSerializedSize<tensorrt_llm::runtime::UniqueToken>(size_t()));
static_assert(hasSerializedSize<MmKey>(size_t()));
template <typename T>
size_t serializedSize(T const& data)
@ -290,6 +291,7 @@ static_assert(hasSerialize<KVCacheRemovedData>(nullptr));
static_assert(hasSerialize<KVCacheEventDiff<SizeType32>>(nullptr));
static_assert(hasSerialize<KVCacheUpdatedData>(nullptr));
static_assert(hasSerialize<tensorrt_llm::runtime::UniqueToken>(nullptr));
static_assert(hasSerialize<MmKey>(nullptr));
template <typename T>
void serialize(T const& data, std::ostream& os)
@ -476,6 +478,10 @@ T deserialize(std::istream& is)
{
return Serialization::deserializeContextPhaseParams(is);
}
else if constexpr (std::is_same_v<T, tensorrt_llm::executor::MmKey>)
{
return Serialization::deserializeMmKey(is);
}
else if constexpr (std::is_same_v<T, tensorrt_llm::executor::Request>)
{
return Serialization::deserializeRequest(is);

View File

@ -289,6 +289,7 @@ void initBindings(nb::module_& m)
std::optional<std::vector<std::vector<tb::LlmRequest::SizeType32>>> multimodal_hashes,
std::optional<std::vector<tb::LlmRequest::SizeType32>> multimodal_positions,
std::optional<std::vector<tb::LlmRequest::SizeType32>> multimodal_lengths,
std::optional<std::vector<std::optional<std::string>>> multimodal_uuids,
std::optional<at::Tensor> multimodal_embedding, std::optional<at::Tensor> mrope_rotary_cos_sin,
std::optional<tb::LlmRequest::SizeType32> mrope_position_deltas,
std::optional<LoraTaskIdType> lora_task_id, std::optional<at::Tensor> lora_weights,
@ -344,7 +345,7 @@ void initBindings(nb::module_& m)
new (self) tb::LlmRequest{request_id, max_new_tokens, input_tokens, sampling_config, is_streaming,
end_id, pad_id, embedding_bias_tensor_ptr, bad_words_list_tensor_ptr, stop_words_list_tensor_ptr,
position_ids, prompt_embedding_table_tensor_ptr, prompt_vocab_size, multimodal_hashes,
multimodal_positions, multimodal_lengths, multimodal_embedding_tensor_ptr,
multimodal_positions, multimodal_lengths, multimodal_uuids, multimodal_embedding_tensor_ptr,
mrope_rotary_cos_sin_tensor_ptr, mrope_position_deltas, lora_task_id, lora_weights_tensor_ptr,
lora_config_tensor_ptr, lookahead_config, kv_cache_retention_config, return_log_probs,
return_context_logits, return_generation_logits, draft_tokens, draft_logits_tensor_ptr,
@ -361,18 +362,19 @@ void initBindings(nb::module_& m)
nb::arg("stop_words_list") = std::nullopt, nb::arg("position_ids") = std::nullopt,
nb::arg("prompt_embedding_table") = std::nullopt, nb::arg("prompt_vocab_size") = std::nullopt,
nb::arg("multimodal_hashes") = std::nullopt, nb::arg("multimodal_positions") = std::nullopt,
nb::arg("multimodal_lengths") = std::nullopt, nb::arg("multimodal_embedding") = std::nullopt,
nb::arg("mrope_rotary_cos_sin") = std::nullopt, nb::arg("mrope_position_deltas") = std::nullopt,
nb::arg("lora_task_id") = std::nullopt, nb::arg("lora_weights") = std::nullopt,
nb::arg("lora_config") = std::nullopt, nb::arg("lookahead_config") = std::nullopt,
nb::arg("kv_cache_retention_config") = std::nullopt, nb::arg("return_log_probs") = false,
nb::arg("return_context_logits") = false, nb::arg("return_generation_logits") = false,
nb::arg("draft_tokens") = std::nullopt, nb::arg("draft_logits") = std::nullopt,
nb::arg("exclude_input_from_output") = false, nb::arg("logits_post_processor") = std::nullopt,
nb::arg("apply_logits_post_processor_batched") = false, nb::arg("encoder_input_tokens") = std::nullopt,
nb::arg("return_encoder_output") = false, nb::arg("client_id") = std::nullopt,
nb::arg("priority") = executor::Request::kDefaultPriority, nb::arg("encoder_input_features") = std::nullopt,
nb::arg("encoder_output_len") = std::nullopt, nb::arg("cross_attention_mask") = std::nullopt,
nb::arg("multimodal_lengths") = std::nullopt, nb::arg("multimodal_uuids") = std::nullopt,
nb::arg("multimodal_embedding") = std::nullopt, nb::arg("mrope_rotary_cos_sin") = std::nullopt,
nb::arg("mrope_position_deltas") = std::nullopt, nb::arg("lora_task_id") = std::nullopt,
nb::arg("lora_weights") = std::nullopt, nb::arg("lora_config") = std::nullopt,
nb::arg("lookahead_config") = std::nullopt, nb::arg("kv_cache_retention_config") = std::nullopt,
nb::arg("return_log_probs") = false, nb::arg("return_context_logits") = false,
nb::arg("return_generation_logits") = false, nb::arg("draft_tokens") = std::nullopt,
nb::arg("draft_logits") = std::nullopt, nb::arg("exclude_input_from_output") = false,
nb::arg("logits_post_processor") = std::nullopt, nb::arg("apply_logits_post_processor_batched") = false,
nb::arg("encoder_input_tokens") = std::nullopt, nb::arg("return_encoder_output") = false,
nb::arg("client_id") = std::nullopt, nb::arg("priority") = executor::Request::kDefaultPriority,
nb::arg("encoder_input_features") = std::nullopt, nb::arg("encoder_output_len") = std::nullopt,
nb::arg("cross_attention_mask") = std::nullopt,
nb::arg("llm_request_type") = tb::LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION,
nb::arg("input_token_extra_ids") = std::nullopt, nb::arg("num_return_sequences") = 1,
nb::arg("eagle_config") = std::nullopt, nb::arg("skip_cross_attn_blocks") = std::nullopt,

View File

@ -93,6 +93,7 @@ std::shared_ptr<tb::LlmRequest> LlmRequest::toTrtLlm() const
mMultimodalHashes, //
mMultimodalPositions, //
mMultimodalLengths, //
mMultimodalUuids, //
from_torch(mMultimodalEmbedding), //
from_torch(mMropeRotaryCosSin), //
mMropePositionDeltas, //

View File

@ -61,6 +61,7 @@ public:
std::optional<std::vector<std::vector<SizeType32>>> multimodalHashes = std::nullopt,
std::optional<std::vector<SizeType32>> multimodalPositions = std::nullopt,
std::optional<std::vector<SizeType32>> multimodalLengths = std::nullopt,
std::optional<std::vector<std::optional<std::string>>> multimodalUuids = std::nullopt,
std::optional<TensorPtr> multimodalEmbedding = std::nullopt,
std::optional<TensorPtr> mropeRotaryCosSin = std::nullopt,
std::optional<SizeType32> mropePositionDeltas = std::nullopt,
@ -111,6 +112,9 @@ public:
multimodalLengths.has_value()
? std::make_shared<std::vector<SizeType32>>(std::move(multimodalLengths.value())) //
: std::optional<std::shared_ptr<std::vector<SizeType32>>>(std::nullopt), //
multimodalUuids.has_value()
? std::make_shared<std::vector<std::optional<std::string>>>(std::move(multimodalUuids.value())) //
: std::optional<std::shared_ptr<std::vector<std::optional<std::string>>>>(std::nullopt), //
multimodalEmbedding, //
mropeRotaryCosSin, //
mropePositionDeltas, //

View File

@ -225,15 +225,20 @@ void initBindings(nb::module_& m)
.def_prop_ro("mm_keys",
[](tle::KVCacheStoredBlockData const& self)
{
// Convert std::vector<MmKey> to Python list of tuples (bytes, int)
// MmKey = std::pair<std::array<uint8_t, 32>, SizeType32>
// Convert std::vector<MmKey> to Python list of tuples (bytes, int, optional<str>)
// MmKey = struct { hash, startOffset, uuid }
nb::list result;
for (auto const& mmKey : self.mmKeys)
{
auto const& hashArray = mmKey.first;
auto offset = mmKey.second;
nb::bytes hashBytes(reinterpret_cast<char const*>(hashArray.data()), hashArray.size());
result.append(nb::make_tuple(hashBytes, offset));
nb::bytes hashBytes(reinterpret_cast<char const*>(mmKey.hash.data()), mmKey.hash.size());
if (mmKey.uuid.has_value())
{
result.append(nb::make_tuple(hashBytes, mmKey.startOffset, mmKey.uuid.value()));
}
else
{
result.append(nb::make_tuple(hashBytes, mmKey.startOffset, nb::none()));
}
}
return result;
});

View File

@ -305,22 +305,29 @@ void initRequestBindings(nb::module_& m)
.def("__setstate__", loraConfigSetstate);
auto multimodalInputGetstate = [](tle::MultimodalInput const& self)
{ return nb::make_tuple(self.getMultimodalHashes(), self.getMultimodalPositions(), self.getMultimodalLengths()); };
{
return nb::make_tuple(self.getMultimodalHashes(), self.getMultimodalPositions(), self.getMultimodalLengths(),
self.getMultimodalUuids());
};
auto multimodalInputSetstate = [](tle::MultimodalInput& multimodalInput, nb::tuple const& state)
{
if (state.size() != 3)
if (state.size() != 4)
{
throw std::runtime_error("Invalid MultimodalInput state!");
}
new (&multimodalInput) tle::MultimodalInput(nb::cast<std::vector<std::vector<SizeType32>>>(state[0]),
nb::cast<std::vector<SizeType32>>(state[1]), nb::cast<std::vector<SizeType32>>(state[2]));
nb::cast<std::vector<SizeType32>>(state[1]), nb::cast<std::vector<SizeType32>>(state[2]),
nb::cast<std::optional<std::vector<std::optional<std::string>>>>(state[3]));
};
nb::class_<tle::MultimodalInput>(m, "MultimodalInput")
.def(nb::init<std::vector<std::vector<SizeType32>>, std::vector<SizeType32>, std::vector<SizeType32>>(),
nb::arg("multimodal_hashes"), nb::arg("multimodal_positions"), nb::arg("multimodal_lengths"))
.def(nb::init<std::vector<std::vector<SizeType32>>, std::vector<SizeType32>, std::vector<SizeType32>,
std::optional<std::vector<std::optional<std::string>>>>(),
nb::arg("multimodal_hashes"), nb::arg("multimodal_positions"), nb::arg("multimodal_lengths"),
nb::arg("multimodal_uuids") = nb::none())
.def_prop_ro("multimodal_hashes", &tle::MultimodalInput::getMultimodalHashes)
.def_prop_ro("multimodal_positions", &tle::MultimodalInput::getMultimodalPositions)
.def_prop_ro("multimodal_lengths", &tle::MultimodalInput::getMultimodalLengths)
.def_prop_ro("multimodal_uuids", &tle::MultimodalInput::getMultimodalUuids)
.def("__getstate__", multimodalInputGetstate)
.def("__setstate__", multimodalInputSetstate);
@ -703,7 +710,7 @@ void initRequestBindings(nb::module_& m)
nb::arg("guided_decoding_params") = nb::none(),
nb::arg("language_adapter_uid") = nb::none(),
nb::arg("allotted_time_ms") = nb::none(),
nb::arg("cache_salt_id") = nb::none(),
nb::arg("cache_salt_id") = nb::none(),
nb::arg("disagg_request_id") = nb::none()
) // clang-format on
.def_prop_ro("input_token_ids", &tle::Request::getInputTokenIds)

View File

@ -968,10 +968,10 @@ TEST_F(TrtGptModelTest, PauseRequestStats)
auto llmRequest = std::make_shared<LlmRequest>(correlationId, maxNewTokens, tokens, inSamplingConfig, false,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, std::nullopt,
false, std::nullopt, false, std::nullopt, executor::Request::kDefaultPriority, std::nullopt, std::nullopt,
std::nullopt, LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::nullopt, 1, std::nullopt,
std::nullopt, true /* returnPerfMetrics */);
std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false,
std::nullopt, false, std::nullopt, false, std::nullopt, executor::Request::kDefaultPriority, std::nullopt,
std::nullopt, std::nullopt, LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::nullopt, 1,
std::nullopt, std::nullopt, true /* returnPerfMetrics */);
RequestList requestList{llmRequest};

View File

@ -1082,8 +1082,8 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdTest)
auto llmRequest0 = std::make_shared<LlmRequest>(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, std::nullopt,
false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false,
std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt,
LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, inputTokenExtraIds, numReturnSequences);
GenerationRequest seq0{requestId, inputLength, beamWidth, blockManager.getWindowSizesMetadata()};
@ -1119,8 +1119,8 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdTest)
auto llmRequest1 = std::make_shared<LlmRequest>(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, std::nullopt,
false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false,
std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt,
LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, inputTokenExtraIds, numReturnSequences);
GenerationRequest seq1{requestId, inputLength, beamWidth, blockManager.getWindowSizesMetadata()};
@ -1151,9 +1151,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdTest)
llmRequest0 = std::make_shared<LlmRequest>(seq0_dup.getRequestId(), maxNewTokens, inputTokens, samplingConfig,
isStreaming, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false,
std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt,
LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, inputTokenExtraIds, numReturnSequences);
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt,
std::nullopt, false, std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt,
std::nullopt, LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, inputTokenExtraIds, numReturnSequences);
promptLen0 = llmRequest0->getNumTokens(beamIdx);
numContextBlocks0 = tc::ceilDiv(promptLen0, blockManager.getTokensPerBlock());
blockManager.holdSequence(seq0_dup.getRequestId());
@ -1175,9 +1175,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdTest)
llmRequest1 = std::make_shared<LlmRequest>(seq1_dup.getRequestId(), maxNewTokens, inputTokens1, samplingConfig,
isStreaming, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false,
std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt,
LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, inputTokenExtraIds1, numReturnSequences);
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt,
std::nullopt, false, std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt,
std::nullopt, LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, inputTokenExtraIds1, numReturnSequences);
promptLen1 = llmRequest1->getNumTokens(beamIdx);
numContextBlocks1 = tc::ceilDiv(promptLen1, blockManager.getTokensPerBlock());
blockManager.holdSequence(seq1_dup.getRequestId());
@ -1207,8 +1207,8 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdTest)
auto llmRequest2 = std::make_shared<LlmRequest>(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, std::nullopt,
false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false,
std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt,
LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, inputTokenExtraIds2, numReturnSequences);
numTokens = llmRequest2->getNumTokens(beamIdx);
@ -1235,8 +1235,8 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdTest)
auto llmRequest3 = std::make_shared<LlmRequest>(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, std::nullopt,
false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false,
std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt,
LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, inputTokenExtraIds3, numReturnSequences);
numTokens = llmRequest3->getNumTokens(beamIdx);
@ -1312,9 +1312,10 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithMultimodalHashTest)
auto llmRequest0 = std::make_shared<LlmRequest>(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
multimodalHashes, multimodalPositions, multimodalLengths, std::nullopt, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt,
std::nullopt, false, std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt,
std::nullopt, LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::nullopt, numReturnSequences);
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false,
std::nullopt, std::nullopt, false, std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt,
std::nullopt, std::nullopt, LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::nullopt,
numReturnSequences);
GenerationRequest seq0{requestId, inputLength, beamWidth, blockManager.getWindowSizesMetadata()};
@ -1354,9 +1355,10 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithMultimodalHashTest)
auto llmRequest1 = std::make_shared<LlmRequest>(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
multimodalHashes, multimodalPositions, multimodalLengths, std::nullopt, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt,
std::nullopt, false, std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt,
std::nullopt, LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::nullopt, numReturnSequences);
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false,
std::nullopt, std::nullopt, false, std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt,
std::nullopt, std::nullopt, LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::nullopt,
numReturnSequences);
GenerationRequest seq1{requestId, inputLength, beamWidth, blockManager.getWindowSizesMetadata()};
// should reuse blocks 0, 1 and get new block 3
@ -1391,9 +1393,10 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithMultimodalHashTest)
auto llmRequest2 = std::make_shared<LlmRequest>(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
multimodalHashes2, multimodalPositions2, multimodalLengths2, std::nullopt, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt,
std::nullopt, false, std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt,
std::nullopt, LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::nullopt, numReturnSequences);
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false,
std::nullopt, std::nullopt, false, std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt,
std::nullopt, std::nullopt, LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::nullopt,
numReturnSequences);
GenerationRequest seq2{requestId, inputLength, beamWidth, blockManager.getWindowSizesMetadata()};
// no reuse, get new blocks 4, 5, 6
@ -1427,9 +1430,10 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithMultimodalHashTest)
auto llmRequest3 = std::make_shared<LlmRequest>(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
multimodalHashes3, multimodalPositions3, multimodalLengths3, std::nullopt, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt,
std::nullopt, false, std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt,
std::nullopt, LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::nullopt, numReturnSequences);
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false,
std::nullopt, std::nullopt, false, std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt,
std::nullopt, std::nullopt, LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::nullopt,
numReturnSequences);
GenerationRequest seq3{requestId, inputLength, beamWidth, blockManager.getWindowSizesMetadata()};
// reuse block 0, get new blocks 7, 8
auto promptLen3 = llmRequest3->getNumTokens(beamIdx);
@ -1498,7 +1502,7 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithLoraTaskIdTest)
LlmRequest::RequestIdType requestId{0};
auto llmRequest0 = std::make_shared<LlmRequest>(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, loraTaskId);
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, loraTaskId);
GenerationRequest seq0{requestId, inputLength, beamWidth, blockManager.getWindowSizesMetadata()};
///////////////////////////////////////////////////////////////////////////
@ -1533,7 +1537,7 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithLoraTaskIdTest)
requestId = 1;
auto llmRequest1 = std::make_shared<LlmRequest>(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, loraTaskId);
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, loraTaskId);
GenerationRequest seq1{requestId, inputLength, beamWidth, blockManager.getWindowSizesMetadata()};
// reuse blocks 0, 1 and get new block 3
@ -1563,7 +1567,8 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithLoraTaskIdTest)
GenerationRequest seq0_dup{10, inputLength, beamWidth, blockManager.getWindowSizesMetadata()};
llmRequest0 = std::make_shared<LlmRequest>(seq0_dup.getRequestId(), maxNewTokens, inputTokens, samplingConfig,
isStreaming, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, loraTaskId);
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
loraTaskId);
promptLen0 = llmRequest0->getNumTokens(beamIdx);
numContextBlocks0 = tc::ceilDiv(promptLen0, blockManager.getTokensPerBlock());
blockManager.holdSequence(seq0_dup.getRequestId());
@ -1586,7 +1591,8 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithLoraTaskIdTest)
GenerationRequest seq1_dup{11, inputLength, beamWidth, blockManager.getWindowSizesMetadata()};
llmRequest1 = std::make_shared<LlmRequest>(seq1_dup.getRequestId(), maxNewTokens, inputTokens1, samplingConfig,
isStreaming, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, loraTaskId);
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
loraTaskId);
promptLen1 = llmRequest1->getNumTokens(beamIdx);
numContextBlocks1 = tc::ceilDiv(promptLen1, blockManager.getTokensPerBlock());
// reuse 0, 1, 2(p) ([0,1,2,3], [4,5,6,7], [8])
@ -1617,7 +1623,7 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithLoraTaskIdTest)
requestId = 2;
auto llmRequest2 = std::make_shared<LlmRequest>(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, loraTaskId);
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, loraTaskId);
numTokens = llmRequest2->getNumTokens(beamIdx);
GenerationRequest seq2{requestId, numTokens, beamWidth, blockManager.getWindowSizesMetadata()};
@ -1648,7 +1654,7 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithLoraTaskIdTest)
requestId = 3;
auto llmRequest3 = std::make_shared<LlmRequest>(requestId, maxNewTokens, inputTokens3, samplingConfig, isStreaming,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, loraTaskId);
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, loraTaskId);
numTokens = llmRequest3->getNumTokens(beamIdx);
GenerationRequest seq3{requestId, numTokens, beamWidth, blockManager.getWindowSizesMetadata()};
@ -1680,7 +1686,7 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithLoraTaskIdTest)
requestId = 4;
auto llmRequest4 = std::make_shared<LlmRequest>(requestId, maxNewTokens, inputTokens4, samplingConfig, isStreaming,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, loraTaskId);
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, loraTaskId);
numTokens = llmRequest4->getNumTokens(beamIdx);
GenerationRequest seq4{requestId, numTokens, beamWidth, blockManager.getWindowSizesMetadata()};
@ -1775,9 +1781,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdAndLoraTaskIdTest)
LlmRequest::RequestIdType requestId{0};
auto llmRequest0 = std::make_shared<LlmRequest>(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, loraTaskId1, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, std::nullopt,
false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, loraTaskId1,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false,
std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt,
LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, inputTokenExtraIds);
GenerationRequest seq0{requestId, inputLength, beamWidth, blockManager.getWindowSizesMetadata()};
@ -1813,9 +1819,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdAndLoraTaskIdTest)
LlmRequest::LoraTaskIdType loraTaskId2 = static_cast<LlmRequest::LoraTaskIdType>(2);
auto llmRequest1 = std::make_shared<LlmRequest>(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, loraTaskId2, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, std::nullopt,
false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, loraTaskId2,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false,
std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt,
LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, inputTokenExtraIds);
GenerationRequest seq1{requestId, inputLength, beamWidth, blockManager.getWindowSizesMetadata()};
@ -1844,10 +1850,10 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdAndLoraTaskIdTest)
GenerationRequest seq0_dup{10, inputLength, beamWidth, blockManager.getWindowSizesMetadata()};
llmRequest0 = std::make_shared<LlmRequest>(seq0_dup.getRequestId(), maxNewTokens, inputTokens, samplingConfig,
isStreaming, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, loraTaskId1,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false,
std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt,
LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, inputTokenExtraIds);
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
loraTaskId1, std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt,
std::nullopt, false, std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt,
std::nullopt, LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, inputTokenExtraIds);
promptLen0 = llmRequest0->getNumTokens(beamIdx);
numContextBlocks0 = tc::ceilDiv(promptLen0, blockManager.getTokensPerBlock());
// reuse blocks 0, 1 and get new block 6
@ -1869,10 +1875,10 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdAndLoraTaskIdTest)
GenerationRequest seq1_dup{11, inputLength, beamWidth, blockManager.getWindowSizesMetadata()};
llmRequest1 = std::make_shared<LlmRequest>(seq1_dup.getRequestId(), maxNewTokens, inputTokens1, samplingConfig,
isStreaming, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, loraTaskId2,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false,
std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt,
LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, inputTokenExtraIds1);
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
loraTaskId2, std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt,
std::nullopt, false, std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt,
std::nullopt, LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, inputTokenExtraIds1);
promptLen1 = llmRequest1->getNumTokens(beamIdx);
numContextBlocks1 = tc::ceilDiv(promptLen1, blockManager.getTokensPerBlock());
blockManager.holdSequence(seq1_dup.getRequestId());
@ -1900,9 +1906,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdAndLoraTaskIdTest)
requestId = 2;
auto llmRequest2 = std::make_shared<LlmRequest>(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, loraTaskId1, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, std::nullopt,
false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, loraTaskId1,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false,
std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt,
LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, inputTokenExtraIds2);
numTokens = llmRequest2->getNumTokens(beamIdx);
@ -1928,9 +1934,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdAndLoraTaskIdTest)
requestId = 3;
auto llmRequest3 = std::make_shared<LlmRequest>(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, loraTaskId1, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, std::nullopt,
false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, loraTaskId1,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false,
std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt,
LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, inputTokenExtraIds3);
numTokens = llmRequest3->getNumTokens(beamIdx);
@ -1955,9 +1961,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdAndLoraTaskIdTest)
requestId = 4;
auto llmRequest4 = std::make_shared<LlmRequest>(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, loraTaskId2, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, std::nullopt,
false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, loraTaskId2,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false,
std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt,
LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, inputTokenExtraIds3);
numTokens = llmRequest4->getNumTokens(beamIdx);
@ -2034,8 +2040,8 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithCacheSaltIdTest)
auto llmRequest0 = std::make_shared<LlmRequest>(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, std::nullopt,
false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false,
std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt,
LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::nullopt, numReturnSequences, std::nullopt,
std::nullopt, false, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
std::nullopt); // No cache_salt_id
@ -2075,8 +2081,8 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithCacheSaltIdTest)
auto llmRequest1 = std::make_shared<LlmRequest>(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, std::nullopt,
false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false,
std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt,
LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::nullopt, numReturnSequences, std::nullopt,
std::nullopt, false, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
cacheSaltId1); // With cache_salt_id = 12345
@ -2110,8 +2116,8 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithCacheSaltIdTest)
auto llmRequest2 = std::make_shared<LlmRequest>(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, std::nullopt,
false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false,
std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt,
LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::nullopt, numReturnSequences, std::nullopt,
std::nullopt, false, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
cacheSaltId1); // Same cache_salt_id = 12345
@ -2146,8 +2152,8 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithCacheSaltIdTest)
auto llmRequest3 = std::make_shared<LlmRequest>(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, std::nullopt,
false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false,
std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt,
LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::nullopt, numReturnSequences, std::nullopt,
std::nullopt, false, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
cacheSaltId2); // Different cache_salt_id = 67890
@ -2175,8 +2181,8 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithCacheSaltIdTest)
auto llmRequest4 = std::make_shared<LlmRequest>(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, std::nullopt,
false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false,
std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt,
LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::nullopt, numReturnSequences, std::nullopt,
std::nullopt, false, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
std::nullopt); // No cache_salt_id

View File

@ -64,7 +64,7 @@ protected:
/*badWordsList=*/std::nullopt, /*stopWordsList=*/std::nullopt, /*positionIds=*/std::nullopt,
/*promptEmbeddingTable=*/std::nullopt, /*promptVocabSize=*/std::nullopt,
/*multimodalHashes=*/std::nullopt, /*multimodalPos=*/std::nullopt, /*multimodalLength=*/std::nullopt,
/*multimodalEmbedding=*/std::nullopt,
/*multimodalUuids=*/std::nullopt, /*multimodalEmbedding=*/std::nullopt,
/*mropeRotaryCosSin=*/std::nullopt, /*mropePositionDeltas*/ std::nullopt,
/*loraTaskId=*/std::nullopt, /*loraWeights=*/std::nullopt,
/*loraConfig=*/std::nullopt, /*lookaheadConfig=*/std::nullopt, /*kvCacheRetentionConfig=*/std::nullopt,

View File

@ -1092,7 +1092,7 @@ TEST(SerializeUtilsTest, BlockKeyWithExtras)
h1[i] = static_cast<uint8_t>(i);
h2[i] = static_cast<uint8_t>(255 - i);
}
std::vector<MmKey> extraKeys{{h1, SizeType32{0}}, {h2, SizeType32{5}}};
std::vector<MmKey> extraKeys{{h1, SizeType32{0}, std::nullopt}, {h2, SizeType32{5}, std::nullopt}};
VecUniqueTokens uniqueTokens{UniqueToken{10, 100}, UniqueToken{20, 200}};
std::optional<LoraTaskIdType> loraTaskId = LoraTaskIdType{42};
@ -1103,6 +1103,144 @@ TEST(SerializeUtilsTest, BlockKeyWithExtras)
testSerializeDeserialize(key);
}
TEST(SerializeUtilsTest, MmKeyWithUuid)
{
using tensorrt_llm::executor::MmKey;
// Test MmKey serialization with UUID
std::array<uint8_t, 32> hash{};
for (size_t i = 0; i < hash.size(); ++i)
{
hash[i] = static_cast<uint8_t>(i * 3);
}
// Test with UUID
MmKey keyWithUuid{hash, SizeType32{42}, std::string("test-image-uuid-12345")};
testSerializeDeserialize(keyWithUuid);
// Test without UUID (nullopt)
MmKey keyNoUuid{hash, SizeType32{100}, std::nullopt};
testSerializeDeserialize(keyNoUuid);
// Test with empty string UUID
MmKey keyEmptyUuid{hash, SizeType32{0}, std::string("")};
testSerializeDeserialize(keyEmptyUuid);
// Test with long UUID (> 32 bytes)
MmKey keyLongUuid{hash, SizeType32{255}, std::string("this-is-a-very-long-uuid-that-exceeds-32-bytes-for-testing")};
testSerializeDeserialize(keyLongUuid);
}
TEST(SerializeUtilsTest, BlockKeyWithExtrasAndUuids)
{
using namespace tensorrt_llm::batch_manager::kv_cache_manager;
// Prepare multimodal extra keys with mixed UUIDs
std::array<uint8_t, 32> h1{};
std::array<uint8_t, 32> h2{};
std::array<uint8_t, 32> h3{};
for (size_t i = 0; i < h1.size(); ++i)
{
h1[i] = static_cast<uint8_t>(i);
h2[i] = static_cast<uint8_t>(255 - i);
h3[i] = static_cast<uint8_t>(i * 2);
}
// Mix of UUIDs: one with UUID, one without, one with empty string
std::vector<MmKey> extraKeys{{h1, SizeType32{0}, std::string("sku-image-001")}, {h2, SizeType32{5}, std::nullopt},
{h3, SizeType32{10}, std::string("")}};
VecUniqueTokens uniqueTokens{UniqueToken{10, 100}, UniqueToken{20, 200}};
std::optional<LoraTaskIdType> loraTaskId = LoraTaskIdType{42};
BlockKey key(true, loraTaskId, uniqueTokens, extraKeys);
testSerializeDeserialize(key);
}
TEST(SerializeUtilsTest, MultimodalInputWithUuids)
{
using tensorrt_llm::executor::MultimodalInput;
// Helper to verify MultimodalInput serialization round-trip
auto verifyMultimodalInput = [](MultimodalInput const& original)
{
auto size = tensorrt_llm::executor::Serialization::serializedSize(original);
std::ostringstream oss;
tensorrt_llm::executor::Serialization::serialize(original, oss);
EXPECT_EQ(oss.str().size(), size);
std::istringstream iss(oss.str());
auto deserialized = tensorrt_llm::executor::Serialization::deserializeMultimodalInput(iss);
// Verify hashes
auto origHashes = original.getMultimodalHashes();
auto deserHashes = deserialized.getMultimodalHashes();
EXPECT_EQ(origHashes.size(), deserHashes.size());
for (size_t i = 0; i < origHashes.size(); ++i)
{
EXPECT_EQ(origHashes[i], deserHashes[i]);
}
// Verify positions
auto origPositions = original.getMultimodalPositions();
auto deserPositions = deserialized.getMultimodalPositions();
EXPECT_EQ(origPositions, deserPositions);
// Verify lengths
auto origLengths = original.getMultimodalLengths();
auto deserLengths = deserialized.getMultimodalLengths();
EXPECT_EQ(origLengths, deserLengths);
// Verify UUIDs
auto origUuids = original.getMultimodalUuids();
auto deserUuids = deserialized.getMultimodalUuids();
EXPECT_EQ(origUuids.has_value(), deserUuids.has_value());
if (origUuids.has_value() && deserUuids.has_value())
{
EXPECT_EQ(origUuids->size(), deserUuids->size());
for (size_t i = 0; i < origUuids->size(); ++i)
{
EXPECT_EQ((*origUuids)[i], (*deserUuids)[i]);
}
}
};
// Test MultimodalInput with UUIDs
std::vector<std::vector<SizeType32>> hashes = {
{1, 2, 3, 4, 5, 6, 7, 8}, // First image hash
{10, 20, 30, 40, 50, 60, 70, 80} // Second image hash
};
std::vector<SizeType32> positions = {0, 100};
std::vector<SizeType32> lengths = {50, 75};
// Test with full UUIDs
std::vector<std::optional<std::string>> uuids = {std::string("image-uuid-001"), std::string("image-uuid-002")};
MultimodalInput inputWithUuids(hashes, positions, lengths, uuids);
verifyMultimodalInput(inputWithUuids);
// Test with partial UUIDs (mixed Some and None)
std::vector<std::optional<std::string>> partialUuids = {std::string("uuid-a"), std::nullopt};
MultimodalInput inputPartialUuids(hashes, positions, lengths, partialUuids);
verifyMultimodalInput(inputPartialUuids);
// Test without UUIDs (nullopt)
MultimodalInput inputNoUuids(hashes, positions, lengths, std::nullopt);
verifyMultimodalInput(inputNoUuids);
// Test with empty string UUID
std::vector<std::optional<std::string>> emptyUuids = {std::string(""), std::string("valid-uuid")};
MultimodalInput inputEmptyUuid(hashes, positions, lengths, emptyUuids);
verifyMultimodalInput(inputEmptyUuid);
// Test with long UUIDs (> 32 bytes)
std::vector<std::optional<std::string>> longUuids
= {std::string("this-is-a-very-long-uuid-string-that-exceeds-the-32-byte-limit-for-testing-purposes"),
std::string("short")};
MultimodalInput inputLongUuids(hashes, positions, lengths, longUuids);
verifyMultimodalInput(inputLongUuids);
}
// Connection notification tests
namespace kv_cache = tensorrt_llm::executor::kv_cache;

View File

@ -64,6 +64,39 @@ KV cache salting provides a security mechanism to control which requests can reu
To use cache salting, specify the `cache_salt` parameter as a string when creating requests. Only requests with matching cache salt values can share cached KV blocks. The salt value can be any non-empty string, such as a user ID, tenant ID, or hash string.
### Multimodal UUID Support for Cache Identification
When working with multimodal models (e.g., vision-language models), the KV cache system needs to identify which cached blocks correspond to which multimodal inputs (images, videos, etc.). By default, the system uses content-based hashing to generate unique identifiers for each multimodal input. However, this approach has limitations for cache management across sessions, as the same content must be re-processed to generate the same hash.
To enable deterministic cache management, you can provide custom UUID strings for your multimodal data using the `multi_modal_uuids` parameter when creating requests. When provided, these UUIDs are returned in KV cache events instead of computed content hashes, while the cache key itself is computed from **both** the UUID and content together for correctness.
**Usage Example:**
```python
from tensorrt_llm.inputs import TextPrompt
# Provide custom UUIDs for your images
prompt = TextPrompt(
prompt="Describe these images.",
multi_modal_data={"image": [image1, image2]},
multi_modal_uuids={"image": ["image-uuid-001", "image-uuid-002"]}
)
```
**Key Features:**
- **Cache Correctness**: When a UUID is provided, the cache key is computed from both the UUID and content together using `BLAKE3(UUID || Content)`. This ensures different content always produces different cache entries, even with the same UUID.
- **User Isolation**: Same content with different UUIDs produces different cache entries, enabling per-user or per-session cache isolation.
- **Stable Event Identifiers**: The original UUID string is preserved and returned in KV cache events via `get_kv_cache_events()`, enabling deterministic external cache management.
- **Partial UUID Support**: You can provide UUIDs for some items and use `None` for others to fall back to content-only hashing.
- **Cross-Modality Support**: Different modalities (images, videos) can each have their own UUIDs.
**UUID Format:**
- Can be any string (e.g., "image-123", "user-session-img-a", database keys)
- Original UUID strings are preserved and returned in KV cache events
### Enable Offloading to Host Memory
Before a block is evicted from GPU memory, it can optionally be offloaded to host (CPU) memory. The block remains reusable until it is evicted from host memory. When an offloaded block is reused, it is first copied back into GPU memory. Offloading is controlled with property ```host_cache_size``` which specifies how much host memory (in bytes) should be allocated for offloading. The default is 0.

View File

@ -180,7 +180,8 @@ class KvCacheCreator:
req_mm_input = trtllm.MultimodalInput(
multimodal_hashes=multimodal_input.multimodal_hashes,
multimodal_positions=multimodal_input.multimodal_positions,
multimodal_lengths=multimodal_input.multimodal_lengths
multimodal_lengths=multimodal_input.multimodal_lengths,
multimodal_uuids=multimodal_input.multimodal_uuids
) if multimodal_input else None
request = trtllm.Request(prompt_token_ids,

View File

@ -846,10 +846,12 @@ def executor_request_to_llm_request(
multimodal_hashes = None
multimodal_positions = None
multimodal_lengths = None
multimodal_uuids = None
if executor_request.multimodal_input is not None:
multimodal_hashes = executor_request.multimodal_input.multimodal_hashes
multimodal_positions = executor_request.multimodal_input.multimodal_positions
multimodal_lengths = executor_request.multimodal_input.multimodal_lengths
multimodal_uuids = executor_request.multimodal_input.multimodal_uuids
# Extract mrope fields
mrope_rotary_cos_sin = None
@ -879,6 +881,7 @@ def executor_request_to_llm_request(
multimodal_hashes=multimodal_hashes,
multimodal_positions=multimodal_positions,
multimodal_lengths=multimodal_lengths,
multimodal_uuids=multimodal_uuids,
multimodal_embedding=executor_request.multimodal_embedding,
lora_task_id=executor_request.lora_config.task_id
if executor_request.lora_config is not None else None,

View File

@ -1205,14 +1205,24 @@ class KVCacheEventSerializer:
@staticmethod
def _mm_key_to_json(data):
# MmKey is a pair of (array<uint8_t, 32>, SizeType32)
hash_array, start_offset = data
# MmKey is a tuple of (hash_bytes, start_offset, uuid)
# where uuid is optional (None if content-hashed)
if len(data) == 3:
hash_array, start_offset, uuid = data
else:
# Backward compatibility: old format (hash_array, start_offset)
hash_array, start_offset = data
uuid = None
# Convert array to hex string
hash_hex = ''.join(f'{b:02x}' for b in hash_array)
# Use UUID from C++ if available, otherwise use hash_hex
hash_or_uuid = uuid if uuid is not None else hash_hex
return {
"type": "mm_key",
"hash": hash_hex,
"hash": hash_or_uuid,
"start_offset": start_offset
}

View File

@ -427,7 +427,9 @@ class BaseWorker(GenerationExecutor):
multimodal_positions=request.multimodal_params.
multimodal_input.multimodal_positions,
multimodal_lengths=request.multimodal_params.
multimodal_input.multimodal_lengths)
multimodal_input.multimodal_lengths,
multimodal_uuids=request.multimodal_params.multimodal_input.
multimodal_uuids)
# NOTE: Setting to None here to avoid sending multimodal_input again through the 'py_multimodal_data' field
request.multimodal_params.multimodal_input = None

View File

@ -17,6 +17,15 @@ class TextPrompt(TypedDict):
if the model supports it.
"""
multi_modal_uuids: NotRequired[Dict[str, List[Any]]]
"""
Optional user-provided UUIDs for multimodal items.
Structure mirrors multi_modal_data: {"image": ["uuid1", None, "uuid3"]}.
When a UUID is provided for an item, it will be returned in KV cache events
instead of the computed content hash. Use None to fall back to content
hashing for specific items.
"""
mm_processor_kwargs: NotRequired[Dict[str, Any]]
"""
Optional multi-modal processor kwargs to be forwarded to the
@ -39,6 +48,15 @@ class TokensPrompt(TypedDict):
if the model supports it.
"""
multi_modal_uuids: NotRequired[Dict[str, List[Any]]]
"""
Optional user-provided UUIDs for multimodal items.
Structure mirrors multi_modal_data: {"image": ["uuid1", None, "uuid3"]}.
When a UUID is provided for an item, it will be returned in KV cache events
instead of the computed content hash. Use None to fall back to content
hashing for specific items.
"""
mm_processor_kwargs: NotRequired[Dict[str, Any]]
"""
Optional multi-modal processor kwargs to be forwarded to the

View File

@ -38,6 +38,22 @@ class MultimodalInput:
(e.g., image_end_token, image_break_token for mistral3) mixed with the actual multimodal tokens.
"""
multimodal_uuids: Optional[List[Optional[str]]] = None
"""Optional user-provided UUIDs for multimodal data items.
When provided, these UUIDs will be returned in KV cache events instead of the
computed hash hex string. This enables deterministic cache identification across
sessions using user-defined stable identifiers.
Each element can be:
- A string UUID: Used as the cache identifier (returned in events)
- None: Falls back to content-based hashing for that item
If the UUID string is longer than 32 bytes, it will be hashed internally
for cache key computation, but the original UUID string is preserved and
returned in KV cache events.
"""
def __post_init__(self):
"""Validate input data structure and consistency."""
# Validate multimodal_hashes
@ -69,13 +85,32 @@ class MultimodalInput:
f"positions={len(self.multimodal_positions)}, lengths={len(self.multimodal_lengths)}"
)
# Validate multimodal_uuids if provided
if self.multimodal_uuids is not None:
if not isinstance(self.multimodal_uuids, list):
raise TypeError("multimodal_uuids must be a list")
if len(self.multimodal_uuids) != len(self.multimodal_hashes):
raise ValueError(
f"multimodal_uuids length ({len(self.multimodal_uuids)}) must match "
f"multimodal_hashes length ({len(self.multimodal_hashes)})")
for i, uuid in enumerate(self.multimodal_uuids):
if uuid is not None and not isinstance(uuid, str):
raise TypeError(
f"multimodal_uuids[{i}] must be a string or None, got {type(uuid)}"
)
@classmethod
def from_components(cls, mm_hashes: List[List[int]],
mm_positions: List[int],
mm_lengths: List[int]) -> 'MultimodalInput':
def from_components(
cls,
mm_hashes: List[List[int]],
mm_positions: List[int],
mm_lengths: List[int],
mm_uuids: Optional[List[Optional[str]]] = None
) -> 'MultimodalInput':
return cls(multimodal_hashes=mm_hashes,
multimodal_positions=mm_positions,
multimodal_lengths=mm_lengths)
multimodal_lengths=mm_lengths,
multimodal_uuids=mm_uuids)
def to_tensor(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Convert data to tensors"""
@ -546,46 +581,116 @@ def serialize_item(obj: object) -> bytes:
raise ValueError(f"Unsupported object type: {type(obj)}")
def apply_mm_hashes(mm_data: Dict[str, Any],
hash_lib=default_hasher) -> Dict[str, List[str]]:
"""Apply hashing to multimodal data items."""
def apply_mm_hashes(
mm_data: Dict[str, Any],
mm_uuids: Optional[Dict[str, List[Optional[str]]]] = None,
hash_lib=default_hasher
) -> Tuple[Dict[str, List[str]], Optional[List[Optional[str]]]]:
"""Apply hashing to multimodal data items, combining UUID with content when provided.
def _hash_image(image):
# TODO: possible hash collision w/ this simplified version (vllm/PR/17378)
hasher = hash_lib()
if isinstance(image, torch.Tensor):
When a UUID is provided for an item, the hash is computed from both the UUID
and the content together: BLAKE3(UUID || Content). This ensures:
- Cache correctness: Different content always produces different hashes
- User isolation: Same content with different UUIDs produces different hashes
- The original UUID string is preserved and returned in KV cache events
Args:
mm_data: Dictionary of modality -> data items
mm_uuids: Optional dictionary of modality -> list of UUID strings.
Use None for items that should use content-based hashing only.
hash_lib: Hash function to use (default: blake3)
Returns:
Tuple of:
- Dictionary of modality -> list of hash hex strings (64 chars each)
- Flattened list of original UUID strings (or None for content-hashed items)
"""
def _hash_content(hasher, item):
"""Hash the content of a multimodal item into the provided hasher."""
if isinstance(item, torch.Tensor):
# Ensure tensor is on CPU and contiguous for consistent hashing
image = image.detach().cpu().contiguous()
hasher.update(serialize_item(image))
elif isinstance(image, list):
item = item.detach().cpu().contiguous()
hasher.update(serialize_item(item))
elif isinstance(item, list):
# Hash each frame with a separator to avoid collisions between [A,B] and [AB]
for frame in image:
for frame in item:
hasher.update(b"<frame>")
if isinstance(frame, torch.Tensor):
frame = frame.detach().cpu().contiguous()
hasher.update(serialize_item(frame))
elif isinstance(image, tensorrt_llm.inputs.utils.VideoData):
frames = image.frames
elif isinstance(item, tensorrt_llm.inputs.utils.VideoData):
frames = item.frames
for frame in frames:
hasher.update(b"<frame>")
if isinstance(frame, torch.Tensor):
frame = frame.detach().cpu().contiguous()
hasher.update(serialize_item(frame))
else:
hasher.update(serialize_item(image))
hasher.update(serialize_item(item))
def _hash_item(item):
"""Hash only the content of a multimodal item (no UUID)."""
# TODO: possible hash collision w/ this simplified version (vllm/PR/17378)
hasher = hash_lib()
_hash_content(hasher, item)
return hasher.hexdigest()
def _hash_item_with_uuid(item, uuid: str):
"""Hash UUID and content together: BLAKE3(UUID || Content).
This creates a unique hash that incorporates both the user-provided
identifier and the actual content, ensuring cache correctness while
supporting user-defined cache isolation.
"""
hasher = hash_lib()
# Hash UUID first with delimiters to prevent length-extension ambiguity
hasher.update(b"<uuid>")
hasher.update(uuid.encode('utf-8'))
hasher.update(b"</uuid>")
# Then hash the content
hasher.update(b"<content>")
_hash_content(hasher, item)
hasher.update(b"</content>")
return hasher.hexdigest()
mm_items = {
modality: items if isinstance(items, list) else [items]
for modality, items in mm_data.items()
}
# TODO: need to hash both modality and item to distinguish modality (vllm/PR)
mm_hashes = {
modality: [_hash_image(item) for item in items]
for modality, items in mm_items.items()
}
return mm_hashes
# Collect UUIDs in the same order as items
all_uuids: List[Optional[str]] = []
mm_hashes: Dict[str, List[str]] = {}
for modality, items in mm_items.items():
modality_uuids = None
if mm_uuids is not None and modality in mm_uuids:
modality_uuids = mm_uuids[modality]
if not isinstance(modality_uuids, list):
modality_uuids = [modality_uuids]
if len(modality_uuids) != len(items):
raise ValueError(
f"UUID list length ({len(modality_uuids)}) doesn't match "
f"data items length ({len(items)}) for modality '{modality}'"
)
hashes = []
for i, item in enumerate(items):
uuid = modality_uuids[i] if modality_uuids else None
if uuid is not None:
# Hash UUID + content together for cache correctness
hashes.append(_hash_item_with_uuid(item, uuid))
all_uuids.append(uuid) # Store original UUID
else:
# Fall back to content-only hashing
hashes.append(_hash_item(item))
all_uuids.append(None)
mm_hashes[modality] = hashes
# Return None for uuids if no UUIDs were provided at all
return mm_hashes, all_uuids if mm_uuids is not None else None
def hexdigest_to_int32(hex_digest: str) -> List[int]:
@ -604,6 +709,30 @@ def hexdigest_to_int32(hex_digest: str) -> List[int]:
return result
def int32_to_hexdigest(int32_values: List[int]) -> str:
"""Convert 8 int32 values back to a 64-character hexadecimal digest.
This is the inverse of hexdigest_to_int32.
Args:
int32_values: List of 8 signed int32 values
Returns:
64-character hexadecimal string representing the 32-byte hash
"""
if len(int32_values) != 8:
raise ValueError(f"Expected 8 int32 values, got {len(int32_values)}")
result = []
for value in int32_values:
# Convert signed int32 back to unsigned
if value < 0:
value = value + 0x100000000
# Format as 8 hex characters (zero-padded)
result.append(f'{value:08x}')
return ''.join(result)
def find_mm_token_lengths(mm_data: Dict[str, Any],
input_processor: Any) -> List[int]:
"""Get the maximum contiguous multimodal token lengths from multimodal data items.

View File

@ -664,10 +664,18 @@ def create_input_processor_with_hash(
) -> Tuple[List[int], Optional[ExtraProcessedInputs]]:
"""
Process the multinmodal hashing for media tokens if possible.
Supports optional user-provided UUIDs via 'multi_modal_uuids' in inputs.
When a UUID is provided for a multimodal item, it will be used as the
cache identifier and returned in KV cache events instead of the content hash.
"""
assert 'multi_modal_data' in inputs, "multi_modal_data must be provided for hashing support."
mm_data = inputs['multi_modal_data']
mm_hashes = apply_mm_hashes(mm_data, hash_lib)
# Extract optional UUIDs (can be None, or dict with same structure as mm_data)
mm_uuids = inputs.get('multi_modal_uuids', None)
mm_hashes, mm_uuid_list = apply_mm_hashes(mm_data, mm_uuids, hash_lib)
prompt_token_ids, extra_processed_inputs = input_processor(
inputs, sampling_params)
@ -698,15 +706,15 @@ def create_input_processor_with_hash(
extra_processed_inputs["multimodal_data"][
"special_token_offsets"] = start_special_token_positions
# flatten the hashes from dict to a single list
mm_hashes = [h for hashes in mm_hashes.values() for h in hashes]
validate_mm_inputs(prompt_token_ids, mm_hashes, start_positions,
mm_hashes_flat = [h for hashes in mm_hashes.values() for h in hashes]
validate_mm_inputs(prompt_token_ids, mm_hashes_flat, start_positions,
num_mm_tokens)
mm_hashes_int32 = [hexdigest_to_int32(h) for h in mm_hashes
mm_hashes_int32 = [hexdigest_to_int32(h) for h in mm_hashes_flat
] # nested list w/ multiple int32 per hash
extra_processed_inputs[
"multimodal_input"] = MultimodalInput.from_components(
mm_hashes_int32, start_positions, num_mm_tokens)
mm_hashes_int32, start_positions, num_mm_tokens, mm_uuid_list)
return prompt_token_ids, extra_processed_inputs
def input_processor_wrapper(

View File

@ -181,6 +181,361 @@ def model_dir(request, tmp_path_factory: pytest.TempPathFactory) -> Path:
return request.param
@pytest.mark.parametrize(
"use_uuids,expected_hash_type",
[
# Without UUIDs: mm_key hash should be a 64-char hex string
(False, "hex"),
# With UUIDs: mm_key hash should be the original UUID string
(True, "uuid"),
])
def test_kv_event_mm_keys_with_uuid(use_uuids, expected_hash_type):
"""Test mm_keys in KV cache events return UUID when provided.
This test verifies that when multi_modal_uuids is provided:
1. The KV cache event mm_keys 'hash' field contains the original UUID string
2. Without UUIDs, the hash field contains a 64-char hex string
The UUID feature allows users to provide stable identifiers for multimodal
items, which are returned in KV cache events for external cache management.
"""
encoder_model_dir = _LLAVA_DIR
max_tokens = 16
free_gpu_memory_fraction = 0.2
# Use different images to generate different prompts
prompts = ["Describe the natural environment in the image."]
media = [example_images[0]]
# Define UUIDs if testing with them
test_uuid = "my-test-image-uuid-12345"
sampling_params = SamplingParams(max_tokens=max_tokens)
kv_cache_config = KvCacheConfig(
enable_block_reuse=True,
free_gpu_memory_fraction=free_gpu_memory_fraction,
event_buffer_max_size=1024,
)
llm = LLM(model=encoder_model_dir,
backend='pytorch',
kv_cache_config=kv_cache_config,
max_batch_size=1)
# Load inputs with or without UUIDs
if use_uuids:
# Create inputs with multi_modal_uuids
inputs = _load_inputs_with_uuids(llm, prompts, media, [test_uuid])
else:
inputs = _load_inputs(llm, prompts, media)
with llm:
for inp in inputs:
_ = llm.generate([inp], sampling_params=sampling_params)
# Wait for KV cache events to be dispatched asynchronously
time.sleep(0.5)
events = llm.get_kv_cache_events(50)
# Extract mm_keys from stored events
mm_keys_found = []
for event in events:
if event and event.get("data", {}).get("type") == "stored":
for block in event["data"].get("blocks", []):
mm_keys_found.extend(block.get("mm_keys", []))
# Verify mm_keys were found (multimodal model should have them)
assert len(mm_keys_found) > 0, "Expected mm_keys in stored events"
# Verify the hash field matches expected type
for mm_key in mm_keys_found:
hash_value = mm_key["hash"]
if expected_hash_type == "uuid":
# Should be the original UUID string
assert hash_value == test_uuid, (
f"Expected UUID '{test_uuid}', got '{hash_value}'")
else:
# Should be a 64-char hex string
assert len(hash_value) == 64, (
f"Expected 64-char hex hash, got {len(hash_value)} chars")
# Verify it's valid hex (fromhex will raise ValueError if invalid)
bytes.fromhex(hash_value)
def _load_inputs_with_uuids(llm: LLM, prompts, media, uuids):
"""Load inputs with multi_modal_uuids for testing.
This function uses the same processing pipeline as _load_inputs but adds
multi_modal_uuids to the processed inputs.
"""
# Use the standard loader to get properly processed inputs with image tokens
inputs = _load_inputs(llm, prompts, media)
# Add multi_modal_uuids to the processed inputs
for inp, uuid in zip(inputs, uuids):
inp["multi_modal_uuids"] = {"image": [uuid]}
return inputs
@pytest.mark.parametrize(
"uuids,expected_patterns",
[
# First image has UUID, second uses content hash
(["custom-uuid-first", None], ["custom-uuid-first", "hex"]),
# Both have UUIDs
(["uuid-img-a", "uuid-img-b"], ["uuid-img-a", "uuid-img-b"]),
# Both use content hash (None)
([None, None], ["hex", "hex"]),
])
def test_kv_event_mm_keys_with_partial_uuids(uuids, expected_patterns):
"""Test mm_keys with partial UUIDs (some items with UUID, some without).
This test verifies the mixed UUID scenario where:
1. Some multimodal items have user-provided UUIDs
2. Other items fall back to content-based hashing
3. KV cache events correctly return UUID or hex hash based on input
"""
encoder_model_dir = _LLAVA_DIR
max_tokens = 16
free_gpu_memory_fraction = 0.2
# Two different images with potentially mixed UUIDs
prompt = "Describe both images in detail."
images = [example_images[0], example_images[1]]
sampling_params = SamplingParams(max_tokens=max_tokens)
kv_cache_config = KvCacheConfig(
enable_block_reuse=True,
free_gpu_memory_fraction=free_gpu_memory_fraction,
event_buffer_max_size=1024,
)
llm = LLM(model=encoder_model_dir,
backend='pytorch',
kv_cache_config=kv_cache_config,
max_batch_size=1)
# Load input using the multimodal input loader directly for multiple images per prompt
config_path = os.path.join(llm._hf_model_dir, 'config.json')
with open(config_path, 'r') as f:
model_config = json.load(f)
model_type = model_config['model_type']
inputs = default_multimodal_input_loader(tokenizer=llm.tokenizer,
model_dir=llm._hf_model_dir,
model_type=model_type,
modality="multiple_image",
prompts=[prompt],
media=[images],
image_data_format="pt")
# Add multi_modal_uuids to the processed input
inputs[0]["multi_modal_uuids"] = {"image": uuids}
inp = inputs[0]
with llm:
_ = llm.generate([inp], sampling_params=sampling_params)
# Wait for KV cache events to be dispatched asynchronously
time.sleep(0.5)
events = llm.get_kv_cache_events(50)
# Collect all unique mm_key hashes from stored events
mm_key_hashes = set()
for event in events:
if event and event.get("data", {}).get("type") == "stored":
for block in event["data"].get("blocks", []):
if block.get("mm_keys"):
for mm_key in block["mm_keys"]:
mm_key_hashes.add(mm_key["hash"])
# Verify we got mm_keys
assert len(mm_key_hashes) > 0, "Expected mm_keys in stored events"
# Verify each expected pattern appears in the results
for pattern in expected_patterns:
if pattern == "hex":
# Should find at least one 64-char hex string
hex_found = any(
len(h) == 64 and all(c in '0123456789abcdef' for c in h)
for h in mm_key_hashes)
assert hex_found, f"Expected hex hash pattern but got: {mm_key_hashes}"
else:
# Should find the exact UUID string
assert pattern in mm_key_hashes, (
f"Expected UUID '{pattern}' in mm_keys, got: {mm_key_hashes}")
def test_kv_event_mm_keys_with_uuid_multiple_prompts():
"""Test mm_keys with UUIDs across multiple prompts, each with its own image.
This test verifies that when multiple prompts are processed, each with its own
multimodal data and UUID:
1. Each prompt's mm_keys correctly return the associated UUID
2. Different UUIDs are preserved for different prompts
3. KV cache events correctly associate UUIDs with their respective blocks
"""
encoder_model_dir = _LLAVA_DIR
max_tokens = 16
free_gpu_memory_fraction = 0.2
# Multiple prompts, each with its own image and UUID
prompts = [
"Describe the natural environment in the image.",
"What objects can you see in the image?",
"Describe the weather in the image.",
]
media = [example_images[0], example_images[1], example_images[2]]
uuids = ["uuid-image-seashore", "uuid-image-inpaint", "uuid-image-61"]
sampling_params = SamplingParams(max_tokens=max_tokens)
kv_cache_config = KvCacheConfig(
enable_block_reuse=True,
free_gpu_memory_fraction=free_gpu_memory_fraction,
event_buffer_max_size=2048,
)
llm = LLM(model=encoder_model_dir,
backend='pytorch',
kv_cache_config=kv_cache_config,
max_batch_size=1)
# Load inputs with UUIDs for each prompt
inputs = _load_inputs(llm, prompts, media)
# Add multi_modal_uuids to each input
for inp, uuid in zip(inputs, uuids):
inp["multi_modal_uuids"] = {"image": [uuid]}
with llm:
# Generate for each input separately
for inp in inputs:
_ = llm.generate([inp], sampling_params=sampling_params)
# Wait for KV cache events to be dispatched asynchronously
time.sleep(0.5)
events = llm.get_kv_cache_events(50)
# Collect all unique mm_key hashes from stored events
mm_key_hashes = set()
for event in events:
if event and event.get("data", {}).get("type") == "stored":
for block in event["data"].get("blocks", []):
if block.get("mm_keys"):
for mm_key in block["mm_keys"]:
mm_key_hashes.add(mm_key["hash"])
# Verify we got mm_keys
assert len(mm_key_hashes) > 0, "Expected mm_keys in stored events"
# Verify each UUID appears in the results
for uuid in uuids:
assert uuid in mm_key_hashes, (
f"Expected UUID '{uuid}' in mm_keys, got: {mm_key_hashes}")
def test_kv_event_mm_keys_with_very_long_uuid():
"""Test mm_keys with UUIDs that exceed 64 bytes.
This test verifies that the system correctly handles UUIDs that are longer
than the typical 64-character hex hash representation:
1. Very long UUIDs (>64 bytes) are preserved and returned correctly
2. No truncation or corruption occurs for long UUID strings
3. The full UUID is returned in KV cache events
"""
encoder_model_dir = _LLAVA_DIR
max_tokens = 16
free_gpu_memory_fraction = 0.2
prompt = "Describe the natural environment in the image."
# Create UUIDs of varying lengths, including very long ones
# Normal UUID (36 chars): standard format
# Medium UUID (80 chars): exceeds 64-byte hash representation
# Very long UUID (200+ chars): stress test for string handling
long_uuid_80 = "sku-product-image-" + "a" * 62 # 80 chars total
very_long_uuid_200 = (
"enterprise-asset-management-system/region/us-east-1/bucket/media-assets/"
"category/electronics/subcategory/smartphones/brand/example-brand/"
"product-line/flagship-series/sku/SKU-2024-FLAGSHIP-PRO-MAX-256GB-MIDNIGHT-BLACK"
) # ~200 chars
# Use 2 images with different long UUIDs
images = [example_images[0], example_images[1]]
uuids = [long_uuid_80, very_long_uuid_200]
sampling_params = SamplingParams(max_tokens=max_tokens)
kv_cache_config = KvCacheConfig(
enable_block_reuse=True,
free_gpu_memory_fraction=free_gpu_memory_fraction,
event_buffer_max_size=1024,
)
llm = LLM(model=encoder_model_dir,
backend='pytorch',
kv_cache_config=kv_cache_config,
max_batch_size=1)
# Load input using the multimodal input loader for multiple images
config_path = os.path.join(llm._hf_model_dir, 'config.json')
with open(config_path, 'r') as f:
model_config = json.load(f)
model_type = model_config['model_type']
inputs = default_multimodal_input_loader(tokenizer=llm.tokenizer,
model_dir=llm._hf_model_dir,
model_type=model_type,
modality="multiple_image",
prompts=[prompt],
media=[images],
image_data_format="pt")
# Add very long UUIDs
inputs[0]["multi_modal_uuids"] = {"image": uuids}
inp = inputs[0]
with llm:
_ = llm.generate([inp], sampling_params=sampling_params)
# Wait for KV cache events to be dispatched asynchronously
time.sleep(0.5)
events = llm.get_kv_cache_events(50)
# Collect all unique mm_key hashes from stored events
mm_key_hashes = set()
for event in events:
if event and event.get("data", {}).get("type") == "stored":
for block in event["data"].get("blocks", []):
if block.get("mm_keys"):
for mm_key in block["mm_keys"]:
mm_key_hashes.add(mm_key["hash"])
# Verify we got mm_keys
assert len(mm_key_hashes) > 0, "Expected mm_keys in stored events"
# Verify the 80-char UUID is present and not truncated
assert long_uuid_80 in mm_key_hashes, (
f"Expected 80-char UUID '{long_uuid_80}' in mm_keys, got: {mm_key_hashes}"
)
# Verify the 200-char UUID is present and not truncated
assert very_long_uuid_200 in mm_key_hashes, (
f"Expected 200-char UUID '{very_long_uuid_200}' in mm_keys, got: {mm_key_hashes}"
)
# Verify the UUIDs are exactly as provided (no truncation)
for uuid in uuids:
matching = [h for h in mm_key_hashes if h == uuid]
assert len(matching) == 1, (
f"UUID '{uuid}' (len={len(uuid)}) should appear exactly once, "
f"found {len(matching)} times in {mm_key_hashes}")
@pytest.fixture(scope="module", params=[False, True])
def pd_disagg(request) -> bool:
return request.param

View File

@ -1014,6 +1014,64 @@ def test_multimodal_input():
assert config.multimodal_hashes == multimodal_hashes
assert config.multimodal_positions == multimodal_positions
assert config.multimodal_lengths == multimodal_lengths
# Default value for multimodal_uuids should be None
assert config.multimodal_uuids is None
@pytest.mark.parametrize(
"multimodal_uuids,expected_uuids",
[
# Test with all UUIDs provided
(["sku-image-001", "sku-image-002"], ["sku-image-001", "sku-image-002"]
),
# Test with partial UUIDs (some None)
(["sku-image-001", None], ["sku-image-001", None]),
# Test with empty list of UUIDs
([], []),
# Test with None (default)
(None, None),
],
ids=["all_uuids", "partial_uuids", "empty_list", "none_default"])
def test_multimodal_input_with_uuids(multimodal_uuids, expected_uuids):
"""Test MultimodalInput with user-provided UUIDs."""
multimodal_hashes = [[1, 2, 3, 4, 5, 6, 7, 8], [8, 7, 6, 5, 4, 3, 2, 1]]
multimodal_positions = [10, 100]
multimodal_lengths = [50, 60]
config = trtllm.MultimodalInput(multimodal_hashes, multimodal_positions,
multimodal_lengths, multimodal_uuids)
assert config.multimodal_hashes == multimodal_hashes
assert config.multimodal_positions == multimodal_positions
assert config.multimodal_lengths == multimodal_lengths
assert config.multimodal_uuids == expected_uuids
def test_multimodal_input_pickle_with_uuids():
"""Test pickling and unpickling of MultimodalInput with UUIDs."""
multimodal_hashes = [[1, 2, 3, 4, 5, 6, 7, 8], [8, 7, 6, 5, 4, 3, 2, 1]]
multimodal_positions = [10, 100]
multimodal_lengths = [50, 60]
multimodal_uuids = ["test-uuid-1", None]
config = trtllm.MultimodalInput(multimodal_hashes, multimodal_positions,
multimodal_lengths, multimodal_uuids)
# Pickle and unpickle
pickled = pickle.dumps(config)
restored = pickle.loads(pickled)
assert restored.multimodal_hashes == multimodal_hashes
assert restored.multimodal_positions == multimodal_positions
assert restored.multimodal_lengths == multimodal_lengths
assert restored.multimodal_uuids == multimodal_uuids
# Test with None UUIDs
config_no_uuids = trtllm.MultimodalInput(multimodal_hashes,
multimodal_positions,
multimodal_lengths)
pickled_no_uuids = pickle.dumps(config_no_uuids)
restored_no_uuids = pickle.loads(pickled_no_uuids)
assert restored_no_uuids.multimodal_uuids is None
def test_mrope_config():

View File

@ -109,23 +109,24 @@ def test_kv_cache_event_data_serialization():
def test_mm_keys_serialization():
"""Test serialization of multimodal keys (mm_keys) in KV cache events."""
# Test _mm_key_to_json with a mock mm_key tuple (bytes, int)
# MmKey from C++ is converted to (bytes, int) tuple by pybind11
# Test _mm_key_to_json with a mock mm_key tuple (bytes, int, uuid)
# MmKey from C++ is converted to (bytes, int, optional<str>) tuple by pybind11
mock_hash = b'\x01\x02\x03\x04\x05\x06\x07\x08' + b'\x00' * 24 # 32 bytes
mock_offset = 42
mock_mm_key = (mock_hash, mock_offset)
# New format: (hash, offset, uuid) - uuid is None for content-hashed items
mock_mm_key = (mock_hash, mock_offset, None)
result = KVCacheEventSerializer._mm_key_to_json(mock_mm_key)
assert result["type"] == "mm_key"
assert result["start_offset"] == 42
# Hash should be converted to hex string
# Hash should be converted to hex string when UUID is None
assert result["hash"] == "0102030405060708" + "00" * 24
assert len(result["hash"]) == 64 # 32 bytes = 64 hex chars
# Test with different hash values
mock_hash2 = bytes(range(32)) # 0x00 to 0x1f
mock_mm_key2 = (mock_hash2, 100)
mock_mm_key2 = (mock_hash2, 100, None)
result2 = KVCacheEventSerializer._mm_key_to_json(mock_mm_key2)
assert result2["type"] == "mm_key"
@ -136,10 +137,10 @@ def test_mm_keys_serialization():
def test_mm_keys_deserialization():
"""Test deserialization of mm_keys JSON back to 32-byte hash."""
# Test case 1: Simple hash pattern
# Test case 1: Simple hash pattern (no UUID)
mock_hash = b'\x01\x02\x03\x04\x05\x06\x07\x08' + b'\x00' * 24 # 32 bytes
mock_offset = 42
mock_mm_key = (mock_hash, mock_offset)
mock_mm_key = (mock_hash, mock_offset, None) # New format with None UUID
# Serialize to JSON
json_result = KVCacheEventSerializer._mm_key_to_json(mock_mm_key)
@ -155,7 +156,7 @@ def test_mm_keys_deserialization():
# Test case 2: Sequential bytes 0x00 to 0x1f
mock_hash2 = bytes(range(32))
mock_offset2 = 100
mock_mm_key2 = (mock_hash2, mock_offset2)
mock_mm_key2 = (mock_hash2, mock_offset2, None)
json_result2 = KVCacheEventSerializer._mm_key_to_json(mock_mm_key2)
recovered_hash2 = bytes.fromhex(json_result2["hash"])
@ -167,7 +168,7 @@ def test_mm_keys_deserialization():
# Test case 3: All 0xFF bytes
mock_hash3 = b'\xff' * 32
mock_offset3 = 255
mock_mm_key3 = (mock_hash3, mock_offset3)
mock_mm_key3 = (mock_hash3, mock_offset3, None)
json_result3 = KVCacheEventSerializer._mm_key_to_json(mock_mm_key3)
recovered_hash3 = bytes.fromhex(json_result3["hash"])
@ -179,7 +180,7 @@ def test_mm_keys_deserialization():
# Test case 4: Random-like pattern
mock_hash4 = bytes([0xde, 0xad, 0xbe, 0xef] + [0xca, 0xfe] * 14)
mock_offset4 = 1024
mock_mm_key4 = (mock_hash4, mock_offset4)
mock_mm_key4 = (mock_hash4, mock_offset4, None)
json_result4 = KVCacheEventSerializer._mm_key_to_json(mock_mm_key4)
recovered_hash4 = bytes.fromhex(json_result4["hash"])
@ -188,6 +189,274 @@ def test_mm_keys_deserialization():
assert len(recovered_hash4) == 32
def test_mm_key_with_uuid():
"""Test _mm_key_to_json returns UUID when provided in the tuple."""
# Create a mock mm_key with new format (hash, offset, uuid)
mock_hash = b'\x01\x02\x03\x04\x05\x06\x07\x08' + b'\x00' * 24 # 32 bytes
mock_offset = 42
expected_hash = "0102030405060708" + "00" * 24
# Test 1: Without UUID (None), should return hex hash
mock_mm_key_no_uuid = (mock_hash, mock_offset, None)
result_no_uuid = KVCacheEventSerializer._mm_key_to_json(mock_mm_key_no_uuid)
assert result_no_uuid["hash"] == expected_hash
assert result_no_uuid["start_offset"] == 42
# Test 2: With UUID in tuple, should return UUID directly
test_uuid = "my-custom-image-uuid"
mock_mm_key_with_uuid = (mock_hash, mock_offset, test_uuid)
result_with_uuid = KVCacheEventSerializer._mm_key_to_json(
mock_mm_key_with_uuid)
assert result_with_uuid["hash"] == test_uuid
assert result_with_uuid["start_offset"] == 42
# Test 3: Backward compatibility - old format (2 elements) should return hex hash
mock_mm_key_old_format = (mock_hash, mock_offset)
result_old_format = KVCacheEventSerializer._mm_key_to_json(
mock_mm_key_old_format)
assert result_old_format["hash"] == expected_hash
def test_apply_mm_hashes_with_uuids():
"""Test apply_mm_hashes with user-provided UUIDs."""
import torch
from tensorrt_llm.inputs.multimodal import apply_mm_hashes
# Create mock multimodal data - use fixed seed for reproducibility
torch.manual_seed(42)
mock_image1 = torch.randn(3, 224, 224)
mock_image2 = torch.randn(3, 224, 224)
mm_data = {"image": [mock_image1, mock_image2]}
# Test without UUIDs - should use content-only hashing
hashes_no_uuid, uuids_no_uuid = apply_mm_hashes(mm_data)
assert len(hashes_no_uuid["image"]) == 2
assert all(len(h) == 64 for h in hashes_no_uuid["image"])
assert uuids_no_uuid is None
# Test with partial UUIDs (first has UUID, second uses content-only hash)
mm_uuids = {"image": ["sku-1234-a", None]}
hashes_partial, uuids_partial = apply_mm_hashes(mm_data, mm_uuids)
assert len(hashes_partial["image"]) == 2
# First hash should be combined UUID+content (different from content-only)
assert len(hashes_partial["image"][0]) == 64
assert hashes_partial["image"][0] != hashes_no_uuid["image"][
0] # UUID changes hash
# Second hash should be content-only (same as without UUID)
assert hashes_partial["image"][1] == hashes_no_uuid["image"][1]
# UUIDs list should have the UUID and None
assert uuids_partial == ["sku-1234-a", None]
# Test with all UUIDs
mm_uuids_all = {"image": ["sku-1234-a", "sku-1234-b"]}
hashes_all, uuids_all = apply_mm_hashes(mm_data, mm_uuids_all)
assert len(hashes_all["image"]) == 2
assert all(len(h) == 64 for h in hashes_all["image"])
# Both hashes should differ from content-only hashes
assert hashes_all["image"][0] != hashes_no_uuid["image"][0]
assert hashes_all["image"][1] != hashes_no_uuid["image"][1]
# Different UUIDs with different content should produce different hashes
assert hashes_all["image"][0] != hashes_all["image"][1]
assert uuids_all == ["sku-1234-a", "sku-1234-b"]
def test_apply_mm_hashes_uuid_content_combined():
"""Test that UUID + content hashing ensures cache correctness.
This test verifies the key properties of combined UUID+content hashing:
1. Same UUID + same content = same hash (cache hit expected)
2. Same UUID + different content = different hash (no incorrect cache hit)
3. Different UUID + same content = different hash (user isolation)
"""
import torch
from tensorrt_llm.inputs.multimodal import apply_mm_hashes
# Create identical images
torch.manual_seed(42)
image_a = torch.randn(3, 224, 224)
image_a_copy = image_a.clone() # Identical content
# Create a different image
torch.manual_seed(123)
image_b = torch.randn(3, 224, 224)
# Property 1: Same UUID + same content = same hash
mm_data_a = {"image": [image_a]}
mm_data_a_copy = {"image": [image_a_copy]}
mm_uuids = {"image": ["user-123-img"]}
hashes_a, _ = apply_mm_hashes(mm_data_a, mm_uuids)
hashes_a_copy, _ = apply_mm_hashes(mm_data_a_copy, mm_uuids)
assert hashes_a["image"][0] == hashes_a_copy["image"][0], \
"Same UUID + same content should produce identical hashes"
# Property 2: Same UUID + different content = different hash
mm_data_b = {"image": [image_b]}
hashes_b, _ = apply_mm_hashes(mm_data_b, mm_uuids)
assert hashes_a["image"][0] != hashes_b["image"][0], \
"Same UUID + different content must produce different hashes"
# Property 3: Different UUID + same content = different hash (user isolation)
mm_uuids_user2 = {"image": ["user-456-img"]}
hashes_user2, _ = apply_mm_hashes(mm_data_a, mm_uuids_user2)
assert hashes_a["image"][0] != hashes_user2["image"][0], \
"Different UUID + same content should produce different hashes"
def test_int32_hexdigest_roundtrip():
"""Test that hexdigest_to_int32 and int32_to_hexdigest are inverses."""
from tensorrt_llm.inputs.multimodal import (hexdigest_to_int32,
int32_to_hexdigest)
# Test with various hash patterns
test_hashes = [
"0000000000000000000000000000000000000000000000000000000000000000",
"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff",
"0102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f20",
"deadbeefcafebabefeedfacebadc0ffedeadbeefcafebabefeedfacebadc0ffe",
]
for original_hex in test_hashes:
int32_values = hexdigest_to_int32(original_hex)
recovered_hex = int32_to_hexdigest(int32_values)
assert recovered_hex == original_hex, f"Roundtrip failed for {original_hex}"
def test_multimodal_input_dataclass_with_uuids():
"""Test Python MultimodalInput dataclass with UUIDs."""
from tensorrt_llm.inputs.multimodal import MultimodalInput
# Test with all UUIDs
mm_input = MultimodalInput(multimodal_hashes=[[1, 2, 3, 4, 5, 6, 7, 8]],
multimodal_positions=[10],
multimodal_lengths=[50],
multimodal_uuids=["test-uuid-123"])
assert mm_input.multimodal_uuids == ["test-uuid-123"]
# Test with partial UUIDs (some None)
mm_input_partial = MultimodalInput(
multimodal_hashes=[[1, 2, 3, 4, 5, 6, 7, 8], [8, 7, 6, 5, 4, 3, 2, 1]],
multimodal_positions=[10, 100],
multimodal_lengths=[50, 60],
multimodal_uuids=["sku-001", None])
assert mm_input_partial.multimodal_uuids == ["sku-001", None]
# Test with None UUIDs (default)
mm_input_no_uuids = MultimodalInput(
multimodal_hashes=[[1, 2, 3, 4, 5, 6, 7, 8]],
multimodal_positions=[10],
multimodal_lengths=[50])
assert mm_input_no_uuids.multimodal_uuids is None
def test_multimodal_input_dataclass_uuid_validation():
"""Test MultimodalInput validation for multimodal_uuids field."""
from tensorrt_llm.inputs.multimodal import MultimodalInput
# Test UUID list length mismatch
with pytest.raises(ValueError, match="multimodal_uuids length"):
MultimodalInput(multimodal_hashes=[[1, 2, 3, 4, 5, 6, 7, 8],
[8, 7, 6, 5, 4, 3, 2, 1]],
multimodal_positions=[10, 100],
multimodal_lengths=[50, 60],
multimodal_uuids=["only-one-uuid"])
# Test invalid UUID type
with pytest.raises(TypeError, match="must be a string or None"):
MultimodalInput(multimodal_hashes=[[1, 2, 3, 4, 5, 6, 7, 8]],
multimodal_positions=[10],
multimodal_lengths=[50],
multimodal_uuids=[123]) # Integer instead of string
# Test invalid multimodal_uuids type (not a list)
with pytest.raises(TypeError, match="multimodal_uuids must be a list"):
MultimodalInput(multimodal_hashes=[[1, 2, 3, 4, 5, 6, 7, 8]],
multimodal_positions=[10],
multimodal_lengths=[50],
multimodal_uuids="not-a-list")
def test_multimodal_input_from_components_with_uuids():
"""Test MultimodalInput.from_components factory method with UUIDs."""
from tensorrt_llm.inputs.multimodal import MultimodalInput
mm_hashes = [[1, 2, 3, 4, 5, 6, 7, 8], [8, 7, 6, 5, 4, 3, 2, 1]]
mm_positions = [10, 100]
mm_lengths = [50, 60]
mm_uuids = ["uuid-a", "uuid-b"]
mm_input = MultimodalInput.from_components(mm_hashes, mm_positions,
mm_lengths, mm_uuids)
assert mm_input.multimodal_hashes == mm_hashes
assert mm_input.multimodal_positions == mm_positions
assert mm_input.multimodal_lengths == mm_lengths
assert mm_input.multimodal_uuids == mm_uuids
# Test without UUIDs
mm_input_no_uuids = MultimodalInput.from_components(mm_hashes, mm_positions,
mm_lengths)
assert mm_input_no_uuids.multimodal_uuids is None
def test_apply_mm_hashes_uuid_length_mismatch():
"""Test apply_mm_hashes raises error on UUID list length mismatch."""
import torch
from tensorrt_llm.inputs.multimodal import apply_mm_hashes
mock_image1 = torch.randn(3, 224, 224)
mock_image2 = torch.randn(3, 224, 224)
mm_data = {"image": [mock_image1, mock_image2]}
# Mismatched UUID list length
mm_uuids_wrong_length = {"image": ["only-one-uuid"]} # Should have 2
with pytest.raises(ValueError,
match="UUID list length.*doesn't match.*data items"):
apply_mm_hashes(mm_data, mm_uuids_wrong_length)
def test_apply_mm_hashes_multiple_modalities():
"""Test apply_mm_hashes with multiple modalities and UUIDs."""
import torch
from tensorrt_llm.inputs.multimodal import apply_mm_hashes
# Create mock data for multiple modalities
torch.manual_seed(42)
mock_image = torch.randn(3, 224, 224)
mock_video_frames = [torch.randn(3, 224, 224) for _ in range(4)]
mm_data = {"image": [mock_image], "video": [mock_video_frames]}
# First, get content-only hashes (without UUIDs)
hashes_no_uuid, _ = apply_mm_hashes(mm_data)
# UUIDs for each modality
mm_uuids = {"image": ["img-uuid-001"], "video": ["vid-uuid-001"]}
hashes, uuids_list = apply_mm_hashes(mm_data, mm_uuids)
# Check hashes are 64-char hex strings (combined UUID+content hashes)
assert len(hashes["image"][0]) == 64
assert len(hashes["video"][0]) == 64
# Verify UUIDs change the hashes (UUID+content != content-only)
assert hashes["image"][0] != hashes_no_uuid["image"][0]
assert hashes["video"][0] != hashes_no_uuid["video"][0]
# Check flattened UUID list (order may vary based on dict iteration)
assert set(uuids_list) == {"img-uuid-001", "vid-uuid-001"}
def test_mm_keys_in_stored_events():
"""Test that mm_keys field is present in stored block events."""
llm = create_llm()