mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-04 18:21:52 +08:00
[https://nvbugs/5769425][fix] add syncthreads for tinygemm to resolve intermittent accuracy problem (#10873)
Signed-off-by: Zhenhuan Chen <zhenhuanc@nvidia.com> Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
This commit is contained in:
parent
8fd22ac72d
commit
6c2ecad2fe
@ -778,6 +778,7 @@ quantize_with_block_size(
|
||||
|
||||
// Get the global scaling factor, which will be applied to the SF.
|
||||
// Note SFScale is the same as next GEMM's alpha, which is (448.f / (Alpha_A / 6.f)).
|
||||
// This value is prepared by model, no need to be protected by ACKBULK
|
||||
float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[0];
|
||||
|
||||
// Is it swizzled layout?
|
||||
|
||||
@ -410,6 +410,8 @@ __global__ __launch_bounds__(384, 1) void tinygemm_kernel(__nv_bfloat16* output,
|
||||
|
||||
__syncthreads();
|
||||
|
||||
cudaTriggerProgrammaticLaunchCompletion();
|
||||
|
||||
if (warp_id == 0)
|
||||
{
|
||||
|
||||
@ -440,10 +442,8 @@ __global__ __launch_bounds__(384, 1) void tinygemm_kernel(__nv_bfloat16* output,
|
||||
|
||||
if (PROFILE && blockIdx.y == 0 && threadIdx.x == 0)
|
||||
profile[blockIdx.x].complete = gclock64();
|
||||
|
||||
if (threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0)
|
||||
cudaTriggerProgrammaticLaunchCompletion();
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
#endif // end if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user