TensorRT-LLMs/cpp/tensorrt_llm/runtime/mcastGPUBuffer.h
Zongfei Jing dbaddb3a29
Adding two-shot allreduce kernel and mnnvl multicasting buffer (#4216)
* Adding two-shot allreduce kernel and mnnvl multicasting buffergit gffe

Signed-off-by: Shiyu Li <shili@nvidia.com>

Adding comments

Signed-off-by: Shiyu Li <shili@nvidia.com>

Add unittest of the twoshot kernel.

Signed-off-by: Shiyu Li <shili@nvidia.com>

Update dispatch logic

Signed-off-by: Shiyu Li <shili@nvidia.com>

Use cpu barrier instead of GPU at init

Signed-off-by: Shiyu Li <shili@nvidia.com>

Merge dispatch logic fix

Signed-off-by: Shiyu Li <shili@nvidia.com>

Update the kernel to use GPU-managed buffer

Signed-off-by: Shiyu Li <shili@nvidia.com>

* Refine

Signed-off-by: Zongfei Jing <20381269+zongfeijing@users.noreply.github.com>

* Clean code

Signed-off-by: Zongfei Jing <20381269+zongfeijing@users.noreply.github.com>

* Fix compile error

Signed-off-by: Zongfei Jing <20381269+zongfeijing@users.noreply.github.com>

* Fix issue

Signed-off-by: Zongfei Jing <20381269+zongfeijing@users.noreply.github.com>

* Clean up

Signed-off-by: Zongfei Jing <20381269+zongfeijing@users.noreply.github.com>

* Simplify AllReduce interface

Signed-off-by: Zongfei Jing <20381269+zongfeijing@users.noreply.github.com>

* Rename

Signed-off-by: Zongfei Jing <20381269+zongfeijing@users.noreply.github.com>

* Fix warning

Signed-off-by: Zongfei Jing <20381269+zongfeijing@users.noreply.github.com>

* Tidy code

Signed-off-by: Zongfei Jing <20381269+zongfeijing@users.noreply.github.com>

* Rename

Signed-off-by: Zongfei Jing <20381269+zongfeijing@users.noreply.github.com>

* Fix compile error

Signed-off-by: Zongfei Jing <20381269+zongfeijing@users.noreply.github.com>

* Refine

Signed-off-by: Zongfei Jing <20381269+zongfeijing@users.noreply.github.com>

* Skip ut for no_fusion

Signed-off-by: Zongfei Jing <20381269+zongfeijing@users.noreply.github.com>

* Refine

Signed-off-by: Zongfei Jing <20381269+zongfeijing@users.noreply.github.com>

---------

Signed-off-by: Shiyu Li <shili@nvidia.com>
Signed-off-by: Zongfei Jing <20381269+zongfeijing@users.noreply.github.com>
Co-authored-by: Shiyu Li <shili@nvidia.com>
2025-05-22 03:42:36 +08:00

91 lines
4.4 KiB
C++

/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/runtime/mcastDeviceMemory.h"
#include "tensorrt_llm/runtime/torchUtils.h"
namespace tensorrt_llm::runtime
{
//! \brief Wrapper class for McastDeviceMemory to facilitate PyTorch tensor creation.
//! It manages a buffer accessible via unicast or multicast for multi-node communication.
class McastGPUBuffer
{
public:
// Disallow copy construction and assignment
McastGPUBuffer(McastGPUBuffer const&) = delete;
McastGPUBuffer& operator=(McastGPUBuffer const&) = delete;
//! \brief Constructor for McastGpuBuffer.
//! \param bufSize The total size of the buffer in bytes.
//! \param groupSize The number of ranks in the communication group.
//! \param groupRank The rank of the local process within the group.
//! \param device The CUDA device for buffer allocation.
//! \param mnNvlink Flag indicating if multi-node NVLink is used.
McastGPUBuffer(size_t bufSize, uint32_t groupSize, uint32_t groupRank, at::Device device, bool mnNvlink)
: mMcastDeviceMemory(bufSize, groupSize, groupRank, device.index(), mnNvlink)
, mBufSize(bufSize)
, mLocalDevice(device)
{
}
//! \brief Returns a PyTorch tensor view of the unicast buffer portion for a specific rank.
//! \param rank The target rank for the unicast pointer.
//! \param sizes The desired shape (dimensions) of the tensor.
//! \param dtype The data type of the tensor elements.
//! \param storageOffset The offset in elements from the start of the buffer.
//! \return An ATen tensor wrapping the unicast buffer section.
at::Tensor getUCBuffer(uint32_t rank, c10::IntArrayRef sizes, c10::ScalarType dtype, int64_t storageOffset)
{
size_t const numel = std::accumulate(sizes.begin(), sizes.end(), 1UL, std::multiplies<size_t>());
size_t const elementSize = c10::elementSize(dtype);
size_t const reqSize = (numel + storageOffset) * elementSize;
TORCH_CHECK(reqSize <= mBufSize, "McastGpuBuffer::getUcBuffer: the requested size (", reqSize,
" bytes) exceeds the allocated size (", mBufSize, " bytes)");
auto* dataPtr = static_cast<uint8_t*>(mMcastDeviceMemory.getUnicastPtr(rank)) + storageOffset * elementSize;
auto options = at::TensorOptions().dtype(dtype).device(mLocalDevice);
return at::for_blob(dataPtr, sizes).options(options).target_device(mLocalDevice).make_tensor();
}
//! \brief Returns a PyTorch tensor view of the multicast buffer portion.
//! \param sizes The desired shape (dimensions) of the tensor.
//! \param dtype The data type of the tensor elements.
//! \param storageOffset The offset in elements from the start of the buffer.
//! \return An ATen tensor wrapping the multicast buffer section.
at::Tensor getMCBuffer(c10::IntArrayRef sizes, c10::ScalarType dtype, int64_t storageOffset)
{
size_t const numel = std::accumulate(sizes.begin(), sizes.end(), 1UL, std::multiplies<size_t>());
size_t const elementSize = c10::elementSize(dtype);
size_t const reqSize = (numel + storageOffset) * elementSize;
TORCH_CHECK(reqSize <= mBufSize, "McastGpuBuffer::getMcBuffer: the requested size (", reqSize,
" bytes) exceeds the allocated size (", mBufSize, " bytes)");
auto* dataPtr = static_cast<uint8_t*>(mMcastDeviceMemory.getMulticastPtr()) + storageOffset * elementSize;
auto options = at::TensorOptions().dtype(dtype).device(mLocalDevice);
return at::for_blob(dataPtr, sizes).options(options).target_device(mLocalDevice).make_tensor();
}
private:
//!< Underlying memory manager for multi-node communication.
tensorrt_llm::runtime::McastDeviceMemory mMcastDeviceMemory;
size_t mBufSize; //!< Total size of the managed buffer.
at::Device mLocalDevice; //!< The local CUDA device.
};
} // namespace tensorrt_llm::runtime