remove star attention related tests

Signed-off-by: yuhangh <58161490+heyuhhh@users.noreply.github.com>
This commit is contained in:
yuhangh 2026-01-06 07:30:26 +00:00
parent f97a1aaa96
commit c9ff79c900
14 changed files with 5 additions and 914 deletions

View File

@ -951,7 +951,6 @@ common-files: &common_files |
tests/unittest/_torch/attention/test_attention_no_cache.py |
tests/unittest/_torch/attention/test_attention.py |
tests/unittest/_torch/attention/test_flashinfer_attention.py |
tests/unittest/_torch/attention/test_flashinfer_star_attn.py |
tests/unittest/_torch/attention/test_vanilla_attention.py |
tests/unittest/_torch/compilation/test_add_norm.py |
tests/unittest/_torch/debugger/test_debugger_addon.py |
@ -1004,7 +1003,6 @@ common-files: &common_files |
tests/unittest/_torch/multi_gpu/test_mnnvl_allreduce.py |
tests/unittest/_torch/multi_gpu/test_mnnvl_memory.py |
tests/unittest/_torch/multi_gpu/test_moe_a2a.py |
tests/unittest/_torch/multi_gpu/test_star_attention.py |
tests/unittest/_torch/multi_gpu/test_user_buffers.py |
tests/unittest/_torch/multimodal/test_external_embedding.py |
tests/unittest/_torch/multimodal/test_find_num_image_tokens.py |

View File

