mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[None][perf] Add custom indexer k cache scatter op (#8960)
Signed-off-by: Chang Liu (Enterprise Products) <9713593+chang-l@users.noreply.github.com>
This commit is contained in:
parent
c232ffd122
commit
7081f254cf
30
cpp/tensorrt_llm/kernels/IndexerKCacheScatter.h
Normal file
30
cpp/tensorrt_llm/kernels/IndexerKCacheScatter.h
Normal file
@ -0,0 +1,30 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
|
||||
namespace tensorrt_llm::kernels
|
||||
{
|
||||
|
||||
void invokeIndexerKCacheScatter(uint8_t const* k_fp8_bytes, uint8_t const* k_scale_bytes, uint8_t* k_cache,
|
||||
int64_t const* slot_mapping_fp8, int64_t const* slot_mapping_scale, int32_t num_tokens, int32_t head_dim,
|
||||
int32_t scale_size, int32_t cache_dim_0, int32_t cache_dim_1, int32_t cache_dim_2, int32_t cache_dim_3,
|
||||
int64_t cache_stride_0, int64_t cache_stride_1, int64_t cache_stride_2, int64_t cache_stride_3,
|
||||
cudaStream_t stream = 0);
|
||||
|
||||
}
|
||||
152
cpp/tensorrt_llm/kernels/indexerKCacheScatter.cu
Normal file
152
cpp/tensorrt_llm/kernels/indexerKCacheScatter.cu
Normal file
@ -0,0 +1,152 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "IndexerKCacheScatter.h"
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
|
||||
namespace tensorrt_llm::kernels
|
||||
{
|
||||
|
||||
namespace
|
||||
{
|
||||
/**
|
||||
* Given a flat element index and tensor shape [d0, d1, d2, d3] with strides [s0, s1, s2, s3],
|
||||
* find the actual memory offset within the given k cache pool using the strides.
|
||||
*/
|
||||
__device__ __forceinline__ int64_t flatIndexToMemoryOffset(
|
||||
int64_t flat_idx, int32_t d0, int32_t d1, int32_t d2, int32_t d3, int64_t s0, int64_t s1, int64_t s2, int64_t s3)
|
||||
{
|
||||
// Unravel from innermost to outermost dimension
|
||||
int32_t i3 = flat_idx % d3;
|
||||
flat_idx /= d3;
|
||||
|
||||
int32_t i2 = flat_idx % d2;
|
||||
flat_idx /= d2;
|
||||
|
||||
int32_t i1 = flat_idx % d1;
|
||||
flat_idx /= d1;
|
||||
|
||||
int32_t i0 = flat_idx;
|
||||
|
||||
// Compute memory offset using strides
|
||||
return i0 * s0 + i1 * s1 + i2 * s2 + i3 * s3;
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
/**
|
||||
* CUDA kernel to scatter both FP8 K values and scales into the indexer k cache pool
|
||||
*
|
||||
* @param k_fp8_bytes Quantized FP8 data [num_tokens, 128]
|
||||
* @param k_scale_bytes Quantized scales (1 per token) [num_tokens, 4]
|
||||
* @param k_cache Indexer k cache pool with shape [num_blocks, block_size, 1, per_token_size] (can be
|
||||
* non-contiguous)
|
||||
* @param slot_mapping_fp8 Flat element index for FP8 data start position [num_tokens]
|
||||
* @param slot_mapping_scale Flat element index for scale data start position [num_tokens]
|
||||
* @param num_tokens Number of tokens
|
||||
* @param head_dim Head dimension (must be 128)
|
||||
* @param scale_size Scale size in bytes (must be 4)
|
||||
* @param cache_stride_0 Stride for k_cache dimension 0 (in bytes)
|
||||
* @param cache_stride_1 Stride for k_cache dimension 1 (in bytes)
|
||||
* @param cache_stride_2 Stride for k_cache dimension 2 (in bytes)
|
||||
* @param cache_stride_3 Stride for k_cache dimension 3 (in bytes)
|
||||
* @param cache_dim_0 Size of k_cache dimension 0
|
||||
* @param cache_dim_1 Size of k_cache dimension 1
|
||||
* @param cache_dim_2 Size of k_cache dimension 2
|
||||
* @param cache_dim_3 Size of k_cache dimension 3
|
||||
*/
|
||||
__global__ void indexerKCacheScatterUnifiedKernel(uint8_t const* __restrict__ k_fp8_bytes,
|
||||
uint8_t const* __restrict__ k_scale_bytes, uint8_t* __restrict__ k_cache,
|
||||
int64_t const* __restrict__ slot_mapping_fp8, int64_t const* __restrict__ slot_mapping_scale, int32_t num_tokens,
|
||||
int32_t head_dim, int32_t scale_size, int64_t cache_stride_0, int64_t cache_stride_1, int64_t cache_stride_2,
|
||||
int64_t cache_stride_3, int32_t cache_dim_0, int32_t cache_dim_1, int32_t cache_dim_2, int32_t cache_dim_3)
|
||||
{
|
||||
// For head_dim=128, each thread handles 4 bytes/elements per read/write instruction
|
||||
constexpr int VEC_SIZE = 4;
|
||||
|
||||
// Token index from block.x
|
||||
int32_t token_idx = blockIdx.x;
|
||||
|
||||
if (token_idx >= num_tokens)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
int64_t flat_idx_fp8_base = slot_mapping_fp8[token_idx];
|
||||
int64_t flat_idx_scale_base = slot_mapping_scale[token_idx];
|
||||
|
||||
if (flat_idx_fp8_base < 0 || flat_idx_scale_base < 0)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
int32_t head_dim_idx = threadIdx.x * VEC_SIZE;
|
||||
int64_t flat_idx = flat_idx_fp8_base + head_dim_idx;
|
||||
|
||||
// Convert flat index to memory offset using strides (k cache pool from cpp kv cache manager is non-contiguous)
|
||||
int64_t dst_offset = flatIndexToMemoryOffset(flat_idx, cache_dim_0, cache_dim_1, cache_dim_2, cache_dim_3,
|
||||
cache_stride_0, cache_stride_1, cache_stride_2, cache_stride_3);
|
||||
int64_t src_offset = token_idx * head_dim + head_dim_idx;
|
||||
|
||||
// 4 bytes write
|
||||
*reinterpret_cast<uint32_t*>(&k_cache[dst_offset]) = *reinterpret_cast<uint32_t const*>(&k_fp8_bytes[src_offset]);
|
||||
|
||||
// Only thread 0 writes the single 4 bytes scale value
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
int64_t dst_offset_scale = flatIndexToMemoryOffset(flat_idx_scale_base, cache_dim_0, cache_dim_1, cache_dim_2,
|
||||
cache_dim_3, cache_stride_0, cache_stride_1, cache_stride_2, cache_stride_3);
|
||||
int64_t src_offset_scale = token_idx * scale_size; // scale_size = 4
|
||||
|
||||
// 4 bytes write for scale
|
||||
*reinterpret_cast<uint32_t*>(&k_cache[dst_offset_scale])
|
||||
= *reinterpret_cast<uint32_t const*>(&k_scale_bytes[src_offset_scale]);
|
||||
}
|
||||
}
|
||||
|
||||
void invokeIndexerKCacheScatter(uint8_t const* k_fp8_bytes, uint8_t const* k_scale_bytes, uint8_t* k_cache,
|
||||
int64_t const* slot_mapping_fp8, int64_t const* slot_mapping_scale, int32_t num_tokens, int32_t head_dim,
|
||||
int32_t scale_size, int32_t cache_dim_0, int32_t cache_dim_1, int32_t cache_dim_2, int32_t cache_dim_3,
|
||||
int64_t cache_stride_0, int64_t cache_stride_1, int64_t cache_stride_2, int64_t cache_stride_3, cudaStream_t stream)
|
||||
{
|
||||
if (num_tokens == 0)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
// Assertions for DeepSeek-V3.2 configuration
|
||||
constexpr int32_t QUANT_BLOCK_SIZE = 128;
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
head_dim == QUANT_BLOCK_SIZE, "head_dim must equal 128 for DeepSeek-V3 indexer cache (got %d)", head_dim);
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
scale_size == 4, "scale_size must equal 4 bytes (1 float32 scale per token, got %d)", scale_size);
|
||||
|
||||
// For head_dim=128, we use 32 threads to handle 128 bytes per token and extra 4 bytes for scale
|
||||
constexpr int32_t THREADS_PER_BLOCK = 32;
|
||||
|
||||
dim3 block(THREADS_PER_BLOCK);
|
||||
dim3 grid(num_tokens);
|
||||
|
||||
indexerKCacheScatterUnifiedKernel<<<grid, block, 0, stream>>>(k_fp8_bytes, k_scale_bytes, k_cache, slot_mapping_fp8,
|
||||
slot_mapping_scale, num_tokens, head_dim, scale_size, cache_stride_0, cache_stride_1, cache_stride_2,
|
||||
cache_stride_3, cache_dim_0, cache_dim_1, cache_dim_2, cache_dim_3);
|
||||
|
||||
// Check for kernel launch errors
|
||||
TLLM_CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::kernels
|
||||
@ -83,6 +83,7 @@ add_library(
|
||||
fp8PerTensorScaleMoe.cpp
|
||||
fp4BlockScaleMoe.cpp
|
||||
noAuxTcOp.cpp
|
||||
IndexerKCacheScatterOp.cpp
|
||||
ncclCommunicatorOp.cpp
|
||||
parallelDecodeKVCacheUpdateOp.cpp
|
||||
redrafterCurandOp.cpp
|
||||
|
||||
106
cpp/tensorrt_llm/thop/IndexerKCacheScatterOp.cpp
Normal file
106
cpp/tensorrt_llm/thop/IndexerKCacheScatterOp.cpp
Normal file
@ -0,0 +1,106 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/common/opUtils.h"
|
||||
#include "tensorrt_llm/runtime/torchUtils.h"
|
||||
|
||||
#include "tensorrt_llm/kernels/IndexerKCacheScatter.h"
|
||||
|
||||
namespace th = torch;
|
||||
namespace tl = tensorrt_llm;
|
||||
namespace tk = tensorrt_llm::kernels;
|
||||
|
||||
namespace torch_ext
|
||||
{
|
||||
|
||||
void indexer_k_cache_scatter_op(th::Tensor const& k_fp8_bytes, th::Tensor const& k_scale_bytes, th::Tensor& k_cache,
|
||||
th::Tensor const& slot_mapping_fp8, th::Tensor const& slot_mapping_scale)
|
||||
{
|
||||
// Validate all tensors are CUDA tensors
|
||||
TORCH_CHECK(k_fp8_bytes.is_cuda() && k_scale_bytes.is_cuda() && k_cache.is_cuda() && slot_mapping_fp8.is_cuda()
|
||||
&& slot_mapping_scale.is_cuda(),
|
||||
"All tensors must be CUDA tensors");
|
||||
|
||||
// Validate tensor dimensions
|
||||
TORCH_CHECK(k_fp8_bytes.dim() == 2, "k_fp8_bytes must be a 2D Tensor [num_tokens, head_dim]");
|
||||
TORCH_CHECK(k_scale_bytes.dim() == 2, "k_scale_bytes must be a 2D Tensor [num_tokens, scale_size]");
|
||||
TORCH_CHECK(slot_mapping_fp8.dim() == 1, "slot_mapping_fp8 must be a 1D Tensor [num_tokens]");
|
||||
TORCH_CHECK(slot_mapping_scale.dim() == 1, "slot_mapping_scale must be a 1D Tensor [num_tokens]");
|
||||
|
||||
// Enforce k_cache is 4D tensor
|
||||
TORCH_CHECK(k_cache.dim() == 4,
|
||||
"k_cache must be a 4D Tensor [num_blocks, block_size, 1, per_token_size], got %d dimensions",
|
||||
static_cast<int>(k_cache.dim()));
|
||||
|
||||
// Validate tensor dtypes
|
||||
TORCH_CHECK(k_fp8_bytes.scalar_type() == torch::kUInt8, "k_fp8_bytes must be uint8");
|
||||
TORCH_CHECK(k_scale_bytes.scalar_type() == torch::kUInt8, "k_scale_bytes must be uint8");
|
||||
TORCH_CHECK(slot_mapping_fp8.scalar_type() == torch::kInt64, "slot_mapping_fp8 must be int64");
|
||||
TORCH_CHECK(slot_mapping_scale.scalar_type() == torch::kInt64, "slot_mapping_scale must be int64");
|
||||
|
||||
// Validate tensor shapes are consistent
|
||||
auto num_tokens = static_cast<int32_t>(k_fp8_bytes.size(0));
|
||||
TORCH_CHECK(
|
||||
k_scale_bytes.size(0) == num_tokens, "k_scale_bytes first dimension must equal k_fp8_bytes first dimension");
|
||||
TORCH_CHECK(slot_mapping_fp8.size(0) == num_tokens, "slot_mapping_fp8 length must equal num_tokens");
|
||||
TORCH_CHECK(slot_mapping_scale.size(0) == num_tokens, "slot_mapping_scale length must equal num_tokens");
|
||||
|
||||
// Validate tensors are contiguous (except k_cache which may be non-contiguous)
|
||||
TORCH_CHECK(k_fp8_bytes.is_contiguous(), "k_fp8_bytes must be contiguous");
|
||||
TORCH_CHECK(k_scale_bytes.is_contiguous(), "k_scale_bytes must be contiguous");
|
||||
// k_cache can be non-contiguous - we handle this via strides
|
||||
TORCH_CHECK(slot_mapping_fp8.is_contiguous(), "slot_mapping_fp8 must be contiguous");
|
||||
TORCH_CHECK(slot_mapping_scale.is_contiguous(), "slot_mapping_scale must be contiguous");
|
||||
|
||||
int32_t head_dim = static_cast<int32_t>(k_fp8_bytes.size(1)); // head_dim = quant_block_size = 128
|
||||
int32_t scale_size = static_cast<int32_t>(k_scale_bytes.size(1)); // scale_size = 4 bytes
|
||||
|
||||
int32_t cache_dim_0 = static_cast<int32_t>(k_cache.size(0)); // num_blocks
|
||||
int32_t cache_dim_1 = static_cast<int32_t>(k_cache.size(1)); // block_size
|
||||
int32_t cache_dim_2 = static_cast<int32_t>(k_cache.size(2)); // num_kv_heads
|
||||
int32_t cache_dim_3 = static_cast<int32_t>(k_cache.size(3)); // per_token_size
|
||||
|
||||
// Validation for indexer k cache pool for DeepSeek-V3.2 constraints
|
||||
TORCH_CHECK(cache_dim_2 == 1, "k_cache dimension 2 must be 1 for DeepSeek-V3.2, got %d", cache_dim_2);
|
||||
TORCH_CHECK(head_dim == 128, "k_fp8_bytes head_dim must be 128 for DeepSeek-V3.2, got %d", head_dim);
|
||||
TORCH_CHECK(scale_size == 4, "k_scale_bytes scale_size must be 4 bytes for DeepSeek-V3.2, got %d", scale_size);
|
||||
|
||||
int64_t cache_stride_0 = static_cast<int64_t>(k_cache.stride(0));
|
||||
int64_t cache_stride_1 = static_cast<int64_t>(k_cache.stride(1));
|
||||
int64_t cache_stride_2 = static_cast<int64_t>(k_cache.stride(2));
|
||||
int64_t cache_stride_3 = static_cast<int64_t>(k_cache.stride(3));
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream(k_fp8_bytes.get_device());
|
||||
|
||||
tk::invokeIndexerKCacheScatter(k_fp8_bytes.data_ptr<uint8_t>(), k_scale_bytes.data_ptr<uint8_t>(),
|
||||
k_cache.data_ptr<uint8_t>(), slot_mapping_fp8.data_ptr<int64_t>(), slot_mapping_scale.data_ptr<int64_t>(),
|
||||
num_tokens, head_dim, scale_size, cache_dim_0, cache_dim_1, cache_dim_2, cache_dim_3, cache_stride_0,
|
||||
cache_stride_1, cache_stride_2, cache_stride_3, stream);
|
||||
}
|
||||
|
||||
} // namespace torch_ext
|
||||
|
||||
TORCH_LIBRARY_FRAGMENT(trtllm, m)
|
||||
{
|
||||
m.def(
|
||||
"indexer_k_cache_scatter_op(Tensor k_fp8_bytes, Tensor k_scale_bytes, Tensor(a!) k_cache, "
|
||||
"Tensor slot_mapping_fp8, Tensor slot_mapping_scale) -> ()");
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
|
||||
{
|
||||
m.impl("indexer_k_cache_scatter_op", &torch_ext::indexer_k_cache_scatter_op);
|
||||
}
|
||||
@ -872,24 +872,12 @@ class Indexer(nn.Module):
|
||||
k_scale_bytes = k_scale_flat.view(torch.uint8).view(
|
||||
num_tokens, scale_size)
|
||||
|
||||
# Scatter FP8 data
|
||||
# Use CUDA kernel to scatter FP8 and scale bytes into cache
|
||||
flat_indices_fp8 = metadata.slot_mapping_fp8[:num_tokens]
|
||||
byte_offsets = torch.arange(head_dim, device=k_cache.device).unsqueeze(
|
||||
0) # [1, head_dim]
|
||||
scatter_indices_fp8 = flat_indices_fp8.unsqueeze(
|
||||
1) + byte_offsets # [num_tokens, head_dim]
|
||||
scatter_indices_fp8 = _unravel_indices(scatter_indices_fp8,
|
||||
k_cache.shape)
|
||||
k_cache[scatter_indices_fp8] = k_fp8_bytes
|
||||
|
||||
flat_indices_scale = metadata.slot_mapping_scale[:num_tokens]
|
||||
byte_offsets = torch.arange(
|
||||
scale_size, device=k_cache.device).unsqueeze(0) # [1, scale_size]
|
||||
scatter_indices_scale = flat_indices_scale.unsqueeze(
|
||||
1) + byte_offsets # [num_tokens, scale_size]
|
||||
scatter_indices_scale = _unravel_indices(scatter_indices_scale,
|
||||
k_cache.shape)
|
||||
k_cache[scatter_indices_scale] = k_scale_bytes
|
||||
torch.ops.trtllm.indexer_k_cache_scatter_op(k_fp8_bytes, k_scale_bytes,
|
||||
k_cache, flat_indices_fp8,
|
||||
flat_indices_scale)
|
||||
|
||||
def _gather_k_cache_for_chunk(
|
||||
self,
|
||||
|
||||
@ -12,7 +12,7 @@ from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from utils.util import check_accuracy, getSMVersion
|
||||
from utils.util import check_accuracy, skip_pre_hopper
|
||||
|
||||
from tensorrt_llm import deep_gemm
|
||||
from tensorrt_llm._torch.attention_backend.interface import (
|
||||
@ -70,12 +70,9 @@ def create_dsa_cache_manager(
|
||||
index_topk=2048)
|
||||
|
||||
# Create KV cache config
|
||||
# Note: max_attention_window expects list[int] (one per layer)
|
||||
kv_cache_config = KvCacheConfig(
|
||||
enable_block_reuse=False,
|
||||
max_tokens=max_seq_len * batch_size,
|
||||
max_attention_window=[max_seq_len] *
|
||||
num_layers, # List of max window per layer
|
||||
)
|
||||
|
||||
# Create mapping (single GPU, no parallelism)
|
||||
@ -303,8 +300,7 @@ def _ref_fp8_mqa_logits(
|
||||
|
||||
|
||||
@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available")
|
||||
@pytest.mark.skipif(getSMVersion() < 90,
|
||||
reason="fp8_mqa_logits is only supported in SM90 and SM100")
|
||||
@skip_pre_hopper
|
||||
def test_deepgemm_fp8_mqa_logits_basic():
|
||||
"""
|
||||
Basic test for deepgemm.fp8_mqa_logits kernel.
|
||||
@ -477,7 +473,179 @@ def _create_mock_metadata(request_ids,
|
||||
|
||||
|
||||
@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available")
|
||||
@pytest.mark.skipif(getSMVersion() < 90, reason="FP8 operations require SM90+")
|
||||
@skip_pre_hopper
|
||||
def test_indexer_k_cache_scatter_custom_op():
|
||||
"""
|
||||
Direct comparison: CUDA kernel vs Python reference for k_cache scatter.
|
||||
|
||||
This test ensures the new CUDA kernel indexer_k_cache_scatter_op produces
|
||||
exactly the same results as the Python scatter implementation.
|
||||
"""
|
||||
torch.manual_seed(123)
|
||||
|
||||
# Test parameters
|
||||
head_dim = 128
|
||||
block_size = 64
|
||||
batch_size = 3
|
||||
num_tokens = 96 # 3 requests × 32 tokens each
|
||||
max_seq_len = 512
|
||||
|
||||
# Use different layers for CUDA vs Python to test non-contiguous handling
|
||||
layer_idx_cuda = 1 # CUDA kernel writes to layer 0
|
||||
layer_idx_python = 2 # Python reference writes to layer 1
|
||||
|
||||
# Create cache manager with multiple layers
|
||||
cache_manager, sparse_attn_config = create_dsa_cache_manager(
|
||||
batch_size=batch_size,
|
||||
head_dim=head_dim,
|
||||
tokens_per_block=block_size,
|
||||
max_seq_len=max_seq_len,
|
||||
num_layers=3) # Multi-layer pool for non-contiguous test
|
||||
|
||||
# Allocate blocks
|
||||
request_ids = list(range(batch_size))
|
||||
tokens_per_req = [32, 32, 32]
|
||||
cache_manager.add_dummy_requests(request_ids,
|
||||
tokens_per_req,
|
||||
is_gen=False,
|
||||
prepare_resource=True)
|
||||
|
||||
# Create metadata
|
||||
metadata = _create_mock_metadata(
|
||||
request_ids,
|
||||
batch_size,
|
||||
num_contexts=batch_size,
|
||||
num_generations=0,
|
||||
seq_lens=torch.tensor(tokens_per_req, dtype=torch.int32),
|
||||
kv_lens=torch.tensor(tokens_per_req, dtype=torch.int32),
|
||||
num_cached_tokens=[0] * batch_size,
|
||||
cache_manager=cache_manager,
|
||||
num_ctx_tokens=num_tokens,
|
||||
num_tokens=num_tokens,
|
||||
)
|
||||
|
||||
from tensorrt_llm._torch.attention_backend.sparse.dsa import Indexer
|
||||
Indexer.prepare(metadata)
|
||||
|
||||
# Generate test data
|
||||
k_original = torch.randn((num_tokens, head_dim),
|
||||
device="cuda",
|
||||
dtype=torch.bfloat16)
|
||||
k_fp8, k_scale = fp8_utils.fp8_quantize_1x128_sf_transpose(k_original)
|
||||
|
||||
# Prepare byte-level data
|
||||
scale_size = k_scale.shape[1] * 4
|
||||
k_fp8_bytes = k_fp8.view(-1).view(torch.uint8).view(num_tokens, head_dim)
|
||||
k_scale_flat = k_scale.view(-1)
|
||||
if k_scale_flat.stride(-1) != 1:
|
||||
k_scale_flat = torch.as_strided(k_scale_flat.contiguous(),
|
||||
size=(k_scale_flat.numel(), ),
|
||||
stride=(1, ))
|
||||
k_scale_bytes = k_scale_flat.view(torch.uint8).view(num_tokens, scale_size)
|
||||
|
||||
flat_indices_fp8 = metadata.slot_mapping_fp8[:num_tokens]
|
||||
flat_indices_scale = metadata.slot_mapping_scale[:num_tokens]
|
||||
|
||||
# ========== Use Different Layers for CUDA vs Python ==========
|
||||
# Simple approach: use layer 0 for CUDA, layer 1 for Python
|
||||
# Both get the same input data, but write to different layers
|
||||
# Then we extract and compare the outputs from each layer
|
||||
|
||||
# Get k_cache for CUDA path (layer 0)
|
||||
k_cache_cuda = cache_manager.get_indexer_k_cache_buffers(layer_idx_cuda)
|
||||
k_cache_cuda.zero_()
|
||||
|
||||
# Get k_cache for Python path (layer 1)
|
||||
k_cache_python = cache_manager.get_indexer_k_cache_buffers(layer_idx_python)
|
||||
k_cache_python.zero_()
|
||||
|
||||
# Print cache properties
|
||||
print(f"\n=== Cache Properties ===")
|
||||
print(f" CUDA (layer {layer_idx_cuda}):")
|
||||
print(f" Shape: {k_cache_cuda.shape}")
|
||||
print(f" Stride: {k_cache_cuda.stride()}")
|
||||
print(f" is_contiguous: {k_cache_cuda.is_contiguous()}")
|
||||
print(f" Python (layer {layer_idx_python}):")
|
||||
print(f" Shape: {k_cache_python.shape}")
|
||||
print(f" Stride: {k_cache_python.stride()}")
|
||||
print(f" is_contiguous: {k_cache_python.is_contiguous()}")
|
||||
|
||||
# ========== Path 1: CUDA Kernel ==========
|
||||
print(f"\n=== Path 1: CUDA Kernel ===")
|
||||
torch.ops.trtllm.indexer_k_cache_scatter_op(k_fp8_bytes, k_scale_bytes,
|
||||
k_cache_cuda, flat_indices_fp8,
|
||||
flat_indices_scale)
|
||||
torch.cuda.synchronize()
|
||||
print(f"✓ CUDA kernel completed")
|
||||
|
||||
# ========== Path 2: Python Reference ==========
|
||||
print(f"\n=== Path 2: Python Reference ===")
|
||||
|
||||
def _unravel_indices(flat_indices, shape):
|
||||
d3 = shape[3]
|
||||
i3 = flat_indices % d3
|
||||
flat_indices = flat_indices // d3
|
||||
d2 = shape[2]
|
||||
i2 = flat_indices % d2
|
||||
flat_indices = flat_indices // d2
|
||||
d1 = shape[1]
|
||||
i1 = flat_indices % d1
|
||||
flat_indices = flat_indices // d1
|
||||
i0 = flat_indices
|
||||
return i0, i1, i2, i3
|
||||
|
||||
# Scatter FP8 data
|
||||
byte_offsets = torch.arange(head_dim,
|
||||
device=k_cache_python.device).unsqueeze(0)
|
||||
scatter_indices_fp8 = flat_indices_fp8.unsqueeze(1) + byte_offsets
|
||||
scatter_indices_fp8 = _unravel_indices(scatter_indices_fp8,
|
||||
k_cache_python.shape)
|
||||
k_cache_python[scatter_indices_fp8] = k_fp8_bytes
|
||||
|
||||
# Scatter scale data
|
||||
byte_offsets = torch.arange(scale_size,
|
||||
device=k_cache_python.device).unsqueeze(0)
|
||||
scatter_indices_scale = flat_indices_scale.unsqueeze(1) + byte_offsets
|
||||
scatter_indices_scale = _unravel_indices(scatter_indices_scale,
|
||||
k_cache_python.shape)
|
||||
k_cache_python[scatter_indices_scale] = k_scale_bytes
|
||||
|
||||
# ========== Validation: Byte-for-Byte Comparison ==========
|
||||
print(f"\n=== Validation ===")
|
||||
|
||||
total_bytes = k_cache_cuda.numel()
|
||||
|
||||
# Compare entire cache tensors
|
||||
if torch.equal(k_cache_cuda, k_cache_python):
|
||||
print(f"✅ PERFECT MATCH! CUDA and Python produce identical cache")
|
||||
print(f" Total bytes compared: {total_bytes}")
|
||||
print(
|
||||
f" Tokens: {num_tokens}, head_dim: {head_dim}, block_size: {block_size}"
|
||||
)
|
||||
else:
|
||||
# Find differences
|
||||
diff_mask = k_cache_cuda != k_cache_python
|
||||
num_diffs = diff_mask.sum().item()
|
||||
|
||||
print(
|
||||
f"⚠️ Found {num_diffs}/{total_bytes} byte differences ({100*num_diffs/total_bytes:.4f}%)"
|
||||
)
|
||||
|
||||
# Show first few differences
|
||||
diff_indices = torch.nonzero(diff_mask.view(-1))[:5]
|
||||
for idx in diff_indices:
|
||||
flat_idx = idx.item()
|
||||
print(
|
||||
f" Byte {flat_idx}: CUDA={k_cache_cuda.view(-1)[flat_idx].item()}, "
|
||||
f"Python={k_cache_python.view(-1)[flat_idx].item()}")
|
||||
|
||||
# Fail the test
|
||||
raise AssertionError(
|
||||
"CUDA kernel produced different results than Python reference")
|
||||
|
||||
|
||||
@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available")
|
||||
@skip_pre_hopper
|
||||
def test_fp8_k_cache_roundtrip():
|
||||
"""Verify FP8 quantization scales survive write/read cycle for multiple requests."""
|
||||
torch.manual_seed(42)
|
||||
@ -562,9 +730,7 @@ def test_fp8_k_cache_roundtrip():
|
||||
|
||||
|
||||
@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available")
|
||||
@pytest.mark.skipif(
|
||||
getSMVersion() < 90,
|
||||
reason="fp8_paged_mqa_logits is only supported in SM90 and SM100")
|
||||
@skip_pre_hopper
|
||||
@pytest.mark.parametrize("batch_size,next_n", [(4, 1), (2, 2)])
|
||||
def test_indexer_decode_with_paged_kv_cache(batch_size, next_n):
|
||||
"""
|
||||
@ -859,7 +1025,7 @@ def test_split_prefill_chunks(max_chunk_size, seq_lens, start_idx,
|
||||
|
||||
|
||||
@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available")
|
||||
@pytest.mark.skipif(getSMVersion() < 90, reason="FP8 operations require SM90+")
|
||||
@skip_pre_hopper
|
||||
@pytest.mark.parametrize(
|
||||
"chunk_size,seq_lens_list,chunking_type",
|
||||
[
|
||||
|
||||
Loading…
Reference in New Issue
Block a user