test: Fix Gemma3 unit tests due to transformers upgrade (#5921)

Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
This commit is contained in:
brb-nv 2025-07-10 17:24:10 -07:00 committed by GitHub
parent 854655f2f7
commit 0385f89abc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 39 additions and 21 deletions

View File

@ -192,7 +192,7 @@ class Gemma3DecoderLayer(DecoderLayer):
super().__init__()
self.layer_idx = layer_idx
config = model_config.pretrained_config
is_sliding = bool((layer_idx + 1) % config.sliding_window_pattern)
is_sliding = (config.layer_types[layer_idx] == "sliding_attention")
self.self_attn = Gemma3Attention(
model_config,
layer_idx=layer_idx,

View File

@ -1964,6 +1964,7 @@ def test_ptp_quickstart_advanced_mixed_precision(llm_root, llm_venv):
("qwen2-vl-7b-instruct", "Qwen2-VL-7B-Instruct"),
("qwen2.5-vl-7b-instruct", "Qwen2.5-VL-7B-Instruct"),
("mistral-small-3.1-24b-instruct", "Mistral-Small-3.1-24B-Instruct-2503"),
("gemma-3-27b-it", "gemma/gemma-3-27b-it"),
])
def test_ptp_quickstart_multimodal(llm_root, llm_venv, model_name, model_path,
modality, use_cuda_graph):
@ -2064,6 +2065,13 @@ def test_ptp_quickstart_multimodal(llm_root, llm_venv, model_name, model_path,
["highway", "traffic", "directions", "lanes", "Jurong"],
],
},
"gemma-3-27b-it": {
"image": [
["dramatic", "turbulent", "waves", "ocean", "overcast"],
["half", "dome", "yosemite", "landmark", "rounded"],
["flowing", "standstill", "vehicles", "road", "Changi"],
],
},
}
cmd = [
@ -2083,6 +2091,14 @@ def test_ptp_quickstart_multimodal(llm_root, llm_venv, model_name, model_path,
cmd.append("--max_num_tokens=16384")
if use_cuda_graph:
cmd.append("--use_cuda_graph")
# Gemma3 VLM needs a custom mask which is only supported by flashinfer backend currently.
# Custom mask involves bidirectional masking of image tokens in context phase. To get this
# correct, chunked prefill and kv cache reuse need to be turned off.
if model_name == "gemma-3-27b-it":
cmd.append("--image_format=pil")
cmd.append("--attention_backend=FLASHINFER")
cmd.append("--disable_kv_cache_reuse")
output = llm_venv.run_cmd(cmd, caller=check_output)
def parse_output(text):

View File

@ -525,6 +525,7 @@ test_e2e.py::test_ptp_quickstart_multimodal[qwen2.5-vl-7b-instruct-Qwen2.5-VL-7B
test_e2e.py::test_ptp_quickstart_multimodal[qwen2.5-vl-7b-instruct-Qwen2.5-VL-7B-Instruct-video-True]
test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-image-True]
test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-image-False]
test_e2e.py::test_ptp_quickstart_multimodal[gemma-3-27b-it-gemma/gemma-3-27b-it-image-False]
test_e2e.py::test_ptp_quickstart_bert[VANILLA-BertForSequenceClassification-bert/bert-base-uncased-yelp-polarity]
test_e2e.py::test_ptp_quickstart_bert[TRTLLM-BertForSequenceClassification-bert/bert-base-uncased-yelp-polarity]
test_e2e.py::test_ptp_star_attention_example[Llama3.1-8B-BF16-llama-3.1-model/Meta-Llama-3.1-8B]

View File

@ -19,6 +19,7 @@ l0_h100:
- unittest/_torch -k "modeling_llama"
- unittest/_torch/modeling -k "modeling_mixtral"
- unittest/_torch/modeling -k "modeling_nemotron"
- unittest/_torch/modeling -k "modeling_gemma3"
- unittest/disaggregated/test_disagg_utils.py
- unittest/disaggregated/test_router.py
- unittest/disaggregated/test_remoteDictionary.py

View File

@ -6,6 +6,7 @@ import torch
from parameterized import parameterized
from transformers import Gemma3Config
from transformers import Gemma3ForCausalLM as HFGemma3ForCausalLM
from transformers import Gemma3TextConfig
from transformers.cache_utils import HybridCache
import tensorrt_llm
@ -35,7 +36,7 @@ GEMMA3_1B_CONFIG = {
"max_position_embeddings": 32768,
"model_type": "gemma3_text",
"num_attention_heads": 4,
"num_hidden_layers": 26,
"num_hidden_layers": 6,
"num_key_value_heads": 1,
"pad_token_id": 0,
"query_pre_attn_scalar": 256,
@ -43,7 +44,7 @@ GEMMA3_1B_CONFIG = {
"rope_local_base_freq": 10000,
"rope_scaling": None,
"rope_theta": 1000000,
"sliding_window": 512,
"sliding_window": 4,
"sliding_window_pattern": 6,
"torch_dtype": "bfloat16",
"transformers_version": "4.50.0.dev0",
@ -66,14 +67,15 @@ GEMMA3_27B_CONFIG = {
"intermediate_size": 21504,
"model_type": "gemma3_text",
"num_attention_heads": 32,
"num_hidden_layers": 62,
"num_hidden_layers": 6,
"num_key_value_heads": 16,
"query_pre_attn_scalar": 168,
"rope_scaling": {
"factor": 8.0,
"rope_type": "linear"
},
"sliding_window": 1024
"sliding_window": 4,
"sliding_window_pattern": 6,
},
"torch_dtype": "bfloat16",
"transformers_version": "4.50.0.dev0",
@ -101,7 +103,7 @@ class Scenario:
class TestGemma3(unittest.TestCase):
def get_kv_cache_manager(self, dtype: torch.dtype, config: Gemma3Config,
def get_kv_cache_manager(self, dtype: torch.dtype, config: Gemma3TextConfig,
tokens_per_block: int, max_seq_len: int,
batch_size: int, num_blocks: int):
if dtype == torch.half:
@ -135,7 +137,7 @@ class TestGemma3(unittest.TestCase):
# Using 1B config for sanity test.
config_dict = deepcopy(GEMMA3_1B_CONFIG)
gemma3_config = Gemma3Config.from_dict(config_dict)
gemma3_config = Gemma3TextConfig.from_dict(config_dict)
dtype = gemma3_config.torch_dtype
device = torch.device('cuda')
@ -240,18 +242,16 @@ class TestGemma3(unittest.TestCase):
else:
raise ValueError(f"Unknown config_name: {config_name}")
gemma3_config = Gemma3Config.from_dict(config_dict)
if config_name == "27B":
gemma3_config = Gemma3Config.from_dict(config_dict)
gemma3_config.text_config.torch_dtype = gemma3_config.torch_dtype
gemma3_config = gemma3_config.text_config
else:
gemma3_config = Gemma3TextConfig.from_dict(config_dict)
dtype = gemma3_config.torch_dtype
device = torch.device('cuda')
# 2-layer network with one local (sliding window=4) and one global layer.
gemma3_config.num_hidden_layers = 2
gemma3_config.sliding_window = 4
gemma3_config.sliding_window_pattern = 2
num_blocks = 1
tokens_per_block = 128
max_seq_len = num_blocks * tokens_per_block
@ -326,8 +326,8 @@ class TestGemma3(unittest.TestCase):
use_cache=True)
torch.testing.assert_close(logits,
ref.logits[:, -1].float(),
atol=0.05,
rtol=0.05)
atol=0.4,
rtol=0.4)
# Generation phase.
gen_input_ids = torch.tensor([900], dtype=torch.int, device=device)
@ -360,19 +360,19 @@ class TestGemma3(unittest.TestCase):
position_ids=gen_position_ids,
past_key_values=hf_cache,
use_cache=True,
cache_position=torch.IntTensor(
cache_position=torch.LongTensor(
[input_ids.size(-1)]).to(device),
last_cache_position=input_ids.size(-1) + 1)
torch.testing.assert_close(logits,
ref.logits[:, -1].float(),
atol=0.05,
rtol=0.05)
atol=0.4,
rtol=0.4)
kv_cache_manager.shutdown()
def test_gemma3_flashinfer_mask(self):
config_dict = deepcopy(GEMMA3_1B_CONFIG)
gemma3_config = Gemma3Config.from_dict(config_dict)
gemma3_config = Gemma3TextConfig.from_dict(config_dict)
dtype = gemma3_config.torch_dtype
device = torch.device('cuda')
@ -450,7 +450,7 @@ class TestGemma3(unittest.TestCase):
def test_gemma3_global_context_mask(self) -> None:
config_dict = deepcopy(GEMMA3_1B_CONFIG)
gemma3_config = Gemma3Config.from_dict(config_dict)
gemma3_config = Gemma3TextConfig.from_dict(config_dict)
device = torch.device('cuda')
model_config = ModelConfig(pretrained_config=gemma3_config,
attn_backend="FLASHINFER")
@ -487,7 +487,7 @@ class TestGemma3(unittest.TestCase):
def test_gemma3_local_context_mask(self) -> None:
config_dict = deepcopy(GEMMA3_1B_CONFIG)
gemma3_config = Gemma3Config.from_dict(config_dict)
gemma3_config = Gemma3TextConfig.from_dict(config_dict)
device = torch.device('cuda')
model_config = ModelConfig(pretrained_config=gemma3_config,
attn_backend="FLASHINFER")