diff --git a/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py b/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py index 011b41e33e..9dc70b88f5 100644 --- a/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py @@ -2256,7 +2256,7 @@ if IS_CUTLASS_DSL_AVAILABLE: # Define candidates together mma_tiler_mn_candidates = [(128, 128), (128, 256)] - cluster_shape_mn_candidates = [(1, 1)] + cluster_shape_mn_candidates = [(1, 1), (1, 2), (1, 4)] # Map torch dtype to cutlass dtype c_cutlass_dtype = { @@ -2619,8 +2619,8 @@ if IS_CUTLASS_DSL_AVAILABLE: l = 1 # dense GEMM # Define candidates together - mma_tiler_mn_candidates = [(128, 128), (128, 256)] - cluster_shape_mn_candidates = [(1, 1)] + mma_tiler_mn_candidates = [(128, 64), (128, 128), (128, 256)] + cluster_shape_mn_candidates = [(1, 1), (1, 2), (1, 4)] # Map torch dtype to cutlass dtype c_cutlass_dtype = {