diff --git a/tensorrt_llm/_torch/models/__init__.py b/tensorrt_llm/_torch/models/__init__.py index 177bae4ad9..6c6c6a4f1d 100644 --- a/tensorrt_llm/_torch/models/__init__.py +++ b/tensorrt_llm/_torch/models/__init__.py @@ -31,6 +31,7 @@ from .modeling_qwen3_next import Qwen3NextForCausalLM from .modeling_qwen_moe import Qwen2MoeForCausalLM from .modeling_seedoss import SeedOssForCausalLM from .modeling_siglip import SiglipVisionModel +from .modeling_starcoder2 import Starcoder2ForCausalLM from .modeling_utils import get_model_architecture from .modeling_vila import VilaModel @@ -62,6 +63,7 @@ __all__ = [ "Qwen2ForRewardModel", "Qwen2MoeForCausalLM", "SiglipVisionModel", + "Starcoder2ForCausalLM", "get_model_architecture", "VilaModel", "Qwen2VLModel", diff --git a/tensorrt_llm/_torch/models/modeling_starcoder2.py b/tensorrt_llm/_torch/models/modeling_starcoder2.py new file mode 100644 index 0000000000..945392d0c2 --- /dev/null +++ b/tensorrt_llm/_torch/models/modeling_starcoder2.py @@ -0,0 +1,287 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch +from torch import nn +from transformers import Starcoder2Config + +from tensorrt_llm._torch.attention_backend import AttentionMetadata +from tensorrt_llm._torch.attention_backend.interface import PositionalEmbeddingParams, RopeParams +from tensorrt_llm._torch.model_config import ModelConfig +from tensorrt_llm._torch.models.modeling_utils import ( + DecoderModel, + DecoderModelForCausalLM, + _load_weights_impl, + register_auto_model, +) +from tensorrt_llm._torch.modules.attention import Attention +from tensorrt_llm._torch.modules.decoder_layer import DecoderLayer +from tensorrt_llm._torch.modules.embedding import Embedding +from tensorrt_llm._torch.modules.layer_norm import LayerNorm +from tensorrt_llm._torch.modules.linear import TensorParallelMode +from tensorrt_llm._torch.modules.mlp import MLP +from tensorrt_llm._torch.speculative import SpecMetadata +from tensorrt_llm.functional import PositionEmbeddingType + + +class Starcoder2Attention(Attention): + """ + StarCoder2 Attention with Grouped Query Attention and Sliding Window support. + """ + + def __init__( + self, + model_config: ModelConfig[Starcoder2Config], + layer_idx: Optional[int] = None, + ): + config = model_config.pretrained_config + super().__init__( + hidden_size=config.hidden_size, + num_attention_heads=config.num_attention_heads, + num_key_value_heads=config.num_key_value_heads, + max_position_embeddings=config.max_position_embeddings, + bias=config.use_bias, + pos_embd_params=PositionalEmbeddingParams( + type=PositionEmbeddingType.rope_gpt_neox, + rope=RopeParams.from_config(config), + ), + layer_idx=layer_idx, + dtype=config.torch_dtype, + config=model_config, + ) + + # Configure sliding window attention (4096 tokens) + self.attention_window_size = getattr(config, "sliding_window", 4096) + + def forward( + self, + position_ids: torch.IntTensor, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + **kwargs, + ) -> torch.Tensor: + """ + Overrides parent to pass attention_window_size parameter. + """ + return super().forward( + position_ids=position_ids, + hidden_states=hidden_states, + attn_metadata=attn_metadata, + attention_window_size=self.attention_window_size, + **kwargs, + ) + + +class Starcoder2DecoderLayer(DecoderLayer): + """ + StarCoder2 Decoder Layer. + + Architecture: + - Layer normalization before attention (with bias) + - Self-attention with GQA and sliding window + - Layer normalization before MLP (with bias) + - MLP with GELU activation + """ + + def __init__( + self, + model_config: ModelConfig[Starcoder2Config], + layer_idx: int, + ): + super().__init__() + config = model_config.pretrained_config + self.layer_idx = layer_idx + + self.self_attn = Starcoder2Attention( + model_config, + layer_idx=layer_idx, + ) + + if config.mlp_type == "default": + self.mlp = MLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + bias=config.use_bias, + activation=nn.GELU(), + dtype=config.torch_dtype, + config=model_config, + ) + else: + raise ValueError( + f"Unsupported mlp_type: {config.mlp_type}. Only default (linear) MLP is supported." + ) + + norm_eps = getattr(config, "norm_epsilon", 1e-5) + self.input_layernorm = LayerNorm( + hidden_size=config.hidden_size, + eps=norm_eps, + dtype=config.torch_dtype, + has_bias=True, # StarCoder2 uses bias in layer norm + ) + + self.post_attention_layernorm = LayerNorm( + hidden_size=config.hidden_size, + eps=norm_eps, + dtype=config.torch_dtype, + has_bias=True, # StarCoder2 uses bias in layer norm + ) + + def forward( + self, + position_ids: torch.IntTensor, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor] = None, + spec_metadata: Optional[SpecMetadata] = None, + **kwargs, + ): + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + + # Self Attention + hidden_states = self.self_attn( + position_ids=position_ids, + hidden_states=hidden_states, + attn_metadata=attn_metadata, + **kwargs, + ) + + # Fully Connected (MLP) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + + if spec_metadata is not None: + spec_metadata.maybe_capture_hidden_states(self.layer_idx, hidden_states, residual) + + return hidden_states, residual + + +class Starcoder2Model(DecoderModel): + """ + StarCoder2 Transformer Model. + """ + + def __init__(self, model_config: ModelConfig[Starcoder2Config]): + super().__init__(model_config) + config = self.model_config.pretrained_config + + self.embed_tokens = Embedding( + config.vocab_size, + config.hidden_size, + dtype=config.torch_dtype, + mapping=model_config.mapping, + tensor_parallel_mode=TensorParallelMode.COLUMN, + gather_output=True, + ) + + self.layers = nn.ModuleList( + [ + Starcoder2DecoderLayer( + model_config, + layer_idx, + ) + for layer_idx in range(config.num_hidden_layers) + ] + ) + + # Use norm_epsilon (Starcoder2Config attribute name) + norm_eps = getattr(config, "norm_epsilon", 1e-5) + self.norm = LayerNorm( + hidden_size=config.hidden_size, + eps=norm_eps, + dtype=config.torch_dtype, + has_bias=True, # StarCoder2 uses bias in layer norm + ) + + def forward( + self, + attn_metadata: AttentionMetadata, + input_ids: Optional[torch.IntTensor] = None, + position_ids: Optional[torch.IntTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + spec_metadata: Optional[SpecMetadata] = None, + lora_params=None, + ) -> torch.Tensor: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds.") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + hidden_states = inputs_embeds + + residual = None + for decoder_layer in self.layers: + hidden_states, residual = decoder_layer( + position_ids=position_ids, + hidden_states=hidden_states, + attn_metadata=attn_metadata, + residual=residual, + spec_metadata=spec_metadata, + lora_params=lora_params, + ) + + # Use LayerNorm's built-in residual connection support + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +@register_auto_model("Starcoder2ForCausalLM") +class Starcoder2ForCausalLM(DecoderModelForCausalLM[Starcoder2Model, Starcoder2Config]): + def __init__( + self, + model_config: ModelConfig[Starcoder2Config], + ): + # Ensure torch_dtype is set on pretrained_config (StarCoder2 uses bfloat16). + # For the 15B FP32 checkpoint, we cast it to bfloat16 for consistency. + torch_dtype_to_check = model_config.pretrained_config.torch_dtype + if torch_dtype_to_check is None or torch_dtype_to_check == torch.float32: + model_config.pretrained_config.torch_dtype = torch.bfloat16 + + super().__init__( + Starcoder2Model(model_config), + config=model_config, + hidden_size=model_config.pretrained_config.hidden_size, + vocab_size=model_config.pretrained_config.vocab_size, + ) + + def load_weights(self, weights, weight_mapper=None, skip_modules=None): + """ + Load weights with custom mapping for StarCoder2. + + StarCoder2 uses GPT-2 style MLP naming (c_fc, c_proj) + while our MLP module expects (up_proj, down_proj). + """ + if skip_modules is None: + skip_modules = [] + + # Map HuggingFace StarCoder2 weight names to TensorRT-LLM names + params_map = { + r"(.*?)\.mlp\.c_fc\.(.*)": r"\1.mlp.up_proj.\2", + r"(.*?)\.mlp\.c_proj\.(.*)": r"\1.mlp.down_proj.\2", + } + preload_weight_modules = getattr(self, "preload_weight_modules", None) + _load_weights_impl( + self, + weights, + skip_modules, + params_map=params_map, + preload_weight_modules=preload_weight_modules, + ) diff --git a/tests/integration/defs/accuracy/references/gsm8k.yaml b/tests/integration/defs/accuracy/references/gsm8k.yaml index ff93254cb1..e2bca37c51 100644 --- a/tests/integration/defs/accuracy/references/gsm8k.yaml +++ b/tests/integration/defs/accuracy/references/gsm8k.yaml @@ -270,3 +270,9 @@ zai-org/GLM-4.6: - quant_algo: NVFP4 spec_dec_algo: MTP accuracy: 88.0 +bigcode/starcoder2-3b: + - accuracy: 20.2 +bigcode/starcoder2-7b: + - accuracy: 26.5 +bigcode/starcoder2-15b: + - accuracy: 54.5 diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 684c319f70..af38a021c7 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -4264,3 +4264,49 @@ class TestDeepSeekR1LongBenchV2(LlmapiAccuracyTestHarness): if temp_dir and os.path.exists(temp_dir): import shutil shutil.rmtree(temp_dir, ignore_errors=True) + + +class TestStarcoder2_3B(LlmapiAccuracyTestHarness): + MODEL_NAME = "bigcode/starcoder2-3b" + MODEL_PATH = f"{llm_models_root()}/starcoder2-3b/" + + @skip_pre_hopper + def test_auto_dtype(self): + with LLM(self.MODEL_PATH, + attn_backend="TRTLLM", + cuda_graph_config=None, + max_batch_size=128, + max_seq_len=4096) as llm: + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm) + + +class TestStarcoder2_7B(LlmapiAccuracyTestHarness): + MODEL_NAME = "bigcode/starcoder2-7b" + MODEL_PATH = f"{llm_models_root()}/starcoder2-7b/" + + @skip_pre_hopper + def test_auto_dtype(self): + with LLM(self.MODEL_PATH, + attn_backend="TRTLLM", + cuda_graph_config=None, + max_batch_size=128, + max_seq_len=4096) as llm: + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm) + + +class TestStarcoder2_15B(LlmapiAccuracyTestHarness): + MODEL_NAME = "bigcode/starcoder2-15b" + MODEL_PATH = f"{llm_models_root()}/starcoder2-15b/" + + @skip_pre_hopper + @pytest.mark.skip_less_device_memory(80000) + def test_auto_dtype(self): + with LLM(self.MODEL_PATH, + attn_backend="TRTLLM", + cuda_graph_config=None, + max_batch_size=128, + max_seq_len=4096) as llm: + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm) diff --git a/tests/integration/test_lists/qa/llm_function_nim.txt b/tests/integration/test_lists/qa/llm_function_nim.txt index 502319d56a..40daaa151f 100644 --- a/tests/integration/test_lists/qa/llm_function_nim.txt +++ b/tests/integration/test_lists/qa/llm_function_nim.txt @@ -381,6 +381,10 @@ accuracy/test_llm_api_pytorch.py::TestNemotronUltra::test_fp8_prequantized[tp8-c accuracy/test_llm_api_pytorch.py::TestQwQ_32B::test_auto_dtype_tp4 accuracy/test_llm_api_pytorch.py::TestCodestral_22B_V01::test_auto_dtype accuracy/test_llm_api_pytorch.py::TestKimiK2::test_fp8_blockscale[latency] +accuracy/test_llm_api_pytorch.py::TestStarcoder2_3B::test_auto_dtype +accuracy/test_llm_api_pytorch.py::TestStarcoder2_7B::test_auto_dtype +accuracy/test_llm_api_pytorch.py::TestStarcoder2_15B::test_auto_dtype + accuracy/test_llm_api_pytorch_multimodal.py::TestQwen2_VL_7B::test_auto_dtype accuracy/test_llm_api_pytorch_multimodal.py::TestQwen2_5_VL_7B::test_auto_dtype accuracy/test_llm_api_pytorch_multimodal.py::TestLlava_V1_6_Mistral_7B::test_auto_dtype diff --git a/tests/integration/test_lists/test-db/l0_a30.yml b/tests/integration/test_lists/test-db/l0_a30.yml index 837467ed17..b63ea04b5f 100644 --- a/tests/integration/test_lists/test-db/l0_a30.yml +++ b/tests/integration/test_lists/test-db/l0_a30.yml @@ -19,6 +19,7 @@ l0_a30: - unittest/_torch/modeling -k "modeling_qwen" - unittest/_torch/modeling -k "modeling_qwen_moe" - unittest/_torch/modeling -k "modeling_out_of_tree" + - unittest/_torch/modeling -k "modeling_starcoder2" - unittest/_torch/auto_deploy/unit/singlegpu - unittest/_torch/sampler/test_beam_search.py - unittest/_torch/sampler/test_return_logits.py diff --git a/tests/integration/test_lists/test-db/l0_h100.yml b/tests/integration/test_lists/test-db/l0_h100.yml index 9b1f97dc6d..540a858060 100644 --- a/tests/integration/test_lists/test-db/l0_h100.yml +++ b/tests/integration/test_lists/test-db/l0_h100.yml @@ -265,6 +265,9 @@ l0_h100: - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_eagle3[llguidance-eagle3_one_model=False] - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_ngram[xgrammar] - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_ngram[llguidance] + - accuracy/test_llm_api_pytorch.py::TestStarcoder2_3B::test_auto_dtype + - accuracy/test_llm_api_pytorch.py::TestStarcoder2_7B::test_auto_dtype + - accuracy/test_llm_api_pytorch.py::TestStarcoder2_15B::test_auto_dtype - test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-image-True] - test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-mixture_text_image-True] - test_e2e.py::test_ptp_quickstart_multimodal_kv_cache_reuse[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-0.6-image] diff --git a/tests/unittest/_torch/modeling/test_modeling_starcoder2.py b/tests/unittest/_torch/modeling/test_modeling_starcoder2.py new file mode 100644 index 0000000000..3eec8dc1e8 --- /dev/null +++ b/tests/unittest/_torch/modeling/test_modeling_starcoder2.py @@ -0,0 +1,313 @@ +from copy import deepcopy +from dataclasses import dataclass + +import pytest +import torch +from transformers import Starcoder2Config +from transformers import Starcoder2ForCausalLM as HFStarcoder2ForCausalLM +from utils.util import default_dtype + +import tensorrt_llm +from tensorrt_llm._torch.attention_backend.utils import get_attention_backend +from tensorrt_llm._torch.metadata import KVCacheParams +from tensorrt_llm._torch.model_config import ModelConfig +from tensorrt_llm._torch.models.modeling_starcoder2 import Starcoder2ForCausalLM +from tensorrt_llm._torch.modules.layer_norm import LayerNorm +from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager +from tensorrt_llm.bindings.executor import KvCacheConfig +from tensorrt_llm.mapping import Mapping + +# Base config for all StarCoder2 models (based on HuggingFace configs) +_STARCODER2_BASE_CONFIG = { + "architectures": ["Starcoder2ForCausalLM"], + "attention_dropout": 0.1, + "residual_dropout": 0.1, + "embedding_dropout": 0.1, + "bos_token_id": 0, + "eos_token_id": 0, + "hidden_act": "gelu_pytorch_tanh", + "max_position_embeddings": 16384, + "mlp_type": "default", + "model_type": "starcoder2", + "norm_epsilon": 1e-5, + "num_hidden_layers": 6, # Reduced from 30/32/40 for testing + "sliding_window": 4096, + "transformers_version": "4.37.0.dev0", + "use_bias": True, + "use_cache": True, + "vocab_size": 49152, + "torch_dtype": "bfloat16", +} + +# StarCoder2-3B config (reduced for testing) +STARCODER2_3B_CONFIG = { + **_STARCODER2_BASE_CONFIG, + "hidden_size": 3072, + "initializer_range": 0.018042, + "intermediate_size": 12288, + "num_attention_heads": 24, + "num_key_value_heads": 2, + "rope_theta": 999999.4420358813, +} + +# StarCoder2-7B config (reduced for testing) +STARCODER2_7B_CONFIG = { + **_STARCODER2_BASE_CONFIG, + "hidden_size": 4608, + "initializer_range": 0.018042, + "intermediate_size": 18432, + "num_attention_heads": 36, + "num_key_value_heads": 4, + "rope_theta": 1000000, +} + +# StarCoder2-15B config (reduced for testing) +STARCODER2_15B_CONFIG = { + **_STARCODER2_BASE_CONFIG, + "hidden_size": 6144, + "initializer_range": 0.01275, + "intermediate_size": 24576, + "num_attention_heads": 48, + "num_key_value_heads": 4, + "rope_theta": 100000, +} + + +@dataclass(repr=False) +class Scenario: + backend: str + config_name: str + use_cuda_graph: bool = False + + def __repr__(self) -> str: + return f"backend:{self.backend.lower()}_config:{self.config_name.lower()}_cuda_graph:{self.use_cuda_graph}" + + +def get_kv_cache_manager( + dtype: torch.dtype, + config: Starcoder2Config, + tokens_per_block: int, + max_seq_len: int, + batch_size: int, + num_blocks: int, +): + """Helper to create KV cache manager.""" + if dtype == torch.half: + kv_cache_dtype = tensorrt_llm.bindings.DataType.HALF + elif dtype == torch.bfloat16: + kv_cache_dtype = tensorrt_llm.bindings.DataType.BF16 + else: + raise ValueError(f"Invalid dtype: {dtype}") + + mapping = Mapping(world_size=1, tp_size=1, rank=0) + kv_cache_config = KvCacheConfig( + enable_block_reuse=False, + enable_partial_reuse=False, + copy_on_partial_reuse=False, + max_tokens=num_blocks * tokens_per_block, + ) + + head_dim = config.hidden_size // config.num_attention_heads + kv_cache_manager = KVCacheManager( + kv_cache_config, + tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF, + num_layers=config.num_hidden_layers, + num_kv_heads=config.num_key_value_heads, + head_dim=head_dim, + tokens_per_block=tokens_per_block, + max_seq_len=max_seq_len, + max_batch_size=batch_size, + mapping=mapping, + dtype=kv_cache_dtype, + ) + return kv_cache_manager + + +@pytest.mark.parametrize( + "scenario", + [ + # Test without CUDA graphs + Scenario(backend="TRTLLM", config_name="3B", use_cuda_graph=False), + Scenario(backend="TRTLLM", config_name="7B", use_cuda_graph=False), + Scenario(backend="TRTLLM", config_name="15B", use_cuda_graph=False), + # Test with CUDA graphs + Scenario(backend="TRTLLM", config_name="3B", use_cuda_graph=True), + Scenario(backend="TRTLLM", config_name="7B", use_cuda_graph=True), + Scenario(backend="TRTLLM", config_name="15B", use_cuda_graph=True), + ], + ids=str, +) +@torch.no_grad() +def test_starcoder2_allclose_to_hf(scenario: Scenario) -> None: + """ + Compare TensorRT-LLM StarCoder2 output to HuggingFace. + + Tests both context and generation phases using randomly initialized models. + Optionally tests with CUDA graphs for generation phase optimization. + """ + backend = scenario.backend + config_name = scenario.config_name + use_cuda_graph = scenario.use_cuda_graph + metadata_cls = get_attention_backend(backend).Metadata + + torch.random.manual_seed(0) + + # Create config based on model size + config_mapping = { + "3B": STARCODER2_3B_CONFIG, + "7B": STARCODER2_7B_CONFIG, + "15B": STARCODER2_15B_CONFIG, + } + config_dict = deepcopy(config_mapping[config_name]) + + # Create HuggingFace model from config with random weights + hf_config = Starcoder2Config.from_dict(config_dict) + hf_starcoder2 = HFStarcoder2ForCausalLM(hf_config) + hf_starcoder2 = hf_starcoder2.to(dtype=torch.bfloat16, device="cuda").eval() + + dtype = torch.bfloat16 + device = torch.device("cuda") + + # Build TRT-LLM model and copy the same random weights from HF model + with torch.device(device), default_dtype(dtype): + model_config = ModelConfig(pretrained_config=hf_config, attn_backend=backend) + starcoder2 = Starcoder2ForCausalLM(model_config).to(dtype).to(device).eval() + starcoder2.load_weights(hf_starcoder2.state_dict()) + + # Convert LayerNorm random weights to FP32 for numerical stability + for name, module in starcoder2.named_modules(): + if isinstance(module, LayerNorm): + if hasattr(module, "weight") and module.weight is not None: + module.weight.data = module.weight.data.to(torch.float32) + if hasattr(module, "bias") and module.bias is not None: + module.bias.data = module.bias.data.to(torch.float32) + + num_blocks = 1 + tokens_per_block = 128 + max_seq_len = num_blocks * tokens_per_block + batch_size = 1 + + kv_cache_manager = get_kv_cache_manager( + dtype=dtype, + config=hf_config, + tokens_per_block=tokens_per_block, + max_seq_len=max_seq_len, + batch_size=batch_size, + num_blocks=num_blocks, + ) + + # Context phase (no CUDA graphs for prefill) + input_ids = torch.tensor( + [100, 200, 300, 400, 500, 600, 700, 800], + dtype=torch.int, + device=device, + ) + num_cached_tokens_per_seq = [0] + request_ids = [1] + token_nums = [input_ids.size(-1)] + prompt_lens = [input_ids.size(-1)] + kv_cache_manager.add_dummy_requests(request_ids, token_nums) + + attn_metadata = metadata_cls( + seq_lens=torch.tensor([input_ids.size(-1)], dtype=torch.int), + num_contexts=1, + kv_cache_params=KVCacheParams( + use_cache=True, + num_cached_tokens_per_seq=num_cached_tokens_per_seq, + ), + max_num_requests=1, + max_num_tokens=8192, + kv_cache_manager=kv_cache_manager, + request_ids=request_ids, + prompt_lens=prompt_lens, + ) + + position_ids = [torch.arange(0, input_ids.size(-1), dtype=torch.int)] + position_ids = torch.cat(position_ids).unsqueeze(0).cuda() + + with torch.inference_mode(): + attn_metadata.prepare() + logits = starcoder2.forward( + input_ids=input_ids, + position_ids=position_ids, + attn_metadata=attn_metadata, + ) + ref = hf_starcoder2.forward( + input_ids=input_ids.unsqueeze(0), + position_ids=position_ids, + use_cache=True, + ) + torch.testing.assert_close(logits, ref.logits[:, -1].float(), atol=0.1, rtol=0.1) + + # Generation phase (optionally with CUDA graphs) + gen_input_ids = torch.tensor([900], dtype=torch.int, device=device) + num_cached_tokens_per_seq = [input_ids.size(-1)] + + attn_metadata = metadata_cls( + seq_lens=torch.tensor([gen_input_ids.size(-1)], dtype=torch.int), + num_contexts=0, + kv_cache_params=KVCacheParams( + use_cache=True, + num_cached_tokens_per_seq=num_cached_tokens_per_seq, + ), + kv_cache_manager=kv_cache_manager, + request_ids=request_ids, + prompt_lens=prompt_lens, + max_num_requests=1, + max_num_tokens=8192, + ) + + gen_position_ids = [ + torch.arange( + input_ids.size(-1), input_ids.size(-1) + gen_input_ids.size(-1), dtype=torch.int + ) + ] + gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda() + + # Setup CUDA graph runner if requested + graph_runner = None + if use_cuda_graph: + from _torch.helpers import create_mock_cuda_graph_runner + + graph_runner = create_mock_cuda_graph_runner(1) + attn_metadata = attn_metadata.create_cuda_graph_metadata(1) + + # Run generation phase + with torch.inference_mode(): + if not use_cuda_graph: + attn_metadata.prepare() + logits = starcoder2.forward( + input_ids=gen_input_ids, + position_ids=gen_position_ids, + attn_metadata=attn_metadata, + ) + else: + # CUDA graph path + inputs = { + "input_ids": gen_input_ids, + "position_ids": gen_position_ids, + "attn_metadata": attn_metadata, + } + key = (1, 0, False) + + attn_metadata.prepare() + graph_runner.capture(key, lambda inputs: starcoder2.forward(**inputs), inputs) + + # Run twice to catch buffer reallocation issues + for _ in range(2): + attn_metadata.prepare() + logits = graph_runner.replay(key, inputs) + + # Compare with HuggingFace + ref = hf_starcoder2.forward( + input_ids=gen_input_ids.unsqueeze(0), + position_ids=gen_position_ids, + past_key_values=ref.past_key_values, + use_cache=True, + ) + torch.testing.assert_close(logits, ref.logits[:, -1].float(), atol=0.1, rtol=0.1) + + # Cleanup + if graph_runner is not None: + graph_runner.clear() + kv_cache_manager.shutdown()