[Kernel][Test] Extend lightning_attn and awq_triton kernel tests to XPU (#43307)

Signed-off-by: Dobrzyniewicz, Agata <agata.dobrzyniewicz@intel.com>
This commit is contained in:
Agata Dobrzyniewicz
2026-06-04 20:25:59 +02:00
committed by GitHub
parent 439203d32c
commit a947f7a420
3 changed files with 37 additions and 19 deletions
+19 -11
View File
@@ -5,8 +5,16 @@ import pytest
import torch
from vllm.model_executor.layers.lightning_attn import linear_decode_forward_triton
from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed
DEVICE = current_platform.device_type
pytestmark = pytest.mark.skipif(
not (current_platform.is_cuda_alike() or current_platform.is_xpu()),
reason="Lightning attention Triton kernels require CUDA/ROCm or XPU.",
)
NUM_HEADS = [4, 8]
HEAD_SIZES = [64]
BATCH_SIZES = [1, 2]
@@ -121,7 +129,7 @@ def test_linear_decode_forward_triton(
head_size: int,
dtype: torch.dtype,
):
torch.set_default_device("cuda")
torch.set_default_device(DEVICE)
set_random_seed(42)
base = 0.01
q = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
@@ -129,16 +137,16 @@ def test_linear_decode_forward_triton(
v = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
kv_caches = base * torch.randn(
batch_size, num_heads, head_size, head_size, dtype=dtype, device="cuda"
batch_size, num_heads, head_size, head_size, dtype=dtype, device=DEVICE
)
kv_caches_copy = kv_caches.clone()
slope_rate = torch.zeros(num_heads, device="cuda")
slope_rate = torch.zeros(num_heads, device=DEVICE)
for h in range(num_heads):
slope_rate[h] = 0.1 * (h + 1)
slot_idx = torch.arange(batch_size, device="cuda")
slot_idx = torch.arange(batch_size, device=DEVICE)
triton_output = linear_decode_forward_triton(
q, k, v, kv_caches, slope_rate, slot_idx
@@ -162,7 +170,7 @@ def test_linear_decode_forward_triton_with_padding(
head_size: int,
dtype: torch.dtype,
):
torch.set_default_device("cuda")
torch.set_default_device(DEVICE)
set_random_seed(42)
batch_size = 4
@@ -172,16 +180,16 @@ def test_linear_decode_forward_triton_with_padding(
v = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
kv_caches = base * torch.randn(
batch_size, num_heads, head_size, head_size, dtype=dtype, device="cuda"
batch_size, num_heads, head_size, head_size, dtype=dtype, device=DEVICE
)
kv_caches_copy = kv_caches.clone()
slope_rate = torch.zeros(num_heads, device="cuda")
slope_rate = torch.zeros(num_heads, device=DEVICE)
for h in range(num_heads):
slope_rate[h] = 0.1 * (h + 1)
slot_idx = torch.tensor([0, 1, -1, 2], device="cuda")
slot_idx = torch.tensor([0, 1, -1, 2], device=DEVICE)
triton_output = linear_decode_forward_triton(
q, k, v, kv_caches, slope_rate, slot_idx
@@ -224,7 +232,7 @@ def test_lightning_attention_reference(
seq_len: int,
dtype: torch.dtype,
):
torch.set_default_device("cuda")
torch.set_default_device(DEVICE)
set_random_seed(42)
base = 0.01
@@ -232,12 +240,12 @@ def test_lightning_attention_reference(
k = base * torch.randn(batch_size, num_heads, seq_len, head_size, dtype=dtype)
v = base * torch.randn(batch_size, num_heads, seq_len, head_size, dtype=dtype)
ed = torch.zeros(num_heads, device="cuda")
ed = torch.zeros(num_heads, device=DEVICE)
for h in range(num_heads):
ed[h] = 0.1 * (h + 1)
kv_history = base * torch.randn(
batch_size, num_heads, head_size, head_size, dtype=dtype, device="cuda"
batch_size, num_heads, head_size, head_size, dtype=dtype, device=DEVICE
)
kv_history_clone = kv_history.clone()
@@ -13,9 +13,15 @@ from vllm.model_executor.layers.quantization.awq_triton import (
awq_dequantize_triton,
awq_gemm_triton,
)
from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed
device = "cuda"
pytestmark = pytest.mark.skipif(
not (current_platform.is_cuda_alike() or current_platform.is_xpu()),
reason="AWQ Triton kernels require CUDA/ROCm or XPU.",
)
device = current_platform.device_type
def reverse_awq_order(t: torch.Tensor):
+11 -7
View File
@@ -4,6 +4,7 @@
import torch
from einops import rearrange
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.v1.attention.backends.utils import PAD_SLOT_ID
@@ -403,13 +404,16 @@ class _attention(torch.autograd.Function):
v = v.contiguous()
s = s.contiguous()
# Check CUDA compute capability
capability = torch.cuda.get_device_capability()
if capability[0] < 8:
raise RuntimeError(
"Flash attention currently only supported",
"for compute capability >= 80",
)
# Check CUDA compute capability (Ampere+ required for flash attention
# path). Other accelerators (ROCm, XPU) rely on their own Triton
# backend support and skip this check.
if current_platform.is_cuda():
capability = torch.cuda.get_device_capability()
if capability[0] < 8:
raise RuntimeError(
"Flash attention currently only supported",
"for compute capability >= 80",
)
# Get input dimensions
b, h, n, d = q.shape