mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-16 15:55:08 +08:00
[None][fix] Fix amax to avoid NaN issue in fp8_blockscale_gemm_kernel. (#11256)
Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
This commit is contained in:
parent
7d235cfb23
commit
d3d951d837
@ -32,6 +32,7 @@
|
||||
#include "fp8_blockscale_tma_utils.cuh"
|
||||
#include "sm120_blockwise_gemm/sm120_fp8_gemm_1d1d.cuh"
|
||||
#include "tensorrt_llm/common/config.h"
|
||||
#include "tensorrt_llm/common/cudaTypeUtils.cuh"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/common/logger.h"
|
||||
#include "tensorrt_llm/deep_gemm/fp8_gemm.cuh"
|
||||
@ -225,7 +226,8 @@ __global__ void scale_1x128_kernel(
|
||||
}
|
||||
|
||||
InputType amax = find_max_elem_in_warp(input_amax);
|
||||
ScaleType quant_scale = amax != InputType(0.f) ? 448.f / ScaleType(amax) : 1.f;
|
||||
amax = tensorrt_llm::common::cuda_max(amax, InputType(1e-10f));
|
||||
ScaleType quant_scale = 448.f / ScaleType(amax);
|
||||
ScaleType dequant_scale;
|
||||
|
||||
if constexpr (USE_UE8M0)
|
||||
@ -370,13 +372,14 @@ __global__ void scale_1x128_kernel(OutputType* output, float* scales, InputType
|
||||
|
||||
InputType amax = kernel_utils::warpReduceSum(max(max(fabs(float(input_frag[0])), fabs(float(input_frag[1]))),
|
||||
max(fabs(float(input_frag[2])), fabs(float(input_frag[3])))));
|
||||
amax = tensorrt_llm::common::cuda_max(amax, InputType(1e-10f));
|
||||
|
||||
// Half seems to be slower, probably because we need float values below
|
||||
// anyway. InputType amax = kernel_utils::warpReduceSum(
|
||||
// __hmax(__hmax(__habs(input_frag[0]), __habs(input_frag[1])),
|
||||
// __hmax(__habs(input_frag[2]), __habs(input_frag[3]))));
|
||||
|
||||
float scale = amax != InputType(0.f) ? 448.f / float(amax) : 1.f;
|
||||
float scale = 448.f / float(amax);
|
||||
|
||||
if (kernel_utils::elect_one_sync(lane_id))
|
||||
{
|
||||
@ -449,7 +452,8 @@ __global__ void scale_1x128_reshape_kernel(
|
||||
}
|
||||
|
||||
InputType amax = find_max_elem_in_warp(input_amax);
|
||||
ScaleType scale = amax != InputType(0.f) ? 448.f / ScaleType(amax) : 1.f;
|
||||
amax = tensorrt_llm::common::cuda_max(amax, InputType(1e-10f));
|
||||
ScaleType scale = 448.f / ScaleType(amax);
|
||||
|
||||
if (lane_id == 0)
|
||||
{
|
||||
@ -522,7 +526,8 @@ __global__ void scale_128x128_kernel(
|
||||
}
|
||||
|
||||
InputType amax = find_max_elem_in_warp(input_amax);
|
||||
ScaleType scale = amax != InputType(0.f) ? 448.f / ScaleType(amax) : 1.f;
|
||||
amax = tensorrt_llm::common::cuda_max(amax, InputType(1e-10f));
|
||||
ScaleType scale = 448.f / ScaleType(amax);
|
||||
|
||||
if (lane_id == 0)
|
||||
{
|
||||
|
||||
Loading…
Reference in New Issue
Block a user