[TRTLLM-7349][feat] Adding new orchestrator type -- ray (#7520)

Signed-off-by: Erin Ho <14718778+hchings@users.noreply.github.com>
Co-authored-by: Yuan Tong <13075180+tongyuantongyu@users.noreply.github.com>
Co-authored-by: Erin Ho <14718778+hchings@users.noreply.github.com>
This commit is contained in:
Jonas Yang CN 2025-10-04 08:12:24 +08:00 committed by GitHub
parent 9d098e3142
commit 88ea2c4ee9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
91 changed files with 5538 additions and 603 deletions

1
.gitignore vendored
View File

@ -46,6 +46,7 @@ tensorrt_llm/deep_ep_cpp_tllm.pyi
tensorrt_llm/deep_gemm/
tensorrt_llm/deep_gemm_cpp_tllm.*.so
tensorrt_llm/deep_gemm_cpp_tllm.pyi
tensorrt_llm/pg_utils_bindings.*.so
*docs/cpp_docs*
*docs/source/_cpp_gen*
docs/source/**/*.rst

View File

@ -23,9 +23,17 @@
#include "tensorrt_llm/executor/cacheCommunicator.h"
#include "tensorrt_llm/executor/dataTransceiverState.h"
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
#include "tensorrt_llm/runtime/utils/pgUtils.h"
#include <future>
#include <map>
#include <memory>
#include <mutex>
#include <optional>
#include <pybind11/pybind11.h>
#include <torch/csrc/jit/python/pybind_utils.h>
#include <torch/custom_class.h>
#include <torch/python.h>
#include <type_traits>
#include <vector>
using SizeType32 = tensorrt_llm::runtime::SizeType32;
@ -43,6 +51,134 @@ class BaseKVCacheManager;
class CacheSender;
class CacheReceiver;
class CacheTransceiverComm
{
public:
// Construct from a non-owning raw pointer, won't take ownership of the pointer
explicit CacheTransceiverComm(mpi::MpiComm const* mpiComm)
: mMpiComm(std::shared_ptr<mpi::MpiComm const>(nullptr), mpiComm)
{
}
// Construct from a shared_ptr with shared ownership
explicit CacheTransceiverComm(std::shared_ptr<mpi::MpiComm const> mpiComm)
: mMpiComm(std::move(mpiComm))
{
}
// Construct from a ProcessGroup communicator
explicit CacheTransceiverComm(c10::intrusive_ptr<c10d::ProcessGroup> pgComm)
: mPgComm(std::move(pgComm))
{
}
~CacheTransceiverComm() = default;
bool isMpi() const noexcept
{
return mMpiComm != nullptr;
}
int getRank() const
{
if (isMpi())
{
return mMpiComm->getRank();
}
return mPgComm->getRank();
}
int getSize() const
{
if (isMpi())
{
return mMpiComm->getSize();
}
return mPgComm->getSize();
}
void allgather(void const* sendbuf, void* recvbuf, int count, mpi::MpiType dtype) const
{
if (isMpi())
{
mMpiComm->allgather(sendbuf, recvbuf, count, dtype);
return;
}
TLLM_THROW("Input arguments only supported in mpi");
}
template <typename Input, typename Output>
bool allgather(Input input, Output output, c10d::AllgatherOptions options = c10d::AllgatherOptions()) const
{
if (isMpi())
{
TLLM_THROW("Input arguments only supported in pg");
}
tensorrt_llm::pg_utils::PgHelper pgh{mPgComm};
PGCHECK_THROW(pgh.allgather(input, output, options));
return true;
}
template <typename Input, typename Output>
bool allgatherv(Input input, Output output, std::vector<int> const& sizes,
c10d::AllgatherOptions options = c10d::AllgatherOptions()) const
{
if (isMpi())
{
TLLM_THROW("Input arguments only supported in pg");
}
tensorrt_llm::pg_utils::PgHelper pgh{mPgComm};
PGCHECK_THROW(pgh.allgatherv(input, output, sizes, options));
return true;
}
bool allgatherv(void const* sendbuf, int sendcount, mpi::MpiType sendtype, void* recvbuf,
std::vector<int> const& recvcounts, std::vector<int> const& displs, mpi::MpiType recvtype) const
{
if (isMpi())
{
mMpiComm->allgatherv(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype);
return true;
}
TLLM_THROW("Input arguments only supported in mpi");
}
CacheTransceiverComm split(int color, int key)
{
if (isMpi())
{
auto subgroup = mMpiComm->split(color, key);
return CacheTransceiverComm(std::make_shared<mpi::MpiComm const>(std::move(subgroup)));
}
bool const initialized = Py_IsInitialized();
TLLM_CHECK_WITH_INFO(initialized, "Trying to use ProcessGroup communicator but Python is not initialized");
try
{
c10::intrusive_ptr<c10d::ProcessGroup> pgSub;
{
pybind11::gil_scoped_acquire gil;
auto const m = pybind11::module::import("tensorrt_llm._torch.distributed.pg_utils");
// Properly box the existing intrusive_ptr ProcessGroup into an IValue
// and convert to a Python object without constructing a new instance.
auto const py_pg = torch::jit::toPyObject(c10::IValue(mPgComm));
auto const py_sub_pg = m.attr("split")(color, key, py_pg);
pgSub = torch::jit::toCustomClass<c10d::ProcessGroup>(py_sub_pg);
}
return CacheTransceiverComm(pgSub);
}
catch (...)
{
TLLM_THROW("Failed to split process group");
}
}
private:
std::shared_ptr<mpi::MpiComm const> mMpiComm;
c10::intrusive_ptr<c10d::ProcessGroup> mPgComm;
};
class CacheTransceiverFactory
{
public:
@ -124,9 +260,11 @@ private:
std::unique_ptr<CacheReceiver> mCacheReceiver;
std::vector<std::pair<LlmRequest*, std::future<void>>> mSenderFutures;
std::vector<std::pair<LlmRequest*, std::future<void>>> mRequesterFutures;
mpi::MpiComm const *mMpiGroupComm{nullptr}, *mMpiWorldComm{nullptr};
std::shared_ptr<mpi::MpiComm> mMpiGroupTensorParaComm, mMpiGroupPipeParaComm, mMpiGroupDataComm,
mMpiGroupTPInDPComm;
mpi::MpiComm const* mMpiWorldComm{nullptr};
std::shared_ptr<CacheTransceiverComm> mGroupComm;
std::shared_ptr<CacheTransceiverComm> mGroupTensorParaComm, mGroupPipeParaComm, mGroupDataComm, mGroupTPInDPComm;
executor::kv_cache::CommState const* mCommState;
std::unique_ptr<executor::kv_cache::CacheState> mCacheState;
std::unique_ptr<executor::kv_cache::ConnectionManager> mManager;

View File

@ -0,0 +1,72 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "c10/util/intrusive_ptr.h"
#include <Python.h>
namespace tensorrt_llm::common
{
// Adapted from pybind11's example implementation:
// https://github.com/pybind/pybind11/blob/master/include/pybind11/conduit/pybind11_conduit_v1.h
// Copyright (c) 2024 The pybind Community.
inline void* get_raw_pointer_ephemeral(
PyObject* py_obj, std::type_info const* cpp_type_info, std::string const& pybind11_abi)
{
PyObject* cpp_type_info_capsule = PyCapsule_New(
const_cast<void*>(static_cast<void const*>(cpp_type_info)), typeid(std::type_info).name(), nullptr);
if (cpp_type_info_capsule == nullptr)
{
return nullptr;
}
PyObject* cpp_conduit = PyObject_CallMethod(
py_obj, "_pybind11_conduit_v1_", "yOy", pybind11_abi.c_str(), cpp_type_info_capsule, "raw_pointer_ephemeral");
Py_DECREF(cpp_type_info_capsule);
if (cpp_conduit == nullptr)
{
return nullptr;
}
void* raw_ptr = PyCapsule_GetPointer(cpp_conduit, cpp_type_info->name());
Py_DECREF(cpp_conduit);
if (PyErr_Occurred())
{
return nullptr;
}
return raw_ptr;
}
template <typename T, typename E>
T* get_type_pointer_ephemeral(PyObject* py_obj, std::string pybind11_abi)
{
void* raw_ptr = get_raw_pointer_ephemeral(py_obj, &typeid(T), pybind11_abi);
if (raw_ptr == nullptr)
{
throw E();
}
return static_cast<T*>(raw_ptr);
}
template <typename T, typename E>
c10::intrusive_ptr<T> get_intrusive_ptr(PyObject* py_obj, std::string pybind11_abi)
{
auto* const p = get_type_pointer_ephemeral<T, E>(py_obj, pybind11_abi);
return c10::intrusive_ptr<T>::reclaim_copy(p);
}
} // namespace tensorrt_llm::common

View File

@ -35,6 +35,7 @@
#include <cstdlib>
#include <memory>
#include <mutex>
#include <optional>
#include <thread>
#if ENABLE_MULTI_DEVICE
@ -425,7 +426,29 @@ public:
return !(rhs == *this);
}
bool couldUseMPI() const
{
if (!mDisableMPI.has_value())
{
char* val = std::getenv("TLLM_DISABLE_MPI");
if (val != NULL && std::string(val) == "1")
{
mDisableMPI = true;
}
else
{
mDisableMPI = false;
}
}
if (mDisableMPI.value())
{
throw std::runtime_error("MPI is disabled, DON\'T USE MPI");
}
return true;
}
private:
mutable std::optional<bool> mDisableMPI;
//! \brief Corresponds to `world()` by default, but can be overridden per process.
static MpiComm& mutableSession();

View File

@ -0,0 +1,284 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <algorithm>
#include <functional>
#include <memory>
#include <numeric>
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
#include <torch/torch.h>
#include <vector>
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/common/tllmException.h"
// Check async op.
inline c10::intrusive_ptr<c10d::Work> pgCheckHelper(
c10::intrusive_ptr<c10d::Work> work, char const* const file, int const line, char const* info)
{
if (work == nullptr)
{
auto const msg = std::string("[TensorRT-LLM][ERROR] empty work returned from: ") + info;
tensorrt_llm::common::throwRuntimeError(file, line, msg);
}
try
{
work->wait();
}
catch (...)
{
auto msg = std::string("[TensorRT-LLM][ERROR] Torch distributed operation error: ") + info;
std::throw_with_nested(tensorrt_llm::common::TllmException(file, line, msg.c_str()));
}
return work;
}
// Check sync op.
inline void pgCheckHelper(bool success, char const* const file, int const line, char const* info)
{
if (!success)
{
throw std::runtime_error(std::string("[TensorRT-LLM][ERROR] Torch distributed operation error: ") + info);
}
}
#define PGCHECK_THROW(op) pgCheckHelper(op, __FILE__, __LINE__, #op)
#define PGCHECK_THROW_WITH_INFO(op, info) pgCheckHelper(op, __FILE__, __LINE__, info)
inline bool useMPI()
{
bool useMPI = true;
char* val = std::getenv("TLLM_DISABLE_MPI");
if (val != nullptr && std::string(val) == "1")
{
useMPI = false;
}
return useMPI;
}
namespace tensorrt_llm::pg_utils
{
// ProcessGroup management functions
c10::intrusive_ptr<c10d::ProcessGroup> get_world_pg();
c10::intrusive_ptr<c10d::ProcessGroup> get_local_pg();
void init_pg(c10::intrusive_ptr<c10d::ProcessGroup> const& process_group_world,
c10::intrusive_ptr<c10d::ProcessGroup> const& process_group_local);
// Tensor wrapping utilities for ProcessGroup operations
inline torch::Tensor wrap_tensor(torch::Tensor data)
{
return data;
}
template <typename T, typename = std::enable_if_t<std::is_arithmetic_v<T>>>
torch::Tensor wrap_tensor(T* data, size_t size)
{
if constexpr (std::is_same_v<std::decay_t<T>, char>)
{
// `char` does not have a guaranteed specialization in CppTypeToScalarType
// across PyTorch builds. Treat `char` as kChar (int8) explicitly.
return at::from_blob(data, {static_cast<int64_t>(size)}, c10::TensorOptions{}.dtype(torch::kChar));
}
else if constexpr (std::is_same_v<std::decay_t<T>, uint64_t>)
{
// `uint64_t` may not have a guaranteed specialization in CppTypeToScalarType
// across PyTorch builds. Treat `uint64_t` as kLong (int64) explicitly.
return at::from_blob(data, {static_cast<int64_t>(size)}, c10::TensorOptions{}.dtype(torch::kLong));
}
else
{
return at::from_blob(data, {static_cast<int64_t>(size)},
c10::TensorOptions{}.dtype(torch::CppTypeToScalarType<std::decay_t<T>>::value));
}
}
template <typename T, typename = std::enable_if_t<std::is_void_v<T>>, typename = void>
torch::Tensor wrap_tensor(T* data, size_t size)
{
return at::from_blob(data, {static_cast<int64_t>(size)}, c10::TensorOptions{}.dtype(torch::kChar));
}
template <typename T>
torch::Tensor wrap_tensor(T const* data, size_t size)
{
return wrap_tensor(const_cast<T*>(data), size);
}
template <typename T, typename = std::enable_if_t<std::is_arithmetic_v<T>>>
torch::Tensor wrap_tensor(T& data)
{
return wrap_tensor(&data, 1);
}
template <typename T, typename = std::enable_if_t<std::is_arithmetic_v<T>>>
torch::Tensor wrap_tensor(std::reference_wrapper<T> data)
{
return wrap_tensor(&data.get(), 1);
}
template <typename T, typename = std::enable_if_t<std::is_arithmetic_v<T>>>
torch::Tensor wrap_tensor(T* data)
{
return wrap_tensor(data, 1);
}
template <typename T>
torch::Tensor wrap_tensor(std::vector<T>& data)
{
return wrap_tensor(data.data(), data.size());
}
template <typename T>
torch::Tensor wrap_tensor(std::vector<T> const& data)
{
return wrap_tensor(data.data(), data.size());
}
template <typename T>
torch::Tensor wrap_tensor(std::reference_wrapper<std::vector<T>> data)
{
auto& ref = data.get();
return wrap_tensor(ref.data(), ref.size());
}
template <typename T>
torch::Tensor wrap_tensor(std::reference_wrapper<std::vector<T> const> data)
{
auto const& ref = data.get();
return wrap_tensor(ref.data(), ref.size());
}
template <typename T>
torch::Tensor wrap_tensor(std::vector<T>* data)
{
return wrap_tensor(data->data(), data->size());
}
// ProcessGroup Helper - convenient wrapper around ProcessGroup operations
struct PgHelper
{
c10::intrusive_ptr<c10d::ProcessGroup> pg;
PgHelper(c10::intrusive_ptr<c10d::ProcessGroup> pg)
: pg(pg)
{
}
template <typename Input, typename Output>
c10::intrusive_ptr<c10d::Work> allgather(
Input input, Output output, c10d::AllgatherOptions options = c10d::AllgatherOptions())
{
auto inputTensor = wrap_tensor(input);
auto outputTensor = wrap_tensor(output);
return pg->_allgather_base(outputTensor, inputTensor, options);
}
template <typename Input>
c10::intrusive_ptr<c10d::Work> allreduce(Input input, c10d::AllreduceOptions options = c10d::AllreduceOptions())
{
std::vector inputs{wrap_tensor(input)};
return pg->allreduce(inputs, options);
}
template <typename Input>
c10::intrusive_ptr<c10d::Work> send(Input input, int dstRank, int tag)
{
std::vector inputs{wrap_tensor(input)};
return pg->send(inputs, dstRank, tag);
}
template <typename Output>
c10::intrusive_ptr<c10d::Work> recv(Output output, int srcRank, int tag)
{
std::vector outputs{wrap_tensor(output)};
return pg->recv(outputs, srcRank, tag);
}
// Variable-size allgather helper implemented via padding + slicing on Tensors.
template <typename Input, typename Output, typename SizeT = int64_t>
bool allgatherv(Input input, Output output, std::vector<SizeT> const& sizes,
c10d::AllgatherOptions options = c10d::AllgatherOptions())
{
auto const worldSize = pg->getSize();
TLLM_CHECK_WITH_INFO(
static_cast<int>(sizes.size()) == worldSize, "sizes.size() must equal worldSize in allgatherv");
at::Tensor inputTensor = wrap_tensor(input);
SizeT const localSize = static_cast<SizeT>(inputTensor.numel());
TLLM_CHECK_WITH_INFO(
sizes[pg->getRank()] == localSize, "sizes[rank] must equal local input size in allgatherv");
SizeT const maxSize = *std::max_element(sizes.begin(), sizes.end());
auto tensorOptions = inputTensor.options();
at::Tensor paddedInput = at::zeros({static_cast<int64_t>(maxSize)}, tensorOptions);
if (localSize > 0)
{
paddedInput.narrow(0, 0, static_cast<int64_t>(localSize)).copy_(inputTensor);
}
at::Tensor paddedOutput
= at::empty({static_cast<int64_t>(maxSize) * static_cast<int64_t>(worldSize)}, tensorOptions);
PGCHECK_THROW(pg->_allgather_base(paddedOutput, paddedInput, options)->wait());
// Prepare compact output tensor backed by 'output'
SizeT const totalSize = std::accumulate(sizes.begin(), sizes.end(), static_cast<SizeT>(0));
at::Tensor outputTensor = wrap_tensor(output);
TLLM_CHECK_WITH_INFO(outputTensor.numel() == static_cast<int64_t>(totalSize),
"output tensor numel must equal total size in allgatherv");
// Slice and compact
size_t writeOffset = 0;
for (int r = 0; r < worldSize; ++r)
{
int64_t const validCount = static_cast<int64_t>(sizes[static_cast<size_t>(r)]);
int64_t const srcOffset = static_cast<int64_t>(r) * static_cast<int64_t>(maxSize);
if (validCount > 0)
{
outputTensor.narrow(0, static_cast<int64_t>(writeOffset), validCount)
.copy_(paddedOutput.narrow(0, srcOffset, validCount));
writeOffset += static_cast<size_t>(validCount);
}
}
return true;
}
// Convenience overload to accept sizes passed via std::cref(...)
template <typename Input, typename Output, typename SizeT = int64_t>
bool allgatherv(Input input, Output output, std::reference_wrapper<std::vector<SizeT> const> sizes,
c10d::AllgatherOptions options = c10d::AllgatherOptions())
{
return allgatherv<Input, Output, SizeT>(input, output, sizes.get(), options);
}
};
} // namespace tensorrt_llm::pg_utils

View File

@ -107,3 +107,8 @@ if(ENABLE_UCX)
target_link_libraries(${BATCH_MANAGER_STATIC_TARGET} PUBLIC)
endif()
find_library(TORCH_PYTHON_LIB torch_python REQUIRED
HINTS ${TORCH_INSTALL_PREFIX}/lib)
target_link_libraries(${BATCH_MANAGER_STATIC_TARGET}
PUBLIC ${TORCH_PYTHON_LIB} Python3::Python pg_utils)

View File

@ -21,7 +21,6 @@
#include <limits>
#include <sstream>
#define UCX_WRAPPER_LIB_NAME "tensorrt_llm_ucx_wrapper"
#if defined(_WIN32)
#include <windows.h>
#define dllOpen(name) LoadLibrary(name ".dll")
@ -48,6 +47,7 @@
#include "tensorrt_llm/executor/dataTransceiverState.h"
#include "tensorrt_llm/executor/serializeUtils.h"
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
#include "tensorrt_llm/runtime/utils/pgUtils.h"
#include <algorithm>
#include <cstddef>
#include <numeric>
@ -114,14 +114,22 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa
std::vector<SizeType32> const& attentionLayerNumPerPP, nvinfer1::DataType dataType,
executor::kv_cache::CacheState::AttentionType attentionType,
std::optional<executor::CacheTransceiverConfig> cacheTransceiverConfig)
: mMpiGroupComm(std::addressof(tensorrt_llm::mpi::MpiComm::session()))
, mCacheTransceiverConfig{cacheTransceiverConfig}
: mCacheTransceiverConfig{cacheTransceiverConfig}
{
using tensorrt_llm::batch_manager::kv_cache_manager::CacheFormatter;
if (useMPI())
{
mGroupComm = std::make_shared<CacheTransceiverComm>(std::addressof(tensorrt_llm::mpi::MpiComm::session()));
}
else
{
mGroupComm = std::make_shared<CacheTransceiverComm>(tensorrt_llm::pg_utils::get_world_pg());
}
if (worldConfig.isTensorParallel())
{
mMpiGroupTensorParaComm = std::make_shared<tensorrt_llm::mpi::MpiComm>(
mMpiGroupComm->split(worldConfig.getPipelineParallelRank(), worldConfig.getTensorParallelRank()));
mGroupTensorParaComm = std::make_shared<CacheTransceiverComm>(
mGroupComm->split(worldConfig.getPipelineParallelRank(), worldConfig.getTensorParallelRank()));
}
int kvFactor = 2;
if (cacheManager->getCacheType() == kv_cache_manager::CacheType::kSELFKONLY)
@ -142,12 +150,11 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa
- TPRankInDPGroup)
/ TPSizeInDPGroup;
// <PP,DP,TP>
mMpiGroupDataComm
= std::make_shared<tensorrt_llm::mpi::MpiComm>(mMpiGroupComm->split(DPRank, worldConfig.getRank()));
mGroupDataComm = std::make_shared<CacheTransceiverComm>(mGroupComm->split(DPRank, worldConfig.getRank()));
if (worldConfig.isTensorParallel())
{
mMpiGroupTPInDPComm = std::make_shared<tensorrt_llm::mpi::MpiComm>(
mMpiGroupComm->split(worldConfig.getRank() / TPSizeInDPGroup, worldConfig.getRank()));
mGroupTPInDPComm = std::make_shared<CacheTransceiverComm>(
mGroupComm->split(worldConfig.getRank() / TPSizeInDPGroup, worldConfig.getRank()));
}
}
bool isMLA = attentionType == executor::kv_cache::CacheState::AttentionType::kMLA;
@ -302,28 +309,39 @@ void CacheTransceiver::requestAndReceiveAsync(LlmRequest* llmRequest)
}
std::vector<LlmRequest::RequestIdType> gatherRequestIds(
mpi::MpiComm const& mpiComm, std::vector<LlmRequest::RequestIdType> const& requestIds)
std::shared_ptr<CacheTransceiverComm> const& mComm, std::vector<LlmRequest::RequestIdType> const& requestIds)
{
int localSize = static_cast<int>(requestIds.size());
std::vector<int> sizes(mpiComm.getSize());
mpiComm.allgather(&localSize, sizes.data(), 1, mpi::MpiType::kINT32);
std::vector<int> displs(mpiComm.getSize());
int totalSize = 0;
for (int i = 0; i < mpiComm.getSize(); i++)
std::vector<int> sizes(mComm->getSize());
std::vector<LlmRequest::RequestIdType> retData;
if (useMPI())
{
displs[i] = totalSize;
totalSize += sizes[i];
mComm->allgather(&localSize, sizes.data(), 1, mpi::MpiType::kINT32);
std::vector<int> displs(mComm->getSize());
size_t totalSize = 0;
for (int i = 0; i < mComm->getSize(); i++)
{
displs[i] = totalSize;
totalSize += sizes[i];
}
retData.resize(totalSize);
mComm->allgatherv(requestIds.data(), static_cast<int>(requestIds.size()), mpi::MpiType::kUINT64, retData.data(),
sizes, displs, mpi::MpiType::kUINT64);
}
else
{
mComm->allgather(&localSize, std::ref(sizes), {});
size_t totalSize = std::accumulate(sizes.begin(), sizes.end(), 0);
retData.resize(totalSize);
mComm->allgatherv(std::ref(requestIds), std::ref(retData), std::cref(sizes), {});
}
std::vector<LlmRequest::RequestIdType> retData(totalSize);
mpiComm.allgatherv(requestIds.data(), static_cast<int>(requestIds.size()), mpi::MpiType::kUINT64, retData.data(),
sizes, displs, mpi::MpiType::kUINT64);
return retData;
}
void updateKVCacheTransferBW(mpi::MpiComm const& mpiComm, LlmRequest* request)
void updateKVCacheTransferBW(std::shared_ptr<CacheTransceiverComm> const& mComm, LlmRequest* request)
{
namespace su = executor::serialize_utils;
int worldSize = mpiComm.getSize();
int worldSize = mComm->getSize();
std::ostringstream oStream;
su::serialize(request->getKvCacheTransferStart(), oStream);
@ -335,7 +353,14 @@ void updateKVCacheTransferBW(mpi::MpiComm const& mpiComm, LlmRequest* request)
auto recvBufferSize = sendBufferSize * worldSize;
std::vector<char> recvBuffer(recvBufferSize);
mpiComm.allgather(sendBuffer.data(), recvBuffer.data(), sendBufferSize, mpi::MpiType::kCHAR);
if (useMPI())
{
mComm->allgather(sendBuffer.data(), recvBuffer.data(), sendBufferSize, mpi::MpiType::kCHAR);
}
else
{
mComm->allgather(std::ref(sendBuffer), std::ref(recvBuffer), {});
}
su::VectorWrapBuf<char> strbuf(recvBuffer);
std::istream is(&strbuf);
@ -353,7 +378,14 @@ void updateKVCacheTransferBW(mpi::MpiComm const& mpiComm, LlmRequest* request)
std::size_t localKVCacheSize = request->getKvCacheSize();
std::vector<std::size_t> allKVCacheSizes(worldSize, 0);
mpiComm.allgather(&localKVCacheSize, allKVCacheSizes.data(), 1, mpi::MpiType::kUINT64);
if (useMPI())
{
mComm->allgather(&localKVCacheSize, allKVCacheSizes.data(), 1, mpi::MpiType::kUINT64);
}
else
{
mComm->allgather(&localKVCacheSize, std::ref(allKVCacheSizes), {});
}
std::size_t totalKVCacheSize = 0;
for (int rank = 0; rank < worldSize; rank++)
@ -362,7 +394,7 @@ void updateKVCacheTransferBW(mpi::MpiComm const& mpiComm, LlmRequest* request)
}
// Update the latest KV cache transfer time for leader rank
if (mpiComm.getRank() == 0)
if (mComm->getRank() == 0)
{
request->setKvCacheTransferStart(minStartTime);
request->setKvCacheTransferEnd(maxEndTime);
@ -373,7 +405,7 @@ void updateKVCacheTransferBW(mpi::MpiComm const& mpiComm, LlmRequest* request)
void CacheTransceiver::checkContextTransferStatus(std::optional<int> const& atLeastRequestNum)
{
bool blockAll = !atLeastRequestNum.has_value();
auto syncComm = mCacheState->getParallelConfig().mEnableAttentionDP ? mMpiGroupTPInDPComm : mMpiGroupTensorParaComm;
auto syncComm = mCacheState->getParallelConfig().mEnableAttentionDP ? mGroupTPInDPComm : mGroupTensorParaComm;
std::vector<LlmRequest::RequestIdType> contextCompleteRequestIds;
for (auto&& [request, future] : mSenderFutures)
{
@ -386,7 +418,7 @@ void CacheTransceiver::checkContextTransferStatus(std::optional<int> const& atLe
std::unordered_map<LlmRequest::RequestIdType, int> frequencyMap;
if ((syncComm) && syncComm->getSize() > 1)
{
auto gatherRequestIdVec = gatherRequestIds(*syncComm, contextCompleteRequestIds);
auto gatherRequestIdVec = gatherRequestIds(syncComm, contextCompleteRequestIds);
for (auto&& requestId : gatherRequestIdVec)
{
frequencyMap[requestId]++;
@ -462,10 +494,10 @@ void CacheTransceiver::checkGenTransferStatus(std::optional<int> const& atLeastR
std::unordered_map<LlmRequest::RequestIdType, int> frequencyMap;
std::vector<LlmRequest::RequestIdType> toBlockRequestIds;
auto syncComm = mCacheState->getParallelConfig().mEnableAttentionDP ? mMpiGroupDataComm.get() : mMpiGroupComm;
auto syncComm = mCacheState->getParallelConfig().mEnableAttentionDP ? mGroupDataComm : mGroupComm;
if ((syncComm) && syncComm->getSize() > 1)
{
auto gatherRequestIdVec = gatherRequestIds(*syncComm, genTransferReadyRequestIds);
auto gatherRequestIdVec = gatherRequestIds(syncComm, genTransferReadyRequestIds);
for (auto&& requestId : gatherRequestIdVec)
{
frequencyMap[requestId]++;
@ -493,8 +525,16 @@ void CacheTransceiver::checkGenTransferStatus(std::optional<int> const& atLeastR
break;
}
toCompleteIdSet.insert(freqVec.at(idx).first);
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(), " checkGenTransferStatus at least from freqVec requestId: %zu ",
freqVec.at(idx).first);
if (useMPI())
{
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
" checkGenTransferStatus at least from freqVec requestId: %zu ", freqVec.at(idx).first);
}
else
{
TLLM_LOG_DEBUG(tensorrt_llm::pg_utils::get_world_pg()->getRank(),
" checkGenTransferStatus at least from freqVec requestId: %zu ", freqVec.at(idx).first);
}
idx++;
}
idx = 0;
@ -509,9 +549,18 @@ void CacheTransceiver::checkGenTransferStatus(std::optional<int> const& atLeastR
if (toCompleteIdSet.find(mRequesterFutures.at(idx).first->mRequestId) == toCompleteIdSet.end())
{
toCompleteIdSet.insert(mRequesterFutures.at(idx).first->mRequestId);
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
" checkGenTransferStatus at least from RequesterFuture requestId: %zu atLeastRequestNum:%d",
mRequesterFutures.at(idx).first->mRequestId, atLeastRequestNum.value_or(0));
if (useMPI())
{
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
" checkGenTransferStatus at least from RequesterFuture requestId: %zu atLeastRequestNum:%d",
mRequesterFutures.at(idx).first->mRequestId, atLeastRequestNum.value_or(0));
}
else
{
TLLM_LOG_DEBUG(tensorrt_llm::pg_utils::get_world_pg()->getRank(),
" checkGenTransferStatus at least from RequesterFuture requestId: %zu atLeastRequestNum:%d",
mRequesterFutures.at(idx).first->mRequestId, atLeastRequestNum.value_or(0));
}
}
idx++;
}
@ -521,12 +570,29 @@ void CacheTransceiver::checkGenTransferStatus(std::optional<int> const& atLeastR
{
toCompleteIdSet.insert(requestId);
}
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(), " checkGenTransferStatus freqVec requestId: %zu,freq:%d ",
requestId, freq);
if (useMPI())
{
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(), " checkGenTransferStatus freqVec requestId: %zu,freq:%d ",
requestId, freq);
}
else
{
TLLM_LOG_DEBUG(tensorrt_llm::pg_utils::get_world_pg()->getRank(),
" checkGenTransferStatus freqVec requestId: %zu,freq:%d ", requestId, freq);
}
}
if (useMPI())
{
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
" checkGenTransferStatus toCompleteIdSet size: %zu, atLeastRequestNum: %d ", toCompleteIdSet.size(),
atLeastRequestNum.value_or(0));
}
else
{
TLLM_LOG_DEBUG(tensorrt_llm::pg_utils::get_world_pg()->getRank(),
" checkGenTransferStatus toCompleteIdSet size: %zu, atLeastRequestNum: %d ", toCompleteIdSet.size(),
atLeastRequestNum.value_or(0));
}
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
" checkGenTransferStatus toCompleteIdSet size: %zu, atLeastRequestNum: %d ", toCompleteIdSet.size(),
atLeastRequestNum.value_or(0));
for (auto it = mRequesterFutures.begin(); it != mRequesterFutures.end();)
{
if (blockAll || toCompleteIdSet.find(it->first->mRequestId) != toCompleteIdSet.end())
@ -539,9 +605,8 @@ void CacheTransceiver::checkGenTransferStatus(std::optional<int> const& atLeastR
// Gather the kv cache transfer time from all workers and update to leader rank
if (!common::getEnvKVCacheTransferOutputPath().empty())
{
auto syncComm
= mCacheState->getParallelConfig().mEnableAttentionDP ? mMpiGroupDataComm.get() : mMpiGroupComm;
updateKVCacheTransferBW(*syncComm, it->first);
auto syncComm = mCacheState->getParallelConfig().mEnableAttentionDP ? mGroupDataComm : mGroupComm;
updateKVCacheTransferBW(syncComm, it->first);
}
}
catch (std::exception const& e)
@ -550,9 +615,18 @@ void CacheTransceiver::checkGenTransferStatus(std::optional<int> const& atLeastR
"Error occurred during generation transfer for request %ld: %s", it->first->mRequestId, e.what());
it->first->setState(LlmRequestState::kDISAGG_TRANS_ERROR);
}
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
"**** it->first->mRequestId: %ld, context request ID: %ld ******** get feature ***",
it->first->mRequestId, it->first->getContextPhaseParams().value().getReqId());
if (useMPI())
{
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
"**** it->first->mRequestId: %ld, context request ID: %ld ******** get feature ***",
it->first->mRequestId, it->first->getContextPhaseParams().value().getReqId());
}
else
{
TLLM_LOG_DEBUG(tensorrt_llm::pg_utils::get_world_pg()->getRank(),
"**** it->first->mRequestId: %ld, context request ID: %ld ******** get feature ***",
it->first->mRequestId, it->first->getContextPhaseParams().value().getReqId());
}
it = mRequesterFutures.erase(it);
}
else

View File

@ -3,6 +3,7 @@
if(ENABLE_UCX)
find_package(ucx REQUIRED)
find_package(ucxx REQUIRED)
find_package(Torch REQUIRED)
include_directories(${3RDPARTY_DIR}/cppzmq)
@ -23,12 +24,16 @@ if(ENABLE_UCX)
set(TOP_LEVEL_DIR "${PROJECT_SOURCE_DIR}/..")
target_compile_definitions(${UCX_WRAPPER_TARGET}
PUBLIC TOP_LEVEL_DIR="${TOP_LEVEL_DIR}")
target_include_directories(${UCX_WRAPPER_TARGET}
PRIVATE ${PROJECT_SOURCE_DIR}/include)
target_include_directories(${UCX_WRAPPER_TARGET} PRIVATE ${ZMQ_INCLUDE_DIRS})
target_link_libraries(${UCX_WRAPPER_TARGET}
PRIVATE $<LINK_LIBRARY:WHOLE_ARCHIVE,ucxx::ucxx>)
target_link_libraries(${UCX_WRAPPER_TARGET} PUBLIC ucxx::ucxx ucx::ucs)
target_link_libraries(${UCX_WRAPPER_TARGET} PUBLIC ${CUDA_RT_LIB})
# Add include directories
target_include_directories(${UCX_WRAPPER_TARGET} PRIVATE ${ZMQ_INCLUDE_DIRS})
target_link_libraries(${UCX_WRAPPER_TARGET} PUBLIC ${TORCH_LIBRARIES})
target_link_libraries(${UCX_WRAPPER_TARGET} PRIVATE ${ZMQ_LIBRARIES})
target_link_libraries(${UCX_WRAPPER_TARGET} PRIVATE pg_utils)
endif()

View File

