[None][fix] [Gemma3] Fix RoPE for local attention for Gemma3 (#9961)

Signed-off-by: Shiv Ghai <8965168+shivghai@users.noreply.github.com>
This commit is contained in:
shivghai 2025-12-27 14:50:59 -05:00 committed by GitHub
parent 1865020b6f
commit ee07a7c55e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 127 additions and 8 deletions

View File

@ -702,12 +702,16 @@ class Attention(Module):
is_buffer=True))
else:
def register_rope_params(rotary_base, names_to_register):
def register_rope_params(rotary_base, rotary_embedding_scale,
rotary_embedding_scale_type,
rotary_embedding_scaling,
names_to_register):
# Rotary const weights.
embed_positions = RopeEmbeddingUtils.create_sinusoidal_positions(
max_position_embeddings,
rotary_embedding_dim,
)
rotary_inv_freq, embed_positions_for_gpt_attention = RopeEmbeddingUtils.create_sinusoidal_positions_for_attention_plugin(
max_position_embeddings, rotary_embedding_dim, rotary_base,
rotary_embedding_scale, rotary_embedding_scale_type,
@ -724,11 +728,15 @@ class Attention(Module):
dtype='float32',
is_buffer=True))
register_rope_params(rotary_base=rotary_embedding_base,
names_to_register=[
'embed_positions', 'rotary_inv_freq',
'embed_positions_for_gpt_attention'
])
register_rope_params(
rotary_base=rotary_embedding_base,
rotary_embedding_scale=rotary_embedding_scale,
rotary_embedding_scale_type=rotary_embedding_scale_type,
rotary_embedding_scaling=rotary_embedding_scaling,
names_to_register=[
'embed_positions', 'rotary_inv_freq',
'embed_positions_for_gpt_attention'
])
# For models with non-homegeneous attention layers requiring a second set of rope params. e.g. Gemma3.
rotary_embedding_base_local = getattr(config,
@ -736,6 +744,9 @@ class Attention(Module):
if rotary_embedding_base_local is not None:
register_rope_params(
rotary_base=rotary_embedding_base_local,
rotary_embedding_scale=1.0,
rotary_embedding_scale_type=RotaryScalingType.none,
rotary_embedding_scaling=None,
names_to_register=[
'embed_positions_local', 'rotary_inv_freq_local',
'embed_positions_for_gpt_attention_local'
@ -1141,10 +1152,12 @@ class Attention(Module):
rotary_embedding_dim=self.rotary_embedding_dim,
rotary_embedding_base=self.rotary_embedding_base
if not self.is_local else self.rotary_embedding_base_local,
rotary_embedding_scale_type=self.rotary_embedding_scale_type,
rotary_embedding_scale_type=self.rotary_embedding_scale_type
if not self.is_local else RotaryScalingType.none,
rotary_embedding_short_m_scale=attention_params.short_mscale,
rotary_embedding_long_m_scale=attention_params.long_mscale,
rotary_embedding_scale=self.rotary_embedding_scale,
rotary_embedding_scale=self.rotary_embedding_scale
if not self.is_local else 1.0,
rotary_embedding_max_positions=self.max_position_embeddings,
rotary_embedding_original_max_positions=self.
original_max_position_embeddings,

View File

@ -2115,6 +2115,112 @@ class TestLayer(unittest.TestCase):
atol=atol,
rtol=rtol)
def test_gemma3_local_attention_rope_scaling(self):
"""
Test that local attention layers in Gemma3 do NOT apply rope scaling,
even when the config has rope_scaling defined.
This is important for Gemma3 which uses different RoPE parameters for
local (sliding window) attention vs global attention layers. The fix
ensures that local attention layers get scale=1.0 and scale_type=none,
while global layers get the configured scaling.
"""
from tensorrt_llm.functional import PositionEmbeddingType
from tensorrt_llm.layers.attention import Attention
# Create a mock config similar to Gemma3 27B with rope_scaling
class MockGemma3Config:
hidden_size = 5376
num_attention_heads = 32
head_size = 128
max_position_embeddings = 32768
position_embedding_type = PositionEmbeddingType.rope_gpt_neox
# Use small rotary base values to avoid numerical instability in tests.
# Large bases (e.g. 1000000) get exponentiated, causing potential flakiness
# when comparing floating point results.
rotary_base = 100.0
rotary_scaling = {"factor": 8.0, "rope_type": "linear"}
rotary_pct = 1.0
# Local attention uses a different base frequency
rope_local_base_freq = 10.0
# Create a mock model class to receive registered parameters
class MockModelCls:
position_embedding_type = PositionEmbeddingType.rope_gpt_neox
@classmethod
def register_parameter(cls, name, param):
setattr(cls, name, param)
config = MockGemma3Config()
# Call the method that creates attention const params
Attention.create_attention_const_params(MockModelCls, config)
# Verify that global rope parameters are registered
self.assertTrue(hasattr(MockModelCls, 'embed_positions'),
"Global embed_positions should be registered")
self.assertTrue(hasattr(MockModelCls, 'rotary_inv_freq'),
"Global rotary_inv_freq should be registered")
self.assertTrue(
hasattr(MockModelCls, 'embed_positions_for_gpt_attention'),
"Global embed_positions_for_gpt_attention should be registered")
# Verify that local rope parameters are registered (since rope_local_base_freq is set)
self.assertTrue(hasattr(MockModelCls, 'embed_positions_local'),
"Local embed_positions should be registered")
self.assertTrue(hasattr(MockModelCls, 'rotary_inv_freq_local'),
"Local rotary_inv_freq should be registered")
self.assertTrue(
hasattr(MockModelCls, 'embed_positions_for_gpt_attention_local'),
"Local embed_positions_for_gpt_attention should be registered")
# Get the parameter values
global_inv_freq = MockModelCls.rotary_inv_freq.raw_value
local_inv_freq = MockModelCls.rotary_inv_freq_local.raw_value
global_cos_sin = MockModelCls.embed_positions_for_gpt_attention.raw_value
local_cos_sin = MockModelCls.embed_positions_for_gpt_attention_local.raw_value
# The global and local inv_freq should be different because:
# 1. Global uses rope_scaling with factor=8.0 (linear scaling applies 1/8 to inv_freq)
# 2. Local uses scale=1.0 (no scaling)
self.assertFalse(
np.allclose(global_inv_freq, local_inv_freq),
"Global and local rotary_inv_freq should be different "
"(global has scaling, local does not)")
# The cos/sin embeddings should also be different
self.assertFalse(
np.allclose(global_cos_sin, local_cos_sin),
"Global and local embed_positions_for_gpt_attention should be different "
"(global has scaling, local does not)")
# Additional verification: Check that local inv_freq matches unscaled calculation
# For local attention with scale=1.0 and base=10:
# inv_freq = 1.0 / (10 ** (arange(0, dim, 2) / dim))
dim = config.head_size # rotary_embedding_dim = head_size * rotary_pct = 128
expected_local_inv_freq = 1.0 / (config.rope_local_base_freq
**(np.arange(0, dim, 2) / dim))
np.testing.assert_allclose(
local_inv_freq,
expected_local_inv_freq,
rtol=1e-5,
err_msg="Local rotary_inv_freq should be computed WITHOUT scaling")
# For global attention with linear scaling (factor=8.0):
# scale = 1.0 / 8.0 = 0.125
# inv_freq = 0.125 / (100 ** (arange(0, dim, 2) / dim))
expected_global_inv_freq = (1.0 / 8.0) / (config.rotary_base**
(np.arange(0, dim, 2) / dim))
np.testing.assert_allclose(
global_inv_freq,
expected_global_inv_freq,
rtol=1e-5,
err_msg=
"Global rotary_inv_freq should be computed WITH linear scaling")
if __name__ == '__main__':
unittest.main()