From e27088421e4f536860ed5991de8768984fe43852 Mon Sep 17 00:00:00 2001 From: Iman Tabrizian <10105175+Tabrizian@users.noreply.github.com> Date: Tue, 19 Aug 2025 22:45:09 -0700 Subject: [PATCH] [None][infra] "[TRTLLM-6960][fix] enable scaled_mm tests (#6936)" (#7059) Signed-off-by: Iman Tabrizian --- tests/unittest/_torch/thop/test_scaled_mm.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/unittest/_torch/thop/test_scaled_mm.py b/tests/unittest/_torch/thop/test_scaled_mm.py index 40dfa4e0cb..31149de731 100644 --- a/tests/unittest/_torch/thop/test_scaled_mm.py +++ b/tests/unittest/_torch/thop/test_scaled_mm.py @@ -38,6 +38,11 @@ from utils.util import getSMVersion [torch.float16, torch.float32, torch.bfloat16], ) def test_fp8_scaled_mm(output_dtype, m, k_n): + if getSMVersion() == 90: + pytest.skip( + "Skip test for sm90 because it's too flaky. https://nvbugspro.nvidia.com/bug/5441734" + ) + k, n = k_n torch.random.manual_seed(0) shape_x = (m, k) @@ -71,7 +76,7 @@ def test_fp8_scaled_mm(output_dtype, m, k_n): os.environ["CUBLASLT_WORKSPACE_SIZE"] = old_env np.testing.assert_allclose(ref.float().cpu(), output.float().cpu(), - atol=0.01, + atol=1, rtol=0.01) if getSMVersion() == 90: