diff --git a/cpp/tensorrt_llm/common/cudaDriverWrapper.cpp b/cpp/tensorrt_llm/common/cudaDriverWrapper.cpp index d298fb42d6..8f57407a40 100644 --- a/cpp/tensorrt_llm/common/cudaDriverWrapper.cpp +++ b/cpp/tensorrt_llm/common/cudaDriverWrapper.cpp @@ -28,11 +28,14 @@ #define dllGetSym(handle, name) dlsym(handle, name) #endif // defined(_WIN32) -#include "cudaDriverWrapper.h" #include "tensorrt_llm/common/assert.h" -#include +#include "tensorrt_llm/common/cudaDriverWrapper.h" + #include +#include +#include + namespace tensorrt_llm::common { @@ -46,7 +49,7 @@ std::shared_ptr CUDADriverWrapper::getInstance() return result; } - std::lock_guard lock(mutex); + std::lock_guard const lock(mutex); result = instance.lock(); if (!result) { @@ -69,7 +72,7 @@ CUDADriverWrapper::CUDADriverWrapper() }; *reinterpret_cast(&_cuGetErrorName) = load_sym(handle, "cuGetErrorName"); - *reinterpret_cast(&_cuGetErrorMessage) = load_sym(handle, "cuGetErrorMessage"); + *reinterpret_cast(&_cuGetErrorString) = load_sym(handle, "cuGetErrorString"); *reinterpret_cast(&_cuFuncSetAttribute) = load_sym(handle, "cuFuncSetAttribute"); *reinterpret_cast(&_cuLinkComplete) = load_sym(handle, "cuLinkComplete"); *reinterpret_cast(&_cuModuleUnload) = load_sym(handle, "cuModuleUnload"); @@ -98,9 +101,9 @@ CUresult CUDADriverWrapper::cuGetErrorName(CUresult error, char const** pStr) co return (*_cuGetErrorName)(error, pStr); } -CUresult CUDADriverWrapper::cuGetErrorMessage(CUresult error, char const** pStr) const +CUresult CUDADriverWrapper::cuGetErrorString(CUresult error, char const** pStr) const { - return (*_cuGetErrorMessage)(error, pStr); + return (*_cuGetErrorString)(error, pStr); } CUresult CUDADriverWrapper::cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value) const diff --git a/cpp/tensorrt_llm/common/cudaDriverWrapper.h b/cpp/tensorrt_llm/common/cudaDriverWrapper.h index 80605896dc..affad6634a 100644 --- a/cpp/tensorrt_llm/common/cudaDriverWrapper.h +++ b/cpp/tensorrt_llm/common/cudaDriverWrapper.h @@ -17,11 +17,13 @@ #ifndef CUDA_DRIVER_WRAPPER_H #define CUDA_DRIVER_WRAPPER_H -#include "tensorrt_llm/common/assert.h" -#include +#include "tensorrt_llm/common/stringUtils.h" +#include "tensorrt_llm/common/tllmException.h" + #include + +#include #include -#include namespace tensorrt_llm::common { @@ -39,7 +41,7 @@ public: CUresult cuGetErrorName(CUresult error, char const** pStr) const; - CUresult cuGetErrorMessage(CUresult error, char const** pStr) const; + CUresult cuGetErrorString(CUresult error, char const** pStr) const; CUresult cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value) const; @@ -88,7 +90,7 @@ private: CUDADriverWrapper(); CUresult (*_cuGetErrorName)(CUresult, char const**); - CUresult (*_cuGetErrorMessage)(CUresult, char const**); + CUresult (*_cuGetErrorString)(CUresult, char const**); CUresult (*_cuFuncSetAttribute)(CUfunction, CUfunction_attribute, int); CUresult (*_cuLinkComplete)(CUlinkState, void**, size_t*); CUresult (*_cuModuleUnload)(CUmodule); @@ -121,11 +123,11 @@ void checkDriver( if (result) { char const* errorName = nullptr; - char const* errorMsg = nullptr; + char const* errorString = nullptr; wrap.cuGetErrorName(result, &errorName); - wrap.cuGetErrorMessage(result, &errorMsg); + wrap.cuGetErrorString(result, &errorString); throw TllmException( - file, line, fmtstr("[TensorRT-LLM][ERROR] CUDA driver error in %s: %s: %s", func, errorName, errorMsg)); + file, line, fmtstr("[TensorRT-LLM][ERROR] CUDA driver error in %s: %s: %s.", func, errorName, errorString)); } } diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/cubinObjRegistry.h b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/cubinObjRegistry.h index ee6c419459..468cd77bc1 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/cubinObjRegistry.h +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/cubinObjRegistry.h @@ -19,14 +19,12 @@ #include "compileEngine.h" #include "serializationUtils.h" #include "tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplCommon.h" + #include +#include #include -namespace tensorrt_llm -{ -namespace kernels -{ -namespace jit +namespace tensorrt_llm::kernels::jit { // A thread-safe collection of CubinObjs, with caching functionality. @@ -39,19 +37,19 @@ public: CubinObjRegistryTemplate(void const* buffer_, size_t buffer_size) { size_t remaining_buffer_size = buffer_size; - uint8_t const* buffer = static_cast(buffer_); + auto const* buffer = static_cast(buffer_); // First 4 bytes: num of elements. - uint32_t n = readFromBuffer(buffer, remaining_buffer_size); + auto const n = readFromBuffer(buffer, remaining_buffer_size); for (uint32_t i = 0; i < n; ++i) { - uint32_t key_size = readFromBuffer(buffer, remaining_buffer_size); + auto key_size = readFromBuffer(buffer, remaining_buffer_size); TLLM_CHECK(key_size <= remaining_buffer_size); Key key(buffer, key_size); buffer += key_size; remaining_buffer_size -= key_size; - uint32_t obj_size = readFromBuffer(buffer, remaining_buffer_size); + auto obj_size = readFromBuffer(buffer, remaining_buffer_size); TLLM_CHECK(obj_size <= remaining_buffer_size); CubinObj obj(buffer, obj_size); buffer += obj_size; @@ -89,7 +87,7 @@ public: { std::lock_guard lock(mMutex); size_t remaining_buffer_size = buffer_size; - uint8_t* buffer = static_cast(buffer_); + auto* buffer = static_cast(buffer_); uint32_t n = mMap.size(); writeToBuffer(n, buffer, remaining_buffer_size); for (auto&& p : mMap) @@ -131,7 +129,6 @@ public: obj.initialize(); } mMap.insert({key, std::move(obj)}); - return; } CubinObj* getCubin(Key const& key) @@ -142,10 +139,8 @@ public: { return &iter->second; } - else - { - return nullptr; - } + + return nullptr; } // When initialize is true, initialize cubins. @@ -178,6 +173,4 @@ using CubinObjKey = XQAKernelFullHashKey; using CubinObjHasher = XQAKernelFullHasher; using CubinObjRegistry = CubinObjRegistryTemplate; -} // namespace jit -} // namespace kernels -} // namespace tensorrt_llm +} // namespace tensorrt_llm::kernels::jit diff --git a/cpp/tests/unit_tests/common/CMakeLists.txt b/cpp/tests/unit_tests/common/CMakeLists.txt index 7e06f28584..0bc5c953cb 100644 --- a/cpp/tests/unit_tests/common/CMakeLists.txt +++ b/cpp/tests/unit_tests/common/CMakeLists.txt @@ -18,3 +18,4 @@ add_gtest(stlUtilsTest stlUtilsTest.cpp) add_gtest(stringUtilsTest stringUtilsTest.cpp) add_gtest(timestampUtilsTest timestampUtilsTest.cpp) add_gtest(tllmExceptionTest tllmExceptionTest.cpp) +add_gtest(cudaDriverWrapperTest cudaDriverWrapperTest.cpp) diff --git a/cpp/tests/unit_tests/common/cudaDriverWrapperTest.cpp b/cpp/tests/unit_tests/common/cudaDriverWrapperTest.cpp new file mode 100644 index 0000000000..2ba48cd9fb --- /dev/null +++ b/cpp/tests/unit_tests/common/cudaDriverWrapperTest.cpp @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "tensorrt_llm/common/cudaDriverWrapper.h" +#include "tensorrt_llm/common/cudaUtils.h" + +TEST(TestCudaDriverWrapper, TllmCuCheckFailingWithValidParametersDoesNotThrow) +{ + auto const deviceCount = tensorrt_llm::common::getDeviceCount(); + if (deviceCount == 0) + { + GTEST_SKIP() << "No CUDA devices found"; + } + CUmemGenericAllocationHandle handle{}; + CUmemAllocationProp const prop{CUmemAllocationType::CU_MEM_ALLOCATION_TYPE_PINNED, + CUmemAllocationHandleType::CU_MEM_HANDLE_TYPE_NONE, + CUmemLocation{ + CUmemLocationType::CU_MEM_LOCATION_TYPE_DEVICE, + 0, + }, + nullptr}; + auto const granularity = tensorrt_llm::common::getAllocationGranularity(); + ASSERT_NO_THROW(TLLM_CU_CHECK(cuMemCreate(&handle, granularity * 16, &prop, 0))); + ASSERT_NO_THROW(TLLM_CU_CHECK(cuMemRelease(handle))); +} + +TEST(TestCudaDriverWrapper, TllmCuCheckFailingWithInvalidParametersThrows) +{ + auto const deviceCount = tensorrt_llm::common::getDeviceCount(); + if (deviceCount == 0) + { + GTEST_SKIP() << "No CUDA devices found"; + } + CUmemGenericAllocationHandle handle{}; + CUmemAllocationProp const prop{CUmemAllocationType::CU_MEM_ALLOCATION_TYPE_PINNED, + CUmemAllocationHandleType::CU_MEM_HANDLE_TYPE_NONE, + CUmemLocation{ + CUmemLocationType::CU_MEM_LOCATION_TYPE_DEVICE, + 0, + }, + nullptr}; + ASSERT_THROW(TLLM_CU_CHECK(cuMemCreate(&handle, -1, &prop, 0ULL)), tensorrt_llm::common::TllmException); +}