mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[https://nvbugs/5433581][fix] DeepGEMM installation on SBSA (#6588)
Signed-off-by: Zongfei Jing <20381269+zongfeijing@users.noreply.github.com>
This commit is contained in:
parent
907c180eb2
commit
0ff8df95b7
3
.gitignore
vendored
3
.gitignore
vendored
@ -43,6 +43,9 @@ tensorrt_llm/bindings/**/*.pyi
|
|||||||
tensorrt_llm/deep_ep/
|
tensorrt_llm/deep_ep/
|
||||||
tensorrt_llm/deep_ep_cpp_tllm.*.so
|
tensorrt_llm/deep_ep_cpp_tllm.*.so
|
||||||
tensorrt_llm/deep_ep_cpp_tllm.pyi
|
tensorrt_llm/deep_ep_cpp_tllm.pyi
|
||||||
|
tensorrt_llm/deep_gemm/
|
||||||
|
tensorrt_llm/deep_gemm_cpp_tllm.*.so
|
||||||
|
tensorrt_llm/deep_gemm_cpp_tllm.pyi
|
||||||
*docs/cpp_docs*
|
*docs/cpp_docs*
|
||||||
*docs/source/_cpp_gen*
|
*docs/source/_cpp_gen*
|
||||||
docs/source/**/*.rst
|
docs/source/**/*.rst
|
||||||
|
|||||||
3
.gitmodules
vendored
3
.gitmodules
vendored
@ -26,3 +26,6 @@
|
|||||||
[submodule "3rdparty/cppzmq"]
|
[submodule "3rdparty/cppzmq"]
|
||||||
path = 3rdparty/cppzmq
|
path = 3rdparty/cppzmq
|
||||||
url = https://github.com/zeromq/cppzmq.git
|
url = https://github.com/zeromq/cppzmq.git
|
||||||
|
[submodule "3rdparty/DeepGEMM"]
|
||||||
|
path = 3rdparty/DeepGEMM
|
||||||
|
url = https://github.com/deepseek-ai/DeepGEMM.git
|
||||||
|
|||||||
1
3rdparty/DeepGEMM
vendored
Submodule
1
3rdparty/DeepGEMM
vendored
Submodule
@ -0,0 +1 @@
|
|||||||
|
Subproject commit 7b6b5563b9d4c1ae07ffbce7f78ad3ac9204827c
|
||||||
@ -31,6 +31,7 @@ option(BUILD_PYT "Build in PyTorch TorchScript class mode" ON)
|
|||||||
option(BUILD_TESTS "Build Google tests" ON)
|
option(BUILD_TESTS "Build Google tests" ON)
|
||||||
option(BUILD_BENCHMARKS "Build benchmarks" ON)
|
option(BUILD_BENCHMARKS "Build benchmarks" ON)
|
||||||
option(BUILD_DEEP_EP "Build the Deep EP module" ON)
|
option(BUILD_DEEP_EP "Build the Deep EP module" ON)
|
||||||
|
option(BUILD_DEEP_GEMM "Build the DeepGEMM module" ON)
|
||||||
option(BUILD_MICRO_BENCHMARKS "Build C++ micro benchmarks" OFF)
|
option(BUILD_MICRO_BENCHMARKS "Build C++ micro benchmarks" OFF)
|
||||||
option(NVTX_DISABLE "Disable all NVTX features" ON)
|
option(NVTX_DISABLE "Disable all NVTX features" ON)
|
||||||
option(WARNING_IS_ERROR "Treat all warnings as errors" OFF)
|
option(WARNING_IS_ERROR "Treat all warnings as errors" OFF)
|
||||||
@ -199,7 +200,9 @@ set(TRT_LIB TensorRT::NvInfer)
|
|||||||
get_filename_component(TRT_LLM_ROOT_DIR ${CMAKE_CURRENT_SOURCE_DIR} PATH)
|
get_filename_component(TRT_LLM_ROOT_DIR ${CMAKE_CURRENT_SOURCE_DIR} PATH)
|
||||||
|
|
||||||
set(3RDPARTY_DIR ${TRT_LLM_ROOT_DIR}/3rdparty)
|
set(3RDPARTY_DIR ${TRT_LLM_ROOT_DIR}/3rdparty)
|
||||||
if(BINDING_TYPE STREQUAL "pybind" OR BUILD_DEEP_EP)
|
if(BINDING_TYPE STREQUAL "pybind"
|
||||||
|
OR BUILD_DEEP_EP
|
||||||
|
OR BUILD_DEEP_GEMM)
|
||||||
add_subdirectory(${3RDPARTY_DIR}/pybind11
|
add_subdirectory(${3RDPARTY_DIR}/pybind11
|
||||||
${CMAKE_CURRENT_BINARY_DIR}/pybind11)
|
${CMAKE_CURRENT_BINARY_DIR}/pybind11)
|
||||||
endif()
|
endif()
|
||||||
@ -218,7 +221,9 @@ include_directories(
|
|||||||
${3RDPARTY_DIR}/cutlass/tools/util/include
|
${3RDPARTY_DIR}/cutlass/tools/util/include
|
||||||
${3RDPARTY_DIR}/NVTX/include
|
${3RDPARTY_DIR}/NVTX/include
|
||||||
${3RDPARTY_DIR}/json/include)
|
${3RDPARTY_DIR}/json/include)
|
||||||
if(BINDING_TYPE STREQUAL "pybind" OR BUILD_DEEP_EP)
|
if(BINDING_TYPE STREQUAL "pybind"
|
||||||
|
OR BUILD_DEEP_EP
|
||||||
|
OR BUILD_DEEP_GEMM)
|
||||||
include_directories(${3RDPARTY_DIR}/pybind11/include)
|
include_directories(${3RDPARTY_DIR}/pybind11/include)
|
||||||
endif()
|
endif()
|
||||||
if(BINDING_TYPE STREQUAL "nanobind")
|
if(BINDING_TYPE STREQUAL "nanobind")
|
||||||
|
|||||||
@ -314,4 +314,8 @@ if(BUILD_DEEP_EP)
|
|||||||
add_subdirectory(deep_ep)
|
add_subdirectory(deep_ep)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
if(BUILD_DEEP_GEMM)
|
||||||
|
add_subdirectory(deep_gemm)
|
||||||
|
endif()
|
||||||
|
|
||||||
add_subdirectory(plugins)
|
add_subdirectory(plugins)
|
||||||
|
|||||||
126
cpp/tensorrt_llm/deep_gemm/CMakeLists.txt
Normal file
126
cpp/tensorrt_llm/deep_gemm/CMakeLists.txt
Normal file
@ -0,0 +1,126 @@
|
|||||||
|
add_custom_target(deep_gemm)
|
||||||
|
|
||||||
|
if(WIN32)
|
||||||
|
return()
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# Prepare files
|
||||||
|
# =============
|
||||||
|
|
||||||
|
# Use DeepGEMM submodule
|
||||||
|
set(DEEP_GEMM_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../../3rdparty/DeepGEMM)
|
||||||
|
get_filename_component(DEEP_GEMM_SOURCE_DIR ${DEEP_GEMM_SOURCE_DIR} ABSOLUTE)
|
||||||
|
|
||||||
|
if(NOT EXISTS ${DEEP_GEMM_SOURCE_DIR})
|
||||||
|
message(
|
||||||
|
FATAL_ERROR
|
||||||
|
"DeepGEMM submodule not found at ${DEEP_GEMM_SOURCE_DIR}. Please run: git submodule update --init --recursive"
|
||||||
|
)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# Check if submodules are initialized
|
||||||
|
if(NOT EXISTS ${DEEP_GEMM_SOURCE_DIR}/third-party/cutlass/include)
|
||||||
|
message(
|
||||||
|
FATAL_ERROR
|
||||||
|
"DeepGEMM submodules not initialized. Please run: git submodule update --init --recursive"
|
||||||
|
)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# Copy and update python files
|
||||||
|
set(DEEP_GEMM_PYTHON_DEST ${CMAKE_CURRENT_BINARY_DIR}/python/deep_gemm)
|
||||||
|
file(REMOVE_RECURSE ${DEEP_GEMM_PYTHON_DEST})
|
||||||
|
file(MAKE_DIRECTORY ${DEEP_GEMM_PYTHON_DEST})
|
||||||
|
|
||||||
|
# Copy all files from deep_gemm directory
|
||||||
|
file(GLOB_RECURSE DEEP_GEMM_ALL_FILES ${DEEP_GEMM_SOURCE_DIR}/deep_gemm/*)
|
||||||
|
configure_file(${DEEP_GEMM_SOURCE_DIR}/LICENSE ${DEEP_GEMM_PYTHON_DEST}/LICENSE
|
||||||
|
COPYONLY)
|
||||||
|
foreach(SOURCE_FILE ${DEEP_GEMM_ALL_FILES})
|
||||||
|
file(RELATIVE_PATH REL_PATH ${DEEP_GEMM_SOURCE_DIR}/deep_gemm ${SOURCE_FILE})
|
||||||
|
get_filename_component(REL_DIR ${REL_PATH} DIRECTORY)
|
||||||
|
file(MAKE_DIRECTORY ${DEEP_GEMM_PYTHON_DEST}/${REL_DIR})
|
||||||
|
|
||||||
|
# Check if it's a Python file that needs import renaming
|
||||||
|
get_filename_component(FILE_EXT ${SOURCE_FILE} EXT)
|
||||||
|
if(FILE_EXT STREQUAL ".py")
|
||||||
|
# Read file content and replace module imports for Python files
|
||||||
|
file(READ ${SOURCE_FILE} _content)
|
||||||
|
string(REPLACE "deep_gemm_cpp" "tensorrt_llm.deep_gemm_cpp_tllm" _content
|
||||||
|
"${_content}")
|
||||||
|
|
||||||
|
# Add adaptation header
|
||||||
|
string(
|
||||||
|
PREPEND
|
||||||
|
_content
|
||||||
|
"# Adapted from https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/${REL_PATH}\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Write modified content
|
||||||
|
set(_dst "${DEEP_GEMM_PYTHON_DEST}/${REL_PATH}")
|
||||||
|
file(WRITE ${_dst} "${_content}")
|
||||||
|
else()
|
||||||
|
# Copy non-Python files as-is
|
||||||
|
set(_dst "${DEEP_GEMM_PYTHON_DEST}/${REL_PATH}")
|
||||||
|
file(COPY ${SOURCE_FILE} DESTINATION ${DEEP_GEMM_PYTHON_DEST}/${REL_DIR})
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# Add dependency tracking
|
||||||
|
set_property(
|
||||||
|
DIRECTORY
|
||||||
|
APPEND
|
||||||
|
PROPERTY CMAKE_CONFIGURE_DEPENDS ${SOURCE_FILE})
|
||||||
|
endforeach()
|
||||||
|
|
||||||
|
# Copy third-party includes (cutlass and fmt) to the include directory
|
||||||
|
set(DEEP_GEMM_INCLUDE_DEST ${DEEP_GEMM_PYTHON_DEST}/include)
|
||||||
|
file(MAKE_DIRECTORY ${DEEP_GEMM_INCLUDE_DEST})
|
||||||
|
file(COPY ${DEEP_GEMM_SOURCE_DIR}/third-party/cutlass/include/cute
|
||||||
|
DESTINATION ${DEEP_GEMM_INCLUDE_DEST})
|
||||||
|
file(COPY ${DEEP_GEMM_SOURCE_DIR}/third-party/cutlass/include/cutlass
|
||||||
|
DESTINATION ${DEEP_GEMM_INCLUDE_DEST})
|
||||||
|
|
||||||
|
# Find torch_python
|
||||||
|
find_library(TORCH_PYTHON_LIB torch_python REQUIRED
|
||||||
|
HINTS ${TORCH_INSTALL_PREFIX}/lib)
|
||||||
|
|
||||||
|
# Build deep_gemm_cpp_tllm extension (matching deep_gemm's setup.py)
|
||||||
|
set(DEEP_GEMM_SOURCES ${DEEP_GEMM_SOURCE_DIR}/csrc/python_api.cpp)
|
||||||
|
|
||||||
|
pybind11_add_module(deep_gemm_cpp_tllm ${DEEP_GEMM_SOURCES})
|
||||||
|
set_target_properties(
|
||||||
|
deep_gemm_cpp_tllm
|
||||||
|
PROPERTIES CXX_STANDARD_REQUIRED ON
|
||||||
|
CXX_STANDARD 17
|
||||||
|
CXX_SCAN_FOR_MODULES OFF
|
||||||
|
CUDA_SEPARABLE_COMPILATION ON
|
||||||
|
LINK_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/deep_gemm_cpp_tllm.version
|
||||||
|
INSTALL_RPATH "${TORCH_INSTALL_PREFIX}/lib"
|
||||||
|
BUILD_WITH_INSTALL_RPATH TRUE)
|
||||||
|
|
||||||
|
target_compile_options(deep_gemm_cpp_tllm PRIVATE ${TORCH_CXX_FLAGS} -std=c++17
|
||||||
|
-O3 -fPIC -Wno-psabi)
|
||||||
|
|
||||||
|
# Extension name definition
|
||||||
|
target_compile_definitions(deep_gemm_cpp_tllm
|
||||||
|
PRIVATE TORCH_EXTENSION_NAME=deep_gemm_cpp_tllm)
|
||||||
|
|
||||||
|
# Include directories matching deep_gemm setup.py
|
||||||
|
target_include_directories(
|
||||||
|
deep_gemm_cpp_tllm
|
||||||
|
PRIVATE ${CUDA_INCLUDE_DIRS} ${DEEP_GEMM_SOURCE_DIR}/deep_gemm/include
|
||||||
|
${DEEP_GEMM_SOURCE_DIR}/third-party/cutlass/include
|
||||||
|
${DEEP_GEMM_SOURCE_DIR}/third-party/fmt/include)
|
||||||
|
|
||||||
|
# Link libraries (matching deep_gemm setup.py: cuda, cudart + torch)
|
||||||
|
target_link_libraries(
|
||||||
|
deep_gemm_cpp_tllm PRIVATE ${TORCH_LIBRARIES} ${TORCH_PYTHON_LIB}
|
||||||
|
CUDA::cuda_driver CUDA::cudart)
|
||||||
|
|
||||||
|
# Link directories
|
||||||
|
target_link_directories(
|
||||||
|
deep_gemm_cpp_tllm PRIVATE ${CUDA_TOOLKIT_ROOT_DIR}/lib64
|
||||||
|
${CUDA_TOOLKIT_ROOT_DIR}/lib64/stubs)
|
||||||
|
|
||||||
|
# Set targets
|
||||||
|
# ===========
|
||||||
|
add_dependencies(deep_gemm deep_gemm_cpp_tllm)
|
||||||
4
cpp/tensorrt_llm/deep_gemm/deep_gemm_cpp_tllm.version
Normal file
4
cpp/tensorrt_llm/deep_gemm/deep_gemm_cpp_tllm.version
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
{
|
||||||
|
global: PyInit_deep_gemm_cpp_tllm;
|
||||||
|
local: *;
|
||||||
|
};
|
||||||
@ -61,6 +61,5 @@ etcd3
|
|||||||
blake3
|
blake3
|
||||||
llguidance==0.7.29
|
llguidance==0.7.29
|
||||||
soundfile
|
soundfile
|
||||||
deep_gemm @ git+https://github.com/zongfeijing/DeepGEMM.git@a9d538ef4dff0326fe521c6ca0bfde115703b56a
|
|
||||||
triton==3.3.1; platform_machine == "x86_64"
|
triton==3.3.1; platform_machine == "x86_64"
|
||||||
blobfile
|
blobfile
|
||||||
|
|||||||
@ -448,10 +448,12 @@ def main(*,
|
|||||||
if cpp_only:
|
if cpp_only:
|
||||||
build_pyt = "OFF"
|
build_pyt = "OFF"
|
||||||
build_deep_ep = "OFF"
|
build_deep_ep = "OFF"
|
||||||
|
build_deep_gemm = "OFF"
|
||||||
else:
|
else:
|
||||||
targets.extend(["th_common", "bindings", "deep_ep"])
|
targets.extend(["th_common", "bindings", "deep_ep", "deep_gemm"])
|
||||||
build_pyt = "ON"
|
build_pyt = "ON"
|
||||||
build_deep_ep = "ON"
|
build_deep_ep = "ON"
|
||||||
|
build_deep_gemm = "ON"
|
||||||
|
|
||||||
if benchmarks:
|
if benchmarks:
|
||||||
targets.append("benchmarks")
|
targets.append("benchmarks")
|
||||||
@ -490,7 +492,7 @@ def main(*,
|
|||||||
)
|
)
|
||||||
cmake_def_args = " ".join(cmake_def_args)
|
cmake_def_args = " ".join(cmake_def_args)
|
||||||
cmake_configure_command = (
|
cmake_configure_command = (
|
||||||
f'cmake -DCMAKE_BUILD_TYPE="{build_type}" -DBUILD_PYT="{build_pyt}" -DBINDING_TYPE="{binding_type}" -DBUILD_DEEP_EP="{build_deep_ep}"'
|
f'cmake -DCMAKE_BUILD_TYPE="{build_type}" -DBUILD_PYT="{build_pyt}" -DBINDING_TYPE="{binding_type}" -DBUILD_DEEP_EP="{build_deep_ep}" -DBUILD_DEEP_GEMM="{build_deep_gemm}"'
|
||||||
f' -DNVTX_DISABLE="{disable_nvtx}" -DBUILD_MICRO_BENCHMARKS={build_micro_benchmarks}'
|
f' -DNVTX_DISABLE="{disable_nvtx}" -DBUILD_MICRO_BENCHMARKS={build_micro_benchmarks}'
|
||||||
f' -DBUILD_WHEEL_TARGETS="{";".join(targets)}"'
|
f' -DBUILD_WHEEL_TARGETS="{";".join(targets)}"'
|
||||||
f' -DPython_EXECUTABLE={venv_python} -DPython3_EXECUTABLE={venv_python}'
|
f' -DPython_EXECUTABLE={venv_python} -DPython3_EXECUTABLE={venv_python}'
|
||||||
@ -637,6 +639,14 @@ def main(*,
|
|||||||
clear_folder(deep_ep_dir)
|
clear_folder(deep_ep_dir)
|
||||||
deep_ep_dir.rmdir()
|
deep_ep_dir.rmdir()
|
||||||
|
|
||||||
|
# Handle deep_gemm installation
|
||||||
|
deep_gemm_dir = pkg_dir / "deep_gemm"
|
||||||
|
if deep_gemm_dir.is_symlink():
|
||||||
|
deep_gemm_dir.unlink()
|
||||||
|
elif deep_gemm_dir.is_dir():
|
||||||
|
clear_folder(deep_gemm_dir)
|
||||||
|
deep_gemm_dir.rmdir()
|
||||||
|
|
||||||
bin_dir = pkg_dir / "bin"
|
bin_dir = pkg_dir / "bin"
|
||||||
if bin_dir.exists():
|
if bin_dir.exists():
|
||||||
clear_folder(bin_dir)
|
clear_folder(bin_dir)
|
||||||
@ -684,6 +694,14 @@ def main(*,
|
|||||||
build_dir /
|
build_dir /
|
||||||
"tensorrt_llm/deep_ep/nvshmem-build/src/lib/nvshmem_transport_ibgda.so.103",
|
"tensorrt_llm/deep_ep/nvshmem-build/src/lib/nvshmem_transport_ibgda.so.103",
|
||||||
lib_dir / "nvshmem")
|
lib_dir / "nvshmem")
|
||||||
|
|
||||||
|
install_file(get_binding_lib("deep_gemm", "deep_gemm_cpp_tllm"),
|
||||||
|
pkg_dir)
|
||||||
|
install_tree(build_dir / "tensorrt_llm" / "deep_gemm" / "python" /
|
||||||
|
"deep_gemm",
|
||||||
|
deep_gemm_dir,
|
||||||
|
dirs_exist_ok=True)
|
||||||
|
|
||||||
if not skip_stubs:
|
if not skip_stubs:
|
||||||
with working_directory(project_dir):
|
with working_directory(project_dir):
|
||||||
if binding_type == "nanobind":
|
if binding_type == "nanobind":
|
||||||
@ -757,6 +775,9 @@ def main(*,
|
|||||||
build_run(
|
build_run(
|
||||||
f"\"{venv_python}\" -m pybind11_stubgen -o . deep_ep_cpp_tllm --exit-code",
|
f"\"{venv_python}\" -m pybind11_stubgen -o . deep_ep_cpp_tllm --exit-code",
|
||||||
env=env_ld)
|
env=env_ld)
|
||||||
|
build_run(
|
||||||
|
f"\"{venv_python}\" -m pybind11_stubgen -o . deep_gemm_cpp_tllm --exit-code",
|
||||||
|
env=env_ld)
|
||||||
|
|
||||||
if not skip_building_wheel:
|
if not skip_building_wheel:
|
||||||
if dist_dir is None:
|
if dist_dir is None:
|
||||||
|
|||||||
3
setup.py
3
setup.py
@ -107,7 +107,8 @@ else:
|
|||||||
'libs/libdecoder_attention_1.so', 'libs/nvshmem/License.txt',
|
'libs/libdecoder_attention_1.so', 'libs/nvshmem/License.txt',
|
||||||
'libs/nvshmem/nvshmem_bootstrap_uid.so.3',
|
'libs/nvshmem/nvshmem_bootstrap_uid.so.3',
|
||||||
'libs/nvshmem/nvshmem_transport_ibgda.so.103', 'bindings.*.so',
|
'libs/nvshmem/nvshmem_transport_ibgda.so.103', 'bindings.*.so',
|
||||||
'deep_ep/LICENSE', 'deep_ep_cpp_tllm.*.so', "include/**/*"
|
'deep_ep/LICENSE', 'deep_ep_cpp_tllm.*.so', "include/**/*",
|
||||||
|
'deep_gemm/LICENSE', 'deep_gemm/include/**/*', 'deep_gemm_cpp_tllm.*.so'
|
||||||
]
|
]
|
||||||
|
|
||||||
package_data += [
|
package_data += [
|
||||||
|
|||||||
@ -1,12 +1,12 @@
|
|||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import deep_gemm
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
import tensorrt_llm.quantization.utils.fp8_utils as fp8_utils
|
import tensorrt_llm.quantization.utils.fp8_utils as fp8_utils
|
||||||
|
from tensorrt_llm import deep_gemm
|
||||||
from tensorrt_llm._utils import nvtx_range
|
from tensorrt_llm._utils import nvtx_range
|
||||||
|
|
||||||
from ...distributed import allgather
|
from ...distributed import allgather
|
||||||
|
|||||||
@ -573,7 +573,7 @@ class FP8BlockScalesLinearMethod(LinearMethodBase):
|
|||||||
assert input.dtype == torch.bfloat16
|
assert input.dtype == torch.bfloat16
|
||||||
|
|
||||||
if get_sm_version() == 100:
|
if get_sm_version() == 100:
|
||||||
import deep_gemm
|
from tensorrt_llm import deep_gemm
|
||||||
a, a_sf = fp8_utils.per_token_quant_and_transform(input)
|
a, a_sf = fp8_utils.per_token_quant_and_transform(input)
|
||||||
output = torch.empty((input.shape[0], module.weight.shape[0]),
|
output = torch.empty((input.shape[0], module.weight.shape[0]),
|
||||||
device=input.device,
|
device=input.device,
|
||||||
|
|||||||
@ -50,7 +50,7 @@ def test_fp8_block_scale_deep_gemm(dtype, m, k, n):
|
|||||||
act_b_fp8, act_b_sf = per_block_cast_to_fp8_e8m0(b)
|
act_b_fp8, act_b_sf = per_block_cast_to_fp8_e8m0(b)
|
||||||
|
|
||||||
output_expected = a @ b.t()
|
output_expected = a @ b.t()
|
||||||
import deep_gemm
|
from tensorrt_llm import deep_gemm
|
||||||
output = torch.empty((act_a_fp8.shape[0], act_b_fp8.shape[0]),
|
output = torch.empty((act_a_fp8.shape[0], act_b_fp8.shape[0]),
|
||||||
device=act_a_fp8.device,
|
device=act_a_fp8.device,
|
||||||
dtype=torch.bfloat16)
|
dtype=torch.bfloat16)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user