TensorRT-LLMs/cpp/tensorrt_llm/kernels/llama4MinLatencyKernels/llama4Fp8Bf16GemmPerBlockTemplate.cuh
Zhenhuan Chen 8462cf6c96
[TRTLLM-9578][feat] make PDL enabled by default (#9695)
Signed-off-by: Zhenhuan Chen <zhenhuanc@nvidia.com>
2025-12-25 07:15:24 -05:00

306 lines
15 KiB
Plaintext

/*
* Copyright (c) 2025-2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/common/config.h"
#include "tensorrt_llm/kernels/llama4MinLatencyKernels/llama4Utils.cuh"
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <stdexcept>
TRTLLM_NAMESPACE_BEGIN
namespace kernels::llama4_min_latency::llama4_fp8_bf16_gemm
{
// Grid size is num_tokens / TILE_TOKEN * hidden_out / TILE_OUT.
// Each block processes TILE_TOKEN tokens and TILE_OUT rows.
// Within each block, it steps through hidden_in in steps of BLOCK_SIZE * VEC_SIZE.
template <int HIDDEN_IN, int TILE_TOKEN, int TILE_OUT, bool ALIGNED = true>
__launch_bounds__(BLOCK_SIZE) __global__ void llama4_fp8_bf16_gemm_per_block_kernel(
__nv_fp8_e4m3 const* __restrict__ A, // Input tensor [num_tokens][hidden_in]
__nv_fp8_e4m3 const* __restrict__ B, // Input tensor [hidden_out][hidden_in]
__nv_bfloat16* __restrict__ C, // Output tensor [num_tokens][hidden_out]
float* __restrict__ scaling_factor, int num_tokens, int hidden_in, int hidden_out)
{
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDA_ARCH__ < 1200))
// Shared memory for block reduction
__shared__ float reduce_buffer[TILE_TOKEN][TILE_OUT][BLOCK_SIZE];
// Each thread accumulates its partial sum
float2 thread_sum[TILE_TOKEN][TILE_OUT];
#pragma unroll
for (int tile_token_idx = 0; tile_token_idx < TILE_TOKEN; tile_token_idx++)
{
#pragma unroll
for (int tile_out_idx = 0; tile_out_idx < TILE_OUT; tile_out_idx++)
{
thread_sum[tile_token_idx][tile_out_idx].x = 0.0f;
thread_sum[tile_token_idx][tile_out_idx].y = 0.0f;
}
}
int const token_idx = blockIdx.y;
int const row_idx = blockIdx.x;
int const tid = threadIdx.x;
int chunk = 0;
#if ENABLE_ACQBULK && ENABLE_PREFETCH
#pragma unroll 9
for (; chunk < ((HIDDEN_IN > 0 ? HIDDEN_IN : hidden_in) / BLOCK_SIZE / VEC_SIZE - 1); chunk++)
{
#pragma unroll
for (int tile_out_idx = 0; tile_out_idx < TILE_OUT; tile_out_idx++)
{
int current_row = row_idx * TILE_OUT + tile_out_idx;
int base_idx = chunk * BLOCK_SIZE + tid;
asm volatile("prefetch.global.L2 [%0];" ::"l"(&B[current_row * hidden_in / VEC_SIZE + base_idx])
: "memory");
}
}
{
#pragma unroll
for (int tile_out_idx = 0; tile_out_idx < TILE_OUT; tile_out_idx++)
{
int current_row = row_idx * TILE_OUT + tile_out_idx;
int base_idx = chunk * BLOCK_SIZE + tid;
if (ALIGNED || base_idx * VEC_SIZE < hidden_in)
{
asm volatile("prefetch.global.L2 [%0];" ::"l"(&B[current_row * hidden_in / VEC_SIZE + base_idx])
: "memory");
}
}
}
#endif
#if ENABLE_ACQBULK
cudaGridDependencySynchronize();
#endif
// Processing 8 elements each
chunk = 0;
#pragma unroll 9
for (; chunk < ((HIDDEN_IN > 0 ? HIDDEN_IN : hidden_in) / BLOCK_SIZE / VEC_SIZE - 1); chunk++)
{
int base_idx = chunk * BLOCK_SIZE + tid;
// Load values from tensor B
aligned_fp8x8 b_vec[TILE_OUT];
#pragma unroll
for (int tile_out_idx = 0; tile_out_idx < TILE_OUT; tile_out_idx++)
{
int current_row = row_idx * TILE_OUT + tile_out_idx;
b_vec[tile_out_idx]
= reinterpret_cast<aligned_fp8x8 const*>(B)[current_row * hidden_in / VEC_SIZE + base_idx];
}
// Load values from tensor A
aligned_fp8x8 a_vec[TILE_TOKEN];
#pragma unroll
for (int tile_token_idx = 0; tile_token_idx < TILE_TOKEN; tile_token_idx++)
{
int current_token = token_idx * TILE_TOKEN + tile_token_idx;
a_vec[tile_token_idx]
= reinterpret_cast<aligned_fp8x8 const*>(A)[current_token * hidden_in / VEC_SIZE + base_idx];
}
// Compute partial sum
#pragma unroll
for (int tile_out_idx = 0; tile_out_idx < TILE_OUT; tile_out_idx++)
{
#pragma unroll
for (int tile_token_idx = 0; tile_token_idx < TILE_TOKEN; tile_token_idx++)
{
aligned_fp8x8 a_vec_current = a_vec[tile_token_idx];
aligned_fp8x8 b_vec_current = b_vec[tile_out_idx];
#pragma unroll
for (int i = 0; i < VEC_SIZE / 4; i++)
{
float4 a_val = float4(a_vec_current.data[i]);
float4 b_val = float4(b_vec_current.data[i]);
thread_sum[tile_token_idx][tile_out_idx] = ffma2(make_float2(a_val.x, a_val.y),
make_float2(b_val.x, b_val.y), thread_sum[tile_token_idx][tile_out_idx]);
thread_sum[tile_token_idx][tile_out_idx] = ffma2(make_float2(a_val.z, a_val.w),
make_float2(b_val.z, b_val.w), thread_sum[tile_token_idx][tile_out_idx]);
}
}
}
}
// The last chunk may be partial.
{
int base_idx = chunk * BLOCK_SIZE + tid;
if (ALIGNED || base_idx * VEC_SIZE < hidden_in)
{
// Load values from tensor B
aligned_fp8x8 b_vec[TILE_OUT];
#pragma unroll
for (int tile_out_idx = 0; tile_out_idx < TILE_OUT; tile_out_idx++)
{
int current_row = row_idx * TILE_OUT + tile_out_idx;
b_vec[tile_out_idx]
= reinterpret_cast<aligned_fp8x8 const*>(B)[current_row * hidden_in / VEC_SIZE + base_idx];
}
// Load values from tensor A
aligned_fp8x8 a_vec[TILE_TOKEN];
#pragma unroll
for (int tile_token_idx = 0; tile_token_idx < TILE_TOKEN; tile_token_idx++)
{
int current_token = token_idx * TILE_TOKEN + tile_token_idx;
a_vec[tile_token_idx]
= reinterpret_cast<aligned_fp8x8 const*>(A)[current_token * hidden_in / VEC_SIZE + base_idx];
}
// Compute partial sum
#pragma unroll
for (int tile_out_idx = 0; tile_out_idx < TILE_OUT; tile_out_idx++)
{
#pragma unroll
for (int tile_token_idx = 0; tile_token_idx < TILE_TOKEN; tile_token_idx++)
{
aligned_fp8x8 a_vec_current = a_vec[tile_token_idx];
aligned_fp8x8 b_vec_current = b_vec[tile_out_idx];
#pragma unroll
for (int i = 0; i < VEC_SIZE / 4; i++)
{
float4 a_val = float4(a_vec_current.data[i]);
float4 b_val = float4(b_vec_current.data[i]);
thread_sum[tile_token_idx][tile_out_idx] = ffma2(make_float2(a_val.x, a_val.y),
make_float2(b_val.x, b_val.y), thread_sum[tile_token_idx][tile_out_idx]);
thread_sum[tile_token_idx][tile_out_idx] = ffma2(make_float2(a_val.z, a_val.w),
make_float2(b_val.z, b_val.w), thread_sum[tile_token_idx][tile_out_idx]);
}
}
}
}
}
// Reduce partial sums using warp-level reduction.
#pragma unroll
for (int tile_out_idx = 0; tile_out_idx < TILE_OUT; tile_out_idx++)
{
#pragma unroll
for (int tile_token_idx = 0; tile_token_idx < TILE_TOKEN; tile_token_idx++)
{
float warp_sum = thread_sum[tile_token_idx][tile_out_idx].x + thread_sum[tile_token_idx][tile_out_idx].y;
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset >>= 1)
{
warp_sum += __shfl_down_sync(0xffffffff, warp_sum, offset);
}
// First thread in each warp writes to shared memory
if (tid % WARP_SIZE == 0)
{
reduce_buffer[tile_token_idx][tile_out_idx][tid / WARP_SIZE] = warp_sum;
}
}
}
__syncthreads();
if (tid == 0)
{
#pragma unroll
for (int tile_token_idx = 0; tile_token_idx < TILE_TOKEN; tile_token_idx++)
{
#pragma unroll
for (int tile_out_idx = 0; tile_out_idx < TILE_OUT; tile_out_idx++)
{
float block_sum = 0.0f;
#pragma unroll
for (int i = 0; i < BLOCK_SIZE / WARP_SIZE; i++)
{
block_sum += reduce_buffer[tile_token_idx][tile_out_idx][i];
}
int current_row = row_idx * TILE_OUT + tile_out_idx;
int current_token = token_idx * TILE_TOKEN + tile_token_idx;
C[current_token * hidden_out + current_row] = __float2bfloat16(block_sum * scaling_factor[0]);
}
}
}
#if ENABLE_PREEXIT
cudaTriggerProgrammaticLaunchCompletion();
#endif
#endif
}
#define DISPATCH_PER_BLOCK_FC_FP8_BF16_TILE_TOKEN(HIDDEN_IN, TILE_TOKEN, TILE_OUT, ALIGNED) \
do \
{ \
if (TILE_TOKEN == 1) \
{ \
return reinterpret_cast<void*>(llama4_fp8_bf16_gemm_per_block_kernel<HIDDEN_IN, 1, TILE_OUT, ALIGNED>); \
} \
if (TILE_TOKEN == 2) \
{ \
return reinterpret_cast<void*>(llama4_fp8_bf16_gemm_per_block_kernel<HIDDEN_IN, 2, TILE_OUT, ALIGNED>); \
} \
if (TILE_TOKEN == 3) \
{ \
return reinterpret_cast<void*>(llama4_fp8_bf16_gemm_per_block_kernel<HIDDEN_IN, 3, TILE_OUT, ALIGNED>); \
} \
if (TILE_TOKEN == 4) \
{ \
return reinterpret_cast<void*>(llama4_fp8_bf16_gemm_per_block_kernel<HIDDEN_IN, 4, TILE_OUT, ALIGNED>); \
} \
if (TILE_TOKEN == 8) \
{ \
return reinterpret_cast<void*>(llama4_fp8_bf16_gemm_per_block_kernel<HIDDEN_IN, 8, TILE_OUT, ALIGNED>); \
} \
throw std::invalid_argument("Invalid tile token"); \
} while (0)
#define DISPATCH_PER_BLOCK_FC_FP8_BF16_TILE_OUT(HIDDEN_IN, TILE_TOKEN, TILE_OUT, ALIGNED) \
do \
{ \
if (TILE_OUT == 1) \
{ \
DISPATCH_PER_BLOCK_FC_FP8_BF16_TILE_TOKEN(HIDDEN_IN, TILE_TOKEN, 1, ALIGNED); \
} \
if (TILE_OUT == 2) \
{ \
DISPATCH_PER_BLOCK_FC_FP8_BF16_TILE_TOKEN(HIDDEN_IN, TILE_TOKEN, 2, ALIGNED); \
} \
if (TILE_OUT == 3) \
{ \
DISPATCH_PER_BLOCK_FC_FP8_BF16_TILE_TOKEN(HIDDEN_IN, TILE_TOKEN, 3, ALIGNED); \
} \
if (TILE_OUT == 4) \
{ \
DISPATCH_PER_BLOCK_FC_FP8_BF16_TILE_TOKEN(HIDDEN_IN, TILE_TOKEN, 4, ALIGNED); \
} \
if (TILE_OUT == 8) \
{ \
DISPATCH_PER_BLOCK_FC_FP8_BF16_TILE_TOKEN(HIDDEN_IN, TILE_TOKEN, 8, ALIGNED); \
} \
throw std::invalid_argument("Invalid tile token"); \
} while (0)
#define DEFINE_GET_PER_BLOCK_FUNC_PTR(HIDDEN_IN, ALIGNED) \
void* get_per_block_func_ptr_aligned_##ALIGNED##_##HIDDEN_IN##_(int tile_token, int tile_out) \
{ \
DISPATCH_PER_BLOCK_FC_FP8_BF16_TILE_OUT(HIDDEN_IN, tile_token, tile_out, ALIGNED); \
}
} // namespace kernels::llama4_min_latency::llama4_fp8_bf16_gemm
TRTLLM_NAMESPACE_END