mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[None][chore] Update Gemma3 closeness check to mitigate flakiness (#6591)
Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
This commit is contained in:
parent
13cc1c4878
commit
6135f75f87
@ -231,6 +231,15 @@ class TestGemma3(unittest.TestCase):
|
||||
attn_metadata.prepare()
|
||||
self.assertEqual(len(attn_metadata._plan_params_to_wrappers), 0)
|
||||
|
||||
# Allow room for small fraction of elements to fail. This is to mitigate flakiness.
|
||||
def _assert_most_elems_close(self, actual_value, ref_value, atol, rtol,
|
||||
max_failed_fraction):
|
||||
matches = torch.isclose(actual_value, ref_value, atol=atol, rtol=rtol)
|
||||
failed_fraction = (~matches).float().mean().item()
|
||||
assert failed_fraction <= max_failed_fraction, (
|
||||
f"Exceeded tolerance: {failed_fraction*100:.2f}% of elements differ more than allowed "
|
||||
f"(max allowed {max_failed_fraction*100:.2f}%)")
|
||||
|
||||
@parameterized.expand([
|
||||
Scenario(backend="TRTLLM", config_name="1B"),
|
||||
Scenario(backend="VANILLA", config_name="1B"),
|
||||
@ -343,10 +352,11 @@ class TestGemma3(unittest.TestCase):
|
||||
position_ids=position_ids,
|
||||
past_key_values=hf_cache,
|
||||
use_cache=True)
|
||||
torch.testing.assert_close(logits,
|
||||
ref.logits[:, -1].float(),
|
||||
atol=0.4,
|
||||
rtol=0.4)
|
||||
self._assert_most_elems_close(actual_value=logits,
|
||||
ref_value=ref.logits[:, -1].float(),
|
||||
atol=0.4,
|
||||
rtol=0.4,
|
||||
max_failed_fraction=0.001)
|
||||
self._verify_params_flushed_upon_prepare(attn_metadata)
|
||||
|
||||
# Generation phase.
|
||||
@ -383,10 +393,11 @@ class TestGemma3(unittest.TestCase):
|
||||
cache_position=torch.LongTensor(
|
||||
[input_ids.size(-1)]).to(device),
|
||||
last_cache_position=input_ids.size(-1) + 1)
|
||||
torch.testing.assert_close(logits,
|
||||
ref.logits[:, -1].float(),
|
||||
atol=0.4,
|
||||
rtol=0.4)
|
||||
self._assert_most_elems_close(actual_value=logits,
|
||||
ref_value=ref.logits[:, -1].float(),
|
||||
atol=0.4,
|
||||
rtol=0.4,
|
||||
max_failed_fraction=0.001)
|
||||
|
||||
kv_cache_manager.shutdown()
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user