mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
fix: Investigate Gemma3 1B decoder output discrepancy (#5564)
Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
This commit is contained in:
parent
92d3a2d0e0
commit
a3c0cf02ce
@ -1,3 +1,4 @@
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
@ -39,6 +40,23 @@ def generate_causal_mask(batch_size: int, target_length: int,
|
||||
return causal_mask
|
||||
|
||||
|
||||
def generate_sliding_window_mask(batch_size: int, target_length: int,
|
||||
cache_position: torch.Tensor,
|
||||
device: torch.device,
|
||||
attention_window_size: int):
|
||||
# TRTLLM's sliding window attention is inclusive.
|
||||
effective_window_size = attention_window_size + 1
|
||||
attention_mask_1 = torch.arange(
|
||||
target_length,
|
||||
device=device).unsqueeze(0) <= cache_position.unsqueeze(-1)
|
||||
attention_mask_2 = torch.arange(target_length, device=device).unsqueeze(
|
||||
0) > cache_position.unsqueeze(-1) - effective_window_size
|
||||
attention_mask = attention_mask_1 & attention_mask_2
|
||||
attention_mask = attention_mask[None,
|
||||
None, :, :].expand(batch_size, 1, -1, -1)
|
||||
return attention_mask
|
||||
|
||||
|
||||
class VanillaAttentionMetadata(AttentionMetadata):
|
||||
|
||||
def prepare(self) -> None:
|
||||
@ -66,11 +84,17 @@ class VanillaAttention(AttentionBackend[VanillaAttentionMetadata]):
|
||||
head_dim: int,
|
||||
num_kv_heads: Optional[int] = None,
|
||||
quant_config: Optional[QuantConfig] = None,
|
||||
q_scaling: Optional[float] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(layer_idx, num_heads, head_dim, num_kv_heads,
|
||||
quant_config, **kwargs)
|
||||
super().__init__(layer_idx,
|
||||
num_heads,
|
||||
head_dim,
|
||||
num_kv_heads=num_kv_heads,
|
||||
quant_config=quant_config,
|
||||
**kwargs)
|
||||
self.num_key_value_groups = self.num_heads // self.num_kv_heads
|
||||
self.q_scaling = q_scaling
|
||||
|
||||
def _single_request_update_kv_cache(self, k, v, kv_cache_tensor, seq_len,
|
||||
cache_idx, cache_position):
|
||||
@ -86,8 +110,15 @@ class VanillaAttention(AttentionBackend[VanillaAttentionMetadata]):
|
||||
|
||||
return k_out[:, :seq_len, :, :], v_out[:, :seq_len, :, :]
|
||||
|
||||
def _single_request_forward(self, q, k, v, attention_mask: AttentionMask,
|
||||
kv_cache_tensor, past_seen_token, cache_idx):
|
||||
def _single_request_forward(self,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
attention_mask: AttentionMask,
|
||||
kv_cache_tensor,
|
||||
past_seen_token,
|
||||
cache_idx,
|
||||
attention_window_size: Optional[int] = None):
|
||||
|
||||
bsz = 1
|
||||
q_len = q.size(0)
|
||||
@ -129,7 +160,12 @@ class VanillaAttention(AttentionBackend[VanillaAttentionMetadata]):
|
||||
is_causal = False
|
||||
attn_mask = None
|
||||
if attention_mask == PredefinedAttentionMask.CAUSAL:
|
||||
if past_seen_token == 0:
|
||||
# Create custom sliding window mask as sdpa doesn't natively support it.
|
||||
if attention_window_size is not None:
|
||||
attn_mask = generate_sliding_window_mask(
|
||||
bsz, target_seq_len, cache_position, q.device,
|
||||
attention_window_size)
|
||||
elif past_seen_token == 0:
|
||||
is_causal = True
|
||||
elif q_len != 1:
|
||||
# attn_mask: 4-D tensor (batch_size, 1, query_seq_len, seq_len)
|
||||
@ -140,12 +176,17 @@ class VanillaAttention(AttentionBackend[VanillaAttentionMetadata]):
|
||||
else:
|
||||
raise ValueError("Unexpected attention mask type")
|
||||
|
||||
qk_scale = None
|
||||
if self.q_scaling is not None:
|
||||
qk_scale = 1 / (math.sqrt(self.head_dim) * self.q_scaling)
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
q,
|
||||
key_states,
|
||||
value_states,
|
||||
is_causal=is_causal,
|
||||
attn_mask=attn_mask,
|
||||
scale=qk_scale,
|
||||
)
|
||||
|
||||
attn_output = attn_output.squeeze(0)
|
||||
@ -229,6 +270,7 @@ class VanillaAttention(AttentionBackend[VanillaAttentionMetadata]):
|
||||
metadata: VanillaAttentionMetadata,
|
||||
*,
|
||||
attention_mask: AttentionMask = PredefinedAttentionMask.CAUSAL,
|
||||
attention_window_size: Optional[int] = None,
|
||||
**kwargs) -> torch.Tensor:
|
||||
if metadata.kv_cache_manager is None:
|
||||
# NOTE: WAR for no kv cache attn e.g. BERT,
|
||||
@ -270,11 +312,9 @@ class VanillaAttention(AttentionBackend[VanillaAttentionMetadata]):
|
||||
past_seen_token = past_seen_tokens[i]
|
||||
cache_idx = cache_indices[i]
|
||||
|
||||
attn_output = self._single_request_forward(single_q, single_k,
|
||||
single_v, attention_mask,
|
||||
kv_cache_tensor,
|
||||
past_seen_token,
|
||||
cache_idx)
|
||||
attn_output = self._single_request_forward(
|
||||
single_q, single_k, single_v, attention_mask, kv_cache_tensor,
|
||||
past_seen_token, cache_idx, attention_window_size)
|
||||
attn_outputs.append(attn_output)
|
||||
|
||||
offset += seq_len
|
||||
|
||||
@ -65,7 +65,7 @@ class Gemma3Attention(Attention):
|
||||
self.attention_window_size = None
|
||||
if is_sliding:
|
||||
rope_params.theta = 10000
|
||||
self.attention_window_size = config.sliding_window
|
||||
self.attention_window_size = config.sliding_window - 1 # Gemma3 sliding window isn't inclusive.
|
||||
pos_embd_params = PositionalEmbeddingParams(
|
||||
type=PositionEmbeddingType.rope_gpt_neox,
|
||||
rope=rope_params,
|
||||
@ -107,7 +107,6 @@ class Gemma3Attention(Attention):
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
|
||||
attention_window_size = self.attention_window_size or attn_metadata.max_seq_len
|
||||
return super().forward(position_ids=position_ids,
|
||||
hidden_states=hidden_states,
|
||||
attn_metadata=attn_metadata,
|
||||
@ -115,7 +114,7 @@ class Gemma3Attention(Attention):
|
||||
mrope_config=mrope_config,
|
||||
all_reduce_params=all_reduce_params,
|
||||
lora_params=lora_params,
|
||||
attention_window_size=attention_window_size,
|
||||
attention_window_size=self.attention_window_size,
|
||||
**kwargs)
|
||||
|
||||
def apply_qk_norm(self, q, k):
|
||||
|
||||
@ -437,7 +437,12 @@ class TestGemma3_1BInstruct(LlmapiAccuracyTestHarness):
|
||||
MODEL_PATH = f"{llm_models_root()}/gemma/gemma-3-1b-it/"
|
||||
|
||||
def test_auto_dtype(self):
|
||||
with LLM(self.MODEL_PATH) as llm:
|
||||
# Disabling kv cache reuse as a WAR to deal with gaps in kernel support for Gemma3's non-inclusive sliding window size.
|
||||
kv_cache_config = KvCacheConfig(
|
||||
enable_block_reuse=False,
|
||||
enable_partial_reuse=False,
|
||||
)
|
||||
with LLM(self.MODEL_PATH, kv_cache_config=kv_cache_config) as llm:
|
||||
task = CnnDailymail(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
|
||||
|
||||
300
tests/unittest/_torch/modeling/test_modeling_gemma3.py
Normal file
300
tests/unittest/_torch/modeling/test_modeling_gemma3.py
Normal file
@ -0,0 +1,300 @@
|
||||
import unittest
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from parameterized import parameterized
|
||||
from transformers import Gemma3Config
|
||||
from transformers import Gemma3ForCausalLM as HFGemma3ForCausalLM
|
||||
from transformers.cache_utils import HybridCache
|
||||
|
||||
import tensorrt_llm
|
||||
from tensorrt_llm._torch.attention_backend.utils import get_attention_backend
|
||||
from tensorrt_llm._torch.metadata import KVCacheParams
|
||||
from tensorrt_llm._torch.model_config import ModelConfig
|
||||
from tensorrt_llm._torch.models.modeling_gemma3 import Gemma3ForCausalLM
|
||||
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
|
||||
from tensorrt_llm.bindings.executor import KvCacheConfig
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
|
||||
# This is copied from https://huggingface.co/google/gemma-3-1b-it/blob/main/config.json.
|
||||
# Updated to have 1 local layer and 1 global layer. Sliding window size updated to 4.
|
||||
GEMMA3_1B_MINI_CONFIG = {
|
||||
"architectures": ["Gemma3ForCausalLM"],
|
||||
"attention_bias": False,
|
||||
"attention_dropout": 0.0,
|
||||
"attn_logit_softcapping": None,
|
||||
"bos_token_id": 2,
|
||||
"cache_implementation": "hybrid",
|
||||
"eos_token_id": [1, 106],
|
||||
"final_logit_softcapping": None,
|
||||
"head_dim": 256,
|
||||
"hidden_activation": "gelu_pytorch_tanh",
|
||||
"hidden_size": 1152,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 6912,
|
||||
"max_position_embeddings": 32768,
|
||||
"model_type": "gemma3_text",
|
||||
"num_attention_heads": 4,
|
||||
"num_hidden_layers": 2, # Modified for testing.
|
||||
"num_key_value_heads": 1,
|
||||
"pad_token_id": 0,
|
||||
"query_pre_attn_scalar": 256,
|
||||
"rms_norm_eps": 1e-06,
|
||||
"rope_local_base_freq": 10000,
|
||||
"rope_scaling": None,
|
||||
"rope_theta": 1000000,
|
||||
"sliding_window": 4, # Modified for testing.
|
||||
"sliding_window_pattern": 2, # Modified for testing.
|
||||
"torch_dtype": "bfloat16",
|
||||
"transformers_version": "4.50.0.dev0",
|
||||
"use_cache": True,
|
||||
"vocab_size": 262144
|
||||
}
|
||||
|
||||
|
||||
@dataclass(repr=False)
|
||||
class Scenario:
|
||||
backend: str
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"backend:{self.backend.lower()}"
|
||||
|
||||
|
||||
class TestGemma3(unittest.TestCase):
|
||||
|
||||
def get_kv_cache_manager(self, dtype: torch.dtype, config: Gemma3Config,
|
||||
tokens_per_block: int, max_seq_len: int,
|
||||
batch_size: int, num_blocks: int):
|
||||
if dtype == torch.half:
|
||||
kv_cache_dtype = tensorrt_llm.bindings.DataType.HALF
|
||||
elif dtype == torch.bfloat16:
|
||||
kv_cache_dtype = tensorrt_llm.bindings.DataType.BF16
|
||||
else:
|
||||
raise ValueError("Invalid dtype")
|
||||
|
||||
mapping = Mapping(world_size=1, tp_size=1, rank=0)
|
||||
kv_cache_config = KvCacheConfig(enable_block_reuse=False,
|
||||
enable_partial_reuse=False,
|
||||
copy_on_partial_reuse=False,
|
||||
max_tokens=num_blocks *
|
||||
tokens_per_block)
|
||||
kv_cache_manager = KVCacheManager(
|
||||
kv_cache_config,
|
||||
tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF,
|
||||
num_layers=config.num_hidden_layers,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
head_dim=config.head_dim,
|
||||
tokens_per_block=tokens_per_block,
|
||||
max_seq_len=max_seq_len,
|
||||
max_batch_size=batch_size,
|
||||
mapping=mapping,
|
||||
dtype=kv_cache_dtype,
|
||||
)
|
||||
return kv_cache_manager
|
||||
|
||||
def test_gemma3_sanity(self):
|
||||
|
||||
config_dict = deepcopy(GEMMA3_1B_MINI_CONFIG)
|
||||
gemma3_config = Gemma3Config.from_dict(config_dict)
|
||||
|
||||
dtype = gemma3_config.torch_dtype
|
||||
device = torch.device('cuda')
|
||||
|
||||
model_config = ModelConfig(pretrained_config=gemma3_config)
|
||||
gemma3 = Gemma3ForCausalLM(model_config).to(device)
|
||||
|
||||
input_ids = torch.tensor([100, 200, 300, 400, 500, 600, 700, 800],
|
||||
dtype=torch.int,
|
||||
device=device)
|
||||
|
||||
context_sequence_lengths = [3, 2, 1]
|
||||
sequence_lengths = context_sequence_lengths + [1, 1]
|
||||
past_seen_tokens = [0, 0, 0, 62, 75]
|
||||
request_ids = list(range(len(sequence_lengths)))
|
||||
token_nums = (torch.tensor(past_seen_tokens) +
|
||||
torch.tensor(sequence_lengths)).tolist()
|
||||
prompt_lens = token_nums[:3] + past_seen_tokens[3:]
|
||||
|
||||
num_blocks = 100
|
||||
tokens_per_block = 128
|
||||
max_seq_len = num_blocks * tokens_per_block
|
||||
batch_size = len(context_sequence_lengths) + 2
|
||||
kv_cache_manager = self.get_kv_cache_manager(
|
||||
dtype=dtype,
|
||||
config=gemma3_config,
|
||||
tokens_per_block=tokens_per_block,
|
||||
max_seq_len=max_seq_len,
|
||||
batch_size=batch_size,
|
||||
num_blocks=num_blocks)
|
||||
kv_cache_manager.add_dummy_requests(request_ids, token_nums)
|
||||
|
||||
metadata_cls = get_attention_backend(model_config.attn_backend).Metadata
|
||||
attn_metadata = metadata_cls(
|
||||
seq_lens=torch.tensor(sequence_lengths, dtype=torch.int),
|
||||
num_contexts=len(context_sequence_lengths),
|
||||
kv_cache_params=KVCacheParams(
|
||||
use_cache=True,
|
||||
num_cached_tokens_per_seq=past_seen_tokens,
|
||||
),
|
||||
kv_cache_manager=kv_cache_manager,
|
||||
request_ids=request_ids,
|
||||
prompt_lens=prompt_lens,
|
||||
max_num_requests=len(context_sequence_lengths) + 2,
|
||||
max_num_tokens=8192,
|
||||
)
|
||||
|
||||
position_ids = []
|
||||
for i, tokens in enumerate(past_seen_tokens):
|
||||
seq_len = context_sequence_lengths[i] if i < len(
|
||||
context_sequence_lengths) else 1
|
||||
position_id = torch.arange(tokens,
|
||||
tokens + seq_len,
|
||||
device=input_ids.device)
|
||||
position_ids.append(position_id)
|
||||
|
||||
position_ids = torch.cat(position_ids).unsqueeze(0)
|
||||
|
||||
with torch.inference_mode():
|
||||
attn_metadata.prepare()
|
||||
logits = gemma3.forward(input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
attn_metadata=attn_metadata)
|
||||
|
||||
self.assertEqual(len(past_seen_tokens), logits.shape[0])
|
||||
|
||||
with torch.inference_mode():
|
||||
attn_metadata.prepare()
|
||||
logits = gemma3.forward(input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
attn_metadata=attn_metadata,
|
||||
return_context_logits=True)
|
||||
self.assertEqual(input_ids.shape, logits.shape[:-1])
|
||||
|
||||
kv_cache_manager.shutdown()
|
||||
|
||||
@parameterized.expand([
|
||||
Scenario(backend="TRTLLM"),
|
||||
Scenario(backend="VANILLA"),
|
||||
], lambda testcase_func, param_num, param:
|
||||
f"{testcase_func.__name__}[{param.args[0]}]")
|
||||
@torch.no_grad()
|
||||
def test_gemma3_allclose_to_hf(self, scenario: Scenario) -> None:
|
||||
"""
|
||||
Compare output to HF.
|
||||
"""
|
||||
backend = scenario.backend
|
||||
metadata_cls = get_attention_backend(backend).Metadata
|
||||
|
||||
torch.random.manual_seed(0)
|
||||
config_dict = deepcopy(GEMMA3_1B_MINI_CONFIG)
|
||||
gemma3_config = Gemma3Config.from_dict(config_dict)
|
||||
dtype = gemma3_config.torch_dtype
|
||||
device = torch.device('cuda')
|
||||
|
||||
num_blocks = 1
|
||||
tokens_per_block = 128
|
||||
max_seq_len = num_blocks * tokens_per_block
|
||||
batch_size = 1
|
||||
|
||||
hf_gemma3 = HFGemma3ForCausalLM(gemma3_config).to(dtype).to(
|
||||
device).eval()
|
||||
hf_cache = HybridCache(config=gemma3_config,
|
||||
max_batch_size=batch_size,
|
||||
max_cache_len=10,
|
||||
device=device,
|
||||
dtype=dtype)
|
||||
|
||||
model_config = ModelConfig(pretrained_config=gemma3_config,
|
||||
attn_backend=backend)
|
||||
gemma3 = Gemma3ForCausalLM(model_config).to(dtype).to(device)
|
||||
gemma3.load_weights(hf_gemma3.state_dict())
|
||||
|
||||
kv_cache_manager = self.get_kv_cache_manager(
|
||||
dtype=dtype,
|
||||
config=gemma3_config,
|
||||
tokens_per_block=tokens_per_block,
|
||||
max_seq_len=max_seq_len,
|
||||
batch_size=batch_size,
|
||||
num_blocks=num_blocks)
|
||||
|
||||
# Context phase.
|
||||
input_ids = torch.tensor([100, 200, 300, 400, 500, 600, 700, 800],
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
num_cached_tokens_per_seq = [0]
|
||||
request_ids = [1]
|
||||
token_nums = [input_ids.size(-1)]
|
||||
prompt_lens = [input_ids.size(-1)]
|
||||
kv_cache_manager.add_dummy_requests(request_ids, token_nums)
|
||||
|
||||
attn_metadata = metadata_cls(
|
||||
seq_lens=torch.tensor([input_ids.size(-1)], dtype=torch.int),
|
||||
num_contexts=1,
|
||||
kv_cache_params=KVCacheParams(
|
||||
use_cache=True,
|
||||
num_cached_tokens_per_seq=num_cached_tokens_per_seq,
|
||||
),
|
||||
max_num_requests=1,
|
||||
max_num_tokens=8192,
|
||||
kv_cache_manager=kv_cache_manager,
|
||||
request_ids=request_ids,
|
||||
prompt_lens=prompt_lens,
|
||||
)
|
||||
position_ids = [torch.arange(0, input_ids.size(-1), dtype=torch.int32)]
|
||||
position_ids = torch.cat(position_ids).unsqueeze(0).cuda()
|
||||
|
||||
with torch.inference_mode():
|
||||
attn_metadata.prepare()
|
||||
logits = gemma3.forward(input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
attn_metadata=attn_metadata)
|
||||
ref = hf_gemma3.forward(input_ids=input_ids.unsqueeze(0),
|
||||
position_ids=position_ids,
|
||||
past_key_values=hf_cache,
|
||||
use_cache=True)
|
||||
torch.testing.assert_close(logits,
|
||||
ref.logits[:, -1].float(),
|
||||
atol=0.05,
|
||||
rtol=0.05)
|
||||
|
||||
# Generation phase.
|
||||
gen_input_ids = torch.tensor([900], dtype=torch.int, device=device)
|
||||
num_cached_tokens_per_seq = [input_ids.size(-1)]
|
||||
attn_metadata = metadata_cls(
|
||||
seq_lens=torch.tensor([gen_input_ids.size(-1)], dtype=torch.int),
|
||||
num_contexts=0,
|
||||
kv_cache_params=KVCacheParams(
|
||||
use_cache=True,
|
||||
num_cached_tokens_per_seq=num_cached_tokens_per_seq,
|
||||
),
|
||||
kv_cache_manager=kv_cache_manager,
|
||||
request_ids=request_ids,
|
||||
prompt_lens=prompt_lens,
|
||||
max_num_requests=1,
|
||||
max_num_tokens=8192,
|
||||
)
|
||||
|
||||
gen_position_ids = [
|
||||
torch.arange(input_ids.size(-1),
|
||||
input_ids.size(-1) + gen_input_ids.size(-1))
|
||||
]
|
||||
gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda()
|
||||
with torch.inference_mode():
|
||||
attn_metadata.prepare()
|
||||
logits = gemma3.forward(input_ids=gen_input_ids,
|
||||
position_ids=gen_position_ids,
|
||||
attn_metadata=attn_metadata)
|
||||
ref = hf_gemma3.forward(input_ids=gen_input_ids.unsqueeze(0),
|
||||
position_ids=gen_position_ids,
|
||||
past_key_values=hf_cache,
|
||||
use_cache=True,
|
||||
cache_position=torch.IntTensor(
|
||||
[input_ids.size(-1)]).to(device),
|
||||
last_cache_position=input_ids.size(-1) + 1)
|
||||
torch.testing.assert_close(logits,
|
||||
ref.logits[:, -1].float(),
|
||||
atol=0.05,
|
||||
rtol=0.05)
|
||||
|
||||
kv_cache_manager.shutdown()
|
||||
Loading…
Reference in New Issue
Block a user