mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[ROCm][CI] Fix stale wvSplitK GEMM fallback test for N=5 (#44368)
Signed-off-by: JartX <sagformas@epdcenter.es>
This commit is contained in:
@@ -41,8 +41,10 @@ def test_rocm_unquantized_gemm_gfx1x_wvsplitk_path(monkeypatch):
|
||||
assert torch.allclose(out, ref, atol=1e-3, rtol=1e-3)
|
||||
|
||||
|
||||
def test_rocm_unquantized_gemm_gfx1x_n_gt_4_falls_back(monkeypatch):
|
||||
x = torch.randn(5, 64, dtype=torch.float16)
|
||||
def test_rocm_unquantized_gemm_gfx1x_n_gt_5_falls_back(monkeypatch):
|
||||
# wvSplitK skinny GEMM handles n in [1, 5] (see PR #40687); n > 5 must
|
||||
# fall back to torch.nn.functional.linear.
|
||||
x = torch.randn(6, 64, dtype=torch.float16)
|
||||
weight = torch.randn(128, 64, dtype=torch.float16)
|
||||
|
||||
monkeypatch.setattr(utils, "use_aiter_triton_gemm", lambda *args: False)
|
||||
|
||||
Reference in New Issue
Block a user