perf: Optimize quantization kernels used in DeepSeek on Hopper (#3466)

Signed-off-by: jiahanc <jiahanc@nvidia.com>
This commit is contained in:
jiahanc 2025-04-15 02:49:57 -07:00 committed by GitHub
parent 5cfa927132
commit 1d3b98b920
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -979,7 +979,7 @@ __global__ void scale_1x128_kernel(
size_t scales_along_dim_x = div_up(dim_x, 128); size_t scales_along_dim_x = div_up(dim_x, 128);
size_t scales_along_dim_y = div_up(dim_y, 1); size_t scales_along_dim_y = div_up(dim_y, 1);
size_t stride_scale_dim_y = div_up(dim_y, 4) * 4; size_t stride_scale_dim_y = div_up(dim_y, 4) * 4;
using Input2Type = typename std::conditional<std::is_same<InputType, half>::value, half2, __nv_bfloat162>::type;
for (size_t warp_idx = (blockIdx.x * blockDim.x + threadIdx.x) / 32; for (size_t warp_idx = (blockIdx.x * blockDim.x + threadIdx.x) / 32;
warp_idx < scales_along_dim_x * scales_along_dim_y; warp_idx += gridDim.x * blockDim.x / 32) warp_idx < scales_along_dim_x * scales_along_dim_y; warp_idx += gridDim.x * blockDim.x / 32)
{ {
@ -988,21 +988,34 @@ __global__ void scale_1x128_kernel(
InputType const* input_line = input + (size_t) scales_idx_y * dim_x + scales_idx_x * 128; InputType const* input_line = input + (size_t) scales_idx_y * dim_x + scales_idx_x * 128;
InputType input_amax = InputType(0); InputType input_amax = InputType(0);
int lane_id = threadIdx.x % 32; // Each thread reads 2 elements from input_line
InputType input_frag[4] = {0}; int lane_id = threadIdx.x % 32 * 2;
for (int i = 0; i < 4; i++) Input2Type input_frag2[2] = {Input2Type(0, 0), Input2Type(0, 0)};
#pragma unroll
for (int i = 0; i < 2; i++)
{ {
if (scales_idx_x * 128 + i * 32 + lane_id >= dim_x) if (scales_idx_x * 128 + i * 64 + lane_id >= dim_x)
{ {
break; break;
} }
else else
{ {
input_frag[i] = input_line[lane_id]; input_frag2[i] = *((Input2Type*) (input_line) + lane_id / 2);
input_amax = InputType(std::max(float(input_amax), std::fabs(float(input_frag[i])))); }
input_line += 64;
}
#pragma unroll
for (int i = 0; i < 2; i++)
{
if (scales_idx_x * 128 + i * 64 + lane_id >= dim_x)
{
break;
}
else
{
input_amax = InputType(__hmax(input_amax, __hmax(__habs(input_frag2[i].x), __habs(input_frag2[i].y))));
} }
input_line += 32;
} }
InputType amax = find_max_elem_in_warp(input_amax); InputType amax = find_max_elem_in_warp(input_amax);
@ -1014,18 +1027,21 @@ __global__ void scale_1x128_kernel(
} }
OutputType* output_line = output + (size_t) scales_idx_y * dim_x + scales_idx_x * 128; OutputType* output_line = output + (size_t) scales_idx_y * dim_x + scales_idx_x * 128;
for (int i = 0; i < 4; i++) #pragma unroll
for (int i = 0; i < 2; i++)
{ {
if (scales_idx_x * 128 + i * 32 + lane_id >= dim_x) if (scales_idx_x * 128 + i * 64 + lane_id >= dim_x)
{ {
break; break;
} }
else else
{ {
ScaleType value = ScaleType(input_frag[i]) * scale; ScaleType value_1 = ScaleType(input_frag2[i].x) * scale;
output_line[lane_id] = OutputType(value); ScaleType value_2 = ScaleType(input_frag2[i].y) * scale;
output_line[lane_id] = OutputType(value_1);
output_line[lane_id + 1] = OutputType(value_2);
} }
output_line += 32; output_line += 64;
} }
} }
#endif #endif
@ -1245,7 +1261,7 @@ __global__ void scale_1x128_reshape_kernel(
size_t scales_along_dim_y = div_up(dim_y, 1); size_t scales_along_dim_y = div_up(dim_y, 1);
size_t scales_along_dim_h = div_up(dim_h, 1); size_t scales_along_dim_h = div_up(dim_h, 1);
size_t stride_scale_dim_y = div_up(dim_y, 4) * 4; size_t stride_scale_dim_y = div_up(dim_y, 4) * 4;
using Input2Type = typename std::conditional<std::is_same<InputType, half>::value, half2, __nv_bfloat162>::type;
for (size_t warp_idx = (blockIdx.x * blockDim.x + threadIdx.x) / 32; for (size_t warp_idx = (blockIdx.x * blockDim.x + threadIdx.x) / 32;
warp_idx < scales_along_dim_x * scales_along_dim_y * scales_along_dim_h; warp_idx < scales_along_dim_x * scales_along_dim_y * scales_along_dim_h;
warp_idx += gridDim.x * blockDim.x / 32) warp_idx += gridDim.x * blockDim.x / 32)
@ -1257,21 +1273,33 @@ __global__ void scale_1x128_reshape_kernel(
InputType const* input_line InputType const* input_line
= input + (size_t) scales_idx_y * stride_x * dim_h + (size_t) scales_idx_h * stride_x + scales_idx_x * 128; = input + (size_t) scales_idx_y * stride_x * dim_h + (size_t) scales_idx_h * stride_x + scales_idx_x * 128;
InputType input_amax = InputType(0); InputType input_amax = InputType(0);
int lane_id = threadIdx.x % 32; int lane_id = threadIdx.x % 32 * 2;
InputType input_frag[4] = {0};
for (int i = 0; i < 4; i++) Input2Type input_frag2[2] = {Input2Type(0, 0), Input2Type(0, 0)};
#pragma unroll
for (int i = 0; i < 2; i++)
{ {
if (scales_idx_x * 128 + i * 32 + lane_id >= dim_x) if (scales_idx_x * 128 + i * 64 + lane_id >= dim_x)
{ {
break; break;
} }
else else
{ {
input_frag[i] = input_line[lane_id]; input_frag2[i] = *((Input2Type*) (input_line) + lane_id / 2);
input_amax = InputType(std::max(float(input_amax), std::fabs(float(input_frag[i])))); }
input_line += 64;
}
#pragma unroll
for (int i = 0; i < 2; i++)
{
if (scales_idx_x * 128 + i * 64 + lane_id >= dim_x)
{
break;
}
else
{
input_amax = InputType(__hmax(input_amax, __hmax(__habs(input_frag2[i].x), __habs(input_frag2[i].y))));
} }
input_line += 32;
} }
InputType amax = find_max_elem_in_warp(input_amax); InputType amax = find_max_elem_in_warp(input_amax);
@ -1286,18 +1314,21 @@ __global__ void scale_1x128_reshape_kernel(
OutputType* output_line OutputType* output_line
= output + (size_t) scales_idx_h * dim_y * dim_x + (size_t) scales_idx_y * dim_x + scales_idx_x * 128; = output + (size_t) scales_idx_h * dim_y * dim_x + (size_t) scales_idx_y * dim_x + scales_idx_x * 128;
for (int i = 0; i < 4; i++) #pragma unroll
for (int i = 0; i < 2; i++)
{ {
if (scales_idx_x * 128 + i * 32 + lane_id >= dim_x) if (scales_idx_x * 128 + i * 64 + lane_id >= dim_x)
{ {
break; break;
} }
else else
{ {
ScaleType value = ScaleType(input_frag[i]) * scale; ScaleType value_1 = ScaleType(input_frag2[i].x) * scale;
output_line[lane_id] = OutputType(value); ScaleType value_2 = ScaleType(input_frag2[i].y) * scale;
output_line[lane_id] = OutputType(value_1);
output_line[lane_id + 1] = OutputType(value_2);
} }
output_line += 32; output_line += 64;
} }
} }
#endif #endif
@ -1425,7 +1456,7 @@ void fp8_1x128_cs(
{ {
kNumDeviceSMs = tensorrt_llm::common::getMultiProcessorCount(); kNumDeviceSMs = tensorrt_llm::common::getMultiProcessorCount();
} }
scale_1x128_kernel<<<kNumDeviceSMs, 256, 0, stream>>>(mat_quant, scales, mat, shape_x, shape_y); scale_1x128_kernel<<<kNumDeviceSMs * 8, 256, 0, stream>>>(mat_quant, scales, mat, shape_x, shape_y);
} }
void fp8_1x128_cs_reshape(__nv_fp8_e4m3* mat_quant, float* scales, __nv_bfloat16 const* mat, int shape_x, int shape_h, void fp8_1x128_cs_reshape(__nv_fp8_e4m3* mat_quant, float* scales, __nv_bfloat16 const* mat, int shape_x, int shape_h,
@ -1435,7 +1466,7 @@ void fp8_1x128_cs_reshape(__nv_fp8_e4m3* mat_quant, float* scales, __nv_bfloat16
{ {
kNumDeviceSMs = tensorrt_llm::common::getMultiProcessorCount(); kNumDeviceSMs = tensorrt_llm::common::getMultiProcessorCount();
} }
scale_1x128_reshape_kernel<<<kNumDeviceSMs, 256, 0, stream>>>( scale_1x128_reshape_kernel<<<kNumDeviceSMs * 8, 256, 0, stream>>>(
mat_quant, scales, mat, shape_x, shape_h, shape_y, stride_x); mat_quant, scales, mat, shape_x, shape_h, shape_y, stride_x);
} }
@ -1642,7 +1673,7 @@ void fp8_gemm_run(__nv_bfloat16 const* mat_a, __nv_fp8_e4m3* fp8_mat_a, int ld_a
if (internal_quantize_a) if (internal_quantize_a)
{ {
scale_1x128_kernel<<<kNumDeviceSMs, 256, 0, stream>>>(fp8_mat_a, scales_a, mat_a, shape_k, shape_m); scale_1x128_kernel<<<kNumDeviceSMs * 8, 256, 0, stream>>>(fp8_mat_a, scales_a, mat_a, shape_k, shape_m);
} }
if (internal_quantize_b) if (internal_quantize_b)
{ {
@ -1792,7 +1823,7 @@ void fp8_stride_batch_gemm_run(__nv_bfloat16 const* mat_a, __nv_fp8_e4m3* fp8_ma
} }
if (internal_quantize_a) if (internal_quantize_a)
{ {
scale_1x128_kernel<<<kNumDeviceSMs, 256, 0, stream>>>( scale_1x128_kernel<<<kNumDeviceSMs * 8, 256, 0, stream>>>(
fp8_mat_a, scales_a, mat_a, shape_k, shape_m * num_problems); fp8_mat_a, scales_a, mat_a, shape_k, shape_m * num_problems);
} }
if (internal_quantize_b) if (internal_quantize_b)