diff --git a/tests/unittest/_torch/test_fp4_gemm_quantize.py b/tests/unittest/_torch/test_fp4_gemm_quantize.py index b013aa616b..22dec5dbf1 100644 --- a/tests/unittest/_torch/test_fp4_gemm_quantize.py +++ b/tests/unittest/_torch/test_fp4_gemm_quantize.py @@ -15,7 +15,6 @@ import unittest -import pytest import torch from parameterized import parameterized from utils.util import skip_pre_blackwell_unittest, unittest_name_func @@ -48,14 +47,13 @@ class TestFunctional(unittest.TestCase): @parameterized.expand( list([ [1024, 1024, 1024], - [7, 32, 32], + [256, 128, 512], ]), name_func=unittest_name_func, ) @skip_pre_blackwell_unittest # TODO: add GEMM test for linear SF layout when kernel is ready def test_fp4_quantize_gemm_torch(self, m, n, k): - pytest.skip("https://nvbugs/5100633") a = torch.randn([m, k], dtype=torch.float32) b = torch.randn([n, k], dtype=torch.float32) a_global_sf = (448 * 6) / a.abs().max().float()