/* * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #include #include #include "tensorrt_llm/common/cudaBf16Fallbacks.cuh" #include "Common.h" #include "Poly.h" namespace tensorrt_llm { namespace kernels { typedef void (*ChunkScanKernelFuncFp16)(int B_, int L_, int H_, int P_, int G_, int N_, half* g_mxY_, // B*L*H*P half const* g_mxOs_, // B*C*H*N*P // const half *g_mxFs_, // B *H*N*P // const float *g_mxSt_, // B*C*H*N*P float const* g_mxdc_, // B*C*H*Q float const* g_mxdA_, // B*C*H*Q // const half *g_mxdt_, // B*L*(2*H*P+2*G*N+H) or B*L*(H*P+2*G*N+H) // const float *g_mxdb_, // H // const float *g_mxA_, // H half const* g_mxCB_, // B*C*G*Q*Q float const* g_mxD_, // H half const* g_mxXBC_, // B*L*(H*P+2*G*N) half const* g_mxZ_, // B*L*(2*H*P+2*G*N+H) bool removePadding_, int const* lastTokenIdsPtr_); typedef void (*ChunkScanKernelFuncBf16)(int B_, int L_, int H_, int P_, int G_, int N_, bf16* g_mxY_, // B*L*H*P bf16 const* g_mxOs_, // B*C*H*N*P // const bf16 *g_mxFs_, // B *H*N*P // const float *g_mxSt_, // B*C*H*N*P float const* g_mxdc_, // B*C*H*Q float const* g_mxdA_, // B*C*H*Q // const bf16 *g_mxdt_, // B*L*(2*H*P+2*G*N+H) or B*L*(H*P+2*G*N+H) // const float *g_mxdb_, // H // const float *g_mxA_, // H bf16 const* g_mxCB_, // B*C*G*Q*Q float const* g_mxD_, // H bf16 const* g_mxXBC_, // B*L*(H*P+2*G*N) bf16 const* g_mxZ_, // B*L*(2*H*P+2*G*N+H) bool removePadding_, int const* lastTokenIdsPtr_); template __global__ std::enable_if_t || std::is_same_v> chunk_scan_kernel(int B_, int L_, int H_, int P_, int G_, int N_, Tp_* g_mxY_, // B*L*H*P Tp_ const* g_mxOs_, // B*C*H*N*P // const Tp_ *g_mxFs_, // B *H*N*P // const float *g_mxSt_, // B*C*H*N*P float const* g_mxdc_, // B*C*H*Q float const* g_mxdA_, // B*C*H*Q // const Tp_ *g_mxdt_, // B*L*(2*H*P+2*G*N+H) or B*L*(H*P+2*G*N+H) // const Wt_ *g_mxdb_, // H // const Wt_ *g_mxA_, // H Tp_ const* g_mxCB_, // B*C*G*Q*Q Wt_ const* g_mxD_, // H Tp_ const* g_mxXBC_, // B*L*(H*P+2*G*N) Tp_ const* g_mxZ_, // B*L*(2*H*P+2*G*N+H) bool removePadding_, int const* lastTokenIdsPtr_) { #if __CUDA_ARCH__ >= 800 using namespace tensorrt_llm::common; auto blockIdx_x = Rn{int(blockIdx.x)}; auto blockIdx_y = Rn{int(blockIdx.y)}; auto blockIdx_z = Rn{int(blockIdx.z)}; auto threadIdx_x = Rn{int(threadIdx.x)}; auto threadIdx_y = Rn{int(threadIdx.y)}; auto threadIdx_z = Rn{int(threadIdx.z)}; // auto B = Rn{B_}; auto L = Rn{L_}; auto H = Rn{H_}; auto P = Rn{P_}; auto G = Rn{G_}; auto N = Rn{N_}; auto Q = cn; auto C = Rn{div_up(L.var, Q_)}; auto xbcDim = Rn{H_ * P_ + 2 * G_ * N_}; auto zdtDim = Rn{2 * H_ * P_ + 2 * G_ * N_ + H_}; auto cOffset = Rn{H_ * P_ + G_ * N_}; auto aStart = blockIdx_z * L; auto cStart = blockIdx_z * C; if (removePadding_) { aStart = Rn{int(blockIdx.z ? lastTokenIdsPtr_[blockIdx.z - 1] : 0)}; cStart = Rn{int(blockIdx.z ? div_up(aStart.var, Q_) + blockIdx.z - 1 : 0)}; L = Rn{lastTokenIdsPtr_[blockIdx.z] - aStart.var}; C = Rn{div_up(L.var, Q_)}; } else { L = Rn{lastTokenIdsPtr_[blockIdx.z]}; C = Rn{div_up(L.var, Q_)}; } if (blockIdx_y * Q >= L) return; auto hStart = Rn{blockIdx_x.var / (P_ / cn) / (Q / cn) }; auto mStart = Rn{blockIdx_x.var / (P_ / cn) % (Q / cn) }; auto nStart = Rn{blockIdx_x.var % (P_ / cn) }; auto gStart = Rn{hStart.var / (H_ / G_)}; extern __shared__ float smem[]; Tp_* s_mxC = (Tp_*) smem; Tp_* s_mxOs = (Tp_*) smem + tileM_ * tileK_ * pipeS_; Tp_* s_mxY = (Tp_*) smem; float* s_mxdc = smem + (tileM_ + tileN_) * tileK_ * pipeS_ / 2; float* s_mxdA = smem + (tileM_ + tileN_) * tileK_ * pipeS_ / 2 + Q_; unsigned b_base = __nvvm_get_smem_pointer(smem); unsigned b_mxC = b_base; unsigned b_mxOs = b_base + tileM_ * tileK_ * pipeS_ * sizeof(Tp_); unsigned b_mxY = b_base; using std::array; register array, tileN_ / wmmaN_ / warpN_>, tileM_ / wmmaM_ / warpM_> r_mxY = array, tileN_ / wmmaN_ / warpN_>, tileM_ / wmmaM_ / warpM_>(); register array, tileM_ / wmmaM_ / warpM_> r_mxC; register array, tileN_ / wmmaN_ / warpN_> r_mxOs; constexpr int step = std::max( 1, tileM_ / wmmaM_ / warpM_ * tileN_ / wmmaN_ / warpN_ / (tileM_ / wmmaM_ / warpM_ + tileN_ / wmmaN_ / warpN_)); auto baseC = [](auto iK) { return iK % cn * cn * cn; }; auto baseOs = [](auto iK) { return iK % cn * cn * cn; }; auto thread = [=](auto iStep) { return iStep * cn + threadIdx_z * cn + threadIdx_y * cn<256> + threadIdx_x * cn<8>; }; #pragma unroll for (Rn iStep; iStep.var < iStep.size; iStep.var++) if (thread(iStep) < cn) { #pragma unroll for (int i = 0; i < 8; i += 4) { *(int4*) (s_mxdc + get(thread(iStep)) + i) = *(int4*) (g_mxdc_ + get((cStart + blockIdx_y) * H * Q + hStart * Q + thread(iStep)) + i); *(int4*) (s_mxdA + get(thread(iStep)) + i) = *(int4*) (g_mxdA_ + get((cStart + blockIdx_y) * H * Q + hStart * Q + thread(iStep)) + i); } } #pragma unroll for (Rn iK; iK.var < iK.size; iK.var++) { #pragma unroll for (Rn iStep; iStep.var < iStep.size; iStep.var++) if (thread(iStep) < cn && thread(iStep) / cn < L - blockIdx_y * Q - mStart * cn) cp_shared_global<16>(b_mxC + swizzle(thread(iStep) * cn<2>, baseC(iK) * cn<2>), g_mxXBC_ + get((aStart + blockIdx_y * Q + mStart * cn + thread(iStep) / cn) *xbcDim + cOffset + gStart * N + iK * cn + thread(iStep) % cn)); else if (thread(iStep) < cn) *(int4*) ((char*) s_mxC + swizzle(thread(iStep) * cn<2>, baseC(iK) * cn<2>)) = int4{0, 0, 0, 0}; #pragma unroll for (Rn iStep; iStep.var < iStep.size; iStep.var++) if (thread(iStep) < cn) cp_shared_global<16>( b_mxOs + swizzle(thread(iStep) * cn<2>, baseOs(iK) * cn<2>), g_mxOs_ + get((cStart + blockIdx_y) * H * N * P + hStart * N * P + (iK * cn + thread(iStep) / cn) *P + nStart * cn + thread(iStep) % cn)); cp_commit_group(); } asm volatile("cp.async.wait_group %0;\n" ::"n"(pipeS_ - 1)); __syncthreads(); for (int iK = pipeS_; iK < (N_ + Q_) / tileK_ + pipeS_; iK++) { auto jK = Rn<>{iK}; if ((iK - pipeS_) * cn == N_) { #pragma unroll for (int y = 0; y < tileM_ / wmmaM_ / warpM_; y++) #pragma unroll for (int x = 0; x < tileN_ / wmmaN_ / warpN_; x++) { float2 tmp2 = float2{expf(s_mxdA[get(mStart * cn + Rn{y} * cn + threadIdx_z * cn + threadIdx_x / cn<4>)]), expf(s_mxdA[get(mStart * cn + Rn{y} * cn + cn<8> + threadIdx_z * cn + threadIdx_x / cn<4>)])}; r_mxY[y][x][0] *= tmp2.x; r_mxY[y][x][1] *= tmp2.x; r_mxY[y][x][2] *= tmp2.y; r_mxY[y][x][3] *= tmp2.y; } } if ((iK - pipeS_) * cn >= N_) { #pragma unroll for (Rn iStep; iStep.var < iStep.size; iStep.var++) if (thread(iStep) < cn) { register Tp_ tmpCB[8]; *(int4*) &tmpCB[0] = *(int4*) ((char*) s_mxC + swizzle(thread(iStep) * cn<2>, baseC(jK) * cn<2>)); #pragma unroll for (int i = 0; i < 8; i += 2) { float2 tmp2 = std::is_same_v ? __half22float2(*(half2*) &tmpCB[i]) : bf1622float2(*(bf162*) &tmpCB[i]); int kStart = (iK - pipeS_) * cn - N_; tmp2.x *= expf(s_mxdA[get(mStart * cn + thread(iStep) / cn)] - s_mxdA[kStart + get(thread(iStep) % cn + Rn{i})]) * s_mxdc[kStart + get(thread(iStep) % cn + Rn{i})]; tmp2.y *= expf(s_mxdA[get(mStart * cn + thread(iStep) / cn)] - s_mxdA[kStart + get(thread(iStep) % cn + Rn{i + 1})]) * s_mxdc[kStart + get(thread(iStep) % cn + Rn{i + 1})]; if (get(mStart * cn + thread(iStep) / cn) < kStart + get(thread(iStep) % cn + Rn{i})) tmp2.x = 0; if (get(mStart * cn + thread(iStep) / cn) < kStart + get(thread(iStep) % cn + Rn{i + 1})) tmp2.y = 0; if (std::is_same_v) *(half2*) &tmpCB[i] = __float22half2_rn(tmp2); else *(bf162*) &tmpCB[i] = __float22bfloat162_rn(tmp2); } *(int4*) ((char*) s_mxC + swizzle(thread(iStep) * cn<2>, baseC(jK) * cn<2>)) = *(int4*) &tmpCB[0]; } __syncthreads(); } #pragma unroll for (int k = 0; k < tileK_ / wmmaK_; k++) { #pragma unroll for (int y = 0; y < tileM_ / wmmaM_ / warpM_; y++) #pragma unroll for (int x = 0; x < tileN_ / wmmaN_ / warpN_; x++) { if ((y * tileN_ / wmmaN_ / warpN_ + x) % step == 0) { int x1 = (y * tileN_ / wmmaN_ / warpN_ + x) / step; int y1 = x1 - tileN_ / wmmaN_ / warpN_ + (tileM_ / wmmaM_ / warpM_ == 1 || tileN_ / wmmaN_ / warpN_ == 1); if (y1 >= 0 && y1 < tileM_ / wmmaM_ / warpM_) { if (wmmaK_ == 16) asm volatile( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n" : "=r"(r_mxC[y1][0]), "=r"(r_mxC[y1][1]), "=r"(r_mxC[y1][2]), "=r"(r_mxC[y1][3]) : "r"(b_mxC + iK % pipeS_ * (tileM_ * tileK_ * 2) + 2 * swz(y1 * warpM_ * wmmaM_ * tileK_ + k * wmmaK_ + threadIdx.z * wmmaM_ * tileK_ + threadIdx.x % 16 * tileK_ + threadIdx.x / 16 * 8))); } if (x1 >= 0 && x1 < tileN_ / wmmaN_ / warpN_) { if (wmmaK_ == 16 && x1 % 2 == 0) asm volatile( "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];\n" : "=r"(r_mxOs[x1][0]), "=r"(r_mxOs[x1][1]), "=r"(r_mxOs[x1 + 1][0]), "=r"(r_mxOs[x1 + 1][1]) : "r"(b_mxOs + iK % pipeS_ * (tileK_ * tileN_ * 2) + 2 * swz(x1 * warpN_ * wmmaN_ + k * wmmaK_ * tileN_ + threadIdx.y * wmmaN_ + threadIdx.x % wmmaK_ * tileN_ + threadIdx.x / wmmaK_ * warpN_ * wmmaN_))); } } } #pragma unroll for (int y = 0; y < tileM_ / wmmaM_ / warpM_; y++) #pragma unroll for (int x = 0; x < tileN_ / wmmaN_ / warpN_; x++) { if (wmmaK_ == 16) { if (std::is_same_v) asm volatile( "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \n" " {%0, %1, %2, %3}, \n" " {%4, %5, %6, %7}, \n" " {%8, %9}, \n" " {%0, %1, %2, %3}; \n" : "+f"(r_mxY[y][x][0]), "+f"(r_mxY[y][x][1]), "+f"(r_mxY[y][x][2]), "+f"(r_mxY[y][x][3]) : "r"(r_mxC[y][0]), "r"(r_mxC[y][1]), "r"(r_mxC[y][2]), "r"(r_mxC[y][3]), "r"(r_mxOs[x][0]), "r"(r_mxOs[x][1])); else asm volatile( "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 \n" " {%0, %1, %2, %3}, \n" " {%4, %5, %6, %7}, \n" " {%8, %9}, \n" " {%0, %1, %2, %3}; \n" : "+f"(r_mxY[y][x][0]), "+f"(r_mxY[y][x][1]), "+f"(r_mxY[y][x][2]), "+f"(r_mxY[y][x][3]) : "r"(r_mxC[y][0]), "r"(r_mxC[y][1]), "r"(r_mxC[y][2]), "r"(r_mxC[y][3]), "r"(r_mxOs[x][0]), "r"(r_mxOs[x][1])); } } } __syncthreads(); if (iK * cn < N_) { #pragma unroll for (Rn iStep; iStep.var < iStep.size; iStep.var++) if (thread(iStep) < cn && thread(iStep) / cn < L - blockIdx_y * Q - mStart * cn) cp_shared_global<16>( b_mxC + swizzle(thread(iStep) * cn<2>, baseC(jK) * cn<2>), g_mxXBC_ + get((aStart + blockIdx_y * Q + mStart * cn + thread(iStep) / cn) *xbcDim + cOffset + gStart * N + jK * cn + thread(iStep) % cn)); else if (thread(iStep) < cn) *(int4*) ((char*) s_mxC + swizzle(thread(iStep) * cn<2>, baseC(jK) * cn<2>)) = int4{0, 0, 0, 0}; #pragma unroll for (Rn iStep; iStep.var < iStep.size; iStep.var++) if (thread(iStep) < cn) cp_shared_global<16>( b_mxOs + swizzle(thread(iStep) * cn<2>, baseOs(jK) * cn<2>), g_mxOs_ + get((cStart + blockIdx_y) * H * N * P + hStart * N * P + (jK * cn + thread(iStep) / cn) *P + nStart * cn + thread(iStep) % cn)); } else if (iK * cn < N_ + Q_) { #pragma unroll for (Rn iStep; iStep.var < iStep.size; iStep.var++) if (thread(iStep) < cn) cp_shared_global<16>( b_mxC + swizzle(thread(iStep) * cn<2>, baseC(jK) * cn<2>), g_mxCB_ + get((cStart + blockIdx_y) * G * Q * Q + gStart * Q * Q + (mStart * cn + thread(iStep) / cn) *Q + jK * cn - N + thread(iStep) % cn)); #pragma unroll for (Rn iStep; iStep.var < iStep.size; iStep.var++) if (thread(iStep) < cn && thread(iStep) / cn < L - blockIdx_y * Q - jK * cn + N) cp_shared_global<16>( b_mxOs + swizzle(thread(iStep) * cn<2>, baseOs(jK) * cn<2>), g_mxXBC_ + get((aStart + blockIdx_y * Q + jK * cn - N + thread(iStep) / cn) *xbcDim + hStart * P + nStart * cn + thread(iStep) % cn)); else if (thread(iStep) < cn) *(int4*) ((char*) s_mxOs + swizzle(thread(iStep) * cn<2>, baseOs(jK) * cn<2>)) = int4{0, 0, 0, 0}; } asm volatile("cp.async.commit_group;\n" ::); asm volatile("cp.async.wait_group %0;\n" ::"n"(pipeS_ - 1)); __syncthreads(); } if (g_mxD_) { float r_D = g_mxD_[hStart.var]; #pragma unroll for (int y = 0; y < tileM_ / wmmaM_ / warpM_; y++) #pragma unroll for (int x = 0; x < tileN_ / wmmaN_ / warpN_; x++) { Tp_ tmp16[4] = {0}; float tmp32[4] = {0}; if (blockIdx_y * Q + mStart * cn + Rn{y} * cn + threadIdx_z * cn + threadIdx_x / cn<4> < L) { *(int*) &tmp16[0] = *(int*) (g_mxXBC_ + get((aStart + blockIdx_y * Q + mStart * cn + Rn{y} * cn + threadIdx_z * cn + threadIdx_x / cn<4>) *xbcDim + hStart * P + nStart * cn + Rn{x} * cn + threadIdx_y * cn + threadIdx_x % cn<4> * cn<2>)); *(float2*) &tmp32[0] = std::is_same_v ? __half22float2(*(half2*) &tmp16[0]) : bf1622float2(*(bf162*) &tmp16[0]); r_mxY[y][x][0] += r_D * tmp32[0]; r_mxY[y][x][1] += r_D * tmp32[1]; } if (blockIdx_y * Q + mStart * cn + Rn{y} * cn + cn<8> + threadIdx_z * cn + threadIdx_x / cn<4> < L) { *(int*) &tmp16[2] = *(int*) (g_mxXBC_ + get((aStart + blockIdx_y * Q + mStart * cn + Rn{y} * cn + cn<8> + threadIdx_z * cn + threadIdx_x / cn<4>) *xbcDim + hStart * P + nStart * cn + Rn{x} * cn + threadIdx_y * cn + threadIdx_x % cn<4> * cn<2>)); *(float2*) &tmp32[2] = std::is_same_v ? __half22float2(*(half2*) &tmp16[2]) : bf1622float2(*(bf162*) &tmp16[2]); r_mxY[y][x][2] += r_D * tmp32[2]; r_mxY[y][x][3] += r_D * tmp32[3]; } } } if (g_mxZ_) { #pragma unroll for (int y = 0; y < tileM_ / wmmaM_ / warpM_; y++) #pragma unroll for (int x = 0; x < tileN_ / wmmaN_ / warpN_; x++) { Tp_ tmp16[4] = {0}; float tmp32[4] = {0}; if (blockIdx_y * Q + mStart * cn + Rn{y} * cn + threadIdx_z * cn + threadIdx_x / cn<4> < L) { *(int*) &tmp16[0] = *(int*) (g_mxZ_ + get((aStart + blockIdx_y * Q + mStart * cn + Rn{y} * cn + threadIdx_z * cn + threadIdx_x / cn<4>) *zdtDim + hStart * P + nStart * cn + Rn{x} * cn + threadIdx_y * cn + threadIdx_x % cn<4> * cn<2>)); *(float2*) &tmp32[0] = std::is_same_v ? __half22float2(*(half2*) &tmp16[0]) : bf1622float2(*(bf162*) &tmp16[0]); r_mxY[y][x][0] *= tmp32[0] > 32.f ? tmp32[0] : tmp32[0] / (1.f + expf(-tmp32[0])); r_mxY[y][x][1] *= tmp32[1] > 32.f ? tmp32[1] : tmp32[1] / (1.f + expf(-tmp32[1])); } if (blockIdx_y * Q + mStart * cn + Rn{y} * cn + cn<8> + threadIdx_z * cn + threadIdx_x / cn<4> < L) { *(int*) &tmp16[2] = *(int*) (g_mxZ_ + get((aStart + blockIdx_y * Q + mStart * cn + Rn{y} * cn + cn<8> + threadIdx_z * cn + threadIdx_x / cn<4>) *zdtDim + hStart * P + nStart * cn + Rn{x} * cn + threadIdx_y * cn + threadIdx_x % cn<4> * cn<2>)); *(float2*) &tmp32[2] = std::is_same_v ? __half22float2(*(half2*) &tmp16[2]) : bf1622float2(*(bf162*) &tmp16[2]); r_mxY[y][x][2] *= tmp32[2] > 32.f ? tmp32[2] : tmp32[2] / (1.f + expf(-tmp32[2])); r_mxY[y][x][3] *= tmp32[3] > 32.f ? tmp32[3] : tmp32[3] / (1.f + expf(-tmp32[3])); } } } #pragma unroll for (int y = 0; y < tileM_ / wmmaM_ / warpM_; y++) #pragma unroll for (int x = 0; x < tileN_ / wmmaN_ / warpN_; x++) { if (std::is_same_v) { *(half2*) &r_mxY[y][x][0] = __floats2half2_rn(r_mxY[y][x][0], r_mxY[y][x][1]); *(half2*) &r_mxY[y][x][2] = __floats2half2_rn(r_mxY[y][x][2], r_mxY[y][x][3]); } else { *(bf162*) &r_mxY[y][x][0] = __floats2bfloat162_rn(r_mxY[y][x][0], r_mxY[y][x][1]); *(bf162*) &r_mxY[y][x][2] = __floats2bfloat162_rn(r_mxY[y][x][2], r_mxY[y][x][3]); } } #pragma unroll for (int y = 0; y < tileM_ / wmmaM_ / warpM_; y++) #pragma unroll for (int x = 0; x < tileN_ / wmmaN_ / warpN_; x++) { asm volatile("st.shared.b32 [%0], %1;\n" ::"r"(b_mxY + 2 * swz(y * warpM_ * wmmaM_ * tileN_ + x * warpN_ * wmmaN_ + (threadIdx.z * wmmaM_ + threadIdx.x / 4) * tileN_ + (threadIdx.y * wmmaN_ + threadIdx.x % 4 * 2))), "r"(*(unsigned*) &r_mxY[y][x][0])); asm volatile("st.shared.b32 [%0], %1;\n" ::"r"(b_mxY + 2 * swz(y * warpM_ * wmmaM_ * tileN_ + 8 * tileN_ + x * warpN_ * wmmaN_ + (threadIdx.z * wmmaM_ + threadIdx.x / 4) * tileN_ + (threadIdx.y * wmmaN_ + threadIdx.x % 4 * 2))), "r"(*(unsigned*) &r_mxY[y][x][2])); } __syncthreads(); #pragma unroll for (Rn iStep; iStep.var < iStep.size; iStep.var++) if (thread(iStep) < cn && thread(iStep) / cn < L - blockIdx_y * Q - mStart * cn) *(int4*) (g_mxY_ + get((aStart + blockIdx_y * Q + mStart * cn + thread(iStep) / cn) *H * P + hStart * P + nStart * cn + thread(iStep) % cn)) = *(int4*) ((char*) s_mxY + swizzle(thread(iStep) * cn<2>)); asm volatile("cp.async.wait_group %0;\n" ::"n"(0)); #endif } ChunkScanKernelFuncFp16 getChunkScanKernelFp16( int B_, int L_, int H_, int P_, int G_, int N_, int Q_, dim3* blockDims_, dim3* threadDims_, int* sharedMem_) { int B = B_; int L = L_; int H = H_; int P = P_; // int G = G_; // int N = N_; int Q = Q_; int C = div_up(L, Q); int tileM = 128; int tileN = 64; int tileK = 32; int warpM = 4; int warpN = 1; int pipeS = 2; auto sharedMem = std::max((tileM * tileK + tileK * tileN) * pipeS * 2 + Q * 8, (tileM * tileN) * 2); *blockDims_ = dim3(H * P / tileN * Q / tileM, C, B); *threadDims_ = dim3(32, warpN, warpM); *sharedMem_ = sharedMem; if (Q_ == 128) return chunk_scan_kernel<128, 128, 64, 32, 16, 8, 16, 4, 1, 2, half>; else if (Q_ == 256) return chunk_scan_kernel<256, 128, 64, 32, 16, 8, 16, 4, 1, 2, half>; else return nullptr; } ChunkScanKernelFuncBf16 getChunkScanKernelBf16( int B_, int L_, int H_, int P_, int G_, int N_, int Q_, dim3* blockDims_, dim3* threadDims_, int* sharedMem_) { int B = B_; int L = L_; int H = H_; int P = P_; // int G = G_; // int N = N_; int Q = Q_; int C = div_up(L, Q); int tileM = 128; int tileN = 64; int tileK = 32; int warpM = 4; int warpN = 1; int pipeS = 2; auto sharedMem = std::max((tileM * tileK + tileK * tileN) * pipeS * 2 + Q * 8, (tileM * tileN) * 2); *blockDims_ = dim3(H * P / tileN * Q / tileM, C, B); *threadDims_ = dim3(32, warpN, warpM); *sharedMem_ = sharedMem; if (Q_ == 128) return chunk_scan_kernel<128, 128, 64, 32, 16, 8, 16, 4, 1, 2, bf16>; else if (Q_ == 256) return chunk_scan_kernel<256, 128, 64, 32, 16, 8, 16, 4, 1, 2, bf16>; else return nullptr; } } // namespace kernels } // namespace tensorrt_llm // vim: ts=2 sw=2 sts=2 et sta