TensorRT-LLMs/tests/_torch/test_attention.py
Sharan Chetlur 258c7540c0 open source 09df54c0cc99354a60bbc0303e3e8ea33a96bef0 (#2725)
Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>

open source f8c0381a2bc50ee2739c3d8c2be481b31e5f00bd (#2736)

Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>

Add note for blackwell (#2742)

Update the docs to workaround the extra-index-url issue (#2744)

update README.md (#2751)

Fix github io pages (#2761)

Update
2025-02-11 02:21:51 +00:00

700 lines
25 KiB
Python

import math
import os
import sys
from dataclasses import dataclass
from typing import Callable, Optional, Sequence
import flashinfer
import pytest
import torch
import tensorrt_llm
from tensorrt_llm._torch.attention_backend import (AttentionBackend,
FlashInferAttention,
VanillaAttention)
from tensorrt_llm._torch.attention_backend.interface import \
PredefinedAttentionMask
from tensorrt_llm._torch.metadata import KVCacheParams
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
from tensorrt_llm.bindings.executor import KvCacheConfig
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models.modeling_utils import QuantConfig
from tensorrt_llm.quantization.mode import QuantAlgo
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from utils.util import getSMVersion
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :,
None, :, :].expand(batch, num_key_value_heads,
n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen,
head_dim)
atol = 1e-2
rtol = 1e-3
fp8_atol = 5e-2
@dataclass(kw_only=True, frozen=True)
class Scenario:
dtype: torch.dtype = torch.float16
kvcache_dtype: torch.dtype = torch.float16
num_layers: int
num_heads: int = 64
num_kv_heads: int = 16
head_dim: int = 128
page_size: int = 256
"""flash-attention requires `page_size` to be a multiple of 256"""
num_pages: int = 4
qo_len: int = 32
"""setting kv_len to non-zero to test cross attention"""
kv_len: int = 0
causal: bool = True
batch_size: int = 7
@property
def cross(self) -> bool:
return self.kv_len != 0
@property
def kv_len_resolved(self) -> int:
return self.kv_len or self.qo_len
@property
def num_kv_groups(self) -> int:
return self.num_heads // self.num_kv_heads
@property
def kv_cache_len(self) -> int:
return self.page_size * self.num_pages
@property
def past_kv_len(self) -> int:
return self.kv_cache_len - self.kv_len_resolved
@property
def max_num_pages(self) -> int:
return self.batch_size * self.num_pages
@property
def nnz_qo(self):
return self.batch_size * self.qo_len
@property
def nnz_kv(self):
return self.batch_size * self.kv_len_resolved
def __post_init__(self) -> None:
assert self.kv_len <= self.kv_cache_len, "KV len larger than cache len"
assert self.kv_len != 0 or self.qo_len <= self.kv_cache_len, "Seq len larger than cache len"
assert not (self.cross
and self.causal), "Cross attention cannot be causal"
@dataclass(kw_only=True, frozen=True)
class PagedScenario(Scenario):
num_generations: int
@property
def num_contexts(self) -> int:
return self.batch_size - self.num_generations
@property
def num_ctx_q_tokens(self) -> int:
return self.num_contexts * self.qo_len
@property
def num_ctx_kv_tokens(self) -> int:
return self.num_contexts * self.kv_len_resolved
@property
def nnz_qo(self) -> int:
return self.num_ctx_q_tokens + self.num_generations
@property
def nnz_kv(self) -> int:
n = self.num_ctx_kv_tokens
if not self.cross:
n += self.num_generations
return n
paged_backends = {
VanillaAttention: False,
FlashInferAttention: True,
}
def kv_cache_manager_from(Attention: type[AttentionBackend], s: Scenario,
kv_cache: torch.Tensor) -> KVCacheManager:
paged = paged_backends[Attention]
num_blocks = s.max_num_pages if paged else s.batch_size
tokens_per_block = s.page_size if paged else s.kv_cache_len
num_layers = s.num_layers
num_kv_heads = s.num_kv_heads
head_dim = s.head_dim
num_heads = s.num_kv_heads
max_seq_len = num_blocks * tokens_per_block
batch_size = s.batch_size
if s.kvcache_dtype == torch.float16:
kv_cache_dtype = tensorrt_llm.bindings.DataType.HALF
elif s.kvcache_dtype == torch.bfloat16:
kv_cache_dtype = tensorrt_llm.bindings.DataType.BF16
elif s.kvcache_dtype == torch.float8_e4m3fn:
kv_cache_dtype = tensorrt_llm.bindings.DataType.FP8
else:
raise ValueError("Invalid dtype for unit test")
kv_cache_config = KvCacheConfig(max_tokens=num_blocks * tokens_per_block)
mapping = Mapping(world_size=1, tp_size=1, rank=0)
cache_type = tensorrt_llm.bindings.internal.batch_manager.CacheType.CROSS if s.cross else tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF
result = KVCacheManager(kv_cache_config, cache_type, num_layers, num_heads,
num_kv_heads, head_dim, tokens_per_block,
max_seq_len, batch_size, mapping, kv_cache_dtype)
for i in range(s.num_layers):
result.get_buffers(i).view_as(kv_cache[i]).copy_(kv_cache[i])
return result
def produce_outputs(
Attention: type[AttentionBackend],
q_at_layer: torch.Tensor,
kv: Optional[torch.Tensor],
s: Scenario,
*,
kv_cache: torch.Tensor,
num_cached_tokens: Callable[[int], int] | int,
num_contexts: int | None = None,
seq_lens: torch.Tensor,
seq_lens_kv: Optional[torch.Tensor] = None,
quant_config: Optional[QuantConfig] = None,
) -> list[torch.Tensor]:
num_cached_tokens_per_seq = [
num_cached_tokens
if isinstance(num_cached_tokens, int) else num_cached_tokens(i)
for i in range(s.batch_size)
]
kv_cache_params = KVCacheParams(
use_cache=True, num_cached_tokens_per_seq=num_cached_tokens_per_seq)
kv_cache_manager = kv_cache_manager_from(Attention, s, kv_cache)
request_ids = list(range(s.batch_size))
seq_lens_append = seq_lens_kv if seq_lens_kv is not None else seq_lens
token_nums = (torch.tensor(num_cached_tokens_per_seq) +
seq_lens_append).tolist()
kv_cache_manager.add_dummy_requests(request_ids, token_nums)
metadata = Attention.Metadata(
num_contexts=num_contexts if num_contexts is not None else s.batch_size,
kv_cache_params=kv_cache_params,
seq_lens=seq_lens,
seq_lens_kv=seq_lens_kv,
max_num_requests=s.batch_size,
max_num_tokens=8192,
kv_cache_manager=kv_cache_manager,
request_ids=request_ids,
prompt_lens=token_nums,
)
metadata.prepare()
mask = PredefinedAttentionMask.CAUSAL if s.causal else PredefinedAttentionMask.FULL
outputs = []
for i in range(s.num_layers):
q = q_at_layer[i]
if kv is not None:
k, v = kv[i][0], kv[i][1]
else:
k, v = None, None
attention = Attention(
layer_idx=i,
num_heads=s.num_heads,
num_kv_heads=s.num_kv_heads,
head_dim=s.head_dim,
quant_config=quant_config,
)
o = attention.forward(q, k, v, metadata, attention_mask=mask)
assert list(o.shape) == [s.nnz_qo, s.num_heads * s.head_dim]
outputs.append(o)
kv_cache_manager.shutdown()
return outputs
def allclose(ref: Sequence[torch.Tensor],
impls: dict[str, Sequence[torch.Tensor]],
*,
layer=0,
atol=atol,
rtol=rtol):
for name, outputs in impls.items():
print(f"{name} output: ", float(outputs[layer].abs().mean()))
print("ref outputs: ", float(ref[layer].abs().mean()))
for name, outputs in impls.items():
print(f"{name} & ref diff: ",
float((ref[layer] - outputs[layer]).abs().mean()))
for name, outputs in impls.items():
torch.testing.assert_close(outputs[layer],
ref[layer],
atol=atol,
rtol=rtol,
msg=f"Allclose failed: ref<->{name}"),
def test_flashinfer_prefill():
s = Scenario(num_layers=1)
dtype = s.dtype
num_layers = s.num_layers
num_qo_heads = s.num_heads
num_kv_heads = s.num_kv_heads
num_kv_groups = s.num_kv_groups
head_dim = s.head_dim
page_size = s.page_size
num_pages = s.num_pages
kv_cache_len = s.kv_cache_len
qo_len = s.qo_len
past_kv_len = s.past_kv_len
batch_size = s.batch_size
nnz_qo = s.nnz_qo
max_num_pages = s.max_num_pages
# allocate 128MB workspace buffer
workspace_buffer = torch.empty(128 * 1024 * 1024,
dtype=torch.uint8,
device="cuda")
paged_wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer, "NHD", backend="fa2")
ragged_wrapper = flashinfer.BatchPrefillWithRaggedKVCacheWrapper(
workspace_buffer, "NHD", backend="fa2")
paged_kv_indices = torch.arange(max_num_pages).int().cuda()
paged_kv_indptr = torch.arange(0, batch_size + 1).int().cuda() * num_pages
# 1 <= paged_kv_last_page_len <= page_size
paged_kv_last_page_len = torch.full((batch_size, ), page_size).int().cuda()
qo_indptr = torch.arange(0, batch_size + 1).int().to("cuda") * qo_len
kv_indptr = torch.arange(0, batch_size + 1).int().to("cuda") * kv_cache_len
q_at_layer = torch.randn(num_layers,
nnz_qo,
num_qo_heads,
head_dim,
device="cuda").to(dtype).cuda()
kv_cache_at_layer = torch.randn(num_layers,
max_num_pages,
2,
page_size,
num_kv_heads,
head_dim,
device="cuda").to(s.kvcache_dtype)
kv_data = kv_cache_at_layer.transpose(1, 2).contiguous().view(
num_layers, 2, batch_size, kv_cache_len, num_kv_heads, head_dim)
causal_mask = torch.full((qo_len, kv_cache_len),
fill_value=torch.finfo(dtype).min,
dtype=dtype,
device="cuda")
cache_position = torch.arange(past_kv_len, kv_cache_len).cuda()
bool_causal_mask = torch.arange(
kv_cache_len).cuda() <= cache_position.reshape(-1, 1)
causal_mask *= ~bool_causal_mask
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
# create auxiliary data structures for batch prefill attention
paged_wrapper.plan(
qo_indptr,
paged_kv_indptr,
paged_kv_indices,
paged_kv_last_page_len,
num_qo_heads,
num_kv_heads,
head_dim,
page_size,
causal=True,
)
ragged_wrapper.plan(
qo_indptr,
kv_indptr,
num_qo_heads,
num_kv_heads,
head_dim,
causal=True,
)
flashinfer_outputs = []
for i in range(num_layers):
q = q_at_layer[i]
kv_cache = kv_cache_at_layer[i]
o = paged_wrapper.run(q, kv_cache)
k = kv_data[i][0]
v = kv_data[i][1]
k = k.view(-1, num_kv_heads, head_dim)
v = v.view(-1, num_kv_heads, head_dim)
ragged_o = ragged_wrapper.run(q, k, v)
assert list(o.shape) == [nnz_qo, num_qo_heads, head_dim]
print("paged output: ", float(o.abs().mean()))
print("ragged output: ", float(ragged_o.abs().mean()))
print("paged & ragged diff: ", float((ragged_o - o).abs().mean()))
assert torch.allclose(o, ragged_o, atol=atol, rtol=rtol)
flashinfer_outputs.append(o)
sdpa_outputs = []
for i in range(num_layers):
q = q_at_layer[i]
k = kv_data[i][0]
v = kv_data[i][1]
q = q.view(batch_size, qo_len, num_qo_heads, head_dim)
q = q.transpose(1, 2).contiguous()
k = k.transpose(1, 2).contiguous()
v = v.transpose(1, 2).contiguous()
k = repeat_kv(k, num_kv_groups)
v = repeat_kv(v, num_kv_groups)
o = torch.nn.functional.scaled_dot_product_attention(
q,
k,
v,
attn_mask=causal_mask,
)
o = o.transpose(1, 2).contiguous().view(nnz_qo, num_qo_heads, head_dim)
sdpa_outputs.append(o)
ref_outputs = []
for i in range(num_layers):
q = q_at_layer[i]
k = kv_data[i][0]
v = kv_data[i][1]
q = q.view(batch_size, qo_len, num_qo_heads, head_dim)
q = q.transpose(1, 2).contiguous()
k = k.transpose(1, 2).contiguous()
v = v.transpose(1, 2).contiguous()
k = repeat_kv(k, num_kv_groups)
v = repeat_kv(v, num_kv_groups)
attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(head_dim)
attn_weights = attn_weights + causal_mask
# upcast attention to fp32
attn_weights = torch.nn.functional.softmax(attn_weights,
dim=-1,
dtype=torch.float32).to(
q.dtype)
o = torch.matmul(attn_weights, v)
o = o.transpose(1, 2).contiguous().view(nnz_qo, num_qo_heads, head_dim)
ref_outputs.append(o)
allclose(
ref_outputs,
{
"flashinfer": flashinfer_outputs,
"sdpa": sdpa_outputs,
},
)
@pytest.mark.parametrize(
"s", [
Scenario(num_layers=1),
Scenario(num_layers=1, causal=False),
Scenario(num_layers=1, qo_len=32, kv_len=32, causal=False),
Scenario(num_layers=1, qo_len=32, kv_len=64, causal=False)
],
ids=["typical", "non-causal", "cross", "cross-diff-kv-len"])
def test_attention_backend(s: Scenario):
dtype = s.dtype
num_layers = s.num_layers
num_heads = s.num_heads
num_kv_heads = s.num_kv_heads
num_kv_groups = s.num_kv_groups
head_dim = s.head_dim
page_size = s.page_size
kv_cache_len = s.kv_cache_len
qo_len = s.qo_len
kv_len = s.kv_len_resolved
past_kv_len = s.past_kv_len
batch_size = s.batch_size
nnz_qo = s.nnz_qo
nnz_kv = s.nnz_kv
causal = s.causal
q_at_layer = torch.randn(num_layers,
nnz_qo,
num_heads * head_dim,
device="cuda").to(dtype)
flashinfer_kv_cache = torch.randn(num_layers,
s.max_num_pages,
2,
page_size,
num_kv_heads,
head_dim,
device="cuda").to(s.kvcache_dtype)
ref_kv_cache = flashinfer_kv_cache.transpose(1, 2).contiguous().view(
num_layers, 2, batch_size, kv_cache_len, num_kv_heads, head_dim)
kv = torch.randn(num_layers,
2,
nnz_kv,
num_kv_heads * head_dim,
device="cuda").to(dtype)
def produce(Attention: type[AttentionBackend], kv_cache: torch.Tensor):
return produce_outputs(
Attention,
q_at_layer,
kv,
s,
kv_cache=kv_cache,
num_cached_tokens=past_kv_len,
seq_lens=torch.full((batch_size, ), qo_len).int(),
seq_lens_kv=torch.full(
(batch_size, ), kv_len).int() if s.cross else None,
)
flashinfer_outputs = produce(FlashInferAttention, flashinfer_kv_cache)
sdpa_outputs = produce(VanillaAttention, ref_kv_cache.transpose(1, 2))
# Test reference attention
if causal:
causal_mask = torch.full((qo_len, kv_cache_len),
fill_value=torch.finfo(dtype).min,
dtype=dtype,
device="cuda")
cache_position = torch.arange(past_kv_len, kv_cache_len).cuda()
bool_causal_mask = torch.arange(
kv_cache_len).cuda() <= cache_position.reshape(-1, 1)
causal_mask *= ~bool_causal_mask
causal_mask = causal_mask[None,
None, :, :].expand(batch_size, 1, -1, -1)
else:
causal_mask = 0
ref_outputs = []
for i in range(num_layers):
q = q_at_layer[i]
ref_kv_cache[i][:, :, past_kv_len:kv_cache_len] = kv[i].view(
2, batch_size, kv_len, num_kv_heads, head_dim)
k = ref_kv_cache[i][0]
v = ref_kv_cache[i][1]
q = q.view(batch_size, qo_len, num_heads, head_dim)
q = q.transpose(1, 2).contiguous()
k = k.transpose(1, 2).contiguous()
v = v.transpose(1, 2).contiguous()
k = repeat_kv(k, num_kv_groups)
v = repeat_kv(v, num_kv_groups)
attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(head_dim)
attn_weights = attn_weights + causal_mask
# upcast attention to fp32
attn_weights = torch.nn.functional.softmax(attn_weights,
dim=-1,
dtype=torch.float32).to(
q.dtype)
o = torch.matmul(attn_weights, v)
o = o.transpose(1, 2).contiguous().view(nnz_qo, num_heads * head_dim)
ref_outputs.append(o)
allclose(
ref_outputs,
{
"flashinfer": flashinfer_outputs,
"sdpa": sdpa_outputs,
},
)
def generate_causal_mask(seq_lens, qo_lens, batch_size, dtype):
causal_masks = []
max_seq_len = int(seq_lens.max())
max_qo_len = int(qo_lens.max())
for i in range(batch_size):
kv_len = seq_lens[i]
qo_len = qo_lens[i]
past_kv_len = kv_len - qo_len
causal_mask = torch.full((qo_len, kv_len),
fill_value=torch.finfo(dtype).min,
dtype=dtype,
device="cuda")
cache_position = torch.arange(past_kv_len, kv_len).cuda()
causal_mask *= torch.arange(kv_len).cuda() > cache_position.reshape(
-1, 1)
causal_mask = torch.nn.functional.pad(
causal_mask, (0, max_seq_len - kv_len, 0, max_qo_len - qo_len),
'constant',
torch.finfo(dtype).min)
causal_masks.append(causal_mask)
causal_mask = torch.stack(causal_masks).view(batch_size, 1, max_qo_len,
max_seq_len)
return causal_mask
@pytest.mark.parametrize("s", [
PagedScenario(num_layers=32, num_generations=5),
PagedScenario(num_layers=32, num_generations=5, kv_len=64, causal=False),
PagedScenario(
num_layers=32, num_generations=5, kvcache_dtype=torch.float8_e4m3fn),
PagedScenario(num_layers=32,
num_generations=5,
kv_len=64,
causal=False,
kvcache_dtype=torch.float8_e4m3fn),
],
ids=["fp16", "fp16-cross", "fp8", "fp8-cross"])
def test_attention_backend_ifb(s: PagedScenario):
dtype = s.dtype
is_fp8 = s.kvcache_dtype == torch.float8_e4m3fn
if is_fp8 and getSMVersion() < 89:
pytest.skip("This test is not supported in pre-Ada architecture.")
torch.manual_seed(0)
num_layers = s.num_layers
num_heads = s.num_heads
num_kv_heads = s.num_kv_heads
num_kv_groups = s.num_kv_groups
head_dim = s.head_dim
page_size = s.page_size
kv_cache_len = s.kv_cache_len
past_kv_len = s.past_kv_len
qo_len = s.qo_len
kv_len = s.kv_len_resolved
batch_size = s.batch_size
num_generations = s.num_generations
num_contexts = s.num_contexts
num_ctx_q_tokens = s.num_ctx_q_tokens
num_ctx_kv_tokens = s.num_ctx_kv_tokens
nnz_qo = s.nnz_qo
nnz_kv = s.nnz_kv
cross = s.cross
q_at_layer = torch.randn(num_layers, nnz_qo,
num_heads * head_dim).half().cuda()
flashinfer_kv_cache = torch.randn(num_layers,
s.max_num_pages,
2,
page_size,
num_kv_heads,
head_dim,
device="cuda").to(s.kvcache_dtype)
ref_kv_cache = flashinfer_kv_cache.transpose(1, 2).contiguous().view(
num_layers, 2, batch_size, kv_cache_len, num_kv_heads, head_dim)
vanilla_kv_cache = ref_kv_cache.transpose(1, 2).contiguous()
kv = torch.randn(num_layers,
2,
nnz_kv,
num_kv_heads * head_dim,
device="cuda").to(dtype)
# Test flashinfer attention
context_lens = torch.full((num_contexts, ), qo_len).int()
qo_lens = torch.concat([context_lens, torch.ones(num_generations).int()])
if cross:
context_lens_kv = torch.full((num_contexts, ), kv_len).int()
seq_lens_kv = torch.concat(
[context_lens_kv,
torch.zeros(num_generations).int()])
else:
seq_lens_kv = None
num_cached_tokens_prefill = past_kv_len
num_cached_tokens_decode = kv_cache_len - (0 if cross else 1)
def produce(Attention: type[AttentionBackend], kv_cache: torch.Tensor):
return produce_outputs(
Attention,
q_at_layer,
kv,
s,
kv_cache=kv_cache,
num_cached_tokens=lambda i: num_cached_tokens_prefill
if i < num_contexts else num_cached_tokens_decode,
seq_lens=qo_lens,
seq_lens_kv=seq_lens_kv,
num_contexts=num_contexts,
quant_config=QuantConfig(
quant_algo=QuantAlgo.FP8,
kv_cache_quant_algo=QuantAlgo.FP8,
) if is_fp8 else None)
flashinfer_outputs = produce(FlashInferAttention, flashinfer_kv_cache)
vanilla_outputs = produce(VanillaAttention, vanilla_kv_cache)
# Test reference attention
kv_lens = torch.full((batch_size, ), kv_cache_len).int().cuda()
causal_mask = generate_causal_mask(kv_lens, qo_lens, batch_size,
dtype) if s.causal else 0
ref_outputs = []
for i in range(num_layers):
q = q_at_layer[i]
ref_kv_cache[i][:, :num_contexts, past_kv_len:kv_cache_len] = kv[
i][:, :num_ctx_kv_tokens].view(2, num_contexts, kv_len,
num_kv_heads, head_dim)
if not cross:
ref_kv_cache[i][:, num_contexts:,
-1:] = kv[i][:, num_ctx_kv_tokens:].view(
2, num_generations, 1, num_kv_heads, head_dim)
k = ref_kv_cache[i][0]
v = ref_kv_cache[i][1]
ctx_q, gen_q = q.split([num_ctx_q_tokens, num_generations])
gen_q = torch.nn.functional.pad(gen_q.unsqueeze(1),
(0, 0, 0, qo_len - 1), 'constant',
0).view(num_generations * qo_len,
num_heads * head_dim)
q = torch.cat([ctx_q, gen_q], dim=0)
q = q.view(batch_size, qo_len, num_heads, head_dim)
q = q.transpose(1, 2).contiguous()
k = k.transpose(1, 2).contiguous().to(dtype)
v = v.transpose(1, 2).contiguous().to(dtype)
k = repeat_kv(k, num_kv_groups)
v = repeat_kv(v, num_kv_groups)
attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(head_dim)
attn_weights = attn_weights + causal_mask
# upcast attention to fp32
attn_weights = torch.nn.functional.softmax(attn_weights,
dim=-1,
dtype=torch.float32).to(
q.dtype)
o = torch.matmul(attn_weights, v)
o = o.transpose(1, 2).contiguous()
o = o.view(batch_size, qo_len, -1)
ctx_o, gen_o = o.split([num_contexts, num_generations], dim=0)
gen_o = gen_o[:, 0].view(num_generations, num_heads * head_dim)
ctx_o = ctx_o.view(num_ctx_q_tokens, num_heads * head_dim)
o = torch.cat([ctx_o, gen_o], dim=0)
assert list(o.shape) == [nnz_qo, num_heads * head_dim]
ref_outputs.append(o)
for i in range(num_layers):
print(f"validate accuracy for layer {i}")
allclose(ref_outputs, {
"flashinfer": flashinfer_outputs,
"vanilla": vanilla_outputs
},
layer=i,
atol=fp8_atol if is_fp8 else atol)
assert torch.allclose(flashinfer_outputs[i],
vanilla_outputs[i],
atol=fp8_atol if is_fp8 else atol,
rtol=rtol)
if __name__ == "__main__":
test_attention_backend(Scenario(num_layers=1))
# test_attention_backend(Scenario(num_layers=1, qo_len=32, kv_len=32, causal=False))