@ -81,15 +81,15 @@ UcxConnection::UcxConnection(ConnectionIdType connectionId, std::shared_ptr<ucxx
}
catch (std::exception const& e)
{
std::string error = "Error in UcxConnection constructor for rank "
+ std::to_string(mpi::MpiComm::world().getRank()) + ": " + e.what();
std::string error = std::string("Error in UcxConnection constructor for rank ")
+ std::to_string(mManager->getRank()) + ": " + e.what();
TLLM_THROW(error);
}
mSendTagPrefix = mConnectionIdInPeer;
mRecvTagPrefix = mConnectionId;
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
TLLM_LOG_DEBUG(mManager->getRank(),
"UcxConnection::UcxConnection, mConnectionId: %lu, mConnectionIdInPeer: %lu,fromRequester: %d", mConnectionId,
mConnectionIdInPeer, mFromRequester);
}
@ -97,7 +97,7 @@ UcxConnection::UcxConnection(ConnectionIdType connectionId, std::shared_ptr<ucxx
UcxConnection::~UcxConnection()
{
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
TLLM_LOG_DEBUG(mManager->getRank(),
"UcxConnection::~UcxConnection, mConnectionId: %lu, mConnectionIdInPeer: %lu,fromRequester: %d", mConnectionId,
mConnectionIdInPeer, mFromRequester);
// TODO: how to close the endpoint safely?
@ -105,7 +105,7 @@ UcxConnection::~UcxConnection()
void UcxConnection::sendConnectionId(DataContext const& ctx, void const* data, size_t size) const
{
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
TLLM_LOG_DEBUG(mManager->getRank(),
"start UcxConnection::sendConnectionId , mConnectionId: %lu, mConnectionIdInPeer: %lu,fromRequester: %d",
mConnectionId, mConnectionIdInPeer, mFromRequester);
@ -126,7 +126,7 @@ void UcxConnection::sendConnectionId(DataContext const& ctx, void const* data, s
}
TLLM_CHECK_WITH_INFO(req->isCompleted(), "sendConnectionId should be completed");
req->checkError();
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
TLLM_LOG_DEBUG(mManager->getRank(),
"end UcxConnection::sendConnectionId , mConnectionId: %lu, mConnectionIdInPeer: %lu,fromRequester: %d",
mConnectionId, mConnectionIdInPeer, mFromRequester);
}
@ -138,7 +138,7 @@ void UcxConnection::send(DataContext const& ctx, void const* data, size_t size)
sendConnectionId(ctx, data, size);
return;
}
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
TLLM_LOG_DEBUG(mManager->getRank(),
"start UcxConnection::send , mConnectionId: %lu, mConnectionIdInPeer: %lu,fromRequester: %d", mConnectionId,
mConnectionIdInPeer, mFromRequester);
@ -156,7 +156,8 @@ void UcxConnection::send(DataContext const& ctx, void const* data, size_t size)
TLLM_CHECK_WITH_INFO(req->isCompleted(), "send should be completed");
// throw if there is error
req->checkError();
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
TLLM_LOG_DEBUG(mManager->getRank(),
"end UcxConnection::send , mConnectionId: %lu, mConnectionIdInPeer: %lu,fromRequester: %d", mConnectionId,
mConnectionIdInPeer, mFromRequester);
}
@ -164,7 +165,7 @@ void UcxConnection::send(DataContext const& ctx, void const* data, size_t size)
void UcxConnection::recv(DataContext const& ctx, void* data, size_t size) const
{
// Guard to ensure CUDA context is initialized for UCX ops
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
TLLM_LOG_DEBUG(mManager->getRank(),
"start UcxConnection::recv , mConnectionId: %lu, mConnectionIdInPeer: %lu,fromRequester: %d", mConnectionId,
mConnectionIdInPeer, mFromRequester);
TLLM_CHECK_WITH_INFO((mEndpoint), "recvBuffer called without established communicator channel.");
@ -180,7 +181,8 @@ void UcxConnection::recv(DataContext const& ctx, void* data, size_t size) const
TLLM_CHECK_WITH_INFO(req->isCompleted(), "recv should be completed");
// throw if there is error
req->checkError();
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
TLLM_LOG_DEBUG(mManager->getRank(),
"end UcxConnection::recv , mConnectionId: %lu, mConnectionIdInPeer: %lu,fromRequester: %d", mConnectionId,
mConnectionIdInPeer, mFromRequester);
}

View File

@ -19,14 +19,31 @@
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/executor/dataTransceiverState.h"
#include "tensorrt_llm/executor/serializeUtils.h"
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
#include "tensorrt_llm/runtime/utils/pgUtils.h"
#include <chrono>
#include <cstdlib>
#include <cstring>
#include <exception>
#include <functional>
#include <iostream>
#include <mutex>
#include <numeric>
#include <regex>
#include <string>
#include <sys/socket.h>
#include <thread>
#include <ucxx/address.h>
#include <ucxx/typedefs.h>
#include <unistd.h>
#include <vector>
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
#include <torch/csrc/distributed/c10d/Store.hpp>
#include <torch/csrc/distributed/c10d/TCPStore.hpp>
using tensorrt_llm::pg_utils::get_world_pg;
using tensorrt_llm::pg_utils::PgHelper;
namespace tensorrt_llm::executor::kv_cache
{
@ -73,12 +90,12 @@ public:
}
};
std::string getLocalIpByNic(std::string const& interface)
std::string getLocalIpByNic(std::string const& interface, int rank)
{
struct ifaddrs* ifaddr = nullptr;
if (getifaddrs(&ifaddr) == -1)
{
TLLM_LOG_ERROR(mpi::MpiComm::world().getRank(),
TLLM_LOG_ERROR(rank,
"getLocalIpByNic: Can't get local ip from NIC Interface. Please check whether TRTLLM_UCX_INTERFACE is set "
"correctly.");
return std::string{};
@ -118,17 +135,17 @@ std::string getLocalIpByNic(std::string const& interface)
}
freeifaddrs(ifaddr);
TLLM_LOG_ERROR(mpi::MpiComm::world().getRank(),
"Can't get local ip from NIC Interface. Please check whether TRTLLM_UCX_INTERFACE is set correctly.");
TLLM_LOG_ERROR(
rank, "Can't get local ip from NIC Interface. Please check whether TRTLLM_UCX_INTERFACE is set correctly.");
return std::string{};
}
std::string getLocalIpByHostname()
std::string getLocalIpByHostname(int rank)
{
char hostname[256]{};
if (gethostname(hostname, sizeof(hostname)) == -1)
{
TLLM_LOG_ERROR(mpi::MpiComm::world().getRank(), "getLocalIpByHostname: Can't get hostname");
TLLM_LOG_ERROR(rank, "getLocalIpByHostname: Can't get hostname");
return std::string{};
}
@ -140,7 +157,7 @@ std::string getLocalIpByHostname()
struct addrinfo* res = nullptr;
if (getaddrinfo(hostname, nullptr, &hints, &res) != 0)
{
TLLM_LOG_WARNING(mpi::MpiComm::world().getRank(), "getLocalIpByHostname: Can't get address info for hostname");
TLLM_LOG_WARNING(rank, "getLocalIpByHostname: Can't get address info for hostname");
return std::string{};
}
@ -174,11 +191,11 @@ std::string getLocalIpByHostname()
}
freeaddrinfo(res);
TLLM_LOG_WARNING(mpi::MpiComm::world().getRank(), "getLocalIpByHostname: Can't get local ip from hostname");
TLLM_LOG_WARNING(rank, "getLocalIpByHostname: Can't get local ip from hostname");
return std::string{};
}
std::string getLocalIpByRemoteOrHostName()
std::string getLocalIpByRemoteOrHostName(int rank)
{
// Try IPv4
@ -238,20 +255,20 @@ std::string getLocalIpByRemoteOrHostName()
}
// Try hostname
return getLocalIpByHostname();
return getLocalIpByHostname(rank);
}
static std::string getLocalIp()
static std::string getLocalIp(int rank)
{
std::string ucxInterface = common::getEnvUCXInterface();
std::string localIP = {};
if (!ucxInterface.empty())
{
localIP = getLocalIpByNic(ucxInterface);
localIP = getLocalIpByNic(ucxInterface, rank);
}
if (localIP.empty())
{
localIP = getLocalIpByRemoteOrHostName();
localIP = getLocalIpByRemoteOrHostName(rank);
}
// check whether the localIP is valid
if (localIP.empty())
@ -278,10 +295,31 @@ std::optional<std::pair<std::string, int>> parse_zmq_endpoint(std::string const&
}
UcxConnectionManager::UcxConnectionManager()
{
try
{
if (useMPI())
{
mRank = mpi::MpiComm::session().getRank();
mWorldSize = mpi::MpiComm::session().getSize();
}
else
{
auto const worldPg = get_world_pg();
if (worldPg)
{
mRank = worldPg->getRank();
mWorldSize = worldPg->getSize();
TLLM_LOG_DEBUG(mRank, "UCX using Torch process group - rank: %d, world size: %d", mRank, mWorldSize);
}
else
{
TLLM_LOG_DEBUG(mRank, "WARNING: Process group is null, defaulting to single process");
mRank = 0;
mWorldSize = 1;
}
}
TLLM_CUDA_CHECK(cudaGetDevice(&mDevice));
mUcxCtx = ucxx::createContext({{"RNDV_PIPELINE_ERROR_HANDLING", "y"}}, UCP_FEATURE_TAG);
int device = mDevice;
@ -302,7 +340,7 @@ UcxConnectionManager::UcxConnectionManager()
mZmqRepSocket = zmq::socket_t(mZmqContext, zmq::socket_type::rep);
mZmqRepSocket.set(zmq::sockopt::sndhwm, 1000);
std::string localIp = getLocalIp();
std::string localIp = getLocalIp(mRank);
if (localIp.find(':') != std::string::npos)
{
// ipv6
@ -310,62 +348,85 @@ UcxConnectionManager::UcxConnectionManager()
localIp = "[" + localIp + "]";
}
TLLM_LOG_INFO(
mpi::MpiComm::world().getRank(), "UcxConnectionManager::UcxConnectionManager localIp: %s", localIp.c_str());
TLLM_LOG_INFO(mRank, "UcxConnectionManager::UcxConnectionManager localIp: %s", localIp.c_str());
mZmqRepSocket.bind("tcp://" + localIp + ":*");
mZmqRepEndpoint = mZmqRepSocket.get(zmq::sockopt::last_endpoint);
TLLM_LOG_INFO(mpi::MpiComm::world().getRank(), "UcxConnectionManager::UcxConnectionManager mZmqRepEndpoint: %s",
mZmqRepEndpoint.c_str());
TLLM_LOG_INFO(mRank, "UcxConnectionManager::UcxConnectionManager mZmqRepEndpoint: %s", mZmqRepEndpoint.c_str());
auto parse_result = parse_zmq_endpoint(mZmqRepEndpoint);
TLLM_CHECK_WITH_INFO(parse_result.has_value(), "Failed to parse ZMQ endpoint");
auto [ip, port] = parse_result.value();
TLLM_LOG_INFO(mpi::MpiComm::world().getRank(), "UcxConnectionManager::UcxConnectionManager ip: %s, port: %d",
ip.c_str(), port);
TLLM_LOG_INFO(mRank, "UcxConnectionManager::UcxConnectionManager ip: %s, port: %d", ip.c_str(), port);
SocketState socketState{static_cast<uint16_t>(port), ip};
std::vector<executor::kv_cache::SocketState> socketStates(mpi::MpiComm::session().getSize());
std::vector<executor::kv_cache::SocketState> socketStates(mWorldSize);
if (mpi::MpiComm::session().getSize() > 1)
{
mpi::MpiComm::session().barrier();
namespace su = executor::serialize_utils;
std::ostringstream oStream;
su::serialize(socketState, oStream);
auto str = oStream.str();
std::vector<char> buffer(str.begin(), str.end());
std::vector<SizeType32> sizeofBuffer(mpi::MpiComm::session().getSize());
SizeType32 bufferSize = buffer.size();
mpi::MpiComm::session().allgather(&bufferSize, sizeofBuffer.data(), 1, mpi::MpiType::kINT32);
SizeType32 recvBufferSize = std::accumulate(sizeofBuffer.begin(), sizeofBuffer.end(), 0);
std::vector<char> recvBuffer(recvBufferSize);
std::vector<int> displs(mpi::MpiComm::session().getSize());
for (int r = 0; r < mpi::MpiComm::session().getSize(); r++)
{
displs[r] = (r == 0) ? 0 : (displs[r - 1] + sizeofBuffer[r - 1]);
}
mpi::MpiComm::session().allgatherv(buffer.data(), bufferSize, mpi::MpiType::kCHAR, recvBuffer.data(),
sizeofBuffer, displs, mpi::MpiType::kCHAR);
// deserialize
for (int i = 0; i < mpi::MpiComm::session().getSize(); i++)
{
std::vector<char> serBuffer(
recvBuffer.begin() + displs[i], recvBuffer.begin() + (displs[i] + sizeofBuffer[i]));
su::VectorWrapBuf<char> strbuf(serBuffer);
std::istream is(&strbuf);
socketStates[i] = su::deserialize<executor::kv_cache::SocketState>(is);
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(), " recv socketStates[%d]: %s", i,
socketStates[i].toString().c_str());
}
}
else
if (mWorldSize == 1)
{
socketStates[0] = socketState;
}
mCommState = CommState(socketStates, mpi::MpiComm::session().getRank());
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(), " ***** UCX mCommState: %s", mCommState.toString().c_str());
else
{
namespace su = executor::serialize_utils;
std::ostringstream oStream;
su::serialize(socketState, oStream);
auto serializedData = oStream.str();
std::vector<char> buffer(serializedData.begin(), serializedData.end());
std::vector<SizeType32> sizeofBuffer(mWorldSize);
SizeType32 bufferSize = buffer.size();
if (useMPI())
{
mpi::MpiComm::session().barrier();
mpi::MpiComm::session().allgather(&bufferSize, sizeofBuffer.data(), 1, mpi::MpiType::kINT32);
SizeType32 recvBufferSize = std::accumulate(sizeofBuffer.begin(), sizeofBuffer.end(), 0);
std::vector<char> recvBuffer(recvBufferSize);
std::vector<int> displs(mpi::MpiComm::session().getSize());
for (int r = 0; r < mpi::MpiComm::session().getSize(); r++)
{
displs[r] = (r == 0) ? 0 : (displs[r - 1] + sizeofBuffer[r - 1]);
}
mpi::MpiComm::session().allgatherv(buffer.data(), bufferSize, mpi::MpiType::kCHAR, recvBuffer.data(),
sizeofBuffer, displs, mpi::MpiType::kCHAR);
// deserialize
for (int i = 0; i < mpi::MpiComm::session().getSize(); i++)
{
std::vector<char> serBuffer(
recvBuffer.begin() + displs[i], recvBuffer.begin() + (displs[i] + sizeofBuffer[i]));
su::VectorWrapBuf<char> strbuf(serBuffer);
std::istream is(&strbuf);
socketStates[i] = su::deserialize<executor::kv_cache::SocketState>(is);
TLLM_LOG_DEBUG(mRank, " recv socketStates[%d]: %s", i, socketStates[i].toString().c_str());
}
}
else
{
auto const worldPg = get_world_pg();
PgHelper pgh{worldPg};
PGCHECK_THROW(worldPg->barrier());
PGCHECK_THROW(pgh.allgather(&bufferSize, std::ref(sizeofBuffer), {}));
SizeType32 recvBufferSize = std::accumulate(sizeofBuffer.begin(), sizeofBuffer.end(), 0);
std::vector<char> recvBuffer(recvBufferSize);
PGCHECK_THROW(pgh.allgatherv(std::ref(buffer), std::ref(recvBuffer), std::cref(sizeofBuffer), {}));
// deserialize
char* begin = reinterpret_cast<char*>(recvBuffer.data());
for (int r = 0; r < mWorldSize; ++r)
{
std::vector<char> serBuffer(begin, begin + sizeofBuffer[r]);
begin += sizeofBuffer[r];
su::VectorWrapBuf<char> strbuf(serBuffer);
std::istream is(&strbuf);
socketStates[r] = su::deserialize<executor::kv_cache::SocketState>(is);
TLLM_LOG_DEBUG(mRank, " recv socketStates[%d]: %s", r, socketStates[r].toString().c_str());
}
}
}
mCommState = CommState(socketStates, mRank);
TLLM_LOG_DEBUG(mRank, " ***** UCX mCommState: %s", mCommState.toString().c_str());
mZmqRepThread = std::thread(
[this]()
@ -417,7 +478,7 @@ UcxConnectionManager::UcxConnectionManager()
UcxConnectionManager::~UcxConnectionManager()
{
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(), "UcxConnectionManager::~UcxConnectionManager");
TLLM_LOG_DEBUG(mRank, "UcxConnectionManager::~UcxConnectionManager");
for (auto& worker : mWorkersPool)
{
@ -447,7 +508,7 @@ UcxConnectionManager::~UcxConnectionManager()
mZmqRepSocket.close();
mZmqContext.close();
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(), "END UcxConnectionManager::~UcxConnectionManager");
TLLM_LOG_DEBUG(mRank, "END UcxConnectionManager::~UcxConnectionManager");
}
void UcxConnectionManager::addConnection(std::string const& workerAddress)
@ -472,8 +533,7 @@ void UcxConnectionManager::addConnection(std::string const& workerAddress)
}
catch (std::exception const& e)
{
std::string error = "Error in addConnection(connRequest) for rank "
+ std::to_string(mpi::MpiComm::world().getRank()) + ": " + e.what();
std::string error = "Error in addConnection(connRequest) for rank " + std::to_string(mRank) + ": " + e.what();
TLLM_THROW(error);
}
}
@ -538,8 +598,8 @@ UcxConnection::ConnectionIdType UcxConnectionManager::addConnection(std::string
}
catch (std::exception const& e)
{
std::string error = "Error in addConnection(ip) for rank " + std::to_string(mpi::MpiComm::world().getRank())
+ " ip: " + ip + " port: " + std::to_string(port) + ": " + e.what();
std::string error = "Error in addConnection(ip) for rank " + std::to_string(mRank) + " ip: " + ip
+ " port: " + std::to_string(port) + ": " + e.what();
TLLM_THROW(error);
}
}
@ -570,20 +630,18 @@ Connection const* UcxConnectionManager::recvConnect(DataContext const& ctx, void
= *reinterpret_cast<UcxConnection::ConnectionIdType*>(buffer.data() + size);
std::scoped_lock lock(mConnectionsMutex, mConnectionFuturesMutex);
TLLM_CHECK_WITH_INFO(mConnectionFutures.find(connectionId) != mConnectionFutures.end(),
"connectionFuture not found In recvConnect connectionId : %lu , worldRank: %d", connectionId,
mpi::MpiComm::world().getRank());
"connectionFuture not found In recvConnect connectionId : %lu , worldRank: %d", connectionId, mRank);
if (mConnectionFutures.at(connectionId).valid())
{
// wait for the connection to be created
mConnectionFutures.at(connectionId).get();
}
TLLM_CHECK_WITH_INFO(mConnections.find(connectionId) != mConnections.end(),
"Connection not found In recvConnect connectionId: %lu , worldRank: %d", connectionId,
mpi::MpiComm::world().getRank());
"Connection not found In recvConnect connectionId: %lu , worldRank: %d", connectionId, mRank);
TLLM_CHECK(!mConnections[connectionId]->isFromRequester());
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(), "recvConnect connectionId: %lu , sendIDData:%lu", connectionId,
TLLM_LOG_DEBUG(mRank, "recvConnect connectionId: %lu , sendIDData:%lu", connectionId,
*reinterpret_cast<uint64_t*>(buffer.data()));
return mConnections[connectionId].get();

View File

@ -55,6 +55,8 @@ private:
std::mutex mAddressToConnectionIdMutex;
CommState mCommState;
int mDevice;
int mRank;
int mWorldSize;
std::atomic<UcxConnection::ConnectionIdType> mConnectionIdCounter{1};
zmq::context_t mZmqContext;
zmq::socket_t mZmqRepSocket;
@ -78,6 +80,11 @@ public:
Connection const* recvConnect(DataContext const& ctx, void* data, size_t size) override;
std::vector<Connection const*> getConnections(CommState const& state) override;
[[nodiscard]] CommState const& getCommState() const override;
[[nodiscard]] int getRank() const
{
return mRank;
}
};
#if defined(__clang__)

View File

@ -15,6 +15,7 @@ set(SRCS
executor/executor.cpp
executor/executorConfig.cpp
executor/request.cpp
process_group/bindings.cpp
runtime/bindings.cpp
runtime/hostfunc.cpp
runtime/moeBindings.cpp
@ -46,7 +47,8 @@ target_link_libraries(
torch_python
${CUDA_DRV_LIB}
${CUDA_NVML_LIB}
th_common)
th_common
pg_utils)
target_compile_definitions(
${TRTLLM_NB_MODULE} PUBLIC TRTLLM_NB_MODULE=${TRTLLM_NB_MODULE}
PYBIND11_DETAILED_ERROR_MESSAGES=1)

View File

@ -18,6 +18,7 @@
#include "cacheTransceiver.h"
#include "tensorrt_llm/batch_manager/cacheTransceiver.h"
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
#include "tensorrt_llm/common/bindingUtils.h"
#include "tensorrt_llm/executor/executor.h"
#include "tensorrt_llm/nanobind/common/customCasters.h"
#include <ATen/ATen.h>
@ -28,6 +29,7 @@
#include <nanobind/stl/vector.h>
#include <nanobind/trampoline.h>
#include <torch/extension.h>
#include <typeinfo>
using SizeType32 = tensorrt_llm::runtime::SizeType32;
@ -103,6 +105,79 @@ void tb::CacheTransceiverBindings::initBindings(nb::module_& m)
nb::arg("tokens_per_block"), nb::arg("world_config"), nb::arg("attention_layer_num_per_pp"),
nb::arg("dtype"), nb::arg("attention_type"), nb::arg("cache_transceiver_config") = std::nullopt);
nb::class_<tb::CacheTransceiverComm>(m, "CacheTransceiverComm")
.def(
"__init__",
[](tb::CacheTransceiverComm* self, nb::object pg_obj, std::string pybind11_abi)
{
new (self) tb::CacheTransceiverComm(
common::get_intrusive_ptr<c10d::ProcessGroup, nb::python_error>(pg_obj.ptr(), pybind11_abi));
},
nb::arg("process_group"), nb::arg("pybind11_abi"))
.def("get_rank", &tb::CacheTransceiverComm::getRank)
.def("get_size", &tb::CacheTransceiverComm::getSize)
.def("split", &tb::CacheTransceiverComm::split, nb::arg("color"), nb::arg("key"))
.def(
"allgather",
[](tb::CacheTransceiverComm const& self, int64_t input)
{
std::vector<int64_t> out(static_cast<size_t>(self.getSize()));
c10d::AllgatherOptions options;
bool ok = self.allgather(input, std::ref(out), options);
return nb::make_tuple(ok, out);
},
nb::arg("input"))
.def(
"allgather",
[](tb::CacheTransceiverComm const& self, double input)
{
std::vector<double> out(static_cast<size_t>(self.getSize()));
c10d::AllgatherOptions options;
bool ok = self.allgather(input, std::ref(out), options);
return nb::make_tuple(ok, out);
},
nb::arg("input"))
.def(
"allgather",
[](tb::CacheTransceiverComm const& self, char input)
{
std::vector<char> out(static_cast<size_t>(self.getSize()));
c10d::AllgatherOptions options;
bool ok = self.allgather(input, std::ref(out), options);
return nb::make_tuple(ok, out);
},
nb::arg("input"))
.def(
"allgatherv",
[](tb::CacheTransceiverComm const& self, std::vector<int64_t> input, std::vector<int> const& sizes)
{
int total_size = std::accumulate(sizes.begin(), sizes.end(), 0);
std::vector<int64_t> output(total_size);
bool ok = self.allgatherv(std::ref(input), std::ref(output), std::cref(sizes));
return nb::make_tuple(ok, output);
},
nb::arg("input"), nb::arg("sizes"))
.def(
"allgatherv",
[](tb::CacheTransceiverComm const& self, std::vector<double> input, std::vector<int> const& sizes)
{
int total_size = std::accumulate(sizes.begin(), sizes.end(), 0);
std::vector<double> output(total_size);
bool ok = self.allgatherv(std::ref(input), std::ref(output), std::cref(sizes));
return nb::make_tuple(ok, output);
},
nb::arg("input"), nb::arg("sizes"))
.def(
"allgatherv",
[](tb::CacheTransceiverComm const& self, std::vector<char> input, std::vector<int> const& sizes)
{
int total_size = std::accumulate(sizes.begin(), sizes.end(), 0);
std::vector<char> output(total_size);
bool ok = self.allgatherv(std::ref(input), std::ref(output), std::cref(sizes));
return nb::make_tuple(ok, output);
},
nb::arg("input"), nb::arg("sizes"));
nb::class_<tb::kv_cache_manager::CacheTransBufferManager>(m, "CacheTransBufferManager")
.def(nb::init<tb::kv_cache_manager::BaseKVCacheManager*, std::optional<size_t>>(), nb::arg("cache_manager"),
nb::arg("max_num_tokens") = std::nullopt)

View File

@ -40,6 +40,7 @@
#include "tensorrt_llm/nanobind/batch_manager/llmRequest.h"
#include "tensorrt_llm/nanobind/common/tllmExceptions.h"
#include "tensorrt_llm/nanobind/executor/bindings.h"
#include "tensorrt_llm/nanobind/process_group/bindings.h"
#include "tensorrt_llm/nanobind/runtime/bindings.h"
#include "tensorrt_llm/nanobind/testing/modelSpecBinding.h"
#include "tensorrt_llm/nanobind/thop/bindings.h"
@ -125,6 +126,7 @@ NB_MODULE(TRTLLM_NB_MODULE, m)
// Create submodule for executor bindings.
auto mExecutor = m.def_submodule("executor", "Executor bindings");
auto mInternal = m.def_submodule("internal", "Internal submodule of TRTLLM runtime");
auto mInternalProcessGroup = mInternal.def_submodule("process_group", "PyTorch ProcessGroup internal bindings");
auto mInternalRuntime = mInternal.def_submodule("runtime", "Runtime internal bindings");
auto mInternalTesting = mInternal.def_submodule("testing", "Testing internal bindings");
auto mInternalBatchManager = mInternal.def_submodule("batch_manager", "Batch manager internal bindings");
@ -485,6 +487,7 @@ NB_MODULE(TRTLLM_NB_MODULE, m)
.def_prop_ro("pinned", &tr::MemoryCounters::getPinned)
.def_prop_ro("uvm", &tr::MemoryCounters::getUVM);
tensorrt_llm::nanobind::process_group::initBindings(mInternalProcessGroup);
tensorrt_llm::nanobind::runtime::initBindings(mInternalRuntime);
tensorrt_llm::nanobind::testing::initBindings(mInternalTesting);
tpb::initBindings(mInternalBatchManager);

View File

@ -0,0 +1,43 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "bindings.h"
#include <nanobind/stl/string.h>
#include "tensorrt_llm/common/bindingUtils.h"
#include "tensorrt_llm/runtime/utils/pgUtils.h"
namespace nb = nanobind;
namespace tensorrt_llm::nanobind::process_group
{
void initBindings(nb::module_& m)
{
m.def("init_pg",
[](nb::object world_pg_obj, nb::object local_pg_obj, std::string const& pybind11_abi)
{
using Pg = c10d::ProcessGroup;
using E = nb::python_error;
pg_utils::init_pg(common::get_intrusive_ptr<Pg, E>(world_pg_obj.ptr(), pybind11_abi),
common::get_intrusive_ptr<Pg, E>(local_pg_obj.ptr(), pybind11_abi));
});
}
} // namespace tensorrt_llm::nanobind::process_group

View File

@ -0,0 +1,26 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <nanobind/nanobind.h>
namespace nb = nanobind;
namespace tensorrt_llm::nanobind::process_group
{
void initBindings(nb::module_& m);
} // namespace tensorrt_llm::nanobind::process_group

View File

@ -14,6 +14,7 @@ set(SRCS
executor/executor.cpp
executor/executorConfig.cpp
executor/request.cpp
process_group/bindings.cpp
runtime/bindings.cpp
runtime/hostfunc.cpp
common/tllmExceptions.cpp
@ -48,7 +49,8 @@ target_link_libraries(
torch_python
${CUDA_DRV_LIB}
${CUDA_NVML_LIB}
th_common)
th_common
pg_utils)
target_compile_definitions(
${TRTLLM_PYBIND_MODULE} PUBLIC TRTLLM_PYBIND_MODULE=${TRTLLM_PYBIND_MODULE}
PYBIND11_DETAILED_ERROR_MESSAGES=1)

View File

@ -18,6 +18,7 @@
#include "cacheTransceiver.h"
#include "tensorrt_llm/batch_manager/cacheTransceiver.h"
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
#include "tensorrt_llm/common/bindingUtils.h"
#include "tensorrt_llm/executor/executor.h"
#include <ATen/ATen.h>
#include <pybind11/functional.h>
@ -26,6 +27,7 @@
#include <pybind11/stl.h>
#include <pybind11/stl_bind.h>
#include <torch/extension.h>
#include <typeinfo>
using SizeType32 = tensorrt_llm::runtime::SizeType32;
@ -99,6 +101,79 @@ void tb::CacheTransceiverBindings::initBindings(py::module_& m)
py::arg("tokens_per_block"), py::arg("world_config"), py::arg("attention_layer_num_per_pp"),
py::arg("dtype"), py::arg("attention_type"), py::arg("cache_transceiver_config") = std::nullopt);
py::classh<tb::CacheTransceiverComm>(m, "CacheTransceiverComm")
.def(py::init(
[](py::object pg_obj, std::string pybind11_abi)
{
return new CacheTransceiverComm(
common::get_intrusive_ptr<c10d::ProcessGroup, py::error_already_set>(
pg_obj.ptr(), pybind11_abi));
}),
py::arg("process_group"), py::arg("pybind11_abi"))
.def("get_rank", &tb::CacheTransceiverComm::getRank)
.def("get_size", &tb::CacheTransceiverComm::getSize)
.def("split", &tb::CacheTransceiverComm::split, py::arg("color"), py::arg("key"))
.def(
"allgather",
[](tb::CacheTransceiverComm const& self, int64_t input)
{
std::vector<int64_t> out(static_cast<size_t>(self.getSize()));
c10d::AllgatherOptions options;
bool ok = self.allgather(input, std::ref(out), options);
return py::make_tuple(ok, out);
},
py::arg("input"))
.def(
"allgather",
[](tb::CacheTransceiverComm const& self, double input)
{
std::vector<double> out(static_cast<size_t>(self.getSize()));
c10d::AllgatherOptions options;
bool ok = self.allgather(input, std::ref(out), options);
return py::make_tuple(ok, out);
},
py::arg("input"))
.def(
"allgather",
[](tb::CacheTransceiverComm const& self, char input)
{
std::vector<char> out(static_cast<size_t>(self.getSize()));
c10d::AllgatherOptions options;
bool ok = self.allgather(input, std::ref(out), options);
return py::make_tuple(ok, out);
},
py::arg("input"))
.def(
"allgatherv",
[](tb::CacheTransceiverComm const& self, std::vector<int64_t> input, std::vector<int> const& sizes)
{
int total_size = std::accumulate(sizes.begin(), sizes.end(), 0);
std::vector<int64_t> output(total_size);
bool ok = self.allgatherv(std::ref(input), std::ref(output), std::cref(sizes));
return py::make_tuple(ok, output);
},
py::arg("input"), py::arg("sizes"))
.def(
"allgatherv",
[](tb::CacheTransceiverComm const& self, std::vector<double> input, std::vector<int> const& sizes)
{
int total_size = std::accumulate(sizes.begin(), sizes.end(), 0);
std::vector<double> output(total_size);
bool ok = self.allgatherv(std::ref(input), std::ref(output), std::cref(sizes));
return py::make_tuple(ok, output);
},
py::arg("input"), py::arg("sizes"))
.def(
"allgatherv",
[](tb::CacheTransceiverComm const& self, std::vector<char> input, std::vector<int> const& sizes)
{
int total_size = std::accumulate(sizes.begin(), sizes.end(), 0);
std::vector<char> output(total_size);
bool ok = self.allgatherv(std::ref(input), std::ref(output), std::cref(sizes));
return py::make_tuple(ok, output);
},
py::arg("input"), py::arg("sizes"));
py::class_<tb::kv_cache_manager::CacheTransBufferManager>(m, "CacheTransBufferManager")
.def(py::init<tb::kv_cache_manager::BaseKVCacheManager*, std::optional<size_t>>(), py::arg("cache_manager"),
py::arg("max_num_tokens") = std::nullopt)

View File

@ -34,6 +34,7 @@
#include "tensorrt_llm/pybind/batch_manager/llmRequest.h"
#include "tensorrt_llm/pybind/common/tllmExceptions.h"
#include "tensorrt_llm/pybind/executor/bindings.h"
#include "tensorrt_llm/pybind/process_group/bindings.h"
#include "tensorrt_llm/pybind/runtime/bindings.h"
#include "tensorrt_llm/pybind/testing/modelSpecBinding.h"
#include "tensorrt_llm/pybind/thop/bindings.h"
@ -117,6 +118,7 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
// Create submodule for executor bindings.
auto mExecutor = m.def_submodule("executor", "Executor bindings");
auto mInternal = m.def_submodule("internal", "Internal submodule of TRTLLM runtime");
auto mInternalProcessGroup = mInternal.def_submodule("process_group", "PyTorch ProcessGroup internal bindings");
auto mInternalRuntime = mInternal.def_submodule("runtime", "Runtime internal bindings");
auto mInternalTesting = mInternal.def_submodule("testing", "Testing internal bindings");
auto mInternalBatchManager = mInternal.def_submodule("batch_manager", "Batch manager internal bindings");
@ -473,6 +475,7 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
.def_property_readonly("pinned", &tr::MemoryCounters::getPinned)
.def_property_readonly("uvm", &tr::MemoryCounters::getUVM);
tensorrt_llm::pybind::process_group::initBindings(mInternalProcessGroup);
tensorrt_llm::pybind::runtime::initBindings(mInternalRuntime);
tensorrt_llm::pybind::testing::initBindings(mInternalTesting);
tpb::initBindings(mInternalBatchManager);

View File

@ -0,0 +1,40 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "bindings.h"
#include "tensorrt_llm/common/bindingUtils.h"
#include "tensorrt_llm/runtime/utils/pgUtils.h"
namespace tensorrt_llm::pybind::process_group
{
void initBindings(py::module_& m)
{
m.def("init_pg",
[](py::object world_pg_obj, py::object local_pg_obj, std::string const& pybind11_abi)
{
using Pg = c10d::ProcessGroup;
using E = py::error_already_set;
pg_utils::init_pg(common::get_intrusive_ptr<Pg, E>(world_pg_obj.ptr(), pybind11_abi),
common::get_intrusive_ptr<Pg, E>(local_pg_obj.ptr(), pybind11_abi));
});
}
} // namespace tensorrt_llm::pybind::process_group

View File

@ -0,0 +1,27 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
namespace tensorrt_llm::pybind::process_group
{
void initBindings(py::module_& m);
} // namespace tensorrt_llm::pybind::process_group

View File

@ -111,3 +111,6 @@ if(NOT WIN32)
target_link_libraries(runtime_src PUBLIC libnuma::libnuma)
target_link_options(runtime_src PUBLIC ${CONAN_LIBNUMA_LINK_OPTIONS})
endif()
# Add utils subdirectory for pg_utils module
add_subdirectory(utils)

View File

@ -0,0 +1,29 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION &
# AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
# the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
# Create the pg_utils shared library
add_library(pg_utils SHARED pgUtils.cpp)
set_property(TARGET pg_utils PROPERTY POSITION_INDEPENDENT_CODE ON)
# Include directories
target_include_directories(pg_utils PUBLIC ${PROJECT_SOURCE_DIR}/include
${TORCH_INCLUDE_DIRS})
target_link_libraries(pg_utils PUBLIC ${TORCH_LIBRARIES})
# Find torch_python
find_library(TORCH_PYTHON_LIB torch_python REQUIRED
HINTS ${TORCH_INSTALL_PREFIX}/lib)

View File

