mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-04 18:21:52 +08:00
[None][fix] nccl symmetric with graceful fallbacks (#11042)
Signed-off-by: Ludwig Schneider <lschneider@nvidia.com>
This commit is contained in:
parent
393c3d259e
commit
4e10bf8950
@ -383,6 +383,21 @@ NCCLWindowBuffer NCCLWindowAllocator::requestBuffer(ncclComm_t comm, size_t size
|
||||
return bestFit->buffer;
|
||||
}
|
||||
|
||||
// No available buffer found, avoid registration during CUDA graph capture
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
cudaStreamCaptureStatus capture_status = cudaStreamCaptureStatusNone;
|
||||
auto capture_err = cudaStreamIsCapturing(stream, &capture_status);
|
||||
if (capture_err != cudaSuccess)
|
||||
{
|
||||
TLLM_LOG_DEBUG("[NCCLUtil] cudaStreamIsCapturing failed: %s", cudaGetErrorString(capture_err));
|
||||
}
|
||||
if (capture_err == cudaSuccess && capture_status != cudaStreamCaptureStatusNone)
|
||||
{
|
||||
TLLM_LOG_DEBUG("[NCCLUtil] Skipping NCCL window allocation during capture for comm %p (requested: %zu)",
|
||||
static_cast<void*>(comm), size);
|
||||
return NCCLWindowBuffer();
|
||||
}
|
||||
|
||||
// No available buffer found, allocate a new one
|
||||
TLLM_LOG_TRACE(
|
||||
"[NCCLUtil] Allocating new NCCL window buffer for comm %p, size=%zu", static_cast<void*>(comm), size);
|
||||
|
||||
@ -21,6 +21,8 @@
|
||||
#include "tensorrt_llm/common/logger.h"
|
||||
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <nccl.h>
|
||||
#include <torch/extension.h>
|
||||
#endif
|
||||
@ -32,7 +34,6 @@
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
@ -377,15 +378,23 @@ inline std::pair<torch::Tensor, NCCLWindowBuffer> createNCCLWindowTensor(
|
||||
|
||||
// Request buffer from allocator
|
||||
auto& allocator = NCCLWindowAllocator::getInstance();
|
||||
auto buffer = allocator.requestBuffer(comm, buffer_size);
|
||||
NCCLWindowBuffer buffer;
|
||||
|
||||
try
|
||||
{
|
||||
buffer = allocator.requestBuffer(comm, buffer_size);
|
||||
}
|
||||
catch (std::exception const& e)
|
||||
{
|
||||
TLLM_LOG_DEBUG("[createNCCLWindowTensor] requestBuffer failed; returning invalid buffer: %s", e.what());
|
||||
return std::make_pair(torch::Tensor(), NCCLWindowBuffer());
|
||||
}
|
||||
|
||||
// Defensive validation: ensure buffer is valid before proceeding
|
||||
if (!buffer.isValid())
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << "Failed to allocate NCCL window buffer: invalid buffer returned from requestBuffer "
|
||||
<< "(comm=" << static_cast<void*>(comm) << ", buffer_size=" << buffer_size << ")";
|
||||
throw std::runtime_error(oss.str());
|
||||
TLLM_LOG_DEBUG("[createNCCLWindowTensor] invalid buffer returned from requestBuffer; returning invalid buffer");
|
||||
return std::make_pair(torch::Tensor(), NCCLWindowBuffer());
|
||||
}
|
||||
|
||||
// Create custom deleter that releases the buffer
|
||||
|
||||
@ -42,6 +42,7 @@
|
||||
#include <ATen/cuda/EmptyTensor.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <nccl.h>
|
||||
#include <torch/csrc/distributed/c10d/FileStore.hpp>
|
||||
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
|
||||
@ -295,7 +296,6 @@ public:
|
||||
auto const rank = getRank();
|
||||
TLLM_LOG_DEBUG(
|
||||
"AllReduceOp runtime strategy for rank %d: " + tensorrt_llm::kernels::toString(runtime_strategy), rank);
|
||||
|
||||
// Dispatch to different allreduce implementations
|
||||
switch (runtime_strategy)
|
||||
{
|
||||
@ -508,8 +508,9 @@ private:
|
||||
minRegistrationThreshold = static_cast<size_t>(std::atoi(envThreshold)) * input.element_size();
|
||||
}
|
||||
|
||||
// Search for existing buffer
|
||||
auto& allocator = NCCLWindowAllocator::getInstance();
|
||||
|
||||
// Search for existing buffer
|
||||
auto windowBuffer0 = allocator.searchBuffer(comm, input.data_ptr());
|
||||
|
||||
torch::Tensor inputTensor = input;
|
||||
@ -532,11 +533,22 @@ private:
|
||||
// Large buffer: create window buffer and copy input (can swap inputTensor reference)
|
||||
auto [symmetricInput, symmetricBuffer0]
|
||||
= createNCCLWindowTensor(comm, input.sizes(), input.scalar_type());
|
||||
TLLM_CUDA_CHECK(cudaMemcpyAsync(
|
||||
symmetricBuffer0.ptr, input.data_ptr(), bufferSizeBytes, cudaMemcpyDeviceToDevice, stream));
|
||||
windowBuffer0 = symmetricBuffer0;
|
||||
inputTensor = symmetricInput; // Swap to window-backed tensor
|
||||
inputPtr = windowBuffer0.ptr;
|
||||
if (!symmetricBuffer0.isValid())
|
||||
{
|
||||
TLLM_LOG_DEBUG(
|
||||
"[runNCCLAllReduceSymmetric] No valid symmetric buffer available; "
|
||||
"falling back to non-symmetric ncclAllReduce (input buffer)");
|
||||
// inputTensor and inputPtr remain pointing to original input
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_CUDA_CHECK(cudaMemcpyAsync(
|
||||
symmetricBuffer0.ptr, input.data_ptr(), bufferSizeBytes, cudaMemcpyDeviceToDevice, stream));
|
||||
|
||||
windowBuffer0 = symmetricBuffer0;
|
||||
inputTensor = symmetricInput; // Swap to window-backed tensor
|
||||
inputPtr = windowBuffer0.ptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
@ -547,8 +559,14 @@ private:
|
||||
|
||||
// Use window-backed output buffer
|
||||
auto [normOut, windowBuffer1] = createNCCLWindowTensor(comm, input.sizes(), input.scalar_type());
|
||||
torch::Tensor outputTensor = normOut;
|
||||
void* outputPtr = windowBuffer1.ptr;
|
||||
torch::Tensor outputTensor = windowBuffer1.isValid() ? normOut : torch::empty_like(inputTensor);
|
||||
void* outputPtr = windowBuffer1.isValid() ? windowBuffer1.ptr : outputTensor.data_ptr();
|
||||
if (!windowBuffer1.isValid())
|
||||
{
|
||||
TLLM_LOG_DEBUG(
|
||||
"[runNCCLAllReduceSymmetric] No valid symmetric buffer available; "
|
||||
"using plain CUDA tensor for output");
|
||||
}
|
||||
|
||||
// Perform allreduce
|
||||
NCCLCHECK_THROW(ncclAllReduce(inputPtr, outputPtr, size, (*getDtypeMap())[mType], ncclSum, comm, stream));
|
||||
|
||||
@ -1695,8 +1695,7 @@ class AllReduceRunner(TunableRunner):
|
||||
**kwargs,
|
||||
) -> List[int]:
|
||||
valid_strategies = [
|
||||
# TODO: NCCL_SYMMETRIC will cause hang during tuning process
|
||||
# AllReduceStrategy.NCCL_SYMMETRIC.value,
|
||||
AllReduceStrategy.NCCL_SYMMETRIC.value,
|
||||
AllReduceStrategy.NCCL.value,
|
||||
]
|
||||
# Fallback in allreduceOp is set to NCCL_SYMMETRIC as default
|
||||
@ -1724,8 +1723,7 @@ class AllReduceRunner(TunableRunner):
|
||||
) -> torch.Tensor:
|
||||
input, residual, norm_weight, scale, bias, workspace = inputs
|
||||
if tactic == -1:
|
||||
# TODO: Use NCCL instead of NCCL_SYMMETRIC to avoid hanging during tuning process
|
||||
tactic = AllReduceStrategy.NCCL.value
|
||||
tactic = AllReduceStrategy.NCCL_SYMMETRIC.value
|
||||
|
||||
return torch.ops.trtllm.allreduce(
|
||||
input,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user