[https://nvbugs/5381276][fix] fix warning for fused_a_gemm (#6402)

Signed-off-by: yunruis <205571022+yunruis@users.noreply.github.com>
This commit is contained in:
yunruis 2025-08-01 21:37:21 +08:00 committed by GitHub
parent 7447d6ed85
commit a20ab5cbdb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -35,7 +35,6 @@ namespace tensorrt_llm::kernels::dsv3MinLatencyKernels
__device__ void hmma_16_8_16_f32acc_bf16ab(
float (&d_reg)[4], const bf16_t (&a_reg)[8], const bf16_t (&b_reg)[4], float const (&c_reg)[4])
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
uint32_t a0 = *reinterpret_cast<uint32_t const*>(a_reg + 0);
uint32_t a1 = *reinterpret_cast<uint32_t const*>(a_reg + 2);
uint32_t a2 = *reinterpret_cast<uint32_t const*>(a_reg + 4);
@ -51,7 +50,6 @@ __device__ void hmma_16_8_16_f32acc_bf16ab(
: "=f"(d_reg[0]), "=f"(d_reg[1]), "=f"(d_reg[2]), "=f"(d_reg[3])
: "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(b0), "r"(b1), "f"(d_reg[0]), "f"(d_reg[1]), "f"(d_reg[2]),
"f"(d_reg[3]));
#endif
}
extern "C"
@ -72,11 +70,9 @@ __device__ void ldgsts_128(void const* gPtr, void* sPtr, uint32_t pred)
__device__ void ldsm_x4(void* smem_ptr, uint32_t* reg_ptr)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n"
: "=r"(reg_ptr[0]), "=r"(reg_ptr[1]), "=r"(reg_ptr[2]), "=r"(reg_ptr[3])
: "r"(__nvvm_get_smem_pointer(smem_ptr)));
#endif
}
template <class Type>
@ -90,20 +86,18 @@ __device__ int apply_swizzle_343_on_elem_row_col(int row_idx_, int col_idx_)
return *reinterpret_cast<int*>(&col_idx);
}
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
__device__ void initialize_barrier(uint64_t* smem_barrier, // 64 bits user-manged barrier in smem
int thread_count = 1) // Thread count expected to arrive/wait on this barrier
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
uint32_t smem_int_ptr = __nvvm_get_smem_pointer(smem_barrier);
asm volatile("mbarrier.init.shared::cta.b64 [%0], %1;\n" ::"r"(smem_int_ptr), "r"(thread_count));
#endif
}
// Barrier wait
__device__ void wait_barrier(uint64_t* smem_barrier, // 64 bits user-manged barrier in smem
int phase_bit) // Current phase bit the barrier waiting to flip
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
uint32_t smem_int_ptr = __nvvm_get_smem_pointer(smem_barrier);
asm volatile(
"{\n"
@ -115,12 +109,10 @@ __device__ void wait_barrier(uint64_t* smem_barrier, // 64 bits user-manged barr
"DONE:\n"
"}\n" ::"r"(smem_int_ptr),
"r"(phase_bit));
#endif
}
__device__ bool try_wait_barrier(uint64_t* smem_ptr, int phase_bit)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
uint32_t wait_complete;
uint32_t smem_int_ptr = __nvvm_get_smem_pointer(smem_ptr);
asm volatile(
@ -132,29 +124,25 @@ __device__ bool try_wait_barrier(uint64_t* smem_ptr, int phase_bit)
: "=r"(wait_complete)
: "r"(smem_int_ptr), "r"(phase_bit));
return static_cast<bool>(wait_complete);
#endif
}
// Barrier arrive
__device__ void arrive_barrier(uint64_t* smem_barrier) // 64 bits user-manged barrier in smem
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
uint32_t smem_int_ptr = __nvvm_get_smem_pointer(smem_barrier);
asm volatile(
"{\n"
".reg .b64 state; \n"
"mbarrier.arrive.shared::cta.b64 state, [%0];\n"
"}\n" ::"r"(smem_int_ptr));
#endif
}
__device__ void ldgsts_arrive(uint64_t* smem_barrier)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
uint32_t smem_int_ptr = __nvvm_get_smem_pointer(smem_barrier);
asm volatile("cp.async.mbarrier.arrive.noinc.shared.b64 [%0];" : : "r"(smem_int_ptr));
#endif
}
#endif
template <int gemm_k, int tile_m, int tile_k, int stage_cnt>
struct GmemLoaderA