mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-16 15:55:08 +08:00
[None][fix] fix tinygemm accuracy (#11411)
Signed-off-by: Bo Deng <deemod@nvidia.com>
This commit is contained in:
parent
adc0d82500
commit
be88fe33be
@ -236,6 +236,7 @@ __global__ __launch_bounds__(384, 1) void tinygemm_kernel(__nv_bfloat16* output,
|
||||
if (!weight_warp)
|
||||
{
|
||||
cudaGridDependencySynchronize();
|
||||
cudaTriggerProgrammaticLaunchCompletion();
|
||||
}
|
||||
|
||||
for (int ki = 0; ki < K_LOOPS_DMA; ki++)
|
||||
@ -301,6 +302,17 @@ __global__ __launch_bounds__(384, 1) void tinygemm_kernel(__nv_bfloat16* output,
|
||||
phase ^= 1;
|
||||
}
|
||||
}
|
||||
// Wait for pending loads to be consumed before exiting, to avoid race
|
||||
for (int i = 0; i < (STAGES / 4) - 1; i++)
|
||||
{
|
||||
bar_wait(__cvta_generic_to_shared(&bar_data_consumed[stage]), phase ^ 1);
|
||||
stage += 4;
|
||||
if (stage >= STAGES)
|
||||
{
|
||||
stage = warp_id % 4;
|
||||
phase ^= 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Compute threads
|
||||
else if (warp_id < 4)
|
||||
@ -410,8 +422,6 @@ __global__ __launch_bounds__(384, 1) void tinygemm_kernel(__nv_bfloat16* output,
|
||||
|
||||
__syncthreads();
|
||||
|
||||
cudaTriggerProgrammaticLaunchCompletion();
|
||||
|
||||
if (warp_id == 0)
|
||||
{
|
||||
|
||||
@ -443,7 +453,6 @@ __global__ __launch_bounds__(384, 1) void tinygemm_kernel(__nv_bfloat16* output,
|
||||
if (PROFILE && blockIdx.y == 0 && threadIdx.x == 0)
|
||||
profile[blockIdx.x].complete = gclock64();
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
#endif // end if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user