TensorRT-LLMs/examples/llm-api/llm_sparse_attention.py
Fanrong Li 0d20a8fd61
[TRTLLM-8536][feat] Add the sparse attention framework and one use case--RocketKV support (#8086)
Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com>
Signed-off-by: yuhangh <58161490+heyuhhh@users.noreply.github.com>
Co-authored-by: yuhangh <58161490+heyuhhh@users.noreply.github.com>
2025-10-14 08:23:16 -07:00

156 lines
4.9 KiB
Python

### :title Sparse Attention
### :order 5
### :section Customization
"""
This example demonstrates how to use sparse attention with TensorRT-LLM.
Supported sparse attention algorithms:
- RocketKV
Usage:
```bash
python llm_sparse_attention.py --algo RocketKV --attention_backend TRTLLM --window_size 32 --kernel_size 63 --prompt_budget 2048
```
"""
import argparse
import json
from tensorrt_llm import LLM, SamplingParams
from tensorrt_llm.llmapi import KvCacheConfig, RocketSparseAttentionConfig
def read_input(input_file):
results = []
with open(input_file, 'r') as f:
for line in f:
ret = json.loads(line)
results.append(ret)
return results
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument(
'--model_path',
type=str,
default=
"/home/scratch.trt_llm_data/llm-models/llama-3.1-model/Llama-3.1-8B-Instruct"
)
parser.add_argument(
'--input_file',
type=str,
default="tests/unittest/_torch/multi_gpu/test_star_attention_input.jsonl"
)
# Build config
parser.add_argument('--algo',
type=str,
default='ROCKETKV',
choices=['ROCKETKV'])
parser.add_argument('--attention_backend',
type=str,
default='TRTLLM',
choices=['VANILLA', 'TRTLLM'])
parser.add_argument('--window_size',
type=int,
default=32,
help="The window size for RocketKV.")
parser.add_argument('--kernel_size',
type=int,
default=63,
help="The kernel size for RocketKV.")
parser.add_argument('--prompt_budget',
type=int,
default=2048,
help="The prompt budget for RocketKV.")
parser.add_argument("--max_seq_len",
type=int,
default=8192,
help="The maximum sequence length.")
parser.add_argument("--max_batch_size",
type=int,
default=256,
help="The maximum batch size.")
parser.add_argument("--max_new_tokens",
type=int,
default=128,
help="The maximum new tokens.")
parser.add_argument(
"--max_num_tokens",
type=int,
default=8192,
help=
"The maximum total tokens (context + generation) across all sequences in a batch."
)
parser.add_argument('--tensor_parallel_size', type=int, default=1)
# KV cache
parser.add_argument('--kv_cache_dtype', type=str, default='auto')
parser.add_argument("--kv_cache_fraction", type=float, default=None)
parser.add_argument('--num_samples', type=int, default=10)
args = parser.parse_args()
return args
def run_RocketKV(args):
data = read_input(args.input_file)
num_samples = args.num_samples if args.num_samples is not None else len(
data)
data = data[:num_samples]
kv_cache_config = KvCacheConfig(
enable_block_reuse=
False, # sparse attention does not support kv cache reuse now
free_gpu_memory_fraction=args.kv_cache_fraction,
dtype=args.kv_cache_dtype,
)
sparse_attention_config = RocketSparseAttentionConfig(
window_size=args.window_size,
kernel_size=args.kernel_size,
prompt_budget=args.prompt_budget,
)
llm = LLM(
model=args.model_path,
backend='pytorch',
kv_cache_config=kv_cache_config,
attn_backend=args.attention_backend,
sparse_attention_config=sparse_attention_config,
max_batch_size=args.max_batch_size,
max_seq_len=args.max_seq_len,
max_num_tokens=args.max_num_tokens,
tensor_parallel_size=args.tensor_parallel_size,
cuda_graph_config=
None, # sparse attention does not support cuda graph now
)
prompts = []
reference = []
for sample in data:
prompts.append(
{'prompt': sample['input_context'] + sample['input_query']})
reference.append(sample['outputs'])
sampling_params = SamplingParams(add_special_tokens=False,
max_tokens=args.max_new_tokens,
temperature=0.8,
top_p=0.95)
outputs = llm.generate(prompts, sampling_params)
for idx, output in enumerate(outputs):
print(
f'Generated text: {output.outputs[0].text!r}, ref: {reference[idx]}'
)
def main():
args = parse_arguments()
if args.algo == 'ROCKETKV':
run_RocketKV(args)
else:
raise ValueError(f"Invalid algorithm: {args.algo}")
if __name__ == "__main__":
main()