/* * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #include #include #include "tensorrt_llm/common/cudaBf16Fallbacks.cuh" #include "Common.h" #include "Poly.h" namespace tensorrt_llm { namespace kernels { typedef void (*ChunkCumsumKernelFuncFp16)(int B_, int L_, int H_, int P_, int G_, int N_, // const half *g_mxY_, // B*L*H*P // const half *g_mxOs_, // B*C*H*N*P // const half *g_mxFs_, // B *H*N*P // const float *g_mxSt_, // B*C*H*N*P float* g_mxdc_, // B*C*H*Q float* g_mxdA_, // B*C*H*Q half const* g_mxdt_, // B*L*(2*H*P+2*G*N+H) or B*L*(H*P+2*G*N+H) float const* g_mxdb_, // H float const* g_mxA_, // H // const half *g_mxCB_, // B*C*G*Q*Q // const float *g_mxD_, // H // const half *g_mxXBC_, // B*L*(H*P+2*G*N) half const* g_mxZ_, // B*L*(2*H*P+2*G*N+H) bool removePadding_, int const* lastTokenIdsPtr_); typedef void (*ChunkCumsumKernelFuncBf16)(int B_, int L_, int H_, int P_, int G_, int N_, // const bf16 *g_mxY_, // B*L*H*P // const bf16 *g_mxOs_, // B*C*H*N*P // const bf16 *g_mxFs_, // B *H*N*P // const float *g_mxSt_, // B*C*H*N*P float* g_mxdc_, // B*C*H*Q float* g_mxdA_, // B*C*H*Q bf16 const* g_mxdt_, // B*L*(2*H*P+2*G*N+H) or B*L*(H*P+2*G*N+H) float const* g_mxdb_, // H float const* g_mxA_, // H // const bf16 *g_mxCB_, // B*C*G*Q*Q // const float *g_mxD_, // H // const bf16 *g_mxXBC_, // B*L*(H*P+2*G*N) bf16 const* g_mxZ_, // B*L*(2*H*P+2*G*N+H) bool removePadding_, int const* lastTokenIdsPtr_); template __global__ std::enable_if_t || std::is_same_v> chunk_cumsum_kernel(int B_, int L_, int H_, int P_, int G_, int N_, // const Tp_ *g_mxY_, // B*L*H*P // const Tp_ *g_mxOs_, // B*C*H*N*P // const Tp_ *g_mxFs_, // B *H*N*P // const float *g_mxSt_, // B*C*H*N*P float* g_mxdc_, // B*C*H*Q float* g_mxdA_, // B*C*H*Q Tp_ const* g_mxdt_, // B*L*(2*H*P+2*G*N+H) or B*L*(H*P+2*G*N+H) Wt_ const* g_mxdb_, // H Wt_ const* g_mxA_, // H // const Tp_ *g_mxCB_, // B*C*G*Q*Q // const Wt_ *g_mxD_, // H // const Tp_ *g_mxXBC_, // B*L*(H*P+2*G*N) Tp_ const* g_mxZ_, // B*L*(2*H*P+2*G*N+H) bool removePadding_, int const* lastTokenIdsPtr_) { using namespace tensorrt_llm::common; auto blockIdx_x = Rn{int(blockIdx.x)}; auto blockIdx_y = Rn{int(blockIdx.y)}; auto blockIdx_z = Rn{int(blockIdx.z)}; auto threadIdx_x = Rn{int(threadIdx.x)}; auto threadIdx_y = Rn{int(threadIdx.y)}; // auto B = Rn{B_}; auto L = Rn{L_}; auto H = Rn{H_}; auto P = Rn{P_}; auto G = Rn{G_}; auto N = Rn{N_}; auto Q = cn; auto C = Rn{div_up(L.var, Q_)}; auto dt_dim = g_mxZ_ ? Rn{2 * H_ * P_ + 2 * G_ * N_ + H_} : Rn{H_ * P_ + 2 * G_ * N_ + H_}; auto aStart = blockIdx_z * L; auto cStart = blockIdx_z * C; if (removePadding_) { aStart = Rn{int(blockIdx.z ? lastTokenIdsPtr_[blockIdx.z - 1] : 0)}; cStart = Rn{int(blockIdx.z ? div_up(aStart.var, Q_) + blockIdx.z - 1 : 0)}; L = Rn{lastTokenIdsPtr_[blockIdx.z] - aStart.var}; C = Rn{div_up(L.var, Q_)}; } else { L = Rn{lastTokenIdsPtr_[blockIdx.z]}; C = Rn{div_up(L.var, Q_)}; } if (blockIdx_y * Q >= L) return; auto thread = [=](auto iStep) { return iStep * cn + threadIdx_y * cn<32> + threadIdx_x; }; #pragma unroll for (Rn iStep; iStep.var < iStep.size; iStep.var++) { float r_A = 0.f, r_db = 0.f, sum = 0.f; if (thread(iStep) < cn) r_A = g_mxA_[get(blockIdx_x * cn + thread(iStep))]; if (thread(iStep) < cn && g_mxdb_) r_db = g_mxdb_[get(blockIdx_x * cn + thread(iStep))]; #pragma unroll for (Rn iQ; iQ.var < iQ.size; iQ.var++) { float r_dt = 0.f; if (thread(iStep) < cn && blockIdx_y * Q + iQ < L) { r_dt = float(g_mxdt_[get((aStart + blockIdx_y * Q + iQ) * dt_dim + dt_dim - H + blockIdx_x * cn + thread(iStep))]) + r_db; if (dtSoftplus_) r_dt = r_dt > 32.f ? r_dt : log1p(expf(r_dt)); sum += r_dt; } if (thread(iStep) < cn) { g_mxdc_[get((cStart + blockIdx_y) * H * Q + (blockIdx_x * cn + thread(iStep)) * Q + iQ)] = r_dt; g_mxdA_[get((cStart + blockIdx_y) * H * Q + (blockIdx_x * cn + thread(iStep)) * Q + iQ)] = sum * r_A; } } } } ChunkCumsumKernelFuncFp16 getChunkCumsumKernelFp16( int B_, int L_, int H_, int Q_, bool dtSoftPlus_, dim3* blockDims_, dim3* threadDims_, int* sharedMem_) { int B = B_; int L = L_; int H = H_; // int P = P_; // int G = G_; // int N = N_; int Q = Q_; int C = div_up(L, Q); int tileH = 1; int warpH = 1; auto sharedMem = 0; *blockDims_ = dim3(H / tileH, C, B); *threadDims_ = dim3(32, warpH); *sharedMem_ = sharedMem; if (dtSoftPlus_) { if (Q_ == 128) return chunk_cumsum_kernel<128, 1, 1, true, half>; else if (Q_ == 256) return chunk_cumsum_kernel<256, 1, 1, true, half>; else return nullptr; } else { if (Q_ == 128) return chunk_cumsum_kernel<128, 1, 1, false, half>; else if (Q_ == 256) return chunk_cumsum_kernel<256, 1, 1, false, half>; else return nullptr; } } ChunkCumsumKernelFuncBf16 getChunkCumsumKernelBf16( int B_, int L_, int H_, int Q_, bool dtSoftPlus_, dim3* blockDims_, dim3* threadDims_, int* sharedMem_) { int B = B_; int L = L_; int H = H_; // int P = P_; // int G = G_; // int N = N_; int Q = Q_; int C = div_up(L, Q); int tileH = 1; int warpH = 1; auto sharedMem = 0; *blockDims_ = dim3(H / tileH, C, B); *threadDims_ = dim3(32, warpH); *sharedMem_ = sharedMem; if (dtSoftPlus_) { if (Q_ == 128) return chunk_cumsum_kernel<128, 1, 1, true, bf16>; else if (Q_ == 256) return chunk_cumsum_kernel<256, 1, 1, true, bf16>; else return nullptr; } else { if (Q_ == 128) return chunk_cumsum_kernel<128, 1, 1, false, bf16>; else if (Q_ == 256) return chunk_cumsum_kernel<256, 1, 1, false, bf16>; else return nullptr; } } } // namespace kernels } // namespace tensorrt_llm // vim: ts=2 sw=2 sts=2 et sta