mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[TRTLLM-7967][feat] Adding Starcoder2 PyTorch Backend Support (#8923)
Signed-off-by: Yibin Li <109242046+yibinl-nvidia@users.noreply.github.com>
This commit is contained in:
parent
336593cac5
commit
1ce483c999
@ -31,6 +31,7 @@ from .modeling_qwen3_next import Qwen3NextForCausalLM
|
||||
from .modeling_qwen_moe import Qwen2MoeForCausalLM
|
||||
from .modeling_seedoss import SeedOssForCausalLM
|
||||
from .modeling_siglip import SiglipVisionModel
|
||||
from .modeling_starcoder2 import Starcoder2ForCausalLM
|
||||
from .modeling_utils import get_model_architecture
|
||||
from .modeling_vila import VilaModel
|
||||
|
||||
@ -62,6 +63,7 @@ __all__ = [
|
||||
"Qwen2ForRewardModel",
|
||||
"Qwen2MoeForCausalLM",
|
||||
"SiglipVisionModel",
|
||||
"Starcoder2ForCausalLM",
|
||||
"get_model_architecture",
|
||||
"VilaModel",
|
||||
"Qwen2VLModel",
|
||||
|
||||
287
tensorrt_llm/_torch/models/modeling_starcoder2.py
Normal file
287
tensorrt_llm/_torch/models/modeling_starcoder2.py
Normal file
@ -0,0 +1,287 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import Starcoder2Config
|
||||
|
||||
from tensorrt_llm._torch.attention_backend import AttentionMetadata
|
||||
from tensorrt_llm._torch.attention_backend.interface import PositionalEmbeddingParams, RopeParams
|
||||
from tensorrt_llm._torch.model_config import ModelConfig
|
||||
from tensorrt_llm._torch.models.modeling_utils import (
|
||||
DecoderModel,
|
||||
DecoderModelForCausalLM,
|
||||
_load_weights_impl,
|
||||
register_auto_model,
|
||||
)
|
||||
from tensorrt_llm._torch.modules.attention import Attention
|
||||
from tensorrt_llm._torch.modules.decoder_layer import DecoderLayer
|
||||
from tensorrt_llm._torch.modules.embedding import Embedding
|
||||
from tensorrt_llm._torch.modules.layer_norm import LayerNorm
|
||||
from tensorrt_llm._torch.modules.linear import TensorParallelMode
|
||||
from tensorrt_llm._torch.modules.mlp import MLP
|
||||
from tensorrt_llm._torch.speculative import SpecMetadata
|
||||
from tensorrt_llm.functional import PositionEmbeddingType
|
||||
|
||||
|
||||
class Starcoder2Attention(Attention):
|
||||
"""
|
||||
StarCoder2 Attention with Grouped Query Attention and Sliding Window support.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig[Starcoder2Config],
|
||||
layer_idx: Optional[int] = None,
|
||||
):
|
||||
config = model_config.pretrained_config
|
||||
super().__init__(
|
||||
hidden_size=config.hidden_size,
|
||||
num_attention_heads=config.num_attention_heads,
|
||||
num_key_value_heads=config.num_key_value_heads,
|
||||
max_position_embeddings=config.max_position_embeddings,
|
||||
bias=config.use_bias,
|
||||
pos_embd_params=PositionalEmbeddingParams(
|
||||
type=PositionEmbeddingType.rope_gpt_neox,
|
||||
rope=RopeParams.from_config(config),
|
||||
),
|
||||
layer_idx=layer_idx,
|
||||
dtype=config.torch_dtype,
|
||||
config=model_config,
|
||||
)
|
||||
|
||||
# Configure sliding window attention (4096 tokens)
|
||||
self.attention_window_size = getattr(config, "sliding_window", 4096)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
position_ids: torch.IntTensor,
|
||||
hidden_states: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Overrides parent to pass attention_window_size parameter.
|
||||
"""
|
||||
return super().forward(
|
||||
position_ids=position_ids,
|
||||
hidden_states=hidden_states,
|
||||
attn_metadata=attn_metadata,
|
||||
attention_window_size=self.attention_window_size,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class Starcoder2DecoderLayer(DecoderLayer):
|
||||
"""
|
||||
StarCoder2 Decoder Layer.
|
||||
|
||||
Architecture:
|
||||
- Layer normalization before attention (with bias)
|
||||
- Self-attention with GQA and sliding window
|
||||
- Layer normalization before MLP (with bias)
|
||||
- MLP with GELU activation
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig[Starcoder2Config],
|
||||
layer_idx: int,
|
||||
):
|
||||
super().__init__()
|
||||
config = model_config.pretrained_config
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
self.self_attn = Starcoder2Attention(
|
||||
model_config,
|
||||
layer_idx=layer_idx,
|
||||
)
|
||||
|
||||
if config.mlp_type == "default":
|
||||
self.mlp = MLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
bias=config.use_bias,
|
||||
activation=nn.GELU(),
|
||||
dtype=config.torch_dtype,
|
||||
config=model_config,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported mlp_type: {config.mlp_type}. Only default (linear) MLP is supported."
|
||||
)
|
||||
|
||||
norm_eps = getattr(config, "norm_epsilon", 1e-5)
|
||||
self.input_layernorm = LayerNorm(
|
||||
hidden_size=config.hidden_size,
|
||||
eps=norm_eps,
|
||||
dtype=config.torch_dtype,
|
||||
has_bias=True, # StarCoder2 uses bias in layer norm
|
||||
)
|
||||
|
||||
self.post_attention_layernorm = LayerNorm(
|
||||
hidden_size=config.hidden_size,
|
||||
eps=norm_eps,
|
||||
dtype=config.torch_dtype,
|
||||
has_bias=True, # StarCoder2 uses bias in layer norm
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
position_ids: torch.IntTensor,
|
||||
hidden_states: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
spec_metadata: Optional[SpecMetadata] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
|
||||
# Self Attention
|
||||
hidden_states = self.self_attn(
|
||||
position_ids=position_ids,
|
||||
hidden_states=hidden_states,
|
||||
attn_metadata=attn_metadata,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Fully Connected (MLP)
|
||||
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
|
||||
if spec_metadata is not None:
|
||||
spec_metadata.maybe_capture_hidden_states(self.layer_idx, hidden_states, residual)
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
class Starcoder2Model(DecoderModel):
|
||||
"""
|
||||
StarCoder2 Transformer Model.
|
||||
"""
|
||||
|
||||
def __init__(self, model_config: ModelConfig[Starcoder2Config]):
|
||||
super().__init__(model_config)
|
||||
config = self.model_config.pretrained_config
|
||||
|
||||
self.embed_tokens = Embedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
dtype=config.torch_dtype,
|
||||
mapping=model_config.mapping,
|
||||
tensor_parallel_mode=TensorParallelMode.COLUMN,
|
||||
gather_output=True,
|
||||
)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
Starcoder2DecoderLayer(
|
||||
model_config,
|
||||
layer_idx,
|
||||
)
|
||||
for layer_idx in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# Use norm_epsilon (Starcoder2Config attribute name)
|
||||
norm_eps = getattr(config, "norm_epsilon", 1e-5)
|
||||
self.norm = LayerNorm(
|
||||
hidden_size=config.hidden_size,
|
||||
eps=norm_eps,
|
||||
dtype=config.torch_dtype,
|
||||
has_bias=True, # StarCoder2 uses bias in layer norm
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
attn_metadata: AttentionMetadata,
|
||||
input_ids: Optional[torch.IntTensor] = None,
|
||||
position_ids: Optional[torch.IntTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
spec_metadata: Optional[SpecMetadata] = None,
|
||||
lora_params=None,
|
||||
) -> torch.Tensor:
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds.")
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
residual = None
|
||||
for decoder_layer in self.layers:
|
||||
hidden_states, residual = decoder_layer(
|
||||
position_ids=position_ids,
|
||||
hidden_states=hidden_states,
|
||||
attn_metadata=attn_metadata,
|
||||
residual=residual,
|
||||
spec_metadata=spec_metadata,
|
||||
lora_params=lora_params,
|
||||
)
|
||||
|
||||
# Use LayerNorm's built-in residual connection support
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@register_auto_model("Starcoder2ForCausalLM")
|
||||
class Starcoder2ForCausalLM(DecoderModelForCausalLM[Starcoder2Model, Starcoder2Config]):
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig[Starcoder2Config],
|
||||
):
|
||||
# Ensure torch_dtype is set on pretrained_config (StarCoder2 uses bfloat16).
|
||||
# For the 15B FP32 checkpoint, we cast it to bfloat16 for consistency.
|
||||
torch_dtype_to_check = model_config.pretrained_config.torch_dtype
|
||||
if torch_dtype_to_check is None or torch_dtype_to_check == torch.float32:
|
||||
model_config.pretrained_config.torch_dtype = torch.bfloat16
|
||||
|
||||
super().__init__(
|
||||
Starcoder2Model(model_config),
|
||||
config=model_config,
|
||||
hidden_size=model_config.pretrained_config.hidden_size,
|
||||
vocab_size=model_config.pretrained_config.vocab_size,
|
||||
)
|
||||
|
||||
def load_weights(self, weights, weight_mapper=None, skip_modules=None):
|
||||
"""
|
||||
Load weights with custom mapping for StarCoder2.
|
||||
|
||||
StarCoder2 uses GPT-2 style MLP naming (c_fc, c_proj)
|
||||
while our MLP module expects (up_proj, down_proj).
|
||||
"""
|
||||
if skip_modules is None:
|
||||
skip_modules = []
|
||||
|
||||
# Map HuggingFace StarCoder2 weight names to TensorRT-LLM names
|
||||
params_map = {
|
||||
r"(.*?)\.mlp\.c_fc\.(.*)": r"\1.mlp.up_proj.\2",
|
||||
r"(.*?)\.mlp\.c_proj\.(.*)": r"\1.mlp.down_proj.\2",
|
||||
}
|
||||
preload_weight_modules = getattr(self, "preload_weight_modules", None)
|
||||
_load_weights_impl(
|
||||
self,
|
||||
weights,
|
||||
skip_modules,
|
||||
params_map=params_map,
|
||||
preload_weight_modules=preload_weight_modules,
|
||||
)
|
||||
@ -270,3 +270,9 @@ zai-org/GLM-4.6:
|
||||
- quant_algo: NVFP4
|
||||
spec_dec_algo: MTP
|
||||
accuracy: 88.0
|
||||
bigcode/starcoder2-3b:
|
||||
- accuracy: 20.2
|
||||
bigcode/starcoder2-7b:
|
||||
- accuracy: 26.5
|
||||
bigcode/starcoder2-15b:
|
||||
- accuracy: 54.5
|
||||
|
||||
@ -4264,3 +4264,49 @@ class TestDeepSeekR1LongBenchV2(LlmapiAccuracyTestHarness):
|
||||
if temp_dir and os.path.exists(temp_dir):
|
||||
import shutil
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
|
||||
class TestStarcoder2_3B(LlmapiAccuracyTestHarness):
|
||||
MODEL_NAME = "bigcode/starcoder2-3b"
|
||||
MODEL_PATH = f"{llm_models_root()}/starcoder2-3b/"
|
||||
|
||||
@skip_pre_hopper
|
||||
def test_auto_dtype(self):
|
||||
with LLM(self.MODEL_PATH,
|
||||
attn_backend="TRTLLM",
|
||||
cuda_graph_config=None,
|
||||
max_batch_size=128,
|
||||
max_seq_len=4096) as llm:
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
|
||||
|
||||
class TestStarcoder2_7B(LlmapiAccuracyTestHarness):
|
||||
MODEL_NAME = "bigcode/starcoder2-7b"
|
||||
MODEL_PATH = f"{llm_models_root()}/starcoder2-7b/"
|
||||
|
||||
@skip_pre_hopper
|
||||
def test_auto_dtype(self):
|
||||
with LLM(self.MODEL_PATH,
|
||||
attn_backend="TRTLLM",
|
||||
cuda_graph_config=None,
|
||||
max_batch_size=128,
|
||||
max_seq_len=4096) as llm:
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
|
||||
|
||||
class TestStarcoder2_15B(LlmapiAccuracyTestHarness):
|
||||
MODEL_NAME = "bigcode/starcoder2-15b"
|
||||
MODEL_PATH = f"{llm_models_root()}/starcoder2-15b/"
|
||||
|
||||
@skip_pre_hopper
|
||||
@pytest.mark.skip_less_device_memory(80000)
|
||||
def test_auto_dtype(self):
|
||||
with LLM(self.MODEL_PATH,
|
||||
attn_backend="TRTLLM",
|
||||
cuda_graph_config=None,
|
||||
max_batch_size=128,
|
||||
max_seq_len=4096) as llm:
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
|
||||
@ -381,6 +381,10 @@ accuracy/test_llm_api_pytorch.py::TestNemotronUltra::test_fp8_prequantized[tp8-c
|
||||
accuracy/test_llm_api_pytorch.py::TestQwQ_32B::test_auto_dtype_tp4
|
||||
accuracy/test_llm_api_pytorch.py::TestCodestral_22B_V01::test_auto_dtype
|
||||
accuracy/test_llm_api_pytorch.py::TestKimiK2::test_fp8_blockscale[latency]
|
||||
accuracy/test_llm_api_pytorch.py::TestStarcoder2_3B::test_auto_dtype
|
||||
accuracy/test_llm_api_pytorch.py::TestStarcoder2_7B::test_auto_dtype
|
||||
accuracy/test_llm_api_pytorch.py::TestStarcoder2_15B::test_auto_dtype
|
||||
|
||||
accuracy/test_llm_api_pytorch_multimodal.py::TestQwen2_VL_7B::test_auto_dtype
|
||||
accuracy/test_llm_api_pytorch_multimodal.py::TestQwen2_5_VL_7B::test_auto_dtype
|
||||
accuracy/test_llm_api_pytorch_multimodal.py::TestLlava_V1_6_Mistral_7B::test_auto_dtype
|
||||
|
||||
@ -19,6 +19,7 @@ l0_a30:
|
||||
- unittest/_torch/modeling -k "modeling_qwen"
|
||||
- unittest/_torch/modeling -k "modeling_qwen_moe"
|
||||
- unittest/_torch/modeling -k "modeling_out_of_tree"
|
||||
- unittest/_torch/modeling -k "modeling_starcoder2"
|
||||
- unittest/_torch/auto_deploy/unit/singlegpu
|
||||
- unittest/_torch/sampler/test_beam_search.py
|
||||
- unittest/_torch/sampler/test_return_logits.py
|
||||
|
||||
@ -265,6 +265,9 @@ l0_h100:
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_eagle3[llguidance-eagle3_one_model=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_ngram[xgrammar]
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_ngram[llguidance]
|
||||
- accuracy/test_llm_api_pytorch.py::TestStarcoder2_3B::test_auto_dtype
|
||||
- accuracy/test_llm_api_pytorch.py::TestStarcoder2_7B::test_auto_dtype
|
||||
- accuracy/test_llm_api_pytorch.py::TestStarcoder2_15B::test_auto_dtype
|
||||
- test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-image-True]
|
||||
- test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-mixture_text_image-True]
|
||||
- test_e2e.py::test_ptp_quickstart_multimodal_kv_cache_reuse[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-0.6-image]
|
||||
|
||||
313
tests/unittest/_torch/modeling/test_modeling_starcoder2.py
Normal file
313
tests/unittest/_torch/modeling/test_modeling_starcoder2.py
Normal file
@ -0,0 +1,313 @@
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import Starcoder2Config
|
||||
from transformers import Starcoder2ForCausalLM as HFStarcoder2ForCausalLM
|
||||
from utils.util import default_dtype
|
||||
|
||||
import tensorrt_llm
|
||||
from tensorrt_llm._torch.attention_backend.utils import get_attention_backend
|
||||
from tensorrt_llm._torch.metadata import KVCacheParams
|
||||
from tensorrt_llm._torch.model_config import ModelConfig
|
||||
from tensorrt_llm._torch.models.modeling_starcoder2 import Starcoder2ForCausalLM
|
||||
from tensorrt_llm._torch.modules.layer_norm import LayerNorm
|
||||
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
|
||||
from tensorrt_llm.bindings.executor import KvCacheConfig
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
|
||||
# Base config for all StarCoder2 models (based on HuggingFace configs)
|
||||
_STARCODER2_BASE_CONFIG = {
|
||||
"architectures": ["Starcoder2ForCausalLM"],
|
||||
"attention_dropout": 0.1,
|
||||
"residual_dropout": 0.1,
|
||||
"embedding_dropout": 0.1,
|
||||
"bos_token_id": 0,
|
||||
"eos_token_id": 0,
|
||||
"hidden_act": "gelu_pytorch_tanh",
|
||||
"max_position_embeddings": 16384,
|
||||
"mlp_type": "default",
|
||||
"model_type": "starcoder2",
|
||||
"norm_epsilon": 1e-5,
|
||||
"num_hidden_layers": 6, # Reduced from 30/32/40 for testing
|
||||
"sliding_window": 4096,
|
||||
"transformers_version": "4.37.0.dev0",
|
||||
"use_bias": True,
|
||||
"use_cache": True,
|
||||
"vocab_size": 49152,
|
||||
"torch_dtype": "bfloat16",
|
||||
}
|
||||
|
||||
# StarCoder2-3B config (reduced for testing)
|
||||
STARCODER2_3B_CONFIG = {
|
||||
**_STARCODER2_BASE_CONFIG,
|
||||
"hidden_size": 3072,
|
||||
"initializer_range": 0.018042,
|
||||
"intermediate_size": 12288,
|
||||
"num_attention_heads": 24,
|
||||
"num_key_value_heads": 2,
|
||||
"rope_theta": 999999.4420358813,
|
||||
}
|
||||
|
||||
# StarCoder2-7B config (reduced for testing)
|
||||
STARCODER2_7B_CONFIG = {
|
||||
**_STARCODER2_BASE_CONFIG,
|
||||
"hidden_size": 4608,
|
||||
"initializer_range": 0.018042,
|
||||
"intermediate_size": 18432,
|
||||
"num_attention_heads": 36,
|
||||
"num_key_value_heads": 4,
|
||||
"rope_theta": 1000000,
|
||||
}
|
||||
|
||||
# StarCoder2-15B config (reduced for testing)
|
||||
STARCODER2_15B_CONFIG = {
|
||||
**_STARCODER2_BASE_CONFIG,
|
||||
"hidden_size": 6144,
|
||||
"initializer_range": 0.01275,
|
||||
"intermediate_size": 24576,
|
||||
"num_attention_heads": 48,
|
||||
"num_key_value_heads": 4,
|
||||
"rope_theta": 100000,
|
||||
}
|
||||
|
||||
|
||||
@dataclass(repr=False)
|
||||
class Scenario:
|
||||
backend: str
|
||||
config_name: str
|
||||
use_cuda_graph: bool = False
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"backend:{self.backend.lower()}_config:{self.config_name.lower()}_cuda_graph:{self.use_cuda_graph}"
|
||||
|
||||
|
||||
def get_kv_cache_manager(
|
||||
dtype: torch.dtype,
|
||||
config: Starcoder2Config,
|
||||
tokens_per_block: int,
|
||||
max_seq_len: int,
|
||||
batch_size: int,
|
||||
num_blocks: int,
|
||||
):
|
||||
"""Helper to create KV cache manager."""
|
||||
if dtype == torch.half:
|
||||
kv_cache_dtype = tensorrt_llm.bindings.DataType.HALF
|
||||
elif dtype == torch.bfloat16:
|
||||
kv_cache_dtype = tensorrt_llm.bindings.DataType.BF16
|
||||
else:
|
||||
raise ValueError(f"Invalid dtype: {dtype}")
|
||||
|
||||
mapping = Mapping(world_size=1, tp_size=1, rank=0)
|
||||
kv_cache_config = KvCacheConfig(
|
||||
enable_block_reuse=False,
|
||||
enable_partial_reuse=False,
|
||||
copy_on_partial_reuse=False,
|
||||
max_tokens=num_blocks * tokens_per_block,
|
||||
)
|
||||
|
||||
head_dim = config.hidden_size // config.num_attention_heads
|
||||
kv_cache_manager = KVCacheManager(
|
||||
kv_cache_config,
|
||||
tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF,
|
||||
num_layers=config.num_hidden_layers,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
head_dim=head_dim,
|
||||
tokens_per_block=tokens_per_block,
|
||||
max_seq_len=max_seq_len,
|
||||
max_batch_size=batch_size,
|
||||
mapping=mapping,
|
||||
dtype=kv_cache_dtype,
|
||||
)
|
||||
return kv_cache_manager
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"scenario",
|
||||
[
|
||||
# Test without CUDA graphs
|
||||
Scenario(backend="TRTLLM", config_name="3B", use_cuda_graph=False),
|
||||
Scenario(backend="TRTLLM", config_name="7B", use_cuda_graph=False),
|
||||
Scenario(backend="TRTLLM", config_name="15B", use_cuda_graph=False),
|
||||
# Test with CUDA graphs
|
||||
Scenario(backend="TRTLLM", config_name="3B", use_cuda_graph=True),
|
||||
Scenario(backend="TRTLLM", config_name="7B", use_cuda_graph=True),
|
||||
Scenario(backend="TRTLLM", config_name="15B", use_cuda_graph=True),
|
||||
],
|
||||
ids=str,
|
||||
)
|
||||
@torch.no_grad()
|
||||
def test_starcoder2_allclose_to_hf(scenario: Scenario) -> None:
|
||||
"""
|
||||
Compare TensorRT-LLM StarCoder2 output to HuggingFace.
|
||||
|
||||
Tests both context and generation phases using randomly initialized models.
|
||||
Optionally tests with CUDA graphs for generation phase optimization.
|
||||
"""
|
||||
backend = scenario.backend
|
||||
config_name = scenario.config_name
|
||||
use_cuda_graph = scenario.use_cuda_graph
|
||||
metadata_cls = get_attention_backend(backend).Metadata
|
||||
|
||||
torch.random.manual_seed(0)
|
||||
|
||||
# Create config based on model size
|
||||
config_mapping = {
|
||||
"3B": STARCODER2_3B_CONFIG,
|
||||
"7B": STARCODER2_7B_CONFIG,
|
||||
"15B": STARCODER2_15B_CONFIG,
|
||||
}
|
||||
config_dict = deepcopy(config_mapping[config_name])
|
||||
|
||||
# Create HuggingFace model from config with random weights
|
||||
hf_config = Starcoder2Config.from_dict(config_dict)
|
||||
hf_starcoder2 = HFStarcoder2ForCausalLM(hf_config)
|
||||
hf_starcoder2 = hf_starcoder2.to(dtype=torch.bfloat16, device="cuda").eval()
|
||||
|
||||
dtype = torch.bfloat16
|
||||
device = torch.device("cuda")
|
||||
|
||||
# Build TRT-LLM model and copy the same random weights from HF model
|
||||
with torch.device(device), default_dtype(dtype):
|
||||
model_config = ModelConfig(pretrained_config=hf_config, attn_backend=backend)
|
||||
starcoder2 = Starcoder2ForCausalLM(model_config).to(dtype).to(device).eval()
|
||||
starcoder2.load_weights(hf_starcoder2.state_dict())
|
||||
|
||||
# Convert LayerNorm random weights to FP32 for numerical stability
|
||||
for name, module in starcoder2.named_modules():
|
||||
if isinstance(module, LayerNorm):
|
||||
if hasattr(module, "weight") and module.weight is not None:
|
||||
module.weight.data = module.weight.data.to(torch.float32)
|
||||
if hasattr(module, "bias") and module.bias is not None:
|
||||
module.bias.data = module.bias.data.to(torch.float32)
|
||||
|
||||
num_blocks = 1
|
||||
tokens_per_block = 128
|
||||
max_seq_len = num_blocks * tokens_per_block
|
||||
batch_size = 1
|
||||
|
||||
kv_cache_manager = get_kv_cache_manager(
|
||||
dtype=dtype,
|
||||
config=hf_config,
|
||||
tokens_per_block=tokens_per_block,
|
||||
max_seq_len=max_seq_len,
|
||||
batch_size=batch_size,
|
||||
num_blocks=num_blocks,
|
||||
)
|
||||
|
||||
# Context phase (no CUDA graphs for prefill)
|
||||
input_ids = torch.tensor(
|
||||
[100, 200, 300, 400, 500, 600, 700, 800],
|
||||
dtype=torch.int,
|
||||
device=device,
|
||||
)
|
||||
num_cached_tokens_per_seq = [0]
|
||||
request_ids = [1]
|
||||
token_nums = [input_ids.size(-1)]
|
||||
prompt_lens = [input_ids.size(-1)]
|
||||
kv_cache_manager.add_dummy_requests(request_ids, token_nums)
|
||||
|
||||
attn_metadata = metadata_cls(
|
||||
seq_lens=torch.tensor([input_ids.size(-1)], dtype=torch.int),
|
||||
num_contexts=1,
|
||||
kv_cache_params=KVCacheParams(
|
||||
use_cache=True,
|
||||
num_cached_tokens_per_seq=num_cached_tokens_per_seq,
|
||||
),
|
||||
max_num_requests=1,
|
||||
max_num_tokens=8192,
|
||||
kv_cache_manager=kv_cache_manager,
|
||||
request_ids=request_ids,
|
||||
prompt_lens=prompt_lens,
|
||||
)
|
||||
|
||||
position_ids = [torch.arange(0, input_ids.size(-1), dtype=torch.int)]
|
||||
position_ids = torch.cat(position_ids).unsqueeze(0).cuda()
|
||||
|
||||
with torch.inference_mode():
|
||||
attn_metadata.prepare()
|
||||
logits = starcoder2.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
ref = hf_starcoder2.forward(
|
||||
input_ids=input_ids.unsqueeze(0),
|
||||
position_ids=position_ids,
|
||||
use_cache=True,
|
||||
)
|
||||
torch.testing.assert_close(logits, ref.logits[:, -1].float(), atol=0.1, rtol=0.1)
|
||||
|
||||
# Generation phase (optionally with CUDA graphs)
|
||||
gen_input_ids = torch.tensor([900], dtype=torch.int, device=device)
|
||||
num_cached_tokens_per_seq = [input_ids.size(-1)]
|
||||
|
||||
attn_metadata = metadata_cls(
|
||||
seq_lens=torch.tensor([gen_input_ids.size(-1)], dtype=torch.int),
|
||||
num_contexts=0,
|
||||
kv_cache_params=KVCacheParams(
|
||||
use_cache=True,
|
||||
num_cached_tokens_per_seq=num_cached_tokens_per_seq,
|
||||
),
|
||||
kv_cache_manager=kv_cache_manager,
|
||||
request_ids=request_ids,
|
||||
prompt_lens=prompt_lens,
|
||||
max_num_requests=1,
|
||||
max_num_tokens=8192,
|
||||
)
|
||||
|
||||
gen_position_ids = [
|
||||
torch.arange(
|
||||
input_ids.size(-1), input_ids.size(-1) + gen_input_ids.size(-1), dtype=torch.int
|
||||
)
|
||||
]
|
||||
gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda()
|
||||
|
||||
# Setup CUDA graph runner if requested
|
||||
graph_runner = None
|
||||
if use_cuda_graph:
|
||||
from _torch.helpers import create_mock_cuda_graph_runner
|
||||
|
||||
graph_runner = create_mock_cuda_graph_runner(1)
|
||||
attn_metadata = attn_metadata.create_cuda_graph_metadata(1)
|
||||
|
||||
# Run generation phase
|
||||
with torch.inference_mode():
|
||||
if not use_cuda_graph:
|
||||
attn_metadata.prepare()
|
||||
logits = starcoder2.forward(
|
||||
input_ids=gen_input_ids,
|
||||
position_ids=gen_position_ids,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
else:
|
||||
# CUDA graph path
|
||||
inputs = {
|
||||
"input_ids": gen_input_ids,
|
||||
"position_ids": gen_position_ids,
|
||||
"attn_metadata": attn_metadata,
|
||||
}
|
||||
key = (1, 0, False)
|
||||
|
||||
attn_metadata.prepare()
|
||||
graph_runner.capture(key, lambda inputs: starcoder2.forward(**inputs), inputs)
|
||||
|
||||
# Run twice to catch buffer reallocation issues
|
||||
for _ in range(2):
|
||||
attn_metadata.prepare()
|
||||
logits = graph_runner.replay(key, inputs)
|
||||
|
||||
# Compare with HuggingFace
|
||||
ref = hf_starcoder2.forward(
|
||||
input_ids=gen_input_ids.unsqueeze(0),
|
||||
position_ids=gen_position_ids,
|
||||
past_key_values=ref.past_key_values,
|
||||
use_cache=True,
|
||||
)
|
||||
torch.testing.assert_close(logits, ref.logits[:, -1].float(), atol=0.1, rtol=0.1)
|
||||
|
||||
# Cleanup
|
||||
if graph_runner is not None:
|
||||
graph_runner.clear()
|
||||
kv_cache_manager.shutdown()
|
||||
Loading…
Reference in New Issue
Block a user