TensorRT-LLMs/tests/test_kv_cache_manager.py
Kaiyu Xie 587d063e6d
Update TensorRT-LLM (#506)
* Update TensorRT-LLM

---------

Co-authored-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
2023-11-30 16:46:22 +08:00

353 lines
14 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
import tensorrt_llm
from tensorrt_llm.runtime.kv_cache_manager import (Block, BlocksManager,
GenerationSequence,
KVCacheManager)
class TestKVCacheManager(unittest.TestCase):
_sizeof = {torch.float32: 4, torch.float16: 2, torch.int8: 1}
def setUp(self):
tensorrt_llm.logger.set_level('error')
def test_block(self):
block = Block(block_idx=0, k_ptrs=[123321], v_ptrs=[321123])
block.add_link()
self.assertEqual(block.ref_count, 1)
block.add_link()
self.assertEqual(block.ref_count, 2)
self.assertTrue(block.has_link())
block.remove_link()
self.assertEqual(block.ref_count, 1)
block.remove_link()
self.assertEqual(block.ref_count, 0)
self.assertFalse(block.has_link())
self.assertEqual(block.get_k_ptr(0), 123321)
self.assertEqual(block.get_v_ptr(0), 321123)
def test_sequence(self):
seq = GenerationSequence(seq_idx=1, batch_idx=0)
self.assertEqual(seq.get_batch_idx(), 0)
self.assertEqual(seq.get_seq_idx(), 1)
seq1 = GenerationSequence(seq_idx=1, batch_idx=1)
seq2 = GenerationSequence(seq_idx=1, batch_idx=0)
seq3 = GenerationSequence(seq_idx=0, batch_idx=0)
self.assertNotEqual(seq, seq1)
self.assertEqual(seq, seq2)
self.assertNotEqual(seq, seq3)
def allocate_blocks(self, manager, sequences, block_len):
for _ in range(block_len):
for seq in sequences:
self.assertTrue(manager.has_free_block())
manager.allocate(seq)
# All blocks should be allocated by now
self.assertFalse(manager.has_free_block())
def verify_pointer_array(self,
manager,
sequences,
block_len,
total_blocks,
max_blocks_per_seq,
block_elts,
memory_pool,
pool_idx=0):
pointers = manager.get_pointer_array(pool_idx, beam_width=1)
self.assertEqual(pointers.shape,
torch.Size([len(sequences), 1, 2, max_blocks_per_seq]))
# Check if pointer array is correct
for seq in sequences:
for block in range(block_len):
linear_block_idx = (block * len(sequences) +
seq.get_batch_idx())
self.assertEqual(pointers[seq.get_batch_idx()][0][0][block], memory_pool.data_ptr() + \
linear_block_idx * block_elts * self._sizeof[memory_pool.dtype])
self.assertEqual(pointers[seq.get_batch_idx()][0][1][block], memory_pool.data_ptr() + \
(linear_block_idx * block_elts + total_blocks * block_elts) * \
self._sizeof[memory_pool.dtype])
def free_blocks(self, manager, sequences, block_len):
for seq in sequences:
manager.free(seq)
# We don't have double references to the blocks for now
self.assertEqual(len(manager.free_blocks),
(seq.get_batch_idx() + 1) * block_len)
def full_allocate_free_test(self, manager, sequences, block_len,
total_blocks, max_blocks_per_seq, block_elts,
memory_pool):
self.allocate_blocks(manager, sequences, block_len)
self.verify_pointer_array(manager, sequences, block_len, total_blocks,
max_blocks_per_seq, block_elts, memory_pool)
self.free_blocks(manager, sequences, block_len)
def test_blocks_manager_single_pool(self):
max_seq = 32
max_blocks_per_seq = 32
block_elts = 64
memory_pool = torch.zeros(max_seq,
2,
max_blocks_per_seq,
block_elts,
dtype=torch.float,
device='cuda')
sequences = [
GenerationSequence(seq_idx=idx, batch_idx=idx)
for idx in range(max_seq)
]
manager = BlocksManager(memory_pools=[memory_pool],
blocks=max_seq * max_blocks_per_seq,
max_blocks_per_seq=max_blocks_per_seq)
self.assertEqual(len(manager.free_blocks), max_seq * max_blocks_per_seq)
self.assertTrue(manager.has_free_block())
# Allocate maximum amount of blocks for maximum amount of sequences
self.full_allocate_free_test(manager, sequences, max_blocks_per_seq,
max_seq * max_blocks_per_seq,
max_blocks_per_seq, block_elts,
memory_pool)
manager = BlocksManager(memory_pools=[memory_pool],
blocks=max_seq * max_blocks_per_seq,
max_blocks_per_seq=max_blocks_per_seq)
# Allocate 2x more sequences with 2 times smaller num of blocks
sequences_2x = [
GenerationSequence(seq_idx=idx, batch_idx=idx)
for idx in range(2 * max_seq)
]
self.full_allocate_free_test(manager, sequences_2x,
max_blocks_per_seq // 2,
max_seq * max_blocks_per_seq,
max_blocks_per_seq, block_elts,
memory_pool)
manager = BlocksManager(memory_pools=[memory_pool],
blocks=max_seq * max_blocks_per_seq,
max_blocks_per_seq=max_blocks_per_seq)
# Allocate maximum amount of blocks for maximum amount of sequences
self.allocate_blocks(manager, sequences, max_blocks_per_seq)
# Can't allocate more blocks
with self.assertRaises(RuntimeError) as context:
manager.allocate(sequences[0])
self.assertEqual("Can't allocate new block for KV cache",
str(context.exception))
def test_blocks_manager_multi_pool(self):
max_seq = 32
max_blocks_per_seq = 32
block_elts_1 = 64
block_elts_2 = 128
memory_pool_1 = torch.zeros(max_seq,
2,
max_blocks_per_seq,
block_elts_1,
dtype=torch.float,
device='cuda')
memory_pool_2 = torch.zeros(max_seq,
2,
max_blocks_per_seq,
block_elts_2,
dtype=torch.float,
device='cuda')
sequences = [
GenerationSequence(seq_idx=idx, batch_idx=idx)
for idx in range(max_seq)
]
manager = BlocksManager(memory_pools=[memory_pool_1, memory_pool_2],
blocks=max_seq * max_blocks_per_seq,
max_blocks_per_seq=max_blocks_per_seq)
self.allocate_blocks(manager, sequences, max_blocks_per_seq)
# Verify that pointers to the both pools are ok
self.verify_pointer_array(manager,
sequences,
max_blocks_per_seq,
max_seq * max_blocks_per_seq,
max_blocks_per_seq,
block_elts_1,
memory_pool_1,
pool_idx=0)
self.verify_pointer_array(manager,
sequences,
max_blocks_per_seq,
max_seq * max_blocks_per_seq,
max_blocks_per_seq,
block_elts_2,
memory_pool_2,
pool_idx=1)
def test_blocks_manager_beam(self):
max_seq = 32
max_blocks_per_seq = 32
block_elts = 64
beam_width = 4
blocks = max_seq * max_blocks_per_seq
memory_pool = torch.zeros(max_seq * beam_width,
2,
max_blocks_per_seq,
block_elts,
dtype=torch.float,
device='cuda')
sequences = [
GenerationSequence(seq_idx=idx, batch_idx=idx)
for idx in range(max_seq)
]
manager = BlocksManager(memory_pools=[memory_pool],
blocks=blocks,
max_blocks_per_seq=max_blocks_per_seq,
beam_width=beam_width)
manager.allocate(sequences[0], share_across_beam=True)
beams_blocks = manager.allocated_blocks[sequences[0]]
self.assertEqual(beams_blocks[0][0].idx, beams_blocks[1][0].idx)
self.assertEqual(beams_blocks[1][0].idx, beams_blocks[2][0].idx)
self.assertEqual(beams_blocks[2][0].idx, beams_blocks[3][0].idx)
self.assertEqual(beams_blocks[1][0].ref_count, beam_width)
manager.allocate(sequences[1], share_across_beam=False)
beams_blocks = manager.allocated_blocks[sequences[1]]
self.assertNotEqual(beams_blocks[0][0].idx, beams_blocks[1][0].idx)
self.assertNotEqual(beams_blocks[1][0].idx, beams_blocks[2][0].idx)
self.assertNotEqual(beams_blocks[2][0].idx, beams_blocks[3][0].idx)
self.assertEqual(beams_blocks[0][0].ref_count, 1)
self.assertEqual(beams_blocks[1][0].ref_count, 1)
self.assertEqual(beams_blocks[2][0].ref_count, 1)
self.assertEqual(beams_blocks[3][0].ref_count, 1)
manager.free(sequences[1])
self.assertEqual(len(manager.free_blocks), blocks - 1)
manager.free(sequences[0])
self.assertEqual(len(manager.free_blocks), blocks)
def test_kv_cache_manager(self):
blocks = 128
tokens_per_block = 32
max_blocks_per_seq = 16
dims_per_head_1 = 64
dims_per_head_2 = 128
memory_pool_1 = torch.zeros(2,
blocks,
tokens_per_block,
dims_per_head_1,
dtype=torch.float,
device='cuda')
memory_pool_2 = torch.zeros(2,
blocks,
tokens_per_block,
dims_per_head_2,
dtype=torch.float,
device='cuda')
manager = KVCacheManager(memory_pools=[memory_pool_1, memory_pool_2],
blocks=blocks,
tokens_per_block=tokens_per_block,
max_kv_cache_len=max_blocks_per_seq *
tokens_per_block,
max_blocks_per_seq=max_blocks_per_seq)
manager.add_sequence(GenerationSequence(seq_idx=0, batch_idx=0), 30)
manager.add_sequence(GenerationSequence(seq_idx=1, batch_idx=1), 35)
manager.add_sequence(GenerationSequence(seq_idx=2, batch_idx=2), 31)
def check_amount_of_blocks(sequence, expected_blocks):
for bi in range(max_blocks_per_seq):
if bi < expected_blocks:
self.assertNotEqual(sequence[bi], 0)
else:
self.assertEqual(sequence[bi], 0)
arrays = manager.get_pointer_arrays(beam_width=1)
arrays = [arr.view(dtype=torch.int64) for arr in arrays]
# Expect 2 arrays for 2 memory pools
self.assertEqual(len(arrays), 2)
check_amount_of_blocks(arrays[0][0][0][0], 1)
check_amount_of_blocks(arrays[0][1][0][0], 2)
check_amount_of_blocks(arrays[0][2][0][0], 1)
self.assertEqual(manager.lens[0], 30)
self.assertEqual(manager.lens[1], 35)
self.assertEqual(manager.lens[2], 31)
# After this loop sequence 1 should have 33 tokens and 2 blocks
for _ in range(3):
manager.step([False, False, False])
arrays = manager.get_pointer_arrays(beam_width=1)
arrays = [arr.view(dtype=torch.int64) for arr in arrays]
check_amount_of_blocks(arrays[0][0][0][0], 2)
check_amount_of_blocks(arrays[0][1][0][0], 2)
check_amount_of_blocks(arrays[0][2][0][0], 2)
self.assertEqual(manager.lens[0], 33)
self.assertEqual(manager.lens[1], 38)
self.assertEqual(manager.lens[2], 34)
# Second sequence finishes
manager.step([False, True, False])
self.assertEqual(len(manager.sequences), 2)
self.assertEqual(len(manager.lens), 2)
arrays = manager.get_pointer_arrays(beam_width=1)
arrays = [arr.view(dtype=torch.int64) for arr in arrays]
self.assertEqual(manager.lens[0], 34)
self.assertEqual(manager.lens[1], 35)
check_amount_of_blocks(arrays[0][0][0][0], 2)
check_amount_of_blocks(arrays[0][1][0][0], 2)
# Second sequence finishes
manager.step([False, True])
self.assertEqual(len(manager.sequences), 1)
self.assertEqual(len(manager.lens), 1)
arrays = manager.get_pointer_arrays(beam_width=1)
arrays = [arr.view(dtype=torch.int64) for arr in arrays]
self.assertEqual(manager.lens[0], 35)
check_amount_of_blocks(arrays[0][0][0][0], 2)
if __name__ == '__main__':
unittest.main()