Signed-off-by: Bo Deng <deemod@nvidia.com>
This commit is contained in:
Bo Deng 2026-01-27 00:02:22 +08:00 committed by GitHub
parent a30d3b7419
commit 6c694f85ba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 7 additions and 4 deletions

View File

@ -177,7 +177,7 @@ void invokeFP4Quantization(int b, int m, int n, T const* input, float const* SFS
config.stream = stream;
cudaLaunchAttribute attrs[1];
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = false;
attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL();
config.numAttrs = 1;
config.attrs = attrs;
cudaLaunchKernelEx(&config, kernel_instance, b, m, n, n, input, SFScale, reinterpret_cast<uint32_t*>(output),
@ -212,7 +212,7 @@ void invokeMxFP8Quantization(int b, int m, int n, int padded_n, T const* input,
config.stream = stream;
cudaLaunchAttribute attrs[1];
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = false;
attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL();
config.numAttrs = 1;
config.attrs = attrs;
cudaLaunchKernelEx(&config,
@ -387,7 +387,7 @@ void computePerTokenGlobalScaleForFP4Quantization(int b, int m, int n, T const*
config.stream = stream;
cudaLaunchAttribute attrs[1];
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = false;
attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL();
config.numAttrs = 1;
config.attrs = attrs;
TLLM_CUDA_CHECK(cudaLaunchKernelEx(

View File

@ -777,6 +777,7 @@ quantize_with_block_size(
// Get the global scaling factor, which will be applied to the SF.
// Note SFScale is the same as next GEMM's alpha, which is (448.f / (Alpha_A / 6.f)).
// This value is prepared by model, no need to be protected by ACKBULK
float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[0];
// Is it swizzled layout?

View File

@ -236,7 +236,6 @@ __global__ __launch_bounds__(384, 1) void kernel(__nv_bfloat16* output, __nv_bfl
if (!weight_warp)
{
cudaGridDependencySynchronize();
cudaTriggerProgrammaticLaunchCompletion();
}
for (int ki = 0; ki < K_LOOPS_DMA; ki++)
@ -411,6 +410,8 @@ __global__ __launch_bounds__(384, 1) void kernel(__nv_bfloat16* output, __nv_bfl
__syncthreads();
cudaTriggerProgrammaticLaunchCompletion();
if (warp_id == 0)
{
@ -442,6 +443,7 @@ __global__ __launch_bounds__(384, 1) void kernel(__nv_bfloat16* output, __nv_bfl
if (PROFILE && blockIdx.y == 0 && threadIdx.x == 0)
profile[blockIdx.x].complete = gclock64();
}
__syncthreads();
}
#endif // end if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
}