TensorRT-LLMs/cpp/tests/unit_tests/kernels/prepareCustomMaskTest.cpp
2025-11-17 09:01:53 +08:00

384 lines
17 KiB
C++

/*
* SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <gtest/gtest.h>
#include <algorithm>
#include <cstdint>
#include <random>
#include <vector>
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/memoryUtils.h"
#include "tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaKernels.h"
#include "tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaRunnerParams.h"
#include "tensorrt_llm/kernels/trtllmGenKernels/fmha/prepareCustomMask.h"
#include "tensorrt_llm/runtime/bufferManager.h"
#include "tensorrt_llm/runtime/common.h"
namespace
{
using tensorrt_llm::kernels::FmhaKernelType;
using tensorrt_llm::kernels::runPrepareCustomMask;
using tensorrt_llm::kernels::TllmGenFmhaKernelMetaInfo;
using tensorrt_llm::kernels::TllmGenFmhaRunnerParams;
using tensorrt_llm::runtime::BufferManager;
using tensorrt_llm::runtime::CudaStream;
using tensorrt_llm::runtime::MemoryType;
using tensorrt_llm::runtime::bufferCast;
inline int32_t ceilDiv(int32_t dividend, int32_t divisor)
{
return (dividend + divisor - 1) / divisor;
}
// CPU reference implementation for preparing custom mask buffers
std::tuple<std::vector<uint32_t>, std::vector<int64_t>, std::vector<int32_t>> prepareCustomMaskBuffersCPU(
int32_t batchSize, int32_t numHeadsQPerKv, int32_t tileSizeQ, int32_t tileSizeKv, int32_t numInstsQ,
int32_t numInstsKv, std::vector<int32_t> const& seqLensQ, std::vector<int32_t> const& seqLensKv,
std::vector<int32_t> const& firstSparseMaskOffsetsKv,
std::vector<int32_t> const& inputTreeMask) // Non-packed mask [bs, seqLenQ, seqLenQ]
{
// Pad tileSizeKv to multiple of 32 for keepsMmaAb kernel
int32_t tileSizeKvPadded = ceilDiv(tileSizeKv, 32) * 32;
int32_t tileSizeQPerCta = tileSizeQ * numInstsQ;
int32_t tileSizeKvPerCta = tileSizeKvPadded * numInstsKv;
std::vector<int32_t> cumSeqLensQ(batchSize + 1, 0);
for (int32_t i = 0; i < batchSize; ++i)
{
cumSeqLensQ[i + 1] = cumSeqLensQ[i] + seqLensQ[i];
}
std::vector<int64_t> customMaskOffsets(batchSize, 0);
std::vector<int32_t> adjustedFirstSparseMaskOffsetsKv(batchSize, 0);
int64_t totalMaskSize = 0;
for (int32_t batchIdx = 0; batchIdx < batchSize; ++batchIdx)
{
int32_t seqLenQ = seqLensQ[batchIdx];
int32_t seqLenKv = seqLensKv[batchIdx];
int32_t firstSparseMaskOffsetKv = firstSparseMaskOffsetsKv[batchIdx];
int32_t numTilesQ = ceilDiv(seqLenQ * numHeadsQPerKv, tileSizeQPerCta);
int32_t firstSparseTile = firstSparseMaskOffsetKv / tileSizeKvPerCta;
int32_t numCustomMaskTilesKv = ceilDiv(seqLenKv, tileSizeKvPerCta) - firstSparseTile;
customMaskOffsets[batchIdx] = totalMaskSize;
adjustedFirstSparseMaskOffsetsKv[batchIdx] = firstSparseTile * tileSizeKvPerCta;
int64_t maskSize = static_cast<int64_t>(numTilesQ) * numCustomMaskTilesKv * numInstsQ * numInstsKv
* (tileSizeQ * tileSizeKvPadded) / 32;
totalMaskSize += maskSize;
}
std::vector<uint32_t> customMask(totalMaskSize, 0);
// Fill custom mask from input packed mask
for (int32_t batchIdx = 0; batchIdx < batchSize; ++batchIdx)
{
int32_t seqLenQ = seqLensQ[batchIdx];
int32_t seqLenKv = seqLensKv[batchIdx];
int32_t firstSparseMaskOffsetKv = firstSparseMaskOffsetsKv[batchIdx];
int32_t adjustedFirstSparseMaskOffsetKv = adjustedFirstSparseMaskOffsetsKv[batchIdx];
int32_t numTilesQ = ceilDiv(seqLenQ * numHeadsQPerKv, tileSizeQPerCta);
int32_t firstSparseTile = firstSparseMaskOffsetKv / tileSizeKvPerCta;
int32_t numCustomMaskTilesKv = ceilDiv(seqLenKv, tileSizeKvPerCta) - firstSparseTile;
uint32_t* localCustomMask = customMask.data() + customMaskOffsets[batchIdx];
// Tree mask layout: [bs, seqLenQ, seqLenQ] (non-packed)
int32_t batchMaskOffset = batchIdx * seqLenQ * seqLenQ;
for (int32_t tokenIdxQ = 0; tokenIdxQ < seqLenQ; ++tokenIdxQ)
{
for (int32_t tokenIdxKv = 0; tokenIdxKv < seqLenKv; ++tokenIdxKv)
{
bool randomMask = false;
if (tokenIdxKv < firstSparseMaskOffsetKv)
{
randomMask = true; // Dense region (always attend)
}
else
{
int32_t qPosInTree = tokenIdxKv - firstSparseMaskOffsetKv;
if (qPosInTree < seqLenQ)
{
int32_t maskIdx = batchMaskOffset + tokenIdxQ * seqLenQ + qPosInTree;
randomMask = static_cast<bool>(inputTreeMask[maskIdx]);
}
else
{
randomMask = false;
}
}
// Only process custom mask region (excluding dense region before adjustedFirstSparseMaskOffsetKv)
if (tokenIdxKv >= adjustedFirstSparseMaskOffsetKv)
{
int32_t customMaskTokenIdxKv = tokenIdxKv - adjustedFirstSparseMaskOffsetKv;
int32_t tileIdxKv = customMaskTokenIdxKv / tileSizeKvPerCta;
int32_t instIdxKv = (customMaskTokenIdxKv % tileSizeKvPerCta) / tileSizeKvPadded;
int32_t tokenIdxInTileKv = (customMaskTokenIdxKv % tileSizeKvPerCta) % tileSizeKvPadded;
for (int32_t headIdxInGrp = 0; headIdxInGrp < numHeadsQPerKv; ++headIdxInGrp)
{
int32_t customMaskTokenIdxQ = tokenIdxQ * numHeadsQPerKv + headIdxInGrp;
int32_t tileIdxQ = customMaskTokenIdxQ / tileSizeQPerCta;
int32_t instIdxQ = (customMaskTokenIdxQ % tileSizeQPerCta) / tileSizeQ;
int32_t tokenIdxInTileQ = (customMaskTokenIdxQ % tileSizeQPerCta) % tileSizeQ;
// Calculate mask offset
int64_t tileOffset = tileIdxQ * numCustomMaskTilesKv + tileIdxKv;
int64_t instOffset = tileOffset * numInstsQ * numInstsKv + (instIdxQ * numInstsKv + instIdxKv);
int64_t maskOffset = instOffset * tileSizeQ * tileSizeKvPadded
+ (tokenIdxInTileQ * tileSizeKvPadded + tokenIdxInTileKv);
int64_t offsetAsUInt32 = maskOffset >> 5;
int64_t bitPosInUInt32 = maskOffset & 0x1F;
localCustomMask[offsetAsUInt32] |= (uint32_t(randomMask) << bitPosInUInt32);
}
}
}
}
}
return std::make_tuple(customMask, customMaskOffsets, adjustedFirstSparseMaskOffsetsKv);
}
class PrepareCustomMaskTest : public ::testing::Test
{
protected:
static bool shouldSkip()
{
return !tensorrt_llm::common::isSM100Family();
}
void SetUp() override
{
if (shouldSkip())
{
GTEST_SKIP() << "Skipping due to not SM100 family GPU";
}
mStream = std::make_shared<CudaStream>();
mBufferManager = std::make_shared<BufferManager>(mStream);
}
void TearDown() override
{
if (mStream)
{
cudaStreamSynchronize(mStream->get());
}
cudaDeviceSynchronize();
mBufferManager.reset();
mStream.reset();
}
void testPrepareCustomMask(int32_t batchSize, int32_t maxSeqLenQ, int32_t maxSeqLenKv, int32_t numHeadsQPerKv,
int32_t tileSizeQ = 128, int32_t tileSizeKv = 128, int32_t numInstsQ = 2, int32_t numInstsKv = 1)
{
std::mt19937 gen(42);
std::uniform_int_distribution<> seqLenQDist(1, maxSeqLenQ);
std::uniform_int_distribution<> seqLenKvDist(maxSeqLenQ, maxSeqLenKv);
std::vector<int32_t> seqLensQ(batchSize);
std::vector<int32_t> seqLensKv(batchSize);
std::vector<int32_t> firstSparseMaskOffsetsKv(batchSize);
std::vector<int32_t> cumSeqLensQ(batchSize + 1, 0);
std::vector<int32_t> specDecodingGenerationLengths(batchSize);
// Generate a uniform seqLenQ for all batches
int32_t uniformSeqLenQ = seqLenQDist(gen);
for (int32_t i = 0; i < batchSize; ++i)
{
seqLensQ[i] = uniformSeqLenQ;
seqLensKv[i] = seqLenKvDist(gen);
firstSparseMaskOffsetsKv[i] = seqLensKv[i] - seqLensQ[i];
cumSeqLensQ[i + 1] = cumSeqLensQ[i] + seqLensQ[i];
specDecodingGenerationLengths[i] = seqLensQ[i];
}
// Generate random tree mask input
// Non-packed mask shape: [bs, seqLensQ, seqLensQ]
int32_t totalTreeMaskSize = batchSize * uniformSeqLenQ * uniformSeqLenQ;
std::vector<int32_t> inputTreeMaskHost(totalTreeMaskSize, 0);
std::uniform_int_distribution<int32_t> binaryDist(0, 1);
for (int32_t batchIdx = 0; batchIdx < batchSize; ++batchIdx)
{
int32_t batchOffset = batchIdx * uniformSeqLenQ * uniformSeqLenQ;
for (int32_t i = 0; i < uniformSeqLenQ * uniformSeqLenQ; ++i)
{
inputTreeMaskHost[batchOffset + i] = binaryDist(gen); // Random 0 or 1
}
}
// Pack the tree mask for GPU kernel input
// Packed mask shape: [bs, seqLensQ, ceilDiv(seqLensQ, 32)]
int32_t const numBitsPerPackedMask = 32;
int32_t const numPackedMasksPerToken = ceilDiv(uniformSeqLenQ, numBitsPerPackedMask);
int32_t totalPackedMaskSize = batchSize * uniformSeqLenQ * numPackedMasksPerToken;
std::vector<int32_t> inputPackedMaskHost(totalPackedMaskSize, 0);
for (int32_t batchIdx = 0; batchIdx < batchSize; ++batchIdx)
{
int32_t treeMaskBatchOffset = batchIdx * uniformSeqLenQ * uniformSeqLenQ;
int32_t packedBatchOffset = batchIdx * uniformSeqLenQ * numPackedMasksPerToken;
for (int32_t i = 0; i < uniformSeqLenQ; ++i)
{
for (int32_t j = 0; j < numPackedMasksPerToken; ++j)
{
int32_t mask = 0;
for (int32_t k = 0; k < numBitsPerPackedMask; ++k)
{
int32_t const bitIndex = j * numBitsPerPackedMask + k;
if (bitIndex < uniformSeqLenQ)
{
int32_t maskFlag = inputTreeMaskHost[treeMaskBatchOffset + i * uniformSeqLenQ + bitIndex];
mask |= (maskFlag << k);
}
}
inputPackedMaskHost[packedBatchOffset + i * numPackedMasksPerToken + j] = mask;
}
}
}
auto seqLensQDevice = mBufferManager->copyFrom(seqLensQ, MemoryType::kGPU);
auto seqLensKvDevice = mBufferManager->copyFrom(seqLensKv, MemoryType::kGPU);
auto cumSeqLensQDevice = mBufferManager->copyFrom(cumSeqLensQ, MemoryType::kGPU);
auto specDecodingGenerationLengthsDevice
= mBufferManager->copyFrom(specDecodingGenerationLengths, MemoryType::kGPU);
auto firstSparseMaskOffsetsKvDevice = mBufferManager->copyFrom(firstSparseMaskOffsetsKv, MemoryType::kGPU);
auto inputPackedMaskDevice = mBufferManager->copyFrom(inputPackedMaskHost, MemoryType::kGPU);
// Calculate output buffer sizes using conservative upper bound
int32_t tileSizeKvPadded = ceilDiv(tileSizeKv, 32) * 32;
int32_t tileSizeQPerCta = tileSizeQ * numInstsQ;
int32_t tileSizeKvPerCta = tileSizeKvPadded * numInstsKv;
// Find max values across all batches
int32_t actualMaxSeqLenQ = *std::max_element(seqLensQ.begin(), seqLensQ.end());
int32_t actualMaxSeqLenKv = *std::max_element(seqLensKv.begin(), seqLensKv.end());
int32_t minFirstSparseMaskOffsetKv
= *std::min_element(firstSparseMaskOffsetsKv.begin(), firstSparseMaskOffsetsKv.end());
// Calculate conservative upper bounds
int32_t maxNumTilesQ = ceilDiv(actualMaxSeqLenQ * numHeadsQPerKv, tileSizeQPerCta);
int32_t firstSparseTile = minFirstSparseMaskOffsetKv / tileSizeKvPerCta;
int32_t maxNumCustomMaskTilesKv = ceilDiv(actualMaxSeqLenKv, tileSizeKvPerCta) - firstSparseTile;
// Total size in uint32 elements
int64_t totalMaskSize = static_cast<int64_t>(batchSize) * maxNumTilesQ * maxNumCustomMaskTilesKv * numInstsQ
* numInstsKv * (tileSizeQ * tileSizeKvPadded) / 32;
auto customMaskOffsetsDevice = mBufferManager->gpu(batchSize, nvinfer1::DataType::kINT64);
auto customMaskDevice = mBufferManager->gpu(totalMaskSize, nvinfer1::DataType::kINT32);
// Clear GPU buffers to ensure no stale data from previous tests
cudaMemsetAsync(bufferCast<int64_t>(*customMaskOffsetsDevice), 0, batchSize * sizeof(int64_t), mStream->get());
cudaMemsetAsync(bufferCast<int32_t>(*customMaskDevice), 0, totalMaskSize * sizeof(int32_t), mStream->get());
cudaStreamSynchronize(mStream->get());
// Setup kernel parameters
TllmGenFmhaKernelMetaInfo kernelMeta{};
kernelMeta.mTileSizeQ = tileSizeQ;
kernelMeta.mTileSizeKv = tileSizeKv;
kernelMeta.mStepQ = tileSizeQ * numInstsQ;
kernelMeta.mStepKv = tileSizeKv * numInstsKv;
kernelMeta.mKernelType = static_cast<int>(FmhaKernelType::KeepsMmaAbForGeneration);
TllmGenFmhaRunnerParams runnerParams;
runnerParams.mBatchSize = batchSize;
runnerParams.mNumHeadsQPerKv = numHeadsQPerKv;
runnerParams.mMaxSeqLenQ = uniformSeqLenQ; // All batches have same Q length
runnerParams.mMaxSeqLenKv = *std::max_element(seqLensKv.begin(), seqLensKv.end());
runnerParams.seqLensKvPtr = bufferCast<int32_t>(*seqLensKvDevice);
runnerParams.cumSeqLensQPtr = bufferCast<int32_t>(*cumSeqLensQDevice);
runnerParams.seqlensQPtr = bufferCast<int32_t>(*specDecodingGenerationLengthsDevice);
runnerParams.firstSparseMaskOffsetsKvPtr = bufferCast<int32_t>(*firstSparseMaskOffsetsKvDevice);
runnerParams.generalPackedCustoMaskPtr = bufferCast<int32_t>(*inputPackedMaskDevice);
runnerParams.customMaskOffsetsPtr = bufferCast<int64_t>(*customMaskOffsetsDevice);
runnerParams.customMaskPtr = reinterpret_cast<uint32_t*>(bufferCast<int32_t>(*customMaskDevice));
runPrepareCustomMask(kernelMeta, runnerParams, mStream->get());
cudaError_t cudaErr = cudaStreamSynchronize(mStream->get());
if (cudaErr != cudaSuccess)
{
FAIL() << "CUDA error: " << cudaGetErrorString(cudaErr);
}
// Get GPU results
auto customMaskOffsetsHost = mBufferManager->copyFrom(*customMaskOffsetsDevice, MemoryType::kCPU);
auto customMaskHost = mBufferManager->copyFrom(*customMaskDevice, MemoryType::kCPU);
// Run CPU reference with non-packed tree mask
auto [cpuMask, cpuOffsets, cpuAdjustedOffsets]
= prepareCustomMaskBuffersCPU(batchSize, numHeadsQPerKv, tileSizeQ, tileSizeKv, numInstsQ, numInstsKv,
seqLensQ, seqLensKv, firstSparseMaskOffsetsKv, inputTreeMaskHost);
auto* gpuOffsets = bufferCast<int64_t>(*customMaskOffsetsHost);
auto* gpuMask = reinterpret_cast<uint32_t*>(bufferCast<int32_t>(*customMaskHost));
auto firstSparseMaskOffsetsKvHost = mBufferManager->copyFrom(*firstSparseMaskOffsetsKvDevice, MemoryType::kCPU);
auto* gpuAdjustedOffsets = bufferCast<int32_t>(*firstSparseMaskOffsetsKvHost);
// Compare only the effective portion
for (int32_t i = 0; i < cpuMask.size(); ++i)
{
EXPECT_EQ(gpuMask[i], cpuMask[i]);
}
for (int32_t i = 0; i < cpuOffsets.size(); ++i)
{
EXPECT_EQ(gpuOffsets[i], cpuOffsets[i]);
}
for (int32_t i = 0; i < cpuAdjustedOffsets.size(); ++i)
{
EXPECT_EQ(gpuAdjustedOffsets[i], cpuAdjustedOffsets[i]);
}
}
std::shared_ptr<CudaStream> mStream;
std::shared_ptr<BufferManager> mBufferManager;
};
TEST_F(PrepareCustomMaskTest, SmallBatch)
{
testPrepareCustomMask(/* batchSize */ 2,
/* maxSeqLenQ */ 16,
/* maxSeqLenKv */ 128,
/* numHeadsQPerKv */ 4);
}
TEST_F(PrepareCustomMaskTest, MediumBatch)
{
testPrepareCustomMask(/* batchSize */ 4,
/* maxSeqLenQ */ 32,
/* maxSeqLenKv */ 256,
/* numHeadsQPerKv */ 8);
}
} // namespace