TensorRT-LLMs/tests/_torch/test_vanilla_attention.py
Sharan Chetlur 258c7540c0 open source 09df54c0cc99354a60bbc0303e3e8ea33a96bef0 (#2725)
Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>

open source f8c0381a2bc50ee2739c3d8c2be481b31e5f00bd (#2736)

Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>

Add note for blackwell (#2742)

Update the docs to workaround the extra-index-url issue (#2744)

update README.md (#2751)

Fix github io pages (#2761)

Update
2025-02-11 02:21:51 +00:00

204 lines
8.0 KiB
Python

import unittest
import torch
import tensorrt_llm
from tensorrt_llm._torch.attention_backend import (VanillaAttention,
VanillaAttentionMetadata)
from tensorrt_llm._torch.metadata import KVCacheParams
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
from tensorrt_llm.bindings.executor import KvCacheConfig
from tensorrt_llm.mapping import Mapping
class TestVanillaAttention(unittest.TestCase):
def test_vanilla_attention(self):
num_heads = 32
num_kv_heads = 8
head_dim = 128
device = torch.device('cuda')
dtype = torch.bfloat16
context_sequence_lengths = [3, 2]
sequence_lengths = context_sequence_lengths + [1, 1]
past_seen_tokens = [30, 40, 62, 75]
request_ids = [0, 1, 2, 3]
token_nums = (torch.tensor(sequence_lengths) +
torch.tensor(past_seen_tokens)).tolist()
context_1_q = torch.randn(context_sequence_lengths[0],
num_heads * head_dim,
dtype=dtype,
device=device)
context_1_k = torch.randn(context_sequence_lengths[0],
num_kv_heads * head_dim,
dtype=dtype,
device=device)
context_1_v = torch.randn(context_sequence_lengths[0],
num_kv_heads * head_dim,
dtype=dtype,
device=device)
context_2_q = torch.randn(context_sequence_lengths[1],
num_heads * head_dim,
dtype=dtype,
device=device)
context_2_k = torch.randn(context_sequence_lengths[1],
num_kv_heads * head_dim,
dtype=dtype,
device=device)
context_2_v = torch.randn(context_sequence_lengths[1],
num_kv_heads * head_dim,
dtype=dtype,
device=device)
gen_1_q = torch.randn(1,
num_heads * head_dim,
dtype=dtype,
device=device)
gen_1_k = torch.randn(1,
num_kv_heads * head_dim,
dtype=dtype,
device=device)
gen_1_v = torch.randn(1,
num_kv_heads * head_dim,
dtype=dtype,
device=device)
gen_2_q = torch.randn(1,
num_heads * head_dim,
dtype=dtype,
device=device)
gen_2_k = torch.randn(1,
num_kv_heads * head_dim,
dtype=dtype,
device=device)
gen_2_v = torch.randn(1,
num_kv_heads * head_dim,
dtype=dtype,
device=device)
num_blocks = 100
tokens_per_block = 128
num_layers = 1
max_seq_len = num_blocks * tokens_per_block
batch_size = 4
mapping = Mapping(world_size=1, tp_size=1, rank=0)
kv_cache_config = KvCacheConfig(max_tokens=num_blocks *
tokens_per_block)
if dtype == torch.float16:
kv_cache_dtype = tensorrt_llm.bindings.DataType.HALF
elif dtype == torch.bfloat16:
kv_cache_dtype = tensorrt_llm.bindings.DataType.BF16
else:
raise ValueError("Invalid dtype for unit test")
kv_cache_manager = KVCacheManager(
kv_cache_config,
tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF,
num_layers, num_heads, num_kv_heads, head_dim, tokens_per_block,
max_seq_len, batch_size, mapping, kv_cache_dtype)
kv_cache_manager.add_dummy_requests(request_ids, token_nums)
for i in range(num_layers):
buf = kv_cache_manager.get_buffers(i)
if buf is not None:
torch.nn.init.normal_(buf)
vanilla_attn = VanillaAttention(layer_idx=0,
num_heads=num_heads,
head_dim=head_dim,
num_kv_heads=num_kv_heads)
# [context_1, context_2, gen_1, gen_2]
attn_metadata = VanillaAttentionMetadata(
seq_lens=torch.tensor(sequence_lengths, dtype=torch.int),
num_contexts=len(context_sequence_lengths),
kv_cache_params=KVCacheParams(
use_cache=True, num_cached_tokens_per_seq=past_seen_tokens),
max_num_requests=4,
max_num_tokens=8192,
kv_cache_manager=kv_cache_manager,
request_ids=request_ids,
)
attn_metadata.prepare()
result_1 = vanilla_attn.forward(
torch.cat((context_1_q, context_2_q, gen_1_q, gen_2_q)),
torch.cat((context_1_k, context_2_k, gen_1_k, gen_2_k)),
torch.cat((context_1_v, context_2_v, gen_1_v, gen_2_v)),
attn_metadata)
self.assertEqual(result_1.size()[0], sum(context_sequence_lengths) + 2)
# [context_1, gen_1]
num_cached_tokens_per_seq = [
j for j in [past_seen_tokens[0], past_seen_tokens[2]]
]
attn_metadata = VanillaAttentionMetadata(
seq_lens=torch.tensor([context_sequence_lengths[0], 1],
dtype=torch.int),
num_contexts=1,
kv_cache_params=KVCacheParams(
use_cache=True,
num_cached_tokens_per_seq=num_cached_tokens_per_seq,
),
kv_cache_manager=kv_cache_manager,
request_ids=[0, 2],
max_num_requests=4,
max_num_tokens=8192,
)
attn_metadata.kv_cache_manager = kv_cache_manager
attn_metadata.prepare()
result_2 = vanilla_attn.forward(torch.cat((context_1_q, gen_1_q)),
torch.cat((context_1_k, gen_1_k)),
torch.cat((context_1_v, gen_1_v)),
attn_metadata)
self.assertEqual(result_2.size()[0], context_sequence_lengths[0] + 1)
# [context_2, gen_2]
num_cached_tokens_per_seq = [
j for j in [past_seen_tokens[1], past_seen_tokens[3]]
]
attn_metadata = VanillaAttentionMetadata(
seq_lens=torch.tensor([context_sequence_lengths[1], 1],
dtype=torch.int),
num_contexts=1,
kv_cache_params=KVCacheParams(
use_cache=True,
num_cached_tokens_per_seq=num_cached_tokens_per_seq,
),
kv_cache_manager=kv_cache_manager,
request_ids=[1, 3],
max_num_requests=4,
max_num_tokens=8192,
)
attn_metadata.kv_cache_manager = kv_cache_manager
attn_metadata.prepare()
result_3 = vanilla_attn.forward(torch.cat((context_2_q, gen_2_q)),
torch.cat((context_2_k, gen_2_k)),
torch.cat((context_2_v, gen_2_v)),
attn_metadata)
self.assertEqual(result_3.size()[0], context_sequence_lengths[1] + 1)
# assert value
torch.testing.assert_close(
result_1,
torch.cat((
result_2[:context_sequence_lengths[0], :],
result_3[:context_sequence_lengths[1], :],
result_2[context_sequence_lengths[0]:, :],
result_3[context_sequence_lengths[1]:, :],
)))
kv_cache_manager.shutdown()