mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[None][fix] [fix] Make NCCL resource manager destructor exception-safe (#10166)
Signed-off-by: Ludwig Schneider <lschneider@nvidia.com>
This commit is contained in:
parent
865992b86b
commit
59045a0e41
@ -37,6 +37,46 @@ NcclCommResourceManager& NcclCommResourceManager::getInstance() noexcept
|
||||
return instance;
|
||||
}
|
||||
|
||||
NcclCommResourceManager::~NcclCommResourceManager()
|
||||
{
|
||||
// Mark that we're in destruction to prevent cleanup attempts from deleters
|
||||
// that may run during static destruction
|
||||
mIsDestroying.store(true, std::memory_order_release);
|
||||
|
||||
// Proactively clean up all resources before destruction
|
||||
// This ensures cleanup happens in a controlled manner before static destruction
|
||||
std::vector<std::pair<ncclComm_t, std::vector<ResourceEntry>>> allResources;
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mMutex);
|
||||
// Move all resources out of the map
|
||||
allResources.reserve(mCommResources.size());
|
||||
for (auto& [comm, resources] : mCommResources)
|
||||
{
|
||||
allResources.emplace_back(comm, std::move(resources));
|
||||
}
|
||||
mCommResources.clear();
|
||||
}
|
||||
|
||||
// Clean up all resources outside the lock
|
||||
// Note: We don't call ncclCommDestroy here - that's the responsibility
|
||||
// of the shared_ptr deleter. We just clean up registered resources.
|
||||
for (auto& [comm, resources] : allResources)
|
||||
{
|
||||
for (auto& [cleanup, name] : resources)
|
||||
{
|
||||
try
|
||||
{
|
||||
cleanup();
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
// Ignore exceptions during destruction
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void NcclCommResourceManager::registerResource(ncclComm_t comm, ResourceCleanupFunc cleanup, char const* debugName)
|
||||
{
|
||||
if (!comm)
|
||||
@ -60,23 +100,56 @@ void NcclCommResourceManager::cleanupResources(ncclComm_t comm) noexcept
|
||||
return;
|
||||
}
|
||||
|
||||
// Check if we're in the process of being destroyed
|
||||
// If so, skip cleanup - the destructor will handle it proactively
|
||||
if (mIsDestroying.load(std::memory_order_acquire))
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<ResourceEntry> resourcesToClean;
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mMutex);
|
||||
auto it = mCommResources.find(comm);
|
||||
if (it == mCommResources.end())
|
||||
// During static destruction, mutex and logging may not be safe.
|
||||
// Use try-catch to handle any issues gracefully.
|
||||
try
|
||||
{
|
||||
// Nothing registered for this comm, nothing to clean up
|
||||
std::lock_guard<std::mutex> lock(mMutex);
|
||||
|
||||
// Double-check after acquiring lock (destruction may have started)
|
||||
if (mIsDestroying.load(std::memory_order_acquire))
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
auto it = mCommResources.find(comm);
|
||||
if (it == mCommResources.end())
|
||||
{
|
||||
// Nothing registered for this comm, nothing to clean up
|
||||
return;
|
||||
}
|
||||
|
||||
// Move resources out (preserves order) and remove from map
|
||||
resourcesToClean = std::move(it->second);
|
||||
mCommResources.erase(it);
|
||||
|
||||
// Logging may fail during static destruction, so wrap in try-catch
|
||||
try
|
||||
{
|
||||
TLLM_LOG_TRACE("[NCCLUtil] Cleaning up %zu resources for NCCL comm %p", resourcesToClean.size(),
|
||||
static_cast<void*>(comm));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
// Ignore logging failures during static destruction
|
||||
}
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
// If mutex access fails during static destruction, just return.
|
||||
// This prevents segfaults when the singleton is being destroyed.
|
||||
return;
|
||||
}
|
||||
|
||||
// Move resources out (preserves order) and remove from map
|
||||
resourcesToClean = std::move(it->second);
|
||||
mCommResources.erase(it);
|
||||
|
||||
TLLM_LOG_TRACE(
|
||||
"[NCCLUtil] Cleaning up %zu resources for NCCL comm %p", resourcesToClean.size(), static_cast<void*>(comm));
|
||||
}
|
||||
|
||||
// Clean up outside the lock to avoid deadlocks if cleanup functions try to access the manager
|
||||
@ -85,19 +158,41 @@ void NcclCommResourceManager::cleanupResources(ncclComm_t comm) noexcept
|
||||
{
|
||||
try
|
||||
{
|
||||
TLLM_LOG_TRACE(
|
||||
"[NCCLUtil] Cleaning up resource '%s' for NCCL comm %p", name.c_str(), static_cast<void*>(comm));
|
||||
// Logging may fail during static destruction, so wrap in try-catch
|
||||
try
|
||||
{
|
||||
TLLM_LOG_TRACE(
|
||||
"[NCCLUtil] Cleaning up resource '%s' for NCCL comm %p", name.c_str(), static_cast<void*>(comm));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
// Ignore logging failures during static destruction
|
||||
}
|
||||
cleanup();
|
||||
}
|
||||
catch (std::exception const& e)
|
||||
{
|
||||
TLLM_LOG_ERROR("[NCCLUtil] Exception during cleanup of resource '%s' for NCCL comm %p: %s", name.c_str(),
|
||||
static_cast<void*>(comm), e.what());
|
||||
try
|
||||
{
|
||||
TLLM_LOG_ERROR("[NCCLUtil] Exception during cleanup of resource '%s' for NCCL comm %p: %s",
|
||||
name.c_str(), static_cast<void*>(comm), e.what());
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
// Ignore logging failures during static destruction
|
||||
}
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
TLLM_LOG_ERROR("[NCCLUtil] Unknown exception during cleanup of resource '%s' for NCCL comm %p",
|
||||
name.c_str(), static_cast<void*>(comm));
|
||||
try
|
||||
{
|
||||
TLLM_LOG_ERROR("[NCCLUtil] Unknown exception during cleanup of resource '%s' for NCCL comm %p",
|
||||
name.c_str(), static_cast<void*>(comm));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
// Ignore logging failures during static destruction
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -26,6 +26,7 @@
|
||||
#endif
|
||||
|
||||
#include <algorithm>
|
||||
#include <atomic>
|
||||
#include <functional>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
@ -139,12 +140,13 @@ public:
|
||||
|
||||
private:
|
||||
NcclCommResourceManager() = default;
|
||||
~NcclCommResourceManager() = default;
|
||||
~NcclCommResourceManager();
|
||||
|
||||
using ResourceEntry = std::pair<ResourceCleanupFunc, std::string>;
|
||||
|
||||
mutable std::mutex mMutex;
|
||||
std::unordered_map<ncclComm_t, std::vector<ResourceEntry>> mCommResources;
|
||||
std::atomic<bool> mIsDestroying{false};
|
||||
};
|
||||
|
||||
// RAII helper to register a resource with a NCCL communicator.
|
||||
|
||||
@ -123,13 +123,24 @@ std::shared_ptr<ncclComm_t> getComm(std::set<int> const& group)
|
||||
if (*comm)
|
||||
{
|
||||
// Clean up all registered resources FIRST
|
||||
// The cleanupResources function uses a destruction guard to safely handle
|
||||
// static destruction order issues - it will return early if the singleton
|
||||
// is being destroyed (in which case the destructor handles cleanup proactively)
|
||||
tensorrt_llm::common::nccl_util::NcclCommResourceManager::getInstance().cleanupResources(*comm);
|
||||
|
||||
// Now destroy the NCCL communicator
|
||||
ncclResult_t result = ncclCommDestroy(*comm);
|
||||
if (result != ncclSuccess)
|
||||
{
|
||||
TLLM_LOG_WARNING("ncclCommDestroy failed with error: %d", result);
|
||||
// Logging may fail during static destruction, so wrap in try-catch
|
||||
try
|
||||
{
|
||||
TLLM_LOG_WARNING("ncclCommDestroy failed with error: %d", result);
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
// Ignore logging failures during static destruction
|
||||
}
|
||||
}
|
||||
|
||||
// Clear the communicator value before freeing the pointer
|
||||
|
||||
Loading…
Reference in New Issue
Block a user