From 5881a6537459c69b7346b6612253632d2bd34688 Mon Sep 17 00:00:00 2001 From: hlu1 <14827759+hlu1@users.noreply.github.com> Date: Mon, 14 Apr 2025 23:58:31 -0700 Subject: [PATCH] Fix test_fp4_quantize_gemm_torch (#3551) Signed-off-by: Hao Lu <14827759+hlu1@users.noreply.github.com> --- tests/unittest/_torch/test_fp4_gemm_quantize.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/unittest/_torch/test_fp4_gemm_quantize.py b/tests/unittest/_torch/test_fp4_gemm_quantize.py index b013aa616b..22dec5dbf1 100644 --- a/tests/unittest/_torch/test_fp4_gemm_quantize.py +++ b/tests/unittest/_torch/test_fp4_gemm_quantize.py @@ -15,7 +15,6 @@ import unittest -import pytest import torch from parameterized import parameterized from utils.util import skip_pre_blackwell_unittest, unittest_name_func @@ -48,14 +47,13 @@ class TestFunctional(unittest.TestCase): @parameterized.expand( list([ [1024, 1024, 1024], - [7, 32, 32], + [256, 128, 512], ]), name_func=unittest_name_func, ) @skip_pre_blackwell_unittest # TODO: add GEMM test for linear SF layout when kernel is ready def test_fp4_quantize_gemm_torch(self, m, n, k): - pytest.skip("https://nvbugs/5100633") a = torch.randn([m, k], dtype=torch.float32) b = torch.randn([n, k], dtype=torch.float32) a_global_sf = (448 * 6) / a.abs().max().float()