mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[TRTLLM-7349][feat] Adding new orchestrator type -- ray (#7520)
Signed-off-by: Erin Ho <14718778+hchings@users.noreply.github.com> Co-authored-by: Yuan Tong <13075180+tongyuantongyu@users.noreply.github.com> Co-authored-by: Erin Ho <14718778+hchings@users.noreply.github.com>
This commit is contained in:
parent
9d098e3142
commit
88ea2c4ee9
1
.gitignore
vendored
1
.gitignore
vendored
@ -46,6 +46,7 @@ tensorrt_llm/deep_ep_cpp_tllm.pyi
|
||||
tensorrt_llm/deep_gemm/
|
||||
tensorrt_llm/deep_gemm_cpp_tllm.*.so
|
||||
tensorrt_llm/deep_gemm_cpp_tllm.pyi
|
||||
tensorrt_llm/pg_utils_bindings.*.so
|
||||
*docs/cpp_docs*
|
||||
*docs/source/_cpp_gen*
|
||||
docs/source/**/*.rst
|
||||
|
||||
@ -23,9 +23,17 @@
|
||||
#include "tensorrt_llm/executor/cacheCommunicator.h"
|
||||
#include "tensorrt_llm/executor/dataTransceiverState.h"
|
||||
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
|
||||
#include "tensorrt_llm/runtime/utils/pgUtils.h"
|
||||
#include <future>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <optional>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <torch/csrc/jit/python/pybind_utils.h>
|
||||
#include <torch/custom_class.h>
|
||||
#include <torch/python.h>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
using SizeType32 = tensorrt_llm::runtime::SizeType32;
|
||||
|
||||
@ -43,6 +51,134 @@ class BaseKVCacheManager;
|
||||
class CacheSender;
|
||||
class CacheReceiver;
|
||||
|
||||
class CacheTransceiverComm
|
||||
{
|
||||
public:
|
||||
// Construct from a non-owning raw pointer, won't take ownership of the pointer
|
||||
explicit CacheTransceiverComm(mpi::MpiComm const* mpiComm)
|
||||
: mMpiComm(std::shared_ptr<mpi::MpiComm const>(nullptr), mpiComm)
|
||||
{
|
||||
}
|
||||
|
||||
// Construct from a shared_ptr with shared ownership
|
||||
explicit CacheTransceiverComm(std::shared_ptr<mpi::MpiComm const> mpiComm)
|
||||
: mMpiComm(std::move(mpiComm))
|
||||
{
|
||||
}
|
||||
|
||||
// Construct from a ProcessGroup communicator
|
||||
explicit CacheTransceiverComm(c10::intrusive_ptr<c10d::ProcessGroup> pgComm)
|
||||
: mPgComm(std::move(pgComm))
|
||||
{
|
||||
}
|
||||
|
||||
~CacheTransceiverComm() = default;
|
||||
|
||||
bool isMpi() const noexcept
|
||||
{
|
||||
return mMpiComm != nullptr;
|
||||
}
|
||||
|
||||
int getRank() const
|
||||
{
|
||||
if (isMpi())
|
||||
{
|
||||
return mMpiComm->getRank();
|
||||
}
|
||||
return mPgComm->getRank();
|
||||
}
|
||||
|
||||
int getSize() const
|
||||
{
|
||||
if (isMpi())
|
||||
{
|
||||
return mMpiComm->getSize();
|
||||
}
|
||||
return mPgComm->getSize();
|
||||
}
|
||||
|
||||
void allgather(void const* sendbuf, void* recvbuf, int count, mpi::MpiType dtype) const
|
||||
{
|
||||
if (isMpi())
|
||||
{
|
||||
mMpiComm->allgather(sendbuf, recvbuf, count, dtype);
|
||||
return;
|
||||
}
|
||||
TLLM_THROW("Input arguments only supported in mpi");
|
||||
}
|
||||
|
||||
template <typename Input, typename Output>
|
||||
bool allgather(Input input, Output output, c10d::AllgatherOptions options = c10d::AllgatherOptions()) const
|
||||
{
|
||||
if (isMpi())
|
||||
{
|
||||
TLLM_THROW("Input arguments only supported in pg");
|
||||
}
|
||||
tensorrt_llm::pg_utils::PgHelper pgh{mPgComm};
|
||||
|
||||
PGCHECK_THROW(pgh.allgather(input, output, options));
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename Input, typename Output>
|
||||
bool allgatherv(Input input, Output output, std::vector<int> const& sizes,
|
||||
c10d::AllgatherOptions options = c10d::AllgatherOptions()) const
|
||||
{
|
||||
if (isMpi())
|
||||
{
|
||||
TLLM_THROW("Input arguments only supported in pg");
|
||||
}
|
||||
tensorrt_llm::pg_utils::PgHelper pgh{mPgComm};
|
||||
PGCHECK_THROW(pgh.allgatherv(input, output, sizes, options));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool allgatherv(void const* sendbuf, int sendcount, mpi::MpiType sendtype, void* recvbuf,
|
||||
std::vector<int> const& recvcounts, std::vector<int> const& displs, mpi::MpiType recvtype) const
|
||||
{
|
||||
if (isMpi())
|
||||
{
|
||||
mMpiComm->allgatherv(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype);
|
||||
return true;
|
||||
}
|
||||
TLLM_THROW("Input arguments only supported in mpi");
|
||||
}
|
||||
|
||||
CacheTransceiverComm split(int color, int key)
|
||||
{
|
||||
if (isMpi())
|
||||
{
|
||||
auto subgroup = mMpiComm->split(color, key);
|
||||
return CacheTransceiverComm(std::make_shared<mpi::MpiComm const>(std::move(subgroup)));
|
||||
}
|
||||
bool const initialized = Py_IsInitialized();
|
||||
TLLM_CHECK_WITH_INFO(initialized, "Trying to use ProcessGroup communicator but Python is not initialized");
|
||||
try
|
||||
{
|
||||
c10::intrusive_ptr<c10d::ProcessGroup> pgSub;
|
||||
{
|
||||
pybind11::gil_scoped_acquire gil;
|
||||
auto const m = pybind11::module::import("tensorrt_llm._torch.distributed.pg_utils");
|
||||
// Properly box the existing intrusive_ptr ProcessGroup into an IValue
|
||||
// and convert to a Python object without constructing a new instance.
|
||||
auto const py_pg = torch::jit::toPyObject(c10::IValue(mPgComm));
|
||||
|
||||
auto const py_sub_pg = m.attr("split")(color, key, py_pg);
|
||||
pgSub = torch::jit::toCustomClass<c10d::ProcessGroup>(py_sub_pg);
|
||||
}
|
||||
return CacheTransceiverComm(pgSub);
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
TLLM_THROW("Failed to split process group");
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::shared_ptr<mpi::MpiComm const> mMpiComm;
|
||||
c10::intrusive_ptr<c10d::ProcessGroup> mPgComm;
|
||||
};
|
||||
|
||||
class CacheTransceiverFactory
|
||||
{
|
||||
public:
|
||||
@ -124,9 +260,11 @@ private:
|
||||
std::unique_ptr<CacheReceiver> mCacheReceiver;
|
||||
std::vector<std::pair<LlmRequest*, std::future<void>>> mSenderFutures;
|
||||
std::vector<std::pair<LlmRequest*, std::future<void>>> mRequesterFutures;
|
||||
mpi::MpiComm const *mMpiGroupComm{nullptr}, *mMpiWorldComm{nullptr};
|
||||
std::shared_ptr<mpi::MpiComm> mMpiGroupTensorParaComm, mMpiGroupPipeParaComm, mMpiGroupDataComm,
|
||||
mMpiGroupTPInDPComm;
|
||||
mpi::MpiComm const* mMpiWorldComm{nullptr};
|
||||
|
||||
std::shared_ptr<CacheTransceiverComm> mGroupComm;
|
||||
std::shared_ptr<CacheTransceiverComm> mGroupTensorParaComm, mGroupPipeParaComm, mGroupDataComm, mGroupTPInDPComm;
|
||||
|
||||
executor::kv_cache::CommState const* mCommState;
|
||||
std::unique_ptr<executor::kv_cache::CacheState> mCacheState;
|
||||
std::unique_ptr<executor::kv_cache::ConnectionManager> mManager;
|
||||
|
||||
72
cpp/include/tensorrt_llm/common/bindingUtils.h
Normal file
72
cpp/include/tensorrt_llm/common/bindingUtils.h
Normal file
@ -0,0 +1,72 @@
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "c10/util/intrusive_ptr.h"
|
||||
#include <Python.h>
|
||||
|
||||
namespace tensorrt_llm::common
|
||||
{
|
||||
|
||||
// Adapted from pybind11's example implementation:
|
||||
// https://github.com/pybind/pybind11/blob/master/include/pybind11/conduit/pybind11_conduit_v1.h
|
||||
// Copyright (c) 2024 The pybind Community.
|
||||
|
||||
inline void* get_raw_pointer_ephemeral(
|
||||
PyObject* py_obj, std::type_info const* cpp_type_info, std::string const& pybind11_abi)
|
||||
{
|
||||
PyObject* cpp_type_info_capsule = PyCapsule_New(
|
||||
const_cast<void*>(static_cast<void const*>(cpp_type_info)), typeid(std::type_info).name(), nullptr);
|
||||
if (cpp_type_info_capsule == nullptr)
|
||||
{
|
||||
return nullptr;
|
||||
}
|
||||
PyObject* cpp_conduit = PyObject_CallMethod(
|
||||
py_obj, "_pybind11_conduit_v1_", "yOy", pybind11_abi.c_str(), cpp_type_info_capsule, "raw_pointer_ephemeral");
|
||||
Py_DECREF(cpp_type_info_capsule);
|
||||
if (cpp_conduit == nullptr)
|
||||
{
|
||||
return nullptr;
|
||||
}
|
||||
void* raw_ptr = PyCapsule_GetPointer(cpp_conduit, cpp_type_info->name());
|
||||
Py_DECREF(cpp_conduit);
|
||||
if (PyErr_Occurred())
|
||||
{
|
||||
return nullptr;
|
||||
}
|
||||
return raw_ptr;
|
||||
}
|
||||
|
||||
template <typename T, typename E>
|
||||
T* get_type_pointer_ephemeral(PyObject* py_obj, std::string pybind11_abi)
|
||||
{
|
||||
void* raw_ptr = get_raw_pointer_ephemeral(py_obj, &typeid(T), pybind11_abi);
|
||||
if (raw_ptr == nullptr)
|
||||
{
|
||||
throw E();
|
||||
}
|
||||
return static_cast<T*>(raw_ptr);
|
||||
}
|
||||
|
||||
template <typename T, typename E>
|
||||
c10::intrusive_ptr<T> get_intrusive_ptr(PyObject* py_obj, std::string pybind11_abi)
|
||||
{
|
||||
auto* const p = get_type_pointer_ephemeral<T, E>(py_obj, pybind11_abi);
|
||||
return c10::intrusive_ptr<T>::reclaim_copy(p);
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::common
|
||||
@ -35,6 +35,7 @@
|
||||
#include <cstdlib>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <optional>
|
||||
#include <thread>
|
||||
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
@ -425,7 +426,29 @@ public:
|
||||
return !(rhs == *this);
|
||||
}
|
||||
|
||||
bool couldUseMPI() const
|
||||
{
|
||||
if (!mDisableMPI.has_value())
|
||||
{
|
||||
char* val = std::getenv("TLLM_DISABLE_MPI");
|
||||
if (val != NULL && std::string(val) == "1")
|
||||
{
|
||||
mDisableMPI = true;
|
||||
}
|
||||
else
|
||||
{
|
||||
mDisableMPI = false;
|
||||
}
|
||||
}
|
||||
if (mDisableMPI.value())
|
||||
{
|
||||
throw std::runtime_error("MPI is disabled, DON\'T USE MPI");
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
private:
|
||||
mutable std::optional<bool> mDisableMPI;
|
||||
//! \brief Corresponds to `world()` by default, but can be overridden per process.
|
||||
static MpiComm& mutableSession();
|
||||
|
||||
|
||||
284
cpp/include/tensorrt_llm/runtime/utils/pgUtils.h
Normal file
284
cpp/include/tensorrt_llm/runtime/utils/pgUtils.h
Normal file
@ -0,0 +1,284 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <numeric>
|
||||
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
|
||||
#include <torch/torch.h>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/logger.h"
|
||||
#include "tensorrt_llm/common/tllmException.h"
|
||||
|
||||
// Check async op.
|
||||
inline c10::intrusive_ptr<c10d::Work> pgCheckHelper(
|
||||
c10::intrusive_ptr<c10d::Work> work, char const* const file, int const line, char const* info)
|
||||
{
|
||||
if (work == nullptr)
|
||||
{
|
||||
auto const msg = std::string("[TensorRT-LLM][ERROR] empty work returned from: ") + info;
|
||||
tensorrt_llm::common::throwRuntimeError(file, line, msg);
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
work->wait();
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
auto msg = std::string("[TensorRT-LLM][ERROR] Torch distributed operation error: ") + info;
|
||||
std::throw_with_nested(tensorrt_llm::common::TllmException(file, line, msg.c_str()));
|
||||
}
|
||||
|
||||
return work;
|
||||
}
|
||||
|
||||
// Check sync op.
|
||||
inline void pgCheckHelper(bool success, char const* const file, int const line, char const* info)
|
||||
{
|
||||
if (!success)
|
||||
{
|
||||
throw std::runtime_error(std::string("[TensorRT-LLM][ERROR] Torch distributed operation error: ") + info);
|
||||
}
|
||||
}
|
||||
|
||||
#define PGCHECK_THROW(op) pgCheckHelper(op, __FILE__, __LINE__, #op)
|
||||
#define PGCHECK_THROW_WITH_INFO(op, info) pgCheckHelper(op, __FILE__, __LINE__, info)
|
||||
|
||||
inline bool useMPI()
|
||||
{
|
||||
bool useMPI = true;
|
||||
char* val = std::getenv("TLLM_DISABLE_MPI");
|
||||
if (val != nullptr && std::string(val) == "1")
|
||||
{
|
||||
useMPI = false;
|
||||
}
|
||||
return useMPI;
|
||||
}
|
||||
|
||||
namespace tensorrt_llm::pg_utils
|
||||
{
|
||||
|
||||
// ProcessGroup management functions
|
||||
c10::intrusive_ptr<c10d::ProcessGroup> get_world_pg();
|
||||
|
||||
c10::intrusive_ptr<c10d::ProcessGroup> get_local_pg();
|
||||
|
||||
void init_pg(c10::intrusive_ptr<c10d::ProcessGroup> const& process_group_world,
|
||||
c10::intrusive_ptr<c10d::ProcessGroup> const& process_group_local);
|
||||
|
||||
// Tensor wrapping utilities for ProcessGroup operations
|
||||
inline torch::Tensor wrap_tensor(torch::Tensor data)
|
||||
{
|
||||
return data;
|
||||
}
|
||||
|
||||
template <typename T, typename = std::enable_if_t<std::is_arithmetic_v<T>>>
|
||||
torch::Tensor wrap_tensor(T* data, size_t size)
|
||||
{
|
||||
if constexpr (std::is_same_v<std::decay_t<T>, char>)
|
||||
{
|
||||
// `char` does not have a guaranteed specialization in CppTypeToScalarType
|
||||
// across PyTorch builds. Treat `char` as kChar (int8) explicitly.
|
||||
return at::from_blob(data, {static_cast<int64_t>(size)}, c10::TensorOptions{}.dtype(torch::kChar));
|
||||
}
|
||||
else if constexpr (std::is_same_v<std::decay_t<T>, uint64_t>)
|
||||
{
|
||||
// `uint64_t` may not have a guaranteed specialization in CppTypeToScalarType
|
||||
// across PyTorch builds. Treat `uint64_t` as kLong (int64) explicitly.
|
||||
return at::from_blob(data, {static_cast<int64_t>(size)}, c10::TensorOptions{}.dtype(torch::kLong));
|
||||
}
|
||||
else
|
||||
{
|
||||
return at::from_blob(data, {static_cast<int64_t>(size)},
|
||||
c10::TensorOptions{}.dtype(torch::CppTypeToScalarType<std::decay_t<T>>::value));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename = std::enable_if_t<std::is_void_v<T>>, typename = void>
|
||||
torch::Tensor wrap_tensor(T* data, size_t size)
|
||||
{
|
||||
return at::from_blob(data, {static_cast<int64_t>(size)}, c10::TensorOptions{}.dtype(torch::kChar));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
torch::Tensor wrap_tensor(T const* data, size_t size)
|
||||
{
|
||||
return wrap_tensor(const_cast<T*>(data), size);
|
||||
}
|
||||
|
||||
template <typename T, typename = std::enable_if_t<std::is_arithmetic_v<T>>>
|
||||
torch::Tensor wrap_tensor(T& data)
|
||||
{
|
||||
return wrap_tensor(&data, 1);
|
||||
}
|
||||
|
||||
template <typename T, typename = std::enable_if_t<std::is_arithmetic_v<T>>>
|
||||
torch::Tensor wrap_tensor(std::reference_wrapper<T> data)
|
||||
{
|
||||
return wrap_tensor(&data.get(), 1);
|
||||
}
|
||||
|
||||
template <typename T, typename = std::enable_if_t<std::is_arithmetic_v<T>>>
|
||||
torch::Tensor wrap_tensor(T* data)
|
||||
{
|
||||
return wrap_tensor(data, 1);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
torch::Tensor wrap_tensor(std::vector<T>& data)
|
||||
{
|
||||
return wrap_tensor(data.data(), data.size());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
torch::Tensor wrap_tensor(std::vector<T> const& data)
|
||||
{
|
||||
return wrap_tensor(data.data(), data.size());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
torch::Tensor wrap_tensor(std::reference_wrapper<std::vector<T>> data)
|
||||
{
|
||||
auto& ref = data.get();
|
||||
return wrap_tensor(ref.data(), ref.size());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
torch::Tensor wrap_tensor(std::reference_wrapper<std::vector<T> const> data)
|
||||
{
|
||||
auto const& ref = data.get();
|
||||
return wrap_tensor(ref.data(), ref.size());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
torch::Tensor wrap_tensor(std::vector<T>* data)
|
||||
{
|
||||
return wrap_tensor(data->data(), data->size());
|
||||
}
|
||||
|
||||
// ProcessGroup Helper - convenient wrapper around ProcessGroup operations
|
||||
struct PgHelper
|
||||
{
|
||||
c10::intrusive_ptr<c10d::ProcessGroup> pg;
|
||||
|
||||
PgHelper(c10::intrusive_ptr<c10d::ProcessGroup> pg)
|
||||
: pg(pg)
|
||||
{
|
||||
}
|
||||
|
||||
template <typename Input, typename Output>
|
||||
c10::intrusive_ptr<c10d::Work> allgather(
|
||||
Input input, Output output, c10d::AllgatherOptions options = c10d::AllgatherOptions())
|
||||
{
|
||||
auto inputTensor = wrap_tensor(input);
|
||||
auto outputTensor = wrap_tensor(output);
|
||||
|
||||
return pg->_allgather_base(outputTensor, inputTensor, options);
|
||||
}
|
||||
|
||||
template <typename Input>
|
||||
c10::intrusive_ptr<c10d::Work> allreduce(Input input, c10d::AllreduceOptions options = c10d::AllreduceOptions())
|
||||
{
|
||||
std::vector inputs{wrap_tensor(input)};
|
||||
|
||||
return pg->allreduce(inputs, options);
|
||||
}
|
||||
|
||||
template <typename Input>
|
||||
c10::intrusive_ptr<c10d::Work> send(Input input, int dstRank, int tag)
|
||||
{
|
||||
std::vector inputs{wrap_tensor(input)};
|
||||
|
||||
return pg->send(inputs, dstRank, tag);
|
||||
}
|
||||
|
||||
template <typename Output>
|
||||
c10::intrusive_ptr<c10d::Work> recv(Output output, int srcRank, int tag)
|
||||
{
|
||||
std::vector outputs{wrap_tensor(output)};
|
||||
|
||||
return pg->recv(outputs, srcRank, tag);
|
||||
}
|
||||
|
||||
// Variable-size allgather helper implemented via padding + slicing on Tensors.
|
||||
template <typename Input, typename Output, typename SizeT = int64_t>
|
||||
bool allgatherv(Input input, Output output, std::vector<SizeT> const& sizes,
|
||||
c10d::AllgatherOptions options = c10d::AllgatherOptions())
|
||||
{
|
||||
auto const worldSize = pg->getSize();
|
||||
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
static_cast<int>(sizes.size()) == worldSize, "sizes.size() must equal worldSize in allgatherv");
|
||||
|
||||
at::Tensor inputTensor = wrap_tensor(input);
|
||||
SizeT const localSize = static_cast<SizeT>(inputTensor.numel());
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
sizes[pg->getRank()] == localSize, "sizes[rank] must equal local input size in allgatherv");
|
||||
|
||||
SizeT const maxSize = *std::max_element(sizes.begin(), sizes.end());
|
||||
auto tensorOptions = inputTensor.options();
|
||||
|
||||
at::Tensor paddedInput = at::zeros({static_cast<int64_t>(maxSize)}, tensorOptions);
|
||||
if (localSize > 0)
|
||||
{
|
||||
paddedInput.narrow(0, 0, static_cast<int64_t>(localSize)).copy_(inputTensor);
|
||||
}
|
||||
|
||||
at::Tensor paddedOutput
|
||||
= at::empty({static_cast<int64_t>(maxSize) * static_cast<int64_t>(worldSize)}, tensorOptions);
|
||||
|
||||
PGCHECK_THROW(pg->_allgather_base(paddedOutput, paddedInput, options)->wait());
|
||||
|
||||
// Prepare compact output tensor backed by 'output'
|
||||
SizeT const totalSize = std::accumulate(sizes.begin(), sizes.end(), static_cast<SizeT>(0));
|
||||
at::Tensor outputTensor = wrap_tensor(output);
|
||||
TLLM_CHECK_WITH_INFO(outputTensor.numel() == static_cast<int64_t>(totalSize),
|
||||
"output tensor numel must equal total size in allgatherv");
|
||||
|
||||
// Slice and compact
|
||||
size_t writeOffset = 0;
|
||||
for (int r = 0; r < worldSize; ++r)
|
||||
{
|
||||
int64_t const validCount = static_cast<int64_t>(sizes[static_cast<size_t>(r)]);
|
||||
int64_t const srcOffset = static_cast<int64_t>(r) * static_cast<int64_t>(maxSize);
|
||||
if (validCount > 0)
|
||||
{
|
||||
outputTensor.narrow(0, static_cast<int64_t>(writeOffset), validCount)
|
||||
.copy_(paddedOutput.narrow(0, srcOffset, validCount));
|
||||
writeOffset += static_cast<size_t>(validCount);
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// Convenience overload to accept sizes passed via std::cref(...)
|
||||
template <typename Input, typename Output, typename SizeT = int64_t>
|
||||
bool allgatherv(Input input, Output output, std::reference_wrapper<std::vector<SizeT> const> sizes,
|
||||
c10d::AllgatherOptions options = c10d::AllgatherOptions())
|
||||
{
|
||||
return allgatherv<Input, Output, SizeT>(input, output, sizes.get(), options);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::pg_utils
|
||||
@ -107,3 +107,8 @@ if(ENABLE_UCX)
|
||||
target_link_libraries(${BATCH_MANAGER_STATIC_TARGET} PUBLIC)
|
||||
|
||||
endif()
|
||||
|
||||
find_library(TORCH_PYTHON_LIB torch_python REQUIRED
|
||||
HINTS ${TORCH_INSTALL_PREFIX}/lib)
|
||||
target_link_libraries(${BATCH_MANAGER_STATIC_TARGET}
|
||||
PUBLIC ${TORCH_PYTHON_LIB} Python3::Python pg_utils)
|
||||
|
||||
@ -21,7 +21,6 @@
|
||||
#include <limits>
|
||||
#include <sstream>
|
||||
#define UCX_WRAPPER_LIB_NAME "tensorrt_llm_ucx_wrapper"
|
||||
|
||||
#if defined(_WIN32)
|
||||
#include <windows.h>
|
||||
#define dllOpen(name) LoadLibrary(name ".dll")
|
||||
@ -48,6 +47,7 @@
|
||||
#include "tensorrt_llm/executor/dataTransceiverState.h"
|
||||
#include "tensorrt_llm/executor/serializeUtils.h"
|
||||
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
|
||||
#include "tensorrt_llm/runtime/utils/pgUtils.h"
|
||||
#include <algorithm>
|
||||
#include <cstddef>
|
||||
#include <numeric>
|
||||
@ -114,14 +114,22 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa
|
||||
std::vector<SizeType32> const& attentionLayerNumPerPP, nvinfer1::DataType dataType,
|
||||
executor::kv_cache::CacheState::AttentionType attentionType,
|
||||
std::optional<executor::CacheTransceiverConfig> cacheTransceiverConfig)
|
||||
: mMpiGroupComm(std::addressof(tensorrt_llm::mpi::MpiComm::session()))
|
||||
, mCacheTransceiverConfig{cacheTransceiverConfig}
|
||||
: mCacheTransceiverConfig{cacheTransceiverConfig}
|
||||
{
|
||||
using tensorrt_llm::batch_manager::kv_cache_manager::CacheFormatter;
|
||||
if (useMPI())
|
||||
{
|
||||
mGroupComm = std::make_shared<CacheTransceiverComm>(std::addressof(tensorrt_llm::mpi::MpiComm::session()));
|
||||
}
|
||||
else
|
||||
{
|
||||
mGroupComm = std::make_shared<CacheTransceiverComm>(tensorrt_llm::pg_utils::get_world_pg());
|
||||
}
|
||||
|
||||
if (worldConfig.isTensorParallel())
|
||||
{
|
||||
mMpiGroupTensorParaComm = std::make_shared<tensorrt_llm::mpi::MpiComm>(
|
||||
mMpiGroupComm->split(worldConfig.getPipelineParallelRank(), worldConfig.getTensorParallelRank()));
|
||||
mGroupTensorParaComm = std::make_shared<CacheTransceiverComm>(
|
||||
mGroupComm->split(worldConfig.getPipelineParallelRank(), worldConfig.getTensorParallelRank()));
|
||||
}
|
||||
int kvFactor = 2;
|
||||
if (cacheManager->getCacheType() == kv_cache_manager::CacheType::kSELFKONLY)
|
||||
@ -142,12 +150,11 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa
|
||||
- TPRankInDPGroup)
|
||||
/ TPSizeInDPGroup;
|
||||
// <PP,DP,TP>
|
||||
mMpiGroupDataComm
|
||||
= std::make_shared<tensorrt_llm::mpi::MpiComm>(mMpiGroupComm->split(DPRank, worldConfig.getRank()));
|
||||
mGroupDataComm = std::make_shared<CacheTransceiverComm>(mGroupComm->split(DPRank, worldConfig.getRank()));
|
||||
if (worldConfig.isTensorParallel())
|
||||
{
|
||||
mMpiGroupTPInDPComm = std::make_shared<tensorrt_llm::mpi::MpiComm>(
|
||||
mMpiGroupComm->split(worldConfig.getRank() / TPSizeInDPGroup, worldConfig.getRank()));
|
||||
mGroupTPInDPComm = std::make_shared<CacheTransceiverComm>(
|
||||
mGroupComm->split(worldConfig.getRank() / TPSizeInDPGroup, worldConfig.getRank()));
|
||||
}
|
||||
}
|
||||
bool isMLA = attentionType == executor::kv_cache::CacheState::AttentionType::kMLA;
|
||||
@ -302,28 +309,39 @@ void CacheTransceiver::requestAndReceiveAsync(LlmRequest* llmRequest)
|
||||
}
|
||||
|
||||
std::vector<LlmRequest::RequestIdType> gatherRequestIds(
|
||||
mpi::MpiComm const& mpiComm, std::vector<LlmRequest::RequestIdType> const& requestIds)
|
||||
std::shared_ptr<CacheTransceiverComm> const& mComm, std::vector<LlmRequest::RequestIdType> const& requestIds)
|
||||
{
|
||||
int localSize = static_cast<int>(requestIds.size());
|
||||
std::vector<int> sizes(mpiComm.getSize());
|
||||
mpiComm.allgather(&localSize, sizes.data(), 1, mpi::MpiType::kINT32);
|
||||
std::vector<int> displs(mpiComm.getSize());
|
||||
int totalSize = 0;
|
||||
for (int i = 0; i < mpiComm.getSize(); i++)
|
||||
std::vector<int> sizes(mComm->getSize());
|
||||
std::vector<LlmRequest::RequestIdType> retData;
|
||||
if (useMPI())
|
||||
{
|
||||
displs[i] = totalSize;
|
||||
totalSize += sizes[i];
|
||||
mComm->allgather(&localSize, sizes.data(), 1, mpi::MpiType::kINT32);
|
||||
std::vector<int> displs(mComm->getSize());
|
||||
size_t totalSize = 0;
|
||||
for (int i = 0; i < mComm->getSize(); i++)
|
||||
{
|
||||
displs[i] = totalSize;
|
||||
totalSize += sizes[i];
|
||||
}
|
||||
retData.resize(totalSize);
|
||||
mComm->allgatherv(requestIds.data(), static_cast<int>(requestIds.size()), mpi::MpiType::kUINT64, retData.data(),
|
||||
sizes, displs, mpi::MpiType::kUINT64);
|
||||
}
|
||||
else
|
||||
{
|
||||
mComm->allgather(&localSize, std::ref(sizes), {});
|
||||
size_t totalSize = std::accumulate(sizes.begin(), sizes.end(), 0);
|
||||
retData.resize(totalSize);
|
||||
mComm->allgatherv(std::ref(requestIds), std::ref(retData), std::cref(sizes), {});
|
||||
}
|
||||
std::vector<LlmRequest::RequestIdType> retData(totalSize);
|
||||
mpiComm.allgatherv(requestIds.data(), static_cast<int>(requestIds.size()), mpi::MpiType::kUINT64, retData.data(),
|
||||
sizes, displs, mpi::MpiType::kUINT64);
|
||||
return retData;
|
||||
}
|
||||
|
||||
void updateKVCacheTransferBW(mpi::MpiComm const& mpiComm, LlmRequest* request)
|
||||
void updateKVCacheTransferBW(std::shared_ptr<CacheTransceiverComm> const& mComm, LlmRequest* request)
|
||||
{
|
||||
namespace su = executor::serialize_utils;
|
||||
int worldSize = mpiComm.getSize();
|
||||
int worldSize = mComm->getSize();
|
||||
|
||||
std::ostringstream oStream;
|
||||
su::serialize(request->getKvCacheTransferStart(), oStream);
|
||||
@ -335,7 +353,14 @@ void updateKVCacheTransferBW(mpi::MpiComm const& mpiComm, LlmRequest* request)
|
||||
auto recvBufferSize = sendBufferSize * worldSize;
|
||||
std::vector<char> recvBuffer(recvBufferSize);
|
||||
|
||||
mpiComm.allgather(sendBuffer.data(), recvBuffer.data(), sendBufferSize, mpi::MpiType::kCHAR);
|
||||
if (useMPI())
|
||||
{
|
||||
mComm->allgather(sendBuffer.data(), recvBuffer.data(), sendBufferSize, mpi::MpiType::kCHAR);
|
||||
}
|
||||
else
|
||||
{
|
||||
mComm->allgather(std::ref(sendBuffer), std::ref(recvBuffer), {});
|
||||
}
|
||||
|
||||
su::VectorWrapBuf<char> strbuf(recvBuffer);
|
||||
std::istream is(&strbuf);
|
||||
@ -353,7 +378,14 @@ void updateKVCacheTransferBW(mpi::MpiComm const& mpiComm, LlmRequest* request)
|
||||
std::size_t localKVCacheSize = request->getKvCacheSize();
|
||||
std::vector<std::size_t> allKVCacheSizes(worldSize, 0);
|
||||
|
||||
mpiComm.allgather(&localKVCacheSize, allKVCacheSizes.data(), 1, mpi::MpiType::kUINT64);
|
||||
if (useMPI())
|
||||
{
|
||||
mComm->allgather(&localKVCacheSize, allKVCacheSizes.data(), 1, mpi::MpiType::kUINT64);
|
||||
}
|
||||
else
|
||||
{
|
||||
mComm->allgather(&localKVCacheSize, std::ref(allKVCacheSizes), {});
|
||||
}
|
||||
|
||||
std::size_t totalKVCacheSize = 0;
|
||||
for (int rank = 0; rank < worldSize; rank++)
|
||||
@ -362,7 +394,7 @@ void updateKVCacheTransferBW(mpi::MpiComm const& mpiComm, LlmRequest* request)
|
||||
}
|
||||
|
||||
// Update the latest KV cache transfer time for leader rank
|
||||
if (mpiComm.getRank() == 0)
|
||||
if (mComm->getRank() == 0)
|
||||
{
|
||||
request->setKvCacheTransferStart(minStartTime);
|
||||
request->setKvCacheTransferEnd(maxEndTime);
|
||||
@ -373,7 +405,7 @@ void updateKVCacheTransferBW(mpi::MpiComm const& mpiComm, LlmRequest* request)
|
||||
void CacheTransceiver::checkContextTransferStatus(std::optional<int> const& atLeastRequestNum)
|
||||
{
|
||||
bool blockAll = !atLeastRequestNum.has_value();
|
||||
auto syncComm = mCacheState->getParallelConfig().mEnableAttentionDP ? mMpiGroupTPInDPComm : mMpiGroupTensorParaComm;
|
||||
auto syncComm = mCacheState->getParallelConfig().mEnableAttentionDP ? mGroupTPInDPComm : mGroupTensorParaComm;
|
||||
std::vector<LlmRequest::RequestIdType> contextCompleteRequestIds;
|
||||
for (auto&& [request, future] : mSenderFutures)
|
||||
{
|
||||
@ -386,7 +418,7 @@ void CacheTransceiver::checkContextTransferStatus(std::optional<int> const& atLe
|
||||
std::unordered_map<LlmRequest::RequestIdType, int> frequencyMap;
|
||||
if ((syncComm) && syncComm->getSize() > 1)
|
||||
{
|
||||
auto gatherRequestIdVec = gatherRequestIds(*syncComm, contextCompleteRequestIds);
|
||||
auto gatherRequestIdVec = gatherRequestIds(syncComm, contextCompleteRequestIds);
|
||||
for (auto&& requestId : gatherRequestIdVec)
|
||||
{
|
||||
frequencyMap[requestId]++;
|
||||
@ -462,10 +494,10 @@ void CacheTransceiver::checkGenTransferStatus(std::optional<int> const& atLeastR
|
||||
std::unordered_map<LlmRequest::RequestIdType, int> frequencyMap;
|
||||
|
||||
std::vector<LlmRequest::RequestIdType> toBlockRequestIds;
|
||||
auto syncComm = mCacheState->getParallelConfig().mEnableAttentionDP ? mMpiGroupDataComm.get() : mMpiGroupComm;
|
||||
auto syncComm = mCacheState->getParallelConfig().mEnableAttentionDP ? mGroupDataComm : mGroupComm;
|
||||
if ((syncComm) && syncComm->getSize() > 1)
|
||||
{
|
||||
auto gatherRequestIdVec = gatherRequestIds(*syncComm, genTransferReadyRequestIds);
|
||||
auto gatherRequestIdVec = gatherRequestIds(syncComm, genTransferReadyRequestIds);
|
||||
for (auto&& requestId : gatherRequestIdVec)
|
||||
{
|
||||
frequencyMap[requestId]++;
|
||||
@ -493,8 +525,16 @@ void CacheTransceiver::checkGenTransferStatus(std::optional<int> const& atLeastR
|
||||
break;
|
||||
}
|
||||
toCompleteIdSet.insert(freqVec.at(idx).first);
|
||||
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(), " checkGenTransferStatus at least from freqVec requestId: %zu ",
|
||||
freqVec.at(idx).first);
|
||||
if (useMPI())
|
||||
{
|
||||
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
|
||||
" checkGenTransferStatus at least from freqVec requestId: %zu ", freqVec.at(idx).first);
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_LOG_DEBUG(tensorrt_llm::pg_utils::get_world_pg()->getRank(),
|
||||
" checkGenTransferStatus at least from freqVec requestId: %zu ", freqVec.at(idx).first);
|
||||
}
|
||||
idx++;
|
||||
}
|
||||
idx = 0;
|
||||
@ -509,9 +549,18 @@ void CacheTransceiver::checkGenTransferStatus(std::optional<int> const& atLeastR
|
||||
if (toCompleteIdSet.find(mRequesterFutures.at(idx).first->mRequestId) == toCompleteIdSet.end())
|
||||
{
|
||||
toCompleteIdSet.insert(mRequesterFutures.at(idx).first->mRequestId);
|
||||
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
|
||||
" checkGenTransferStatus at least from RequesterFuture requestId: %zu atLeastRequestNum:%d",
|
||||
mRequesterFutures.at(idx).first->mRequestId, atLeastRequestNum.value_or(0));
|
||||
if (useMPI())
|
||||
{
|
||||
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
|
||||
" checkGenTransferStatus at least from RequesterFuture requestId: %zu atLeastRequestNum:%d",
|
||||
mRequesterFutures.at(idx).first->mRequestId, atLeastRequestNum.value_or(0));
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_LOG_DEBUG(tensorrt_llm::pg_utils::get_world_pg()->getRank(),
|
||||
" checkGenTransferStatus at least from RequesterFuture requestId: %zu atLeastRequestNum:%d",
|
||||
mRequesterFutures.at(idx).first->mRequestId, atLeastRequestNum.value_or(0));
|
||||
}
|
||||
}
|
||||
idx++;
|
||||
}
|
||||
@ -521,12 +570,29 @@ void CacheTransceiver::checkGenTransferStatus(std::optional<int> const& atLeastR
|
||||
{
|
||||
toCompleteIdSet.insert(requestId);
|
||||
}
|
||||
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(), " checkGenTransferStatus freqVec requestId: %zu,freq:%d ",
|
||||
requestId, freq);
|
||||
if (useMPI())
|
||||
{
|
||||
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(), " checkGenTransferStatus freqVec requestId: %zu,freq:%d ",
|
||||
requestId, freq);
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_LOG_DEBUG(tensorrt_llm::pg_utils::get_world_pg()->getRank(),
|
||||
" checkGenTransferStatus freqVec requestId: %zu,freq:%d ", requestId, freq);
|
||||
}
|
||||
}
|
||||
if (useMPI())
|
||||
{
|
||||
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
|
||||
" checkGenTransferStatus toCompleteIdSet size: %zu, atLeastRequestNum: %d ", toCompleteIdSet.size(),
|
||||
atLeastRequestNum.value_or(0));
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_LOG_DEBUG(tensorrt_llm::pg_utils::get_world_pg()->getRank(),
|
||||
" checkGenTransferStatus toCompleteIdSet size: %zu, atLeastRequestNum: %d ", toCompleteIdSet.size(),
|
||||
atLeastRequestNum.value_or(0));
|
||||
}
|
||||
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
|
||||
" checkGenTransferStatus toCompleteIdSet size: %zu, atLeastRequestNum: %d ", toCompleteIdSet.size(),
|
||||
atLeastRequestNum.value_or(0));
|
||||
for (auto it = mRequesterFutures.begin(); it != mRequesterFutures.end();)
|
||||
{
|
||||
if (blockAll || toCompleteIdSet.find(it->first->mRequestId) != toCompleteIdSet.end())
|
||||
@ -539,9 +605,8 @@ void CacheTransceiver::checkGenTransferStatus(std::optional<int> const& atLeastR
|
||||
// Gather the kv cache transfer time from all workers and update to leader rank
|
||||
if (!common::getEnvKVCacheTransferOutputPath().empty())
|
||||
{
|
||||
auto syncComm
|
||||
= mCacheState->getParallelConfig().mEnableAttentionDP ? mMpiGroupDataComm.get() : mMpiGroupComm;
|
||||
updateKVCacheTransferBW(*syncComm, it->first);
|
||||
auto syncComm = mCacheState->getParallelConfig().mEnableAttentionDP ? mGroupDataComm : mGroupComm;
|
||||
updateKVCacheTransferBW(syncComm, it->first);
|
||||
}
|
||||
}
|
||||
catch (std::exception const& e)
|
||||
@ -550,9 +615,18 @@ void CacheTransceiver::checkGenTransferStatus(std::optional<int> const& atLeastR
|
||||
"Error occurred during generation transfer for request %ld: %s", it->first->mRequestId, e.what());
|
||||
it->first->setState(LlmRequestState::kDISAGG_TRANS_ERROR);
|
||||
}
|
||||
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
|
||||
"**** it->first->mRequestId: %ld, context request ID: %ld ******** get feature ***",
|
||||
it->first->mRequestId, it->first->getContextPhaseParams().value().getReqId());
|
||||
if (useMPI())
|
||||
{
|
||||
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
|
||||
"**** it->first->mRequestId: %ld, context request ID: %ld ******** get feature ***",
|
||||
it->first->mRequestId, it->first->getContextPhaseParams().value().getReqId());
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_LOG_DEBUG(tensorrt_llm::pg_utils::get_world_pg()->getRank(),
|
||||
"**** it->first->mRequestId: %ld, context request ID: %ld ******** get feature ***",
|
||||
it->first->mRequestId, it->first->getContextPhaseParams().value().getReqId());
|
||||
}
|
||||
it = mRequesterFutures.erase(it);
|
||||
}
|
||||
else
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
if(ENABLE_UCX)
|
||||
find_package(ucx REQUIRED)
|
||||
find_package(ucxx REQUIRED)
|
||||
find_package(Torch REQUIRED)
|
||||
|
||||
include_directories(${3RDPARTY_DIR}/cppzmq)
|
||||
|
||||
@ -23,12 +24,16 @@ if(ENABLE_UCX)
|
||||
set(TOP_LEVEL_DIR "${PROJECT_SOURCE_DIR}/..")
|
||||
target_compile_definitions(${UCX_WRAPPER_TARGET}
|
||||
PUBLIC TOP_LEVEL_DIR="${TOP_LEVEL_DIR}")
|
||||
|
||||
target_include_directories(${UCX_WRAPPER_TARGET}
|
||||
PRIVATE ${PROJECT_SOURCE_DIR}/include)
|
||||
target_include_directories(${UCX_WRAPPER_TARGET} PRIVATE ${ZMQ_INCLUDE_DIRS})
|
||||
|
||||
target_link_libraries(${UCX_WRAPPER_TARGET}
|
||||
PRIVATE $<LINK_LIBRARY:WHOLE_ARCHIVE,ucxx::ucxx>)
|
||||
target_link_libraries(${UCX_WRAPPER_TARGET} PUBLIC ucxx::ucxx ucx::ucs)
|
||||
target_link_libraries(${UCX_WRAPPER_TARGET} PUBLIC ${CUDA_RT_LIB})
|
||||
|
||||
# Add include directories
|
||||
target_include_directories(${UCX_WRAPPER_TARGET} PRIVATE ${ZMQ_INCLUDE_DIRS})
|
||||
target_link_libraries(${UCX_WRAPPER_TARGET} PUBLIC ${TORCH_LIBRARIES})
|
||||
target_link_libraries(${UCX_WRAPPER_TARGET} PRIVATE ${ZMQ_LIBRARIES})
|
||||
target_link_libraries(${UCX_WRAPPER_TARGET} PRIVATE pg_utils)
|
||||
endif()
|
||||
|
||||
@ -81,15 +81,15 @@ UcxConnection::UcxConnection(ConnectionIdType connectionId, std::shared_ptr<ucxx
|
||||
}
|
||||
catch (std::exception const& e)
|
||||
{
|
||||
std::string error = "Error in UcxConnection constructor for rank "
|
||||
+ std::to_string(mpi::MpiComm::world().getRank()) + ": " + e.what();
|
||||
std::string error = std::string("Error in UcxConnection constructor for rank ")
|
||||
+ std::to_string(mManager->getRank()) + ": " + e.what();
|
||||
TLLM_THROW(error);
|
||||
}
|
||||
|
||||
mSendTagPrefix = mConnectionIdInPeer;
|
||||
mRecvTagPrefix = mConnectionId;
|
||||
|
||||
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
|
||||
TLLM_LOG_DEBUG(mManager->getRank(),
|
||||
"UcxConnection::UcxConnection, mConnectionId: %lu, mConnectionIdInPeer: %lu,fromRequester: %d", mConnectionId,
|
||||
mConnectionIdInPeer, mFromRequester);
|
||||
}
|
||||
@ -97,7 +97,7 @@ UcxConnection::UcxConnection(ConnectionIdType connectionId, std::shared_ptr<ucxx
|
||||
UcxConnection::~UcxConnection()
|
||||
{
|
||||
|
||||
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
|
||||
TLLM_LOG_DEBUG(mManager->getRank(),
|
||||
"UcxConnection::~UcxConnection, mConnectionId: %lu, mConnectionIdInPeer: %lu,fromRequester: %d", mConnectionId,
|
||||
mConnectionIdInPeer, mFromRequester);
|
||||
// TODO: how to close the endpoint safely?
|
||||
@ -105,7 +105,7 @@ UcxConnection::~UcxConnection()
|
||||
|
||||
void UcxConnection::sendConnectionId(DataContext const& ctx, void const* data, size_t size) const
|
||||
{
|
||||
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
|
||||
TLLM_LOG_DEBUG(mManager->getRank(),
|
||||
"start UcxConnection::sendConnectionId , mConnectionId: %lu, mConnectionIdInPeer: %lu,fromRequester: %d",
|
||||
mConnectionId, mConnectionIdInPeer, mFromRequester);
|
||||
|
||||
@ -126,7 +126,7 @@ void UcxConnection::sendConnectionId(DataContext const& ctx, void const* data, s
|
||||
}
|
||||
TLLM_CHECK_WITH_INFO(req->isCompleted(), "sendConnectionId should be completed");
|
||||
req->checkError();
|
||||
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
|
||||
TLLM_LOG_DEBUG(mManager->getRank(),
|
||||
"end UcxConnection::sendConnectionId , mConnectionId: %lu, mConnectionIdInPeer: %lu,fromRequester: %d",
|
||||
mConnectionId, mConnectionIdInPeer, mFromRequester);
|
||||
}
|
||||
@ -138,7 +138,7 @@ void UcxConnection::send(DataContext const& ctx, void const* data, size_t size)
|
||||
sendConnectionId(ctx, data, size);
|
||||
return;
|
||||
}
|
||||
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
|
||||
TLLM_LOG_DEBUG(mManager->getRank(),
|
||||
"start UcxConnection::send , mConnectionId: %lu, mConnectionIdInPeer: %lu,fromRequester: %d", mConnectionId,
|
||||
mConnectionIdInPeer, mFromRequester);
|
||||
|
||||
@ -156,7 +156,8 @@ void UcxConnection::send(DataContext const& ctx, void const* data, size_t size)
|
||||
TLLM_CHECK_WITH_INFO(req->isCompleted(), "send should be completed");
|
||||
// throw if there is error
|
||||
req->checkError();
|
||||
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
|
||||
|
||||
TLLM_LOG_DEBUG(mManager->getRank(),
|
||||
"end UcxConnection::send , mConnectionId: %lu, mConnectionIdInPeer: %lu,fromRequester: %d", mConnectionId,
|
||||
mConnectionIdInPeer, mFromRequester);
|
||||
}
|
||||
@ -164,7 +165,7 @@ void UcxConnection::send(DataContext const& ctx, void const* data, size_t size)
|
||||
void UcxConnection::recv(DataContext const& ctx, void* data, size_t size) const
|
||||
{
|
||||
// Guard to ensure CUDA context is initialized for UCX ops
|
||||
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
|
||||
TLLM_LOG_DEBUG(mManager->getRank(),
|
||||
"start UcxConnection::recv , mConnectionId: %lu, mConnectionIdInPeer: %lu,fromRequester: %d", mConnectionId,
|
||||
mConnectionIdInPeer, mFromRequester);
|
||||
TLLM_CHECK_WITH_INFO((mEndpoint), "recvBuffer called without established communicator channel.");
|
||||
@ -180,7 +181,8 @@ void UcxConnection::recv(DataContext const& ctx, void* data, size_t size) const
|
||||
TLLM_CHECK_WITH_INFO(req->isCompleted(), "recv should be completed");
|
||||
// throw if there is error
|
||||
req->checkError();
|
||||
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
|
||||
|
||||
TLLM_LOG_DEBUG(mManager->getRank(),
|
||||
"end UcxConnection::recv , mConnectionId: %lu, mConnectionIdInPeer: %lu,fromRequester: %d", mConnectionId,
|
||||
mConnectionIdInPeer, mFromRequester);
|
||||
}
|
||||
|
||||
@ -19,14 +19,31 @@
|
||||
#include "tensorrt_llm/common/logger.h"
|
||||
#include "tensorrt_llm/executor/dataTransceiverState.h"
|
||||
#include "tensorrt_llm/executor/serializeUtils.h"
|
||||
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
|
||||
#include "tensorrt_llm/runtime/utils/pgUtils.h"
|
||||
#include <chrono>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <exception>
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <mutex>
|
||||
#include <numeric>
|
||||
#include <regex>
|
||||
#include <string>
|
||||
#include <sys/socket.h>
|
||||
#include <thread>
|
||||
#include <ucxx/address.h>
|
||||
#include <ucxx/typedefs.h>
|
||||
#include <unistd.h>
|
||||
#include <vector>
|
||||
|
||||
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
|
||||
#include <torch/csrc/distributed/c10d/Store.hpp>
|
||||
#include <torch/csrc/distributed/c10d/TCPStore.hpp>
|
||||
|
||||
using tensorrt_llm::pg_utils::get_world_pg;
|
||||
using tensorrt_llm::pg_utils::PgHelper;
|
||||
|
||||
namespace tensorrt_llm::executor::kv_cache
|
||||
{
|
||||
@ -73,12 +90,12 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
std::string getLocalIpByNic(std::string const& interface)
|
||||
std::string getLocalIpByNic(std::string const& interface, int rank)
|
||||
{
|
||||
struct ifaddrs* ifaddr = nullptr;
|
||||
if (getifaddrs(&ifaddr) == -1)
|
||||
{
|
||||
TLLM_LOG_ERROR(mpi::MpiComm::world().getRank(),
|
||||
TLLM_LOG_ERROR(rank,
|
||||
"getLocalIpByNic: Can't get local ip from NIC Interface. Please check whether TRTLLM_UCX_INTERFACE is set "
|
||||
"correctly.");
|
||||
return std::string{};
|
||||
@ -118,17 +135,17 @@ std::string getLocalIpByNic(std::string const& interface)
|
||||
}
|
||||
|
||||
freeifaddrs(ifaddr);
|
||||
TLLM_LOG_ERROR(mpi::MpiComm::world().getRank(),
|
||||
"Can't get local ip from NIC Interface. Please check whether TRTLLM_UCX_INTERFACE is set correctly.");
|
||||
TLLM_LOG_ERROR(
|
||||
rank, "Can't get local ip from NIC Interface. Please check whether TRTLLM_UCX_INTERFACE is set correctly.");
|
||||
return std::string{};
|
||||
}
|
||||
|
||||
std::string getLocalIpByHostname()
|
||||
std::string getLocalIpByHostname(int rank)
|
||||
{
|
||||
char hostname[256]{};
|
||||
if (gethostname(hostname, sizeof(hostname)) == -1)
|
||||
{
|
||||
TLLM_LOG_ERROR(mpi::MpiComm::world().getRank(), "getLocalIpByHostname: Can't get hostname");
|
||||
TLLM_LOG_ERROR(rank, "getLocalIpByHostname: Can't get hostname");
|
||||
return std::string{};
|
||||
}
|
||||
|
||||
@ -140,7 +157,7 @@ std::string getLocalIpByHostname()
|
||||
struct addrinfo* res = nullptr;
|
||||
if (getaddrinfo(hostname, nullptr, &hints, &res) != 0)
|
||||
{
|
||||
TLLM_LOG_WARNING(mpi::MpiComm::world().getRank(), "getLocalIpByHostname: Can't get address info for hostname");
|
||||
TLLM_LOG_WARNING(rank, "getLocalIpByHostname: Can't get address info for hostname");
|
||||
return std::string{};
|
||||
}
|
||||
|
||||
@ -174,11 +191,11 @@ std::string getLocalIpByHostname()
|
||||
}
|
||||
|
||||
freeaddrinfo(res);
|
||||
TLLM_LOG_WARNING(mpi::MpiComm::world().getRank(), "getLocalIpByHostname: Can't get local ip from hostname");
|
||||
TLLM_LOG_WARNING(rank, "getLocalIpByHostname: Can't get local ip from hostname");
|
||||
return std::string{};
|
||||
}
|
||||
|
||||
std::string getLocalIpByRemoteOrHostName()
|
||||
std::string getLocalIpByRemoteOrHostName(int rank)
|
||||
{
|
||||
|
||||
// Try IPv4
|
||||
@ -238,20 +255,20 @@ std::string getLocalIpByRemoteOrHostName()
|
||||
}
|
||||
|
||||
// Try hostname
|
||||
return getLocalIpByHostname();
|
||||
return getLocalIpByHostname(rank);
|
||||
}
|
||||
|
||||
static std::string getLocalIp()
|
||||
static std::string getLocalIp(int rank)
|
||||
{
|
||||
std::string ucxInterface = common::getEnvUCXInterface();
|
||||
std::string localIP = {};
|
||||
if (!ucxInterface.empty())
|
||||
{
|
||||
localIP = getLocalIpByNic(ucxInterface);
|
||||
localIP = getLocalIpByNic(ucxInterface, rank);
|
||||
}
|
||||
if (localIP.empty())
|
||||
{
|
||||
localIP = getLocalIpByRemoteOrHostName();
|
||||
localIP = getLocalIpByRemoteOrHostName(rank);
|
||||
}
|
||||
// check whether the localIP is valid
|
||||
if (localIP.empty())
|
||||
@ -278,10 +295,31 @@ std::optional<std::pair<std::string, int>> parse_zmq_endpoint(std::string const&
|
||||
}
|
||||
|
||||
UcxConnectionManager::UcxConnectionManager()
|
||||
|
||||
{
|
||||
try
|
||||
{
|
||||
if (useMPI())
|
||||
{
|
||||
mRank = mpi::MpiComm::session().getRank();
|
||||
mWorldSize = mpi::MpiComm::session().getSize();
|
||||
}
|
||||
else
|
||||
{
|
||||
auto const worldPg = get_world_pg();
|
||||
if (worldPg)
|
||||
{
|
||||
mRank = worldPg->getRank();
|
||||
mWorldSize = worldPg->getSize();
|
||||
TLLM_LOG_DEBUG(mRank, "UCX using Torch process group - rank: %d, world size: %d", mRank, mWorldSize);
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_LOG_DEBUG(mRank, "WARNING: Process group is null, defaulting to single process");
|
||||
mRank = 0;
|
||||
mWorldSize = 1;
|
||||
}
|
||||
}
|
||||
|
||||
TLLM_CUDA_CHECK(cudaGetDevice(&mDevice));
|
||||
mUcxCtx = ucxx::createContext({{"RNDV_PIPELINE_ERROR_HANDLING", "y"}}, UCP_FEATURE_TAG);
|
||||
int device = mDevice;
|
||||
@ -302,7 +340,7 @@ UcxConnectionManager::UcxConnectionManager()
|
||||
|
||||
mZmqRepSocket = zmq::socket_t(mZmqContext, zmq::socket_type::rep);
|
||||
mZmqRepSocket.set(zmq::sockopt::sndhwm, 1000);
|
||||
std::string localIp = getLocalIp();
|
||||
std::string localIp = getLocalIp(mRank);
|
||||
if (localIp.find(':') != std::string::npos)
|
||||
{
|
||||
// ipv6
|
||||
@ -310,62 +348,85 @@ UcxConnectionManager::UcxConnectionManager()
|
||||
|
||||
localIp = "[" + localIp + "]";
|
||||
}
|
||||
TLLM_LOG_INFO(
|
||||
mpi::MpiComm::world().getRank(), "UcxConnectionManager::UcxConnectionManager localIp: %s", localIp.c_str());
|
||||
TLLM_LOG_INFO(mRank, "UcxConnectionManager::UcxConnectionManager localIp: %s", localIp.c_str());
|
||||
mZmqRepSocket.bind("tcp://" + localIp + ":*");
|
||||
mZmqRepEndpoint = mZmqRepSocket.get(zmq::sockopt::last_endpoint);
|
||||
TLLM_LOG_INFO(mpi::MpiComm::world().getRank(), "UcxConnectionManager::UcxConnectionManager mZmqRepEndpoint: %s",
|
||||
mZmqRepEndpoint.c_str());
|
||||
TLLM_LOG_INFO(mRank, "UcxConnectionManager::UcxConnectionManager mZmqRepEndpoint: %s", mZmqRepEndpoint.c_str());
|
||||
auto parse_result = parse_zmq_endpoint(mZmqRepEndpoint);
|
||||
TLLM_CHECK_WITH_INFO(parse_result.has_value(), "Failed to parse ZMQ endpoint");
|
||||
auto [ip, port] = parse_result.value();
|
||||
TLLM_LOG_INFO(mpi::MpiComm::world().getRank(), "UcxConnectionManager::UcxConnectionManager ip: %s, port: %d",
|
||||
ip.c_str(), port);
|
||||
TLLM_LOG_INFO(mRank, "UcxConnectionManager::UcxConnectionManager ip: %s, port: %d", ip.c_str(), port);
|
||||
|
||||
SocketState socketState{static_cast<uint16_t>(port), ip};
|
||||
std::vector<executor::kv_cache::SocketState> socketStates(mpi::MpiComm::session().getSize());
|
||||
std::vector<executor::kv_cache::SocketState> socketStates(mWorldSize);
|
||||
|
||||
if (mpi::MpiComm::session().getSize() > 1)
|
||||
{
|
||||
|
||||
mpi::MpiComm::session().barrier();
|
||||
namespace su = executor::serialize_utils;
|
||||
|
||||
std::ostringstream oStream;
|
||||
su::serialize(socketState, oStream);
|
||||
auto str = oStream.str();
|
||||
std::vector<char> buffer(str.begin(), str.end());
|
||||
std::vector<SizeType32> sizeofBuffer(mpi::MpiComm::session().getSize());
|
||||
SizeType32 bufferSize = buffer.size();
|
||||
mpi::MpiComm::session().allgather(&bufferSize, sizeofBuffer.data(), 1, mpi::MpiType::kINT32);
|
||||
SizeType32 recvBufferSize = std::accumulate(sizeofBuffer.begin(), sizeofBuffer.end(), 0);
|
||||
std::vector<char> recvBuffer(recvBufferSize);
|
||||
std::vector<int> displs(mpi::MpiComm::session().getSize());
|
||||
for (int r = 0; r < mpi::MpiComm::session().getSize(); r++)
|
||||
{
|
||||
displs[r] = (r == 0) ? 0 : (displs[r - 1] + sizeofBuffer[r - 1]);
|
||||
}
|
||||
mpi::MpiComm::session().allgatherv(buffer.data(), bufferSize, mpi::MpiType::kCHAR, recvBuffer.data(),
|
||||
sizeofBuffer, displs, mpi::MpiType::kCHAR);
|
||||
|
||||
// deserialize
|
||||
for (int i = 0; i < mpi::MpiComm::session().getSize(); i++)
|
||||
{
|
||||
std::vector<char> serBuffer(
|
||||
recvBuffer.begin() + displs[i], recvBuffer.begin() + (displs[i] + sizeofBuffer[i]));
|
||||
su::VectorWrapBuf<char> strbuf(serBuffer);
|
||||
std::istream is(&strbuf);
|
||||
socketStates[i] = su::deserialize<executor::kv_cache::SocketState>(is);
|
||||
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(), " recv socketStates[%d]: %s", i,
|
||||
socketStates[i].toString().c_str());
|
||||
}
|
||||
}
|
||||
else
|
||||
if (mWorldSize == 1)
|
||||
{
|
||||
socketStates[0] = socketState;
|
||||
}
|
||||
mCommState = CommState(socketStates, mpi::MpiComm::session().getRank());
|
||||
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(), " ***** UCX mCommState: %s", mCommState.toString().c_str());
|
||||
else
|
||||
{
|
||||
namespace su = executor::serialize_utils;
|
||||
std::ostringstream oStream;
|
||||
su::serialize(socketState, oStream);
|
||||
auto serializedData = oStream.str();
|
||||
std::vector<char> buffer(serializedData.begin(), serializedData.end());
|
||||
std::vector<SizeType32> sizeofBuffer(mWorldSize);
|
||||
SizeType32 bufferSize = buffer.size();
|
||||
|
||||
if (useMPI())
|
||||
{
|
||||
mpi::MpiComm::session().barrier();
|
||||
|
||||
mpi::MpiComm::session().allgather(&bufferSize, sizeofBuffer.data(), 1, mpi::MpiType::kINT32);
|
||||
SizeType32 recvBufferSize = std::accumulate(sizeofBuffer.begin(), sizeofBuffer.end(), 0);
|
||||
std::vector<char> recvBuffer(recvBufferSize);
|
||||
std::vector<int> displs(mpi::MpiComm::session().getSize());
|
||||
for (int r = 0; r < mpi::MpiComm::session().getSize(); r++)
|
||||
{
|
||||
displs[r] = (r == 0) ? 0 : (displs[r - 1] + sizeofBuffer[r - 1]);
|
||||
}
|
||||
mpi::MpiComm::session().allgatherv(buffer.data(), bufferSize, mpi::MpiType::kCHAR, recvBuffer.data(),
|
||||
sizeofBuffer, displs, mpi::MpiType::kCHAR);
|
||||
|
||||
// deserialize
|
||||
for (int i = 0; i < mpi::MpiComm::session().getSize(); i++)
|
||||
{
|
||||
std::vector<char> serBuffer(
|
||||
recvBuffer.begin() + displs[i], recvBuffer.begin() + (displs[i] + sizeofBuffer[i]));
|
||||
su::VectorWrapBuf<char> strbuf(serBuffer);
|
||||
std::istream is(&strbuf);
|
||||
socketStates[i] = su::deserialize<executor::kv_cache::SocketState>(is);
|
||||
TLLM_LOG_DEBUG(mRank, " recv socketStates[%d]: %s", i, socketStates[i].toString().c_str());
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
auto const worldPg = get_world_pg();
|
||||
PgHelper pgh{worldPg};
|
||||
PGCHECK_THROW(worldPg->barrier());
|
||||
|
||||
PGCHECK_THROW(pgh.allgather(&bufferSize, std::ref(sizeofBuffer), {}));
|
||||
SizeType32 recvBufferSize = std::accumulate(sizeofBuffer.begin(), sizeofBuffer.end(), 0);
|
||||
std::vector<char> recvBuffer(recvBufferSize);
|
||||
|
||||
PGCHECK_THROW(pgh.allgatherv(std::ref(buffer), std::ref(recvBuffer), std::cref(sizeofBuffer), {}));
|
||||
|
||||
// deserialize
|
||||
char* begin = reinterpret_cast<char*>(recvBuffer.data());
|
||||
for (int r = 0; r < mWorldSize; ++r)
|
||||
{
|
||||
std::vector<char> serBuffer(begin, begin + sizeofBuffer[r]);
|
||||
begin += sizeofBuffer[r];
|
||||
su::VectorWrapBuf<char> strbuf(serBuffer);
|
||||
std::istream is(&strbuf);
|
||||
socketStates[r] = su::deserialize<executor::kv_cache::SocketState>(is);
|
||||
TLLM_LOG_DEBUG(mRank, " recv socketStates[%d]: %s", r, socketStates[r].toString().c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
mCommState = CommState(socketStates, mRank);
|
||||
TLLM_LOG_DEBUG(mRank, " ***** UCX mCommState: %s", mCommState.toString().c_str());
|
||||
|
||||
mZmqRepThread = std::thread(
|
||||
[this]()
|
||||
@ -417,7 +478,7 @@ UcxConnectionManager::UcxConnectionManager()
|
||||
|
||||
UcxConnectionManager::~UcxConnectionManager()
|
||||
{
|
||||
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(), "UcxConnectionManager::~UcxConnectionManager");
|
||||
TLLM_LOG_DEBUG(mRank, "UcxConnectionManager::~UcxConnectionManager");
|
||||
|
||||
for (auto& worker : mWorkersPool)
|
||||
{
|
||||
@ -447,7 +508,7 @@ UcxConnectionManager::~UcxConnectionManager()
|
||||
mZmqRepSocket.close();
|
||||
|
||||
mZmqContext.close();
|
||||
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(), "END UcxConnectionManager::~UcxConnectionManager");
|
||||
TLLM_LOG_DEBUG(mRank, "END UcxConnectionManager::~UcxConnectionManager");
|
||||
}
|
||||
|
||||
void UcxConnectionManager::addConnection(std::string const& workerAddress)
|
||||
@ -472,8 +533,7 @@ void UcxConnectionManager::addConnection(std::string const& workerAddress)
|
||||
}
|
||||
catch (std::exception const& e)
|
||||
{
|
||||
std::string error = "Error in addConnection(connRequest) for rank "
|
||||
+ std::to_string(mpi::MpiComm::world().getRank()) + ": " + e.what();
|
||||
std::string error = "Error in addConnection(connRequest) for rank " + std::to_string(mRank) + ": " + e.what();
|
||||
TLLM_THROW(error);
|
||||
}
|
||||
}
|
||||
@ -538,8 +598,8 @@ UcxConnection::ConnectionIdType UcxConnectionManager::addConnection(std::string
|
||||
}
|
||||
catch (std::exception const& e)
|
||||
{
|
||||
std::string error = "Error in addConnection(ip) for rank " + std::to_string(mpi::MpiComm::world().getRank())
|
||||
+ " ip: " + ip + " port: " + std::to_string(port) + ": " + e.what();
|
||||
std::string error = "Error in addConnection(ip) for rank " + std::to_string(mRank) + " ip: " + ip
|
||||
+ " port: " + std::to_string(port) + ": " + e.what();
|
||||
TLLM_THROW(error);
|
||||
}
|
||||
}
|
||||
@ -570,20 +630,18 @@ Connection const* UcxConnectionManager::recvConnect(DataContext const& ctx, void
|
||||
= *reinterpret_cast<UcxConnection::ConnectionIdType*>(buffer.data() + size);
|
||||
std::scoped_lock lock(mConnectionsMutex, mConnectionFuturesMutex);
|
||||
TLLM_CHECK_WITH_INFO(mConnectionFutures.find(connectionId) != mConnectionFutures.end(),
|
||||
"connectionFuture not found In recvConnect connectionId : %lu , worldRank: %d", connectionId,
|
||||
mpi::MpiComm::world().getRank());
|
||||
"connectionFuture not found In recvConnect connectionId : %lu , worldRank: %d", connectionId, mRank);
|
||||
if (mConnectionFutures.at(connectionId).valid())
|
||||
{
|
||||
// wait for the connection to be created
|
||||
mConnectionFutures.at(connectionId).get();
|
||||
}
|
||||
TLLM_CHECK_WITH_INFO(mConnections.find(connectionId) != mConnections.end(),
|
||||
"Connection not found In recvConnect connectionId: %lu , worldRank: %d", connectionId,
|
||||
mpi::MpiComm::world().getRank());
|
||||
"Connection not found In recvConnect connectionId: %lu , worldRank: %d", connectionId, mRank);
|
||||
|
||||
TLLM_CHECK(!mConnections[connectionId]->isFromRequester());
|
||||
|
||||
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(), "recvConnect connectionId: %lu , sendIDData:%lu", connectionId,
|
||||
TLLM_LOG_DEBUG(mRank, "recvConnect connectionId: %lu , sendIDData:%lu", connectionId,
|
||||
*reinterpret_cast<uint64_t*>(buffer.data()));
|
||||
|
||||
return mConnections[connectionId].get();
|
||||
|
||||
@ -55,6 +55,8 @@ private:
|
||||
std::mutex mAddressToConnectionIdMutex;
|
||||
CommState mCommState;
|
||||
int mDevice;
|
||||
int mRank;
|
||||
int mWorldSize;
|
||||
std::atomic<UcxConnection::ConnectionIdType> mConnectionIdCounter{1};
|
||||
zmq::context_t mZmqContext;
|
||||
zmq::socket_t mZmqRepSocket;
|
||||
@ -78,6 +80,11 @@ public:
|
||||
Connection const* recvConnect(DataContext const& ctx, void* data, size_t size) override;
|
||||
std::vector<Connection const*> getConnections(CommState const& state) override;
|
||||
[[nodiscard]] CommState const& getCommState() const override;
|
||||
|
||||
[[nodiscard]] int getRank() const
|
||||
{
|
||||
return mRank;
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(__clang__)
|
||||
|
||||
@ -15,6 +15,7 @@ set(SRCS
|
||||
executor/executor.cpp
|
||||
executor/executorConfig.cpp
|
||||
executor/request.cpp
|
||||
process_group/bindings.cpp
|
||||
runtime/bindings.cpp
|
||||
runtime/hostfunc.cpp
|
||||
runtime/moeBindings.cpp
|
||||
@ -46,7 +47,8 @@ target_link_libraries(
|
||||
torch_python
|
||||
${CUDA_DRV_LIB}
|
||||
${CUDA_NVML_LIB}
|
||||
th_common)
|
||||
th_common
|
||||
pg_utils)
|
||||
target_compile_definitions(
|
||||
${TRTLLM_NB_MODULE} PUBLIC TRTLLM_NB_MODULE=${TRTLLM_NB_MODULE}
|
||||
PYBIND11_DETAILED_ERROR_MESSAGES=1)
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
#include "cacheTransceiver.h"
|
||||
#include "tensorrt_llm/batch_manager/cacheTransceiver.h"
|
||||
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
|
||||
#include "tensorrt_llm/common/bindingUtils.h"
|
||||
#include "tensorrt_llm/executor/executor.h"
|
||||
#include "tensorrt_llm/nanobind/common/customCasters.h"
|
||||
#include <ATen/ATen.h>
|
||||
@ -28,6 +29,7 @@
|
||||
#include <nanobind/stl/vector.h>
|
||||
#include <nanobind/trampoline.h>
|
||||
#include <torch/extension.h>
|
||||
#include <typeinfo>
|
||||
|
||||
using SizeType32 = tensorrt_llm::runtime::SizeType32;
|
||||
|
||||
@ -103,6 +105,79 @@ void tb::CacheTransceiverBindings::initBindings(nb::module_& m)
|
||||
nb::arg("tokens_per_block"), nb::arg("world_config"), nb::arg("attention_layer_num_per_pp"),
|
||||
nb::arg("dtype"), nb::arg("attention_type"), nb::arg("cache_transceiver_config") = std::nullopt);
|
||||
|
||||
nb::class_<tb::CacheTransceiverComm>(m, "CacheTransceiverComm")
|
||||
.def(
|
||||
"__init__",
|
||||
[](tb::CacheTransceiverComm* self, nb::object pg_obj, std::string pybind11_abi)
|
||||
{
|
||||
new (self) tb::CacheTransceiverComm(
|
||||
common::get_intrusive_ptr<c10d::ProcessGroup, nb::python_error>(pg_obj.ptr(), pybind11_abi));
|
||||
},
|
||||
nb::arg("process_group"), nb::arg("pybind11_abi"))
|
||||
.def("get_rank", &tb::CacheTransceiverComm::getRank)
|
||||
.def("get_size", &tb::CacheTransceiverComm::getSize)
|
||||
.def("split", &tb::CacheTransceiverComm::split, nb::arg("color"), nb::arg("key"))
|
||||
.def(
|
||||
"allgather",
|
||||
[](tb::CacheTransceiverComm const& self, int64_t input)
|
||||
{
|
||||
std::vector<int64_t> out(static_cast<size_t>(self.getSize()));
|
||||
c10d::AllgatherOptions options;
|
||||
bool ok = self.allgather(input, std::ref(out), options);
|
||||
return nb::make_tuple(ok, out);
|
||||
},
|
||||
nb::arg("input"))
|
||||
.def(
|
||||
"allgather",
|
||||
[](tb::CacheTransceiverComm const& self, double input)
|
||||
{
|
||||
std::vector<double> out(static_cast<size_t>(self.getSize()));
|
||||
c10d::AllgatherOptions options;
|
||||
bool ok = self.allgather(input, std::ref(out), options);
|
||||
return nb::make_tuple(ok, out);
|
||||
},
|
||||
nb::arg("input"))
|
||||
.def(
|
||||
"allgather",
|
||||
[](tb::CacheTransceiverComm const& self, char input)
|
||||
{
|
||||
std::vector<char> out(static_cast<size_t>(self.getSize()));
|
||||
c10d::AllgatherOptions options;
|
||||
bool ok = self.allgather(input, std::ref(out), options);
|
||||
return nb::make_tuple(ok, out);
|
||||
},
|
||||
nb::arg("input"))
|
||||
.def(
|
||||
"allgatherv",
|
||||
[](tb::CacheTransceiverComm const& self, std::vector<int64_t> input, std::vector<int> const& sizes)
|
||||
{
|
||||
int total_size = std::accumulate(sizes.begin(), sizes.end(), 0);
|
||||
std::vector<int64_t> output(total_size);
|
||||
bool ok = self.allgatherv(std::ref(input), std::ref(output), std::cref(sizes));
|
||||
return nb::make_tuple(ok, output);
|
||||
},
|
||||
nb::arg("input"), nb::arg("sizes"))
|
||||
.def(
|
||||
"allgatherv",
|
||||
[](tb::CacheTransceiverComm const& self, std::vector<double> input, std::vector<int> const& sizes)
|
||||
{
|
||||
int total_size = std::accumulate(sizes.begin(), sizes.end(), 0);
|
||||
std::vector<double> output(total_size);
|
||||
bool ok = self.allgatherv(std::ref(input), std::ref(output), std::cref(sizes));
|
||||
return nb::make_tuple(ok, output);
|
||||
},
|
||||
nb::arg("input"), nb::arg("sizes"))
|
||||
.def(
|
||||
"allgatherv",
|
||||
[](tb::CacheTransceiverComm const& self, std::vector<char> input, std::vector<int> const& sizes)
|
||||
{
|
||||
int total_size = std::accumulate(sizes.begin(), sizes.end(), 0);
|
||||
std::vector<char> output(total_size);
|
||||
bool ok = self.allgatherv(std::ref(input), std::ref(output), std::cref(sizes));
|
||||
return nb::make_tuple(ok, output);
|
||||
},
|
||||
nb::arg("input"), nb::arg("sizes"));
|
||||
|
||||
nb::class_<tb::kv_cache_manager::CacheTransBufferManager>(m, "CacheTransBufferManager")
|
||||
.def(nb::init<tb::kv_cache_manager::BaseKVCacheManager*, std::optional<size_t>>(), nb::arg("cache_manager"),
|
||||
nb::arg("max_num_tokens") = std::nullopt)
|
||||
|
||||
@ -40,6 +40,7 @@
|
||||
#include "tensorrt_llm/nanobind/batch_manager/llmRequest.h"
|
||||
#include "tensorrt_llm/nanobind/common/tllmExceptions.h"
|
||||
#include "tensorrt_llm/nanobind/executor/bindings.h"
|
||||
#include "tensorrt_llm/nanobind/process_group/bindings.h"
|
||||
#include "tensorrt_llm/nanobind/runtime/bindings.h"
|
||||
#include "tensorrt_llm/nanobind/testing/modelSpecBinding.h"
|
||||
#include "tensorrt_llm/nanobind/thop/bindings.h"
|
||||
@ -125,6 +126,7 @@ NB_MODULE(TRTLLM_NB_MODULE, m)
|
||||
// Create submodule for executor bindings.
|
||||
auto mExecutor = m.def_submodule("executor", "Executor bindings");
|
||||
auto mInternal = m.def_submodule("internal", "Internal submodule of TRTLLM runtime");
|
||||
auto mInternalProcessGroup = mInternal.def_submodule("process_group", "PyTorch ProcessGroup internal bindings");
|
||||
auto mInternalRuntime = mInternal.def_submodule("runtime", "Runtime internal bindings");
|
||||
auto mInternalTesting = mInternal.def_submodule("testing", "Testing internal bindings");
|
||||
auto mInternalBatchManager = mInternal.def_submodule("batch_manager", "Batch manager internal bindings");
|
||||
@ -485,6 +487,7 @@ NB_MODULE(TRTLLM_NB_MODULE, m)
|
||||
.def_prop_ro("pinned", &tr::MemoryCounters::getPinned)
|
||||
.def_prop_ro("uvm", &tr::MemoryCounters::getUVM);
|
||||
|
||||
tensorrt_llm::nanobind::process_group::initBindings(mInternalProcessGroup);
|
||||
tensorrt_llm::nanobind::runtime::initBindings(mInternalRuntime);
|
||||
tensorrt_llm::nanobind::testing::initBindings(mInternalTesting);
|
||||
tpb::initBindings(mInternalBatchManager);
|
||||
|
||||
43
cpp/tensorrt_llm/nanobind/process_group/bindings.cpp
Normal file
43
cpp/tensorrt_llm/nanobind/process_group/bindings.cpp
Normal file
@ -0,0 +1,43 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "bindings.h"
|
||||
#include <nanobind/stl/string.h>
|
||||
|
||||
#include "tensorrt_llm/common/bindingUtils.h"
|
||||
#include "tensorrt_llm/runtime/utils/pgUtils.h"
|
||||
|
||||
namespace nb = nanobind;
|
||||
|
||||
namespace tensorrt_llm::nanobind::process_group
|
||||
{
|
||||
|
||||
void initBindings(nb::module_& m)
|
||||
{
|
||||
|
||||
m.def("init_pg",
|
||||
[](nb::object world_pg_obj, nb::object local_pg_obj, std::string const& pybind11_abi)
|
||||
{
|
||||
using Pg = c10d::ProcessGroup;
|
||||
using E = nb::python_error;
|
||||
|
||||
pg_utils::init_pg(common::get_intrusive_ptr<Pg, E>(world_pg_obj.ptr(), pybind11_abi),
|
||||
common::get_intrusive_ptr<Pg, E>(local_pg_obj.ptr(), pybind11_abi));
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::nanobind::process_group
|
||||
26
cpp/tensorrt_llm/nanobind/process_group/bindings.h
Normal file
26
cpp/tensorrt_llm/nanobind/process_group/bindings.h
Normal file
@ -0,0 +1,26 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
namespace nb = nanobind;
|
||||
|
||||
namespace tensorrt_llm::nanobind::process_group
|
||||
{
|
||||
void initBindings(nb::module_& m);
|
||||
} // namespace tensorrt_llm::nanobind::process_group
|
||||
@ -14,6 +14,7 @@ set(SRCS
|
||||
executor/executor.cpp
|
||||
executor/executorConfig.cpp
|
||||
executor/request.cpp
|
||||
process_group/bindings.cpp
|
||||
runtime/bindings.cpp
|
||||
runtime/hostfunc.cpp
|
||||
common/tllmExceptions.cpp
|
||||
@ -48,7 +49,8 @@ target_link_libraries(
|
||||
torch_python
|
||||
${CUDA_DRV_LIB}
|
||||
${CUDA_NVML_LIB}
|
||||
th_common)
|
||||
th_common
|
||||
pg_utils)
|
||||
target_compile_definitions(
|
||||
${TRTLLM_PYBIND_MODULE} PUBLIC TRTLLM_PYBIND_MODULE=${TRTLLM_PYBIND_MODULE}
|
||||
PYBIND11_DETAILED_ERROR_MESSAGES=1)
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
#include "cacheTransceiver.h"
|
||||
#include "tensorrt_llm/batch_manager/cacheTransceiver.h"
|
||||
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
|
||||
#include "tensorrt_llm/common/bindingUtils.h"
|
||||
#include "tensorrt_llm/executor/executor.h"
|
||||
#include <ATen/ATen.h>
|
||||
#include <pybind11/functional.h>
|
||||
@ -26,6 +27,7 @@
|
||||
#include <pybind11/stl.h>
|
||||
#include <pybind11/stl_bind.h>
|
||||
#include <torch/extension.h>
|
||||
#include <typeinfo>
|
||||
|
||||
using SizeType32 = tensorrt_llm::runtime::SizeType32;
|
||||
|
||||
@ -99,6 +101,79 @@ void tb::CacheTransceiverBindings::initBindings(py::module_& m)
|
||||
py::arg("tokens_per_block"), py::arg("world_config"), py::arg("attention_layer_num_per_pp"),
|
||||
py::arg("dtype"), py::arg("attention_type"), py::arg("cache_transceiver_config") = std::nullopt);
|
||||
|
||||
py::classh<tb::CacheTransceiverComm>(m, "CacheTransceiverComm")
|
||||
.def(py::init(
|
||||
[](py::object pg_obj, std::string pybind11_abi)
|
||||
{
|
||||
return new CacheTransceiverComm(
|
||||
common::get_intrusive_ptr<c10d::ProcessGroup, py::error_already_set>(
|
||||
pg_obj.ptr(), pybind11_abi));
|
||||
}),
|
||||
py::arg("process_group"), py::arg("pybind11_abi"))
|
||||
.def("get_rank", &tb::CacheTransceiverComm::getRank)
|
||||
.def("get_size", &tb::CacheTransceiverComm::getSize)
|
||||
.def("split", &tb::CacheTransceiverComm::split, py::arg("color"), py::arg("key"))
|
||||
.def(
|
||||
"allgather",
|
||||
[](tb::CacheTransceiverComm const& self, int64_t input)
|
||||
{
|
||||
std::vector<int64_t> out(static_cast<size_t>(self.getSize()));
|
||||
c10d::AllgatherOptions options;
|
||||
bool ok = self.allgather(input, std::ref(out), options);
|
||||
return py::make_tuple(ok, out);
|
||||
},
|
||||
py::arg("input"))
|
||||
.def(
|
||||
"allgather",
|
||||
[](tb::CacheTransceiverComm const& self, double input)
|
||||
{
|
||||
std::vector<double> out(static_cast<size_t>(self.getSize()));
|
||||
c10d::AllgatherOptions options;
|
||||
bool ok = self.allgather(input, std::ref(out), options);
|
||||
return py::make_tuple(ok, out);
|
||||
},
|
||||
py::arg("input"))
|
||||
.def(
|
||||
"allgather",
|
||||
[](tb::CacheTransceiverComm const& self, char input)
|
||||
{
|
||||
std::vector<char> out(static_cast<size_t>(self.getSize()));
|
||||
c10d::AllgatherOptions options;
|
||||
bool ok = self.allgather(input, std::ref(out), options);
|
||||
return py::make_tuple(ok, out);
|
||||
},
|
||||
py::arg("input"))
|
||||
.def(
|
||||
"allgatherv",
|
||||
[](tb::CacheTransceiverComm const& self, std::vector<int64_t> input, std::vector<int> const& sizes)
|
||||
{
|
||||
int total_size = std::accumulate(sizes.begin(), sizes.end(), 0);
|
||||
std::vector<int64_t> output(total_size);
|
||||
bool ok = self.allgatherv(std::ref(input), std::ref(output), std::cref(sizes));
|
||||
return py::make_tuple(ok, output);
|
||||
},
|
||||
py::arg("input"), py::arg("sizes"))
|
||||
.def(
|
||||
"allgatherv",
|
||||
[](tb::CacheTransceiverComm const& self, std::vector<double> input, std::vector<int> const& sizes)
|
||||
{
|
||||
int total_size = std::accumulate(sizes.begin(), sizes.end(), 0);
|
||||
std::vector<double> output(total_size);
|
||||
bool ok = self.allgatherv(std::ref(input), std::ref(output), std::cref(sizes));
|
||||
return py::make_tuple(ok, output);
|
||||
},
|
||||
py::arg("input"), py::arg("sizes"))
|
||||
.def(
|
||||
"allgatherv",
|
||||
[](tb::CacheTransceiverComm const& self, std::vector<char> input, std::vector<int> const& sizes)
|
||||
{
|
||||
int total_size = std::accumulate(sizes.begin(), sizes.end(), 0);
|
||||
std::vector<char> output(total_size);
|
||||
bool ok = self.allgatherv(std::ref(input), std::ref(output), std::cref(sizes));
|
||||
return py::make_tuple(ok, output);
|
||||
},
|
||||
py::arg("input"), py::arg("sizes"));
|
||||
|
||||
py::class_<tb::kv_cache_manager::CacheTransBufferManager>(m, "CacheTransBufferManager")
|
||||
.def(py::init<tb::kv_cache_manager::BaseKVCacheManager*, std::optional<size_t>>(), py::arg("cache_manager"),
|
||||
py::arg("max_num_tokens") = std::nullopt)
|
||||
|
||||
@ -34,6 +34,7 @@
|
||||
#include "tensorrt_llm/pybind/batch_manager/llmRequest.h"
|
||||
#include "tensorrt_llm/pybind/common/tllmExceptions.h"
|
||||
#include "tensorrt_llm/pybind/executor/bindings.h"
|
||||
#include "tensorrt_llm/pybind/process_group/bindings.h"
|
||||
#include "tensorrt_llm/pybind/runtime/bindings.h"
|
||||
#include "tensorrt_llm/pybind/testing/modelSpecBinding.h"
|
||||
#include "tensorrt_llm/pybind/thop/bindings.h"
|
||||
@ -117,6 +118,7 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
|
||||
// Create submodule for executor bindings.
|
||||
auto mExecutor = m.def_submodule("executor", "Executor bindings");
|
||||
auto mInternal = m.def_submodule("internal", "Internal submodule of TRTLLM runtime");
|
||||
auto mInternalProcessGroup = mInternal.def_submodule("process_group", "PyTorch ProcessGroup internal bindings");
|
||||
auto mInternalRuntime = mInternal.def_submodule("runtime", "Runtime internal bindings");
|
||||
auto mInternalTesting = mInternal.def_submodule("testing", "Testing internal bindings");
|
||||
auto mInternalBatchManager = mInternal.def_submodule("batch_manager", "Batch manager internal bindings");
|
||||
@ -473,6 +475,7 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
|
||||
.def_property_readonly("pinned", &tr::MemoryCounters::getPinned)
|
||||
.def_property_readonly("uvm", &tr::MemoryCounters::getUVM);
|
||||
|
||||
tensorrt_llm::pybind::process_group::initBindings(mInternalProcessGroup);
|
||||
tensorrt_llm::pybind::runtime::initBindings(mInternalRuntime);
|
||||
tensorrt_llm::pybind::testing::initBindings(mInternalTesting);
|
||||
tpb::initBindings(mInternalBatchManager);
|
||||
|
||||
40
cpp/tensorrt_llm/pybind/process_group/bindings.cpp
Normal file
40
cpp/tensorrt_llm/pybind/process_group/bindings.cpp
Normal file
@ -0,0 +1,40 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "bindings.h"
|
||||
|
||||
#include "tensorrt_llm/common/bindingUtils.h"
|
||||
#include "tensorrt_llm/runtime/utils/pgUtils.h"
|
||||
|
||||
namespace tensorrt_llm::pybind::process_group
|
||||
{
|
||||
|
||||
void initBindings(py::module_& m)
|
||||
{
|
||||
|
||||
m.def("init_pg",
|
||||
[](py::object world_pg_obj, py::object local_pg_obj, std::string const& pybind11_abi)
|
||||
{
|
||||
using Pg = c10d::ProcessGroup;
|
||||
using E = py::error_already_set;
|
||||
|
||||
pg_utils::init_pg(common::get_intrusive_ptr<Pg, E>(world_pg_obj.ptr(), pybind11_abi),
|
||||
common::get_intrusive_ptr<Pg, E>(local_pg_obj.ptr(), pybind11_abi));
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::pybind::process_group
|
||||
27
cpp/tensorrt_llm/pybind/process_group/bindings.h
Normal file
27
cpp/tensorrt_llm/pybind/process_group/bindings.h
Normal file
@ -0,0 +1,27 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace tensorrt_llm::pybind::process_group
|
||||
{
|
||||
void initBindings(py::module_& m);
|
||||
} // namespace tensorrt_llm::pybind::process_group
|
||||
@ -111,3 +111,6 @@ if(NOT WIN32)
|
||||
target_link_libraries(runtime_src PUBLIC libnuma::libnuma)
|
||||
target_link_options(runtime_src PUBLIC ${CONAN_LIBNUMA_LINK_OPTIONS})
|
||||
endif()
|
||||
|
||||
# Add utils subdirectory for pg_utils module
|
||||
add_subdirectory(utils)
|
||||
|
||||
29
cpp/tensorrt_llm/runtime/utils/CMakeLists.txt
Normal file
29
cpp/tensorrt_llm/runtime/utils/CMakeLists.txt
Normal file
@ -0,0 +1,29 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION &
|
||||
# AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
|
||||
# use this file except in compliance with the License. You may obtain a copy of
|
||||
# the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations under
|
||||
# the License.
|
||||
|
||||
# Create the pg_utils shared library
|
||||
add_library(pg_utils SHARED pgUtils.cpp)
|
||||
|
||||
set_property(TARGET pg_utils PROPERTY POSITION_INDEPENDENT_CODE ON)
|
||||
|
||||
# Include directories
|
||||
target_include_directories(pg_utils PUBLIC ${PROJECT_SOURCE_DIR}/include
|
||||
${TORCH_INCLUDE_DIRS})
|
||||
|
||||
target_link_libraries(pg_utils PUBLIC ${TORCH_LIBRARIES})
|
||||
|
||||
# Find torch_python
|
||||
find_library(TORCH_PYTHON_LIB torch_python REQUIRED
|
||||
HINTS ${TORCH_INSTALL_PREFIX}/lib)
|
||||
@ -222,6 +222,7 @@ void initialize(MpiThreadSupport threadMode, bool forwardAbortToParent)
|
||||
|
||||
void MpiComm::barrier() const
|
||||
{
|
||||
couldUseMPI();
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
MPICHECK(MPI_Barrier(mComm));
|
||||
#else
|
||||
@ -267,6 +268,7 @@ size_t invokeChunked(TMpiFunc func, TBase* buffer, size_t size, MPI_Datatype dty
|
||||
|
||||
std::unique_ptr<MpiRequest> MpiComm::bcastAsync(void* buffer, size_t size, MpiType dtype, int root) const
|
||||
{
|
||||
couldUseMPI();
|
||||
std::unique_ptr<MpiRequest> r = std::make_unique<MpiRequest>();
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
invokeChunked(MPI_Ibcast, buffer, size, getMpiDtype(dtype), root, mComm, &r->mRequest);
|
||||
@ -278,11 +280,13 @@ std::unique_ptr<MpiRequest> MpiComm::bcastAsync(void* buffer, size_t size, MpiTy
|
||||
|
||||
std::unique_ptr<MpiRequest> MpiComm::bcastAsync(runtime::IBuffer& buf, int root) const
|
||||
{
|
||||
couldUseMPI();
|
||||
return bcastAsync(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, root);
|
||||
}
|
||||
|
||||
void MpiComm::bcast(void* buffer, size_t size, MpiType dtype, int root) const
|
||||
{
|
||||
couldUseMPI();
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
invokeChunked(MPI_Bcast, buffer, size, getMpiDtype(dtype), root, mComm);
|
||||
#else
|
||||
@ -292,12 +296,14 @@ void MpiComm::bcast(void* buffer, size_t size, MpiType dtype, int root) const
|
||||
|
||||
void MpiComm::bcast(runtime::IBuffer& buf, int root) const
|
||||
{
|
||||
couldUseMPI();
|
||||
bcast(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, root);
|
||||
}
|
||||
|
||||
std::unique_ptr<MpiRequest> MpiComm::sendAsync(
|
||||
void const* buffer, size_t size, MpiType dtype, int dest, MpiTag tag) const
|
||||
{
|
||||
couldUseMPI();
|
||||
TLLM_LOG_DEBUG("start MPI_Isend with dest %d, tag %d, size %d", dest, static_cast<int>(tag), size);
|
||||
std::unique_ptr<MpiRequest> r = std::make_unique<MpiRequest>();
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
@ -311,11 +317,13 @@ std::unique_ptr<MpiRequest> MpiComm::sendAsync(
|
||||
|
||||
std::unique_ptr<MpiRequest> MpiComm::sendAsync(runtime::IBuffer const& buf, int dest, MpiTag tag) const
|
||||
{
|
||||
couldUseMPI();
|
||||
return sendAsync(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, dest, tag);
|
||||
}
|
||||
|
||||
void MpiComm::sendRawTag(void const* buffer, size_t size, MpiType dtype, int dest, int tag) const
|
||||
{
|
||||
couldUseMPI();
|
||||
TLLM_LOG_DEBUG("start MPI_Send with dest %d, tag %d, size %d", dest, tag, size);
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
invokeChunked(MPI_Send, buffer, size, getMpiDtype(dtype), dest, tag, mComm);
|
||||
@ -327,16 +335,19 @@ void MpiComm::sendRawTag(void const* buffer, size_t size, MpiType dtype, int des
|
||||
|
||||
void MpiComm::send(void const* buffer, size_t size, MpiType dtype, int dest, MpiTag tag) const
|
||||
{
|
||||
couldUseMPI();
|
||||
sendRawTag(buffer, size, dtype, dest, static_cast<int>(tag));
|
||||
}
|
||||
|
||||
void MpiComm::send(runtime::IBuffer const& buf, int dest, MpiTag tag) const
|
||||
{
|
||||
couldUseMPI();
|
||||
send(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, dest, tag);
|
||||
}
|
||||
|
||||
MPI_Status MpiComm::recvRawTag(void* buffer, size_t size, MpiType dtype, int source, int tag) const
|
||||
{
|
||||
couldUseMPI();
|
||||
TLLM_LOG_DEBUG("start MPI_Recv with source %d, tag %d, size %d", source, tag, size);
|
||||
MPI_Status status{};
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
@ -350,11 +361,13 @@ MPI_Status MpiComm::recvRawTag(void* buffer, size_t size, MpiType dtype, int sou
|
||||
|
||||
MPI_Status MpiComm::recv(void* buffer, size_t size, MpiType dtype, int source, MpiTag tag) const
|
||||
{
|
||||
couldUseMPI();
|
||||
return recvRawTag(buffer, size, dtype, source, static_cast<int>(tag));
|
||||
}
|
||||
|
||||
MPI_Status MpiComm::recv(runtime::IBuffer& buf, int source, MpiTag tag) const
|
||||
{
|
||||
couldUseMPI();
|
||||
return recv(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, source, tag);
|
||||
}
|
||||
|
||||
@ -382,6 +395,7 @@ MpiComm const& MpiComm::setRawSessionByFortran(int64_t fortranHandle)
|
||||
|
||||
void MpiComm::allreduce(void const* sendbuf, void* recvbuf, int count, MpiType dtype, MpiOp op) const
|
||||
{
|
||||
couldUseMPI();
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
MPICHECK(MPI_Allreduce(sendbuf, recvbuf, count, getMpiDtype(dtype), getMpiOp(op), mComm));
|
||||
#else
|
||||
@ -391,6 +405,7 @@ void MpiComm::allreduce(void const* sendbuf, void* recvbuf, int count, MpiType d
|
||||
|
||||
void MpiComm::allgather(void const* sendbuf, void* recvbuf, int count, MpiType dtype) const
|
||||
{
|
||||
couldUseMPI();
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
MPICHECK(MPI_Allgather(sendbuf, count, getMpiDtype(dtype), recvbuf, count, getMpiDtype(dtype), mComm));
|
||||
#else
|
||||
@ -401,6 +416,7 @@ void MpiComm::allgather(void const* sendbuf, void* recvbuf, int count, MpiType d
|
||||
void MpiComm::allgatherv(void const* sendbuf, int sendcount, MpiType sendtype, void* recvbuf,
|
||||
std::vector<int> const& recvcounts, std::vector<int> const& displs, MpiType recvtype) const
|
||||
{
|
||||
couldUseMPI();
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
MPICHECK(MPI_Allgatherv(sendbuf, sendcount, getMpiDtype(sendtype), recvbuf, recvcounts.data(), displs.data(),
|
||||
getMpiDtype(recvtype), mComm));
|
||||
@ -450,6 +466,7 @@ bool MpiComm::iprobe(int source, MpiTag tag, MPI_Status* status) const
|
||||
|
||||
void MpiComm::recvPoll(int source, MpiTag tag, int periodMs) const
|
||||
{
|
||||
couldUseMPI();
|
||||
MPI_Status status;
|
||||
while (!iprobe(source, tag, &status))
|
||||
{
|
||||
|
||||
44
cpp/tensorrt_llm/runtime/utils/pgUtils.cpp
Normal file
44
cpp/tensorrt_llm/runtime/utils/pgUtils.cpp
Normal file
@ -0,0 +1,44 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/runtime/utils/pgUtils.h"
|
||||
#include "tensorrt_llm/common/logger.h"
|
||||
|
||||
namespace tensorrt_llm::pg_utils
|
||||
{
|
||||
|
||||
c10::intrusive_ptr<c10d::ProcessGroup> pg_world;
|
||||
c10::intrusive_ptr<c10d::ProcessGroup> pg_local;
|
||||
|
||||
c10::intrusive_ptr<c10d::ProcessGroup> get_world_pg()
|
||||
{
|
||||
return pg_world;
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<c10d::ProcessGroup> get_local_pg()
|
||||
{
|
||||
return pg_local;
|
||||
}
|
||||
|
||||
void init_pg(c10::intrusive_ptr<c10d::ProcessGroup> const& process_group_world,
|
||||
c10::intrusive_ptr<c10d::ProcessGroup> const& process_group_local)
|
||||
{
|
||||
TLLM_LOG_DEBUG(process_group_world->getRank(), "Init process group on rank %d", process_group_world->getRank());
|
||||
pg_world = process_group_world;
|
||||
pg_local = process_group_local;
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::pg_utils
|
||||
@ -97,8 +97,9 @@ add_library(
|
||||
finegrained_mixed_dtype_gemm_thop.cpp
|
||||
tinygemm2.cpp)
|
||||
set_property(TARGET th_common PROPERTY POSITION_INDEPENDENT_CODE ON)
|
||||
target_link_libraries(th_common PRIVATE ${TORCH_LIBRARIES} th_utils
|
||||
${Python3_LIBRARIES} ${SHARED_TARGET})
|
||||
target_link_libraries(
|
||||
th_common PRIVATE ${TORCH_LIBRARIES} th_utils ${Python3_LIBRARIES}
|
||||
${SHARED_TARGET} pg_utils)
|
||||
|
||||
if(USING_OSS_CUTLASS_LOW_LATENCY_GEMM)
|
||||
target_compile_definitions(th_common
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
#include "tensorrt_llm/common/opUtils.h"
|
||||
#include "tensorrt_llm/runtime/torchUtils.h"
|
||||
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
|
||||
#include "tensorrt_llm/runtime/utils/pgUtils.h"
|
||||
|
||||
#include <NvInferRuntime.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
@ -28,8 +29,12 @@
|
||||
#include <vector>
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
#include <nccl.h>
|
||||
#include <torch/csrc/distributed/c10d/FileStore.hpp>
|
||||
#include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
|
||||
#endif // ENABLE_MULTI_DEVICE
|
||||
|
||||
using tensorrt_llm::pg_utils::PgHelper;
|
||||
|
||||
namespace torch_ext
|
||||
{
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
@ -115,6 +120,95 @@ private:
|
||||
std::shared_ptr<ncclComm_t> mNcclComm;
|
||||
};
|
||||
|
||||
class AllgatherPgOp
|
||||
{
|
||||
public:
|
||||
AllgatherPgOp(std::set<int> group, c10::intrusive_ptr<c10d::ProcessGroup> const& process_group_)
|
||||
: mGroup(std::move(group))
|
||||
, mProcessGroup(process_group_)
|
||||
{
|
||||
}
|
||||
|
||||
~AllgatherPgOp() = default;
|
||||
|
||||
int initialize() noexcept
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start for rank %d", __PRETTY_FUNCTION__, mProcessGroup->getRank());
|
||||
return 0;
|
||||
}
|
||||
|
||||
std::pair<torch::Tensor, c10::intrusive_ptr<c10d::Work>> run(
|
||||
torch::Tensor input, torch::optional<torch::List<int64_t>> sizes, bool coalescing = false)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(mProcessGroup.get() != nullptr, "mProcessGroup should be initialized before used");
|
||||
std::vector<int64_t> outputShape = input.sizes().vec();
|
||||
if (sizes.has_value())
|
||||
{
|
||||
outputShape[0] = std::accumulate(sizes.value().begin(), sizes.value().end(), 0, std::plus<>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
outputShape[0] *= mGroup.size();
|
||||
}
|
||||
auto output = torch::empty(outputShape, input.options());
|
||||
|
||||
PgHelper pgh{mProcessGroup};
|
||||
|
||||
c10::intrusive_ptr<c10d::Work> work;
|
||||
if (sizes.has_value())
|
||||
{
|
||||
std::vector inputs{input};
|
||||
int64_t split_offset = 0;
|
||||
std::vector<torch::Tensor> outputTensors{};
|
||||
for (int root = 0; root < static_cast<int>(mGroup.size()); ++root)
|
||||
{
|
||||
auto split_size = sizes.value()[root];
|
||||
outputTensors.push_back(
|
||||
output.index({torch::indexing::Slice(split_offset, split_offset + split_size)}));
|
||||
split_offset += split_size;
|
||||
}
|
||||
std::vector<std::vector<torch::Tensor>> outputs{outputTensors};
|
||||
work = mProcessGroup->allgather(outputs, inputs, {});
|
||||
}
|
||||
else
|
||||
{
|
||||
work = pgh.allgather(input, output, {});
|
||||
}
|
||||
|
||||
if (!coalescing)
|
||||
{
|
||||
PGCHECK_THROW_WITH_INFO(work, "ProcessGroup: allgather");
|
||||
return {output, nullptr};
|
||||
}
|
||||
|
||||
return {output, work};
|
||||
}
|
||||
|
||||
std::vector<torch::Tensor> run_list(torch::TensorList input_list, torch::optional<torch::List<int64_t>> sizes)
|
||||
{
|
||||
std::vector<torch::Tensor> output_list;
|
||||
std::vector<c10::intrusive_ptr<c10d::Work>> work_list;
|
||||
output_list.reserve(input_list.size());
|
||||
work_list.reserve(input_list.size());
|
||||
mProcessGroup->startCoalescing(c10::DeviceType::CUDA);
|
||||
for (auto const& input : input_list)
|
||||
{
|
||||
auto [output, work] = run(input, sizes, true);
|
||||
output_list.push_back(output);
|
||||
work_list.push_back(work); // Hold work objects (input & output tensors) until endCoalescing wait finished
|
||||
}
|
||||
if (auto work = mProcessGroup->endCoalescing(c10::DeviceType::CUDA))
|
||||
{
|
||||
PGCHECK_THROW_WITH_INFO(work, "ProcessGroup: allgather, end coalescing");
|
||||
}
|
||||
return output_list;
|
||||
}
|
||||
|
||||
private:
|
||||
std::set<int> mGroup;
|
||||
c10::intrusive_ptr<c10d::ProcessGroup> mProcessGroup;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
#endif // ENABLE_MULTI_DEVICE
|
||||
@ -136,6 +230,24 @@ torch::Tensor allgather(torch::Tensor input, torch::optional<torch::List<int64_t
|
||||
#endif // ENABLE_MULTI_DEVICE
|
||||
}
|
||||
|
||||
torch::Tensor allgather_pg(torch::Tensor input, torch::optional<torch::List<int64_t>> sizes,
|
||||
torch::List<int64_t> group_, c10::intrusive_ptr<c10d::ProcessGroup> const& process_group_)
|
||||
{
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
std::set<int> group;
|
||||
for (int64_t rank : group_)
|
||||
{
|
||||
group.insert(static_cast<int>(rank));
|
||||
}
|
||||
AllgatherPgOp op(group, process_group_);
|
||||
op.initialize();
|
||||
auto [output, _] = op.run(input, sizes);
|
||||
return output;
|
||||
#else
|
||||
return input;
|
||||
#endif // ENABLE_MULTI_DEVICE
|
||||
}
|
||||
|
||||
std::vector<torch::Tensor> allgather_list(
|
||||
torch::TensorList input_list, torch::optional<torch::List<int64_t>> sizes, torch::List<int64_t> group_)
|
||||
{
|
||||
@ -154,16 +266,42 @@ std::vector<torch::Tensor> allgather_list(
|
||||
#endif // ENABLE_MULTI_DEVICE
|
||||
}
|
||||
|
||||
std::vector<torch::Tensor> allgather_list_pg(torch::TensorList input_list, torch::optional<torch::List<int64_t>> sizes,
|
||||
torch::List<int64_t> group_, c10::intrusive_ptr<c10d::ProcessGroup> const& process_group_)
|
||||
{
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
std::set<int> group;
|
||||
for (int64_t rank : group_)
|
||||
{
|
||||
group.insert(static_cast<int>(rank));
|
||||
}
|
||||
AllgatherPgOp op(group, process_group_);
|
||||
op.initialize();
|
||||
auto output_list = op.run_list(input_list, sizes);
|
||||
return output_list;
|
||||
#else
|
||||
return input_list.vec();
|
||||
#endif // ENABLE_MULTI_DEVICE
|
||||
}
|
||||
|
||||
} // namespace torch_ext
|
||||
|
||||
TORCH_LIBRARY_FRAGMENT(trtllm, m)
|
||||
{
|
||||
m.def("allgather(Tensor input, SymInt[]? sizes, int[] group) -> Tensor");
|
||||
m.def(
|
||||
"allgather_pg(Tensor input, SymInt[]? sizes, int[] group, __torch__.torch.classes.c10d.ProcessGroup "
|
||||
"process_group) -> Tensor");
|
||||
m.def("allgather_list(Tensor[] input_list, SymInt[]? sizes, int[] group) -> Tensor[]");
|
||||
m.def(
|
||||
"allgather_list_pg(Tensor[] input_list, SymInt[]? sizes, int[] group, "
|
||||
"__torch__.torch.classes.c10d.ProcessGroup process_group) -> Tensor[]");
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
|
||||
{
|
||||
m.impl("allgather", &torch_ext::allgather);
|
||||
m.impl("allgather_pg", &torch_ext::allgather_pg);
|
||||
m.impl("allgather_list", &torch_ext::allgather_list);
|
||||
m.impl("allgather_list_pg", &torch_ext::allgather_list_pg);
|
||||
}
|
||||
|
||||
@ -30,6 +30,7 @@
|
||||
#include "tensorrt_llm/runtime/mcastDeviceMemory.h"
|
||||
#include "tensorrt_llm/runtime/torchUtils.h"
|
||||
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
|
||||
#include "tensorrt_llm/runtime/utils/pgUtils.h"
|
||||
#include "tensorrt_llm/thop/fp4Quantize.h"
|
||||
#include "tensorrt_llm/thop/fp8Op.h"
|
||||
#include "tensorrt_llm/thop/thUtils.h"
|
||||
@ -37,7 +38,13 @@
|
||||
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
#include <ATen/cuda/EmptyTensor.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <nccl.h>
|
||||
#include <torch/csrc/distributed/c10d/FileStore.hpp>
|
||||
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
|
||||
#include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
|
||||
#include <torch/csrc/distributed/c10d/Store.hpp>
|
||||
#include <torch/csrc/distributed/c10d/TCPStore.hpp>
|
||||
#endif // ENABLE_MULTI_DEVICE
|
||||
#include <nvml.h>
|
||||
#include <torch/extension.h>
|
||||
@ -50,6 +57,9 @@
|
||||
using tensorrt_llm::kernels::AllReduceFusionOp;
|
||||
using tensorrt_llm::kernels::AllReduceStrategyType;
|
||||
using tensorrt_llm::mpi::MpiTag;
|
||||
using tensorrt_llm::pg_utils::get_world_pg;
|
||||
using tensorrt_llm::pg_utils::get_local_pg;
|
||||
using tensorrt_llm::pg_utils::PgHelper;
|
||||
|
||||
namespace torch_ext
|
||||
{
|
||||
@ -59,6 +69,14 @@ namespace torch_ext
|
||||
namespace
|
||||
{
|
||||
|
||||
template <class... Ts>
|
||||
struct overloaded : Ts...
|
||||
{
|
||||
using Ts::operator()...;
|
||||
};
|
||||
template <class... Ts>
|
||||
overloaded(Ts...) -> overloaded<Ts...>;
|
||||
|
||||
class NvmlManager
|
||||
{
|
||||
public:
|
||||
@ -141,6 +159,79 @@ std::set<int> getLocalGroup(std::set<int> const& group)
|
||||
return localGroup;
|
||||
}
|
||||
|
||||
std::set<int> getLocalGroupTorch(std::set<int> const& group)
|
||||
{
|
||||
auto const worldPg = get_world_pg();
|
||||
auto const myRank = worldPg->getRank();
|
||||
auto const localPg = get_local_pg();
|
||||
auto const myLocalRank = localPg->getRank();
|
||||
auto const localSize = static_cast<uint32_t>(localPg->getSize());
|
||||
|
||||
PgHelper pgh_local{localPg};
|
||||
PgHelper pgh_world{worldPg}; // for p2p
|
||||
|
||||
std::vector<int32_t> ranks(localSize, -1);
|
||||
std::vector<int32_t> localRanks(localSize, -1);
|
||||
|
||||
if (group.size() >= localSize)
|
||||
{
|
||||
PGCHECK_THROW(pgh_local.allgather(&myRank, ref(ranks), {}));
|
||||
PGCHECK_THROW(pgh_local.allgather(&myLocalRank, ref(localRanks), {}));
|
||||
}
|
||||
else
|
||||
{
|
||||
int tag = static_cast<int>(MpiTag::kDefault);
|
||||
|
||||
if (myRank == *group.begin())
|
||||
{
|
||||
// Leader: gather from peers (world ranks), then broadcast full localSize arrays.
|
||||
size_t cnt = 0;
|
||||
ranks[cnt++] = myRank;
|
||||
int tmp;
|
||||
for (auto it = std::next(group.begin()); it != group.end(); ++it)
|
||||
{
|
||||
PGCHECK_THROW(pgh_world.recv(&tmp, *it, tag));
|
||||
ranks[cnt++] = tmp;
|
||||
}
|
||||
for (auto it = std::next(group.begin()); it != group.end(); ++it)
|
||||
{
|
||||
PGCHECK_THROW(pgh_world.send(ref(ranks), *it, tag));
|
||||
}
|
||||
|
||||
cnt = 0;
|
||||
localRanks[cnt++] = myLocalRank;
|
||||
for (auto it = std::next(group.begin()); it != group.end(); ++it)
|
||||
{
|
||||
PGCHECK_THROW(pgh_world.recv(&tmp, *it, tag));
|
||||
localRanks[cnt++] = tmp;
|
||||
}
|
||||
for (auto it = std::next(group.begin()); it != group.end(); ++it)
|
||||
{
|
||||
PGCHECK_THROW(pgh_world.send(ref(localRanks), *it, tag));
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
int leader = *group.begin();
|
||||
|
||||
PGCHECK_THROW(pgh_world.send(&myRank, leader, tag));
|
||||
PGCHECK_THROW(pgh_world.recv(ref(ranks), leader, tag));
|
||||
|
||||
PGCHECK_THROW(pgh_world.send(&myLocalRank, leader, tag));
|
||||
PGCHECK_THROW(pgh_world.recv(ref(localRanks), leader, tag));
|
||||
}
|
||||
}
|
||||
|
||||
std::set<int> localGroup;
|
||||
for (size_t i = 0; i < ranks.size(); ++i)
|
||||
{
|
||||
int world_r = ranks[i];
|
||||
if (group.find(world_r) != group.end())
|
||||
localGroup.insert(localRanks[i]);
|
||||
}
|
||||
return localGroup;
|
||||
}
|
||||
|
||||
class AllreduceOp
|
||||
{
|
||||
public:
|
||||
@ -154,8 +245,27 @@ public:
|
||||
{
|
||||
}
|
||||
|
||||
AllreduceOp(std::set<int> group, c10::intrusive_ptr<c10d::ProcessGroup> const& process_group_,
|
||||
nvinfer1::DataType type, AllReduceStrategyType strategy, AllReduceFusionOp op, float eps)
|
||||
: mGroup(std::move(group))
|
||||
, mType(type)
|
||||
, mStrategy(strategy)
|
||||
, mOp(op)
|
||||
, mEps(eps)
|
||||
, mNcclComm(process_group_)
|
||||
{
|
||||
}
|
||||
|
||||
~AllreduceOp() = default;
|
||||
|
||||
int getRank() const
|
||||
{
|
||||
return std::visit(
|
||||
overloaded{[&](std::shared_ptr<ncclComm_t> const&) { return COMM_SESSION.getRank(); },
|
||||
[&](c10::intrusive_ptr<c10d::ProcessGroup> const& torchPg) { return get_world_pg()->getRank(); }},
|
||||
mNcclComm);
|
||||
}
|
||||
|
||||
std::vector<torch::Tensor> run(torch::Tensor const& input, torch::optional<torch::Tensor> const& residual,
|
||||
torch::optional<torch::Tensor> const& norm_weight, torch::optional<torch::Tensor> const& scale,
|
||||
torch::optional<torch::Tensor> const& bias, bool trigger_completion_at_end,
|
||||
@ -169,7 +279,7 @@ public:
|
||||
AllReduceStrategyType runtime_strategy = getRuntimeStrategy(seq_len, size);
|
||||
|
||||
// Log runtime strategy
|
||||
auto const rank = COMM_SESSION.getRank();
|
||||
auto const rank = getRank();
|
||||
logRunTimeStrategy(runtime_strategy, rank);
|
||||
|
||||
// Dispatch to different allreduce implementations
|
||||
@ -192,14 +302,18 @@ public:
|
||||
|
||||
int initialize()
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start for rank %d", __PRETTY_FUNCTION__, COMM_SESSION.getRank());
|
||||
mNcclComm = getComm(mGroup);
|
||||
TLLM_LOG_TRACE("%s start for rank %d", __PRETTY_FUNCTION__, getRank());
|
||||
if (mNcclComm.index() == 0)
|
||||
{
|
||||
mNcclComm = getComm(mGroup);
|
||||
}
|
||||
if (mStrategy != AllReduceStrategyType::NCCL && mStrategy != AllReduceStrategyType::UB)
|
||||
{
|
||||
|
||||
initGroupTopology();
|
||||
}
|
||||
|
||||
TLLM_LOG_TRACE("%s stop for rank %d", __PRETTY_FUNCTION__, COMM_SESSION.getRank());
|
||||
TLLM_LOG_TRACE("%s stop for rank %d", __PRETTY_FUNCTION__, getRank());
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -288,13 +402,25 @@ private:
|
||||
torch::optional<torch::Tensor> const& residual, torch::optional<torch::Tensor> const& norm_weight,
|
||||
torch::optional<torch::Tensor> const& scale, torch::optional<torch::Tensor> const& bias)
|
||||
{
|
||||
torch::Tensor reduce_output;
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
|
||||
int size = input.numel();
|
||||
|
||||
torch::Tensor reduce_output = torch::empty_like(input);
|
||||
NCCLCHECK_THROW(ncclAllReduce(input.data_ptr(), reduce_output.mutable_data_ptr(), size, (*getDtypeMap())[mType],
|
||||
ncclSum, *mNcclComm, stream));
|
||||
std::visit(overloaded{[&](std::shared_ptr<ncclComm_t>& rawComm)
|
||||
{
|
||||
auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
|
||||
int size = input.numel();
|
||||
reduce_output = torch::empty_like(input);
|
||||
NCCLCHECK_THROW(ncclAllReduce(input.data_ptr(), reduce_output.mutable_data_ptr(), size,
|
||||
(*getDtypeMap())[mType], ncclSum, *rawComm, stream));
|
||||
},
|
||||
[&](c10::intrusive_ptr<c10d::ProcessGroup>& torchPg)
|
||||
{
|
||||
reduce_output = input.clone();
|
||||
// TLLM_LOG_INFO("AllReduce Rank: %d, tensor numel: %d", torchPg->getRank(),
|
||||
// reduce_output.numel());
|
||||
std::vector tensors{reduce_output};
|
||||
PGCHECK_THROW(torchPg->allreduce(tensors, {c10d::ReduceOp::SUM}));
|
||||
}},
|
||||
mNcclComm);
|
||||
|
||||
if (mOp == AllReduceFusionOp::NONE)
|
||||
{
|
||||
@ -307,12 +433,13 @@ private:
|
||||
|
||||
std::vector<torch::Tensor> runNCCLAllReduceSymmetric(torch::Tensor const& input,
|
||||
torch::optional<torch::Tensor> const& residual, torch::optional<torch::Tensor> const& norm_weight,
|
||||
torch::optional<torch::Tensor> const& scale, torch::optional<torch::Tensor> const& bias) noexcept
|
||||
torch::optional<torch::Tensor> const& scale, torch::optional<torch::Tensor> const& bias)
|
||||
{
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
|
||||
int size = input.numel();
|
||||
auto& ub_manager = tensorrt_llm::runtime::ub::UserBuffersManager::get_instance();
|
||||
auto ub_tensor0 = input;
|
||||
auto ub_buffer0 = ub_manager.search_buffer(input.data_ptr());
|
||||
if (ub_buffer0.invalid())
|
||||
{
|
||||
@ -321,13 +448,23 @@ private:
|
||||
cudaMemcpyAsync(symmetric_ub_buffer0.addr, input.data_ptr(), size * input.element_size(),
|
||||
cudaMemcpyDeviceToDevice, stream);
|
||||
ub_buffer0 = symmetric_ub_buffer0;
|
||||
ub_tensor0 = symmetric_input;
|
||||
}
|
||||
|
||||
TLLM_CHECK(!ub_buffer0.invalid());
|
||||
auto [norm_out, ub_buffer1] = torch_ext::create_userbuffers_tensor(input.sizes(), input.scalar_type());
|
||||
|
||||
NCCLCHECK(ncclAllReduce(
|
||||
ub_buffer0.addr, norm_out.mutable_data_ptr(), size, (*getDtypeMap())[mType], ncclSum, *mNcclComm, stream));
|
||||
std::visit(overloaded{[&, norm_out_ = norm_out](std::shared_ptr<ncclComm_t>& rawComm)
|
||||
{
|
||||
NCCLCHECK_THROW(ncclAllReduce(ub_buffer0.addr, norm_out_.mutable_data_ptr(), size,
|
||||
(*getDtypeMap())[mType], ncclSum, *rawComm, stream));
|
||||
},
|
||||
[&, norm_out_ = norm_out](c10::intrusive_ptr<c10d::ProcessGroup>& torchPg)
|
||||
{
|
||||
PGCHECK_THROW(PgHelper{torchPg}.allreduce(ub_tensor0, {c10d::ReduceOp::SUM}));
|
||||
std::ignore = norm_out_.copy_(ub_tensor0, true);
|
||||
}},
|
||||
mNcclComm);
|
||||
|
||||
if (mOp == AllReduceFusionOp::NONE)
|
||||
{
|
||||
@ -348,7 +485,7 @@ private:
|
||||
int hidden_size = input.size(-1);
|
||||
|
||||
auto const tp_size = mGroup.size();
|
||||
auto const cur_rank = COMM_SESSION.getRank();
|
||||
auto const cur_rank = getRank();
|
||||
int tp_rank = 0;
|
||||
|
||||
for (auto const& currentRank : mGroup)
|
||||
@ -418,7 +555,7 @@ private:
|
||||
int seq_len = input.size(0);
|
||||
|
||||
auto const tp_size = mGroup.size();
|
||||
auto const cur_rank = COMM_SESSION.getRank();
|
||||
auto const cur_rank = getRank();
|
||||
int tp_rank = 0;
|
||||
|
||||
for (auto const& currentRank : mGroup)
|
||||
@ -737,9 +874,13 @@ private:
|
||||
|
||||
void setGroupTopology()
|
||||
{
|
||||
auto const rank = COMM_SESSION.getRank();
|
||||
auto const rank = getRank();
|
||||
TLLM_LOG_INFO("Detecting local TP group for rank %d", rank);
|
||||
std::set<int> local_group = getLocalGroup(mGroup);
|
||||
std::set<int> local_group = std::visit(
|
||||
overloaded{[&](std::shared_ptr<ncclComm_t>&) { return getLocalGroup(mGroup); },
|
||||
[&](c10::intrusive_ptr<c10d::ProcessGroup>& torchPg) { return getLocalGroupTorch(mGroup); }},
|
||||
mNcclComm);
|
||||
|
||||
if (mGroup.size() != local_group.size())
|
||||
{
|
||||
mIsP2PSupported = false;
|
||||
@ -750,18 +891,17 @@ private:
|
||||
TLLM_LOG_INFO("TP group is intra-node for rank %d", rank);
|
||||
|
||||
NvmlManager nvml_manager;
|
||||
std::unordered_set<int> visited_device;
|
||||
mIsP2PSupported = true;
|
||||
mIsNVLINKSupported = true;
|
||||
|
||||
// TODO(ytong): Should we provide group topology info instead of querying it here?
|
||||
// Use cudaDeviceCanAccessPeer to determine whether p2p is supported,
|
||||
// and use nvml to determine whether there are nvlink links between ranks.
|
||||
for (int first_device_id : local_group)
|
||||
{
|
||||
for (int second_device_id : local_group)
|
||||
{
|
||||
if (first_device_id == second_device_id
|
||||
|| visited_device.find(second_device_id) != visited_device.end())
|
||||
if (first_device_id >= second_device_id)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
@ -842,7 +982,6 @@ private:
|
||||
|
||||
mIsNVLINKSupported &= is_NVLINK;
|
||||
}
|
||||
visited_device.insert(first_device_id);
|
||||
}
|
||||
}
|
||||
|
||||
@ -999,14 +1138,14 @@ private:
|
||||
AllReduceStrategyType mStrategy;
|
||||
AllReduceFusionOp mOp;
|
||||
float mEps;
|
||||
std::shared_ptr<ncclComm_t> mNcclComm;
|
||||
std::variant<std::shared_ptr<ncclComm_t>, c10::intrusive_ptr<c10d::ProcessGroup>> mNcclComm;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
#endif // ENABLE_MULTI_DEVICE
|
||||
|
||||
std::vector<torch::Tensor> allreduce(torch::Tensor const& input, torch::optional<torch::Tensor> const& residual,
|
||||
std::vector<torch::Tensor> allreduce_raw(torch::Tensor const& input, torch::optional<torch::Tensor> const& residual,
|
||||
torch::optional<torch::Tensor> const& norm_weight, torch::optional<torch::Tensor> const& scale,
|
||||
torch::optional<torch::Tensor> const& bias, torch::optional<torch::Tensor> workspace,
|
||||
torch::List<int64_t> const& group_, int64_t const strategy_, int64_t const fusion_op_, double const eps_,
|
||||
@ -1030,6 +1169,46 @@ std::vector<torch::Tensor> allreduce(torch::Tensor const& input, torch::optional
|
||||
#endif // ENABLE_MULTI_DEVICE
|
||||
}
|
||||
|
||||
std::vector<torch::Tensor> allreduce_pg(torch::Tensor const& input, torch::optional<torch::Tensor> const& residual,
|
||||
torch::optional<torch::Tensor> const& norm_weight, torch::optional<torch::Tensor> const& scale,
|
||||
torch::optional<torch::Tensor> const& bias, torch::optional<torch::Tensor> const& workspace,
|
||||
torch::List<int64_t> const& group_, int64_t rank, c10::intrusive_ptr<c10d::ProcessGroup> const& pg,
|
||||
int64_t const strategy_, int64_t const fusion_op_, double const eps_, bool const trigger_completion_at_end_)
|
||||
{
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
auto const dtype = tensorrt_llm::runtime::TorchUtils::dataType(input.scalar_type());
|
||||
auto const strategy = static_cast<AllReduceStrategyType>(int8_t(strategy_));
|
||||
auto const fusion_op = static_cast<AllReduceFusionOp>(int8_t(fusion_op_));
|
||||
float const eps = eps_;
|
||||
std::set<int> group;
|
||||
|
||||
for (int64_t my_rank : group_)
|
||||
{
|
||||
group.insert(static_cast<int>(my_rank));
|
||||
}
|
||||
|
||||
// Get nccl rank for this process process_group_
|
||||
auto it = group.find(rank);
|
||||
if (it == group.end())
|
||||
{
|
||||
throw std::runtime_error("Rank not found in group");
|
||||
}
|
||||
int nccl_rank = std::distance(group.begin(), it);
|
||||
|
||||
if (nccl_rank != pg->getRank())
|
||||
{
|
||||
throw std::runtime_error("nccl_rank != pg->getRank()");
|
||||
}
|
||||
|
||||
AllreduceOp op(group, pg, dtype, strategy, fusion_op, eps);
|
||||
op.initialize();
|
||||
auto ret = op.run(input, residual, norm_weight, scale, bias, trigger_completion_at_end_, workspace);
|
||||
return ret;
|
||||
#else
|
||||
return {input};
|
||||
#endif // ENABLE_MULTI_DEVICE
|
||||
}
|
||||
|
||||
// residual [m, hidden_dim]
|
||||
// norm_weight [hidden_dim]
|
||||
// device_num_experts [1]
|
||||
@ -1231,6 +1410,21 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m)
|
||||
"int op,"
|
||||
"float eps,"
|
||||
"bool trigger_completion_at_end) -> Tensor[]");
|
||||
m.def(
|
||||
"allreduce_pg("
|
||||
"Tensor input,"
|
||||
"Tensor? residual,"
|
||||
"Tensor? norm_weight,"
|
||||
"Tensor? scale,"
|
||||
"Tensor? bias,"
|
||||
"Tensor? workspace,"
|
||||
"int[] group,"
|
||||
"int rank,"
|
||||
"__torch__.torch.classes.c10d.ProcessGroup pg,"
|
||||
"int strategy,"
|
||||
"int op,"
|
||||
"float eps,"
|
||||
"bool trigger_completion_at_end) -> Tensor[]");
|
||||
m.def(
|
||||
"moe_allreduce("
|
||||
"Tensor residual,"
|
||||
@ -1262,7 +1456,8 @@ TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
|
||||
{
|
||||
m.impl("mnnvl_twoshot_allreduce", &torch_ext::mnnvlTwoShotAllReduce);
|
||||
m.impl("mnnvl_twoshot_rmsnorm", &torch_ext::twoShotRMSNorm);
|
||||
m.impl("allreduce", &torch_ext::allreduce);
|
||||
m.impl("allreduce", &torch_ext::allreduce_raw);
|
||||
m.impl("allreduce_pg", &torch_ext::allreduce_pg);
|
||||
m.impl("moe_allreduce", &torch_ext::moe_allreduce);
|
||||
m.impl("moe_finalize_allreduce", &torch_ext::moe_finalize_allreduce);
|
||||
}
|
||||
|
||||
@ -18,18 +18,22 @@
|
||||
#include "tensorrt_llm/common/opUtils.h"
|
||||
#include "tensorrt_llm/runtime/torchUtils.h"
|
||||
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
|
||||
#include "tensorrt_llm/runtime/utils/pgUtils.h"
|
||||
|
||||
#include <NvInferRuntime.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <torch/extension.h>
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
#include <nccl.h>
|
||||
#include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
|
||||
#endif // ENABLE_MULTI_DEVICE
|
||||
|
||||
#include <cassert>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
|
||||
using tensorrt_llm::pg_utils::PgHelper;
|
||||
|
||||
namespace torch_ext
|
||||
{
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
@ -49,9 +53,9 @@ public:
|
||||
|
||||
int initialize()
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start for rank %d", __PRETTY_FUNCTION__, COMM_SESSION.getRank());
|
||||
TLLM_LOG_TRACE("%s start for rank %d", __PRETTY_FUNCTION__, -1);
|
||||
mNcclComm = getComm(mGroup);
|
||||
TLLM_LOG_TRACE("%s stop for rank %d", __PRETTY_FUNCTION__, COMM_SESSION.getRank());
|
||||
TLLM_LOG_TRACE("%s stop for rank %d", __PRETTY_FUNCTION__, -1);
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -125,6 +129,86 @@ private:
|
||||
std::shared_ptr<ncclComm_t> mNcclComm;
|
||||
};
|
||||
|
||||
class ReducescatterPgOp
|
||||
{
|
||||
public:
|
||||
ReducescatterPgOp(std::set<int> group, c10::intrusive_ptr<c10d::ProcessGroup> const& process_group_)
|
||||
: mGroup(std::move(group))
|
||||
, mProcessGroup(process_group_)
|
||||
{
|
||||
}
|
||||
|
||||
~ReducescatterPgOp() = default;
|
||||
|
||||
int initialize() noexcept
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start for rank %d", __PRETTY_FUNCTION__, mProcessGroup->getRank());
|
||||
return 0;
|
||||
}
|
||||
|
||||
std::pair<torch::Tensor, c10::intrusive_ptr<c10d::Work>> run(
|
||||
torch::Tensor input, torch::optional<torch::List<int64_t>> sizes, bool coalescing = false)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(mProcessGroup.get() != nullptr, "mProcessGroup should be initialized before used");
|
||||
auto rank = mProcessGroup->getRank();
|
||||
std::vector<int64_t> outputShape = input.sizes().vec();
|
||||
if (sizes.has_value())
|
||||
{
|
||||
TLLM_CHECK(sizes.value().size() == mGroup.size());
|
||||
outputShape[0] = sizes.value()[rank];
|
||||
}
|
||||
else
|
||||
{
|
||||
outputShape[0] = outputShape[0] / mGroup.size();
|
||||
}
|
||||
auto output = torch::empty(outputShape, input.options());
|
||||
|
||||
int64_t split_offset = 0;
|
||||
std::vector<torch::Tensor> inputTensors{};
|
||||
for (int root = 0; root < static_cast<int>(mGroup.size()); ++root)
|
||||
{
|
||||
auto split_size = sizes.has_value() ? sizes.value()[root] : outputShape[0];
|
||||
inputTensors.push_back(input.index({torch::indexing::Slice(split_offset, split_offset + split_size)}));
|
||||
split_offset += split_size;
|
||||
}
|
||||
std::vector<torch::Tensor> outputs{output};
|
||||
std::vector<std::vector<torch::Tensor>> inputs{inputTensors};
|
||||
auto work = mProcessGroup->reduce_scatter(outputs, inputs, {});
|
||||
|
||||
if (!coalescing)
|
||||
{
|
||||
PGCHECK_THROW_WITH_INFO(work, "ProcessGroup: reduce_scatter");
|
||||
return {output, nullptr};
|
||||
}
|
||||
|
||||
return {output, work};
|
||||
}
|
||||
|
||||
std::vector<torch::Tensor> run_list(torch::TensorList input_list, torch::optional<torch::List<int64_t>> sizes)
|
||||
{
|
||||
std::vector<torch::Tensor> output_list;
|
||||
std::vector<c10::intrusive_ptr<c10d::Work>> work_list;
|
||||
output_list.reserve(input_list.size());
|
||||
work_list.reserve(input_list.size());
|
||||
mProcessGroup->startCoalescing(c10::DeviceType::CUDA);
|
||||
for (auto const& input : input_list)
|
||||
{
|
||||
auto [output, work] = run(input, sizes, true);
|
||||
output_list.push_back(output);
|
||||
work_list.push_back(work); // Hold work objects (input & output tensors) until endCoalescing wait finished
|
||||
}
|
||||
if (auto work = mProcessGroup->endCoalescing(c10::DeviceType::CUDA))
|
||||
{
|
||||
PGCHECK_THROW_WITH_INFO(work, "ProcessGroup: reduce_scatter, end coalescing");
|
||||
}
|
||||
return output_list;
|
||||
}
|
||||
|
||||
private:
|
||||
std::set<int> mGroup;
|
||||
c10::intrusive_ptr<c10d::ProcessGroup> mProcessGroup;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
#endif // ENABLE_MULTI_DEVICE
|
||||
@ -147,6 +231,24 @@ extern torch::Tensor reducescatter(
|
||||
#endif // ENABLE_MULTI_DEVICE
|
||||
}
|
||||
|
||||
extern torch::Tensor reducescatter_pg(torch::Tensor input, torch::optional<torch::List<int64_t>> sizes,
|
||||
torch::List<int64_t> group_, c10::intrusive_ptr<c10d::ProcessGroup> const& process_group_)
|
||||
{
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
std::set<int> group;
|
||||
for (int64_t rank : group_)
|
||||
{
|
||||
group.insert(static_cast<int>(rank));
|
||||
}
|
||||
ReducescatterPgOp op(group, process_group_);
|
||||
op.initialize();
|
||||
auto [output, _] = op.run(input, sizes);
|
||||
return output;
|
||||
#else
|
||||
return input;
|
||||
#endif // ENABLE_MULTI_DEVICE
|
||||
}
|
||||
|
||||
extern std::vector<torch::Tensor> reducescatter_list(
|
||||
torch::TensorList input_list, torch::optional<torch::List<int64_t>> sizes, torch::List<int64_t> group_)
|
||||
{
|
||||
@ -165,16 +267,42 @@ extern std::vector<torch::Tensor> reducescatter_list(
|
||||
#endif // ENABLE_MULTI_DEVICE
|
||||
}
|
||||
|
||||
extern std::vector<torch::Tensor> reducescatter_list_pg(torch::TensorList input_list,
|
||||
torch::optional<torch::List<int64_t>> sizes, torch::List<int64_t> group_,
|
||||
c10::intrusive_ptr<c10d::ProcessGroup> const& process_group_)
|
||||
{
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
std::set<int> group;
|
||||
for (int64_t rank : group_)
|
||||
{
|
||||
group.insert(static_cast<int>(rank));
|
||||
}
|
||||
ReducescatterPgOp op(group, process_group_);
|
||||
op.initialize();
|
||||
auto output_list = op.run_list(input_list, sizes);
|
||||
return output_list;
|
||||
#else
|
||||
return input_list.vec();
|
||||
#endif // ENABLE_MULTI_DEVICE
|
||||
}
|
||||
} // namespace torch_ext
|
||||
|
||||
TORCH_LIBRARY_FRAGMENT(trtllm, m)
|
||||
{
|
||||
m.def("reducescatter(Tensor input, SymInt[]? sizes, int[] group) -> Tensor");
|
||||
m.def(
|
||||
"reducescatter_pg(Tensor input, SymInt[]? sizes, int[] group, __torch__.torch.classes.c10d.ProcessGroup "
|
||||
"process_group) -> Tensor");
|
||||
m.def("reducescatter_list(Tensor[] input_list, SymInt[]? sizes, int[] group) -> Tensor[]");
|
||||
m.def(
|
||||
"reducescatter_list_pg(Tensor[] input_list, SymInt[]? sizes, int[] group, "
|
||||
"__torch__.torch.classes.c10d.ProcessGroup process_group) -> Tensor[]");
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
|
||||
{
|
||||
m.impl("reducescatter", &torch_ext::reducescatter);
|
||||
m.impl("reducescatter_pg", &torch_ext::reducescatter_pg);
|
||||
m.impl("reducescatter_list", &torch_ext::reducescatter_list);
|
||||
m.impl("reducescatter_list_pg", &torch_ext::reducescatter_list_pg);
|
||||
}
|
||||
|
||||
@ -31,6 +31,7 @@ add_gtest(loraConfigTest loraConfigTest.cpp)
|
||||
add_gtest(intervalSetTest intervalSetTest.cpp)
|
||||
add_gtest(dynamicBatchTunerTest dynamicBatchTunerTest.cpp)
|
||||
add_gtest(ucxCommTest ucxCommTest.cpp)
|
||||
target_link_libraries(ucxCommTest PRIVATE ${Python3_LIBRARIES})
|
||||
|
||||
if(NIXL_ROOT)
|
||||
add_gtest(transferAgentTest transferAgentTest.cpp)
|
||||
|
||||
@ -12,5 +12,7 @@
|
||||
add_subdirectory(kernels)
|
||||
|
||||
add_gtest(cacheTransceiverTest cacheTransceiverTest.cpp)
|
||||
target_link_libraries(cacheTransceiverTest PRIVATE ${Python3_LIBRARIES})
|
||||
|
||||
add_gtest(mpiUtilsTest mpiUtilsTest.cpp)
|
||||
add_gtest(userBufferTest userBufferTest.cpp)
|
||||
|
||||
BIN
docs/source/media/ray_orchestrator_architecture.jpg
Normal file
BIN
docs/source/media/ray_orchestrator_architecture.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 283 KiB |
51
examples/ray_orchestrator/README.md
Normal file
51
examples/ray_orchestrator/README.md
Normal file
@ -0,0 +1,51 @@
|
||||
<div align="center">
|
||||
|
||||
# TensorRT-LLM with Ray orchestrator
|
||||
|
||||
</div>
|
||||
|
||||
<div align="left">
|
||||
|
||||
This folder contains examples for a prototype **Ray orchestrator** that supports on-demand LLM instance spin-up and flexible GPU placement across single- and multi-node inference. It’s a first step toward making TensorRT-LLM a better fit for Reinforcement learning from human feedback (RLHF) workflows. For RLHF, [Ray](https://docs.ray.io/en/latest/index.html) — unlike MPI’s fixed world size and placement — can dynamically spawn and reconnect distributed inference actors, each with its own parallelism strategy.
|
||||
|
||||
This feature is a prototype and under active development. MPI remains the default.
|
||||
|
||||
|
||||
## Quick Start
|
||||
To use Ray orchestrator, you need to first install Ray.
|
||||
```shell
|
||||
cd examples/ray_orchestrator
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
Run a simple `TP=2` example with a Hugging Face model:
|
||||
|
||||
```shell
|
||||
python llm_inference_distributed_ray.py
|
||||
```
|
||||
|
||||
This example is the same as in `/examples/llm-api`, with the only change being `orchestrator_type="ray"` on `LLM()`. Other examples can be adapted similarly by toggling this flag.
|
||||
|
||||
|
||||
## Features
|
||||
### Available
|
||||
- Generate text asynchronously (refer to [llm_inference_async_ray.py](llm_inference_async_ray.py))
|
||||
- Multi-node inference (refer to [multi-node README](./multi_nodes/README.md))
|
||||
- Disaggregated serving (refer to [disagg README](./disaggregated/README.md))
|
||||
|
||||
**Initial testing has been focused on LLaMA and DeepSeek variants. Please open an Issue if you encounter problems with other models so we can prioritize support.**
|
||||
|
||||
### Upcoming
|
||||
- Performance optimization
|
||||
- Integration with RLHF frameworks, such as [NVIDIA Nemo-RL](https://github.com/NVIDIA-NeMo/RL) and [Verl](https://github.com/volcengine/verl).
|
||||
|
||||
## Architecture
|
||||
This feature introduces new classes such as [RayExecutor](/tensorrt_llm/executor/ray_executor.py) and [RayGPUWorker](/tensorrt_llm/executor/ray_gpu_worker.py) for Ray actor lifecycle management and distributed inference. In Ray mode, collective ops run on [torch.distributed](https://docs.pytorch.org/tutorials/beginner/dist_overview.html) without MPI. We welcome contributions to improve and extend this support.
|
||||
|
||||

|
||||
|
||||
|
||||
## Disclaimer
|
||||
The code a prototype and subject to change. Currently, there are no guarantees regarding functionality, performance, or stability.
|
||||
|
||||
</div>
|
||||
28
examples/ray_orchestrator/disaggregated/README.md
Normal file
28
examples/ray_orchestrator/disaggregated/README.md
Normal file
@ -0,0 +1,28 @@
|
||||
# Disaggregated Serving with Ray orchestrator
|
||||
TensorRT-LLM supports a prototype [Ray orchestrator](../README.md) as an alternative to MPI.
|
||||
|
||||
Running disaggregated serving with Ray follows [the same workflow as in MPI](/examples/disaggregated/README.md), except that `orchestrator_type="ray"` must be set on the `LLM` class, and `CUDA_VISIBLE_DEVICES` can be omitted since Ray handles GPU placement.
|
||||
|
||||
|
||||
## Quick Start
|
||||
This script is a shorthand to launch a single-GPU context and generation server, as well as the disaggregated server within a single Ray cluster. Please see [this documentation](/examples/disaggregated/README.md) for details on adjusting parallel settings.
|
||||
|
||||
```bash
|
||||
# requires a total of two GPUs
|
||||
bash -e disagg_serving_local.sh
|
||||
```
|
||||
|
||||
Once the disaggregated server is ready, you can send requests to the disaggregated server using curl:
|
||||
```bash
|
||||
curl http://localhost:8000/v1/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"prompt": "NVIDIA is a great company because",
|
||||
"max_tokens": 16,
|
||||
"temperature": 0
|
||||
}' -w "\n"
|
||||
```
|
||||
|
||||
## Disclaimer
|
||||
The code is a prototype and subject to change. Currently, there are no guarantees regarding functionality, performance, or stability.
|
||||
146
examples/ray_orchestrator/disaggregated/disagg_serving_local.sh
Normal file
146
examples/ray_orchestrator/disaggregated/disagg_serving_local.sh
Normal file
@ -0,0 +1,146 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Parse command line arguments
|
||||
BACKEND="ray"
|
||||
ATTACH_MODE=false
|
||||
TP_SIZE=1
|
||||
USAGE="Usage: $0 [--executor ray|mpi] [--attach] [--tp_size N] [--help]"
|
||||
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case $1 in
|
||||
--executor)
|
||||
BACKEND="$2"
|
||||
shift 2
|
||||
;;
|
||||
--attach)
|
||||
ATTACH_MODE=true
|
||||
shift
|
||||
;;
|
||||
--tp_size)
|
||||
TP_SIZE="$2"
|
||||
shift 2
|
||||
;;
|
||||
--help|-h)
|
||||
echo "$USAGE"
|
||||
echo "Options:"
|
||||
echo " --executor ray|mpi Choose distributed executor (default: ray)"
|
||||
echo " --attach Attach to existing ray cluster (skip ray start/stop)"
|
||||
echo " --tp_size N Tensor parallel size (default: 1)"
|
||||
echo " --help, -h Show this help message"
|
||||
exit 0
|
||||
;;
|
||||
*)
|
||||
echo "Unknown option: $1"
|
||||
echo "$USAGE"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
if [[ "$BACKEND" != "ray" && "$BACKEND" != "mpi" ]]; then
|
||||
echo "Error: Executor must be either 'ray' or 'mpi'"
|
||||
echo "$USAGE"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Executor: $BACKEND"
|
||||
echo "Tensor parallel size: $TP_SIZE"
|
||||
if [[ "$ATTACH_MODE" == "true" ]]; then
|
||||
echo "Attach mode enabled - will not manage ray cluster"
|
||||
fi
|
||||
|
||||
# Generate extra_llm_config.yaml based on executor type
|
||||
echo "Generating extra_llm_config.yaml for executor: $BACKEND"
|
||||
if [[ "$BACKEND" == "ray" ]]; then
|
||||
cat > extra_llm_config.yaml << EOF
|
||||
# extra_llm_config.yaml when launching disaggregated server instances.
|
||||
cache_transceiver_config:
|
||||
backend: "UCX"
|
||||
max_tokens_in_buffer: 2048
|
||||
disable_overlap_scheduler: true
|
||||
# Ray executor configuration
|
||||
orchestrator_type: "ray"
|
||||
EOF
|
||||
else
|
||||
cat > extra_llm_config.yaml << EOF
|
||||
# extra_llm_config.yaml when launching disaggregated server instances.
|
||||
cache_transceiver_config:
|
||||
backend: "UCX"
|
||||
max_tokens_in_buffer: 2048
|
||||
disable_overlap_scheduler: true
|
||||
# Using default executor MPI (no orchestrator_type specified)
|
||||
EOF
|
||||
fi
|
||||
|
||||
# Generate disaggregated server config
|
||||
echo "Generating disagg_config_local.yaml"
|
||||
cat > disagg_config_local.yaml << EOF
|
||||
hostname: localhost
|
||||
port: 8000
|
||||
model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
|
||||
free_gpu_memory_fraction: 0.25
|
||||
backend: "pytorch"
|
||||
disable_overlap_scheduler: True
|
||||
context_servers:
|
||||
num_instances: 1
|
||||
tensor_parallel_size: $TP_SIZE
|
||||
pipeline_parallel_size: 1
|
||||
kv_cache_config:
|
||||
free_gpu_memory_fraction: 0.2
|
||||
cache_transceiver_config:
|
||||
backend: "UCX"
|
||||
urls:
|
||||
- "localhost:8001"
|
||||
generation_servers:
|
||||
num_instances: 1
|
||||
tensor_parallel_size: $TP_SIZE
|
||||
pipeline_parallel_size: 1
|
||||
cache_transceiver_config:
|
||||
backend: "UCX"
|
||||
urls:
|
||||
- "localhost:8002"
|
||||
EOF
|
||||
|
||||
# Conditionally start ray head if using ray backend and not in attach mode
|
||||
RAY_STARTED=false
|
||||
if [[ "$BACKEND" == "ray" && "$ATTACH_MODE" != "true" ]]; then
|
||||
echo "Checking if ray cluster is already running..."
|
||||
if ray status > /dev/null 2>&1; then
|
||||
echo "Ray cluster is already running. Stopping existing cluster first..."
|
||||
ray stop
|
||||
sleep 2
|
||||
fi
|
||||
echo "Launching ray head..."
|
||||
ray start --head --disable-usage-stats
|
||||
RAY_STARTED=true
|
||||
elif [[ "$BACKEND" == "ray" && "$ATTACH_MODE" == "true" ]]; then
|
||||
echo "Attach mode: Skipping ray cluster management"
|
||||
fi
|
||||
|
||||
# Launching context servers
|
||||
echo "Launching context servers..."
|
||||
if [[ "$BACKEND" == "mpi" ]]; then
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
fi
|
||||
|
||||
trtllm-serve TinyLlama/TinyLlama-1.1B-Chat-v1.0 --host localhost --tp_size $TP_SIZE --port 8001 --kv_cache_free_gpu_memory_fraction 0.15 --backend pytorch --extra_llm_api_options extra_llm_config.yaml &> output_ctx0 &
|
||||
|
||||
if [[ "$BACKEND" == "mpi" ]]; then
|
||||
export CUDA_VISIBLE_DEVICES=1
|
||||
fi
|
||||
# Launching generation servers
|
||||
echo "Launching generation servers..."
|
||||
trtllm-serve TinyLlama/TinyLlama-1.1B-Chat-v1.0 --host localhost --tp_size $TP_SIZE --port 8002 --kv_cache_free_gpu_memory_fraction 0.15 --backend pytorch --extra_llm_api_options extra_llm_config.yaml &> output_gen0 &
|
||||
|
||||
# Launching disaggregated server
|
||||
echo "Launching disaggregated server..."
|
||||
trtllm-serve disaggregated -c disagg_config_local.yaml
|
||||
|
||||
# Cleanup
|
||||
if [[ "$RAY_STARTED" == "true" && "$ATTACH_MODE" != "true" ]]; then
|
||||
echo "Stopping ray..."
|
||||
ray stop
|
||||
fi
|
||||
|
||||
echo "Cleaning up generated extra_llm_config.yaml..."
|
||||
rm -f extra_llm_config.yaml
|
||||
55
examples/ray_orchestrator/llm_inference_async_ray.py
Normal file
55
examples/ray_orchestrator/llm_inference_async_ray.py
Normal file
@ -0,0 +1,55 @@
|
||||
# Generate text asynchronously with Ray orchestrator.
|
||||
import asyncio
|
||||
|
||||
from tensorrt_llm import LLM, SamplingParams
|
||||
from tensorrt_llm.llmapi import KvCacheConfig
|
||||
|
||||
|
||||
def main():
|
||||
# Configure KV cache memory usage fraction.
|
||||
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.5,
|
||||
max_tokens=4096,
|
||||
enable_block_reuse=True)
|
||||
|
||||
# model could accept HF model name or a path to local HF model.
|
||||
llm = LLM(
|
||||
model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
kv_cache_config=kv_cache_config,
|
||||
max_seq_len=1024,
|
||||
max_batch_size=1,
|
||||
orchestrator_type="ray", # Enable Ray orchestrator
|
||||
# Enable 2-way tensor parallelism
|
||||
# tensor_parallel_size=2
|
||||
)
|
||||
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
# Create a sampling params.
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
|
||||
# Async based on Python coroutines
|
||||
async def task(prompt: str):
|
||||
output = await llm.generate_async(prompt, sampling_params)
|
||||
print(
|
||||
f"Prompt: {output.prompt!r}, Generated text: {output.outputs[0].text!r}"
|
||||
)
|
||||
|
||||
async def main():
|
||||
tasks = [task(prompt) for prompt in prompts]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
# Got output like follows:
|
||||
# Prompt: 'Hello, my name is', Generated text: '\n\nJane Smith. I am a student pursuing my degree in Computer Science at [university]. I enjoy learning new things, especially technology and programming'
|
||||
# Prompt: 'The capital of France is', Generated text: 'Paris.'
|
||||
# Prompt: 'The future of AI is', Generated text: 'an exciting time for us. We are constantly researching, developing, and improving our platform to create the most advanced and efficient model available. We are'
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
56
examples/ray_orchestrator/llm_inference_distributed_ray.py
Normal file
56
examples/ray_orchestrator/llm_inference_distributed_ray.py
Normal file
@ -0,0 +1,56 @@
|
||||
# Generate text with Ray orchestrator.
|
||||
import argparse
|
||||
|
||||
from tensorrt_llm import LLM, SamplingParams
|
||||
|
||||
|
||||
def main():
|
||||
# model could accept HF model name or a path to local HF model.
|
||||
llm = LLM(
|
||||
model=args.model_dir,
|
||||
orchestrator_type="ray", # Enable Ray orchestrator
|
||||
# Enable 2-way tensor parallelism
|
||||
tensor_parallel_size=args.tp_size,
|
||||
# Enable 2-way pipeline parallelism if needed
|
||||
pipeline_parallel_size=args.pp_size,
|
||||
# Enable 2-way expert parallelism for MoE model's expert weights
|
||||
moe_expert_parallel_size=args.moe_ep_size)
|
||||
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
# Create a sampling params.
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
|
||||
for output in llm.generate(prompts, sampling_params):
|
||||
print(
|
||||
f"Prompt: {output.prompt!r}, Generated text: {output.outputs[0].text!r}"
|
||||
)
|
||||
|
||||
# Got output like
|
||||
# Prompt: 'Hello, my name is', Generated text: '\n\nJane Smith. I am a student pursuing my degree in Computer Science at [university]. I enjoy learning new things, especially technology and programming'
|
||||
# Prompt: 'The capital of France is', Generated text: 'Paris.'
|
||||
# Prompt: 'The future of AI is', Generated text: 'an exciting time for us. We are constantly researching, developing, and improving our platform to create the most advanced and efficient model available. We are'
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='LLM Inference with Ray orchestrator')
|
||||
parser.add_argument('--model_dir',
|
||||
type=str,
|
||||
default="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
help="Model checkpoint directory")
|
||||
parser.add_argument('--tp_size', type=int, default=2)
|
||||
parser.add_argument('--pp_size', type=int, default=1)
|
||||
parser.add_argument('--moe_ep_size', type=int, default=-1)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_arguments()
|
||||
|
||||
main()
|
||||
43
examples/ray_orchestrator/multi_nodes/README.md
Normal file
43
examples/ray_orchestrator/multi_nodes/README.md
Normal file
@ -0,0 +1,43 @@
|
||||
# Multi-node inference with Ray orchestrator
|
||||
TensorRT-LLM supports a prototype [Ray orchestrator](../README.md) as an alternative to MPI. The following example shows how to start a Ray cluster for multi-node inference.
|
||||
|
||||
|
||||
## Quick Start
|
||||
|
||||
**Prerequisite:** a container image with TensorRT-LLM preinstalled (or suitable for installing it). The examples use Slurm and [Enroot](https://github.com/NVIDIA/enroot). If you use a different setup, adapt the following scripts and commands to your multi-node environment.
|
||||
|
||||
1. Allocate nodes and open a shell on the head node:
|
||||
|
||||
```shell
|
||||
# e.g., 2 nodes with 8 GPUs per node
|
||||
>> salloc -t 240 -N 2 -p interactive
|
||||
|
||||
>> srun --pty -p interactive bash
|
||||
```
|
||||
2. Once on the head node, launch a multi-node Ray cluster:
|
||||
```shell
|
||||
# Remember to set CONTAINER and MOUNTS env vars or variables inside the script to your path.
|
||||
# You can add the TensorRT-LLM installation command in this script if it is not preinstalled in your container.
|
||||
>> bash -e run_cluster.sh
|
||||
```
|
||||
|
||||
3. Enter the head container and run your TensorRT-LLM driver script
|
||||
|
||||
Note that this step requires TensorRT-LLM to be installed in the containers on all nodes. If it isn’t, install it manually inside each node’s container.
|
||||
|
||||
```shell
|
||||
# On the head node
|
||||
>> sacct
|
||||
|
||||
# Grab the Slurm step ID with Job Name "ray-head"
|
||||
>> srun --jobid=<Your Step ID> --overlap --pty bash
|
||||
|
||||
>> enroot list -f # get process id
|
||||
>> enroot exec <process id> bash
|
||||
|
||||
# You can change this script to a model and parallel settings effective for multi-node inference (e.g., TP8 or TP4PP4).
|
||||
>> python examples/ray_orchestrator/llm_inference_async_ray.py
|
||||
```
|
||||
|
||||
## Disclaimer
|
||||
The code is a prototype and subject to change. Currently, there are no guarantees regarding functionality, performance, or stability.
|
||||
228
examples/ray_orchestrator/multi_nodes/run_cluster.sh
Normal file
228
examples/ray_orchestrator/multi_nodes/run_cluster.sh
Normal file
@ -0,0 +1,228 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# NOTE: Multi-node with Ray orchestrator in TensorRT-LLM is an experimental feature and may not work on all systems.
|
||||
# This script launches a Ray cluster and connects all allocated nodes for multi-node inference.
|
||||
|
||||
# The following variables are expected to be set in the environment:
|
||||
# CONTAINER: the path of the container image that has the desired TensorRT-LLM version installed.
|
||||
# MOUNTS: directory mount specification in format src:dest
|
||||
|
||||
# Run inside an already‑active allocated node.
|
||||
# To start a Ray cluster across nodes:
|
||||
# >> bash -e launch_ray.sh
|
||||
#
|
||||
# See multi_nodes/README.md for more details.
|
||||
#
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
: "${CONTAINER:?Set CONTAINER to the container image path (e.g. /path/to/trtllm.sqfs)}"
|
||||
: "${MOUNTS:?Set MOUNTS to mount spec src:dst[,src2:dst2...] (no spaces)}"
|
||||
RAY_PORT=${RAY_PORT:-6379}
|
||||
|
||||
RUN_ID=$(date +%m%d-%H%M-%S)
|
||||
HEAD_NAME="ray-head-${RUN_ID}"
|
||||
|
||||
SLURM_SUBMIT_DIR=$PWD
|
||||
BASE_LOG_DIR=${BASE_LOG_DIR:-${SLURM_SUBMIT_DIR:-$(pwd)}}
|
||||
LOG_DIR="$BASE_LOG_DIR/${SLURM_JOB_ID}-logs"
|
||||
mkdir -p "$LOG_DIR"
|
||||
|
||||
COMMAND=""
|
||||
if [[ "$#" -gt 0 ]]; then
|
||||
for arg; do [[ "$arg" == "--" ]] && shift && break; done
|
||||
COMMAND="$*"
|
||||
fi
|
||||
|
||||
MIN_WORKER_PORT=${MIN_WORKER_PORT:-54001}
|
||||
MAX_WORKER_PORT=${MAX_WORKER_PORT:-54257}
|
||||
|
||||
COMMON_SRUN_ARGS+=" --mpi=pmix"
|
||||
COMMON_SRUN_ARGS+=" --container-remap-root --container-writable"
|
||||
COMMON_SRUN_ARGS+=" --container-mounts=$MOUNTS"
|
||||
COMMON_SRUN_ARGS+=" --container-image=$CONTAINER"
|
||||
COMMON_SRUN_ARGS+=" --container-workdir=$SLURM_SUBMIT_DIR"
|
||||
COMMON_SRUN_ARGS+=" -p $SLURM_JOB_PARTITION"
|
||||
COMMON_SRUN_ARGS+=" -A $SLURM_JOB_ACCOUNT"
|
||||
COMMON_SRUN_ARGS+=" --gres=gpu:$SLURM_GPUS_ON_NODE"
|
||||
|
||||
# Getting the node names and IP addresses in the SLURM allocation
|
||||
nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST")
|
||||
nodes_array=($nodes)
|
||||
ip_addresses_array=()
|
||||
|
||||
for node in $nodes; do
|
||||
ip_address=$(host $node | awk '/has address/ { print $4 }')
|
||||
ip_addresses_array+=("$ip_address")
|
||||
done
|
||||
|
||||
head_node=${nodes_array[0]}
|
||||
head_node_ip=${ip_addresses_array[0]}
|
||||
WORKERS=("${nodes_array[@]:1}")
|
||||
|
||||
ip_head=$head_node_ip:$RAY_PORT
|
||||
|
||||
BLUE='\e[96m'
|
||||
GREEN='\e[32m'
|
||||
RESET='\e[0m'
|
||||
|
||||
echo -e "${BLUE}[INFO] Head node : $head_node${RESET}"
|
||||
echo -e "${BLUE}[INFO] Worker(s) : ${WORKERS[*]}${RESET}"
|
||||
echo -e "${BLUE}[INFO] GPUs per node : ${SLURM_GPUS_ON_NODE}${RESET}"
|
||||
echo -e "${BLUE}[INFO] CONTAINER: '$CONTAINER'${RESET}"
|
||||
echo -e "${BLUE}[INFO] COMMAND = '$COMMAND'${RESET}"
|
||||
echo -e "${BLUE}[INFO] Logs : $LOG_DIR${RESET}"
|
||||
|
||||
########################################################
|
||||
# Start Ray cluster on head node
|
||||
########################################################
|
||||
|
||||
# enabled dashboard only for debug
|
||||
# Add apt-get install -y --no-install-recommends libzmq3-dev for multi-node disagg
|
||||
|
||||
head_cmd=$(cat <<EOF
|
||||
# WAR: clean all slurm / MPI / PMIx env to avoid pmix mismatch error
|
||||
for v in \$(env | awk -F= '/^(PMI|PMIX|MPI|OMPI|SLURM)_/{print \$1}'); do
|
||||
unset "\$v"
|
||||
done
|
||||
# ---- mark that the container shell is alive ----
|
||||
touch "$LOG_DIR/STARTED_RAY_HEAD"
|
||||
export RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES=1
|
||||
export RAY_DEDUP_LOGS=0
|
||||
export TRTLLM_UCX_INTERFACE=eth0
|
||||
|
||||
ray start --head \
|
||||
--port=$RAY_PORT \
|
||||
--node-ip-address="$head_node_ip" \
|
||||
--disable-usage-stats \
|
||||
--include-dashboard=true \
|
||||
--min-worker-port=${MIN_WORKER_PORT} \
|
||||
--max-worker-port=${MAX_WORKER_PORT} \
|
||||
--num-cpus=16 \
|
||||
--block
|
||||
EOF
|
||||
)
|
||||
|
||||
echo "head_cmd: $head_cmd"
|
||||
srun --overlap $COMMON_SRUN_ARGS --job-name=ray-head --container-name="${HEAD_NAME}" --nodes=1 --ntasks=1 -w "$head_node" \
|
||||
bash -c "$head_cmd" 2>&1 | tee -a "$LOG_DIR/ray-head.log" &
|
||||
|
||||
|
||||
########################################################
|
||||
# Wait til Ray cluster ready on head
|
||||
########################################################
|
||||
|
||||
sleep 30
|
||||
|
||||
echo "[INFO] waiting for head container..."
|
||||
while ! srun --overlap -N1 -n1 -w "$head_node" bash -c "test -f '$LOG_DIR/STARTED_RAY_HEAD'" >/dev/null 2>&1; do
|
||||
echo "[INFO][$(date)] Waiting for head node container to start..."
|
||||
sleep 2
|
||||
done
|
||||
|
||||
ATTEMPTS=15
|
||||
until srun --overlap --nodes=1 --ntasks=1 --cpu-bind=none -w "$head_node" \
|
||||
--container-name="${HEAD_NAME}" bash -lc 'ray status >/dev/null 2>&1'
|
||||
do
|
||||
((ATTEMPTS--)) || { echo "[ERROR] Ray head did not come up."; exit 1; }
|
||||
sleep 4
|
||||
done
|
||||
echo -e "${GREEN}[INFO] Ray head is UP! ✔${RESET}"
|
||||
|
||||
|
||||
########################################################
|
||||
# Start Ray worker nodes
|
||||
########################################################
|
||||
|
||||
NUM_WORKERS=${#WORKERS[@]}
|
||||
for idx in "${!WORKERS[@]}"; do
|
||||
W=${WORKERS[$idx]}
|
||||
|
||||
worker_cmd=$(
|
||||
cat <<EOF
|
||||
# WAR: clean all slurm / MPI / PMIx env to avoid pmix mismatch error
|
||||
for v in \$(env | awk -F= '/^(PMI|PMIX|MPI|OMPI|SLURM)_/{print \$1}'); do
|
||||
unset "\$v"
|
||||
done
|
||||
|
||||
export RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES=1
|
||||
export RAY_DEDUP_LOGS=0
|
||||
export TRTLLM_UCX_INTERFACE=eth0
|
||||
|
||||
ray start --address="$ip_head" \
|
||||
--disable-usage-stats \
|
||||
--min-worker-port=${MIN_WORKER_PORT} \
|
||||
--max-worker-port=${MAX_WORKER_PORT} \
|
||||
--num-cpus=16 \
|
||||
--block
|
||||
EOF
|
||||
)
|
||||
|
||||
echo "worker_cmd (node=$W, idx=$idx): $worker_cmd"
|
||||
srun --overlap $COMMON_SRUN_ARGS --exact --container-name="ray-worker-${idx}" --nodes=1 --ntasks=1 --cpu-bind=none \
|
||||
-w "$W" bash -lc "$worker_cmd" 2>&1 | tee -a "$LOG_DIR/ray-worker-${idx}.log" &
|
||||
done
|
||||
|
||||
|
||||
##############################################################################
|
||||
# Wait until every node is connected in the Ray cluster
|
||||
##############################################################################
|
||||
|
||||
expected_nodes=$(( ${#WORKERS[@]} + 1 )) # head + workers
|
||||
echo "[INFO] Waiting for ${#WORKERS[@]} worker nodes to join..."
|
||||
|
||||
ATTEMPTS=30
|
||||
while (( ATTEMPTS-- )); do
|
||||
# run ray status inside the head container
|
||||
status=$(srun --overlap --nodes=1 --ntasks=1 --cpu-bind=none \
|
||||
-w "$head_node" --container-name="${HEAD_NAME}" ray status 2>/dev/null)
|
||||
|
||||
# to count active nodes
|
||||
active_count=$(
|
||||
sed -n '/^Active:/,/^Pending:/p' <<<"$status" |
|
||||
grep -c 'node_'
|
||||
)
|
||||
|
||||
echo "[INFO] Detected $active_count worker node(s) out of $expected_nodes total expected node(s)."
|
||||
|
||||
[[ $active_count -eq $expected_nodes ]] && break
|
||||
sleep 4
|
||||
done
|
||||
|
||||
[[ $active_count -eq $expected_nodes ]] || { echo '[ERROR] Ray nodes timed-out'; exit 1; }
|
||||
|
||||
echo -e "${GREEN}[INFO] All nodes have joined Ray cluster. ✔ \`ray status\`:${RESET}"
|
||||
printf '%s\n' "$status"
|
||||
|
||||
|
||||
##############################################################################
|
||||
# Run TRT-LLM driver (if given)
|
||||
##############################################################################
|
||||
|
||||
if [[ -n "$COMMAND" ]]; then
|
||||
driver_srun_cmd=(
|
||||
srun --overlap
|
||||
--no-container-mount-home
|
||||
--container-name="${HEAD_NAME}"
|
||||
--nodes=1 --ntasks=1 -w "$head_node"
|
||||
bash -lc "$COMMAND"
|
||||
)
|
||||
# --c ontainer-workdir="$CONTAINER_CWD"
|
||||
|
||||
echo -e "${BLUE}[INFO] Driver srun command:${RESET}"
|
||||
printf ' %q ' "${driver_srun_cmd[@]}"
|
||||
echo
|
||||
|
||||
set +e
|
||||
"${driver_srun_cmd[@]}" 2>&1 | tee -a "$LOG_DIR/trtllm-command.log"
|
||||
DRIVER_RC=$?
|
||||
set -e
|
||||
|
||||
if [[ $DRIVER_RC -ne 0 ]]; then
|
||||
echo "[WARN] driver exited with status $DRIVER_RC – Ray cluster left running."
|
||||
fi
|
||||
else
|
||||
echo "[INFO] No COMMAND supplied. Idling. Press Ctrl+C to exit."
|
||||
fi
|
||||
|
||||
wait
|
||||
3
examples/ray_orchestrator/requirements.txt
Normal file
3
examples/ray_orchestrator/requirements.txt
Normal file
@ -0,0 +1,3 @@
|
||||
-c ../constraints.txt
|
||||
tensorrt_llm>=0.0.0.dev0
|
||||
ray[default]
|
||||
@ -1279,6 +1279,15 @@ def getMakoArgsFromStageName(stageName, parseSysinfo=false) {
|
||||
} else {
|
||||
makoArgs += ["auto_trigger=others"]
|
||||
}
|
||||
if (stageName.contains("-Ray-")) {
|
||||
// If stageName contains "-Ray-", add "orchestrator=ray" to makoArgs
|
||||
// At this point, only tests with orchestrator=ray or unspecified orchestrator will be run.
|
||||
// Mark tests with orchestrator=mpi to exclude them from Ray stage.
|
||||
makoArgs += ["orchestrator=ray"]
|
||||
} else {
|
||||
// Otherwise select tests with orchestrator=mpi or unspecified orchestrator
|
||||
makoArgs += ["orchestrator=mpi"]
|
||||
}
|
||||
|
||||
if (parseSysinfo) {
|
||||
def taskConfig = parseMultiNodeTaskConfigFromStageName(stageName)
|
||||
@ -1708,6 +1717,11 @@ def runLLMTestlistOnPlatformImpl(pipeline, platform, testList, config=VANILLA_CO
|
||||
"--perf-log-formats yaml"
|
||||
]
|
||||
}
|
||||
if (stageName.contains("-Ray-")) {
|
||||
testCmdLine += ["--run-ray"]
|
||||
|
||||
trtllm_utils.llmExecStepWithRetry(pipeline, script: "pip3 install ray[default]")
|
||||
}
|
||||
// Test Coverage
|
||||
def TRTLLM_WHL_PATH = sh(returnStdout: true, script: "pip3 show tensorrt_llm | grep Location | cut -d ' ' -f 2").replaceAll("\\s","")
|
||||
sh "echo ${TRTLLM_WHL_PATH}"
|
||||
@ -2049,6 +2063,7 @@ def launchTestJobs(pipeline, testFilter)
|
||||
"DGX_H100-2_GPUs-PyTorch-Others-1": ["dgx-h100-x2", "l0_dgx_h100", 1, 1, 2],
|
||||
"DGX_H100-4_GPUs-PyTorch-GptOss-1": ["dgx-h100-x4", "l0_dgx_h100", 1, 1, 4],
|
||||
"DGX_H100-4_GPUs-PyTorch-Others-1": ["dgx-h100-x4", "l0_dgx_h100", 1, 1, 4],
|
||||
"DGX_H100-2_GPUs-PyTorch-Ray-1": ["dgx-h100-x2", "l0_dgx_h100", 1, 1, 2],
|
||||
"DGX_H100-4_GPUs-CPP-1": ["dgx-h100-x4", "l0_dgx_h100", 1, 1, 4],
|
||||
"A10-PyTorch-1": ["a10", "l0_a10", 1, 1],
|
||||
"A10-CPP-1": ["a10", "l0_a10", 1, 1],
|
||||
@ -2071,6 +2086,7 @@ def launchTestJobs(pipeline, testFilter)
|
||||
"H100_PCIe-PyTorch-1": ["h100-cr", "l0_h100", 1, 3],
|
||||
"H100_PCIe-PyTorch-2": ["h100-cr", "l0_h100", 2, 3],
|
||||
"H100_PCIe-PyTorch-3": ["h100-cr", "l0_h100", 3, 3],
|
||||
"H100_PCIe-PyTorch-Ray-1": ["h100-cr", "l0_h100", 1, 1],
|
||||
"H100_PCIe-CPP-1": ["h100-cr", "l0_h100", 1, 2],
|
||||
"H100_PCIe-CPP-2": ["h100-cr", "l0_h100", 2, 2],
|
||||
"H100_PCIe-TensorRT-1": ["h100-cr", "l0_h100", 1, 2],
|
||||
@ -2158,6 +2174,7 @@ def launchTestJobs(pipeline, testFilter)
|
||||
"B300-PyTorch-1": ["b300-single", "l0_b300", 1, 1],
|
||||
"DGX_B200-4_GPUs-PyTorch-1": ["b200-x4", "l0_dgx_b200", 1, 2, 4],
|
||||
"DGX_B200-4_GPUs-PyTorch-2": ["b200-x4", "l0_dgx_b200", 2, 2, 4],
|
||||
"DGX_B200-4_GPUs-PyTorch-Ray-1": ["b200-x4", "l0_dgx_b200", 1, 1, 4],
|
||||
"DGX_B200-8_GPUs-PyTorch-1": ["b200-x8", "l0_dgx_b200", 1, 1, 8],
|
||||
"DGX_B200-4_GPUs-PyTorch-Post-Merge-1": ["b200-x4", "l0_dgx_b200", 1, 1, 4],
|
||||
"DGX_B300-4_GPUs-PyTorch-Post-Merge-1": ["b300-x4", "l0_dgx_b300", 1, 1, 4],
|
||||
|
||||
@ -606,7 +606,8 @@ def main(*,
|
||||
build_deep_ep = "OFF"
|
||||
build_deep_gemm = "OFF"
|
||||
else:
|
||||
targets.extend(["th_common", "bindings", "deep_ep", "deep_gemm"])
|
||||
targets.extend(
|
||||
["th_common", "bindings", "deep_ep", "deep_gemm", "pg_utils"])
|
||||
build_pyt = "ON"
|
||||
build_deep_ep = "ON"
|
||||
build_deep_gemm = "ON"
|
||||
@ -811,6 +812,8 @@ def main(*,
|
||||
build_dir /
|
||||
"tensorrt_llm/kernels/decoderMaskedMultiheadAttention/libdecoder_attention_1.so",
|
||||
lib_dir / "libdecoder_attention_1.so")
|
||||
install_file(build_dir / "tensorrt_llm/runtime/utils/libpg_utils.so",
|
||||
lib_dir / "libpg_utils.so")
|
||||
|
||||
deep_ep_dir = pkg_dir / "deep_ep"
|
||||
if deep_ep_dir.is_symlink():
|
||||
|
||||
5
setup.py
5
setup.py
@ -104,8 +104,9 @@ else:
|
||||
'libs/libnvinfer_plugin_tensorrt_llm.so',
|
||||
'libs/libtensorrt_llm_ucx_wrapper.so', 'libs/libdecoder_attention_0.so',
|
||||
'libs/libtensorrt_llm_nixl_wrapper.so', 'libs/nixl/**/*',
|
||||
'libs/ucx/**/*', 'libs/libdecoder_attention_1.so',
|
||||
'libs/nvshmem/License.txt', 'libs/nvshmem/nvshmem_bootstrap_uid.so.3',
|
||||
'libs/ucx/**/*', 'libs/libpg_utils.so',
|
||||
'libs/libdecoder_attention_1.so', 'libs/nvshmem/License.txt',
|
||||
'libs/nvshmem/nvshmem_bootstrap_uid.so.3',
|
||||
'libs/nvshmem/nvshmem_transport_ibgda.so.103', 'bindings.*.so',
|
||||
'deep_ep/LICENSE', 'deep_ep_cpp_tllm.*.so', "include/**/*",
|
||||
'deep_gemm/LICENSE', 'deep_gemm/include/**/*',
|
||||
|
||||
@ -17,6 +17,8 @@ import struct
|
||||
import sys
|
||||
from typing import List, Tuple
|
||||
|
||||
from tensorrt_llm._utils import mpi_disabled
|
||||
|
||||
try:
|
||||
from cuda.bindings import driver as cuda
|
||||
from cuda.bindings import runtime as cudart
|
||||
@ -72,11 +74,11 @@ class IpcMemory:
|
||||
def __init__(self, mapping: Mapping, size: int, open_ipc: bool = True):
|
||||
self.mapping = mapping
|
||||
self.open_ipc = open_ipc and mapping.tp_size <= mapping.gpus_per_node
|
||||
self.peer_ptrs = [0] * mapping.tp_size
|
||||
self.local_ptr = 0
|
||||
|
||||
if self.open_ipc:
|
||||
self.peer_ptrs, self.local_ptr = IpcMemory.open_ipc_memory(self.mapping, size, True)
|
||||
else:
|
||||
self.peer_ptrs = [0] * mapping.tp_size
|
||||
self.local_ptr = 0
|
||||
|
||||
def __del__(self):
|
||||
if not sys.is_finalizing() and self.open_ipc:
|
||||
@ -103,9 +105,15 @@ class IpcMemory:
|
||||
size += alignment - (size % alignment)
|
||||
return size
|
||||
|
||||
comm = mpi_comm().Split(
|
||||
mapping.pp_rank * mapping.cp_size + mapping.cp_rank, mapping.tp_rank
|
||||
)
|
||||
if mpi_disabled():
|
||||
from tensorrt_llm._utils import torch_comm
|
||||
|
||||
allgather = torch_comm().tp_allgather
|
||||
else:
|
||||
comm = mpi_comm().Split(
|
||||
mapping.pp_rank * mapping.cp_size + mapping.cp_rank, mapping.tp_rank
|
||||
)
|
||||
allgather = comm.allgather
|
||||
|
||||
# see allocateIpcMemory in cpp/tensorrt_llm/runtime/ipcUtils.cpp for alignment reason
|
||||
# 1 << 21 is 2MB
|
||||
@ -116,8 +124,8 @@ class IpcMemory:
|
||||
_raise_if_error(cudart.cudaMemset(local_ptr, 0, aligned_size)[0])
|
||||
error, local_handle = cudart.cudaIpcGetMemHandle(local_ptr)
|
||||
_raise_if_error(error)
|
||||
handles_reserved = allgather(local_handle.reserved)
|
||||
|
||||
handles_reserved = comm.allgather(local_handle.reserved)
|
||||
handles = []
|
||||
for reserved in handles_reserved:
|
||||
handle = cudart.cudaIpcMemHandle_t()
|
||||
@ -141,6 +149,8 @@ class IpcMemory:
|
||||
def close_ipc_memory(mapping: Mapping, peer_ptrs: List[int]):
|
||||
for node, ptr in enumerate(peer_ptrs):
|
||||
if node == mapping.tp_rank:
|
||||
_raise_if_error(cudart.cudaFree(ptr)[0])
|
||||
if ptr != 0:
|
||||
_raise_if_error(cudart.cudaFree(ptr)[0])
|
||||
else:
|
||||
_raise_if_error(cudart.cudaIpcCloseMemHandle(ptr)[0])
|
||||
if ptr != 0:
|
||||
_raise_if_error(cudart.cudaIpcCloseMemHandle(ptr)[0])
|
||||
|
||||
@ -2,7 +2,6 @@
|
||||
|
||||
import atexit
|
||||
import os
|
||||
import socket
|
||||
import sys
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
@ -10,6 +9,8 @@ import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from tensorrt_llm._utils import get_free_port as _get_free_port
|
||||
|
||||
from ..utils.logger import ad_logger
|
||||
|
||||
# TODO: check to what extend we can reuse _torch/distributed.py
|
||||
@ -69,10 +70,7 @@ def all_gather_object(object_list, object, group=None):
|
||||
|
||||
|
||||
def get_free_port():
|
||||
sock = socket.socket()
|
||||
sock.bind(("", 0))
|
||||
port = sock.getsockname()[1]
|
||||
return port
|
||||
return _get_free_port()
|
||||
|
||||
|
||||
def get_world_size() -> int:
|
||||
|
||||
@ -10,7 +10,7 @@ from ..._utils import get_sm_version
|
||||
def _register_fake():
|
||||
|
||||
@torch.library.register_fake("trtllm::allreduce")
|
||||
def _(
|
||||
def allreduce(
|
||||
input,
|
||||
residual,
|
||||
norm_weight,
|
||||
@ -55,6 +55,25 @@ def _register_fake():
|
||||
else:
|
||||
return [torch.empty_like(input)]
|
||||
|
||||
@torch.library.register_fake("trtllm::allreduce_pg")
|
||||
def _(
|
||||
input,
|
||||
residual,
|
||||
norm_weight,
|
||||
scale,
|
||||
bias,
|
||||
workspace,
|
||||
group,
|
||||
rank,
|
||||
pg,
|
||||
strategy,
|
||||
op,
|
||||
eps,
|
||||
trigger_completion_at_end,
|
||||
):
|
||||
return allreduce(input, residual, norm_weight, scale, bias, workspace,
|
||||
group, strategy, op, eps, trigger_completion_at_end)
|
||||
|
||||
#MNNVL Allreduce
|
||||
@torch.library.register_fake("trtllm::mnnvl_twoshot_allreduce")
|
||||
def _(input, buffer, buffer_flags, buffer_size, wait_for_results):
|
||||
@ -76,13 +95,17 @@ def _register_fake():
|
||||
return [norm_out, residual_out]
|
||||
|
||||
@torch.library.register_fake("trtllm::allgather")
|
||||
def _(input, sizes, group):
|
||||
def allgather(input, sizes, group):
|
||||
if sizes is None:
|
||||
output_shape = (len(group) * input.shape[0], *input.shape[1:])
|
||||
else:
|
||||
output_shape = (sum(sizes), *input.shape[1:])
|
||||
return input.new_empty(output_shape)
|
||||
|
||||
@torch.library.register_fake("trtllm::allgather_pg")
|
||||
def _(input, sizes, group, process_group):
|
||||
return allgather(input, sizes, group)
|
||||
|
||||
@torch.library.register_fake("trtllm::cublas_scaled_mm")
|
||||
def _(
|
||||
mat_a: torch.Tensor,
|
||||
@ -439,7 +462,7 @@ def _register_fake():
|
||||
dtype=gemm2_output.dtype)
|
||||
|
||||
@torch.library.register_fake("trtllm::allgather_list")
|
||||
def _(input_list, sizes, group):
|
||||
def allgather_list(input_list, sizes, group):
|
||||
assert len(input_list) > 0
|
||||
|
||||
def create_output_tensor(i):
|
||||
@ -452,8 +475,12 @@ def _register_fake():
|
||||
|
||||
return [create_output_tensor(i) for i in input_list]
|
||||
|
||||
@torch.library.register_fake("trtllm::allgather_list_pg")
|
||||
def _(input_list, sizes, group, process_group):
|
||||
return allgather_list(input_list, sizes, group)
|
||||
|
||||
@torch.library.register_fake("trtllm::reducescatter")
|
||||
def _(input, sizes, group):
|
||||
def reducescatter(input, sizes, group):
|
||||
import tensorrt_llm
|
||||
local_rank = tensorrt_llm.mpi_rank()
|
||||
|
||||
@ -464,6 +491,10 @@ def _register_fake():
|
||||
shape[0] = sizes[local_rank]
|
||||
return input.new_empty(shape)
|
||||
|
||||
@torch.library.register_fake("trtllm::reducescatter_pg")
|
||||
def _(input, sizes, group, process_group):
|
||||
return reducescatter(input, sizes, group)
|
||||
|
||||
@torch.library.register_fake("trtllm::block_scale_interleave")
|
||||
def _(sf: torch.Tensor):
|
||||
rows = sf.shape[-2]
|
||||
|
||||
165
tensorrt_llm/_torch/device_mesh.py
Normal file
165
tensorrt_llm/_torch/device_mesh.py
Normal file
@ -0,0 +1,165 @@
|
||||
from functools import wraps
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import get_process_group_ranks
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tensorrt_llm.mapping import MappingBase as _MappingBaseForTypeCheck
|
||||
else:
|
||||
_MappingBaseForTypeCheck = object
|
||||
|
||||
|
||||
def require_device_mesh(func):
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
if DeviceMeshTopologyImpl.device_mesh is None:
|
||||
self.build_mesh()
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class SingleProcessGroup:
|
||||
|
||||
@staticmethod
|
||||
def get_group():
|
||||
return dist.group.WORLD if dist.is_initialized(
|
||||
) else SingleProcessGroup()
|
||||
|
||||
@staticmethod
|
||||
def rank():
|
||||
return 0
|
||||
|
||||
@staticmethod
|
||||
def size():
|
||||
return 1
|
||||
|
||||
|
||||
class DeviceMeshTopologyImpl(_MappingBaseForTypeCheck):
|
||||
device_mesh = None
|
||||
tp_mesh = None
|
||||
|
||||
# Access Torch ProcessGroup
|
||||
@property
|
||||
@require_device_mesh
|
||||
def tp_group_pg(self):
|
||||
return self._get_mesh_dim_by_name('tp').get_group()
|
||||
|
||||
@property
|
||||
@require_device_mesh
|
||||
def pp_group_pg(self):
|
||||
return self._get_mesh_dim_by_name('pp').get_group()
|
||||
|
||||
@property
|
||||
@require_device_mesh
|
||||
def cp_group_pg(self):
|
||||
return self._get_mesh_dim_by_name('cp').get_group()
|
||||
|
||||
@property
|
||||
@require_device_mesh
|
||||
def moe_tp_group_pg(self):
|
||||
return self._get_mesh_dim_by_name('moe_tp').get_group()
|
||||
|
||||
@property
|
||||
@require_device_mesh
|
||||
def moe_ep_group_pg(self):
|
||||
return self._get_mesh_dim_by_name('moe_ep').get_group()
|
||||
|
||||
# Access rank
|
||||
@property
|
||||
def tp_rank(self) -> int:
|
||||
assert not self.auto_parallel, "Auto parallel is not currently supported in Ray mode."
|
||||
return self.tp_group_pg.rank()
|
||||
|
||||
@property
|
||||
def pp_rank(self) -> int:
|
||||
assert not self.auto_parallel, "Auto parallel is not currently supported in Ray mode."
|
||||
return self.pp_group_pg.rank()
|
||||
|
||||
@property
|
||||
def cp_rank(self) -> int:
|
||||
# TODO: WIP
|
||||
assert not self.auto_parallel, "Auto parallel is not currently supported in Ray mode."
|
||||
return self.cp_group_pg.rank()
|
||||
|
||||
# Access group ranks
|
||||
@property
|
||||
def tp_group(self) -> List[int]:
|
||||
return self._get_group_ranks(self.tp_group_pg)
|
||||
|
||||
@property
|
||||
def pp_group(self) -> List[int]:
|
||||
return self._get_group_ranks(self.pp_group_pg)
|
||||
|
||||
@property
|
||||
def cp_group(self) -> List[int]:
|
||||
return self._get_group_ranks(self.cp_group_pg)
|
||||
|
||||
@property
|
||||
def moe_tp_group(self) -> List[int]:
|
||||
return self._get_group_ranks(self.moe_tp_group_pg)
|
||||
|
||||
@property
|
||||
def moe_ep_group(self) -> List[int]:
|
||||
return self._get_group_ranks(self.moe_ep_group_pg)
|
||||
|
||||
def build_mesh(self):
|
||||
cls = DeviceMeshTopologyImpl
|
||||
|
||||
if self.world_size == 1 or cls.device_mesh is not None:
|
||||
# only build mesh once
|
||||
return
|
||||
|
||||
if not torch.distributed.is_initialized():
|
||||
raise RuntimeError(
|
||||
"DeviceMesh creation requested but torch.distributed process group "
|
||||
"has not been initialised.")
|
||||
|
||||
dims = ["cp", "pp"]
|
||||
shape = [self.cp_size, self.pp_size]
|
||||
|
||||
if self.moe_ep_size > 1:
|
||||
dims += ["moe_tp", "moe_ep"]
|
||||
shape += [self.moe_tp_size, self.moe_ep_size]
|
||||
else:
|
||||
dims += ["tp"]
|
||||
shape += [self.tp_size]
|
||||
|
||||
cls.device_mesh = init_device_mesh(
|
||||
"cuda",
|
||||
mesh_shape=tuple(shape),
|
||||
mesh_dim_names=tuple(dims),
|
||||
)
|
||||
|
||||
if self.moe_ep_size > 1:
|
||||
cls.tp_mesh = cls.device_mesh["moe_tp",
|
||||
"moe_ep"]._flatten(mesh_dim_name="tp")
|
||||
logger.debug(f"DeviceMeshTopology.device_mesh: {cls.device_mesh}")
|
||||
logger.debug(f"DeviceMeshTopology.tp_mesh: {cls.tp_mesh}")
|
||||
|
||||
@require_device_mesh
|
||||
def _get_mesh_dim_by_name(self, name: str) -> dist.DeviceMesh:
|
||||
cls = DeviceMeshTopologyImpl
|
||||
|
||||
if cls.device_mesh is None and self.world_size == 1:
|
||||
return SingleProcessGroup()
|
||||
|
||||
if name == 'tp':
|
||||
if 'tp' in cls.device_mesh.mesh_dim_names:
|
||||
return cls.device_mesh['tp']
|
||||
else:
|
||||
return cls.tp_mesh
|
||||
else:
|
||||
assert name in cls.device_mesh.mesh_dim_names, f"Dimension name {name} not found in device mesh."
|
||||
return cls.device_mesh[name]
|
||||
|
||||
def _get_group_ranks(self, pg) -> List[int]:
|
||||
if self.world_size == 1:
|
||||
return [0]
|
||||
return get_process_group_ranks(pg)
|
||||
@ -1,12 +1,14 @@
|
||||
import math
|
||||
import os
|
||||
import pickle # nosec B403
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import wraps
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed.distributed_c10d import (_object_to_tensor,
|
||||
_tensor_to_object)
|
||||
|
||||
try:
|
||||
from mpi4py import MPI
|
||||
@ -14,11 +16,19 @@ except Exception:
|
||||
MPI = None # deferred; functions will error if used when ENABLE_MULTI_DEVICE is True
|
||||
|
||||
from tensorrt_llm._utils import (mpi_allgather, mpi_barrier, mpi_comm,
|
||||
mpi_isend, mpi_isend_object, mpi_recv,
|
||||
mpi_recv_object, mpi_send, mpi_send_object)
|
||||
mpi_disabled, mpi_isend, mpi_isend_object,
|
||||
mpi_recv, mpi_recv_object, mpi_send,
|
||||
mpi_send_object, torch_pybind11_abi)
|
||||
from tensorrt_llm.bindings.BuildInfo import ENABLE_MULTI_DEVICE
|
||||
from tensorrt_llm.bindings.internal.process_group import init_pg
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
|
||||
try:
|
||||
import ray
|
||||
except ModuleNotFoundError:
|
||||
from tensorrt_llm import ray_stub as ray
|
||||
|
||||
|
||||
class Distributed(ABC):
|
||||
|
||||
@ -327,6 +337,7 @@ def safe_gather(comm, obj, root=0, chunk_size: int = 4 * 1024 * 1024):
|
||||
|
||||
|
||||
class MPIDist(Distributed):
|
||||
tp_comm: MPI.Comm
|
||||
|
||||
def __init__(self, mapping: Mapping):
|
||||
super().__init__(mapping)
|
||||
@ -401,65 +412,361 @@ class MPIDist(Distributed):
|
||||
return self.pp_comm.bcast(obj, root)
|
||||
|
||||
|
||||
class MultiHandleWrapper:
|
||||
"""
|
||||
Wrapper that encapsulates multiple handles and provides a single wait() interface
|
||||
to unify the API between MPIDist and TorchDist.
|
||||
"""
|
||||
|
||||
def __init__(self, handles):
|
||||
self.handles = handles if isinstance(handles, list) else [handles]
|
||||
|
||||
def wait(self):
|
||||
for handle in self.handles:
|
||||
try:
|
||||
handle.wait()
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Asynchronous operation failed: {e}") from e
|
||||
|
||||
|
||||
class TorchDist(Distributed):
|
||||
|
||||
@property
|
||||
def rank(self):
|
||||
return torch.distributed.get_rank()
|
||||
|
||||
def __init__(self, mapping: Mapping):
|
||||
super().__init__(mapping)
|
||||
if not dist.is_initialized():
|
||||
master_ip = os.getenv("MASTER_ADDR", "localhost")
|
||||
# TODO: fix the constant default port
|
||||
master_port = os.getenv("MASTER_PORT", "6000")
|
||||
init_method = f"tcp://{master_ip}:{master_port}"
|
||||
dist.init_process_group(backend="nccl",
|
||||
init_method=init_method,
|
||||
world_size=mapping.world_size,
|
||||
rank=mapping.rank)
|
||||
self.device_tp_group = dist.new_group(mapping.tp_group, backend="nccl")
|
||||
self.cpu_tp_group = dist.new_group(mapping.tp_group, backend="gloo")
|
||||
self.device_cp_group = dist.new_group(mapping.cp_group, backend="nccl")
|
||||
self.cpu_cp_group = dist.new_group(mapping.cp_group, backend="gloo")
|
||||
assert dist.is_initialized(
|
||||
), "torch.distributed should be initialized before TorchDist"
|
||||
|
||||
def broadcast_tp(self, obj, root=0):
|
||||
if root not in self.mapping.tp_group:
|
||||
return obj
|
||||
elif self.rank == root:
|
||||
torch.distributed.broadcast_object_list([obj],
|
||||
src=root,
|
||||
group=self.cpu_tp_group)
|
||||
return obj
|
||||
self.cluster_info = None
|
||||
|
||||
from tensorrt_llm._utils import set_torch_comm
|
||||
set_torch_comm(self) # Set as global instance
|
||||
mapping.build_mesh()
|
||||
|
||||
self.setup_local_comm()
|
||||
self.default_store = torch.distributed.distributed_c10d._get_default_store(
|
||||
)
|
||||
|
||||
init_pg(torch.distributed.group.WORLD, self.local_comm,
|
||||
torch_pybind11_abi())
|
||||
|
||||
def setup_local_comm(self):
|
||||
self._get_cluster_info()
|
||||
|
||||
# node IP -> list of ranks
|
||||
ip_to_ranks = {}
|
||||
for rank, (node_ip, _) in enumerate(self.cluster_info):
|
||||
ip_to_ranks.setdefault(node_ip, []).append(int(rank))
|
||||
|
||||
self.local_comm = None
|
||||
for ranks in ip_to_ranks.values():
|
||||
# All global ranks from the default process group to participate in the call,
|
||||
# even if some ranks are not part of the new process group being created
|
||||
pg = dist.new_group(ranks=ranks, backend='cuda:nccl,cpu:gloo')
|
||||
if int(self.rank) in ranks:
|
||||
logger.debug(
|
||||
f"[Rank {self.rank}] Done setting local comm. ip_to_ranks: {ip_to_ranks}"
|
||||
)
|
||||
self.local_comm = pg
|
||||
|
||||
def _get_cluster_info(self):
|
||||
if self.cluster_info is not None:
|
||||
return self.cluster_info
|
||||
|
||||
if ray.is_initialized():
|
||||
node_ip = ray.util.get_node_ip_address()
|
||||
else:
|
||||
recv = [None]
|
||||
torch.distributed.broadcast_object_list(recv,
|
||||
src=root,
|
||||
group=self.cpu_tp_group)
|
||||
return recv[0]
|
||||
raise RuntimeError("Ray is not initialized")
|
||||
|
||||
def broadcast_cp(self, obj, root=0):
|
||||
if root not in self.mapping.cp_group:
|
||||
return obj
|
||||
elif self.rank == root:
|
||||
torch.distributed.broadcast_object_list([obj],
|
||||
src=root,
|
||||
group=self.cpu_cp_group)
|
||||
return obj
|
||||
else:
|
||||
recv = [None]
|
||||
torch.distributed.broadcast_object_list(recv,
|
||||
src=root,
|
||||
group=self.cpu_cp_group)
|
||||
return recv[0]
|
||||
gpu_index = [int(id) for id in ray.get_gpu_ids()]
|
||||
|
||||
assert len(gpu_index) == 1
|
||||
|
||||
# Gather node ip
|
||||
node_list = [None] * torch.distributed.get_world_size()
|
||||
|
||||
torch.distributed.all_gather_object(node_list, node_ip)
|
||||
|
||||
# Gather gpu index
|
||||
gpu_list = [None] * torch.distributed.get_world_size()
|
||||
torch.distributed.all_gather_object(gpu_list, gpu_index[0])
|
||||
|
||||
# Gather rank
|
||||
rank_list = [None] * torch.distributed.get_world_size()
|
||||
torch.distributed.all_gather_object(rank_list,
|
||||
torch.distributed.get_rank())
|
||||
|
||||
rank_info_list = [None] * torch.distributed.get_world_size()
|
||||
for i in range(len(rank_list)):
|
||||
rank_info_list[rank_list[i]] = (node_list[i], gpu_list[i])
|
||||
|
||||
self.cluster_info = rank_info_list
|
||||
|
||||
logger.debug(f"Cluster info: {self.cluster_info}")
|
||||
return self.cluster_info
|
||||
|
||||
@staticmethod
|
||||
def log_op(func, enable_log=False):
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
if enable_log:
|
||||
logger.debug(
|
||||
f"{func.__name__} enter: {args[1:]}, {kwargs}, rank: {torch.distributed.get_rank()}"
|
||||
)
|
||||
ret = func(*args, **kwargs)
|
||||
|
||||
if enable_log:
|
||||
logger.debug(f"{func.__name__} exit: {ret}")
|
||||
return ret
|
||||
|
||||
return wrapper
|
||||
|
||||
@log_op
|
||||
def broadcast(self, obj, root=0):
|
||||
assert not (self.mapping.has_cp_ulysses() and self.mapping.has_tp()
|
||||
), 'Unsupported mix of Ulysses CP and TP.'
|
||||
|
||||
if mpi_disabled():
|
||||
if isinstance(obj, torch.Tensor):
|
||||
dist.broadcast(obj, src=root)
|
||||
return obj
|
||||
else:
|
||||
obj_list = [obj]
|
||||
dist.broadcast_object_list(obj_list, src=root)
|
||||
return obj_list[0]
|
||||
|
||||
if self.mapping.has_cp_ulysses():
|
||||
self.broadcast_cp(obj, root)
|
||||
elif self.mapping.has_tp():
|
||||
self.broadcast_tp(obj, root)
|
||||
|
||||
@log_op
|
||||
def allgather(self, obj):
|
||||
if isinstance(obj, torch.Tensor):
|
||||
output_list = [
|
||||
torch.empty_like(obj) for _ in range(self.world_size)
|
||||
]
|
||||
dist.all_gather(output_list, obj)
|
||||
return output_list
|
||||
else:
|
||||
pass
|
||||
obj_list = [None] * self.world_size
|
||||
dist.all_gather_object(obj_list, obj)
|
||||
return obj_list
|
||||
|
||||
@log_op
|
||||
def barrier(self):
|
||||
dist.barrier()
|
||||
|
||||
@log_op
|
||||
def isend(self, buf: np.ndarray, dest, tag=0):
|
||||
# non-blocking send numpy buffer
|
||||
tensor = torch.from_numpy(buf)
|
||||
return dist.isend(tensor, dst=dest, tag=tag)
|
||||
|
||||
@log_op
|
||||
def send(self, buf: np.ndarray, dest, tag=0):
|
||||
raise NotImplementedError(
|
||||
"blocking send is not implemented for TorchDist")
|
||||
|
||||
@log_op
|
||||
def recv(self, buf: np.ndarray, src, tag=0):
|
||||
# in-place recv numpy buffer
|
||||
tensor = torch.empty_like(torch.from_numpy(buf))
|
||||
dist.recv(tensor, src=src, tag=tag)
|
||||
return tensor.numpy()
|
||||
|
||||
@log_op
|
||||
def isend_tensor(self, tensor: torch.Tensor, dest, tag=0):
|
||||
return dist.isend(tensor, dst=dest, tag=tag)
|
||||
|
||||
@log_op
|
||||
def recv_tensor(self, tensor: torch.Tensor, src, tag=0):
|
||||
dist.recv(tensor, src=src, tag=tag)
|
||||
return tensor
|
||||
|
||||
@log_op
|
||||
def recv_object(self, src, tag=0):
|
||||
size_tensor = torch.tensor([0], dtype=torch.int32)
|
||||
torch.distributed.recv(size_tensor,
|
||||
src=src,
|
||||
tag=tag,
|
||||
group=torch.distributed.group.WORLD)
|
||||
bytes_size = size_tensor.item()
|
||||
recv_tensor = torch.empty(bytes_size, dtype=torch.uint8)
|
||||
torch.distributed.recv(recv_tensor,
|
||||
src=src,
|
||||
tag=tag,
|
||||
group=torch.distributed.group.WORLD)
|
||||
return _tensor_to_object(recv_tensor, bytes_size,
|
||||
torch.distributed.group.WORLD)
|
||||
|
||||
@log_op
|
||||
def send_object(self, obj, dest, tag=0):
|
||||
raise NotImplementedError(
|
||||
"send_object is not implemented for TorchDist")
|
||||
|
||||
@log_op
|
||||
def isend_object(self, obj, dest, tag=0):
|
||||
input_tensor, local_size = _object_to_tensor(
|
||||
obj, torch.device("cpu"), torch.distributed.group.WORLD)
|
||||
|
||||
# Send object size
|
||||
works = []
|
||||
works.append(
|
||||
torch.distributed.isend(torch.tensor([local_size],
|
||||
dtype=torch.int32),
|
||||
dst=dest,
|
||||
tag=tag))
|
||||
works.append(torch.distributed.isend(input_tensor, dst=dest, tag=tag))
|
||||
return MultiHandleWrapper(works)
|
||||
|
||||
@log_op
|
||||
def recv_object_from_isend(self, src, tag):
|
||||
size_tensor = torch.tensor([0], dtype=torch.int32)
|
||||
torch.distributed.recv(size_tensor, src=src, tag=tag)
|
||||
bytes_size = size_tensor.item()
|
||||
recv_tensor = torch.empty(bytes_size, dtype=torch.uint8)
|
||||
torch.distributed.recv(recv_tensor, src=src, tag=tag)
|
||||
return _tensor_to_object(recv_tensor, bytes_size,
|
||||
torch.distributed.group.WORLD)
|
||||
|
||||
@log_op
|
||||
def allreduce(self,
|
||||
obj: int | float | torch.Tensor,
|
||||
op=torch.distributed.ReduceOp.SUM):
|
||||
is_base_type = isinstance(obj, int) or isinstance(obj, float)
|
||||
if is_base_type:
|
||||
obj = torch.tensor(obj)
|
||||
|
||||
dist.all_reduce(obj, op=op)
|
||||
|
||||
if is_base_type:
|
||||
obj = obj.item()
|
||||
|
||||
return obj
|
||||
|
||||
@log_op
|
||||
def tp_allgather(self, obj):
|
||||
if isinstance(obj, torch.Tensor):
|
||||
output_list = [
|
||||
torch.empty_like(obj)
|
||||
for _ in range(self.mapping.tp_group_pg.size())
|
||||
]
|
||||
dist.all_gather(output_list, obj, group=self.mapping.tp_group_pg)
|
||||
return output_list
|
||||
else:
|
||||
output_list = [None] * self.mapping.tp_group_pg.size()
|
||||
dist.all_gather_object(output_list,
|
||||
obj,
|
||||
group=self.mapping.tp_group_pg)
|
||||
return output_list
|
||||
|
||||
@log_op
|
||||
def tp_gather(self, obj, dst=0):
|
||||
global_rank = torch.distributed.get_rank()
|
||||
if isinstance(obj, torch.Tensor):
|
||||
if global_rank == dst:
|
||||
output_list = [
|
||||
torch.empty_like(obj)
|
||||
for _ in range(self.mapping.tp_group_pg.size())
|
||||
]
|
||||
else:
|
||||
output_list = None
|
||||
dist.gather(obj,
|
||||
output_list,
|
||||
dst=dst,
|
||||
group=self.mapping.tp_group_pg)
|
||||
return output_list
|
||||
else:
|
||||
output_list = [None] * self.mapping.tp_group_pg.size()
|
||||
if global_rank == dst:
|
||||
output_list = [None] * self.mapping.tp_group_pg.size()
|
||||
else:
|
||||
output_list = None
|
||||
dist.gather_object(obj,
|
||||
output_list,
|
||||
dst=dst,
|
||||
group=self.mapping.tp_group_pg)
|
||||
return output_list
|
||||
|
||||
@log_op
|
||||
def tp_broadcast(self, obj, root=0):
|
||||
if isinstance(obj, torch.Tensor):
|
||||
dist.broadcast(obj, src=root, group=self.mapping.tp_group_pg)
|
||||
return obj
|
||||
else:
|
||||
ret = [obj]
|
||||
torch.distributed.broadcast_object_list(
|
||||
ret,
|
||||
src=root,
|
||||
group=self.mapping.tp_group_pg,
|
||||
device=torch.device("cpu"))
|
||||
return ret[0]
|
||||
|
||||
@log_op
|
||||
def pp_allgather(self, obj):
|
||||
if isinstance(obj, torch.Tensor):
|
||||
output_list = [
|
||||
torch.empty_like(obj)
|
||||
for _ in range(self.mapping.pp_group_pg.size())
|
||||
]
|
||||
dist.all_gather(output_list, obj, group=self.mapping.pp_group_pg)
|
||||
return output_list
|
||||
else:
|
||||
output_list = [None] * self.mapping.pp_group_pg.size()
|
||||
dist.all_gather_object(output_list,
|
||||
obj,
|
||||
group=self.mapping.pp_group_pg)
|
||||
return output_list
|
||||
|
||||
@log_op
|
||||
def pp_gather(self, obj, dst=0):
|
||||
global_rank = torch.distributed.get_rank()
|
||||
if isinstance(obj, torch.Tensor):
|
||||
if global_rank == dst:
|
||||
output_list = [
|
||||
torch.empty_like(obj)
|
||||
for _ in range(self.mapping.pp_group_pg.size())
|
||||
]
|
||||
else:
|
||||
output_list = None
|
||||
dist.gather(obj,
|
||||
output_list,
|
||||
dst=dst,
|
||||
group=self.mapping.pp_group_pg)
|
||||
return output_list
|
||||
else:
|
||||
output_list = [None] * self.mapping.pp_group_pg.size()
|
||||
if global_rank == dst:
|
||||
output_list = [None] * self.mapping.pp_group_pg.size()
|
||||
else:
|
||||
output_list = None
|
||||
dist.gather_object(obj,
|
||||
output_list,
|
||||
dst=dst,
|
||||
group=self.mapping.pp_group_pg)
|
||||
return output_list
|
||||
|
||||
@log_op
|
||||
def pp_broadcast(self, obj, root=0):
|
||||
if isinstance(obj, torch.Tensor):
|
||||
dist.broadcast(obj, src=root, group=self.mapping.pp_group_pg)
|
||||
return obj
|
||||
else:
|
||||
ret = [obj]
|
||||
torch.distributed.broadcast_object_list(
|
||||
ret,
|
||||
src=root,
|
||||
group=self.mapping.pp_group_pg,
|
||||
device=torch.device("cpu"))
|
||||
return ret[0]
|
||||
|
||||
|
||||
# TODO: rename to PPCommNCCL
|
||||
class PPComm:
|
||||
|
||||
def __init__(self, global_mapping: Mapping):
|
||||
@ -480,20 +787,49 @@ class PPComm:
|
||||
self.nccl_comm.recv(tensor, src)
|
||||
|
||||
|
||||
class PPCommTorch:
|
||||
|
||||
def __init__(self, global_mapping: Mapping):
|
||||
self.mapping = global_mapping
|
||||
self.pg = self.mapping.pp_group_pg
|
||||
self.pg_group = self.mapping.pp_group
|
||||
|
||||
def _global_to_local_rank(self, global_rank: int):
|
||||
assert global_rank in self.pg_group
|
||||
return self.pg_group.index(global_rank)
|
||||
|
||||
def send(self, tensor: torch.Tensor, dest: Optional[int] = None):
|
||||
if dest is None:
|
||||
dest = self.mapping.next_pp_rank()
|
||||
|
||||
self.pg.send([tensor], self._global_to_local_rank(dest), tag=0).wait()
|
||||
|
||||
def recv(self, tensor: torch.Tensor, src: Optional[int] = None):
|
||||
if src is None:
|
||||
src = self.mapping.prev_pp_rank()
|
||||
|
||||
self.pg.recv([tensor], self._global_to_local_rank(src), tag=0).wait()
|
||||
|
||||
|
||||
_pp_comm = None
|
||||
|
||||
|
||||
def init_pp_comm(mapping):
|
||||
"""Initialize PPComm once at startup"""
|
||||
global _pp_comm
|
||||
_pp_comm = PPComm(mapping)
|
||||
if mpi_disabled():
|
||||
_pp_comm = PPCommTorch(mapping)
|
||||
else:
|
||||
_pp_comm = PPComm(mapping)
|
||||
|
||||
|
||||
@TorchDist.log_op
|
||||
def pp_recv(tensor):
|
||||
"""Receive tensors from previous pp rank."""
|
||||
_pp_comm.recv(tensor)
|
||||
|
||||
|
||||
@TorchDist.log_op
|
||||
def pp_send(tensor):
|
||||
"""Send tensors to next pp rank."""
|
||||
_pp_comm.send(tensor)
|
||||
|
||||
@ -7,7 +7,7 @@ from typing import List, Optional, Tuple, Union
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from tensorrt_llm._utils import mpi_comm
|
||||
from tensorrt_llm._utils import mpi_comm, mpi_disabled
|
||||
from tensorrt_llm.bindings.internal.runtime import McastGPUBuffer
|
||||
from tensorrt_llm.functional import (AllReduceFusionOp, AllReduceParams,
|
||||
AllReduceStrategy, MoEAllReduceParams)
|
||||
@ -185,26 +185,33 @@ def allgather(
|
||||
val.shape[dim] == sizes[mapping.tp_rank] for val in input
|
||||
if val is not None
|
||||
])
|
||||
|
||||
# Inputs are reshaped in this way to pass necessary shape information to the allgather op
|
||||
if isinstance(input, torch.Tensor):
|
||||
torch_op = torch.ops.trtllm.allgather
|
||||
if mpi_disabled():
|
||||
torch_op = torch.ops.trtllm.allgather_pg
|
||||
else:
|
||||
torch_op = torch.ops.trtllm.allgather
|
||||
|
||||
output_info = get_output_info(input, dim)
|
||||
input = input.contiguous().view(-1, output_info['numel_base'])
|
||||
else:
|
||||
input, valid = filter_valid_input(input)
|
||||
torch_op = torch.ops.trtllm.allgather_list
|
||||
if mpi_disabled():
|
||||
torch_op = torch.ops.trtllm.allgather_list_pg
|
||||
else:
|
||||
torch_op = torch.ops.trtllm.allgather_list
|
||||
|
||||
output_info = [get_output_info(val, dim) for val in input]
|
||||
input = [
|
||||
val.contiguous().view(-1, val_info['numel_base'])
|
||||
for val, val_info in zip(input, output_info)
|
||||
]
|
||||
|
||||
output = torch_op(
|
||||
input,
|
||||
sizes,
|
||||
mapping.tp_group,
|
||||
)
|
||||
if mpi_disabled():
|
||||
output = torch_op(input, sizes, mapping.tp_group,
|
||||
mapping.tp_group_pg.boxed())
|
||||
else:
|
||||
output = torch_op(input, sizes, mapping.tp_group)
|
||||
|
||||
def convert_output(x, x_info):
|
||||
if dim == 0:
|
||||
@ -300,23 +307,29 @@ def reducescatter(
|
||||
return x
|
||||
|
||||
if isinstance(input, torch.Tensor):
|
||||
torch_op = torch.ops.trtllm.reducescatter
|
||||
if mpi_disabled():
|
||||
torch_op = torch.ops.trtllm.reducescatter_pg
|
||||
else:
|
||||
torch_op = torch.ops.trtllm.reducescatter
|
||||
output_info = get_output_info(input, dim)
|
||||
input = convert_input(input, output_info)
|
||||
else:
|
||||
input, valid = filter_valid_input(input)
|
||||
torch_op = torch.ops.trtllm.reducescatter_list
|
||||
if mpi_disabled():
|
||||
torch_op = torch.ops.trtllm.reducescatter_list_pg
|
||||
else:
|
||||
torch_op = torch.ops.trtllm.reducescatter_list
|
||||
output_info = [get_output_info(val, dim) for val in input]
|
||||
input = [
|
||||
convert_input(val, val_info)
|
||||
for val, val_info in zip(input, output_info)
|
||||
]
|
||||
|
||||
output = torch_op(
|
||||
input,
|
||||
sizes,
|
||||
mapping.tp_group,
|
||||
)
|
||||
if mpi_disabled():
|
||||
output = torch_op(input, sizes, mapping.tp_group,
|
||||
mapping.tp_group_pg.boxed())
|
||||
else:
|
||||
output = torch_op(input, sizes, mapping.tp_group)
|
||||
|
||||
if isinstance(input, torch.Tensor):
|
||||
output = output.view(output_info['output_shape'])
|
||||
@ -489,6 +502,9 @@ class AllReduce(nn.Module):
|
||||
self.workspace = None
|
||||
self.strategy = strategy
|
||||
self.mnnvl_allreduce = None
|
||||
self._disable_mpi = mpi_disabled()
|
||||
|
||||
self.all_reduce_op = torch.ops.trtllm.allreduce_pg if self._disable_mpi else torch.ops.trtllm.allreduce
|
||||
|
||||
if self.mapping.tp_size > 1:
|
||||
# When Strategy is UB, it is guaranteed that the workspace is not used.
|
||||
@ -572,7 +588,17 @@ class AllReduce(nn.Module):
|
||||
# Make sure the strategy is AUTO since allreduceOp does not have the branch for MNNVL
|
||||
if allreduce_strategy == AllReduceStrategy.MNNVL:
|
||||
allreduce_strategy = AllReduceStrategy.AUTO
|
||||
output = torch.ops.trtllm.allreduce(
|
||||
|
||||
additional_args = {}
|
||||
if self._disable_mpi:
|
||||
pg = self.mapping.tp_group_pg
|
||||
assert pg is not None, "TP ProcessGroup not initialised"
|
||||
additional_args = {
|
||||
"rank": torch.distributed.get_rank(),
|
||||
"pg": pg.boxed(),
|
||||
}
|
||||
|
||||
output = self.all_reduce_op(
|
||||
input=input,
|
||||
residual=all_reduce_params.residual,
|
||||
norm_weight=all_reduce_params.norm_weight,
|
||||
@ -585,6 +611,7 @@ class AllReduce(nn.Module):
|
||||
eps=all_reduce_params.eps,
|
||||
trigger_completion_at_end=all_reduce_params.
|
||||
trigger_completion_at_end,
|
||||
**additional_args,
|
||||
)
|
||||
|
||||
return output if len(output) > 1 else output[0]
|
||||
|
||||
43
tensorrt_llm/_torch/distributed/pg_utils.py
Normal file
43
tensorrt_llm/_torch/distributed/pg_utils.py
Normal file
@ -0,0 +1,43 @@
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
def split(color: int, key: int,
|
||||
pg_boxed: torch.ScriptObject) -> torch.ScriptObject:
|
||||
"""Create a subgroup ProcessGroup.
|
||||
|
||||
This gathers (color, key) from all ranks, selects members with matching color,
|
||||
sorts them by key to determine rank ordering, and creates a new ProcessGroup
|
||||
for those ranks using torch.distributed.new_group, inheriting backend from
|
||||
the global ProcessGroup.
|
||||
"""
|
||||
if not dist.is_initialized():
|
||||
raise RuntimeError("torch.distributed is not initialized")
|
||||
try:
|
||||
pg = torch.distributed.ProcessGroup.unbox(pg_boxed)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error unboxing ProcessGroup: {e}") from e
|
||||
|
||||
group_size = dist.get_world_size(group=pg)
|
||||
# gather (color, key, global_rank) within the provided pg
|
||||
payload = (int(color), int(key), int(dist.get_rank(group=pg)))
|
||||
gathered = [None] * group_size
|
||||
dist.all_gather_object(gathered, payload, group=pg)
|
||||
|
||||
members = []
|
||||
for c, k, global_rank in gathered:
|
||||
if c == color:
|
||||
members.append((int(k), int(global_rank)))
|
||||
|
||||
members.sort()
|
||||
ranks = [r for _, r in members]
|
||||
if not ranks:
|
||||
raise ValueError(f"Split by color {color} produced empty subgroup")
|
||||
if (current_rank := dist.get_rank()) not in ranks:
|
||||
raise ValueError(
|
||||
f"Current rank {current_rank} not in color {color} subgroup")
|
||||
|
||||
# Create subgroup under the provided pg; ranks are global ranks
|
||||
sub_pg = dist.new_group(ranks=ranks, use_local_synchronization=True)
|
||||
# Return TorchScript boxed ProcessGroup so C++ can unwrap it
|
||||
return sub_pg.boxed()
|
||||
@ -10,7 +10,7 @@ from typing import Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from tensorrt_llm._utils import nvtx_range
|
||||
from tensorrt_llm._utils import mpi_disabled, nvtx_range
|
||||
from tensorrt_llm.mapping import CpType
|
||||
|
||||
from ..distributed import Distributed
|
||||
@ -69,6 +69,8 @@ class ExecutorRequestQueue:
|
||||
self.is_shutdown = False
|
||||
self.should_exclude_last_generation_logits = False
|
||||
|
||||
self._disable_mpi = mpi_disabled()
|
||||
|
||||
def _get_from_request_queue(
|
||||
self,
|
||||
timeout: Optional[datetime.timedelta]) -> List[RequestQueueItem]:
|
||||
@ -267,8 +269,14 @@ class ExecutorRequestQueue:
|
||||
) -> List[RequestQueueItem]:
|
||||
"""Common logic for fetching and processing requests from the queue."""
|
||||
# Calculate timeout
|
||||
timeout = None if (total_num_active_requests == 0) and len(
|
||||
self.waiting_queue) == 0 else datetime.timedelta(0)
|
||||
idle = (total_num_active_requests == 0) and len(self.waiting_queue) == 0
|
||||
if idle:
|
||||
# In Ray path (TLLM_DISABLE_MPI=1), use a periodic heartbeat timeout so rank 0
|
||||
# reaches the broadcast path regularly to prevent trtllm-serve timeout when idle.
|
||||
timeout = datetime.timedelta(
|
||||
seconds=1200) if self._disable_mpi else None
|
||||
else:
|
||||
timeout = datetime.timedelta(0)
|
||||
|
||||
# Fetch requests from rank 0
|
||||
new_requests = []
|
||||
@ -569,7 +577,13 @@ class ExecutorRequestQueue:
|
||||
payloads = self.dist.recv_object(self.dist.prev_pp_rank, tag)
|
||||
|
||||
if not self.dist.is_last_pp_rank:
|
||||
self.dist.send_object(payloads, self.dist.next_pp_rank, tag)
|
||||
if self._disable_mpi:
|
||||
isend_payload = self.dist.isend_object(payloads,
|
||||
self.dist.next_pp_rank,
|
||||
tag)
|
||||
isend_payload.wait()
|
||||
else:
|
||||
self.dist.send_object(payloads, self.dist.next_pp_rank, tag)
|
||||
|
||||
return payloads
|
||||
|
||||
|
||||
@ -3,6 +3,7 @@ import datetime
|
||||
import functools
|
||||
import gc
|
||||
import os
|
||||
import pickle # nosec B403
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
@ -21,8 +22,8 @@ except ImportError:
|
||||
|
||||
from tensorrt_llm._torch.pyexecutor.resource_manager import (
|
||||
ResourceManagerType, request_context)
|
||||
from tensorrt_llm._utils import (customized_gc_thresholds, global_mpi_rank,
|
||||
is_trace_enabled, nvtx_range, trace_func)
|
||||
from tensorrt_llm._utils import (customized_gc_thresholds, is_trace_enabled,
|
||||
mpi_disabled, nvtx_range, trace_func)
|
||||
from tensorrt_llm.bindings.executor import (DisServingRequestStats,
|
||||
FinishReason, InflightBatchingStats,
|
||||
IterationStats, KvCacheStats,
|
||||
@ -166,7 +167,7 @@ class PyExecutor:
|
||||
peft_cache_config: Optional[PeftCacheConfig] = None):
|
||||
super(PyExecutor, self).__init__()
|
||||
self.device_id = torch.cuda.current_device()
|
||||
self.global_rank = global_mpi_rank()
|
||||
self.global_rank = dist.rank
|
||||
|
||||
self.peft_cache_config = peft_cache_config
|
||||
|
||||
@ -213,6 +214,7 @@ class PyExecutor:
|
||||
self.response_lock = threading.Lock()
|
||||
self.response_cv = threading.Condition(self.response_lock)
|
||||
self.responses = {}
|
||||
self.result_wait_queues = {}
|
||||
|
||||
# kv cache events
|
||||
self.kv_cache_manager = self.resource_manager.resource_managers.get(
|
||||
@ -232,6 +234,7 @@ class PyExecutor:
|
||||
self.num_scheduled_requests: int = 0
|
||||
self.benchmark_req_queues_size = int(
|
||||
os.environ.get("TLLM_BENCHMARK_REQ_QUEUES_SIZE", 0))
|
||||
self._disable_mpi = mpi_disabled()
|
||||
|
||||
# list of requests in each PP micro batch
|
||||
self.num_micro_batches = self.dist.pp_size
|
||||
@ -391,11 +394,19 @@ class PyExecutor:
|
||||
def __exit__(self):
|
||||
self.shutdown()
|
||||
|
||||
def enqueue_requests(self, requests: List[ExecutorRequest]) -> List[int]:
|
||||
def enqueue_requests(
|
||||
self,
|
||||
requests: List[ExecutorRequest],
|
||||
result_wait_queue: "Optional[ray.actor.ActorHandle]" = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
Enqueue new requests
|
||||
"""
|
||||
req_ids = self.executor_request_queue.enqueue_requests(requests)
|
||||
if result_wait_queue is not None:
|
||||
with self.response_cv:
|
||||
for req_id in req_ids:
|
||||
self.result_wait_queues[req_id] = result_wait_queue
|
||||
return req_ids
|
||||
|
||||
def await_responses(
|
||||
@ -420,6 +431,7 @@ class PyExecutor:
|
||||
for req_id in id:
|
||||
responses.append(
|
||||
self._await_single_response(id=req_id, timeout=timeout))
|
||||
|
||||
return responses
|
||||
|
||||
def cancel_request(self, id: int):
|
||||
@ -477,14 +489,18 @@ class PyExecutor:
|
||||
def wait_shutdown(self):
|
||||
self.shutdown_event.wait()
|
||||
|
||||
def enqueue_request(self,
|
||||
request: ExecutorRequest,
|
||||
query: Optional[List] = None) -> int:
|
||||
def enqueue_request(
|
||||
self,
|
||||
request: ExecutorRequest,
|
||||
query: Optional[List] = None,
|
||||
result_wait_queue: "Optional[ray.actor.ActorHandle]" = None) -> int:
|
||||
"""
|
||||
Enqueue a new request, query is only used in `StarAttention`.
|
||||
"""
|
||||
req_id = self.executor_request_queue.enqueue_request(request, query)
|
||||
|
||||
if result_wait_queue is not None:
|
||||
with self.response_cv:
|
||||
self.result_wait_queues[req_id] = result_wait_queue
|
||||
return req_id
|
||||
|
||||
def set_gather_responses(self, gather_all_responses):
|
||||
@ -900,10 +916,12 @@ class PyExecutor:
|
||||
if previous_batch is not None:
|
||||
sample_state = previous_batch.sample_state
|
||||
if not self.dist.is_last_pp_rank:
|
||||
recv_object_funct = self.dist.recv_object_from_isend if self._disable_mpi \
|
||||
else self.dist.recv_object
|
||||
torch.cuda.nvtx.range_push(
|
||||
"_handle_new_tokens_inter_pp")
|
||||
# Receive tokens from previous pp rank (w.r.t model forward direction)
|
||||
sample_state.host = self.dist.recv_object(
|
||||
sample_state.host = recv_object_funct(
|
||||
src=self.dist.prev_pp_rank,
|
||||
tag=prev_microbatch_id,
|
||||
)
|
||||
@ -1849,7 +1867,6 @@ class PyExecutor:
|
||||
for request in failed_requests:
|
||||
req_id = request.py_request_id
|
||||
request.state = LlmRequestState.GENERATION_COMPLETE
|
||||
self._terminate_request(request)
|
||||
error_responses[req_id] = LlmResponse(
|
||||
request_id=req_id,
|
||||
error_msg=error_msg,
|
||||
@ -1861,7 +1878,9 @@ class PyExecutor:
|
||||
request for request in self.active_requests
|
||||
if request not in requests
|
||||
]
|
||||
self._enqueue_responses(error_responses.items())
|
||||
self._enqueue_responses(list(error_responses.items()))
|
||||
for request in failed_requests:
|
||||
self._terminate_request(request)
|
||||
|
||||
def _terminate_request(self, request: LlmRequest):
|
||||
if self._disagg_pp_termination_handler is not None:
|
||||
@ -1885,6 +1904,8 @@ class PyExecutor:
|
||||
self.resource_manager.free_resources(request)
|
||||
else:
|
||||
self.resource_manager.free_resources(request)
|
||||
if self.gather_all_responses or self.dist.rank == 0:
|
||||
self.result_wait_queues.pop(request.py_request_id, None)
|
||||
|
||||
def _is_request_in_transmission(self, request) -> bool:
|
||||
"""Check if a request is currently in transmission state."""
|
||||
@ -1960,6 +1981,13 @@ class PyExecutor:
|
||||
self.responses[req_id].append(resp)
|
||||
else:
|
||||
self.responses.update({req_id: [resp]})
|
||||
# (TODO: joyang) There are other types of responses, we need to sort out.
|
||||
if type(
|
||||
resp
|
||||
) == LlmResponse and req_id in self.result_wait_queues and self.result_wait_queues[
|
||||
req_id] is not None:
|
||||
self.result_wait_queues[req_id].put_response.remote(
|
||||
resp.client_id, resp)
|
||||
self.response_cv.notify_all()
|
||||
|
||||
@nvtx_range("_handle_first_token_response")
|
||||
@ -2030,9 +2058,10 @@ class PyExecutor:
|
||||
|
||||
self.active_requests.clear()
|
||||
self.active_requests.extend(new_active_requests)
|
||||
# Request should be terminated after enqueueing response to ensure we can enqueue response successfully.
|
||||
self._enqueue_responses(new_responses)
|
||||
for request in requests_to_terminate:
|
||||
self._terminate_request(request)
|
||||
self._enqueue_responses(new_responses)
|
||||
return requests_to_terminate
|
||||
|
||||
@nvtx_range("_terminate_ctx_finished_requests")
|
||||
|
||||
@ -12,7 +12,7 @@ import torch
|
||||
|
||||
import tensorrt_llm
|
||||
from tensorrt_llm._torch.pyexecutor.resource_manager import ResourceManagerType
|
||||
from tensorrt_llm._utils import get_sm_version
|
||||
from tensorrt_llm._utils import get_sm_version, mpi_disabled
|
||||
from tensorrt_llm.bindings.executor import (CapacitySchedulerPolicy,
|
||||
ContextChunkingPolicy,
|
||||
GuidedDecodingConfig)
|
||||
@ -28,7 +28,7 @@ from tensorrt_llm.mapping import Mapping
|
||||
from tensorrt_llm.quantization import QuantAlgo
|
||||
|
||||
from ..attention_backend.interface import AttentionRuntimeFeatures
|
||||
from ..distributed import MPIDist
|
||||
from ..distributed import MPIDist, TorchDist
|
||||
from ..speculative import (get_num_extra_kv_tokens, get_spec_drafter,
|
||||
get_spec_resource_manager)
|
||||
from ..utils import _get_allow_chain_drafter
|
||||
@ -293,7 +293,10 @@ def create_py_executor(
|
||||
pytorch_backend_config.disable_overlap_scheduler = True
|
||||
|
||||
mapping = _get_mapping(llm_args.parallel_config.to_mapping())
|
||||
dist = MPIDist(mapping=mapping)
|
||||
if mpi_disabled():
|
||||
dist = TorchDist(mapping=mapping)
|
||||
else:
|
||||
dist = MPIDist(mapping=mapping)
|
||||
|
||||
cache_transceiver_config = None
|
||||
if llm_args.cache_transceiver_config is not None:
|
||||
|
||||
@ -9,6 +9,7 @@ import torch
|
||||
|
||||
import tensorrt_llm
|
||||
import tensorrt_llm.bindings
|
||||
from tensorrt_llm._utils import mpi_disabled
|
||||
from tensorrt_llm.bindings.BuildInfo import ENABLE_MULTI_DEVICE
|
||||
from tensorrt_llm.lora_helper import LoraConfig
|
||||
from tensorrt_llm.lora_manager import LoraManager, LoraModelConfig
|
||||
@ -610,7 +611,12 @@ class KVCacheManager(BaseResourceManager):
|
||||
|
||||
if mapping.world_size > 1:
|
||||
# make sure all ranks use same value for maxTokens
|
||||
max_tokens = mpi_comm().allreduce(max_tokens, op=MPI.MIN)
|
||||
if mpi_disabled():
|
||||
from tensorrt_llm._utils import torch_comm
|
||||
max_tokens = torch_comm().allreduce(
|
||||
max_tokens, op=torch.distributed.ReduceOp.MIN)
|
||||
else:
|
||||
max_tokens = mpi_comm().allreduce(max_tokens, op=MPI.MIN)
|
||||
|
||||
# get number of blocks
|
||||
blocks_in_primary_pool = math.ceil(max_tokens / tokens_per_block)
|
||||
|
||||
@ -13,7 +13,7 @@ import torch.nn.functional as F
|
||||
|
||||
from tensorrt_llm._torch.pyexecutor.make_decoding_batch_input_output import \
|
||||
MakeDecodingBatchInputOutput
|
||||
from tensorrt_llm._utils import nvtx_range, torch_dtype_to_binding
|
||||
from tensorrt_llm._utils import mpi_disabled, nvtx_range, torch_dtype_to_binding
|
||||
from tensorrt_llm.bindings import (CudaStream, DataType, ModelConfig,
|
||||
WorldConfig, make_sampling_config)
|
||||
from tensorrt_llm.bindings.executor import (DecodingConfig, DecodingMode,
|
||||
@ -1781,8 +1781,16 @@ class TRTLLMSampler(Sampler):
|
||||
2 if self.is_trt_overlap else 1)
|
||||
self.micro_batch_idx = 0
|
||||
|
||||
self.world_config = WorldConfig.mpi(mapping.gpus_per_node,
|
||||
mapping.tp_size, mapping.pp_size)
|
||||
if mpi_disabled():
|
||||
self.world_config = WorldConfig(mapping.tp_size,
|
||||
mapping.pp_size,
|
||||
mapping.cp_size,
|
||||
rank=mapping.rank,
|
||||
gpus_per_node=mapping.gpus_per_node)
|
||||
else:
|
||||
self.world_config = WorldConfig.mpi(mapping.gpus_per_node,
|
||||
mapping.tp_size,
|
||||
mapping.pp_size)
|
||||
self.model_config = ModelConfig(vocab_size, num_hidden_layers,
|
||||
num_hidden_layers, 0, num_heads,
|
||||
hidden_size, self.model_datatype)
|
||||
|
||||
@ -312,3 +312,11 @@ def create_lm_head_tp_mapping(mapping: Mapping, token_count: int) -> Mapping:
|
||||
# It's here so that unit tests can mock it and turn it off.
|
||||
def _get_allow_chain_drafter() -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def get_device_uuid(device_idx: int) -> str:
|
||||
"""Get the UUID of a CUDA device using torch cuda api"""
|
||||
|
||||
property = torch.cuda.get_device_properties(device_idx)
|
||||
uuid = "GPU-" + str(property.uuid)
|
||||
return uuid
|
||||
|
||||
@ -19,6 +19,7 @@ import json
|
||||
import linecache
|
||||
import math
|
||||
import os
|
||||
import socket
|
||||
import struct
|
||||
import tempfile
|
||||
import trace
|
||||
@ -468,6 +469,12 @@ def dim_resolve_negative(dim, ndim):
|
||||
return tuple(pos)
|
||||
|
||||
|
||||
def get_free_port():
|
||||
with socket.socket() as sock:
|
||||
sock.bind(("", 0))
|
||||
return sock.getsockname()[1]
|
||||
|
||||
|
||||
# mpi4py only exports MPI_COMM_TYPE_SHARED, so we define OMPI_COMM_TYPE_HOST here
|
||||
OMPI_COMM_TYPE_HOST = 9
|
||||
|
||||
@ -490,11 +497,44 @@ def local_mpi_comm():
|
||||
return local_comm
|
||||
|
||||
|
||||
# Global TorchDist instance for Ray orchestrator
|
||||
_torch_comm = None
|
||||
|
||||
|
||||
def set_torch_comm(torch_comm_instance):
|
||||
"""Set global TorchDist instance"""
|
||||
global _torch_comm
|
||||
_torch_comm = torch_comm_instance
|
||||
|
||||
|
||||
def torch_comm():
|
||||
"""Get global TorchDist instance"""
|
||||
if _torch_comm is None:
|
||||
raise RuntimeError(
|
||||
"TorchDist not initialized. Call set_torch_comm() first.")
|
||||
return _torch_comm
|
||||
|
||||
|
||||
def mpi_disabled() -> bool:
|
||||
"""True if TLLM_DISABLE_MPI is set to "1", False otherwise."""
|
||||
return os.environ.get("TLLM_DISABLE_MPI") == "1"
|
||||
|
||||
|
||||
def mpi_rank():
|
||||
if mpi_disabled():
|
||||
try:
|
||||
return torch.distributed.get_rank()
|
||||
except ValueError:
|
||||
# Fallback: return 0 when MPI is absent (Ray / Slurm PMIx)
|
||||
return 0
|
||||
return mpi_comm().Get_rank() if ENABLE_MULTI_DEVICE else 0
|
||||
|
||||
|
||||
def global_mpi_rank():
|
||||
if mpi_disabled():
|
||||
# Fallback: return 0 when MPI is absent (Ray / Slurm PMIx)
|
||||
return 0
|
||||
|
||||
return MPI.COMM_WORLD.Get_rank() if ENABLE_MULTI_DEVICE else 0
|
||||
|
||||
|
||||
@ -702,6 +742,9 @@ def is_trace_enabled(env_var: str):
|
||||
value = os.environ.get(env_var, "-1")
|
||||
if value == "ALL":
|
||||
return True
|
||||
if value == "-1":
|
||||
# early return w/o calling global_mpi_rank() for Ray path
|
||||
return False
|
||||
try:
|
||||
return int(value) == global_mpi_rank()
|
||||
except ValueError:
|
||||
@ -1150,3 +1193,13 @@ def set_prometheus_multiproc_dir() -> object:
|
||||
os.environ["PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
|
||||
logger.info(
|
||||
f"PROMETHEUS_MULTIPROC_DIR: {os.environ['PROMETHEUS_MULTIPROC_DIR']}")
|
||||
|
||||
|
||||
TORCH_PYBIND11_ABI = None
|
||||
|
||||
|
||||
def torch_pybind11_abi() -> str:
|
||||
global TORCH_PYBIND11_ABI
|
||||
if TORCH_PYBIND11_ABI is None:
|
||||
TORCH_PYBIND11_ABI = f"{torch._C._PYBIND11_COMPILER_TYPE}{torch._C._PYBIND11_STDLIB}{torch._C._PYBIND11_BUILD_ABI}"
|
||||
return TORCH_PYBIND11_ABI
|
||||
|
||||
@ -100,6 +100,15 @@ class BaseWorker(GenerationExecutor):
|
||||
if global_mpi_size() > 1:
|
||||
logger.set_rank(self.global_rank)
|
||||
|
||||
def _get_comm_ranks_device_id(self):
|
||||
device_id = self.global_rank % torch.cuda.device_count()
|
||||
torch.cuda.set_device(device_id)
|
||||
# Make sure C++ executor would use same devices/ranks as py_executor
|
||||
global_rank = global_mpi_rank()
|
||||
comm_ranks = mpi_comm().allgather(global_rank)
|
||||
device_ids = mpi_comm().allgather(device_id)
|
||||
return comm_ranks, device_ids
|
||||
|
||||
def setup_engine(self):
|
||||
"""
|
||||
Setup the engine for the worker.
|
||||
@ -108,21 +117,12 @@ class BaseWorker(GenerationExecutor):
|
||||
if isinstance(self._engine, list):
|
||||
self._engine = self._engine[self.rank]
|
||||
|
||||
def _get_comm_ranks_device_id():
|
||||
device_id = self.global_rank % torch.cuda.device_count()
|
||||
torch.cuda.set_device(device_id)
|
||||
# Make sure C++ executor would use same devices/ranks as py_executor
|
||||
global_rank = global_mpi_rank()
|
||||
comm_ranks = mpi_comm().allgather(global_rank)
|
||||
device_ids = mpi_comm().allgather(device_id)
|
||||
return comm_ranks, device_ids
|
||||
|
||||
def _create_py_executor():
|
||||
args = {}
|
||||
assert hasattr(
|
||||
self.llm_args, "backend"
|
||||
), "llm_args should be with backend in _create_py_executor"
|
||||
_ = _get_comm_ranks_device_id()
|
||||
_ = self._get_comm_ranks_device_id()
|
||||
if self.llm_args.backend == "pytorch":
|
||||
from tensorrt_llm._torch.pyexecutor.py_executor_creator import \
|
||||
create_py_executor
|
||||
@ -168,7 +168,7 @@ class BaseWorker(GenerationExecutor):
|
||||
executor_config.logits_post_processor_config = tllm.LogitsPostProcessorConfig(
|
||||
processor_batched=self._batched_logits_processor,
|
||||
replicate=False)
|
||||
comm_ranks, device_ids = _get_comm_ranks_device_id()
|
||||
comm_ranks, device_ids = self._get_comm_ranks_device_id()
|
||||
executor_config.parallel_config = tllm.ParallelConfig(
|
||||
participant_ids=comm_ranks, device_ids=device_ids)
|
||||
|
||||
@ -308,7 +308,9 @@ class BaseWorker(GenerationExecutor):
|
||||
model_config=self._runtime_model_config,
|
||||
uids=[str(prompt_adapter_request.adapter_id)])
|
||||
|
||||
def _enqueue_request(self, request: GenerationRequest) -> int:
|
||||
def _enqueue_request(self,
|
||||
request: GenerationRequest,
|
||||
result_wait_queue=None) -> int:
|
||||
assert request.id is not None
|
||||
py_lora_path = None
|
||||
if self._lora_manager is not None and request.lora_request is not None:
|
||||
@ -506,10 +508,20 @@ class BaseWorker(GenerationExecutor):
|
||||
if request.query_token_ids is not None:
|
||||
# pytorch star attention workflow
|
||||
# a workaround to avoid public interface update
|
||||
req_id = self.engine.enqueue_request(executor_request,
|
||||
request.query_token_ids)
|
||||
if self._is_pytorch_backend and result_wait_queue is not None:
|
||||
req_id = self.engine.enqueue_request(
|
||||
executor_request,
|
||||
request.query_token_ids,
|
||||
result_wait_queue=result_wait_queue)
|
||||
else:
|
||||
req_id = self.engine.enqueue_request(
|
||||
executor_request, request.query_token_ids)
|
||||
else:
|
||||
req_id = self.engine.enqueue_request(executor_request)
|
||||
if self._is_pytorch_backend and result_wait_queue is not None:
|
||||
req_id = self.engine.enqueue_request(
|
||||
executor_request, result_wait_queue=result_wait_queue)
|
||||
else:
|
||||
req_id = self.engine.enqueue_request(executor_request)
|
||||
return req_id
|
||||
except Exception as e:
|
||||
raise RequestError(str(e)) from e
|
||||
@ -557,7 +569,7 @@ class BaseWorker(GenerationExecutor):
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback) -> bool:
|
||||
self.shutdown()
|
||||
return exc_type is None or exc_type == BaseWorker.WorkerExit
|
||||
return exc_type is None or exc_type == self.WorkerExit
|
||||
|
||||
def __del__(self):
|
||||
self.shutdown()
|
||||
|
||||
@ -7,8 +7,8 @@ import traceback
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
from typing import (TYPE_CHECKING, AsyncIterable, Generator, List, Optional,
|
||||
Union)
|
||||
from typing import (TYPE_CHECKING, AsyncIterable, Dict, Generator, List,
|
||||
Optional, Union)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -21,7 +21,7 @@ from .._utils import mpi_world_size
|
||||
from ..bindings import executor as tllm
|
||||
from ..builder import Engine
|
||||
from ..disaggregated_params import DisaggregatedParams
|
||||
from ..llmapi.llm_args import BaseLlmArgs, KvCacheConnectorConfig
|
||||
from ..llmapi.llm_args import BaseLlmArgs, KvCacheConnectorConfig, TorchLlmArgs
|
||||
from ..llmapi.llm_utils import KvCacheRetentionConfig
|
||||
from ..llmapi.mpi_session import (MpiSession, external_mpi_comm_available,
|
||||
need_spawn_mpi_workers)
|
||||
@ -103,6 +103,9 @@ class GenerationExecutor(ABC):
|
||||
self._iter_kv_events_result: IterationResult | None = None
|
||||
self._iter_stats_result: IterationResult | None = None
|
||||
|
||||
def use_ray_queue(self) -> bool:
|
||||
return False
|
||||
|
||||
@abstractmethod
|
||||
def submit(self, request: GenerationRequest) -> GenerationResult:
|
||||
pass
|
||||
@ -355,6 +358,23 @@ class GenerationExecutor(ABC):
|
||||
self._iter_kv_events_result.set_timeout(timeout)
|
||||
return self._iter_kv_events_result
|
||||
|
||||
@staticmethod
|
||||
def _create_ray_executor(
|
||||
worker_kwargs: Dict,
|
||||
model_world_size: int,
|
||||
postproc_worker_config: PostprocWorkerConfig,
|
||||
is_llm_executor: bool,
|
||||
tp_size: int,
|
||||
kv_connector_config: Optional[KvCacheConnectorConfig] = None):
|
||||
from .ray_executor import RayExecutor
|
||||
|
||||
return RayExecutor(worker_kwargs,
|
||||
model_world_size=model_world_size,
|
||||
postproc_worker_config=postproc_worker_config,
|
||||
is_llm_executor=is_llm_executor,
|
||||
tp_size=tp_size,
|
||||
kv_connector_config=kv_connector_config)
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
engine: Union[Path, Engine],
|
||||
@ -372,6 +392,7 @@ class GenerationExecutor(ABC):
|
||||
hf_model_dir: Optional[Path] = None,
|
||||
tokenizer: Optional[TokenizerBase] = None,
|
||||
llm_args: Optional[BaseLlmArgs] = None,
|
||||
**args,
|
||||
) -> Union["GenerationExecutorProxy", "GenerationExecutorWorker"]:
|
||||
# local imports to avoid cyclic importing
|
||||
from .proxy import GenerationExecutorProxy
|
||||
@ -406,6 +427,20 @@ class GenerationExecutor(ABC):
|
||||
if lora_config:
|
||||
worker_kwargs["lora_config"] = lora_config
|
||||
|
||||
orchestrator_type = None if not isinstance(
|
||||
llm_args, TorchLlmArgs) else llm_args.orchestrator_type
|
||||
if orchestrator_type == "ray":
|
||||
return GenerationExecutor._create_ray_executor(
|
||||
worker_kwargs,
|
||||
model_world_size,
|
||||
postproc_worker_config,
|
||||
is_llm_executor=is_llm_executor,
|
||||
tp_size=args.get("tp_size", 1),
|
||||
kv_connector_config=kv_connector_config)
|
||||
elif orchestrator_type is not None:
|
||||
raise ValueError(
|
||||
f"Unsupported orchestrator_type: {orchestrator_type}")
|
||||
|
||||
# The case where the Python main process is launched by mpirun
|
||||
mpirun_launch = external_mpi_comm_available(model_world_size)
|
||||
# The case where the Python main process utilizes mpi4py to spawn MPI workers
|
||||
|
||||
305
tensorrt_llm/executor/ray_executor.py
Normal file
305
tensorrt_llm/executor/ray_executor.py
Normal file
@ -0,0 +1,305 @@
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
try:
|
||||
import ray
|
||||
except ModuleNotFoundError as e:
|
||||
e.msg = """Cannot import Ray. Please install 'ray' package to use ray orchestrator"""
|
||||
raise
|
||||
|
||||
from ray.util.placement_group import (PlacementGroup,
|
||||
PlacementGroupSchedulingStrategy,
|
||||
get_current_placement_group,
|
||||
placement_group)
|
||||
|
||||
from tensorrt_llm._utils import get_free_port
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
from .._utils import nvtx_range_debug
|
||||
from ..llmapi.llm_args import KvCacheConnectorConfig
|
||||
from .executor import GenerationExecutor
|
||||
from .postproc_worker import PostprocWorkerConfig
|
||||
from .ray_gpu_worker import RayGPUWorker, RayWorkerWrapper
|
||||
from .request import GenerationRequest
|
||||
from .result import GenerationResult, RayAsyncQueue, RaySyncQueue
|
||||
|
||||
__all__ = [
|
||||
"RayExecutor",
|
||||
]
|
||||
|
||||
|
||||
class RayExecutor(GenerationExecutor):
|
||||
|
||||
def __init__(self,
|
||||
worker_kwargs: Dict,
|
||||
model_world_size: int,
|
||||
postproc_worker_config: PostprocWorkerConfig,
|
||||
is_llm_executor: bool,
|
||||
tp_size=1,
|
||||
kv_connector_config: Optional[KvCacheConnectorConfig] = None):
|
||||
os.environ['RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES'] = '1'
|
||||
os.environ["RAY_DEDUP_LOGS"] = "0" # for debug
|
||||
|
||||
super().__init__(model_world_size, postproc_worker_config,
|
||||
is_llm_executor)
|
||||
|
||||
self.has_start_local_cluser = False
|
||||
runtime_env = {
|
||||
"env_vars": {
|
||||
"RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1"
|
||||
}
|
||||
}
|
||||
|
||||
ray_init_args = {
|
||||
"include_dashboard": False,
|
||||
"namespace": "trtllm",
|
||||
"ignore_reinit_error": True,
|
||||
"runtime_env": runtime_env
|
||||
}
|
||||
|
||||
if os.environ.get("TLLM_RAY_FORCE_LOCAL_CLUSTER", "0") != "1":
|
||||
try:
|
||||
ray.init(address="auto", **ray_init_args)
|
||||
logger.info(f"Attached to an existing Ray cluster.")
|
||||
except ConnectionError:
|
||||
logger.info(f"Ray cluster not found, starting a new one.")
|
||||
|
||||
if not ray.is_initialized():
|
||||
ray.init(**ray_init_args)
|
||||
self.has_start_local_cluser = True
|
||||
else:
|
||||
ray.init(address="local", **ray_init_args)
|
||||
self.has_start_local_cluser = True
|
||||
|
||||
self.world_size = model_world_size
|
||||
self.tp_size = tp_size
|
||||
self.master_address = ray.util.get_node_ip_address()
|
||||
self.master_port = get_free_port()
|
||||
|
||||
self.response_queue = RayAsyncQueue.options(runtime_env={
|
||||
"env_vars": {
|
||||
"TLLM_DISABLE_MPI": "1"
|
||||
}
|
||||
}).remote()
|
||||
self.response_sync_queue = RaySyncQueue.options(runtime_env={
|
||||
"env_vars": {
|
||||
"TLLM_DISABLE_MPI": "1"
|
||||
}
|
||||
}).remote()
|
||||
self.async_response_queue_weakref = self.create_actor_weak_ref(
|
||||
self.response_queue)
|
||||
self.sync_response_queue_weakref = self.create_actor_weak_ref(
|
||||
self.response_sync_queue)
|
||||
self.response_queue.warmup.remote()
|
||||
self.response_sync_queue.warmup.remote()
|
||||
|
||||
worker_kwargs = dict(**worker_kwargs,
|
||||
postproc_worker_config=postproc_worker_config,
|
||||
is_llm_executor=is_llm_executor,
|
||||
kv_connector_config=kv_connector_config)
|
||||
|
||||
self.create_workers(RayGPUWorker, worker_kwargs)
|
||||
|
||||
@staticmethod
|
||||
def create_actor_weak_ref(actor_handle: ray.actor.ActorHandle):
|
||||
state, _, _ = actor_handle._serialization_helper()
|
||||
return ray.actor.ActorHandle._deserialization_helper(state,
|
||||
weak_ref=True)
|
||||
|
||||
def use_ray_queue(self) -> bool:
|
||||
return True
|
||||
|
||||
def create_workers(self, worker_cls, worker_kwargs):
|
||||
# When set to be a fraction, it allows Ray to schedule
|
||||
# multiple actors on a single GPU for colocate use cases.
|
||||
num_gpus = float(os.getenv("TRTLLM_RAY_PER_WORKER_GPUS", "1.0"))
|
||||
logger.debug(f"{num_gpus=} for each worker.")
|
||||
|
||||
runtime_env = ray.runtime_env.RuntimeEnv()
|
||||
runtime_env["env_vars"] = os.environ.copy()
|
||||
runtime_env["env_vars"].update({
|
||||
"TLLM_DISABLE_MPI": "1",
|
||||
"MASTER_ADDR": self.master_address, # head-IP for NCCL/Gloo
|
||||
"MASTER_PORT": str(self.master_port)
|
||||
})
|
||||
|
||||
self.placement_group, self.bundle_indices = self._get_placement_group(
|
||||
tp_size=self.tp_size)
|
||||
|
||||
self.workers = [
|
||||
RayWorkerWrapper.options(
|
||||
num_gpus=num_gpus,
|
||||
runtime_env=runtime_env, # per-actor env
|
||||
scheduling_strategy=PlacementGroupSchedulingStrategy(
|
||||
placement_group=self.placement_group,
|
||||
placement_group_bundle_index=self.bundle_indices[rank],
|
||||
)).remote(worker_cls, worker_kwargs, self.world_size, rank)
|
||||
for rank in range(self.world_size)
|
||||
]
|
||||
|
||||
ray.get([worker.__ray_ready__.remote() for worker in self.workers])
|
||||
|
||||
def call_all_ray_workers(self, func: str, leader_only: bool,
|
||||
async_call: bool, *args, **kwargs):
|
||||
workers = (self.workers[0], ) if leader_only else self.workers
|
||||
|
||||
if async_call:
|
||||
return [
|
||||
getattr(worker, func).remote(*args, **kwargs)
|
||||
for worker in workers
|
||||
]
|
||||
else:
|
||||
return ray.get([
|
||||
getattr(worker, func).remote(*args, **kwargs)
|
||||
for worker in workers
|
||||
])
|
||||
|
||||
def collective_rpc(self,
|
||||
method: str,
|
||||
args: tuple = (),
|
||||
kwargs: Optional[dict] = None,
|
||||
non_block: bool = False,
|
||||
unique_reply_rank: Optional[int] = None) -> list[Any]:
|
||||
workers = (self.workers[unique_reply_rank],
|
||||
) if unique_reply_rank is not None else self.workers
|
||||
kwargs = kwargs or {}
|
||||
|
||||
refs = []
|
||||
for w in workers:
|
||||
try:
|
||||
refs.append(getattr(w, method).remote(*args, **kwargs))
|
||||
except AttributeError:
|
||||
# Here worker is the RayWorkerWrapper.
|
||||
# For extended worker methods, we need to use call_worker_method since
|
||||
# Ray actor doesn't work with __getattr__ delegation.
|
||||
refs.append(w.call_worker_method.remote(method, *args,
|
||||
**kwargs))
|
||||
|
||||
return refs if non_block else ray.get(refs)
|
||||
|
||||
def submit(self, request: GenerationRequest) -> GenerationResult:
|
||||
"""
|
||||
Low-level API to the executor. Return a "future" GenerationResult
|
||||
which can be waited.
|
||||
Forwards the request to the workers through the request queue.
|
||||
"""
|
||||
request.set_id(self._get_next_client_id())
|
||||
logprob_params = self._get_logprob_params(request)
|
||||
|
||||
result = GenerationResult(
|
||||
request,
|
||||
background_error_handler=self._handle_background_error,
|
||||
executor=self,
|
||||
disaggregated_params=request.disaggregated_params,
|
||||
logprob_params=logprob_params)
|
||||
|
||||
with nvtx_range_debug("request_queue.put"):
|
||||
self.call_all_ray_workers("enqueue_request",
|
||||
leader_only=True,
|
||||
request=request,
|
||||
async_call=False,
|
||||
result_wait_queue=result.queue)
|
||||
|
||||
return result
|
||||
|
||||
def report_device_ids(self) -> list[str]:
|
||||
gpu_ids = self.call_all_ray_workers("report_device_id",
|
||||
leader_only=False,
|
||||
async_call=False)
|
||||
return sorted(gpu_ids)
|
||||
|
||||
def abort_request(self, request_id: int) -> None:
|
||||
self.call_all_ray_workers("abort_request",
|
||||
leader_only=True,
|
||||
async_call=False,
|
||||
request_id=request_id)
|
||||
|
||||
def shutdown(self):
|
||||
# Release actors
|
||||
self.response_queue = None
|
||||
self.response_sync_queue = None
|
||||
self.async_response_queue_weakref = None
|
||||
self.sync_response_queue_weakref = None
|
||||
|
||||
self.workers = None
|
||||
if hasattr(self,
|
||||
"placement_group") and self.placement_group is not None:
|
||||
ray.util.remove_placement_group(self.placement_group)
|
||||
self.placement_group = None
|
||||
self.bundle_indices = None
|
||||
|
||||
if self.has_start_local_cluser:
|
||||
logger.debug("Shutting down Ray cluster")
|
||||
ray.shutdown()
|
||||
|
||||
@property
|
||||
def enable_postprocess_parallel(self) -> bool:
|
||||
ret = super().enable_postprocess_parallel
|
||||
assert ret == False, "Postprocess parallel is not supported in RayExecutor"
|
||||
return ret
|
||||
|
||||
def _get_placement_group(self,
|
||||
tp_size: int) -> Tuple[PlacementGroup, List[int]]:
|
||||
"""
|
||||
Either use the existing placement group from driver script (e.g., in the case of RL FW integration),
|
||||
or create a default PACK placement group where each bundle has tp_size GPUs.
|
||||
- When tp_size ≤ GPUs per node, keep one TP group per node.
|
||||
- When tp_size > GPUs per node, allow a TP group span nodes.
|
||||
- rank 0 must be put on the driver node
|
||||
"""
|
||||
bundle_indices = os.getenv("TRTLLM_RAY_BUNDLE_INDICES", None)
|
||||
|
||||
if bundle_indices:
|
||||
pg = get_current_placement_group()
|
||||
if pg is not None:
|
||||
bundle_indices = list(map(int, bundle_indices.split(",")))
|
||||
assert len(bundle_indices) == self.world_size, (
|
||||
f"Need {self.world_size} bundle indices for world_size, got {bundle_indices=}"
|
||||
)
|
||||
assert len(set(bundle_indices)) == len(bundle_indices), \
|
||||
f"TRTLLM_RAY_BUNDLE_INDICES cannot have duplicate values, but got {bundle_indices=}."
|
||||
|
||||
assert max(bundle_indices) < len(pg.bundle_specs), \
|
||||
f"{bundle_indices=} out of range for PG with {len(pg.bundle_specs)} bundles"
|
||||
|
||||
logger.info(
|
||||
f"Found existing placement group {pg.bundle_specs=}. {bundle_indices=}"
|
||||
)
|
||||
|
||||
# TODO: need to ping TP group onto the same node for RL FW integration case
|
||||
|
||||
return pg, bundle_indices
|
||||
else:
|
||||
logger.warning(
|
||||
f"Ignoring TRTLLM_RAY_BUNDLE_INDICES={bundle_indices} because no global placement group is found."
|
||||
)
|
||||
|
||||
if self.world_size % tp_size:
|
||||
raise ValueError("world_size must be a multiple of tp_size")
|
||||
|
||||
head_tag = f"node:{self.master_address}"
|
||||
nodes = ray.nodes()
|
||||
gpus_per_node = int(nodes[0]["Resources"].get(
|
||||
"GPU", 0)) # assume symmetric across nodes
|
||||
|
||||
bundle_cpu = bundle_gpu = min(tp_size, gpus_per_node)
|
||||
|
||||
bundles, bundle_indices = [], []
|
||||
current = 0
|
||||
for rank in range(self.world_size):
|
||||
if current == 0:
|
||||
bundle = {"GPU": bundle_gpu, "CPU": bundle_cpu}
|
||||
if len(bundles) == 0:
|
||||
bundle[head_tag] = 0.01 # to force placement on head node
|
||||
bundles.append(bundle)
|
||||
|
||||
bundle_indices.append(len(bundles) - 1)
|
||||
current = (current + 1) % bundle_gpu
|
||||
|
||||
strategy = "PACK"
|
||||
logger.debug(
|
||||
f"[Strategy={strategy}] Bundles: {bundles} for tp_size: {tp_size} and world_size: {self.world_size}"
|
||||
)
|
||||
pg = placement_group(bundles, strategy=strategy)
|
||||
|
||||
return pg, bundle_indices
|
||||
202
tensorrt_llm/executor/ray_gpu_worker.py
Normal file
202
tensorrt_llm/executor/ray_gpu_worker.py
Normal file
@ -0,0 +1,202 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
from typing import Optional, Union
|
||||
|
||||
import ray
|
||||
import torch
|
||||
|
||||
from ..bindings import executor as tllm
|
||||
from ..builder import Engine
|
||||
from ..llmapi.llm_args import BaseLlmArgs, KvCacheConnectorConfig
|
||||
from ..llmapi.tokenizer import TokenizerBase
|
||||
from ..lora_helper import LoraConfig
|
||||
from ..sampling_params import BatchedLogitsProcessor
|
||||
from .base_worker import BaseWorker
|
||||
from .postproc_worker import PostprocWorkerConfig
|
||||
from .request import GenerationRequest
|
||||
from .result import GenerationResult
|
||||
|
||||
__all__ = [
|
||||
"RayGPUWorker",
|
||||
"RayWorkerWrapper",
|
||||
]
|
||||
|
||||
|
||||
@ray.remote
|
||||
class RayWorkerWrapper:
|
||||
|
||||
def __init__(self, worker_cls, worker_kwargs, world_size, rank):
|
||||
self.master_address = os.environ["MASTER_ADDR"]
|
||||
self.master_port = os.environ["MASTER_PORT"]
|
||||
|
||||
# Ray can't pickle TensorRT logger
|
||||
global logger
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
# Expect to see global counts w/ RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES=1,
|
||||
# unless CUDA_VISIBLE_DEVICES is set.
|
||||
logger.debug(
|
||||
f"CUDA device count visible to Ray: {torch.cuda.device_count()}")
|
||||
|
||||
# Physical gpu id
|
||||
self.gpu = int(ray.get_gpu_ids()[0])
|
||||
local_gpu = self.physical_to_local_id(self.gpu)
|
||||
|
||||
torch.distributed.init_process_group(
|
||||
backend="cuda:nccl,cpu:gloo",
|
||||
init_method=f"tcp://{self.master_address}:{self.master_port}",
|
||||
world_size=world_size,
|
||||
rank=rank)
|
||||
|
||||
logger.info(
|
||||
f"[Rank {rank}] Finished PG init. Global GPU ID: {self.gpu}, local GPU ID: {local_gpu}"
|
||||
)
|
||||
|
||||
torch.cuda.set_device(local_gpu)
|
||||
|
||||
self.worker = worker_cls(device_id=local_gpu, **worker_kwargs)
|
||||
|
||||
def submit(self, request: GenerationRequest) -> GenerationResult:
|
||||
return self.worker.submit(request)
|
||||
|
||||
def enqueue_request(self,
|
||||
request: GenerationRequest,
|
||||
result_wait_queue: Queue | None = None) -> int:
|
||||
return self.worker.enqueue_request(request, result_wait_queue)
|
||||
|
||||
def abort_request(self, request_id: int) -> None:
|
||||
self.worker.abort_request(request_id)
|
||||
|
||||
def report_device_id(self) -> str:
|
||||
from tensorrt_llm._torch.utils import get_device_uuid
|
||||
local_id = self.physical_to_local_id(self.gpu)
|
||||
return get_device_uuid(local_id)
|
||||
|
||||
@staticmethod
|
||||
def physical_to_local_id(phys_id: int) -> int:
|
||||
visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES")
|
||||
if not visible_devices:
|
||||
return phys_id
|
||||
id_mapping = list(map(int, visible_devices.split(",")))
|
||||
return id_mapping.index(phys_id)
|
||||
|
||||
def call_worker_method(self, method_name: str, *args, **kwargs):
|
||||
"""Generic method to call any method on the underlying worker."""
|
||||
if hasattr(self.worker, method_name):
|
||||
method = getattr(self.worker, method_name)
|
||||
if callable(method):
|
||||
return method(*args, **kwargs)
|
||||
else:
|
||||
raise AttributeError(
|
||||
f"'{method_name}' is not callable on the underlying worker")
|
||||
else:
|
||||
raise AttributeError(
|
||||
f"Underlying worker has no method '{method_name}'")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Customizes the actor's prefix in the Ray logs.
|
||||
|
||||
This makes it easier to identify which worker is producing specific log messages.
|
||||
Refer to https://github.com/NVIDIA-NeMo/RL/blob/faad02113c3c502437ccb339cb848796334aedd9/nemo_rl/models/policy/dtensor_policy_worker_v2.py#L95
|
||||
"""
|
||||
if torch.distributed.is_initialized():
|
||||
return f"{self.__class__.__qualname__}[rank={torch.distributed.get_rank()}]"
|
||||
else:
|
||||
return f"{self.__class__.__qualname__}"
|
||||
|
||||
|
||||
class RayGPUWorker(BaseWorker):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device_id: int,
|
||||
engine: Union[Path, Engine],
|
||||
executor_config: Optional[tllm.ExecutorConfig] = None,
|
||||
batched_logits_processor: Optional[BatchedLogitsProcessor] = None,
|
||||
postproc_worker_config: Optional[PostprocWorkerConfig] = None,
|
||||
is_llm_executor: Optional[bool] = None,
|
||||
lora_config: Optional[LoraConfig] = None,
|
||||
kv_connector_config: Optional[KvCacheConnectorConfig] = None,
|
||||
hf_model_dir: Optional[Path] = None,
|
||||
tokenizer: Optional[TokenizerBase] = None,
|
||||
llm_args: Optional[BaseLlmArgs] = None,
|
||||
) -> None:
|
||||
global logger
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
super().__init__(
|
||||
engine=engine,
|
||||
executor_config=executor_config,
|
||||
batched_logits_processor=batched_logits_processor,
|
||||
postproc_worker_config=postproc_worker_config,
|
||||
is_llm_executor=is_llm_executor,
|
||||
lora_config=lora_config,
|
||||
kv_connector_config=kv_connector_config,
|
||||
hf_model_dir=hf_model_dir,
|
||||
tokenizer=tokenizer,
|
||||
llm_args=llm_args,
|
||||
)
|
||||
|
||||
if not self._is_pytorch_backend:
|
||||
raise ValueError(f"Ray GPU worker only supports PyTorch backend.")
|
||||
|
||||
self.device_id = device_id
|
||||
|
||||
# Override rank attributes using torch
|
||||
self.global_rank = torch.distributed.get_rank()
|
||||
if self.global_rank > 1:
|
||||
logger.set_rank(self.global_rank)
|
||||
|
||||
self.setup_engine()
|
||||
|
||||
def _get_comm_ranks_device_id(self):
|
||||
# Make sure C++ executor would use same devices/ranks as py_executor
|
||||
global_rank = torch.distributed.get_rank()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
comm_ranks = [None] * world_size
|
||||
device_ids = [None] * world_size
|
||||
|
||||
torch.distributed.all_gather_object(comm_ranks, global_rank)
|
||||
torch.distributed.all_gather_object(device_ids, self.device_id)
|
||||
return comm_ranks, device_ids
|
||||
|
||||
def enqueue_request(self,
|
||||
request: GenerationRequest,
|
||||
result_wait_queue: Queue | None = None) -> int:
|
||||
return self._enqueue_request(request, result_wait_queue)
|
||||
|
||||
def submit(self, request: GenerationRequest):
|
||||
raise NotImplementedError(
|
||||
"Ray GPU worker does not support submit() yet.")
|
||||
|
||||
def shutdown(self):
|
||||
|
||||
if self.doing_shutdown:
|
||||
return
|
||||
else:
|
||||
self.doing_shutdown = True
|
||||
|
||||
logger.debug(f'Worker {self.rank} shutting down...')
|
||||
|
||||
if self.engine is not None:
|
||||
self.engine.shutdown()
|
||||
self.engine = None
|
||||
|
||||
assert self._executor_config is None, "An empty executor_config is expected in shutdown when LLM arguments are defined."
|
||||
if (self.llm_args.backend == "pytorch"
|
||||
and hasattr(self, "checkpoint_loader")
|
||||
and self.checkpoint_loader is not None):
|
||||
self.checkpoint_loader.cleanup()
|
||||
self.checkpoint_loader = None
|
||||
|
||||
# Check if there are any errors from the threads before shutdown.
|
||||
self._handle_background_error()
|
||||
|
||||
logger.debug(f"Worker {self.rank} shutdown done.")
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __del__(self):
|
||||
self.shutdown()
|
||||
@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import json
|
||||
import threading
|
||||
import weakref
|
||||
from dataclasses import dataclass, field
|
||||
from queue import Empty, Queue
|
||||
@ -10,7 +11,12 @@ from weakref import WeakMethod
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .._utils import nvtx_range_debug
|
||||
try:
|
||||
import ray
|
||||
except ModuleNotFoundError:
|
||||
from tensorrt_llm import ray_stub as ray
|
||||
|
||||
from .._utils import mpi_disabled, nvtx_range_debug
|
||||
from ..bindings import executor as tllm
|
||||
from ..disaggregated_params import DisaggregatedParams
|
||||
from ..llmapi.tracer import global_tracer
|
||||
@ -146,12 +152,104 @@ class CompletionOutput:
|
||||
return self.logprobs[self._last_logprobs_len:]
|
||||
|
||||
|
||||
def warmup_tensorrt_llm():
|
||||
import tensorrt_llm
|
||||
print("Warmup by importing tensorrt_llm with version",
|
||||
tensorrt_llm.version.__version__)
|
||||
|
||||
|
||||
@ray.remote(max_concurrency=1000000, num_cpus=2)
|
||||
class RayAsyncQueue:
|
||||
"""Ray actor for async response handling."""
|
||||
|
||||
def __init__(self):
|
||||
self.data = {}
|
||||
self.event_map = {}
|
||||
self.warmup_done = False
|
||||
|
||||
def register(self, key: int):
|
||||
assert key not in self.event_map, f"Key {key} already registered"
|
||||
self.event_map[key] = asyncio.Event()
|
||||
|
||||
def unregister(self, key: int):
|
||||
if key in self.event_map:
|
||||
del self.event_map[key]
|
||||
|
||||
if key in self.data:
|
||||
del self.data[key]
|
||||
|
||||
def warmup(self):
|
||||
if self.warmup_done:
|
||||
return
|
||||
warmup_tensorrt_llm()
|
||||
self.warmup_done = True
|
||||
|
||||
def put_response(self, key: int, item: Any):
|
||||
assert key in self.event_map, f"Key {key} not registered"
|
||||
self.data[key] = item
|
||||
self.event_map[key].set()
|
||||
|
||||
async def get_async(self, key: int):
|
||||
assert key in self.event_map, f"Key {key} not registered"
|
||||
await self.event_map[key].wait()
|
||||
self.event_map[key].clear()
|
||||
ret = self.data[key]
|
||||
del self.data[key]
|
||||
return ret
|
||||
|
||||
|
||||
SYNC_QUEUE_MAX_CONCURRENCY = 2
|
||||
|
||||
|
||||
@ray.remote(max_concurrency=SYNC_QUEUE_MAX_CONCURRENCY,
|
||||
num_cpus=SYNC_QUEUE_MAX_CONCURRENCY)
|
||||
class RaySyncQueue:
|
||||
"""Ray actor for sync response handling."""
|
||||
|
||||
def __init__(self):
|
||||
self.data = {}
|
||||
self.event_map = {}
|
||||
self.semaphore = threading.Semaphore(SYNC_QUEUE_MAX_CONCURRENCY - 1)
|
||||
self.warmup_done = False
|
||||
|
||||
def register(self, key: int):
|
||||
assert key not in self.event_map, f"Key {key} already registered"
|
||||
self.event_map[key] = threading.Event()
|
||||
self.event_map[key]
|
||||
|
||||
def unregister(self, key: int):
|
||||
if key in self.event_map:
|
||||
del self.event_map[key]
|
||||
|
||||
if key in self.data:
|
||||
del self.data[key]
|
||||
|
||||
def warmup(self):
|
||||
if self.warmup_done:
|
||||
return
|
||||
warmup_tensorrt_llm()
|
||||
self.warmup_done = True
|
||||
|
||||
def put_response(self, key: int, item: Any):
|
||||
self.data[key] = item
|
||||
self.event_map[key].set()
|
||||
|
||||
def get(self, key: int):
|
||||
with self.semaphore:
|
||||
self.event_map[key].wait()
|
||||
self.event_map[key].clear()
|
||||
ret = self.data[key]
|
||||
del self.data[key]
|
||||
return ret
|
||||
|
||||
|
||||
class GenerationResultBase:
|
||||
''' This holds the core logic of the GenerationResult class. '''
|
||||
|
||||
def __init__(self,
|
||||
id: int,
|
||||
sampling_params: SamplingParams,
|
||||
ray_queue: Optional[RayAsyncQueue] = None,
|
||||
background_error_handler: Optional[Callable] = None,
|
||||
postproc_params: "Optional[PostprocParams]" = None):
|
||||
self.id = id
|
||||
@ -165,12 +263,22 @@ class GenerationResultBase:
|
||||
self._done = False
|
||||
self.metrics_dict = {}
|
||||
|
||||
if has_event_loop():
|
||||
self.aqueue = AsyncQueue()
|
||||
self.queue = self.aqueue.sync_q
|
||||
if ray_queue is not None:
|
||||
if has_event_loop():
|
||||
self.aqueue = ray_queue
|
||||
self.queue = self.aqueue
|
||||
else:
|
||||
self.queue = ray_queue
|
||||
self.aqueue = None
|
||||
|
||||
ray.get(self.queue.register.remote(id))
|
||||
else:
|
||||
self.queue = Queue()
|
||||
self.aqueue = None
|
||||
if has_event_loop():
|
||||
self.aqueue = AsyncQueue()
|
||||
self.queue = self.aqueue.sync_q
|
||||
else:
|
||||
self.queue = Queue()
|
||||
self.aqueue = None
|
||||
|
||||
# In Sampling mode, the Executor runtime will return best_of sequences
|
||||
# in total, which the LLM API will select the n-best sequences among
|
||||
@ -419,6 +527,12 @@ class GenerationResultBase:
|
||||
else:
|
||||
raise ValueError(f"Unknown response type: {response}")
|
||||
|
||||
if self._done and mpi_disabled():
|
||||
assert hasattr(
|
||||
self.queue, "unregister"
|
||||
), "Ray path should be activated for unregistering the Ray queue."
|
||||
self.queue.unregister.remote(self.id)
|
||||
|
||||
def record_stats(self,
|
||||
output: CompletionOutput,
|
||||
stats: Optional[dict[str, float]] = None) -> None:
|
||||
@ -541,9 +655,15 @@ class GenerationResult(GenerationResultBase):
|
||||
disaggregated_params: Optional[DisaggregatedParams] = None,
|
||||
logprob_params: Optional[LogprobParams] = None,
|
||||
) -> None:
|
||||
use_async_queue = has_event_loop()
|
||||
shared_queue = None
|
||||
if executor and executor.use_ray_queue():
|
||||
shared_queue = executor.async_response_queue_weakref if use_async_queue else executor.sync_response_queue_weakref
|
||||
|
||||
super().__init__(
|
||||
generation_request.id,
|
||||
generation_request.sampling_params,
|
||||
shared_queue,
|
||||
background_error_handler,
|
||||
postproc_params=generation_request.postproc_params,
|
||||
)
|
||||
@ -597,13 +717,25 @@ class GenerationResult(GenerationResultBase):
|
||||
if hasattr(self, "_logprob_params"):
|
||||
del self._logprob_params
|
||||
|
||||
def _handle_ray_response(self, response: Any):
|
||||
return response
|
||||
|
||||
def _result_step(self, timeout: Optional[float] = None):
|
||||
response = self.queue.get(timeout=timeout)
|
||||
if mpi_disabled():
|
||||
response = ray.get(self.queue.get.remote(self.request_id))
|
||||
response = self._handle_ray_response(response)
|
||||
else:
|
||||
response = self.queue.get()
|
||||
|
||||
self._handle_response(response)
|
||||
|
||||
async def _aresult_step(self):
|
||||
assert self.aqueue is not None, "The asyncio event loop was not present during initialization, so async operations are not available."
|
||||
response = await self.aqueue.get()
|
||||
if mpi_disabled():
|
||||
response = await self.aqueue.get_async.remote(self.request_id)
|
||||
response = self._handle_ray_response(response)
|
||||
else:
|
||||
response = await self.aqueue.get()
|
||||
global_tracer().log_instant("result_step.get")
|
||||
self._handle_response(response)
|
||||
|
||||
|
||||
@ -41,9 +41,6 @@ __all__ = [
|
||||
|
||||
class GenerationExecutorWorker(BaseWorker):
|
||||
|
||||
class WorkerExit(GeneratorExit):
|
||||
pass
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine: Union[Path, Engine],
|
||||
@ -248,10 +245,6 @@ class GenerationExecutorWorker(BaseWorker):
|
||||
if isinstance(self.engine, PyExecutor):
|
||||
self.engine.wait_shutdown()
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback) -> bool:
|
||||
self.shutdown()
|
||||
return exc_type is None or exc_type == GenerationExecutorWorker.WorkerExit
|
||||
|
||||
|
||||
@print_traceback_on_error
|
||||
def worker_main(
|
||||
|
||||
@ -13,6 +13,7 @@ import transformers
|
||||
from tqdm import tqdm
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from tensorrt_llm._utils import mpi_disabled
|
||||
from tensorrt_llm.inputs.data import TextPrompt
|
||||
from tensorrt_llm.inputs.multimodal import MultimodalInput, MultimodalParams
|
||||
from tensorrt_llm.inputs.registry import DefaultInputProcessor
|
||||
@ -124,6 +125,7 @@ class BaseLLM:
|
||||
**kwargs: Any) -> None:
|
||||
|
||||
self._executor_cls = kwargs.pop("executor_cls", GenerationExecutor)
|
||||
self._orchestrator_type = kwargs.get("orchestrator_type", None)
|
||||
self._llm_id = None
|
||||
|
||||
log_level = logger.level
|
||||
@ -134,6 +136,12 @@ class BaseLLM:
|
||||
if backend == "pytorch":
|
||||
logger.info("Using LLM with PyTorch backend")
|
||||
llm_args_cls = TorchLlmArgs
|
||||
if self._orchestrator_type == "ray" or mpi_disabled():
|
||||
self._orchestrator_type = "ray"
|
||||
os.environ["TLLM_DISABLE_MPI"] = "1"
|
||||
# Propagate to args construction
|
||||
kwargs["orchestrator_type"] = "ray"
|
||||
|
||||
elif backend == '_autodeploy':
|
||||
logger.info("Using LLM with AutoDeploy backend")
|
||||
from .._torch.auto_deploy.llm_args import \
|
||||
@ -984,6 +992,34 @@ class _TorchLLM(BaseLLM):
|
||||
backend=backend,
|
||||
**kwargs)
|
||||
|
||||
@set_api_status("prototype")
|
||||
def _collective_rpc(self,
|
||||
method: str,
|
||||
args: tuple[Any, ...] = (),
|
||||
kwargs: Optional[dict] = None,
|
||||
non_block: bool = False,
|
||||
unique_reply_rank: Optional[int] = None) -> list[Any]:
|
||||
"""
|
||||
Execute an RPC call on all GPU workers. Currently, this is only supported for RayExecutor.
|
||||
|
||||
Args:
|
||||
method (str): The name of the worker method to execute.
|
||||
args (tuple[Any, ...]): Positional arguments to pass to the worker method. Defaults to ().
|
||||
kwargs (dict, optional): Keyword arguments to pass to the worker method. Defaults to None.
|
||||
non_block (bool): Whether to block until all workers have completed the RPC call. Defaults to False.
|
||||
unique_reply_rank (int, optional): The rank of the worker that will be used to send the reply. Defaults to None.
|
||||
|
||||
Returns:
|
||||
list[Any]: A list of results from each worker.
|
||||
"""
|
||||
if hasattr(self._executor, 'collective_rpc'):
|
||||
return self._executor.collective_rpc(method, args, kwargs,
|
||||
non_block, unique_reply_rank)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Executor type {type(self._executor)} does not support collective RPC."
|
||||
)
|
||||
|
||||
def _build_model(self):
|
||||
super()._build_model()
|
||||
assert self._engine_dir is None
|
||||
|
||||
@ -2437,6 +2437,13 @@ class TorchLlmArgs(BaseLlmArgs):
|
||||
status="prototype",
|
||||
)
|
||||
|
||||
orchestrator_type: Optional[Literal["ray"]] = Field(
|
||||
default=None,
|
||||
description=
|
||||
"The orchestrator type to use. Options: 'ray'. Defaults to None, which uses MPI.",
|
||||
status="prototype",
|
||||
)
|
||||
|
||||
# PrivateVars
|
||||
_quant_config: Optional[QuantConfig] = PrivateAttr(default=None)
|
||||
|
||||
|
||||
@ -17,6 +17,9 @@ from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
from tensorrt_llm._torch.device_mesh import DeviceMeshTopologyImpl
|
||||
from tensorrt_llm._utils import mpi_disabled
|
||||
|
||||
|
||||
class CpType(IntEnum):
|
||||
# CP type for ulysses parallelism
|
||||
@ -29,101 +32,12 @@ class CpType(IntEnum):
|
||||
HELIX = 3
|
||||
|
||||
|
||||
class Mapping(object):
|
||||
'''
|
||||
A node with 8 GPUs, tp_size = 4, cp_size = 1, pp_size = 2
|
||||
class MappingBase:
|
||||
"""Base class for distributed mapping configurations"""
|
||||
|
||||
2 tp groups:
|
||||
|
||||
- [0, 1, 2, 3]
|
||||
- [4, 5, 6, 7]
|
||||
|
||||
4 pp groups:
|
||||
|
||||
- [0, 4]
|
||||
- [1, 5]
|
||||
- [2, 6]
|
||||
- [3, 7]
|
||||
|
||||
A node with 8 GPUs, tp_size = 4, cp_size = 2, pp_size = 1
|
||||
|
||||
2 tp groups:
|
||||
|
||||
- [0, 1, 2, 3]
|
||||
- [4, 5, 6, 7]
|
||||
|
||||
4 cp groups:
|
||||
|
||||
- [0, 4]
|
||||
- [1, 5]
|
||||
- [2, 6]
|
||||
- [3, 7]
|
||||
|
||||
A node with 8 GPUs, moe_tp_size = 2, moe_ep_size = 4
|
||||
|
||||
4 moe_tp groups:
|
||||
|
||||
- [0, 4]
|
||||
- [1, 5]
|
||||
- [2, 6]
|
||||
- [3, 7]
|
||||
|
||||
2 moe_ep groups:
|
||||
|
||||
- [0, 1, 2, 3]
|
||||
- [4, 5, 6, 7]
|
||||
|
||||
2 nodes with 16 GPUs, moe_tp_size = 2, moe_ep_size = 4, pp_size = 2
|
||||
|
||||
8 moe_tp groups:
|
||||
|
||||
- [0 4]
|
||||
- [1 5]
|
||||
- [2 6]
|
||||
- [3 7]
|
||||
- [8 12]
|
||||
- [9 13]
|
||||
- [10 14]
|
||||
- [11 15]
|
||||
|
||||
4 moe_ep groups:
|
||||
|
||||
- [0, 1, 2, 3]
|
||||
- [4, 5, 6, 7]
|
||||
- [8, 9, 10, 11]
|
||||
- [12, 13, 14, 15]
|
||||
|
||||
8 pp groups:
|
||||
|
||||
- [0 8]
|
||||
- [1 9]
|
||||
- [2 10]
|
||||
- [3 11]
|
||||
- [4 12]
|
||||
- [5 13]
|
||||
- [6 14]
|
||||
- [7 15]
|
||||
|
||||
2 nodes with 8 GPUs, tp_size 2, pp_size 2, cp_size 2
|
||||
|
||||
4 tp groups:
|
||||
- [0, 1]
|
||||
- [2, 3]
|
||||
- [4, 5]
|
||||
- [6, 7]
|
||||
|
||||
4 pp groups:
|
||||
- [0, 4]
|
||||
- [1, 5]
|
||||
- [2, 6]
|
||||
- [3, 7]
|
||||
|
||||
4 cp groups:
|
||||
- [0, 2]
|
||||
- [1, 3]
|
||||
- [4, 6]
|
||||
- [5, 7]
|
||||
'''
|
||||
tp_rank: int
|
||||
pp_rank: int
|
||||
cp_rank: int
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -196,11 +110,11 @@ class Mapping(object):
|
||||
)
|
||||
|
||||
moe_tp_ep_size = moe_tp_size * moe_ep_size
|
||||
moe_tp_cluster_ep_size = moe_tp_ep_size * moe_cluster_size
|
||||
if moe_tp_cluster_ep_size != moe_world_size:
|
||||
self.moe_tp_cluster_ep_size = moe_tp_ep_size * moe_cluster_size
|
||||
if self.moe_tp_cluster_ep_size != moe_world_size:
|
||||
raise ValueError(
|
||||
"moe_tp_size * moe_ep_size * moe_cluster_size must equal to moe_world_size, "
|
||||
f"but got {moe_tp_cluster_ep_size} != {moe_world_size}")
|
||||
f"but got {self.moe_tp_cluster_ep_size} != {moe_world_size}")
|
||||
|
||||
attn_tp_cp_size = attn_tp_size * attn_cp_size
|
||||
if attn_tp_cp_size != tp_size * cp_size:
|
||||
@ -213,6 +127,9 @@ class Mapping(object):
|
||||
raise NotImplementedError(
|
||||
f"CP {cp_type} doesn't support MoE tp/ep yet")
|
||||
|
||||
if moe_cluster_size > 1:
|
||||
assert moe_ep_size == 1
|
||||
|
||||
self.tp_size = tp_size
|
||||
self.cp_size = cp_size
|
||||
self.cp_config = cp_config if cp_config is not None else {}
|
||||
@ -230,6 +147,7 @@ class Mapping(object):
|
||||
self.enable_lm_head_tp_in_adp = enable_lm_head_tp_in_adp
|
||||
self.rank = rank
|
||||
self.gpus_per_node = gpus_per_node
|
||||
|
||||
self.pp_groups = []
|
||||
self.cp_groups = []
|
||||
self.tp_groups = []
|
||||
@ -237,60 +155,8 @@ class Mapping(object):
|
||||
self.moe_tp_groups = []
|
||||
self.moe_ep_groups = []
|
||||
|
||||
if moe_cluster_size > 1:
|
||||
assert moe_ep_size == 1
|
||||
|
||||
# init pp group
|
||||
for i in range(tp_size * cp_size):
|
||||
ranks = range(i, world_size, tp_size * cp_size)
|
||||
self.pp_groups.append(list(ranks))
|
||||
|
||||
# init cp group
|
||||
for i in range(pp_size):
|
||||
for j in range(tp_size):
|
||||
ranks = range(i * tp_size * cp_size + j,
|
||||
(i + 1) * tp_size * cp_size + j, tp_size)
|
||||
self.cp_groups.append(list(ranks))
|
||||
|
||||
# init tp group
|
||||
for i in range(pp_size):
|
||||
for j in range(cp_size):
|
||||
ranks = range(i * tp_size * cp_size + j * tp_size,
|
||||
i * tp_size * cp_size + (j + 1) * tp_size)
|
||||
self.tp_groups.append(list(ranks))
|
||||
|
||||
# init moe tp group
|
||||
for i in range(pp_size):
|
||||
for j in range(moe_cluster_size * moe_ep_size):
|
||||
ranks = range(i * moe_tp_cluster_ep_size + j,
|
||||
(i + 1) * moe_tp_cluster_ep_size,
|
||||
moe_cluster_size * moe_ep_size)
|
||||
self.moe_tp_groups.append(list(ranks))
|
||||
|
||||
# init moe cluster group
|
||||
for i in range(pp_size):
|
||||
for j in range(moe_tp_size):
|
||||
ranks = range(
|
||||
i * moe_tp_cluster_ep_size +
|
||||
j * moe_cluster_size * moe_ep_size,
|
||||
i * moe_tp_cluster_ep_size +
|
||||
(j + 1) * moe_cluster_size * moe_ep_size)
|
||||
self.moe_cluster_groups.append(list(ranks))
|
||||
|
||||
# init moe ep group
|
||||
for i in range(pp_size):
|
||||
for j in range(moe_tp_size):
|
||||
for k in range(moe_cluster_size):
|
||||
ranks = range(
|
||||
i * moe_tp_cluster_ep_size +
|
||||
j * moe_cluster_size * moe_ep_size + k * moe_ep_size,
|
||||
i * moe_tp_cluster_ep_size +
|
||||
j * moe_cluster_size * moe_ep_size +
|
||||
(k + 1) * moe_ep_size)
|
||||
self.moe_ep_groups.append(list(ranks))
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, Mapping):
|
||||
if not isinstance(other, MappingBase):
|
||||
return NotImplemented
|
||||
|
||||
return (self.world_size == other.world_size and self.rank == other.rank
|
||||
@ -338,20 +204,6 @@ class Mapping(object):
|
||||
)
|
||||
self._rank = rank
|
||||
|
||||
@property
|
||||
def tp_rank(self):
|
||||
return 0 if self.auto_parallel else self.rank % self.tp_size
|
||||
|
||||
@property
|
||||
def pp_rank(self):
|
||||
return 0 if self.auto_parallel else self.rank // (self.tp_size *
|
||||
self.cp_size)
|
||||
|
||||
@property
|
||||
def cp_rank(self):
|
||||
return 0 if self.auto_parallel else self.rank % (
|
||||
self.tp_size * self.cp_size) // self.tp_size
|
||||
|
||||
@property
|
||||
def moe_tp_rank(self):
|
||||
return self.tp_rank // (self.moe_ep_size * self.moe_cluster_size)
|
||||
@ -364,37 +216,11 @@ class Mapping(object):
|
||||
def moe_ep_rank(self):
|
||||
return self.tp_rank % self.moe_ep_size
|
||||
|
||||
@property
|
||||
def tp_group(self):
|
||||
return self.tp_groups[self.pp_rank * self.cp_size + self.cp_rank]
|
||||
|
||||
@property
|
||||
def pp_group(self):
|
||||
return self.pp_groups[self.cp_rank * self.tp_size + self.tp_rank]
|
||||
|
||||
@property
|
||||
def cp_group(self):
|
||||
return self.cp_groups[self.pp_rank * self.tp_size + self.tp_rank]
|
||||
|
||||
@property
|
||||
def moe_tp_group(self):
|
||||
return self.moe_tp_groups[self.pp_rank * self.moe_cluster_size *
|
||||
self.moe_ep_size +
|
||||
self.moe_cluster_rank * self.moe_ep_size +
|
||||
self.moe_ep_rank]
|
||||
|
||||
@property
|
||||
def moe_cluster_group(self):
|
||||
return self.moe_cluster_groups[self.pp_rank * self.moe_tp_size +
|
||||
self.moe_tp_rank]
|
||||
|
||||
@property
|
||||
def moe_ep_group(self):
|
||||
return self.moe_ep_groups[self.pp_rank * self.moe_tp_size *
|
||||
self.moe_cluster_size +
|
||||
self.moe_tp_rank * self.moe_cluster_size +
|
||||
self.moe_cluster_rank]
|
||||
|
||||
@property
|
||||
def node_rank(self):
|
||||
return self.rank // self.gpus_per_node
|
||||
@ -517,3 +343,280 @@ class Mapping(object):
|
||||
'enable_attention_dp': self.enable_attention_dp,
|
||||
'enable_lm_head_tp_in_adp': self.enable_lm_head_tp_in_adp,
|
||||
}
|
||||
|
||||
|
||||
class Mapping(MappingBase):
|
||||
"""
|
||||
A node with 8 GPUs, tp_size = 4, cp_size = 1, pp_size = 2
|
||||
|
||||
2 tp groups:
|
||||
|
||||
- [0, 1, 2, 3]
|
||||
- [4, 5, 6, 7]
|
||||
|
||||
4 pp groups:
|
||||
|
||||
- [0, 4]
|
||||
- [1, 5]
|
||||
- [2, 6]
|
||||
- [3, 7]
|
||||
|
||||
A node with 8 GPUs, tp_size = 4, cp_size = 2, pp_size = 1
|
||||
|
||||
2 tp groups:
|
||||
|
||||
- [0, 1, 2, 3]
|
||||
- [4, 5, 6, 7]
|
||||
|
||||
4 cp groups:
|
||||
|
||||
- [0, 4]
|
||||
- [1, 5]
|
||||
- [2, 6]
|
||||
- [3, 7]
|
||||
|
||||
A node with 8 GPUs, moe_tp_size = 2, moe_ep_size = 4
|
||||
|
||||
4 moe_tp groups:
|
||||
|
||||
- [0, 4]
|
||||
- [1, 5]
|
||||
- [2, 6]
|
||||
- [3, 7]
|
||||
|
||||
2 moe_ep groups:
|
||||
|
||||
- [0, 1, 2, 3]
|
||||
- [4, 5, 6, 7]
|
||||
|
||||
2 nodes with 16 GPUs, moe_tp_size = 2, moe_ep_size = 4, pp_size = 2
|
||||
|
||||
8 moe_tp groups:
|
||||
|
||||
- [0 4]
|
||||
- [1 5]
|
||||
- [2 6]
|
||||
- [3 7]
|
||||
- [8 12]
|
||||
- [9 13]
|
||||
- [10 14]
|
||||
- [11 15]
|
||||
|
||||
4 moe_ep groups:
|
||||
|
||||
- [0, 1, 2, 3]
|
||||
- [4, 5, 6, 7]
|
||||
- [8, 9, 10, 11]
|
||||
- [12, 13, 14, 15]
|
||||
|
||||
8 pp groups:
|
||||
|
||||
- [0 8]
|
||||
- [1 9]
|
||||
- [2 10]
|
||||
- [3 11]
|
||||
- [4 12]
|
||||
- [5 13]
|
||||
- [6 14]
|
||||
- [7 15]
|
||||
|
||||
2 nodes with 8 GPUs, tp_size 2, pp_size 2, cp_size 2
|
||||
|
||||
4 tp groups:
|
||||
- [0, 1]
|
||||
- [2, 3]
|
||||
- [4, 5]
|
||||
- [6, 7]
|
||||
|
||||
4 pp groups:
|
||||
- [0, 4]
|
||||
- [1, 5]
|
||||
- [2, 6]
|
||||
- [3, 7]
|
||||
|
||||
4 cp groups:
|
||||
- [0, 2]
|
||||
- [1, 3]
|
||||
- [4, 6]
|
||||
- [5, 7]
|
||||
"""
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
if mpi_disabled():
|
||||
return super().__new__(DeviceMeshTopology)
|
||||
else:
|
||||
return super().__new__(MpiTopology)
|
||||
|
||||
# Intentionally repeated for type hints
|
||||
def __init__(
|
||||
self,
|
||||
world_size=1,
|
||||
rank=0,
|
||||
gpus_per_node=8,
|
||||
*,
|
||||
cp_size=1,
|
||||
cp_config=None,
|
||||
tp_size=1,
|
||||
pp_size=1,
|
||||
moe_cluster_size=-1, # -1 means no moe
|
||||
moe_tp_size=-1, # -1 means no moe
|
||||
moe_ep_size=-1, # -1 means no moe
|
||||
attn_tp_size=-1,
|
||||
attn_cp_size=-1,
|
||||
auto_parallel=False,
|
||||
enable_attention_dp=False,
|
||||
enable_lm_head_tp_in_adp=False):
|
||||
super().__init__(world_size=world_size,
|
||||
rank=rank,
|
||||
gpus_per_node=gpus_per_node,
|
||||
cp_size=cp_size,
|
||||
cp_config=cp_config,
|
||||
tp_size=tp_size,
|
||||
pp_size=pp_size,
|
||||
moe_cluster_size=moe_cluster_size,
|
||||
moe_tp_size=moe_tp_size,
|
||||
moe_ep_size=moe_ep_size,
|
||||
attn_tp_size=attn_tp_size,
|
||||
attn_cp_size=attn_cp_size,
|
||||
auto_parallel=auto_parallel,
|
||||
enable_attention_dp=enable_attention_dp,
|
||||
enable_lm_head_tp_in_adp=enable_lm_head_tp_in_adp)
|
||||
|
||||
# DeviceMesh specific methods
|
||||
@property
|
||||
def tp_group_pg(self):
|
||||
raise NotImplementedError("tp_group_pg is not implemented.")
|
||||
|
||||
@property
|
||||
def pp_group_pg(self):
|
||||
raise NotImplementedError("pp_group_pg is not implemented.")
|
||||
|
||||
@property
|
||||
def cp_group_pg(self):
|
||||
raise NotImplementedError("cp_group_pg is not implemented.")
|
||||
|
||||
@property
|
||||
def moe_tp_group_pg(self):
|
||||
raise NotImplementedError("moe_tp_group_pg is not implemented.")
|
||||
|
||||
@property
|
||||
def moe_ep_group_pg(self):
|
||||
raise NotImplementedError("moe_ep_group_pg is not implemented.")
|
||||
|
||||
def build_mesh(self):
|
||||
raise NotImplementedError("build_mesh is not implemented.")
|
||||
|
||||
|
||||
class MpiTopology(Mapping):
|
||||
'''MPI-based mapping implementation'''
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._init_parallel_groups()
|
||||
|
||||
@property
|
||||
def tp_rank(self) -> int:
|
||||
return 0 if self.auto_parallel else self.rank % self.tp_size
|
||||
|
||||
@property
|
||||
def pp_rank(self) -> int:
|
||||
return 0 if self.auto_parallel else self.rank // (self.tp_size *
|
||||
self.cp_size)
|
||||
|
||||
@property
|
||||
def cp_rank(self) -> int:
|
||||
return 0 if self.auto_parallel else self.rank % (
|
||||
self.tp_size * self.cp_size) // self.tp_size
|
||||
|
||||
@property
|
||||
def tp_group(self) -> List[int]:
|
||||
return self.tp_groups[self.pp_rank * self.cp_size + self.cp_rank]
|
||||
|
||||
@property
|
||||
def pp_group(self) -> List[int]:
|
||||
return self.pp_groups[self.cp_rank * self.tp_size + self.tp_rank]
|
||||
|
||||
@property
|
||||
def cp_group(self) -> List[int]:
|
||||
return self.cp_groups[self.pp_rank * self.tp_size + self.tp_rank]
|
||||
|
||||
@property
|
||||
def moe_tp_group(self) -> List[int]:
|
||||
return self.moe_tp_groups[self.pp_rank * self.moe_cluster_size *
|
||||
self.moe_ep_size +
|
||||
self.moe_cluster_rank * self.moe_ep_size +
|
||||
self.moe_ep_rank]
|
||||
|
||||
@property
|
||||
def moe_ep_group(self) -> List[int]:
|
||||
return self.moe_ep_groups[self.pp_rank * self.moe_tp_size *
|
||||
self.moe_cluster_size +
|
||||
self.moe_tp_rank * self.moe_cluster_size +
|
||||
self.moe_cluster_rank]
|
||||
|
||||
@property
|
||||
def moe_cluster_group(self) -> List[int]:
|
||||
return self.moe_cluster_groups[self.pp_rank * self.moe_tp_size +
|
||||
self.moe_tp_rank]
|
||||
|
||||
def _init_parallel_groups(self):
|
||||
# init pp group
|
||||
for i in range(self.tp_size * self.cp_size):
|
||||
ranks = range(i, self.world_size, self.tp_size * self.cp_size)
|
||||
self.pp_groups.append(list(ranks))
|
||||
|
||||
# init cp group
|
||||
for i in range(self.pp_size):
|
||||
for j in range(self.tp_size):
|
||||
ranks = range(i * self.tp_size * self.cp_size + j,
|
||||
(i + 1) * self.tp_size * self.cp_size + j,
|
||||
self.tp_size)
|
||||
self.cp_groups.append(list(ranks))
|
||||
|
||||
# init tp group
|
||||
for i in range(self.pp_size):
|
||||
for j in range(self.cp_size):
|
||||
ranks = range(
|
||||
i * self.tp_size * self.cp_size + j * self.tp_size,
|
||||
i * self.tp_size * self.cp_size + (j + 1) * self.tp_size)
|
||||
self.tp_groups.append(list(ranks))
|
||||
|
||||
# init moe tp group
|
||||
for i in range(self.pp_size):
|
||||
for j in range(self.moe_cluster_size * self.moe_ep_size):
|
||||
ranks = range(i * self.moe_tp_cluster_ep_size + j,
|
||||
(i + 1) * self.moe_tp_cluster_ep_size,
|
||||
self.moe_cluster_size * self.moe_ep_size)
|
||||
self.moe_tp_groups.append(list(ranks))
|
||||
|
||||
# init moe cluster group
|
||||
for i in range(self.pp_size):
|
||||
for j in range(self.moe_tp_size):
|
||||
ranks = range(
|
||||
i * self.moe_tp_cluster_ep_size +
|
||||
j * self.moe_cluster_size * self.moe_ep_size,
|
||||
i * self.moe_tp_cluster_ep_size +
|
||||
(j + 1) * self.moe_cluster_size * self.moe_ep_size)
|
||||
self.moe_cluster_groups.append(list(ranks))
|
||||
|
||||
# init moe ep group
|
||||
for i in range(self.pp_size):
|
||||
for j in range(self.moe_tp_size):
|
||||
for k in range(self.moe_cluster_size):
|
||||
ranks = range(
|
||||
i * self.moe_tp_cluster_ep_size +
|
||||
j * self.moe_cluster_size * self.moe_ep_size +
|
||||
k * self.moe_ep_size, i * self.moe_tp_cluster_ep_size +
|
||||
j * self.moe_cluster_size * self.moe_ep_size +
|
||||
(k + 1) * self.moe_ep_size)
|
||||
self.moe_ep_groups.append(list(ranks))
|
||||
|
||||
|
||||
class DeviceMeshTopology(DeviceMeshTopologyImpl, Mapping):
|
||||
"""PyTorch DeviceMesh-based mapping implementation"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
assert mpi_disabled(
|
||||
), "DeviceMeshTopology is only available in Ray orchestrator mode."
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
40
tensorrt_llm/ray_stub.py
Normal file
40
tensorrt_llm/ray_stub.py
Normal file
@ -0,0 +1,40 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import functools
|
||||
|
||||
from tensorrt_llm._utils import mpi_disabled
|
||||
|
||||
if mpi_disabled():
|
||||
raise RuntimeError(
|
||||
"Ray requested (TLLM_DISABLE_MPI=1), but not installed. Please install Ray."
|
||||
)
|
||||
|
||||
|
||||
def remote(*args, **kwargs):
|
||||
|
||||
def decorator(func):
|
||||
# Returns a function that always raises.
|
||||
# Decorated class depends on ray, but ray is not installed.
|
||||
@functools.wraps(func)
|
||||
def stub_checker(*_, **__):
|
||||
raise RuntimeError(
|
||||
"Ray not installed, cannot use Ray based feature.")
|
||||
|
||||
return stub_checker
|
||||
|
||||
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
|
||||
return decorator(args[0])
|
||||
|
||||
return decorator
|
||||
@ -31,7 +31,7 @@ from tensorrt_llm.quantization import QuantAlgo
|
||||
from ..conftest import (get_device_count, get_device_memory, llm_models_root,
|
||||
parametrize_with_ids, skip_no_hopper,
|
||||
skip_post_blackwell, skip_pre_ada, skip_pre_blackwell,
|
||||
skip_pre_hopper)
|
||||
skip_pre_hopper, skip_ray)
|
||||
from .accuracy_core import (GSM8K, MMLU, MMMU, CnnDailymail, GPQADiamond,
|
||||
JsonModeEval, LlmapiAccuracyTestHarness)
|
||||
|
||||
@ -1403,6 +1403,7 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
|
||||
@pytest.mark.skip_less_device(4)
|
||||
@skip_pre_hopper
|
||||
@skip_ray
|
||||
@parametrize_with_ids("torch_compile", [False, True])
|
||||
@parametrize_with_ids("fp8kv,attention_dp,cuda_graph,overlap_scheduler",
|
||||
[(False, False, False, False),
|
||||
|
||||
26
tests/integration/defs/accuracy/test_llm_api_pytorch_ray.py
Normal file
26
tests/integration/defs/accuracy/test_llm_api_pytorch_ray.py
Normal file
@ -0,0 +1,26 @@
|
||||
import pytest
|
||||
|
||||
from tensorrt_llm import LLM
|
||||
from tensorrt_llm.llmapi import KvCacheConfig
|
||||
|
||||
from ..conftest import llm_models_root
|
||||
from .accuracy_core import MMLU, LlmapiAccuracyTestHarness
|
||||
|
||||
pytestmark = pytest.mark.ray
|
||||
|
||||
|
||||
class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
|
||||
MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
MODEL_PATH = f"{llm_models_root()}/llama-3.1-model/Llama-3.1-8B-Instruct"
|
||||
|
||||
@pytest.mark.skip_less_device(2)
|
||||
@pytest.mark.skip_less_device_memory(32000)
|
||||
def test_pp2_ray(self):
|
||||
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.6)
|
||||
|
||||
with LLM(self.MODEL_PATH,
|
||||
orchestrator_type="ray",
|
||||
pipeline_parallel_size=2,
|
||||
kv_cache_config=kv_cache_config) as llm:
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
@ -16,6 +16,7 @@ import copy
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
import socket
|
||||
import time
|
||||
from difflib import SequenceMatcher
|
||||
from pathlib import Path
|
||||
@ -1135,3 +1136,14 @@ def get_mmlu_accuracy(output):
|
||||
print(f"MMLU weighted average accuracy is: {mmlu_accuracy}")
|
||||
|
||||
return mmlu_accuracy
|
||||
|
||||
|
||||
def wait_for_server(host, port, timeout_seconds=180):
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout_seconds:
|
||||
try:
|
||||
with socket.create_connection((host, port), timeout=5):
|
||||
return True
|
||||
except (socket.error, ConnectionRefusedError, OSError):
|
||||
time.sleep(2)
|
||||
return False
|
||||
|
||||
@ -36,6 +36,7 @@ import tqdm
|
||||
import yaml
|
||||
from _pytest.mark import ParameterSet
|
||||
|
||||
from tensorrt_llm._utils import mpi_disabled
|
||||
from tensorrt_llm.bindings import ipc_nvls_supported
|
||||
from tensorrt_llm.llmapi.mpi_session import get_mpi_world_size
|
||||
|
||||
@ -2071,6 +2072,13 @@ def pytest_addoption(parser):
|
||||
parser.addoption("--perf",
|
||||
action="store_true",
|
||||
help="'--perf' will run perf tests")
|
||||
parser.addoption(
|
||||
"--run-ray",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help=
|
||||
"Enable Ray orchestrator path for integration tests (disables MPI).",
|
||||
)
|
||||
parser.addoption(
|
||||
"--perf-log-formats",
|
||||
help=
|
||||
@ -2171,6 +2179,8 @@ def pytest_collection_modifyitems(session, config, items):
|
||||
def pytest_configure(config):
|
||||
# avoid thread leak of tqdm's TMonitor
|
||||
tqdm.tqdm.monitor_interval = 0
|
||||
if config.getoption("--run-ray"):
|
||||
os.environ["TLLM_DISABLE_MPI"] = "1"
|
||||
|
||||
|
||||
def deselect_by_regex(regexp, items, test_prefix, config):
|
||||
@ -2270,6 +2280,10 @@ def check_nvlink():
|
||||
skip_nvlink_inactive = pytest.mark.skipif(check_nvlink() is False,
|
||||
reason="nvlink is inactive.")
|
||||
|
||||
skip_ray = pytest.mark.skipif(
|
||||
os.environ.get("TLLM_DISABLE_MPI") == "1",
|
||||
reason="This test is skipped for Ray orchestrator.")
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def eval_venv(llm_venv):
|
||||
@ -2448,3 +2462,15 @@ def torch_empty_cache() -> None:
|
||||
"""
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def ray_cleanup(llm_venv) -> None:
|
||||
yield
|
||||
|
||||
if mpi_disabled():
|
||||
llm_venv.run_cmd([
|
||||
"-m",
|
||||
"ray.scripts.scripts",
|
||||
"stop",
|
||||
])
|
||||
|
||||
@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import contextlib
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
@ -21,10 +22,12 @@ from typing import Callable
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from defs.common import wait_for_server
|
||||
from defs.conftest import (get_sm_version, llm_models_root, skip_arm,
|
||||
skip_no_hopper)
|
||||
from defs.trt_test_alternative import check_call, check_output, popen
|
||||
|
||||
from tensorrt_llm._utils import mpi_disabled
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
|
||||
@ -267,13 +270,146 @@ def get_test_config(test_desc, example_dir, test_root):
|
||||
return config_map[test_desc]
|
||||
|
||||
|
||||
def get_extra_llm_config(config, suffix, cwd):
|
||||
extra_llm_config = {
|
||||
'orchestrator_type': 'ray',
|
||||
}
|
||||
for key, value in config.items():
|
||||
if key not in ['num_instances', 'urls']:
|
||||
extra_llm_config[key] = value
|
||||
|
||||
temp_fd, extra_config_file = tempfile.mkstemp(suffix='_%s.yaml' % suffix,
|
||||
dir=cwd)
|
||||
with os.fdopen(temp_fd, 'w') as f:
|
||||
yaml.dump(extra_llm_config, f)
|
||||
|
||||
return extra_config_file
|
||||
|
||||
|
||||
def generate_worker_commands(model_path, config, server_config,
|
||||
extra_config_file, server_role):
|
||||
worker_commands = []
|
||||
|
||||
assert model_path, "model path is required."
|
||||
|
||||
for url in server_config['urls']:
|
||||
host, port = url.split(':')
|
||||
cmd = [
|
||||
'trtllm-serve', model_path, '--host', host, '--port', port,
|
||||
'--backend', config['backend'], '--extra_llm_api_options',
|
||||
extra_config_file, '--server_role', server_role
|
||||
]
|
||||
worker_commands.append(cmd)
|
||||
return worker_commands
|
||||
|
||||
|
||||
def run_client_tests(example_dir,
|
||||
config_file,
|
||||
test_desc,
|
||||
num_iters,
|
||||
env,
|
||||
server_start_timeout,
|
||||
prompt_file,
|
||||
extra_endpoints_test,
|
||||
server_url,
|
||||
workers_proc,
|
||||
server_proc,
|
||||
use_ray=False):
|
||||
"""Run client tests against the disaggregated server."""
|
||||
client_dir = f"{example_dir}/clients"
|
||||
for _ in range(num_iters):
|
||||
client_cmd = [
|
||||
'python3', f'{client_dir}/disagg_client.py', '-c', f'{config_file}',
|
||||
'-p', f'{client_dir}/{prompt_file}', '--ignore-eos',
|
||||
'--server-start-timeout',
|
||||
str(server_start_timeout)
|
||||
]
|
||||
if prompt_file == "long_prompts.json":
|
||||
# Use max_tokens 4 for long prompts to reduce test time
|
||||
client_cmd.extend(['--max-tokens', '4'])
|
||||
|
||||
# Prepare poll processes
|
||||
worker_processes = []
|
||||
if use_ray:
|
||||
for proc_cm in workers_proc:
|
||||
worker_processes.append(proc_cm.__enter__())
|
||||
else:
|
||||
worker_processes = [workers_proc]
|
||||
|
||||
poll_procs = worker_processes + [server_proc]
|
||||
check_call(client_cmd, env=env, poll_procs=poll_procs)
|
||||
|
||||
# Streaming client run
|
||||
streaming_client_cmd = client_cmd + [
|
||||
'--streaming', '-o', 'output_streaming.json'
|
||||
]
|
||||
check_call(streaming_client_cmd, env=env, poll_procs=poll_procs)
|
||||
|
||||
# Run the chat completion endpoint test only for TinyLlama
|
||||
if test_desc == "overlap" or test_desc == "trtllm_sampler":
|
||||
chat_client_cmd = client_cmd + [
|
||||
'-e', 'chat', '-o', 'output_chat.json'
|
||||
]
|
||||
check_call(chat_client_cmd, env=env, poll_procs=poll_procs)
|
||||
|
||||
streaming_chat_client_cmd = chat_client_cmd + [
|
||||
'--streaming', '-o', 'output_streaming_chat.json'
|
||||
]
|
||||
check_call(streaming_chat_client_cmd,
|
||||
env=env,
|
||||
poll_procs=poll_procs)
|
||||
|
||||
# Skip output verification for long prompts test
|
||||
if prompt_file == "long_prompts.json":
|
||||
continue
|
||||
|
||||
if extra_endpoints_test is not None:
|
||||
extra_endpoints_test(server_url)
|
||||
|
||||
# Verify outputs
|
||||
not_expected_strings = ["Berlin Berlin"]
|
||||
|
||||
output_files = ['output.json', 'output_streaming.json']
|
||||
if test_desc == "overlap" or test_desc == "trtllm_sampler":
|
||||
# Disable streaming chat completion for overlap test
|
||||
# due to bug
|
||||
output_files.extend(['output_chat.json'])
|
||||
|
||||
if test_desc.startswith("gen_only"):
|
||||
continue
|
||||
|
||||
for output_file in output_files:
|
||||
with open(output_file, 'r') as f:
|
||||
content = f.read()
|
||||
if "deepseek_v3_lite" in test_desc or output_file == "output_chat.json":
|
||||
expected_strings = [
|
||||
"Berlin", ["Asyncio is a", "Asyncio module in"]
|
||||
]
|
||||
else:
|
||||
expected_strings = [
|
||||
"The capital of Germany is Berlin",
|
||||
"Asyncio is a Python library"
|
||||
]
|
||||
for expected_string in expected_strings:
|
||||
if isinstance(expected_string, list):
|
||||
# At least one of the strings in the list should be found in the content
|
||||
assert any(
|
||||
string in content for string in expected_string
|
||||
), f"None of the strings in {expected_string} found in {output_file}"
|
||||
else:
|
||||
assert expected_string in content, f"Expected string '{expected_string}' not found in {output_file}"
|
||||
for not_expected_string in not_expected_strings:
|
||||
assert not_expected_string not in content, f"Unexpected string '{not_expected_string}' found in {output_file}"
|
||||
|
||||
|
||||
def run_disaggregated_test(example_dir,
|
||||
test_desc,
|
||||
num_iters=5,
|
||||
env=None,
|
||||
cwd=None,
|
||||
prompt_file="prompts.json",
|
||||
extra_endpoints_test: Callable[[str], None] = None):
|
||||
extra_endpoints_test: Callable[[str], None] = None,
|
||||
model_path=None):
|
||||
"""Run disaggregated test with given configuration."""
|
||||
cleanup_output_files()
|
||||
run_env = env.copy()
|
||||
@ -282,11 +418,42 @@ def run_disaggregated_test(example_dir,
|
||||
num_ranks, config_file = get_test_config(test_desc, example_dir,
|
||||
os.path.dirname(__file__))
|
||||
|
||||
workers_cmd = [
|
||||
'mpirun', '--allow-run-as-root', '--oversubscribe', '-n',
|
||||
str(num_ranks), 'trtllm-serve', 'disaggregated_mpi_worker', '-c',
|
||||
config_file
|
||||
]
|
||||
use_ray = mpi_disabled()
|
||||
if not use_ray:
|
||||
workers_cmd = [
|
||||
'mpirun', '--allow-run-as-root', '--oversubscribe', '-n',
|
||||
str(num_ranks), 'trtllm-serve', 'disaggregated_mpi_worker', '-c',
|
||||
config_file
|
||||
]
|
||||
else:
|
||||
with open(config_file, 'r') as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
if config['backend'] != "pytorch":
|
||||
pytest.skip(
|
||||
"Ray orchestrator is only supported with pytorch backend.")
|
||||
|
||||
extra_config_files = []
|
||||
workers_cmds = []
|
||||
subprocess.run(['ray', 'start', '--head', '--disable-usage-stats'],
|
||||
check=True)
|
||||
|
||||
# Generate ctx and gen server worker commands
|
||||
ctx_extra_config_file = get_extra_llm_config(config['context_servers'],
|
||||
"ctx", cwd)
|
||||
extra_config_files.append(ctx_extra_config_file)
|
||||
workers_cmds.extend(
|
||||
generate_worker_commands(model_path, config,
|
||||
config['context_servers'],
|
||||
ctx_extra_config_file, 'context'))
|
||||
|
||||
gen_extra_config_file = get_extra_llm_config(
|
||||
config['generation_servers'], "gen", cwd)
|
||||
extra_config_files.append(gen_extra_config_file)
|
||||
workers_cmds.extend(
|
||||
generate_worker_commands(model_path, config,
|
||||
config['generation_servers'],
|
||||
gen_extra_config_file, 'generation'))
|
||||
|
||||
server_start_timeout = 900
|
||||
server_cmd = [
|
||||
@ -296,101 +463,79 @@ def run_disaggregated_test(example_dir,
|
||||
server_url = get_disagg_server_url_from_cfg(config_file)
|
||||
|
||||
try:
|
||||
with ( # Start workers
|
||||
open('output_workers.log', 'w') as output_workers,
|
||||
popen(workers_cmd,
|
||||
stdout=output_workers,
|
||||
stderr=subprocess.STDOUT,
|
||||
env=run_env,
|
||||
cwd=cwd) as workers_proc,
|
||||
# Start server
|
||||
open('output_disagg.log', 'w') as output_disagg,
|
||||
popen(server_cmd,
|
||||
stdout=output_disagg,
|
||||
stderr=subprocess.STDOUT,
|
||||
env=run_env,
|
||||
cwd=cwd) as server_proc):
|
||||
client_dir = f"{example_dir}/clients"
|
||||
for _ in range(num_iters):
|
||||
client_cmd = [
|
||||
'python3', f'{client_dir}/disagg_client.py', '-c',
|
||||
f'{config_file}', '-p', f'{client_dir}/{prompt_file}',
|
||||
'--ignore-eos', '--server-start-timeout',
|
||||
str(server_start_timeout)
|
||||
]
|
||||
if prompt_file == "long_prompts.json":
|
||||
# Use max_tokens 4 for long prompts to reduce test time
|
||||
client_cmd.extend(['--max-tokens', '4'])
|
||||
check_call(client_cmd,
|
||||
env=env,
|
||||
poll_procs=[workers_proc, server_proc])
|
||||
if not use_ray:
|
||||
with ( # Start workers
|
||||
open('output_workers.log', 'w') as output_workers,
|
||||
popen(workers_cmd,
|
||||
stdout=output_workers,
|
||||
stderr=subprocess.STDOUT,
|
||||
env=run_env,
|
||||
cwd=cwd) as workers_proc,
|
||||
# Start server
|
||||
open('output_disagg.log', 'w') as output_disagg,
|
||||
popen(server_cmd,
|
||||
stdout=output_disagg,
|
||||
stderr=subprocess.STDOUT,
|
||||
env=run_env,
|
||||
cwd=cwd) as server_proc):
|
||||
run_client_tests(example_dir,
|
||||
config_file,
|
||||
test_desc,
|
||||
num_iters,
|
||||
env,
|
||||
server_start_timeout,
|
||||
prompt_file,
|
||||
extra_endpoints_test,
|
||||
server_url,
|
||||
workers_proc,
|
||||
server_proc,
|
||||
use_ray=False)
|
||||
|
||||
# Streaming client run
|
||||
streaming_client_cmd = client_cmd + [
|
||||
'--streaming', '-o', 'output_streaming.json'
|
||||
]
|
||||
check_call(streaming_client_cmd,
|
||||
env=env,
|
||||
poll_procs=[workers_proc, server_proc])
|
||||
else:
|
||||
workers_proc = []
|
||||
with contextlib.ExitStack() as stack:
|
||||
workers_log = stack.enter_context(
|
||||
open('output_workers.log', 'w'))
|
||||
|
||||
# Run the chat completion endpoint test only for TinyLlama
|
||||
if test_desc == "overlap" or test_desc == "trtllm_sampler":
|
||||
chat_client_cmd = client_cmd + [
|
||||
'-e', 'chat', '-o', 'output_chat.json'
|
||||
]
|
||||
check_call(chat_client_cmd,
|
||||
env=env,
|
||||
poll_procs=[workers_proc, server_proc])
|
||||
for cmd in workers_cmds:
|
||||
proc = stack.enter_context(
|
||||
popen(
|
||||
cmd,
|
||||
stdout=workers_log,
|
||||
stderr=subprocess.STDOUT,
|
||||
env=run_env,
|
||||
cwd=cwd,
|
||||
))
|
||||
workers_proc.append(proc)
|
||||
|
||||
streaming_chat_client_cmd = chat_client_cmd + [
|
||||
'--streaming', '-o', 'output_streaming_chat.json'
|
||||
]
|
||||
check_call(streaming_chat_client_cmd,
|
||||
env=env,
|
||||
poll_procs=[workers_proc, server_proc])
|
||||
output_disagg = stack.enter_context(
|
||||
open('output_disagg.log', 'w'))
|
||||
server_proc = stack.enter_context(
|
||||
popen(server_cmd,
|
||||
stdout=output_disagg,
|
||||
stderr=subprocess.STDOUT,
|
||||
env=run_env,
|
||||
cwd=cwd))
|
||||
|
||||
# Skip output verification for long prompts test
|
||||
if prompt_file == "long_prompts.json":
|
||||
continue
|
||||
if not wait_for_server("localhost",
|
||||
8000,
|
||||
timeout_seconds=server_start_timeout):
|
||||
raise RuntimeError(
|
||||
f"Disaggregated server failed to start within {server_start_timeout} seconds"
|
||||
)
|
||||
|
||||
if extra_endpoints_test is not None:
|
||||
extra_endpoints_test(server_url)
|
||||
|
||||
# Verify outputs
|
||||
not_expected_strings = ["Berlin Berlin"]
|
||||
|
||||
output_files = ['output.json', 'output_streaming.json']
|
||||
if test_desc == "overlap" or test_desc == "trtllm_sampler":
|
||||
# Disable streaming chat completion for overlap test
|
||||
# due to bug
|
||||
output_files.extend(['output_chat.json'])
|
||||
|
||||
if test_desc.startswith("gen_only"):
|
||||
continue
|
||||
|
||||
for output_file in output_files:
|
||||
with open(output_file, 'r') as f:
|
||||
content = f.read()
|
||||
if "deepseek_v3_lite" in test_desc or output_file == "output_chat.json":
|
||||
expected_strings = [
|
||||
"Berlin", ["Asyncio is a", "Asyncio module in"]
|
||||
]
|
||||
else:
|
||||
expected_strings = [
|
||||
"The capital of Germany is Berlin",
|
||||
"Asyncio is a Python library"
|
||||
]
|
||||
for expected_string in expected_strings:
|
||||
if isinstance(expected_string, list):
|
||||
# At least one of the strings in the list should be found in the content
|
||||
assert any(
|
||||
string in content
|
||||
for string in expected_string
|
||||
), f"None of the strings in {expected_string} found in {output_file}"
|
||||
else:
|
||||
assert expected_string in content, f"Expected string '{expected_string}' not found in {output_file}"
|
||||
for not_expected_string in not_expected_strings:
|
||||
assert not_expected_string not in content, f"Unexpected string '{not_expected_string}' found in {output_file}"
|
||||
run_client_tests(example_dir,
|
||||
config_file,
|
||||
test_desc,
|
||||
num_iters,
|
||||
env,
|
||||
server_start_timeout,
|
||||
prompt_file,
|
||||
extra_endpoints_test,
|
||||
server_url,
|
||||
workers_proc,
|
||||
server_proc,
|
||||
use_ray=True)
|
||||
except Exception:
|
||||
# Print outputs on error
|
||||
logger.error("-------- Workers output --------")
|
||||
@ -402,10 +547,16 @@ def run_disaggregated_test(example_dir,
|
||||
logger.error(f.read())
|
||||
raise
|
||||
finally:
|
||||
server_proc.terminate()
|
||||
workers_proc.terminate()
|
||||
server_proc.wait()
|
||||
workers_proc.wait()
|
||||
if use_ray:
|
||||
subprocess.run(['ray', 'stop', '--force'], check=False)
|
||||
for extra_file in extra_config_files:
|
||||
if os.path.exists(extra_file):
|
||||
os.remove(extra_file)
|
||||
elif 'server_proc' in locals() and 'workers_proc' in locals():
|
||||
server_proc.terminate()
|
||||
workers_proc.terminate()
|
||||
server_proc.wait()
|
||||
workers_proc.wait()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("llama_model_root", ['TinyLlama-1.1B-Chat-v1.0'],
|
||||
@ -831,7 +982,8 @@ def test_disaggregated_ctxpp2_genpp2(disaggregated_test_root, llm_venv,
|
||||
run_disaggregated_test(disaggregated_example_root,
|
||||
"ctxpp2_genpp2",
|
||||
env=llm_venv._new_env,
|
||||
cwd=llm_venv.get_working_directory())
|
||||
cwd=llm_venv.get_working_directory(),
|
||||
model_path=llama_model_root)
|
||||
|
||||
|
||||
@pytest.mark.skip_less_device(4)
|
||||
@ -851,7 +1003,8 @@ def test_disaggregated_ctxtp2_genpp2(disaggregated_test_root, llm_venv,
|
||||
run_disaggregated_test(disaggregated_example_root,
|
||||
"ctxtp2_genpp2",
|
||||
env=llm_venv._new_env,
|
||||
cwd=llm_venv.get_working_directory())
|
||||
cwd=llm_venv.get_working_directory(),
|
||||
model_path=llama_model_root)
|
||||
|
||||
|
||||
@pytest.mark.skip_less_device(4)
|
||||
@ -871,7 +1024,8 @@ def test_disaggregated_ctxpp2_gentp2(disaggregated_test_root, llm_venv,
|
||||
run_disaggregated_test(disaggregated_example_root,
|
||||
"ctxpp2_gentp2",
|
||||
env=llm_venv._new_env,
|
||||
cwd=llm_venv.get_working_directory())
|
||||
cwd=llm_venv.get_working_directory(),
|
||||
model_path=llama_model_root)
|
||||
|
||||
|
||||
@pytest.mark.skip_less_device(8)
|
||||
@ -932,7 +1086,8 @@ def test_disaggregated_ctxpp4_gentp4(disaggregated_test_root, llm_venv,
|
||||
run_disaggregated_test(disaggregated_example_root,
|
||||
"ctxpp4_gentp4",
|
||||
env=llm_venv._new_env,
|
||||
cwd=llm_venv.get_working_directory())
|
||||
cwd=llm_venv.get_working_directory(),
|
||||
model_path=llama_model_root)
|
||||
|
||||
|
||||
@skip_no_hopper
|
||||
@ -1021,7 +1176,8 @@ def test_disaggregated_deepseek_v3_lite_fp8_ctxpp2_gentp2_one_mtp(
|
||||
run_disaggregated_test(disaggregated_example_root,
|
||||
"deepseek_v3_lite_fp8_ctxpp2_gentp2_one_mtp",
|
||||
env=llm_venv._new_env,
|
||||
cwd=llm_venv.get_working_directory())
|
||||
cwd=llm_venv.get_working_directory(),
|
||||
model_path=deepseek_v3_model_root)
|
||||
|
||||
|
||||
@skip_no_hopper
|
||||
@ -1048,7 +1204,8 @@ def test_disaggregated_deepseek_v3_lite_fp8_ucx(disaggregated_test_root,
|
||||
run_disaggregated_test(disaggregated_example_root,
|
||||
"deepseek_v3_lite_fp8_ucx",
|
||||
env=env,
|
||||
cwd=llm_venv.get_working_directory())
|
||||
cwd=llm_venv.get_working_directory(),
|
||||
model_path=deepseek_v3_model_root)
|
||||
|
||||
|
||||
@skip_no_hopper
|
||||
@ -1074,7 +1231,8 @@ def test_disaggregated_deepseek_v3_lite_fp8_nixl(disaggregated_test_root,
|
||||
run_disaggregated_test(disaggregated_example_root,
|
||||
"deepseek_v3_lite_fp8_nixl",
|
||||
env=env,
|
||||
cwd=llm_venv.get_working_directory())
|
||||
cwd=llm_venv.get_working_directory(),
|
||||
model_path=deepseek_v3_model_root)
|
||||
|
||||
|
||||
@skip_no_hopper
|
||||
@ -1262,7 +1420,8 @@ def test_disaggregated_deepseek_v3_lite_fp8_tp1_attention_dp_overlap_one_mtp(
|
||||
disaggregated_example_root,
|
||||
"deepseek_v3_lite_fp8_tp1_attention_dp_overlap_one_mtp",
|
||||
env=llm_venv._new_env,
|
||||
cwd=llm_venv.get_working_directory())
|
||||
cwd=llm_venv.get_working_directory(),
|
||||
model_path=deepseek_v3_model_root)
|
||||
|
||||
|
||||
@skip_no_hopper
|
||||
|
||||
102
tests/integration/defs/examples/test_ray.py
Normal file
102
tests/integration/defs/examples/test_ray.py
Normal file
@ -0,0 +1,102 @@
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
import pytest
|
||||
from defs.common import venv_check_call, wait_for_server
|
||||
from defs.conftest import get_device_count, llm_models_root
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def ray_example_root(llm_root):
|
||||
example_root = os.path.join(llm_root, "examples", "ray_orchestrator")
|
||||
return example_root
|
||||
|
||||
|
||||
def test_llm_inference_async_ray(ray_example_root, llm_venv):
|
||||
script_path = os.path.join(ray_example_root, "llm_inference_async_ray.py")
|
||||
venv_check_call(llm_venv, [script_path])
|
||||
|
||||
|
||||
@pytest.mark.skip_less_device(2)
|
||||
@pytest.mark.parametrize("tp_size,pp_size,ep_size", [
|
||||
(2, 1, -1),
|
||||
(1, 2, -1),
|
||||
(2, 2, -1),
|
||||
(2, 1, 2),
|
||||
],
|
||||
ids=["tp2", "pp2", "tp2pp2", "tep2"])
|
||||
def test_llm_inference_distributed_ray(ray_example_root, llm_venv, tp_size,
|
||||
pp_size, ep_size):
|
||||
world_size = tp_size * pp_size
|
||||
|
||||
if get_device_count() < world_size:
|
||||
pytest.skip(f"Need {world_size} GPUs.")
|
||||
|
||||
script_path = os.path.join(ray_example_root,
|
||||
"llm_inference_distributed_ray.py")
|
||||
|
||||
cmd = [
|
||||
script_path, "--tp_size",
|
||||
str(tp_size), "--pp_size",
|
||||
str(pp_size), "--moe_ep_size",
|
||||
str(ep_size)
|
||||
]
|
||||
|
||||
if ep_size != -1:
|
||||
model_dir = f"{llm_models_root()}/DeepSeek-V3-Lite/bf16"
|
||||
cmd.extend(["--model_dir", model_dir])
|
||||
|
||||
venv_check_call(llm_venv, cmd)
|
||||
|
||||
|
||||
@pytest.mark.skip_less_device(2)
|
||||
@pytest.mark.parametrize("tp_size", [1, 2], ids=["tp1", "tp2"])
|
||||
def test_ray_disaggregated_serving(ray_example_root, llm_venv, tp_size):
|
||||
if get_device_count() < tp_size * 2:
|
||||
pytest.skip(f"Need {tp_size * 2} GPUs.")
|
||||
|
||||
disagg_dir = os.path.join(ray_example_root, "disaggregated")
|
||||
script_path = os.path.join(disagg_dir, "disagg_serving_local.sh")
|
||||
|
||||
subprocess.run("ray stop --force", shell=True, check=False)
|
||||
|
||||
proc = subprocess.Popen(
|
||||
["bash", script_path, "--executor", "ray", "--tp_size",
|
||||
str(tp_size)],
|
||||
cwd=disagg_dir,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
)
|
||||
try:
|
||||
assert wait_for_server("localhost", 8000, timeout_seconds=180), \
|
||||
"Disaggregated server failed to start within 3 minutes"
|
||||
|
||||
result = subprocess.run([
|
||||
"curl", "-sS", "-w", "\n%{http_code}",
|
||||
"http://localhost:8000/v1/completions", "-H",
|
||||
"Content-Type: application/json", "-d",
|
||||
'{"model":"TinyLlama/TinyLlama-1.1B-Chat-v1.0","prompt":"NVIDIA is a great company because","max_tokens":16,"temperature":0}'
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=30)
|
||||
|
||||
*body_lines, status_line = result.stdout.strip().splitlines()
|
||||
body = "\n".join(body_lines)
|
||||
status = int(status_line)
|
||||
|
||||
print("HTTP status:", status)
|
||||
print("Response body:", body)
|
||||
|
||||
assert result.returncode == 0, f"curl exit {result.returncode}"
|
||||
assert status == 200, f"Expected 200, got {status}"
|
||||
|
||||
finally:
|
||||
proc.terminate()
|
||||
try:
|
||||
proc.wait(timeout=10)
|
||||
except Exception:
|
||||
proc.kill()
|
||||
|
||||
subprocess.run("ray stop --force", shell=True, check=False)
|
||||
subprocess.run("pkill -9 -f trtllm-serve", shell=True, check=False)
|
||||
@ -109,6 +109,10 @@ if is_linux():
|
||||
torch_inductors.append(pid)
|
||||
continue
|
||||
|
||||
# Readability for ray processes. They have a lot of empty args
|
||||
if cmdline and cmdline[0].startswith('ray::'):
|
||||
cmdline = [arg for arg in cmdline if arg]
|
||||
|
||||
lines.append(f"{pid}: {cmdline}")
|
||||
persist_pids.append(pid)
|
||||
except psutil.Error:
|
||||
|
||||
@ -13,6 +13,7 @@ l0_dgx_b200:
|
||||
terms:
|
||||
stage: pre_merge
|
||||
backend: pytorch
|
||||
orchestrator: mpi
|
||||
tests:
|
||||
- unittest/_torch/multi_gpu_modeling -k "deepseek"
|
||||
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall_fp4[DeepEPLowLatency]
|
||||
@ -53,6 +54,29 @@ l0_dgx_b200:
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_tp4
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_nvfp4_tp4
|
||||
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-cutlass-auto]
|
||||
- condition:
|
||||
ranges:
|
||||
system_gpu_count:
|
||||
gte: 4
|
||||
lte: 4
|
||||
wildcards:
|
||||
gpu:
|
||||
- '*b200*'
|
||||
linux_distribution_name: ubuntu*
|
||||
terms:
|
||||
stage: pre_merge
|
||||
backend: pytorch
|
||||
orchestrator: ray
|
||||
tests:
|
||||
- unittest/llmapi/test_llm_multi_gpu_pytorch.py -m "gpu4"
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-fp8kv=False-attn_backend=TRTLLM-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp2pp2-fp8kv=False-attn_backend=TRTLLM-torch_compile=False]
|
||||
- disaggregated/test_disaggregated.py::test_disaggregated_ctxpp2_genpp2[TinyLlama-1.1B-Chat-v1.0]
|
||||
- disaggregated/test_disaggregated.py::test_disaggregated_ctxtp2_genpp2[TinyLlama-1.1B-Chat-v1.0]
|
||||
- disaggregated/test_disaggregated.py::test_disaggregated_ctxpp2_gentp2[TinyLlama-1.1B-Chat-v1.0]
|
||||
- disaggregated/test_disaggregated.py::test_disaggregated_ctxpp4_gentp4[TinyLlama-1.1B-Chat-v1.0]
|
||||
- examples/test_ray.py::test_llm_inference_distributed_ray[tp2pp2]
|
||||
- examples/test_ray.py::test_ray_disaggregated_serving[tp2]
|
||||
- condition:
|
||||
ranges:
|
||||
system_gpu_count:
|
||||
@ -66,6 +90,7 @@ l0_dgx_b200:
|
||||
terms:
|
||||
stage: pre_merge
|
||||
backend: pytorch
|
||||
orchestrator: mpi
|
||||
tests:
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput] TIMEOUT (180)
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_mtp] TIMEOUT (180)
|
||||
@ -86,6 +111,7 @@ l0_dgx_b200:
|
||||
terms:
|
||||
stage: post_merge
|
||||
backend: pytorch
|
||||
orchestrator: mpi
|
||||
tests:
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus_corner_case TIMEOUT (180)
|
||||
- condition:
|
||||
@ -101,6 +127,7 @@ l0_dgx_b200:
|
||||
terms:
|
||||
stage: post_merge
|
||||
backend: pytorch
|
||||
orchestrator: mpi
|
||||
tests:
|
||||
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall_fp4[DeepEP]
|
||||
- unittest/_torch/multi_gpu_modeling/test_llama3.py::test_llama_3_3
|
||||
|
||||
@ -13,6 +13,7 @@ l0_dgx_h100:
|
||||
stage: pre_merge
|
||||
backend: pytorch
|
||||
auto_trigger: others
|
||||
orchestrator: mpi
|
||||
tests:
|
||||
- unittest/llmapi/test_llm_multi_gpu_pytorch.py -m "gpu2"
|
||||
- unittest/_torch/multi_gpu -m "not post_merge" TIMEOUT (90)
|
||||
@ -54,6 +55,7 @@ l0_dgx_h100:
|
||||
stage: pre_merge
|
||||
backend: pytorch
|
||||
auto_trigger: others
|
||||
orchestrator: mpi
|
||||
tests:
|
||||
# ------------- PyTorch tests ---------------
|
||||
- unittest/llmapi/test_llm_multi_gpu_pytorch.py -m "gpu4"
|
||||
@ -99,6 +101,7 @@ l0_dgx_h100:
|
||||
stage: pre_merge
|
||||
backend: pytorch
|
||||
auto_trigger: deepseek
|
||||
orchestrator: mpi
|
||||
tests:
|
||||
- unittest/_torch/multi_gpu_modeling/test_deepseek.py::test_deepseek_streaming[tp1-bf16-trtllm-deepseekv3_lite]
|
||||
- unittest/_torch/multi_gpu_modeling/test_deepseek.py::test_deepseek_streaming[tp4-bf16-trtllm-deepseekv3_lite]
|
||||
@ -156,6 +159,7 @@ l0_dgx_h100:
|
||||
stage: pre_merge
|
||||
backend: pytorch
|
||||
auto_trigger: gpt_oss
|
||||
orchestrator: mpi
|
||||
tests:
|
||||
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[tp4-cutlass-auto]
|
||||
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[tp4-triton-auto]
|
||||
@ -177,6 +181,7 @@ l0_dgx_h100:
|
||||
stage: pre_merge
|
||||
backend: cpp
|
||||
auto_trigger: others
|
||||
orchestrator: mpi
|
||||
tests:
|
||||
# ------------- CPP tests ---------------
|
||||
- cpp/test_multi_gpu.py::test_mpi_utils[90]
|
||||
@ -218,3 +223,24 @@ l0_dgx_h100:
|
||||
- cpp/test_multi_gpu.py::TestDisagg::test_spawn_orchestrator[llama-ucx_kvcache-90]
|
||||
- cpp/test_multi_gpu.py::TestDisagg::test_orchestrator_params[llama-nixl_kvcache-90] TIMEOUT (90)
|
||||
- cpp/test_multi_gpu.py::TestDisagg::test_spawn_orchestrator[llama-nixl_kvcache-90]
|
||||
- condition:
|
||||
ranges:
|
||||
system_gpu_count:
|
||||
gte: 2
|
||||
lte: 2
|
||||
wildcards:
|
||||
gpu:
|
||||
- '*h100*'
|
||||
linux_distribution_name: ubuntu*
|
||||
terms:
|
||||
stage: pre_merge
|
||||
backend: pytorch
|
||||
orchestrator: ray
|
||||
tests:
|
||||
- unittest/_torch/ray_orchestrator/multi_gpu -m "gpu2"
|
||||
- unittest/llmapi/test_llm_multi_gpu_pytorch.py -m "gpu2"
|
||||
- accuracy/test_llm_api_pytorch_ray.py::TestLlama3_1_8BInstruct::test_pp2_ray
|
||||
- examples/test_ray.py::test_llm_inference_distributed_ray[tp2]
|
||||
- examples/test_ray.py::test_llm_inference_distributed_ray[pp2]
|
||||
- examples/test_ray.py::test_llm_inference_distributed_ray[tep2]
|
||||
- examples/test_ray.py::test_ray_disaggregated_serving[tp1]
|
||||
|
||||
@ -12,6 +12,7 @@ l0_h100:
|
||||
terms:
|
||||
stage: pre_merge
|
||||
backend: pytorch
|
||||
orchestrator: mpi
|
||||
tests:
|
||||
# ------------- PyTorch tests ---------------
|
||||
- unittest/_torch/attention
|
||||
@ -115,6 +116,32 @@ l0_h100:
|
||||
# ------------- AutoDeploy tests ---------------
|
||||
- accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype
|
||||
- accuracy/test_llm_api_autodeploy.py::TestNemotronH::test_auto_dtype
|
||||
- condition:
|
||||
ranges:
|
||||
system_gpu_count:
|
||||
gte: 1
|
||||
lte: 1
|
||||
wildcards:
|
||||
gpu:
|
||||
- '*h100*'
|
||||
linux_distribution_name: ubuntu*
|
||||
terms:
|
||||
stage: pre_merge
|
||||
backend: pytorch
|
||||
orchestrator: ray
|
||||
tests:
|
||||
- accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16[attn_backend=TRTLLM-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=disable-fp8kv=True-attention_dp=False-cuda_graph=True-overlap_scheduler=True-torch_compile=True]
|
||||
- accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_fp8_block_scales[latency]
|
||||
- test_e2e.py::test_trtllm_bench_pytorch_backend_sanity[meta-llama/Llama-3.1-8B-llama-3.1-8b-False-False]
|
||||
- disaggregated/test_disaggregated_single_gpu.py::test_disaggregated_simple_llama[False-False-TinyLlama-1.1B-Chat-v1.0]
|
||||
- disaggregated/test_disaggregated_single_gpu.py::test_disaggregated_simple_deepseek[False-False-DeepSeek-V3-Lite-fp8/fp8]
|
||||
- disaggregated/test_disaggregated_single_gpu.py::test_disaggregated_simple_qwen3[False-False-Qwen3-8B-FP8]
|
||||
- unittest/_torch/executor
|
||||
- unittest/_torch/ray_orchestrator/single_gpu
|
||||
- unittest/llmapi/test_llm_pytorch.py
|
||||
- examples/test_ray.py::test_llm_inference_async_ray
|
||||
- condition:
|
||||
ranges:
|
||||
system_gpu_count:
|
||||
@ -190,6 +217,7 @@ l0_h100:
|
||||
terms:
|
||||
stage: post_merge
|
||||
backend: pytorch
|
||||
orchestrator: mpi
|
||||
tests:
|
||||
# ------------- PyTorch tests ---------------
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16[attn_backend=FLASHINFER-torch_compile=False]
|
||||
|
||||
@ -43,6 +43,7 @@ def create_llm(model_dir, disable_overlap_scheduler, sampler_type):
|
||||
|
||||
@pytest.mark.parametrize("sampler_type", ["TorchSampler", "TRTLLMSampler"])
|
||||
@pytest.mark.high_cuda_memory
|
||||
@pytest.mark.mpi_ray_parity
|
||||
def test_overlap_scheduler_consistency(model_path, test_case, sampler_type):
|
||||
# Test configuration
|
||||
prompts = test_case["prompts"]
|
||||
|
||||
@ -16,6 +16,8 @@ from tensorrt_llm._torch.pyexecutor.resource_manager import (KVCacheManager,
|
||||
ResourceManagerType
|
||||
)
|
||||
# isort: on
|
||||
from utils.util import skip_ray
|
||||
|
||||
from tensorrt_llm._torch.attention_backend.interface import AttentionMetadata
|
||||
from tensorrt_llm._torch.pyexecutor.scheduler import ScheduledRequests
|
||||
from tensorrt_llm.bindings.executor import KvCacheConfig
|
||||
@ -322,6 +324,7 @@ class PyTorchModelEngineTestCase(unittest.TestCase):
|
||||
self.assertTrue(pytorch_config_custom.cuda_graph_padding_enabled,
|
||||
"Custom enable_padding should be respected")
|
||||
|
||||
@skip_ray
|
||||
def test_prepare_tp_inputs_with_helix_parallelism(self) -> None:
|
||||
"""Test _prepare_tp_inputs function with helix parallelism."""
|
||||
|
||||
|
||||
23
tests/unittest/_torch/ray_orchestrator/conftest.py
Normal file
23
tests/unittest/_torch/ray_orchestrator/conftest.py
Normal file
@ -0,0 +1,23 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
from tensorrt_llm._utils import mpi_disabled
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
if config.getoption("--run-ray"):
|
||||
os.environ["TLLM_DISABLE_MPI"] = "1"
|
||||
os.environ["TLLM_RAY_FORCE_LOCAL_CLUSTER"] = "1"
|
||||
|
||||
|
||||
run_ray_flag = "--run-ray" in sys.argv
|
||||
if run_ray_flag:
|
||||
os.environ["TLLM_DISABLE_MPI"] = "1"
|
||||
os.environ["TLLM_RAY_FORCE_LOCAL_CLUSTER"] = "1"
|
||||
|
||||
if not mpi_disabled():
|
||||
pytest.skip(
|
||||
"Ray tests are only tested in Ray CI stage or with --run-ray flag",
|
||||
allow_module_level=True)
|
||||
147
tests/unittest/_torch/ray_orchestrator/multi_gpu/test_mapping.py
Normal file
147
tests/unittest/_torch/ray_orchestrator/multi_gpu/test_mapping.py
Normal file
@ -0,0 +1,147 @@
|
||||
import copy
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from tensorrt_llm._torch.device_mesh import DeviceMeshTopologyImpl
|
||||
from tensorrt_llm._utils import get_free_port
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
|
||||
|
||||
class TestMapping(unittest.TestCase):
|
||||
|
||||
@pytest.mark.gpu2
|
||||
def test_device_mesh_parity(self):
|
||||
"""To ensure parity between Ray and MPI Mapping instance."""
|
||||
# (tp, pp, cp, moe_tp, moe_ep)
|
||||
combos = [
|
||||
# no cp
|
||||
(2, 1, 1, -1, -1), # -1 means no MoE in Mapping
|
||||
# 8 GPUs, no cp
|
||||
(4, 2, 1, -1, -1),
|
||||
(2, 4, 1, -1, -1),
|
||||
# 8 GPUs with cp
|
||||
(4, 1, 2, -1, -1),
|
||||
(2, 1, 4, -1, -1),
|
||||
# with moe_tp, moe_ep
|
||||
(8, 1, 1, 2, 4),
|
||||
(2, 1, 1, 1, 2)
|
||||
]
|
||||
|
||||
num_gpus = torch.cuda.device_count()
|
||||
|
||||
for tp, pp, cp, moe_tp, moe_ep in combos:
|
||||
world_size = tp * pp * cp
|
||||
print(
|
||||
f"\n\n=== TP={tp}, PP={pp}, CP={cp}, MOE_TP={moe_tp}, MOE_EP={moe_ep} ==="
|
||||
)
|
||||
|
||||
if world_size > num_gpus:
|
||||
print(
|
||||
f"SKIPPING: need {world_size} GPUs. Only have {num_gpus}.")
|
||||
continue
|
||||
|
||||
mp.spawn(
|
||||
self._worker,
|
||||
args=(world_size, get_free_port(), tp, pp, cp, moe_tp, moe_ep),
|
||||
nprocs=world_size,
|
||||
join=True,
|
||||
)
|
||||
|
||||
@pytest.mark.gpu2
|
||||
def test_picklable(self):
|
||||
world_size = tp = 2
|
||||
|
||||
mp.spawn(
|
||||
self._worker,
|
||||
args=(world_size, get_free_port(), tp, 1, 1, -1, -1, "pickle"),
|
||||
nprocs=world_size,
|
||||
join=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _worker(rank: int,
|
||||
world_size: int,
|
||||
master_port: int,
|
||||
tp=1,
|
||||
pp=1,
|
||||
cp=1,
|
||||
moe_tp=1,
|
||||
moe_ep=1,
|
||||
test_type="parity") -> None:
|
||||
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
||||
os.environ["MASTER_PORT"] = str(master_port)
|
||||
os.environ["RANK"] = str(rank)
|
||||
os.environ["WORLD_SIZE"] = str(world_size)
|
||||
os.environ["LOCAL_RANK"] = str(rank)
|
||||
torch.cuda.set_device(rank)
|
||||
|
||||
if test_type == "parity":
|
||||
if "TLLM_DISABLE_MPI" in os.environ:
|
||||
del os.environ["TLLM_DISABLE_MPI"]
|
||||
|
||||
mapping_mpi = Mapping(
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
gpus_per_node=world_size,
|
||||
tp_size=tp,
|
||||
pp_size=pp,
|
||||
cp_size=cp,
|
||||
moe_tp_size=moe_tp,
|
||||
moe_ep_size=moe_ep,
|
||||
)
|
||||
|
||||
os.environ["TLLM_DISABLE_MPI"] = "1"
|
||||
dist.init_process_group("nccl", rank=rank, world_size=world_size)
|
||||
|
||||
mapping_device_mesh = Mapping(
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
gpus_per_node=world_size,
|
||||
tp_size=tp,
|
||||
pp_size=pp,
|
||||
cp_size=cp,
|
||||
moe_tp_size=moe_tp,
|
||||
moe_ep_size=moe_ep,
|
||||
)
|
||||
|
||||
if test_type == "parity":
|
||||
mapping_device_mesh.build_mesh()
|
||||
|
||||
properties = []
|
||||
for dim in mapping_device_mesh.device_mesh.mesh_dim_names:
|
||||
properties.append(f"{dim}_rank")
|
||||
properties.append(f"{dim}_group")
|
||||
|
||||
for prop in properties:
|
||||
mpi_value = getattr(mapping_mpi, prop)
|
||||
device_mesh_value = getattr(mapping_device_mesh, prop)
|
||||
|
||||
if rank == 0:
|
||||
print(
|
||||
f" {prop}: MPI={mpi_value}, DeviceMesh={device_mesh_value}"
|
||||
)
|
||||
|
||||
assert mpi_value == device_mesh_value, \
|
||||
f"Property {prop} mismatch: MPI={mpi_value}, DeviceMesh={device_mesh_value} (rank {rank})"
|
||||
elif test_type == "pickle":
|
||||
mapping = mapping_device_mesh
|
||||
|
||||
tp_group = mapping.tp_group
|
||||
print(f"tp_group: {tp_group}")
|
||||
assert DeviceMeshTopologyImpl.device_mesh is not None
|
||||
mapping_copy = copy.deepcopy(mapping)
|
||||
|
||||
# check static mesh still exists
|
||||
assert mapping_copy.device_mesh is not None
|
||||
print(f"tp_group after deepcopy: {mapping.tp_group}")
|
||||
assert mapping.tp_group == mapping_copy.tp_group
|
||||
|
||||
else:
|
||||
raise ValueError(f"Invalid test type: {test_type}")
|
||||
|
||||
dist.destroy_process_group()
|
||||
381
tests/unittest/_torch/ray_orchestrator/multi_gpu/test_ops_ray.py
Normal file
381
tests/unittest/_torch/ray_orchestrator/multi_gpu/test_ops_ray.py
Normal file
@ -0,0 +1,381 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
try:
|
||||
import ray
|
||||
except ModuleNotFoundError:
|
||||
from tensorrt_llm import ray_stub as ray
|
||||
|
||||
from tensorrt_llm._torch.distributed.communicator import TorchDist
|
||||
from tensorrt_llm._utils import get_free_port
|
||||
from tensorrt_llm.functional import AllReduceFusionOp, AllReduceStrategy
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
|
||||
|
||||
@ray.remote(num_gpus=1)
|
||||
class AllgatherPGTest:
|
||||
|
||||
def __init__(self, rank, world_size):
|
||||
self.rank = rank
|
||||
self.world_size = world_size
|
||||
self.master_address = os.environ["MASTER_ADDR"]
|
||||
self.master_port = os.environ["MASTER_PORT"]
|
||||
assert len(ray.get_gpu_ids()) == 1
|
||||
|
||||
torch.distributed.init_process_group(
|
||||
backend="cuda:nccl,cpu:gloo",
|
||||
init_method=f"tcp://{self.master_address}:{self.master_port}",
|
||||
world_size=world_size,
|
||||
rank=rank)
|
||||
|
||||
@torch.inference_mode()
|
||||
def run_allgather_pg_op(self, test_tensor, expected_result, sizes):
|
||||
torch.cuda.set_device(0)
|
||||
test_tensor = test_tensor.cuda(0)
|
||||
expected_result = expected_result.cuda(0)
|
||||
|
||||
mapping = Mapping(world_size=self.world_size,
|
||||
gpus_per_node=self.world_size,
|
||||
tp_size=self.world_size,
|
||||
rank=self.rank)
|
||||
if torch.distributed.is_initialized():
|
||||
TorchDist(mapping)
|
||||
else:
|
||||
raise RuntimeError("torch.distributed is not initialized")
|
||||
|
||||
module = torch
|
||||
op_path = ['ops', 'trtllm', 'allgather_pg']
|
||||
for attr in op_path:
|
||||
module = getattr(module, attr)
|
||||
allgather_pg_op = module
|
||||
output = allgather_pg_op(test_tensor, sizes, mapping.tp_group,
|
||||
mapping.tp_group_pg.boxed())
|
||||
|
||||
if isinstance(output, (list, tuple)):
|
||||
result_tensor = output[0]
|
||||
else:
|
||||
result_tensor = output
|
||||
|
||||
rtol, atol = 0.05, 0.15
|
||||
torch.testing.assert_close(result_tensor,
|
||||
expected_result,
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@ray.remote(num_gpus=1)
|
||||
class ReducescatterPGTest:
|
||||
|
||||
def __init__(self, rank, world_size):
|
||||
self.rank = rank
|
||||
self.world_size = world_size
|
||||
self.master_address = os.environ["MASTER_ADDR"]
|
||||
self.master_port = os.environ["MASTER_PORT"]
|
||||
assert len(ray.get_gpu_ids()) == 1
|
||||
|
||||
torch.distributed.init_process_group(
|
||||
backend="cuda:nccl,cpu:gloo",
|
||||
init_method=f"tcp://{self.master_address}:{self.master_port}",
|
||||
world_size=world_size,
|
||||
rank=rank)
|
||||
|
||||
@torch.inference_mode()
|
||||
def run_reducescatter_pg_op(self, test_tensor, expected_result, sizes):
|
||||
torch.cuda.set_device(0)
|
||||
test_tensor = test_tensor.cuda(0)
|
||||
expected_result = expected_result.cuda(0)
|
||||
|
||||
mapping = Mapping(world_size=self.world_size,
|
||||
gpus_per_node=self.world_size,
|
||||
tp_size=self.world_size,
|
||||
rank=self.rank)
|
||||
if torch.distributed.is_initialized():
|
||||
TorchDist(mapping)
|
||||
else:
|
||||
raise RuntimeError("torch.distributed is not initialized")
|
||||
|
||||
module = torch
|
||||
op_path = ['ops', 'trtllm', 'reducescatter_pg']
|
||||
for attr in op_path:
|
||||
module = getattr(module, attr)
|
||||
reducescatter_pg_op = module
|
||||
output = reducescatter_pg_op(test_tensor, sizes, mapping.tp_group,
|
||||
mapping.tp_group_pg.boxed())
|
||||
|
||||
if isinstance(output, (list, tuple)):
|
||||
result_tensor = output[0]
|
||||
else:
|
||||
result_tensor = output
|
||||
|
||||
rtol, atol = 0.05, 0.15
|
||||
torch.testing.assert_close(result_tensor,
|
||||
expected_result,
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@ray.remote(num_gpus=1)
|
||||
class AllreducePGTest:
|
||||
|
||||
def __init__(self, rank, world_size):
|
||||
self.rank = rank
|
||||
self.world_size = world_size
|
||||
self.master_address = os.environ["MASTER_ADDR"]
|
||||
self.master_port = os.environ["MASTER_PORT"]
|
||||
assert len(ray.get_gpu_ids()) == 1
|
||||
|
||||
torch.distributed.init_process_group(
|
||||
backend="cuda:nccl,cpu:gloo",
|
||||
init_method=f"tcp://{self.master_address}:{self.master_port}",
|
||||
world_size=world_size,
|
||||
rank=rank)
|
||||
|
||||
@torch.inference_mode()
|
||||
def run_allreduce_pg_op(self, test_tensor, expected_result):
|
||||
torch.cuda.set_device(0)
|
||||
test_tensor = test_tensor.cuda(0)
|
||||
expected_result = expected_result.cuda(0)
|
||||
|
||||
mapping = Mapping(world_size=self.world_size,
|
||||
gpus_per_node=self.world_size,
|
||||
tp_size=self.world_size,
|
||||
rank=self.rank)
|
||||
if torch.distributed.is_initialized():
|
||||
TorchDist(mapping)
|
||||
else:
|
||||
raise RuntimeError("torch.distributed is not initialized")
|
||||
|
||||
module = torch
|
||||
op_path = ['ops', 'trtllm', 'allreduce_pg']
|
||||
for attr in op_path:
|
||||
module = getattr(module, attr)
|
||||
allreduce_pg_op = module
|
||||
output = allreduce_pg_op(
|
||||
input=test_tensor,
|
||||
residual=None,
|
||||
norm_weight=None,
|
||||
scale=None,
|
||||
bias=None,
|
||||
workspace=None,
|
||||
group=mapping.tp_group,
|
||||
strategy=AllReduceStrategy.NCCL,
|
||||
op=AllReduceFusionOp.NONE, # Pure allreduce, no fusion
|
||||
eps=1e-5,
|
||||
trigger_completion_at_end=True,
|
||||
rank=self.rank,
|
||||
pg=mapping.tp_group_pg.boxed())
|
||||
|
||||
if isinstance(output, (list, tuple)):
|
||||
result_tensor = output[0]
|
||||
else:
|
||||
result_tensor = output
|
||||
|
||||
rtol, atol = 0.05, 0.15
|
||||
torch.testing.assert_close(result_tensor,
|
||||
expected_result,
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||
reason="Requires at least 2 GPUs for this test")
|
||||
@pytest.mark.parametrize("hidden_size", [128, 1024],
|
||||
ids=lambda x: f"hidden:{x}")
|
||||
@pytest.mark.parametrize("seq_len", [16, 64], ids=lambda x: f"seqlen:{x}")
|
||||
@pytest.mark.parametrize("var_len", [True, False], ids=lambda x: f"var_len:{x}")
|
||||
def test_allgather_pg_op(seq_len, hidden_size, var_len):
|
||||
torch.manual_seed(42)
|
||||
dtype = torch.bfloat16
|
||||
world_size = 2
|
||||
if var_len:
|
||||
test_tensor_list = [
|
||||
torch.randn((seq_len * (i + 1), hidden_size), dtype=dtype)
|
||||
for i in range(world_size)
|
||||
]
|
||||
expected_result = torch.cat(test_tensor_list, dim=0)
|
||||
sizes = [seq_len * (i + 1) for i in range(world_size)]
|
||||
else:
|
||||
test_tensor = torch.randn((seq_len, hidden_size), dtype=dtype)
|
||||
expected_result = test_tensor.repeat(world_size, 1)
|
||||
sizes = None
|
||||
ray_init_args = {
|
||||
"include_dashboard": False,
|
||||
"namespace": "test_allgather_pg_op",
|
||||
"ignore_reinit_error": True
|
||||
}
|
||||
|
||||
try:
|
||||
ray.init(address="local", **ray_init_args)
|
||||
|
||||
master_port = get_free_port()
|
||||
runtime_env = ray.runtime_env.RuntimeEnv()
|
||||
runtime_env["env_vars"] = os.environ.copy()
|
||||
runtime_env["env_vars"].update({
|
||||
"TLLM_DISABLE_MPI": "1",
|
||||
"MASTER_ADDR": "127.0.0.1",
|
||||
"MASTER_PORT": str(master_port)
|
||||
})
|
||||
|
||||
remotePGTests = []
|
||||
for rank in range(world_size):
|
||||
remotePGTests.append(
|
||||
AllgatherPGTest.options(runtime_env=runtime_env).remote(
|
||||
rank, world_size))
|
||||
|
||||
if var_len:
|
||||
results = ray.get([
|
||||
remotePGTest.run_allgather_pg_op.remote(test_tensor,
|
||||
expected_result, sizes)
|
||||
for remotePGTest, test_tensor in zip(remotePGTests,
|
||||
test_tensor_list)
|
||||
])
|
||||
else:
|
||||
results = ray.get([
|
||||
remotePGTest.run_allgather_pg_op.remote(test_tensor,
|
||||
expected_result, sizes)
|
||||
for remotePGTest in remotePGTests
|
||||
])
|
||||
except Exception as e:
|
||||
if ray.is_initialized():
|
||||
ray.shutdown()
|
||||
raise e
|
||||
|
||||
ray.shutdown()
|
||||
for r in results:
|
||||
assert r is True
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||
reason="Requires at least 2 GPUs for this test")
|
||||
@pytest.mark.parametrize("hidden_size", [128, 1024],
|
||||
ids=lambda x: f"hidden:{x}")
|
||||
@pytest.mark.parametrize("seq_len", [16, 64], ids=lambda x: f"seqlen:{x}")
|
||||
@pytest.mark.parametrize("var_len", [True, False], ids=lambda x: f"var_len:{x}")
|
||||
def test_reducescatter_pg_op(seq_len, hidden_size, var_len):
|
||||
torch.manual_seed(42)
|
||||
dtype = torch.bfloat16
|
||||
world_size = 2
|
||||
if var_len:
|
||||
total_seq_len = sum([seq_len * (i + 1) for i in range(world_size)])
|
||||
test_tensor = torch.randn((total_seq_len, hidden_size), dtype=dtype)
|
||||
expected_result_list = []
|
||||
offset = 0
|
||||
for i in range(world_size):
|
||||
expected_result_list.append(test_tensor[offset:offset + seq_len *
|
||||
(i + 1)] * world_size)
|
||||
offset += seq_len * (i + 1)
|
||||
sizes = [seq_len * (i + 1) for i in range(world_size)]
|
||||
else:
|
||||
test_tensor = torch.randn((seq_len * world_size, hidden_size),
|
||||
dtype=dtype)
|
||||
expected_result_list = [
|
||||
test_tensor[i * seq_len:(i + 1) * seq_len] * world_size
|
||||
for i in range(world_size)
|
||||
]
|
||||
sizes = None
|
||||
ray_init_args = {
|
||||
"include_dashboard": False,
|
||||
"namespace": "test_reducescatter_pg_op",
|
||||
"ignore_reinit_error": True
|
||||
}
|
||||
|
||||
try:
|
||||
ray.init(address="local", **ray_init_args)
|
||||
|
||||
master_port = get_free_port()
|
||||
runtime_env = ray.runtime_env.RuntimeEnv()
|
||||
runtime_env["env_vars"] = os.environ.copy()
|
||||
runtime_env["env_vars"].update({
|
||||
"TLLM_DISABLE_MPI": "1",
|
||||
"MASTER_ADDR": "127.0.0.1",
|
||||
"MASTER_PORT": str(master_port)
|
||||
})
|
||||
|
||||
remotePGTests = []
|
||||
for rank in range(world_size):
|
||||
remotePGTests.append(
|
||||
ReducescatterPGTest.options(runtime_env=runtime_env).remote(
|
||||
rank, world_size))
|
||||
|
||||
if var_len:
|
||||
results = ray.get([
|
||||
remotePGTest.run_reducescatter_pg_op.remote(
|
||||
test_tensor, expected_result, sizes)
|
||||
for remotePGTest, expected_result in zip(
|
||||
remotePGTests, expected_result_list)
|
||||
])
|
||||
else:
|
||||
results = ray.get([
|
||||
remotePGTest.run_reducescatter_pg_op.remote(
|
||||
test_tensor, expected_result, sizes)
|
||||
for remotePGTest, expected_result in zip(
|
||||
remotePGTests, expected_result_list)
|
||||
])
|
||||
except Exception as e:
|
||||
if ray.is_initialized():
|
||||
ray.shutdown()
|
||||
raise e
|
||||
|
||||
ray.shutdown()
|
||||
for r in results:
|
||||
assert r is True
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||
reason="Requires at least 2 GPUs for this test")
|
||||
@pytest.mark.parametrize("hidden_size", [128, 1024],
|
||||
ids=lambda x: f"hidden:{x}")
|
||||
@pytest.mark.parametrize("seq_len", [16, 64], ids=lambda x: f"seqlen:{x}")
|
||||
def test_allreduce_pg_op(seq_len, hidden_size):
|
||||
torch.manual_seed(42)
|
||||
dtype = torch.bfloat16
|
||||
world_size = 2
|
||||
test_tensor = torch.randn((seq_len, hidden_size), dtype=dtype)
|
||||
expected_result = test_tensor * world_size
|
||||
|
||||
ray_init_args = {
|
||||
"include_dashboard": False,
|
||||
"namespace": "test_allreduce_pg_op",
|
||||
"ignore_reinit_error": True
|
||||
}
|
||||
|
||||
try:
|
||||
ray.init(address="local", **ray_init_args)
|
||||
|
||||
master_port = get_free_port()
|
||||
runtime_env = ray.runtime_env.RuntimeEnv()
|
||||
runtime_env["env_vars"] = os.environ.copy()
|
||||
runtime_env["env_vars"].update({
|
||||
"TLLM_DISABLE_MPI": "1",
|
||||
"MASTER_ADDR": "127.0.0.1",
|
||||
"MASTER_PORT": str(master_port)
|
||||
})
|
||||
|
||||
remotePGTests = []
|
||||
for rank in range(world_size):
|
||||
remotePGTests.append(
|
||||
AllreducePGTest.options(runtime_env=runtime_env).remote(
|
||||
rank, world_size))
|
||||
|
||||
results = ray.get([
|
||||
remotePGTest.run_allreduce_pg_op.remote(test_tensor,
|
||||
expected_result)
|
||||
for remotePGTest in remotePGTests
|
||||
])
|
||||
except Exception as e:
|
||||
if ray.is_initialized():
|
||||
ray.shutdown()
|
||||
raise e
|
||||
|
||||
ray.shutdown()
|
||||
|
||||
for r in results:
|
||||
assert r is True
|
||||
@ -0,0 +1,21 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from tensorrt_llm import LLM
|
||||
from tensorrt_llm._torch.utils import get_device_uuid
|
||||
|
||||
|
||||
@pytest.mark.gpu2
|
||||
def test_cuda_visible_device():
|
||||
"""Placement via cuda_visible_device"""
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
|
||||
|
||||
llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
orchestrator_type="ray")
|
||||
|
||||
infer_actor_uuids = llm._collective_rpc("report_device_id")
|
||||
|
||||
del os.environ["CUDA_VISIBLE_DEVICES"]
|
||||
assert infer_actor_uuids[0] == get_device_uuid(1)
|
||||
print(f"{infer_actor_uuids=}")
|
||||
@ -0,0 +1,87 @@
|
||||
import importlib
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from tensorrt_llm._utils import get_free_port, torch_pybind11_abi
|
||||
|
||||
|
||||
class TestCacheTransceiverComm(unittest.TestCase):
|
||||
|
||||
def test_cache_transceiver_comm(self):
|
||||
mp.spawn(
|
||||
self._worker,
|
||||
args=(4, get_free_port()),
|
||||
nprocs=4,
|
||||
join=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _worker(rank: int, world_size: int, master_port: int):
|
||||
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
||||
os.environ["MASTER_PORT"] = str(master_port)
|
||||
os.environ["RANK"] = str(rank)
|
||||
os.environ["WORLD_SIZE"] = str(world_size)
|
||||
os.environ["LOCAL_RANK"] = str(rank)
|
||||
|
||||
try:
|
||||
dist.init_process_group(
|
||||
backend="gloo",
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
world_pg = torch.distributed.group.WORLD
|
||||
bm = importlib.import_module(
|
||||
"tensorrt_llm.bindings.internal.batch_manager")
|
||||
cacheComm = getattr(bm, "CacheTransceiverComm")
|
||||
comm = cacheComm(world_pg, torch_pybind11_abi())
|
||||
|
||||
# Test split
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
color = rank // 2
|
||||
key = rank % 2
|
||||
sub = comm.split(color, key)
|
||||
|
||||
expected_group_size = 2
|
||||
assert sub.get_size() == expected_group_size
|
||||
|
||||
# Test allgather
|
||||
ok, gathered_ranks = sub.allgather(rank)
|
||||
assert ok is True
|
||||
expected_world_ranks = [
|
||||
r for r in range(world_size) if (r // 2) == color
|
||||
]
|
||||
assert gathered_ranks == expected_world_ranks
|
||||
|
||||
# Test allgatherv
|
||||
local_len = rank + 1
|
||||
payload = [rank] * local_len
|
||||
|
||||
ok_sizes, sizes64 = sub.allgather(local_len)
|
||||
assert ok_sizes is True
|
||||
sizes = [int(x) for x in sizes64]
|
||||
|
||||
ok_v, out = sub.allgatherv(payload, sizes)
|
||||
assert ok_v is True
|
||||
|
||||
expected_concat = []
|
||||
for r in expected_world_ranks:
|
||||
expected_concat.extend([r] * (r + 1))
|
||||
assert out == expected_concat
|
||||
|
||||
# Test allgatherv with char
|
||||
char_payload = [chr(65 + rank)] * local_len
|
||||
ok_char, char_out = sub.allgatherv(char_payload, sizes)
|
||||
assert ok_char is True
|
||||
|
||||
expected_char_concat = []
|
||||
for r in expected_world_ranks:
|
||||
expected_char_concat.extend([chr(65 + r)] * (r + 1))
|
||||
assert char_out == expected_char_concat
|
||||
finally:
|
||||
if dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
@ -64,6 +64,7 @@ def test_register_fake(custom_ops):
|
||||
"trtllm::mtp_prepare_drafter_inputs_op",
|
||||
"trtllm::selective_scan",
|
||||
"trtllm::reducescatter_list",
|
||||
"trtllm::reducescatter_list_pg",
|
||||
"trtllm::fp8_per_tensor_scale_moe_runner",
|
||||
"trtllm::migrate_to_host_accessible",
|
||||
"trtllm::mnnvl_moe_alltoallv_prepare_without_allgather",
|
||||
|
||||
@ -75,6 +75,10 @@ methods:
|
||||
annotation: Optional[str]
|
||||
default: null
|
||||
status: deprecated
|
||||
orchestrator_type:
|
||||
annotation: Optional[Literal['ray']]
|
||||
default: null
|
||||
status: prototype
|
||||
build_config:
|
||||
annotation: Optional[tensorrt_llm.llmapi.llm_args.BuildConfig]
|
||||
default: null
|
||||
|
||||
@ -79,6 +79,12 @@ def pytest_addoption(parser):
|
||||
help=
|
||||
"Prepend a prefix to the test names. Useful for distinguishing different test runs in a test report."
|
||||
)
|
||||
parser.addoption(
|
||||
"--run-ray",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Run Ray-marked tests (by default they are skipped).",
|
||||
)
|
||||
|
||||
|
||||
@pytest.hookimpl(tryfirst=True, hookwrapper=True)
|
||||
@ -94,8 +100,23 @@ def pytest_collection_modifyitems(session, config, items):
|
||||
for item in items:
|
||||
item._nodeid = f"{test_prefix}/{item._nodeid}"
|
||||
|
||||
# Ray tests are disabled by default
|
||||
run_ray = config.getoption("--run-ray") or os.environ.get(
|
||||
"TLLM_RUN_RAY_TESTS") == "1"
|
||||
if not run_ray:
|
||||
skip_marker = pytest.mark.skip(
|
||||
reason=
|
||||
"Ray tests skipped; pass --run-ray or set TLLM_RUN_RAY_TESTS=1")
|
||||
for item in items:
|
||||
if "ray" in item.keywords:
|
||||
item.add_marker(skip_marker)
|
||||
|
||||
|
||||
def pytest_sessionstart(session):
|
||||
if session.config.getoption("--run-ray"):
|
||||
os.environ["TLLM_DISABLE_MPI"] = "1"
|
||||
os.environ["TLLM_RAY_FORCE_LOCAL_CLUSTER"] = "1"
|
||||
|
||||
# To counter TransformerEngine v2.3's lazy_compile deferral,
|
||||
# which will cause Pytest thinks there's a thread leakage.
|
||||
import torch._inductor.async_compile # noqa: F401
|
||||
@ -147,3 +168,70 @@ def mpi_pool_executor(request):
|
||||
# make the number of workers visible to tests
|
||||
setattr(executor, "num_workers", num_workers)
|
||||
yield executor
|
||||
|
||||
|
||||
def pytest_generate_tests(metafunc: pytest.Metafunc):
|
||||
if metafunc.definition.get_closest_marker('mpi_ray_parity'):
|
||||
run_ray = metafunc.config.getoption("--run-ray") or os.environ.get(
|
||||
"TLLM_RUN_RAY_TESTS") == "1"
|
||||
if run_ray:
|
||||
metafunc.parametrize(
|
||||
'ray_mode',
|
||||
[
|
||||
pytest.param('ray', id='ray', marks=pytest.mark.ray),
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ray_mode(request):
|
||||
return getattr(request, 'param', 'mpi')
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _maybe_force_ray(request, monkeypatch, ray_mode):
|
||||
"""
|
||||
Patch the LLM class (torch only) to use Ray executor.
|
||||
"""
|
||||
if 'mpi_ray_parity' not in request.node.keywords or ray_mode != 'ray':
|
||||
return
|
||||
|
||||
def wrap_llm(cls):
|
||||
|
||||
class LLMProxy(cls):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs["orchestrator_type"] = "ray"
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
return LLMProxy
|
||||
|
||||
test_mod = request.node.module
|
||||
|
||||
# Only patch the torch LLM class
|
||||
if hasattr(test_mod, 'LLM'):
|
||||
try:
|
||||
from tensorrt_llm._tensorrt_engine import LLM as LLM_legacy
|
||||
is_trtllm_backend = (test_mod.LLM is LLM_legacy)
|
||||
except Exception:
|
||||
is_trtllm_backend = False
|
||||
if not is_trtllm_backend:
|
||||
monkeypatch.setattr(test_mod,
|
||||
'LLM',
|
||||
wrap_llm(test_mod.LLM),
|
||||
raising=False)
|
||||
if hasattr(test_mod, 'LLM_torch'):
|
||||
monkeypatch.setattr(test_mod,
|
||||
'LLM_torch',
|
||||
wrap_llm(test_mod.LLM_torch),
|
||||
raising=False)
|
||||
|
||||
try:
|
||||
import tensorrt_llm.llmapi.llm as llm_mod
|
||||
monkeypatch.setattr(llm_mod,
|
||||
'LLM',
|
||||
wrap_llm(llm_mod.LLM),
|
||||
raising=False)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@ -550,6 +550,7 @@ def _test_llm_generate_async(model_name=default_model_name,
|
||||
|
||||
@pytest.mark.parametrize("chunked", [True, False])
|
||||
@pytest.mark.part0
|
||||
@pytest.mark.mpi_ray_parity
|
||||
def test_llm_generate_async_with_stream_interval(chunked):
|
||||
model_path = get_model_path('llama-models-v2/llama-v2-7b-hf')
|
||||
max_num_tokens = 256
|
||||
|
||||
@ -27,7 +27,7 @@ from .test_llm import (_test_llm_capture_request_error, get_model_path,
|
||||
tinyllama_logits_processor_test_harness)
|
||||
from utils.util import (force_ampere, similar, skip_gpu_memory_less_than_40gb,
|
||||
skip_gpu_memory_less_than_80gb,
|
||||
skip_gpu_memory_less_than_138gb)
|
||||
skip_gpu_memory_less_than_138gb, skip_ray)
|
||||
from utils.llm_data import llm_models_root
|
||||
from tensorrt_llm.lora_helper import LoraConfig
|
||||
from tensorrt_llm.executor.request import LoRARequest
|
||||
@ -50,6 +50,7 @@ def test_tinyllama_logits_processor(enable_chunked_prefill):
|
||||
backend="pytorch", enable_chunked_prefill=enable_chunked_prefill)
|
||||
|
||||
|
||||
@skip_ray
|
||||
@pytest.mark.parametrize(
|
||||
"return_context_logits, use_overlap, enable_iter_req_stats", [
|
||||
(False, False, False),
|
||||
@ -66,6 +67,7 @@ def test_llm_get_stats(return_context_logits, use_overlap,
|
||||
enable_iter_req_stats=enable_iter_req_stats)
|
||||
|
||||
|
||||
@skip_ray
|
||||
@pytest.mark.parametrize(
|
||||
"return_context_logits, use_overlap, enable_iter_req_stats", [
|
||||
(False, False, False),
|
||||
@ -88,6 +90,7 @@ def test_llm_capture_request_error():
|
||||
|
||||
|
||||
@force_ampere
|
||||
@pytest.mark.mpi_ray_parity
|
||||
@pytest.mark.parametrize(
|
||||
"sampling_params",
|
||||
[
|
||||
@ -175,6 +178,7 @@ def test_llm_reward_model():
|
||||
assert not outputs[0].outputs[0].text
|
||||
|
||||
|
||||
@skip_ray
|
||||
def test_llm_perf_metrics():
|
||||
llm = LLM(model=llama_model_path, kv_cache_config=global_kvcache_config)
|
||||
sampling_params = SamplingParams(max_tokens=10, return_perf_metrics=True)
|
||||
@ -200,6 +204,7 @@ def test_llm_perf_metrics():
|
||||
assert perf_metrics.last_iter == perf_metrics.iter
|
||||
|
||||
|
||||
@skip_ray
|
||||
def test_llm_prometheus():
|
||||
test_prompts = [
|
||||
"Hello, my name is",
|
||||
@ -221,6 +226,7 @@ def test_llm_prometheus():
|
||||
assert request_output.outputs is not None
|
||||
|
||||
|
||||
@skip_ray
|
||||
@pytest.mark.parametrize("streaming", [True, False])
|
||||
def test_llm_with_postprocess_parallel_and_result_handler(streaming):
|
||||
run_llm_with_postprocess_parallel_and_result_handler(streaming,
|
||||
@ -404,6 +410,7 @@ def test_llama_7b_multi_lora_evict_and_reload_evicted_adapters_in_cpu_and_gpu_ca
|
||||
repeats_per_call=1)
|
||||
|
||||
|
||||
@skip_ray
|
||||
@skip_gpu_memory_less_than_40gb
|
||||
def test_llama_7b_peft_cache_config_affects_peft_cache_size():
|
||||
"""Tests that LLM arg of peft_cache_config affects the peft cache sizes.
|
||||
@ -832,6 +839,7 @@ FailingExecutor = type(
|
||||
})
|
||||
|
||||
|
||||
@skip_ray
|
||||
def test_llm_with_proxy_error():
|
||||
"""Test that LLM properly handles GenerationExecutorWorker constructor failures.
|
||||
|
||||
@ -885,6 +893,7 @@ def test_min_tokens(use_speculative: bool):
|
||||
assert len(res.outputs[0].token_ids) == output_len
|
||||
|
||||
|
||||
@skip_ray
|
||||
@pytest.mark.parametrize(
|
||||
"prompt_logprobs, logprobs, return_context_logits, return_generation_logits, backend",
|
||||
[
|
||||
@ -907,6 +916,7 @@ def test_llm_return_logprobs(prompt_logprobs: Optional[int],
|
||||
backend=backend)
|
||||
|
||||
|
||||
@skip_ray
|
||||
@pytest.mark.parametrize(
|
||||
"prompt_logprobs, logprobs, return_context_logits, return_generation_logits",
|
||||
[
|
||||
|
||||
@ -22,3 +22,5 @@ markers =
|
||||
post_merge: this test should only run in post merge
|
||||
high_cuda_memory: this test uses a lot of CUDA memory (typically more than 12GB)
|
||||
no_xdist: this test should not run when using pytest-xdist
|
||||
ray: mark Ray-based tests
|
||||
mpi_ray_parity: parametrize a test to also run with Ray executor variant
|
||||
|
||||
@ -446,3 +446,8 @@ def check_accuracy(a, b, atol, rtol, percent):
|
||||
if not (mismatch_percent < 1 - percent):
|
||||
raise Exception("Mismatch percentage is %f for rtol %f" %
|
||||
(mismatch_percent, rtol))
|
||||
|
||||
|
||||
skip_ray = pytest.mark.skipif(
|
||||
os.environ.get("TLLM_DISABLE_MPI") == "1",
|
||||
reason="This test is skipped for Ray orchestrator.")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user