mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[Kernel][Test] Extend lightning_attn and awq_triton kernel tests to XPU (#43307)
Signed-off-by: Dobrzyniewicz, Agata <agata.dobrzyniewicz@intel.com>
This commit is contained in:
committed by
GitHub
parent
439203d32c
commit
a947f7a420
@@ -5,8 +5,16 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.lightning_attn import linear_decode_forward_triton
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import set_random_seed
|
||||
|
||||
DEVICE = current_platform.device_type
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not (current_platform.is_cuda_alike() or current_platform.is_xpu()),
|
||||
reason="Lightning attention Triton kernels require CUDA/ROCm or XPU.",
|
||||
)
|
||||
|
||||
NUM_HEADS = [4, 8]
|
||||
HEAD_SIZES = [64]
|
||||
BATCH_SIZES = [1, 2]
|
||||
@@ -121,7 +129,7 @@ def test_linear_decode_forward_triton(
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
torch.set_default_device("cuda")
|
||||
torch.set_default_device(DEVICE)
|
||||
set_random_seed(42)
|
||||
base = 0.01
|
||||
q = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
|
||||
@@ -129,16 +137,16 @@ def test_linear_decode_forward_triton(
|
||||
v = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
|
||||
|
||||
kv_caches = base * torch.randn(
|
||||
batch_size, num_heads, head_size, head_size, dtype=dtype, device="cuda"
|
||||
batch_size, num_heads, head_size, head_size, dtype=dtype, device=DEVICE
|
||||
)
|
||||
|
||||
kv_caches_copy = kv_caches.clone()
|
||||
|
||||
slope_rate = torch.zeros(num_heads, device="cuda")
|
||||
slope_rate = torch.zeros(num_heads, device=DEVICE)
|
||||
for h in range(num_heads):
|
||||
slope_rate[h] = 0.1 * (h + 1)
|
||||
|
||||
slot_idx = torch.arange(batch_size, device="cuda")
|
||||
slot_idx = torch.arange(batch_size, device=DEVICE)
|
||||
|
||||
triton_output = linear_decode_forward_triton(
|
||||
q, k, v, kv_caches, slope_rate, slot_idx
|
||||
@@ -162,7 +170,7 @@ def test_linear_decode_forward_triton_with_padding(
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
torch.set_default_device("cuda")
|
||||
torch.set_default_device(DEVICE)
|
||||
set_random_seed(42)
|
||||
|
||||
batch_size = 4
|
||||
@@ -172,16 +180,16 @@ def test_linear_decode_forward_triton_with_padding(
|
||||
v = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
|
||||
|
||||
kv_caches = base * torch.randn(
|
||||
batch_size, num_heads, head_size, head_size, dtype=dtype, device="cuda"
|
||||
batch_size, num_heads, head_size, head_size, dtype=dtype, device=DEVICE
|
||||
)
|
||||
|
||||
kv_caches_copy = kv_caches.clone()
|
||||
|
||||
slope_rate = torch.zeros(num_heads, device="cuda")
|
||||
slope_rate = torch.zeros(num_heads, device=DEVICE)
|
||||
for h in range(num_heads):
|
||||
slope_rate[h] = 0.1 * (h + 1)
|
||||
|
||||
slot_idx = torch.tensor([0, 1, -1, 2], device="cuda")
|
||||
slot_idx = torch.tensor([0, 1, -1, 2], device=DEVICE)
|
||||
|
||||
triton_output = linear_decode_forward_triton(
|
||||
q, k, v, kv_caches, slope_rate, slot_idx
|
||||
@@ -224,7 +232,7 @@ def test_lightning_attention_reference(
|
||||
seq_len: int,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
torch.set_default_device("cuda")
|
||||
torch.set_default_device(DEVICE)
|
||||
set_random_seed(42)
|
||||
|
||||
base = 0.01
|
||||
@@ -232,12 +240,12 @@ def test_lightning_attention_reference(
|
||||
k = base * torch.randn(batch_size, num_heads, seq_len, head_size, dtype=dtype)
|
||||
v = base * torch.randn(batch_size, num_heads, seq_len, head_size, dtype=dtype)
|
||||
|
||||
ed = torch.zeros(num_heads, device="cuda")
|
||||
ed = torch.zeros(num_heads, device=DEVICE)
|
||||
for h in range(num_heads):
|
||||
ed[h] = 0.1 * (h + 1)
|
||||
|
||||
kv_history = base * torch.randn(
|
||||
batch_size, num_heads, head_size, head_size, dtype=dtype, device="cuda"
|
||||
batch_size, num_heads, head_size, head_size, dtype=dtype, device=DEVICE
|
||||
)
|
||||
|
||||
kv_history_clone = kv_history.clone()
|
||||
|
||||
@@ -13,9 +13,15 @@ from vllm.model_executor.layers.quantization.awq_triton import (
|
||||
awq_dequantize_triton,
|
||||
awq_gemm_triton,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import set_random_seed
|
||||
|
||||
device = "cuda"
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not (current_platform.is_cuda_alike() or current_platform.is_xpu()),
|
||||
reason="AWQ Triton kernels require CUDA/ROCm or XPU.",
|
||||
)
|
||||
|
||||
device = current_platform.device_type
|
||||
|
||||
|
||||
def reverse_awq_order(t: torch.Tensor):
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
import torch
|
||||
from einops import rearrange
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.attention.backends.utils import PAD_SLOT_ID
|
||||
|
||||
@@ -403,13 +404,16 @@ class _attention(torch.autograd.Function):
|
||||
v = v.contiguous()
|
||||
s = s.contiguous()
|
||||
|
||||
# Check CUDA compute capability
|
||||
capability = torch.cuda.get_device_capability()
|
||||
if capability[0] < 8:
|
||||
raise RuntimeError(
|
||||
"Flash attention currently only supported",
|
||||
"for compute capability >= 80",
|
||||
)
|
||||
# Check CUDA compute capability (Ampere+ required for flash attention
|
||||
# path). Other accelerators (ROCm, XPU) rely on their own Triton
|
||||
# backend support and skip this check.
|
||||
if current_platform.is_cuda():
|
||||
capability = torch.cuda.get_device_capability()
|
||||
if capability[0] < 8:
|
||||
raise RuntimeError(
|
||||
"Flash attention currently only supported",
|
||||
"for compute capability >= 80",
|
||||
)
|
||||
|
||||
# Get input dimensions
|
||||
b, h, n, d = q.shape
|
||||
|
||||
Reference in New Issue
Block a user