mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[None][feat] CuteDSL MOE FC1 Enhancement (#10088)
Signed-off-by: Yuhan Li <51736452+liyuhannnnn@users.noreply.github.com>
This commit is contained in:
parent
77712ed4ab
commit
6b8ae6fa81
@ -1518,9 +1518,9 @@ if IS_CUTLASS_DSL_AVAILABLE:
|
||||
f"{self.__class__.kernel_class.__name__} supports SM 100 (B200) and SM 103 (B300) only, but got SM {sm_version}"
|
||||
)
|
||||
|
||||
if self.tile_size not in (128, ):
|
||||
if self.tile_size not in (128, 256):
|
||||
raise ValueError(
|
||||
f"{self.__class__.kernel_class.__name__} supports tile_size (MMA tile M dimension) 128 only, but got {self.tile_size}"
|
||||
f"{self.__class__.kernel_class.__name__} supports tile_size (MMA tile M dimension) 128 and 256 only, but got {self.tile_size}"
|
||||
)
|
||||
|
||||
def unique_id(self):
|
||||
|
||||
@ -171,8 +171,9 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
|
||||
- Token ID mapping enables efficient gather operation during A/SFA load
|
||||
- SwiGLU activation fusion in epilogue (up * silu(gate) with interleaved weights)
|
||||
- Optional quantization fusion for Float4E2M1FN output with scale factor generation
|
||||
- Warp specialization: Scheduler (warp 10), LDGSTS A/SFA (warps 4-7), TMA B/SFB (warp 9),
|
||||
MMA (warp 8), Epilogue (warps 0-3)
|
||||
- Warp specialization: Scheduler (warp 10), A Sync Transform (warp 11, only used when
|
||||
use_2cta_instrs is True), LDGSTS A/SFA (warps 4-7), TMA B/SFB (warp 9), MMA (warp 8),
|
||||
Epilogue (warps 0-3)
|
||||
|
||||
:param sf_vec_size: Scalefactor vector size (16 for NVF4, 32 for MXF4/MXF8).
|
||||
:type sf_vec_size: int
|
||||
@ -302,6 +303,7 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
|
||||
self.mma_warp_id = 8
|
||||
self.tma_b_warp_id = 9
|
||||
self.sched_warp_id = 10
|
||||
self.sync_transform_warp_id = 11
|
||||
self.threads_per_warp = 32
|
||||
self.threads_per_cta = self.threads_per_warp * len(
|
||||
(
|
||||
@ -310,16 +312,30 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
|
||||
self.tma_b_warp_id,
|
||||
*self.epilog_warp_id,
|
||||
self.sched_warp_id,
|
||||
self.sync_transform_warp_id,
|
||||
)
|
||||
)
|
||||
self.threads_wo_sched = self.threads_per_warp * len(
|
||||
(
|
||||
*self.epilog_warp_id,
|
||||
self.mma_warp_id,
|
||||
self.tma_b_warp_id,
|
||||
*self.ldgsts_a_warp_id,
|
||||
self.warps_wo_sched = (
|
||||
len(
|
||||
(
|
||||
*self.epilog_warp_id,
|
||||
self.mma_warp_id,
|
||||
self.tma_b_warp_id,
|
||||
self.sync_transform_warp_id,
|
||||
*self.ldgsts_a_warp_id,
|
||||
)
|
||||
)
|
||||
if self.use_2cta_instrs
|
||||
else len(
|
||||
(
|
||||
*self.epilog_warp_id,
|
||||
self.mma_warp_id,
|
||||
self.tma_b_warp_id,
|
||||
*self.ldgsts_a_warp_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
self.threads_wo_sched = self.threads_per_warp * self.warps_wo_sched
|
||||
|
||||
# Set barrier for cta sync, epilogue sync and tmem ptr sync
|
||||
self.cta_sync_barrier = pipeline.NamedBarrier(
|
||||
@ -338,9 +354,11 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
|
||||
barrier_id=4,
|
||||
num_threads=self.threads_per_warp,
|
||||
)
|
||||
|
||||
self.num_smem_capacity = utils.get_smem_capacity_in_bytes("sm_100")
|
||||
SM100_TMEM_CAPACITY_COLUMNS = 512
|
||||
self.num_tmem_alloc_cols = SM100_TMEM_CAPACITY_COLUMNS
|
||||
|
||||
self.vectorized_f32 = vectorized_f32
|
||||
|
||||
def _setup_attributes(self):
|
||||
@ -410,6 +428,12 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
|
||||
mma_inst_shape_k * mma_inst_tile_k,
|
||||
)
|
||||
|
||||
self.mma_tiler_c = (
|
||||
self.mma_inst_shape_mn[0],
|
||||
self.mma_inst_shape_mn[1] // 2,
|
||||
mma_inst_shape_k * mma_inst_tile_k,
|
||||
)
|
||||
|
||||
self.cta_tile_shape_mnk = (
|
||||
self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape),
|
||||
self.mma_tiler[1],
|
||||
@ -422,10 +446,10 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
|
||||
self.mma_tiler_sfa[2],
|
||||
)
|
||||
|
||||
self.mma_tiler_c = (
|
||||
self.mma_inst_shape_mn[0],
|
||||
self.mma_inst_shape_mn[1] // 2,
|
||||
mma_inst_shape_k * mma_inst_tile_k,
|
||||
self.cta_tile_shape_mnk_sfb = (
|
||||
self.mma_tiler_sfb[0] // cute.size(tiled_mma.thr_id.shape),
|
||||
self.mma_tiler_sfb[1],
|
||||
self.mma_tiler_sfb[2],
|
||||
)
|
||||
|
||||
self.cta_tile_shape_mnk_c = (
|
||||
@ -509,8 +533,23 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
|
||||
self.num_c_stage,
|
||||
)
|
||||
|
||||
# Compute the number of tensor memory allocation columns
|
||||
self.num_tmem_alloc_cols = 512
|
||||
# Overlap and double buffer accumulator when num_acc_stage == 1 for cta_tile_n = 256 case
|
||||
self.overlapping_accum = self.num_acc_stage == 1
|
||||
|
||||
# Compute number of TMEM columns for SFA/SFB/Accumulator
|
||||
sf_atom_mn = 32
|
||||
self.num_sfa_tmem_cols = (self.cta_tile_shape_mnk[0] // sf_atom_mn) * mma_inst_tile_k
|
||||
self.num_sfb_tmem_cols = (self.cta_tile_shape_mnk_sfb[1] // sf_atom_mn) * mma_inst_tile_k
|
||||
self.num_sf_tmem_cols = self.num_sfa_tmem_cols + self.num_sfb_tmem_cols
|
||||
self.num_accumulator_tmem_cols = (
|
||||
self.cta_tile_shape_mnk[1] * self.num_acc_stage
|
||||
if not self.overlapping_accum
|
||||
else self.cta_tile_shape_mnk[1] * 2 - self.num_sf_tmem_cols
|
||||
)
|
||||
|
||||
self.epi_tile_n_required = 2 * cute.size(self.epi_tile[1])
|
||||
# Only when overlapping_accum is enabled, we need to release accumulator buffer early in epilogue
|
||||
self.iter_acc_early_release_in_epilogue = self.num_sf_tmem_cols // self.epi_tile_n_required
|
||||
|
||||
@cute.jit
|
||||
def __call__(
|
||||
@ -551,6 +590,8 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
|
||||
5. Launch the kernel synchronously with warp specialization:
|
||||
- Scheduler warp: Dispatches tile information
|
||||
- LDGSTS warps: Load A and SFA with gather
|
||||
- A Sync Transform warps: Transform the sync signal of A and SFA from global to
|
||||
shared memory when use_2cta_instrs is True
|
||||
- TMA warp: Load B and SFB with multicast
|
||||
- MMA warp: Perform matrix multiply-accumulate
|
||||
- Epilogue warps: Apply SwiGLU activation, optional quantization, and store results
|
||||
@ -606,10 +647,12 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
|
||||
# Setup attributes that dependent on gemm inputs
|
||||
self._setup_attributes()
|
||||
|
||||
# Setup sfb tensor by filling B tensor to scale factor atom layout
|
||||
# ((Atom_N, Rest_N),(Atom_K, Rest_K),RestL)
|
||||
sfb_layout = blockscaled_utils.tile_atom_to_shape_SF(b.shape, self.sf_vec_size)
|
||||
sfb = cute.make_tensor(sfb.iterator, sfb_layout)
|
||||
|
||||
# Setup sfc tensor by filling C tensor to scale factor atom layout
|
||||
self.generate_sfc = sfc_tensor is not None and norm_const_tensor is not None
|
||||
if cutlass.const_expr(self.generate_sfc):
|
||||
sfc_layout = blockscaled_utils.tile_atom_to_shape_SF(c.shape, self.sf_vec_size)
|
||||
@ -707,7 +750,7 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
|
||||
|
||||
# Define shared storage for kernel
|
||||
@cute.struct
|
||||
class SharedStorage:
|
||||
class SharedStorage1cta:
|
||||
# (bidx, bidy, bidz, valid, mn_limit)
|
||||
sInfo: cute.struct.Align[
|
||||
cute.struct.MemRange[cutlass.Int32, 5 * self.num_tile_stage],
|
||||
@ -749,7 +792,53 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
|
||||
self.buffer_align_bytes,
|
||||
]
|
||||
|
||||
self.shared_storage = SharedStorage
|
||||
@cute.struct
|
||||
class SharedStorage2cta:
|
||||
# (bidx, bidy, bidz, valid, mn_limit)
|
||||
sInfo: cute.struct.Align[
|
||||
cute.struct.MemRange[cutlass.Int32, 5 * self.num_tile_stage],
|
||||
# 1 byte alignment
|
||||
1,
|
||||
]
|
||||
a_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
|
||||
a_sync_transform_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
|
||||
b_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
|
||||
acc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
|
||||
tile_info_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_tile_stage * 2]
|
||||
tmem_dealloc_mbar_ptr: cutlass.Int64
|
||||
tmem_holding_buf: cutlass.Int32
|
||||
# (EPI_TILE_M, EPI_TILE_N, STAGE)
|
||||
sC: cute.struct.Align[
|
||||
cute.struct.MemRange[
|
||||
self.c_dtype,
|
||||
cute.cosize(self.c_smem_layout_staged.outer),
|
||||
],
|
||||
self.buffer_align_bytes,
|
||||
]
|
||||
# (MMA, MMA_M, MMA_K, STAGE)
|
||||
sA: cute.struct.Align[
|
||||
cute.struct.MemRange[self.a_dtype, cute.cosize(self.a_smem_layout_staged.outer)],
|
||||
self.buffer_align_bytes,
|
||||
]
|
||||
# (MMA, MMA_N, MMA_K, STAGE)
|
||||
sB: cute.struct.Align[
|
||||
cute.struct.MemRange[self.b_dtype, cute.cosize(self.b_smem_layout_staged.outer)],
|
||||
self.buffer_align_bytes,
|
||||
]
|
||||
# (granularity_m, repeat_m), (granularity_k, repeat_k), num_scale_stage)
|
||||
sSFA: cute.struct.Align[
|
||||
cute.struct.MemRange[self.sf_dtype, cute.cosize(self.sfa_smem_layout_staged)],
|
||||
self.buffer_align_bytes,
|
||||
]
|
||||
# (granularity_n, repeat_n), (granularity_k, repeat_k), num_scale_stage)
|
||||
sSFB: cute.struct.Align[
|
||||
cute.struct.MemRange[self.sf_dtype, cute.cosize(self.sfb_smem_layout_staged)],
|
||||
self.buffer_align_bytes,
|
||||
]
|
||||
|
||||
self.shared_storage = (
|
||||
SharedStorage2cta if cutlass.const_expr(self.use_2cta_instrs) else SharedStorage1cta
|
||||
)
|
||||
|
||||
# Launch the kernel synchronously
|
||||
self.kernel(
|
||||
@ -875,9 +964,7 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
|
||||
# Prefetch tma desc
|
||||
#
|
||||
if warp_idx == self.tma_b_warp_id:
|
||||
# cpasync.prefetch_descriptor(tma_atom_a)
|
||||
cpasync.prefetch_descriptor(tma_atom_b)
|
||||
# cpasync.prefetch_descriptor(tma_atom_sfa)
|
||||
cpasync.prefetch_descriptor(tma_atom_sfb)
|
||||
cpasync.prefetch_descriptor(tma_atom_c)
|
||||
|
||||
@ -911,10 +998,7 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
|
||||
# Consumer: MMA warp for consuming A/SFA data
|
||||
a_pipeline_producer_group = pipeline.CooperativeGroup(
|
||||
pipeline.Agent.Thread,
|
||||
128
|
||||
* cute.size(
|
||||
cluster_layout_vmnk, mode=[0]
|
||||
), # 4 warps * 32 threads per warp = 128 threads
|
||||
self.threads_per_warp * 4,
|
||||
)
|
||||
|
||||
a_pipeline = PipelineCpAsyncUmma.create(
|
||||
@ -924,9 +1008,25 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
cta_layout_vmnk=cluster_layout_vmnk,
|
||||
defer_sync=True,
|
||||
enable_cp_async=(not self.use_2cta_instrs),
|
||||
)
|
||||
|
||||
# Pipeline Init: Initialize A SYNC Transform pipeline when use_2cta_instrs is True
|
||||
# Producer: 1 warp (warp 11) for LDGSTS SYNC transformation operations
|
||||
# Consumer: MMA warp for consuming A/SFA data
|
||||
if cutlass.const_expr(self.use_2cta_instrs):
|
||||
a_sync_transform_pipeline_producer_group = pipeline.CooperativeGroup(
|
||||
pipeline.Agent.Thread,
|
||||
32 * cute.size(cluster_layout_vmnk, mode=[0]),
|
||||
)
|
||||
a_sync_transform_pipeline = pipeline.PipelineAsyncUmma.create(
|
||||
barrier_storage=storage.a_sync_transform_mbar_ptr.data_ptr(),
|
||||
num_stages=self.num_ab_stage,
|
||||
producer_group=a_sync_transform_pipeline_producer_group,
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
cta_layout_vmnk=cluster_layout_vmnk,
|
||||
defer_sync=True,
|
||||
)
|
||||
|
||||
# Pipeline Init: Initialize B pipeline for TMA operations
|
||||
# Using PipelineTmaUmma for B/SFB since they use TMA load with multicast support
|
||||
# Producer: TMA B/SFB warp (warp 9) - 1 warp issuing TMA operations
|
||||
@ -959,7 +1059,7 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
|
||||
cta_layout_vmnk=cluster_layout_vmnk,
|
||||
)
|
||||
|
||||
# Pipeline Init: Tensor memory dealloc barrier init
|
||||
# Pipeline Init:Initialize tile info pipeline (barrier) and states
|
||||
tile_info_pipeline_producer_group = pipeline.CooperativeGroup(
|
||||
pipeline.Agent.Thread,
|
||||
self.threads_per_warp * 1,
|
||||
@ -1001,16 +1101,14 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
|
||||
sSFA = storage.sSFA.get_tensor(sfa_smem_layout_staged)
|
||||
# (granularity_n, repeat_n), (granularity_k, repeat_k), num_scale_stage)
|
||||
sSFB = storage.sSFB.get_tensor(sfb_smem_layout_staged)
|
||||
# (bidx, bidy, bidz, valid)
|
||||
# (bidx, bidy, bidz, valid, mn_limit)
|
||||
info_layout = cute.make_layout((5, self.num_tile_stage), stride=(1, 5))
|
||||
sInfo = storage.sInfo.get_tensor(info_layout)
|
||||
|
||||
#
|
||||
# Compute multicast mask for A/B buffer full
|
||||
#
|
||||
# a_full_mcast_mask = None
|
||||
b_full_mcast_mask = None
|
||||
# sfa_full_mcast_mask = None
|
||||
sfb_full_mcast_mask = None
|
||||
if cutlass.const_expr(self.is_b_mcast or use_2cta_instrs):
|
||||
b_full_mcast_mask = cpasync.create_tma_multicast_mask(
|
||||
@ -1107,7 +1205,27 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
|
||||
# (MMA, MMA_M, MMA_N)
|
||||
acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2])
|
||||
# (MMA, MMA_M, MMA_N, STAGE)
|
||||
tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage))
|
||||
if cutlass.const_expr(self.overlapping_accum):
|
||||
num_acc_stage_overlapped = 2
|
||||
tCtAcc_fake = tiled_mma.make_fragment_C(
|
||||
cute.append(acc_shape, num_acc_stage_overlapped)
|
||||
)
|
||||
# (MMA, MMA_M, MMA_N, STAGE)
|
||||
tCtAcc_fake = cute.make_tensor(
|
||||
tCtAcc_fake.iterator,
|
||||
cute.make_layout(
|
||||
tCtAcc_fake.shape,
|
||||
stride=(
|
||||
tCtAcc_fake.stride[0],
|
||||
tCtAcc_fake.stride[1],
|
||||
tCtAcc_fake.stride[2],
|
||||
(256 - self.num_sf_tmem_cols) * tCtAcc_fake.stride[0][1],
|
||||
),
|
||||
),
|
||||
)
|
||||
else:
|
||||
# (MMA, MMA_M, MMA_N, STAGE)
|
||||
tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage))
|
||||
|
||||
#
|
||||
# Cluster wait before tensor memory alloc
|
||||
@ -1120,7 +1238,7 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
|
||||
griddepcontrol_wait()
|
||||
|
||||
#
|
||||
# Specialized Schedule warp
|
||||
# Specialized Schedule Warp
|
||||
#
|
||||
if warp_idx == self.sched_warp_id:
|
||||
#
|
||||
@ -1187,7 +1305,6 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
|
||||
# with gather/permutation capability enabled by token_id_mapping
|
||||
#
|
||||
if warp_idx <= self.ldgsts_a_warp_id[-1] and warp_idx >= self.ldgsts_a_warp_id[0]:
|
||||
# cute.arch.warpgroup_reg_dealloc(self.num_regs_uniform_warps)
|
||||
#
|
||||
# Setup LDGSTS copy atoms for A and SFA
|
||||
# A: 8x LDGSTS.128 per thread with swizzle_128B for A matrix (32 elements per thread)
|
||||
@ -1274,9 +1391,6 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
|
||||
tile_info_consumer_state.advance()
|
||||
|
||||
while is_valid_tile:
|
||||
# Get tile coord from tile scheduler
|
||||
# cur_tile_coord = work_tile.tile_idx
|
||||
|
||||
# Load token IDs for gather operation
|
||||
# For A matrix: each thread loads 8 token offsets (for 8 LDGSTS.128 operations)
|
||||
# For SFA matrix: each thread loads 1 token offset (for 4 LDGSTS.32 operations)
|
||||
@ -1409,10 +1523,6 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
|
||||
sfa_atom_copy, tAgSFA_slice, tAsSFA_slice, pred=sfa_predicate_tensor
|
||||
)
|
||||
|
||||
# Signal the completion of async
|
||||
if cutlass.const_expr(self.use_2cta_instrs):
|
||||
cute.arch.cp_async_commit_group()
|
||||
cute.arch.cp_async_wait_group(0)
|
||||
a_pipeline.producer_commit(a_producer_state)
|
||||
|
||||
# Peek (try_wait) A buffer empty for k_tile = prefetch_k_tile_cnt + k_tile + 1
|
||||
@ -1440,6 +1550,83 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
|
||||
#
|
||||
a_pipeline.producer_tail(a_producer_state)
|
||||
|
||||
#
|
||||
# Specialized A/SFA Sync Transform Warp (warp 11) when use_2cta_instrs is True
|
||||
# This warp serve as sync transformation for A and SFA
|
||||
#
|
||||
if warp_idx == self.sync_transform_warp_id:
|
||||
if cutlass.const_expr(self.use_2cta_instrs):
|
||||
#
|
||||
# Persistent tile scheduling loop
|
||||
#
|
||||
tile_sched = utils.StaticPersistentTileScheduler.create(
|
||||
tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim()
|
||||
)
|
||||
# First tile
|
||||
work_tile = tile_sched.initial_work_tile_info()
|
||||
|
||||
a_consumer_state = pipeline.make_pipeline_state(
|
||||
pipeline.PipelineUserType.Consumer, self.num_ab_stage
|
||||
)
|
||||
a_sync_transform_producer_state = pipeline.make_pipeline_state(
|
||||
pipeline.PipelineUserType.Producer, self.num_ab_stage
|
||||
)
|
||||
tile_info_consumer_state = pipeline.make_pipeline_state(
|
||||
pipeline.PipelineUserType.Consumer, self.num_tile_stage
|
||||
)
|
||||
|
||||
# Get the first tile info
|
||||
valid_tile_info = cute.make_rmem_tensor((1,), cutlass.Int32)
|
||||
tile_info_pipeline.consumer_wait(tile_info_consumer_state)
|
||||
valid_tile_info[0] = sInfo[(3, tile_info_consumer_state.index)]
|
||||
is_valid_tile = valid_tile_info[0] == 1
|
||||
cute.arch.fence_proxy(
|
||||
cute.arch.ProxyKind.async_shared,
|
||||
space=cute.arch.SharedSpace.shared_cta,
|
||||
)
|
||||
tile_info_pipeline.consumer_release(tile_info_consumer_state)
|
||||
tile_info_consumer_state.advance()
|
||||
|
||||
while is_valid_tile:
|
||||
# Peek (try_wait) A buffer full for k_tile = 0
|
||||
a_consumer_state.reset_count()
|
||||
peek_a_full_status = cutlass.Boolean(1)
|
||||
if a_consumer_state.count < k_tile_cnt:
|
||||
peek_a_full_status = a_pipeline.consumer_try_wait(a_consumer_state)
|
||||
# Peek (try_wait) a sync transform buffer empty
|
||||
a_sync_transform_producer_state.reset_count()
|
||||
|
||||
for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1):
|
||||
# Conditionally wait for A buffer full
|
||||
a_pipeline.consumer_wait(a_consumer_state, peek_a_full_status)
|
||||
|
||||
a_sync_transform_pipeline.producer_commit(a_sync_transform_producer_state)
|
||||
a_sync_transform_producer_state.advance()
|
||||
|
||||
# Peek (try_wait) AB buffer full for k_tile = k_tile + 1
|
||||
a_consumer_state.advance()
|
||||
peek_a_full_status = cutlass.Boolean(1)
|
||||
if a_consumer_state.count < k_tile_cnt:
|
||||
peek_a_full_status = a_pipeline.consumer_try_wait(a_consumer_state)
|
||||
|
||||
#
|
||||
# Advance to next tile
|
||||
#
|
||||
tile_info_pipeline.consumer_wait(tile_info_consumer_state)
|
||||
valid_tile_info[0] = sInfo[(3, tile_info_consumer_state.index)]
|
||||
is_valid_tile = valid_tile_info[0] == 1
|
||||
cute.arch.fence_proxy(
|
||||
cute.arch.ProxyKind.async_shared,
|
||||
space=cute.arch.SharedSpace.shared_cta,
|
||||
)
|
||||
tile_info_pipeline.consumer_release(tile_info_consumer_state)
|
||||
tile_info_consumer_state.advance()
|
||||
|
||||
#
|
||||
# Wait A sync transform buffer empty
|
||||
#
|
||||
a_sync_transform_pipeline.producer_tail(a_sync_transform_producer_state)
|
||||
|
||||
#
|
||||
# Specialized TMA B/SFB load warp (warp 9)
|
||||
# This warp uses TMA instructions to load B and SFB from global to shared memory
|
||||
@ -1488,9 +1675,6 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
|
||||
# ((atom_v, rest_v), loopK)
|
||||
tBgB_slice = tBgB[(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])]
|
||||
|
||||
# ((atom_v, rest_v), RestK)
|
||||
# tAgSFA_slice = tAgSFA[(None, mma_tile_coord_mnl[0], None, 0)]
|
||||
|
||||
# Apply SFB slicing hack when cta_tile_shape_n=64
|
||||
slice_n = mma_tile_coord_mnl[1]
|
||||
if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 64):
|
||||
@ -1578,7 +1762,7 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
|
||||
|
||||
# Make SFA tmem tensor
|
||||
sfa_tmem_ptr = cute.recast_ptr(
|
||||
acc_tmem_ptr + tcgen05.find_tmem_tensor_col_offset(tCtAcc_base),
|
||||
acc_tmem_ptr + self.num_accumulator_tmem_cols,
|
||||
dtype=self.sf_dtype,
|
||||
)
|
||||
# (MMA, MMA_M, MMA_K)
|
||||
@ -1592,9 +1776,7 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
|
||||
|
||||
# Make SFB tmem tensor
|
||||
sfb_tmem_ptr = cute.recast_ptr(
|
||||
acc_tmem_ptr
|
||||
+ tcgen05.find_tmem_tensor_col_offset(tCtAcc_base)
|
||||
+ tcgen05.find_tmem_tensor_col_offset(tCtSFA),
|
||||
acc_tmem_ptr + self.num_accumulator_tmem_cols + self.num_sfa_tmem_cols,
|
||||
dtype=self.sf_dtype,
|
||||
)
|
||||
# (MMA, MMA_N, MMA_K)
|
||||
@ -1627,9 +1809,14 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
|
||||
)
|
||||
work_tile = tile_sched.initial_work_tile_info()
|
||||
|
||||
if cutlass.const_expr(self.use_2cta_instrs):
|
||||
a_sync_transform_consumer_state = pipeline.make_pipeline_state(
|
||||
pipeline.PipelineUserType.Consumer, self.num_ab_stage
|
||||
)
|
||||
a_consumer_state = pipeline.make_pipeline_state(
|
||||
pipeline.PipelineUserType.Consumer, self.num_ab_stage
|
||||
)
|
||||
|
||||
b_consumer_state = pipeline.make_pipeline_state(
|
||||
pipeline.PipelineUserType.Consumer, self.num_ab_stage
|
||||
)
|
||||
@ -1656,12 +1843,24 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
|
||||
|
||||
while is_valid_tile:
|
||||
# Peek (try_wait) AB buffer full for k_tile = 0
|
||||
a_consumer_state.reset_count()
|
||||
if cutlass.const_expr(self.use_2cta_instrs):
|
||||
a_sync_transform_consumer_state.reset_count()
|
||||
peek_a_sync_transform_full_status = cutlass.Boolean(1)
|
||||
if a_sync_transform_consumer_state.count < k_tile_cnt and is_leader_cta:
|
||||
peek_a_sync_transform_full_status = (
|
||||
a_sync_transform_pipeline.consumer_try_wait(
|
||||
a_sync_transform_consumer_state
|
||||
)
|
||||
)
|
||||
a_consumer_state.reset_count()
|
||||
else:
|
||||
a_consumer_state.reset_count()
|
||||
peek_a_full_status = cutlass.Boolean(1)
|
||||
if a_consumer_state.count < k_tile_cnt:
|
||||
peek_a_full_status = a_pipeline.consumer_try_wait(a_consumer_state)
|
||||
|
||||
b_consumer_state.reset_count()
|
||||
peek_a_full_status = cutlass.Boolean(1)
|
||||
peek_b_full_status = cutlass.Boolean(1)
|
||||
if a_consumer_state.count < k_tile_cnt and is_leader_cta:
|
||||
peek_a_full_status = a_pipeline.consumer_try_wait(a_consumer_state)
|
||||
if b_consumer_state.count < k_tile_cnt and is_leader_cta:
|
||||
peek_b_full_status = b_pipeline.consumer_try_wait(b_consumer_state)
|
||||
|
||||
@ -1671,11 +1870,16 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
|
||||
tile_info[2],
|
||||
)
|
||||
|
||||
tCtAcc = tCtAcc_base[(None, None, None, acc_producer_state.index)]
|
||||
# Get accumulator stage index
|
||||
if cutlass.const_expr(self.overlapping_accum):
|
||||
acc_stage_index = acc_producer_state.phase ^ 1
|
||||
else:
|
||||
acc_stage_index = acc_producer_state.index
|
||||
|
||||
tCtAcc = tCtAcc_base[(None, None, None, acc_stage_index)]
|
||||
|
||||
# Apply TMEM pointer offset hack when cta_tile_shape_n=192 or
|
||||
# cta_tile_shape_n=64
|
||||
|
||||
tCtSFB_mma = tCtSFB
|
||||
if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 192):
|
||||
# If this is an ODD tile, shift the TMEM start address for
|
||||
@ -1686,8 +1890,8 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
|
||||
)
|
||||
shifted_ptr = cute.recast_ptr(
|
||||
acc_tmem_ptr
|
||||
+ tcgen05.find_tmem_tensor_col_offset(tCtAcc_base)
|
||||
+ tcgen05.find_tmem_tensor_col_offset(tCtSFA)
|
||||
+ self.num_accumulator_tmem_cols
|
||||
+ self.num_sfa_tmem_cols
|
||||
+ offset,
|
||||
dtype=self.sf_dtype,
|
||||
)
|
||||
@ -1697,8 +1901,8 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
|
||||
offset = cutlass.Int32((mma_tile_coord_mnl[1] % 2) * 2)
|
||||
shifted_ptr = cute.recast_ptr(
|
||||
acc_tmem_ptr
|
||||
+ tcgen05.find_tmem_tensor_col_offset(tCtAcc_base)
|
||||
+ tcgen05.find_tmem_tensor_col_offset(tCtSFA)
|
||||
+ self.num_accumulator_tmem_cols
|
||||
+ self.num_sfa_tmem_cols
|
||||
+ offset,
|
||||
dtype=self.sf_dtype,
|
||||
)
|
||||
@ -1723,7 +1927,12 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
|
||||
|
||||
if is_leader_cta:
|
||||
# Conditionally wait for AB buffer full
|
||||
a_pipeline.consumer_wait(a_consumer_state, peek_a_full_status)
|
||||
if cutlass.const_expr(self.use_2cta_instrs):
|
||||
a_sync_transform_pipeline.consumer_wait(
|
||||
a_sync_transform_consumer_state, peek_a_sync_transform_full_status
|
||||
)
|
||||
else:
|
||||
a_pipeline.consumer_wait(a_consumer_state, peek_a_full_status)
|
||||
b_pipeline.consumer_wait(b_consumer_state, peek_b_full_status)
|
||||
|
||||
# Copy SFA/SFB from smem to tmem
|
||||
@ -1781,16 +1990,31 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
|
||||
|
||||
# Async arrive AB buffer empty
|
||||
a_pipeline.consumer_release(a_consumer_state)
|
||||
if cutlass.const_expr(self.use_2cta_instrs):
|
||||
a_sync_transform_pipeline.consumer_release(
|
||||
a_sync_transform_consumer_state
|
||||
)
|
||||
b_pipeline.consumer_release(b_consumer_state)
|
||||
|
||||
# Peek (try_wait) AB buffer full for k_tile = k_tile + 1
|
||||
a_consumer_state.advance()
|
||||
b_consumer_state.advance()
|
||||
peek_a_full_status = cutlass.Boolean(1)
|
||||
if a_consumer_state.count < k_tile_cnt:
|
||||
if is_leader_cta:
|
||||
if cutlass.const_expr(self.use_2cta_instrs):
|
||||
a_sync_transform_consumer_state.advance()
|
||||
peek_a_sync_transform_full_status = cutlass.Boolean(1)
|
||||
if a_sync_transform_consumer_state.count < k_tile_cnt:
|
||||
if is_leader_cta:
|
||||
peek_a_sync_transform_full_status = (
|
||||
a_sync_transform_pipeline.consumer_try_wait(
|
||||
a_sync_transform_consumer_state
|
||||
)
|
||||
)
|
||||
a_consumer_state.advance()
|
||||
else:
|
||||
a_consumer_state.advance()
|
||||
peek_a_full_status = cutlass.Boolean(1)
|
||||
if a_consumer_state.count < k_tile_cnt:
|
||||
peek_a_full_status = a_pipeline.consumer_try_wait(a_consumer_state)
|
||||
|
||||
b_consumer_state.advance()
|
||||
peek_b_full_status = cutlass.Boolean(1)
|
||||
if b_consumer_state.count < k_tile_cnt:
|
||||
if is_leader_cta:
|
||||
@ -1959,9 +2183,18 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
|
||||
)
|
||||
]
|
||||
|
||||
# Get accumulator stage index
|
||||
if cutlass.const_expr(self.overlapping_accum):
|
||||
acc_stage_index = acc_consumer_state.phase
|
||||
reverse_subtile = (
|
||||
cutlass.Boolean(True) if acc_stage_index == 0 else cutlass.Boolean(False)
|
||||
)
|
||||
else:
|
||||
acc_stage_index = acc_consumer_state.index
|
||||
|
||||
# Set tensor memory buffer for current tile
|
||||
# (T2R, T2R_M, T2R_N, EPI_M, EPI_M)
|
||||
tTR_tAcc = tTR_tAcc_base[(None, None, None, None, None, acc_consumer_state.index)]
|
||||
tTR_tAcc = tTR_tAcc_base[(None, None, None, None, None, acc_stage_index)]
|
||||
|
||||
if cutlass.const_expr(self.generate_sfc):
|
||||
# (T2R, T2R_M, T2R_N, RestM, RestN)
|
||||
@ -1990,18 +2223,36 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
|
||||
# up * silu(gate)
|
||||
#
|
||||
subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3])
|
||||
num_prev_subtiles = tile_sched.num_tiles_executed * subtile_cnt
|
||||
|
||||
for subtile_idx in cutlass.range(0, subtile_cnt, 2):
|
||||
real_subtile_idx = subtile_idx // 2
|
||||
if cutlass.const_expr(self.overlapping_accum):
|
||||
if reverse_subtile:
|
||||
real_subtile_idx = (
|
||||
self.cta_tile_shape_mnk[1] // self.epi_tile_n_required
|
||||
- 1
|
||||
- subtile_idx // 2
|
||||
)
|
||||
#
|
||||
# Load accumulator from tensor memory buffer to register
|
||||
#
|
||||
tTR_tAcc_mn_up = tTR_tAcc[(None, None, None, subtile_idx)]
|
||||
tTR_tAcc_mn_gate = tTR_tAcc[(None, None, None, subtile_idx + 1)]
|
||||
tTR_tAcc_mn_up = tTR_tAcc[(None, None, None, real_subtile_idx * 2)]
|
||||
tTR_tAcc_mn_gate = tTR_tAcc[(None, None, None, real_subtile_idx * 2 + 1)]
|
||||
|
||||
cute.copy(tiled_copy_t2r, tTR_tAcc_mn_up, tTR_rAcc_up)
|
||||
cute.copy(tiled_copy_t2r, tTR_tAcc_mn_gate, tTR_rAcc_gate)
|
||||
|
||||
#
|
||||
# Async arrive accumulator buffer empty earlier when overlapping_accum is enabled
|
||||
#
|
||||
if cutlass.const_expr(self.overlapping_accum):
|
||||
if subtile_idx // 2 == self.iter_acc_early_release_in_epilogue:
|
||||
# Fence for TMEM load
|
||||
cute.arch.fence_view_async_tmem_load()
|
||||
with cute.arch.elect_one():
|
||||
acc_pipeline.consumer_release(acc_consumer_state)
|
||||
acc_consumer_state.advance()
|
||||
|
||||
acc_vec_up = tTR_rAcc_up.load()
|
||||
acc_vec_gate = tTR_rAcc_gate.load()
|
||||
|
||||
@ -2075,7 +2326,7 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
|
||||
# Assume subtile partitioned always happens on n dimension
|
||||
sfc_subtile_idx_mn = (
|
||||
tile_info[0] * self.epi_tile_cnt[0],
|
||||
tile_info[1] * self.epi_tile_cnt[1] + subtile_idx // 2,
|
||||
tile_info[1] * self.epi_tile_cnt[1] + real_subtile_idx,
|
||||
)
|
||||
tCgSFC = tCgSFC_mn[
|
||||
(
|
||||
@ -2195,7 +2446,7 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
|
||||
# Store C to shared memory
|
||||
#
|
||||
num_prev_subtiles = num_prev_subtiles + 1
|
||||
c_buffer = (num_prev_subtiles + subtile_idx // 2) % self.num_c_stage
|
||||
c_buffer = num_prev_subtiles % self.num_c_stage
|
||||
|
||||
cute.copy(
|
||||
tiled_copy_r2s,
|
||||
@ -2215,7 +2466,7 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
|
||||
cute.copy(
|
||||
tma_atom_c,
|
||||
bSG_sC[(None, c_buffer)],
|
||||
bSG_gC[(None, subtile_idx // 2)],
|
||||
bSG_gC[(None, real_subtile_idx)],
|
||||
)
|
||||
# Fence and barrier to make sure shared memory store is visible to TMA store
|
||||
c_pipeline.producer_commit()
|
||||
@ -2225,9 +2476,10 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
|
||||
#
|
||||
# Async arrive accumulator buffer empty
|
||||
#
|
||||
with cute.arch.elect_one():
|
||||
acc_pipeline.consumer_release(acc_consumer_state)
|
||||
acc_consumer_state.advance()
|
||||
if cutlass.const_expr(not self.overlapping_accum):
|
||||
with cute.arch.elect_one():
|
||||
acc_pipeline.consumer_release(acc_consumer_state)
|
||||
acc_consumer_state.advance()
|
||||
|
||||
#
|
||||
# Advance to next tile
|
||||
@ -2509,10 +2761,6 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
|
||||
# Start with total smem per CTA (capacity / occupancy)
|
||||
# Subtract reserved bytes and initial C stages bytes
|
||||
# Divide remaining by bytes needed per A/B stage
|
||||
# cute.printf("num_smem_capacity: {}, occupancy: {}, "
|
||||
# "mbar_helpers_bytes: {}, c_bytes: {}",
|
||||
# num_smem_capacity, occupancy, mbar_helpers_bytes, c_bytes)
|
||||
# cute.printf("ab_bytes_per_stage: {}", ab_bytes_per_stage)
|
||||
num_ab_stage = (
|
||||
num_smem_capacity // occupancy - (mbar_helpers_bytes + c_bytes)
|
||||
) // ab_bytes_per_stage
|
||||
@ -2695,15 +2943,12 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
|
||||
|
||||
if ab_dtype is cutlass.Float4E2M1FN and not (a_major == "k" and b_major == "k"):
|
||||
is_valid = False
|
||||
# TODO: Currently we don't support m major output for Float4E2M1FN,
|
||||
# Need to support it in the future.
|
||||
if c_dtype is cutlass.Float4E2M1FN and c_major == "m":
|
||||
is_valid = False
|
||||
return is_valid
|
||||
|
||||
@staticmethod
|
||||
def is_valid_mma_tiler_and_cluster_shape(
|
||||
use_2cta_instrs: bool,
|
||||
mma_tiler_mn: Tuple[int, int],
|
||||
cluster_shape_mn: Tuple[int, int],
|
||||
) -> bool:
|
||||
@ -2723,19 +2968,18 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
|
||||
is_valid = True
|
||||
|
||||
# Skip invalid mma tile shape
|
||||
if not (
|
||||
(not use_2cta_instrs and mma_tiler_mn[0] in [64, 128])
|
||||
or (use_2cta_instrs and mma_tiler_mn[0] in [128, 256])
|
||||
):
|
||||
if mma_tiler_mn[0] not in (128, 256):
|
||||
is_valid = False
|
||||
# Skip invalid mma tile n
|
||||
# Needs to have even iterations with Epi Tile N 64 for swiGeLU fusion
|
||||
# SwiGlu Fusion requires even epi_tile counts,
|
||||
# based on epi_tile_n = 64, only mma_tiler_n = 128 and 256 are supported
|
||||
if mma_tiler_mn[1] not in (128, 256):
|
||||
is_valid = False
|
||||
|
||||
# Skip illegal cluster shape
|
||||
if cluster_shape_mn[0] % (2 if use_2cta_instrs else 1) != 0:
|
||||
if (mma_tiler_mn[0] // cluster_shape_mn[0]) != 128:
|
||||
is_valid = False
|
||||
# Skip invalid cluster shape
|
||||
|
||||
if (
|
||||
cluster_shape_mn[0] * cluster_shape_mn[1] > 16
|
||||
or cluster_shape_mn[0] <= 0
|
||||
@ -2748,15 +2992,11 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
|
||||
or not is_power_of_2(cluster_shape_mn[1])
|
||||
):
|
||||
is_valid = False
|
||||
cluster_tiler_m = (cluster_shape_mn[0] // (2 if use_2cta_instrs else 1)) * mma_tiler_mn[0]
|
||||
|
||||
# Skip invalid cluster tiler shape since contiguous layout can't handle oob access
|
||||
# The contiguous layout means the aligned data is stored in a contiguous manner.
|
||||
# It can't handle runtime oob when alignment is not align with the tile_M,
|
||||
# since the problem shape of TMA store can't be changed at runtime.
|
||||
if cluster_tiler_m not in [64, 128, 256]:
|
||||
# We only support cluster shape n = 1 for now
|
||||
# TODO: Support cluster shape n > 1
|
||||
if cluster_shape_mn[1] != 1:
|
||||
is_valid = False
|
||||
|
||||
return is_valid
|
||||
|
||||
@staticmethod
|
||||
@ -2812,8 +3052,9 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
|
||||
is_valid = False
|
||||
return is_valid
|
||||
|
||||
@staticmethod
|
||||
@classmethod
|
||||
def can_implement(
|
||||
cls,
|
||||
ab_dtype: Type[cutlass.Numeric],
|
||||
sf_dtype: Type[cutlass.Numeric],
|
||||
sf_vec_size: int,
|
||||
@ -2839,8 +3080,6 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
|
||||
:type sf_vec_size: int
|
||||
:param c_dtype: The data type of the output tensor
|
||||
:type c_dtype: Type[cutlass.Numeric]
|
||||
:param use_2cta_instrs: Whether to use 2 CTA groups
|
||||
:type use_2cta_instrs: bool
|
||||
:param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler
|
||||
:type mma_tiler_mn: Tuple[int, int]
|
||||
:param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster
|
||||
@ -2865,25 +3104,20 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
|
||||
"""
|
||||
can_implement = True
|
||||
# Skip unsupported types
|
||||
if not BlockScaledContiguousGatherGroupedGemmKernel.is_valid_dtypes_and_scale_factor_vec_size(
|
||||
if not cls.is_valid_dtypes_and_scale_factor_vec_size(
|
||||
ab_dtype, sf_dtype, sf_vec_size, c_dtype
|
||||
):
|
||||
can_implement = False
|
||||
|
||||
# Skip unsupported layouts
|
||||
if not BlockScaledContiguousGatherGroupedGemmKernel.is_valid_layouts(
|
||||
ab_dtype, c_dtype, a_major, b_major, c_major
|
||||
):
|
||||
if not cls.is_valid_layouts(ab_dtype, c_dtype, a_major, b_major, c_major):
|
||||
can_implement = False
|
||||
|
||||
# Skip invalid mma tile shape and cluster shape
|
||||
use_2cta_instrs = mma_tiler_mn[0] == 256
|
||||
if not BlockScaledContiguousGatherGroupedGemmKernel.is_valid_mma_tiler_and_cluster_shape(
|
||||
use_2cta_instrs, mma_tiler_mn, cluster_shape_mn
|
||||
):
|
||||
if not cls.is_valid_mma_tiler_and_cluster_shape(mma_tiler_mn, cluster_shape_mn):
|
||||
can_implement = False
|
||||
# Skip illegal problem shape for load/store alignment
|
||||
if not BlockScaledContiguousGatherGroupedGemmKernel.is_valid_tensor_alignment(
|
||||
if not cls.is_valid_tensor_alignment(
|
||||
m, n, k, l, ab_dtype, c_dtype, a_major, b_major, c_major
|
||||
):
|
||||
can_implement = False
|
||||
|
||||
@ -2244,6 +2244,10 @@ class Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel:
|
||||
cluster_shape_mn[0] * cluster_shape_mn[1] > 16
|
||||
or cluster_shape_mn[0] <= 0
|
||||
or cluster_shape_mn[1] <= 0
|
||||
# Special cluster shape check for scale factor multicasts.
|
||||
# Due to limited size of scale factors, we can't multicast among more than 4 CTAs.
|
||||
or cluster_shape_mn[0] > 4
|
||||
or cluster_shape_mn[1] > 4
|
||||
or not is_power_of_2(cluster_shape_mn[0])
|
||||
or not is_power_of_2(cluster_shape_mn[1])
|
||||
):
|
||||
@ -2304,8 +2308,9 @@ class Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel:
|
||||
is_valid = False
|
||||
return is_valid
|
||||
|
||||
@staticmethod
|
||||
@classmethod
|
||||
def can_implement(
|
||||
cls,
|
||||
ab_dtype: Type[cutlass.Numeric],
|
||||
sf_dtype: Type[cutlass.Numeric],
|
||||
sf_vec_size: int,
|
||||
@ -2355,24 +2360,20 @@ class Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel:
|
||||
"""
|
||||
can_implement = True
|
||||
# Skip unsupported types
|
||||
if not Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel.is_valid_dtypes_and_scale_factor_vec_size(
|
||||
if not cls.is_valid_dtypes_and_scale_factor_vec_size(
|
||||
ab_dtype, sf_dtype, sf_vec_size, out_dtype
|
||||
):
|
||||
can_implement = False
|
||||
|
||||
# Skip unsupported layouts
|
||||
if not Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel.is_valid_layouts(
|
||||
ab_dtype, out_dtype, a_major, b_major, out_major
|
||||
):
|
||||
if not cls.is_valid_layouts(ab_dtype, out_dtype, a_major, b_major, out_major):
|
||||
can_implement = False
|
||||
|
||||
# Skip invalid mma tile shape and cluster shape
|
||||
if not Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel.is_valid_mma_tiler_and_cluster_shape(
|
||||
mma_tiler_mn, cluster_shape_mn
|
||||
):
|
||||
if not cls.is_valid_mma_tiler_and_cluster_shape(mma_tiler_mn, cluster_shape_mn):
|
||||
can_implement = False
|
||||
# Skip illegal problem shape for load/store alignment
|
||||
if not Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel.is_valid_tensor_alignment(
|
||||
if not cls.is_valid_tensor_alignment(
|
||||
m, n, k, l, ab_dtype, out_dtype, a_major, b_major, out_major
|
||||
):
|
||||
can_implement = False
|
||||
|
||||
@ -82,15 +82,15 @@ This GEMM kernel supports the following features:
|
||||
|
||||
This GEMM works as follows:
|
||||
1. DMA warp: Load A and B matrices from global memory (GMEM) to shared memory (SMEM) using TMA operations.
|
||||
2. SCALE warp: Load scaleA and scaleB matrices from global memory (GMEM) to shared memory (SMEM) using non-TMA
|
||||
operations.
|
||||
2. MMA warp:
|
||||
- Load scale factor A/B from shared memory (SMEM) to tensor memory (TMEM) using tcgen05.cp instruction.
|
||||
- Perform matrix multiply-accumulate (MMA) operations using tcgen05.mma instruction.
|
||||
3. EPILOGUE warp:
|
||||
- Load completed accumulator from tensor memory (TMEM) to registers (RMEM) using tcgen05.ld.
|
||||
- Apply alpha and update the final accumulator Final = alpha * acc
|
||||
- Type convert Final matrix to output type.
|
||||
5. EPILOGUE warps:
|
||||
- Load two accumulator subtiles (up and gate) from tensor memory (TMEM) to registers (RMEM) using tcgen05.ld.
|
||||
- Apply alpha scaling: up_scaled = alpha * up, gate_scaled = alpha * gate
|
||||
- Compute SwiGLU activation: output = up_scaled * silu(gate_scaled), where silu(x) = x * sigmoid(x)
|
||||
- If c_dtype is Float4E2M1FN: generate scale factor C (SFC) and quantize output
|
||||
- Type convert output to c_dtype.
|
||||
- Store C matrix from registers (RMEM) to shared memory (SMEM) to global memory (GMEM) with TMA operations.
|
||||
|
||||
SM100 tcgen05.mma.kind.block_scale instructions operate as follows:
|
||||
@ -101,24 +101,6 @@ SM100 tcgen05.mma.kind.block_scale instructions operate as follows:
|
||||
- Write accumulator to TMEM
|
||||
The accumulator in TMEM must then be loaded to registers before writing back to GMEM.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python examples/blackwell/contiguous_blockscaled_grouped_gemm.py \
|
||||
--ab_dtype Float4E2M1FN --c_dtype BFloat16 \
|
||||
--sf_dtype Float8E4M3FN --sf_vec_size 16 \
|
||||
--mma_tiler_mn 256,128 --cluster_shape_mn 2,1 \
|
||||
--mnkl 256,4096,7168,1
|
||||
|
||||
To collect performance with NCU profiler:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
ncu python examples/blackwell/contiguous_blockscaled_grouped_gemm.py \
|
||||
--ab_dtype Float4E2M1FN --c_dtype BFloat16 \
|
||||
--sf_dtype Float8E4M3FN --sf_vec_size 16 \
|
||||
--mma_tiler_mn 256,128 --cluster_shape_mn 2,1 \
|
||||
--mnkl 256,4096,7168,1
|
||||
|
||||
Constraints:
|
||||
* Supported input data types: mxf8, mxf4, nvf4
|
||||
see detailed valid dtype combinations in below Sm100BlockScaledPersistentDenseGemmKernel class documentation
|
||||
@ -138,14 +120,14 @@ CUDA Graph Support:
|
||||
- A matrix: padded to permuted_m rows (padding rows contain dummy data)
|
||||
- C matrix: padded to permuted_m rows (output buffer for cuda_graph)
|
||||
- Scale factor A: padded to match A matrix dimensions
|
||||
* Kernel handling of padding (similar to masked_grouped_gemm.py):
|
||||
* Kernel handling of padding:
|
||||
- Scheduler warp checks if tile_idx >= num_non_exiting_tiles to exit
|
||||
- Only valid tiles (tile_idx < num_non_exiting_tiles) are written to tile_info pipeline
|
||||
- When no more valid tiles exist, outer loop exits and calls producer_tail()
|
||||
- Consumer warps process only valid tiles from pipeline
|
||||
- No deadlock or synchronization issues
|
||||
* Consumer warps check initial tile against num_non_exiting_tiles and set is_valid_tile=False if
|
||||
tile_idx >= num_non_exiting_tiles
|
||||
* Consumer warps check initial tile against num_non_exiting_tiles and set
|
||||
is_valid_tile=False if tile_idx >= num_non_exiting_tiles
|
||||
* Only rows within (aligned_groupm[0]+aligned_groupm[1]+...) contain valid data
|
||||
* Padding rows in C matrix will not be written by the kernel
|
||||
"""
|
||||
@ -199,23 +181,34 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
|
||||
cluster_shape_mn: Tuple[int, int],
|
||||
vectorized_f32: bool,
|
||||
):
|
||||
"""Initializes the configuration for a Blackwell blockscaled dense GEMM kernel.
|
||||
"""Initializes the configuration for a Blackwell blockscaled dense GEMM kernel with SwiGLU fusion.
|
||||
|
||||
This configuration includes several key aspects:
|
||||
|
||||
1. MMA Instruction Settings (tcgen05):
|
||||
- acc_dtype: Data types for MMA accumulator.
|
||||
- mma_tiler_mn: The (M, N) shape of the MMA instruction tiler.
|
||||
- use_2cta_instrs: Boolean indicating if the tcgen05 MMA variant
|
||||
with cta_group=2 should be used.
|
||||
- use_2cta_instrs: Automatically inferred from mma_tiler_mn[0]
|
||||
(True when M=256, False when M=128).
|
||||
|
||||
2. Cluster Shape:
|
||||
- cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster.
|
||||
|
||||
3. Scale Factor Configuration:
|
||||
- sf_vec_size: Vector size for block-scaled quantization.
|
||||
|
||||
4. Performance Optimization:
|
||||
- vectorized_f32: Enable vectorized f32x2 operations.
|
||||
|
||||
:param sf_vec_size: Vector size for scale factors (16 for NVF4, 32 for MXF4/MXF8).
|
||||
:type sf_vec_size: int
|
||||
:param mma_tiler_mn: Tuple (M, N) shape of the MMA instruction.
|
||||
use_2cta_instrs is automatically set based on M (True if M=256, False if M=128).
|
||||
:type mma_tiler_mn: Tuple[int, int]
|
||||
:param cluster_shape_mn: Tuple (ClusterM, ClusterN) shape of the cluster.
|
||||
:type cluster_shape_mn: Tuple[int, int]
|
||||
:param vectorized_f32: Enable vectorized f32x2 operations for better performance.
|
||||
:type vectorized_f32: bool
|
||||
"""
|
||||
|
||||
self.sf_vec_size = sf_vec_size
|
||||
@ -248,10 +241,6 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
|
||||
self.tma_warp_id,
|
||||
)
|
||||
)
|
||||
# TODO: Do we need to reallocate register?
|
||||
# self.num_regs_uniform_warps = 64
|
||||
# self.num_regs_sched_warps = 64
|
||||
# self.num_regs_epilogue_warps = 216
|
||||
|
||||
# Set barrier for cta sync, epilogue sync and tmem ptr sync
|
||||
self.cta_sync_barrier = pipeline.NamedBarrier(
|
||||
@ -270,6 +259,7 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
|
||||
barrier_id=4,
|
||||
num_threads=self.threads_per_warp,
|
||||
)
|
||||
|
||||
self.num_smem_capacity = utils.get_smem_capacity_in_bytes("sm_100")
|
||||
SM100_TMEM_CAPACITY_COLUMNS = 512
|
||||
self.num_tmem_alloc_cols = SM100_TMEM_CAPACITY_COLUMNS
|
||||
@ -337,22 +327,29 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
|
||||
mma_inst_shape_k * mma_inst_tile_k,
|
||||
)
|
||||
|
||||
self.cta_tile_shape_mnk = (
|
||||
self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape),
|
||||
self.mma_tiler[1],
|
||||
self.mma_tiler[2],
|
||||
)
|
||||
|
||||
self.mma_tiler_c = (
|
||||
self.mma_inst_shape_mn[0],
|
||||
self.mma_inst_shape_mn[1] // 2,
|
||||
mma_inst_shape_k * mma_inst_tile_k,
|
||||
)
|
||||
|
||||
self.cta_tile_shape_mnk = (
|
||||
self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape),
|
||||
self.mma_tiler[1],
|
||||
self.mma_tiler[2],
|
||||
)
|
||||
self.cta_tile_shape_mnk_sfb = (
|
||||
self.mma_tiler_sfb[0] // cute.size(tiled_mma.thr_id.shape),
|
||||
self.mma_tiler_sfb[1],
|
||||
self.mma_tiler_sfb[2],
|
||||
)
|
||||
|
||||
self.cta_tile_shape_mnk_c = (
|
||||
self.mma_tiler_c[0] // cute.size(tiled_mma.thr_id.shape),
|
||||
self.mma_tiler_c[1],
|
||||
self.mma_tiler_c[2],
|
||||
)
|
||||
|
||||
# Compute cluster layout
|
||||
self.cluster_layout_vmnk = cute.tiled_divide(
|
||||
cute.make_layout((*self.cluster_shape_mn, 1)),
|
||||
@ -436,6 +433,24 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
|
||||
self.num_c_stage,
|
||||
)
|
||||
|
||||
# Overlap and double buffer accumulator when num_acc_stage == 1 for cta_tile_n = 256 case
|
||||
self.overlapping_accum = self.num_acc_stage == 1
|
||||
|
||||
# Compute number of TMEM columns for SFA/SFB/Accumulator
|
||||
sf_atom_mn = 32
|
||||
self.num_sfa_tmem_cols = (self.cta_tile_shape_mnk[0] // sf_atom_mn) * mma_inst_tile_k
|
||||
self.num_sfb_tmem_cols = (self.cta_tile_shape_mnk_sfb[1] // sf_atom_mn) * mma_inst_tile_k
|
||||
self.num_sf_tmem_cols = self.num_sfa_tmem_cols + self.num_sfb_tmem_cols
|
||||
self.num_accumulator_tmem_cols = (
|
||||
self.cta_tile_shape_mnk[1] * self.num_acc_stage
|
||||
if not self.overlapping_accum
|
||||
else self.cta_tile_shape_mnk[1] * 2 - self.num_sf_tmem_cols
|
||||
)
|
||||
|
||||
self.epi_tile_n_required = 2 * cute.size(self.epi_tile[1])
|
||||
# Only when overlapping_accum is enabled, we need to release accumulator buffer early in epilogue
|
||||
self.iter_acc_early_release_in_epilogue = self.num_sf_tmem_cols // self.epi_tile_n_required
|
||||
|
||||
@cute.jit
|
||||
def __call__(
|
||||
self,
|
||||
@ -530,6 +545,7 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
|
||||
self.mma_inst_shape_mn,
|
||||
)
|
||||
|
||||
# For 2CTA blockscaled kernels, SFB needs to be replicated across peer CTAs.
|
||||
tiled_mma_sfb = sm100_utils.make_blockscaled_trivial_tiled_mma(
|
||||
self.a_dtype,
|
||||
self.a_major_mode,
|
||||
@ -591,7 +607,9 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
|
||||
internal_type=cutlass.Int16,
|
||||
)
|
||||
|
||||
if cutlass.const_expr(self.cta_tile_shape_mnk_c[1] == 192):
|
||||
# This modifies the layout to handle overlapping 256x(# of scale factors for a single column of B (nNSF))
|
||||
# logical blocks for SFB when cta_tile_shape_n=192.
|
||||
if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 192):
|
||||
x = tma_tensor_sfb.stride[0][1]
|
||||
y = cute.ceil_div(tma_tensor_sfb.shape[0][1], 4)
|
||||
|
||||
@ -727,8 +745,8 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
|
||||
tSF: cute.Tensor,
|
||||
) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]:
|
||||
"""
|
||||
Make tiledCopy for smem to tmem load for scale factor tensor, then use it to partition smem memory (source) and
|
||||
tensor memory (destination).
|
||||
Make tiledCopy for smem to tmem load for scale factor tensor, then use it to
|
||||
partition smem memory (source) and tensor memory (destination).
|
||||
|
||||
:param sSF: The scale factor tensor in smem
|
||||
:type sSF: cute.Tensor
|
||||
@ -779,7 +797,7 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
|
||||
mSFB_nkl: cute.Tensor,
|
||||
tma_atom_c: cute.CopyAtom,
|
||||
mC_mnl: cute.Tensor,
|
||||
mSFD_mnl: Optional[cute.Tensor],
|
||||
mSFC_mnl: Optional[cute.Tensor],
|
||||
norm_const_tensor: Optional[cute.Tensor],
|
||||
tile_idx_to_expert_idx: cute.Tensor,
|
||||
num_non_exiting_tiles: cute.Tensor,
|
||||
@ -1046,7 +1064,27 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
|
||||
# (MMA, MMA_M, MMA_N)
|
||||
acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2])
|
||||
# (MMA, MMA_M, MMA_N, STAGE)
|
||||
tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage))
|
||||
if cutlass.const_expr(self.overlapping_accum):
|
||||
num_acc_stage_overlapped = 2
|
||||
tCtAcc_fake = tiled_mma.make_fragment_C(
|
||||
cute.append(acc_shape, num_acc_stage_overlapped)
|
||||
)
|
||||
# (MMA, MMA_M, MMA_N, STAGE)
|
||||
tCtAcc_fake = cute.make_tensor(
|
||||
tCtAcc_fake.iterator,
|
||||
cute.make_layout(
|
||||
tCtAcc_fake.shape,
|
||||
stride=(
|
||||
tCtAcc_fake.stride[0],
|
||||
tCtAcc_fake.stride[1],
|
||||
tCtAcc_fake.stride[2],
|
||||
(256 - self.num_sf_tmem_cols) * tCtAcc_fake.stride[0][1],
|
||||
),
|
||||
),
|
||||
)
|
||||
else:
|
||||
# (MMA, MMA_M, MMA_N, STAGE)
|
||||
tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage))
|
||||
|
||||
#
|
||||
# Cluster wait before tensor memory alloc
|
||||
@ -1062,7 +1100,6 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
|
||||
# Specialized Schedule warp
|
||||
#
|
||||
if warp_idx == self.sched_warp_id:
|
||||
# cute.arch.warpgroup_reg_dealloc(self.num_regs_sched_warps)
|
||||
#
|
||||
# Persistent tile scheduling loop
|
||||
#
|
||||
@ -1078,10 +1115,11 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
|
||||
|
||||
while work_tile.is_valid_tile:
|
||||
cur_tile_coord = work_tile.tile_idx
|
||||
if cur_tile_coord[0] < num_non_exiting_tiles[0]:
|
||||
mma_tile_coord_m = cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape)
|
||||
if mma_tile_coord_m < num_non_exiting_tiles[0]:
|
||||
tile_info_pipeline.producer_acquire(tile_info_producer_state)
|
||||
cur_tile_coord = work_tile.tile_idx
|
||||
expert_idx = tile_idx_to_expert_idx[cur_tile_coord[0]]
|
||||
expert_idx = tile_idx_to_expert_idx[mma_tile_coord_m]
|
||||
with cute.arch.elect_one():
|
||||
sInfo[(0, tile_info_producer_state.index)] = cur_tile_coord[0]
|
||||
sInfo[(1, tile_info_producer_state.index)] = cur_tile_coord[1]
|
||||
@ -1121,7 +1159,6 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
|
||||
# Specialized TMA load warp
|
||||
#
|
||||
if warp_idx == self.tma_warp_id:
|
||||
# cute.arch.warpgroup_reg_dealloc(self.num_regs_uniform_warps)
|
||||
#
|
||||
# Persistent tile scheduling loop
|
||||
#
|
||||
@ -1169,6 +1206,7 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
|
||||
# ((atom_v, rest_v), RestK)
|
||||
tAgSFA_slice = tAgSFA[(None, mma_tile_coord_mnl[0], None, 0)]
|
||||
|
||||
# Apply SFB slicing hack when cta_tile_shape_n=64
|
||||
slice_n = mma_tile_coord_mnl[1]
|
||||
if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 64):
|
||||
slice_n = mma_tile_coord_mnl[1] // 2
|
||||
@ -1272,7 +1310,7 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
|
||||
|
||||
# Make SFA tmem tensor
|
||||
sfa_tmem_ptr = cute.recast_ptr(
|
||||
acc_tmem_ptr + tcgen05.find_tmem_tensor_col_offset(tCtAcc_base),
|
||||
acc_tmem_ptr + self.num_accumulator_tmem_cols,
|
||||
dtype=self.sf_dtype,
|
||||
)
|
||||
# (MMA, MMA_M, MMA_K)
|
||||
@ -1286,9 +1324,7 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
|
||||
|
||||
# Make SFB tmem tensor
|
||||
sfb_tmem_ptr = cute.recast_ptr(
|
||||
acc_tmem_ptr
|
||||
+ tcgen05.find_tmem_tensor_col_offset(tCtAcc_base)
|
||||
+ tcgen05.find_tmem_tensor_col_offset(tCtSFA),
|
||||
acc_tmem_ptr + self.num_accumulator_tmem_cols + self.num_sfa_tmem_cols,
|
||||
dtype=self.sf_dtype,
|
||||
)
|
||||
# (MMA, MMA_N, MMA_K)
|
||||
@ -1352,31 +1388,34 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
|
||||
if ab_consumer_state.count < k_tile_cnt and is_leader_cta:
|
||||
peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_consumer_state)
|
||||
|
||||
# Peek (try_wait) Acc buffer empty for k_tile = 0
|
||||
acc_producer_state.reset_count()
|
||||
peek_acc_empty_status = cutlass.Boolean(1)
|
||||
if ab_consumer_state.count < k_tile_cnt and is_leader_cta:
|
||||
peek_acc_empty_status = acc_pipeline.producer_try_acquire(acc_producer_state)
|
||||
|
||||
mma_tile_coord_mnl = (
|
||||
tile_info[0] // cute.size(tiled_mma.thr_id.shape),
|
||||
tile_info[1],
|
||||
tile_info[2],
|
||||
)
|
||||
|
||||
tCtAcc = tCtAcc_base[(None, None, None, acc_producer_state.index)]
|
||||
# Get accumulator stage index
|
||||
if cutlass.const_expr(self.overlapping_accum):
|
||||
acc_stage_index = acc_producer_state.phase ^ 1
|
||||
else:
|
||||
acc_stage_index = acc_producer_state.index
|
||||
|
||||
tCtAcc = tCtAcc_base[(None, None, None, acc_stage_index)]
|
||||
|
||||
# Apply TMEM pointer offset hack when cta_tile_shape_n=192 or
|
||||
# cta_tile_shape_n=64
|
||||
tCtSFB_mma = tCtSFB
|
||||
if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 192):
|
||||
# If this is an ODD tile, shift the TMEM start address for cta_tile_shape_n=192 case by two words
|
||||
# If this is an ODD tile, shift the TMEM start address for
|
||||
# cta_tile_shape_n=192 case by two words
|
||||
# (ignores first 64 columns of SFB)
|
||||
offset = (
|
||||
cutlass.Int32(2) if mma_tile_coord_mnl[1] % 2 == 1 else cutlass.Int32(0)
|
||||
)
|
||||
shifted_ptr = cute.recast_ptr(
|
||||
acc_tmem_ptr
|
||||
+ tcgen05.find_tmem_tensor_col_offset(tCtAcc_base)
|
||||
+ tcgen05.find_tmem_tensor_col_offset(tCtSFA)
|
||||
+ self.num_accumulator_tmem_cols
|
||||
+ self.num_sfa_tmem_cols
|
||||
+ offset,
|
||||
dtype=self.sf_dtype,
|
||||
)
|
||||
@ -1386,8 +1425,8 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
|
||||
offset = cutlass.Int32((mma_tile_coord_mnl[1] % 2) * 2)
|
||||
shifted_ptr = cute.recast_ptr(
|
||||
acc_tmem_ptr
|
||||
+ tcgen05.find_tmem_tensor_col_offset(tCtAcc_base)
|
||||
+ tcgen05.find_tmem_tensor_col_offset(tCtSFA)
|
||||
+ self.num_accumulator_tmem_cols
|
||||
+ self.num_sfa_tmem_cols
|
||||
+ offset,
|
||||
dtype=self.sf_dtype,
|
||||
)
|
||||
@ -1396,7 +1435,7 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
|
||||
# Wait for accumulator buffer empty
|
||||
#
|
||||
if is_leader_cta:
|
||||
acc_pipeline.producer_acquire(acc_producer_state, peek_acc_empty_status)
|
||||
acc_pipeline.producer_acquire(acc_producer_state)
|
||||
#
|
||||
# Mma mainloop
|
||||
#
|
||||
@ -1406,7 +1445,7 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
|
||||
#
|
||||
tiled_mma.set(tcgen05.Field.ACCUMULATE, False)
|
||||
|
||||
for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1):
|
||||
for k_tile in cutlass.range(k_tile_cnt):
|
||||
# Set tensor memory buffer for current tile
|
||||
# (MMA, MMA_M, MMA_N)
|
||||
|
||||
@ -1485,11 +1524,6 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
|
||||
|
||||
# Peek (try_wait) Acc buffer empty for k_tile = k_tile + 1
|
||||
acc_producer_state.advance()
|
||||
if acc_producer_state.count < k_tile_cnt:
|
||||
if is_leader_cta:
|
||||
peek_acc_empty_status = acc_pipeline.producer_try_acquire(
|
||||
acc_producer_state
|
||||
)
|
||||
|
||||
#
|
||||
# Advance to next tile
|
||||
@ -1533,7 +1567,7 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
|
||||
#
|
||||
# Partition for epilogue
|
||||
#
|
||||
epi_tidx = tidx
|
||||
epi_tidx = tidx % 128
|
||||
(
|
||||
tiled_copy_t2r,
|
||||
tTR_tAcc_base,
|
||||
@ -1562,17 +1596,18 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
|
||||
if cutlass.const_expr(self.generate_sfc):
|
||||
norm_const = norm_const_tensor[0]
|
||||
# (EPI_TILE_M, EPI_TILE_N, RestM, RestN, RestL)
|
||||
gSFD_mnl = cute.local_tile(mSFD_mnl, epi_tile, (None, None, None))
|
||||
gSFC_mnl = cute.local_tile(mSFC_mnl, epi_tile, (None, None, None))
|
||||
|
||||
thr_copy_t2r = tiled_copy_t2r.get_slice(tidx)
|
||||
# (T2R, T2R_M, T2R_N, RestM, RestN, RestL)
|
||||
tCgSFD_mnl = thr_copy_t2r.partition_D(gSFD_mnl)
|
||||
tCgSFD_mnl = cute.filter_zeros(tCgSFD_mnl)
|
||||
tCgSFC_mnl = thr_copy_t2r.partition_D(gSFC_mnl)
|
||||
tCgSFC_mnl = cute.filter_zeros(tCgSFC_mnl)
|
||||
# (T2R, T2R_M, T2R_N)
|
||||
tCrSFD = cute.make_rmem_tensor(
|
||||
tCgSFD_mnl[(None, None, None, 0, 0, 0)].layout, self.sf_dtype
|
||||
tCrSFC = cute.make_rmem_tensor(
|
||||
tCgSFC_mnl[(None, None, None, 0, 0, 0)].layout, self.sf_dtype
|
||||
)
|
||||
tCrSFD_pvscale = cute.make_rmem_tensor_like(tCrSFD, cutlass.Float32)
|
||||
tCrSFC_pvscale = cute.make_rmem_tensor_like(tCrSFC, cutlass.Float32)
|
||||
|
||||
#
|
||||
# Persistent tile scheduling loop
|
||||
#
|
||||
@ -1614,6 +1649,7 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
|
||||
tile_info_pipeline.consumer_release(tile_info_consumer_state)
|
||||
tile_info_consumer_state.advance()
|
||||
|
||||
num_prev_subtiles = cutlass.Int32(0)
|
||||
while is_valid_tile:
|
||||
mma_tile_coord_mnl = (
|
||||
tile_info[0] // cute.size(tiled_mma.thr_id.shape),
|
||||
@ -1643,13 +1679,22 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
|
||||
)
|
||||
]
|
||||
|
||||
# Get accumulator stage index
|
||||
if cutlass.const_expr(self.overlapping_accum):
|
||||
acc_stage_index = acc_consumer_state.phase
|
||||
reverse_subtile = (
|
||||
cutlass.Boolean(True) if acc_stage_index == 0 else cutlass.Boolean(False)
|
||||
)
|
||||
else:
|
||||
acc_stage_index = acc_consumer_state.index
|
||||
|
||||
# Set tensor memory buffer for current tile
|
||||
# (T2R, T2R_M, T2R_N, EPI_M, EPI_M)
|
||||
tTR_tAcc = tTR_tAcc_base[(None, None, None, None, None, acc_consumer_state.index)]
|
||||
tTR_tAcc = tTR_tAcc_base[(None, None, None, None, None, acc_stage_index)]
|
||||
|
||||
if cutlass.const_expr(self.generate_sfc):
|
||||
# (T2R, T2R_M, T2R_N, RestM, RestN)
|
||||
tCgSFD_mn = tCgSFD_mnl[
|
||||
tCgSFC_mn = tCgSFC_mnl[
|
||||
(
|
||||
None,
|
||||
None,
|
||||
@ -1669,28 +1714,54 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
|
||||
bSG_gC = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC))
|
||||
|
||||
#
|
||||
# Store accumulator to global memory in sub-tiles
|
||||
# Process accumulator subtiles with SwiGLU fusion and store to global memory
|
||||
# Each iteration processes a pair of subtiles (up, gate) and computes
|
||||
# up * silu(gate)
|
||||
#
|
||||
subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3])
|
||||
num_prev_subtiles = tile_sched.num_tiles_executed * subtile_cnt
|
||||
|
||||
for subtile_idx in cutlass.range(0, subtile_cnt, 2):
|
||||
real_subtile_idx = subtile_idx // 2
|
||||
if cutlass.const_expr(self.overlapping_accum):
|
||||
if reverse_subtile:
|
||||
real_subtile_idx = (
|
||||
self.cta_tile_shape_mnk[1] // self.epi_tile_n_required
|
||||
- 1
|
||||
- subtile_idx // 2
|
||||
)
|
||||
#
|
||||
# Load accumulator from tensor memory buffer to register
|
||||
#
|
||||
tTR_tAcc_mn_up = tTR_tAcc[(None, None, None, subtile_idx)]
|
||||
tTR_tAcc_mn_gate = tTR_tAcc[(None, None, None, subtile_idx + 1)]
|
||||
tTR_tAcc_mn_up = tTR_tAcc[(None, None, None, real_subtile_idx * 2)]
|
||||
tTR_tAcc_mn_gate = tTR_tAcc[(None, None, None, real_subtile_idx * 2 + 1)]
|
||||
|
||||
cute.copy(tiled_copy_t2r, tTR_tAcc_mn_up, tTR_rAcc_up)
|
||||
cute.copy(tiled_copy_t2r, tTR_tAcc_mn_gate, tTR_rAcc_gate)
|
||||
|
||||
#
|
||||
# Async arrive accumulator buffer empty earlier when overlapping_accum is enabled
|
||||
#
|
||||
if cutlass.const_expr(self.overlapping_accum):
|
||||
if subtile_idx // 2 == self.iter_acc_early_release_in_epilogue:
|
||||
# Fence for TMEM load
|
||||
cute.arch.fence_view_async_tmem_load()
|
||||
with cute.arch.elect_one():
|
||||
acc_pipeline.consumer_release(acc_consumer_state)
|
||||
acc_consumer_state.advance()
|
||||
|
||||
acc_vec_up = tTR_rAcc_up.load()
|
||||
acc_vec_gate = tTR_rAcc_gate.load()
|
||||
|
||||
# SwiGlu
|
||||
#
|
||||
# SwiGLU activation: output = up * silu(gate)
|
||||
# where silu(x) = x * sigmoid(x)
|
||||
# up and gate are extracted from interleaved accumulator subtiles
|
||||
#
|
||||
tCompute = cute.make_rmem_tensor(acc_vec_gate.shape, self.acc_dtype)
|
||||
if cutlass.const_expr(self.vectorized_f32):
|
||||
# SwiGlu Packed Version
|
||||
# SwiGLU Packed Version: uses f32x2 packed operations for better performance
|
||||
# Computes: output = (alpha * up) * silu(alpha * gate)
|
||||
# where silu(x) = x * sigmoid(x) = x / (1 + exp(-x))
|
||||
LOG2_E = cutlass.Float32(1.4426950408889634)
|
||||
for i in cutlass.range_constexpr(0, cute.size(tTR_rAcc_up), 2):
|
||||
acc_vec_up_alpha = cute.arch.mul_packed_f32x2(
|
||||
@ -1731,7 +1802,8 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
|
||||
(acc_vec_up_alpha[0], acc_vec_up_alpha[1]),
|
||||
)
|
||||
else:
|
||||
# SwiGlu Unpacked Version
|
||||
# SwiGLU Unpacked Version: scalar operations
|
||||
# Computes: output = (alpha * up) * silu(alpha * gate)
|
||||
for i in cutlass.range_constexpr(cute.size(tTR_rAcc_up)):
|
||||
acc_vec_up_alpha = acc_vec_up[i] * cutlass.Float32(alpha_val)
|
||||
acc_vec_gate_alpha = acc_vec_gate[i] * cutlass.Float32(alpha_val)
|
||||
@ -1740,12 +1812,19 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
|
||||
)
|
||||
|
||||
if cutlass.const_expr(self.generate_sfc):
|
||||
#
|
||||
# Quantization path for Float4E2M1FN output:
|
||||
# 1. Compute per-vector absolute max from SwiGLU result
|
||||
# 2. Generate scale factor C (SFC) based on max values
|
||||
# 3. Store SFC to global memory
|
||||
# 4. Quantize output by scaling with reciprocal of SFC
|
||||
#
|
||||
# Assume subtile partitioned always happens on n dimension
|
||||
sfc_subtile_idx_mn = (
|
||||
tile_info[0] * self.epi_tile_cnt[0],
|
||||
tile_info[1] * self.epi_tile_cnt[1] + subtile_idx // 2,
|
||||
tile_info[1] * self.epi_tile_cnt[1] + real_subtile_idx,
|
||||
)
|
||||
tCgSFD = tCgSFD_mn[
|
||||
tCgSFC = tCgSFC_mn[
|
||||
(
|
||||
None,
|
||||
None,
|
||||
@ -1769,30 +1848,30 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
|
||||
|
||||
if cutlass.const_expr(self.vectorized_f32):
|
||||
for vi in cutlass.range_constexpr(abs_acc_frg.shape[1]):
|
||||
tCrSFD_pvscale[vi] = abs_acc_frg[None, vi].reduce(
|
||||
tCrSFC_pvscale[vi] = abs_acc_frg[None, vi].reduce(
|
||||
cute.ReductionOp.MAX,
|
||||
cutlass.Float32(0.0),
|
||||
0, # Use 0.0 as init for abs values
|
||||
)
|
||||
for vi in cutlass.range_constexpr(0, abs_acc_frg.shape[1], 2):
|
||||
tCrSFD_pvscale[vi], tCrSFD_pvscale[vi + 1] = (
|
||||
tCrSFC_pvscale[vi], tCrSFC_pvscale[vi + 1] = (
|
||||
cute.arch.mul_packed_f32x2(
|
||||
(tCrSFD_pvscale[vi], tCrSFD_pvscale[vi + 1]),
|
||||
(tCrSFC_pvscale[vi], tCrSFC_pvscale[vi + 1]),
|
||||
(
|
||||
self.get_dtype_rcp_limits(self.c_dtype),
|
||||
self.get_dtype_rcp_limits(self.c_dtype),
|
||||
),
|
||||
)
|
||||
)
|
||||
tCrSFD_pvscale[vi], tCrSFD_pvscale[vi + 1] = (
|
||||
tCrSFC_pvscale[vi], tCrSFC_pvscale[vi + 1] = (
|
||||
cute.arch.mul_packed_f32x2(
|
||||
(tCrSFD_pvscale[vi], tCrSFD_pvscale[vi + 1]),
|
||||
(tCrSFC_pvscale[vi], tCrSFC_pvscale[vi + 1]),
|
||||
(norm_const, norm_const),
|
||||
)
|
||||
)
|
||||
else:
|
||||
for vi in cutlass.range_constexpr(abs_acc_frg.shape[1]):
|
||||
tCrSFD_pvscale[vi] = (
|
||||
tCrSFC_pvscale[vi] = (
|
||||
abs_acc_frg[None, vi].reduce(
|
||||
cute.ReductionOp.MAX,
|
||||
cutlass.Float32(0.0),
|
||||
@ -1803,27 +1882,27 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
|
||||
)
|
||||
|
||||
# TODO: need to add f32x2 -> f8x2 conversion
|
||||
tCrSFD.store(tCrSFD_pvscale.load().to(self.sf_dtype))
|
||||
tCrSFC.store(tCrSFC_pvscale.load().to(self.sf_dtype))
|
||||
|
||||
#
|
||||
# Store SFC to global memory
|
||||
#
|
||||
# TODO: Need to think about predicate on it
|
||||
# if cute.elem_less():
|
||||
cute.autovec_copy(tCrSFD, tCgSFD)
|
||||
cute.autovec_copy(tCrSFC, tCgSFC)
|
||||
|
||||
#
|
||||
# Compute quantized output values and convert to C type
|
||||
#
|
||||
# TODO: need to add f8x2 -> f32x2 conversion
|
||||
tCrSFD_qpvscale_up = tCrSFD.load().to(cutlass.Float32)
|
||||
tCrSFC_qpvscale_up = tCrSFC.load().to(cutlass.Float32)
|
||||
fp32_max = cutlass.Float32(3.40282346638528859812e38)
|
||||
if cutlass.const_expr(self.vectorized_f32):
|
||||
for vi in cutlass.range_constexpr(0, cute.size(tCrSFD), 2):
|
||||
for vi in cutlass.range_constexpr(0, cute.size(tCrSFC), 2):
|
||||
acc_scale = cute.arch.mul_packed_f32x2(
|
||||
(
|
||||
cute.arch.rcp_approx(tCrSFD_qpvscale_up[vi]),
|
||||
cute.arch.rcp_approx(tCrSFD_qpvscale_up[vi + 1]),
|
||||
cute.arch.rcp_approx(tCrSFC_qpvscale_up[vi]),
|
||||
cute.arch.rcp_approx(tCrSFC_qpvscale_up[vi + 1]),
|
||||
),
|
||||
(norm_const, norm_const),
|
||||
)
|
||||
@ -1838,10 +1917,10 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
|
||||
(acc_scale_min0, acc_scale_min1),
|
||||
)
|
||||
else:
|
||||
for vi in cutlass.range_constexpr(cute.size(tCrSFD)):
|
||||
for vi in cutlass.range_constexpr(cute.size(tCrSFC)):
|
||||
# TODO:Need to add E8M0 rcp approximation
|
||||
acc_scale = norm_const * cute.arch.rcp_approx(
|
||||
tCrSFD_qpvscale_up[vi]
|
||||
tCrSFC_qpvscale_up[vi]
|
||||
)
|
||||
acc_scale = fmin(acc_scale, fp32_max, nan=True)
|
||||
|
||||
@ -1862,7 +1941,8 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
|
||||
#
|
||||
# Store C to shared memory
|
||||
#
|
||||
c_buffer = (num_prev_subtiles + subtile_idx // 2) % self.num_c_stage
|
||||
num_prev_subtiles = num_prev_subtiles + 1
|
||||
c_buffer = num_prev_subtiles % self.num_c_stage
|
||||
|
||||
cute.copy(
|
||||
tiled_copy_r2s,
|
||||
@ -1882,7 +1962,7 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
|
||||
cute.copy(
|
||||
tma_atom_c,
|
||||
bSG_sC[(None, c_buffer)],
|
||||
bSG_gC[(None, subtile_idx // 2)],
|
||||
bSG_gC[(None, real_subtile_idx)],
|
||||
)
|
||||
# Fence and barrier to make sure shared memory store is visible to TMA store
|
||||
c_pipeline.producer_commit()
|
||||
@ -1892,9 +1972,10 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
|
||||
#
|
||||
# Async arrive accumulator buffer empty
|
||||
#
|
||||
with cute.arch.elect_one():
|
||||
acc_pipeline.consumer_release(acc_consumer_state)
|
||||
acc_consumer_state.advance()
|
||||
if cutlass.const_expr(not self.overlapping_accum):
|
||||
with cute.arch.elect_one():
|
||||
acc_pipeline.consumer_release(acc_consumer_state)
|
||||
acc_consumer_state.advance()
|
||||
|
||||
#
|
||||
# Advance to next tile
|
||||
@ -1931,8 +2012,8 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
|
||||
use_2cta_instrs: Union[cutlass.Boolean, bool],
|
||||
) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor, cute.Tensor]:
|
||||
"""
|
||||
Make tiledCopy for tensor memory load, then use it to partition tensor memory (source) and register array
|
||||
(destination).
|
||||
Make tiledCopy for tensor memory load, then use it to partition tensor memory
|
||||
(source) and register array (destination).
|
||||
|
||||
:param tidx: The thread index in epilogue warp groups
|
||||
:type tidx: cutlass.Int32
|
||||
@ -1998,8 +2079,8 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
|
||||
sC: cute.Tensor,
|
||||
) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]:
|
||||
"""
|
||||
Make tiledCopy for shared memory store, then use it to partition register array (source) and shared memory
|
||||
(destination).
|
||||
Make tiledCopy for shared memory store, then use it to partition register
|
||||
array (source) and shared memory (destination).
|
||||
|
||||
:param tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r)
|
||||
:type tiled_copy_t2r: cute.TiledCopy
|
||||
@ -2120,8 +2201,6 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
|
||||
# Default ACC stages
|
||||
num_acc_stage = 1 if mma_tiler_mnk[1] == 256 else 2
|
||||
|
||||
# num_acc_stage = 1
|
||||
|
||||
# Default C stages
|
||||
num_c_stage = 2
|
||||
|
||||
@ -2178,9 +2257,6 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
|
||||
# Start with total smem per CTA (capacity / occupancy)
|
||||
# Subtract reserved bytes and initial C stages bytes
|
||||
# Divide remaining by bytes needed per A/B stage
|
||||
# cute.printf("num_smem_capacity: {}, occupancy: {}, mbar_helpers_bytes: {}, c_bytes: {}", num_smem_capacity,
|
||||
# occupancy, mbar_helpers_bytes, c_bytes)
|
||||
# cute.printf("ab_bytes_per_stage: {}", ab_bytes_per_stage)
|
||||
num_ab_stage = (
|
||||
num_smem_capacity // occupancy - (mbar_helpers_bytes + c_bytes)
|
||||
) // ab_bytes_per_stage
|
||||
@ -2369,7 +2445,6 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
|
||||
|
||||
@staticmethod
|
||||
def is_valid_mma_tiler_and_cluster_shape(
|
||||
use_2cta_instrs: bool,
|
||||
mma_tiler_mn: Tuple[int, int],
|
||||
cluster_shape_mn: Tuple[int, int],
|
||||
) -> bool:
|
||||
@ -2389,19 +2464,18 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
|
||||
is_valid = True
|
||||
|
||||
# Skip invalid mma tile shape
|
||||
if not (
|
||||
(not use_2cta_instrs and mma_tiler_mn[0] in [64, 128])
|
||||
or (use_2cta_instrs and mma_tiler_mn[0] in [128, 256])
|
||||
):
|
||||
if mma_tiler_mn[0] not in (128, 256):
|
||||
is_valid = False
|
||||
# Skip invalid mma tile n
|
||||
# Needs to have even iterations with Epi Tile N 64 for swiGeLU fusion
|
||||
# SwiGlu Fusion requires even epi_tile counts,
|
||||
# based on epi_tile_n = 64, only mma_tiler_n = 128 and 256 are supported
|
||||
if mma_tiler_mn[1] not in (128, 256):
|
||||
is_valid = False
|
||||
|
||||
# Skip illegal cluster shape
|
||||
if cluster_shape_mn[0] % (2 if use_2cta_instrs else 1) != 0:
|
||||
if (mma_tiler_mn[0] // cluster_shape_mn[0]) != 128:
|
||||
is_valid = False
|
||||
# Skip invalid cluster shape
|
||||
|
||||
if (
|
||||
cluster_shape_mn[0] * cluster_shape_mn[1] > 16
|
||||
or cluster_shape_mn[0] <= 0
|
||||
@ -2414,14 +2488,6 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
|
||||
or not is_power_of_2(cluster_shape_mn[1])
|
||||
):
|
||||
is_valid = False
|
||||
cluster_tiler_m = (cluster_shape_mn[0] // (2 if use_2cta_instrs else 1)) * mma_tiler_mn[0]
|
||||
|
||||
# Skip invalid cluster tiler shape since contiguous layout can't handle oob access
|
||||
# The contiguous layout means the aligned data is stored in a contiguous manner.
|
||||
# It can't handle runtime oob when alignment is not align with the tile_M,
|
||||
# since the problem shape of TMA store can't be changed at runtime.
|
||||
if cluster_tiler_m not in [64, 128, 256]:
|
||||
is_valid = False
|
||||
|
||||
return is_valid
|
||||
|
||||
@ -2540,10 +2606,7 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
|
||||
can_implement = False
|
||||
|
||||
# Skip invalid mma tile shape and cluster shape
|
||||
use_2cta_instrs = mma_tiler_mn[0] == 256
|
||||
if not cls.is_valid_mma_tiler_and_cluster_shape(
|
||||
use_2cta_instrs, mma_tiler_mn, cluster_shape_mn
|
||||
):
|
||||
if not cls.is_valid_mma_tiler_and_cluster_shape(mma_tiler_mn, cluster_shape_mn):
|
||||
can_implement = False
|
||||
# Skip illegal problem shape for load/store alignment
|
||||
if not cls.is_valid_tensor_alignment(
|
||||
@ -2630,3 +2693,18 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
|
||||
stream=stream,
|
||||
epilogue_op=epilogue_op,
|
||||
)
|
||||
|
||||
|
||||
@cute.jit
|
||||
def cvt_sf_MKL_to_M32x4xrm_K4xrk_L(
|
||||
sf_ref_tensor: cute.Tensor,
|
||||
sf_mma_tensor: cute.Tensor,
|
||||
):
|
||||
"""Convert scale factor tensor from MKL layout to mma specification M(32x4xrest_m)xK(4xrest_k)xL layout"""
|
||||
# sf_mma_tensor has flatten shape (32, 4, rest_m, 4, rest_k, l)
|
||||
# group to ((32, 4, rest_m), (4, rest_k), l)
|
||||
sf_mma_tensor = cute.group_modes(sf_mma_tensor, 0, 3)
|
||||
sf_mma_tensor = cute.group_modes(sf_mma_tensor, 1, 3)
|
||||
for i in cutlass.range(cute.size(sf_ref_tensor)):
|
||||
mkl_coord = sf_ref_tensor.layout.get_hier_coord(i)
|
||||
sf_mma_tensor[mkl_coord] = sf_ref_tensor[mkl_coord]
|
||||
|
||||
@ -443,7 +443,6 @@ class PipelineCpAsyncUmma(PipelineAsync):
|
||||
barrier_storage: cute.Pointer = None,
|
||||
cta_layout_vmnk: Optional[cute.Layout] = None,
|
||||
defer_sync: bool = False,
|
||||
enable_cp_async: bool = False,
|
||||
):
|
||||
"""Creates and initializes a new PipelineCpAsyncUmma instance.
|
||||
|
||||
@ -459,8 +458,6 @@ class PipelineCpAsyncUmma(PipelineAsync):
|
||||
:type cta_layout_vmnk: cute.Layout, optional
|
||||
:param defer_sync: Whether to defer the sync
|
||||
:type defer_sync: bool, optional
|
||||
:param enable_cp_async: Whether to enable cp.async instructions
|
||||
:type enable_cp_async: bool, optional
|
||||
:raises ValueError: If barrier_storage is not a cute.Pointer instance
|
||||
:return: A new PipelineCpAsyncUmma instance configured with the provided parameters
|
||||
:rtype: PipelineCpAsyncUmma
|
||||
@ -470,7 +467,7 @@ class PipelineCpAsyncUmma(PipelineAsync):
|
||||
f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}"
|
||||
)
|
||||
|
||||
producer_type = PipelineOp.AsyncLoad if enable_cp_async else PipelineOp.AsyncThread
|
||||
producer_type = PipelineOp.AsyncLoad
|
||||
consumer_type = PipelineOp.TCGen05Mma
|
||||
|
||||
producer = (producer_type, producer_group)
|
||||
|
||||
@ -349,7 +349,6 @@ def run(
|
||||
):
|
||||
"""Prepare A/B/C tensors, launch GPU kernel, and reference checking."""
|
||||
m_aligned = mma_tiler_mn[0]
|
||||
use_2cta_instrs = mma_tiler_mn[0] == 256 and cluster_shape_mn[0] % 2 == 0
|
||||
|
||||
print("Running Blackwell Persistent Dense Contiguous Grouped GEMM test with:")
|
||||
print(f"nkl: {nkl}")
|
||||
@ -362,8 +361,6 @@ def run(
|
||||
print(f"Padded M (CUDA graph support): {permuted_m}")
|
||||
print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {c_major}")
|
||||
print(f"Mma Tiler (M, N): {mma_tiler_mn}, Cluster Shape (M, N): {cluster_shape_mn}")
|
||||
print(f"2CTA MMA instructions: {'True' if use_2cta_instrs else 'False'}")
|
||||
print(f"Use TMA Store: {'True'}")
|
||||
print(f"Tolerance: {tolerance}")
|
||||
print(f"Warmup iterations: {warmup_iterations}")
|
||||
print(f"Iterations: {iterations}")
|
||||
@ -394,7 +391,7 @@ def run(
|
||||
):
|
||||
raise TypeError(
|
||||
f"Unsupported testcase {ab_dtype}, {sf_dtype}, {sf_vec_size}, {c_dtype},"
|
||||
f"{use_2cta_instrs},{mma_tiler_mn}, {cluster_shape_mn}, {n}, {k}, {l},"
|
||||
f"{mma_tiler_mn}, {cluster_shape_mn}, {n}, {k}, {l},"
|
||||
f"{a_major}, {b_major}, {c_major}, {m_aligned}"
|
||||
)
|
||||
|
||||
@ -727,7 +724,9 @@ if __name__ == "__main__":
|
||||
f"Invalid benchmark argument format. Expected file path, 'MxNxKxL', or '[m0,m1,...]xNxK'. Got: {arg}"
|
||||
)
|
||||
|
||||
parser = argparse.ArgumentParser(description="Example of Dense Persistent GEMM on Blackwell.")
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Example of BlockScaled Contiguous grouped GEMM kernel on Blackwell."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nkl",
|
||||
|
||||
@ -532,7 +532,6 @@ def run(
|
||||
Defaults to False.
|
||||
"""
|
||||
m_aligned = mma_tiler_mn[0]
|
||||
use_2cta_instrs = mma_tiler_mn[0] == 256 and cluster_shape_mn[0] % 2 == 0
|
||||
|
||||
print("Running Blackwell Persistent Dense Contiguous Grouped GEMM test with:")
|
||||
print(f"nkl: {nkl}")
|
||||
@ -547,8 +546,6 @@ def run(
|
||||
print(f"Sequence length: {seq_len}")
|
||||
print(f"Matrix majors - A: {a_major}, B: {b_major}, Out: {out_major}")
|
||||
print(f"Mma Tiler (M, N): {mma_tiler_mn}, Cluster Shape (M, N): {cluster_shape_mn}")
|
||||
print(f"2CTA MMA instructions: {'True' if use_2cta_instrs else 'False'}")
|
||||
print(f"Use TMA Store: {'True'}")
|
||||
print(f"Tolerance: {tolerance}")
|
||||
print(f"Warmup iterations: {warmup_iterations}")
|
||||
print(f"Iterations: {iterations}")
|
||||
@ -581,7 +578,7 @@ def run(
|
||||
):
|
||||
raise TypeError(
|
||||
f"Unsupported testcase {ab_dtype}, {sf_dtype}, {sf_vec_size}, {out_dtype}, "
|
||||
f"{use_2cta_instrs}, {mma_tiler_mn}, {cluster_shape_mn}, {n}, {k}, {l}, "
|
||||
f"{mma_tiler_mn}, {cluster_shape_mn}, {n}, {k}, {l}, "
|
||||
f"{a_major}, {b_major}, {out_major}, {m_aligned}"
|
||||
)
|
||||
|
||||
@ -957,7 +954,9 @@ if __name__ == "__main__":
|
||||
f"Invalid benchmark argument format. Expected file path, 'MxNxKxL', or '[m0,m1,...]xNxK'. Got: {arg}"
|
||||
)
|
||||
|
||||
parser = argparse.ArgumentParser(description="Example of Dense Persistent GEMM on Blackwell.")
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Example of BlockScaled Contiguous grouped GEMM finalize fusion kernel on Blackwell."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nkl",
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -594,7 +594,7 @@ def test_nvfp4_grouped_gemm_finalize_blackwell(
|
||||
get_sm_version() not in (100, 103),
|
||||
reason="This test is only supported on SM 100 and SM 103 GPUs",
|
||||
)
|
||||
@pytest.mark.parametrize("tile_size", [128])
|
||||
@pytest.mark.parametrize("tile_size", [128, 256])
|
||||
@pytest.mark.parametrize("ep_size", [1, 8, 32])
|
||||
@pytest.mark.parametrize("top_k", [1, 2, 8])
|
||||
@pytest.mark.parametrize("num_tokens", [128, 515, 1024, 8192])
|
||||
|
||||
Loading…
Reference in New Issue
Block a user