TensorRT-LLMs/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/mmha_notes.md
2023-09-20 00:29:41 -07:00

4.7 KiB

Notes on the Masked Multihead Attention Kernel

Types

  • T: type of the q, k, v, kv_cache inputs and outputs: float, uint16_t, __nv_bfloat16, __nv_fp8_e4m3
  • Tk: compute type internal to the kernel mapped from T as follows:
    • float -> float
    • uint16_t -> uint16_t (i.e. half)
    • __nv_bfloat16 -> __nv_bfloat16
    • __nv_fp8_e4m3 -> float

Constraints

  • THREADS_PER_BLOCK: in {64, 128, 256}
  • Dh: 32 <= Dh <= 256

Constants

  • Dh_MAX: round Dh up to the next power of 2
  • THREADS_PER_KEY: 256 / THREADS_PER_BLOCK in {1, 2, 4}
  • THREADS_PER_VALUE: Dh_MAX * sizeof(T) / 16, except for FP8 where sizeof(T) is assumed to be 4.

Note that THREADS_PER_KEY is currently computed by the simple heuristic above which seems to work fine for the moment.

Auxiliary vector types

  • Qk_vec_m: vector for Q/K elements with memory precision depending on T and Dh_MAX in (32, 64, 128, 256):
    • float: (float, float, float2, float4) with sizes (4, 4, 8, 16)
    • uint16_t: (uint32_t, uint32_t, uint2, uint4) with sizes (4, 4, 8, 16)
    • __nv_bfloat16: (__nv_bfloat162, __nv_bfloat162, bf16_4_t, bf16_8_t) with sizes (4, 4, 8, 16)
    • __nv_fp8_e4m3: (fp8_4_t, fp8_4_t, fp8_4_t, fp8_4_t) with sizes (4, 4, 4, 4)
  • Qk_vec_k: vector for Q/K elements with kernel precision depending on T and Dh_MAX in (32, 64, 128, 256):
    • __nv_fp8_e4m3: (float4, float4, float4, float4) with sizes (16, 16, 16, 16)
    • other types sames as Qk_vec_m

Associated constants are:

  • QK_VEC_SIZE: sizeof(Qk_vec_m) / sizeof(T) in {1, 2, 4} depending on T and Dh_MAX in (32, 64, 128, 256)
    • float, uint16_t, __nv_bfloat16 : (1, 1, 2, 4)
    • __nv_fp8_e4m3: (4, 4, 4, 4)
  • QK_VECS_PER_Dh_MAX: Dh_MAX / QK_VEC_SIZE in {8, 16, 32, 64} depending on T and Dh_MAX in (32, 64, 128, 256)
    • float, uint16_t, __nv_bfloat16: (32, 64, 64, 64)
    • __nv_fp8_e4m3: (8, 16, 32, 64)
  • QK_ELTS_IN_16B: 16 / sizeof(T) in {16, 8, 4}
  • QK_VECS_IN_16B: 16 / sizeof(Qk_vec_m) in {16, 8, 4} and <= QK_ELTS_IN_16B

Note that QK_ELTS_IN_16B / QK_VECS_IN_16B == QK_VEC_SIZE.

Similarly, we have:

  • k_vec_m: vector for K elements with memory precision depending on T and THREADS_PER_KEY in (1, 2, 4):
    • float: (float4, float2, float) with sizes (16, 8, 4)
    • uint16_t: (uint4, uint2, uint32_t) with sizes 16, 8, 4)
    • __nv_bfloat16: (bf16_8_t, bf16_4_t, nv_bfloat162) with sizes (16, 8, 4)
    • __nv_fp8_e4m3: (fp8_4_t, fp8_4_t, fp8_4_t) with sizes (4, 4, 4)
  • k_vec_k: vector for K elements with kernel precision depending on T and THREADS_PER_KEY in (1, 2, 4):
    • __nv_fp8_e4m3: (float4, float4, float4) with sizes (16, 16, 16)
    • other types sames as k_vec_m

Associated constants are:

  • K_VEC_SIZE: sizeof(k_vec_m) / sizeof(T) in {1, 2, 4} depending on T and THREADS_PER_KEY in (1, 2, 4)
    • float : (4, 2, 1)
    • uint16_t: (8, 4, 2)
    • __nv_bfloat16: (8, 4, 2)
    • __nv_fp8_e4m3: (4, 4, 4)

Memory Layout

Notation:

  • B: Batch size (number of sequences),
  • L: Sequence length,
  • D: Hidden dimension,
  • H: Number of heads,
  • Dh: Hidden dimension per head - Dh = D / H.

k_cache

The k_cache stores elements of the T.

Layout: [B, H, Dh/x, L, x] where x == QK_ELTS_IN_16B, i.e., x == 16 (FP8), x == 8 (FP16), x == 4 (FP32)

Each thread writes QK_VEC_SIZE elements.

v_cache

The v_cache stores elements of the T.

Layout: [B, H, L, Dh]

QKV buffer

The qkv buffer stores elements of type T.

Layout: [B, H, Dh]

Shared memory

Dynamic size of shared memory smem_: max over several expressions since the memory is reused in different contexts

Notes on GEMMs in the context of the MMHA kernel

GEMM in DecoderSelfAttentionLayer.cc

            cublas_wrapper_->Gemm(CUBLAS_OP_N,
                                  CUBLAS_OP_N,
                                  3 * local_hidden_units_,  // n
                                  batch_size,
                                  d_model_,  // k
                                  attention_weights->query_weight.kernel,
                                  3 * local_hidden_units_,  // n
                                  attention_input,
                                  d_model_,  // k
                                  qkv_buf_,
                                  3 * local_hidden_units_ /* n */);
  • A: query, key, value weights with shape d_model_ x 3 * local_hidden_units_
  • B: attention input with shape batch_size x d_model
  • C: has shape batch_size x 3 * local_hidden_units_