mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Fix test_fp4_quantize_gemm_torch (#3551)
Signed-off-by: Hao Lu <14827759+hlu1@users.noreply.github.com>
This commit is contained in:
parent
668a0335e4
commit
5881a65374
@ -15,7 +15,6 @@
|
||||
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from parameterized import parameterized
|
||||
from utils.util import skip_pre_blackwell_unittest, unittest_name_func
|
||||
@ -48,14 +47,13 @@ class TestFunctional(unittest.TestCase):
|
||||
@parameterized.expand(
|
||||
list([
|
||||
[1024, 1024, 1024],
|
||||
[7, 32, 32],
|
||||
[256, 128, 512],
|
||||
]),
|
||||
name_func=unittest_name_func,
|
||||
)
|
||||
@skip_pre_blackwell_unittest
|
||||
# TODO: add GEMM test for linear SF layout when kernel is ready
|
||||
def test_fp4_quantize_gemm_torch(self, m, n, k):
|
||||
pytest.skip("https://nvbugs/5100633")
|
||||
a = torch.randn([m, k], dtype=torch.float32)
|
||||
b = torch.randn([n, k], dtype=torch.float32)
|
||||
a_global_sf = (448 * 6) / a.abs().max().float()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user