TensorRT-LLMs/cpp/kernels/fmha_v2/src/fmha/traits.h
qsang-nv 0fd59d64ab
infra: open source fmha v2 kernels (#4185)
* add fmha repo

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>

* fix format

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>

* fix code style

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>

* fix header

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>

* fix header kernel_traits.h

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>

* add .gitignore file

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>

* add SLIDING_WINDOW_ATTENTION

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>

* fix style

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>

* fix format

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>

* update setup.py

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>

* update build_wheel.py

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>

---------

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>
Signed-off-by: qsang-nv <200703406+qsang-nv@users.noreply.github.com>
2025-05-15 10:56:34 +08:00

1205 lines
35 KiB
C++

/*
* SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement
*
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
* property and proprietary rights in and to this material, related
* documentation and any modifications thereto. Any use, reproduction,
* disclosure or distribution of this material and related documentation
* without an express license agreement from NVIDIA CORPORATION or
* its affiliates is strictly prohibited.
*/
#pragma once
#include "fmha/numeric_types.h"
#include <fmha/utils.h>
#define FMHA_DIV_UP(m, n) (((m) + (n) -1) / (n))
namespace fmha
{
////////////////////////////////////////////////////////////////////////////////////////////////////
// Trait class for heuristically determining the tile sizes
template <bool GRANULAR, int STEP, int S, int D, int DV, int K_PER_MMA>
struct Traits_tile_size;
template <int STEP, int S, int D, int DV, int K_PER_MMA>
struct Traits_tile_size</* GRANULAR = */ false, STEP, S, D, DV, K_PER_MMA>
{
enum
{
CTA_P_TILE_M = STEP,
CTA_P_TILE_N = S,
CTA_P_TILE_K = D,
CTA_O_TILE_M = CTA_P_TILE_M,
CTA_O_TILE_N = DV,
CTA_O_TILE_K = S
};
};
template <int STEP, int S, int D, int DV, int K_PER_MMA>
struct Traits_tile_size</* GRANULAR = */ true, STEP, S, D, DV, K_PER_MMA>
{
enum
{
CTA_P_TILE_M = STEP,
CTA_P_TILE_N = S,
// D =16: CTA_P_TILE_K=16
// D =32: CTA_P_TILE_K=32
// D>=64: CTA_P_TILE_K=64
CTA_P_TILE_K = D < 32 ? 16 : (D < 64 ? 32 : 64),
CTA_O_TILE_M = CTA_P_TILE_M,
// D =512: CTA_TILE_N=256
// D<=256: CTA_TILE_N=D
CTA_O_TILE_N = DV > 256 ? 256 : DV,
// D =512: CTA_O_TILE_K=16
// D =256: CTA_O_TILE_K=32
// D<=128: CTA_O_TILE_K=64
CTA_O_TILE_K = std::max(K_PER_MMA, DV > 256 ? 16 : (DV > 128 ? 32 : 64))
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
// The GPU architecture.
typename Gpu_arch,
// The number of rows in the CTA tile.
int M_,
// The number of cols in the CTA tile.
int N_,
// The number of elements in the the K dimension of the GEMM loop.
int K_,
// The number of valid cols in the CTA tile.
int VALID_N_,
// The number of valid elements in the the K dimension of the GEMM loop.
int VALID_K_,
// The number of rows of warps.
int WARPS_M_,
// The number of cols of warps.
int WARPS_N_,
// The number of warps in the K dimension of the GEMM loop.
int WARPS_K_>
struct Cta_tile_
{
enum
{
M = M_,
N = N_,
K = K_,
VALID_N = VALID_N_,
VALID_K = VALID_K_
};
// The number of warps.
enum
{
WARPS_M = WARPS_M_,
WARPS_N = WARPS_N_,
WARPS_K = WARPS_K_
};
// The number of warps per CTA.
enum
{
WARPS_PER_CTA = WARPS_M * WARPS_N * WARPS_K
};
// The number of threads per warp.
enum
{
THREADS_PER_WARP = Gpu_arch::THREADS_PER_WARP
};
// The number of threads per CTA.
enum
{
THREADS_PER_CTA = WARPS_PER_CTA * THREADS_PER_WARP
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
// The GPU architecture.
typename Gpu_arch_,
// The type of the elements of A.
typename A_type_,
// The type of the elements of B.
typename B_type_,
// The type of the elements of C.
typename C_type_,
// The type of the elements of the accumulators.
typename Accumulator_type_,
// The type of the elements of the epilogue.
typename Epilogue_type_>
struct Traits
{
// The architecture.
using Gpu_arch = Gpu_arch_;
// The data type for A elements.
using A_type = A_type_;
// The data type for B elements.
using B_type = B_type_;
// The data type for C elements.
using C_type = C_type_;
// The data type for accumulators.
using Accumulator_type = Accumulator_type_;
// The data type of the math in the epilogue.
using Epilogue_type = Epilogue_type_;
// Create the description of the CTA tile from a configuration.
template <int M, int N, int K, int VALID_N, int VALID_K, int WARPS_M, int WARPS_N, int WARPS_K>
using Cta_tile_extd = Cta_tile_<Gpu_arch, M, N, K, VALID_N, VALID_K, WARPS_M, WARPS_N, WARPS_K>;
// The number of bits per element of A.
enum
{
BITS_PER_ELEMENT_A = sizeof(A_type) * 8
};
// An offset in bytes for A.
static inline __host__ __device__ int64_t offset_in_bytes_a(int64_t offset)
{
return offset * static_cast<int64_t>(sizeof(A_type));
}
// The number of bits per element of B.
enum
{
BITS_PER_ELEMENT_B = sizeof(B_type) * 8
};
// An offset in bytes for B.
static inline __host__ __device__ int64_t offset_in_bytes_b(int64_t offset)
{
return offset * static_cast<int64_t>(sizeof(B_type));
}
// The number of bits per element of C.
enum
{
BITS_PER_ELEMENT_C = sizeof(C_type) * 8
};
// An offset in bytes for C.
static inline __host__ __device__ int64_t offset_in_bytes_c(int64_t offset)
{
return offset * static_cast<int64_t>(sizeof(C_type));
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Gpu_arch_base
{
// By default, architectures have 32 threads per warp.
enum
{
THREADS_PER_WARP = 32
};
// By default, architectures do not support LDGSTS.
enum
{
HAS_LDGSTS = 0
};
// By default, architecture do not support super HMMA
enum
{
HAS_SUPER_HMMA = 0
};
// By default, architecture do not support TMA
enum
{
HAS_TMA = 0
};
// By default, architecture do not support GMMA
enum
{
HAS_GMMA = 0
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Traits_, typename Cta_tile_>
using Cta_tile_with_k_with_padding = typename Traits_::template Cta_tile_extd<Cta_tile_::M, Cta_tile_::N,
Next_power_of_two<Cta_tile_::K>::VALUE, Cta_tile_::N, Next_power_of_two<Cta_tile_::K>::VALUE, Cta_tile_::WARPS_M,
Cta_tile_::WARPS_N, Cta_tile_::WARPS_K>;
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Volta : public Gpu_arch_base
{
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Cta_tile, int N_PER_MMA_ = 16, int K_PER_MMA_ = 8>
struct Volta_mma_tile
{
// The number of elements computed with a single warp-MMA.
enum
{
M_PER_MMA = 16,
N_PER_MMA = N_PER_MMA_,
K_PER_MMA = K_PER_MMA_
};
// The number of elements computed with a single CTA-MMA.
enum
{
M_PER_MMA_PER_CTA = M_PER_MMA * Cta_tile::WARPS_M,
N_PER_MMA_PER_CTA = N_PER_MMA * Cta_tile::WARPS_N,
K_PER_MMA_PER_CTA = K_PER_MMA * Cta_tile::WARPS_K
};
// The number of MMAs needed to compute the GEMM.
enum
{
MMAS_M = (Cta_tile::M + M_PER_MMA_PER_CTA - 1) / M_PER_MMA_PER_CTA,
MMAS_N = (Cta_tile::N + N_PER_MMA_PER_CTA - 1) / N_PER_MMA_PER_CTA,
MMAS_K = (Cta_tile::K + K_PER_MMA_PER_CTA - 1) / K_PER_MMA_PER_CTA
};
// The number of valid MMAs (for Head Size)
enum
{
// tile o
VALID_MMAS_N = Div_up<Cta_tile::VALID_N, N_PER_MMA_PER_CTA>::VALUE,
// tile p
VALID_MMAS_K = Div_up<Cta_tile::VALID_K, K_PER_MMA_PER_CTA>::VALUE,
};
// The number of elements computed per warp.
enum
{
M_PER_WARP = MMAS_M * M_PER_MMA,
N_PER_WARP = MMAS_N * N_PER_MMA,
K_PER_WARP = MMAS_K * K_PER_MMA,
};
// Do we enable the fast path for LDS.
enum
{
ENABLE_LDS_FAST_PATH = 0
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Volta_hmma_fp16_traits : public Traits<Volta, uint16_t, uint16_t, uint16_t, uint16_t, uint16_t>
{
// The K_PER_MMA for Volta_hmma_fp16_traits is 8.
enum
{
K_PER_MMA = 8
};
// The MMA tile.
template <typename Cta_tile>
using Mma_tile = Volta_mma_tile<Cta_tile, 16, K_PER_MMA>;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Volta_hmma_fp16_16x16x16_traits : public Traits<Volta, uint16_t, uint16_t, uint16_t, uint16_t, uint16_t>
{
// The K_PER_MMA for Volta_hmma_fp16_16x16x16_traits is 16.
enum
{
K_PER_MMA = 16
};
// The MMA tile.
template <typename Cta_tile>
using Mma_tile = Volta_mma_tile<Cta_tile, 16, K_PER_MMA>;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Volta_imma_int8_int32_traits : public Traits<Volta, int8_t, int8_t, int8_t, int32_t, float>
{
// The K_PER_MMA for Volta_imma_int8_int32_traits is 16.
enum
{
K_PER_MMA = 16
};
// The MMA tile.
template <typename Cta_tile>
using Mma_tile = Volta_mma_tile<Cta_tile, 16, K_PER_MMA>;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Turing : public Gpu_arch_base
{
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Cta_tile, int K_PER_MMA_>
struct Turing_mma_tile
{
// The number of elements computed with a single warp-MMA.
enum
{
M_PER_MMA = 16,
N_PER_MMA = 16,
K_PER_MMA = K_PER_MMA_
};
// The number of elements computed with a single CTA-MMA.
enum
{
M_PER_MMA_PER_CTA = M_PER_MMA * Cta_tile::WARPS_M,
N_PER_MMA_PER_CTA = N_PER_MMA * Cta_tile::WARPS_N,
K_PER_MMA_PER_CTA = K_PER_MMA * Cta_tile::WARPS_K
};
// The number of MMAs needed to compute the GEMM.
enum
{
MMAS_M = Div_up<Cta_tile::M, M_PER_MMA_PER_CTA>::VALUE,
MMAS_N = Div_up<Cta_tile::N, N_PER_MMA_PER_CTA>::VALUE,
MMAS_K = Div_up<Cta_tile::K, K_PER_MMA_PER_CTA>::VALUE,
};
// The number of valid MMAs (for Head Size)
enum
{
// tile o
VALID_MMAS_N = Div_up<Cta_tile::VALID_N, N_PER_MMA_PER_CTA>::VALUE,
// tile p
VALID_MMAS_K = Div_up<Cta_tile::VALID_K, K_PER_MMA_PER_CTA>::VALUE,
};
// The number of elements computed per warp.
enum
{
M_PER_WARP = MMAS_M * M_PER_MMA,
N_PER_WARP = MMAS_N * N_PER_MMA,
K_PER_WARP = MMAS_K * K_PER_MMA,
};
// The distribution of threads in the output tile.
enum
{
THREADS_PER_MMA_M = 8,
THREADS_PER_MMA_N = 4,
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Cta_tile>
struct Turing_hmma_tile : public Turing_mma_tile<Cta_tile, 8>
{
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Turing_hmma_fp16_traits : public Traits<Turing, uint16_t, uint16_t, uint16_t, uint16_t, uint16_t>
{
// The K_PER_MMA for Turing_hmma_fp16_traits is 8.
enum
{
K_PER_MMA = 8
};
// The MMA tile.
template <typename Cta_tile>
using Mma_tile = Turing_hmma_tile<Cta_tile>;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Turing_hmma_fp32_traits : public Traits<Turing, uint16_t, uint16_t, uint16_t, float, float>
{
// The K_PER_MMA for Turing_hmma_fp32_traits is 8.
enum
{
K_PER_MMA = 8
};
// The MMA tile.
template <typename Cta_tile>
using Mma_tile = Turing_hmma_tile<Cta_tile>;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Cta_tile>
struct Turing_imma_int8_tile : public Turing_mma_tile<Cta_tile, 16>
{
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Turing_imma_int8_int32_traits : public Traits<Turing, int8_t, int8_t, int8_t, int32_t, float>
{
// The K_PER_MMA for Turing_imma_int8_int32_traits is 16.
enum
{
K_PER_MMA = 16
};
// The MMA tile.
template <typename Cta_tile>
using Mma_tile = Turing_imma_int8_tile<Cta_tile>;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Ampere : public Gpu_arch_base
{
// It has LDGSTS.
enum
{
HAS_LDGSTS = 1
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Cta_tile, int K_PER_MMA = 16>
struct Ampere_hmma_tile : public Turing_mma_tile<Cta_tile, K_PER_MMA>
{
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Ampere_hmma_fp16_traits : public Traits<Ampere, uint16_t, uint16_t, uint16_t, uint16_t, uint16_t>
{
// The K_PER_MMA for Ampere_hmma_fp16_traits is 16.
enum
{
K_PER_MMA = 16
};
// The MMA tile.
template <typename Cta_tile>
using Mma_tile = Ampere_hmma_tile<Cta_tile>;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Ampere_hmma_fp32_traits : public Traits<Ampere, uint16_t, uint16_t, uint16_t, float, uint16_t>
{
// The K_PER_MMA for Ampere_hmma_fp32_traits is 16.
enum
{
K_PER_MMA = 16
};
// The MMA tile.
template <typename Cta_tile>
using Mma_tile = Ampere_hmma_tile<Cta_tile>;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
// used for Epilogue_type = bf16_t (similar to Ampere_hmma_fp16_traits).
struct Ampere_hmma_bf16_bf16_traits : public Traits<Ampere, bf16_t, bf16_t, bf16_t, bf16_t, bf16_t>
{
// The K_PER_MMA for Ampere_hmma_bf16_bf16_traits is 16.
enum
{
K_PER_MMA = 16
};
// The MMA tile.
template <typename Cta_tile>
using Mma_tile = Ampere_hmma_tile<Cta_tile>;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Ampere_hmma_bf16_traits : public Traits<Ampere, bf16_t, bf16_t, bf16_t, float, bf16_t>
{
// The K_PER_MMA for Ampere_hmma_bf16_traits is 16.
enum
{
K_PER_MMA = 16
};
// The MMA tile.
template <typename Cta_tile>
using Mma_tile = Ampere_hmma_tile<Cta_tile>;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Cta_tile>
struct Ampere_imma_int8_tile : public Turing_mma_tile<Cta_tile, 32>
{
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Ampere_imma_int8_int32_traits : public Traits<Ampere, int8_t, int8_t, int8_t, int32_t, float>
{
// The K_PER_MMA for Ampere_imma_int8_int32_traits is 32.
enum
{
K_PER_MMA = 32
};
// The MMA tile.
template <typename Cta_tile>
using Mma_tile = Ampere_imma_int8_tile<Cta_tile>;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Ada : public Gpu_arch_base
{
// It has LDGSTS.
enum
{
HAS_LDGSTS = 1
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
// The following partial traits are mapped to Ampere_hmma_fp16_traits in fmha/kernel_traits.h.
//
// It is easier to implement setup.py this way.
struct Ada_hmma_fp16_traits
{
};
struct Ada_hmma_fp32_traits
{
};
struct Ada_imma_int8_int32_traits
{
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Cta_tile>
struct Ada_qmma_fp8_tile : public Turing_mma_tile<Cta_tile, 32>
{
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Ada_qmma_e4m3_fp16_traits : public Traits<Ada, e4m3_t, e4m3_t, e4m3_t, uint16_t, uint16_t>
{
// The K_PER_MMA for Ada_qmma_e4m3_fp16_traits is 32.
enum
{
K_PER_MMA = 32
};
// The MMA tile.
template <typename Cta_tile>
using Mma_tile = Ada_qmma_fp8_tile<Cta_tile>;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Ada_qmma_e4m3_fp32_traits : public Traits<Ada, e4m3_t, e4m3_t, e4m3_t, float, float>
{
// The K_PER_MMA for Ada_qmma_e4m3_fp32_traits is 32.
enum
{
K_PER_MMA = 32
};
// The MMA tile.
template <typename Cta_tile>
using Mma_tile = Ada_qmma_fp8_tile<Cta_tile>;
static constexpr float SOFTMAX_FP_QUANT_SCALE = Softmax_fp_quant_scale<Traits::A_type>();
static constexpr float SOFTMAX_FP_DEQUANT_SCALE = 1.f / SOFTMAX_FP_QUANT_SCALE;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Hopper : public Gpu_arch_base
{
// It has LDGSTS.
enum
{
HAS_LDGSTS = 1
};
// It has TMA.
enum
{
HAS_TMA = 1
};
// It has GMMA
enum
{
HAS_GMMA = 1
};
// for Hopper there are 4 warps per warpgroup.
enum
{
WARPS_PER_WARP_GROUP = 4
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
// Hopper related code.
// SHOULD we move this to a different file??
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int HEIGHT_ = 1, int WIDTH_ = 1, int DEPTH_ = 1>
struct Hopper_cga_tile
{
// The size of the CGA in terms of CTA
enum
{
CLUSTER_HEIGHT = HEIGHT_
};
enum
{
CLUSTER_WIDTH = WIDTH_
};
enum
{
CLUSTER_DEPTH = DEPTH_
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Gpu_arch_,
int M_, // CTA tile M
int N_, // CTA tile N
int K_, // CTA tile K
int VALID_N_, // CTA tile valid N
int VALID_K_, // CTA tile valid K
int WARP_GROUP_M_, // Number of warp group along M dim
int WARP_GROUP_N_, // Number of warp group along N dim
int WARP_GROUP_K_> // Number of warp group along K dim
struct Hopper_cta_tile
{
// GPU arch.
using Gpu_arch = Gpu_arch_;
// The size of the CTA tile.
// TODO: support D (not power of 2)
enum
{
M = M_,
N = N_,
K = K_,
VALID_N = VALID_N_,
VALID_K = VALID_K_
};
// The number of warp groups.
enum
{
WARP_GROUP_M = WARP_GROUP_M_,
WARP_GROUP_N = WARP_GROUP_N_,
WARP_GROUP_K = WARP_GROUP_K_
};
// The number of warps in a warp group.
enum
{
WARPS_M_PER_GROUP = 4,
WARPS_N_PER_GROUP = 1,
WARPS_K_PER_GROUP = 1,
};
// The number of warps in a cta.
enum
{
WARPS_M = WARPS_M_PER_GROUP * WARP_GROUP_M_,
WARPS_N = WARPS_N_PER_GROUP * WARP_GROUP_N_,
WARPS_K = WARPS_K_PER_GROUP * WARP_GROUP_K_
};
// The number of warps per CTA.
enum
{
WARPS_PER_CTA = WARP_GROUP_M * WARP_GROUP_N * WARP_GROUP_K * Gpu_arch::WARPS_PER_WARP_GROUP
};
// The number of warps per warpgroup.
enum
{
WARPS_PER_WARP_GROUP = Gpu_arch::WARPS_PER_WARP_GROUP
};
// The number of threads per warp.
enum
{
THREADS_PER_WARP = Gpu_arch::THREADS_PER_WARP
};
// the number of threads per warpgroup.
enum
{
THREADS_PER_WARP_GROUP = THREADS_PER_WARP * WARPS_PER_WARP_GROUP
};
// The number of threads per CTA.
enum
{
THREADS_PER_CTA = WARPS_PER_CTA * THREADS_PER_WARP
};
enum
{
GROUPS_M = 1
};
enum
{
GROUPS_N = 1
};
enum
{
GROUPS_K = 1
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Cta_tile, int GMMA_M, int GMMA_N, int GMMA_K>
struct Hopper_gmma_tile
{
// The number of elements computed with a single warp group mma.
enum
{
M_PER_MMA = GMMA_M,
N_PER_MMA = GMMA_N,
K_PER_MMA = GMMA_K
};
// The number of warp groups.
enum
{
NUM_WARP_GROUPS = Cta_tile::WARP_GROUP_M * Cta_tile::WARP_GROUP_N * Cta_tile::WARP_GROUP_K
};
// The number of elements computed with a single CTA-MMA.
enum
{
M_PER_MMA_PER_CTA = M_PER_MMA * Cta_tile::WARP_GROUP_M,
N_PER_MMA_PER_CTA = N_PER_MMA * Cta_tile::WARP_GROUP_N,
K_PER_MMA_PER_CTA = K_PER_MMA * Cta_tile::WARP_GROUP_K
};
// The number of MMAs needed to compute the GEMM.
enum
{
MMAS_M = (Cta_tile::M + M_PER_MMA_PER_CTA - 1) / M_PER_MMA_PER_CTA,
MMAS_N = (Cta_tile::N + N_PER_MMA_PER_CTA - 1) / N_PER_MMA_PER_CTA,
MMAS_K = (Cta_tile::K + K_PER_MMA_PER_CTA - 1) / K_PER_MMA_PER_CTA,
};
// The number of valid MMAs (for Head Size)
enum
{
// tile o
VALID_MMAS_N = Div_up<Cta_tile::VALID_N, N_PER_MMA_PER_CTA>::VALUE,
// tile p
VALID_MMAS_K = Div_up<Cta_tile::VALID_K, K_PER_MMA_PER_CTA>::VALUE,
};
// The number of elements computed per warp group.
enum
{
M_PER_WARP_GROUP = MMAS_M * M_PER_MMA,
N_PER_WARP_GROUP = MMAS_N * N_PER_MMA,
K_PER_WARP_GROUP = MMAS_K * K_PER_MMA,
};
// the size of GMMA group, which is GMMA_M x GMMA_N x Kblock.
enum
{
M_PER_GMMA_GROUP = GMMA_M,
N_PER_GMMA_GROUP = GMMA_N,
K_PER_GMMA_GROUP = Cta_tile::K,
};
// The distribution of threads in the output tile.
// TODO
enum
{
THREADS_PER_MMA_M = 8,
THREADS_PER_MMA_N = 4,
};
// The number of core matrices per GMMA.
enum
{
CORES_M_PER_GROUP = 8 * Cta_tile::WARPS_M_PER_GROUP,
CORES_N_PER_GROUP = 8 * Cta_tile::WARPS_N_PER_GROUP,
CORES_M = GMMA_M / CORES_M_PER_GROUP,
CORES_N = GMMA_N / CORES_N_PER_GROUP,
};
// The number of logical rows/cols per thread.
enum
{
// A thread owns 1 row per core matrix.
ROWS_PER_THREAD = CORES_M,
// A thread owns 2 col per core matrix.
COLS_PER_THREAD = CORES_N * 2,
};
static_assert(ROWS_PER_THREAD == 2);
static_assert(COLS_PER_THREAD == GMMA_N / 4);
};
////////////////////////////////////////////////////////////////////////////////////////////////////
enum class Hopper_instructions
{
HGMMA_FP16,
HGMMA_BF16,
HGMMA_FP32,
IGMMA_INT32,
QGMMA_E4M3_FP32,
QGMMA_E5M2_FP32,
QGMMA_E4M3_FP16,
QGMMA_E5M2_FP16
};
////////////////////////////////////////////////////////////////////////////////////////////////////
// Hopper HGMMA FP16 Traits
template <int GMMA_M_, // GMMA instruction shape in M dim
int GMMA_N_, // GMMA instruction shape in N dim
int GMMA_K_, // GMMA instruction shape in K dim
bool GMMA_A_RF_, // GMMA A operand coming from RF?
bool GMMA_B_RF_ // GMMA B operand coming from RF?
>
struct Hopper_hgmma_fp16_traits : public Traits<Hopper, uint16_t, uint16_t, uint16_t, uint16_t, uint16_t>
{
// The GMMA shape.
enum
{
GMMA_M = GMMA_M_,
GMMA_N = GMMA_N_,
GMMA_K = 16
};
// is A operand in RF for GMMA?
static constexpr bool GMMA_A_RF = GMMA_A_RF_;
// is B operand in RF for GMMA?
static constexpr bool GMMA_B_RF = GMMA_B_RF_;
// GMMA shape has certain requirements.
static_assert(GMMA_K == 16, "GMMA K must be 16; this might change");
static_assert(GMMA_M == 64, "GMMA M must be 64; this might change");
static_assert(GMMA_N % 8 == 0, "GMMA N must be multiple of 8; this might change");
static_assert(GMMA_N <= 256, "GMMA N must be no larger than 256; this might change");
// GMMA does not allow both operands coming from RF.
static_assert((GMMA_A_RF && GMMA_B_RF) != true, "GMMA does not allow both operands coming from RF.");
// The Cta tile.
template <int M, int N, int K, int Warpgroup_M, int Warpgroup_N, int Warpgroup_K>
using Cta_tile = Hopper_cta_tile<Hopper, M, N, K, N, K, Warpgroup_M, Warpgroup_N, Warpgroup_K>;
// The Cta tile.
template <int M, int N, int K, int VALID_N, int VALID_K, int Warpgroup_M, int Warpgroup_N, int Warpgroup_K>
using Cta_padded_tile = Hopper_cta_tile<Hopper, M, N, K, VALID_N, VALID_K, Warpgroup_M, Warpgroup_N, Warpgroup_K>;
// The CGA Tile
template <int HEIGHT = 1, int WIDTH = 1, int DEPTH = 1>
using Cga_tile = Hopper_cga_tile<HEIGHT, WIDTH, DEPTH>;
// The MMA tile.
template <typename Cta_tile>
using Mma_tile = Hopper_gmma_tile<Cta_tile, GMMA_M, GMMA_N, GMMA_K>;
// The handle to differentiate instructions.
static constexpr fmha::Hopper_instructions HOPPER_INSTRUCTION = fmha::Hopper_instructions::HGMMA_FP16;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
// Hopper HGMMA FP32 Traits
template <int GMMA_M_, // GMMA instruction shape in M dim
int GMMA_N_, // GMMA instruction shape in N dim
int GMMA_K_, // GMMA instruction shape in K dim
bool GMMA_A_RF_, // GMMA A operand coming from RF?
bool GMMA_B_RF_ // GMMA B operand coming from RF?
>
struct Hopper_hgmma_fp32_traits : public Traits<Hopper, uint16_t, uint16_t, uint16_t, float, uint16_t>
{
// The GMMA shape.
enum
{
GMMA_M = GMMA_M_,
GMMA_N = GMMA_N_,
GMMA_K = 16
};
// is A operand in RF for GMMA?
static constexpr bool GMMA_A_RF = GMMA_A_RF_;
// is B operand in RF for GMMA?
static constexpr bool GMMA_B_RF = GMMA_B_RF_;
// GMMA shape has certain requirements.
static_assert(GMMA_K == 16, "GMMA K must be 16; this might change");
static_assert(GMMA_M == 64, "GMMA M must be 64; this might change");
static_assert(GMMA_N % 8 == 0, "GMMA N must be multiple of 8; this might change");
static_assert(GMMA_N <= 256, "GMMA N must be no larger than 256; this might change");
// GMMA does not allow both operands coming from RF.
static_assert((GMMA_A_RF && GMMA_B_RF) != true, "GMMA does not allow both operands coming from RF.");
// The Cta tile.
template <int M, int N, int K, int Warpgroup_M, int Warpgroup_N, int Warpgroup_K>
using Cta_tile = Hopper_cta_tile<Hopper, M, N, K, N, K, Warpgroup_M, Warpgroup_N, Warpgroup_K>;
// The Cta tile.
template <int M, int N, int K, int VALID_N, int VALID_K, int Warpgroup_M, int Warpgroup_N, int Warpgroup_K>
using Cta_padded_tile = Hopper_cta_tile<Hopper, M, N, K, VALID_N, VALID_K, Warpgroup_M, Warpgroup_N, Warpgroup_K>;
// The CGA Tile
template <int HEIGHT = 1, int WIDTH = 1, int DEPTH = 1>
using Cga_tile = Hopper_cga_tile<HEIGHT, WIDTH, DEPTH>;
// The MMA tile.
template <typename Cta_tile>
using Mma_tile = Hopper_gmma_tile<Cta_tile, GMMA_M, GMMA_N, GMMA_K>;
// The handle to differentiate instructions.
static constexpr fmha::Hopper_instructions HOPPER_INSTRUCTION = fmha::Hopper_instructions::HGMMA_FP32;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
// Hopper BF16 HGMMA Traits
template <int GMMA_M_, // GMMA instruction shape in M dim
int GMMA_N_, // GMMA instruction shape in N dim
int GMMA_K_, // GMMA instruction shape in K dim
bool GMMA_A_RF_, // GMMA A operand coming from RF?
bool GMMA_B_RF_ // GMMA B operand coming from RF?
>
struct Hopper_hgmma_bf16_traits : public Traits<Hopper, bf16_t, bf16_t, bf16_t, float, bf16_t>
{
// The GMMA shape.
enum
{
GMMA_M = GMMA_M_,
GMMA_N = GMMA_N_,
GMMA_K = 16
};
// is A operand in RF for GMMA?
static constexpr bool GMMA_A_RF = GMMA_A_RF_;
// is B operand in RF for GMMA?
static constexpr bool GMMA_B_RF = GMMA_B_RF_;
// GMMA shape has certain requirements.
static_assert(GMMA_K == 16, "GMMA K must be 16; this might change");
static_assert(GMMA_M == 64, "GMMA M must be 64; this might change");
static_assert(GMMA_N % 8 == 0, "GMMA N must be multiple of 8; this might change");
static_assert(GMMA_N <= 256, "GMMA N must be no larger than 256; this might change");
// GMMA does not allow both operands coming from RF.
static_assert((GMMA_A_RF && GMMA_B_RF) != true, "GMMA does not allow both operands coming from RF.");
// The Cta tile.
template <int M, int N, int K, int Warpgroup_M, int Warpgroup_N, int Warpgroup_K>
using Cta_tile = Hopper_cta_tile<Hopper, M, N, K, N, K, Warpgroup_M, Warpgroup_N, Warpgroup_K>;
// The Cta tile.
template <int M, int N, int K, int VALID_N, int VALID_K, int Warpgroup_M, int Warpgroup_N, int Warpgroup_K>
using Cta_padded_tile = Hopper_cta_tile<Hopper, M, N, K, VALID_N, VALID_K, Warpgroup_M, Warpgroup_N, Warpgroup_K>;
// The CGA Tile
template <int HEIGHT = 1, int WIDTH = 1, int DEPTH = 1>
using Cga_tile = Hopper_cga_tile<HEIGHT, WIDTH, DEPTH>;
// The MMA tile.
template <typename Cta_tile>
using Mma_tile = Hopper_gmma_tile<Cta_tile, GMMA_M, GMMA_N, GMMA_K>;
// The handle to differentiate instructions.
static constexpr fmha::Hopper_instructions HOPPER_INSTRUCTION = fmha::Hopper_instructions::HGMMA_BF16;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
// Hopper IGMMA Traits
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int GMMA_M_, // GMMA instruction shape in M dim
int GMMA_N_, // GMMA instruction shape in N dim
int GMMA_K_, // GMMA instruction shape in K dim
bool GMMA_A_RF_, // GMMA A operand coming from RF?
bool GMMA_B_RF_ // GMMA B operand coming from RF?
>
struct Hopper_igmma_int8_int32_traits : public Traits<Hopper, int8_t, int8_t, int8_t, int32_t, float>
{
using Base = Traits<Hopper, int8_t, int8_t, int8_t, int32_t, float>;
// The GMMA shape
enum
{
GMMA_M = GMMA_M_
};
enum
{
GMMA_N = GMMA_N_
};
enum
{
GMMA_K = 32
};
// is A operand in RF for GMMA?
static constexpr bool GMMA_A_RF = GMMA_A_RF_;
// is B operand in RF for GMMA?
static constexpr bool GMMA_B_RF = GMMA_B_RF_;
// GMMA shape has certain requirement
static_assert(GMMA_K == 32, "GMMA K must be 32; this might change");
static_assert(GMMA_M == 64, "GMMA M must be 64; this might change");
static_assert(GMMA_N % 8 == 0, "GMMA N must be multiple of 8; this might change");
static_assert(GMMA_N <= 256, "GMMA N must be no larger than 256; this might change");
// GMMA does not allow both operands coming from RF.
static_assert((GMMA_A_RF && GMMA_B_RF) != true, "GMMA does not allow both operands coming from RF.");
// The Cta tile.
template <int M, int N, int K, int Warpgroup_M, int Warpgroup_N, int Warpgroup_K>
using Cta_tile = Hopper_cta_tile<Hopper, M, N, K, N, K, Warpgroup_M, Warpgroup_N, Warpgroup_K>;
// The Cta tile.
template <int M, int N, int K, int VALID_N, int VALID_K, int Warpgroup_M, int Warpgroup_N, int Warpgroup_K>
using Cta_padded_tile = Hopper_cta_tile<Hopper, M, N, K, VALID_N, VALID_K, Warpgroup_M, Warpgroup_N, Warpgroup_K>;
// The CGA Tile
template <int HEIGHT = 1, int WIDTH = 1, int DEPTH = 1>
using Cga_tile = Hopper_cga_tile<HEIGHT, WIDTH, DEPTH>;
// The MMA tile.
template <typename Cta_tile>
using Mma_tile = Hopper_gmma_tile<Cta_tile, GMMA_M, GMMA_N, GMMA_K>;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
// Hopper QGMMA Traits
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int GMMA_M_, // GMMA instruction shape in M dim
int GMMA_N_, // GMMA instruction shape in N dim
int GMMA_K_, // GMMA instruction shape in K dim
bool GMMA_A_RF_, // GMMA A operand coming from RF?
bool GMMA_B_RF_, // GMMA B operand coming from RF?
typename Input_type_A_ = e4m3_t, typename Input_type_B_ = e4m3_t, typename Output_type_ = e4m3_t>
struct Hopper_qgmma_fp8_fp32_traits : public Traits<Hopper, Input_type_A_, Input_type_B_, Output_type_, float, float>
{
using Base = Traits<Hopper, Input_type_A_, Input_type_B_, Output_type_, float, float>;
using Input_type_A = Input_type_A_;
using Input_type_B = Input_type_B_;
using Output_type = Output_type_;
// The GMMA shape
enum
{
GMMA_M = GMMA_M_
};
enum
{
GMMA_N = GMMA_N_
};
enum
{
GMMA_K = 32
};
// is A operand in RF for GMMA?
static constexpr bool GMMA_A_RF = GMMA_A_RF_;
// is B operand in RF for GMMA?
static constexpr bool GMMA_B_RF = GMMA_B_RF_;
// GMMA shape has certain requirement
static_assert(GMMA_K == 32, "GMMA K must be 32; this might change");
static_assert(GMMA_M == 64, "GMMA M must be 64; this might change");
static_assert(GMMA_N % 8 == 0, "GMMA N must be multiple of 8; this might change");
static_assert(GMMA_N <= 256, "GMMA N must be no larger than 256; this might change");
// GMMA does not allow both operands coming from RF.
static_assert((GMMA_A_RF && GMMA_B_RF) != true, "GMMA does not allow both operands coming from RF.");
// The Cta tile.
template <int M, int N, int K, int Warpgroup_M, int Warpgroup_N, int Warpgroup_K>
using Cta_tile = Hopper_cta_tile<Hopper, M, N, K, N, K, Warpgroup_M, Warpgroup_N, Warpgroup_K>;
// The Cta tile.
template <int M, int N, int K, int VALID_N, int VALID_K, int Warpgroup_M, int Warpgroup_N, int Warpgroup_K>
using Cta_padded_tile = Hopper_cta_tile<Hopper, M, N, K, VALID_N, VALID_K, Warpgroup_M, Warpgroup_N, Warpgroup_K>;
// The CGA Tile
template <int HEIGHT = 1, int WIDTH = 1, int DEPTH = 1>
using Cga_tile = Hopper_cga_tile<HEIGHT, WIDTH, DEPTH>;
// The XMMA tile.
template <typename Cta_tile>
using Mma_tile = Hopper_gmma_tile<Cta_tile, GMMA_M, GMMA_N, GMMA_K>;
// Used by low precision floating point types (e4m3, e5m2, etc.)
static constexpr float SOFTMAX_FP_QUANT_SCALE = Softmax_fp_quant_scale<Input_type_A_>();
static constexpr float SOFTMAX_FP_DEQUANT_SCALE = 1.f / SOFTMAX_FP_QUANT_SCALE;
};
template <int GMMA_M, // GMMA instruction shape in M dim
int GMMA_N, // GMMA instruction shape in N dim
int GMMA_K, // GMMA instruction shape in K dim
bool GMMA_A_RF, // GMMA A operand coming from RF?
bool GMMA_B_RF // GMMA B operand coming from RF?
>
using Hopper_qgmma_e4m3_fp32_traits
= Hopper_qgmma_fp8_fp32_traits<GMMA_M, GMMA_N, GMMA_K, GMMA_A_RF, GMMA_B_RF, e4m3_t, e4m3_t, e4m3_t>;
} // namespace fmha