mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
perf: Optimize quantization kernels used in DeepSeek on Hopper (#3466)
Signed-off-by: jiahanc <jiahanc@nvidia.com>
This commit is contained in:
parent
5cfa927132
commit
1d3b98b920
@ -979,7 +979,7 @@ __global__ void scale_1x128_kernel(
|
|||||||
size_t scales_along_dim_x = div_up(dim_x, 128);
|
size_t scales_along_dim_x = div_up(dim_x, 128);
|
||||||
size_t scales_along_dim_y = div_up(dim_y, 1);
|
size_t scales_along_dim_y = div_up(dim_y, 1);
|
||||||
size_t stride_scale_dim_y = div_up(dim_y, 4) * 4;
|
size_t stride_scale_dim_y = div_up(dim_y, 4) * 4;
|
||||||
|
using Input2Type = typename std::conditional<std::is_same<InputType, half>::value, half2, __nv_bfloat162>::type;
|
||||||
for (size_t warp_idx = (blockIdx.x * blockDim.x + threadIdx.x) / 32;
|
for (size_t warp_idx = (blockIdx.x * blockDim.x + threadIdx.x) / 32;
|
||||||
warp_idx < scales_along_dim_x * scales_along_dim_y; warp_idx += gridDim.x * blockDim.x / 32)
|
warp_idx < scales_along_dim_x * scales_along_dim_y; warp_idx += gridDim.x * blockDim.x / 32)
|
||||||
{
|
{
|
||||||
@ -988,21 +988,34 @@ __global__ void scale_1x128_kernel(
|
|||||||
|
|
||||||
InputType const* input_line = input + (size_t) scales_idx_y * dim_x + scales_idx_x * 128;
|
InputType const* input_line = input + (size_t) scales_idx_y * dim_x + scales_idx_x * 128;
|
||||||
InputType input_amax = InputType(0);
|
InputType input_amax = InputType(0);
|
||||||
int lane_id = threadIdx.x % 32;
|
// Each thread reads 2 elements from input_line
|
||||||
InputType input_frag[4] = {0};
|
int lane_id = threadIdx.x % 32 * 2;
|
||||||
|
|
||||||
for (int i = 0; i < 4; i++)
|
Input2Type input_frag2[2] = {Input2Type(0, 0), Input2Type(0, 0)};
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 2; i++)
|
||||||
{
|
{
|
||||||
if (scales_idx_x * 128 + i * 32 + lane_id >= dim_x)
|
if (scales_idx_x * 128 + i * 64 + lane_id >= dim_x)
|
||||||
{
|
{
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
input_frag[i] = input_line[lane_id];
|
input_frag2[i] = *((Input2Type*) (input_line) + lane_id / 2);
|
||||||
input_amax = InputType(std::max(float(input_amax), std::fabs(float(input_frag[i]))));
|
}
|
||||||
|
input_line += 64;
|
||||||
|
}
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 2; i++)
|
||||||
|
{
|
||||||
|
if (scales_idx_x * 128 + i * 64 + lane_id >= dim_x)
|
||||||
|
{
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
input_amax = InputType(__hmax(input_amax, __hmax(__habs(input_frag2[i].x), __habs(input_frag2[i].y))));
|
||||||
}
|
}
|
||||||
input_line += 32;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
InputType amax = find_max_elem_in_warp(input_amax);
|
InputType amax = find_max_elem_in_warp(input_amax);
|
||||||
@ -1014,18 +1027,21 @@ __global__ void scale_1x128_kernel(
|
|||||||
}
|
}
|
||||||
|
|
||||||
OutputType* output_line = output + (size_t) scales_idx_y * dim_x + scales_idx_x * 128;
|
OutputType* output_line = output + (size_t) scales_idx_y * dim_x + scales_idx_x * 128;
|
||||||
for (int i = 0; i < 4; i++)
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 2; i++)
|
||||||
{
|
{
|
||||||
if (scales_idx_x * 128 + i * 32 + lane_id >= dim_x)
|
if (scales_idx_x * 128 + i * 64 + lane_id >= dim_x)
|
||||||
{
|
{
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
ScaleType value = ScaleType(input_frag[i]) * scale;
|
ScaleType value_1 = ScaleType(input_frag2[i].x) * scale;
|
||||||
output_line[lane_id] = OutputType(value);
|
ScaleType value_2 = ScaleType(input_frag2[i].y) * scale;
|
||||||
|
output_line[lane_id] = OutputType(value_1);
|
||||||
|
output_line[lane_id + 1] = OutputType(value_2);
|
||||||
}
|
}
|
||||||
output_line += 32;
|
output_line += 64;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
@ -1245,7 +1261,7 @@ __global__ void scale_1x128_reshape_kernel(
|
|||||||
size_t scales_along_dim_y = div_up(dim_y, 1);
|
size_t scales_along_dim_y = div_up(dim_y, 1);
|
||||||
size_t scales_along_dim_h = div_up(dim_h, 1);
|
size_t scales_along_dim_h = div_up(dim_h, 1);
|
||||||
size_t stride_scale_dim_y = div_up(dim_y, 4) * 4;
|
size_t stride_scale_dim_y = div_up(dim_y, 4) * 4;
|
||||||
|
using Input2Type = typename std::conditional<std::is_same<InputType, half>::value, half2, __nv_bfloat162>::type;
|
||||||
for (size_t warp_idx = (blockIdx.x * blockDim.x + threadIdx.x) / 32;
|
for (size_t warp_idx = (blockIdx.x * blockDim.x + threadIdx.x) / 32;
|
||||||
warp_idx < scales_along_dim_x * scales_along_dim_y * scales_along_dim_h;
|
warp_idx < scales_along_dim_x * scales_along_dim_y * scales_along_dim_h;
|
||||||
warp_idx += gridDim.x * blockDim.x / 32)
|
warp_idx += gridDim.x * blockDim.x / 32)
|
||||||
@ -1257,21 +1273,33 @@ __global__ void scale_1x128_reshape_kernel(
|
|||||||
InputType const* input_line
|
InputType const* input_line
|
||||||
= input + (size_t) scales_idx_y * stride_x * dim_h + (size_t) scales_idx_h * stride_x + scales_idx_x * 128;
|
= input + (size_t) scales_idx_y * stride_x * dim_h + (size_t) scales_idx_h * stride_x + scales_idx_x * 128;
|
||||||
InputType input_amax = InputType(0);
|
InputType input_amax = InputType(0);
|
||||||
int lane_id = threadIdx.x % 32;
|
int lane_id = threadIdx.x % 32 * 2;
|
||||||
InputType input_frag[4] = {0};
|
|
||||||
|
|
||||||
for (int i = 0; i < 4; i++)
|
Input2Type input_frag2[2] = {Input2Type(0, 0), Input2Type(0, 0)};
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 2; i++)
|
||||||
{
|
{
|
||||||
if (scales_idx_x * 128 + i * 32 + lane_id >= dim_x)
|
if (scales_idx_x * 128 + i * 64 + lane_id >= dim_x)
|
||||||
{
|
{
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
input_frag[i] = input_line[lane_id];
|
input_frag2[i] = *((Input2Type*) (input_line) + lane_id / 2);
|
||||||
input_amax = InputType(std::max(float(input_amax), std::fabs(float(input_frag[i]))));
|
}
|
||||||
|
input_line += 64;
|
||||||
|
}
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 2; i++)
|
||||||
|
{
|
||||||
|
if (scales_idx_x * 128 + i * 64 + lane_id >= dim_x)
|
||||||
|
{
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
input_amax = InputType(__hmax(input_amax, __hmax(__habs(input_frag2[i].x), __habs(input_frag2[i].y))));
|
||||||
}
|
}
|
||||||
input_line += 32;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
InputType amax = find_max_elem_in_warp(input_amax);
|
InputType amax = find_max_elem_in_warp(input_amax);
|
||||||
@ -1286,18 +1314,21 @@ __global__ void scale_1x128_reshape_kernel(
|
|||||||
|
|
||||||
OutputType* output_line
|
OutputType* output_line
|
||||||
= output + (size_t) scales_idx_h * dim_y * dim_x + (size_t) scales_idx_y * dim_x + scales_idx_x * 128;
|
= output + (size_t) scales_idx_h * dim_y * dim_x + (size_t) scales_idx_y * dim_x + scales_idx_x * 128;
|
||||||
for (int i = 0; i < 4; i++)
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 2; i++)
|
||||||
{
|
{
|
||||||
if (scales_idx_x * 128 + i * 32 + lane_id >= dim_x)
|
if (scales_idx_x * 128 + i * 64 + lane_id >= dim_x)
|
||||||
{
|
{
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
ScaleType value = ScaleType(input_frag[i]) * scale;
|
ScaleType value_1 = ScaleType(input_frag2[i].x) * scale;
|
||||||
output_line[lane_id] = OutputType(value);
|
ScaleType value_2 = ScaleType(input_frag2[i].y) * scale;
|
||||||
|
output_line[lane_id] = OutputType(value_1);
|
||||||
|
output_line[lane_id + 1] = OutputType(value_2);
|
||||||
}
|
}
|
||||||
output_line += 32;
|
output_line += 64;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
@ -1425,7 +1456,7 @@ void fp8_1x128_cs(
|
|||||||
{
|
{
|
||||||
kNumDeviceSMs = tensorrt_llm::common::getMultiProcessorCount();
|
kNumDeviceSMs = tensorrt_llm::common::getMultiProcessorCount();
|
||||||
}
|
}
|
||||||
scale_1x128_kernel<<<kNumDeviceSMs, 256, 0, stream>>>(mat_quant, scales, mat, shape_x, shape_y);
|
scale_1x128_kernel<<<kNumDeviceSMs * 8, 256, 0, stream>>>(mat_quant, scales, mat, shape_x, shape_y);
|
||||||
}
|
}
|
||||||
|
|
||||||
void fp8_1x128_cs_reshape(__nv_fp8_e4m3* mat_quant, float* scales, __nv_bfloat16 const* mat, int shape_x, int shape_h,
|
void fp8_1x128_cs_reshape(__nv_fp8_e4m3* mat_quant, float* scales, __nv_bfloat16 const* mat, int shape_x, int shape_h,
|
||||||
@ -1435,7 +1466,7 @@ void fp8_1x128_cs_reshape(__nv_fp8_e4m3* mat_quant, float* scales, __nv_bfloat16
|
|||||||
{
|
{
|
||||||
kNumDeviceSMs = tensorrt_llm::common::getMultiProcessorCount();
|
kNumDeviceSMs = tensorrt_llm::common::getMultiProcessorCount();
|
||||||
}
|
}
|
||||||
scale_1x128_reshape_kernel<<<kNumDeviceSMs, 256, 0, stream>>>(
|
scale_1x128_reshape_kernel<<<kNumDeviceSMs * 8, 256, 0, stream>>>(
|
||||||
mat_quant, scales, mat, shape_x, shape_h, shape_y, stride_x);
|
mat_quant, scales, mat, shape_x, shape_h, shape_y, stride_x);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1642,7 +1673,7 @@ void fp8_gemm_run(__nv_bfloat16 const* mat_a, __nv_fp8_e4m3* fp8_mat_a, int ld_a
|
|||||||
|
|
||||||
if (internal_quantize_a)
|
if (internal_quantize_a)
|
||||||
{
|
{
|
||||||
scale_1x128_kernel<<<kNumDeviceSMs, 256, 0, stream>>>(fp8_mat_a, scales_a, mat_a, shape_k, shape_m);
|
scale_1x128_kernel<<<kNumDeviceSMs * 8, 256, 0, stream>>>(fp8_mat_a, scales_a, mat_a, shape_k, shape_m);
|
||||||
}
|
}
|
||||||
if (internal_quantize_b)
|
if (internal_quantize_b)
|
||||||
{
|
{
|
||||||
@ -1792,7 +1823,7 @@ void fp8_stride_batch_gemm_run(__nv_bfloat16 const* mat_a, __nv_fp8_e4m3* fp8_ma
|
|||||||
}
|
}
|
||||||
if (internal_quantize_a)
|
if (internal_quantize_a)
|
||||||
{
|
{
|
||||||
scale_1x128_kernel<<<kNumDeviceSMs, 256, 0, stream>>>(
|
scale_1x128_kernel<<<kNumDeviceSMs * 8, 256, 0, stream>>>(
|
||||||
fp8_mat_a, scales_a, mat_a, shape_k, shape_m * num_problems);
|
fp8_mat_a, scales_a, mat_a, shape_k, shape_m * num_problems);
|
||||||
}
|
}
|
||||||
if (internal_quantize_b)
|
if (internal_quantize_b)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user