[https://nvbugs/5355128][fix] Add missing wgmma intrinsic for starcoder (#7643)

Signed-off-by: Pengbo Wang <221450789+pengbowang-nv@users.noreply.github.com>
This commit is contained in:
Pengbo Wang 2025-09-23 10:38:58 +08:00 committed by GitHub
parent 126cd707e3
commit 08cc7a041f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 464 additions and 4 deletions

View File

@ -23,12 +23,15 @@ namespace gmma
{
// cog template. Do code generation with: pip install cogapp; cog -r $filename
// clang-format off
/*[[[cog
import cog
reg_list = lambda beg,end: ", ".join([f"%{i}" for i in range(beg, end)])
acc_placeholder = lambda n: "{%s}" % reg_list(0, n//2)
acc_registers = lambda n: "\n , ".join([f'"+f"(acc[{i}][0][0]), "+f"(acc[{i}][0][1]), "+f"(acc[{i}][1][0]),
"+f"(acc[{i}][1][1])' for i in range(n//8)]) ptx_eol = "\\n" n_list = [8, 16, 32, 64, 128, 256] for n in n_list:
acc_registers = lambda n: "\n , ".join([f'"+f"(acc[{i}][0][0]), "+f"(acc[{i}][0][1]), "+f"(acc[{i}][1][0]), "+f"(acc[{i}][1][1])' for i in range(n//8)])
ptx_eol = "\\n"
n_list = [8, 16, 24, 32, 64, 128, 256]
for n in n_list:
cog.outl(f'''
template<>
__device__ inline void mma_async_shmA<__nv_fp8_e4m3, {n}, false, false>(float(&acc)[{n//8}][2][2], MatDesc::Raw descA,
@ -139,6 +142,7 @@ const&>(descB)), "n"(false));
}}
''')
]]]*/
// clang-format on
template <>
__device__ inline void mma_async_shmA<__nv_fp8_e4m3, 8, false, false>(
@ -260,6 +264,72 @@ __device__ inline void mma_async_regA<__nv_fp8_e4m3, 16, false, false>(
}
}
template <>
__device__ inline void mma_async_shmA<__nv_fp8_e4m3, 24, false, false>(
float (&acc)[3][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal)
{
if (accHasVal)
{
asm volatile(
"wgmma.mma_async.sync.aligned.m64n24k32.f32.e4m3.e4m3\n"
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d
"%12,\n" // a-desc
"%13,\n" // b-desc
"%14, 1, 1;\n"
: "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]),
"+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]),
"+f"(acc[2][1][0]), "+f"(acc[2][1][1])
: "l"(reinterpret_cast<uint64_t const&>(descA)), "l"(reinterpret_cast<uint64_t const&>(descB)), "n"(true));
}
else
{
asm volatile(
"wgmma.mma_async.sync.aligned.m64n24k32.f32.e4m3.e4m3\n"
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d
"%12,\n" // a-desc
"%13,\n" // b-desc
"%14, 1, 1;\n"
: "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]),
"+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]),
"+f"(acc[2][1][0]), "+f"(acc[2][1][1])
: "l"(reinterpret_cast<uint64_t const&>(descA)), "l"(reinterpret_cast<uint64_t const&>(descB)), "n"(false));
}
}
template <>
__device__ inline void mma_async_regA<__nv_fp8_e4m3, 24, false, false>(
float (&acc)[3][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal)
{
if (accHasVal)
{
asm volatile(
"wgmma.mma_async.sync.aligned.m64n24k32.f32.e4m3.e4m3\n"
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d
"{%12, %13, %14, %15},\n" // a
"%16,\n" // b-desc
"%17, 1, 1;\n"
: "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]),
"+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]),
"+f"(acc[2][1][0]), "+f"(acc[2][1][1])
: "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]),
"l"(reinterpret_cast<uint64_t const&>(descB)), "n"(true));
}
else
{
asm volatile(
"wgmma.mma_async.sync.aligned.m64n24k32.f32.e4m3.e4m3\n"
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d
"{%12, %13, %14, %15},\n" // a
"%16,\n" // b-desc
"%17, 1, 1;\n"
: "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]),
"+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]),
"+f"(acc[2][1][0]), "+f"(acc[2][1][1])
: "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]),
"l"(reinterpret_cast<uint64_t const&>(descB)), "n"(false));
}
}
template <>
__device__ inline void mma_async_shmA<__nv_fp8_e4m3, 32, false, false>(
float (&acc)[4][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal)
@ -1424,6 +1494,398 @@ __device__ inline void mma_async_shmA<__nv_bfloat16, 16, 1, 1>(
}
}
template <>
__device__ inline void mma_async_shmA<half, 24, 0, 0>(
float (&acc)[3][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal)
{
if (accHasVal)
{
asm volatile(
"wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n"
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d
"%12,\n" // a-desc
"%13,\n" // b-desc
"%14, 1, 1, 0, 0;\n"
: "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]),
"+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]),
"+f"(acc[2][1][0]), "+f"(acc[2][1][1])
: "l"(reinterpret_cast<uint64_t const&>(descA)), "l"(reinterpret_cast<uint64_t const&>(descB)), "n"(true));
}
else
{
asm volatile(
"wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n"
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d
"%12,\n" // a-desc
"%13,\n" // b-desc
"%14, 1, 1, 0, 0;\n"
: "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]),
"+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]),
"+f"(acc[2][1][0]), "+f"(acc[2][1][1])
: "l"(reinterpret_cast<uint64_t const&>(descA)), "l"(reinterpret_cast<uint64_t const&>(descB)), "n"(false));
}
}
template <>
__device__ inline void mma_async_regA<half, 24, 0, 0>(
float (&acc)[3][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal)
{
if (accHasVal)
{
asm volatile(
"wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n"
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d
"{%12, %13, %14, %15},\n" // a
"%16,\n" // b-desc
"%17, 1, 1, 0;\n"
: "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]),
"+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]),
"+f"(acc[2][1][0]), "+f"(acc[2][1][1])
: "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]),
"l"(reinterpret_cast<uint64_t const&>(descB)), "n"(true));
}
else
{
asm volatile(
"wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n"
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d
"{%12, %13, %14, %15},\n" // a
"%16,\n" // b-desc
"%17, 1, 1, 0;\n"
: "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]),
"+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]),
"+f"(acc[2][1][0]), "+f"(acc[2][1][1])
: "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]),
"l"(reinterpret_cast<uint64_t const&>(descB)), "n"(false));
}
}
template <>
__device__ inline void mma_async_shmA<__nv_bfloat16, 24, 0, 0>(
float (&acc)[3][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal)
{
if (accHasVal)
{
asm volatile(
"wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n"
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d
"%12,\n" // a-desc
"%13,\n" // b-desc
"%14, 1, 1, 0, 0;\n"
: "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]),
"+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]),
"+f"(acc[2][1][0]), "+f"(acc[2][1][1])
: "l"(reinterpret_cast<uint64_t const&>(descA)), "l"(reinterpret_cast<uint64_t const&>(descB)), "n"(true));
}
else
{
asm volatile(
"wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n"
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d
"%12,\n" // a-desc
"%13,\n" // b-desc
"%14, 1, 1, 0, 0;\n"
: "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]),
"+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]),
"+f"(acc[2][1][0]), "+f"(acc[2][1][1])
: "l"(reinterpret_cast<uint64_t const&>(descA)), "l"(reinterpret_cast<uint64_t const&>(descB)), "n"(false));
}
}
template <>
__device__ inline void mma_async_regA<__nv_bfloat16, 24, 0, 0>(
float (&acc)[3][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal)
{
if (accHasVal)
{
asm volatile(
"wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n"
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d
"{%12, %13, %14, %15},\n" // a
"%16,\n" // b-desc
"%17, 1, 1, 0;\n"
: "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]),
"+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]),
"+f"(acc[2][1][0]), "+f"(acc[2][1][1])
: "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]),
"l"(reinterpret_cast<uint64_t const&>(descB)), "n"(true));
}
else
{
asm volatile(
"wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n"
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d
"{%12, %13, %14, %15},\n" // a
"%16,\n" // b-desc
"%17, 1, 1, 0;\n"
: "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]),
"+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]),
"+f"(acc[2][1][0]), "+f"(acc[2][1][1])
: "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]),
"l"(reinterpret_cast<uint64_t const&>(descB)), "n"(false));
}
}
template <>
__device__ inline void mma_async_shmA<half, 24, 0, 1>(
float (&acc)[3][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal)
{
if (accHasVal)
{
asm volatile(
"wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n"
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d
"%12,\n" // a-desc
"%13,\n" // b-desc
"%14, 1, 1, 0, 1;\n"
: "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]),
"+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]),
"+f"(acc[2][1][0]), "+f"(acc[2][1][1])
: "l"(reinterpret_cast<uint64_t const&>(descA)), "l"(reinterpret_cast<uint64_t const&>(descB)), "n"(true));
}
else
{
asm volatile(
"wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n"
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d
"%12,\n" // a-desc
"%13,\n" // b-desc
"%14, 1, 1, 0, 1;\n"
: "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]),
"+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]),
"+f"(acc[2][1][0]), "+f"(acc[2][1][1])
: "l"(reinterpret_cast<uint64_t const&>(descA)), "l"(reinterpret_cast<uint64_t const&>(descB)), "n"(false));
}
}
template <>
__device__ inline void mma_async_regA<half, 24, 0, 1>(
float (&acc)[3][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal)
{
if (accHasVal)
{
asm volatile(
"wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n"
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d
"{%12, %13, %14, %15},\n" // a
"%16,\n" // b-desc
"%17, 1, 1, 1;\n"
: "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]),
"+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]),
"+f"(acc[2][1][0]), "+f"(acc[2][1][1])
: "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]),
"l"(reinterpret_cast<uint64_t const&>(descB)), "n"(true));
}
else
{
asm volatile(
"wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n"
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d
"{%12, %13, %14, %15},\n" // a
"%16,\n" // b-desc
"%17, 1, 1, 1;\n"
: "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]),
"+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]),
"+f"(acc[2][1][0]), "+f"(acc[2][1][1])
: "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]),
"l"(reinterpret_cast<uint64_t const&>(descB)), "n"(false));
}
}
template <>
__device__ inline void mma_async_shmA<__nv_bfloat16, 24, 0, 1>(
float (&acc)[3][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal)
{
if (accHasVal)
{
asm volatile(
"wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n"
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d
"%12,\n" // a-desc
"%13,\n" // b-desc
"%14, 1, 1, 0, 1;\n"
: "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]),
"+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]),
"+f"(acc[2][1][0]), "+f"(acc[2][1][1])
: "l"(reinterpret_cast<uint64_t const&>(descA)), "l"(reinterpret_cast<uint64_t const&>(descB)), "n"(true));
}
else
{
asm volatile(
"wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n"
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d
"%12,\n" // a-desc
"%13,\n" // b-desc
"%14, 1, 1, 0, 1;\n"
: "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]),
"+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]),
"+f"(acc[2][1][0]), "+f"(acc[2][1][1])
: "l"(reinterpret_cast<uint64_t const&>(descA)), "l"(reinterpret_cast<uint64_t const&>(descB)), "n"(false));
}
}
template <>
__device__ inline void mma_async_regA<__nv_bfloat16, 24, 0, 1>(
float (&acc)[3][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal)
{
if (accHasVal)
{
asm volatile(
"wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n"
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d
"{%12, %13, %14, %15},\n" // a
"%16,\n" // b-desc
"%17, 1, 1, 1;\n"
: "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]),
"+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]),
"+f"(acc[2][1][0]), "+f"(acc[2][1][1])
: "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]),
"l"(reinterpret_cast<uint64_t const&>(descB)), "n"(true));
}
else
{
asm volatile(
"wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n"
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d
"{%12, %13, %14, %15},\n" // a
"%16,\n" // b-desc
"%17, 1, 1, 1;\n"
: "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]),
"+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]),
"+f"(acc[2][1][0]), "+f"(acc[2][1][1])
: "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]),
"l"(reinterpret_cast<uint64_t const&>(descB)), "n"(false));
}
}
template <>
__device__ inline void mma_async_shmA<half, 24, 1, 0>(
float (&acc)[3][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal)
{
if (accHasVal)
{
asm volatile(
"wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n"
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d
"%12,\n" // a-desc
"%13,\n" // b-desc
"%14, 1, 1, 1, 0;\n"
: "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]),
"+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]),
"+f"(acc[2][1][0]), "+f"(acc[2][1][1])
: "l"(reinterpret_cast<uint64_t const&>(descA)), "l"(reinterpret_cast<uint64_t const&>(descB)), "n"(true));
}
else
{
asm volatile(
"wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n"
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d
"%12,\n" // a-desc
"%13,\n" // b-desc
"%14, 1, 1, 1, 0;\n"
: "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]),
"+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]),
"+f"(acc[2][1][0]), "+f"(acc[2][1][1])
: "l"(reinterpret_cast<uint64_t const&>(descA)), "l"(reinterpret_cast<uint64_t const&>(descB)), "n"(false));
}
}
template <>
__device__ inline void mma_async_shmA<__nv_bfloat16, 24, 1, 0>(
float (&acc)[3][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal)
{
if (accHasVal)
{
asm volatile(
"wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n"
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d
"%12,\n" // a-desc
"%13,\n" // b-desc
"%14, 1, 1, 1, 0;\n"
: "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]),
"+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]),
"+f"(acc[2][1][0]), "+f"(acc[2][1][1])
: "l"(reinterpret_cast<uint64_t const&>(descA)), "l"(reinterpret_cast<uint64_t const&>(descB)), "n"(true));
}
else
{
asm volatile(
"wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n"
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d
"%12,\n" // a-desc
"%13,\n" // b-desc
"%14, 1, 1, 1, 0;\n"
: "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]),
"+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]),
"+f"(acc[2][1][0]), "+f"(acc[2][1][1])
: "l"(reinterpret_cast<uint64_t const&>(descA)), "l"(reinterpret_cast<uint64_t const&>(descB)), "n"(false));
}
}
template <>
__device__ inline void mma_async_shmA<half, 24, 1, 1>(
float (&acc)[3][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal)
{
if (accHasVal)
{
asm volatile(
"wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n"
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d
"%12,\n" // a-desc
"%13,\n" // b-desc
"%14, 1, 1, 1, 1;\n"
: "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]),
"+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]),
"+f"(acc[2][1][0]), "+f"(acc[2][1][1])
: "l"(reinterpret_cast<uint64_t const&>(descA)), "l"(reinterpret_cast<uint64_t const&>(descB)), "n"(true));
}
else
{
asm volatile(
"wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n"
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d
"%12,\n" // a-desc
"%13,\n" // b-desc
"%14, 1, 1, 1, 1;\n"
: "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]),
"+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]),
"+f"(acc[2][1][0]), "+f"(acc[2][1][1])
: "l"(reinterpret_cast<uint64_t const&>(descA)), "l"(reinterpret_cast<uint64_t const&>(descB)), "n"(false));
}
}
template <>
__device__ inline void mma_async_shmA<__nv_bfloat16, 24, 1, 1>(
float (&acc)[3][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal)
{
if (accHasVal)
{
asm volatile(
"wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n"
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d
"%12,\n" // a-desc
"%13,\n" // b-desc
"%14, 1, 1, 1, 1;\n"
: "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]),
"+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]),
"+f"(acc[2][1][0]), "+f"(acc[2][1][1])
: "l"(reinterpret_cast<uint64_t const&>(descA)), "l"(reinterpret_cast<uint64_t const&>(descB)), "n"(true));
}
else
{
asm volatile(
"wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n"
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d
"%12,\n" // a-desc
"%13,\n" // b-desc
"%14, 1, 1, 1, 1;\n"
: "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]),
"+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]),
"+f"(acc[2][1][0]), "+f"(acc[2][1][1])
: "l"(reinterpret_cast<uint64_t const&>(descA)), "l"(reinterpret_cast<uint64_t const&>(descB)), "n"(false));
}
}
template <>
__device__ inline void mma_async_shmA<half, 32, 0, 0>(
float (&acc)[4][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal)

View File

@ -226,8 +226,6 @@ examples/test_llama.py::test_llm_llama_1gpu_streaming_llm[ailab-deepseek-coder-6
test_e2e.py::test_openai_multinodes_chat_tp16pp1 SKIP (https://nvbugs/5112075)
examples/test_qwen.py::test_llm_hf_qwen_quantization_1gpu[qwen2_vl_7b_instruct-fp8-bfloat16] SKIP (https://nvbugs/5322488)
accuracy/test_cli_flow.py::TestSantacoder::test_auto_dtype SKIP (https://nvbugs/5234043)
examples/test_gpt.py::test_starcoder_fp8_quantization_2gpu[starcoder] SKIP (https://nvbugs/5355128)
examples/test_gpt.py::test_starcoder_fp8_quantization_2gpu[starcoderplus] SKIP (https://nvbugs/5355128)
full:L40S/accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype SKIP (https://nvbugs/5375620)
full:L20/accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype SKIP (https://nvbugs/5375620)
test_e2e.py::test_ptp_quickstart_advanced_8gpus[Llama3.1-405B-FP8-llama-3.1-model/Llama-3.1-405B-Instruct-FP8] SKIP (https://nvbugs/5380570)