TensorRT-LLMs/cpp/tensorrt_llm/kernels/chunkScan/chunkscan.h
Kaiyu Xie 2d234357c6
Update TensorRT-LLM (#1954)
* Update TensorRT-LLM

---------

Co-authored-by: Altair-Alpha <62340011+Altair-Alpha@users.noreply.github.com>
2024-07-16 15:30:25 +08:00

635 lines
28 KiB
C++

/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cuda_fp8.h>
#include <mma.h>
#include "tensorrt_llm/common/cudaBf16Fallbacks.cuh"
#include "Common.h"
#include "Poly.h"
namespace tensorrt_llm
{
namespace kernels
{
typedef void (*ChunkScanKernelFuncFp16)(int B_, int L_, int H_, int P_, int G_, int N_,
half* g_mxY_, // B*L*H*P
half const* g_mxOs_, // B*C*H*N*P
// const half *g_mxFs_, // B *H*N*P
// const float *g_mxSt_, // B*C*H*N*P
float const* g_mxdc_, // B*C*H*Q
float const* g_mxdA_, // B*C*H*Q
// const half *g_mxdt_, // B*L*(2*H*P+2*G*N+H) or B*L*(H*P+2*G*N+H)
// const float *g_mxdb_, // H
// const float *g_mxA_, // H
half const* g_mxCB_, // B*C*G*Q*Q
float const* g_mxD_, // H
half const* g_mxXBC_, // B*L*(H*P+2*G*N)
half const* g_mxZ_, // B*L*(2*H*P+2*G*N+H)
bool removePadding_, int const* lastTokenIdsPtr_);
typedef void (*ChunkScanKernelFuncBf16)(int B_, int L_, int H_, int P_, int G_, int N_,
bf16* g_mxY_, // B*L*H*P
bf16 const* g_mxOs_, // B*C*H*N*P
// const bf16 *g_mxFs_, // B *H*N*P
// const float *g_mxSt_, // B*C*H*N*P
float const* g_mxdc_, // B*C*H*Q
float const* g_mxdA_, // B*C*H*Q
// const bf16 *g_mxdt_, // B*L*(2*H*P+2*G*N+H) or B*L*(H*P+2*G*N+H)
// const float *g_mxdb_, // H
// const float *g_mxA_, // H
bf16 const* g_mxCB_, // B*C*G*Q*Q
float const* g_mxD_, // H
bf16 const* g_mxXBC_, // B*L*(H*P+2*G*N)
bf16 const* g_mxZ_, // B*L*(2*H*P+2*G*N+H)
bool removePadding_, int const* lastTokenIdsPtr_);
template <int Q_, int tileM_, int tileN_, int tileK_, // smem size, per sm
int wmmaM_, int wmmaN_, int wmmaK_, // wmma size, per instruction
int warpM_, int warpN_, // warp number
int pipeS_, class Tp_, class Wt_ = float>
__global__ std::enable_if_t<std::is_same_v<Tp_, half> || std::is_same_v<Tp_, __nv_bfloat16>> chunk_scan_kernel(int B_,
int L_, int H_, int P_, int G_, int N_,
Tp_* g_mxY_, // B*L*H*P
Tp_ const* g_mxOs_, // B*C*H*N*P
// const Tp_ *g_mxFs_, // B *H*N*P
// const float *g_mxSt_, // B*C*H*N*P
float const* g_mxdc_, // B*C*H*Q
float const* g_mxdA_, // B*C*H*Q
// const Tp_ *g_mxdt_, // B*L*(2*H*P+2*G*N+H) or B*L*(H*P+2*G*N+H)
// const Wt_ *g_mxdb_, // H
// const Wt_ *g_mxA_, // H
Tp_ const* g_mxCB_, // B*C*G*Q*Q
Wt_ const* g_mxD_, // H
Tp_ const* g_mxXBC_, // B*L*(H*P+2*G*N)
Tp_ const* g_mxZ_, // B*L*(2*H*P+2*G*N+H)
bool removePadding_, int const* lastTokenIdsPtr_)
{
#if __CUDA_ARCH__ >= 800
using namespace tensorrt_llm::common;
auto blockIdx_x = Rn<ID>{int(blockIdx.x)};
auto blockIdx_y = Rn<ID>{int(blockIdx.y)};
auto blockIdx_z = Rn<ID>{int(blockIdx.z)};
auto threadIdx_x = Rn<ID, 32>{int(threadIdx.x)};
auto threadIdx_y = Rn<ID, warpN_>{int(threadIdx.y)};
auto threadIdx_z = Rn<ID, warpM_>{int(threadIdx.z)};
// auto B = Rn<ID>{B_};
auto L = Rn<ID>{L_};
auto H = Rn<ID>{H_};
auto P = Rn<ID>{P_};
auto G = Rn<ID>{G_};
auto N = Rn<ID>{N_};
auto Q = cn<Q_>;
auto C = Rn<ID>{div_up(L.var, Q_)};
auto xbcDim = Rn<ID>{H_ * P_ + 2 * G_ * N_};
auto zdtDim = Rn<ID>{2 * H_ * P_ + 2 * G_ * N_ + H_};
auto cOffset = Rn<ID>{H_ * P_ + G_ * N_};
auto aStart = blockIdx_z * L;
auto cStart = blockIdx_z * C;
if (removePadding_)
{
aStart = Rn<ID>{int(blockIdx.z ? lastTokenIdsPtr_[blockIdx.z - 1] : 0)};
cStart = Rn<ID>{int(blockIdx.z ? div_up(aStart.var, Q_) + blockIdx.z - 1 : 0)};
L = Rn<ID>{lastTokenIdsPtr_[blockIdx.z] - aStart.var};
C = Rn<ID>{div_up(L.var, Q_)};
}
else
{
L = Rn<ID>{lastTokenIdsPtr_[blockIdx.z]};
C = Rn<ID>{div_up(L.var, Q_)};
}
if (blockIdx_y * Q >= L)
return;
auto hStart = Rn<ID>{blockIdx_x.var / (P_ / cn<tileN_>) / (Q / cn<tileM_>) };
auto mStart = Rn<ID>{blockIdx_x.var / (P_ / cn<tileN_>) % (Q / cn<tileM_>) };
auto nStart = Rn<ID>{blockIdx_x.var % (P_ / cn<tileN_>) };
auto gStart = Rn<ID>{hStart.var / (H_ / G_)};
extern __shared__ float smem[];
Tp_* s_mxC = (Tp_*) smem;
Tp_* s_mxOs = (Tp_*) smem + tileM_ * tileK_ * pipeS_;
Tp_* s_mxY = (Tp_*) smem;
float* s_mxdc = smem + (tileM_ + tileN_) * tileK_ * pipeS_ / 2;
float* s_mxdA = smem + (tileM_ + tileN_) * tileK_ * pipeS_ / 2 + Q_;
unsigned b_base = __nvvm_get_smem_pointer(smem);
unsigned b_mxC = b_base;
unsigned b_mxOs = b_base + tileM_ * tileK_ * pipeS_ * sizeof(Tp_);
unsigned b_mxY = b_base;
using std::array;
register array<array<array<float, wmmaM_ * wmmaN_ / 32>, tileN_ / wmmaN_ / warpN_>, tileM_ / wmmaM_ / warpM_> r_mxY
= array<array<array<float, wmmaM_ * wmmaN_ / 32>, tileN_ / wmmaN_ / warpN_>, tileM_ / wmmaM_ / warpM_>();
register array<array<unsigned, wmmaM_ * wmmaK_ / 64>, tileM_ / wmmaM_ / warpM_> r_mxC;
register array<array<unsigned, wmmaK_ * wmmaN_ / 64>, tileN_ / wmmaN_ / warpN_> r_mxOs;
constexpr int step = std::max(
1, tileM_ / wmmaM_ / warpM_ * tileN_ / wmmaN_ / warpN_ / (tileM_ / wmmaM_ / warpM_ + tileN_ / wmmaN_ / warpN_));
auto baseC = [](auto iK) { return iK % cn<pipeS_> * cn<tileM_> * cn<tileK_>; };
auto baseOs = [](auto iK) { return iK % cn<pipeS_> * cn<tileN_> * cn<tileK_>; };
auto thread = [=](auto iStep)
{
return iStep * cn<warpM_ * warpN_ * 256> + threadIdx_z * cn<warpN_ * 256> + threadIdx_y * cn<256>
+ threadIdx_x * cn<8>;
};
#pragma unroll
for (Rn<UNROLL, div_up(Q_, warpM_ * warpN_ * 256)> iStep; iStep.var < iStep.size; iStep.var++)
if (thread(iStep) < cn<Q_>)
{
#pragma unroll
for (int i = 0; i < 8; i += 4)
{
*(int4*) (s_mxdc + get(thread(iStep)) + i)
= *(int4*) (g_mxdc_ + get((cStart + blockIdx_y) * H * Q + hStart * Q + thread(iStep)) + i);
*(int4*) (s_mxdA + get(thread(iStep)) + i)
= *(int4*) (g_mxdA_ + get((cStart + blockIdx_y) * H * Q + hStart * Q + thread(iStep)) + i);
}
}
#pragma unroll
for (Rn<UNROLL, pipeS_> iK; iK.var < iK.size; iK.var++)
{
#pragma unroll
for (Rn<UNROLL, div_up(tileM_ * tileK_, warpM_ * warpN_ * 256)> iStep; iStep.var < iStep.size; iStep.var++)
if (thread(iStep) < cn<tileM_ * tileK_>
&& thread(iStep) / cn<tileK_> < L - blockIdx_y * Q - mStart * cn<tileM_>)
cp_shared_global<16>(b_mxC + swizzle<tileK_ * 2, tileK_ * 2>(thread(iStep) * cn<2>, baseC(iK) * cn<2>),
g_mxXBC_
+ get((aStart + blockIdx_y * Q + mStart * cn<tileM_> + thread(iStep) / cn<tileK_>) *xbcDim
+ cOffset + gStart * N + iK * cn<tileK_> + thread(iStep) % cn<tileK_>));
else if (thread(iStep) < cn<tileM_ * tileK_>)
*(int4*) ((char*) s_mxC + swizzle<tileK_ * 2, tileK_ * 2>(thread(iStep) * cn<2>, baseC(iK) * cn<2>))
= int4{0, 0, 0, 0};
#pragma unroll
for (Rn<UNROLL, div_up(tileN_ * tileK_, warpM_ * warpN_ * 256)> iStep; iStep.var < iStep.size; iStep.var++)
if (thread(iStep) < cn<tileN_ * tileK_>)
cp_shared_global<16>(
b_mxOs + swizzle<tileN_ * 2, tileN_ * 2>(thread(iStep) * cn<2>, baseOs(iK) * cn<2>),
g_mxOs_
+ get((cStart + blockIdx_y) * H * N * P + hStart * N * P
+ (iK * cn<tileK_> + thread(iStep) / cn<tileN_>) *P + nStart * cn<tileN_>
+ thread(iStep) % cn<tileN_>));
cp_commit_group();
}
asm volatile("cp.async.wait_group %0;\n" ::"n"(pipeS_ - 1));
__syncthreads();
for (int iK = pipeS_; iK < (N_ + Q_) / tileK_ + pipeS_; iK++)
{
auto jK = Rn<>{iK};
if ((iK - pipeS_) * cn<tileK_> == N_)
{
#pragma unroll
for (int y = 0; y < tileM_ / wmmaM_ / warpM_; y++)
#pragma unroll
for (int x = 0; x < tileN_ / wmmaN_ / warpN_; x++)
{
float2 tmp2 = float2{expf(s_mxdA[get(mStart * cn<tileM_> + Rn<UNROLL>{y} * cn<warpM_ * wmmaM_>
+ threadIdx_z * cn<wmmaM_> + threadIdx_x / cn<4>)]),
expf(s_mxdA[get(mStart * cn<tileM_> + Rn<UNROLL>{y} * cn<warpM_ * wmmaM_> + cn<8>
+ threadIdx_z * cn<wmmaM_> + threadIdx_x / cn<4>)])};
r_mxY[y][x][0] *= tmp2.x;
r_mxY[y][x][1] *= tmp2.x;
r_mxY[y][x][2] *= tmp2.y;
r_mxY[y][x][3] *= tmp2.y;
}
}
if ((iK - pipeS_) * cn<tileK_> >= N_)
{
#pragma unroll
for (Rn<UNROLL, div_up(tileM_ * tileK_, warpM_ * warpN_ * 256)> iStep; iStep.var < iStep.size; iStep.var++)
if (thread(iStep) < cn<tileM_ * tileK_>)
{
register Tp_ tmpCB[8];
*(int4*) &tmpCB[0] = *(int4*) ((char*) s_mxC
+ swizzle<tileK_ * 2, tileK_ * 2>(thread(iStep) * cn<2>, baseC(jK) * cn<2>));
#pragma unroll
for (int i = 0; i < 8; i += 2)
{
float2 tmp2 = std::is_same_v<Tp_, half> ? __half22float2(*(half2*) &tmpCB[i])
: bf1622float2(*(bf162*) &tmpCB[i]);
int kStart = (iK - pipeS_) * cn<tileK_> - N_;
tmp2.x *= expf(s_mxdA[get(mStart * cn<tileM_> + thread(iStep) / cn<tileK_>)]
- s_mxdA[kStart + get(thread(iStep) % cn<tileK_> + Rn<UNROLL>{i})])
* s_mxdc[kStart + get(thread(iStep) % cn<tileK_> + Rn<UNROLL>{i})];
tmp2.y *= expf(s_mxdA[get(mStart * cn<tileM_> + thread(iStep) / cn<tileK_>)]
- s_mxdA[kStart + get(thread(iStep) % cn<tileK_> + Rn<UNROLL>{i + 1})])
* s_mxdc[kStart + get(thread(iStep) % cn<tileK_> + Rn<UNROLL>{i + 1})];
if (get(mStart * cn<tileM_> + thread(iStep) / cn<tileK_>)
< kStart + get(thread(iStep) % cn<tileK_> + Rn<UNROLL>{i}))
tmp2.x = 0;
if (get(mStart * cn<tileM_> + thread(iStep) / cn<tileK_>)
< kStart + get(thread(iStep) % cn<tileK_> + Rn<UNROLL>{i + 1}))
tmp2.y = 0;
if (std::is_same_v<Tp_, half>)
*(half2*) &tmpCB[i] = __float22half2_rn(tmp2);
else
*(bf162*) &tmpCB[i] = __float22bfloat162_rn(tmp2);
}
*(int4*) ((char*) s_mxC + swizzle<tileK_ * 2, tileK_ * 2>(thread(iStep) * cn<2>, baseC(jK) * cn<2>))
= *(int4*) &tmpCB[0];
}
__syncthreads();
}
#pragma unroll
for (int k = 0; k < tileK_ / wmmaK_; k++)
{
#pragma unroll
for (int y = 0; y < tileM_ / wmmaM_ / warpM_; y++)
#pragma unroll
for (int x = 0; x < tileN_ / wmmaN_ / warpN_; x++)
{
if ((y * tileN_ / wmmaN_ / warpN_ + x) % step == 0)
{
int x1 = (y * tileN_ / wmmaN_ / warpN_ + x) / step;
int y1 = x1 - tileN_ / wmmaN_ / warpN_
+ (tileM_ / wmmaM_ / warpM_ == 1 || tileN_ / wmmaN_ / warpN_ == 1);
if (y1 >= 0 && y1 < tileM_ / wmmaM_ / warpM_)
{
if (wmmaK_ == 16)
asm volatile(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n"
: "=r"(r_mxC[y1][0]), "=r"(r_mxC[y1][1]), "=r"(r_mxC[y1][2]), "=r"(r_mxC[y1][3])
: "r"(b_mxC + iK % pipeS_ * (tileM_ * tileK_ * 2)
+ 2
* swz<tileK_ * 2, tileK_>(y1 * warpM_ * wmmaM_ * tileK_ + k * wmmaK_
+ threadIdx.z * wmmaM_ * tileK_ + threadIdx.x % 16 * tileK_
+ threadIdx.x / 16 * 8)));
}
if (x1 >= 0 && x1 < tileN_ / wmmaN_ / warpN_)
{
if (wmmaK_ == 16 && x1 % 2 == 0)
asm volatile(
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];\n"
: "=r"(r_mxOs[x1][0]), "=r"(r_mxOs[x1][1]), "=r"(r_mxOs[x1 + 1][0]),
"=r"(r_mxOs[x1 + 1][1])
: "r"(b_mxOs + iK % pipeS_ * (tileK_ * tileN_ * 2)
+ 2
* swz<tileN_ * 2, tileN_>(x1 * warpN_ * wmmaN_ + k * wmmaK_ * tileN_
+ threadIdx.y * wmmaN_ + threadIdx.x % wmmaK_ * tileN_
+ threadIdx.x / wmmaK_ * warpN_ * wmmaN_)));
}
}
}
#pragma unroll
for (int y = 0; y < tileM_ / wmmaM_ / warpM_; y++)
#pragma unroll
for (int x = 0; x < tileN_ / wmmaN_ / warpN_; x++)
{
if (wmmaK_ == 16)
{
if (std::is_same_v<Tp_, half>)
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \n"
" {%0, %1, %2, %3}, \n"
" {%4, %5, %6, %7}, \n"
" {%8, %9}, \n"
" {%0, %1, %2, %3}; \n"
: "+f"(r_mxY[y][x][0]), "+f"(r_mxY[y][x][1]), "+f"(r_mxY[y][x][2]), "+f"(r_mxY[y][x][3])
: "r"(r_mxC[y][0]), "r"(r_mxC[y][1]), "r"(r_mxC[y][2]), "r"(r_mxC[y][3]),
"r"(r_mxOs[x][0]), "r"(r_mxOs[x][1]));
else
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 \n"
" {%0, %1, %2, %3}, \n"
" {%4, %5, %6, %7}, \n"
" {%8, %9}, \n"
" {%0, %1, %2, %3}; \n"
: "+f"(r_mxY[y][x][0]), "+f"(r_mxY[y][x][1]), "+f"(r_mxY[y][x][2]), "+f"(r_mxY[y][x][3])
: "r"(r_mxC[y][0]), "r"(r_mxC[y][1]), "r"(r_mxC[y][2]), "r"(r_mxC[y][3]),
"r"(r_mxOs[x][0]), "r"(r_mxOs[x][1]));
}
}
}
__syncthreads();
if (iK * cn<tileK_> < N_)
{
#pragma unroll
for (Rn<UNROLL, div_up(tileM_ * tileK_, warpM_ * warpN_ * 256)> iStep; iStep.var < iStep.size; iStep.var++)
if (thread(iStep) < cn<tileM_ * tileK_>
&& thread(iStep) / cn<tileK_> < L - blockIdx_y * Q - mStart * cn<tileM_>)
cp_shared_global<16>(
b_mxC + swizzle<tileK_ * 2, tileK_ * 2>(thread(iStep) * cn<2>, baseC(jK) * cn<2>),
g_mxXBC_
+ get((aStart + blockIdx_y * Q + mStart * cn<tileM_> + thread(iStep) / cn<tileK_>) *xbcDim
+ cOffset + gStart * N + jK * cn<tileK_> + thread(iStep) % cn<tileK_>));
else if (thread(iStep) < cn<tileM_ * tileK_>)
*(int4*) ((char*) s_mxC + swizzle<tileK_ * 2, tileK_ * 2>(thread(iStep) * cn<2>, baseC(jK) * cn<2>))
= int4{0, 0, 0, 0};
#pragma unroll
for (Rn<UNROLL, div_up(tileN_ * tileK_, warpM_ * warpN_ * 256)> iStep; iStep.var < iStep.size; iStep.var++)
if (thread(iStep) < cn<tileN_ * tileK_>)
cp_shared_global<16>(
b_mxOs + swizzle<tileN_ * 2, tileN_ * 2>(thread(iStep) * cn<2>, baseOs(jK) * cn<2>),
g_mxOs_
+ get((cStart + blockIdx_y) * H * N * P + hStart * N * P
+ (jK * cn<tileK_> + thread(iStep) / cn<tileN_>) *P + nStart * cn<tileN_>
+ thread(iStep) % cn<tileN_>));
}
else if (iK * cn<tileK_> < N_ + Q_)
{
#pragma unroll
for (Rn<UNROLL, div_up(tileM_ * tileK_, warpM_ * warpN_ * 256)> iStep; iStep.var < iStep.size; iStep.var++)
if (thread(iStep) < cn<tileM_ * tileK_>)
cp_shared_global<16>(
b_mxC + swizzle<tileK_ * 2, tileK_ * 2>(thread(iStep) * cn<2>, baseC(jK) * cn<2>),
g_mxCB_
+ get((cStart + blockIdx_y) * G * Q * Q + gStart * Q * Q
+ (mStart * cn<tileM_> + thread(iStep) / cn<tileK_>) *Q + jK * cn<tileK_>
- N + thread(iStep) % cn<tileK_>));
#pragma unroll
for (Rn<UNROLL, div_up(tileN_ * tileK_, warpM_ * warpN_ * 256)> iStep; iStep.var < iStep.size; iStep.var++)
if (thread(iStep) < cn<tileN_ * tileK_>
&& thread(iStep) / cn<tileN_> < L - blockIdx_y * Q - jK * cn<tileK_> + N)
cp_shared_global<16>(
b_mxOs + swizzle<tileN_ * 2, tileN_ * 2>(thread(iStep) * cn<2>, baseOs(jK) * cn<2>),
g_mxXBC_
+ get((aStart + blockIdx_y * Q + jK * cn<tileK_> - N + thread(iStep) / cn<tileN_>) *xbcDim
+ hStart * P + nStart * cn<tileN_> + thread(iStep) % cn<tileN_>));
else if (thread(iStep) < cn<tileN_ * tileK_>)
*(int4*) ((char*) s_mxOs
+ swizzle<tileN_ * 2, tileN_ * 2>(thread(iStep) * cn<2>, baseOs(jK) * cn<2>))
= int4{0, 0, 0, 0};
}
asm volatile("cp.async.commit_group;\n" ::);
asm volatile("cp.async.wait_group %0;\n" ::"n"(pipeS_ - 1));
__syncthreads();
}
if (g_mxD_)
{
float r_D = g_mxD_[hStart.var];
#pragma unroll
for (int y = 0; y < tileM_ / wmmaM_ / warpM_; y++)
#pragma unroll
for (int x = 0; x < tileN_ / wmmaN_ / warpN_; x++)
{
Tp_ tmp16[4] = {0};
float tmp32[4] = {0};
if (blockIdx_y * Q + mStart * cn<tileM_> + Rn<UNROLL>{y} * cn<warpM_ * wmmaM_>
+ threadIdx_z * cn<wmmaM_> + threadIdx_x / cn<4>
< L)
{
*(int*) &tmp16[0] = *(int*) (g_mxXBC_
+ get((aStart + blockIdx_y * Q + mStart * cn<tileM_> + Rn<UNROLL>{y} * cn<warpM_ * wmmaM_>
+ threadIdx_z * cn<wmmaM_> + threadIdx_x / cn<4>) *xbcDim
+ hStart * P + nStart * cn<tileN_> + Rn<UNROLL>{x} * cn<warpN_ * wmmaN_>
+ threadIdx_y * cn<wmmaN_> + threadIdx_x % cn<4> * cn<2>));
*(float2*) &tmp32[0] = std::is_same_v<Tp_, half> ? __half22float2(*(half2*) &tmp16[0])
: bf1622float2(*(bf162*) &tmp16[0]);
r_mxY[y][x][0] += r_D * tmp32[0];
r_mxY[y][x][1] += r_D * tmp32[1];
}
if (blockIdx_y * Q + mStart * cn<tileM_> + Rn<UNROLL>{y} * cn<warpM_ * wmmaM_> + cn<8>
+ threadIdx_z * cn<wmmaM_> + threadIdx_x / cn<4>
< L)
{
*(int*) &tmp16[2] = *(int*) (g_mxXBC_
+ get((aStart + blockIdx_y * Q + mStart * cn<tileM_> + Rn<UNROLL>{y} * cn<warpM_ * wmmaM_>
+ cn<8> + threadIdx_z * cn<wmmaM_> + threadIdx_x / cn<4>) *xbcDim
+ hStart * P + nStart * cn<tileN_> + Rn<UNROLL>{x} * cn<warpN_ * wmmaN_>
+ threadIdx_y * cn<wmmaN_> + threadIdx_x % cn<4> * cn<2>));
*(float2*) &tmp32[2] = std::is_same_v<Tp_, half> ? __half22float2(*(half2*) &tmp16[2])
: bf1622float2(*(bf162*) &tmp16[2]);
r_mxY[y][x][2] += r_D * tmp32[2];
r_mxY[y][x][3] += r_D * tmp32[3];
}
}
}
if (g_mxZ_)
{
#pragma unroll
for (int y = 0; y < tileM_ / wmmaM_ / warpM_; y++)
#pragma unroll
for (int x = 0; x < tileN_ / wmmaN_ / warpN_; x++)
{
Tp_ tmp16[4] = {0};
float tmp32[4] = {0};
if (blockIdx_y * Q + mStart * cn<tileM_> + Rn<UNROLL>{y} * cn<warpM_ * wmmaM_>
+ threadIdx_z * cn<wmmaM_> + threadIdx_x / cn<4>
< L)
{
*(int*) &tmp16[0] = *(int*) (g_mxZ_
+ get((aStart + blockIdx_y * Q + mStart * cn<tileM_> + Rn<UNROLL>{y} * cn<warpM_ * wmmaM_>
+ threadIdx_z * cn<wmmaM_> + threadIdx_x / cn<4>) *zdtDim
+ hStart * P + nStart * cn<tileN_> + Rn<UNROLL>{x} * cn<warpN_ * wmmaN_>
+ threadIdx_y * cn<wmmaN_> + threadIdx_x % cn<4> * cn<2>));
*(float2*) &tmp32[0] = std::is_same_v<Tp_, half> ? __half22float2(*(half2*) &tmp16[0])
: bf1622float2(*(bf162*) &tmp16[0]);
r_mxY[y][x][0] *= tmp32[0] > 32.f ? tmp32[0] : tmp32[0] / (1.f + expf(-tmp32[0]));
r_mxY[y][x][1] *= tmp32[1] > 32.f ? tmp32[1] : tmp32[1] / (1.f + expf(-tmp32[1]));
}
if (blockIdx_y * Q + mStart * cn<tileM_> + Rn<UNROLL>{y} * cn<warpM_ * wmmaM_> + cn<8>
+ threadIdx_z * cn<wmmaM_> + threadIdx_x / cn<4>
< L)
{
*(int*) &tmp16[2] = *(int*) (g_mxZ_
+ get((aStart + blockIdx_y * Q + mStart * cn<tileM_> + Rn<UNROLL>{y} * cn<warpM_ * wmmaM_>
+ cn<8> + threadIdx_z * cn<wmmaM_> + threadIdx_x / cn<4>) *zdtDim
+ hStart * P + nStart * cn<tileN_> + Rn<UNROLL>{x} * cn<warpN_ * wmmaN_>
+ threadIdx_y * cn<wmmaN_> + threadIdx_x % cn<4> * cn<2>));
*(float2*) &tmp32[2] = std::is_same_v<Tp_, half> ? __half22float2(*(half2*) &tmp16[2])
: bf1622float2(*(bf162*) &tmp16[2]);
r_mxY[y][x][2] *= tmp32[2] > 32.f ? tmp32[2] : tmp32[2] / (1.f + expf(-tmp32[2]));
r_mxY[y][x][3] *= tmp32[3] > 32.f ? tmp32[3] : tmp32[3] / (1.f + expf(-tmp32[3]));
}
}
}
#pragma unroll
for (int y = 0; y < tileM_ / wmmaM_ / warpM_; y++)
#pragma unroll
for (int x = 0; x < tileN_ / wmmaN_ / warpN_; x++)
{
if (std::is_same_v<Tp_, half>)
{
*(half2*) &r_mxY[y][x][0] = __floats2half2_rn(r_mxY[y][x][0], r_mxY[y][x][1]);
*(half2*) &r_mxY[y][x][2] = __floats2half2_rn(r_mxY[y][x][2], r_mxY[y][x][3]);
}
else
{
*(bf162*) &r_mxY[y][x][0] = __floats2bfloat162_rn(r_mxY[y][x][0], r_mxY[y][x][1]);
*(bf162*) &r_mxY[y][x][2] = __floats2bfloat162_rn(r_mxY[y][x][2], r_mxY[y][x][3]);
}
}
#pragma unroll
for (int y = 0; y < tileM_ / wmmaM_ / warpM_; y++)
#pragma unroll
for (int x = 0; x < tileN_ / wmmaN_ / warpN_; x++)
{
asm volatile("st.shared.b32 [%0], %1;\n" ::"r"(b_mxY
+ 2
* swz<tileN_ * 2, tileN_>(y * warpM_ * wmmaM_ * tileN_ + x * warpN_ * wmmaN_
+ (threadIdx.z * wmmaM_ + threadIdx.x / 4) * tileN_
+ (threadIdx.y * wmmaN_ + threadIdx.x % 4 * 2))),
"r"(*(unsigned*) &r_mxY[y][x][0]));
asm volatile("st.shared.b32 [%0], %1;\n" ::"r"(b_mxY
+ 2
* swz<tileN_ * 2, tileN_>(y * warpM_ * wmmaM_ * tileN_ + 8 * tileN_
+ x * warpN_ * wmmaN_ + (threadIdx.z * wmmaM_ + threadIdx.x / 4) * tileN_
+ (threadIdx.y * wmmaN_ + threadIdx.x % 4 * 2))),
"r"(*(unsigned*) &r_mxY[y][x][2]));
}
__syncthreads();
#pragma unroll
for (Rn<UNROLL, div_up(tileM_ * tileN_, warpM_ * warpN_ * 256)> iStep; iStep.var < iStep.size; iStep.var++)
if (thread(iStep) < cn<tileM_ * tileN_>
&& thread(iStep) / cn<tileN_> < L - blockIdx_y * Q - mStart * cn<tileM_>)
*(int4*) (g_mxY_
+ get((aStart + blockIdx_y * Q + mStart * cn<tileM_> + thread(iStep) / cn<tileN_>) *H * P + hStart * P
+ nStart * cn<tileN_> + thread(iStep) % cn<tileN_>))
= *(int4*) ((char*) s_mxY + swizzle<tileN_ * 2, tileN_ * 2>(thread(iStep) * cn<2>));
asm volatile("cp.async.wait_group %0;\n" ::"n"(0));
#endif
}
ChunkScanKernelFuncFp16 getChunkScanKernelFp16(
int B_, int L_, int H_, int P_, int G_, int N_, int Q_, dim3* blockDims_, dim3* threadDims_, int* sharedMem_)
{
int B = B_;
int L = L_;
int H = H_;
int P = P_;
// int G = G_;
// int N = N_;
int Q = Q_;
int C = div_up(L, Q);
int tileM = 128;
int tileN = 64;
int tileK = 32;
int warpM = 4;
int warpN = 1;
int pipeS = 2;
auto sharedMem = std::max((tileM * tileK + tileK * tileN) * pipeS * 2 + Q * 8, (tileM * tileN) * 2);
*blockDims_ = dim3(H * P / tileN * Q / tileM, C, B);
*threadDims_ = dim3(32, warpN, warpM);
*sharedMem_ = sharedMem;
if (Q_ == 128)
return chunk_scan_kernel<128, 128, 64, 32, 16, 8, 16, 4, 1, 2, half>;
else if (Q_ == 256)
return chunk_scan_kernel<256, 128, 64, 32, 16, 8, 16, 4, 1, 2, half>;
else
return nullptr;
}
ChunkScanKernelFuncBf16 getChunkScanKernelBf16(
int B_, int L_, int H_, int P_, int G_, int N_, int Q_, dim3* blockDims_, dim3* threadDims_, int* sharedMem_)
{
int B = B_;
int L = L_;
int H = H_;
int P = P_;
// int G = G_;
// int N = N_;
int Q = Q_;
int C = div_up(L, Q);
int tileM = 128;
int tileN = 64;
int tileK = 32;
int warpM = 4;
int warpN = 1;
int pipeS = 2;
auto sharedMem = std::max((tileM * tileK + tileK * tileN) * pipeS * 2 + Q * 8, (tileM * tileN) * 2);
*blockDims_ = dim3(H * P / tileN * Q / tileM, C, B);
*threadDims_ = dim3(32, warpN, warpM);
*sharedMem_ = sharedMem;
if (Q_ == 128)
return chunk_scan_kernel<128, 128, 64, 32, 16, 8, 16, 4, 1, 2, bf16>;
else if (Q_ == 256)
return chunk_scan_kernel<256, 128, 64, 32, 16, 8, 16, 4, 1, 2, bf16>;
else
return nullptr;
}
} // namespace kernels
} // namespace tensorrt_llm
// vim: ts=2 sw=2 sts=2 et sta