@ -222,6 +222,7 @@ void initialize(MpiThreadSupport threadMode, bool forwardAbortToParent)
void MpiComm::barrier() const
{
couldUseMPI();
#if ENABLE_MULTI_DEVICE
MPICHECK(MPI_Barrier(mComm));
#else
@ -267,6 +268,7 @@ size_t invokeChunked(TMpiFunc func, TBase* buffer, size_t size, MPI_Datatype dty
std::unique_ptr<MpiRequest> MpiComm::bcastAsync(void* buffer, size_t size, MpiType dtype, int root) const
{
couldUseMPI();
std::unique_ptr<MpiRequest> r = std::make_unique<MpiRequest>();
#if ENABLE_MULTI_DEVICE
invokeChunked(MPI_Ibcast, buffer, size, getMpiDtype(dtype), root, mComm, &r->mRequest);
@ -278,11 +280,13 @@ std::unique_ptr<MpiRequest> MpiComm::bcastAsync(void* buffer, size_t size, MpiTy
std::unique_ptr<MpiRequest> MpiComm::bcastAsync(runtime::IBuffer& buf, int root) const
{
couldUseMPI();
return bcastAsync(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, root);
}
void MpiComm::bcast(void* buffer, size_t size, MpiType dtype, int root) const
{
couldUseMPI();
#if ENABLE_MULTI_DEVICE
invokeChunked(MPI_Bcast, buffer, size, getMpiDtype(dtype), root, mComm);
#else
@ -292,12 +296,14 @@ void MpiComm::bcast(void* buffer, size_t size, MpiType dtype, int root) const
void MpiComm::bcast(runtime::IBuffer& buf, int root) const
{
couldUseMPI();
bcast(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, root);
}
std::unique_ptr<MpiRequest> MpiComm::sendAsync(
void const* buffer, size_t size, MpiType dtype, int dest, MpiTag tag) const
{
couldUseMPI();
TLLM_LOG_DEBUG("start MPI_Isend with dest %d, tag %d, size %d", dest, static_cast<int>(tag), size);
std::unique_ptr<MpiRequest> r = std::make_unique<MpiRequest>();
#if ENABLE_MULTI_DEVICE
@ -311,11 +317,13 @@ std::unique_ptr<MpiRequest> MpiComm::sendAsync(
std::unique_ptr<MpiRequest> MpiComm::sendAsync(runtime::IBuffer const& buf, int dest, MpiTag tag) const
{
couldUseMPI();
return sendAsync(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, dest, tag);
}
void MpiComm::sendRawTag(void const* buffer, size_t size, MpiType dtype, int dest, int tag) const
{
couldUseMPI();
TLLM_LOG_DEBUG("start MPI_Send with dest %d, tag %d, size %d", dest, tag, size);
#if ENABLE_MULTI_DEVICE
invokeChunked(MPI_Send, buffer, size, getMpiDtype(dtype), dest, tag, mComm);
@ -327,16 +335,19 @@ void MpiComm::sendRawTag(void const* buffer, size_t size, MpiType dtype, int des
void MpiComm::send(void const* buffer, size_t size, MpiType dtype, int dest, MpiTag tag) const
{
couldUseMPI();
sendRawTag(buffer, size, dtype, dest, static_cast<int>(tag));
}
void MpiComm::send(runtime::IBuffer const& buf, int dest, MpiTag tag) const
{
couldUseMPI();
send(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, dest, tag);
}
MPI_Status MpiComm::recvRawTag(void* buffer, size_t size, MpiType dtype, int source, int tag) const
{
couldUseMPI();
TLLM_LOG_DEBUG("start MPI_Recv with source %d, tag %d, size %d", source, tag, size);
MPI_Status status{};
#if ENABLE_MULTI_DEVICE
@ -350,11 +361,13 @@ MPI_Status MpiComm::recvRawTag(void* buffer, size_t size, MpiType dtype, int sou
MPI_Status MpiComm::recv(void* buffer, size_t size, MpiType dtype, int source, MpiTag tag) const
{
couldUseMPI();
return recvRawTag(buffer, size, dtype, source, static_cast<int>(tag));
}
MPI_Status MpiComm::recv(runtime::IBuffer& buf, int source, MpiTag tag) const
{
couldUseMPI();
return recv(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, source, tag);
}
@ -382,6 +395,7 @@ MpiComm const& MpiComm::setRawSessionByFortran(int64_t fortranHandle)
void MpiComm::allreduce(void const* sendbuf, void* recvbuf, int count, MpiType dtype, MpiOp op) const
{
couldUseMPI();
#if ENABLE_MULTI_DEVICE
MPICHECK(MPI_Allreduce(sendbuf, recvbuf, count, getMpiDtype(dtype), getMpiOp(op), mComm));
#else
@ -391,6 +405,7 @@ void MpiComm::allreduce(void const* sendbuf, void* recvbuf, int count, MpiType d
void MpiComm::allgather(void const* sendbuf, void* recvbuf, int count, MpiType dtype) const
{
couldUseMPI();
#if ENABLE_MULTI_DEVICE
MPICHECK(MPI_Allgather(sendbuf, count, getMpiDtype(dtype), recvbuf, count, getMpiDtype(dtype), mComm));
#else
@ -401,6 +416,7 @@ void MpiComm::allgather(void const* sendbuf, void* recvbuf, int count, MpiType d
void MpiComm::allgatherv(void const* sendbuf, int sendcount, MpiType sendtype, void* recvbuf,
std::vector<int> const& recvcounts, std::vector<int> const& displs, MpiType recvtype) const
{
couldUseMPI();
#if ENABLE_MULTI_DEVICE
MPICHECK(MPI_Allgatherv(sendbuf, sendcount, getMpiDtype(sendtype), recvbuf, recvcounts.data(), displs.data(),
getMpiDtype(recvtype), mComm));
@ -450,6 +466,7 @@ bool MpiComm::iprobe(int source, MpiTag tag, MPI_Status* status) const
void MpiComm::recvPoll(int source, MpiTag tag, int periodMs) const
{
couldUseMPI();
MPI_Status status;
while (!iprobe(source, tag, &status))
{

View File

@ -0,0 +1,44 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/runtime/utils/pgUtils.h"
#include "tensorrt_llm/common/logger.h"
namespace tensorrt_llm::pg_utils
{
c10::intrusive_ptr<c10d::ProcessGroup> pg_world;
c10::intrusive_ptr<c10d::ProcessGroup> pg_local;
c10::intrusive_ptr<c10d::ProcessGroup> get_world_pg()
{
return pg_world;
}
c10::intrusive_ptr<c10d::ProcessGroup> get_local_pg()
{
return pg_local;
}
void init_pg(c10::intrusive_ptr<c10d::ProcessGroup> const& process_group_world,
c10::intrusive_ptr<c10d::ProcessGroup> const& process_group_local)
{
TLLM_LOG_DEBUG(process_group_world->getRank(), "Init process group on rank %d", process_group_world->getRank());
pg_world = process_group_world;
pg_local = process_group_local;
}
} // namespace tensorrt_llm::pg_utils

View File

@ -97,8 +97,9 @@ add_library(
finegrained_mixed_dtype_gemm_thop.cpp
tinygemm2.cpp)
set_property(TARGET th_common PROPERTY POSITION_INDEPENDENT_CODE ON)
target_link_libraries(th_common PRIVATE ${TORCH_LIBRARIES} th_utils
${Python3_LIBRARIES} ${SHARED_TARGET})
target_link_libraries(
th_common PRIVATE ${TORCH_LIBRARIES} th_utils ${Python3_LIBRARIES}
${SHARED_TARGET} pg_utils)
if(USING_OSS_CUTLASS_LOW_LATENCY_GEMM)
target_compile_definitions(th_common

View File

@ -18,6 +18,7 @@
#include "tensorrt_llm/common/opUtils.h"
#include "tensorrt_llm/runtime/torchUtils.h"
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
#include "tensorrt_llm/runtime/utils/pgUtils.h"
#include <NvInferRuntime.h>
#include <c10/cuda/CUDAStream.h>
@ -28,8 +29,12 @@
#include <vector>
#if ENABLE_MULTI_DEVICE
#include <nccl.h>
#include <torch/csrc/distributed/c10d/FileStore.hpp>
#include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
#endif // ENABLE_MULTI_DEVICE
using tensorrt_llm::pg_utils::PgHelper;
namespace torch_ext
{
#if ENABLE_MULTI_DEVICE
@ -115,6 +120,95 @@ private:
std::shared_ptr<ncclComm_t> mNcclComm;
};
class AllgatherPgOp
{
public:
AllgatherPgOp(std::set<int> group, c10::intrusive_ptr<c10d::ProcessGroup> const& process_group_)
: mGroup(std::move(group))
, mProcessGroup(process_group_)
{
}
~AllgatherPgOp() = default;
int initialize() noexcept
{
TLLM_LOG_TRACE("%s start for rank %d", __PRETTY_FUNCTION__, mProcessGroup->getRank());
return 0;
}
std::pair<torch::Tensor, c10::intrusive_ptr<c10d::Work>> run(
torch::Tensor input, torch::optional<torch::List<int64_t>> sizes, bool coalescing = false)
{
TLLM_CHECK_WITH_INFO(mProcessGroup.get() != nullptr, "mProcessGroup should be initialized before used");
std::vector<int64_t> outputShape = input.sizes().vec();
if (sizes.has_value())
{
outputShape[0] = std::accumulate(sizes.value().begin(), sizes.value().end(), 0, std::plus<>{});
}
else
{
outputShape[0] *= mGroup.size();
}
auto output = torch::empty(outputShape, input.options());
PgHelper pgh{mProcessGroup};
c10::intrusive_ptr<c10d::Work> work;
if (sizes.has_value())
{
std::vector inputs{input};
int64_t split_offset = 0;
std::vector<torch::Tensor> outputTensors{};
for (int root = 0; root < static_cast<int>(mGroup.size()); ++root)
{
auto split_size = sizes.value()[root];
outputTensors.push_back(
output.index({torch::indexing::Slice(split_offset, split_offset + split_size)}));
split_offset += split_size;
}
std::vector<std::vector<torch::Tensor>> outputs{outputTensors};
work = mProcessGroup->allgather(outputs, inputs, {});
}
else
{
work = pgh.allgather(input, output, {});
}
if (!coalescing)
{
PGCHECK_THROW_WITH_INFO(work, "ProcessGroup: allgather");
return {output, nullptr};
}
return {output, work};
}
std::vector<torch::Tensor> run_list(torch::TensorList input_list, torch::optional<torch::List<int64_t>> sizes)
{
std::vector<torch::Tensor> output_list;
std::vector<c10::intrusive_ptr<c10d::Work>> work_list;
output_list.reserve(input_list.size());
work_list.reserve(input_list.size());
mProcessGroup->startCoalescing(c10::DeviceType::CUDA);
for (auto const& input : input_list)
{
auto [output, work] = run(input, sizes, true);
output_list.push_back(output);
work_list.push_back(work); // Hold work objects (input & output tensors) until endCoalescing wait finished
}
if (auto work = mProcessGroup->endCoalescing(c10::DeviceType::CUDA))
{
PGCHECK_THROW_WITH_INFO(work, "ProcessGroup: allgather, end coalescing");
}
return output_list;
}
private:
std::set<int> mGroup;
c10::intrusive_ptr<c10d::ProcessGroup> mProcessGroup;
};
} // namespace
#endif // ENABLE_MULTI_DEVICE
@ -136,6 +230,24 @@ torch::Tensor allgather(torch::Tensor input, torch::optional<torch::List<int64_t
#endif // ENABLE_MULTI_DEVICE
}
torch::Tensor allgather_pg(torch::Tensor input, torch::optional<torch::List<int64_t>> sizes,
torch::List<int64_t> group_, c10::intrusive_ptr<c10d::ProcessGroup> const& process_group_)
{
#if ENABLE_MULTI_DEVICE
std::set<int> group;
for (int64_t rank : group_)
{
group.insert(static_cast<int>(rank));
}
AllgatherPgOp op(group, process_group_);
op.initialize();
auto [output, _] = op.run(input, sizes);
return output;
#else
return input;
#endif // ENABLE_MULTI_DEVICE
}
std::vector<torch::Tensor> allgather_list(
torch::TensorList input_list, torch::optional<torch::List<int64_t>> sizes, torch::List<int64_t> group_)
{
@ -154,16 +266,42 @@ std::vector<torch::Tensor> allgather_list(
#endif // ENABLE_MULTI_DEVICE
}
std::vector<torch::Tensor> allgather_list_pg(torch::TensorList input_list, torch::optional<torch::List<int64_t>> sizes,
torch::List<int64_t> group_, c10::intrusive_ptr<c10d::ProcessGroup> const& process_group_)
{
#if ENABLE_MULTI_DEVICE
std::set<int> group;
for (int64_t rank : group_)
{
group.insert(static_cast<int>(rank));
}
AllgatherPgOp op(group, process_group_);
op.initialize();
auto output_list = op.run_list(input_list, sizes);
return output_list;
#else
return input_list.vec();
#endif // ENABLE_MULTI_DEVICE
}
} // namespace torch_ext
TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
m.def("allgather(Tensor input, SymInt[]? sizes, int[] group) -> Tensor");
m.def(
"allgather_pg(Tensor input, SymInt[]? sizes, int[] group, __torch__.torch.classes.c10d.ProcessGroup "
"process_group) -> Tensor");
m.def("allgather_list(Tensor[] input_list, SymInt[]? sizes, int[] group) -> Tensor[]");
m.def(
"allgather_list_pg(Tensor[] input_list, SymInt[]? sizes, int[] group, "
"__torch__.torch.classes.c10d.ProcessGroup process_group) -> Tensor[]");
}
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
{
m.impl("allgather", &torch_ext::allgather);
m.impl("allgather_pg", &torch_ext::allgather_pg);
m.impl("allgather_list", &torch_ext::allgather_list);
m.impl("allgather_list_pg", &torch_ext::allgather_list_pg);
}

View File

@ -30,6 +30,7 @@
#include "tensorrt_llm/runtime/mcastDeviceMemory.h"
#include "tensorrt_llm/runtime/torchUtils.h"
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
#include "tensorrt_llm/runtime/utils/pgUtils.h"
#include "tensorrt_llm/thop/fp4Quantize.h"
#include "tensorrt_llm/thop/fp8Op.h"
#include "tensorrt_llm/thop/thUtils.h"
@ -37,7 +38,13 @@
#if ENABLE_MULTI_DEVICE
#include <ATen/cuda/EmptyTensor.h>
#include <c10/util/irange.h>
#include <nccl.h>
#include <torch/csrc/distributed/c10d/FileStore.hpp>
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
#include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
#include <torch/csrc/distributed/c10d/Store.hpp>
#include <torch/csrc/distributed/c10d/TCPStore.hpp>
#endif // ENABLE_MULTI_DEVICE
#include <nvml.h>
#include <torch/extension.h>
@ -50,6 +57,9 @@
using tensorrt_llm::kernels::AllReduceFusionOp;
using tensorrt_llm::kernels::AllReduceStrategyType;
using tensorrt_llm::mpi::MpiTag;
using tensorrt_llm::pg_utils::get_world_pg;
using tensorrt_llm::pg_utils::get_local_pg;
using tensorrt_llm::pg_utils::PgHelper;
namespace torch_ext
{
@ -59,6 +69,14 @@ namespace torch_ext
namespace
{
template <class... Ts>
struct overloaded : Ts...
{
using Ts::operator()...;
};
template <class... Ts>
overloaded(Ts...) -> overloaded<Ts...>;
class NvmlManager
{
public:
@ -141,6 +159,79 @@ std::set<int> getLocalGroup(std::set<int> const& group)
return localGroup;
}
std::set<int> getLocalGroupTorch(std::set<int> const& group)
{
auto const worldPg = get_world_pg();
auto const myRank = worldPg->getRank();
auto const localPg = get_local_pg();
auto const myLocalRank = localPg->getRank();
auto const localSize = static_cast<uint32_t>(localPg->getSize());
PgHelper pgh_local{localPg};
PgHelper pgh_world{worldPg}; // for p2p
std::vector<int32_t> ranks(localSize, -1);
std::vector<int32_t> localRanks(localSize, -1);
if (group.size() >= localSize)
{
PGCHECK_THROW(pgh_local.allgather(&myRank, ref(ranks), {}));
PGCHECK_THROW(pgh_local.allgather(&myLocalRank, ref(localRanks), {}));
}
else
{
int tag = static_cast<int>(MpiTag::kDefault);
if (myRank == *group.begin())
{
// Leader: gather from peers (world ranks), then broadcast full localSize arrays.
size_t cnt = 0;
ranks[cnt++] = myRank;
int tmp;
for (auto it = std::next(group.begin()); it != group.end(); ++it)
{
PGCHECK_THROW(pgh_world.recv(&tmp, *it, tag));
ranks[cnt++] = tmp;
}
for (auto it = std::next(group.begin()); it != group.end(); ++it)
{
PGCHECK_THROW(pgh_world.send(ref(ranks), *it, tag));
}
cnt = 0;
localRanks[cnt++] = myLocalRank;
for (auto it = std::next(group.begin()); it != group.end(); ++it)
{
PGCHECK_THROW(pgh_world.recv(&tmp, *it, tag));
localRanks[cnt++] = tmp;
}
for (auto it = std::next(group.begin()); it != group.end(); ++it)
{
PGCHECK_THROW(pgh_world.send(ref(localRanks), *it, tag));
}
}
else
{
int leader = *group.begin();
PGCHECK_THROW(pgh_world.send(&myRank, leader, tag));
PGCHECK_THROW(pgh_world.recv(ref(ranks), leader, tag));
PGCHECK_THROW(pgh_world.send(&myLocalRank, leader, tag));
PGCHECK_THROW(pgh_world.recv(ref(localRanks), leader, tag));
}
}
std::set<int> localGroup;
for (size_t i = 0; i < ranks.size(); ++i)
{
int world_r = ranks[i];
if (group.find(world_r) != group.end())
localGroup.insert(localRanks[i]);
}
return localGroup;
}
class AllreduceOp
{
public:
@ -154,8 +245,27 @@ public:
{
}
AllreduceOp(std::set<int> group, c10::intrusive_ptr<c10d::ProcessGroup> const& process_group_,
nvinfer1::DataType type, AllReduceStrategyType strategy, AllReduceFusionOp op, float eps)
: mGroup(std::move(group))
, mType(type)
, mStrategy(strategy)
, mOp(op)
, mEps(eps)
, mNcclComm(process_group_)
{
}
~AllreduceOp() = default;
int getRank() const
{
return std::visit(
overloaded{[&](std::shared_ptr<ncclComm_t> const&) { return COMM_SESSION.getRank(); },
[&](c10::intrusive_ptr<c10d::ProcessGroup> const& torchPg) { return get_world_pg()->getRank(); }},
mNcclComm);
}
std::vector<torch::Tensor> run(torch::Tensor const& input, torch::optional<torch::Tensor> const& residual,
torch::optional<torch::Tensor> const& norm_weight, torch::optional<torch::Tensor> const& scale,
torch::optional<torch::Tensor> const& bias, bool trigger_completion_at_end,
@ -169,7 +279,7 @@ public:
AllReduceStrategyType runtime_strategy = getRuntimeStrategy(seq_len, size);
// Log runtime strategy
auto const rank = COMM_SESSION.getRank();
auto const rank = getRank();
logRunTimeStrategy(runtime_strategy, rank);
// Dispatch to different allreduce implementations
@ -192,14 +302,18 @@ public:
int initialize()
{
TLLM_LOG_TRACE("%s start for rank %d", __PRETTY_FUNCTION__, COMM_SESSION.getRank());
mNcclComm = getComm(mGroup);
TLLM_LOG_TRACE("%s start for rank %d", __PRETTY_FUNCTION__, getRank());
if (mNcclComm.index() == 0)
{
mNcclComm = getComm(mGroup);
}
if (mStrategy != AllReduceStrategyType::NCCL && mStrategy != AllReduceStrategyType::UB)
{
initGroupTopology();
}
TLLM_LOG_TRACE("%s stop for rank %d", __PRETTY_FUNCTION__, COMM_SESSION.getRank());
TLLM_LOG_TRACE("%s stop for rank %d", __PRETTY_FUNCTION__, getRank());
return 0;
}
@ -288,13 +402,25 @@ private:
torch::optional<torch::Tensor> const& residual, torch::optional<torch::Tensor> const& norm_weight,
torch::optional<torch::Tensor> const& scale, torch::optional<torch::Tensor> const& bias)
{
torch::Tensor reduce_output;
auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
int size = input.numel();
torch::Tensor reduce_output = torch::empty_like(input);
NCCLCHECK_THROW(ncclAllReduce(input.data_ptr(), reduce_output.mutable_data_ptr(), size, (*getDtypeMap())[mType],
ncclSum, *mNcclComm, stream));
std::visit(overloaded{[&](std::shared_ptr<ncclComm_t>& rawComm)
{
auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
int size = input.numel();
reduce_output = torch::empty_like(input);
NCCLCHECK_THROW(ncclAllReduce(input.data_ptr(), reduce_output.mutable_data_ptr(), size,
(*getDtypeMap())[mType], ncclSum, *rawComm, stream));
},
[&](c10::intrusive_ptr<c10d::ProcessGroup>& torchPg)
{
reduce_output = input.clone();
// TLLM_LOG_INFO("AllReduce Rank: %d, tensor numel: %d", torchPg->getRank(),
// reduce_output.numel());
std::vector tensors{reduce_output};
PGCHECK_THROW(torchPg->allreduce(tensors, {c10d::ReduceOp::SUM}));
}},
mNcclComm);
if (mOp == AllReduceFusionOp::NONE)
{
@ -307,12 +433,13 @@ private:
std::vector<torch::Tensor> runNCCLAllReduceSymmetric(torch::Tensor const& input,
torch::optional<torch::Tensor> const& residual, torch::optional<torch::Tensor> const& norm_weight,
torch::optional<torch::Tensor> const& scale, torch::optional<torch::Tensor> const& bias) noexcept
torch::optional<torch::Tensor> const& scale, torch::optional<torch::Tensor> const& bias)
{
auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
int size = input.numel();
auto& ub_manager = tensorrt_llm::runtime::ub::UserBuffersManager::get_instance();
auto ub_tensor0 = input;
auto ub_buffer0 = ub_manager.search_buffer(input.data_ptr());
if (ub_buffer0.invalid())
{
@ -321,13 +448,23 @@ private:
cudaMemcpyAsync(symmetric_ub_buffer0.addr, input.data_ptr(), size * input.element_size(),
cudaMemcpyDeviceToDevice, stream);
ub_buffer0 = symmetric_ub_buffer0;
ub_tensor0 = symmetric_input;
}
TLLM_CHECK(!ub_buffer0.invalid());
auto [norm_out, ub_buffer1] = torch_ext::create_userbuffers_tensor(input.sizes(), input.scalar_type());
NCCLCHECK(ncclAllReduce(
ub_buffer0.addr, norm_out.mutable_data_ptr(), size, (*getDtypeMap())[mType], ncclSum, *mNcclComm, stream));
std::visit(overloaded{[&, norm_out_ = norm_out](std::shared_ptr<ncclComm_t>& rawComm)
{
NCCLCHECK_THROW(ncclAllReduce(ub_buffer0.addr, norm_out_.mutable_data_ptr(), size,
(*getDtypeMap())[mType], ncclSum, *rawComm, stream));
},
[&, norm_out_ = norm_out](c10::intrusive_ptr<c10d::ProcessGroup>& torchPg)
{
PGCHECK_THROW(PgHelper{torchPg}.allreduce(ub_tensor0, {c10d::ReduceOp::SUM}));
std::ignore = norm_out_.copy_(ub_tensor0, true);
}},
mNcclComm);
if (mOp == AllReduceFusionOp::NONE)
{
@ -348,7 +485,7 @@ private:
int hidden_size = input.size(-1);
auto const tp_size = mGroup.size();
auto const cur_rank = COMM_SESSION.getRank();
auto const cur_rank = getRank();
int tp_rank = 0;
for (auto const& currentRank : mGroup)
@ -418,7 +555,7 @@ private:
int seq_len = input.size(0);
auto const tp_size = mGroup.size();
auto const cur_rank = COMM_SESSION.getRank();
auto const cur_rank = getRank();
int tp_rank = 0;
for (auto const& currentRank : mGroup)
@ -737,9 +874,13 @@ private:
void setGroupTopology()
{
auto const rank = COMM_SESSION.getRank();
auto const rank = getRank();
TLLM_LOG_INFO("Detecting local TP group for rank %d", rank);
std::set<int> local_group = getLocalGroup(mGroup);
std::set<int> local_group = std::visit(
overloaded{[&](std::shared_ptr<ncclComm_t>&) { return getLocalGroup(mGroup); },
[&](c10::intrusive_ptr<c10d::ProcessGroup>& torchPg) { return getLocalGroupTorch(mGroup); }},
mNcclComm);
if (mGroup.size() != local_group.size())
{
mIsP2PSupported = false;
@ -750,18 +891,17 @@ private:
TLLM_LOG_INFO("TP group is intra-node for rank %d", rank);
NvmlManager nvml_manager;
std::unordered_set<int> visited_device;
mIsP2PSupported = true;
mIsNVLINKSupported = true;
// TODO(ytong): Should we provide group topology info instead of querying it here?
// Use cudaDeviceCanAccessPeer to determine whether p2p is supported,
// and use nvml to determine whether there are nvlink links between ranks.
for (int first_device_id : local_group)
{
for (int second_device_id : local_group)
{
if (first_device_id == second_device_id
|| visited_device.find(second_device_id) != visited_device.end())
if (first_device_id >= second_device_id)
{
continue;
}
@ -842,7 +982,6 @@ private:
mIsNVLINKSupported &= is_NVLINK;
}
visited_device.insert(first_device_id);
}
}
@ -999,14 +1138,14 @@ private:
AllReduceStrategyType mStrategy;
AllReduceFusionOp mOp;
float mEps;
std::shared_ptr<ncclComm_t> mNcclComm;
std::variant<std::shared_ptr<ncclComm_t>, c10::intrusive_ptr<c10d::ProcessGroup>> mNcclComm;
};
} // namespace
#endif // ENABLE_MULTI_DEVICE
std::vector<torch::Tensor> allreduce(torch::Tensor const& input, torch::optional<torch::Tensor> const& residual,
std::vector<torch::Tensor> allreduce_raw(torch::Tensor const& input, torch::optional<torch::Tensor> const& residual,
torch::optional<torch::Tensor> const& norm_weight, torch::optional<torch::Tensor> const& scale,
torch::optional<torch::Tensor> const& bias, torch::optional<torch::Tensor> workspace,
torch::List<int64_t> const& group_, int64_t const strategy_, int64_t const fusion_op_, double const eps_,
@ -1030,6 +1169,46 @@ std::vector<torch::Tensor> allreduce(torch::Tensor const& input, torch::optional
#endif // ENABLE_MULTI_DEVICE
}
std::vector<torch::Tensor> allreduce_pg(torch::Tensor const& input, torch::optional<torch::Tensor> const& residual,
torch::optional<torch::Tensor> const& norm_weight, torch::optional<torch::Tensor> const& scale,
torch::optional<torch::Tensor> const& bias, torch::optional<torch::Tensor> const& workspace,
torch::List<int64_t> const& group_, int64_t rank, c10::intrusive_ptr<c10d::ProcessGroup> const& pg,
int64_t const strategy_, int64_t const fusion_op_, double const eps_, bool const trigger_completion_at_end_)
{
#if ENABLE_MULTI_DEVICE
auto const dtype = tensorrt_llm::runtime::TorchUtils::dataType(input.scalar_type());
auto const strategy = static_cast<AllReduceStrategyType>(int8_t(strategy_));
auto const fusion_op = static_cast<AllReduceFusionOp>(int8_t(fusion_op_));
float const eps = eps_;
std::set<int> group;
for (int64_t my_rank : group_)
{
group.insert(static_cast<int>(my_rank));
}
// Get nccl rank for this process process_group_
auto it = group.find(rank);
if (it == group.end())
{
throw std::runtime_error("Rank not found in group");
}
int nccl_rank = std::distance(group.begin(), it);
if (nccl_rank != pg->getRank())
{
throw std::runtime_error("nccl_rank != pg->getRank()");
}
AllreduceOp op(group, pg, dtype, strategy, fusion_op, eps);
op.initialize();
auto ret = op.run(input, residual, norm_weight, scale, bias, trigger_completion_at_end_, workspace);
return ret;
#else
return {input};
#endif // ENABLE_MULTI_DEVICE
}
// residual [m, hidden_dim]
// norm_weight [hidden_dim]
// device_num_experts [1]
@ -1231,6 +1410,21 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m)
"int op,"
"float eps,"
"bool trigger_completion_at_end) -> Tensor[]");
m.def(
"allreduce_pg("
"Tensor input,"
"Tensor? residual,"
"Tensor? norm_weight,"
"Tensor? scale,"
"Tensor? bias,"
"Tensor? workspace,"
"int[] group,"
"int rank,"
"__torch__.torch.classes.c10d.ProcessGroup pg,"
"int strategy,"
"int op,"
"float eps,"
"bool trigger_completion_at_end) -> Tensor[]");
m.def(
"moe_allreduce("
"Tensor residual,"
@ -1262,7 +1456,8 @@ TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
{
m.impl("mnnvl_twoshot_allreduce", &torch_ext::mnnvlTwoShotAllReduce);
m.impl("mnnvl_twoshot_rmsnorm", &torch_ext::twoShotRMSNorm);
m.impl("allreduce", &torch_ext::allreduce);
m.impl("allreduce", &torch_ext::allreduce_raw);
m.impl("allreduce_pg", &torch_ext::allreduce_pg);
m.impl("moe_allreduce", &torch_ext::moe_allreduce);
m.impl("moe_finalize_allreduce", &torch_ext::moe_finalize_allreduce);
}

View File

@ -18,18 +18,22 @@
#include "tensorrt_llm/common/opUtils.h"
#include "tensorrt_llm/runtime/torchUtils.h"
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
#include "tensorrt_llm/runtime/utils/pgUtils.h"
#include <NvInferRuntime.h>
#include <c10/cuda/CUDAStream.h>
#include <torch/extension.h>
#if ENABLE_MULTI_DEVICE
#include <nccl.h>
#include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
#endif // ENABLE_MULTI_DEVICE
#include <cassert>
#include <set>
#include <vector>
using tensorrt_llm::pg_utils::PgHelper;
namespace torch_ext
{
#if ENABLE_MULTI_DEVICE
@ -49,9 +53,9 @@ public:
int initialize()
{
TLLM_LOG_TRACE("%s start for rank %d", __PRETTY_FUNCTION__, COMM_SESSION.getRank());
TLLM_LOG_TRACE("%s start for rank %d", __PRETTY_FUNCTION__, -1);
mNcclComm = getComm(mGroup);
TLLM_LOG_TRACE("%s stop for rank %d", __PRETTY_FUNCTION__, COMM_SESSION.getRank());
TLLM_LOG_TRACE("%s stop for rank %d", __PRETTY_FUNCTION__, -1);
return 0;
}
@ -125,6 +129,86 @@ private:
std::shared_ptr<ncclComm_t> mNcclComm;
};
class ReducescatterPgOp
{
public:
ReducescatterPgOp(std::set<int> group, c10::intrusive_ptr<c10d::ProcessGroup> const& process_group_)
: mGroup(std::move(group))
, mProcessGroup(process_group_)
{
}
~ReducescatterPgOp() = default;
int initialize() noexcept
{
TLLM_LOG_TRACE("%s start for rank %d", __PRETTY_FUNCTION__, mProcessGroup->getRank());
return 0;
}
std::pair<torch::Tensor, c10::intrusive_ptr<c10d::Work>> run(
torch::Tensor input, torch::optional<torch::List<int64_t>> sizes, bool coalescing = false)
{
TLLM_CHECK_WITH_INFO(mProcessGroup.get() != nullptr, "mProcessGroup should be initialized before used");
auto rank = mProcessGroup->getRank();
std::vector<int64_t> outputShape = input.sizes().vec();
if (sizes.has_value())
{
TLLM_CHECK(sizes.value().size() == mGroup.size());
outputShape[0] = sizes.value()[rank];
}
else
{
outputShape[0] = outputShape[0] / mGroup.size();
}
auto output = torch::empty(outputShape, input.options());
int64_t split_offset = 0;
std::vector<torch::Tensor> inputTensors{};
for (int root = 0; root < static_cast<int>(mGroup.size()); ++root)
{
auto split_size = sizes.has_value() ? sizes.value()[root] : outputShape[0];
inputTensors.push_back(input.index({torch::indexing::Slice(split_offset, split_offset + split_size)}));
split_offset += split_size;
}
std::vector<torch::Tensor> outputs{output};
std::vector<std::vector<torch::Tensor>> inputs{inputTensors};
auto work = mProcessGroup->reduce_scatter(outputs, inputs, {});
if (!coalescing)
{
PGCHECK_THROW_WITH_INFO(work, "ProcessGroup: reduce_scatter");
return {output, nullptr};
}
return {output, work};
}
std::vector<torch::Tensor> run_list(torch::TensorList input_list, torch::optional<torch::List<int64_t>> sizes)
{
std::vector<torch::Tensor> output_list;
std::vector<c10::intrusive_ptr<c10d::Work>> work_list;
output_list.reserve(input_list.size());
work_list.reserve(input_list.size());
mProcessGroup->startCoalescing(c10::DeviceType::CUDA);
for (auto const& input : input_list)
{
auto [output, work] = run(input, sizes, true);
output_list.push_back(output);
work_list.push_back(work); // Hold work objects (input & output tensors) until endCoalescing wait finished
}
if (auto work = mProcessGroup->endCoalescing(c10::DeviceType::CUDA))
{
PGCHECK_THROW_WITH_INFO(work, "ProcessGroup: reduce_scatter, end coalescing");
}
return output_list;
}
private:
std::set<int> mGroup;
c10::intrusive_ptr<c10d::ProcessGroup> mProcessGroup;
};
} // namespace
#endif // ENABLE_MULTI_DEVICE
@ -147,6 +231,24 @@ extern torch::Tensor reducescatter(
#endif // ENABLE_MULTI_DEVICE
}
extern torch::Tensor reducescatter_pg(torch::Tensor input, torch::optional<torch::List<int64_t>> sizes,
torch::List<int64_t> group_, c10::intrusive_ptr<c10d::ProcessGroup> const& process_group_)
{
#if ENABLE_MULTI_DEVICE
std::set<int> group;
for (int64_t rank : group_)
{
group.insert(static_cast<int>(rank));
}
ReducescatterPgOp op(group, process_group_);
op.initialize();
auto [output, _] = op.run(input, sizes);
return output;
#else
return input;
#endif // ENABLE_MULTI_DEVICE
}
extern std::vector<torch::Tensor> reducescatter_list(
torch::TensorList input_list, torch::optional<torch::List<int64_t>> sizes, torch::List<int64_t> group_)
{
@ -165,16 +267,42 @@ extern std::vector<torch::Tensor> reducescatter_list(
#endif // ENABLE_MULTI_DEVICE
}
extern std::vector<torch::Tensor> reducescatter_list_pg(torch::TensorList input_list,
torch::optional<torch::List<int64_t>> sizes, torch::List<int64_t> group_,
c10::intrusive_ptr<c10d::ProcessGroup> const& process_group_)
{
#if ENABLE_MULTI_DEVICE
std::set<int> group;
for (int64_t rank : group_)
{
group.insert(static_cast<int>(rank));
}
ReducescatterPgOp op(group, process_group_);
op.initialize();
auto output_list = op.run_list(input_list, sizes);
return output_list;
#else
return input_list.vec();
#endif // ENABLE_MULTI_DEVICE
}
} // namespace torch_ext
TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
m.def("reducescatter(Tensor input, SymInt[]? sizes, int[] group) -> Tensor");
m.def(
"reducescatter_pg(Tensor input, SymInt[]? sizes, int[] group, __torch__.torch.classes.c10d.ProcessGroup "
"process_group) -> Tensor");
m.def("reducescatter_list(Tensor[] input_list, SymInt[]? sizes, int[] group) -> Tensor[]");
m.def(
"reducescatter_list_pg(Tensor[] input_list, SymInt[]? sizes, int[] group, "
"__torch__.torch.classes.c10d.ProcessGroup process_group) -> Tensor[]");
}
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
{
m.impl("reducescatter", &torch_ext::reducescatter);
m.impl("reducescatter_pg", &torch_ext::reducescatter_pg);
m.impl("reducescatter_list", &torch_ext::reducescatter_list);
m.impl("reducescatter_list_pg", &torch_ext::reducescatter_list_pg);
}

View File

@ -31,6 +31,7 @@ add_gtest(loraConfigTest loraConfigTest.cpp)
add_gtest(intervalSetTest intervalSetTest.cpp)
add_gtest(dynamicBatchTunerTest dynamicBatchTunerTest.cpp)
add_gtest(ucxCommTest ucxCommTest.cpp)
target_link_libraries(ucxCommTest PRIVATE ${Python3_LIBRARIES})
if(NIXL_ROOT)
add_gtest(transferAgentTest transferAgentTest.cpp)

View File

@ -12,5 +12,7 @@
add_subdirectory(kernels)
add_gtest(cacheTransceiverTest cacheTransceiverTest.cpp)
target_link_libraries(cacheTransceiverTest PRIVATE ${Python3_LIBRARIES})
add_gtest(mpiUtilsTest mpiUtilsTest.cpp)
add_gtest(userBufferTest userBufferTest.cpp)

Binary file not shown.

After

Width:  |  Height:  |  Size: 283 KiB

View File

@ -0,0 +1,51 @@
<div align="center">
# TensorRT-LLM with Ray orchestrator
</div>
<div align="left">
This folder contains examples for a prototype **Ray orchestrator** that supports on-demand LLM instance spin-up and flexible GPU placement across single- and multi-node inference. Its a first step toward making TensorRT-LLM a better fit for Reinforcement learning from human feedback (RLHF) workflows. For RLHF, [Ray](https://docs.ray.io/en/latest/index.html) — unlike MPIs fixed world size and placement — can dynamically spawn and reconnect distributed inference actors, each with its own parallelism strategy.
This feature is a prototype and under active development. MPI remains the default.
## Quick Start
To use Ray orchestrator, you need to first install Ray.
```shell
cd examples/ray_orchestrator
pip install -r requirements.txt
```
Run a simple `TP=2` example with a Hugging Face model:
```shell
python llm_inference_distributed_ray.py
```
This example is the same as in `/examples/llm-api`, with the only change being `orchestrator_type="ray"` on `LLM()`. Other examples can be adapted similarly by toggling this flag.
## Features
### Available
- Generate text asynchronously (refer to [llm_inference_async_ray.py](llm_inference_async_ray.py))
- Multi-node inference (refer to [multi-node README](./multi_nodes/README.md))
- Disaggregated serving (refer to [disagg README](./disaggregated/README.md))
**Initial testing has been focused on LLaMA and DeepSeek variants. Please open an Issue if you encounter problems with other models so we can prioritize support.**
### Upcoming
- Performance optimization
- Integration with RLHF frameworks, such as [NVIDIA Nemo-RL](https://github.com/NVIDIA-NeMo/RL) and [Verl](https://github.com/volcengine/verl).
## Architecture
This feature introduces new classes such as [RayExecutor](/tensorrt_llm/executor/ray_executor.py) and [RayGPUWorker](/tensorrt_llm/executor/ray_gpu_worker.py) for Ray actor lifecycle management and distributed inference. In Ray mode, collective ops run on [torch.distributed](https://docs.pytorch.org/tutorials/beginner/dist_overview.html) without MPI. We welcome contributions to improve and extend this support.
![Ray orchestrator architecture](/docs/source/media/ray_orchestrator_architecture.jpg)
## Disclaimer
The code a prototype and subject to change. Currently, there are no guarantees regarding functionality, performance, or stability.
</div>

View File

@ -0,0 +1,28 @@
# Disaggregated Serving with Ray orchestrator
TensorRT-LLM supports a prototype [Ray orchestrator](../README.md) as an alternative to MPI.
Running disaggregated serving with Ray follows [the same workflow as in MPI](/examples/disaggregated/README.md), except that `orchestrator_type="ray"` must be set on the `LLM` class, and `CUDA_VISIBLE_DEVICES` can be omitted since Ray handles GPU placement.
## Quick Start
This script is a shorthand to launch a single-GPU context and generation server, as well as the disaggregated server within a single Ray cluster. Please see [this documentation](/examples/disaggregated/README.md) for details on adjusting parallel settings.
```bash
# requires a total of two GPUs
bash -e disagg_serving_local.sh
```
Once the disaggregated server is ready, you can send requests to the disaggregated server using curl:
```bash
curl http://localhost:8000/v1/completions \
-H "Content-Type: application/json" \
-d '{
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"prompt": "NVIDIA is a great company because",
"max_tokens": 16,
"temperature": 0
}' -w "\n"
```
## Disclaimer
The code is a prototype and subject to change. Currently, there are no guarantees regarding functionality, performance, or stability.

View File

@ -0,0 +1,146 @@
#!/bin/bash
# Parse command line arguments
BACKEND="ray"
ATTACH_MODE=false
TP_SIZE=1
USAGE="Usage: $0 [--executor ray|mpi] [--attach] [--tp_size N] [--help]"
while [[ $# -gt 0 ]]; do
case $1 in
--executor)
BACKEND="$2"
shift 2
;;
--attach)
ATTACH_MODE=true
shift
;;
--tp_size)
TP_SIZE="$2"
shift 2
;;
--help|-h)
echo "$USAGE"
echo "Options:"
echo " --executor ray|mpi Choose distributed executor (default: ray)"
echo " --attach Attach to existing ray cluster (skip ray start/stop)"
echo " --tp_size N Tensor parallel size (default: 1)"
echo " --help, -h Show this help message"
exit 0
;;
*)
echo "Unknown option: $1"
echo "$USAGE"
exit 1
;;
esac
done
if [[ "$BACKEND" != "ray" && "$BACKEND" != "mpi" ]]; then
echo "Error: Executor must be either 'ray' or 'mpi'"
echo "$USAGE"
exit 1
fi
echo "Executor: $BACKEND"
echo "Tensor parallel size: $TP_SIZE"
if [[ "$ATTACH_MODE" == "true" ]]; then
echo "Attach mode enabled - will not manage ray cluster"
fi
# Generate extra_llm_config.yaml based on executor type
echo "Generating extra_llm_config.yaml for executor: $BACKEND"
if [[ "$BACKEND" == "ray" ]]; then
cat > extra_llm_config.yaml << EOF
# extra_llm_config.yaml when launching disaggregated server instances.
cache_transceiver_config:
backend: "UCX"
max_tokens_in_buffer: 2048
disable_overlap_scheduler: true
# Ray executor configuration
orchestrator_type: "ray"
EOF
else
cat > extra_llm_config.yaml << EOF
# extra_llm_config.yaml when launching disaggregated server instances.
cache_transceiver_config:
backend: "UCX"
max_tokens_in_buffer: 2048
disable_overlap_scheduler: true
# Using default executor MPI (no orchestrator_type specified)
EOF
fi
# Generate disaggregated server config
echo "Generating disagg_config_local.yaml"
cat > disagg_config_local.yaml << EOF
hostname: localhost
port: 8000
model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
free_gpu_memory_fraction: 0.25
backend: "pytorch"
disable_overlap_scheduler: True
context_servers:
num_instances: 1
tensor_parallel_size: $TP_SIZE
pipeline_parallel_size: 1
kv_cache_config:
free_gpu_memory_fraction: 0.2
cache_transceiver_config:
backend: "UCX"
urls:
- "localhost:8001"
generation_servers:
num_instances: 1
tensor_parallel_size: $TP_SIZE
pipeline_parallel_size: 1
cache_transceiver_config:
backend: "UCX"
urls:
- "localhost:8002"
EOF
# Conditionally start ray head if using ray backend and not in attach mode
RAY_STARTED=false
if [[ "$BACKEND" == "ray" && "$ATTACH_MODE" != "true" ]]; then
echo "Checking if ray cluster is already running..."
if ray status > /dev/null 2>&1; then
echo "Ray cluster is already running. Stopping existing cluster first..."
ray stop
sleep 2
fi
echo "Launching ray head..."
ray start --head --disable-usage-stats
RAY_STARTED=true
elif [[ "$BACKEND" == "ray" && "$ATTACH_MODE" == "true" ]]; then
echo "Attach mode: Skipping ray cluster management"
fi
# Launching context servers
echo "Launching context servers..."
if [[ "$BACKEND" == "mpi" ]]; then
export CUDA_VISIBLE_DEVICES=0
fi
trtllm-serve TinyLlama/TinyLlama-1.1B-Chat-v1.0 --host localhost --tp_size $TP_SIZE --port 8001 --kv_cache_free_gpu_memory_fraction 0.15 --backend pytorch --extra_llm_api_options extra_llm_config.yaml &> output_ctx0 &
if [[ "$BACKEND" == "mpi" ]]; then
export CUDA_VISIBLE_DEVICES=1
fi
# Launching generation servers
echo "Launching generation servers..."
trtllm-serve TinyLlama/TinyLlama-1.1B-Chat-v1.0 --host localhost --tp_size $TP_SIZE --port 8002 --kv_cache_free_gpu_memory_fraction 0.15 --backend pytorch --extra_llm_api_options extra_llm_config.yaml &> output_gen0 &
# Launching disaggregated server
echo "Launching disaggregated server..."
trtllm-serve disaggregated -c disagg_config_local.yaml
# Cleanup
if [[ "$RAY_STARTED" == "true" && "$ATTACH_MODE" != "true" ]]; then
echo "Stopping ray..."
ray stop
fi
echo "Cleaning up generated extra_llm_config.yaml..."
rm -f extra_llm_config.yaml

View File

@ -0,0 +1,55 @@
# Generate text asynchronously with Ray orchestrator.
import asyncio
from tensorrt_llm import LLM, SamplingParams
from tensorrt_llm.llmapi import KvCacheConfig
def main():
# Configure KV cache memory usage fraction.
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.5,
max_tokens=4096,
enable_block_reuse=True)
# model could accept HF model name or a path to local HF model.
llm = LLM(
model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
kv_cache_config=kv_cache_config,
max_seq_len=1024,
max_batch_size=1,
orchestrator_type="ray", # Enable Ray orchestrator
# Enable 2-way tensor parallelism
# tensor_parallel_size=2
)
# Sample prompts.
prompts = [
"Hello, my name is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
# Async based on Python coroutines
async def task(prompt: str):
output = await llm.generate_async(prompt, sampling_params)
print(
f"Prompt: {output.prompt!r}, Generated text: {output.outputs[0].text!r}"
)
async def main():
tasks = [task(prompt) for prompt in prompts]
await asyncio.gather(*tasks)
asyncio.run(main())
# Got output like follows:
# Prompt: 'Hello, my name is', Generated text: '\n\nJane Smith. I am a student pursuing my degree in Computer Science at [university]. I enjoy learning new things, especially technology and programming'
# Prompt: 'The capital of France is', Generated text: 'Paris.'
# Prompt: 'The future of AI is', Generated text: 'an exciting time for us. We are constantly researching, developing, and improving our platform to create the most advanced and efficient model available. We are'
if __name__ == '__main__':
main()

View File

@ -0,0 +1,56 @@
# Generate text with Ray orchestrator.
import argparse
from tensorrt_llm import LLM, SamplingParams
def main():
# model could accept HF model name or a path to local HF model.
llm = LLM(
model=args.model_dir,
orchestrator_type="ray", # Enable Ray orchestrator
# Enable 2-way tensor parallelism
tensor_parallel_size=args.tp_size,
# Enable 2-way pipeline parallelism if needed
pipeline_parallel_size=args.pp_size,
# Enable 2-way expert parallelism for MoE model's expert weights
moe_expert_parallel_size=args.moe_ep_size)
# Sample prompts.
prompts = [
"Hello, my name is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
for output in llm.generate(prompts, sampling_params):
print(
f"Prompt: {output.prompt!r}, Generated text: {output.outputs[0].text!r}"
)
# Got output like
# Prompt: 'Hello, my name is', Generated text: '\n\nJane Smith. I am a student pursuing my degree in Computer Science at [university]. I enjoy learning new things, especially technology and programming'
# Prompt: 'The capital of France is', Generated text: 'Paris.'
# Prompt: 'The future of AI is', Generated text: 'an exciting time for us. We are constantly researching, developing, and improving our platform to create the most advanced and efficient model available. We are'
def parse_arguments():
parser = argparse.ArgumentParser(
description='LLM Inference with Ray orchestrator')
parser.add_argument('--model_dir',
type=str,
default="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
help="Model checkpoint directory")
parser.add_argument('--tp_size', type=int, default=2)
parser.add_argument('--pp_size', type=int, default=1)
parser.add_argument('--moe_ep_size', type=int, default=-1)
return parser.parse_args()
if __name__ == '__main__':
args = parse_arguments()
main()

View File

@ -0,0 +1,43 @@
# Multi-node inference with Ray orchestrator
TensorRT-LLM supports a prototype [Ray orchestrator](../README.md) as an alternative to MPI. The following example shows how to start a Ray cluster for multi-node inference.
## Quick Start
**Prerequisite:** a container image with TensorRT-LLM preinstalled (or suitable for installing it). The examples use Slurm and [Enroot](https://github.com/NVIDIA/enroot). If you use a different setup, adapt the following scripts and commands to your multi-node environment.
1. Allocate nodes and open a shell on the head node:
```shell
# e.g., 2 nodes with 8 GPUs per node
>> salloc -t 240 -N 2 -p interactive
>> srun --pty -p interactive bash
```
2. Once on the head node, launch a multi-node Ray cluster:
```shell
# Remember to set CONTAINER and MOUNTS env vars or variables inside the script to your path.
# You can add the TensorRT-LLM installation command in this script if it is not preinstalled in your container.
>> bash -e run_cluster.sh
```
3. Enter the head container and run your TensorRT-LLM driver script
Note that this step requires TensorRT-LLM to be installed in the containers on all nodes. If it isnt, install it manually inside each nodes container.
```shell
# On the head node
>> sacct
# Grab the Slurm step ID with Job Name "ray-head"
>> srun --jobid=<Your Step ID> --overlap --pty bash
>> enroot list -f # get process id
>> enroot exec <process id> bash
# You can change this script to a model and parallel settings effective for multi-node inference (e.g., TP8 or TP4PP4).
>> python examples/ray_orchestrator/llm_inference_async_ray.py
```
## Disclaimer
The code is a prototype and subject to change. Currently, there are no guarantees regarding functionality, performance, or stability.

View File

@ -0,0 +1,228 @@
#!/usr/bin/env bash
# NOTE: Multi-node with Ray orchestrator in TensorRT-LLM is an experimental feature and may not work on all systems.
# This script launches a Ray cluster and connects all allocated nodes for multi-node inference.
# The following variables are expected to be set in the environment:
# CONTAINER: the path of the container image that has the desired TensorRT-LLM version installed.
# MOUNTS: directory mount specification in format src:dest
# Run inside an alreadyactive allocated node.
# To start a Ray cluster across nodes:
# >> bash -e launch_ray.sh
#
# See multi_nodes/README.md for more details.
#
set -euo pipefail
: "${CONTAINER:?Set CONTAINER to the container image path (e.g. /path/to/trtllm.sqfs)}"
: "${MOUNTS:?Set MOUNTS to mount spec src:dst[,src2:dst2...] (no spaces)}"
RAY_PORT=${RAY_PORT:-6379}
RUN_ID=$(date +%m%d-%H%M-%S)
HEAD_NAME="ray-head-${RUN_ID}"
SLURM_SUBMIT_DIR=$PWD
BASE_LOG_DIR=${BASE_LOG_DIR:-${SLURM_SUBMIT_DIR:-$(pwd)}}
LOG_DIR="$BASE_LOG_DIR/${SLURM_JOB_ID}-logs"
mkdir -p "$LOG_DIR"
COMMAND=""
if [[ "$#" -gt 0 ]]; then
for arg; do [[ "$arg" == "--" ]] && shift && break; done
COMMAND="$*"
fi
MIN_WORKER_PORT=${MIN_WORKER_PORT:-54001}
MAX_WORKER_PORT=${MAX_WORKER_PORT:-54257}
COMMON_SRUN_ARGS+=" --mpi=pmix"
COMMON_SRUN_ARGS+=" --container-remap-root --container-writable"
COMMON_SRUN_ARGS+=" --container-mounts=$MOUNTS"
COMMON_SRUN_ARGS+=" --container-image=$CONTAINER"
COMMON_SRUN_ARGS+=" --container-workdir=$SLURM_SUBMIT_DIR"
COMMON_SRUN_ARGS+=" -p $SLURM_JOB_PARTITION"
COMMON_SRUN_ARGS+=" -A $SLURM_JOB_ACCOUNT"
COMMON_SRUN_ARGS+=" --gres=gpu:$SLURM_GPUS_ON_NODE"
# Getting the node names and IP addresses in the SLURM allocation
nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST")
nodes_array=($nodes)
ip_addresses_array=()
for node in $nodes; do
ip_address=$(host $node | awk '/has address/ { print $4 }')
ip_addresses_array+=("$ip_address")
done
head_node=${nodes_array[0]}
head_node_ip=${ip_addresses_array[0]}
WORKERS=("${nodes_array[@]:1}")
ip_head=$head_node_ip:$RAY_PORT
BLUE='\e[96m'
GREEN='\e[32m'
RESET='\e[0m'
echo -e "${BLUE}[INFO] Head node : $head_node${RESET}"
echo -e "${BLUE}[INFO] Worker(s) : ${WORKERS[*]}${RESET}"
echo -e "${BLUE}[INFO] GPUs per node : ${SLURM_GPUS_ON_NODE}${RESET}"
echo -e "${BLUE}[INFO] CONTAINER: '$CONTAINER'${RESET}"
echo -e "${BLUE}[INFO] COMMAND = '$COMMAND'${RESET}"
echo -e "${BLUE}[INFO] Logs : $LOG_DIR${RESET}"
########################################################
# Start Ray cluster on head node
########################################################
# enabled dashboard only for debug
# Add apt-get install -y --no-install-recommends libzmq3-dev for multi-node disagg
head_cmd=$(cat <<EOF
# WAR: clean all slurm / MPI / PMIx env to avoid pmix mismatch error
for v in \$(env | awk -F= '/^(PMI|PMIX|MPI|OMPI|SLURM)_/{print \$1}'); do
unset "\$v"
done
# ---- mark that the container shell is alive ----
touch "$LOG_DIR/STARTED_RAY_HEAD"
export RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES=1
export RAY_DEDUP_LOGS=0
export TRTLLM_UCX_INTERFACE=eth0
ray start --head \
--port=$RAY_PORT \
--node-ip-address="$head_node_ip" \
--disable-usage-stats \
--include-dashboard=true \
--min-worker-port=${MIN_WORKER_PORT} \
--max-worker-port=${MAX_WORKER_PORT} \
--num-cpus=16 \
--block
EOF
)
echo "head_cmd: $head_cmd"
srun --overlap $COMMON_SRUN_ARGS --job-name=ray-head --container-name="${HEAD_NAME}" --nodes=1 --ntasks=1 -w "$head_node" \
bash -c "$head_cmd" 2>&1 | tee -a "$LOG_DIR/ray-head.log" &
########################################################
# Wait til Ray cluster ready on head
########################################################
sleep 30
echo "[INFO] waiting for head container..."
while ! srun --overlap -N1 -n1 -w "$head_node" bash -c "test -f '$LOG_DIR/STARTED_RAY_HEAD'" >/dev/null 2>&1; do
echo "[INFO][$(date)] Waiting for head node container to start..."
sleep 2
done
ATTEMPTS=15
until srun --overlap --nodes=1 --ntasks=1 --cpu-bind=none -w "$head_node" \
--container-name="${HEAD_NAME}" bash -lc 'ray status >/dev/null 2>&1'
do
((ATTEMPTS--)) || { echo "[ERROR] Ray head did not come up."; exit 1; }
sleep 4
done
echo -e "${GREEN}[INFO] Ray head is UP! ✔${RESET}"
########################################################
# Start Ray worker nodes
########################################################
NUM_WORKERS=${#WORKERS[@]}
for idx in "${!WORKERS[@]}"; do
W=${WORKERS[$idx]}
worker_cmd=$(
cat <<EOF
# WAR: clean all slurm / MPI / PMIx env to avoid pmix mismatch error
for v in \$(env | awk -F= '/^(PMI|PMIX|MPI|OMPI|SLURM)_/{print \$1}'); do
unset "\$v"
done
export RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES=1
export RAY_DEDUP_LOGS=0
export TRTLLM_UCX_INTERFACE=eth0
ray start --address="$ip_head" \
--disable-usage-stats \
--min-worker-port=${MIN_WORKER_PORT} \
--max-worker-port=${MAX_WORKER_PORT} \
--num-cpus=16 \
--block
EOF
)
echo "worker_cmd (node=$W, idx=$idx): $worker_cmd"
srun --overlap $COMMON_SRUN_ARGS --exact --container-name="ray-worker-${idx}" --nodes=1 --ntasks=1 --cpu-bind=none \
-w "$W" bash -lc "$worker_cmd" 2>&1 | tee -a "$LOG_DIR/ray-worker-${idx}.log" &
done
##############################################################################
# Wait until every node is connected in the Ray cluster
##############################################################################
expected_nodes=$(( ${#WORKERS[@]} + 1 )) # head + workers
echo "[INFO] Waiting for ${#WORKERS[@]} worker nodes to join..."
ATTEMPTS=30
while (( ATTEMPTS-- )); do
# run ray status inside the head container
status=$(srun --overlap --nodes=1 --ntasks=1 --cpu-bind=none \
-w "$head_node" --container-name="${HEAD_NAME}" ray status 2>/dev/null)
# to count active nodes
active_count=$(
sed -n '/^Active:/,/^Pending:/p' <<<"$status" |
grep -c 'node_'
)
echo "[INFO] Detected $active_count worker node(s) out of $expected_nodes total expected node(s)."
[[ $active_count -eq $expected_nodes ]] && break
sleep 4
done
[[ $active_count -eq $expected_nodes ]] || { echo '[ERROR] Ray nodes timed-out'; exit 1; }
echo -e "${GREEN}[INFO] All nodes have joined Ray cluster. ✔ \`ray status\`:${RESET}"
printf '%s\n' "$status"
##############################################################################
# Run TRT-LLM driver (if given)
##############################################################################
if [[ -n "$COMMAND" ]]; then
driver_srun_cmd=(
srun --overlap
--no-container-mount-home
--container-name="${HEAD_NAME}"
--nodes=1 --ntasks=1 -w "$head_node"
bash -lc "$COMMAND"
)
# --c ontainer-workdir="$CONTAINER_CWD"
echo -e "${BLUE}[INFO] Driver srun command:${RESET}"
printf ' %q ' "${driver_srun_cmd[@]}"
echo
set +e
"${driver_srun_cmd[@]}" 2>&1 | tee -a "$LOG_DIR/trtllm-command.log"
DRIVER_RC=$?
set -e
if [[ $DRIVER_RC -ne 0 ]]; then
echo "[WARN] driver exited with status $DRIVER_RC Ray cluster left running."
fi
else
echo "[INFO] No COMMAND supplied. Idling. Press Ctrl+C to exit."
fi
wait

View File

@ -0,0 +1,3 @@
-c ../constraints.txt
tensorrt_llm>=0.0.0.dev0
ray[default]

View File

@ -1279,6 +1279,15 @@ def getMakoArgsFromStageName(stageName, parseSysinfo=false) {
} else {
makoArgs += ["auto_trigger=others"]
}
if (stageName.contains("-Ray-")) {
// If stageName contains "-Ray-", add "orchestrator=ray" to makoArgs
// At this point, only tests with orchestrator=ray or unspecified orchestrator will be run.
// Mark tests with orchestrator=mpi to exclude them from Ray stage.
makoArgs += ["orchestrator=ray"]
} else {
// Otherwise select tests with orchestrator=mpi or unspecified orchestrator
makoArgs += ["orchestrator=mpi"]
}
if (parseSysinfo) {
def taskConfig = parseMultiNodeTaskConfigFromStageName(stageName)
@ -1708,6 +1717,11 @@ def runLLMTestlistOnPlatformImpl(pipeline, platform, testList, config=VANILLA_CO
"--perf-log-formats yaml"
]
}
if (stageName.contains("-Ray-")) {
testCmdLine += ["--run-ray"]
trtllm_utils.llmExecStepWithRetry(pipeline, script: "pip3 install ray[default]")
}
// Test Coverage
def TRTLLM_WHL_PATH = sh(returnStdout: true, script: "pip3 show tensorrt_llm | grep Location | cut -d ' ' -f 2").replaceAll("\\s","")
sh "echo ${TRTLLM_WHL_PATH}"
@ -2049,6 +2063,7 @@ def launchTestJobs(pipeline, testFilter)
"DGX_H100-2_GPUs-PyTorch-Others-1": ["dgx-h100-x2", "l0_dgx_h100", 1, 1, 2],
"DGX_H100-4_GPUs-PyTorch-GptOss-1": ["dgx-h100-x4", "l0_dgx_h100", 1, 1, 4],
"DGX_H100-4_GPUs-PyTorch-Others-1": ["dgx-h100-x4", "l0_dgx_h100", 1, 1, 4],
"DGX_H100-2_GPUs-PyTorch-Ray-1": ["dgx-h100-x2", "l0_dgx_h100", 1, 1, 2],
"DGX_H100-4_GPUs-CPP-1": ["dgx-h100-x4", "l0_dgx_h100", 1, 1, 4],
"A10-PyTorch-1": ["a10", "l0_a10", 1, 1],
"A10-CPP-1": ["a10", "l0_a10", 1, 1],
@ -2071,6 +2086,7 @@ def launchTestJobs(pipeline, testFilter)
"H100_PCIe-PyTorch-1": ["h100-cr", "l0_h100", 1, 3],
"H100_PCIe-PyTorch-2": ["h100-cr", "l0_h100", 2, 3],
"H100_PCIe-PyTorch-3": ["h100-cr", "l0_h100", 3, 3],
"H100_PCIe-PyTorch-Ray-1": ["h100-cr", "l0_h100", 1, 1],
"H100_PCIe-CPP-1": ["h100-cr", "l0_h100", 1, 2],
"H100_PCIe-CPP-2": ["h100-cr", "l0_h100", 2, 2],
"H100_PCIe-TensorRT-1": ["h100-cr", "l0_h100", 1, 2],
@ -2158,6 +2174,7 @@ def launchTestJobs(pipeline, testFilter)
"B300-PyTorch-1": ["b300-single", "l0_b300", 1, 1],
"DGX_B200-4_GPUs-PyTorch-1": ["b200-x4", "l0_dgx_b200", 1, 2, 4],
"DGX_B200-4_GPUs-PyTorch-2": ["b200-x4", "l0_dgx_b200", 2, 2, 4],
"DGX_B200-4_GPUs-PyTorch-Ray-1": ["b200-x4", "l0_dgx_b200", 1, 1, 4],
"DGX_B200-8_GPUs-PyTorch-1": ["b200-x8", "l0_dgx_b200", 1, 1, 8],
"DGX_B200-4_GPUs-PyTorch-Post-Merge-1": ["b200-x4", "l0_dgx_b200", 1, 1, 4],
"DGX_B300-4_GPUs-PyTorch-Post-Merge-1": ["b300-x4", "l0_dgx_b300", 1, 1, 4],

View File

@ -606,7 +606,8 @@ def main(*,
build_deep_ep = "OFF"
build_deep_gemm = "OFF"
else:
targets.extend(["th_common", "bindings", "deep_ep", "deep_gemm"])
targets.extend(
["th_common", "bindings", "deep_ep", "deep_gemm", "pg_utils"])
build_pyt = "ON"
build_deep_ep = "ON"
build_deep_gemm = "ON"
@ -811,6 +812,8 @@ def main(*,
build_dir /
"tensorrt_llm/kernels/decoderMaskedMultiheadAttention/libdecoder_attention_1.so",
lib_dir / "libdecoder_attention_1.so")
install_file(build_dir / "tensorrt_llm/runtime/utils/libpg_utils.so",
lib_dir / "libpg_utils.so")
deep_ep_dir = pkg_dir / "deep_ep"
if deep_ep_dir.is_symlink():

View File

@ -104,8 +104,9 @@ else:
'libs/libnvinfer_plugin_tensorrt_llm.so',
'libs/libtensorrt_llm_ucx_wrapper.so', 'libs/libdecoder_attention_0.so',
'libs/libtensorrt_llm_nixl_wrapper.so', 'libs/nixl/**/*',
'libs/ucx/**/*', 'libs/libdecoder_attention_1.so',
'libs/nvshmem/License.txt', 'libs/nvshmem/nvshmem_bootstrap_uid.so.3',
'libs/ucx/**/*', 'libs/libpg_utils.so',
'libs/libdecoder_attention_1.so', 'libs/nvshmem/License.txt',
'libs/nvshmem/nvshmem_bootstrap_uid.so.3',
'libs/nvshmem/nvshmem_transport_ibgda.so.103', 'bindings.*.so',
'deep_ep/LICENSE', 'deep_ep_cpp_tllm.*.so', "include/**/*",
'deep_gemm/LICENSE', 'deep_gemm/include/**/*',

View File

@ -17,6 +17,8 @@ import struct
import sys
from typing import List, Tuple
from tensorrt_llm._utils import mpi_disabled
try:
from cuda.bindings import driver as cuda
from cuda.bindings import runtime as cudart
@ -72,11 +74,11 @@ class IpcMemory:
def __init__(self, mapping: Mapping, size: int, open_ipc: bool = True):
self.mapping = mapping
self.open_ipc = open_ipc and mapping.tp_size <= mapping.gpus_per_node
self.peer_ptrs = [0] * mapping.tp_size
self.local_ptr = 0
if self.open_ipc:
self.peer_ptrs, self.local_ptr = IpcMemory.open_ipc_memory(self.mapping, size, True)
else:
self.peer_ptrs = [0] * mapping.tp_size
self.local_ptr = 0
def __del__(self):
if not sys.is_finalizing() and self.open_ipc:
@ -103,9 +105,15 @@ class IpcMemory:
size += alignment - (size % alignment)
return size
comm = mpi_comm().Split(
mapping.pp_rank * mapping.cp_size + mapping.cp_rank, mapping.tp_rank
)
if mpi_disabled():
from tensorrt_llm._utils import torch_comm
allgather = torch_comm().tp_allgather
else:
comm = mpi_comm().Split(
mapping.pp_rank * mapping.cp_size + mapping.cp_rank, mapping.tp_rank
)
allgather = comm.allgather
# see allocateIpcMemory in cpp/tensorrt_llm/runtime/ipcUtils.cpp for alignment reason
# 1 << 21 is 2MB
@ -116,8 +124,8 @@ class IpcMemory:
_raise_if_error(cudart.cudaMemset(local_ptr, 0, aligned_size)[0])
error, local_handle = cudart.cudaIpcGetMemHandle(local_ptr)
_raise_if_error(error)
handles_reserved = allgather(local_handle.reserved)
handles_reserved = comm.allgather(local_handle.reserved)
handles = []
for reserved in handles_reserved:
handle = cudart.cudaIpcMemHandle_t()
@ -141,6 +149,8 @@ class IpcMemory:
def close_ipc_memory(mapping: Mapping, peer_ptrs: List[int]):
for node, ptr in enumerate(peer_ptrs):
if node == mapping.tp_rank:
_raise_if_error(cudart.cudaFree(ptr)[0])
if ptr != 0:
_raise_if_error(cudart.cudaFree(ptr)[0])
else:
_raise_if_error(cudart.cudaIpcCloseMemHandle(ptr)[0])
if ptr != 0:
_raise_if_error(cudart.cudaIpcCloseMemHandle(ptr)[0])

View File

@ -2,7 +2,6 @@
import atexit
import os
import socket
import sys
from typing import Callable, List, Optional, Tuple
@ -10,6 +9,8 @@ import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from tensorrt_llm._utils import get_free_port as _get_free_port
from ..utils.logger import ad_logger
# TODO: check to what extend we can reuse _torch/distributed.py
@ -69,10 +70,7 @@ def all_gather_object(object_list, object, group=None):
def get_free_port():
sock = socket.socket()
sock.bind(("", 0))
port = sock.getsockname()[1]
return port
return _get_free_port()
def get_world_size() -> int:

View File

@ -10,7 +10,7 @@ from ..._utils import get_sm_version
def _register_fake():
@torch.library.register_fake("trtllm::allreduce")
def _(
def allreduce(
input,
residual,
norm_weight,
@ -55,6 +55,25 @@ def _register_fake():
else:
return [torch.empty_like(input)]
@torch.library.register_fake("trtllm::allreduce_pg")
def _(
input,
residual,
norm_weight,
scale,
bias,
workspace,
group,
rank,
pg,
strategy,
op,
eps,
trigger_completion_at_end,
):
return allreduce(input, residual, norm_weight, scale, bias, workspace,
group, strategy, op, eps, trigger_completion_at_end)
#MNNVL Allreduce
@torch.library.register_fake("trtllm::mnnvl_twoshot_allreduce")
def _(input, buffer, buffer_flags, buffer_size, wait_for_results):
@ -76,13 +95,17 @@ def _register_fake():
return [norm_out, residual_out]
@torch.library.register_fake("trtllm::allgather")
def _(input, sizes, group):
def allgather(input, sizes, group):
if sizes is None:
output_shape = (len(group) * input.shape[0], *input.shape[1:])
else:
output_shape = (sum(sizes), *input.shape[1:])
return input.new_empty(output_shape)
@torch.library.register_fake("trtllm::allgather_pg")
def _(input, sizes, group, process_group):
return allgather(input, sizes, group)
@torch.library.register_fake("trtllm::cublas_scaled_mm")
def _(
mat_a: torch.Tensor,
@ -439,7 +462,7 @@ def _register_fake():
dtype=gemm2_output.dtype)
@torch.library.register_fake("trtllm::allgather_list")
def _(input_list, sizes, group):
def allgather_list(input_list, sizes, group):
assert len(input_list) > 0
def create_output_tensor(i):
@ -452,8 +475,12 @@ def _register_fake():
return [create_output_tensor(i) for i in input_list]
@torch.library.register_fake("trtllm::allgather_list_pg")
def _(input_list, sizes, group, process_group):
return allgather_list(input_list, sizes, group)
@torch.library.register_fake("trtllm::reducescatter")
def _(input, sizes, group):
def reducescatter(input, sizes, group):
import tensorrt_llm
local_rank = tensorrt_llm.mpi_rank()
@ -464,6 +491,10 @@ def _register_fake():
shape[0] = sizes[local_rank]
return input.new_empty(shape)
@torch.library.register_fake("trtllm::reducescatter_pg")
def _(input, sizes, group, process_group):
return reducescatter(input, sizes, group)
@torch.library.register_fake("trtllm::block_scale_interleave")
def _(sf: torch.Tensor):
rows = sf.shape[-2]

View File

@ -0,0 +1,165 @@
from functools import wraps
from typing import TYPE_CHECKING, List
import torch
import torch.distributed as dist
from torch.distributed import get_process_group_ranks
from torch.distributed.device_mesh import init_device_mesh
from tensorrt_llm.logger import logger
if TYPE_CHECKING:
from tensorrt_llm.mapping import MappingBase as _MappingBaseForTypeCheck
else:
_MappingBaseForTypeCheck = object
def require_device_mesh(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
if DeviceMeshTopologyImpl.device_mesh is None:
self.build_mesh()
return func(self, *args, **kwargs)
return wrapper
class SingleProcessGroup:
@staticmethod
def get_group():
return dist.group.WORLD if dist.is_initialized(
) else SingleProcessGroup()
@staticmethod
def rank():
return 0
@staticmethod
def size():
return 1
class DeviceMeshTopologyImpl(_MappingBaseForTypeCheck):
device_mesh = None
tp_mesh = None
# Access Torch ProcessGroup
@property
@require_device_mesh
def tp_group_pg(self):
return self._get_mesh_dim_by_name('tp').get_group()
@property
@require_device_mesh
def pp_group_pg(self):
return self._get_mesh_dim_by_name('pp').get_group()
@property
@require_device_mesh
def cp_group_pg(self):
return self._get_mesh_dim_by_name('cp').get_group()
@property
@require_device_mesh
def moe_tp_group_pg(self):
return self._get_mesh_dim_by_name('moe_tp').get_group()
@property
@require_device_mesh
def moe_ep_group_pg(self):
return self._get_mesh_dim_by_name('moe_ep').get_group()
# Access rank
@property
def tp_rank(self) -> int:
assert not self.auto_parallel, "Auto parallel is not currently supported in Ray mode."
return self.tp_group_pg.rank()
@property
def pp_rank(self) -> int:
assert not self.auto_parallel, "Auto parallel is not currently supported in Ray mode."
return self.pp_group_pg.rank()
@property
def cp_rank(self) -> int:
# TODO: WIP
assert not self.auto_parallel, "Auto parallel is not currently supported in Ray mode."
return self.cp_group_pg.rank()
# Access group ranks
@property
def tp_group(self) -> List[int]:
return self._get_group_ranks(self.tp_group_pg)
@property
def pp_group(self) -> List[int]:
return self._get_group_ranks(self.pp_group_pg)
@property
def cp_group(self) -> List[int]:
return self._get_group_ranks(self.cp_group_pg)
@property
def moe_tp_group(self) -> List[int]:
return self._get_group_ranks(self.moe_tp_group_pg)
@property
def moe_ep_group(self) -> List[int]:
return self._get_group_ranks(self.moe_ep_group_pg)
def build_mesh(self):
cls = DeviceMeshTopologyImpl
if self.world_size == 1 or cls.device_mesh is not None:
# only build mesh once
return
if not torch.distributed.is_initialized():
raise RuntimeError(
"DeviceMesh creation requested but torch.distributed process group "
"has not been initialised.")
dims = ["cp", "pp"]
shape = [self.cp_size, self.pp_size]
if self.moe_ep_size > 1:
dims += ["moe_tp", "moe_ep"]
shape += [self.moe_tp_size, self.moe_ep_size]
else:
dims += ["tp"]
shape += [self.tp_size]
cls.device_mesh = init_device_mesh(
"cuda",
mesh_shape=tuple(shape),
mesh_dim_names=tuple(dims),
)
if self.moe_ep_size > 1:
cls.tp_mesh = cls.device_mesh["moe_tp",
"moe_ep"]._flatten(mesh_dim_name="tp")
logger.debug(f"DeviceMeshTopology.device_mesh: {cls.device_mesh}")
logger.debug(f"DeviceMeshTopology.tp_mesh: {cls.tp_mesh}")
@require_device_mesh
def _get_mesh_dim_by_name(self, name: str) -> dist.DeviceMesh:
cls = DeviceMeshTopologyImpl
if cls.device_mesh is None and self.world_size == 1:
return SingleProcessGroup()
if name == 'tp':
if 'tp' in cls.device_mesh.mesh_dim_names:
return cls.device_mesh['tp']
else:
return cls.tp_mesh
else:
assert name in cls.device_mesh.mesh_dim_names, f"Dimension name {name} not found in device mesh."
return cls.device_mesh[name]
def _get_group_ranks(self, pg) -> List[int]:
if self.world_size == 1:
return [0]
return get_process_group_ranks(pg)

View File

@ -1,12 +1,14 @@
import math
import os
import pickle # nosec B403
from abc import ABC, abstractmethod
from functools import wraps
from typing import Optional
import numpy as np
import torch
import torch.distributed as dist
from torch.distributed.distributed_c10d import (_object_to_tensor,
_tensor_to_object)
try:
from mpi4py import MPI
@ -14,11 +16,19 @@ except Exception:
MPI = None # deferred; functions will error if used when ENABLE_MULTI_DEVICE is True
from tensorrt_llm._utils import (mpi_allgather, mpi_barrier, mpi_comm,
mpi_isend, mpi_isend_object, mpi_recv,
mpi_recv_object, mpi_send, mpi_send_object)
mpi_disabled, mpi_isend, mpi_isend_object,
mpi_recv, mpi_recv_object, mpi_send,
mpi_send_object, torch_pybind11_abi)
from tensorrt_llm.bindings.BuildInfo import ENABLE_MULTI_DEVICE
from tensorrt_llm.bindings.internal.process_group import init_pg
from tensorrt_llm.logger import logger
from tensorrt_llm.mapping import Mapping
try:
import ray
except ModuleNotFoundError:
from tensorrt_llm import ray_stub as ray
class Distributed(ABC):
@ -327,6 +337,7 @@ def safe_gather(comm, obj, root=0, chunk_size: int = 4 * 1024 * 1024):
class MPIDist(Distributed):
tp_comm: MPI.Comm
def __init__(self, mapping: Mapping):
super().__init__(mapping)
@ -401,65 +412,361 @@ class MPIDist(Distributed):
return self.pp_comm.bcast(obj, root)
class MultiHandleWrapper:
"""
Wrapper that encapsulates multiple handles and provides a single wait() interface
to unify the API between MPIDist and TorchDist.
"""
def __init__(self, handles):
self.handles = handles if isinstance(handles, list) else [handles]
def wait(self):
for handle in self.handles:
try:
handle.wait()
except Exception as e:
raise RuntimeError(f"Asynchronous operation failed: {e}") from e
class TorchDist(Distributed):
@property
def rank(self):
return torch.distributed.get_rank()
def __init__(self, mapping: Mapping):
super().__init__(mapping)
if not dist.is_initialized():
master_ip = os.getenv("MASTER_ADDR", "localhost")
# TODO: fix the constant default port
master_port = os.getenv("MASTER_PORT", "6000")
init_method = f"tcp://{master_ip}:{master_port}"
dist.init_process_group(backend="nccl",
init_method=init_method,
world_size=mapping.world_size,
rank=mapping.rank)
self.device_tp_group = dist.new_group(mapping.tp_group, backend="nccl")
self.cpu_tp_group = dist.new_group(mapping.tp_group, backend="gloo")
self.device_cp_group = dist.new_group(mapping.cp_group, backend="nccl")
self.cpu_cp_group = dist.new_group(mapping.cp_group, backend="gloo")
assert dist.is_initialized(
), "torch.distributed should be initialized before TorchDist"
def broadcast_tp(self, obj, root=0):
if root not in self.mapping.tp_group:
return obj
elif self.rank == root:
torch.distributed.broadcast_object_list([obj],
src=root,
group=self.cpu_tp_group)
return obj
self.cluster_info = None
from tensorrt_llm._utils import set_torch_comm
set_torch_comm(self) # Set as global instance
mapping.build_mesh()
self.setup_local_comm()
self.default_store = torch.distributed.distributed_c10d._get_default_store(
)
init_pg(torch.distributed.group.WORLD, self.local_comm,
torch_pybind11_abi())
def setup_local_comm(self):
self._get_cluster_info()
# node IP -> list of ranks
ip_to_ranks = {}
for rank, (node_ip, _) in enumerate(self.cluster_info):
ip_to_ranks.setdefault(node_ip, []).append(int(rank))
self.local_comm = None
for ranks in ip_to_ranks.values():
# All global ranks from the default process group to participate in the call,
# even if some ranks are not part of the new process group being created
pg = dist.new_group(ranks=ranks, backend='cuda:nccl,cpu:gloo')
if int(self.rank) in ranks:
logger.debug(
f"[Rank {self.rank}] Done setting local comm. ip_to_ranks: {ip_to_ranks}"
)
self.local_comm = pg
def _get_cluster_info(self):
if self.cluster_info is not None:
return self.cluster_info
if ray.is_initialized():
node_ip = ray.util.get_node_ip_address()
else:
recv = [None]
torch.distributed.broadcast_object_list(recv,
src=root,
group=self.cpu_tp_group)
return recv[0]
raise RuntimeError("Ray is not initialized")
def broadcast_cp(self, obj, root=0):
if root not in self.mapping.cp_group:
return obj
elif self.rank == root:
torch.distributed.broadcast_object_list([obj],
src=root,
group=self.cpu_cp_group)
return obj
else:
recv = [None]
torch.distributed.broadcast_object_list(recv,
src=root,
group=self.cpu_cp_group)
return recv[0]
gpu_index = [int(id) for id in ray.get_gpu_ids()]
assert len(gpu_index) == 1
# Gather node ip
node_list = [None] * torch.distributed.get_world_size()
torch.distributed.all_gather_object(node_list, node_ip)
# Gather gpu index
gpu_list = [None] * torch.distributed.get_world_size()
torch.distributed.all_gather_object(gpu_list, gpu_index[0])
# Gather rank
rank_list = [None] * torch.distributed.get_world_size()
torch.distributed.all_gather_object(rank_list,
torch.distributed.get_rank())
rank_info_list = [None] * torch.distributed.get_world_size()
for i in range(len(rank_list)):
rank_info_list[rank_list[i]] = (node_list[i], gpu_list[i])
self.cluster_info = rank_info_list
logger.debug(f"Cluster info: {self.cluster_info}")
return self.cluster_info
@staticmethod
def log_op(func, enable_log=False):
@wraps(func)
def wrapper(*args, **kwargs):
if enable_log:
logger.debug(
f"{func.__name__} enter: {args[1:]}, {kwargs}, rank: {torch.distributed.get_rank()}"
)
ret = func(*args, **kwargs)
if enable_log:
logger.debug(f"{func.__name__} exit: {ret}")
return ret
return wrapper
@log_op
def broadcast(self, obj, root=0):
assert not (self.mapping.has_cp_ulysses() and self.mapping.has_tp()
), 'Unsupported mix of Ulysses CP and TP.'
if mpi_disabled():
if isinstance(obj, torch.Tensor):
dist.broadcast(obj, src=root)
return obj
else:
obj_list = [obj]
dist.broadcast_object_list(obj_list, src=root)
return obj_list[0]
if self.mapping.has_cp_ulysses():
self.broadcast_cp(obj, root)
elif self.mapping.has_tp():
self.broadcast_tp(obj, root)
@log_op
def allgather(self, obj):
if isinstance(obj, torch.Tensor):
output_list = [
torch.empty_like(obj) for _ in range(self.world_size)
]
dist.all_gather(output_list, obj)
return output_list
else:
pass
obj_list = [None] * self.world_size
dist.all_gather_object(obj_list, obj)
return obj_list
@log_op
def barrier(self):
dist.barrier()
@log_op
def isend(self, buf: np.ndarray, dest, tag=0):
# non-blocking send numpy buffer
tensor = torch.from_numpy(buf)
return dist.isend(tensor, dst=dest, tag=tag)
@log_op
def send(self, buf: np.ndarray, dest, tag=0):
raise NotImplementedError(
"blocking send is not implemented for TorchDist")
@log_op
def recv(self, buf: np.ndarray, src, tag=0):
# in-place recv numpy buffer
tensor = torch.empty_like(torch.from_numpy(buf))
dist.recv(tensor, src=src, tag=tag)
return tensor.numpy()
@log_op
def isend_tensor(self, tensor: torch.Tensor, dest, tag=0):
return dist.isend(tensor, dst=dest, tag=tag)
@log_op
def recv_tensor(self, tensor: torch.Tensor, src, tag=0):
dist.recv(tensor, src=src, tag=tag)
return tensor
@log_op
def recv_object(self, src, tag=0):
size_tensor = torch.tensor([0], dtype=torch.int32)
torch.distributed.recv(size_tensor,
src=src,
tag=tag,
group=torch.distributed.group.WORLD)
bytes_size = size_tensor.item()
recv_tensor = torch.empty(bytes_size, dtype=torch.uint8)
torch.distributed.recv(recv_tensor,
src=src,
tag=tag,
group=torch.distributed.group.WORLD)
return _tensor_to_object(recv_tensor, bytes_size,
torch.distributed.group.WORLD)
@log_op
def send_object(self, obj, dest, tag=0):
raise NotImplementedError(
"send_object is not implemented for TorchDist")
@log_op
def isend_object(self, obj, dest, tag=0):
input_tensor, local_size = _object_to_tensor(
obj, torch.device("cpu"), torch.distributed.group.WORLD)
# Send object size
works = []
works.append(
torch.distributed.isend(torch.tensor([local_size],
dtype=torch.int32),
dst=dest,
tag=tag))
works.append(torch.distributed.isend(input_tensor, dst=dest, tag=tag))
return MultiHandleWrapper(works)
@log_op
def recv_object_from_isend(self, src, tag):
size_tensor = torch.tensor([0], dtype=torch.int32)
torch.distributed.recv(size_tensor, src=src, tag=tag)
bytes_size = size_tensor.item()
recv_tensor = torch.empty(bytes_size, dtype=torch.uint8)
torch.distributed.recv(recv_tensor, src=src, tag=tag)
return _tensor_to_object(recv_tensor, bytes_size,
torch.distributed.group.WORLD)
@log_op
def allreduce(self,
obj: int | float | torch.Tensor,
op=torch.distributed.ReduceOp.SUM):
is_base_type = isinstance(obj, int) or isinstance(obj, float)
if is_base_type:
obj = torch.tensor(obj)
dist.all_reduce(obj, op=op)
if is_base_type:
obj = obj.item()
return obj
@log_op
def tp_allgather(self, obj):
if isinstance(obj, torch.Tensor):
output_list = [
torch.empty_like(obj)
for _ in range(self.mapping.tp_group_pg.size())
]
dist.all_gather(output_list, obj, group=self.mapping.tp_group_pg)
return output_list
else:
output_list = [None] * self.mapping.tp_group_pg.size()
dist.all_gather_object(output_list,
obj,
group=self.mapping.tp_group_pg)
return output_list
@log_op
def tp_gather(self, obj, dst=0):
global_rank = torch.distributed.get_rank()
if isinstance(obj, torch.Tensor):
if global_rank == dst:
output_list = [
torch.empty_like(obj)
for _ in range(self.mapping.tp_group_pg.size())
]
else:
output_list = None
dist.gather(obj,
output_list,
dst=dst,
group=self.mapping.tp_group_pg)
return output_list
else:
output_list = [None] * self.mapping.tp_group_pg.size()
if global_rank == dst:
output_list = [None] * self.mapping.tp_group_pg.size()
else:
output_list = None
dist.gather_object(obj,
output_list,
dst=dst,
group=self.mapping.tp_group_pg)
return output_list
@log_op
def tp_broadcast(self, obj, root=0):
if isinstance(obj, torch.Tensor):
dist.broadcast(obj, src=root, group=self.mapping.tp_group_pg)
return obj
else:
ret = [obj]
torch.distributed.broadcast_object_list(
ret,
src=root,
group=self.mapping.tp_group_pg,
device=torch.device("cpu"))
return ret[0]
@log_op
def pp_allgather(self, obj):
if isinstance(obj, torch.Tensor):
output_list = [
torch.empty_like(obj)
for _ in range(self.mapping.pp_group_pg.size())
]
dist.all_gather(output_list, obj, group=self.mapping.pp_group_pg)
return output_list
else:
output_list = [None] * self.mapping.pp_group_pg.size()
dist.all_gather_object(output_list,
obj,
group=self.mapping.pp_group_pg)
return output_list
@log_op
def pp_gather(self, obj, dst=0):
global_rank = torch.distributed.get_rank()
if isinstance(obj, torch.Tensor):
if global_rank == dst:
output_list = [
torch.empty_like(obj)
for _ in range(self.mapping.pp_group_pg.size())
]
else:
output_list = None
dist.gather(obj,
output_list,
dst=dst,
group=self.mapping.pp_group_pg)
return output_list
else:
output_list = [None] * self.mapping.pp_group_pg.size()
if global_rank == dst:
output_list = [None] * self.mapping.pp_group_pg.size()
else:
output_list = None
dist.gather_object(obj,
output_list,
dst=dst,
group=self.mapping.pp_group_pg)
return output_list
@log_op
def pp_broadcast(self, obj, root=0):
if isinstance(obj, torch.Tensor):
dist.broadcast(obj, src=root, group=self.mapping.pp_group_pg)
return obj
else:
ret = [obj]
torch.distributed.broadcast_object_list(
ret,
src=root,
group=self.mapping.pp_group_pg,
device=torch.device("cpu"))
return ret[0]
# TODO: rename to PPCommNCCL
class PPComm:
def __init__(self, global_mapping: Mapping):
@ -480,20 +787,49 @@ class PPComm:
self.nccl_comm.recv(tensor, src)
class PPCommTorch:
def __init__(self, global_mapping: Mapping):
self.mapping = global_mapping
self.pg = self.mapping.pp_group_pg
self.pg_group = self.mapping.pp_group
def _global_to_local_rank(self, global_rank: int):
assert global_rank in self.pg_group
return self.pg_group.index(global_rank)
def send(self, tensor: torch.Tensor, dest: Optional[int] = None):
if dest is None:
dest = self.mapping.next_pp_rank()
self.pg.send([tensor], self._global_to_local_rank(dest), tag=0).wait()
def recv(self, tensor: torch.Tensor, src: Optional[int] = None):
if src is None:
src = self.mapping.prev_pp_rank()
self.pg.recv([tensor], self._global_to_local_rank(src), tag=0).wait()
_pp_comm = None
def init_pp_comm(mapping):
"""Initialize PPComm once at startup"""
global _pp_comm
_pp_comm = PPComm(mapping)
if mpi_disabled():
_pp_comm = PPCommTorch(mapping)
else:
_pp_comm = PPComm(mapping)
@TorchDist.log_op
def pp_recv(tensor):
"""Receive tensors from previous pp rank."""
_pp_comm.recv(tensor)
@TorchDist.log_op
def pp_send(tensor):
"""Send tensors to next pp rank."""
_pp_comm.send(tensor)

View File

@ -7,7 +7,7 @@ from typing import List, Optional, Tuple, Union
import torch
from torch import nn
from tensorrt_llm._utils import mpi_comm
from tensorrt_llm._utils import mpi_comm, mpi_disabled
from tensorrt_llm.bindings.internal.runtime import McastGPUBuffer
from tensorrt_llm.functional import (AllReduceFusionOp, AllReduceParams,
AllReduceStrategy, MoEAllReduceParams)
@ -185,26 +185,33 @@ def allgather(
val.shape[dim] == sizes[mapping.tp_rank] for val in input
if val is not None
])
# Inputs are reshaped in this way to pass necessary shape information to the allgather op
if isinstance(input, torch.Tensor):
torch_op = torch.ops.trtllm.allgather
if mpi_disabled():
torch_op = torch.ops.trtllm.allgather_pg
else:
torch_op = torch.ops.trtllm.allgather
output_info = get_output_info(input, dim)
input = input.contiguous().view(-1, output_info['numel_base'])
else:
input, valid = filter_valid_input(input)
torch_op = torch.ops.trtllm.allgather_list
if mpi_disabled():
torch_op = torch.ops.trtllm.allgather_list_pg
else:
torch_op = torch.ops.trtllm.allgather_list
output_info = [get_output_info(val, dim) for val in input]
input = [
val.contiguous().view(-1, val_info['numel_base'])
for val, val_info in zip(input, output_info)
]
output = torch_op(
input,
sizes,
mapping.tp_group,
)
if mpi_disabled():
output = torch_op(input, sizes, mapping.tp_group,
mapping.tp_group_pg.boxed())
else:
output = torch_op(input, sizes, mapping.tp_group)
def convert_output(x, x_info):
if dim == 0:
@ -300,23 +307,29 @@ def reducescatter(
return x
if isinstance(input, torch.Tensor):
torch_op = torch.ops.trtllm.reducescatter
if mpi_disabled():
torch_op = torch.ops.trtllm.reducescatter_pg
else:
torch_op = torch.ops.trtllm.reducescatter
output_info = get_output_info(input, dim)
input = convert_input(input, output_info)
else:
input, valid = filter_valid_input(input)
torch_op = torch.ops.trtllm.reducescatter_list
if mpi_disabled():
torch_op = torch.ops.trtllm.reducescatter_list_pg
else:
torch_op = torch.ops.trtllm.reducescatter_list
output_info = [get_output_info(val, dim) for val in input]
input = [
convert_input(val, val_info)
for val, val_info in zip(input, output_info)
]
output = torch_op(
input,
sizes,
mapping.tp_group,
)
if mpi_disabled():
output = torch_op(input, sizes, mapping.tp_group,
mapping.tp_group_pg.boxed())
else:
output = torch_op(input, sizes, mapping.tp_group)
if isinstance(input, torch.Tensor):
output = output.view(output_info['output_shape'])
@ -489,6 +502,9 @@ class AllReduce(nn.Module):
self.workspace = None
self.strategy = strategy
self.mnnvl_allreduce = None
self._disable_mpi = mpi_disabled()
self.all_reduce_op = torch.ops.trtllm.allreduce_pg if self._disable_mpi else torch.ops.trtllm.allreduce
if self.mapping.tp_size > 1:
# When Strategy is UB, it is guaranteed that the workspace is not used.
@ -572,7 +588,17 @@ class AllReduce(nn.Module):
# Make sure the strategy is AUTO since allreduceOp does not have the branch for MNNVL
if allreduce_strategy == AllReduceStrategy.MNNVL:
allreduce_strategy = AllReduceStrategy.AUTO
output = torch.ops.trtllm.allreduce(
additional_args = {}
if self._disable_mpi:
pg = self.mapping.tp_group_pg
assert pg is not None, "TP ProcessGroup not initialised"
additional_args = {
"rank": torch.distributed.get_rank(),
"pg": pg.boxed(),
}
output = self.all_reduce_op(
input=input,
residual=all_reduce_params.residual,
norm_weight=all_reduce_params.norm_weight,
@ -585,6 +611,7 @@ class AllReduce(nn.Module):
eps=all_reduce_params.eps,
trigger_completion_at_end=all_reduce_params.
trigger_completion_at_end,
**additional_args,
)
return output if len(output) > 1 else output[0]

View File

@ -0,0 +1,43 @@
import torch
import torch.distributed as dist
def split(color: int, key: int,
pg_boxed: torch.ScriptObject) -> torch.ScriptObject:
"""Create a subgroup ProcessGroup.
This gathers (color, key) from all ranks, selects members with matching color,
sorts them by key to determine rank ordering, and creates a new ProcessGroup
for those ranks using torch.distributed.new_group, inheriting backend from
the global ProcessGroup.
"""
if not dist.is_initialized():
raise RuntimeError("torch.distributed is not initialized")
try:
pg = torch.distributed.ProcessGroup.unbox(pg_boxed)
except Exception as e:
raise ValueError(f"Error unboxing ProcessGroup: {e}") from e
group_size = dist.get_world_size(group=pg)
# gather (color, key, global_rank) within the provided pg
payload = (int(color), int(key), int(dist.get_rank(group=pg)))
gathered = [None] * group_size
dist.all_gather_object(gathered, payload, group=pg)
members = []
for c, k, global_rank in gathered:
if c == color:
members.append((int(k), int(global_rank)))
members.sort()
ranks = [r for _, r in members]
if not ranks:
raise ValueError(f"Split by color {color} produced empty subgroup")
if (current_rank := dist.get_rank()) not in ranks:
raise ValueError(
f"Current rank {current_rank} not in color {color} subgroup")
# Create subgroup under the provided pg; ranks are global ranks
sub_pg = dist.new_group(ranks=ranks, use_local_synchronization=True)
# Return TorchScript boxed ProcessGroup so C++ can unwrap it
return sub_pg.boxed()

View File

@ -10,7 +10,7 @@ from typing import Dict, Iterable, List, Optional, Tuple
import torch
from tensorrt_llm._utils import nvtx_range
from tensorrt_llm._utils import mpi_disabled, nvtx_range
from tensorrt_llm.mapping import CpType
from ..distributed import Distributed
@ -69,6 +69,8 @@ class ExecutorRequestQueue:
self.is_shutdown = False
self.should_exclude_last_generation_logits = False
self._disable_mpi = mpi_disabled()
def _get_from_request_queue(
self,
timeout: Optional[datetime.timedelta]) -> List[RequestQueueItem]:
@ -267,8 +269,14 @@ class ExecutorRequestQueue:
) -> List[RequestQueueItem]:
"""Common logic for fetching and processing requests from the queue."""
# Calculate timeout
timeout = None if (total_num_active_requests == 0) and len(
self.waiting_queue) == 0 else datetime.timedelta(0)
idle = (total_num_active_requests == 0) and len(self.waiting_queue) == 0
if idle:
# In Ray path (TLLM_DISABLE_MPI=1), use a periodic heartbeat timeout so rank 0
# reaches the broadcast path regularly to prevent trtllm-serve timeout when idle.
timeout = datetime.timedelta(
seconds=1200) if self._disable_mpi else None
else:
timeout = datetime.timedelta(0)
# Fetch requests from rank 0
new_requests = []
@ -569,7 +577,13 @@ class ExecutorRequestQueue:
payloads = self.dist.recv_object(self.dist.prev_pp_rank, tag)
if not self.dist.is_last_pp_rank:
self.dist.send_object(payloads, self.dist.next_pp_rank, tag)
if self._disable_mpi:
isend_payload = self.dist.isend_object(payloads,
self.dist.next_pp_rank,
tag)
isend_payload.wait()
else:
self.dist.send_object(payloads, self.dist.next_pp_rank, tag)
return payloads

View File

@ -3,6 +3,7 @@ import datetime
import functools
import gc
import os
import pickle # nosec B403
import threading
import time
import traceback
@ -21,8 +22,8 @@ except ImportError:
from tensorrt_llm._torch.pyexecutor.resource_manager import (
ResourceManagerType, request_context)
from tensorrt_llm._utils import (customized_gc_thresholds, global_mpi_rank,
is_trace_enabled, nvtx_range, trace_func)
from tensorrt_llm._utils import (customized_gc_thresholds, is_trace_enabled,
mpi_disabled, nvtx_range, trace_func)
from tensorrt_llm.bindings.executor import (DisServingRequestStats,
FinishReason, InflightBatchingStats,
IterationStats, KvCacheStats,
@ -166,7 +167,7 @@ class PyExecutor:
peft_cache_config: Optional[PeftCacheConfig] = None):
super(PyExecutor, self).__init__()
self.device_id = torch.cuda.current_device()
self.global_rank = global_mpi_rank()
self.global_rank = dist.rank
self.peft_cache_config = peft_cache_config
@ -213,6 +214,7 @@ class PyExecutor:
self.response_lock = threading.Lock()
self.response_cv = threading.Condition(self.response_lock)
self.responses = {}
self.result_wait_queues = {}
# kv cache events
self.kv_cache_manager = self.resource_manager.resource_managers.get(
@ -232,6 +234,7 @@ class PyExecutor:
self.num_scheduled_requests: int = 0
self.benchmark_req_queues_size = int(
os.environ.get("TLLM_BENCHMARK_REQ_QUEUES_SIZE", 0))
self._disable_mpi = mpi_disabled()
# list of requests in each PP micro batch
self.num_micro_batches = self.dist.pp_size
@ -391,11 +394,19 @@ class PyExecutor:
def __exit__(self):
self.shutdown()
def enqueue_requests(self, requests: List[ExecutorRequest]) -> List[int]:
def enqueue_requests(
self,
requests: List[ExecutorRequest],
result_wait_queue: "Optional[ray.actor.ActorHandle]" = None
) -> List[int]:
"""
Enqueue new requests
"""
req_ids = self.executor_request_queue.enqueue_requests(requests)
if result_wait_queue is not None:
with self.response_cv:
for req_id in req_ids:
self.result_wait_queues[req_id] = result_wait_queue
return req_ids
def await_responses(
@ -420,6 +431,7 @@ class PyExecutor:
for req_id in id:
responses.append(
self._await_single_response(id=req_id, timeout=timeout))
return responses
def cancel_request(self, id: int):
@ -477,14 +489,18 @@ class PyExecutor:
def wait_shutdown(self):
self.shutdown_event.wait()
def enqueue_request(self,
request: ExecutorRequest,
query: Optional[List] = None) -> int:
def enqueue_request(
self,
request: ExecutorRequest,
query: Optional[List] = None,
result_wait_queue: "Optional[ray.actor.ActorHandle]" = None) -> int:
"""
Enqueue a new request, query is only used in `StarAttention`.
"""
req_id = self.executor_request_queue.enqueue_request(request, query)
if result_wait_queue is not None:
with self.response_cv:
self.result_wait_queues[req_id] = result_wait_queue
return req_id
def set_gather_responses(self, gather_all_responses):
@ -900,10 +916,12 @@ class PyExecutor:
if previous_batch is not None:
sample_state = previous_batch.sample_state
if not self.dist.is_last_pp_rank:
recv_object_funct = self.dist.recv_object_from_isend if self._disable_mpi \
else self.dist.recv_object
torch.cuda.nvtx.range_push(
"_handle_new_tokens_inter_pp")
# Receive tokens from previous pp rank (w.r.t model forward direction)
sample_state.host = self.dist.recv_object(
sample_state.host = recv_object_funct(
src=self.dist.prev_pp_rank,
tag=prev_microbatch_id,
)
@ -1849,7 +1867,6 @@ class PyExecutor:
for request in failed_requests:
req_id = request.py_request_id
request.state = LlmRequestState.GENERATION_COMPLETE
self._terminate_request(request)
error_responses[req_id] = LlmResponse(
request_id=req_id,
error_msg=error_msg,
@ -1861,7 +1878,9 @@ class PyExecutor:
request for request in self.active_requests
if request not in requests
]
self._enqueue_responses(error_responses.items())
self._enqueue_responses(list(error_responses.items()))
for request in failed_requests:
self._terminate_request(request)
def _terminate_request(self, request: LlmRequest):
if self._disagg_pp_termination_handler is not None:
@ -1885,6 +1904,8 @@ class PyExecutor:
self.resource_manager.free_resources(request)
else:
self.resource_manager.free_resources(request)
if self.gather_all_responses or self.dist.rank == 0:
self.result_wait_queues.pop(request.py_request_id, None)
def _is_request_in_transmission(self, request) -> bool:
"""Check if a request is currently in transmission state."""
@ -1960,6 +1981,13 @@ class PyExecutor:
self.responses[req_id].append(resp)
else:
self.responses.update({req_id: [resp]})
# (TODO: joyang) There are other types of responses, we need to sort out.
if type(
resp
) == LlmResponse and req_id in self.result_wait_queues and self.result_wait_queues[
req_id] is not None:
self.result_wait_queues[req_id].put_response.remote(
resp.client_id, resp)
self.response_cv.notify_all()
@nvtx_range("_handle_first_token_response")
@ -2030,9 +2058,10 @@ class PyExecutor:
self.active_requests.clear()
self.active_requests.extend(new_active_requests)
# Request should be terminated after enqueueing response to ensure we can enqueue response successfully.
self._enqueue_responses(new_responses)
for request in requests_to_terminate:
self._terminate_request(request)
self._enqueue_responses(new_responses)
return requests_to_terminate
@nvtx_range("_terminate_ctx_finished_requests")

View File

@ -12,7 +12,7 @@ import torch
import tensorrt_llm
from tensorrt_llm._torch.pyexecutor.resource_manager import ResourceManagerType
from tensorrt_llm._utils import get_sm_version
from tensorrt_llm._utils import get_sm_version, mpi_disabled
from tensorrt_llm.bindings.executor import (CapacitySchedulerPolicy,
ContextChunkingPolicy,
GuidedDecodingConfig)
@ -28,7 +28,7 @@ from tensorrt_llm.mapping import Mapping
from tensorrt_llm.quantization import QuantAlgo
from ..attention_backend.interface import AttentionRuntimeFeatures
from ..distributed import MPIDist
from ..distributed import MPIDist, TorchDist
from ..speculative import (get_num_extra_kv_tokens, get_spec_drafter,
get_spec_resource_manager)
from ..utils import _get_allow_chain_drafter
@ -293,7 +293,10 @@ def create_py_executor(
pytorch_backend_config.disable_overlap_scheduler = True
mapping = _get_mapping(llm_args.parallel_config.to_mapping())
dist = MPIDist(mapping=mapping)
if mpi_disabled():
dist = TorchDist(mapping=mapping)
else:
dist = MPIDist(mapping=mapping)
cache_transceiver_config = None
if llm_args.cache_transceiver_config is not None:

View File

@ -9,6 +9,7 @@ import torch
import tensorrt_llm
import tensorrt_llm.bindings
from tensorrt_llm._utils import mpi_disabled
from tensorrt_llm.bindings.BuildInfo import ENABLE_MULTI_DEVICE
from tensorrt_llm.lora_helper import LoraConfig
from tensorrt_llm.lora_manager import LoraManager, LoraModelConfig
@ -610,7 +611,12 @@ class KVCacheManager(BaseResourceManager):
if mapping.world_size > 1:
# make sure all ranks use same value for maxTokens
max_tokens = mpi_comm().allreduce(max_tokens, op=MPI.MIN)
if mpi_disabled():
from tensorrt_llm._utils import torch_comm
max_tokens = torch_comm().allreduce(
max_tokens, op=torch.distributed.ReduceOp.MIN)
else:
max_tokens = mpi_comm().allreduce(max_tokens, op=MPI.MIN)
# get number of blocks
blocks_in_primary_pool = math.ceil(max_tokens / tokens_per_block)

View File

@ -13,7 +13,7 @@ import torch.nn.functional as F
from tensorrt_llm._torch.pyexecutor.make_decoding_batch_input_output import \
MakeDecodingBatchInputOutput
from tensorrt_llm._utils import nvtx_range, torch_dtype_to_binding
from tensorrt_llm._utils import mpi_disabled, nvtx_range, torch_dtype_to_binding
from tensorrt_llm.bindings import (CudaStream, DataType, ModelConfig,
WorldConfig, make_sampling_config)
from tensorrt_llm.bindings.executor import (DecodingConfig, DecodingMode,
@ -1781,8 +1781,16 @@ class TRTLLMSampler(Sampler):
2 if self.is_trt_overlap else 1)
self.micro_batch_idx = 0
self.world_config = WorldConfig.mpi(mapping.gpus_per_node,
mapping.tp_size, mapping.pp_size)
if mpi_disabled():
self.world_config = WorldConfig(mapping.tp_size,
mapping.pp_size,
mapping.cp_size,
rank=mapping.rank,
gpus_per_node=mapping.gpus_per_node)
else:
self.world_config = WorldConfig.mpi(mapping.gpus_per_node,
mapping.tp_size,
mapping.pp_size)
self.model_config = ModelConfig(vocab_size, num_hidden_layers,
num_hidden_layers, 0, num_heads,
hidden_size, self.model_datatype)

View File

@ -312,3 +312,11 @@ def create_lm_head_tp_mapping(mapping: Mapping, token_count: int) -> Mapping:
# It's here so that unit tests can mock it and turn it off.
def _get_allow_chain_drafter() -> bool:
return True
def get_device_uuid(device_idx: int) -> str:
"""Get the UUID of a CUDA device using torch cuda api"""
property = torch.cuda.get_device_properties(device_idx)
uuid = "GPU-" + str(property.uuid)
return uuid

View File

@ -19,6 +19,7 @@ import json
import linecache
import math
import os
import socket
import struct
import tempfile
import trace
@ -468,6 +469,12 @@ def dim_resolve_negative(dim, ndim):
return tuple(pos)
def get_free_port():
with socket.socket() as sock:
sock.bind(("", 0))
return sock.getsockname()[1]
# mpi4py only exports MPI_COMM_TYPE_SHARED, so we define OMPI_COMM_TYPE_HOST here
OMPI_COMM_TYPE_HOST = 9
@ -490,11 +497,44 @@ def local_mpi_comm():
return local_comm
# Global TorchDist instance for Ray orchestrator
_torch_comm = None
def set_torch_comm(torch_comm_instance):
"""Set global TorchDist instance"""
global _torch_comm
_torch_comm = torch_comm_instance
def torch_comm():
"""Get global TorchDist instance"""
if _torch_comm is None:
raise RuntimeError(
"TorchDist not initialized. Call set_torch_comm() first.")
return _torch_comm
def mpi_disabled() -> bool:
"""True if TLLM_DISABLE_MPI is set to "1", False otherwise."""
return os.environ.get("TLLM_DISABLE_MPI") == "1"
def mpi_rank():
if mpi_disabled():
try:
return torch.distributed.get_rank()
except ValueError:
# Fallback: return 0 when MPI is absent (Ray / Slurm PMIx)
return 0
return mpi_comm().Get_rank() if ENABLE_MULTI_DEVICE else 0
def global_mpi_rank():
if mpi_disabled():
# Fallback: return 0 when MPI is absent (Ray / Slurm PMIx)
return 0
return MPI.COMM_WORLD.Get_rank() if ENABLE_MULTI_DEVICE else 0
@ -702,6 +742,9 @@ def is_trace_enabled(env_var: str):
value = os.environ.get(env_var, "-1")
if value == "ALL":
return True
if value == "-1":
# early return w/o calling global_mpi_rank() for Ray path
return False
try:
return int(value) == global_mpi_rank()
except ValueError:
@ -1150,3 +1193,13 @@ def set_prometheus_multiproc_dir() -> object:
os.environ["PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
logger.info(
f"PROMETHEUS_MULTIPROC_DIR: {os.environ['PROMETHEUS_MULTIPROC_DIR']}")
TORCH_PYBIND11_ABI = None
def torch_pybind11_abi() -> str:
global TORCH_PYBIND11_ABI
if TORCH_PYBIND11_ABI is None:
TORCH_PYBIND11_ABI = f"{torch._C._PYBIND11_COMPILER_TYPE}{torch._C._PYBIND11_STDLIB}{torch._C._PYBIND11_BUILD_ABI}"
return TORCH_PYBIND11_ABI

View File

@ -100,6 +100,15 @@ class BaseWorker(GenerationExecutor):
if global_mpi_size() > 1:
logger.set_rank(self.global_rank)
def _get_comm_ranks_device_id(self):
device_id = self.global_rank % torch.cuda.device_count()
torch.cuda.set_device(device_id)
# Make sure C++ executor would use same devices/ranks as py_executor
global_rank = global_mpi_rank()
comm_ranks = mpi_comm().allgather(global_rank)
device_ids = mpi_comm().allgather(device_id)
return comm_ranks, device_ids
def setup_engine(self):
"""
Setup the engine for the worker.
@ -108,21 +117,12 @@ class BaseWorker(GenerationExecutor):
if isinstance(self._engine, list):
self._engine = self._engine[self.rank]
def _get_comm_ranks_device_id():
device_id = self.global_rank % torch.cuda.device_count()
torch.cuda.set_device(device_id)
# Make sure C++ executor would use same devices/ranks as py_executor
global_rank = global_mpi_rank()
comm_ranks = mpi_comm().allgather(global_rank)
device_ids = mpi_comm().allgather(device_id)
return comm_ranks, device_ids
def _create_py_executor():
args = {}
assert hasattr(
self.llm_args, "backend"
), "llm_args should be with backend in _create_py_executor"
_ = _get_comm_ranks_device_id()
_ = self._get_comm_ranks_device_id()
if self.llm_args.backend == "pytorch":
from tensorrt_llm._torch.pyexecutor.py_executor_creator import \
create_py_executor
@ -168,7 +168,7 @@ class BaseWorker(GenerationExecutor):
executor_config.logits_post_processor_config = tllm.LogitsPostProcessorConfig(
processor_batched=self._batched_logits_processor,
replicate=False)
comm_ranks, device_ids = _get_comm_ranks_device_id()
comm_ranks, device_ids = self._get_comm_ranks_device_id()
executor_config.parallel_config = tllm.ParallelConfig(
participant_ids=comm_ranks, device_ids=device_ids)
@ -308,7 +308,9 @@ class BaseWorker(GenerationExecutor):
model_config=self._runtime_model_config,
uids=[str(prompt_adapter_request.adapter_id)])
def _enqueue_request(self, request: GenerationRequest) -> int:
def _enqueue_request(self,
request: GenerationRequest,
result_wait_queue=None) -> int:
assert request.id is not None
py_lora_path = None
if self._lora_manager is not None and request.lora_request is not None:
@ -506,10 +508,20 @@ class BaseWorker(GenerationExecutor):
if request.query_token_ids is not None:
# pytorch star attention workflow
# a workaround to avoid public interface update
req_id = self.engine.enqueue_request(executor_request,
request.query_token_ids)
if self._is_pytorch_backend and result_wait_queue is not None:
req_id = self.engine.enqueue_request(
executor_request,
request.query_token_ids,
result_wait_queue=result_wait_queue)
else:
req_id = self.engine.enqueue_request(
executor_request, request.query_token_ids)
else:
req_id = self.engine.enqueue_request(executor_request)
if self._is_pytorch_backend and result_wait_queue is not None:
req_id = self.engine.enqueue_request(
executor_request, result_wait_queue=result_wait_queue)
else:
req_id = self.engine.enqueue_request(executor_request)
return req_id
except Exception as e:
raise RequestError(str(e)) from e
@ -557,7 +569,7 @@ class BaseWorker(GenerationExecutor):
def __exit__(self, exc_type, exc_value, traceback) -> bool:
self.shutdown()
return exc_type is None or exc_type == BaseWorker.WorkerExit
return exc_type is None or exc_type == self.WorkerExit
def __del__(self):
self.shutdown()

View File

@ -7,8 +7,8 @@ import traceback
from abc import ABC, abstractmethod
from pathlib import Path
from queue import Queue
from typing import (TYPE_CHECKING, AsyncIterable, Generator, List, Optional,
Union)
from typing import (TYPE_CHECKING, AsyncIterable, Dict, Generator, List,
Optional, Union)
import numpy as np
import torch
@ -21,7 +21,7 @@ from .._utils import mpi_world_size
from ..bindings import executor as tllm
from ..builder import Engine
from ..disaggregated_params import DisaggregatedParams
from ..llmapi.llm_args import BaseLlmArgs, KvCacheConnectorConfig
from ..llmapi.llm_args import BaseLlmArgs, KvCacheConnectorConfig, TorchLlmArgs
from ..llmapi.llm_utils import KvCacheRetentionConfig
from ..llmapi.mpi_session import (MpiSession, external_mpi_comm_available,
need_spawn_mpi_workers)
@ -103,6 +103,9 @@ class GenerationExecutor(ABC):
self._iter_kv_events_result: IterationResult | None = None
self._iter_stats_result: IterationResult | None = None
def use_ray_queue(self) -> bool:
return False
@abstractmethod
def submit(self, request: GenerationRequest) -> GenerationResult:
pass
@ -355,6 +358,23 @@ class GenerationExecutor(ABC):
self._iter_kv_events_result.set_timeout(timeout)
return self._iter_kv_events_result
@staticmethod
def _create_ray_executor(
worker_kwargs: Dict,
model_world_size: int,
postproc_worker_config: PostprocWorkerConfig,
is_llm_executor: bool,
tp_size: int,
kv_connector_config: Optional[KvCacheConnectorConfig] = None):
from .ray_executor import RayExecutor
return RayExecutor(worker_kwargs,
model_world_size=model_world_size,
postproc_worker_config=postproc_worker_config,
is_llm_executor=is_llm_executor,
tp_size=tp_size,
kv_connector_config=kv_connector_config)
@staticmethod
def create(
engine: Union[Path, Engine],
@ -372,6 +392,7 @@ class GenerationExecutor(ABC):
hf_model_dir: Optional[Path] = None,
tokenizer: Optional[TokenizerBase] = None,
llm_args: Optional[BaseLlmArgs] = None,
**args,
) -> Union["GenerationExecutorProxy", "GenerationExecutorWorker"]:
# local imports to avoid cyclic importing
from .proxy import GenerationExecutorProxy
@ -406,6 +427,20 @@ class GenerationExecutor(ABC):
if lora_config:
worker_kwargs["lora_config"] = lora_config
orchestrator_type = None if not isinstance(
llm_args, TorchLlmArgs) else llm_args.orchestrator_type
if orchestrator_type == "ray":
return GenerationExecutor._create_ray_executor(
worker_kwargs,
model_world_size,
postproc_worker_config,
is_llm_executor=is_llm_executor,
tp_size=args.get("tp_size", 1),
kv_connector_config=kv_connector_config)
elif orchestrator_type is not None:
raise ValueError(
f"Unsupported orchestrator_type: {orchestrator_type}")
# The case where the Python main process is launched by mpirun
mpirun_launch = external_mpi_comm_available(model_world_size)
# The case where the Python main process utilizes mpi4py to spawn MPI workers

View File

@ -0,0 +1,305 @@
import os
from typing import Any, Dict, List, Optional, Tuple
try:
import ray
except ModuleNotFoundError as e:
e.msg = """Cannot import Ray. Please install 'ray' package to use ray orchestrator"""
raise
from ray.util.placement_group import (PlacementGroup,
PlacementGroupSchedulingStrategy,
get_current_placement_group,
placement_group)
from tensorrt_llm._utils import get_free_port
from tensorrt_llm.logger import logger
from .._utils import nvtx_range_debug
from ..llmapi.llm_args import KvCacheConnectorConfig
from .executor import GenerationExecutor
from .postproc_worker import PostprocWorkerConfig
from .ray_gpu_worker import RayGPUWorker, RayWorkerWrapper
from .request import GenerationRequest
from .result import GenerationResult, RayAsyncQueue, RaySyncQueue
__all__ = [
"RayExecutor",
]
class RayExecutor(GenerationExecutor):
def __init__(self,
worker_kwargs: Dict,
model_world_size: int,
postproc_worker_config: PostprocWorkerConfig,
is_llm_executor: bool,
tp_size=1,
kv_connector_config: Optional[KvCacheConnectorConfig] = None):
os.environ['RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES'] = '1'
os.environ["RAY_DEDUP_LOGS"] = "0" # for debug
super().__init__(model_world_size, postproc_worker_config,
is_llm_executor)
self.has_start_local_cluser = False
runtime_env = {
"env_vars": {
"RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1"
}
}
ray_init_args = {
"include_dashboard": False,
"namespace": "trtllm",
"ignore_reinit_error": True,
"runtime_env": runtime_env
}
if os.environ.get("TLLM_RAY_FORCE_LOCAL_CLUSTER", "0") != "1":
try:
ray.init(address="auto", **ray_init_args)
logger.info(f"Attached to an existing Ray cluster.")
except ConnectionError:
logger.info(f"Ray cluster not found, starting a new one.")
if not ray.is_initialized():
ray.init(**ray_init_args)
self.has_start_local_cluser = True
else:
ray.init(address="local", **ray_init_args)
self.has_start_local_cluser = True
self.world_size = model_world_size
self.tp_size = tp_size
self.master_address = ray.util.get_node_ip_address()
self.master_port = get_free_port()
self.response_queue = RayAsyncQueue.options(runtime_env={
"env_vars": {
"TLLM_DISABLE_MPI": "1"
}
}).remote()
self.response_sync_queue = RaySyncQueue.options(runtime_env={
"env_vars": {
"TLLM_DISABLE_MPI": "1"
}
}).remote()
self.async_response_queue_weakref = self.create_actor_weak_ref(
self.response_queue)
self.sync_response_queue_weakref = self.create_actor_weak_ref(
self.response_sync_queue)
self.response_queue.warmup.remote()
self.response_sync_queue.warmup.remote()
worker_kwargs = dict(**worker_kwargs,
postproc_worker_config=postproc_worker_config,
is_llm_executor=is_llm_executor,
kv_connector_config=kv_connector_config)
self.create_workers(RayGPUWorker, worker_kwargs)
@staticmethod
def create_actor_weak_ref(actor_handle: ray.actor.ActorHandle):
state, _, _ = actor_handle._serialization_helper()
return ray.actor.ActorHandle._deserialization_helper(state,
weak_ref=True)
def use_ray_queue(self) -> bool:
return True
def create_workers(self, worker_cls, worker_kwargs):
# When set to be a fraction, it allows Ray to schedule
# multiple actors on a single GPU for colocate use cases.
num_gpus = float(os.getenv("TRTLLM_RAY_PER_WORKER_GPUS", "1.0"))
logger.debug(f"{num_gpus=} for each worker.")
runtime_env = ray.runtime_env.RuntimeEnv()
runtime_env["env_vars"] = os.environ.copy()
runtime_env["env_vars"].update({
"TLLM_DISABLE_MPI": "1",
"MASTER_ADDR": self.master_address, # head-IP for NCCL/Gloo
"MASTER_PORT": str(self.master_port)
})
self.placement_group, self.bundle_indices = self._get_placement_group(
tp_size=self.tp_size)
self.workers = [
RayWorkerWrapper.options(
num_gpus=num_gpus,
runtime_env=runtime_env, # per-actor env
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=self.placement_group,
placement_group_bundle_index=self.bundle_indices[rank],
)).remote(worker_cls, worker_kwargs, self.world_size, rank)
for rank in range(self.world_size)
]
ray.get([worker.__ray_ready__.remote() for worker in self.workers])
def call_all_ray_workers(self, func: str, leader_only: bool,
async_call: bool, *args, **kwargs):
workers = (self.workers[0], ) if leader_only else self.workers
if async_call:
return [
getattr(worker, func).remote(*args, **kwargs)
for worker in workers
]
else:
return ray.get([
getattr(worker, func).remote(*args, **kwargs)
for worker in workers
])
def collective_rpc(self,
method: str,
args: tuple = (),
kwargs: Optional[dict] = None,
non_block: bool = False,
unique_reply_rank: Optional[int] = None) -> list[Any]:
workers = (self.workers[unique_reply_rank],
) if unique_reply_rank is not None else self.workers
kwargs = kwargs or {}
refs = []
for w in workers:
try:
refs.append(getattr(w, method).remote(*args, **kwargs))
except AttributeError:
# Here worker is the RayWorkerWrapper.
# For extended worker methods, we need to use call_worker_method since
# Ray actor doesn't work with __getattr__ delegation.
refs.append(w.call_worker_method.remote(method, *args,
**kwargs))
return refs if non_block else ray.get(refs)
def submit(self, request: GenerationRequest) -> GenerationResult:
"""
Low-level API to the executor. Return a "future" GenerationResult
which can be waited.
Forwards the request to the workers through the request queue.
"""
request.set_id(self._get_next_client_id())
logprob_params = self._get_logprob_params(request)
result = GenerationResult(
request,
background_error_handler=self._handle_background_error,
executor=self,
disaggregated_params=request.disaggregated_params,
logprob_params=logprob_params)
with nvtx_range_debug("request_queue.put"):
self.call_all_ray_workers("enqueue_request",
leader_only=True,
request=request,
async_call=False,
result_wait_queue=result.queue)
return result
def report_device_ids(self) -> list[str]:
gpu_ids = self.call_all_ray_workers("report_device_id",
leader_only=False,
async_call=False)
return sorted(gpu_ids)
def abort_request(self, request_id: int) -> None:
self.call_all_ray_workers("abort_request",
leader_only=True,
async_call=False,
request_id=request_id)
def shutdown(self):
# Release actors
self.response_queue = None
self.response_sync_queue = None
self.async_response_queue_weakref = None
self.sync_response_queue_weakref = None
self.workers = None
if hasattr(self,
"placement_group") and self.placement_group is not None:
ray.util.remove_placement_group(self.placement_group)
self.placement_group = None
self.bundle_indices = None
if self.has_start_local_cluser:
logger.debug("Shutting down Ray cluster")
ray.shutdown()
@property
def enable_postprocess_parallel(self) -> bool:
ret = super().enable_postprocess_parallel
assert ret == False, "Postprocess parallel is not supported in RayExecutor"
return ret
def _get_placement_group(self,
tp_size: int) -> Tuple[PlacementGroup, List[int]]:
"""
Either use the existing placement group from driver script (e.g., in the case of RL FW integration),
or create a default PACK placement group where each bundle has tp_size GPUs.
- When tp_size GPUs per node, keep one TP group per node.
- When tp_size > GPUs per node, allow a TP group span nodes.
- rank 0 must be put on the driver node
"""
bundle_indices = os.getenv("TRTLLM_RAY_BUNDLE_INDICES", None)
if bundle_indices:
pg = get_current_placement_group()
if pg is not None:
bundle_indices = list(map(int, bundle_indices.split(",")))
assert len(bundle_indices) == self.world_size, (
f"Need {self.world_size} bundle indices for world_size, got {bundle_indices=}"
)
assert len(set(bundle_indices)) == len(bundle_indices), \
f"TRTLLM_RAY_BUNDLE_INDICES cannot have duplicate values, but got {bundle_indices=}."
assert max(bundle_indices) < len(pg.bundle_specs), \
f"{bundle_indices=} out of range for PG with {len(pg.bundle_specs)} bundles"
logger.info(
f"Found existing placement group {pg.bundle_specs=}. {bundle_indices=}"
)
# TODO: need to ping TP group onto the same node for RL FW integration case
return pg, bundle_indices
else:
logger.warning(
f"Ignoring TRTLLM_RAY_BUNDLE_INDICES={bundle_indices} because no global placement group is found."
)
if self.world_size % tp_size:
raise ValueError("world_size must be a multiple of tp_size")
head_tag = f"node:{self.master_address}"
nodes = ray.nodes()
gpus_per_node = int(nodes[0]["Resources"].get(
"GPU", 0)) # assume symmetric across nodes
bundle_cpu = bundle_gpu = min(tp_size, gpus_per_node)
bundles, bundle_indices = [], []
current = 0
for rank in range(self.world_size):
if current == 0:
bundle = {"GPU": bundle_gpu, "CPU": bundle_cpu}
if len(bundles) == 0:
bundle[head_tag] = 0.01 # to force placement on head node
bundles.append(bundle)
bundle_indices.append(len(bundles) - 1)
current = (current + 1) % bundle_gpu
strategy = "PACK"
logger.debug(
f"[Strategy={strategy}] Bundles: {bundles} for tp_size: {tp_size} and world_size: {self.world_size}"
)
pg = placement_group(bundles, strategy=strategy)
return pg, bundle_indices

View File

@ -0,0 +1,202 @@
import os
from pathlib import Path
from queue import Queue
from typing import Optional, Union
import ray
import torch
from ..bindings import executor as tllm
from ..builder import Engine
from ..llmapi.llm_args import BaseLlmArgs, KvCacheConnectorConfig
from ..llmapi.tokenizer import TokenizerBase
from ..lora_helper import LoraConfig
from ..sampling_params import BatchedLogitsProcessor
from .base_worker import BaseWorker
from .postproc_worker import PostprocWorkerConfig
from .request import GenerationRequest
from .result import GenerationResult
__all__ = [
"RayGPUWorker",
"RayWorkerWrapper",
]
@ray.remote
class RayWorkerWrapper:
def __init__(self, worker_cls, worker_kwargs, world_size, rank):
self.master_address = os.environ["MASTER_ADDR"]
self.master_port = os.environ["MASTER_PORT"]
# Ray can't pickle TensorRT logger
global logger
from tensorrt_llm.logger import logger
# Expect to see global counts w/ RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES=1,
# unless CUDA_VISIBLE_DEVICES is set.
logger.debug(
f"CUDA device count visible to Ray: {torch.cuda.device_count()}")
# Physical gpu id
self.gpu = int(ray.get_gpu_ids()[0])
local_gpu = self.physical_to_local_id(self.gpu)
torch.distributed.init_process_group(
backend="cuda:nccl,cpu:gloo",
init_method=f"tcp://{self.master_address}:{self.master_port}",
world_size=world_size,
rank=rank)
logger.info(
f"[Rank {rank}] Finished PG init. Global GPU ID: {self.gpu}, local GPU ID: {local_gpu}"
)
torch.cuda.set_device(local_gpu)
self.worker = worker_cls(device_id=local_gpu, **worker_kwargs)
def submit(self, request: GenerationRequest) -> GenerationResult:
return self.worker.submit(request)
def enqueue_request(self,
request: GenerationRequest,
result_wait_queue: Queue | None = None) -> int:
return self.worker.enqueue_request(request, result_wait_queue)
def abort_request(self, request_id: int) -> None:
self.worker.abort_request(request_id)
def report_device_id(self) -> str:
from tensorrt_llm._torch.utils import get_device_uuid
local_id = self.physical_to_local_id(self.gpu)
return get_device_uuid(local_id)
@staticmethod
def physical_to_local_id(phys_id: int) -> int:
visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES")
if not visible_devices:
return phys_id
id_mapping = list(map(int, visible_devices.split(",")))
return id_mapping.index(phys_id)
def call_worker_method(self, method_name: str, *args, **kwargs):
"""Generic method to call any method on the underlying worker."""
if hasattr(self.worker, method_name):
method = getattr(self.worker, method_name)
if callable(method):
return method(*args, **kwargs)
else:
raise AttributeError(
f"'{method_name}' is not callable on the underlying worker")
else:
raise AttributeError(
f"Underlying worker has no method '{method_name}'")
def __repr__(self) -> str:
"""Customizes the actor's prefix in the Ray logs.
This makes it easier to identify which worker is producing specific log messages.
Refer to https://github.com/NVIDIA-NeMo/RL/blob/faad02113c3c502437ccb339cb848796334aedd9/nemo_rl/models/policy/dtensor_policy_worker_v2.py#L95
"""
if torch.distributed.is_initialized():
return f"{self.__class__.__qualname__}[rank={torch.distributed.get_rank()}]"
else:
return f"{self.__class__.__qualname__}"
class RayGPUWorker(BaseWorker):
def __init__(
self,
device_id: int,
engine: Union[Path, Engine],
executor_config: Optional[tllm.ExecutorConfig] = None,
batched_logits_processor: Optional[BatchedLogitsProcessor] = None,
postproc_worker_config: Optional[PostprocWorkerConfig] = None,
is_llm_executor: Optional[bool] = None,
lora_config: Optional[LoraConfig] = None,
kv_connector_config: Optional[KvCacheConnectorConfig] = None,
hf_model_dir: Optional[Path] = None,
tokenizer: Optional[TokenizerBase] = None,
llm_args: Optional[BaseLlmArgs] = None,
) -> None:
global logger
from tensorrt_llm.logger import logger
super().__init__(
engine=engine,
executor_config=executor_config,
batched_logits_processor=batched_logits_processor,
postproc_worker_config=postproc_worker_config,
is_llm_executor=is_llm_executor,
lora_config=lora_config,
kv_connector_config=kv_connector_config,
hf_model_dir=hf_model_dir,
tokenizer=tokenizer,
llm_args=llm_args,
)
if not self._is_pytorch_backend:
raise ValueError(f"Ray GPU worker only supports PyTorch backend.")
self.device_id = device_id
# Override rank attributes using torch
self.global_rank = torch.distributed.get_rank()
if self.global_rank > 1:
logger.set_rank(self.global_rank)
self.setup_engine()
def _get_comm_ranks_device_id(self):
# Make sure C++ executor would use same devices/ranks as py_executor
global_rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
comm_ranks = [None] * world_size
device_ids = [None] * world_size
torch.distributed.all_gather_object(comm_ranks, global_rank)
torch.distributed.all_gather_object(device_ids, self.device_id)
return comm_ranks, device_ids
def enqueue_request(self,
request: GenerationRequest,
result_wait_queue: Queue | None = None) -> int:
return self._enqueue_request(request, result_wait_queue)
def submit(self, request: GenerationRequest):
raise NotImplementedError(
"Ray GPU worker does not support submit() yet.")
def shutdown(self):
if self.doing_shutdown:
return
else:
self.doing_shutdown = True
logger.debug(f'Worker {self.rank} shutting down...')
if self.engine is not None:
self.engine.shutdown()
self.engine = None
assert self._executor_config is None, "An empty executor_config is expected in shutdown when LLM arguments are defined."
if (self.llm_args.backend == "pytorch"
and hasattr(self, "checkpoint_loader")
and self.checkpoint_loader is not None):
self.checkpoint_loader.cleanup()
self.checkpoint_loader = None
# Check if there are any errors from the threads before shutdown.
self._handle_background_error()
logger.debug(f"Worker {self.rank} shutdown done.")
def __enter__(self):
return self
def __del__(self):
self.shutdown()

View File

@ -1,5 +1,6 @@
import asyncio
import json
import threading
import weakref
from dataclasses import dataclass, field
from queue import Empty, Queue
@ -10,7 +11,12 @@ from weakref import WeakMethod
import torch
import torch.nn.functional as F
from .._utils import nvtx_range_debug
try:
import ray
except ModuleNotFoundError:
from tensorrt_llm import ray_stub as ray
from .._utils import mpi_disabled, nvtx_range_debug
from ..bindings import executor as tllm
from ..disaggregated_params import DisaggregatedParams
from ..llmapi.tracer import global_tracer
@ -146,12 +152,104 @@ class CompletionOutput:
return self.logprobs[self._last_logprobs_len:]
def warmup_tensorrt_llm():
import tensorrt_llm
print("Warmup by importing tensorrt_llm with version",
tensorrt_llm.version.__version__)
@ray.remote(max_concurrency=1000000, num_cpus=2)
class RayAsyncQueue:
"""Ray actor for async response handling."""
def __init__(self):
self.data = {}
self.event_map = {}
self.warmup_done = False
def register(self, key: int):
assert key not in self.event_map, f"Key {key} already registered"
self.event_map[key] = asyncio.Event()
def unregister(self, key: int):
if key in self.event_map:
del self.event_map[key]
if key in self.data:
del self.data[key]
def warmup(self):
if self.warmup_done:
return
warmup_tensorrt_llm()
self.warmup_done = True
def put_response(self, key: int, item: Any):
assert key in self.event_map, f"Key {key} not registered"
self.data[key] = item
self.event_map[key].set()
async def get_async(self, key: int):
assert key in self.event_map, f"Key {key} not registered"
await self.event_map[key].wait()
self.event_map[key].clear()
ret = self.data[key]
del self.data[key]
return ret
SYNC_QUEUE_MAX_CONCURRENCY = 2
@ray.remote(max_concurrency=SYNC_QUEUE_MAX_CONCURRENCY,
num_cpus=SYNC_QUEUE_MAX_CONCURRENCY)
class RaySyncQueue:
"""Ray actor for sync response handling."""
def __init__(self):
self.data = {}
self.event_map = {}
self.semaphore = threading.Semaphore(SYNC_QUEUE_MAX_CONCURRENCY - 1)
self.warmup_done = False
def register(self, key: int):
assert key not in self.event_map, f"Key {key} already registered"
self.event_map[key] = threading.Event()
self.event_map[key]
def unregister(self, key: int):
if key in self.event_map:
del self.event_map[key]
if key in self.data:
del self.data[key]
def warmup(self):
if self.warmup_done:
return
warmup_tensorrt_llm()
self.warmup_done = True
def put_response(self, key: int, item: Any):
self.data[key] = item
self.event_map[key].set()
def get(self, key: int):
with self.semaphore:
self.event_map[key].wait()
self.event_map[key].clear()
ret = self.data[key]
del self.data[key]
return ret
class GenerationResultBase:
''' This holds the core logic of the GenerationResult class. '''
def __init__(self,
id: int,
sampling_params: SamplingParams,
ray_queue: Optional[RayAsyncQueue] = None,
background_error_handler: Optional[Callable] = None,
postproc_params: "Optional[PostprocParams]" = None):
self.id = id
@ -165,12 +263,22 @@ class GenerationResultBase:
self._done = False
self.metrics_dict = {}
if has_event_loop():
self.aqueue = AsyncQueue()
self.queue = self.aqueue.sync_q
if ray_queue is not None:
if has_event_loop():
self.aqueue = ray_queue
self.queue = self.aqueue
else:
self.queue = ray_queue
self.aqueue = None
ray.get(self.queue.register.remote(id))
else:
self.queue = Queue()
self.aqueue = None
if has_event_loop():
self.aqueue = AsyncQueue()
self.queue = self.aqueue.sync_q
else:
self.queue = Queue()
self.aqueue = None
# In Sampling mode, the Executor runtime will return best_of sequences
# in total, which the LLM API will select the n-best sequences among
@ -419,6 +527,12 @@ class GenerationResultBase:
else:
raise ValueError(f"Unknown response type: {response}")
if self._done and mpi_disabled():
assert hasattr(
self.queue, "unregister"
), "Ray path should be activated for unregistering the Ray queue."
self.queue.unregister.remote(self.id)
def record_stats(self,
output: CompletionOutput,
stats: Optional[dict[str, float]] = None) -> None:
@ -541,9 +655,15 @@ class GenerationResult(GenerationResultBase):
disaggregated_params: Optional[DisaggregatedParams] = None,
logprob_params: Optional[LogprobParams] = None,
) -> None:
use_async_queue = has_event_loop()
shared_queue = None
if executor and executor.use_ray_queue():
shared_queue = executor.async_response_queue_weakref if use_async_queue else executor.sync_response_queue_weakref
super().__init__(
generation_request.id,
generation_request.sampling_params,
shared_queue,
background_error_handler,
postproc_params=generation_request.postproc_params,
)
@ -597,13 +717,25 @@ class GenerationResult(GenerationResultBase):
if hasattr(self, "_logprob_params"):
del self._logprob_params
def _handle_ray_response(self, response: Any):
return response
def _result_step(self, timeout: Optional[float] = None):
response = self.queue.get(timeout=timeout)
if mpi_disabled():
response = ray.get(self.queue.get.remote(self.request_id))
response = self._handle_ray_response(response)
else:
response = self.queue.get()
self._handle_response(response)
async def _aresult_step(self):
assert self.aqueue is not None, "The asyncio event loop was not present during initialization, so async operations are not available."
response = await self.aqueue.get()
if mpi_disabled():
response = await self.aqueue.get_async.remote(self.request_id)
response = self._handle_ray_response(response)
else:
response = await self.aqueue.get()
global_tracer().log_instant("result_step.get")
self._handle_response(response)

View File

@ -41,9 +41,6 @@ __all__ = [
class GenerationExecutorWorker(BaseWorker):
class WorkerExit(GeneratorExit):
pass
def __init__(
self,
engine: Union[Path, Engine],
@ -248,10 +245,6 @@ class GenerationExecutorWorker(BaseWorker):
if isinstance(self.engine, PyExecutor):
self.engine.wait_shutdown()
def __exit__(self, exc_type, exc_value, traceback) -> bool:
self.shutdown()
return exc_type is None or exc_type == GenerationExecutorWorker.WorkerExit
@print_traceback_on_error
def worker_main(

View File

@ -13,6 +13,7 @@ import transformers
from tqdm import tqdm
from transformers import PreTrainedTokenizerBase
from tensorrt_llm._utils import mpi_disabled
from tensorrt_llm.inputs.data import TextPrompt
from tensorrt_llm.inputs.multimodal import MultimodalInput, MultimodalParams
from tensorrt_llm.inputs.registry import DefaultInputProcessor
@ -124,6 +125,7 @@ class BaseLLM:
**kwargs: Any) -> None:
self._executor_cls = kwargs.pop("executor_cls", GenerationExecutor)
self._orchestrator_type = kwargs.get("orchestrator_type", None)
self._llm_id = None
log_level = logger.level
@ -134,6 +136,12 @@ class BaseLLM:
if backend == "pytorch":
logger.info("Using LLM with PyTorch backend")
llm_args_cls = TorchLlmArgs
if self._orchestrator_type == "ray" or mpi_disabled():
self._orchestrator_type = "ray"
os.environ["TLLM_DISABLE_MPI"] = "1"
# Propagate to args construction
kwargs["orchestrator_type"] = "ray"
elif backend == '_autodeploy':
logger.info("Using LLM with AutoDeploy backend")
from .._torch.auto_deploy.llm_args import \
@ -984,6 +992,34 @@ class _TorchLLM(BaseLLM):
backend=backend,
**kwargs)
@set_api_status("prototype")
def _collective_rpc(self,
method: str,
args: tuple[Any, ...] = (),
kwargs: Optional[dict] = None,
non_block: bool = False,
unique_reply_rank: Optional[int] = None) -> list[Any]:
"""
Execute an RPC call on all GPU workers. Currently, this is only supported for RayExecutor.
Args:
method (str): The name of the worker method to execute.
args (tuple[Any, ...]): Positional arguments to pass to the worker method. Defaults to ().
kwargs (dict, optional): Keyword arguments to pass to the worker method. Defaults to None.
non_block (bool): Whether to block until all workers have completed the RPC call. Defaults to False.
unique_reply_rank (int, optional): The rank of the worker that will be used to send the reply. Defaults to None.
Returns:
list[Any]: A list of results from each worker.
"""
if hasattr(self._executor, 'collective_rpc'):
return self._executor.collective_rpc(method, args, kwargs,
non_block, unique_reply_rank)
else:
raise ValueError(
f"Executor type {type(self._executor)} does not support collective RPC."
)
def _build_model(self):
super()._build_model()
assert self._engine_dir is None

View File

@ -2437,6 +2437,13 @@ class TorchLlmArgs(BaseLlmArgs):
status="prototype",
)
orchestrator_type: Optional[Literal["ray"]] = Field(
default=None,
description=
"The orchestrator type to use. Options: 'ray'. Defaults to None, which uses MPI.",
status="prototype",
)
# PrivateVars
_quant_config: Optional[QuantConfig] = PrivateAttr(default=None)

View File

@ -17,6 +17,9 @@ from typing import List
import torch
from tensorrt_llm._torch.device_mesh import DeviceMeshTopologyImpl
from tensorrt_llm._utils import mpi_disabled
class CpType(IntEnum):
# CP type for ulysses parallelism
@ -29,101 +32,12 @@ class CpType(IntEnum):
HELIX = 3
class Mapping(object):
'''
A node with 8 GPUs, tp_size = 4, cp_size = 1, pp_size = 2
class MappingBase:
"""Base class for distributed mapping configurations"""
2 tp groups:
- [0, 1, 2, 3]
- [4, 5, 6, 7]
4 pp groups:
- [0, 4]
- [1, 5]
- [2, 6]
- [3, 7]
A node with 8 GPUs, tp_size = 4, cp_size = 2, pp_size = 1
2 tp groups:
- [0, 1, 2, 3]
- [4, 5, 6, 7]
4 cp groups:
- [0, 4]
- [1, 5]
- [2, 6]
- [3, 7]
A node with 8 GPUs, moe_tp_size = 2, moe_ep_size = 4
4 moe_tp groups:
- [0, 4]
- [1, 5]
- [2, 6]
- [3, 7]
2 moe_ep groups:
- [0, 1, 2, 3]
- [4, 5, 6, 7]
2 nodes with 16 GPUs, moe_tp_size = 2, moe_ep_size = 4, pp_size = 2
8 moe_tp groups:
- [0 4]
- [1 5]
- [2 6]
- [3 7]
- [8 12]
- [9 13]
- [10 14]
- [11 15]
4 moe_ep groups:
- [0, 1, 2, 3]
- [4, 5, 6, 7]
- [8, 9, 10, 11]
- [12, 13, 14, 15]
8 pp groups:
- [0 8]
- [1 9]
- [2 10]
- [3 11]
- [4 12]
- [5 13]
- [6 14]
- [7 15]
2 nodes with 8 GPUs, tp_size 2, pp_size 2, cp_size 2
4 tp groups:
- [0, 1]
- [2, 3]
- [4, 5]
- [6, 7]
4 pp groups:
- [0, 4]
- [1, 5]
- [2, 6]
- [3, 7]
4 cp groups:
- [0, 2]
- [1, 3]
- [4, 6]
- [5, 7]
'''
tp_rank: int
pp_rank: int
cp_rank: int
def __init__(
self,
@ -196,11 +110,11 @@ class Mapping(object):
)
moe_tp_ep_size = moe_tp_size * moe_ep_size
moe_tp_cluster_ep_size = moe_tp_ep_size * moe_cluster_size
if moe_tp_cluster_ep_size != moe_world_size:
self.moe_tp_cluster_ep_size = moe_tp_ep_size * moe_cluster_size
if self.moe_tp_cluster_ep_size != moe_world_size:
raise ValueError(
"moe_tp_size * moe_ep_size * moe_cluster_size must equal to moe_world_size, "
f"but got {moe_tp_cluster_ep_size} != {moe_world_size}")
f"but got {self.moe_tp_cluster_ep_size} != {moe_world_size}")
attn_tp_cp_size = attn_tp_size * attn_cp_size
if attn_tp_cp_size != tp_size * cp_size:
@ -213,6 +127,9 @@ class Mapping(object):
raise NotImplementedError(
f"CP {cp_type} doesn't support MoE tp/ep yet")
if moe_cluster_size > 1:
assert moe_ep_size == 1
self.tp_size = tp_size
self.cp_size = cp_size
self.cp_config = cp_config if cp_config is not None else {}
@ -230,6 +147,7 @@ class Mapping(object):
self.enable_lm_head_tp_in_adp = enable_lm_head_tp_in_adp
self.rank = rank
self.gpus_per_node = gpus_per_node
self.pp_groups = []
self.cp_groups = []
self.tp_groups = []
@ -237,60 +155,8 @@ class Mapping(object):
self.moe_tp_groups = []
self.moe_ep_groups = []
if moe_cluster_size > 1:
assert moe_ep_size == 1
# init pp group
for i in range(tp_size * cp_size):
ranks = range(i, world_size, tp_size * cp_size)
self.pp_groups.append(list(ranks))
# init cp group
for i in range(pp_size):
for j in range(tp_size):
ranks = range(i * tp_size * cp_size + j,
(i + 1) * tp_size * cp_size + j, tp_size)
self.cp_groups.append(list(ranks))
# init tp group
for i in range(pp_size):
for j in range(cp_size):
ranks = range(i * tp_size * cp_size + j * tp_size,
i * tp_size * cp_size + (j + 1) * tp_size)
self.tp_groups.append(list(ranks))
# init moe tp group
for i in range(pp_size):
for j in range(moe_cluster_size * moe_ep_size):
ranks = range(i * moe_tp_cluster_ep_size + j,
(i + 1) * moe_tp_cluster_ep_size,
moe_cluster_size * moe_ep_size)
self.moe_tp_groups.append(list(ranks))
# init moe cluster group
for i in range(pp_size):
for j in range(moe_tp_size):
ranks = range(
i * moe_tp_cluster_ep_size +
j * moe_cluster_size * moe_ep_size,
i * moe_tp_cluster_ep_size +
(j + 1) * moe_cluster_size * moe_ep_size)
self.moe_cluster_groups.append(list(ranks))
# init moe ep group
for i in range(pp_size):
for j in range(moe_tp_size):
for k in range(moe_cluster_size):
ranks = range(
i * moe_tp_cluster_ep_size +
j * moe_cluster_size * moe_ep_size + k * moe_ep_size,
i * moe_tp_cluster_ep_size +
j * moe_cluster_size * moe_ep_size +
(k + 1) * moe_ep_size)
self.moe_ep_groups.append(list(ranks))
def __eq__(self, other):
if not isinstance(other, Mapping):
if not isinstance(other, MappingBase):
return NotImplemented
return (self.world_size == other.world_size and self.rank == other.rank
@ -338,20 +204,6 @@ class Mapping(object):
)
self._rank = rank
@property
def tp_rank(self):
return 0 if self.auto_parallel else self.rank % self.tp_size
@property
def pp_rank(self):
return 0 if self.auto_parallel else self.rank // (self.tp_size *
self.cp_size)
@property
def cp_rank(self):
return 0 if self.auto_parallel else self.rank % (
self.tp_size * self.cp_size) // self.tp_size
@property
def moe_tp_rank(self):
return self.tp_rank // (self.moe_ep_size * self.moe_cluster_size)
@ -364,37 +216,11 @@ class Mapping(object):
def moe_ep_rank(self):
return self.tp_rank % self.moe_ep_size
@property
def tp_group(self):
return self.tp_groups[self.pp_rank * self.cp_size + self.cp_rank]
@property
def pp_group(self):
return self.pp_groups[self.cp_rank * self.tp_size + self.tp_rank]
@property
def cp_group(self):
return self.cp_groups[self.pp_rank * self.tp_size + self.tp_rank]
@property
def moe_tp_group(self):
return self.moe_tp_groups[self.pp_rank * self.moe_cluster_size *
self.moe_ep_size +
self.moe_cluster_rank * self.moe_ep_size +
self.moe_ep_rank]
@property
def moe_cluster_group(self):
return self.moe_cluster_groups[self.pp_rank * self.moe_tp_size +
self.moe_tp_rank]
@property
def moe_ep_group(self):
return self.moe_ep_groups[self.pp_rank * self.moe_tp_size *
self.moe_cluster_size +
self.moe_tp_rank * self.moe_cluster_size +
self.moe_cluster_rank]
@property
def node_rank(self):
return self.rank // self.gpus_per_node
@ -517,3 +343,280 @@ class Mapping(object):
'enable_attention_dp': self.enable_attention_dp,
'enable_lm_head_tp_in_adp': self.enable_lm_head_tp_in_adp,
}
class Mapping(MappingBase):
"""
A node with 8 GPUs, tp_size = 4, cp_size = 1, pp_size = 2
2 tp groups:
- [0, 1, 2, 3]
- [4, 5, 6, 7]
4 pp groups:
- [0, 4]
- [1, 5]
- [2, 6]
- [3, 7]
A node with 8 GPUs, tp_size = 4, cp_size = 2, pp_size = 1
2 tp groups:
- [0, 1, 2, 3]
- [4, 5, 6, 7]
4 cp groups:
- [0, 4]
- [1, 5]
- [2, 6]
- [3, 7]
A node with 8 GPUs, moe_tp_size = 2, moe_ep_size = 4
4 moe_tp groups:
- [0, 4]
- [1, 5]
- [2, 6]
- [3, 7]
2 moe_ep groups:
- [0, 1, 2, 3]
- [4, 5, 6, 7]
2 nodes with 16 GPUs, moe_tp_size = 2, moe_ep_size = 4, pp_size = 2
8 moe_tp groups:
- [0 4]
- [1 5]
- [2 6]
- [3 7]
- [8 12]
- [9 13]
- [10 14]
- [11 15]
4 moe_ep groups:
- [0, 1, 2, 3]
- [4, 5, 6, 7]
- [8, 9, 10, 11]
- [12, 13, 14, 15]
8 pp groups:
- [0 8]
- [1 9]
- [2 10]
- [3 11]
- [4 12]
- [5 13]
- [6 14]
- [7 15]
2 nodes with 8 GPUs, tp_size 2, pp_size 2, cp_size 2
4 tp groups:
- [0, 1]
- [2, 3]
- [4, 5]
- [6, 7]
4 pp groups:
- [0, 4]
- [1, 5]
- [2, 6]
- [3, 7]
4 cp groups:
- [0, 2]
- [1, 3]
- [4, 6]
- [5, 7]
"""
def __new__(cls, *args, **kwargs):
if mpi_disabled():
return super().__new__(DeviceMeshTopology)
else:
return super().__new__(MpiTopology)
# Intentionally repeated for type hints
def __init__(
self,
world_size=1,
rank=0,
gpus_per_node=8,
*,
cp_size=1,
cp_config=None,
tp_size=1,
pp_size=1,
moe_cluster_size=-1, # -1 means no moe
moe_tp_size=-1, # -1 means no moe
moe_ep_size=-1, # -1 means no moe
attn_tp_size=-1,
attn_cp_size=-1,
auto_parallel=False,
enable_attention_dp=False,
enable_lm_head_tp_in_adp=False):
super().__init__(world_size=world_size,
rank=rank,
gpus_per_node=gpus_per_node,
cp_size=cp_size,
cp_config=cp_config,
tp_size=tp_size,
pp_size=pp_size,
moe_cluster_size=moe_cluster_size,
moe_tp_size=moe_tp_size,
moe_ep_size=moe_ep_size,
attn_tp_size=attn_tp_size,
attn_cp_size=attn_cp_size,
auto_parallel=auto_parallel,
enable_attention_dp=enable_attention_dp,
enable_lm_head_tp_in_adp=enable_lm_head_tp_in_adp)
# DeviceMesh specific methods
@property
def tp_group_pg(self):
raise NotImplementedError("tp_group_pg is not implemented.")
@property
def pp_group_pg(self):
raise NotImplementedError("pp_group_pg is not implemented.")
@property
def cp_group_pg(self):
raise NotImplementedError("cp_group_pg is not implemented.")
@property
def moe_tp_group_pg(self):
raise NotImplementedError("moe_tp_group_pg is not implemented.")
@property
def moe_ep_group_pg(self):
raise NotImplementedError("moe_ep_group_pg is not implemented.")
def build_mesh(self):
raise NotImplementedError("build_mesh is not implemented.")
class MpiTopology(Mapping):
'''MPI-based mapping implementation'''
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._init_parallel_groups()
@property
def tp_rank(self) -> int:
return 0 if self.auto_parallel else self.rank % self.tp_size
@property
def pp_rank(self) -> int:
return 0 if self.auto_parallel else self.rank // (self.tp_size *
self.cp_size)
@property
def cp_rank(self) -> int:
return 0 if self.auto_parallel else self.rank % (
self.tp_size * self.cp_size) // self.tp_size
@property
def tp_group(self) -> List[int]:
return self.tp_groups[self.pp_rank * self.cp_size + self.cp_rank]
@property
def pp_group(self) -> List[int]:
return self.pp_groups[self.cp_rank * self.tp_size + self.tp_rank]
@property
def cp_group(self) -> List[int]:
return self.cp_groups[self.pp_rank * self.tp_size + self.tp_rank]
@property
def moe_tp_group(self) -> List[int]:
return self.moe_tp_groups[self.pp_rank * self.moe_cluster_size *
self.moe_ep_size +
self.moe_cluster_rank * self.moe_ep_size +
self.moe_ep_rank]
@property
def moe_ep_group(self) -> List[int]:
return self.moe_ep_groups[self.pp_rank * self.moe_tp_size *
self.moe_cluster_size +
self.moe_tp_rank * self.moe_cluster_size +
self.moe_cluster_rank]
@property
def moe_cluster_group(self) -> List[int]:
return self.moe_cluster_groups[self.pp_rank * self.moe_tp_size +
self.moe_tp_rank]
def _init_parallel_groups(self):
# init pp group
for i in range(self.tp_size * self.cp_size):
ranks = range(i, self.world_size, self.tp_size * self.cp_size)
self.pp_groups.append(list(ranks))
# init cp group
for i in range(self.pp_size):
for j in range(self.tp_size):
ranks = range(i * self.tp_size * self.cp_size + j,
(i + 1) * self.tp_size * self.cp_size + j,
self.tp_size)
self.cp_groups.append(list(ranks))
# init tp group
for i in range(self.pp_size):
for j in range(self.cp_size):
ranks = range(
i * self.tp_size * self.cp_size + j * self.tp_size,
i * self.tp_size * self.cp_size + (j + 1) * self.tp_size)
self.tp_groups.append(list(ranks))
# init moe tp group
for i in range(self.pp_size):
for j in range(self.moe_cluster_size * self.moe_ep_size):
ranks = range(i * self.moe_tp_cluster_ep_size + j,
(i + 1) * self.moe_tp_cluster_ep_size,
self.moe_cluster_size * self.moe_ep_size)
self.moe_tp_groups.append(list(ranks))
# init moe cluster group
for i in range(self.pp_size):
for j in range(self.moe_tp_size):
ranks = range(
i * self.moe_tp_cluster_ep_size +
j * self.moe_cluster_size * self.moe_ep_size,
i * self.moe_tp_cluster_ep_size +
(j + 1) * self.moe_cluster_size * self.moe_ep_size)
self.moe_cluster_groups.append(list(ranks))
# init moe ep group
for i in range(self.pp_size):
for j in range(self.moe_tp_size):
for k in range(self.moe_cluster_size):
ranks = range(
i * self.moe_tp_cluster_ep_size +
j * self.moe_cluster_size * self.moe_ep_size +
k * self.moe_ep_size, i * self.moe_tp_cluster_ep_size +
j * self.moe_cluster_size * self.moe_ep_size +
(k + 1) * self.moe_ep_size)
self.moe_ep_groups.append(list(ranks))
class DeviceMeshTopology(DeviceMeshTopologyImpl, Mapping):
"""PyTorch DeviceMesh-based mapping implementation"""
def __init__(self, *args, **kwargs):
assert mpi_disabled(
), "DeviceMeshTopology is only available in Ray orchestrator mode."
super().__init__(*args, **kwargs)

40
tensorrt_llm/ray_stub.py Normal file
View File

@ -0,0 +1,40 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
from tensorrt_llm._utils import mpi_disabled
if mpi_disabled():
raise RuntimeError(
"Ray requested (TLLM_DISABLE_MPI=1), but not installed. Please install Ray."
)
def remote(*args, **kwargs):
def decorator(func):
# Returns a function that always raises.
# Decorated class depends on ray, but ray is not installed.
@functools.wraps(func)
def stub_checker(*_, **__):
raise RuntimeError(
"Ray not installed, cannot use Ray based feature.")
return stub_checker
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
return decorator(args[0])
return decorator

View File

@ -31,7 +31,7 @@ from tensorrt_llm.quantization import QuantAlgo
from ..conftest import (get_device_count, get_device_memory, llm_models_root,
parametrize_with_ids, skip_no_hopper,
skip_post_blackwell, skip_pre_ada, skip_pre_blackwell,
skip_pre_hopper)
skip_pre_hopper, skip_ray)
from .accuracy_core import (GSM8K, MMLU, MMMU, CnnDailymail, GPQADiamond,
JsonModeEval, LlmapiAccuracyTestHarness)
@ -1403,6 +1403,7 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
@pytest.mark.skip_less_device(4)
@skip_pre_hopper
@skip_ray
@parametrize_with_ids("torch_compile", [False, True])
@parametrize_with_ids("fp8kv,attention_dp,cuda_graph,overlap_scheduler",
[(False, False, False, False),

View File

@ -0,0 +1,26 @@
import pytest
from tensorrt_llm import LLM
from tensorrt_llm.llmapi import KvCacheConfig
from ..conftest import llm_models_root
from .accuracy_core import MMLU, LlmapiAccuracyTestHarness
pytestmark = pytest.mark.ray
class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
MODEL_PATH = f"{llm_models_root()}/llama-3.1-model/Llama-3.1-8B-Instruct"
@pytest.mark.skip_less_device(2)
@pytest.mark.skip_less_device_memory(32000)
def test_pp2_ray(self):
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.6)
with LLM(self.MODEL_PATH,
orchestrator_type="ray",
pipeline_parallel_size=2,
kv_cache_config=kv_cache_config) as llm:
task = MMLU(self.MODEL_NAME)
task.evaluate(llm)

View File

@ -16,6 +16,7 @@ import copy
import os
import platform
import re
import socket
import time
from difflib import SequenceMatcher
from pathlib import Path
@ -1135,3 +1136,14 @@ def get_mmlu_accuracy(output):
print(f"MMLU weighted average accuracy is: {mmlu_accuracy}")
return mmlu_accuracy
def wait_for_server(host, port, timeout_seconds=180):
start_time = time.time()
while time.time() - start_time < timeout_seconds:
try:
with socket.create_connection((host, port), timeout=5):
return True
except (socket.error, ConnectionRefusedError, OSError):
time.sleep(2)
return False

View File

@ -36,6 +36,7 @@ import tqdm
import yaml
from _pytest.mark import ParameterSet
from tensorrt_llm._utils import mpi_disabled
from tensorrt_llm.bindings import ipc_nvls_supported
from tensorrt_llm.llmapi.mpi_session import get_mpi_world_size
@ -2071,6 +2072,13 @@ def pytest_addoption(parser):
parser.addoption("--perf",
action="store_true",
help="'--perf' will run perf tests")
parser.addoption(
"--run-ray",
action="store_true",
default=False,
help=
"Enable Ray orchestrator path for integration tests (disables MPI).",
)
parser.addoption(
"--perf-log-formats",
help=
@ -2171,6 +2179,8 @@ def pytest_collection_modifyitems(session, config, items):
def pytest_configure(config):
# avoid thread leak of tqdm's TMonitor
tqdm.tqdm.monitor_interval = 0
if config.getoption("--run-ray"):
os.environ["TLLM_DISABLE_MPI"] = "1"
def deselect_by_regex(regexp, items, test_prefix, config):
@ -2270,6 +2280,10 @@ def check_nvlink():
skip_nvlink_inactive = pytest.mark.skipif(check_nvlink() is False,
reason="nvlink is inactive.")
skip_ray = pytest.mark.skipif(
os.environ.get("TLLM_DISABLE_MPI") == "1",
reason="This test is skipped for Ray orchestrator.")
@pytest.fixture(scope="function")
def eval_venv(llm_venv):
@ -2448,3 +2462,15 @@ def torch_empty_cache() -> None:
"""
if torch.cuda.is_available():
torch.cuda.empty_cache()
@pytest.fixture(autouse=True)
def ray_cleanup(llm_venv) -> None:
yield
if mpi_disabled():
llm_venv.run_cmd([
"-m",
"ray.scripts.scripts",
"stop",
])

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import os
import re
import subprocess
@ -21,10 +22,12 @@ from typing import Callable
import pytest
import yaml
from defs.common import wait_for_server
from defs.conftest import (get_sm_version, llm_models_root, skip_arm,
skip_no_hopper)
from defs.trt_test_alternative import check_call, check_output, popen
from tensorrt_llm._utils import mpi_disabled
from tensorrt_llm.logger import logger
@ -267,13 +270,146 @@ def get_test_config(test_desc, example_dir, test_root):
return config_map[test_desc]
def get_extra_llm_config(config, suffix, cwd):
extra_llm_config = {
'orchestrator_type': 'ray',
}
for key, value in config.items():
if key not in ['num_instances', 'urls']:
extra_llm_config[key] = value
temp_fd, extra_config_file = tempfile.mkstemp(suffix='_%s.yaml' % suffix,
dir=cwd)
with os.fdopen(temp_fd, 'w') as f:
yaml.dump(extra_llm_config, f)
return extra_config_file
def generate_worker_commands(model_path, config, server_config,
extra_config_file, server_role):
worker_commands = []
assert model_path, "model path is required."
for url in server_config['urls']:
host, port = url.split(':')
cmd = [
'trtllm-serve', model_path, '--host', host, '--port', port,
'--backend', config['backend'], '--extra_llm_api_options',
extra_config_file, '--server_role', server_role
]
worker_commands.append(cmd)
return worker_commands
def run_client_tests(example_dir,
config_file,
test_desc,
num_iters,
env,
server_start_timeout,
prompt_file,
extra_endpoints_test,
server_url,
workers_proc,
server_proc,
use_ray=False):
"""Run client tests against the disaggregated server."""
client_dir = f"{example_dir}/clients"
for _ in range(num_iters):
client_cmd = [
'python3', f'{client_dir}/disagg_client.py', '-c', f'{config_file}',
'-p', f'{client_dir}/{prompt_file}', '--ignore-eos',
'--server-start-timeout',
str(server_start_timeout)
]
if prompt_file == "long_prompts.json":
# Use max_tokens 4 for long prompts to reduce test time
client_cmd.extend(['--max-tokens', '4'])
# Prepare poll processes
worker_processes = []
if use_ray:
for proc_cm in workers_proc:
worker_processes.append(proc_cm.__enter__())
else:
worker_processes = [workers_proc]
poll_procs = worker_processes + [server_proc]
check_call(client_cmd, env=env, poll_procs=poll_procs)
# Streaming client run
streaming_client_cmd = client_cmd + [
'--streaming', '-o', 'output_streaming.json'
]
check_call(streaming_client_cmd, env=env, poll_procs=poll_procs)
# Run the chat completion endpoint test only for TinyLlama
if test_desc == "overlap" or test_desc == "trtllm_sampler":
chat_client_cmd = client_cmd + [
'-e', 'chat', '-o', 'output_chat.json'
]
check_call(chat_client_cmd, env=env, poll_procs=poll_procs)
streaming_chat_client_cmd = chat_client_cmd + [
'--streaming', '-o', 'output_streaming_chat.json'
]
check_call(streaming_chat_client_cmd,
env=env,
poll_procs=poll_procs)
# Skip output verification for long prompts test
if prompt_file == "long_prompts.json":
continue
if extra_endpoints_test is not None:
extra_endpoints_test(server_url)
# Verify outputs
not_expected_strings = ["Berlin Berlin"]
output_files = ['output.json', 'output_streaming.json']
if test_desc == "overlap" or test_desc == "trtllm_sampler":
# Disable streaming chat completion for overlap test
# due to bug
output_files.extend(['output_chat.json'])
if test_desc.startswith("gen_only"):
continue
for output_file in output_files:
with open(output_file, 'r') as f:
content = f.read()
if "deepseek_v3_lite" in test_desc or output_file == "output_chat.json":
expected_strings = [
"Berlin", ["Asyncio is a", "Asyncio module in"]
]
else:
expected_strings = [
"The capital of Germany is Berlin",
"Asyncio is a Python library"
]
for expected_string in expected_strings:
if isinstance(expected_string, list):
# At least one of the strings in the list should be found in the content
assert any(
string in content for string in expected_string
), f"None of the strings in {expected_string} found in {output_file}"
else:
assert expected_string in content, f"Expected string '{expected_string}' not found in {output_file}"
for not_expected_string in not_expected_strings:
assert not_expected_string not in content, f"Unexpected string '{not_expected_string}' found in {output_file}"
def run_disaggregated_test(example_dir,
test_desc,
num_iters=5,
env=None,
cwd=None,
prompt_file="prompts.json",
extra_endpoints_test: Callable[[str], None] = None):
extra_endpoints_test: Callable[[str], None] = None,
model_path=None):
"""Run disaggregated test with given configuration."""
cleanup_output_files()
run_env = env.copy()
@ -282,11 +418,42 @@ def run_disaggregated_test(example_dir,
num_ranks, config_file = get_test_config(test_desc, example_dir,
os.path.dirname(__file__))
workers_cmd = [
'mpirun', '--allow-run-as-root', '--oversubscribe', '-n',
str(num_ranks), 'trtllm-serve', 'disaggregated_mpi_worker', '-c',
config_file
]
use_ray = mpi_disabled()
if not use_ray:
workers_cmd = [
'mpirun', '--allow-run-as-root', '--oversubscribe', '-n',
str(num_ranks), 'trtllm-serve', 'disaggregated_mpi_worker', '-c',
config_file
]
else:
with open(config_file, 'r') as f:
config = yaml.safe_load(f)
if config['backend'] != "pytorch":
pytest.skip(
"Ray orchestrator is only supported with pytorch backend.")
extra_config_files = []
workers_cmds = []
subprocess.run(['ray', 'start', '--head', '--disable-usage-stats'],
check=True)
# Generate ctx and gen server worker commands
ctx_extra_config_file = get_extra_llm_config(config['context_servers'],
"ctx", cwd)
extra_config_files.append(ctx_extra_config_file)
workers_cmds.extend(
generate_worker_commands(model_path, config,
config['context_servers'],
ctx_extra_config_file, 'context'))
gen_extra_config_file = get_extra_llm_config(
config['generation_servers'], "gen", cwd)
extra_config_files.append(gen_extra_config_file)
workers_cmds.extend(
generate_worker_commands(model_path, config,
config['generation_servers'],
gen_extra_config_file, 'generation'))
server_start_timeout = 900
server_cmd = [
@ -296,101 +463,79 @@ def run_disaggregated_test(example_dir,
server_url = get_disagg_server_url_from_cfg(config_file)
try:
with ( # Start workers
open('output_workers.log', 'w') as output_workers,
popen(workers_cmd,
stdout=output_workers,
stderr=subprocess.STDOUT,
env=run_env,
cwd=cwd) as workers_proc,
# Start server
open('output_disagg.log', 'w') as output_disagg,
popen(server_cmd,
stdout=output_disagg,
stderr=subprocess.STDOUT,
env=run_env,
cwd=cwd) as server_proc):
client_dir = f"{example_dir}/clients"
for _ in range(num_iters):
client_cmd = [
'python3', f'{client_dir}/disagg_client.py', '-c',
f'{config_file}', '-p', f'{client_dir}/{prompt_file}',
'--ignore-eos', '--server-start-timeout',
str(server_start_timeout)
]
if prompt_file == "long_prompts.json":
# Use max_tokens 4 for long prompts to reduce test time
client_cmd.extend(['--max-tokens', '4'])
check_call(client_cmd,
env=env,
poll_procs=[workers_proc, server_proc])
if not use_ray:
with ( # Start workers
open('output_workers.log', 'w') as output_workers,
popen(workers_cmd,
stdout=output_workers,
stderr=subprocess.STDOUT,
env=run_env,
cwd=cwd) as workers_proc,
# Start server
open('output_disagg.log', 'w') as output_disagg,
popen(server_cmd,
stdout=output_disagg,
stderr=subprocess.STDOUT,
env=run_env,
cwd=cwd) as server_proc):
run_client_tests(example_dir,
config_file,
test_desc,
num_iters,
env,
server_start_timeout,
prompt_file,
extra_endpoints_test,
server_url,
workers_proc,
server_proc,
use_ray=False)
# Streaming client run
streaming_client_cmd = client_cmd + [
'--streaming', '-o', 'output_streaming.json'
]
check_call(streaming_client_cmd,
env=env,
poll_procs=[workers_proc, server_proc])
else:
workers_proc = []
with contextlib.ExitStack() as stack:
workers_log = stack.enter_context(
open('output_workers.log', 'w'))
# Run the chat completion endpoint test only for TinyLlama
if test_desc == "overlap" or test_desc == "trtllm_sampler":
chat_client_cmd = client_cmd + [
'-e', 'chat', '-o', 'output_chat.json'
]
check_call(chat_client_cmd,
env=env,
poll_procs=[workers_proc, server_proc])
for cmd in workers_cmds:
proc = stack.enter_context(
popen(
cmd,
stdout=workers_log,
stderr=subprocess.STDOUT,
env=run_env,
cwd=cwd,
))
workers_proc.append(proc)
streaming_chat_client_cmd = chat_client_cmd + [
'--streaming', '-o', 'output_streaming_chat.json'
]
check_call(streaming_chat_client_cmd,
env=env,
poll_procs=[workers_proc, server_proc])
output_disagg = stack.enter_context(
open('output_disagg.log', 'w'))
server_proc = stack.enter_context(
popen(server_cmd,
stdout=output_disagg,
stderr=subprocess.STDOUT,
env=run_env,
cwd=cwd))
# Skip output verification for long prompts test
if prompt_file == "long_prompts.json":
continue
if not wait_for_server("localhost",
8000,
timeout_seconds=server_start_timeout):
raise RuntimeError(
f"Disaggregated server failed to start within {server_start_timeout} seconds"
)
if extra_endpoints_test is not None:
extra_endpoints_test(server_url)
# Verify outputs
not_expected_strings = ["Berlin Berlin"]
output_files = ['output.json', 'output_streaming.json']
if test_desc == "overlap" or test_desc == "trtllm_sampler":
# Disable streaming chat completion for overlap test
# due to bug
output_files.extend(['output_chat.json'])
if test_desc.startswith("gen_only"):
continue
for output_file in output_files:
with open(output_file, 'r') as f:
content = f.read()
if "deepseek_v3_lite" in test_desc or output_file == "output_chat.json":
expected_strings = [
"Berlin", ["Asyncio is a", "Asyncio module in"]
]
else:
expected_strings = [
"The capital of Germany is Berlin",
"Asyncio is a Python library"
]
for expected_string in expected_strings:
if isinstance(expected_string, list):
# At least one of the strings in the list should be found in the content
assert any(
string in content
for string in expected_string
), f"None of the strings in {expected_string} found in {output_file}"
else:
assert expected_string in content, f"Expected string '{expected_string}' not found in {output_file}"
for not_expected_string in not_expected_strings:
assert not_expected_string not in content, f"Unexpected string '{not_expected_string}' found in {output_file}"
run_client_tests(example_dir,
config_file,
test_desc,
num_iters,
env,
server_start_timeout,
prompt_file,
extra_endpoints_test,
server_url,
workers_proc,
server_proc,
use_ray=True)
except Exception:
# Print outputs on error
logger.error("-------- Workers output --------")
@ -402,10 +547,16 @@ def run_disaggregated_test(example_dir,
logger.error(f.read())
raise
finally:
server_proc.terminate()
workers_proc.terminate()
server_proc.wait()
workers_proc.wait()
if use_ray:
subprocess.run(['ray', 'stop', '--force'], check=False)
for extra_file in extra_config_files:
if os.path.exists(extra_file):
os.remove(extra_file)
elif 'server_proc' in locals() and 'workers_proc' in locals():
server_proc.terminate()
workers_proc.terminate()
server_proc.wait()
workers_proc.wait()
@pytest.mark.parametrize("llama_model_root", ['TinyLlama-1.1B-Chat-v1.0'],
@ -831,7 +982,8 @@ def test_disaggregated_ctxpp2_genpp2(disaggregated_test_root, llm_venv,
run_disaggregated_test(disaggregated_example_root,
"ctxpp2_genpp2",
env=llm_venv._new_env,
cwd=llm_venv.get_working_directory())
cwd=llm_venv.get_working_directory(),
model_path=llama_model_root)
@pytest.mark.skip_less_device(4)
@ -851,7 +1003,8 @@ def test_disaggregated_ctxtp2_genpp2(disaggregated_test_root, llm_venv,
run_disaggregated_test(disaggregated_example_root,
"ctxtp2_genpp2",
env=llm_venv._new_env,
cwd=llm_venv.get_working_directory())
cwd=llm_venv.get_working_directory(),
model_path=llama_model_root)
@pytest.mark.skip_less_device(4)
@ -871,7 +1024,8 @@ def test_disaggregated_ctxpp2_gentp2(disaggregated_test_root, llm_venv,
run_disaggregated_test(disaggregated_example_root,
"ctxpp2_gentp2",
env=llm_venv._new_env,
cwd=llm_venv.get_working_directory())
cwd=llm_venv.get_working_directory(),
model_path=llama_model_root)
@pytest.mark.skip_less_device(8)
@ -932,7 +1086,8 @@ def test_disaggregated_ctxpp4_gentp4(disaggregated_test_root, llm_venv,
run_disaggregated_test(disaggregated_example_root,
"ctxpp4_gentp4",
env=llm_venv._new_env,
cwd=llm_venv.get_working_directory())
cwd=llm_venv.get_working_directory(),
model_path=llama_model_root)
@skip_no_hopper
@ -1021,7 +1176,8 @@ def test_disaggregated_deepseek_v3_lite_fp8_ctxpp2_gentp2_one_mtp(
run_disaggregated_test(disaggregated_example_root,
"deepseek_v3_lite_fp8_ctxpp2_gentp2_one_mtp",
env=llm_venv._new_env,
cwd=llm_venv.get_working_directory())
cwd=llm_venv.get_working_directory(),
model_path=deepseek_v3_model_root)
@skip_no_hopper
@ -1048,7 +1204,8 @@ def test_disaggregated_deepseek_v3_lite_fp8_ucx(disaggregated_test_root,
run_disaggregated_test(disaggregated_example_root,
"deepseek_v3_lite_fp8_ucx",
env=env,
cwd=llm_venv.get_working_directory())
cwd=llm_venv.get_working_directory(),
model_path=deepseek_v3_model_root)
@skip_no_hopper
@ -1074,7 +1231,8 @@ def test_disaggregated_deepseek_v3_lite_fp8_nixl(disaggregated_test_root,
run_disaggregated_test(disaggregated_example_root,
"deepseek_v3_lite_fp8_nixl",
env=env,
cwd=llm_venv.get_working_directory())
cwd=llm_venv.get_working_directory(),
model_path=deepseek_v3_model_root)
@skip_no_hopper
@ -1262,7 +1420,8 @@ def test_disaggregated_deepseek_v3_lite_fp8_tp1_attention_dp_overlap_one_mtp(
disaggregated_example_root,
"deepseek_v3_lite_fp8_tp1_attention_dp_overlap_one_mtp",
env=llm_venv._new_env,
cwd=llm_venv.get_working_directory())
cwd=llm_venv.get_working_directory(),
model_path=deepseek_v3_model_root)
@skip_no_hopper

View File

@ -0,0 +1,102 @@
import os
import subprocess
import pytest
from defs.common import venv_check_call, wait_for_server
from defs.conftest import get_device_count, llm_models_root
@pytest.fixture(scope="module")
def ray_example_root(llm_root):
example_root = os.path.join(llm_root, "examples", "ray_orchestrator")
return example_root
def test_llm_inference_async_ray(ray_example_root, llm_venv):
script_path = os.path.join(ray_example_root, "llm_inference_async_ray.py")
venv_check_call(llm_venv, [script_path])
@pytest.mark.skip_less_device(2)
@pytest.mark.parametrize("tp_size,pp_size,ep_size", [
(2, 1, -1),
(1, 2, -1),
(2, 2, -1),
(2, 1, 2),
],
ids=["tp2", "pp2", "tp2pp2", "tep2"])
def test_llm_inference_distributed_ray(ray_example_root, llm_venv, tp_size,
pp_size, ep_size):
world_size = tp_size * pp_size
if get_device_count() < world_size:
pytest.skip(f"Need {world_size} GPUs.")
script_path = os.path.join(ray_example_root,
"llm_inference_distributed_ray.py")
cmd = [
script_path, "--tp_size",
str(tp_size), "--pp_size",
str(pp_size), "--moe_ep_size",
str(ep_size)
]
if ep_size != -1:
model_dir = f"{llm_models_root()}/DeepSeek-V3-Lite/bf16"
cmd.extend(["--model_dir", model_dir])
venv_check_call(llm_venv, cmd)
@pytest.mark.skip_less_device(2)
@pytest.mark.parametrize("tp_size", [1, 2], ids=["tp1", "tp2"])
def test_ray_disaggregated_serving(ray_example_root, llm_venv, tp_size):
if get_device_count() < tp_size * 2:
pytest.skip(f"Need {tp_size * 2} GPUs.")
disagg_dir = os.path.join(ray_example_root, "disaggregated")
script_path = os.path.join(disagg_dir, "disagg_serving_local.sh")
subprocess.run("ray stop --force", shell=True, check=False)
proc = subprocess.Popen(
["bash", script_path, "--executor", "ray", "--tp_size",
str(tp_size)],
cwd=disagg_dir,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
try:
assert wait_for_server("localhost", 8000, timeout_seconds=180), \
"Disaggregated server failed to start within 3 minutes"
result = subprocess.run([
"curl", "-sS", "-w", "\n%{http_code}",
"http://localhost:8000/v1/completions", "-H",
"Content-Type: application/json", "-d",
'{"model":"TinyLlama/TinyLlama-1.1B-Chat-v1.0","prompt":"NVIDIA is a great company because","max_tokens":16,"temperature":0}'
],
capture_output=True,
text=True,
timeout=30)
*body_lines, status_line = result.stdout.strip().splitlines()
body = "\n".join(body_lines)
status = int(status_line)
print("HTTP status:", status)
print("Response body:", body)
assert result.returncode == 0, f"curl exit {result.returncode}"
assert status == 200, f"Expected 200, got {status}"
finally:
proc.terminate()
try:
proc.wait(timeout=10)
except Exception:
proc.kill()
subprocess.run("ray stop --force", shell=True, check=False)
subprocess.run("pkill -9 -f trtllm-serve", shell=True, check=False)

View File

@ -109,6 +109,10 @@ if is_linux():
torch_inductors.append(pid)
continue
# Readability for ray processes. They have a lot of empty args
if cmdline and cmdline[0].startswith('ray::'):
cmdline = [arg for arg in cmdline if arg]
lines.append(f"{pid}: {cmdline}")
persist_pids.append(pid)
except psutil.Error:

View File

@ -13,6 +13,7 @@ l0_dgx_b200:
terms:
stage: pre_merge
backend: pytorch
orchestrator: mpi
tests:
- unittest/_torch/multi_gpu_modeling -k "deepseek"
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall_fp4[DeepEPLowLatency]
@ -53,6 +54,29 @@ l0_dgx_b200:
- accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_tp4
- accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_nvfp4_tp4
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-cutlass-auto]
- condition:
ranges:
system_gpu_count:
gte: 4
lte: 4
wildcards:
gpu:
- '*b200*'
linux_distribution_name: ubuntu*
terms:
stage: pre_merge
backend: pytorch
orchestrator: ray
tests:
- unittest/llmapi/test_llm_multi_gpu_pytorch.py -m "gpu4"
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-fp8kv=False-attn_backend=TRTLLM-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp2pp2-fp8kv=False-attn_backend=TRTLLM-torch_compile=False]
- disaggregated/test_disaggregated.py::test_disaggregated_ctxpp2_genpp2[TinyLlama-1.1B-Chat-v1.0]
- disaggregated/test_disaggregated.py::test_disaggregated_ctxtp2_genpp2[TinyLlama-1.1B-Chat-v1.0]
- disaggregated/test_disaggregated.py::test_disaggregated_ctxpp2_gentp2[TinyLlama-1.1B-Chat-v1.0]
- disaggregated/test_disaggregated.py::test_disaggregated_ctxpp4_gentp4[TinyLlama-1.1B-Chat-v1.0]
- examples/test_ray.py::test_llm_inference_distributed_ray[tp2pp2]
- examples/test_ray.py::test_ray_disaggregated_serving[tp2]
- condition:
ranges:
system_gpu_count:
@ -66,6 +90,7 @@ l0_dgx_b200:
terms:
stage: pre_merge
backend: pytorch
orchestrator: mpi
tests:
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput] TIMEOUT (180)
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_mtp] TIMEOUT (180)
@ -86,6 +111,7 @@ l0_dgx_b200:
terms:
stage: post_merge
backend: pytorch
orchestrator: mpi
tests:
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus_corner_case TIMEOUT (180)
- condition:
@ -101,6 +127,7 @@ l0_dgx_b200:
terms:
stage: post_merge
backend: pytorch
orchestrator: mpi
tests:
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall_fp4[DeepEP]
- unittest/_torch/multi_gpu_modeling/test_llama3.py::test_llama_3_3

View File

@ -13,6 +13,7 @@ l0_dgx_h100:
stage: pre_merge
backend: pytorch
auto_trigger: others
orchestrator: mpi
tests:
- unittest/llmapi/test_llm_multi_gpu_pytorch.py -m "gpu2"
- unittest/_torch/multi_gpu -m "not post_merge" TIMEOUT (90)
@ -54,6 +55,7 @@ l0_dgx_h100:
stage: pre_merge
backend: pytorch
auto_trigger: others
orchestrator: mpi
tests:
# ------------- PyTorch tests ---------------
- unittest/llmapi/test_llm_multi_gpu_pytorch.py -m "gpu4"
@ -99,6 +101,7 @@ l0_dgx_h100:
stage: pre_merge
backend: pytorch
auto_trigger: deepseek
orchestrator: mpi
tests:
- unittest/_torch/multi_gpu_modeling/test_deepseek.py::test_deepseek_streaming[tp1-bf16-trtllm-deepseekv3_lite]
- unittest/_torch/multi_gpu_modeling/test_deepseek.py::test_deepseek_streaming[tp4-bf16-trtllm-deepseekv3_lite]
@ -156,6 +159,7 @@ l0_dgx_h100:
stage: pre_merge
backend: pytorch
auto_trigger: gpt_oss
orchestrator: mpi
tests:
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[tp4-cutlass-auto]
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[tp4-triton-auto]
@ -177,6 +181,7 @@ l0_dgx_h100:
stage: pre_merge
backend: cpp
auto_trigger: others
orchestrator: mpi
tests:
# ------------- CPP tests ---------------
- cpp/test_multi_gpu.py::test_mpi_utils[90]
@ -218,3 +223,24 @@ l0_dgx_h100:
- cpp/test_multi_gpu.py::TestDisagg::test_spawn_orchestrator[llama-ucx_kvcache-90]
- cpp/test_multi_gpu.py::TestDisagg::test_orchestrator_params[llama-nixl_kvcache-90] TIMEOUT (90)
- cpp/test_multi_gpu.py::TestDisagg::test_spawn_orchestrator[llama-nixl_kvcache-90]
- condition:
ranges:
system_gpu_count:
gte: 2
lte: 2
wildcards:
gpu:
- '*h100*'
linux_distribution_name: ubuntu*
terms:
stage: pre_merge
backend: pytorch
orchestrator: ray
tests:
- unittest/_torch/ray_orchestrator/multi_gpu -m "gpu2"
- unittest/llmapi/test_llm_multi_gpu_pytorch.py -m "gpu2"
- accuracy/test_llm_api_pytorch_ray.py::TestLlama3_1_8BInstruct::test_pp2_ray
- examples/test_ray.py::test_llm_inference_distributed_ray[tp2]
- examples/test_ray.py::test_llm_inference_distributed_ray[pp2]
- examples/test_ray.py::test_llm_inference_distributed_ray[tep2]
- examples/test_ray.py::test_ray_disaggregated_serving[tp1]

View File

@ -12,6 +12,7 @@ l0_h100:
terms:
stage: pre_merge
backend: pytorch
orchestrator: mpi
tests:
# ------------- PyTorch tests ---------------
- unittest/_torch/attention
@ -115,6 +116,32 @@ l0_h100:
# ------------- AutoDeploy tests ---------------
- accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype
- accuracy/test_llm_api_autodeploy.py::TestNemotronH::test_auto_dtype
- condition:
ranges:
system_gpu_count:
gte: 1
lte: 1
wildcards:
gpu:
- '*h100*'
linux_distribution_name: ubuntu*
terms:
stage: pre_merge
backend: pytorch
orchestrator: ray
tests:
- accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16[attn_backend=TRTLLM-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=disable-fp8kv=True-attention_dp=False-cuda_graph=True-overlap_scheduler=True-torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_fp8_block_scales[latency]
- test_e2e.py::test_trtllm_bench_pytorch_backend_sanity[meta-llama/Llama-3.1-8B-llama-3.1-8b-False-False]
- disaggregated/test_disaggregated_single_gpu.py::test_disaggregated_simple_llama[False-False-TinyLlama-1.1B-Chat-v1.0]
- disaggregated/test_disaggregated_single_gpu.py::test_disaggregated_simple_deepseek[False-False-DeepSeek-V3-Lite-fp8/fp8]
- disaggregated/test_disaggregated_single_gpu.py::test_disaggregated_simple_qwen3[False-False-Qwen3-8B-FP8]
- unittest/_torch/executor
- unittest/_torch/ray_orchestrator/single_gpu
- unittest/llmapi/test_llm_pytorch.py
- examples/test_ray.py::test_llm_inference_async_ray
- condition:
ranges:
system_gpu_count:
@ -190,6 +217,7 @@ l0_h100:
terms:
stage: post_merge
backend: pytorch
orchestrator: mpi
tests:
# ------------- PyTorch tests ---------------
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16[attn_backend=FLASHINFER-torch_compile=False]

View File

@ -43,6 +43,7 @@ def create_llm(model_dir, disable_overlap_scheduler, sampler_type):
@pytest.mark.parametrize("sampler_type", ["TorchSampler", "TRTLLMSampler"])
@pytest.mark.high_cuda_memory
@pytest.mark.mpi_ray_parity
def test_overlap_scheduler_consistency(model_path, test_case, sampler_type):
# Test configuration
prompts = test_case["prompts"]

View File

@ -16,6 +16,8 @@ from tensorrt_llm._torch.pyexecutor.resource_manager import (KVCacheManager,
ResourceManagerType
)
# isort: on
from utils.util import skip_ray
from tensorrt_llm._torch.attention_backend.interface import AttentionMetadata
from tensorrt_llm._torch.pyexecutor.scheduler import ScheduledRequests
from tensorrt_llm.bindings.executor import KvCacheConfig
@ -322,6 +324,7 @@ class PyTorchModelEngineTestCase(unittest.TestCase):
self.assertTrue(pytorch_config_custom.cuda_graph_padding_enabled,
"Custom enable_padding should be respected")
@skip_ray
def test_prepare_tp_inputs_with_helix_parallelism(self) -> None:
"""Test _prepare_tp_inputs function with helix parallelism."""

View File

@ -0,0 +1,23 @@
import os
import sys
import pytest
from tensorrt_llm._utils import mpi_disabled
def pytest_configure(config):
if config.getoption("--run-ray"):
os.environ["TLLM_DISABLE_MPI"] = "1"
os.environ["TLLM_RAY_FORCE_LOCAL_CLUSTER"] = "1"
run_ray_flag = "--run-ray" in sys.argv
if run_ray_flag:
os.environ["TLLM_DISABLE_MPI"] = "1"
os.environ["TLLM_RAY_FORCE_LOCAL_CLUSTER"] = "1"
if not mpi_disabled():
pytest.skip(
"Ray tests are only tested in Ray CI stage or with --run-ray flag",
allow_module_level=True)

View File

@ -0,0 +1,147 @@
import copy
import os
import unittest
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from tensorrt_llm._torch.device_mesh import DeviceMeshTopologyImpl
from tensorrt_llm._utils import get_free_port
from tensorrt_llm.mapping import Mapping
class TestMapping(unittest.TestCase):
@pytest.mark.gpu2
def test_device_mesh_parity(self):
"""To ensure parity between Ray and MPI Mapping instance."""
# (tp, pp, cp, moe_tp, moe_ep)
combos = [
# no cp
(2, 1, 1, -1, -1), # -1 means no MoE in Mapping
# 8 GPUs, no cp
(4, 2, 1, -1, -1),
(2, 4, 1, -1, -1),
# 8 GPUs with cp
(4, 1, 2, -1, -1),
(2, 1, 4, -1, -1),
# with moe_tp, moe_ep
(8, 1, 1, 2, 4),
(2, 1, 1, 1, 2)
]
num_gpus = torch.cuda.device_count()
for tp, pp, cp, moe_tp, moe_ep in combos:
world_size = tp * pp * cp
print(
f"\n\n=== TP={tp}, PP={pp}, CP={cp}, MOE_TP={moe_tp}, MOE_EP={moe_ep} ==="
)
if world_size > num_gpus:
print(
f"SKIPPING: need {world_size} GPUs. Only have {num_gpus}.")
continue
mp.spawn(
self._worker,
args=(world_size, get_free_port(), tp, pp, cp, moe_tp, moe_ep),
nprocs=world_size,
join=True,
)
@pytest.mark.gpu2
def test_picklable(self):
world_size = tp = 2
mp.spawn(
self._worker,
args=(world_size, get_free_port(), tp, 1, 1, -1, -1, "pickle"),
nprocs=world_size,
join=True,
)
@staticmethod
def _worker(rank: int,
world_size: int,
master_port: int,
tp=1,
pp=1,
cp=1,
moe_tp=1,
moe_ep=1,
test_type="parity") -> None:
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = str(master_port)
os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["LOCAL_RANK"] = str(rank)
torch.cuda.set_device(rank)
if test_type == "parity":
if "TLLM_DISABLE_MPI" in os.environ:
del os.environ["TLLM_DISABLE_MPI"]
mapping_mpi = Mapping(
world_size=world_size,
rank=rank,
gpus_per_node=world_size,
tp_size=tp,
pp_size=pp,
cp_size=cp,
moe_tp_size=moe_tp,
moe_ep_size=moe_ep,
)
os.environ["TLLM_DISABLE_MPI"] = "1"
dist.init_process_group("nccl", rank=rank, world_size=world_size)
mapping_device_mesh = Mapping(
world_size=world_size,
rank=rank,
gpus_per_node=world_size,
tp_size=tp,
pp_size=pp,
cp_size=cp,
moe_tp_size=moe_tp,
moe_ep_size=moe_ep,
)
if test_type == "parity":
mapping_device_mesh.build_mesh()
properties = []
for dim in mapping_device_mesh.device_mesh.mesh_dim_names:
properties.append(f"{dim}_rank")
properties.append(f"{dim}_group")
for prop in properties:
mpi_value = getattr(mapping_mpi, prop)
device_mesh_value = getattr(mapping_device_mesh, prop)
if rank == 0:
print(
f" {prop}: MPI={mpi_value}, DeviceMesh={device_mesh_value}"
)
assert mpi_value == device_mesh_value, \
f"Property {prop} mismatch: MPI={mpi_value}, DeviceMesh={device_mesh_value} (rank {rank})"
elif test_type == "pickle":
mapping = mapping_device_mesh
tp_group = mapping.tp_group
print(f"tp_group: {tp_group}")
assert DeviceMeshTopologyImpl.device_mesh is not None
mapping_copy = copy.deepcopy(mapping)
# check static mesh still exists
assert mapping_copy.device_mesh is not None
print(f"tp_group after deepcopy: {mapping.tp_group}")
assert mapping.tp_group == mapping_copy.tp_group
else:
raise ValueError(f"Invalid test type: {test_type}")
dist.destroy_process_group()

View File

@ -0,0 +1,381 @@
import os
import pytest
import torch
try:
import ray
except ModuleNotFoundError:
from tensorrt_llm import ray_stub as ray
from tensorrt_llm._torch.distributed.communicator import TorchDist
from tensorrt_llm._utils import get_free_port
from tensorrt_llm.functional import AllReduceFusionOp, AllReduceStrategy
from tensorrt_llm.mapping import Mapping
@ray.remote(num_gpus=1)
class AllgatherPGTest:
def __init__(self, rank, world_size):
self.rank = rank
self.world_size = world_size
self.master_address = os.environ["MASTER_ADDR"]
self.master_port = os.environ["MASTER_PORT"]
assert len(ray.get_gpu_ids()) == 1
torch.distributed.init_process_group(
backend="cuda:nccl,cpu:gloo",
init_method=f"tcp://{self.master_address}:{self.master_port}",
world_size=world_size,
rank=rank)
@torch.inference_mode()
def run_allgather_pg_op(self, test_tensor, expected_result, sizes):
torch.cuda.set_device(0)
test_tensor = test_tensor.cuda(0)
expected_result = expected_result.cuda(0)
mapping = Mapping(world_size=self.world_size,
gpus_per_node=self.world_size,
tp_size=self.world_size,
rank=self.rank)
if torch.distributed.is_initialized():
TorchDist(mapping)
else:
raise RuntimeError("torch.distributed is not initialized")
module = torch
op_path = ['ops', 'trtllm', 'allgather_pg']
for attr in op_path:
module = getattr(module, attr)
allgather_pg_op = module
output = allgather_pg_op(test_tensor, sizes, mapping.tp_group,
mapping.tp_group_pg.boxed())
if isinstance(output, (list, tuple)):
result_tensor = output[0]
else:
result_tensor = output
rtol, atol = 0.05, 0.15
torch.testing.assert_close(result_tensor,
expected_result,
rtol=rtol,
atol=atol)
return True
@ray.remote(num_gpus=1)
class ReducescatterPGTest:
def __init__(self, rank, world_size):
self.rank = rank
self.world_size = world_size
self.master_address = os.environ["MASTER_ADDR"]
self.master_port = os.environ["MASTER_PORT"]
assert len(ray.get_gpu_ids()) == 1
torch.distributed.init_process_group(
backend="cuda:nccl,cpu:gloo",
init_method=f"tcp://{self.master_address}:{self.master_port}",
world_size=world_size,
rank=rank)
@torch.inference_mode()
def run_reducescatter_pg_op(self, test_tensor, expected_result, sizes):
torch.cuda.set_device(0)
test_tensor = test_tensor.cuda(0)
expected_result = expected_result.cuda(0)
mapping = Mapping(world_size=self.world_size,
gpus_per_node=self.world_size,
tp_size=self.world_size,
rank=self.rank)
if torch.distributed.is_initialized():
TorchDist(mapping)
else:
raise RuntimeError("torch.distributed is not initialized")
module = torch
op_path = ['ops', 'trtllm', 'reducescatter_pg']
for attr in op_path:
module = getattr(module, attr)
reducescatter_pg_op = module
output = reducescatter_pg_op(test_tensor, sizes, mapping.tp_group,
mapping.tp_group_pg.boxed())
if isinstance(output, (list, tuple)):
result_tensor = output[0]
else:
result_tensor = output
rtol, atol = 0.05, 0.15
torch.testing.assert_close(result_tensor,
expected_result,
rtol=rtol,
atol=atol)
return True
@ray.remote(num_gpus=1)
class AllreducePGTest:
def __init__(self, rank, world_size):
self.rank = rank
self.world_size = world_size
self.master_address = os.environ["MASTER_ADDR"]
self.master_port = os.environ["MASTER_PORT"]
assert len(ray.get_gpu_ids()) == 1
torch.distributed.init_process_group(
backend="cuda:nccl,cpu:gloo",
init_method=f"tcp://{self.master_address}:{self.master_port}",
world_size=world_size,
rank=rank)
@torch.inference_mode()
def run_allreduce_pg_op(self, test_tensor, expected_result):
torch.cuda.set_device(0)
test_tensor = test_tensor.cuda(0)
expected_result = expected_result.cuda(0)
mapping = Mapping(world_size=self.world_size,
gpus_per_node=self.world_size,
tp_size=self.world_size,
rank=self.rank)
if torch.distributed.is_initialized():
TorchDist(mapping)
else:
raise RuntimeError("torch.distributed is not initialized")
module = torch
op_path = ['ops', 'trtllm', 'allreduce_pg']
for attr in op_path:
module = getattr(module, attr)
allreduce_pg_op = module
output = allreduce_pg_op(
input=test_tensor,
residual=None,
norm_weight=None,
scale=None,
bias=None,
workspace=None,
group=mapping.tp_group,
strategy=AllReduceStrategy.NCCL,
op=AllReduceFusionOp.NONE, # Pure allreduce, no fusion
eps=1e-5,
trigger_completion_at_end=True,
rank=self.rank,
pg=mapping.tp_group_pg.boxed())
if isinstance(output, (list, tuple)):
result_tensor = output[0]
else:
result_tensor = output
rtol, atol = 0.05, 0.15
torch.testing.assert_close(result_tensor,
expected_result,
rtol=rtol,
atol=atol)
return True
@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Requires at least 2 GPUs for this test")
@pytest.mark.parametrize("hidden_size", [128, 1024],
ids=lambda x: f"hidden:{x}")
@pytest.mark.parametrize("seq_len", [16, 64], ids=lambda x: f"seqlen:{x}")
@pytest.mark.parametrize("var_len", [True, False], ids=lambda x: f"var_len:{x}")
def test_allgather_pg_op(seq_len, hidden_size, var_len):
torch.manual_seed(42)
dtype = torch.bfloat16
world_size = 2
if var_len:
test_tensor_list = [
torch.randn((seq_len * (i + 1), hidden_size), dtype=dtype)
for i in range(world_size)
]
expected_result = torch.cat(test_tensor_list, dim=0)
sizes = [seq_len * (i + 1) for i in range(world_size)]
else:
test_tensor = torch.randn((seq_len, hidden_size), dtype=dtype)
expected_result = test_tensor.repeat(world_size, 1)
sizes = None
ray_init_args = {
"include_dashboard": False,
"namespace": "test_allgather_pg_op",
"ignore_reinit_error": True
}
try:
ray.init(address="local", **ray_init_args)
master_port = get_free_port()
runtime_env = ray.runtime_env.RuntimeEnv()
runtime_env["env_vars"] = os.environ.copy()
runtime_env["env_vars"].update({
"TLLM_DISABLE_MPI": "1",
"MASTER_ADDR": "127.0.0.1",
"MASTER_PORT": str(master_port)
})
remotePGTests = []
for rank in range(world_size):
remotePGTests.append(
AllgatherPGTest.options(runtime_env=runtime_env).remote(
rank, world_size))
if var_len:
results = ray.get([
remotePGTest.run_allgather_pg_op.remote(test_tensor,
expected_result, sizes)
for remotePGTest, test_tensor in zip(remotePGTests,
test_tensor_list)
])
else:
results = ray.get([
remotePGTest.run_allgather_pg_op.remote(test_tensor,
expected_result, sizes)
for remotePGTest in remotePGTests
])
except Exception as e:
if ray.is_initialized():
ray.shutdown()
raise e
ray.shutdown()
for r in results:
assert r is True
@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Requires at least 2 GPUs for this test")
@pytest.mark.parametrize("hidden_size", [128, 1024],
ids=lambda x: f"hidden:{x}")
@pytest.mark.parametrize("seq_len", [16, 64], ids=lambda x: f"seqlen:{x}")
@pytest.mark.parametrize("var_len", [True, False], ids=lambda x: f"var_len:{x}")
def test_reducescatter_pg_op(seq_len, hidden_size, var_len):
torch.manual_seed(42)
dtype = torch.bfloat16
world_size = 2
if var_len:
total_seq_len = sum([seq_len * (i + 1) for i in range(world_size)])
test_tensor = torch.randn((total_seq_len, hidden_size), dtype=dtype)
expected_result_list = []
offset = 0
for i in range(world_size):
expected_result_list.append(test_tensor[offset:offset + seq_len *
(i + 1)] * world_size)
offset += seq_len * (i + 1)
sizes = [seq_len * (i + 1) for i in range(world_size)]
else:
test_tensor = torch.randn((seq_len * world_size, hidden_size),
dtype=dtype)
expected_result_list = [
test_tensor[i * seq_len:(i + 1) * seq_len] * world_size
for i in range(world_size)
]
sizes = None
ray_init_args = {
"include_dashboard": False,
"namespace": "test_reducescatter_pg_op",
"ignore_reinit_error": True
}
try:
ray.init(address="local", **ray_init_args)
master_port = get_free_port()
runtime_env = ray.runtime_env.RuntimeEnv()
runtime_env["env_vars"] = os.environ.copy()
runtime_env["env_vars"].update({
"TLLM_DISABLE_MPI": "1",
"MASTER_ADDR": "127.0.0.1",
"MASTER_PORT": str(master_port)
})
remotePGTests = []
for rank in range(world_size):
remotePGTests.append(
ReducescatterPGTest.options(runtime_env=runtime_env).remote(
rank, world_size))
if var_len:
results = ray.get([
remotePGTest.run_reducescatter_pg_op.remote(
test_tensor, expected_result, sizes)
for remotePGTest, expected_result in zip(
remotePGTests, expected_result_list)
])
else:
results = ray.get([
remotePGTest.run_reducescatter_pg_op.remote(
test_tensor, expected_result, sizes)
for remotePGTest, expected_result in zip(
remotePGTests, expected_result_list)
])
except Exception as e:
if ray.is_initialized():
ray.shutdown()
raise e
ray.shutdown()
for r in results:
assert r is True
@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Requires at least 2 GPUs for this test")
@pytest.mark.parametrize("hidden_size", [128, 1024],
ids=lambda x: f"hidden:{x}")
@pytest.mark.parametrize("seq_len", [16, 64], ids=lambda x: f"seqlen:{x}")
def test_allreduce_pg_op(seq_len, hidden_size):
torch.manual_seed(42)
dtype = torch.bfloat16
world_size = 2
test_tensor = torch.randn((seq_len, hidden_size), dtype=dtype)
expected_result = test_tensor * world_size
ray_init_args = {
"include_dashboard": False,
"namespace": "test_allreduce_pg_op",
"ignore_reinit_error": True
}
try:
ray.init(address="local", **ray_init_args)
master_port = get_free_port()
runtime_env = ray.runtime_env.RuntimeEnv()
runtime_env["env_vars"] = os.environ.copy()
runtime_env["env_vars"].update({
"TLLM_DISABLE_MPI": "1",
"MASTER_ADDR": "127.0.0.1",
"MASTER_PORT": str(master_port)
})
remotePGTests = []
for rank in range(world_size):
remotePGTests.append(
AllreducePGTest.options(runtime_env=runtime_env).remote(
rank, world_size))
results = ray.get([
remotePGTest.run_allreduce_pg_op.remote(test_tensor,
expected_result)
for remotePGTest in remotePGTests
])
except Exception as e:
if ray.is_initialized():
ray.shutdown()
raise e
ray.shutdown()
for r in results:
assert r is True

View File

@ -0,0 +1,21 @@
import os
import pytest
from tensorrt_llm import LLM
from tensorrt_llm._torch.utils import get_device_uuid
@pytest.mark.gpu2
def test_cuda_visible_device():
"""Placement via cuda_visible_device"""
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
orchestrator_type="ray")
infer_actor_uuids = llm._collective_rpc("report_device_id")
del os.environ["CUDA_VISIBLE_DEVICES"]
assert infer_actor_uuids[0] == get_device_uuid(1)
print(f"{infer_actor_uuids=}")

View File

@ -0,0 +1,87 @@
import importlib
import os
import unittest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from tensorrt_llm._utils import get_free_port, torch_pybind11_abi
class TestCacheTransceiverComm(unittest.TestCase):
def test_cache_transceiver_comm(self):
mp.spawn(
self._worker,
args=(4, get_free_port()),
nprocs=4,
join=True,
)
@staticmethod
def _worker(rank: int, world_size: int, master_port: int):
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = str(master_port)
os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["LOCAL_RANK"] = str(rank)
try:
dist.init_process_group(
backend="gloo",
rank=rank,
world_size=world_size,
)
world_pg = torch.distributed.group.WORLD
bm = importlib.import_module(
"tensorrt_llm.bindings.internal.batch_manager")
cacheComm = getattr(bm, "CacheTransceiverComm")
comm = cacheComm(world_pg, torch_pybind11_abi())
# Test split
rank = dist.get_rank()
world_size = dist.get_world_size()
color = rank // 2
key = rank % 2
sub = comm.split(color, key)
expected_group_size = 2
assert sub.get_size() == expected_group_size
# Test allgather
ok, gathered_ranks = sub.allgather(rank)
assert ok is True
expected_world_ranks = [
r for r in range(world_size) if (r // 2) == color
]
assert gathered_ranks == expected_world_ranks
# Test allgatherv
local_len = rank + 1
payload = [rank] * local_len
ok_sizes, sizes64 = sub.allgather(local_len)
assert ok_sizes is True
sizes = [int(x) for x in sizes64]
ok_v, out = sub.allgatherv(payload, sizes)
assert ok_v is True
expected_concat = []
for r in expected_world_ranks:
expected_concat.extend([r] * (r + 1))
assert out == expected_concat
# Test allgatherv with char
char_payload = [chr(65 + rank)] * local_len
ok_char, char_out = sub.allgatherv(char_payload, sizes)
assert ok_char is True
expected_char_concat = []
for r in expected_world_ranks:
expected_char_concat.extend([chr(65 + r)] * (r + 1))
assert char_out == expected_char_concat
finally:
if dist.is_initialized():
dist.destroy_process_group()

View File

@ -64,6 +64,7 @@ def test_register_fake(custom_ops):
"trtllm::mtp_prepare_drafter_inputs_op",
"trtllm::selective_scan",
"trtllm::reducescatter_list",
"trtllm::reducescatter_list_pg",
"trtllm::fp8_per_tensor_scale_moe_runner",
"trtllm::migrate_to_host_accessible",
"trtllm::mnnvl_moe_alltoallv_prepare_without_allgather",

View File

@ -75,6 +75,10 @@ methods:
annotation: Optional[str]
default: null
status: deprecated
orchestrator_type:
annotation: Optional[Literal['ray']]
default: null
status: prototype
build_config:
annotation: Optional[tensorrt_llm.llmapi.llm_args.BuildConfig]
default: null

View File

@ -79,6 +79,12 @@ def pytest_addoption(parser):
help=
"Prepend a prefix to the test names. Useful for distinguishing different test runs in a test report."
)
parser.addoption(
"--run-ray",
action="store_true",
default=False,
help="Run Ray-marked tests (by default they are skipped).",
)
@pytest.hookimpl(tryfirst=True, hookwrapper=True)
@ -94,8 +100,23 @@ def pytest_collection_modifyitems(session, config, items):
for item in items:
item._nodeid = f"{test_prefix}/{item._nodeid}"
# Ray tests are disabled by default
run_ray = config.getoption("--run-ray") or os.environ.get(
"TLLM_RUN_RAY_TESTS") == "1"
if not run_ray:
skip_marker = pytest.mark.skip(
reason=
"Ray tests skipped; pass --run-ray or set TLLM_RUN_RAY_TESTS=1")
for item in items:
if "ray" in item.keywords:
item.add_marker(skip_marker)
def pytest_sessionstart(session):
if session.config.getoption("--run-ray"):
os.environ["TLLM_DISABLE_MPI"] = "1"
os.environ["TLLM_RAY_FORCE_LOCAL_CLUSTER"] = "1"
# To counter TransformerEngine v2.3's lazy_compile deferral,
# which will cause Pytest thinks there's a thread leakage.
import torch._inductor.async_compile # noqa: F401
@ -147,3 +168,70 @@ def mpi_pool_executor(request):
# make the number of workers visible to tests
setattr(executor, "num_workers", num_workers)
yield executor
def pytest_generate_tests(metafunc: pytest.Metafunc):
if metafunc.definition.get_closest_marker('mpi_ray_parity'):
run_ray = metafunc.config.getoption("--run-ray") or os.environ.get(
"TLLM_RUN_RAY_TESTS") == "1"
if run_ray:
metafunc.parametrize(
'ray_mode',
[
pytest.param('ray', id='ray', marks=pytest.mark.ray),
],
indirect=True,
)
@pytest.fixture
def ray_mode(request):
return getattr(request, 'param', 'mpi')
@pytest.fixture(autouse=True)
def _maybe_force_ray(request, monkeypatch, ray_mode):
"""
Patch the LLM class (torch only) to use Ray executor.
"""
if 'mpi_ray_parity' not in request.node.keywords or ray_mode != 'ray':
return
def wrap_llm(cls):
class LLMProxy(cls):
def __init__(self, *args, **kwargs):
kwargs["orchestrator_type"] = "ray"
super().__init__(*args, **kwargs)
return LLMProxy
test_mod = request.node.module
# Only patch the torch LLM class
if hasattr(test_mod, 'LLM'):
try:
from tensorrt_llm._tensorrt_engine import LLM as LLM_legacy
is_trtllm_backend = (test_mod.LLM is LLM_legacy)
except Exception:
is_trtllm_backend = False
if not is_trtllm_backend:
monkeypatch.setattr(test_mod,
'LLM',
wrap_llm(test_mod.LLM),
raising=False)
if hasattr(test_mod, 'LLM_torch'):
monkeypatch.setattr(test_mod,
'LLM_torch',
wrap_llm(test_mod.LLM_torch),
raising=False)
try:
import tensorrt_llm.llmapi.llm as llm_mod
monkeypatch.setattr(llm_mod,
'LLM',
wrap_llm(llm_mod.LLM),
raising=False)
except Exception:
pass

View File

@ -550,6 +550,7 @@ def _test_llm_generate_async(model_name=default_model_name,
@pytest.mark.parametrize("chunked", [True, False])
@pytest.mark.part0
@pytest.mark.mpi_ray_parity
def test_llm_generate_async_with_stream_interval(chunked):
model_path = get_model_path('llama-models-v2/llama-v2-7b-hf')
max_num_tokens = 256

View File

@ -27,7 +27,7 @@ from .test_llm import (_test_llm_capture_request_error, get_model_path,
tinyllama_logits_processor_test_harness)
from utils.util import (force_ampere, similar, skip_gpu_memory_less_than_40gb,
skip_gpu_memory_less_than_80gb,
skip_gpu_memory_less_than_138gb)
skip_gpu_memory_less_than_138gb, skip_ray)
from utils.llm_data import llm_models_root
from tensorrt_llm.lora_helper import LoraConfig
from tensorrt_llm.executor.request import LoRARequest
@ -50,6 +50,7 @@ def test_tinyllama_logits_processor(enable_chunked_prefill):
backend="pytorch", enable_chunked_prefill=enable_chunked_prefill)
@skip_ray
@pytest.mark.parametrize(
"return_context_logits, use_overlap, enable_iter_req_stats", [
(False, False, False),
@ -66,6 +67,7 @@ def test_llm_get_stats(return_context_logits, use_overlap,
enable_iter_req_stats=enable_iter_req_stats)
@skip_ray
@pytest.mark.parametrize(
"return_context_logits, use_overlap, enable_iter_req_stats", [
(False, False, False),
@ -88,6 +90,7 @@ def test_llm_capture_request_error():
@force_ampere
@pytest.mark.mpi_ray_parity
@pytest.mark.parametrize(
"sampling_params",
[
@ -175,6 +178,7 @@ def test_llm_reward_model():
assert not outputs[0].outputs[0].text
@skip_ray
def test_llm_perf_metrics():
llm = LLM(model=llama_model_path, kv_cache_config=global_kvcache_config)
sampling_params = SamplingParams(max_tokens=10, return_perf_metrics=True)
@ -200,6 +204,7 @@ def test_llm_perf_metrics():
assert perf_metrics.last_iter == perf_metrics.iter
@skip_ray
def test_llm_prometheus():
test_prompts = [
"Hello, my name is",
@ -221,6 +226,7 @@ def test_llm_prometheus():
assert request_output.outputs is not None
@skip_ray
@pytest.mark.parametrize("streaming", [True, False])
def test_llm_with_postprocess_parallel_and_result_handler(streaming):
run_llm_with_postprocess_parallel_and_result_handler(streaming,
@ -404,6 +410,7 @@ def test_llama_7b_multi_lora_evict_and_reload_evicted_adapters_in_cpu_and_gpu_ca
repeats_per_call=1)
@skip_ray
@skip_gpu_memory_less_than_40gb
def test_llama_7b_peft_cache_config_affects_peft_cache_size():
"""Tests that LLM arg of peft_cache_config affects the peft cache sizes.
@ -832,6 +839,7 @@ FailingExecutor = type(
})
@skip_ray
def test_llm_with_proxy_error():
"""Test that LLM properly handles GenerationExecutorWorker constructor failures.
@ -885,6 +893,7 @@ def test_min_tokens(use_speculative: bool):
assert len(res.outputs[0].token_ids) == output_len
@skip_ray
@pytest.mark.parametrize(
"prompt_logprobs, logprobs, return_context_logits, return_generation_logits, backend",
[
@ -907,6 +916,7 @@ def test_llm_return_logprobs(prompt_logprobs: Optional[int],
backend=backend)
@skip_ray
@pytest.mark.parametrize(
"prompt_logprobs, logprobs, return_context_logits, return_generation_logits",
[

View File

@ -22,3 +22,5 @@ markers =
post_merge: this test should only run in post merge
high_cuda_memory: this test uses a lot of CUDA memory (typically more than 12GB)
no_xdist: this test should not run when using pytest-xdist
ray: mark Ray-based tests
mpi_ray_parity: parametrize a test to also run with Ray executor variant

View File

@ -446,3 +446,8 @@ def check_accuracy(a, b, atol, rtol, percent):
if not (mismatch_percent < 1 - percent):
raise Exception("Mismatch percentage is %f for rtol %f" %
(mismatch_percent, rtol))
skip_ray = pytest.mark.skipif(
os.environ.get("TLLM_DISABLE_MPI") == "1",
reason="This test is skipped for Ray orchestrator.")