diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b9dd903c6c..33a4dca2af 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -951,7 +951,6 @@ common-files: &common_files | tests/unittest/_torch/attention/test_attention_no_cache.py | tests/unittest/_torch/attention/test_attention.py | tests/unittest/_torch/attention/test_flashinfer_attention.py | - tests/unittest/_torch/attention/test_flashinfer_star_attn.py | tests/unittest/_torch/attention/test_vanilla_attention.py | tests/unittest/_torch/compilation/test_add_norm.py | tests/unittest/_torch/debugger/test_debugger_addon.py | @@ -1004,7 +1003,6 @@ common-files: &common_files | tests/unittest/_torch/multi_gpu/test_mnnvl_allreduce.py | tests/unittest/_torch/multi_gpu/test_mnnvl_memory.py | tests/unittest/_torch/multi_gpu/test_moe_a2a.py | - tests/unittest/_torch/multi_gpu/test_star_attention.py | tests/unittest/_torch/multi_gpu/test_user_buffers.py | tests/unittest/_torch/multimodal/test_external_embedding.py | tests/unittest/_torch/multimodal/test_find_num_image_tokens.py | diff --git a/examples/llm-api/llm_sparse_attention.py b/examples/llm-api/llm_sparse_attention.py index 3ebe4dcb61..ce052c3367 100644 --- a/examples/llm-api/llm_sparse_attention.py +++ b/examples/llm-api/llm_sparse_attention.py @@ -42,8 +42,7 @@ def parse_arguments(): parser.add_argument( '--input_file', type=str, - default="tests/unittest/_torch/multi_gpu/test_star_attention_input.jsonl" - ) + default="tests/unittest/_torch/multi_gpu/NIAH_simple_data.jsonl") # Build config parser.add_argument('--algo', diff --git a/pyproject.toml b/pyproject.toml index 031d41850d..7a57bbe09c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -992,7 +992,6 @@ exclude = [ "tests/unittest/_torch/attention/test_attention_no_cache.py", "tests/unittest/_torch/attention/test_attention.py", "tests/unittest/_torch/attention/test_flashinfer_attention.py", - "tests/unittest/_torch/attention/test_flashinfer_star_attn.py", "tests/unittest/_torch/attention/test_vanilla_attention.py", "tests/unittest/_torch/compilation/test_add_norm.py", "tests/unittest/_torch/debugger/test_debugger_addon.py", @@ -1045,7 +1044,6 @@ exclude = [ "tests/unittest/_torch/multi_gpu/test_mnnvl_allreduce.py", "tests/unittest/_torch/multi_gpu/test_mnnvl_memory.py", "tests/unittest/_torch/multi_gpu/test_moe_a2a.py", - "tests/unittest/_torch/multi_gpu/test_star_attention.py", "tests/unittest/_torch/multi_gpu/test_user_buffers.py", "tests/unittest/_torch/multimodal/test_external_embedding.py", "tests/unittest/_torch/multimodal/test_find_num_image_tokens.py", diff --git a/tests/integration/defs/.test_durations b/tests/integration/defs/.test_durations index 9943109ff3..bbb4411063 100644 --- a/tests/integration/defs/.test_durations +++ b/tests/integration/defs/.test_durations @@ -886,7 +886,6 @@ "test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[image]": 66.199569062097, "test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[image_audio]": 62.389084175927565, "test_e2e.py::test_ptp_scaffolding[DeepSeek-R1-Distill-Qwen-7B-DeepSeek-R1/DeepSeek-R1-Distill-Qwen-7B]": 7200.0001350759994238615, - "test_e2e.py::test_ptp_star_attention_example[Llama3.1-8B-BF16-llama-3.1-model/Meta-Llama-3.1-8B]": 3600.00020311586558818817, "test_e2e.py::test_qwen_e2e_cpprunner_large_new_tokens[DeepSeek-R1-Distill-Qwen-1.5B-DeepSeek-R1-Distill-Qwen-1.5B]": 137.7278483910486, "test_e2e.py::test_relaxed_acceptance_quickstart_advanced_deepseek_r1_8gpus[DeepSeek-R1-DeepSeek-R1/DeepSeek-R1]": 12134.278186964104, "test_e2e.py::test_trtllm_bench_help_sanity[meta-llama/Llama-3.1-8B]": 109.25386995915323, diff --git a/tests/integration/defs/test_e2e.py b/tests/integration/defs/test_e2e.py index ab1690c87a..344e2f5b3f 100644 --- a/tests/integration/defs/test_e2e.py +++ b/tests/integration/defs/test_e2e.py @@ -3135,29 +3135,6 @@ def test_ptp_quickstart_bert(llm_root, llm_venv, model_name, model_path, print("Success: HF model logits match TRTLLM logits!") -@pytest.mark.parametrize("model_name,model_path", [ - ("Llama3.1-8B-BF16", "llama-3.1-model/Meta-Llama-3.1-8B"), -]) -def test_ptp_star_attention_example(llm_root, llm_venv, model_name, model_path, - star_attention_input_root): - print(f"Testing {model_name}.") - workspace = llm_venv.get_working_directory() - example_root = Path(os.path.join(llm_root, "examples", "llm-api")) - input_file = Path( - os.path.join(star_attention_input_root, - "test_star_attention_input.jsonl")) - output_file = Path(os.path.join(workspace, "star_attention_output.jsonl")) - llm_venv.run_cmd([ - str(example_root / "star_attention.py"), - "--model_path", - f"{llm_models_root()}/{model_path}", - "--sa_block_size=200", - "--sa_anchor_size=200", - f"--input_file={input_file}", - f"--output_file={output_file}", - ]) - - @pytest.mark.skip_less_device_memory(80000) @pytest.mark.parametrize("model_name,model_path", [ ("DeepSeek-R1-Distill-Qwen-7B", "DeepSeek-R1/DeepSeek-R1-Distill-Qwen-7B"), diff --git a/tests/integration/test_lists/qa/llm_function_core.txt b/tests/integration/test_lists/qa/llm_function_core.txt index 8e7873f61c..872f109d6e 100644 --- a/tests/integration/test_lists/qa/llm_function_core.txt +++ b/tests/integration/test_lists/qa/llm_function_core.txt @@ -371,7 +371,6 @@ test_e2e.py::test_ptp_quickstart_advanced_pp_enabled[Llama3.3-70B-FP8-llama-3.3- test_e2e.py::test_ptp_quickstart_advanced_pp_enabled[Llama3.3-70B-FP8-llama-3.3-models/Llama-3.3-70B-Instruct-FP8-2-2-True] test_e2e.py::test_ptp_quickstart_advanced_pp_enabled[Llama3.3-70B-FP8-llama-3.3-models/Llama-3.3-70B-Instruct-FP8-2-4-False] test_e2e.py::test_ptp_quickstart_advanced_pp_enabled[Llama3.3-70B-FP8-llama-3.3-models/Llama-3.3-70B-Instruct-FP8-2-4-True] -test_e2e.py::test_ptp_star_attention_example[Llama3.1-8B-BF16-llama-3.1-model/Meta-Llama-3.1-8B] test_e2e.py::test_trtllm_bench_pytorch_backend_sanity[meta-llama/Llama-3.1-8B-llama-3.1-8b-hf-nvfp4-False-False] test_e2e.py::test_ptp_scaffolding[DeepSeek-R1-Distill-Qwen-7B-DeepSeek-R1/DeepSeek-R1-Distill-Qwen-7B] test_e2e.py::test_ptp_quickstart_advanced_deepseek_r1_w4afp8_8gpus[DeepSeek-R1-W4AFP8-DeepSeek-R1/DeepSeek-R1-W4AFP8] diff --git a/tests/integration/test_lists/qa/llm_function_core_sanity.txt b/tests/integration/test_lists/qa/llm_function_core_sanity.txt index 75282580e0..3eb5705c64 100644 --- a/tests/integration/test_lists/qa/llm_function_core_sanity.txt +++ b/tests/integration/test_lists/qa/llm_function_core_sanity.txt @@ -229,7 +229,6 @@ test_e2e.py::test_ptp_quickstart_bert[TRTLLM-BertForSequenceClassification-bert/ test_e2e.py::test_ptp_quickstart_bert[VANILLA-BertForSequenceClassification-bert/bert-base-uncased-yelp-polarity] test_e2e.py::test_ptp_scaffolding[DeepSeek-R1-Distill-Qwen-7B-DeepSeek-R1/DeepSeek-R1-Distill-Qwen-7B] test_e2e.py::test_ptp_quickstart_advanced_deepseek_r1_w4afp8_8gpus[DeepSeek-R1-W4AFP8-DeepSeek-R1/DeepSeek-R1-W4AFP8] -test_e2e.py::test_ptp_star_attention_example[Llama3.1-8B-BF16-llama-3.1-model/Meta-Llama-3.1-8B] test_e2e.py::test_qwen_e2e_cpprunner_large_new_tokens[DeepSeek-R1-Distill-Qwen-1.5B-DeepSeek-R1-Distill-Qwen-1.5B] test_e2e.py::test_relaxed_acceptance_quickstart_advanced_deepseek_r1_8gpus[DeepSeek-R1-DeepSeek-R1/DeepSeek-R1] test_e2e.py::test_trtllm_bench_pytorch_backend_sanity[meta-llama/Llama-3.1-8B-llama-3.1-8b-hf-nvfp4-False-False] diff --git a/tests/integration/test_lists/qa/llm_function_l20.txt b/tests/integration/test_lists/qa/llm_function_l20.txt index 772e39a683..b5f733001f 100644 --- a/tests/integration/test_lists/qa/llm_function_l20.txt +++ b/tests/integration/test_lists/qa/llm_function_l20.txt @@ -59,6 +59,5 @@ test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[phi4-multimodal-instruct-mult test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[phi4-multimodal-instruct-multimodals/Phi-4-multimodal-instruct-image_audio] test_e2e.py::test_ptp_quickstart_bert[VANILLA-BertForSequenceClassification-bert/bert-base-uncased-yelp-polarity] test_e2e.py::test_ptp_quickstart_bert[TRTLLM-BertForSequenceClassification-bert/bert-base-uncased-yelp-polarity] -test_e2e.py::test_ptp_star_attention_example[Llama3.1-8B-BF16-llama-3.1-model/Meta-Llama-3.1-8B] test_e2e.py::test_trtllm_bench_pytorch_backend_sanity[meta-llama/Llama-3.1-8B-llama-3.1-8b-hf-nvfp4-False-False] test_e2e.py::test_ptp_scaffolding[DeepSeek-R1-Distill-Qwen-7B-DeepSeek-R1/DeepSeek-R1-Distill-Qwen-7B] diff --git a/tests/integration/test_lists/qa/llm_function_nim.txt b/tests/integration/test_lists/qa/llm_function_nim.txt index 0b26975d26..0a3dfd3cae 100644 --- a/tests/integration/test_lists/qa/llm_function_nim.txt +++ b/tests/integration/test_lists/qa/llm_function_nim.txt @@ -263,7 +263,6 @@ test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[phi4-multimodal-instruct-mult test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[phi4-multimodal-instruct-multimodals/Phi-4-multimodal-instruct-image_audio] test_e2e.py::test_ptp_quickstart_multimodal_2gpu[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503] test_e2e.py::test_ptp_quickstart_multimodal_multiturn[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503] -test_e2e.py::test_ptp_star_attention_example[Llama3.1-8B-BF16-llama-3.1-model/Meta-Llama-3.1-8B] test_e2e.py::test_ptp_scaffolding[DeepSeek-R1-Distill-Qwen-7B-DeepSeek-R1/DeepSeek-R1-Distill-Qwen-7B] test_e2e.py::test_ptp_quickstart_advanced_deepseek_r1_w4afp8_8gpus[DeepSeek-R1-W4AFP8-DeepSeek-R1/DeepSeek-R1-W4AFP8] unittest/llmapi/test_llm_pytorch.py::test_gemma3_1b_instruct_multi_lora diff --git a/tests/integration/test_lists/waives.txt b/tests/integration/test_lists/waives.txt index 2ede4acdf4..3a9e4fa9e1 100644 --- a/tests/integration/test_lists/waives.txt +++ b/tests/integration/test_lists/waives.txt @@ -152,7 +152,9 @@ test_e2e.py::test_ptp_quickstart_advanced_multi_gpus[Nemotron-Ultra-253B-nemotro examples/test_multimodal.py::test_llm_multimodal_general[Phi-4-multimodal-instruct-pp:1-tp:1-float16-bs:1-cpp_e2e:False-nb:1] SKIP (https://nvbugs/5385992) examples/test_recurrentgemma.py::test_llm_recurrentgemma_1gpu[use_cpp_session-recurrentgemma-2b-use_paged_cache-int4_awq-float16-enable_attn_plugin-enable_gemm_plugin] SKIP (https://nvbugs/5401233) examples/test_recurrentgemma.py::test_llm_recurrentgemma_2gpu[recurrentgemma-2b] SKIP (https://nvbugs/5401233) -test_e2e.py::test_ptp_star_attention_example[Llama3.1-8B-BF16-llama-3.1-model/Meta-Llama-3.1-8B] SKIP (https://nvbugs/5409420) +examples/test_multimodal.py::test_llm_multimodal_general[kosmos-2-pp:1-tp:1-float16-bs:8-cpp_e2e:False-nb:1] SKIP (https://nvbugs/5141288) +examples/test_qwen.py::test_llm_qwen_7b_int8_kv_1node_1gpus[qwen2_vl_7b_instruct-enable_gemm_plugin-enable_weight_only] SKIP (https://nvbugs/5419067) +examples/test_qwen.py::test_llm_qwen_awq_single_gpu_summary[qwen2_vl_7b_instruct-nb:4] SKIP (https://nvbugs/5419068) examples/test_recurrentgemma.py::test_llm_recurrentgemma_1gpu[use_cpp_session-recurrentgemma-2b-use_paged_cache-fp8-float16-enable_attn_plugin-enable_gemm_plugin] SKIP (https://nvbugs/5419070) examples/test_granite.py::test_granite_bf16_lora[granite-3.0-1b-a400m-instruct] SKIP (https://nvbugs/5431132) examples/test_gemma.py::test_hf_gemma_fp8_base_bf16_multi_lora[gemma-2-9b-it] SKIP (https://nvbugs/5434451) @@ -341,7 +343,6 @@ disaggregated/test_disaggregated.py::test_disaggregated_benchmark_on_diff_backen disaggregated/test_disaggregated.py::test_disaggregated_benchmark_on_diff_backends[DeepSeek-V3-Lite-bf16] SKIP (https://nvbugs/5769890) disaggregated/test_disaggregated.py::test_disaggregated_benchmark_on_diff_backends[llama-v3-8b-hf] SKIP (https://nvbugs/5769890,https://nvbugs/5748683) accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_pp4_mtp] SKIP (https://nvbugs/5779536) -unittest/_torch/attention/test_flashinfer_star_attn.py::TestStarAttention::test_flashinfer_star_attention[num_layers:2-num_heads:32-num_kv_heads:8-head_dim:64-anchor_size:64-block_size:64-dtype:torch.float16] SKIP (https://nvbugs/5781389) unittest/_torch/ray_orchestrator/multi_gpu/test_ops.py::test_reducescatter_pg_op[var_len:True-seqlen:16-hidden:128] SKIP (https://nvbugs/5781383) cpp/test_e2e.py::test_model[-mamba-86] SKIP (https://nvbugs/5781665) unittest/llmapi/test_llm_multi_gpu_pytorch.py::test_tinyllama_logits_processor_tp2pp2 SKIP (https://nvbugs/5781731) diff --git a/tests/unittest/_torch/attention/sparse/test_rocketkv.py b/tests/unittest/_torch/attention/sparse/test_rocketkv.py index 0112fa5489..b96691b3cb 100644 --- a/tests/unittest/_torch/attention/sparse/test_rocketkv.py +++ b/tests/unittest/_torch/attention/sparse/test_rocketkv.py @@ -61,7 +61,7 @@ def test_model(backend, model_name, attention_backend): current_file = os.path.abspath(__file__) current_dir = os.path.dirname(os.path.dirname( os.path.dirname(current_file))) - input_file = f'{current_dir}/multi_gpu/test_star_attention_input.jsonl' + input_file = f'{current_dir}/multi_gpu/NIAH_simple_data.jsonl' with open(input_file, 'r') as f: for line in f: sample = json.loads(line) diff --git a/tests/unittest/_torch/attention/test_flashinfer_star_attn.py b/tests/unittest/_torch/attention/test_flashinfer_star_attn.py deleted file mode 100644 index 007298574a..0000000000 --- a/tests/unittest/_torch/attention/test_flashinfer_star_attn.py +++ /dev/null @@ -1,738 +0,0 @@ -import random -import unittest -from collections import defaultdict -from dataclasses import dataclass -from typing import List, Optional, Union - -import pytest -import torch -from parameterized import parameterized - -import tensorrt_llm -from tensorrt_llm._torch.attention_backend import (StarAttention, - StarAttentionMetadata) -from tensorrt_llm._torch.metadata import KVCacheParams -from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager -from tensorrt_llm.bindings.executor import KvCacheConfig -from tensorrt_llm.mapping import CpType, Mapping - - -class TestingStarAttentionMetadata(StarAttentionMetadata): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._num_times_planned = defaultdict(int) - - def get_num_plans(self, plan_params) -> int: - return self._num_times_planned[plan_params] - - def _plan_with_params(self, plan_params): - if self.needs_plan(plan_params): - self._num_times_planned[plan_params] += 1 - return super()._plan_with_params(plan_params) - - -@dataclass(repr=False) -class Scenario: - num_layers: int - num_heads: int - num_kv_heads: Union[int, List[Optional[int]]] - head_dim: int - anchor_size: int - block_size: int - dtype: torch.dtype - - def __repr__(self) -> str: - if isinstance(self.num_kv_heads, int): - num_kv_heads_str = str(self.num_kv_heads) - else: - num_kv_heads_str = '_'.join(map(str, self.num_kv_heads)) - return f"num_layers:{self.num_layers}-num_heads:{self.num_heads}-num_kv_heads:{num_kv_heads_str}-head_dim:{self.head_dim}-anchor_size:{self.anchor_size}-block_size:{self.block_size}-dtype:{self.dtype}" - - -@dataclass -class CUDAGraphTestScenario: - batch_size: int - num_heads: int - num_kv_heads: int - head_dim: int - anchor_size: int - block_size: int - dtype: torch.dtype - - def __repr__(self) -> str: - if isinstance(self.num_kv_heads, int): - num_kv_heads_str = str(self.num_kv_heads) - else: - num_kv_heads_str = '_'.join(map(str, self.num_kv_heads)) - return f"batch_size:{self.batch_size}-num_heads:{self.num_heads}-num_kv_heads:{num_kv_heads_str}-head_dim:{self.head_dim}-anchor_size:{self.anchor_size}-block_size:{self.block_size}-dtype:{self.dtype}" - - -@pytest.mark.skip(reason="https://nvbugspro.nvidia.com/bug/5781389") -class TestStarAttention(unittest.TestCase): - - @parameterized.expand([ - Scenario(num_layers=1, - num_heads=32, - num_kv_heads=8, - head_dim=128, - anchor_size=64, - block_size=64, - dtype=torch.bfloat16), - Scenario(num_layers=2, - num_heads=32, - num_kv_heads=8, - head_dim=64, - anchor_size=64, - block_size=64, - dtype=torch.float16), - Scenario(num_layers=2, - num_heads=32, - num_kv_heads=[8, 16], - head_dim=128, - anchor_size=64, - block_size=64, - dtype=torch.bfloat16), - Scenario(num_layers=3, - num_heads=32, - num_kv_heads=[8, None, 16], - head_dim=64, - anchor_size=64, - block_size=64, - dtype=torch.float16), - Scenario(num_layers=3, - num_heads=32, - num_kv_heads=[8, None, 16], - head_dim=64, - anchor_size=64, - block_size=128, - dtype=torch.bfloat16), - Scenario(num_layers=3, - num_heads=32, - num_kv_heads=[8, None, 16], - head_dim=64, - anchor_size=64, - block_size=256, - dtype=torch.bfloat16), - ], lambda testcase_func, param_num, param: - f"{testcase_func.__name__}[{param.args[0]}]") - def test_flashinfer_star_attention(self, scenario: Scenario): - num_layers = scenario.num_layers - num_heads = scenario.num_heads - num_kv_heads = scenario.num_kv_heads - head_dim = scenario.head_dim - dtype = scenario.dtype - - device = torch.device('cuda') - - num_gens = 2 - context_sequence_lengths = [356, 400] - query_sequence_lengths = [4, 10] - - sequence_lengths = context_sequence_lengths + query_sequence_lengths + [ - 1 - ] * num_gens - past_seen_tokens = [0, 0, 318, 356, 256, 258] - # 6 7 6 6 5 5 - cache_indices = [[0, 1, 2, 3, 4, 5], [6, 7, 8, 9, 10, 11, 12], - [13, 14, 15, 16, 17, 18], [19, 20, 21, 22, 23, 24], - [25, 26, 27, 28, 29], [30, 31, 32, 33, 34]] - batch_size = len(sequence_lengths) - request_ids = list(range(batch_size)) - token_nums = (torch.tensor(sequence_lengths) + - torch.tensor(past_seen_tokens)).tolist() - - num_blocks = 64 - tokens_per_block = 64 - max_seq_len = tokens_per_block * num_blocks - cp_config = { - "cp_type": CpType.STAR, - "cp_anchor_size": scenario.anchor_size, - "block_size": scenario.block_size - } - mapping = Mapping(world_size=1, - tp_size=1, - cp_size=1, - cp_config=cp_config, - rank=0) - - if dtype == torch.float16: - kv_cache_dtype = tensorrt_llm.bindings.DataType.HALF - elif dtype == torch.bfloat16: - kv_cache_dtype = tensorrt_llm.bindings.DataType.BF16 - else: - raise ValueError("Invalid dtype for unit test") - - kv_cache_config = KvCacheConfig(max_tokens=num_blocks * - tokens_per_block) - kv_cache_manager = KVCacheManager( - kv_cache_config, - tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF, - num_layers=num_layers, - num_kv_heads=num_kv_heads, - head_dim=head_dim, - tokens_per_block=tokens_per_block, - max_seq_len=max_seq_len, - max_batch_size=batch_size, - mapping=mapping, - dtype=kv_cache_dtype, - ) - kv_cache_manager.add_dummy_requests(request_ids, token_nums) - - for i in range(kv_cache_manager.num_layers): - buf = kv_cache_manager.get_buffers(i) - if buf is not None: - torch.nn.init.normal_(buf) - del buf - - if isinstance(num_kv_heads, int): - num_kv_heads = [num_kv_heads] * num_layers - - contexts_per_layer = [] - queries_per_layer = [] - gens_per_layer = [] - - for layer_idx in range(num_layers): - kv_heads = num_kv_heads[layer_idx] - if kv_heads is None: - continue - - context_qs = [ - torch.randn(sequence_length, - num_heads * head_dim, - dtype=dtype, - device=device) - for sequence_length in context_sequence_lengths - ] - context_ks = [ - torch.randn(sequence_length, - kv_heads * head_dim, - dtype=dtype, - device=device) - for sequence_length in context_sequence_lengths - ] - context_vs = [ - torch.randn(sequence_length, - kv_heads * head_dim, - dtype=dtype, - device=device) - for sequence_length in context_sequence_lengths - ] - - contexts_per_layer.append((context_qs, context_ks, context_vs)) - - query_qs = [ - torch.randn(sequence_length, - num_heads * head_dim, - dtype=dtype, - device=device) - for sequence_length in query_sequence_lengths - ] - - query_ks = [ - torch.randn(sequence_length, - kv_heads * head_dim, - dtype=dtype, - device=device) - for sequence_length in query_sequence_lengths - ] - query_vs = [ - torch.randn(sequence_length, - kv_heads * head_dim, - dtype=dtype, - device=device) - for sequence_length in query_sequence_lengths - ] - - queries_per_layer.append((query_qs, query_ks, query_vs)) - - gen_qs = [ - torch.randn(1, num_heads * head_dim, dtype=dtype, device=device) - for _ in range(num_gens) - ] - - gen_ks = [ - torch.randn(1, kv_heads * head_dim, dtype=dtype, device=device) - for _ in range(num_gens) - ] - - gen_vs = [ - torch.randn(1, kv_heads * head_dim, dtype=dtype, device=device) - for _ in range(num_gens) - ] - - gens_per_layer.append((gen_qs, gen_ks, gen_vs)) - - layers = [ - StarAttention( - layer_idx=layer_idx, - num_heads=num_heads, - head_dim=head_dim, - num_kv_heads=kv_heads, - ) for layer_idx, kv_heads in enumerate(num_kv_heads) - if kv_heads is not None - ] - - # [context_1, context_2, query_1, query_2, gen_1, gen_2] - results_1 = [] - - block_ids_per_seq = [i for i in cache_indices] - num_cached_tokens_per_seq = [j for j in past_seen_tokens] - - seq_lens = torch.tensor(sequence_lengths).int() - attn_metadata = TestingStarAttentionMetadata( - seq_lens=seq_lens, - num_contexts=len(context_sequence_lengths), - num_queries=len(query_sequence_lengths), - kv_cache_params=KVCacheParams( - use_cache=True, - block_ids_per_seq=block_ids_per_seq, - num_cached_tokens_per_seq=past_seen_tokens, - ), - max_num_requests=6, - max_num_tokens=8192, - kv_cache_manager=kv_cache_manager, - request_ids=request_ids, - mapping=mapping, - ) - - attn_metadata.prepare() - for attn_layer_idx, star_attn in enumerate(layers): - context_qs, context_ks, context_vs = contexts_per_layer[ - attn_layer_idx] - query_qs, query_ks, query_vs = queries_per_layer[attn_layer_idx] - gen_qs, gen_ks, gen_vs = gens_per_layer[attn_layer_idx] - - q = torch.cat((*context_qs, *query_qs, *gen_qs)) - k = torch.cat((*context_ks, *query_ks, *gen_ks)) - v = torch.cat((*context_vs, *query_vs, *gen_vs)) - - result_1 = star_attn.forward(q, k, v, attn_metadata) - self.assertEqual( - result_1.size()[0], - sum(context_sequence_lengths) + sum(query_sequence_lengths) + - num_gens) - - # validate kv cache was updated expectedly - cache_buf = kv_cache_manager.get_buffers( - star_attn.layer_idx, kv_layout=attn_metadata.kv_layout) - if attn_metadata.kv_layout == "HND": - cache_buf = cache_buf.transpose(2, 3).contiguous() - assert cache_buf is not None - num_kv_heads = cache_buf.size(-2) - - # validate contexts - block_ids_per_seq = kv_cache_manager.get_batch_cache_indices( - request_ids) - for seq_id in range(len(context_sequence_lengths)): - # get a contiguous copy of the cache for the sequence - block_ids = block_ids_per_seq[seq_id] - cached_kvs = torch.concat(cache_buf[block_ids, :].unbind(dim=0), - dim=1) - # only look at new tokens added - cached_kvs = cached_kvs[:, past_seen_tokens[seq_id]: - past_seen_tokens[seq_id] + - context_sequence_lengths[seq_id]] - - # compare to input kvs - torch.testing.assert_close( - cached_kvs[0].to(context_ks[seq_id].dtype), - context_ks[seq_id].view(-1, num_kv_heads, head_dim)) - torch.testing.assert_close( - cached_kvs[1].to(context_vs[seq_id].dtype), - context_vs[seq_id].view(-1, num_kv_heads, head_dim)) - - # validate queries - for query_seq_id in range(len(query_sequence_lengths)): - seq_id = query_seq_id + len(context_sequence_lengths) - # get a contiguous copy of the cache for the sequence - block_ids = block_ids_per_seq[seq_id] - cached_kvs = torch.concat(cache_buf[block_ids, :].unbind(dim=0), - dim=1) - # only look at new tokens added - cached_kvs = cached_kvs[:, past_seen_tokens[seq_id]: - past_seen_tokens[seq_id] + - query_sequence_lengths[query_seq_id]] - - # compare to input kvs - torch.testing.assert_close( - cached_kvs[0].to(query_ks[query_seq_id].dtype), - query_ks[query_seq_id].view(-1, num_kv_heads, head_dim)) - torch.testing.assert_close( - cached_kvs[1].to(query_vs[query_seq_id].dtype), - query_vs[query_seq_id].view(-1, num_kv_heads, head_dim)) - - # validate generations (same way) - for gen_seq_id in range(num_gens): - seq_id = len(context_sequence_lengths) + len( - query_sequence_lengths) + gen_seq_id - block_ids = block_ids_per_seq[seq_id] - cached_kvs = torch.concat( - cache_buf[block_ids, :].unbind(dim=0), - dim=1)[:, - past_seen_tokens[seq_id]:past_seen_tokens[seq_id] + - 1] - - torch.testing.assert_close( - cached_kvs[0], - gen_ks[gen_seq_id].view(-1, num_kv_heads, head_dim)) - torch.testing.assert_close( - cached_kvs[1], - gen_vs[gen_seq_id].view(-1, num_kv_heads, head_dim)) - - results_1.append(result_1) - del cache_buf - - for plan_params in attn_metadata._plan_params_to_wrappers.keys(): - self.assertEqual(attn_metadata.get_num_plans(plan_params), 1) - - # Make sure prepare() re-planned all params. - attn_metadata.prepare() - for plan_params in attn_metadata._plan_params_to_wrappers.keys(): - self.assertEqual(attn_metadata.get_num_plans(plan_params), 2) - - # [context_1, gen_1, gen_2] - results_2 = [] - - block_ids_per_seq = [ - cache_indices[0], cache_indices[-2], cache_indices[-1] - ] - num_cached_tokens_per_seq = [ - j for j in - [past_seen_tokens[0], past_seen_tokens[-2], past_seen_tokens[-1]] - ] - - seq_lens = torch.tensor([context_sequence_lengths[0], 1, 1], - dtype=torch.int) - attn_metadata = TestingStarAttentionMetadata( - seq_lens=seq_lens, - num_contexts=1, - num_queries=0, - kv_cache_params=KVCacheParams( - use_cache=True, - block_ids_per_seq=block_ids_per_seq, - num_cached_tokens_per_seq=num_cached_tokens_per_seq), - max_num_requests=3, - max_num_tokens=8192, - kv_cache_manager=kv_cache_manager, - request_ids=[0, 4, 5], - mapping=mapping, - ) - - attn_metadata.prepare() - - for attn_layer_idx, star_attn in enumerate(layers): - context_qs, context_ks, context_vs = contexts_per_layer[ - attn_layer_idx] - gen_qs, gen_ks, gen_vs = gens_per_layer[attn_layer_idx] - - result_2 = star_attn.forward(torch.cat((context_qs[0], *gen_qs)), - torch.cat((context_ks[0], *gen_ks)), - torch.cat((context_vs[0], *gen_vs)), - attn_metadata) - self.assertEqual(result_2.size()[0], - context_sequence_lengths[0] + 1 + 1) - results_2.append(result_2) - - for plan_params in attn_metadata._plan_params_to_wrappers.keys(): - self.assertEqual(attn_metadata.get_num_plans(plan_params), 1) - - # Make sure prepare() re-planned all params. - attn_metadata.prepare() - for plan_params in attn_metadata._plan_params_to_wrappers.keys(): - self.assertEqual(attn_metadata.get_num_plans(plan_params), 2) - - # [context_2, query_1, query_2] - results_3 = [] - - block_ids_per_seq = [ - cache_indices[1], cache_indices[2], cache_indices[3] - ] - num_cached_tokens_per_seq = [ - j for j in - [past_seen_tokens[1], past_seen_tokens[2], past_seen_tokens[3]] - ] - - seq_lens = torch.tensor([ - context_sequence_lengths[1], query_sequence_lengths[0], - query_sequence_lengths[1] - ], - dtype=torch.int) - attn_metadata = TestingStarAttentionMetadata( - seq_lens=seq_lens, - num_contexts=1, - num_queries=2, - kv_cache_params=KVCacheParams( - use_cache=True, - block_ids_per_seq=block_ids_per_seq, - num_cached_tokens_per_seq=num_cached_tokens_per_seq, - ), - max_num_requests=3, - max_num_tokens=8192, - kv_cache_manager=kv_cache_manager, - request_ids=[1, 2, 3], - mapping=mapping, - ) - - attn_metadata.prepare() - for attn_layer_idx, star_attn in enumerate(layers): - context_qs, context_ks, context_vs = contexts_per_layer[ - attn_layer_idx] - query_qs, query_ks, query_vs = queries_per_layer[attn_layer_idx] - - result_3 = star_attn.forward(torch.cat((context_qs[1], *query_qs)), - torch.cat((context_ks[1], *query_ks)), - torch.cat((context_vs[1], *query_vs)), - attn_metadata) - self.assertEqual( - result_3.size()[0], - context_sequence_lengths[1] + sum(query_sequence_lengths)) - results_3.append(result_3) - - for plan_params in attn_metadata._plan_params_to_wrappers.keys(): - self.assertEqual(attn_metadata.get_num_plans(plan_params), 1) - - # Make sure prepare() re-planned all params. - attn_metadata.prepare() - for plan_params in attn_metadata._plan_params_to_wrappers.keys(): - self.assertEqual(attn_metadata.get_num_plans(plan_params), 2) - - # assert value - for result_1, result_2, result_3 in zip(results_1, results_2, - results_3): - tensor_1 = torch.cat(( - result_1[:sum(context_sequence_lengths), :], - result_1[sum(context_sequence_lengths - ):sum(context_sequence_lengths) + - sum(query_sequence_lengths), :], - result_1[sum(context_sequence_lengths) + - sum(query_sequence_lengths):, :], - )) - tensor_2 = torch.cat(( - result_2[:context_sequence_lengths[0], :], - result_3[:context_sequence_lengths[1] + - sum(query_sequence_lengths), :], - result_2[context_sequence_lengths[0]:, :], - )) - # Allow larger absolute difference due to flash_infer's precision problems, especially on PCIE nodes - # atol: 1e-5 -> 0.1 - torch.testing.assert_close(tensor_1, tensor_2, atol=0.1, rtol=0.02) - - kv_cache_manager.shutdown() - - @parameterized.expand([ - CUDAGraphTestScenario( - batch_size=1, - num_heads=32, - num_kv_heads=32, - head_dim=128, - anchor_size=64, - block_size=64, - dtype=torch.float16, - ), - CUDAGraphTestScenario( - batch_size=16, - num_heads=32, - num_kv_heads=32, - head_dim=128, - anchor_size=64, - block_size=128, - dtype=torch.bfloat16, - ), - CUDAGraphTestScenario( - batch_size=16, - num_heads=32, - num_kv_heads=[32, 16], - head_dim=128, - anchor_size=128, - block_size=128, - dtype=torch.bfloat16, - ), - ], - lambda testcase_func, param_num, param: - f"{testcase_func.__name__}[{param.args[0]}]", - skip_on_empty=True) - def test_attention_with_cuda_graphs( - self, test_scenario: CUDAGraphTestScenario) -> None: - # This test exercises our CUDAGraph metadata class and makes sure - # that the flashinfer attention layer is compatible with graph capture/replay. - # We compare the CUDA graph results to the results without CUDA graph. - batch_size = test_scenario.batch_size - num_heads = test_scenario.num_heads - num_kv_heads = test_scenario.num_kv_heads - head_dim = test_scenario.head_dim - dtype = test_scenario.dtype - device = 'cuda' - - tokens_per_block = 64 - past_seen_tokens = [random.randint(1, 1024) for _ in range(batch_size)] - cache_indices = [] - last_pos = 0 - for i in range(batch_size): - used_blocks = (past_seen_tokens[i] + 64) // 64 - cache_indices.append( - [j for j in range(last_pos, last_pos + used_blocks)]) - last_pos += used_blocks - - block_ids_per_seq = [i for i in cache_indices] - [j for j in past_seen_tokens] - - request_ids = list(range(batch_size)) - token_nums = (torch.tensor(past_seen_tokens) + 1).tolist() - - num_blocks = 512 - max_seq_len = tokens_per_block * num_blocks - num_layers = 1 if isinstance(num_kv_heads, int) else len(num_kv_heads) - cp_config = { - "cp_type": CpType.STAR, - "cp_anchor_size": test_scenario.anchor_size, - "block_size": test_scenario.block_size - } - mapping = Mapping(world_size=1, - tp_size=1, - cp_size=1, - cp_config=cp_config, - rank=0) - - kv_cache_config = KvCacheConfig(max_tokens=num_blocks * - tokens_per_block) - if dtype == torch.float16: - kv_cache_dtype = tensorrt_llm.bindings.DataType.HALF - elif dtype == torch.bfloat16: - kv_cache_dtype = tensorrt_llm.bindings.DataType.BF16 - else: - raise ValueError("Invalid dtype for unit test") - - kv_cache_manager = KVCacheManager( - kv_cache_config, - tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF, - num_layers=num_layers, - num_kv_heads=num_kv_heads, - head_dim=head_dim, - tokens_per_block=tokens_per_block, - max_seq_len=max_seq_len, - max_batch_size=batch_size, - mapping=mapping, - dtype=kv_cache_dtype, - ) - kv_cache_manager.add_dummy_requests(request_ids, token_nums) - - gen_qs = [] - gen_ks = [] - gen_vs = [] - - for i in range(num_layers): - gen_qs.append([ - torch.randn(1, num_heads * head_dim, dtype=dtype, device=device) - for _ in range(batch_size) - ]) - - kv_heads = num_kv_heads if isinstance(num_kv_heads, - int) else num_kv_heads[i] - gen_ks.append([ - torch.randn(1, kv_heads * head_dim, dtype=dtype, device=device) - for _ in range(batch_size) - ]) - - gen_vs.append([ - torch.randn(1, kv_heads * head_dim, dtype=dtype, device=device) - for _ in range(batch_size) - ]) - - layers = [] - for i in range(num_layers): - kv_heads = num_kv_heads if isinstance(num_kv_heads, - int) else num_kv_heads[i] - layers.append( - StarAttention( - layer_idx=i, - head_dim=head_dim, - num_heads=num_heads, - num_kv_heads=kv_heads, - )) - - seq_lens = torch.ones((batch_size, ), dtype=torch.int) - attn_metadata_ref = TestingStarAttentionMetadata( - seq_lens=seq_lens, - num_contexts=0, - num_queries=0, - kv_cache_params=KVCacheParams( - use_cache=True, - block_ids_per_seq=block_ids_per_seq, - num_cached_tokens_per_seq=past_seen_tokens, - ), - max_num_requests=batch_size, - max_num_tokens=8192, - kv_cache_manager=kv_cache_manager, - request_ids=request_ids, - mapping=mapping, - ) - - attn_metadata_ref.kv_cache_manager = kv_cache_manager - - workspace = torch.empty(1024 * 1024 * 128, - dtype=torch.int, - device='cuda') - attn_metadata_cuda_graph = TestingStarAttentionMetadata( - seq_lens=seq_lens, - num_contexts=0, - num_queries=0, - is_cuda_graph=True, - kv_cache_params=KVCacheParams( - use_cache=True, - block_ids_per_seq=block_ids_per_seq, - num_cached_tokens_per_seq=past_seen_tokens, - ), - workspace_buffer=workspace, - max_num_requests=batch_size, - max_num_tokens=8192, - kv_cache_manager=kv_cache_manager, - request_ids=request_ids, - mapping=mapping, - ) - - attn_metadata_ref.prepare() - attn_metadata_cuda_graph.prepare() - - results_ref = [] - - for i in range(num_layers): - q = torch.cat(gen_qs[i]) - k = torch.cat(gen_ks[i]) - v = torch.cat(gen_vs[i]) - layer = layers[i] - results_ref.append(layer.forward(q, k, v, attn_metadata_ref)) - - graph = torch.cuda.CUDAGraph() - for i in range(num_layers): - layer = layers[i] - q = torch.cat(gen_qs[i]) - k = torch.cat(gen_ks[i]) - v = torch.cat(gen_vs[i]) - # Warmup run, required by PT - for _ in range(2): - layer.forward(q, k, v, attn_metadata_cuda_graph) - - results_actual = [] - with torch.cuda.graph(graph): - for i in range(num_layers): - layer = layers[i] - q = torch.cat(gen_qs[i]) - k = torch.cat(gen_ks[i]) - v = torch.cat(gen_vs[i]) - results_actual.append( - layer.forward(q, k, v, attn_metadata_cuda_graph)) - - graph.replay() - - for result_actual, result_ref in zip(results_actual, results_ref): - torch.testing.assert_close(result_actual, - result_ref, - atol=0.5, - rtol=0.5) - - kv_cache_manager.shutdown() - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/unittest/_torch/multi_gpu/test_star_attention_input.jsonl b/tests/unittest/_torch/multi_gpu/NIAH_simple_data.jsonl similarity index 100% rename from tests/unittest/_torch/multi_gpu/test_star_attention_input.jsonl rename to tests/unittest/_torch/multi_gpu/NIAH_simple_data.jsonl diff --git a/tests/unittest/_torch/multi_gpu/test_star_attention.py b/tests/unittest/_torch/multi_gpu/test_star_attention.py deleted file mode 100644 index abad54e6bc..0000000000 --- a/tests/unittest/_torch/multi_gpu/test_star_attention.py +++ /dev/null @@ -1,139 +0,0 @@ -import json -import os - -import pytest -import torch -from utils.llm_data import llm_models_root - -from tensorrt_llm import LLM, SamplingParams -from tensorrt_llm.llmapi import KvCacheConfig -from tensorrt_llm.llmapi.utils import get_total_gpu_memory -from tensorrt_llm.mapping import CpType -from tensorrt_llm.models.modeling_utils import QuantAlgo, QuantConfig - -MAX_SEQ_LEN = 4096 + 1024 - - -@pytest.mark.post_merge -@pytest.mark.parametrize("backend", ["pytorch"]) -@pytest.mark.parametrize("model_name", - ["llama-models-v3/Llama-3-8B-Instruct-Gradient-1048k"], - ids=["llama-3-8b-1048k"]) -@pytest.mark.parametrize("quant", ["bf16", "fp8"]) -@pytest.mark.parametrize("sp_size", [1, 2, 4], ids=["sp1", "sp2", "sp4"]) -@pytest.mark.parametrize("sa_block_size", [256, 1024], - ids=["block1024", "block4096"]) -@pytest.mark.parametrize("sa_anchor_size", [256, 1024], - ids=["anchor1024", "anchor4096"]) -def test_model(backend, model_name, quant, sp_size, sa_block_size, - sa_anchor_size): - pytest.skip("https://nvbugs/5391679") - quant_configs = { - "bf16": - QuantConfig(), - "fp8": - QuantConfig(quant_algo=QuantAlgo.FP8), - "fp8_kv_cache": - QuantConfig( - quant_algo=QuantAlgo.FP8, - kv_cache_quant_algo=QuantAlgo.FP8, - ), - } - quant_config = quant_configs[quant] - if sp_size != 1: - pytest.skip(f"skip multi gpu tests due to flashinfer's jitting mode") - if torch.cuda.device_count() < sp_size: - pytest.skip(f"Not enough GPUs available, need {sp_size} " - f"but only have {torch.cuda.device_count()}") - if sa_anchor_size > sa_block_size: - pytest.skip( - f"Unsupported sa_anchor_size {sa_anchor_size} > sa_block_size {sa_block_size}" - ) - - if get_total_gpu_memory(0) < 32 * 1024**3: - pytest.skip("Not enough GPU memory to run BF16 model") - - model_dir = str(llm_models_root() / model_name) - cp_config = { - "cp_type": CpType.STAR, - "cp_anchor_size": sa_anchor_size, - "block_size": sa_block_size - } - max_batch_size = 20 - max_output_tokens = 128 - kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.7) - pytorch_backend_options = dict(attn_backend='FLASHINFER_STAR_ATTENTION', - disable_overlap_scheduler=True) - - llm = LLM(model=model_dir, - backend=backend, - kv_cache_config=kv_cache_config, - tensor_parallel_size=1, - quant_config=quant_config, - context_parallel_size=sp_size, - cp_config=cp_config, - **pytorch_backend_options, - max_batch_size=max_batch_size, - max_input_len=MAX_SEQ_LEN - max_output_tokens, - max_seq_len=MAX_SEQ_LEN, - max_num_tokens=(sa_block_size + sa_anchor_size) * max_batch_size) - - inputs, references = [], [] - current_file = os.path.abspath(__file__) - current_dir = os.path.dirname(current_file) - with open(f'{current_dir}/test_star_attention_input.jsonl', 'r') as f: - for line in f: - sample = json.loads(line) - inputs.append({ - 'prompt': sample['input_context'], - 'query': sample['input_query'] - }) - references.append(sample['outputs'][0]) - with llm: - outputs = llm.generate( - inputs, - use_tqdm=True, - sampling_params=SamplingParams( - max_tokens=max_output_tokens, - add_special_tokens=False, - ), - ) - - count = 0 - for ref, ret in zip(references, outputs): - #print(f'reference = {ref}') - #print(f'prediction = {ret.outputs[0].text}') - if ref not in ret.outputs[0].text: - print(f'reference {ref} is not in the output {ret.outputs[0].text}') - else: - count = count + 1 - acc = count / len(outputs) - if acc < 1.0: - assert False, 'accuracy test of star attention failed' - - -if __name__ == '__main__': - test_model("pytorch", "llama-models-v3/Llama-3-8B-Instruct-Gradient-1048k", - "bf16", 1, 256, 256) - test_model("pytorch", "llama-models-v3/Llama-3-8B-Instruct-Gradient-1048k", - "bf16", 1, 1024, 256) - test_model("pytorch", "llama-models-v3/Llama-3-8B-Instruct-Gradient-1048k", - "bf16", 1, 1024, 1024) - test_model("pytorch", "llama-models-v3/Llama-3-8B-Instruct-Gradient-1048k", - "fp8", 1, 256, 256) - test_model("pytorch", "llama-models-v3/Llama-3-8B-Instruct-Gradient-1048k", - "fp8", 1, 1024, 256) - test_model("pytorch", "llama-models-v3/Llama-3-8B-Instruct-Gradient-1048k", - "fp8", 1, 1024, 1024) - test_model("pytorch", "llama-models-v3/Llama-3-8B-Instruct-Gradient-1048k", - "bf16", 2, 1024, 256) - test_model("pytorch", "llama-models-v3/Llama-3-8B-Instruct-Gradient-1048k", - "bf16", 2, 1024, 1024) - test_model("pytorch", "llama-models-v3/Llama-3-8B-Instruct-Gradient-1048k", - "bf16", 2, 256, 256) - test_model("pytorch", "llama-models-v3/Llama-3-8B-Instruct-Gradient-1048k", - "bf16", 4, 1024, 256) - test_model("pytorch", "llama-models-v3/Llama-3-8B-Instruct-Gradient-1048k", - "bf16", 4, 1024, 1024) - test_model("pytorch", "llama-models-v3/Llama-3-8B-Instruct-Gradient-1048k", - "bf16", 4, 256, 256)