[Attention][Perf][Kernel] Replace torch.cat with vectorized CUDA kernel MLA query concat - DeepSeek-V3.2 (#34917)

Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: Roberto L. Castro <38211239+LopezCastroRoberto@users.noreply.github.com>
This commit is contained in:
Roberto L. Castro
2026-03-09 17:50:36 +01:00
committed by GitHub
parent 2b28b9b269
commit 580864d81e
10 changed files with 415 additions and 15 deletions
+2 -1
View File
@@ -8,8 +8,9 @@ steps:
- csrc/
- tests/kernels/core
- tests/kernels/test_top_k_per_row.py
- tests/kernels/test_concat_mla_q.py
commands:
- pytest -v -s kernels/core kernels/test_top_k_per_row.py
- pytest -v -s kernels/core kernels/test_top_k_per_row.py kernels/test_concat_mla_q.py
- label: Kernels Attention Test %N
timeout_in_minutes: 35
+98
View File
@@ -0,0 +1,98 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import torch
from vllm import _custom_ops as ops
from vllm.triton_utils import triton
# DeepSeek V3 dimensions
NOPE_DIM = 512
ROPE_DIM = 64
NUM_HEADS = 128
NUM_TOKENS = [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]
def get_configs():
return NUM_TOKENS
def make_inputs(num_tokens, dtype):
"""Create inputs matching the real code path.
Args:
contiguous_nope: If False, simulate the transposed BMM output
(non-contiguous nope with stride pattern from
[N,B,L].transpose(0,1)).
"""
# Simulate: bmm output [N, B, L].transpose(0, 1) -> [B, N, L]
raw = torch.randn(NUM_HEADS, num_tokens, NOPE_DIM, dtype=dtype, device="cuda")
ql_nope = raw.transpose(0, 1)
q_pe = torch.randn(num_tokens, NUM_HEADS, ROPE_DIM, dtype=dtype, device="cuda")
return ql_nope, q_pe
# ---- Non-contiguous nope benchmark (real code path) ----
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["num_tokens"],
x_vals=get_configs(),
line_arg="provider",
line_vals=["torch_cat", "concat_mla_q"],
line_names=["torch.cat", "concat_mla_q (v8)"],
styles=[("blue", "--"), ("green", "-")],
ylabel="Latency (us)",
plot_name="concat_mla_q-transposed",
args={},
)
)
def bench_transposed(num_tokens, provider):
dtype = torch.bfloat16
ql_nope, q_pe = make_inputs(num_tokens, dtype)
q_out = torch.empty(
num_tokens, NUM_HEADS, NOPE_DIM + ROPE_DIM, dtype=dtype, device="cuda"
)
quantiles = [0.5, 0.2, 0.8]
if provider == "torch_cat":
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: torch.cat((ql_nope, q_pe), dim=-1), quantiles=quantiles, rep=500
)
else:
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: ops.concat_mla_q(ql_nope, q_pe, q_out), quantiles=quantiles, rep=500
)
return ms * 1000, max_ms * 1000, min_ms * 1000 # us
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Benchmark concat_mla_q vs torch.cat")
parser.add_argument(
"--save-path", type=str, default=None, help="Path to save benchmark results"
)
args = parser.parse_args()
print("\n" + "=" * 70)
print("CONCAT MLA Q KERNEL BENCHMARKS")
print("=" * 70)
print(f"Dimensions: nope={NOPE_DIM}, rope={ROPE_DIM}, heads={NUM_HEADS}")
print(
f"Per-head output: {NOPE_DIM + ROPE_DIM} bf16 = "
f"{(NOPE_DIM + ROPE_DIM) * 2} bytes"
)
print(f"num_tokens (decode=batch_size, prefill=chunk_size): {NUM_TOKENS}")
print("=" * 70)
print("\n--- Non-contiguous nope inputs (transposed BMM output) ---")
bench_transposed.run(print_data=True, save_path=args.save_path)
print("\n" + "=" * 70)
print("Benchmarking complete!")
print("=" * 70)
+6
View File
@@ -74,6 +74,12 @@ void indexer_k_quant_and_cache(
int64_t quant_block_size, // quantization block size
const std::string& scale_fmt);
// Concatenate query nope and rope for MLA/DSA attention
void concat_mla_q(
torch::Tensor& ql_nope, // [num_tokens, num_heads, nope_dim]
torch::Tensor& q_pe, // [num_tokens, num_heads, rope_dim]
torch::Tensor& q_out); // [num_tokens, num_heads, nope_dim + rope_dim]
// Extract function to gather quantized K cache
void cp_gather_indexer_k_quant_cache(
const torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride]
+41
View File
@@ -8,6 +8,7 @@
#include "cuda_compat.h"
#include "dispatch_utils.h"
#include "quantization/vectorization_utils.cuh"
#include "concat_mla_q.cuh"
#ifdef USE_ROCM
#include "quantization/w8a8/fp8/amd/quant_utils.cuh"
@@ -1358,3 +1359,43 @@ void cp_gather_indexer_k_quant_cache(
CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(32);
}
}
// Concatenate ql_nope and q_pe into a contiguous q_out tensor for MLA/DSA.
// Replaces torch.cat((ql_nope, q_pe), dim=-1).
void concat_mla_q(torch::Tensor& ql_nope, // [num_tokens, num_heads, nope_dim]
torch::Tensor& q_pe, // [num_tokens, num_heads, rope_dim]
torch::Tensor& q_out // [num_tokens, num_heads, nope_dim +
// rope_dim]
) {
const int num_tokens = ql_nope.size(0);
const int num_heads = ql_nope.size(1);
const int nope_dim = ql_nope.size(2);
const int rope_dim = q_pe.size(2);
TORCH_CHECK(nope_dim % 512 == 0, "nope_dim must be a multiple of 512, got ",
nope_dim);
TORCH_CHECK(rope_dim == 64, "rope_dim must be 64, got ", rope_dim);
TORCH_CHECK(q_out.size(2) == nope_dim + rope_dim);
TORCH_CHECK(ql_nope.stride(2) == 1, "ql_nope must have stride 1 in dim 2");
TORCH_CHECK(q_pe.stride(2) == 1, "q_pe must have stride 1 in dim 2");
TORCH_CHECK(q_out.stride(2) == 1, "q_out must have stride 1 in dim 2");
if (num_tokens == 0) return;
constexpr int warps_per_block = 8;
const int total_warps = num_tokens * num_heads;
const int grid_size = (total_warps + warps_per_block - 1) / warps_per_block;
const int block_size = warps_per_block * 32;
const at::cuda::OptionalCUDAGuard device_guard(device_of(ql_nope));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(ql_nope.scalar_type(), "concat_mla_q", [&] {
vllm::ConcatMLAQKernel<scalar_t, 512><<<grid_size, block_size, 0, stream>>>(
q_out.data_ptr<scalar_t>(), ql_nope.data_ptr<scalar_t>(),
q_pe.data_ptr<scalar_t>(), num_tokens, num_heads, q_out.stride(0),
q_out.stride(1), ql_nope.stride(0), ql_nope.stride(1), q_pe.stride(0),
q_pe.stride(1));
});
}
+60
View File
@@ -0,0 +1,60 @@
#ifndef CONCAT_MLA_Q_CUH_
#define CONCAT_MLA_Q_CUH_
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include "cuda_vec_utils.cuh"
namespace vllm {
// Concatenates ql_nope [num_tokens, num_heads, NOPE_DIM] and
// q_pe [num_tokens, num_heads, 64]
// into q_out [num_tokens, num_heads, NOPE_DIM+64].
// Currently instantiated only for NOPE_DIM=512.
// Rope dim is hardcoded to 64 (DeepSeek V3.2 MLA)
template <typename DType, int NOPE_DIM>
__global__ void ConcatMLAQKernel(
DType* __restrict__ q_out, const DType* __restrict__ ql_nope,
const DType* __restrict__ q_pe, const int num_tokens, const int num_heads,
const int64_t out_stride_0, const int64_t out_stride_1,
const int64_t nope_stride_0, const int64_t nope_stride_1,
const int64_t pe_stride_0, const int64_t pe_stride_1) {
const int flat_warp_id = (blockIdx.x * blockDim.x + threadIdx.x) >> 5;
if (flat_warp_id >= num_tokens * num_heads) return;
const int token_id = flat_warp_id / num_heads;
const int head_id = flat_warp_id % num_heads;
const int lane_id = threadIdx.x & 31;
constexpr bool use_256b = VLLM_256B_PTX_ENABLED;
constexpr int nope_vec_loads =
NOPE_DIM * sizeof(DType) / (VecTraits<use_256b>::ARCH_MAX_VEC_SIZE * 32);
const DType* nope_src =
ql_nope + token_id * nope_stride_0 + head_id * nope_stride_1;
DType* nope_dst = q_out + token_id * out_stride_0 + head_id * out_stride_1;
#pragma unroll
for (int i = 0; i < nope_vec_loads; i++) {
const int offset = i * 32 + lane_id;
if constexpr (use_256b) {
st256_cs(reinterpret_cast<u32x8_t*>(nope_dst) + offset,
ld256_cs(reinterpret_cast<const u32x8_t*>(nope_src) + offset));
} else {
st128_cs(reinterpret_cast<int4*>(nope_dst) + offset,
ld128_cs(reinterpret_cast<const int4*>(nope_src) + offset));
}
}
const int* rope_src = reinterpret_cast<const int*>(
q_pe + token_id * pe_stride_0 + head_id * pe_stride_1);
int* rope_dst = reinterpret_cast<int*>(q_out + token_id * out_stride_0 +
head_id * out_stride_1 + NOPE_DIM);
st32_cs(rope_dst + lane_id, ld32_cs(rope_src + lane_id));
}
} // namespace vllm
#endif // CONCAT_MLA_Q_CUH_
+37 -10
View File
@@ -196,7 +196,6 @@ __forceinline__ __device__ u32x8_t ld256_cs(const u32x8_t* addr) {
return val;
#else
assert(false && "ld256_cs requires SM100+ with CUDA 12.9+");
return {};
#endif
}
@@ -211,23 +210,51 @@ __forceinline__ __device__ void st256_cs(u32x8_t* addr, u32x8_t val) {
#endif
}
// 32-bit cache-streaming (.cs) load / store — SM100+ only.
// 32-bit load / store.
__device__ __forceinline__ int ld32(const int* addr) { return __ldg(addr); }
__device__ __forceinline__ void st32(int* addr, int val) { *addr = val; }
// 32-bit cache-streaming (.cs) load / store.
// Falls back to ld32/st32 on ROCm (no .cs hint).
__forceinline__ __device__ int ld32_cs(const int* addr) {
#if VLLM_256B_PTX_ENABLED
int val;
#ifndef USE_ROCM
asm volatile("ld.global.cs.b32 %0, [%1];" : "=r"(val) : "l"(addr));
return val;
#else
assert(false && "ld32_cs requires SM100+ with CUDA 12.9+");
return 0;
val = ld32(addr);
#endif
return val;
}
__forceinline__ __device__ void st32_cs(int* addr, int val) {
#if VLLM_256B_PTX_ENABLED
#ifndef USE_ROCM
asm volatile("st.global.cs.b32 [%0], %1;" ::"l"(addr), "r"(val));
#else
assert(false && "st32_cs requires SM100+ with CUDA 12.9+");
st32(addr, val);
#endif
}
// 128-bit cache-streaming (.cs) load / store.
// Falls back to ld128/st128 on ROCm (no .cs hint).
__forceinline__ __device__ int4 ld128_cs(const int4* addr) {
int4 val;
#ifndef USE_ROCM
asm volatile("ld.global.cs.v4.u32 {%0,%1,%2,%3}, [%4];"
: "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w)
: "l"(addr));
#else
ld128(val, addr);
#endif
return val;
}
__forceinline__ __device__ void st128_cs(int4* addr, int4 val) {
#ifndef USE_ROCM
asm volatile("st.global.cs.v4.u32 [%0], {%1,%2,%3,%4};" ::"l"(addr),
"r"(val.x), "r"(val.y), "r"(val.z), "r"(val.w));
#else
st128(val, addr);
#endif
}
@@ -260,7 +287,7 @@ __device__ __forceinline__ void ld256_cg_or_zero(u32x8_t& val, const void* ptr,
__device__ __forceinline__ void ld128_cg_or_zero(uint4& val, const void* ptr,
bool pred) {
#if VLLM_256B_PTX_ENABLED
#ifndef USE_ROCM
uint32_t r0, r1, r2, r3;
asm volatile(
@@ -278,7 +305,7 @@ __device__ __forceinline__ void ld128_cg_or_zero(uint4& val, const void* ptr,
val = uint4{r0, r1, r2, r3};
#else
assert(false && "ld128_cg_or_zero requires SM100+ with CUDA 12.9+");
assert(false && "ld128_cg_or_zero is not supported on ROCm");
#endif
}
+4
View File
@@ -802,6 +802,10 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
cache_ops.impl("indexer_k_quant_and_cache", torch::kCUDA,
&indexer_k_quant_and_cache);
cache_ops.def(
"concat_mla_q(Tensor ql_nope, Tensor q_pe, Tensor! q_out) -> ()");
cache_ops.impl("concat_mla_q", torch::kCUDA, &concat_mla_q);
cache_ops.def(
"cp_gather_indexer_k_quant_cache(Tensor kv_cache, Tensor! dst_k, Tensor! "
"dst_scale, Tensor block_table, Tensor cu_seq_lens) -> ()");
+139
View File
@@ -0,0 +1,139 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from vllm import _custom_ops as ops
NUM_TOKENS = [1, 4, 16, 64, 128]
NUM_HEADS = [128]
NOPE_DIM = [512]
ROPE_DIM = [64]
DTYPES = [torch.bfloat16, torch.float16]
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("nope_dim", NOPE_DIM)
@pytest.mark.parametrize("rope_dim", ROPE_DIM)
@pytest.mark.parametrize("dtype", DTYPES)
def test_concat_mla_q_contiguous(num_tokens, num_heads, nope_dim, rope_dim, dtype):
"""Test with contiguous inputs (standard layout)."""
torch.manual_seed(42)
ql_nope = torch.randn(num_tokens, num_heads, nope_dim, dtype=dtype, device="cuda")
q_pe = torch.randn(num_tokens, num_heads, rope_dim, dtype=dtype, device="cuda")
ref = torch.cat((ql_nope, q_pe), dim=-1)
q_out = torch.empty(
num_tokens, num_heads, nope_dim + rope_dim, dtype=dtype, device="cuda"
)
ops.concat_mla_q(ql_nope, q_pe, q_out)
torch.testing.assert_close(q_out, ref, atol=0, rtol=0)
@pytest.mark.parametrize("num_tokens", [t for t in NUM_TOKENS if t > 1])
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("nope_dim", NOPE_DIM)
@pytest.mark.parametrize("rope_dim", ROPE_DIM)
@pytest.mark.parametrize("dtype", DTYPES)
def test_concat_mla_q_transposed_nope(num_tokens, num_heads, nope_dim, rope_dim, dtype):
"""Test with transposed nope input (simulates BMM output after transpose).
In the real code path, mqa_ql_nope is the result of:
torch.bmm(q_nope, W_UK_T) # [N, B, L]
.transpose(0, 1) # [B, N, L] — non-contiguous!
"""
torch.manual_seed(42)
nope_raw = torch.randn(num_heads, num_tokens, nope_dim, dtype=dtype, device="cuda")
ql_nope = nope_raw.transpose(0, 1) # [B, N, L], non-contiguous
assert not ql_nope.is_contiguous()
q_pe = torch.randn(num_tokens, num_heads, rope_dim, dtype=dtype, device="cuda")
ref = torch.cat((ql_nope, q_pe), dim=-1)
q_out = torch.empty(
num_tokens, num_heads, nope_dim + rope_dim, dtype=dtype, device="cuda"
)
ops.concat_mla_q(ql_nope, q_pe, q_out)
torch.testing.assert_close(q_out, ref, atol=0, rtol=0)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("dtype", DTYPES)
def test_concat_mla_q_split_rope(num_tokens, num_heads, dtype):
"""Test with rope from a split (simulates the actual code path).
In the real code path, q_pe comes from:
mqa_q.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1)
which creates a non-contiguous view with stride(1) != rope_dim.
"""
torch.manual_seed(42)
nope_dim = 512
rope_dim = 64
orig_dim = 128 + 64 # original q before absorption: [B, N, 192]
# Simulate split from original q tensor
q_orig = torch.randn(num_tokens, num_heads, orig_dim, dtype=dtype, device="cuda")
q_nope_orig, q_pe = q_orig.split([128, 64], dim=-1)
# q_pe is non-contiguous: stride(1) = 192, not 64
assert q_pe.stride(1) == orig_dim
assert q_pe.stride(2) == 1 # but innermost is fine
# Simulate absorbed nope (contiguous, different size)
ql_nope = torch.randn(num_tokens, num_heads, nope_dim, dtype=dtype, device="cuda")
ref = torch.cat((ql_nope, q_pe), dim=-1)
q_out = torch.empty(
num_tokens, num_heads, nope_dim + rope_dim, dtype=dtype, device="cuda"
)
ops.concat_mla_q(ql_nope, q_pe, q_out)
torch.testing.assert_close(q_out, ref, atol=0, rtol=0)
def test_concat_mla_q_zero_tokens():
"""Test with zero tokens (edge case)."""
ql_nope = torch.empty(0, 128, 512, dtype=torch.bfloat16, device="cuda")
q_pe = torch.empty(0, 128, 64, dtype=torch.bfloat16, device="cuda")
q_out = torch.empty(0, 128, 576, dtype=torch.bfloat16, device="cuda")
ops.concat_mla_q(ql_nope, q_pe, q_out)
@pytest.mark.parametrize("num_tokens", [1, 64])
def test_concat_mla_q_values_preserved(num_tokens):
"""Verify exact bit-level preservation (no computation, pure copy).
Compares raw int16 bits to avoid NaN != NaN issues from IEEE 754.
"""
nope_dim, rope_dim = 512, 64
# Use specific bit patterns (stay in int16 for bit-exact comparison)
ql_nope_bits = torch.arange(
num_tokens * 128 * nope_dim, dtype=torch.int16, device="cuda"
).view(num_tokens, 128, nope_dim)
q_pe_bits = torch.arange(
num_tokens * 128 * rope_dim, dtype=torch.int16, device="cuda"
).view(num_tokens, 128, rope_dim)
ql_nope = ql_nope_bits.view(torch.bfloat16)
q_pe = q_pe_bits.view(torch.bfloat16)
q_out = torch.empty(
num_tokens, 128, nope_dim + rope_dim, dtype=torch.bfloat16, device="cuda"
)
ops.concat_mla_q(ql_nope, q_pe, q_out)
out_bits = q_out.view(torch.int16)
assert torch.equal(out_bits[..., :nope_dim], ql_nope_bits)
assert torch.equal(out_bits[..., nope_dim:], q_pe_bits)
+15
View File
@@ -2672,6 +2672,21 @@ def cp_gather_and_upconvert_fp8_kv_cache(
)
def concat_mla_q(
ql_nope: torch.Tensor,
q_pe: torch.Tensor,
q_out: torch.Tensor,
) -> None:
"""Concatenate query nope and rope for MLA/DSA attention.
Args:
ql_nope: Query nope component [num_tokens, num_heads, nope_dim]
q_pe: Query rope component [num_tokens, num_heads, rope_dim]
q_out: Output tensor [num_tokens, num_heads, nope_dim + rope_dim]
"""
torch.ops._C_cache_ops.concat_mla_q(ql_nope, q_pe, q_out)
def indexer_k_quant_and_cache(
k: torch.Tensor,
kv_cache: torch.Tensor,
@@ -568,6 +568,9 @@ class FlashMLASparseImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]):
)
self.fp8_decode_padded_heads = self._compute_fp8_decode_padded_heads(num_heads)
vllm_config = get_current_vllm_config()
max_tokens = vllm_config.scheduler_config.max_num_batched_tokens
q_concat_shape = (max_tokens, num_heads, head_size)
if kv_cache_dtype.startswith("fp8"):
assert kv_cache_dtype == "fp8_ds_mla", (
"FlashMLA Sparse Attention backend fp8 only supports "
@@ -576,17 +579,21 @@ class FlashMLASparseImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]):
if kv_cache_dtype == "fp8_ds_mla":
# Reserve workspace during initialization
vllm_config = get_current_vllm_config()
assert vllm_config is not None and vllm_config.model_config is not None
prefill_workspace_size = get_prefill_workspace_size(
vllm_config.model_config.max_model_len
)
self.prefill_workspace_shape = (prefill_workspace_size, head_size)
(self.prefill_bf16_workspace,) = (
self.q_concat_buffer, self.prefill_bf16_workspace = (
current_workspace_manager().get_simultaneous(
(self.prefill_workspace_shape, torch.bfloat16)
(q_concat_shape, torch.bfloat16),
(self.prefill_workspace_shape, torch.bfloat16),
)
)
else:
(self.q_concat_buffer,) = current_workspace_manager().get_simultaneous(
(q_concat_shape, torch.bfloat16),
)
def _forward_bf16_kv(
self,
@@ -828,7 +835,9 @@ class FlashMLASparseImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]):
# Concatenate q if it's a tuple (ql_nope, q_pe)
if isinstance(q, tuple):
q = torch.cat(q, dim=-1)
ql_nope, q_pe = q
q = self.q_concat_buffer[: ql_nope.shape[0]]
ops.concat_mla_q(ql_nope, q_pe, q)
num_actual_toks = q.shape[0]