From d3d951d837b77ea6ac66bafc77aba86a4e24c9cd Mon Sep 17 00:00:00 2001 From: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com> Date: Fri, 6 Feb 2026 00:28:29 +0800 Subject: [PATCH] [None][fix] Fix amax to avoid NaN issue in fp8_blockscale_gemm_kernel. (#11256) Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com> --- .../fp8_blockscale_gemm_kernel.cuh | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_gemm_kernel.cuh b/cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_gemm_kernel.cuh index 6d253e25c6..e3dbcbae93 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_gemm_kernel.cuh +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_gemm_kernel.cuh @@ -32,6 +32,7 @@ #include "fp8_blockscale_tma_utils.cuh" #include "sm120_blockwise_gemm/sm120_fp8_gemm_1d1d.cuh" #include "tensorrt_llm/common/config.h" +#include "tensorrt_llm/common/cudaTypeUtils.cuh" #include "tensorrt_llm/common/cudaUtils.h" #include "tensorrt_llm/common/logger.h" #include "tensorrt_llm/deep_gemm/fp8_gemm.cuh" @@ -225,7 +226,8 @@ __global__ void scale_1x128_kernel( } InputType amax = find_max_elem_in_warp(input_amax); - ScaleType quant_scale = amax != InputType(0.f) ? 448.f / ScaleType(amax) : 1.f; + amax = tensorrt_llm::common::cuda_max(amax, InputType(1e-10f)); + ScaleType quant_scale = 448.f / ScaleType(amax); ScaleType dequant_scale; if constexpr (USE_UE8M0) @@ -370,13 +372,14 @@ __global__ void scale_1x128_kernel(OutputType* output, float* scales, InputType InputType amax = kernel_utils::warpReduceSum(max(max(fabs(float(input_frag[0])), fabs(float(input_frag[1]))), max(fabs(float(input_frag[2])), fabs(float(input_frag[3]))))); + amax = tensorrt_llm::common::cuda_max(amax, InputType(1e-10f)); // Half seems to be slower, probably because we need float values below // anyway. InputType amax = kernel_utils::warpReduceSum( // __hmax(__hmax(__habs(input_frag[0]), __habs(input_frag[1])), // __hmax(__habs(input_frag[2]), __habs(input_frag[3])))); - float scale = amax != InputType(0.f) ? 448.f / float(amax) : 1.f; + float scale = 448.f / float(amax); if (kernel_utils::elect_one_sync(lane_id)) { @@ -449,7 +452,8 @@ __global__ void scale_1x128_reshape_kernel( } InputType amax = find_max_elem_in_warp(input_amax); - ScaleType scale = amax != InputType(0.f) ? 448.f / ScaleType(amax) : 1.f; + amax = tensorrt_llm::common::cuda_max(amax, InputType(1e-10f)); + ScaleType scale = 448.f / ScaleType(amax); if (lane_id == 0) { @@ -522,7 +526,8 @@ __global__ void scale_128x128_kernel( } InputType amax = find_max_elem_in_warp(input_amax); - ScaleType scale = amax != InputType(0.f) ? 448.f / ScaleType(amax) : 1.f; + amax = tensorrt_llm::common::cuda_max(amax, InputType(1e-10f)); + ScaleType scale = 448.f / ScaleType(amax); if (lane_id == 0) {