Update TensorRT-LLM (#188)

* Update batch manager
* Update src

---------

Co-authored-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
Co-authored-by: jdemouth-nvidia <11447840+jdemouth-nvidia@users.noreply.github.com>
This commit is contained in:
Kaiyu Xie 2023-10-30 16:06:41 +08:00 committed by GitHub
parent d8b408e6dc
commit 4de32a86ae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
43 changed files with 912 additions and 2710 deletions

View File

@ -273,13 +273,9 @@ class GptServer
{
public:
GptServer(std::filesystem::path const& trtEnginePath, TrtGptModelType modelType, int32_t maxBeamWidth,
batch_scheduler::SchedulerPolicy schedulerPolicy, std::optional<int32_t> maxNumSequences,
std::optional<int32_t> maxTokensInPagedKvCache, std::optional<float> kvCacheFreeGpuMemFraction,
std::optional<bool> enableTrtOverlap, std::shared_ptr<Recorder> recorder,
std::optional<uint64_t> terminateReqId)
batch_scheduler::SchedulerPolicy schedulerPolicy, TrtGptModelOptionalParams const& optionalParams,
std::shared_ptr<Recorder> recorder, std::optional<uint64_t> terminateReqId)
{
const TrtGptModelOptionalParams& optionalParams = TrtGptModelOptionalParams(
maxNumSequences, maxTokensInPagedKvCache, kvCacheFreeGpuMemFraction, enableTrtOverlap);
mBatchManager = std::make_shared<GptManager>(
trtEnginePath, modelType, maxBeamWidth, schedulerPolicy,
[this](int max_num_requests) { return getInferenceRequests(max_num_requests); },
@ -460,10 +456,8 @@ std::pair<std::vector<std::vector<int32_t>>, std::vector<int32_t>> parseDataset(
}
void benchmarkGptManager(std::string const& modelName, std::filesystem::path const& engineDir, std::string const& type,
std::string const& datasetPath, std::shared_ptr<nvinfer1::ILogger> const& logger,
std::optional<int32_t> maxNumSequences, std::optional<int32_t> maxTokensInPagedKvCache,
std::optional<float> kvCacheFreeGpuMemFraction, std::optional<bool> enableTrtOverlap,
batch_scheduler::SchedulerPolicy schedulerPolicy)
std::string const& datasetPath, int beamWidth, std::shared_ptr<nvinfer1::ILogger> const& logger,
TrtGptModelOptionalParams const& optionalParams, batch_scheduler::SchedulerPolicy schedulerPolicy)
{
auto const worldConfig = WorldConfig::mpi(*logger);
@ -482,6 +476,11 @@ void benchmarkGptManager(std::string const& modelName, std::filesystem::path con
TLLM_LOG_ERROR(errStr);
}
ITensor::SharedPtr beamWidthBuffer = BufferManager::cpu(ITensor::makeShape({1}), nvinfer1::DataType::kINT32);
auto beamWidthBufferPtr = bufferCast<SizeType>(*beamWidthBuffer);
*beamWidthBufferPtr = beamWidth;
auto beamWidthTensor = NamedTensor(beamWidthBuffer, "beam_width");
// Load dataset
auto dataset = parseDataset(datasetPath);
std::vector<std::vector<NamedTensor>> tensors_list;
@ -494,15 +493,16 @@ void benchmarkGptManager(std::string const& modelName, std::filesystem::path con
auto input_ids_tensor = NamedTensor(nvinfer1::DataType::kINT32, input_ids_shape, "input_ids", input_ids.data());
auto request_output_len_tensor
= NamedTensor(nvinfer1::DataType::kINT32, {1, 1}, "request_output_len", &request_output_len);
std::vector<NamedTensor> tensors = {input_ids_tensor, request_output_len_tensor};
tensors_list.push_back(tensors);
std::vector<NamedTensor> tensors
= {std::move(input_ids_tensor), std::move(request_output_len_tensor), beamWidthTensor};
tensors_list.emplace_back(std::move(tensors));
}
const int maxBeamWidth = 1;
const int maxBeamWidth = beamWidth;
auto recorder = std::make_shared<Recorder>();
uint64_t terminateReqId = num_samples + 1;
auto gptServer = std::make_shared<GptServer>(engineDir, modelType, maxBeamWidth, schedulerPolicy, maxNumSequences,
maxTokensInPagedKvCache, kvCacheFreeGpuMemFraction, enableTrtOverlap, recorder, terminateReqId);
auto gptServer = std::make_shared<GptServer>(
engineDir, modelType, maxBeamWidth, schedulerPolicy, optionalParams, recorder, terminateReqId);
if (worldConfig.getRank() == 0)
{
@ -537,16 +537,18 @@ int main(int argc, char* argv[])
"type", "Batching type: IFB or V1(non-IFB) batching.", cxxopts::value<std::string>()->default_value("IFB"));
options.add_options()("dataset", "Dataset that is used for benchmarking BatchManager.",
cxxopts::value<std::string>()->default_value(""));
options.add_options()("max_num_sequences", "Max number of Sequences.", cxxopts::value<int>()->default_value("-1"));
options.add_options()(
"max_tokens_in_paged_kvcache", "Max tokens in paged K-V Cache.", cxxopts::value<int>()->default_value("-1"));
options.add_options()("kv_cache_free_gpu_mem_fraction", "K-V Cache Free Gpu Mem Fraction.",
cxxopts::value<float>()->default_value("-1"));
"beam_width", "Specify beam width you want to benchmark.", cxxopts::value<int>()->default_value("1"));
options.add_options()("max_num_sequences", "Max number of Sequences.", cxxopts::value<int>());
options.add_options()("max_tokens_in_paged_kvcache", "Max tokens in paged K-V Cache.", cxxopts::value<int>());
options.add_options()(
"kv_cache_free_gpu_mem_fraction", "K-V Cache Free Gpu Mem Fraction.", cxxopts::value<float>());
options.add_options()(
"enable_trt_overlap", "Overlap TRT context preparation and execution", cxxopts::value<bool>());
options.add_options()("scheduler_policy", "Choose scheduler policy between max_utilization/guaranteed_no_evict.",
cxxopts::value<std::string>()->default_value("guaranteed_no_evict"));
options.add_options()("enable_trt_overlap", "Overlap TRT context preparation and execution",
cxxopts::value<bool>()->default_value("false"));
options.add_options()("log_level", "Choose log level between verbose/info/warning/error/internal_error.",
cxxopts::value<std::string>()->default_value("error"));
@ -573,32 +575,29 @@ int main(int argc, char* argv[])
// Argument: Dataset
auto const datasetPath = result["dataset"].as<std::string>();
// Argument: beam width
auto const beamWidth = result["beam_width"].as<int>();
TrtGptModelOptionalParams optionalParams;
// Argument: Max Num Sequences
std::optional<int32_t> maxNumSequences = std::nullopt;
if (result["max_num_sequences"].as<int>() != -1)
if (result.count("max_num_sequences"))
{
maxNumSequences = result["max_num_sequences"].as<int>();
optionalParams.maxNumSequences = result["max_num_sequences"].as<int>();
}
// Argument: Max tokens in paged K-V Cache
std::optional<int32_t> maxTokensInPagedKvCache = std::nullopt;
if (result["max_tokens_in_paged_kvcache"].as<int>() != -1)
if (result.count("max_tokens_in_paged_kvcache"))
{
maxTokensInPagedKvCache = result["max_tokens_in_paged_kvcache"].as<int>();
optionalParams.kvCacheConfig.maxTokens = result["max_tokens_in_paged_kvcache"].as<int>();
}
// Argument: K-V Cache Free Gpu Mem Fraction
std::optional<float> kvCacheFreeGpuMemFraction = std::nullopt;
if (result["kv_cache_free_gpu_mem_fraction"].as<float>() != -1)
if (result.count("kv_cache_free_gpu_mem_fraction"))
{
kvCacheFreeGpuMemFraction = result["kv_cache_free_gpu_mem_fraction"].as<float>();
optionalParams.kvCacheConfig.freeGpuMemoryFraction = result["kv_cache_free_gpu_mem_fraction"].as<float>();
}
// Argument: Enable TRT overlap
std::optional<bool> enableTrtOverlap = std::nullopt;
if (result["enable_trt_overlap"].as<bool>() != -1)
if (result.count("enable_trt_overlap"))
{
enableTrtOverlap = result["enable_trt_overlap"].as<bool>();
optionalParams.enableTrtOverlap = result["enable_trt_overlap"].as<bool>();
}
// Argument: Scheduler policy
@ -652,8 +651,7 @@ int main(int argc, char* argv[])
try
{
benchmarkGptManager(result["model"].as<std::string>(), result["engine_dir"].as<std::string>(), type,
datasetPath, logger, maxNumSequences, maxTokensInPagedKvCache, kvCacheFreeGpuMemFraction, enableTrtOverlap,
schedulerPolicy);
datasetPath, beamWidth, logger, optionalParams, schedulerPolicy);
}
catch (const std::exception& e)
{

View File

@ -35,9 +35,9 @@ namespace trt = nvinfer1;
namespace
{
void benchmarkGptSession(std::string const& modelName, std::filesystem::path const& dataPath,
std::vector<int> const& batchSizes, std::vector<std::vector<int>> const& inOutLen,
std::vector<int> const& batchSizes, int beamWidth, std::vector<std::vector<int>> const& inOutLen,
std::shared_ptr<nvinfer1::ILogger> const& logger, int warmUp, int numRuns, int duration,
std::optional<SizeType> numMicroBatches, bool cudaGraphMode)
GptSession::Config& sessionConfig, bool cudaGraphMode)
{
auto const json = GptJsonConfig::parse(dataPath / "config.json");
auto const modelConfig = json.getModelConfig();
@ -50,8 +50,6 @@ void benchmarkGptSession(std::string const& modelName, std::filesystem::path con
auto const dtype = modelConfig.getDataType();
auto const useHalf = (dtype == nvinfer1::DataType::kHALF);
auto constexpr decoderPerRequest = false;
auto constexpr beamWidth = 1;
SamplingConfig samplingConfig{beamWidth};
samplingConfig.temperature = std::vector{1.0f};
samplingConfig.minLength = std::vector{1};
@ -59,16 +57,24 @@ void benchmarkGptSession(std::string const& modelName, std::filesystem::path con
samplingConfig.topK = std::vector{1};
samplingConfig.topP = std::vector{0.0f};
GptSession session{modelConfig, worldConfig, enginePath.string(), logger};
// Use bufferManager for copying data to and from the GPU
auto& bufferManager = session.getBufferManager();
session.setCudaGraphMode(cudaGraphMode);
auto const maxBatchSize = *std::max_element(batchSizes.begin(), batchSizes.end());
sessionConfig.maxBatchSize = maxBatchSize;
sessionConfig.maxBeamWidth = beamWidth;
sessionConfig.decoderPerRequest = false;
sessionConfig.cudaGraphMode = cudaGraphMode;
for (auto inOut : inOutLen)
{
auto const maxInputLength = inOut[0];
auto const maxNewTokens = inOut[1];
sessionConfig.maxSequenceLength = maxInputLength + maxNewTokens;
GptSession session{sessionConfig, modelConfig, worldConfig, enginePath.string(), logger};
// Use bufferManager for copying data to and from the GPU
auto& bufferManager = session.getBufferManager();
auto constexpr endId = 50256;
auto constexpr padId = 50256;
@ -76,9 +82,6 @@ void benchmarkGptSession(std::string const& modelName, std::filesystem::path con
{
try
{
session.setup(batchSize, beamWidth, maxInputLength + maxNewTokens, decoderPerRequest, std::nullopt,
numMicroBatches);
std::vector<SizeType> inputLenghtsHost(batchSize, maxInputLength);
auto inputLenghts
= bufferManager.copyFrom(inputLenghtsHost, ITensor::makeShape({batchSize}), MemoryType::kGPU);
@ -133,11 +136,14 @@ void benchmarkGptSession(std::string const& modelName, std::filesystem::path con
}
printf("Benchmarking done. Iteration: %d, duration: %.2f sec.\n", iterIdx, curDuration / 1000);
auto averageLatency = curDuration / iterIdx;
if (worldConfig.getRank() == 0)
{
printf("[BENCHMARK] batch_size %d input_length %d output_length %d latency(ms) %.2f\n", batchSize,
maxInputLength, maxNewTokens, averageLatency);
auto const averageLatency = curDuration / iterIdx;
float const tokensPerSec = batchSize * maxNewTokens / (averageLatency / 1000);
printf(
"[BENCHMARK] batch_size %d input_length %d output_length %d latency(ms) %.2f tokensPerSec "
"%.2f\n",
batchSize, maxInputLength, maxNewTokens, averageLatency, tokensPerSec);
}
}
catch (std::runtime_error& e)
@ -154,8 +160,9 @@ void benchmarkGptSession(std::string const& modelName, std::filesystem::path con
if (worldConfig.getRank() == 0)
{
printf("%s", e.what());
printf("[BENCHMARK] batch_size %d input_length %d output_length %d latency(ms) N/A\n", batchSize,
maxInputLength, maxNewTokens);
printf(
"[BENCHMARK] batch_size %d input_length %d output_length %d latency(ms) N/A tokensPerSec N/A\n",
batchSize, maxInputLength, maxNewTokens);
}
continue;
}
@ -177,6 +184,8 @@ int main(int argc, char* argv[])
"Specify batch size(s) you want to benchmark. Multiple batch sizes can be separated by \";\", example: "
"\"1;8;64\".",
cxxopts::value<std::string>()->default_value("8"));
options.add_options()(
"beam_width", "Specify beam width you want to benchmark.", cxxopts::value<int>()->default_value("1"));
options.add_options()("input_output_len",
"Specify input-output length(s) you want to benchmark. Multiple input lengths can be separated by \";\", "
"example: \"60,20;128,20\".",
@ -190,8 +199,12 @@ int main(int argc, char* argv[])
cxxopts::value<int>()->default_value("10"));
options.add_options()("duration", "Minimal duration of iterations to measure in seconds.",
cxxopts::value<int>()->default_value("60"));
options.add_options()(
"num_micro_batches", "Number of micro batches if enabling pipeline parallelism.", cxxopts::value<int>());
options.add_options()("max_tokens_in_paged_kvcache", "Max tokens in paged K-V Cache.", cxxopts::value<int>());
options.add_options()(
"kv_cache_free_gpu_mem_fraction", "K-V Cache Free Gpu Mem Fraction.", cxxopts::value<float>());
options.add_options()("enable_cuda_graph", "Execute GPT session with CUDA graph.");
@ -220,6 +233,9 @@ int main(int argc, char* argv[])
batchSizes.push_back(std::stoi(token));
}
// Argument: beam width
auto const beamWidth = result["beam_width"].as<int>();
// Argument: Input-output lengths
std::istringstream ssInOutLenArg;
ssInOutLenArg.str(result["input_output_len"].as<std::string>());
@ -264,11 +280,21 @@ int main(int argc, char* argv[])
return 1;
}
GptSession::Config sessionConfig{0, 0, 0};
// Argument: Number of micro batches
std::optional<SizeType> numMicroBatches{std::nullopt};
if (result.count("num_micro_batches"))
{
numMicroBatches = result["num_micro_batches"].as<int>();
sessionConfig.numMicroBatches = result["num_micro_batches"].as<int>();
}
// Argument: Max tokens in paged K-V Cache
if (result.count("max_tokens_in_paged_kvcache"))
{
sessionConfig.kvCacheConfig.maxTokens = result["max_tokens_in_paged_kvcache"].as<int>();
}
// Argument: K-V Cache Free Gpu Mem Fraction
if (result.count("kv_cache_free_gpu_mem_fraction"))
{
sessionConfig.kvCacheConfig.freeGpuMemoryFraction = result["kv_cache_free_gpu_mem_fraction"].as<float>();
}
// Argument: Enable CUDA graph
@ -279,8 +305,8 @@ int main(int argc, char* argv[])
try
{
benchmarkGptSession(result["model"].as<std::string>(), result["engine_dir"].as<std::string>(), batchSizes,
inOutLen, logger, result["warm_up"].as<int>(), result["num_runs"].as<int>(), result["duration"].as<int>(),
numMicroBatches, enableCudaGraph);
beamWidth, inOutLen, logger, result["warm_up"].as<int>(), result["num_runs"].as<int>(),
result["duration"].as<int>(), sessionConfig, enableCudaGraph);
}
catch (const std::exception& e)
{

View File

@ -0,0 +1,45 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/runtime/common.h"
#include <optional>
namespace tensorrt_llm::batch_manager::kv_cache_manager
{
class KvCacheConfig
{
public:
using SizeType = tensorrt_llm::runtime::SizeType;
explicit KvCacheConfig(
std::optional<SizeType> maxTokens = std::nullopt, std::optional<float> freeGpuMemoryFraction = std::nullopt)
: maxTokens{maxTokens}
, freeGpuMemoryFraction{freeGpuMemoryFraction}
{
}
std::optional<SizeType> maxTokens;
std::optional<float> freeGpuMemoryFraction;
static constexpr auto kDefaultGpuMemFraction = 0.85f;
};
} // namespace tensorrt_llm::batch_manager::kv_cache_manager

View File

@ -16,6 +16,7 @@
#pragma once
#include "tensorrt_llm/batch_manager/kvCacheConfig.h"
#include "tensorrt_llm/batch_manager/llmRequest.h"
#include "tensorrt_llm/runtime/bufferManager.h"
#include "tensorrt_llm/runtime/cudaStream.h"
@ -266,10 +267,11 @@ public:
void schedulingRemoveSequence(SizeType batchSlotIdx);
void getBlockPointersOfBatch(runtime::ITensor::SharedPtr dstPointers, SizeType batchSize, SizeType beamWidth) const;
void getBlockPointersOfBatch(
runtime::ITensor& dstPointers, SizeType firstBatchSlotIdx, SizeType batchSize, SizeType beamWidth) const;
void copyBlockPointers(runtime::ITensor::SharedPtr dstPointers, SizeType dstSlotOffset, SizeType batchSlotIdx,
SizeType beamWidth) const;
void copyBlockPointers(
runtime::ITensor& dstPointers, SizeType dstSlotOffset, SizeType batchSlotIdx, SizeType beamWidth) const;
// Volume of [2, numKvHeads, tokensPerBlock, sizePerHead]
[[nodiscard]] static SizeType constexpr calculatePageSize(tensorrt_llm::runtime::GptModelConfig const& modelConfig)
@ -285,6 +287,10 @@ public:
* modelConfig.getSizePerHead();
}
[[nodiscard]] static SizeType getMaxNumTokens(KvCacheConfig const& config, nvinfer1::DataType dtype,
tensorrt_llm::runtime::GptModelConfig const& modelConfig,
tensorrt_llm::runtime::WorldConfig const& worldConfig);
private:
void resetBlockPointers(SizeType batchSlotIdx, SizeType beamWidth);

View File

@ -17,6 +17,7 @@
#pragma once
#include "tensorrt_llm/batch_manager/kvCacheConfig.h"
#include "tensorrt_llm/runtime/common.h"
#include <optional>
@ -26,51 +27,22 @@ namespace tensorrt_llm::batch_manager
class TrtGptModelOptionalParams
{
using KvCacheConfig = kv_cache_manager::KvCacheConfig;
public:
using SizeType = tensorrt_llm::runtime::SizeType;
TrtGptModelOptionalParams()
: mMaxNumSequences(std::nullopt)
, mMaxTokensInPagedKvCache(std::nullopt)
, mKvCacheFreeGpuMemFraction(std::nullopt)
, mEnableTrtOverlap(std::nullopt)
explicit TrtGptModelOptionalParams(KvCacheConfig const& kvCacheConfig = KvCacheConfig{},
std::optional<SizeType> maxNumSequences = std::nullopt, bool enableTrtOverlap = true)
: kvCacheConfig{kvCacheConfig}
, maxNumSequences{maxNumSequences}
, enableTrtOverlap{enableTrtOverlap}
{
}
TrtGptModelOptionalParams(std::optional<SizeType> maxNumSequences, std::optional<SizeType> maxTokensInPagedKvCache,
std::optional<float> kvCacheFreeGpuMemFraction, std::optional<bool> enableTrtOverlap)
: mMaxNumSequences(maxNumSequences)
, mMaxTokensInPagedKvCache(maxTokensInPagedKvCache)
, mKvCacheFreeGpuMemFraction(kvCacheFreeGpuMemFraction)
, mEnableTrtOverlap(enableTrtOverlap)
{
}
[[nodiscard]] std::optional<SizeType> getMaxTokensInPagedKvCache() const
{
return mMaxTokensInPagedKvCache;
}
[[nodiscard]] std::optional<float> getKvCacheFreeGpuMemFraction() const
{
return mKvCacheFreeGpuMemFraction;
}
[[nodiscard]] std::optional<float> getMaxNumSequences() const
{
return mMaxNumSequences;
}
[[nodiscard]] std::optional<bool> getEnableTrtOverlap() const
{
return mEnableTrtOverlap;
}
private:
std::optional<SizeType> mMaxNumSequences;
std::optional<SizeType> mMaxTokensInPagedKvCache;
std::optional<float> mKvCacheFreeGpuMemFraction;
std::optional<bool> mEnableTrtOverlap;
KvCacheConfig kvCacheConfig;
std::optional<SizeType> maxNumSequences;
bool enableTrtOverlap;
};
} // namespace tensorrt_llm::batch_manager

View File

@ -59,7 +59,7 @@ public:
void forwardAsync(decoder::Output& output, decoder::Input const& input) override;
bool isFinishedSync() override;
void forwardSync() override;
//! @return [batchSize], indicators of finished requests
[[nodiscard]] std::vector<bool> getFinished() const override
@ -83,13 +83,13 @@ public:
return ITensor::slice(mJointDecodingOutput->ids, 0, mActualBatchSize);
}
//! Execute postProcessRequest and returns OutputIds for request `batchIdx`.
//! Result will only be available after event returned
//! @brief Gather final beam search results for request `batchIdx`.
//! Result will only be available after event returned.
//! @returns [maxBeamWidth, maxInputLength + maxNewTokens], contains input token ids and generated token ids without
//! padding for request `batchIdx`, on gpu
[[nodiscard]] std::tuple<CudaEvent, TensorPtr> getFinalOutputIds(SizeType batchIdx) const override;
//! Execute postProcessRequest and returns OutputIds.
//! @brief Gather final beam search results for all requests.
//! @returns [batchSize, maxBeamWidth, maxInputLength + maxNewTokens], contains input token ids and generated token
//! ids without padding, on gpu
[[nodiscard]] TensorPtr getFinalOutputIds() const override;
@ -138,7 +138,7 @@ public:
}
private:
//! @brief Gather final results for request `batchIdx`
//! @brief Gather final beam search results for request `batchIdx`.
CudaEvent postProcessRequest(SizeType batchIdx) const;
private:

View File

@ -16,6 +16,7 @@
#pragma once
#include "tensorrt_llm/batch_manager/kvCacheConfig.h"
#include "tensorrt_llm/runtime/bufferManager.h"
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/cudaEvent.h"
@ -23,17 +24,22 @@
#include "tensorrt_llm/runtime/generationOutput.h"
#include "tensorrt_llm/runtime/gptModelConfig.h"
#include "tensorrt_llm/runtime/iTensor.h"
#include "tensorrt_llm/runtime/ipcUtils.h"
#include "tensorrt_llm/runtime/samplingConfig.h"
#include "tensorrt_llm/runtime/worldConfig.h"
#include <NvInferRuntime.h>
#include <cstdint>
#include <functional>
#include <memory>
#include <string>
#include <vector>
namespace tensorrt_llm::batch_manager
{
class TrtGptModelV1;
}
namespace tensorrt_llm::batch_manager::kv_cache_manager
{
class KVCacheManager;
@ -54,21 +60,47 @@ class RuntimeBuffers;
class GptSession
{
using KvCacheManager = batch_manager::kv_cache_manager::KVCacheManager;
using KvCacheConfig = batch_manager::kv_cache_manager::KvCacheConfig;
public:
using LoggerPtr = std::shared_ptr<nvinfer1::ILogger>;
GptSession(GptModelConfig const& modelConfig, WorldConfig const& worldConfig, void const* engineBuffer,
std::size_t engineSize, LoggerPtr logger = nullptr);
//! @brief Configuration for session execution and buffer sizes.
//! `generate` may be called with batch size and beam width smaller than the configured parameters.
//! @details `maxBatchSize` will be divided by the number of micro batches to initialize each batch buffer.
class Config
{
public:
Config(SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxSequenceLength)
: maxBatchSize{maxBatchSize}
, maxBeamWidth{maxBeamWidth}
, maxSequenceLength{maxSequenceLength}
{
}
GptSession(GptModelConfig const& modelConfig, WorldConfig const& worldConfig,
SizeType maxBatchSize;
SizeType maxBeamWidth;
SizeType maxSequenceLength;
bool decoderPerRequest{false};
bool cudaGraphMode{false};
KvCacheConfig kvCacheConfig{};
std::optional<SizeType> numMicroBatches = std::nullopt;
};
GptSession(Config const& sessionConfig, GptModelConfig const& modelConfig, WorldConfig const& worldConfig,
void const* engineBuffer, std::size_t engineSize, LoggerPtr logger = nullptr);
GptSession(Config const& sessionConfig, GptModelConfig const& modelConfig, WorldConfig const& worldConfig,
std::vector<uint8_t> const& engineBuffer, LoggerPtr logger = nullptr)
: GptSession(modelConfig, worldConfig, engineBuffer.data(), engineBuffer.size(), logger)
: GptSession(
sessionConfig, modelConfig, worldConfig, engineBuffer.data(), engineBuffer.size(), std::move(logger))
{
}
GptSession(GptModelConfig const& modelConfig, WorldConfig const& worldConfig, std::string const& engineFile,
LoggerPtr logger = nullptr)
: GptSession(modelConfig, worldConfig, utils::loadEngine(engineFile), logger)
GptSession(Config const& sessionConfig, GptModelConfig const& modelConfig, WorldConfig const& worldConfig,
std::string const& engineFile, LoggerPtr logger = nullptr)
: GptSession(sessionConfig, modelConfig, worldConfig, utils::loadEngine(engineFile), std::move(logger))
{
}
@ -91,64 +123,45 @@ public:
return mDevice;
}
[[nodiscard]] bool isCudaGraphMode() const noexcept
{
return mCudaGraphMode;
}
void setCudaGraphMode(bool value)
{
mCudaGraphMode = value;
}
//! @brief Initialize buffers for the given sizes.
//! `generate` may be called with batch size and beam width smaller than the setup parameters.
//! @details `maxBatchSize` will be divided by the number of micro batches to initialize each batch buffer.
void setup(SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxSequenceLength, bool decoderPerRequest,
std::optional<SizeType> maxTokensInPagedKvCache = std::nullopt,
std::optional<SizeType> numMicroBatches = std::nullopt);
void generate(GenerationOutput& outputs, GenerationInput const& inputs, SamplingConfig const& samplingConfig)
{
if (mNumMicroBatches == 1)
generateSingleBatch(outputs, inputs, samplingConfig);
else
generateMultiBatch(outputs, inputs, samplingConfig);
}
void generate(GenerationOutput& outputs, GenerationInput const& inputs, SamplingConfig const& samplingConfig);
private:
void generateSingleBatch(
GenerationOutput& outputs, GenerationInput const& inputs, SamplingConfig const& samplingConfig);
[[nodiscard]] bool useCudaGraphs()
{
return !mCudaGraphInstances.empty();
}
void generateMultiBatch(
GenerationOutput& outputs, GenerationInput const& inputs, SamplingConfig const& samplingConfig);
void generateBatched(GenerationOutput& outputs, std::vector<GenerationInput> const& microBatches,
SamplingConfig const& samplingConfig);
using KvCacheManager = batch_manager::kv_cache_manager::KVCacheManager;
void setup(Config const& sessionConfig);
void createContexts(SizeType numMicroBatches);
void createContexts(SizeType numMicroBatches, bool useCudaGraphs);
void createBuffers(SizeType numMicroBatches);
void createDecoders(SizeType batchSize, SizeType beamWidth, SizeType maxSequenceLength,
nvinfer1::DataType logitsType, bool decoderPerRequest, SizeType numMicroBatches);
void createKvCacheManagers(SizeType batchSize, SizeType beamWidth, SizeType maxSequenceLength,
SizeType numMicroBatches, std::optional<SizeType> maxTokensInPagedKvCache);
void createKvCacheManager(
SizeType batchSize, SizeType beamWidth, SizeType maxSequenceLength, KvCacheConfig const& config);
void createCustomAllReduceWorkspace(SizeType batchSize, SizeType beamWidth, SizeType maxSequenceLength);
//! @brief Execute decoder on last PP rank, receive decoder output on other PP ranks.
void decoderStepAsync(
ITensor::SharedPtr& outputIds, ITensor::SharedPtr& newTokens, SizeType decoderStep, SizeType microBatchId);
void decoderStepAsync(SizeType decoderStep, SizeType microBatchId);
//! @brief Synchronize with the decoder and return the `shouldStop` flag.
bool shouldStopSync(SizeType batchSize, SizeType beamWidth, SizeType microBatchId);
//! @brief Collect final output ids on last PP rank and send them to first PP rank.
//! @details Receives are asynchronous on host, so synchronization is required before access.
void finalizeOutputIds(ITensor& outputIds, SizeType microBatchId);
void finalizeOutputIds(SizeType microBatchId);
void kvCacheAddSequences(SizeType beamWidth, SizeType microBatchId);
void kvCacheAddSequences(SizeType beamWidth, SizeType microBatchId, SizeType firstBatchIdx);
ITensor::SharedPtr initNewTokens(
GenerationInput const& inputs, SamplingConfig const& samplingConfig, SizeType microBatchId);
std::function<void(SizeType microBatchId, SizeType step, bool finished)> createOnTokenGeneratedCallback(
GenerationOutput& outputs, SizeType numMicroBatches);
class CudaGraphExecutor
{
public:
@ -180,10 +193,11 @@ private:
bool update(cudaGraph_t const& graph);
void uploadToStream(CudaStream const& stream);
using cudaGraphExecPtr = cudaGraphExec_t;
cudaGraphExecPtr mInstance;
cudaGraphExec_t mInstance;
};
friend class batch_manager::TrtGptModelV1;
private:
GptModelConfig const mModelConfig;
WorldConfig const mWorldConfig;
@ -196,17 +210,17 @@ private:
LoggerPtr mLogger;
std::shared_ptr<TllmRuntime> mRuntime;
std::shared_ptr<KvCacheManager> mKvCacheManager;
SizeType mNumMicroBatches;
// for each micro batch
std::vector<std::shared_ptr<IStatefulGptDecoder>> mDecoders;
std::vector<std::shared_ptr<RuntimeBuffers>> mBuffers;
std::vector<std::shared_ptr<KvCacheManager>> mKvCacheManagers;
std::vector<CudaEvent> mReceivedEvents;
bool mCudaGraphMode{false};
// ping-pong instances
std::array<CudaGraphExecutor, 2> mCudaGraphInstances;
std::vector<CudaGraphExecutor> mCudaGraphInstances;
};
} // namespace tensorrt_llm::runtime

View File

@ -125,7 +125,7 @@ public:
//! ids without padding for request `batchIdx`, on gpu
virtual TensorPtr getOutputIds(SizeType batchIdx) const = 0;
//! Execute postProcessRequest and returns OutputIds for request `batchIdx`.
//! @brief Gather final beam search results for request `batchIdx`.
//! Result will only be available after event returned
//! @returns [maxBeamWidth, maxInputLength + maxNewTokens], contains input token ids and generated token ids without
//! padding for request `batchIdx`, on gpu

View File

@ -83,17 +83,17 @@ public:
//! @brief Run one step for all requests without blocking the host thread.
virtual void forwardAsync(decoder::Output& output, decoder::Input const& input) = 0;
//! @brief Wait for the last call to `forwardAsync` to complete and return whether all sequences have finished.
virtual bool isFinishedSync() = 0;
//! @brief Wait for the last call to `forwardAsync` to complete.
virtual void forwardSync() = 0;
//! @brief Run one step for all requests.
virtual bool forward(decoder::Output& output, decoder::Input const& input)
virtual void forward(decoder::Output& output, decoder::Input const& input)
{
forwardAsync(output, input);
return isFinishedSync();
return forwardSync();
}
//! @brief Gather final results for all requests.
//! @brief Gather final beam search results for all requests.
virtual TensorPtr getFinalOutputIds() const = 0;
//! @returns [batchSize, beamWidth, maxSequenceLength], all token ids, on gpu
@ -105,6 +105,8 @@ public:
//! @returns [1], number of finished sequences, in pinned host memory
virtual TensorPtr getNbFinished() const = 0;
virtual ~IStatefulGptDecoder() = default;
protected:
IStatefulGptDecoder() = default;
};

