TensorRT-LLMs/tests/test_kv_cache_update.py
Kaiyu Xie 9bd15f1937
TensorRT-LLM v0.10 update
* TensorRT-LLM Release 0.10.0

---------

Co-authored-by: Loki <lokravi@amazon.com>
Co-authored-by: meghagarwal <16129366+megha95@users.noreply.github.com>
2024-06-05 20:43:25 +08:00

160 lines
6.4 KiB
Python

import random
import pytest
import torch
from utils.util import getSMVersion
import tensorrt_llm # noqa
def is_integer_type(torch_dtype):
integer_types = {
torch.int8, torch.int16, torch.int32, torch.int64, torch.uint8,
torch.short, torch.int, torch.long
}
return torch_dtype in integer_types
def is_float_type(torch_dtype):
float_types = {
torch.float16, torch.float32, torch.float64, torch.float, torch.double,
torch.half
}
return torch_dtype in float_types
@pytest.mark.parametrize("num_kv_heads", [1, 4, 8])
@pytest.mark.parametrize("head_dim", [64, 67, 128])
@pytest.mark.parametrize("layer_count", [1, 32, 45])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize("kv_cache_type", [torch.int8, torch.float16])
@pytest.mark.parametrize("rewind_draft_token_count", [5, 63])
@pytest.mark.parametrize("separate_draft_count", [False, True])
@pytest.mark.parametrize("max_kv_cache_length", [100, 200])
def test_linear_kvcache_update(num_kv_heads: int, head_dim: int,
layer_count: int, batch_size: int,
kv_cache_type: torch.dtype,
rewind_draft_token_count: int,
separate_draft_count: bool,
max_kv_cache_length: int):
if getSMVersion() <= 70:
pytest.skip(
"Volta has invalid-device-function issues when running this tests.")
torch.cuda.set_device(0)
torch.cuda.manual_seed(1234)
cache_shape = (
batch_size,
2,
num_kv_heads,
max_kv_cache_length,
head_dim,
)
elt_size = torch.zeros(1, dtype=kv_cache_type).element_size()
if is_integer_type(kv_cache_type):
past_key_values = [
torch.randint(0,
100,
cache_shape,
dtype=kv_cache_type,
device='cuda') for i in range(layer_count)
]
elif is_float_type(kv_cache_type):
past_key_values = [
torch.rand(cache_shape, dtype=kv_cache_type, device='cuda')
for i in range(layer_count)
]
else:
raise ValueError("dtype is neither float or integer.")
rewind_draft_token_count_tensor_cuda = None
if separate_draft_count:
rewind_draft_token_count_tensor_cpu = torch.randint(
1, rewind_draft_token_count, (batch_size, ), dtype=torch.int32)
rewind_draft_token_count_tensor_cuda = rewind_draft_token_count_tensor_cpu.cuda(
)
else:
rewind_draft_token_count_tensor_cpu = torch.full(
(batch_size, ), rewind_draft_token_count, dtype=torch.int32)
rewind_draft_token_tensor_list = rewind_draft_token_count_tensor_cpu.tolist(
)
accepted_draft_token_counts_list = [
random.randint(0, rewind_draft_token_tensor_list_value) for
rewind_draft_token_tensor_list_value in rewind_draft_token_tensor_list
]
accepted_draft_token_counts = torch.tensor(accepted_draft_token_counts_list,
dtype=torch.int32).cuda()
accepted_draft_token_offsets = torch.zeros(batch_size + 1,
dtype=torch.int32,
device='cuda')
accepted_draft_token_offsets[1:] = torch.cumsum(accepted_draft_token_counts,
dim=0)
accepted_draft_token_offsets_cpu = accepted_draft_token_offsets.to('cpu')
packed_accepted_draft_tokens_indices_cpu = torch.empty(
accepted_draft_token_offsets_cpu[batch_size], dtype=torch.int32)
for seq_idx in range(batch_size):
rand_perm = torch.randperm(rewind_draft_token_tensor_list[seq_idx],
dtype=torch.int32)
seq_start = accepted_draft_token_offsets_cpu[seq_idx]
seq_end = accepted_draft_token_offsets_cpu[seq_idx + 1]
packed_accepted_draft_tokens_indices_cpu[
seq_start:seq_end] = rand_perm[:seq_end - seq_start]
packed_accepted_draft_tokens_indices = packed_accepted_draft_tokens_indices_cpu.to(
'cuda')
past_key_value_lengths = torch.randint(rewind_draft_token_count,
max_kv_cache_length, (batch_size, ),
dtype=torch.int32,
device='cuda')
past_key_value_lengths_cpu = past_key_value_lengths.to('cpu')
# compute ground truth first
ground_truth_past_key_values = []
for i in range(layer_count):
layer_past_key_value = past_key_values[i]
new_layer_past_key_value = layer_past_key_value.clone()
for seq_idx in range(batch_size):
token_start = accepted_draft_token_offsets_cpu[seq_idx]
token_end = accepted_draft_token_offsets_cpu[seq_idx + 1]
for relative_target_idx in range(token_end - token_start):
relative_draft_idx = packed_accepted_draft_tokens_indices_cpu[
token_start + relative_target_idx]
past_key_value_len = past_key_value_lengths_cpu[seq_idx]
rewind_key_value_len = past_key_value_len - rewind_draft_token_tensor_list[
seq_idx]
new_layer_past_key_value[
seq_idx, :, :, rewind_key_value_len +
relative_target_idx] = layer_past_key_value[
seq_idx, :, :,
rewind_key_value_len + relative_draft_idx]
ground_truth_past_key_values.append(new_layer_past_key_value)
torch.cuda.synchronize()
torch.ops.tensorrt_llm.update_kv_cache_draft_token_location(
accepted_draft_token_offsets,
packed_accepted_draft_tokens_indices,
past_key_value_lengths,
False,
layer_count,
num_kv_heads,
head_dim * elt_size,
0 if separate_draft_count else rewind_draft_token_count,
max_kv_cache_length,
rewind_draft_token_count_tensor_cuda if separate_draft_count else None,
past_key_values,
None,
None,
None,
None,
None,
)
torch.cuda.synchronize()
for i in range(layer_count):
layer_past_key_value = past_key_values[i]
ground_truth_layer_past_key_value = ground_truth_past_key_values[i]
assert torch.allclose(layer_past_key_value,
ground_truth_layer_past_key_value)