From a3c0cf02ce6ca4df78f956d14f9a767c9e228dac Mon Sep 17 00:00:00 2001 From: brb-nv <169953907+brb-nv@users.noreply.github.com> Date: Wed, 2 Jul 2025 18:55:25 -0700 Subject: [PATCH] fix: Investigate Gemma3 1B decoder output discrepancy (#5564) Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com> --- .../_torch/attention_backend/vanilla.py | 60 +++- tensorrt_llm/_torch/models/modeling_gemma3.py | 5 +- .../defs/accuracy/test_llm_api_pytorch.py | 7 +- .../_torch/modeling/test_modeling_gemma3.py | 300 ++++++++++++++++++ 4 files changed, 358 insertions(+), 14 deletions(-) create mode 100644 tests/unittest/_torch/modeling/test_modeling_gemma3.py diff --git a/tensorrt_llm/_torch/attention_backend/vanilla.py b/tensorrt_llm/_torch/attention_backend/vanilla.py index 5578ac0280..3397ded646 100644 --- a/tensorrt_llm/_torch/attention_backend/vanilla.py +++ b/tensorrt_llm/_torch/attention_backend/vanilla.py @@ -1,3 +1,4 @@ +import math from typing import Optional import torch @@ -39,6 +40,23 @@ def generate_causal_mask(batch_size: int, target_length: int, return causal_mask +def generate_sliding_window_mask(batch_size: int, target_length: int, + cache_position: torch.Tensor, + device: torch.device, + attention_window_size: int): + # TRTLLM's sliding window attention is inclusive. + effective_window_size = attention_window_size + 1 + attention_mask_1 = torch.arange( + target_length, + device=device).unsqueeze(0) <= cache_position.unsqueeze(-1) + attention_mask_2 = torch.arange(target_length, device=device).unsqueeze( + 0) > cache_position.unsqueeze(-1) - effective_window_size + attention_mask = attention_mask_1 & attention_mask_2 + attention_mask = attention_mask[None, + None, :, :].expand(batch_size, 1, -1, -1) + return attention_mask + + class VanillaAttentionMetadata(AttentionMetadata): def prepare(self) -> None: @@ -66,11 +84,17 @@ class VanillaAttention(AttentionBackend[VanillaAttentionMetadata]): head_dim: int, num_kv_heads: Optional[int] = None, quant_config: Optional[QuantConfig] = None, + q_scaling: Optional[float] = None, **kwargs, ): - super().__init__(layer_idx, num_heads, head_dim, num_kv_heads, - quant_config, **kwargs) + super().__init__(layer_idx, + num_heads, + head_dim, + num_kv_heads=num_kv_heads, + quant_config=quant_config, + **kwargs) self.num_key_value_groups = self.num_heads // self.num_kv_heads + self.q_scaling = q_scaling def _single_request_update_kv_cache(self, k, v, kv_cache_tensor, seq_len, cache_idx, cache_position): @@ -86,8 +110,15 @@ class VanillaAttention(AttentionBackend[VanillaAttentionMetadata]): return k_out[:, :seq_len, :, :], v_out[:, :seq_len, :, :] - def _single_request_forward(self, q, k, v, attention_mask: AttentionMask, - kv_cache_tensor, past_seen_token, cache_idx): + def _single_request_forward(self, + q, + k, + v, + attention_mask: AttentionMask, + kv_cache_tensor, + past_seen_token, + cache_idx, + attention_window_size: Optional[int] = None): bsz = 1 q_len = q.size(0) @@ -129,7 +160,12 @@ class VanillaAttention(AttentionBackend[VanillaAttentionMetadata]): is_causal = False attn_mask = None if attention_mask == PredefinedAttentionMask.CAUSAL: - if past_seen_token == 0: + # Create custom sliding window mask as sdpa doesn't natively support it. + if attention_window_size is not None: + attn_mask = generate_sliding_window_mask( + bsz, target_seq_len, cache_position, q.device, + attention_window_size) + elif past_seen_token == 0: is_causal = True elif q_len != 1: # attn_mask: 4-D tensor (batch_size, 1, query_seq_len, seq_len) @@ -140,12 +176,17 @@ class VanillaAttention(AttentionBackend[VanillaAttentionMetadata]): else: raise ValueError("Unexpected attention mask type") + qk_scale = None + if self.q_scaling is not None: + qk_scale = 1 / (math.sqrt(self.head_dim) * self.q_scaling) + attn_output = torch.nn.functional.scaled_dot_product_attention( q, key_states, value_states, is_causal=is_causal, attn_mask=attn_mask, + scale=qk_scale, ) attn_output = attn_output.squeeze(0) @@ -229,6 +270,7 @@ class VanillaAttention(AttentionBackend[VanillaAttentionMetadata]): metadata: VanillaAttentionMetadata, *, attention_mask: AttentionMask = PredefinedAttentionMask.CAUSAL, + attention_window_size: Optional[int] = None, **kwargs) -> torch.Tensor: if metadata.kv_cache_manager is None: # NOTE: WAR for no kv cache attn e.g. BERT, @@ -270,11 +312,9 @@ class VanillaAttention(AttentionBackend[VanillaAttentionMetadata]): past_seen_token = past_seen_tokens[i] cache_idx = cache_indices[i] - attn_output = self._single_request_forward(single_q, single_k, - single_v, attention_mask, - kv_cache_tensor, - past_seen_token, - cache_idx) + attn_output = self._single_request_forward( + single_q, single_k, single_v, attention_mask, kv_cache_tensor, + past_seen_token, cache_idx, attention_window_size) attn_outputs.append(attn_output) offset += seq_len diff --git a/tensorrt_llm/_torch/models/modeling_gemma3.py b/tensorrt_llm/_torch/models/modeling_gemma3.py index 4a55d42714..aa3a5badbb 100644 --- a/tensorrt_llm/_torch/models/modeling_gemma3.py +++ b/tensorrt_llm/_torch/models/modeling_gemma3.py @@ -65,7 +65,7 @@ class Gemma3Attention(Attention): self.attention_window_size = None if is_sliding: rope_params.theta = 10000 - self.attention_window_size = config.sliding_window + self.attention_window_size = config.sliding_window - 1 # Gemma3 sliding window isn't inclusive. pos_embd_params = PositionalEmbeddingParams( type=PositionEmbeddingType.rope_gpt_neox, rope=rope_params, @@ -107,7 +107,6 @@ class Gemma3Attention(Attention): **kwargs, ) -> torch.Tensor: - attention_window_size = self.attention_window_size or attn_metadata.max_seq_len return super().forward(position_ids=position_ids, hidden_states=hidden_states, attn_metadata=attn_metadata, @@ -115,7 +114,7 @@ class Gemma3Attention(Attention): mrope_config=mrope_config, all_reduce_params=all_reduce_params, lora_params=lora_params, - attention_window_size=attention_window_size, + attention_window_size=self.attention_window_size, **kwargs) def apply_qk_norm(self, q, k): diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 97f9b23382..4ffa08e86c 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -437,7 +437,12 @@ class TestGemma3_1BInstruct(LlmapiAccuracyTestHarness): MODEL_PATH = f"{llm_models_root()}/gemma/gemma-3-1b-it/" def test_auto_dtype(self): - with LLM(self.MODEL_PATH) as llm: + # Disabling kv cache reuse as a WAR to deal with gaps in kernel support for Gemma3's non-inclusive sliding window size. + kv_cache_config = KvCacheConfig( + enable_block_reuse=False, + enable_partial_reuse=False, + ) + with LLM(self.MODEL_PATH, kv_cache_config=kv_cache_config) as llm: task = CnnDailymail(self.MODEL_NAME) task.evaluate(llm) diff --git a/tests/unittest/_torch/modeling/test_modeling_gemma3.py b/tests/unittest/_torch/modeling/test_modeling_gemma3.py new file mode 100644 index 0000000000..03f4d7c8c1 --- /dev/null +++ b/tests/unittest/_torch/modeling/test_modeling_gemma3.py @@ -0,0 +1,300 @@ +import unittest +from copy import deepcopy +from dataclasses import dataclass + +import torch +from parameterized import parameterized +from transformers import Gemma3Config +from transformers import Gemma3ForCausalLM as HFGemma3ForCausalLM +from transformers.cache_utils import HybridCache + +import tensorrt_llm +from tensorrt_llm._torch.attention_backend.utils import get_attention_backend +from tensorrt_llm._torch.metadata import KVCacheParams +from tensorrt_llm._torch.model_config import ModelConfig +from tensorrt_llm._torch.models.modeling_gemma3 import Gemma3ForCausalLM +from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager +from tensorrt_llm.bindings.executor import KvCacheConfig +from tensorrt_llm.mapping import Mapping + +# This is copied from https://huggingface.co/google/gemma-3-1b-it/blob/main/config.json. +# Updated to have 1 local layer and 1 global layer. Sliding window size updated to 4. +GEMMA3_1B_MINI_CONFIG = { + "architectures": ["Gemma3ForCausalLM"], + "attention_bias": False, + "attention_dropout": 0.0, + "attn_logit_softcapping": None, + "bos_token_id": 2, + "cache_implementation": "hybrid", + "eos_token_id": [1, 106], + "final_logit_softcapping": None, + "head_dim": 256, + "hidden_activation": "gelu_pytorch_tanh", + "hidden_size": 1152, + "initializer_range": 0.02, + "intermediate_size": 6912, + "max_position_embeddings": 32768, + "model_type": "gemma3_text", + "num_attention_heads": 4, + "num_hidden_layers": 2, # Modified for testing. + "num_key_value_heads": 1, + "pad_token_id": 0, + "query_pre_attn_scalar": 256, + "rms_norm_eps": 1e-06, + "rope_local_base_freq": 10000, + "rope_scaling": None, + "rope_theta": 1000000, + "sliding_window": 4, # Modified for testing. + "sliding_window_pattern": 2, # Modified for testing. + "torch_dtype": "bfloat16", + "transformers_version": "4.50.0.dev0", + "use_cache": True, + "vocab_size": 262144 +} + + +@dataclass(repr=False) +class Scenario: + backend: str + + def __repr__(self) -> str: + return f"backend:{self.backend.lower()}" + + +class TestGemma3(unittest.TestCase): + + def get_kv_cache_manager(self, dtype: torch.dtype, config: Gemma3Config, + tokens_per_block: int, max_seq_len: int, + batch_size: int, num_blocks: int): + if dtype == torch.half: + kv_cache_dtype = tensorrt_llm.bindings.DataType.HALF + elif dtype == torch.bfloat16: + kv_cache_dtype = tensorrt_llm.bindings.DataType.BF16 + else: + raise ValueError("Invalid dtype") + + mapping = Mapping(world_size=1, tp_size=1, rank=0) + kv_cache_config = KvCacheConfig(enable_block_reuse=False, + enable_partial_reuse=False, + copy_on_partial_reuse=False, + max_tokens=num_blocks * + tokens_per_block) + kv_cache_manager = KVCacheManager( + kv_cache_config, + tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF, + num_layers=config.num_hidden_layers, + num_kv_heads=config.num_key_value_heads, + head_dim=config.head_dim, + tokens_per_block=tokens_per_block, + max_seq_len=max_seq_len, + max_batch_size=batch_size, + mapping=mapping, + dtype=kv_cache_dtype, + ) + return kv_cache_manager + + def test_gemma3_sanity(self): + + config_dict = deepcopy(GEMMA3_1B_MINI_CONFIG) + gemma3_config = Gemma3Config.from_dict(config_dict) + + dtype = gemma3_config.torch_dtype + device = torch.device('cuda') + + model_config = ModelConfig(pretrained_config=gemma3_config) + gemma3 = Gemma3ForCausalLM(model_config).to(device) + + input_ids = torch.tensor([100, 200, 300, 400, 500, 600, 700, 800], + dtype=torch.int, + device=device) + + context_sequence_lengths = [3, 2, 1] + sequence_lengths = context_sequence_lengths + [1, 1] + past_seen_tokens = [0, 0, 0, 62, 75] + request_ids = list(range(len(sequence_lengths))) + token_nums = (torch.tensor(past_seen_tokens) + + torch.tensor(sequence_lengths)).tolist() + prompt_lens = token_nums[:3] + past_seen_tokens[3:] + + num_blocks = 100 + tokens_per_block = 128 + max_seq_len = num_blocks * tokens_per_block + batch_size = len(context_sequence_lengths) + 2 + kv_cache_manager = self.get_kv_cache_manager( + dtype=dtype, + config=gemma3_config, + tokens_per_block=tokens_per_block, + max_seq_len=max_seq_len, + batch_size=batch_size, + num_blocks=num_blocks) + kv_cache_manager.add_dummy_requests(request_ids, token_nums) + + metadata_cls = get_attention_backend(model_config.attn_backend).Metadata + attn_metadata = metadata_cls( + seq_lens=torch.tensor(sequence_lengths, dtype=torch.int), + num_contexts=len(context_sequence_lengths), + kv_cache_params=KVCacheParams( + use_cache=True, + num_cached_tokens_per_seq=past_seen_tokens, + ), + kv_cache_manager=kv_cache_manager, + request_ids=request_ids, + prompt_lens=prompt_lens, + max_num_requests=len(context_sequence_lengths) + 2, + max_num_tokens=8192, + ) + + position_ids = [] + for i, tokens in enumerate(past_seen_tokens): + seq_len = context_sequence_lengths[i] if i < len( + context_sequence_lengths) else 1 + position_id = torch.arange(tokens, + tokens + seq_len, + device=input_ids.device) + position_ids.append(position_id) + + position_ids = torch.cat(position_ids).unsqueeze(0) + + with torch.inference_mode(): + attn_metadata.prepare() + logits = gemma3.forward(input_ids=input_ids, + position_ids=position_ids, + attn_metadata=attn_metadata) + + self.assertEqual(len(past_seen_tokens), logits.shape[0]) + + with torch.inference_mode(): + attn_metadata.prepare() + logits = gemma3.forward(input_ids=input_ids, + position_ids=position_ids, + attn_metadata=attn_metadata, + return_context_logits=True) + self.assertEqual(input_ids.shape, logits.shape[:-1]) + + kv_cache_manager.shutdown() + + @parameterized.expand([ + Scenario(backend="TRTLLM"), + Scenario(backend="VANILLA"), + ], lambda testcase_func, param_num, param: + f"{testcase_func.__name__}[{param.args[0]}]") + @torch.no_grad() + def test_gemma3_allclose_to_hf(self, scenario: Scenario) -> None: + """ + Compare output to HF. + """ + backend = scenario.backend + metadata_cls = get_attention_backend(backend).Metadata + + torch.random.manual_seed(0) + config_dict = deepcopy(GEMMA3_1B_MINI_CONFIG) + gemma3_config = Gemma3Config.from_dict(config_dict) + dtype = gemma3_config.torch_dtype + device = torch.device('cuda') + + num_blocks = 1 + tokens_per_block = 128 + max_seq_len = num_blocks * tokens_per_block + batch_size = 1 + + hf_gemma3 = HFGemma3ForCausalLM(gemma3_config).to(dtype).to( + device).eval() + hf_cache = HybridCache(config=gemma3_config, + max_batch_size=batch_size, + max_cache_len=10, + device=device, + dtype=dtype) + + model_config = ModelConfig(pretrained_config=gemma3_config, + attn_backend=backend) + gemma3 = Gemma3ForCausalLM(model_config).to(dtype).to(device) + gemma3.load_weights(hf_gemma3.state_dict()) + + kv_cache_manager = self.get_kv_cache_manager( + dtype=dtype, + config=gemma3_config, + tokens_per_block=tokens_per_block, + max_seq_len=max_seq_len, + batch_size=batch_size, + num_blocks=num_blocks) + + # Context phase. + input_ids = torch.tensor([100, 200, 300, 400, 500, 600, 700, 800], + dtype=torch.int32, + device=device) + num_cached_tokens_per_seq = [0] + request_ids = [1] + token_nums = [input_ids.size(-1)] + prompt_lens = [input_ids.size(-1)] + kv_cache_manager.add_dummy_requests(request_ids, token_nums) + + attn_metadata = metadata_cls( + seq_lens=torch.tensor([input_ids.size(-1)], dtype=torch.int), + num_contexts=1, + kv_cache_params=KVCacheParams( + use_cache=True, + num_cached_tokens_per_seq=num_cached_tokens_per_seq, + ), + max_num_requests=1, + max_num_tokens=8192, + kv_cache_manager=kv_cache_manager, + request_ids=request_ids, + prompt_lens=prompt_lens, + ) + position_ids = [torch.arange(0, input_ids.size(-1), dtype=torch.int32)] + position_ids = torch.cat(position_ids).unsqueeze(0).cuda() + + with torch.inference_mode(): + attn_metadata.prepare() + logits = gemma3.forward(input_ids=input_ids, + position_ids=position_ids, + attn_metadata=attn_metadata) + ref = hf_gemma3.forward(input_ids=input_ids.unsqueeze(0), + position_ids=position_ids, + past_key_values=hf_cache, + use_cache=True) + torch.testing.assert_close(logits, + ref.logits[:, -1].float(), + atol=0.05, + rtol=0.05) + + # Generation phase. + gen_input_ids = torch.tensor([900], dtype=torch.int, device=device) + num_cached_tokens_per_seq = [input_ids.size(-1)] + attn_metadata = metadata_cls( + seq_lens=torch.tensor([gen_input_ids.size(-1)], dtype=torch.int), + num_contexts=0, + kv_cache_params=KVCacheParams( + use_cache=True, + num_cached_tokens_per_seq=num_cached_tokens_per_seq, + ), + kv_cache_manager=kv_cache_manager, + request_ids=request_ids, + prompt_lens=prompt_lens, + max_num_requests=1, + max_num_tokens=8192, + ) + + gen_position_ids = [ + torch.arange(input_ids.size(-1), + input_ids.size(-1) + gen_input_ids.size(-1)) + ] + gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda() + with torch.inference_mode(): + attn_metadata.prepare() + logits = gemma3.forward(input_ids=gen_input_ids, + position_ids=gen_position_ids, + attn_metadata=attn_metadata) + ref = hf_gemma3.forward(input_ids=gen_input_ids.unsqueeze(0), + position_ids=gen_position_ids, + past_key_values=hf_cache, + use_cache=True, + cache_position=torch.IntTensor( + [input_ids.size(-1)]).to(device), + last_cache_position=input_ids.size(-1) + 1) + torch.testing.assert_close(logits, + ref.logits[:, -1].float(), + atol=0.05, + rtol=0.05) + + kv_cache_manager.shutdown()