fix: segfault in cudaDriverWrapper (#3017)

* fix segmentation fault in cudaDriverWrapper

Signed-off-by: jdebache <jdebache@nvidia.com>

* replace cuGetErrorMessage with cuGetErrorString and added tests

Signed-off-by: jdebache <jdebache@nvidia.com>

---------

Signed-off-by: jdebache <jdebache@nvidia.com>
This commit is contained in:
Julien Debache 2025-04-02 08:55:19 +02:00 committed by GitHub
parent 8d48b96545
commit 76a6a62073
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 89 additions and 32 deletions

View File

@ -28,11 +28,14 @@
#define dllGetSym(handle, name) dlsym(handle, name)
#endif // defined(_WIN32)
#include "cudaDriverWrapper.h"
#include "tensorrt_llm/common/assert.h"
#include <cstdio>
#include "tensorrt_llm/common/cudaDriverWrapper.h"
#include <cuda.h>
#include <cstdio>
#include <mutex>
namespace tensorrt_llm::common
{
@ -46,7 +49,7 @@ std::shared_ptr<CUDADriverWrapper> CUDADriverWrapper::getInstance()
return result;
}
std::lock_guard<std::mutex> lock(mutex);
std::lock_guard<std::mutex> const lock(mutex);
result = instance.lock();
if (!result)
{
@ -69,7 +72,7 @@ CUDADriverWrapper::CUDADriverWrapper()
};
*reinterpret_cast<void**>(&_cuGetErrorName) = load_sym(handle, "cuGetErrorName");
*reinterpret_cast<void**>(&_cuGetErrorMessage) = load_sym(handle, "cuGetErrorMessage");
*reinterpret_cast<void**>(&_cuGetErrorString) = load_sym(handle, "cuGetErrorString");
*reinterpret_cast<void**>(&_cuFuncSetAttribute) = load_sym(handle, "cuFuncSetAttribute");
*reinterpret_cast<void**>(&_cuLinkComplete) = load_sym(handle, "cuLinkComplete");
*reinterpret_cast<void**>(&_cuModuleUnload) = load_sym(handle, "cuModuleUnload");
@ -98,9 +101,9 @@ CUresult CUDADriverWrapper::cuGetErrorName(CUresult error, char const** pStr) co
return (*_cuGetErrorName)(error, pStr);
}
CUresult CUDADriverWrapper::cuGetErrorMessage(CUresult error, char const** pStr) const
CUresult CUDADriverWrapper::cuGetErrorString(CUresult error, char const** pStr) const
{
return (*_cuGetErrorMessage)(error, pStr);
return (*_cuGetErrorString)(error, pStr);
}
CUresult CUDADriverWrapper::cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value) const

View File

@ -17,11 +17,13 @@
#ifndef CUDA_DRIVER_WRAPPER_H
#define CUDA_DRIVER_WRAPPER_H
#include "tensorrt_llm/common/assert.h"
#include <cstdio>
#include "tensorrt_llm/common/stringUtils.h"
#include "tensorrt_llm/common/tllmException.h"
#include <cuda.h>
#include <cstdio>
#include <memory>
#include <mutex>
namespace tensorrt_llm::common
{
@ -39,7 +41,7 @@ public:
CUresult cuGetErrorName(CUresult error, char const** pStr) const;
CUresult cuGetErrorMessage(CUresult error, char const** pStr) const;
CUresult cuGetErrorString(CUresult error, char const** pStr) const;
CUresult cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value) const;
@ -88,7 +90,7 @@ private:
CUDADriverWrapper();
CUresult (*_cuGetErrorName)(CUresult, char const**);
CUresult (*_cuGetErrorMessage)(CUresult, char const**);
CUresult (*_cuGetErrorString)(CUresult, char const**);
CUresult (*_cuFuncSetAttribute)(CUfunction, CUfunction_attribute, int);
CUresult (*_cuLinkComplete)(CUlinkState, void**, size_t*);
CUresult (*_cuModuleUnload)(CUmodule);
@ -121,11 +123,11 @@ void checkDriver(
if (result)
{
char const* errorName = nullptr;
char const* errorMsg = nullptr;
char const* errorString = nullptr;
wrap.cuGetErrorName(result, &errorName);
wrap.cuGetErrorMessage(result, &errorMsg);
wrap.cuGetErrorString(result, &errorString);
throw TllmException(
file, line, fmtstr("[TensorRT-LLM][ERROR] CUDA driver error in %s: %s: %s", func, errorName, errorMsg));
file, line, fmtstr("[TensorRT-LLM][ERROR] CUDA driver error in %s: %s: %s.", func, errorName, errorString));
}
}

View File

