/* * SPDX-FileCopyrightText: Copyright (c) 2025 DeepSeek * SPDX-License-Identifier: MIT * * Licensed under the MIT License. * You may obtain a copy of the License at * * https://opensource.org/licenses/MIT * * * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #include #include "utils.cuh" namespace deep_gemm { enum class GemmType { Normal, GroupedContiguous, GroupedMasked, GroupedWithOffset, StridedBatched }; #pragma clang diagnostic push #pragma ide diagnostic ignored "cppcoreguidelines-pro-type-member-init" template __device__ __forceinline__ void get_swizzled_block_idx( const uint32_t num_m_blocks, int block_idx, uint32_t& m_block_idx, uint32_t& n_block_idx) { DG_STATIC_ASSERT(kNumNBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size"); // Swizzle for better L2 usages auto num_blocks_per_group = num_m_blocks * kNumNBlocksPerGroup; auto group_idx = block_idx / num_blocks_per_group; auto first_n_block_idx = group_idx * kNumNBlocksPerGroup; auto num_n_blocks_in_group = min(kNumNBlocksPerGroup, kNumNBlocks - first_n_block_idx); auto in_group_idx = block_idx % num_blocks_per_group; m_block_idx = in_group_idx / num_n_blocks_in_group; n_block_idx = first_n_block_idx + in_group_idx % num_n_blocks_in_group; } template struct NormalScheduler { static constexpr GemmType gemm_type = GemmType::Normal; int current_iter = -1; uint32_t num_aligned_m_blocks; uint32_t num_blocks; struct Input { uint32_t shape_m; int* grouped_layout; // no use }; NormalScheduler() {} __device__ __forceinline__ NormalScheduler(Input& input) { num_aligned_m_blocks = ceil_div(input.shape_m, BLOCK_M); num_blocks = num_aligned_m_blocks * kNumNBlocks; } __device__ __forceinline__ uint32_t get_global_m_idx(uint32_t const& block_idx) { return block_idx * BLOCK_M; } __device__ __forceinline__ uint32_t get_global_n_idx( uint32_t const shape_dim, uint32_t const block_size, uint32_t const& block_idx, uint32_t const& m_block_idx = 0) { return block_idx * block_size; } __device__ __forceinline__ uint32_t get_global_scales_a_idx(uint32_t const& block_idx) { return block_idx; } __device__ __forceinline__ uint32_t get_global_scales_b_idx( uint32_t const shape_dim, uint32_t const block_size, uint32_t const& block_idx, uint32_t const& m_block_idx = 0) { return block_idx * block_size; } __device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx) { ++current_iter; auto const next_block_idx = current_iter * gridDim.x + blockIdx.x; if (next_block_idx >= num_blocks) { return false; } get_swizzled_block_idx( num_aligned_m_blocks, next_block_idx, m_block_idx, n_block_idx); return true; } }; template struct GroupedContiguousScheduler { static constexpr GemmType gemm_type = GemmType::GroupedContiguous; int current_iter = -1; uint32_t num_aligned_m_blocks; int* grouped_layout; uint32_t num_blocks; uint32_t shape_m; struct Input { uint32_t shape_m; int* grouped_layout; }; GroupedContiguousScheduler() {} __device__ __forceinline__ GroupedContiguousScheduler(Input& input) { num_aligned_m_blocks = ceil_div(input.shape_m, BLOCK_M); num_blocks = num_aligned_m_blocks * kNumNBlocks; this->shape_m = input.shape_m; this->grouped_layout = input.grouped_layout; } __device__ __forceinline__ uint32_t get_global_m_idx(uint32_t const& block_idx) { return block_idx * BLOCK_M; } __device__ __forceinline__ uint32_t get_global_n_idx( uint32_t const shape_dim, uint32_t const block_size, uint32_t const& block_idx, uint32_t const& m_block_idx = 0) { return __ldg(grouped_layout + m_block_idx * BLOCK_M) * shape_dim + block_idx * block_size; } __device__ __forceinline__ uint32_t get_global_scales_a_idx(uint32_t const& block_idx) { return block_idx; } __device__ __forceinline__ uint32_t get_global_scales_b_idx( uint32_t const shape_dim, uint32_t const block_size, uint32_t const& block_idx, uint32_t const& m_block_idx = 0) { return __ldg(grouped_layout + m_block_idx * BLOCK_M) * shape_dim + block_idx * block_size; } __device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx) { ++current_iter; auto const next_block_idx = current_iter * gridDim.x + blockIdx.x; if (next_block_idx >= num_blocks) { return false; } get_swizzled_block_idx( num_aligned_m_blocks, next_block_idx, m_block_idx, n_block_idx); return true; } }; template struct GroupedMaskedScheduler { static constexpr GemmType gemm_type = GemmType::GroupedMasked; int current_iter = -1; uint32_t num_blocks; uint32_t num_aligned_m_blocks; uint32_t curr_group_idx; uint32_t curr_cumsum; uint32_t shape_m; int* grouped_layout; struct Input { uint32_t shape_m; int* grouped_layout; }; GroupedMaskedScheduler() {} __device__ __forceinline__ GroupedMaskedScheduler(Input& input) { num_aligned_m_blocks = ceil_div(input.shape_m, BLOCK_M); num_blocks = num_aligned_m_blocks * kNumNBlocks; this->shape_m = input.shape_m; this->grouped_layout = input.grouped_layout; curr_group_idx = 0; curr_cumsum = 0; } __device__ __forceinline__ uint32_t get_global_m_idx(uint32_t const& block_idx) { return curr_group_idx * shape_m + block_idx * BLOCK_M; } __device__ __forceinline__ uint32_t get_global_n_idx( uint32_t const shape_dim, uint32_t const block_size, uint32_t const& block_idx, uint32_t const& m_block_idx = 0) { return curr_group_idx * shape_dim + block_idx * block_size; } __device__ __forceinline__ uint32_t get_global_scales_a_idx(uint32_t const& block_idx) { return curr_group_idx * ceil_div(SHAPE_K, BLOCK_K) + block_idx; } __device__ __forceinline__ uint32_t get_global_scales_b_idx( uint32_t const shape_dim, uint32_t const block_size, uint32_t const& block_idx, uint32_t const& m_block_idx = 0) { return curr_group_idx * shape_dim + block_idx * block_size; } __device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx) { ++current_iter; auto const next_block_idx = current_iter * gridDim.x + blockIdx.x; uint32_t num_m_blocks; while (true) { // End of the task if (curr_group_idx == kNumGroups) return false; // Within current group num_m_blocks = ceil_div(static_cast(__ldg(grouped_layout + curr_group_idx)), BLOCK_M); auto current_m_block_cumsum = curr_cumsum + num_m_blocks; if (next_block_idx < current_m_block_cumsum * kNumNBlocks) break; // Move to check the next group curr_group_idx++; curr_cumsum = current_m_block_cumsum; } get_swizzled_block_idx( num_m_blocks, next_block_idx - curr_cumsum * kNumNBlocks, m_block_idx, n_block_idx); return true; } }; template struct GroupedWithOffsetScheduler { static constexpr GemmType gemm_type = GemmType::GroupedWithOffset; int current_iter = -1; uint32_t curr_group_idx; uint32_t curr_cumsum; int64_t m_offset; int64_t m_padded_4_offset; int64_t m_boundary; int64_t* problem_m_offsets; int64_t* problem_m_padded_offsets; struct Input { uint32_t shape_m; int64_t* problem_m_offsets; int64_t* problem_m_padded_offsets; }; GroupedWithOffsetScheduler() {} __device__ __forceinline__ GroupedWithOffsetScheduler(Input& input) { this->problem_m_offsets = input.problem_m_offsets; this->problem_m_padded_offsets = input.problem_m_padded_offsets; curr_group_idx = 0; curr_cumsum = 0; } __device__ __forceinline__ uint32_t get_global_m_idx(uint32_t const& block_idx) { return m_offset + block_idx * BLOCK_M; } __device__ __forceinline__ uint32_t get_global_n_idx( uint32_t const shape_dim, uint32_t const block_size, uint32_t const& block_idx, uint32_t const& m_block_idx = 0) { return curr_group_idx * shape_dim + block_idx * block_size; } __device__ __forceinline__ uint32_t get_global_scales_a_idx(uint32_t const& block_idx) { return m_padded_4_offset + block_idx * BLOCK_M; } __device__ __forceinline__ uint32_t get_global_scales_b_idx( uint32_t const shape_dim, uint32_t const block_size, uint32_t const& block_idx, uint32_t const& m_block_idx = 0) { return curr_group_idx * shape_dim + block_idx * block_size; } __device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx) { ++current_iter; auto const next_block_idx = current_iter * gridDim.x + blockIdx.x; uint32_t num_m_blocks; while (true) { // End of the task if (curr_group_idx == kNumGroups) return false; m_padded_4_offset = __ldg(problem_m_padded_offsets + curr_group_idx); m_offset = __ldg(problem_m_offsets + curr_group_idx); m_boundary = __ldg(problem_m_offsets + curr_group_idx + 1); auto m = m_boundary - m_offset; // Within current group num_m_blocks = ceil_div(m, static_cast(BLOCK_M)); auto current_m_block_cumsum = curr_cumsum + num_m_blocks; if (next_block_idx < current_m_block_cumsum * kNumNBlocks) break; // Move to check the next group curr_group_idx++; curr_cumsum = current_m_block_cumsum; } get_swizzled_block_idx( num_m_blocks, next_block_idx - curr_cumsum * kNumNBlocks, m_block_idx, n_block_idx); return true; } }; template struct StridedBatchedScheduler { static constexpr GemmType gemm_type = GemmType::StridedBatched; int current_iter = -1; uint32_t curr_group_idx; uint32_t curr_cumsum; int64_t m_offset; int64_t m_boundary; struct Input { uint32_t shape_m; uint64_t ld_a; uint64_t stride_a; uint64_t ld_b; uint64_t stride_b; uint64_t ld_d; uint64_t stride_d; } input; StridedBatchedScheduler() {} __device__ __forceinline__ StridedBatchedScheduler(Input& input) { this->input = input; curr_group_idx = 0; curr_cumsum = 0; } __device__ __forceinline__ uint32_t get_global_m_idx(uint32_t const& block_idx) { // Assuming stride_a % ld_a == 0 && stride_a >= ld_a return input.stride_a / input.ld_a * curr_group_idx + block_idx * BLOCK_M; } __device__ __forceinline__ uint32_t get_global_n_idx( uint32_t const shape_dim, uint32_t const block_size, uint32_t const& block_idx, uint32_t const& m_block_idx = 0) { // Assuming stride_b % ld_b == 0 && stride_b >= ld_b return input.stride_b / input.ld_b * curr_group_idx + block_idx * block_size; } __device__ __forceinline__ uint32_t get_global_scales_a_idx(uint32_t const& block_idx) { return curr_group_idx * ceil_div(SHAPE_K, BLOCK_K) + block_idx; } __device__ __forceinline__ uint32_t get_global_scales_b_idx( uint32_t const shape_dim, uint32_t const block_size, uint32_t const& block_idx, uint32_t const& m_block_idx = 0) { return curr_group_idx * shape_dim + block_idx * block_size; } __device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx) { ++current_iter; auto const next_block_idx = current_iter * gridDim.x + blockIdx.x; uint32_t num_m_blocks; while (true) { // End of the task if (curr_group_idx == kNumGroups) return false; m_offset = curr_group_idx * input.shape_m; m_boundary = (curr_group_idx + 1) * input.shape_m; // Within current group num_m_blocks = ceil_div(input.shape_m, BLOCK_M); auto current_m_block_cumsum = curr_cumsum + num_m_blocks; if (next_block_idx < current_m_block_cumsum * kNumNBlocks) break; // Move to check the next group curr_group_idx++; curr_cumsum = current_m_block_cumsum; } get_swizzled_block_idx( num_m_blocks, next_block_idx - curr_cumsum * kNumNBlocks, m_block_idx, n_block_idx); return true; } }; template struct SchedulerSelector { static constexpr auto select_type() { if constexpr (GT == GemmType::Normal) return NormalScheduler(); if constexpr (GT == GemmType::GroupedContiguous) return GroupedContiguousScheduler(); if constexpr (GT == GemmType::GroupedMasked) return GroupedMaskedScheduler(); if constexpr (GT == GemmType::GroupedWithOffset) return GroupedWithOffsetScheduler(); if constexpr (GT == GemmType::StridedBatched) return StridedBatchedScheduler(); } using type = decltype(select_type()); }; #pragma clang diagnostic pop } // namespace deep_gemm