View File

@ -93,11 +93,17 @@ public:
return getPipelineParallelRank() == 0;
}
//! \brief Is my rank the last rank in its pipeline?
[[nodiscard]] bool constexpr isLastPipelineParallelRank() const noexcept
{
return getPipelineParallelRank() == getPipelineParallelism() - 1;
}
[[nodiscard]] SizeType constexpr getLastRank() const noexcept
{
return getSize() - 1;
}
[[nodiscard]] std::vector<SizeType> getPipelineParallelGroup() const;
static bool validConfig(nvinfer1::ILogger& logger, SizeType tensorParallelism, SizeType pipelineParallelism);

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:db7d9dced28c2cab3569073ed5c3fd3b11919df3ef35489f042e5fa6531f1f2f
size 1328386
oid sha256:422df71fccde81a55049fb61996d0b88bbaf1f18866b63c8e73c36b772c2df46
size 1508332

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:e661364a37c4b0f1452d71d50ef688a98e27c67798b40210992c0d0b972f0345
size 1336196
oid sha256:0013625bc6b18255f44d6ab38e8ea0bceda6452bddf9df3cf832ad106fc2058d
size 1516676

View File

@ -1,3 +1,3 @@
7b956c958b7097655203190c8484cee7 libtensorrt_llm_batch_manager_static.a
b9695f09d54cbfb340669a07981f4f84 libtensorrt_llm_batch_manager_static.pre_cxx11.a
commit 04aec6adc3db913b6b58f914c70b765a4a745162
bda56cf4ad2242be25115ddecd23e7df libtensorrt_llm_batch_manager_static.a
12d7c8e5b4a018dfd9043fa7db979b5a libtensorrt_llm_batch_manager_static.pre_cxx11.a
7e492cc1057b1091f62d69df81547cb071729e5d commit

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a42b9de39385cd0e80f56b9a26e928029cc13f34c7ce114dc0394f02aa70c336
size 1276192
oid sha256:c5a207480594cb228b7264f28af85b0a820046f64379f11fd7389c701ca5497d
size 1421186

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:ce20355cb54f0b27a77bdee7bb2d3dfb9f30241ac9a059eb4d319240f8bafd74
size 1266016
oid sha256:80e06e15b9e29ba80c036ba6604a2ce286acb294eddb50015bad53cfdeba4534
size 1423958

View File

