TensorRT-LLMs/cpp/include/tensorrt_llm/deep_gemm/scheduler.cuh
yunruis 30c5b4183a
refactoring: port customized kernels with public cutlass version (#5027)
Signed-off-by: yunruis 

Merge this to unblock others since the full CI has been run through
2025-06-13 16:19:31 +08:00

752 lines
26 KiB
Plaintext

/*
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#ifndef NVRTC_JIT_COMPILATION
#include <cstdint>
#endif
#include "utils.cuh"
namespace deep_gemm
{
enum class GemmType
{
Normal,
GroupedContiguous,
GroupedMasked,
GroupedWithOffset,
StridedBatched
};
#pragma clang diagnostic push
#pragma ide diagnostic ignored "cppcoreguidelines-pro-type-member-init"
template <uint32_t kNumTMAMulticast, uint32_t kNumNBlocks, uint32_t kNumNBlocksPerGroup>
__device__ __forceinline__ void get_swizzled_block_idx(
const uint32_t num_m_blocks, int block_idx, uint32_t& m_block_idx, uint32_t& n_block_idx)
{
DG_STATIC_ASSERT(kNumNBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size");
// Swizzle for better L2 usages
auto num_blocks_per_group = num_m_blocks * kNumNBlocksPerGroup;
auto group_idx = block_idx / num_blocks_per_group;
auto first_n_block_idx = group_idx * kNumNBlocksPerGroup;
auto num_n_blocks_in_group = min(kNumNBlocksPerGroup, kNumNBlocks - first_n_block_idx);
auto in_group_idx = block_idx % num_blocks_per_group;
m_block_idx = in_group_idx / num_n_blocks_in_group;
n_block_idx = first_n_block_idx + in_group_idx % num_n_blocks_in_group;
}
struct NormalSchedulerInput
{
uint32_t shape_m;
int* grouped_layout; // no use
};
struct NormalSchedulerInputSwapAB
{
uint32_t shape_n;
int* grouped_layout; // no use
};
struct GroupedContiguousSchedulerInput
{
uint32_t shape_m;
int* grouped_layout;
};
struct GroupedMaskedSchedulerInput
{
uint32_t shape_m;
int* grouped_layout;
};
struct GroupedWithOffsetSchedulerInput
{
uint32_t shape_m;
int64_t* problem_m_offsets;
};
struct GroupedWithOffsetSchedulerInputSwapAB
{
uint32_t shape_m;
int64_t* problem_n_offsets;
};
struct StridedBatchedSchedulerInput
{
uint32_t shape_m;
uint64_t ld_a;
uint64_t stride_a;
uint64_t ld_b;
uint64_t stride_b;
uint64_t ld_d;
uint64_t stride_d;
};
struct StridedBatchedSchedulerInputSwapAB
{
uint32_t shape_n;
uint64_t ld_a;
uint64_t stride_a;
uint64_t ld_b;
uint64_t stride_b;
uint64_t ld_d;
uint64_t stride_d;
};
template <uint32_t SHAPE_N, uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t kNumGroups, uint32_t kNumTMAMulticast,
uint32_t kNumNBlocks = ceil_div(SHAPE_N, BLOCK_N), uint32_t kNumNBlocksPerGroup = 16>
struct NormalScheduler
{
static constexpr GemmType gemm_type = GemmType::Normal;
int current_iter = -1;
uint32_t num_aligned_m_blocks;
uint32_t num_blocks;
using Input = NormalSchedulerInput;
Input input;
NormalScheduler() {}
__device__ __forceinline__ NormalScheduler(Input& input)
{
num_aligned_m_blocks = ceil_div(input.shape_m, BLOCK_M);
num_blocks = num_aligned_m_blocks * kNumNBlocks;
}
__device__ __forceinline__ uint32_t get_global_m_idx(uint32_t const& block_idx)
{
return block_idx * BLOCK_M;
}
__device__ __forceinline__ uint32_t get_global_n_idx(
uint32_t const shape_dim, uint32_t const block_size, uint32_t const& block_idx, uint32_t const& m_block_idx = 0)
{
return block_idx * block_size;
}
__device__ __forceinline__ uint32_t get_global_scales_a_idx(uint32_t const& block_idx)
{
return block_idx;
}
__device__ __forceinline__ uint32_t get_global_scales_b_idx(
uint32_t const shape_dim, uint32_t const block_size, uint32_t const& block_idx, uint32_t const& m_block_idx = 0)
{
return block_idx * block_size;
}
__device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx)
{
++current_iter;
auto const next_block_idx = current_iter * gridDim.x + blockIdx.x;
if (next_block_idx >= num_blocks)
{
return false;
}
get_swizzled_block_idx<kNumTMAMulticast, kNumNBlocks, kNumNBlocksPerGroup>(
num_aligned_m_blocks, next_block_idx, m_block_idx, n_block_idx);
return true;
}
};
template <uint32_t SHAPE_M, uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t kNumGroups, uint32_t kNumTMAMulticast,
uint32_t kNumMBlocks = ceil_div(SHAPE_M, BLOCK_M), uint32_t kNumMBlocksPerGroup = 16>
struct NormalSchedulerSwapAB
{
static constexpr GemmType gemm_type = GemmType::Normal;
int current_iter = -1;
uint32_t num_aligned_n_blocks;
uint32_t num_blocks;
using Input = NormalSchedulerInputSwapAB;
Input input;
NormalSchedulerSwapAB() {}
__device__ __forceinline__ NormalSchedulerSwapAB(Input& input)
{
num_aligned_n_blocks = ceil_div(input.shape_n, BLOCK_N);
num_blocks = num_aligned_n_blocks * kNumMBlocks;
}
// weight
__device__ __forceinline__ uint32_t get_global_m_idx(
const uint32_t shape_dim, const uint32_t block_size, uint32_t const& block_idx, uint32_t const& n_block_idx = 0)
{
return block_idx * block_size;
}
// act
__device__ __forceinline__ uint32_t get_global_n_idx(uint32_t const& block_idx)
{
return block_idx * BLOCK_N;
}
// act scales
__device__ __forceinline__ uint32_t get_global_scales_b_idx(uint32_t const& block_idx)
{
return block_idx;
}
// weight scales
__device__ __forceinline__ uint32_t get_global_scales_a_idx(
const uint32_t shape_dim, const uint32_t block_size, uint32_t const& block_idx, uint32_t const& n_block_idx = 0)
{
return block_idx * block_size;
}
__device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx)
{
++current_iter;
auto const next_block_idx = current_iter * gridDim.x + blockIdx.x;
if (next_block_idx >= num_blocks)
{
return false;
}
get_swizzled_block_idx<kNumTMAMulticast, kNumMBlocks, kNumMBlocksPerGroup>(
num_aligned_n_blocks, next_block_idx, n_block_idx, m_block_idx);
return true;
}
};
template <uint32_t SHAPE_N, uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t kNumGroups, uint32_t kNumTMAMulticast,
uint32_t kNumNBlocks, uint32_t kNumNBlocksPerGroup>
struct GroupedContiguousScheduler
{
static constexpr GemmType gemm_type = GemmType::GroupedContiguous;
int current_iter = -1;
uint32_t num_aligned_m_blocks;
int* grouped_layout;
uint32_t num_blocks;
uint32_t shape_m;
using Input = GroupedContiguousSchedulerInput;
Input input;
GroupedContiguousScheduler() {}
__device__ __forceinline__ GroupedContiguousScheduler(Input& input)
{
num_aligned_m_blocks = ceil_div(input.shape_m, BLOCK_M);
num_blocks = num_aligned_m_blocks * kNumNBlocks;
this->shape_m = input.shape_m;
this->grouped_layout = input.grouped_layout;
}
__device__ __forceinline__ uint32_t get_global_m_idx(uint32_t const& block_idx)
{
return block_idx * BLOCK_M;
}
__device__ __forceinline__ uint32_t get_global_n_idx(
uint32_t const shape_dim, uint32_t const block_size, uint32_t const& block_idx, uint32_t const& m_block_idx = 0)
{
return __ldg(grouped_layout + m_block_idx * BLOCK_M) * shape_dim + block_idx * block_size;
}
__device__ __forceinline__ uint32_t get_global_scales_a_idx(uint32_t const& block_idx)
{
return block_idx;
}
__device__ __forceinline__ uint32_t get_global_scales_b_idx(
uint32_t const shape_dim, uint32_t const block_size, uint32_t const& block_idx, uint32_t const& m_block_idx = 0)
{
return __ldg(grouped_layout + m_block_idx * BLOCK_M) * shape_dim + block_idx * block_size;
}
__device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx)
{
++current_iter;
auto const next_block_idx = current_iter * gridDim.x + blockIdx.x;
if (next_block_idx >= num_blocks)
{
return false;
}
get_swizzled_block_idx<kNumTMAMulticast, kNumNBlocks, kNumNBlocksPerGroup>(
num_aligned_m_blocks, next_block_idx, m_block_idx, n_block_idx);
return true;
}
};
template <uint32_t SHAPE_N, uint32_t SHAPE_K, uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K, uint32_t kNumGroups,
uint32_t kNumTMAMulticast, uint32_t kNumNBlocks = ceil_div(SHAPE_N, BLOCK_N), uint32_t kNumNBlocksPerGroup = 16>
struct GroupedMaskedScheduler
{
static constexpr GemmType gemm_type = GemmType::GroupedMasked;
int current_iter = -1;
uint32_t num_blocks;
uint32_t num_aligned_m_blocks;
uint32_t curr_group_idx;
uint32_t curr_cumsum;
uint32_t shape_m;
int* grouped_layout;
using Input = GroupedMaskedSchedulerInput;
Input input;
GroupedMaskedScheduler() {}
__device__ __forceinline__ GroupedMaskedScheduler(Input& input)
{
num_aligned_m_blocks = ceil_div(input.shape_m, BLOCK_M);
num_blocks = num_aligned_m_blocks * kNumNBlocks;
this->shape_m = input.shape_m;
this->grouped_layout = input.grouped_layout;
curr_group_idx = 0;
curr_cumsum = 0;
}
__device__ __forceinline__ uint32_t get_global_m_idx(uint32_t const& block_idx)
{
return curr_group_idx * shape_m + block_idx * BLOCK_M;
}
__device__ __forceinline__ uint32_t get_global_n_idx(
uint32_t const shape_dim, uint32_t const block_size, uint32_t const& block_idx, uint32_t const& m_block_idx = 0)
{
return curr_group_idx * shape_dim + block_idx * block_size;
}
__device__ __forceinline__ uint32_t get_global_scales_a_idx(uint32_t const& block_idx)
{
return curr_group_idx * ceil_div(SHAPE_K, BLOCK_K) + block_idx;
}
__device__ __forceinline__ uint32_t get_global_scales_b_idx(
uint32_t const shape_dim, uint32_t const block_size, uint32_t const& block_idx, uint32_t const& m_block_idx = 0)
{
return curr_group_idx * shape_dim + block_idx * block_size;
}
__device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx)
{
++current_iter;
auto const next_block_idx = current_iter * gridDim.x + blockIdx.x;
uint32_t num_m_blocks;
while (true)
{
// End of the task
if (curr_group_idx == kNumGroups)
return false;
// Within current group
num_m_blocks = ceil_div(static_cast<uint32_t>(__ldg(grouped_layout + curr_group_idx)), BLOCK_M);
auto current_m_block_cumsum = curr_cumsum + num_m_blocks;
if (next_block_idx < current_m_block_cumsum * kNumNBlocks)
break;
// Move to check the next group
curr_group_idx++;
curr_cumsum = current_m_block_cumsum;
}
get_swizzled_block_idx<kNumTMAMulticast, kNumNBlocks, kNumNBlocksPerGroup>(
num_m_blocks, next_block_idx - curr_cumsum * kNumNBlocks, m_block_idx, n_block_idx);
return true;
}
};
// Need to keep the same as the one in tests/unittest/_torch/thop/deep_gemm_tests.py
template <typename T_offset, typename T_index>
__host__ __device__ __forceinline__ T_offset compute_padded_offset(T_offset offset, T_index problem_idx)
{
// This formulation ensures that padded_offset[i + 1] - padded_offset[i] >= offset[i + 1] - offset[i].
constexpr T_offset alignment = 32;
return (offset + problem_idx * (alignment - 1)) / alignment * alignment;
}
template <uint32_t SHAPE_N, uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t kNumGroups, uint32_t kNumTMAMulticast,
uint32_t kNumNBlocks = ceil_div(SHAPE_N, BLOCK_N), uint32_t kNumNBlocksPerGroup = 16>
struct GroupedWithOffsetScheduler
{
static constexpr GemmType gemm_type = GemmType::GroupedWithOffset;
int current_iter = -1;
uint32_t curr_group_idx;
uint32_t curr_cumsum;
int64_t m_offset;
int64_t m_padded_4_offset;
int64_t m_boundary;
int64_t* problem_m_offsets;
using Input = GroupedWithOffsetSchedulerInput;
Input input;
GroupedWithOffsetScheduler() {}
__device__ __forceinline__ GroupedWithOffsetScheduler(Input& input)
{
this->problem_m_offsets = input.problem_m_offsets;
curr_group_idx = 0;
curr_cumsum = 0;
}
__device__ __forceinline__ uint32_t get_global_m_idx(uint32_t const& block_idx)
{
return m_offset + block_idx * BLOCK_M;
}
__device__ __forceinline__ uint32_t get_global_n_idx(
uint32_t const shape_dim, uint32_t const block_size, uint32_t const& block_idx, uint32_t const& m_block_idx = 0)
{
return curr_group_idx * shape_dim + block_idx * block_size;
}
__device__ __forceinline__ uint32_t get_global_scales_a_idx(uint32_t const& block_idx)
{
return m_padded_4_offset + block_idx * BLOCK_M;
}
__device__ __forceinline__ uint32_t get_global_scales_b_idx(
uint32_t const shape_dim, uint32_t const block_size, uint32_t const& block_idx, uint32_t const& m_block_idx = 0)
{
return curr_group_idx * shape_dim + block_idx * block_size;
}
__device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx)
{
++current_iter;
auto const next_block_idx = current_iter * gridDim.x + blockIdx.x;
uint32_t num_m_blocks;
while (true)
{
// End of the task
if (curr_group_idx == kNumGroups)
return false;
m_offset = __ldg(problem_m_offsets + curr_group_idx);
m_boundary = __ldg(problem_m_offsets + curr_group_idx + 1);
m_padded_4_offset = compute_padded_offset(m_offset, curr_group_idx);
auto m = m_boundary - m_offset;
// Within current group
num_m_blocks = ceil_div(m, static_cast<int64_t>(BLOCK_M));
auto current_m_block_cumsum = curr_cumsum + num_m_blocks;
if (next_block_idx < current_m_block_cumsum * kNumNBlocks)
break;
// Move to check the next group
curr_group_idx++;
curr_cumsum = current_m_block_cumsum;
}
get_swizzled_block_idx<kNumTMAMulticast, kNumNBlocks, kNumNBlocksPerGroup>(
num_m_blocks, next_block_idx - curr_cumsum * kNumNBlocks, m_block_idx, n_block_idx);
return true;
}
};
template <uint32_t SHAPE_M, uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t kNumGroups, uint32_t kNumTMAMulticast,
uint32_t kNumMBlocks = ceil_div(SHAPE_M, BLOCK_M), uint32_t kNumMBlocksPerGroup = 16>
struct GroupedWithOffsetSchedulerSwapAB
{
static constexpr GemmType gemm_type = GemmType::GroupedWithOffset;
int current_iter = -1;
uint32_t curr_group_idx;
uint32_t curr_cumsum;
int64_t n_offset;
int64_t n_padded_4_offset;
int64_t n_boundary;
int64_t* problem_n_offsets;
using Input = GroupedWithOffsetSchedulerInputSwapAB;
Input input;
GroupedWithOffsetSchedulerSwapAB() {}
__device__ __forceinline__ GroupedWithOffsetSchedulerSwapAB(Input& input)
{
this->problem_n_offsets = input.problem_n_offsets;
curr_group_idx = 0;
curr_cumsum = 0;
}
// weight
__device__ __forceinline__ uint32_t get_global_m_idx(
const uint32_t shape_dim, const uint32_t block_size, uint32_t const& block_idx, uint32_t const& n_block_idx = 0)
{
return curr_group_idx * shape_dim + block_idx * block_size;
}
// act
__device__ __forceinline__ uint32_t get_global_n_idx(uint32_t const& block_idx)
{
return n_offset + block_idx * BLOCK_N;
}
// act scales
__device__ __forceinline__ uint32_t get_global_scales_b_idx(uint32_t const& block_idx)
{
return n_padded_4_offset + block_idx * BLOCK_N;
}
// weight scales
__device__ __forceinline__ uint32_t get_global_scales_a_idx(
const uint32_t shape_dim, const uint32_t block_size, uint32_t const& block_idx, uint32_t const& n_block_idx = 0)
{
return curr_group_idx * shape_dim + block_idx * block_size;
}
__device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx)
{
++current_iter;
auto const next_block_idx = current_iter * gridDim.x + blockIdx.x;
uint32_t num_n_blocks;
while (true)
{
// End of the task
if (curr_group_idx == kNumGroups)
return false;
n_offset = __ldg(problem_n_offsets + curr_group_idx);
n_boundary = __ldg(problem_n_offsets + curr_group_idx + 1);
n_padded_4_offset = compute_padded_offset(n_offset, curr_group_idx);
auto n = n_boundary - n_offset;
// Within current group
num_n_blocks = ceil_div(n, static_cast<int64_t>(BLOCK_N));
auto current_n_block_cumsum = curr_cumsum + num_n_blocks;
if (next_block_idx < current_n_block_cumsum * kNumMBlocks)
break;
// Move to check the next group
curr_group_idx++;
curr_cumsum = current_n_block_cumsum;
}
get_swizzled_block_idx<kNumTMAMulticast, kNumMBlocks, kNumMBlocksPerGroup>(
num_n_blocks, next_block_idx - curr_cumsum * kNumMBlocks, n_block_idx, m_block_idx);
return true;
}
};
template <uint32_t SHAPE_N, uint32_t SHAPE_K, uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K, uint32_t kNumGroups,
uint32_t kNumTMAMulticast, uint32_t kNumNBlocks = ceil_div(SHAPE_N, BLOCK_N), uint32_t kNumNBlocksPerGroup = 16>
struct StridedBatchedScheduler
{
static constexpr GemmType gemm_type = GemmType::StridedBatched;
int current_iter = -1;
uint32_t curr_group_idx;
uint32_t curr_cumsum;
int64_t m_offset;
int64_t m_boundary;
using Input = StridedBatchedSchedulerInput;
Input input;
StridedBatchedScheduler() {}
__device__ __forceinline__ StridedBatchedScheduler(Input& input)
{
this->input = input;
curr_group_idx = 0;
curr_cumsum = 0;
}
__device__ __forceinline__ uint32_t get_global_m_idx(uint32_t const& block_idx)
{
// Assuming stride_a % ld_a == 0 && stride_a >= ld_a
return input.stride_a / input.ld_a * curr_group_idx + block_idx * BLOCK_M;
}
__device__ __forceinline__ uint32_t get_global_n_idx(
uint32_t const shape_dim, uint32_t const block_size, uint32_t const& block_idx, uint32_t const& m_block_idx = 0)
{
// Assuming stride_b % ld_b == 0 && stride_b >= ld_b
return input.stride_b / input.ld_b * curr_group_idx + block_idx * block_size;
}
__device__ __forceinline__ uint32_t get_global_scales_a_idx(uint32_t const& block_idx)
{
return curr_group_idx * ceil_div(SHAPE_K, BLOCK_K) + block_idx;
}
__device__ __forceinline__ uint32_t get_global_scales_b_idx(
uint32_t const shape_dim, uint32_t const block_size, uint32_t const& block_idx, uint32_t const& m_block_idx = 0)
{
return curr_group_idx * shape_dim + block_idx * block_size;
}
__device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx)
{
++current_iter;
auto const next_block_idx = current_iter * gridDim.x + blockIdx.x;
uint32_t num_m_blocks;
while (true)
{
// End of the task
if (curr_group_idx == kNumGroups)
return false;
m_offset = curr_group_idx * input.shape_m;
m_boundary = (curr_group_idx + 1) * input.shape_m;
// Within current group
num_m_blocks = ceil_div(input.shape_m, BLOCK_M);
auto current_m_block_cumsum = curr_cumsum + num_m_blocks;
if (next_block_idx < current_m_block_cumsum * kNumNBlocks)
break;
// Move to check the next group
curr_group_idx++;
curr_cumsum = current_m_block_cumsum;
}
get_swizzled_block_idx<kNumTMAMulticast, kNumNBlocks, kNumNBlocksPerGroup>(
num_m_blocks, next_block_idx - curr_cumsum * kNumNBlocks, m_block_idx, n_block_idx);
return true;
}
};
template <uint32_t SHAPE_M, uint32_t SHAPE_K, uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K, uint32_t kNumGroups,
uint32_t kNumTMAMulticast, uint32_t kNumMBlocks = ceil_div(SHAPE_M, BLOCK_M), uint32_t kNumMBlocksPerGroup = 16>
struct StridedBatchedSchedulerSwapAB
{
static constexpr GemmType gemm_type = GemmType::StridedBatched;
int current_iter = -1;
uint32_t curr_group_idx;
uint32_t curr_cumsum;
int64_t n_offset;
int64_t n_boundary;
using Input = StridedBatchedSchedulerInputSwapAB;
Input input;
StridedBatchedSchedulerSwapAB() {}
__device__ __forceinline__ StridedBatchedSchedulerSwapAB(Input& input)
{
this->input = input;
curr_group_idx = 0;
curr_cumsum = 0;
}
// weight
__device__ __forceinline__ uint32_t get_global_m_idx(
const uint32_t shape_dim, const uint32_t block_size, uint32_t const& block_idx, uint32_t const& n_block_idx = 0)
{
// Assuming stride_a % ld_a == 0 && stride_a >= ld_a
return input.stride_a / input.ld_a * curr_group_idx + block_idx * block_size;
}
// act
__device__ __forceinline__ uint32_t get_global_n_idx(uint32_t const& block_idx)
{
// Assuming stride_b % ld_b == 0 && stride_b >= ld_b
return input.stride_b / input.ld_b * curr_group_idx + block_idx * BLOCK_N;
}
// act scales
__device__ __forceinline__ uint32_t get_global_scales_b_idx(uint32_t const& block_idx)
{
return curr_group_idx * ceil_div(SHAPE_K, BLOCK_K) + block_idx;
}
// weight scales
__device__ __forceinline__ uint32_t get_global_scales_a_idx(
const uint32_t shape_dim, const uint32_t block_size, uint32_t const& block_idx, uint32_t const& n_block_idx = 0)
{
return curr_group_idx * shape_dim + block_idx * block_size;
}
__device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx)
{
++current_iter;
auto const next_block_idx = current_iter * gridDim.x + blockIdx.x;
uint32_t num_n_blocks;
while (true)
{
// End of the task
if (curr_group_idx == kNumGroups)
return false;
n_offset = curr_group_idx * input.shape_n;
n_boundary = (curr_group_idx + 1) * input.shape_n;
// Within current group
num_n_blocks = ceil_div(input.shape_n, BLOCK_N);
auto current_n_block_cumsum = curr_cumsum + num_n_blocks;
if (next_block_idx < current_n_block_cumsum * kNumMBlocks)
break;
// Move to check the next group
curr_group_idx++;
curr_cumsum = current_n_block_cumsum;
}
// Note: Here, m and n roles are swapped
get_swizzled_block_idx<kNumTMAMulticast, kNumMBlocks, kNumMBlocksPerGroup>(
num_n_blocks, next_block_idx - curr_cumsum * kNumMBlocks, n_block_idx, m_block_idx);
return true;
}
};
template <GemmType GT, uint32_t SHAPE_N, uint32_t SHAPE_K, uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
uint32_t kNumGroups, uint32_t kNumTMAMulticast, uint32_t kNumNBlocks = ceil_div(SHAPE_N, BLOCK_N),
uint32_t kNumNBlocksPerGroup = 16>
struct SchedulerSelector
{
static constexpr auto select_type()
{
if constexpr (GT == GemmType::Normal)
return NormalScheduler<SHAPE_N, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast, kNumNBlocks,
kNumNBlocksPerGroup>();
if constexpr (GT == GemmType::GroupedContiguous)
return GroupedContiguousScheduler<SHAPE_N, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast, kNumNBlocks,
kNumNBlocksPerGroup>();
if constexpr (GT == GemmType::GroupedMasked)
return GroupedMaskedScheduler<SHAPE_N, SHAPE_K, BLOCK_M, BLOCK_N, BLOCK_K, kNumGroups, kNumTMAMulticast,
kNumNBlocks, kNumNBlocksPerGroup>();
if constexpr (GT == GemmType::GroupedWithOffset)
return GroupedWithOffsetScheduler<SHAPE_N, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast, kNumNBlocks,
kNumNBlocksPerGroup>();
if constexpr (GT == GemmType::StridedBatched)
return StridedBatchedScheduler<SHAPE_N, SHAPE_K, BLOCK_M, BLOCK_N, BLOCK_K, kNumGroups, kNumTMAMulticast,
kNumNBlocks, kNumNBlocksPerGroup>();
}
using type = decltype(select_type());
};
template <GemmType GT, uint32_t SHAPE_M, uint32_t SHAPE_K, uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
uint32_t kNumGroups, uint32_t kNumTMAMulticast, uint32_t kNumMBlocks = ceil_div(SHAPE_M, BLOCK_M),
uint32_t kNumMBlocksPerGroup = 16>
struct SchedulerSelectorSwapAB
{
static constexpr auto select_type()
{
static_assert(GT == GemmType::GroupedWithOffset || GT == GemmType::Normal,
"Only GroupedWithOffset and Normal are supported for SwapAB");
if constexpr (GT == GemmType::Normal)
return NormalSchedulerSwapAB<SHAPE_M, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast, kNumMBlocks,
kNumMBlocksPerGroup>();
if constexpr (GT == GemmType::GroupedWithOffset)
return GroupedWithOffsetSchedulerSwapAB<SHAPE_M, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast,
kNumMBlocks, kNumMBlocksPerGroup>();
}
using type = decltype(select_type());
};
#pragma clang diagnostic pop
} // namespace deep_gemm