mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
293 lines
12 KiB
C++
293 lines
12 KiB
C++
/*
|
|
* SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#define UCX_WRAPPER_LIB_NAME "tensorrt_llm_ucx_wrapper"
|
|
|
|
#if defined(_WIN32)
|
|
#include <windows.h>
|
|
#define dllOpen(name) LoadLibrary(name ".dll")
|
|
#define dllClose(handle) FreeLibrary(static_cast<HMODULE>(handle))
|
|
#define dllGetSym(handle, name) static_cast<void*>(GetProcAddress(static_cast<HMODULE>(handle), name))
|
|
#else // For non-Windows platforms
|
|
#include <dlfcn.h>
|
|
#define dllOpen(name) dlopen("lib" name ".so", RTLD_LAZY)
|
|
#define dllClose(handle) dlclose(handle)
|
|
#define dllGetSym(handle, name) dlsym(handle, name)
|
|
#endif // defined(_WIN32)
|
|
|
|
#include "tensorrt_llm/batch_manager/cacheFormatter.h"
|
|
#include "tensorrt_llm/batch_manager/cacheTransceiver.h"
|
|
#include "tensorrt_llm/batch_manager/dataTransceiver.h"
|
|
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
|
|
#include "tensorrt_llm/common/assert.h"
|
|
#include "tensorrt_llm/common/cudaUtils.h"
|
|
#include "tensorrt_llm/common/envUtils.h"
|
|
#include "tensorrt_llm/executor/cache_transmission/mpi_utils/connection.h"
|
|
#include "tensorrt_llm/executor/dataTransceiverState.h"
|
|
#include "tensorrt_llm/executor/executor.h"
|
|
#include "tensorrt_llm/runtime/common.h"
|
|
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
|
|
#include "gtest/gtest.h"
|
|
#include <csignal>
|
|
#include <cstddef>
|
|
#include <cstdint>
|
|
#include <cstdio>
|
|
#include <cstdlib>
|
|
#include <gmock/gmock.h>
|
|
#include <memory>
|
|
#include <random>
|
|
#include <tensorrt_llm/batch_manager/mlaCacheFormatter.h>
|
|
#include <tensorrt_llm/executor/cache_transmission/cacheSplitConcat.h>
|
|
|
|
using SizeType32 = tensorrt_llm::runtime::SizeType32;
|
|
using LlmRequest = tensorrt_llm::batch_manager::LlmRequest;
|
|
using namespace tensorrt_llm::batch_manager::kv_cache_manager;
|
|
using namespace tensorrt_llm::batch_manager;
|
|
namespace texec = tensorrt_llm::executor;
|
|
|
|
namespace
|
|
{
|
|
std::mutex mDllMutex;
|
|
|
|
std::unique_ptr<texec::kv_cache::ConnectionManager> makeOneUcxConnectionManager()
|
|
{
|
|
std::lock_guard<std::mutex> lock(mDllMutex);
|
|
void* WrapperLibHandle{nullptr};
|
|
WrapperLibHandle = dllOpen(UCX_WRAPPER_LIB_NAME);
|
|
TLLM_CHECK_WITH_INFO(WrapperLibHandle != nullptr, "UCX wrapper library is not open correctly.");
|
|
auto load_sym = [](void* handle, char const* name)
|
|
{
|
|
void* ret = dllGetSym(handle, name);
|
|
|
|
TLLM_CHECK_WITH_INFO(ret != nullptr,
|
|
"Unable to load UCX wrapper library symbol, possible cause is that TensorRT LLM library is not "
|
|
"built with UCX support, please rebuild in UCX-enabled environment.");
|
|
return ret;
|
|
};
|
|
std::unique_ptr<tensorrt_llm::executor::kv_cache::ConnectionManager> (*makeUcxConnectionManager)();
|
|
*(void**) (&makeUcxConnectionManager) = load_sym(WrapperLibHandle, "makeUcxConnectionManager");
|
|
return makeUcxConnectionManager();
|
|
}
|
|
|
|
class UcxCommTest : public ::testing::Test
|
|
{
|
|
};
|
|
|
|
using DataContext = tensorrt_llm::executor::kv_cache::DataContext;
|
|
using TransceiverTag = tensorrt_llm::batch_manager::TransceiverTag;
|
|
|
|
TEST_F(UcxCommTest, Basic)
|
|
{
|
|
|
|
try
|
|
{
|
|
TransceiverTag::Id id1;
|
|
TransceiverTag::Id id2;
|
|
|
|
auto connectionManager1 = makeOneUcxConnectionManager();
|
|
EXPECT_NE(connectionManager1, nullptr);
|
|
auto connectionManager2 = makeOneUcxConnectionManager();
|
|
EXPECT_NE(connectionManager2, nullptr);
|
|
auto CommState1 = connectionManager1->getCommState();
|
|
auto CommState2 = connectionManager2->getCommState();
|
|
ASSERT_EQ(CommState1.isSocketState(), true);
|
|
ASSERT_EQ(CommState2.isSocketState(), true);
|
|
|
|
auto connections1 = connectionManager2->getConnections(CommState1);
|
|
ASSERT_EQ(connections1.size(), 1);
|
|
auto connection1 = connections1[0];
|
|
id1 = TransceiverTag::Id::REQUEST_SEND;
|
|
connection1->send(DataContext{TransceiverTag::kID_TAG}, &id1, sizeof(id1));
|
|
|
|
auto connection1Peer = connectionManager1->recvConnect(DataContext{TransceiverTag::kID_TAG}, &id2, sizeof(id2));
|
|
ASSERT_EQ(id2, id1);
|
|
constexpr size_t bufferSize = 1024;
|
|
std::vector<char> buffer(bufferSize);
|
|
// Fill buffer with random data
|
|
std::generate(buffer.begin(), buffer.end(), []() { return static_cast<char>(std::rand()); });
|
|
|
|
connection1->send(DataContext{0x74}, buffer.data(), buffer.size());
|
|
|
|
std::vector<char> recvBuffer(buffer.size());
|
|
|
|
connection1Peer->recv(DataContext{0x74}, recvBuffer.data(), recvBuffer.size());
|
|
|
|
ASSERT_EQ(memcmp(buffer.data(), recvBuffer.data(), buffer.size()), 0);
|
|
|
|
// Test with CUDA memory
|
|
tensorrt_llm::runtime::BufferManager bufferManager{std::make_shared<tensorrt_llm::runtime::CudaStream>()};
|
|
|
|
// Create and fill source CUDA buffer with random data
|
|
auto srcBuffer = bufferManager.gpu(buffer.size(), nvinfer1::DataType::kINT8);
|
|
bufferManager.copy(buffer.data(), *srcBuffer);
|
|
bufferManager.getStream().synchronize();
|
|
|
|
auto dstBuffer = bufferManager.gpu(buffer.size(), nvinfer1::DataType::kINT8);
|
|
|
|
// Send CUDA buffer using connection1
|
|
connection1->send(DataContext{0x75}, srcBuffer->data(), srcBuffer->getSizeInBytes());
|
|
|
|
// Receive into CUDA buffer using connection1Peer
|
|
connection1Peer->recv(DataContext{0x75}, dstBuffer->data(), dstBuffer->getSizeInBytes());
|
|
|
|
std::vector<char> recvCudaBuffer(buffer.size());
|
|
bufferManager.copy(*dstBuffer, recvCudaBuffer.data(), dstBuffer->getMemoryType());
|
|
bufferManager.getStream().synchronize();
|
|
|
|
ASSERT_EQ(memcmp(buffer.data(), recvCudaBuffer.data(), buffer.size()), 0);
|
|
}
|
|
catch (std::exception const& e)
|
|
{
|
|
std::string error = e.what();
|
|
if (error.find("UCX wrapper library is not open correctly") != std::string::npos
|
|
|| error.find("Unable to load UCX wrapper library symbol") != std::string::npos)
|
|
{
|
|
GTEST_SKIP() << "UCX wrapper library is not open correctly. Skip this test case.";
|
|
}
|
|
|
|
throw e;
|
|
}
|
|
}
|
|
|
|
TEST_F(UcxCommTest, multiSend)
|
|
{
|
|
try
|
|
{
|
|
TransceiverTag::Id id1;
|
|
TransceiverTag::Id id2;
|
|
TransceiverTag::Id id1Peer;
|
|
TransceiverTag::Id id2Peer;
|
|
|
|
auto manager1 = makeOneUcxConnectionManager();
|
|
auto manager2 = makeOneUcxConnectionManager();
|
|
auto managerRecv = makeOneUcxConnectionManager();
|
|
|
|
auto connection1 = managerRecv->getConnections(manager1->getCommState())[0];
|
|
auto connection2 = managerRecv->getConnections(manager2->getCommState())[0];
|
|
id1 = TransceiverTag::Id::REQUEST_SEND;
|
|
id2 = TransceiverTag::Id::REQUEST_SEND;
|
|
connection1->send(DataContext{TransceiverTag::kID_TAG}, &id1, sizeof(id1));
|
|
connection2->send(DataContext{TransceiverTag::kID_TAG}, &id2, sizeof(id2));
|
|
auto connection1Peer = manager1->recvConnect(DataContext{TransceiverTag::kID_TAG}, &id1Peer, sizeof(id1Peer));
|
|
auto connection2Peer = manager2->recvConnect(DataContext{TransceiverTag::kID_TAG}, &id2Peer, sizeof(id2Peer));
|
|
ASSERT_EQ(id1Peer, id1);
|
|
ASSERT_EQ(id2Peer, id2);
|
|
constexpr size_t bufferSize = 1024;
|
|
std::vector<char> buffer1(bufferSize);
|
|
std::vector<char> buffer2(bufferSize);
|
|
std::generate(buffer1.begin(), buffer1.end(), []() { return static_cast<char>(std::rand()); });
|
|
std::generate(buffer2.begin(), buffer2.end(), []() { return static_cast<char>(std::rand()); });
|
|
|
|
connection1Peer->send(DataContext{0x74}, buffer1.data(), buffer1.size());
|
|
connection2Peer->send(DataContext{0x74}, buffer2.data(), buffer2.size());
|
|
|
|
std::vector<char> recvBuffer1(buffer1.size());
|
|
std::vector<char> recvBuffer2(buffer2.size());
|
|
connection2->recv(DataContext{0x74}, recvBuffer2.data(), recvBuffer2.size());
|
|
connection1->recv(DataContext{0x74}, recvBuffer1.data(), recvBuffer1.size());
|
|
ASSERT_EQ(memcmp(buffer1.data(), recvBuffer1.data(), buffer1.size()), 0);
|
|
ASSERT_EQ(memcmp(buffer2.data(), recvBuffer2.data(), buffer2.size()), 0);
|
|
|
|
tensorrt_llm::runtime::BufferManager bufferManager{std::make_shared<tensorrt_llm::runtime::CudaStream>()};
|
|
|
|
auto srcBuffer1 = bufferManager.gpu(buffer1.size(), nvinfer1::DataType::kINT8);
|
|
auto srcBuffer2 = bufferManager.gpu(buffer2.size(), nvinfer1::DataType::kINT8);
|
|
bufferManager.copy(buffer1.data(), *srcBuffer1);
|
|
bufferManager.copy(buffer2.data(), *srcBuffer2);
|
|
bufferManager.getStream().synchronize();
|
|
|
|
auto dstBuffer1 = bufferManager.gpu(buffer1.size(), nvinfer1::DataType::kINT8);
|
|
auto dstBuffer2 = bufferManager.gpu(buffer2.size(), nvinfer1::DataType::kINT8);
|
|
|
|
connection1Peer->send(DataContext{0x75}, srcBuffer1->data(), srcBuffer1->getSizeInBytes());
|
|
connection2Peer->send(DataContext{0x75}, srcBuffer2->data(), srcBuffer2->getSizeInBytes());
|
|
connection2->recv(DataContext{0x75}, dstBuffer2->data(), dstBuffer2->getSizeInBytes());
|
|
|
|
connection1->recv(DataContext{0x75}, dstBuffer1->data(), dstBuffer1->getSizeInBytes());
|
|
std::vector<char> recvCudaBuffer1(buffer1.size());
|
|
std::vector<char> recvCudaBuffer2(buffer2.size());
|
|
bufferManager.copy(*dstBuffer1, recvCudaBuffer1.data(), dstBuffer1->getMemoryType());
|
|
bufferManager.copy(*dstBuffer2, recvCudaBuffer2.data(), dstBuffer2->getMemoryType());
|
|
bufferManager.getStream().synchronize();
|
|
|
|
ASSERT_EQ(memcmp(buffer1.data(), recvCudaBuffer1.data(), buffer1.size()), 0);
|
|
ASSERT_EQ(memcmp(buffer2.data(), recvCudaBuffer2.data(), buffer2.size()), 0);
|
|
}
|
|
catch (std::exception const& e)
|
|
{
|
|
std::string error = e.what();
|
|
if (error.find("UCX wrapper library is not open correctly") != std::string::npos
|
|
|| error.find("Unable to load UCX wrapper library symbol") != std::string::npos)
|
|
{
|
|
GTEST_SKIP() << "UCX wrapper library is not open correctly. Skip this test case.";
|
|
}
|
|
|
|
throw e;
|
|
}
|
|
}
|
|
|
|
TEST_F(UcxCommTest, CommCache)
|
|
{
|
|
|
|
try
|
|
{
|
|
TransceiverTag::Id id1;
|
|
TransceiverTag::Id id2;
|
|
|
|
auto connectionManager1 = makeOneUcxConnectionManager();
|
|
EXPECT_NE(connectionManager1, nullptr);
|
|
auto connectionManager2 = makeOneUcxConnectionManager();
|
|
EXPECT_NE(connectionManager2, nullptr);
|
|
auto CommState1 = connectionManager1->getCommState();
|
|
auto CommState2 = connectionManager2->getCommState();
|
|
ASSERT_EQ(CommState1.isSocketState(), true);
|
|
ASSERT_EQ(CommState2.isSocketState(), true);
|
|
|
|
auto connections1 = connectionManager2->getConnections(CommState1);
|
|
ASSERT_EQ(connections1.size(), 1);
|
|
auto connection1 = connections1[0];
|
|
id1 = TransceiverTag::Id::REQUEST_SEND;
|
|
connection1->send(DataContext{TransceiverTag::kID_TAG}, &id1, sizeof(id1));
|
|
|
|
auto connection1Peer = connectionManager1->recvConnect(DataContext{TransceiverTag::kID_TAG}, &id2, sizeof(id2));
|
|
ASSERT_EQ(id2, id1);
|
|
auto connection1Cached = connectionManager2->getConnections(CommState1)[0];
|
|
ASSERT_EQ(connection1Cached, connection1);
|
|
id1 = TransceiverTag::Id::REQUEST_SEND;
|
|
connection1Cached->send(DataContext{TransceiverTag::kID_TAG}, &id1, sizeof(id1));
|
|
|
|
auto connection1PeerCached
|
|
= connectionManager1->recvConnect(DataContext{TransceiverTag::kID_TAG}, &id2, sizeof(id2));
|
|
ASSERT_EQ(id2, id1);
|
|
|
|
ASSERT_EQ(connection1PeerCached, connection1Peer);
|
|
}
|
|
catch (std::exception const& e)
|
|
{
|
|
std::string error = e.what();
|
|
if (error.find("UCX wrapper library is not open correctly") != std::string::npos
|
|
|| error.find("Unable to load UCX wrapper library symbol") != std::string::npos)
|
|
{
|
|
GTEST_SKIP() << "UCX wrapper library is not open correctly. Skip this test case.";
|
|
}
|
|
|
|
throw e;
|
|
}
|
|
}
|
|
|
|
}; // namespace
|