mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[None][feat] Add NCCL Symmetric Integration for All Reduce (#4500)
Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com>
This commit is contained in:
parent
980929e1a9
commit
82276167e6
@ -56,6 +56,8 @@ enum class AllReduceStrategyType : int8_t
|
||||
ONESHOT = 4,
|
||||
TWOSHOT = 5,
|
||||
LOWPRECISION = 6,
|
||||
MNNVL = 7,
|
||||
NCCL_SYMMETRIC = 8,
|
||||
};
|
||||
|
||||
enum class AllReduceStrategyConfig : int8_t
|
||||
|
||||
@ -14,47 +14,58 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "ub_allocator.h"
|
||||
#include "tensorrt_llm/common/opUtils.h"
|
||||
#include <set>
|
||||
#include <stdexcept>
|
||||
|
||||
namespace tensorrt_llm::runtime::ub
|
||||
{
|
||||
UserBufferAllocator& UserBufferAllocator::Instance()
|
||||
{
|
||||
static UserBufferAllocator _;
|
||||
return _;
|
||||
}
|
||||
|
||||
void UserBufferAllocator::initialize(tensorrt_llm::runtime::WorldConfig const& world_config)
|
||||
{
|
||||
if (!is_initialized())
|
||||
if (use_nccl_symmetric)
|
||||
{
|
||||
ub_comm_ = nullptr;
|
||||
world_config_ = world_config;
|
||||
create_communicator_grouped2(&ub_comm_, world_config_);
|
||||
TLLM_CHECK(ub_comm_ != nullptr);
|
||||
is_initialized_ = true;
|
||||
static NCCLUserBufferAllocator _;
|
||||
return _;
|
||||
}
|
||||
else
|
||||
{
|
||||
static UserBufferAllocator _;
|
||||
return _;
|
||||
}
|
||||
}
|
||||
|
||||
bool UserBufferAllocator::is_initialized()
|
||||
void UserBufferAllocator::initialize(tensorrt_llm::runtime::WorldConfig const& worldConfig)
|
||||
{
|
||||
return is_initialized_;
|
||||
if (!isInitialized())
|
||||
{
|
||||
mUbComm = nullptr;
|
||||
mWorldConfig = worldConfig;
|
||||
create_communicator_grouped2(&mUbComm, worldConfig);
|
||||
TLLM_CHECK(mUbComm != nullptr);
|
||||
mIsInitialized = true;
|
||||
}
|
||||
}
|
||||
|
||||
UBBuffer UserBufferAllocator::register_ub_buffer(size_t bytes)
|
||||
bool UserBufferAllocator::isInitialized()
|
||||
{
|
||||
TLLM_CHECK(is_initialized());
|
||||
return mIsInitialized;
|
||||
}
|
||||
|
||||
UBBuffer UserBufferAllocator::registerUBBuffer(size_t bytes)
|
||||
{
|
||||
TLLM_CHECK(isInitialized());
|
||||
void* addr = nullptr;
|
||||
int handle = -1;
|
||||
handle = register_user_buffer_collective((void**) &addr, bytes, ub_comm_);
|
||||
handle = register_user_buffer_collective((void**) &addr, bytes, mUbComm);
|
||||
return {addr, handle, bytes};
|
||||
}
|
||||
|
||||
UBBuffer UserBufferAllocator::allocate(size_t bytes)
|
||||
{
|
||||
TLLM_CHECK(is_initialized());
|
||||
auto ub_buffer = register_ub_buffer(bytes);
|
||||
TLLM_CHECK(isInitialized());
|
||||
auto ub_buffer = registerUBBuffer(bytes);
|
||||
TLLM_CHECK(!ub_buffer.invalid());
|
||||
buffers_.push_back(ub_buffer);
|
||||
mBuffers.push_back(ub_buffer);
|
||||
return ub_buffer;
|
||||
}
|
||||
|
||||
@ -62,13 +73,177 @@ void UserBufferAllocator::deallocate(void* addr) {}
|
||||
|
||||
UBBuffer UserBufferAllocator::get(int idx)
|
||||
{
|
||||
TLLM_CHECK(is_initialized() && idx < buffers_.size() && !buffers_[idx].invalid());
|
||||
return buffers_[idx];
|
||||
TLLM_CHECK(isInitialized() && idx < mBuffers.size() && !mBuffers[idx].invalid());
|
||||
return mBuffers[idx];
|
||||
}
|
||||
|
||||
communicator* UserBufferAllocator::comm()
|
||||
{
|
||||
TLLM_CHECK(is_initialized());
|
||||
return ub_comm_;
|
||||
TLLM_CHECK(isInitialized());
|
||||
return mUbComm;
|
||||
}
|
||||
|
||||
void NCCLUserBufferAllocator::initialize(tensorrt_llm::runtime::WorldConfig const& worldConfig)
|
||||
{
|
||||
if (!isInitialized())
|
||||
{
|
||||
TLLM_LOG_INFO("Initializing NCCLUserBufferAllocator");
|
||||
std::set<int> group;
|
||||
for (int i = 0; i < worldConfig.getSize(); i++)
|
||||
{
|
||||
group.insert(i);
|
||||
}
|
||||
mComm = getComm(group);
|
||||
mIsInitialized = true;
|
||||
}
|
||||
}
|
||||
|
||||
UBBuffer NCCLUserBufferAllocator::registerUBBuffer(size_t bytes)
|
||||
{
|
||||
TLLM_CHECK(isInitialized());
|
||||
UBBuffer ub_buffer;
|
||||
|
||||
auto& ncclHelper = getNCCLHelper();
|
||||
if (!ncclHelper.isLoaded())
|
||||
{
|
||||
TLLM_THROW("NCCL library could not be loaded for dynamic symbol access");
|
||||
}
|
||||
|
||||
auto ncclMemAllocFunc = ncclHelper.getNCCLMemAlloc();
|
||||
auto ncclCommWindowRegisterFunc = ncclHelper.getNCCLCommWindowRegister();
|
||||
|
||||
NCCLCHECK(ncclMemAllocFunc(&ub_buffer.addr, bytes));
|
||||
NCCLCHECK(ncclCommWindowRegisterFunc((*mComm), ub_buffer.addr, bytes, &ub_buffer.window, NCCL_WIN_COLL_SYMMETRIC));
|
||||
ub_buffer.handle = 5;
|
||||
ub_buffer.size = bytes;
|
||||
return ub_buffer;
|
||||
}
|
||||
|
||||
// Static member definitions
|
||||
std::unique_ptr<NCCLHelper> NCCLUserBufferAllocator::mNCCLHelper = nullptr;
|
||||
|
||||
NCCLHelper& NCCLUserBufferAllocator::getNCCLHelper()
|
||||
{
|
||||
if (!mNCCLHelper)
|
||||
{
|
||||
mNCCLHelper = std::make_unique<NCCLHelper>();
|
||||
}
|
||||
return *mNCCLHelper;
|
||||
}
|
||||
|
||||
// NCCLHelper implementation
|
||||
NCCLHelper::NCCLHelper()
|
||||
: mLibraryHandle(nullptr)
|
||||
, mNCCLCommWindowRegister(nullptr)
|
||||
, mNCCLMemAlloc(nullptr)
|
||||
, mIsLoaded(false)
|
||||
{
|
||||
loadNCCLLibrary();
|
||||
}
|
||||
|
||||
NCCLHelper::~NCCLHelper()
|
||||
{
|
||||
if (mLibraryHandle)
|
||||
{
|
||||
#ifdef _WIN32
|
||||
FreeLibrary(mLibraryHandle);
|
||||
#else
|
||||
dlclose(mLibraryHandle);
|
||||
#endif
|
||||
mLibraryHandle = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
void NCCLHelper::loadNCCLLibrary()
|
||||
{
|
||||
try
|
||||
{
|
||||
#ifdef _WIN32
|
||||
char const* libraryNames[] = {"nccl.dll"};
|
||||
#else
|
||||
char const* libraryNames[] = {"libnccl.so"};
|
||||
#endif
|
||||
|
||||
for (int i = 0; libraryNames[i] != nullptr; ++i)
|
||||
{
|
||||
mLibraryHandle = loadLibraryHandle(libraryNames[i]);
|
||||
if (mLibraryHandle)
|
||||
{
|
||||
TLLM_LOG_INFO("Successfully loaded NCCL library: %s", libraryNames[i]);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!mLibraryHandle)
|
||||
{
|
||||
TLLM_LOG_WARNING("Failed to load NCCL library");
|
||||
return;
|
||||
}
|
||||
|
||||
// Load the required symbols
|
||||
mNCCLCommWindowRegister
|
||||
= reinterpret_cast<ncclCommWindowRegisterFunc>(getSymbolAddress(mLibraryHandle, "ncclCommWindowRegister"));
|
||||
|
||||
mNCCLMemAlloc = reinterpret_cast<ncclMemAllocFunc>(getSymbolAddress(mLibraryHandle, "ncclMemAlloc"));
|
||||
|
||||
if (mNCCLCommWindowRegister == nullptr)
|
||||
{
|
||||
TLLM_LOG_WARNING("Failed to load ncclCommWindowRegister symbol, NCCL symmetric will not be supported.");
|
||||
}
|
||||
|
||||
if (mNCCLMemAlloc)
|
||||
{
|
||||
mIsLoaded = true;
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_LOG_WARNING("Failed to load required NCCL symbols");
|
||||
}
|
||||
}
|
||||
catch (std::exception const& e)
|
||||
{
|
||||
TLLM_LOG_WARNING("Exception while loading NCCL library: %s", e.what());
|
||||
}
|
||||
}
|
||||
|
||||
void* NCCLHelper::loadLibraryHandle(char const* libName)
|
||||
{
|
||||
#ifdef _WIN32
|
||||
return LoadLibraryA(libName);
|
||||
#else
|
||||
return dlopen(libName, RTLD_LAZY | RTLD_GLOBAL);
|
||||
#endif
|
||||
}
|
||||
|
||||
void* NCCLHelper::getSymbolAddress(void* handle, char const* symbolName)
|
||||
{
|
||||
if (!handle)
|
||||
{
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
#ifdef _WIN32
|
||||
return GetProcAddress(static_cast<HMODULE>(handle), symbolName);
|
||||
#else
|
||||
return dlsym(handle, symbolName);
|
||||
#endif
|
||||
}
|
||||
|
||||
NCCLHelper::ncclCommWindowRegisterFunc NCCLHelper::getNCCLCommWindowRegister()
|
||||
{
|
||||
return mNCCLCommWindowRegister;
|
||||
}
|
||||
|
||||
NCCLHelper::ncclMemAllocFunc NCCLHelper::getNCCLMemAlloc()
|
||||
{
|
||||
return mNCCLMemAlloc;
|
||||
}
|
||||
|
||||
bool NCCLHelper::isLoaded() const
|
||||
{
|
||||
return mIsLoaded;
|
||||
}
|
||||
|
||||
bool UserBufferAllocator::use_nccl_symmetric = false;
|
||||
|
||||
}; // namespace tensorrt_llm::runtime::ub
|
||||
|
||||
@ -14,9 +14,16 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "nccl.h"
|
||||
#include "tensorrt_llm/runtime/worldConfig.h"
|
||||
#include <memory>
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
#include "userbuffers.h"
|
||||
#ifdef _WIN32
|
||||
#include <windows.h>
|
||||
#else
|
||||
#include <dlfcn.h>
|
||||
#endif
|
||||
#endif
|
||||
|
||||
namespace tensorrt_llm::runtime::ub
|
||||
@ -28,11 +35,13 @@ struct UBBuffer
|
||||
void* addr;
|
||||
int handle;
|
||||
size_t size;
|
||||
ncclWindow_t window;
|
||||
|
||||
UBBuffer(void* a = nullptr, int h = -1, size_t s = 0)
|
||||
UBBuffer(void* a = nullptr, int h = -1, size_t s = 0, ncclWindow_t w = nullptr)
|
||||
: addr(a)
|
||||
, handle(h)
|
||||
, size(s)
|
||||
, window(w)
|
||||
{
|
||||
}
|
||||
|
||||
@ -49,21 +58,74 @@ public:
|
||||
|
||||
UserBufferAllocator() = default;
|
||||
|
||||
void initialize(tensorrt_llm::runtime::WorldConfig const& world_config);
|
||||
bool is_initialized();
|
||||
virtual void initialize(tensorrt_llm::runtime::WorldConfig const& worldConfig);
|
||||
bool isInitialized();
|
||||
UBBuffer allocate(size_t bytes);
|
||||
void deallocate(void* addr);
|
||||
UBBuffer get(int idx);
|
||||
communicator* comm();
|
||||
virtual UBBuffer registerUBBuffer(size_t bytes);
|
||||
|
||||
static bool use_nccl_symmetric;
|
||||
|
||||
private:
|
||||
UBBuffer register_ub_buffer(size_t bytes);
|
||||
communicator* mUbComm;
|
||||
|
||||
communicator* ub_comm_;
|
||||
std::vector<UBBuffer> buffers_;
|
||||
bool is_initialized_;
|
||||
tensorrt_llm::runtime::WorldConfig world_config_;
|
||||
protected:
|
||||
std::vector<UBBuffer> mBuffers;
|
||||
bool mIsInitialized;
|
||||
tensorrt_llm::runtime::WorldConfig mWorldConfig;
|
||||
};
|
||||
|
||||
class NCCLHelper
|
||||
{
|
||||
public:
|
||||
NCCLHelper();
|
||||
~NCCLHelper();
|
||||
|
||||
// Dynamic loading function type definition
|
||||
using ncclCommWindowRegisterFunc = ncclResult_t (*)(ncclComm_t, void*, size_t, ncclWindow_t*, int);
|
||||
using ncclMemAllocFunc = ncclResult_t (*)(void**, size_t);
|
||||
|
||||
// Get function pointer for ncclCommWindowRegister
|
||||
ncclCommWindowRegisterFunc getNCCLCommWindowRegister();
|
||||
|
||||
// Get function pointer for ncclMemAlloc
|
||||
ncclMemAllocFunc getNCCLMemAlloc();
|
||||
|
||||
// Check if NCCL library is successfully loaded
|
||||
bool isLoaded() const;
|
||||
|
||||
private:
|
||||
void loadNCCLLibrary();
|
||||
void* loadLibraryHandle(char const* libName);
|
||||
void* getSymbolAddress(void* handle, char const* symbolName);
|
||||
|
||||
#ifdef _WIN32
|
||||
HMODULE mLibraryHandle;
|
||||
#else
|
||||
void* mLibraryHandle;
|
||||
#endif
|
||||
|
||||
ncclCommWindowRegisterFunc mNCCLCommWindowRegister;
|
||||
ncclMemAllocFunc mNCCLMemAlloc;
|
||||
bool mIsLoaded;
|
||||
};
|
||||
|
||||
class NCCLUserBufferAllocator : public UserBufferAllocator
|
||||
{
|
||||
public:
|
||||
void initialize(tensorrt_llm::runtime::WorldConfig const& world_config) override;
|
||||
UBBuffer registerUBBuffer(size_t bytes) override;
|
||||
|
||||
// Get shared NCCLHelper instance
|
||||
static NCCLHelper& getNCCLHelper();
|
||||
|
||||
private:
|
||||
std::shared_ptr<ncclComm_t> mComm;
|
||||
static std::unique_ptr<NCCLHelper> mNCCLHelper;
|
||||
};
|
||||
|
||||
#else
|
||||
using communicator = void;
|
||||
#endif
|
||||
|
||||
@ -36,7 +36,7 @@ void ub_initialize(int tp_size)
|
||||
|
||||
bool ub_is_initialized()
|
||||
{
|
||||
return UserBufferAllocator::Instance().is_initialized();
|
||||
return UserBufferAllocator::Instance().isInitialized();
|
||||
}
|
||||
|
||||
UBBuffer ub_allocate(size_t bytes)
|
||||
|
||||
@ -29,11 +29,14 @@ UserBuffersManager& UserBuffersManager::get_instance()
|
||||
return allocator;
|
||||
}
|
||||
|
||||
void UserBuffersManager::initialize(
|
||||
int64_t tp_size, int64_t pp_size, int64_t cp_size, int64_t rank, int64_t gpus_per_node, int64_t buffer_size)
|
||||
void UserBuffersManager::initialize(int64_t tp_size, int64_t pp_size, int64_t cp_size, int64_t rank,
|
||||
int64_t gpus_per_node, int64_t buffer_size, bool use_nccl_symmetric)
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
tensorrt_llm::runtime::WorldConfig world_config(tp_size, pp_size, cp_size, rank, gpus_per_node);
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
UserBufferAllocator::Instance().use_nccl_symmetric = use_nccl_symmetric;
|
||||
#endif
|
||||
tensorrt_llm::runtime::ub::ub_initialize(world_config);
|
||||
TLLM_CHECK(tensorrt_llm::runtime::ub::ub_is_initialized());
|
||||
buffer_size_ = buffer_size;
|
||||
@ -95,10 +98,11 @@ tensorrt_llm::runtime::ub::communicator* UserBuffersManager::comm()
|
||||
return tensorrt_llm::runtime::ub::ub_comm();
|
||||
}
|
||||
|
||||
void initialize_userbuffers_manager(
|
||||
int64_t tp_size, int64_t pp_size, int64_t cp_size, int64_t rank, int64_t gpus_per_node, int64_t buffer_size)
|
||||
void initialize_userbuffers_manager(int64_t tp_size, int64_t pp_size, int64_t cp_size, int64_t rank,
|
||||
int64_t gpus_per_node, int64_t buffer_size, bool use_nccl_symmetric)
|
||||
{
|
||||
UserBuffersManager::get_instance().initialize(tp_size, pp_size, cp_size, rank, gpus_per_node, buffer_size);
|
||||
UserBuffersManager::get_instance().initialize(
|
||||
tp_size, pp_size, cp_size, rank, gpus_per_node, buffer_size, use_nccl_symmetric);
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::runtime::ub
|
||||
|
||||
@ -46,8 +46,9 @@ public:
|
||||
//! @param gpus_per_node The number of GPUs per node.
|
||||
//! @param buffer_size The size of the buffer to allocate. All buffers allocated by this manager will have this
|
||||
//! size.
|
||||
void initialize(
|
||||
int64_t tp_size, int64_t pp_size, int64_t cp_size, int64_t rank, int64_t gpus_per_node, int64_t buffer_size);
|
||||
//! @param use_nccl_symmetric Whether to use NCCL symmetric communication.
|
||||
void initialize(int64_t tp_size, int64_t pp_size, int64_t cp_size, int64_t rank, int64_t gpus_per_node,
|
||||
int64_t buffer_size, bool use_nccl_symmetric);
|
||||
|
||||
//! @brief Create a UB tensor from the given shape, strides and data type. The function will choose available UB
|
||||
//! buffer or create a new one if no available buffer is found.
|
||||
@ -75,7 +76,7 @@ private:
|
||||
int64_t buffer_size_;
|
||||
};
|
||||
|
||||
void initialize_userbuffers_manager(
|
||||
int64_t tp_size, int64_t pp_size, int64_t cp_size, int64_t rank, int64_t gpus_per_node, int64_t buffer_size);
|
||||
void initialize_userbuffers_manager(int64_t tp_size, int64_t pp_size, int64_t cp_size, int64_t rank,
|
||||
int64_t gpus_per_node, int64_t buffer_size, bool use_nccl_symmetric);
|
||||
|
||||
} // namespace tensorrt_llm::runtime::ub
|
||||
|
||||
@ -456,7 +456,8 @@ void initBindings(pybind11::module_& m)
|
||||
.value("AUTO", tensorrt_llm::kernels::AllReduceStrategyType::AUTO)
|
||||
.value("UB", tensorrt_llm::kernels::AllReduceStrategyType::UB)
|
||||
.value("ONESHOT", tensorrt_llm::kernels::AllReduceStrategyType::ONESHOT)
|
||||
.value("TWOSHOT", tensorrt_llm::kernels::AllReduceStrategyType::TWOSHOT);
|
||||
.value("TWOSHOT", tensorrt_llm::kernels::AllReduceStrategyType::TWOSHOT)
|
||||
.value("NCCL_SYMMETRIC", tensorrt_llm::kernels::AllReduceStrategyType::NCCL_SYMMETRIC);
|
||||
|
||||
// Initialize MoeLoadBalancer bindings
|
||||
initMoeBindings(m);
|
||||
|
||||
@ -163,9 +163,9 @@ public:
|
||||
{
|
||||
size_t size = input.numel();
|
||||
size_t seq_len = input.size(0);
|
||||
size_t bytes_per_element = input.element_size();
|
||||
TLLM_LOG_DEBUG("All reduce message size is %zu", size * bytes_per_element);
|
||||
|
||||
// If strategy is set to UB, UB must be used as UB impl output is special and cannot be used
|
||||
// by others.
|
||||
AllReduceStrategyType runtime_strategy = getRuntimeStrategy(seq_len, size);
|
||||
|
||||
// Log runtime strategy
|
||||
@ -177,6 +177,8 @@ public:
|
||||
{
|
||||
case AllReduceStrategyType::UB: return runUBAllReduce(input, residual, norm_weight, scale, bias);
|
||||
case AllReduceStrategyType::NCCL: return runNCCLAllReduce(input, residual, norm_weight, scale, bias);
|
||||
case AllReduceStrategyType::NCCL_SYMMETRIC:
|
||||
return runNCCLAllReduceSymmetric(input, residual, norm_weight, scale, bias);
|
||||
case AllReduceStrategyType::MIN_LATENCY:
|
||||
case AllReduceStrategyType::ONESHOT:
|
||||
case AllReduceStrategyType::TWOSHOT:
|
||||
@ -303,6 +305,39 @@ private:
|
||||
return fallbackRunSubsequentOps(input, residual, norm_weight, scale, bias, reduce_output);
|
||||
}
|
||||
|
||||
std::vector<torch::Tensor> runNCCLAllReduceSymmetric(torch::Tensor const& input,
|
||||
torch::optional<torch::Tensor> const& residual, torch::optional<torch::Tensor> const& norm_weight,
|
||||
torch::optional<torch::Tensor> const& scale, torch::optional<torch::Tensor> const& bias) noexcept
|
||||
{
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
|
||||
int size = input.numel();
|
||||
auto& ub_manager = tensorrt_llm::runtime::ub::UserBuffersManager::get_instance();
|
||||
auto ub_buffer0 = ub_manager.search_buffer(input.data_ptr());
|
||||
if (ub_buffer0.invalid())
|
||||
{
|
||||
auto [symmetric_input, symmetric_ub_buffer0]
|
||||
= torch_ext::create_userbuffers_tensor(input.sizes(), input.scalar_type());
|
||||
cudaMemcpyAsync(symmetric_ub_buffer0.addr, input.data_ptr(), size * input.element_size(),
|
||||
cudaMemcpyDeviceToDevice, stream);
|
||||
ub_buffer0 = symmetric_ub_buffer0;
|
||||
}
|
||||
|
||||
TLLM_CHECK(!ub_buffer0.invalid());
|
||||
auto [norm_out, ub_buffer1] = torch_ext::create_userbuffers_tensor(input.sizes(), input.scalar_type());
|
||||
|
||||
NCCLCHECK(ncclAllReduce(
|
||||
ub_buffer0.addr, norm_out.mutable_data_ptr(), size, (*getDtypeMap())[mType], ncclSum, *mNcclComm, stream));
|
||||
|
||||
if (mOp == AllReduceFusionOp::NONE)
|
||||
{
|
||||
return {norm_out};
|
||||
}
|
||||
|
||||
// Treat any other patterns as fallback cases.
|
||||
return fallbackRunSubsequentOps(input, residual, norm_weight, scale, bias, norm_out);
|
||||
}
|
||||
|
||||
std::vector<torch::Tensor> runLowPrecisionAllReduce(torch::Tensor const& input,
|
||||
torch::optional<torch::Tensor> const& residual, torch::optional<torch::Tensor> const& norm_weight,
|
||||
torch::optional<torch::Tensor> const& scale, torch::optional<torch::Tensor> const& bias) noexcept
|
||||
@ -633,6 +668,10 @@ private:
|
||||
{
|
||||
runtime_strategy = AllReduceStrategyType::NCCL;
|
||||
}
|
||||
else if (mStrategy == AllReduceStrategyType::NCCL_SYMMETRIC)
|
||||
{
|
||||
runtime_strategy = AllReduceStrategyType::NCCL_SYMMETRIC;
|
||||
}
|
||||
else
|
||||
{
|
||||
// This is for DEBUG and BENCHMARK purpose. It will overried the strategy if AUTO is set.
|
||||
@ -658,6 +697,11 @@ private:
|
||||
TLLM_LOG_DEBUG("AllReducePlugin strategy for rank %d: NCCL", rank);
|
||||
break;
|
||||
}
|
||||
case AllReduceStrategyType::NCCL_SYMMETRIC:
|
||||
{
|
||||
TLLM_LOG_DEBUG("AllReducePlugin strategy for rank %d: NCCL_SYMMETRIC", rank);
|
||||
break;
|
||||
}
|
||||
case AllReduceStrategyType::MIN_LATENCY:
|
||||
{
|
||||
TLLM_LOG_DEBUG("AllReducePlugin strategy for rank %d: MIN_LATENCY", rank);
|
||||
@ -673,7 +717,7 @@ private:
|
||||
TLLM_LOG_DEBUG("AllReducePlugin strategy for rank %d: LOWPRECISION", rank);
|
||||
break;
|
||||
}
|
||||
default: break;
|
||||
default: TLLM_LOG_DEBUG("AllReducePlugin strategy for rank %d: UNKNOWN: %d", rank, strategy); break;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import platform
|
||||
@ -16,6 +17,7 @@ from tensorrt_llm.mapping import Mapping
|
||||
from tensorrt_llm.plugin.plugin import CustomAllReduceHelper
|
||||
|
||||
_thread_local = threading.local()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_allreduce_workspace(mapping: Mapping) -> torch.LongTensor:
|
||||
|
||||
@ -130,7 +130,8 @@ class ModelConfig(Generic[TConfig]):
|
||||
"ONESHOT": AllReduceStrategy.ONESHOT,
|
||||
"TWOSHOT": AllReduceStrategy.TWOSHOT,
|
||||
"LOWPRECISION": AllReduceStrategy.LOWPRECISION,
|
||||
"MNNVL": AllReduceStrategy.MNNVL
|
||||
"MNNVL": AllReduceStrategy.MNNVL,
|
||||
"NCCL_SYMMETRIC": AllReduceStrategy.NCCL_SYMMETRIC
|
||||
}
|
||||
key = strategy.upper()
|
||||
return maps[key] if key in maps else AllReduceStrategy.AUTO
|
||||
|
||||
@ -322,10 +322,14 @@ class PyTorchModelEngine(ModelEngine):
|
||||
and not self.enable_attention_dp)
|
||||
|
||||
try:
|
||||
use_ub_for_nccl = (
|
||||
pytorch_backend_config.allreduce_strategy == "NCCL_SYMMETRIC"
|
||||
and self._init_userbuffers(self.model.config.hidden_size))
|
||||
if pytorch_backend_config.torch_compile_enabled:
|
||||
set_torch_compiling(True)
|
||||
use_ub = pytorch_backend_config.torch_compile_enable_userbuffers and self._init_userbuffers(
|
||||
self.model.config.hidden_size)
|
||||
use_ub = not use_ub_for_nccl and (
|
||||
pytorch_backend_config.torch_compile_enable_userbuffers
|
||||
and self._init_userbuffers(self.model.config.hidden_size))
|
||||
self._torch_compile_backend = Backend(
|
||||
pytorch_backend_config.torch_compile_inductor_enabled,
|
||||
enable_userbuffers=use_ub,
|
||||
@ -2232,12 +2236,12 @@ class PyTorchModelEngine(ModelEngine):
|
||||
# Disable UB for unsupported platforms
|
||||
if not ub.ub_supported():
|
||||
return False
|
||||
ub.initialize_userbuffers_manager(self.mapping.tp_size,
|
||||
self.mapping.pp_size,
|
||||
self.mapping.cp_size,
|
||||
self.mapping.rank,
|
||||
self.mapping.gpus_per_node,
|
||||
hidden_size * self.max_num_tokens * 2)
|
||||
use_nccl_symmetric = self.pytorch_backend_config.allreduce_strategy == "NCCL_SYMMETRIC"
|
||||
ub.initialize_userbuffers_manager(
|
||||
self.mapping.tp_size, self.mapping.pp_size, self.mapping.cp_size,
|
||||
self.mapping.rank, self.mapping.gpus_per_node,
|
||||
hidden_size * self.max_num_tokens * 2, use_nccl_symmetric)
|
||||
|
||||
return True
|
||||
|
||||
def load_weights_from_target_model(self,
|
||||
|
||||
@ -3882,6 +3882,7 @@ class AllReduceStrategy(IntEnum):
|
||||
TWOSHOT = 5
|
||||
LOWPRECISION = 6
|
||||
MNNVL = 7
|
||||
NCCL_SYMMETRIC = 8
|
||||
|
||||
|
||||
class AllReduceFusionOp(IntEnum):
|
||||
|
||||
@ -2098,14 +2098,12 @@ class TorchLlmArgs(BaseLlmArgs):
|
||||
status="prototype",
|
||||
)
|
||||
|
||||
allreduce_strategy: Optional[
|
||||
Literal['AUTO', 'NCCL', 'UB', 'MINLATENCY', 'ONESHOT', 'TWOSHOT',
|
||||
'LOWPRECISION', 'MNNVL']] = Field(
|
||||
default='AUTO',
|
||||
description="Allreduce strategy to use.",
|
||||
status="beta",
|
||||
)
|
||||
|
||||
allreduce_strategy: Optional[Literal[
|
||||
'AUTO', 'NCCL', 'UB', 'MINLATENCY', 'ONESHOT', 'TWOSHOT',
|
||||
'LOWPRECISION', 'MNNVL',
|
||||
'NCCL_SYMMETRIC']] = Field(default='AUTO',
|
||||
description="Allreduce strategy to use.",
|
||||
status="beta")
|
||||
checkpoint_loader: Optional[object] = Field(
|
||||
default=None,
|
||||
description="The checkpoint loader to use for this LLM instance.",
|
||||
|
||||
@ -21,9 +21,9 @@ import pytest
|
||||
import torch
|
||||
from mpi4py import MPI
|
||||
from mpi4py.futures import MPIPoolExecutor
|
||||
from utils.util import skip_pre_blackwell
|
||||
|
||||
import tensorrt_llm
|
||||
import tensorrt_llm.bindings.internal.userbuffers as ub
|
||||
from tensorrt_llm._torch.distributed import (AllReduce, AllReduceFusionOp,
|
||||
AllReduceParams)
|
||||
from tensorrt_llm.functional import AllReduceStrategy
|
||||
@ -55,6 +55,7 @@ def run_single_rank(
|
||||
dtype,
|
||||
fused_add_norm,
|
||||
reference_output_list,
|
||||
strategy,
|
||||
):
|
||||
rank = tensorrt_llm.mpi_rank()
|
||||
torch.cuda.set_device(rank)
|
||||
@ -70,6 +71,7 @@ def run_single_rank(
|
||||
rank,
|
||||
fused_add_norm,
|
||||
reference_output_list,
|
||||
strategy,
|
||||
)
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
@ -89,6 +91,7 @@ def row_linear_residual_norm_fusion_forward(
|
||||
tensor_parallel_rank: int,
|
||||
fusion: bool,
|
||||
reference_output_list: list[tuple[torch.Tensor, ...]],
|
||||
strategy: AllReduceStrategy,
|
||||
):
|
||||
|
||||
# Move all tensors to GPU
|
||||
@ -100,6 +103,12 @@ def row_linear_residual_norm_fusion_forward(
|
||||
for ref_output in reference_output_list
|
||||
]
|
||||
|
||||
if strategy == AllReduceStrategy.NCCL_SYMMETRIC:
|
||||
ub.initialize_userbuffers_manager(
|
||||
tensor_parallel_size, 1, 1, tensor_parallel_rank,
|
||||
torch.cuda.device_count(),
|
||||
x_list[0].nelement() * x_list[0].element_size(), True)
|
||||
|
||||
MPI.COMM_WORLD.barrier()
|
||||
|
||||
# Create a single AllReduce instance to be reused for all sequence lengths
|
||||
@ -109,7 +118,7 @@ def row_linear_residual_norm_fusion_forward(
|
||||
tp_size=tensor_parallel_size,
|
||||
rank=tensor_parallel_rank,
|
||||
),
|
||||
strategy=AllReduceStrategy.MNNVL,
|
||||
strategy=strategy,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
@ -152,7 +161,6 @@ def row_linear_residual_norm_fusion_forward(
|
||||
)
|
||||
|
||||
|
||||
@skip_pre_blackwell
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||
reason="needs 2 GPUs to run this test")
|
||||
@pytest.mark.parametrize(
|
||||
@ -165,12 +173,16 @@ def row_linear_residual_norm_fusion_forward(
|
||||
ids=lambda x: f"hidden:{x}")
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16],
|
||||
ids=lambda x: f"dtype:{torch.finfo(x).dtype}")
|
||||
@pytest.mark.parametrize(
|
||||
"strategy", [AllReduceStrategy.MNNVL, AllReduceStrategy.NCCL_SYMMETRIC],
|
||||
ids=lambda x: f"strategy:{x}")
|
||||
@pytest.mark.parametrize(
|
||||
"fusion",
|
||||
[True, False],
|
||||
ids=["fusion", "no_fusion"],
|
||||
)
|
||||
def test_row_linear_residual_norm_fusion(seq_len, hidden_size, dtype, fusion):
|
||||
def test_row_linear_residual_norm_fusion(seq_len, hidden_size, dtype, strategy,
|
||||
fusion):
|
||||
|
||||
torch.manual_seed(42)
|
||||
tensor_parallel_size = 2
|
||||
@ -216,6 +228,7 @@ def test_row_linear_residual_norm_fusion(seq_len, hidden_size, dtype, fusion):
|
||||
dtype,
|
||||
fusion,
|
||||
reference_output_list,
|
||||
strategy,
|
||||
) for i in range(tensor_parallel_size)
|
||||
]),
|
||||
)
|
||||
|
||||
@ -35,7 +35,8 @@ pytestmark = pytest.mark.threadleak(enabled=False)
|
||||
|
||||
def init_userbuffers_allocator(tp_size, rank, max_ub_size):
|
||||
ub.initialize_userbuffers_manager(tp_size, 1, 1, rank,
|
||||
torch.cuda.device_count(), max_ub_size)
|
||||
torch.cuda.device_count(), max_ub_size,
|
||||
False)
|
||||
|
||||
|
||||
def create_userbuffers_tensor(shape, dtype):
|
||||
|
||||
@ -144,7 +144,7 @@ methods:
|
||||
default: False
|
||||
status: prototype
|
||||
allreduce_strategy:
|
||||
annotation: Optional[Literal['AUTO', 'NCCL', 'UB', 'MINLATENCY', 'ONESHOT', 'TWOSHOT', 'LOWPRECISION', 'MNNVL']]
|
||||
annotation: Optional[Literal['AUTO', 'NCCL', 'UB', 'MINLATENCY', 'ONESHOT', 'TWOSHOT', 'LOWPRECISION', 'MNNVL', 'NCCL_SYMMETRIC']]
|
||||
default: AUTO
|
||||
status: beta
|
||||
decoding_config:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user