perf: Optimize quantization kernels used in DeepSeek on Hopper (#3466)

Signed-off-by: jiahanc <jiahanc@nvidia.com>
This commit is contained in:
jiahanc 2025-04-15 02:49:57 -07:00 committed by GitHub
parent 5cfa927132
commit 1d3b98b920
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -979,7 +979,7 @@ __global__ void scale_1x128_kernel(
size_t scales_along_dim_x = div_up(dim_x, 128);
size_t scales_along_dim_y = div_up(dim_y, 1);
size_t stride_scale_dim_y = div_up(dim_y, 4) * 4;
using Input2Type = typename std::conditional<std::is_same<InputType, half>::value, half2, __nv_bfloat162>::type;
for (size_t warp_idx = (blockIdx.x * blockDim.x + threadIdx.x) / 32;
warp_idx < scales_along_dim_x * scales_along_dim_y; warp_idx += gridDim.x * blockDim.x / 32)
{
@ -988,21 +988,34 @@ __global__ void scale_1x128_kernel(
InputType const* input_line = input + (size_t) scales_idx_y * dim_x + scales_idx_x * 128;
InputType input_amax = InputType(0);
int lane_id = threadIdx.x % 32;
InputType input_frag[4] = {0};
// Each thread reads 2 elements from input_line
int lane_id = threadIdx.x % 32 * 2;
for (int i = 0; i < 4; i++)
Input2Type input_frag2[2] = {Input2Type(0, 0), Input2Type(0, 0)};
#pragma unroll
for (int i = 0; i < 2; i++)
{
if (scales_idx_x * 128 + i * 32 + lane_id >= dim_x)
if (scales_idx_x * 128 + i * 64 + lane_id >= dim_x)
{
break;
}
else
{
input_frag[i] = input_line[lane_id];
input_amax = InputType(std::max(float(input_amax), std::fabs(float(input_frag[i]))));
input_frag2[i] = *((Input2Type*) (input_line) + lane_id / 2);
}
input_line += 64;
}
#pragma unroll
for (int i = 0; i < 2; i++)
{
if (scales_idx_x * 128 + i * 64 + lane_id >= dim_x)
{
break;
}
else
{
input_amax = InputType(__hmax(input_amax, __hmax(__habs(input_frag2[i].x), __habs(input_frag2[i].y))));
}
input_line += 32;
}
InputType amax = find_max_elem_in_warp(input_amax);
@ -1014,18 +1027,21 @@ __global__ void scale_1x128_kernel(
}
OutputType* output_line = output + (size_t) scales_idx_y * dim_x + scales_idx_x * 128;
for (int i = 0; i < 4; i++)
#pragma unroll
for (int i = 0; i < 2; i++)
{
if (scales_idx_x * 128 + i * 32 + lane_id >= dim_x)
if (scales_idx_x * 128 + i * 64 + lane_id >= dim_x)
{
break;
}
else
{
ScaleType value = ScaleType(input_frag[i]) * scale;
output_line[lane_id] = OutputType(value);
ScaleType value_1 = ScaleType(input_frag2[i].x) * scale;
ScaleType value_2 = ScaleType(input_frag2[i].y) * scale;
output_line[lane_id] = OutputType(value_1);
output_line[lane_id + 1] = OutputType(value_2);
}
output_line += 32;
output_line += 64;
}
}
#endif
@ -1245,7 +1261,7 @@ __global__ void scale_1x128_reshape_kernel(
size_t scales_along_dim_y = div_up(dim_y, 1);
size_t scales_along_dim_h = div_up(dim_h, 1);
size_t stride_scale_dim_y = div_up(dim_y, 4) * 4;
using Input2Type = typename std::conditional<std::is_same<InputType, half>::value, half2, __nv_bfloat162>::type;
for (size_t warp_idx = (blockIdx.x * blockDim.x + threadIdx.x) / 32;
warp_idx < scales_along_dim_x * scales_along_dim_y * scales_along_dim_h;
warp_idx += gridDim.x * blockDim.x / 32)
@ -1257,21 +1273,33 @@ __global__ void scale_1x128_reshape_kernel(
InputType const* input_line
= input + (size_t) scales_idx_y * stride_x * dim_h + (size_t) scales_idx_h * stride_x + scales_idx_x * 128;
InputType input_amax = InputType(0);
int lane_id = threadIdx.x % 32;
InputType input_frag[4] = {0};
int lane_id = threadIdx.x % 32 * 2;
for (int i = 0; i < 4; i++)
Input2Type input_frag2[2] = {Input2Type(0, 0), Input2Type(0, 0)};
#pragma unroll
for (int i = 0; i < 2; i++)
{
if (scales_idx_x * 128 + i * 32 + lane_id >= dim_x)
if (scales_idx_x * 128 + i * 64 + lane_id >= dim_x)
{
break;
}
else
{
input_frag[i] = input_line[lane_id];
input_amax = InputType(std::max(float(input_amax), std::fabs(float(input_frag[i]))));
input_frag2[i] = *((Input2Type*) (input_line) + lane_id / 2);
}
input_line += 64;
}
#pragma unroll
for (int i = 0; i < 2; i++)
{
if (scales_idx_x * 128 + i * 64 + lane_id >= dim_x)
{
break;
}
else
{
input_amax = InputType(__hmax(input_amax, __hmax(__habs(input_frag2[i].x), __habs(input_frag2[i].y))));
}
input_line += 32;
}
InputType amax = find_max_elem_in_warp(input_amax);
@ -1286,18 +1314,21 @@ __global__ void scale_1x128_reshape_kernel(
OutputType* output_line
= output + (size_t) scales_idx_h * dim_y * dim_x + (size_t) scales_idx_y * dim_x + scales_idx_x * 128;
for (int i = 0; i < 4; i++)
#pragma unroll
for (int i = 0; i < 2; i++)
{
if (scales_idx_x * 128 + i * 32 + lane_id >= dim_x)
if (scales_idx_x * 128 + i * 64 + lane_id >= dim_x)
{
break;
}
else
{
ScaleType value = ScaleType(input_frag[i]) * scale;
output_line[lane_id] = OutputType(value);
ScaleType value_1 = ScaleType(input_frag2[i].x) * scale;
ScaleType value_2 = ScaleType(input_frag2[i].y) * scale;
output_line[lane_id] = OutputType(value_1);
output_line[lane_id + 1] = OutputType(value_2);
}
output_line += 32;
output_line += 64;
}
}
#endif
@ -1425,7 +1456,7 @@ void fp8_1x128_cs(
{
kNumDeviceSMs = tensorrt_llm::common::getMultiProcessorCount();
}
scale_1x128_kernel<<<kNumDeviceSMs, 256, 0, stream>>>(mat_quant, scales, mat, shape_x, shape_y);
scale_1x128_kernel<<<kNumDeviceSMs * 8, 256, 0, stream>>>(mat_quant, scales, mat, shape_x, shape_y);
}
void fp8_1x128_cs_reshape(__nv_fp8_e4m3* mat_quant, float* scales, __nv_bfloat16 const* mat, int shape_x, int shape_h,
@ -1435,7 +1466,7 @@ void fp8_1x128_cs_reshape(__nv_fp8_e4m3* mat_quant, float* scales, __nv_bfloat16
{
kNumDeviceSMs = tensorrt_llm::common::getMultiProcessorCount();
}
scale_1x128_reshape_kernel<<<kNumDeviceSMs, 256, 0, stream>>>(
scale_1x128_reshape_kernel<<<kNumDeviceSMs * 8, 256, 0, stream>>>(
mat_quant, scales, mat, shape_x, shape_h, shape_y, stride_x);
}
@ -1642,7 +1673,7 @@ void fp8_gemm_run(__nv_bfloat16 const* mat_a, __nv_fp8_e4m3* fp8_mat_a, int ld_a
if (internal_quantize_a)
{
scale_1x128_kernel<<<kNumDeviceSMs, 256, 0, stream>>>(fp8_mat_a, scales_a, mat_a, shape_k, shape_m);
scale_1x128_kernel<<<kNumDeviceSMs * 8, 256, 0, stream>>>(fp8_mat_a, scales_a, mat_a, shape_k, shape_m);
}
if (internal_quantize_b)
{
@ -1792,7 +1823,7 @@ void fp8_stride_batch_gemm_run(__nv_bfloat16 const* mat_a, __nv_fp8_e4m3* fp8_ma
}
if (internal_quantize_a)
{
scale_1x128_kernel<<<kNumDeviceSMs, 256, 0, stream>>>(
scale_1x128_kernel<<<kNumDeviceSMs * 8, 256, 0, stream>>>(
fp8_mat_a, scales_a, mat_a, shape_k, shape_m * num_problems);
}
if (internal_quantize_b)