fix: Investigate Gemma3 1B decoder output discrepancy (#5564)

Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
This commit is contained in:
brb-nv 2025-07-02 18:55:25 -07:00 committed by GitHub
parent 92d3a2d0e0
commit a3c0cf02ce
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 358 additions and 14 deletions

View File

@ -1,3 +1,4 @@
import math
from typing import Optional
import torch
@ -39,6 +40,23 @@ def generate_causal_mask(batch_size: int, target_length: int,
return causal_mask
def generate_sliding_window_mask(batch_size: int, target_length: int,
cache_position: torch.Tensor,
device: torch.device,
attention_window_size: int):
# TRTLLM's sliding window attention is inclusive.
effective_window_size = attention_window_size + 1
attention_mask_1 = torch.arange(
target_length,
device=device).unsqueeze(0) <= cache_position.unsqueeze(-1)
attention_mask_2 = torch.arange(target_length, device=device).unsqueeze(
0) > cache_position.unsqueeze(-1) - effective_window_size
attention_mask = attention_mask_1 & attention_mask_2
attention_mask = attention_mask[None,
None, :, :].expand(batch_size, 1, -1, -1)
return attention_mask
class VanillaAttentionMetadata(AttentionMetadata):
def prepare(self) -> None:
@ -66,11 +84,17 @@ class VanillaAttention(AttentionBackend[VanillaAttentionMetadata]):
head_dim: int,
num_kv_heads: Optional[int] = None,
quant_config: Optional[QuantConfig] = None,
q_scaling: Optional[float] = None,
**kwargs,
):
super().__init__(layer_idx, num_heads, head_dim, num_kv_heads,
quant_config, **kwargs)
super().__init__(layer_idx,
num_heads,
head_dim,
num_kv_heads=num_kv_heads,
quant_config=quant_config,
**kwargs)
self.num_key_value_groups = self.num_heads // self.num_kv_heads
self.q_scaling = q_scaling
def _single_request_update_kv_cache(self, k, v, kv_cache_tensor, seq_len,
cache_idx, cache_position):
@ -86,8 +110,15 @@ class VanillaAttention(AttentionBackend[VanillaAttentionMetadata]):
return k_out[:, :seq_len, :, :], v_out[:, :seq_len, :, :]
def _single_request_forward(self, q, k, v, attention_mask: AttentionMask,
kv_cache_tensor, past_seen_token, cache_idx):
def _single_request_forward(self,
q,
k,
v,
attention_mask: AttentionMask,
kv_cache_tensor,
past_seen_token,
cache_idx,
attention_window_size: Optional[int] = None):
bsz = 1
q_len = q.size(0)
@ -129,7 +160,12 @@ class VanillaAttention(AttentionBackend[VanillaAttentionMetadata]):
is_causal = False
attn_mask = None
if attention_mask == PredefinedAttentionMask.CAUSAL:
if past_seen_token == 0:
# Create custom sliding window mask as sdpa doesn't natively support it.
if attention_window_size is not None:
attn_mask = generate_sliding_window_mask(
bsz, target_seq_len, cache_position, q.device,
attention_window_size)
elif past_seen_token == 0:
is_causal = True
elif q_len != 1:
# attn_mask: 4-D tensor (batch_size, 1, query_seq_len, seq_len)
@ -140,12 +176,17 @@ class VanillaAttention(AttentionBackend[VanillaAttentionMetadata]):
else:
raise ValueError("Unexpected attention mask type")
qk_scale = None
if self.q_scaling is not None:
qk_scale = 1 / (math.sqrt(self.head_dim) * self.q_scaling)
attn_output = torch.nn.functional.scaled_dot_product_attention(
q,
key_states,
value_states,
is_causal=is_causal,
attn_mask=attn_mask,
scale=qk_scale,
)
attn_output = attn_output.squeeze(0)
@ -229,6 +270,7 @@ class VanillaAttention(AttentionBackend[VanillaAttentionMetadata]):
metadata: VanillaAttentionMetadata,
*,
attention_mask: AttentionMask = PredefinedAttentionMask.CAUSAL,
attention_window_size: Optional[int] = None,
**kwargs) -> torch.Tensor:
if metadata.kv_cache_manager is None:
# NOTE: WAR for no kv cache attn e.g. BERT,
@ -270,11 +312,9 @@ class VanillaAttention(AttentionBackend[VanillaAttentionMetadata]):
past_seen_token = past_seen_tokens[i]
cache_idx = cache_indices[i]
attn_output = self._single_request_forward(single_q, single_k,
single_v, attention_mask,
kv_cache_tensor,
past_seen_token,
cache_idx)
attn_output = self._single_request_forward(
single_q, single_k, single_v, attention_mask, kv_cache_tensor,
past_seen_token, cache_idx, attention_window_size)
attn_outputs.append(attn_output)
offset += seq_len