@ -44,9 +44,13 @@ namespace kernels
// Use HMMA to compute with FP16/BF16 inputs and FP32 accumulators.
// #define MMHA_USE_HMMA
// Apply the FP8 scaling to Q instead of K.
// Pre-scale Q or P to reduce number of instructions for dequantizing KV cache.
// If you notice a decrease in accuracy when the fp8 kv cache is enabled,
// consider disabling the two flags.
#ifdef ENABLE_FP8
// Apply the FP8 scaling to Q instead of K.
#define MMHA_FP8_SCALE_Q_INSTEAD_OF_K
// Apply the FP8 scaling to P instead of V.
#define MMHA_FP8_SCALE_P_INSTEAD_OF_V
#endif // !defined ENABLE_FP8
@ -1428,8 +1432,9 @@ __global__ void masked_multihead_attention_kernel(
// Quant/Dequant scales for 8bits kv cache.
using T_scale = typename kv_cache_scale_type_t<T, Tcache>::Type;
T_scale kv_scale_orig_quant;
float kv_scale_quant_orig = (ENABLE_8BITS_CACHE ? params.kv_scale_quant_orig[0] : 1.0f);
T_scale kv_scale_orig_quant, kv_scale_quant_orig;
const float kv_scale_quant_orig_f = (ENABLE_8BITS_CACHE ? params.kv_scale_quant_orig[0] : 1.0f);
convert_from_float(&kv_scale_quant_orig, kv_scale_quant_orig_f);
convert_from_float(&kv_scale_orig_quant, (ENABLE_8BITS_CACHE ? params.kv_scale_orig_quant[0] : 1.0f));
// Up to QK_VECS_PER_Dh_MAX threads load Q and K + the bias values for the current timestep.
@ -1640,7 +1645,7 @@ __global__ void masked_multihead_attention_kernel(
zero(scaled_q);
if (is_valid_qk_vec)
{
scaled_q = mul<Qk_vec_k, float, Qk_vec_k>(kv_scale_quant_orig, q);
scaled_q = mul<Qk_vec_k, Tk, Qk_vec_k>(kv_scale_quant_orig, q);
}
reinterpret_cast<Qk_vec_k*>(&q_smem[qk_vec_idx])[0] = scaled_q;
}
@ -1895,7 +1900,8 @@ __global__ void masked_multihead_attention_kernel(
{
if constexpr (ENABLE_8BITS_CACHE)
{
qk_ = Qk_dot<T, THREADS_PER_KEY>::scale_dot(q_vec, k_vec, kv_scale_quant_orig) * params.inv_sqrt_dh;
qk_ = Qk_dot<T, THREADS_PER_KEY>::scale_dot(q_vec, k_vec, kv_scale_quant_orig_f)
* params.inv_sqrt_dh;
}
else
{
@ -2011,7 +2017,8 @@ __global__ void masked_multihead_attention_kernel(
{
if constexpr (ENABLE_8BITS_CACHE)
{
qk_ = Qk_dot<T, THREADS_PER_KEY>::scale_dot(q_vec, k_vec, kv_scale_quant_orig) * params.inv_sqrt_dh;
qk_ = Qk_dot<T, THREADS_PER_KEY>::scale_dot(q_vec, k_vec, kv_scale_quant_orig_f)
* params.inv_sqrt_dh;
}
else
{
@ -2136,7 +2143,7 @@ __global__ void masked_multihead_attention_kernel(
// Normalize the logits.
#ifdef MMHA_FP8_SCALE_P_INSTEAD_OF_V
float logit_scale = (FP8_KV_CACHE ? kv_scale_quant_orig : 1.0f);
float logit_scale = (FP8_KV_CACHE ? kv_scale_quant_orig_f : 1.0f);
#else
float logit_scale = 1.f;
#endif // MMHA_FP8_SCALE_P_INSTEAD_OF_V
@ -2265,7 +2272,7 @@ __global__ void masked_multihead_attention_kernel(
// Load the logits from shared memory.
// Note that fma will convert 8bit vec to the accumulation data type (float by default).
Logit_value_fma<Tk, V_vec_accum, V_vec_m, INT8_KV_CACHE, FP8_KV_CACHE>(
out, reinterpret_cast<Tk*>(logits_smem + local_time_idx), v_vec, kv_scale_quant_orig, is_mask);
out, reinterpret_cast<Tk*>(logits_smem + local_time_idx), v_vec, kv_scale_quant_orig_f, is_mask);
}
}
@ -2295,7 +2302,7 @@ __global__ void masked_multihead_attention_kernel(
// Load the logits from shared memory.
// Note that fma will convert 8bit vec to the accumulation data type (float by default).
Logit_value_fma<Tk, V_vec_accum, V_vec_m, INT8_KV_CACHE, FP8_KV_CACHE>(
out, reinterpret_cast<Tk*>(logits_smem + local_time_idx), v_vec, kv_scale_quant_orig, false);
out, reinterpret_cast<Tk*>(logits_smem + local_time_idx), v_vec, kv_scale_quant_orig_f, false);
}
}
}

View File

@ -1481,7 +1481,8 @@ inline __device__ Float8_ fma(__nv_bfloat16 a, int64_t b, Float8_ fc)
template <typename Acc, typename A, typename B>
inline __device__ Acc mul(A a, B b)
{
return Acc{}; // for compile
// This will error out when multiply operation is not supported.
return Acc(a * b);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
@ -2195,6 +2196,15 @@ inline __device__ float4 mul(float4 fa, fp8_4_t b)
////////////////////////////////////////////////////////////////////////////////////////////////////
template <>
inline __device__ Float4_ mul(float4 fa, fp8_4_t b)
{
float4 fc = mul<float4, float4, fp8_4_t>(fa, b);
return reinterpret_cast<Float4_&>(fc);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#endif // ENABLE_FP8
////////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -278,6 +278,8 @@ void IGptDecoder::gatherTree(ITensor& finalOutputIds, DecodingOutput const& deco
auto const beamWidth = finalOutputIdsShape.d[1];
auto const maxSeqLength = finalOutputIdsShape.d[2];
TLLM_CHECK_WITH_INFO(beamWidth > 1, "gatherTree is only needed for beam search.");
TLLM_CHECK_WITH_INFO(decodingOutputIdsShape.d[0] == batchSize,
common::fmtstr(
"Decoder batch size (%d) does not match final batch size (%d)", decodingOutputIdsShape.d[0], batchSize));
@ -290,48 +292,41 @@ void IGptDecoder::gatherTree(ITensor& finalOutputIds, DecodingOutput const& deco
auto const& stream = manager.getStream();
if (beamWidth > 1)
{
tensorrt_llm::kernels::invokeInitializeOutput(bufferCast<TokenIdType>(finalOutputIds),
bufferCast<TokenIdType>(*decodingInput.endIds), batchSize * beamWidth, maxSeqLength, stream.get());
sync_check_cuda_error();
tensorrt_llm::kernels::invokeInitializeOutput(bufferCast<TokenIdType>(finalOutputIds),
bufferCast<TokenIdType>(*decodingInput.endIds), batchSize * beamWidth, maxSeqLength, stream.get());
sync_check_cuda_error();
tensorrt_llm::kernels::BeamHypotheses beamHypotheses;
beamHypotheses.sequence_lengths_src = bufferCast<SizeType>(*decodingOutput.lengths);
beamHypotheses.parent_ids_src = bufferCast<TokenIdType>(*decodingOutput.parentIds);
beamHypotheses.output_ids_src = bufferCast<TokenIdType>(*decodingOutput.ids);
beamHypotheses.log_probs_src = nullptr;
beamHypotheses.max_seq_len = maxSeqLength;
beamHypotheses.length_penalties
= nullptr; // TODO (bhsueh) should set length penalties, this should be a gpu tensor When it is set as
// nullptr, the kernel will use default value (1.0f) automatically.
tensorrt_llm::kernels::BeamHypotheses beamHypotheses;
beamHypotheses.sequence_lengths_src = bufferCast<SizeType>(*decodingOutput.lengths);
beamHypotheses.parent_ids_src = bufferCast<TokenIdType>(*decodingOutput.parentIds);
beamHypotheses.output_ids_src = bufferCast<TokenIdType>(*decodingOutput.ids);
beamHypotheses.log_probs_src = nullptr;
beamHypotheses.max_seq_len = maxSeqLength;
beamHypotheses.length_penalties
= nullptr; // TODO (bhsueh) should set length penalties, this should be a gpu tensor When it is set as
// nullptr, the kernel will use default value (1.0f) automatically.
beamHypotheses.output_ids_tgt = bufferCast<TokenIdType>(*decodingOutput.beamHypotheses.outputIdsTgt);
beamHypotheses.sequence_lengths_tgt = bufferCast<SizeType>(*decodingOutput.beamHypotheses.sequenceLengthsTgt);
beamHypotheses.cum_log_probs = bufferCast<float>(*decodingOutput.beamHypotheses.cumLogProbs);
beamHypotheses.normed_scores = bufferCast<float>(*decodingOutput.beamHypotheses.normedScores);
beamHypotheses.log_probs = bufferCast<float>(*decodingOutput.beamHypotheses.logProbs);
beamHypotheses.min_normed_scores = bufferCast<float>(*decodingOutput.beamHypotheses.minNormedScores);
beamHypotheses.num_beams = bufferCast<SizeType>(*decodingOutput.beamHypotheses.numBeams);
beamHypotheses.is_done = bufferCast<bool>(*decodingOutput.beamHypotheses.isDone);
beamHypotheses.input_lengths = bufferCast<SizeType>(*decodingInput.lengths);
beamHypotheses.output_ids_tgt = bufferCast<TokenIdType>(*decodingOutput.beamHypotheses.outputIdsTgt);
beamHypotheses.sequence_lengths_tgt = bufferCast<SizeType>(*decodingOutput.beamHypotheses.sequenceLengthsTgt);
beamHypotheses.cum_log_probs = bufferCast<float>(*decodingOutput.beamHypotheses.cumLogProbs);
beamHypotheses.normed_scores = bufferCast<float>(*decodingOutput.beamHypotheses.normedScores);
beamHypotheses.log_probs = bufferCast<float>(*decodingOutput.beamHypotheses.logProbs);
beamHypotheses.min_normed_scores = bufferCast<float>(*decodingOutput.beamHypotheses.minNormedScores);
beamHypotheses.num_beams = bufferCast<SizeType>(*decodingOutput.beamHypotheses.numBeams);
beamHypotheses.is_done = bufferCast<bool>(*decodingOutput.beamHypotheses.isDone);
beamHypotheses.input_lengths = bufferCast<SizeType>(*decodingInput.lengths);
tensorrt_llm::kernels::invokeInsertUnfinishedPath(beamHypotheses, bufferCast<bool>(*decodingOutput.finished),
bufferCast<float>(*decodingOutput.cumLogProbs), batchSize, beamWidth, stream.get());
sync_check_cuda_error();
tensorrt_llm::kernels::invokeInsertUnfinishedPath(beamHypotheses, bufferCast<bool>(*decodingOutput.finished),
bufferCast<float>(*decodingOutput.cumLogProbs), batchSize, beamWidth, stream.get());
sync_check_cuda_error();
tensorrt_llm::kernels::invokeFinalize(bufferCast<TokenIdType>(finalOutputIds),
bufferCast<SizeType>(*decodingOutput.lengths), bufferCast<float>(*decodingOutput.cumLogProbs),
nullptr, // output_logs
beamHypotheses.output_ids_tgt, beamHypotheses.sequence_lengths_tgt, beamHypotheses.normed_scores,
beamHypotheses.cum_log_probs, beamHypotheses.log_probs, beamHypotheses.num_beams, beamHypotheses.input_lengths,
beamWidth, maxSeqLength, batchSize, stream.get());
sync_check_cuda_error();
tensorrt_llm::kernels::invokeFinalize(bufferCast<TokenIdType>(finalOutputIds),
bufferCast<SizeType>(*decodingOutput.lengths), bufferCast<float>(*decodingOutput.cumLogProbs),
nullptr, // output_logs
beamHypotheses.output_ids_tgt, beamHypotheses.sequence_lengths_tgt, beamHypotheses.normed_scores,
beamHypotheses.cum_log_probs, beamHypotheses.log_probs, beamHypotheses.num_beams,
beamHypotheses.input_lengths, beamWidth, maxSeqLength, batchSize, stream.get());
sync_check_cuda_error();
}
else
{
manager.copy(*decodingOutput.ids, finalOutputIds);
sync_check_cuda_error();
}
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
}

View File

@ -464,16 +464,13 @@ void GptDecoderBatch::forwardAsync(decoder::Output& output, decoder::Input const
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
}
bool GptDecoderBatch::isFinishedSync()
void GptDecoderBatch::forwardSync()
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
forwardSync(*mForwardToken);
auto const finished
= std::all_of(mFinished.begin(), mFinished.begin() + mActualBatchSize, [](bool x) { return x; });
// wait for mFinishedSum to be updated
mStream->wait(mForwardEvent);
mForwardEvent.synchronize();
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
return finished;
}
IStatefulGptDecoder::TensorPtr GptDecoderBatch::getFinalOutputIds() const

View File

@ -33,8 +33,7 @@
#include "tensorrt_llm/runtime/utils/sessionUtils.h"
#include <algorithm>
#include <cstdint>
#include <fstream>
#include <limits>
#include <memory>
using namespace tensorrt_llm::runtime;
@ -42,8 +41,8 @@ using namespace tensorrt_llm::runtime;
namespace tc = tensorrt_llm::common;
namespace bmkv = tensorrt_llm::batch_manager::kv_cache_manager;
GptSession::GptSession(GptModelConfig const& modelConfig, WorldConfig const& worldConfig, void const* engineBuffer,
std::size_t engineSize, LoggerPtr logger)
GptSession::GptSession(Config const& sessionConfig, GptModelConfig const& modelConfig, WorldConfig const& worldConfig,
void const* engineBuffer, std::size_t engineSize, LoggerPtr logger)
: mModelConfig{modelConfig}
, mWorldConfig{worldConfig}
, mDevice{utils::initDevice(worldConfig)}
@ -61,6 +60,8 @@ GptSession::GptSession(GptModelConfig const& modelConfig, WorldConfig const& wor
}
// TODO compare expected and runtime tensor names?
setup(sessionConfig);
}
nvinfer1::ILogger& GptSession::getLogger() const
@ -73,25 +74,39 @@ BufferManager& GptSession::getBufferManager() const
return mRuntime->getBufferManager();
}
void GptSession::createContexts(SizeType numMicroBatches)
void GptSession::createContexts(SizeType numMicroBatches, bool useCudaGraphs)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
mRuntime->clearContexts();
// Instantiate multiple contexts for flip-flopping
auto const numContextsPerPhase = std::max(2, numMicroBatches);
if (useCudaGraphs)
{
// Instantiate multiple graph instances for flip-flopping
mCudaGraphInstances.resize(2 * numMicroBatches);
}
auto const numProfiles = mRuntime->getNbProfiles();
TLLM_CHECK_WITH_INFO(
numProfiles == 1 || numProfiles == 2, "GPT only expects one optimization profile or two optimization profiles");
auto constexpr ctxContextId = 0;
auto constexpr genContextId = 1;
if (numProfiles == 2)
{
for (auto i = 0; i < numContextsPerPhase; ++i)
auto constexpr ctxContextId = 0;
auto constexpr genContextId = 1;
// Instantiate 2 contexts for flip-flopping
for (auto i = 0; i < 2 * numMicroBatches; ++i)
mRuntime->addContext(genContextId);
// Instantiate 1 context for context phase
for (auto i = 0; i < numMicroBatches; ++i)
mRuntime->addContext(ctxContextId);
}
else
{
auto constexpr contextId = 0;
// Instantiate 2 contexts for flip-flopping
for (auto i = 0; i < 2 * numMicroBatches; ++i)
mRuntime->addContext(contextId);
}
for (auto i = 0; i < numContextsPerPhase; ++i)
mRuntime->addContext(ctxContextId);
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
}
@ -130,8 +145,8 @@ void GptSession::createDecoders(SizeType batchSize, SizeType beamWidth, SizeType
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
}
void GptSession::createKvCacheManagers(SizeType batchSize, SizeType beamWidth, SizeType maxSequenceLength,
SizeType numMicroBatches, std::optional<SizeType> maxTokensInPagedKvCache)
void GptSession::createKvCacheManager(
SizeType batchSize, SizeType beamWidth, SizeType maxSequenceLength, KvCacheConfig const& config)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
auto const localNbLayers = mModelConfig.getNbLayers(mWorldConfig.getPipelineParallelism());
@ -140,11 +155,6 @@ void GptSession::createKvCacheManagers(SizeType batchSize, SizeType beamWidth, S
auto const hiddenSize = mModelConfig.getHiddenSize();
auto const tokensPerBlock = mModelConfig.getTokensPerBlock();
auto const maxBlocksPerSeq = tc::divUp(maxSequenceLength, tokensPerBlock);
auto const maxNumTokens
= maxTokensInPagedKvCache.value_or(batchSize * beamWidth * maxBlocksPerSeq * tokensPerBlock);
auto const maxNumBlocks = tc::divUp(maxNumTokens, tokensPerBlock);
nvinfer1::DataType kvDtype;
if (mModelConfig.getQuantMode().hasFp8KvCache())
{
@ -159,14 +169,13 @@ void GptSession::createKvCacheManagers(SizeType batchSize, SizeType beamWidth, S
kvDtype = mModelConfig.getDataType();
}
mKvCacheManagers.clear();
auto const maxNumTokens = bmkv::KVCacheManager::getMaxNumTokens(config, kvDtype, mModelConfig, mWorldConfig);
TLLM_LOG_INFO("Using %d tokens in paged KV cache.", maxNumTokens);
auto const maxNumBlocks = tc::ceilDiv(maxNumTokens, tokensPerBlock);
auto const maxBlocksPerSeq = tc::ceilDiv(maxSequenceLength, tokensPerBlock);
for (SizeType i = 0; i < numMicroBatches; ++i)
{
mKvCacheManagers.emplace_back(
std::make_shared<bmkv::KVCacheManager>(localNbLayers, nbHeads, nbKvHeads, hiddenSize, tokensPerBlock,
maxNumBlocks, batchSize, beamWidth, maxBlocksPerSeq, kvDtype, mRuntime->getStreamPtr()));
}
mKvCacheManager = std::make_shared<bmkv::KVCacheManager>(localNbLayers, nbHeads, nbKvHeads, hiddenSize,
tokensPerBlock, maxNumBlocks, batchSize, beamWidth, maxBlocksPerSeq, kvDtype, mRuntime->getStreamPtr());
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
}
@ -183,14 +192,19 @@ void GptSession::createCustomAllReduceWorkspace(
}
}
void GptSession::setup(SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxSequenceLength, bool decoderPerRequest,
std::optional<SizeType> maxTokensInPagedKvCache, std::optional<SizeType> numMicroBatches)
void GptSession::setup(Config const& sessionConfig)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
if (numMicroBatches)
mNumMicroBatches = numMicroBatches.value();
createContexts(mNumMicroBatches);
mCudaGraphMode = sessionConfig.cudaGraphMode;
auto const maxBatchSize = sessionConfig.maxBatchSize;
auto const maxBeamWidth = sessionConfig.maxBeamWidth;
auto const maxSequenceLength = sessionConfig.maxSequenceLength;
if (sessionConfig.numMicroBatches)
mNumMicroBatches = sessionConfig.numMicroBatches.value();
createContexts(mNumMicroBatches, sessionConfig.cudaGraphMode);
createBuffers(mNumMicroBatches);
auto const microBatchSize = tc::ceilDiv(maxBatchSize, mNumMicroBatches);
@ -202,18 +216,17 @@ void GptSession::setup(SizeType maxBatchSize, SizeType maxBeamWidth, SizeType ma
if (mModelConfig.usePagedKvCache())
{
createKvCacheManagers(
microBatchSize, maxBeamWidth, maxSequenceLength, mNumMicroBatches, maxTokensInPagedKvCache);
createKvCacheManager(maxBatchSize, maxBeamWidth, maxSequenceLength, sessionConfig.kvCacheConfig);
}
if (mWorldConfig.isLastPipelineParallelRank())
{
auto const logitsType = mRuntime->getEngine().getTensorDataType("logits");
createDecoders(
microBatchSize, maxBeamWidth, maxSequenceLength, logitsType, decoderPerRequest, mNumMicroBatches);
createDecoders(microBatchSize, maxBeamWidth, maxSequenceLength, logitsType, sessionConfig.decoderPerRequest,
mNumMicroBatches);
}
if (mWorldConfig.isPipelineParallel())
if (mWorldConfig.isPipelineParallel() || mNumMicroBatches > 1)
{
mReceivedEvents.clear();
for (SizeType i = 0; i < mNumMicroBatches; ++i)
@ -234,173 +247,18 @@ void GptSession::setup(SizeType maxBatchSize, SizeType maxBeamWidth, SizeType ma
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
}
void GptSession::generateSingleBatch(
GenerationOutput& outputs, GenerationInput const& inputs, SamplingConfig const& samplingConfig)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
TLLM_CHECK_WITH_INFO(inputs.packed == mModelConfig.usePackedInput(),
"The chosen model requires a packed input tensor (did you set packed?).");
auto const& inputLengths = inputs.lengths;
TLLM_CHECK_WITH_INFO(inputLengths->getShape().nbDims == 1, "Input lengths tensor must be one-dimensional.");
auto constexpr microBatchId = 0;
auto& manager = mRuntime->getBufferManager();
// Initialize and reshape buffers
auto& buffers = *mBuffers.at(microBatchId);
TLLM_CHECK_WITH_INFO(buffers.allocated, "Buffers not allocated, please call setup first!");
buffers.initContextLengths(inputLengths, manager);
auto const generationConfig = RuntimeBuffers::GenerationConfig::fromInput(*inputs.ids, *buffers.contextLengthsHost,
inputs.packed, samplingConfig.beamWidth, mDecoderMaxSequenceLength, inputs.maxNewTokens);
auto const batchSize = generationConfig.batchSize;
auto const beamWidth = generationConfig.beamWidth;
auto const maxInputLength = generationConfig.maxInputLength;
auto const maxNewTokens = generationConfig.maxNewTokens;
if (mWorldConfig.isLastPipelineParallelRank() && mModelConfig.computeContextLogits())
{
auto const vocabSizePadded = mModelConfig.getVocabSizePadded(mWorldConfig.getSize());
TLLM_CHECK_WITH_INFO(outputs.contextLogits,
"outputs.contextLogits is nullptr. It must be allocated when computeContextLogits() is enabled.");
outputs.contextLogits->reshape(ITensor::makeShape({batchSize, maxInputLength, vocabSizePadded}));
auto const contextLogitsShape = outputs.contextLogits->getShape();
TLLM_CHECK_WITH_INFO(contextLogitsShape.d[0] == batchSize, "Invalid dim[0]");
TLLM_CHECK_WITH_INFO(contextLogitsShape.d[1] == maxInputLength, "Invalid dim[1]");
TLLM_CHECK_WITH_INFO(contextLogitsShape.d[2] == vocabSizePadded, "Invalid dim[2]");
buffers.logits = outputs.contextLogits;
}
buffers.reshape(generationConfig, mModelConfig, mWorldConfig);
kvCacheAddSequences(beamWidth, microBatchId);
ITensor::SharedPtr newTokens{initNewTokens(inputs, samplingConfig, microBatchId)};
auto& onTokenGenerated = outputs.onTokenGenerated;
outputs.ids->reshape(ITensor::makeShape({batchSize, beamWidth, mDecoderMaxSequenceLength}));
outputs.lengths->reshape(ITensor::makeShape({batchSize, beamWidth}));
auto kvCacheManager = mModelConfig.usePagedKvCache() ? mKvCacheManagers.at(microBatchId).get() : nullptr;
RuntimeBuffers::TensorMap inputBuffers[2];
RuntimeBuffers::TensorMap outputBuffers[2];
for (SizeType step = 0; step < maxNewTokens; ++step)
{
auto const contextId = step % 2;
if (step == 0)
{
SizeType const contextIdForContextPhase
= mRuntime->getNbProfiles() == 2 ? mRuntime->getNbContexts() / 2 : 0;
buffers.prepareContextStep(
inputs.ids, inputs.padId, manager, kvCacheManager, generationConfig, mModelConfig, mWorldConfig);
buffers.getRuntimeBuffers(
inputBuffers[contextId], outputBuffers[contextId], step, inputs.ids, mModelConfig, mWorldConfig);
mRuntime->setInputTensors(contextIdForContextPhase, inputBuffers[contextId]);
mRuntime->setOutputTensors(contextIdForContextPhase, outputBuffers[contextId]);
if (isCudaGraphMode())
{
for (auto& instance : mCudaGraphInstances)
{
instance.clear();
}
}
TLLM_CHECK_WITH_INFO(
mRuntime->executeContext(contextIdForContextPhase), "Executing TRT engine in context phase failed!");
}
else
{
if (isCudaGraphMode() && mCudaGraphInstances[contextId].hasInstance())
{
mCudaGraphInstances[contextId].launch(mRuntime->getStream());
}
else
{
TLLM_CHECK_WITH_INFO(
mRuntime->executeContext(contextId), "Executing TRT engine in generation phase failed!");
}
}
sync_check_cuda_error();
if (step == 0)
{
buffers.postContextStep(manager, generationConfig, mModelConfig, mWorldConfig);
}
std::swap(buffers.cacheIndirectionDecoderInput, buffers.cacheIndirectionDecoderOutput);
if (step < maxNewTokens - 1) // this is not the last step
{ // preparing the next step
auto const nextStep = step + 1;
auto const nextContextId = nextStep % 2;
auto nextInputIds = buffers.prepareNextStep(
step, newTokens, manager, kvCacheManager, generationConfig, mModelConfig, mWorldConfig);
buffers.getRuntimeBuffers(inputBuffers[nextContextId], outputBuffers[nextContextId], nextStep, nextInputIds,
mModelConfig, mWorldConfig);
mRuntime->setInputTensors(nextContextId, inputBuffers[nextContextId]);
mRuntime->setOutputTensors(nextContextId, outputBuffers[nextContextId]);
if (isCudaGraphMode())
{
mCudaGraphInstances[nextContextId].prepareNextGraph(*mRuntime, nextContextId);
}
}
sync_check_cuda_error();
// FIXME: this synchronize is important to get logits right
// manager.getStream().synchronize();
decoderStepAsync(outputs.ids, newTokens, maxInputLength + step, microBatchId);
auto const shouldStop = shouldStopSync(batchSize, beamWidth, microBatchId);
if (mWorldConfig.isFirstPipelineParallelRank())
{
if (onTokenGenerated)
{
// TODO use getNewTokens(), remove step from Callback?
ITensor::SharedPtr outputIds
= mWorldConfig.isPipelineParallel() ? outputs.ids : mDecoders.at(microBatchId)->getOutputIds();
onTokenGenerated(outputIds, step, shouldStop || step == maxNewTokens - 1);
}
}
if (shouldStop)
{
mLogger->log(nvinfer1::ILogger::Severity::kVERBOSE, "GPT decoding finished early");
break;
}
}
if (mModelConfig.usePagedKvCache())
{
for (auto batchIdx = 0; batchIdx < batchSize; ++batchIdx)
{
kvCacheManager->removeSequence(batchIdx);
}
}
finalizeOutputIds(*outputs.ids, microBatchId);
manager.copy(*buffers.sequenceLengths, *outputs.lengths);
manager.getStream().synchronize();
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
}
void GptSession::kvCacheAddSequences(SizeType beamWidth, SizeType microBatchId)
void GptSession::kvCacheAddSequences(SizeType beamWidth, SizeType microBatchId, SizeType firstBatchIdx)
{
if (mModelConfig.usePagedKvCache())
{
auto& kvCacheManager = mKvCacheManagers.at(microBatchId);
TLLM_CHECK(kvCacheManager);
TLLM_CHECK(mKvCacheManager);
auto contextLengthsHost = mBuffers.at(microBatchId)->contextLengthsHost;
TLLM_CHECK(contextLengthsHost);
auto const contextLengthsPtr = bufferCast<SizeType const>(*contextLengthsHost);
auto const contextLengthsSize = static_cast<SizeType>(contextLengthsHost->getSize());
for (SizeType batchIdx = 0; batchIdx < contextLengthsSize; ++batchIdx)
for (SizeType batchIdx = firstBatchIdx; batchIdx < firstBatchIdx + contextLengthsSize; ++batchIdx)
{
kvCacheManager->addSequence(batchIdx, contextLengthsPtr[batchIdx], beamWidth);
mKvCacheManager->addSequence(batchIdx, contextLengthsPtr[batchIdx], beamWidth);
}
}
}
@ -483,8 +341,8 @@ std::vector<GenerationInput> splitInputs(
return inputBatches;
}
void updateOutputIds(
ITensor::SharedPtr& outputIds, ITensor::SharedPtr& newTokens, SizeType decoderStep, CudaStream const& stream)
void updateOutputIds(ITensor::SharedPtr const& outputIds, ITensor::SharedPtr const& newTokens, SizeType decoderStep,
CudaStream const& stream)
{ // assemble outputIds of all micro batches
auto const& newTokensShape = newTokens->getShape();
auto newTokensView = ITensor::view(newTokens, ITensor::makeShape({1, newTokensShape.d[0] * newTokensShape.d[1]}));
@ -492,10 +350,11 @@ void updateOutputIds(
auto outputIdsView = ITensor::view(
outputIds, ITensor::makeShape({outputIdsShape.d[0] * outputIdsShape.d[1], outputIdsShape.d[2]}));
kernels::invokeTransposeWithOutputOffset(*outputIdsView, *newTokensView, decoderStep, stream);
sync_check_cuda_error();
}
} // namespace
void GptSession::generateMultiBatch(
void GptSession::generate(
GenerationOutput& outputs, GenerationInput const& inputs, SamplingConfig const& samplingConfig)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
@ -511,173 +370,241 @@ void GptSession::generateMultiBatch(
auto const beamWidth = samplingConfig.beamWidth;
outputs.ids->reshape(ITensor::makeShape({batchSize, beamWidth, mDecoderMaxSequenceLength}));
outputs.lengths->reshape(ITensor::makeShape({batchSize, beamWidth}));
auto& onTokenGenerated = outputs.onTokenGenerated;
if (mWorldConfig.isLastPipelineParallelRank() && mModelConfig.computeContextLogits())
{
TLLM_CHECK_WITH_INFO(outputs.contextLogits,
"outputs.contextLogits is nullptr. It must be allocated when computeContextLogits() is enabled.");
auto const vocabSizePadded = mModelConfig.getVocabSizePadded(mWorldConfig.getSize());
auto const inputLengthsHost = manager.copyFrom(*inputLengths, MemoryType::kCPU);
auto const inputLengthsRange = BufferRange<SizeType>(*inputLengthsHost);
auto const maxInputLength = *std::max_element(inputLengthsRange.begin(), inputLengthsRange.end());
outputs.contextLogits->reshape(ITensor::makeShape({batchSize, maxInputLength, vocabSizePadded}));
}
auto const numMicroBatches = std::min(batchSize, mNumMicroBatches);
auto microBatches = splitInputs(inputs, numMicroBatches, manager);
if (numMicroBatches == 1)
{
std::vector<GenerationInput> microBatches{inputs};
generateBatched(outputs, microBatches, samplingConfig);
}
else
{
auto const microBatches = splitInputs(inputs, numMicroBatches, manager);
generateBatched(outputs, microBatches, samplingConfig);
}
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
}
std::function<void(SizeType microBatchId, SizeType step, bool finished)> GptSession::createOnTokenGeneratedCallback(
GenerationOutput& outputs, SizeType numMicroBatches)
{
if (outputs.onTokenGenerated && mWorldConfig.isFirstPipelineParallelRank())
{
ITensor::SharedPtr outputIds{mWorldConfig.isPipelineParallel() || mNumMicroBatches > 1
? outputs.ids
: mDecoders.front()->getOutputIds()};
auto const lastMicroBatchId = numMicroBatches - 1;
return [onTokenGenerated = outputs.onTokenGenerated, outputIds = std::move(outputIds), lastMicroBatchId](
SizeType microBatchId, SizeType step, bool finished)
{
if (microBatchId == lastMicroBatchId)
onTokenGenerated(outputIds, step, finished);
};
}
else
{
return [](SizeType microBatchId, SizeType step, bool finished) {};
}
}
void GptSession::generateBatched(
GenerationOutput& outputs, std::vector<GenerationInput> const& microBatches, SamplingConfig const& samplingConfig)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
auto& manager = mRuntime->getBufferManager();
auto const numMicroBatches = static_cast<SizeType>(microBatches.size());
TLLM_CHECK(numMicroBatches > 0);
TLLM_CHECK(numMicroBatches <= mNumMicroBatches);
SizeType const beamWidth{samplingConfig.beamWidth};
// Initialize and reshape buffers
std::vector<RuntimeBuffers::GenerationConfig> generationConfigs;
std::vector<ITensor::SharedPtr> newTokensPerBatch;
std::vector<ITensor::SharedPtr> outputIdsPerBatch;
for (auto microBatchId = 0; microBatchId < numMicroBatches; ++microBatchId)
{
auto const& microBatchInputs = microBatches.at(microBatchId);
// Initialize and reshape buffers
auto& buffers = *mBuffers.at(microBatchId);
TLLM_CHECK_WITH_INFO(buffers.allocated, "Buffers not allocated, please call setup first!");
buffers.initContextLengths(microBatchInputs.lengths, manager);
generationConfigs.emplace_back(RuntimeBuffers::GenerationConfig::fromInput(*microBatchInputs.ids,
*buffers.contextLengthsHost, microBatchInputs.packed, samplingConfig.beamWidth, mDecoderMaxSequenceLength,
microBatchInputs.maxNewTokens));
auto const& generationConfig = generationConfigs.back();
auto const beamWidth = generationConfig.beamWidth;
buffers.reshape(generationConfig, mModelConfig, mWorldConfig);
kvCacheAddSequences(beamWidth, microBatchId);
newTokensPerBatch.emplace_back(initNewTokens(microBatchInputs, samplingConfig, microBatchId));
generationConfigs.emplace_back(
RuntimeBuffers::GenerationConfig::fromInput(*microBatchInputs.ids, *buffers.contextLengthsHost,
microBatchInputs.packed, beamWidth, mDecoderMaxSequenceLength, microBatchInputs.maxNewTokens));
buffers.reshape(generationConfigs.back(), mModelConfig, mWorldConfig);
}
auto maxNewTokens = generationConfigs.front().maxNewTokens;
auto microBatchSize = generationConfigs.front().batchSize;
auto offset = 0;
outputIdsPerBatch.emplace_back(ITensor::slice(outputs.ids, offset, microBatchSize));
offset += microBatchSize;
for (auto microBatchId = 1; microBatchId < numMicroBatches; ++microBatchId)
auto minMaxNewTokens = std::numeric_limits<SizeType>::max();
std::vector<SizeType> microBatchOffsets(1, 0);
microBatchOffsets.reserve(numMicroBatches + 1);
for (auto microBatchId = 0; microBatchId < numMicroBatches; ++microBatchId)
{
maxNewTokens = std::min(maxNewTokens, generationConfigs.at(microBatchId).maxNewTokens);
auto microBatchSize = generationConfigs.at(microBatchId).batchSize;
outputIdsPerBatch.emplace_back(ITensor::slice(outputs.ids, offset, microBatchSize));
offset += microBatchSize;
auto const& generationConfig = generationConfigs.at(microBatchId);
minMaxNewTokens = std::min(minMaxNewTokens, generationConfig.maxNewTokens);
microBatchOffsets.emplace_back(microBatchOffsets.back() + generationConfig.batchSize);
}
// TODO(micro batching) do we need 1 or 2 per micro batch?
for (auto microBatchId = 0; microBatchId < numMicroBatches; ++microBatchId)
{
auto& buffers = *mBuffers.at(microBatchId);
auto const& generationConfig = generationConfigs.at(microBatchId);
auto const batchOffset = microBatchOffsets.at(microBatchId);
kvCacheAddSequences(beamWidth, microBatchId, batchOffset);
auto const& microBatchInputs = microBatches.at(microBatchId);
buffers.newTokens = initNewTokens(microBatchInputs, samplingConfig, microBatchId);
auto const microBatchSize = generationConfig.batchSize;
buffers.outputIds = ITensor::slice(outputs.ids, batchOffset, microBatchSize);
buffers.outputLengths = ITensor::slice(outputs.lengths, batchOffset, microBatchSize);
if (mWorldConfig.isLastPipelineParallelRank() && mModelConfig.computeContextLogits())
{
buffers.logits = ITensor::slice(outputs.contextLogits, batchOffset, microBatchSize);
}
}
// Prepare the onTokenGenerated callback
auto const onTokenGenerated = createOnTokenGeneratedCallback(outputs, numMicroBatches);
if (useCudaGraphs())
{
for (auto& instance : mCudaGraphInstances)
{
instance.clear();
}
}
auto kvCacheManager = mModelConfig.usePagedKvCache() ? mKvCacheManager.get() : nullptr;
std::vector<RuntimeBuffers::TensorMap> inputBuffers(numMicroBatches * 2);
std::vector<RuntimeBuffers::TensorMap> outputBuffers(numMicroBatches * 2);
std::vector<bool> microBatchesFinished(numMicroBatches, false);
for (SizeType step = 0; step < maxNewTokens; ++step)
{
if (std::all_of(microBatchesFinished.begin(), microBatchesFinished.end(), [](bool x) { return x; }))
break;
auto notFinished = [&microBatchesFinished]()
{ return std::any_of(microBatchesFinished.begin(), microBatchesFinished.end(), [](bool x) { return !x; }); };
for (SizeType step = 0; step < minMaxNewTokens && notFinished(); ++step)
{
auto const flipFlopId = step % 2;
for (auto microBatchId = 0; microBatchId < numMicroBatches; ++microBatchId)
{
auto& buffers = *mBuffers.at(microBatchId);
auto kvCacheManager = mModelConfig.usePagedKvCache() ? mKvCacheManagers.at(microBatchId).get() : nullptr;
auto& newTokens = newTokensPerBatch.at(microBatchId);
auto& generationConfig = generationConfigs.at(microBatchId);
auto& outputIds = outputIdsPerBatch.at(microBatchId);
if (microBatchesFinished.at(microBatchId))
continue;
if (step > 0)
{
auto const microBatchSize = generationConfig.batchSize;
auto const beamWidth = generationConfig.beamWidth;
auto const shouldStop = shouldStopSync(microBatchSize, beamWidth, microBatchId);
auto& buffers = *mBuffers.at(microBatchId);
auto& generationConfig = generationConfigs.at(microBatchId);
if (mWorldConfig.isFirstPipelineParallelRank() && onTokenGenerated
&& microBatchId == numMicroBatches - 1)
auto const contextId = flipFlopId * numMicroBatches + microBatchId;
auto& inputBuffer = inputBuffers[contextId];
auto& outputBuffer = outputBuffers[contextId];
if (step == 0)
{
SizeType const contextIdForContextPhase
= (mRuntime->getNbProfiles() == 2 ? 2 * mNumMicroBatches : 0) + microBatchId;
auto const& microBatchInputs = microBatches.at(microBatchId);
buffers.prepareContextStep(microBatchInputs.ids, microBatchInputs.padId, manager, kvCacheManager,
microBatchOffsets.at(microBatchId), generationConfig, mModelConfig, mWorldConfig);
buffers.getRuntimeBuffers(
inputBuffer, outputBuffer, step, microBatchInputs.ids, mModelConfig, mWorldConfig);
mRuntime->setInputTensors(contextIdForContextPhase, inputBuffer);
mRuntime->setOutputTensors(contextIdForContextPhase, outputBuffer);
TLLM_CHECK_WITH_INFO(
mRuntime->executeContext(contextIdForContextPhase), "Executing TRT engine in context step failed!");
sync_check_cuda_error();
buffers.postContextStep(manager, generationConfig, mModelConfig, mWorldConfig);
sync_check_cuda_error();
}
else
{
auto nextInputIds = buffers.prepareNextStep(step - 1, manager, kvCacheManager,
microBatchOffsets.at(microBatchId), generationConfig, mModelConfig, mWorldConfig);
buffers.getRuntimeBuffers(inputBuffer, outputBuffer, step, nextInputIds, mModelConfig, mWorldConfig);
mRuntime->setInputTensors(contextId, inputBuffer);
mRuntime->setOutputTensors(contextId, outputBuffer);
if (useCudaGraphs())
{
onTokenGenerated(outputs.ids, step - 1, shouldStop);
mCudaGraphInstances.at(contextId).prepareNextGraph(*mRuntime, contextId);
}
// check decoder result of previous iteration
auto const microBatchSize = generationConfig.batchSize;
auto const shouldStop = shouldStopSync(microBatchSize, beamWidth, microBatchId);
onTokenGenerated(microBatchId, step - 1, shouldStop);
if (shouldStop)
{
mLogger->log(nvinfer1::ILogger::Severity::kVERBOSE, "GPT decoding finished early");
microBatchesFinished.at(microBatchId) = true;
continue;
}
if (useCudaGraphs())
{
auto& cudaGraphInstance = mCudaGraphInstances.at(contextId);
TLLM_CHECK(cudaGraphInstance.hasInstance());
cudaGraphInstance.launch(mRuntime->getStream());
}
else
{
TLLM_CHECK_WITH_INFO(mRuntime->executeContext(contextId),
tc::fmtstr("Executing TRT engine in step %d failed!", step));
}
sync_check_cuda_error();
}
auto const contextId = microBatchId;
if (step == 0)
{
SizeType const contextIdForContextPhase
= contextId + (mRuntime->getNbProfiles() == 2 ? mNumMicroBatches : 0);
auto const& inputs = microBatches.at(microBatchId);
buffers.prepareContextStep(
inputs.ids, inputs.padId, manager, kvCacheManager, generationConfig, mModelConfig, mWorldConfig);
buffers.getRuntimeBuffers(
inputBuffers[contextId], outputBuffers[contextId], step, inputs.ids, mModelConfig, mWorldConfig);
mRuntime->setInputTensors(contextIdForContextPhase, inputBuffers[contextId]);
mRuntime->setOutputTensors(contextIdForContextPhase, outputBuffers[contextId]);
TLLM_CHECK_WITH_INFO(
mRuntime->executeContext(contextIdForContextPhase), "Executing TRT engine failed!");
buffers.postContextStep(manager, generationConfig, mModelConfig, mWorldConfig);
}
else
{
auto nextInputIds = buffers.prepareNextStep(
step - 1, newTokens, manager, kvCacheManager, generationConfig, mModelConfig, mWorldConfig);
buffers.getRuntimeBuffers(
inputBuffers[contextId], outputBuffers[contextId], step, nextInputIds, mModelConfig, mWorldConfig);
mRuntime->setInputTensors(contextId, inputBuffers[contextId]);
mRuntime->setOutputTensors(contextId, outputBuffers[contextId]);
TLLM_CHECK_WITH_INFO(mRuntime->executeContext(contextId), "Executing TRT engine failed!");
}
sync_check_cuda_error();
std::swap(buffers.cacheIndirectionDecoderInput, buffers.cacheIndirectionDecoderOutput);
auto const maxInputLength = generationConfigs.at(microBatchId).maxInputLength;
auto const decoderStep = maxInputLength + step;
decoderStepAsync(outputIds, newTokens, decoderStep, microBatchId);
if (!mWorldConfig.isPipelineParallel() && mNumMicroBatches > 1)
{
updateOutputIds(outputIds, newTokens, decoderStep, mRuntime->getStream());
}
decoderStepAsync(decoderStep, microBatchId);
}
}
offset = 0;
// TODO(micro batching) move into loop above?
// Collect the results for the last step
for (auto microBatchId = 0; microBatchId < numMicroBatches; ++microBatchId)
{
auto const& generationConfig = generationConfigs.at(microBatchId);
auto const microBatchSize = generationConfig.batchSize;
auto const shouldStop = shouldStopSync(microBatchSize, beamWidth, microBatchId);
onTokenGenerated(microBatchId, minMaxNewTokens - 1, shouldStop);
auto kvCacheManager = mModelConfig.usePagedKvCache() ? mKvCacheManagers.at(microBatchId).get() : nullptr;
auto& outputIds = outputIdsPerBatch.at(microBatchId);
// TODO(micro batching) sync receive event
if (mWorldConfig.isFirstPipelineParallelRank() && onTokenGenerated && microBatchId == numMicroBatches - 1)
{
onTokenGenerated(outputs.ids, maxNewTokens - 1, true);
}
auto const firstBatchIdx = microBatchOffsets.at(microBatchId);
if (mModelConfig.usePagedKvCache())
{
for (auto batchIdx = 0; batchIdx < microBatchSize; ++batchIdx)
for (auto batchIdx = firstBatchIdx; batchIdx < firstBatchIdx + microBatchSize; ++batchIdx)
{
kvCacheManager->removeSequence(batchIdx);
}
}
// TODO(micro batching) use mCommStream?
finalizeOutputIds(*outputIds, microBatchId);
auto& buffers = *mBuffers.at(microBatchId);
auto outputLengths = ITensor::slice(outputs.lengths, offset, microBatchSize);
manager.copy(*buffers.sequenceLengths, *outputLengths);
offset += microBatchSize;
if (beamWidth > 1)
finalizeOutputIds(microBatchId);
else if (!mWorldConfig.isPipelineParallel())
manager.copy(*mDecoders.at(microBatchId)->getOutputIds(), *mBuffers.at(microBatchId)->outputIds);
}
manager.getStream().synchronize();
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
}
void GptSession::decoderStepAsync(
ITensor::SharedPtr& outputIds, ITensor::SharedPtr& newTokens, SizeType decoderStep, SizeType microBatchId)
void GptSession::decoderStepAsync(SizeType decoderStep, SizeType microBatchId)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
auto& stream = mRuntime->getStream();
auto& buffers = *mBuffers.at(microBatchId);
auto const& outputIds = buffers.outputIds;
auto const& newTokens = buffers.newTokens;
if (mWorldConfig.isLastPipelineParallelRank())
{
@ -734,6 +661,14 @@ void GptSession::decoderStepAsync(
}
mCommStream->record(mReceivedEvents.at(microBatchId).get());
}
if (!mWorldConfig.isPipelineParallel() && mNumMicroBatches > 1)
{
updateOutputIds(outputIds, newTokens, decoderStep, stream);
stream.record(mReceivedEvents.at(microBatchId).get());
}
sync_check_cuda_error();
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
}
@ -746,12 +681,18 @@ bool GptSession::shouldStopSync(SizeType batchSize, SizeType beamWidth, SizeType
if (mWorldConfig.isLastPipelineParallelRank())
{ // read the Finished flag from the decoder
auto& decoder = *mDecoders.at(microBatchId);
decoder.isFinishedSync();
decoder.forwardSync();
nbFinished = *bufferCast<SizeType>(*decoder.getNbFinished());
if (!mWorldConfig.isPipelineParallel() && mNumMicroBatches > 1)
{
// ensure outputIds have been updated
mReceivedEvents.at(microBatchId).synchronize();
}
}
else
{ // ensure all information has been received
TLLM_CUDA_CHECK(cudaEventSynchronize(mReceivedEvents.at(microBatchId).get()));
mReceivedEvents.at(microBatchId).synchronize();
nbFinished = *bufferCast<SizeType>(*mBuffers.at(microBatchId)->nbFinished);
}
sync_check_cuda_error();
@ -759,10 +700,12 @@ bool GptSession::shouldStopSync(SizeType batchSize, SizeType beamWidth, SizeType
return nbFinished == batchSize * beamWidth;
}
void GptSession::finalizeOutputIds(ITensor& outputIds, SizeType microBatchId)
void GptSession::finalizeOutputIds(SizeType microBatchId)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
auto& manager = mRuntime->getBufferManager();
auto& outputIds = *mBuffers.at(microBatchId)->outputIds;
auto& sequenceLengths = *mBuffers.at(microBatchId)->sequenceLengths;
if (mWorldConfig.isPipelineParallel())
{
@ -773,18 +716,20 @@ void GptSession::finalizeOutputIds(ITensor& outputIds, SizeType microBatchId)
{ // send ids from last to first
auto const peer = pipelineGroup.front();
auto const finalOutputIds = mDecoders.at(microBatchId)->getFinalOutputIds();
mPipelineComm->send(
bufferCast<std::int32_t>(*finalOutputIds), finalOutputIds->getSize(), peer, stream, *mLogger);
mPipelineComm->send<TokenIdType>(*finalOutputIds, peer, stream, *mLogger);
mPipelineComm->send<SizeType>(sequenceLengths, peer, stream, *mLogger);
}
else if (mWorldConfig.isFirstPipelineParallelRank())
{ // receive ids from last on first
auto const peer = pipelineGroup.back();
mPipelineComm->receive(bufferCast<std::int32_t>(outputIds), outputIds.getSize(), peer, stream, *mLogger);
mPipelineComm->receive<TokenIdType>(outputIds, peer, stream, *mLogger);
mPipelineComm->receive<SizeType>(sequenceLengths, peer, stream, *mLogger);
}
}
else
{
manager.copy(*mDecoders.at(microBatchId)->getFinalOutputIds(), outputIds);
// sequenceLengths are already updated by decoder
}
sync_check_cuda_error();

View File

@ -1,6 +1,3 @@
//
// Created by martinma on 5/24/23.
//
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
*
@ -19,9 +16,6 @@
#include "tensorrt_llm/runtime/runtimeBuffers.h"
#include <algorithm>
#include <iostream>
#include "ipcUtils.h"
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
#include "tensorrt_llm/common/stlUtils.h"
@ -29,6 +23,9 @@
#include "tensorrt_llm/runtime/tllmRuntime.h"
#include "tensorrt_llm/runtime/utils/sessionUtils.h"
#include <algorithm>
#include <iostream>
using namespace tensorrt_llm::runtime;
namespace tc = tensorrt_llm::common;
@ -107,7 +104,6 @@ void RuntimeBuffers::create(TllmRuntime& runtime, GptModelConfig const& modelCon
}
contextLengthsHost = manager.emptyTensor(MemoryType::kPINNED, nvinfer1::DataType::kINT32);
sequenceLengths = manager.emptyTensor(MemoryType::kGPU, nvinfer1::DataType::kINT32);
lastTokenIds = manager.emptyTensor(MemoryType::kGPU, nvinfer1::DataType::kINT32);
auto const localNbLayers = modelConfig.getNbLayers(worldConfig.getPipelineParallelism());
@ -224,7 +220,6 @@ void RuntimeBuffers::reshape(
logits->reshape(ITensor::makeShape({batchSize, 1, vocabSizePadded}));
}
sequenceLengths->reshape(ITensor::makeShape({batchSize}));
lastTokenIds->reshape(ITensor::makeShape({batchSize}));
auto kvCacheShape
@ -317,7 +312,6 @@ void RuntimeBuffers::tile(BufferManager& manager, GenerationConfig const& genera
}
utils::tileBufferReplace(contextLengthsDevice, beamWidth, manager);
utils::tileBufferReplace(sequenceLengths, beamWidth, manager);
if (modelConfig.useGptAttentionPlugin())
{
@ -363,6 +357,10 @@ void RuntimeBuffers::postContextStep(BufferManager& manager, GenerationConfig co
tile(manager, generationConfig, modelConfig, worldConfig);
}
// use output lengths after context step
manager.copy(*contextLengthsDevice, *outputLengths);
sequenceLengths = ITensor::view(outputLengths);
sequenceLengths->reshape(ITensor::makeShape({batchSize * beamWidth}));
// no need to copy data in lastTokenIds because it is overwritten in prepareNextStep
lastTokenIds->reshape(ITensor::makeShape({batchSize * beamWidth}));
@ -377,15 +375,16 @@ void RuntimeBuffers::postContextStep(BufferManager& manager, GenerationConfig co
}
void RuntimeBuffers::prepareContextStep(TensorPtr const& inputIds, TokenIdType const padId, BufferManager& manager,
KvCacheManager const* kvCacheManager, GenerationConfig const& generationConfig, GptModelConfig const& modelConfig,
WorldConfig const& worldConfig)
KvCacheManager const* kvCacheManager, SizeType firstBatchSlotIdx, GenerationConfig const& generationConfig,
GptModelConfig const& modelConfig, WorldConfig const& worldConfig)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
auto& stream = manager.getStream();
SizeType const batchSize = generationConfig.batchSize;
SizeType const maxInputLength = generationConfig.maxInputLength;
manager.copy(*contextLengthsDevice, *sequenceLengths);
// use context lengths only in context step
sequenceLengths = contextLengthsDevice;
if (modelConfig.useGptAttentionPlugin())
{
@ -457,7 +456,8 @@ void RuntimeBuffers::prepareContextStep(TensorPtr const& inputIds, TokenIdType c
if (modelConfig.useGptAttentionPlugin() && modelConfig.usePagedKvCache())
{
auto constexpr contextBeamWidth = 1;
kvCacheManager->getBlockPointersOfBatch(kvCacheBlockPointersHost, batchSize, contextBeamWidth);
kvCacheManager->getBlockPointersOfBatch(
*kvCacheBlockPointersHost, firstBatchSlotIdx, batchSize, contextBeamWidth);
manager.copy(*kvCacheBlockPointersHost, *kvCacheBlockPointersDevice);
}
@ -475,8 +475,8 @@ void RuntimeBuffers::prepareContextStep(TensorPtr const& inputIds, TokenIdType c
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
}
RuntimeBuffers::TensorPtr RuntimeBuffers::prepareNextStep(SizeType const step, TensorPtr const& outputIds,
BufferManager& manager, KvCacheManager* kvCacheManager, GenerationConfig const& generationConfig,
RuntimeBuffers::TensorPtr RuntimeBuffers::prepareNextStep(SizeType const step, BufferManager& manager,
KvCacheManager* kvCacheManager, SizeType firstBatchSlotIdx, GenerationConfig const& generationConfig,
GptModelConfig const& modelConfig, WorldConfig const& worldConfig)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
@ -495,7 +495,7 @@ RuntimeBuffers::TensorPtr RuntimeBuffers::prepareNextStep(SizeType const step, T
// batch in first dim
inputShape = ITensor::makeShape({batchSize * beamWidth, 1});
}
auto nextInputIds = outputIds ? ITensor::view(outputIds, inputShape) : TensorPtr{};
auto nextInputIds = newTokens ? ITensor::view(newTokens, inputShape) : TensorPtr{};
if (modelConfig.useGptAttentionPlugin())
{
@ -570,11 +570,11 @@ RuntimeBuffers::TensorPtr RuntimeBuffers::prepareNextStep(SizeType const step, T
if (modelConfig.usePagedKvCache())
{
for (auto batchIdx = 0; batchIdx < batchSize; ++batchIdx)
for (auto batchIdx = firstBatchSlotIdx; batchIdx < firstBatchSlotIdx + batchSize; ++batchIdx)
{
kvCacheManager->addToken(batchIdx);
}
kvCacheManager->getBlockPointersOfBatch(kvCacheBlockPointersHost, batchSize, beamWidth);
kvCacheManager->getBlockPointersOfBatch(*kvCacheBlockPointersHost, firstBatchSlotIdx, batchSize, beamWidth);
manager.copy(*kvCacheBlockPointersHost, *kvCacheBlockPointersDevice);
}

View File

@ -16,7 +16,6 @@
#pragma once
#include "ipcUtils.h"
#include "tensorrt_llm/runtime/bufferManager.h"
#include "tensorrt_llm/runtime/gptModelConfig.h"
#include "tensorrt_llm/runtime/iTensor.h"
@ -29,6 +28,7 @@ class KVCacheManager;
namespace tensorrt_llm::runtime
{
class IpcMemory;
class TllmRuntime;
class RuntimeBuffers
@ -59,6 +59,11 @@ public:
std::vector<TensorPtr> presentKeysValsAlt; // without attention plugin
TensorPtr kvCacheBlockPointersDevice; // [numLayers, batchSize * beamWidth, 2, maxBlocksPerSeq * 2]
// References to tmp buffers
TensorPtr newTokens;
TensorPtr outputIds;
TensorPtr outputLengths;
// beam search (shared between engine and decoder)
TensorPtr cacheIndirectionDecoderInput;
TensorPtr cacheIndirectionDecoderOutput;
@ -74,6 +79,7 @@ public:
bool allocated{false};
private:
std::vector<std::shared_ptr<IpcMemory>> mIpcMemoryHandles;
public:
@ -119,10 +125,10 @@ public:
GptModelConfig const& modelConfig, WorldConfig const& worldConfig);
void prepareContextStep(TensorPtr const& inputIds, TokenIdType padId, BufferManager& manager,
KvCacheManager const* kvCacheManager, GenerationConfig const& generationConfig,
KvCacheManager const* kvCacheManager, SizeType firstBatchSlotIdx, GenerationConfig const& generationConfig,
GptModelConfig const& modelConfig, WorldConfig const& worldConfig);
TensorPtr prepareNextStep(SizeType step, TensorPtr const& outputIds, BufferManager& manager,
KvCacheManager* kvCacheManager, GenerationConfig const& generationConfig, GptModelConfig const& modelConfig,
TensorPtr prepareNextStep(SizeType step, BufferManager& manager, KvCacheManager* kvCacheManager,
SizeType firstBatchSlotIdx, GenerationConfig const& generationConfig, GptModelConfig const& modelConfig,
WorldConfig const& worldConfig);
void getRuntimeBuffers(TensorMap& inputBuffers, TensorMap& outputBuffers, SizeType step, TensorPtr const& inputIds,

View File

@ -273,18 +273,11 @@ void StatefulGptDecoder::forwardAsync(decoder::Output& output, decoder::Input co
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
}
bool StatefulGptDecoder::isFinishedSync()
void StatefulGptDecoder::forwardSync()
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
mDecodedEvent.synchronize();
auto& dOutput = *mDecodingOutput;
auto finished = mNbSteps >= mMaxNewTokens
// This condition requires the synchronization above
|| *bufferCast<SizeType>(*dOutput.finishedSum) == static_cast<SizeType>(dOutput.finished->getSize());
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
return finished;
}
IStatefulGptDecoder::TensorPtr StatefulGptDecoder::getFinalOutputIds() const

View File

@ -47,7 +47,7 @@ public:
void forwardAsync(decoder::Output& output, decoder::Input const& input) override;
bool isFinishedSync() override;
void forwardSync() override;
//! @brief Gather final results for all requests.
[[nodiscard]] TensorPtr getFinalOutputIds() const override;

View File

@ -44,35 +44,25 @@ def find_root_dir(start_dir: _tp.Optional[_pl.Path] = None) -> _pl.Path:
return find_dir_containing(("scripts", "examples", "cpp"), start_dir)
def run_tests(cuda_architectures: _tp.Optional[str] = None,
build_dir: _tp.Optional[str] = None,
dist_dir: _tp.Optional[str] = None,
model_cache: _tp.Optional[str] = None,
skip_gptj=False,
skip_llama=False,
skip_chatglm6b=False,
skip_chatglm2_6b=False,
only_fp8=False,
trt_root: _tp.Optional[str] = None) -> None:
root_dir = find_root_dir()
_log.info("Using root directory: %s", str(root_dir))
def run_command(command: _tp.Sequence[str],
cwd: _pl.Path,
*,
shell=False,
env=None) -> None:
_log.info("Running: cd %s && %s", str(cwd), " ".join(command))
_sp.check_call(command, cwd=cwd, shell=shell, env=env)
def run_command(command: _tp.Sequence[str],
*,
cwd=root_dir,
shell=False,
env=None) -> None:
_log.info("Running: cd %s && %s", str(cwd), " ".join(command))
_sp.check_call(command, cwd=cwd, shell=shell, env=env)
python_exe = _sys.executable
def build_trt_llm(python_exe: str,
root_dir: _pl.Path,
build_dir: _pl.Path,
cuda_architectures: _tp.Optional[str] = None,
dist_dir: _tp.Optional[str] = None,
trt_root: _tp.Optional[str] = None):
# Build wheel again to WAR issue that the "google-tests" target needs the cmake generated files
# which were not packaged when running the build job
# eventually it should be packaged in build job, and run test only on test node
cuda_architectures = cuda_architectures if cuda_architectures is not None else "80"
build_dir = _pl.Path(
build_dir) if build_dir is not None else _pl.Path("cpp") / "build"
dist_dir = _pl.Path(dist_dir) if dist_dir is not None else _pl.Path("build")
build_wheel = [
python_exe, "scripts/build_wheel.py", "--cuda_architectures",
@ -83,95 +73,171 @@ def run_tests(cuda_architectures: _tp.Optional[str] = None,
if trt_root is not None:
build_wheel += ["--trt_root", str(trt_root)]
run_command(build_wheel)
run_command(build_wheel, cwd=root_dir)
dist_dir = dist_dir if dist_dir.is_absolute() else root_dir / dist_dir
wheels = _gl.glob(str(dist_dir / "tensorrt_llm-*.whl"))
assert len(wheels) > 0, "No wheels found"
install_wheel = [python_exe, "-m", "pip", "install", "--upgrade", *wheels]
run_command(install_wheel)
run_command(install_wheel, cwd=root_dir)
def run_tests(cuda_architectures: _tp.Optional[str] = None,
build_dir: _tp.Optional[str] = None,
dist_dir: _tp.Optional[str] = None,
model_cache: _tp.Optional[str] = None,
skip_gptj=False,
skip_llama=False,
skip_chatglm6b=False,
skip_chatglm2_6b=False,
only_fp8=False,
only_multi_gpu=False,
trt_root: _tp.Optional[str] = None) -> None:
root_dir = find_root_dir()
_log.info("Using root directory: %s", str(root_dir))
python_exe = _sys.executable
build_dir = _pl.Path(
build_dir) if build_dir is not None else _pl.Path("cpp") / "build"
build_trt_llm(python_exe=python_exe,
root_dir=root_dir,
build_dir=build_dir,
cuda_architectures=cuda_architectures,
dist_dir=dist_dir,
trt_root=trt_root)
build_dir = build_dir if build_dir.is_absolute() else root_dir / build_dir
resources_dir = _pl.Path("cpp") / "tests" / "resources"
scripts_dir = resources_dir / "scripts"
model_cache = ["--model_cache", model_cache] if model_cache else []
if not only_multi_gpu:
prepare_all_model_tests(python_exe=python_exe,
root_dir=root_dir,
resources_dir=resources_dir,
model_cache=model_cache,
skip_gptj=skip_gptj,
skip_llama=skip_llama,
skip_chatglm6b=skip_chatglm6b,
skip_chatglm2_6b=skip_chatglm2_6b,
only_fp8=only_fp8)
run_google_tests(build_dir=build_dir,
skip_gptj=skip_gptj,
skip_llama=skip_llama,
skip_chatglm6b=skip_chatglm6b,
skip_chatglm2_6b=skip_chatglm2_6b,
only_fp8=only_fp8)
run_benchmarks(python_exe=python_exe,
root_dir=root_dir,
build_dir=build_dir,
resources_dir=resources_dir)
else:
prepare_multi_gpu_model_tests(python_exe=python_exe,
root_dir=root_dir,
resources_dir=resources_dir,
model_cache=model_cache)
run_multi_gpu_tests(build_dir=build_dir)
def prepare_all_model_tests(python_exe: str,
root_dir: _pl.Path,
resources_dir: _pl.Path,
model_cache: _tp.Optional[str] = None,
skip_gptj=False,
skip_llama=False,
skip_chatglm6b=False,
skip_chatglm2_6b=False,
only_fp8=False):
model_cache_arg = ["--model_cache", model_cache] if model_cache else []
only_fp8_arg = ["--only_fp8"] if only_fp8 else []
gpt_env = {**_os.environ, "PYTHONPATH": "examples/gpt"}
build_gpt_engines = [python_exe,
str(scripts_dir / "build_gpt_engines.py")
] + model_cache
run_command(build_gpt_engines, env=gpt_env)
generate_expected_gpt_output = [
python_exe,
str(scripts_dir / "generate_expected_gpt_output.py")
]
run_command(generate_expected_gpt_output, env=gpt_env)
prepare_model_tests(model_name="gpt",
python_exe=python_exe,
root_dir=root_dir,
resources_dir=resources_dir,
model_cache_arg=model_cache_arg)
if not skip_gptj:
build_gptj_engines = [
python_exe, str(scripts_dir / "build_gptj_engines.py")
] + model_cache + only_fp8_arg
run_command(build_gptj_engines)
gptj_env = {**_os.environ, "PYTHONPATH": "examples/gptj"}
generate_expected_gptj_output = [
python_exe,
str(scripts_dir / "generate_expected_gptj_output.py")
] + only_fp8_arg
run_command(generate_expected_gptj_output, env=gptj_env)
prepare_model_tests(model_name="gptj",
python_exe=python_exe,
root_dir=root_dir,
resources_dir=resources_dir,
model_cache_arg=model_cache_arg,
only_fp8_arg=only_fp8_arg)
else:
_log.info("Skipping GPT-J tests")
if not skip_llama:
build_llama_engines = [
python_exe, str(scripts_dir / "build_llama_engines.py")
] + model_cache
run_command(build_llama_engines)
llama_env = {**_os.environ, "PYTHONPATH": "examples/llama"}
generate_expected_llama_output = [
python_exe,
str(scripts_dir / "generate_expected_llama_output.py")
]
run_command(generate_expected_llama_output, env=llama_env)
prepare_model_tests(model_name="llama",
python_exe=python_exe,
root_dir=root_dir,
resources_dir=resources_dir,
model_cache_arg=model_cache_arg)
else:
_log.info("Skipping Lllama tests")
if not skip_chatglm6b:
build_chatglm6b_engines = [
python_exe,
str(scripts_dir / "build_chatglm6b_engines.py")
]
run_command(build_chatglm6b_engines)
chatglm6b_env = {**_os.environ, "PYTHONPATH": "examples/chatglm6b"}
generate_expected_chatglm6b_output = [
python_exe,
str(scripts_dir / "generate_expected_chatglm6b_output.py")
] # only_fp8 is not supported by ChatGLM-6B now
run_command(generate_expected_chatglm6b_output, env=chatglm6b_env)
prepare_model_tests(model_name="chatglm6b",
python_exe=python_exe,
root_dir=root_dir,
resources_dir=resources_dir)
else:
_log.info("Skipping ChatGLM6B tests")
if not skip_chatglm2_6b:
build_chatglm2_6b_engines = [
python_exe,
str(scripts_dir / "build_chatglm2-6b_engines.py")
]
run_command(build_chatglm2_6b_engines)
chatglm2_6b_env = {**_os.environ, "PYTHONPATH": "examples/chatglm2-6b"}
generate_expected_chatglm2_6b_output = [
python_exe,
str(scripts_dir / "generate_expected_chatglm2-6b_output.py")
] # only_fp8 is not supported by ChatGLM2-6B now
run_command(generate_expected_chatglm2_6b_output, env=chatglm2_6b_env)
prepare_model_tests(model_name="chatglm2-6b",
python_exe=python_exe,
root_dir=root_dir,
resources_dir=resources_dir)
else:
_log.info("Skipping ChatGLM2-6B tests")
build_dir = build_dir if build_dir.is_absolute() else root_dir / build_dir
def prepare_multi_gpu_model_tests(python_exe: str,
root_dir: _pl.Path,
resources_dir: _pl.Path,
model_cache: _tp.Optional[str] = None):
model_cache_arg = ["--model_cache", model_cache] if model_cache else []
only_multi_gpu_arg = ["--only_multi_gpu"]
prepare_model_tests(model_name="llama",
python_exe=python_exe,
root_dir=root_dir,
resources_dir=resources_dir,
model_cache_arg=model_cache_arg,
only_multi_gpu_arg=only_multi_gpu_arg)
def prepare_model_tests(model_name: str,
python_exe: str,
root_dir: _pl.Path,
resources_dir: _pl.Path,
model_cache_arg=[],
only_fp8_arg=[],
only_multi_gpu_arg=[]):
scripts_dir = resources_dir / "scripts"
model_env = {**_os.environ, "PYTHONPATH": f"examples/{model_name}"}
build_engines = [
python_exe,
str(scripts_dir / f"build_{model_name}_engines.py")
] + model_cache_arg + only_fp8_arg + only_multi_gpu_arg
run_command(build_engines, cwd=root_dir, env=model_env)
generate_expected_output = [
python_exe,
str(scripts_dir / f"generate_expected_{model_name}_output.py")
] + only_fp8_arg + only_multi_gpu_arg
if only_multi_gpu_arg:
generate_expected_output = ["mpirun", "-n", "4"
] + generate_expected_output
run_command(generate_expected_output, cwd=root_dir, env=model_env)
def run_google_tests(build_dir: _pl.Path, skip_gptj, skip_llama, skip_chatglm6b,
skip_chatglm2_6b, only_fp8):
make_google_tests = [
"cmake", "--build", ".", "--config", "Release", "-j", "--target",
"google-tests"
@ -197,6 +263,26 @@ def run_tests(cuda_architectures: _tp.Optional[str] = None,
ctest.extend(["-E", "|".join(excluded_tests)])
run_command(ctest, cwd=build_dir, env=cpp_env)
def run_multi_gpu_tests(build_dir: _pl.Path):
make_google_tests = [
"cmake", "--build", ".", "--config", "Release", "-j", "--target",
"google-tests"
]
run_command(make_google_tests, cwd=build_dir)
tests_dir = build_dir / "tests"
cpp_env = {**_os.environ}
session_test = [
"mpirun", "-n", "4", "gptSessionTest", "--gtest_filter=*TP*:*PP*"
]
run_command(session_test, cwd=tests_dir, env=cpp_env)
def run_benchmarks(python_exe: str, root_dir: _pl.Path, build_dir: _pl.Path,
resources_dir: _pl.Path):
scripts_dir = resources_dir / "scripts"
make_benchmarks = [
"cmake", "--build", ".", "--config", "Release", "-j", "--target",
"benchmarks"
@ -211,13 +297,13 @@ def run_tests(cuda_architectures: _tp.Optional[str] = None,
str(gpt_engine_dir / "fp16-plugin" / "tp1-pp1-gpu"), "--batch_size",
"8", "--input_output_len", "10,20", "--duration", "10"
]
run_command(benchmark)
run_command(benchmark, cwd=root_dir)
generate_batch_manager_data = [
python_exe,
str(scripts_dir / "generate_batch_manager_data.py")
]
run_command(generate_batch_manager_data)
run_command(generate_batch_manager_data, cwd=root_dir)
benchmark_src_dir = _pl.Path("benchmarks") / "cpp"
data_dir = resources_dir / "data"
@ -229,7 +315,7 @@ def run_tests(cuda_architectures: _tp.Optional[str] = None,
str(resources_dir / "models" / "gpt2"), "--output",
str(data_dir / "prepared_dummy_cnn.json")
]
run_command(prepare_dataset)
run_command(prepare_dataset, cwd=root_dir)
benchmark = [
str(benchmark_exe_dir / "gptManagerBenchmark"), "--model", "gpt",
@ -238,7 +324,7 @@ def run_tests(cuda_architectures: _tp.Optional[str] = None,
"--type", "IFB", "--dataset",
str(data_dir / "prepared_dummy_cnn.json")
]
run_command(benchmark)
run_command(benchmark, cwd=root_dir)
benchmark = [
str(benchmark_exe_dir / "gptManagerBenchmark"), "--model", "gpt",
"--engine_dir",
@ -246,7 +332,7 @@ def run_tests(cuda_architectures: _tp.Optional[str] = None,
"--type", "V1", "--dataset",
str(data_dir / "prepared_dummy_cnn.json")
]
run_command(benchmark)
run_command(benchmark, cwd=root_dir)
if __name__ == "__main__":
@ -282,4 +368,9 @@ if __name__ == "__main__":
"--only_fp8",
action="store_true",
help="Run only FP8 tests. Implemented for H100 runners.")
parser.add_argument(
"--only_multi_gpu",
action="store_true",
help="Run only mulit-GPU tests. Implemented for 4 GPUs.")
run_tests(**vars(parser.parse_args()))

View File

@ -205,7 +205,6 @@ void testGptSession(fs::path const& modelPath, ModelSpec const& modelSpec, Model
auto const json = GptJsonConfig::parse(modelPath / "config.json");
auto const modelConfig = json.getModelConfig();
verifyModelConfig(modelConfig, modelSpec);
auto const decoderPerRequest = modelSpec.mDecoderPerRequest;
const int worldSize = modelSpec.mTPSize * modelSpec.mPPSize;
auto const worldConfig = WorldConfig::mpi(*logger, worldSize, modelSpec.mTPSize, modelSpec.mPPSize);
@ -273,15 +272,17 @@ void testGptSession(fs::path const& modelPath, ModelSpec const& modelSpec, Model
}
}
GptSession session{modelConfig, worldConfig, enginePath.string(), logger};
session.setCudaGraphMode(cudaGraphMode);
auto const maxBatchSize = *std::max_element(batchSizes.begin(), batchSizes.end());
GptSession::Config sessionConfig{maxBatchSize, beamWidth, maxSeqLength};
sessionConfig.decoderPerRequest = modelSpec.mDecoderPerRequest;
sessionConfig.numMicroBatches = numMicroBatches;
sessionConfig.cudaGraphMode = cudaGraphMode;
GptSession session{sessionConfig, modelConfig, worldConfig, enginePath.string(), logger};
EXPECT_EQ(session.getDevice(), worldConfig.getDevice());
// Use bufferManager for copying data to and from the GPU
auto& bufferManager = session.getBufferManager();
auto maxBatchSize = *std::max_element(batchSizes.begin(), batchSizes.end());
session.setup(maxBatchSize, beamWidth, maxSeqLength, decoderPerRequest, std::nullopt, numMicroBatches);
for (auto const batchSize : batchSizes)
{
std::cout << "=== batchSize:" << batchSize << " ===\n";
@ -724,15 +725,17 @@ void testGlm6bSession(fs::path const& modelPath, std::string const& modelName, M
givenInputLengths[i] = std::distance(seqBegin, it);
}
GptSession session{modelConfig, worldConfig, enginePath.string(), logger};
session.setCudaGraphMode(cudaGraphMode);
auto const maxBatchSize = *std::max_element(batchSizes.begin(), batchSizes.end());
GptSession::Config sessionConfig{maxBatchSize, beamWidth, maxSeqLength};
sessionConfig.decoderPerRequest = decoderPerRequest;
sessionConfig.numMicroBatches = numMicroBatches;
sessionConfig.cudaGraphMode = cudaGraphMode;
GptSession session{sessionConfig, modelConfig, worldConfig, enginePath.string(), logger};
EXPECT_EQ(session.getDevice(), worldConfig.getDevice());
// Use bufferManager for copying data to and from the GPU
auto& bufferManager = session.getBufferManager();
auto maxBatchSize = *std::max_element(batchSizes.begin(), batchSizes.end());
session.setup(maxBatchSize, beamWidth, maxSeqLength, decoderPerRequest, std::nullopt, numMicroBatches);
for (auto const batchSize : batchSizes)
{
std::vector<SizeType> inputLenghtsHost(batchSize);

View File

@ -67,6 +67,7 @@ DOCKER_RUN_ARGS ?=
GPU_OPTS ?= --gpus=all
SOURCE_DIR ?= $(shell readlink -f ..)
CODE_DIR ?= /code/tensorrt_llm
CCACHE_DIR ?= ${CODE_DIR}/cpp/.ccache
RUN_CMD ?=
CONTAINER_NAME ?= tensorrt_llm
@ -77,6 +78,8 @@ endif
docker run $(DOCKER_RUN_OPTS) $(DOCKER_RUN_ARGS) \
$(GPU_OPTS) \
--volume $(SOURCE_DIR):$(CODE_DIR) \
--env "CCACHE_DIR=${CCACHE_DIR}" \
--env "CCACHE_BASEDIR=${CODE_DIR}" \
--workdir $(CODE_DIR) \
--hostname $(shell hostname)-$* \
--name $(CONTAINER_NAME)-$*-$(USER_NAME) \

View File

@ -16,7 +16,15 @@ set_bash_env() {
init_ubuntu() {
apt-get update
apt-get install -y --no-install-recommends wget gdb git-lfs python3-pip python3-dev python-is-python3 libffi-dev
apt-get install -y --no-install-recommends \
ccache \
gdb \
git-lfs \
libffi-dev \
python3-dev \
python3-pip \
python-is-python3 \
wget
if ! command -v mpirun &> /dev/null; then
DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends openmpi-bin libopenmpi-dev
fi

View File

@ -87,12 +87,24 @@ callback:
using ReturnBatchManagerStatsCallback = std::function<void(const std::string&)>;
```
The statistics are packaged as a JSON string. That string contains three fields:
The statistics are packaged as a JSON string. That string contains the following fields:
* `Timestamp`, the timestamp of the request (obtained using
`std::put_time(&tm, "%m-%d-%Y %H:%M:%S")`),
* `Iteration Counter`, a counter value that corresponds to the execution of a
given request,
* `Active Request Count`, the number of active requests.
* `Iteration Counter`, a global step counter value that increases monotonically over time
* `Active Request Count`, the number of active requests in batch manager
* `Max Request Count`, the max number of requests batch manager can support at a time
When using in-flight batching, the following additional statistics are reported:
* `Max KV cache blocks`, the maximum number of KV cache blocks per GPU
* `Free KV cache blocks`, number of free KV cache blocks per GPU
* `Used KV cache blocks`, number of used KV cache blocks per GPU
* `Tokens per KV cache block`, number of tokens per KV cache block
* `Scheduled Requests`, number of requests scheduled this iteration
* `Context Requests`, number of requests in Context phase
* `Total Context Tokens`, total number of tokens across requests in context phase
* `Generation Requests`, number of requests in Context phase
* `Generation Requests`, number of requests in Generation phase
* `MicroBatch ID`, number of requests in Generation phase
### GptManager Design

View File

@ -1,186 +0,0 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utilities for exporting a model to our custom format.
"""
import numpy as np
def save_val(val, dir, key, tp_num=None):
suffix = "bin" if tp_num is None else f"{tp_num}.bin"
val.tofile(dir / f"model.{key}.{suffix}")
def save_split(split_vals, dir, key, i, factor):
for j, val in enumerate(split_vals):
save_val(val, dir, key, i * factor + j)
def generate_int8(weights, act_range, is_qkv=False):
"""
This function has two purposes:
- compute quantized weights, scaled either per-tensor or per-column
- compute scaling factors
Depending on the GEMM API (CUTLASS/CUBLAS) the required scaling factors differ.
CUTLASS uses two sets of scaling factors. One for the activation X, one for the weight W.
CUBLAS only has one (we can't do per-row scaling). So we must provide pre-multiplied scaling factor.
Here is the list of what we need (T means per-tensor, C per-column):
- scale_x_orig_quant puts fp activation into the quantized range (i.e. [-128, 127], for int8). Used before the GEMM. (T)
- scale_y_quant_orig puts quantized activation into the fp range. Used if the GEMM outputs int8. (T)
- scale_w_quant_orig puts weights from quant range to fp range (used with CUTLASS) (T, C)
- scale_y_accum_quant puts the GEMM result (XW) from accumulation range (int32)
to quant range (int8) (used for CUBLAS) (T, C)
Note that we don't do anything special about row-parallel GEMM. Theorically, we could have per-GPU scaling factors too,
but then the model would change depending on the number of GPUs used.
For QKV projection, the behavior is special. Even if we have a single matrix to perform QKV projection, we consider it
as three different matrices: Q, K, and V. So per-tensor actually means one scaling factor for each Q, K and V.
"""
# compute weight scaling factors for fp->int8 and int8->fp
if is_qkv:
scale_w_orig_quant_t = 127. / act_range["w"].reshape(3, -1).max(
dim=-1, keepdims=True)[0].cpu().numpy()
scale_w_orig_quant_c = 127. / act_range["w"].reshape(3,
-1).cpu().numpy()
else:
scale_w_orig_quant_t = 127. / act_range["w"].max().cpu().numpy()
scale_w_orig_quant_c = 127. / act_range["w"].cpu().numpy()
scale_w_quant_orig_t = 1.0 / scale_w_orig_quant_t
scale_w_quant_orig_c = 1.0 / scale_w_orig_quant_c
# compute the rest of needed scaling factors
scale_x_orig_quant_t = np.array(127. / act_range["x"].max().item())
scale_y_orig_quant_t = np.array(127. / act_range["y"].max().item())
scale_y_quant_orig_t = np.array(act_range["y"].max().item() / 127.)
scale_y_accum_quant_t = scale_y_orig_quant_t / (scale_x_orig_quant_t *
scale_w_orig_quant_t)
scale_y_accum_quant_c = scale_y_orig_quant_t / (scale_x_orig_quant_t *
scale_w_orig_quant_c)
if is_qkv:
scale_y_accum_quant_t = np.broadcast_to(scale_y_accum_quant_t,
scale_w_orig_quant_c.shape)
scale_w_quant_orig_t = np.broadcast_to(scale_w_quant_orig_t,
scale_w_orig_quant_c.shape)
to_i8 = lambda x: x.round().clip(-127, 127).astype(np.int8)
return {
"weight.int8": to_i8(weights * scale_w_orig_quant_t),
"weight.int8.col": to_i8(weights * scale_w_orig_quant_c),
"scale_x_orig_quant": scale_x_orig_quant_t.astype(np.float32),
"scale_w_quant_orig": scale_w_quant_orig_t.astype(np.float32),
"scale_w_quant_orig.col": scale_w_quant_orig_c.astype(np.float32),
"scale_y_accum_quant": scale_y_accum_quant_t.astype(np.float32),
"scale_y_accum_quant.col": scale_y_accum_quant_c.astype(np.float32),
"scale_y_quant_orig": scale_y_quant_orig_t.astype(np.float32),
}
def write_int8(vals, dir, base_key, split_dim, i, factor):
save_split(np.split(vals["weight.int8"], factor, axis=split_dim), dir,
f"{base_key}.weight.int8", i, factor)
save_split(np.split(vals["weight.int8.col"], factor, axis=split_dim), dir,
f"{base_key}.weight.int8.col", i, factor)
saved_keys_once = [
"scale_x_orig_quant", "scale_w_quant_orig", "scale_y_accum_quant",
"scale_y_quant_orig"
]
# per-column scaling factors are loaded per-gpu for ColumnParallel GEMMs (QKV, FC1)
if split_dim == -1:
save_split(
np.split(vals["scale_w_quant_orig.col"], factor, axis=split_dim),
dir, f"{base_key}.scale_w_quant_orig.col", i, factor)
save_split(
np.split(vals["scale_y_accum_quant.col"], factor, axis=split_dim),
dir, f"{base_key}.scale_y_accum_quant.col", i, factor)
else:
saved_keys_once += ["scale_w_quant_orig.col", "scale_y_accum_quant.col"]
if i == 0:
for save_key in saved_keys_once:
save_val(vals[save_key], dir, f"{base_key}.{save_key}")
def str_to_np_dtype(type_str):
convert_dict = {
"fp32": np.float32,
"fp16": np.float16,
}
dtype = convert_dict.get(type_str)
if dtype is None:
raise ValueError(f"{type_str} is an invalid storage type")
return dtype
def split_and_save_weight(i, saved_dir, factor, key, args, val, act_range):
save_int8 = act_range is not None
if "input_layernorm.weight" in key or "input_layernorm.bias" in key or \
"attention.dense.bias" in key or "post_attention_layernorm.weight" in key or \
"post_attention_layernorm.bias" in key or "mlp.dense_4h_to_h.bias" in key or \
"final_layernorm.weight" in key or "final_layernorm.bias" in key:
# shared weights, only need to convert the weights of rank 0
if i == 0:
save_val(val, saved_dir, key)
elif "attention.dense.weight" in key or "mlp.dense_4h_to_h.weight" in key:
split_dim = 0
split_vals = np.split(val, factor, axis=split_dim)
save_split(split_vals, saved_dir, key, i, factor)
if save_int8:
base_key = key.replace(".weight", "")
vals_i8 = generate_int8(val, act_range)
write_int8(vals_i8, saved_dir, base_key, split_dim, i, factor)
elif "mlp.dense_h_to_4h.weight" in key:
split_dim = -1
split_vals = np.split(val, factor, axis=split_dim)
save_split(split_vals, saved_dir, key, i, factor)
if save_int8:
base_key = key.replace(".weight", "")
vals_i8 = generate_int8(val, act_range)
write_int8(vals_i8, saved_dir, base_key, split_dim, i, factor)
elif "mlp.dense_h_to_4h.bias" in key:
split_vals = np.split(val, factor, axis=-1)
save_split(split_vals, saved_dir, key, i, factor)
elif "attention.query_key_value.bias" in key:
local_dim = val.shape[-1] // 3
val = val.reshape(3, local_dim)
split_vals = np.split(val, factor, axis=-1)
save_split(split_vals, saved_dir, key, i, factor)
elif "attention.query_key_value.weight" in key:
hidden_dim = val.shape[0] // 3
local_dim = val.shape[-1]
val = val.reshape(3, hidden_dim, local_dim)
split_dim = -1
split_vals = np.split(val, factor, axis=split_dim)
save_split(split_vals, saved_dir, key, i, factor)
if save_int8:
base_key = key.replace(".weight", "")
vals_i8 = generate_int8(val, act_range, is_qkv=True)
write_int8(vals_i8, saved_dir, base_key, split_dim, i, factor)
else:
print(f"[WARNING] {key} not handled by converter")

View File

@ -1,25 +0,0 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import transformers
prompt = "Could you introduce NVIDIA Corporation for me?"
# run the original model to export LM
tokenizer = transformers.AutoTokenizer.from_pretrained("pyTorchModel",
trust_remote_code=True)
model = transformers.AutoModel.from_pretrained(
"pyTorchModel", trust_remote_code=True).half().cuda()
response, history = model.chat(tokenizer, prompt, history=[])

View File

@ -1,212 +0,0 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
'''
Convert huggingface ChatGLM-6b model. Use https://huggingface.co/THUDM/chatglm-6b as demo.
'''
import argparse
import configparser
import os
from pathlib import Path
import numpy as np
import torch
import torch.multiprocessing as multiprocessing
from convert import split_and_save_weight, str_to_np_dtype
from smoothquant import capture_activation_range, smooth_gemm
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer
from transformers.models.gpt2.modeling_gpt2 import GPT2Block
@torch.no_grad()
def smooth_gpt_model(model, scales, alpha):
# Smooth the activation and weights with smoother = $\diag{s}$
for name, module in model.named_modules():
if not isinstance(module, GPT2Block):
continue
# qkv_proj
layer_name = name + ".attn.c_attn"
smoother = smooth_gemm(module.attn.c_attn.weight.T,
scales[layer_name]["x"], module.ln_1.weight,
module.ln_1.bias, alpha)
scales[layer_name]["x"] = scales[layer_name]["x"] / smoother
scales[layer_name]["w"] = module.attn.c_attn.weight.abs().max(dim=0)[0]
# fc1
layer_name = name + ".mlp.c_fc"
smoother = smooth_gemm(module.mlp.c_fc.weight.T,
scales[layer_name]["x"], module.ln_2.weight,
module.ln_2.bias, alpha)
scales[layer_name]["x"] = scales[layer_name]["x"] / smoother
scales[layer_name]["w"] = module.mlp.c_fc.weight.abs().max(dim=0)[0]
def gpt_to_ft_name(orig_name):
global_weights = { \
"transformer.final_layernorm.bias": "model.final_layernorm.bias", \
"transformer.final_layernorm.weight": "model.final_layernorm.weight", \
}
if orig_name in global_weights:
return global_weights[orig_name]
return ".".join(orig_name.split(".")[1:])
@torch.no_grad()
def hf_chatglm6b_converter(args):
infer_tp = args.tensor_parallelism
saved_dir = Path(args.out_dir) / f"{infer_tp}-gpu"
saved_dir.mkdir(parents=True, exist_ok=True)
# load position_embedding from rank 0
model = AutoModel.from_pretrained(args.in_file, trust_remote_code=True)
act_range = {}
if args.smoothquant is not None or args.calibrate_kv_cache:
os.environ["TOKENIZERS_PARALLELISM"] = os.environ.get(
"TOKENIZERS_PARALLELISM", "false")
act_range = capture_activation_range(
model, AutoTokenizer.from_pretrained(args.in_file))
if args.smoothquant is not None:
smooth_gpt_model(model, act_range, args.smoothquant)
config = configparser.ConfigParser()
config["chatglm6b"] = {}
for key in vars(args):
config["chatglm6b"][key] = f"{vars(args)[key]}"
for k, v in vars(model.config).items():
config["chatglm6b"][k] = f"{v}"
config["chatglm6b"]["weight_data_type"] = args.storage_type
with open(saved_dir / "config.ini", 'w') as configfile:
config.write(configfile)
storage_type = str_to_np_dtype(args.storage_type)
if args.calibrate_kv_cache:
pass
if args.smoothquant is not None:
pass
'''
# list all named parameters
for name, param in model.named_parameters():
print(name,param.shape)
'''
# add weight of LM
data = np.load("lm.npy")
data.astype(storage_type).tofile(saved_dir / "model.lm.weight.bin")
print("Save model.lm.weight.bin")
# add weight of position embedding
nMaxSL = 2048
inv_freq = 10**(-1 / 16 * np.arange(0, 64, 2, dtype=np.float32))
valueTable = np.matmul(
np.arange(nMaxSL, dtype=np.float32).reshape(-1, 1),
np.concatenate([inv_freq, inv_freq],
axis=0).reshape(1, -1)).reshape(nMaxSL,
len(inv_freq) * 2)
np.cos(valueTable).astype(storage_type).tofile(saved_dir /
"model.cosTable.weight.bin")
np.sin(valueTable).astype(storage_type).tofile(saved_dir /
"model.sinTable.weight.bin")
print("Save model.cosTable.weight.bin")
print("Save model.sinTable.weight.bin")
starmap_args = []
for name, param in model.named_parameters():
if "weight" not in name and "bias" not in name:
print("Skip %s" % name)
continue
elif name in [
"transformer.word_embeddings.weight",
"transformer.final_layernorm.weight",
"transformer.final_layernorm.bias"
]:
param.detach().cpu().numpy().astype(storage_type).tofile(
saved_dir / (name.replace("transformer", "model") + ".bin"))
print("Save %s" % name)
continue
ft_name = gpt_to_ft_name(name)
param = param.detach().cpu().numpy().astype(storage_type)
starmap_args.append((0, saved_dir, infer_tp, ft_name, args, param,
act_range.get(name.replace(".weight", ""))))
starmap_args = tqdm(starmap_args, desc="saving weights")
if args.processes > 1:
with multiprocessing.Pool(args.processes) as pool:
pool.starmap(split_and_save_weight, starmap_args)
else:
# simpler for debug situations
for starmap_arg in starmap_args:
split_and_save_weight(*starmap_arg)
print("Save %s" % starmap_arg[3])
if __name__ == "__main__":
torch.multiprocessing.set_start_method("spawn")
parser = argparse.ArgumentParser(
formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument('--out-dir',
'-o',
type=str,
help='file name of output directory',
required=True)
parser.add_argument('--in-file',
'-i',
type=str,
help='file name of input checkpoint file',
required=True)
parser.add_argument('--tensor-parallelism',
'-tp',
type=int,
help='Requested tensor parallelism for inference',
default=1)
parser.add_argument(
"--processes",
"-p",
type=int,
help="How many processes to spawn for conversion (default: 4)",
default=4)
parser.add_argument(
"--calibrate-kv-cache",
"-kv",
action="store_true",
help=
"Generate scaling factors for KV cache. Used for storing KV cache in int8."
)
parser.add_argument(
"--smoothquant",
"-sq",
type=float,
default=None,
help="Set the α parameter (see https://arxiv.org/pdf/2211.10438.pdf)"
" to Smoothquant the model, and output int8 weights."
" A good first try is 0.5. Must be in [0, 1]")
parser.add_argument("--storage-type",
"-t",
type=str,
default="fp32",
choices=["fp32", "fp16"])
args = parser.parse_args()
print("\n=============== Argument ===============")
for key in vars(args):
print("{}: {}".format(key, vars(args)[key]))
print("========================================")
hf_chatglm6b_converter(args)

File diff suppressed because it is too large Load Diff

View File

@ -442,7 +442,11 @@ def build_rank_engine(builder: Builder,
network.plugin_config.set_gpt_attention_plugin(
dtype=args.use_gpt_attention_plugin)
if args.use_gemm_plugin:
network.plugin_config.set_gemm_plugin(dtype=args.use_gemm_plugin)
if not args.enable_fp8:
network.plugin_config.set_gemm_plugin(dtype=args.use_gemm_plugin)
else:
logger.info(
"Gemm plugin does not support FP8. Disabled Gemm plugin.")
if args.use_layernorm_plugin:
network.plugin_config.set_layernorm_plugin(
dtype=args.use_layernorm_plugin)

View File

@ -470,7 +470,11 @@ def build_rank_engine(builder: Builder,
network.plugin_config.set_gpt_attention_plugin(
dtype=args.use_gpt_attention_plugin)
if args.use_gemm_plugin:
network.plugin_config.set_gemm_plugin(dtype=args.use_gemm_plugin)
if not args.enable_fp8:
network.plugin_config.set_gemm_plugin(dtype=args.use_gemm_plugin)
else:
logger.info(
"Gemm plugin does not support FP8. Disabled Gemm plugin.")
if args.use_layernorm_plugin:
network.plugin_config.set_layernorm_plugin(
dtype=args.use_layernorm_plugin)
@ -575,6 +579,7 @@ def build(rank, args):
use_prompt_tuning=args.max_prompt_embedding_table_size > 0,
gather_all_token_logits=args.gather_all_token_logits,
fp8=args.enable_fp8,
quant_mode=args.quant_mode,
use_parallel_embedding=args.use_parallel_embedding)
engine_name = get_engine_name(MODEL_NAME, args.dtype, args.world_size,

View File

@ -22,6 +22,7 @@ import torch
from transformers import AutoTokenizer, T5Tokenizer
import tensorrt_llm
from tensorrt_llm.quantization import QuantMode
from tensorrt_llm.runtime import ModelConfig, SamplingConfig
from build import get_engine_name # isort:skip
@ -52,6 +53,7 @@ def read_config(config_path: Path):
gather_all_token_logits = config['builder_config'][
'gather_all_token_logits']
use_custom_all_reduce = config['plugin_config']['use_custom_all_reduce']
quant_mode = QuantMode(config['builder_config']['quant_mode'])
model_config = ModelConfig(num_heads=num_heads,
num_kv_heads=num_kv_heads,
@ -64,6 +66,7 @@ def read_config(config_path: Path):
tokens_per_block=tokens_per_block,
use_prompt_tuning=use_prompt_tuning,
dtype=dtype,
quant_mode=quant_mode,
gather_all_token_logits=gather_all_token_logits,
use_custom_all_reduce=use_custom_all_reduce)

View File

@ -25,6 +25,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, T5Tokenizer
import tensorrt_llm
import tensorrt_llm.profiler as profiler
from tensorrt_llm.logger import logger
from tensorrt_llm.quantization import QuantMode
from build import find_engines # isort:skip
@ -48,6 +49,7 @@ def TRTGPT(args, config):
paged_kv_cache = config['plugin_config']['paged_kv_cache']
tokens_per_block = config['plugin_config']['tokens_per_block']
use_custom_all_reduce = config['plugin_config']['use_custom_all_reduce']
quant_mode = QuantMode(config['builder_config'].get('quant_mode', 0))
model_config = tensorrt_llm.runtime.ModelConfig(
vocab_size=vocab_size,
@ -60,6 +62,7 @@ def TRTGPT(args, config):
tokens_per_block=tokens_per_block,
paged_kv_cache=paged_kv_cache,
dtype=dtype,
quant_mode=quant_mode,
use_custom_all_reduce=use_custom_all_reduce,
)

View File

@ -348,7 +348,11 @@ def build_rank_engine(builder: Builder,
network.plugin_config.set_gpt_attention_plugin(
dtype=args.use_gpt_attention_plugin)
if args.use_gemm_plugin:
network.plugin_config.set_gemm_plugin(dtype=args.use_gemm_plugin)
if not args.enable_fp8:
network.plugin_config.set_gemm_plugin(dtype=args.use_gemm_plugin)
else:
logger.info(
"Gemm plugin does not support FP8. Disabled Gemm plugin.")
if args.use_layernorm_plugin:
network.plugin_config.set_layernorm_plugin(
dtype=args.use_layernorm_plugin)

View File

@ -585,7 +585,11 @@ def build_rank_engine(builder: Builder,
network.plugin_config.set_gpt_attention_plugin(
dtype=args.use_gpt_attention_plugin)
if args.use_gemm_plugin:
network.plugin_config.set_gemm_plugin(dtype=args.use_gemm_plugin)
if not args.enable_fp8:
network.plugin_config.set_gemm_plugin(dtype=args.use_gemm_plugin)
else:
logger.info(
"Gemm plugin does not support FP8. Disabled Gemm plugin.")
if args.use_rmsnorm_plugin:
network.plugin_config.set_rmsnorm_plugin(dtype=args.use_rmsnorm_plugin)

View File

@ -140,7 +140,6 @@ def main(args):
config_path = os.path.join(args.engine_dir, 'config.json')
with open(config_path, 'r') as f:
config = json.load(f)
tensorrt_llm_llama = TRTLLaMA(args, config)
if test_hf:
@ -156,6 +155,7 @@ def main(args):
def summarize_tensorrt_llm(datapoint):
batch_size = len(datapoint['article'])
assert batch_size > 0, f"Validation dataset is corrupt (0 samples found). The dataset is loaded from ~/.cache/huggingface/datasets/ccdv___cnn_dailymail"
line = copy.copy(datapoint['article'])
line_encoded = []

View File

@ -466,7 +466,11 @@ def build_rank_engine(builder: Builder,
network.plugin_config.set_gpt_attention_plugin(
dtype=args.use_gpt_attention_plugin)
if args.use_gemm_plugin:
network.plugin_config.set_gemm_plugin(dtype=args.use_gemm_plugin)
if not args.enable_fp8:
network.plugin_config.set_gemm_plugin(dtype=args.use_gemm_plugin)
else:
logger.info(
"Gemm plugin does not support FP8. Disabled Gemm plugin.")
if args.use_layernorm_plugin:
network.plugin_config.set_layernorm_plugin(
dtype=args.use_layernorm_plugin)

View File

@ -160,11 +160,9 @@ class Session(object):
context = self.context
for i in inputs:
if self.engine.get_tensor_mode(i.name) != trt.TensorIOMode.INPUT:
logger.error(f"Tensor:{i.name} is not an input tensor")
return None
raise ValueError(f"Tensor:{i.name} is not an input tensor")
if self.engine.get_tensor_dtype(i.name) != i.dtype:
logger.error(f"Tensor:{i.name} has wrong dtype")
return None
raise ValueError(f"Tensor:{i.name} has wrong dtype")
context.set_input_shape(i.name, i.shape)
outputs = []

View File

@ -53,8 +53,8 @@ class TestFunctional(unittest.TestCase):
test_cases = []
test_cases += list(
product(['gpt2_attention', 'llama_attention', 'gptj_attention'],
[ContextFMHAType.disabled], ['float16', 'bfloat16'], [2],
[128], [4], [64], [0], [False], [False], [False], [1, 4],
[ContextFMHAType.disabled], ['float16', 'bfloat16'], [None],
[2], [128], [4], [64], [0], [False], [False], [1, 4],
[True, False], [True, False]))
# Test cases for input padding
@ -62,22 +62,22 @@ class TestFunctional(unittest.TestCase):
product(['llama_attention'], [
ContextFMHAType.disabled,
ContextFMHAType.enabled,
], ['float16', 'bfloat16'], [2], [128], [4], [64], [False], [False],
], ['float16', 'bfloat16'], [None], [2], [128], [4], [64], [False],
[False], [False, True], [1], [False], [False]))
# Test cases for fused context MHAs
test_cases += list(
product(['llama_attention'], [
ContextFMHAType.enabled, ContextFMHAType.enabled_with_fp32_acc
], ['float16', 'bfloat16'], [2], [90, 1024], [4],
[32, 64, 80, 112, 128], [0], [False], [False],
[False, True], [1], [False], [False]))
], ['float16', 'bfloat16'], [None], [2], [90, 1024], [4],
[32, 64, 80, 112, 128], [0], [False], [False, True], [1],
[False], [False]))
# Test cases of float32 d=256 case (for testing MMHA key loops).
test_cases += list(
product(['gptj_attention'], [
ContextFMHAType.enabled,
], ['float32'], [2], [128], [2], [256], [False], [False], [False],
], ['float32'], [None], [2], [128], [2], [256], [False], [False],
[True], [1], [False], [True, False]))
# Test cases for the multi-block MMHA.
@ -85,32 +85,32 @@ class TestFunctional(unittest.TestCase):
test_cases += list(
product(['llama_attention'], [
ContextFMHAType.enabled, ContextFMHAType.enabled_with_fp32_acc
], ['float16', 'bfloat16'], [2], [2048], [4], [64], [0], [True],
[False], [False], [1], [False], [False]))
], ['float16', 'bfloat16'], [None], [2], [2048], [4], [64], [0],
[True], [False], [1], [False], [False]))
# Test cases for the int8 K/V cache.
# Test cases for the 8-bit K/V cache.
test_cases += list(
product(['gpt2_attention'], [ContextFMHAType.disabled],
['float16', 'float32'], [2], [128], [4], [64], [0], [False],
[True], [False], [1, 4], [False], [False]))
['float16', 'float32'], ['int8', 'fp8'], [2], [128], [4],
[64], [0], [False], [False], [1, 4], [False], [False]))
#test cases for multi-query attention
# test cases for multi-query attention
test_cases += list(
product(['gpt_bigcode_attention'], [
ContextFMHAType.disabled, ContextFMHAType.enabled,
ContextFMHAType.enabled_with_fp32_acc
], ['float16', 'bfloat16'], [2], [128], [4], [64], [1], [False],
], ['float16', 'bfloat16'], [None], [2], [128], [4], [64], [1],
[False], [False], [1, 4], [False], [False]))
# test cases for grouped-query attention
test_cases += list(
product(['llama_attention'], [ContextFMHAType.disabled],
['bfloat16', 'float16'], [2], [4], [8], [32], [2, 4],
[False], [False], [False], [1], [False], [False]))
['bfloat16', 'float16'], [None], [2], [4], [8], [32],
[2, 4], [False], [False], [1], [False], [False]))
test_cases += list(
product(['llama_attention'], [ContextFMHAType.enabled], ['float32'],
[1], [165], [32], [128], [4], [False], [False], [False],
[1], [False], [False]))
[None], [1], [165], [32], [128], [4], [False], [False], [1],
[False], [False]))
# test cases for RoPE base and scaling
test_cases += list(
@ -118,6 +118,7 @@ class TestFunctional(unittest.TestCase):
['llama_attention'],
[ContextFMHAType.disabled],
['bfloat16', 'float32'],
[None],
[2],
[4],
[8],
@ -125,7 +126,6 @@ class TestFunctional(unittest.TestCase):
[2, 4],
[False],
[False],
[False],
[1],
[False],
[False],
@ -145,6 +145,7 @@ class TestFunctional(unittest.TestCase):
['llama_attention'],
[ContextFMHAType.enabled],
['float32'],
[None],
[1],
[165],
[32],
@ -152,7 +153,6 @@ class TestFunctional(unittest.TestCase):
[4],
[False],
[False],
[False],
[1],
[False],
[False],
@ -181,13 +181,13 @@ class TestFunctional(unittest.TestCase):
attention_type,
context_fmha_type,
dtype,
kv_cache_dtype,
batch_size,
in_len,
num_heads,
head_size,
num_kv_heads,
enable_multi_block_mmha,
use_int8_kv_cache,
enable_remove_input_padding,
beam_width,
paged_kv_cache,
@ -197,6 +197,10 @@ class TestFunctional(unittest.TestCase):
# if attention_type != "gpt_bigcode_attention" and attention_type != "llama_attention":
# assert num_kv_heads == 0 # safe guard against bad test case configs
use_int8_kv_cache = True if kv_cache_dtype == 'int8' else False
use_fp8_kv_cache = True if kv_cache_dtype == 'fp8' else False
if kv_cache_dtype is None:
kv_cache_dtype = dtype
# Skip tests that are not supported in pre-ampere architecture
if getSMVersion() < 80:
if context_fmha_type == ContextFMHAType.enabled:
@ -211,6 +215,10 @@ class TestFunctional(unittest.TestCase):
pytest.skip(
"bfloat16 is not supported in pre-ampere architecture")
if getSMVersion() < 89:
if use_fp8_kv_cache:
pytest.skip("FP8 is not supported on pre-Ada architectures")
if num_kv_heads == 0:
num_kv_heads = num_heads
# Skip duplicated tests.
@ -220,8 +228,7 @@ class TestFunctional(unittest.TestCase):
so it has been tested with ContextFMHAType.enabled")
session = None
kv_cache_dtype = 'int8' if use_int8_kv_cache else dtype
if use_int8_kv_cache:
if use_int8_kv_cache or use_fp8_kv_cache or True:
# Fixing seed to avoid flakiness in tests with quantization
torch.manual_seed(42)
@ -293,15 +300,15 @@ class TestFunctional(unittest.TestCase):
shape=tuple(past_key_value.shape),
dtype=tensorrt_llm.str_dtype_to_trt(kv_cache_dtype))
kv_int8_quant_scale_tensor = None
kv_int8_dequant_scale_tensor = None
if use_int8_kv_cache:
kv_int8_quant_scale_tensor = Tensor(
name='kv_int8_quant_scale',
kv_quant_scale_tensor = None
kv_dequant_scale_tensor = None
if use_int8_kv_cache or use_fp8_kv_cache:
kv_quant_scale_tensor = Tensor(
name='kv_quant_scale',
shape=(1, ),
dtype=tensorrt_llm.str_dtype_to_trt('float32'))
kv_int8_dequant_scale_tensor = Tensor(
name='kv_int8_dequant_scale',
kv_dequant_scale_tensor = Tensor(
name='kv_dequant_scale',
shape=(1, ),
dtype=tensorrt_llm.str_dtype_to_trt('float32'))
@ -374,11 +381,12 @@ class TestFunctional(unittest.TestCase):
max_position_embeddings,
position_embedding_type=position_embedding_type,
multi_block_mode=enable_multi_block_mmha,
kv_orig_quant_scale=kv_int8_quant_scale_tensor,
kv_quant_orig_scale=kv_int8_dequant_scale_tensor,
kv_orig_quant_scale=kv_quant_scale_tensor,
kv_quant_orig_scale=kv_dequant_scale_tensor,
host_context_lengths=host_context_lengths_tensor,
kv_cache_quant_mode=QuantMode.from_description(
use_int8_kv_cache=use_int8_kv_cache),
use_int8_kv_cache=use_int8_kv_cache,
use_fp8_kv_cache=use_fp8_kv_cache),
kv_cache_block_pointers=pointer_array_tensor,
max_context_length=max_context_length,
qkv_bias=qkv_bias)
@ -405,9 +413,9 @@ class TestFunctional(unittest.TestCase):
else:
inputs['past_key_value'] = past_key_value
if use_int8_kv_cache:
inputs['kv_int8_quant_scale'] = kv_int8_quant_scale
inputs['kv_int8_dequant_scale'] = kv_int8_dequant_scale
if use_int8_kv_cache or use_fp8_kv_cache:
inputs['kv_quant_scale'] = kv_quant_scale
inputs['kv_dequant_scale'] = kv_dequant_scale
if enable_remove_input_padding:
inputs['host_context_lengths'] = host_context_lengths
@ -417,11 +425,13 @@ class TestFunctional(unittest.TestCase):
outputs['present_key_value'] = past_key_value
stream = torch.cuda.current_stream()
# NOTE: when int8 kv cache is used together with paged kv cache no int8 tensors are exposed to TRT
# NOTE: when 8-bit kv cache is used together with paged kv cache no 8-bit tensors are exposed to TRT
int8_trt_flag = use_int8_kv_cache and not paged_kv_cache
fp8_trt_flag = use_fp8_kv_cache and not paged_kv_cache
builder_config = builder.create_builder_config(name=attention_type,
precision=dtype,
int8=int8_trt_flag)
int8=int8_trt_flag,
fp8=fp8_trt_flag)
if session is None:
engine = builder.build_engine(net, builder_config)
session = tensorrt_llm.runtime.Session.from_serialized_engine(
@ -452,8 +462,8 @@ class TestFunctional(unittest.TestCase):
'host_past_key_value_lengths': (batch_size, ),
'sequence_length': (batch_size, ),
'context_lengths': (batch_size, ),
'kv_int8_quant_scale': (1, ),
'kv_int8_dequant_scale': (1, ),
'kv_quant_scale': (1, ),
'kv_dequant_scale': (1, ),
'cache_indirection': (batch_size, 1, max_seq_len),
'host_request_types': (batch_size)
}
@ -467,10 +477,14 @@ class TestFunctional(unittest.TestCase):
if enable_remove_input_padding:
shape_dict['host_context_lengths'] = (batch_size, )
present_key_value = torch.zeros(
shape_dict['past_key_value'],
dtype=tensorrt_llm._utils.str_dtype_to_torch(kv_cache_dtype),
device='cuda')
# HACK: pytorch does not have fp8 dtype yet
torch_kv_cache_dtype = tensorrt_llm._utils.str_dtype_to_torch(
'int8'
) if kv_cache_dtype == 'fp8' else tensorrt_llm._utils.str_dtype_to_torch(
kv_cache_dtype)
present_key_value = torch.zeros(shape_dict['past_key_value'],
dtype=torch_kv_cache_dtype,
device='cuda')
# Init KV cache block manager
if paged_kv_cache:
manager = KVCacheManager([present_key_value],
@ -496,13 +510,12 @@ class TestFunctional(unittest.TestCase):
device='cuda') * 1e-2
torch_present = None
kv_int8_dequant_scale = torch.randint(
1,
10,
shape_dict['kv_int8_dequant_scale'],
dtype=str_dtype_to_torch(kv_cache_dtype),
device='cuda') * 0.0001
kv_int8_quant_scale = 1.0 / kv_int8_dequant_scale
kv_dequant_scale = torch.randint(1,
10,
shape_dict['kv_dequant_scale'],
dtype=torch.float32,
device='cuda') * 0.0001
kv_quant_scale = 1.0 / kv_dequant_scale
ConfigCls = None
AttentionCls = None
@ -709,7 +722,7 @@ class TestFunctional(unittest.TestCase):
device='cuda')
def verify_kv_cache(torch_present):
if not use_int8_kv_cache and num_kv_heads == num_heads and beam_width == 1:
if not use_int8_kv_cache and not use_fp8_kv_cache and num_kv_heads == num_heads and beam_width == 1:
if paged_kv_cache:
kv_cache_cont = manager.blocks_manager.get_continous_caches(
0)
@ -862,8 +875,8 @@ class TestFunctional(unittest.TestCase):
host_past_key_value_lengths, input_lengths,
host_context_lengths, cache_indirection, host_request_types,
num_heads, hidden_size, num_kv_heads, output, dtype,
max_context_length, shape_dict, kv_int8_quant_scale,
kv_int8_dequant_scale, configuration)
max_context_length, shape_dict, kv_quant_scale,
kv_dequant_scale, configuration)
del session
session = None
@ -1007,8 +1020,8 @@ class TestFunctional(unittest.TestCase):
tiled_input_lengths, tiled_host_context_lengths,
cache_indirection, tiled_host_request_types, num_heads,
hidden_size, num_kv_heads, tiled_output, dtype,
max_context_length, shape_dict, kv_int8_quant_scale,
kv_int8_dequant_scale, configuration)
max_context_length, shape_dict, kv_quant_scale,
kv_dequant_scale, configuration)
del session
session = None