mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Merge c9ff79c900 into 6df2c8a074
This commit is contained in:
commit
fa1a79fcd7
@ -951,7 +951,6 @@ common-files: &common_files |
|
||||
tests/unittest/_torch/attention/test_attention_no_cache.py |
|
||||
tests/unittest/_torch/attention/test_attention.py |
|
||||
tests/unittest/_torch/attention/test_flashinfer_attention.py |
|
||||
tests/unittest/_torch/attention/test_flashinfer_star_attn.py |
|
||||
tests/unittest/_torch/attention/test_vanilla_attention.py |
|
||||
tests/unittest/_torch/compilation/test_add_norm.py |
|
||||
tests/unittest/_torch/debugger/test_debugger_addon.py |
|
||||
@ -1004,7 +1003,6 @@ common-files: &common_files |
|
||||
tests/unittest/_torch/multi_gpu/test_mnnvl_allreduce.py |
|
||||
tests/unittest/_torch/multi_gpu/test_mnnvl_memory.py |
|
||||
tests/unittest/_torch/multi_gpu/test_moe_a2a.py |
|
||||
tests/unittest/_torch/multi_gpu/test_star_attention.py |
|
||||
tests/unittest/_torch/multi_gpu/test_user_buffers.py |
|
||||
tests/unittest/_torch/multimodal/test_external_embedding.py |
|
||||
tests/unittest/_torch/multimodal/test_find_num_image_tokens.py |
|
||||
|
||||
@ -42,8 +42,7 @@ def parse_arguments():
|
||||
parser.add_argument(
|
||||
'--input_file',
|
||||
type=str,
|
||||
default="tests/unittest/_torch/multi_gpu/test_star_attention_input.jsonl"
|
||||
)
|
||||
default="tests/unittest/_torch/multi_gpu/NIAH_simple_data.jsonl")
|
||||
|
||||
# Build config
|
||||
parser.add_argument('--algo',
|
||||
|
||||
@ -992,7 +992,6 @@ exclude = [
|
||||
"tests/unittest/_torch/attention/test_attention_no_cache.py",
|
||||
"tests/unittest/_torch/attention/test_attention.py",
|
||||
"tests/unittest/_torch/attention/test_flashinfer_attention.py",
|
||||
"tests/unittest/_torch/attention/test_flashinfer_star_attn.py",
|
||||
"tests/unittest/_torch/attention/test_vanilla_attention.py",
|
||||
"tests/unittest/_torch/compilation/test_add_norm.py",
|
||||
"tests/unittest/_torch/debugger/test_debugger_addon.py",
|
||||
@ -1045,7 +1044,6 @@ exclude = [
|
||||
"tests/unittest/_torch/multi_gpu/test_mnnvl_allreduce.py",
|
||||
"tests/unittest/_torch/multi_gpu/test_mnnvl_memory.py",
|
||||
"tests/unittest/_torch/multi_gpu/test_moe_a2a.py",
|
||||
"tests/unittest/_torch/multi_gpu/test_star_attention.py",
|
||||
"tests/unittest/_torch/multi_gpu/test_user_buffers.py",
|
||||
"tests/unittest/_torch/multimodal/test_external_embedding.py",
|
||||
"tests/unittest/_torch/multimodal/test_find_num_image_tokens.py",
|
||||
|
||||
@ -886,7 +886,6 @@
|
||||
"test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[image]": 66.199569062097,
|
||||
"test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[image_audio]": 62.389084175927565,
|
||||
"test_e2e.py::test_ptp_scaffolding[DeepSeek-R1-Distill-Qwen-7B-DeepSeek-R1/DeepSeek-R1-Distill-Qwen-7B]": 7200.0001350759994238615,
|
||||
"test_e2e.py::test_ptp_star_attention_example[Llama3.1-8B-BF16-llama-3.1-model/Meta-Llama-3.1-8B]": 3600.00020311586558818817,
|
||||
"test_e2e.py::test_qwen_e2e_cpprunner_large_new_tokens[DeepSeek-R1-Distill-Qwen-1.5B-DeepSeek-R1-Distill-Qwen-1.5B]": 137.7278483910486,
|
||||
"test_e2e.py::test_relaxed_acceptance_quickstart_advanced_deepseek_r1_8gpus[DeepSeek-R1-DeepSeek-R1/DeepSeek-R1]": 12134.278186964104,
|
||||
"test_e2e.py::test_trtllm_bench_help_sanity[meta-llama/Llama-3.1-8B]": 109.25386995915323,
|
||||
|
||||
@ -3135,29 +3135,6 @@ def test_ptp_quickstart_bert(llm_root, llm_venv, model_name, model_path,
|
||||
print("Success: HF model logits match TRTLLM logits!")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name,model_path", [
|
||||
("Llama3.1-8B-BF16", "llama-3.1-model/Meta-Llama-3.1-8B"),
|
||||
])
|
||||
def test_ptp_star_attention_example(llm_root, llm_venv, model_name, model_path,
|
||||
star_attention_input_root):
|
||||
print(f"Testing {model_name}.")
|
||||
workspace = llm_venv.get_working_directory()
|
||||
example_root = Path(os.path.join(llm_root, "examples", "llm-api"))
|
||||
input_file = Path(
|
||||
os.path.join(star_attention_input_root,
|
||||
"test_star_attention_input.jsonl"))
|
||||
output_file = Path(os.path.join(workspace, "star_attention_output.jsonl"))
|
||||
llm_venv.run_cmd([
|
||||
str(example_root / "star_attention.py"),
|
||||
"--model_path",
|
||||
f"{llm_models_root()}/{model_path}",
|
||||
"--sa_block_size=200",
|
||||
"--sa_anchor_size=200",
|
||||
f"--input_file={input_file}",
|
||||
f"--output_file={output_file}",
|
||||
])
|
||||
|
||||
|
||||
@pytest.mark.skip_less_device_memory(80000)
|
||||
@pytest.mark.parametrize("model_name,model_path", [
|
||||
("DeepSeek-R1-Distill-Qwen-7B", "DeepSeek-R1/DeepSeek-R1-Distill-Qwen-7B"),
|
||||
|
||||
@ -371,7 +371,6 @@ test_e2e.py::test_ptp_quickstart_advanced_pp_enabled[Llama3.3-70B-FP8-llama-3.3-
|
||||
test_e2e.py::test_ptp_quickstart_advanced_pp_enabled[Llama3.3-70B-FP8-llama-3.3-models/Llama-3.3-70B-Instruct-FP8-2-2-True]
|
||||
test_e2e.py::test_ptp_quickstart_advanced_pp_enabled[Llama3.3-70B-FP8-llama-3.3-models/Llama-3.3-70B-Instruct-FP8-2-4-False]
|
||||
test_e2e.py::test_ptp_quickstart_advanced_pp_enabled[Llama3.3-70B-FP8-llama-3.3-models/Llama-3.3-70B-Instruct-FP8-2-4-True]
|
||||
test_e2e.py::test_ptp_star_attention_example[Llama3.1-8B-BF16-llama-3.1-model/Meta-Llama-3.1-8B]
|
||||
test_e2e.py::test_trtllm_bench_pytorch_backend_sanity[meta-llama/Llama-3.1-8B-llama-3.1-8b-hf-nvfp4-False-False]
|
||||
test_e2e.py::test_ptp_scaffolding[DeepSeek-R1-Distill-Qwen-7B-DeepSeek-R1/DeepSeek-R1-Distill-Qwen-7B]
|
||||
test_e2e.py::test_ptp_quickstart_advanced_deepseek_r1_w4afp8_8gpus[DeepSeek-R1-W4AFP8-DeepSeek-R1/DeepSeek-R1-W4AFP8]
|
||||
|
||||
@ -229,7 +229,6 @@ test_e2e.py::test_ptp_quickstart_bert[TRTLLM-BertForSequenceClassification-bert/
|
||||
test_e2e.py::test_ptp_quickstart_bert[VANILLA-BertForSequenceClassification-bert/bert-base-uncased-yelp-polarity]
|
||||
test_e2e.py::test_ptp_scaffolding[DeepSeek-R1-Distill-Qwen-7B-DeepSeek-R1/DeepSeek-R1-Distill-Qwen-7B]
|
||||
test_e2e.py::test_ptp_quickstart_advanced_deepseek_r1_w4afp8_8gpus[DeepSeek-R1-W4AFP8-DeepSeek-R1/DeepSeek-R1-W4AFP8]
|
||||
test_e2e.py::test_ptp_star_attention_example[Llama3.1-8B-BF16-llama-3.1-model/Meta-Llama-3.1-8B]
|
||||
test_e2e.py::test_qwen_e2e_cpprunner_large_new_tokens[DeepSeek-R1-Distill-Qwen-1.5B-DeepSeek-R1-Distill-Qwen-1.5B]
|
||||
test_e2e.py::test_relaxed_acceptance_quickstart_advanced_deepseek_r1_8gpus[DeepSeek-R1-DeepSeek-R1/DeepSeek-R1]
|
||||
test_e2e.py::test_trtllm_bench_pytorch_backend_sanity[meta-llama/Llama-3.1-8B-llama-3.1-8b-hf-nvfp4-False-False]
|
||||
|
||||
@ -59,6 +59,5 @@ test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[phi4-multimodal-instruct-mult
|
||||
test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[phi4-multimodal-instruct-multimodals/Phi-4-multimodal-instruct-image_audio]
|
||||
test_e2e.py::test_ptp_quickstart_bert[VANILLA-BertForSequenceClassification-bert/bert-base-uncased-yelp-polarity]
|
||||
test_e2e.py::test_ptp_quickstart_bert[TRTLLM-BertForSequenceClassification-bert/bert-base-uncased-yelp-polarity]
|
||||
test_e2e.py::test_ptp_star_attention_example[Llama3.1-8B-BF16-llama-3.1-model/Meta-Llama-3.1-8B]
|
||||
test_e2e.py::test_trtllm_bench_pytorch_backend_sanity[meta-llama/Llama-3.1-8B-llama-3.1-8b-hf-nvfp4-False-False]
|
||||
test_e2e.py::test_ptp_scaffolding[DeepSeek-R1-Distill-Qwen-7B-DeepSeek-R1/DeepSeek-R1-Distill-Qwen-7B]
|
||||
|
||||
@ -263,7 +263,6 @@ test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[phi4-multimodal-instruct-mult
|
||||
test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[phi4-multimodal-instruct-multimodals/Phi-4-multimodal-instruct-image_audio]
|
||||
test_e2e.py::test_ptp_quickstart_multimodal_2gpu[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503]
|
||||
test_e2e.py::test_ptp_quickstart_multimodal_multiturn[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503]
|
||||
test_e2e.py::test_ptp_star_attention_example[Llama3.1-8B-BF16-llama-3.1-model/Meta-Llama-3.1-8B]
|
||||
test_e2e.py::test_ptp_scaffolding[DeepSeek-R1-Distill-Qwen-7B-DeepSeek-R1/DeepSeek-R1-Distill-Qwen-7B]
|
||||
test_e2e.py::test_ptp_quickstart_advanced_deepseek_r1_w4afp8_8gpus[DeepSeek-R1-W4AFP8-DeepSeek-R1/DeepSeek-R1-W4AFP8]
|
||||
unittest/llmapi/test_llm_pytorch.py::test_gemma3_1b_instruct_multi_lora
|
||||
|
||||
@ -152,7 +152,9 @@ test_e2e.py::test_ptp_quickstart_advanced_multi_gpus[Nemotron-Ultra-253B-nemotro
|
||||
examples/test_multimodal.py::test_llm_multimodal_general[Phi-4-multimodal-instruct-pp:1-tp:1-float16-bs:1-cpp_e2e:False-nb:1] SKIP (https://nvbugs/5385992)
|
||||
examples/test_recurrentgemma.py::test_llm_recurrentgemma_1gpu[use_cpp_session-recurrentgemma-2b-use_paged_cache-int4_awq-float16-enable_attn_plugin-enable_gemm_plugin] SKIP (https://nvbugs/5401233)
|
||||
examples/test_recurrentgemma.py::test_llm_recurrentgemma_2gpu[recurrentgemma-2b] SKIP (https://nvbugs/5401233)
|
||||
test_e2e.py::test_ptp_star_attention_example[Llama3.1-8B-BF16-llama-3.1-model/Meta-Llama-3.1-8B] SKIP (https://nvbugs/5409420)
|
||||
examples/test_multimodal.py::test_llm_multimodal_general[kosmos-2-pp:1-tp:1-float16-bs:8-cpp_e2e:False-nb:1] SKIP (https://nvbugs/5141288)
|
||||
examples/test_qwen.py::test_llm_qwen_7b_int8_kv_1node_1gpus[qwen2_vl_7b_instruct-enable_gemm_plugin-enable_weight_only] SKIP (https://nvbugs/5419067)
|
||||
examples/test_qwen.py::test_llm_qwen_awq_single_gpu_summary[qwen2_vl_7b_instruct-nb:4] SKIP (https://nvbugs/5419068)
|
||||
examples/test_recurrentgemma.py::test_llm_recurrentgemma_1gpu[use_cpp_session-recurrentgemma-2b-use_paged_cache-fp8-float16-enable_attn_plugin-enable_gemm_plugin] SKIP (https://nvbugs/5419070)
|
||||
examples/test_granite.py::test_granite_bf16_lora[granite-3.0-1b-a400m-instruct] SKIP (https://nvbugs/5431132)
|
||||
examples/test_gemma.py::test_hf_gemma_fp8_base_bf16_multi_lora[gemma-2-9b-it] SKIP (https://nvbugs/5434451)
|
||||
@ -341,7 +343,6 @@ disaggregated/test_disaggregated.py::test_disaggregated_benchmark_on_diff_backen
|
||||
disaggregated/test_disaggregated.py::test_disaggregated_benchmark_on_diff_backends[DeepSeek-V3-Lite-bf16] SKIP (https://nvbugs/5769890)
|
||||
disaggregated/test_disaggregated.py::test_disaggregated_benchmark_on_diff_backends[llama-v3-8b-hf] SKIP (https://nvbugs/5769890,https://nvbugs/5748683)
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_pp4_mtp] SKIP (https://nvbugs/5779536)
|
||||
unittest/_torch/attention/test_flashinfer_star_attn.py::TestStarAttention::test_flashinfer_star_attention[num_layers:2-num_heads:32-num_kv_heads:8-head_dim:64-anchor_size:64-block_size:64-dtype:torch.float16] SKIP (https://nvbugs/5781389)
|
||||
unittest/_torch/ray_orchestrator/multi_gpu/test_ops.py::test_reducescatter_pg_op[var_len:True-seqlen:16-hidden:128] SKIP (https://nvbugs/5781383)
|
||||
cpp/test_e2e.py::test_model[-mamba-86] SKIP (https://nvbugs/5781665)
|
||||
unittest/llmapi/test_llm_multi_gpu_pytorch.py::test_tinyllama_logits_processor_tp2pp2 SKIP (https://nvbugs/5781731)
|
||||
|
||||
@ -61,7 +61,7 @@ def test_model(backend, model_name, attention_backend):
|
||||
current_file = os.path.abspath(__file__)
|
||||
current_dir = os.path.dirname(os.path.dirname(
|
||||
os.path.dirname(current_file)))
|
||||
input_file = f'{current_dir}/multi_gpu/test_star_attention_input.jsonl'
|
||||
input_file = f'{current_dir}/multi_gpu/NIAH_simple_data.jsonl'
|
||||
with open(input_file, 'r') as f:
|
||||
for line in f:
|
||||
sample = json.loads(line)
|
||||
|
||||
@ -1,736 +0,0 @@
|
||||
import random
|
||||
import unittest
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
from parameterized import parameterized
|
||||
|
||||
import tensorrt_llm
|
||||
from tensorrt_llm._torch.attention_backend import (StarAttention,
|
||||
StarAttentionMetadata)
|
||||
from tensorrt_llm._torch.metadata import KVCacheParams
|
||||
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
|
||||
from tensorrt_llm.bindings.executor import KvCacheConfig
|
||||
from tensorrt_llm.mapping import CpType, Mapping
|
||||
|
||||
|
||||
class TestingStarAttentionMetadata(StarAttentionMetadata):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._num_times_planned = defaultdict(int)
|
||||
|
||||
def get_num_plans(self, plan_params) -> int:
|
||||
return self._num_times_planned[plan_params]
|
||||
|
||||
def _plan_with_params(self, plan_params):
|
||||
if self.needs_plan(plan_params):
|
||||
self._num_times_planned[plan_params] += 1
|
||||
return super()._plan_with_params(plan_params)
|
||||
|
||||
|
||||
@dataclass(repr=False)
|
||||
class Scenario:
|
||||
num_layers: int
|
||||
num_heads: int
|
||||
num_kv_heads: Union[int, List[Optional[int]]]
|
||||
head_dim: int
|
||||
anchor_size: int
|
||||
block_size: int
|
||||
dtype: torch.dtype
|
||||
|
||||
def __repr__(self) -> str:
|
||||
if isinstance(self.num_kv_heads, int):
|
||||
num_kv_heads_str = str(self.num_kv_heads)
|
||||
else:
|
||||
num_kv_heads_str = '_'.join(map(str, self.num_kv_heads))
|
||||
return f"num_layers:{self.num_layers}-num_heads:{self.num_heads}-num_kv_heads:{num_kv_heads_str}-head_dim:{self.head_dim}-anchor_size:{self.anchor_size}-block_size:{self.block_size}-dtype:{self.dtype}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class CUDAGraphTestScenario:
|
||||
batch_size: int
|
||||
num_heads: int
|
||||
num_kv_heads: int
|
||||
head_dim: int
|
||||
anchor_size: int
|
||||
block_size: int
|
||||
dtype: torch.dtype
|
||||
|
||||
def __repr__(self) -> str:
|
||||
if isinstance(self.num_kv_heads, int):
|
||||
num_kv_heads_str = str(self.num_kv_heads)
|
||||
else:
|
||||
num_kv_heads_str = '_'.join(map(str, self.num_kv_heads))
|
||||
return f"batch_size:{self.batch_size}-num_heads:{self.num_heads}-num_kv_heads:{num_kv_heads_str}-head_dim:{self.head_dim}-anchor_size:{self.anchor_size}-block_size:{self.block_size}-dtype:{self.dtype}"
|
||||
|
||||
|
||||
class TestStarAttention(unittest.TestCase):
|
||||
|
||||
@parameterized.expand([
|
||||
Scenario(num_layers=1,
|
||||
num_heads=32,
|
||||
num_kv_heads=8,
|
||||
head_dim=128,
|
||||
anchor_size=64,
|
||||
block_size=64,
|
||||
dtype=torch.bfloat16),
|
||||
Scenario(num_layers=2,
|
||||
num_heads=32,
|
||||
num_kv_heads=8,
|
||||
head_dim=64,
|
||||
anchor_size=64,
|
||||
block_size=64,
|
||||
dtype=torch.float16),
|
||||
Scenario(num_layers=2,
|
||||
num_heads=32,
|
||||
num_kv_heads=[8, 16],
|
||||
head_dim=128,
|
||||
anchor_size=64,
|
||||
block_size=64,
|
||||
dtype=torch.bfloat16),
|
||||
Scenario(num_layers=3,
|
||||
num_heads=32,
|
||||
num_kv_heads=[8, None, 16],
|
||||
head_dim=64,
|
||||
anchor_size=64,
|
||||
block_size=64,
|
||||
dtype=torch.float16),
|
||||
Scenario(num_layers=3,
|
||||
num_heads=32,
|
||||
num_kv_heads=[8, None, 16],
|
||||
head_dim=64,
|
||||
anchor_size=64,
|
||||
block_size=128,
|
||||
dtype=torch.bfloat16),
|
||||
Scenario(num_layers=3,
|
||||
num_heads=32,
|
||||
num_kv_heads=[8, None, 16],
|
||||
head_dim=64,
|
||||
anchor_size=64,
|
||||
block_size=256,
|
||||
dtype=torch.bfloat16),
|
||||
], lambda testcase_func, param_num, param:
|
||||
f"{testcase_func.__name__}[{param.args[0]}]")
|
||||
def test_flashinfer_star_attention(self, scenario: Scenario):
|
||||
num_layers = scenario.num_layers
|
||||
num_heads = scenario.num_heads
|
||||
num_kv_heads = scenario.num_kv_heads
|
||||
head_dim = scenario.head_dim
|
||||
dtype = scenario.dtype
|
||||
|
||||
device = torch.device('cuda')
|
||||
|
||||
num_gens = 2
|
||||
context_sequence_lengths = [356, 400]
|
||||
query_sequence_lengths = [4, 10]
|
||||
|
||||
sequence_lengths = context_sequence_lengths + query_sequence_lengths + [
|
||||
1
|
||||
] * num_gens
|
||||
past_seen_tokens = [0, 0, 318, 356, 256, 258]
|
||||
# 6 7 6 6 5 5
|
||||
cache_indices = [[0, 1, 2, 3, 4, 5], [6, 7, 8, 9, 10, 11, 12],
|
||||
[13, 14, 15, 16, 17, 18], [19, 20, 21, 22, 23, 24],
|
||||
[25, 26, 27, 28, 29], [30, 31, 32, 33, 34]]
|
||||
batch_size = len(sequence_lengths)
|
||||
request_ids = list(range(batch_size))
|
||||
token_nums = (torch.tensor(sequence_lengths) +
|
||||
torch.tensor(past_seen_tokens)).tolist()
|
||||
|
||||
num_blocks = 64
|
||||
tokens_per_block = 64
|
||||
max_seq_len = tokens_per_block * num_blocks
|
||||
cp_config = {
|
||||
"cp_type": CpType.STAR,
|
||||
"cp_anchor_size": scenario.anchor_size,
|
||||
"block_size": scenario.block_size
|
||||
}
|
||||
mapping = Mapping(world_size=1,
|
||||
tp_size=1,
|
||||
cp_size=1,
|
||||
cp_config=cp_config,
|
||||
rank=0)
|
||||
|
||||
if dtype == torch.float16:
|
||||
kv_cache_dtype = tensorrt_llm.bindings.DataType.HALF
|
||||
elif dtype == torch.bfloat16:
|
||||
kv_cache_dtype = tensorrt_llm.bindings.DataType.BF16
|
||||
else:
|
||||
raise ValueError("Invalid dtype for unit test")
|
||||
|
||||
kv_cache_config = KvCacheConfig(max_tokens=num_blocks *
|
||||
tokens_per_block)
|
||||
kv_cache_manager = KVCacheManager(
|
||||
kv_cache_config,
|
||||
tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF,
|
||||
num_layers=num_layers,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_dim,
|
||||
tokens_per_block=tokens_per_block,
|
||||
max_seq_len=max_seq_len,
|
||||
max_batch_size=batch_size,
|
||||
mapping=mapping,
|
||||
dtype=kv_cache_dtype,
|
||||
)
|
||||
kv_cache_manager.add_dummy_requests(request_ids, token_nums)
|
||||
|
||||
for i in range(kv_cache_manager.num_layers):
|
||||
buf = kv_cache_manager.get_buffers(i)
|
||||
if buf is not None:
|
||||
torch.nn.init.normal_(buf)
|
||||
del buf
|
||||
|
||||
if isinstance(num_kv_heads, int):
|
||||
num_kv_heads = [num_kv_heads] * num_layers
|
||||
|
||||
contexts_per_layer = []
|
||||
queries_per_layer = []
|
||||
gens_per_layer = []
|
||||
|
||||
for layer_idx in range(num_layers):
|
||||
kv_heads = num_kv_heads[layer_idx]
|
||||
if kv_heads is None:
|
||||
continue
|
||||
|
||||
context_qs = [
|
||||
torch.randn(sequence_length,
|
||||
num_heads * head_dim,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
for sequence_length in context_sequence_lengths
|
||||
]
|
||||
context_ks = [
|
||||
torch.randn(sequence_length,
|
||||
kv_heads * head_dim,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
for sequence_length in context_sequence_lengths
|
||||
]
|
||||
context_vs = [
|
||||
torch.randn(sequence_length,
|
||||
kv_heads * head_dim,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
for sequence_length in context_sequence_lengths
|
||||
]
|
||||
|
||||
contexts_per_layer.append((context_qs, context_ks, context_vs))
|
||||
|
||||
query_qs = [
|
||||
torch.randn(sequence_length,
|
||||
num_heads * head_dim,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
for sequence_length in query_sequence_lengths
|
||||
]
|
||||
|
||||
query_ks = [
|
||||
torch.randn(sequence_length,
|
||||
kv_heads * head_dim,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
for sequence_length in query_sequence_lengths
|
||||
]
|
||||
query_vs = [
|
||||
torch.randn(sequence_length,
|
||||
kv_heads * head_dim,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
for sequence_length in query_sequence_lengths
|
||||
]
|
||||
|
||||
queries_per_layer.append((query_qs, query_ks, query_vs))
|
||||
|
||||
gen_qs = [
|
||||
torch.randn(1, num_heads * head_dim, dtype=dtype, device=device)
|
||||
for _ in range(num_gens)
|
||||
]
|
||||
|
||||
gen_ks = [
|
||||
torch.randn(1, kv_heads * head_dim, dtype=dtype, device=device)
|
||||
for _ in range(num_gens)
|
||||
]
|
||||
|
||||
gen_vs = [
|
||||
torch.randn(1, kv_heads * head_dim, dtype=dtype, device=device)
|
||||
for _ in range(num_gens)
|
||||
]
|
||||
|
||||
gens_per_layer.append((gen_qs, gen_ks, gen_vs))
|
||||
|
||||
layers = [
|
||||
StarAttention(
|
||||
layer_idx=layer_idx,
|
||||
num_heads=num_heads,
|
||||
head_dim=head_dim,
|
||||
num_kv_heads=kv_heads,
|
||||
) for layer_idx, kv_heads in enumerate(num_kv_heads)
|
||||
if kv_heads is not None
|
||||
]
|
||||
|
||||
# [context_1, context_2, query_1, query_2, gen_1, gen_2]
|
||||
results_1 = []
|
||||
|
||||
block_ids_per_seq = [i for i in cache_indices]
|
||||
num_cached_tokens_per_seq = [j for j in past_seen_tokens]
|
||||
|
||||
seq_lens = torch.tensor(sequence_lengths).int()
|
||||
attn_metadata = TestingStarAttentionMetadata(
|
||||
seq_lens=seq_lens,
|
||||
num_contexts=len(context_sequence_lengths),
|
||||
num_queries=len(query_sequence_lengths),
|
||||
kv_cache_params=KVCacheParams(
|
||||
use_cache=True,
|
||||
block_ids_per_seq=block_ids_per_seq,
|
||||
num_cached_tokens_per_seq=past_seen_tokens,
|
||||
),
|
||||
max_num_requests=6,
|
||||
max_num_tokens=8192,
|
||||
kv_cache_manager=kv_cache_manager,
|
||||
request_ids=request_ids,
|
||||
mapping=mapping,
|
||||
)
|
||||
|
||||
attn_metadata.prepare()
|
||||
for attn_layer_idx, star_attn in enumerate(layers):
|
||||
context_qs, context_ks, context_vs = contexts_per_layer[
|
||||
attn_layer_idx]
|
||||
query_qs, query_ks, query_vs = queries_per_layer[attn_layer_idx]
|
||||
gen_qs, gen_ks, gen_vs = gens_per_layer[attn_layer_idx]
|
||||
|
||||
q = torch.cat((*context_qs, *query_qs, *gen_qs))
|
||||
k = torch.cat((*context_ks, *query_ks, *gen_ks))
|
||||
v = torch.cat((*context_vs, *query_vs, *gen_vs))
|
||||
|
||||
result_1 = star_attn.forward(q, k, v, attn_metadata)
|
||||
self.assertEqual(
|
||||
result_1.size()[0],
|
||||
sum(context_sequence_lengths) + sum(query_sequence_lengths) +
|
||||
num_gens)
|
||||
|
||||
# validate kv cache was updated expectedly
|
||||
cache_buf = kv_cache_manager.get_buffers(
|
||||
star_attn.layer_idx, kv_layout=attn_metadata.kv_layout)
|
||||
if attn_metadata.kv_layout == "HND":
|
||||
cache_buf = cache_buf.transpose(2, 3).contiguous()
|
||||
assert cache_buf is not None
|
||||
num_kv_heads = cache_buf.size(-2)
|
||||
|
||||
# validate contexts
|
||||
block_ids_per_seq = kv_cache_manager.get_batch_cache_indices(
|
||||
request_ids)
|
||||
for seq_id in range(len(context_sequence_lengths)):
|
||||
# get a contiguous copy of the cache for the sequence
|
||||
block_ids = block_ids_per_seq[seq_id]
|
||||
cached_kvs = torch.concat(cache_buf[block_ids, :].unbind(dim=0),
|
||||
dim=1)
|
||||
# only look at new tokens added
|
||||
cached_kvs = cached_kvs[:, past_seen_tokens[seq_id]:
|
||||
past_seen_tokens[seq_id] +
|
||||
context_sequence_lengths[seq_id]]
|
||||
|
||||
# compare to input kvs
|
||||
torch.testing.assert_close(
|
||||
cached_kvs[0].to(context_ks[seq_id].dtype),
|
||||
context_ks[seq_id].view(-1, num_kv_heads, head_dim))
|
||||
torch.testing.assert_close(
|
||||
cached_kvs[1].to(context_vs[seq_id].dtype),
|
||||
context_vs[seq_id].view(-1, num_kv_heads, head_dim))
|
||||
|
||||
# validate queries
|
||||
for query_seq_id in range(len(query_sequence_lengths)):
|
||||
seq_id = query_seq_id + len(context_sequence_lengths)
|
||||
# get a contiguous copy of the cache for the sequence
|
||||
block_ids = block_ids_per_seq[seq_id]
|
||||
cached_kvs = torch.concat(cache_buf[block_ids, :].unbind(dim=0),
|
||||
dim=1)
|
||||
# only look at new tokens added
|
||||
cached_kvs = cached_kvs[:, past_seen_tokens[seq_id]:
|
||||
past_seen_tokens[seq_id] +
|
||||
query_sequence_lengths[query_seq_id]]
|
||||
|
||||
# compare to input kvs
|
||||
torch.testing.assert_close(
|
||||
cached_kvs[0].to(query_ks[query_seq_id].dtype),
|
||||
query_ks[query_seq_id].view(-1, num_kv_heads, head_dim))
|
||||
torch.testing.assert_close(
|
||||
cached_kvs[1].to(query_vs[query_seq_id].dtype),
|
||||
query_vs[query_seq_id].view(-1, num_kv_heads, head_dim))
|
||||
|
||||
# validate generations (same way)
|
||||
for gen_seq_id in range(num_gens):
|
||||
seq_id = len(context_sequence_lengths) + len(
|
||||
query_sequence_lengths) + gen_seq_id
|
||||
block_ids = block_ids_per_seq[seq_id]
|
||||
cached_kvs = torch.concat(
|
||||
cache_buf[block_ids, :].unbind(dim=0),
|
||||
dim=1)[:,
|
||||
past_seen_tokens[seq_id]:past_seen_tokens[seq_id] +
|
||||
1]
|
||||
|
||||
torch.testing.assert_close(
|
||||
cached_kvs[0],
|
||||
gen_ks[gen_seq_id].view(-1, num_kv_heads, head_dim))
|
||||
torch.testing.assert_close(
|
||||
cached_kvs[1],
|
||||
gen_vs[gen_seq_id].view(-1, num_kv_heads, head_dim))
|
||||
|
||||
results_1.append(result_1)
|
||||
del cache_buf
|
||||
|
||||
for plan_params in attn_metadata._plan_params_to_wrappers.keys():
|
||||
self.assertEqual(attn_metadata.get_num_plans(plan_params), 1)
|
||||
|
||||
# Make sure prepare() re-planned all params.
|
||||
attn_metadata.prepare()
|
||||
for plan_params in attn_metadata._plan_params_to_wrappers.keys():
|
||||
self.assertEqual(attn_metadata.get_num_plans(plan_params), 2)
|
||||
|
||||
# [context_1, gen_1, gen_2]
|
||||
results_2 = []
|
||||
|
||||
block_ids_per_seq = [
|
||||
cache_indices[0], cache_indices[-2], cache_indices[-1]
|
||||
]
|
||||
num_cached_tokens_per_seq = [
|
||||
j for j in
|
||||
[past_seen_tokens[0], past_seen_tokens[-2], past_seen_tokens[-1]]
|
||||
]
|
||||
|
||||
seq_lens = torch.tensor([context_sequence_lengths[0], 1, 1],
|
||||
dtype=torch.int)
|
||||
attn_metadata = TestingStarAttentionMetadata(
|
||||
seq_lens=seq_lens,
|
||||
num_contexts=1,
|
||||
num_queries=0,
|
||||
kv_cache_params=KVCacheParams(
|
||||
use_cache=True,
|
||||
block_ids_per_seq=block_ids_per_seq,
|
||||
num_cached_tokens_per_seq=num_cached_tokens_per_seq),
|
||||
max_num_requests=3,
|
||||
max_num_tokens=8192,
|
||||
kv_cache_manager=kv_cache_manager,
|
||||
request_ids=[0, 4, 5],
|
||||
mapping=mapping,
|
||||
)
|
||||
|
||||
attn_metadata.prepare()
|
||||
|
||||
for attn_layer_idx, star_attn in enumerate(layers):
|
||||
context_qs, context_ks, context_vs = contexts_per_layer[
|
||||
attn_layer_idx]
|
||||
gen_qs, gen_ks, gen_vs = gens_per_layer[attn_layer_idx]
|
||||
|
||||
result_2 = star_attn.forward(torch.cat((context_qs[0], *gen_qs)),
|
||||
torch.cat((context_ks[0], *gen_ks)),
|
||||
torch.cat((context_vs[0], *gen_vs)),
|
||||
attn_metadata)
|
||||
self.assertEqual(result_2.size()[0],
|
||||
context_sequence_lengths[0] + 1 + 1)
|
||||
results_2.append(result_2)
|
||||
|
||||
for plan_params in attn_metadata._plan_params_to_wrappers.keys():
|
||||
self.assertEqual(attn_metadata.get_num_plans(plan_params), 1)
|
||||
|
||||
# Make sure prepare() re-planned all params.
|
||||
attn_metadata.prepare()
|
||||
for plan_params in attn_metadata._plan_params_to_wrappers.keys():
|
||||
self.assertEqual(attn_metadata.get_num_plans(plan_params), 2)
|
||||
|
||||
# [context_2, query_1, query_2]
|
||||
results_3 = []
|
||||
|
||||
block_ids_per_seq = [
|
||||
cache_indices[1], cache_indices[2], cache_indices[3]
|
||||
]
|
||||
num_cached_tokens_per_seq = [
|
||||
j for j in
|
||||
[past_seen_tokens[1], past_seen_tokens[2], past_seen_tokens[3]]
|
||||
]
|
||||
|
||||
seq_lens = torch.tensor([
|
||||
context_sequence_lengths[1], query_sequence_lengths[0],
|
||||
query_sequence_lengths[1]
|
||||
],
|
||||
dtype=torch.int)
|
||||
attn_metadata = TestingStarAttentionMetadata(
|
||||
seq_lens=seq_lens,
|
||||
num_contexts=1,
|
||||
num_queries=2,
|
||||
kv_cache_params=KVCacheParams(
|
||||
use_cache=True,
|
||||
block_ids_per_seq=block_ids_per_seq,
|
||||
num_cached_tokens_per_seq=num_cached_tokens_per_seq,
|
||||
),
|
||||
max_num_requests=3,
|
||||
max_num_tokens=8192,
|
||||
kv_cache_manager=kv_cache_manager,
|
||||
request_ids=[1, 2, 3],
|
||||
mapping=mapping,
|
||||
)
|
||||
|
||||
attn_metadata.prepare()
|
||||
for attn_layer_idx, star_attn in enumerate(layers):
|
||||
context_qs, context_ks, context_vs = contexts_per_layer[
|
||||
attn_layer_idx]
|
||||
query_qs, query_ks, query_vs = queries_per_layer[attn_layer_idx]
|
||||
|
||||
result_3 = star_attn.forward(torch.cat((context_qs[1], *query_qs)),
|
||||
torch.cat((context_ks[1], *query_ks)),
|
||||
torch.cat((context_vs[1], *query_vs)),
|
||||
attn_metadata)
|
||||
self.assertEqual(
|
||||
result_3.size()[0],
|
||||
context_sequence_lengths[1] + sum(query_sequence_lengths))
|
||||
results_3.append(result_3)
|
||||
|
||||
for plan_params in attn_metadata._plan_params_to_wrappers.keys():
|
||||
self.assertEqual(attn_metadata.get_num_plans(plan_params), 1)
|
||||
|
||||
# Make sure prepare() re-planned all params.
|
||||
attn_metadata.prepare()
|
||||
for plan_params in attn_metadata._plan_params_to_wrappers.keys():
|
||||
self.assertEqual(attn_metadata.get_num_plans(plan_params), 2)
|
||||
|
||||
# assert value
|
||||
for result_1, result_2, result_3 in zip(results_1, results_2,
|
||||
results_3):
|
||||
tensor_1 = torch.cat((
|
||||
result_1[:sum(context_sequence_lengths), :],
|
||||
result_1[sum(context_sequence_lengths
|
||||
):sum(context_sequence_lengths) +
|
||||
sum(query_sequence_lengths), :],
|
||||
result_1[sum(context_sequence_lengths) +
|
||||
sum(query_sequence_lengths):, :],
|
||||
))
|
||||
tensor_2 = torch.cat((
|
||||
result_2[:context_sequence_lengths[0], :],
|
||||
result_3[:context_sequence_lengths[1] +
|
||||
sum(query_sequence_lengths), :],
|
||||
result_2[context_sequence_lengths[0]:, :],
|
||||
))
|
||||
# Allow larger absolute difference due to flash_infer's precision problems, especially on PCIE nodes
|
||||
# atol: 1e-5 -> 0.1
|
||||
torch.testing.assert_close(tensor_1, tensor_2, atol=0.1, rtol=0.02)
|
||||
|
||||
kv_cache_manager.shutdown()
|
||||
|
||||
@parameterized.expand([
|
||||
CUDAGraphTestScenario(
|
||||
batch_size=1,
|
||||
num_heads=32,
|
||||
num_kv_heads=32,
|
||||
head_dim=128,
|
||||
anchor_size=64,
|
||||
block_size=64,
|
||||
dtype=torch.float16,
|
||||
),
|
||||
CUDAGraphTestScenario(
|
||||
batch_size=16,
|
||||
num_heads=32,
|
||||
num_kv_heads=32,
|
||||
head_dim=128,
|
||||
anchor_size=64,
|
||||
block_size=128,
|
||||
dtype=torch.bfloat16,
|
||||
),
|
||||
CUDAGraphTestScenario(
|
||||
batch_size=16,
|
||||
num_heads=32,
|
||||
num_kv_heads=[32, 16],
|
||||
head_dim=128,
|
||||
anchor_size=128,
|
||||
block_size=128,
|
||||
dtype=torch.bfloat16,
|
||||
),
|
||||
],
|
||||
lambda testcase_func, param_num, param:
|
||||
f"{testcase_func.__name__}[{param.args[0]}]",
|
||||
skip_on_empty=True)
|
||||
def test_attention_with_cuda_graphs(
|
||||
self, test_scenario: CUDAGraphTestScenario) -> None:
|
||||
# This test exercises our CUDAGraph metadata class and makes sure
|
||||
# that the flashinfer attention layer is compatible with graph capture/replay.
|
||||
# We compare the CUDA graph results to the results without CUDA graph.
|
||||
batch_size = test_scenario.batch_size
|
||||
num_heads = test_scenario.num_heads
|
||||
num_kv_heads = test_scenario.num_kv_heads
|
||||
head_dim = test_scenario.head_dim
|
||||
dtype = test_scenario.dtype
|
||||
device = 'cuda'
|
||||
|
||||
tokens_per_block = 64
|
||||
past_seen_tokens = [random.randint(1, 1024) for _ in range(batch_size)]
|
||||
cache_indices = []
|
||||
last_pos = 0
|
||||
for i in range(batch_size):
|
||||
used_blocks = (past_seen_tokens[i] + 64) // 64
|
||||
cache_indices.append(
|
||||
[j for j in range(last_pos, last_pos + used_blocks)])
|
||||
last_pos += used_blocks
|
||||
|
||||
block_ids_per_seq = [i for i in cache_indices]
|
||||
[j for j in past_seen_tokens]
|
||||
|
||||
request_ids = list(range(batch_size))
|
||||
token_nums = (torch.tensor(past_seen_tokens) + 1).tolist()
|
||||
|
||||
num_blocks = 512
|
||||
max_seq_len = tokens_per_block * num_blocks
|
||||
num_layers = 1 if isinstance(num_kv_heads, int) else len(num_kv_heads)
|
||||
cp_config = {
|
||||
"cp_type": CpType.STAR,
|
||||
"cp_anchor_size": test_scenario.anchor_size,
|
||||
"block_size": test_scenario.block_size
|
||||
}
|
||||
mapping = Mapping(world_size=1,
|
||||
tp_size=1,
|
||||
cp_size=1,
|
||||
cp_config=cp_config,
|
||||
rank=0)
|
||||
|
||||
kv_cache_config = KvCacheConfig(max_tokens=num_blocks *
|
||||
tokens_per_block)
|
||||
if dtype == torch.float16:
|
||||
kv_cache_dtype = tensorrt_llm.bindings.DataType.HALF
|
||||
elif dtype == torch.bfloat16:
|
||||
kv_cache_dtype = tensorrt_llm.bindings.DataType.BF16
|
||||
else:
|
||||
raise ValueError("Invalid dtype for unit test")
|
||||
|
||||
kv_cache_manager = KVCacheManager(
|
||||
kv_cache_config,
|
||||
tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF,
|
||||
num_layers=num_layers,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_dim,
|
||||
tokens_per_block=tokens_per_block,
|
||||
max_seq_len=max_seq_len,
|
||||
max_batch_size=batch_size,
|
||||
mapping=mapping,
|
||||
dtype=kv_cache_dtype,
|
||||
)
|
||||
kv_cache_manager.add_dummy_requests(request_ids, token_nums)
|
||||
|
||||
gen_qs = []
|
||||
gen_ks = []
|
||||
gen_vs = []
|
||||
|
||||
for i in range(num_layers):
|
||||
gen_qs.append([
|
||||
torch.randn(1, num_heads * head_dim, dtype=dtype, device=device)
|
||||
for _ in range(batch_size)
|
||||
])
|
||||
|
||||
kv_heads = num_kv_heads if isinstance(num_kv_heads,
|
||||
int) else num_kv_heads[i]
|
||||
gen_ks.append([
|
||||
torch.randn(1, kv_heads * head_dim, dtype=dtype, device=device)
|
||||
for _ in range(batch_size)
|
||||
])
|
||||
|
||||
gen_vs.append([
|
||||
torch.randn(1, kv_heads * head_dim, dtype=dtype, device=device)
|
||||
for _ in range(batch_size)
|
||||
])
|
||||
|
||||
layers = []
|
||||
for i in range(num_layers):
|
||||
kv_heads = num_kv_heads if isinstance(num_kv_heads,
|
||||
int) else num_kv_heads[i]
|
||||
layers.append(
|
||||
StarAttention(
|
||||
layer_idx=i,
|
||||
head_dim=head_dim,
|
||||
num_heads=num_heads,
|
||||
num_kv_heads=kv_heads,
|
||||
))
|
||||
|
||||
seq_lens = torch.ones((batch_size, ), dtype=torch.int)
|
||||
attn_metadata_ref = TestingStarAttentionMetadata(
|
||||
seq_lens=seq_lens,
|
||||
num_contexts=0,
|
||||
num_queries=0,
|
||||
kv_cache_params=KVCacheParams(
|
||||
use_cache=True,
|
||||
block_ids_per_seq=block_ids_per_seq,
|
||||
num_cached_tokens_per_seq=past_seen_tokens,
|
||||
),
|
||||
max_num_requests=batch_size,
|
||||
max_num_tokens=8192,
|
||||
kv_cache_manager=kv_cache_manager,
|
||||
request_ids=request_ids,
|
||||
mapping=mapping,
|
||||
)
|
||||
|
||||
attn_metadata_ref.kv_cache_manager = kv_cache_manager
|
||||
|
||||
workspace = torch.empty(1024 * 1024 * 128,
|
||||
dtype=torch.int,
|
||||
device='cuda')
|
||||
attn_metadata_cuda_graph = TestingStarAttentionMetadata(
|
||||
seq_lens=seq_lens,
|
||||
num_contexts=0,
|
||||
num_queries=0,
|
||||
is_cuda_graph=True,
|
||||
kv_cache_params=KVCacheParams(
|
||||
use_cache=True,
|
||||
block_ids_per_seq=block_ids_per_seq,
|
||||
num_cached_tokens_per_seq=past_seen_tokens,
|
||||
),
|
||||
workspace_buffer=workspace,
|
||||
max_num_requests=batch_size,
|
||||
max_num_tokens=8192,
|
||||
kv_cache_manager=kv_cache_manager,
|
||||
request_ids=request_ids,
|
||||
mapping=mapping,
|
||||
)
|
||||
|
||||
attn_metadata_ref.prepare()
|
||||
attn_metadata_cuda_graph.prepare()
|
||||
|
||||
results_ref = []
|
||||
|
||||
for i in range(num_layers):
|
||||
q = torch.cat(gen_qs[i])
|
||||
k = torch.cat(gen_ks[i])
|
||||
v = torch.cat(gen_vs[i])
|
||||
layer = layers[i]
|
||||
results_ref.append(layer.forward(q, k, v, attn_metadata_ref))
|
||||
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
for i in range(num_layers):
|
||||
layer = layers[i]
|
||||
q = torch.cat(gen_qs[i])
|
||||
k = torch.cat(gen_ks[i])
|
||||
v = torch.cat(gen_vs[i])
|
||||
# Warmup run, required by PT
|
||||
for _ in range(2):
|
||||
layer.forward(q, k, v, attn_metadata_cuda_graph)
|
||||
|
||||
results_actual = []
|
||||
with torch.cuda.graph(graph):
|
||||
for i in range(num_layers):
|
||||
layer = layers[i]
|
||||
q = torch.cat(gen_qs[i])
|
||||
k = torch.cat(gen_ks[i])
|
||||
v = torch.cat(gen_vs[i])
|
||||
results_actual.append(
|
||||
layer.forward(q, k, v, attn_metadata_cuda_graph))
|
||||
|
||||
graph.replay()
|
||||
|
||||
for result_actual, result_ref in zip(results_actual, results_ref):
|
||||
torch.testing.assert_close(result_actual,
|
||||
result_ref,
|
||||
atol=0.5,
|
||||
rtol=0.5)
|
||||
|
||||
kv_cache_manager.shutdown()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@ -1,139 +0,0 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from utils.llm_data import llm_models_root
|
||||
|
||||
from tensorrt_llm import LLM, SamplingParams
|
||||
from tensorrt_llm.llmapi import KvCacheConfig
|
||||
from tensorrt_llm.llmapi.utils import get_total_gpu_memory
|
||||
from tensorrt_llm.mapping import CpType
|
||||
from tensorrt_llm.models.modeling_utils import QuantAlgo, QuantConfig
|
||||
|
||||
MAX_SEQ_LEN = 4096 + 1024
|
||||
|
||||
|
||||
@pytest.mark.post_merge
|
||||
@pytest.mark.parametrize("backend", ["pytorch"])
|
||||
@pytest.mark.parametrize("model_name",
|
||||
["llama-models-v3/Llama-3-8B-Instruct-Gradient-1048k"],
|
||||
ids=["llama-3-8b-1048k"])
|
||||
@pytest.mark.parametrize("quant", ["bf16", "fp8"])
|
||||
@pytest.mark.parametrize("sp_size", [1, 2, 4], ids=["sp1", "sp2", "sp4"])
|
||||
@pytest.mark.parametrize("sa_block_size", [256, 1024],
|
||||
ids=["block1024", "block4096"])
|
||||
@pytest.mark.parametrize("sa_anchor_size", [256, 1024],
|
||||
ids=["anchor1024", "anchor4096"])
|
||||
def test_model(backend, model_name, quant, sp_size, sa_block_size,
|
||||
sa_anchor_size):
|
||||
pytest.skip("https://nvbugs/5391679")
|
||||
quant_configs = {
|
||||
"bf16":
|
||||
QuantConfig(),
|
||||
"fp8":
|
||||
QuantConfig(quant_algo=QuantAlgo.FP8),
|
||||
"fp8_kv_cache":
|
||||
QuantConfig(
|
||||
quant_algo=QuantAlgo.FP8,
|
||||
kv_cache_quant_algo=QuantAlgo.FP8,
|
||||
),
|
||||
}
|
||||
quant_config = quant_configs[quant]
|
||||
if sp_size != 1:
|
||||
pytest.skip(f"skip multi gpu tests due to flashinfer's jitting mode")
|
||||
if torch.cuda.device_count() < sp_size:
|
||||
pytest.skip(f"Not enough GPUs available, need {sp_size} "
|
||||
f"but only have {torch.cuda.device_count()}")
|
||||
if sa_anchor_size > sa_block_size:
|
||||
pytest.skip(
|
||||
f"Unsupported sa_anchor_size {sa_anchor_size} > sa_block_size {sa_block_size}"
|
||||
)
|
||||
|
||||
if get_total_gpu_memory(0) < 32 * 1024**3:
|
||||
pytest.skip("Not enough GPU memory to run BF16 model")
|
||||
|
||||
model_dir = str(llm_models_root() / model_name)
|
||||
cp_config = {
|
||||
"cp_type": CpType.STAR,
|
||||
"cp_anchor_size": sa_anchor_size,
|
||||
"block_size": sa_block_size
|
||||
}
|
||||
max_batch_size = 20
|
||||
max_output_tokens = 128
|
||||
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.7)
|
||||
pytorch_backend_options = dict(attn_backend='FLASHINFER_STAR_ATTENTION',
|
||||
disable_overlap_scheduler=True)
|
||||
|
||||
llm = LLM(model=model_dir,
|
||||
backend=backend,
|
||||
kv_cache_config=kv_cache_config,
|
||||
tensor_parallel_size=1,
|
||||
quant_config=quant_config,
|
||||
context_parallel_size=sp_size,
|
||||
cp_config=cp_config,
|
||||
**pytorch_backend_options,
|
||||
max_batch_size=max_batch_size,
|
||||
max_input_len=MAX_SEQ_LEN - max_output_tokens,
|
||||
max_seq_len=MAX_SEQ_LEN,
|
||||
max_num_tokens=(sa_block_size + sa_anchor_size) * max_batch_size)
|
||||
|
||||
inputs, references = [], []
|
||||
current_file = os.path.abspath(__file__)
|
||||
current_dir = os.path.dirname(current_file)
|
||||
with open(f'{current_dir}/test_star_attention_input.jsonl', 'r') as f:
|
||||
for line in f:
|
||||
sample = json.loads(line)
|
||||
inputs.append({
|
||||
'prompt': sample['input_context'],
|
||||
'query': sample['input_query']
|
||||
})
|
||||
references.append(sample['outputs'][0])
|
||||
with llm:
|
||||
outputs = llm.generate(
|
||||
inputs,
|
||||
use_tqdm=True,
|
||||
sampling_params=SamplingParams(
|
||||
max_tokens=max_output_tokens,
|
||||
add_special_tokens=False,
|
||||
),
|
||||
)
|
||||
|
||||
count = 0
|
||||
for ref, ret in zip(references, outputs):
|
||||
#print(f'reference = {ref}')
|
||||
#print(f'prediction = {ret.outputs[0].text}')
|
||||
if ref not in ret.outputs[0].text:
|
||||
print(f'reference {ref} is not in the output {ret.outputs[0].text}')
|
||||
else:
|
||||
count = count + 1
|
||||
acc = count / len(outputs)
|
||||
if acc < 1.0:
|
||||
assert False, 'accuracy test of star attention failed'
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_model("pytorch", "llama-models-v3/Llama-3-8B-Instruct-Gradient-1048k",
|
||||
"bf16", 1, 256, 256)
|
||||
test_model("pytorch", "llama-models-v3/Llama-3-8B-Instruct-Gradient-1048k",
|
||||
"bf16", 1, 1024, 256)
|
||||
test_model("pytorch", "llama-models-v3/Llama-3-8B-Instruct-Gradient-1048k",
|
||||
"bf16", 1, 1024, 1024)
|
||||
test_model("pytorch", "llama-models-v3/Llama-3-8B-Instruct-Gradient-1048k",
|
||||
"fp8", 1, 256, 256)
|
||||
test_model("pytorch", "llama-models-v3/Llama-3-8B-Instruct-Gradient-1048k",
|
||||
"fp8", 1, 1024, 256)
|
||||
test_model("pytorch", "llama-models-v3/Llama-3-8B-Instruct-Gradient-1048k",
|
||||
"fp8", 1, 1024, 1024)
|
||||
test_model("pytorch", "llama-models-v3/Llama-3-8B-Instruct-Gradient-1048k",
|
||||
"bf16", 2, 1024, 256)
|
||||
test_model("pytorch", "llama-models-v3/Llama-3-8B-Instruct-Gradient-1048k",
|
||||
"bf16", 2, 1024, 1024)
|
||||
test_model("pytorch", "llama-models-v3/Llama-3-8B-Instruct-Gradient-1048k",
|
||||
"bf16", 2, 256, 256)
|
||||
test_model("pytorch", "llama-models-v3/Llama-3-8B-Instruct-Gradient-1048k",
|
||||
"bf16", 4, 1024, 256)
|
||||
test_model("pytorch", "llama-models-v3/Llama-3-8B-Instruct-Gradient-1048k",
|
||||
"bf16", 4, 1024, 1024)
|
||||
test_model("pytorch", "llama-models-v3/Llama-3-8B-Instruct-Gradient-1048k",
|
||||
"bf16", 4, 256, 256)
|
||||
Loading…
Reference in New Issue
Block a user