TensorRT-LLMs/cpp/tensorrt_llm/common/ncclUtils.cpp
Ludwig Schneider 41ce14ab04
[None][feat] Enable NCCL_SYMMETRIC as default fallback for AllReduce (#9314)
Signed-off-by: Ludwig Schneider <lschneider@nvidia.com>
2025-12-07 09:43:26 -08:00

586 lines
18 KiB
C++

/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/common/ncclUtils.h"
#if ENABLE_MULTI_DEVICE
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/logger.h"
#include <limits>
#include <stdexcept>
namespace tensorrt_llm::common::nccl_util
{
//==============================================================================
// NcclCommResourceManager Implementation
//==============================================================================
NcclCommResourceManager& NcclCommResourceManager::getInstance() noexcept
{
static NcclCommResourceManager instance;
return instance;
}
void NcclCommResourceManager::registerResource(ncclComm_t comm, ResourceCleanupFunc cleanup, char const* debugName)
{
if (!comm)
{
TLLM_LOG_WARNING("[NCCLUtil] Attempted to register resource for null NCCL comm");
return;
}
std::lock_guard<std::mutex> lock(mMutex);
auto& resources = mCommResources[comm];
resources.emplace_back(std::move(cleanup), debugName ? debugName : "unnamed");
TLLM_LOG_TRACE("[NCCLUtil] Registered resource '%s' for NCCL comm %p (total: %zu)",
debugName ? debugName : "unnamed", static_cast<void*>(comm), resources.size());
}
void NcclCommResourceManager::cleanupResources(ncclComm_t comm) noexcept
{
if (!comm)
{
return;
}
std::vector<ResourceEntry> resourcesToClean;
{
std::lock_guard<std::mutex> lock(mMutex);
auto it = mCommResources.find(comm);
if (it == mCommResources.end())
{
// Nothing registered for this comm, nothing to clean up
return;
}
// Move resources out (preserves order) and remove from map
resourcesToClean = std::move(it->second);
mCommResources.erase(it);
TLLM_LOG_TRACE(
"[NCCLUtil] Cleaning up %zu resources for NCCL comm %p", resourcesToClean.size(), static_cast<void*>(comm));
}
// Clean up outside the lock to avoid deadlocks if cleanup functions try to access the manager
// Order is preserved: resources are cleaned up in registration order
for (auto& [cleanup, name] : resourcesToClean)
{
try
{
TLLM_LOG_TRACE(
"[NCCLUtil] Cleaning up resource '%s' for NCCL comm %p", name.c_str(), static_cast<void*>(comm));
cleanup();
}
catch (std::exception const& e)
{
TLLM_LOG_ERROR("[NCCLUtil] Exception during cleanup of resource '%s' for NCCL comm %p: %s", name.c_str(),
static_cast<void*>(comm), e.what());
}
catch (...)
{
TLLM_LOG_ERROR("[NCCLUtil] Unknown exception during cleanup of resource '%s' for NCCL comm %p",
name.c_str(), static_cast<void*>(comm));
}
}
}
bool NcclCommResourceManager::hasResources(ncclComm_t comm) const noexcept
{
std::lock_guard<std::mutex> lock(mMutex);
return mCommResources.find(comm) != mCommResources.end();
}
size_t NcclCommResourceManager::getResourceCount(ncclComm_t comm) const noexcept
{
std::lock_guard<std::mutex> lock(mMutex);
auto it = mCommResources.find(comm);
return it != mCommResources.end() ? it->second.size() : 0;
}
//==============================================================================
// NCCLHelper Implementation
//==============================================================================
NCCLHelper& NCCLHelper::getInstance()
{
static NCCLHelper instance;
return instance;
}
NCCLHelper::NCCLHelper()
: mLibraryHandle(nullptr)
, mNCCLCommWindowRegister(nullptr)
, mNCCLMemAlloc(nullptr)
, mIsLoaded(false)
{
loadNCCLLibrary();
}
NCCLHelper::~NCCLHelper()
{
if (mLibraryHandle)
{
#ifdef _WIN32
FreeLibrary(mLibraryHandle);
#else
dlclose(mLibraryHandle);
#endif
mLibraryHandle = nullptr;
}
}
void NCCLHelper::loadNCCLLibrary()
{
try
{
#ifdef _WIN32
char const* libraryNames[] = {"nccl.dll"};
#else
char const* libraryNames[] = {"libnccl.so"};
#endif
for (auto const* name : libraryNames)
{
mLibraryHandle = loadLibraryHandle(name);
if (mLibraryHandle)
{
TLLM_LOG_INFO("Successfully loaded NCCL library: %s", name);
break;
}
}
if (!mLibraryHandle)
{
TLLM_LOG_WARNING("Failed to load NCCL library");
return;
}
// Load the required symbols
mNCCLCommWindowRegister
= reinterpret_cast<ncclCommWindowRegisterFunc>(getSymbolAddress(mLibraryHandle, "ncclCommWindowRegister"));
mNCCLMemAlloc = reinterpret_cast<ncclMemAllocFunc>(getSymbolAddress(mLibraryHandle, "ncclMemAlloc"));
if (mNCCLCommWindowRegister == nullptr)
{
TLLM_LOG_WARNING("Failed to load ncclCommWindowRegister symbol, NCCL symmetric will not be supported.");
}
if (mNCCLMemAlloc == nullptr)
{
TLLM_LOG_WARNING("Failed to load ncclMemAlloc symbol, NCCL symmetric will not be supported.");
}
if (mNCCLCommWindowRegister != nullptr && mNCCLMemAlloc != nullptr)
{
mIsLoaded = true;
}
else
{
TLLM_LOG_WARNING(
"Failed to load required NCCL symbols (both ncclCommWindowRegister and ncclMemAlloc are required)");
}
}
catch (std::exception const& e)
{
TLLM_LOG_WARNING("Exception while loading NCCL library: %s", e.what());
}
}
void* NCCLHelper::loadLibraryHandle(char const* libName)
{
#ifdef _WIN32
return LoadLibraryA(libName);
#else
return dlopen(libName, RTLD_LAZY | RTLD_GLOBAL);
#endif
}
void* NCCLHelper::getSymbolAddress(void* handle, char const* symbolName)
{
if (!handle)
{
return nullptr;
}
#ifdef _WIN32
return GetProcAddress(static_cast<HMODULE>(handle), symbolName);
#else
return dlsym(handle, symbolName);
#endif
}
NCCLHelper::ncclCommWindowRegisterFunc NCCLHelper::getNCCLCommWindowRegister()
{
return mNCCLCommWindowRegister;
}
NCCLHelper::ncclMemAllocFunc NCCLHelper::getNCCLMemAlloc()
{
return mNCCLMemAlloc;
}
bool NCCLHelper::isLoaded() const
{
return mIsLoaded;
}
//==============================================================================
// NCCLWindowAllocator Implementation
//==============================================================================
NCCLWindowAllocator& NCCLWindowAllocator::getInstance()
{
static NCCLWindowAllocator instance;
return instance;
}
NCCLWindowBuffer NCCLWindowAllocator::requestBuffer(ncclComm_t comm, size_t size)
{
TLLM_CHECK_WITH_INFO(comm != nullptr, "NCCL communicator cannot be null");
TLLM_CHECK_WITH_INFO(size > 0, "Buffer size must be greater than 0");
std::lock_guard<std::mutex> lock(mMutex);
// Register cleanup callback for this communicator if not already registered
// This is cheap even if no buffers exist yet - cleanup will just return early
registerBufferCleanup(comm);
// Check if we have an available buffer of at least the requested size for this communicator
// Use best-fit: find the smallest buffer that's >= requested size
auto& commBuffers = mBufferPool[comm];
auto bestFit = commBuffers.end();
size_t bestFitSize = std::numeric_limits<size_t>::max();
for (auto it = commBuffers.begin(); it != commBuffers.end(); ++it)
{
if (!it->inUse && it->buffer.size >= size && it->buffer.size < bestFitSize)
{
bestFit = it;
bestFitSize = it->buffer.size;
}
}
if (bestFit != commBuffers.end())
{
bestFit->inUse = true;
TLLM_LOG_TRACE(
"[NCCLUtil] Reusing NCCL window buffer for comm %p: handle=%d, ptr=%p, size=%zu (requested: %zu)",
static_cast<void*>(comm), bestFit->buffer.handle, bestFit->buffer.ptr, bestFit->buffer.size, size);
return bestFit->buffer;
}
// No available buffer found, allocate a new one
TLLM_LOG_TRACE(
"[NCCLUtil] Allocating new NCCL window buffer for comm %p, size=%zu", static_cast<void*>(comm), size);
int handle = static_cast<int>(commBuffers.size());
NCCLWindowBuffer buffer = allocateAndRegisterBuffer(comm, size, handle);
commBuffers.push_back({buffer, true});
return buffer;
}
NCCLWindowBuffer NCCLWindowAllocator::searchBuffer(ncclComm_t comm, void* ptr) const
{
if (!comm || !ptr)
{
return NCCLWindowBuffer();
}
std::lock_guard<std::mutex> lock(mMutex);
return searchBufferLocked(comm, ptr);
}
void NCCLWindowAllocator::releaseBuffer(ncclComm_t comm, void* ptr)
{
if (!comm || !ptr)
{
return;
}
std::lock_guard<std::mutex> lock(mMutex);
auto commIt = mBufferPool.find(comm);
if (commIt == mBufferPool.end())
{
TLLM_LOG_WARNING(
"[NCCLUtil] Attempted to release buffer %p for unknown comm %p", ptr, static_cast<void*>(comm));
return;
}
for (auto& entry : commIt->second)
{
if (entry.buffer.ptr == ptr)
{
entry.inUse = false;
TLLM_LOG_TRACE("[NCCLUtil] Released NCCL window buffer for comm %p: ptr=%p", static_cast<void*>(comm), ptr);
return;
}
}
TLLM_LOG_WARNING("[NCCLUtil] Attempted to release unknown buffer %p for comm %p", ptr, static_cast<void*>(comm));
}
ncclWindow_t NCCLWindowAllocator::getWindow(ncclComm_t comm, void* ptr) const
{
std::lock_guard<std::mutex> lock(mMutex);
NCCLWindowBuffer buffer = searchBufferLocked(comm, ptr);
return buffer.isValid() ? buffer.window : nullptr;
}
size_t NCCLWindowAllocator::getSize(ncclComm_t comm, void* ptr) const
{
std::lock_guard<std::mutex> lock(mMutex);
NCCLWindowBuffer buffer = searchBufferLocked(comm, ptr);
return buffer.isValid() ? buffer.size : 0;
}
NCCLWindowBuffer NCCLWindowAllocator::getBufferInfo(ncclComm_t comm, void* ptr) const
{
std::lock_guard<std::mutex> lock(mMutex);
return searchBufferLocked(comm, ptr);
}
size_t NCCLWindowAllocator::getBufferCount(ncclComm_t comm) const
{
std::lock_guard<std::mutex> lock(mMutex);
auto commIt = mBufferPool.find(comm);
return commIt != mBufferPool.end() ? commIt->second.size() : 0;
}
size_t NCCLWindowAllocator::getBufferInUseCount(ncclComm_t comm) const
{
std::lock_guard<std::mutex> lock(mMutex);
auto commIt = mBufferPool.find(comm);
if (commIt == mBufferPool.end())
{
return 0;
}
size_t count = 0;
for (auto const& entry : commIt->second)
{
if (entry.inUse)
{
++count;
}
}
return count;
}
bool NCCLWindowAllocator::isCommValid(ncclComm_t comm) const noexcept
{
// Simply check for null - all non-null comms are valid
// We don't track cleaned-up comms because NCCL can reuse memory addresses,
// making pointer-based tracking unreliable. New comms will be registered when used.
return comm != nullptr;
}
NCCLWindowBuffer NCCLWindowAllocator::allocateAndRegisterBuffer(ncclComm_t comm, size_t size, int handle)
{
NCCLWindowBuffer buffer;
buffer.handle = handle;
// Get NCCL helper for dynamic symbol loading
auto& ncclHelper = NCCLHelper::getInstance();
if (!ncclHelper.isLoaded())
{
TLLM_THROW("NCCL library could not be loaded for dynamic symbol access");
}
auto ncclMemAllocFunc = ncclHelper.getNCCLMemAlloc();
auto ncclCommWindowRegisterFunc = ncclHelper.getNCCLCommWindowRegister();
// Defensive checks: both function pointers must be non-null
if (ncclMemAllocFunc == nullptr)
{
TLLM_THROW("ncclMemAlloc function pointer is null, cannot allocate NCCL window buffer");
}
if (ncclCommWindowRegisterFunc == nullptr)
{
TLLM_THROW("ncclCommWindowRegister function pointer is null, cannot register NCCL window buffer");
}
// Allocate device memory using ncclMemAlloc
ncclResult_t allocResult = ncclMemAllocFunc(&buffer.ptr, size);
if (allocResult != ncclSuccess)
{
TLLM_THROW("ncclMemAlloc failed with error: %d", allocResult);
}
buffer.size = size;
// Register the buffer with NCCL as a window
ncclResult_t regResult
= ncclCommWindowRegisterFunc(comm, buffer.ptr, size, &buffer.window, NCCL_WIN_COLL_SYMMETRIC);
if (regResult != ncclSuccess)
{
ncclMemFree(buffer.ptr);
TLLM_THROW("ncclCommWindowRegister failed with error: %d", regResult);
}
TLLM_LOG_TRACE("[NCCLUtil] Allocated and registered NCCL window buffer: handle=%d, ptr=%p, size=%zu, window=%p",
handle, buffer.ptr, size, static_cast<void*>(buffer.window));
return buffer;
}
NCCLWindowBuffer NCCLWindowAllocator::searchBufferLocked(ncclComm_t comm, void* ptr) const
{
auto commIt = mBufferPool.find(comm);
if (commIt == mBufferPool.end())
{
return NCCLWindowBuffer();
}
for (auto const& entry : commIt->second)
{
if (entry.buffer.ptr == ptr)
{
return entry.buffer;
}
}
return NCCLWindowBuffer();
}
void NCCLWindowAllocator::registerBufferCleanup(ncclComm_t comm)
{
// Don't register if already registered
if (mRegisteredComms.find(comm) != mRegisteredComms.end())
{
return;
}
mRegisteredComms.insert(comm);
// Register cleanup with the resource manager
NcclCommResourceManager::getInstance().registerResource(
comm, [this, comm]() { this->cleanupBuffersForComm(comm); }, "NCCLWindowAllocator");
}
void NCCLWindowAllocator::cleanupBuffersForComm(ncclComm_t comm) noexcept
{
if (!comm)
{
return;
}
// Synchronize CUDA to ensure all operations using these buffers are complete
// before we deregister windows and free memory
cudaError_t cudaErr = cudaDeviceSynchronize();
if (cudaErr != cudaSuccess)
{
TLLM_LOG_WARNING("[NCCLUtil] cudaDeviceSynchronize failed with error: %d before cleanup for comm %p", cudaErr,
static_cast<void*>(comm));
// Continue anyway - the sync failure might be from a previous error
}
std::lock_guard<std::mutex> lock(mMutex);
// Check if we've already cleaned up this communicator
if (mRegisteredComms.find(comm) == mRegisteredComms.end())
{
// Already cleaned up or never registered
return;
}
auto commIt = mBufferPool.find(comm);
if (commIt == mBufferPool.end())
{
// No buffers to clean up, but mark as cleaned
mRegisteredComms.erase(comm);
return;
}
TLLM_LOG_TRACE(
"[NCCLUtil] Cleaning up %zu NCCL window buffers for comm %p", commIt->second.size(), static_cast<void*>(comm));
// Check for buffers still in use - this shouldn't happen if cleanup is called properly,
// but we log a warning if it does
size_t inUseCount = 0;
for (auto const& entry : commIt->second)
{
if (entry.inUse)
{
++inUseCount;
}
}
if (inUseCount > 0)
{
TLLM_LOG_WARNING(
"[NCCLUtil] Cleaning up %zu buffers still marked as in-use for comm %p. "
"This may indicate buffers weren't properly released before cleanup.",
inUseCount, static_cast<void*>(comm));
}
for (auto& entry : commIt->second)
{
if (entry.buffer.isValid())
{
// Deregister the window - the communicator is still valid at this point
// (cleanup happens before ncclCommDestroy), but we need to be careful
// if buffers are still in use by active operations
if (entry.buffer.window && comm)
{
// Note: Even if buffer is marked inUse, we must deregister since
// the communicator is being destroyed. The communicator is valid,
// but we should handle potential errors gracefully.
ncclResult_t result = ncclCommWindowDeregister(comm, entry.buffer.window);
if (result != ncclSuccess)
{
TLLM_LOG_WARNING(
"[NCCLUtil] ncclCommWindowDeregister failed with error: %d for comm %p, "
"window %p (buffer inUse: %d)",
result, static_cast<void*>(comm), static_cast<void*>(entry.buffer.window), entry.inUse);
}
}
// Free device memory using ncclMemFree
// This should be safe even if deregister failed
if (entry.buffer.ptr)
{
try
{
ncclResult_t ncclResult = ncclMemFree(entry.buffer.ptr);
if (ncclResult != ncclSuccess)
{
TLLM_LOG_WARNING("[NCCLUtil] ncclMemFree failed with error: %d", ncclResult);
}
}
catch (...)
{
TLLM_LOG_ERROR("[NCCLUtil] Exception during ncclMemFree for ptr %p", entry.buffer.ptr);
}
}
TLLM_LOG_TRACE(
"[NCCLUtil] Freed NCCL window buffer: ptr=%p, size=%zu", entry.buffer.ptr, entry.buffer.size);
}
}
mBufferPool.erase(commIt);
mRegisteredComms.erase(comm);
}
} // namespace tensorrt_llm::common::nccl_util
#endif // ENABLE_MULTI_DEVICE