/* * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #include #include #include "tensorrt_llm/common/cudaBf16Fallbacks.cuh" #include "Common.h" #include "Poly.h" namespace tensorrt_llm { namespace kernels { typedef void (*ChunkStateKernelFuncFp16)(int B_, int L_, int H_, int P_, int G_, int N_, // const half *g_mxY_, // B*L*H*P // const half *g_mxOs_, // B*C*H*N*P // const half *g_mxFs_, // B *H*N*P float* g_mxSt_, // B*C*H*N*P float const* g_mxdc_, // B*C*H*Q float const* g_mxdA_, // B*C*H*Q // const half *g_mxdt_, // B*L*(2*H*P+2*G*N+H) or B*L*(H*P+2*G*N+H) // const float *g_mxdb_, // H // const float *g_mxA_, // H // const half *g_mxCB_, // B*C*G*Q*Q // const float *g_mxD_, // H half const* g_mxXBC_, // B*L*(H*P+2*G*N) // const half *g_mxZ_, // B*L*(2*H*P+2*G*N+H) bool removePadding_, int const* lastTokenIdsPtr_); typedef void (*ChunkStateKernelFuncBf16)(int B_, int L_, int H_, int P_, int G_, int N_, // const bf16 *g_mxY_, // B*L*H*P // const bf16 *g_mxOs_, // B*C*H*N*P // const bf16 *g_mxFs_, // B *H*N*P float* g_mxSt_, // B*C*H*N*P float const* g_mxdc_, // B*C*H*Q float const* g_mxdA_, // B*C*H*Q // const bf16 *g_mxdt_, // B*L*(2*H*P+2*G*N+H) or B*L*(H*P+2*G*N+H) // const float *g_mxdb_, // H // const float *g_mxA_, // H // const bf16 *g_mxCB_, // B*C*G*Q*Q // const float *g_mxD_, // H bf16 const* g_mxXBC_, // B*L*(H*P+2*G*N) // const bf16 *g_mxZ_, // B*L*(2*H*P+2*G*N+H) bool removePadding_, int const* lastTokenIdsPtr_); template __global__ std::enable_if_t || std::is_same_v> chunk_state_kernel(int B_, int L_, int H_, int P_, int G_, int N_, // const Tp_ *g_mxY_, // B*L*H*P // const Tp_ *g_mxOs_, // B*C*H*N*P // const Tp_ *g_mxFs_, // B *H*N*P float* g_mxSt_, // B*C*H*N*P float const* g_mxdc_, // B*C*H*Q float const* g_mxdA_, // B*C*H*Q // const Tp_ *g_mxdt_, // B*L*(2*H*P+2*G*N+H) or B*L*(H*P+2*G*N+H) // const Wt_ *g_mxdb_, // H // const Wt_ *g_mxA_, // H // const Tp_ *g_mxCB_, // B*C*G*Q*Q // const Wt_ *g_mxD_, // H Tp_ const* g_mxXBC_, // B*L*(H*P+2*G*N) // const Tp_ *g_mxZ_, // B*L*(2*H*P+2*G*N+H) bool removePadding_, int const* lastTokenIdsPtr_) { #if __CUDA_ARCH__ >= 800 using namespace tensorrt_llm::common; auto blockIdx_x = Rn{int(blockIdx.x)}; auto blockIdx_y = Rn{int(blockIdx.y)}; auto blockIdx_z = Rn{int(blockIdx.z)}; auto threadIdx_x = Rn{int(threadIdx.x)}; auto threadIdx_y = Rn{int(threadIdx.y)}; auto threadIdx_z = Rn{int(threadIdx.z)}; // auto B = Rn{B_}; auto L = Rn{L_}; auto H = Rn{H_}; auto P = Rn{P_}; auto G = Rn{G_}; auto N = Rn{N_}; auto Q = cn; auto C = Rn{div_up(L.var, Q_)}; auto xbcDim = Rn{H_ * P_ + 2 * G_ * N_}; auto bOffset = Rn{H_ * P_}; auto cOffset = Rn{H_ * P_ + G_ * N_}; auto aStart = blockIdx_z * L; auto cStart = blockIdx_z * C; if (removePadding_) { aStart = Rn{int(blockIdx.z ? lastTokenIdsPtr_[blockIdx.z - 1] : 0)}; cStart = Rn{int(blockIdx.z ? div_up(aStart.var, Q_) + blockIdx.z - 1 : 0)}; L = Rn{lastTokenIdsPtr_[blockIdx.z] - aStart.var}; C = Rn{div_up(L.var, Q_)}; } else { L = Rn{lastTokenIdsPtr_[blockIdx.z]}; C = Rn{div_up(L.var, Q_)}; } if (blockIdx_y * Q >= L) return; auto hStart = Rn{blockIdx_x.var / (P_ / cn) / (N_ / cn) }; auto mStart = Rn{blockIdx_x.var / (P_ / cn) % (N_ / cn) }; auto nStart = Rn{blockIdx_x.var % (P_ / cn) }; auto gStart = Rn{hStart.var / (H_ / G_)}; extern __shared__ float smem[]; Tp_* s_mxB = (Tp_*) smem; Tp_* s_mxX = (Tp_*) smem + tileM_ * tileK_ * pipeS_; float* s_mxdc = smem + (tileM_ + tileN_) * tileK_ * pipeS_ / 2; float* s_mxdA = smem + (tileM_ + tileN_) * tileK_ * pipeS_ / 2 + Q_; unsigned b_base = __nvvm_get_smem_pointer(smem); unsigned b_mxB = b_base; unsigned b_mxX = b_base + tileM_ * tileK_ * pipeS_ * sizeof(Tp_); using std::array; register array, tileN_ / wmmaN_ / warpN_>, tileM_ / wmmaM_ / warpM_> r_mxSt = array, tileN_ / wmmaN_ / warpN_>, tileM_ / wmmaM_ / warpM_>(); register array, tileM_ / wmmaM_ / warpM_> r_mxB; register array, tileN_ / wmmaN_ / warpN_> r_mxX; constexpr int step = std::max( 1, tileM_ / wmmaM_ / warpM_ * tileN_ / wmmaN_ / warpN_ / (tileM_ / wmmaM_ / warpM_ + tileN_ / wmmaN_ / warpN_)); auto baseB = [](auto iK) { return iK % cn * cn * cn; }; auto baseX = [](auto iK) { return iK % cn * cn * cn; }; auto thread = [=](auto iStep) { return iStep * cn + threadIdx_z * cn + threadIdx_y * cn<256> + threadIdx_x * cn<8>; }; #pragma unroll for (Rn iStep; iStep.var < iStep.size; iStep.var++) if (thread(iStep) < cn) { #pragma unroll for (int i = 0; i < 8; i += 4) { *(int4*) (s_mxdc + get(thread(iStep)) + i) = *(int4*) (g_mxdc_ + get((cStart + blockIdx_y) * H * Q + hStart * Q + thread(iStep)) + i); *(int4*) (s_mxdA + get(thread(iStep)) + i) = *(int4*) (g_mxdA_ + get((cStart + blockIdx_y) * H * Q + hStart * Q + thread(iStep)) + i); } } #pragma unroll for (Rn iK; iK.var < iK.size; iK.var++) { #pragma unroll for (Rn iStep; iStep.var < iStep.size; iStep.var++) if (thread(iStep) < cn && thread(iStep) / cn < L - blockIdx_y * Q - iK * cn) cp_shared_global<16>(b_mxB + swizzle(thread(iStep) * cn<2>, baseB(iK) * cn<2>), g_mxXBC_ + get((aStart + blockIdx_y * Q + iK * cn + thread(iStep) / cn) *xbcDim + bOffset + gStart * N + mStart * cn + thread(iStep) % cn)); else if (thread(iStep) < cn) *(int4*) ((char*) s_mxB + swizzle(thread(iStep) * cn<2>, baseB(iK) * cn<2>)) = int4{0, 0, 0, 0}; #pragma unroll for (Rn iStep; iStep.var < iStep.size; iStep.var++) if (thread(iStep) < cn && thread(iStep) / cn < L - blockIdx_y * Q - iK * cn) cp_shared_global<16>(b_mxX + swizzle(thread(iStep) * cn<2>, baseX(iK) * cn<2>), g_mxXBC_ + get((aStart + blockIdx_y * Q + iK * cn + thread(iStep) / cn) *xbcDim + hStart * P + nStart * cn + thread(iStep) % cn)); else if (thread(iStep) < cn) *(int4*) ((char*) s_mxX + swizzle(thread(iStep) * cn<2>, baseX(iK) * cn<2>)) = int4{0, 0, 0, 0}; cp_commit_group(); } asm volatile("cp.async.wait_group %0;\n" ::"n"(pipeS_ - 1)); __syncthreads(); for (int iK = pipeS_; iK < Q_ / tileK_ + pipeS_; iK++) { auto jK = Rn<>{iK}; #pragma unroll for (Rn iStep; iStep.var < iStep.size; iStep.var++) if (thread(iStep) < cn) { register Tp_ tmpB[8]; *(int4*) &tmpB[0] = *( int4*) ((char*) s_mxB + swizzle(thread(iStep) * cn<2>, baseB(jK) * cn<2>)); #pragma unroll for (int i = 0; i < 8; i += 2) { float2 tmp2 = std::is_same_v ? __half22float2(*(half2*) &tmpB[i]) : bf1622float2(*(bf162*) &tmpB[i]); int kStart = (iK - pipeS_) * cn; tmp2.x *= expf(s_mxdA[Q_ - 1] - s_mxdA[kStart + get(thread(iStep) / cn)]) * s_mxdc[kStart + get(thread(iStep) / cn)]; tmp2.y *= expf(s_mxdA[Q_ - 1] - s_mxdA[kStart + get(thread(iStep) / cn)]) * s_mxdc[kStart + get(thread(iStep) / cn)]; if (std::is_same_v) *(half2*) &tmpB[i] = __float22half2_rn(tmp2); else *(bf162*) &tmpB[i] = __float22bfloat162_rn(tmp2); } *(int4*) ((char*) s_mxB + swizzle(thread(iStep) * cn<2>, baseB(jK) * cn<2>)) = *(int4*) &tmpB[0]; } __syncthreads(); #pragma unroll for (int k = 0; k < tileK_ / wmmaK_; k++) { #pragma unroll for (int y = 0; y < tileM_ / wmmaM_ / warpM_; y++) #pragma unroll for (int x = 0; x < tileN_ / wmmaN_ / warpN_; x++) { if ((y * tileN_ / wmmaN_ / warpN_ + x) % step == 0) { int x1 = (y * tileN_ / wmmaN_ / warpN_ + x) / step; int y1 = x1 - tileN_ / wmmaN_ / warpN_ + (tileM_ / wmmaM_ / warpM_ == 1 || tileN_ / wmmaN_ / warpN_ == 1); if (y1 >= 0 && y1 < tileM_ / wmmaM_ / warpM_) { if (wmmaK_ == 16) asm volatile( "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];\n" : "=r"(r_mxB[y1][0]), "=r"(r_mxB[y1][1]), "=r"(r_mxB[y1][2]), "=r"(r_mxB[y1][3]) : "r"(b_mxB + iK % pipeS_ * (tileM_ * tileK_ * 2) + 2 * swz(y1 * warpM_ * wmmaM_ + k * wmmaK_ * tileM_ + threadIdx.z * wmmaM_ + threadIdx.x % 8 * tileM_ + threadIdx.x / 8 % 2 * 8 + threadIdx.x / wmmaK_ * 8 * tileM_))); } if (x1 >= 0 && x1 < tileN_ / wmmaN_ / warpN_) { if (wmmaK_ == 16 && x1 % 2 == 0) asm volatile( "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];\n" : "=r"(r_mxX[x1][0]), "=r"(r_mxX[x1][1]), "=r"(r_mxX[x1 + 1][0]), "=r"(r_mxX[x1 + 1][1]) : "r"(b_mxX + iK % pipeS_ * (tileK_ * tileN_ * 2) + 2 * swz(x1 * warpN_ * wmmaN_ + k * wmmaK_ * tileN_ + threadIdx.y * wmmaN_ + threadIdx.x % wmmaK_ * tileN_ + threadIdx.x / wmmaK_ * warpN_ * wmmaN_))); } } } #pragma unroll for (int y = 0; y < tileM_ / wmmaM_ / warpM_; y++) #pragma unroll for (int x = 0; x < tileN_ / wmmaN_ / warpN_; x++) { if (wmmaK_ == 16) { if (std::is_same_v) asm volatile( "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \n" " {%0, %1, %2, %3}, \n" " {%4, %5, %6, %7}, \n" " {%8, %9}, \n" " {%0, %1, %2, %3}; \n" : "+f"(r_mxSt[y][x][0]), "+f"(r_mxSt[y][x][1]), "+f"(r_mxSt[y][x][2]), "+f"(r_mxSt[y][x][3]) : "r"(r_mxB[y][0]), "r"(r_mxB[y][1]), "r"(r_mxB[y][2]), "r"(r_mxB[y][3]), "r"(r_mxX[x][0]), "r"(r_mxX[x][1])); else asm volatile( "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 \n" " {%0, %1, %2, %3}, \n" " {%4, %5, %6, %7}, \n" " {%8, %9}, \n" " {%0, %1, %2, %3}; \n" : "+f"(r_mxSt[y][x][0]), "+f"(r_mxSt[y][x][1]), "+f"(r_mxSt[y][x][2]), "+f"(r_mxSt[y][x][3]) : "r"(r_mxB[y][0]), "r"(r_mxB[y][1]), "r"(r_mxB[y][2]), "r"(r_mxB[y][3]), "r"(r_mxX[x][0]), "r"(r_mxX[x][1])); } } } __syncthreads(); #pragma unroll for (Rn iStep; iStep.var < iStep.size; iStep.var++) if (thread(iStep) < cn && thread(iStep) / cn < L - blockIdx_y * Q - jK * cn && jK * cn < Q) cp_shared_global<16>(b_mxB + swizzle(thread(iStep) * cn<2>, baseB(jK) * cn<2>), g_mxXBC_ + get((aStart + blockIdx_y * Q + jK * cn + thread(iStep) / cn) *xbcDim + bOffset + gStart * N + mStart * cn + thread(iStep) % cn)); else if (thread(iStep) < cn && jK * cn < Q) *(int4*) ((char*) s_mxB + swizzle(thread(iStep) * cn<2>, baseB(jK) * cn<2>)) = int4{0, 0, 0, 0}; #pragma unroll for (Rn iStep; iStep.var < iStep.size; iStep.var++) if (thread(iStep) < cn && thread(iStep) / cn < L - blockIdx_y * Q - jK * cn && jK * cn < Q) cp_shared_global<16>(b_mxX + swizzle(thread(iStep) * cn<2>, baseX(jK) * cn<2>), g_mxXBC_ + get((aStart + blockIdx_y * Q + jK * cn + thread(iStep) / cn) *xbcDim + hStart * P + nStart * cn + thread(iStep) % cn)); else if (thread(iStep) < cn && jK * cn < Q) *(int4*) ((char*) s_mxX + swizzle(thread(iStep) * cn<2>, baseX(jK) * cn<2>)) = int4{0, 0, 0, 0}; asm volatile("cp.async.commit_group;\n" ::); asm volatile("cp.async.wait_group %0;\n" ::"n"(pipeS_ - 1)); __syncthreads(); } #pragma unroll for (int y = 0; y < tileM_ / wmmaM_ / warpM_; y++) #pragma unroll for (int x = 0; x < tileN_ / wmmaN_ / warpN_; x++) { *(float2*) (g_mxSt_ + get((cStart + blockIdx_y) * H * N * P + hStart * N * P + (mStart * cn + Rn{y} * cn + threadIdx_z * cn + threadIdx_x / cn<4>) *P + nStart * cn + Rn{x} * cn + threadIdx_y * cn + threadIdx_x % cn<4> * cn<2>)) = *(float2*) &r_mxSt[y][x][0]; *(float2*) (g_mxSt_ + get((cStart + blockIdx_y) * H * N * P + hStart * N * P + (mStart * cn + Rn{y} * cn + cn<8> + threadIdx_z * cn + threadIdx_x / cn<4>) *P + nStart * cn + Rn{x} * cn + threadIdx_y * cn + threadIdx_x % cn<4> * cn<2>)) = *(float2*) &r_mxSt[y][x][2]; } asm volatile("cp.async.wait_group %0;\n" ::"n"(0)); #endif } ChunkStateKernelFuncFp16 getChunkStateKernelFp16( int B_, int L_, int H_, int P_, int G_, int N_, int Q_, dim3* blockDims_, dim3* threadDims_, int* sharedMem_) { int B = B_; int L = L_; int H = H_; int P = P_; // int G = G_; int N = N_; int Q = Q_; int C = div_up(L, Q); int tileM = 64; int tileN = 64; int tileK = 32; int warpM = 1; int warpN = 2; int pipeS = 3; auto sharedMem = (tileM * tileK + tileK * tileN) * pipeS * 2 + Q * 8; *blockDims_ = dim3(H * P / tileN * N / tileM, C, B); *threadDims_ = dim3(32, warpN, warpM); *sharedMem_ = sharedMem; if (Q_ == 128) return chunk_state_kernel<128, 64, 64, 32, 16, 8, 16, 1, 2, 3, half>; else if (Q_ == 256) return chunk_state_kernel<256, 64, 64, 32, 16, 8, 16, 1, 2, 3, half>; else return nullptr; } ChunkStateKernelFuncBf16 getChunkStateKernelBf16( int B_, int L_, int H_, int P_, int G_, int N_, int Q_, dim3* blockDims_, dim3* threadDims_, int* sharedMem_) { int B = B_; int L = L_; int H = H_; int P = P_; // int G = G_; int N = N_; int Q = Q_; int C = div_up(L, Q); int tileM = 64; int tileN = 64; int tileK = 32; int warpM = 1; int warpN = 2; int pipeS = 3; auto sharedMem = (tileM * tileK + tileK * tileN) * pipeS * 2 + Q * 8; *blockDims_ = dim3(H * P / tileN * N / tileM, C, B); *threadDims_ = dim3(32, warpN, warpM); *sharedMem_ = sharedMem; if (Q_ == 128) return chunk_state_kernel<128, 64, 64, 32, 16, 8, 16, 1, 2, 3, bf16>; else if (Q_ == 256) return chunk_state_kernel<256, 64, 64, 32, 16, 8, 16, 1, 2, 3, bf16>; else return nullptr; } } // namespace kernels } // namespace tensorrt_llm // vim: ts=2 sw=2 sts=2 et sta