diff --git a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h index 1d05c42e20..8b6ca20322 100644 --- a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h +++ b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h @@ -119,6 +119,7 @@ public: std::optional>>> multimodalHashes = std::nullopt, std::optional>> multimodalPositions = std::nullopt, std::optional>> multimodalLengths = std::nullopt, + std::optional>>> multimodalUuids = std::nullopt, std::optional multimodalEmbedding = std::nullopt, std::optional mropeRotaryCosSin = std::nullopt, std::optional mropePositionDeltas = std::nullopt, @@ -168,6 +169,7 @@ public: , mMultimodalHashes(std::move(multimodalHashes)) , mMultimodalPositions(std::move(multimodalPositions)) , mMultimodalLengths(std::move(multimodalLengths)) + , mMultimodalUuids(std::move(multimodalUuids)) , mMultimodalEmbedding(std::move(multimodalEmbedding)) , mMropeRotaryCosSin(std::move(mropeRotaryCosSin)) , mMropePositionDeltas(mropePositionDeltas) @@ -909,6 +911,11 @@ public: return mMultimodalLengths; } + [[nodiscard]] std::optional>>> getMultimodalUuids() const + { + return mMultimodalUuids; + } + [[nodiscard]] std::optional getMultimodalEmbedding() const { return mMultimodalEmbedding; @@ -1964,6 +1971,7 @@ protected: std::optional>>> mMultimodalHashes{std::nullopt}; std::optional>> mMultimodalPositions{std::nullopt}; std::optional>> mMultimodalLengths{std::nullopt}; + std::optional>>> mMultimodalUuids{std::nullopt}; std::optional mMultimodalEmbedding{std::nullopt}; std::optional mMropeRotaryCosSin{std::nullopt}; std::optional mMropePositionDeltas{std::nullopt}; @@ -2252,6 +2260,7 @@ public: std::optional>> multimodalHashes = std::nullopt, std::optional> multimodalPositions = std::nullopt, std::optional> multimodalLengths = std::nullopt, + std::optional>> multimodalUuids = std::nullopt, std::optional multimodalEmbedding = std::nullopt, std::optional mropeRotaryCosSin = std::nullopt, std::optional mropePositionDeltas = std::nullopt, @@ -2292,6 +2301,9 @@ public: multimodalLengths.has_value() ? std::make_shared>(std::move(multimodalLengths.value())) : std::optional>>(std::nullopt), + multimodalUuids.has_value() + ? std::make_shared>>(std::move(multimodalUuids.value())) + : std::optional>>>(std::nullopt), std::move(multimodalEmbedding), std::move(mropeRotaryCosSin), mropePositionDeltas, loraTaskId, std::move(loraWeights), std::move(loraConfig), lookaheadConfig, std::move(kvCacheRetentionConfig), returnLogProbs, returnContextLogits, returnGenerationLogits, diff --git a/cpp/include/tensorrt_llm/executor/executor.h b/cpp/include/tensorrt_llm/executor/executor.h index 44806b37b0..4c219c9d42 100644 --- a/cpp/include/tensorrt_llm/executor/executor.h +++ b/cpp/include/tensorrt_llm/executor/executor.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2022-2026, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -48,10 +48,6 @@ namespace tensorrt_llm::executor { using SizeType32 = tensorrt_llm::runtime::SizeType32; -// Mmkey is used in KVCacheBlock when multimodal data presents in a block. -// Type alias for hash array + start offset at per-block granularity. -// This differs from the per-request level multimodal hash in MultimodalInput. -using MmKey = std::pair, SizeType32>; /// @brief Version of TRT-LLM char const* version() noexcept; @@ -301,11 +297,13 @@ class MultimodalInput { public: explicit MultimodalInput(std::vector> multimodalHashes, - std::vector multimodalPositions, std::vector multimodalLengths); + std::vector multimodalPositions, std::vector multimodalLengths, + std::optional>> multimodalUuids = std::nullopt); [[nodiscard]] std::vector> getMultimodalHashes() const; [[nodiscard]] std::vector getMultimodalPositions() const; [[nodiscard]] std::vector getMultimodalLengths() const; + [[nodiscard]] std::optional>> const& getMultimodalUuids() const; private: friend class Serialization; @@ -315,6 +313,9 @@ private: std::vector mMultimodalPositions; /// @brief The multimodal lengths std::vector mMultimodalLengths; + /// @brief Optional user-provided UUIDs for multimodal items. + /// When provided, these are returned in KV cache events instead of content hashes. + std::optional>> mMultimodalUuids; }; /// @brief Configuration for mrope diff --git a/cpp/include/tensorrt_llm/executor/serialization.h b/cpp/include/tensorrt_llm/executor/serialization.h index 1d30da2027..62546dd70d 100644 --- a/cpp/include/tensorrt_llm/executor/serialization.h +++ b/cpp/include/tensorrt_llm/executor/serialization.h @@ -349,6 +349,11 @@ public: static void serialize(KVCacheUpdatedData const& data, std::ostream& os); [[nodiscard]] static KVCacheUpdatedData deserializeKVCacheUpdatedData(std::istream& is); + // MmKey + [[nodiscard]] static size_t serializedSize(MmKey const& key); + static void serialize(MmKey const& key, std::ostream& os); + [[nodiscard]] static MmKey deserializeMmKey(std::istream& is); + // UniqueToken [[nodiscard]] static size_t serializedSize(tensorrt_llm::runtime::UniqueToken const& token); static void serialize(tensorrt_llm::runtime::UniqueToken const& token, std::ostream& os); diff --git a/cpp/include/tensorrt_llm/executor/types.h b/cpp/include/tensorrt_llm/executor/types.h index e248cb1c3c..89618dce54 100644 --- a/cpp/include/tensorrt_llm/executor/types.h +++ b/cpp/include/tensorrt_llm/executor/types.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2022-2026, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ #pragma once +#include #include #include #include @@ -70,6 +71,29 @@ using EagleChoices = std::vector>; using PriorityType = float; using BufferView = std::basic_string_view; +//! MmKey is used in KVCacheBlock when multimodal data presents in a block. +//! Hash is a 32-byte array; startOffset is the per-block token offset; uuid is optional. +struct MmKey +{ + std::array hash; + SizeType32 startOffset{}; + std::optional uuid{std::nullopt}; + + MmKey() = default; + + MmKey(std::array hash, SizeType32 startOffset, std::optional uuid = std::nullopt) + : hash(std::move(hash)) + , startOffset(startOffset) + , uuid(std::move(uuid)) + { + } + + bool operator==(MmKey const& other) const noexcept + { + return hash == other.hash && startOffset == other.startOffset && uuid == other.uuid; + } +}; + enum class DataType { kBOOL, diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp index 8e7b6ed5a8..b22268e3d2 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp @@ -100,6 +100,7 @@ std::vector generateBlockHashExtraKeys( auto const multimodalHashes = llmRequest.getMultimodalHashes(); auto const multimodalPositions = llmRequest.getMultimodalPositions(); auto const multimodalLengths = llmRequest.getMultimodalLengths(); + auto const multimodalUuids = llmRequest.getMultimodalUuids(); if (!multimodalHashes || !multimodalPositions || !multimodalLengths || !(*multimodalHashes) || (*multimodalHashes)->empty() || !(*multimodalPositions) || (*multimodalPositions)->empty() @@ -115,7 +116,7 @@ std::vector generateBlockHashExtraKeys( return {}; } - std::vector extraKeys; // MmKey = std::pair, SizeType32> + std::vector extraKeys; extraKeys.reserve((*multimodalPositions)->size()); std::array mmHashArray; @@ -145,7 +146,15 @@ std::vector generateBlockHashExtraKeys( if (endTokenIdx > startPos && startTokenIdx < startPos + length) { uint64_t mmStartInBlock = (startPos >= startTokenIdx) ? 0 : static_cast(startTokenIdx - startPos); - extraKeys.emplace_back(mmHashArray, mmStartInBlock); + + // Get UUID if available + std::optional uuid = std::nullopt; + if (multimodalUuids && *multimodalUuids && i < (*multimodalUuids)->size()) + { + uuid = (*(*multimodalUuids))[i]; + } + + extraKeys.emplace_back(mmHashArray, mmStartInBlock, std::move(uuid)); } } @@ -222,8 +231,10 @@ size_t BlockKeyHasher::hash(BlockKey const& blockKey, std::size_t parentHash) no // block if (!blockKey.extraKeys.empty()) { - for (auto const& [mmHash, startOffset] : blockKey.extraKeys) + for (auto const& mmKey : blockKey.extraKeys) { + auto const& mmHash = mmKey.hash; + auto const& startOffset = mmKey.startOffset; // Hash the multimodal hash array in 32-bit chunks (more efficient) for (size_t i = 0; i < 32; i += 4) { diff --git a/cpp/tensorrt_llm/executor/multimodalInput.cpp b/cpp/tensorrt_llm/executor/multimodalInput.cpp index 2077e9a261..9a2affdb0e 100644 --- a/cpp/tensorrt_llm/executor/multimodalInput.cpp +++ b/cpp/tensorrt_llm/executor/multimodalInput.cpp @@ -21,10 +21,12 @@ namespace tensorrt_llm::executor { MultimodalInput::MultimodalInput(std::vector> multimodalHashes, - std::vector multimodalPositions, std::vector multimodalLengths) + std::vector multimodalPositions, std::vector multimodalLengths, + std::optional>> multimodalUuids) : mMultimodalHashes(std::move(multimodalHashes)) , mMultimodalPositions(std::move(multimodalPositions)) , mMultimodalLengths(std::move(multimodalLengths)) + , mMultimodalUuids(std::move(multimodalUuids)) { } @@ -43,4 +45,9 @@ std::vector MultimodalInput::getMultimodalLengths() const return mMultimodalLengths; } +std::optional>> const& MultimodalInput::getMultimodalUuids() const +{ + return mMultimodalUuids; +} + } // namespace tensorrt_llm::executor diff --git a/cpp/tensorrt_llm/executor/serialization.cpp b/cpp/tensorrt_llm/executor/serialization.cpp index 10f238fa75..ad41c1f176 100644 --- a/cpp/tensorrt_llm/executor/serialization.cpp +++ b/cpp/tensorrt_llm/executor/serialization.cpp @@ -339,7 +339,9 @@ MultimodalInput Serialization::deserializeMultimodalInput(std::istream& is) auto multimodalHashes = su::deserialize>>(is); auto multimodalPositions = su::deserialize>(is); auto multimodalLengths = su::deserialize>(is); - return MultimodalInput{std::move(multimodalHashes), std::move(multimodalPositions), std::move(multimodalLengths)}; + auto multimodalUuids = su::deserialize>>>(is); + return MultimodalInput{std::move(multimodalHashes), std::move(multimodalPositions), std::move(multimodalLengths), + std::move(multimodalUuids)}; } void Serialization::serialize(MultimodalInput const& multimodalInput, std::ostream& os) @@ -347,6 +349,7 @@ void Serialization::serialize(MultimodalInput const& multimodalInput, std::ostre su::serialize(multimodalInput.mMultimodalHashes, os); su::serialize(multimodalInput.mMultimodalPositions, os); su::serialize(multimodalInput.mMultimodalLengths, os); + su::serialize(multimodalInput.mMultimodalUuids, os); } size_t Serialization::serializedSize(MultimodalInput const& multimodalInput) @@ -355,6 +358,7 @@ size_t Serialization::serializedSize(MultimodalInput const& multimodalInput) totalSize += su::serializedSize(multimodalInput.mMultimodalHashes); totalSize += su::serializedSize(multimodalInput.mMultimodalPositions); totalSize += su::serializedSize(multimodalInput.mMultimodalLengths); + totalSize += su::serializedSize(multimodalInput.mMultimodalUuids); return totalSize; } @@ -2441,6 +2445,31 @@ KVCacheUpdatedData Serialization::deserializeKVCacheUpdatedData(std::istream& is return KVCacheUpdatedData{blockHash, cacheLevel, priority}; } +// MmKey +size_t Serialization::serializedSize(MmKey const& key) +{ + size_t totalSize = 0; + totalSize += su::serializedSize(key.hash); + totalSize += su::serializedSize(key.startOffset); + totalSize += su::serializedSize(key.uuid); + return totalSize; +} + +void Serialization::serialize(MmKey const& key, std::ostream& os) +{ + su::serialize(key.hash, os); + su::serialize(key.startOffset, os); + su::serialize(key.uuid, os); +} + +MmKey Serialization::deserializeMmKey(std::istream& is) +{ + auto hash = su::deserialize>(is); + auto startOffset = su::deserialize(is); + auto uuid = su::deserialize>(is); + return MmKey{std::move(hash), startOffset, std::move(uuid)}; +} + // UniqueToken size_t Serialization::serializedSize(tensorrt_llm::runtime::UniqueToken const& token) { diff --git a/cpp/tensorrt_llm/executor/serializeUtils.h b/cpp/tensorrt_llm/executor/serializeUtils.h index 1f1e90e0a3..02034741e6 100644 --- a/cpp/tensorrt_llm/executor/serializeUtils.h +++ b/cpp/tensorrt_llm/executor/serializeUtils.h @@ -170,6 +170,7 @@ static_assert(hasSerializedSize(size_t())); static_assert(hasSerializedSize>(size_t())); static_assert(hasSerializedSize(size_t())); static_assert(hasSerializedSize(size_t())); +static_assert(hasSerializedSize(size_t())); template size_t serializedSize(T const& data) @@ -290,6 +291,7 @@ static_assert(hasSerialize(nullptr)); static_assert(hasSerialize>(nullptr)); static_assert(hasSerialize(nullptr)); static_assert(hasSerialize(nullptr)); +static_assert(hasSerialize(nullptr)); template void serialize(T const& data, std::ostream& os) @@ -476,6 +478,10 @@ T deserialize(std::istream& is) { return Serialization::deserializeContextPhaseParams(is); } + else if constexpr (std::is_same_v) + { + return Serialization::deserializeMmKey(is); + } else if constexpr (std::is_same_v) { return Serialization::deserializeRequest(is); diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp index a9c4b3e03a..7166531835 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp @@ -289,6 +289,7 @@ void initBindings(nb::module_& m) std::optional>> multimodal_hashes, std::optional> multimodal_positions, std::optional> multimodal_lengths, + std::optional>> multimodal_uuids, std::optional multimodal_embedding, std::optional mrope_rotary_cos_sin, std::optional mrope_position_deltas, std::optional lora_task_id, std::optional lora_weights, @@ -344,7 +345,7 @@ void initBindings(nb::module_& m) new (self) tb::LlmRequest{request_id, max_new_tokens, input_tokens, sampling_config, is_streaming, end_id, pad_id, embedding_bias_tensor_ptr, bad_words_list_tensor_ptr, stop_words_list_tensor_ptr, position_ids, prompt_embedding_table_tensor_ptr, prompt_vocab_size, multimodal_hashes, - multimodal_positions, multimodal_lengths, multimodal_embedding_tensor_ptr, + multimodal_positions, multimodal_lengths, multimodal_uuids, multimodal_embedding_tensor_ptr, mrope_rotary_cos_sin_tensor_ptr, mrope_position_deltas, lora_task_id, lora_weights_tensor_ptr, lora_config_tensor_ptr, lookahead_config, kv_cache_retention_config, return_log_probs, return_context_logits, return_generation_logits, draft_tokens, draft_logits_tensor_ptr, @@ -361,18 +362,19 @@ void initBindings(nb::module_& m) nb::arg("stop_words_list") = std::nullopt, nb::arg("position_ids") = std::nullopt, nb::arg("prompt_embedding_table") = std::nullopt, nb::arg("prompt_vocab_size") = std::nullopt, nb::arg("multimodal_hashes") = std::nullopt, nb::arg("multimodal_positions") = std::nullopt, - nb::arg("multimodal_lengths") = std::nullopt, nb::arg("multimodal_embedding") = std::nullopt, - nb::arg("mrope_rotary_cos_sin") = std::nullopt, nb::arg("mrope_position_deltas") = std::nullopt, - nb::arg("lora_task_id") = std::nullopt, nb::arg("lora_weights") = std::nullopt, - nb::arg("lora_config") = std::nullopt, nb::arg("lookahead_config") = std::nullopt, - nb::arg("kv_cache_retention_config") = std::nullopt, nb::arg("return_log_probs") = false, - nb::arg("return_context_logits") = false, nb::arg("return_generation_logits") = false, - nb::arg("draft_tokens") = std::nullopt, nb::arg("draft_logits") = std::nullopt, - nb::arg("exclude_input_from_output") = false, nb::arg("logits_post_processor") = std::nullopt, - nb::arg("apply_logits_post_processor_batched") = false, nb::arg("encoder_input_tokens") = std::nullopt, - nb::arg("return_encoder_output") = false, nb::arg("client_id") = std::nullopt, - nb::arg("priority") = executor::Request::kDefaultPriority, nb::arg("encoder_input_features") = std::nullopt, - nb::arg("encoder_output_len") = std::nullopt, nb::arg("cross_attention_mask") = std::nullopt, + nb::arg("multimodal_lengths") = std::nullopt, nb::arg("multimodal_uuids") = std::nullopt, + nb::arg("multimodal_embedding") = std::nullopt, nb::arg("mrope_rotary_cos_sin") = std::nullopt, + nb::arg("mrope_position_deltas") = std::nullopt, nb::arg("lora_task_id") = std::nullopt, + nb::arg("lora_weights") = std::nullopt, nb::arg("lora_config") = std::nullopt, + nb::arg("lookahead_config") = std::nullopt, nb::arg("kv_cache_retention_config") = std::nullopt, + nb::arg("return_log_probs") = false, nb::arg("return_context_logits") = false, + nb::arg("return_generation_logits") = false, nb::arg("draft_tokens") = std::nullopt, + nb::arg("draft_logits") = std::nullopt, nb::arg("exclude_input_from_output") = false, + nb::arg("logits_post_processor") = std::nullopt, nb::arg("apply_logits_post_processor_batched") = false, + nb::arg("encoder_input_tokens") = std::nullopt, nb::arg("return_encoder_output") = false, + nb::arg("client_id") = std::nullopt, nb::arg("priority") = executor::Request::kDefaultPriority, + nb::arg("encoder_input_features") = std::nullopt, nb::arg("encoder_output_len") = std::nullopt, + nb::arg("cross_attention_mask") = std::nullopt, nb::arg("llm_request_type") = tb::LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, nb::arg("input_token_extra_ids") = std::nullopt, nb::arg("num_return_sequences") = 1, nb::arg("eagle_config") = std::nullopt, nb::arg("skip_cross_attn_blocks") = std::nullopt, diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.cpp index 07d630cb3b..46da3e0570 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.cpp @@ -93,6 +93,7 @@ std::shared_ptr LlmRequest::toTrtLlm() const mMultimodalHashes, // mMultimodalPositions, // mMultimodalLengths, // + mMultimodalUuids, // from_torch(mMultimodalEmbedding), // from_torch(mMropeRotaryCosSin), // mMropePositionDeltas, // diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.h b/cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.h index 4ea47fdcc8..3665252517 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.h +++ b/cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.h @@ -61,6 +61,7 @@ public: std::optional>> multimodalHashes = std::nullopt, std::optional> multimodalPositions = std::nullopt, std::optional> multimodalLengths = std::nullopt, + std::optional>> multimodalUuids = std::nullopt, std::optional multimodalEmbedding = std::nullopt, std::optional mropeRotaryCosSin = std::nullopt, std::optional mropePositionDeltas = std::nullopt, @@ -111,6 +112,9 @@ public: multimodalLengths.has_value() ? std::make_shared>(std::move(multimodalLengths.value())) // : std::optional>>(std::nullopt), // + multimodalUuids.has_value() + ? std::make_shared>>(std::move(multimodalUuids.value())) // + : std::optional>>>(std::nullopt), // multimodalEmbedding, // mropeRotaryCosSin, // mropePositionDeltas, // diff --git a/cpp/tensorrt_llm/nanobind/executor/bindings.cpp b/cpp/tensorrt_llm/nanobind/executor/bindings.cpp index f8e69fa1ad..4f873e2ed1 100644 --- a/cpp/tensorrt_llm/nanobind/executor/bindings.cpp +++ b/cpp/tensorrt_llm/nanobind/executor/bindings.cpp @@ -225,15 +225,20 @@ void initBindings(nb::module_& m) .def_prop_ro("mm_keys", [](tle::KVCacheStoredBlockData const& self) { - // Convert std::vector to Python list of tuples (bytes, int) - // MmKey = std::pair, SizeType32> + // Convert std::vector to Python list of tuples (bytes, int, optional) + // MmKey = struct { hash, startOffset, uuid } nb::list result; for (auto const& mmKey : self.mmKeys) { - auto const& hashArray = mmKey.first; - auto offset = mmKey.second; - nb::bytes hashBytes(reinterpret_cast(hashArray.data()), hashArray.size()); - result.append(nb::make_tuple(hashBytes, offset)); + nb::bytes hashBytes(reinterpret_cast(mmKey.hash.data()), mmKey.hash.size()); + if (mmKey.uuid.has_value()) + { + result.append(nb::make_tuple(hashBytes, mmKey.startOffset, mmKey.uuid.value())); + } + else + { + result.append(nb::make_tuple(hashBytes, mmKey.startOffset, nb::none())); + } } return result; }); diff --git a/cpp/tensorrt_llm/nanobind/executor/request.cpp b/cpp/tensorrt_llm/nanobind/executor/request.cpp index aa75a5fb60..01736d18f5 100644 --- a/cpp/tensorrt_llm/nanobind/executor/request.cpp +++ b/cpp/tensorrt_llm/nanobind/executor/request.cpp @@ -305,22 +305,29 @@ void initRequestBindings(nb::module_& m) .def("__setstate__", loraConfigSetstate); auto multimodalInputGetstate = [](tle::MultimodalInput const& self) - { return nb::make_tuple(self.getMultimodalHashes(), self.getMultimodalPositions(), self.getMultimodalLengths()); }; + { + return nb::make_tuple(self.getMultimodalHashes(), self.getMultimodalPositions(), self.getMultimodalLengths(), + self.getMultimodalUuids()); + }; auto multimodalInputSetstate = [](tle::MultimodalInput& multimodalInput, nb::tuple const& state) { - if (state.size() != 3) + if (state.size() != 4) { throw std::runtime_error("Invalid MultimodalInput state!"); } new (&multimodalInput) tle::MultimodalInput(nb::cast>>(state[0]), - nb::cast>(state[1]), nb::cast>(state[2])); + nb::cast>(state[1]), nb::cast>(state[2]), + nb::cast>>>(state[3])); }; nb::class_(m, "MultimodalInput") - .def(nb::init>, std::vector, std::vector>(), - nb::arg("multimodal_hashes"), nb::arg("multimodal_positions"), nb::arg("multimodal_lengths")) + .def(nb::init>, std::vector, std::vector, + std::optional>>>(), + nb::arg("multimodal_hashes"), nb::arg("multimodal_positions"), nb::arg("multimodal_lengths"), + nb::arg("multimodal_uuids") = nb::none()) .def_prop_ro("multimodal_hashes", &tle::MultimodalInput::getMultimodalHashes) .def_prop_ro("multimodal_positions", &tle::MultimodalInput::getMultimodalPositions) .def_prop_ro("multimodal_lengths", &tle::MultimodalInput::getMultimodalLengths) + .def_prop_ro("multimodal_uuids", &tle::MultimodalInput::getMultimodalUuids) .def("__getstate__", multimodalInputGetstate) .def("__setstate__", multimodalInputSetstate); @@ -703,7 +710,7 @@ void initRequestBindings(nb::module_& m) nb::arg("guided_decoding_params") = nb::none(), nb::arg("language_adapter_uid") = nb::none(), nb::arg("allotted_time_ms") = nb::none(), - nb::arg("cache_salt_id") = nb::none(), + nb::arg("cache_salt_id") = nb::none(), nb::arg("disagg_request_id") = nb::none() ) // clang-format on .def_prop_ro("input_token_ids", &tle::Request::getInputTokenIds) diff --git a/cpp/tests/e2e_tests/batch_manager/trtGptModelTest.cpp b/cpp/tests/e2e_tests/batch_manager/trtGptModelTest.cpp index 7ef5f4285c..3ceb9aa884 100644 --- a/cpp/tests/e2e_tests/batch_manager/trtGptModelTest.cpp +++ b/cpp/tests/e2e_tests/batch_manager/trtGptModelTest.cpp @@ -968,10 +968,10 @@ TEST_F(TrtGptModelTest, PauseRequestStats) auto llmRequest = std::make_shared(correlationId, maxNewTokens, tokens, inSamplingConfig, false, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, - std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, std::nullopt, - false, std::nullopt, false, std::nullopt, executor::Request::kDefaultPriority, std::nullopt, std::nullopt, - std::nullopt, LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::nullopt, 1, std::nullopt, - std::nullopt, true /* returnPerfMetrics */); + std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, + std::nullopt, false, std::nullopt, false, std::nullopt, executor::Request::kDefaultPriority, std::nullopt, + std::nullopt, std::nullopt, LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::nullopt, 1, + std::nullopt, std::nullopt, true /* returnPerfMetrics */); RequestList requestList{llmRequest}; diff --git a/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp b/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp index 3faee32f4f..020af4dd5c 100644 --- a/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp +++ b/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp @@ -1082,8 +1082,8 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdTest) auto llmRequest0 = std::make_shared(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, - std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, std::nullopt, - false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt, + std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, + std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt, LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, inputTokenExtraIds, numReturnSequences); GenerationRequest seq0{requestId, inputLength, beamWidth, blockManager.getWindowSizesMetadata()}; @@ -1119,8 +1119,8 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdTest) auto llmRequest1 = std::make_shared(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, - std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, std::nullopt, - false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt, + std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, + std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt, LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, inputTokenExtraIds, numReturnSequences); GenerationRequest seq1{requestId, inputLength, beamWidth, blockManager.getWindowSizesMetadata()}; @@ -1151,9 +1151,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdTest) llmRequest0 = std::make_shared(seq0_dup.getRequestId(), maxNewTokens, inputTokens, samplingConfig, isStreaming, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, - std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, - std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt, - LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, inputTokenExtraIds, numReturnSequences); + std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, + std::nullopt, false, std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, + std::nullopt, LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, inputTokenExtraIds, numReturnSequences); promptLen0 = llmRequest0->getNumTokens(beamIdx); numContextBlocks0 = tc::ceilDiv(promptLen0, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq0_dup.getRequestId()); @@ -1175,9 +1175,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdTest) llmRequest1 = std::make_shared(seq1_dup.getRequestId(), maxNewTokens, inputTokens1, samplingConfig, isStreaming, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, - std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, - std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt, - LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, inputTokenExtraIds1, numReturnSequences); + std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, + std::nullopt, false, std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, + std::nullopt, LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, inputTokenExtraIds1, numReturnSequences); promptLen1 = llmRequest1->getNumTokens(beamIdx); numContextBlocks1 = tc::ceilDiv(promptLen1, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq1_dup.getRequestId()); @@ -1207,8 +1207,8 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdTest) auto llmRequest2 = std::make_shared(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, - std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, std::nullopt, - false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt, + std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, + std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt, LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, inputTokenExtraIds2, numReturnSequences); numTokens = llmRequest2->getNumTokens(beamIdx); @@ -1235,8 +1235,8 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdTest) auto llmRequest3 = std::make_shared(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, - std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, std::nullopt, - false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt, + std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, + std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt, LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, inputTokenExtraIds3, numReturnSequences); numTokens = llmRequest3->getNumTokens(beamIdx); @@ -1312,9 +1312,10 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithMultimodalHashTest) auto llmRequest0 = std::make_shared(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, multimodalHashes, multimodalPositions, multimodalLengths, std::nullopt, std::nullopt, std::nullopt, - std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, - std::nullopt, false, std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, - std::nullopt, LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::nullopt, numReturnSequences); + std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, + std::nullopt, std::nullopt, false, std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, + std::nullopt, std::nullopt, LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::nullopt, + numReturnSequences); GenerationRequest seq0{requestId, inputLength, beamWidth, blockManager.getWindowSizesMetadata()}; @@ -1354,9 +1355,10 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithMultimodalHashTest) auto llmRequest1 = std::make_shared(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, multimodalHashes, multimodalPositions, multimodalLengths, std::nullopt, std::nullopt, std::nullopt, - std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, - std::nullopt, false, std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, - std::nullopt, LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::nullopt, numReturnSequences); + std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, + std::nullopt, std::nullopt, false, std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, + std::nullopt, std::nullopt, LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::nullopt, + numReturnSequences); GenerationRequest seq1{requestId, inputLength, beamWidth, blockManager.getWindowSizesMetadata()}; // should reuse blocks 0, 1 and get new block 3 @@ -1391,9 +1393,10 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithMultimodalHashTest) auto llmRequest2 = std::make_shared(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, multimodalHashes2, multimodalPositions2, multimodalLengths2, std::nullopt, std::nullopt, std::nullopt, - std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, - std::nullopt, false, std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, - std::nullopt, LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::nullopt, numReturnSequences); + std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, + std::nullopt, std::nullopt, false, std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, + std::nullopt, std::nullopt, LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::nullopt, + numReturnSequences); GenerationRequest seq2{requestId, inputLength, beamWidth, blockManager.getWindowSizesMetadata()}; // no reuse, get new blocks 4, 5, 6 @@ -1427,9 +1430,10 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithMultimodalHashTest) auto llmRequest3 = std::make_shared(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, multimodalHashes3, multimodalPositions3, multimodalLengths3, std::nullopt, std::nullopt, std::nullopt, - std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, - std::nullopt, false, std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, - std::nullopt, LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::nullopt, numReturnSequences); + std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, + std::nullopt, std::nullopt, false, std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, + std::nullopt, std::nullopt, LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::nullopt, + numReturnSequences); GenerationRequest seq3{requestId, inputLength, beamWidth, blockManager.getWindowSizesMetadata()}; // reuse block 0, get new blocks 7, 8 auto promptLen3 = llmRequest3->getNumTokens(beamIdx); @@ -1498,7 +1502,7 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithLoraTaskIdTest) LlmRequest::RequestIdType requestId{0}; auto llmRequest0 = std::make_shared(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, - std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, loraTaskId); + std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, loraTaskId); GenerationRequest seq0{requestId, inputLength, beamWidth, blockManager.getWindowSizesMetadata()}; /////////////////////////////////////////////////////////////////////////// @@ -1533,7 +1537,7 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithLoraTaskIdTest) requestId = 1; auto llmRequest1 = std::make_shared(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, - std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, loraTaskId); + std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, loraTaskId); GenerationRequest seq1{requestId, inputLength, beamWidth, blockManager.getWindowSizesMetadata()}; // reuse blocks 0, 1 and get new block 3 @@ -1563,7 +1567,8 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithLoraTaskIdTest) GenerationRequest seq0_dup{10, inputLength, beamWidth, blockManager.getWindowSizesMetadata()}; llmRequest0 = std::make_shared(seq0_dup.getRequestId(), maxNewTokens, inputTokens, samplingConfig, isStreaming, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, - std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, loraTaskId); + std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, + loraTaskId); promptLen0 = llmRequest0->getNumTokens(beamIdx); numContextBlocks0 = tc::ceilDiv(promptLen0, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq0_dup.getRequestId()); @@ -1586,7 +1591,8 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithLoraTaskIdTest) GenerationRequest seq1_dup{11, inputLength, beamWidth, blockManager.getWindowSizesMetadata()}; llmRequest1 = std::make_shared(seq1_dup.getRequestId(), maxNewTokens, inputTokens1, samplingConfig, isStreaming, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, - std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, loraTaskId); + std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, + loraTaskId); promptLen1 = llmRequest1->getNumTokens(beamIdx); numContextBlocks1 = tc::ceilDiv(promptLen1, blockManager.getTokensPerBlock()); // reuse 0, 1, 2(p) ([0,1,2,3], [4,5,6,7], [8]) @@ -1617,7 +1623,7 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithLoraTaskIdTest) requestId = 2; auto llmRequest2 = std::make_shared(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, - std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, loraTaskId); + std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, loraTaskId); numTokens = llmRequest2->getNumTokens(beamIdx); GenerationRequest seq2{requestId, numTokens, beamWidth, blockManager.getWindowSizesMetadata()}; @@ -1648,7 +1654,7 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithLoraTaskIdTest) requestId = 3; auto llmRequest3 = std::make_shared(requestId, maxNewTokens, inputTokens3, samplingConfig, isStreaming, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, - std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, loraTaskId); + std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, loraTaskId); numTokens = llmRequest3->getNumTokens(beamIdx); GenerationRequest seq3{requestId, numTokens, beamWidth, blockManager.getWindowSizesMetadata()}; @@ -1680,7 +1686,7 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithLoraTaskIdTest) requestId = 4; auto llmRequest4 = std::make_shared(requestId, maxNewTokens, inputTokens4, samplingConfig, isStreaming, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, - std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, loraTaskId); + std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, loraTaskId); numTokens = llmRequest4->getNumTokens(beamIdx); GenerationRequest seq4{requestId, numTokens, beamWidth, blockManager.getWindowSizesMetadata()}; @@ -1775,9 +1781,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdAndLoraTaskIdTest) LlmRequest::RequestIdType requestId{0}; auto llmRequest0 = std::make_shared(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, - std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, loraTaskId1, std::nullopt, - std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, std::nullopt, - false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt, + std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, loraTaskId1, + std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, + std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt, LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, inputTokenExtraIds); GenerationRequest seq0{requestId, inputLength, beamWidth, blockManager.getWindowSizesMetadata()}; @@ -1813,9 +1819,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdAndLoraTaskIdTest) LlmRequest::LoraTaskIdType loraTaskId2 = static_cast(2); auto llmRequest1 = std::make_shared(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, - std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, loraTaskId2, std::nullopt, - std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, std::nullopt, - false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt, + std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, loraTaskId2, + std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, + std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt, LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, inputTokenExtraIds); GenerationRequest seq1{requestId, inputLength, beamWidth, blockManager.getWindowSizesMetadata()}; @@ -1844,10 +1850,10 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdAndLoraTaskIdTest) GenerationRequest seq0_dup{10, inputLength, beamWidth, blockManager.getWindowSizesMetadata()}; llmRequest0 = std::make_shared(seq0_dup.getRequestId(), maxNewTokens, inputTokens, samplingConfig, isStreaming, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, - std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, loraTaskId1, - std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, - std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt, - LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, inputTokenExtraIds); + std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, + loraTaskId1, std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, + std::nullopt, false, std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, + std::nullopt, LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, inputTokenExtraIds); promptLen0 = llmRequest0->getNumTokens(beamIdx); numContextBlocks0 = tc::ceilDiv(promptLen0, blockManager.getTokensPerBlock()); // reuse blocks 0, 1 and get new block 6 @@ -1869,10 +1875,10 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdAndLoraTaskIdTest) GenerationRequest seq1_dup{11, inputLength, beamWidth, blockManager.getWindowSizesMetadata()}; llmRequest1 = std::make_shared(seq1_dup.getRequestId(), maxNewTokens, inputTokens1, samplingConfig, isStreaming, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, - std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, loraTaskId2, - std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, - std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt, - LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, inputTokenExtraIds1); + std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, + loraTaskId2, std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, + std::nullopt, false, std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, + std::nullopt, LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, inputTokenExtraIds1); promptLen1 = llmRequest1->getNumTokens(beamIdx); numContextBlocks1 = tc::ceilDiv(promptLen1, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq1_dup.getRequestId()); @@ -1900,9 +1906,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdAndLoraTaskIdTest) requestId = 2; auto llmRequest2 = std::make_shared(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, - std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, loraTaskId1, std::nullopt, - std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, std::nullopt, - false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt, + std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, loraTaskId1, + std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, + std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt, LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, inputTokenExtraIds2); numTokens = llmRequest2->getNumTokens(beamIdx); @@ -1928,9 +1934,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdAndLoraTaskIdTest) requestId = 3; auto llmRequest3 = std::make_shared(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, - std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, loraTaskId1, std::nullopt, - std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, std::nullopt, - false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt, + std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, loraTaskId1, + std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, + std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt, LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, inputTokenExtraIds3); numTokens = llmRequest3->getNumTokens(beamIdx); @@ -1955,9 +1961,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdAndLoraTaskIdTest) requestId = 4; auto llmRequest4 = std::make_shared(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, - std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, loraTaskId2, std::nullopt, - std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, std::nullopt, - false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt, + std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, loraTaskId2, + std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, + std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt, LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, inputTokenExtraIds3); numTokens = llmRequest4->getNumTokens(beamIdx); @@ -2034,8 +2040,8 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithCacheSaltIdTest) auto llmRequest0 = std::make_shared(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, - std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, std::nullopt, - false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt, + std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, + std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt, LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::nullopt, numReturnSequences, std::nullopt, std::nullopt, false, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt); // No cache_salt_id @@ -2075,8 +2081,8 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithCacheSaltIdTest) auto llmRequest1 = std::make_shared(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, - std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, std::nullopt, - false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt, + std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, + std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt, LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::nullopt, numReturnSequences, std::nullopt, std::nullopt, false, std::nullopt, std::nullopt, std::nullopt, std::nullopt, cacheSaltId1); // With cache_salt_id = 12345 @@ -2110,8 +2116,8 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithCacheSaltIdTest) auto llmRequest2 = std::make_shared(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, - std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, std::nullopt, - false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt, + std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, + std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt, LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::nullopt, numReturnSequences, std::nullopt, std::nullopt, false, std::nullopt, std::nullopt, std::nullopt, std::nullopt, cacheSaltId1); // Same cache_salt_id = 12345 @@ -2146,8 +2152,8 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithCacheSaltIdTest) auto llmRequest3 = std::make_shared(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, - std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, std::nullopt, - false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt, + std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, + std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt, LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::nullopt, numReturnSequences, std::nullopt, std::nullopt, false, std::nullopt, std::nullopt, std::nullopt, std::nullopt, cacheSaltId2); // Different cache_salt_id = 67890 @@ -2175,8 +2181,8 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithCacheSaltIdTest) auto llmRequest4 = std::make_shared(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, - std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, std::nullopt, - false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt, + std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, + std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt, LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::nullopt, numReturnSequences, std::nullopt, std::nullopt, false, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt); // No cache_salt_id diff --git a/cpp/tests/unit_tests/batch_manager/microBatchSchedulerTest.cpp b/cpp/tests/unit_tests/batch_manager/microBatchSchedulerTest.cpp index 55834f3c66..d727bf31af 100644 --- a/cpp/tests/unit_tests/batch_manager/microBatchSchedulerTest.cpp +++ b/cpp/tests/unit_tests/batch_manager/microBatchSchedulerTest.cpp @@ -64,7 +64,7 @@ protected: /*badWordsList=*/std::nullopt, /*stopWordsList=*/std::nullopt, /*positionIds=*/std::nullopt, /*promptEmbeddingTable=*/std::nullopt, /*promptVocabSize=*/std::nullopt, /*multimodalHashes=*/std::nullopt, /*multimodalPos=*/std::nullopt, /*multimodalLength=*/std::nullopt, - /*multimodalEmbedding=*/std::nullopt, + /*multimodalUuids=*/std::nullopt, /*multimodalEmbedding=*/std::nullopt, /*mropeRotaryCosSin=*/std::nullopt, /*mropePositionDeltas*/ std::nullopt, /*loraTaskId=*/std::nullopt, /*loraWeights=*/std::nullopt, /*loraConfig=*/std::nullopt, /*lookaheadConfig=*/std::nullopt, /*kvCacheRetentionConfig=*/std::nullopt, diff --git a/cpp/tests/unit_tests/executor/serializeUtilsTest.cpp b/cpp/tests/unit_tests/executor/serializeUtilsTest.cpp index 23683f36c7..d0e1222535 100644 --- a/cpp/tests/unit_tests/executor/serializeUtilsTest.cpp +++ b/cpp/tests/unit_tests/executor/serializeUtilsTest.cpp @@ -1092,7 +1092,7 @@ TEST(SerializeUtilsTest, BlockKeyWithExtras) h1[i] = static_cast(i); h2[i] = static_cast(255 - i); } - std::vector extraKeys{{h1, SizeType32{0}}, {h2, SizeType32{5}}}; + std::vector extraKeys{{h1, SizeType32{0}, std::nullopt}, {h2, SizeType32{5}, std::nullopt}}; VecUniqueTokens uniqueTokens{UniqueToken{10, 100}, UniqueToken{20, 200}}; std::optional loraTaskId = LoraTaskIdType{42}; @@ -1103,6 +1103,144 @@ TEST(SerializeUtilsTest, BlockKeyWithExtras) testSerializeDeserialize(key); } +TEST(SerializeUtilsTest, MmKeyWithUuid) +{ + using tensorrt_llm::executor::MmKey; + + // Test MmKey serialization with UUID + std::array hash{}; + for (size_t i = 0; i < hash.size(); ++i) + { + hash[i] = static_cast(i * 3); + } + + // Test with UUID + MmKey keyWithUuid{hash, SizeType32{42}, std::string("test-image-uuid-12345")}; + testSerializeDeserialize(keyWithUuid); + + // Test without UUID (nullopt) + MmKey keyNoUuid{hash, SizeType32{100}, std::nullopt}; + testSerializeDeserialize(keyNoUuid); + + // Test with empty string UUID + MmKey keyEmptyUuid{hash, SizeType32{0}, std::string("")}; + testSerializeDeserialize(keyEmptyUuid); + + // Test with long UUID (> 32 bytes) + MmKey keyLongUuid{hash, SizeType32{255}, std::string("this-is-a-very-long-uuid-that-exceeds-32-bytes-for-testing")}; + testSerializeDeserialize(keyLongUuid); +} + +TEST(SerializeUtilsTest, BlockKeyWithExtrasAndUuids) +{ + using namespace tensorrt_llm::batch_manager::kv_cache_manager; + + // Prepare multimodal extra keys with mixed UUIDs + std::array h1{}; + std::array h2{}; + std::array h3{}; + for (size_t i = 0; i < h1.size(); ++i) + { + h1[i] = static_cast(i); + h2[i] = static_cast(255 - i); + h3[i] = static_cast(i * 2); + } + + // Mix of UUIDs: one with UUID, one without, one with empty string + std::vector extraKeys{{h1, SizeType32{0}, std::string("sku-image-001")}, {h2, SizeType32{5}, std::nullopt}, + {h3, SizeType32{10}, std::string("")}}; + + VecUniqueTokens uniqueTokens{UniqueToken{10, 100}, UniqueToken{20, 200}}; + std::optional loraTaskId = LoraTaskIdType{42}; + + BlockKey key(true, loraTaskId, uniqueTokens, extraKeys); + + testSerializeDeserialize(key); +} + +TEST(SerializeUtilsTest, MultimodalInputWithUuids) +{ + using tensorrt_llm::executor::MultimodalInput; + + // Helper to verify MultimodalInput serialization round-trip + auto verifyMultimodalInput = [](MultimodalInput const& original) + { + auto size = tensorrt_llm::executor::Serialization::serializedSize(original); + std::ostringstream oss; + tensorrt_llm::executor::Serialization::serialize(original, oss); + EXPECT_EQ(oss.str().size(), size); + + std::istringstream iss(oss.str()); + auto deserialized = tensorrt_llm::executor::Serialization::deserializeMultimodalInput(iss); + + // Verify hashes + auto origHashes = original.getMultimodalHashes(); + auto deserHashes = deserialized.getMultimodalHashes(); + EXPECT_EQ(origHashes.size(), deserHashes.size()); + for (size_t i = 0; i < origHashes.size(); ++i) + { + EXPECT_EQ(origHashes[i], deserHashes[i]); + } + + // Verify positions + auto origPositions = original.getMultimodalPositions(); + auto deserPositions = deserialized.getMultimodalPositions(); + EXPECT_EQ(origPositions, deserPositions); + + // Verify lengths + auto origLengths = original.getMultimodalLengths(); + auto deserLengths = deserialized.getMultimodalLengths(); + EXPECT_EQ(origLengths, deserLengths); + + // Verify UUIDs + auto origUuids = original.getMultimodalUuids(); + auto deserUuids = deserialized.getMultimodalUuids(); + EXPECT_EQ(origUuids.has_value(), deserUuids.has_value()); + if (origUuids.has_value() && deserUuids.has_value()) + { + EXPECT_EQ(origUuids->size(), deserUuids->size()); + for (size_t i = 0; i < origUuids->size(); ++i) + { + EXPECT_EQ((*origUuids)[i], (*deserUuids)[i]); + } + } + }; + + // Test MultimodalInput with UUIDs + std::vector> hashes = { + {1, 2, 3, 4, 5, 6, 7, 8}, // First image hash + {10, 20, 30, 40, 50, 60, 70, 80} // Second image hash + }; + std::vector positions = {0, 100}; + std::vector lengths = {50, 75}; + + // Test with full UUIDs + std::vector> uuids = {std::string("image-uuid-001"), std::string("image-uuid-002")}; + MultimodalInput inputWithUuids(hashes, positions, lengths, uuids); + verifyMultimodalInput(inputWithUuids); + + // Test with partial UUIDs (mixed Some and None) + std::vector> partialUuids = {std::string("uuid-a"), std::nullopt}; + MultimodalInput inputPartialUuids(hashes, positions, lengths, partialUuids); + verifyMultimodalInput(inputPartialUuids); + + // Test without UUIDs (nullopt) + MultimodalInput inputNoUuids(hashes, positions, lengths, std::nullopt); + verifyMultimodalInput(inputNoUuids); + + // Test with empty string UUID + std::vector> emptyUuids = {std::string(""), std::string("valid-uuid")}; + MultimodalInput inputEmptyUuid(hashes, positions, lengths, emptyUuids); + verifyMultimodalInput(inputEmptyUuid); + + // Test with long UUIDs (> 32 bytes) + std::vector> longUuids + = {std::string("this-is-a-very-long-uuid-string-that-exceeds-the-32-byte-limit-for-testing-purposes"), + std::string("short")}; + MultimodalInput inputLongUuids(hashes, positions, lengths, longUuids); + verifyMultimodalInput(inputLongUuids); +} + // Connection notification tests namespace kv_cache = tensorrt_llm::executor::kv_cache; diff --git a/docs/source/features/kvcache.md b/docs/source/features/kvcache.md index 56b0ed41ce..029c636c32 100644 --- a/docs/source/features/kvcache.md +++ b/docs/source/features/kvcache.md @@ -64,6 +64,39 @@ KV cache salting provides a security mechanism to control which requests can reu To use cache salting, specify the `cache_salt` parameter as a string when creating requests. Only requests with matching cache salt values can share cached KV blocks. The salt value can be any non-empty string, such as a user ID, tenant ID, or hash string. +### Multimodal UUID Support for Cache Identification + +When working with multimodal models (e.g., vision-language models), the KV cache system needs to identify which cached blocks correspond to which multimodal inputs (images, videos, etc.). By default, the system uses content-based hashing to generate unique identifiers for each multimodal input. However, this approach has limitations for cache management across sessions, as the same content must be re-processed to generate the same hash. + +To enable deterministic cache management, you can provide custom UUID strings for your multimodal data using the `multi_modal_uuids` parameter when creating requests. When provided, these UUIDs are returned in KV cache events instead of computed content hashes, while the cache key itself is computed from **both** the UUID and content together for correctness. + +**Usage Example:** + +```python +from tensorrt_llm.inputs import TextPrompt + +# Provide custom UUIDs for your images +prompt = TextPrompt( + prompt="Describe these images.", + multi_modal_data={"image": [image1, image2]}, + multi_modal_uuids={"image": ["image-uuid-001", "image-uuid-002"]} +) +``` + +**Key Features:** + +- **Cache Correctness**: When a UUID is provided, the cache key is computed from both the UUID and content together using `BLAKE3(UUID || Content)`. This ensures different content always produces different cache entries, even with the same UUID. +- **User Isolation**: Same content with different UUIDs produces different cache entries, enabling per-user or per-session cache isolation. +- **Stable Event Identifiers**: The original UUID string is preserved and returned in KV cache events via `get_kv_cache_events()`, enabling deterministic external cache management. +- **Partial UUID Support**: You can provide UUIDs for some items and use `None` for others to fall back to content-only hashing. +- **Cross-Modality Support**: Different modalities (images, videos) can each have their own UUIDs. + +**UUID Format:** + +- Can be any string (e.g., "image-123", "user-session-img-a", database keys) +- Original UUID strings are preserved and returned in KV cache events + + ### Enable Offloading to Host Memory Before a block is evicted from GPU memory, it can optionally be offloaded to host (CPU) memory. The block remains reusable until it is evicted from host memory. When an offloaded block is reused, it is first copied back into GPU memory. Offloading is controlled with property ```host_cache_size``` which specifies how much host memory (in bytes) should be allocated for offloading. The default is 0. diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index c2ab5fcc81..2f5542ddf2 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -180,7 +180,8 @@ class KvCacheCreator: req_mm_input = trtllm.MultimodalInput( multimodal_hashes=multimodal_input.multimodal_hashes, multimodal_positions=multimodal_input.multimodal_positions, - multimodal_lengths=multimodal_input.multimodal_lengths + multimodal_lengths=multimodal_input.multimodal_lengths, + multimodal_uuids=multimodal_input.multimodal_uuids ) if multimodal_input else None request = trtllm.Request(prompt_token_ids, diff --git a/tensorrt_llm/_torch/pyexecutor/llm_request.py b/tensorrt_llm/_torch/pyexecutor/llm_request.py index bc10d00364..092fde5a48 100644 --- a/tensorrt_llm/_torch/pyexecutor/llm_request.py +++ b/tensorrt_llm/_torch/pyexecutor/llm_request.py @@ -846,10 +846,12 @@ def executor_request_to_llm_request( multimodal_hashes = None multimodal_positions = None multimodal_lengths = None + multimodal_uuids = None if executor_request.multimodal_input is not None: multimodal_hashes = executor_request.multimodal_input.multimodal_hashes multimodal_positions = executor_request.multimodal_input.multimodal_positions multimodal_lengths = executor_request.multimodal_input.multimodal_lengths + multimodal_uuids = executor_request.multimodal_input.multimodal_uuids # Extract mrope fields mrope_rotary_cos_sin = None @@ -879,6 +881,7 @@ def executor_request_to_llm_request( multimodal_hashes=multimodal_hashes, multimodal_positions=multimodal_positions, multimodal_lengths=multimodal_lengths, + multimodal_uuids=multimodal_uuids, multimodal_embedding=executor_request.multimodal_embedding, lora_task_id=executor_request.lora_config.task_id if executor_request.lora_config is not None else None, diff --git a/tensorrt_llm/_utils.py b/tensorrt_llm/_utils.py index fb35c78d08..720269e3b4 100644 --- a/tensorrt_llm/_utils.py +++ b/tensorrt_llm/_utils.py @@ -1205,14 +1205,24 @@ class KVCacheEventSerializer: @staticmethod def _mm_key_to_json(data): - # MmKey is a pair of (array, SizeType32) - hash_array, start_offset = data + # MmKey is a tuple of (hash_bytes, start_offset, uuid) + # where uuid is optional (None if content-hashed) + if len(data) == 3: + hash_array, start_offset, uuid = data + else: + # Backward compatibility: old format (hash_array, start_offset) + hash_array, start_offset = data + uuid = None # Convert array to hex string hash_hex = ''.join(f'{b:02x}' for b in hash_array) + + # Use UUID from C++ if available, otherwise use hash_hex + hash_or_uuid = uuid if uuid is not None else hash_hex + return { "type": "mm_key", - "hash": hash_hex, + "hash": hash_or_uuid, "start_offset": start_offset } diff --git a/tensorrt_llm/executor/base_worker.py b/tensorrt_llm/executor/base_worker.py index 04888a2ad5..4cf8ad56e2 100644 --- a/tensorrt_llm/executor/base_worker.py +++ b/tensorrt_llm/executor/base_worker.py @@ -427,7 +427,9 @@ class BaseWorker(GenerationExecutor): multimodal_positions=request.multimodal_params. multimodal_input.multimodal_positions, multimodal_lengths=request.multimodal_params. - multimodal_input.multimodal_lengths) + multimodal_input.multimodal_lengths, + multimodal_uuids=request.multimodal_params.multimodal_input. + multimodal_uuids) # NOTE: Setting to None here to avoid sending multimodal_input again through the 'py_multimodal_data' field request.multimodal_params.multimodal_input = None diff --git a/tensorrt_llm/inputs/data.py b/tensorrt_llm/inputs/data.py index 6f28d287cc..615043fe48 100644 --- a/tensorrt_llm/inputs/data.py +++ b/tensorrt_llm/inputs/data.py @@ -17,6 +17,15 @@ class TextPrompt(TypedDict): if the model supports it. """ + multi_modal_uuids: NotRequired[Dict[str, List[Any]]] + """ + Optional user-provided UUIDs for multimodal items. + Structure mirrors multi_modal_data: {"image": ["uuid1", None, "uuid3"]}. + When a UUID is provided for an item, it will be returned in KV cache events + instead of the computed content hash. Use None to fall back to content + hashing for specific items. + """ + mm_processor_kwargs: NotRequired[Dict[str, Any]] """ Optional multi-modal processor kwargs to be forwarded to the @@ -39,6 +48,15 @@ class TokensPrompt(TypedDict): if the model supports it. """ + multi_modal_uuids: NotRequired[Dict[str, List[Any]]] + """ + Optional user-provided UUIDs for multimodal items. + Structure mirrors multi_modal_data: {"image": ["uuid1", None, "uuid3"]}. + When a UUID is provided for an item, it will be returned in KV cache events + instead of the computed content hash. Use None to fall back to content + hashing for specific items. + """ + mm_processor_kwargs: NotRequired[Dict[str, Any]] """ Optional multi-modal processor kwargs to be forwarded to the diff --git a/tensorrt_llm/inputs/multimodal.py b/tensorrt_llm/inputs/multimodal.py index d7ff9d08f8..5a1eac4479 100644 --- a/tensorrt_llm/inputs/multimodal.py +++ b/tensorrt_llm/inputs/multimodal.py @@ -38,6 +38,22 @@ class MultimodalInput: (e.g., image_end_token, image_break_token for mistral3) mixed with the actual multimodal tokens. """ + multimodal_uuids: Optional[List[Optional[str]]] = None + """Optional user-provided UUIDs for multimodal data items. + + When provided, these UUIDs will be returned in KV cache events instead of the + computed hash hex string. This enables deterministic cache identification across + sessions using user-defined stable identifiers. + + Each element can be: + - A string UUID: Used as the cache identifier (returned in events) + - None: Falls back to content-based hashing for that item + + If the UUID string is longer than 32 bytes, it will be hashed internally + for cache key computation, but the original UUID string is preserved and + returned in KV cache events. + """ + def __post_init__(self): """Validate input data structure and consistency.""" # Validate multimodal_hashes @@ -69,13 +85,32 @@ class MultimodalInput: f"positions={len(self.multimodal_positions)}, lengths={len(self.multimodal_lengths)}" ) + # Validate multimodal_uuids if provided + if self.multimodal_uuids is not None: + if not isinstance(self.multimodal_uuids, list): + raise TypeError("multimodal_uuids must be a list") + if len(self.multimodal_uuids) != len(self.multimodal_hashes): + raise ValueError( + f"multimodal_uuids length ({len(self.multimodal_uuids)}) must match " + f"multimodal_hashes length ({len(self.multimodal_hashes)})") + for i, uuid in enumerate(self.multimodal_uuids): + if uuid is not None and not isinstance(uuid, str): + raise TypeError( + f"multimodal_uuids[{i}] must be a string or None, got {type(uuid)}" + ) + @classmethod - def from_components(cls, mm_hashes: List[List[int]], - mm_positions: List[int], - mm_lengths: List[int]) -> 'MultimodalInput': + def from_components( + cls, + mm_hashes: List[List[int]], + mm_positions: List[int], + mm_lengths: List[int], + mm_uuids: Optional[List[Optional[str]]] = None + ) -> 'MultimodalInput': return cls(multimodal_hashes=mm_hashes, multimodal_positions=mm_positions, - multimodal_lengths=mm_lengths) + multimodal_lengths=mm_lengths, + multimodal_uuids=mm_uuids) def to_tensor(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Convert data to tensors""" @@ -546,46 +581,116 @@ def serialize_item(obj: object) -> bytes: raise ValueError(f"Unsupported object type: {type(obj)}") -def apply_mm_hashes(mm_data: Dict[str, Any], - hash_lib=default_hasher) -> Dict[str, List[str]]: - """Apply hashing to multimodal data items.""" +def apply_mm_hashes( + mm_data: Dict[str, Any], + mm_uuids: Optional[Dict[str, List[Optional[str]]]] = None, + hash_lib=default_hasher +) -> Tuple[Dict[str, List[str]], Optional[List[Optional[str]]]]: + """Apply hashing to multimodal data items, combining UUID with content when provided. - def _hash_image(image): - # TODO: possible hash collision w/ this simplified version (vllm/PR/17378) - hasher = hash_lib() - if isinstance(image, torch.Tensor): + When a UUID is provided for an item, the hash is computed from both the UUID + and the content together: BLAKE3(UUID || Content). This ensures: + - Cache correctness: Different content always produces different hashes + - User isolation: Same content with different UUIDs produces different hashes + - The original UUID string is preserved and returned in KV cache events + + Args: + mm_data: Dictionary of modality -> data items + mm_uuids: Optional dictionary of modality -> list of UUID strings. + Use None for items that should use content-based hashing only. + hash_lib: Hash function to use (default: blake3) + + Returns: + Tuple of: + - Dictionary of modality -> list of hash hex strings (64 chars each) + - Flattened list of original UUID strings (or None for content-hashed items) + """ + + def _hash_content(hasher, item): + """Hash the content of a multimodal item into the provided hasher.""" + if isinstance(item, torch.Tensor): # Ensure tensor is on CPU and contiguous for consistent hashing - image = image.detach().cpu().contiguous() - hasher.update(serialize_item(image)) - elif isinstance(image, list): + item = item.detach().cpu().contiguous() + hasher.update(serialize_item(item)) + elif isinstance(item, list): # Hash each frame with a separator to avoid collisions between [A,B] and [AB] - for frame in image: + for frame in item: hasher.update(b"") if isinstance(frame, torch.Tensor): frame = frame.detach().cpu().contiguous() hasher.update(serialize_item(frame)) - elif isinstance(image, tensorrt_llm.inputs.utils.VideoData): - frames = image.frames + elif isinstance(item, tensorrt_llm.inputs.utils.VideoData): + frames = item.frames for frame in frames: hasher.update(b"") if isinstance(frame, torch.Tensor): frame = frame.detach().cpu().contiguous() hasher.update(serialize_item(frame)) else: - hasher.update(serialize_item(image)) + hasher.update(serialize_item(item)) + def _hash_item(item): + """Hash only the content of a multimodal item (no UUID).""" + # TODO: possible hash collision w/ this simplified version (vllm/PR/17378) + hasher = hash_lib() + _hash_content(hasher, item) + return hasher.hexdigest() + + def _hash_item_with_uuid(item, uuid: str): + """Hash UUID and content together: BLAKE3(UUID || Content). + + This creates a unique hash that incorporates both the user-provided + identifier and the actual content, ensuring cache correctness while + supporting user-defined cache isolation. + """ + hasher = hash_lib() + # Hash UUID first with delimiters to prevent length-extension ambiguity + hasher.update(b"") + hasher.update(uuid.encode('utf-8')) + hasher.update(b"") + # Then hash the content + hasher.update(b"") + _hash_content(hasher, item) + hasher.update(b"") return hasher.hexdigest() mm_items = { modality: items if isinstance(items, list) else [items] for modality, items in mm_data.items() } - # TODO: need to hash both modality and item to distinguish modality (vllm/PR) - mm_hashes = { - modality: [_hash_image(item) for item in items] - for modality, items in mm_items.items() - } - return mm_hashes + + # Collect UUIDs in the same order as items + all_uuids: List[Optional[str]] = [] + mm_hashes: Dict[str, List[str]] = {} + + for modality, items in mm_items.items(): + modality_uuids = None + if mm_uuids is not None and modality in mm_uuids: + modality_uuids = mm_uuids[modality] + if not isinstance(modality_uuids, list): + modality_uuids = [modality_uuids] + if len(modality_uuids) != len(items): + raise ValueError( + f"UUID list length ({len(modality_uuids)}) doesn't match " + f"data items length ({len(items)}) for modality '{modality}'" + ) + + hashes = [] + for i, item in enumerate(items): + uuid = modality_uuids[i] if modality_uuids else None + if uuid is not None: + # Hash UUID + content together for cache correctness + hashes.append(_hash_item_with_uuid(item, uuid)) + all_uuids.append(uuid) # Store original UUID + else: + # Fall back to content-only hashing + hashes.append(_hash_item(item)) + all_uuids.append(None) + + mm_hashes[modality] = hashes + + # Return None for uuids if no UUIDs were provided at all + return mm_hashes, all_uuids if mm_uuids is not None else None def hexdigest_to_int32(hex_digest: str) -> List[int]: @@ -604,6 +709,30 @@ def hexdigest_to_int32(hex_digest: str) -> List[int]: return result +def int32_to_hexdigest(int32_values: List[int]) -> str: + """Convert 8 int32 values back to a 64-character hexadecimal digest. + + This is the inverse of hexdigest_to_int32. + + Args: + int32_values: List of 8 signed int32 values + + Returns: + 64-character hexadecimal string representing the 32-byte hash + """ + if len(int32_values) != 8: + raise ValueError(f"Expected 8 int32 values, got {len(int32_values)}") + + result = [] + for value in int32_values: + # Convert signed int32 back to unsigned + if value < 0: + value = value + 0x100000000 + # Format as 8 hex characters (zero-padded) + result.append(f'{value:08x}') + return ''.join(result) + + def find_mm_token_lengths(mm_data: Dict[str, Any], input_processor: Any) -> List[int]: """Get the maximum contiguous multimodal token lengths from multimodal data items. diff --git a/tensorrt_llm/inputs/registry.py b/tensorrt_llm/inputs/registry.py index 8f2d6b940f..0471f3c0af 100644 --- a/tensorrt_llm/inputs/registry.py +++ b/tensorrt_llm/inputs/registry.py @@ -664,10 +664,18 @@ def create_input_processor_with_hash( ) -> Tuple[List[int], Optional[ExtraProcessedInputs]]: """ Process the multinmodal hashing for media tokens if possible. + + Supports optional user-provided UUIDs via 'multi_modal_uuids' in inputs. + When a UUID is provided for a multimodal item, it will be used as the + cache identifier and returned in KV cache events instead of the content hash. """ assert 'multi_modal_data' in inputs, "multi_modal_data must be provided for hashing support." mm_data = inputs['multi_modal_data'] - mm_hashes = apply_mm_hashes(mm_data, hash_lib) + + # Extract optional UUIDs (can be None, or dict with same structure as mm_data) + mm_uuids = inputs.get('multi_modal_uuids', None) + + mm_hashes, mm_uuid_list = apply_mm_hashes(mm_data, mm_uuids, hash_lib) prompt_token_ids, extra_processed_inputs = input_processor( inputs, sampling_params) @@ -698,15 +706,15 @@ def create_input_processor_with_hash( extra_processed_inputs["multimodal_data"][ "special_token_offsets"] = start_special_token_positions # flatten the hashes from dict to a single list - mm_hashes = [h for hashes in mm_hashes.values() for h in hashes] - validate_mm_inputs(prompt_token_ids, mm_hashes, start_positions, + mm_hashes_flat = [h for hashes in mm_hashes.values() for h in hashes] + validate_mm_inputs(prompt_token_ids, mm_hashes_flat, start_positions, num_mm_tokens) - mm_hashes_int32 = [hexdigest_to_int32(h) for h in mm_hashes + mm_hashes_int32 = [hexdigest_to_int32(h) for h in mm_hashes_flat ] # nested list w/ multiple int32 per hash extra_processed_inputs[ "multimodal_input"] = MultimodalInput.from_components( - mm_hashes_int32, start_positions, num_mm_tokens) + mm_hashes_int32, start_positions, num_mm_tokens, mm_uuid_list) return prompt_token_ids, extra_processed_inputs def input_processor_wrapper( diff --git a/tests/unittest/_torch/multimodal/test_mm_encoder_standalone.py b/tests/unittest/_torch/multimodal/test_mm_encoder_standalone.py index abdb45d1f5..88f5aeb684 100644 --- a/tests/unittest/_torch/multimodal/test_mm_encoder_standalone.py +++ b/tests/unittest/_torch/multimodal/test_mm_encoder_standalone.py @@ -181,6 +181,361 @@ def model_dir(request, tmp_path_factory: pytest.TempPathFactory) -> Path: return request.param +@pytest.mark.parametrize( + "use_uuids,expected_hash_type", + [ + # Without UUIDs: mm_key hash should be a 64-char hex string + (False, "hex"), + # With UUIDs: mm_key hash should be the original UUID string + (True, "uuid"), + ]) +def test_kv_event_mm_keys_with_uuid(use_uuids, expected_hash_type): + """Test mm_keys in KV cache events return UUID when provided. + + This test verifies that when multi_modal_uuids is provided: + 1. The KV cache event mm_keys 'hash' field contains the original UUID string + 2. Without UUIDs, the hash field contains a 64-char hex string + + The UUID feature allows users to provide stable identifiers for multimodal + items, which are returned in KV cache events for external cache management. + """ + encoder_model_dir = _LLAVA_DIR + + max_tokens = 16 + free_gpu_memory_fraction = 0.2 + + # Use different images to generate different prompts + prompts = ["Describe the natural environment in the image."] + media = [example_images[0]] + + # Define UUIDs if testing with them + test_uuid = "my-test-image-uuid-12345" + + sampling_params = SamplingParams(max_tokens=max_tokens) + kv_cache_config = KvCacheConfig( + enable_block_reuse=True, + free_gpu_memory_fraction=free_gpu_memory_fraction, + event_buffer_max_size=1024, + ) + + llm = LLM(model=encoder_model_dir, + backend='pytorch', + kv_cache_config=kv_cache_config, + max_batch_size=1) + + # Load inputs with or without UUIDs + if use_uuids: + # Create inputs with multi_modal_uuids + inputs = _load_inputs_with_uuids(llm, prompts, media, [test_uuid]) + else: + inputs = _load_inputs(llm, prompts, media) + + with llm: + for inp in inputs: + _ = llm.generate([inp], sampling_params=sampling_params) + + # Wait for KV cache events to be dispatched asynchronously + time.sleep(0.5) + events = llm.get_kv_cache_events(50) + + # Extract mm_keys from stored events + mm_keys_found = [] + for event in events: + if event and event.get("data", {}).get("type") == "stored": + for block in event["data"].get("blocks", []): + mm_keys_found.extend(block.get("mm_keys", [])) + + # Verify mm_keys were found (multimodal model should have them) + assert len(mm_keys_found) > 0, "Expected mm_keys in stored events" + + # Verify the hash field matches expected type + for mm_key in mm_keys_found: + hash_value = mm_key["hash"] + if expected_hash_type == "uuid": + # Should be the original UUID string + assert hash_value == test_uuid, ( + f"Expected UUID '{test_uuid}', got '{hash_value}'") + else: + # Should be a 64-char hex string + assert len(hash_value) == 64, ( + f"Expected 64-char hex hash, got {len(hash_value)} chars") + # Verify it's valid hex (fromhex will raise ValueError if invalid) + bytes.fromhex(hash_value) + + +def _load_inputs_with_uuids(llm: LLM, prompts, media, uuids): + """Load inputs with multi_modal_uuids for testing. + + This function uses the same processing pipeline as _load_inputs but adds + multi_modal_uuids to the processed inputs. + """ + # Use the standard loader to get properly processed inputs with image tokens + inputs = _load_inputs(llm, prompts, media) + + # Add multi_modal_uuids to the processed inputs + for inp, uuid in zip(inputs, uuids): + inp["multi_modal_uuids"] = {"image": [uuid]} + + return inputs + + +@pytest.mark.parametrize( + "uuids,expected_patterns", + [ + # First image has UUID, second uses content hash + (["custom-uuid-first", None], ["custom-uuid-first", "hex"]), + # Both have UUIDs + (["uuid-img-a", "uuid-img-b"], ["uuid-img-a", "uuid-img-b"]), + # Both use content hash (None) + ([None, None], ["hex", "hex"]), + ]) +def test_kv_event_mm_keys_with_partial_uuids(uuids, expected_patterns): + """Test mm_keys with partial UUIDs (some items with UUID, some without). + + This test verifies the mixed UUID scenario where: + 1. Some multimodal items have user-provided UUIDs + 2. Other items fall back to content-based hashing + 3. KV cache events correctly return UUID or hex hash based on input + """ + encoder_model_dir = _LLAVA_DIR + + max_tokens = 16 + free_gpu_memory_fraction = 0.2 + + # Two different images with potentially mixed UUIDs + prompt = "Describe both images in detail." + images = [example_images[0], example_images[1]] + + sampling_params = SamplingParams(max_tokens=max_tokens) + kv_cache_config = KvCacheConfig( + enable_block_reuse=True, + free_gpu_memory_fraction=free_gpu_memory_fraction, + event_buffer_max_size=1024, + ) + + llm = LLM(model=encoder_model_dir, + backend='pytorch', + kv_cache_config=kv_cache_config, + max_batch_size=1) + + # Load input using the multimodal input loader directly for multiple images per prompt + config_path = os.path.join(llm._hf_model_dir, 'config.json') + with open(config_path, 'r') as f: + model_config = json.load(f) + model_type = model_config['model_type'] + + inputs = default_multimodal_input_loader(tokenizer=llm.tokenizer, + model_dir=llm._hf_model_dir, + model_type=model_type, + modality="multiple_image", + prompts=[prompt], + media=[images], + image_data_format="pt") + + # Add multi_modal_uuids to the processed input + inputs[0]["multi_modal_uuids"] = {"image": uuids} + inp = inputs[0] + + with llm: + _ = llm.generate([inp], sampling_params=sampling_params) + + # Wait for KV cache events to be dispatched asynchronously + time.sleep(0.5) + events = llm.get_kv_cache_events(50) + + # Collect all unique mm_key hashes from stored events + mm_key_hashes = set() + for event in events: + if event and event.get("data", {}).get("type") == "stored": + for block in event["data"].get("blocks", []): + if block.get("mm_keys"): + for mm_key in block["mm_keys"]: + mm_key_hashes.add(mm_key["hash"]) + + # Verify we got mm_keys + assert len(mm_key_hashes) > 0, "Expected mm_keys in stored events" + + # Verify each expected pattern appears in the results + for pattern in expected_patterns: + if pattern == "hex": + # Should find at least one 64-char hex string + hex_found = any( + len(h) == 64 and all(c in '0123456789abcdef' for c in h) + for h in mm_key_hashes) + assert hex_found, f"Expected hex hash pattern but got: {mm_key_hashes}" + else: + # Should find the exact UUID string + assert pattern in mm_key_hashes, ( + f"Expected UUID '{pattern}' in mm_keys, got: {mm_key_hashes}") + + +def test_kv_event_mm_keys_with_uuid_multiple_prompts(): + """Test mm_keys with UUIDs across multiple prompts, each with its own image. + + This test verifies that when multiple prompts are processed, each with its own + multimodal data and UUID: + 1. Each prompt's mm_keys correctly return the associated UUID + 2. Different UUIDs are preserved for different prompts + 3. KV cache events correctly associate UUIDs with their respective blocks + """ + encoder_model_dir = _LLAVA_DIR + + max_tokens = 16 + free_gpu_memory_fraction = 0.2 + + # Multiple prompts, each with its own image and UUID + prompts = [ + "Describe the natural environment in the image.", + "What objects can you see in the image?", + "Describe the weather in the image.", + ] + media = [example_images[0], example_images[1], example_images[2]] + uuids = ["uuid-image-seashore", "uuid-image-inpaint", "uuid-image-61"] + + sampling_params = SamplingParams(max_tokens=max_tokens) + kv_cache_config = KvCacheConfig( + enable_block_reuse=True, + free_gpu_memory_fraction=free_gpu_memory_fraction, + event_buffer_max_size=2048, + ) + + llm = LLM(model=encoder_model_dir, + backend='pytorch', + kv_cache_config=kv_cache_config, + max_batch_size=1) + + # Load inputs with UUIDs for each prompt + inputs = _load_inputs(llm, prompts, media) + + # Add multi_modal_uuids to each input + for inp, uuid in zip(inputs, uuids): + inp["multi_modal_uuids"] = {"image": [uuid]} + + with llm: + # Generate for each input separately + for inp in inputs: + _ = llm.generate([inp], sampling_params=sampling_params) + + # Wait for KV cache events to be dispatched asynchronously + time.sleep(0.5) + events = llm.get_kv_cache_events(50) + + # Collect all unique mm_key hashes from stored events + mm_key_hashes = set() + for event in events: + if event and event.get("data", {}).get("type") == "stored": + for block in event["data"].get("blocks", []): + if block.get("mm_keys"): + for mm_key in block["mm_keys"]: + mm_key_hashes.add(mm_key["hash"]) + + # Verify we got mm_keys + assert len(mm_key_hashes) > 0, "Expected mm_keys in stored events" + + # Verify each UUID appears in the results + for uuid in uuids: + assert uuid in mm_key_hashes, ( + f"Expected UUID '{uuid}' in mm_keys, got: {mm_key_hashes}") + + +def test_kv_event_mm_keys_with_very_long_uuid(): + """Test mm_keys with UUIDs that exceed 64 bytes. + + This test verifies that the system correctly handles UUIDs that are longer + than the typical 64-character hex hash representation: + 1. Very long UUIDs (>64 bytes) are preserved and returned correctly + 2. No truncation or corruption occurs for long UUID strings + 3. The full UUID is returned in KV cache events + """ + encoder_model_dir = _LLAVA_DIR + + max_tokens = 16 + free_gpu_memory_fraction = 0.2 + + prompt = "Describe the natural environment in the image." + + # Create UUIDs of varying lengths, including very long ones + # Normal UUID (36 chars): standard format + # Medium UUID (80 chars): exceeds 64-byte hash representation + # Very long UUID (200+ chars): stress test for string handling + long_uuid_80 = "sku-product-image-" + "a" * 62 # 80 chars total + very_long_uuid_200 = ( + "enterprise-asset-management-system/region/us-east-1/bucket/media-assets/" + "category/electronics/subcategory/smartphones/brand/example-brand/" + "product-line/flagship-series/sku/SKU-2024-FLAGSHIP-PRO-MAX-256GB-MIDNIGHT-BLACK" + ) # ~200 chars + + # Use 2 images with different long UUIDs + images = [example_images[0], example_images[1]] + uuids = [long_uuid_80, very_long_uuid_200] + + sampling_params = SamplingParams(max_tokens=max_tokens) + kv_cache_config = KvCacheConfig( + enable_block_reuse=True, + free_gpu_memory_fraction=free_gpu_memory_fraction, + event_buffer_max_size=1024, + ) + + llm = LLM(model=encoder_model_dir, + backend='pytorch', + kv_cache_config=kv_cache_config, + max_batch_size=1) + + # Load input using the multimodal input loader for multiple images + config_path = os.path.join(llm._hf_model_dir, 'config.json') + with open(config_path, 'r') as f: + model_config = json.load(f) + model_type = model_config['model_type'] + + inputs = default_multimodal_input_loader(tokenizer=llm.tokenizer, + model_dir=llm._hf_model_dir, + model_type=model_type, + modality="multiple_image", + prompts=[prompt], + media=[images], + image_data_format="pt") + + # Add very long UUIDs + inputs[0]["multi_modal_uuids"] = {"image": uuids} + inp = inputs[0] + + with llm: + _ = llm.generate([inp], sampling_params=sampling_params) + + # Wait for KV cache events to be dispatched asynchronously + time.sleep(0.5) + events = llm.get_kv_cache_events(50) + + # Collect all unique mm_key hashes from stored events + mm_key_hashes = set() + for event in events: + if event and event.get("data", {}).get("type") == "stored": + for block in event["data"].get("blocks", []): + if block.get("mm_keys"): + for mm_key in block["mm_keys"]: + mm_key_hashes.add(mm_key["hash"]) + + # Verify we got mm_keys + assert len(mm_key_hashes) > 0, "Expected mm_keys in stored events" + + # Verify the 80-char UUID is present and not truncated + assert long_uuid_80 in mm_key_hashes, ( + f"Expected 80-char UUID '{long_uuid_80}' in mm_keys, got: {mm_key_hashes}" + ) + + # Verify the 200-char UUID is present and not truncated + assert very_long_uuid_200 in mm_key_hashes, ( + f"Expected 200-char UUID '{very_long_uuid_200}' in mm_keys, got: {mm_key_hashes}" + ) + + # Verify the UUIDs are exactly as provided (no truncation) + for uuid in uuids: + matching = [h for h in mm_key_hashes if h == uuid] + assert len(matching) == 1, ( + f"UUID '{uuid}' (len={len(uuid)}) should appear exactly once, " + f"found {len(matching)} times in {mm_key_hashes}") + + @pytest.fixture(scope="module", params=[False, True]) def pd_disagg(request) -> bool: return request.param diff --git a/tests/unittest/bindings/test_executor_bindings.py b/tests/unittest/bindings/test_executor_bindings.py index 884247f91c..7139db6c92 100644 --- a/tests/unittest/bindings/test_executor_bindings.py +++ b/tests/unittest/bindings/test_executor_bindings.py @@ -1014,6 +1014,64 @@ def test_multimodal_input(): assert config.multimodal_hashes == multimodal_hashes assert config.multimodal_positions == multimodal_positions assert config.multimodal_lengths == multimodal_lengths + # Default value for multimodal_uuids should be None + assert config.multimodal_uuids is None + + +@pytest.mark.parametrize( + "multimodal_uuids,expected_uuids", + [ + # Test with all UUIDs provided + (["sku-image-001", "sku-image-002"], ["sku-image-001", "sku-image-002"] + ), + # Test with partial UUIDs (some None) + (["sku-image-001", None], ["sku-image-001", None]), + # Test with empty list of UUIDs + ([], []), + # Test with None (default) + (None, None), + ], + ids=["all_uuids", "partial_uuids", "empty_list", "none_default"]) +def test_multimodal_input_with_uuids(multimodal_uuids, expected_uuids): + """Test MultimodalInput with user-provided UUIDs.""" + multimodal_hashes = [[1, 2, 3, 4, 5, 6, 7, 8], [8, 7, 6, 5, 4, 3, 2, 1]] + multimodal_positions = [10, 100] + multimodal_lengths = [50, 60] + + config = trtllm.MultimodalInput(multimodal_hashes, multimodal_positions, + multimodal_lengths, multimodal_uuids) + assert config.multimodal_hashes == multimodal_hashes + assert config.multimodal_positions == multimodal_positions + assert config.multimodal_lengths == multimodal_lengths + assert config.multimodal_uuids == expected_uuids + + +def test_multimodal_input_pickle_with_uuids(): + """Test pickling and unpickling of MultimodalInput with UUIDs.""" + multimodal_hashes = [[1, 2, 3, 4, 5, 6, 7, 8], [8, 7, 6, 5, 4, 3, 2, 1]] + multimodal_positions = [10, 100] + multimodal_lengths = [50, 60] + multimodal_uuids = ["test-uuid-1", None] + + config = trtllm.MultimodalInput(multimodal_hashes, multimodal_positions, + multimodal_lengths, multimodal_uuids) + + # Pickle and unpickle + pickled = pickle.dumps(config) + restored = pickle.loads(pickled) + + assert restored.multimodal_hashes == multimodal_hashes + assert restored.multimodal_positions == multimodal_positions + assert restored.multimodal_lengths == multimodal_lengths + assert restored.multimodal_uuids == multimodal_uuids + + # Test with None UUIDs + config_no_uuids = trtllm.MultimodalInput(multimodal_hashes, + multimodal_positions, + multimodal_lengths) + pickled_no_uuids = pickle.dumps(config_no_uuids) + restored_no_uuids = pickle.loads(pickled_no_uuids) + assert restored_no_uuids.multimodal_uuids is None def test_mrope_config(): diff --git a/tests/unittest/llmapi/test_llm_kv_cache_events.py b/tests/unittest/llmapi/test_llm_kv_cache_events.py index ee5da20c43..6e1793bf12 100644 --- a/tests/unittest/llmapi/test_llm_kv_cache_events.py +++ b/tests/unittest/llmapi/test_llm_kv_cache_events.py @@ -109,23 +109,24 @@ def test_kv_cache_event_data_serialization(): def test_mm_keys_serialization(): """Test serialization of multimodal keys (mm_keys) in KV cache events.""" - # Test _mm_key_to_json with a mock mm_key tuple (bytes, int) - # MmKey from C++ is converted to (bytes, int) tuple by pybind11 + # Test _mm_key_to_json with a mock mm_key tuple (bytes, int, uuid) + # MmKey from C++ is converted to (bytes, int, optional) tuple by pybind11 mock_hash = b'\x01\x02\x03\x04\x05\x06\x07\x08' + b'\x00' * 24 # 32 bytes mock_offset = 42 - mock_mm_key = (mock_hash, mock_offset) + # New format: (hash, offset, uuid) - uuid is None for content-hashed items + mock_mm_key = (mock_hash, mock_offset, None) result = KVCacheEventSerializer._mm_key_to_json(mock_mm_key) assert result["type"] == "mm_key" assert result["start_offset"] == 42 - # Hash should be converted to hex string + # Hash should be converted to hex string when UUID is None assert result["hash"] == "0102030405060708" + "00" * 24 assert len(result["hash"]) == 64 # 32 bytes = 64 hex chars # Test with different hash values mock_hash2 = bytes(range(32)) # 0x00 to 0x1f - mock_mm_key2 = (mock_hash2, 100) + mock_mm_key2 = (mock_hash2, 100, None) result2 = KVCacheEventSerializer._mm_key_to_json(mock_mm_key2) assert result2["type"] == "mm_key" @@ -136,10 +137,10 @@ def test_mm_keys_serialization(): def test_mm_keys_deserialization(): """Test deserialization of mm_keys JSON back to 32-byte hash.""" - # Test case 1: Simple hash pattern + # Test case 1: Simple hash pattern (no UUID) mock_hash = b'\x01\x02\x03\x04\x05\x06\x07\x08' + b'\x00' * 24 # 32 bytes mock_offset = 42 - mock_mm_key = (mock_hash, mock_offset) + mock_mm_key = (mock_hash, mock_offset, None) # New format with None UUID # Serialize to JSON json_result = KVCacheEventSerializer._mm_key_to_json(mock_mm_key) @@ -155,7 +156,7 @@ def test_mm_keys_deserialization(): # Test case 2: Sequential bytes 0x00 to 0x1f mock_hash2 = bytes(range(32)) mock_offset2 = 100 - mock_mm_key2 = (mock_hash2, mock_offset2) + mock_mm_key2 = (mock_hash2, mock_offset2, None) json_result2 = KVCacheEventSerializer._mm_key_to_json(mock_mm_key2) recovered_hash2 = bytes.fromhex(json_result2["hash"]) @@ -167,7 +168,7 @@ def test_mm_keys_deserialization(): # Test case 3: All 0xFF bytes mock_hash3 = b'\xff' * 32 mock_offset3 = 255 - mock_mm_key3 = (mock_hash3, mock_offset3) + mock_mm_key3 = (mock_hash3, mock_offset3, None) json_result3 = KVCacheEventSerializer._mm_key_to_json(mock_mm_key3) recovered_hash3 = bytes.fromhex(json_result3["hash"]) @@ -179,7 +180,7 @@ def test_mm_keys_deserialization(): # Test case 4: Random-like pattern mock_hash4 = bytes([0xde, 0xad, 0xbe, 0xef] + [0xca, 0xfe] * 14) mock_offset4 = 1024 - mock_mm_key4 = (mock_hash4, mock_offset4) + mock_mm_key4 = (mock_hash4, mock_offset4, None) json_result4 = KVCacheEventSerializer._mm_key_to_json(mock_mm_key4) recovered_hash4 = bytes.fromhex(json_result4["hash"]) @@ -188,6 +189,274 @@ def test_mm_keys_deserialization(): assert len(recovered_hash4) == 32 +def test_mm_key_with_uuid(): + """Test _mm_key_to_json returns UUID when provided in the tuple.""" + # Create a mock mm_key with new format (hash, offset, uuid) + mock_hash = b'\x01\x02\x03\x04\x05\x06\x07\x08' + b'\x00' * 24 # 32 bytes + mock_offset = 42 + expected_hash = "0102030405060708" + "00" * 24 + + # Test 1: Without UUID (None), should return hex hash + mock_mm_key_no_uuid = (mock_hash, mock_offset, None) + result_no_uuid = KVCacheEventSerializer._mm_key_to_json(mock_mm_key_no_uuid) + assert result_no_uuid["hash"] == expected_hash + assert result_no_uuid["start_offset"] == 42 + + # Test 2: With UUID in tuple, should return UUID directly + test_uuid = "my-custom-image-uuid" + mock_mm_key_with_uuid = (mock_hash, mock_offset, test_uuid) + result_with_uuid = KVCacheEventSerializer._mm_key_to_json( + mock_mm_key_with_uuid) + assert result_with_uuid["hash"] == test_uuid + assert result_with_uuid["start_offset"] == 42 + + # Test 3: Backward compatibility - old format (2 elements) should return hex hash + mock_mm_key_old_format = (mock_hash, mock_offset) + result_old_format = KVCacheEventSerializer._mm_key_to_json( + mock_mm_key_old_format) + assert result_old_format["hash"] == expected_hash + + +def test_apply_mm_hashes_with_uuids(): + """Test apply_mm_hashes with user-provided UUIDs.""" + import torch + + from tensorrt_llm.inputs.multimodal import apply_mm_hashes + + # Create mock multimodal data - use fixed seed for reproducibility + torch.manual_seed(42) + mock_image1 = torch.randn(3, 224, 224) + mock_image2 = torch.randn(3, 224, 224) + mm_data = {"image": [mock_image1, mock_image2]} + + # Test without UUIDs - should use content-only hashing + hashes_no_uuid, uuids_no_uuid = apply_mm_hashes(mm_data) + assert len(hashes_no_uuid["image"]) == 2 + assert all(len(h) == 64 for h in hashes_no_uuid["image"]) + assert uuids_no_uuid is None + + # Test with partial UUIDs (first has UUID, second uses content-only hash) + mm_uuids = {"image": ["sku-1234-a", None]} + hashes_partial, uuids_partial = apply_mm_hashes(mm_data, mm_uuids) + + assert len(hashes_partial["image"]) == 2 + # First hash should be combined UUID+content (different from content-only) + assert len(hashes_partial["image"][0]) == 64 + assert hashes_partial["image"][0] != hashes_no_uuid["image"][ + 0] # UUID changes hash + # Second hash should be content-only (same as without UUID) + assert hashes_partial["image"][1] == hashes_no_uuid["image"][1] + # UUIDs list should have the UUID and None + assert uuids_partial == ["sku-1234-a", None] + + # Test with all UUIDs + mm_uuids_all = {"image": ["sku-1234-a", "sku-1234-b"]} + hashes_all, uuids_all = apply_mm_hashes(mm_data, mm_uuids_all) + + assert len(hashes_all["image"]) == 2 + assert all(len(h) == 64 for h in hashes_all["image"]) + # Both hashes should differ from content-only hashes + assert hashes_all["image"][0] != hashes_no_uuid["image"][0] + assert hashes_all["image"][1] != hashes_no_uuid["image"][1] + # Different UUIDs with different content should produce different hashes + assert hashes_all["image"][0] != hashes_all["image"][1] + assert uuids_all == ["sku-1234-a", "sku-1234-b"] + + +def test_apply_mm_hashes_uuid_content_combined(): + """Test that UUID + content hashing ensures cache correctness. + + This test verifies the key properties of combined UUID+content hashing: + 1. Same UUID + same content = same hash (cache hit expected) + 2. Same UUID + different content = different hash (no incorrect cache hit) + 3. Different UUID + same content = different hash (user isolation) + """ + import torch + + from tensorrt_llm.inputs.multimodal import apply_mm_hashes + + # Create identical images + torch.manual_seed(42) + image_a = torch.randn(3, 224, 224) + image_a_copy = image_a.clone() # Identical content + + # Create a different image + torch.manual_seed(123) + image_b = torch.randn(3, 224, 224) + + # Property 1: Same UUID + same content = same hash + mm_data_a = {"image": [image_a]} + mm_data_a_copy = {"image": [image_a_copy]} + mm_uuids = {"image": ["user-123-img"]} + + hashes_a, _ = apply_mm_hashes(mm_data_a, mm_uuids) + hashes_a_copy, _ = apply_mm_hashes(mm_data_a_copy, mm_uuids) + assert hashes_a["image"][0] == hashes_a_copy["image"][0], \ + "Same UUID + same content should produce identical hashes" + + # Property 2: Same UUID + different content = different hash + mm_data_b = {"image": [image_b]} + hashes_b, _ = apply_mm_hashes(mm_data_b, mm_uuids) + assert hashes_a["image"][0] != hashes_b["image"][0], \ + "Same UUID + different content must produce different hashes" + + # Property 3: Different UUID + same content = different hash (user isolation) + mm_uuids_user2 = {"image": ["user-456-img"]} + hashes_user2, _ = apply_mm_hashes(mm_data_a, mm_uuids_user2) + assert hashes_a["image"][0] != hashes_user2["image"][0], \ + "Different UUID + same content should produce different hashes" + + +def test_int32_hexdigest_roundtrip(): + """Test that hexdigest_to_int32 and int32_to_hexdigest are inverses.""" + from tensorrt_llm.inputs.multimodal import (hexdigest_to_int32, + int32_to_hexdigest) + + # Test with various hash patterns + test_hashes = [ + "0000000000000000000000000000000000000000000000000000000000000000", + "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", + "0102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f20", + "deadbeefcafebabefeedfacebadc0ffedeadbeefcafebabefeedfacebadc0ffe", + ] + + for original_hex in test_hashes: + int32_values = hexdigest_to_int32(original_hex) + recovered_hex = int32_to_hexdigest(int32_values) + assert recovered_hex == original_hex, f"Roundtrip failed for {original_hex}" + + +def test_multimodal_input_dataclass_with_uuids(): + """Test Python MultimodalInput dataclass with UUIDs.""" + from tensorrt_llm.inputs.multimodal import MultimodalInput + + # Test with all UUIDs + mm_input = MultimodalInput(multimodal_hashes=[[1, 2, 3, 4, 5, 6, 7, 8]], + multimodal_positions=[10], + multimodal_lengths=[50], + multimodal_uuids=["test-uuid-123"]) + + assert mm_input.multimodal_uuids == ["test-uuid-123"] + + # Test with partial UUIDs (some None) + mm_input_partial = MultimodalInput( + multimodal_hashes=[[1, 2, 3, 4, 5, 6, 7, 8], [8, 7, 6, 5, 4, 3, 2, 1]], + multimodal_positions=[10, 100], + multimodal_lengths=[50, 60], + multimodal_uuids=["sku-001", None]) + + assert mm_input_partial.multimodal_uuids == ["sku-001", None] + + # Test with None UUIDs (default) + mm_input_no_uuids = MultimodalInput( + multimodal_hashes=[[1, 2, 3, 4, 5, 6, 7, 8]], + multimodal_positions=[10], + multimodal_lengths=[50]) + + assert mm_input_no_uuids.multimodal_uuids is None + + +def test_multimodal_input_dataclass_uuid_validation(): + """Test MultimodalInput validation for multimodal_uuids field.""" + from tensorrt_llm.inputs.multimodal import MultimodalInput + + # Test UUID list length mismatch + with pytest.raises(ValueError, match="multimodal_uuids length"): + MultimodalInput(multimodal_hashes=[[1, 2, 3, 4, 5, 6, 7, 8], + [8, 7, 6, 5, 4, 3, 2, 1]], + multimodal_positions=[10, 100], + multimodal_lengths=[50, 60], + multimodal_uuids=["only-one-uuid"]) + + # Test invalid UUID type + with pytest.raises(TypeError, match="must be a string or None"): + MultimodalInput(multimodal_hashes=[[1, 2, 3, 4, 5, 6, 7, 8]], + multimodal_positions=[10], + multimodal_lengths=[50], + multimodal_uuids=[123]) # Integer instead of string + + # Test invalid multimodal_uuids type (not a list) + with pytest.raises(TypeError, match="multimodal_uuids must be a list"): + MultimodalInput(multimodal_hashes=[[1, 2, 3, 4, 5, 6, 7, 8]], + multimodal_positions=[10], + multimodal_lengths=[50], + multimodal_uuids="not-a-list") + + +def test_multimodal_input_from_components_with_uuids(): + """Test MultimodalInput.from_components factory method with UUIDs.""" + from tensorrt_llm.inputs.multimodal import MultimodalInput + + mm_hashes = [[1, 2, 3, 4, 5, 6, 7, 8], [8, 7, 6, 5, 4, 3, 2, 1]] + mm_positions = [10, 100] + mm_lengths = [50, 60] + mm_uuids = ["uuid-a", "uuid-b"] + + mm_input = MultimodalInput.from_components(mm_hashes, mm_positions, + mm_lengths, mm_uuids) + + assert mm_input.multimodal_hashes == mm_hashes + assert mm_input.multimodal_positions == mm_positions + assert mm_input.multimodal_lengths == mm_lengths + assert mm_input.multimodal_uuids == mm_uuids + + # Test without UUIDs + mm_input_no_uuids = MultimodalInput.from_components(mm_hashes, mm_positions, + mm_lengths) + assert mm_input_no_uuids.multimodal_uuids is None + + +def test_apply_mm_hashes_uuid_length_mismatch(): + """Test apply_mm_hashes raises error on UUID list length mismatch.""" + import torch + + from tensorrt_llm.inputs.multimodal import apply_mm_hashes + + mock_image1 = torch.randn(3, 224, 224) + mock_image2 = torch.randn(3, 224, 224) + mm_data = {"image": [mock_image1, mock_image2]} + + # Mismatched UUID list length + mm_uuids_wrong_length = {"image": ["only-one-uuid"]} # Should have 2 + + with pytest.raises(ValueError, + match="UUID list length.*doesn't match.*data items"): + apply_mm_hashes(mm_data, mm_uuids_wrong_length) + + +def test_apply_mm_hashes_multiple_modalities(): + """Test apply_mm_hashes with multiple modalities and UUIDs.""" + import torch + + from tensorrt_llm.inputs.multimodal import apply_mm_hashes + + # Create mock data for multiple modalities + torch.manual_seed(42) + mock_image = torch.randn(3, 224, 224) + mock_video_frames = [torch.randn(3, 224, 224) for _ in range(4)] + + mm_data = {"image": [mock_image], "video": [mock_video_frames]} + + # First, get content-only hashes (without UUIDs) + hashes_no_uuid, _ = apply_mm_hashes(mm_data) + + # UUIDs for each modality + mm_uuids = {"image": ["img-uuid-001"], "video": ["vid-uuid-001"]} + + hashes, uuids_list = apply_mm_hashes(mm_data, mm_uuids) + + # Check hashes are 64-char hex strings (combined UUID+content hashes) + assert len(hashes["image"][0]) == 64 + assert len(hashes["video"][0]) == 64 + + # Verify UUIDs change the hashes (UUID+content != content-only) + assert hashes["image"][0] != hashes_no_uuid["image"][0] + assert hashes["video"][0] != hashes_no_uuid["video"][0] + + # Check flattened UUID list (order may vary based on dict iteration) + assert set(uuids_list) == {"img-uuid-001", "vid-uuid-001"} + + def test_mm_keys_in_stored_events(): """Test that mm_keys field is present in stored block events.""" llm = create_llm()