mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-16 07:53:55 +08:00
[None][fix] Pre-Allocation for Auto-Tuning NCCL_SYMMETRIC (#11326)
Signed-off-by: Ludwig Schneider <lschneider@nvidia.com>
This commit is contained in:
parent
9c2d23c2e5
commit
5130cbd73e
@ -626,8 +626,10 @@ void NCCLWindowAllocator::cleanupBuffersForComm(ncclComm_t comm) noexcept
|
||||
// Check for buffers still in use - this shouldn't happen if cleanup is called properly,
|
||||
// but we log a warning if it does
|
||||
size_t inUseCount = 0;
|
||||
size_t totalBytes = 0;
|
||||
for (auto const& entry : commIt->second)
|
||||
{
|
||||
totalBytes += entry.buffer.size;
|
||||
if (entry.inUse)
|
||||
{
|
||||
++inUseCount;
|
||||
@ -640,6 +642,8 @@ void NCCLWindowAllocator::cleanupBuffersForComm(ncclComm_t comm) noexcept
|
||||
"This may indicate buffers weren't properly released before cleanup.",
|
||||
inUseCount, static_cast<void*>(comm));
|
||||
}
|
||||
TLLM_LOG_DEBUG("[NCCLUtil] NCCL window allocator teardown for comm %p: %zu buffers, %zu bytes total",
|
||||
static_cast<void*>(comm), commIt->second.size(), totalBytes);
|
||||
|
||||
for (auto& entry : commIt->second)
|
||||
{
|
||||
|
||||
@ -1428,6 +1428,75 @@ private:
|
||||
|
||||
#endif // ENABLE_MULTI_DEVICE
|
||||
|
||||
void preallocateNCCLWindowBuffer(
|
||||
torch::Tensor const& input, torch::List<int64_t> const& group, const int64_t buffersPerSize)
|
||||
{
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
if (buffersPerSize <= 0 || group.size() == 0 || input.numel() == 0 || input.size(0) == 0)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
std::set<int> groupSet;
|
||||
for (auto const& rank : group)
|
||||
{
|
||||
groupSet.insert(static_cast<int>(rank));
|
||||
}
|
||||
|
||||
auto const commPtr = getComm(groupSet);
|
||||
if (!commPtr || *commPtr == nullptr)
|
||||
{
|
||||
TLLM_LOG_DEBUG("[preallocateNCCLWindowBuffers] NCCL comm is null; skipping preallocation");
|
||||
return;
|
||||
}
|
||||
|
||||
using tensorrt_llm::common::nccl_util::NCCLWindowAllocator;
|
||||
auto& allocator = NCCLWindowAllocator::getInstance();
|
||||
const ncclComm_t comm = *commPtr;
|
||||
|
||||
const int64_t numTokens = input.size(0);
|
||||
const int64_t elementsPerToken = input.numel() / numTokens;
|
||||
if (elementsPerToken <= 0)
|
||||
{
|
||||
return;
|
||||
}
|
||||
const size_t bufferSize = static_cast<size_t>(numTokens) * static_cast<size_t>(elementsPerToken)
|
||||
* static_cast<size_t>(input.element_size());
|
||||
if (bufferSize == 0)
|
||||
{
|
||||
return;
|
||||
}
|
||||
TLLM_LOG_DEBUG("[preallocateNCCLWindowBuffer] Pre-allocating %ld buffer(s) for tokens=%ld (%zu bytes) comm %p",
|
||||
buffersPerSize, numTokens, bufferSize, static_cast<void*>(comm));
|
||||
std::vector<void*> allocatedPtrs;
|
||||
allocatedPtrs.reserve(buffersPerSize);
|
||||
try
|
||||
{
|
||||
for (int64_t i = 0; i < buffersPerSize; ++i)
|
||||
{
|
||||
auto buffer = allocator.requestBuffer(comm, bufferSize);
|
||||
if (!buffer.isValid())
|
||||
{
|
||||
break;
|
||||
}
|
||||
allocatedPtrs.push_back(buffer.ptr);
|
||||
}
|
||||
}
|
||||
catch (std::exception const& e)
|
||||
{
|
||||
TLLM_LOG_DEBUG("[preallocateNCCLWindowBuffer] requestBuffer failed for %zu bytes: %s", bufferSize, e.what());
|
||||
}
|
||||
|
||||
for (auto ptr : allocatedPtrs)
|
||||
{
|
||||
allocator.releaseBuffer(comm, ptr);
|
||||
}
|
||||
#else
|
||||
(void) group;
|
||||
(void) buffersPerSize;
|
||||
#endif
|
||||
}
|
||||
|
||||
std::vector<torch::Tensor> allreduce_raw(torch::Tensor const& input, torch::optional<torch::Tensor> const& residual,
|
||||
torch::optional<torch::Tensor> const& norm_weight, torch::optional<torch::Tensor> const& scale,
|
||||
torch::optional<torch::Tensor> const& bias, torch::optional<torch::Tensor> workspace,
|
||||
@ -1747,6 +1816,7 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m)
|
||||
"int rank,"
|
||||
"int nranks,"
|
||||
"float eps) -> Tensor[]");
|
||||
m.def("preallocate_nccl_window_buffer(Tensor input, int[] group, int count) -> ()");
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
|
||||
@ -1756,6 +1826,7 @@ TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
|
||||
m.impl("allreduce_pg", &tensorrt_llm::torch_ext::allreduce_pg);
|
||||
m.impl("moe_allreduce", &tensorrt_llm::torch_ext::moe_allreduce);
|
||||
m.impl("moe_finalize_allreduce", &tensorrt_llm::torch_ext::moe_finalize_allreduce);
|
||||
m.impl("preallocate_nccl_window_buffer", &tensorrt_llm::torch_ext::preallocateNCCLWindowBuffer);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(trtllm, CPU, m)
|
||||
@ -1767,4 +1838,5 @@ TORCH_LIBRARY_IMPL(trtllm, CPU, m)
|
||||
reinterpret_cast<int64_t*>(workspace.data_ptr()), (int) tp_size);
|
||||
return std::vector<at::Tensor>{};
|
||||
});
|
||||
m.impl("preallocate_nccl_window_buffer", [](at::Tensor const&, torch::List<int64_t> const&, int64_t) { return; });
|
||||
}
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import threading
|
||||
from functools import lru_cache
|
||||
from typing import List, Mapping, Optional, Tuple, Union
|
||||
from typing import ClassVar, List, Mapping, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import triton # type: ignore[import]
|
||||
@ -1655,6 +1656,8 @@ def _(
|
||||
|
||||
|
||||
class AllReduceRunner(TunableRunner):
|
||||
_prealloc_lock: ClassVar[threading.Lock] = threading.Lock()
|
||||
_prealloc_done: ClassVar[set] = set()
|
||||
tuning_config = TuningConfig(
|
||||
dynamic_tensor_specs=(DynamicTensorSpec(
|
||||
0, 0, get_last_power_of_2_num_tokens_buckets(8192),
|
||||
@ -1683,6 +1686,42 @@ class AllReduceRunner(TunableRunner):
|
||||
self.op,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _maybe_preallocate_buffers(cls,
|
||||
input_tensor: torch.Tensor,
|
||||
group: List[int],
|
||||
do_preparation: bool = False) -> None:
|
||||
if not do_preparation:
|
||||
return
|
||||
if not hasattr(torch.ops.trtllm, "preallocate_nccl_window_buffer"):
|
||||
return
|
||||
if input_tensor.numel() == 0 or input_tensor.size(0) == 0:
|
||||
return
|
||||
if hasattr(torch.cuda, "is_current_stream_capturing"):
|
||||
try:
|
||||
if torch.cuda.is_current_stream_capturing():
|
||||
return
|
||||
except (RuntimeError, AssertionError):
|
||||
# If capture status can't be queried, avoid prealloc to be safe.
|
||||
return
|
||||
|
||||
num_tokens = int(input_tensor.size(0))
|
||||
if num_tokens <= 0:
|
||||
return
|
||||
group_key = tuple(group)
|
||||
cache_key = (group_key, num_tokens)
|
||||
with cls._prealloc_lock:
|
||||
if cache_key in cls._prealloc_done:
|
||||
return
|
||||
cls._prealloc_done.add(cache_key)
|
||||
|
||||
logger.debug(
|
||||
"[tunable_allreduce] Pre-allocating NCCL window buffers: "
|
||||
"tokens=%d group=%s", num_tokens, list(group))
|
||||
prealloc_input = input_tensor
|
||||
torch.ops.trtllm.preallocate_nccl_window_buffer(prealloc_input, group,
|
||||
2)
|
||||
|
||||
def get_valid_tactics(
|
||||
self,
|
||||
inputs: List[torch.Tensor],
|
||||
@ -1715,8 +1754,19 @@ class AllReduceRunner(TunableRunner):
|
||||
self,
|
||||
inputs: List[torch.Tensor],
|
||||
tactic: int = -1,
|
||||
do_preparation: bool = False,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
input, residual, norm_weight, scale, bias, workspace = inputs
|
||||
if do_preparation:
|
||||
valid_tactics = self.get_valid_tactics(inputs,
|
||||
OptimizationProfile(),
|
||||
**kwargs)
|
||||
if AllReduceStrategy.NCCL_SYMMETRIC.value in valid_tactics:
|
||||
self._maybe_preallocate_buffers(input,
|
||||
self.group,
|
||||
do_preparation=True)
|
||||
return input
|
||||
if tactic == -1:
|
||||
tactic = AllReduceStrategy.NCCL_SYMMETRIC.value
|
||||
|
||||
@ -1735,6 +1785,15 @@ class AllReduceRunner(TunableRunner):
|
||||
)
|
||||
|
||||
|
||||
@torch.library.register_fake("trtllm::preallocate_nccl_window_buffer")
|
||||
def _(
|
||||
input: torch.Tensor,
|
||||
group: List[int],
|
||||
count: int,
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
|
||||
@torch.library.custom_op("trtllm::tunable_allreduce", mutates_args=())
|
||||
def tunable_allreduce(
|
||||
input: torch.Tensor,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user