mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
sync internal cutlass kernel changes (#3968)
Signed-off-by: Pamela Peng <179191831+pamelap-nvidia@users.noreply.github.com>
This commit is contained in:
parent
99929e724b
commit
f98a80f9d9
@ -393,7 +393,7 @@ def is_grouped_gemm_op_valid(op):
|
||||
|
||||
|
||||
def is_op_valid(op):
|
||||
if op.arch >= 100 and op.arch < 120:
|
||||
if op.arch >= 100:
|
||||
return is_gemm_op_valid_sm100(op)
|
||||
|
||||
if op.gemm_kind == GemmKind.Gemm:
|
||||
@ -526,6 +526,57 @@ def calc_shape_mnk_sm100_grouped_gemm(cta_shape_mn, dtype):
|
||||
return cta_shape_mn + (cta_shape_k, )
|
||||
|
||||
|
||||
def generate_sm120_grouped_gemm_operations(is_arch_enabled):
|
||||
|
||||
if not is_arch_enabled:
|
||||
return []
|
||||
arch = 120
|
||||
supported_dtypes = [e2m1]
|
||||
quant_ops = [TrtLlm_QuantOp.none]
|
||||
epi_tags = [TrtLlm_EpilogueTag.epilogue_op_default]
|
||||
cta_shapes_mnk = [[128, 128, 128], [128, 128, 256], [256, 128, 128],
|
||||
[128, 256, 128]]
|
||||
|
||||
warp_shape = [0, 0, 0] # ignored except for naming
|
||||
stages = 0 # auto
|
||||
|
||||
epi_fusions = [
|
||||
TrtLlm_EpilogueFusion.epilogue_fusion_none,
|
||||
# TrtLlm_EpilogueFusion.epilogue_fusion_finalize
|
||||
]
|
||||
|
||||
cga_shapes = [[1, 1, 1]]
|
||||
|
||||
partial_args = product(supported_dtypes, quant_ops, epi_tags, epi_fusions,
|
||||
cta_shapes_mnk, cga_shapes)
|
||||
|
||||
operations = list()
|
||||
for dtype, quant_op, epi_tag, epi_fusion, cta_shape_mnk, cga_shape in partial_args:
|
||||
cga_tile_shape_mnk = elementwise(cta_shape_mnk, cga_shape, mul)
|
||||
|
||||
# Ignored
|
||||
mainloop_schedule = KernelScheduleType.TmaWarpSpecializedCooperative
|
||||
epi_schedule = EpilogueScheduleType.NoSmemWarpSpecialized
|
||||
|
||||
otypes = [dtype]
|
||||
if dtype in [DataType.e4m3, e2m1]:
|
||||
otypes = [DataType.f16, DataType.bf16]
|
||||
|
||||
for otype in otypes:
|
||||
moe_gemm_operation = TrtLlm_GemmLauncher(
|
||||
GemmKind.Grouped, arch, dtype, dtype, dtype, dtype, otype,
|
||||
quant_op, epi_tag, cga_tile_shape_mnk, warp_shape, stages,
|
||||
cga_shape, mainloop_schedule, epi_schedule, epi_fusion)
|
||||
|
||||
operations.append(moe_gemm_operation)
|
||||
return operations
|
||||
|
||||
|
||||
def generate_sm120_operations(is_arch_enabled):
|
||||
operations = generate_sm120_grouped_gemm_operations(is_arch_enabled)
|
||||
return operations
|
||||
|
||||
|
||||
def generate_sm100_grouped_gemm_operations(is_arch_enabled):
|
||||
if not is_arch_enabled:
|
||||
return []
|
||||
@ -655,6 +706,7 @@ if __name__ == "__main__":
|
||||
(GemmKind.Gemm, 90): [fpA_intB_inl],
|
||||
(GemmKind.Grouped, 90): [moe_gemm_inl],
|
||||
(GemmKind.Grouped, 100): [moe_gemm_inl],
|
||||
(GemmKind.Grouped, 120): [moe_gemm_inl],
|
||||
(GemmKind.Grouped, 80): [sm80_moe_gemm_inl]
|
||||
}
|
||||
|
||||
@ -664,10 +716,10 @@ if __name__ == "__main__":
|
||||
# The goal here is to group kernels with common instantiations together in order to reduce template instantiation overheads.
|
||||
# Template instantiation dominates the time in a compilation unit, so it is the most important factor to improve.
|
||||
operations = []
|
||||
operations += generate_sm120_operations(has_arch(120))
|
||||
operations += generate_sm100_operations(has_arch(100))
|
||||
operations += generate_sm90_operations(has_arch(90))
|
||||
operations += generate_sm80_operations(
|
||||
has_arch(80) or has_arch(89) or has_arch(120))
|
||||
operations += generate_sm80_operations(has_arch(80) or has_arch(89))
|
||||
|
||||
def should_skip(op):
|
||||
is_internal = op.gemm_kind == GemmKind.Grouped
|
||||
|
||||
Loading…
Reference in New Issue
Block a user