[None][fix] fix tinygemm accuracy (#11411)

Signed-off-by: Bo Deng <deemod@nvidia.com>
This commit is contained in:
Bo Deng 2026-02-10 18:09:30 +08:00 committed by GitHub
parent adc0d82500
commit be88fe33be
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -236,6 +236,7 @@ __global__ __launch_bounds__(384, 1) void tinygemm_kernel(__nv_bfloat16* output,
if (!weight_warp)
{
cudaGridDependencySynchronize();
cudaTriggerProgrammaticLaunchCompletion();
}
for (int ki = 0; ki < K_LOOPS_DMA; ki++)
@ -301,6 +302,17 @@ __global__ __launch_bounds__(384, 1) void tinygemm_kernel(__nv_bfloat16* output,
phase ^= 1;
}
}
// Wait for pending loads to be consumed before exiting, to avoid race
for (int i = 0; i < (STAGES / 4) - 1; i++)
{
bar_wait(__cvta_generic_to_shared(&bar_data_consumed[stage]), phase ^ 1);
stage += 4;
if (stage >= STAGES)
{
stage = warp_id % 4;
phase ^= 1;
}
}
}
// Compute threads
else if (warp_id < 4)
@ -410,8 +422,6 @@ __global__ __launch_bounds__(384, 1) void tinygemm_kernel(__nv_bfloat16* output,
__syncthreads();
cudaTriggerProgrammaticLaunchCompletion();
if (warp_id == 0)
{
@ -443,7 +453,6 @@ __global__ __launch_bounds__(384, 1) void tinygemm_kernel(__nv_bfloat16* output,
if (PROFILE && blockIdx.y == 0 && threadIdx.x == 0)
profile[blockIdx.x].complete = gclock64();
}
__syncthreads();
}
#endif // end if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
}