[https://nvbugs/5716787][fix] terminate nixl running when exiting (#9785)

Signed-off-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com>
Co-authored-by: Patrice Castonguay <55748270+pcastonguay@users.noreply.github.com>
This commit is contained in:
Chuang Zhu 2025-12-13 00:15:02 +08:00 committed by GitHub
parent 9c59c9f920
commit 4cc4cbe926
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 54 additions and 4 deletions

View File

@ -66,6 +66,7 @@ public:
[[nodiscard]] virtual std::vector<Connection const*> getConnections(CommState const& state) = 0;
[[nodiscard]] virtual CommState const& getCommState() const = 0;
[[nodiscard]] virtual bool isRunning() const = 0;
};
} // namespace tensorrt_llm::executor::kv_cache

View File

@ -360,6 +360,12 @@ public:
RequestInfo info;
auto const* connection = isAgent ? agentConnectionManager->recvConnectionAndRequestInfo(info)
: mManager->recvConnect(DataContext{TransceiverTag::kID_TAG}, &id, sizeof(id));
if (connection == nullptr && !mManager->isRunning())
{
TLLM_LOG_WARNING(" recvRequestInfo connection is nullptr, maybe the server is terminating");
return info;
}
if (!isAgent)
{
TLLM_CHECK(id == TransceiverTag::Id::REQUEST_SEND);
@ -616,6 +622,10 @@ private:
if (!mReadyResponses.empty())
{
auto const& requestInfo = recvRequestInfo();
if (mTerminate || !mManager->isRunning())
{
return;
}
auto reqId = requestInfo.getRequestId();
{

View File

@ -319,6 +319,10 @@ AgentConnection const* AgentConnectionManager::recvConnectionAndRequestInfo(batc
{
while (true)
{
if (!mIsRunning)
{
return nullptr;
}
updateUnhandledNotifications();
std::scoped_lock lock(mNotificationMutex);
auto it = mUnhandledNotifications.begin();
@ -491,6 +495,11 @@ void AgentConnectionManager::waitForNotification(std::string const& remoteAgentN
while (true)
{
if (!mIsRunning)
{
return;
}
updateUnhandledNotifications();
std::scoped_lock lock(mNotificationMutex);
auto it = mUnhandledNotifications.begin();
@ -587,6 +596,13 @@ std::string const& AgentConnectionManager::getAgentName() const
AgentConnectionManager::~AgentConnectionManager()
{
mIsRunning = false;
m_Agent->deregisterMemory(mRegMemDescs);
}
bool AgentConnectionManager::isRunning() const
{
return mIsRunning;
}
} // namespace tensorrt_llm::executor::kv_cache

View File

@ -296,6 +296,7 @@ public:
void waitForNotification(std::string const& remoteAgentName, NotificationType& expectedInfo);
void waitForSyncInfo(std::string const& remoteAgentName, NotificationSyncInfo& syncInfo);
void waitForReadySignal(std::string const& remoteAgentName, ReadySignalInfo& readySignalInfo);
[[nodiscard]] bool isRunning() const override;
private:
std::map<std::string, std::shared_ptr<AgentConnection>> mConnections;
@ -309,6 +310,7 @@ private:
int mDeviceId;
std::string mAgentName;
MemoryDescs mRegMemDescs;
std::atomic<bool> mIsRunning{true};
};
} // namespace tensorrt_llm::executor::kv_cache

View File

@ -77,4 +77,13 @@ CommState const& MpiConnectionManager::getCommState() const
return mCommState;
}
bool MpiConnectionManager::isRunning() const
{
return mIsRunning;
}
MpiConnectionManager::~MpiConnectionManager()
{
mIsRunning = false;
}
} // namespace tensorrt_llm::executor::kv_cache

View File

@ -42,14 +42,17 @@ class MpiConnectionManager : public ConnectionManager
{
public:
MpiConnectionManager(mpi::MpiComm const* comm);
~MpiConnectionManager();
MpiConnection const* recvConnect(DataContext const& ctx, void* data, size_t size) override;
[[nodiscard]] std::vector<Connection const*> getConnections(CommState const& state) override;
[[nodiscard]] CommState const& getCommState() const override;
[[nodiscard]] bool isRunning() const override;
private:
mpi::MpiComm const* mComm;
std::map<int, MpiConnection> mConnections;
CommState mCommState;
std::atomic<bool> mIsRunning{true};
};
} // namespace tensorrt_llm::executor::kv_cache

View File

@ -504,7 +504,7 @@ UcxConnectionManager::~UcxConnectionManager()
socket.close();
mZmqRepThread.join();
}
mIsRunning = false;
mZmqRepSocket.close();
mZmqContext.close();
@ -673,6 +673,11 @@ std::vector<Connection const*> UcxConnectionManager::getConnections(CommState co
return ret;
}
bool UcxConnectionManager::isRunning() const
{
return mIsRunning;
}
CommState const& UcxConnectionManager::getCommState() const
{
return mCommState;

View File

@ -62,6 +62,7 @@ private:
zmq::socket_t mZmqRepSocket;
std::string mZmqRepEndpoint;
std::thread mZmqRepThread;
std::atomic<bool> mIsRunning{true};
UcxConnection::ConnectionIdType getNewConnectionId(std::shared_ptr<ucxx::Endpoint> const& newEp);
UcxConnection::ConnectionIdType addConnection(std::string const& ip, uint16_t port);
@ -85,6 +86,8 @@ public:
{
return mRank;
}
[[nodiscard]] bool isRunning() const override;
};
#if defined(__clang__)

View File

@ -1052,8 +1052,9 @@ def test_llm_context_only_timed_out():
@pytest.mark.part0
@skip_ray
@pytest.mark.parametrize("sender_future_timeout_ms", [100, 1000])
def test_llm_context_only_timed_out_kv_cache_exhausted(
sender_future_timeout_ms):
@pytest.mark.parametrize("backend", ["NIXL", "UCX"])
def test_llm_context_only_timed_out_kv_cache_exhausted(sender_future_timeout_ms,
backend):
tp_size = 1
use_overlap = False
enable_iter_req_stats = False
@ -1073,7 +1074,7 @@ def test_llm_context_only_timed_out_kv_cache_exhausted(
kv_cache_config=kv_cache_config,
tensor_parallel_size=tp_size,
cache_transceiver_config=CacheTransceiverConfig(
backend="UCX",
backend=backend,
kv_transfer_timeout_ms=1000,
kv_transfer_sender_future_timeout_ms=sender_future_timeout_ms),
**llm_args_extra)