[TRTLLM-7967][feat] Adding Starcoder2 PyTorch Backend Support (#8923)

Signed-off-by: Yibin Li <109242046+yibinl-nvidia@users.noreply.github.com>
This commit is contained in:
Yibin Li 2025-11-24 11:23:22 -08:00 committed by GitHub
parent 336593cac5
commit 1ce483c999
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 662 additions and 0 deletions

View File

@ -31,6 +31,7 @@ from .modeling_qwen3_next import Qwen3NextForCausalLM
from .modeling_qwen_moe import Qwen2MoeForCausalLM
from .modeling_seedoss import SeedOssForCausalLM
from .modeling_siglip import SiglipVisionModel
from .modeling_starcoder2 import Starcoder2ForCausalLM
from .modeling_utils import get_model_architecture
from .modeling_vila import VilaModel
@ -62,6 +63,7 @@ __all__ = [
"Qwen2ForRewardModel",
"Qwen2MoeForCausalLM",
"SiglipVisionModel",
"Starcoder2ForCausalLM",
"get_model_architecture",
"VilaModel",
"Qwen2VLModel",

View File

@ -0,0 +1,287 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import torch
from torch import nn
from transformers import Starcoder2Config
from tensorrt_llm._torch.attention_backend import AttentionMetadata
from tensorrt_llm._torch.attention_backend.interface import PositionalEmbeddingParams, RopeParams
from tensorrt_llm._torch.model_config import ModelConfig
from tensorrt_llm._torch.models.modeling_utils import (
DecoderModel,
DecoderModelForCausalLM,
_load_weights_impl,
register_auto_model,
)
from tensorrt_llm._torch.modules.attention import Attention
from tensorrt_llm._torch.modules.decoder_layer import DecoderLayer
from tensorrt_llm._torch.modules.embedding import Embedding
from tensorrt_llm._torch.modules.layer_norm import LayerNorm
from tensorrt_llm._torch.modules.linear import TensorParallelMode
from tensorrt_llm._torch.modules.mlp import MLP
from tensorrt_llm._torch.speculative import SpecMetadata
from tensorrt_llm.functional import PositionEmbeddingType
class Starcoder2Attention(Attention):
"""
StarCoder2 Attention with Grouped Query Attention and Sliding Window support.
"""
def __init__(
self,
model_config: ModelConfig[Starcoder2Config],
layer_idx: Optional[int] = None,
):
config = model_config.pretrained_config
super().__init__(
hidden_size=config.hidden_size,
num_attention_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads,
max_position_embeddings=config.max_position_embeddings,
bias=config.use_bias,
pos_embd_params=PositionalEmbeddingParams(
type=PositionEmbeddingType.rope_gpt_neox,
rope=RopeParams.from_config(config),
),
layer_idx=layer_idx,
dtype=config.torch_dtype,
config=model_config,
)
# Configure sliding window attention (4096 tokens)
self.attention_window_size = getattr(config, "sliding_window", 4096)
def forward(
self,
position_ids: torch.IntTensor,
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
**kwargs,
) -> torch.Tensor:
"""
Overrides parent to pass attention_window_size parameter.
"""
return super().forward(
position_ids=position_ids,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
attention_window_size=self.attention_window_size,
**kwargs,
)
class Starcoder2DecoderLayer(DecoderLayer):
"""
StarCoder2 Decoder Layer.
Architecture:
- Layer normalization before attention (with bias)
- Self-attention with GQA and sliding window
- Layer normalization before MLP (with bias)
- MLP with GELU activation
"""
def __init__(
self,
model_config: ModelConfig[Starcoder2Config],
layer_idx: int,
):
super().__init__()
config = model_config.pretrained_config
self.layer_idx = layer_idx
self.self_attn = Starcoder2Attention(
model_config,
layer_idx=layer_idx,
)
if config.mlp_type == "default":
self.mlp = MLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
bias=config.use_bias,
activation=nn.GELU(),
dtype=config.torch_dtype,
config=model_config,
)
else:
raise ValueError(
f"Unsupported mlp_type: {config.mlp_type}. Only default (linear) MLP is supported."
)
norm_eps = getattr(config, "norm_epsilon", 1e-5)
self.input_layernorm = LayerNorm(
hidden_size=config.hidden_size,
eps=norm_eps,
dtype=config.torch_dtype,
has_bias=True, # StarCoder2 uses bias in layer norm
)
self.post_attention_layernorm = LayerNorm(
hidden_size=config.hidden_size,
eps=norm_eps,
dtype=config.torch_dtype,
has_bias=True, # StarCoder2 uses bias in layer norm
)
def forward(
self,
position_ids: torch.IntTensor,
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor] = None,
spec_metadata: Optional[SpecMetadata] = None,
**kwargs,
):
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
# Self Attention
hidden_states = self.self_attn(
position_ids=position_ids,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
**kwargs,
)
# Fully Connected (MLP)
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.mlp(hidden_states)
if spec_metadata is not None:
spec_metadata.maybe_capture_hidden_states(self.layer_idx, hidden_states, residual)
return hidden_states, residual
class Starcoder2Model(DecoderModel):
"""
StarCoder2 Transformer Model.
"""
def __init__(self, model_config: ModelConfig[Starcoder2Config]):
super().__init__(model_config)
config = self.model_config.pretrained_config
self.embed_tokens = Embedding(
config.vocab_size,
config.hidden_size,
dtype=config.torch_dtype,
mapping=model_config.mapping,
tensor_parallel_mode=TensorParallelMode.COLUMN,
gather_output=True,
)
self.layers = nn.ModuleList(
[
Starcoder2DecoderLayer(
model_config,
layer_idx,
)
for layer_idx in range(config.num_hidden_layers)
]
)
# Use norm_epsilon (Starcoder2Config attribute name)
norm_eps = getattr(config, "norm_epsilon", 1e-5)
self.norm = LayerNorm(
hidden_size=config.hidden_size,
eps=norm_eps,
dtype=config.torch_dtype,
has_bias=True, # StarCoder2 uses bias in layer norm
)
def forward(
self,
attn_metadata: AttentionMetadata,
input_ids: Optional[torch.IntTensor] = None,
position_ids: Optional[torch.IntTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
spec_metadata: Optional[SpecMetadata] = None,
lora_params=None,
) -> torch.Tensor:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds.")
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds
residual = None
for decoder_layer in self.layers:
hidden_states, residual = decoder_layer(
position_ids=position_ids,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
residual=residual,
spec_metadata=spec_metadata,
lora_params=lora_params,
)
# Use LayerNorm's built-in residual connection support
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
@register_auto_model("Starcoder2ForCausalLM")
class Starcoder2ForCausalLM(DecoderModelForCausalLM[Starcoder2Model, Starcoder2Config]):
def __init__(
self,
model_config: ModelConfig[Starcoder2Config],
):
# Ensure torch_dtype is set on pretrained_config (StarCoder2 uses bfloat16).
# For the 15B FP32 checkpoint, we cast it to bfloat16 for consistency.
torch_dtype_to_check = model_config.pretrained_config.torch_dtype
if torch_dtype_to_check is None or torch_dtype_to_check == torch.float32:
model_config.pretrained_config.torch_dtype = torch.bfloat16
super().__init__(
Starcoder2Model(model_config),
config=model_config,
hidden_size=model_config.pretrained_config.hidden_size,
vocab_size=model_config.pretrained_config.vocab_size,
)
def load_weights(self, weights, weight_mapper=None, skip_modules=None):
"""
Load weights with custom mapping for StarCoder2.
StarCoder2 uses GPT-2 style MLP naming (c_fc, c_proj)
while our MLP module expects (up_proj, down_proj).
"""
if skip_modules is None:
skip_modules = []
# Map HuggingFace StarCoder2 weight names to TensorRT-LLM names
params_map = {
r"(.*?)\.mlp\.c_fc\.(.*)": r"\1.mlp.up_proj.\2",
r"(.*?)\.mlp\.c_proj\.(.*)": r"\1.mlp.down_proj.\2",
}
preload_weight_modules = getattr(self, "preload_weight_modules", None)
_load_weights_impl(
self,
weights,
skip_modules,
params_map=params_map,
preload_weight_modules=preload_weight_modules,
)

View File

@ -270,3 +270,9 @@ zai-org/GLM-4.6:
- quant_algo: NVFP4
spec_dec_algo: MTP
accuracy: 88.0
bigcode/starcoder2-3b:
- accuracy: 20.2
bigcode/starcoder2-7b:
- accuracy: 26.5
bigcode/starcoder2-15b:
- accuracy: 54.5

View File

@ -4264,3 +4264,49 @@ class TestDeepSeekR1LongBenchV2(LlmapiAccuracyTestHarness):
if temp_dir and os.path.exists(temp_dir):
import shutil
shutil.rmtree(temp_dir, ignore_errors=True)
class TestStarcoder2_3B(LlmapiAccuracyTestHarness):
MODEL_NAME = "bigcode/starcoder2-3b"
MODEL_PATH = f"{llm_models_root()}/starcoder2-3b/"
@skip_pre_hopper
def test_auto_dtype(self):
with LLM(self.MODEL_PATH,
attn_backend="TRTLLM",
cuda_graph_config=None,
max_batch_size=128,
max_seq_len=4096) as llm:
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)
class TestStarcoder2_7B(LlmapiAccuracyTestHarness):
MODEL_NAME = "bigcode/starcoder2-7b"
MODEL_PATH = f"{llm_models_root()}/starcoder2-7b/"
@skip_pre_hopper
def test_auto_dtype(self):
with LLM(self.MODEL_PATH,
attn_backend="TRTLLM",
cuda_graph_config=None,
max_batch_size=128,
max_seq_len=4096) as llm:
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)
class TestStarcoder2_15B(LlmapiAccuracyTestHarness):
MODEL_NAME = "bigcode/starcoder2-15b"
MODEL_PATH = f"{llm_models_root()}/starcoder2-15b/"
@skip_pre_hopper
@pytest.mark.skip_less_device_memory(80000)
def test_auto_dtype(self):
with LLM(self.MODEL_PATH,
attn_backend="TRTLLM",
cuda_graph_config=None,
max_batch_size=128,
max_seq_len=4096) as llm:
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)

