diff --git a/cpp/tensorrt_llm/common/ncclUtils.cpp b/cpp/tensorrt_llm/common/ncclUtils.cpp index 21b976d73d..70e2a5dcb0 100644 --- a/cpp/tensorrt_llm/common/ncclUtils.cpp +++ b/cpp/tensorrt_llm/common/ncclUtils.cpp @@ -383,6 +383,21 @@ NCCLWindowBuffer NCCLWindowAllocator::requestBuffer(ncclComm_t comm, size_t size return bestFit->buffer; } + // No available buffer found, avoid registration during CUDA graph capture + auto stream = at::cuda::getCurrentCUDAStream(); + cudaStreamCaptureStatus capture_status = cudaStreamCaptureStatusNone; + auto capture_err = cudaStreamIsCapturing(stream, &capture_status); + if (capture_err != cudaSuccess) + { + TLLM_LOG_DEBUG("[NCCLUtil] cudaStreamIsCapturing failed: %s", cudaGetErrorString(capture_err)); + } + if (capture_err == cudaSuccess && capture_status != cudaStreamCaptureStatusNone) + { + TLLM_LOG_DEBUG("[NCCLUtil] Skipping NCCL window allocation during capture for comm %p (requested: %zu)", + static_cast(comm), size); + return NCCLWindowBuffer(); + } + // No available buffer found, allocate a new one TLLM_LOG_TRACE( "[NCCLUtil] Allocating new NCCL window buffer for comm %p, size=%zu", static_cast(comm), size); diff --git a/cpp/tensorrt_llm/common/ncclUtils.h b/cpp/tensorrt_llm/common/ncclUtils.h index c809394a48..506dcc5557 100644 --- a/cpp/tensorrt_llm/common/ncclUtils.h +++ b/cpp/tensorrt_llm/common/ncclUtils.h @@ -21,6 +21,8 @@ #include "tensorrt_llm/common/logger.h" #if ENABLE_MULTI_DEVICE +#include +#include #include #include #endif @@ -32,7 +34,6 @@ #include #include #include -#include #include #include #include @@ -377,15 +378,23 @@ inline std::pair createNCCLWindowTensor( // Request buffer from allocator auto& allocator = NCCLWindowAllocator::getInstance(); - auto buffer = allocator.requestBuffer(comm, buffer_size); + NCCLWindowBuffer buffer; + + try + { + buffer = allocator.requestBuffer(comm, buffer_size); + } + catch (std::exception const& e) + { + TLLM_LOG_DEBUG("[createNCCLWindowTensor] requestBuffer failed; returning invalid buffer: %s", e.what()); + return std::make_pair(torch::Tensor(), NCCLWindowBuffer()); + } // Defensive validation: ensure buffer is valid before proceeding if (!buffer.isValid()) { - std::ostringstream oss; - oss << "Failed to allocate NCCL window buffer: invalid buffer returned from requestBuffer " - << "(comm=" << static_cast(comm) << ", buffer_size=" << buffer_size << ")"; - throw std::runtime_error(oss.str()); + TLLM_LOG_DEBUG("[createNCCLWindowTensor] invalid buffer returned from requestBuffer; returning invalid buffer"); + return std::make_pair(torch::Tensor(), NCCLWindowBuffer()); } // Create custom deleter that releases the buffer diff --git a/cpp/tensorrt_llm/thop/allreduceOp.cpp b/cpp/tensorrt_llm/thop/allreduceOp.cpp index c753242518..00fe82c6fd 100644 --- a/cpp/tensorrt_llm/thop/allreduceOp.cpp +++ b/cpp/tensorrt_llm/thop/allreduceOp.cpp @@ -42,6 +42,7 @@ #include #include #include +#include #include #include #include @@ -295,7 +296,6 @@ public: auto const rank = getRank(); TLLM_LOG_DEBUG( "AllReduceOp runtime strategy for rank %d: " + tensorrt_llm::kernels::toString(runtime_strategy), rank); - // Dispatch to different allreduce implementations switch (runtime_strategy) { @@ -508,8 +508,9 @@ private: minRegistrationThreshold = static_cast(std::atoi(envThreshold)) * input.element_size(); } - // Search for existing buffer auto& allocator = NCCLWindowAllocator::getInstance(); + + // Search for existing buffer auto windowBuffer0 = allocator.searchBuffer(comm, input.data_ptr()); torch::Tensor inputTensor = input; @@ -532,11 +533,22 @@ private: // Large buffer: create window buffer and copy input (can swap inputTensor reference) auto [symmetricInput, symmetricBuffer0] = createNCCLWindowTensor(comm, input.sizes(), input.scalar_type()); - TLLM_CUDA_CHECK(cudaMemcpyAsync( - symmetricBuffer0.ptr, input.data_ptr(), bufferSizeBytes, cudaMemcpyDeviceToDevice, stream)); - windowBuffer0 = symmetricBuffer0; - inputTensor = symmetricInput; // Swap to window-backed tensor - inputPtr = windowBuffer0.ptr; + if (!symmetricBuffer0.isValid()) + { + TLLM_LOG_DEBUG( + "[runNCCLAllReduceSymmetric] No valid symmetric buffer available; " + "falling back to non-symmetric ncclAllReduce (input buffer)"); + // inputTensor and inputPtr remain pointing to original input + } + else + { + TLLM_CUDA_CHECK(cudaMemcpyAsync( + symmetricBuffer0.ptr, input.data_ptr(), bufferSizeBytes, cudaMemcpyDeviceToDevice, stream)); + + windowBuffer0 = symmetricBuffer0; + inputTensor = symmetricInput; // Swap to window-backed tensor + inputPtr = windowBuffer0.ptr; + } } } else @@ -547,8 +559,14 @@ private: // Use window-backed output buffer auto [normOut, windowBuffer1] = createNCCLWindowTensor(comm, input.sizes(), input.scalar_type()); - torch::Tensor outputTensor = normOut; - void* outputPtr = windowBuffer1.ptr; + torch::Tensor outputTensor = windowBuffer1.isValid() ? normOut : torch::empty_like(inputTensor); + void* outputPtr = windowBuffer1.isValid() ? windowBuffer1.ptr : outputTensor.data_ptr(); + if (!windowBuffer1.isValid()) + { + TLLM_LOG_DEBUG( + "[runNCCLAllReduceSymmetric] No valid symmetric buffer available; " + "using plain CUDA tensor for output"); + } // Perform allreduce NCCLCHECK_THROW(ncclAllReduce(inputPtr, outputPtr, size, (*getDtypeMap())[mType], ncclSum, comm, stream)); diff --git a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py index 5b683637c6..0fbb5ceacf 100644 --- a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py @@ -1695,8 +1695,7 @@ class AllReduceRunner(TunableRunner): **kwargs, ) -> List[int]: valid_strategies = [ - # TODO: NCCL_SYMMETRIC will cause hang during tuning process - # AllReduceStrategy.NCCL_SYMMETRIC.value, + AllReduceStrategy.NCCL_SYMMETRIC.value, AllReduceStrategy.NCCL.value, ] # Fallback in allreduceOp is set to NCCL_SYMMETRIC as default @@ -1724,8 +1723,7 @@ class AllReduceRunner(TunableRunner): ) -> torch.Tensor: input, residual, norm_weight, scale, bias, workspace = inputs if tactic == -1: - # TODO: Use NCCL instead of NCCL_SYMMETRIC to avoid hanging during tuning process - tactic = AllReduceStrategy.NCCL.value + tactic = AllReduceStrategy.NCCL_SYMMETRIC.value return torch.ops.trtllm.allreduce( input,