mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[https://nvbugs/5716787][fix] terminate nixl running when exiting (#9785)
Signed-off-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com> Co-authored-by: Patrice Castonguay <55748270+pcastonguay@users.noreply.github.com>
This commit is contained in:
parent
9c59c9f920
commit
4cc4cbe926
@ -66,6 +66,7 @@ public:
|
||||
[[nodiscard]] virtual std::vector<Connection const*> getConnections(CommState const& state) = 0;
|
||||
|
||||
[[nodiscard]] virtual CommState const& getCommState() const = 0;
|
||||
[[nodiscard]] virtual bool isRunning() const = 0;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::executor::kv_cache
|
||||
|
||||
@ -360,6 +360,12 @@ public:
|
||||
RequestInfo info;
|
||||
auto const* connection = isAgent ? agentConnectionManager->recvConnectionAndRequestInfo(info)
|
||||
: mManager->recvConnect(DataContext{TransceiverTag::kID_TAG}, &id, sizeof(id));
|
||||
if (connection == nullptr && !mManager->isRunning())
|
||||
{
|
||||
TLLM_LOG_WARNING(" recvRequestInfo connection is nullptr, maybe the server is terminating");
|
||||
return info;
|
||||
}
|
||||
|
||||
if (!isAgent)
|
||||
{
|
||||
TLLM_CHECK(id == TransceiverTag::Id::REQUEST_SEND);
|
||||
@ -616,6 +622,10 @@ private:
|
||||
if (!mReadyResponses.empty())
|
||||
{
|
||||
auto const& requestInfo = recvRequestInfo();
|
||||
if (mTerminate || !mManager->isRunning())
|
||||
{
|
||||
return;
|
||||
}
|
||||
auto reqId = requestInfo.getRequestId();
|
||||
|
||||
{
|
||||
|
||||
@ -319,6 +319,10 @@ AgentConnection const* AgentConnectionManager::recvConnectionAndRequestInfo(batc
|
||||
{
|
||||
while (true)
|
||||
{
|
||||
if (!mIsRunning)
|
||||
{
|
||||
return nullptr;
|
||||
}
|
||||
updateUnhandledNotifications();
|
||||
std::scoped_lock lock(mNotificationMutex);
|
||||
auto it = mUnhandledNotifications.begin();
|
||||
@ -491,6 +495,11 @@ void AgentConnectionManager::waitForNotification(std::string const& remoteAgentN
|
||||
while (true)
|
||||
{
|
||||
|
||||
if (!mIsRunning)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
updateUnhandledNotifications();
|
||||
std::scoped_lock lock(mNotificationMutex);
|
||||
auto it = mUnhandledNotifications.begin();
|
||||
@ -587,6 +596,13 @@ std::string const& AgentConnectionManager::getAgentName() const
|
||||
|
||||
AgentConnectionManager::~AgentConnectionManager()
|
||||
{
|
||||
mIsRunning = false;
|
||||
m_Agent->deregisterMemory(mRegMemDescs);
|
||||
}
|
||||
|
||||
bool AgentConnectionManager::isRunning() const
|
||||
{
|
||||
return mIsRunning;
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::executor::kv_cache
|
||||
|
||||
@ -296,6 +296,7 @@ public:
|
||||
void waitForNotification(std::string const& remoteAgentName, NotificationType& expectedInfo);
|
||||
void waitForSyncInfo(std::string const& remoteAgentName, NotificationSyncInfo& syncInfo);
|
||||
void waitForReadySignal(std::string const& remoteAgentName, ReadySignalInfo& readySignalInfo);
|
||||
[[nodiscard]] bool isRunning() const override;
|
||||
|
||||
private:
|
||||
std::map<std::string, std::shared_ptr<AgentConnection>> mConnections;
|
||||
@ -309,6 +310,7 @@ private:
|
||||
int mDeviceId;
|
||||
std::string mAgentName;
|
||||
MemoryDescs mRegMemDescs;
|
||||
std::atomic<bool> mIsRunning{true};
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::executor::kv_cache
|
||||
|
||||
@ -77,4 +77,13 @@ CommState const& MpiConnectionManager::getCommState() const
|
||||
return mCommState;
|
||||
}
|
||||
|
||||
bool MpiConnectionManager::isRunning() const
|
||||
{
|
||||
return mIsRunning;
|
||||
}
|
||||
|
||||
MpiConnectionManager::~MpiConnectionManager()
|
||||
{
|
||||
mIsRunning = false;
|
||||
}
|
||||
} // namespace tensorrt_llm::executor::kv_cache
|
||||
|
||||
@ -42,14 +42,17 @@ class MpiConnectionManager : public ConnectionManager
|
||||
{
|
||||
public:
|
||||
MpiConnectionManager(mpi::MpiComm const* comm);
|
||||
~MpiConnectionManager();
|
||||
MpiConnection const* recvConnect(DataContext const& ctx, void* data, size_t size) override;
|
||||
[[nodiscard]] std::vector<Connection const*> getConnections(CommState const& state) override;
|
||||
[[nodiscard]] CommState const& getCommState() const override;
|
||||
[[nodiscard]] bool isRunning() const override;
|
||||
|
||||
private:
|
||||
mpi::MpiComm const* mComm;
|
||||
std::map<int, MpiConnection> mConnections;
|
||||
CommState mCommState;
|
||||
std::atomic<bool> mIsRunning{true};
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::executor::kv_cache
|
||||
|
||||
@ -504,7 +504,7 @@ UcxConnectionManager::~UcxConnectionManager()
|
||||
socket.close();
|
||||
mZmqRepThread.join();
|
||||
}
|
||||
|
||||
mIsRunning = false;
|
||||
mZmqRepSocket.close();
|
||||
|
||||
mZmqContext.close();
|
||||
@ -673,6 +673,11 @@ std::vector<Connection const*> UcxConnectionManager::getConnections(CommState co
|
||||
return ret;
|
||||
}
|
||||
|
||||
bool UcxConnectionManager::isRunning() const
|
||||
{
|
||||
return mIsRunning;
|
||||
}
|
||||
|
||||
CommState const& UcxConnectionManager::getCommState() const
|
||||
{
|
||||
return mCommState;
|
||||
|
||||
@ -62,6 +62,7 @@ private:
|
||||
zmq::socket_t mZmqRepSocket;
|
||||
std::string mZmqRepEndpoint;
|
||||
std::thread mZmqRepThread;
|
||||
std::atomic<bool> mIsRunning{true};
|
||||
|
||||
UcxConnection::ConnectionIdType getNewConnectionId(std::shared_ptr<ucxx::Endpoint> const& newEp);
|
||||
UcxConnection::ConnectionIdType addConnection(std::string const& ip, uint16_t port);
|
||||
@ -85,6 +86,8 @@ public:
|
||||
{
|
||||
return mRank;
|
||||
}
|
||||
|
||||
[[nodiscard]] bool isRunning() const override;
|
||||
};
|
||||
|
||||
#if defined(__clang__)
|
||||
|
||||
@ -1052,8 +1052,9 @@ def test_llm_context_only_timed_out():
|
||||
@pytest.mark.part0
|
||||
@skip_ray
|
||||
@pytest.mark.parametrize("sender_future_timeout_ms", [100, 1000])
|
||||
def test_llm_context_only_timed_out_kv_cache_exhausted(
|
||||
sender_future_timeout_ms):
|
||||
@pytest.mark.parametrize("backend", ["NIXL", "UCX"])
|
||||
def test_llm_context_only_timed_out_kv_cache_exhausted(sender_future_timeout_ms,
|
||||
backend):
|
||||
tp_size = 1
|
||||
use_overlap = False
|
||||
enable_iter_req_stats = False
|
||||
@ -1073,7 +1074,7 @@ def test_llm_context_only_timed_out_kv_cache_exhausted(
|
||||
kv_cache_config=kv_cache_config,
|
||||
tensor_parallel_size=tp_size,
|
||||
cache_transceiver_config=CacheTransceiverConfig(
|
||||
backend="UCX",
|
||||
backend=backend,
|
||||
kv_transfer_timeout_ms=1000,
|
||||
kv_transfer_sender_future_timeout_ms=sender_future_timeout_ms),
|
||||
**llm_args_extra)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user