mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com> open source f8c0381a2bc50ee2739c3d8c2be481b31e5f00bd (#2736) Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com> Add note for blackwell (#2742) Update the docs to workaround the extra-index-url issue (#2744) update README.md (#2751) Fix github io pages (#2761) Update
700 lines
25 KiB
Python
700 lines
25 KiB
Python
import math
|
|
import os
|
|
import sys
|
|
from dataclasses import dataclass
|
|
from typing import Callable, Optional, Sequence
|
|
|
|
import flashinfer
|
|
import pytest
|
|
import torch
|
|
|
|
import tensorrt_llm
|
|
from tensorrt_llm._torch.attention_backend import (AttentionBackend,
|
|
FlashInferAttention,
|
|
VanillaAttention)
|
|
from tensorrt_llm._torch.attention_backend.interface import \
|
|
PredefinedAttentionMask
|
|
from tensorrt_llm._torch.metadata import KVCacheParams
|
|
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
|
|
from tensorrt_llm.bindings.executor import KvCacheConfig
|
|
from tensorrt_llm.mapping import Mapping
|
|
from tensorrt_llm.models.modeling_utils import QuantConfig
|
|
from tensorrt_llm.quantization.mode import QuantAlgo
|
|
|
|
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
|
from utils.util import getSMVersion
|
|
|
|
|
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|
"""
|
|
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
|
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
|
"""
|
|
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
|
if n_rep == 1:
|
|
return hidden_states
|
|
hidden_states = hidden_states[:, :,
|
|
None, :, :].expand(batch, num_key_value_heads,
|
|
n_rep, slen, head_dim)
|
|
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen,
|
|
head_dim)
|
|
|
|
|
|
atol = 1e-2
|
|
rtol = 1e-3
|
|
fp8_atol = 5e-2
|
|
|
|
|
|
@dataclass(kw_only=True, frozen=True)
|
|
class Scenario:
|
|
dtype: torch.dtype = torch.float16
|
|
kvcache_dtype: torch.dtype = torch.float16
|
|
num_layers: int
|
|
num_heads: int = 64
|
|
num_kv_heads: int = 16
|
|
head_dim: int = 128
|
|
page_size: int = 256
|
|
"""flash-attention requires `page_size` to be a multiple of 256"""
|
|
num_pages: int = 4
|
|
qo_len: int = 32
|
|
"""setting kv_len to non-zero to test cross attention"""
|
|
kv_len: int = 0
|
|
causal: bool = True
|
|
batch_size: int = 7
|
|
|
|
@property
|
|
def cross(self) -> bool:
|
|
return self.kv_len != 0
|
|
|
|
@property
|
|
def kv_len_resolved(self) -> int:
|
|
return self.kv_len or self.qo_len
|
|
|
|
@property
|
|
def num_kv_groups(self) -> int:
|
|
return self.num_heads // self.num_kv_heads
|
|
|
|
@property
|
|
def kv_cache_len(self) -> int:
|
|
return self.page_size * self.num_pages
|
|
|
|
@property
|
|
def past_kv_len(self) -> int:
|
|
return self.kv_cache_len - self.kv_len_resolved
|
|
|
|
@property
|
|
def max_num_pages(self) -> int:
|
|
return self.batch_size * self.num_pages
|
|
|
|
@property
|
|
def nnz_qo(self):
|
|
return self.batch_size * self.qo_len
|
|
|
|
@property
|
|
def nnz_kv(self):
|
|
return self.batch_size * self.kv_len_resolved
|
|
|
|
def __post_init__(self) -> None:
|
|
assert self.kv_len <= self.kv_cache_len, "KV len larger than cache len"
|
|
assert self.kv_len != 0 or self.qo_len <= self.kv_cache_len, "Seq len larger than cache len"
|
|
assert not (self.cross
|
|
and self.causal), "Cross attention cannot be causal"
|
|
|
|
|
|
@dataclass(kw_only=True, frozen=True)
|
|
class PagedScenario(Scenario):
|
|
num_generations: int
|
|
|
|
@property
|
|
def num_contexts(self) -> int:
|
|
return self.batch_size - self.num_generations
|
|
|
|
@property
|
|
def num_ctx_q_tokens(self) -> int:
|
|
return self.num_contexts * self.qo_len
|
|
|
|
@property
|
|
def num_ctx_kv_tokens(self) -> int:
|
|
return self.num_contexts * self.kv_len_resolved
|
|
|
|
@property
|
|
def nnz_qo(self) -> int:
|
|
return self.num_ctx_q_tokens + self.num_generations
|
|
|
|
@property
|
|
def nnz_kv(self) -> int:
|
|
n = self.num_ctx_kv_tokens
|
|
if not self.cross:
|
|
n += self.num_generations
|
|
return n
|
|
|
|
|
|
paged_backends = {
|
|
VanillaAttention: False,
|
|
FlashInferAttention: True,
|
|
}
|
|
|
|
|
|
def kv_cache_manager_from(Attention: type[AttentionBackend], s: Scenario,
|
|
kv_cache: torch.Tensor) -> KVCacheManager:
|
|
paged = paged_backends[Attention]
|
|
|
|
num_blocks = s.max_num_pages if paged else s.batch_size
|
|
tokens_per_block = s.page_size if paged else s.kv_cache_len
|
|
num_layers = s.num_layers
|
|
num_kv_heads = s.num_kv_heads
|
|
head_dim = s.head_dim
|
|
num_heads = s.num_kv_heads
|
|
max_seq_len = num_blocks * tokens_per_block
|
|
batch_size = s.batch_size
|
|
|
|
if s.kvcache_dtype == torch.float16:
|
|
kv_cache_dtype = tensorrt_llm.bindings.DataType.HALF
|
|
elif s.kvcache_dtype == torch.bfloat16:
|
|
kv_cache_dtype = tensorrt_llm.bindings.DataType.BF16
|
|
elif s.kvcache_dtype == torch.float8_e4m3fn:
|
|
kv_cache_dtype = tensorrt_llm.bindings.DataType.FP8
|
|
else:
|
|
raise ValueError("Invalid dtype for unit test")
|
|
|
|
kv_cache_config = KvCacheConfig(max_tokens=num_blocks * tokens_per_block)
|
|
|
|
mapping = Mapping(world_size=1, tp_size=1, rank=0)
|
|
|
|
cache_type = tensorrt_llm.bindings.internal.batch_manager.CacheType.CROSS if s.cross else tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF
|
|
|
|
result = KVCacheManager(kv_cache_config, cache_type, num_layers, num_heads,
|
|
num_kv_heads, head_dim, tokens_per_block,
|
|
max_seq_len, batch_size, mapping, kv_cache_dtype)
|
|
|
|
for i in range(s.num_layers):
|
|
result.get_buffers(i).view_as(kv_cache[i]).copy_(kv_cache[i])
|
|
return result
|
|
|
|
|
|
def produce_outputs(
|
|
Attention: type[AttentionBackend],
|
|
q_at_layer: torch.Tensor,
|
|
kv: Optional[torch.Tensor],
|
|
s: Scenario,
|
|
*,
|
|
kv_cache: torch.Tensor,
|
|
num_cached_tokens: Callable[[int], int] | int,
|
|
num_contexts: int | None = None,
|
|
seq_lens: torch.Tensor,
|
|
seq_lens_kv: Optional[torch.Tensor] = None,
|
|
quant_config: Optional[QuantConfig] = None,
|
|
) -> list[torch.Tensor]:
|
|
num_cached_tokens_per_seq = [
|
|
num_cached_tokens
|
|
if isinstance(num_cached_tokens, int) else num_cached_tokens(i)
|
|
for i in range(s.batch_size)
|
|
]
|
|
|
|
kv_cache_params = KVCacheParams(
|
|
use_cache=True, num_cached_tokens_per_seq=num_cached_tokens_per_seq)
|
|
kv_cache_manager = kv_cache_manager_from(Attention, s, kv_cache)
|
|
request_ids = list(range(s.batch_size))
|
|
seq_lens_append = seq_lens_kv if seq_lens_kv is not None else seq_lens
|
|
token_nums = (torch.tensor(num_cached_tokens_per_seq) +
|
|
seq_lens_append).tolist()
|
|
kv_cache_manager.add_dummy_requests(request_ids, token_nums)
|
|
|
|
metadata = Attention.Metadata(
|
|
num_contexts=num_contexts if num_contexts is not None else s.batch_size,
|
|
kv_cache_params=kv_cache_params,
|
|
seq_lens=seq_lens,
|
|
seq_lens_kv=seq_lens_kv,
|
|
max_num_requests=s.batch_size,
|
|
max_num_tokens=8192,
|
|
kv_cache_manager=kv_cache_manager,
|
|
request_ids=request_ids,
|
|
prompt_lens=token_nums,
|
|
)
|
|
metadata.prepare()
|
|
mask = PredefinedAttentionMask.CAUSAL if s.causal else PredefinedAttentionMask.FULL
|
|
outputs = []
|
|
for i in range(s.num_layers):
|
|
q = q_at_layer[i]
|
|
if kv is not None:
|
|
k, v = kv[i][0], kv[i][1]
|
|
else:
|
|
k, v = None, None
|
|
attention = Attention(
|
|
layer_idx=i,
|
|
num_heads=s.num_heads,
|
|
num_kv_heads=s.num_kv_heads,
|
|
head_dim=s.head_dim,
|
|
quant_config=quant_config,
|
|
)
|
|
o = attention.forward(q, k, v, metadata, attention_mask=mask)
|
|
assert list(o.shape) == [s.nnz_qo, s.num_heads * s.head_dim]
|
|
outputs.append(o)
|
|
kv_cache_manager.shutdown()
|
|
return outputs
|
|
|
|
|
|
def allclose(ref: Sequence[torch.Tensor],
|
|
impls: dict[str, Sequence[torch.Tensor]],
|
|
*,
|
|
layer=0,
|
|
atol=atol,
|
|
rtol=rtol):
|
|
for name, outputs in impls.items():
|
|
print(f"{name} output: ", float(outputs[layer].abs().mean()))
|
|
print("ref outputs: ", float(ref[layer].abs().mean()))
|
|
for name, outputs in impls.items():
|
|
print(f"{name} & ref diff: ",
|
|
float((ref[layer] - outputs[layer]).abs().mean()))
|
|
for name, outputs in impls.items():
|
|
torch.testing.assert_close(outputs[layer],
|
|
ref[layer],
|
|
atol=atol,
|
|
rtol=rtol,
|
|
msg=f"Allclose failed: ref<->{name}"),
|
|
|
|
|
|
def test_flashinfer_prefill():
|
|
s = Scenario(num_layers=1)
|
|
dtype = s.dtype
|
|
num_layers = s.num_layers
|
|
num_qo_heads = s.num_heads
|
|
num_kv_heads = s.num_kv_heads
|
|
num_kv_groups = s.num_kv_groups
|
|
head_dim = s.head_dim
|
|
page_size = s.page_size
|
|
num_pages = s.num_pages
|
|
kv_cache_len = s.kv_cache_len
|
|
qo_len = s.qo_len
|
|
past_kv_len = s.past_kv_len
|
|
batch_size = s.batch_size
|
|
nnz_qo = s.nnz_qo
|
|
max_num_pages = s.max_num_pages
|
|
|
|
# allocate 128MB workspace buffer
|
|
workspace_buffer = torch.empty(128 * 1024 * 1024,
|
|
dtype=torch.uint8,
|
|
device="cuda")
|
|
paged_wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
|
|
workspace_buffer, "NHD", backend="fa2")
|
|
ragged_wrapper = flashinfer.BatchPrefillWithRaggedKVCacheWrapper(
|
|
workspace_buffer, "NHD", backend="fa2")
|
|
paged_kv_indices = torch.arange(max_num_pages).int().cuda()
|
|
paged_kv_indptr = torch.arange(0, batch_size + 1).int().cuda() * num_pages
|
|
# 1 <= paged_kv_last_page_len <= page_size
|
|
paged_kv_last_page_len = torch.full((batch_size, ), page_size).int().cuda()
|
|
qo_indptr = torch.arange(0, batch_size + 1).int().to("cuda") * qo_len
|
|
kv_indptr = torch.arange(0, batch_size + 1).int().to("cuda") * kv_cache_len
|
|
|
|
q_at_layer = torch.randn(num_layers,
|
|
nnz_qo,
|
|
num_qo_heads,
|
|
head_dim,
|
|
device="cuda").to(dtype).cuda()
|
|
kv_cache_at_layer = torch.randn(num_layers,
|
|
max_num_pages,
|
|
2,
|
|
page_size,
|
|
num_kv_heads,
|
|
head_dim,
|
|
device="cuda").to(s.kvcache_dtype)
|
|
kv_data = kv_cache_at_layer.transpose(1, 2).contiguous().view(
|
|
num_layers, 2, batch_size, kv_cache_len, num_kv_heads, head_dim)
|
|
|
|
causal_mask = torch.full((qo_len, kv_cache_len),
|
|
fill_value=torch.finfo(dtype).min,
|
|
dtype=dtype,
|
|
device="cuda")
|
|
cache_position = torch.arange(past_kv_len, kv_cache_len).cuda()
|
|
bool_causal_mask = torch.arange(
|
|
kv_cache_len).cuda() <= cache_position.reshape(-1, 1)
|
|
causal_mask *= ~bool_causal_mask
|
|
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
|
|
|
# create auxiliary data structures for batch prefill attention
|
|
paged_wrapper.plan(
|
|
qo_indptr,
|
|
paged_kv_indptr,
|
|
paged_kv_indices,
|
|
paged_kv_last_page_len,
|
|
num_qo_heads,
|
|
num_kv_heads,
|
|
head_dim,
|
|
page_size,
|
|
causal=True,
|
|
)
|
|
ragged_wrapper.plan(
|
|
qo_indptr,
|
|
kv_indptr,
|
|
num_qo_heads,
|
|
num_kv_heads,
|
|
head_dim,
|
|
causal=True,
|
|
)
|
|
flashinfer_outputs = []
|
|
for i in range(num_layers):
|
|
q = q_at_layer[i]
|
|
kv_cache = kv_cache_at_layer[i]
|
|
o = paged_wrapper.run(q, kv_cache)
|
|
k = kv_data[i][0]
|
|
v = kv_data[i][1]
|
|
k = k.view(-1, num_kv_heads, head_dim)
|
|
v = v.view(-1, num_kv_heads, head_dim)
|
|
ragged_o = ragged_wrapper.run(q, k, v)
|
|
assert list(o.shape) == [nnz_qo, num_qo_heads, head_dim]
|
|
print("paged output: ", float(o.abs().mean()))
|
|
print("ragged output: ", float(ragged_o.abs().mean()))
|
|
print("paged & ragged diff: ", float((ragged_o - o).abs().mean()))
|
|
assert torch.allclose(o, ragged_o, atol=atol, rtol=rtol)
|
|
flashinfer_outputs.append(o)
|
|
|
|
sdpa_outputs = []
|
|
for i in range(num_layers):
|
|
q = q_at_layer[i]
|
|
k = kv_data[i][0]
|
|
v = kv_data[i][1]
|
|
q = q.view(batch_size, qo_len, num_qo_heads, head_dim)
|
|
q = q.transpose(1, 2).contiguous()
|
|
k = k.transpose(1, 2).contiguous()
|
|
v = v.transpose(1, 2).contiguous()
|
|
k = repeat_kv(k, num_kv_groups)
|
|
v = repeat_kv(v, num_kv_groups)
|
|
o = torch.nn.functional.scaled_dot_product_attention(
|
|
q,
|
|
k,
|
|
v,
|
|
attn_mask=causal_mask,
|
|
)
|
|
o = o.transpose(1, 2).contiguous().view(nnz_qo, num_qo_heads, head_dim)
|
|
sdpa_outputs.append(o)
|
|
|
|
ref_outputs = []
|
|
for i in range(num_layers):
|
|
q = q_at_layer[i]
|
|
k = kv_data[i][0]
|
|
v = kv_data[i][1]
|
|
q = q.view(batch_size, qo_len, num_qo_heads, head_dim)
|
|
q = q.transpose(1, 2).contiguous()
|
|
k = k.transpose(1, 2).contiguous()
|
|
v = v.transpose(1, 2).contiguous()
|
|
k = repeat_kv(k, num_kv_groups)
|
|
v = repeat_kv(v, num_kv_groups)
|
|
|
|
attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(head_dim)
|
|
attn_weights = attn_weights + causal_mask
|
|
|
|
# upcast attention to fp32
|
|
attn_weights = torch.nn.functional.softmax(attn_weights,
|
|
dim=-1,
|
|
dtype=torch.float32).to(
|
|
q.dtype)
|
|
o = torch.matmul(attn_weights, v)
|
|
|
|
o = o.transpose(1, 2).contiguous().view(nnz_qo, num_qo_heads, head_dim)
|
|
ref_outputs.append(o)
|
|
|
|
allclose(
|
|
ref_outputs,
|
|
{
|
|
"flashinfer": flashinfer_outputs,
|
|
"sdpa": sdpa_outputs,
|
|
},
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"s", [
|
|
Scenario(num_layers=1),
|
|
Scenario(num_layers=1, causal=False),
|
|
Scenario(num_layers=1, qo_len=32, kv_len=32, causal=False),
|
|
Scenario(num_layers=1, qo_len=32, kv_len=64, causal=False)
|
|
],
|
|
ids=["typical", "non-causal", "cross", "cross-diff-kv-len"])
|
|
def test_attention_backend(s: Scenario):
|
|
dtype = s.dtype
|
|
num_layers = s.num_layers
|
|
num_heads = s.num_heads
|
|
num_kv_heads = s.num_kv_heads
|
|
num_kv_groups = s.num_kv_groups
|
|
head_dim = s.head_dim
|
|
page_size = s.page_size
|
|
kv_cache_len = s.kv_cache_len
|
|
qo_len = s.qo_len
|
|
kv_len = s.kv_len_resolved
|
|
past_kv_len = s.past_kv_len
|
|
batch_size = s.batch_size
|
|
nnz_qo = s.nnz_qo
|
|
nnz_kv = s.nnz_kv
|
|
causal = s.causal
|
|
|
|
q_at_layer = torch.randn(num_layers,
|
|
nnz_qo,
|
|
num_heads * head_dim,
|
|
device="cuda").to(dtype)
|
|
flashinfer_kv_cache = torch.randn(num_layers,
|
|
s.max_num_pages,
|
|
2,
|
|
page_size,
|
|
num_kv_heads,
|
|
head_dim,
|
|
device="cuda").to(s.kvcache_dtype)
|
|
ref_kv_cache = flashinfer_kv_cache.transpose(1, 2).contiguous().view(
|
|
num_layers, 2, batch_size, kv_cache_len, num_kv_heads, head_dim)
|
|
kv = torch.randn(num_layers,
|
|
2,
|
|
nnz_kv,
|
|
num_kv_heads * head_dim,
|
|
device="cuda").to(dtype)
|
|
|
|
def produce(Attention: type[AttentionBackend], kv_cache: torch.Tensor):
|
|
return produce_outputs(
|
|
Attention,
|
|
q_at_layer,
|
|
kv,
|
|
s,
|
|
kv_cache=kv_cache,
|
|
num_cached_tokens=past_kv_len,
|
|
seq_lens=torch.full((batch_size, ), qo_len).int(),
|
|
seq_lens_kv=torch.full(
|
|
(batch_size, ), kv_len).int() if s.cross else None,
|
|
)
|
|
|
|
flashinfer_outputs = produce(FlashInferAttention, flashinfer_kv_cache)
|
|
sdpa_outputs = produce(VanillaAttention, ref_kv_cache.transpose(1, 2))
|
|
|
|
# Test reference attention
|
|
if causal:
|
|
causal_mask = torch.full((qo_len, kv_cache_len),
|
|
fill_value=torch.finfo(dtype).min,
|
|
dtype=dtype,
|
|
device="cuda")
|
|
cache_position = torch.arange(past_kv_len, kv_cache_len).cuda()
|
|
bool_causal_mask = torch.arange(
|
|
kv_cache_len).cuda() <= cache_position.reshape(-1, 1)
|
|
causal_mask *= ~bool_causal_mask
|
|
causal_mask = causal_mask[None,
|
|
None, :, :].expand(batch_size, 1, -1, -1)
|
|
else:
|
|
causal_mask = 0
|
|
|
|
ref_outputs = []
|
|
for i in range(num_layers):
|
|
q = q_at_layer[i]
|
|
ref_kv_cache[i][:, :, past_kv_len:kv_cache_len] = kv[i].view(
|
|
2, batch_size, kv_len, num_kv_heads, head_dim)
|
|
k = ref_kv_cache[i][0]
|
|
v = ref_kv_cache[i][1]
|
|
q = q.view(batch_size, qo_len, num_heads, head_dim)
|
|
q = q.transpose(1, 2).contiguous()
|
|
k = k.transpose(1, 2).contiguous()
|
|
v = v.transpose(1, 2).contiguous()
|
|
k = repeat_kv(k, num_kv_groups)
|
|
v = repeat_kv(v, num_kv_groups)
|
|
|
|
attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(head_dim)
|
|
attn_weights = attn_weights + causal_mask
|
|
|
|
# upcast attention to fp32
|
|
attn_weights = torch.nn.functional.softmax(attn_weights,
|
|
dim=-1,
|
|
dtype=torch.float32).to(
|
|
q.dtype)
|
|
o = torch.matmul(attn_weights, v)
|
|
|
|
o = o.transpose(1, 2).contiguous().view(nnz_qo, num_heads * head_dim)
|
|
ref_outputs.append(o)
|
|
|
|
allclose(
|
|
ref_outputs,
|
|
{
|
|
"flashinfer": flashinfer_outputs,
|
|
"sdpa": sdpa_outputs,
|
|
},
|
|
)
|
|
|
|
|
|
def generate_causal_mask(seq_lens, qo_lens, batch_size, dtype):
|
|
causal_masks = []
|
|
max_seq_len = int(seq_lens.max())
|
|
max_qo_len = int(qo_lens.max())
|
|
for i in range(batch_size):
|
|
kv_len = seq_lens[i]
|
|
qo_len = qo_lens[i]
|
|
past_kv_len = kv_len - qo_len
|
|
causal_mask = torch.full((qo_len, kv_len),
|
|
fill_value=torch.finfo(dtype).min,
|
|
dtype=dtype,
|
|
device="cuda")
|
|
cache_position = torch.arange(past_kv_len, kv_len).cuda()
|
|
causal_mask *= torch.arange(kv_len).cuda() > cache_position.reshape(
|
|
-1, 1)
|
|
causal_mask = torch.nn.functional.pad(
|
|
causal_mask, (0, max_seq_len - kv_len, 0, max_qo_len - qo_len),
|
|
'constant',
|
|
torch.finfo(dtype).min)
|
|
causal_masks.append(causal_mask)
|
|
causal_mask = torch.stack(causal_masks).view(batch_size, 1, max_qo_len,
|
|
max_seq_len)
|
|
return causal_mask
|
|
|
|
|
|
@pytest.mark.parametrize("s", [
|
|
PagedScenario(num_layers=32, num_generations=5),
|
|
PagedScenario(num_layers=32, num_generations=5, kv_len=64, causal=False),
|
|
PagedScenario(
|
|
num_layers=32, num_generations=5, kvcache_dtype=torch.float8_e4m3fn),
|
|
PagedScenario(num_layers=32,
|
|
num_generations=5,
|
|
kv_len=64,
|
|
causal=False,
|
|
kvcache_dtype=torch.float8_e4m3fn),
|
|
],
|
|
ids=["fp16", "fp16-cross", "fp8", "fp8-cross"])
|
|
def test_attention_backend_ifb(s: PagedScenario):
|
|
dtype = s.dtype
|
|
is_fp8 = s.kvcache_dtype == torch.float8_e4m3fn
|
|
if is_fp8 and getSMVersion() < 89:
|
|
pytest.skip("This test is not supported in pre-Ada architecture.")
|
|
torch.manual_seed(0)
|
|
num_layers = s.num_layers
|
|
num_heads = s.num_heads
|
|
num_kv_heads = s.num_kv_heads
|
|
num_kv_groups = s.num_kv_groups
|
|
head_dim = s.head_dim
|
|
page_size = s.page_size
|
|
kv_cache_len = s.kv_cache_len
|
|
past_kv_len = s.past_kv_len
|
|
qo_len = s.qo_len
|
|
kv_len = s.kv_len_resolved
|
|
batch_size = s.batch_size
|
|
num_generations = s.num_generations
|
|
num_contexts = s.num_contexts
|
|
num_ctx_q_tokens = s.num_ctx_q_tokens
|
|
num_ctx_kv_tokens = s.num_ctx_kv_tokens
|
|
nnz_qo = s.nnz_qo
|
|
nnz_kv = s.nnz_kv
|
|
cross = s.cross
|
|
|
|
q_at_layer = torch.randn(num_layers, nnz_qo,
|
|
num_heads * head_dim).half().cuda()
|
|
flashinfer_kv_cache = torch.randn(num_layers,
|
|
s.max_num_pages,
|
|
2,
|
|
page_size,
|
|
num_kv_heads,
|
|
head_dim,
|
|
device="cuda").to(s.kvcache_dtype)
|
|
ref_kv_cache = flashinfer_kv_cache.transpose(1, 2).contiguous().view(
|
|
num_layers, 2, batch_size, kv_cache_len, num_kv_heads, head_dim)
|
|
vanilla_kv_cache = ref_kv_cache.transpose(1, 2).contiguous()
|
|
kv = torch.randn(num_layers,
|
|
2,
|
|
nnz_kv,
|
|
num_kv_heads * head_dim,
|
|
device="cuda").to(dtype)
|
|
|
|
# Test flashinfer attention
|
|
|
|
context_lens = torch.full((num_contexts, ), qo_len).int()
|
|
qo_lens = torch.concat([context_lens, torch.ones(num_generations).int()])
|
|
|
|
if cross:
|
|
context_lens_kv = torch.full((num_contexts, ), kv_len).int()
|
|
seq_lens_kv = torch.concat(
|
|
[context_lens_kv,
|
|
torch.zeros(num_generations).int()])
|
|
else:
|
|
seq_lens_kv = None
|
|
|
|
num_cached_tokens_prefill = past_kv_len
|
|
num_cached_tokens_decode = kv_cache_len - (0 if cross else 1)
|
|
|
|
def produce(Attention: type[AttentionBackend], kv_cache: torch.Tensor):
|
|
return produce_outputs(
|
|
Attention,
|
|
q_at_layer,
|
|
kv,
|
|
s,
|
|
kv_cache=kv_cache,
|
|
num_cached_tokens=lambda i: num_cached_tokens_prefill
|
|
if i < num_contexts else num_cached_tokens_decode,
|
|
seq_lens=qo_lens,
|
|
seq_lens_kv=seq_lens_kv,
|
|
num_contexts=num_contexts,
|
|
quant_config=QuantConfig(
|
|
quant_algo=QuantAlgo.FP8,
|
|
kv_cache_quant_algo=QuantAlgo.FP8,
|
|
) if is_fp8 else None)
|
|
|
|
flashinfer_outputs = produce(FlashInferAttention, flashinfer_kv_cache)
|
|
vanilla_outputs = produce(VanillaAttention, vanilla_kv_cache)
|
|
|
|
# Test reference attention
|
|
|
|
kv_lens = torch.full((batch_size, ), kv_cache_len).int().cuda()
|
|
causal_mask = generate_causal_mask(kv_lens, qo_lens, batch_size,
|
|
dtype) if s.causal else 0
|
|
|
|
ref_outputs = []
|
|
for i in range(num_layers):
|
|
q = q_at_layer[i]
|
|
ref_kv_cache[i][:, :num_contexts, past_kv_len:kv_cache_len] = kv[
|
|
i][:, :num_ctx_kv_tokens].view(2, num_contexts, kv_len,
|
|
num_kv_heads, head_dim)
|
|
if not cross:
|
|
ref_kv_cache[i][:, num_contexts:,
|
|
-1:] = kv[i][:, num_ctx_kv_tokens:].view(
|
|
2, num_generations, 1, num_kv_heads, head_dim)
|
|
k = ref_kv_cache[i][0]
|
|
v = ref_kv_cache[i][1]
|
|
ctx_q, gen_q = q.split([num_ctx_q_tokens, num_generations])
|
|
gen_q = torch.nn.functional.pad(gen_q.unsqueeze(1),
|
|
(0, 0, 0, qo_len - 1), 'constant',
|
|
0).view(num_generations * qo_len,
|
|
num_heads * head_dim)
|
|
q = torch.cat([ctx_q, gen_q], dim=0)
|
|
q = q.view(batch_size, qo_len, num_heads, head_dim)
|
|
q = q.transpose(1, 2).contiguous()
|
|
k = k.transpose(1, 2).contiguous().to(dtype)
|
|
v = v.transpose(1, 2).contiguous().to(dtype)
|
|
k = repeat_kv(k, num_kv_groups)
|
|
v = repeat_kv(v, num_kv_groups)
|
|
|
|
attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(head_dim)
|
|
attn_weights = attn_weights + causal_mask
|
|
|
|
# upcast attention to fp32
|
|
attn_weights = torch.nn.functional.softmax(attn_weights,
|
|
dim=-1,
|
|
dtype=torch.float32).to(
|
|
q.dtype)
|
|
o = torch.matmul(attn_weights, v)
|
|
|
|
o = o.transpose(1, 2).contiguous()
|
|
o = o.view(batch_size, qo_len, -1)
|
|
ctx_o, gen_o = o.split([num_contexts, num_generations], dim=0)
|
|
gen_o = gen_o[:, 0].view(num_generations, num_heads * head_dim)
|
|
ctx_o = ctx_o.view(num_ctx_q_tokens, num_heads * head_dim)
|
|
o = torch.cat([ctx_o, gen_o], dim=0)
|
|
assert list(o.shape) == [nnz_qo, num_heads * head_dim]
|
|
ref_outputs.append(o)
|
|
|
|
for i in range(num_layers):
|
|
print(f"validate accuracy for layer {i}")
|
|
|
|
allclose(ref_outputs, {
|
|
"flashinfer": flashinfer_outputs,
|
|
"vanilla": vanilla_outputs
|
|
},
|
|
layer=i,
|
|
atol=fp8_atol if is_fp8 else atol)
|
|
assert torch.allclose(flashinfer_outputs[i],
|
|
vanilla_outputs[i],
|
|
atol=fp8_atol if is_fp8 else atol,
|
|
rtol=rtol)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_attention_backend(Scenario(num_layers=1))
|
|
# test_attention_backend(Scenario(num_layers=1, qo_len=32, kv_len=32, causal=False))
|