diff --git a/cpp/tensorrt_llm/kernels/quantization.cuh b/cpp/tensorrt_llm/kernels/quantization.cuh index 6589cc67d5..89b96b288b 100644 --- a/cpp/tensorrt_llm/kernels/quantization.cuh +++ b/cpp/tensorrt_llm/kernels/quantization.cuh @@ -778,6 +778,7 @@ quantize_with_block_size( // Get the global scaling factor, which will be applied to the SF. // Note SFScale is the same as next GEMM's alpha, which is (448.f / (Alpha_A / 6.f)). + // This value is prepared by model, no need to be protected by ACKBULK float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[0]; // Is it swizzled layout? diff --git a/cpp/tensorrt_llm/kernels/tinygemm2/tinygemm2_kernel.cuh b/cpp/tensorrt_llm/kernels/tinygemm2/tinygemm2_kernel.cuh index 7f60e787bf..6c1d72c353 100644 --- a/cpp/tensorrt_llm/kernels/tinygemm2/tinygemm2_kernel.cuh +++ b/cpp/tensorrt_llm/kernels/tinygemm2/tinygemm2_kernel.cuh @@ -410,6 +410,8 @@ __global__ __launch_bounds__(384, 1) void tinygemm_kernel(__nv_bfloat16* output, __syncthreads(); + cudaTriggerProgrammaticLaunchCompletion(); + if (warp_id == 0) { @@ -440,10 +442,8 @@ __global__ __launch_bounds__(384, 1) void tinygemm_kernel(__nv_bfloat16* output, if (PROFILE && blockIdx.y == 0 && threadIdx.x == 0) profile[blockIdx.x].complete = gclock64(); - - if (threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0) - cudaTriggerProgrammaticLaunchCompletion(); } + __syncthreads(); } #endif // end if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) }