mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[Attention][Perf][Kernel] Replace torch.cat with vectorized CUDA kernel MLA query concat - DeepSeek-V3.2 (#34917)
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com> Signed-off-by: Roberto L. Castro <38211239+LopezCastroRoberto@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
2b28b9b269
commit
580864d81e
@@ -8,8 +8,9 @@ steps:
|
||||
- csrc/
|
||||
- tests/kernels/core
|
||||
- tests/kernels/test_top_k_per_row.py
|
||||
- tests/kernels/test_concat_mla_q.py
|
||||
commands:
|
||||
- pytest -v -s kernels/core kernels/test_top_k_per_row.py
|
||||
- pytest -v -s kernels/core kernels/test_top_k_per_row.py kernels/test_concat_mla_q.py
|
||||
|
||||
- label: Kernels Attention Test %N
|
||||
timeout_in_minutes: 35
|
||||
|
||||
@@ -0,0 +1,98 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.triton_utils import triton
|
||||
|
||||
# DeepSeek V3 dimensions
|
||||
NOPE_DIM = 512
|
||||
ROPE_DIM = 64
|
||||
NUM_HEADS = 128
|
||||
|
||||
NUM_TOKENS = [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]
|
||||
|
||||
|
||||
def get_configs():
|
||||
return NUM_TOKENS
|
||||
|
||||
|
||||
def make_inputs(num_tokens, dtype):
|
||||
"""Create inputs matching the real code path.
|
||||
|
||||
Args:
|
||||
contiguous_nope: If False, simulate the transposed BMM output
|
||||
(non-contiguous nope with stride pattern from
|
||||
[N,B,L].transpose(0,1)).
|
||||
"""
|
||||
# Simulate: bmm output [N, B, L].transpose(0, 1) -> [B, N, L]
|
||||
raw = torch.randn(NUM_HEADS, num_tokens, NOPE_DIM, dtype=dtype, device="cuda")
|
||||
ql_nope = raw.transpose(0, 1)
|
||||
|
||||
q_pe = torch.randn(num_tokens, NUM_HEADS, ROPE_DIM, dtype=dtype, device="cuda")
|
||||
return ql_nope, q_pe
|
||||
|
||||
|
||||
# ---- Non-contiguous nope benchmark (real code path) ----
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["num_tokens"],
|
||||
x_vals=get_configs(),
|
||||
line_arg="provider",
|
||||
line_vals=["torch_cat", "concat_mla_q"],
|
||||
line_names=["torch.cat", "concat_mla_q (v8)"],
|
||||
styles=[("blue", "--"), ("green", "-")],
|
||||
ylabel="Latency (us)",
|
||||
plot_name="concat_mla_q-transposed",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def bench_transposed(num_tokens, provider):
|
||||
dtype = torch.bfloat16
|
||||
ql_nope, q_pe = make_inputs(num_tokens, dtype)
|
||||
|
||||
q_out = torch.empty(
|
||||
num_tokens, NUM_HEADS, NOPE_DIM + ROPE_DIM, dtype=dtype, device="cuda"
|
||||
)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "torch_cat":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: torch.cat((ql_nope, q_pe), dim=-1), quantiles=quantiles, rep=500
|
||||
)
|
||||
else:
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: ops.concat_mla_q(ql_nope, q_pe, q_out), quantiles=quantiles, rep=500
|
||||
)
|
||||
|
||||
return ms * 1000, max_ms * 1000, min_ms * 1000 # us
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Benchmark concat_mla_q vs torch.cat")
|
||||
parser.add_argument(
|
||||
"--save-path", type=str, default=None, help="Path to save benchmark results"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print("CONCAT MLA Q KERNEL BENCHMARKS")
|
||||
print("=" * 70)
|
||||
print(f"Dimensions: nope={NOPE_DIM}, rope={ROPE_DIM}, heads={NUM_HEADS}")
|
||||
print(
|
||||
f"Per-head output: {NOPE_DIM + ROPE_DIM} bf16 = "
|
||||
f"{(NOPE_DIM + ROPE_DIM) * 2} bytes"
|
||||
)
|
||||
print(f"num_tokens (decode=batch_size, prefill=chunk_size): {NUM_TOKENS}")
|
||||
print("=" * 70)
|
||||
|
||||
print("\n--- Non-contiguous nope inputs (transposed BMM output) ---")
|
||||
bench_transposed.run(print_data=True, save_path=args.save_path)
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print("Benchmarking complete!")
|
||||
print("=" * 70)
|
||||
@@ -74,6 +74,12 @@ void indexer_k_quant_and_cache(
|
||||
int64_t quant_block_size, // quantization block size
|
||||
const std::string& scale_fmt);
|
||||
|
||||
// Concatenate query nope and rope for MLA/DSA attention
|
||||
void concat_mla_q(
|
||||
torch::Tensor& ql_nope, // [num_tokens, num_heads, nope_dim]
|
||||
torch::Tensor& q_pe, // [num_tokens, num_heads, rope_dim]
|
||||
torch::Tensor& q_out); // [num_tokens, num_heads, nope_dim + rope_dim]
|
||||
|
||||
// Extract function to gather quantized K cache
|
||||
void cp_gather_indexer_k_quant_cache(
|
||||
const torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride]
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#include "cuda_compat.h"
|
||||
#include "dispatch_utils.h"
|
||||
#include "quantization/vectorization_utils.cuh"
|
||||
#include "concat_mla_q.cuh"
|
||||
|
||||
#ifdef USE_ROCM
|
||||
#include "quantization/w8a8/fp8/amd/quant_utils.cuh"
|
||||
@@ -1358,3 +1359,43 @@ void cp_gather_indexer_k_quant_cache(
|
||||
CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(32);
|
||||
}
|
||||
}
|
||||
|
||||
// Concatenate ql_nope and q_pe into a contiguous q_out tensor for MLA/DSA.
|
||||
// Replaces torch.cat((ql_nope, q_pe), dim=-1).
|
||||
void concat_mla_q(torch::Tensor& ql_nope, // [num_tokens, num_heads, nope_dim]
|
||||
torch::Tensor& q_pe, // [num_tokens, num_heads, rope_dim]
|
||||
torch::Tensor& q_out // [num_tokens, num_heads, nope_dim +
|
||||
// rope_dim]
|
||||
) {
|
||||
const int num_tokens = ql_nope.size(0);
|
||||
const int num_heads = ql_nope.size(1);
|
||||
const int nope_dim = ql_nope.size(2);
|
||||
const int rope_dim = q_pe.size(2);
|
||||
|
||||
TORCH_CHECK(nope_dim % 512 == 0, "nope_dim must be a multiple of 512, got ",
|
||||
nope_dim);
|
||||
TORCH_CHECK(rope_dim == 64, "rope_dim must be 64, got ", rope_dim);
|
||||
TORCH_CHECK(q_out.size(2) == nope_dim + rope_dim);
|
||||
|
||||
TORCH_CHECK(ql_nope.stride(2) == 1, "ql_nope must have stride 1 in dim 2");
|
||||
TORCH_CHECK(q_pe.stride(2) == 1, "q_pe must have stride 1 in dim 2");
|
||||
TORCH_CHECK(q_out.stride(2) == 1, "q_out must have stride 1 in dim 2");
|
||||
|
||||
if (num_tokens == 0) return;
|
||||
|
||||
constexpr int warps_per_block = 8;
|
||||
const int total_warps = num_tokens * num_heads;
|
||||
const int grid_size = (total_warps + warps_per_block - 1) / warps_per_block;
|
||||
const int block_size = warps_per_block * 32;
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(ql_nope));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(ql_nope.scalar_type(), "concat_mla_q", [&] {
|
||||
vllm::ConcatMLAQKernel<scalar_t, 512><<<grid_size, block_size, 0, stream>>>(
|
||||
q_out.data_ptr<scalar_t>(), ql_nope.data_ptr<scalar_t>(),
|
||||
q_pe.data_ptr<scalar_t>(), num_tokens, num_heads, q_out.stride(0),
|
||||
q_out.stride(1), ql_nope.stride(0), ql_nope.stride(1), q_pe.stride(0),
|
||||
q_pe.stride(1));
|
||||
});
|
||||
}
|
||||
|
||||
@@ -0,0 +1,60 @@
|
||||
#ifndef CONCAT_MLA_Q_CUH_
|
||||
#define CONCAT_MLA_Q_CUH_
|
||||
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
#include "cuda_vec_utils.cuh"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
// Concatenates ql_nope [num_tokens, num_heads, NOPE_DIM] and
|
||||
// q_pe [num_tokens, num_heads, 64]
|
||||
// into q_out [num_tokens, num_heads, NOPE_DIM+64].
|
||||
// Currently instantiated only for NOPE_DIM=512.
|
||||
// Rope dim is hardcoded to 64 (DeepSeek V3.2 MLA)
|
||||
template <typename DType, int NOPE_DIM>
|
||||
__global__ void ConcatMLAQKernel(
|
||||
DType* __restrict__ q_out, const DType* __restrict__ ql_nope,
|
||||
const DType* __restrict__ q_pe, const int num_tokens, const int num_heads,
|
||||
const int64_t out_stride_0, const int64_t out_stride_1,
|
||||
const int64_t nope_stride_0, const int64_t nope_stride_1,
|
||||
const int64_t pe_stride_0, const int64_t pe_stride_1) {
|
||||
const int flat_warp_id = (blockIdx.x * blockDim.x + threadIdx.x) >> 5;
|
||||
if (flat_warp_id >= num_tokens * num_heads) return;
|
||||
|
||||
const int token_id = flat_warp_id / num_heads;
|
||||
const int head_id = flat_warp_id % num_heads;
|
||||
const int lane_id = threadIdx.x & 31;
|
||||
|
||||
constexpr bool use_256b = VLLM_256B_PTX_ENABLED;
|
||||
constexpr int nope_vec_loads =
|
||||
NOPE_DIM * sizeof(DType) / (VecTraits<use_256b>::ARCH_MAX_VEC_SIZE * 32);
|
||||
|
||||
const DType* nope_src =
|
||||
ql_nope + token_id * nope_stride_0 + head_id * nope_stride_1;
|
||||
DType* nope_dst = q_out + token_id * out_stride_0 + head_id * out_stride_1;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < nope_vec_loads; i++) {
|
||||
const int offset = i * 32 + lane_id;
|
||||
if constexpr (use_256b) {
|
||||
st256_cs(reinterpret_cast<u32x8_t*>(nope_dst) + offset,
|
||||
ld256_cs(reinterpret_cast<const u32x8_t*>(nope_src) + offset));
|
||||
} else {
|
||||
st128_cs(reinterpret_cast<int4*>(nope_dst) + offset,
|
||||
ld128_cs(reinterpret_cast<const int4*>(nope_src) + offset));
|
||||
}
|
||||
}
|
||||
|
||||
const int* rope_src = reinterpret_cast<const int*>(
|
||||
q_pe + token_id * pe_stride_0 + head_id * pe_stride_1);
|
||||
int* rope_dst = reinterpret_cast<int*>(q_out + token_id * out_stride_0 +
|
||||
head_id * out_stride_1 + NOPE_DIM);
|
||||
|
||||
st32_cs(rope_dst + lane_id, ld32_cs(rope_src + lane_id));
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
#endif // CONCAT_MLA_Q_CUH_
|
||||
+37
-10
@@ -196,7 +196,6 @@ __forceinline__ __device__ u32x8_t ld256_cs(const u32x8_t* addr) {
|
||||
return val;
|
||||
#else
|
||||
assert(false && "ld256_cs requires SM100+ with CUDA 12.9+");
|
||||
return {};
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -211,23 +210,51 @@ __forceinline__ __device__ void st256_cs(u32x8_t* addr, u32x8_t val) {
|
||||
#endif
|
||||
}
|
||||
|
||||
// 32-bit cache-streaming (.cs) load / store — SM100+ only.
|
||||
// 32-bit load / store.
|
||||
__device__ __forceinline__ int ld32(const int* addr) { return __ldg(addr); }
|
||||
|
||||
__device__ __forceinline__ void st32(int* addr, int val) { *addr = val; }
|
||||
|
||||
// 32-bit cache-streaming (.cs) load / store.
|
||||
// Falls back to ld32/st32 on ROCm (no .cs hint).
|
||||
__forceinline__ __device__ int ld32_cs(const int* addr) {
|
||||
#if VLLM_256B_PTX_ENABLED
|
||||
int val;
|
||||
#ifndef USE_ROCM
|
||||
asm volatile("ld.global.cs.b32 %0, [%1];" : "=r"(val) : "l"(addr));
|
||||
return val;
|
||||
#else
|
||||
assert(false && "ld32_cs requires SM100+ with CUDA 12.9+");
|
||||
return 0;
|
||||
val = ld32(addr);
|
||||
#endif
|
||||
return val;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void st32_cs(int* addr, int val) {
|
||||
#if VLLM_256B_PTX_ENABLED
|
||||
#ifndef USE_ROCM
|
||||
asm volatile("st.global.cs.b32 [%0], %1;" ::"l"(addr), "r"(val));
|
||||
#else
|
||||
assert(false && "st32_cs requires SM100+ with CUDA 12.9+");
|
||||
st32(addr, val);
|
||||
#endif
|
||||
}
|
||||
|
||||
// 128-bit cache-streaming (.cs) load / store.
|
||||
// Falls back to ld128/st128 on ROCm (no .cs hint).
|
||||
__forceinline__ __device__ int4 ld128_cs(const int4* addr) {
|
||||
int4 val;
|
||||
#ifndef USE_ROCM
|
||||
asm volatile("ld.global.cs.v4.u32 {%0,%1,%2,%3}, [%4];"
|
||||
: "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w)
|
||||
: "l"(addr));
|
||||
#else
|
||||
ld128(val, addr);
|
||||
#endif
|
||||
return val;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void st128_cs(int4* addr, int4 val) {
|
||||
#ifndef USE_ROCM
|
||||
asm volatile("st.global.cs.v4.u32 [%0], {%1,%2,%3,%4};" ::"l"(addr),
|
||||
"r"(val.x), "r"(val.y), "r"(val.z), "r"(val.w));
|
||||
#else
|
||||
st128(val, addr);
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -260,7 +287,7 @@ __device__ __forceinline__ void ld256_cg_or_zero(u32x8_t& val, const void* ptr,
|
||||
|
||||
__device__ __forceinline__ void ld128_cg_or_zero(uint4& val, const void* ptr,
|
||||
bool pred) {
|
||||
#if VLLM_256B_PTX_ENABLED
|
||||
#ifndef USE_ROCM
|
||||
uint32_t r0, r1, r2, r3;
|
||||
|
||||
asm volatile(
|
||||
@@ -278,7 +305,7 @@ __device__ __forceinline__ void ld128_cg_or_zero(uint4& val, const void* ptr,
|
||||
|
||||
val = uint4{r0, r1, r2, r3};
|
||||
#else
|
||||
assert(false && "ld128_cg_or_zero requires SM100+ with CUDA 12.9+");
|
||||
assert(false && "ld128_cg_or_zero is not supported on ROCm");
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
@@ -802,6 +802,10 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
||||
cache_ops.impl("indexer_k_quant_and_cache", torch::kCUDA,
|
||||
&indexer_k_quant_and_cache);
|
||||
|
||||
cache_ops.def(
|
||||
"concat_mla_q(Tensor ql_nope, Tensor q_pe, Tensor! q_out) -> ()");
|
||||
cache_ops.impl("concat_mla_q", torch::kCUDA, &concat_mla_q);
|
||||
|
||||
cache_ops.def(
|
||||
"cp_gather_indexer_k_quant_cache(Tensor kv_cache, Tensor! dst_k, Tensor! "
|
||||
"dst_scale, Tensor block_table, Tensor cu_seq_lens) -> ()");
|
||||
|
||||
@@ -0,0 +1,139 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
NUM_TOKENS = [1, 4, 16, 64, 128]
|
||||
NUM_HEADS = [128]
|
||||
NOPE_DIM = [512]
|
||||
ROPE_DIM = [64]
|
||||
DTYPES = [torch.bfloat16, torch.float16]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("nope_dim", NOPE_DIM)
|
||||
@pytest.mark.parametrize("rope_dim", ROPE_DIM)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
def test_concat_mla_q_contiguous(num_tokens, num_heads, nope_dim, rope_dim, dtype):
|
||||
"""Test with contiguous inputs (standard layout)."""
|
||||
torch.manual_seed(42)
|
||||
ql_nope = torch.randn(num_tokens, num_heads, nope_dim, dtype=dtype, device="cuda")
|
||||
q_pe = torch.randn(num_tokens, num_heads, rope_dim, dtype=dtype, device="cuda")
|
||||
|
||||
ref = torch.cat((ql_nope, q_pe), dim=-1)
|
||||
|
||||
q_out = torch.empty(
|
||||
num_tokens, num_heads, nope_dim + rope_dim, dtype=dtype, device="cuda"
|
||||
)
|
||||
ops.concat_mla_q(ql_nope, q_pe, q_out)
|
||||
|
||||
torch.testing.assert_close(q_out, ref, atol=0, rtol=0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", [t for t in NUM_TOKENS if t > 1])
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("nope_dim", NOPE_DIM)
|
||||
@pytest.mark.parametrize("rope_dim", ROPE_DIM)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
def test_concat_mla_q_transposed_nope(num_tokens, num_heads, nope_dim, rope_dim, dtype):
|
||||
"""Test with transposed nope input (simulates BMM output after transpose).
|
||||
|
||||
In the real code path, mqa_ql_nope is the result of:
|
||||
torch.bmm(q_nope, W_UK_T) # [N, B, L]
|
||||
.transpose(0, 1) # [B, N, L] — non-contiguous!
|
||||
"""
|
||||
torch.manual_seed(42)
|
||||
nope_raw = torch.randn(num_heads, num_tokens, nope_dim, dtype=dtype, device="cuda")
|
||||
ql_nope = nope_raw.transpose(0, 1) # [B, N, L], non-contiguous
|
||||
assert not ql_nope.is_contiguous()
|
||||
|
||||
q_pe = torch.randn(num_tokens, num_heads, rope_dim, dtype=dtype, device="cuda")
|
||||
|
||||
ref = torch.cat((ql_nope, q_pe), dim=-1)
|
||||
|
||||
q_out = torch.empty(
|
||||
num_tokens, num_heads, nope_dim + rope_dim, dtype=dtype, device="cuda"
|
||||
)
|
||||
ops.concat_mla_q(ql_nope, q_pe, q_out)
|
||||
|
||||
torch.testing.assert_close(q_out, ref, atol=0, rtol=0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
def test_concat_mla_q_split_rope(num_tokens, num_heads, dtype):
|
||||
"""Test with rope from a split (simulates the actual code path).
|
||||
|
||||
In the real code path, q_pe comes from:
|
||||
mqa_q.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1)
|
||||
which creates a non-contiguous view with stride(1) != rope_dim.
|
||||
"""
|
||||
torch.manual_seed(42)
|
||||
nope_dim = 512
|
||||
rope_dim = 64
|
||||
orig_dim = 128 + 64 # original q before absorption: [B, N, 192]
|
||||
|
||||
# Simulate split from original q tensor
|
||||
q_orig = torch.randn(num_tokens, num_heads, orig_dim, dtype=dtype, device="cuda")
|
||||
q_nope_orig, q_pe = q_orig.split([128, 64], dim=-1)
|
||||
|
||||
# q_pe is non-contiguous: stride(1) = 192, not 64
|
||||
assert q_pe.stride(1) == orig_dim
|
||||
assert q_pe.stride(2) == 1 # but innermost is fine
|
||||
|
||||
# Simulate absorbed nope (contiguous, different size)
|
||||
ql_nope = torch.randn(num_tokens, num_heads, nope_dim, dtype=dtype, device="cuda")
|
||||
|
||||
ref = torch.cat((ql_nope, q_pe), dim=-1)
|
||||
|
||||
q_out = torch.empty(
|
||||
num_tokens, num_heads, nope_dim + rope_dim, dtype=dtype, device="cuda"
|
||||
)
|
||||
ops.concat_mla_q(ql_nope, q_pe, q_out)
|
||||
|
||||
torch.testing.assert_close(q_out, ref, atol=0, rtol=0)
|
||||
|
||||
|
||||
def test_concat_mla_q_zero_tokens():
|
||||
"""Test with zero tokens (edge case)."""
|
||||
ql_nope = torch.empty(0, 128, 512, dtype=torch.bfloat16, device="cuda")
|
||||
q_pe = torch.empty(0, 128, 64, dtype=torch.bfloat16, device="cuda")
|
||||
q_out = torch.empty(0, 128, 576, dtype=torch.bfloat16, device="cuda")
|
||||
|
||||
ops.concat_mla_q(ql_nope, q_pe, q_out)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", [1, 64])
|
||||
def test_concat_mla_q_values_preserved(num_tokens):
|
||||
"""Verify exact bit-level preservation (no computation, pure copy).
|
||||
|
||||
Compares raw int16 bits to avoid NaN != NaN issues from IEEE 754.
|
||||
"""
|
||||
nope_dim, rope_dim = 512, 64
|
||||
|
||||
# Use specific bit patterns (stay in int16 for bit-exact comparison)
|
||||
ql_nope_bits = torch.arange(
|
||||
num_tokens * 128 * nope_dim, dtype=torch.int16, device="cuda"
|
||||
).view(num_tokens, 128, nope_dim)
|
||||
q_pe_bits = torch.arange(
|
||||
num_tokens * 128 * rope_dim, dtype=torch.int16, device="cuda"
|
||||
).view(num_tokens, 128, rope_dim)
|
||||
|
||||
ql_nope = ql_nope_bits.view(torch.bfloat16)
|
||||
q_pe = q_pe_bits.view(torch.bfloat16)
|
||||
|
||||
q_out = torch.empty(
|
||||
num_tokens, 128, nope_dim + rope_dim, dtype=torch.bfloat16, device="cuda"
|
||||
)
|
||||
ops.concat_mla_q(ql_nope, q_pe, q_out)
|
||||
|
||||
out_bits = q_out.view(torch.int16)
|
||||
|
||||
assert torch.equal(out_bits[..., :nope_dim], ql_nope_bits)
|
||||
|
||||
assert torch.equal(out_bits[..., nope_dim:], q_pe_bits)
|
||||
@@ -2672,6 +2672,21 @@ def cp_gather_and_upconvert_fp8_kv_cache(
|
||||
)
|
||||
|
||||
|
||||
def concat_mla_q(
|
||||
ql_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
q_out: torch.Tensor,
|
||||
) -> None:
|
||||
"""Concatenate query nope and rope for MLA/DSA attention.
|
||||
|
||||
Args:
|
||||
ql_nope: Query nope component [num_tokens, num_heads, nope_dim]
|
||||
q_pe: Query rope component [num_tokens, num_heads, rope_dim]
|
||||
q_out: Output tensor [num_tokens, num_heads, nope_dim + rope_dim]
|
||||
"""
|
||||
torch.ops._C_cache_ops.concat_mla_q(ql_nope, q_pe, q_out)
|
||||
|
||||
|
||||
def indexer_k_quant_and_cache(
|
||||
k: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
|
||||
@@ -568,6 +568,9 @@ class FlashMLASparseImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]):
|
||||
)
|
||||
self.fp8_decode_padded_heads = self._compute_fp8_decode_padded_heads(num_heads)
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
max_tokens = vllm_config.scheduler_config.max_num_batched_tokens
|
||||
q_concat_shape = (max_tokens, num_heads, head_size)
|
||||
if kv_cache_dtype.startswith("fp8"):
|
||||
assert kv_cache_dtype == "fp8_ds_mla", (
|
||||
"FlashMLA Sparse Attention backend fp8 only supports "
|
||||
@@ -576,17 +579,21 @@ class FlashMLASparseImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]):
|
||||
|
||||
if kv_cache_dtype == "fp8_ds_mla":
|
||||
# Reserve workspace during initialization
|
||||
vllm_config = get_current_vllm_config()
|
||||
assert vllm_config is not None and vllm_config.model_config is not None
|
||||
prefill_workspace_size = get_prefill_workspace_size(
|
||||
vllm_config.model_config.max_model_len
|
||||
)
|
||||
self.prefill_workspace_shape = (prefill_workspace_size, head_size)
|
||||
(self.prefill_bf16_workspace,) = (
|
||||
self.q_concat_buffer, self.prefill_bf16_workspace = (
|
||||
current_workspace_manager().get_simultaneous(
|
||||
(self.prefill_workspace_shape, torch.bfloat16)
|
||||
(q_concat_shape, torch.bfloat16),
|
||||
(self.prefill_workspace_shape, torch.bfloat16),
|
||||
)
|
||||
)
|
||||
else:
|
||||
(self.q_concat_buffer,) = current_workspace_manager().get_simultaneous(
|
||||
(q_concat_shape, torch.bfloat16),
|
||||
)
|
||||
|
||||
def _forward_bf16_kv(
|
||||
self,
|
||||
@@ -828,7 +835,9 @@ class FlashMLASparseImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]):
|
||||
|
||||
# Concatenate q if it's a tuple (ql_nope, q_pe)
|
||||
if isinstance(q, tuple):
|
||||
q = torch.cat(q, dim=-1)
|
||||
ql_nope, q_pe = q
|
||||
q = self.q_concat_buffer[: ql_nope.shape[0]]
|
||||
ops.concat_mla_q(ql_nope, q_pe, q)
|
||||
|
||||
num_actual_toks = q.shape[0]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user