[None][perf] Add custom indexer k cache scatter op (#8960)

Signed-off-by: Chang Liu (Enterprise Products) <9713593+chang-l@users.noreply.github.com>
This commit is contained in:
Chang Liu 2025-11-07 11:24:26 -08:00 committed by GitHub
parent c232ffd122
commit 7081f254cf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 470 additions and 27 deletions

View File

@ -0,0 +1,30 @@
/*
* Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/common/cudaUtils.h"
namespace tensorrt_llm::kernels
{
void invokeIndexerKCacheScatter(uint8_t const* k_fp8_bytes, uint8_t const* k_scale_bytes, uint8_t* k_cache,
int64_t const* slot_mapping_fp8, int64_t const* slot_mapping_scale, int32_t num_tokens, int32_t head_dim,
int32_t scale_size, int32_t cache_dim_0, int32_t cache_dim_1, int32_t cache_dim_2, int32_t cache_dim_3,
int64_t cache_stride_0, int64_t cache_stride_1, int64_t cache_stride_2, int64_t cache_stride_3,
cudaStream_t stream = 0);
}

View File

@ -0,0 +1,152 @@
/*
* Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "IndexerKCacheScatter.h"
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaUtils.h"
namespace tensorrt_llm::kernels
{
namespace
{
/**
* Given a flat element index and tensor shape [d0, d1, d2, d3] with strides [s0, s1, s2, s3],
* find the actual memory offset within the given k cache pool using the strides.
*/
__device__ __forceinline__ int64_t flatIndexToMemoryOffset(
int64_t flat_idx, int32_t d0, int32_t d1, int32_t d2, int32_t d3, int64_t s0, int64_t s1, int64_t s2, int64_t s3)
{
// Unravel from innermost to outermost dimension
int32_t i3 = flat_idx % d3;
flat_idx /= d3;
int32_t i2 = flat_idx % d2;
flat_idx /= d2;
int32_t i1 = flat_idx % d1;
flat_idx /= d1;
int32_t i0 = flat_idx;
// Compute memory offset using strides
return i0 * s0 + i1 * s1 + i2 * s2 + i3 * s3;
}
} // anonymous namespace
/**
* CUDA kernel to scatter both FP8 K values and scales into the indexer k cache pool
*
* @param k_fp8_bytes Quantized FP8 data [num_tokens, 128]
* @param k_scale_bytes Quantized scales (1 per token) [num_tokens, 4]
* @param k_cache Indexer k cache pool with shape [num_blocks, block_size, 1, per_token_size] (can be
* non-contiguous)
* @param slot_mapping_fp8 Flat element index for FP8 data start position [num_tokens]
* @param slot_mapping_scale Flat element index for scale data start position [num_tokens]
* @param num_tokens Number of tokens
* @param head_dim Head dimension (must be 128)
* @param scale_size Scale size in bytes (must be 4)
* @param cache_stride_0 Stride for k_cache dimension 0 (in bytes)
* @param cache_stride_1 Stride for k_cache dimension 1 (in bytes)
* @param cache_stride_2 Stride for k_cache dimension 2 (in bytes)
* @param cache_stride_3 Stride for k_cache dimension 3 (in bytes)
* @param cache_dim_0 Size of k_cache dimension 0
* @param cache_dim_1 Size of k_cache dimension 1
* @param cache_dim_2 Size of k_cache dimension 2
* @param cache_dim_3 Size of k_cache dimension 3
*/
__global__ void indexerKCacheScatterUnifiedKernel(uint8_t const* __restrict__ k_fp8_bytes,
uint8_t const* __restrict__ k_scale_bytes, uint8_t* __restrict__ k_cache,
int64_t const* __restrict__ slot_mapping_fp8, int64_t const* __restrict__ slot_mapping_scale, int32_t num_tokens,
int32_t head_dim, int32_t scale_size, int64_t cache_stride_0, int64_t cache_stride_1, int64_t cache_stride_2,
int64_t cache_stride_3, int32_t cache_dim_0, int32_t cache_dim_1, int32_t cache_dim_2, int32_t cache_dim_3)
{
// For head_dim=128, each thread handles 4 bytes/elements per read/write instruction
constexpr int VEC_SIZE = 4;
// Token index from block.x
int32_t token_idx = blockIdx.x;
if (token_idx >= num_tokens)
{
return;
}
int64_t flat_idx_fp8_base = slot_mapping_fp8[token_idx];
int64_t flat_idx_scale_base = slot_mapping_scale[token_idx];
if (flat_idx_fp8_base < 0 || flat_idx_scale_base < 0)
{
return;
}
int32_t head_dim_idx = threadIdx.x * VEC_SIZE;
int64_t flat_idx = flat_idx_fp8_base + head_dim_idx;
// Convert flat index to memory offset using strides (k cache pool from cpp kv cache manager is non-contiguous)
int64_t dst_offset = flatIndexToMemoryOffset(flat_idx, cache_dim_0, cache_dim_1, cache_dim_2, cache_dim_3,
cache_stride_0, cache_stride_1, cache_stride_2, cache_stride_3);
int64_t src_offset = token_idx * head_dim + head_dim_idx;
// 4 bytes write
*reinterpret_cast<uint32_t*>(&k_cache[dst_offset]) = *reinterpret_cast<uint32_t const*>(&k_fp8_bytes[src_offset]);
// Only thread 0 writes the single 4 bytes scale value
if (threadIdx.x == 0)
{
int64_t dst_offset_scale = flatIndexToMemoryOffset(flat_idx_scale_base, cache_dim_0, cache_dim_1, cache_dim_2,
cache_dim_3, cache_stride_0, cache_stride_1, cache_stride_2, cache_stride_3);
int64_t src_offset_scale = token_idx * scale_size; // scale_size = 4
// 4 bytes write for scale
*reinterpret_cast<uint32_t*>(&k_cache[dst_offset_scale])
= *reinterpret_cast<uint32_t const*>(&k_scale_bytes[src_offset_scale]);
}
}
void invokeIndexerKCacheScatter(uint8_t const* k_fp8_bytes, uint8_t const* k_scale_bytes, uint8_t* k_cache,
int64_t const* slot_mapping_fp8, int64_t const* slot_mapping_scale, int32_t num_tokens, int32_t head_dim,
int32_t scale_size, int32_t cache_dim_0, int32_t cache_dim_1, int32_t cache_dim_2, int32_t cache_dim_3,
int64_t cache_stride_0, int64_t cache_stride_1, int64_t cache_stride_2, int64_t cache_stride_3, cudaStream_t stream)
{
if (num_tokens == 0)
{
return;
}
// Assertions for DeepSeek-V3.2 configuration
constexpr int32_t QUANT_BLOCK_SIZE = 128;
TLLM_CHECK_WITH_INFO(
head_dim == QUANT_BLOCK_SIZE, "head_dim must equal 128 for DeepSeek-V3 indexer cache (got %d)", head_dim);
TLLM_CHECK_WITH_INFO(
scale_size == 4, "scale_size must equal 4 bytes (1 float32 scale per token, got %d)", scale_size);
// For head_dim=128, we use 32 threads to handle 128 bytes per token and extra 4 bytes for scale
constexpr int32_t THREADS_PER_BLOCK = 32;
dim3 block(THREADS_PER_BLOCK);
dim3 grid(num_tokens);
indexerKCacheScatterUnifiedKernel<<<grid, block, 0, stream>>>(k_fp8_bytes, k_scale_bytes, k_cache, slot_mapping_fp8,
slot_mapping_scale, num_tokens, head_dim, scale_size, cache_stride_0, cache_stride_1, cache_stride_2,
cache_stride_3, cache_dim_0, cache_dim_1, cache_dim_2, cache_dim_3);
// Check for kernel launch errors
TLLM_CUDA_CHECK(cudaGetLastError());
}
} // namespace tensorrt_llm::kernels

