TensorRT-LLMs/cpp/kernels/fmha_v2/src/fmha/mask.h
Perkz Zheng 1c5b0d6a13
[Feat] add chunked-attention kernels on Hopper (for llama4) (#4291)
* update cubins

Signed-off-by: Perkz Zheng <67892460+PerkzZheng@users.noreply.github.com>

* add mtp for fmha_v2 MLA kernels and add chunked-attention support for hopper fmha kernels

Signed-off-by: Perkz Zheng <67892460+PerkzZheng@users.noreply.github.com>

---------

Signed-off-by: Perkz Zheng <67892460+PerkzZheng@users.noreply.github.com>
Co-authored-by: Sharan Chetlur <116769508+schetlur-nv@users.noreply.github.com>
2025-05-19 09:57:10 -07:00

959 lines
30 KiB
C++

/*
* SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement
*
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
* property and proprietary rights in and to this material, related
* documentation and any modifications thereto. Any use, reproduction,
* disclosure or distribution of this material and related documentation
* without an express license agreement from NVIDIA CORPORATION or
* its affiliates is strictly prohibited.
*/
#pragma once
#include "fmha/traits.h"
namespace fmha
{
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Traits, typename Cta_tile, int FMHA_VERSION>
struct Mask
{
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Traits, typename Cta_tile>
struct Mask<Traits, Cta_tile, 1>
{
// The shape of the MMA tile.
using Mma_tile = typename Traits::template Mma_tile<Cta_tile>;
// The number of MMAs in each dimension.
enum
{
MMAS_M = Mma_tile::MMAS_M
};
enum
{
MMAS_N = Mma_tile::MMAS_N
};
// Ctor.
template <typename Params, typename Block_info>
inline __device__ Mask(Params const& params, Block_info const& block_info, int tidx)
{
// The pointer.
packed_mask_ptr_ = reinterpret_cast<char const*>(params.packed_mask_ptr);
// Take the head into account.
packed_mask_ptr_ += block_info.bidb * params.packed_mask_stride_in_bytes;
// The thread inside the CTA.
packed_mask_ptr_ += tidx * sizeof(uint32_t);
}
// Load the mask into registers (and expand).
inline __device__ void load(int it)
{
// One 32-bit integer per MMA.
uint32_t packed_mask[MMAS_M];
#pragma unroll
for (int mi = 0; mi < MMAS_M; ++mi)
{
int offset = (it * MMAS_M + mi) * Cta_tile::THREADS_PER_CTA * sizeof(uint32_t);
fmha::ldg(packed_mask[mi], packed_mask_ptr_ + offset);
}
// Expand the mask.
#pragma unroll
for (int mi = 0; mi < MMAS_M; ++mi)
{
#pragma unroll
for (int ni = 0; ni < MMAS_N; ++ni)
{
mask_[2 * mi + 0][4 * ni + 0] = packed_mask[mi] & (1u << (8 * ni + 0));
mask_[2 * mi + 0][4 * ni + 1] = packed_mask[mi] & (1u << (8 * ni + 1));
mask_[2 * mi + 1][4 * ni + 0] = packed_mask[mi] & (1u << (8 * ni + 2));
mask_[2 * mi + 1][4 * ni + 1] = packed_mask[mi] & (1u << (8 * ni + 3));
mask_[2 * mi + 0][4 * ni + 2] = packed_mask[mi] & (1u << (8 * ni + 4));
mask_[2 * mi + 0][4 * ni + 3] = packed_mask[mi] & (1u << (8 * ni + 5));
mask_[2 * mi + 1][4 * ni + 2] = packed_mask[mi] & (1u << (8 * ni + 6));
mask_[2 * mi + 1][4 * ni + 3] = packed_mask[mi] & (1u << (8 * ni + 7));
}
}
}
// Is a given position valid?
inline __device__ bool is_valid(int mi, int ni, int ii, int jj) const
{
return mask_[mi * 2 + ii][ni * 4 + jj];
}
// The pointer to the mask.
char const* packed_mask_ptr_;
// The mask after expansion.
bool mask_[MMAS_M * 2][MMAS_N * 4];
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Cta_tile>
struct Mask<Volta_hmma_fp16_traits, Cta_tile, 1>
{
// The instruction traits.
using Traits = Volta_hmma_fp16_traits;
// The shape of the MMA tile.
using Mma_tile = typename Traits::Mma_tile<Cta_tile>;
// The number of MMAs in each dimension.
enum
{
MMAS_M = Mma_tile::MMAS_M
};
enum
{
MMAS_N = Mma_tile::MMAS_N
};
// Ctor.
template <typename Params, typename Block_info>
inline __device__ Mask(Params const& params, Block_info const& block_info, int tidx)
{
// The pointer.
packed_mask_ptr_ = reinterpret_cast<char const*>(params.packed_mask_ptr);
// Take the head into account.
packed_mask_ptr_ += block_info.bidb * params.packed_mask_stride_in_bytes;
// The thread inside the CTA.
packed_mask_ptr_ += tidx * sizeof(uint32_t);
}
// Load the mask into registers (and expand).
inline __device__ void load(int it)
{
// One 32-bit integer per MMA.
uint32_t packed_mask[MMAS_M];
#pragma unroll
for (int mi = 0; mi < MMAS_M; ++mi)
{
int offset = (it * MMAS_M + mi) * Cta_tile::THREADS_PER_CTA * sizeof(uint32_t);
fmha::ldg(packed_mask[mi], packed_mask_ptr_ + offset);
}
// Expand the mask.
#pragma unroll
for (int mi = 0; mi < MMAS_M; ++mi)
{
#pragma unroll
for (int ii = 0; ii < MMAS_N * 8; ++ii)
{
mask_[mi][ii] = packed_mask[mi] & (1u << ii);
}
}
}
// Is a given position valid?
inline __device__ bool is_valid(int mi, int ni, int, int jj) const
{
return mask_[mi][ni * 8 + jj];
}
// The pointer to the mask.
char const* packed_mask_ptr_;
// The mask after expansion.
bool mask_[MMAS_M][MMAS_N * 8];
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Traits, typename Cta_tile>
struct Mask<Traits, Cta_tile, 2>
{
// That implementation works only when WARPS_K is 1.
static_assert(Cta_tile::WARPS_K == 1, "");
// The shape of the MMA tile.
using Mma_tile = typename Traits::template Mma_tile<Cta_tile>;
// Ctor.
template <typename Params, typename Block_info>
inline __device__ Mask(Params const& params, Block_info const& block_info, int tidx)
: seqlen_(block_info.actual_seqlen)
, col_loop_step_(0)
{
// The decomposition of the thread index into warp/lane.
int warp = tidx / Cta_tile::THREADS_PER_WARP;
int lane = tidx % Cta_tile::THREADS_PER_WARP;
// The position of the warp.
int warp_n = warp / Cta_tile::WARPS_M;
// The position of the thread.
col_ = block_info.bidn * Cta_tile::N + warp_n * 16 + lane % 4 * 2;
col_init_ = col_;
}
// Is a given position valid?
inline __device__ bool is_valid(int, int ni, int, int jj) const
{
// The position of the thread in the sequence.
int offset = this->col_ + this->col_loop_step_ * Cta_tile::N + ni * Mma_tile::N_PER_MMA_PER_CTA;
// The position inside the MMA.
offset += (jj & 0x02) * 4 + (jj & 0x1);
// Is it a valid position in the sequence?
return offset < seqlen_;
}
// BERT Mask: if upper left is invalid, none are valid
inline __device__ bool any_valid(int mi, int ni) const
{
return is_valid(mi, ni, 0, 0);
}
// Move mask to next tile (flash attention)
inline __device__ void move()
{
this->col_ += Cta_tile::N;
}
// Move mask the col by offset (flash attention)
inline __device__ void move_to_offset(int offset)
{
this->col_ = col_init_ + offset;
}
// Reset mask to the initial col
inline __device__ void reset()
{
col_ = col_init_;
}
// Load the mask... Nothing to do for real.
inline __device__ void load(int) {}
// Load the mask... we use it to keep track of to row, col (flash attention).
inline __device__ void load(int, int col_loop_step)
{
col_loop_step_ = col_loop_step;
}
// The length of the sequence.
int seqlen_;
// The left-most position of the thread in the sequence.
int col_, col_init_;
// The current col iteration
int col_loop_step_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Traits, typename Cta_tile>
struct Mask<Traits, Cta_tile, 3> : public Mask<Traits, Cta_tile, 2>
{
// V3 mask is the causal mask (e.g. for GPT) and extends V2 masks (self-attention).
using Base = Mask<Traits, Cta_tile, 2>;
// The shape of the MMA tile.
using Mma_tile = typename Base::Mma_tile;
// Ctor.
template <typename Params, typename Block_info>
inline __device__ Mask(Params const& params, Block_info const& block_info, int tidx)
: Base(params, block_info, tidx)
, row_loop_step_(0)
{
int warp = tidx / Cta_tile::THREADS_PER_WARP;
int lane = tidx % Cta_tile::THREADS_PER_WARP;
// The position of the warp.
int warp_m = warp % Cta_tile::WARPS_M;
row_ = warp_m * 16 + lane / 4;
}
inline __device__ void get_row_col(int& row, int& col, int mi, int ni, int ii, int jj) const
{
// The position of the thread in the sequence.
row = this->row_ + this->row_loop_step_ + mi * Mma_tile::M_PER_MMA_PER_CTA;
// The position inside the MMA.
row += ii * 8;
// The position of the thread in the sequence.
col = this->col_ + this->col_loop_step_ * Cta_tile::N + ni * Mma_tile::N_PER_MMA_PER_CTA;
// The position inside the MMA.
col += (jj & 0x02) * 4 + (jj & 0x1);
}
// Is a given position valid?
inline __device__ bool is_valid(int mi, int ni, int ii, int jj) const
{
int row, col;
get_row_col(row, col, mi, ni, ii, jj);
// Is it a valid position in the sequence?
return is_valid(row, col);
}
// Is a given position valid?
inline __device__ bool is_valid(int row, int col) const
{
// Is it a valid position in the sequence, i.e. are we in the lower triangle?
return (row >= col);
}
// GPT Mask: if lower left is invalid, none are valid
inline __device__ bool any_valid(int mi, int ni) const
{
return is_valid(mi, ni, 1, 0);
}
// Load the mask... we use it to keep track of to row.
inline __device__ void load(int row_loop_step)
{
row_loop_step_ = row_loop_step;
}
// Load the mask... we use it to keep track of to row, col (flash attention).
inline __device__ void load(int row_loop_step, int col_loop_step)
{
row_loop_step_ = row_loop_step;
this->col_loop_step_ = col_loop_step;
}
// The upper-most position of the thread in the sequence.
int row_;
// Current row step offset.
int row_loop_step_;
};
// Specialized mask for MTP (multi-token prediction used in MLA).
template <typename Traits, typename Cta_tile>
struct MtpMask : public Mask<Traits, Cta_tile, 2>
{
// MTP mask (causal mask) extends from V2 (dense) masks (self-attention).
using Base = Mask<Traits, Cta_tile, 2>;
// The shape of the MMA tile.
using Mma_tile = typename Base::Mma_tile;
// Ctor.
template <typename Params, typename Block_info>
inline __device__ MtpMask(Params const& params, Block_info const& block_info, int tidx)
: Base(params, block_info, tidx)
, num_grouped_heads_(params.num_grouped_heads)
, row_loop_step_(0)
{
// Update the seqlen (excluding all MTP draft tokens).
this->seqlen_ = this->seqlen_ - (block_info.actual_q_seqlen / params.num_grouped_heads) + 1;
int warp = tidx / Cta_tile::THREADS_PER_WARP;
int lane = tidx % Cta_tile::THREADS_PER_WARP;
// The position of the warp.
int warp_m = warp % Cta_tile::WARPS_M;
row_ = warp_m * 16 + lane / 4;
}
inline __device__ int get_row(int mi, int ii) const
{
// The position of the thread in the sequence.
int row = this->row_ + this->row_loop_step_ + mi * Mma_tile::M_PER_MMA_PER_CTA;
// The position inside the MMA.
row += ii * 8;
return row;
}
inline __device__ int get_col(int ni, int jj) const
{
// The position of the thread in the sequence.
int col = this->col_ + this->col_loop_step_ * Cta_tile::N + ni * Mma_tile::N_PER_MMA_PER_CTA;
// The position inside the MMA.
col += (jj & 0x02) * 4 + (jj & 0x1);
return col;
}
inline __device__ void get_row_col(int& row, int& col, int mi, int ni, int ii, int jj) const
{
row = get_row(mi, ii);
col = get_col(ni, jj);
}
// Is a given position valid?
inline __device__ bool is_valid(int mi, int ni, int ii, int jj) const
{
int col = get_col(ni, jj);
// Is it a valid position in the sequence?
return col < (this->seqlen_ + mtp_token_idx_[mi][ii]);
}
// Is a given position valid?
inline __device__ bool is_valid(int row, int col) const
{
// Is it a valid position in the sequence, i.e. are we in the lower triangle?
return (row >= col);
}
// Load the mask... we use it to keep track of to row.
inline __device__ void load(int row_loop_step)
{
row_loop_step_ = row_loop_step;
// Update the MTP token index.
#pragma unroll
for (int mi = 0; mi < Mma_tile::MMAS_M; ++mi)
{
#pragma unroll
for (int ii = 0; ii < 2; ++ii)
{
mtp_token_idx_[mi][ii] = get_row(mi, ii) / num_grouped_heads_;
}
}
}
// The number of grouped heads in the row dimension.
int num_grouped_heads_;
// The corresponding MTP token index for each row.
// FIXME: currently we assume 2 rows per thread (volta/hopper-gmma traits are not supported yet).
int mtp_token_idx_[Mma_tile::MMAS_M][2];
// The upper-most position of the thread in the sequence.
int row_;
// The current row step offset.
int row_loop_step_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
// The lower triangle attention matrix.
// Assume we only pay attention to past sliding-window-size long sequence.
// v x x x x x x x x
// v v x x x x x x x
// v v v x x x x x x
// v v v v x x x x x
// v v v v v x x x x
// x v v v v v x x x
// x x v v v v v x x
// x x x v v v v v x
// x x x x v v v v v
template <typename Traits, typename Cta_tile>
struct Mask<Traits, Cta_tile, 4> : public Mask<Traits, Cta_tile, 3>
{
// V4 mask is the causal mask (e.g. for GPT) plus the sliding-window feature.
using Base = Mask<Traits, Cta_tile, 3>;
// The shape of the MMA tile.
using Mma_tile = typename Base::Mma_tile;
// Ctor.
template <typename Params, typename Block_info>
inline __device__ Mask(Params const& params, Block_info const& block_info, int tidx)
: Base(params, block_info, tidx)
, sliding_window_size_(params.sliding_window_size)
{
}
// Is a given position valid?
inline __device__ bool is_valid(int mi, int ni, int ii, int jj) const
{
int row, col;
this->get_row_col(row, col, mi, ni, ii, jj);
// Is it a valid position in the sequence?
return is_valid(row, col);
}
// Is a given position valid?
inline __device__ bool is_valid(int row, int col) const
{
// Is it a valid position in the sequence, i.e. are we in the lower triangle?
return (row >= col) && (col >= max(0, row - sliding_window_size_));
}
// The sliding window size.
int sliding_window_size_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
// The custom mask (from global memory).
template <typename Traits, typename Cta_tile>
struct Mask<Traits, Cta_tile, 5> : public Mask<Traits, Cta_tile, 3>
{
using Base = Mask<Traits, Cta_tile, 3>;
// The shape of the MMA tile.
using Mma_tile = typename Base::Mma_tile;
// The number of MMAs in each dimension.
enum
{
MMAS_M = Mma_tile::MMAS_M
};
enum
{
MMAS_N = Mma_tile::MMAS_N
};
// One 32-bit packed mask holds 4 MMAS_N as one group.
enum
{
MMA_GROUPS_N = fmha::Div_up<MMAS_N, 4>::VALUE
};
// The MMAS_N in the group.
enum
{
MMAS_N_IN_GROUP = fmha::Min<MMAS_N, 4>::VALUE
};
// MMAS_N uses full 32-bit integer packed masks.
enum
{
FULL_PACKED_MASK = (MMAS_N % 4 == 0)
};
// Ctor.
template <typename Params, typename Block_info>
inline __device__ Mask(Params const& params, Block_info const& block_info, int tidx)
: Base(params, block_info, tidx)
, packed_mask_ptr_(reinterpret_cast<char const*>(params.packed_mask_ptr))
, params_packed_mask_stride_in_bytes_(params.packed_mask_stride_in_bytes)
, row_offset_(0)
{
// Add the thread offset in bytes.
packed_mask_ptr_ += (block_info.sum_mask_row * params_packed_mask_stride_in_bytes_ + tidx * sizeof(uint32_t));
}
// Load the mask... we use it to keep track of row offset.
inline __device__ void load(int row_offset)
{
row_offset_ = row_offset;
}
// Load the mask into registers (and expand).
inline __device__ void load_mask(int col_offset)
{
// The packed_mask_offset in the col(N) dimension.
int mask_col_offset
= int(col_offset / (Mma_tile::N_PER_MMA_PER_CTA * 4)) * Cta_tile::THREADS_PER_CTA * sizeof(uint32_t);
// When MMAS_N < 4, one loaded packed_mask can be expanded to boolean masks
// of multiple iterations.
int local_col = FULL_PACKED_MASK ? 0 : (col_offset % (Mma_tile::N_PER_MMA_PER_CTA * 4));
// The local mma ni if MMAS_N < 4.
int local_ni = local_col / 16;
#pragma unroll
for (int mi = 0; mi < MMAS_M; ++mi)
{
// The M dimension offset.
int offset = (row_offset_ + mi * Mma_tile::M_PER_MMA_PER_CTA) * params_packed_mask_stride_in_bytes_;
// The N dimension offset.
offset += mask_col_offset;
// Set predicate to true only when next 32-bit packed mask is needed.
bool pred = local_col == 0;
#pragma unroll
for (int ni = 0; ni < MMA_GROUPS_N; ++ni)
{
// The MMAS_N group offset.
if (pred)
{
fmha::ldg(packed_mask_[mi][ni],
packed_mask_ptr_ + offset + ni * Cta_tile::THREADS_PER_CTA * sizeof(uint32_t));
}
}
}
// Expand the mask.
#pragma unroll
for (int mi = 0; mi < MMAS_M; ++mi)
{
#pragma unroll
for (int ni = 0; ni < MMA_GROUPS_N; ++ni)
{
#pragma unroll
for (int nni = 0; nni < MMAS_N_IN_GROUP; ++nni)
{
mask_[2 * mi + 0][(ni * 4 + nni) * 4 + 0]
= packed_mask_[mi][ni] & (1u << (8 * (nni + local_ni) + 0));
mask_[2 * mi + 0][(ni * 4 + nni) * 4 + 1]
= packed_mask_[mi][ni] & (1u << (8 * (nni + local_ni) + 1));
mask_[2 * mi + 1][(ni * 4 + nni) * 4 + 0]
= packed_mask_[mi][ni] & (1u << (8 * (nni + local_ni) + 2));
mask_[2 * mi + 1][(ni * 4 + nni) * 4 + 1]
= packed_mask_[mi][ni] & (1u << (8 * (nni + local_ni) + 3));
mask_[2 * mi + 0][(ni * 4 + nni) * 4 + 2]
= packed_mask_[mi][ni] & (1u << (8 * (nni + local_ni) + 4));
mask_[2 * mi + 0][(ni * 4 + nni) * 4 + 3]
= packed_mask_[mi][ni] & (1u << (8 * (nni + local_ni) + 5));
mask_[2 * mi + 1][(ni * 4 + nni) * 4 + 2]
= packed_mask_[mi][ni] & (1u << (8 * (nni + local_ni) + 6));
mask_[2 * mi + 1][(ni * 4 + nni) * 4 + 3]
= packed_mask_[mi][ni] & (1u << (8 * (nni + local_ni) + 7));
}
}
}
}
// Move mask the col by offset (flash attention)
inline __device__ void move_to_offset(int col_offset)
{
load_mask(col_offset);
}
// Is a given position valid?
inline __device__ bool is_valid(int mi, int ni, int ii, int jj) const
{
return mask_[mi * 2 + ii][ni * 4 + jj];
}
// Current row step offset.
int row_offset_;
// The pointer to the mask.
char const* packed_mask_ptr_;
// The stride in the n dimension.
int64_t const params_packed_mask_stride_in_bytes_;
// The packed mask (one 32-bit integer per MMA GROUP, MMAS_M * 2 rows, MMA_GROUPS_N * 16 cols).
uint32_t packed_mask_[MMAS_M][MMA_GROUPS_N];
// The mask after expansion.
bool mask_[MMAS_M * 2][MMAS_N * 4];
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Cta_tile>
struct Mask<Volta_hmma_fp16_traits, Cta_tile, 2>
{
// The instruction traits.
using Traits = Volta_hmma_fp16_traits;
// The shape of the MMA tile.
using Mma_tile = typename Traits::Mma_tile<Cta_tile>;
// That implementation works only when WARPS_K is 1.
static_assert(Cta_tile::WARPS_K == 1, "");
// Ctor.
template <typename Params, typename Block_info>
inline __device__ Mask(Params const& params, Block_info const& block_info, int tidx)
: seqlen_(block_info.actual_seqlen)
{
// The decomposition of the thread index into warp/lane.
int warp = tidx / Cta_tile::THREADS_PER_WARP;
int lane = tidx % Cta_tile::THREADS_PER_WARP;
// The position of the warp.
int warp_n = warp / Cta_tile::WARPS_M;
// The position of the thread.
col_ = block_info.bidn * Cta_tile::N + warp_n * 16 + (lane & 0x08) / 2;
col_init_ = col_;
}
// Is a given position valid?
inline __device__ bool is_valid(int, int ni, int, int jj) const
{
// The position of the thread in the sequence.
int offset = this->col_ + ni * Mma_tile::N_PER_MMA_PER_CTA;
// The position inside the MMA.
offset += (jj & 0x04) * 2 + (jj & 0x03);
// Is it a valid position in the sequence?
return offset < seqlen_;
}
// Load the mask... Nothing to do for real.
inline __device__ void load(int) {}
// Reset mask to the initial col
inline __device__ void reset()
{
col_ = col_init_;
}
// Move mask to next tile (flash attention)
inline __device__ void move()
{
this->col_ += Cta_tile::N;
}
// Move mask the col by offset (flash attention)
inline __device__ void move_to_offset(int offset)
{
this->col_ = col_init_ + offset;
}
// The length of the sequence.
int const seqlen_;
// The left-most position of the thread in the sequence.
int col_, col_init_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Cta_tile>
struct Mask<Volta_hmma_fp16_traits, Cta_tile, 3> : public Mask<Volta_hmma_fp16_traits, Cta_tile, 2>
{
// V3 mask is the causal mask (e.g. for GPT) and extends V2 masks (self-attention).
using Base = Mask<Volta_hmma_fp16_traits, Cta_tile, 2>;
// The shape of the MMA tile.
using Mma_tile = typename Base::Mma_tile;
// Ctor.
template <typename Params, typename Block_info>
inline __device__ Mask(Params const& params, Block_info const& block_info, int tidx)
: Base(params, block_info, tidx)
, loop_step_(0)
{
int warp = tidx / Cta_tile::THREADS_PER_WARP;
int lane = tidx % Cta_tile::THREADS_PER_WARP;
// The position of the warp.
int warp_m = warp % Cta_tile::WARPS_M;
row_ = warp_m * 16 + (lane & 0x07) + (lane & 0x10) / 2;
}
inline __device__ void get_row_col(int& row, int& col, int mi, int ni, int ii, int jj) const
{
// The position of the thread in the sequence.
row = this->row_ + this->loop_step_ + mi * Mma_tile::M_PER_MMA_PER_CTA;
// The position of the thread in the sequence.
col = this->col_ + ni * Mma_tile::N_PER_MMA_PER_CTA;
// The position inside the MMA.
col += (jj & 0x04) * 2 + (jj & 0x03);
}
// Is a given position valid?
inline __device__ bool is_valid(int mi, int ni, int ii, int jj) const
{
int row, col;
get_row_col(row, col, mi, ni, ii, jj);
// Is it a valid position in the sequence?
return is_valid(row, col);
}
// Is a given position valid?
inline __device__ bool is_valid(int row, int col) const
{
// Is it a valid position in the sequence, i.e. are we in the lower triangle?
return (row >= col) && (col < this->seqlen_);
}
// GPT Mask: if lower left is invalid, none are valid
inline __device__ bool any_valid(int mi, int ni) const
{
return is_valid(mi, ni, 0, 0);
}
// Load the mask... we use it to keep track of to row.
inline __device__ void load(int loop_step)
{
loop_step_ = loop_step;
}
// The upper-most position of the thread in the sequence.
int row_;
// Current iteration.
int loop_step_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Traits, typename Cta_tile, int FMHA_VERSION, bool IS_MTP>
struct Mask_dispatcher
{
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Traits, typename Cta_tile, int FMHA_VERSION>
struct Mask_dispatcher<Traits, Cta_tile, FMHA_VERSION, false> : public Mask<Traits, Cta_tile, FMHA_VERSION>
{
using Base = Mask<Traits, Cta_tile, FMHA_VERSION>;
template <typename Params, typename Block_info>
inline __device__ Mask_dispatcher(Params const& params, Block_info const& block_info, int tidx)
: Base(params, block_info, tidx)
{
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Traits, typename Cta_tile, int FMHA_VERSION>
struct Mask_dispatcher<Traits, Cta_tile, FMHA_VERSION, true> : public MtpMask<Traits, Cta_tile>
{
using Base = MtpMask<Traits, Cta_tile>;
template <typename Params, typename Block_info>
inline __device__ Mask_dispatcher(Params const& params, Block_info const& block_info, int tidx)
: Base(params, block_info, tidx)
{
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Traits, typename Cta_tile, int FMHA_VERSION>
struct Mask_hopper
{
// The shape of the MMA tile.
using Mma_tile = typename Traits::template Mma_tile<Cta_tile>;
// Ctor.
template <typename Params, typename Block_info>
inline __device__ Mask_hopper(Params const& params, Block_info const& block_info, int tidx)
: seqlen_(block_info.actual_seqlen)
{
// For Hopper the warp distribution is always 4x1 within a warpgroup.
// So maybe there is some assumptions/optimizations to be made here.
int warp = tidx / Cta_tile::THREADS_PER_WARP;
int warp_n = warp / 4;
int lane = tidx % Cta_tile::THREADS_PER_WARP;
col_ = warp_n * Mma_tile::N_PER_WARP_GROUP + (lane % 4) * 2;
}
// Is a given position valid?
inline __device__ bool is_valid(int, int ni, int, int jj) const
{
// The position of the thread in the sequence.
int offset = this->col_ + ni * Mma_tile::N_PER_MMA;
// The position inside the MMA.
offset += (jj / 2) * 8 + (jj % 2);
// Is it a valid position in the sequence?
return offset < seqlen_;
}
// Load the mask... Nothing to do for real.
inline __device__ void load(int) {}
// The length of the sequence.
int const seqlen_;
// The left-most position of the thread in the sequence.
int col_;
};
template <typename Traits, typename Cta_tile>
struct Mask_hopper<Traits, Cta_tile, 3>
{
// The shape of the MMA tile.
using Mma_tile = typename Traits::template Mma_tile<Cta_tile>;
// Ctor.
template <typename Params, typename Block_info>
inline __device__ Mask_hopper(Params const& params, Block_info const& block_info, int tidx)
{
// For Hopper the warp distribution is always 4x1 within a warpgroup.
// So maybe there is some assumptions/optimizations to be made here.
int warp = tidx / Cta_tile::THREADS_PER_WARP;
int warp_n = warp / 4;
int warp_m = warp % 4;
int lane = tidx % Cta_tile::THREADS_PER_WARP;
col_ = warp_n * Mma_tile::N_PER_WARP_GROUP + (lane % 4) * 2;
row_base_ = warp_m * 16 + lane / 4;
row_ = row_base_;
}
inline __device__ void get_row_col(int& row, int& col, int mi, int ni, int ii, int jj) const
{
// The row position of the thread in the sequence.
row = row_ + mi * Mma_tile::M_PER_MMA + ii * 8;
// The position of the thread in the sequence.
col = this->col_ + ni * Mma_tile::N_PER_MMA;
// The position inside the MMA.
col += (jj / 2) * 8 + (jj % 2);
}
// Is a given position valid?
inline __device__ bool is_valid(int mi, int ni, int ii, int jj) const
{
int row, col;
get_row_col(row, col, mi, ni, ii, jj);
// Is it a valid position in the sequence?
return is_valid(row, col);
}
// Is a given position valid?
inline __device__ bool is_valid(int row, int col) const
{
// Is it a valid position in the sequence?
return col <= row;
}
// Load the mask... Nothing to do for real.
inline __device__ void load(int loop_step)
{
row_ = row_base_ + loop_step * Cta_tile::M;
}
// The left-most position of the thread in the sequence.
int row_, row_base_, col_;
};
template <typename Traits, typename Cta_tile>
struct Mask_hopper<Traits, Cta_tile, 4> : public Mask_hopper<Traits, Cta_tile, 3>
{
// V4 mask is the causal mask (e.g. for GPT) plus the sliding-window feature.
using Base = Mask_hopper<Traits, Cta_tile, 3>;
// The shape of the MMA tile.
using Mma_tile = typename Traits::template Mma_tile<Cta_tile>;
// Ctor.
template <typename Params, typename Block_info>
inline __device__ Mask_hopper(Params const& params, Block_info const& block_info, int tidx)
: Base(params, block_info, tidx)
, sliding_window_size_(params.sliding_window_size)
{
}
// Is a given position valid?
inline __device__ bool is_valid(int mi, int ni, int ii, int jj) const
{
int row, col;
this->get_row_col(row, col, mi, ni, ii, jj);
// Is it a valid position in the sequence?
return is_valid(row, col);
}
// Is a given position valid?
inline __device__ bool is_valid(int row, int col) const
{
// Is it a valid position in the sequence?
return col <= row && col >= max(0, row - sliding_window_size_);
}
// The sliding window size for attention.
int sliding_window_size_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace fmha