@ -42,8 +42,7 @@ def parse_arguments():
parser.add_argument(
'--input_file',
type=str,
default="tests/unittest/_torch/multi_gpu/test_star_attention_input.jsonl"
)
default="tests/unittest/_torch/multi_gpu/NIAH_simple_data.jsonl")
# Build config
parser.add_argument('--algo',

View File

@ -992,7 +992,6 @@ exclude = [
"tests/unittest/_torch/attention/test_attention_no_cache.py",
"tests/unittest/_torch/attention/test_attention.py",
"tests/unittest/_torch/attention/test_flashinfer_attention.py",
"tests/unittest/_torch/attention/test_flashinfer_star_attn.py",
"tests/unittest/_torch/attention/test_vanilla_attention.py",
"tests/unittest/_torch/compilation/test_add_norm.py",
"tests/unittest/_torch/debugger/test_debugger_addon.py",
@ -1045,7 +1044,6 @@ exclude = [
"tests/unittest/_torch/multi_gpu/test_mnnvl_allreduce.py",
"tests/unittest/_torch/multi_gpu/test_mnnvl_memory.py",
"tests/unittest/_torch/multi_gpu/test_moe_a2a.py",
"tests/unittest/_torch/multi_gpu/test_star_attention.py",
"tests/unittest/_torch/multi_gpu/test_user_buffers.py",
"tests/unittest/_torch/multimodal/test_external_embedding.py",
"tests/unittest/_torch/multimodal/test_find_num_image_tokens.py",

View File

@ -886,7 +886,6 @@
"test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[image]": 66.199569062097,
"test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[image_audio]": 62.389084175927565,
"test_e2e.py::test_ptp_scaffolding[DeepSeek-R1-Distill-Qwen-7B-DeepSeek-R1/DeepSeek-R1-Distill-Qwen-7B]": 7200.0001350759994238615,
"test_e2e.py::test_ptp_star_attention_example[Llama3.1-8B-BF16-llama-3.1-model/Meta-Llama-3.1-8B]": 3600.00020311586558818817,
"test_e2e.py::test_qwen_e2e_cpprunner_large_new_tokens[DeepSeek-R1-Distill-Qwen-1.5B-DeepSeek-R1-Distill-Qwen-1.5B]": 137.7278483910486,
"test_e2e.py::test_relaxed_acceptance_quickstart_advanced_deepseek_r1_8gpus[DeepSeek-R1-DeepSeek-R1/DeepSeek-R1]": 12134.278186964104,
"test_e2e.py::test_trtllm_bench_help_sanity[meta-llama/Llama-3.1-8B]": 109.25386995915323,

View File

@ -3135,29 +3135,6 @@ def test_ptp_quickstart_bert(llm_root, llm_venv, model_name, model_path,
print("Success: HF model logits match TRTLLM logits!")
@pytest.mark.parametrize("model_name,model_path", [
("Llama3.1-8B-BF16", "llama-3.1-model/Meta-Llama-3.1-8B"),
])
def test_ptp_star_attention_example(llm_root, llm_venv, model_name, model_path,
star_attention_input_root):
print(f"Testing {model_name}.")
workspace = llm_venv.get_working_directory()
example_root = Path(os.path.join(llm_root, "examples", "llm-api"))
input_file = Path(
os.path.join(star_attention_input_root,
"test_star_attention_input.jsonl"))
output_file = Path(os.path.join(workspace, "star_attention_output.jsonl"))
llm_venv.run_cmd([
str(example_root / "star_attention.py"),
"--model_path",
f"{llm_models_root()}/{model_path}",
"--sa_block_size=200",
"--sa_anchor_size=200",
f"--input_file={input_file}",
f"--output_file={output_file}",
])
@pytest.mark.skip_less_device_memory(80000)
@pytest.mark.parametrize("model_name,model_path", [
("DeepSeek-R1-Distill-Qwen-7B", "DeepSeek-R1/DeepSeek-R1-Distill-Qwen-7B"),

View File

@ -371,7 +371,6 @@ test_e2e.py::test_ptp_quickstart_advanced_pp_enabled[Llama3.3-70B-FP8-llama-3.3-
test_e2e.py::test_ptp_quickstart_advanced_pp_enabled[Llama3.3-70B-FP8-llama-3.3-models/Llama-3.3-70B-Instruct-FP8-2-2-True]
test_e2e.py::test_ptp_quickstart_advanced_pp_enabled[Llama3.3-70B-FP8-llama-3.3-models/Llama-3.3-70B-Instruct-FP8-2-4-False]
test_e2e.py::test_ptp_quickstart_advanced_pp_enabled[Llama3.3-70B-FP8-llama-3.3-models/Llama-3.3-70B-Instruct-FP8-2-4-True]
test_e2e.py::test_ptp_star_attention_example[Llama3.1-8B-BF16-llama-3.1-model/Meta-Llama-3.1-8B]
test_e2e.py::test_trtllm_bench_pytorch_backend_sanity[meta-llama/Llama-3.1-8B-llama-3.1-8b-hf-nvfp4-False-False]
test_e2e.py::test_ptp_scaffolding[DeepSeek-R1-Distill-Qwen-7B-DeepSeek-R1/DeepSeek-R1-Distill-Qwen-7B]
test_e2e.py::test_ptp_quickstart_advanced_deepseek_r1_w4afp8_8gpus[DeepSeek-R1-W4AFP8-DeepSeek-R1/DeepSeek-R1-W4AFP8]

View File

@ -229,7 +229,6 @@ test_e2e.py::test_ptp_quickstart_bert[TRTLLM-BertForSequenceClassification-bert/
test_e2e.py::test_ptp_quickstart_bert[VANILLA-BertForSequenceClassification-bert/bert-base-uncased-yelp-polarity]
test_e2e.py::test_ptp_scaffolding[DeepSeek-R1-Distill-Qwen-7B-DeepSeek-R1/DeepSeek-R1-Distill-Qwen-7B]
test_e2e.py::test_ptp_quickstart_advanced_deepseek_r1_w4afp8_8gpus[DeepSeek-R1-W4AFP8-DeepSeek-R1/DeepSeek-R1-W4AFP8]
test_e2e.py::test_ptp_star_attention_example[Llama3.1-8B-BF16-llama-3.1-model/Meta-Llama-3.1-8B]
test_e2e.py::test_qwen_e2e_cpprunner_large_new_tokens[DeepSeek-R1-Distill-Qwen-1.5B-DeepSeek-R1-Distill-Qwen-1.5B]
test_e2e.py::test_relaxed_acceptance_quickstart_advanced_deepseek_r1_8gpus[DeepSeek-R1-DeepSeek-R1/DeepSeek-R1]
test_e2e.py::test_trtllm_bench_pytorch_backend_sanity[meta-llama/Llama-3.1-8B-llama-3.1-8b-hf-nvfp4-False-False]

View File

@ -59,6 +59,5 @@ test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[phi4-multimodal-instruct-mult
test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[phi4-multimodal-instruct-multimodals/Phi-4-multimodal-instruct-image_audio]
test_e2e.py::test_ptp_quickstart_bert[VANILLA-BertForSequenceClassification-bert/bert-base-uncased-yelp-polarity]
test_e2e.py::test_ptp_quickstart_bert[TRTLLM-BertForSequenceClassification-bert/bert-base-uncased-yelp-polarity]
test_e2e.py::test_ptp_star_attention_example[Llama3.1-8B-BF16-llama-3.1-model/Meta-Llama-3.1-8B]
test_e2e.py::test_trtllm_bench_pytorch_backend_sanity[meta-llama/Llama-3.1-8B-llama-3.1-8b-hf-nvfp4-False-False]
test_e2e.py::test_ptp_scaffolding[DeepSeek-R1-Distill-Qwen-7B-DeepSeek-R1/DeepSeek-R1-Distill-Qwen-7B]

View File

@ -263,7 +263,6 @@ test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[phi4-multimodal-instruct-mult
test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[phi4-multimodal-instruct-multimodals/Phi-4-multimodal-instruct-image_audio]
test_e2e.py::test_ptp_quickstart_multimodal_2gpu[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503]
test_e2e.py::test_ptp_quickstart_multimodal_multiturn[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503]
test_e2e.py::test_ptp_star_attention_example[Llama3.1-8B-BF16-llama-3.1-model/Meta-Llama-3.1-8B]
test_e2e.py::test_ptp_scaffolding[DeepSeek-R1-Distill-Qwen-7B-DeepSeek-R1/DeepSeek-R1-Distill-Qwen-7B]
test_e2e.py::test_ptp_quickstart_advanced_deepseek_r1_w4afp8_8gpus[DeepSeek-R1-W4AFP8-DeepSeek-R1/DeepSeek-R1-W4AFP8]
unittest/llmapi/test_llm_pytorch.py::test_gemma3_1b_instruct_multi_lora

View File

@ -152,7 +152,9 @@ test_e2e.py::test_ptp_quickstart_advanced_multi_gpus[Nemotron-Ultra-253B-nemotro
examples/test_multimodal.py::test_llm_multimodal_general[Phi-4-multimodal-instruct-pp:1-tp:1-float16-bs:1-cpp_e2e:False-nb:1] SKIP (https://nvbugs/5385992)
examples/test_recurrentgemma.py::test_llm_recurrentgemma_1gpu[use_cpp_session-recurrentgemma-2b-use_paged_cache-int4_awq-float16-enable_attn_plugin-enable_gemm_plugin] SKIP (https://nvbugs/5401233)
examples/test_recurrentgemma.py::test_llm_recurrentgemma_2gpu[recurrentgemma-2b] SKIP (https://nvbugs/5401233)
test_e2e.py::test_ptp_star_attention_example[Llama3.1-8B-BF16-llama-3.1-model/Meta-Llama-3.1-8B] SKIP (https://nvbugs/5409420)
examples/test_multimodal.py::test_llm_multimodal_general[kosmos-2-pp:1-tp:1-float16-bs:8-cpp_e2e:False-nb:1] SKIP (https://nvbugs/5141288)
examples/test_qwen.py::test_llm_qwen_7b_int8_kv_1node_1gpus[qwen2_vl_7b_instruct-enable_gemm_plugin-enable_weight_only] SKIP (https://nvbugs/5419067)
examples/test_qwen.py::test_llm_qwen_awq_single_gpu_summary[qwen2_vl_7b_instruct-nb:4] SKIP (https://nvbugs/5419068)
examples/test_recurrentgemma.py::test_llm_recurrentgemma_1gpu[use_cpp_session-recurrentgemma-2b-use_paged_cache-fp8-float16-enable_attn_plugin-enable_gemm_plugin] SKIP (https://nvbugs/5419070)
examples/test_granite.py::test_granite_bf16_lora[granite-3.0-1b-a400m-instruct] SKIP (https://nvbugs/5431132)
examples/test_gemma.py::test_hf_gemma_fp8_base_bf16_multi_lora[gemma-2-9b-it] SKIP (https://nvbugs/5434451)
@ -341,7 +343,6 @@ disaggregated/test_disaggregated.py::test_disaggregated_benchmark_on_diff_backen
disaggregated/test_disaggregated.py::test_disaggregated_benchmark_on_diff_backends[DeepSeek-V3-Lite-bf16] SKIP (https://nvbugs/5769890)
disaggregated/test_disaggregated.py::test_disaggregated_benchmark_on_diff_backends[llama-v3-8b-hf] SKIP (https://nvbugs/5769890,https://nvbugs/5748683)
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_pp4_mtp] SKIP (https://nvbugs/5779536)
unittest/_torch/attention/test_flashinfer_star_attn.py::TestStarAttention::test_flashinfer_star_attention[num_layers:2-num_heads:32-num_kv_heads:8-head_dim:64-anchor_size:64-block_size:64-dtype:torch.float16] SKIP (https://nvbugs/5781389)
unittest/_torch/ray_orchestrator/multi_gpu/test_ops.py::test_reducescatter_pg_op[var_len:True-seqlen:16-hidden:128] SKIP (https://nvbugs/5781383)
cpp/test_e2e.py::test_model[-mamba-86] SKIP (https://nvbugs/5781665)
unittest/llmapi/test_llm_multi_gpu_pytorch.py::test_tinyllama_logits_processor_tp2pp2 SKIP (https://nvbugs/5781731)

View File

@ -61,7 +61,7 @@ def test_model(backend, model_name, attention_backend):
current_file = os.path.abspath(__file__)
current_dir = os.path.dirname(os.path.dirname(
os.path.dirname(current_file)))
input_file = f'{current_dir}/multi_gpu/test_star_attention_input.jsonl'
input_file = f'{current_dir}/multi_gpu/NIAH_simple_data.jsonl'
with open(input_file, 'r') as f:
for line in f:
sample = json.loads(line)

View File

@ -1,738 +0,0 @@
import random
import unittest
from collections import defaultdict
from dataclasses import dataclass
from typing import List, Optional, Union
import pytest
import torch
from parameterized import parameterized
import tensorrt_llm
from tensorrt_llm._torch.attention_backend import (StarAttention,
StarAttentionMetadata)
from tensorrt_llm._torch.metadata import KVCacheParams
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
from tensorrt_llm.bindings.executor import KvCacheConfig
from tensorrt_llm.mapping import CpType, Mapping
class TestingStarAttentionMetadata(StarAttentionMetadata):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._num_times_planned = defaultdict(int)
def get_num_plans(self, plan_params) -> int:
return self._num_times_planned[plan_params]
def _plan_with_params(self, plan_params):
if self.needs_plan(plan_params):
self._num_times_planned[plan_params] += 1
return super()._plan_with_params(plan_params)
@dataclass(repr=False)
class Scenario:
num_layers: int
num_heads: int
num_kv_heads: Union[int, List[Optional[int]]]
head_dim: int
anchor_size: int
block_size: int
dtype: torch.dtype
def __repr__(self) -> str:
if isinstance(self.num_kv_heads, int):
num_kv_heads_str = str(self.num_kv_heads)
else:
num_kv_heads_str = '_'.join(map(str, self.num_kv_heads))
return f"num_layers:{self.num_layers}-num_heads:{self.num_heads}-num_kv_heads:{num_kv_heads_str}-head_dim:{self.head_dim}-anchor_size:{self.anchor_size}-block_size:{self.block_size}-dtype:{self.dtype}"
@dataclass
class CUDAGraphTestScenario:
batch_size: int
num_heads: int
num_kv_heads: int
head_dim: int
anchor_size: int
block_size: int
dtype: torch.dtype
def __repr__(self) -> str:
if isinstance(self.num_kv_heads, int):
num_kv_heads_str = str(self.num_kv_heads)
else:
num_kv_heads_str = '_'.join(map(str, self.num_kv_heads))
return f"batch_size:{self.batch_size}-num_heads:{self.num_heads}-num_kv_heads:{num_kv_heads_str}-head_dim:{self.head_dim}-anchor_size:{self.anchor_size}-block_size:{self.block_size}-dtype:{self.dtype}"
@pytest.mark.skip(reason="https://nvbugspro.nvidia.com/bug/5781389")
class TestStarAttention(unittest.TestCase):
@parameterized.expand([
Scenario(num_layers=1,
num_heads=32,
num_kv_heads=8,
head_dim=128,
anchor_size=64,
block_size=64,
dtype=torch.bfloat16),
Scenario(num_layers=2,
num_heads=32,
num_kv_heads=8,
head_dim=64,
anchor_size=64,
block_size=64,
dtype=torch.float16),
Scenario(num_layers=2,
num_heads=32,
num_kv_heads=[8, 16],
head_dim=128,
anchor_size=64,
block_size=64,
dtype=torch.bfloat16),
Scenario(num_layers=3,
num_heads=32,
num_kv_heads=[8, None, 16],
head_dim=64,
anchor_size=64,
block_size=64,
dtype=torch.float16),
Scenario(num_layers=3,
num_heads=32,
num_kv_heads=[8, None, 16],
head_dim=64,
anchor_size=64,
block_size=128,
dtype=torch.bfloat16),
Scenario(num_layers=3,
num_heads=32,
num_kv_heads=[8, None, 16],
head_dim=64,
anchor_size=64,
block_size=256,
dtype=torch.bfloat16),
], lambda testcase_func, param_num, param:
f"{testcase_func.__name__}[{param.args[0]}]")
def test_flashinfer_star_attention(self, scenario: Scenario):
num_layers = scenario.num_layers
num_heads = scenario.num_heads
num_kv_heads = scenario.num_kv_heads
head_dim = scenario.head_dim
dtype = scenario.dtype
device = torch.device('cuda')
num_gens = 2
context_sequence_lengths = [356, 400]
query_sequence_lengths = [4, 10]
sequence_lengths = context_sequence_lengths + query_sequence_lengths + [
1
] * num_gens
past_seen_tokens = [0, 0, 318, 356, 256, 258]
# 6 7 6 6 5 5
cache_indices = [[0, 1, 2, 3, 4, 5], [6, 7, 8, 9, 10, 11, 12],
[13, 14, 15, 16, 17, 18], [19, 20, 21, 22, 23, 24],
[25, 26, 27, 28, 29], [30, 31, 32, 33, 34]]
batch_size = len(sequence_lengths)
request_ids = list(range(batch_size))
token_nums = (torch.tensor(sequence_lengths) +
torch.tensor(past_seen_tokens)).tolist()
num_blocks = 64
tokens_per_block = 64
max_seq_len = tokens_per_block * num_blocks
cp_config = {
"cp_type": CpType.STAR,
"cp_anchor_size": scenario.anchor_size,
"block_size": scenario.block_size
}
mapping = Mapping(world_size=1,
tp_size=1,
cp_size=1,
cp_config=cp_config,
rank=0)
if dtype == torch.float16:
kv_cache_dtype = tensorrt_llm.bindings.DataType.HALF
elif dtype == torch.bfloat16:
kv_cache_dtype = tensorrt_llm.bindings.DataType.BF16
else:
raise ValueError("Invalid dtype for unit test")
kv_cache_config = KvCacheConfig(max_tokens=num_blocks *
tokens_per_block)
kv_cache_manager = KVCacheManager(
kv_cache_config,
tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF,
num_layers=num_layers,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
tokens_per_block=tokens_per_block,
max_seq_len=max_seq_len,
max_batch_size=batch_size,
mapping=mapping,
dtype=kv_cache_dtype,
)
kv_cache_manager.add_dummy_requests(request_ids, token_nums)
for i in range(kv_cache_manager.num_layers):
buf = kv_cache_manager.get_buffers(i)
if buf is not None:
torch.nn.init.normal_(buf)
del buf
if isinstance(num_kv_heads, int):
num_kv_heads = [num_kv_heads] * num_layers
contexts_per_layer = []
queries_per_layer = []
gens_per_layer = []
for layer_idx in range(num_layers):
kv_heads = num_kv_heads[layer_idx]
if kv_heads is None:
continue
context_qs = [
torch.randn(sequence_length,
num_heads * head_dim,
dtype=dtype,
device=device)
for sequence_length in context_sequence_lengths
]
context_ks = [
torch.randn(sequence_length,
kv_heads * head_dim,
dtype=dtype,
device=device)
for sequence_length in context_sequence_lengths
]
context_vs = [
torch.randn(sequence_length,
kv_heads * head_dim,
dtype=dtype,
device=device)
for sequence_length in context_sequence_lengths
]
contexts_per_layer.append((context_qs, context_ks, context_vs))
query_qs = [
torch.randn(sequence_length,
num_heads * head_dim,
dtype=dtype,
device=device)
for sequence_length in query_sequence_lengths
]
query_ks = [
torch.randn(sequence_length,
kv_heads * head_dim,
dtype=dtype,
device=device)
for sequence_length in query_sequence_lengths
]
query_vs = [
torch.randn(sequence_length,
kv_heads * head_dim,
dtype=dtype,
device=device)
for sequence_length in query_sequence_lengths
]
queries_per_layer.append((query_qs, query_ks, query_vs))
gen_qs = [
torch.randn(1, num_heads * head_dim, dtype=dtype, device=device)
for _ in range(num_gens)
]
gen_ks = [
torch.randn(1, kv_heads * head_dim, dtype=dtype, device=device)
for _ in range(num_gens)
]
gen_vs = [
torch.randn(1, kv_heads * head_dim, dtype=dtype, device=device)
for _ in range(num_gens)
]
gens_per_layer.append((gen_qs, gen_ks, gen_vs))
layers = [
StarAttention(
layer_idx=layer_idx,
num_heads=num_heads,
head_dim=head_dim,
num_kv_heads=kv_heads,
) for layer_idx, kv_heads in enumerate(num_kv_heads)
if kv_heads is not None
]
# [context_1, context_2, query_1, query_2, gen_1, gen_2]
results_1 = []
block_ids_per_seq = [i for i in cache_indices]
num_cached_tokens_per_seq = [j for j in past_seen_tokens]
seq_lens = torch.tensor(sequence_lengths).int()
attn_metadata = TestingStarAttentionMetadata(
seq_lens=seq_lens,
num_contexts=len(context_sequence_lengths),
num_queries=len(query_sequence_lengths),
kv_cache_params=KVCacheParams(
use_cache=True,
block_ids_per_seq=block_ids_per_seq,
num_cached_tokens_per_seq=past_seen_tokens,
),
max_num_requests=6,
max_num_tokens=8192,
kv_cache_manager=kv_cache_manager,
request_ids=request_ids,
mapping=mapping,
)
attn_metadata.prepare()
for attn_layer_idx, star_attn in enumerate(layers):
context_qs, context_ks, context_vs = contexts_per_layer[
attn_layer_idx]
query_qs, query_ks, query_vs = queries_per_layer[attn_layer_idx]
gen_qs, gen_ks, gen_vs = gens_per_layer[attn_layer_idx]
q = torch.cat((*context_qs, *query_qs, *gen_qs))
k = torch.cat((*context_ks, *query_ks, *gen_ks))
v = torch.cat((*context_vs, *query_vs, *gen_vs))
result_1 = star_attn.forward(q, k, v, attn_metadata)
self.assertEqual(
result_1.size()[0],
sum(context_sequence_lengths) + sum(query_sequence_lengths) +
num_gens)
# validate kv cache was updated expectedly
cache_buf = kv_cache_manager.get_buffers(
star_attn.layer_idx, kv_layout=attn_metadata.kv_layout)
if attn_metadata.kv_layout == "HND":
cache_buf = cache_buf.transpose(2, 3).contiguous()
assert cache_buf is not None
num_kv_heads = cache_buf.size(-2)
# validate contexts
block_ids_per_seq = kv_cache_manager.get_batch_cache_indices(
request_ids)
for seq_id in range(len(context_sequence_lengths)):
# get a contiguous copy of the cache for the sequence
block_ids = block_ids_per_seq[seq_id]
cached_kvs = torch.concat(cache_buf[block_ids, :].unbind(dim=0),
dim=1)
# only look at new tokens added
cached_kvs = cached_kvs[:, past_seen_tokens[seq_id]:
past_seen_tokens[seq_id] +
context_sequence_lengths[seq_id]]
# compare to input kvs
torch.testing.assert_close(
cached_kvs[0].to(context_ks[seq_id].dtype),
context_ks[seq_id].view(-1, num_kv_heads, head_dim))
torch.testing.assert_close(
cached_kvs[1].to(context_vs[seq_id].dtype),
context_vs[seq_id].view(-1, num_kv_heads, head_dim))
# validate queries
for query_seq_id in range(len(query_sequence_lengths)):
seq_id = query_seq_id + len(context_sequence_lengths)
# get a contiguous copy of the cache for the sequence
block_ids = block_ids_per_seq[seq_id]
cached_kvs = torch.concat(cache_buf[block_ids, :].unbind(dim=0),
dim=1)
# only look at new tokens added
cached_kvs = cached_kvs[:, past_seen_tokens[seq_id]:
past_seen_tokens[seq_id] +
query_sequence_lengths[query_seq_id]]
# compare to input kvs
torch.testing.assert_close(
cached_kvs[0].to(query_ks[query_seq_id].dtype),
query_ks[query_seq_id].view(-1, num_kv_heads, head_dim))
torch.testing.assert_close(
cached_kvs[1].to(query_vs[query_seq_id].dtype),
query_vs[query_seq_id].view(-1, num_kv_heads, head_dim))
# validate generations (same way)
for gen_seq_id in range(num_gens):
seq_id = len(context_sequence_lengths) + len(
query_sequence_lengths) + gen_seq_id
block_ids = block_ids_per_seq[seq_id]
cached_kvs = torch.concat(
cache_buf[block_ids, :].unbind(dim=0),
dim=1)[:,
past_seen_tokens[seq_id]:past_seen_tokens[seq_id] +
1]
torch.testing.assert_close(
cached_kvs[0],
gen_ks[gen_seq_id].view(-1, num_kv_heads, head_dim))
torch.testing.assert_close(
cached_kvs[1],
gen_vs[gen_seq_id].view(-1, num_kv_heads, head_dim))
results_1.append(result_1)
del cache_buf
for plan_params in attn_metadata._plan_params_to_wrappers.keys():
self.assertEqual(attn_metadata.get_num_plans(plan_params), 1)
# Make sure prepare() re-planned all params.
attn_metadata.prepare()
for plan_params in attn_metadata._plan_params_to_wrappers.keys():
self.assertEqual(attn_metadata.get_num_plans(plan_params), 2)
# [context_1, gen_1, gen_2]
results_2 = []
block_ids_per_seq = [
cache_indices[0], cache_indices[-2], cache_indices[-1]
]
num_cached_tokens_per_seq = [
j for j in
[past_seen_tokens[0], past_seen_tokens[-2], past_seen_tokens[-1]]
]
seq_lens = torch.tensor([context_sequence_lengths[0], 1, 1],
dtype=torch.int)
attn_metadata = TestingStarAttentionMetadata(
seq_lens=seq_lens,
num_contexts=1,
num_queries=0,
kv_cache_params=KVCacheParams(
use_cache=True,
block_ids_per_seq=block_ids_per_seq,
num_cached_tokens_per_seq=num_cached_tokens_per_seq),
max_num_requests=3,
max_num_tokens=8192,
kv_cache_manager=kv_cache_manager,
request_ids=[0, 4, 5],
mapping=mapping,
)
attn_metadata.prepare()
for attn_layer_idx, star_attn in enumerate(layers):
context_qs, context_ks, context_vs = contexts_per_layer[
attn_layer_idx]
gen_qs, gen_ks, gen_vs = gens_per_layer[attn_layer_idx]
result_2 = star_attn.forward(torch.cat((context_qs[0], *gen_qs)),
torch.cat((context_ks[0], *gen_ks)),
torch.cat((context_vs[0], *gen_vs)),
attn_metadata)
self.assertEqual(result_2.size()[0],
context_sequence_lengths[0] + 1 + 1)
results_2.append(result_2)
for plan_params in attn_metadata._plan_params_to_wrappers.keys():
self.assertEqual(attn_metadata.get_num_plans(plan_params), 1)
# Make sure prepare() re-planned all params.
attn_metadata.prepare()
for plan_params in attn_metadata._plan_params_to_wrappers.keys():
self.assertEqual(attn_metadata.get_num_plans(plan_params), 2)
# [context_2, query_1, query_2]
results_3 = []
block_ids_per_seq = [
cache_indices[1], cache_indices[2], cache_indices[3]
]
num_cached_tokens_per_seq = [
j for j in
[past_seen_tokens[1], past_seen_tokens[2], past_seen_tokens[3]]
]
seq_lens = torch.tensor([
context_sequence_lengths[1], query_sequence_lengths[0],
query_sequence_lengths[1]
],
dtype=torch.int)
attn_metadata = TestingStarAttentionMetadata(
seq_lens=seq_lens,
num_contexts=1,
num_queries=2,
kv_cache_params=KVCacheParams(
use_cache=True,
block_ids_per_seq=block_ids_per_seq,
num_cached_tokens_per_seq=num_cached_tokens_per_seq,
),
max_num_requests=3,
max_num_tokens=8192,
kv_cache_manager=kv_cache_manager,
request_ids=[1, 2, 3],
mapping=mapping,
)
attn_metadata.prepare()
for attn_layer_idx, star_attn in enumerate(layers):
context_qs, context_ks, context_vs = contexts_per_layer[
attn_layer_idx]
query_qs, query_ks, query_vs = queries_per_layer[attn_layer_idx]
result_3 = star_attn.forward(torch.cat((context_qs[1], *query_qs)),
torch.cat((context_ks[1], *query_ks)),
torch.cat((context_vs[1], *query_vs)),
attn_metadata)
self.assertEqual(
result_3.size()[0],
context_sequence_lengths[1] + sum(query_sequence_lengths))
results_3.append(result_3)
for plan_params in attn_metadata._plan_params_to_wrappers.keys():
self.assertEqual(attn_metadata.get_num_plans(plan_params), 1)
# Make sure prepare() re-planned all params.
attn_metadata.prepare()
for plan_params in attn_metadata._plan_params_to_wrappers.keys():
self.assertEqual(attn_metadata.get_num_plans(plan_params), 2)
# assert value
for result_1, result_2, result_3 in zip(results_1, results_2,
results_3):
tensor_1 = torch.cat((
result_1[:sum(context_sequence_lengths), :],
result_1[sum(context_sequence_lengths
):sum(context_sequence_lengths) +
sum(query_sequence_lengths), :],
result_1[sum(context_sequence_lengths) +
sum(query_sequence_lengths):, :],
))
tensor_2 = torch.cat((
result_2[:context_sequence_lengths[0], :],
result_3[:context_sequence_lengths[1] +
sum(query_sequence_lengths), :],
result_2[context_sequence_lengths[0]:, :],
))
# Allow larger absolute difference due to flash_infer's precision problems, especially on PCIE nodes
# atol: 1e-5 -> 0.1
torch.testing.assert_close(tensor_1, tensor_2, atol=0.1, rtol=0.02)
kv_cache_manager.shutdown()
@parameterized.expand([
CUDAGraphTestScenario(
batch_size=1,
num_heads=32,
num_kv_heads=32,
head_dim=128,
anchor_size=64,
block_size=64,
dtype=torch.float16,
),
CUDAGraphTestScenario(
batch_size=16,
num_heads=32,
num_kv_heads=32,
head_dim=128,
anchor_size=64,
block_size=128,
dtype=torch.bfloat16,
),
CUDAGraphTestScenario(
batch_size=16,
num_heads=32,
num_kv_heads=[32, 16],
head_dim=128,
anchor_size=128,
block_size=128,
dtype=torch.bfloat16,
),
],
lambda testcase_func, param_num, param:
f"{testcase_func.__name__}[{param.args[0]}]",
skip_on_empty=True)
def test_attention_with_cuda_graphs(
self, test_scenario: CUDAGraphTestScenario) -> None:
# This test exercises our CUDAGraph metadata class and makes sure
# that the flashinfer attention layer is compatible with graph capture/replay.
# We compare the CUDA graph results to the results without CUDA graph.
batch_size = test_scenario.batch_size
num_heads = test_scenario.num_heads
num_kv_heads = test_scenario.num_kv_heads
head_dim = test_scenario.head_dim
dtype = test_scenario.dtype
device = 'cuda'
tokens_per_block = 64
past_seen_tokens = [random.randint(1, 1024) for _ in range(batch_size)]
cache_indices = []
last_pos = 0
for i in range(batch_size):
used_blocks = (past_seen_tokens[i] + 64) // 64
cache_indices.append(
[j for j in range(last_pos, last_pos + used_blocks)])
last_pos += used_blocks
block_ids_per_seq = [i for i in cache_indices]
[j for j in past_seen_tokens]
request_ids = list(range(batch_size))
token_nums = (torch.tensor(past_seen_tokens) + 1).tolist()
num_blocks = 512
max_seq_len = tokens_per_block * num_blocks
num_layers = 1 if isinstance(num_kv_heads, int) else len(num_kv_heads)
cp_config = {
"cp_type": CpType.STAR,
"cp_anchor_size": test_scenario.anchor_size,
"block_size": test_scenario.block_size
}
mapping = Mapping(world_size=1,
tp_size=1,
cp_size=1,
cp_config=cp_config,
rank=0)
kv_cache_config = KvCacheConfig(max_tokens=num_blocks *
tokens_per_block)
if dtype == torch.float16:
kv_cache_dtype = tensorrt_llm.bindings.DataType.HALF
elif dtype == torch.bfloat16:
kv_cache_dtype = tensorrt_llm.bindings.DataType.BF16
else:
raise ValueError("Invalid dtype for unit test")
kv_cache_manager = KVCacheManager(
kv_cache_config,
tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF,
num_layers=num_layers,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
tokens_per_block=tokens_per_block,
max_seq_len=max_seq_len,
max_batch_size=batch_size,
mapping=mapping,
dtype=kv_cache_dtype,
)
kv_cache_manager.add_dummy_requests(request_ids, token_nums)
gen_qs = []
gen_ks = []
gen_vs = []
for i in range(num_layers):
gen_qs.append([
torch.randn(1, num_heads * head_dim, dtype=dtype, device=device)
for _ in range(batch_size)
])
kv_heads = num_kv_heads if isinstance(num_kv_heads,
int) else num_kv_heads[i]
gen_ks.append([
torch.randn(1, kv_heads * head_dim, dtype=dtype, device=device)
for _ in range(batch_size)
])
gen_vs.append([
torch.randn(1, kv_heads * head_dim, dtype=dtype, device=device)
for _ in range(batch_size)
])
layers = []
for i in range(num_layers):
kv_heads = num_kv_heads if isinstance(num_kv_heads,
int) else num_kv_heads[i]
layers.append(
StarAttention(
layer_idx=i,
head_dim=head_dim,
num_heads=num_heads,
num_kv_heads=kv_heads,
))
seq_lens = torch.ones((batch_size, ), dtype=torch.int)
attn_metadata_ref = TestingStarAttentionMetadata(
seq_lens=seq_lens,
num_contexts=0,
num_queries=0,
kv_cache_params=KVCacheParams(
use_cache=True,
block_ids_per_seq=block_ids_per_seq,
num_cached_tokens_per_seq=past_seen_tokens,
),
max_num_requests=batch_size,
max_num_tokens=8192,
kv_cache_manager=kv_cache_manager,
request_ids=request_ids,
mapping=mapping,
)
attn_metadata_ref.kv_cache_manager = kv_cache_manager
workspace = torch.empty(1024 * 1024 * 128,
dtype=torch.int,
device='cuda')
attn_metadata_cuda_graph = TestingStarAttentionMetadata(
seq_lens=seq_lens,
num_contexts=0,
num_queries=0,
is_cuda_graph=True,
kv_cache_params=KVCacheParams(
use_cache=True,
block_ids_per_seq=block_ids_per_seq,
num_cached_tokens_per_seq=past_seen_tokens,
),
workspace_buffer=workspace,
max_num_requests=batch_size,
max_num_tokens=8192,
kv_cache_manager=kv_cache_manager,
request_ids=request_ids,
mapping=mapping,
)
attn_metadata_ref.prepare()
attn_metadata_cuda_graph.prepare()
results_ref = []
for i in range(num_layers):
q = torch.cat(gen_qs[i])
k = torch.cat(gen_ks[i])
v = torch.cat(gen_vs[i])
layer = layers[i]
results_ref.append(layer.forward(q, k, v, attn_metadata_ref))
graph = torch.cuda.CUDAGraph()
for i in range(num_layers):
layer = layers[i]
q = torch.cat(gen_qs[i])
k = torch.cat(gen_ks[i])
v = torch.cat(gen_vs[i])
# Warmup run, required by PT
for _ in range(2):
layer.forward(q, k, v, attn_metadata_cuda_graph)
results_actual = []
with torch.cuda.graph(graph):
for i in range(num_layers):
layer = layers[i]
q = torch.cat(gen_qs[i])
k = torch.cat(gen_ks[i])
v = torch.cat(gen_vs[i])
results_actual.append(
layer.forward(q, k, v, attn_metadata_cuda_graph))
graph.replay()
for result_actual, result_ref in zip(results_actual, results_ref):
torch.testing.assert_close(result_actual,
result_ref,
atol=0.5,
rtol=0.5)
kv_cache_manager.shutdown()
if __name__ == "__main__":
unittest.main()

View File

@ -1,139 +0,0 @@
import json
import os
import pytest
import torch
from utils.llm_data import llm_models_root
from tensorrt_llm import LLM, SamplingParams
from tensorrt_llm.llmapi import KvCacheConfig
from tensorrt_llm.llmapi.utils import get_total_gpu_memory
from tensorrt_llm.mapping import CpType
from tensorrt_llm.models.modeling_utils import QuantAlgo, QuantConfig
MAX_SEQ_LEN = 4096 + 1024
@pytest.mark.post_merge
@pytest.mark.parametrize("backend", ["pytorch"])
@pytest.mark.parametrize("model_name",
["llama-models-v3/Llama-3-8B-Instruct-Gradient-1048k"],
ids=["llama-3-8b-1048k"])
@pytest.mark.parametrize("quant", ["bf16", "fp8"])
@pytest.mark.parametrize("sp_size", [1, 2, 4], ids=["sp1", "sp2", "sp4"])
@pytest.mark.parametrize("sa_block_size", [256, 1024],
ids=["block1024", "block4096"])
@pytest.mark.parametrize("sa_anchor_size", [256, 1024],
ids=["anchor1024", "anchor4096"])
def test_model(backend, model_name, quant, sp_size, sa_block_size,
sa_anchor_size):
pytest.skip("https://nvbugs/5391679")
quant_configs = {
"bf16":
QuantConfig(),
"fp8":
QuantConfig(quant_algo=QuantAlgo.FP8),
"fp8_kv_cache":
QuantConfig(
quant_algo=QuantAlgo.FP8,
kv_cache_quant_algo=QuantAlgo.FP8,
),
}
quant_config = quant_configs[quant]
if sp_size != 1:
pytest.skip(f"skip multi gpu tests due to flashinfer's jitting mode")
if torch.cuda.device_count() < sp_size:
pytest.skip(f"Not enough GPUs available, need {sp_size} "
f"but only have {torch.cuda.device_count()}")
if sa_anchor_size > sa_block_size:
pytest.skip(
f"Unsupported sa_anchor_size {sa_anchor_size} > sa_block_size {sa_block_size}"
)
if get_total_gpu_memory(0) < 32 * 1024**3:
pytest.skip("Not enough GPU memory to run BF16 model")
model_dir = str(llm_models_root() / model_name)
cp_config = {
"cp_type": CpType.STAR,
"cp_anchor_size": sa_anchor_size,
"block_size": sa_block_size
}
max_batch_size = 20
max_output_tokens = 128
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.7)
pytorch_backend_options = dict(attn_backend='FLASHINFER_STAR_ATTENTION',
disable_overlap_scheduler=True)
llm = LLM(model=model_dir,
backend=backend,
kv_cache_config=kv_cache_config,
tensor_parallel_size=1,
quant_config=quant_config,
context_parallel_size=sp_size,
cp_config=cp_config,
**pytorch_backend_options,
max_batch_size=max_batch_size,
max_input_len=MAX_SEQ_LEN - max_output_tokens,
max_seq_len=MAX_SEQ_LEN,
max_num_tokens=(sa_block_size + sa_anchor_size) * max_batch_size)
inputs, references = [], []
current_file = os.path.abspath(__file__)
current_dir = os.path.dirname(current_file)
with open(f'{current_dir}/test_star_attention_input.jsonl', 'r') as f:
for line in f:
sample = json.loads(line)
inputs.append({
'prompt': sample['input_context'],
'query': sample['input_query']
})
references.append(sample['outputs'][0])
with llm:
outputs = llm.generate(
inputs,
use_tqdm=True,
sampling_params=SamplingParams(
max_tokens=max_output_tokens,
add_special_tokens=False,
),
)
count = 0
for ref, ret in zip(references, outputs):
#print(f'reference = {ref}')
#print(f'prediction = {ret.outputs[0].text}')
if ref not in ret.outputs[0].text:
print(f'reference {ref} is not in the output {ret.outputs[0].text}')
else:
count = count + 1
acc = count / len(outputs)
if acc < 1.0:
assert False, 'accuracy test of star attention failed'
if __name__ == '__main__':
test_model("pytorch", "llama-models-v3/Llama-3-8B-Instruct-Gradient-1048k",
"bf16", 1, 256, 256)
test_model("pytorch", "llama-models-v3/Llama-3-8B-Instruct-Gradient-1048k",
"bf16", 1, 1024, 256)
test_model("pytorch", "llama-models-v3/Llama-3-8B-Instruct-Gradient-1048k",
"bf16", 1, 1024, 1024)
test_model("pytorch", "llama-models-v3/Llama-3-8B-Instruct-Gradient-1048k",
"fp8", 1, 256, 256)
test_model("pytorch", "llama-models-v3/Llama-3-8B-Instruct-Gradient-1048k",
"fp8", 1, 1024, 256)
test_model("pytorch", "llama-models-v3/Llama-3-8B-Instruct-Gradient-1048k",
"fp8", 1, 1024, 1024)
test_model("pytorch", "llama-models-v3/Llama-3-8B-Instruct-Gradient-1048k",
"bf16", 2, 1024, 256)
test_model("pytorch", "llama-models-v3/Llama-3-8B-Instruct-Gradient-1048k",
"bf16", 2, 1024, 1024)
test_model("pytorch", "llama-models-v3/Llama-3-8B-Instruct-Gradient-1048k",
"bf16", 2, 256, 256)
test_model("pytorch", "llama-models-v3/Llama-3-8B-Instruct-Gradient-1048k",
"bf16", 4, 1024, 256)
test_model("pytorch", "llama-models-v3/Llama-3-8B-Instruct-Gradient-1048k",
"bf16", 4, 1024, 1024)
test_model("pytorch", "llama-models-v3/Llama-3-8B-Instruct-Gradient-1048k",
"bf16", 4, 256, 256)