From b1c6f6a568423e3c503cda0c7be16580fbeb8007 Mon Sep 17 00:00:00 2001 From: Xiwen Yu Date: Tue, 26 Aug 2025 22:43:39 -0700 Subject: [PATCH] update cutlass and DeepGEMM Signed-off-by: Xiwen Yu --- .gitmodules | 3 +-- 3rdparty/DeepGEMM | 2 +- 3rdparty/cutlass | 2 +- .../moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl | 4 ++-- .../moe_gemm/moe_tma_warp_specialized_traits.h | 12 ------------ .../cutlass_kernels/python/generate_kernels.py | 13 +++++++++++++ .../auto_deploy/custom_ops/flashinfer_rope.py | 2 +- 7 files changed, 19 insertions(+), 19 deletions(-) diff --git a/.gitmodules b/.gitmodules index 6eca578800..45d99f8fe4 100644 --- a/.gitmodules +++ b/.gitmodules @@ -28,5 +28,4 @@ url = https://github.com/zeromq/cppzmq.git [submodule "3rdparty/DeepGEMM"] path = 3rdparty/DeepGEMM - url = https://github.com/VALLIS-NERIA/DeepGEMM.git - branch = cu13_and_sm100f + url = https://github.com/deepseek-ai/DeepGEMM.git diff --git a/3rdparty/DeepGEMM b/3rdparty/DeepGEMM index 4a55b52e0d..89b4089d24 160000 --- a/3rdparty/DeepGEMM +++ b/3rdparty/DeepGEMM @@ -1 +1 @@ -Subproject commit 4a55b52e0d0ae99a9a646f66bd42c22dae059547 +Subproject commit 89b4089d24216c107f8f805d931a817abb241850 diff --git a/3rdparty/cutlass b/3rdparty/cutlass index a1aaf2300a..a49a78ffef 160000 --- a/3rdparty/cutlass +++ b/3rdparty/cutlass @@ -1 +1 @@ -Subproject commit a1aaf2300a8fc3a8106a05436e1a2abad0930443 +Subproject commit a49a78ffefc86a87160dfe0ccc3a3a2d1622c918 diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl index e96132773b..cb1f84b6bc 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl @@ -396,8 +396,8 @@ using SafeBF16 = void; \ /* TRT-LLM uses vector size 16 for block scaled */ \ using KernelScheduleSM103 = std::conditional_t; \ + cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103, \ + cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103>; \ \ using KernelScheduleSM100 = std::conditional_t #endif -#if !defined(CUTLASS_ARCH_MMA_SM103_SUPPORTED) -namespace cutlass::arch -{ -using Sm103 = Sm100; -} - -namespace cutlass::gemm -{ -using KernelPtrArrayTmaWarpSpecialized1SmBlockScaled3xOmmaVs16Sm103 = void; -using KernelPtrArrayTmaWarpSpecialized2SmBlockScaled3xOmmaVs16Sm103 = void; -} // namespace cutlass::gemm -#endif namespace tensorrt_llm::kernels::cutlass_kernels { diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/python/generate_kernels.py b/cpp/tensorrt_llm/kernels/cutlass_kernels/python/generate_kernels.py index 05e4bd33e9..371c1cdbfc 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/python/generate_kernels.py +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/python/generate_kernels.py @@ -3,6 +3,19 @@ import enum import os from itertools import chain, product +file_to_patch = os.path.abspath( + os.path.join( + os.path.dirname(__file__), + "../../../../../3rdparty/cutlass/python/cutlass_library/heuristics_provider.py" + )) +# replace "from library import" to "from cutlass_library.library import" +with open(file_to_patch, "r") as f: + file_contents = f.read() +with open(file_to_patch, "w") as f: + f.write( + file_contents.replace("from library import", + "from cutlass_library.library import")) + from cutlass_library import * diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_rope.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_rope.py index dd65701ec4..e4f329eeec 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_rope.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_rope.py @@ -38,7 +38,7 @@ def apply_rope_with_input_pos_flashinfer( k_shape = k.shape head_dim = cos_sin_cache.shape[-1] - position_ids = position_ids.view(-1).to(q.device) + position_ids = position_ids.view(-1).to(q.device).int() # flashinfer requires int num_nnz = position_ids.shape[0] q_flat = q.view(num_nnz, -1)