mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-24 04:33:04 +08:00
Merge branch 'user/xiweny/update_cutlass_4.2' into 'feat/b300_cu13'
update cutlass and DeepGEMM See merge request ftp/tekit!9678 Signed-off-by: Xiwen Yu <xiweny@nvidia.com>
This commit is contained in:
commit
9ad68de159
3
.gitmodules
vendored
3
.gitmodules
vendored
@ -28,5 +28,4 @@
|
||||
url = https://github.com/zeromq/cppzmq.git
|
||||
[submodule "3rdparty/DeepGEMM"]
|
||||
path = 3rdparty/DeepGEMM
|
||||
url = https://github.com/VALLIS-NERIA/DeepGEMM.git
|
||||
branch = cu13_and_sm100f
|
||||
url = https://github.com/deepseek-ai/DeepGEMM.git
|
||||
|
||||
2
3rdparty/DeepGEMM
vendored
2
3rdparty/DeepGEMM
vendored
@ -1 +1 @@
|
||||
Subproject commit 4a55b52e0d0ae99a9a646f66bd42c22dae059547
|
||||
Subproject commit 89b4089d24216c107f8f805d931a817abb241850
|
||||
2
3rdparty/cutlass
vendored
2
3rdparty/cutlass
vendored
@ -1 +1 @@
|
||||
Subproject commit a1aaf2300a8fc3a8106a05436e1a2abad0930443
|
||||
Subproject commit a49a78ffefc86a87160dfe0ccc3a3a2d1622c918
|
||||
@ -396,8 +396,8 @@ using SafeBF16 = void;
|
||||
\
|
||||
/* TRT-LLM uses vector size 16 for block scaled */ \
|
||||
using KernelScheduleSM103 = std::conditional_t<Is2SM, \
|
||||
cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaled3xOmmaVs16Sm103, \
|
||||
cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaled3xOmmaVs16Sm103>; \
|
||||
cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103, \
|
||||
cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103>; \
|
||||
\
|
||||
using KernelScheduleSM100 = std::conditional_t<Is2SM, \
|
||||
std::conditional_t<IsBlockScaled, KernelSchedule2SmSm100BlockScaled, \
|
||||
|
||||
@ -24,18 +24,6 @@
|
||||
#include <cuda_fp4.h>
|
||||
#endif
|
||||
|
||||
#if !defined(CUTLASS_ARCH_MMA_SM103_SUPPORTED)
|
||||
namespace cutlass::arch
|
||||
{
|
||||
using Sm103 = Sm100;
|
||||
}
|
||||
|
||||
namespace cutlass::gemm
|
||||
{
|
||||
using KernelPtrArrayTmaWarpSpecialized1SmBlockScaled3xOmmaVs16Sm103 = void;
|
||||
using KernelPtrArrayTmaWarpSpecialized2SmBlockScaled3xOmmaVs16Sm103 = void;
|
||||
} // namespace cutlass::gemm
|
||||
#endif
|
||||
namespace tensorrt_llm::kernels::cutlass_kernels
|
||||
{
|
||||
|
||||
|
||||
@ -3,6 +3,19 @@ import enum
|
||||
import os
|
||||
from itertools import chain, product
|
||||
|
||||
file_to_patch = os.path.abspath(
|
||||
os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
"../../../../../3rdparty/cutlass/python/cutlass_library/heuristics_provider.py"
|
||||
))
|
||||
# replace "from library import" to "from cutlass_library.library import"
|
||||
with open(file_to_patch, "r") as f:
|
||||
file_contents = f.read()
|
||||
with open(file_to_patch, "w") as f:
|
||||
f.write(
|
||||
file_contents.replace("from library import",
|
||||
"from cutlass_library.library import"))
|
||||
|
||||
from cutlass_library import *
|
||||
|
||||
|
||||
|
||||
@ -38,7 +38,7 @@ def apply_rope_with_input_pos_flashinfer(
|
||||
k_shape = k.shape
|
||||
head_dim = cos_sin_cache.shape[-1]
|
||||
|
||||
position_ids = position_ids.view(-1).to(q.device)
|
||||
position_ids = position_ids.view(-1).to(q.device).int() # flashinfer requires int
|
||||
num_nnz = position_ids.shape[0]
|
||||
|
||||
q_flat = q.view(num_nnz, -1)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user