TensorRT-LLMs/cpp/tensorrt_llm/runtime/tllmRuntime.cpp
Kaiyu Xie 1730a587d8
Update TensorRT-LLM (#2363)
* Update TensorRT-LLM

---------

Co-authored-by: tonylek <137782967+tonylek@users.noreply.github.com>
2024-10-22 20:27:35 +08:00

596 lines
22 KiB
C++

/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tllmRuntime.h"
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/common/mpiUtils.h"
#include "tensorrt_llm/common/nvtxUtils.h"
#include "tensorrt_llm/common/safetensors.h"
#include "tensorrt_llm/executor/tensor.h"
#include "tllmLogger.h"
#include <algorithm>
#include <iterator>
#include <limits>
#include <memory>
#include <type_traits>
using namespace tensorrt_llm::runtime;
using TensorMap = StringPtrMap<ITensor>;
namespace
{
static_assert(std::is_signed<SizeType32>::value, "SizeType32 must be signed");
nvinfer1::Dims shapeToDims(std::vector<std::size_t> const& shape)
{
TLLM_CHECK(shape.size() <= nvinfer1::Dims::MAX_DIMS);
nvinfer1::Dims dims;
auto constexpr dim_max = std::numeric_limits<ITensor::DimType64>::max();
dims.nbDims = static_cast<std::int32_t>(shape.size());
for (std::size_t i = 0; i < shape.size(); ++i)
{
// shape[i] >= 0 because it has unsigned type. Check upper bound:
TLLM_CHECK(shape[i] <= static_cast<std::size_t>(dim_max));
dims.d[i] = static_cast<ITensor::DimType64>(shape[i]);
}
return dims;
}
std::vector<std::size_t> dimsToShape(nvinfer1::Dims const& dims)
{
TLLM_CHECK(dims.nbDims >= 0);
std::vector<std::size_t> shape(dims.nbDims);
for (std::int32_t i = 0; i < dims.nbDims; ++i)
{
TLLM_CHECK(dims.d[i] >= 0);
shape[i] = static_cast<std::size_t>(dims.d[i]);
}
return shape;
}
tensorrt_llm::runtime::TllmLogger defaultLogger{};
class StreamReader final : public nvinfer1::IStreamReader
{
public:
StreamReader(std::filesystem::path fp)
{
mFile.open(fp.string(), std::ios::binary | std::ios::in);
TLLM_CHECK_WITH_INFO(mFile.good(), std::string("Error opening engine file: " + fp.string()));
}
virtual ~StreamReader()
{
if (mFile.is_open())
{
mFile.close();
}
}
int64_t read(void* destination, int64_t nbBytes) final
{
if (!mFile.good())
{
return -1;
}
mFile.read(static_cast<char*>(destination), nbBytes);
return mFile.gcount();
}
std::ifstream mFile;
};
void setWeightStreaming(nvinfer1::ICudaEngine& engine, float const gpuWeightsPercent)
{
if (gpuWeightsPercent < 1)
{
int64_t streamableSize = engine.getStreamableWeightsSize();
int64_t budget = gpuWeightsPercent * streamableSize;
TLLM_LOG_INFO("Set gpu weights percent to %f, which is %lld bytes. Valid range: %lld bytes - %lld bytes.",
gpuWeightsPercent, budget, 0, streamableSize);
engine.setWeightStreamingBudgetV2(budget);
}
}
} // namespace
TllmRuntime::TllmRuntime(
RawEngine const& rawEngine, nvinfer1::ILogger* logger, float gpuWeightsPercent, bool useShapeInference)
: mStream(std::make_shared<CudaStream>())
, mBufferManager{mStream, true} // Ensure to trim the memory pool on destruction.
, mRuntime{nvinfer1::createInferRuntime(logger ? *logger : defaultLogger)}
, mUseShapeInference{useShapeInference}
{
switch (rawEngine.getType())
{
case RawEngine::Type::FilePath:
{
auto reader = StreamReader(rawEngine.getPath());
mEngine.reset(mRuntime->deserializeCudaEngine(reader));
break;
}
case RawEngine::Type::AddressWithSize:
mEngine.reset(mRuntime->deserializeCudaEngine(rawEngine.getAddress(), rawEngine.getSize()));
break;
case RawEngine::Type::HostMemory:
mEngine.reset(
mRuntime->deserializeCudaEngine(rawEngine.getHostMemory()->data(), rawEngine.getHostMemory()->size()));
break;
default: TLLM_THROW("Unsupported raw engine type.");
}
TLLM_CHECK_WITH_INFO(mEngine != nullptr, "Failed to deserialize cuda engine.");
mEngineInspector.reset(mEngine->createEngineInspector());
setWeightStreaming(getEngine(), gpuWeightsPercent);
auto const devMemorySize = mEngine->getDeviceMemorySizeV2();
mEngineBuffer = mBufferManager.gpu(devMemorySize);
// Print context memory size for CI/CD to track.
TLLM_LOG_INFO("[MemUsageChange] Allocated %.2f MiB for execution context memory.",
static_cast<double>(devMemorySize) / 1048576.0);
cacheTensorNames();
}
void TllmRuntime::cacheTensorNames()
{
for (std::int32_t i = 0; i < mEngine->getNbIOTensors(); ++i)
{
auto const* const name = mEngine->getIOTensorName(i);
if (mEngine->getTensorIOMode(name) == nvinfer1::TensorIOMode::kINPUT)
{
mInputTensorNames.emplace_back(name);
}
else if (mEngine->getTensorIOMode(name) == nvinfer1::TensorIOMode::kOUTPUT)
{
mOutputTensorNames.emplace_back(name);
}
}
}
nvinfer1::IExecutionContext& TllmRuntime::addContext(std::int32_t profileIndex)
{
TLLM_CHECK(0 <= profileIndex && profileIndex < mEngine->getNbOptimizationProfiles());
mContexts.emplace_back(mEngine->createExecutionContextWithoutDeviceMemory());
if (!mContexts.back())
{
if (mEngine->getStreamableWeightsSize() > 0)
{
TLLM_THROW("Failed to allocate memory for weights. Please try reducing --gpu_weights_percent.");
}
else
{
TLLM_THROW("Internal Error: Failed to create an execution context.");
}
}
auto& context = *mContexts.back();
context.setDeviceMemoryV2(mEngineBuffer->data(), static_cast<int64_t>(mEngineBuffer->getCapacity()));
if (tensorrt_llm::common::Logger::getLogger()->isEnabled(tensorrt_llm::common::Logger::TRACE)
&& mContexts.size() == 1)
{
printEngineInfo();
}
context.setOptimizationProfileAsync(profileIndex, mStream->get());
// If nvtx verbosity is DETAILED, print an info about potential perf overhead.
if (context.getNvtxVerbosity() == nvinfer1::ProfilingVerbosity::kDETAILED)
{
TLLM_LOG_INFO(
"The engine was built with kDETAILED profiling verbosity, which may result in small overheads at runtime.");
}
return context;
}
void TllmRuntime::printEngineInfo()
{
auto& context = *(mContexts[0]);
int const nIO = mEngine->getNbIOTensors(); // Count of Input / Output tensor
int const nOP = mEngine->getNbOptimizationProfiles(); // Count of Optimization Profile
std::size_t mwn = 0; // Maximum Width of tensor Name
std::size_t mws = 0; // Maximum Width of tensor Shape
// Get information of engine input / output
std::vector<std::string> tensorNameList{};
for (int i = 0; i < nIO; ++i)
{
tensorNameList.push_back(std::string(mEngine->getIOTensorName(i)));
}
std::vector<std::map<std::string, std::string>> tiv(nIO); // Tensor Information Vector
std::vector<std::vector<std::vector<nvinfer1::Dims64>>> topv(nIO); // Tensor Optimization Profile Vector
for (int i = 0; i < nIO; ++i)
{
std::string name{tensorNameList[i]};
char const* nameC{name.c_str()}; // name of C-style
mwn = std::max(mwn, name.size());
tiv[i]["mode"] = mEngine->getTensorIOMode(nameC) == nvinfer1::TensorIOMode::kINPUT ? "I" : "O";
tiv[i]["location"] = mEngine->getTensorLocation(nameC) == nvinfer1::TensorLocation::kDEVICE ? "GPU" : "CPU";
tiv[i]["data_type"] = dataTypeToString(mEngine->getTensorDataType(nameC));
tiv[i]["build_shape"] = shapeToString(mEngine->getTensorShape(nameC));
mws = std::max(mws, tiv[i]["build_shape"].size());
if (tiv[i]["mode"] == "I")
{
std::vector<std::vector<nvinfer1::Dims64>> topPerTensor(nOP);
for (int k = 0; k < nOP; ++k)
{
if (tiv[i]["location"] == std::string("GPU"))
{
std::vector<nvinfer1::Dims64> top(3);
top[0] = mEngine->getProfileShape(nameC, k, nvinfer1::OptProfileSelector::kMIN);
top[1] = mEngine->getProfileShape(nameC, k, nvinfer1::OptProfileSelector::kOPT);
top[2] = mEngine->getProfileShape(nameC, k, nvinfer1::OptProfileSelector::kMAX);
topPerTensor[k] = top;
mws = std::max(mws, shapeToString(top[2]).size());
}
else
{
// Shape input tensor, not used in TRT-LLM support yet
std::vector<nvinfer1::Dims64> top(3);
int const nDim = mEngine->getTensorShape(nameC).nbDims;
nvinfer1::Dims64 tensorShape{nDim, {-1}};
int const* pos = nullptr;
pos = mEngine->getProfileTensorValues(nameC, k, nvinfer1::OptProfileSelector::kMIN);
std::copy(pos, pos + nDim, tensorShape.d);
top[0] = tensorShape;
pos = mEngine->getProfileTensorValues(nameC, k, nvinfer1::OptProfileSelector::kOPT);
std::copy(pos, pos + nDim, tensorShape.d);
top[1] = tensorShape;
pos = mEngine->getProfileTensorValues(nameC, k, nvinfer1::OptProfileSelector::kMAX);
std::copy(pos, pos + nDim, tensorShape.d);
top[2] = tensorShape;
topPerTensor[k] = top;
}
}
topv[i] = topPerTensor;
}
else
{
topv[i] = std::vector<std::vector<nvinfer1::Dims64>>(nOP);
}
}
// Set input shape to get output shape
for (int k = 0; k < nOP; ++k)
{
for (int j = 0; j < 3; ++j) // Min, Opt, Max
{
for (int i = 0; i < nIO; ++i)
{
std::string name = tensorNameList[i];
char const* nameC = name.c_str();
if (tiv[i]["mode"] == "I")
{
if (tiv[i]["location"] == std::string("GPU"))
{
context.setInputShape(nameC, topv[i][k][j]);
}
else
{
// Shape input tensor, not used in TRT-LLM support yet
context.setInputTensorAddress(nameC, topv[i][k][j].d);
}
}
else
{
TLLM_CHECK_WITH_INFO(context.allInputDimensionsSpecified(), "Input dimensions not specified");
TLLM_CHECK_WITH_INFO(context.allInputShapesSpecified(), "Input shapes not specified");
if (tiv[i]["location"] == std::string("GPU"))
{
topv[i][k].push_back(context.getTensorShape(nameC));
}
else
{
// Shape input tensor, not used in TRT-LLM support yet
int const nDim = mEngine->getTensorShape(nameC).nbDims;
nvinfer1::Dims64 tensorShape{nDim, {}};
int const* pos = reinterpret_cast<int const*>(context.getTensorAddress(nameC));
std::copy(pos, pos + nDim, tensorShape.d);
topv[i][k].push_back(tensorShape);
}
}
}
}
}
// Print information of engine input / output
std::string info;
TLLM_LOG_TRACE("Information of engine input / output.");
TLLM_LOG_TRACE(std::string(mwn + mws + 24, '='));
info = alignText("Name", mwn) + "|I/O|Location|DataType|" + alignText("Shape", mws) + "|";
TLLM_LOG_TRACE(info.c_str());
TLLM_LOG_TRACE(std::string(mwn + mws + 24, '-'));
for (int i = 0; i < nIO; ++i)
{
info = alignText(tensorNameList[i], mwn, false) + "|";
info += alignText(tiv[i]["mode"], 3) + "|";
info += alignText(tiv[i]["location"], 8) + "|";
info += alignText(tiv[i]["data_type"], 8) + "|";
info += alignText(tiv[i]["build_shape"], mws) + "|";
TLLM_LOG_TRACE(info.c_str());
}
TLLM_LOG_TRACE(std::string(mwn + mws + 24, '='));
// Print information of optimization profile
TLLM_LOG_TRACE("Information of optimization profile.");
for (int k = 0; k < nOP; ++k)
{
TLLM_LOG_TRACE("Optimization Profile %d:", k);
TLLM_LOG_TRACE(std::string(mwn + mws * 3 + 4, '='));
info = alignText("Name", mwn) + "|";
info += alignText("Min", mws) + "|";
info += alignText("Opt", mws) + "|";
info += alignText("Max", mws) + "|";
TLLM_LOG_TRACE(info.c_str());
TLLM_LOG_TRACE(std::string(mwn + mws * 3 + 4, '-'));
for (int i = 0; i < nIO; ++i)
{
auto const& top = topv[i][k];
info = alignText(tensorNameList[i], mwn, false) + "|";
info += alignText(shapeToString(top[0]), mws) + "|";
info += alignText(shapeToString(top[1]), mws) + "|";
info += alignText(shapeToString(top[2]), mws) + "|";
TLLM_LOG_TRACE(info.c_str());
}
TLLM_LOG_TRACE(std::string(mwn + mws * 3 + 4, '='));
}
return;
}
void TllmRuntime::clearContexts()
{
for (auto& context : mContexts)
{
context.reset();
}
mContexts.clear();
}
bool TllmRuntime::executeContext(SizeType32 contextIndex) const
{
NVTX3_FUNC_RANGE();
auto& context = getContext(contextIndex);
auto res = context.enqueueV3(mStream->get());
sync_check_cuda_error();
return res;
}
void TllmRuntime::setInputTensorsImpl(SizeType32 contextIndex, TensorMap const& tensorMap, bool throwOnMiss)
{
NVTX3_FUNC_RANGE();
auto& context = getContext(contextIndex);
for (auto const& name : mInputTensorNames)
{
auto const pos = tensorMap.find(name);
if (pos == tensorMap.end())
{
if (throwOnMiss)
{
auto expectedShape = mEngine->getTensorShape(name.c_str());
TLLM_THROW("Input tensor '%s' not found; expected shape: %s", name.c_str(),
ITensor::toString(expectedShape).c_str());
}
else
{
continue;
}
}
auto const& tensor = pos->second;
auto const tensorDtype = tensor->getDataType();
auto const engineDtype = mEngine->getTensorDataType(name.c_str());
// WAR: TRT does not support mixed FP8 and FP16 input, so engine expects FP16 tensors.
TLLM_CHECK_WITH_INFO(tensorDtype == engineDtype
|| (tensorDtype == nvinfer1::DataType::kFP8 && engineDtype == nvinfer1::DataType::kHALF),
"%s: expected type %d, provided type %d", name.c_str(), static_cast<std::int32_t>(engineDtype),
static_cast<std::int32_t>(tensorDtype));
auto const tensorShape = tensor->getShape();
auto const setInputShapeSuccess = context.setInputShape(name.c_str(), tensorShape);
if (!setInputShapeSuccess)
{
auto const minShape
= mEngine->getProfileShape(name.c_str(), contextIndex, nvinfer1::OptProfileSelector::kMIN);
auto const maxShape
= mEngine->getProfileShape(name.c_str(), contextIndex, nvinfer1::OptProfileSelector::kMAX);
TLLM_THROW("Tensor '%s' has invalid shape %s, expected in range min %s, max %s", name.c_str(),
ITensor::toString(tensorShape).c_str(), ITensor::toString(minShape).c_str(),
ITensor::toString(maxShape).c_str());
}
auto* const data = tensor->data();
if (data)
{
context.setInputTensorAddress(name.c_str(), data);
}
else
{
TLLM_CHECK_WITH_INFO(tensor->getSize() == 0, std::string("Invalid data for tensor: ") + name.c_str());
// TensorRT runtime does not support nullptr.
if (!mDummyTensor)
{
mDummyTensor = mBufferManager.gpu(ITensor::makeShape({1}));
}
context.setInputTensorAddress(name.c_str(), mDummyTensor->data());
}
}
}
void TllmRuntime::setStaticInputTensors(TensorMap const& tensorMap)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
NVTX3_FUNC_RANGE();
TLLM_CHECK_WITH_INFO(getNbContexts() > 0, "Contexts should be created before calling setStaticInputTensors");
for (auto contextIndex = 0; contextIndex < getNbContexts(); ++contextIndex)
{
setInputTensorsImpl(contextIndex, tensorMap, false);
}
// move static input tensor names to separate vector
auto const begin = mInputTensorNames.begin();
auto end = mInputTensorNames.end();
for (auto const& [name, tensor] : tensorMap)
{
end = std::remove(begin, end, name);
}
mInputTensorNames.erase(end, mInputTensorNames.end());
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void TllmRuntime::setInputTensors(SizeType32 contextIndex, TensorMap const& tensorMap)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
NVTX3_FUNC_RANGE();
setInputTensorsImpl(contextIndex, tensorMap, true);
auto& context = getContext(contextIndex);
if (mUseShapeInference)
{
NVTX3_SCOPED_RANGE(infer_shapes);
char const* missing;
auto const nbMissing = context.inferShapes(1, &missing);
if (nbMissing > 0)
{
TLLM_THROW("Input shape not specified: %s", missing);
}
else if (nbMissing < 0)
{
TLLM_THROW("Invalid input shape");
}
}
{
NVTX3_SCOPED_RANGE(final_checks);
TLLM_CHECK_WITH_INFO(context.allInputDimensionsSpecified(), "Input dimensions not specified");
TLLM_CHECK_WITH_INFO(context.allInputShapesSpecified(), "Input shapes not specified");
}
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void TllmRuntime::setOutputTensors(SizeType32 contextIndex, TensorMap& tensorMap)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
NVTX3_FUNC_RANGE();
auto& context = getContext(contextIndex);
for (auto const& name : mOutputTensorNames)
{
auto const engineDtype = mEngine->getTensorDataType(name.c_str());
auto const pos = tensorMap.find(name);
if (pos != tensorMap.end())
{
auto const& tensor = pos->second;
auto const tensorDtype = tensor->getDataType();
// WAR: TRT does not support mixed FP8 and FP16 input, so engine expects FP16 tensors.
TLLM_CHECK_WITH_INFO(tensorDtype == engineDtype
|| (tensorDtype == nvinfer1::DataType::kFP8 && engineDtype == nvinfer1::DataType::kHALF),
"%s: expected type %d, provided type %d", name.c_str(), static_cast<std::int32_t>(engineDtype),
static_cast<std::int32_t>(tensorDtype));
if (mUseShapeInference)
{
auto const dims = context.getTensorShape(name.c_str());
tensor->reshape(dims);
}
context.setTensorAddress(name.c_str(), tensor->data());
}
else if (mUseShapeInference)
{
auto const dims = context.getTensorShape(name.c_str());
auto tensor = ITensor::SharedPtr(mBufferManager.gpu(dims, engineDtype));
tensorMap.insert(pos, std::make_pair(name, tensor));
context.setTensorAddress(name.c_str(), tensor->data());
}
else
{
TLLM_THROW("Tensor %s is not found in tensorMap and shape inference is not allowed", name.c_str());
}
}
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
CudaStream const& TllmRuntime::getStream() const
{
return *mStream;
}
bool TllmRuntime::hasLayerProfiler(SizeType32 contextId) const
{
return mContexts[contextId]->getProfiler() != nullptr;
}
void TllmRuntime::setLayerProfiler()
{
mLayerProfiler = std::make_unique<LayerProfiler>();
for (auto& context : mContexts)
{
context->setProfiler(mLayerProfiler.get());
context->setEnqueueEmitsProfile(false);
}
}
std::string TllmRuntime::getLayerProfileInfo() const
{
TLLM_CHECK(mLayerProfiler);
return mLayerProfiler->getLayerProfile();
}
void TllmRuntime::reportToProfiler(SizeType32 contextId)
{
mContexts[contextId]->reportToProfiler();
}
void TllmRuntime::loadManagedWeights(RawEngine const& rawEngine, int localRank)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
NVTX3_FUNC_RANGE();
auto& engine = getEngine();
auto& manager = getBufferManager();
if (rawEngine.getManagedWeightsMapOpt().has_value())
{
TLLM_LOG_DEBUG("Loading managed weights from raw engine");
auto executorMap = rawEngine.getManagedWeightsMapOpt().value();
for (auto const& [name, weight] : executorMap)
{
TLLM_LOG_DEBUG("Loading managed weight: %s", name.c_str());
auto iTensor = tensorrt_llm::executor::detail::toITensor(weight);
auto weightsDevice = std::shared_ptr<ITensor>{manager.copyFrom(*iTensor, MemoryType::kGPU)};
mManagedWeightsMap.insert(std::make_pair(name, weightsDevice));
}
}
else
{
TLLM_LOG_DEBUG("Loading managed weights from file");
auto const enginePath = rawEngine.getPathOpt();
TLLM_CHECK_WITH_INFO(enginePath.has_value(), "Engine path is not set.");
auto weightPath
= enginePath->parent_path() / ("rank" + std::to_string(localRank) + "_managed_weights.safetensors");
auto managed_weights = common::safetensors::ISafeTensor::open(weightPath.string().c_str());
for (auto const& name : managed_weights->keys())
{
TLLM_LOG_DEBUG("Loading managed weight: %s", name.c_str());
auto const weight = managed_weights->getTensor(name.c_str());
TLLM_CHECK(weight->dtype() == engine.getTensorDataType(name.c_str()));
auto weightsDevice
= std::shared_ptr<ITensor>{manager.allocate(MemoryType::kGPU, weight->trtDims(), weight->dtype())};
manager.copy(weight->data(), *weightsDevice, MemoryType::kCPU);
mManagedWeightsMap.insert(std::make_pair(name, weightsDevice));
}
}
setStaticInputTensors(mManagedWeightsMap);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}