/* * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #include #include #include #include "tensorrt_llm/common/cudaBf16Fallbacks.cuh" #include "tensorrt_llm/common/cudaDriverWrapper.h" #include "Common.h" #include "CudaType.h" #include "Poly.h" namespace tensorrt_llm { namespace kernels { typedef void (*ChunkStateKernelFunc)(int B_, int L_, int H_, int P_, int G_, int N_, // const void *g_mxY_, // Tp_ B*L*H*P // const void *g_mxOs_, // Tp_ B*C*H*N*P // const void *g_mxFs_, // Tp_ B *H*N*P void* g_mxSt_, // float B*C*H*N*P void const* g_mxdc_, // float B*C*H*Q void const* g_mxdA_, // float B*C*H*Q // const void *g_mxdt_, // Tp_ B*L*((g_mxZ?2:1)*H*P+2*G+round_up(H,8)) // const void *g_mxdb_, // Wt_ H // const void *g_mxA_, // Wt_ H // const void *g_mxCB_, // Tp_ B*C*G*Q*Q // const void *g_mxD_, // Wt_ H void const* g_mxX_, // Tp_ B*L*(H*P+2*G*N) // const void *g_mxZ_, // g_mxdt_ or nullptr bool removePadding_, int const* lastTokenIdsPtr_); template __global__ std::enable_if_t || std::is_same_v> chunk_state_kernel(int B_, int L_, int H_, int P_, int G_, int N_, // const void *g_mxY_, // Tp_ B*L*H*P // const void *g_mxOs_, // Tp_ B*C*H*N*P // const void *g_mxFs_, // Tp_ B *H*N*P void* g_mxSt_, // float B*C*H*N*P void const* g_mxdc_, // float B*C*H*Q void const* g_mxdA_, // float B*C*H*Q // const void *g_mxdt_, // Tp_ B*L*((g_mxZ?2:1)*H*P+2*G+round_up(H,8)) // const void *g_mxdb_, // Wt_ H // const void *g_mxA_, // Wt_ H // const void *g_mxCB_, // Tp_ B*C*G*Q*Q // const void *g_mxD_, // Wt_ H void const* g_mxX_, // Tp_ B*L*(H*P+2*G*N) // const void *g_mxZ_, // g_mxdt_ or nullptr bool removePadding_, int const* lastTokenIdsPtr_) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 using namespace tensorrt_llm::common; constexpr int wmmaM_ = 16; // wmma size, per instruction constexpr int wmmaN_ = 8; constexpr int wmmaK_ = 16; auto blockIdx_x = Rn{int(blockIdx.x)}; auto blockIdx_y = Rn{int(blockIdx.y)}; auto blockIdx_z = Rn{int(blockIdx.z)}; auto threadIdx_x = Rn{int(threadIdx.x)}; auto threadIdx_y = Rn{int(threadIdx.y)}; auto threadIdx_z = Rn{int(threadIdx.z)}; // auto B = Rn{B_}; auto L = Rn{L_}; auto H = Rn{H_}; auto P = Rn{P_}; auto G = Rn{G_}; auto N = Rn{N_}; auto Q = cn; auto C = Rn{div_up(L.var, Q_)}; auto X_stride = Rn{H_ * P_ + 2 * G_ * N_}; auto aStart = blockIdx_z * L; auto cStart = blockIdx_z * C; if (removePadding_) { aStart = Rn{int(blockIdx.z ? lastTokenIdsPtr_[blockIdx.z - 1] : 0)}; cStart = Rn{int(blockIdx.z ? div_up(aStart.var, Q_) + blockIdx.z - 1 : 0)}; L = Rn{lastTokenIdsPtr_[blockIdx.z] - aStart.var}; C = Rn{div_up(L.var, Q_)}; } else { L = Rn{lastTokenIdsPtr_[blockIdx.z]}; C = Rn{div_up(L.var, Q_)}; } if (blockIdx_y * Q >= L) return; int numBlocks_m = H_ * (N_ / tileM_); int numBlocks_n = div_up(P_, tileN_); int numBlocks_group = numBlocks_n * group_; int blockIdx_0 = blockIdx.x / numBlocks_group * group_; int group_m = std::min(group_, numBlocks_m - blockIdx_0); int mStartInt = blockIdx.x % numBlocks_group % group_m + blockIdx_0; int nStartInt = blockIdx.x % numBlocks_group / group_m; auto hStart = Rn{mStartInt / (N_ / tileM_)}; auto mStart = Rn{mStartInt % (N_ / tileM_)}; auto nStart = Rn{nStartInt}; auto gStart = Rn{hStart.var / (H_ / G_)}; // const Tp_ *g_mxY = (const Tp_ *)g_mxY_; // const Tp_ *g_mxOs = (const Tp_ *)g_mxOs_; // const Tp_ *g_mxFs = (const Tp_ *)g_mxFs_; float* g_mxSt = (float*) g_mxSt_ + int64_t(get(cStart + blockIdx_y)) * get(H * N * P); float const* g_mxdc = (float const*) g_mxdc_ + int64_t(get(cStart + blockIdx_y)) * get(H * Q); float const* g_mxdA = (float const*) g_mxdA_ + int64_t(get(cStart + blockIdx_y)) * get(H * Q); // const Tp_ *g_mxdt = (const Tp_ *)g_mxdt_; // const Wt_ *g_mxdb = (const Wt_ *)g_mxdb_; // const Wt_ *g_mxA = (const Wt_ *)g_mxA_; // const Tp_ *g_mxCB = (const Tp_ *)g_mxCB_; // const Wt_ *g_mxD = (const Wt_ *)g_mxD_; Tp_ const* g_mxX = (Tp_ const*) g_mxX_ + int64_t(get(aStart + blockIdx_y * Q)) * get(X_stride); // const Tp_ *g_mxZ = (const Tp_ *)g_mxZ_; extern __shared__ float smem[]; Tp_* s_mxL = (Tp_*) smem; Tp_* s_mxR = (Tp_*) smem + tileM_ * tileK_ * pipeS_; float* s_mxdc = smem + (tileM_ + tileN_) * tileK_ * pipeS_ / 2; float* s_mxdA = smem + (tileM_ + tileN_) * tileK_ * pipeS_ / 2 + Q_; unsigned b_base = __nvvm_get_smem_pointer(smem); unsigned b_mxL = b_base; unsigned b_mxR = b_base + tileM_ * tileK_ * pipeS_ * sizeof(Tp_); using std::array; array, tileN_ / wmmaN_ / warpN_>, tileM_ / wmmaM_ / warpM_> r_mxAcc = array, tileN_ / wmmaN_ / warpN_>, tileM_ / wmmaM_ / warpM_>(); array, tileM_ / wmmaM_ / warpM_>, pipeR_> r_mxL; array, tileN_ / wmmaN_ / warpN_>, pipeR_> r_mxR; if (pipeR_ == 2) { r_mxL = array, tileM_ / wmmaM_ / warpM_>, pipeR_>(); r_mxR = array, tileN_ / wmmaN_ / warpN_>, pipeR_>(); } constexpr int step = std::max( 1, tileM_ / wmmaM_ / warpM_ * tileN_ / wmmaN_ / warpN_ / (tileM_ / wmmaM_ / warpM_ + tileN_ / wmmaN_ / warpN_)); auto baseL = [](auto iK) { return iK % cn * cn * cn; }; auto baseR = [](auto iK) { return iK % cn * cn * cn; }; auto thread = [=](auto iStep) { return iStep * cn + threadIdx_z * cn + threadIdx_y * cn<256> + threadIdx_x * cn<8>; }; #pragma unroll for (Rn iStep; iStep.var < iStep.size; iStep.var++) if (thread(iStep) < cn) { #pragma unroll for (int i = 0; i < 8; i += 4) { // dc & dA should be set to zero out of seq length, but this is done to // dc by chunkcumsum, so no need here. *(int4*) (s_mxdc + get(thread(iStep)) + i) = *(int4*) (g_mxdc + get(hStart * Q + thread(iStep)) + i); *(int4*) (s_mxdA + get(thread(iStep)) + i) = *(int4*) (g_mxdA + get(hStart * Q + thread(iStep)) + i); } } #pragma unroll for (Rn iK; iK.var < iK.size; iK.var++) { #pragma unroll for (Rn iStep; iStep.var < iStep.size; iStep.var++) if (thread(iStep) < cn && thread(iStep) / cn < L - blockIdx_y * Q - iK * cn) cp_shared_global<16>(b_mxL + swizzle(thread(iStep) * cn<2>, baseL(iK) * cn<2>), g_mxX + get((iK * cn + thread(iStep) / cn) *X_stride + H * P + gStart * N + mStart * cn + thread(iStep) % cn)); else if (thread(iStep) < cn) *(int4*) ((char*) s_mxL + swizzle(thread(iStep) * cn<2>, baseL(iK) * cn<2>)) = int4{0, 0, 0, 0}; #pragma unroll for (Rn iStep; iStep.var < iStep.size; iStep.var++) if (thread(iStep) < cn && thread(iStep) / cn < L - blockIdx_y * Q - iK * cn && (!bnChk_ || nStart * cn + thread(iStep) % cn < P)) cp_shared_global<16>(b_mxR + swizzle(thread(iStep) * cn<2>, baseR(iK) * cn<2>), g_mxX + get((iK * cn + thread(iStep) / cn) *X_stride + hStart * P + nStart * cn + thread(iStep) % cn)); else if (thread(iStep) < cn) *(int4*) ((char*) s_mxR + swizzle(thread(iStep) * cn<2>, baseR(iK) * cn<2>)) = int4{0, 0, 0, 0}; cp_commit_group(); } cp_wait_group(); __syncthreads(); for (int iK = pipeS_; iK < Q_ / tileK_ + pipeS_; iK++) { auto jK = Rn<>{iK}; #pragma unroll for (Rn iStep; iStep.var < iStep.size; iStep.var++) if (thread(iStep) < cn) { Tp_ tmpL[8]; *(int4*) &tmpL[0] = *( int4*) ((char*) s_mxL + swizzle(thread(iStep) * cn<2>, baseL(jK) * cn<2>)); #pragma unroll for (int i = 0; i < 8; i += 2) { float2 tmp2 = std::is_same_v ? __half22float2(*(half2*) &tmpL[i]) : bf1622float2(*(bf162*) &tmpL[i]); int kStart = (iK - pipeS_) * cn; // dc & dA should be set to zero out of seq length, but this is done to // dc by chunkcumsum, so no need here. tmp2.x *= expf(s_mxdA[Q_ - 1] - s_mxdA[kStart + get(thread(iStep) / cn)]) * s_mxdc[kStart + get(thread(iStep) / cn)]; tmp2.y *= expf(s_mxdA[Q_ - 1] - s_mxdA[kStart + get(thread(iStep) / cn)]) * s_mxdc[kStart + get(thread(iStep) / cn)]; if (std::is_same_v) *(half2*) &tmpL[i] = __float22half2_rn(tmp2); else *(bf162*) &tmpL[i] = __float22bfloat162_rn(tmp2); } *(int4*) ((char*) s_mxL + swizzle(thread(iStep) * cn<2>, baseL(jK) * cn<2>)) = *(int4*) &tmpL[0]; } __syncthreads(); #pragma unroll for (int k = 0; k < tileK_ / wmmaK_; k++) { #pragma unroll for (int y = 0; y < tileM_ / wmmaM_ / warpM_; y++) #pragma unroll for (int x = 0; x < tileN_ / wmmaN_ / warpN_; x++) { if ((y * tileN_ / wmmaN_ / warpN_ + x) % step == 0) { int x1 = (y * tileN_ / wmmaN_ / warpN_ + x) / step; int y1 = x1 - tileN_ / wmmaN_ / warpN_ + (tileM_ / wmmaM_ / warpM_ == 1 || tileN_ / wmmaN_ / warpN_ == 1); if (y1 >= 0 && y1 < tileM_ / wmmaM_ / warpM_) { if (wmmaK_ == 16) ldmatrix_x4(r_mxL[k % pipeR_][y1][0], r_mxL[k % pipeR_][y1][1], r_mxL[k % pipeR_][y1][2], r_mxL[k % pipeR_][y1][3], b_mxL + iK % pipeS_ * (tileM_ * tileK_ * 2) + 2 * swz(y1 * warpM_ * wmmaM_ + k * wmmaK_ * tileM_ + threadIdx.z * wmmaM_ + threadIdx.x % 8 * tileM_ + threadIdx.x / 8 % 2 * 8 + threadIdx.x / wmmaK_ * 8 * tileM_)); } if (x1 >= 0 && x1 < tileN_ / wmmaN_ / warpN_) { if (wmmaK_ == 16 && x1 % 2 == 0) ldmatrix_x4(r_mxR[k % pipeR_][x1][0], r_mxR[k % pipeR_][x1][1], r_mxR[k % pipeR_][x1 + 1][0], r_mxR[k % pipeR_][x1 + 1][1], b_mxR + iK % pipeS_ * (tileK_ * tileN_ * 2) + 2 * swz(x1 * warpN_ * wmmaN_ + k * wmmaK_ * tileN_ + threadIdx.y * wmmaN_ + threadIdx.x % wmmaK_ * tileN_ + threadIdx.x / wmmaK_ * warpN_ * wmmaN_)); } } if (pipeR_ == 2) { if (wmmaK_ == 16) mma(r_mxAcc[y][x], r_mxL[(k + 1) % pipeR_][y], r_mxR[(k + 1) % pipeR_][x]); } } #pragma unroll for (int y = 0; y < tileM_ / wmmaM_ / warpM_; y++) #pragma unroll for (int x = 0; x < tileN_ / wmmaN_ / warpN_; x++) { if (pipeR_ == 1) { if (wmmaK_ == 16) mma(r_mxAcc[y][x], r_mxL[(k + 1) % pipeR_][y], r_mxR[(k + 1) % pipeR_][x]); } } } __syncthreads(); #pragma unroll for (Rn iStep; iStep.var < iStep.size; iStep.var++) if (thread(iStep) < cn && thread(iStep) / cn < L - blockIdx_y * Q - jK * cn && jK * cn < Q) cp_shared_global<16>(b_mxL + swizzle(thread(iStep) * cn<2>, baseL(jK) * cn<2>), g_mxX + get((jK * cn + thread(iStep) / cn) *X_stride + H * P + gStart * N + mStart * cn + thread(iStep) % cn)); else if (thread(iStep) < cn && jK * cn < Q) *(int4*) ((char*) s_mxL + swizzle(thread(iStep) * cn<2>, baseL(jK) * cn<2>)) = int4{0, 0, 0, 0}; #pragma unroll for (Rn iStep; iStep.var < iStep.size; iStep.var++) if (thread(iStep) < cn && thread(iStep) / cn < L - blockIdx_y * Q - jK * cn && (!bnChk_ || nStart * cn + thread(iStep) % cn < P) && jK * cn < Q) cp_shared_global<16>(b_mxR + swizzle(thread(iStep) * cn<2>, baseR(jK) * cn<2>), g_mxX + get((jK * cn + thread(iStep) / cn) *X_stride + hStart * P + nStart * cn + thread(iStep) % cn)); else if (thread(iStep) < cn && jK * cn < Q) *(int4*) ((char*) s_mxR + swizzle(thread(iStep) * cn<2>, baseR(jK) * cn<2>)) = int4{0, 0, 0, 0}; cp_commit_group(); cp_wait_group(); __syncthreads(); } cp_wait_group<0>(); #pragma unroll for (int y = 0; y < tileM_ / wmmaM_ / warpM_; y++) #pragma unroll for (int x = 0; x < tileN_ / wmmaN_ / warpN_; x++) { if (pipeR_ == 2) { if (wmmaK_ == 16) mma(r_mxAcc[y][x], r_mxL[1][y], r_mxR[1][x]); } if (!bnChk_ || nStart * cn + Rn{x} * cn + threadIdx_y * cn + threadIdx_x % cn<4> * cn<2> < P) { *(float2*) (g_mxSt + get(hStart * N * P + (mStart * cn + Rn{y} * cn + threadIdx_z * cn + threadIdx_x / cn<4>) *P + nStart * cn + Rn{x} * cn + threadIdx_y * cn + threadIdx_x % cn<4> * cn<2>)) = *(float2*) &r_mxAcc[y][x][0]; *(float2*) (g_mxSt + get(hStart * N * P + (mStart * cn + Rn{y} * cn + cn<8> + threadIdx_z * cn + threadIdx_x / cn<4>) *P + nStart * cn + Rn{x} * cn + threadIdx_y * cn + threadIdx_x % cn<4> * cn<2>)) = *(float2*) &r_mxAcc[y][x][2]; } } #endif } template __global__ std::enable_if_t || std::is_same_v> chunk_state_hopper(int B_, int L_, int H_, int P_, int G_, int N_, // const void *g_mxY_, // Tp_ B*L*H*P // const void *g_mxOs_, // Tp_ B*C*H*N*P // const void *g_mxFs_, // Tp_ B *H*N*P void* g_mxSt_, // float B*C*H*N*P void const* g_mxdc_, // float B*C*H*Q void const* g_mxdA_, // float B*C*H*Q // const void *g_mxdt_, // Tp_ B*L*((g_mxZ?2:1)*H*P+2*G+round_up(H,8)) // const void *g_mxdb_, // Wt_ H // const void *g_mxA_, // Wt_ H // const void *g_mxCB_, // Tp_ B*C*G*Q*Q // const void *g_mxD_, // Wt_ H void const* g_mxX_, // Tp_ B*L*(H*P+2*G*N) // const void *g_mxZ_, // g_mxdt_ or nullptr bool removePadding_, int const* lastTokenIdsPtr_) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) using namespace tensorrt_llm::common; constexpr int wmmaM_ = 16; // wmma size, per instruction constexpr int wmmaN_ = 8; constexpr int wmmaK_ = 16; auto blockIdx_x = Rn{int(blockIdx.x)}; auto blockIdx_y = Rn{int(blockIdx.y)}; auto blockIdx_z = Rn{int(blockIdx.z)}; auto threadIdx_x = Rn{int(threadIdx.x)}; auto threadIdx_y = Rn{int(threadIdx.y)}; auto threadIdx_z = Rn{int(threadIdx.z)}; // auto B = Rn{B_}; auto L = Rn{L_}; auto H = Rn{H_}; auto P = Rn{P_}; auto G = Rn{G_}; auto N = Rn{N_}; auto Q = cn; auto C = Rn{div_up(L.var, Q_)}; auto X_stride = Rn{H_ * P_ + 2 * G_ * N_}; auto aStart = blockIdx_z * L; auto cStart = blockIdx_z * C; if (removePadding_) { aStart = Rn{int(blockIdx.z ? lastTokenIdsPtr_[blockIdx.z - 1] : 0)}; cStart = Rn{int(blockIdx.z ? div_up(aStart.var, Q_) + blockIdx.z - 1 : 0)}; L = Rn{lastTokenIdsPtr_[blockIdx.z] - aStart.var}; C = Rn{div_up(L.var, Q_)}; } else { L = Rn{lastTokenIdsPtr_[blockIdx.z]}; C = Rn{div_up(L.var, Q_)}; } if (blockIdx_y * Q >= L) return; int numBlocks_m = H_ * (N_ / tileM_); int numBlocks_n = div_up(P_, tileN_); int numBlocks_group = numBlocks_n * group_; int blockIdx_0 = blockIdx.x / numBlocks_group * group_; int group_m = std::min(group_, numBlocks_m - blockIdx_0); int mStartInt = blockIdx.x % numBlocks_group % group_m + blockIdx_0; int nStartInt = blockIdx.x % numBlocks_group / group_m; auto hStart = Rn{mStartInt / (N_ / tileM_)}; auto mStart = Rn{mStartInt % (N_ / tileM_)}; auto nStart = Rn{nStartInt}; auto gStart = Rn{hStart.var / (H_ / G_)}; // const Tp_ *g_mxY = (const Tp_ *)g_mxY_; // const Tp_ *g_mxOs = (const Tp_ *)g_mxOs_; // const Tp_ *g_mxFs = (const Tp_ *)g_mxFs_; float* g_mxSt = (float*) g_mxSt_ + int64_t(get(cStart + blockIdx_y)) * get(H * N * P); float const* g_mxdc = (float const*) g_mxdc_ + int64_t(get(cStart + blockIdx_y)) * get(H * Q); float const* g_mxdA = (float const*) g_mxdA_ + int64_t(get(cStart + blockIdx_y)) * get(H * Q); // const Tp_ *g_mxdt = (const Tp_ *)g_mxdt_; // const Wt_ *g_mxdb = (const Wt_ *)g_mxdb_; // const Wt_ *g_mxA = (const Wt_ *)g_mxA_; // const Tp_ *g_mxCB = (const Tp_ *)g_mxCB_; // const Wt_ *g_mxD = (const Wt_ *)g_mxD_; // const Tp_ *g_mxX = (const Tp_ *)g_mxX_; // const Tp_ *g_mxZ = (const Tp_ *)g_mxZ_; extern __shared__ float smem[]; Tp_* s_mxL = (Tp_*) smem + 512; Tp_* s_mxR = (Tp_*) smem + 512 + tileM_ * tileK_ * pipeS_; float* s_mxdc = smem + 256 + (tileM_ + tileN_) * tileK_ * pipeS_ / 2; float* s_mxdA = smem + 256 + (tileM_ + tileN_) * tileK_ * pipeS_ / 2 + Q_; unsigned b_base = __nvvm_get_smem_pointer(smem); unsigned b_mxL = b_base + 1024; unsigned b_mxR = b_base + 1024 + tileM_ * tileK_ * pipeS_ * sizeof(Tp_); unsigned b_mbar = b_base; uint32_t formatL = 1; uint32_t formatR = tileN_ >= 64 ? 1 : (tileN_ == 32 ? 2 : 3); uint32_t baseOffsetL = 0; uint32_t baseOffsetR = 0; uint32_t strideDimL = 64; uint32_t strideDimR = tileN_ >= 64 ? 64 : tileN_; uint32_t leadingDimL = tileK_ * 8 * warpM_ / 4; uint32_t leadingDimR = tileN_ >= 64 ? tileK_ * 8 : 0; uint32_t startAddrL = (b_mxL + threadIdx.y / 4 * tileK_ * 128) / 16; uint32_t startAddrR = (b_mxR + threadIdx.z * tileN_ / warpN_ % strideDimR * 2 + threadIdx.z * tileN_ / warpN_ / strideDimR * tileK_ * strideDimR * 2) / 16; uint64_t d_mxL = ((uint64_t) formatL << 62) + ((uint64_t) baseOffsetL << 49) + ((uint64_t) strideDimL << 32) + ((uint64_t) leadingDimL << 16) + ((uint64_t) startAddrL); uint64_t d_mxR = ((uint64_t) formatR << 62) + ((uint64_t) baseOffsetR << 49) + ((uint64_t) strideDimR << 32) + ((uint64_t) leadingDimR << 16) + ((uint64_t) startAddrR); using std::array; array, tileN_ / wmmaN_ / warpN_>, tileM_ / wmmaM_ / warpM_> r_mxAcc = array, tileN_ / wmmaN_ / warpN_>, tileM_ / wmmaM_ / warpM_>(); auto baseL = [](auto iK) { return iK % cn * cn * cn; }; auto baseR = [](auto iK) { return iK % cn * cn * cn; }; auto thread = [=](auto iStep) { return iStep * cn + threadIdx_y * cn + threadIdx_z * cn<256> + threadIdx_x * cn<8>; }; asm volatile( ".reg .pred elected_one;\n" ".reg .pred done;\n" "elect.sync _|elected_one, 0xFFFFFFFF;\n" "@!elected_one setp.eq.b32 done, 0, 0;\n" ::); if (threadIdx.y == 0 && threadIdx.z == 0) { #pragma unroll for (Rn iK; iK.var < iK.size; iK.var++) asm volatile("@elected_one mbarrier.init.shared.b64 [%0], %1;\n" ::"r"(b_mbar + iK.var * 8), "r"(1)); } __syncthreads(); #pragma unroll for (Rn iStep; iStep.var < iStep.size; iStep.var++) if (thread(iStep) < cn) { #pragma unroll for (int i = 0; i < 8; i += 4) { // dc & dA should be set to zero out of seq length, but this is done to // dc by chunkcumsum, so no need here. *(int4*) (s_mxdc + get(thread(iStep)) + i) = *(int4*) (g_mxdc + get(hStart * Q + thread(iStep)) + i); *(int4*) (s_mxdA + get(thread(iStep)) + i) = *(int4*) (g_mxdA + get(hStart * Q + thread(iStep)) + i); } } #pragma unroll for (Rn iK; iK.var < iK.size; iK.var++) { if (threadIdx.y == 0 && threadIdx.z == 0) { asm volatile("@elected_one mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;\n" ::"r"(b_mbar + iK.var * 8), "r"((tileM_ + tileN_) * tileK_ * 2)); #pragma unroll for (int iM = 0; iM < tileM_; iM += 64) asm volatile( "@elected_one cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes" " [%0], [%1, {%2, %3, %4}], [%5];\n" ::"r"( b_mxL + iK.var % pipeS_ * (tileM_ * tileK_ * 2) + iM * tileK_ * 2), "l"((CUtensorMap const*) g_mxX_), "r"(get(mStart * cn) + iM), "r"(gStart.var), "r"(get(aStart + blockIdx_y * Q + iK * cn)), "r"(b_mbar + iK.var * 8)); #pragma unroll for (int iN = 0; iN < tileN_; iN += 64) asm volatile( "@elected_one cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes" " [%0], [%1, {%2, %3, %4}], [%5];\n" ::"r"( b_mxR + iK.var % pipeS_ * (tileN_ * tileK_ * 2) + iN * tileK_ * 2), "l"((CUtensorMap const*) g_mxX_ + 1), "r"(get(nStart * cn) + iN), "r"(hStart.var), "r"(get(aStart + blockIdx_y * Q + iK * cn)), "r"(b_mbar + iK.var * 8)); } } __syncthreads(); for (int iK = pipeS_; iK < Q_ / tileK_ + pipeS_; iK++) { uint64_t d_mxL0 = d_mxL + iK % pipeS_ * (tileM_ * tileK_ * 2) / 16; uint64_t d_mxR0 = d_mxR + iK % pipeS_ * (tileN_ * tileK_ * 2) / 16; asm volatile( "{\n" "TRY_AGAIN:\n" "mbarrier.try_wait.parity.shared.b64 done, [%0], %1;\n" "@!done nanosleep.u32 %2;\n" "@!done bra TRY_AGAIN;\n" "}\n" ::"r"(b_mbar + iK % pipeS_ * 8), "r"(1 - iK / pipeS_ % 2), "r"(threadIdx.x % 6)); auto jK = Rn<>{iK}; #pragma unroll for (Rn iStep; iStep.var < iStep.size; iStep.var++) if (thread(iStep) < cn) { Tp_ tmpL[8]; *(int4*) &tmpL[0] = *(int4*) ((char*) s_mxL + swizzle<128, 128>( thread(iStep) * cn<2> / cn * cn<128> + thread(iStep) * cn<2> % cn<128>, baseL(jK) * cn<2> + thread(iStep) * cn<2> % cn / cn<128> * cn<128 * tileK_>)); #pragma unroll for (int i = 0; i < 8; i += 2) { float2 tmp2 = std::is_same_v ? __half22float2(*(half2*) &tmpL[i]) : bf1622float2(*(bf162*) &tmpL[i]); int kStart = (iK - pipeS_) * cn; // dc & dA should be set to zero out of seq length, but this is done to // dc by chunkcumsum, so no need here. tmp2.x *= expf(s_mxdA[Q_ - 1] - s_mxdA[kStart + get(thread(iStep) / cn)]) * s_mxdc[kStart + get(thread(iStep) / cn)]; tmp2.y *= expf(s_mxdA[Q_ - 1] - s_mxdA[kStart + get(thread(iStep) / cn)]) * s_mxdc[kStart + get(thread(iStep) / cn)]; if (std::is_same_v) *(half2*) &tmpL[i] = __float22half2_rn(tmp2); else *(bf162*) &tmpL[i] = __float22bfloat162_rn(tmp2); } *(int4*) ((char*) s_mxL + swizzle<128, 128>( thread(iStep) * cn<2> / cn * cn<128> + thread(iStep) * cn<2> % cn<128>, baseL(jK) * cn<2> + thread(iStep) * cn<2> % cn / cn<128> * cn<128 * tileK_>)) = *(int4*) &tmpL[0]; } __syncthreads(); asm volatile("fence.proxy.async.shared::cta; \n"); asm volatile("wgmma.fence.sync.aligned; \n"); #pragma unroll for (int k = 0; k < tileK_ / wmmaK_; k++) { #pragma unroll for (int y = 0; y < tileM_ / wmmaM_ / warpM_; y++) { wgmma(r_mxAcc[y], d_mxL0 + k * 128 + y * tileK_ * 128 * warpM_ / 4 / 16, d_mxR0 + k * (tileN_ >= 64 ? 128 : 64)); } } asm volatile("wgmma.commit_group.sync.aligned;"); asm volatile("wgmma.wait_group.sync.aligned %0;" ::"n"(0)); __syncthreads(); if (iK * tileK_ < Q_) { if (threadIdx.y == 0 && threadIdx.z == 0) { asm volatile("@elected_one mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;\n" ::"r"( b_mbar + jK.var % pipeS_ * 8), "r"((tileM_ + tileN_) * tileK_ * 2)); #pragma unroll for (int iM = 0; iM < tileM_; iM += 64) asm volatile( "@elected_one cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes" " [%0], [%1, {%2, %3, %4}], [%5];\n" ::"r"( b_mxL + jK.var % pipeS_ * (tileM_ * tileK_ * 2) + iM * tileK_ * 2), "l"((CUtensorMap const*) g_mxX_), "r"(get(mStart * cn) + iM), "r"(gStart.var), "r"(get(aStart + blockIdx_y * Q + jK * cn)), "r"(b_mbar + jK.var % pipeS_ * 8)); #pragma unroll for (int iN = 0; iN < tileN_; iN += 64) asm volatile( "@elected_one cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes" " [%0], [%1, {%2, %3, %4}], [%5];\n" ::"r"( b_mxR + jK.var % pipeS_ * (tileN_ * tileK_ * 2) + iN * tileK_ * 2), "l"((CUtensorMap const*) g_mxX_ + 1), "r"(get(nStart * cn) + iN), "r"(hStart.var), "r"(get(aStart + blockIdx_y * Q + jK * cn)), "r"(b_mbar + jK.var % pipeS_ * 8)); } } } if (threadIdx.y == 0 && threadIdx.z == 0) { #pragma unroll for (Rn iK; iK.var < iK.size; iK.var++) asm volatile("@elected_one mbarrier.inval.shared.b64 [%0];\n" ::"r"(b_mbar + iK.var * 8)); } #pragma unroll for (int y = 0; y < tileM_ / wmmaM_ / warpM_; y++) #pragma unroll for (int x = 0; x < tileN_ / wmmaN_ / warpN_; x++) if (!bnChk_ || nStart * cn + Rn{x} * cn + threadIdx_z * cn + threadIdx_x % cn<4> * cn<2> < P) { *(float2*) (g_mxSt + get(hStart * N * P + (mStart * cn + Rn{y} * cn + threadIdx_y * cn + threadIdx_x / cn<4>) *P + nStart * cn + Rn{x} * cn + threadIdx_z * cn + threadIdx_x % cn<4> * cn<2>)) = *(float2*) &r_mxAcc[y][x][0]; *(float2*) (g_mxSt + get(hStart * N * P + (mStart * cn + Rn{y} * cn + cn<8> + threadIdx_y * cn + threadIdx_x / cn<4>) *P + nStart * cn + Rn{x} * cn + threadIdx_z * cn + threadIdx_x % cn<4> * cn<2>)) = *(float2*) &r_mxAcc[y][x][2]; } #endif } typedef ChunkStateKernelFunc (*GetChunkStateKernelFunc)(int B_, int L_, int H_, int P_, int G_, int N_, int Q_, int numTokens_, int isHopper_, tensorrt_llm::common::CUDADriverWrapper* cudaDriver_, dim3* blockDims_, dim3* threadDims_, int* sharedMem_, int* useTma_, CUtensorMap* descs_); template ChunkStateKernelFunc getChunkStateKernel(int B_, int L_, int H_, int P_, int G_, int N_, int Q_, int numTokens_, int isHopper_, tensorrt_llm::common::CUDADriverWrapper* cudaDriver_, dim3* blockDims_, dim3* threadDims_, int* sharedMem_, int* useTma_, CUtensorMap* descs_) { int B = B_; int L = L_; int H = H_; int P = P_; int G = G_; int N = N_; int Q = Q_; int C = div_up(L, Q); int64_t compute = int64_t(numTokens_) * H * N * P_ * Q; auto set = [&](int tileM, int tileN, int tileK, int warpM, int warpN, int pipeS, int useTma, ChunkStateKernelFunc func) { auto sharedMem = useTma * 1024 + (tileM * tileK + tileK * tileN) * pipeS * 2 + Q * 8; *blockDims_ = dim3(H * div_up(P_, tileN) * N / tileM, C, B); *threadDims_ = dim3(32, useTma ? warpM : warpN, useTma ? warpN : warpM); *sharedMem_ = sharedMem; *useTma_ = useTma; if (useTma) { { std::array tensorStrides{2uL * N, 2uL * (H * P_ + 2 * G * N)}; std::array tensorSizes{1uL * N, 1uL * G, 1uL * numTokens_}; std::array traveralStrides{1u, 1u, 1u}; std::array boxSizes{64u, 1u, (uint32_t) tileK}; cudaDriver_->cuTensorMapEncodeTiled(&descs_[0], CU_TENSOR_MAP_DATA_TYPE_FLOAT16, 3, nullptr, &tensorSizes[0], &tensorStrides[0], &boxSizes[0], &traveralStrides[0], CU_TENSOR_MAP_INTERLEAVE_NONE, CU_TENSOR_MAP_SWIZZLE_128B, CU_TENSOR_MAP_L2_PROMOTION_L2_128B, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); } { std::array tensorStrides{2uL * P_, 2uL * (H * P_ + 2 * G * N)}; std::array tensorSizes{1uL * P_, 1uL * H, 1uL * numTokens_}; std::array traveralStrides{1u, 1u, 1u}; std::array boxSizes{tileN >= 64 ? 64u : (uint32_t) tileN, 1u, (uint32_t) tileK}; cudaDriver_->cuTensorMapEncodeTiled(&descs_[1], CU_TENSOR_MAP_DATA_TYPE_FLOAT16, 3, nullptr, &tensorSizes[0], &tensorStrides[0], &boxSizes[0], &traveralStrides[0], CU_TENSOR_MAP_INTERLEAVE_NONE, tileN >= 64 ? CU_TENSOR_MAP_SWIZZLE_128B : CU_TENSOR_MAP_SWIZZLE_64B, CU_TENSOR_MAP_L2_PROMOTION_L2_128B, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); } } return func; }; if (Q == 256 && isHopper_ == false) { #ifndef FAST_BUILD if (P == 32 && N % 256 == 0) { if (compute >= (1LL << 38)) return set(64, 32, 32, 2, 1, 2, 0, chunk_state_kernel<256, 64, 32, 32, 2, 1, 2, 1, 1, 1, Tp_>); else if (compute >= (1LL << 35)) return set(64, 32, 64, 2, 2, 2, 0, chunk_state_kernel<256, 64, 32, 64, 2, 2, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 33)) return set(64, 32, 64, 2, 2, 4, 0, chunk_state_kernel<256, 64, 32, 64, 2, 2, 4, 2, 1, 1, Tp_>); else return set(64, 32, 64, 2, 2, 2, 0, chunk_state_kernel<256, 64, 32, 64, 2, 2, 2, 1, 8, 1, Tp_>); } if (P == 32 && N % 128 == 0) { if (compute >= (1LL << 38)) return set(64, 32, 32, 2, 1, 2, 0, chunk_state_kernel<256, 64, 32, 32, 2, 1, 2, 1, 1, 1, Tp_>); else if (compute >= (1LL << 35)) return set(64, 32, 64, 2, 2, 2, 0, chunk_state_kernel<256, 64, 32, 64, 2, 2, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 33)) return set(64, 32, 64, 2, 2, 4, 0, chunk_state_kernel<256, 64, 32, 64, 2, 2, 4, 2, 1, 1, Tp_>); else return set(64, 32, 64, 2, 2, 2, 0, chunk_state_kernel<256, 64, 32, 64, 2, 2, 2, 1, 8, 1, Tp_>); } if (P == 32 && N % 64 == 0) { if (compute >= (1LL << 38)) return set(64, 32, 32, 2, 1, 2, 0, chunk_state_kernel<256, 64, 32, 32, 2, 1, 2, 1, 1, 1, Tp_>); else if (compute >= (1LL << 35)) return set(64, 32, 64, 2, 2, 2, 0, chunk_state_kernel<256, 64, 32, 64, 2, 2, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 33)) return set(64, 32, 64, 2, 2, 4, 0, chunk_state_kernel<256, 64, 32, 64, 2, 2, 4, 2, 1, 1, Tp_>); else return set(64, 32, 64, 2, 2, 2, 0, chunk_state_kernel<256, 64, 32, 64, 2, 2, 2, 1, 8, 1, Tp_>); } if (P % 256 == 0 && N % 256 == 0) { if (compute >= (1LL << 44)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 2, 1, 2, 1, Tp_>); else if (compute >= (1LL << 43)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 4, 1, Tp_>); else if (compute >= (1LL << 42)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 1, 1, Tp_>); else if (compute >= (1LL << 39)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 4, 1, Tp_>); else if (compute >= (1LL << 38)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 1, 1, Tp_>); else if (compute >= (1LL << 37)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 2, 1, 2, 1, Tp_>); else if (compute >= (1LL << 36)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 1, 1, Tp_>); else if (compute >= (1LL << 35)) return set(64, 64, 32, 2, 2, 2, 0, chunk_state_kernel<256, 64, 64, 32, 2, 2, 2, 2, 2, 1, Tp_>); else return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 4, 1, Tp_>); } if (P % 256 == 0 && N % 128 == 0) { if (compute >= (1LL << 44)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 2, 1, 2, 1, Tp_>); else if (compute >= (1LL << 43)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 4, 1, Tp_>); else if (compute >= (1LL << 42)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 1, 1, Tp_>); else if (compute >= (1LL << 39)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 4, 1, Tp_>); else if (compute >= (1LL << 38)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 1, 1, Tp_>); else if (compute >= (1LL << 37)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 2, 1, 2, 1, Tp_>); else if (compute >= (1LL << 36)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 1, 1, Tp_>); else if (compute >= (1LL << 35)) return set(64, 64, 32, 2, 2, 2, 0, chunk_state_kernel<256, 64, 64, 32, 2, 2, 2, 2, 2, 1, Tp_>); else return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 4, 1, Tp_>); } if (P % 256 == 0 && N % 64 == 0) { if (compute >= (1LL << 44)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 2, 1, 2, 1, Tp_>); else if (compute >= (1LL << 43)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 4, 1, Tp_>); else if (compute >= (1LL << 42)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 1, 1, Tp_>); else if (compute >= (1LL << 39)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 4, 1, Tp_>); else if (compute >= (1LL << 38)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 1, 1, Tp_>); else if (compute >= (1LL << 37)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 2, 1, 2, 1, Tp_>); else if (compute >= (1LL << 36)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 1, 1, Tp_>); else if (compute >= (1LL << 35)) return set(64, 64, 32, 2, 2, 2, 0, chunk_state_kernel<256, 64, 64, 32, 2, 2, 2, 2, 2, 1, Tp_>); else return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 4, 1, Tp_>); } if (P % 128 == 0 && N % 256 == 0) { if (compute >= (1LL << 44)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 2, 1, 2, 1, Tp_>); else if (compute >= (1LL << 43)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 4, 1, Tp_>); else if (compute >= (1LL << 42)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 1, 1, Tp_>); else if (compute >= (1LL << 38)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 4, 1, Tp_>); else if (compute >= (1LL << 34)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 2, 1, 2, 1, Tp_>); else return set(64, 64, 64, 2, 2, 4, 0, chunk_state_kernel<256, 64, 64, 64, 2, 2, 4, 2, 1, 1, Tp_>); } if (P % 128 == 0 && N % 128 == 0) { if (compute >= (1LL << 44)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 2, 1, 2, 1, Tp_>); else if (compute >= (1LL << 43)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 4, 1, Tp_>); else if (compute >= (1LL << 42)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 1, 1, Tp_>); else if (compute >= (1LL << 38)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 4, 1, Tp_>); else if (compute >= (1LL << 34)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 2, 1, 2, 1, Tp_>); else return set(64, 64, 64, 2, 2, 4, 0, chunk_state_kernel<256, 64, 64, 64, 2, 2, 4, 2, 1, 1, Tp_>); } if (P % 128 == 0 && N % 64 == 0) { if (compute >= (1LL << 44)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 2, 1, 2, 1, Tp_>); else if (compute >= (1LL << 43)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 4, 1, Tp_>); else if (compute >= (1LL << 42)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 1, 1, Tp_>); else if (compute >= (1LL << 38)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 4, 1, Tp_>); else if (compute >= (1LL << 34)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 2, 1, 2, 1, Tp_>); else return set(64, 64, 64, 2, 2, 4, 0, chunk_state_kernel<256, 64, 64, 64, 2, 2, 4, 2, 1, 1, Tp_>); } if (P % 64 == 0 && N % 256 == 0) { if (compute >= (1LL << 39)) return set(64, 64, 32, 1, 2, 2, 0, chunk_state_kernel<256, 64, 64, 32, 1, 2, 2, 1, 1, 1, Tp_>); else if (compute >= (1LL << 36)) return set(64, 64, 32, 2, 2, 2, 0, chunk_state_kernel<256, 64, 64, 32, 2, 2, 2, 2, 2, 1, Tp_>); else if (compute >= (1LL << 35)) return set(64, 64, 64, 2, 2, 2, 0, chunk_state_kernel<256, 64, 64, 64, 2, 2, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 34)) return set(64, 64, 32, 2, 2, 2, 0, chunk_state_kernel<256, 64, 64, 32, 2, 2, 2, 2, 2, 1, Tp_>); else if (compute >= (1LL << 33)) return set(64, 64, 64, 2, 2, 4, 0, chunk_state_kernel<256, 64, 64, 64, 2, 2, 4, 2, 1, 1, Tp_>); else return set(64, 64, 64, 2, 2, 2, 0, chunk_state_kernel<256, 64, 64, 64, 2, 2, 2, 1, 8, 1, Tp_>); } if (P % 64 == 0 && N % 128 == 0) { if (compute >= (1LL << 39)) return set(64, 64, 32, 1, 2, 2, 0, chunk_state_kernel<256, 64, 64, 32, 1, 2, 2, 1, 1, 1, Tp_>); else if (compute >= (1LL << 36)) return set(64, 64, 32, 2, 2, 2, 0, chunk_state_kernel<256, 64, 64, 32, 2, 2, 2, 2, 2, 1, Tp_>); else if (compute >= (1LL << 35)) return set(64, 64, 64, 2, 2, 2, 0, chunk_state_kernel<256, 64, 64, 64, 2, 2, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 34)) return set(64, 64, 32, 2, 2, 2, 0, chunk_state_kernel<256, 64, 64, 32, 2, 2, 2, 2, 2, 1, Tp_>); else if (compute >= (1LL << 33)) return set(64, 64, 64, 2, 2, 4, 0, chunk_state_kernel<256, 64, 64, 64, 2, 2, 4, 2, 1, 1, Tp_>); else return set(64, 64, 64, 2, 2, 2, 0, chunk_state_kernel<256, 64, 64, 64, 2, 2, 2, 1, 8, 1, Tp_>); } #endif if (P % 64 == 0 && N % 64 == 0) { if (compute >= (1LL << 39)) return set(64, 64, 32, 1, 2, 2, 0, chunk_state_kernel<256, 64, 64, 32, 1, 2, 2, 1, 1, 1, Tp_>); else if (compute >= (1LL << 36)) return set(64, 64, 32, 2, 2, 2, 0, chunk_state_kernel<256, 64, 64, 32, 2, 2, 2, 2, 2, 1, Tp_>); else if (compute >= (1LL << 35)) return set(64, 64, 64, 2, 2, 2, 0, chunk_state_kernel<256, 64, 64, 64, 2, 2, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 34)) return set(64, 64, 32, 2, 2, 2, 0, chunk_state_kernel<256, 64, 64, 32, 2, 2, 2, 2, 2, 1, Tp_>); else if (compute >= (1LL << 33)) return set(64, 64, 64, 2, 2, 4, 0, chunk_state_kernel<256, 64, 64, 64, 2, 2, 4, 2, 1, 1, Tp_>); else return set(64, 64, 64, 2, 2, 2, 0, chunk_state_kernel<256, 64, 64, 64, 2, 2, 2, 1, 8, 1, Tp_>); } if (round_up(P, 32) == 32) P = 32; else P = round_up(P, 64); #ifndef FAST_BUILD if (P == 32 && N % 256 == 0) { if (compute >= (1LL << 38)) return set(64, 32, 32, 2, 1, 2, 0, chunk_state_kernel<256, 64, 32, 32, 2, 1, 2, 1, 1, 1, Tp_>); else if (compute >= (1LL << 35)) return set(64, 32, 64, 2, 2, 2, 0, chunk_state_kernel<256, 64, 32, 64, 2, 2, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 33)) return set(64, 32, 64, 2, 2, 4, 0, chunk_state_kernel<256, 64, 32, 64, 2, 2, 4, 2, 1, 1, Tp_>); else return set(64, 32, 64, 2, 2, 2, 0, chunk_state_kernel<256, 64, 32, 64, 2, 2, 2, 1, 8, 1, Tp_>); } if (P == 32 && N % 128 == 0) { if (compute >= (1LL << 38)) return set(64, 32, 32, 2, 1, 2, 0, chunk_state_kernel<256, 64, 32, 32, 2, 1, 2, 1, 1, 1, Tp_>); else if (compute >= (1LL << 35)) return set(64, 32, 64, 2, 2, 2, 0, chunk_state_kernel<256, 64, 32, 64, 2, 2, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 33)) return set(64, 32, 64, 2, 2, 4, 0, chunk_state_kernel<256, 64, 32, 64, 2, 2, 4, 2, 1, 1, Tp_>); else return set(64, 32, 64, 2, 2, 2, 0, chunk_state_kernel<256, 64, 32, 64, 2, 2, 2, 1, 8, 1, Tp_>); } if (P == 32 && N % 64 == 0) { if (compute >= (1LL << 38)) return set(64, 32, 32, 2, 1, 2, 0, chunk_state_kernel<256, 64, 32, 32, 2, 1, 2, 1, 1, 1, Tp_>); else if (compute >= (1LL << 35)) return set(64, 32, 64, 2, 2, 2, 0, chunk_state_kernel<256, 64, 32, 64, 2, 2, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 33)) return set(64, 32, 64, 2, 2, 4, 0, chunk_state_kernel<256, 64, 32, 64, 2, 2, 4, 2, 1, 1, Tp_>); else return set(64, 32, 64, 2, 2, 2, 0, chunk_state_kernel<256, 64, 32, 64, 2, 2, 2, 1, 8, 1, Tp_>); } if (P % 256 == 0 && N % 256 == 0) { if (compute >= (1LL << 44)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 2, 1, 2, 1, Tp_>); else if (compute >= (1LL << 43)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 4, 1, Tp_>); else if (compute >= (1LL << 42)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 1, 1, Tp_>); else if (compute >= (1LL << 39)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 4, 1, Tp_>); else if (compute >= (1LL << 38)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 1, 1, Tp_>); else if (compute >= (1LL << 37)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 2, 1, 2, 1, Tp_>); else if (compute >= (1LL << 36)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 1, 1, Tp_>); else if (compute >= (1LL << 35)) return set(64, 64, 32, 2, 2, 2, 0, chunk_state_kernel<256, 64, 64, 32, 2, 2, 2, 2, 2, 1, Tp_>); else return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 4, 1, Tp_>); } if (P % 256 == 0 && N % 128 == 0) { if (compute >= (1LL << 44)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 2, 1, 2, 1, Tp_>); else if (compute >= (1LL << 43)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 4, 1, Tp_>); else if (compute >= (1LL << 42)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 1, 1, Tp_>); else if (compute >= (1LL << 39)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 4, 1, Tp_>); else if (compute >= (1LL << 38)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 1, 1, Tp_>); else if (compute >= (1LL << 37)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 2, 1, 2, 1, Tp_>); else if (compute >= (1LL << 36)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 1, 1, Tp_>); else if (compute >= (1LL << 35)) return set(64, 64, 32, 2, 2, 2, 0, chunk_state_kernel<256, 64, 64, 32, 2, 2, 2, 2, 2, 1, Tp_>); else return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 4, 1, Tp_>); } if (P % 256 == 0 && N % 64 == 0) { if (compute >= (1LL << 44)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 2, 1, 2, 1, Tp_>); else if (compute >= (1LL << 43)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 4, 1, Tp_>); else if (compute >= (1LL << 42)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 1, 1, Tp_>); else if (compute >= (1LL << 39)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 4, 1, Tp_>); else if (compute >= (1LL << 38)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 1, 1, Tp_>); else if (compute >= (1LL << 37)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 2, 1, 2, 1, Tp_>); else if (compute >= (1LL << 36)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 1, 1, Tp_>); else if (compute >= (1LL << 35)) return set(64, 64, 32, 2, 2, 2, 0, chunk_state_kernel<256, 64, 64, 32, 2, 2, 2, 2, 2, 1, Tp_>); else return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 4, 1, Tp_>); } if (P % 128 == 0 && N % 256 == 0) { if (compute >= (1LL << 44)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 2, 1, 2, 1, Tp_>); else if (compute >= (1LL << 42)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 4, 1, Tp_>); else if (compute >= (1LL << 41)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 2, 1, 2, 1, Tp_>); else if (compute >= (1LL << 40)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 4, 1, Tp_>); else if (compute >= (1LL << 39)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 1, 1, Tp_>); else if (compute >= (1LL << 38)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 4, 1, Tp_>); else if (compute >= (1LL << 34)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 2, 1, 2, 1, Tp_>); else if (compute >= (1LL << 33)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 4, 1, Tp_>); else if (compute >= (1LL << 32)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 2, 1, 2, 1, Tp_>); else return set(64, 64, 32, 2, 2, 2, 0, chunk_state_kernel<256, 64, 64, 32, 2, 2, 2, 2, 2, 1, Tp_>); } if (P % 128 == 0 && N % 128 == 0) { if (compute >= (1LL << 44)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 2, 1, 2, 1, Tp_>); else if (compute >= (1LL << 42)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 4, 1, Tp_>); else if (compute >= (1LL << 41)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 2, 1, 2, 1, Tp_>); else if (compute >= (1LL << 40)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 4, 1, Tp_>); else if (compute >= (1LL << 39)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 1, 1, Tp_>); else if (compute >= (1LL << 38)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 4, 1, Tp_>); else if (compute >= (1LL << 34)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 2, 1, 2, 1, Tp_>); else if (compute >= (1LL << 33)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 4, 1, Tp_>); else if (compute >= (1LL << 32)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 2, 1, 2, 1, Tp_>); else return set(64, 64, 32, 2, 2, 2, 0, chunk_state_kernel<256, 64, 64, 32, 2, 2, 2, 2, 2, 1, Tp_>); } if (P % 128 == 0 && N % 64 == 0) { if (compute >= (1LL << 44)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 2, 1, 2, 1, Tp_>); else if (compute >= (1LL << 42)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 4, 1, Tp_>); else if (compute >= (1LL << 41)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 2, 1, 2, 1, Tp_>); else if (compute >= (1LL << 40)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 4, 1, Tp_>); else if (compute >= (1LL << 39)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 1, 1, Tp_>); else if (compute >= (1LL << 38)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 4, 1, Tp_>); else if (compute >= (1LL << 34)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 2, 1, 2, 1, Tp_>); else if (compute >= (1LL << 33)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 4, 1, Tp_>); else if (compute >= (1LL << 32)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 2, 1, 2, 1, Tp_>); else return set(64, 64, 32, 2, 2, 2, 0, chunk_state_kernel<256, 64, 64, 32, 2, 2, 2, 2, 2, 1, Tp_>); } if (P % 64 == 0 && N % 256 == 0) { if (compute >= (1LL << 39)) return set(64, 64, 32, 1, 2, 2, 0, chunk_state_kernel<256, 64, 64, 32, 1, 2, 2, 1, 1, 1, Tp_>); else if (compute >= (1LL << 36)) return set(64, 64, 32, 2, 2, 2, 0, chunk_state_kernel<256, 64, 64, 32, 2, 2, 2, 2, 2, 1, Tp_>); else if (compute >= (1LL << 35)) return set(64, 64, 64, 2, 2, 2, 0, chunk_state_kernel<256, 64, 64, 64, 2, 2, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 34)) return set(64, 64, 64, 2, 2, 4, 0, chunk_state_kernel<256, 64, 64, 64, 2, 2, 4, 2, 1, 1, Tp_>); else if (compute >= (1LL << 32)) return set(64, 64, 32, 2, 2, 2, 0, chunk_state_kernel<256, 64, 64, 32, 2, 2, 2, 2, 2, 1, Tp_>); else if (compute >= (1LL << 30)) return set(64, 64, 64, 2, 2, 4, 0, chunk_state_kernel<256, 64, 64, 64, 2, 2, 4, 2, 1, 1, Tp_>); else return set(64, 64, 32, 2, 2, 2, 0, chunk_state_kernel<256, 64, 64, 32, 2, 2, 2, 2, 2, 1, Tp_>); } if (P % 64 == 0 && N % 128 == 0) { if (compute >= (1LL << 39)) return set(64, 64, 32, 1, 2, 2, 0, chunk_state_kernel<256, 64, 64, 32, 1, 2, 2, 1, 1, 1, Tp_>); else if (compute >= (1LL << 36)) return set(64, 64, 32, 2, 2, 2, 0, chunk_state_kernel<256, 64, 64, 32, 2, 2, 2, 2, 2, 1, Tp_>); else if (compute >= (1LL << 35)) return set(64, 64, 64, 2, 2, 2, 0, chunk_state_kernel<256, 64, 64, 64, 2, 2, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 34)) return set(64, 64, 64, 2, 2, 4, 0, chunk_state_kernel<256, 64, 64, 64, 2, 2, 4, 2, 1, 1, Tp_>); else if (compute >= (1LL << 32)) return set(64, 64, 32, 2, 2, 2, 0, chunk_state_kernel<256, 64, 64, 32, 2, 2, 2, 2, 2, 1, Tp_>); else if (compute >= (1LL << 30)) return set(64, 64, 64, 2, 2, 4, 0, chunk_state_kernel<256, 64, 64, 64, 2, 2, 4, 2, 1, 1, Tp_>); else return set(64, 64, 32, 2, 2, 2, 0, chunk_state_kernel<256, 64, 64, 32, 2, 2, 2, 2, 2, 1, Tp_>); } #endif if (P % 64 == 0 && N % 64 == 0) { if (compute >= (1LL << 39)) return set(64, 64, 32, 1, 2, 2, 0, chunk_state_kernel<256, 64, 64, 32, 1, 2, 2, 1, 1, 1, Tp_>); else if (compute >= (1LL << 36)) return set(64, 64, 32, 2, 2, 2, 0, chunk_state_kernel<256, 64, 64, 32, 2, 2, 2, 2, 2, 1, Tp_>); else if (compute >= (1LL << 35)) return set(64, 64, 64, 2, 2, 2, 0, chunk_state_kernel<256, 64, 64, 64, 2, 2, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 34)) return set(64, 64, 64, 2, 2, 4, 0, chunk_state_kernel<256, 64, 64, 64, 2, 2, 4, 2, 1, 1, Tp_>); else if (compute >= (1LL << 32)) return set(64, 64, 32, 2, 2, 2, 0, chunk_state_kernel<256, 64, 64, 32, 2, 2, 2, 2, 2, 1, Tp_>); else if (compute >= (1LL << 30)) return set(64, 64, 64, 2, 2, 4, 0, chunk_state_kernel<256, 64, 64, 64, 2, 2, 4, 2, 1, 1, Tp_>); else return set(64, 64, 32, 2, 2, 2, 0, chunk_state_kernel<256, 64, 64, 32, 2, 2, 2, 2, 2, 1, Tp_>); } } else if (Q == 256 && isHopper_) { #ifndef FAST_BUILD if (P == 32 && N % 256 == 0) { if (compute >= (1LL << 39)) return set(64, 32, 64, 4, 1, 2, 1, chunk_state_hopper<256, 64, 32, 64, 4, 1, 2, 0, 1, 1, Tp_>); else if (compute >= (1LL << 38)) return set(64, 32, 32, 2, 1, 2, 0, chunk_state_kernel<256, 64, 32, 32, 2, 1, 2, 1, 1, 1, Tp_>); else if (compute >= (1LL << 36)) return set(64, 32, 64, 2, 2, 2, 0, chunk_state_kernel<256, 64, 32, 64, 2, 2, 2, 1, 8, 1, Tp_>); else return set(64, 32, 64, 2, 2, 4, 0, chunk_state_kernel<256, 64, 32, 64, 2, 2, 4, 2, 1, 1, Tp_>); } if (P == 32 && N % 128 == 0) { if (compute >= (1LL << 39)) return set(64, 32, 64, 4, 1, 2, 1, chunk_state_hopper<256, 64, 32, 64, 4, 1, 2, 0, 1, 1, Tp_>); else if (compute >= (1LL << 38)) return set(64, 32, 32, 2, 1, 2, 0, chunk_state_kernel<256, 64, 32, 32, 2, 1, 2, 1, 1, 1, Tp_>); else if (compute >= (1LL << 36)) return set(64, 32, 64, 2, 2, 2, 0, chunk_state_kernel<256, 64, 32, 64, 2, 2, 2, 1, 8, 1, Tp_>); else return set(64, 32, 64, 2, 2, 4, 0, chunk_state_kernel<256, 64, 32, 64, 2, 2, 4, 2, 1, 1, Tp_>); } if (P == 32 && N % 64 == 0) { if (compute >= (1LL << 39)) return set(64, 32, 64, 4, 1, 2, 1, chunk_state_hopper<256, 64, 32, 64, 4, 1, 2, 0, 1, 1, Tp_>); else if (compute >= (1LL << 38)) return set(64, 32, 32, 2, 1, 2, 0, chunk_state_kernel<256, 64, 32, 32, 2, 1, 2, 1, 1, 1, Tp_>); else if (compute >= (1LL << 36)) return set(64, 32, 64, 2, 2, 2, 0, chunk_state_kernel<256, 64, 32, 64, 2, 2, 2, 1, 8, 1, Tp_>); else return set(64, 32, 64, 2, 2, 4, 0, chunk_state_kernel<256, 64, 32, 64, 2, 2, 4, 2, 1, 1, Tp_>); } if (P % 256 == 0 && N % 256 == 0) { if (compute >= (1LL << 41)) return set(64, 256, 32, 4, 1, 2, 1, chunk_state_hopper<256, 64, 256, 32, 4, 1, 2, 0, 4, 1, Tp_>); else if (compute >= (1LL << 39)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 4, 1, Tp_>); else if (compute >= (1LL << 35)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 1, 1, Tp_>); else if (compute >= (1LL << 34)) return set(64, 64, 64, 2, 2, 4, 0, chunk_state_kernel<256, 64, 64, 64, 2, 2, 4, 2, 1, 1, Tp_>); else return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 1, 1, Tp_>); } if (P % 256 == 0 && N % 128 == 0) { if (compute >= (1LL << 41)) return set(64, 256, 32, 4, 1, 2, 1, chunk_state_hopper<256, 64, 256, 32, 4, 1, 2, 0, 4, 1, Tp_>); else if (compute >= (1LL << 39)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 4, 1, Tp_>); else if (compute >= (1LL << 35)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 1, 1, Tp_>); else if (compute >= (1LL << 34)) return set(64, 64, 64, 2, 2, 4, 0, chunk_state_kernel<256, 64, 64, 64, 2, 2, 4, 2, 1, 1, Tp_>); else return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 1, 1, Tp_>); } if (P % 256 == 0 && N % 64 == 0) { if (compute >= (1LL << 41)) return set(64, 256, 32, 4, 1, 2, 1, chunk_state_hopper<256, 64, 256, 32, 4, 1, 2, 0, 4, 1, Tp_>); else if (compute >= (1LL << 39)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 4, 1, Tp_>); else if (compute >= (1LL << 35)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 1, 1, Tp_>); else if (compute >= (1LL << 34)) return set(64, 64, 64, 2, 2, 4, 0, chunk_state_kernel<256, 64, 64, 64, 2, 2, 4, 2, 1, 1, Tp_>); else return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 1, 1, Tp_>); } if (P % 128 == 0 && N % 256 == 0) { if (compute >= (1LL << 41)) return set(64, 128, 64, 4, 1, 2, 1, chunk_state_hopper<256, 64, 128, 64, 4, 1, 2, 0, 2, 1, Tp_>); else if (compute >= (1LL << 39)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 4, 1, Tp_>); else if (compute >= (1LL << 35)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 1, 1, Tp_>); else if (compute >= (1LL << 33)) return set(64, 64, 64, 2, 2, 4, 0, chunk_state_kernel<256, 64, 64, 64, 2, 2, 4, 2, 1, 1, Tp_>); else return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 1, 1, Tp_>); } if (P % 128 == 0 && N % 128 == 0) { if (compute >= (1LL << 41)) return set(64, 128, 64, 4, 1, 2, 1, chunk_state_hopper<256, 64, 128, 64, 4, 1, 2, 0, 2, 1, Tp_>); else if (compute >= (1LL << 39)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 4, 1, Tp_>); else if (compute >= (1LL << 35)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 1, 1, Tp_>); else if (compute >= (1LL << 33)) return set(64, 64, 64, 2, 2, 4, 0, chunk_state_kernel<256, 64, 64, 64, 2, 2, 4, 2, 1, 1, Tp_>); else return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 1, 1, Tp_>); } if (P % 128 == 0 && N % 64 == 0) { if (compute >= (1LL << 41)) return set(64, 128, 64, 4, 1, 2, 1, chunk_state_hopper<256, 64, 128, 64, 4, 1, 2, 0, 2, 1, Tp_>); else if (compute >= (1LL << 39)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 4, 1, Tp_>); else if (compute >= (1LL << 35)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 1, 1, Tp_>); else if (compute >= (1LL << 33)) return set(64, 64, 64, 2, 2, 4, 0, chunk_state_kernel<256, 64, 64, 64, 2, 2, 4, 2, 1, 1, Tp_>); else return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 1, 1, Tp_>); } if (P % 64 == 0 && N % 256 == 0) { if (compute >= (1LL << 40)) return set(64, 64, 32, 4, 1, 3, 1, chunk_state_hopper<256, 64, 64, 32, 4, 1, 3, 0, 4, 1, Tp_>); else if (compute >= (1LL << 37)) return set(64, 64, 32, 1, 2, 2, 0, chunk_state_kernel<256, 64, 64, 32, 1, 2, 2, 1, 1, 1, Tp_>); else if (compute >= (1LL << 36)) return set(64, 64, 64, 2, 2, 2, 0, chunk_state_kernel<256, 64, 64, 64, 2, 2, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 33)) return set(64, 64, 64, 2, 2, 4, 0, chunk_state_kernel<256, 64, 64, 64, 2, 2, 4, 2, 1, 1, Tp_>); else return set(64, 64, 64, 2, 2, 2, 0, chunk_state_kernel<256, 64, 64, 64, 2, 2, 2, 1, 8, 1, Tp_>); } if (P % 64 == 0 && N % 128 == 0) { if (compute >= (1LL << 40)) return set(64, 64, 32, 4, 1, 3, 1, chunk_state_hopper<256, 64, 64, 32, 4, 1, 3, 0, 4, 1, Tp_>); else if (compute >= (1LL << 37)) return set(64, 64, 32, 1, 2, 2, 0, chunk_state_kernel<256, 64, 64, 32, 1, 2, 2, 1, 1, 1, Tp_>); else if (compute >= (1LL << 36)) return set(64, 64, 64, 2, 2, 2, 0, chunk_state_kernel<256, 64, 64, 64, 2, 2, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 33)) return set(64, 64, 64, 2, 2, 4, 0, chunk_state_kernel<256, 64, 64, 64, 2, 2, 4, 2, 1, 1, Tp_>); else return set(64, 64, 64, 2, 2, 2, 0, chunk_state_kernel<256, 64, 64, 64, 2, 2, 2, 1, 8, 1, Tp_>); } #endif if (P % 64 == 0 && N % 64 == 0) { if (compute >= (1LL << 40)) return set(64, 64, 32, 4, 1, 3, 1, chunk_state_hopper<256, 64, 64, 32, 4, 1, 3, 0, 4, 1, Tp_>); else if (compute >= (1LL << 37)) return set(64, 64, 32, 1, 2, 2, 0, chunk_state_kernel<256, 64, 64, 32, 1, 2, 2, 1, 1, 1, Tp_>); else if (compute >= (1LL << 36)) return set(64, 64, 64, 2, 2, 2, 0, chunk_state_kernel<256, 64, 64, 64, 2, 2, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 33)) return set(64, 64, 64, 2, 2, 4, 0, chunk_state_kernel<256, 64, 64, 64, 2, 2, 4, 2, 1, 1, Tp_>); else return set(64, 64, 64, 2, 2, 2, 0, chunk_state_kernel<256, 64, 64, 64, 2, 2, 2, 1, 8, 1, Tp_>); } if (round_up(P, 32) == 32) P = 32; else P = round_up(P, 64); #ifndef FAST_BUILD if (P == 32 && N % 256 == 0) { if (compute >= (1LL << 39)) return set(64, 32, 64, 4, 1, 2, 1, chunk_state_hopper<256, 64, 32, 64, 4, 1, 2, 0, 1, 1, Tp_>); else if (compute >= (1LL << 38)) return set(64, 32, 32, 2, 1, 2, 0, chunk_state_kernel<256, 64, 32, 32, 2, 1, 2, 1, 1, 1, Tp_>); else if (compute >= (1LL << 36)) return set(64, 32, 64, 2, 2, 2, 0, chunk_state_kernel<256, 64, 32, 64, 2, 2, 2, 1, 8, 1, Tp_>); else return set(64, 32, 64, 2, 2, 4, 0, chunk_state_kernel<256, 64, 32, 64, 2, 2, 4, 2, 1, 1, Tp_>); } if (P == 32 && N % 128 == 0) { if (compute >= (1LL << 39)) return set(64, 32, 64, 4, 1, 2, 1, chunk_state_hopper<256, 64, 32, 64, 4, 1, 2, 0, 1, 1, Tp_>); else if (compute >= (1LL << 38)) return set(64, 32, 32, 2, 1, 2, 0, chunk_state_kernel<256, 64, 32, 32, 2, 1, 2, 1, 1, 1, Tp_>); else if (compute >= (1LL << 36)) return set(64, 32, 64, 2, 2, 2, 0, chunk_state_kernel<256, 64, 32, 64, 2, 2, 2, 1, 8, 1, Tp_>); else return set(64, 32, 64, 2, 2, 4, 0, chunk_state_kernel<256, 64, 32, 64, 2, 2, 4, 2, 1, 1, Tp_>); } if (P == 32 && N % 64 == 0) { if (compute >= (1LL << 39)) return set(64, 32, 64, 4, 1, 2, 1, chunk_state_hopper<256, 64, 32, 64, 4, 1, 2, 0, 1, 1, Tp_>); else if (compute >= (1LL << 38)) return set(64, 32, 32, 2, 1, 2, 0, chunk_state_kernel<256, 64, 32, 32, 2, 1, 2, 1, 1, 1, Tp_>); else if (compute >= (1LL << 36)) return set(64, 32, 64, 2, 2, 2, 0, chunk_state_kernel<256, 64, 32, 64, 2, 2, 2, 1, 8, 1, Tp_>); else return set(64, 32, 64, 2, 2, 4, 0, chunk_state_kernel<256, 64, 32, 64, 2, 2, 4, 2, 1, 1, Tp_>); } if (P % 256 == 0 && N % 256 == 0) { if (compute >= (1LL << 41)) return set(64, 256, 32, 4, 1, 2, 1, chunk_state_hopper<256, 64, 256, 32, 4, 1, 2, 0, 4, 1, Tp_>); else if (compute >= (1LL << 39)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 4, 1, Tp_>); else if (compute >= (1LL << 35)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 1, 1, Tp_>); else if (compute >= (1LL << 34)) return set(64, 64, 64, 2, 2, 4, 0, chunk_state_kernel<256, 64, 64, 64, 2, 2, 4, 2, 1, 1, Tp_>); else return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 1, 1, Tp_>); } if (P % 256 == 0 && N % 128 == 0) { if (compute >= (1LL << 41)) return set(64, 256, 32, 4, 1, 2, 1, chunk_state_hopper<256, 64, 256, 32, 4, 1, 2, 0, 4, 1, Tp_>); else if (compute >= (1LL << 39)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 4, 1, Tp_>); else if (compute >= (1LL << 35)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 1, 1, Tp_>); else if (compute >= (1LL << 34)) return set(64, 64, 64, 2, 2, 4, 0, chunk_state_kernel<256, 64, 64, 64, 2, 2, 4, 2, 1, 1, Tp_>); else return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 1, 1, Tp_>); } if (P % 256 == 0 && N % 64 == 0) { if (compute >= (1LL << 41)) return set(64, 256, 32, 4, 1, 2, 1, chunk_state_hopper<256, 64, 256, 32, 4, 1, 2, 0, 4, 1, Tp_>); else if (compute >= (1LL << 39)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 4, 1, Tp_>); else if (compute >= (1LL << 35)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 1, 1, Tp_>); else if (compute >= (1LL << 34)) return set(64, 64, 64, 2, 2, 4, 0, chunk_state_kernel<256, 64, 64, 64, 2, 2, 4, 2, 1, 1, Tp_>); else return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 1, 1, Tp_>); } if (P % 128 == 0 && N % 256 == 0) { if (compute >= (1LL << 40)) return set(64, 128, 64, 4, 1, 2, 1, chunk_state_hopper<256, 64, 128, 64, 4, 1, 2, 0, 2, 1, Tp_>); else if (compute >= (1LL << 39)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 4, 1, Tp_>); else if (compute >= (1LL << 35)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 1, 1, Tp_>); else if (compute >= (1LL << 34)) return set(64, 64, 64, 2, 2, 4, 0, chunk_state_kernel<256, 64, 64, 64, 2, 2, 4, 2, 1, 1, Tp_>); else if (compute >= (1LL << 33)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 4, 1, Tp_>); else if (compute >= (1LL << 32)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 1, 1, Tp_>); else return set(64, 64, 64, 2, 2, 4, 0, chunk_state_kernel<256, 64, 64, 64, 2, 2, 4, 2, 1, 1, Tp_>); } if (P % 128 == 0 && N % 128 == 0) { if (compute >= (1LL << 40)) return set(64, 128, 64, 4, 1, 2, 1, chunk_state_hopper<256, 64, 128, 64, 4, 1, 2, 0, 2, 1, Tp_>); else if (compute >= (1LL << 39)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 4, 1, Tp_>); else if (compute >= (1LL << 35)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 1, 1, Tp_>); else if (compute >= (1LL << 34)) return set(64, 64, 64, 2, 2, 4, 0, chunk_state_kernel<256, 64, 64, 64, 2, 2, 4, 2, 1, 1, Tp_>); else if (compute >= (1LL << 33)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 4, 1, Tp_>); else if (compute >= (1LL << 32)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 1, 1, Tp_>); else return set(64, 64, 64, 2, 2, 4, 0, chunk_state_kernel<256, 64, 64, 64, 2, 2, 4, 2, 1, 1, Tp_>); } if (P % 128 == 0 && N % 64 == 0) { if (compute >= (1LL << 40)) return set(64, 128, 64, 4, 1, 2, 1, chunk_state_hopper<256, 64, 128, 64, 4, 1, 2, 0, 2, 1, Tp_>); else if (compute >= (1LL << 39)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 4, 1, Tp_>); else if (compute >= (1LL << 35)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 1, 1, Tp_>); else if (compute >= (1LL << 34)) return set(64, 64, 64, 2, 2, 4, 0, chunk_state_kernel<256, 64, 64, 64, 2, 2, 4, 2, 1, 1, Tp_>); else if (compute >= (1LL << 33)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 4, 1, Tp_>); else if (compute >= (1LL << 32)) return set(64, 128, 32, 1, 4, 3, 0, chunk_state_kernel<256, 64, 128, 32, 1, 4, 3, 1, 1, 1, Tp_>); else return set(64, 64, 64, 2, 2, 4, 0, chunk_state_kernel<256, 64, 64, 64, 2, 2, 4, 2, 1, 1, Tp_>); } if (P % 64 == 0 && N % 256 == 0) { if (compute >= (1LL << 40)) return set(64, 64, 32, 4, 1, 3, 1, chunk_state_hopper<256, 64, 64, 32, 4, 1, 3, 0, 4, 1, Tp_>); else if (compute >= (1LL << 38)) return set(64, 64, 32, 1, 2, 2, 0, chunk_state_kernel<256, 64, 64, 32, 1, 2, 2, 1, 1, 1, Tp_>); else if (compute >= (1LL << 36)) return set(64, 64, 64, 2, 2, 2, 0, chunk_state_kernel<256, 64, 64, 64, 2, 2, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 34)) return set(64, 64, 64, 2, 2, 4, 0, chunk_state_kernel<256, 64, 64, 64, 2, 2, 4, 2, 1, 1, Tp_>); else if (compute >= (1LL << 32)) return set(64, 64, 64, 2, 2, 2, 0, chunk_state_kernel<256, 64, 64, 64, 2, 2, 2, 1, 8, 1, Tp_>); else return set(64, 64, 64, 2, 2, 4, 0, chunk_state_kernel<256, 64, 64, 64, 2, 2, 4, 2, 1, 1, Tp_>); } if (P % 64 == 0 && N % 128 == 0) { if (compute >= (1LL << 40)) return set(64, 64, 32, 4, 1, 3, 1, chunk_state_hopper<256, 64, 64, 32, 4, 1, 3, 0, 4, 1, Tp_>); else if (compute >= (1LL << 38)) return set(64, 64, 32, 1, 2, 2, 0, chunk_state_kernel<256, 64, 64, 32, 1, 2, 2, 1, 1, 1, Tp_>); else if (compute >= (1LL << 36)) return set(64, 64, 64, 2, 2, 2, 0, chunk_state_kernel<256, 64, 64, 64, 2, 2, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 34)) return set(64, 64, 64, 2, 2, 4, 0, chunk_state_kernel<256, 64, 64, 64, 2, 2, 4, 2, 1, 1, Tp_>); else if (compute >= (1LL << 32)) return set(64, 64, 64, 2, 2, 2, 0, chunk_state_kernel<256, 64, 64, 64, 2, 2, 2, 1, 8, 1, Tp_>); else return set(64, 64, 64, 2, 2, 4, 0, chunk_state_kernel<256, 64, 64, 64, 2, 2, 4, 2, 1, 1, Tp_>); } #endif if (P % 64 == 0 && N % 64 == 0) { if (compute >= (1LL << 40)) return set(64, 64, 32, 4, 1, 3, 1, chunk_state_hopper<256, 64, 64, 32, 4, 1, 3, 0, 4, 1, Tp_>); else if (compute >= (1LL << 38)) return set(64, 64, 32, 1, 2, 2, 0, chunk_state_kernel<256, 64, 64, 32, 1, 2, 2, 1, 1, 1, Tp_>); else if (compute >= (1LL << 36)) return set(64, 64, 64, 2, 2, 2, 0, chunk_state_kernel<256, 64, 64, 64, 2, 2, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 34)) return set(64, 64, 64, 2, 2, 4, 0, chunk_state_kernel<256, 64, 64, 64, 2, 2, 4, 2, 1, 1, Tp_>); else if (compute >= (1LL << 32)) return set(64, 64, 64, 2, 2, 2, 0, chunk_state_kernel<256, 64, 64, 64, 2, 2, 2, 1, 8, 1, Tp_>); else return set(64, 64, 64, 2, 2, 4, 0, chunk_state_kernel<256, 64, 64, 64, 2, 2, 4, 2, 1, 1, Tp_>); } } if (Q == 128 && isHopper_ == false) { #ifndef FAST_BUILD if (P == 32 && N % 256 == 0) { if (compute >= (1LL << 34)) return set(64, 32, 32, 2, 1, 2, 0, chunk_state_kernel<128, 64, 32, 32, 2, 1, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 31)) return set(64, 32, 32, 1, 2, 4, 0, chunk_state_kernel<128, 64, 32, 32, 1, 2, 4, 1, 2, 1, Tp_>); else return set(64, 64, 32, 1, 2, 4, 0, chunk_state_kernel<128, 64, 64, 32, 1, 2, 4, 2, 1, 1, Tp_>); } if (P == 32 && N % 128 == 0) { if (compute >= (1LL << 34)) return set(64, 32, 32, 2, 1, 2, 0, chunk_state_kernel<128, 64, 32, 32, 2, 1, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 31)) return set(64, 32, 32, 1, 2, 4, 0, chunk_state_kernel<128, 64, 32, 32, 1, 2, 4, 1, 2, 1, Tp_>); else return set(64, 64, 32, 1, 2, 4, 0, chunk_state_kernel<128, 64, 64, 32, 1, 2, 4, 2, 1, 1, Tp_>); } if (P == 32 && N % 64 == 0) { if (compute >= (1LL << 34)) return set(64, 32, 32, 2, 1, 2, 0, chunk_state_kernel<128, 64, 32, 32, 2, 1, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 31)) return set(64, 32, 32, 1, 2, 4, 0, chunk_state_kernel<128, 64, 32, 32, 1, 2, 4, 1, 2, 1, Tp_>); else return set(64, 64, 32, 1, 2, 4, 0, chunk_state_kernel<128, 64, 64, 32, 1, 2, 4, 2, 1, 1, Tp_>); } if (P % 256 == 0 && N % 256 == 0) { if (compute >= (1LL << 44)) return set(64, 128, 32, 1, 4, 4, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 4, 2, 4, 1, Tp_>); else if (compute >= (1LL << 36)) return set(64, 128, 32, 2, 2, 3, 0, chunk_state_kernel<128, 64, 128, 32, 2, 2, 3, 1, 4, 1, Tp_>); else if (compute >= (1LL << 35)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 34)) return set(64, 64, 32, 1, 2, 2, 0, chunk_state_kernel<128, 64, 64, 32, 1, 2, 2, 1, 8, 0, Tp_>); else return set(64, 128, 32, 2, 2, 3, 0, chunk_state_kernel<128, 64, 128, 32, 2, 2, 3, 1, 4, 1, Tp_>); } if (P % 256 == 0 && N % 128 == 0) { if (compute >= (1LL << 44)) return set(64, 128, 32, 1, 4, 4, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 4, 2, 4, 1, Tp_>); else if (compute >= (1LL << 36)) return set(64, 128, 32, 2, 2, 3, 0, chunk_state_kernel<128, 64, 128, 32, 2, 2, 3, 1, 4, 1, Tp_>); else if (compute >= (1LL << 35)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 34)) return set(64, 64, 32, 1, 2, 2, 0, chunk_state_kernel<128, 64, 64, 32, 1, 2, 2, 1, 8, 0, Tp_>); else return set(64, 128, 32, 2, 2, 3, 0, chunk_state_kernel<128, 64, 128, 32, 2, 2, 3, 1, 4, 1, Tp_>); } if (P % 256 == 0 && N % 64 == 0) { if (compute >= (1LL << 44)) return set(64, 128, 32, 1, 4, 4, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 4, 2, 4, 1, Tp_>); else if (compute >= (1LL << 36)) return set(64, 128, 32, 2, 2, 3, 0, chunk_state_kernel<128, 64, 128, 32, 2, 2, 3, 1, 4, 1, Tp_>); else if (compute >= (1LL << 35)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 34)) return set(64, 64, 32, 1, 2, 2, 0, chunk_state_kernel<128, 64, 64, 32, 1, 2, 2, 1, 8, 0, Tp_>); else return set(64, 128, 32, 2, 2, 3, 0, chunk_state_kernel<128, 64, 128, 32, 2, 2, 3, 1, 4, 1, Tp_>); } if (P % 128 == 0 && N % 256 == 0) { if (compute >= (1LL << 44)) return set(64, 128, 32, 1, 4, 4, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 4, 2, 4, 1, Tp_>); else if (compute >= (1LL << 35)) return set(64, 128, 32, 2, 2, 3, 0, chunk_state_kernel<128, 64, 128, 32, 2, 2, 3, 1, 4, 1, Tp_>); else if (compute >= (1LL << 34)) return set(64, 128, 32, 1, 4, 4, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 4, 2, 4, 1, Tp_>); else if (compute >= (1LL << 33)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 2, 1, 8, 1, Tp_>); else return set(64, 64, 32, 2, 2, 3, 0, chunk_state_kernel<128, 64, 64, 32, 2, 2, 3, 1, 4, 1, Tp_>); } if (P % 128 == 0 && N % 128 == 0) { if (compute >= (1LL << 44)) return set(64, 128, 32, 1, 4, 4, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 4, 2, 4, 1, Tp_>); else if (compute >= (1LL << 35)) return set(64, 128, 32, 2, 2, 3, 0, chunk_state_kernel<128, 64, 128, 32, 2, 2, 3, 1, 4, 1, Tp_>); else if (compute >= (1LL << 34)) return set(64, 128, 32, 1, 4, 4, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 4, 2, 4, 1, Tp_>); else if (compute >= (1LL << 33)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 2, 1, 8, 1, Tp_>); else return set(64, 64, 32, 2, 2, 3, 0, chunk_state_kernel<128, 64, 64, 32, 2, 2, 3, 1, 4, 1, Tp_>); } if (P % 128 == 0 && N % 64 == 0) { if (compute >= (1LL << 44)) return set(64, 128, 32, 1, 4, 4, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 4, 2, 4, 1, Tp_>); else if (compute >= (1LL << 35)) return set(64, 128, 32, 2, 2, 3, 0, chunk_state_kernel<128, 64, 128, 32, 2, 2, 3, 1, 4, 1, Tp_>); else if (compute >= (1LL << 34)) return set(64, 128, 32, 1, 4, 4, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 4, 2, 4, 1, Tp_>); else if (compute >= (1LL << 33)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 2, 1, 8, 1, Tp_>); else return set(64, 64, 32, 2, 2, 3, 0, chunk_state_kernel<128, 64, 64, 32, 2, 2, 3, 1, 4, 1, Tp_>); } if (P % 64 == 0 && N % 256 == 0) { if (compute >= (1LL << 44)) return set(64, 64, 32, 1, 2, 4, 0, chunk_state_kernel<128, 64, 64, 32, 1, 2, 4, 2, 1, 1, Tp_>); else if (compute >= (1LL << 36)) return set(64, 64, 32, 1, 2, 2, 0, chunk_state_kernel<128, 64, 64, 32, 1, 2, 2, 1, 8, 0, Tp_>); else if (compute >= (1LL << 34)) return set(64, 64, 32, 2, 2, 3, 0, chunk_state_kernel<128, 64, 64, 32, 2, 2, 3, 1, 4, 1, Tp_>); else if (compute >= (1LL << 33)) return set(64, 64, 32, 1, 2, 4, 0, chunk_state_kernel<128, 64, 64, 32, 1, 2, 4, 2, 1, 1, Tp_>); else if (compute >= (1LL << 32)) return set(64, 64, 32, 2, 2, 3, 0, chunk_state_kernel<128, 64, 64, 32, 2, 2, 3, 1, 4, 1, Tp_>); else return set(64, 64, 32, 1, 2, 2, 0, chunk_state_kernel<128, 64, 64, 32, 1, 2, 2, 1, 8, 0, Tp_>); } if (P % 64 == 0 && N % 128 == 0) { if (compute >= (1LL << 44)) return set(64, 64, 32, 1, 2, 4, 0, chunk_state_kernel<128, 64, 64, 32, 1, 2, 4, 2, 1, 1, Tp_>); else if (compute >= (1LL << 36)) return set(64, 64, 32, 1, 2, 2, 0, chunk_state_kernel<128, 64, 64, 32, 1, 2, 2, 1, 8, 0, Tp_>); else if (compute >= (1LL << 34)) return set(64, 64, 32, 2, 2, 3, 0, chunk_state_kernel<128, 64, 64, 32, 2, 2, 3, 1, 4, 1, Tp_>); else if (compute >= (1LL << 33)) return set(64, 64, 32, 1, 2, 4, 0, chunk_state_kernel<128, 64, 64, 32, 1, 2, 4, 2, 1, 1, Tp_>); else if (compute >= (1LL << 32)) return set(64, 64, 32, 2, 2, 3, 0, chunk_state_kernel<128, 64, 64, 32, 2, 2, 3, 1, 4, 1, Tp_>); else return set(64, 64, 32, 1, 2, 2, 0, chunk_state_kernel<128, 64, 64, 32, 1, 2, 2, 1, 8, 0, Tp_>); } #endif if (P % 64 == 0 && N % 64 == 0) { if (compute >= (1LL << 44)) return set(64, 64, 32, 1, 2, 4, 0, chunk_state_kernel<128, 64, 64, 32, 1, 2, 4, 2, 1, 1, Tp_>); else if (compute >= (1LL << 36)) return set(64, 64, 32, 1, 2, 2, 0, chunk_state_kernel<128, 64, 64, 32, 1, 2, 2, 1, 8, 0, Tp_>); else if (compute >= (1LL << 34)) return set(64, 64, 32, 2, 2, 3, 0, chunk_state_kernel<128, 64, 64, 32, 2, 2, 3, 1, 4, 1, Tp_>); else if (compute >= (1LL << 33)) return set(64, 64, 32, 1, 2, 4, 0, chunk_state_kernel<128, 64, 64, 32, 1, 2, 4, 2, 1, 1, Tp_>); else if (compute >= (1LL << 32)) return set(64, 64, 32, 2, 2, 3, 0, chunk_state_kernel<128, 64, 64, 32, 2, 2, 3, 1, 4, 1, Tp_>); else return set(64, 64, 32, 1, 2, 2, 0, chunk_state_kernel<128, 64, 64, 32, 1, 2, 2, 1, 8, 0, Tp_>); } if (round_up(P, 32) == 32) P = 32; else P = round_up(P, 64); #ifndef FAST_BUILD if (P == 32 && N % 256 == 0) { if (compute >= (1LL << 34)) return set(64, 32, 32, 2, 1, 2, 0, chunk_state_kernel<128, 64, 32, 32, 2, 1, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 31)) return set(64, 32, 32, 1, 2, 4, 0, chunk_state_kernel<128, 64, 32, 32, 1, 2, 4, 1, 2, 1, Tp_>); else return set(64, 64, 32, 1, 2, 4, 0, chunk_state_kernel<128, 64, 64, 32, 1, 2, 4, 2, 1, 1, Tp_>); } if (P == 32 && N % 128 == 0) { if (compute >= (1LL << 34)) return set(64, 32, 32, 2, 1, 2, 0, chunk_state_kernel<128, 64, 32, 32, 2, 1, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 31)) return set(64, 32, 32, 1, 2, 4, 0, chunk_state_kernel<128, 64, 32, 32, 1, 2, 4, 1, 2, 1, Tp_>); else return set(64, 64, 32, 1, 2, 4, 0, chunk_state_kernel<128, 64, 64, 32, 1, 2, 4, 2, 1, 1, Tp_>); } if (P == 32 && N % 64 == 0) { if (compute >= (1LL << 34)) return set(64, 32, 32, 2, 1, 2, 0, chunk_state_kernel<128, 64, 32, 32, 2, 1, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 31)) return set(64, 32, 32, 1, 2, 4, 0, chunk_state_kernel<128, 64, 32, 32, 1, 2, 4, 1, 2, 1, Tp_>); else return set(64, 64, 32, 1, 2, 4, 0, chunk_state_kernel<128, 64, 64, 32, 1, 2, 4, 2, 1, 1, Tp_>); } if (P % 256 == 0 && N % 256 == 0) { if (compute >= (1LL << 44)) return set(64, 128, 32, 1, 4, 4, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 4, 2, 4, 1, Tp_>); else if (compute >= (1LL << 36)) return set(64, 128, 32, 2, 2, 3, 0, chunk_state_kernel<128, 64, 128, 32, 2, 2, 3, 1, 4, 1, Tp_>); else if (compute >= (1LL << 34)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 2, 1, 8, 1, Tp_>); else return set(64, 128, 32, 2, 2, 3, 0, chunk_state_kernel<128, 64, 128, 32, 2, 2, 3, 1, 4, 1, Tp_>); } if (P % 256 == 0 && N % 128 == 0) { if (compute >= (1LL << 44)) return set(64, 128, 32, 1, 4, 4, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 4, 2, 4, 1, Tp_>); else if (compute >= (1LL << 36)) return set(64, 128, 32, 2, 2, 3, 0, chunk_state_kernel<128, 64, 128, 32, 2, 2, 3, 1, 4, 1, Tp_>); else if (compute >= (1LL << 34)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 2, 1, 8, 1, Tp_>); else return set(64, 128, 32, 2, 2, 3, 0, chunk_state_kernel<128, 64, 128, 32, 2, 2, 3, 1, 4, 1, Tp_>); } if (P % 256 == 0 && N % 64 == 0) { if (compute >= (1LL << 44)) return set(64, 128, 32, 1, 4, 4, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 4, 2, 4, 1, Tp_>); else if (compute >= (1LL << 36)) return set(64, 128, 32, 2, 2, 3, 0, chunk_state_kernel<128, 64, 128, 32, 2, 2, 3, 1, 4, 1, Tp_>); else if (compute >= (1LL << 34)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 2, 1, 8, 1, Tp_>); else return set(64, 128, 32, 2, 2, 3, 0, chunk_state_kernel<128, 64, 128, 32, 2, 2, 3, 1, 4, 1, Tp_>); } if (P % 128 == 0 && N % 256 == 0) { if (compute >= (1LL << 44)) return set(64, 128, 32, 1, 4, 4, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 4, 2, 4, 1, Tp_>); else if (compute >= (1LL << 35)) return set(64, 128, 32, 2, 2, 3, 0, chunk_state_kernel<128, 64, 128, 32, 2, 2, 3, 1, 4, 1, Tp_>); else if (compute >= (1LL << 34)) return set(64, 128, 32, 1, 4, 4, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 4, 2, 4, 1, Tp_>); else if (compute >= (1LL << 32)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 31)) return set(64, 128, 32, 2, 2, 3, 0, chunk_state_kernel<128, 64, 128, 32, 2, 2, 3, 1, 4, 1, Tp_>); else return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 2, 1, 8, 1, Tp_>); } if (P % 128 == 0 && N % 128 == 0) { if (compute >= (1LL << 44)) return set(64, 128, 32, 1, 4, 4, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 4, 2, 4, 1, Tp_>); else if (compute >= (1LL << 35)) return set(64, 128, 32, 2, 2, 3, 0, chunk_state_kernel<128, 64, 128, 32, 2, 2, 3, 1, 4, 1, Tp_>); else if (compute >= (1LL << 34)) return set(64, 128, 32, 1, 4, 4, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 4, 2, 4, 1, Tp_>); else if (compute >= (1LL << 32)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 31)) return set(64, 128, 32, 2, 2, 3, 0, chunk_state_kernel<128, 64, 128, 32, 2, 2, 3, 1, 4, 1, Tp_>); else return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 2, 1, 8, 1, Tp_>); } if (P % 128 == 0 && N % 64 == 0) { if (compute >= (1LL << 44)) return set(64, 128, 32, 1, 4, 4, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 4, 2, 4, 1, Tp_>); else if (compute >= (1LL << 35)) return set(64, 128, 32, 2, 2, 3, 0, chunk_state_kernel<128, 64, 128, 32, 2, 2, 3, 1, 4, 1, Tp_>); else if (compute >= (1LL << 34)) return set(64, 128, 32, 1, 4, 4, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 4, 2, 4, 1, Tp_>); else if (compute >= (1LL << 32)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 31)) return set(64, 128, 32, 2, 2, 3, 0, chunk_state_kernel<128, 64, 128, 32, 2, 2, 3, 1, 4, 1, Tp_>); else return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 2, 1, 8, 1, Tp_>); } if (P % 64 == 0 && N % 256 == 0) { if (compute >= (1LL << 44)) return set(64, 64, 32, 1, 2, 4, 0, chunk_state_kernel<128, 64, 64, 32, 1, 2, 4, 2, 1, 1, Tp_>); else if (compute >= (1LL << 34)) return set(64, 64, 32, 2, 2, 3, 0, chunk_state_kernel<128, 64, 64, 32, 2, 2, 3, 1, 4, 1, Tp_>); else if (compute >= (1LL << 33)) return set(64, 64, 32, 1, 2, 4, 0, chunk_state_kernel<128, 64, 64, 32, 1, 2, 4, 2, 1, 1, Tp_>); else if (compute >= (1LL << 32)) return set(64, 64, 32, 2, 2, 3, 0, chunk_state_kernel<128, 64, 64, 32, 2, 2, 3, 1, 4, 1, Tp_>); else if (compute >= (1LL << 31)) return set(64, 64, 32, 1, 2, 4, 0, chunk_state_kernel<128, 64, 64, 32, 1, 2, 4, 2, 1, 1, Tp_>); else if (compute >= (1LL << 30)) return set(64, 32, 32, 2, 1, 2, 0, chunk_state_kernel<128, 64, 32, 32, 2, 1, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 27)) return set(64, 64, 32, 2, 2, 3, 0, chunk_state_kernel<128, 64, 64, 32, 2, 2, 3, 1, 4, 1, Tp_>); else return set(64, 64, 32, 1, 2, 4, 0, chunk_state_kernel<128, 64, 64, 32, 1, 2, 4, 2, 1, 1, Tp_>); } if (P % 64 == 0 && N % 128 == 0) { if (compute >= (1LL << 44)) return set(64, 64, 32, 1, 2, 4, 0, chunk_state_kernel<128, 64, 64, 32, 1, 2, 4, 2, 1, 1, Tp_>); else if (compute >= (1LL << 34)) return set(64, 64, 32, 2, 2, 3, 0, chunk_state_kernel<128, 64, 64, 32, 2, 2, 3, 1, 4, 1, Tp_>); else if (compute >= (1LL << 33)) return set(64, 64, 32, 1, 2, 4, 0, chunk_state_kernel<128, 64, 64, 32, 1, 2, 4, 2, 1, 1, Tp_>); else if (compute >= (1LL << 32)) return set(64, 64, 32, 2, 2, 3, 0, chunk_state_kernel<128, 64, 64, 32, 2, 2, 3, 1, 4, 1, Tp_>); else if (compute >= (1LL << 31)) return set(64, 64, 32, 1, 2, 4, 0, chunk_state_kernel<128, 64, 64, 32, 1, 2, 4, 2, 1, 1, Tp_>); else if (compute >= (1LL << 30)) return set(64, 32, 32, 2, 1, 2, 0, chunk_state_kernel<128, 64, 32, 32, 2, 1, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 27)) return set(64, 64, 32, 2, 2, 3, 0, chunk_state_kernel<128, 64, 64, 32, 2, 2, 3, 1, 4, 1, Tp_>); else return set(64, 64, 32, 1, 2, 4, 0, chunk_state_kernel<128, 64, 64, 32, 1, 2, 4, 2, 1, 1, Tp_>); } #endif if (P % 64 == 0 && N % 64 == 0) { if (compute >= (1LL << 44)) return set(64, 64, 32, 1, 2, 4, 0, chunk_state_kernel<128, 64, 64, 32, 1, 2, 4, 2, 1, 1, Tp_>); else if (compute >= (1LL << 34)) return set(64, 64, 32, 2, 2, 3, 0, chunk_state_kernel<128, 64, 64, 32, 2, 2, 3, 1, 4, 1, Tp_>); else if (compute >= (1LL << 33)) return set(64, 64, 32, 1, 2, 4, 0, chunk_state_kernel<128, 64, 64, 32, 1, 2, 4, 2, 1, 1, Tp_>); else if (compute >= (1LL << 32)) return set(64, 64, 32, 2, 2, 3, 0, chunk_state_kernel<128, 64, 64, 32, 2, 2, 3, 1, 4, 1, Tp_>); else if (compute >= (1LL << 31)) return set(64, 64, 32, 1, 2, 4, 0, chunk_state_kernel<128, 64, 64, 32, 1, 2, 4, 2, 1, 1, Tp_>); else if (compute >= (1LL << 30)) return set(64, 32, 32, 2, 1, 2, 0, chunk_state_kernel<128, 64, 32, 32, 2, 1, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 27)) return set(64, 64, 32, 2, 2, 3, 0, chunk_state_kernel<128, 64, 64, 32, 2, 2, 3, 1, 4, 1, Tp_>); else return set(64, 64, 32, 1, 2, 4, 0, chunk_state_kernel<128, 64, 64, 32, 1, 2, 4, 2, 1, 1, Tp_>); } } else if (Q == 128 && isHopper_) { #ifndef FAST_BUILD if (P == 32 && N % 256 == 0) { if (compute >= (1LL << 37)) return set(64, 32, 32, 2, 1, 2, 0, chunk_state_kernel<128, 64, 32, 32, 2, 1, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 34)) return set(64, 32, 32, 1, 2, 4, 0, chunk_state_kernel<128, 64, 32, 32, 1, 2, 4, 1, 2, 1, Tp_>); else if (compute >= (1LL << 33)) return set(64, 32, 32, 2, 1, 2, 0, chunk_state_kernel<128, 64, 32, 32, 2, 1, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 32)) return set(64, 32, 32, 1, 2, 4, 0, chunk_state_kernel<128, 64, 32, 32, 1, 2, 4, 1, 2, 1, Tp_>); else return set(64, 64, 32, 2, 2, 3, 0, chunk_state_kernel<128, 64, 64, 32, 2, 2, 3, 1, 4, 1, Tp_>); } if (P == 32 && N % 128 == 0) { if (compute >= (1LL << 37)) return set(64, 32, 32, 2, 1, 2, 0, chunk_state_kernel<128, 64, 32, 32, 2, 1, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 34)) return set(64, 32, 32, 1, 2, 4, 0, chunk_state_kernel<128, 64, 32, 32, 1, 2, 4, 1, 2, 1, Tp_>); else if (compute >= (1LL << 33)) return set(64, 32, 32, 2, 1, 2, 0, chunk_state_kernel<128, 64, 32, 32, 2, 1, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 32)) return set(64, 32, 32, 1, 2, 4, 0, chunk_state_kernel<128, 64, 32, 32, 1, 2, 4, 1, 2, 1, Tp_>); else return set(64, 64, 32, 2, 2, 3, 0, chunk_state_kernel<128, 64, 64, 32, 2, 2, 3, 1, 4, 1, Tp_>); } if (P == 32 && N % 64 == 0) { if (compute >= (1LL << 37)) return set(64, 32, 32, 2, 1, 2, 0, chunk_state_kernel<128, 64, 32, 32, 2, 1, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 34)) return set(64, 32, 32, 1, 2, 4, 0, chunk_state_kernel<128, 64, 32, 32, 1, 2, 4, 1, 2, 1, Tp_>); else if (compute >= (1LL << 33)) return set(64, 32, 32, 2, 1, 2, 0, chunk_state_kernel<128, 64, 32, 32, 2, 1, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 32)) return set(64, 32, 32, 1, 2, 4, 0, chunk_state_kernel<128, 64, 32, 32, 1, 2, 4, 1, 2, 1, Tp_>); else return set(64, 64, 32, 2, 2, 3, 0, chunk_state_kernel<128, 64, 64, 32, 2, 2, 3, 1, 4, 1, Tp_>); } if (P % 256 == 0 && N % 256 == 0) { if (compute >= (1LL << 43)) return set(64, 128, 32, 1, 4, 4, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 4, 2, 4, 1, Tp_>); else if (compute >= (1LL << 42)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 40)) return set(64, 128, 32, 1, 4, 4, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 4, 2, 4, 1, Tp_>); else if (compute >= (1LL << 38)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 37)) return set(64, 128, 32, 1, 4, 4, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 4, 2, 4, 1, Tp_>); else if (compute >= (1LL << 36)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 33)) return set(64, 128, 32, 1, 4, 4, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 4, 2, 4, 1, Tp_>); else return set(64, 64, 32, 2, 2, 3, 0, chunk_state_kernel<128, 64, 64, 32, 2, 2, 3, 1, 4, 1, Tp_>); } if (P % 256 == 0 && N % 128 == 0) { if (compute >= (1LL << 43)) return set(64, 128, 32, 1, 4, 4, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 4, 2, 4, 1, Tp_>); else if (compute >= (1LL << 42)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 40)) return set(64, 128, 32, 1, 4, 4, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 4, 2, 4, 1, Tp_>); else if (compute >= (1LL << 38)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 37)) return set(64, 128, 32, 1, 4, 4, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 4, 2, 4, 1, Tp_>); else if (compute >= (1LL << 36)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 33)) return set(64, 128, 32, 1, 4, 4, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 4, 2, 4, 1, Tp_>); else return set(64, 64, 32, 2, 2, 3, 0, chunk_state_kernel<128, 64, 64, 32, 2, 2, 3, 1, 4, 1, Tp_>); } if (P % 256 == 0 && N % 64 == 0) { if (compute >= (1LL << 43)) return set(64, 128, 32, 1, 4, 4, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 4, 2, 4, 1, Tp_>); else if (compute >= (1LL << 42)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 40)) return set(64, 128, 32, 1, 4, 4, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 4, 2, 4, 1, Tp_>); else if (compute >= (1LL << 38)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 37)) return set(64, 128, 32, 1, 4, 4, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 4, 2, 4, 1, Tp_>); else if (compute >= (1LL << 36)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 33)) return set(64, 128, 32, 1, 4, 4, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 4, 2, 4, 1, Tp_>); else return set(64, 64, 32, 2, 2, 3, 0, chunk_state_kernel<128, 64, 64, 32, 2, 2, 3, 1, 4, 1, Tp_>); } if (P % 128 == 0 && N % 256 == 0) { if (compute >= (1LL << 43)) return set(64, 128, 32, 1, 4, 4, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 4, 2, 4, 1, Tp_>); else if (compute >= (1LL << 42)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 40)) return set(64, 128, 32, 1, 4, 4, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 4, 2, 4, 1, Tp_>); else if (compute >= (1LL << 36)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 35)) return set(64, 128, 32, 1, 4, 4, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 4, 2, 4, 1, Tp_>); else if (compute >= (1LL << 34)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 32)) return set(64, 128, 32, 1, 4, 4, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 4, 2, 4, 1, Tp_>); else return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 2, 1, 8, 1, Tp_>); } if (P % 128 == 0 && N % 128 == 0) { if (compute >= (1LL << 43)) return set(64, 128, 32, 1, 4, 4, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 4, 2, 4, 1, Tp_>); else if (compute >= (1LL << 42)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 40)) return set(64, 128, 32, 1, 4, 4, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 4, 2, 4, 1, Tp_>); else if (compute >= (1LL << 36)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 35)) return set(64, 128, 32, 1, 4, 4, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 4, 2, 4, 1, Tp_>); else if (compute >= (1LL << 34)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 32)) return set(64, 128, 32, 1, 4, 4, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 4, 2, 4, 1, Tp_>); else return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 2, 1, 8, 1, Tp_>); } if (P % 128 == 0 && N % 64 == 0) { if (compute >= (1LL << 43)) return set(64, 128, 32, 1, 4, 4, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 4, 2, 4, 1, Tp_>); else if (compute >= (1LL << 42)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 40)) return set(64, 128, 32, 1, 4, 4, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 4, 2, 4, 1, Tp_>); else if (compute >= (1LL << 36)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 35)) return set(64, 128, 32, 1, 4, 4, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 4, 2, 4, 1, Tp_>); else if (compute >= (1LL << 34)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 32)) return set(64, 128, 32, 1, 4, 4, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 4, 2, 4, 1, Tp_>); else return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 2, 1, 8, 1, Tp_>); } if (P % 64 == 0 && N % 256 == 0) { if (compute >= (1LL << 40)) return set(64, 64, 32, 4, 1, 3, 1, chunk_state_hopper<128, 64, 64, 32, 4, 1, 3, 0, 8, 1, Tp_>); else if (compute >= (1LL << 36)) return set(64, 64, 32, 1, 2, 4, 0, chunk_state_kernel<128, 64, 64, 32, 1, 2, 4, 2, 1, 1, Tp_>); else if (compute >= (1LL << 31)) return set(64, 64, 32, 2, 2, 3, 0, chunk_state_kernel<128, 64, 64, 32, 2, 2, 3, 1, 4, 1, Tp_>); else return set(64, 64, 32, 1, 2, 4, 0, chunk_state_kernel<128, 64, 64, 32, 1, 2, 4, 2, 1, 1, Tp_>); } if (P % 64 == 0 && N % 128 == 0) { if (compute >= (1LL << 40)) return set(64, 64, 32, 4, 1, 3, 1, chunk_state_hopper<128, 64, 64, 32, 4, 1, 3, 0, 8, 1, Tp_>); else if (compute >= (1LL << 36)) return set(64, 64, 32, 1, 2, 4, 0, chunk_state_kernel<128, 64, 64, 32, 1, 2, 4, 2, 1, 1, Tp_>); else if (compute >= (1LL << 31)) return set(64, 64, 32, 2, 2, 3, 0, chunk_state_kernel<128, 64, 64, 32, 2, 2, 3, 1, 4, 1, Tp_>); else return set(64, 64, 32, 1, 2, 4, 0, chunk_state_kernel<128, 64, 64, 32, 1, 2, 4, 2, 1, 1, Tp_>); } #endif if (P % 64 == 0 && N % 64 == 0) { if (compute >= (1LL << 40)) return set(64, 64, 32, 4, 1, 3, 1, chunk_state_hopper<128, 64, 64, 32, 4, 1, 3, 0, 8, 1, Tp_>); else if (compute >= (1LL << 36)) return set(64, 64, 32, 1, 2, 4, 0, chunk_state_kernel<128, 64, 64, 32, 1, 2, 4, 2, 1, 1, Tp_>); else if (compute >= (1LL << 31)) return set(64, 64, 32, 2, 2, 3, 0, chunk_state_kernel<128, 64, 64, 32, 2, 2, 3, 1, 4, 1, Tp_>); else return set(64, 64, 32, 1, 2, 4, 0, chunk_state_kernel<128, 64, 64, 32, 1, 2, 4, 2, 1, 1, Tp_>); } if (round_up(P, 32) == 32) P = 32; else P = round_up(P, 64); #ifndef FAST_BUILD if (P == 32 && N % 256 == 0) { if (compute >= (1LL << 37)) return set(64, 32, 32, 2, 1, 2, 0, chunk_state_kernel<128, 64, 32, 32, 2, 1, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 34)) return set(64, 32, 32, 1, 2, 4, 0, chunk_state_kernel<128, 64, 32, 32, 1, 2, 4, 1, 2, 1, Tp_>); else if (compute >= (1LL << 33)) return set(64, 32, 32, 2, 1, 2, 0, chunk_state_kernel<128, 64, 32, 32, 2, 1, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 32)) return set(64, 32, 32, 1, 2, 4, 0, chunk_state_kernel<128, 64, 32, 32, 1, 2, 4, 1, 2, 1, Tp_>); else return set(64, 64, 32, 2, 2, 3, 0, chunk_state_kernel<128, 64, 64, 32, 2, 2, 3, 1, 4, 1, Tp_>); } if (P == 32 && N % 128 == 0) { if (compute >= (1LL << 37)) return set(64, 32, 32, 2, 1, 2, 0, chunk_state_kernel<128, 64, 32, 32, 2, 1, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 34)) return set(64, 32, 32, 1, 2, 4, 0, chunk_state_kernel<128, 64, 32, 32, 1, 2, 4, 1, 2, 1, Tp_>); else if (compute >= (1LL << 33)) return set(64, 32, 32, 2, 1, 2, 0, chunk_state_kernel<128, 64, 32, 32, 2, 1, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 32)) return set(64, 32, 32, 1, 2, 4, 0, chunk_state_kernel<128, 64, 32, 32, 1, 2, 4, 1, 2, 1, Tp_>); else return set(64, 64, 32, 2, 2, 3, 0, chunk_state_kernel<128, 64, 64, 32, 2, 2, 3, 1, 4, 1, Tp_>); } if (P == 32 && N % 64 == 0) { if (compute >= (1LL << 37)) return set(64, 32, 32, 2, 1, 2, 0, chunk_state_kernel<128, 64, 32, 32, 2, 1, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 34)) return set(64, 32, 32, 1, 2, 4, 0, chunk_state_kernel<128, 64, 32, 32, 1, 2, 4, 1, 2, 1, Tp_>); else if (compute >= (1LL << 33)) return set(64, 32, 32, 2, 1, 2, 0, chunk_state_kernel<128, 64, 32, 32, 2, 1, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 32)) return set(64, 32, 32, 1, 2, 4, 0, chunk_state_kernel<128, 64, 32, 32, 1, 2, 4, 1, 2, 1, Tp_>); else return set(64, 64, 32, 2, 2, 3, 0, chunk_state_kernel<128, 64, 64, 32, 2, 2, 3, 1, 4, 1, Tp_>); } if (P % 256 == 0 && N % 256 == 0) { if (compute >= (1LL << 43)) return set(64, 128, 32, 1, 4, 4, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 4, 2, 4, 1, Tp_>); else if (compute >= (1LL << 42)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 40)) return set(64, 128, 32, 1, 4, 4, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 4, 2, 4, 1, Tp_>); else if (compute >= (1LL << 38)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 37)) return set(64, 128, 32, 1, 4, 4, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 4, 2, 4, 1, Tp_>); else if (compute >= (1LL << 36)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 33)) return set(64, 128, 32, 1, 4, 4, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 4, 2, 4, 1, Tp_>); else return set(64, 64, 32, 2, 2, 3, 0, chunk_state_kernel<128, 64, 64, 32, 2, 2, 3, 1, 4, 1, Tp_>); } if (P % 256 == 0 && N % 128 == 0) { if (compute >= (1LL << 43)) return set(64, 128, 32, 1, 4, 4, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 4, 2, 4, 1, Tp_>); else if (compute >= (1LL << 42)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 40)) return set(64, 128, 32, 1, 4, 4, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 4, 2, 4, 1, Tp_>); else if (compute >= (1LL << 38)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 37)) return set(64, 128, 32, 1, 4, 4, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 4, 2, 4, 1, Tp_>); else if (compute >= (1LL << 36)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 33)) return set(64, 128, 32, 1, 4, 4, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 4, 2, 4, 1, Tp_>); else return set(64, 64, 32, 2, 2, 3, 0, chunk_state_kernel<128, 64, 64, 32, 2, 2, 3, 1, 4, 1, Tp_>); } if (P % 256 == 0 && N % 64 == 0) { if (compute >= (1LL << 43)) return set(64, 128, 32, 1, 4, 4, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 4, 2, 4, 1, Tp_>); else if (compute >= (1LL << 42)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 40)) return set(64, 128, 32, 1, 4, 4, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 4, 2, 4, 1, Tp_>); else if (compute >= (1LL << 38)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 37)) return set(64, 128, 32, 1, 4, 4, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 4, 2, 4, 1, Tp_>); else if (compute >= (1LL << 36)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 33)) return set(64, 128, 32, 1, 4, 4, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 4, 2, 4, 1, Tp_>); else return set(64, 64, 32, 2, 2, 3, 0, chunk_state_kernel<128, 64, 64, 32, 2, 2, 3, 1, 4, 1, Tp_>); } if (P % 128 == 0 && N % 256 == 0) { if (compute >= (1LL << 42)) return set(64, 128, 32, 1, 4, 4, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 4, 2, 4, 1, Tp_>); else if (compute >= (1LL << 41)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 39)) return set(64, 128, 32, 1, 4, 4, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 4, 2, 4, 1, Tp_>); else if (compute >= (1LL << 36)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 29)) return set(64, 128, 32, 1, 4, 4, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 4, 2, 4, 1, Tp_>); else return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 2, 1, 8, 1, Tp_>); } if (P % 128 == 0 && N % 128 == 0) { if (compute >= (1LL << 42)) return set(64, 128, 32, 1, 4, 4, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 4, 2, 4, 1, Tp_>); else if (compute >= (1LL << 41)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 39)) return set(64, 128, 32, 1, 4, 4, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 4, 2, 4, 1, Tp_>); else if (compute >= (1LL << 36)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 29)) return set(64, 128, 32, 1, 4, 4, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 4, 2, 4, 1, Tp_>); else return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 2, 1, 8, 1, Tp_>); } if (P % 128 == 0 && N % 64 == 0) { if (compute >= (1LL << 42)) return set(64, 128, 32, 1, 4, 4, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 4, 2, 4, 1, Tp_>); else if (compute >= (1LL << 41)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 39)) return set(64, 128, 32, 1, 4, 4, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 4, 2, 4, 1, Tp_>); else if (compute >= (1LL << 36)) return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 2, 1, 8, 1, Tp_>); else if (compute >= (1LL << 29)) return set(64, 128, 32, 1, 4, 4, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 4, 2, 4, 1, Tp_>); else return set(64, 128, 32, 1, 4, 2, 0, chunk_state_kernel<128, 64, 128, 32, 1, 4, 2, 1, 8, 1, Tp_>); } if (P % 64 == 0 && N % 256 == 0) { if (compute >= (1LL << 39)) return set(64, 64, 32, 4, 1, 3, 1, chunk_state_hopper<128, 64, 64, 32, 4, 1, 3, 0, 8, 1, Tp_>); else if (compute >= (1LL << 35)) return set(64, 64, 32, 1, 2, 4, 0, chunk_state_kernel<128, 64, 64, 32, 1, 2, 4, 2, 1, 1, Tp_>); else if (compute >= (1LL << 31)) return set(64, 64, 32, 2, 2, 3, 0, chunk_state_kernel<128, 64, 64, 32, 2, 2, 3, 1, 4, 1, Tp_>); else if (compute >= (1LL << 30)) return set(64, 32, 32, 1, 2, 4, 0, chunk_state_kernel<128, 64, 32, 32, 1, 2, 4, 1, 2, 1, Tp_>); else if (compute >= (1LL << 27)) return set(64, 64, 32, 2, 2, 3, 0, chunk_state_kernel<128, 64, 64, 32, 2, 2, 3, 1, 4, 1, Tp_>); else return set(64, 64, 32, 1, 2, 4, 0, chunk_state_kernel<128, 64, 64, 32, 1, 2, 4, 2, 1, 1, Tp_>); } if (P % 64 == 0 && N % 128 == 0) { if (compute >= (1LL << 39)) return set(64, 64, 32, 4, 1, 3, 1, chunk_state_hopper<128, 64, 64, 32, 4, 1, 3, 0, 8, 1, Tp_>); else if (compute >= (1LL << 35)) return set(64, 64, 32, 1, 2, 4, 0, chunk_state_kernel<128, 64, 64, 32, 1, 2, 4, 2, 1, 1, Tp_>); else if (compute >= (1LL << 31)) return set(64, 64, 32, 2, 2, 3, 0, chunk_state_kernel<128, 64, 64, 32, 2, 2, 3, 1, 4, 1, Tp_>); else if (compute >= (1LL << 30)) return set(64, 32, 32, 1, 2, 4, 0, chunk_state_kernel<128, 64, 32, 32, 1, 2, 4, 1, 2, 1, Tp_>); else if (compute >= (1LL << 27)) return set(64, 64, 32, 2, 2, 3, 0, chunk_state_kernel<128, 64, 64, 32, 2, 2, 3, 1, 4, 1, Tp_>); else return set(64, 64, 32, 1, 2, 4, 0, chunk_state_kernel<128, 64, 64, 32, 1, 2, 4, 2, 1, 1, Tp_>); } #endif if (P % 64 == 0 && N % 64 == 0) { if (compute >= (1LL << 39)) return set(64, 64, 32, 4, 1, 3, 1, chunk_state_hopper<128, 64, 64, 32, 4, 1, 3, 0, 8, 1, Tp_>); else if (compute >= (1LL << 35)) return set(64, 64, 32, 1, 2, 4, 0, chunk_state_kernel<128, 64, 64, 32, 1, 2, 4, 2, 1, 1, Tp_>); else if (compute >= (1LL << 31)) return set(64, 64, 32, 2, 2, 3, 0, chunk_state_kernel<128, 64, 64, 32, 2, 2, 3, 1, 4, 1, Tp_>); else if (compute >= (1LL << 30)) return set(64, 32, 32, 1, 2, 4, 0, chunk_state_kernel<128, 64, 32, 32, 1, 2, 4, 1, 2, 1, Tp_>); else if (compute >= (1LL << 27)) return set(64, 64, 32, 2, 2, 3, 0, chunk_state_kernel<128, 64, 64, 32, 2, 2, 3, 1, 4, 1, Tp_>); else return set(64, 64, 32, 1, 2, 4, 0, chunk_state_kernel<128, 64, 64, 32, 1, 2, 4, 2, 1, 1, Tp_>); } } return nullptr; } extern GetChunkStateKernelFunc getChunkStateKernel_fp16; extern GetChunkStateKernelFunc getChunkStateKernel_bf16; static inline ChunkStateKernelFunc getChunkStateKernel(int B_, int L_, int H_, int P_, int G_, int N_, int Q_, int numTokens_, int isHopper_, tensorrt_llm::common::CUDADriverWrapper* cudaDriver_, dim3* blockDims_, dim3* threadDims_, int* sharedMem_, int* useTma_, CUtensorMap* descs_, CudaType tp_ = CT_FP16) { if (tp_ == CT_FP16) return getChunkStateKernel_fp16(B_, L_, H_, P_, G_, N_, Q_, numTokens_, isHopper_, cudaDriver_, blockDims_, threadDims_, sharedMem_, useTma_, descs_); else if (tp_ == CT_BF16) return getChunkStateKernel_bf16(B_, L_, H_, P_, G_, N_, Q_, numTokens_, isHopper_, cudaDriver_, blockDims_, threadDims_, sharedMem_, useTma_, descs_); return nullptr; } } // namespace kernels } // namespace tensorrt_llm // vim: ts=2 sw=2 sts=2 et sta