[TRTLLM-9831][perf] Use TMA.RED to improve effective memory bandwidth (#10987)

Signed-off-by: zhichen jiang <zhichenj@NVIDIA.com>
This commit is contained in:
ZhichenJiang 2026-01-27 16:15:32 +08:00 committed by GitHub
parent 6b251cc7fa
commit fae4985797
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 209 additions and 47 deletions

View File

@ -1305,6 +1305,7 @@ if IS_CUTLASS_DSL_AVAILABLE:
sf_vec_size=self.scaling_vector_size,
mma_tiler_mn=mma_tiler_mn,
cluster_shape_mn=cluster_shape_mn,
use_blkred=True,
raster_along_m=raster_along_m,
)
# Compute max active clusters on current device

View File

@ -40,6 +40,9 @@ from cutlass.cute.nvgpu import cpasync, tcgen05
from .utils import (
TRTLLM_ENABLE_PDL,
atomic_add_func,
blk_reduce_bf16,
blk_reduce_fp16,
blk_reduce_fp32,
griddepcontrol_launch_dependents,
griddepcontrol_wait,
is_power_of_2,
@ -341,6 +344,7 @@ class Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel:
sf_vec_size: int,
mma_tiler_mn: Tuple[int, int],
cluster_shape_mn: Tuple[int, int],
use_blkred: bool = False,
raster_along_m: bool = False,
):
"""Initializes the configuration for a Blackwell blockscaled dense GEMM kernel.
@ -371,6 +375,9 @@ class Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel:
self.cta_group = tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE
# Block reduce configuration
self.use_blkred = use_blkred
self.occupancy = 1
self.epilog_warp_id = (0, 1, 2, 3)
self.mma_warp_id = 4
@ -528,12 +535,12 @@ class Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel:
self.a_dtype,
self.b_dtype,
self.out_dtype,
self.gemm_output_layout,
self.epi_tile,
self.cta_tile_shape_mnk,
self.sf_dtype,
self.sf_vec_size,
self.num_smem_capacity,
self.occupancy,
self.use_blkred,
)
# Compute A/B/C/Scale shared memory layout
@ -562,12 +569,16 @@ class Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel:
self.num_ab_stage,
)
self.c_smem_layout_staged = sm100_utils.make_smem_layout_epi(
self.out_dtype,
self.gemm_output_layout,
self.epi_tile,
self.num_c_stage,
swizzled_pad = 16 // (self.out_dtype.width // 8)
self.c_smem_layout_staged = cute.make_layout(
(self.cta_tile_shape_mnk[0], self.cta_tile_shape_mnk[1], self.num_c_stage),
stride=(
self.cta_tile_shape_mnk[1] + swizzled_pad,
1,
self.cta_tile_shape_mnk[0] * (self.cta_tile_shape_mnk[1] + 8),
),
)
# Overlap and double buffer accumulator when num_acc_stage == 1 for cta_tile_n = 256 case
self.overlapping_accum = self.num_acc_stage == 1
@ -622,8 +633,8 @@ class Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel:
:type sfa: cute.Tensor
:param sfb: Scale factor tensor B
:type sfb: cute.Tensor
:param tile_idx_to_expert_idx: Mapping from tile index to expert ID, shape (permuted_m/cta_tile_m,) where
cta_tile_m is the CTA tile M size
:param tile_idx_to_expert_idx: Mapping from tile index to expert ID,
shape (permuted_m/cta_tile_m,) where cta_tile_m is the CTA tile M size
:type tile_idx_to_expert_idx: cute.Tensor
:param num_non_exiting_tiles: Number of valid tiles (valid_m/cta_tile_m), shape (1,)
:type num_non_exiting_tiles: cute.Tensor
@ -781,6 +792,7 @@ class Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel:
epi_tile_size = epi_tile_m * epi_tile_n
num_epilogue_threads = 32 * len(self.epilog_warp_id)
self.ttr_racc_size = epi_tile_size // num_epilogue_threads
self.copy_size = self.cta_tile_shape_mnk[1] * (self.out_dtype.width // 8)
if cutlass.const_expr(self.out_dtype == cutlass.BFloat16):
# 8-element vectorization for BF16
@ -804,7 +816,7 @@ class Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel:
# Define shared storage for kernel
@cute.struct
class SharedStorage:
# (bidx, bidy, bidz, valid)
# (bidx, bidy, bidz, valid, mn_limit)
sInfo: cute.struct.Align[
cute.struct.MemRange[cutlass.Int32, 5 * self.num_tile_stage],
# 1 byte alignment
@ -836,6 +848,12 @@ class Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel:
self.buffer_align_bytes,
]
if cutlass.const_expr(self.use_blkred):
sC: cute.struct.Align[
cute.struct.MemRange[self.out_dtype, cute.cosize(self.c_smem_layout_staged)],
self.buffer_align_bytes,
]
self.shared_storage = SharedStorage
# Launch the kernel synchronously
@ -863,8 +881,10 @@ class Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel:
self.b_smem_layout_staged,
self.sfa_smem_layout_staged,
self.sfb_smem_layout_staged,
self.c_smem_layout_staged,
self.epi_tile,
self.epi_layout,
self.topK,
self.tile_sched_params,
epilogue_op,
).launch(
@ -946,8 +966,10 @@ class Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel:
b_smem_layout_staged: cute.ComposedLayout,
sfa_smem_layout_staged: cute.Layout,
sfb_smem_layout_staged: cute.Layout,
c_smem_layout_staged: cute.Layout,
epi_tile: cute.Tile,
epi_layout: cute.Layout,
topK: cutlass.Int32,
tile_sched_params: utils.PersistentTileSchedulerParams,
epilogue_op: cutlass.Constexpr,
):
@ -1060,6 +1082,10 @@ class Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel:
sSFA = storage.sSFA.get_tensor(sfa_smem_layout_staged)
# (granularity_n, repeat_n), (granularity_k, repeat_k), num_scale_stage)
sSFB = storage.sSFB.get_tensor(sfb_smem_layout_staged)
if cutlass.const_expr(self.use_blkred):
sC = storage.sC.get_tensor(c_smem_layout_staged)
# (bidx, bidy, bidz, valid)
info_layout = cute.make_layout((5, self.num_tile_stage), stride=(1, 5))
sInfo = storage.sInfo.get_tensor(info_layout)
@ -1731,6 +1757,10 @@ class Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel:
)
tTR_rC = cute.make_rmem_tensor(tTR_rAcc.shape, self.out_dtype)
if cutlass.const_expr(self.use_blkred):
tiled_copy_r2s, tRS_rC, tRS_sC = self.epilog_smem_copy_and_partition(
epi_tidx, tTR_rC, sC, tiled_copy_t2r
)
acc_consumer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer, self.num_acc_stage
@ -1801,17 +1831,11 @@ class Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel:
subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3])
if is_valid_row:
topK = token_final_scales.shape[1]
token_idx = expanded_idx // topK
topk_idx = expanded_idx % topK
token_scale = token_final_scales[(token_idx, topk_idx)]
alpha_val = alpha_val * token_scale
scatter_out = cute.domain_offset(
(token_idx, 0, 0),
out, # Use original tensor to get real pointer
)
for subtile_idx in cutlass.range(subtile_cnt):
real_subtile_idx = subtile_idx
if cutlass.const_expr(self.overlapping_accum):
@ -1839,27 +1863,52 @@ class Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel:
acc_vec = tTR_rAcc.load()
acc_vec_final = alpha_val * acc_vec
tTR_rC.store(acc_vec_final.to(self.out_dtype))
if cutlass.const_expr(self.use_blkred):
tRS_rC.store(acc_vec_final.to(self.out_dtype))
if is_valid_row:
cute.copy(
tiled_copy_r2s,
tRS_rC,
tRS_sC[(None, None, real_subtile_idx, None)],
)
else:
tTR_rC.store(acc_vec_final.to(self.out_dtype))
if is_valid_row:
rOut_epi = cute.make_tensor(tTR_rC.iterator, epi_layout)
if is_valid_row:
rOut_epi = cute.make_tensor(tTR_rC.iterator, epi_layout)
base_coord_n = mma_tile_coord_mnl[1] * self.cta_tile_shape_mnk[
1
] + real_subtile_idx * cute.size(tTR_rC)
base_coord_n = mma_tile_coord_mnl[1] * self.cta_tile_shape_mnk[
1
] + real_subtile_idx * cute.size(tTR_rC)
scatter_out = cute.domain_offset(
(token_idx, 0, 0),
out, # Use original tensor to get real pointer
)
for index in cutlass.range(self.epi_loop_size, unroll_full=True):
coord_n = base_coord_n + index * self.element_offset
scatter_out_offset = cute.domain_offset((0, coord_n, 0), scatter_out)
if cutlass.const_expr(self.out_dtype == cutlass.BFloat16):
rOut_epi_packed = rOut_epi[index, None, None]
vectorized_atomic_add_bf16x8(rOut_epi_packed, scatter_out_offset)
elif cutlass.const_expr(self.out_dtype == cutlass.Float32):
rOut_epi_packed = rOut_epi[index, None]
vectorized_atomic_add_fp32x2(rOut_epi_packed, scatter_out_offset)
else:
rOut_epi_packed = rOut_epi[index]
atomic_add_func(rOut_epi_packed, scatter_out_offset)
for index in cutlass.range(self.epi_loop_size, unroll_full=True):
coord_n = base_coord_n + index * self.element_offset
scatter_out_offset = cute.domain_offset(
(0, coord_n, 0), scatter_out
)
if cutlass.const_expr(self.out_dtype == cutlass.BFloat16):
rOut_epi_packed = rOut_epi[index, None, None]
vectorized_atomic_add_bf16x8(
rOut_epi_packed, scatter_out_offset
)
elif cutlass.const_expr(self.out_dtype == cutlass.Float32):
rOut_epi_packed = rOut_epi[index, None]
vectorized_atomic_add_fp32x2(
rOut_epi_packed, scatter_out_offset
)
else:
rOut_epi_packed = rOut_epi[index]
atomic_add_func(rOut_epi_packed, scatter_out_offset)
if cutlass.const_expr(self.use_blkred):
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
)
#
# Async arrive accumulator buffer empty
#
@ -1869,7 +1918,34 @@ class Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel:
acc_pipeline.consumer_release(acc_consumer_state)
acc_consumer_state.advance()
#
if cutlass.const_expr(self.use_blkred):
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
)
if is_valid_row:
coord_n = mma_tile_coord_mnl[1] * self.cta_tile_shape_mnk[1]
scatter_out_offset = cute.domain_offset((token_idx, coord_n, 0), out)
if cutlass.const_expr(self.out_dtype == cutlass.BFloat16):
blk_reduce_bf16(
scatter_out_offset,
sC[epi_tidx, None, 0],
cutlass.Int32(self.copy_size),
)
elif cutlass.const_expr(self.out_dtype == cutlass.Float32):
blk_reduce_fp32(
scatter_out_offset,
sC[epi_tidx, None, 0],
cutlass.Int32(self.copy_size),
)
elif cutlass.const_expr(self.out_dtype == cutlass.Float16):
blk_reduce_fp16(
scatter_out_offset,
sC[epi_tidx, None, 0],
cutlass.Int32(self.copy_size),
)
self.epilog_sync_barrier.arrive_and_wait()
# Advance to next tile
#
tile_info_pipeline.consumer_wait(tile_info_consumer_state)
@ -1954,6 +2030,28 @@ class Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel:
return tiled_copy_t2r, tTR_tAcc, tTR_rAcc
def epilog_smem_copy_and_partition(
self,
tidx: cutlass.Int32,
tTR_rC: cute.Tensor,
sC: cute.Tensor,
tiled_copy_t2r: cute.TiledCopy,
) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]:
"""
Create tiled copy for register to shared memory (R2S).
"""
atom = cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(),
self.out_dtype,
)
tiled_copy_r2s = cute.make_tiled_copy_D(atom, tiled_copy_t2r)
# (R2S, R2S_M, R2S_N, PIPE_D)
thr_copy_r2s = tiled_copy_r2s.get_slice(tidx)
tRS_sC = thr_copy_r2s.partition_D(sC)
tRS_rC = tiled_copy_r2s.retile(tTR_rC)
return tiled_copy_r2s, tRS_rC, tRS_sC
@staticmethod
def _compute_stages(
tiled_mma: cute.TiledMma,
@ -1961,12 +2059,12 @@ class Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel:
a_dtype: Type[cutlass.Numeric],
b_dtype: Type[cutlass.Numeric],
out_dtype: Type[cutlass.Numeric],
gemm_output_layout: utils.LayoutEnum,
epi_tile: cute.Tile,
cta_tile: cute.Tile,
sf_dtype: Type[cutlass.Numeric],
sf_vec_size: int,
num_smem_capacity: int,
occupancy: int,
use_blkred: bool,
) -> Tuple[int, int, int]:
"""Computes the number of stages for A/B/C operands based on heuristics.
@ -1978,12 +2076,10 @@ class Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel:
:type a_dtype: type[cutlass.Numeric]
:param b_dtype: Data type of operand B.
:type b_dtype: type[cutlass.Numeric]
:param epi_tile: The epilogue tile shape.
:type epi_tile: cute.Tile
:param out_dtype: Data type of operand C (output).
:type out_dtype: type[cutlass.Numeric]
:param gemm_output_layout: Layout of operand C.
:type gemm_output_layout: utils.LayoutEnum
:param cta_tile: The CTA tile shape.
:type cta_tile: cute.Tile
:param sf_dtype: Data type of scale factor.
:type sf_dtype: type[cutlass.Numeric]
:param sf_vec_size: Vector size of scale factor.
@ -1992,6 +2088,8 @@ class Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel:
:type num_smem_capacity: int
:param occupancy: Target number of CTAs per SM (occupancy).
:type occupancy: int
:param use_blkred: Whether to use block reduce.
:type use_blkred: bool
:return: A tuple containing the computed number of stages for:
(ACC stages, A/B operand stages, C stages)
@ -2001,7 +2099,7 @@ class Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel:
num_acc_stage = 1 if mma_tiler_mnk[1] == 256 else 2
# Default C stages
num_c_stage = 2
num_c_stage = 1
# Default Tile info stages
num_tile_stage = 2
@ -2034,6 +2132,12 @@ class Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel:
1, # a tmp 1 stage is provided
)
# satisfy 16B alignment for the output tensor
swizzled_pad = 16 // (out_dtype.width // 8)
c_smem_layout_staged_one = cute.make_layout(
(cta_tile[0], cta_tile[1]), stride=(cta_tile[1] + swizzled_pad, 1)
)
ab_bytes_per_stage = (
cute.size_in_bytes(a_dtype, a_smem_layout_stage_one)
+ cute.size_in_bytes(b_dtype, b_smem_layout_staged_one)
@ -2043,16 +2147,22 @@ class Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel:
# 1024B alignment for mbar
mbar_helpers_bytes = 1024
c_bytes_per_stage = cute.size_in_bytes(out_dtype, c_smem_layout_staged_one)
c_bytes = c_bytes_per_stage * num_c_stage
# Calculate A/B stages:
# Start with total smem per CTA (capacity / occupancy)
# Subtract reserved bytes and initial C stages bytes
# Divide remaining by bytes needed per A/B stage
num_ab_stage = (num_smem_capacity // occupancy - mbar_helpers_bytes) // ab_bytes_per_stage
if cutlass.const_expr(use_blkred):
num_ab_stage = (
num_smem_capacity // occupancy - (mbar_helpers_bytes + c_bytes)
) // ab_bytes_per_stage
else:
num_ab_stage = (
num_smem_capacity // occupancy - mbar_helpers_bytes
) // ab_bytes_per_stage
# Refine epilogue stages:
# Calculate remaining smem after allocating for A/B stages and reserved bytes
# Add remaining unused smem to epilogue
num_c_stage = 2
return num_acc_stage, num_ab_stage, num_c_stage, num_tile_stage
@staticmethod

View File

@ -306,6 +306,57 @@ def atomic_add_func(rOut_epi_packed, scatter_out_offset, loc=None, ip=None):
)
@dsl_user_op
def blk_reduce_bf16(dst_gemm, src_smem, size, loc=None, ip=None):
llvm.inline_asm(
None,
[
dst_gemm.iterator.llvm_ptr,
src_smem.iterator.llvm_ptr,
size.ir_value(),
],
"cp.reduce.async.bulk.global.shared::cta.bulk_group.add.noftz.bf16 [$0], [$1], $2;",
"l,l,r",
has_side_effects=True,
loc=loc,
ip=ip,
)
@dsl_user_op
def blk_reduce_fp32(dst_gemm, src_smem, size, loc=None, ip=None):
llvm.inline_asm(
None,
[
dst_gemm.iterator.llvm_ptr,
src_smem.iterator.llvm_ptr,
size.ir_value(),
],
"cp.reduce.async.bulk.global.shared::cta.bulk_group.add.f32 [$0], [$1], $2;",
"l,l,r",
has_side_effects=True,
loc=loc,
ip=ip,
)
@dsl_user_op
def blk_reduce_fp16(dst_gemm, src_smem, size, loc=None, ip=None):
llvm.inline_asm(
None,
[
dst_gemm.iterator.llvm_ptr,
src_smem.iterator.llvm_ptr,
size.ir_value(),
],
"cp.reduce.async.bulk.global.shared::cta.bulk_group.noftz.f16 [$0], [$1], $2;",
"l,l,r",
has_side_effects=True,
loc=loc,
ip=ip,
)
@dsl_user_op
def griddepcontrol_wait(*, loc=None, ip=None) -> None:
"""