[None][chore] Update Gemma3 closeness check to mitigate flakiness (#6591)

Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
This commit is contained in:
brb-nv 2025-08-04 07:10:58 -07:00 committed by GitHub
parent 13cc1c4878
commit 6135f75f87
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -231,6 +231,15 @@ class TestGemma3(unittest.TestCase):
attn_metadata.prepare()
self.assertEqual(len(attn_metadata._plan_params_to_wrappers), 0)
# Allow room for small fraction of elements to fail. This is to mitigate flakiness.
def _assert_most_elems_close(self, actual_value, ref_value, atol, rtol,
max_failed_fraction):
matches = torch.isclose(actual_value, ref_value, atol=atol, rtol=rtol)
failed_fraction = (~matches).float().mean().item()
assert failed_fraction <= max_failed_fraction, (
f"Exceeded tolerance: {failed_fraction*100:.2f}% of elements differ more than allowed "
f"(max allowed {max_failed_fraction*100:.2f}%)")
@parameterized.expand([
Scenario(backend="TRTLLM", config_name="1B"),
Scenario(backend="VANILLA", config_name="1B"),
@ -343,10 +352,11 @@ class TestGemma3(unittest.TestCase):
position_ids=position_ids,
past_key_values=hf_cache,
use_cache=True)
torch.testing.assert_close(logits,
ref.logits[:, -1].float(),
atol=0.4,
rtol=0.4)
self._assert_most_elems_close(actual_value=logits,
ref_value=ref.logits[:, -1].float(),
atol=0.4,
rtol=0.4,
max_failed_fraction=0.001)
self._verify_params_flushed_upon_prepare(attn_metadata)
# Generation phase.
@ -383,10 +393,11 @@ class TestGemma3(unittest.TestCase):
cache_position=torch.LongTensor(
[input_ids.size(-1)]).to(device),
last_cache_position=input_ids.size(-1) + 1)
torch.testing.assert_close(logits,
ref.logits[:, -1].float(),
atol=0.4,
rtol=0.4)
self._assert_most_elems_close(actual_value=logits,
ref_value=ref.logits[:, -1].float(),
atol=0.4,
rtol=0.4,
max_failed_fraction=0.001)
kv_cache_manager.shutdown()