/* * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #include "tensorrt_llm/kernels/weightOnlyBatchedGemv/converter.h" #include "tensorrt_llm/kernels/weightOnlyBatchedGemv/details.h" namespace tensorrt_llm { namespace kernels { namespace weight_only { template struct ConverterWrapper { using TypeDetailsA = typename Details::TypeDetailsA; using TypeDetailsW = typename Details::TypeDetailsW; static constexpr bool kUseInterleavedConverter = Details::kUseInterleavedConverter; using Converter = I2FConverter; }; template struct MathWrapper { }; template <> struct MathWrapper { using Type = typename FP16DetailsA::Type; using Type2 = typename FP16DetailsA::Type2; __device__ __forceinline__ static Type2 to_vec2(Type const& v) { return __half2half2(v); } __device__ __forceinline__ static Type2 fma2(Type2 const& a, Type2 const& b, Type2 const& c) { return __hfma2(a, b, c); } __device__ __forceinline__ static Type2 mul2(Type2 const& a, Type2 const& b) { return __hmul2(a, b); } }; template <> struct MathWrapper { using Type = typename BF16DetailsA::Type; using Type2 = typename BF16DetailsA::Type2; __device__ __forceinline__ static Type2 to_vec2(Type const& v) { #if ((defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) && defined(ENABLE_BF16)) return __bfloat162bfloat162(v); #else uint32_t val = 0; Type2 ret = reinterpret_cast(val); return ret; #endif } __device__ __forceinline__ static Type2 fma2(Type2 const& a, Type2 const& b, Type2 const& c) { #if ((defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) && defined(ENABLE_BF16)) return __hfma2(a, b, c); #else return to_vec2(static_cast(0.f)); #endif } __device__ __forceinline__ static Type2 mul2(Type2 const& a, Type2 const& b) { #if ((defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) && defined(ENABLE_BF16)) return __hmul2(a, b); #else return to_vec2(static_cast(0.f)); #endif } }; template __device__ __forceinline__ void apply_scale(void* act, void* act_scale) { using Type2 = typename MathWrapper::Type2; static_assert(K % 2 == 0); [[maybe_unused]] static constexpr int VecK = K / 2; if constexpr (Enable) { Type2* pa = reinterpret_cast(act); Type2* pb = reinterpret_cast(act_scale); #pragma unroll for (int m = 0; m < M; ++m) { #pragma unroll for (int k = 0; k < VecK; ++k) { pa[m * VecK + k] = MathWrapper::mul2(pa[m * VecK + k], pb[k]); } } } } template __device__ __forceinline__ void dequantize(void* w, void* quantized_w, void* scales, void* zeros, float alpha) { using Type = typename MathWrapper::Type; using Type2 = typename MathWrapper::Type2; using Converter = typename ConverterWrapper
::Converter; static_assert(K % 2 == 0); static constexpr int VecK = K / 2; #pragma unroll for (int n = 0; n < N; ++n) { Converter::convert(reinterpret_cast(quantized_w) + n * K / Details::kElemsPerByteW, reinterpret_cast(w) + n * K); if constexpr (EnableZero || ApplyAlphaInAdvance) { Type2 vec_scale, vec_zero; if constexpr (ApplyAlphaInAdvance) { Type scales_ = static_cast(reinterpret_cast(scales)[n]) * alpha; vec_scale = MathWrapper::to_vec2(scales_); vec_zero = MathWrapper::to_vec2(static_cast(0.f)); if constexpr (EnableZero) { vec_zero = MathWrapper::to_vec2( static_cast(reinterpret_cast(zeros)[n]) * alpha); } } else { vec_scale = MathWrapper::to_vec2(reinterpret_cast(scales)[n]); vec_zero = MathWrapper::to_vec2(static_cast(0.f)); if constexpr (EnableZero) { vec_zero = MathWrapper::to_vec2(reinterpret_cast(zeros)[n]); } } #pragma unroll for (int k = 0; k < VecK; ++k) { reinterpret_cast(w)[n * VecK + k] = MathWrapper::fma2( reinterpret_cast(w)[n * VecK + k], vec_scale, vec_zero); } } } } template __device__ __forceinline__ void pack_to_vec2(void* dst, void* src, int n) { using Type = typename MathWrapper::Type; typename Details::LayoutDetails::Mapper mapper; int n0 = n & ~0x1, n1 = n & 0x1; for (int k = 0; k < K; ++k) { int physical_idx = mapper(k); reinterpret_cast(dst)[n0 * K + k * 2 + n1] = reinterpret_cast(src)[physical_idx]; } } template __device__ __forceinline__ void mma(void* acc, void* w_pack2, void* act, void* scale) { using Type = typename MathWrapper::Type; using Type2 = typename MathWrapper::Type2; static_assert(N % 2 == 0); static constexpr int VecN = N / 2; #pragma unroll for (int m = 0; m < M; ++m) { #pragma unroll for (int n = 0; n < VecN; ++n) { if constexpr (EnableZero || ApplyAlphaInAdvance) { #pragma unroll for (int k = 0; k < K; ++k) { reinterpret_cast(acc)[m * VecN + n] = MathWrapper::fma2( reinterpret_cast(w_pack2)[n * K + k], MathWrapper::to_vec2(reinterpret_cast(act)[m * K + k]), reinterpret_cast(acc)[m * VecN + n]); } } else { Type2 local_acc{}; #pragma unroll for (int k = 0; k < K; ++k) { local_acc = MathWrapper::fma2( reinterpret_cast(w_pack2)[n * K + k], MathWrapper::to_vec2(reinterpret_cast(act)[m * K + k]), local_acc); } reinterpret_cast(acc)[m * VecN + n] = MathWrapper::fma2( local_acc, reinterpret_cast(scale)[n], reinterpret_cast(acc)[m * VecN + n]); } } } } template __device__ __forceinline__ T warp_reduce_sum(T& val) { val += __shfl_xor_sync(~0, val, 16); val += __shfl_xor_sync(~0, val, 8); if (Interleave != 2 && Interleave != 4) val += __shfl_xor_sync(~0, val, 4); if (Interleave != 4) val += __shfl_xor_sync(~0, val, 2); val += __shfl_xor_sync(~0, val, 1); return val; } template __device__ __forceinline__ void epilogue(void* out, int stride, void* tile_acc, void* bias, float alpha) { using Type = typename MathWrapper::Type; static constexpr int Interleave = Details::kInterleave; static constexpr int ThreadsPerInterleavedTile = Details::kThreadsPerInterleavedTile; static constexpr int WarpSize = Details::kWarpSize; static constexpr int WarpNum = Threads / WarpSize; static_assert(Threads % WarpSize == 0); __shared__ float shmem[CtaM * CtaN * Interleave * WarpNum]; int tid = threadIdx.x; int warp_id = tid / WarpSize, lane_id = tid % WarpSize; #pragma unroll for (int m = 0; m < CtaM; ++m) { #pragma unroll for (int n = 0; n < CtaN; ++n) { float v = static_cast(reinterpret_cast(tile_acc)[m * CtaN + n]); v = warp_reduce_sum(v); if (lane_id < Interleave * ThreadsPerInterleavedTile && lane_id % ThreadsPerInterleavedTile == 0) { shmem[warp_id * CtaM * CtaN * Interleave + m * CtaN * Interleave + n * Interleave + lane_id / ThreadsPerInterleavedTile] = v; } } } __syncthreads(); #pragma unroll for (int ii = tid; ii < CtaM * CtaN * Interleave; ii += Threads) { int m = ii / (CtaN * Interleave), n = ii % (CtaN * Interleave); float val = 0.f, v_bias = 0.f; if constexpr (EnableBias) { v_bias = static_cast(reinterpret_cast(bias)[n]); } #pragma unroll for (int jj = 0; jj < WarpNum; ++jj) { val += shmem[jj * CtaM * CtaN * Interleave + ii]; } if constexpr (ApplyAlphaInAdvance) { reinterpret_cast(out)[m * stride + n] = static_cast(val + v_bias); } else { reinterpret_cast(out)[m * stride + n] = static_cast(alpha * val + v_bias); } } } template __device__ __forceinline__ void fill(void* tile, T v) { #pragma unroll for (int ii = 0; ii < N; ++ii) { reinterpret_cast(tile)[ii] = v; } } template class GMemIterator { public: __device__ __forceinline__ GMemIterator(T* addr, int offset, int step, int stride) : addr_(Enable ? (addr + offset) : nullptr) , step_(step) , stride_(stride) { } __device__ __forceinline__ void load(void* dst, int iter, int ii = 0) { if constexpr (Enable) { #pragma unroll for (int jj = 0; jj < Continuous; ++jj) { reinterpret_cast(dst)[jj] = reinterpret_cast(addr_ + iter * step_ + ii * stride_)[jj]; } } } private: T* addr_; int step_; int stride_; }; } // namespace weight_only } // namespace kernels } // namespace tensorrt_llm