From be88fe33be8610af6589e8cd047300f96038f868 Mon Sep 17 00:00:00 2001 From: Bo Deng Date: Tue, 10 Feb 2026 18:09:30 +0800 Subject: [PATCH] [None][fix] fix tinygemm accuracy (#11411) Signed-off-by: Bo Deng --- .../kernels/tinygemm2/tinygemm2_kernel.cuh | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/cpp/tensorrt_llm/kernels/tinygemm2/tinygemm2_kernel.cuh b/cpp/tensorrt_llm/kernels/tinygemm2/tinygemm2_kernel.cuh index 6c1d72c353..ca95f6849b 100644 --- a/cpp/tensorrt_llm/kernels/tinygemm2/tinygemm2_kernel.cuh +++ b/cpp/tensorrt_llm/kernels/tinygemm2/tinygemm2_kernel.cuh @@ -236,6 +236,7 @@ __global__ __launch_bounds__(384, 1) void tinygemm_kernel(__nv_bfloat16* output, if (!weight_warp) { cudaGridDependencySynchronize(); + cudaTriggerProgrammaticLaunchCompletion(); } for (int ki = 0; ki < K_LOOPS_DMA; ki++) @@ -301,6 +302,17 @@ __global__ __launch_bounds__(384, 1) void tinygemm_kernel(__nv_bfloat16* output, phase ^= 1; } } + // Wait for pending loads to be consumed before exiting, to avoid race + for (int i = 0; i < (STAGES / 4) - 1; i++) + { + bar_wait(__cvta_generic_to_shared(&bar_data_consumed[stage]), phase ^ 1); + stage += 4; + if (stage >= STAGES) + { + stage = warp_id % 4; + phase ^= 1; + } + } } // Compute threads else if (warp_id < 4) @@ -410,8 +422,6 @@ __global__ __launch_bounds__(384, 1) void tinygemm_kernel(__nv_bfloat16* output, __syncthreads(); - cudaTriggerProgrammaticLaunchCompletion(); - if (warp_id == 0) { @@ -443,7 +453,6 @@ __global__ __launch_bounds__(384, 1) void tinygemm_kernel(__nv_bfloat16* output, if (PROFILE && blockIdx.y == 0 && threadIdx.x == 0) profile[blockIdx.x].complete = gclock64(); } - __syncthreads(); } #endif // end if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) }