View File

@ -65,7 +65,7 @@ class Gemma3Attention(Attention):
self.attention_window_size = None
if is_sliding:
rope_params.theta = 10000
self.attention_window_size = config.sliding_window
self.attention_window_size = config.sliding_window - 1 # Gemma3 sliding window isn't inclusive.
pos_embd_params = PositionalEmbeddingParams(
type=PositionEmbeddingType.rope_gpt_neox,
rope=rope_params,
@ -107,7 +107,6 @@ class Gemma3Attention(Attention):
**kwargs,
) -> torch.Tensor:
attention_window_size = self.attention_window_size or attn_metadata.max_seq_len
return super().forward(position_ids=position_ids,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
@ -115,7 +114,7 @@ class Gemma3Attention(Attention):
mrope_config=mrope_config,
all_reduce_params=all_reduce_params,
lora_params=lora_params,
attention_window_size=attention_window_size,
attention_window_size=self.attention_window_size,
**kwargs)
def apply_qk_norm(self, q, k):

View File

@ -437,7 +437,12 @@ class TestGemma3_1BInstruct(LlmapiAccuracyTestHarness):
MODEL_PATH = f"{llm_models_root()}/gemma/gemma-3-1b-it/"
def test_auto_dtype(self):
with LLM(self.MODEL_PATH) as llm:
# Disabling kv cache reuse as a WAR to deal with gaps in kernel support for Gemma3's non-inclusive sliding window size.
kv_cache_config = KvCacheConfig(
enable_block_reuse=False,
enable_partial_reuse=False,
)
with LLM(self.MODEL_PATH, kv_cache_config=kv_cache_config) as llm:
task = CnnDailymail(self.MODEL_NAME)
task.evaluate(llm)

View File

