mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Replace libtensorrt_llm_nvrtc_wrapper.so with its source code, which consists of two parts: 1. NVRTC glue code 2. XQA kernel code During TensorRT-LLM build, XQA kernel code is embedded as C++ arries via gen_cpp_header.py and passed to NVRTC for JIT compilation. Signed-off-by: Ming Wei <2345434+ming-wei@users.noreply.github.com>
179 lines
5.1 KiB
Plaintext
179 lines
5.1 KiB
Plaintext
/*
|
|
* SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
* SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement
|
|
*
|
|
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
|
* property and proprietary rights in and to this material, related
|
|
* documentation and any modifications thereto. Any use, reproduction,
|
|
* disclosure or distribution of this material and related documentation
|
|
* without an express license agreement from NVIDIA CORPORATION or
|
|
* its affiliates is strictly prohibited.
|
|
*/
|
|
|
|
#pragma once
|
|
#include "cuda_hint.cuh"
|
|
#include "mha_stdheaders.cuh"
|
|
#include "utils.cuh"
|
|
#ifndef __CUDACC__
|
|
#include <cuda_runtime.h>
|
|
#endif
|
|
#include <cuda_fp16.h>
|
|
#include <cuda_fp8.h>
|
|
|
|
namespace gmma
|
|
{
|
|
|
|
enum class SwizzleMode : uint64_t
|
|
{
|
|
kNONE = 0,
|
|
k128 = 1,
|
|
k64 = 2,
|
|
k32 = 3
|
|
};
|
|
|
|
struct MatDesc
|
|
{
|
|
uint64_t addr : 16;
|
|
uint64_t dimKOffset : 16;
|
|
uint64_t dimMNOffset : 16;
|
|
uint64_t pad0 : 1;
|
|
uint64_t baseOffset : 3;
|
|
uint64_t pad1 : 10;
|
|
SwizzleMode swizzle : 2;
|
|
|
|
enum class Raw : uint64_t
|
|
{
|
|
};
|
|
|
|
[[nodiscard]] __device__ inline MatDesc withAddr(void const* data) const
|
|
{
|
|
MatDesc ret = *this;
|
|
ret.addr = encode(__cvta_generic_to_shared(data));
|
|
return ret;
|
|
}
|
|
|
|
static __device__ inline uint32_t encode(uint32_t val)
|
|
{
|
|
return (val & 0x3FFFFU) >> 4;
|
|
}
|
|
|
|
__device__ inline bool operator==(MatDesc const& other) const
|
|
{
|
|
return raw() == other.raw();
|
|
}
|
|
|
|
__device__ inline Raw const& raw() const
|
|
{
|
|
static_assert(sizeof(MatDesc) == 8);
|
|
return reinterpret_cast<Raw const&>(*this);
|
|
}
|
|
|
|
static __device__ inline MatDesc fromRaw(Raw const& raw)
|
|
{
|
|
return reinterpret_cast<MatDesc const&>(raw);
|
|
}
|
|
};
|
|
|
|
static_assert(sizeof(MatDesc) == 8);
|
|
|
|
[[nodiscard]] __device__ inline MatDesc::Raw addAddr(MatDesc::Raw base, void const* data)
|
|
{
|
|
assert((uint32_t(__cvta_generic_to_shared(data)) & ~0x3FFFFU) == 0);
|
|
MatDesc::Raw ret = base;
|
|
auto& u32x2 = reinterpret_cast<uint32_t(&)[2]>(ret);
|
|
u32x2[0] += static_cast<uint32_t>(__cvta_generic_to_shared(data)) >> 4;
|
|
return ret;
|
|
}
|
|
|
|
__device__ inline MatDesc makeMatDesc(void const* data, uint32_t dimKByteOffset, uint32_t dimMNByteOffset,
|
|
void const* patternStartAddr, SwizzleMode swizzleMode)
|
|
{
|
|
uint32_t const patternAddr = __cvta_generic_to_shared(patternStartAddr);
|
|
uint32_t const baseAlign = [&]() -> uint32_t
|
|
{
|
|
switch (swizzleMode)
|
|
{
|
|
case SwizzleMode::kNONE: return 1;
|
|
case SwizzleMode::k128: return 1024;
|
|
case SwizzleMode::k64: return 512;
|
|
case SwizzleMode::k32: return 256;
|
|
}
|
|
asm volatile("trap;\n");
|
|
return 0;
|
|
}();
|
|
uint32_t const baseOffset = ((patternAddr % baseAlign == 0) ? 0U : ((patternAddr >> 0x7) & 0x7));
|
|
return MatDesc{
|
|
/*addr=*/MatDesc::encode(__cvta_generic_to_shared(data)),
|
|
/*dimKOffset=*/MatDesc::encode(dimKByteOffset),
|
|
/*dimMNOffset=*/MatDesc::encode(dimMNByteOffset),
|
|
/*pad0=*/0,
|
|
/*baseOffset=*/baseOffset,
|
|
/*pad1=*/0,
|
|
/*swizzle=*/swizzleMode,
|
|
};
|
|
}
|
|
|
|
__device__ inline MatDesc makeMatDesc(
|
|
void const* data, uint32_t dimKByteOffset, uint32_t dimMNByteOffset, SwizzleMode swizzleMode)
|
|
{
|
|
return makeMatDesc(data, dimKByteOffset, dimMNByteOffset, data, swizzleMode);
|
|
}
|
|
|
|
inline constexpr uint32_t instM = 64;
|
|
|
|
template <typename MathElem>
|
|
inline constexpr uint32_t instK = 32 / sizeof(MathElem);
|
|
|
|
inline constexpr uint32_t instNBase = 8;
|
|
|
|
// for both a and b, outer-dim is gemm-K and inner-dim is gemm-M or gemm-N
|
|
// acc is used as both input and output.
|
|
template <typename InputElem, uint32_t n, bool transA = false, bool transB = false>
|
|
__device__ void mma_async_shmA(
|
|
float (&acc)[exactDiv(n, instNBase)][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal);
|
|
template <typename InputElem, uint32_t n, bool transA = false, bool transB = false>
|
|
__device__ void mma_async_regA(
|
|
float (&acc)[exactDiv(n, instNBase)][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal);
|
|
|
|
__device__ inline void fence()
|
|
{
|
|
asm volatile("wgmma.fence.sync.aligned;\n");
|
|
}
|
|
|
|
__device__ inline void commit_group()
|
|
{
|
|
asm volatile("wgmma.commit_group.sync.aligned;\n");
|
|
}
|
|
|
|
template <uint32_t targetNbInFlightGroups>
|
|
__device__ inline void wait_group()
|
|
{
|
|
asm volatile("wgmma.wait_group.sync.aligned %0\n; " ::"n"(targetNbInFlightGroups));
|
|
}
|
|
|
|
template <bool swizzle, typename T, uint32_t rows, uint32_t cols, bool alignedForSwizzle>
|
|
constexpr SwizzleMode getSwizzleMode(Array2D<T, rows, cols, alignedForSwizzle> const&)
|
|
{
|
|
constexpr auto rowBytes = Array2D<T, rows, cols, alignedForSwizzle>::rowBytes;
|
|
if constexpr (!swizzle)
|
|
{
|
|
return SwizzleMode::kNONE;
|
|
}
|
|
if constexpr (rowBytes % 128 == 0)
|
|
{
|
|
return SwizzleMode::k128;
|
|
}
|
|
else if constexpr (rowBytes == 64)
|
|
{
|
|
return SwizzleMode::k64;
|
|
}
|
|
else
|
|
{
|
|
static_assert(rowBytes == 32);
|
|
return SwizzleMode::k32;
|
|
}
|
|
}
|
|
} // namespace gmma
|
|
|
|
#include "gmma_impl.cuh"
|