From fae4985797b1b4bdb7683d281c19b6ff56f414f9 Mon Sep 17 00:00:00 2001 From: ZhichenJiang Date: Tue, 27 Jan 2026 16:15:32 +0800 Subject: [PATCH] [TRTLLM-9831][perf] Use TMA.RED to improve effective memory bandwidth (#10987) Signed-off-by: zhichen jiang --- .../_torch/custom_ops/cute_dsl_custom_ops.py | 1 + ...contiguous_grouped_gemm_finalize_fusion.py | 204 ++++++++++++++---- .../cute_dsl_kernels/blackwell/utils.py | 51 +++++ 3 files changed, 209 insertions(+), 47 deletions(-) diff --git a/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py b/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py index ff37aa7d95..a75b9aeddf 100644 --- a/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py @@ -1305,6 +1305,7 @@ if IS_CUTLASS_DSL_AVAILABLE: sf_vec_size=self.scaling_vector_size, mma_tiler_mn=mma_tiler_mn, cluster_shape_mn=cluster_shape_mn, + use_blkred=True, raster_along_m=raster_along_m, ) # Compute max active clusters on current device diff --git a/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py b/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py index bc2856acbb..50d36beff8 100644 --- a/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py +++ b/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py @@ -40,6 +40,9 @@ from cutlass.cute.nvgpu import cpasync, tcgen05 from .utils import ( TRTLLM_ENABLE_PDL, atomic_add_func, + blk_reduce_bf16, + blk_reduce_fp16, + blk_reduce_fp32, griddepcontrol_launch_dependents, griddepcontrol_wait, is_power_of_2, @@ -341,6 +344,7 @@ class Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel: sf_vec_size: int, mma_tiler_mn: Tuple[int, int], cluster_shape_mn: Tuple[int, int], + use_blkred: bool = False, raster_along_m: bool = False, ): """Initializes the configuration for a Blackwell blockscaled dense GEMM kernel. @@ -371,6 +375,9 @@ class Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel: self.cta_group = tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE + # Block reduce configuration + self.use_blkred = use_blkred + self.occupancy = 1 self.epilog_warp_id = (0, 1, 2, 3) self.mma_warp_id = 4 @@ -528,12 +535,12 @@ class Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel: self.a_dtype, self.b_dtype, self.out_dtype, - self.gemm_output_layout, - self.epi_tile, + self.cta_tile_shape_mnk, self.sf_dtype, self.sf_vec_size, self.num_smem_capacity, self.occupancy, + self.use_blkred, ) # Compute A/B/C/Scale shared memory layout @@ -562,12 +569,16 @@ class Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel: self.num_ab_stage, ) - self.c_smem_layout_staged = sm100_utils.make_smem_layout_epi( - self.out_dtype, - self.gemm_output_layout, - self.epi_tile, - self.num_c_stage, + swizzled_pad = 16 // (self.out_dtype.width // 8) + self.c_smem_layout_staged = cute.make_layout( + (self.cta_tile_shape_mnk[0], self.cta_tile_shape_mnk[1], self.num_c_stage), + stride=( + self.cta_tile_shape_mnk[1] + swizzled_pad, + 1, + self.cta_tile_shape_mnk[0] * (self.cta_tile_shape_mnk[1] + 8), + ), ) + # Overlap and double buffer accumulator when num_acc_stage == 1 for cta_tile_n = 256 case self.overlapping_accum = self.num_acc_stage == 1 @@ -622,8 +633,8 @@ class Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel: :type sfa: cute.Tensor :param sfb: Scale factor tensor B :type sfb: cute.Tensor - :param tile_idx_to_expert_idx: Mapping from tile index to expert ID, shape (permuted_m/cta_tile_m,) where - cta_tile_m is the CTA tile M size + :param tile_idx_to_expert_idx: Mapping from tile index to expert ID, + shape (permuted_m/cta_tile_m,) where cta_tile_m is the CTA tile M size :type tile_idx_to_expert_idx: cute.Tensor :param num_non_exiting_tiles: Number of valid tiles (valid_m/cta_tile_m), shape (1,) :type num_non_exiting_tiles: cute.Tensor @@ -781,6 +792,7 @@ class Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel: epi_tile_size = epi_tile_m * epi_tile_n num_epilogue_threads = 32 * len(self.epilog_warp_id) self.ttr_racc_size = epi_tile_size // num_epilogue_threads + self.copy_size = self.cta_tile_shape_mnk[1] * (self.out_dtype.width // 8) if cutlass.const_expr(self.out_dtype == cutlass.BFloat16): # 8-element vectorization for BF16 @@ -804,7 +816,7 @@ class Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel: # Define shared storage for kernel @cute.struct class SharedStorage: - # (bidx, bidy, bidz, valid) + # (bidx, bidy, bidz, valid, mn_limit) sInfo: cute.struct.Align[ cute.struct.MemRange[cutlass.Int32, 5 * self.num_tile_stage], # 1 byte alignment @@ -836,6 +848,12 @@ class Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel: self.buffer_align_bytes, ] + if cutlass.const_expr(self.use_blkred): + sC: cute.struct.Align[ + cute.struct.MemRange[self.out_dtype, cute.cosize(self.c_smem_layout_staged)], + self.buffer_align_bytes, + ] + self.shared_storage = SharedStorage # Launch the kernel synchronously @@ -863,8 +881,10 @@ class Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel: self.b_smem_layout_staged, self.sfa_smem_layout_staged, self.sfb_smem_layout_staged, + self.c_smem_layout_staged, self.epi_tile, self.epi_layout, + self.topK, self.tile_sched_params, epilogue_op, ).launch( @@ -946,8 +966,10 @@ class Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel: b_smem_layout_staged: cute.ComposedLayout, sfa_smem_layout_staged: cute.Layout, sfb_smem_layout_staged: cute.Layout, + c_smem_layout_staged: cute.Layout, epi_tile: cute.Tile, epi_layout: cute.Layout, + topK: cutlass.Int32, tile_sched_params: utils.PersistentTileSchedulerParams, epilogue_op: cutlass.Constexpr, ): @@ -1060,6 +1082,10 @@ class Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel: sSFA = storage.sSFA.get_tensor(sfa_smem_layout_staged) # (granularity_n, repeat_n), (granularity_k, repeat_k), num_scale_stage) sSFB = storage.sSFB.get_tensor(sfb_smem_layout_staged) + + if cutlass.const_expr(self.use_blkred): + sC = storage.sC.get_tensor(c_smem_layout_staged) + # (bidx, bidy, bidz, valid) info_layout = cute.make_layout((5, self.num_tile_stage), stride=(1, 5)) sInfo = storage.sInfo.get_tensor(info_layout) @@ -1731,6 +1757,10 @@ class Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel: ) tTR_rC = cute.make_rmem_tensor(tTR_rAcc.shape, self.out_dtype) + if cutlass.const_expr(self.use_blkred): + tiled_copy_r2s, tRS_rC, tRS_sC = self.epilog_smem_copy_and_partition( + epi_tidx, tTR_rC, sC, tiled_copy_t2r + ) acc_consumer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Consumer, self.num_acc_stage @@ -1801,17 +1831,11 @@ class Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel: subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) if is_valid_row: - topK = token_final_scales.shape[1] token_idx = expanded_idx // topK topk_idx = expanded_idx % topK token_scale = token_final_scales[(token_idx, topk_idx)] alpha_val = alpha_val * token_scale - scatter_out = cute.domain_offset( - (token_idx, 0, 0), - out, # Use original tensor to get real pointer - ) - for subtile_idx in cutlass.range(subtile_cnt): real_subtile_idx = subtile_idx if cutlass.const_expr(self.overlapping_accum): @@ -1839,27 +1863,52 @@ class Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel: acc_vec = tTR_rAcc.load() acc_vec_final = alpha_val * acc_vec - tTR_rC.store(acc_vec_final.to(self.out_dtype)) + if cutlass.const_expr(self.use_blkred): + tRS_rC.store(acc_vec_final.to(self.out_dtype)) + if is_valid_row: + cute.copy( + tiled_copy_r2s, + tRS_rC, + tRS_sC[(None, None, real_subtile_idx, None)], + ) + else: + tTR_rC.store(acc_vec_final.to(self.out_dtype)) + if is_valid_row: + rOut_epi = cute.make_tensor(tTR_rC.iterator, epi_layout) - if is_valid_row: - rOut_epi = cute.make_tensor(tTR_rC.iterator, epi_layout) + base_coord_n = mma_tile_coord_mnl[1] * self.cta_tile_shape_mnk[ + 1 + ] + real_subtile_idx * cute.size(tTR_rC) - base_coord_n = mma_tile_coord_mnl[1] * self.cta_tile_shape_mnk[ - 1 - ] + real_subtile_idx * cute.size(tTR_rC) + scatter_out = cute.domain_offset( + (token_idx, 0, 0), + out, # Use original tensor to get real pointer + ) - for index in cutlass.range(self.epi_loop_size, unroll_full=True): - coord_n = base_coord_n + index * self.element_offset - scatter_out_offset = cute.domain_offset((0, coord_n, 0), scatter_out) - if cutlass.const_expr(self.out_dtype == cutlass.BFloat16): - rOut_epi_packed = rOut_epi[index, None, None] - vectorized_atomic_add_bf16x8(rOut_epi_packed, scatter_out_offset) - elif cutlass.const_expr(self.out_dtype == cutlass.Float32): - rOut_epi_packed = rOut_epi[index, None] - vectorized_atomic_add_fp32x2(rOut_epi_packed, scatter_out_offset) - else: - rOut_epi_packed = rOut_epi[index] - atomic_add_func(rOut_epi_packed, scatter_out_offset) + for index in cutlass.range(self.epi_loop_size, unroll_full=True): + coord_n = base_coord_n + index * self.element_offset + scatter_out_offset = cute.domain_offset( + (0, coord_n, 0), scatter_out + ) + if cutlass.const_expr(self.out_dtype == cutlass.BFloat16): + rOut_epi_packed = rOut_epi[index, None, None] + vectorized_atomic_add_bf16x8( + rOut_epi_packed, scatter_out_offset + ) + elif cutlass.const_expr(self.out_dtype == cutlass.Float32): + rOut_epi_packed = rOut_epi[index, None] + vectorized_atomic_add_fp32x2( + rOut_epi_packed, scatter_out_offset + ) + else: + rOut_epi_packed = rOut_epi[index] + atomic_add_func(rOut_epi_packed, scatter_out_offset) + + if cutlass.const_expr(self.use_blkred): + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) # # Async arrive accumulator buffer empty # @@ -1869,7 +1918,34 @@ class Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel: acc_pipeline.consumer_release(acc_consumer_state) acc_consumer_state.advance() - # + if cutlass.const_expr(self.use_blkred): + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + if is_valid_row: + coord_n = mma_tile_coord_mnl[1] * self.cta_tile_shape_mnk[1] + scatter_out_offset = cute.domain_offset((token_idx, coord_n, 0), out) + if cutlass.const_expr(self.out_dtype == cutlass.BFloat16): + blk_reduce_bf16( + scatter_out_offset, + sC[epi_tidx, None, 0], + cutlass.Int32(self.copy_size), + ) + elif cutlass.const_expr(self.out_dtype == cutlass.Float32): + blk_reduce_fp32( + scatter_out_offset, + sC[epi_tidx, None, 0], + cutlass.Int32(self.copy_size), + ) + elif cutlass.const_expr(self.out_dtype == cutlass.Float16): + blk_reduce_fp16( + scatter_out_offset, + sC[epi_tidx, None, 0], + cutlass.Int32(self.copy_size), + ) + self.epilog_sync_barrier.arrive_and_wait() + # Advance to next tile # tile_info_pipeline.consumer_wait(tile_info_consumer_state) @@ -1954,6 +2030,28 @@ class Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel: return tiled_copy_t2r, tTR_tAcc, tTR_rAcc + def epilog_smem_copy_and_partition( + self, + tidx: cutlass.Int32, + tTR_rC: cute.Tensor, + sC: cute.Tensor, + tiled_copy_t2r: cute.TiledCopy, + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Create tiled copy for register to shared memory (R2S). + """ + atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.out_dtype, + ) + + tiled_copy_r2s = cute.make_tiled_copy_D(atom, tiled_copy_t2r) + # (R2S, R2S_M, R2S_N, PIPE_D) + thr_copy_r2s = tiled_copy_r2s.get_slice(tidx) + tRS_sC = thr_copy_r2s.partition_D(sC) + tRS_rC = tiled_copy_r2s.retile(tTR_rC) + return tiled_copy_r2s, tRS_rC, tRS_sC + @staticmethod def _compute_stages( tiled_mma: cute.TiledMma, @@ -1961,12 +2059,12 @@ class Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel: a_dtype: Type[cutlass.Numeric], b_dtype: Type[cutlass.Numeric], out_dtype: Type[cutlass.Numeric], - gemm_output_layout: utils.LayoutEnum, - epi_tile: cute.Tile, + cta_tile: cute.Tile, sf_dtype: Type[cutlass.Numeric], sf_vec_size: int, num_smem_capacity: int, occupancy: int, + use_blkred: bool, ) -> Tuple[int, int, int]: """Computes the number of stages for A/B/C operands based on heuristics. @@ -1978,12 +2076,10 @@ class Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel: :type a_dtype: type[cutlass.Numeric] :param b_dtype: Data type of operand B. :type b_dtype: type[cutlass.Numeric] - :param epi_tile: The epilogue tile shape. - :type epi_tile: cute.Tile :param out_dtype: Data type of operand C (output). :type out_dtype: type[cutlass.Numeric] - :param gemm_output_layout: Layout of operand C. - :type gemm_output_layout: utils.LayoutEnum + :param cta_tile: The CTA tile shape. + :type cta_tile: cute.Tile :param sf_dtype: Data type of scale factor. :type sf_dtype: type[cutlass.Numeric] :param sf_vec_size: Vector size of scale factor. @@ -1992,6 +2088,8 @@ class Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel: :type num_smem_capacity: int :param occupancy: Target number of CTAs per SM (occupancy). :type occupancy: int + :param use_blkred: Whether to use block reduce. + :type use_blkred: bool :return: A tuple containing the computed number of stages for: (ACC stages, A/B operand stages, C stages) @@ -2001,7 +2099,7 @@ class Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel: num_acc_stage = 1 if mma_tiler_mnk[1] == 256 else 2 # Default C stages - num_c_stage = 2 + num_c_stage = 1 # Default Tile info stages num_tile_stage = 2 @@ -2034,6 +2132,12 @@ class Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel: 1, # a tmp 1 stage is provided ) + # satisfy 16B alignment for the output tensor + swizzled_pad = 16 // (out_dtype.width // 8) + c_smem_layout_staged_one = cute.make_layout( + (cta_tile[0], cta_tile[1]), stride=(cta_tile[1] + swizzled_pad, 1) + ) + ab_bytes_per_stage = ( cute.size_in_bytes(a_dtype, a_smem_layout_stage_one) + cute.size_in_bytes(b_dtype, b_smem_layout_staged_one) @@ -2043,16 +2147,22 @@ class Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel: # 1024B alignment for mbar mbar_helpers_bytes = 1024 + c_bytes_per_stage = cute.size_in_bytes(out_dtype, c_smem_layout_staged_one) + c_bytes = c_bytes_per_stage * num_c_stage + # Calculate A/B stages: # Start with total smem per CTA (capacity / occupancy) # Subtract reserved bytes and initial C stages bytes # Divide remaining by bytes needed per A/B stage - num_ab_stage = (num_smem_capacity // occupancy - mbar_helpers_bytes) // ab_bytes_per_stage + if cutlass.const_expr(use_blkred): + num_ab_stage = ( + num_smem_capacity // occupancy - (mbar_helpers_bytes + c_bytes) + ) // ab_bytes_per_stage + else: + num_ab_stage = ( + num_smem_capacity // occupancy - mbar_helpers_bytes + ) // ab_bytes_per_stage - # Refine epilogue stages: - # Calculate remaining smem after allocating for A/B stages and reserved bytes - # Add remaining unused smem to epilogue - num_c_stage = 2 return num_acc_stage, num_ab_stage, num_c_stage, num_tile_stage @staticmethod diff --git a/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/utils.py b/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/utils.py index b1b9026b8a..98f9294d1d 100644 --- a/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/utils.py +++ b/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/utils.py @@ -306,6 +306,57 @@ def atomic_add_func(rOut_epi_packed, scatter_out_offset, loc=None, ip=None): ) +@dsl_user_op +def blk_reduce_bf16(dst_gemm, src_smem, size, loc=None, ip=None): + llvm.inline_asm( + None, + [ + dst_gemm.iterator.llvm_ptr, + src_smem.iterator.llvm_ptr, + size.ir_value(), + ], + "cp.reduce.async.bulk.global.shared::cta.bulk_group.add.noftz.bf16 [$0], [$1], $2;", + "l,l,r", + has_side_effects=True, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def blk_reduce_fp32(dst_gemm, src_smem, size, loc=None, ip=None): + llvm.inline_asm( + None, + [ + dst_gemm.iterator.llvm_ptr, + src_smem.iterator.llvm_ptr, + size.ir_value(), + ], + "cp.reduce.async.bulk.global.shared::cta.bulk_group.add.f32 [$0], [$1], $2;", + "l,l,r", + has_side_effects=True, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def blk_reduce_fp16(dst_gemm, src_smem, size, loc=None, ip=None): + llvm.inline_asm( + None, + [ + dst_gemm.iterator.llvm_ptr, + src_smem.iterator.llvm_ptr, + size.ir_value(), + ], + "cp.reduce.async.bulk.global.shared::cta.bulk_group.noftz.f16 [$0], [$1], $2;", + "l,l,r", + has_side_effects=True, + loc=loc, + ip=ip, + ) + + @dsl_user_op def griddepcontrol_wait(*, loc=None, ip=None) -> None: """