TensorRT-LLMs/cpp/tests/unit_tests/runtime/loraUtilsTest.cpp
Dan Blanaru 16d2467ea8 Update TensorRT-LLM (#2755)
* Update TensorRT-LLM

---------

Co-authored-by: Denis Kayshev <topenkoff@gmail.com>
Co-authored-by: akhoroshev <arthoroshev@gmail.com>
Co-authored-by: Patrick Reiter Horn <patrick.horn@gmail.com>

Update
2025-02-11 03:01:00 +00:00

124 lines
4.5 KiB
C++

/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "tensorrt_llm/runtime/bufferManager.h"
#include "tensorrt_llm/runtime/iBuffer.h"
#include "tensorrt_llm/runtime/iTensor.h"
#include "tensorrt_llm/runtime/loraModule.h"
#include "tensorrt_llm/runtime/loraUtils.h"
#include "tensorrt_llm/runtime/modelConfig.h"
#include "tensorrt_llm/runtime/worldConfig.h"
#include <NvInferRuntime.h>
#include <algorithm>
#include <optional>
#include <stdexcept>
namespace tensorrt_llm::runtime::lora
{
using TensorPtr = ITensor::SharedPtr;
class LoraUtilsTest : public ::testing::Test // NOLINT(cppcoreguidelines-pro-type-member-init)
{
protected:
LoraUtilsTest() {}
void SetUp() override
{
mStream = std::make_unique<CudaStream>();
mManager = std::make_unique<BufferManager>(mStream);
}
std::unique_ptr<BufferManager> mManager;
BufferManager::CudaStreamPtr mStream;
};
TEST_F(LoraUtilsTest, null_values)
{
std::optional<TensorPtr> optReqLoraWeights = std::nullopt;
std::optional<TensorPtr> optReqLoraConfig = mManager->emptyTensor(MemoryType::kCPU, nvinfer1::DataType::kHALF);
EXPECT_THAT([&]() { loraValidateRequestTensorDims(optReqLoraWeights, optReqLoraConfig); },
testing::Throws<std::runtime_error>());
optReqLoraConfig = std::nullopt;
EXPECT_THAT([&]() { loraValidateRequestTensorDims(optReqLoraWeights, optReqLoraConfig); },
testing::Throws<std::runtime_error>());
}
TEST_F(LoraUtilsTest, dims_mem_type)
{
std::optional<TensorPtr> optReqLoraWeights = mManager->cpu(ITensor::makeShape({1, 2}), nvinfer1::DataType::kHALF);
std::optional<TensorPtr> optReqLoraConfig
= mManager->cpu(ITensor::makeShape({1, 2, 3}), nvinfer1::DataType::kINT32);
EXPECT_THAT([&]() { loraValidateRequestTensorDims(optReqLoraWeights, optReqLoraConfig); },
testing::Throws<std::runtime_error>());
std::optional<TensorPtr> optGpuWeights = mManager->gpu(ITensor::makeShape({1, 2, 50}), nvinfer1::DataType::kHALF);
EXPECT_THAT([&]() { loraValidateRequestTensorDims(optGpuWeights, optReqLoraConfig); },
testing::Throws<std::runtime_error>());
optReqLoraWeights = mManager->cpu(ITensor::makeShape({1, 2, 50}), nvinfer1::DataType::kHALF);
optReqLoraConfig = mManager->cpu(ITensor::makeShape({1, 2, 3}), nvinfer1::DataType::kINT32);
loraValidateRequestTensorDims(optReqLoraWeights, optReqLoraConfig);
}
TEST_F(LoraUtilsTest, loraValidateRequestTensors)
{
auto modelConfig = ModelConfig(0, 2, 2, 0, 1, 4, nvinfer1::DataType::kFLOAT);
auto worldConfig = WorldConfig();
std::optional<TensorPtr> optReqLoraWeights
= mManager->cpu(ITensor::makeShape({1, 2, 32}), nvinfer1::DataType::kFLOAT);
std::optional<TensorPtr> optReqLoraConfig
= mManager->cpu(ITensor::makeShape({1, 2, 3}), nvinfer1::DataType::kINT32);
std::vector<int32_t> config{1, 0, 4, 1, 1, 4};
auto configPtr = bufferCast<int32_t>(*optReqLoraConfig.value());
std::copy_n(config.data(), config.size(), configPtr);
EXPECT_THAT([&]()
{ loraValidateRequestTensors(12345, optReqLoraWeights, optReqLoraConfig, modelConfig, worldConfig); },
testing::Throws<std::runtime_error>());
std::vector<LoraModule> modules{
LoraModule(LoraModule::ModuleType::kATTN_Q, 4, 4, false, true, -1, 0),
};
modelConfig.setLoraModules(modules);
EXPECT_THAT([&]()
{ loraValidateRequestTensors(12345, optReqLoraWeights, optReqLoraConfig, modelConfig, worldConfig); },
testing::Throws<std::runtime_error>());
modelConfig.setMaxLoraRank(4);
loraValidateRequestTensors(12345, optReqLoraWeights, optReqLoraConfig, modelConfig, worldConfig);
EXPECT_THAT([&]()
{ loraValidateRequestTensors(std::nullopt, optReqLoraWeights, optReqLoraConfig, modelConfig, worldConfig); },
testing::Throws<std::runtime_error>());
}
} // namespace tensorrt_llm::runtime::lora