[None][fix] Fix amax to avoid NaN issue in fp8_blockscale_gemm_kernel. (#11256)

Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
This commit is contained in:
Yuxian Qiu 2026-02-06 00:28:29 +08:00 committed by GitHub
parent 7d235cfb23
commit d3d951d837
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -32,6 +32,7 @@
#include "fp8_blockscale_tma_utils.cuh"
#include "sm120_blockwise_gemm/sm120_fp8_gemm_1d1d.cuh"
#include "tensorrt_llm/common/config.h"
#include "tensorrt_llm/common/cudaTypeUtils.cuh"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/deep_gemm/fp8_gemm.cuh"
@ -225,7 +226,8 @@ __global__ void scale_1x128_kernel(
}
InputType amax = find_max_elem_in_warp(input_amax);
ScaleType quant_scale = amax != InputType(0.f) ? 448.f / ScaleType(amax) : 1.f;
amax = tensorrt_llm::common::cuda_max(amax, InputType(1e-10f));
ScaleType quant_scale = 448.f / ScaleType(amax);
ScaleType dequant_scale;
if constexpr (USE_UE8M0)
@ -370,13 +372,14 @@ __global__ void scale_1x128_kernel(OutputType* output, float* scales, InputType
InputType amax = kernel_utils::warpReduceSum(max(max(fabs(float(input_frag[0])), fabs(float(input_frag[1]))),
max(fabs(float(input_frag[2])), fabs(float(input_frag[3])))));
amax = tensorrt_llm::common::cuda_max(amax, InputType(1e-10f));
// Half seems to be slower, probably because we need float values below
// anyway. InputType amax = kernel_utils::warpReduceSum(
// __hmax(__hmax(__habs(input_frag[0]), __habs(input_frag[1])),
// __hmax(__habs(input_frag[2]), __habs(input_frag[3]))));
float scale = amax != InputType(0.f) ? 448.f / float(amax) : 1.f;
float scale = 448.f / float(amax);
if (kernel_utils::elect_one_sync(lane_id))
{
@ -449,7 +452,8 @@ __global__ void scale_1x128_reshape_kernel(
}
InputType amax = find_max_elem_in_warp(input_amax);
ScaleType scale = amax != InputType(0.f) ? 448.f / ScaleType(amax) : 1.f;
amax = tensorrt_llm::common::cuda_max(amax, InputType(1e-10f));
ScaleType scale = 448.f / ScaleType(amax);
if (lane_id == 0)
{
@ -522,7 +526,8 @@ __global__ void scale_128x128_kernel(
}
InputType amax = find_max_elem_in_warp(input_amax);
ScaleType scale = amax != InputType(0.f) ? 448.f / ScaleType(amax) : 1.f;
amax = tensorrt_llm::common::cuda_max(amax, InputType(1e-10f));
ScaleType scale = 448.f / ScaleType(amax);
if (lane_id == 0)
{