diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_gemm_kernel.cuh b/cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_gemm_kernel.cuh index 4747fb25c2..9234883db0 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_gemm_kernel.cuh +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_gemm_kernel.cuh @@ -979,7 +979,7 @@ __global__ void scale_1x128_kernel( size_t scales_along_dim_x = div_up(dim_x, 128); size_t scales_along_dim_y = div_up(dim_y, 1); size_t stride_scale_dim_y = div_up(dim_y, 4) * 4; - + using Input2Type = typename std::conditional::value, half2, __nv_bfloat162>::type; for (size_t warp_idx = (blockIdx.x * blockDim.x + threadIdx.x) / 32; warp_idx < scales_along_dim_x * scales_along_dim_y; warp_idx += gridDim.x * blockDim.x / 32) { @@ -988,21 +988,34 @@ __global__ void scale_1x128_kernel( InputType const* input_line = input + (size_t) scales_idx_y * dim_x + scales_idx_x * 128; InputType input_amax = InputType(0); - int lane_id = threadIdx.x % 32; - InputType input_frag[4] = {0}; + // Each thread reads 2 elements from input_line + int lane_id = threadIdx.x % 32 * 2; - for (int i = 0; i < 4; i++) + Input2Type input_frag2[2] = {Input2Type(0, 0), Input2Type(0, 0)}; +#pragma unroll + for (int i = 0; i < 2; i++) { - if (scales_idx_x * 128 + i * 32 + lane_id >= dim_x) + if (scales_idx_x * 128 + i * 64 + lane_id >= dim_x) { break; } else { - input_frag[i] = input_line[lane_id]; - input_amax = InputType(std::max(float(input_amax), std::fabs(float(input_frag[i])))); + input_frag2[i] = *((Input2Type*) (input_line) + lane_id / 2); + } + input_line += 64; + } +#pragma unroll + for (int i = 0; i < 2; i++) + { + if (scales_idx_x * 128 + i * 64 + lane_id >= dim_x) + { + break; + } + else + { + input_amax = InputType(__hmax(input_amax, __hmax(__habs(input_frag2[i].x), __habs(input_frag2[i].y)))); } - input_line += 32; } InputType amax = find_max_elem_in_warp(input_amax); @@ -1014,18 +1027,21 @@ __global__ void scale_1x128_kernel( } OutputType* output_line = output + (size_t) scales_idx_y * dim_x + scales_idx_x * 128; - for (int i = 0; i < 4; i++) +#pragma unroll + for (int i = 0; i < 2; i++) { - if (scales_idx_x * 128 + i * 32 + lane_id >= dim_x) + if (scales_idx_x * 128 + i * 64 + lane_id >= dim_x) { break; } else { - ScaleType value = ScaleType(input_frag[i]) * scale; - output_line[lane_id] = OutputType(value); + ScaleType value_1 = ScaleType(input_frag2[i].x) * scale; + ScaleType value_2 = ScaleType(input_frag2[i].y) * scale; + output_line[lane_id] = OutputType(value_1); + output_line[lane_id + 1] = OutputType(value_2); } - output_line += 32; + output_line += 64; } } #endif @@ -1245,7 +1261,7 @@ __global__ void scale_1x128_reshape_kernel( size_t scales_along_dim_y = div_up(dim_y, 1); size_t scales_along_dim_h = div_up(dim_h, 1); size_t stride_scale_dim_y = div_up(dim_y, 4) * 4; - + using Input2Type = typename std::conditional::value, half2, __nv_bfloat162>::type; for (size_t warp_idx = (blockIdx.x * blockDim.x + threadIdx.x) / 32; warp_idx < scales_along_dim_x * scales_along_dim_y * scales_along_dim_h; warp_idx += gridDim.x * blockDim.x / 32) @@ -1257,21 +1273,33 @@ __global__ void scale_1x128_reshape_kernel( InputType const* input_line = input + (size_t) scales_idx_y * stride_x * dim_h + (size_t) scales_idx_h * stride_x + scales_idx_x * 128; InputType input_amax = InputType(0); - int lane_id = threadIdx.x % 32; - InputType input_frag[4] = {0}; + int lane_id = threadIdx.x % 32 * 2; - for (int i = 0; i < 4; i++) + Input2Type input_frag2[2] = {Input2Type(0, 0), Input2Type(0, 0)}; +#pragma unroll + for (int i = 0; i < 2; i++) { - if (scales_idx_x * 128 + i * 32 + lane_id >= dim_x) + if (scales_idx_x * 128 + i * 64 + lane_id >= dim_x) { break; } else { - input_frag[i] = input_line[lane_id]; - input_amax = InputType(std::max(float(input_amax), std::fabs(float(input_frag[i])))); + input_frag2[i] = *((Input2Type*) (input_line) + lane_id / 2); + } + input_line += 64; + } +#pragma unroll + for (int i = 0; i < 2; i++) + { + if (scales_idx_x * 128 + i * 64 + lane_id >= dim_x) + { + break; + } + else + { + input_amax = InputType(__hmax(input_amax, __hmax(__habs(input_frag2[i].x), __habs(input_frag2[i].y)))); } - input_line += 32; } InputType amax = find_max_elem_in_warp(input_amax); @@ -1286,18 +1314,21 @@ __global__ void scale_1x128_reshape_kernel( OutputType* output_line = output + (size_t) scales_idx_h * dim_y * dim_x + (size_t) scales_idx_y * dim_x + scales_idx_x * 128; - for (int i = 0; i < 4; i++) +#pragma unroll + for (int i = 0; i < 2; i++) { - if (scales_idx_x * 128 + i * 32 + lane_id >= dim_x) + if (scales_idx_x * 128 + i * 64 + lane_id >= dim_x) { break; } else { - ScaleType value = ScaleType(input_frag[i]) * scale; - output_line[lane_id] = OutputType(value); + ScaleType value_1 = ScaleType(input_frag2[i].x) * scale; + ScaleType value_2 = ScaleType(input_frag2[i].y) * scale; + output_line[lane_id] = OutputType(value_1); + output_line[lane_id + 1] = OutputType(value_2); } - output_line += 32; + output_line += 64; } } #endif @@ -1425,7 +1456,7 @@ void fp8_1x128_cs( { kNumDeviceSMs = tensorrt_llm::common::getMultiProcessorCount(); } - scale_1x128_kernel<<>>(mat_quant, scales, mat, shape_x, shape_y); + scale_1x128_kernel<<>>(mat_quant, scales, mat, shape_x, shape_y); } void fp8_1x128_cs_reshape(__nv_fp8_e4m3* mat_quant, float* scales, __nv_bfloat16 const* mat, int shape_x, int shape_h, @@ -1435,7 +1466,7 @@ void fp8_1x128_cs_reshape(__nv_fp8_e4m3* mat_quant, float* scales, __nv_bfloat16 { kNumDeviceSMs = tensorrt_llm::common::getMultiProcessorCount(); } - scale_1x128_reshape_kernel<<>>( + scale_1x128_reshape_kernel<<>>( mat_quant, scales, mat, shape_x, shape_h, shape_y, stride_x); } @@ -1642,7 +1673,7 @@ void fp8_gemm_run(__nv_bfloat16 const* mat_a, __nv_fp8_e4m3* fp8_mat_a, int ld_a if (internal_quantize_a) { - scale_1x128_kernel<<>>(fp8_mat_a, scales_a, mat_a, shape_k, shape_m); + scale_1x128_kernel<<>>(fp8_mat_a, scales_a, mat_a, shape_k, shape_m); } if (internal_quantize_b) { @@ -1792,7 +1823,7 @@ void fp8_stride_batch_gemm_run(__nv_bfloat16 const* mat_a, __nv_fp8_e4m3* fp8_ma } if (internal_quantize_a) { - scale_1x128_kernel<<>>( + scale_1x128_kernel<<>>( fp8_mat_a, scales_a, mat_a, shape_k, shape_m * num_problems); } if (internal_quantize_b)