mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[None] [feat] Add test script and raster M for gather fc1 kernel (#10429)
Signed-off-by: Zongfei Jing <20381269+zongfeijing@users.noreply.github.com>
This commit is contained in:
parent
bb6a3973aa
commit
bb2f883296
@ -1864,10 +1864,13 @@ if IS_CUTLASS_DSL_AVAILABLE:
|
||||
mma_tiler_mn_candidates = [(self.tile_size, 128),
|
||||
(self.tile_size, 256)]
|
||||
cluster_shape_mn_candidates = [(self.tile_size // 128, 1)]
|
||||
# TODO: Add raster_along_m=True if we find it more performant in some cases.
|
||||
raster_along_m_candidates = [False]
|
||||
|
||||
valid_tactics = []
|
||||
for mma_tiler_mn, cluster_shape_mn in itertools.product(
|
||||
mma_tiler_mn_candidates, cluster_shape_mn_candidates):
|
||||
for mma_tiler_mn, cluster_shape_mn, raster_along_m in itertools.product(
|
||||
mma_tiler_mn_candidates, cluster_shape_mn_candidates,
|
||||
raster_along_m_candidates):
|
||||
if self.__class__.kernel_class.can_implement(
|
||||
ab_dtype=cutlass.Float4E2M1FN,
|
||||
sf_dtype=cutlass.Float8E4M3FN,
|
||||
@ -1883,7 +1886,8 @@ if IS_CUTLASS_DSL_AVAILABLE:
|
||||
b_major="k",
|
||||
c_major="n",
|
||||
):
|
||||
valid_tactics.append((mma_tiler_mn, cluster_shape_mn))
|
||||
valid_tactics.append(
|
||||
(mma_tiler_mn, cluster_shape_mn, raster_along_m))
|
||||
|
||||
return valid_tactics
|
||||
|
||||
@ -2013,15 +2017,16 @@ if IS_CUTLASS_DSL_AVAILABLE:
|
||||
stream = cuda.CUstream(torch_stream.cuda_stream)
|
||||
|
||||
if isinstance(tactic, tuple):
|
||||
mma_tiler_mn, cluster_shape_mn = tactic
|
||||
mma_tiler_mn, cluster_shape_mn, raster_along_m = tactic
|
||||
else:
|
||||
mma_tiler_mn = (self.tile_size, 128)
|
||||
cluster_shape_mn = (self.tile_size // 128, 1)
|
||||
raster_along_m = False
|
||||
assert mma_tiler_mn[
|
||||
0] == self.tile_size, f"Tactic ({tactic}) is incompatible with tile size ({self.tile_size})"
|
||||
|
||||
cache_key = (self.scaling_vector_size, self.tile_size, self.top_k,
|
||||
mma_tiler_mn, cluster_shape_mn)
|
||||
mma_tiler_mn, cluster_shape_mn, raster_along_m)
|
||||
if cache_key not in self.__class__.kernel_cache:
|
||||
gemm = self.__class__.kernel_class(
|
||||
sf_vec_size=self.scaling_vector_size,
|
||||
@ -2029,6 +2034,7 @@ if IS_CUTLASS_DSL_AVAILABLE:
|
||||
cluster_shape_mn=cluster_shape_mn,
|
||||
vectorized_f32=True,
|
||||
topk=self.top_k,
|
||||
raster_along_m=raster_along_m,
|
||||
)
|
||||
# Compute max active clusters on current device
|
||||
hardware_info = cutlass.utils.HardwareInfo()
|
||||
|
||||
@ -37,6 +37,7 @@ import cutlass.utils.blackwell_helpers as sm100_utils
|
||||
import cutlass.utils.blockscaled_layout as blockscaled_utils
|
||||
from cutlass._mlir.dialects import math
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
from cutlass.cutlass_dsl import Int32
|
||||
|
||||
from .custom_pipeline import PipelineCpAsyncUmma
|
||||
from .utils import (
|
||||
@ -154,6 +155,144 @@ CUDA Graph Support:
|
||||
"""
|
||||
|
||||
|
||||
# TODO: Remove this hook helper function after nvidia-cutlass-dsl 4.4 is released.
|
||||
def hooked_PersistentTileSchedulerParams_init(
|
||||
self,
|
||||
problem_shape_ntile_mnl: cute.Shape,
|
||||
cluster_shape_mnk: cute.Shape,
|
||||
swizzle_size: int = 1,
|
||||
raster_along_m: bool = True,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
if cluster_shape_mnk[2] != 1:
|
||||
raise ValueError(f"unsupported cluster_shape_k {cluster_shape_mnk[2]}")
|
||||
if swizzle_size < 1:
|
||||
raise ValueError(f"expect swizzle_size >= 1, but get {swizzle_size}")
|
||||
|
||||
self.problem_shape_ntile_mnl = problem_shape_ntile_mnl
|
||||
# cluster_shape_mnk is kept for reconstruction
|
||||
self._cluster_shape_mnk = cluster_shape_mnk
|
||||
self.cluster_shape_mn = cluster_shape_mnk[:2]
|
||||
self.swizzle_size = swizzle_size
|
||||
self._raster_along_m = raster_along_m
|
||||
self._loc = loc
|
||||
|
||||
# Apply swizzle if swizzle_size > 1
|
||||
if swizzle_size > 1:
|
||||
problem_shape_ncluster_mnl = cute.round_up(
|
||||
self.problem_layout_ncluster_mnl.shape,
|
||||
(1, swizzle_size, 1) if raster_along_m else (swizzle_size, 1, 1),
|
||||
)
|
||||
|
||||
if raster_along_m:
|
||||
self.problem_layout_ncluster_mnl = cute.make_layout(
|
||||
(
|
||||
problem_shape_ncluster_mnl[0],
|
||||
(swizzle_size, problem_shape_ncluster_mnl[1] // swizzle_size),
|
||||
problem_shape_ncluster_mnl[2],
|
||||
),
|
||||
stride=(
|
||||
swizzle_size,
|
||||
(1, swizzle_size * problem_shape_ncluster_mnl[0]),
|
||||
problem_shape_ncluster_mnl[0] * problem_shape_ncluster_mnl[1],
|
||||
),
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
else:
|
||||
self.problem_layout_ncluster_mnl = cute.make_layout(
|
||||
(
|
||||
(swizzle_size, problem_shape_ncluster_mnl[0] // swizzle_size),
|
||||
problem_shape_ncluster_mnl[1],
|
||||
problem_shape_ncluster_mnl[2],
|
||||
),
|
||||
stride=(
|
||||
(1, swizzle_size * problem_shape_ncluster_mnl[1]),
|
||||
swizzle_size,
|
||||
problem_shape_ncluster_mnl[0] * problem_shape_ncluster_mnl[1],
|
||||
),
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
# Create FastDivmod divisors (only when swizzle_size == 1 for correctness)
|
||||
# FastDivmod assumes simple col-major/row-major layout, incompatible with swizzled layouts
|
||||
if swizzle_size == 1:
|
||||
problem_shape_ncluster_mnl = cute.ceil_div(
|
||||
self.problem_shape_ntile_mnl, cluster_shape_mnk[:2], loc=loc, ip=ip
|
||||
)
|
||||
if raster_along_m:
|
||||
self.problem_layout_ncluster_mnl = cute.make_layout(
|
||||
problem_shape_ncluster_mnl,
|
||||
stride=(
|
||||
1,
|
||||
problem_shape_ncluster_mnl[0],
|
||||
problem_shape_ncluster_mnl[0] * problem_shape_ncluster_mnl[1],
|
||||
),
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
else:
|
||||
self.problem_layout_ncluster_mnl = cute.make_layout(
|
||||
problem_shape_ncluster_mnl,
|
||||
stride=(
|
||||
problem_shape_ncluster_mnl[1],
|
||||
1,
|
||||
problem_shape_ncluster_mnl[0] * problem_shape_ncluster_mnl[1],
|
||||
),
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
problem_layout_size = cute.size(self.problem_layout_ncluster_mnl, loc=loc, ip=ip)
|
||||
cluster_count_m = self.problem_layout_ncluster_mnl.shape[0]
|
||||
cluster_count_n = self.problem_layout_ncluster_mnl.shape[1]
|
||||
|
||||
# batch_fdd: Used to map linear_idx to work_unit_id (handles persistent scheduling)
|
||||
self.batch_fdd = cute.fast_divmod_create_divisor(problem_layout_size, loc=loc, ip=ip)
|
||||
|
||||
# cluster_shape_m_fdd: Used to decode work_unit_id to cluster coordinates
|
||||
self.cluster_shape_m_fdd = cute.fast_divmod_create_divisor(cluster_count_m, loc=loc, ip=ip)
|
||||
|
||||
# cluster_shape_n_fdd: Used for the second level decomposition
|
||||
self.cluster_shape_n_fdd = cute.fast_divmod_create_divisor(cluster_count_n, loc=loc, ip=ip)
|
||||
else:
|
||||
# FastDivmod not applicable with swizzling, set to None
|
||||
self.batch_fdd = None
|
||||
self.cluster_shape_m_fdd = None
|
||||
self.cluster_shape_n_fdd = None
|
||||
|
||||
|
||||
def hooked_get_cluster_work_idx_with_fastdivmod(
|
||||
self, current_work_linear_idx: Int32, *, loc=None, ip=None
|
||||
) -> Tuple[Int32, Int32, Int32]:
|
||||
work_iteration, work_unit_id = divmod(current_work_linear_idx, self.params.batch_fdd)
|
||||
|
||||
if self.params._raster_along_m:
|
||||
# raster_along_m=True means column major (m is fastest)
|
||||
# First, get cluster_m using cluster_shape_m_fdd
|
||||
cluster_n_batch, cluster_m = divmod(work_unit_id, self.params.cluster_shape_m_fdd)
|
||||
|
||||
# Then decode cluster_n_batch to get cluster_n and batch_l using FastDivmod
|
||||
batch_l, cluster_n = divmod(cluster_n_batch, self.params.cluster_shape_n_fdd)
|
||||
else:
|
||||
# raster_along_m=False means row major (n is fastest)
|
||||
# First, get cluster_n using cluster_shape_n_fdd
|
||||
cluster_m_batch, cluster_n = divmod(work_unit_id, self.params.cluster_shape_n_fdd)
|
||||
|
||||
# Then decode cluster_m_batch to get cluster_m and batch_l using FastDivmod
|
||||
batch_l, cluster_m = divmod(cluster_m_batch, self.params.cluster_shape_m_fdd)
|
||||
|
||||
return (cluster_m, cluster_n, batch_l)
|
||||
|
||||
|
||||
cutlass.utils.PersistentTileSchedulerParams.__init__ = hooked_PersistentTileSchedulerParams_init
|
||||
cutlass.utils.StaticPersistentTileScheduler._get_cluster_work_idx_with_fastdivmod = (
|
||||
hooked_get_cluster_work_idx_with_fastdivmod
|
||||
)
|
||||
|
||||
|
||||
class BlockScaledContiguousGatherGroupedGemmKernel:
|
||||
"""This class implements contiguous grouped matrix multiplication with gather operation and SwiGLU fusion
|
||||
for FC1 layer computation (C = up * silu(gate), where up/gate come from interleaved GEMM result).
|
||||
@ -245,6 +384,7 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
|
||||
cluster_shape_mn: Tuple[int, int],
|
||||
vectorized_f32: bool,
|
||||
topk: cutlass.Int64,
|
||||
raster_along_m: bool = False,
|
||||
):
|
||||
"""Initializes the configuration for a Blackwell blockscaled dense GEMM kernel with
|
||||
gather operation and SwiGLU fusion.
|
||||
@ -289,6 +429,7 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
|
||||
self.cluster_shape_mn = cluster_shape_mn
|
||||
# K dimension is deferred in _setup_attributes
|
||||
self.mma_tiler = (*mma_tiler_mn, 1)
|
||||
self.raster_along_m = raster_along_m
|
||||
|
||||
self.cta_group = tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE
|
||||
|
||||
@ -743,7 +884,11 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
|
||||
|
||||
# Compute grid size
|
||||
self.tile_sched_params, grid = self._compute_grid(
|
||||
c, self.cta_tile_shape_mnk_c, self.cluster_shape_mn, max_active_clusters
|
||||
c,
|
||||
self.cta_tile_shape_mnk_c,
|
||||
self.cluster_shape_mn,
|
||||
max_active_clusters,
|
||||
self.raster_along_m,
|
||||
)
|
||||
|
||||
self.buffer_align_bytes = 1024
|
||||
@ -1254,34 +1399,69 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
|
||||
pipeline.PipelineUserType.Producer, self.num_tile_stage
|
||||
)
|
||||
|
||||
while work_tile.is_valid_tile:
|
||||
cur_tile_coord = work_tile.tile_idx
|
||||
mma_tile_coord_m = cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape)
|
||||
if mma_tile_coord_m < num_non_exiting_tiles[0]:
|
||||
tile_info_pipeline.producer_acquire(tile_info_producer_state)
|
||||
num_non_exiting_tiles_value = num_non_exiting_tiles[0]
|
||||
|
||||
if cutlass.const_expr(self.raster_along_m):
|
||||
while work_tile.is_valid_tile:
|
||||
cur_tile_coord = work_tile.tile_idx
|
||||
expert_idx = tile_idx_to_expert_idx[mma_tile_coord_m]
|
||||
mn_limit = tile_idx_to_mn_limit[mma_tile_coord_m]
|
||||
with cute.arch.elect_one():
|
||||
sInfo[(0, tile_info_producer_state.index)] = cur_tile_coord[0]
|
||||
sInfo[(1, tile_info_producer_state.index)] = cur_tile_coord[1]
|
||||
sInfo[(2, tile_info_producer_state.index)] = expert_idx
|
||||
sInfo[(3, tile_info_producer_state.index)] = cutlass.Int32(
|
||||
work_tile.is_valid_tile
|
||||
mma_tile_coord_m = cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape)
|
||||
if mma_tile_coord_m < num_non_exiting_tiles_value:
|
||||
tile_info_pipeline.producer_acquire(tile_info_producer_state)
|
||||
cur_tile_coord = work_tile.tile_idx
|
||||
expert_idx = tile_idx_to_expert_idx[mma_tile_coord_m]
|
||||
mn_limit = tile_idx_to_mn_limit[mma_tile_coord_m]
|
||||
with cute.arch.elect_one():
|
||||
sInfo[(0, tile_info_producer_state.index)] = cur_tile_coord[0]
|
||||
sInfo[(1, tile_info_producer_state.index)] = cur_tile_coord[1]
|
||||
sInfo[(2, tile_info_producer_state.index)] = expert_idx
|
||||
sInfo[(3, tile_info_producer_state.index)] = cutlass.Int32(
|
||||
work_tile.is_valid_tile
|
||||
)
|
||||
sInfo[(4, tile_info_producer_state.index)] = mn_limit
|
||||
# fence view async shared
|
||||
cute.arch.fence_proxy(
|
||||
cute.arch.ProxyKind.async_shared,
|
||||
space=cute.arch.SharedSpace.shared_cta,
|
||||
)
|
||||
sInfo[(4, tile_info_producer_state.index)] = mn_limit
|
||||
# fence view async shared
|
||||
cute.arch.fence_proxy(
|
||||
cute.arch.ProxyKind.async_shared,
|
||||
space=cute.arch.SharedSpace.shared_cta,
|
||||
)
|
||||
|
||||
self.sched_sync_barrier.arrive_and_wait()
|
||||
tile_info_pipeline.producer_commit(tile_info_producer_state)
|
||||
tile_info_producer_state.advance()
|
||||
self.sched_sync_barrier.arrive_and_wait()
|
||||
tile_info_pipeline.producer_commit(tile_info_producer_state)
|
||||
tile_info_producer_state.advance()
|
||||
|
||||
tile_sched.advance_to_next_work()
|
||||
work_tile = tile_sched.get_current_work()
|
||||
tile_sched.advance_to_next_work()
|
||||
work_tile = tile_sched.get_current_work()
|
||||
else:
|
||||
is_continue = cutlass.Boolean(1)
|
||||
while work_tile.is_valid_tile and is_continue:
|
||||
cur_tile_coord = work_tile.tile_idx
|
||||
mma_tile_coord_m = cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape)
|
||||
if mma_tile_coord_m < num_non_exiting_tiles_value:
|
||||
tile_info_pipeline.producer_acquire(tile_info_producer_state)
|
||||
cur_tile_coord = work_tile.tile_idx
|
||||
expert_idx = tile_idx_to_expert_idx[mma_tile_coord_m]
|
||||
mn_limit = tile_idx_to_mn_limit[mma_tile_coord_m]
|
||||
with cute.arch.elect_one():
|
||||
sInfo[(0, tile_info_producer_state.index)] = cur_tile_coord[0]
|
||||
sInfo[(1, tile_info_producer_state.index)] = cur_tile_coord[1]
|
||||
sInfo[(2, tile_info_producer_state.index)] = expert_idx
|
||||
sInfo[(3, tile_info_producer_state.index)] = cutlass.Int32(
|
||||
work_tile.is_valid_tile
|
||||
)
|
||||
sInfo[(4, tile_info_producer_state.index)] = mn_limit
|
||||
# fence view async shared
|
||||
cute.arch.fence_proxy(
|
||||
cute.arch.ProxyKind.async_shared,
|
||||
space=cute.arch.SharedSpace.shared_cta,
|
||||
)
|
||||
|
||||
self.sched_sync_barrier.arrive_and_wait()
|
||||
tile_info_pipeline.producer_commit(tile_info_producer_state)
|
||||
tile_info_producer_state.advance()
|
||||
else:
|
||||
is_continue = cutlass.Boolean(0)
|
||||
|
||||
tile_sched.advance_to_next_work()
|
||||
work_tile = tile_sched.get_current_work()
|
||||
|
||||
tile_info_pipeline.producer_acquire(tile_info_producer_state)
|
||||
with cute.arch.elect_one():
|
||||
@ -2781,6 +2961,7 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
|
||||
cta_tile_shape_mnk: Tuple[int, int, int],
|
||||
cluster_shape_mn: Tuple[int, int],
|
||||
max_active_clusters: cutlass.Constexpr,
|
||||
raster_along_m: bool = False,
|
||||
) -> Tuple[utils.PersistentTileSchedulerParams, Tuple[int, int, int]]:
|
||||
"""Use persistent tile scheduler to compute the grid size for the output tensor C.
|
||||
|
||||
@ -2803,7 +2984,9 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
|
||||
num_ctas_mnl = gc[(0, (None, None, None))].shape
|
||||
cluster_shape_mnl = (*cluster_shape_mn, 1)
|
||||
|
||||
tile_sched_params = utils.PersistentTileSchedulerParams(num_ctas_mnl, cluster_shape_mnl)
|
||||
tile_sched_params = utils.PersistentTileSchedulerParams(
|
||||
num_ctas_mnl, cluster_shape_mnl, raster_along_m=raster_along_m
|
||||
)
|
||||
grid = utils.StaticPersistentTileScheduler.get_grid_shape(
|
||||
tile_sched_params, max_active_clusters
|
||||
)
|
||||
@ -3209,3 +3392,33 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
|
||||
stream=stream,
|
||||
epilogue_op=epilogue_op,
|
||||
)
|
||||
|
||||
|
||||
@cute.jit
|
||||
def cvt_sf_MKL_to_M32x4xrm_K4xrk_L(
|
||||
sf_ref_tensor: cute.Tensor,
|
||||
sf_mma_tensor: cute.Tensor,
|
||||
):
|
||||
"""Convert scale factor tensor from MKL layout to mma specification M(32x4xrest_m)xK(4xrest_k)xL layout"""
|
||||
# sf_mma_tensor has flatten shape (32, 4, rest_m, 4, rest_k, l)
|
||||
# group to ((32, 4, rest_m), (4, rest_k), l)
|
||||
sf_mma_tensor = cute.group_modes(sf_mma_tensor, 0, 3)
|
||||
sf_mma_tensor = cute.group_modes(sf_mma_tensor, 1, 3)
|
||||
for i in cutlass.range(cute.size(sf_ref_tensor)):
|
||||
mkl_coord = sf_ref_tensor.layout.get_hier_coord(i)
|
||||
sf_mma_tensor[mkl_coord] = sf_ref_tensor[mkl_coord]
|
||||
|
||||
|
||||
@cute.jit
|
||||
def cvt_sf_M32x4xrm_K4xrk_L_to_MKL(
|
||||
sf_swizzled_tensor: cute.Tensor,
|
||||
sf_unswizzled_tensor: cute.Tensor,
|
||||
):
|
||||
"""Convert scale factor tensor from mma specification M(32x4xrest_m)xK(4xrest_k)xL layout to MKL layout"""
|
||||
# sf_swizzled_tensor has flatten shape (32, 4, rest_m, 4, rest_k, l)
|
||||
# group to ((32, 4, rest_m), (4, rest_k), l)
|
||||
sf_swizzled_tensor = cute.group_modes(sf_swizzled_tensor, 0, 3)
|
||||
sf_swizzled_tensor = cute.group_modes(sf_swizzled_tensor, 1, 3)
|
||||
for i in cutlass.range(cute.size(sf_unswizzled_tensor)):
|
||||
mkl_coord = sf_unswizzled_tensor.layout.get_hier_coord(i)
|
||||
sf_unswizzled_tensor[mkl_coord] = sf_swizzled_tensor[mkl_coord]
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user