@ -0,0 +1,300 @@
import unittest
from copy import deepcopy
from dataclasses import dataclass
import torch
from parameterized import parameterized
from transformers import Gemma3Config
from transformers import Gemma3ForCausalLM as HFGemma3ForCausalLM
from transformers.cache_utils import HybridCache
import tensorrt_llm
from tensorrt_llm._torch.attention_backend.utils import get_attention_backend
from tensorrt_llm._torch.metadata import KVCacheParams
from tensorrt_llm._torch.model_config import ModelConfig
from tensorrt_llm._torch.models.modeling_gemma3 import Gemma3ForCausalLM
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
from tensorrt_llm.bindings.executor import KvCacheConfig
from tensorrt_llm.mapping import Mapping
# This is copied from https://huggingface.co/google/gemma-3-1b-it/blob/main/config.json.
# Updated to have 1 local layer and 1 global layer. Sliding window size updated to 4.
GEMMA3_1B_MINI_CONFIG = {
"architectures": ["Gemma3ForCausalLM"],
"attention_bias": False,
"attention_dropout": 0.0,
"attn_logit_softcapping": None,
"bos_token_id": 2,
"cache_implementation": "hybrid",
"eos_token_id": [1, 106],
"final_logit_softcapping": None,
"head_dim": 256,
"hidden_activation": "gelu_pytorch_tanh",
"hidden_size": 1152,
"initializer_range": 0.02,
"intermediate_size": 6912,
"max_position_embeddings": 32768,
"model_type": "gemma3_text",
"num_attention_heads": 4,
"num_hidden_layers": 2, # Modified for testing.
"num_key_value_heads": 1,
"pad_token_id": 0,
"query_pre_attn_scalar": 256,
"rms_norm_eps": 1e-06,
"rope_local_base_freq": 10000,
"rope_scaling": None,
"rope_theta": 1000000,
"sliding_window": 4, # Modified for testing.
"sliding_window_pattern": 2, # Modified for testing.
"torch_dtype": "bfloat16",
"transformers_version": "4.50.0.dev0",
"use_cache": True,
"vocab_size": 262144
}
@dataclass(repr=False)
class Scenario:
backend: str
def __repr__(self) -> str:
return f"backend:{self.backend.lower()}"
class TestGemma3(unittest.TestCase):
def get_kv_cache_manager(self, dtype: torch.dtype, config: Gemma3Config,
tokens_per_block: int, max_seq_len: int,
batch_size: int, num_blocks: int):
if dtype == torch.half:
kv_cache_dtype = tensorrt_llm.bindings.DataType.HALF
elif dtype == torch.bfloat16:
kv_cache_dtype = tensorrt_llm.bindings.DataType.BF16
else:
raise ValueError("Invalid dtype")
mapping = Mapping(world_size=1, tp_size=1, rank=0)
kv_cache_config = KvCacheConfig(enable_block_reuse=False,
enable_partial_reuse=False,
copy_on_partial_reuse=False,
max_tokens=num_blocks *
tokens_per_block)
kv_cache_manager = KVCacheManager(
kv_cache_config,
tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF,
num_layers=config.num_hidden_layers,
num_kv_heads=config.num_key_value_heads,
head_dim=config.head_dim,
tokens_per_block=tokens_per_block,
max_seq_len=max_seq_len,
max_batch_size=batch_size,
mapping=mapping,
dtype=kv_cache_dtype,
)
return kv_cache_manager
def test_gemma3_sanity(self):
config_dict = deepcopy(GEMMA3_1B_MINI_CONFIG)
gemma3_config = Gemma3Config.from_dict(config_dict)
dtype = gemma3_config.torch_dtype
device = torch.device('cuda')
model_config = ModelConfig(pretrained_config=gemma3_config)
gemma3 = Gemma3ForCausalLM(model_config).to(device)
input_ids = torch.tensor([100, 200, 300, 400, 500, 600, 700, 800],
dtype=torch.int,
device=device)
context_sequence_lengths = [3, 2, 1]
sequence_lengths = context_sequence_lengths + [1, 1]
past_seen_tokens = [0, 0, 0, 62, 75]
request_ids = list(range(len(sequence_lengths)))
token_nums = (torch.tensor(past_seen_tokens) +
torch.tensor(sequence_lengths)).tolist()
prompt_lens = token_nums[:3] + past_seen_tokens[3:]
num_blocks = 100
tokens_per_block = 128
max_seq_len = num_blocks * tokens_per_block
batch_size = len(context_sequence_lengths) + 2
kv_cache_manager = self.get_kv_cache_manager(
dtype=dtype,
config=gemma3_config,
tokens_per_block=tokens_per_block,
max_seq_len=max_seq_len,
batch_size=batch_size,
num_blocks=num_blocks)
kv_cache_manager.add_dummy_requests(request_ids, token_nums)
metadata_cls = get_attention_backend(model_config.attn_backend).Metadata
attn_metadata = metadata_cls(
seq_lens=torch.tensor(sequence_lengths, dtype=torch.int),
num_contexts=len(context_sequence_lengths),
kv_cache_params=KVCacheParams(
use_cache=True,
num_cached_tokens_per_seq=past_seen_tokens,
),
kv_cache_manager=kv_cache_manager,
request_ids=request_ids,
prompt_lens=prompt_lens,
max_num_requests=len(context_sequence_lengths) + 2,
max_num_tokens=8192,
)
position_ids = []
for i, tokens in enumerate(past_seen_tokens):
seq_len = context_sequence_lengths[i] if i < len(
context_sequence_lengths) else 1
position_id = torch.arange(tokens,
tokens + seq_len,
device=input_ids.device)
position_ids.append(position_id)
position_ids = torch.cat(position_ids).unsqueeze(0)
with torch.inference_mode():
attn_metadata.prepare()
logits = gemma3.forward(input_ids=input_ids,
position_ids=position_ids,
attn_metadata=attn_metadata)
self.assertEqual(len(past_seen_tokens), logits.shape[0])
with torch.inference_mode():
attn_metadata.prepare()
logits = gemma3.forward(input_ids=input_ids,
position_ids=position_ids,
attn_metadata=attn_metadata,
return_context_logits=True)
self.assertEqual(input_ids.shape, logits.shape[:-1])
kv_cache_manager.shutdown()
@parameterized.expand([
Scenario(backend="TRTLLM"),
Scenario(backend="VANILLA"),
], lambda testcase_func, param_num, param:
f"{testcase_func.__name__}[{param.args[0]}]")
@torch.no_grad()
def test_gemma3_allclose_to_hf(self, scenario: Scenario) -> None:
"""
Compare output to HF.
"""
backend = scenario.backend
metadata_cls = get_attention_backend(backend).Metadata
torch.random.manual_seed(0)
config_dict = deepcopy(GEMMA3_1B_MINI_CONFIG)
gemma3_config = Gemma3Config.from_dict(config_dict)
dtype = gemma3_config.torch_dtype
device = torch.device('cuda')
num_blocks = 1
tokens_per_block = 128
max_seq_len = num_blocks * tokens_per_block
batch_size = 1
hf_gemma3 = HFGemma3ForCausalLM(gemma3_config).to(dtype).to(
device).eval()
hf_cache = HybridCache(config=gemma3_config,
max_batch_size=batch_size,
max_cache_len=10,
device=device,
dtype=dtype)
model_config = ModelConfig(pretrained_config=gemma3_config,
attn_backend=backend)
gemma3 = Gemma3ForCausalLM(model_config).to(dtype).to(device)
gemma3.load_weights(hf_gemma3.state_dict())
kv_cache_manager = self.get_kv_cache_manager(
dtype=dtype,
config=gemma3_config,
tokens_per_block=tokens_per_block,
max_seq_len=max_seq_len,
batch_size=batch_size,
num_blocks=num_blocks)
# Context phase.
input_ids = torch.tensor([100, 200, 300, 400, 500, 600, 700, 800],
dtype=torch.int32,
device=device)
num_cached_tokens_per_seq = [0]
request_ids = [1]
token_nums = [input_ids.size(-1)]
prompt_lens = [input_ids.size(-1)]
kv_cache_manager.add_dummy_requests(request_ids, token_nums)
attn_metadata = metadata_cls(
seq_lens=torch.tensor([input_ids.size(-1)], dtype=torch.int),
num_contexts=1,
kv_cache_params=KVCacheParams(
use_cache=True,
num_cached_tokens_per_seq=num_cached_tokens_per_seq,
),
max_num_requests=1,
max_num_tokens=8192,
kv_cache_manager=kv_cache_manager,
request_ids=request_ids,
prompt_lens=prompt_lens,
)
position_ids = [torch.arange(0, input_ids.size(-1), dtype=torch.int32)]
position_ids = torch.cat(position_ids).unsqueeze(0).cuda()
with torch.inference_mode():
attn_metadata.prepare()
logits = gemma3.forward(input_ids=input_ids,
position_ids=position_ids,
attn_metadata=attn_metadata)
ref = hf_gemma3.forward(input_ids=input_ids.unsqueeze(0),
position_ids=position_ids,
past_key_values=hf_cache,
use_cache=True)
torch.testing.assert_close(logits,
ref.logits[:, -1].float(),
atol=0.05,
rtol=0.05)
# Generation phase.
gen_input_ids = torch.tensor([900], dtype=torch.int, device=device)
num_cached_tokens_per_seq = [input_ids.size(-1)]
attn_metadata = metadata_cls(
seq_lens=torch.tensor([gen_input_ids.size(-1)], dtype=torch.int),
num_contexts=0,
kv_cache_params=KVCacheParams(
use_cache=True,
num_cached_tokens_per_seq=num_cached_tokens_per_seq,
),
kv_cache_manager=kv_cache_manager,
request_ids=request_ids,
prompt_lens=prompt_lens,
max_num_requests=1,
max_num_tokens=8192,
)
gen_position_ids = [
torch.arange(input_ids.size(-1),
input_ids.size(-1) + gen_input_ids.size(-1))
]
gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda()
with torch.inference_mode():
attn_metadata.prepare()
logits = gemma3.forward(input_ids=gen_input_ids,
position_ids=gen_position_ids,
attn_metadata=attn_metadata)
ref = hf_gemma3.forward(input_ids=gen_input_ids.unsqueeze(0),
position_ids=gen_position_ids,
past_key_values=hf_cache,
use_cache=True,
cache_position=torch.IntTensor(
[input_ids.size(-1)]).to(device),
last_cache_position=input_ids.size(-1) + 1)
torch.testing.assert_close(logits,
ref.logits[:, -1].float(),
atol=0.05,
rtol=0.05)
kv_cache_manager.shutdown()