View File

@ -83,6 +83,7 @@ add_library(
fp8PerTensorScaleMoe.cpp
fp4BlockScaleMoe.cpp
noAuxTcOp.cpp
IndexerKCacheScatterOp.cpp
ncclCommunicatorOp.cpp
parallelDecodeKVCacheUpdateOp.cpp
redrafterCurandOp.cpp

View File

@ -0,0 +1,106 @@
/*
* Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/common/opUtils.h"
#include "tensorrt_llm/runtime/torchUtils.h"
#include "tensorrt_llm/kernels/IndexerKCacheScatter.h"
namespace th = torch;
namespace tl = tensorrt_llm;
namespace tk = tensorrt_llm::kernels;
namespace torch_ext
{
void indexer_k_cache_scatter_op(th::Tensor const& k_fp8_bytes, th::Tensor const& k_scale_bytes, th::Tensor& k_cache,
th::Tensor const& slot_mapping_fp8, th::Tensor const& slot_mapping_scale)
{
// Validate all tensors are CUDA tensors
TORCH_CHECK(k_fp8_bytes.is_cuda() && k_scale_bytes.is_cuda() && k_cache.is_cuda() && slot_mapping_fp8.is_cuda()
&& slot_mapping_scale.is_cuda(),
"All tensors must be CUDA tensors");
// Validate tensor dimensions
TORCH_CHECK(k_fp8_bytes.dim() == 2, "k_fp8_bytes must be a 2D Tensor [num_tokens, head_dim]");
TORCH_CHECK(k_scale_bytes.dim() == 2, "k_scale_bytes must be a 2D Tensor [num_tokens, scale_size]");
TORCH_CHECK(slot_mapping_fp8.dim() == 1, "slot_mapping_fp8 must be a 1D Tensor [num_tokens]");
TORCH_CHECK(slot_mapping_scale.dim() == 1, "slot_mapping_scale must be a 1D Tensor [num_tokens]");
// Enforce k_cache is 4D tensor
TORCH_CHECK(k_cache.dim() == 4,
"k_cache must be a 4D Tensor [num_blocks, block_size, 1, per_token_size], got %d dimensions",
static_cast<int>(k_cache.dim()));
// Validate tensor dtypes
TORCH_CHECK(k_fp8_bytes.scalar_type() == torch::kUInt8, "k_fp8_bytes must be uint8");
TORCH_CHECK(k_scale_bytes.scalar_type() == torch::kUInt8, "k_scale_bytes must be uint8");
TORCH_CHECK(slot_mapping_fp8.scalar_type() == torch::kInt64, "slot_mapping_fp8 must be int64");
TORCH_CHECK(slot_mapping_scale.scalar_type() == torch::kInt64, "slot_mapping_scale must be int64");
// Validate tensor shapes are consistent
auto num_tokens = static_cast<int32_t>(k_fp8_bytes.size(0));
TORCH_CHECK(
k_scale_bytes.size(0) == num_tokens, "k_scale_bytes first dimension must equal k_fp8_bytes first dimension");
TORCH_CHECK(slot_mapping_fp8.size(0) == num_tokens, "slot_mapping_fp8 length must equal num_tokens");
TORCH_CHECK(slot_mapping_scale.size(0) == num_tokens, "slot_mapping_scale length must equal num_tokens");
// Validate tensors are contiguous (except k_cache which may be non-contiguous)
TORCH_CHECK(k_fp8_bytes.is_contiguous(), "k_fp8_bytes must be contiguous");
TORCH_CHECK(k_scale_bytes.is_contiguous(), "k_scale_bytes must be contiguous");
// k_cache can be non-contiguous - we handle this via strides
TORCH_CHECK(slot_mapping_fp8.is_contiguous(), "slot_mapping_fp8 must be contiguous");
TORCH_CHECK(slot_mapping_scale.is_contiguous(), "slot_mapping_scale must be contiguous");
int32_t head_dim = static_cast<int32_t>(k_fp8_bytes.size(1)); // head_dim = quant_block_size = 128
int32_t scale_size = static_cast<int32_t>(k_scale_bytes.size(1)); // scale_size = 4 bytes
int32_t cache_dim_0 = static_cast<int32_t>(k_cache.size(0)); // num_blocks
int32_t cache_dim_1 = static_cast<int32_t>(k_cache.size(1)); // block_size
int32_t cache_dim_2 = static_cast<int32_t>(k_cache.size(2)); // num_kv_heads
int32_t cache_dim_3 = static_cast<int32_t>(k_cache.size(3)); // per_token_size
// Validation for indexer k cache pool for DeepSeek-V3.2 constraints
TORCH_CHECK(cache_dim_2 == 1, "k_cache dimension 2 must be 1 for DeepSeek-V3.2, got %d", cache_dim_2);
TORCH_CHECK(head_dim == 128, "k_fp8_bytes head_dim must be 128 for DeepSeek-V3.2, got %d", head_dim);
TORCH_CHECK(scale_size == 4, "k_scale_bytes scale_size must be 4 bytes for DeepSeek-V3.2, got %d", scale_size);
int64_t cache_stride_0 = static_cast<int64_t>(k_cache.stride(0));
int64_t cache_stride_1 = static_cast<int64_t>(k_cache.stride(1));
int64_t cache_stride_2 = static_cast<int64_t>(k_cache.stride(2));
int64_t cache_stride_3 = static_cast<int64_t>(k_cache.stride(3));
auto stream = at::cuda::getCurrentCUDAStream(k_fp8_bytes.get_device());
tk::invokeIndexerKCacheScatter(k_fp8_bytes.data_ptr<uint8_t>(), k_scale_bytes.data_ptr<uint8_t>(),
k_cache.data_ptr<uint8_t>(), slot_mapping_fp8.data_ptr<int64_t>(), slot_mapping_scale.data_ptr<int64_t>(),
num_tokens, head_dim, scale_size, cache_dim_0, cache_dim_1, cache_dim_2, cache_dim_3, cache_stride_0,
cache_stride_1, cache_stride_2, cache_stride_3, stream);
}
} // namespace torch_ext
TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
m.def(
"indexer_k_cache_scatter_op(Tensor k_fp8_bytes, Tensor k_scale_bytes, Tensor(a!) k_cache, "
"Tensor slot_mapping_fp8, Tensor slot_mapping_scale) -> ()");
}
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
{
m.impl("indexer_k_cache_scatter_op", &torch_ext::indexer_k_cache_scatter_op);
}

View File

@ -872,24 +872,12 @@ class Indexer(nn.Module):
k_scale_bytes = k_scale_flat.view(torch.uint8).view(
num_tokens, scale_size)
# Scatter FP8 data
# Use CUDA kernel to scatter FP8 and scale bytes into cache
flat_indices_fp8 = metadata.slot_mapping_fp8[:num_tokens]
byte_offsets = torch.arange(head_dim, device=k_cache.device).unsqueeze(
0) # [1, head_dim]
scatter_indices_fp8 = flat_indices_fp8.unsqueeze(
1) + byte_offsets # [num_tokens, head_dim]
scatter_indices_fp8 = _unravel_indices(scatter_indices_fp8,
k_cache.shape)
k_cache[scatter_indices_fp8] = k_fp8_bytes
flat_indices_scale = metadata.slot_mapping_scale[:num_tokens]
byte_offsets = torch.arange(
scale_size, device=k_cache.device).unsqueeze(0) # [1, scale_size]
scatter_indices_scale = flat_indices_scale.unsqueeze(
1) + byte_offsets # [num_tokens, scale_size]
scatter_indices_scale = _unravel_indices(scatter_indices_scale,
k_cache.shape)
k_cache[scatter_indices_scale] = k_scale_bytes
torch.ops.trtllm.indexer_k_cache_scatter_op(k_fp8_bytes, k_scale_bytes,
k_cache, flat_indices_fp8,
flat_indices_scale)
def _gather_k_cache_for_chunk(
self,

View File

@ -12,7 +12,7 @@ from unittest.mock import Mock, patch
import pytest
import torch
from utils.util import check_accuracy, getSMVersion
from utils.util import check_accuracy, skip_pre_hopper
from tensorrt_llm import deep_gemm
from tensorrt_llm._torch.attention_backend.interface import (
@ -70,12 +70,9 @@ def create_dsa_cache_manager(
index_topk=2048)
# Create KV cache config
# Note: max_attention_window expects list[int] (one per layer)
kv_cache_config = KvCacheConfig(
enable_block_reuse=False,
max_tokens=max_seq_len * batch_size,
max_attention_window=[max_seq_len] *
num_layers, # List of max window per layer
)
# Create mapping (single GPU, no parallelism)
@ -303,8 +300,7 @@ def _ref_fp8_mqa_logits(
@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available")
@pytest.mark.skipif(getSMVersion() < 90,
reason="fp8_mqa_logits is only supported in SM90 and SM100")
@skip_pre_hopper
def test_deepgemm_fp8_mqa_logits_basic():
"""
Basic test for deepgemm.fp8_mqa_logits kernel.
@ -477,7 +473,179 @@ def _create_mock_metadata(request_ids,
@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available")
@pytest.mark.skipif(getSMVersion() < 90, reason="FP8 operations require SM90+")
@skip_pre_hopper
def test_indexer_k_cache_scatter_custom_op():
"""
Direct comparison: CUDA kernel vs Python reference for k_cache scatter.
This test ensures the new CUDA kernel indexer_k_cache_scatter_op produces
exactly the same results as the Python scatter implementation.
"""
torch.manual_seed(123)
# Test parameters
head_dim = 128
block_size = 64
batch_size = 3
num_tokens = 96 # 3 requests × 32 tokens each
max_seq_len = 512
# Use different layers for CUDA vs Python to test non-contiguous handling
layer_idx_cuda = 1 # CUDA kernel writes to layer 0
layer_idx_python = 2 # Python reference writes to layer 1
# Create cache manager with multiple layers
cache_manager, sparse_attn_config = create_dsa_cache_manager(
batch_size=batch_size,
head_dim=head_dim,
tokens_per_block=block_size,
max_seq_len=max_seq_len,
num_layers=3) # Multi-layer pool for non-contiguous test
# Allocate blocks
request_ids = list(range(batch_size))
tokens_per_req = [32, 32, 32]
cache_manager.add_dummy_requests(request_ids,
tokens_per_req,
is_gen=False,
prepare_resource=True)
# Create metadata
metadata = _create_mock_metadata(
request_ids,
batch_size,
num_contexts=batch_size,
num_generations=0,
seq_lens=torch.tensor(tokens_per_req, dtype=torch.int32),
kv_lens=torch.tensor(tokens_per_req, dtype=torch.int32),
num_cached_tokens=[0] * batch_size,
cache_manager=cache_manager,
num_ctx_tokens=num_tokens,
num_tokens=num_tokens,
)
from tensorrt_llm._torch.attention_backend.sparse.dsa import Indexer
Indexer.prepare(metadata)
# Generate test data
k_original = torch.randn((num_tokens, head_dim),
device="cuda",
dtype=torch.bfloat16)
k_fp8, k_scale = fp8_utils.fp8_quantize_1x128_sf_transpose(k_original)
# Prepare byte-level data
scale_size = k_scale.shape[1] * 4
k_fp8_bytes = k_fp8.view(-1).view(torch.uint8).view(num_tokens, head_dim)
k_scale_flat = k_scale.view(-1)
if k_scale_flat.stride(-1) != 1:
k_scale_flat = torch.as_strided(k_scale_flat.contiguous(),
size=(k_scale_flat.numel(), ),
stride=(1, ))
k_scale_bytes = k_scale_flat.view(torch.uint8).view(num_tokens, scale_size)
flat_indices_fp8 = metadata.slot_mapping_fp8[:num_tokens]
flat_indices_scale = metadata.slot_mapping_scale[:num_tokens]
# ========== Use Different Layers for CUDA vs Python ==========
# Simple approach: use layer 0 for CUDA, layer 1 for Python
# Both get the same input data, but write to different layers
# Then we extract and compare the outputs from each layer
# Get k_cache for CUDA path (layer 0)
k_cache_cuda = cache_manager.get_indexer_k_cache_buffers(layer_idx_cuda)
k_cache_cuda.zero_()
# Get k_cache for Python path (layer 1)
k_cache_python = cache_manager.get_indexer_k_cache_buffers(layer_idx_python)
k_cache_python.zero_()
# Print cache properties
print(f"\n=== Cache Properties ===")
print(f" CUDA (layer {layer_idx_cuda}):")
print(f" Shape: {k_cache_cuda.shape}")
print(f" Stride: {k_cache_cuda.stride()}")
print(f" is_contiguous: {k_cache_cuda.is_contiguous()}")
print(f" Python (layer {layer_idx_python}):")
print(f" Shape: {k_cache_python.shape}")
print(f" Stride: {k_cache_python.stride()}")
print(f" is_contiguous: {k_cache_python.is_contiguous()}")
# ========== Path 1: CUDA Kernel ==========
print(f"\n=== Path 1: CUDA Kernel ===")
torch.ops.trtllm.indexer_k_cache_scatter_op(k_fp8_bytes, k_scale_bytes,
k_cache_cuda, flat_indices_fp8,
flat_indices_scale)
torch.cuda.synchronize()
print(f"✓ CUDA kernel completed")
# ========== Path 2: Python Reference ==========
print(f"\n=== Path 2: Python Reference ===")
def _unravel_indices(flat_indices, shape):
d3 = shape[3]
i3 = flat_indices % d3
flat_indices = flat_indices // d3
d2 = shape[2]
i2 = flat_indices % d2
flat_indices = flat_indices // d2
d1 = shape[1]
i1 = flat_indices % d1
flat_indices = flat_indices // d1
i0 = flat_indices
return i0, i1, i2, i3
# Scatter FP8 data
byte_offsets = torch.arange(head_dim,
device=k_cache_python.device).unsqueeze(0)
scatter_indices_fp8 = flat_indices_fp8.unsqueeze(1) + byte_offsets
scatter_indices_fp8 = _unravel_indices(scatter_indices_fp8,
k_cache_python.shape)
k_cache_python[scatter_indices_fp8] = k_fp8_bytes
# Scatter scale data
byte_offsets = torch.arange(scale_size,
device=k_cache_python.device).unsqueeze(0)
scatter_indices_scale = flat_indices_scale.unsqueeze(1) + byte_offsets
scatter_indices_scale = _unravel_indices(scatter_indices_scale,
k_cache_python.shape)
k_cache_python[scatter_indices_scale] = k_scale_bytes
# ========== Validation: Byte-for-Byte Comparison ==========
print(f"\n=== Validation ===")
total_bytes = k_cache_cuda.numel()
# Compare entire cache tensors
if torch.equal(k_cache_cuda, k_cache_python):
print(f"✅ PERFECT MATCH! CUDA and Python produce identical cache")
print(f" Total bytes compared: {total_bytes}")
print(
f" Tokens: {num_tokens}, head_dim: {head_dim}, block_size: {block_size}"
)
else:
# Find differences
diff_mask = k_cache_cuda != k_cache_python
num_diffs = diff_mask.sum().item()
print(
f"⚠️ Found {num_diffs}/{total_bytes} byte differences ({100*num_diffs/total_bytes:.4f}%)"
)
# Show first few differences
diff_indices = torch.nonzero(diff_mask.view(-1))[:5]
for idx in diff_indices:
flat_idx = idx.item()
print(
f" Byte {flat_idx}: CUDA={k_cache_cuda.view(-1)[flat_idx].item()}, "
f"Python={k_cache_python.view(-1)[flat_idx].item()}")
# Fail the test
raise AssertionError(
"CUDA kernel produced different results than Python reference")
@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available")
@skip_pre_hopper
def test_fp8_k_cache_roundtrip():
"""Verify FP8 quantization scales survive write/read cycle for multiple requests."""
torch.manual_seed(42)
@ -562,9 +730,7 @@ def test_fp8_k_cache_roundtrip():
@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available")
@pytest.mark.skipif(
getSMVersion() < 90,
reason="fp8_paged_mqa_logits is only supported in SM90 and SM100")
@skip_pre_hopper
@pytest.mark.parametrize("batch_size,next_n", [(4, 1), (2, 2)])
def test_indexer_decode_with_paged_kv_cache(batch_size, next_n):
"""
@ -859,7 +1025,7 @@ def test_split_prefill_chunks(max_chunk_size, seq_lens, start_idx,
@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available")
@pytest.mark.skipif(getSMVersion() < 90, reason="FP8 operations require SM90+")
@skip_pre_hopper
@pytest.mark.parametrize(
"chunk_size,seq_lens_list,chunking_type",
[