TensorRT-LLMs/cpp/tensorrt_llm/common/ncclUtils.h
Yihan Wang 9df4dad3b6
[None][fix] Introduce inline namespace to avoid symbol collision (#9541)
Signed-off-by: Yihan Wang <yihwang@nvidia.com>
2025-12-12 23:32:15 +08:00

403 lines
13 KiB
C++

/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/config.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/logger.h"
#if ENABLE_MULTI_DEVICE
#include <nccl.h>
#include <torch/extension.h>
#endif
#include <algorithm>
#include <functional>
#include <limits>
#include <memory>
#include <mutex>
#include <numeric>
#include <sstream>
#include <stdexcept>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#if ENABLE_MULTI_DEVICE
#ifdef _WIN32
#include <windows.h>
#else
#include <dlfcn.h>
#endif
TRTLLM_NAMESPACE_BEGIN
namespace common::nccl_util
{
//==============================================================================
// NCCL Helper - Dynamic Library Loading
//==============================================================================
// Helper class for dynamically loading NCCL symbols (ncclMemAlloc, ncclCommWindowRegister)
// This allows the code to work with NCCL libraries that may or may not have these symbols
class NCCLHelper
{
public:
static NCCLHelper& getInstance();
// Dynamic loading function type definition
using ncclCommWindowRegisterFunc = ncclResult_t (*)(ncclComm_t, void*, size_t, ncclWindow_t*, int);
using ncclMemAllocFunc = ncclResult_t (*)(void**, size_t);
// Get function pointer for ncclCommWindowRegister
ncclCommWindowRegisterFunc getNCCLCommWindowRegister();
// Get function pointer for ncclMemAlloc
ncclMemAllocFunc getNCCLMemAlloc();
// Check if NCCL library is successfully loaded
bool isLoaded() const;
NCCLHelper(NCCLHelper const&) = delete;
NCCLHelper& operator=(NCCLHelper const&) = delete;
NCCLHelper(NCCLHelper&&) = delete;
NCCLHelper& operator=(NCCLHelper&&) = delete;
private:
NCCLHelper();
~NCCLHelper();
void loadNCCLLibrary();
void* loadLibraryHandle(char const* libName);
void* getSymbolAddress(void* handle, char const* symbolName);
#ifdef _WIN32
HMODULE mLibraryHandle;
#else
void* mLibraryHandle;
#endif
ncclCommWindowRegisterFunc mNCCLCommWindowRegister;
ncclMemAllocFunc mNCCLMemAlloc;
bool mIsLoaded;
};
//==============================================================================
// NCCL Resource Management
//==============================================================================
// Resource cleanup function type. Called before the NCCL communicator is destroyed.
using ResourceCleanupFunc = std::function<void()>;
// Manages resources associated with NCCL communicators. Thread-safe singleton that maintains
// a pool of resources per NCCL comm. Resources are automatically cleaned up when the
// communicator is destroyed.
class NcclCommResourceManager
{
public:
static NcclCommResourceManager& getInstance() noexcept;
// Register a resource cleanup function for a specific NCCL communicator.
// The cleanup function will be called before ncclCommDestroy.
// Thread-safe: Uses global mutex to serialize all operations.
void registerResource(ncclComm_t comm, ResourceCleanupFunc cleanup, char const* debugName = nullptr);
// Cleanup all resources associated with a communicator. Called automatically by
// the shared_ptr deleter before ncclCommDestroy.
// Thread-safe: Uses global mutex to serialize cleanup operations.
// Order-preserving: Resources are cleaned up in registration order.
void cleanupResources(ncclComm_t comm) noexcept;
// Check if a communicator has registered resources.
bool hasResources(ncclComm_t comm) const noexcept;
// Get the number of resources registered for a communicator.
size_t getResourceCount(ncclComm_t comm) const noexcept;
NcclCommResourceManager(NcclCommResourceManager const&) = delete;
NcclCommResourceManager& operator=(NcclCommResourceManager const&) = delete;
NcclCommResourceManager(NcclCommResourceManager&&) = delete;
NcclCommResourceManager& operator=(NcclCommResourceManager&&) = delete;
private:
NcclCommResourceManager() = default;
~NcclCommResourceManager() = default;
using ResourceEntry = std::pair<ResourceCleanupFunc, std::string>;
mutable std::mutex mMutex;
std::unordered_map<ncclComm_t, std::vector<ResourceEntry>> mCommResources;
};
// RAII helper to register a resource with a NCCL communicator.
// Automatically registers cleanup function on construction.
template <typename ResourceType>
class NcclCommResource
{
public:
NcclCommResource(ncclComm_t comm, ResourceType&& resource, std::function<void(ResourceType&)> cleanup,
char const* debugName = nullptr)
: mComm(comm)
, mResource(std::forward<ResourceType>(resource))
, mCleanup(std::move(cleanup))
, mRegistered(true)
{
// Register with the manager
NcclCommResourceManager::getInstance().registerResource(
comm,
[this]()
{
if (mCleanup)
{
mCleanup(mResource);
}
},
debugName);
}
ResourceType& get()
{
return mResource;
}
ResourceType const& get() const
{
return mResource;
}
NcclCommResource(NcclCommResource const&) = delete;
NcclCommResource& operator=(NcclCommResource const&) = delete;
NcclCommResource(NcclCommResource&&) = delete;
NcclCommResource& operator=(NcclCommResource&&) = delete;
private:
ncclComm_t mComm;
ResourceType mResource;
std::function<void(ResourceType&)> mCleanup;
bool mRegistered;
};
//==============================================================================
// NCCL Window Buffer Allocation
//==============================================================================
// Represents a buffer with an associated NCCL window
struct NCCLWindowBuffer
{
void* ptr; // Device pointer (same as UBBuffer.addr)
int handle; // Buffer handle/index (for compatibility with UB interface)
size_t size; // Size in bytes
ncclWindow_t window; // NCCL window handle
NCCLWindowBuffer(void* p = nullptr, int h = -1, size_t s = 0, ncclWindow_t w = nullptr)
: ptr(p)
, handle(h)
, size(s)
, window(w)
{
}
[[nodiscard]] bool isValid() const
{
return ptr != nullptr && handle >= 0 && size > 0 && window != nullptr;
}
[[nodiscard]] bool invalid() const
{
return !isValid();
}
// Alias for compatibility with UBBuffer interface
void* addr() const
{
return ptr;
}
};
// Manages NCCL window-registered buffers with pooling and automatic cleanup.
// Buffers are tied to the lifetime of their associated NCCL communicator.
class NCCLWindowAllocator
{
public:
static NCCLWindowAllocator& getInstance();
// Request a buffer for the given communicator and size.
// If an unused buffer of at least the requested size exists for this communicator, it will be reused.
// Uses best-fit strategy: selects the smallest available buffer that meets the size requirement.
// Otherwise, a new buffer is allocated and registered.
NCCLWindowBuffer requestBuffer(ncclComm_t comm, size_t size);
// Search for a buffer by pointer. Returns an invalid buffer if not found.
// This matches the UBManager.search_buffer() interface.
NCCLWindowBuffer searchBuffer(ncclComm_t comm, void* ptr) const;
// Release a buffer back to the pool for potential reuse
void releaseBuffer(ncclComm_t comm, void* ptr);
// Get the window handle for a specific buffer pointer
ncclWindow_t getWindow(ncclComm_t comm, void* ptr) const;
// Get the size of a specific buffer pointer
size_t getSize(ncclComm_t comm, void* ptr) const;
// Get buffer info by pointer
NCCLWindowBuffer getBufferInfo(ncclComm_t comm, void* ptr) const;
// Get the number of buffers allocated for a communicator
size_t getBufferCount(ncclComm_t comm) const;
// Get the number of buffers in use for a communicator
size_t getBufferInUseCount(ncclComm_t comm) const;
// Check if a communicator is valid (non-null)
// Note: We don't track cleaned-up comms because NCCL can reuse memory addresses.
// All non-null comms are considered valid and will be registered when first used.
bool isCommValid(ncclComm_t comm) const noexcept;
NCCLWindowAllocator(NCCLWindowAllocator const&) = delete;
NCCLWindowAllocator& operator=(NCCLWindowAllocator const&) = delete;
NCCLWindowAllocator(NCCLWindowAllocator&&) = delete;
NCCLWindowAllocator& operator=(NCCLWindowAllocator&&) = delete;
private:
NCCLWindowAllocator() = default;
~NCCLWindowAllocator() = default;
// Allocate a new buffer and register it with NCCL as a window
NCCLWindowBuffer allocateAndRegisterBuffer(ncclComm_t comm, size_t size, int handle);
// Search for a buffer by pointer (assumes mMutex is already locked)
NCCLWindowBuffer searchBufferLocked(ncclComm_t comm, void* ptr) const;
// Register cleanup function for all buffers associated with a communicator
void registerBufferCleanup(ncclComm_t comm);
// Cleanup all buffers for a specific communicator
void cleanupBuffersForComm(ncclComm_t comm) noexcept;
struct BufferEntry
{
NCCLWindowBuffer buffer;
bool inUse;
};
mutable std::mutex mMutex;
std::unordered_map<ncclComm_t, std::vector<BufferEntry>> mBufferPool;
std::unordered_set<ncclComm_t> mRegisteredComms;
};
// RAII wrapper for NCCL window buffers
class ScopedNCCLWindowBuffer
{
public:
ScopedNCCLWindowBuffer(ncclComm_t comm, size_t size)
: mComm(comm)
, mBuffer(NCCLWindowAllocator::getInstance().requestBuffer(comm, size))
{
}
~ScopedNCCLWindowBuffer()
{
if (mBuffer.isValid())
{
NCCLWindowAllocator::getInstance().releaseBuffer(mComm, mBuffer.ptr);
}
}
void* getPtr() const
{
return mBuffer.ptr;
}
size_t getSize() const
{
return mBuffer.size;
}
ncclWindow_t getWindow() const
{
return mBuffer.window;
}
NCCLWindowBuffer const& getBuffer() const
{
return mBuffer;
}
ScopedNCCLWindowBuffer(ScopedNCCLWindowBuffer const&) = delete;
ScopedNCCLWindowBuffer& operator=(ScopedNCCLWindowBuffer const&) = delete;
ScopedNCCLWindowBuffer(ScopedNCCLWindowBuffer&&) = delete;
ScopedNCCLWindowBuffer& operator=(ScopedNCCLWindowBuffer&&) = delete;
private:
ncclComm_t mComm;
NCCLWindowBuffer mBuffer;
};
// Creates a PyTorch tensor backed by an NCCL window buffer.
// The tensor will automatically release the buffer back to the pool when destroyed.
// This is analogous to torch_ext::create_userbuffers_tensor() but for NCCLWindowAllocator.
inline std::pair<torch::Tensor, NCCLWindowBuffer> createNCCLWindowTensor(
ncclComm_t comm, at::IntArrayRef shape, torch::ScalarType dtype)
{
// Calculate buffer size
int64_t buffer_size
= std::accumulate(shape.begin(), shape.end(), 1LL, std::multiplies<int64_t>()) * torch::elementSize(dtype);
// Calculate strides
std::vector<int64_t> strides_vec(shape.size());
if (!shape.empty())
{
strides_vec[shape.size() - 1] = 1;
for (int64_t i = static_cast<int64_t>(shape.size()) - 1; i >= 1; --i)
{
strides_vec[i - 1] = strides_vec[i] * shape[i];
}
}
// Request buffer from allocator
auto& allocator = NCCLWindowAllocator::getInstance();
auto buffer = allocator.requestBuffer(comm, buffer_size);
// Defensive validation: ensure buffer is valid before proceeding
if (!buffer.isValid())
{
std::ostringstream oss;
oss << "Failed to allocate NCCL window buffer: invalid buffer returned from requestBuffer "
<< "(comm=" << static_cast<void*>(comm) << ", buffer_size=" << buffer_size << ")";
throw std::runtime_error(oss.str());
}
// Create custom deleter that releases the buffer
auto deleter = [comm, ptr = buffer.ptr](void*) { NCCLWindowAllocator::getInstance().releaseBuffer(comm, ptr); };
// Create tensor from the buffer
auto tensor = torch::from_blob(buffer.ptr, shape, strides_vec, deleter, torch::dtype(dtype).device(torch::kCUDA));
return std::make_pair(tensor, buffer);
}
} // namespace common::nccl_util
TRTLLM_NAMESPACE_END
#endif // ENABLE_MULTI_DEVICE