mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[None][fix] [Gemma3] Fix RoPE for local attention for Gemma3 (#9961)
Signed-off-by: Shiv Ghai <8965168+shivghai@users.noreply.github.com>
This commit is contained in:
parent
1865020b6f
commit
ee07a7c55e
@ -702,12 +702,16 @@ class Attention(Module):
|
||||
is_buffer=True))
|
||||
else:
|
||||
|
||||
def register_rope_params(rotary_base, names_to_register):
|
||||
def register_rope_params(rotary_base, rotary_embedding_scale,
|
||||
rotary_embedding_scale_type,
|
||||
rotary_embedding_scaling,
|
||||
names_to_register):
|
||||
# Rotary const weights.
|
||||
embed_positions = RopeEmbeddingUtils.create_sinusoidal_positions(
|
||||
max_position_embeddings,
|
||||
rotary_embedding_dim,
|
||||
)
|
||||
|
||||
rotary_inv_freq, embed_positions_for_gpt_attention = RopeEmbeddingUtils.create_sinusoidal_positions_for_attention_plugin(
|
||||
max_position_embeddings, rotary_embedding_dim, rotary_base,
|
||||
rotary_embedding_scale, rotary_embedding_scale_type,
|
||||
@ -724,7 +728,11 @@ class Attention(Module):
|
||||
dtype='float32',
|
||||
is_buffer=True))
|
||||
|
||||
register_rope_params(rotary_base=rotary_embedding_base,
|
||||
register_rope_params(
|
||||
rotary_base=rotary_embedding_base,
|
||||
rotary_embedding_scale=rotary_embedding_scale,
|
||||
rotary_embedding_scale_type=rotary_embedding_scale_type,
|
||||
rotary_embedding_scaling=rotary_embedding_scaling,
|
||||
names_to_register=[
|
||||
'embed_positions', 'rotary_inv_freq',
|
||||
'embed_positions_for_gpt_attention'
|
||||
@ -736,6 +744,9 @@ class Attention(Module):
|
||||
if rotary_embedding_base_local is not None:
|
||||
register_rope_params(
|
||||
rotary_base=rotary_embedding_base_local,
|
||||
rotary_embedding_scale=1.0,
|
||||
rotary_embedding_scale_type=RotaryScalingType.none,
|
||||
rotary_embedding_scaling=None,
|
||||
names_to_register=[
|
||||
'embed_positions_local', 'rotary_inv_freq_local',
|
||||
'embed_positions_for_gpt_attention_local'
|
||||
@ -1141,10 +1152,12 @@ class Attention(Module):
|
||||
rotary_embedding_dim=self.rotary_embedding_dim,
|
||||
rotary_embedding_base=self.rotary_embedding_base
|
||||
if not self.is_local else self.rotary_embedding_base_local,
|
||||
rotary_embedding_scale_type=self.rotary_embedding_scale_type,
|
||||
rotary_embedding_scale_type=self.rotary_embedding_scale_type
|
||||
if not self.is_local else RotaryScalingType.none,
|
||||
rotary_embedding_short_m_scale=attention_params.short_mscale,
|
||||
rotary_embedding_long_m_scale=attention_params.long_mscale,
|
||||
rotary_embedding_scale=self.rotary_embedding_scale,
|
||||
rotary_embedding_scale=self.rotary_embedding_scale
|
||||
if not self.is_local else 1.0,
|
||||
rotary_embedding_max_positions=self.max_position_embeddings,
|
||||
rotary_embedding_original_max_positions=self.
|
||||
original_max_position_embeddings,
|
||||
|
||||
@ -2115,6 +2115,112 @@ class TestLayer(unittest.TestCase):
|
||||
atol=atol,
|
||||
rtol=rtol)
|
||||
|
||||
def test_gemma3_local_attention_rope_scaling(self):
|
||||
"""
|
||||
Test that local attention layers in Gemma3 do NOT apply rope scaling,
|
||||
even when the config has rope_scaling defined.
|
||||
|
||||
This is important for Gemma3 which uses different RoPE parameters for
|
||||
local (sliding window) attention vs global attention layers. The fix
|
||||
ensures that local attention layers get scale=1.0 and scale_type=none,
|
||||
while global layers get the configured scaling.
|
||||
"""
|
||||
from tensorrt_llm.functional import PositionEmbeddingType
|
||||
from tensorrt_llm.layers.attention import Attention
|
||||
|
||||
# Create a mock config similar to Gemma3 27B with rope_scaling
|
||||
class MockGemma3Config:
|
||||
hidden_size = 5376
|
||||
num_attention_heads = 32
|
||||
head_size = 128
|
||||
max_position_embeddings = 32768
|
||||
position_embedding_type = PositionEmbeddingType.rope_gpt_neox
|
||||
# Use small rotary base values to avoid numerical instability in tests.
|
||||
# Large bases (e.g. 1000000) get exponentiated, causing potential flakiness
|
||||
# when comparing floating point results.
|
||||
rotary_base = 100.0
|
||||
rotary_scaling = {"factor": 8.0, "rope_type": "linear"}
|
||||
rotary_pct = 1.0
|
||||
# Local attention uses a different base frequency
|
||||
rope_local_base_freq = 10.0
|
||||
|
||||
# Create a mock model class to receive registered parameters
|
||||
class MockModelCls:
|
||||
position_embedding_type = PositionEmbeddingType.rope_gpt_neox
|
||||
|
||||
@classmethod
|
||||
def register_parameter(cls, name, param):
|
||||
setattr(cls, name, param)
|
||||
|
||||
config = MockGemma3Config()
|
||||
|
||||
# Call the method that creates attention const params
|
||||
Attention.create_attention_const_params(MockModelCls, config)
|
||||
|
||||
# Verify that global rope parameters are registered
|
||||
self.assertTrue(hasattr(MockModelCls, 'embed_positions'),
|
||||
"Global embed_positions should be registered")
|
||||
self.assertTrue(hasattr(MockModelCls, 'rotary_inv_freq'),
|
||||
"Global rotary_inv_freq should be registered")
|
||||
self.assertTrue(
|
||||
hasattr(MockModelCls, 'embed_positions_for_gpt_attention'),
|
||||
"Global embed_positions_for_gpt_attention should be registered")
|
||||
|
||||
# Verify that local rope parameters are registered (since rope_local_base_freq is set)
|
||||
self.assertTrue(hasattr(MockModelCls, 'embed_positions_local'),
|
||||
"Local embed_positions should be registered")
|
||||
self.assertTrue(hasattr(MockModelCls, 'rotary_inv_freq_local'),
|
||||
"Local rotary_inv_freq should be registered")
|
||||
self.assertTrue(
|
||||
hasattr(MockModelCls, 'embed_positions_for_gpt_attention_local'),
|
||||
"Local embed_positions_for_gpt_attention should be registered")
|
||||
|
||||
# Get the parameter values
|
||||
global_inv_freq = MockModelCls.rotary_inv_freq.raw_value
|
||||
local_inv_freq = MockModelCls.rotary_inv_freq_local.raw_value
|
||||
global_cos_sin = MockModelCls.embed_positions_for_gpt_attention.raw_value
|
||||
local_cos_sin = MockModelCls.embed_positions_for_gpt_attention_local.raw_value
|
||||
|
||||
# The global and local inv_freq should be different because:
|
||||
# 1. Global uses rope_scaling with factor=8.0 (linear scaling applies 1/8 to inv_freq)
|
||||
# 2. Local uses scale=1.0 (no scaling)
|
||||
self.assertFalse(
|
||||
np.allclose(global_inv_freq, local_inv_freq),
|
||||
"Global and local rotary_inv_freq should be different "
|
||||
"(global has scaling, local does not)")
|
||||
|
||||
# The cos/sin embeddings should also be different
|
||||
self.assertFalse(
|
||||
np.allclose(global_cos_sin, local_cos_sin),
|
||||
"Global and local embed_positions_for_gpt_attention should be different "
|
||||
"(global has scaling, local does not)")
|
||||
|
||||
# Additional verification: Check that local inv_freq matches unscaled calculation
|
||||
# For local attention with scale=1.0 and base=10:
|
||||
# inv_freq = 1.0 / (10 ** (arange(0, dim, 2) / dim))
|
||||
dim = config.head_size # rotary_embedding_dim = head_size * rotary_pct = 128
|
||||
expected_local_inv_freq = 1.0 / (config.rope_local_base_freq
|
||||
**(np.arange(0, dim, 2) / dim))
|
||||
|
||||
np.testing.assert_allclose(
|
||||
local_inv_freq,
|
||||
expected_local_inv_freq,
|
||||
rtol=1e-5,
|
||||
err_msg="Local rotary_inv_freq should be computed WITHOUT scaling")
|
||||
|
||||
# For global attention with linear scaling (factor=8.0):
|
||||
# scale = 1.0 / 8.0 = 0.125
|
||||
# inv_freq = 0.125 / (100 ** (arange(0, dim, 2) / dim))
|
||||
expected_global_inv_freq = (1.0 / 8.0) / (config.rotary_base**
|
||||
(np.arange(0, dim, 2) / dim))
|
||||
|
||||
np.testing.assert_allclose(
|
||||
global_inv_freq,
|
||||
expected_global_inv_freq,
|
||||
rtol=1e-5,
|
||||
err_msg=
|
||||
"Global rotary_inv_freq should be computed WITH linear scaling")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user