From 5130cbd73e9f3b82fe04e4acdb6b94eaeb02b67c Mon Sep 17 00:00:00 2001 From: Ludwig Schneider Date: Thu, 12 Feb 2026 16:31:51 -0600 Subject: [PATCH] [None][fix] Pre-Allocation for Auto-Tuning NCCL_SYMMETRIC (#11326) Signed-off-by: Ludwig Schneider --- cpp/tensorrt_llm/common/ncclUtils.cpp | 4 ++ cpp/tensorrt_llm/thop/allreduceOp.cpp | 72 +++++++++++++++++++ .../_torch/custom_ops/torch_custom_ops.py | 61 +++++++++++++++- 3 files changed, 136 insertions(+), 1 deletion(-) diff --git a/cpp/tensorrt_llm/common/ncclUtils.cpp b/cpp/tensorrt_llm/common/ncclUtils.cpp index 70e2a5dcb0..7cae70cf35 100644 --- a/cpp/tensorrt_llm/common/ncclUtils.cpp +++ b/cpp/tensorrt_llm/common/ncclUtils.cpp @@ -626,8 +626,10 @@ void NCCLWindowAllocator::cleanupBuffersForComm(ncclComm_t comm) noexcept // Check for buffers still in use - this shouldn't happen if cleanup is called properly, // but we log a warning if it does size_t inUseCount = 0; + size_t totalBytes = 0; for (auto const& entry : commIt->second) { + totalBytes += entry.buffer.size; if (entry.inUse) { ++inUseCount; @@ -640,6 +642,8 @@ void NCCLWindowAllocator::cleanupBuffersForComm(ncclComm_t comm) noexcept "This may indicate buffers weren't properly released before cleanup.", inUseCount, static_cast(comm)); } + TLLM_LOG_DEBUG("[NCCLUtil] NCCL window allocator teardown for comm %p: %zu buffers, %zu bytes total", + static_cast(comm), commIt->second.size(), totalBytes); for (auto& entry : commIt->second) { diff --git a/cpp/tensorrt_llm/thop/allreduceOp.cpp b/cpp/tensorrt_llm/thop/allreduceOp.cpp index 251ce440b5..b5dc61d3a0 100644 --- a/cpp/tensorrt_llm/thop/allreduceOp.cpp +++ b/cpp/tensorrt_llm/thop/allreduceOp.cpp @@ -1428,6 +1428,75 @@ private: #endif // ENABLE_MULTI_DEVICE +void preallocateNCCLWindowBuffer( + torch::Tensor const& input, torch::List const& group, const int64_t buffersPerSize) +{ +#if ENABLE_MULTI_DEVICE + if (buffersPerSize <= 0 || group.size() == 0 || input.numel() == 0 || input.size(0) == 0) + { + return; + } + + std::set groupSet; + for (auto const& rank : group) + { + groupSet.insert(static_cast(rank)); + } + + auto const commPtr = getComm(groupSet); + if (!commPtr || *commPtr == nullptr) + { + TLLM_LOG_DEBUG("[preallocateNCCLWindowBuffers] NCCL comm is null; skipping preallocation"); + return; + } + + using tensorrt_llm::common::nccl_util::NCCLWindowAllocator; + auto& allocator = NCCLWindowAllocator::getInstance(); + const ncclComm_t comm = *commPtr; + + const int64_t numTokens = input.size(0); + const int64_t elementsPerToken = input.numel() / numTokens; + if (elementsPerToken <= 0) + { + return; + } + const size_t bufferSize = static_cast(numTokens) * static_cast(elementsPerToken) + * static_cast(input.element_size()); + if (bufferSize == 0) + { + return; + } + TLLM_LOG_DEBUG("[preallocateNCCLWindowBuffer] Pre-allocating %ld buffer(s) for tokens=%ld (%zu bytes) comm %p", + buffersPerSize, numTokens, bufferSize, static_cast(comm)); + std::vector allocatedPtrs; + allocatedPtrs.reserve(buffersPerSize); + try + { + for (int64_t i = 0; i < buffersPerSize; ++i) + { + auto buffer = allocator.requestBuffer(comm, bufferSize); + if (!buffer.isValid()) + { + break; + } + allocatedPtrs.push_back(buffer.ptr); + } + } + catch (std::exception const& e) + { + TLLM_LOG_DEBUG("[preallocateNCCLWindowBuffer] requestBuffer failed for %zu bytes: %s", bufferSize, e.what()); + } + + for (auto ptr : allocatedPtrs) + { + allocator.releaseBuffer(comm, ptr); + } +#else + (void) group; + (void) buffersPerSize; +#endif +} + std::vector allreduce_raw(torch::Tensor const& input, torch::optional const& residual, torch::optional const& norm_weight, torch::optional const& scale, torch::optional const& bias, torch::optional workspace, @@ -1747,6 +1816,7 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m) "int rank," "int nranks," "float eps) -> Tensor[]"); + m.def("preallocate_nccl_window_buffer(Tensor input, int[] group, int count) -> ()"); } TORCH_LIBRARY_IMPL(trtllm, CUDA, m) @@ -1756,6 +1826,7 @@ TORCH_LIBRARY_IMPL(trtllm, CUDA, m) m.impl("allreduce_pg", &tensorrt_llm::torch_ext::allreduce_pg); m.impl("moe_allreduce", &tensorrt_llm::torch_ext::moe_allreduce); m.impl("moe_finalize_allreduce", &tensorrt_llm::torch_ext::moe_finalize_allreduce); + m.impl("preallocate_nccl_window_buffer", &tensorrt_llm::torch_ext::preallocateNCCLWindowBuffer); } TORCH_LIBRARY_IMPL(trtllm, CPU, m) @@ -1767,4 +1838,5 @@ TORCH_LIBRARY_IMPL(trtllm, CPU, m) reinterpret_cast(workspace.data_ptr()), (int) tp_size); return std::vector{}; }); + m.impl("preallocate_nccl_window_buffer", [](at::Tensor const&, torch::List const&, int64_t) { return; }); } diff --git a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py index a8201cf714..482e22d5a7 100644 --- a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py @@ -1,5 +1,6 @@ +import threading from functools import lru_cache -from typing import List, Mapping, Optional, Tuple, Union +from typing import ClassVar, List, Mapping, Optional, Tuple, Union import torch import triton # type: ignore[import] @@ -1655,6 +1656,8 @@ def _( class AllReduceRunner(TunableRunner): + _prealloc_lock: ClassVar[threading.Lock] = threading.Lock() + _prealloc_done: ClassVar[set] = set() tuning_config = TuningConfig( dynamic_tensor_specs=(DynamicTensorSpec( 0, 0, get_last_power_of_2_num_tokens_buckets(8192), @@ -1683,6 +1686,42 @@ class AllReduceRunner(TunableRunner): self.op, ) + @classmethod + def _maybe_preallocate_buffers(cls, + input_tensor: torch.Tensor, + group: List[int], + do_preparation: bool = False) -> None: + if not do_preparation: + return + if not hasattr(torch.ops.trtllm, "preallocate_nccl_window_buffer"): + return + if input_tensor.numel() == 0 or input_tensor.size(0) == 0: + return + if hasattr(torch.cuda, "is_current_stream_capturing"): + try: + if torch.cuda.is_current_stream_capturing(): + return + except (RuntimeError, AssertionError): + # If capture status can't be queried, avoid prealloc to be safe. + return + + num_tokens = int(input_tensor.size(0)) + if num_tokens <= 0: + return + group_key = tuple(group) + cache_key = (group_key, num_tokens) + with cls._prealloc_lock: + if cache_key in cls._prealloc_done: + return + cls._prealloc_done.add(cache_key) + + logger.debug( + "[tunable_allreduce] Pre-allocating NCCL window buffers: " + "tokens=%d group=%s", num_tokens, list(group)) + prealloc_input = input_tensor + torch.ops.trtllm.preallocate_nccl_window_buffer(prealloc_input, group, + 2) + def get_valid_tactics( self, inputs: List[torch.Tensor], @@ -1715,8 +1754,19 @@ class AllReduceRunner(TunableRunner): self, inputs: List[torch.Tensor], tactic: int = -1, + do_preparation: bool = False, + **kwargs, ) -> torch.Tensor: input, residual, norm_weight, scale, bias, workspace = inputs + if do_preparation: + valid_tactics = self.get_valid_tactics(inputs, + OptimizationProfile(), + **kwargs) + if AllReduceStrategy.NCCL_SYMMETRIC.value in valid_tactics: + self._maybe_preallocate_buffers(input, + self.group, + do_preparation=True) + return input if tactic == -1: tactic = AllReduceStrategy.NCCL_SYMMETRIC.value @@ -1735,6 +1785,15 @@ class AllReduceRunner(TunableRunner): ) +@torch.library.register_fake("trtllm::preallocate_nccl_window_buffer") +def _( + input: torch.Tensor, + group: List[int], + count: int, +) -> None: + return None + + @torch.library.custom_op("trtllm::tunable_allreduce", mutates_args=()) def tunable_allreduce( input: torch.Tensor,