TensorRT-LLMs/cpp/kernels/xqa/utils.h
Ming Wei ed887940d4
infra: open source XQA kernels (#3762)
Replace libtensorrt_llm_nvrtc_wrapper.so with its source code, which
consists of two parts:

1. NVRTC glue code
2. XQA kernel code

During TensorRT-LLM build, XQA kernel code is embedded as C++ arries via
gen_cpp_header.py and passed to NVRTC for JIT compilation.

Signed-off-by: Ming Wei <2345434+ming-wei@users.noreply.github.com>
2025-04-30 18:05:15 +08:00

365 lines
8.5 KiB
C++

/*
* SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement
*
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
* property and proprietary rights in and to this material, related
* documentation and any modifications thereto. Any use, reproduction,
* disclosure or distribution of this material and related documentation
* without an express license agreement from NVIDIA CORPORATION or
* its affiliates is strictly prohibited.
*/
#pragma once
#ifndef GENERATE_CUBIN
#include <cassert>
#include <cstdint>
#include <cstdio>
#include <cuda.h>
#include <cuda_runtime_api.h>
#include <stdexcept>
#endif
#include "mha_stdheaders.cuh"
template <typename T>
HOST_DEVICE_FUNC constexpr inline void unused(T&& x)
{
static_cast<void>(x);
}
#ifndef GENERATE_CUBIN
inline void checkCuda(cudaError_t err)
{
if (err != cudaSuccess)
{
printf("%s\n", cudaGetErrorName(err));
throw std::runtime_error(cudaGetErrorName(err));
}
}
inline void checkCu(CUresult err)
{
if (err != CUDA_SUCCESS)
{
char const* str = nullptr;
if (cuGetErrorName(err, &str) != CUDA_SUCCESS)
{
str = "A cuda driver API error happened, but we failed to query the error name\n";
}
printf("%s\n", str);
throw std::runtime_error(str);
}
}
#endif
HOST_DEVICE_FUNC constexpr inline uint32_t greatestPowerOf2Divisor(uint32_t x)
{
return x & ~(x - 1);
}
template <typename T>
HOST_DEVICE_FUNC constexpr uint32_t maxArrayAlign(uint32_t size)
{
return sizeof(T) * greatestPowerOf2Divisor(size);
}
HOST_DEVICE_FUNC constexpr inline uint32_t exactDiv(uint32_t a, uint32_t b)
{
assert(a % b == 0);
return a / b;
}
template <typename T>
HOST_DEVICE_FUNC constexpr inline T divUp(T a, T b)
{
return (a + b - 1) / b;
}
template <typename T>
HOST_DEVICE_FUNC constexpr inline T roundUp(T a, T b)
{
return divUp(a, b) * b;
}
// upperBound is exclusive, i.e. range is [0, upperBound)
template <uint32_t upperBound>
struct BoundedVal
{
template <uint32_t divisor>
HOST_DEVICE_FUNC inline BoundedVal<upperBound / divisor> divBy() const
{
assert(value < upperBound);
return {upperBound <= divisor ? 0 : value / divisor};
}
template <uint32_t divisor>
HOST_DEVICE_FUNC inline BoundedVal<mha::min(divisor, upperBound)> mod() const
{
assert(value < upperBound);
return {upperBound <= divisor ? value : value % divisor};
}
HOST_DEVICE_FUNC inline bool operator<=(uint32_t rhs) const
{
assert(value < upperBound);
return upperBound <= rhs || value <= rhs;
}
HOST_DEVICE_FUNC inline uint32_t get() const
{
assert(value < upperBound);
return upperBound == 1 ? 0 : value;
}
uint32_t value;
};
template <typename T, uint32_t size_>
struct alignas(mha::max<uint32_t>(alignof(T), mha::min<uint32_t>(maxArrayAlign<T>(size_), 16))) Vec
{
using Elem = T;
static constexpr uint32_t size = size_;
Elem data[size];
HOST_DEVICE_FUNC inline void fill(T val)
{
#pragma unroll
for (uint32_t i = 0; i < size; i++)
{
data[i] = val;
}
}
static HOST_DEVICE_FUNC inline Vec<T, size> filled(T val)
{
Vec<T, size> ret;
ret.fill(val);
return ret;
}
HOST_DEVICE_FUNC inline Elem const& operator[](uint32_t i) const
{
assert(i < size);
return data[BoundedVal<size>{i}.get()];
}
HOST_DEVICE_FUNC inline Elem& operator[](uint32_t i)
{
assert(i < size);
return data[BoundedVal<size>{i}.get()];
}
};
template <uint32_t nbBuffers_>
struct CircIdx
{
public:
static constexpr uint32_t nbBuffers = nbBuffers_;
static_assert(nbBuffers >= 1);
__device__ inline CircIdx(uint32_t init)
: mIndex{init % nbBuffers}
{
}
__device__ inline operator uint32_t() const
{
return mIndex;
}
__device__ inline CircIdx operator+(uint32_t i) const
{
return CircIdx{(mIndex + i) % nbBuffers};
}
__device__ inline CircIdx operator-(uint32_t i) const
{
return CircIdx{(mIndex + (nbBuffers - 1) * i) % nbBuffers};
}
__device__ inline CircIdx next() const
{
return *this + 1u;
}
__device__ inline CircIdx& operator++()
{
mIndex = next();
return *this;
}
__device__ inline CircIdx operator++(int)
{
CircIdx old = *this;
operator++();
return old;
}
__device__ inline CircIdx prev() const
{
return *this - 1u;
}
__device__ inline CircIdx& operator--()
{
mIndex = prev();
return *this;
}
__device__ inline CircIdx operator--(int)
{
CircIdx old = *this;
operator--();
return old;
}
private:
uint32_t mIndex;
};
// base is usually in constant memory, so usually only require 1 register to store the offset.
template <typename T>
struct TinyPtr
{
T* base; // typically in constant memory or uniform registers
uint32_t offset; // may be non-uniform
template <typename D>
__device__ __host__ inline TinyPtr<D> cast() const
{
D* const p = reinterpret_cast<D*>(base);
assert(reinterpret_cast<uintptr_t>(p) % alignof(D) == 0);
if constexpr (mha::is_void_v<T>)
{
assert(offset == 0);
return TinyPtr<D>{p, 0};
}
else if constexpr (sizeof(T) < sizeof(D))
{
return TinyPtr<D>{p, exactDiv(offset, exactDiv(sizeof(D), sizeof(T)))};
}
else
{
return TinyPtr<D>{p, offset * exactDiv(sizeof(T), sizeof(D))};
}
}
__device__ __host__ inline T& operator*() const
{
return base[offset];
}
__device__ __host__ inline TinyPtr<T> operator+(uint32_t i) const
{
return TinyPtr<T>{base, offset + i};
}
__device__ __host__ inline T& operator[](uint32_t i) const
{
return *(*this + i);
}
__device__ __host__ inline operator T*() const
{
return base + offset;
}
};
template <typename OffsetInt = uint32_t>
class Segmenter
{
public:
HOST_DEVICE_FUNC Segmenter(uint32_t offset = 0)
: mNextOffset{offset}
{
}
// offset is in bytes
template <typename T>
HOST_DEVICE_FUNC OffsetInt newSeg(uint32_t count = 1, uint32_t alignment = alignof(T))
{
mMaxAlignment = mha::max<uint32_t>(mMaxAlignment, alignment);
OffsetInt const offset = roundUp<OffsetInt>(mNextOffset, alignment);
mNextOffset = offset + sizeof(T) * count;
return offset;
}
HOST_DEVICE_FUNC OffsetInt getEndOffset() const
{
return mNextOffset;
}
HOST_DEVICE_FUNC uint32_t getMaxAlignment() const
{
return mMaxAlignment;
}
private:
OffsetInt mNextOffset;
uint32_t mMaxAlignment = 1;
};
template <typename T, bool addConst>
using AddConst = mha::conditional_t<addConst, T const, T>;
template <bool isConst, typename OffsetInt = uint32_t>
class MemSegmenter
{
public:
HOST_DEVICE_FUNC MemSegmenter(AddConst<void, isConst>* base, uint32_t offset = 0)
: mBase{static_cast<AddConst<mha::byte, isConst>*>(base)}
, mSegmenter{offset}
{
}
// to use TinyPtr, alignment must be sizeof(T)
template <typename T>
HOST_DEVICE_FUNC TinyPtr<AddConst<T, isConst>> newSeg(uint32_t count = 1, uint32_t alignment = sizeof(T))
{
assert(reinterpret_cast<uintptr_t>(mBase) % alignment == 0);
OffsetInt const offset = mSegmenter.template newSeg<T>(count, alignment);
return TinyPtr<AddConst<mha::byte, isConst>>{mBase, offset}.template cast<AddConst<T, isConst>>();
}
HOST_DEVICE_FUNC OffsetInt getEndOffset() const
{
return mSegmenter.getEndOffset();
}
HOST_DEVICE_FUNC uint32_t getMaxAlignment() const
{
return mSegmenter.getMaxAlignment();
}
private:
AddConst<mha::byte, isConst>* mBase;
Segmenter<OffsetInt> mSegmenter;
};
// dims in little endian
template <uint32_t nbDims_>
struct DimsLE
{
static constexpr uint32_t nbDims = nbDims_;
__device__ __host__ inline uint32_t& operator[](uint32_t i)
{
return d[i];
}
__device__ __host__ inline uint32_t const& operator[](uint32_t i) const
{
return d[i];
}
uint32_t d[nbDims];
};
// check if val is in range [lb, ub)
template <typename T>
constexpr bool inRange(T val, T lb, T ub)
{
return val >= lb && val < ub;
}