mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
perf: Optimize quantization kernels used in DeepSeek on Hopper (#3466)
Signed-off-by: jiahanc <jiahanc@nvidia.com>
This commit is contained in:
parent
5cfa927132
commit
1d3b98b920
@ -979,7 +979,7 @@ __global__ void scale_1x128_kernel(
|
||||
size_t scales_along_dim_x = div_up(dim_x, 128);
|
||||
size_t scales_along_dim_y = div_up(dim_y, 1);
|
||||
size_t stride_scale_dim_y = div_up(dim_y, 4) * 4;
|
||||
|
||||
using Input2Type = typename std::conditional<std::is_same<InputType, half>::value, half2, __nv_bfloat162>::type;
|
||||
for (size_t warp_idx = (blockIdx.x * blockDim.x + threadIdx.x) / 32;
|
||||
warp_idx < scales_along_dim_x * scales_along_dim_y; warp_idx += gridDim.x * blockDim.x / 32)
|
||||
{
|
||||
@ -988,21 +988,34 @@ __global__ void scale_1x128_kernel(
|
||||
|
||||
InputType const* input_line = input + (size_t) scales_idx_y * dim_x + scales_idx_x * 128;
|
||||
InputType input_amax = InputType(0);
|
||||
int lane_id = threadIdx.x % 32;
|
||||
InputType input_frag[4] = {0};
|
||||
// Each thread reads 2 elements from input_line
|
||||
int lane_id = threadIdx.x % 32 * 2;
|
||||
|
||||
for (int i = 0; i < 4; i++)
|
||||
Input2Type input_frag2[2] = {Input2Type(0, 0), Input2Type(0, 0)};
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 2; i++)
|
||||
{
|
||||
if (scales_idx_x * 128 + i * 32 + lane_id >= dim_x)
|
||||
if (scales_idx_x * 128 + i * 64 + lane_id >= dim_x)
|
||||
{
|
||||
break;
|
||||
}
|
||||
else
|
||||
{
|
||||
input_frag[i] = input_line[lane_id];
|
||||
input_amax = InputType(std::max(float(input_amax), std::fabs(float(input_frag[i]))));
|
||||
input_frag2[i] = *((Input2Type*) (input_line) + lane_id / 2);
|
||||
}
|
||||
input_line += 64;
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 2; i++)
|
||||
{
|
||||
if (scales_idx_x * 128 + i * 64 + lane_id >= dim_x)
|
||||
{
|
||||
break;
|
||||
}
|
||||
else
|
||||
{
|
||||
input_amax = InputType(__hmax(input_amax, __hmax(__habs(input_frag2[i].x), __habs(input_frag2[i].y))));
|
||||
}
|
||||
input_line += 32;
|
||||
}
|
||||
|
||||
InputType amax = find_max_elem_in_warp(input_amax);
|
||||
@ -1014,18 +1027,21 @@ __global__ void scale_1x128_kernel(
|
||||
}
|
||||
|
||||
OutputType* output_line = output + (size_t) scales_idx_y * dim_x + scales_idx_x * 128;
|
||||
for (int i = 0; i < 4; i++)
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 2; i++)
|
||||
{
|
||||
if (scales_idx_x * 128 + i * 32 + lane_id >= dim_x)
|
||||
if (scales_idx_x * 128 + i * 64 + lane_id >= dim_x)
|
||||
{
|
||||
break;
|
||||
}
|
||||
else
|
||||
{
|
||||
ScaleType value = ScaleType(input_frag[i]) * scale;
|
||||
output_line[lane_id] = OutputType(value);
|
||||
ScaleType value_1 = ScaleType(input_frag2[i].x) * scale;
|
||||
ScaleType value_2 = ScaleType(input_frag2[i].y) * scale;
|
||||
output_line[lane_id] = OutputType(value_1);
|
||||
output_line[lane_id + 1] = OutputType(value_2);
|
||||
}
|
||||
output_line += 32;
|
||||
output_line += 64;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
@ -1245,7 +1261,7 @@ __global__ void scale_1x128_reshape_kernel(
|
||||
size_t scales_along_dim_y = div_up(dim_y, 1);
|
||||
size_t scales_along_dim_h = div_up(dim_h, 1);
|
||||
size_t stride_scale_dim_y = div_up(dim_y, 4) * 4;
|
||||
|
||||
using Input2Type = typename std::conditional<std::is_same<InputType, half>::value, half2, __nv_bfloat162>::type;
|
||||
for (size_t warp_idx = (blockIdx.x * blockDim.x + threadIdx.x) / 32;
|
||||
warp_idx < scales_along_dim_x * scales_along_dim_y * scales_along_dim_h;
|
||||
warp_idx += gridDim.x * blockDim.x / 32)
|
||||
@ -1257,21 +1273,33 @@ __global__ void scale_1x128_reshape_kernel(
|
||||
InputType const* input_line
|
||||
= input + (size_t) scales_idx_y * stride_x * dim_h + (size_t) scales_idx_h * stride_x + scales_idx_x * 128;
|
||||
InputType input_amax = InputType(0);
|
||||
int lane_id = threadIdx.x % 32;
|
||||
InputType input_frag[4] = {0};
|
||||
int lane_id = threadIdx.x % 32 * 2;
|
||||
|
||||
for (int i = 0; i < 4; i++)
|
||||
Input2Type input_frag2[2] = {Input2Type(0, 0), Input2Type(0, 0)};
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 2; i++)
|
||||
{
|
||||
if (scales_idx_x * 128 + i * 32 + lane_id >= dim_x)
|
||||
if (scales_idx_x * 128 + i * 64 + lane_id >= dim_x)
|
||||
{
|
||||
break;
|
||||
}
|
||||
else
|
||||
{
|
||||
input_frag[i] = input_line[lane_id];
|
||||
input_amax = InputType(std::max(float(input_amax), std::fabs(float(input_frag[i]))));
|
||||
input_frag2[i] = *((Input2Type*) (input_line) + lane_id / 2);
|
||||
}
|
||||
input_line += 64;
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 2; i++)
|
||||
{
|
||||
if (scales_idx_x * 128 + i * 64 + lane_id >= dim_x)
|
||||
{
|
||||
break;
|
||||
}
|
||||
else
|
||||
{
|
||||
input_amax = InputType(__hmax(input_amax, __hmax(__habs(input_frag2[i].x), __habs(input_frag2[i].y))));
|
||||
}
|
||||
input_line += 32;
|
||||
}
|
||||
|
||||
InputType amax = find_max_elem_in_warp(input_amax);
|
||||
@ -1286,18 +1314,21 @@ __global__ void scale_1x128_reshape_kernel(
|
||||
|
||||
OutputType* output_line
|
||||
= output + (size_t) scales_idx_h * dim_y * dim_x + (size_t) scales_idx_y * dim_x + scales_idx_x * 128;
|
||||
for (int i = 0; i < 4; i++)
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 2; i++)
|
||||
{
|
||||
if (scales_idx_x * 128 + i * 32 + lane_id >= dim_x)
|
||||
if (scales_idx_x * 128 + i * 64 + lane_id >= dim_x)
|
||||
{
|
||||
break;
|
||||
}
|
||||
else
|
||||
{
|
||||
ScaleType value = ScaleType(input_frag[i]) * scale;
|
||||
output_line[lane_id] = OutputType(value);
|
||||
ScaleType value_1 = ScaleType(input_frag2[i].x) * scale;
|
||||
ScaleType value_2 = ScaleType(input_frag2[i].y) * scale;
|
||||
output_line[lane_id] = OutputType(value_1);
|
||||
output_line[lane_id + 1] = OutputType(value_2);
|
||||
}
|
||||
output_line += 32;
|
||||
output_line += 64;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
@ -1425,7 +1456,7 @@ void fp8_1x128_cs(
|
||||
{
|
||||
kNumDeviceSMs = tensorrt_llm::common::getMultiProcessorCount();
|
||||
}
|
||||
scale_1x128_kernel<<<kNumDeviceSMs, 256, 0, stream>>>(mat_quant, scales, mat, shape_x, shape_y);
|
||||
scale_1x128_kernel<<<kNumDeviceSMs * 8, 256, 0, stream>>>(mat_quant, scales, mat, shape_x, shape_y);
|
||||
}
|
||||
|
||||
void fp8_1x128_cs_reshape(__nv_fp8_e4m3* mat_quant, float* scales, __nv_bfloat16 const* mat, int shape_x, int shape_h,
|
||||
@ -1435,7 +1466,7 @@ void fp8_1x128_cs_reshape(__nv_fp8_e4m3* mat_quant, float* scales, __nv_bfloat16
|
||||
{
|
||||
kNumDeviceSMs = tensorrt_llm::common::getMultiProcessorCount();
|
||||
}
|
||||
scale_1x128_reshape_kernel<<<kNumDeviceSMs, 256, 0, stream>>>(
|
||||
scale_1x128_reshape_kernel<<<kNumDeviceSMs * 8, 256, 0, stream>>>(
|
||||
mat_quant, scales, mat, shape_x, shape_h, shape_y, stride_x);
|
||||
}
|
||||
|
||||
@ -1642,7 +1673,7 @@ void fp8_gemm_run(__nv_bfloat16 const* mat_a, __nv_fp8_e4m3* fp8_mat_a, int ld_a
|
||||
|
||||
if (internal_quantize_a)
|
||||
{
|
||||
scale_1x128_kernel<<<kNumDeviceSMs, 256, 0, stream>>>(fp8_mat_a, scales_a, mat_a, shape_k, shape_m);
|
||||
scale_1x128_kernel<<<kNumDeviceSMs * 8, 256, 0, stream>>>(fp8_mat_a, scales_a, mat_a, shape_k, shape_m);
|
||||
}
|
||||
if (internal_quantize_b)
|
||||
{
|
||||
@ -1792,7 +1823,7 @@ void fp8_stride_batch_gemm_run(__nv_bfloat16 const* mat_a, __nv_fp8_e4m3* fp8_ma
|
||||
}
|
||||
if (internal_quantize_a)
|
||||
{
|
||||
scale_1x128_kernel<<<kNumDeviceSMs, 256, 0, stream>>>(
|
||||
scale_1x128_kernel<<<kNumDeviceSMs * 8, 256, 0, stream>>>(
|
||||
fp8_mat_a, scales_a, mat_a, shape_k, shape_m * num_problems);
|
||||
}
|
||||
if (internal_quantize_b)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user