mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
586 lines
18 KiB
C++
586 lines
18 KiB
C++
/*
|
|
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include "tensorrt_llm/common/ncclUtils.h"
|
|
|
|
#if ENABLE_MULTI_DEVICE
|
|
|
|
#include "tensorrt_llm/common/assert.h"
|
|
#include "tensorrt_llm/common/cudaUtils.h"
|
|
#include "tensorrt_llm/common/logger.h"
|
|
#include <limits>
|
|
#include <stdexcept>
|
|
|
|
namespace tensorrt_llm::common::nccl_util
|
|
{
|
|
|
|
//==============================================================================
|
|
// NcclCommResourceManager Implementation
|
|
//==============================================================================
|
|
|
|
NcclCommResourceManager& NcclCommResourceManager::getInstance() noexcept
|
|
{
|
|
static NcclCommResourceManager instance;
|
|
return instance;
|
|
}
|
|
|
|
void NcclCommResourceManager::registerResource(ncclComm_t comm, ResourceCleanupFunc cleanup, char const* debugName)
|
|
{
|
|
if (!comm)
|
|
{
|
|
TLLM_LOG_WARNING("[NCCLUtil] Attempted to register resource for null NCCL comm");
|
|
return;
|
|
}
|
|
|
|
std::lock_guard<std::mutex> lock(mMutex);
|
|
auto& resources = mCommResources[comm];
|
|
resources.emplace_back(std::move(cleanup), debugName ? debugName : "unnamed");
|
|
|
|
TLLM_LOG_TRACE("[NCCLUtil] Registered resource '%s' for NCCL comm %p (total: %zu)",
|
|
debugName ? debugName : "unnamed", static_cast<void*>(comm), resources.size());
|
|
}
|
|
|
|
void NcclCommResourceManager::cleanupResources(ncclComm_t comm) noexcept
|
|
{
|
|
if (!comm)
|
|
{
|
|
return;
|
|
}
|
|
|
|
std::vector<ResourceEntry> resourcesToClean;
|
|
|
|
{
|
|
std::lock_guard<std::mutex> lock(mMutex);
|
|
auto it = mCommResources.find(comm);
|
|
if (it == mCommResources.end())
|
|
{
|
|
// Nothing registered for this comm, nothing to clean up
|
|
return;
|
|
}
|
|
|
|
// Move resources out (preserves order) and remove from map
|
|
resourcesToClean = std::move(it->second);
|
|
mCommResources.erase(it);
|
|
|
|
TLLM_LOG_TRACE(
|
|
"[NCCLUtil] Cleaning up %zu resources for NCCL comm %p", resourcesToClean.size(), static_cast<void*>(comm));
|
|
}
|
|
|
|
// Clean up outside the lock to avoid deadlocks if cleanup functions try to access the manager
|
|
// Order is preserved: resources are cleaned up in registration order
|
|
for (auto& [cleanup, name] : resourcesToClean)
|
|
{
|
|
try
|
|
{
|
|
TLLM_LOG_TRACE(
|
|
"[NCCLUtil] Cleaning up resource '%s' for NCCL comm %p", name.c_str(), static_cast<void*>(comm));
|
|
cleanup();
|
|
}
|
|
catch (std::exception const& e)
|
|
{
|
|
TLLM_LOG_ERROR("[NCCLUtil] Exception during cleanup of resource '%s' for NCCL comm %p: %s", name.c_str(),
|
|
static_cast<void*>(comm), e.what());
|
|
}
|
|
catch (...)
|
|
{
|
|
TLLM_LOG_ERROR("[NCCLUtil] Unknown exception during cleanup of resource '%s' for NCCL comm %p",
|
|
name.c_str(), static_cast<void*>(comm));
|
|
}
|
|
}
|
|
}
|
|
|
|
bool NcclCommResourceManager::hasResources(ncclComm_t comm) const noexcept
|
|
{
|
|
std::lock_guard<std::mutex> lock(mMutex);
|
|
return mCommResources.find(comm) != mCommResources.end();
|
|
}
|
|
|
|
size_t NcclCommResourceManager::getResourceCount(ncclComm_t comm) const noexcept
|
|
{
|
|
std::lock_guard<std::mutex> lock(mMutex);
|
|
auto it = mCommResources.find(comm);
|
|
return it != mCommResources.end() ? it->second.size() : 0;
|
|
}
|
|
|
|
//==============================================================================
|
|
// NCCLHelper Implementation
|
|
//==============================================================================
|
|
|
|
NCCLHelper& NCCLHelper::getInstance()
|
|
{
|
|
static NCCLHelper instance;
|
|
return instance;
|
|
}
|
|
|
|
NCCLHelper::NCCLHelper()
|
|
: mLibraryHandle(nullptr)
|
|
, mNCCLCommWindowRegister(nullptr)
|
|
, mNCCLMemAlloc(nullptr)
|
|
, mIsLoaded(false)
|
|
{
|
|
loadNCCLLibrary();
|
|
}
|
|
|
|
NCCLHelper::~NCCLHelper()
|
|
{
|
|
if (mLibraryHandle)
|
|
{
|
|
#ifdef _WIN32
|
|
FreeLibrary(mLibraryHandle);
|
|
#else
|
|
dlclose(mLibraryHandle);
|
|
#endif
|
|
mLibraryHandle = nullptr;
|
|
}
|
|
}
|
|
|
|
void NCCLHelper::loadNCCLLibrary()
|
|
{
|
|
try
|
|
{
|
|
#ifdef _WIN32
|
|
char const* libraryNames[] = {"nccl.dll"};
|
|
#else
|
|
char const* libraryNames[] = {"libnccl.so"};
|
|
#endif
|
|
|
|
for (auto const* name : libraryNames)
|
|
{
|
|
mLibraryHandle = loadLibraryHandle(name);
|
|
if (mLibraryHandle)
|
|
{
|
|
TLLM_LOG_INFO("Successfully loaded NCCL library: %s", name);
|
|
break;
|
|
}
|
|
}
|
|
|
|
if (!mLibraryHandle)
|
|
{
|
|
TLLM_LOG_WARNING("Failed to load NCCL library");
|
|
return;
|
|
}
|
|
|
|
// Load the required symbols
|
|
mNCCLCommWindowRegister
|
|
= reinterpret_cast<ncclCommWindowRegisterFunc>(getSymbolAddress(mLibraryHandle, "ncclCommWindowRegister"));
|
|
|
|
mNCCLMemAlloc = reinterpret_cast<ncclMemAllocFunc>(getSymbolAddress(mLibraryHandle, "ncclMemAlloc"));
|
|
|
|
if (mNCCLCommWindowRegister == nullptr)
|
|
{
|
|
TLLM_LOG_WARNING("Failed to load ncclCommWindowRegister symbol, NCCL symmetric will not be supported.");
|
|
}
|
|
|
|
if (mNCCLMemAlloc == nullptr)
|
|
{
|
|
TLLM_LOG_WARNING("Failed to load ncclMemAlloc symbol, NCCL symmetric will not be supported.");
|
|
}
|
|
|
|
if (mNCCLCommWindowRegister != nullptr && mNCCLMemAlloc != nullptr)
|
|
{
|
|
mIsLoaded = true;
|
|
}
|
|
else
|
|
{
|
|
TLLM_LOG_WARNING(
|
|
"Failed to load required NCCL symbols (both ncclCommWindowRegister and ncclMemAlloc are required)");
|
|
}
|
|
}
|
|
catch (std::exception const& e)
|
|
{
|
|
TLLM_LOG_WARNING("Exception while loading NCCL library: %s", e.what());
|
|
}
|
|
}
|
|
|
|
void* NCCLHelper::loadLibraryHandle(char const* libName)
|
|
{
|
|
#ifdef _WIN32
|
|
return LoadLibraryA(libName);
|
|
#else
|
|
return dlopen(libName, RTLD_LAZY | RTLD_GLOBAL);
|
|
#endif
|
|
}
|
|
|
|
void* NCCLHelper::getSymbolAddress(void* handle, char const* symbolName)
|
|
{
|
|
if (!handle)
|
|
{
|
|
return nullptr;
|
|
}
|
|
|
|
#ifdef _WIN32
|
|
return GetProcAddress(static_cast<HMODULE>(handle), symbolName);
|
|
#else
|
|
return dlsym(handle, symbolName);
|
|
#endif
|
|
}
|
|
|
|
NCCLHelper::ncclCommWindowRegisterFunc NCCLHelper::getNCCLCommWindowRegister()
|
|
{
|
|
return mNCCLCommWindowRegister;
|
|
}
|
|
|
|
NCCLHelper::ncclMemAllocFunc NCCLHelper::getNCCLMemAlloc()
|
|
{
|
|
return mNCCLMemAlloc;
|
|
}
|
|
|
|
bool NCCLHelper::isLoaded() const
|
|
{
|
|
return mIsLoaded;
|
|
}
|
|
|
|
//==============================================================================
|
|
// NCCLWindowAllocator Implementation
|
|
//==============================================================================
|
|
|
|
NCCLWindowAllocator& NCCLWindowAllocator::getInstance()
|
|
{
|
|
static NCCLWindowAllocator instance;
|
|
return instance;
|
|
}
|
|
|
|
NCCLWindowBuffer NCCLWindowAllocator::requestBuffer(ncclComm_t comm, size_t size)
|
|
{
|
|
TLLM_CHECK_WITH_INFO(comm != nullptr, "NCCL communicator cannot be null");
|
|
TLLM_CHECK_WITH_INFO(size > 0, "Buffer size must be greater than 0");
|
|
|
|
std::lock_guard<std::mutex> lock(mMutex);
|
|
|
|
// Register cleanup callback for this communicator if not already registered
|
|
// This is cheap even if no buffers exist yet - cleanup will just return early
|
|
registerBufferCleanup(comm);
|
|
|
|
// Check if we have an available buffer of at least the requested size for this communicator
|
|
// Use best-fit: find the smallest buffer that's >= requested size
|
|
auto& commBuffers = mBufferPool[comm];
|
|
auto bestFit = commBuffers.end();
|
|
size_t bestFitSize = std::numeric_limits<size_t>::max();
|
|
|
|
for (auto it = commBuffers.begin(); it != commBuffers.end(); ++it)
|
|
{
|
|
if (!it->inUse && it->buffer.size >= size && it->buffer.size < bestFitSize)
|
|
{
|
|
bestFit = it;
|
|
bestFitSize = it->buffer.size;
|
|
}
|
|
}
|
|
|
|
if (bestFit != commBuffers.end())
|
|
{
|
|
bestFit->inUse = true;
|
|
TLLM_LOG_TRACE(
|
|
"[NCCLUtil] Reusing NCCL window buffer for comm %p: handle=%d, ptr=%p, size=%zu (requested: %zu)",
|
|
static_cast<void*>(comm), bestFit->buffer.handle, bestFit->buffer.ptr, bestFit->buffer.size, size);
|
|
return bestFit->buffer;
|
|
}
|
|
|
|
// No available buffer found, allocate a new one
|
|
TLLM_LOG_TRACE(
|
|
"[NCCLUtil] Allocating new NCCL window buffer for comm %p, size=%zu", static_cast<void*>(comm), size);
|
|
int handle = static_cast<int>(commBuffers.size());
|
|
NCCLWindowBuffer buffer = allocateAndRegisterBuffer(comm, size, handle);
|
|
commBuffers.push_back({buffer, true});
|
|
|
|
return buffer;
|
|
}
|
|
|
|
NCCLWindowBuffer NCCLWindowAllocator::searchBuffer(ncclComm_t comm, void* ptr) const
|
|
{
|
|
if (!comm || !ptr)
|
|
{
|
|
return NCCLWindowBuffer();
|
|
}
|
|
|
|
std::lock_guard<std::mutex> lock(mMutex);
|
|
return searchBufferLocked(comm, ptr);
|
|
}
|
|
|
|
void NCCLWindowAllocator::releaseBuffer(ncclComm_t comm, void* ptr)
|
|
{
|
|
if (!comm || !ptr)
|
|
{
|
|
return;
|
|
}
|
|
|
|
std::lock_guard<std::mutex> lock(mMutex);
|
|
auto commIt = mBufferPool.find(comm);
|
|
if (commIt == mBufferPool.end())
|
|
{
|
|
TLLM_LOG_WARNING(
|
|
"[NCCLUtil] Attempted to release buffer %p for unknown comm %p", ptr, static_cast<void*>(comm));
|
|
return;
|
|
}
|
|
|
|
for (auto& entry : commIt->second)
|
|
{
|
|
if (entry.buffer.ptr == ptr)
|
|
{
|
|
entry.inUse = false;
|
|
TLLM_LOG_TRACE("[NCCLUtil] Released NCCL window buffer for comm %p: ptr=%p", static_cast<void*>(comm), ptr);
|
|
return;
|
|
}
|
|
}
|
|
|
|
TLLM_LOG_WARNING("[NCCLUtil] Attempted to release unknown buffer %p for comm %p", ptr, static_cast<void*>(comm));
|
|
}
|
|
|
|
ncclWindow_t NCCLWindowAllocator::getWindow(ncclComm_t comm, void* ptr) const
|
|
{
|
|
std::lock_guard<std::mutex> lock(mMutex);
|
|
NCCLWindowBuffer buffer = searchBufferLocked(comm, ptr);
|
|
return buffer.isValid() ? buffer.window : nullptr;
|
|
}
|
|
|
|
size_t NCCLWindowAllocator::getSize(ncclComm_t comm, void* ptr) const
|
|
{
|
|
std::lock_guard<std::mutex> lock(mMutex);
|
|
NCCLWindowBuffer buffer = searchBufferLocked(comm, ptr);
|
|
return buffer.isValid() ? buffer.size : 0;
|
|
}
|
|
|
|
NCCLWindowBuffer NCCLWindowAllocator::getBufferInfo(ncclComm_t comm, void* ptr) const
|
|
{
|
|
std::lock_guard<std::mutex> lock(mMutex);
|
|
return searchBufferLocked(comm, ptr);
|
|
}
|
|
|
|
size_t NCCLWindowAllocator::getBufferCount(ncclComm_t comm) const
|
|
{
|
|
std::lock_guard<std::mutex> lock(mMutex);
|
|
auto commIt = mBufferPool.find(comm);
|
|
return commIt != mBufferPool.end() ? commIt->second.size() : 0;
|
|
}
|
|
|
|
size_t NCCLWindowAllocator::getBufferInUseCount(ncclComm_t comm) const
|
|
{
|
|
std::lock_guard<std::mutex> lock(mMutex);
|
|
auto commIt = mBufferPool.find(comm);
|
|
if (commIt == mBufferPool.end())
|
|
{
|
|
return 0;
|
|
}
|
|
|
|
size_t count = 0;
|
|
for (auto const& entry : commIt->second)
|
|
{
|
|
if (entry.inUse)
|
|
{
|
|
++count;
|
|
}
|
|
}
|
|
return count;
|
|
}
|
|
|
|
bool NCCLWindowAllocator::isCommValid(ncclComm_t comm) const noexcept
|
|
{
|
|
// Simply check for null - all non-null comms are valid
|
|
// We don't track cleaned-up comms because NCCL can reuse memory addresses,
|
|
// making pointer-based tracking unreliable. New comms will be registered when used.
|
|
return comm != nullptr;
|
|
}
|
|
|
|
NCCLWindowBuffer NCCLWindowAllocator::allocateAndRegisterBuffer(ncclComm_t comm, size_t size, int handle)
|
|
{
|
|
NCCLWindowBuffer buffer;
|
|
buffer.handle = handle;
|
|
|
|
// Get NCCL helper for dynamic symbol loading
|
|
auto& ncclHelper = NCCLHelper::getInstance();
|
|
if (!ncclHelper.isLoaded())
|
|
{
|
|
TLLM_THROW("NCCL library could not be loaded for dynamic symbol access");
|
|
}
|
|
|
|
auto ncclMemAllocFunc = ncclHelper.getNCCLMemAlloc();
|
|
auto ncclCommWindowRegisterFunc = ncclHelper.getNCCLCommWindowRegister();
|
|
|
|
// Defensive checks: both function pointers must be non-null
|
|
if (ncclMemAllocFunc == nullptr)
|
|
{
|
|
TLLM_THROW("ncclMemAlloc function pointer is null, cannot allocate NCCL window buffer");
|
|
}
|
|
|
|
if (ncclCommWindowRegisterFunc == nullptr)
|
|
{
|
|
TLLM_THROW("ncclCommWindowRegister function pointer is null, cannot register NCCL window buffer");
|
|
}
|
|
|
|
// Allocate device memory using ncclMemAlloc
|
|
ncclResult_t allocResult = ncclMemAllocFunc(&buffer.ptr, size);
|
|
if (allocResult != ncclSuccess)
|
|
{
|
|
TLLM_THROW("ncclMemAlloc failed with error: %d", allocResult);
|
|
}
|
|
buffer.size = size;
|
|
|
|
// Register the buffer with NCCL as a window
|
|
ncclResult_t regResult
|
|
= ncclCommWindowRegisterFunc(comm, buffer.ptr, size, &buffer.window, NCCL_WIN_COLL_SYMMETRIC);
|
|
if (regResult != ncclSuccess)
|
|
{
|
|
ncclMemFree(buffer.ptr);
|
|
TLLM_THROW("ncclCommWindowRegister failed with error: %d", regResult);
|
|
}
|
|
|
|
TLLM_LOG_TRACE("[NCCLUtil] Allocated and registered NCCL window buffer: handle=%d, ptr=%p, size=%zu, window=%p",
|
|
handle, buffer.ptr, size, static_cast<void*>(buffer.window));
|
|
|
|
return buffer;
|
|
}
|
|
|
|
NCCLWindowBuffer NCCLWindowAllocator::searchBufferLocked(ncclComm_t comm, void* ptr) const
|
|
{
|
|
auto commIt = mBufferPool.find(comm);
|
|
if (commIt == mBufferPool.end())
|
|
{
|
|
return NCCLWindowBuffer();
|
|
}
|
|
|
|
for (auto const& entry : commIt->second)
|
|
{
|
|
if (entry.buffer.ptr == ptr)
|
|
{
|
|
return entry.buffer;
|
|
}
|
|
}
|
|
|
|
return NCCLWindowBuffer();
|
|
}
|
|
|
|
void NCCLWindowAllocator::registerBufferCleanup(ncclComm_t comm)
|
|
{
|
|
// Don't register if already registered
|
|
if (mRegisteredComms.find(comm) != mRegisteredComms.end())
|
|
{
|
|
return;
|
|
}
|
|
|
|
mRegisteredComms.insert(comm);
|
|
|
|
// Register cleanup with the resource manager
|
|
NcclCommResourceManager::getInstance().registerResource(
|
|
comm, [this, comm]() { this->cleanupBuffersForComm(comm); }, "NCCLWindowAllocator");
|
|
}
|
|
|
|
void NCCLWindowAllocator::cleanupBuffersForComm(ncclComm_t comm) noexcept
|
|
{
|
|
if (!comm)
|
|
{
|
|
return;
|
|
}
|
|
|
|
// Synchronize CUDA to ensure all operations using these buffers are complete
|
|
// before we deregister windows and free memory
|
|
cudaError_t cudaErr = cudaDeviceSynchronize();
|
|
if (cudaErr != cudaSuccess)
|
|
{
|
|
TLLM_LOG_WARNING("[NCCLUtil] cudaDeviceSynchronize failed with error: %d before cleanup for comm %p", cudaErr,
|
|
static_cast<void*>(comm));
|
|
// Continue anyway - the sync failure might be from a previous error
|
|
}
|
|
|
|
std::lock_guard<std::mutex> lock(mMutex);
|
|
|
|
// Check if we've already cleaned up this communicator
|
|
if (mRegisteredComms.find(comm) == mRegisteredComms.end())
|
|
{
|
|
// Already cleaned up or never registered
|
|
return;
|
|
}
|
|
|
|
auto commIt = mBufferPool.find(comm);
|
|
if (commIt == mBufferPool.end())
|
|
{
|
|
// No buffers to clean up, but mark as cleaned
|
|
mRegisteredComms.erase(comm);
|
|
return;
|
|
}
|
|
|
|
TLLM_LOG_TRACE(
|
|
"[NCCLUtil] Cleaning up %zu NCCL window buffers for comm %p", commIt->second.size(), static_cast<void*>(comm));
|
|
|
|
// Check for buffers still in use - this shouldn't happen if cleanup is called properly,
|
|
// but we log a warning if it does
|
|
size_t inUseCount = 0;
|
|
for (auto const& entry : commIt->second)
|
|
{
|
|
if (entry.inUse)
|
|
{
|
|
++inUseCount;
|
|
}
|
|
}
|
|
if (inUseCount > 0)
|
|
{
|
|
TLLM_LOG_WARNING(
|
|
"[NCCLUtil] Cleaning up %zu buffers still marked as in-use for comm %p. "
|
|
"This may indicate buffers weren't properly released before cleanup.",
|
|
inUseCount, static_cast<void*>(comm));
|
|
}
|
|
|
|
for (auto& entry : commIt->second)
|
|
{
|
|
if (entry.buffer.isValid())
|
|
{
|
|
// Deregister the window - the communicator is still valid at this point
|
|
// (cleanup happens before ncclCommDestroy), but we need to be careful
|
|
// if buffers are still in use by active operations
|
|
if (entry.buffer.window && comm)
|
|
{
|
|
// Note: Even if buffer is marked inUse, we must deregister since
|
|
// the communicator is being destroyed. The communicator is valid,
|
|
// but we should handle potential errors gracefully.
|
|
ncclResult_t result = ncclCommWindowDeregister(comm, entry.buffer.window);
|
|
if (result != ncclSuccess)
|
|
{
|
|
TLLM_LOG_WARNING(
|
|
"[NCCLUtil] ncclCommWindowDeregister failed with error: %d for comm %p, "
|
|
"window %p (buffer inUse: %d)",
|
|
result, static_cast<void*>(comm), static_cast<void*>(entry.buffer.window), entry.inUse);
|
|
}
|
|
}
|
|
|
|
// Free device memory using ncclMemFree
|
|
// This should be safe even if deregister failed
|
|
if (entry.buffer.ptr)
|
|
{
|
|
try
|
|
{
|
|
ncclResult_t ncclResult = ncclMemFree(entry.buffer.ptr);
|
|
if (ncclResult != ncclSuccess)
|
|
{
|
|
TLLM_LOG_WARNING("[NCCLUtil] ncclMemFree failed with error: %d", ncclResult);
|
|
}
|
|
}
|
|
catch (...)
|
|
{
|
|
TLLM_LOG_ERROR("[NCCLUtil] Exception during ncclMemFree for ptr %p", entry.buffer.ptr);
|
|
}
|
|
}
|
|
|
|
TLLM_LOG_TRACE(
|
|
"[NCCLUtil] Freed NCCL window buffer: ptr=%p, size=%zu", entry.buffer.ptr, entry.buffer.size);
|
|
}
|
|
}
|
|
|
|
mBufferPool.erase(commIt);
|
|
mRegisteredComms.erase(comm);
|
|
}
|
|
|
|
} // namespace tensorrt_llm::common::nccl_util
|
|
|
|
#endif // ENABLE_MULTI_DEVICE
|