[ROCm][CI] Fix stale wvSplitK GEMM fallback test for N=5 (#44368)

Signed-off-by: JartX <sagformas@epdcenter.es>
This commit is contained in:
JartX
2026-06-03 05:00:25 +02:00
committed by GitHub
parent 02a01496fc
commit 4454a18695
@@ -41,8 +41,10 @@ def test_rocm_unquantized_gemm_gfx1x_wvsplitk_path(monkeypatch):
assert torch.allclose(out, ref, atol=1e-3, rtol=1e-3)
def test_rocm_unquantized_gemm_gfx1x_n_gt_4_falls_back(monkeypatch):
x = torch.randn(5, 64, dtype=torch.float16)
def test_rocm_unquantized_gemm_gfx1x_n_gt_5_falls_back(monkeypatch):
# wvSplitK skinny GEMM handles n in [1, 5] (see PR #40687); n > 5 must
# fall back to torch.nn.functional.linear.
x = torch.randn(6, 64, dtype=torch.float16)
weight = torch.randn(128, 64, dtype=torch.float16)
monkeypatch.setattr(utils, "use_aiter_triton_gemm", lambda *args: False)