[XPU][Mamba] Triton-based selective scan forward op for XPU (#43421)

Signed-off-by: Marceli Fylcek <marceli.fylcek@intel.com>
Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
Marceli Fylcek
2026-06-02 13:50:26 +03:00
committed by GitHub
parent 2a2b5ca791
commit f69ede495b
2 changed files with 523 additions and 22 deletions
+474
View File
@@ -8,6 +8,7 @@ from vllm_xpu_kernels.flash_attn_interface import flash_attn_varlen_func
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils.torch_utils import direct_register_custom_op
logger = init_logger(__name__)
@@ -407,6 +408,312 @@ def _xpu_mxfp4_quantize_fake(
return x_q, x_s
@triton.jit
def _softplus(x):
return tl.where(x <= 20.0, tl.math.log(tl.math.exp(x) + 1.0), x)
@triton.jit
def _selective_scan_fwd_kernel(
# Pointers to input tensors
u_ptr,
delta_ptr,
A_ptr,
B_ptr,
C_ptr,
D_ptr,
z_ptr,
delta_bias_ptr,
# Pointers to output tensors (out aliases delta, out_z aliases z)
out_ptr,
out_z_ptr,
# SSM states
ssm_states_ptr,
# Optional pointers
query_start_loc_ptr,
cache_indices_ptr,
has_initial_state_ptr,
# APC pointers
block_idx_first_ptr,
block_idx_last_ptr,
initial_state_idx_ptr,
cu_chunk_seqlen_ptr,
last_chunk_indices_ptr,
# Dimensions
batch: tl.int32,
dim: tl.int32,
seqlen: tl.int32,
dstate: tl.int32,
n_groups: tl.int32,
dim_ngroups_ratio: tl.int32,
# Strides for u (and out, since out = delta which has same layout)
u_batch_stride: tl.int64,
u_d_stride: tl.int64,
# Strides for delta
delta_batch_stride: tl.int64,
delta_d_stride: tl.int64,
# Strides for A
A_d_stride: tl.int64,
A_dstate_stride: tl.int64,
# Strides for B
B_batch_stride: tl.int64,
B_group_stride: tl.int64,
B_dstate_stride: tl.int64,
# Strides for C
C_batch_stride: tl.int64,
C_group_stride: tl.int64,
C_dstate_stride: tl.int64,
# Strides for z
z_batch_stride: tl.int64,
z_d_stride: tl.int64,
# Strides for out
out_batch_stride: tl.int64,
out_d_stride: tl.int64,
# Strides for out_z
out_z_batch_stride: tl.int64,
out_z_d_stride: tl.int64,
# Strides for ssm_states
ssm_batch_stride: tl.int64,
ssm_dim_stride: tl.int64,
ssm_dstate_stride: tl.int64,
# Cache strides
cache_indices_stride: tl.int64,
# Scalar params
null_block_id: tl.int64,
block_size: tl.int32,
# Compile-time constants
delta_softplus: tl.constexpr,
HAS_D: tl.constexpr,
HAS_Z: tl.constexpr,
HAS_DELTA_BIAS: tl.constexpr,
IS_VARLEN: tl.constexpr,
HAS_CACHE_INDICES: tl.constexpr,
CACHE_ENABLED: tl.constexpr,
BLOCK_DSTATE: tl.constexpr,
):
batch_idx = tl.program_id(0)
dim_idx = tl.program_id(1)
group_idx = dim_idx // dim_ngroups_ratio
# Determine sequence boundaries
if IS_VARLEN:
seq_start = tl.load(query_start_loc_ptr + batch_idx).to(tl.int32)
seq_end = tl.load(query_start_loc_ptr + batch_idx + 1).to(tl.int32)
actual_seqlen = seq_end - seq_start
else:
seq_start = 0
actual_seqlen = seqlen
# Determine cache index for ssm_states
if CACHE_ENABLED:
init_state_idx = tl.load(initial_state_idx_ptr + batch_idx).to(tl.int32)
load_cache_slot = tl.load(
cache_indices_ptr + batch_idx * cache_indices_stride + init_state_idx
).to(tl.int64)
if load_cache_slot == null_block_id:
return
elif HAS_CACHE_INDICES:
cache_index = tl.load(cache_indices_ptr + batch_idx).to(tl.int64)
if cache_index == null_block_id:
return
load_cache_slot = cache_index
else:
load_cache_slot = batch_idx.to(tl.int64)
# Load D value
D_val = 0.0
if HAS_D:
D_val = tl.load(D_ptr + dim_idx).to(tl.float32)
# Load delta_bias value
delta_bias_val = 0.0
if HAS_DELTA_BIAS:
delta_bias_val = tl.load(delta_bias_ptr + dim_idx).to(tl.float32)
# Load A values for this dim - shape (dstate,)
dstate_offs = tl.arange(0, BLOCK_DSTATE)
dstate_mask = dstate_offs < dstate
A_vals = tl.load(
A_ptr + dim_idx * A_d_stride + dstate_offs * A_dstate_stride,
mask=dstate_mask,
other=0.0,
).to(tl.float32)
# Initialize state vector
state = tl.zeros((BLOCK_DSTATE,), dtype=tl.float32)
# Load initial state if available
has_init = False
if has_initial_state_ptr is not None:
has_init = tl.load(has_initial_state_ptr + batch_idx)
if has_init:
state = tl.load(
ssm_states_ptr
+ load_cache_slot * ssm_batch_stride
+ dim_idx * ssm_dim_stride
+ dstate_offs * ssm_dstate_stride,
mask=dstate_mask,
other=0.0,
).to(tl.float32)
# Compute base addresses for u and delta
if IS_VARLEN:
u_base = u_ptr + dim_idx * u_d_stride + seq_start * u_batch_stride
delta_base = (
delta_ptr + dim_idx * delta_d_stride + seq_start * delta_batch_stride
)
out_base = out_ptr + dim_idx * out_d_stride + seq_start * out_batch_stride
B_base = B_ptr + group_idx * B_group_stride + seq_start * B_batch_stride
C_base = C_ptr + group_idx * C_group_stride + seq_start * C_batch_stride
else:
u_base = u_ptr + batch_idx * u_batch_stride + dim_idx * u_d_stride
delta_base = (
delta_ptr + batch_idx * delta_batch_stride + dim_idx * delta_d_stride
)
out_base = out_ptr + batch_idx * out_batch_stride + dim_idx * out_d_stride
B_base = B_ptr + batch_idx * B_batch_stride + group_idx * B_group_stride
C_base = C_ptr + batch_idx * C_batch_stride + group_idx * C_group_stride
if HAS_Z:
if IS_VARLEN:
z_base = z_ptr + dim_idx * z_d_stride + seq_start * z_batch_stride
out_z_base = (
out_z_ptr + dim_idx * out_z_d_stride + seq_start * out_z_batch_stride
)
else:
z_base = z_ptr + batch_idx * z_batch_stride + dim_idx * z_d_stride
out_z_base = (
out_z_ptr + batch_idx * out_z_batch_stride + dim_idx * out_z_d_stride
)
# Determine chunk boundaries for APC mode
if CACHE_ENABLED:
last_chunk_idx = tl.load(last_chunk_indices_ptr + batch_idx).to(tl.int32)
if batch_idx == 0:
first_chunk_idx = 0
else:
first_chunk_idx = (
tl.load(last_chunk_indices_ptr + batch_idx - 1).to(tl.int32) + 1
)
n_chunks = last_chunk_idx - first_chunk_idx + 1
first_chunk_tokens = tl.load(cu_chunk_seqlen_ptr + first_chunk_idx + 1).to(
tl.int32
) - tl.load(cu_chunk_seqlen_ptr + first_chunk_idx).to(tl.int32)
block_idx_first = tl.load(block_idx_first_ptr + batch_idx).to(tl.int32)
chunk_start_offset = 0
if n_chunks > 1 and first_chunk_tokens < block_size:
chunk_start_offset = block_size - first_chunk_tokens
current_position = block_idx_first * block_size + chunk_start_offset
else:
n_chunks = 1
first_chunk_idx = 0
# Sequential scan over the sequence
tokens_processed = 0
for chunk in range(0, n_chunks if CACHE_ENABLED else 1):
if CACHE_ENABLED:
chunk_tokens = tl.load(
cu_chunk_seqlen_ptr + first_chunk_idx + chunk + 1
).to(tl.int32) - tl.load(cu_chunk_seqlen_ptr + first_chunk_idx + chunk).to(
tl.int32
)
else:
chunk_tokens = actual_seqlen
for local_pos in range(chunk_tokens):
pos = tokens_processed + local_pos
# Load u value
u_val = tl.load(u_base + pos).to(tl.float32)
# Load delta value
delta_val = tl.load(delta_base + pos).to(tl.float32)
# Apply delta bias
if HAS_DELTA_BIAS:
delta_val = delta_val + delta_bias_val
# Apply softplus
if delta_softplus:
delta_val = _softplus(delta_val)
delta_u = delta_val * u_val
# Compute dA = exp(delta * A) for all dstate elements
dA = tl.exp(delta_val * A_vals)
# Load B values for this position
B_vals = tl.load(
B_base + dstate_offs * B_dstate_stride + pos,
mask=dstate_mask,
other=0.0,
).to(tl.float32)
# Load C values for this position
C_vals = tl.load(
C_base + dstate_offs * C_dstate_stride + pos,
mask=dstate_mask,
other=0.0,
).to(tl.float32)
# Update state: state = dA * state + delta * u * B
state = dA * state + delta_u * B_vals
# Compute output: out = sum(state * C) + D * u
out_val = tl.sum(state * C_vals, axis=0)
if HAS_D:
out_val = out_val + D_val * u_val
# Store output
tl.store(out_base + pos, out_val.to(out_ptr.dtype.element_ty))
if HAS_Z:
z_val = tl.load(z_base + pos).to(tl.float32)
out_z_val = out_val * z_val / (1.0 + tl.exp(-z_val))
tl.store(
out_z_base + pos,
out_z_val.to(out_z_ptr.dtype.element_ty),
)
tokens_processed += chunk_tokens
# Store intermediate state for APC mode
if CACHE_ENABLED:
if chunk == n_chunks - 1:
store_slot = tl.load(
cache_indices_ptr
+ batch_idx * cache_indices_stride
+ tl.load(block_idx_last_ptr + batch_idx).to(tl.int32)
).to(tl.int64)
else:
block_idx_done = (current_position + chunk_tokens - 1) // block_size
store_slot = tl.load(
cache_indices_ptr
+ batch_idx * cache_indices_stride
+ block_idx_done
).to(tl.int64)
tl.store(
ssm_states_ptr
+ store_slot * ssm_batch_stride
+ dim_idx * ssm_dim_stride
+ dstate_offs * ssm_dstate_stride,
state.to(ssm_states_ptr.dtype.element_ty),
mask=dstate_mask,
)
current_position += chunk_tokens
# Store final state for non-APC mode
if not CACHE_ENABLED:
tl.store(
ssm_states_ptr
+ load_cache_slot * ssm_batch_stride
+ dim_idx * ssm_dim_stride
+ dstate_offs * ssm_dstate_stride,
state.to(ssm_states_ptr.dtype.element_ty),
mask=dstate_mask,
)
# Global flag to ensure ops are registered only once
_OPS_REGISTERED = False
@@ -549,6 +856,173 @@ class xpu_ops:
)
return None
@staticmethod
def selective_scan_fwd(
u: torch.Tensor,
delta: torch.Tensor,
A: torch.Tensor,
B: torch.Tensor,
C: torch.Tensor,
D_: torch.Tensor | None,
z_: torch.Tensor | None,
delta_bias_: torch.Tensor | None,
delta_softplus: bool,
query_start_loc: torch.Tensor | None,
cache_indices: torch.Tensor | None,
has_initial_state: torch.Tensor | None,
ssm_states: torch.Tensor,
null_block_id: int,
block_size: int = 1024,
block_idx_first_scheduled_token: torch.Tensor | None = None,
block_idx_last_scheduled_token: torch.Tensor | None = None,
initial_state_idx: torch.Tensor | None = None,
cu_chunk_seqlen: torch.Tensor | None = None,
last_chunk_indices: torch.Tensor | None = None,
) -> None:
varlen = query_start_loc is not None
batch_size = (
(query_start_loc.shape[0] - 1)
if query_start_loc is not None
else u.shape[0]
)
dim = u.shape[0] if varlen else u.shape[1]
total_seqlen = u.shape[1] if varlen else u.shape[2]
dstate = A.size(1)
n_groups = B.size(0) if varlen else B.size(1)
dim_ngroups_ratio = dim // n_groups
has_z = z_ is not None
has_D = D_ is not None
has_delta_bias = delta_bias_ is not None
has_cache_indices = cache_indices is not None
cache_enabled = block_idx_first_scheduled_token is not None
# out and out_z alias delta and z respectively
out = delta
out_z = z_ if z_ is not None else delta # won't be used if not has_z
BLOCK_DSTATE = triton.next_power_of_2(dstate)
# Compute strides
if varlen:
u_batch_stride = u.stride(1)
u_d_stride = u.stride(0)
delta_batch_stride = delta.stride(1)
delta_d_stride = delta.stride(0)
B_batch_stride = B.stride(2)
B_group_stride = B.stride(0)
B_dstate_stride = B.stride(1)
C_batch_stride = C.stride(2)
C_group_stride = C.stride(0)
C_dstate_stride = C.stride(1)
out_batch_stride = out.stride(1)
out_d_stride = out.stride(0)
if z_ is not None:
z_batch_stride = z_.stride(1)
z_d_stride = z_.stride(0)
out_z_batch_stride = out_z.stride(1)
out_z_d_stride = out_z.stride(0)
else:
z_batch_stride = 0
z_d_stride = 0
out_z_batch_stride = 0
out_z_d_stride = 0
else:
u_batch_stride = u.stride(0)
u_d_stride = u.stride(1)
delta_batch_stride = delta.stride(0)
delta_d_stride = delta.stride(1)
B_batch_stride = B.stride(0)
B_group_stride = B.stride(1)
B_dstate_stride = B.stride(2)
C_batch_stride = C.stride(0)
C_group_stride = C.stride(1)
C_dstate_stride = C.stride(2)
out_batch_stride = out.stride(0)
out_d_stride = out.stride(1)
if z_ is not None:
z_batch_stride = z_.stride(0)
z_d_stride = z_.stride(1)
out_z_batch_stride = out_z.stride(0)
out_z_d_stride = out_z.stride(1)
else:
z_batch_stride = 0
z_d_stride = 0
out_z_batch_stride = 0
out_z_d_stride = 0
ssm_batch_stride = ssm_states.stride(0)
ssm_dim_stride = ssm_states.stride(1)
ssm_dstate_stride = ssm_states.stride(2)
cache_indices_stride = (
cache_indices.stride(0) if cache_indices is not None else 0
)
grid = (batch_size, dim)
_selective_scan_fwd_kernel[grid](
u,
delta,
A,
B,
C,
D_ if has_D else u, # dummy, won't be dereferenced
z_ if has_z else u, # dummy
delta_bias_ if has_delta_bias else u, # dummy
out,
out_z,
ssm_states,
query_start_loc if varlen else u, # dummy
cache_indices if has_cache_indices else u, # dummy
has_initial_state,
# APC pointers
block_idx_first_scheduled_token if cache_enabled else u,
block_idx_last_scheduled_token if cache_enabled else u,
initial_state_idx if cache_enabled else u,
cu_chunk_seqlen if cache_enabled else u,
last_chunk_indices if cache_enabled else u,
# Dimensions
batch_size,
dim,
total_seqlen,
dstate,
n_groups,
dim_ngroups_ratio,
# Strides
u_batch_stride,
u_d_stride,
delta_batch_stride,
delta_d_stride,
A.stride(0),
A.stride(1),
B_batch_stride,
B_group_stride,
B_dstate_stride,
C_batch_stride,
C_group_stride,
C_dstate_stride,
z_batch_stride,
z_d_stride,
out_batch_stride,
out_d_stride,
out_z_batch_stride,
out_z_d_stride,
ssm_batch_stride,
ssm_dim_stride,
ssm_dstate_stride,
cache_indices_stride,
null_block_id,
block_size,
# Compile-time constants
delta_softplus=delta_softplus,
HAS_D=has_D,
HAS_Z=has_z,
HAS_DELTA_BIAS=has_delta_bias,
IS_VARLEN=varlen,
HAS_CACHE_INDICES=has_cache_indices,
CACHE_ENABLED=cache_enabled,
BLOCK_DSTATE=BLOCK_DSTATE,
)
@staticmethod
def register_ops_once() -> None:
global _OPS_REGISTERED
@@ -21,6 +21,9 @@ from vllm.platforms import current_platform
from vllm.triton_utils import HAS_TRITON, tl, triton
from vllm.v1.attention.backends.utils import NULL_BLOCK_ID
if current_platform.is_xpu():
from vllm._xpu_ops import xpu_ops
logger = init_logger(__name__)
TRITON3 = HAS_TRITON and (version.parse(triton.__version__) >= version.parse("3.0.0"))
@@ -790,28 +793,52 @@ def selective_scan_fn(
if C.dim() == 2 and query_start_loc is not None:
C = C.unsqueeze(0)
ops.selective_scan_fwd(
u,
delta,
A,
B,
C,
D,
z,
delta_bias,
delta_softplus,
query_start_loc,
cache_indices,
has_initial_state,
ssm_states,
null_block_id,
block_size,
block_idx_first_scheduled_token,
block_idx_last_scheduled_token,
initial_state_idx,
cu_chunk_seqlen,
last_chunk_indices,
)
if current_platform.is_xpu():
xpu_ops.selective_scan_fwd(
u,
delta,
A,
B,
C,
D,
z,
delta_bias,
delta_softplus,
query_start_loc,
cache_indices,
has_initial_state,
ssm_states,
null_block_id,
block_size,
block_idx_first_scheduled_token,
block_idx_last_scheduled_token,
initial_state_idx,
cu_chunk_seqlen,
last_chunk_indices,
)
else:
ops.selective_scan_fwd(
u,
delta,
A,
B,
C,
D,
z,
delta_bias,
delta_softplus,
query_start_loc,
cache_indices,
has_initial_state,
ssm_states,
null_block_id,
block_size,
block_idx_first_scheduled_token,
block_idx_last_scheduled_token,
initial_state_idx,
cu_chunk_seqlen,
last_chunk_indices,
)
if z is None:
return delta # output written inplace to delta