@ -19,14 +19,12 @@
#include "compileEngine.h"
#include "serializationUtils.h"
#include "tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplCommon.h"
#include <functional>
#include <mutex>
#include <unordered_map>
namespace tensorrt_llm
{
namespace kernels
{
namespace jit
namespace tensorrt_llm::kernels::jit
{
// A thread-safe collection of CubinObjs, with caching functionality.
@ -39,19 +37,19 @@ public:
CubinObjRegistryTemplate(void const* buffer_, size_t buffer_size)
{
size_t remaining_buffer_size = buffer_size;
uint8_t const* buffer = static_cast<uint8_t const*>(buffer_);
auto const* buffer = static_cast<uint8_t const*>(buffer_);
// First 4 bytes: num of elements.
uint32_t n = readFromBuffer<uint32_t>(buffer, remaining_buffer_size);
auto const n = readFromBuffer<uint32_t>(buffer, remaining_buffer_size);
for (uint32_t i = 0; i < n; ++i)
{
uint32_t key_size = readFromBuffer<uint32_t>(buffer, remaining_buffer_size);
auto key_size = readFromBuffer<uint32_t>(buffer, remaining_buffer_size);
TLLM_CHECK(key_size <= remaining_buffer_size);
Key key(buffer, key_size);
buffer += key_size;
remaining_buffer_size -= key_size;
uint32_t obj_size = readFromBuffer<uint32_t>(buffer, remaining_buffer_size);
auto obj_size = readFromBuffer<uint32_t>(buffer, remaining_buffer_size);
TLLM_CHECK(obj_size <= remaining_buffer_size);
CubinObj obj(buffer, obj_size);
buffer += obj_size;
@ -89,7 +87,7 @@ public:
{
std::lock_guard<std::mutex> lock(mMutex);
size_t remaining_buffer_size = buffer_size;
uint8_t* buffer = static_cast<uint8_t*>(buffer_);
auto* buffer = static_cast<uint8_t*>(buffer_);
uint32_t n = mMap.size();
writeToBuffer<uint32_t>(n, buffer, remaining_buffer_size);
for (auto&& p : mMap)
@ -131,7 +129,6 @@ public:
obj.initialize();
}
mMap.insert({key, std::move(obj)});
return;
}
CubinObj* getCubin(Key const& key)
@ -142,10 +139,8 @@ public:
{
return &iter->second;
}
else
{
return nullptr;
}
return nullptr;
}
// When initialize is true, initialize cubins.
@ -178,6 +173,4 @@ using CubinObjKey = XQAKernelFullHashKey;
using CubinObjHasher = XQAKernelFullHasher;
using CubinObjRegistry = CubinObjRegistryTemplate<CubinObjKey, CubinObjHasher>;
} // namespace jit
} // namespace kernels
} // namespace tensorrt_llm
} // namespace tensorrt_llm::kernels::jit

View File

@ -18,3 +18,4 @@ add_gtest(stlUtilsTest stlUtilsTest.cpp)
add_gtest(stringUtilsTest stringUtilsTest.cpp)
add_gtest(timestampUtilsTest timestampUtilsTest.cpp)
add_gtest(tllmExceptionTest tllmExceptionTest.cpp)
add_gtest(cudaDriverWrapperTest cudaDriverWrapperTest.cpp)

View File

@ -0,0 +1,58 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <gtest/gtest.h>
#include "tensorrt_llm/common/cudaDriverWrapper.h"
#include "tensorrt_llm/common/cudaUtils.h"
TEST(TestCudaDriverWrapper, TllmCuCheckFailingWithValidParametersDoesNotThrow)
{
auto const deviceCount = tensorrt_llm::common::getDeviceCount();
if (deviceCount == 0)
{
GTEST_SKIP() << "No CUDA devices found";
}
CUmemGenericAllocationHandle handle{};
CUmemAllocationProp const prop{CUmemAllocationType::CU_MEM_ALLOCATION_TYPE_PINNED,
CUmemAllocationHandleType::CU_MEM_HANDLE_TYPE_NONE,
CUmemLocation{
CUmemLocationType::CU_MEM_LOCATION_TYPE_DEVICE,
0,
},
nullptr};
auto const granularity = tensorrt_llm::common::getAllocationGranularity();
ASSERT_NO_THROW(TLLM_CU_CHECK(cuMemCreate(&handle, granularity * 16, &prop, 0)));
ASSERT_NO_THROW(TLLM_CU_CHECK(cuMemRelease(handle)));
}
TEST(TestCudaDriverWrapper, TllmCuCheckFailingWithInvalidParametersThrows)
{
auto const deviceCount = tensorrt_llm::common::getDeviceCount();
if (deviceCount == 0)
{
GTEST_SKIP() << "No CUDA devices found";
}
CUmemGenericAllocationHandle handle{};
CUmemAllocationProp const prop{CUmemAllocationType::CU_MEM_ALLOCATION_TYPE_PINNED,
CUmemAllocationHandleType::CU_MEM_HANDLE_TYPE_NONE,
CUmemLocation{
CUmemLocationType::CU_MEM_LOCATION_TYPE_DEVICE,
0,
},
nullptr};
ASSERT_THROW(TLLM_CU_CHECK(cuMemCreate(&handle, -1, &prop, 0ULL)), tensorrt_llm::common::TllmException);
}