/* * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #include "cuda_hint.cuh" #include "mha_stdheaders.cuh" #ifndef __CUDACC__ #include #endif #include #include // for both a and b, outer-dim is gemm-K and inner-dim is gemm-M or gemm-N // acc is used as both input and output. template __device__ inline void mma(float (&acc)[2][2], uint32_t const (&a)[2][2], uint32_t const (&b)[2][1]) { static_assert(mha::is_same_v || mha::is_same_v || mha::is_same_v, "not implemented"); if constexpr (mha::is_same_v) { asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \n" " {%0, %1, %2, %3}, \n" " {%4, %5, %6, %7}, \n" " {%8, %9}, \n" " {%0, %1, %2, %3}; \n" : "+f"(acc[0][0]), "+f"(acc[0][1]), "+f"(acc[1][0]), "+f"(acc[1][1]) : "r"(a[0][0]), "r"(a[0][1]), "r"(a[1][0]), "r"(a[1][1]), "r"(b[0][0]), "r"(b[1][0])); } else if constexpr (mha::is_same_v) { asm("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 \n" " {%0, %1, %2, %3}, \n" " {%4, %5, %6, %7}, \n" " {%8, %9}, \n" " {%0, %1, %2, %3}; \n" : "+f"(acc[0][0]), "+f"(acc[0][1]), "+f"(acc[1][0]), "+f"(acc[1][1]) : "r"(a[0][0]), "r"(a[0][1]), "r"(a[1][0]), "r"(a[1][1]), "r"(b[0][0]), "r"(b[1][0])); } else if constexpr (mha::is_same_v) { asm("mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 \n" " {%0, %1, %2, %3}, \n" " {%4, %5, %6, %7}, \n" " {%8, %9}, \n" " {%0, %1, %2, %3}; \n" : "+f"(acc[0][0]), "+f"(acc[0][1]), "+f"(acc[1][0]), "+f"(acc[1][1]) : "r"(a[0][0]), "r"(a[0][1]), "r"(a[1][0]), "r"(a[1][1]), "r"(b[0][0]), "r"(b[1][0])); } else { asm volatile("trap;"); } } __device__ inline void mmaF8_k16(float (&acc)[2][2], uint32_t const (&a)[2], uint32_t const b) { asm("mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 \n" " {%0, %1, %2, %3}, \n" " {%4, %5}, \n" " {%6}, \n" " {%0, %1, %2, %3}; \n" : "+f"(acc[0][0]), "+f"(acc[0][1]), "+f"(acc[1][0]), "+f"(acc[1][1]) : "r"(a[0]), "r"(a[1]), "r"(b)); } __device__ inline void mmaF8_k32_2inst(float (&acc)[2][2], uint32_t const (&a)[2][2], uint32_t const (&b)[2][1]) { for (uint32_t i = 0; i < 2; i++) { mmaF8_k16(acc, a[i], b[i][0]); } } struct mmaShape { uint32_t m; uint32_t n; uint32_t k; }; inline constexpr mmaShape qmmaShape = {16, 8, 32};