From ee07a7c55e438b14a4398199d517a2fa6e59b534 Mon Sep 17 00:00:00 2001 From: shivghai <8965168+shivghai@users.noreply.github.com> Date: Sat, 27 Dec 2025 14:50:59 -0500 Subject: [PATCH] [None][fix] [Gemma3] Fix RoPE for local attention for Gemma3 (#9961) Signed-off-by: Shiv Ghai <8965168+shivghai@users.noreply.github.com> --- tensorrt_llm/layers/attention.py | 29 +++++--- tests/unittest/others/test_layer.py | 106 ++++++++++++++++++++++++++++ 2 files changed, 127 insertions(+), 8 deletions(-) diff --git a/tensorrt_llm/layers/attention.py b/tensorrt_llm/layers/attention.py index f995b6390d..29b63a4258 100755 --- a/tensorrt_llm/layers/attention.py +++ b/tensorrt_llm/layers/attention.py @@ -702,12 +702,16 @@ class Attention(Module): is_buffer=True)) else: - def register_rope_params(rotary_base, names_to_register): + def register_rope_params(rotary_base, rotary_embedding_scale, + rotary_embedding_scale_type, + rotary_embedding_scaling, + names_to_register): # Rotary const weights. embed_positions = RopeEmbeddingUtils.create_sinusoidal_positions( max_position_embeddings, rotary_embedding_dim, ) + rotary_inv_freq, embed_positions_for_gpt_attention = RopeEmbeddingUtils.create_sinusoidal_positions_for_attention_plugin( max_position_embeddings, rotary_embedding_dim, rotary_base, rotary_embedding_scale, rotary_embedding_scale_type, @@ -724,11 +728,15 @@ class Attention(Module): dtype='float32', is_buffer=True)) - register_rope_params(rotary_base=rotary_embedding_base, - names_to_register=[ - 'embed_positions', 'rotary_inv_freq', - 'embed_positions_for_gpt_attention' - ]) + register_rope_params( + rotary_base=rotary_embedding_base, + rotary_embedding_scale=rotary_embedding_scale, + rotary_embedding_scale_type=rotary_embedding_scale_type, + rotary_embedding_scaling=rotary_embedding_scaling, + names_to_register=[ + 'embed_positions', 'rotary_inv_freq', + 'embed_positions_for_gpt_attention' + ]) # For models with non-homegeneous attention layers requiring a second set of rope params. e.g. Gemma3. rotary_embedding_base_local = getattr(config, @@ -736,6 +744,9 @@ class Attention(Module): if rotary_embedding_base_local is not None: register_rope_params( rotary_base=rotary_embedding_base_local, + rotary_embedding_scale=1.0, + rotary_embedding_scale_type=RotaryScalingType.none, + rotary_embedding_scaling=None, names_to_register=[ 'embed_positions_local', 'rotary_inv_freq_local', 'embed_positions_for_gpt_attention_local' @@ -1141,10 +1152,12 @@ class Attention(Module): rotary_embedding_dim=self.rotary_embedding_dim, rotary_embedding_base=self.rotary_embedding_base if not self.is_local else self.rotary_embedding_base_local, - rotary_embedding_scale_type=self.rotary_embedding_scale_type, + rotary_embedding_scale_type=self.rotary_embedding_scale_type + if not self.is_local else RotaryScalingType.none, rotary_embedding_short_m_scale=attention_params.short_mscale, rotary_embedding_long_m_scale=attention_params.long_mscale, - rotary_embedding_scale=self.rotary_embedding_scale, + rotary_embedding_scale=self.rotary_embedding_scale + if not self.is_local else 1.0, rotary_embedding_max_positions=self.max_position_embeddings, rotary_embedding_original_max_positions=self. original_max_position_embeddings, diff --git a/tests/unittest/others/test_layer.py b/tests/unittest/others/test_layer.py index 8de1189dc9..38bb7f1ef7 100644 --- a/tests/unittest/others/test_layer.py +++ b/tests/unittest/others/test_layer.py @@ -2115,6 +2115,112 @@ class TestLayer(unittest.TestCase): atol=atol, rtol=rtol) + def test_gemma3_local_attention_rope_scaling(self): + """ + Test that local attention layers in Gemma3 do NOT apply rope scaling, + even when the config has rope_scaling defined. + + This is important for Gemma3 which uses different RoPE parameters for + local (sliding window) attention vs global attention layers. The fix + ensures that local attention layers get scale=1.0 and scale_type=none, + while global layers get the configured scaling. + """ + from tensorrt_llm.functional import PositionEmbeddingType + from tensorrt_llm.layers.attention import Attention + + # Create a mock config similar to Gemma3 27B with rope_scaling + class MockGemma3Config: + hidden_size = 5376 + num_attention_heads = 32 + head_size = 128 + max_position_embeddings = 32768 + position_embedding_type = PositionEmbeddingType.rope_gpt_neox + # Use small rotary base values to avoid numerical instability in tests. + # Large bases (e.g. 1000000) get exponentiated, causing potential flakiness + # when comparing floating point results. + rotary_base = 100.0 + rotary_scaling = {"factor": 8.0, "rope_type": "linear"} + rotary_pct = 1.0 + # Local attention uses a different base frequency + rope_local_base_freq = 10.0 + + # Create a mock model class to receive registered parameters + class MockModelCls: + position_embedding_type = PositionEmbeddingType.rope_gpt_neox + + @classmethod + def register_parameter(cls, name, param): + setattr(cls, name, param) + + config = MockGemma3Config() + + # Call the method that creates attention const params + Attention.create_attention_const_params(MockModelCls, config) + + # Verify that global rope parameters are registered + self.assertTrue(hasattr(MockModelCls, 'embed_positions'), + "Global embed_positions should be registered") + self.assertTrue(hasattr(MockModelCls, 'rotary_inv_freq'), + "Global rotary_inv_freq should be registered") + self.assertTrue( + hasattr(MockModelCls, 'embed_positions_for_gpt_attention'), + "Global embed_positions_for_gpt_attention should be registered") + + # Verify that local rope parameters are registered (since rope_local_base_freq is set) + self.assertTrue(hasattr(MockModelCls, 'embed_positions_local'), + "Local embed_positions should be registered") + self.assertTrue(hasattr(MockModelCls, 'rotary_inv_freq_local'), + "Local rotary_inv_freq should be registered") + self.assertTrue( + hasattr(MockModelCls, 'embed_positions_for_gpt_attention_local'), + "Local embed_positions_for_gpt_attention should be registered") + + # Get the parameter values + global_inv_freq = MockModelCls.rotary_inv_freq.raw_value + local_inv_freq = MockModelCls.rotary_inv_freq_local.raw_value + global_cos_sin = MockModelCls.embed_positions_for_gpt_attention.raw_value + local_cos_sin = MockModelCls.embed_positions_for_gpt_attention_local.raw_value + + # The global and local inv_freq should be different because: + # 1. Global uses rope_scaling with factor=8.0 (linear scaling applies 1/8 to inv_freq) + # 2. Local uses scale=1.0 (no scaling) + self.assertFalse( + np.allclose(global_inv_freq, local_inv_freq), + "Global and local rotary_inv_freq should be different " + "(global has scaling, local does not)") + + # The cos/sin embeddings should also be different + self.assertFalse( + np.allclose(global_cos_sin, local_cos_sin), + "Global and local embed_positions_for_gpt_attention should be different " + "(global has scaling, local does not)") + + # Additional verification: Check that local inv_freq matches unscaled calculation + # For local attention with scale=1.0 and base=10: + # inv_freq = 1.0 / (10 ** (arange(0, dim, 2) / dim)) + dim = config.head_size # rotary_embedding_dim = head_size * rotary_pct = 128 + expected_local_inv_freq = 1.0 / (config.rope_local_base_freq + **(np.arange(0, dim, 2) / dim)) + + np.testing.assert_allclose( + local_inv_freq, + expected_local_inv_freq, + rtol=1e-5, + err_msg="Local rotary_inv_freq should be computed WITHOUT scaling") + + # For global attention with linear scaling (factor=8.0): + # scale = 1.0 / 8.0 = 0.125 + # inv_freq = 0.125 / (100 ** (arange(0, dim, 2) / dim)) + expected_global_inv_freq = (1.0 / 8.0) / (config.rotary_base** + (np.arange(0, dim, 2) / dim)) + + np.testing.assert_allclose( + global_inv_freq, + expected_global_inv_freq, + rtol=1e-5, + err_msg= + "Global rotary_inv_freq should be computed WITH linear scaling") + if __name__ == '__main__': unittest.main()