mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[TRTLLM-9601][feat] Expose mmKeys for multimodal to integrate with dynamo. (#9604)
Signed-off-by: SimengLiu-nv <simengl@nvidia.com>
This commit is contained in:
parent
9a1750c8f9
commit
f21e2b3329
@ -78,9 +78,7 @@ using VecUniqueTokens = tensorrt_llm::runtime::VecUniqueTokens;
|
||||
using LoraTaskIdType = tensorrt_llm::runtime::LoraTaskIdType;
|
||||
using BlocksPerWindow = std::map<SizeType32, std::tuple<SizeType32, SizeType32>>;
|
||||
using CacheSaltIDType = tensorrt_llm::runtime::CacheSaltIDType;
|
||||
|
||||
// Type alias for multimodal hash key (hash array + start offset)
|
||||
using MmKey = std::pair<std::array<uint8_t, 32>, SizeType32>;
|
||||
using MmKey = tensorrt_llm::executor::MmKey;
|
||||
|
||||
template <typename T>
|
||||
using OptionalRef = tensorrt_llm::common::OptionalRef<T>;
|
||||
@ -325,6 +323,8 @@ public:
|
||||
|
||||
size_t getHash() const;
|
||||
|
||||
std::vector<MmKey> getExtraKeys() const;
|
||||
|
||||
private:
|
||||
// Linear ID of block independent of pool
|
||||
IdType mBlockId;
|
||||
|
||||
@ -47,6 +47,12 @@ class BaseKVCacheManager;
|
||||
namespace tensorrt_llm::executor
|
||||
{
|
||||
|
||||
using SizeType32 = tensorrt_llm::runtime::SizeType32;
|
||||
// Mmkey is used in KVCacheBlock when multimodal data presents in a block.
|
||||
// Type alias for hash array + start offset at per-block granularity.
|
||||
// This differs from the per-request level multimodal hash in MultimodalInput.
|
||||
using MmKey = std::pair<std::array<uint8_t, 32>, SizeType32>;
|
||||
|
||||
/// @brief Version of TRT-LLM
|
||||
char const* version() noexcept;
|
||||
|
||||
@ -1691,12 +1697,14 @@ struct KVCacheStoredBlockData
|
||||
{
|
||||
|
||||
KVCacheStoredBlockData(IdType blockHash, tensorrt_llm::runtime::VecUniqueTokens tokens,
|
||||
std::optional<tensorrt_llm::runtime::LoraTaskIdType> loraId, SizeType32 cacheLevel, SizeType32 priority)
|
||||
std::optional<tensorrt_llm::runtime::LoraTaskIdType> loraId, SizeType32 cacheLevel, SizeType32 priority,
|
||||
std::vector<MmKey> mmKeys = {})
|
||||
: blockHash{blockHash}
|
||||
, tokens{std::move(tokens)}
|
||||
, loraId{loraId}
|
||||
, cacheLevel{cacheLevel}
|
||||
, priority{priority}
|
||||
, mmKeys{std::move(mmKeys)}
|
||||
{
|
||||
}
|
||||
|
||||
@ -1710,6 +1718,8 @@ struct KVCacheStoredBlockData
|
||||
SizeType32 cacheLevel;
|
||||
/// @brief The priority of the block
|
||||
SizeType32 priority;
|
||||
/// @brief The multimodal keys of the block
|
||||
std::vector<MmKey> mmKeys;
|
||||
};
|
||||
|
||||
struct KVCacheStoredData
|
||||
|
||||
@ -102,7 +102,7 @@ void KVCacheEventManager::enqueueStoredEvent(std::vector<BlockPtr> const& blocks
|
||||
for (auto const& block : blocks)
|
||||
{
|
||||
data.blocks.emplace_back(block->getHash(), block->getUniqueTokens(), block->getBlockKey().loraTaskId,
|
||||
block->isPrimary() ? kPrimaryLevel : kSecondaryLevel, block->getPriority());
|
||||
block->isPrimary() ? kPrimaryLevel : kSecondaryLevel, block->getPriority(), block->getExtraKeys());
|
||||
}
|
||||
|
||||
enqueueEvent({mEventId++, data, windowSize, mAttentionDpRank});
|
||||
|
||||
@ -284,6 +284,11 @@ tk::KVCacheIndex::UnderlyingType KVCacheBlock::getMemoryPoolBlockIndex() const
|
||||
return mMemoryPoolBlockIndex.get();
|
||||
}
|
||||
|
||||
std::vector<MmKey> KVCacheBlock::getExtraKeys() const
|
||||
{
|
||||
return mBlockKey.extraKeys;
|
||||
}
|
||||
|
||||
bool KVCacheBlock::isPrimary() const
|
||||
{
|
||||
return mMemoryPoolBlockIndex.isPrimary();
|
||||
|
||||
@ -2333,6 +2333,7 @@ size_t Serialization::serializedSize(KVCacheStoredBlockData const& data)
|
||||
totalSize += su::serializedSize(data.loraId);
|
||||
totalSize += su::serializedSize(data.cacheLevel);
|
||||
totalSize += su::serializedSize(data.priority);
|
||||
totalSize += su::serializedSize(data.mmKeys);
|
||||
return totalSize;
|
||||
}
|
||||
|
||||
@ -2343,6 +2344,7 @@ void Serialization::serialize(KVCacheStoredBlockData const& data, std::ostream&
|
||||
su::serialize(data.loraId, os);
|
||||
su::serialize(data.cacheLevel, os);
|
||||
su::serialize(data.priority, os);
|
||||
su::serialize(data.mmKeys, os);
|
||||
}
|
||||
|
||||
KVCacheStoredBlockData Serialization::deserializeKVCacheStoredBlockData(std::istream& is)
|
||||
@ -2352,8 +2354,9 @@ KVCacheStoredBlockData Serialization::deserializeKVCacheStoredBlockData(std::ist
|
||||
auto loraId = su::deserialize<std::optional<tensorrt_llm::runtime::LoraTaskIdType>>(is);
|
||||
auto cacheLevel = su::deserialize<SizeType32>(is);
|
||||
auto priority = su::deserialize<SizeType32>(is);
|
||||
auto mmKeys = su::deserialize<std::vector<tensorrt_llm::batch_manager::kv_cache_manager::MmKey>>(is);
|
||||
|
||||
return KVCacheStoredBlockData{blockHash, tokens, loraId, cacheLevel, priority};
|
||||
return KVCacheStoredBlockData{blockHash, tokens, loraId, cacheLevel, priority, mmKeys};
|
||||
}
|
||||
|
||||
// KVcacheRemovedData
|
||||
|
||||
@ -221,7 +221,22 @@ void initBindings(nb::module_& m)
|
||||
.def_ro("tokens", &tle::KVCacheStoredBlockData::tokens)
|
||||
.def_ro("lora_id", &tle::KVCacheStoredBlockData::loraId)
|
||||
.def_ro("cache_level", &tle::KVCacheStoredBlockData::cacheLevel)
|
||||
.def_ro("priority", &tle::KVCacheStoredBlockData::priority);
|
||||
.def_ro("priority", &tle::KVCacheStoredBlockData::priority)
|
||||
.def_prop_ro("mm_keys",
|
||||
[](tle::KVCacheStoredBlockData const& self)
|
||||
{
|
||||
// Convert std::vector<MmKey> to Python list of tuples (bytes, int)
|
||||
// MmKey = std::pair<std::array<uint8_t, 32>, SizeType32>
|
||||
nb::list result;
|
||||
for (auto const& mmKey : self.mmKeys)
|
||||
{
|
||||
auto const& hashArray = mmKey.first;
|
||||
auto offset = mmKey.second;
|
||||
nb::bytes hashBytes(reinterpret_cast<char const*>(hashArray.data()), hashArray.size());
|
||||
result.append(nb::make_tuple(hashBytes, offset));
|
||||
}
|
||||
return result;
|
||||
});
|
||||
|
||||
nb::class_<tle::KVCacheStoredData>(executor_kv_cache, "KVCacheStoredData")
|
||||
.def_ro("parent_hash", &tle::KVCacheStoredData::parentHash)
|
||||
|
||||
@ -221,7 +221,22 @@ void initBindings(pybind11::module_& m)
|
||||
.def_readonly("tokens", &tle::KVCacheStoredBlockData::tokens)
|
||||
.def_readonly("lora_id", &tle::KVCacheStoredBlockData::loraId)
|
||||
.def_readonly("cache_level", &tle::KVCacheStoredBlockData::cacheLevel)
|
||||
.def_readonly("priority", &tle::KVCacheStoredBlockData::priority);
|
||||
.def_readonly("priority", &tle::KVCacheStoredBlockData::priority)
|
||||
.def_property_readonly("mm_keys",
|
||||
[](tle::KVCacheStoredBlockData const& self)
|
||||
{
|
||||
// Convert std::vector<MmKey> to Python list of tuples (bytes, int)
|
||||
// MmKey = std::pair<std::array<uint8_t, 32>, SizeType32>
|
||||
py::list result;
|
||||
for (auto const& mmKey : self.mmKeys)
|
||||
{
|
||||
auto const& hashArray = mmKey.first;
|
||||
auto offset = mmKey.second;
|
||||
py::bytes hashBytes(reinterpret_cast<char const*>(hashArray.data()), hashArray.size());
|
||||
result.append(py::make_tuple(hashBytes, offset));
|
||||
}
|
||||
return result;
|
||||
});
|
||||
|
||||
py::class_<tle::KVCacheStoredData>(executor_kv_cache, "KVCacheStoredData")
|
||||
.def_readonly("parent_hash", &tle::KVCacheStoredData::parentHash)
|
||||
|
||||
@ -1,4 +1,6 @@
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
|
||||
from tensorrt_llm import LLM, SamplingParams
|
||||
from tensorrt_llm.llmapi import (AttentionDpConfig, AutoDecodingConfig,
|
||||
@ -90,6 +92,9 @@ def add_llm_args(parser):
|
||||
default=False,
|
||||
action='store_true')
|
||||
parser.add_argument("--tokens_per_block", type=int, default=32)
|
||||
parser.add_argument('--log_kv_cache_events',
|
||||
default=False,
|
||||
action='store_true')
|
||||
|
||||
# Runtime
|
||||
parser.add_argument('--disable_overlap_scheduler',
|
||||
@ -190,7 +195,7 @@ def setup_llm(args, **kwargs):
|
||||
free_gpu_memory_fraction=args.kv_cache_fraction,
|
||||
dtype=args.kv_cache_dtype,
|
||||
tokens_per_block=args.tokens_per_block,
|
||||
)
|
||||
event_buffer_max_size=1024 if args.log_kv_cache_events else 0)
|
||||
|
||||
spec_decode_algo = args.spec_decode_algo.upper(
|
||||
) if args.spec_decode_algo is not None else None
|
||||
@ -355,6 +360,13 @@ def main():
|
||||
f"[{i}]{sequence_id_text} Generation {output_name}: {sequence.additional_generation_outputs[output_name]}"
|
||||
)
|
||||
|
||||
if args.log_kv_cache_events:
|
||||
time.sleep(1) # Wait for events to be dispatched
|
||||
events = llm.get_kv_cache_events(5)
|
||||
print("=== KV_CACHE_EVENTS_START ===")
|
||||
print(json.dumps(events, indent=2))
|
||||
print("=== KV_CACHE_EVENTS_END ===")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
|
||||
from quickstart_advanced import add_llm_args, setup_llm
|
||||
|
||||
@ -264,6 +265,14 @@ def main():
|
||||
print(
|
||||
f"[{i}] Prompt: {output['user_input']!r}, Generated text: {output['assistant_response']!r}"
|
||||
)
|
||||
|
||||
if args.log_kv_cache_events:
|
||||
time.sleep(1) # Wait for events to be dispatched
|
||||
events = llm.get_kv_cache_events(5)
|
||||
print("=== KV_CACHE_EVENTS_START ===")
|
||||
print(json.dumps(events, indent=2))
|
||||
print("=== KV_CACHE_EVENTS_END ===")
|
||||
|
||||
return
|
||||
|
||||
# Original single-turn processing logic
|
||||
@ -272,6 +281,7 @@ def main():
|
||||
args.prompt = example_medias_and_prompts[args.modality]["prompt"]
|
||||
if args.media is None:
|
||||
args.media = example_medias_and_prompts[args.modality]["media"]
|
||||
|
||||
inputs = default_multimodal_input_loader(tokenizer=llm.tokenizer,
|
||||
model_dir=str(llm._hf_model_dir),
|
||||
model_type=model_type,
|
||||
@ -281,7 +291,6 @@ def main():
|
||||
image_data_format=image_format,
|
||||
num_frames=args.num_frames,
|
||||
device=args.device)
|
||||
|
||||
lora_request = None
|
||||
if args.load_lora:
|
||||
lora_request = model_class.lora_request(len(inputs), args.modality,
|
||||
@ -306,6 +315,13 @@ def main():
|
||||
if args.logprobs:
|
||||
print(f"[{i}] Logprobs: {output.outputs[0].logprobs}")
|
||||
|
||||
if args.log_kv_cache_events:
|
||||
time.sleep(1) # Wait for events to be dispatched
|
||||
events = llm.get_kv_cache_events(5)
|
||||
print("=== KV_CACHE_EVENTS_START ===")
|
||||
print(json.dumps(events, indent=2))
|
||||
print("=== KV_CACHE_EVENTS_END ===")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@ -1117,7 +1117,9 @@ class KVCacheEventSerializer:
|
||||
"cache_level":
|
||||
data.cache_level,
|
||||
"priority":
|
||||
data.priority
|
||||
data.priority,
|
||||
"mm_keys":
|
||||
KVCacheEventSerializer._mm_keys_to_json(data)
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
@ -1153,6 +1155,30 @@ class KVCacheEventSerializer:
|
||||
"token_extra_id": data.token_extra_id
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _mm_key_to_json(data):
|
||||
# MmKey is a pair of (array<uint8_t, 32>, SizeType32)
|
||||
hash_array, start_offset = data
|
||||
|
||||
# Convert array to hex string
|
||||
hash_hex = ''.join(f'{b:02x}' for b in hash_array)
|
||||
return {
|
||||
"type": "mm_key",
|
||||
"hash": hash_hex,
|
||||
"start_offset": start_offset
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _mm_keys_to_json(data):
|
||||
# MmKeys is a list of MmKey
|
||||
if hasattr(data, 'mm_keys') and data.mm_keys:
|
||||
return [
|
||||
KVCacheEventSerializer._mm_key_to_json(mm_key)
|
||||
for mm_key in data.mm_keys
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
|
||||
def set_prometheus_multiproc_dir() -> object:
|
||||
# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.10/python/sglang/srt/utils.py#L1266
|
||||
|
||||
@ -254,6 +254,7 @@ accuracy/test_cli_flow.py::TestPhi4MiniInstruct::test_tp2 SKIP (https://nvbugs/5
|
||||
accuracy/test_cli_flow.py::TestLongAlpaca7B::test_auto_dtype SKIP (https://nvbugs/5481075)
|
||||
accuracy/test_llm_api.py::TestPhi4MiniInstruct::test_fp8 SKIP (https://nvbugs/5465143, 5481206 WNF)
|
||||
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ngram SKIP (https://nvbugs/5488118)
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV32Exp::test_auto_dtype[False] SKIP (https://nvbugs/5738168)
|
||||
test_e2e.py::test_trtllm_bench_iteration_log[TRT-streaming-meta-llama/Llama-3.1-8B-llama-3.1-model/Meta-Llama-3.1-8B] SKIP (https://nvbugs/5448523)
|
||||
accuracy/test_llm_api_pytorch.py::TestLlama3_2_3B::test_auto_dtype SKIP (https://nvbugs/5520319)
|
||||
examples/test_llama.py::test_llm_llama_1gpu_fp8_kv_cache[llama-v2-7b-hf-bfloat16] SKIP (https://nvbugs/5527940)
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
@ -288,3 +289,85 @@ def test_multi_request_batch_chat(model_key, multimodal_model_config):
|
||||
zip(ref_output.outputs, test_output.outputs)):
|
||||
assert ref_gen.text == test_gen.text, \
|
||||
f"Generated text doesn't match for output {i}, generation {j}:\nReference: {ref_gen.text!r}\nTest: {test_gen.text!r}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"prompts,expected_num_duplicates",
|
||||
[
|
||||
# Full reuse: same media + same prompts
|
||||
# All blocks are reused, thus no duplicates
|
||||
(["Describe the natural environment in the image."] * 2, 0),
|
||||
# Partial reuse: same media + different prompts
|
||||
# Prefix blocks are reused, thus 2 duplicates
|
||||
([
|
||||
"Describe the natural environment in the image.",
|
||||
"What objects can you see in the image?",
|
||||
"Describe the weather in the image.",
|
||||
], 2),
|
||||
])
|
||||
def test_kv_event_mm_keys_with_reuse(prompts, expected_num_duplicates,
|
||||
multimodal_model_config):
|
||||
"""Test mm_keys in KV cache events with cache reuse scenarios.
|
||||
|
||||
This test verifies:
|
||||
1. KV cache events contain mm_keys for multimodal blocks
|
||||
2. mm_keys have the expected structure (hash + start_offset)
|
||||
3. Cache reuse behavior based on media and prompts:
|
||||
- Same media + same prompts: full reuse (0 duplicate offsets)
|
||||
- Same media + different prompts: partial reuse (prefix blocks reused)
|
||||
"""
|
||||
encoder_model_dir = multimodal_model_config['model_dir']
|
||||
|
||||
max_tokens = 16
|
||||
free_gpu_memory_fraction = 0.6
|
||||
|
||||
# Use same image for all prompts
|
||||
media = [example_images[0]] * len(prompts)
|
||||
|
||||
sampling_params = SamplingParams(max_tokens=max_tokens)
|
||||
kv_cache_config = KvCacheConfig(
|
||||
enable_block_reuse=True,
|
||||
free_gpu_memory_fraction=free_gpu_memory_fraction,
|
||||
event_buffer_max_size=1024, # Enable KV cache events
|
||||
)
|
||||
|
||||
llm = LLM(model=encoder_model_dir,
|
||||
backend='pytorch',
|
||||
kv_cache_config=kv_cache_config,
|
||||
max_batch_size=1)
|
||||
|
||||
config_path = os.path.join(llm._hf_model_dir, 'config.json')
|
||||
with open(config_path, 'r') as f:
|
||||
model_config = json.load(f)
|
||||
model_type = model_config['model_type']
|
||||
|
||||
inputs = default_multimodal_input_loader(tokenizer=llm.tokenizer,
|
||||
model_dir=llm._hf_model_dir,
|
||||
model_type=model_type,
|
||||
modality="image",
|
||||
prompts=prompts,
|
||||
media=media,
|
||||
image_data_format="pt")
|
||||
|
||||
# Generate for each input separately to test KV cache reuse
|
||||
for inp in inputs:
|
||||
_ = llm.generate([inp], sampling_params=sampling_params)
|
||||
|
||||
time.sleep(0.5) # Wait for events to be dispatched
|
||||
events = llm.get_kv_cache_events(10)
|
||||
|
||||
# Extract mm_keys offsets from stored events
|
||||
mm_keys_offsets = []
|
||||
for event in events:
|
||||
if event and event.get("data", {}).get("type") == "stored":
|
||||
for block in event["data"].get("blocks", []):
|
||||
if block.get("mm_keys"):
|
||||
for mm_key in block["mm_keys"]:
|
||||
assert "hash" in mm_key, "mm_key should have 'hash' field"
|
||||
assert "start_offset" in mm_key, "mm_key should have 'start_offset' field"
|
||||
mm_keys_offsets.append(mm_key["start_offset"])
|
||||
|
||||
num_duplicates = len(mm_keys_offsets) - len(set(mm_keys_offsets))
|
||||
assert num_duplicates == expected_num_duplicates, (
|
||||
f"Expected {expected_num_duplicates} duplicate mm_keys offsets, "
|
||||
f"got {num_duplicates}. Offsets: {mm_keys_offsets}")
|
||||
|
||||
@ -93,6 +93,9 @@ def test_kv_cache_event_data_serialization():
|
||||
assert serialized_event[0]["data"]["parent_hash"] is None
|
||||
assert len(serialized_event[0]["data"]["blocks"]) == 1
|
||||
assert len(serialized_event[0]["data"]["blocks"][0]["tokens"]) == 4
|
||||
# Verify mm_keys field exists (empty for text-only requests)
|
||||
assert "mm_keys" in serialized_event[0]["data"]["blocks"][0]
|
||||
assert serialized_event[0]["data"]["blocks"][0]["mm_keys"] == []
|
||||
|
||||
req2 = create_llm_request(1, [1, 2, 3, 4, 5])
|
||||
kv_cache_manager.impl.add_sequence(req2.py_request_id, req2.prompt_len, 1,
|
||||
@ -104,6 +107,109 @@ def test_kv_cache_event_data_serialization():
|
||||
serialized_event = KVCacheEventSerializer.serialize(events)
|
||||
|
||||
|
||||
def test_mm_keys_serialization():
|
||||
"""Test serialization of multimodal keys (mm_keys) in KV cache events."""
|
||||
# Test _mm_key_to_json with a mock mm_key tuple (bytes, int)
|
||||
# MmKey from C++ is converted to (bytes, int) tuple by pybind11
|
||||
mock_hash = b'\x01\x02\x03\x04\x05\x06\x07\x08' + b'\x00' * 24 # 32 bytes
|
||||
mock_offset = 42
|
||||
mock_mm_key = (mock_hash, mock_offset)
|
||||
|
||||
result = KVCacheEventSerializer._mm_key_to_json(mock_mm_key)
|
||||
|
||||
assert result["type"] == "mm_key"
|
||||
assert result["start_offset"] == 42
|
||||
# Hash should be converted to hex string
|
||||
assert result["hash"] == "0102030405060708" + "00" * 24
|
||||
assert len(result["hash"]) == 64 # 32 bytes = 64 hex chars
|
||||
|
||||
# Test with different hash values
|
||||
mock_hash2 = bytes(range(32)) # 0x00 to 0x1f
|
||||
mock_mm_key2 = (mock_hash2, 100)
|
||||
result2 = KVCacheEventSerializer._mm_key_to_json(mock_mm_key2)
|
||||
|
||||
assert result2["type"] == "mm_key"
|
||||
assert result2["start_offset"] == 100
|
||||
expected_hash = ''.join(f'{i:02x}' for i in range(32))
|
||||
assert result2["hash"] == expected_hash
|
||||
|
||||
|
||||
def test_mm_keys_deserialization():
|
||||
"""Test deserialization of mm_keys JSON back to 32-byte hash."""
|
||||
# Test case 1: Simple hash pattern
|
||||
mock_hash = b'\x01\x02\x03\x04\x05\x06\x07\x08' + b'\x00' * 24 # 32 bytes
|
||||
mock_offset = 42
|
||||
mock_mm_key = (mock_hash, mock_offset)
|
||||
|
||||
# Serialize to JSON
|
||||
json_result = KVCacheEventSerializer._mm_key_to_json(mock_mm_key)
|
||||
|
||||
# Deserialize hex string back to bytes
|
||||
recovered_hash = bytes.fromhex(json_result["hash"])
|
||||
|
||||
# Verify the recovered hash matches the original
|
||||
assert recovered_hash == mock_hash
|
||||
assert len(recovered_hash) == 32
|
||||
assert json_result["start_offset"] == mock_offset
|
||||
|
||||
# Test case 2: Sequential bytes 0x00 to 0x1f
|
||||
mock_hash2 = bytes(range(32))
|
||||
mock_offset2 = 100
|
||||
mock_mm_key2 = (mock_hash2, mock_offset2)
|
||||
|
||||
json_result2 = KVCacheEventSerializer._mm_key_to_json(mock_mm_key2)
|
||||
recovered_hash2 = bytes.fromhex(json_result2["hash"])
|
||||
|
||||
assert recovered_hash2 == mock_hash2
|
||||
assert len(recovered_hash2) == 32
|
||||
assert json_result2["start_offset"] == mock_offset2
|
||||
|
||||
# Test case 3: All 0xFF bytes
|
||||
mock_hash3 = b'\xff' * 32
|
||||
mock_offset3 = 255
|
||||
mock_mm_key3 = (mock_hash3, mock_offset3)
|
||||
|
||||
json_result3 = KVCacheEventSerializer._mm_key_to_json(mock_mm_key3)
|
||||
recovered_hash3 = bytes.fromhex(json_result3["hash"])
|
||||
|
||||
assert recovered_hash3 == mock_hash3
|
||||
assert len(recovered_hash3) == 32
|
||||
assert json_result3["hash"] == "ff" * 32
|
||||
|
||||
# Test case 4: Random-like pattern
|
||||
mock_hash4 = bytes([0xde, 0xad, 0xbe, 0xef] + [0xca, 0xfe] * 14)
|
||||
mock_offset4 = 1024
|
||||
mock_mm_key4 = (mock_hash4, mock_offset4)
|
||||
|
||||
json_result4 = KVCacheEventSerializer._mm_key_to_json(mock_mm_key4)
|
||||
recovered_hash4 = bytes.fromhex(json_result4["hash"])
|
||||
|
||||
assert recovered_hash4 == mock_hash4
|
||||
assert len(recovered_hash4) == 32
|
||||
|
||||
|
||||
def test_mm_keys_in_stored_events():
|
||||
"""Test that mm_keys field is present in stored block events."""
|
||||
llm = create_llm()
|
||||
sampling_params = SamplingParams(max_tokens=6, temperature=0.01)
|
||||
prompt = "Hello, my name is"
|
||||
|
||||
_ = llm.generate(prompt, sampling_params=sampling_params)
|
||||
|
||||
events = llm.get_kv_cache_events(5)
|
||||
|
||||
# Find stored events and verify mm_keys field
|
||||
for event in events:
|
||||
if event and event["data"]["type"] == "stored":
|
||||
blocks = event["data"]["blocks"]
|
||||
for block in blocks:
|
||||
# mm_keys should always be present (empty list for text-only)
|
||||
assert "mm_keys" in block
|
||||
assert isinstance(block["mm_keys"], list)
|
||||
# For text-only requests, mm_keys should be empty
|
||||
assert block["mm_keys"] == []
|
||||
|
||||
|
||||
def test_expected_kv_cache_events():
|
||||
llm = create_llm()
|
||||
sampling_params = SamplingParams(max_tokens=6, temperature=0.01)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user