mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[https://nvbugs/5355128][fix] Add missing wgmma intrinsic for starcoder (#7643)
Signed-off-by: Pengbo Wang <221450789+pengbowang-nv@users.noreply.github.com>
This commit is contained in:
parent
126cd707e3
commit
08cc7a041f
@ -23,12 +23,15 @@ namespace gmma
|
||||
{
|
||||
// cog template. Do code generation with: pip install cogapp; cog -r $filename
|
||||
|
||||
// clang-format off
|
||||
/*[[[cog
|
||||
import cog
|
||||
reg_list = lambda beg,end: ", ".join([f"%{i}" for i in range(beg, end)])
|
||||
acc_placeholder = lambda n: "{%s}" % reg_list(0, n//2)
|
||||
acc_registers = lambda n: "\n , ".join([f'"+f"(acc[{i}][0][0]), "+f"(acc[{i}][0][1]), "+f"(acc[{i}][1][0]),
|
||||
"+f"(acc[{i}][1][1])' for i in range(n//8)]) ptx_eol = "\\n" n_list = [8, 16, 32, 64, 128, 256] for n in n_list:
|
||||
acc_registers = lambda n: "\n , ".join([f'"+f"(acc[{i}][0][0]), "+f"(acc[{i}][0][1]), "+f"(acc[{i}][1][0]), "+f"(acc[{i}][1][1])' for i in range(n//8)])
|
||||
ptx_eol = "\\n"
|
||||
n_list = [8, 16, 24, 32, 64, 128, 256]
|
||||
for n in n_list:
|
||||
cog.outl(f'''
|
||||
template<>
|
||||
__device__ inline void mma_async_shmA<__nv_fp8_e4m3, {n}, false, false>(float(&acc)[{n//8}][2][2], MatDesc::Raw descA,
|
||||
@ -139,6 +142,7 @@ const&>(descB)), "n"(false));
|
||||
}}
|
||||
''')
|
||||
]]]*/
|
||||
// clang-format on
|
||||
|
||||
template <>
|
||||
__device__ inline void mma_async_shmA<__nv_fp8_e4m3, 8, false, false>(
|
||||
@ -260,6 +264,72 @@ __device__ inline void mma_async_regA<__nv_fp8_e4m3, 16, false, false>(
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void mma_async_shmA<__nv_fp8_e4m3, 24, false, false>(
|
||||
float (&acc)[3][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal)
|
||||
{
|
||||
if (accHasVal)
|
||||
{
|
||||
asm volatile(
|
||||
"wgmma.mma_async.sync.aligned.m64n24k32.f32.e4m3.e4m3\n"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d
|
||||
"%12,\n" // a-desc
|
||||
"%13,\n" // b-desc
|
||||
"%14, 1, 1;\n"
|
||||
: "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]),
|
||||
"+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]),
|
||||
"+f"(acc[2][1][0]), "+f"(acc[2][1][1])
|
||||
: "l"(reinterpret_cast<uint64_t const&>(descA)), "l"(reinterpret_cast<uint64_t const&>(descB)), "n"(true));
|
||||
}
|
||||
else
|
||||
{
|
||||
asm volatile(
|
||||
"wgmma.mma_async.sync.aligned.m64n24k32.f32.e4m3.e4m3\n"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d
|
||||
"%12,\n" // a-desc
|
||||
"%13,\n" // b-desc
|
||||
"%14, 1, 1;\n"
|
||||
: "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]),
|
||||
"+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]),
|
||||
"+f"(acc[2][1][0]), "+f"(acc[2][1][1])
|
||||
: "l"(reinterpret_cast<uint64_t const&>(descA)), "l"(reinterpret_cast<uint64_t const&>(descB)), "n"(false));
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void mma_async_regA<__nv_fp8_e4m3, 24, false, false>(
|
||||
float (&acc)[3][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal)
|
||||
{
|
||||
if (accHasVal)
|
||||
{
|
||||
asm volatile(
|
||||
"wgmma.mma_async.sync.aligned.m64n24k32.f32.e4m3.e4m3\n"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d
|
||||
"{%12, %13, %14, %15},\n" // a
|
||||
"%16,\n" // b-desc
|
||||
"%17, 1, 1;\n"
|
||||
: "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]),
|
||||
"+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]),
|
||||
"+f"(acc[2][1][0]), "+f"(acc[2][1][1])
|
||||
: "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]),
|
||||
"l"(reinterpret_cast<uint64_t const&>(descB)), "n"(true));
|
||||
}
|
||||
else
|
||||
{
|
||||
asm volatile(
|
||||
"wgmma.mma_async.sync.aligned.m64n24k32.f32.e4m3.e4m3\n"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d
|
||||
"{%12, %13, %14, %15},\n" // a
|
||||
"%16,\n" // b-desc
|
||||
"%17, 1, 1;\n"
|
||||
: "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]),
|
||||
"+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]),
|
||||
"+f"(acc[2][1][0]), "+f"(acc[2][1][1])
|
||||
: "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]),
|
||||
"l"(reinterpret_cast<uint64_t const&>(descB)), "n"(false));
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void mma_async_shmA<__nv_fp8_e4m3, 32, false, false>(
|
||||
float (&acc)[4][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal)
|
||||
@ -1424,6 +1494,398 @@ __device__ inline void mma_async_shmA<__nv_bfloat16, 16, 1, 1>(
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void mma_async_shmA<half, 24, 0, 0>(
|
||||
float (&acc)[3][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal)
|
||||
{
|
||||
if (accHasVal)
|
||||
{
|
||||
asm volatile(
|
||||
"wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d
|
||||
"%12,\n" // a-desc
|
||||
"%13,\n" // b-desc
|
||||
"%14, 1, 1, 0, 0;\n"
|
||||
: "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]),
|
||||
"+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]),
|
||||
"+f"(acc[2][1][0]), "+f"(acc[2][1][1])
|
||||
: "l"(reinterpret_cast<uint64_t const&>(descA)), "l"(reinterpret_cast<uint64_t const&>(descB)), "n"(true));
|
||||
}
|
||||
else
|
||||
{
|
||||
asm volatile(
|
||||
"wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d
|
||||
"%12,\n" // a-desc
|
||||
"%13,\n" // b-desc
|
||||
"%14, 1, 1, 0, 0;\n"
|
||||
: "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]),
|
||||
"+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]),
|
||||
"+f"(acc[2][1][0]), "+f"(acc[2][1][1])
|
||||
: "l"(reinterpret_cast<uint64_t const&>(descA)), "l"(reinterpret_cast<uint64_t const&>(descB)), "n"(false));
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void mma_async_regA<half, 24, 0, 0>(
|
||||
float (&acc)[3][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal)
|
||||
{
|
||||
if (accHasVal)
|
||||
{
|
||||
asm volatile(
|
||||
"wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d
|
||||
"{%12, %13, %14, %15},\n" // a
|
||||
"%16,\n" // b-desc
|
||||
"%17, 1, 1, 0;\n"
|
||||
: "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]),
|
||||
"+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]),
|
||||
"+f"(acc[2][1][0]), "+f"(acc[2][1][1])
|
||||
: "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]),
|
||||
"l"(reinterpret_cast<uint64_t const&>(descB)), "n"(true));
|
||||
}
|
||||
else
|
||||
{
|
||||
asm volatile(
|
||||
"wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d
|
||||
"{%12, %13, %14, %15},\n" // a
|
||||
"%16,\n" // b-desc
|
||||
"%17, 1, 1, 0;\n"
|
||||
: "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]),
|
||||
"+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]),
|
||||
"+f"(acc[2][1][0]), "+f"(acc[2][1][1])
|
||||
: "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]),
|
||||
"l"(reinterpret_cast<uint64_t const&>(descB)), "n"(false));
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void mma_async_shmA<__nv_bfloat16, 24, 0, 0>(
|
||||
float (&acc)[3][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal)
|
||||
{
|
||||
if (accHasVal)
|
||||
{
|
||||
asm volatile(
|
||||
"wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d
|
||||
"%12,\n" // a-desc
|
||||
"%13,\n" // b-desc
|
||||
"%14, 1, 1, 0, 0;\n"
|
||||
: "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]),
|
||||
"+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]),
|
||||
"+f"(acc[2][1][0]), "+f"(acc[2][1][1])
|
||||
: "l"(reinterpret_cast<uint64_t const&>(descA)), "l"(reinterpret_cast<uint64_t const&>(descB)), "n"(true));
|
||||
}
|
||||
else
|
||||
{
|
||||
asm volatile(
|
||||
"wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d
|
||||
"%12,\n" // a-desc
|
||||
"%13,\n" // b-desc
|
||||
"%14, 1, 1, 0, 0;\n"
|
||||
: "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]),
|
||||
"+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]),
|
||||
"+f"(acc[2][1][0]), "+f"(acc[2][1][1])
|
||||
: "l"(reinterpret_cast<uint64_t const&>(descA)), "l"(reinterpret_cast<uint64_t const&>(descB)), "n"(false));
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void mma_async_regA<__nv_bfloat16, 24, 0, 0>(
|
||||
float (&acc)[3][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal)
|
||||
{
|
||||
if (accHasVal)
|
||||
{
|
||||
asm volatile(
|
||||
"wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d
|
||||
"{%12, %13, %14, %15},\n" // a
|
||||
"%16,\n" // b-desc
|
||||
"%17, 1, 1, 0;\n"
|
||||
: "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]),
|
||||
"+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]),
|
||||
"+f"(acc[2][1][0]), "+f"(acc[2][1][1])
|
||||
: "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]),
|
||||
"l"(reinterpret_cast<uint64_t const&>(descB)), "n"(true));
|
||||
}
|
||||
else
|
||||
{
|
||||
asm volatile(
|
||||
"wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d
|
||||
"{%12, %13, %14, %15},\n" // a
|
||||
"%16,\n" // b-desc
|
||||
"%17, 1, 1, 0;\n"
|
||||
: "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]),
|
||||
"+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]),
|
||||
"+f"(acc[2][1][0]), "+f"(acc[2][1][1])
|
||||
: "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]),
|
||||
"l"(reinterpret_cast<uint64_t const&>(descB)), "n"(false));
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void mma_async_shmA<half, 24, 0, 1>(
|
||||
float (&acc)[3][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal)
|
||||
{
|
||||
if (accHasVal)
|
||||
{
|
||||
asm volatile(
|
||||
"wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d
|
||||
"%12,\n" // a-desc
|
||||
"%13,\n" // b-desc
|
||||
"%14, 1, 1, 0, 1;\n"
|
||||
: "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]),
|
||||
"+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]),
|
||||
"+f"(acc[2][1][0]), "+f"(acc[2][1][1])
|
||||
: "l"(reinterpret_cast<uint64_t const&>(descA)), "l"(reinterpret_cast<uint64_t const&>(descB)), "n"(true));
|
||||
}
|
||||
else
|
||||
{
|
||||
asm volatile(
|
||||
"wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d
|
||||
"%12,\n" // a-desc
|
||||
"%13,\n" // b-desc
|
||||
"%14, 1, 1, 0, 1;\n"
|
||||
: "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]),
|
||||
"+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]),
|
||||
"+f"(acc[2][1][0]), "+f"(acc[2][1][1])
|
||||
: "l"(reinterpret_cast<uint64_t const&>(descA)), "l"(reinterpret_cast<uint64_t const&>(descB)), "n"(false));
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void mma_async_regA<half, 24, 0, 1>(
|
||||
float (&acc)[3][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal)
|
||||
{
|
||||
if (accHasVal)
|
||||
{
|
||||
asm volatile(
|
||||
"wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d
|
||||
"{%12, %13, %14, %15},\n" // a
|
||||
"%16,\n" // b-desc
|
||||
"%17, 1, 1, 1;\n"
|
||||
: "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]),
|
||||
"+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]),
|
||||
"+f"(acc[2][1][0]), "+f"(acc[2][1][1])
|
||||
: "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]),
|
||||
"l"(reinterpret_cast<uint64_t const&>(descB)), "n"(true));
|
||||
}
|
||||
else
|
||||
{
|
||||
asm volatile(
|
||||
"wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d
|
||||
"{%12, %13, %14, %15},\n" // a
|
||||
"%16,\n" // b-desc
|
||||
"%17, 1, 1, 1;\n"
|
||||
: "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]),
|
||||
"+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]),
|
||||
"+f"(acc[2][1][0]), "+f"(acc[2][1][1])
|
||||
: "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]),
|
||||
"l"(reinterpret_cast<uint64_t const&>(descB)), "n"(false));
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void mma_async_shmA<__nv_bfloat16, 24, 0, 1>(
|
||||
float (&acc)[3][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal)
|
||||
{
|
||||
if (accHasVal)
|
||||
{
|
||||
asm volatile(
|
||||
"wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d
|
||||
"%12,\n" // a-desc
|
||||
"%13,\n" // b-desc
|
||||
"%14, 1, 1, 0, 1;\n"
|
||||
: "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]),
|
||||
"+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]),
|
||||
"+f"(acc[2][1][0]), "+f"(acc[2][1][1])
|
||||
: "l"(reinterpret_cast<uint64_t const&>(descA)), "l"(reinterpret_cast<uint64_t const&>(descB)), "n"(true));
|
||||
}
|
||||
else
|
||||
{
|
||||
asm volatile(
|
||||
"wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d
|
||||
"%12,\n" // a-desc
|
||||
"%13,\n" // b-desc
|
||||
"%14, 1, 1, 0, 1;\n"
|
||||
: "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]),
|
||||
"+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]),
|
||||
"+f"(acc[2][1][0]), "+f"(acc[2][1][1])
|
||||
: "l"(reinterpret_cast<uint64_t const&>(descA)), "l"(reinterpret_cast<uint64_t const&>(descB)), "n"(false));
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void mma_async_regA<__nv_bfloat16, 24, 0, 1>(
|
||||
float (&acc)[3][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal)
|
||||
{
|
||||
if (accHasVal)
|
||||
{
|
||||
asm volatile(
|
||||
"wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d
|
||||
"{%12, %13, %14, %15},\n" // a
|
||||
"%16,\n" // b-desc
|
||||
"%17, 1, 1, 1;\n"
|
||||
: "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]),
|
||||
"+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]),
|
||||
"+f"(acc[2][1][0]), "+f"(acc[2][1][1])
|
||||
: "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]),
|
||||
"l"(reinterpret_cast<uint64_t const&>(descB)), "n"(true));
|
||||
}
|
||||
else
|
||||
{
|
||||
asm volatile(
|
||||
"wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d
|
||||
"{%12, %13, %14, %15},\n" // a
|
||||
"%16,\n" // b-desc
|
||||
"%17, 1, 1, 1;\n"
|
||||
: "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]),
|
||||
"+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]),
|
||||
"+f"(acc[2][1][0]), "+f"(acc[2][1][1])
|
||||
: "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]),
|
||||
"l"(reinterpret_cast<uint64_t const&>(descB)), "n"(false));
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void mma_async_shmA<half, 24, 1, 0>(
|
||||
float (&acc)[3][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal)
|
||||
{
|
||||
if (accHasVal)
|
||||
{
|
||||
asm volatile(
|
||||
"wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d
|
||||
"%12,\n" // a-desc
|
||||
"%13,\n" // b-desc
|
||||
"%14, 1, 1, 1, 0;\n"
|
||||
: "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]),
|
||||
"+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]),
|
||||
"+f"(acc[2][1][0]), "+f"(acc[2][1][1])
|
||||
: "l"(reinterpret_cast<uint64_t const&>(descA)), "l"(reinterpret_cast<uint64_t const&>(descB)), "n"(true));
|
||||
}
|
||||
else
|
||||
{
|
||||
asm volatile(
|
||||
"wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d
|
||||
"%12,\n" // a-desc
|
||||
"%13,\n" // b-desc
|
||||
"%14, 1, 1, 1, 0;\n"
|
||||
: "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]),
|
||||
"+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]),
|
||||
"+f"(acc[2][1][0]), "+f"(acc[2][1][1])
|
||||
: "l"(reinterpret_cast<uint64_t const&>(descA)), "l"(reinterpret_cast<uint64_t const&>(descB)), "n"(false));
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void mma_async_shmA<__nv_bfloat16, 24, 1, 0>(
|
||||
float (&acc)[3][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal)
|
||||
{
|
||||
if (accHasVal)
|
||||
{
|
||||
asm volatile(
|
||||
"wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d
|
||||
"%12,\n" // a-desc
|
||||
"%13,\n" // b-desc
|
||||
"%14, 1, 1, 1, 0;\n"
|
||||
: "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]),
|
||||
"+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]),
|
||||
"+f"(acc[2][1][0]), "+f"(acc[2][1][1])
|
||||
: "l"(reinterpret_cast<uint64_t const&>(descA)), "l"(reinterpret_cast<uint64_t const&>(descB)), "n"(true));
|
||||
}
|
||||
else
|
||||
{
|
||||
asm volatile(
|
||||
"wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d
|
||||
"%12,\n" // a-desc
|
||||
"%13,\n" // b-desc
|
||||
"%14, 1, 1, 1, 0;\n"
|
||||
: "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]),
|
||||
"+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]),
|
||||
"+f"(acc[2][1][0]), "+f"(acc[2][1][1])
|
||||
: "l"(reinterpret_cast<uint64_t const&>(descA)), "l"(reinterpret_cast<uint64_t const&>(descB)), "n"(false));
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void mma_async_shmA<half, 24, 1, 1>(
|
||||
float (&acc)[3][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal)
|
||||
{
|
||||
if (accHasVal)
|
||||
{
|
||||
asm volatile(
|
||||
"wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d
|
||||
"%12,\n" // a-desc
|
||||
"%13,\n" // b-desc
|
||||
"%14, 1, 1, 1, 1;\n"
|
||||
: "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]),
|
||||
"+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]),
|
||||
"+f"(acc[2][1][0]), "+f"(acc[2][1][1])
|
||||
: "l"(reinterpret_cast<uint64_t const&>(descA)), "l"(reinterpret_cast<uint64_t const&>(descB)), "n"(true));
|
||||
}
|
||||
else
|
||||
{
|
||||
asm volatile(
|
||||
"wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d
|
||||
"%12,\n" // a-desc
|
||||
"%13,\n" // b-desc
|
||||
"%14, 1, 1, 1, 1;\n"
|
||||
: "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]),
|
||||
"+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]),
|
||||
"+f"(acc[2][1][0]), "+f"(acc[2][1][1])
|
||||
: "l"(reinterpret_cast<uint64_t const&>(descA)), "l"(reinterpret_cast<uint64_t const&>(descB)), "n"(false));
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void mma_async_shmA<__nv_bfloat16, 24, 1, 1>(
|
||||
float (&acc)[3][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal)
|
||||
{
|
||||
if (accHasVal)
|
||||
{
|
||||
asm volatile(
|
||||
"wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d
|
||||
"%12,\n" // a-desc
|
||||
"%13,\n" // b-desc
|
||||
"%14, 1, 1, 1, 1;\n"
|
||||
: "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]),
|
||||
"+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]),
|
||||
"+f"(acc[2][1][0]), "+f"(acc[2][1][1])
|
||||
: "l"(reinterpret_cast<uint64_t const&>(descA)), "l"(reinterpret_cast<uint64_t const&>(descB)), "n"(true));
|
||||
}
|
||||
else
|
||||
{
|
||||
asm volatile(
|
||||
"wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d
|
||||
"%12,\n" // a-desc
|
||||
"%13,\n" // b-desc
|
||||
"%14, 1, 1, 1, 1;\n"
|
||||
: "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]),
|
||||
"+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]),
|
||||
"+f"(acc[2][1][0]), "+f"(acc[2][1][1])
|
||||
: "l"(reinterpret_cast<uint64_t const&>(descA)), "l"(reinterpret_cast<uint64_t const&>(descB)), "n"(false));
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void mma_async_shmA<half, 32, 0, 0>(
|
||||
float (&acc)[4][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal)
|
||||
|
||||
@ -226,8 +226,6 @@ examples/test_llama.py::test_llm_llama_1gpu_streaming_llm[ailab-deepseek-coder-6
|
||||
test_e2e.py::test_openai_multinodes_chat_tp16pp1 SKIP (https://nvbugs/5112075)
|
||||
examples/test_qwen.py::test_llm_hf_qwen_quantization_1gpu[qwen2_vl_7b_instruct-fp8-bfloat16] SKIP (https://nvbugs/5322488)
|
||||
accuracy/test_cli_flow.py::TestSantacoder::test_auto_dtype SKIP (https://nvbugs/5234043)
|
||||
examples/test_gpt.py::test_starcoder_fp8_quantization_2gpu[starcoder] SKIP (https://nvbugs/5355128)
|
||||
examples/test_gpt.py::test_starcoder_fp8_quantization_2gpu[starcoderplus] SKIP (https://nvbugs/5355128)
|
||||
full:L40S/accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype SKIP (https://nvbugs/5375620)
|
||||
full:L20/accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype SKIP (https://nvbugs/5375620)
|
||||
test_e2e.py::test_ptp_quickstart_advanced_8gpus[Llama3.1-405B-FP8-llama-3.1-model/Llama-3.1-405B-Instruct-FP8] SKIP (https://nvbugs/5380570)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user