View File

@ -381,6 +381,10 @@ accuracy/test_llm_api_pytorch.py::TestNemotronUltra::test_fp8_prequantized[tp8-c
accuracy/test_llm_api_pytorch.py::TestQwQ_32B::test_auto_dtype_tp4
accuracy/test_llm_api_pytorch.py::TestCodestral_22B_V01::test_auto_dtype
accuracy/test_llm_api_pytorch.py::TestKimiK2::test_fp8_blockscale[latency]
accuracy/test_llm_api_pytorch.py::TestStarcoder2_3B::test_auto_dtype
accuracy/test_llm_api_pytorch.py::TestStarcoder2_7B::test_auto_dtype
accuracy/test_llm_api_pytorch.py::TestStarcoder2_15B::test_auto_dtype
accuracy/test_llm_api_pytorch_multimodal.py::TestQwen2_VL_7B::test_auto_dtype
accuracy/test_llm_api_pytorch_multimodal.py::TestQwen2_5_VL_7B::test_auto_dtype
accuracy/test_llm_api_pytorch_multimodal.py::TestLlava_V1_6_Mistral_7B::test_auto_dtype

View File

@ -19,6 +19,7 @@ l0_a30:
- unittest/_torch/modeling -k "modeling_qwen"
- unittest/_torch/modeling -k "modeling_qwen_moe"
- unittest/_torch/modeling -k "modeling_out_of_tree"
- unittest/_torch/modeling -k "modeling_starcoder2"
- unittest/_torch/auto_deploy/unit/singlegpu
- unittest/_torch/sampler/test_beam_search.py
- unittest/_torch/sampler/test_return_logits.py

View File

@ -265,6 +265,9 @@ l0_h100:
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_eagle3[llguidance-eagle3_one_model=False]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_ngram[xgrammar]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_ngram[llguidance]
- accuracy/test_llm_api_pytorch.py::TestStarcoder2_3B::test_auto_dtype
- accuracy/test_llm_api_pytorch.py::TestStarcoder2_7B::test_auto_dtype
- accuracy/test_llm_api_pytorch.py::TestStarcoder2_15B::test_auto_dtype
- test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-image-True]
- test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-mixture_text_image-True]
- test_e2e.py::test_ptp_quickstart_multimodal_kv_cache_reuse[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-0.6-image]

