mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[TRTLLM-7967][feat] Adding Starcoder2 PyTorch Backend Support (#8923)
Signed-off-by: Yibin Li <109242046+yibinl-nvidia@users.noreply.github.com>
This commit is contained in:
parent
336593cac5
commit
1ce483c999
@ -31,6 +31,7 @@ from .modeling_qwen3_next import Qwen3NextForCausalLM
|
|||||||
from .modeling_qwen_moe import Qwen2MoeForCausalLM
|
from .modeling_qwen_moe import Qwen2MoeForCausalLM
|
||||||
from .modeling_seedoss import SeedOssForCausalLM
|
from .modeling_seedoss import SeedOssForCausalLM
|
||||||
from .modeling_siglip import SiglipVisionModel
|
from .modeling_siglip import SiglipVisionModel
|
||||||
|
from .modeling_starcoder2 import Starcoder2ForCausalLM
|
||||||
from .modeling_utils import get_model_architecture
|
from .modeling_utils import get_model_architecture
|
||||||
from .modeling_vila import VilaModel
|
from .modeling_vila import VilaModel
|
||||||
|
|
||||||
@ -62,6 +63,7 @@ __all__ = [
|
|||||||
"Qwen2ForRewardModel",
|
"Qwen2ForRewardModel",
|
||||||
"Qwen2MoeForCausalLM",
|
"Qwen2MoeForCausalLM",
|
||||||
"SiglipVisionModel",
|
"SiglipVisionModel",
|
||||||
|
"Starcoder2ForCausalLM",
|
||||||
"get_model_architecture",
|
"get_model_architecture",
|
||||||
"VilaModel",
|
"VilaModel",
|
||||||
"Qwen2VLModel",
|
"Qwen2VLModel",
|
||||||
|
|||||||
287
tensorrt_llm/_torch/models/modeling_starcoder2.py
Normal file
287
tensorrt_llm/_torch/models/modeling_starcoder2.py
Normal file
@ -0,0 +1,287 @@
|
|||||||
|
# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from transformers import Starcoder2Config
|
||||||
|
|
||||||
|
from tensorrt_llm._torch.attention_backend import AttentionMetadata
|
||||||
|
from tensorrt_llm._torch.attention_backend.interface import PositionalEmbeddingParams, RopeParams
|
||||||
|
from tensorrt_llm._torch.model_config import ModelConfig
|
||||||
|
from tensorrt_llm._torch.models.modeling_utils import (
|
||||||
|
DecoderModel,
|
||||||
|
DecoderModelForCausalLM,
|
||||||
|
_load_weights_impl,
|
||||||
|
register_auto_model,
|
||||||
|
)
|
||||||
|
from tensorrt_llm._torch.modules.attention import Attention
|
||||||
|
from tensorrt_llm._torch.modules.decoder_layer import DecoderLayer
|
||||||
|
from tensorrt_llm._torch.modules.embedding import Embedding
|
||||||
|
from tensorrt_llm._torch.modules.layer_norm import LayerNorm
|
||||||
|
from tensorrt_llm._torch.modules.linear import TensorParallelMode
|
||||||
|
from tensorrt_llm._torch.modules.mlp import MLP
|
||||||
|
from tensorrt_llm._torch.speculative import SpecMetadata
|
||||||
|
from tensorrt_llm.functional import PositionEmbeddingType
|
||||||
|
|
||||||
|
|
||||||
|
class Starcoder2Attention(Attention):
|
||||||
|
"""
|
||||||
|
StarCoder2 Attention with Grouped Query Attention and Sliding Window support.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_config: ModelConfig[Starcoder2Config],
|
||||||
|
layer_idx: Optional[int] = None,
|
||||||
|
):
|
||||||
|
config = model_config.pretrained_config
|
||||||
|
super().__init__(
|
||||||
|
hidden_size=config.hidden_size,
|
||||||
|
num_attention_heads=config.num_attention_heads,
|
||||||
|
num_key_value_heads=config.num_key_value_heads,
|
||||||
|
max_position_embeddings=config.max_position_embeddings,
|
||||||
|
bias=config.use_bias,
|
||||||
|
pos_embd_params=PositionalEmbeddingParams(
|
||||||
|
type=PositionEmbeddingType.rope_gpt_neox,
|
||||||
|
rope=RopeParams.from_config(config),
|
||||||
|
),
|
||||||
|
layer_idx=layer_idx,
|
||||||
|
dtype=config.torch_dtype,
|
||||||
|
config=model_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Configure sliding window attention (4096 tokens)
|
||||||
|
self.attention_window_size = getattr(config, "sliding_window", 4096)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
position_ids: torch.IntTensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
**kwargs,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Overrides parent to pass attention_window_size parameter.
|
||||||
|
"""
|
||||||
|
return super().forward(
|
||||||
|
position_ids=position_ids,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
attn_metadata=attn_metadata,
|
||||||
|
attention_window_size=self.attention_window_size,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Starcoder2DecoderLayer(DecoderLayer):
|
||||||
|
"""
|
||||||
|
StarCoder2 Decoder Layer.
|
||||||
|
|
||||||
|
Architecture:
|
||||||
|
- Layer normalization before attention (with bias)
|
||||||
|
- Self-attention with GQA and sliding window
|
||||||
|
- Layer normalization before MLP (with bias)
|
||||||
|
- MLP with GELU activation
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_config: ModelConfig[Starcoder2Config],
|
||||||
|
layer_idx: int,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
config = model_config.pretrained_config
|
||||||
|
self.layer_idx = layer_idx
|
||||||
|
|
||||||
|
self.self_attn = Starcoder2Attention(
|
||||||
|
model_config,
|
||||||
|
layer_idx=layer_idx,
|
||||||
|
)
|
||||||
|
|
||||||
|
if config.mlp_type == "default":
|
||||||
|
self.mlp = MLP(
|
||||||
|
hidden_size=config.hidden_size,
|
||||||
|
intermediate_size=config.intermediate_size,
|
||||||
|
bias=config.use_bias,
|
||||||
|
activation=nn.GELU(),
|
||||||
|
dtype=config.torch_dtype,
|
||||||
|
config=model_config,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported mlp_type: {config.mlp_type}. Only default (linear) MLP is supported."
|
||||||
|
)
|
||||||
|
|
||||||
|
norm_eps = getattr(config, "norm_epsilon", 1e-5)
|
||||||
|
self.input_layernorm = LayerNorm(
|
||||||
|
hidden_size=config.hidden_size,
|
||||||
|
eps=norm_eps,
|
||||||
|
dtype=config.torch_dtype,
|
||||||
|
has_bias=True, # StarCoder2 uses bias in layer norm
|
||||||
|
)
|
||||||
|
|
||||||
|
self.post_attention_layernorm = LayerNorm(
|
||||||
|
hidden_size=config.hidden_size,
|
||||||
|
eps=norm_eps,
|
||||||
|
dtype=config.torch_dtype,
|
||||||
|
has_bias=True, # StarCoder2 uses bias in layer norm
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
position_ids: torch.IntTensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
residual: Optional[torch.Tensor] = None,
|
||||||
|
spec_metadata: Optional[SpecMetadata] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if residual is None:
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
else:
|
||||||
|
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||||
|
|
||||||
|
# Self Attention
|
||||||
|
hidden_states = self.self_attn(
|
||||||
|
position_ids=position_ids,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
attn_metadata=attn_metadata,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fully Connected (MLP)
|
||||||
|
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
|
||||||
|
if spec_metadata is not None:
|
||||||
|
spec_metadata.maybe_capture_hidden_states(self.layer_idx, hidden_states, residual)
|
||||||
|
|
||||||
|
return hidden_states, residual
|
||||||
|
|
||||||
|
|
||||||
|
class Starcoder2Model(DecoderModel):
|
||||||
|
"""
|
||||||
|
StarCoder2 Transformer Model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model_config: ModelConfig[Starcoder2Config]):
|
||||||
|
super().__init__(model_config)
|
||||||
|
config = self.model_config.pretrained_config
|
||||||
|
|
||||||
|
self.embed_tokens = Embedding(
|
||||||
|
config.vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
dtype=config.torch_dtype,
|
||||||
|
mapping=model_config.mapping,
|
||||||
|
tensor_parallel_mode=TensorParallelMode.COLUMN,
|
||||||
|
gather_output=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
Starcoder2DecoderLayer(
|
||||||
|
model_config,
|
||||||
|
layer_idx,
|
||||||
|
)
|
||||||
|
for layer_idx in range(config.num_hidden_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use norm_epsilon (Starcoder2Config attribute name)
|
||||||
|
norm_eps = getattr(config, "norm_epsilon", 1e-5)
|
||||||
|
self.norm = LayerNorm(
|
||||||
|
hidden_size=config.hidden_size,
|
||||||
|
eps=norm_eps,
|
||||||
|
dtype=config.torch_dtype,
|
||||||
|
has_bias=True, # StarCoder2 uses bias in layer norm
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
input_ids: Optional[torch.IntTensor] = None,
|
||||||
|
position_ids: Optional[torch.IntTensor] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
spec_metadata: Optional[SpecMetadata] = None,
|
||||||
|
lora_params=None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||||
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds.")
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
|
residual = None
|
||||||
|
for decoder_layer in self.layers:
|
||||||
|
hidden_states, residual = decoder_layer(
|
||||||
|
position_ids=position_ids,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
attn_metadata=attn_metadata,
|
||||||
|
residual=residual,
|
||||||
|
spec_metadata=spec_metadata,
|
||||||
|
lora_params=lora_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use LayerNorm's built-in residual connection support
|
||||||
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
@register_auto_model("Starcoder2ForCausalLM")
|
||||||
|
class Starcoder2ForCausalLM(DecoderModelForCausalLM[Starcoder2Model, Starcoder2Config]):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_config: ModelConfig[Starcoder2Config],
|
||||||
|
):
|
||||||
|
# Ensure torch_dtype is set on pretrained_config (StarCoder2 uses bfloat16).
|
||||||
|
# For the 15B FP32 checkpoint, we cast it to bfloat16 for consistency.
|
||||||
|
torch_dtype_to_check = model_config.pretrained_config.torch_dtype
|
||||||
|
if torch_dtype_to_check is None or torch_dtype_to_check == torch.float32:
|
||||||
|
model_config.pretrained_config.torch_dtype = torch.bfloat16
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
Starcoder2Model(model_config),
|
||||||
|
config=model_config,
|
||||||
|
hidden_size=model_config.pretrained_config.hidden_size,
|
||||||
|
vocab_size=model_config.pretrained_config.vocab_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
def load_weights(self, weights, weight_mapper=None, skip_modules=None):
|
||||||
|
"""
|
||||||
|
Load weights with custom mapping for StarCoder2.
|
||||||
|
|
||||||
|
StarCoder2 uses GPT-2 style MLP naming (c_fc, c_proj)
|
||||||
|
while our MLP module expects (up_proj, down_proj).
|
||||||
|
"""
|
||||||
|
if skip_modules is None:
|
||||||
|
skip_modules = []
|
||||||
|
|
||||||
|
# Map HuggingFace StarCoder2 weight names to TensorRT-LLM names
|
||||||
|
params_map = {
|
||||||
|
r"(.*?)\.mlp\.c_fc\.(.*)": r"\1.mlp.up_proj.\2",
|
||||||
|
r"(.*?)\.mlp\.c_proj\.(.*)": r"\1.mlp.down_proj.\2",
|
||||||
|
}
|
||||||
|
preload_weight_modules = getattr(self, "preload_weight_modules", None)
|
||||||
|
_load_weights_impl(
|
||||||
|
self,
|
||||||
|
weights,
|
||||||
|
skip_modules,
|
||||||
|
params_map=params_map,
|
||||||
|
preload_weight_modules=preload_weight_modules,
|
||||||
|
)
|
||||||
@ -270,3 +270,9 @@ zai-org/GLM-4.6:
|
|||||||
- quant_algo: NVFP4
|
- quant_algo: NVFP4
|
||||||
spec_dec_algo: MTP
|
spec_dec_algo: MTP
|
||||||
accuracy: 88.0
|
accuracy: 88.0
|
||||||
|
bigcode/starcoder2-3b:
|
||||||
|
- accuracy: 20.2
|
||||||
|
bigcode/starcoder2-7b:
|
||||||
|
- accuracy: 26.5
|
||||||
|
bigcode/starcoder2-15b:
|
||||||
|
- accuracy: 54.5
|
||||||
|
|||||||
@ -4264,3 +4264,49 @@ class TestDeepSeekR1LongBenchV2(LlmapiAccuracyTestHarness):
|
|||||||
if temp_dir and os.path.exists(temp_dir):
|
if temp_dir and os.path.exists(temp_dir):
|
||||||
import shutil
|
import shutil
|
||||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||||
|
|
||||||
|
|
||||||
|
class TestStarcoder2_3B(LlmapiAccuracyTestHarness):
|
||||||
|
MODEL_NAME = "bigcode/starcoder2-3b"
|
||||||
|
MODEL_PATH = f"{llm_models_root()}/starcoder2-3b/"
|
||||||
|
|
||||||
|
@skip_pre_hopper
|
||||||
|
def test_auto_dtype(self):
|
||||||
|
with LLM(self.MODEL_PATH,
|
||||||
|
attn_backend="TRTLLM",
|
||||||
|
cuda_graph_config=None,
|
||||||
|
max_batch_size=128,
|
||||||
|
max_seq_len=4096) as llm:
|
||||||
|
task = GSM8K(self.MODEL_NAME)
|
||||||
|
task.evaluate(llm)
|
||||||
|
|
||||||
|
|
||||||
|
class TestStarcoder2_7B(LlmapiAccuracyTestHarness):
|
||||||
|
MODEL_NAME = "bigcode/starcoder2-7b"
|
||||||
|
MODEL_PATH = f"{llm_models_root()}/starcoder2-7b/"
|
||||||
|
|
||||||
|
@skip_pre_hopper
|
||||||
|
def test_auto_dtype(self):
|
||||||
|
with LLM(self.MODEL_PATH,
|
||||||
|
attn_backend="TRTLLM",
|
||||||
|
cuda_graph_config=None,
|
||||||
|
max_batch_size=128,
|
||||||
|
max_seq_len=4096) as llm:
|
||||||
|
task = GSM8K(self.MODEL_NAME)
|
||||||
|
task.evaluate(llm)
|
||||||
|
|
||||||
|
|
||||||
|
class TestStarcoder2_15B(LlmapiAccuracyTestHarness):
|
||||||
|
MODEL_NAME = "bigcode/starcoder2-15b"
|
||||||
|
MODEL_PATH = f"{llm_models_root()}/starcoder2-15b/"
|
||||||
|
|
||||||
|
@skip_pre_hopper
|
||||||
|
@pytest.mark.skip_less_device_memory(80000)
|
||||||
|
def test_auto_dtype(self):
|
||||||
|
with LLM(self.MODEL_PATH,
|
||||||
|
attn_backend="TRTLLM",
|
||||||
|
cuda_graph_config=None,
|
||||||
|
max_batch_size=128,
|
||||||
|
max_seq_len=4096) as llm:
|
||||||
|
task = GSM8K(self.MODEL_NAME)
|
||||||
|
task.evaluate(llm)
|
||||||
|
|||||||
@ -381,6 +381,10 @@ accuracy/test_llm_api_pytorch.py::TestNemotronUltra::test_fp8_prequantized[tp8-c
|
|||||||
accuracy/test_llm_api_pytorch.py::TestQwQ_32B::test_auto_dtype_tp4
|
accuracy/test_llm_api_pytorch.py::TestQwQ_32B::test_auto_dtype_tp4
|
||||||
accuracy/test_llm_api_pytorch.py::TestCodestral_22B_V01::test_auto_dtype
|
accuracy/test_llm_api_pytorch.py::TestCodestral_22B_V01::test_auto_dtype
|
||||||
accuracy/test_llm_api_pytorch.py::TestKimiK2::test_fp8_blockscale[latency]
|
accuracy/test_llm_api_pytorch.py::TestKimiK2::test_fp8_blockscale[latency]
|
||||||
|
accuracy/test_llm_api_pytorch.py::TestStarcoder2_3B::test_auto_dtype
|
||||||
|
accuracy/test_llm_api_pytorch.py::TestStarcoder2_7B::test_auto_dtype
|
||||||
|
accuracy/test_llm_api_pytorch.py::TestStarcoder2_15B::test_auto_dtype
|
||||||
|
|
||||||
accuracy/test_llm_api_pytorch_multimodal.py::TestQwen2_VL_7B::test_auto_dtype
|
accuracy/test_llm_api_pytorch_multimodal.py::TestQwen2_VL_7B::test_auto_dtype
|
||||||
accuracy/test_llm_api_pytorch_multimodal.py::TestQwen2_5_VL_7B::test_auto_dtype
|
accuracy/test_llm_api_pytorch_multimodal.py::TestQwen2_5_VL_7B::test_auto_dtype
|
||||||
accuracy/test_llm_api_pytorch_multimodal.py::TestLlava_V1_6_Mistral_7B::test_auto_dtype
|
accuracy/test_llm_api_pytorch_multimodal.py::TestLlava_V1_6_Mistral_7B::test_auto_dtype
|
||||||
|
|||||||
@ -19,6 +19,7 @@ l0_a30:
|
|||||||
- unittest/_torch/modeling -k "modeling_qwen"
|
- unittest/_torch/modeling -k "modeling_qwen"
|
||||||
- unittest/_torch/modeling -k "modeling_qwen_moe"
|
- unittest/_torch/modeling -k "modeling_qwen_moe"
|
||||||
- unittest/_torch/modeling -k "modeling_out_of_tree"
|
- unittest/_torch/modeling -k "modeling_out_of_tree"
|
||||||
|
- unittest/_torch/modeling -k "modeling_starcoder2"
|
||||||
- unittest/_torch/auto_deploy/unit/singlegpu
|
- unittest/_torch/auto_deploy/unit/singlegpu
|
||||||
- unittest/_torch/sampler/test_beam_search.py
|
- unittest/_torch/sampler/test_beam_search.py
|
||||||
- unittest/_torch/sampler/test_return_logits.py
|
- unittest/_torch/sampler/test_return_logits.py
|
||||||
|
|||||||
@ -265,6 +265,9 @@ l0_h100:
|
|||||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_eagle3[llguidance-eagle3_one_model=False]
|
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_eagle3[llguidance-eagle3_one_model=False]
|
||||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_ngram[xgrammar]
|
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_ngram[xgrammar]
|
||||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_ngram[llguidance]
|
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_ngram[llguidance]
|
||||||
|
- accuracy/test_llm_api_pytorch.py::TestStarcoder2_3B::test_auto_dtype
|
||||||
|
- accuracy/test_llm_api_pytorch.py::TestStarcoder2_7B::test_auto_dtype
|
||||||
|
- accuracy/test_llm_api_pytorch.py::TestStarcoder2_15B::test_auto_dtype
|
||||||
- test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-image-True]
|
- test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-image-True]
|
||||||
- test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-mixture_text_image-True]
|
- test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-mixture_text_image-True]
|
||||||
- test_e2e.py::test_ptp_quickstart_multimodal_kv_cache_reuse[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-0.6-image]
|
- test_e2e.py::test_ptp_quickstart_multimodal_kv_cache_reuse[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-0.6-image]
|
||||||
|
|||||||
313
tests/unittest/_torch/modeling/test_modeling_starcoder2.py
Normal file
313
tests/unittest/_torch/modeling/test_modeling_starcoder2.py
Normal file
@ -0,0 +1,313 @@
|
|||||||
|
from copy import deepcopy
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from transformers import Starcoder2Config
|
||||||
|
from transformers import Starcoder2ForCausalLM as HFStarcoder2ForCausalLM
|
||||||
|
from utils.util import default_dtype
|
||||||
|
|
||||||
|
import tensorrt_llm
|
||||||
|
from tensorrt_llm._torch.attention_backend.utils import get_attention_backend
|
||||||
|
from tensorrt_llm._torch.metadata import KVCacheParams
|
||||||
|
from tensorrt_llm._torch.model_config import ModelConfig
|
||||||
|
from tensorrt_llm._torch.models.modeling_starcoder2 import Starcoder2ForCausalLM
|
||||||
|
from tensorrt_llm._torch.modules.layer_norm import LayerNorm
|
||||||
|
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
|
||||||
|
from tensorrt_llm.bindings.executor import KvCacheConfig
|
||||||
|
from tensorrt_llm.mapping import Mapping
|
||||||
|
|
||||||
|
# Base config for all StarCoder2 models (based on HuggingFace configs)
|
||||||
|
_STARCODER2_BASE_CONFIG = {
|
||||||
|
"architectures": ["Starcoder2ForCausalLM"],
|
||||||
|
"attention_dropout": 0.1,
|
||||||
|
"residual_dropout": 0.1,
|
||||||
|
"embedding_dropout": 0.1,
|
||||||
|
"bos_token_id": 0,
|
||||||
|
"eos_token_id": 0,
|
||||||
|
"hidden_act": "gelu_pytorch_tanh",
|
||||||
|
"max_position_embeddings": 16384,
|
||||||
|
"mlp_type": "default",
|
||||||
|
"model_type": "starcoder2",
|
||||||
|
"norm_epsilon": 1e-5,
|
||||||
|
"num_hidden_layers": 6, # Reduced from 30/32/40 for testing
|
||||||
|
"sliding_window": 4096,
|
||||||
|
"transformers_version": "4.37.0.dev0",
|
||||||
|
"use_bias": True,
|
||||||
|
"use_cache": True,
|
||||||
|
"vocab_size": 49152,
|
||||||
|
"torch_dtype": "bfloat16",
|
||||||
|
}
|
||||||
|
|
||||||
|
# StarCoder2-3B config (reduced for testing)
|
||||||
|
STARCODER2_3B_CONFIG = {
|
||||||
|
**_STARCODER2_BASE_CONFIG,
|
||||||
|
"hidden_size": 3072,
|
||||||
|
"initializer_range": 0.018042,
|
||||||
|
"intermediate_size": 12288,
|
||||||
|
"num_attention_heads": 24,
|
||||||
|
"num_key_value_heads": 2,
|
||||||
|
"rope_theta": 999999.4420358813,
|
||||||
|
}
|
||||||
|
|
||||||
|
# StarCoder2-7B config (reduced for testing)
|
||||||
|
STARCODER2_7B_CONFIG = {
|
||||||
|
**_STARCODER2_BASE_CONFIG,
|
||||||
|
"hidden_size": 4608,
|
||||||
|
"initializer_range": 0.018042,
|
||||||
|
"intermediate_size": 18432,
|
||||||
|
"num_attention_heads": 36,
|
||||||
|
"num_key_value_heads": 4,
|
||||||
|
"rope_theta": 1000000,
|
||||||
|
}
|
||||||
|
|
||||||
|
# StarCoder2-15B config (reduced for testing)
|
||||||
|
STARCODER2_15B_CONFIG = {
|
||||||
|
**_STARCODER2_BASE_CONFIG,
|
||||||
|
"hidden_size": 6144,
|
||||||
|
"initializer_range": 0.01275,
|
||||||
|
"intermediate_size": 24576,
|
||||||
|
"num_attention_heads": 48,
|
||||||
|
"num_key_value_heads": 4,
|
||||||
|
"rope_theta": 100000,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(repr=False)
|
||||||
|
class Scenario:
|
||||||
|
backend: str
|
||||||
|
config_name: str
|
||||||
|
use_cuda_graph: bool = False
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"backend:{self.backend.lower()}_config:{self.config_name.lower()}_cuda_graph:{self.use_cuda_graph}"
|
||||||
|
|
||||||
|
|
||||||
|
def get_kv_cache_manager(
|
||||||
|
dtype: torch.dtype,
|
||||||
|
config: Starcoder2Config,
|
||||||
|
tokens_per_block: int,
|
||||||
|
max_seq_len: int,
|
||||||
|
batch_size: int,
|
||||||
|
num_blocks: int,
|
||||||
|
):
|
||||||
|
"""Helper to create KV cache manager."""
|
||||||
|
if dtype == torch.half:
|
||||||
|
kv_cache_dtype = tensorrt_llm.bindings.DataType.HALF
|
||||||
|
elif dtype == torch.bfloat16:
|
||||||
|
kv_cache_dtype = tensorrt_llm.bindings.DataType.BF16
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid dtype: {dtype}")
|
||||||
|
|
||||||
|
mapping = Mapping(world_size=1, tp_size=1, rank=0)
|
||||||
|
kv_cache_config = KvCacheConfig(
|
||||||
|
enable_block_reuse=False,
|
||||||
|
enable_partial_reuse=False,
|
||||||
|
copy_on_partial_reuse=False,
|
||||||
|
max_tokens=num_blocks * tokens_per_block,
|
||||||
|
)
|
||||||
|
|
||||||
|
head_dim = config.hidden_size // config.num_attention_heads
|
||||||
|
kv_cache_manager = KVCacheManager(
|
||||||
|
kv_cache_config,
|
||||||
|
tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF,
|
||||||
|
num_layers=config.num_hidden_layers,
|
||||||
|
num_kv_heads=config.num_key_value_heads,
|
||||||
|
head_dim=head_dim,
|
||||||
|
tokens_per_block=tokens_per_block,
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
max_batch_size=batch_size,
|
||||||
|
mapping=mapping,
|
||||||
|
dtype=kv_cache_dtype,
|
||||||
|
)
|
||||||
|
return kv_cache_manager
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"scenario",
|
||||||
|
[
|
||||||
|
# Test without CUDA graphs
|
||||||
|
Scenario(backend="TRTLLM", config_name="3B", use_cuda_graph=False),
|
||||||
|
Scenario(backend="TRTLLM", config_name="7B", use_cuda_graph=False),
|
||||||
|
Scenario(backend="TRTLLM", config_name="15B", use_cuda_graph=False),
|
||||||
|
# Test with CUDA graphs
|
||||||
|
Scenario(backend="TRTLLM", config_name="3B", use_cuda_graph=True),
|
||||||
|
Scenario(backend="TRTLLM", config_name="7B", use_cuda_graph=True),
|
||||||
|
Scenario(backend="TRTLLM", config_name="15B", use_cuda_graph=True),
|
||||||
|
],
|
||||||
|
ids=str,
|
||||||
|
)
|
||||||
|
@torch.no_grad()
|
||||||
|
def test_starcoder2_allclose_to_hf(scenario: Scenario) -> None:
|
||||||
|
"""
|
||||||
|
Compare TensorRT-LLM StarCoder2 output to HuggingFace.
|
||||||
|
|
||||||
|
Tests both context and generation phases using randomly initialized models.
|
||||||
|
Optionally tests with CUDA graphs for generation phase optimization.
|
||||||
|
"""
|
||||||
|
backend = scenario.backend
|
||||||
|
config_name = scenario.config_name
|
||||||
|
use_cuda_graph = scenario.use_cuda_graph
|
||||||
|
metadata_cls = get_attention_backend(backend).Metadata
|
||||||
|
|
||||||
|
torch.random.manual_seed(0)
|
||||||
|
|
||||||
|
# Create config based on model size
|
||||||
|
config_mapping = {
|
||||||
|
"3B": STARCODER2_3B_CONFIG,
|
||||||
|
"7B": STARCODER2_7B_CONFIG,
|
||||||
|
"15B": STARCODER2_15B_CONFIG,
|
||||||
|
}
|
||||||
|
config_dict = deepcopy(config_mapping[config_name])
|
||||||
|
|
||||||
|
# Create HuggingFace model from config with random weights
|
||||||
|
hf_config = Starcoder2Config.from_dict(config_dict)
|
||||||
|
hf_starcoder2 = HFStarcoder2ForCausalLM(hf_config)
|
||||||
|
hf_starcoder2 = hf_starcoder2.to(dtype=torch.bfloat16, device="cuda").eval()
|
||||||
|
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
device = torch.device("cuda")
|
||||||
|
|
||||||
|
# Build TRT-LLM model and copy the same random weights from HF model
|
||||||
|
with torch.device(device), default_dtype(dtype):
|
||||||
|
model_config = ModelConfig(pretrained_config=hf_config, attn_backend=backend)
|
||||||
|
starcoder2 = Starcoder2ForCausalLM(model_config).to(dtype).to(device).eval()
|
||||||
|
starcoder2.load_weights(hf_starcoder2.state_dict())
|
||||||
|
|
||||||
|
# Convert LayerNorm random weights to FP32 for numerical stability
|
||||||
|
for name, module in starcoder2.named_modules():
|
||||||
|
if isinstance(module, LayerNorm):
|
||||||
|
if hasattr(module, "weight") and module.weight is not None:
|
||||||
|
module.weight.data = module.weight.data.to(torch.float32)
|
||||||
|
if hasattr(module, "bias") and module.bias is not None:
|
||||||
|
module.bias.data = module.bias.data.to(torch.float32)
|
||||||
|
|
||||||
|
num_blocks = 1
|
||||||
|
tokens_per_block = 128
|
||||||
|
max_seq_len = num_blocks * tokens_per_block
|
||||||
|
batch_size = 1
|
||||||
|
|
||||||
|
kv_cache_manager = get_kv_cache_manager(
|
||||||
|
dtype=dtype,
|
||||||
|
config=hf_config,
|
||||||
|
tokens_per_block=tokens_per_block,
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_blocks=num_blocks,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Context phase (no CUDA graphs for prefill)
|
||||||
|
input_ids = torch.tensor(
|
||||||
|
[100, 200, 300, 400, 500, 600, 700, 800],
|
||||||
|
dtype=torch.int,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
num_cached_tokens_per_seq = [0]
|
||||||
|
request_ids = [1]
|
||||||
|
token_nums = [input_ids.size(-1)]
|
||||||
|
prompt_lens = [input_ids.size(-1)]
|
||||||
|
kv_cache_manager.add_dummy_requests(request_ids, token_nums)
|
||||||
|
|
||||||
|
attn_metadata = metadata_cls(
|
||||||
|
seq_lens=torch.tensor([input_ids.size(-1)], dtype=torch.int),
|
||||||
|
num_contexts=1,
|
||||||
|
kv_cache_params=KVCacheParams(
|
||||||
|
use_cache=True,
|
||||||
|
num_cached_tokens_per_seq=num_cached_tokens_per_seq,
|
||||||
|
),
|
||||||
|
max_num_requests=1,
|
||||||
|
max_num_tokens=8192,
|
||||||
|
kv_cache_manager=kv_cache_manager,
|
||||||
|
request_ids=request_ids,
|
||||||
|
prompt_lens=prompt_lens,
|
||||||
|
)
|
||||||
|
|
||||||
|
position_ids = [torch.arange(0, input_ids.size(-1), dtype=torch.int)]
|
||||||
|
position_ids = torch.cat(position_ids).unsqueeze(0).cuda()
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
attn_metadata.prepare()
|
||||||
|
logits = starcoder2.forward(
|
||||||
|
input_ids=input_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
attn_metadata=attn_metadata,
|
||||||
|
)
|
||||||
|
ref = hf_starcoder2.forward(
|
||||||
|
input_ids=input_ids.unsqueeze(0),
|
||||||
|
position_ids=position_ids,
|
||||||
|
use_cache=True,
|
||||||
|
)
|
||||||
|
torch.testing.assert_close(logits, ref.logits[:, -1].float(), atol=0.1, rtol=0.1)
|
||||||
|
|
||||||
|
# Generation phase (optionally with CUDA graphs)
|
||||||
|
gen_input_ids = torch.tensor([900], dtype=torch.int, device=device)
|
||||||
|
num_cached_tokens_per_seq = [input_ids.size(-1)]
|
||||||
|
|
||||||
|
attn_metadata = metadata_cls(
|
||||||
|
seq_lens=torch.tensor([gen_input_ids.size(-1)], dtype=torch.int),
|
||||||
|
num_contexts=0,
|
||||||
|
kv_cache_params=KVCacheParams(
|
||||||
|
use_cache=True,
|
||||||
|
num_cached_tokens_per_seq=num_cached_tokens_per_seq,
|
||||||
|
),
|
||||||
|
kv_cache_manager=kv_cache_manager,
|
||||||
|
request_ids=request_ids,
|
||||||
|
prompt_lens=prompt_lens,
|
||||||
|
max_num_requests=1,
|
||||||
|
max_num_tokens=8192,
|
||||||
|
)
|
||||||
|
|
||||||
|
gen_position_ids = [
|
||||||
|
torch.arange(
|
||||||
|
input_ids.size(-1), input_ids.size(-1) + gen_input_ids.size(-1), dtype=torch.int
|
||||||
|
)
|
||||||
|
]
|
||||||
|
gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda()
|
||||||
|
|
||||||
|
# Setup CUDA graph runner if requested
|
||||||
|
graph_runner = None
|
||||||
|
if use_cuda_graph:
|
||||||
|
from _torch.helpers import create_mock_cuda_graph_runner
|
||||||
|
|
||||||
|
graph_runner = create_mock_cuda_graph_runner(1)
|
||||||
|
attn_metadata = attn_metadata.create_cuda_graph_metadata(1)
|
||||||
|
|
||||||
|
# Run generation phase
|
||||||
|
with torch.inference_mode():
|
||||||
|
if not use_cuda_graph:
|
||||||
|
attn_metadata.prepare()
|
||||||
|
logits = starcoder2.forward(
|
||||||
|
input_ids=gen_input_ids,
|
||||||
|
position_ids=gen_position_ids,
|
||||||
|
attn_metadata=attn_metadata,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# CUDA graph path
|
||||||
|
inputs = {
|
||||||
|
"input_ids": gen_input_ids,
|
||||||
|
"position_ids": gen_position_ids,
|
||||||
|
"attn_metadata": attn_metadata,
|
||||||
|
}
|
||||||
|
key = (1, 0, False)
|
||||||
|
|
||||||
|
attn_metadata.prepare()
|
||||||
|
graph_runner.capture(key, lambda inputs: starcoder2.forward(**inputs), inputs)
|
||||||
|
|
||||||
|
# Run twice to catch buffer reallocation issues
|
||||||
|
for _ in range(2):
|
||||||
|
attn_metadata.prepare()
|
||||||
|
logits = graph_runner.replay(key, inputs)
|
||||||
|
|
||||||
|
# Compare with HuggingFace
|
||||||
|
ref = hf_starcoder2.forward(
|
||||||
|
input_ids=gen_input_ids.unsqueeze(0),
|
||||||
|
position_ids=gen_position_ids,
|
||||||
|
past_key_values=ref.past_key_values,
|
||||||
|
use_cache=True,
|
||||||
|
)
|
||||||
|
torch.testing.assert_close(logits, ref.logits[:, -1].float(), atol=0.1, rtol=0.1)
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
if graph_runner is not None:
|
||||||
|
graph_runner.clear()
|
||||||
|
kv_cache_manager.shutdown()
|
||||||
Loading…
Reference in New Issue
Block a user