sync internal cutlass kernel changes (#3968)

Signed-off-by: Pamela Peng <179191831+pamelap-nvidia@users.noreply.github.com>
This commit is contained in:
Pamela Peng 2025-04-29 20:57:28 -04:00 committed by GitHub
parent 99929e724b
commit f98a80f9d9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -393,7 +393,7 @@ def is_grouped_gemm_op_valid(op):
def is_op_valid(op):
if op.arch >= 100 and op.arch < 120:
if op.arch >= 100:
return is_gemm_op_valid_sm100(op)
if op.gemm_kind == GemmKind.Gemm:
@ -526,6 +526,57 @@ def calc_shape_mnk_sm100_grouped_gemm(cta_shape_mn, dtype):
return cta_shape_mn + (cta_shape_k, )
def generate_sm120_grouped_gemm_operations(is_arch_enabled):
if not is_arch_enabled:
return []
arch = 120
supported_dtypes = [e2m1]
quant_ops = [TrtLlm_QuantOp.none]
epi_tags = [TrtLlm_EpilogueTag.epilogue_op_default]
cta_shapes_mnk = [[128, 128, 128], [128, 128, 256], [256, 128, 128],
[128, 256, 128]]
warp_shape = [0, 0, 0] # ignored except for naming
stages = 0 # auto
epi_fusions = [
TrtLlm_EpilogueFusion.epilogue_fusion_none,
# TrtLlm_EpilogueFusion.epilogue_fusion_finalize
]
cga_shapes = [[1, 1, 1]]
partial_args = product(supported_dtypes, quant_ops, epi_tags, epi_fusions,
cta_shapes_mnk, cga_shapes)
operations = list()
for dtype, quant_op, epi_tag, epi_fusion, cta_shape_mnk, cga_shape in partial_args:
cga_tile_shape_mnk = elementwise(cta_shape_mnk, cga_shape, mul)
# Ignored
mainloop_schedule = KernelScheduleType.TmaWarpSpecializedCooperative
epi_schedule = EpilogueScheduleType.NoSmemWarpSpecialized
otypes = [dtype]
if dtype in [DataType.e4m3, e2m1]:
otypes = [DataType.f16, DataType.bf16]
for otype in otypes:
moe_gemm_operation = TrtLlm_GemmLauncher(
GemmKind.Grouped, arch, dtype, dtype, dtype, dtype, otype,
quant_op, epi_tag, cga_tile_shape_mnk, warp_shape, stages,
cga_shape, mainloop_schedule, epi_schedule, epi_fusion)
operations.append(moe_gemm_operation)
return operations
def generate_sm120_operations(is_arch_enabled):
operations = generate_sm120_grouped_gemm_operations(is_arch_enabled)
return operations
def generate_sm100_grouped_gemm_operations(is_arch_enabled):
if not is_arch_enabled:
return []
@ -655,6 +706,7 @@ if __name__ == "__main__":
(GemmKind.Gemm, 90): [fpA_intB_inl],
(GemmKind.Grouped, 90): [moe_gemm_inl],
(GemmKind.Grouped, 100): [moe_gemm_inl],
(GemmKind.Grouped, 120): [moe_gemm_inl],
(GemmKind.Grouped, 80): [sm80_moe_gemm_inl]
}
@ -664,10 +716,10 @@ if __name__ == "__main__":
# The goal here is to group kernels with common instantiations together in order to reduce template instantiation overheads.
# Template instantiation dominates the time in a compilation unit, so it is the most important factor to improve.
operations = []
operations += generate_sm120_operations(has_arch(120))
operations += generate_sm100_operations(has_arch(100))
operations += generate_sm90_operations(has_arch(90))
operations += generate_sm80_operations(
has_arch(80) or has_arch(89) or has_arch(120))
operations += generate_sm80_operations(has_arch(80) or has_arch(89))
def should_skip(op):
is_internal = op.gemm_kind == GemmKind.Grouped