View File

@ -0,0 +1,313 @@
from copy import deepcopy
from dataclasses import dataclass
import pytest
import torch
from transformers import Starcoder2Config
from transformers import Starcoder2ForCausalLM as HFStarcoder2ForCausalLM
from utils.util import default_dtype
import tensorrt_llm
from tensorrt_llm._torch.attention_backend.utils import get_attention_backend
from tensorrt_llm._torch.metadata import KVCacheParams
from tensorrt_llm._torch.model_config import ModelConfig
from tensorrt_llm._torch.models.modeling_starcoder2 import Starcoder2ForCausalLM
from tensorrt_llm._torch.modules.layer_norm import LayerNorm
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
from tensorrt_llm.bindings.executor import KvCacheConfig
from tensorrt_llm.mapping import Mapping
# Base config for all StarCoder2 models (based on HuggingFace configs)
_STARCODER2_BASE_CONFIG = {
"architectures": ["Starcoder2ForCausalLM"],
"attention_dropout": 0.1,
"residual_dropout": 0.1,
"embedding_dropout": 0.1,
"bos_token_id": 0,
"eos_token_id": 0,
"hidden_act": "gelu_pytorch_tanh",
"max_position_embeddings": 16384,
"mlp_type": "default",
"model_type": "starcoder2",
"norm_epsilon": 1e-5,
"num_hidden_layers": 6, # Reduced from 30/32/40 for testing
"sliding_window": 4096,
"transformers_version": "4.37.0.dev0",
"use_bias": True,
"use_cache": True,
"vocab_size": 49152,
"torch_dtype": "bfloat16",
}
# StarCoder2-3B config (reduced for testing)
STARCODER2_3B_CONFIG = {
**_STARCODER2_BASE_CONFIG,
"hidden_size": 3072,
"initializer_range": 0.018042,
"intermediate_size": 12288,
"num_attention_heads": 24,
"num_key_value_heads": 2,
"rope_theta": 999999.4420358813,
}
# StarCoder2-7B config (reduced for testing)
STARCODER2_7B_CONFIG = {
**_STARCODER2_BASE_CONFIG,
"hidden_size": 4608,
"initializer_range": 0.018042,
"intermediate_size": 18432,
"num_attention_heads": 36,
"num_key_value_heads": 4,
"rope_theta": 1000000,
}
# StarCoder2-15B config (reduced for testing)
STARCODER2_15B_CONFIG = {
**_STARCODER2_BASE_CONFIG,
"hidden_size": 6144,
"initializer_range": 0.01275,
"intermediate_size": 24576,
"num_attention_heads": 48,
"num_key_value_heads": 4,
"rope_theta": 100000,
}
@dataclass(repr=False)
class Scenario:
backend: str
config_name: str
use_cuda_graph: bool = False
def __repr__(self) -> str:
return f"backend:{self.backend.lower()}_config:{self.config_name.lower()}_cuda_graph:{self.use_cuda_graph}"
def get_kv_cache_manager(
dtype: torch.dtype,
config: Starcoder2Config,
tokens_per_block: int,
max_seq_len: int,
batch_size: int,
num_blocks: int,
):
"""Helper to create KV cache manager."""
if dtype == torch.half:
kv_cache_dtype = tensorrt_llm.bindings.DataType.HALF
elif dtype == torch.bfloat16:
kv_cache_dtype = tensorrt_llm.bindings.DataType.BF16
else:
raise ValueError(f"Invalid dtype: {dtype}")
mapping = Mapping(world_size=1, tp_size=1, rank=0)
kv_cache_config = KvCacheConfig(
enable_block_reuse=False,
enable_partial_reuse=False,
copy_on_partial_reuse=False,
max_tokens=num_blocks * tokens_per_block,
)
head_dim = config.hidden_size // config.num_attention_heads
kv_cache_manager = KVCacheManager(
kv_cache_config,
tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF,
num_layers=config.num_hidden_layers,
num_kv_heads=config.num_key_value_heads,
head_dim=head_dim,
tokens_per_block=tokens_per_block,
max_seq_len=max_seq_len,
max_batch_size=batch_size,
mapping=mapping,
dtype=kv_cache_dtype,
)
return kv_cache_manager
@pytest.mark.parametrize(
"scenario",
[
# Test without CUDA graphs
Scenario(backend="TRTLLM", config_name="3B", use_cuda_graph=False),
Scenario(backend="TRTLLM", config_name="7B", use_cuda_graph=False),
Scenario(backend="TRTLLM", config_name="15B", use_cuda_graph=False),
# Test with CUDA graphs
Scenario(backend="TRTLLM", config_name="3B", use_cuda_graph=True),
Scenario(backend="TRTLLM", config_name="7B", use_cuda_graph=True),
Scenario(backend="TRTLLM", config_name="15B", use_cuda_graph=True),
],
ids=str,
)
@torch.no_grad()
def test_starcoder2_allclose_to_hf(scenario: Scenario) -> None:
"""
Compare TensorRT-LLM StarCoder2 output to HuggingFace.
Tests both context and generation phases using randomly initialized models.
Optionally tests with CUDA graphs for generation phase optimization.
"""
backend = scenario.backend
config_name = scenario.config_name
use_cuda_graph = scenario.use_cuda_graph
metadata_cls = get_attention_backend(backend).Metadata
torch.random.manual_seed(0)
# Create config based on model size
config_mapping = {
"3B": STARCODER2_3B_CONFIG,
"7B": STARCODER2_7B_CONFIG,
"15B": STARCODER2_15B_CONFIG,
}
config_dict = deepcopy(config_mapping[config_name])
# Create HuggingFace model from config with random weights
hf_config = Starcoder2Config.from_dict(config_dict)
hf_starcoder2 = HFStarcoder2ForCausalLM(hf_config)
hf_starcoder2 = hf_starcoder2.to(dtype=torch.bfloat16, device="cuda").eval()
dtype = torch.bfloat16
device = torch.device("cuda")
# Build TRT-LLM model and copy the same random weights from HF model
with torch.device(device), default_dtype(dtype):
model_config = ModelConfig(pretrained_config=hf_config, attn_backend=backend)
starcoder2 = Starcoder2ForCausalLM(model_config).to(dtype).to(device).eval()
starcoder2.load_weights(hf_starcoder2.state_dict())
# Convert LayerNorm random weights to FP32 for numerical stability
for name, module in starcoder2.named_modules():
if isinstance(module, LayerNorm):
if hasattr(module, "weight") and module.weight is not None:
module.weight.data = module.weight.data.to(torch.float32)
if hasattr(module, "bias") and module.bias is not None:
module.bias.data = module.bias.data.to(torch.float32)
num_blocks = 1
tokens_per_block = 128
max_seq_len = num_blocks * tokens_per_block
batch_size = 1
kv_cache_manager = get_kv_cache_manager(
dtype=dtype,
config=hf_config,
tokens_per_block=tokens_per_block,
max_seq_len=max_seq_len,
batch_size=batch_size,
num_blocks=num_blocks,
)
# Context phase (no CUDA graphs for prefill)
input_ids = torch.tensor(
[100, 200, 300, 400, 500, 600, 700, 800],
dtype=torch.int,
device=device,
)
num_cached_tokens_per_seq = [0]
request_ids = [1]
token_nums = [input_ids.size(-1)]
prompt_lens = [input_ids.size(-1)]
kv_cache_manager.add_dummy_requests(request_ids, token_nums)
attn_metadata = metadata_cls(
seq_lens=torch.tensor([input_ids.size(-1)], dtype=torch.int),
num_contexts=1,
kv_cache_params=KVCacheParams(
use_cache=True,
num_cached_tokens_per_seq=num_cached_tokens_per_seq,
),
max_num_requests=1,
max_num_tokens=8192,
kv_cache_manager=kv_cache_manager,
request_ids=request_ids,
prompt_lens=prompt_lens,
)
position_ids = [torch.arange(0, input_ids.size(-1), dtype=torch.int)]
position_ids = torch.cat(position_ids).unsqueeze(0).cuda()
with torch.inference_mode():
attn_metadata.prepare()
logits = starcoder2.forward(
input_ids=input_ids,
position_ids=position_ids,
attn_metadata=attn_metadata,
)
ref = hf_starcoder2.forward(
input_ids=input_ids.unsqueeze(0),
position_ids=position_ids,
use_cache=True,
)
torch.testing.assert_close(logits, ref.logits[:, -1].float(), atol=0.1, rtol=0.1)
# Generation phase (optionally with CUDA graphs)
gen_input_ids = torch.tensor([900], dtype=torch.int, device=device)
num_cached_tokens_per_seq = [input_ids.size(-1)]
attn_metadata = metadata_cls(
seq_lens=torch.tensor([gen_input_ids.size(-1)], dtype=torch.int),
num_contexts=0,
kv_cache_params=KVCacheParams(
use_cache=True,
num_cached_tokens_per_seq=num_cached_tokens_per_seq,
),
kv_cache_manager=kv_cache_manager,
request_ids=request_ids,
prompt_lens=prompt_lens,
max_num_requests=1,
max_num_tokens=8192,
)
gen_position_ids = [
torch.arange(
input_ids.size(-1), input_ids.size(-1) + gen_input_ids.size(-1), dtype=torch.int
)
]
gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda()
# Setup CUDA graph runner if requested
graph_runner = None
if use_cuda_graph:
from _torch.helpers import create_mock_cuda_graph_runner
graph_runner = create_mock_cuda_graph_runner(1)
attn_metadata = attn_metadata.create_cuda_graph_metadata(1)
# Run generation phase
with torch.inference_mode():
if not use_cuda_graph:
attn_metadata.prepare()
logits = starcoder2.forward(
input_ids=gen_input_ids,
position_ids=gen_position_ids,
attn_metadata=attn_metadata,
)
else:
# CUDA graph path
inputs = {
"input_ids": gen_input_ids,
"position_ids": gen_position_ids,
"attn_metadata": attn_metadata,
}
key = (1, 0, False)
attn_metadata.prepare()
graph_runner.capture(key, lambda inputs: starcoder2.forward(**inputs), inputs)
# Run twice to catch buffer reallocation issues
for _ in range(2):
attn_metadata.prepare()
logits = graph_runner.replay(key, inputs)
# Compare with HuggingFace
ref = hf_starcoder2.forward(
input_ids=gen_input_ids.unsqueeze(0),
position_ids=gen_position_ids,
past_key_values=ref.past_key_values,
use_cache=True,
)
torch.testing.assert_close(logits, ref.logits[:, -1].float(), atol=0.1, rtol=0.1)
# Cleanup
if graph_runner is not None:
graph_runner.clear()
kv_cache_manager.shutdown()