/* * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #ifndef NVRTC_JIT_COMPILATION #include #endif #include "utils.cuh" namespace deep_gemm { enum class GemmType { Normal, GroupedContiguous, GroupedMasked, GroupedWithOffset, StridedBatched }; #pragma clang diagnostic push #pragma ide diagnostic ignored "cppcoreguidelines-pro-type-member-init" template __device__ __forceinline__ void get_swizzled_block_idx( const uint32_t num_m_blocks, int block_idx, uint32_t& m_block_idx, uint32_t& n_block_idx) { DG_STATIC_ASSERT(kNumNBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size"); // Swizzle for better L2 usages auto num_blocks_per_group = num_m_blocks * kNumNBlocksPerGroup; auto group_idx = block_idx / num_blocks_per_group; auto first_n_block_idx = group_idx * kNumNBlocksPerGroup; auto num_n_blocks_in_group = min(kNumNBlocksPerGroup, kNumNBlocks - first_n_block_idx); auto in_group_idx = block_idx % num_blocks_per_group; m_block_idx = in_group_idx / num_n_blocks_in_group; n_block_idx = first_n_block_idx + in_group_idx % num_n_blocks_in_group; } struct NormalSchedulerInput { uint32_t shape_m; int* grouped_layout; // no use }; struct NormalSchedulerInputSwapAB { uint32_t shape_n; int* grouped_layout; // no use }; struct GroupedContiguousSchedulerInput { uint32_t shape_m; int* grouped_layout; }; struct GroupedMaskedSchedulerInput { uint32_t shape_m; int* grouped_layout; }; struct GroupedWithOffsetSchedulerInput { uint32_t shape_m; int64_t* problem_m_offsets; }; struct GroupedWithOffsetSchedulerInputSwapAB { uint32_t shape_m; int64_t* problem_n_offsets; }; struct StridedBatchedSchedulerInput { uint32_t shape_m; uint64_t ld_a; uint64_t stride_a; uint64_t ld_b; uint64_t stride_b; uint64_t ld_d; uint64_t stride_d; }; struct StridedBatchedSchedulerInputSwapAB { uint32_t shape_n; uint64_t ld_a; uint64_t stride_a; uint64_t ld_b; uint64_t stride_b; uint64_t ld_d; uint64_t stride_d; }; template struct NormalScheduler { static constexpr GemmType gemm_type = GemmType::Normal; int current_iter = -1; uint32_t num_aligned_m_blocks; uint32_t num_blocks; using Input = NormalSchedulerInput; Input input; NormalScheduler() {} __device__ __forceinline__ NormalScheduler(Input& input) { num_aligned_m_blocks = ceil_div(input.shape_m, BLOCK_M); num_blocks = num_aligned_m_blocks * kNumNBlocks; } __device__ __forceinline__ uint32_t get_global_m_idx(uint32_t const& block_idx) { return block_idx * BLOCK_M; } __device__ __forceinline__ uint32_t get_global_n_idx( uint32_t const shape_dim, uint32_t const block_size, uint32_t const& block_idx, uint32_t const& m_block_idx = 0) { return block_idx * block_size; } __device__ __forceinline__ uint32_t get_global_scales_a_idx(uint32_t const& block_idx) { return block_idx; } __device__ __forceinline__ uint32_t get_global_scales_b_idx( uint32_t const shape_dim, uint32_t const block_size, uint32_t const& block_idx, uint32_t const& m_block_idx = 0) { return block_idx * block_size; } __device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx) { ++current_iter; auto const next_block_idx = current_iter * gridDim.x + blockIdx.x; if (next_block_idx >= num_blocks) { return false; } get_swizzled_block_idx( num_aligned_m_blocks, next_block_idx, m_block_idx, n_block_idx); return true; } }; template struct NormalSchedulerSwapAB { static constexpr GemmType gemm_type = GemmType::Normal; int current_iter = -1; uint32_t num_aligned_n_blocks; uint32_t num_blocks; using Input = NormalSchedulerInputSwapAB; Input input; NormalSchedulerSwapAB() {} __device__ __forceinline__ NormalSchedulerSwapAB(Input& input) { num_aligned_n_blocks = ceil_div(input.shape_n, BLOCK_N); num_blocks = num_aligned_n_blocks * kNumMBlocks; } // weight __device__ __forceinline__ uint32_t get_global_m_idx( const uint32_t shape_dim, const uint32_t block_size, uint32_t const& block_idx, uint32_t const& n_block_idx = 0) { return block_idx * block_size; } // act __device__ __forceinline__ uint32_t get_global_n_idx(uint32_t const& block_idx) { return block_idx * BLOCK_N; } // act scales __device__ __forceinline__ uint32_t get_global_scales_b_idx(uint32_t const& block_idx) { return block_idx; } // weight scales __device__ __forceinline__ uint32_t get_global_scales_a_idx( const uint32_t shape_dim, const uint32_t block_size, uint32_t const& block_idx, uint32_t const& n_block_idx = 0) { return block_idx * block_size; } __device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx) { ++current_iter; auto const next_block_idx = current_iter * gridDim.x + blockIdx.x; if (next_block_idx >= num_blocks) { return false; } get_swizzled_block_idx( num_aligned_n_blocks, next_block_idx, n_block_idx, m_block_idx); return true; } }; template struct GroupedContiguousScheduler { static constexpr GemmType gemm_type = GemmType::GroupedContiguous; int current_iter = -1; uint32_t num_aligned_m_blocks; int* grouped_layout; uint32_t num_blocks; uint32_t shape_m; using Input = GroupedContiguousSchedulerInput; Input input; GroupedContiguousScheduler() {} __device__ __forceinline__ GroupedContiguousScheduler(Input& input) { num_aligned_m_blocks = ceil_div(input.shape_m, BLOCK_M); num_blocks = num_aligned_m_blocks * kNumNBlocks; this->shape_m = input.shape_m; this->grouped_layout = input.grouped_layout; } __device__ __forceinline__ uint32_t get_global_m_idx(uint32_t const& block_idx) { return block_idx * BLOCK_M; } __device__ __forceinline__ uint32_t get_global_n_idx( uint32_t const shape_dim, uint32_t const block_size, uint32_t const& block_idx, uint32_t const& m_block_idx = 0) { return __ldg(grouped_layout + m_block_idx * BLOCK_M) * shape_dim + block_idx * block_size; } __device__ __forceinline__ uint32_t get_global_scales_a_idx(uint32_t const& block_idx) { return block_idx; } __device__ __forceinline__ uint32_t get_global_scales_b_idx( uint32_t const shape_dim, uint32_t const block_size, uint32_t const& block_idx, uint32_t const& m_block_idx = 0) { return __ldg(grouped_layout + m_block_idx * BLOCK_M) * shape_dim + block_idx * block_size; } __device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx) { ++current_iter; auto const next_block_idx = current_iter * gridDim.x + blockIdx.x; if (next_block_idx >= num_blocks) { return false; } get_swizzled_block_idx( num_aligned_m_blocks, next_block_idx, m_block_idx, n_block_idx); return true; } }; template struct GroupedMaskedScheduler { static constexpr GemmType gemm_type = GemmType::GroupedMasked; int current_iter = -1; uint32_t num_blocks; uint32_t num_aligned_m_blocks; uint32_t curr_group_idx; uint32_t curr_cumsum; uint32_t shape_m; int* grouped_layout; using Input = GroupedMaskedSchedulerInput; Input input; GroupedMaskedScheduler() {} __device__ __forceinline__ GroupedMaskedScheduler(Input& input) { num_aligned_m_blocks = ceil_div(input.shape_m, BLOCK_M); num_blocks = num_aligned_m_blocks * kNumNBlocks; this->shape_m = input.shape_m; this->grouped_layout = input.grouped_layout; curr_group_idx = 0; curr_cumsum = 0; } __device__ __forceinline__ uint32_t get_global_m_idx(uint32_t const& block_idx) { return curr_group_idx * shape_m + block_idx * BLOCK_M; } __device__ __forceinline__ uint32_t get_global_n_idx( uint32_t const shape_dim, uint32_t const block_size, uint32_t const& block_idx, uint32_t const& m_block_idx = 0) { return curr_group_idx * shape_dim + block_idx * block_size; } __device__ __forceinline__ uint32_t get_global_scales_a_idx(uint32_t const& block_idx) { return curr_group_idx * ceil_div(SHAPE_K, BLOCK_K) + block_idx; } __device__ __forceinline__ uint32_t get_global_scales_b_idx( uint32_t const shape_dim, uint32_t const block_size, uint32_t const& block_idx, uint32_t const& m_block_idx = 0) { return curr_group_idx * shape_dim + block_idx * block_size; } __device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx) { ++current_iter; auto const next_block_idx = current_iter * gridDim.x + blockIdx.x; uint32_t num_m_blocks; while (true) { // End of the task if (curr_group_idx == kNumGroups) return false; // Within current group num_m_blocks = ceil_div(static_cast(__ldg(grouped_layout + curr_group_idx)), BLOCK_M); auto current_m_block_cumsum = curr_cumsum + num_m_blocks; if (next_block_idx < current_m_block_cumsum * kNumNBlocks) break; // Move to check the next group curr_group_idx++; curr_cumsum = current_m_block_cumsum; } get_swizzled_block_idx( num_m_blocks, next_block_idx - curr_cumsum * kNumNBlocks, m_block_idx, n_block_idx); return true; } }; // Need to keep the same as the one in tests/unittest/_torch/thop/deep_gemm_tests.py template __host__ __device__ __forceinline__ T_offset compute_padded_offset(T_offset offset, T_index problem_idx) { // This formulation ensures that padded_offset[i + 1] - padded_offset[i] >= offset[i + 1] - offset[i]. constexpr T_offset alignment = 32; return (offset + problem_idx * (alignment - 1)) / alignment * alignment; } template struct GroupedWithOffsetScheduler { static constexpr GemmType gemm_type = GemmType::GroupedWithOffset; int current_iter = -1; uint32_t curr_group_idx; uint32_t curr_cumsum; int64_t m_offset; int64_t m_padded_4_offset; int64_t m_boundary; int64_t* problem_m_offsets; using Input = GroupedWithOffsetSchedulerInput; Input input; GroupedWithOffsetScheduler() {} __device__ __forceinline__ GroupedWithOffsetScheduler(Input& input) { this->problem_m_offsets = input.problem_m_offsets; curr_group_idx = 0; curr_cumsum = 0; } __device__ __forceinline__ uint32_t get_global_m_idx(uint32_t const& block_idx) { return m_offset + block_idx * BLOCK_M; } __device__ __forceinline__ uint32_t get_global_n_idx( uint32_t const shape_dim, uint32_t const block_size, uint32_t const& block_idx, uint32_t const& m_block_idx = 0) { return curr_group_idx * shape_dim + block_idx * block_size; } __device__ __forceinline__ uint32_t get_global_scales_a_idx(uint32_t const& block_idx) { return m_padded_4_offset + block_idx * BLOCK_M; } __device__ __forceinline__ uint32_t get_global_scales_b_idx( uint32_t const shape_dim, uint32_t const block_size, uint32_t const& block_idx, uint32_t const& m_block_idx = 0) { return curr_group_idx * shape_dim + block_idx * block_size; } __device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx) { ++current_iter; auto const next_block_idx = current_iter * gridDim.x + blockIdx.x; uint32_t num_m_blocks; while (true) { // End of the task if (curr_group_idx == kNumGroups) return false; m_offset = __ldg(problem_m_offsets + curr_group_idx); m_boundary = __ldg(problem_m_offsets + curr_group_idx + 1); m_padded_4_offset = compute_padded_offset(m_offset, curr_group_idx); auto m = m_boundary - m_offset; // Within current group num_m_blocks = ceil_div(m, static_cast(BLOCK_M)); auto current_m_block_cumsum = curr_cumsum + num_m_blocks; if (next_block_idx < current_m_block_cumsum * kNumNBlocks) break; // Move to check the next group curr_group_idx++; curr_cumsum = current_m_block_cumsum; } get_swizzled_block_idx( num_m_blocks, next_block_idx - curr_cumsum * kNumNBlocks, m_block_idx, n_block_idx); return true; } }; template struct GroupedWithOffsetSchedulerSwapAB { static constexpr GemmType gemm_type = GemmType::GroupedWithOffset; int current_iter = -1; uint32_t curr_group_idx; uint32_t curr_cumsum; int64_t n_offset; int64_t n_padded_4_offset; int64_t n_boundary; int64_t* problem_n_offsets; using Input = GroupedWithOffsetSchedulerInputSwapAB; Input input; GroupedWithOffsetSchedulerSwapAB() {} __device__ __forceinline__ GroupedWithOffsetSchedulerSwapAB(Input& input) { this->problem_n_offsets = input.problem_n_offsets; curr_group_idx = 0; curr_cumsum = 0; } // weight __device__ __forceinline__ uint32_t get_global_m_idx( const uint32_t shape_dim, const uint32_t block_size, uint32_t const& block_idx, uint32_t const& n_block_idx = 0) { return curr_group_idx * shape_dim + block_idx * block_size; } // act __device__ __forceinline__ uint32_t get_global_n_idx(uint32_t const& block_idx) { return n_offset + block_idx * BLOCK_N; } // act scales __device__ __forceinline__ uint32_t get_global_scales_b_idx(uint32_t const& block_idx) { return n_padded_4_offset + block_idx * BLOCK_N; } // weight scales __device__ __forceinline__ uint32_t get_global_scales_a_idx( const uint32_t shape_dim, const uint32_t block_size, uint32_t const& block_idx, uint32_t const& n_block_idx = 0) { return curr_group_idx * shape_dim + block_idx * block_size; } __device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx) { ++current_iter; auto const next_block_idx = current_iter * gridDim.x + blockIdx.x; uint32_t num_n_blocks; while (true) { // End of the task if (curr_group_idx == kNumGroups) return false; n_offset = __ldg(problem_n_offsets + curr_group_idx); n_boundary = __ldg(problem_n_offsets + curr_group_idx + 1); n_padded_4_offset = compute_padded_offset(n_offset, curr_group_idx); auto n = n_boundary - n_offset; // Within current group num_n_blocks = ceil_div(n, static_cast(BLOCK_N)); auto current_n_block_cumsum = curr_cumsum + num_n_blocks; if (next_block_idx < current_n_block_cumsum * kNumMBlocks) break; // Move to check the next group curr_group_idx++; curr_cumsum = current_n_block_cumsum; } get_swizzled_block_idx( num_n_blocks, next_block_idx - curr_cumsum * kNumMBlocks, n_block_idx, m_block_idx); return true; } }; template struct StridedBatchedScheduler { static constexpr GemmType gemm_type = GemmType::StridedBatched; int current_iter = -1; uint32_t curr_group_idx; uint32_t curr_cumsum; int64_t m_offset; int64_t m_boundary; using Input = StridedBatchedSchedulerInput; Input input; StridedBatchedScheduler() {} __device__ __forceinline__ StridedBatchedScheduler(Input& input) { this->input = input; curr_group_idx = 0; curr_cumsum = 0; } __device__ __forceinline__ uint32_t get_global_m_idx(uint32_t const& block_idx) { // Assuming stride_a % ld_a == 0 && stride_a >= ld_a return input.stride_a / input.ld_a * curr_group_idx + block_idx * BLOCK_M; } __device__ __forceinline__ uint32_t get_global_n_idx( uint32_t const shape_dim, uint32_t const block_size, uint32_t const& block_idx, uint32_t const& m_block_idx = 0) { // Assuming stride_b % ld_b == 0 && stride_b >= ld_b return input.stride_b / input.ld_b * curr_group_idx + block_idx * block_size; } __device__ __forceinline__ uint32_t get_global_scales_a_idx(uint32_t const& block_idx) { return curr_group_idx * ceil_div(SHAPE_K, BLOCK_K) + block_idx; } __device__ __forceinline__ uint32_t get_global_scales_b_idx( uint32_t const shape_dim, uint32_t const block_size, uint32_t const& block_idx, uint32_t const& m_block_idx = 0) { return curr_group_idx * shape_dim + block_idx * block_size; } __device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx) { ++current_iter; auto const next_block_idx = current_iter * gridDim.x + blockIdx.x; uint32_t num_m_blocks; while (true) { // End of the task if (curr_group_idx == kNumGroups) return false; m_offset = curr_group_idx * input.shape_m; m_boundary = (curr_group_idx + 1) * input.shape_m; // Within current group num_m_blocks = ceil_div(input.shape_m, BLOCK_M); auto current_m_block_cumsum = curr_cumsum + num_m_blocks; if (next_block_idx < current_m_block_cumsum * kNumNBlocks) break; // Move to check the next group curr_group_idx++; curr_cumsum = current_m_block_cumsum; } get_swizzled_block_idx( num_m_blocks, next_block_idx - curr_cumsum * kNumNBlocks, m_block_idx, n_block_idx); return true; } }; template struct StridedBatchedSchedulerSwapAB { static constexpr GemmType gemm_type = GemmType::StridedBatched; int current_iter = -1; uint32_t curr_group_idx; uint32_t curr_cumsum; int64_t n_offset; int64_t n_boundary; using Input = StridedBatchedSchedulerInputSwapAB; Input input; StridedBatchedSchedulerSwapAB() {} __device__ __forceinline__ StridedBatchedSchedulerSwapAB(Input& input) { this->input = input; curr_group_idx = 0; curr_cumsum = 0; } // weight __device__ __forceinline__ uint32_t get_global_m_idx( const uint32_t shape_dim, const uint32_t block_size, uint32_t const& block_idx, uint32_t const& n_block_idx = 0) { // Assuming stride_a % ld_a == 0 && stride_a >= ld_a return input.stride_a / input.ld_a * curr_group_idx + block_idx * block_size; } // act __device__ __forceinline__ uint32_t get_global_n_idx(uint32_t const& block_idx) { // Assuming stride_b % ld_b == 0 && stride_b >= ld_b return input.stride_b / input.ld_b * curr_group_idx + block_idx * BLOCK_N; } // act scales __device__ __forceinline__ uint32_t get_global_scales_b_idx(uint32_t const& block_idx) { return curr_group_idx * ceil_div(SHAPE_K, BLOCK_K) + block_idx; } // weight scales __device__ __forceinline__ uint32_t get_global_scales_a_idx( const uint32_t shape_dim, const uint32_t block_size, uint32_t const& block_idx, uint32_t const& n_block_idx = 0) { return curr_group_idx * shape_dim + block_idx * block_size; } __device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx) { ++current_iter; auto const next_block_idx = current_iter * gridDim.x + blockIdx.x; uint32_t num_n_blocks; while (true) { // End of the task if (curr_group_idx == kNumGroups) return false; n_offset = curr_group_idx * input.shape_n; n_boundary = (curr_group_idx + 1) * input.shape_n; // Within current group num_n_blocks = ceil_div(input.shape_n, BLOCK_N); auto current_n_block_cumsum = curr_cumsum + num_n_blocks; if (next_block_idx < current_n_block_cumsum * kNumMBlocks) break; // Move to check the next group curr_group_idx++; curr_cumsum = current_n_block_cumsum; } // Note: Here, m and n roles are swapped get_swizzled_block_idx( num_n_blocks, next_block_idx - curr_cumsum * kNumMBlocks, n_block_idx, m_block_idx); return true; } }; template struct SchedulerSelector { static constexpr auto select_type() { if constexpr (GT == GemmType::Normal) return NormalScheduler(); if constexpr (GT == GemmType::GroupedContiguous) return GroupedContiguousScheduler(); if constexpr (GT == GemmType::GroupedMasked) return GroupedMaskedScheduler(); if constexpr (GT == GemmType::GroupedWithOffset) return GroupedWithOffsetScheduler(); if constexpr (GT == GemmType::StridedBatched) return StridedBatchedScheduler(); } using type = decltype(select_type()); }; template struct SchedulerSelectorSwapAB { static constexpr auto select_type() { static_assert(GT == GemmType::GroupedWithOffset || GT == GemmType::Normal, "Only GroupedWithOffset and Normal are supported for SwapAB"); if constexpr (GT == GemmType::Normal) return NormalSchedulerSwapAB(); if constexpr (GT == GemmType::GroupedWithOffset) return GroupedWithOffsetSchedulerSwapAB(); } using type = decltype(select_type()); }; #pragma clang diagnostic pop } // namespace deep_gemm