[None][feat] CuteDSL MOE FC1 Enhancement (#10088)

Signed-off-by: Yuhan Li <51736452+liyuhannnnn@users.noreply.github.com>
This commit is contained in:
alel 2026-01-06 09:30:43 +08:00 committed by GitHub
parent 77712ed4ab
commit 6b8ae6fa81
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 1629 additions and 267 deletions

View File

@ -1518,9 +1518,9 @@ if IS_CUTLASS_DSL_AVAILABLE:
f"{self.__class__.kernel_class.__name__} supports SM 100 (B200) and SM 103 (B300) only, but got SM {sm_version}"
)
if self.tile_size not in (128, ):
if self.tile_size not in (128, 256):
raise ValueError(
f"{self.__class__.kernel_class.__name__} supports tile_size (MMA tile M dimension) 128 only, but got {self.tile_size}"
f"{self.__class__.kernel_class.__name__} supports tile_size (MMA tile M dimension) 128 and 256 only, but got {self.tile_size}"
)
def unique_id(self):

View File

@ -171,8 +171,9 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
- Token ID mapping enables efficient gather operation during A/SFA load
- SwiGLU activation fusion in epilogue (up * silu(gate) with interleaved weights)
- Optional quantization fusion for Float4E2M1FN output with scale factor generation
- Warp specialization: Scheduler (warp 10), LDGSTS A/SFA (warps 4-7), TMA B/SFB (warp 9),
MMA (warp 8), Epilogue (warps 0-3)
- Warp specialization: Scheduler (warp 10), A Sync Transform (warp 11, only used when
use_2cta_instrs is True), LDGSTS A/SFA (warps 4-7), TMA B/SFB (warp 9), MMA (warp 8),
Epilogue (warps 0-3)
:param sf_vec_size: Scalefactor vector size (16 for NVF4, 32 for MXF4/MXF8).
:type sf_vec_size: int
@ -302,6 +303,7 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
self.mma_warp_id = 8
self.tma_b_warp_id = 9
self.sched_warp_id = 10
self.sync_transform_warp_id = 11
self.threads_per_warp = 32
self.threads_per_cta = self.threads_per_warp * len(
(
@ -310,16 +312,30 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
self.tma_b_warp_id,
*self.epilog_warp_id,
self.sched_warp_id,
self.sync_transform_warp_id,
)
)
self.threads_wo_sched = self.threads_per_warp * len(
(
*self.epilog_warp_id,
self.mma_warp_id,
self.tma_b_warp_id,
*self.ldgsts_a_warp_id,
self.warps_wo_sched = (
len(
(
*self.epilog_warp_id,
self.mma_warp_id,
self.tma_b_warp_id,
self.sync_transform_warp_id,
*self.ldgsts_a_warp_id,
)
)
if self.use_2cta_instrs
else len(
(
*self.epilog_warp_id,
self.mma_warp_id,
self.tma_b_warp_id,
*self.ldgsts_a_warp_id,
)
)
)
self.threads_wo_sched = self.threads_per_warp * self.warps_wo_sched
# Set barrier for cta sync, epilogue sync and tmem ptr sync
self.cta_sync_barrier = pipeline.NamedBarrier(
@ -338,9 +354,11 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
barrier_id=4,
num_threads=self.threads_per_warp,
)
self.num_smem_capacity = utils.get_smem_capacity_in_bytes("sm_100")
SM100_TMEM_CAPACITY_COLUMNS = 512
self.num_tmem_alloc_cols = SM100_TMEM_CAPACITY_COLUMNS
self.vectorized_f32 = vectorized_f32
def _setup_attributes(self):
@ -410,6 +428,12 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
mma_inst_shape_k * mma_inst_tile_k,
)
self.mma_tiler_c = (
self.mma_inst_shape_mn[0],
self.mma_inst_shape_mn[1] // 2,
mma_inst_shape_k * mma_inst_tile_k,
)
self.cta_tile_shape_mnk = (
self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape),
self.mma_tiler[1],
@ -422,10 +446,10 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
self.mma_tiler_sfa[2],
)
self.mma_tiler_c = (
self.mma_inst_shape_mn[0],
self.mma_inst_shape_mn[1] // 2,
mma_inst_shape_k * mma_inst_tile_k,
self.cta_tile_shape_mnk_sfb = (
self.mma_tiler_sfb[0] // cute.size(tiled_mma.thr_id.shape),
self.mma_tiler_sfb[1],
self.mma_tiler_sfb[2],
)
self.cta_tile_shape_mnk_c = (
@ -509,8 +533,23 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
self.num_c_stage,
)
# Compute the number of tensor memory allocation columns
self.num_tmem_alloc_cols = 512
# Overlap and double buffer accumulator when num_acc_stage == 1 for cta_tile_n = 256 case
self.overlapping_accum = self.num_acc_stage == 1
# Compute number of TMEM columns for SFA/SFB/Accumulator
sf_atom_mn = 32
self.num_sfa_tmem_cols = (self.cta_tile_shape_mnk[0] // sf_atom_mn) * mma_inst_tile_k
self.num_sfb_tmem_cols = (self.cta_tile_shape_mnk_sfb[1] // sf_atom_mn) * mma_inst_tile_k
self.num_sf_tmem_cols = self.num_sfa_tmem_cols + self.num_sfb_tmem_cols
self.num_accumulator_tmem_cols = (
self.cta_tile_shape_mnk[1] * self.num_acc_stage
if not self.overlapping_accum
else self.cta_tile_shape_mnk[1] * 2 - self.num_sf_tmem_cols
)
self.epi_tile_n_required = 2 * cute.size(self.epi_tile[1])
# Only when overlapping_accum is enabled, we need to release accumulator buffer early in epilogue
self.iter_acc_early_release_in_epilogue = self.num_sf_tmem_cols // self.epi_tile_n_required
@cute.jit
def __call__(
@ -551,6 +590,8 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
5. Launch the kernel synchronously with warp specialization:
- Scheduler warp: Dispatches tile information
- LDGSTS warps: Load A and SFA with gather
- A Sync Transform warps: Transform the sync signal of A and SFA from global to
shared memory when use_2cta_instrs is True
- TMA warp: Load B and SFB with multicast
- MMA warp: Perform matrix multiply-accumulate
- Epilogue warps: Apply SwiGLU activation, optional quantization, and store results
@ -606,10 +647,12 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
# Setup attributes that dependent on gemm inputs
self._setup_attributes()
# Setup sfb tensor by filling B tensor to scale factor atom layout
# ((Atom_N, Rest_N),(Atom_K, Rest_K),RestL)
sfb_layout = blockscaled_utils.tile_atom_to_shape_SF(b.shape, self.sf_vec_size)
sfb = cute.make_tensor(sfb.iterator, sfb_layout)
# Setup sfc tensor by filling C tensor to scale factor atom layout
self.generate_sfc = sfc_tensor is not None and norm_const_tensor is not None
if cutlass.const_expr(self.generate_sfc):
sfc_layout = blockscaled_utils.tile_atom_to_shape_SF(c.shape, self.sf_vec_size)
@ -707,7 +750,7 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
# Define shared storage for kernel
@cute.struct
class SharedStorage:
class SharedStorage1cta:
# (bidx, bidy, bidz, valid, mn_limit)
sInfo: cute.struct.Align[
cute.struct.MemRange[cutlass.Int32, 5 * self.num_tile_stage],
@ -749,7 +792,53 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
self.buffer_align_bytes,
]
self.shared_storage = SharedStorage
@cute.struct
class SharedStorage2cta:
# (bidx, bidy, bidz, valid, mn_limit)
sInfo: cute.struct.Align[
cute.struct.MemRange[cutlass.Int32, 5 * self.num_tile_stage],
# 1 byte alignment
1,
]
a_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
a_sync_transform_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
b_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
acc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
tile_info_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_tile_stage * 2]
tmem_dealloc_mbar_ptr: cutlass.Int64
tmem_holding_buf: cutlass.Int32
# (EPI_TILE_M, EPI_TILE_N, STAGE)
sC: cute.struct.Align[
cute.struct.MemRange[
self.c_dtype,
cute.cosize(self.c_smem_layout_staged.outer),
],
self.buffer_align_bytes,
]
# (MMA, MMA_M, MMA_K, STAGE)
sA: cute.struct.Align[
cute.struct.MemRange[self.a_dtype, cute.cosize(self.a_smem_layout_staged.outer)],
self.buffer_align_bytes,
]
# (MMA, MMA_N, MMA_K, STAGE)
sB: cute.struct.Align[
cute.struct.MemRange[self.b_dtype, cute.cosize(self.b_smem_layout_staged.outer)],
self.buffer_align_bytes,
]
# (granularity_m, repeat_m), (granularity_k, repeat_k), num_scale_stage)
sSFA: cute.struct.Align[
cute.struct.MemRange[self.sf_dtype, cute.cosize(self.sfa_smem_layout_staged)],
self.buffer_align_bytes,
]
# (granularity_n, repeat_n), (granularity_k, repeat_k), num_scale_stage)
sSFB: cute.struct.Align[
cute.struct.MemRange[self.sf_dtype, cute.cosize(self.sfb_smem_layout_staged)],
self.buffer_align_bytes,
]
self.shared_storage = (
SharedStorage2cta if cutlass.const_expr(self.use_2cta_instrs) else SharedStorage1cta
)
# Launch the kernel synchronously
self.kernel(
@ -875,9 +964,7 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
# Prefetch tma desc
#
if warp_idx == self.tma_b_warp_id:
# cpasync.prefetch_descriptor(tma_atom_a)
cpasync.prefetch_descriptor(tma_atom_b)
# cpasync.prefetch_descriptor(tma_atom_sfa)
cpasync.prefetch_descriptor(tma_atom_sfb)
cpasync.prefetch_descriptor(tma_atom_c)
@ -911,10 +998,7 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
# Consumer: MMA warp for consuming A/SFA data
a_pipeline_producer_group = pipeline.CooperativeGroup(
pipeline.Agent.Thread,
128
* cute.size(
cluster_layout_vmnk, mode=[0]
), # 4 warps * 32 threads per warp = 128 threads
self.threads_per_warp * 4,
)
a_pipeline = PipelineCpAsyncUmma.create(
@ -924,9 +1008,25 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
cta_layout_vmnk=cluster_layout_vmnk,
defer_sync=True,
enable_cp_async=(not self.use_2cta_instrs),
)
# Pipeline Init: Initialize A SYNC Transform pipeline when use_2cta_instrs is True
# Producer: 1 warp (warp 11) for LDGSTS SYNC transformation operations
# Consumer: MMA warp for consuming A/SFA data
if cutlass.const_expr(self.use_2cta_instrs):
a_sync_transform_pipeline_producer_group = pipeline.CooperativeGroup(
pipeline.Agent.Thread,
32 * cute.size(cluster_layout_vmnk, mode=[0]),
)
a_sync_transform_pipeline = pipeline.PipelineAsyncUmma.create(
barrier_storage=storage.a_sync_transform_mbar_ptr.data_ptr(),
num_stages=self.num_ab_stage,
producer_group=a_sync_transform_pipeline_producer_group,
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
cta_layout_vmnk=cluster_layout_vmnk,
defer_sync=True,
)
# Pipeline Init: Initialize B pipeline for TMA operations
# Using PipelineTmaUmma for B/SFB since they use TMA load with multicast support
# Producer: TMA B/SFB warp (warp 9) - 1 warp issuing TMA operations
@ -959,7 +1059,7 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
cta_layout_vmnk=cluster_layout_vmnk,
)
# Pipeline Init: Tensor memory dealloc barrier init
# Pipeline Init:Initialize tile info pipeline (barrier) and states
tile_info_pipeline_producer_group = pipeline.CooperativeGroup(
pipeline.Agent.Thread,
self.threads_per_warp * 1,
@ -1001,16 +1101,14 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
sSFA = storage.sSFA.get_tensor(sfa_smem_layout_staged)
# (granularity_n, repeat_n), (granularity_k, repeat_k), num_scale_stage)
sSFB = storage.sSFB.get_tensor(sfb_smem_layout_staged)
# (bidx, bidy, bidz, valid)
# (bidx, bidy, bidz, valid, mn_limit)
info_layout = cute.make_layout((5, self.num_tile_stage), stride=(1, 5))
sInfo = storage.sInfo.get_tensor(info_layout)
#
# Compute multicast mask for A/B buffer full
#
# a_full_mcast_mask = None
b_full_mcast_mask = None
# sfa_full_mcast_mask = None
sfb_full_mcast_mask = None
if cutlass.const_expr(self.is_b_mcast or use_2cta_instrs):
b_full_mcast_mask = cpasync.create_tma_multicast_mask(
@ -1107,7 +1205,27 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
# (MMA, MMA_M, MMA_N)
acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2])
# (MMA, MMA_M, MMA_N, STAGE)
tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage))
if cutlass.const_expr(self.overlapping_accum):
num_acc_stage_overlapped = 2
tCtAcc_fake = tiled_mma.make_fragment_C(
cute.append(acc_shape, num_acc_stage_overlapped)
)
# (MMA, MMA_M, MMA_N, STAGE)
tCtAcc_fake = cute.make_tensor(
tCtAcc_fake.iterator,
cute.make_layout(
tCtAcc_fake.shape,
stride=(
tCtAcc_fake.stride[0],
tCtAcc_fake.stride[1],
tCtAcc_fake.stride[2],
(256 - self.num_sf_tmem_cols) * tCtAcc_fake.stride[0][1],
),
),
)
else:
# (MMA, MMA_M, MMA_N, STAGE)
tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage))
#
# Cluster wait before tensor memory alloc
@ -1120,7 +1238,7 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
griddepcontrol_wait()
#
# Specialized Schedule warp
# Specialized Schedule Warp
#
if warp_idx == self.sched_warp_id:
#
@ -1187,7 +1305,6 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
# with gather/permutation capability enabled by token_id_mapping
#
if warp_idx <= self.ldgsts_a_warp_id[-1] and warp_idx >= self.ldgsts_a_warp_id[0]:
# cute.arch.warpgroup_reg_dealloc(self.num_regs_uniform_warps)
#
# Setup LDGSTS copy atoms for A and SFA
# A: 8x LDGSTS.128 per thread with swizzle_128B for A matrix (32 elements per thread)
@ -1274,9 +1391,6 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
tile_info_consumer_state.advance()
while is_valid_tile:
# Get tile coord from tile scheduler
# cur_tile_coord = work_tile.tile_idx
# Load token IDs for gather operation
# For A matrix: each thread loads 8 token offsets (for 8 LDGSTS.128 operations)
# For SFA matrix: each thread loads 1 token offset (for 4 LDGSTS.32 operations)
@ -1409,10 +1523,6 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
sfa_atom_copy, tAgSFA_slice, tAsSFA_slice, pred=sfa_predicate_tensor
)
# Signal the completion of async
if cutlass.const_expr(self.use_2cta_instrs):
cute.arch.cp_async_commit_group()
cute.arch.cp_async_wait_group(0)
a_pipeline.producer_commit(a_producer_state)
# Peek (try_wait) A buffer empty for k_tile = prefetch_k_tile_cnt + k_tile + 1
@ -1440,6 +1550,83 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
#
a_pipeline.producer_tail(a_producer_state)
#
# Specialized A/SFA Sync Transform Warp (warp 11) when use_2cta_instrs is True
# This warp serve as sync transformation for A and SFA
#
if warp_idx == self.sync_transform_warp_id:
if cutlass.const_expr(self.use_2cta_instrs):
#
# Persistent tile scheduling loop
#
tile_sched = utils.StaticPersistentTileScheduler.create(
tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim()
)
# First tile
work_tile = tile_sched.initial_work_tile_info()
a_consumer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer, self.num_ab_stage
)
a_sync_transform_producer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Producer, self.num_ab_stage
)
tile_info_consumer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer, self.num_tile_stage
)
# Get the first tile info
valid_tile_info = cute.make_rmem_tensor((1,), cutlass.Int32)
tile_info_pipeline.consumer_wait(tile_info_consumer_state)
valid_tile_info[0] = sInfo[(3, tile_info_consumer_state.index)]
is_valid_tile = valid_tile_info[0] == 1
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
)
tile_info_pipeline.consumer_release(tile_info_consumer_state)
tile_info_consumer_state.advance()
while is_valid_tile:
# Peek (try_wait) A buffer full for k_tile = 0
a_consumer_state.reset_count()
peek_a_full_status = cutlass.Boolean(1)
if a_consumer_state.count < k_tile_cnt:
peek_a_full_status = a_pipeline.consumer_try_wait(a_consumer_state)
# Peek (try_wait) a sync transform buffer empty
a_sync_transform_producer_state.reset_count()
for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1):
# Conditionally wait for A buffer full
a_pipeline.consumer_wait(a_consumer_state, peek_a_full_status)
a_sync_transform_pipeline.producer_commit(a_sync_transform_producer_state)
a_sync_transform_producer_state.advance()
# Peek (try_wait) AB buffer full for k_tile = k_tile + 1
a_consumer_state.advance()
peek_a_full_status = cutlass.Boolean(1)
if a_consumer_state.count < k_tile_cnt:
peek_a_full_status = a_pipeline.consumer_try_wait(a_consumer_state)
#
# Advance to next tile
#
tile_info_pipeline.consumer_wait(tile_info_consumer_state)
valid_tile_info[0] = sInfo[(3, tile_info_consumer_state.index)]
is_valid_tile = valid_tile_info[0] == 1
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
)
tile_info_pipeline.consumer_release(tile_info_consumer_state)
tile_info_consumer_state.advance()
#
# Wait A sync transform buffer empty
#
a_sync_transform_pipeline.producer_tail(a_sync_transform_producer_state)
#
# Specialized TMA B/SFB load warp (warp 9)
# This warp uses TMA instructions to load B and SFB from global to shared memory
@ -1488,9 +1675,6 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
# ((atom_v, rest_v), loopK)
tBgB_slice = tBgB[(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])]
# ((atom_v, rest_v), RestK)
# tAgSFA_slice = tAgSFA[(None, mma_tile_coord_mnl[0], None, 0)]
# Apply SFB slicing hack when cta_tile_shape_n=64
slice_n = mma_tile_coord_mnl[1]
if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 64):
@ -1578,7 +1762,7 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
# Make SFA tmem tensor
sfa_tmem_ptr = cute.recast_ptr(
acc_tmem_ptr + tcgen05.find_tmem_tensor_col_offset(tCtAcc_base),
acc_tmem_ptr + self.num_accumulator_tmem_cols,
dtype=self.sf_dtype,
)
# (MMA, MMA_M, MMA_K)
@ -1592,9 +1776,7 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
# Make SFB tmem tensor
sfb_tmem_ptr = cute.recast_ptr(
acc_tmem_ptr
+ tcgen05.find_tmem_tensor_col_offset(tCtAcc_base)
+ tcgen05.find_tmem_tensor_col_offset(tCtSFA),
acc_tmem_ptr + self.num_accumulator_tmem_cols + self.num_sfa_tmem_cols,
dtype=self.sf_dtype,
)
# (MMA, MMA_N, MMA_K)
@ -1627,9 +1809,14 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
)
work_tile = tile_sched.initial_work_tile_info()
if cutlass.const_expr(self.use_2cta_instrs):
a_sync_transform_consumer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer, self.num_ab_stage
)
a_consumer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer, self.num_ab_stage
)
b_consumer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer, self.num_ab_stage
)
@ -1656,12 +1843,24 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
while is_valid_tile:
# Peek (try_wait) AB buffer full for k_tile = 0
a_consumer_state.reset_count()
if cutlass.const_expr(self.use_2cta_instrs):
a_sync_transform_consumer_state.reset_count()
peek_a_sync_transform_full_status = cutlass.Boolean(1)
if a_sync_transform_consumer_state.count < k_tile_cnt and is_leader_cta:
peek_a_sync_transform_full_status = (
a_sync_transform_pipeline.consumer_try_wait(
a_sync_transform_consumer_state
)
)
a_consumer_state.reset_count()
else:
a_consumer_state.reset_count()
peek_a_full_status = cutlass.Boolean(1)
if a_consumer_state.count < k_tile_cnt:
peek_a_full_status = a_pipeline.consumer_try_wait(a_consumer_state)
b_consumer_state.reset_count()
peek_a_full_status = cutlass.Boolean(1)
peek_b_full_status = cutlass.Boolean(1)
if a_consumer_state.count < k_tile_cnt and is_leader_cta:
peek_a_full_status = a_pipeline.consumer_try_wait(a_consumer_state)
if b_consumer_state.count < k_tile_cnt and is_leader_cta:
peek_b_full_status = b_pipeline.consumer_try_wait(b_consumer_state)
@ -1671,11 +1870,16 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
tile_info[2],
)
tCtAcc = tCtAcc_base[(None, None, None, acc_producer_state.index)]
# Get accumulator stage index
if cutlass.const_expr(self.overlapping_accum):
acc_stage_index = acc_producer_state.phase ^ 1
else:
acc_stage_index = acc_producer_state.index
tCtAcc = tCtAcc_base[(None, None, None, acc_stage_index)]
# Apply TMEM pointer offset hack when cta_tile_shape_n=192 or
# cta_tile_shape_n=64
tCtSFB_mma = tCtSFB
if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 192):
# If this is an ODD tile, shift the TMEM start address for
@ -1686,8 +1890,8 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
)
shifted_ptr = cute.recast_ptr(
acc_tmem_ptr
+ tcgen05.find_tmem_tensor_col_offset(tCtAcc_base)
+ tcgen05.find_tmem_tensor_col_offset(tCtSFA)
+ self.num_accumulator_tmem_cols
+ self.num_sfa_tmem_cols
+ offset,
dtype=self.sf_dtype,
)
@ -1697,8 +1901,8 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
offset = cutlass.Int32((mma_tile_coord_mnl[1] % 2) * 2)
shifted_ptr = cute.recast_ptr(
acc_tmem_ptr
+ tcgen05.find_tmem_tensor_col_offset(tCtAcc_base)
+ tcgen05.find_tmem_tensor_col_offset(tCtSFA)
+ self.num_accumulator_tmem_cols
+ self.num_sfa_tmem_cols
+ offset,
dtype=self.sf_dtype,
)
@ -1723,7 +1927,12 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
if is_leader_cta:
# Conditionally wait for AB buffer full
a_pipeline.consumer_wait(a_consumer_state, peek_a_full_status)
if cutlass.const_expr(self.use_2cta_instrs):
a_sync_transform_pipeline.consumer_wait(
a_sync_transform_consumer_state, peek_a_sync_transform_full_status
)
else:
a_pipeline.consumer_wait(a_consumer_state, peek_a_full_status)
b_pipeline.consumer_wait(b_consumer_state, peek_b_full_status)
# Copy SFA/SFB from smem to tmem
@ -1781,16 +1990,31 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
# Async arrive AB buffer empty
a_pipeline.consumer_release(a_consumer_state)
if cutlass.const_expr(self.use_2cta_instrs):
a_sync_transform_pipeline.consumer_release(
a_sync_transform_consumer_state
)
b_pipeline.consumer_release(b_consumer_state)
# Peek (try_wait) AB buffer full for k_tile = k_tile + 1
a_consumer_state.advance()
b_consumer_state.advance()
peek_a_full_status = cutlass.Boolean(1)
if a_consumer_state.count < k_tile_cnt:
if is_leader_cta:
if cutlass.const_expr(self.use_2cta_instrs):
a_sync_transform_consumer_state.advance()
peek_a_sync_transform_full_status = cutlass.Boolean(1)
if a_sync_transform_consumer_state.count < k_tile_cnt:
if is_leader_cta:
peek_a_sync_transform_full_status = (
a_sync_transform_pipeline.consumer_try_wait(
a_sync_transform_consumer_state
)
)
a_consumer_state.advance()
else:
a_consumer_state.advance()
peek_a_full_status = cutlass.Boolean(1)
if a_consumer_state.count < k_tile_cnt:
peek_a_full_status = a_pipeline.consumer_try_wait(a_consumer_state)
b_consumer_state.advance()
peek_b_full_status = cutlass.Boolean(1)
if b_consumer_state.count < k_tile_cnt:
if is_leader_cta:
@ -1959,9 +2183,18 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
)
]
# Get accumulator stage index
if cutlass.const_expr(self.overlapping_accum):
acc_stage_index = acc_consumer_state.phase
reverse_subtile = (
cutlass.Boolean(True) if acc_stage_index == 0 else cutlass.Boolean(False)
)
else:
acc_stage_index = acc_consumer_state.index
# Set tensor memory buffer for current tile
# (T2R, T2R_M, T2R_N, EPI_M, EPI_M)
tTR_tAcc = tTR_tAcc_base[(None, None, None, None, None, acc_consumer_state.index)]
tTR_tAcc = tTR_tAcc_base[(None, None, None, None, None, acc_stage_index)]
if cutlass.const_expr(self.generate_sfc):
# (T2R, T2R_M, T2R_N, RestM, RestN)
@ -1990,18 +2223,36 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
# up * silu(gate)
#
subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3])
num_prev_subtiles = tile_sched.num_tiles_executed * subtile_cnt
for subtile_idx in cutlass.range(0, subtile_cnt, 2):
real_subtile_idx = subtile_idx // 2
if cutlass.const_expr(self.overlapping_accum):
if reverse_subtile:
real_subtile_idx = (
self.cta_tile_shape_mnk[1] // self.epi_tile_n_required
- 1
- subtile_idx // 2
)
#
# Load accumulator from tensor memory buffer to register
#
tTR_tAcc_mn_up = tTR_tAcc[(None, None, None, subtile_idx)]
tTR_tAcc_mn_gate = tTR_tAcc[(None, None, None, subtile_idx + 1)]
tTR_tAcc_mn_up = tTR_tAcc[(None, None, None, real_subtile_idx * 2)]
tTR_tAcc_mn_gate = tTR_tAcc[(None, None, None, real_subtile_idx * 2 + 1)]
cute.copy(tiled_copy_t2r, tTR_tAcc_mn_up, tTR_rAcc_up)
cute.copy(tiled_copy_t2r, tTR_tAcc_mn_gate, tTR_rAcc_gate)
#
# Async arrive accumulator buffer empty earlier when overlapping_accum is enabled
#
if cutlass.const_expr(self.overlapping_accum):
if subtile_idx // 2 == self.iter_acc_early_release_in_epilogue:
# Fence for TMEM load
cute.arch.fence_view_async_tmem_load()
with cute.arch.elect_one():
acc_pipeline.consumer_release(acc_consumer_state)
acc_consumer_state.advance()
acc_vec_up = tTR_rAcc_up.load()
acc_vec_gate = tTR_rAcc_gate.load()
@ -2075,7 +2326,7 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
# Assume subtile partitioned always happens on n dimension
sfc_subtile_idx_mn = (
tile_info[0] * self.epi_tile_cnt[0],
tile_info[1] * self.epi_tile_cnt[1] + subtile_idx // 2,
tile_info[1] * self.epi_tile_cnt[1] + real_subtile_idx,
)
tCgSFC = tCgSFC_mn[
(
@ -2195,7 +2446,7 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
# Store C to shared memory
#
num_prev_subtiles = num_prev_subtiles + 1
c_buffer = (num_prev_subtiles + subtile_idx // 2) % self.num_c_stage
c_buffer = num_prev_subtiles % self.num_c_stage
cute.copy(
tiled_copy_r2s,
@ -2215,7 +2466,7 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
cute.copy(
tma_atom_c,
bSG_sC[(None, c_buffer)],
bSG_gC[(None, subtile_idx // 2)],
bSG_gC[(None, real_subtile_idx)],
)
# Fence and barrier to make sure shared memory store is visible to TMA store
c_pipeline.producer_commit()
@ -2225,9 +2476,10 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
#
# Async arrive accumulator buffer empty
#
with cute.arch.elect_one():
acc_pipeline.consumer_release(acc_consumer_state)
acc_consumer_state.advance()
if cutlass.const_expr(not self.overlapping_accum):
with cute.arch.elect_one():
acc_pipeline.consumer_release(acc_consumer_state)
acc_consumer_state.advance()
#
# Advance to next tile
@ -2509,10 +2761,6 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
# Start with total smem per CTA (capacity / occupancy)
# Subtract reserved bytes and initial C stages bytes
# Divide remaining by bytes needed per A/B stage
# cute.printf("num_smem_capacity: {}, occupancy: {}, "
# "mbar_helpers_bytes: {}, c_bytes: {}",
# num_smem_capacity, occupancy, mbar_helpers_bytes, c_bytes)
# cute.printf("ab_bytes_per_stage: {}", ab_bytes_per_stage)
num_ab_stage = (
num_smem_capacity // occupancy - (mbar_helpers_bytes + c_bytes)
) // ab_bytes_per_stage
@ -2695,15 +2943,12 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
if ab_dtype is cutlass.Float4E2M1FN and not (a_major == "k" and b_major == "k"):
is_valid = False
# TODO: Currently we don't support m major output for Float4E2M1FN,
# Need to support it in the future.
if c_dtype is cutlass.Float4E2M1FN and c_major == "m":
is_valid = False
return is_valid
@staticmethod
def is_valid_mma_tiler_and_cluster_shape(
use_2cta_instrs: bool,
mma_tiler_mn: Tuple[int, int],
cluster_shape_mn: Tuple[int, int],
) -> bool:
@ -2723,19 +2968,18 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
is_valid = True
# Skip invalid mma tile shape
if not (
(not use_2cta_instrs and mma_tiler_mn[0] in [64, 128])
or (use_2cta_instrs and mma_tiler_mn[0] in [128, 256])
):
if mma_tiler_mn[0] not in (128, 256):
is_valid = False
# Skip invalid mma tile n
# Needs to have even iterations with Epi Tile N 64 for swiGeLU fusion
# SwiGlu Fusion requires even epi_tile counts,
# based on epi_tile_n = 64, only mma_tiler_n = 128 and 256 are supported
if mma_tiler_mn[1] not in (128, 256):
is_valid = False
# Skip illegal cluster shape
if cluster_shape_mn[0] % (2 if use_2cta_instrs else 1) != 0:
if (mma_tiler_mn[0] // cluster_shape_mn[0]) != 128:
is_valid = False
# Skip invalid cluster shape
if (
cluster_shape_mn[0] * cluster_shape_mn[1] > 16
or cluster_shape_mn[0] <= 0
@ -2748,15 +2992,11 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
or not is_power_of_2(cluster_shape_mn[1])
):
is_valid = False
cluster_tiler_m = (cluster_shape_mn[0] // (2 if use_2cta_instrs else 1)) * mma_tiler_mn[0]
# Skip invalid cluster tiler shape since contiguous layout can't handle oob access
# The contiguous layout means the aligned data is stored in a contiguous manner.
# It can't handle runtime oob when alignment is not align with the tile_M,
# since the problem shape of TMA store can't be changed at runtime.
if cluster_tiler_m not in [64, 128, 256]:
# We only support cluster shape n = 1 for now
# TODO: Support cluster shape n > 1
if cluster_shape_mn[1] != 1:
is_valid = False
return is_valid
@staticmethod
@ -2812,8 +3052,9 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
is_valid = False
return is_valid
@staticmethod
@classmethod
def can_implement(
cls,
ab_dtype: Type[cutlass.Numeric],
sf_dtype: Type[cutlass.Numeric],
sf_vec_size: int,
@ -2839,8 +3080,6 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
:type sf_vec_size: int
:param c_dtype: The data type of the output tensor
:type c_dtype: Type[cutlass.Numeric]
:param use_2cta_instrs: Whether to use 2 CTA groups
:type use_2cta_instrs: bool
:param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler
:type mma_tiler_mn: Tuple[int, int]
:param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster
@ -2865,25 +3104,20 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
"""
can_implement = True
# Skip unsupported types
if not BlockScaledContiguousGatherGroupedGemmKernel.is_valid_dtypes_and_scale_factor_vec_size(
if not cls.is_valid_dtypes_and_scale_factor_vec_size(
ab_dtype, sf_dtype, sf_vec_size, c_dtype
):
can_implement = False
# Skip unsupported layouts
if not BlockScaledContiguousGatherGroupedGemmKernel.is_valid_layouts(
ab_dtype, c_dtype, a_major, b_major, c_major
):
if not cls.is_valid_layouts(ab_dtype, c_dtype, a_major, b_major, c_major):
can_implement = False
# Skip invalid mma tile shape and cluster shape
use_2cta_instrs = mma_tiler_mn[0] == 256
if not BlockScaledContiguousGatherGroupedGemmKernel.is_valid_mma_tiler_and_cluster_shape(
use_2cta_instrs, mma_tiler_mn, cluster_shape_mn
):
if not cls.is_valid_mma_tiler_and_cluster_shape(mma_tiler_mn, cluster_shape_mn):
can_implement = False
# Skip illegal problem shape for load/store alignment
if not BlockScaledContiguousGatherGroupedGemmKernel.is_valid_tensor_alignment(
if not cls.is_valid_tensor_alignment(
m, n, k, l, ab_dtype, c_dtype, a_major, b_major, c_major
):
can_implement = False

View File

@ -2244,6 +2244,10 @@ class Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel:
cluster_shape_mn[0] * cluster_shape_mn[1] > 16
or cluster_shape_mn[0] <= 0
or cluster_shape_mn[1] <= 0
# Special cluster shape check for scale factor multicasts.
# Due to limited size of scale factors, we can't multicast among more than 4 CTAs.
or cluster_shape_mn[0] > 4
or cluster_shape_mn[1] > 4
or not is_power_of_2(cluster_shape_mn[0])
or not is_power_of_2(cluster_shape_mn[1])
):
@ -2304,8 +2308,9 @@ class Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel:
is_valid = False
return is_valid
@staticmethod
@classmethod
def can_implement(
cls,
ab_dtype: Type[cutlass.Numeric],
sf_dtype: Type[cutlass.Numeric],
sf_vec_size: int,
@ -2355,24 +2360,20 @@ class Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel:
"""
can_implement = True
# Skip unsupported types
if not Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel.is_valid_dtypes_and_scale_factor_vec_size(
if not cls.is_valid_dtypes_and_scale_factor_vec_size(
ab_dtype, sf_dtype, sf_vec_size, out_dtype
):
can_implement = False
# Skip unsupported layouts
if not Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel.is_valid_layouts(
ab_dtype, out_dtype, a_major, b_major, out_major
):
if not cls.is_valid_layouts(ab_dtype, out_dtype, a_major, b_major, out_major):
can_implement = False
# Skip invalid mma tile shape and cluster shape
if not Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel.is_valid_mma_tiler_and_cluster_shape(
mma_tiler_mn, cluster_shape_mn
):
if not cls.is_valid_mma_tiler_and_cluster_shape(mma_tiler_mn, cluster_shape_mn):
can_implement = False
# Skip illegal problem shape for load/store alignment
if not Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel.is_valid_tensor_alignment(
if not cls.is_valid_tensor_alignment(
m, n, k, l, ab_dtype, out_dtype, a_major, b_major, out_major
):
can_implement = False

View File

@ -82,15 +82,15 @@ This GEMM kernel supports the following features:
This GEMM works as follows:
1. DMA warp: Load A and B matrices from global memory (GMEM) to shared memory (SMEM) using TMA operations.
2. SCALE warp: Load scaleA and scaleB matrices from global memory (GMEM) to shared memory (SMEM) using non-TMA
operations.
2. MMA warp:
- Load scale factor A/B from shared memory (SMEM) to tensor memory (TMEM) using tcgen05.cp instruction.
- Perform matrix multiply-accumulate (MMA) operations using tcgen05.mma instruction.
3. EPILOGUE warp:
- Load completed accumulator from tensor memory (TMEM) to registers (RMEM) using tcgen05.ld.
- Apply alpha and update the final accumulator Final = alpha * acc
- Type convert Final matrix to output type.
5. EPILOGUE warps:
- Load two accumulator subtiles (up and gate) from tensor memory (TMEM) to registers (RMEM) using tcgen05.ld.
- Apply alpha scaling: up_scaled = alpha * up, gate_scaled = alpha * gate
- Compute SwiGLU activation: output = up_scaled * silu(gate_scaled), where silu(x) = x * sigmoid(x)
- If c_dtype is Float4E2M1FN: generate scale factor C (SFC) and quantize output
- Type convert output to c_dtype.
- Store C matrix from registers (RMEM) to shared memory (SMEM) to global memory (GMEM) with TMA operations.
SM100 tcgen05.mma.kind.block_scale instructions operate as follows:
@ -101,24 +101,6 @@ SM100 tcgen05.mma.kind.block_scale instructions operate as follows:
- Write accumulator to TMEM
The accumulator in TMEM must then be loaded to registers before writing back to GMEM.
.. code-block:: bash
python examples/blackwell/contiguous_blockscaled_grouped_gemm.py \
--ab_dtype Float4E2M1FN --c_dtype BFloat16 \
--sf_dtype Float8E4M3FN --sf_vec_size 16 \
--mma_tiler_mn 256,128 --cluster_shape_mn 2,1 \
--mnkl 256,4096,7168,1
To collect performance with NCU profiler:
.. code-block:: bash
ncu python examples/blackwell/contiguous_blockscaled_grouped_gemm.py \
--ab_dtype Float4E2M1FN --c_dtype BFloat16 \
--sf_dtype Float8E4M3FN --sf_vec_size 16 \
--mma_tiler_mn 256,128 --cluster_shape_mn 2,1 \
--mnkl 256,4096,7168,1
Constraints:
* Supported input data types: mxf8, mxf4, nvf4
see detailed valid dtype combinations in below Sm100BlockScaledPersistentDenseGemmKernel class documentation
@ -138,14 +120,14 @@ CUDA Graph Support:
- A matrix: padded to permuted_m rows (padding rows contain dummy data)
- C matrix: padded to permuted_m rows (output buffer for cuda_graph)
- Scale factor A: padded to match A matrix dimensions
* Kernel handling of padding (similar to masked_grouped_gemm.py):
* Kernel handling of padding:
- Scheduler warp checks if tile_idx >= num_non_exiting_tiles to exit
- Only valid tiles (tile_idx < num_non_exiting_tiles) are written to tile_info pipeline
- When no more valid tiles exist, outer loop exits and calls producer_tail()
- Consumer warps process only valid tiles from pipeline
- No deadlock or synchronization issues
* Consumer warps check initial tile against num_non_exiting_tiles and set is_valid_tile=False if
tile_idx >= num_non_exiting_tiles
* Consumer warps check initial tile against num_non_exiting_tiles and set
is_valid_tile=False if tile_idx >= num_non_exiting_tiles
* Only rows within (aligned_groupm[0]+aligned_groupm[1]+...) contain valid data
* Padding rows in C matrix will not be written by the kernel
"""
@ -199,23 +181,34 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
cluster_shape_mn: Tuple[int, int],
vectorized_f32: bool,
):
"""Initializes the configuration for a Blackwell blockscaled dense GEMM kernel.
"""Initializes the configuration for a Blackwell blockscaled dense GEMM kernel with SwiGLU fusion.
This configuration includes several key aspects:
1. MMA Instruction Settings (tcgen05):
- acc_dtype: Data types for MMA accumulator.
- mma_tiler_mn: The (M, N) shape of the MMA instruction tiler.
- use_2cta_instrs: Boolean indicating if the tcgen05 MMA variant
with cta_group=2 should be used.
- use_2cta_instrs: Automatically inferred from mma_tiler_mn[0]
(True when M=256, False when M=128).
2. Cluster Shape:
- cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster.
3. Scale Factor Configuration:
- sf_vec_size: Vector size for block-scaled quantization.
4. Performance Optimization:
- vectorized_f32: Enable vectorized f32x2 operations.
:param sf_vec_size: Vector size for scale factors (16 for NVF4, 32 for MXF4/MXF8).
:type sf_vec_size: int
:param mma_tiler_mn: Tuple (M, N) shape of the MMA instruction.
use_2cta_instrs is automatically set based on M (True if M=256, False if M=128).
:type mma_tiler_mn: Tuple[int, int]
:param cluster_shape_mn: Tuple (ClusterM, ClusterN) shape of the cluster.
:type cluster_shape_mn: Tuple[int, int]
:param vectorized_f32: Enable vectorized f32x2 operations for better performance.
:type vectorized_f32: bool
"""
self.sf_vec_size = sf_vec_size
@ -248,10 +241,6 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
self.tma_warp_id,
)
)
# TODO: Do we need to reallocate register?
# self.num_regs_uniform_warps = 64
# self.num_regs_sched_warps = 64
# self.num_regs_epilogue_warps = 216
# Set barrier for cta sync, epilogue sync and tmem ptr sync
self.cta_sync_barrier = pipeline.NamedBarrier(
@ -270,6 +259,7 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
barrier_id=4,
num_threads=self.threads_per_warp,
)
self.num_smem_capacity = utils.get_smem_capacity_in_bytes("sm_100")
SM100_TMEM_CAPACITY_COLUMNS = 512
self.num_tmem_alloc_cols = SM100_TMEM_CAPACITY_COLUMNS
@ -337,22 +327,29 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
mma_inst_shape_k * mma_inst_tile_k,
)
self.cta_tile_shape_mnk = (
self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape),
self.mma_tiler[1],
self.mma_tiler[2],
)
self.mma_tiler_c = (
self.mma_inst_shape_mn[0],
self.mma_inst_shape_mn[1] // 2,
mma_inst_shape_k * mma_inst_tile_k,
)
self.cta_tile_shape_mnk = (
self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape),
self.mma_tiler[1],
self.mma_tiler[2],
)
self.cta_tile_shape_mnk_sfb = (
self.mma_tiler_sfb[0] // cute.size(tiled_mma.thr_id.shape),
self.mma_tiler_sfb[1],
self.mma_tiler_sfb[2],
)
self.cta_tile_shape_mnk_c = (
self.mma_tiler_c[0] // cute.size(tiled_mma.thr_id.shape),
self.mma_tiler_c[1],
self.mma_tiler_c[2],
)
# Compute cluster layout
self.cluster_layout_vmnk = cute.tiled_divide(
cute.make_layout((*self.cluster_shape_mn, 1)),
@ -436,6 +433,24 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
self.num_c_stage,
)
# Overlap and double buffer accumulator when num_acc_stage == 1 for cta_tile_n = 256 case
self.overlapping_accum = self.num_acc_stage == 1
# Compute number of TMEM columns for SFA/SFB/Accumulator
sf_atom_mn = 32
self.num_sfa_tmem_cols = (self.cta_tile_shape_mnk[0] // sf_atom_mn) * mma_inst_tile_k
self.num_sfb_tmem_cols = (self.cta_tile_shape_mnk_sfb[1] // sf_atom_mn) * mma_inst_tile_k
self.num_sf_tmem_cols = self.num_sfa_tmem_cols + self.num_sfb_tmem_cols
self.num_accumulator_tmem_cols = (
self.cta_tile_shape_mnk[1] * self.num_acc_stage
if not self.overlapping_accum
else self.cta_tile_shape_mnk[1] * 2 - self.num_sf_tmem_cols
)
self.epi_tile_n_required = 2 * cute.size(self.epi_tile[1])
# Only when overlapping_accum is enabled, we need to release accumulator buffer early in epilogue
self.iter_acc_early_release_in_epilogue = self.num_sf_tmem_cols // self.epi_tile_n_required
@cute.jit
def __call__(
self,
@ -530,6 +545,7 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
self.mma_inst_shape_mn,
)
# For 2CTA blockscaled kernels, SFB needs to be replicated across peer CTAs.
tiled_mma_sfb = sm100_utils.make_blockscaled_trivial_tiled_mma(
self.a_dtype,
self.a_major_mode,
@ -591,7 +607,9 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
internal_type=cutlass.Int16,
)
if cutlass.const_expr(self.cta_tile_shape_mnk_c[1] == 192):
# This modifies the layout to handle overlapping 256x(# of scale factors for a single column of B (nNSF))
# logical blocks for SFB when cta_tile_shape_n=192.
if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 192):
x = tma_tensor_sfb.stride[0][1]
y = cute.ceil_div(tma_tensor_sfb.shape[0][1], 4)
@ -727,8 +745,8 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
tSF: cute.Tensor,
) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]:
"""
Make tiledCopy for smem to tmem load for scale factor tensor, then use it to partition smem memory (source) and
tensor memory (destination).
Make tiledCopy for smem to tmem load for scale factor tensor, then use it to
partition smem memory (source) and tensor memory (destination).
:param sSF: The scale factor tensor in smem
:type sSF: cute.Tensor
@ -779,7 +797,7 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
mSFB_nkl: cute.Tensor,
tma_atom_c: cute.CopyAtom,
mC_mnl: cute.Tensor,
mSFD_mnl: Optional[cute.Tensor],
mSFC_mnl: Optional[cute.Tensor],
norm_const_tensor: Optional[cute.Tensor],
tile_idx_to_expert_idx: cute.Tensor,
num_non_exiting_tiles: cute.Tensor,
@ -1046,7 +1064,27 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
# (MMA, MMA_M, MMA_N)
acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2])
# (MMA, MMA_M, MMA_N, STAGE)
tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage))
if cutlass.const_expr(self.overlapping_accum):
num_acc_stage_overlapped = 2
tCtAcc_fake = tiled_mma.make_fragment_C(
cute.append(acc_shape, num_acc_stage_overlapped)
)
# (MMA, MMA_M, MMA_N, STAGE)
tCtAcc_fake = cute.make_tensor(
tCtAcc_fake.iterator,
cute.make_layout(
tCtAcc_fake.shape,
stride=(
tCtAcc_fake.stride[0],
tCtAcc_fake.stride[1],
tCtAcc_fake.stride[2],
(256 - self.num_sf_tmem_cols) * tCtAcc_fake.stride[0][1],
),
),
)
else:
# (MMA, MMA_M, MMA_N, STAGE)
tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage))
#
# Cluster wait before tensor memory alloc
@ -1062,7 +1100,6 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
# Specialized Schedule warp
#
if warp_idx == self.sched_warp_id:
# cute.arch.warpgroup_reg_dealloc(self.num_regs_sched_warps)
#
# Persistent tile scheduling loop
#
@ -1078,10 +1115,11 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
while work_tile.is_valid_tile:
cur_tile_coord = work_tile.tile_idx
if cur_tile_coord[0] < num_non_exiting_tiles[0]:
mma_tile_coord_m = cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape)
if mma_tile_coord_m < num_non_exiting_tiles[0]:
tile_info_pipeline.producer_acquire(tile_info_producer_state)
cur_tile_coord = work_tile.tile_idx
expert_idx = tile_idx_to_expert_idx[cur_tile_coord[0]]
expert_idx = tile_idx_to_expert_idx[mma_tile_coord_m]
with cute.arch.elect_one():
sInfo[(0, tile_info_producer_state.index)] = cur_tile_coord[0]
sInfo[(1, tile_info_producer_state.index)] = cur_tile_coord[1]
@ -1121,7 +1159,6 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
# Specialized TMA load warp
#
if warp_idx == self.tma_warp_id:
# cute.arch.warpgroup_reg_dealloc(self.num_regs_uniform_warps)
#
# Persistent tile scheduling loop
#
@ -1169,6 +1206,7 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
# ((atom_v, rest_v), RestK)
tAgSFA_slice = tAgSFA[(None, mma_tile_coord_mnl[0], None, 0)]
# Apply SFB slicing hack when cta_tile_shape_n=64
slice_n = mma_tile_coord_mnl[1]
if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 64):
slice_n = mma_tile_coord_mnl[1] // 2
@ -1272,7 +1310,7 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
# Make SFA tmem tensor
sfa_tmem_ptr = cute.recast_ptr(
acc_tmem_ptr + tcgen05.find_tmem_tensor_col_offset(tCtAcc_base),
acc_tmem_ptr + self.num_accumulator_tmem_cols,
dtype=self.sf_dtype,
)
# (MMA, MMA_M, MMA_K)
@ -1286,9 +1324,7 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
# Make SFB tmem tensor
sfb_tmem_ptr = cute.recast_ptr(
acc_tmem_ptr
+ tcgen05.find_tmem_tensor_col_offset(tCtAcc_base)
+ tcgen05.find_tmem_tensor_col_offset(tCtSFA),
acc_tmem_ptr + self.num_accumulator_tmem_cols + self.num_sfa_tmem_cols,
dtype=self.sf_dtype,
)
# (MMA, MMA_N, MMA_K)
@ -1352,31 +1388,34 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
if ab_consumer_state.count < k_tile_cnt and is_leader_cta:
peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_consumer_state)
# Peek (try_wait) Acc buffer empty for k_tile = 0
acc_producer_state.reset_count()
peek_acc_empty_status = cutlass.Boolean(1)
if ab_consumer_state.count < k_tile_cnt and is_leader_cta:
peek_acc_empty_status = acc_pipeline.producer_try_acquire(acc_producer_state)
mma_tile_coord_mnl = (
tile_info[0] // cute.size(tiled_mma.thr_id.shape),
tile_info[1],
tile_info[2],
)
tCtAcc = tCtAcc_base[(None, None, None, acc_producer_state.index)]
# Get accumulator stage index
if cutlass.const_expr(self.overlapping_accum):
acc_stage_index = acc_producer_state.phase ^ 1
else:
acc_stage_index = acc_producer_state.index
tCtAcc = tCtAcc_base[(None, None, None, acc_stage_index)]
# Apply TMEM pointer offset hack when cta_tile_shape_n=192 or
# cta_tile_shape_n=64
tCtSFB_mma = tCtSFB
if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 192):
# If this is an ODD tile, shift the TMEM start address for cta_tile_shape_n=192 case by two words
# If this is an ODD tile, shift the TMEM start address for
# cta_tile_shape_n=192 case by two words
# (ignores first 64 columns of SFB)
offset = (
cutlass.Int32(2) if mma_tile_coord_mnl[1] % 2 == 1 else cutlass.Int32(0)
)
shifted_ptr = cute.recast_ptr(
acc_tmem_ptr
+ tcgen05.find_tmem_tensor_col_offset(tCtAcc_base)
+ tcgen05.find_tmem_tensor_col_offset(tCtSFA)
+ self.num_accumulator_tmem_cols
+ self.num_sfa_tmem_cols
+ offset,
dtype=self.sf_dtype,
)
@ -1386,8 +1425,8 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
offset = cutlass.Int32((mma_tile_coord_mnl[1] % 2) * 2)
shifted_ptr = cute.recast_ptr(
acc_tmem_ptr
+ tcgen05.find_tmem_tensor_col_offset(tCtAcc_base)
+ tcgen05.find_tmem_tensor_col_offset(tCtSFA)
+ self.num_accumulator_tmem_cols
+ self.num_sfa_tmem_cols
+ offset,
dtype=self.sf_dtype,
)
@ -1396,7 +1435,7 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
# Wait for accumulator buffer empty
#
if is_leader_cta:
acc_pipeline.producer_acquire(acc_producer_state, peek_acc_empty_status)
acc_pipeline.producer_acquire(acc_producer_state)
#
# Mma mainloop
#
@ -1406,7 +1445,7 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
#
tiled_mma.set(tcgen05.Field.ACCUMULATE, False)
for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1):
for k_tile in cutlass.range(k_tile_cnt):
# Set tensor memory buffer for current tile
# (MMA, MMA_M, MMA_N)
@ -1485,11 +1524,6 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
# Peek (try_wait) Acc buffer empty for k_tile = k_tile + 1
acc_producer_state.advance()
if acc_producer_state.count < k_tile_cnt:
if is_leader_cta:
peek_acc_empty_status = acc_pipeline.producer_try_acquire(
acc_producer_state
)
#
# Advance to next tile
@ -1533,7 +1567,7 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
#
# Partition for epilogue
#
epi_tidx = tidx
epi_tidx = tidx % 128
(
tiled_copy_t2r,
tTR_tAcc_base,
@ -1562,17 +1596,18 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
if cutlass.const_expr(self.generate_sfc):
norm_const = norm_const_tensor[0]
# (EPI_TILE_M, EPI_TILE_N, RestM, RestN, RestL)
gSFD_mnl = cute.local_tile(mSFD_mnl, epi_tile, (None, None, None))
gSFC_mnl = cute.local_tile(mSFC_mnl, epi_tile, (None, None, None))
thr_copy_t2r = tiled_copy_t2r.get_slice(tidx)
# (T2R, T2R_M, T2R_N, RestM, RestN, RestL)
tCgSFD_mnl = thr_copy_t2r.partition_D(gSFD_mnl)
tCgSFD_mnl = cute.filter_zeros(tCgSFD_mnl)
tCgSFC_mnl = thr_copy_t2r.partition_D(gSFC_mnl)
tCgSFC_mnl = cute.filter_zeros(tCgSFC_mnl)
# (T2R, T2R_M, T2R_N)
tCrSFD = cute.make_rmem_tensor(
tCgSFD_mnl[(None, None, None, 0, 0, 0)].layout, self.sf_dtype
tCrSFC = cute.make_rmem_tensor(
tCgSFC_mnl[(None, None, None, 0, 0, 0)].layout, self.sf_dtype
)
tCrSFD_pvscale = cute.make_rmem_tensor_like(tCrSFD, cutlass.Float32)
tCrSFC_pvscale = cute.make_rmem_tensor_like(tCrSFC, cutlass.Float32)
#
# Persistent tile scheduling loop
#
@ -1614,6 +1649,7 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
tile_info_pipeline.consumer_release(tile_info_consumer_state)
tile_info_consumer_state.advance()
num_prev_subtiles = cutlass.Int32(0)
while is_valid_tile:
mma_tile_coord_mnl = (
tile_info[0] // cute.size(tiled_mma.thr_id.shape),
@ -1643,13 +1679,22 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
)
]
# Get accumulator stage index
if cutlass.const_expr(self.overlapping_accum):
acc_stage_index = acc_consumer_state.phase
reverse_subtile = (
cutlass.Boolean(True) if acc_stage_index == 0 else cutlass.Boolean(False)
)
else:
acc_stage_index = acc_consumer_state.index
# Set tensor memory buffer for current tile
# (T2R, T2R_M, T2R_N, EPI_M, EPI_M)
tTR_tAcc = tTR_tAcc_base[(None, None, None, None, None, acc_consumer_state.index)]
tTR_tAcc = tTR_tAcc_base[(None, None, None, None, None, acc_stage_index)]
if cutlass.const_expr(self.generate_sfc):
# (T2R, T2R_M, T2R_N, RestM, RestN)
tCgSFD_mn = tCgSFD_mnl[
tCgSFC_mn = tCgSFC_mnl[
(
None,
None,
@ -1669,28 +1714,54 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
bSG_gC = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC))
#
# Store accumulator to global memory in sub-tiles
# Process accumulator subtiles with SwiGLU fusion and store to global memory
# Each iteration processes a pair of subtiles (up, gate) and computes
# up * silu(gate)
#
subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3])
num_prev_subtiles = tile_sched.num_tiles_executed * subtile_cnt
for subtile_idx in cutlass.range(0, subtile_cnt, 2):
real_subtile_idx = subtile_idx // 2
if cutlass.const_expr(self.overlapping_accum):
if reverse_subtile:
real_subtile_idx = (
self.cta_tile_shape_mnk[1] // self.epi_tile_n_required
- 1
- subtile_idx // 2
)
#
# Load accumulator from tensor memory buffer to register
#
tTR_tAcc_mn_up = tTR_tAcc[(None, None, None, subtile_idx)]
tTR_tAcc_mn_gate = tTR_tAcc[(None, None, None, subtile_idx + 1)]
tTR_tAcc_mn_up = tTR_tAcc[(None, None, None, real_subtile_idx * 2)]
tTR_tAcc_mn_gate = tTR_tAcc[(None, None, None, real_subtile_idx * 2 + 1)]
cute.copy(tiled_copy_t2r, tTR_tAcc_mn_up, tTR_rAcc_up)
cute.copy(tiled_copy_t2r, tTR_tAcc_mn_gate, tTR_rAcc_gate)
#
# Async arrive accumulator buffer empty earlier when overlapping_accum is enabled
#
if cutlass.const_expr(self.overlapping_accum):
if subtile_idx // 2 == self.iter_acc_early_release_in_epilogue:
# Fence for TMEM load
cute.arch.fence_view_async_tmem_load()
with cute.arch.elect_one():
acc_pipeline.consumer_release(acc_consumer_state)
acc_consumer_state.advance()
acc_vec_up = tTR_rAcc_up.load()
acc_vec_gate = tTR_rAcc_gate.load()
# SwiGlu
#
# SwiGLU activation: output = up * silu(gate)
# where silu(x) = x * sigmoid(x)
# up and gate are extracted from interleaved accumulator subtiles
#
tCompute = cute.make_rmem_tensor(acc_vec_gate.shape, self.acc_dtype)
if cutlass.const_expr(self.vectorized_f32):
# SwiGlu Packed Version
# SwiGLU Packed Version: uses f32x2 packed operations for better performance
# Computes: output = (alpha * up) * silu(alpha * gate)
# where silu(x) = x * sigmoid(x) = x / (1 + exp(-x))
LOG2_E = cutlass.Float32(1.4426950408889634)
for i in cutlass.range_constexpr(0, cute.size(tTR_rAcc_up), 2):
acc_vec_up_alpha = cute.arch.mul_packed_f32x2(
@ -1731,7 +1802,8 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
(acc_vec_up_alpha[0], acc_vec_up_alpha[1]),
)
else:
# SwiGlu Unpacked Version
# SwiGLU Unpacked Version: scalar operations
# Computes: output = (alpha * up) * silu(alpha * gate)
for i in cutlass.range_constexpr(cute.size(tTR_rAcc_up)):
acc_vec_up_alpha = acc_vec_up[i] * cutlass.Float32(alpha_val)
acc_vec_gate_alpha = acc_vec_gate[i] * cutlass.Float32(alpha_val)
@ -1740,12 +1812,19 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
)
if cutlass.const_expr(self.generate_sfc):
#
# Quantization path for Float4E2M1FN output:
# 1. Compute per-vector absolute max from SwiGLU result
# 2. Generate scale factor C (SFC) based on max values
# 3. Store SFC to global memory
# 4. Quantize output by scaling with reciprocal of SFC
#
# Assume subtile partitioned always happens on n dimension
sfc_subtile_idx_mn = (
tile_info[0] * self.epi_tile_cnt[0],
tile_info[1] * self.epi_tile_cnt[1] + subtile_idx // 2,
tile_info[1] * self.epi_tile_cnt[1] + real_subtile_idx,
)
tCgSFD = tCgSFD_mn[
tCgSFC = tCgSFC_mn[
(
None,
None,
@ -1769,30 +1848,30 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
if cutlass.const_expr(self.vectorized_f32):
for vi in cutlass.range_constexpr(abs_acc_frg.shape[1]):
tCrSFD_pvscale[vi] = abs_acc_frg[None, vi].reduce(
tCrSFC_pvscale[vi] = abs_acc_frg[None, vi].reduce(
cute.ReductionOp.MAX,
cutlass.Float32(0.0),
0, # Use 0.0 as init for abs values
)
for vi in cutlass.range_constexpr(0, abs_acc_frg.shape[1], 2):
tCrSFD_pvscale[vi], tCrSFD_pvscale[vi + 1] = (
tCrSFC_pvscale[vi], tCrSFC_pvscale[vi + 1] = (
cute.arch.mul_packed_f32x2(
(tCrSFD_pvscale[vi], tCrSFD_pvscale[vi + 1]),
(tCrSFC_pvscale[vi], tCrSFC_pvscale[vi + 1]),
(
self.get_dtype_rcp_limits(self.c_dtype),
self.get_dtype_rcp_limits(self.c_dtype),
),
)
)
tCrSFD_pvscale[vi], tCrSFD_pvscale[vi + 1] = (
tCrSFC_pvscale[vi], tCrSFC_pvscale[vi + 1] = (
cute.arch.mul_packed_f32x2(
(tCrSFD_pvscale[vi], tCrSFD_pvscale[vi + 1]),
(tCrSFC_pvscale[vi], tCrSFC_pvscale[vi + 1]),
(norm_const, norm_const),
)
)
else:
for vi in cutlass.range_constexpr(abs_acc_frg.shape[1]):
tCrSFD_pvscale[vi] = (
tCrSFC_pvscale[vi] = (
abs_acc_frg[None, vi].reduce(
cute.ReductionOp.MAX,
cutlass.Float32(0.0),
@ -1803,27 +1882,27 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
)
# TODO: need to add f32x2 -> f8x2 conversion
tCrSFD.store(tCrSFD_pvscale.load().to(self.sf_dtype))
tCrSFC.store(tCrSFC_pvscale.load().to(self.sf_dtype))
#
# Store SFC to global memory
#
# TODO: Need to think about predicate on it
# if cute.elem_less():
cute.autovec_copy(tCrSFD, tCgSFD)
cute.autovec_copy(tCrSFC, tCgSFC)
#
# Compute quantized output values and convert to C type
#
# TODO: need to add f8x2 -> f32x2 conversion
tCrSFD_qpvscale_up = tCrSFD.load().to(cutlass.Float32)
tCrSFC_qpvscale_up = tCrSFC.load().to(cutlass.Float32)
fp32_max = cutlass.Float32(3.40282346638528859812e38)
if cutlass.const_expr(self.vectorized_f32):
for vi in cutlass.range_constexpr(0, cute.size(tCrSFD), 2):
for vi in cutlass.range_constexpr(0, cute.size(tCrSFC), 2):
acc_scale = cute.arch.mul_packed_f32x2(
(
cute.arch.rcp_approx(tCrSFD_qpvscale_up[vi]),
cute.arch.rcp_approx(tCrSFD_qpvscale_up[vi + 1]),
cute.arch.rcp_approx(tCrSFC_qpvscale_up[vi]),
cute.arch.rcp_approx(tCrSFC_qpvscale_up[vi + 1]),
),
(norm_const, norm_const),
)
@ -1838,10 +1917,10 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
(acc_scale_min0, acc_scale_min1),
)
else:
for vi in cutlass.range_constexpr(cute.size(tCrSFD)):
for vi in cutlass.range_constexpr(cute.size(tCrSFC)):
# TODO:Need to add E8M0 rcp approximation
acc_scale = norm_const * cute.arch.rcp_approx(
tCrSFD_qpvscale_up[vi]
tCrSFC_qpvscale_up[vi]
)
acc_scale = fmin(acc_scale, fp32_max, nan=True)
@ -1862,7 +1941,8 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
#
# Store C to shared memory
#
c_buffer = (num_prev_subtiles + subtile_idx // 2) % self.num_c_stage
num_prev_subtiles = num_prev_subtiles + 1
c_buffer = num_prev_subtiles % self.num_c_stage
cute.copy(
tiled_copy_r2s,
@ -1882,7 +1962,7 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
cute.copy(
tma_atom_c,
bSG_sC[(None, c_buffer)],
bSG_gC[(None, subtile_idx // 2)],
bSG_gC[(None, real_subtile_idx)],
)
# Fence and barrier to make sure shared memory store is visible to TMA store
c_pipeline.producer_commit()
@ -1892,9 +1972,10 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
#
# Async arrive accumulator buffer empty
#
with cute.arch.elect_one():
acc_pipeline.consumer_release(acc_consumer_state)
acc_consumer_state.advance()
if cutlass.const_expr(not self.overlapping_accum):
with cute.arch.elect_one():
acc_pipeline.consumer_release(acc_consumer_state)
acc_consumer_state.advance()
#
# Advance to next tile
@ -1931,8 +2012,8 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
use_2cta_instrs: Union[cutlass.Boolean, bool],
) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor, cute.Tensor]:
"""
Make tiledCopy for tensor memory load, then use it to partition tensor memory (source) and register array
(destination).
Make tiledCopy for tensor memory load, then use it to partition tensor memory
(source) and register array (destination).
:param tidx: The thread index in epilogue warp groups
:type tidx: cutlass.Int32
@ -1998,8 +2079,8 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
sC: cute.Tensor,
) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]:
"""
Make tiledCopy for shared memory store, then use it to partition register array (source) and shared memory
(destination).
Make tiledCopy for shared memory store, then use it to partition register
array (source) and shared memory (destination).
:param tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r)
:type tiled_copy_t2r: cute.TiledCopy
@ -2120,8 +2201,6 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
# Default ACC stages
num_acc_stage = 1 if mma_tiler_mnk[1] == 256 else 2
# num_acc_stage = 1
# Default C stages
num_c_stage = 2
@ -2178,9 +2257,6 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
# Start with total smem per CTA (capacity / occupancy)
# Subtract reserved bytes and initial C stages bytes
# Divide remaining by bytes needed per A/B stage
# cute.printf("num_smem_capacity: {}, occupancy: {}, mbar_helpers_bytes: {}, c_bytes: {}", num_smem_capacity,
# occupancy, mbar_helpers_bytes, c_bytes)
# cute.printf("ab_bytes_per_stage: {}", ab_bytes_per_stage)
num_ab_stage = (
num_smem_capacity // occupancy - (mbar_helpers_bytes + c_bytes)
) // ab_bytes_per_stage
@ -2369,7 +2445,6 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
@staticmethod
def is_valid_mma_tiler_and_cluster_shape(
use_2cta_instrs: bool,
mma_tiler_mn: Tuple[int, int],
cluster_shape_mn: Tuple[int, int],
) -> bool:
@ -2389,19 +2464,18 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
is_valid = True
# Skip invalid mma tile shape
if not (
(not use_2cta_instrs and mma_tiler_mn[0] in [64, 128])
or (use_2cta_instrs and mma_tiler_mn[0] in [128, 256])
):
if mma_tiler_mn[0] not in (128, 256):
is_valid = False
# Skip invalid mma tile n
# Needs to have even iterations with Epi Tile N 64 for swiGeLU fusion
# SwiGlu Fusion requires even epi_tile counts,
# based on epi_tile_n = 64, only mma_tiler_n = 128 and 256 are supported
if mma_tiler_mn[1] not in (128, 256):
is_valid = False
# Skip illegal cluster shape
if cluster_shape_mn[0] % (2 if use_2cta_instrs else 1) != 0:
if (mma_tiler_mn[0] // cluster_shape_mn[0]) != 128:
is_valid = False
# Skip invalid cluster shape
if (
cluster_shape_mn[0] * cluster_shape_mn[1] > 16
or cluster_shape_mn[0] <= 0
@ -2414,14 +2488,6 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
or not is_power_of_2(cluster_shape_mn[1])
):
is_valid = False
cluster_tiler_m = (cluster_shape_mn[0] // (2 if use_2cta_instrs else 1)) * mma_tiler_mn[0]
# Skip invalid cluster tiler shape since contiguous layout can't handle oob access
# The contiguous layout means the aligned data is stored in a contiguous manner.
# It can't handle runtime oob when alignment is not align with the tile_M,
# since the problem shape of TMA store can't be changed at runtime.
if cluster_tiler_m not in [64, 128, 256]:
is_valid = False
return is_valid
@ -2540,10 +2606,7 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
can_implement = False
# Skip invalid mma tile shape and cluster shape
use_2cta_instrs = mma_tiler_mn[0] == 256
if not cls.is_valid_mma_tiler_and_cluster_shape(
use_2cta_instrs, mma_tiler_mn, cluster_shape_mn
):
if not cls.is_valid_mma_tiler_and_cluster_shape(mma_tiler_mn, cluster_shape_mn):
can_implement = False
# Skip illegal problem shape for load/store alignment
if not cls.is_valid_tensor_alignment(
@ -2630,3 +2693,18 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
stream=stream,
epilogue_op=epilogue_op,
)
@cute.jit
def cvt_sf_MKL_to_M32x4xrm_K4xrk_L(
sf_ref_tensor: cute.Tensor,
sf_mma_tensor: cute.Tensor,
):
"""Convert scale factor tensor from MKL layout to mma specification M(32x4xrest_m)xK(4xrest_k)xL layout"""
# sf_mma_tensor has flatten shape (32, 4, rest_m, 4, rest_k, l)
# group to ((32, 4, rest_m), (4, rest_k), l)
sf_mma_tensor = cute.group_modes(sf_mma_tensor, 0, 3)
sf_mma_tensor = cute.group_modes(sf_mma_tensor, 1, 3)
for i in cutlass.range(cute.size(sf_ref_tensor)):
mkl_coord = sf_ref_tensor.layout.get_hier_coord(i)
sf_mma_tensor[mkl_coord] = sf_ref_tensor[mkl_coord]

View File

@ -443,7 +443,6 @@ class PipelineCpAsyncUmma(PipelineAsync):
barrier_storage: cute.Pointer = None,
cta_layout_vmnk: Optional[cute.Layout] = None,
defer_sync: bool = False,
enable_cp_async: bool = False,
):
"""Creates and initializes a new PipelineCpAsyncUmma instance.
@ -459,8 +458,6 @@ class PipelineCpAsyncUmma(PipelineAsync):
:type cta_layout_vmnk: cute.Layout, optional
:param defer_sync: Whether to defer the sync
:type defer_sync: bool, optional
:param enable_cp_async: Whether to enable cp.async instructions
:type enable_cp_async: bool, optional
:raises ValueError: If barrier_storage is not a cute.Pointer instance
:return: A new PipelineCpAsyncUmma instance configured with the provided parameters
:rtype: PipelineCpAsyncUmma
@ -470,7 +467,7 @@ class PipelineCpAsyncUmma(PipelineAsync):
f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}"
)
producer_type = PipelineOp.AsyncLoad if enable_cp_async else PipelineOp.AsyncThread
producer_type = PipelineOp.AsyncLoad
consumer_type = PipelineOp.TCGen05Mma
producer = (producer_type, producer_group)

View File

@ -349,7 +349,6 @@ def run(
):
"""Prepare A/B/C tensors, launch GPU kernel, and reference checking."""
m_aligned = mma_tiler_mn[0]
use_2cta_instrs = mma_tiler_mn[0] == 256 and cluster_shape_mn[0] % 2 == 0
print("Running Blackwell Persistent Dense Contiguous Grouped GEMM test with:")
print(f"nkl: {nkl}")
@ -362,8 +361,6 @@ def run(
print(f"Padded M (CUDA graph support): {permuted_m}")
print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {c_major}")
print(f"Mma Tiler (M, N): {mma_tiler_mn}, Cluster Shape (M, N): {cluster_shape_mn}")
print(f"2CTA MMA instructions: {'True' if use_2cta_instrs else 'False'}")
print(f"Use TMA Store: {'True'}")
print(f"Tolerance: {tolerance}")
print(f"Warmup iterations: {warmup_iterations}")
print(f"Iterations: {iterations}")
@ -394,7 +391,7 @@ def run(
):
raise TypeError(
f"Unsupported testcase {ab_dtype}, {sf_dtype}, {sf_vec_size}, {c_dtype},"
f"{use_2cta_instrs},{mma_tiler_mn}, {cluster_shape_mn}, {n}, {k}, {l},"
f"{mma_tiler_mn}, {cluster_shape_mn}, {n}, {k}, {l},"
f"{a_major}, {b_major}, {c_major}, {m_aligned}"
)
@ -727,7 +724,9 @@ if __name__ == "__main__":
f"Invalid benchmark argument format. Expected file path, 'MxNxKxL', or '[m0,m1,...]xNxK'. Got: {arg}"
)
parser = argparse.ArgumentParser(description="Example of Dense Persistent GEMM on Blackwell.")
parser = argparse.ArgumentParser(
description="Example of BlockScaled Contiguous grouped GEMM kernel on Blackwell."
)
parser.add_argument(
"--nkl",

View File

@ -532,7 +532,6 @@ def run(
Defaults to False.
"""
m_aligned = mma_tiler_mn[0]
use_2cta_instrs = mma_tiler_mn[0] == 256 and cluster_shape_mn[0] % 2 == 0
print("Running Blackwell Persistent Dense Contiguous Grouped GEMM test with:")
print(f"nkl: {nkl}")
@ -547,8 +546,6 @@ def run(
print(f"Sequence length: {seq_len}")
print(f"Matrix majors - A: {a_major}, B: {b_major}, Out: {out_major}")
print(f"Mma Tiler (M, N): {mma_tiler_mn}, Cluster Shape (M, N): {cluster_shape_mn}")
print(f"2CTA MMA instructions: {'True' if use_2cta_instrs else 'False'}")
print(f"Use TMA Store: {'True'}")
print(f"Tolerance: {tolerance}")
print(f"Warmup iterations: {warmup_iterations}")
print(f"Iterations: {iterations}")
@ -581,7 +578,7 @@ def run(
):
raise TypeError(
f"Unsupported testcase {ab_dtype}, {sf_dtype}, {sf_vec_size}, {out_dtype}, "
f"{use_2cta_instrs}, {mma_tiler_mn}, {cluster_shape_mn}, {n}, {k}, {l}, "
f"{mma_tiler_mn}, {cluster_shape_mn}, {n}, {k}, {l}, "
f"{a_major}, {b_major}, {out_major}, {m_aligned}"
)
@ -957,7 +954,9 @@ if __name__ == "__main__":
f"Invalid benchmark argument format. Expected file path, 'MxNxKxL', or '[m0,m1,...]xNxK'. Got: {arg}"
)
parser = argparse.ArgumentParser(description="Example of Dense Persistent GEMM on Blackwell.")
parser = argparse.ArgumentParser(
description="Example of BlockScaled Contiguous grouped GEMM finalize fusion kernel on Blackwell."
)
parser.add_argument(
"--nkl",

View File

@ -594,7 +594,7 @@ def test_nvfp4_grouped_gemm_finalize_blackwell(
get_sm_version() not in (100, 103),
reason="This test is only supported on SM 100 and SM 103 GPUs",
)
@pytest.mark.parametrize("tile_size", [128])
@pytest.mark.parametrize("tile_size", [128, 256])
@pytest.mark.parametrize("ep_size", [1, 8, 32])
@pytest.mark.parametrize("top_k", [1, 2, 8])
@pytest.mark.parametrize("num_tokens", [128, 515, 1024, 8192])