[None] [feat] Add test script and raster M for gather fc1 kernel (#10429)

Signed-off-by: Zongfei Jing <20381269+zongfeijing@users.noreply.github.com>
This commit is contained in:
Zongfei Jing 2026-01-07 09:31:49 +08:00 committed by GitHub
parent bb6a3973aa
commit bb2f883296
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 1555 additions and 31 deletions

View File

@ -1864,10 +1864,13 @@ if IS_CUTLASS_DSL_AVAILABLE:
mma_tiler_mn_candidates = [(self.tile_size, 128),
(self.tile_size, 256)]
cluster_shape_mn_candidates = [(self.tile_size // 128, 1)]
# TODO: Add raster_along_m=True if we find it more performant in some cases.
raster_along_m_candidates = [False]
valid_tactics = []
for mma_tiler_mn, cluster_shape_mn in itertools.product(
mma_tiler_mn_candidates, cluster_shape_mn_candidates):
for mma_tiler_mn, cluster_shape_mn, raster_along_m in itertools.product(
mma_tiler_mn_candidates, cluster_shape_mn_candidates,
raster_along_m_candidates):
if self.__class__.kernel_class.can_implement(
ab_dtype=cutlass.Float4E2M1FN,
sf_dtype=cutlass.Float8E4M3FN,
@ -1883,7 +1886,8 @@ if IS_CUTLASS_DSL_AVAILABLE:
b_major="k",
c_major="n",
):
valid_tactics.append((mma_tiler_mn, cluster_shape_mn))
valid_tactics.append(
(mma_tiler_mn, cluster_shape_mn, raster_along_m))
return valid_tactics
@ -2013,15 +2017,16 @@ if IS_CUTLASS_DSL_AVAILABLE:
stream = cuda.CUstream(torch_stream.cuda_stream)
if isinstance(tactic, tuple):
mma_tiler_mn, cluster_shape_mn = tactic
mma_tiler_mn, cluster_shape_mn, raster_along_m = tactic
else:
mma_tiler_mn = (self.tile_size, 128)
cluster_shape_mn = (self.tile_size // 128, 1)
raster_along_m = False
assert mma_tiler_mn[
0] == self.tile_size, f"Tactic ({tactic}) is incompatible with tile size ({self.tile_size})"
cache_key = (self.scaling_vector_size, self.tile_size, self.top_k,
mma_tiler_mn, cluster_shape_mn)
mma_tiler_mn, cluster_shape_mn, raster_along_m)
if cache_key not in self.__class__.kernel_cache:
gemm = self.__class__.kernel_class(
sf_vec_size=self.scaling_vector_size,
@ -2029,6 +2034,7 @@ if IS_CUTLASS_DSL_AVAILABLE:
cluster_shape_mn=cluster_shape_mn,
vectorized_f32=True,
topk=self.top_k,
raster_along_m=raster_along_m,
)
# Compute max active clusters on current device
hardware_info = cutlass.utils.HardwareInfo()

View File

@ -37,6 +37,7 @@ import cutlass.utils.blackwell_helpers as sm100_utils
import cutlass.utils.blockscaled_layout as blockscaled_utils
from cutlass._mlir.dialects import math
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass.cutlass_dsl import Int32
from .custom_pipeline import PipelineCpAsyncUmma
from .utils import (
@ -154,6 +155,144 @@ CUDA Graph Support:
"""
# TODO: Remove this hook helper function after nvidia-cutlass-dsl 4.4 is released.
def hooked_PersistentTileSchedulerParams_init(
self,
problem_shape_ntile_mnl: cute.Shape,
cluster_shape_mnk: cute.Shape,
swizzle_size: int = 1,
raster_along_m: bool = True,
*,
loc=None,
ip=None,
):
if cluster_shape_mnk[2] != 1:
raise ValueError(f"unsupported cluster_shape_k {cluster_shape_mnk[2]}")
if swizzle_size < 1:
raise ValueError(f"expect swizzle_size >= 1, but get {swizzle_size}")
self.problem_shape_ntile_mnl = problem_shape_ntile_mnl
# cluster_shape_mnk is kept for reconstruction
self._cluster_shape_mnk = cluster_shape_mnk
self.cluster_shape_mn = cluster_shape_mnk[:2]
self.swizzle_size = swizzle_size
self._raster_along_m = raster_along_m
self._loc = loc
# Apply swizzle if swizzle_size > 1
if swizzle_size > 1:
problem_shape_ncluster_mnl = cute.round_up(
self.problem_layout_ncluster_mnl.shape,
(1, swizzle_size, 1) if raster_along_m else (swizzle_size, 1, 1),
)
if raster_along_m:
self.problem_layout_ncluster_mnl = cute.make_layout(
(
problem_shape_ncluster_mnl[0],
(swizzle_size, problem_shape_ncluster_mnl[1] // swizzle_size),
problem_shape_ncluster_mnl[2],
),
stride=(
swizzle_size,
(1, swizzle_size * problem_shape_ncluster_mnl[0]),
problem_shape_ncluster_mnl[0] * problem_shape_ncluster_mnl[1],
),
loc=loc,
ip=ip,
)
else:
self.problem_layout_ncluster_mnl = cute.make_layout(
(
(swizzle_size, problem_shape_ncluster_mnl[0] // swizzle_size),
problem_shape_ncluster_mnl[1],
problem_shape_ncluster_mnl[2],
),
stride=(
(1, swizzle_size * problem_shape_ncluster_mnl[1]),
swizzle_size,
problem_shape_ncluster_mnl[0] * problem_shape_ncluster_mnl[1],
),
loc=loc,
ip=ip,
)
# Create FastDivmod divisors (only when swizzle_size == 1 for correctness)
# FastDivmod assumes simple col-major/row-major layout, incompatible with swizzled layouts
if swizzle_size == 1:
problem_shape_ncluster_mnl = cute.ceil_div(
self.problem_shape_ntile_mnl, cluster_shape_mnk[:2], loc=loc, ip=ip
)
if raster_along_m:
self.problem_layout_ncluster_mnl = cute.make_layout(
problem_shape_ncluster_mnl,
stride=(
1,
problem_shape_ncluster_mnl[0],
problem_shape_ncluster_mnl[0] * problem_shape_ncluster_mnl[1],
),
loc=loc,
ip=ip,
)
else:
self.problem_layout_ncluster_mnl = cute.make_layout(
problem_shape_ncluster_mnl,
stride=(
problem_shape_ncluster_mnl[1],
1,
problem_shape_ncluster_mnl[0] * problem_shape_ncluster_mnl[1],
),
loc=loc,
ip=ip,
)
problem_layout_size = cute.size(self.problem_layout_ncluster_mnl, loc=loc, ip=ip)
cluster_count_m = self.problem_layout_ncluster_mnl.shape[0]
cluster_count_n = self.problem_layout_ncluster_mnl.shape[1]
# batch_fdd: Used to map linear_idx to work_unit_id (handles persistent scheduling)
self.batch_fdd = cute.fast_divmod_create_divisor(problem_layout_size, loc=loc, ip=ip)
# cluster_shape_m_fdd: Used to decode work_unit_id to cluster coordinates
self.cluster_shape_m_fdd = cute.fast_divmod_create_divisor(cluster_count_m, loc=loc, ip=ip)
# cluster_shape_n_fdd: Used for the second level decomposition
self.cluster_shape_n_fdd = cute.fast_divmod_create_divisor(cluster_count_n, loc=loc, ip=ip)
else:
# FastDivmod not applicable with swizzling, set to None
self.batch_fdd = None
self.cluster_shape_m_fdd = None
self.cluster_shape_n_fdd = None
def hooked_get_cluster_work_idx_with_fastdivmod(
self, current_work_linear_idx: Int32, *, loc=None, ip=None
) -> Tuple[Int32, Int32, Int32]:
work_iteration, work_unit_id = divmod(current_work_linear_idx, self.params.batch_fdd)
if self.params._raster_along_m:
# raster_along_m=True means column major (m is fastest)
# First, get cluster_m using cluster_shape_m_fdd
cluster_n_batch, cluster_m = divmod(work_unit_id, self.params.cluster_shape_m_fdd)
# Then decode cluster_n_batch to get cluster_n and batch_l using FastDivmod
batch_l, cluster_n = divmod(cluster_n_batch, self.params.cluster_shape_n_fdd)
else:
# raster_along_m=False means row major (n is fastest)
# First, get cluster_n using cluster_shape_n_fdd
cluster_m_batch, cluster_n = divmod(work_unit_id, self.params.cluster_shape_n_fdd)
# Then decode cluster_m_batch to get cluster_m and batch_l using FastDivmod
batch_l, cluster_m = divmod(cluster_m_batch, self.params.cluster_shape_m_fdd)
return (cluster_m, cluster_n, batch_l)
cutlass.utils.PersistentTileSchedulerParams.__init__ = hooked_PersistentTileSchedulerParams_init
cutlass.utils.StaticPersistentTileScheduler._get_cluster_work_idx_with_fastdivmod = (
hooked_get_cluster_work_idx_with_fastdivmod
)
class BlockScaledContiguousGatherGroupedGemmKernel:
"""This class implements contiguous grouped matrix multiplication with gather operation and SwiGLU fusion
for FC1 layer computation (C = up * silu(gate), where up/gate come from interleaved GEMM result).
@ -245,6 +384,7 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
cluster_shape_mn: Tuple[int, int],
vectorized_f32: bool,
topk: cutlass.Int64,
raster_along_m: bool = False,
):
"""Initializes the configuration for a Blackwell blockscaled dense GEMM kernel with
gather operation and SwiGLU fusion.
@ -289,6 +429,7 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
self.cluster_shape_mn = cluster_shape_mn
# K dimension is deferred in _setup_attributes
self.mma_tiler = (*mma_tiler_mn, 1)
self.raster_along_m = raster_along_m
self.cta_group = tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE
@ -743,7 +884,11 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
# Compute grid size
self.tile_sched_params, grid = self._compute_grid(
c, self.cta_tile_shape_mnk_c, self.cluster_shape_mn, max_active_clusters
c,
self.cta_tile_shape_mnk_c,
self.cluster_shape_mn,
max_active_clusters,
self.raster_along_m,
)
self.buffer_align_bytes = 1024
@ -1254,34 +1399,69 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
pipeline.PipelineUserType.Producer, self.num_tile_stage
)
while work_tile.is_valid_tile:
cur_tile_coord = work_tile.tile_idx
mma_tile_coord_m = cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape)
if mma_tile_coord_m < num_non_exiting_tiles[0]:
tile_info_pipeline.producer_acquire(tile_info_producer_state)
num_non_exiting_tiles_value = num_non_exiting_tiles[0]
if cutlass.const_expr(self.raster_along_m):
while work_tile.is_valid_tile:
cur_tile_coord = work_tile.tile_idx
expert_idx = tile_idx_to_expert_idx[mma_tile_coord_m]
mn_limit = tile_idx_to_mn_limit[mma_tile_coord_m]
with cute.arch.elect_one():
sInfo[(0, tile_info_producer_state.index)] = cur_tile_coord[0]
sInfo[(1, tile_info_producer_state.index)] = cur_tile_coord[1]
sInfo[(2, tile_info_producer_state.index)] = expert_idx
sInfo[(3, tile_info_producer_state.index)] = cutlass.Int32(
work_tile.is_valid_tile
mma_tile_coord_m = cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape)
if mma_tile_coord_m < num_non_exiting_tiles_value:
tile_info_pipeline.producer_acquire(tile_info_producer_state)
cur_tile_coord = work_tile.tile_idx
expert_idx = tile_idx_to_expert_idx[mma_tile_coord_m]
mn_limit = tile_idx_to_mn_limit[mma_tile_coord_m]
with cute.arch.elect_one():
sInfo[(0, tile_info_producer_state.index)] = cur_tile_coord[0]
sInfo[(1, tile_info_producer_state.index)] = cur_tile_coord[1]
sInfo[(2, tile_info_producer_state.index)] = expert_idx
sInfo[(3, tile_info_producer_state.index)] = cutlass.Int32(
work_tile.is_valid_tile
)
sInfo[(4, tile_info_producer_state.index)] = mn_limit
# fence view async shared
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
)
sInfo[(4, tile_info_producer_state.index)] = mn_limit
# fence view async shared
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
)
self.sched_sync_barrier.arrive_and_wait()
tile_info_pipeline.producer_commit(tile_info_producer_state)
tile_info_producer_state.advance()
self.sched_sync_barrier.arrive_and_wait()
tile_info_pipeline.producer_commit(tile_info_producer_state)
tile_info_producer_state.advance()
tile_sched.advance_to_next_work()
work_tile = tile_sched.get_current_work()
tile_sched.advance_to_next_work()
work_tile = tile_sched.get_current_work()
else:
is_continue = cutlass.Boolean(1)
while work_tile.is_valid_tile and is_continue:
cur_tile_coord = work_tile.tile_idx
mma_tile_coord_m = cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape)
if mma_tile_coord_m < num_non_exiting_tiles_value:
tile_info_pipeline.producer_acquire(tile_info_producer_state)
cur_tile_coord = work_tile.tile_idx
expert_idx = tile_idx_to_expert_idx[mma_tile_coord_m]
mn_limit = tile_idx_to_mn_limit[mma_tile_coord_m]
with cute.arch.elect_one():
sInfo[(0, tile_info_producer_state.index)] = cur_tile_coord[0]
sInfo[(1, tile_info_producer_state.index)] = cur_tile_coord[1]
sInfo[(2, tile_info_producer_state.index)] = expert_idx
sInfo[(3, tile_info_producer_state.index)] = cutlass.Int32(
work_tile.is_valid_tile
)
sInfo[(4, tile_info_producer_state.index)] = mn_limit
# fence view async shared
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
)
self.sched_sync_barrier.arrive_and_wait()
tile_info_pipeline.producer_commit(tile_info_producer_state)
tile_info_producer_state.advance()
else:
is_continue = cutlass.Boolean(0)
tile_sched.advance_to_next_work()
work_tile = tile_sched.get_current_work()
tile_info_pipeline.producer_acquire(tile_info_producer_state)
with cute.arch.elect_one():
@ -2781,6 +2961,7 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
cta_tile_shape_mnk: Tuple[int, int, int],
cluster_shape_mn: Tuple[int, int],
max_active_clusters: cutlass.Constexpr,
raster_along_m: bool = False,
) -> Tuple[utils.PersistentTileSchedulerParams, Tuple[int, int, int]]:
"""Use persistent tile scheduler to compute the grid size for the output tensor C.
@ -2803,7 +2984,9 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
num_ctas_mnl = gc[(0, (None, None, None))].shape
cluster_shape_mnl = (*cluster_shape_mn, 1)
tile_sched_params = utils.PersistentTileSchedulerParams(num_ctas_mnl, cluster_shape_mnl)
tile_sched_params = utils.PersistentTileSchedulerParams(
num_ctas_mnl, cluster_shape_mnl, raster_along_m=raster_along_m
)
grid = utils.StaticPersistentTileScheduler.get_grid_shape(
tile_sched_params, max_active_clusters
)
@ -3209,3 +3392,33 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
stream=stream,
epilogue_op=epilogue_op,
)
@cute.jit
def cvt_sf_MKL_to_M32x4xrm_K4xrk_L(
sf_ref_tensor: cute.Tensor,
sf_mma_tensor: cute.Tensor,
):
"""Convert scale factor tensor from MKL layout to mma specification M(32x4xrest_m)xK(4xrest_k)xL layout"""
# sf_mma_tensor has flatten shape (32, 4, rest_m, 4, rest_k, l)
# group to ((32, 4, rest_m), (4, rest_k), l)
sf_mma_tensor = cute.group_modes(sf_mma_tensor, 0, 3)
sf_mma_tensor = cute.group_modes(sf_mma_tensor, 1, 3)
for i in cutlass.range(cute.size(sf_ref_tensor)):
mkl_coord = sf_ref_tensor.layout.get_hier_coord(i)
sf_mma_tensor[mkl_coord] = sf_ref_tensor[mkl_coord]
@cute.jit
def cvt_sf_M32x4xrm_K4xrk_L_to_MKL(
sf_swizzled_tensor: cute.Tensor,
sf_unswizzled_tensor: cute.Tensor,
):
"""Convert scale factor tensor from mma specification M(32x4xrest_m)xK(4xrest_k)xL layout to MKL layout"""
# sf_swizzled_tensor has flatten shape (32, 4, rest_m, 4, rest_k, l)
# group to ((32, 4, rest_m), (4, rest_k), l)
sf_swizzled_tensor = cute.group_modes(sf_swizzled_tensor, 0, 3)
sf_swizzled_tensor = cute.group_modes(sf_swizzled_tensor, 1, 3)
for i in cutlass.range(cute.size(sf_unswizzled_tensor)):
mkl_coord = sf_unswizzled_tensor.layout.get_hier_coord(i)
sf_unswizzled_tensor[mkl_coord] = sf_swizzled_tensor[mkl_coord]