[https://nvbugs/5433581][fix] DeepGEMM installation on SBSA (#6588)

Signed-off-by: Zongfei Jing <20381269+zongfeijing@users.noreply.github.com>
This commit is contained in:
Zongfei Jing 2025-08-06 16:44:21 +08:00 committed by GitHub
parent 907c180eb2
commit 0ff8df95b7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 176 additions and 9 deletions

3
.gitignore vendored
View File

@ -43,6 +43,9 @@ tensorrt_llm/bindings/**/*.pyi
tensorrt_llm/deep_ep/
tensorrt_llm/deep_ep_cpp_tllm.*.so
tensorrt_llm/deep_ep_cpp_tllm.pyi
tensorrt_llm/deep_gemm/
tensorrt_llm/deep_gemm_cpp_tllm.*.so
tensorrt_llm/deep_gemm_cpp_tllm.pyi
*docs/cpp_docs*
*docs/source/_cpp_gen*
docs/source/**/*.rst

3
.gitmodules vendored
View File

@ -26,3 +26,6 @@
[submodule "3rdparty/cppzmq"]
path = 3rdparty/cppzmq
url = https://github.com/zeromq/cppzmq.git
[submodule "3rdparty/DeepGEMM"]
path = 3rdparty/DeepGEMM
url = https://github.com/deepseek-ai/DeepGEMM.git

1
3rdparty/DeepGEMM vendored Submodule

@ -0,0 +1 @@
Subproject commit 7b6b5563b9d4c1ae07ffbce7f78ad3ac9204827c

View File

@ -31,6 +31,7 @@ option(BUILD_PYT "Build in PyTorch TorchScript class mode" ON)
option(BUILD_TESTS "Build Google tests" ON)
option(BUILD_BENCHMARKS "Build benchmarks" ON)
option(BUILD_DEEP_EP "Build the Deep EP module" ON)
option(BUILD_DEEP_GEMM "Build the DeepGEMM module" ON)
option(BUILD_MICRO_BENCHMARKS "Build C++ micro benchmarks" OFF)
option(NVTX_DISABLE "Disable all NVTX features" ON)
option(WARNING_IS_ERROR "Treat all warnings as errors" OFF)
@ -199,7 +200,9 @@ set(TRT_LIB TensorRT::NvInfer)
get_filename_component(TRT_LLM_ROOT_DIR ${CMAKE_CURRENT_SOURCE_DIR} PATH)
set(3RDPARTY_DIR ${TRT_LLM_ROOT_DIR}/3rdparty)
if(BINDING_TYPE STREQUAL "pybind" OR BUILD_DEEP_EP)
if(BINDING_TYPE STREQUAL "pybind"
OR BUILD_DEEP_EP
OR BUILD_DEEP_GEMM)
add_subdirectory(${3RDPARTY_DIR}/pybind11
${CMAKE_CURRENT_BINARY_DIR}/pybind11)
endif()
@ -218,7 +221,9 @@ include_directories(
${3RDPARTY_DIR}/cutlass/tools/util/include
${3RDPARTY_DIR}/NVTX/include
${3RDPARTY_DIR}/json/include)
if(BINDING_TYPE STREQUAL "pybind" OR BUILD_DEEP_EP)
if(BINDING_TYPE STREQUAL "pybind"
OR BUILD_DEEP_EP
OR BUILD_DEEP_GEMM)
include_directories(${3RDPARTY_DIR}/pybind11/include)
endif()
if(BINDING_TYPE STREQUAL "nanobind")

View File

@ -314,4 +314,8 @@ if(BUILD_DEEP_EP)
add_subdirectory(deep_ep)
endif()
if(BUILD_DEEP_GEMM)
add_subdirectory(deep_gemm)
endif()
add_subdirectory(plugins)

View File

@ -0,0 +1,126 @@
add_custom_target(deep_gemm)
if(WIN32)
return()
endif()
# Prepare files
# =============
# Use DeepGEMM submodule
set(DEEP_GEMM_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../../3rdparty/DeepGEMM)
get_filename_component(DEEP_GEMM_SOURCE_DIR ${DEEP_GEMM_SOURCE_DIR} ABSOLUTE)
if(NOT EXISTS ${DEEP_GEMM_SOURCE_DIR})
message(
FATAL_ERROR
"DeepGEMM submodule not found at ${DEEP_GEMM_SOURCE_DIR}. Please run: git submodule update --init --recursive"
)
endif()
# Check if submodules are initialized
if(NOT EXISTS ${DEEP_GEMM_SOURCE_DIR}/third-party/cutlass/include)
message(
FATAL_ERROR
"DeepGEMM submodules not initialized. Please run: git submodule update --init --recursive"
)
endif()
# Copy and update python files
set(DEEP_GEMM_PYTHON_DEST ${CMAKE_CURRENT_BINARY_DIR}/python/deep_gemm)
file(REMOVE_RECURSE ${DEEP_GEMM_PYTHON_DEST})
file(MAKE_DIRECTORY ${DEEP_GEMM_PYTHON_DEST})
# Copy all files from deep_gemm directory
file(GLOB_RECURSE DEEP_GEMM_ALL_FILES ${DEEP_GEMM_SOURCE_DIR}/deep_gemm/*)
configure_file(${DEEP_GEMM_SOURCE_DIR}/LICENSE ${DEEP_GEMM_PYTHON_DEST}/LICENSE
COPYONLY)
foreach(SOURCE_FILE ${DEEP_GEMM_ALL_FILES})
file(RELATIVE_PATH REL_PATH ${DEEP_GEMM_SOURCE_DIR}/deep_gemm ${SOURCE_FILE})
get_filename_component(REL_DIR ${REL_PATH} DIRECTORY)
file(MAKE_DIRECTORY ${DEEP_GEMM_PYTHON_DEST}/${REL_DIR})
# Check if it's a Python file that needs import renaming
get_filename_component(FILE_EXT ${SOURCE_FILE} EXT)
if(FILE_EXT STREQUAL ".py")
# Read file content and replace module imports for Python files
file(READ ${SOURCE_FILE} _content)
string(REPLACE "deep_gemm_cpp" "tensorrt_llm.deep_gemm_cpp_tllm" _content
"${_content}")
# Add adaptation header
string(
PREPEND
_content
"# Adapted from https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/${REL_PATH}\n"
)
# Write modified content
set(_dst "${DEEP_GEMM_PYTHON_DEST}/${REL_PATH}")
file(WRITE ${_dst} "${_content}")
else()
# Copy non-Python files as-is
set(_dst "${DEEP_GEMM_PYTHON_DEST}/${REL_PATH}")
file(COPY ${SOURCE_FILE} DESTINATION ${DEEP_GEMM_PYTHON_DEST}/${REL_DIR})
endif()
# Add dependency tracking
set_property(
DIRECTORY
APPEND
PROPERTY CMAKE_CONFIGURE_DEPENDS ${SOURCE_FILE})
endforeach()
# Copy third-party includes (cutlass and fmt) to the include directory
set(DEEP_GEMM_INCLUDE_DEST ${DEEP_GEMM_PYTHON_DEST}/include)
file(MAKE_DIRECTORY ${DEEP_GEMM_INCLUDE_DEST})
file(COPY ${DEEP_GEMM_SOURCE_DIR}/third-party/cutlass/include/cute
DESTINATION ${DEEP_GEMM_INCLUDE_DEST})
file(COPY ${DEEP_GEMM_SOURCE_DIR}/third-party/cutlass/include/cutlass
DESTINATION ${DEEP_GEMM_INCLUDE_DEST})
# Find torch_python
find_library(TORCH_PYTHON_LIB torch_python REQUIRED
HINTS ${TORCH_INSTALL_PREFIX}/lib)
# Build deep_gemm_cpp_tllm extension (matching deep_gemm's setup.py)
set(DEEP_GEMM_SOURCES ${DEEP_GEMM_SOURCE_DIR}/csrc/python_api.cpp)
pybind11_add_module(deep_gemm_cpp_tllm ${DEEP_GEMM_SOURCES})
set_target_properties(
deep_gemm_cpp_tllm
PROPERTIES CXX_STANDARD_REQUIRED ON
CXX_STANDARD 17
CXX_SCAN_FOR_MODULES OFF
CUDA_SEPARABLE_COMPILATION ON
LINK_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/deep_gemm_cpp_tllm.version
INSTALL_RPATH "${TORCH_INSTALL_PREFIX}/lib"
BUILD_WITH_INSTALL_RPATH TRUE)
target_compile_options(deep_gemm_cpp_tllm PRIVATE ${TORCH_CXX_FLAGS} -std=c++17
-O3 -fPIC -Wno-psabi)
# Extension name definition
target_compile_definitions(deep_gemm_cpp_tllm
PRIVATE TORCH_EXTENSION_NAME=deep_gemm_cpp_tllm)
# Include directories matching deep_gemm setup.py
target_include_directories(
deep_gemm_cpp_tllm
PRIVATE ${CUDA_INCLUDE_DIRS} ${DEEP_GEMM_SOURCE_DIR}/deep_gemm/include
${DEEP_GEMM_SOURCE_DIR}/third-party/cutlass/include
${DEEP_GEMM_SOURCE_DIR}/third-party/fmt/include)
# Link libraries (matching deep_gemm setup.py: cuda, cudart + torch)
target_link_libraries(
deep_gemm_cpp_tllm PRIVATE ${TORCH_LIBRARIES} ${TORCH_PYTHON_LIB}
CUDA::cuda_driver CUDA::cudart)
# Link directories
target_link_directories(
deep_gemm_cpp_tllm PRIVATE ${CUDA_TOOLKIT_ROOT_DIR}/lib64
${CUDA_TOOLKIT_ROOT_DIR}/lib64/stubs)
# Set targets
# ===========
add_dependencies(deep_gemm deep_gemm_cpp_tllm)

View File

@ -0,0 +1,4 @@
{
global: PyInit_deep_gemm_cpp_tllm;
local: *;
};

View File

@ -61,6 +61,5 @@ etcd3
blake3
llguidance==0.7.29
soundfile
deep_gemm @ git+https://github.com/zongfeijing/DeepGEMM.git@a9d538ef4dff0326fe521c6ca0bfde115703b56a
triton==3.3.1; platform_machine == "x86_64"
blobfile

View File

@ -448,10 +448,12 @@ def main(*,
if cpp_only:
build_pyt = "OFF"
build_deep_ep = "OFF"
build_deep_gemm = "OFF"
else:
targets.extend(["th_common", "bindings", "deep_ep"])
targets.extend(["th_common", "bindings", "deep_ep", "deep_gemm"])
build_pyt = "ON"
build_deep_ep = "ON"
build_deep_gemm = "ON"
if benchmarks:
targets.append("benchmarks")
@ -490,7 +492,7 @@ def main(*,
)
cmake_def_args = " ".join(cmake_def_args)
cmake_configure_command = (
f'cmake -DCMAKE_BUILD_TYPE="{build_type}" -DBUILD_PYT="{build_pyt}" -DBINDING_TYPE="{binding_type}" -DBUILD_DEEP_EP="{build_deep_ep}"'
f'cmake -DCMAKE_BUILD_TYPE="{build_type}" -DBUILD_PYT="{build_pyt}" -DBINDING_TYPE="{binding_type}" -DBUILD_DEEP_EP="{build_deep_ep}" -DBUILD_DEEP_GEMM="{build_deep_gemm}"'
f' -DNVTX_DISABLE="{disable_nvtx}" -DBUILD_MICRO_BENCHMARKS={build_micro_benchmarks}'
f' -DBUILD_WHEEL_TARGETS="{";".join(targets)}"'
f' -DPython_EXECUTABLE={venv_python} -DPython3_EXECUTABLE={venv_python}'
@ -637,6 +639,14 @@ def main(*,
clear_folder(deep_ep_dir)
deep_ep_dir.rmdir()
# Handle deep_gemm installation
deep_gemm_dir = pkg_dir / "deep_gemm"
if deep_gemm_dir.is_symlink():
deep_gemm_dir.unlink()
elif deep_gemm_dir.is_dir():
clear_folder(deep_gemm_dir)
deep_gemm_dir.rmdir()
bin_dir = pkg_dir / "bin"
if bin_dir.exists():
clear_folder(bin_dir)
@ -684,6 +694,14 @@ def main(*,
build_dir /
"tensorrt_llm/deep_ep/nvshmem-build/src/lib/nvshmem_transport_ibgda.so.103",
lib_dir / "nvshmem")
install_file(get_binding_lib("deep_gemm", "deep_gemm_cpp_tllm"),
pkg_dir)
install_tree(build_dir / "tensorrt_llm" / "deep_gemm" / "python" /
"deep_gemm",
deep_gemm_dir,
dirs_exist_ok=True)
if not skip_stubs:
with working_directory(project_dir):
if binding_type == "nanobind":
@ -757,6 +775,9 @@ def main(*,
build_run(
f"\"{venv_python}\" -m pybind11_stubgen -o . deep_ep_cpp_tllm --exit-code",
env=env_ld)
build_run(
f"\"{venv_python}\" -m pybind11_stubgen -o . deep_gemm_cpp_tllm --exit-code",
env=env_ld)
if not skip_building_wheel:
if dist_dir is None:

View File

@ -107,7 +107,8 @@ else:
'libs/libdecoder_attention_1.so', 'libs/nvshmem/License.txt',
'libs/nvshmem/nvshmem_bootstrap_uid.so.3',
'libs/nvshmem/nvshmem_transport_ibgda.so.103', 'bindings.*.so',
'deep_ep/LICENSE', 'deep_ep_cpp_tllm.*.so', "include/**/*"
'deep_ep/LICENSE', 'deep_ep_cpp_tllm.*.so', "include/**/*",
'deep_gemm/LICENSE', 'deep_gemm/include/**/*', 'deep_gemm_cpp_tllm.*.so'
]
package_data += [

View File

@ -1,12 +1,12 @@
from typing import List, Optional, Union
import deep_gemm
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
import tensorrt_llm.quantization.utils.fp8_utils as fp8_utils
from tensorrt_llm import deep_gemm
from tensorrt_llm._utils import nvtx_range
from ...distributed import allgather

View File

@ -573,7 +573,7 @@ class FP8BlockScalesLinearMethod(LinearMethodBase):
assert input.dtype == torch.bfloat16
if get_sm_version() == 100:
import deep_gemm
from tensorrt_llm import deep_gemm
a, a_sf = fp8_utils.per_token_quant_and_transform(input)
output = torch.empty((input.shape[0], module.weight.shape[0]),
device=input.device,

View File

@ -50,7 +50,7 @@ def test_fp8_block_scale_deep_gemm(dtype, m, k, n):
act_b_fp8, act_b_sf = per_block_cast_to_fp8_e8m0(b)
output_expected = a @ b.t()
import deep_gemm
from tensorrt_llm import deep_gemm
output = torch.empty((act_a_fp8.shape[0], act_b_fp8.shape[0]),
device=act_a_fp8.device,
dtype=torch.bfloat16)