Update TensorRT-LLM (#1122)

* Update TensorRT-LLM

---------

Co-authored-by: Eddie-Wang1120 <wangjinheng1120@163.com>
Co-authored-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
This commit is contained in:
Kaiyu Xie 2024-02-21 21:30:55 +08:00 committed by GitHub
parent 0f041b7b57
commit eb8f26c7e4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
95 changed files with 6958 additions and 460 deletions

View File

@ -154,3 +154,5 @@ Take GPT-350M as an example for single GPU with static batching
--static_emulated_timeout 100 \
--dataset ../../benchmarks/cpp/preprocessed_dataset.json
```
`gptManagerBenchmark` can also be used with the high-level C++ API defined by the `executor::Executor` class (see `cpp/include/tensorrt_llm/executor/executor.h`). This can be done by passing the argument `--api executor`. Note that the Executor class is still under development and currently does not support models with tp or pp > 1.

View File

@ -22,6 +22,7 @@
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/common/mpiUtils.h"
#include "tensorrt_llm/common/stringUtils.h"
#include "tensorrt_llm/executor/executor.h"
#include "tensorrt_llm/plugins/api/tllmPlugin.h"
#include "tensorrt_llm/runtime/tllmLogger.h"
#include "tensorrt_llm/runtime/worldConfig.h"
@ -39,9 +40,24 @@ using namespace tensorrt_llm::batch_manager;
using namespace tensorrt_llm::runtime;
namespace tc = tensorrt_llm::common;
namespace texec = tensorrt_llm::executor;
namespace mpi = tensorrt_llm::mpi;
namespace trt = nvinfer1;
namespace
{
struct BenchmarkParams
{
std::optional<SizeType> maxTokensInPagedKvCache = std::nullopt;
std::optional<float> freeGpuMemoryFraction = std::nullopt;
bool enableTrtOverlap = false;
bool enableBlockReuse = false;
bool enableChunkedContext = false;
bool streaming = false;
};
} // namespace
// Class holding all infos regarding a single work item.
// This includes the original request, associated response factor
// and state.
@ -223,6 +239,12 @@ public:
mRequestBenchInfos[requestId] = BenchInfo(inputLength, outputLength, start);
}
void recordStart(SizeType inputLength, SizeType maxNewTokens, uint64_t requestId,
std::chrono::time_point<std::chrono::steady_clock> const& start)
{
mRequestBenchInfos[requestId] = BenchInfo(inputLength, maxNewTokens, start);
}
void recordEnd(uint64_t requestId)
{
mRequestBenchInfos[requestId].end = std::chrono::steady_clock::now();
@ -296,13 +318,107 @@ private:
std::string mOpCsvFile;
}; // class Recorder
class ExecutorServer
{
public:
ExecutorServer(std::filesystem::path const& trtEnginePath, TrtGptModelType modelType, int32_t maxBeamWidth,
batch_scheduler::SchedulerPolicy schedulerPolicy, BenchmarkParams const& benchmarkParams,
std::shared_ptr<Recorder> recorder, std::chrono::milliseconds waitSleep,
std::optional<uint64_t> const staticEmulatedBatchSize, bool logIterationData)
: mRecorder(std::move(recorder))
, mWaitSleep(waitSleep)
, mStaticEmulatedBatchSize(staticEmulatedBatchSize)
, mActiveCount(0)
{
texec::SchedulerConfig schedulerConfig(batch_scheduler::batchManagerToExecSchedPolicy(schedulerPolicy));
texec::KvCacheConfig kvCacheConfig(benchmarkParams.enableBlockReuse, benchmarkParams.maxTokensInPagedKvCache,
std::nullopt, std::nullopt, benchmarkParams.freeGpuMemoryFraction, false);
texec::ExecutorConfig executorConfig(maxBeamWidth, schedulerConfig, kvCacheConfig,
benchmarkParams.enableChunkedContext, true, benchmarkParams.enableTrtOverlap);
mExecutor = std::make_shared<texec::Executor>(trtEnginePath, texec::ModelType::kDECODER_ONLY, executorConfig);
}
~ExecutorServer() {}
void enqueue(std::vector<texec::Request> requests, bool warmup = false)
{
try
{
std::vector<SizeType> inputLengths, maxNewTokens;
for (auto const& request : requests)
{
inputLengths.push_back(request.getInputTokenIds().size());
maxNewTokens.push_back(request.getMaxNewTokens());
}
auto const start = std::chrono::steady_clock::now();
auto reqIds = mExecutor->enqueueRequests(std::move(requests));
for (int req = 0; req < reqIds.size(); ++req)
{
if (!warmup)
{
mRecorder->recordStart(inputLengths.at(req), maxNewTokens.at(req), reqIds.at(req), start);
}
mActiveCount++;
}
}
catch (const std::exception& e)
{
TLLM_THROW("%s", e.what());
}
return;
}
void waitForResponses(std::optional<SizeType> numRequests, bool warmup = false)
{
SizeType numFinished = 0;
while (mActiveCount || (numRequests && numFinished < numRequests.value()))
{
auto responses = mExecutor->awaitResponses(std::nullopt, mWaitSleep);
for (auto const& response : responses)
{
if (response.hasError())
{
// This request failed for some reason, get error msg
std::string errStr = "Request id " + std::to_string(response.getRequestId()) + " failed with err "
+ response.getErrorMsg();
TLLM_THROW(errStr);
}
else if (response.getResult().isFinal)
{
auto reqId = response.getRequestId();
mActiveCount--;
numFinished++;
if (!warmup)
{
mRecorder->recordEnd(reqId);
}
}
}
}
}
void shutdown()
{
mExecutor->shutdown();
}
private:
std::shared_ptr<texec::Executor> mExecutor;
std::shared_ptr<Recorder> mRecorder;
std::chrono::milliseconds mWaitSleep;
std::optional<int> mStaticEmulatedBatchSize;
std::atomic<uint64_t> mActiveCount;
}; // class ExecutorServer
class GptServer
{
public:
GptServer(std::filesystem::path const& trtEnginePath, TrtGptModelType modelType, int32_t maxBeamWidth,
batch_scheduler::SchedulerPolicy schedulerPolicy, TrtGptModelOptionalParams const& optionalParams,
std::shared_ptr<Recorder> recorder, std::optional<uint64_t> terminateReqId, std::chrono::milliseconds waitSleep,
std::optional<uint64_t> const staticEmulatedBatchSize, int const staticEmulatedTimeoutMs)
std::optional<uint64_t> const staticEmulatedBatchSize, int const staticEmulatedTimeoutMs, bool logIterationData)
: mRecorder(std::move(recorder))
, mTerminateReqId(terminateReqId)
, mWaitSleep(waitSleep)
@ -312,9 +428,9 @@ public:
, mStaticEmulatedTimeoutMs(staticEmulatedTimeoutMs)
, mActiveCount(0)
{
ReturnBatchManagerStatsCallback iterationDataCallback = [this, &optionalParams](std::string const& log)
ReturnBatchManagerStatsCallback iterationDataCallback = [this, &logIterationData](std::string const& log)
{
if (optionalParams.logIterationData)
if (logIterationData)
{
TLLM_LOG_INFO(log);
}
@ -396,12 +512,8 @@ public:
auto rank = comm.getRank();
if (rank == 0)
{
auto numNewWorkItems = std::min(static_cast<int64_t>(mWorkItemsQueue.numPendingWorkItems()),
auto const numNewWorkItems = std::min(static_cast<int64_t>(mWorkItemsQueue.numPendingWorkItems()),
static_cast<int64_t>(max_num_requests));
if (world_size > 1)
{
comm.bcast(&numNewWorkItems, 1, mpi::MpiType::kINT64, 0);
}
bool readyForNextBatch = numNewWorkItems > 0;
if (mStaticEmulatedBatchSize)
@ -446,18 +558,21 @@ public:
sendResponse(workItem->requestId(), {}, true, warnStr);
}
}
if (world_size > 1)
}
if (world_size > 1)
{
auto numNewWorkItems = static_cast<int64_t>(rval.size());
comm.bcast(&numNewWorkItems, 1, mpi::MpiType::kINT64, 0);
std::vector<int64_t> packed;
for (auto const& ir : rval)
{
std::vector<int64_t> packed;
for (auto const& ir : rval)
{
auto vpacked = ir->serialize();
packed.push_back(static_cast<int64_t>(vpacked.size()));
packed.insert(
packed.end(), std::move_iterator(vpacked.begin()), std::move_iterator(vpacked.end()));
}
comm.bcast(packed, 0);
auto vpacked = ir->serialize();
packed.push_back(static_cast<int64_t>(vpacked.size()));
packed.insert(
packed.end(), std::move_iterator(vpacked.begin()), std::move_iterator(vpacked.end()));
}
comm.bcast(packed, 0);
}
}
else
@ -581,15 +696,38 @@ std::shared_ptr<InferenceRequest> makeRequest(std::uint64_t reqId, Sample const&
return request;
}
texec::Request makeExecutorRequest(Sample const& sample, SizeType const& beamWidth,
std::optional<SizeType> const& eosId, std::optional<SizeType> const& padId, bool streaming = false,
bool const& returnContextLogits = false, bool const& returnGenerationLogits = false)
{
auto samplingConfig = texec::SamplingConfig{beamWidth};
auto outputConfig = texec::OutputConfig{false, returnContextLogits, returnGenerationLogits, false};
return texec::Request(sample.inputIds, sample.outputLen, streaming, samplingConfig, outputConfig, eosId, padId);
}
void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType modelType,
std::string const& datasetPath, std::string const& opCsvFile, int maxNumSamples, int beamWidth, int warmUp,
const std::optional<int32_t>& eosId, const std::optional<int32_t>& padId,
TrtGptModelOptionalParams const& optionalParams, batch_scheduler::SchedulerPolicy schedulerPolicy,
std::chrono::milliseconds waitSleep, bool returnContextLogits, bool returnGenerationLogits,
std::optional<int> const staticEmulatedBatchSize, int const staticEmulatedTimeoutMs)
std::optional<int32_t> const& eosId, std::optional<int32_t> const& padId, BenchmarkParams const& benchmarkParams,
batch_scheduler::SchedulerPolicy schedulerPolicy, std::chrono::milliseconds waitSleep, bool returnContextLogits,
bool returnGenerationLogits, std::optional<int> const staticEmulatedBatchSize, int const staticEmulatedTimeoutMs,
bool logIterationData)
{
auto const worldConfig = WorldConfig::mpi();
TrtGptModelOptionalParams optionalParams;
if (benchmarkParams.maxTokensInPagedKvCache)
{
optionalParams.kvCacheConfig.maxTokens = benchmarkParams.maxTokensInPagedKvCache;
}
if (benchmarkParams.freeGpuMemoryFraction)
{
optionalParams.kvCacheConfig.freeGpuMemoryFraction = benchmarkParams.freeGpuMemoryFraction;
}
optionalParams.kvCacheConfig.enableBlockReuse = benchmarkParams.enableBlockReuse;
optionalParams.enableChunkedContext = benchmarkParams.enableChunkedContext;
optionalParams.enableTrtOverlap = benchmarkParams.enableTrtOverlap;
BufferManager bufferManager{std::make_shared<CudaStream>()}; // the stream is not used
ITensor::SharedPtr beamWidthTensor{
@ -603,7 +741,7 @@ void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType
auto recorder = std::make_shared<Recorder>(opCsvFile);
uint64_t terminateReqId = numSamples + 1;
auto gptServer = std::make_shared<GptServer>(engineDir, modelType, maxBeamWidth, schedulerPolicy, optionalParams,
recorder, terminateReqId, waitSleep, staticEmulatedBatchSize, staticEmulatedTimeoutMs);
recorder, terminateReqId, waitSleep, staticEmulatedBatchSize, staticEmulatedTimeoutMs, logIterationData);
ITensor::SharedPtr eosIdTensor{
eosId ? bufferManager.copyFrom(&eosId.value(), ITensor::makeShape({1}), MemoryType::kPINNED) : nullptr};
@ -660,6 +798,109 @@ void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType
gptServer->waitBatchManager();
}
void benchmarkExecutor(std::filesystem::path const& engineDir, TrtGptModelType modelType,
std::string const& datasetPath, std::string const& opCsvFile, int maxNumSamples, int beamWidth, int warmUp,
std::optional<int32_t> const& eosId, std::optional<int32_t> const& padId, BenchmarkParams const& benchmarkParams,
batch_scheduler::SchedulerPolicy schedulerPolicy, std::chrono::milliseconds waitSleep, bool returnContextLogits,
bool returnGenerationLogits, std::optional<int> const staticEmulatedBatchSize, bool logIterationData)
{
// Check that mpi size is 1 for now
auto const worldConfig = WorldConfig::mpi();
if (worldConfig.getSize() > 1)
{
TLLM_THROW("benchmarkExecutor does not yet support mpiSize > 1");
}
// Load dataset
const auto samples = parseWorkloadJson(datasetPath, maxNumSamples);
const auto numSamples = samples.size();
auto recorder = std::make_shared<Recorder>(opCsvFile);
auto executorServer = std::make_shared<ExecutorServer>(engineDir, modelType, beamWidth, schedulerPolicy,
benchmarkParams, recorder, waitSleep, staticEmulatedBatchSize, logIterationData);
if (worldConfig.getRank() == 0)
{
// Warm up
{
std::vector<texec::Request> requests;
for (auto i = 0; i < warmUp; ++i)
{
requests.emplace_back(makeExecutorRequest(samples[0], beamWidth, eosId, padId,
benchmarkParams.streaming, returnContextLogits, returnGenerationLogits));
}
executorServer->enqueue(std::move(requests), true);
executorServer->waitForResponses(warmUp, true);
}
// Benchmark
{
// Create requests
recorder->initialize();
std::vector<texec::Request> requests;
std::vector<int> delays;
for (std::size_t i = 0; i < numSamples; ++i)
{
requests.emplace_back(makeExecutorRequest(samples[i], beamWidth, eosId, padId,
benchmarkParams.streaming, returnContextLogits, returnGenerationLogits));
delays.push_back(static_cast<int>(samples[i].delay * 1000));
}
bool hasDelay = std::any_of(delays.begin(), delays.end(), [](const auto& delay) { return delay > 0; });
if (hasDelay && staticEmulatedBatchSize)
{
TLLM_THROW("Executor benchmark doesn't support delays with emulated static batch sizes");
}
if (!hasDelay)
{
if (!staticEmulatedBatchSize)
{
executorServer->enqueue(std::move(requests));
executorServer->waitForResponses(numSamples);
}
else
{
SizeType numRequests = requests.size();
SizeType maxBatchSize = staticEmulatedBatchSize.value();
for (SizeType req = 0; req < numRequests; req += maxBatchSize)
{
auto batchSize = std::min(maxBatchSize, numRequests - req);
std::vector<texec::Request> requestsBatch(std::make_move_iterator(requests.begin() + req),
std::make_move_iterator(requests.begin() + req + batchSize));
// Enqueue in batches
executorServer->enqueue(std::move(requestsBatch));
// Wait for current batch to be done
executorServer->waitForResponses(batchSize);
}
}
}
else
{
// Launch a thread that will wait for responses
std::thread waitThread(
[numSamples, executorServer]() { executorServer->waitForResponses(numSamples); });
// Enqueue requests one by one
for (std::size_t i = 0; i < numSamples; ++i)
{
executorServer->enqueue({std::move(requests.at(i))});
std::this_thread::sleep_for(std::chrono::milliseconds(delays.at(i)));
}
waitThread.join();
}
}
recorder->finalize();
recorder->calculateMetrics();
recorder->report();
recorder->writeOpMetricsToCsv();
// Send terminateReqId to terminate servers on all ranks
// Sever on rank 0 will broadcast the terminate signal to other servers on multi-GPU cases
// gptServer->enqueue(std::make_shared<InferenceRequest>(terminateReqId));
}
}
} // namespace
int main(int argc, char* argv[])
@ -671,6 +912,8 @@ int main(int argc, char* argv[])
options.add_options()(
"m,model", "Model name specified for engines.", cxxopts::value<std::string>()->default_value("gpt_350m"));
options.add_options()("engine_dir", "Directory that store the engines.", cxxopts::value<std::string>());
options.add_options()(
"api", "API type: gptManager or executor.", cxxopts::value<std::string>()->default_value("gptManager"));
options.add_options()(
"type", "Batching type: IFB or V1(non-IFB) batching.", cxxopts::value<std::string>()->default_value("IFB"));
options.add_options()("dataset", "Dataset that is used for benchmarking BatchManager.",
@ -689,14 +932,17 @@ int main(int argc, char* argv[])
options.add_options()("max_tokens_in_paged_kvcache", "Max tokens in paged K-V Cache.", cxxopts::value<int>());
options.add_options()(
"kv_cache_free_gpu_mem_fraction", "K-V Cache Free Gpu Mem Fraction.", cxxopts::value<float>());
options.add_options()("enable_trt_overlap", "Overlap TRT context preparation and execution",
cxxopts::value<bool>()->default_value("false"));
options.add_options()("streaming", "Operate in streaming mode", cxxopts::value<bool>()->default_value("false"));
options.add_options()(
"enable_trt_overlap", "Overlap TRT context preparation and execution", cxxopts::value<bool>());
options.add_options()("enable_kv_cache_reuse", "Enables the KV cache reuse.", cxxopts::value<bool>());
options.add_options()("enable_chunked_context", "Whether to enable context chunking.", cxxopts::value<bool>());
"enable_kv_cache_reuse", "Enables the KV cache reuse.", cxxopts::value<bool>()->default_value("false"));
options.add_options()("enable_chunked_context", "Whether to enable context chunking.",
cxxopts::value<bool>()->default_value("false"));
options.add_options()(
"return_context_logits", "Whether to return context logits.", cxxopts::value<bool>()->default_value("0"));
options.add_options()(
"return_generation_logits", "Whether to return generation logits.", cxxopts::value<bool>()->default_value("0"));
"return_context_logits", "Whether to return context logits.", cxxopts::value<bool>()->default_value("false"));
options.add_options()("return_generation_logits", "Whether to return generation logits.",
cxxopts::value<bool>()->default_value("false"));
options.add_options()("scheduler_policy", "Choose scheduler policy between max_utilization/guaranteed_no_evict.",
cxxopts::value<std::string>()->default_value("guaranteed_no_evict"));
@ -708,8 +954,8 @@ int main(int argc, char* argv[])
cxxopts::value<int>()->default_value("500"));
options.add_options()("log_level", "Choose log level between verbose/info/warning/error/internal_error.",
cxxopts::value<std::string>()->default_value("error"));
options.add_options()(
"log_iteration_data", "On each decoder iteration, print batch state metadata.", cxxopts::value<bool>());
options.add_options()("log_iteration_data", "On each decoder iteration, print batch state metadata.",
cxxopts::value<bool>()->default_value("false"));
options.add_options()("wait_sleep", "Specify how many milliseconds to sleep each iteration of waitForEmpty loop.",
cxxopts::value<int>()->default_value("25"));
@ -729,6 +975,9 @@ int main(int argc, char* argv[])
return 1;
}
// Argument: API
auto const api = result["api"].as<std::string>();
// Argument: Batching Type
auto const type = result["type"].as<std::string>();
TrtGptModelType modelType{TrtGptModelType::V1};
@ -758,50 +1007,38 @@ int main(int argc, char* argv[])
// Argument: wait_sleep
auto const waitSleep = std::chrono::milliseconds(result["wait_sleep"].as<int>());
BenchmarkParams benchmarkParams;
TrtGptModelOptionalParams optionalParams;
// Argument: Max tokens in paged K-V Cache
if (result.count("max_tokens_in_paged_kvcache"))
{
optionalParams.kvCacheConfig.maxTokens = result["max_tokens_in_paged_kvcache"].as<int>();
benchmarkParams.maxTokensInPagedKvCache = result["max_tokens_in_paged_kvcache"].as<int>();
}
// Argument: K-V Cache Free Gpu Mem Fraction
if (result.count("kv_cache_free_gpu_mem_fraction"))
{
optionalParams.kvCacheConfig.freeGpuMemoryFraction = result["kv_cache_free_gpu_mem_fraction"].as<float>();
benchmarkParams.freeGpuMemoryFraction = result["kv_cache_free_gpu_mem_fraction"].as<float>();
}
// Argument: Enable TRT overlap
if (result.count("enable_trt_overlap"))
{
optionalParams.enableTrtOverlap = result["enable_trt_overlap"].as<bool>();
}
benchmarkParams.enableTrtOverlap = result["enable_trt_overlap"].as<bool>();
// Argument: Enable KV cache reuse
if (result.count("enable_kv_cache_reuse"))
{
optionalParams.kvCacheConfig.enableBlockReuse = result["enable_kv_cache_reuse"].as<bool>();
}
benchmarkParams.enableBlockReuse = result["enable_kv_cache_reuse"].as<bool>();
// Argument: streaming
benchmarkParams.streaming = result["streaming"].as<bool>();
// Argument: Enable batch stats output
if (result.count("log_iteration_data"))
{
optionalParams.logIterationData = result["log_iteration_data"].as<bool>();
}
bool logIterationData = result["log_iteration_data"].as<bool>();
// Argument: Enable chunked context
if (result.count("enable_chunked_context"))
{
optionalParams.enableChunkedContext = result["enable_chunked_context"].as<bool>();
}
benchmarkParams.enableChunkedContext = result["enable_chunked_context"].as<bool>();
// Argument: Enable return context logits
bool returnContextLogits = false;
if (result.count("return_context_logits"))
{
returnContextLogits = result["return_context_logits"].as<bool>();
}
bool returnContextLogits = result["return_context_logits"].as<bool>();
// Argument: Enable return context logits
bool returnGenerationLogits = false;
if (result.count("return_generation_logits"))
{
returnGenerationLogits = result["return_generation_logits"].as<bool>();
}
bool returnGenerationLogits = result["return_generation_logits"].as<bool>();
std::optional<int32_t> padId;
// Argument: Padding token id
@ -873,16 +1110,40 @@ int main(int argc, char* argv[])
initTrtLlmPlugins(logger.get());
try
if (api == "gptManager")
{
benchmarkGptManager(result["engine_dir"].as<std::string>(), modelType, datasetPath, opCsvFile, maxNumSamples,
beamWidth, result["warm_up"].as<int>(), eosId, padId, optionalParams, schedulerPolicy, waitSleep,
returnContextLogits, returnGenerationLogits, staticEmulatedBatchSize, staticEmulatedTimeout);
try
{
benchmarkGptManager(result["engine_dir"].as<std::string>(), modelType, datasetPath, opCsvFile,
maxNumSamples, beamWidth, result["warm_up"].as<int>(), eosId, padId, benchmarkParams, schedulerPolicy,
waitSleep, returnContextLogits, returnGenerationLogits, staticEmulatedBatchSize, staticEmulatedTimeout,
logIterationData);
}
catch (const std::exception& e)
{
TLLM_LOG_ERROR(e.what());
return 1;
}
}
catch (const std::exception& e)
else if (api == "executor")
{
TLLM_LOG_ERROR(e.what());
try
{
benchmarkExecutor(result["engine_dir"].as<std::string>(), modelType, datasetPath, opCsvFile, maxNumSamples,
beamWidth, result["warm_up"].as<int>(), eosId, padId, benchmarkParams, schedulerPolicy, waitSleep,
returnContextLogits, returnGenerationLogits, staticEmulatedBatchSize, logIterationData);
}
catch (const std::exception& e)
{
TLLM_LOG_ERROR(e.what());
return 1;
}
}
else
{
TLLM_LOG_ERROR("api parameter must be gptManager or executor");
return 1;
}
return 0;
}

View File

@ -60,6 +60,20 @@ else()
message(STATUS "Importing batch manager")
endif()
if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/tensorrt_llm/executor/CMakeLists.txt")
set(BUILD_EXECUTOR_DEFAULT ON)
else()
set(BUILD_EXECUTOR_DEFAULT OFF)
endif()
option(BUILD_EXECUTOR "Build executor from source" ${BUILD_EXECUTOR_DEFAULT})
if(BUILD_EXECUTOR)
message(STATUS "Building executor")
else()
message(STATUS "Importing executor")
endif()
if(BUILD_PYT)
message(STATUS "Building PyTorch")
else()

View File

@ -17,6 +17,7 @@
#pragma once
#include "tensorrt_llm/executor/executor.h"
#include "tensorrt_llm/runtime/common.h"
#include <optional>
@ -42,6 +43,13 @@ public:
{
}
explicit KvCacheConfig(executor::KvCacheConfig const& kvCacheConfig)
: KvCacheConfig(kvCacheConfig.getMaxTokens(), kvCacheConfig.getMaxAttentionWindow(),
kvCacheConfig.getSinkTokenLength(), kvCacheConfig.getFreeGpuMemoryFraction(),
kvCacheConfig.getEnableBlockReuse(), kvCacheConfig.getUseUvm())
{
}
std::optional<SizeType> maxTokens;
std::optional<SizeType> maxAttentionWindow;
std::optional<SizeType> sinkTokenLength;

View File

@ -17,6 +17,7 @@
#pragma once
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/executor/executor.h"
#include "tensorrt_llm/runtime/bufferManager.h"
#include "tensorrt_llm/runtime/iTensor.h"
#include "tensorrt_llm/runtime/samplingConfig.h"
@ -58,7 +59,7 @@ public:
std::optional<TensorPtr> loraConfig = std::nullopt, bool returnLogProbs = false,
bool returnContextLogits = false, bool returnGenerationLogits = false,
std::optional<std::shared_ptr<VecTokens>> draftTokens = std::nullopt,
std::optional<TensorPtr> draftLogits = std::nullopt)
std::optional<TensorPtr> draftLogits = std::nullopt, bool excludeInputFromOutput = false)
: mRequestId(requestId)
, mPromptLen(inputTokens->size())
, mMaxNewTokens(maxNewTokens)
@ -68,7 +69,8 @@ public:
, mEndId(endId)
, mPadId(padId)
, mSeqSlot(-1)
, mOrigPromptLen(inputTokens->size())
, mOrigPromptLen(mPromptLen)
, mMaxSentTokenPos(mPromptLen - 1)
, mEmbeddingBias(embeddingBias)
, mBadWordsList(badWordsList)
, mStopWordsList(stopWordsList)
@ -85,27 +87,112 @@ public:
, mDraftLogits(draftLogits)
, mReturnContextLogits(returnContextLogits)
, mReturnGenerationLogits(returnGenerationLogits)
, mExcludeInputFromOutput(excludeInputFromOutput)
{
mMaxSentTokenPos = mPromptLen - 1;
// Scatter the input tokens to other beam
mTokens = BeamTokens(mSamplingConfig.beamWidth, *inputTokens);
initialize(*inputTokens);
}
if ((mPromptEmbeddingTable.has_value() && !mPromptVocabSize.has_value())
|| (!mPromptEmbeddingTable.has_value() && mPromptVocabSize.has_value()))
GenericLlmRequest(RequestIdType requestId, executor::Request const& req)
: mRequestId(requestId)
, mPromptLen(req.getInputTokenIds().size())
, mMaxNewTokens(req.getMaxNewTokens())
, mSamplingConfig(req.getSamplingConfig(), req.getSpeculativeDecodingConfig())
, mState(REQUEST_STATE_CONTEXT_INIT)
, mIsStreaming(req.getStreaming())
, mEndId(req.getEndId())
, mPadId(req.getPadId())
, mSeqSlot(-1)
, mOrigPromptLen(mPromptLen)
, mMaxSentTokenPos(mPromptLen - 1)
, mReturnLogProbs(req.getOutputConfig().returnLogProbs)
, mContextChunkSize(std::nullopt)
, mContextCurrentPosition(0)
, mLogProbs(mSamplingConfig.beamWidth)
, mCumLogProbs(mSamplingConfig.beamWidth)
, mDraftTokens(std::make_shared<VecTokens>())
, mReturnContextLogits(req.getOutputConfig().returnContextLogits)
, mReturnGenerationLogits(req.getOutputConfig().returnGenerationLogits)
, mExcludeInputFromOutput(req.getOutputConfig().excludeInputFromOutput)
{
if (req.getEmbeddingBias())
{
std::string errStr
= "Prompt embedding table and prompt vocab size tensors must both be provided for requests with prompt "
"tuning enabled.";
TLLM_LOG_ERROR(errStr);
throw std::runtime_error(errStr);
mEmbeddingBias = executor::detail::toITensor(*(req.getEmbeddingBias().value()));
// Add leading 1 dimension since that's what IFB code expects
mEmbeddingBias.value()->unsqueeze(0);
}
if (req.getBadWords())
{
mBadWordsList = createListTensor(req.getBadWords().value());
}
if (req.getStopWords())
{
mStopWordsList = createListTensor(req.getStopWords().value());
}
if (draftLogits.has_value() && !draftTokens.has_value())
auto pTuningConfig = req.getPromptTuningConfig();
if (pTuningConfig)
{
std::string errStr = "Draft tokens must be specified when draft logits are given.";
TLLM_LOG_ERROR(errStr);
throw std::runtime_error(errStr);
mPromptEmbeddingTable = executor::detail::toITensor(*pTuningConfig.value().getEmbeddingTable());
TLLM_CHECK(mPromptEmbeddingTable.value()->getShape().nbDims == 2);
mPromptVocabSize = mPromptEmbeddingTable.value()->getShape().d[0];
mPromptEmbeddingTable.value()->unsqueeze(0);
}
auto loraConfig = req.getLoraConfig();
if (loraConfig)
{
mLoraWeights = executor::detail::toITensor(*loraConfig.value().getWeights());
mLoraWeights.value()->unsqueeze(0);
mLoraConfig = executor::detail::toITensor(*loraConfig.value().getConfig());
mLoraConfig.value()->unsqueeze(0);
}
auto speculativeDecodingConfig = req.getSpeculativeDecodingConfig();
if (speculativeDecodingConfig)
{
mDraftTokens = std::make_shared<VecTokens>(speculativeDecodingConfig.value().getTokens());
if (speculativeDecodingConfig.value().getLogits())
{
mDraftLogits = executor::detail::toITensor(*speculativeDecodingConfig.value().getLogits().value());
}
// NOTE: Draft acceptance threshold is stored in mSamplingConfig
}
initialize(req.getInputTokenIds());
}
void validate(SizeType maxInputLen, SizeType maxSequenceLen)
{
if (mPromptLen > maxInputLen)
{
TLLM_THROW("Prompt length (%d) exceeds maximum input length (%d).", mPromptLen, maxInputLen);
}
if (mPromptLen + mMaxNewTokens > maxSequenceLen)
{
auto const maxNewTokens = maxSequenceLen - mPromptLen;
TLLM_LOG_WARNING(
"Number of requested output tokens (%d) exceeds maximum sequence length (%d). "
"Number of requested output tokens is changed to (%d).",
mMaxNewTokens, maxSequenceLen, maxNewTokens);
mMaxNewTokens = maxNewTokens;
}
if (mSamplingConfig.beamWidth <= 0)
{
TLLM_THROW(
"Requested value: %d for beamWidth is invalid. To de-activate beam searching "
"set beamWidth to 1 instead.",
mSamplingConfig.beamWidth);
}
}
void setExcludeInputFromOutput(bool exclude)
{
mExcludeInputFromOutput = exclude;
}
/// @brief Get total number of tokens for this req (prompt + generated)
@ -236,7 +323,6 @@ public:
else
{
SizeType newPromptLen = std::min(maxInputLen, mPromptLen + getMaxNumGeneratedTokens());
TLLM_LOG_DEBUG("pause: id %lu, mPromptLen %d, newPromptLen %d", mRequestId, mPromptLen, newPromptLen);
for (std::size_t beam = 0; beam < mTokens.size(); ++beam)
{
auto& beamTokens = mTokens.at(beam);
@ -288,11 +374,31 @@ public:
return mLoraWeights;
}
void setLoraWeights(TensorPtr weights)
{
mLoraWeights = weights;
}
void clearLoraWeights()
{
mLoraWeights = std::nullopt;
}
std::optional<TensorPtr> getLoraConfig() const
{
return mLoraConfig;
}
void setLoraConfig(TensorPtr config)
{
mLoraConfig = config;
}
void clearLoraConfig()
{
mLoraConfig = std::nullopt;
}
std::optional<TensorPtr> getEmbeddingBias() const
{
return mEmbeddingBias;
@ -389,6 +495,12 @@ public:
mContextLogitsHost = std::move(contextLogitsHost);
}
void allocContextLogitsHost(SizeType vocabSizePadded, nvinfer1::DataType logitsDataType)
{
mContextLogitsHost = runtime::BufferManager::pinned(
runtime::ITensor::makeShape({mPromptLen, vocabSizePadded}), logitsDataType);
}
TensorPtr const& getGenerationLogitsHost() const
{
return mGenerationLogitsHost;
@ -399,6 +511,12 @@ public:
mGenerationLogitsHost = std::move(generationLogitsHost);
}
void allocGenerationLogitsHost(SizeType vocabSizePadded, nvinfer1::DataType logitsDataType)
{
mGenerationLogitsHost = runtime::BufferManager::pinned(
runtime::ITensor::makeShape({mSamplingConfig.beamWidth, mMaxNewTokens, vocabSizePadded}), logitsDataType);
}
std::vector<TensorPtr> const& getGenerationLogitsFragments() const
{
return mGenerationLogitsFragments;
@ -498,6 +616,84 @@ public:
}
}
/// @brief Create a Response from the current state of the request
/// @return An optional Response
std::optional<executor::Response> createResponse()
{
if (mState == batch_manager::REQUEST_STATE_GENERATION_COMPLETE
|| (mIsStreaming && mState == batch_manager::REQUEST_STATE_GENERATION_IN_PROGRESS))
{
executor::Result result;
result.isFinal = mState == batch_manager::REQUEST_STATE_GENERATION_COMPLETE ? true : false;
auto nbBeams = mSamplingConfig.beamWidth;
auto maxNbTokens = getMaxBeamNumTokens();
// FIXME(nkorobov): For streaming we do not allow beam search and
// streaming index calculation here applies only for sampling
int nbTokensOut = mIsStreaming ? 1 : maxNbTokens;
if (mExcludeInputFromOutput && !mIsStreaming)
{
nbTokensOut -= getOrigPromptLen();
}
result.outputTokenIds.resize(nbBeams);
SizeType tokenPos = maxNbTokens - nbTokensOut;
bool shouldSendResponse = (mState == batch_manager::REQUEST_STATE_GENERATION_COMPLETE)
|| (mIsStreaming && tokenPos > getMaxSentTokenPos());
if (!shouldSendResponse)
{
return std::nullopt;
}
else
{
for (SizeType beam = 0; beam < nbBeams; ++beam)
{
auto tokens = getTokens(beam);
auto nbTokens = mIsStreaming ? (tokenPos - getMaxSentTokenPos()) : tokens.size();
if (mExcludeInputFromOutput && !mIsStreaming)
{
nbTokens -= getOrigPromptLen();
}
if (nbTokens > 0)
{
result.outputTokenIds.at(beam).assign(
tokens.data() + tokenPos, tokens.data() + tokenPos + nbTokens);
}
}
if (returnLogProbs())
{
result.cumLogProbs = getCumLogProbs();
result.logProbs = getLogProbs();
}
if (getReturnContextLogits())
{
result.contextLogits
= std::make_shared<executor::Tensor>(executor::detail::ofITensor(getContextLogitsHost()));
}
if (getReturnGenerationLogits())
{
result.generationLogits
= std::make_shared<executor::Tensor>(executor::detail::ofITensor(getGenerationLogitsHost()));
}
// Update position of last sent response
mMaxSentTokenPos = tokenPos;
auto response = executor::Response(mRequestId, std::move(result));
return response;
}
}
else
{
return std::nullopt;
}
}
RequestIdType mRequestId;
SizeType mPromptLen;
SizeType mMaxNewTokens;
@ -545,6 +741,55 @@ protected:
TensorPtr mGenerationLogits; // [beam_size, mMaxNewTokens, vocab_size_padded]
TensorPtr mGenerationLogitsHost;
std::vector<TensorPtr> mGenerationLogitsFragments;
bool mExcludeInputFromOutput;
private:
void initialize(VecTokens const& inputTokens)
{
// Scatter the input tokens to other beam
mTokens = BeamTokens(mSamplingConfig.beamWidth, inputTokens);
if ((mPromptEmbeddingTable.has_value() && !mPromptVocabSize.has_value())
|| (!mPromptEmbeddingTable.has_value() && mPromptVocabSize.has_value()))
{
std::string errStr
= "Prompt embedding table and prompt vocab size tensors must both be provided for requests with "
"prompt "
"tuning enabled.";
TLLM_THROW(errStr);
}
if (mDraftLogits.has_value() && mDraftTokens->empty())
{
TLLM_THROW("Draft tokens must be specified when draft logits are given.");
}
}
TensorPtr createListTensor(std::list<VecTokens> const& wordsList)
{
std::vector<SizeType> offsets;
VecTokens words;
SizeType offsetCnt = 0;
for (auto const& tokens : wordsList)
{
offsetCnt += tokens.size();
offsets.push_back(offsetCnt);
words.insert(words.end(), tokens.begin(), tokens.end());
}
offsets.resize(words.size(), -1);
SizeType numWords = static_cast<SizeType>(words.size());
auto shape = runtime::ITensor::makeShape({2, numWords});
auto tensor = runtime::BufferManager::pinnedPool(shape, nvinfer1::DataType::kINT32);
auto data = runtime::bufferCast<int32_t>(*tensor);
std::memcpy(data, words.data(), numWords * sizeof(int32_t));
std::memcpy(data + numWords, offsets.data(), numWords * sizeof(int32_t));
// Add leading dim of 1
tensor->unsqueeze(0);
return tensor;
}
};
class LlmRequest : public GenericLlmRequest<runtime::ITensor::SharedPtr>
@ -568,10 +813,15 @@ public:
std::optional<TensorPtr> loraConfig = std::nullopt, bool returnLogProbs = false,
bool returnContextLogits = false, bool returnGenerationLogits = false,
std::optional<std::shared_ptr<VecTokens>> draftTokens = std::nullopt,
std::optional<TensorPtr> draftLogits = std::nullopt)
std::optional<TensorPtr> draftLogits = std::nullopt, bool excludeInputFromOutput = false)
: Base(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming, endId, padId, embeddingBias,
badWordsList, stopWordsList, promptEmbeddingTable, promptVocabSize, loraWeights, loraConfig, returnLogProbs,
returnContextLogits, returnGenerationLogits, draftTokens, draftLogits)
returnContextLogits, returnGenerationLogits, draftTokens, draftLogits, excludeInputFromOutput)
{
}
LlmRequest(RequestIdType requestId, executor::Request const& Request)
: Base(requestId, Request)
{
}

View File

@ -16,6 +16,8 @@
#pragma once
#include "tensorrt_llm/executor/types.h"
namespace tensorrt_llm::batch_manager::batch_scheduler
{
@ -25,4 +27,8 @@ enum class SchedulerPolicy
GUARANTEED_NO_EVICT,
};
SchedulerPolicy execToBatchManagerSchedPolicy(executor::SchedulerPolicy policy);
executor::SchedulerPolicy batchManagerToExecSchedPolicy(SchedulerPolicy policy);
} // namespace tensorrt_llm::batch_manager::batch_scheduler

View File

@ -18,6 +18,7 @@
#pragma once
#include "tensorrt_llm/batch_manager/kvCacheConfig.h"
#include "tensorrt_llm/executor/executor.h"
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/decodingMode.h"
@ -36,23 +37,29 @@ public:
explicit TrtGptModelOptionalParams(KvCacheConfig const& kvCacheConfig = KvCacheConfig{},
bool enableTrtOverlap = false, std::optional<std::vector<SizeType>> const& deviceIds = std::nullopt,
bool normalizeLogProbs = true, bool logIterationData = false, bool enableChunkedContext = false,
bool normalizeLogProbs = true, bool enableChunkedContext = false,
std::optional<runtime::DecodingMode> const& decodingMode = std::nullopt)
: kvCacheConfig{kvCacheConfig}
, enableTrtOverlap{enableTrtOverlap}
, deviceIds(deviceIds)
, normalizeLogProbs{normalizeLogProbs}
, logIterationData{logIterationData}
, enableChunkedContext{enableChunkedContext}
, decodingMode{decodingMode}
{
}
explicit TrtGptModelOptionalParams(executor::ExecutorConfig const& executorConfig)
: TrtGptModelOptionalParams(KvCacheConfig(executorConfig.getKvCacheConfig()),
executorConfig.getEnableTrtOverlap(), executorConfig.getDeviceIds(), executorConfig.getNormalizeLogProbs(),
executorConfig.getEnableChunkedContext())
{
}
KvCacheConfig kvCacheConfig;
bool enableTrtOverlap;
std::optional<std::vector<SizeType>> deviceIds;
bool normalizeLogProbs;
bool logIterationData;
bool enableChunkedContext;
std::optional<runtime::DecodingMode> decodingMode;
};

View File

@ -0,0 +1,96 @@
/*
* Copyright (c) 2021-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cstdint>
namespace tensorrt_llm::common
{
//!
//! \brief A very rudimentary implementation of std::span.
//!
template <typename T>
class ArrayView
{
public:
using value_type = T;
using size_type = std::size_t;
using reference = value_type&;
using const_reference = value_type const&;
using pointer = T*;
using const_pointer = T const*;
using iterator = pointer;
using const_iterator = const_pointer;
ArrayView(T* data, size_type size)
: mData{data}
, mSize{size}
{
}
[[nodiscard]] iterator begin()
{
return mData;
}
[[nodiscard]] iterator end()
{
return mData + mSize;
}
[[nodiscard]] const_iterator begin() const
{
return mData;
}
[[nodiscard]] const_iterator end() const
{
return mData + mSize;
}
[[nodiscard]] const_iterator cbegin() const
{
return mData;
}
[[nodiscard]] const_iterator cend() const
{
return mData + mSize;
}
[[nodiscard]] size_type size() const
{
return mSize;
}
[[nodiscard]] reference operator[](size_type index)
{
return mData[index];
}
[[nodiscard]] const_reference operator[](size_type index) const
{
return mData[index];
}
private:
T* mData;
size_type mSize;
};
} // namespace tensorrt_llm::common

View File

@ -0,0 +1,416 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/executor/tensor.h"
#include "tensorrt_llm/executor/types.h"
#include <chrono>
#include <cstdint>
#include <deque>
#include <filesystem>
#include <list>
#include <memory>
#include <optional>
#include <string>
#include <tuple>
#include <variant>
#include <vector>
namespace tensorrt_llm::executor
{
/// @brief Sampling configuration
class SamplingConfig
{
public:
SamplingConfig(SizeType beamWidth = 1, std::optional<SizeType> topK = std::nullopt,
std::optional<FloatType> topP = std::nullopt, std::optional<FloatType> topPMin = std::nullopt,
std::optional<SizeType> topPResetIds = std::nullopt, std::optional<FloatType> topPDecay = std::nullopt,
std::optional<RandomSeedType> randomSeed = std::nullopt, std::optional<FloatType> temperature = std::nullopt,
std::optional<SizeType> minLength = std::nullopt,
std::optional<FloatType> beamSearchDiversityRate = std::nullopt,
std::optional<FloatType> repetitionPenalty = std::nullopt,
std::optional<FloatType> presencePenalty = std::nullopt,
std::optional<FloatType> frequencyPenalty = std::nullopt,
std::optional<FloatType> lengthPenalty = std::nullopt);
~SamplingConfig();
[[nodiscard]] SizeType getBeamWidth() const;
[[nodiscard]] std::optional<SizeType> getTopK() const;
[[nodiscard]] std::optional<FloatType> getTopP() const;
[[nodiscard]] std::optional<FloatType> getTopPMin() const;
[[nodiscard]] std::optional<SizeType> getTopPResetIds() const;
[[nodiscard]] std::optional<FloatType> getTopPDecay() const;
[[nodiscard]] std::optional<RandomSeedType> getRandomSeed() const;
[[nodiscard]] std::optional<FloatType> getTemperature() const;
[[nodiscard]] std::optional<SizeType> getMinLength() const;
[[nodiscard]] std::optional<FloatType> getBeamSearchDiversityRate() const;
[[nodiscard]] std::optional<FloatType> getRepetitionPenalty() const;
[[nodiscard]] std::optional<FloatType> getPresencePenalty() const;
[[nodiscard]] std::optional<FloatType> getFrequencyPenalty() const;
[[nodiscard]] std::optional<FloatType> getLengthPenalty() const;
private:
SizeType mBeamWidth;
std::optional<SizeType> mTopK;
std::optional<FloatType> mTopP;
std::optional<FloatType> mTopPMin;
std::optional<SizeType> mTopPResetIds;
std::optional<FloatType> mTopPDecay;
std::optional<RandomSeedType> mRandomSeed;
std::optional<FloatType> mTemperature;
std::optional<SizeType> mMinLength;
std::optional<FloatType> mBeamSearchDiversityRate;
std::optional<FloatType> mRepetitionPenalty;
std::optional<FloatType> mPresencePenalty;
std::optional<FloatType> mFrequencyPenalty;
std::optional<FloatType> mLengthPenalty;
};
/// @brief Configuration that controls the outputs of a Result
struct OutputConfig
{
bool returnLogProbs{false};
bool returnContextLogits{false};
bool returnGenerationLogits{false};
bool excludeInputFromOutput{false};
};
/// @brief Configuration for speculative decoding. Allows to include draft tokens, draft logits and specify acceptance
/// threshold
class SpeculativeDecodingConfig
{
public:
explicit SpeculativeDecodingConfig(VecTokens tokens, std::optional<TensorPtr> logits = std::nullopt,
std::optional<FloatType> acceptanceThreshold = std::nullopt);
~SpeculativeDecodingConfig();
[[nodiscard]] VecTokens getTokens() const;
[[nodiscard]] std::optional<TensorPtr> getLogits() const;
[[nodiscard]] std::optional<FloatType> getAcceptanceThreshold() const;
private:
VecTokens mTokens;
std::optional<TensorPtr> mLogits;
std::optional<FloatType> mAcceptanceThreshold;
};
/// @brief Configuration for prompt tuning
class PromptTuningConfig
{
public:
/// @brief
/// @param embeddingTable The prompt embedding table. Data type must match model weights. Shape [vocabSize,
/// hiddenSize]
/// @param vocabSize
PromptTuningConfig(TensorPtr embeddingTable);
~PromptTuningConfig();
[[nodiscard]] TensorPtr getEmbeddingTable() const;
private:
TensorPtr mEmbeddingTable;
};
/// @brief Configuration for LoRA
class LoraConfig
{
public:
LoraConfig(TensorPtr weights, TensorPtr config);
~LoraConfig();
[[nodiscard]] TensorPtr getWeights() const;
[[nodiscard]] TensorPtr getConfig() const;
private:
TensorPtr mWeights;
TensorPtr mConfig;
};
/// @brief A class that holds information about the request
class Request
{
public:
/// @brief
/// @param inputTokenIds The input token ids
/// @param maxNewTokens The maximum number of tokens to generate
/// @param streaming // Indicates if the responses should be streamed or not
/// @param samplingConfig // The sampling configuration
/// @param outputConfig // The output configuration
/// @param endId // The end token id
/// @param padId // The pad token id
/// @param badWords // A list of bad words tokens. Each "word" can be composed of multiple tokens
/// @param stopWords // A list of stop words tokens. Each "word" can be composed of multiple tokens
/// @param embeddingBias // The embedding bias tensor. Expected type is kFP32 and shape is [vocab_size]
/// @param speculativeDecodingConfig // The speculative decoding configuration
/// @param pTuningConfig // The prompt tuning configuration
/// @param loraConfig // The LoRA configuration
Request(VecTokens inputTokenIds, SizeType maxNewTokens, bool streaming = false,
SamplingConfig samplingConfig = SamplingConfig(), OutputConfig outputConfig = OutputConfig(),
std::optional<SizeType> endId = std::nullopt, std::optional<SizeType> padId = std::nullopt,
std::optional<std::list<VecTokens>> badWords = std::nullopt,
std::optional<std::list<VecTokens>> stopWords = std::nullopt,
std::optional<TensorPtr> embeddingBias = std::nullopt,
std::optional<SpeculativeDecodingConfig> speculativeDecodingConfig = std::nullopt,
std::optional<PromptTuningConfig> pTuningConfig = std::nullopt,
std::optional<LoraConfig> loraConfig = std::nullopt);
Request(Request const& other);
Request(Request&& other) noexcept;
Request& operator=(Request const& other);
Request& operator=(Request&& other) noexcept;
~Request();
[[nodiscard]] VecTokens getInputTokenIds() const;
[[nodiscard]] SizeType getMaxNewTokens() const;
[[nodiscard]] bool getStreaming() const;
[[nodiscard]] SamplingConfig getSamplingConfig() const;
[[nodiscard]] OutputConfig getOutputConfig() const;
[[nodiscard]] std::optional<SizeType> getEndId() const;
[[nodiscard]] std::optional<SizeType> getPadId() const;
[[nodiscard]] std::optional<std::list<VecTokens>> getBadWords() const;
[[nodiscard]] std::optional<std::list<VecTokens>> getStopWords() const;
[[nodiscard]] std::optional<TensorPtr> getEmbeddingBias() const;
[[nodiscard]] std::optional<SpeculativeDecodingConfig> getSpeculativeDecodingConfig() const;
[[nodiscard]] std::optional<PromptTuningConfig> getPromptTuningConfig() const;
[[nodiscard]] std::optional<LoraConfig> getLoraConfig() const;
void setStreaming(bool streaming);
void setSamplingConfig(SamplingConfig config);
void setOutputConfig(OutputConfig outputConfig);
void setEndId(SizeType endId);
void setPadId(SizeType padId);
void setBadWords(std::list<VecTokens> badWords);
void setStopWords(std::list<VecTokens> stopWords);
void setEmbeddingBias(TensorPtr);
void setSpeculativeDecodingConfig(SpeculativeDecodingConfig specDecodingConfig);
void setPromptTuningConfig(PromptTuningConfig pTuningConfig);
void setLoraConfig(LoraConfig loraConfig);
private:
class Impl;
std::unique_ptr<Impl> mImpl;
};
/// @brief Struct that holds the generation result
struct Result
{
// Indicates if this is the final result for the request
bool isFinal;
/// @brief The output tokens for each beam
BeamTokens outputTokenIds;
std::optional<VecLogProbs> cumLogProbs; // [beamSize]
std::optional<std::vector<VecLogProbs>> logProbs; // [beamSize, seqLen]
std::optional<TensorPtr> contextLogits; // [promptLen, vocab_size_padded]
std::optional<TensorPtr> generationLogits; // [beam_size, mMaxNewTokens, vocab_size_padded]
};
/// @brief Class that holds either an error or a result
class Response
{
public:
Response(IdType requestId, std::string errorMsg);
Response(IdType requestId, Result Result);
~Response();
Response(Response const& other);
Response(Response&& other) noexcept;
Response& operator=(Response const& other);
Response& operator=(Response&& other) noexcept;
// Get the id of the request for which this response was generated
IdType getRequestId() const;
// Indicates if this response has an error or not
bool hasError() const;
// Get the error msg for this response
// Will throw an exception if hasError is false
std::string getErrorMsg() const;
// Get the result for this response
// Will throw an exception if hasResult is true
Result getResult() const;
private:
class Impl;
std::unique_ptr<Impl> mImpl;
};
/// @brief Configuration class for the scheduler
class SchedulerConfig
{
public:
explicit SchedulerConfig(SchedulerPolicy policy = SchedulerPolicy::kGUARANTEED_NO_EVICT);
~SchedulerConfig();
[[nodiscard]] SchedulerPolicy getPolicy() const;
private:
SchedulerPolicy mPolicy;
};
/// @brief Configuration class for the KV cache
class KvCacheConfig
{
public:
KvCacheConfig(bool enableBlockReuse = false, std::optional<SizeType> maxTokens = std::nullopt,
std::optional<SizeType> maxAttentionWindow = std::nullopt,
std::optional<SizeType> sinkTokenLength = std::nullopt,
std::optional<FloatType> freeGpuMemoryFraction = std::nullopt, bool useUvm = false);
[[nodiscard]] bool getEnableBlockReuse() const;
[[nodiscard]] std::optional<SizeType> getMaxTokens() const;
[[nodiscard]] std::optional<SizeType> getMaxAttentionWindow() const;
[[nodiscard]] std::optional<SizeType> getSinkTokenLength() const;
[[nodiscard]] std::optional<FloatType> getFreeGpuMemoryFraction() const;
[[nodiscard]] bool getUseUvm() const;
private:
bool mEnableBlockReuse;
std::optional<SizeType> mMaxTokens;
std::optional<SizeType> mMaxAttentionWindow;
std::optional<SizeType> mSinkTokenLength;
std::optional<FloatType> mFreeGpuMemoryFraction;
bool mUseUvm;
};
SizeType const kDefaultIterStatsMaxIterations = 1000;
/// @brief Configuration class for the model executor
class ExecutorConfig
{
public:
ExecutorConfig(SizeType maxBeamWidth = 1, SchedulerConfig schedulerConfig = SchedulerConfig(),
KvCacheConfig kvCacheConfig = KvCacheConfig(), bool enableChunkedContext = false, bool normalizeLogProbs = true,
bool enableTrtOverlap = false, std::optional<std::vector<SizeType>> deviceIds = std::nullopt,
SizeType iterStatsMaxIterations = kDefaultIterStatsMaxIterations,
BatchingType batchingType = BatchingType::kINFLIGHT);
[[nodiscard]] SizeType getMaxBeamWidth() const;
[[nodiscard]] SchedulerConfig getSchedulerConfig() const;
[[nodiscard]] KvCacheConfig getKvCacheConfig() const;
[[nodiscard]] bool getEnableChunkedContext() const;
[[nodiscard]] bool getNormalizeLogProbs() const;
[[nodiscard]] bool getEnableTrtOverlap() const;
[[nodiscard]] std::optional<std::vector<SizeType>> getDeviceIds() const;
[[nodiscard]] SizeType getIterStatsMaxIterations() const;
[[nodiscard]] BatchingType getBatchingType() const;
void setMaxBeamWidth(SizeType maxBeamWidth);
void setSchedulerConfig(SchedulerConfig schedulerConfig);
void setKvCacheConfig(KvCacheConfig kvCacheConfig);
void setEnableChunkedContext(bool enableChunkedContext);
void setNormalizeLogProbs(bool normalizeLogProbs);
void setEnableTrtOverlap(bool enableTrtOverlap);
void setDeviceIds(std::optional<std::vector<SizeType>> deviceIds);
void setIterStatsMaxIterations(SizeType iterStatsMaxIterations);
void setBatchingType(BatchingType batchingType);
private:
SizeType mMaxBeamWidth;
SchedulerConfig mSchedulerConfig;
KvCacheConfig mKvCacheConfig;
bool mEnableChunkedContext;
bool mNormalizeLogProbs;
bool mEnableTrtOverlap;
std::optional<std::vector<SizeType>> mDeviceIds;
SizeType mIterStatsMaxIterations;
BatchingType mBatchingType;
};
/// TODO:
/// @brief A class to identify processes involved in the execution of a model
/// Currently only supports MPI communication
class Communicator
{
public:
Communicator(CommunicatorType commType, CommMode mode, SizeType currentId, std::vector<SizeType> const& commIds,
std::optional<SizeType> orchestratorId){};
~Communicator() = default;
};
class Model;
/// @brief The executor is responsible for receiving new requests and sending responses, and running the inference
class Executor
{
using RequestPtr = std::shared_ptr<Request>;
public:
/// @brief
/// @param modelPath Path to the folder that defines the model to run
/// @param modelType The type of model
/// @param executorConfig The configuration for the executor
/// @param comm An optional inter-process communicator configuration
Executor(std::filesystem::path const& modelPath, ModelType modelType, ExecutorConfig executorConfig,
std::optional<Communicator> comm = std::nullopt);
Executor(std::vector<uint8_t> const& engineBuffer, std::string const& jsonConfigStr, ModelType modelType,
ExecutorConfig executorConfig, std::optional<Communicator> comm = std::nullopt);
Executor(
std::shared_ptr<Model> model, ExecutorConfig executorConfig, std::optional<Communicator> comm = std::nullopt);
~Executor();
/// @brief Enqueue a new request
/// @param request The LLM request which contains input tokens and request parameters
/// @return A unique id that identifies the request
IdType enqueueRequest(Request request);
/// @brief Enqueue a batch of request
std::vector<IdType> enqueueRequests(std::vector<Request> requests);
/// @brief Await for ready responses
/// @param id An optional request id. If not specified, responses for any request can be returned
/// @param timeout The maximum time to wait for new responses
/// @return A vector of responses
std::vector<Response> awaitResponses(
std::optional<IdType> id = std::nullopt, std::optional<std::chrono::milliseconds> timeout = std::nullopt);
/// @brief Get the number of ready responses
/// @param id The request id
/// @return The number of ready responses
SizeType getNumResponsesReady(std::optional<IdType> id = std::nullopt);
/// @brief Cancel the request with provided request id
/// @param id The request id for which to cancel the response
void cancelRequest(IdType id);
/// @brief Signals the server to shutdown
/// This call is blocking. Only returns when all requests have terminated or timeout has been reached
void shutdown();
/// @brief Returns the per-iterations statistics computed since last call to getLatestIterationStats
/// Contains at most iterStatsMaxIterations iterations
/// Will block until stats for at least one iteration are available
/// TODO: Should we use a class for iterationStats, i.e. std::deque<IterationStats>
/// @return
std::deque<std::string> getLatestIterationStats();
private:
class Impl;
std::unique_ptr<Impl> mImpl;
};
} // namespace tensorrt_llm::executor

View File

@ -0,0 +1,272 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/executor/types.h"
#include "tensorrt_llm/common/arrayView.h"
#include "tensorrt_llm/common/assert.h"
#include <cstdint>
#include <initializer_list>
#include <limits>
#include <memory>
#include <type_traits>
#include <vector>
namespace tensorrt_llm::runtime
{
class ITensor;
class CudaStream;
} // namespace tensorrt_llm::runtime
namespace tensorrt_llm::executor
{
class Tensor;
namespace detail
{
std::shared_ptr<runtime::ITensor> const& toITensor(Tensor const& tensor);
Tensor ofITensor(std::shared_ptr<runtime::ITensor> tensor);
} // namespace detail
// A thin wrapper around span that supports constructions with an initializer list.
class Shape : public tensorrt_llm::common::ArrayView<std::int32_t const>
{
public:
using Base = tensorrt_llm::common::ArrayView<std::int32_t const>;
using DimType = typename std::remove_cv_t<Base::value_type>;
Shape()
: Base{nullptr, 0} {};
Shape(DimType const* data, Base::size_type size)
: Base{data, size}
{
}
Shape(std::initializer_list<DimType> dims) // NOLINT(*-explicit-constructor)
: Base{dims.begin(), dims.size()}
{
}
};
class Tensor
{
public:
using CudaStreamPtr = std::shared_ptr<runtime::CudaStream>;
//! Allocate a cpu tensor with the given shape and data type.
//!
//! \param shape The shape of the tensor.
//! \param dataType The data type of the tensor.
static Tensor cpu(DataType dataType, Shape shape = {});
template <typename T>
static Tensor cpu(Shape shape = {})
{
return Tensor::cpu(getRuntimeType<T>(), shape);
}
[[nodiscard]] Tensor copyToCpu(Tensor::CudaStreamPtr stream = nullptr) const;
//! Allocate a cpu tensor in pinned memory with the given shape and data type.
//!
//! \param shape The shape of the tensor.
//! \param dataType The data type of the tensor.
static Tensor pinned(DataType dataType, Shape shape = {});
template <typename T>
static Tensor pinned(Shape shape = {})
{
return Tensor::pinned(getRuntimeType<T>(), shape);
}
[[nodiscard]] Tensor copyToPinned(Tensor::CudaStreamPtr stream = nullptr) const;
//! Allocate a cpu tensor in pooled pinned memory with the given shape and data type.
//!
//! \param shape The shape of the tensor.
//! \param dataType The data type of the tensor.
static Tensor pooledPinned(DataType dataType, Shape shape = {});
template <typename T>
static Tensor pooledPinned(Shape shape = {})
{
return Tensor::pooledPinned(getRuntimeType<T>(), shape);
}
[[nodiscard]] Tensor copyToPooledPinned(Tensor::CudaStreamPtr stream = nullptr) const;
//! Allocate a tensor in managed memory (UVM) with the given shape and data type.
//!
//! \param shape The shape of the tensor.
//! \param dataType The data type of the tensor.
static Tensor managed(DataType dataType, Shape shape = {});
template <typename T>
static Tensor managed(Shape shape = {})
{
return Tensor::managed(getRuntimeType<T>(), shape);
}
[[nodiscard]] Tensor copyToManaged(Tensor::CudaStreamPtr stream = nullptr) const;
//! Allocate a gpu tensor with the given shape and data type on a particular cuda stream.
//!
//! \param shape The shape of the tensor.
//! \param stream Specifies the CUDA stream on which to allocate the tensor for GPU memory.
//! \param dataType The data type of the tensor.
static Tensor gpu(DataType dataType, CudaStreamPtr stream, Shape shape = {});
template <typename T>
static Tensor gpu(CudaStreamPtr stream, Shape shape = {})
{
return Tensor::gpu(getRuntimeType<T>(), std::move(stream), shape);
}
[[nodiscard]] Tensor copyToGpu(Tensor::CudaStreamPtr stream) const;
//! Wrap a data pointer into a tensor without taking ownership.
//!
//! \param shape The shape of the tensor.
//! \param dataType The data type of the tensor.
//! \param stream Specifies the CUDA stream on which to allocate the tensor for GPU memory.
static Tensor of(DataType dataType, void* data, Shape shape);
//! Wrap a data pointer into a tensor without taking ownership.
//!
//! \param shape The shape of the tensor.
//! \param dataType The data type of the tensor.
//! \param stream Specifies the CUDA stream on which to allocate the tensor for GPU memory.
template <typename T>
static Tensor of(T* data, Shape shape)
{
return of(getRuntimeType<T>(), static_cast<void*>(data), shape);
}
//! Wrap any container into a tensor without taking ownership.
//!
//! \param shape The shape of the tensor.
//! \param dataType The data type of the tensor.
//! \param stream Specifies the CUDA stream on which to allocate the tensor for GPU memory.
template <typename T>
static Tensor of(T& data)
{
using DimType = Shape::DimType;
if constexpr (!std::is_same_v<DimType, decltype(data.size())>)
{
TLLM_CHECK(data.size() <= std::numeric_limits<DimType>::max());
}
return of(data.data(), {static_cast<Shape::DimType const>(data.size())});
}
Tensor() noexcept = default;
~Tensor() = default;
Tensor(const Tensor& other) noexcept = default;
Tensor(Tensor&& other) noexcept = default;
Tensor& operator=(const Tensor& other) noexcept = default;
Tensor& operator=(Tensor&& other) noexcept = default;
//!
//! \brief Returns a pointer to underlying array.
//!
[[nodiscard]] void* getData();
//!
//! \brief Returns a pointer to underlying array.
//!
[[nodiscard]] void const* getData() const;
//!
//! \brief Returns the data type of the buffer.
//!
[[nodiscard]] DataType getDataType() const;
//!
//! \brief Returns the memory type of the buffer.
//!
[[nodiscard]] MemoryType getMemoryType() const;
//!
//! \brief Returns the tensor dimensions.
//!
[[nodiscard]] Shape getShape() const;
//!
//! \brief Returns the number of elements in the tensor.
//!
[[nodiscard]] std::size_t getSize() const;
//!
//! \brief Returns the size of the tensor in bytes.
//!
[[nodiscard]] std::size_t getSizeInBytes() const;
//!
//! \brief Set the entire memory to zero.
//!
//! \param stream Must be a valid CUDA stream if the memory type is GPU.
void setZero(CudaStreamPtr stream = nullptr);
//!
//! \brief Copy the data and shape from another tensor.
//!
//! \param other A tensor to copy from.
//! \param stream Must be a valid CUDA stream if the memory type is GPU.
void setFrom(Tensor const& other, CudaStreamPtr stream = nullptr);
explicit operator bool() const
{
return static_cast<bool>(mTensor);
}
bool operator==(Tensor const& rhs) const
{
return mTensor == rhs.mTensor;
}
bool operator!=(Tensor const& rhs) const
{
return !(rhs == *this);
}
private:
using Impl = runtime::ITensor;
explicit Tensor(std::shared_ptr<runtime::ITensor> tensor);
template <typename T>
static DataType getRuntimeType()
{
return TypeTraits<std::remove_cv_t<T>>::value;
}
[[nodiscard]] Tensor copyTo(std::shared_ptr<Impl> tensor, CudaStreamPtr stream) const;
std::shared_ptr<Impl> mTensor;
friend std::shared_ptr<runtime::ITensor> const& detail::toITensor(Tensor const& tensor);
friend Tensor detail::ofITensor(std::shared_ptr<runtime::ITensor> tensor);
};
} // namespace tensorrt_llm::executor

View File

@ -0,0 +1,175 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cstdint>
#include <memory>
#include <vector>
#ifdef ENABLE_FP8
#include <cuda_fp8.h>
#endif
#ifdef ENABLE_BF16
#include <cuda_bf16.h>
#endif
namespace tensorrt_llm::executor
{
class Request;
class Tensor;
using TensorPtr = std::shared_ptr<Tensor>;
using SizeType = std::int32_t;
using FloatType = float;
using TokenIdType = std::int32_t;
using VecTokens = std::vector<TokenIdType>;
using BeamTokens = std::vector<VecTokens>;
using IdType = std::uint64_t;
using RandomSeedType = std::uint64_t;
using VecLogProbs = std::vector<FloatType>;
enum class DataType
{
kBOOL,
kUINT8,
kINT8,
kINT32,
kINT64,
kBF16,
kFP8,
kFP16,
kFP32,
kUNKNOWN
};
//! \brief For converting a C++ data type to a `TrtLmmDataType`.
template <typename T, bool = false>
struct TypeTraits
{
};
template <>
struct TypeTraits<float>
{
static constexpr auto value = DataType::kFP32;
};
template <>
struct TypeTraits<half>
{
static constexpr auto value = DataType::kFP16;
};
template <>
struct TypeTraits<std::int8_t>
{
static constexpr auto value = DataType::kINT8;
};
template <>
struct TypeTraits<std::int32_t>
{
static constexpr auto value = DataType::kINT32;
};
template <>
struct TypeTraits<std::int64_t>
{
static constexpr auto value = DataType::kINT64;
};
template <>
struct TypeTraits<bool>
{
static constexpr auto value = DataType::kBOOL;
};
template <>
struct TypeTraits<std::uint8_t>
{
static constexpr auto value = DataType::kUINT8;
};
#ifdef ENABLE_BF16
template <>
struct TypeTraits<__nv_bfloat16>
{
static constexpr auto value = DataType::kBF16;
};
#endif
#ifdef ENABLE_FP8
template <>
struct TypeTraits<__nv_fp8_e4m3>
{
static constexpr auto value = DataType::kFP8;
};
#endif
template <typename T>
struct TypeTraits<T*>
{
// Pointers are stored as int64_t.
static constexpr auto value = DataType::kINT64;
};
enum class MemoryType
{
kCPU,
kCPU_PINNED,
kGPU,
kUVM,
kUNKNOWN
};
enum class ModelType
{
kDECODER_ONLY = 0,
};
enum class BatchingType
{
kSTATIC = 0,
kINFLIGHT = 1,
kINFLIGHT_UNFUSED = 2,
};
enum class SchedulerPolicy
{
kMAX_UTILIZATION = 0,
kGUARANTEED_NO_EVICT = 1,
};
enum class CommunicatorType
{
kMPI = 0
};
enum class CommMode
{
kLEADER, // With the leader mode, only the leader will be returning from the executor constructor and
// therefore only the leader can enqueue requests and get responses
kORCHESTRATOR, // With the orchestrator mode, only the orchestrator will be returning from the executor constructor
// and therefore only the leader can enqueue requests and get responses The orchestrator doesn't
// participate in the computations
kALL, // With the ALL mode, all participants are expected to make the same calls to the executor API
// So they all need to send the same requests
// Responses will be the same for all participants
};
} // namespace tensorrt_llm::executor

View File

@ -73,10 +73,10 @@ public:
[[nodiscard]] static ITensorPtr pinnedPool(nvinfer1::Dims dims, nvinfer1::DataType type = kBYTE_TYPE);
//! \brief Allocates an `IBuffer` of the given size in UVM.
[[nodiscard]] IBufferPtr managed(std::size_t size, nvinfer1::DataType type = kBYTE_TYPE) const;
[[nodiscard]] static IBufferPtr managed(std::size_t size, nvinfer1::DataType type = kBYTE_TYPE);
//! \brief Allocates an `ITensor` of the given dimensions in UVM.
[[nodiscard]] ITensorPtr managed(nvinfer1::Dims dims, nvinfer1::DataType type = kBYTE_TYPE) const;
[[nodiscard]] static ITensorPtr managed(nvinfer1::Dims dims, nvinfer1::DataType type = kBYTE_TYPE);
//! \brief Allocates an `IBuffer` of the given size and memory type.
[[nodiscard]] IBufferPtr allocate(

View File

@ -60,6 +60,12 @@ public:
mStream = StreamPtr{stream, Deleter{ownsStream}};
}
//! Construct with an existing cuda stream or the default stream by passing nullptr.
explicit CudaStream(cudaStream_t stream)
: CudaStream{stream, tensorrt_llm::common::getDevice(), false}
{
}
//! Returns the device on which the stream was created.
[[nodiscard]] int getDevice() const
{

View File

@ -16,6 +16,9 @@
#pragma once
#include "tensorrt_llm/common/arrayView.h"
#include "tensorrt_llm/common/dataType.h"
#include <NvInferRuntime.h>
#include <cstdint>
@ -33,8 +36,6 @@
#include <typeinfo>
#include <vector>
#include "tensorrt_llm/common/dataType.h"
namespace tensorrt_llm::runtime
{
@ -561,21 +562,14 @@ T* bufferCast(IBuffer& buffer)
}
template <typename T>
class BufferRange
class BufferRange : public tensorrt_llm::common::ArrayView<T>
{
public:
using value_type = T;
using size_type = std::size_t;
using reference = value_type&;
using const_reference = value_type const&;
using pointer = T*;
using const_pointer = T const*;
using iterator = pointer;
using const_iterator = const_pointer;
using Base = tensorrt_llm::common::ArrayView<T>;
using typename Base::size_type;
BufferRange(T* data, size_type size)
: mData{data}
, mSize{size}
: Base{data, size}
{
}
@ -583,65 +577,6 @@ public:
: BufferRange(bufferCast<T>(buffer), buffer.getSize())
{
}
iterator begin()
{
return mData;
}
iterator end()
{
return mData + mSize;
}
const_iterator begin() const
{
return mData;
}
const_iterator end() const
{
return mData + mSize;
}
const_iterator cbegin()
{
return mData;
}
const_iterator cend()
{
return mData + mSize;
}
const_iterator cbegin() const
{
return mData;
}
const_iterator cend() const
{
return mData + mSize;
}
[[nodiscard]] size_type size() const
{
return mSize;
}
reference operator[](size_type index)
{
return mData[index];
}
const_reference operator[](size_type index) const
{
return mData[index];
}
private:
T* mData;
size_type mSize;
};
//! \brief Utility function to print a buffer.

View File

@ -16,8 +16,10 @@
#pragma once
#include "tensorrt_llm/executor/executor.h"
#include "tensorrt_llm/runtime/common.h"
#include <functional>
#include <optional>
#include <vector>
@ -57,6 +59,9 @@ private:
return std::make_optional<std::vector<T>>(values);
}
template <typename T>
using Vec = std::vector<T>;
public:
explicit SamplingConfig(SizeType beamWidth = 1)
: beamWidth{beamWidth}
@ -86,6 +91,39 @@ public:
= fuseValues<FloatType>(configs, [&configs](SizeType ci) { return configs[ci].draftAcceptanceThreshold; });
}
explicit SamplingConfig(executor::SamplingConfig const& samplingConfig,
std::optional<executor::SpeculativeDecodingConfig> const& specDecodingConfig)
: beamWidth{samplingConfig.getBeamWidth()}
{
if (specDecodingConfig && specDecodingConfig.value().getAcceptanceThreshold())
{
draftAcceptanceThreshold = Vec<FloatType>{specDecodingConfig.value().getAcceptanceThreshold().value()};
}
#define SET_FROM_OPTIONAL(varName, VarName, VarType) \
\
if (samplingConfig.get##VarName()) \
{ \
varName = Vec<VarType>{samplingConfig.get##VarName().value()}; \
}
SET_FROM_OPTIONAL(topK, TopK, SizeType)
SET_FROM_OPTIONAL(topP, TopP, FloatType)
SET_FROM_OPTIONAL(topPMin, TopPMin, FloatType)
SET_FROM_OPTIONAL(topPResetIds, TopPResetIds, SizeType)
SET_FROM_OPTIONAL(topPDecay, TopPDecay, FloatType)
SET_FROM_OPTIONAL(randomSeed, RandomSeed, uint64_t)
SET_FROM_OPTIONAL(temperature, Temperature, FloatType)
SET_FROM_OPTIONAL(minLength, MinLength, SizeType)
SET_FROM_OPTIONAL(beamSearchDiversityRate, BeamSearchDiversityRate, FloatType)
SET_FROM_OPTIONAL(repetitionPenalty, RepetitionPenalty, FloatType)
SET_FROM_OPTIONAL(presencePenalty, PresencePenalty, FloatType)
SET_FROM_OPTIONAL(frequencyPenalty, FrequencyPenalty, FloatType)
SET_FROM_OPTIONAL(lengthPenalty, LengthPenalty, FloatType)
#undef SET_FROM_OPTIONAL
}
public:
SizeType beamWidth;

View File

@ -34,6 +34,9 @@ add_subdirectory(runtime)
set(BATCH_MANAGER_TARGET tensorrt_llm_batch_manager_static)
set(BATCH_MANAGER_TARGET_ARCH "unknown")
set(EXECUTOR_TARGET tensorrt_llm_executor_static)
set(EXECUTOR_TARGET_ARCH "unknown")
message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}")
if(NOT WIN32) # Linux
execute_process(
@ -52,8 +55,10 @@ if(NOT WIN32) # Linux
if(CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64")
set(BATCH_MANAGER_TARGET_ARCH "x86_64-linux-gnu")
set(EXECUTOR_TARGET_ARCH "x86_64-linux-gnu")
elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64")
set(BATCH_MANAGER_TARGET_ARCH "aarch64-linux-gnu")
set(EXECUTOR_TARGET_ARCH "aarch64-linux-gnu")
if(NOT ${OS_ID} MATCHES "ubuntu" OR ${OS_VERSION_ID} VERSION_LESS 22.04)
message(
FATAL_ERROR
@ -68,6 +73,7 @@ else() # Windows
# AMD64, IA64, ARM64, EM64T, X86
if(CMAKE_SYSTEM_PROCESSOR MATCHES "AMD64")
set(BATCH_MANAGER_TARGET_ARCH "x86_64-windows-msvc")
set(EXECUTOR_TARGET_ARCH "x86_64-windows-msvc")
else()
message(
FATAL_ERROR
@ -105,8 +111,39 @@ else()
endif()
endif()
if(BUILD_EXECUTOR)
add_subdirectory(executor)
else()
add_library(${EXECUTOR_TARGET} STATIC IMPORTED)
if(NOT WIN32) # Linux
if(USE_CXX11_ABI)
set(EXECUTOR_LIB_LOC
"${CMAKE_CURRENT_SOURCE_DIR}/executor/${EXECUTOR_TARGET_ARCH}/libtensorrt_llm_executor_static.a"
)
else()
set(EXECUTOR_LIB_LOC
"${CMAKE_CURRENT_SOURCE_DIR}/executor/${EXECUTOR_TARGET_ARCH}/libtensorrt_llm_executor_static.pre_cxx11.a"
)
endif()
else() # Windows
set(EXECUTOR_LIB_LOC
"${CMAKE_CURRENT_SOURCE_DIR}/executor/${EXECUTOR_TARGET_ARCH}/tensorrt_llm_executor_static.lib"
)
endif()
set_property(TARGET ${EXECUTOR_TARGET} PROPERTY IMPORTED_LOCATION
${EXECUTOR_LIB_LOC})
file(SIZE ${EXECUTOR_LIB_LOC} EXECUTOR_LIB_SIZE)
if(EXECUTOR_LIB_SIZE LESS 1024)
message(
FATAL_ERROR
"The executor library is truncated or incomplete. This is usually caused by using Git LFS (Large File Storage) incorrectly. Please try running command `git lfs install && git lfs pull`."
)
endif()
endif()
find_package(Threads REQUIRED)
target_link_libraries(${BATCH_MANAGER_TARGET} INTERFACE Threads::Threads)
target_link_libraries(${EXECUTOR_TARGET} INTERFACE Threads::Threads)
if(NOT WIN32)
if(USE_CXX11_ABI)
@ -128,6 +165,26 @@ else()
add_custom_target(check_symbol)
endif()
if(NOT WIN32)
if(USE_CXX11_ABI)
add_custom_command(
OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/.check_symbol_executor"
COMMAND nm -C $<TARGET_FILE:${EXECUTOR_TARGET}> | grep -q 'std::__cxx11::'
DEPENDS ${EXECUTOR_TARGET})
else()
add_custom_command(
OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/.check_symbol_executor"
COMMAND nm -C $<TARGET_FILE:${EXECUTOR_TARGET}> | grep -qv
'std::__cxx11::'
DEPENDS ${EXECUTOR_TARGET})
endif()
add_custom_target(
check_symbol_executor
DEPENDS "${CMAKE_CURRENT_BINARY_DIR}/.check_symbol_executor")
else()
add_custom_target(check_symbol_executor)
endif()
set(TRTLLM_LINK_LIBS
${CUBLAS_LIB}
${CUBLASLT_LIB}
@ -175,11 +232,28 @@ else()
"-Wl,--no-whole-archive")
endif()
if(WIN32)
target_link_libraries(${SHARED_TARGET}
PUBLIC $<TARGET_FILE:${EXECUTOR_TARGET}>)
set_target_properties(
${SHARED_TARGET} PROPERTIES LINK_FLAGS "/WHOLEARCHIVE:${EXECUTOR_TARGET}")
else()
# Assume everything else is like gcc
target_link_libraries(
${SHARED_TARGET}
PRIVATE "-Wl,--whole-archive" $<TARGET_FILE:${EXECUTOR_TARGET}>
"-Wl,--no-whole-archive")
endif()
add_dependencies(${SHARED_TARGET} check_symbol)
add_dependencies(${SHARED_TARGET} check_symbol_executor)
# Cyclic dependency of batch manager on TRT-LLM
target_link_libraries(${BATCH_MANAGER_TARGET} INTERFACE ${SHARED_TARGET})
# Cyclic dependency of executor on TRT-LLM
target_link_libraries(${EXECUTOR_TARGET} INTERFACE ${SHARED_TARGET})
if(BUILD_PYT)
add_subdirectory(thop)
endif()

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:4ba61c04ed7623fc44b5364802c1893fa824467455f4a9fe8245d5d51fef97e6
size 2172266

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:bf4afdfd281029c8e4bf0af548529b94a4a6d0f9bb5148ae10423e5e0275db06
size 2195822

View File

@ -0,0 +1,3 @@
4c405d39a0cbb93d44a5758480a1a223 libtensorrt_llm_batch_manager_static.a
68aea75a2ed5b219eec5a0f77ce33482 libtensorrt_llm_batch_manager_static.pre_cxx11.a
9b63c754d2a1edc7a17106e83c3e131d312f0a80 commit

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:0268f64b0c2540e07bf05ad458f7aa33c9d6e65fc4f5c85cd8d0946d658ffeb8
size 2092012
oid sha256:39835ca321e9c45d3b554ebceb1734966b75f83dbe8c550cc44846fb4fae8f72
size 2110728

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:89ae0be676e7aa9b562f6745636f7d77198f87b83ec6295aff74273767e4fca7
size 2071180
oid sha256:789c2eba349161e84a76b95b23f8294cf3bdcf855871672d76722c4ae858d81b
size 2091842

View File

@ -1,2 +1,2 @@
63c3f64faa14f9d5d66b7e186a6cc80b libtensorrt_llm_batch_manager_static.a
dbcc1bbe80d977c1655d32ef69b36578 libtensorrt_llm_batch_manager_static.pre_cxx11.a
30a6c963121b3cfda21dc0117b7984e1 libtensorrt_llm_batch_manager_static.a
0d2d2e3157201f6336d749b3e6f994bc libtensorrt_llm_batch_manager_static.pre_cxx11.a

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:49c84a22cee9e6c3a975db08d8d0d8dbe88867e2eb4fc12a4b3ff6c1c90e8c21
size 586202

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:fea93ae7d09e74b073a65d5d0ac34aec9ccc8f8299af1abd6826e97e9c8427f4
size 589652

View File

@ -0,0 +1,3 @@
73999f4c2b3a4328db454b7ab6fe86d3 libtensorrt_llm_executor_static.a
df53aa83848b5ed75550a7b536ca02a4 libtensorrt_llm_executor_static.pre_cxx11.a
9b63c754d2a1edc7a17106e83c3e131d312f0a80 commit

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:643e546711fd33a85073560e3428c6a2f60525f7592aa3328043dfad61631c30
size 586532

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:28059131a9325c88bd362cb12c57a2b2e47d3e0aac140e5d1cf9a7020a81999e
size 570860

View File

@ -0,0 +1,2 @@
fa89714705a1915f052c635a07dc4c73 libtensorrt_llm_executor_static.a
83cbfaf10bedd7d8edeab33552dcf3df libtensorrt_llm_executor_static.pre_cxx11.a

View File

@ -472,31 +472,35 @@ void invokeCopyNextStepIds(int* nextStepIds, int** outputIdsPtr, const int* sequ
}
__global__ void transposeLogProbs(float* outputLogProbs, float* outputLogProbsTiled, const int* sequenceLengths,
const int* batchSlots, int batchSize, int beamWidth, int maxSeqLen)
const int* batchSlots, int batchSize, int maxBatchSize, int beamWidth, int maxSeqLen)
{
int index = blockIdx.x * blockDim.x + threadIdx.x;
const int batchIdx = index / (beamWidth * maxSeqLen);
auto const batchSlot = batchSlots != nullptr ? batchSlots[batchIdx] : batchIdx;
const int tmpIdx = index % (beamWidth * maxSeqLen);
const int beamIdx = tmpIdx / maxSeqLen;
const int pos = tmpIdx % maxSeqLen;
int const batchIdx = index / (beamWidth * maxSeqLen);
int const tmpIdx = index % (beamWidth * maxSeqLen);
int const beamIdx = tmpIdx / maxSeqLen;
int const pos = tmpIdx % maxSeqLen;
if (batchIdx >= batchSize)
{
return;
}
if (batchIdx < batchSize && pos < sequenceLengths[batchSlot])
auto const batchSlot = batchSlots != nullptr ? batchSlots[batchIdx] : batchIdx;
if (pos < sequenceLengths[batchSlot])
{
auto const batchBeamIdx = batchSlot * beamWidth * maxSeqLen + beamIdx * maxSeqLen + pos;
outputLogProbs[batchBeamIdx]
= outputLogProbsTiled[pos * batchSize * beamWidth + batchSlot * beamWidth + beamIdx];
= outputLogProbsTiled[pos * maxBatchSize * beamWidth + batchSlot * beamWidth + beamIdx];
}
}
void invokeTransposeLogProbs(float* outputLogProbs, float* outputLogProbsTiled, const int* sequenceLengths,
const int* batchSlots, int batchSize, int beamWidth, int maxSeqLen, cudaStream_t stream)
const int* batchSlots, int batchSize, int maxBatchSize, int beamWidth, int maxSeqLen, cudaStream_t stream)
{
dim3 block(256);
dim3 grid(divUp(batchSize * beamWidth * maxSeqLen, block.x));
transposeLogProbs<<<grid, block, 0, stream>>>(
outputLogProbs, outputLogProbsTiled, sequenceLengths, batchSlots, batchSize, beamWidth, maxSeqLen);
transposeLogProbs<<<grid, block, 0, stream>>>(outputLogProbs, outputLogProbsTiled, sequenceLengths, batchSlots,
batchSize, maxBatchSize, beamWidth, maxSeqLen);
}
__global__ void acceptDraftTokensByIds(int32_t const* draftIds, int32_t const* targetIds, int32_t const* contextLengths,

View File

@ -129,7 +129,8 @@ void acceptDraftTokensByLogits(T* draftLogits, T** targetLogits, T* draftProbs,
int32_t maxDraftTokens, bool randomThreshold, float constantThreshold, cudaStream_t stream);
void invokeTransposeLogProbs(float* output_log_probs, float* output_log_probs_tiled, int32_t const* sequence_lengths,
int32_t const* batchSlots, int32_t batch_size, int32_t beam_width, int32_t max_seq_len, cudaStream_t stream);
int32_t const* batchSlots, int32_t batch_size, int32_t max_batch_size, int32_t beam_width, int32_t max_seq_len,
cudaStream_t stream);
void invokeAcceptTokens(int32_t const* draft_tokens, int32_t const* target_tokens, int32_t const* context_lengths,
int32_t const* nums_draft_tokens, int32_t* sequence_lengths, bool const* finished, bool* finished_final,

View File

@ -138,7 +138,8 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ int selected_beams;
__shared__ float old_cum_log_probs[MAX_K2];
__shared__ cub_kvp cta_topk[MAX_K2];
__shared__ char cta_topk_store[MAX_K2 * sizeof(cub_kvp)];
auto* cta_topk = reinterpret_cast<cub_kvp*>(cta_topk_store);
if (thread_id == 0)
{
@ -687,7 +688,8 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_online_softmax_topk_sta
MD partial_md{-MAX_T_VAL, 0.0f};
cub_kvp total_topk{V - 1, -MAX_T_VAL};
__shared__ cub_kvp buf_smem_kv[MAX_K2];
__shared__ char buf_smem_kv_store[MAX_K2 * sizeof(cub_kvp)];
auto* buf_smem_kv = reinterpret_cast<cub_kvp*>(buf_smem_kv_store);
// load and unpack into registers through smem
for (int idx = thread_id; idx < PACKED_TOP_KMD_SIZE * parts_per_beam; idx += THREADBLOCK_SIZE)

View File

@ -308,7 +308,7 @@ template <typename T, typename IdxT, typename AccT, int BitsPerPass>
__device__ __forceinline__ void filterAndHistogram(T const* inBuf, IdxT const* inIdxBuf, T* outBuf, IdxT* outIdxBuf,
int previousLen, Counter<T, IdxT, AccT>* counter, AccT* histogram, IdxT* countHistogram, int pass,
float* outputLogProbs, float* cumLogProbs, IdxT** ids, IdxT const* endIds, IdxT* sequenceLengths,
FinishedState* finishedOutput, int const batchId, bool earlyStop)
FinishedState* finishedOutput, int const batchId, int maxBatchSize, bool earlyStop)
{
static_assert(std::is_same_v<T, half> | std::is_same_v<T, float>, "T needs to be either half or float");
static_assert(std::is_same_v<AccT, float>, "AccT needs to be float");
@ -359,7 +359,7 @@ __device__ __forceinline__ void filterAndHistogram(T const* inBuf, IdxT const* i
// See the remark above on the distributed execution of `f` using
// vectorizedProcess.
auto f = [inIdxBuf, outBuf, outIdxBuf, selectMin, startBit, mask, previousStartBit, kthValueBits, pFilterCnt,
outputLogProbs, cumLogProbs, ids, endIds, sequenceLengths, finishedOutput, batchId,
outputLogProbs, cumLogProbs, ids, endIds, sequenceLengths, finishedOutput, batchId, maxBatchSize,
earlyStop](T value, IdxT i)
{
auto const previousBits = (twiddleIn(value, selectMin) >> previousStartBit) << previousStartBit;
@ -370,8 +370,8 @@ __device__ __forceinline__ void filterAndHistogram(T const* inBuf, IdxT const* i
int const currentStep = sequenceLengths[batchId];
IdxT index = inIdxBuf ? inIdxBuf[i] : i;
ids[batchId][currentStep] = index;
epilogue(
value, index, outputLogProbs, cumLogProbs, endIds, sequenceLengths, finishedOutput, batchId);
epilogue(value, index, outputLogProbs, cumLogProbs, endIds, sequenceLengths, finishedOutput,
batchId, maxBatchSize);
}
if (outBuf)
{
@ -506,14 +506,14 @@ __device__ void chooseBucket(
*/
template <typename T, typename IdxT>
__device__ void epilogue(T const value, IdxT const index, float* outputLogProbs, float* cumLogProbs, IdxT const* endIds,
IdxT* sequenceLengths, FinishedState* finishedOutput, int const batchId)
IdxT* sequenceLengths, FinishedState* finishedOutput, int const batchId, int maxBatchSize)
{
if (outputLogProbs != nullptr || cumLogProbs != nullptr)
{
float res = logf(value);
if (outputLogProbs)
{
outputLogProbs[batchId] = res;
outputLogProbs[sequenceLengths[batchId] * maxBatchSize + batchId] = res;
}
if (cumLogProbs)
{
@ -542,7 +542,7 @@ __device__ void epilogue(T const value, IdxT const index, float* outputLogProbs,
template <typename T, typename IdxT, typename AccT, int BitsPerPass>
__device__ void lastFilter(T const* inBuf, IdxT const* inIdxBuf, IdxT currentLen, Counter<T, IdxT, AccT>* counter,
float* outputLogProbs, float* cumLogProbs, IdxT** ids, IdxT const* endIds, IdxT* sequenceLengths,
FinishedState* finishedOutput, int const batchId)
FinishedState* finishedOutput, int const batchId, int maxBatchSize)
{
auto const kthValueBits = counter->kthValueBits;
auto const equalValue = twiddleOut<T>(kthValueBits, false);
@ -565,7 +565,8 @@ __device__ void lastFilter(T const* inBuf, IdxT const* inIdxBuf, IdxT currentLen
if (threadIdx.x == 0)
{
epilogue(equalValue, *outIdx, outputLogProbs, cumLogProbs, endIds, sequenceLengths, finishedOutput, batchId);
epilogue(equalValue, *outIdx, outputLogProbs, cumLogProbs, endIds, sequenceLengths, finishedOutput, batchId,
maxBatchSize);
}
}
@ -609,7 +610,7 @@ __device__ void lastFilter(T const* inBuf, IdxT const* inIdxBuf, IdxT currentLen
template <typename T, typename IdxT, typename AccT, int BitsPerPass, int BlockSize, bool is_fused_filter = false>
__global__ void airTopPSampling(Counter<T, IdxT, AccT>* counters, AccT* histograms, IdxT* countHistograms, IdxT** ids,
int* sequenceLengths, FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs,
float* outputLogProbs, IdxT const* endIds, int const batchSize, bool const* skipDecode, int const pass, T* buf1,
float* outputLogProbs, IdxT const* endIds, int const maxBatchSize, bool const* skipDecode, int const pass, T* buf1,
IdxT* idxBuf1, T* buf2, IdxT* idxBuf2, int32_t const* batchSlots)
{
assert(sequenceLengths != nullptr);
@ -698,7 +699,7 @@ __global__ void airTopPSampling(Counter<T, IdxT, AccT>* counters, AccT* histogra
filterAndHistogram<T, IdxT, AccT, BitsPerPass>(inBuf, inIdxBuf, outBuf, outIdxBuf, previousLen, counter, histogram,
countHistogram, pass, outputLogProbs, cumLogProbs, ids, endIds, sequenceLengths, finishedOutput, batchSlot,
earlyStop);
maxBatchSize, earlyStop);
__syncthreads();
__threadfence();
@ -779,7 +780,7 @@ __global__ void airTopPSampling(Counter<T, IdxT, AccT>* counters, AccT* histogra
{
lastFilter<T, IdxT, AccT, BitsPerPass>(outBuf ? outBuf : inBuf, outIdxBuf ? outIdxBuf : inIdxBuf,
outBuf ? currentLen : counter->oriLen, counter, outputLogProbs, cumLogProbs, ids, endIds,
sequenceLengths, finishedOutput, batchSlot);
sequenceLengths, finishedOutput, batchSlot, maxBatchSize);
__syncthreads();
}
@ -891,9 +892,9 @@ unsigned calcAirTopPBlockNum(int batchSize, IdxT len, int smCnt)
template <typename T>
void invokeBatchAirTopPSampling(void* workspace, size_t& workspaceSize, int** outputIds, int* sequenceLength,
FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs,
T const* logProbs, curandState_t* curandstate, int const batchSize, size_t const vocabSizePadded, int const* endIds,
float const maxTopP, float const* topPs, cudaStream_t stream, int blockNum, bool const* skipDecode,
int32_t const* batchSlots)
T const* logProbs, curandState_t* curandstate, int const batchSize, int maxBatchSize, size_t const vocabSizePadded,
int const* endIds, float const maxTopP, float const* topPs, cudaStream_t stream, int blockNum,
bool const* skipDecode, int32_t const* batchSlots)
{
using IdxT = int;
using AccT = float;
@ -953,46 +954,47 @@ void invokeBatchAirTopPSampling(void* workspace, size_t& workspaceSize, int** ou
}
kernel<<<grid, SAMPLING_BLOCK_SIZE, 0, stream>>>(counters, histograms, countHistograms, outputIds,
sequenceLength, finishedInput, finishedOutput, cumLogProbs, outputLogProbs, endIds, batchSize, skipDecode,
pass, buf1, idxBuf1, buf2, idxBuf2, batchSlots);
sequenceLength, finishedInput, finishedOutput, cumLogProbs, outputLogProbs, endIds, maxBatchSize,
skipDecode, pass, buf1, idxBuf1, buf2, idxBuf2, batchSlots);
sync_check_cuda_error();
}
}
template void invokeBatchAirTopPSampling(void* workspace, size_t& workspaceSize, int** outputIds, int* sequenceLength,
FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs,
float const* logProbs, curandState_t* curandstate, int const batchSize, size_t const vocabSizePadded,
int const* endIds, float const maxTopP, float const* topPs, cudaStream_t stream, int blockNum,
bool const* skipDecode, int32_t const* batchSlots);
float const* logProbs, curandState_t* curandstate, int const batchSize, int maxBatchSize,
size_t const vocabSizePadded, int const* endIds, float const maxTopP, float const* topPs, cudaStream_t stream,
int blockNum, bool const* skipDecode, int32_t const* batchSlots);
template void invokeBatchAirTopPSampling(void* workspace, size_t& workspaceSize, int** outputIds, int* sequenceLength,
FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs,
half const* logProbs, curandState_t* curandstate, int const batchSize, size_t const vocabSizePadded,
int const* endIds, float const maxTopP, float const* topPs, cudaStream_t stream, int blockNum,
bool const* skipDecode, int32_t const* batchSlots);
half const* logProbs, curandState_t* curandstate, int const batchSize, int maxBatchSize,
size_t const vocabSizePadded, int const* endIds, float const maxTopP, float const* topPs, cudaStream_t stream,
int blockNum, bool const* skipDecode, int32_t const* batchSlots);
template <typename T>
void invokeAirTopPSampling(void* workspace, size_t& workspaceSize, int** outputIds, int* sequenceLength,
FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs,
T const* logProbs, curandState_t* curandstate, int const batchSize, size_t const vocabSizePadded, int const* endIds,
float const topP, cudaStream_t stream, int blockNum, bool const* skipDecode, int32_t const* batchSlots)
T const* logProbs, curandState_t* curandstate, int const batchSize, int maxBatchSize, size_t const vocabSizePadded,
int const* endIds, float const topP, cudaStream_t stream, int blockNum, bool const* skipDecode,
int32_t const* batchSlots)
{
invokeBatchAirTopPSampling(workspace, workspaceSize, outputIds, sequenceLength, finishedInput, finishedOutput,
cumLogProbs, outputLogProbs, logProbs, curandstate, batchSize, vocabSizePadded, endIds, topP, nullptr, stream,
blockNum, skipDecode, batchSlots);
cumLogProbs, outputLogProbs, logProbs, curandstate, batchSize, maxBatchSize, vocabSizePadded, endIds, topP,
nullptr, stream, blockNum, skipDecode, batchSlots);
}
template void invokeAirTopPSampling(void* workspace, size_t& workspaceSize, int** outputIds, int* sequenceLength,
FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs,
float const* logProbs, curandState_t* curandstate, int const batchSize, size_t const vocabSizePadded,
int const* endIds, float const topP, cudaStream_t stream, int blockNum, bool const* skipDecode,
int32_t const* batchSlots);
float const* logProbs, curandState_t* curandstate, int const batchSize, int maxBatchSize,
size_t const vocabSizePadded, int const* endIds, float const topP, cudaStream_t stream, int blockNum,
bool const* skipDecode, int32_t const* batchSlots);
template void invokeAirTopPSampling(void* workspace, size_t& workspaceSize, int** outputIds, int* sequenceLength,
FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs,
half const* logProbs, curandState_t* curandstate, int const batchSize, size_t const vocabSizePadded,
int const* endIds, float const topP, cudaStream_t stream, int blockNum, bool const* skipDecode,
int32_t const* batchSlots);
half const* logProbs, curandState_t* curandstate, int const batchSize, int maxBatchSize,
size_t const vocabSizePadded, int const* endIds, float const topP, cudaStream_t stream, int blockNum,
bool const* skipDecode, int32_t const* batchSlots);
template unsigned calcAirTopPBlockNum<float, int, float>(int batchSize, int len, int smCnt);
template unsigned calcAirTopPBlockNum<half, int, float>(int batchSize, int len, int smCnt);

View File

@ -123,7 +123,7 @@ __global__ void topKStage2Sampling(const int* __restrict topKTmpIdBuf, T* topKTm
int* sequenceLengths, const FinishedState* finishedInput, FinishedState* finishedOutput, float* cumLogProbs,
float* outputLogProbs, const int maxTopK, const int* topKs, const float topP, const float* topPs,
curandState_t* curandstate, const int* endIds, const int vocabSize, const bool* skipDecode, const int* batchSlots,
const bool normalizeLogProbs, const bool logitHasProbs)
int maxBatchSize, const bool normalizeLogProbs, const bool logitHasProbs)
{
bool const IS_FP16 = std::is_same<T, half>::value;
T const MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
@ -210,7 +210,8 @@ __global__ void topKStage2Sampling(const int* __restrict topKTmpIdBuf, T* topKTm
// If s_id is -1 here we force output token to the last from vocabulary to get vivid indicator of smth
// going wrong for the debug
auto outputId = idx != -1 ? topKTmpIdBuf[batchIdx * stride + idx] % vocabSize : vocabSize - 1;
ids[batchSlot][sequenceLengths[batchSlot]] = outputId;
auto const curSeqLen = sequenceLengths[batchSlot];
ids[batchSlot][curSeqLen] = outputId;
if (cumLogProbs != nullptr || outputLogProbs != nullptr)
{
float logProb = logf(expLogit);
@ -225,7 +226,8 @@ __global__ void topKStage2Sampling(const int* __restrict topKTmpIdBuf, T* topKTm
// log_prob = log P(i | i is in top-k) = log(expLogit)
// normalized:
// log_prob = log P(i | i is in top-k) = log(expLogit / sum)
outputLogProbs[batchSlot] = normalizeLogProbs ? logProb - logf(s_sum) : logProb;
outputLogProbs[curSeqLen * maxBatchSize + batchSlot]
= normalizeLogProbs ? logProb - logf(s_sum) : logProb;
}
}
break;
@ -256,8 +258,8 @@ __global__ void topKStage2Sampling(const int* __restrict topKTmpIdBuf, T* topKTm
topKStage2Sampling<T, BLOCK_SIZE_2_, BLOCKS_PER_BEAM_> \
<<<batchSize, BLOCK_SIZE_2_, K_MAX * sizeof(int) + K_MAX * sizeof(float), stream>>>(topKTmpIdBuf, \
topKTmpValBuf, ids, sequenceLengths, finishedInput, finishedOutput, cumLogProbs, outputLogProbs, maxTopK, \
topKs, topP, topPs, curandstate, endIds, vocabSize, skipDecode, batchSlots, normalizeLogProbs, \
logitsHasProbs); \
topKs, topP, topPs, curandstate, endIds, vocabSize, skipDecode, batchSlots, maxBatchSize, \
normalizeLogProbs, logitsHasProbs); \
break;
template <typename T>
@ -265,7 +267,7 @@ void invokeBatchTopKSampling(void* workspace, size_t& workspaceSize, const T* lo
const FinishedState* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs,
curandState_t* curandstate, const int maxTopK, const int* topKs, const float topP, const float* topPs,
const int vocabSizePadded, const int* endIds, const int* batchSlots, cudaStream_t stream, const int batchSize,
const bool* skipDecode, const bool normalizeLogProbs, const bool logitsHasProbs)
int maxBatchSize, const bool* skipDecode, const bool normalizeLogProbs, const bool logitsHasProbs)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
@ -330,37 +332,39 @@ template void invokeBatchTopKSampling(void* workspace, size_t& workspaceSize, co
int* sequenceLengths, const FinishedState* finishedInput, FinishedState* finishedOutput, float* cumLogProbs,
float* outputLogProbs, curandState_t* curandstate, const int maxTopK, const int* topKs, const float topP,
const float* topPs, const int vocabSizePadded, const int* endIds, const int* batchSlots, cudaStream_t stream,
const int batchSize, const bool* skipDecode, const bool normalizeLogProbs, const bool logitsHasProbs);
const int batchSize, int maxBatchSize, const bool* skipDecode, const bool normalizeLogProbs,
const bool logitsHasProbs);
template void invokeBatchTopKSampling(void* workspace, size_t& workspaceSize, const half* logProbs, int** ids,
int* sequenceLengths, const FinishedState* finishedInput, FinishedState* finishedOutput, float* cumLogProbs,
float* outputLogProbs, curandState_t* curandstate, const int maxTopK, const int* topKs, const float topP,
const float* topPs, const int vocabSizePadded, const int* endIds, const int* batchSlots, cudaStream_t stream,
const int batchSize, const bool* skipDecode, const bool normalizeLogProbs, const bool logitsHasProbs);
const int batchSize, int maxBatchSize, const bool* skipDecode, const bool normalizeLogProbs,
const bool logitsHasProbs);
template <typename T>
void invokeTopKSampling(void* workspace, size_t& workspaceSize, const T* logProbs, int** ids, int* sequenceLengths,
const FinishedState* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs,
curandState_t* curandstate, const int topK, const float topP, const int vocabSizePadded, const int* endIds,
const int* batchSlots, cudaStream_t stream, const int batchSize, const bool* skipDecode,
const int* batchSlots, cudaStream_t stream, const int batchSize, int maxBatchSize, const bool* skipDecode,
const bool normalizeLogProbs, const bool logitsHasProbs)
{
invokeBatchTopKSampling(workspace, workspaceSize, logProbs, ids, sequenceLengths, finishedInput, finishedOutput,
cumLogProbs, outputLogProbs, curandstate, topK, nullptr, topP, nullptr, vocabSizePadded, endIds, batchSlots,
stream, batchSize, skipDecode, normalizeLogProbs, logitsHasProbs);
stream, batchSize, maxBatchSize, skipDecode, normalizeLogProbs, logitsHasProbs);
}
template void invokeTopKSampling(void* workspace, size_t& workspaceSize, const float* logProbs, int** ids,
int* sequenceLengths, const FinishedState* finishedInput, FinishedState* finishedOutput, float* cumLogProbs,
float* outputLogProbs, curandState_t* curandstate, const int topK, const float topP, const int vocabSizePadded,
const int* endIds, const int* batchSlots, cudaStream_t stream, const int batchSize, const bool* skipDecode,
const bool normalizeLogProbs, const bool logitsHasProbs);
const int* endIds, const int* batchSlots, cudaStream_t stream, const int batchSize, int maxBatchSize,
const bool* skipDecode, const bool normalizeLogProbs, const bool logitsHasProbs);
template void invokeTopKSampling(void* workspace, size_t& workspaceSize, const half* logProbs, int** ids,
int* sequenceLengths, const FinishedState* finishedInput, FinishedState* finishedOutput, float* cumLogProbs,
float* outputLogProbs, curandState_t* curandstate, const int topK, const float topP, const int vocabSizePadded,
const int* endIds, const int* batchSlots, cudaStream_t stream, const int batchSize, const bool* skipDecode,
const bool normalizeLogProbs, const bool logitsHasProbs);
const int* endIds, const int* batchSlots, cudaStream_t stream, const int batchSize, int maxBatchSize,
const bool* skipDecode, const bool normalizeLogProbs, const bool logitsHasProbs);
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -59,6 +59,7 @@ namespace kernels
//! \param batchSlots input buffer[batchSize], optional. Indices of rows of data in memory pool
//! \param stream cuda stream
//! \param batchSize batch size
//! \param maxBatchSize maximum batch size
//! \param skipDecode input buffer [maxBatchSize]. Flags whether to skip decoding per request
//! \param normalizeLogProbs when set to True outputLogProbs are normalized to TopK
//! \param logitsHasProbs flag to highlight that logProbs contains probabilities
@ -68,14 +69,14 @@ void invokeBatchTopKSampling(void* workspace, size_t& workspaceSize, const T* lo
const FinishedState* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs,
curandState_t* curandstate, const int maxTopK, const int* topKs, const float topP, const float* topPs,
const int vocabSizePadded, const int* endIds, const int* batchSlots, cudaStream_t stream, const int batchSize,
const bool* skipDecode, const bool normalizeLogProbs, const bool logitsHasProbs);
int maxBatchSize, const bool* skipDecode, const bool normalizeLogProbs, const bool logitsHasProbs);
//! \brief Specialization of invokeBatchTopKSampling with topPs=nullptr and topKs=nullptr
template <typename T>
void invokeTopKSampling(void* workspace, size_t& workspaceSize, const T* logProbs, int** outputIds, int* sequenceLength,
const FinishedState* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs,
curandState_t* curandstate, const int topK, const float topP, const int vocabSizePadded, const int* endIds,
const int* batchSlots, cudaStream_t stream, const int batchSize, const bool* skipDecode,
const int* batchSlots, cudaStream_t stream, const int batchSize, int maxBatchSize, const bool* skipDecode,
const bool normalizeLogProbs, const bool logitsHasProbs);
} // namespace kernels

View File

@ -164,7 +164,8 @@ struct BlockPrefixCallbackOp
template <typename T>
__device__ void epilogue(int batchId, int currentStep, int offset, int** ids, int* sortedIdVals, T* sortedLogProbs,
float* cumLogProbs, float* outputLogProbs, int const* endIds, int* sequenceLengths, FinishedState* finishedOutput)
float* cumLogProbs, float* outputLogProbs, int const* endIds, int* sequenceLengths, FinishedState* finishedOutput,
int maxBatchSize)
{
ids[batchId][currentStep] = sortedIdVals[offset];
@ -177,7 +178,7 @@ __device__ void epilogue(int batchId, int currentStep, int offset, int** ids, in
}
if (outputLogProbs != nullptr)
{
outputLogProbs[batchId] = lprob;
outputLogProbs[sequenceLengths[batchId] * maxBatchSize + batchId] = lprob;
}
}
if (sequenceLengths != nullptr && finishedOutput != nullptr)
@ -199,7 +200,7 @@ template <typename T, int blockSize>
__global__ void topPSsampling(T* sortedLogProbs, int* sortedIdVals, int** ids, int* sequenceLength,
FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs,
int const* beginOffsetBuf, int const* offsetBuf, int const vocabSize, curandState_t* curandstate, float const topP,
float const* topPs, int const* endIds, int const batchSize, bool const* skipDecode, int const* batchSlots)
float const* topPs, int const* endIds, int maxBatchSize, bool const* skipDecode, int const* batchSlots)
{
/**
* Each block processes one request row sorted in descending order by probabilities.
@ -258,7 +259,7 @@ __global__ void topPSsampling(T* sortedLogProbs, int* sortedIdVals, int** ids, i
{
int offset = batchId * vocabSize;
epilogue(batchSlot, currentStep, offset, ids, sortedIdVals, sortedLogProbs, cumLogProbs, outputLogProbs,
endIds, sequenceLength, finishedOutput);
endIds, sequenceLength, finishedOutput, maxBatchSize);
}
return;
}
@ -299,7 +300,7 @@ __global__ void topPSsampling(T* sortedLogProbs, int* sortedIdVals, int** ids, i
if (threadIdx.x == min(blockDim.x - count, blockDim.x - 1))
{
epilogue(batchSlot, currentStep, offset + selectedTokenId, ids, sortedIdVals, sortedLogProbs, cumLogProbs,
outputLogProbs, endIds, sequenceLength, finishedOutput);
outputLogProbs, endIds, sequenceLength, finishedOutput, maxBatchSize);
}
}
@ -307,7 +308,7 @@ template <typename T>
void invokeBatchTopPSampling(void* workspace, size_t& workspaceSize, size_t& cubTempStorageSize, int** outputIds,
int* sequenceLength, FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs,
float* outputLogProbs, T const* logProbs, int const* idVals, int* offsetBuf, int* beginOffsetBuf,
curandState_t* curandstate, int const batchSize, size_t const vocabSizePadded, int const* endIds,
curandState_t* curandstate, int const batchSize, int maxBatchSize, size_t const vocabSizePadded, int const* endIds,
float const maxTopP, float const* topPs, cudaStream_t stream, bool const* skipDecode, int const* batchSlots)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
@ -354,7 +355,7 @@ void invokeBatchTopPSampling(void* workspace, size_t& workspaceSize, size_t& cub
// Sample with Top P given sorted tokens
topPSsampling<T, SAMPLING_BLOCK_SIZE><<<grid, SAMPLING_BLOCK_SIZE, 0, stream>>>(sortedLogProbs, sortedIdVals,
outputIds, sequenceLength, finishedInput, finishedOutput, cumLogProbs, outputLogProbs, beginOffsetBuf,
offsetBuf + 1, vocabSize, curandstate, maxTopP, topPs, endIds, batchSize, skipDecode, batchSlots);
offsetBuf + 1, vocabSize, curandstate, maxTopP, topPs, endIds, maxBatchSize, skipDecode, batchSlots);
sync_check_cuda_error();
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
@ -363,40 +364,40 @@ void invokeBatchTopPSampling(void* workspace, size_t& workspaceSize, size_t& cub
template void invokeBatchTopPSampling(void* workspace, size_t& workspaceSize, size_t& cubTempStorageSize,
int** outputIds, int* sequenceLength, FinishedState const* finishedInput, FinishedState* finishedOutput,
float* cumLogProbs, float* outputLogProbs, float const* logProbs, int const* idVals, int* offsetBuf,
int* beginOffsetBuf, curandState_t* curandstate, int const batchSize, size_t const vocabSizePadded,
int const* endIds, float const maxTopP, float const* topPs, cudaStream_t stream, bool const* skipDecode,
int const* batchSlots);
int* beginOffsetBuf, curandState_t* curandstate, int const batchSize, int maxBatchSize,
size_t const vocabSizePadded, int const* endIds, float const maxTopP, float const* topPs, cudaStream_t stream,
bool const* skipDecode, int const* batchSlots);
template void invokeBatchTopPSampling(void* workspace, size_t& workspaceSize, size_t& cubTempStorageSize,
int** outputIds, int* sequenceLength, FinishedState const* finishedInput, FinishedState* finishedOutput,
float* cumLogProbs, float* outputLogProbs, half const* logProbs, int const* idVals, int* offsetBuf,
int* beginOffsetBuf, curandState_t* curandstate, int const batchSize, size_t const vocabSizePadded,
int const* endIds, float const maxTopP, float const* topPs, cudaStream_t stream, bool const* skipDecode,
int const* batchSlots);
int* beginOffsetBuf, curandState_t* curandstate, int const batchSize, int maxBatchSize,
size_t const vocabSizePadded, int const* endIds, float const maxTopP, float const* topPs, cudaStream_t stream,
bool const* skipDecode, int const* batchSlots);
template <typename T>
void invokeTopPSampling(void* workspace, size_t& workspaceSize, size_t& cubTempStorageSize, int** outputIds,
int* sequenceLength, FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs,
float* outputLogProbs, T const* logProbs, int const* idVals, int* offsetBuf, int* beginOffsetBuf,
curandState_t* curandstate, int const batchSize, size_t const vocabSizePadded, int const* endIds, float const topP,
cudaStream_t stream, bool const* skipDecode, int const* batchSlots)
curandState_t* curandstate, int const batchSize, int maxBatchSize, size_t const vocabSizePadded, int const* endIds,
float const topP, cudaStream_t stream, bool const* skipDecode, int const* batchSlots)
{
invokeBatchTopPSampling(workspace, workspaceSize, cubTempStorageSize, outputIds, sequenceLength, finishedInput,
finishedOutput, cumLogProbs, outputLogProbs, logProbs, idVals, offsetBuf, beginOffsetBuf, curandstate,
batchSize, vocabSizePadded, endIds, topP, nullptr, stream, skipDecode, batchSlots);
batchSize, maxBatchSize, vocabSizePadded, endIds, topP, nullptr, stream, skipDecode, batchSlots);
}
template void invokeTopPSampling(void* workspace, size_t& workspaceSize, size_t& cubTempStorageSize, int** outputIds,
int* sequenceLength, FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs,
float* outputLogProbs, float const* logProbs, int const* idVals, int* offsetBuf, int* beginOffsetBuf,
curandState_t* curandstate, int const batchSize, size_t const vocabSizePadded, int const* endIds, float const topP,
cudaStream_t stream, bool const* skipDecode, int const* batchSlots);
curandState_t* curandstate, int const batchSize, int maxBatchSize, size_t const vocabSizePadded, int const* endIds,
float const topP, cudaStream_t stream, bool const* skipDecode, int const* batchSlots);
template void invokeTopPSampling(void* workspace, size_t& workspaceSize, size_t& cubTempStorageSize, int** outputIds,
int* sequenceLength, FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs,
float* outputLogProbs, half const* logProbs, int const* idVals, int* offsetBuf, int* beginOffsetBuf,
curandState_t* curandstate, int const batchSize, size_t const vocabSizePadded, int const* endIds, float const topP,
cudaStream_t stream, bool const* skipDecode, int const* batchSlots);
curandState_t* curandstate, int const batchSize, int maxBatchSize, size_t const vocabSizePadded, int const* endIds,
float const topP, cudaStream_t stream, bool const* skipDecode, int const* batchSlots);
__global__ void computeToppDecay(float* runtimeTopP, float const* runtimeInitialTopP, int const** outputIds,
float const* topPDecay, float const* topPMin, int32_t const* topPResetIds, int const* sequenceLengths,

View File

@ -63,6 +63,7 @@ void invokeTopPInitialize(int* topPIdValBuf, int* topPOffsetBuf, int* beginTopPO
//! \param curandstate input buffer [maxBatchSize]. Curand states properly initialized using
//! invokeCurandInitialize per request.
//! \param batchSize batch size
//! \param maxBatchSize maximum batch size
//! \param vocabSizePadded size of padded vocab
//! \param endIds input buffer [maxBatchSize]. EOS token ids per request
//! \param maxTopP maximum among all topPs P for topP sampling
@ -77,7 +78,7 @@ template <typename T>
void invokeBatchTopPSampling(void* workspace, size_t& workspaceSize, size_t& cubTempStorageSize, int** outputIds,
int* sequenceLength, FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs,
float* outputLogProbs, T const* logProbs, int const* idVals, int* offsetBuf, int* beginOffsetBuf,
curandState_t* curandstate, int const batchSize, size_t const vocabSizePadded, int const* endIds,
curandState_t* curandstate, int const batchSize, int maxBatchSize, size_t const vocabSizePadded, int const* endIds,
float const maxTopP, float const* topPs, cudaStream_t stream, bool const* skipDecode, int const* batchSlots);
//! \brief Specialization of invokeBatchTopPSampling with topPs=nullptr
@ -85,8 +86,8 @@ template <typename T>
void invokeTopPSampling(void* workspace, size_t& workspaceSize, size_t& cubTempStorageSize, int** outputIds,
int* sequenceLength, FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs,
float* outputLogProbs, T const* logProbs, int const* idVals, int* offsetBuf, int* beginOffsetBuf,
curandState_t* curandstate, int const batchSize, size_t const vocabSizePadded, int const* endIds, float const topPp,
cudaStream_t stream, bool const* skipDecode, int const* batchSlots);
curandState_t* curandstate, int const batchSize, int maxBatchSize, size_t const vocabSizePadded, int const* endIds,
float const topPp, cudaStream_t stream, bool const* skipDecode, int const* batchSlots);
//! \brief Given logProbs, performs top P sampling.
//! Note different from invokeTopPSampling() and invokeBatchTopPSampling() there two functions invokeAirTopPSampling
@ -116,6 +117,7 @@ void invokeTopPSampling(void* workspace, size_t& workspaceSize, size_t& cubTempS
//! \param curandstate input buffer [batchSize]. Curand states properly initialized using invokeCurandInitialize per
//! request.
//! \param batchSize batch size
//! \param maxBatchSize max batch size
//! \param vocabSizePadded size of padded vocab
//! \param endIds input buffer [batchSize]. EOS token ids per request
//! \param maxTopP maximum among all topPs P for topP sampling
@ -128,16 +130,17 @@ void invokeTopPSampling(void* workspace, size_t& workspaceSize, size_t& cubTempS
template <typename T>
void invokeBatchAirTopPSampling(void* workspace, size_t& workspaceSize, int** outputIds, int* sequenceLength,
FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs,
T const* logProbs, curandState_t* curandstate, int const batchSize, size_t const vocabSizePadded, int const* endIds,
float const maxTopP, float const* topPs, cudaStream_t stream, int blockNum, bool const* skipDecode,
int32_t const* batchSlots);
T const* logProbs, curandState_t* curandstate, int const batchSize, int maxBatchSize, size_t const vocabSizePadded,
int const* endIds, float const maxTopP, float const* topPs, cudaStream_t stream, int blockNum,
bool const* skipDecode, int32_t const* batchSlots);
//! \brief Specialization of invokeBatchAirTopPSampling with topPs=nullptr
template <typename T>
void invokeAirTopPSampling(void* workspace, size_t& workspaceSize, int** outputIds, int* sequenceLength,
FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs,
T const* logProbs, curandState_t* curandstate, int const batchSize, size_t const vocabSizePadded, int const* endIds,
float const topP, cudaStream_t stream, int blockNum, bool const* skipDecode, int32_t const* batchSlots);
T const* logProbs, curandState_t* curandstate, int const batchSize, int maxBatchSize, size_t const vocabSizePadded,
int const* endIds, float const topP, cudaStream_t stream, int blockNum, bool const* skipDecode,
int32_t const* batchSlots);
//! \brief Calculate the number of blocks based on the number of multiprocessors, batchSize and vocabSize.
//! \tparam T the data type of value

View File

@ -372,7 +372,8 @@ void DynamicDecodeLayer<T>::forward(OutputParams& outputs, ForwardParams const&
checkStopCriteria(outputs, params, batchSlots, batchSize, beamWidth, maxSeqLen, mStream);
// Copy nextIds and transpose logits when needed
prepareOutputData(outputs, params, mIdsPtrHost, batchSlots, batchSize, beamWidth, maxSeqLen, mCyclicStep, mStream);
prepareOutputData(
outputs, params, mIdsPtrHost, batchSlots, batchSize, mMaxBatchSize, beamWidth, maxSeqLen, mCyclicStep, mStream);
mCyclicStep += 1;
@ -489,10 +490,8 @@ void DynamicDecodeLayer<T>::layersForward(Tensor& logits, OutputParams& outputs,
}
if (outputs.output_log_probs_tiled)
{
TLLM_CHECK(0 <= mCyclicStep && mCyclicStep < maxSeqLen);
Tensor& output_log_probs = outputs.output_log_probs_tiled.value();
size_t step_offset = mCyclicStep * batchSize * beamWidth;
decode_outputs.output_log_probs = output_log_probs.slice({1, localBatchSize * beamWidth}, step_offset);
decode_outputs.output_log_probs = output_log_probs.slice({1, localBatchSize * beamWidth}, 0);
}
// Run TopK + TopP decode layers.
@ -697,8 +696,8 @@ void DynamicDecodeLayer<T>::prepareIdsPtrs(
template <typename T>
void DynamicDecodeLayer<T>::prepareOutputData(OutputParams& outputs, ForwardParams const& params,
runtime::ITensor::SharedPtr const& idsPtrsHost, int32_t const* batchSlots, size_t batchSize, size_t beamWidth,
size_t maxSeqLen, int32_t cyclicStep, cudaStream_t stream)
runtime::ITensor::SharedPtr const& idsPtrsHost, int32_t const* batchSlots, size_t batchSize, size_t maxBatchSize,
size_t beamWidth, size_t maxSeqLen, int32_t cyclicStep, cudaStream_t stream)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto idsPtrHostSlice = ITensor::slice(idsPtrsHost, cyclicStep, 1);
@ -713,8 +712,8 @@ void DynamicDecodeLayer<T>::prepareOutputData(OutputParams& outputs, ForwardPara
invokeTransposeLogProbs(outputs.output_log_probs.value().template getPtr<float>(),
outputs.output_log_probs_tiled.value().template getPtr<float>(),
outputs.sequence_length->template getPtr<int>(), batchSlots, batchSize, beamWidth, logProbsMaxSeqLen,
stream);
outputs.sequence_length->template getPtr<int>(), batchSlots, batchSize, maxBatchSize, beamWidth,
logProbsMaxSeqLen, stream);
}
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}

View File

@ -150,7 +150,7 @@ public:
std::optional<tc::Tensor>
output_log_probs_tiled; // [request_output_length, batch_size, beam_width], must be float*, optional
std::optional<tc::Tensor>
output_log_probs; // [batchSize, beam_width, request_ouptut_length], must be float*, optional
output_log_probs; // [batch_size, beam_width, request_output_length], must be float*, optional
std::optional<tc::Tensor>
tgt_cache_indirection; // [local_batch_size, beam_width, max_seq_len], the k/v cache index for beam search
std::shared_ptr<kernels::BeamHypotheses>
@ -164,6 +164,11 @@ public:
void allocateBuffer();
void freeBuffer();
T* getRuntimeLogitsDevice()
{
return mRuntimeLogitsDevice;
}
private:
void initialize();
void initializeLayers();
@ -197,8 +202,8 @@ private:
void prepareIdsPtrs(
OutputParams& outputs, int32_t const* batchSlots, size_t batchSize, size_t beamWidth, size_t maxSeqLen);
static void prepareOutputData(OutputParams& outputs, ForwardParams const& params,
runtime::ITensor::SharedPtr const& idsPtrsHost, int32_t const* batchSlots, size_t batchSize, size_t beamWidth,
size_t maxSeqLen, int32_t cyclicStep, cudaStream_t stream);
runtime::ITensor::SharedPtr const& idsPtrsHost, int32_t const* batchSlots, size_t batchSize,
size_t maxBatchSize, size_t beamWidth, size_t maxSeqLen, int32_t cyclicStep, cudaStream_t stream);
private:
std::unique_ptr<OnlineBeamSearchLayer<T>> mOnlineBeamSearchDecode;

View File

@ -74,8 +74,8 @@ void TopKSamplingLayer<T>::allocateBuffer(size_t const batchSize)
{
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
invokeTopKSampling<T>(nullptr, mSamplingWorkspaceSize, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr,
nullptr, nullptr, TOP_K_MAX, 1.0f, mVocabSizePadded, nullptr, nullptr, mStream, batchSize, nullptr,
mNormalizeLogProbs, false);
nullptr, nullptr, TOP_K_MAX, 1.0f, mVocabSizePadded, nullptr, nullptr, mStream, batchSize, mMaxBatchSize,
nullptr, mNormalizeLogProbs, false);
std::array<size_t, 4> deviceBufferSizes;
deviceBufferSizes[0] = sizeof(uint32_t) * batchSize;
@ -213,7 +213,7 @@ void TopKSamplingLayer<T>::forward(DecodingOutputParams& outputs, ForwardParams&
invokeBatchTopKSampling(samplingWorkspaceDevice, mSamplingWorkspaceSize, logits,
outputs.output_ids_ptr.template getPtr<int*>(), sequenceLength, finishedInput, finishedOutput, cumLogProbs,
outputLogProbs, curandStatesDevice, (int) mRuntimeMaxTopK, (int*) (mRuntimeTopKDevice), 1.0f,
mRuntimeTopPDevice, mVocabSizePadded, endIds, batchSlots, mStream, batchSize, mSkipDecodeDevice,
mRuntimeTopPDevice, mVocabSizePadded, endIds, batchSlots, mStream, batchSize, mMaxBatchSize, mSkipDecodeDevice,
mNormalizeLogProbs, probsComputed);
sync_check_cuda_error();
}

View File

@ -78,8 +78,8 @@ void TopPSamplingLayer<T>::allocateBuffer(size_t batchSize)
nullptr, // cum_log_probs
nullptr, // output_log_probs
nullptr, // log_probs
mTopPIdValsDevice, mTopPOffsetDevice, mBeginTopPOffsetDevice, nullptr, batchSize, mVocabSizePadded, nullptr,
0.f, mStream, nullptr, nullptr);
mTopPIdValsDevice, mTopPOffsetDevice, mBeginTopPOffsetDevice, nullptr, batchSize, mMaxBatchSize,
mVocabSizePadded, nullptr, 0.f, mStream, nullptr, nullptr);
}
else
{
@ -91,7 +91,8 @@ void TopPSamplingLayer<T>::allocateBuffer(size_t batchSize)
nullptr, // cum_log_probs
nullptr, // output_log_probs
nullptr, // log_probs)
nullptr, batchSize, mVocabSizePadded, nullptr, 0.f, mStream, mAirTopPBlockNum, nullptr, nullptr);
nullptr, batchSize, mMaxBatchSize, mVocabSizePadded, nullptr, 0.f, mStream, mAirTopPBlockNum, nullptr,
nullptr);
}
std::array<size_t, 11> deviceBufferSizes;
@ -315,8 +316,8 @@ void TopPSamplingLayer<T>::forward(DecodingOutputParams& outputs, ForwardParams&
invokeBatchTopPSampling<T>(samplingWorkspaceDevice, mSamplingWorkspaceSize, mCubTempStorageSize,
outputs.output_ids_ptr.template getPtr<int*>(), sequenceLength, finishedInput, finishedOutput, cumLogProbs,
outputLogProbs, probs, mTopPIdValsDevice, mTopPOffsetDevice, mBeginTopPOffsetDevice, curandStatesDevice,
batchSize, mVocabSizePadded, endIds, mRuntimeMaxTopP, mRuntimeTopPDevice, mStream, mSkipDecodeDevice,
batchSlots);
batchSize, mMaxBatchSize, mVocabSizePadded, endIds, mRuntimeMaxTopP, mRuntimeTopPDevice, mStream,
mSkipDecodeDevice, batchSlots);
sync_check_cuda_error();
invokeComputeToppDecay(mRuntimeTopPDevice, mInitialTopPDevice,
outputs.output_ids_ptr.template getPtr<const int*>(), mTopPDecayDevice, mTopPMinDevice, mTopPResetIdsDevice,
@ -327,8 +328,8 @@ void TopPSamplingLayer<T>::forward(DecodingOutputParams& outputs, ForwardParams&
{
invokeBatchAirTopPSampling<T>(samplingWorkspaceDevice, mSamplingWorkspaceSize,
outputs.output_ids_ptr.template getPtr<int*>(), sequenceLength, finishedInput, finishedOutput, cumLogProbs,
outputLogProbs, probs, curandStatesDevice, batchSize, mVocabSizePadded, endIds, mRuntimeMaxTopP,
mRuntimeTopPDevice, mStream, mAirTopPBlockNum, mSkipDecodeDevice, batchSlots);
outputLogProbs, probs, curandStatesDevice, batchSize, mMaxBatchSize, mVocabSizePadded, endIds,
mRuntimeMaxTopP, mRuntimeTopPDevice, mStream, mAirTopPBlockNum, mSkipDecodeDevice, batchSlots);
sync_check_cuda_error();
}
}

View File

@ -47,8 +47,27 @@ BertAttentionPlugin::BertAttentionPlugin(int num_heads, int head_size, float q_s
, mRemovePadding(remove_padding)
{
// pre-check whether FMHA is supported in order to save memory allocation
mEnableContextFMHA = mEnableContextFMHA && (mType == DataType::kHALF) && MHARunner::fmha_supported(mHeadSize, mSM)
&& !mRelativeAttention;
if (mEnableContextFMHA)
{
mEnableContextFMHA = false;
if (!(mType == DataType::kHALF || mType == DataType::kBF16))
{
TLLM_LOG_WARNING("Fall back to unfused MHA because of unsupported data type.");
}
else if (!MHARunner::fmha_supported(mHeadSize, mSM))
{
TLLM_LOG_WARNING(
"Fall back to unfused MHA because of unsupported head size %d in sm_{%d}.", mHeadSize, mSM);
}
else if (mRelativeAttention)
{
TLLM_LOG_WARNING("Fall back to unfused MHA because of relative position embedding.");
}
else
{
mEnableContextFMHA = true;
}
}
}
// Parameterized constructor
@ -450,9 +469,23 @@ int BertAttentionPlugin::initialize() noexcept
mCublasWrapper.reset(new tc::CublasMMWrapper(cublasHandle, cublasLtHandle, nullptr, nullptr));
if (mEnableContextFMHA)
{
mFMHARunner.reset(new FusedMHARunnerV2(DATA_TYPE_FP16, mNumHeads, mHeadSize, mQScaling));
// Pre-checked during constructing.
Data_type data_type;
if (mType == DataType::kHALF)
{
data_type = DATA_TYPE_FP16;
}
else if (mType == DataType::kBF16)
{
data_type = DATA_TYPE_BF16;
}
else
{
TLLM_CHECK_WITH_INFO(false, "GPTAttentionPlugin received wrong data type.");
}
mFMHARunner.reset(new FusedMHARunnerV2(data_type, mNumHeads, mHeadSize, mQScaling));
// set flags: force_fp32_acc, is_s_padded, causal_mask, num_kv_heads = num_heads
mFMHARunner->setup_flags(mFMHAForceFP32Acc, true, false, mNumHeads);
mFMHARunner->setup_flags(mFMHAForceFP32Acc, !mRemovePadding, false, mNumHeads);
}
return 0;

View File

@ -80,12 +80,12 @@ BufferManager::ITensorPtr BufferManager::pinnedPool(nvinfer1::Dims dims, nvinfer
return std::make_unique<PinnedPoolTensor>(dims, type);
}
BufferManager::IBufferPtr BufferManager::managed(std::size_t size, nvinfer1::DataType type) const
BufferManager::IBufferPtr BufferManager::managed(std::size_t size, nvinfer1::DataType type)
{
return std::make_unique<UVMBuffer>(size, type);
}
BufferManager::ITensorPtr BufferManager::managed(nvinfer1::Dims dims, nvinfer1::DataType type) const
BufferManager::ITensorPtr BufferManager::managed(nvinfer1::Dims dims, nvinfer1::DataType type)
{
return std::make_unique<UVMTensor>(dims, type);
}
@ -149,8 +149,10 @@ BufferManager::IBufferPtr BufferManager::allocate(
case MemoryType::kCPU: return cpu(size, type);
case MemoryType::kGPU: return gpu(size, type);
case MemoryType::kPINNED: return pinned(size, type);
default: TLLM_THROW("Unknown memory type");
case MemoryType::kUVM: return managed(size, type);
}
TLLM_THROW("Unknown memory type");
}
BufferManager::ITensorPtr BufferManager::allocate(
@ -161,8 +163,10 @@ BufferManager::ITensorPtr BufferManager::allocate(
case MemoryType::kCPU: return cpu(dims, type);
case MemoryType::kGPU: return gpu(dims, type);
case MemoryType::kPINNED: return pinned(dims, type);
default: TLLM_THROW("Unknown memory type");
case MemoryType::kUVM: return managed(dims, type);
}
TLLM_THROW("Unknown memory type");
}
BufferManager::IBufferPtr BufferManager::copyFrom(IBuffer const& src, MemoryType memoryType) const

View File

@ -297,6 +297,12 @@ void GptDecoderBatch::newRequest(
= ITensor::slice(constPointerCast(dJointInput.embeddingBias), batchIdx, localBatchSize);
if (request.embeddingBias)
{
TLLM_CHECK(request.embeddingBias->getShape().nbDims == 2);
TLLM_CHECK(request.embeddingBias->getShape().d[0] == 1);
TLLM_CHECK_WITH_INFO(request.embeddingBias->getShape().d[1] == static_cast<SizeType>(mVocabSize),
"The embedding bias shape is not as expected. Expected last dimension to be same as vocab size: %lu.",
mVocabSize);
manager.copy(*request.embeddingBias, *embeddingBiasSlice);
dInput->embeddingBias = embeddingBiasSlice;
}

View File

@ -35,11 +35,12 @@ MemoryType IBuffer::memoryType(void const* data)
switch (attributes.type)
{
case cudaMemoryTypeHost: return MemoryType::kPINNED;
case cudaMemoryTypeDevice:
case cudaMemoryTypeManaged: return MemoryType::kGPU;
case cudaMemoryTypeDevice: return MemoryType::kGPU;
case cudaMemoryTypeManaged: return MemoryType::kUVM;
case cudaMemoryTypeUnregistered: return MemoryType::kCPU;
default: TLLM_THROW("Unsupported memory type");
}
TLLM_THROW("Unsupported memory type");
}
IBuffer::UniquePtr IBuffer::slice(IBuffer::SharedPtr buffer, std::size_t offset, std::size_t size)

View File

@ -89,7 +89,12 @@ ITensor::UniquePtr ITensor::wrap(void* data, nvinfer1::DataType type, nvinfer1::
new GenericTensor<GpuBorrowingAllocator>(
shape, capacity, type, GpuBorrowingAllocator(data, capacityInBytes)));
break;
default: TLLM_THROW("Unknown memory type");
case MemoryType::kUVM:
result.reset( // NOLINT(modernize-make-unique)
new GenericTensor<ManagedBorrowingAllocator>(
shape, capacity, type, ManagedBorrowingAllocator(data, capacityInBytes)));
break;
default: TLLM_THROW("Invalid memory type."); break;
}
return result;
}

View File

@ -240,6 +240,7 @@ private:
using CpuBorrowingAllocator = BorrowingAllocator<MemoryType::kCPU>;
using GpuBorrowingAllocator = BorrowingAllocator<MemoryType::kGPU>;
using PinnedBorrowingAllocator = BorrowingAllocator<MemoryType::kPINNED>;
using ManagedBorrowingAllocator = BorrowingAllocator<MemoryType::kUVM>;
// using UVMBorrowingAllocator = BorrowingAllocator<MemoryType::kUVM>;

View File

@ -144,12 +144,13 @@ void FtDynamicDecode<T>::forward(th::Tensor& logits, // (batch_size, beam_width,
th::Tensor& output_token_ids, th::Tensor& newTokens, th::Tensor& should_stop,
th::optional<th::Tensor> finished_input, th::optional<th::Tensor> finished_output,
th::optional<th::Tensor> sequence_lengths_opt, th::optional<th::Tensor> cum_log_probs_opt,
th::optional<th::Tensor> output_log_probs_opt, th::optional<th::Tensor> parent_ids_opt,
th::optional<th::Tensor> tgt_cache_indirection_opt, th::optional<th::Tensor> beam_hyps_output_ids_tgt_opt,
th::optional<th::Tensor> beam_hyps_sequence_lengths_tgt_opt, th::optional<th::Tensor> beam_hyps_cum_log_probs_opt,
th::optional<th::Tensor> beam_hyps_normed_scores_opt, th::optional<th::Tensor> beam_hyps_log_probs_opt,
th::optional<th::Tensor> beam_hyps_min_normed_scores_opt, th::optional<th::Tensor> beam_hyps_num_beams_opt,
th::optional<th::Tensor> beam_hyps_is_done_opt, bool use_beam_hyps)
th::optional<th::Tensor> output_log_probs_opt, th::optional<th::Tensor> output_log_probs_tiled_opt,
th::optional<th::Tensor> parent_ids_opt, th::optional<th::Tensor> tgt_cache_indirection_opt,
th::optional<th::Tensor> beam_hyps_output_ids_tgt_opt, th::optional<th::Tensor> beam_hyps_sequence_lengths_tgt_opt,
th::optional<th::Tensor> beam_hyps_cum_log_probs_opt, th::optional<th::Tensor> beam_hyps_normed_scores_opt,
th::optional<th::Tensor> beam_hyps_log_probs_opt, th::optional<th::Tensor> beam_hyps_min_normed_scores_opt,
th::optional<th::Tensor> beam_hyps_num_beams_opt, th::optional<th::Tensor> beam_hyps_is_done_opt,
bool use_beam_hyps)
{
auto const& logits_converted = convert_tensor<float>(logits);
@ -190,6 +191,7 @@ void FtDynamicDecode<T>::forward(th::Tensor& logits, // (batch_size, beam_width,
safeUpdate<int>(parent_ids_opt, outputParams.parent_ids);
safeUpdate<float>(cum_log_probs_opt, outputParams.cum_log_probs);
safeUpdate<float>(output_log_probs_opt, outputParams.output_log_probs);
safeUpdate<float>(output_log_probs_tiled_opt, outputParams.output_log_probs_tiled);
safeUpdate<int>(tgt_cache_indirection_opt, outputParams.tgt_cache_indirection);
if (use_beam_hyps)
@ -297,12 +299,12 @@ th::Tensor DynamicDecodeOp::forward(th::Tensor logits, int64_t step, int64_t max
th::optional<th::Tensor> finished_output,
th::optional<th::Tensor> seuqence_lengths_opt, // length of the current sequences.
th::optional<th::Tensor> cum_log_probs_opt, th::optional<th::Tensor> output_log_probs_opt,
th::optional<th::Tensor> parent_ids_opt, th::optional<th::Tensor> tgt_cache_indirection_opt,
th::optional<th::Tensor> beam_hyps_output_ids_tgt_opt, th::optional<th::Tensor> beam_hyps_sequence_lengths_tgt_opt,
th::optional<th::Tensor> beam_hyps_cum_log_probs_opt, th::optional<th::Tensor> beam_hyps_normed_scores_opt,
th::optional<th::Tensor> beam_hyps_log_probs_opt, th::optional<th::Tensor> beam_hyps_min_normed_scores_opt,
th::optional<th::Tensor> beam_hyps_num_beams_opt, th::optional<th::Tensor> beam_hyps_is_done_opt,
bool use_beam_hyps)
th::optional<th::Tensor> output_log_probs_tiled_opt, th::optional<th::Tensor> parent_ids_opt,
th::optional<th::Tensor> tgt_cache_indirection_opt, th::optional<th::Tensor> beam_hyps_output_ids_tgt_opt,
th::optional<th::Tensor> beam_hyps_sequence_lengths_tgt_opt, th::optional<th::Tensor> beam_hyps_cum_log_probs_opt,
th::optional<th::Tensor> beam_hyps_normed_scores_opt, th::optional<th::Tensor> beam_hyps_log_probs_opt,
th::optional<th::Tensor> beam_hyps_min_normed_scores_opt, th::optional<th::Tensor> beam_hyps_num_beams_opt,
th::optional<th::Tensor> beam_hyps_is_done_opt, bool use_beam_hyps)
{
// Input Arguments:
// logits: [batch_size, beam_width, vocab_size_padded], T
@ -349,6 +351,7 @@ th::Tensor DynamicDecodeOp::forward(th::Tensor logits, int64_t step, int64_t max
CHECK_OPTIONAL_INPUT(seuqence_lengths_opt, torch::kInt32);
CHECK_OPTIONAL_INPUT(cum_log_probs_opt, torch::kFloat32);
CHECK_OPTIONAL_INPUT(output_log_probs_opt, torch::kFloat32);
CHECK_OPTIONAL_INPUT(output_log_probs_tiled_opt, torch::kFloat32);
CHECK_OPTIONAL_INPUT(parent_ids_opt, torch::kInt32);
CHECK_OPTIONAL_INPUT(tgt_cache_indirection_opt, torch::kInt32);
@ -363,7 +366,7 @@ th::Tensor DynamicDecodeOp::forward(th::Tensor logits, int64_t step, int64_t max
static_cast<int32_t>(max_bad_words_len), no_repeat_ngram_size_opt, src_cache_indirection_opt,
// Outputs
output_token_ids, newTokens, should_stop, finished_input, finished_output, seuqence_lengths_opt,
cum_log_probs_opt, output_log_probs_opt, parent_ids_opt, tgt_cache_indirection_opt,
cum_log_probs_opt, output_log_probs_opt, output_log_probs_tiled_opt, parent_ids_opt, tgt_cache_indirection_opt,
beam_hyps_output_ids_tgt_opt, beam_hyps_sequence_lengths_tgt_opt, beam_hyps_cum_log_probs_opt,
beam_hyps_normed_scores_opt, beam_hyps_log_probs_opt, beam_hyps_min_normed_scores_opt, beam_hyps_num_beams_opt,
beam_hyps_is_done_opt, use_beam_hyps);

View File

@ -48,8 +48,9 @@ public:
th::Tensor& output_token_ids, th::Tensor& newTokens, th::Tensor& should_stop,
th::optional<th::Tensor> finished_input, th::optional<th::Tensor> finished_output,
th::optional<th::Tensor> sequence_lengths_opt, th::optional<th::Tensor> cum_log_probs_opt,
th::optional<th::Tensor> output_log_probs_opt, th::optional<th::Tensor> parent_ids_opt,
th::optional<th::Tensor> tgt_cache_indirection_opt, th::optional<th::Tensor> beam_hyps_output_ids_tgt_opt,
th::optional<th::Tensor> output_log_probs_opt, th::optional<th::Tensor> output_log_probs_tiled_opt,
th::optional<th::Tensor> parent_ids_opt, th::optional<th::Tensor> tgt_cache_indirection_opt,
th::optional<th::Tensor> beam_hyps_output_ids_tgt_opt,
th::optional<th::Tensor> beam_hyps_sequence_lengths_tgt_opt,
th::optional<th::Tensor> beam_hyps_cum_log_probs_opt, th::optional<th::Tensor> beam_hyps_normed_scores_opt,
th::optional<th::Tensor> beam_hyps_log_probs_opt, th::optional<th::Tensor> beam_hyps_min_normed_scores_opt,
@ -87,8 +88,9 @@ public:
th::Tensor& output_token_ids, th::Tensor& newTokens, th::Tensor& should_stop,
th::optional<th::Tensor> finished_input, th::optional<th::Tensor> finished_output,
th::optional<th::Tensor> sequence_lengths_opt, th::optional<th::Tensor> cum_log_probs_opt,
th::optional<th::Tensor> output_log_probs_opt, th::optional<th::Tensor> parent_ids_opt,
th::optional<th::Tensor> tgt_cache_indirection_opt, th::optional<th::Tensor> beam_hyps_output_ids_tgt_opt,
th::optional<th::Tensor> output_log_probs_opt, th::optional<th::Tensor> output_log_probs_tiled_opt,
th::optional<th::Tensor> parent_ids_opt, th::optional<th::Tensor> tgt_cache_indirection_opt,
th::optional<th::Tensor> beam_hyps_output_ids_tgt_opt,
th::optional<th::Tensor> beam_hyps_sequence_lengths_tgt_opt,
th::optional<th::Tensor> beam_hyps_cum_log_probs_opt, th::optional<th::Tensor> beam_hyps_normed_scores_opt,
th::optional<th::Tensor> beam_hyps_log_probs_opt, th::optional<th::Tensor> beam_hyps_min_normed_scores_opt,
@ -134,8 +136,8 @@ public:
th::optional<th::Tensor> finished_output,
th::optional<th::Tensor> sequence_lengths_opt, // length of the current sequences.
th::optional<th::Tensor> cum_log_probs_opt, th::optional<th::Tensor> output_log_probs_opt,
th::optional<th::Tensor> parent_ids_opt, th::optional<th::Tensor> tgt_cache_indirection_opt,
th::optional<th::Tensor> beam_hyps_output_ids_tgt_opt,
th::optional<th::Tensor> output_log_probs_tiled_opt, th::optional<th::Tensor> parent_ids_opt,
th::optional<th::Tensor> tgt_cache_indirection_opt, th::optional<th::Tensor> beam_hyps_output_ids_tgt_opt,
th::optional<th::Tensor> beam_hyps_sequence_lengths_tgt_opt,
th::optional<th::Tensor> beam_hyps_cum_log_probs_opt, th::optional<th::Tensor> beam_hyps_normed_scores_opt,
th::optional<th::Tensor> beam_hyps_log_probs_opt, th::optional<th::Tensor> beam_hyps_min_normed_scores_opt,

View File

@ -69,6 +69,7 @@ add_gtest(tllmBuffersTest runtime/tllmBuffersTest.cpp)
add_gtest(bufferManagerTest runtime/bufferManagerTest.cpp)
add_gtest(runtimeKernelTest runtime/runtimeKernelTest.cpp)
add_gtest(samplingTest runtime/samplingTest.cpp)
add_gtest(samplingConfigTest runtime/samplingConfigTest.cpp)
add_gtest(iTensorTest runtime/iTensorTest.cpp)
add_gtest(worldConfigTest runtime/worldConfigTest.cpp)
add_gtest(medusaModuleTest runtime/medusaModuleTest.cpp)
@ -101,3 +102,7 @@ if(BUILD_BATCH_MANAGER)
add_subdirectory(batch_manager)
endif()
endif()
if(BUILD_EXECUTOR)
add_subdirectory(executor)
endif()

View File

@ -43,6 +43,7 @@ protected:
private:
size_t getWorkspaceSize(const SamplingKernelTestParam& params) override
{
auto const maxBatchSize = 2 * params.batchSize;
size_t sampling_workspace_size_;
tk::invokeAirTopPSampling<T>(nullptr, sampling_workspace_size_,
nullptr, // output_ids
@ -52,7 +53,7 @@ private:
nullptr, // cum_log_probs
nullptr, // output_log_probs
nullptr, // log_probs)
this->mCurandStatesDevice, params.batchSize, params.vocabSize, nullptr, this->mMaxTopP,
this->mCurandStatesDevice, params.batchSize, maxBatchSize, params.vocabSize, nullptr, this->mMaxTopP,
this->mStream->get(), 0, nullptr, nullptr);
return sampling_workspace_size_;
}
@ -65,6 +66,7 @@ private:
int smCnt;
TLLM_CUDA_CHECK(cudaGetDevice(&dev));
TLLM_CUDA_CHECK(cudaDeviceGetAttribute(&smCnt, cudaDevAttrMultiProcessorCount, dev));
auto const maxBatchSize = 2 * params.batchSize;
int blockNum = tk::calcAirTopPBlockNum<T, int, float>(params.batchSize, params.vocabSize, smCnt);
// Perform batched TopP sampling
@ -79,8 +81,8 @@ private:
// log-prob if cum_log_probs or output_log_probs are
// provided. It's because the sampling layer already
// preprocesses log_prob_buf when those are provided.
bufferCast<T>(*this->mProbsDevice), this->mCurandStatesDevice, params.batchSize, params.vocabSize,
bufferCast<int32_t>(*this->mEndIdsDevice), this->mMaxTopP,
bufferCast<T>(*this->mProbsDevice), this->mCurandStatesDevice, params.batchSize, maxBatchSize,
params.vocabSize, bufferCast<int32_t>(*this->mEndIdsDevice), this->mMaxTopP,
hasDiffRuntimeArgs ? bufferCast<float>(*this->mTopPsDevice) : nullptr, this->mStream->get(), blockNum,
bufferCast<bool>(*this->mSkipDecodeDevice), bufferCast<int32_t>(*this->mBatchSlots));
}

View File

@ -64,7 +64,7 @@ void SamplingKernelTest<T>::allocateBuffers(
mCumLogProbsDevice = mBufferManager->gpu(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT);
mOutputLogProbsDevice
= mBufferManager->gpu(ITensor::makeShape({maxBatchSize, outputLen}), nvinfer1::DataType::kFLOAT);
= mBufferManager->gpu(ITensor::makeShape({maxSeqLen, maxBatchSize}), nvinfer1::DataType::kFLOAT);
mZeroParentIdsDevice
= mBufferManager->gpu(ITensor::makeShape({maxBatchSize, maxSeqLen}), nvinfer1::DataType::kINT32);

View File

@ -42,16 +42,18 @@ protected:
size_t getWorkspaceSize(const SamplingKernelTestParam& params) override
{
auto const maxBatchSize = 2 * params.batchSize;
size_t workspaceSize;
tk::invokeTopKSampling<T>(nullptr, workspaceSize, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr,
nullptr, this->mMaxTopK, 1.0f, params.vocabSize, nullptr, nullptr, this->mStream->get(), params.batchSize,
nullptr, true, false);
maxBatchSize, nullptr, true, false);
return workspaceSize;
}
void callTestedFunction(const SamplingKernelTestParam& params, bool hasDiffRuntimeArgs, size_t workspaceSize,
tensorrt_llm::runtime::ITensor::SharedPtr& workspaceDevice) override
{
auto const maxBatchSize = 2 * params.batchSize;
// Perform batched TopK sampling
tk::invokeBatchTopKSampling(workspaceDevice->data(), workspaceSize,
// Note that the kernel needs vocab probs instead of
@ -69,7 +71,7 @@ protected:
hasDiffRuntimeArgs ? bufferCast<int32_t>(*this->mTopKsDevice) : nullptr, params.topP,
hasDiffRuntimeArgs ? bufferCast<float>(*this->mTopPsDevice) : nullptr, params.vocabSize,
bufferCast<int32_t>(*this->mEndIdsDevice), bufferCast<int32_t>(*this->mBatchSlots), this->mStream->get(),
params.batchSize, bufferCast<bool>(*this->mSkipDecodeDevice), params.normalizeLogProbs,
params.batchSize, maxBatchSize, bufferCast<bool>(*this->mSkipDecodeDevice), params.normalizeLogProbs,
params.logitsHasProbs);
}
};

View File

@ -43,6 +43,7 @@ protected:
private:
size_t getWorkspaceSize(const SamplingKernelTestParam& params) override
{
auto const maxBatchSize = 2 * params.batchSize;
size_t workspaceSize;
size_t cubTempStorageSize;
tk::invokeBatchTopPSampling<T>(nullptr, // workspace
@ -55,7 +56,7 @@ private:
nullptr, // output_log_probs
nullptr, // log_probs
bufferCast<int32_t>(*this->mTopPIdValsDevice), bufferCast<int32_t>(*this->mEndOffsetsDevice),
bufferCast<int32_t>(*this->mBeginOffsetsDevice), this->mCurandStatesDevice, params.batchSize,
bufferCast<int32_t>(*this->mBeginOffsetsDevice), this->mCurandStatesDevice, params.batchSize, maxBatchSize,
params.vocabSize, nullptr, this->mMaxTopP, bufferCast<float>(*this->mTopPsDevice), this->mStream->get(),
nullptr, nullptr);
return workspaceSize;
@ -64,6 +65,7 @@ private:
void callTestedFunction(const SamplingKernelTestParam& params, bool hasDiffRuntimeArgs, size_t workspaceSize,
tensorrt_llm::runtime::ITensor::SharedPtr& workspaceDevice) override
{
auto const maxBatchSize = 2 * params.batchSize;
size_t cubTempStorageSize;
tk::invokeBatchTopPSampling<T>(nullptr, // workspace
workspaceSize, cubTempStorageSize,
@ -75,7 +77,7 @@ private:
nullptr, // output_log_probs
nullptr, // log_probs
bufferCast<int32_t>(*this->mTopPIdValsDevice), bufferCast<int32_t>(*this->mEndOffsetsDevice),
bufferCast<int32_t>(*this->mBeginOffsetsDevice), this->mCurandStatesDevice, params.batchSize,
bufferCast<int32_t>(*this->mBeginOffsetsDevice), this->mCurandStatesDevice, params.batchSize, maxBatchSize,
params.vocabSize, nullptr, this->mMaxTopP, bufferCast<float>(*this->mTopPsDevice), this->mStream->get(),
nullptr, nullptr);
@ -98,8 +100,9 @@ private:
// preprocesses log_prob_buf when those are provided.
bufferCast<T>(*this->mProbsDevice), bufferCast<int32_t>(*this->mTopPIdValsDevice),
bufferCast<int32_t>(*this->mEndOffsetsDevice), bufferCast<int32_t>(*this->mBeginOffsetsDevice),
this->mCurandStatesDevice, params.batchSize, params.vocabSize, bufferCast<int32_t>(*this->mEndIdsDevice),
this->mMaxTopP, hasDiffRuntimeArgs ? bufferCast<float>(*this->mTopPsDevice) : nullptr, this->mStream->get(),
this->mCurandStatesDevice, params.batchSize, maxBatchSize, params.vocabSize,
bufferCast<int32_t>(*this->mEndIdsDevice), this->mMaxTopP,
hasDiffRuntimeArgs ? bufferCast<float>(*this->mTopPsDevice) : nullptr, this->mStream->get(),
bufferCast<bool>(*this->mSkipDecodeDevice), bufferCast<int32_t>(*this->mBatchSlots));
}
};

View File

@ -15,6 +15,7 @@
*/
#include "tests/layers/dynamicDecodeLayerTest.h"
#include <algorithm>
namespace tensorrt_llm::tests::layers::sampling
{
@ -25,7 +26,7 @@ namespace tensorrt_llm::tests::layers::sampling
// - finished sum
// - max length
// - repeat n grams
// - output logits
// - padded vocab
// - beam search
using namespace tensorrt_llm::runtime;
@ -129,17 +130,19 @@ void DynamicDecodeLayerTest<T>::setup(uint64_t seed, SamplingParams const& param
// clang-format off
// prob = (0.0, 0.0, 0.0, 0.0, 0.4, 0.3, 0.2, 0.1)
// prob = (0.0, 0.0, 0.0, 0.0, 0.4, 0.3, 0.2, 0.1, 0.0)
mTestLogitsInit = {
-FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX, -0.9163, -1.2040, -1.6094, -2.3026, // step 0
-0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX, // step 1
-FLT_MAX, -FLT_MAX, -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, -FLT_MAX, // step 2
-0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX // step 3
-FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX, -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, // step 0
-0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX, // step 1
-FLT_MAX, -FLT_MAX, -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, -FLT_MAX, -FLT_MAX, // step 2
-0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX // step 3
};
// clang-format on
mLogitsDevice = mBufferManager->gpu(ITensor::makeShape({mBatchSize, mBeamWidth, mVocabSizePadded}), dataType);
mRuntimeLogitsHost
= mBufferManager->pinned(ITensor::makeShape({mBatchSize, mBeamWidth, mVocabSizePadded}), dataType);
mSeqLengthsDevice = mBufferManager->gpu(ITensor::makeShape({mMaxBatchSize}), nvinfer1::DataType::kINT32);
mContextLengthDevice = mBufferManager->gpu(ITensor::makeShape({mMaxBatchSize}), nvinfer1::DataType::kINT32);
@ -154,6 +157,13 @@ void DynamicDecodeLayerTest<T>::setup(uint64_t seed, SamplingParams const& param
mEmbeddingBiasHost = mBufferManager->pinned(ITensor::makeShape({mMaxBatchSize, mVocabSizePadded}), dataType);
mEmbeddingBiasDevice = mBufferManager->gpu(ITensor::makeShape({mMaxBatchSize, mVocabSizePadded}), dataType);
mRefLogProbsHost
= mBufferManager->pinned(ITensor::makeShape({mMaxBatchSize, mMaxSeqLen}), nvinfer1::DataType::kFLOAT);
mOutputLogProbsDevice
= mBufferManager->gpu(ITensor::makeShape({mMaxBatchSize, mMaxSeqLen}), nvinfer1::DataType::kFLOAT);
mOutputLogProbsTiledDevice
= mBufferManager->gpu(ITensor::makeShape({mMaxSeqLen, mMaxBatchSize}), nvinfer1::DataType::kFLOAT);
mCumLogProbsDevice = mBufferManager->gpu(ITensor::makeShape({mMaxBatchSize}), nvinfer1::DataType::kFLOAT);
mMaxBadWordsLen = getMaxWordsLen(params.badWords);
@ -177,6 +187,9 @@ void DynamicDecodeLayerTest<T>::setup(uint64_t seed, SamplingParams const& param
trk::invokeFill(*mOutputIdsDevice, int32_t{0}, *mStream);
trk::invokeFill(*mEmbeddingBiasDevice, T{0.0f}, *mStream);
trk::invokeFill(*mCumLogProbsDevice, float{0.0f}, *mStream);
trk::invokeFill(*mOutputLogProbsDevice, float{0.0f}, *mStream);
trk::invokeFill(*mOutputLogProbsTiledDevice, float{0.0f}, *mStream);
trk::invokeFill(*mRefLogProbsHost, float{0.0f}, *mStream);
trk::invokeFill(*mEndIdsDevice, int32_t{mEndId}, *mStream);
auto batchSlotsPtr = bufferCast<int32_t>(*mBatchSlots);
@ -229,6 +242,7 @@ void DynamicDecodeLayerTest<T>::setup(uint64_t seed, SamplingParams const& param
= params.minTopP.size() ? std::make_optional<std::vector<float>>(params.minTopP) : std::nullopt;
setupParams.top_p_reset_ids
= params.topPResetIds.size() ? std::make_optional<std::vector<int32_t>>(params.topPResetIds) : std::nullopt;
setupParams.normalize_log_probs = {false};
initXWordsTensors(batchSlotsPtr, bufferCast<SizeType>(*mBadWords),
reinterpret_cast<SizeType**>(bufferCast<int64_t>(*mBadWordsPtrs)), bufferCast<SizeType>(*mBadWordsLens),
@ -350,10 +364,12 @@ typename DynamicDecodeLayer<T>::OutputParams DynamicDecodeLayerTest<T>::createOu
outputParams.newTokens = tcc::toTllmTensor(*mNewTokens);
outputParams.output_log_probs = tcc::toTllmTensor(*mOutputLogProbsDevice);
outputParams.output_log_probs_tiled = tcc::toTllmTensor(*mOutputLogProbsTiledDevice);
// TODO(nkorobov): extend to
// std::optional<tc::Tensor> parent_ids;
// std::optional<tc::Tensor> output_log_probs_tiled;
// std::optional<tc::Tensor> output_log_probs;
// std::optional<tc::Tensor> tgt_cache_indirection;
// std::shared_ptr<kernels::BeamHypotheses> beamHypotheses;
@ -375,7 +391,7 @@ void DynamicDecodeLayerTest<T>::batchCopy(int32_t step)
}
template <typename T>
bool DynamicDecodeLayerTest<T>::checkResult(int32_t* outputIds, std::vector<std::set<int32_t>>& expectedIds,
bool DynamicDecodeLayerTest<T>::checkResult(int32_t* outputIds, std::vector<std::set<int32_t>> const& expectedIds,
int32_t* seqLens, int32_t leadingDim, int32_t stride, int32_t step)
{
assert(expectedIds.size() == leadingDim * stride);
@ -416,10 +432,34 @@ bool DynamicDecodeLayerTest<T>::checkResult(int32_t* outputIds, std::vector<std:
}
template <typename T>
void DynamicDecodeLayerTest<T>::runTestImpl(
std::vector<std::set<int32_t>> expectedOutputIds, SamplingParams const& params, int32_t endId)
void DynamicDecodeLayerTest<T>::fillRefLogits(
int32_t const* seqLenHost, std::vector<std::set<int32_t>> const& expectedOutputIds, SizeType step)
{
mEndId = endId;
auto const batchSlotsPtr = bufferCast<int32_t>(*mBatchSlots);
auto const runtimeLogitsHost = bufferCast<T>(*mRuntimeLogitsHost);
for (SizeType bi = 0; bi < mBatchBeam; ++bi)
{
auto const batchSlot = batchSlotsPtr[bi];
if (seqLenHost[batchSlot] <= step)
{
continue;
}
auto& expectedSet = expectedOutputIds[step * mBatchBeam + bi];
TLLM_CHECK(expectedSet.size() == 1);
auto expectedToken = *expectedSet.begin();
bufferCast<float>(*mRefLogProbsHost)[batchSlot * mMaxSeqLen + step]
= logf(runtimeLogitsHost[bi * mVocabSizePadded + expectedToken]);
}
}
template <typename T>
void DynamicDecodeLayerTest<T>::runTestImpl(
std::vector<std::set<int32_t>> const& expectedOutputIds, SamplingParams const& params, int32_t endId)
{
mEndId = endId == -1 ? mVocabSize - 1 : endId;
bool greedySearch
= std::all_of(expectedOutputIds.begin(), expectedOutputIds.end(), [](auto v) { return v.size() == 1; });
for (uint64_t seed = 0; seed < mMaxSeed; ++seed)
{
setup(seed, params);
@ -439,6 +479,14 @@ void DynamicDecodeLayerTest<T>::runTestImpl(
auto const seqLenHost
= mBufferManager->copyFrom(*mSeqLengthsDevice, tensorrt_llm::runtime::MemoryType::kCPU);
auto const logitsHost = mBufferManager->copyFrom(*mLogitsDevice, tensorrt_llm::runtime::MemoryType::kCPU);
mBufferManager->copy(
mDecodeLayer->getRuntimeLogitsDevice(), *mRuntimeLogitsHost, tensorrt_llm::runtime::MemoryType::kGPU);
mStream->synchronize();
if (greedySearch)
{
fillRefLogits(bufferCast<int32_t>(*seqLenHost), expectedOutputIds, step);
}
{
bool passed = checkResult(bufferCast<int32_t>(*newTokensHost), expectedOutputIds,
@ -462,24 +510,35 @@ void DynamicDecodeLayerTest<T>::runTestImpl(
mStream->synchronize();
const auto outputIdsHost = mBufferManager->copyFrom(*mOutputIdsDevice, tensorrt_llm::runtime::MemoryType::kCPU);
const auto seqLenHost = mBufferManager->copyFrom(*mSeqLengthsDevice, tensorrt_llm::runtime::MemoryType::kCPU);
auto const outputIdsHost = mBufferManager->copyFrom(*mOutputIdsDevice, tensorrt_llm::runtime::MemoryType::kCPU);
auto const seqLenHost = mBufferManager->copyFrom(*mSeqLengthsDevice, tensorrt_llm::runtime::MemoryType::kCPU);
auto const logProbsHost
= mBufferManager->copyFrom(*mOutputLogProbsDevice, tensorrt_llm::runtime::MemoryType::kCPU);
bool passed = checkResult(bufferCast<int32_t>(*outputIdsHost), expectedOutputIds,
bufferCast<int32_t>(*seqLenHost), mMaxSeqLen, mBatchBeam, /* step */ 0);
EXPECT_TRUE(passed) << "Output Ids check failed at seed " << seed;
if (!passed)
{
std::stringstream ss;
ss << "Actual output ids:" << std::endl << *outputIdsHost;
TLLM_LOG_DEBUG(ss.str());
bool passed = checkResult(bufferCast<int32_t>(*outputIdsHost), expectedOutputIds,
bufferCast<int32_t>(*seqLenHost), mMaxSeqLen, mBatchBeam, /* step */ 0);
EXPECT_TRUE(passed) << "Output Ids check failed at seed " << seed;
if (!passed)
{
std::stringstream ss;
ss << "Actual output ids:" << std::endl << *outputIdsHost;
TLLM_LOG_DEBUG(ss.str());
}
}
if (greedySearch)
{
bool passed = compareValues(
bufferCast<float>(*logProbsHost), bufferCast<float>(*mRefLogProbsHost), mMaxSeqLen * mMaxBatchSize);
EXPECT_TRUE(passed) << "Log probs check failed at seed " << seed;
}
}
}
template <typename T>
void DynamicDecodeLayerTest<T>::runTest(
std::vector<std::set<int32_t>> expectedOutputIds, SamplingParams const& params, int32_t endId)
std::vector<std::set<int32_t>> const& expectedOutputIds, SamplingParams const& params, int32_t endId)
{
TLLM_LOG_DEBUG("Run test with linear logits");
mUseLogitsVec = false;

View File

@ -70,7 +70,7 @@ private:
int32_t const mMaxBatchSize = 2 * mBatchSize;
int32_t const mBeamWidth = 1;
int32_t const mBatchBeam = mBatchSize * mBeamWidth;
int32_t const mVocabSize = 8;
int32_t const mVocabSize = 9;
int32_t const mVocabSizePadded = mVocabSize;
int32_t const mMaxInputLen = 0; // has no effect.
@ -82,6 +82,7 @@ private:
bool mUseLogitsVec = false;
TensorPtr mLogitsDevice;
TensorPtr mRuntimeLogitsHost;
TensorPtr mLogitsRefHost;
TensorPtr mContextLengthDevice;
TensorPtr mSeqLengthsDevice;
@ -103,6 +104,10 @@ private:
TensorPtr mEmbeddingBiasHost;
TensorPtr mEmbeddingBiasDevice;
TensorPtr mRefLogProbsHost;
TensorPtr mOutputLogProbsDevice;
TensorPtr mOutputLogProbsTiledDevice;
TensorPtr mCumLogProbsDevice;
std::vector<tensorrt_llm::common::Tensor> mLogitsVec;
@ -134,14 +139,18 @@ private:
typename tensorrt_llm::layers::DynamicDecodeLayer<T>::OutputParams createOutputTensors();
void batchCopy(int32_t step);
bool checkResult(int32_t* outputIds, std::vector<std::set<int32_t>>& expectedIds, int32_t* seqLens,
bool checkResult(int32_t* outputIds, std::vector<std::set<int32_t>> const& expectedIds, int32_t* seqLens,
int32_t leadingDim, int32_t stride, int32_t step);
void runTestImpl(
std::vector<std::set<int32_t>> expectedOutputIds, SamplingParams const& params, int32_t endId = -1);
std::vector<std::set<int32_t>> const& expectedOutputIds, SamplingParams const& params, int32_t endId = -1);
void fillRefLogits(
int32_t const* seqLenHost, std::vector<std::set<int32_t>> const& expectedOutputIds, int32_t step);
public:
void runTest(std::vector<std::set<int32_t>> expectedOutputIds, SamplingParams const& params, int32_t endId = -1);
void runTest(
std::vector<std::set<int32_t>> const& expectedOutputIds, SamplingParams const& params, int32_t endId = -1);
};
typedef testing::Types<float, half> FloatAndHalfTypes;

View File

@ -0,0 +1,77 @@
/* * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License,
* Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a
* copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/runtime/samplingConfig.h"
#include "tensorrt_llm/executor/executor.h"
#include "tensorrt_llm/executor/types.h"
#include <gmock/gmock.h>
#include <gtest/gtest.h>
using ::testing::_;
using ::testing::Invoke;
namespace tr = tensorrt_llm::runtime;
namespace tc = tensorrt_llm::common;
namespace texec = tensorrt_llm::executor;
TEST(samplingConfigTest, validInputs)
{
{
texec::SamplingConfig execSamplingCfg(1);
tr::SamplingConfig samplingCfg(execSamplingCfg, std::nullopt);
EXPECT_EQ(samplingCfg.beamWidth, execSamplingCfg.getBeamWidth());
EXPECT_EQ(samplingCfg.draftAcceptanceThreshold, std::nullopt);
}
{
texec::SamplingConfig execSamplingCfg(1);
texec::SpeculativeDecodingConfig specCfg({1}, std::nullopt, 0.5);
tr::SamplingConfig samplingCfg(execSamplingCfg, specCfg);
EXPECT_EQ(samplingCfg.beamWidth, execSamplingCfg.getBeamWidth());
EXPECT_TRUE(samplingCfg.draftAcceptanceThreshold.has_value());
EXPECT_THAT(samplingCfg.draftAcceptanceThreshold.value(), testing::ElementsAre(0.5f));
}
{
texec::SizeType topK = 1;
texec::FloatType topP = 0.5;
texec::FloatType topPMin = 0.1;
texec::SizeType topPResetIds = 1;
texec::FloatType topPDecay = 0.6;
uint64_t randomSeed = 7777;
texec::FloatType temperature = 0.245;
texec::SizeType minLength = 1234;
texec::FloatType beamSearchDiversityRate = 0.9999;
texec::FloatType repetitionPenalty = 0.11;
texec::FloatType presencePenalty = 0.22;
texec::FloatType frequencyPenalty = 0.33;
texec::FloatType lengthPenalty = 0.44;
texec::SamplingConfig execSamplingCfg(1, topK, topP, topPMin, topPResetIds, topPDecay, randomSeed, temperature,
minLength, beamSearchDiversityRate, repetitionPenalty, presencePenalty, frequencyPenalty, lengthPenalty);
texec::SpeculativeDecodingConfig specCfg({1}, std::nullopt, 0.5);
tr::SamplingConfig samplingCfg(execSamplingCfg, specCfg);
EXPECT_EQ(samplingCfg.beamWidth, execSamplingCfg.getBeamWidth());
EXPECT_THAT(samplingCfg.draftAcceptanceThreshold.value(), testing::ElementsAre(0.5f));
EXPECT_THAT(samplingCfg.temperature.value(), testing::ElementsAre(temperature));
EXPECT_THAT(samplingCfg.minLength.value(), testing::ElementsAre(minLength));
EXPECT_THAT(samplingCfg.repetitionPenalty.value(), testing::ElementsAre(repetitionPenalty));
EXPECT_THAT(samplingCfg.presencePenalty.value(), testing::ElementsAre(presencePenalty));
EXPECT_THAT(samplingCfg.frequencyPenalty.value(), testing::ElementsAre(frequencyPenalty));
EXPECT_THAT(samplingCfg.topK.value(), testing::ElementsAre(topK));
EXPECT_THAT(samplingCfg.topP.value(), testing::ElementsAre(topP));
EXPECT_THAT(samplingCfg.randomSeed.value(), testing::ElementsAre(randomSeed));
EXPECT_THAT(samplingCfg.topPMin.value(), testing::ElementsAre(topPMin));
EXPECT_THAT(samplingCfg.topPResetIds.value(), testing::ElementsAre(topPResetIds));
EXPECT_THAT(samplingCfg.beamSearchDiversityRate.value(), testing::ElementsAre(beamSearchDiversityRate));
EXPECT_THAT(samplingCfg.lengthPenalty.value(), testing::ElementsAre(lengthPenalty));
}
}

View File

@ -113,6 +113,7 @@ python build.py --model_type t5 \
--max_beam_width 3
# Example 4: build bart-large-cnn using a single GPU, FP32, running greedy search
# Note: non-T5 models can enable FMHA for the encoder part, for FP16/BF16
python build.py --model_type bart \
--weight_dir tmp/trt_models/bart-large-cnn/tp1 \
-o tmp/trt_engines/bart-large-cnn/1-gpu \
@ -120,6 +121,7 @@ python build.py --model_type bart \
--remove_input_padding \
--use_bert_attention_plugin \
--use_gpt_attention_plugin \
--enable_context_fmha \
--use_gemm_plugin \
--dtype float32 \
--max_beam_width 1
@ -237,12 +239,14 @@ pushd tmp && (git clone https://github.com/facebookresearch/fairseq.git || true)
python nmt/convert.py -i tmp/fairseq_models/wmt14 -o tmp/trt_models/wmt14 --weight_data_type float32 --inference_tensor_para_size 1
# Build TensorRT engine(s)
# Note: non-T5 models can enable FMHA for the encoder part, although only FP16/BF16 precisions are valid
python build.py --model_type nmt \
--weight_dir tmp/trt_models/wmt14/tp1/ \
-o tmp/trt_engines/wmt14/1-gpu \
--engine_name wmt14 \
--use_bert_attention_plugin \
--use_gpt_attention_plugin \
--enable_context_fmha \
--dtype float32 \
--max_beam_width 1

View File

@ -27,6 +27,7 @@ from tensorrt_llm.builder import Builder
from tensorrt_llm.logger import logger
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.network import net_guard
from tensorrt_llm.plugin.plugin import ContextFMHAType
from t5.weight import parse_t5_config, load_from_hf_t5, load_from_binary_t5 # isort:skip
from bart.weight import parse_bart_config, load_from_binary_bart # isort:skip
@ -185,6 +186,12 @@ def parse_arguments(component):
parser.add_argument('--enable_qk_half_accum',
default=False,
action='store_true')
parser.add_argument('--enable_context_fmha',
default=False,
action='store_true')
parser.add_argument('--enable_context_fmha_fp32_acc',
default=False,
action='store_true')
parser.add_argument('--builder_opt', type=int, default=None)
parser.add_argument('--remove_input_padding',
default=False,
@ -404,6 +411,14 @@ def build_rank_engine(builder: Builder,
network.plugin_config.set_gemm_plugin(dtype=args.use_gemm_plugin)
if args.enable_qk_half_accum:
network.plugin_config.enable_qk_half_accum()
assert not (args.enable_context_fmha and args.enable_context_fmha_fp32_acc)
if args.enable_context_fmha and not args.relative_attention:
logger.warning("Only non-T5 enc-dec models support FMHA")
network.plugin_config.set_context_fmha(ContextFMHAType.enabled)
if args.enable_context_fmha_fp32_acc and not args.relative_attention:
logger.warning("Only non-T5 enc-dec models support FMHA")
network.plugin_config.set_context_fmha(
ContextFMHAType.enabled_with_fp32_acc)
if args.remove_input_padding:
network.plugin_config.enable_remove_input_padding()
if args.use_lookup_plugin:

View File

@ -417,6 +417,7 @@ class TRTLLMEncDecModel:
prompt_tasks=None,
prompt_vocab_size=None,
attention_mask=None,
time_encoder=False,
):
## ensure all externally provided tensors are on the correct device.
encoder_input_ids = encoder_input_ids.to(self.device)
@ -436,6 +437,8 @@ class TRTLLMEncDecModel:
if not self.skip_encoder:
logger.info(f"Rank {self.runtime_rank} Running encoder engine ...")
if time_encoder:
tik = time.time()
encoder_output = self.encoder_run(
encoder_input_ids,
encoder_input_lengths,
@ -445,6 +448,9 @@ class TRTLLMEncDecModel:
prompt_tasks=prompt_tasks,
prompt_vocab_size=prompt_vocab_size,
attention_mask=attention_mask)
if time_encoder:
tok = time.time()
print(f"TRT-LLM Encoder time {(tok-tik)*1000}ms")
else:
encoder_output = prompt_embedding_table
if encoder_input_ids.dim() > 1:
@ -472,7 +478,10 @@ class TRTLLMEncDecModel:
sampling_config = SamplingConfig(end_id=eos_token_id,
pad_id=pad_token_id,
num_beams=num_beams,
min_length=1)
min_length=1,
return_dict=return_dict)
sampling_config.update(output_cum_log_probs=False,
output_log_probs=False)
# decoder autoregressive generation
self.decoder_session.setup(
@ -485,7 +494,7 @@ class TRTLLMEncDecModel:
)
torch.cuda.synchronize()
output_ids = self.decoder_session.decode(
output = self.decoder_session.decode(
decoder_input_ids,
decoder_input_lengths,
sampling_config,
@ -495,7 +504,7 @@ class TRTLLMEncDecModel:
cross_attention_mask=cross_attention_mask)
torch.cuda.synchronize()
return output_ids
return output
def test_fairseq_models(args):
@ -545,8 +554,9 @@ def test_fairseq_models(args):
inference_dtype = tllm_model.encoder_model_config.dtype
return_dict = False # when set return_dict=True, get outputs by key
tik = time.time()
tllm_output_ids = tllm_model.generate(
tllm_output = tllm_model.generate(
encoder_input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
max_new_tokens=max_new_tokens,
@ -557,6 +567,12 @@ def test_fairseq_models(args):
debug_mode=args.debug_mode,
)
tok = time.time()
if return_dict:
tllm_output_ids = tllm_output['output_ids']
else:
tllm_output_ids = tllm_output
output_ids = tllm_output_ids[:, 0, :]
output_ids = output_ids[output_ids != eos_token_id]
fairseq_output_ids = fairseq_output_ids[fairseq_output_ids != eos_token_id]
@ -680,8 +696,10 @@ if __name__ == "__main__":
tllm_model = TRTLLMEncDecModel.from_engine(args.engine_name,
args.engine_dir,
debug_mode=args.debug_mode)
return_dict = False # when set return_dict=True, get outputs by key
tik = time.time()
tllm_output_ids = tllm_model.generate(
tllm_output = tllm_model.generate(
encoder_input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
max_new_tokens=max_new_tokens,
@ -690,10 +708,16 @@ if __name__ == "__main__":
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
debug_mode=args.debug_mode,
return_dict=False, # when set return_dict=True, get outputs by key
attention_mask=tokenized_inputs.attention_mask)
return_dict=return_dict,
attention_mask=tokenized_inputs.attention_mask,
time_encoder=True)
tok = time.time()
if return_dict:
tllm_output_ids = tllm_output['output_ids']
else:
tllm_output_ids = tllm_output
inference_dtype = tllm_model.encoder_model_config.dtype
if tensorrt_llm.mpi_rank() == 0:

View File

@ -72,9 +72,8 @@ def parse_t5_config(config, component, args):
args.hidden_act = config.get(component, 'dense_act_fn')
args.gated_act = config.getboolean(component, 'is_gated_act')
args.mlp_type = mlp_type_map['GatedMLP' if args.gated_act else 'MLP']
args.relative_attention = config.getboolean(component,
'relative_attention',
fallback=True)
args.relative_attention = config.get(
'structure', 'position_embedding_type') == 'relative'
args.num_buckets = config.getint(component,
'relative_attention_num_buckets')
args.max_distance = config.getint(component,

734
examples/gemma/README.md Normal file
View File

@ -0,0 +1,734 @@
# Run Gemma on TensorRT-LLM
## Table Of Contents
- [Run Gemma on TensorRT-LLM](#run-gemma-on-tensorrt-llm)
- [Table Of Contents](#table-of-contents)
- [Support Matrix](#support-matrix)
- [Common scripts](#common-scripts)
- [Convert checkpoint](#convert-checkpoint)
- [Build engine](#build-engine)
- [Run inference](#run-inference)
- [Specific commands](#specific-commands)
- [Run Gemma 2B](#run-gemma-2b)
- [Run inference under bfloat16 for keras checkpoint](#run-inference-under-bfloat16-for-keras-checkpoint)
- [Run inference under FP8 for keras checkpoint](#run-inference-under-fp8-for-keras-checkpoint)
- [Run inference under SmoothQuant for jax checkpoint](#run-2b-inference-under-smoothquant-for-jax-checkpoint)
- [Run inference under weight only for jax checkpoint](#run-inference-under-weight-only-for-jax-checkpoint)
- [Run inference under INT8 KV caches for jax checkpoint](#run-inference-under-int8-kv-caches-for-jax-checkpoint)
- [Run Gemma 7B](#run-gemma-7b)
- [Run inference under bfloat16 for torch checkpoint](#run-inference-under-bfloat16-for-torch-checkpoint)
- [Run inference under FP8 for jax checkpoint](#run-inference-under-fp8-for-jax-checkpoint)
- [Run inference under SmoothQuant for jax checkpoint](#run-7b-inference-under-smoothquant-for-jax-checkpoint)
- [Run inference under weight only for keras checkpoint](#run-inference-under-weight-only-for-keras-checkpoint)
- [Run inference under INT8 KV caches for keras checkpoint](#run-inference-under-int8-kv-caches-for-keras-checkpoint)
- [Run AMMO Quantization](#run-ammo-quantization)
- [Requirements](#requirements)
- [Quantize Checkpoints](#quantize-checkpoints)
- [Build Engines](#build-engines)
- [Accuracy Results on MMLU](#accuracy-results-on-mmlu)
## Support Matrix
* FP32/FP16/BF16/INT8 Weight-Only/INT4 Weight-Only/SmoothQuant/FP8
* For SmoothQuant, TRT-LLM only supports FP16 higher precision now.
* checkpoint type: Jax, Torch, Keras
* STRONGLY TYPED
* python runtime and triton backend
## Common scripts
### Convert checkpoint
Users can use `convert_checkpoint.py` to convert the different source checkpoint to unified TensorRT-LLM checkpoint format. Users could set `--dtype` to determine the inference data type, and set the quantization options like `--enable_fp8`, `--fp8_kv_cache` `--use_smooth_quant`, `--calibrate_kv_cache` (for INT8 kv cache) and `--use-weight-only-with-precision` (weight only). Users could also control the source checkpoint type by `--ckpt-type`. Currently, supported checkpoint types are `jax`, `torch` and `keras`.
```bash
CKPT_PATH=/tmp/models/gemma_nv/checkpoints/tmp_2b_it
UNIFIED_CKPT_PATH=/tmp/checkpoints/tmp_2b_it_tensorrt_llm/bf16/tp1/
python3 ./convert_checkpoint.py \
--ckpt-type jax \
--model-dir ${CKPT_PATH} \
--dtype bfloat16 \
--world-size 1 \
--output-model-dir ${UNIFIED_CKPT_PATH}
```
### Build engine
After getting checkpoint, we can use `trtllm-build` command to build TensorRT-LLM engines from TensorRT-LLM checkpoints.
```bash
ENGINE_PATH=/tmp/gemma/2B/bf16/1-gpu/
trtllm-build --checkpoint_dir ${UNIFIED_CKPT_PATH} \
--gemm_plugin bfloat16 \
--gpt_attention_plugin bfloat16 \
--max_batch_size 8 \
--max_input_len 3000 \
--max_output_len 100 \
--context_fmha enable \
--output_dir ${ENGINE_PATH}
```
### Run inference
We provide three examples to run inference `run.py`, `summarize.py` and `mmlu.py`. `run.py` only run inference with `input_text` and show the output.
`summarize.py` runs summarization on [cnn_dailymail](https://huggingface.co/datasets/cnn_dailymail) dataset and evaluate the model by [ROUGE](https://en.wikipedia.org/wiki/ROUGE_(metric)) scores and use the `ROUGE-1` score to validate the implementation.
`mmlu.py` runs MMLU to evaluate the model by accuracy.
Note that we need to download the dataset of MMLU first and the evaluation of MMLU requires more time.
* run.py
```bash
VOCAB_FILE_PATH=/tmp/models/gemma_nv/checkpoints/tmp_vocab.model
python3 ../run.py --engine_dir ${ENGINE_PATH} \
--max_output_len 30 \
--vocab_file ${VOCAB_FILE_PATH} \
--no_add_special_tokens
[TensorRT-LLM] TensorRT-LLM version: 0.9.0.dev2024020600Input [Text 0]: "<bos> Born in north-east France, Soyer trained as a"
Output [Text 0 Beam 0]: "chef in the renowned kitchens of Lyon. After honing his skills in various Michelin-starred establishments, he embarked on a solo venture, establishing his own restaurant"
```
* summarize.py
```bash
python3 ../summarize.py --test_trt_llm \
--engine_dir ${ENGINE_PATH} \
--batch_size 8 \
--max_ite 5 \
--vocab_file ${VOCAB_FILE_PATH} \
--no_add_special_tokens
[02/06/2024-10:08:54] [TRT-LLM] [I] TensorRT-LLM (total latency: 3.2821836471557617 sec)
[02/06/2024-10:08:54] [TRT-LLM] [I] TensorRT-LLM (total output tokens: 1989)
[02/06/2024-10:08:54] [TRT-LLM] [I] TensorRT-LLM (tokens per second: 605.9989975648089)
[02/06/2024-10:08:54] [TRT-LLM] [I] TensorRT-LLM beam 0 result
[02/06/2024-10:08:55] [TRT-LLM] [I] rouge1 : 26.376388677070615
[02/06/2024-10:08:55] [TRT-LLM] [I] rouge2 : 7.468157586877296
[02/06/2024-10:08:55] [TRT-LLM] [I] rougeL : 17.953060795106556
[02/06/2024-10:08:55] [TRT-LLM] [I] rougeLsum : 22.410938121151652
```
* mmlu.py
Download the dataset first
```bash
mkdir data
wget https://people.eecs.berkeley.edu/~hendrycks/data.tar -O data/mmlu.tar
tar -xf data/mmlu.tar -C data
mv data/data data/mmlu
```
Evaluate on MMLU dataset.
```bash
python3 ../mmlu.py --test_trt_llm \
--vocab_file ${VOCAB_FILE_PATH} \
--engine_dir ${ENGINE_PATH}
Average accuracy 0.358 - social sciences
Average accuracy 0.359 - other (business, health, misc.)
Average accuracy: 0.329
```
## Specific commands
In this section, we demonstrate the scripts to convert checkpoint, building engine and run inference on different settings. We will not demonstrate all combinations here because there are too many cases. We choose some important cases to demonstrate.
### Run Gemma 2B
#### Run inference under bfloat16 for keras checkpoint
```bash
CKPT_PATH=/tmp/models/gemma_keras/keras/gemma_2b_en/
UNIFIED_CKPT_PATH=/tmp/checkpoints/tmp_2b_en_tensorrt_llm/bf16/tp1/
ENGINE_PATH=/tmp/gemma/2B/bf16/1-gpu/
VOCAB_FILE_PATH=/tmp/models/gemma_nv/checkpoints/tmp_vocab.model
python3 ./convert_checkpoint.py \
--ckpt-type keras \
--model-dir ${CKPT_PATH} \
--dtype bfloat16 \
--world-size 1 \
--output-model-dir ${UNIFIED_CKPT_PATH}
trtllm-build --checkpoint_dir ${UNIFIED_CKPT_PATH} \
--gemm_plugin bfloat16 \
--gpt_attention_plugin bfloat16 \
--max_batch_size 8 \
--max_input_len 3000 \
--max_output_len 100 \
--context_fmha enable \
--output_dir ${ENGINE_PATH}
python3 ../summarize.py --test_trt_llm \
--vocab_file ${VOCAB_FILE_PATH} \
--engine_dir ${ENGINE_PATH} \
--batch_size 8 \
--max_ite 5 \
--no_add_special_tokens
[02/08/2024-05:04:13] [TRT-LLM] [I] TensorRT-LLM (total latency: 3.96612286567688 sec)
[02/08/2024-05:04:13] [TRT-LLM] [I] TensorRT-LLM (total output tokens: 2510)
[02/08/2024-05:04:13] [TRT-LLM] [I] TensorRT-LLM (tokens per second: 632.8598697034137)
[02/08/2024-05:04:13] [TRT-LLM] [I] TensorRT-LLM beam 0 result
[02/08/2024-05:04:13] [TRT-LLM] [I] rouge1 : 20.40970022875146
[02/08/2024-05:04:13] [TRT-LLM] [I] rouge2 : 5.512437888775742
[02/08/2024-05:04:13] [TRT-LLM] [I] rougeL : 15.135998543979978
[02/08/2024-05:04:13] [TRT-LLM] [I] rougeLsum : 17.250431908889873
```
#### Run inference under FP8 for keras checkpoint
WARNING: This way of running FP8 will introduce noticeable accuracy drop. To avoid that, use AMMO quantization mentioned in this readme.
In this example, we demonstrate how to run FP8 inference on Gemma. Note that `convert_checkpoint.py` only uses identity activation scales, so the accuracy might be little worse than higher precision in some cases, but it is still very good because we don't do any calibration. This also shows the stability of FP8 compared to INT8.
```bash
CKPT_PATH=/tmp/models/gemma_keras/keras/gemma_2b_en/
UNIFIED_CKPT_PATH=/tmp/checkpoints/tmp_2b_en_tensorrt_llm/fp8/tp1/
ENGINE_PATH=/tmp/gemma/2B/fp8/1-gpu/
VOCAB_FILE_PATH=/tmp/models/gemma_nv/checkpoints/tmp_vocab.model
python3 ./convert_checkpoint.py \
--ckpt-type keras \
--model-dir ${CKPT_PATH} \
--dtype bfloat16 \
--world-size 1 \
--enable_fp8 \
--fp8_kv_cache \
--output-model-dir ${UNIFIED_CKPT_PATH}
trtllm-build --checkpoint_dir ${UNIFIED_CKPT_PATH} \
--gemm_plugin bfloat16 \
--gpt_attention_plugin bfloat16 \
--max_batch_size 8 \
--max_input_len 3000 \
--max_output_len 100 \
--context_fmha enable \
--output_dir ${ENGINE_PATH}
python3 ../summarize.py --test_trt_llm \
--vocab_file ${VOCAB_FILE_PATH} \
--engine_dir ${ENGINE_PATH} \
--batch_size 8 \
--max_ite 5 \
--no_add_special_tokens
[02/08/2024-10:37:15] [TRT-LLM] [I] TensorRT-LLM (total latency: 3.116227149963379 sec)
[02/08/2024-10:37:15] [TRT-LLM] [I] TensorRT-LLM (total output tokens: 2419)
[02/08/2024-10:37:15] [TRT-LLM] [I] TensorRT-LLM (tokens per second: 776.259201781368)
[02/08/2024-10:37:15] [TRT-LLM] [I] TensorRT-LLM beam 0 result
[02/08/2024-10:37:15] [TRT-LLM] [I] rouge1 : 20.206082692133098
[02/08/2024-10:37:15] [TRT-LLM] [I] rouge2 : 5.902141189518428
[02/08/2024-10:37:15] [TRT-LLM] [I] rougeL : 15.403458457907643
[02/08/2024-10:37:15] [TRT-LLM] [I] rougeLsum : 17.44535527417846
python3 ../mmlu.py --test_trt_llm \
--vocab_file ${VOCAB_FILE_PATH} \
--engine_dir ${ENGINE_PATH}
Average accuracy 0.390 - social sciences
Average accuracy 0.405 - other (business, health, misc.)
Average accuracy: 0.356
```
#### Run 2B inference under SmoothQuant for jax checkpoint
```bash
CKPT_PATH=/tmp/models/gemma_nv/checkpoints/tmp_2b_it
UNIFIED_CKPT_PATH=/tmp/checkpoints/tmp_2b_it_tensorrt_llm/sq/tp1
ENGINE_PATH=/tmp/gemma/2B/int8_sq/1-gpu/
VOCAB_FILE_PATH=/tmp/models/gemma_nv/checkpoints/tmp_vocab.model
python3 ./convert_checkpoint.py \
--ckpt-type jax \
--model-dir ${CKPT_PATH} \
--dtype float16 \
--use_smooth_quant_plugin 0.5 \
--tokenizer_dir ${VOCAB_FILE_PATH} \
--output-model-dir ${UNIFIED_CKPT_PATH}
trtllm-build --checkpoint_dir ${UNIFIED_CKPT_PATH} \
--gemm_plugin float16 \
--gpt_attention_plugin float16 \
--max_batch_size 8 \
--max_input_len 3000 \
--max_output_len 100 \
--context_fmha enable \
--enable_xqa enable \
--output_dir ${ENGINE_PATH}
python3 ../summarize.py --test_trt_llm \
--vocab_file ${VOCAB_FILE_PATH} \
--engine_dir ${ENGINE_PATH} \
--batch_size 8 \
--max_ite 5 \
--no_add_special_tokens
[02/08/2024-04:42:06] [TRT-LLM] [I] TensorRT-LLM (total latency: 3.460859775543213 sec)
[02/08/2024-04:42:06] [TRT-LLM] [I] TensorRT-LLM (total output tokens: 1786)
[02/08/2024-04:42:06] [TRT-LLM] [I] TensorRT-LLM (tokens per second: 516.0567361385428)
[02/08/2024-04:42:06] [TRT-LLM] [I] TensorRT-LLM beam 0 result
[02/08/2024-04:42:06] [TRT-LLM] [I] rouge1 : 22.534044843245525
[02/08/2024-04:42:06] [TRT-LLM] [I] rouge2 : 5.940093176022924
[02/08/2024-04:42:06] [TRT-LLM] [I] rougeL : 16.258991712579736
[02/08/2024-04:42:06] [TRT-LLM] [I] rougeLsum : 19.60977626046262
```
#### Run inference under weight only for jax checkpoint
Available precisions: `int8` and `int4`
* `int8`
```bash
CKPT_PATH=/tmp/models/gemma_nv/checkpoints/tmp_2b_it
UNIFIED_CKPT_PATH=/tmp/checkpoints/tmp_2b_it_tensorrt_llm/w8_a16/tp1/
ENGINE_PATH=/tmp/gemma/2B/w8_a16/1-gpu/
VOCAB_FILE_PATH=/tmp/models/gemma_nv/checkpoints/tmp_vocab.model
python3 ./convert_checkpoint.py \
--ckpt-type jax \
--model-dir ${CKPT_PATH} \
--use-weight-only-with-precision int8 \
--dtype bfloat16 \
--output-model-dir ${UNIFIED_CKPT_PATH}
trtllm-build --checkpoint_dir ${UNIFIED_CKPT_PATH} \
--gemm_plugin bfloat16 \
--gpt_attention_plugin bfloat16 \
--max_batch_size 32 \
--max_input_len 3000 \
--max_output_len 100 \
--context_fmha enable \
--enable_xqa enable \
--output_dir ${ENGINE_PATH}
python3 ../summarize.py --test_trt_llm \
--vocab_file ${VOCAB_FILE_PATH} \
--engine_dir ${ENGINE_PATH} \
--batch_size 8 \
--max_ite 5 \
--no_add_special_tokens
[02/08/2024-04:44:54] [TRT-LLM] [I] TensorRT-LLM (total latency: 3.5987987518310547 sec)
[02/08/2024-04:44:54] [TRT-LLM] [I] TensorRT-LLM (total output tokens: 1797)
[02/08/2024-04:44:54] [TRT-LLM] [I] TensorRT-LLM (tokens per second: 499.3332842203787)
[02/08/2024-04:44:54] [TRT-LLM] [I] TensorRT-LLM beam 0 result
[02/08/2024-04:44:54] [TRT-LLM] [I] rouge1 : 24.48521318679745
[02/08/2024-04:44:54] [TRT-LLM] [I] rouge2 : 7.240543314565931
[02/08/2024-04:44:54] [TRT-LLM] [I] rougeL : 17.857921729984078
[02/08/2024-04:44:54] [TRT-LLM] [I] rougeLsum : 21.214162155642896
```
* `int4`
```bash
CKPT_PATH=/tmp/models/gemma_nv/checkpoints/tmp_2b_it
UNIFIED_CKPT_PATH=/tmp/checkpoints/tmp_2b_it_tensorrt_llm/w4_a16/tp1/
ENGINE_PATH=/tmp/gemma/2B/w4_a16/1-gpu/
VOCAB_FILE_PATH=/tmp/models/gemma_nv/checkpoints/tmp_vocab.model
python3 ./convert_checkpoint.py \
--ckpt-type jax \
--model-dir ${CKPT_PATH} \
--use-weight-only-with-precision int4 \
--dtype bfloat16 \
--output-model-dir ${UNIFIED_CKPT_PATH}
trtllm-build --checkpoint_dir ${UNIFIED_CKPT_PATH} \
--gemm_plugin bfloat16 \
--gpt_attention_plugin bfloat16 \
--max_batch_size 32 \
--max_input_len 3000 \
--max_output_len 100 \
--context_fmha enable \
--enable_xqa enable \
--output_dir ${ENGINE_PATH}
python3 ../summarize.py --test_trt_llm \
--vocab_file ${VOCAB_FILE_PATH} \
--engine_dir ${ENGINE_PATH} \
--batch_size 8 \
--max_ite 5 \
--no_add_special_tokens
[02/08/2024-04:48:06] [TRT-LLM] [I] TensorRT-LLM (total latency: 3.1938045024871826 sec)
[02/08/2024-04:48:06] [TRT-LLM] [I] TensorRT-LLM (total output tokens: 1462)
[02/08/2024-04:48:06] [TRT-LLM] [I] TensorRT-LLM (tokens per second: 457.7612683749003)
[02/08/2024-04:48:06] [TRT-LLM] [I] TensorRT-LLM beam 0 result
[02/08/2024-04:48:06] [TRT-LLM] [I] rouge1 : 25.19118129834017
[02/08/2024-04:48:06] [TRT-LLM] [I] rouge2 : 6.284558232487986
[02/08/2024-04:48:06] [TRT-LLM] [I] rougeL : 18.133244708843726
[02/08/2024-04:48:06] [TRT-LLM] [I] rougeLsum : 20.562024727650662
```
#### Run inference under INT8 KV caches for jax checkpoint
```bash
CKPT_PATH=/tmp/models/gemma_nv/checkpoints/tmp_2b_it
UNIFIED_CKPT_PATH=/tmp/checkpoints/tmp_2b_it_tensorrt_llm/int8kv/tp1
ENGINE_PATH=/tmp/gemma/2B/int8kv/1-gpu/
VOCAB_FILE_PATH=/tmp/models/gemma_nv/checkpoints/tmp_vocab.model
python3 ./convert_checkpoint.py \
--ckpt-type jax \
--model-dir ${CKPT_PATH} \
--world-size 1 \
--dtype bfloat16 \
--calibrate_kv_cache \
--tokenizer_dir ${VOCAB_FILE_PATH} \
--output-model-dir ${UNIFIED_CKPT_PATH}
trtllm-build --checkpoint_dir ${UNIFIED_CKPT_PATH} \
--gemm_plugin bfloat16 \
--gpt_attention_plugin bfloat16 \
--max_batch_size 32 \
--max_input_len 3000 \
--max_output_len 100 \
--context_fmha enable \
--enable_xqa enable \
--strongly_type \
--output_dir ${ENGINE_PATH}
python3 ../summarize.py --test_trt_llm \
--vocab_file ${VOCAB_FILE_PATH} \
--engine_dir ${ENGINE_PATH} \
--batch_size 8 \
--max_ite 5 \
--no_add_special_tokens
[02/08/2024-04:52:22] [TRT-LLM] [I] TensorRT-LLM (total latency: 3.5348474979400635 sec)
[02/08/2024-04:52:22] [TRT-LLM] [I] TensorRT-LLM (total output tokens: 1819)
[02/08/2024-04:52:22] [TRT-LLM] [I] TensorRT-LLM (tokens per second: 514.5907994786265)
[02/08/2024-04:52:22] [TRT-LLM] [I] TensorRT-LLM beam 0 result
[02/08/2024-04:52:22] [TRT-LLM] [I] rouge1 : 24.0397941580232
[02/08/2024-04:52:22] [TRT-LLM] [I] rouge2 : 7.325311340360227
[02/08/2024-04:52:22] [TRT-LLM] [I] rougeL : 17.54210044633271
[02/08/2024-04:52:22] [TRT-LLM] [I] rougeLsum : 20.627861723682177
```
### Run Gemma 7B
#### Run inference under bfloat16 for torch checkpoint
Since torch model does not have model config, we need to add it manually in `CKPT_PATH` with file name `config.json`.
```bash
CKPT_PATH=/tmp/models/pytorch/ckpt/
UNIFIED_CKPT_PATH=/tmp/checkpoints/tmp_7b_it_tensorrt_llm/bf16/tp1/
ENGINE_PATH=/tmp/gemma/7B/bf16/1-gpu/
VOCAB_FILE_PATH=/tmp/models/gemma_nv/checkpoints/tmp_vocab.model
python3 ./convert_checkpoint.py \
--ckpt-type torch \
--model-dir ${CKPT_PATH} \
--dtype bfloat16 \
--world-size 1 \
--output-model-dir ${UNIFIED_CKPT_PATH}
trtllm-build --checkpoint_dir ${UNIFIED_CKPT_PATH} \
--gemm_plugin bfloat16 \
--gpt_attention_plugin bfloat16 \
--max_batch_size 8 \
--max_input_len 3000 \
--max_output_len 100 \
--context_fmha enable \
--output_dir ${ENGINE_PATH}
python3 ../summarize.py --test_trt_llm \
--vocab_file ${VOCAB_FILE_PATH} \
--engine_dir ${ENGINE_PATH} \
--batch_size 8 \
--max_ite 5 \
--no_add_special_tokens
python3 ../mmlu.py --test_trt_llm \
--vocab_file ${VOCAB_FILE_PATH} \
--engine_dir ${ENGINE_PATH}
Average accuracy 0.739 - social sciences
Average accuracy 0.697 - other (business, health, misc.)
Average accuracy: 0.630
```
#### Run inference under FP8 for jax checkpoint
WARNING: This way of running FP8 will introduce noticeable accuracy drop. To avoid that, use AMMO quantization mentioned in this readme.
In this example, we demonstrate how to run FP8 inference on Gemma. Note that `convert_checkpoint.py` only uses identity activation scales, so the accuracy might be little worse than higher precision in some cases, but it is still very good because we don't do any calibration. This also shows the stability of FP8 compared to INT8.
```bash
CKPT_PATH=/tmp/models/gemma_nv/checkpoints/tmp_7b_it
UNIFIED_CKPT_PATH=/tmp/checkpoints/tmp_7b_it_tensorrt_llm/fp8/tp1/
ENGINE_PATH=/tmp/gemma/7B/fp8/1-gpu/
VOCAB_FILE_PATH=/tmp/models/gemma_nv/checkpoints/tmp_vocab.model
python3 ./convert_checkpoint.py \
--ckpt-type jax \
--model-dir ${CKPT_PATH} \
--dtype bfloat16 \
--world-size 1 \
--enable_fp8 \
--fp8_kv_cache \
--output-model-dir ${UNIFIED_CKPT_PATH}
trtllm-build --checkpoint_dir ${UNIFIED_CKPT_PATH} \
--gemm_plugin bfloat16 \
--gpt_attention_plugin bfloat16 \
--max_batch_size 8 \
--max_input_len 3000 \
--max_output_len 100 \
--context_fmha enable \
--output_dir ${ENGINE_PATH}
python3 ../summarize.py --test_trt_llm \
--vocab_file ${VOCAB_FILE_PATH} \
--engine_dir ${ENGINE_PATH} \
--batch_size 8 \
--max_ite 5 \
--no_add_special_tokens
[02/08/2024-06:42:13] [TRT-LLM] [I] TensorRT-LLM (total latency: 5.884302377700806 sec)
[02/08/2024-06:42:13] [TRT-LLM] [I] TensorRT-LLM (total output tokens: 2694)
[02/08/2024-06:42:13] [TRT-LLM] [I] TensorRT-LLM (tokens per second: 457.8282737830064)
[02/08/2024-06:42:13] [TRT-LLM] [I] TensorRT-LLM beam 0 result
[02/08/2024-06:42:13] [TRT-LLM] [I] rouge1 : 27.18633861010837
[02/08/2024-06:42:13] [TRT-LLM] [I] rouge2 : 7.734928823230158
[02/08/2024-06:42:13] [TRT-LLM] [I] rougeL : 19.32537431798716
[02/08/2024-06:42:13] [TRT-LLM] [I] rougeLsum : 22.82522575944535
```
#### Run 7B inference under SmoothQuant for jax checkpoint
```bash
CKPT_PATH=/tmp/models/gemma_nv/checkpoints/tmp_7b_it
UNIFIED_CKPT_PATH=/tmp/checkpoints/tmp_7b_it_tensorrt_llm/sq/tp1
ENGINE_PATH=/tmp/gemma/7B/int8_sq/1-gpu/
VOCAB_FILE_PATH=/tmp/models/gemma_nv/checkpoints/tmp_vocab.model
python3 ./convert_checkpoint.py \
--ckpt-type jax \
--model-dir ${CKPT_PATH} \
--dtype float16 \
--use_smooth_quant_plugin 0.5 \
--tokenizer_dir ${VOCAB_FILE_PATH} \
--output-model-dir ${UNIFIED_CKPT_PATH}
trtllm-build --checkpoint_dir ${UNIFIED_CKPT_PATH} \
--gemm_plugin float16 \
--gpt_attention_plugin float16 \
--max_batch_size 8 \
--max_input_len 3000 \
--max_output_len 100 \
--context_fmha enable \
--enable_xqa enable \
--output_dir ${ENGINE_PATH}
python3 ../summarize.py --test_trt_llm \
--vocab_file ${VOCAB_FILE_PATH} \
--engine_dir ${ENGINE_PATH} \
--batch_size 8 \
--max_ite 5 \
--no_add_special_tokens
[02/19/2024-10:02:53] [TRT-LLM] [I] ---------------------------------------------------------
[02/19/2024-10:03:09] [TRT-LLM] [I] TensorRT-LLM (total latency: 13.65670919418335 sec)
[02/19/2024-10:03:09] [TRT-LLM] [I] TensorRT-LLM (total output tokens: 8351)
[02/19/2024-10:03:09] [TRT-LLM] [I] TensorRT-LLM (tokens per second: 611.494312521266)
[02/19/2024-10:03:09] [TRT-LLM] [I] TensorRT-LLM beam 0 result
[02/19/2024-10:03:09] [TRT-LLM] [I] rouge1 : 28.8107815115074
[02/19/2024-10:03:09] [TRT-LLM] [I] rouge2 : 8.623835512061866
[02/19/2024-10:03:09] [TRT-LLM] [I] rougeL : 19.7277195532959
[02/19/2024-10:03:09] [TRT-LLM] [I] rougeLsum : 23.434950511855114
```
#### Run inference under weight only for keras checkpoint
Available precisions: `int8` and `int4`
* `int8`
```bash
CKPT_PATH=/tmp/models/gemma_keras/keras/gemma_7b_en/
UNIFIED_CKPT_PATH=/tmp/checkpoints/tmp_7b_it_tensorrt_llm/w8_a16/tp1/
ENGINE_PATH=/tmp/gemma/7B/w8_a16/1-gpu/
VOCAB_FILE_PATH=/tmp/models/gemma_nv/checkpoints/tmp_vocab.model
python3 ./convert_checkpoint.py \
--ckpt-type keras \
--model-dir ${CKPT_PATH} \
--use-weight-only-with-precision int8 \
--dtype bfloat16 \
--output-model-dir ${UNIFIED_CKPT_PATH}
trtllm-build --checkpoint_dir ${UNIFIED_CKPT_PATH} \
--gemm_plugin bfloat16 \
--gpt_attention_plugin bfloat16 \
--max_batch_size 32 \
--max_input_len 3000 \
--max_output_len 100 \
--context_fmha enable \
--enable_xqa enable \
--output_dir ${ENGINE_PATH}
python3 ../summarize.py --test_trt_llm \
--vocab_file ${VOCAB_FILE_PATH} \
--engine_dir ${ENGINE_PATH} \
--batch_size 8 \
--max_ite 5 \
--no_add_special_tokens
[02/08/2024-07:38:15] [TRT-LLM] [I] TensorRT-LLM (total latency: 8.49835753440857 sec)
[02/08/2024-07:38:15] [TRT-LLM] [I] TensorRT-LLM (total output tokens: 2654)
[02/08/2024-07:38:15] [TRT-LLM] [I] TensorRT-LLM (tokens per second: 312.2956393931832)
[02/08/2024-07:38:15] [TRT-LLM] [I] TensorRT-LLM beam 0 result
[02/08/2024-07:38:16] [TRT-LLM] [I] rouge1 : 20.396209981234687
[02/08/2024-07:38:16] [TRT-LLM] [I] rouge2 : 5.73302850102211
[02/08/2024-07:38:16] [TRT-LLM] [I] rougeL : 16.001683776127507
[02/08/2024-07:38:16] [TRT-LLM] [I] rougeLsum : 18.36957526315223
```
* `int4`
```bash
CKPT_PATH=/tmp/models/gemma_nv/checkpoints/tmp_7b_it
UNIFIED_CKPT_PATH=/tmp/checkpoints/tmp_7b_it_tensorrt_llm/w4_a16/tp1/
ENGINE_PATH=/tmp/gemma/7B/w4_a16/1-gpu/
VOCAB_FILE_PATH=/tmp/models/gemma_nv/checkpoints/tmp_vocab.model
python3 ./convert_checkpoint.py \
--ckpt-type jax \
--model-dir ${CKPT_PATH} \
--use-weight-only-with-precision int4 \
--dtype bfloat16 \
--output-model-dir ${UNIFIED_CKPT_PATH}
trtllm-build --checkpoint_dir ${UNIFIED_CKPT_PATH} \
--gemm_plugin bfloat16 \
--gpt_attention_plugin bfloat16 \
--max_batch_size 32 \
--max_input_len 3000 \
--max_output_len 100 \
--context_fmha enable \
--enable_xqa enable \
--output_dir ${ENGINE_PATH}
python3 ../summarize.py --test_trt_llm \
--vocab_file ${VOCAB_FILE_PATH} \
--engine_dir ${ENGINE_PATH} \
--batch_size 8 \
--max_ite 5 \
--no_add_special_tokens
[02/08/2024-07:43:32] [TRT-LLM] [I] TensorRT-LLM (total latency: 7.282559156417847 sec)
[02/08/2024-07:43:32] [TRT-LLM] [I] TensorRT-LLM (total output tokens: 2253)
[02/08/2024-07:43:32] [TRT-LLM] [I] TensorRT-LLM (tokens per second: 309.3692686333369)
[02/08/2024-07:43:32] [TRT-LLM] [I] TensorRT-LLM beam 0 result
[02/08/2024-07:43:32] [TRT-LLM] [I] rouge1 : 27.22556858171486
[02/08/2024-07:43:32] [TRT-LLM] [I] rouge2 : 6.889046653923549
[02/08/2024-07:43:32] [TRT-LLM] [I] rougeL : 19.07040336076859
[02/08/2024-07:43:32] [TRT-LLM] [I] rougeLsum : 22.840545705675858
```
#### Run inference under INT8 KV caches for keras checkpoint
```bash
CKPT_PATH=/tmp/models/gemma_keras/keras/gemma_7b_en/
UNIFIED_CKPT_PATH=/tmp/checkpoints/tmp_7b_it_tensorrt_llm/int8kv/tp1
ENGINE_PATH=/tmp/gemma/7B/int8kv/1-gpu/
VOCAB_FILE_PATH=/tmp/models/gemma_nv/checkpoints/tmp_vocab.model
python3 ./convert_checkpoint.py \
--ckpt-type keras \
--model-dir ${CKPT_PATH} \
--world-size 1 \
--dtype bfloat16 \
--calibrate_kv_cache \
--tokenizer_dir ${VOCAB_FILE_PATH} \
--output-model-dir ${UNIFIED_CKPT_PATH}
trtllm-build --checkpoint_dir ${UNIFIED_CKPT_PATH} \
--gemm_plugin bfloat16 \
--gpt_attention_plugin bfloat16 \
--max_batch_size 32 \
--max_input_len 3000 \
--max_output_len 100 \
--context_fmha enable \
--enable_xqa enable \
--strongly_type \
--output_dir ${ENGINE_PATH}
python3 ../summarize.py --test_trt_llm \
--vocab_file ${VOCAB_FILE_PATH} \
--engine_dir ${ENGINE_PATH} \
--batch_size 8 \
--max_ite 5 \
--no_add_special_tokens
[02/08/2024-07:51:11] [TRT-LLM] [I] TensorRT-LLM (total latency: 8.73880124092102 sec)
[02/08/2024-07:51:11] [TRT-LLM] [I] TensorRT-LLM (total output tokens: 2771)
[02/08/2024-07:51:11] [TRT-LLM] [I] TensorRT-LLM (tokens per second: 317.09154649544956)
[02/08/2024-07:51:11] [TRT-LLM] [I] TensorRT-LLM beam 0 result
[02/08/2024-07:51:11] [TRT-LLM] [I] rouge1 : 20.934864626327627
[02/08/2024-07:51:11] [TRT-LLM] [I] rouge2 : 4.954721611692932
[02/08/2024-07:51:11] [TRT-LLM] [I] rougeL : 15.307592049634444
[02/08/2024-07:51:11] [TRT-LLM] [I] rougeLsum : 17.94213019528988
```
### Run AMMO Quantization
#### Requirements
AMMO toolkit provides quantization solutions with better accuracy. To enable it, have the latest ammo and transformers Python package installed to support Gemma. Then run the following commands.
#### Quantize Checkpoints
```
python ../quantization/quantize.py --model_dir ${HF_GEMMA_PATH} \
--dtype float16 \
--qformat ${QUANT_TYPE} \
--output_dir ${UNIFIED_CKPT_PATH} \
--tp_size 1
```
HF_GEMMA_PATH can either be HF model card name or the downloaded model path. QUANT_TYPE can be chosen from fp8, int4_awq, and int8_sq.
#### Build Engines
For fp8, build engines with:
```
trtllm-build --checkpoint_dir ${UNIFIED_CKPT_PATH} \
--gemm_plugin float16 \
--gpt_attention_plugin float16 \
--max_batch_size 8 \
--max_input_len 3000 \
--max_output_len 100 \
--context_fmha enable \
--output_dir ${ENGINE_PATH}
```
For int4_awq and int8_sq, build engines with:
```
trtllm-build --checkpoint_dir ${UNIFIED_CKPT_PATH} \
--gemm_plugin float16 \
--gpt_attention_plugin float16 \
--max_batch_size 8 \
--max_input_len 3000 \
--max_output_len 100 \
--context_fmha enable \
--enable_xqa enable \
--output_dir ${ENGINE_PATH}
```
#### Accuracy Results on MMLU
| Model | fp8 | int4_awq | int8_sq |
|---------------|-------|----------|---------|
| 2B Pretrained | 0.407 | 0.378 | 0.328 |
| 7B Pretrained | 0.643 | 0.615 | 0.480 |

View File

@ -0,0 +1,856 @@
#!/usr/bin/env python3
import argparse
import json
import logging
import math
import pathlib
import re
import time
import typing
import flax.traverse_util
import h5py
import numpy as np
import safetensors.numpy
import safetensors.torch
import sentencepiece as sp
import torch
import utils.params
import utils.transformer
from datasets import load_dataset
from easydict import EasyDict
import tensorrt_llm
from tensorrt_llm._utils import torch_to_numpy
from tensorrt_llm.models.gemma.smoothquant import *
from tensorrt_llm.models.gemma.weight import (dummy_weights_awq,
load_from_fp8_llama,
quantize_fp8_weigths)
LOGGER = logging.getLogger("convert_checkpoint")
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt-type",
type=str,
choices=["jax", "keras", "torch"])
parser.add_argument("--model-dir", type=pathlib.Path, required=True)
parser.add_argument("--output-model-dir", type=pathlib.Path, required=True)
parser.add_argument("--world-size",
type=int,
default=1,
help="world size, only support tensor parallelism now")
parser.add_argument(
"--use-weight-only-with-precision",
choices=["int8", "int4", "w4a8_awq", "w4a16_awq"],
help=
"help='Quantize weights for the various GEMMs to INT4/INT8. Define the precision for the weights.",
)
parser.add_argument("--dtype",
type=str,
choices=["float32", "bfloat16", "float16"])
parser.add_argument(
"--enable_fp8",
action="store_true",
help="Use FP8 Linear layer for Attention QKV/Dense and MLP.")
parser.add_argument(
"--fp8_kv_cache",
action="store_true",
help=
"By default, we use dtype for KV cache. fp8_kv_cache chooses int8 quantization for KV",
)
parser.add_argument(
"--ammo_quant_ckpt_path",
default=None,
help=
"Path of a directory to quantized model checkpoints in .safetensors format or \
path of a quantized model checkpoint in .npz format")
parser.add_argument('--use_smooth_quant',
default=False,
action="store_true",
help="Use smooth quant.")
parser.add_argument(
"--calibrate_kv_cache",
"-kv",
action="store_true",
help=
"Generate scaling factors for KV cache. Used for storing KV cache in int8."
)
parser.add_argument(
'--per_channel',
default=False,
action="store_true",
help=
'By default, we use a single static scaling factor for the GEMM\'s result. '
'per_channel instead uses a different static scaling factor for each channel. '
'The latter is usually more accurate, but a little slower.')
parser.add_argument(
'--per_token',
default=False,
action="store_true",
help=
'By default, we use a single static scaling factor to scale activations in the int8 range. '
'per_token chooses at run time, and for each token, a custom scaling factor. '
'The latter is usually more accurate, but a little slower.')
parser.add_argument(
"--use_smooth_quant_plugin",
"-sq",
type=float,
default=None,
help="Set the α parameter (see https://arxiv.org/pdf/2211.10438.pdf)"
" to Smoothquant the model, and output int8 weights."
" A good first try is 0.5. Must be in [0, 1]")
parser.add_argument(
'--tokenizer_dir',
default=None,
help='tokenizer path; defaults to jax_model_dir if left unspecified')
args = parser.parse_args()
return args
class JAXParser:
def load_parameters(self, checkpoint_path: pathlib.Path):
checkpoint_path = checkpoint_path.absolute()
return utils.params.nest_params(
utils.params.param_remapper(
utils.params.load_params(checkpoint_path)))
def embedding_weights(self, ckpt_params):
return ckpt_params["transformer"]["embedder"]["input_embedding"]
def get_config(self, checkpoint_path, ckpt_params, num_embed):
return utils.transformer.TransformerConfig.from_params(
ckpt_params, num_embed=num_embed)
def rename_to_trt_llm(self, name: str):
"""Rename a gemma parameter name by the corresponding TRT-LLM style name."""
prefix, name = name.split(".", maxsplit=1)
assert prefix == "transformer"
sub_patterns = (
(r"embedder.input_embedding", r"vocab_embedding.weight"),
(r"layer_(\d+).pre_attention_norm.scale",
r"layers.\1.input_layernorm.weight"),
(r"layer_(\d+).attn.q_einsum.w", r"layers.\1.attention.qkv.weight"),
(r"layer_(\d+).attn.kv_einsum.w",
None), # drop as kv will be concatenated with q
(r"layer_(\d+).attn.qkv_einsum.w",
r"layers.\1.attention.qkv.weight"),
(r"layer_(\d+).attn.attn_vec_einsum.w",
r"layers.\1.attention.dense.weight"),
(r"layer_(\d+).mlp.gating_einsum", r"layers.\1.mlp.fc.weight"),
(r"layer_(\d+).mlp.linear", r"layers.\1.mlp.proj.weight"),
(r"layer_(\d+).pre_ffw_norm.scale",
r"layers.\1.post_layernorm.weight"),
(r"final_norm.scale", r"ln_f.weight"),
)
for source, target in sub_patterns:
if re.match(source, name):
if target is None:
return target
else:
name = re.sub(source, target, name)
return ".".join((prefix, name))
else:
raise ValueError(f"Don't know how to rename {prefix}.{name}")
def flatten_params(self, params):
return flax.traverse_util.flatten_dict(params, sep=".")
class KerasParser:
def load_parameters(self, checkpoint_path: pathlib.Path):
checkpoint_path = checkpoint_path.absolute()
config_file = "config.json"
weights_file = json.load(open(checkpoint_path / config_file))["weights"]
h5_path = checkpoint_path / weights_file
return h5py.File(h5_path, "r+")
def embedding_weights(self, ckpt_params):
return np.array(ckpt_params["layers/reversible_embedding/vars/0"])
def get_config(self, checkpoint_path, ckpt_params, num_embed):
checkpoint_path = checkpoint_path.absolute()
config_file = "config.json"
config_old = json.load(open(checkpoint_path / config_file))["config"]
config_new = {}
config_new["num_layers"] = config_old["num_layers"]
config_new["num_embed"] = config_old["vocabulary_size"]
config_new["embed_dim"] = config_old["hidden_dim"]
config_new["hidden_dim"] = config_old["intermediate_dim"] // 2
config_new["num_heads"] = config_old["num_query_heads"]
config_new["head_dim"] = config_old["head_dim"]
config_new["num_kv_heads"] = config_old["num_key_value_heads"]
return EasyDict(config_new)
def rename_to_trt_llm(self, name: str):
"""Rename a gemma parameter name by the corresponding TRT-LLM style name."""
prefix = "transformer"
name = name.replace("/gemma_decoder_block/", "/gemma_decoder_block_0/")
sub_patterns = (
(r"layers/reversible_embedding/vars/0", r"vocab_embedding.weight"),
(r"layers/gemma_decoder_block_(\d+)/pre_attention_norm/vars/0",
r"layers.\1.input_layernorm.weight"),
(r"layers/gemma_decoder_block_(\d+)/attention/query_dense/vars/0",
r"layers.\1.attention.qkv.weight"),
(r"layers/gemma_decoder_block_(\d+)/attention/key_dense/vars/0",
None), # drop as k will be concatenated with q
(r"layers/gemma_decoder_block_(\d+)/attention/value_dense/vars/0",
None), # drop as v will be concatenated with q
(r"layers/gemma_decoder_block_(\d+)/attention/output_dense/vars/0",
r"layers.\1.attention.dense.weight"),
(r"layers/gemma_decoder_block_(\d+)/gating_ffw/vars/0",
r"layers.\1.mlp.fc.weight"),
(r"layers/gemma_decoder_block_(\d+)/gating_ffw_2/vars/0",
None), # merged with above
(r"layers/gemma_decoder_block_(\d+)/ffw_linear/vars/0",
r"layers.\1.mlp.proj.weight"),
(r"layers/gemma_decoder_block_(\d+)/pre_ffw_norm/vars/0",
r"layers.\1.post_layernorm.weight"),
(r"layers/rms_normalization/vars/0", r"ln_f.weight"),
(r"optimizer/vars/(\d+)", None), # Not used
)
for source, target in sub_patterns:
if re.match(source, name):
if target is None:
return target
else:
name = re.sub(source, target, name)
return ".".join((prefix, name))
else:
raise ValueError(f"Don't know how to rename {prefix}.{name}")
def flatten_params(self, params):
f_params = {}
def walk(name, obj):
if isinstance(obj, h5py.Dataset):
f_params[name] = np.array(obj)
params.visititems(walk)
return f_params
class TorchParser:
def load_parameters(self, checkpoint_path: pathlib.Path):
ckpt_path = list(checkpoint_path.glob('*.ckpt'))[0]
model_params = torch.load(ckpt_path)['model_state_dict']
model_params.pop('freqs_cis')
return model_params
def embedding_weights(self, ckpt_params):
return ckpt_params["embedder.weight"]
def get_config(self, checkpoint_path, ckpt_params, num_embed):
checkpoint_path = checkpoint_path.absolute()
config_file = "config.json"
with open(checkpoint_path / config_file, 'r') as f:
json_str = f.read()
json_str = json_str.replace("'", "\"")
json_str = json_str.replace(",\n}", "\n}")
config_old = json.loads(json_str)
config_new = {}
config_new["num_layers"] = config_old["num_hidden_layers"]
config_new["num_embed"] = config_old["vocab_size"]
config_new["embed_dim"] = config_old["hidden_size"]
config_new["hidden_dim"] = config_old["intermediate_size"]
config_new["num_heads"] = config_old["num_attention_heads"]
config_new["head_dim"] = config_old["head_dim"]
config_new["num_kv_heads"] = config_old["num_key_value_heads"]
return EasyDict(config_new)
def rename_to_trt_llm(self, name: str):
"""Rename a gemma parameter name by the corresponding TRT-LLM style name."""
prefix = "transformer"
sub_patterns = (
(r"embedder.weight", r"vocab_embedding.weight"),
(r"model.layers.(\d+).input_layernorm.weight",
r"layers.\1.input_layernorm.weight"),
(r"model.layers.(\d+).self_attn.qkv_proj.weight",
r"layers.\1.attention.qkv.weight"),
(r"model.layers.(\d+).self_attn.o_proj.weight",
r"layers.\1.attention.dense.weight"),
(r"model.layers.(\d+).mlp.gate_proj.weight",
r"layers.\1.mlp.fc.weight"),
(r"model.layers.(\d+).mlp.up_proj.weight",
None), # merged with above
(r"model.layers.(\d+).mlp.down_proj.weight",
r"layers.\1.mlp.proj.weight"),
(r"model.layers.(\d+).post_attention_layernorm.weight",
r"layers.\1.post_layernorm.weight"),
(r"model.norm.weight", r"ln_f.weight"),
)
for source, target in sub_patterns:
if re.match(source, name):
if target is None:
return target
else:
name = re.sub(source, target, name)
return ".".join((prefix, name))
else:
raise ValueError(f"Don't know how to rename {name}")
def flatten_params(self, params):
f_params = {}
for k, v in params.items():
if v.dtype == torch.bfloat16:
v = v.float()
f_params[k] = torch_to_numpy(v)
return f_params
CKPT_PARSER = {'jax': JAXParser, 'keras': KerasParser, 'torch': TorchParser}
def split(v, tp_size, idx, dim=0):
if tp_size == 1:
return v
return np.split(v, tp_size, axis=dim)[idx]
def split_matrix_tp(v, tensor_parallel, rank, dim):
return split(v, tensor_parallel, rank, dim=dim)
def add_trt_llm_weight(weights: typing.Dict[str, np.ndarray],
name: str,
param: np.ndarray,
dtype: typing.Optional[np.dtype] = None):
assert name not in weights, f"{name} is already added."
if dtype is not None:
param = param.astype(dtype)
param = np.ascontiguousarray(param)
weights[name] = param
def quantize(param: np.ndarray,
quant_mode: tensorrt_llm.quantization.QuantMode):
if quant_mode.is_int8_weight_only():
quant_dtype = torch.int8
elif quant_mode.is_int4_weight_only():
quant_dtype = torch.quint4x2
else:
raise ValueError(f"Invalid configuration got quant_mode={quant_mode}")
if param.dtype == np.dtype("bfloat16"):
param = torch.from_numpy(param.astype(np.float32)).to(torch.bfloat16)
else:
param = torch.from_numpy(param)
param = param.t().contiguous()
# previously this fn was available in torch.ops.fastertransformer namespace
(
quantized_weights,
scales,
) = torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix(
param, quant_dtype)
if scales.dtype == torch.bfloat16:
scales = scales.to(torch.float32).numpy().astype("bfloat16")
else:
scales = scales.numpy()
return quantized_weights.numpy(), scales
def convert_from_checkpoint(
trt_llm_config: tensorrt_llm.models.modeling_utils.PretrainedConfig,
model_dir: typing.Union[str, pathlib.Path],
ckpt_parser,
rank=0,
):
print("Loading weights...")
tik = time.time()
tp_rank = rank
tp_size = trt_llm_config.mapping.tp_size
hidden_size = trt_llm_config.hidden_size
head_dim = trt_llm_config.head_size
weights = {}
for model_file in [model_dir]:
LOGGER.debug(f"Loading directory {str(model_file)}...")
model_params = ckpt_parser.load_parameters(model_file)
model_params = ckpt_parser.flatten_params(model_params)
for name, param in model_params.items():
LOGGER.debug(f"Converting weight {name}...")
trt_llm_name = ckpt_parser.rename_to_trt_llm(name)
if trt_llm_name is None: # omit as used with other params
continue
if "attn.q_einsum" in name:
gqa_mode = trt_llm_config.num_attention_heads != trt_llm_config.num_key_value_heads
assert gqa_mode
# initial shape: (num_q_heads, hidden_size, head_dim)
q_param = param.transpose(1, 0, 2)
q_param = split_matrix_tp(q_param, tp_size, tp_rank, dim=1)
# initial shape: (2, num_kv_heads, hidden_size, head_dim)
kv_name = name.replace("q_einsum", "kv_einsum")
kv_param = model_params[kv_name]
kv_param = kv_param.reshape(
trt_llm_config.num_key_value_heads * 2,
hidden_size,
head_dim,
).transpose(1, 0, 2)
# -> (hidden_size, num_q_heads / tp_size + 2, head_dim)
qkv_param = np.concatenate([q_param, kv_param], axis=1)
qkv_param = qkv_param.reshape(qkv_param.shape[0], -1)
qkv_param = qkv_param.transpose(1, 0)
# If int8 kv enabled, weight-only quantization will be done later.
if trt_llm_config.quant_mode.is_weight_only() and not trt_llm_config.quant_mode.has_per_group_scaling() and \
not trt_llm_config.quant_mode.has_int8_kv_cache():
qkv_param_quantized, qkv_param_scales = quantize(
qkv_param, trt_llm_config.quant_mode)
add_trt_llm_weight(weights, trt_llm_name,
qkv_param_quantized)
add_trt_llm_weight(
weights,
trt_llm_name.replace(".weight", ".per_channel_scale"),
qkv_param_scales,
trt_llm_config.dtype,
)
else:
add_trt_llm_weight(weights, trt_llm_name, qkv_param,
trt_llm_config.dtype)
elif "self_attn.qkv_proj" in name:
q_param, k_param, v_param = np.split(param, [
trt_llm_config.num_attention_heads *
trt_llm_config.head_size,
trt_llm_config.num_attention_heads *
trt_llm_config.head_size +
trt_llm_config.num_key_value_heads *
trt_llm_config.head_size
],
axis=0)
gqa_mode = trt_llm_config.num_attention_heads != trt_llm_config.num_key_value_heads
q_param = split_matrix_tp(q_param, tp_size, tp_rank, dim=0)
if not gqa_mode:
k_param = split_matrix_tp(k_param, tp_size, tp_rank, dim=0)
v_param = split_matrix_tp(v_param, tp_size, tp_rank, dim=0)
qkv_param = np.concatenate([q_param, k_param, v_param], axis=0)
if trt_llm_config.quant_mode.is_weight_only(
) and not trt_llm_config.quant_mode.has_per_group_scaling():
qkv_param_quantized, qkv_param_scales = quantize(
qkv_param, trt_llm_config.quant_mode)
add_trt_llm_weight(weights, trt_llm_name,
qkv_param_quantized)
add_trt_llm_weight(
weights,
trt_llm_name.replace(".weight", ".per_channel_scale"),
qkv_param_scales,
trt_llm_config.dtype,
)
else:
add_trt_llm_weight(weights, trt_llm_name, qkv_param,
trt_llm_config.dtype)
elif "attn.qkv_einsum" in name:
gqa_mode = trt_llm_config.num_attention_heads != trt_llm_config.num_key_value_heads
assert not gqa_mode
# initial shape: [3, num_heads, hidden_size, head_dim] -> [3, num_heads, head_dim, hidden_size]
qkv_param = param.transpose(0, 1, 3, 2)
qkv_param = qkv_param.reshape(qkv_param.shape[0], -1,
qkv_param.shape[3])
qkv_param = split_matrix_tp(qkv_param, tp_size, tp_rank, dim=1)
qkv_param = qkv_param.reshape(-1, qkv_param.shape[2])
if trt_llm_config.quant_mode.is_weight_only() and not trt_llm_config.quant_mode.has_per_group_scaling() \
and not trt_llm_config.quant_mode.has_int8_kv_cache():
qkv_param_quantized, qkv_param_scales = quantize(
qkv_param, trt_llm_config.quant_mode)
add_trt_llm_weight(weights, trt_llm_name,
qkv_param_quantized)
add_trt_llm_weight(
weights,
trt_llm_name.replace(".weight", ".per_channel_scale"),
qkv_param_scales,
trt_llm_config.dtype,
)
else:
add_trt_llm_weight(weights, trt_llm_name, qkv_param,
trt_llm_config.dtype)
elif "attention/query_dense" in name:
# Keras specific KQV convert
gqa_mode = trt_llm_config.num_attention_heads != trt_llm_config.num_key_value_heads
if gqa_mode:
# initial shape: (num_q_heads, hidden_size, head_dim)
q_param = param.transpose(1, 0, 2)
q_param = split_matrix_tp(q_param, tp_size, tp_rank, dim=1)
# initial shape: (2, num_kv_heads, hidden_size, head_dim)
k_name = name.replace("query", "key")
k_param = model_params[k_name]
v_name = name.replace("query", "value")
v_param = model_params[v_name]
kv_param = np.stack((k_param, v_param), axis=0)
kv_param = kv_param.reshape(
trt_llm_config.num_key_value_heads * 2,
hidden_size,
head_dim,
).transpose(1, 0, 2)
# -> (hidden_size, num_q_heads / tp_size + 2, head_dim)
qkv_param = np.concatenate([q_param, kv_param], axis=1)
qkv_param = qkv_param.reshape(qkv_param.shape[0], -1)
qkv_param = qkv_param.transpose(1, 0)
if trt_llm_config.quant_mode.is_weight_only(
) and not trt_llm_config.quant_mode.has_int8_kv_cache():
qkv_param_quantized, qkv_param_scales = quantize(
qkv_param, trt_llm_config.quant_mode)
add_trt_llm_weight(weights, trt_llm_name,
qkv_param_quantized)
add_trt_llm_weight(
weights,
trt_llm_name.replace(".weight",
".per_channel_scale"),
qkv_param_scales,
trt_llm_config.dtype,
)
else:
add_trt_llm_weight(weights, trt_llm_name, qkv_param,
trt_llm_config.dtype)
else:
q_param = param
k_name = name.replace("query", "key")
k_param = model_params[k_name]
v_name = name.replace("query", "value")
v_param = model_params[v_name]
# initial shape: [3, num_heads, hidden_size, head_dim] -> [3, num_heads, head_dim, hidden_size]
qkv_param = np.stack((q_param, k_param, v_param), axis=0)
qkv_param = qkv_param.transpose(0, 1, 3, 2)
qkv_param = qkv_param.reshape(qkv_param.shape[0], -1,
qkv_param.shape[3])
qkv_param = split_matrix_tp(qkv_param,
tp_size,
tp_rank,
dim=1)
qkv_param = qkv_param.reshape(-1, qkv_param.shape[2])
if trt_llm_config.quant_mode.is_weight_only(
) and not trt_llm_config.quant_mode.has_int8_kv_cache():
qkv_param_quantized, qkv_param_scales = quantize(
qkv_param, trt_llm_config.quant_mode)
add_trt_llm_weight(weights, trt_llm_name,
qkv_param_quantized)
add_trt_llm_weight(
weights,
trt_llm_name.replace(".weight",
".per_channel_scale"),
qkv_param_scales,
trt_llm_config.dtype,
)
else:
add_trt_llm_weight(weights, trt_llm_name, qkv_param,
trt_llm_config.dtype)
elif "attention.dense.weight" in trt_llm_name:
# initial shape: (num_heads, head_dim, hidden_size)
if len(param.shape) == 3:
param = param.reshape(-1, param.shape[2])
param = param.transpose(
1, 0) # (hidden_size, num_heads * head_dum)
param = split_matrix_tp(param, tp_size, tp_rank, dim=1)
if trt_llm_config.quant_mode.is_weight_only(
) and not trt_llm_config.quant_mode.has_int8_kv_cache():
param_quantized, param_scales = quantize(
param, trt_llm_config.quant_mode)
add_trt_llm_weight(weights, trt_llm_name, param_quantized)
add_trt_llm_weight(
weights,
trt_llm_name.replace(".weight", ".per_channel_scale"),
param_scales,
trt_llm_config.dtype,
)
else:
add_trt_llm_weight(weights, trt_llm_name, param,
trt_llm_config.dtype)
elif "mlp.fc.weight" in trt_llm_name:
if isinstance(ckpt_parser, KerasParser):
# initial shape: (hidden_size, intermediate_size)
fc_param, gate_param = param, model_params[name.replace(
"gating_ffw", "gating_ffw_2")]
elif isinstance(ckpt_parser, TorchParser):
# initial shape: (intermediate_size, hidden_size)
fc_param, gate_param = param, model_params[name.replace(
"mlp.gate_proj", "mlp.up_proj")]
fc_param = fc_param.transpose(1, 0)
gate_param = gate_param.transpose(1, 0)
else:
# initial shape: (2, hidden_size, intermediate_size)
fc_param, gate_param = param[0], param[1]
fc_param = fc_param.transpose(1, 0)
fc_param = split_matrix_tp(fc_param, tp_size, tp_rank, dim=0)
if trt_llm_config.quant_mode.is_weight_only() and not trt_llm_config.quant_mode.has_per_group_scaling() and \
not trt_llm_config.quant_mode.has_int8_kv_cache():
fc_param_quantized, fc_param_scales = quantize(
fc_param, trt_llm_config.quant_mode)
add_trt_llm_weight(weights, trt_llm_name,
fc_param_quantized)
add_trt_llm_weight(
weights,
trt_llm_name.replace(".weight", ".per_channel_scale"),
fc_param_scales,
trt_llm_config.dtype,
)
else:
add_trt_llm_weight(weights, trt_llm_name, fc_param,
trt_llm_config.dtype)
gate_param = gate_param.transpose(1, 0)
gate_param = split_matrix_tp(gate_param,
tp_size,
tp_rank,
dim=0)
trt_llm_name = trt_llm_name.replace("mlp.fc.weight",
"mlp.gate.weight")
if trt_llm_config.quant_mode.is_weight_only() and not trt_llm_config.quant_mode.has_per_group_scaling() and \
not trt_llm_config.quant_mode.has_int8_kv_cache():
gate_param_quantized, gate_param_scales = quantize(
gate_param, trt_llm_config.quant_mode)
add_trt_llm_weight(weights, trt_llm_name,
gate_param_quantized)
add_trt_llm_weight(
weights,
trt_llm_name.replace(".weight", ".per_channel_scale"),
gate_param_scales,
trt_llm_config.dtype,
)
else:
add_trt_llm_weight(weights, trt_llm_name, gate_param,
trt_llm_config.dtype)
elif "mlp.proj.weight" in trt_llm_name:
if not isinstance(ckpt_parser, TorchParser):
# initial shape: (intermediate_size, hidden_size)
param = param.transpose(1, 0)
param = split_matrix_tp(param, tp_size, tp_rank, dim=1)
if trt_llm_config.quant_mode.is_weight_only() and not trt_llm_config.quant_mode.has_per_group_scaling() and \
not trt_llm_config.quant_mode.has_int8_kv_cache():
param_quantized, param_scales = quantize(
param, trt_llm_config.quant_mode)
add_trt_llm_weight(weights, trt_llm_name, param_quantized)
add_trt_llm_weight(
weights,
trt_llm_name.replace(".weight", ".per_channel_scale"),
param_scales,
trt_llm_config.dtype,
)
else:
add_trt_llm_weight(weights, trt_llm_name, param,
trt_llm_config.dtype)
elif "embedder.input_embedding" in name or "reversible_embedding" in name or "embedder.weight" in name:
if not trt_llm_config.share_embedding_table:
# TODO: safetensor doesn't allow to save a shared tensor.
# Currently, we clone the weight but to save the disk, it
# would be better to skip saving lm_head weights and
# handle it at the loading phase.
lm_head = split_matrix_tp(param, tp_size, tp_rank, dim=0)
add_trt_llm_weight(weights, "lm_head.weight",
np.copy(lm_head), trt_llm_config.dtype)
param = np.multiply(
param.astype(np.float32),
math.sqrt(trt_llm_config.hidden_size),
)
if trt_llm_config.use_parallel_embedding:
assert trt_llm_config.vocab_size % tp_size == 0
param = split_matrix_tp(
param,
tp_size,
tp_rank,
dim=trt_llm_config.embedding_sharding_dim,
)
add_trt_llm_weight(weights, trt_llm_name, param,
trt_llm_config.dtype)
elif any(keyword in name for keyword in (
"pre_attention_norm.scale",
"pre_ffw_norm.scale",
"final_norm.scale",
"pre_attention_norm/vars/0",
"pre_ffw_norm/vars/0",
"rms_normalization/vars/0",
"input_layernorm",
"post_attention_layernorm",
"model.norm.weight",
)):
param = param + 1.0 # upcasted to float32 in case of bfloat16
add_trt_llm_weight(weights, trt_llm_name, param,
trt_llm_config.dtype)
else:
raise RuntimeError(f"Unhandled {name} module weights")
del model_params
print(
f"Weights loaded. Total time: {time.strftime('%H:%M:%S', time.gmtime(time.time() - tik))}"
)
return weights
def convert(worker_rank, args, convert_kwargs):
for rank in range(worker_rank, args.world_size):
weights = convert_from_checkpoint(rank=rank, **convert_kwargs)
trt_llm_config = convert_kwargs.get("trt_llm_config")
if args.use_smooth_quant_plugin is not None or args.calibrate_kv_cache:
qkv_para = {}
smoother = {}
dataset = load_dataset("ccdv/cnn_dailymail", '3.0.0')
tokenizer = sp.SentencePieceProcessor(model_file=args.tokenizer_dir)
hf_model = create_model_from_config(trt_llm_config, weights)
act_range = capture_activation_range(hf_model, tokenizer, dataset)
if args.use_smooth_quant_plugin is not None:
smooth_model(hf_model, act_range, args.use_smooth_quant_plugin,
qkv_para, smoother)
weights = convert_hf_model(
hf_model, trt_llm_config.mapping, trt_llm_config.vocab_size,
args.dtype, False, 0,
args.use_weight_only_with_precision != None,
torch.int8 if args.use_weight_only_with_precision == 'int8' else
torch.quint4x2, args.use_smooth_quant_plugin is not None,
args.per_channel, args.per_token, args.calibrate_kv_cache,
act_range, qkv_para, smoother)
safetensors.torch.save_file(
weights, args.output_model_dir / f"rank{rank}.safetensors")
return
use_awq = False
if args.use_weight_only_with_precision:
if args.use_weight_only_with_precision.endswith("awq"):
use_awq = True
if use_awq:
weights = dummy_weights_awq(
weights=weights,
precision=args.use_weight_only_with_precision,
trt_llm_config=trt_llm_config,
group_size=128)
elif args.enable_fp8 or args.fp8_kv_cache:
weight_scales = quantize_fp8_weigths(
weights, trt_llm_config.num_hidden_layers,
trt_llm_config.mapping)
scales = load_from_fp8_llama(args.ammo_quant_ckpt_path,
trt_llm_config.num_hidden_layers,
trt_llm_config.mapping,
args.fp8_kv_cache, weight_scales)
weights.update(scales)
safetensors.numpy.save_file(
weights, args.output_model_dir / f"rank{rank}.safetensors")
def main():
args = parse_arguments()
tik = time.time()
print(f"Loading source parameters from {args.model_dir.absolute()}")
ckpt_parser = CKPT_PARSER[args.ckpt_type]()
ckpt_params = ckpt_parser.load_parameters(args.model_dir)
input_embedding_weights = ckpt_parser.embedding_weights(ckpt_params)
num_embed, _ = input_embedding_weights.shape
ckpt_params_dtype = str(
input_embedding_weights.dtype).split(".")[-1] # np.bfloat16 -> bfloat16
ckpt_config = ckpt_parser.get_config(args.model_dir, ckpt_params, num_embed)
# 2B TransformerConfig(num_layers=18, num_embed=256128, embed_dim=2048, hidden_dim=16384, num_heads=8, head_dim=256, num_kv_heads=1)
# 7B TransformerConfig(...)
print(f"Source configuration determined from parameters: {ckpt_config}")
quant_mode = tensorrt_llm.quantization.QuantMode(0)
quant_kwargs = {}
quant_algo = None
kv_cache_quant_algo = None
if args.use_weight_only_with_precision:
quant_algo = {
"int8": "W8A16",
"int4": "W4A16",
"w4a8_awq": "W4A8_AWQ",
"w4a16_awq": "W4A16_AWQ",
}[args.use_weight_only_with_precision]
elif args.enable_fp8:
quant_algo = "FP8"
elif args.use_smooth_quant:
quant_algo = "W8A8_SQ_PER_CHANNEL"
if args.fp8_kv_cache:
kv_cache_quant_algo = "FP8"
if args.calibrate_kv_cache:
kv_cache_quant_algo = "INT8"
if args.use_smooth_quant:
quant_algo = "W8A8_SQ_PER_CHANNEL"
elif args.use_smooth_quant_plugin is not None:
if args.per_token and args.per_channel:
quant_algo = 'W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN'
elif not args.per_token and not args.per_channel:
quant_algo = 'W8A8_SQ_PER_TENSOR_PLUGIN'
elif not args.per_token and args.per_channel:
quant_algo = 'W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN'
elif args.per_token and not args.per_channel:
quant_algo = 'W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN'
quant_kwargs.update(sq_use_plugin=True)
quant_kwargs.update(quant_algo=quant_algo,
kv_cache_quant_algo=kv_cache_quant_algo)
if quant_algo is not None or kv_cache_quant_algo is not None:
quant_mode = tensorrt_llm.quantization.QuantMode.from_quant_algo(
quant_algo,
kv_cache_quant_algo=kv_cache_quant_algo,
)
if args.use_weight_only_with_precision:
if args.use_weight_only_with_precision.endswith("awq"):
quant_kwargs.update(has_zero_point=False,
pre_quant_scale=True,
exclude_modules=["lm_head"])
trt_llm_config = tensorrt_llm.models.modeling_utils.PretrainedConfig(
architecture="GemmaForCausalLM",
dtype=args.dtype or ckpt_params_dtype,
logits_dtype="float32",
vocab_size=ckpt_config.num_embed,
max_position_embeddings=8192,
hidden_size=ckpt_config.embed_dim,
num_hidden_layers=ckpt_config.num_layers,
num_attention_heads=ckpt_config.num_heads,
num_key_value_heads=ckpt_config.num_kv_heads,
head_size=ckpt_config.head_dim,
hidden_act="gelu",
intermediate_size=ckpt_config.hidden_dim,
norm_epsilon=1e-6, # hard-coded in RMSNorm from gemma/layers.py
position_embedding_type="rope_gpt_neox",
world_size=args.world_size,
tp_size=args.world_size,
pp_size=1,
quant_mode=quant_mode,
quant_kwargs=quant_kwargs,
)
trt_llm_config_dict = trt_llm_config.to_dict()
print(f"Determined TensorRT-LLM configuration {trt_llm_config_dict}")
config_path = args.output_model_dir / "config.json"
config_path.parent.mkdir(exist_ok=True, parents=True)
LOGGER.debug(f"Saving TensorRT-LLM configuration to {config_path}")
with config_path.open("w") as config_file:
json.dump(trt_llm_config_dict, config_file, indent=4)
convert_args = dict(trt_llm_config=trt_llm_config,
model_dir=args.model_dir,
ckpt_parser=ckpt_parser)
convert(0, args, convert_args)
elapsed = time.strftime("%H:%M:%S", time.gmtime(time.time() - tik))
print(f"Total time of converting checkpoints: {elapsed}")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,9 @@
-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
flax~=0.8.0
jax[cuda12_pip]~=0.4.19
safetensors~=0.4.1
sentencepiece~=0.1.99
h5py~=3.10.0
easydict~=1.11
rouge_score
nltk

View File

@ -0,0 +1,14 @@
# Copyright 2024 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

View File

@ -0,0 +1,39 @@
# Copyright 2024 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Base layers."""
import jax
import jax.numpy as jnp
from flax import linen as nn
class Einsum(nn.Module):
shape: tuple[int, ...]
@nn.compact
def __call__(self, eqn: str, x: jax.Array) -> jax.Array:
w = self.param('w', nn.initializers.zeros_init(), self.shape)
return jnp.einsum(eqn, x, w)
class RMSNorm(nn.Module):
@nn.compact
def __call__(self, x):
scale = self.param('scale', nn.initializers.zeros_init(), (x.shape[-1]))
var = jnp.mean(jnp.square(x), axis=-1, keepdims=True)
normed_inputs = jnp.asarray(x * jnp.reciprocal(jnp.sqrt(var + 1e-06)))
normed_inputs = normed_inputs * (1 + scale)
return normed_inputs

View File

@ -0,0 +1,206 @@
# Copyright 2024 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Transformer sub-modules.
"""
import jax
import jax.numpy as jnp
from flax import linen as nn
from . import layers, positional_embeddings
K_MASK = -2.3819763e38 # Set to a large negative number.
LayerCache = dict[str, jax.Array]
def init_layer_cache(cache_size: int, num_heads: int, head_dim: int,
batch_size: int) -> LayerCache:
return {
'v':
jnp.zeros((batch_size, cache_size, num_heads, head_dim),
dtype=jnp.float32),
'k':
jnp.zeros((batch_size, cache_size, num_heads, head_dim),
dtype=jnp.float32),
}
class Embedder(nn.Module):
"""Embedder module."""
vocab_size: int
embed_dim: int
def setup(self):
self.input_embedding_table = self.param(
'input_embedding',
nn.initializers.zeros_init(),
(self.vocab_size, self.embed_dim),
)
def encode(self, x: jax.Array) -> jax.Array:
x = self.input_embedding_table[(x, )]
x *= jnp.sqrt(self.embed_dim).astype(x.dtype)
return x
def decode(self, x: jax.Array) -> jax.Array:
return jnp.dot(x, self.input_embedding_table.T)
class Attention(nn.Module):
"""Attention module."""
num_heads: int
num_kv_heads: int
features: int
head_dim: int
@property
def use_qkv_einsum(self):
return self.num_kv_heads == self.num_heads
def setup(self):
self.attn_vec_einsum = layers.Einsum(shape=(self.num_heads,
self.head_dim,
self.features), )
if self.use_qkv_einsum:
self.qkv_einsum = layers.Einsum(shape=(3, self.num_heads,
self.features,
self.head_dim), )
else:
self.q_einsum = layers.Einsum(shape=(self.num_heads, self.features,
self.head_dim), )
self.kv_einsum = layers.Einsum(shape=(2, self.num_kv_heads,
self.features,
self.head_dim), )
def __call__(
self,
x: jax.Array,
segment_pos: int,
cache: LayerCache,
attn_mask: jax.Array,
time_step: int,
) -> tuple[LayerCache, jax.Array]:
bsz = x.shape[0]
if self.use_qkv_einsum:
query_proj, key_proj, value_proj = self.qkv_einsum(
'BTD,SNDH->SBTNH', x)
else:
query_proj = self.q_einsum('BTD,NDH->BTNH', x)
key_proj, value_proj = self.kv_einsum('BSD,CKDH->CBSKH', x)
query_proj = positional_embeddings.apply_rope(
query_proj,
segment_pos,
head_dim=self.head_dim,
)
query_scaled = query_proj * self.head_dim**-0.5
key_proj = positional_embeddings.apply_rope(
key_proj,
segment_pos,
head_dim=self.head_dim,
)
# Cache is left aligned.
cache['v'] = (cache['v'].at[:bsz, [time_step], :, :].set(value_proj)
) # values
cache['k'] = (cache['k'].at[:bsz, [time_step], :, :].set(key_proj)
) # rotated_keys
logits = jnp.einsum('BTNH,BSNH->BTNS', query_scaled, cache['k'])
logits = logits.astype(jnp.float32)
padded_logits = jnp.where(
(jnp.expand_dims(attn_mask, -2) >= K_MASK * 0.5), logits, K_MASK)
probs = jax.nn.softmax(padded_logits, axis=-1).astype(cache['k'].dtype)
encoded = jnp.einsum('BTNS,BSNH->BTNH', probs, cache['v'])
attn_output = self.attn_vec_einsum('BTNH,NHD->BTD', encoded)
return cache, attn_output
class FeedForward(nn.Module):
"""Feed forward module."""
features: int
hidden_dim: int
@nn.compact
def __call__(self, x):
w_gating = self.param(
'gating_einsum',
nn.initializers.zeros_init(),
((2, self.features, self.hidden_dim)),
)
ff_gate = jnp.dot(x, w_gating[0])
gate_value = nn.gelu(ff_gate)
ff1 = jnp.dot(x, w_gating[1])
activations = gate_value * ff1
w_linear = self.param(
'linear',
nn.initializers.zeros_init(),
(self.hidden_dim, self.features),
)
outputs = jnp.dot(activations, w_linear)
return outputs
class Block(nn.Module):
"""Transformer block."""
num_heads: int
num_kv_heads: int
embed_dim: int
head_dim: int
hidden_dim: int
def setup(self):
self.pre_attention_norm = layers.RMSNorm()
self.attn = Attention(
num_heads=self.num_heads,
features=self.embed_dim,
head_dim=self.head_dim,
num_kv_heads=self.num_kv_heads,
)
self.pre_ffw_norm = layers.RMSNorm()
self.mlp = FeedForward(features=self.embed_dim,
hidden_dim=self.hidden_dim)
def __call__(
self,
x: jax.Array,
segment_pos: int,
cache: LayerCache,
attn_mask: jax.Array,
time_step: int,
):
inputs_normalized = self.pre_attention_norm(x)
cache, attn_output = self.attn(inputs_normalized, segment_pos, cache,
attn_mask, time_step)
attn_output += x
residual = attn_output
attn_output = self.pre_ffw_norm(attn_output)
outputs = self.mlp(attn_output)
outputs = residual + outputs
return cache, outputs

View File

@ -0,0 +1,73 @@
# Copyright 2024 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Utils for loading Gemma params.
These utilities are just helpers for current development. They will not be
needed once Gemma switches to Orbax and changes checkpoint formats ahead of
open sourcing.
"""
import functools
from typing import Any
import orbax.checkpoint
Params = dict[str, Any]
@functools.cache
def load_params(path: str) -> Params:
"""Loads parameters from a checkpoint path."""
checkpointer = orbax.checkpoint.PyTreeCheckpointer()
params = checkpointer.restore(path)
return params
def param_remapper(orig_params: Params) -> Params:
"""Remaps params to new module layout.
This is needed here because the model definition does not have a separate
`mlp` module. For the real code release, we will just save the params in a
different format and this will not be needed.
Args:
orig_params: original dict of parameters in Gemma format.
Returns:
dict of params with different names.
"""
new_params = {}
for k, v in orig_params.items():
if 'mlp/' in k:
layer_name, param = k.rsplit('/', maxsplit=1)
if layer_name not in new_params:
new_params[layer_name] = {}
if 'w' in v:
new_params[layer_name][param] = v['w']
else:
new_params[k] = v
return new_params
def nest_params(params: Params) -> Params:
"""Nests params as a dict of dicts rather than a flat dict."""
nested_params = {}
for path, param in params.items():
*path, leaf = path.split('/')
subdict = nested_params
for key in path:
subdict = subdict.setdefault(key, {})
subdict[leaf] = param
return nested_params

View File

@ -0,0 +1,92 @@
# Copyright 2024 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Utils for positional embeddings (including RoPE).
"""
import jax
import jax.numpy as jnp
_MAX_WAVELENGTH = 10_000
def add_positional_embedding(
input_embedding: jax.Array,
position: int,
max_wavelength: int = _MAX_WAVELENGTH,
) -> jax.Array:
"""Adds positional embeddings to input embeddings."""
embed_dim = input_embedding.shape[-1]
num_timescales = embed_dim // 2
log_timescale_increment = jnp.log(float(max_wavelength)) / jnp.maximum(
jnp.asarray(num_timescales, dtype=jnp.float32) - 1, 1)
inv_timescales = jnp.exp(
jnp.arange(num_timescales, dtype=jnp.float32) *
-log_timescale_increment)
scaled_time = position * inv_timescales
signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)])
signal = jnp.pad(signal, [[0, jnp.mod(embed_dim, 2)]])
position_embedding = signal.astype(jnp.float32)
return input_embedding + position_embedding
def _rotary_embed(
inputs: jax.Array, # [B, 1, H, D]
position: jax.Array, # [B,]
head_dim: int,
max_wavelength: int = _MAX_WAVELENGTH,
) -> jax.Array:
"""Helper for RoPE."""
fraction = 2 * jnp.arange(0, head_dim // 2) / head_dim
timescale = max_wavelength**fraction
timescale = timescale[jnp.newaxis, jnp.newaxis, jnp.newaxis, :]
sinusoid_inp = position[:, jnp.newaxis, jnp.newaxis,
jnp.newaxis] / timescale
sin = jnp.sin(sinusoid_inp)
cos = jnp.cos(sinusoid_inp)
first_half, second_half = jnp.split(inputs, 2, axis=-1)
first_part = first_half * cos - second_half * sin
second_part = second_half * cos + first_half * sin
return jnp.concatenate([first_part, second_part], axis=-1)
def apply_rope(
inputs: jax.Array,
position: int,
head_dim: int,
max_wavelength: int = _MAX_WAVELENGTH,
) -> jax.Array:
"""Applies RoPE."""
batch_size, seq_length = inputs.shape[0:2]
position = jnp.broadcast_to(position, [batch_size])[:, jnp.newaxis]
prefix_position = jnp.arange(seq_length, dtype=jnp.int32)
prefix_position = (position - jnp.flip(prefix_position)[jnp.newaxis, :]
) # [B, seq_len]
prefix_position = jnp.where(prefix_position < 0,
jnp.zeros_like(prefix_position),
prefix_position).reshape((batch_size, ))
output = _rotary_embed(
inputs,
position=prefix_position,
head_dim=head_dim,
max_wavelength=max_wavelength,
)
return output

View File

@ -0,0 +1,190 @@
# Copyright 2024 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Sampler for Gemma transformer.
An example of a sampling class for a Gemma model.
"""
import chex
import jax
import jax.numpy as jnp
import sentencepiece as spm
from . import modules
from . import params as params_lib
from . import transformer as transformer_lib
def _compute_attention_masks(time_step: jax.Array, seq_len: int,
input_mask: jax.Array) -> jax.Array:
"""Computes causal attention mask."""
bsz = input_mask.shape[0]
batch_time_step = jnp.full((bsz, 1), time_step, dtype=jnp.uint32)
causal_padding = jnp.greater(jnp.expand_dims(jnp.arange(seq_len), 0),
batch_time_step)
causal_padding = causal_padding * jnp.expand_dims(input_mask, axis=-1)
attention_mask = (
causal_padding[:, jnp.newaxis, jnp.newaxis, :].astype(jnp.float32) *
modules.K_MASK)
attention_mask = jnp.squeeze(attention_mask, axis=1)
return attention_mask
@chex.dataclass
class _SamplingState:
# Number of tokens in the prompt.
num_input_tokens: jnp.int32 # [B]
# Fixed-size buffer for accumulating the output tokens.
token_buffer: jnp.ndarray # [B, L]
# Model state for conditioning the model on autoregressively.
cache: dict[str, modules.LayerCache]
class Sampler:
"""Sampler for Gemma transformer."""
def __init__(
self,
transformer_config: transformer_lib.TransformerConfig,
vocab: spm.SentencePieceProcessor,
params: params_lib.Params,
cache_size: int,
buffer_size: int,
max_decode_steps: int,
):
self.transformer = transformer_lib.Transformer(
config=transformer_config)
self.vocab = vocab
self.params = params
self.cache_size = cache_size
self.buffer_size = buffer_size
self.max_decode_steps = max_decode_steps
self._compiled_sample_fn = jax.jit(self._sample_fn)
def _sample_step(self, params, time_step,
sampler_state: _SamplingState) -> _SamplingState:
"""Performs a single sampling step."""
time_step = jnp.asarray(time_step, dtype=jnp.int32)
last_token = sampler_state.token_buffer[:, time_step]
input_mask = last_token != self.vocab.pad_id()
attention_mask = _compute_attention_masks(
time_step, self.cache_size, input_mask).astype(jnp.float32)
logits, cache = self.transformer.apply(
{'params': params},
last_token,
time_step,
sampler_state.cache,
attention_mask,
time_step,
)
next_token_candidate = jnp.argmax(logits, axis=-1) # [B, 1]
next_token_candidate = next_token_candidate[:, 0] # [B,]
next_token_candidate = jnp.where(
time_step < sampler_state.num_input_tokens - 1,
sampler_state.token_buffer[:, time_step + 1],
next_token_candidate,
)
token_buffer = sampler_state.token_buffer.at[:, time_step + 1].set(
next_token_candidate)
return _SamplingState(
num_input_tokens=sampler_state.num_input_tokens,
token_buffer=token_buffer,
cache=cache,
)
def init_cache(self, bsz) -> dict[str, modules.LayerCache]:
"""Initializes the attention cache for each layer."""
return {
f'layer_{i}': modules.init_layer_cache(
self.cache_size,
self.transformer.config.num_heads,
self.transformer.config.head_dim,
bsz,
)
for i in range(self.transformer.config.num_layers)
}
def init_sample_state(self,
all_input_ids: list[jax.Array]) -> _SamplingState:
"""Initializes the sampling state given input prompts."""
bsz = len(all_input_ids)
num_input_tokens = [len(input_ids) for input_ids in all_input_ids]
token_buffer = jnp.full(
(
bsz,
self.buffer_size,
),
self.vocab.pad_id(),
dtype=jnp.int32,
)
for i, (input_ids,
num_tokens) in enumerate(zip(all_input_ids, num_input_tokens)):
token_buffer = token_buffer.at[i, :num_tokens].set(input_ids)
return _SamplingState(
num_input_tokens=jnp.array(num_input_tokens, dtype=jnp.int32),
token_buffer=token_buffer,
cache=self.init_cache(bsz),
)
def tokenize(self, input_string: str) -> jax.Array:
"""Tokenizes the input string."""
input_ids = self.vocab.EncodeAsIds(input_string)
input_ids = jnp.array([self.vocab.bos_id()] +
jnp.array(input_ids).tolist(),
dtype=jnp.int32)
return input_ids
def _sample_fn(
self,
params: params_lib.Params,
initial_sampling_state: _SamplingState,
) -> _SamplingState:
def sample_with_params(time_step: int, sampler_state: _SamplingState):
return self._sample_step(params, time_step, sampler_state)
return jax.lax.fori_loop(0, self.max_decode_steps, sample_with_params,
initial_sampling_state)
def __call__(self, input_strings: list[str] | str) -> list[str]:
"""Samples a completion of the input string."""
if isinstance(input_strings, str):
input_strings = [input_strings]
all_input_ids = [self.tokenize(x) for x in input_strings]
initial_sampling_state = self.init_sample_state(all_input_ids)
sampling_state = self._compiled_sample_fn(self.params,
initial_sampling_state)
out_tokens = [
buffer[num_tokens:num_tokens + self.max_decode_steps]
for buffer, num_tokens in zip(sampling_state.token_buffer,
sampling_state.num_input_tokens)
]
decoded_outputs = [
self.vocab.DecodeIds(out_tokens.tolist())
for out_tokens in out_tokens
]
return decoded_outputs

View File

@ -0,0 +1,113 @@
"""Gemma transformer."""
import dataclasses
import jax
import jax.numpy as jnp
from flax import linen as nn
from . import layers, modules
from . import params as params_lib
Cache = dict[str, modules.LayerCache]
@dataclasses.dataclass
class TransformerConfig:
"""Configuration for the Gemma transformer."""
num_layers: int
num_embed: int
embed_dim: int
hidden_dim: int
num_heads: int
head_dim: int
num_kv_heads: int
@classmethod
def from_params(cls, params: params_lib.Params,
num_embed: int) -> 'TransformerConfig':
"""Creates a TransformerConfig from loaded parameters."""
num_layers = (max([
int(k.split('_')[1])
for k in params['transformer'].keys() if 'layer_' in k
]) + 1)
hidden_dim, embed_dim = (
params['transformer']['layer_0']['mlp']['linear'].shape)
num_heads, head_dim, _ = (params['transformer']['layer_0']['attn']
['attn_vec_einsum']['w'].shape)
use_qkv_einsum = 'qkv_einsum' in params['transformer']['layer_0'][
'attn']
if use_qkv_einsum:
num_kv_heads = num_heads
else:
num_kv_heads = params['transformer']['layer_0']['attn'][
'kv_einsum']['w'].shape[1]
return cls(
num_layers=num_layers,
num_embed=num_embed,
embed_dim=embed_dim,
hidden_dim=hidden_dim,
num_heads=num_heads,
head_dim=head_dim,
num_kv_heads=num_kv_heads,
)
def init_cache(config: TransformerConfig, cache_size: int,
batch_size: int) -> Cache:
"""Initializes a new Transformer cache."""
return {
f'layer_{i}': modules.init_layer_cache(cache_size, config.num_heads,
config.head_dim, batch_size)
for i in range(config.num_layers)
}
class Transformer(nn.Module):
"""Gemma transformer."""
config: TransformerConfig
def setup(self):
self.embedder = modules.Embedder(
vocab_size=self.config.num_embed,
embed_dim=self.config.embed_dim,
)
self.blocks = [
modules.Block(
name=f'layer_{i}',
num_heads=self.config.num_heads,
num_kv_heads=self.config.num_kv_heads,
embed_dim=self.config.embed_dim,
head_dim=self.config.head_dim,
hidden_dim=self.config.hidden_dim,
) for i in range(self.config.num_layers)
]
self.final_norm = layers.RMSNorm()
def __call__(
self,
last_tokens: jax.Array, # [B,]
current_token_position: int,
cache: Cache,
attention_mask: jax.Array, # [B, 1, L]
time_step: int,
) -> tuple[jax.Array, Cache]:
input_emb = self.embedder.encode(last_tokens)
x = jnp.expand_dims(input_emb, axis=1) # adding temporal dimension
for i, block in enumerate(self.blocks):
layer_name = f'layer_{i}'
cache[layer_name], x = block(
x,
current_token_position,
cache[layer_name],
attention_mask,
time_step,
)
x = self.final_norm(x)
logits = self.embedder.decode(x)
return logits, cache

View File

@ -30,11 +30,6 @@ from tensorrt_llm.models.llama.weight import (load_from_gptq_llama,
from tensorrt_llm.models.modeling_utils import PretrainedConfig
from tensorrt_llm.runtime.lora_manager import LoraConfig
try:
from transformers import MixtralForCausalLM
except ImportError:
MixtralForCausalLM = None
try:
from transformers import LlavaConfig, LlavaForConditionalGeneration
except ImportError:

View File

@ -34,8 +34,7 @@ Here are some examples:
python convert_checkpoint.py --model_dir ./Mixtral-8x7B-v0.1 \
--output_dir ./tllm_checkpoint_mixtral_2gpu \
--dtype float16 \
--world_size 2 \
--Pp_size 2
--pp_size 2
trtllm-build --checkpoint_dir ./tllm_checkpoint_mixtral_2gpu \
--output_dir ./trt_engines/mixtral/pp2 \
--gemm_plugin float16
@ -47,7 +46,6 @@ trtllm-build --checkpoint_dir ./tllm_checkpoint_mixtral_2gpu \
python convert_checkpoint.py --model_dir ./Mixtral-8x7B-v0.1 \
--output_dir ./tllm_checkpoint_mixtral_2gpu \
--dtype float16 \
--world_size 2 \
--tp_size 2
trtllm-build --checkpoint_dir ./tllm_checkpoint_mixtral_2gpu \
--output_dir ./trt_engines/mixtral/tp2 \

View File

@ -248,7 +248,7 @@ class Pipeline:
def __call__(self, prompt):
# Run the model in batch size 1 and beam size 1
if self.model_name == 'SpecialForCausalLM':
if self.model_name == 'GemmaForCausalLM':
inputs = self.tokenizer.encode(prompt, add_special_tokens=False)
inputs = torch.tensor([self.tokenizer.bos_token_id] + inputs)
else:

View File

@ -20,8 +20,8 @@ docker run --gpus all --ipc=host --ulimit memlock=-1 --shm-size=20g -it <the doc
2. Install the quantization toolkit `ammo` and the related dependencies on top of the TensorRT-LLM installation or docker file.
```bash
# Obtain the python version from the system.
pip install --no-cache-dir --extra-index-url https://pypi.nvidia.com nvidia-ammo~=0.5.0
# Install AMMO
pip install --no-cache-dir --extra-index-url https://pypi.nvidia.com nvidia-ammo~=0.7.0
# Install the additional requirements
cd <this example folder>
pip install -r requirements.txt
@ -29,7 +29,7 @@ pip install -r requirements.txt
## APIs
[`ammo.py`](../../tensorrt_llm/models/quantized/ammo.py) uses the quantization toolkit to calibrate the PyTorch models, and generate a model config, saved as a json (for the model structure) and npz files (for the model weights) that TensorRT-LLM could parse. The model config includes everything needed by TensorRT-LLM to build the TensorRT inference engine, as explained below.
[`quantize.py`](./quantize.py) uses the quantization toolkit to calibrate the PyTorch models and export TensorRT-LLM checkpoints. Each TensorRT-LLM checkpoint contains a config file (in .json format) and one or several rank weight files (in .safetensors format). The checkpoints can be directly used by `trtllm-build` command to build TensorRT-LLM engines. See this [`doc`](../../docs/source/new_workflow.md) for more details on the TensorRT-LLM checkpoint format.
> *This quantization step may take a long time to finish and requires large GPU memory. Please use a server grade GPU if a GPU out-of-memory error occurs*
@ -41,33 +41,35 @@ pip install -r requirements.txt
PTQ can be achieved with simple calibration on a small set of training or evaluation data (typically 128-512 samples) after converting a regular PyTorch model to a quantized model.
```python
import torch
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM
import ammo.torch.quantization as atq
model = AutoModelForCausalLM.from_pretrained("...")
model = AutoModelForCausalLM.from_pretrained(...)
# Select the quantization config, for example, FP8
config = atq.FP8_DEFAULT_CFG
# Prepare the calibration set and define a forward loop
def forward_loop():
for data in calib_set:
calib_dataloader = DataLoader(...)
def calibrate_loop():
for data in calib_dataloader:
model(data)
# PTQ with in-place replacement to quantized modules
with torch.no_grad():
atq.quantize(model, config, forward_loop)
atq.quantize(model, config, forward_loop=calibrate_loop)
```
### Export Quantized Model
After the model is quantized, the model config can be stored. The model config files include all the information needed by TensorRT-LLM to generate the deployable engine, including the quantized scaling factors.
After the model is quantized, it can be exported to a TensorRT-LLM checkpoint, which includes
The exported model config are stored as
- A single JSON file recording the model structure and metadata and
- A group of npz files each recording the model on a single tensor parallel rank (model weights, scaling factors per GPU).
- One json file recording the model structure and metadata, and
- One or several rank weight files storing quantized model weights and scaling factors.
The export API is
@ -80,6 +82,8 @@ with torch.inference_mode():
decoder_type, # The type of the model as str, e.g gptj, llama or gptnext.
dtype, # The exported weights data type as torch.dtype.
export_dir, # The directory where the exported files will be stored.
inference_gpus, # The number of GPUs used in the inference time for tensor parallelism.
inference_tensor_parallel=tp_size, # The tensor parallelism size for inference.
inference_pipeline_parallel=pp_size, # The pipeline parallelism size for inference.
export_tensorrt_llm_config=True, # Enable exporting TensorRT-LLM checkpoint config file.
)
```

View File

@ -110,6 +110,7 @@ MODEL_NAME_PATTERN_MAP = {
"Bloom": "bloom",
"ChatGLM": "chatglm",
"QWen": "qwen",
"Gemma": "gemma",
}
@ -296,7 +297,7 @@ def main(args):
torch.save(model.state_dict(), export_path)
else:
export_npz = (model_type not in [
'gptj', 'falcon', 'chatglm', 'mpt', 'llama', 'baichuan'
'gptj', 'falcon', 'chatglm', 'mpt', 'llama', 'baichuan', 'gemma'
])
export_model_config(model,
model_type,
@ -320,19 +321,6 @@ def main(args):
with open(f"{export_path}/config.json", "w") as f:
json.dump(tensorrt_llm_config, f, indent=4)
# TODO(enweiz): Remove if a newer AMMO version is released
# Workaround for baichuan
if model_type == 'baichuan':
with open(f"{export_path}/config.json", 'r') as f:
tensorrt_llm_config = json.load(f)
if hasattr(model.model, "alibi_mask"):
tensorrt_llm_config["position_embedding_type"] = 'alibi'
else:
tensorrt_llm_config[
"position_embedding_type"] = 'rope_gpt_neox'
with open(f"{export_path}/config.json", "w") as f:
json.dump(tensorrt_llm_config, f, indent=4)
end_time = time.time()
print(
"Quantized model exported to {} \nTotal time used {:.2f} s.".format(

View File

@ -214,7 +214,7 @@ def parse_input(tokenizer,
else:
print('Input file format not supported.')
raise SystemExit
if model_name == 'SpecialForCausalLM':
if model_name == 'GemmaForCausalLM':
batch_input_ids[0] = [tokenizer.bos_token_id] + batch_input_ids[0]
if num_prepend_vtokens:

View File

@ -57,9 +57,9 @@ python3 convert_checkpoint.py --model_dir ./Skywork-13B-base \
```bash
# fp16
trtllm-build --checkpoint_dir ./skywork-13b-base/trt_ckpt/fp16 \
--use_gemm_plugin float16 \
--use_gpt_attention_plugin float16 \
--enable_context_fmha \
--gemm_plugin float16 \
--gpt_attention_plugin float16 \
--context_fmha enable \
--max_batch_size 32 \
--max_input_len 512 \
--max_output_len 512 \
@ -67,9 +67,9 @@ trtllm-build --checkpoint_dir ./skywork-13b-base/trt_ckpt/fp16 \
# bf16
trtllm-build --checkpoint_dir ./skywork-13b-base/trt_ckpt/bf16 \
--use_gemm_plugin bfloat16 \
--use_gpt_attention_plugin bfloat16 \
--enable_context_fmha \
--gemm_plugin bfloat16 \
--gpt_attention_plugin bfloat16 \
--context_fmha enable \
--max_batch_size 32 \
--max_input_len 512 \
--max_output_len 512 \
@ -85,23 +85,23 @@ After building TRT engines, we can use them to perform various tasks. TensorRT-L
python ../summarize.py --hf_model_dir ./Skywork-13B-base \
--test_hf \
--batch_size 32 \
--max_input_length 512
--max_input_length 512 \
--output_len 512 \
--test_trt_llm \
--engine_dir ./skywork-13b-base/trt_engine/fp16 \
--data_type fp16 \
-check_accuracy \
--check_accuracy \
--tensorrt_llm_rouge1_threshold=14
# bf16
python ../summarize.py --hf_model_dir ./Skywork-13B-base \
--test_hf \
--batch_size 32 \
--max_input_length 512
--max_input_length 512 \
--output_len 512 \
--test_trt_llm \
--engine_dir ./skywork-13b-base/trt_engine/bf16 \
--data_type bf16 \
-check_accuracy \
--check_accuracy \
--tensorrt_llm_rouge1_threshold=14
```

View File

@ -157,12 +157,13 @@ def main(args):
max_input_length=test_token_num,
)
input_ids = torch.tensor(input_id_list)
elif model_name == 'SpecialForCausalLM':
elif model_name == 'GemmaForCausalLM':
input_ids = tokenizer.encode(
curr_text,
add_special_tokens=add_special_tokens,
truncation=True,
max_length=test_token_num)
max_length=test_token_num -
1) # minus 1 to add bos_token_id
input_ids = torch.tensor([tokenizer.bos_token_id] + input_ids)
else:
input_ids = tokenizer.encode(

View File

@ -21,25 +21,25 @@ from transformers import AutoTokenizer, T5Tokenizer
import tensorrt_llm
# TODO(enweiz): Update for refactered models
# TODO(enweiz): Update for refactored models
DEFAULT_HF_MODEL_DIRS = {
'baichuan': 'baichuan-inc/Baichuan-13B-Chat',
'BaichuanForCausalLM': 'baichuan-inc/Baichuan-13B-Chat',
'BloomForCausalLM': 'bigscience/bloom-560m',
'ChatGLMForCausalLM': 'THUDM/chatglm3-6b',
'FalconForCausalLM': 'tiiuae/falcon-rw-1b',
'gpt': 'gpt2-medium',
'GPTJForCausalLM': 'EleutherAI/gpt-j-6b',
'GPTNeoXForCausalLM': 'EleutherAI/gpt-neox-20b',
'internlm': 'internlm/internlm-chat-7b',
'llama': 'meta-llama/Llama-2-7b-hf',
'mpt': 'mosaicml/mpt-7b',
'InternLMForCausalLM': 'internlm/internlm-chat-7b',
'LlamaForCausalLM': 'meta-llama/Llama-2-7b-hf',
'MPTForCausalLM': 'mosaicml/mpt-7b',
'PhiForCausalLM': 'microsoft/phi-2',
'OPTForCausalLM': 'facebook/opt-350m',
'qwen': 'Qwen/Qwen-7B',
}
DEFAULT_PROMPT_TEMPLATES = {
'internlm':
'InternLMForCausalLM':
"<|User|>:{input_text}<eoh>\n<|Bot|>:",
'qwen':
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{input_text}<|im_end|>\n<|im_start|>assistant\n",
@ -110,7 +110,7 @@ def load_tokenizer(tokenizer_dir: Optional[str] = None,
elif model_name == 'ChatGLMForCausalLM' and model_version == 'glm':
pad_id = tokenizer.pad_token_id
end_id = tokenizer.eop_token_id
elif model_name == 'SpecialForCausalLM':
elif model_name == 'GemmaForCausalLM':
tokenizer.eos_token_id = tokenizer.sp_model.eos_id()
tokenizer.bos_token_id = tokenizer.sp_model.bos_id()
pad_id = tokenizer.pad_token_id

View File

@ -37,10 +37,10 @@ TensorRT-LLM Whisper builds TensorRT engine(s) from the pytorch checkpoint.
pip install -r requirements.txt
# Build the large-v3 model using a single GPU with plugins.
python3 build.py --output_dir whisper_large_v3 --use_gpt_attention_plugin --use_gemm_plugin --use_bert_attention_plugin
python3 build.py --output_dir whisper_large_v3 --use_gpt_attention_plugin --use_gemm_plugin --use_bert_attention_plugin --enable_context_fmha
# Build the large-v3 model using a single GPU with plugins and weight-only quantization.
python3 build.py --output_dir whisper_large_weight_only --use_gpt_attention_plugin --use_gemm_plugin --use_bert_attention_plugin --use_weight_only
python3 build.py --output_dir whisper_large_weight_only --use_gpt_attention_plugin --use_gemm_plugin --use_bert_attention_plugin --enable_context_fmha --use_weight_only
```
### Run

View File

@ -26,6 +26,7 @@ from tensorrt_llm.functional import LayerNormPositionType, LayerNormType
from tensorrt_llm.logger import logger
from tensorrt_llm.models import quantize_model
from tensorrt_llm.network import net_guard
from tensorrt_llm.plugin.plugin import ContextFMHAType
from tensorrt_llm.quantization import QuantMode
MODEL_ENCODER_NAME = "whisper_encoder"
@ -116,6 +117,9 @@ def parse_arguments():
action="store_true",
help='Quantize weights for the various GEMMs to INT4/INT8.'
'See --weight_only_precision to set the precision')
parser.add_argument('--enable_context_fmha',
default=False,
action='store_true')
parser.add_argument(
'--weight_only_precision',
const='int8',
@ -203,9 +207,11 @@ def build_encoder(model, args):
if args.use_weight_only:
tensorrt_llm_whisper_encoder = quantize_model(
tensorrt_llm_whisper_encoder, args.quant_mode)
use_gemm_woq_plugin = args.use_gemm_plugin and args.use_weight_only
load_encoder_weight(tensorrt_llm_whisper_encoder, model_metadata,
model_params, model_metadata['n_audio_layer'])
model_params, model_metadata['n_audio_layer'],
use_gemm_woq_plugin)
network = builder.create_network()
network.plugin_config.to_legacy_setting()
@ -215,6 +221,8 @@ def build_encoder(model, args):
if args.use_bert_attention_plugin:
network.plugin_config.set_bert_attention_plugin(
dtype=args.use_bert_attention_plugin)
if args.enable_context_fmha:
network.plugin_config.set_context_fmha(ContextFMHAType.enabled)
if args.remove_input_padding:
network.plugin_config.enable_remove_input_padding()
if args.use_weight_only:
@ -310,11 +318,10 @@ def build_decoder(model, args):
if args.use_weight_only:
tensorrt_llm_whisper_decoder = quantize_model(
tensorrt_llm_whisper_decoder, args.quant_mode)
use_gemm_woq_plugin = args.use_gemm_plugin and args.use_weight_only
load_decoder_weight(
tensorrt_llm_whisper_decoder,
model_params,
)
load_decoder_weight(tensorrt_llm_whisper_decoder, model_params,
use_gemm_woq_plugin)
network = builder.create_network()
network.plugin_config.to_legacy_setting()
@ -324,8 +331,13 @@ def build_decoder(model, args):
if args.use_gpt_attention_plugin:
network.plugin_config.set_gpt_attention_plugin(
dtype=args.use_gpt_attention_plugin)
if args.enable_context_fmha:
network.plugin_config.set_context_fmha(ContextFMHAType.enabled)
if args.remove_input_padding:
network.plugin_config.enable_remove_input_padding()
if args.use_weight_only:
network.plugin_config.set_weight_only_quant_matmul_plugin(
dtype=args.dtype)
with net_guard(network):
inputs = tensorrt_llm_whisper_decoder.prepare_inputs(

View File

@ -3,3 +3,4 @@ datasets
kaldialign
openai-whisper
soundfile
safetensors

View File

@ -19,13 +19,14 @@ import time
from collections import OrderedDict
from pathlib import Path
import numpy as np
import torch
from datasets import load_dataset
from tokenizer import get_tokenizer
from torch.utils.data import DataLoader
from whisper.normalizers import EnglishTextNormalizer
from whisper_utils import (log_mel_spectrogram, store_transcripts,
write_error_stats)
from whisper_utils import (N_SAMPLES, log_mel_spectrogram, pad_or_trim,
store_transcripts, write_error_stats)
import tensorrt_llm
import tensorrt_llm.logger as logger
@ -291,12 +292,18 @@ def decode_wav_file(
def collate_wrapper(batch):
speeches, labels, ids = [], [], []
speeches, durations, labels, ids = [], [], [], []
for item in batch:
speeches.append(item["audio"]["array"])
speech = item["audio"]["array"]
duration = speech.shape[-1]
speech = pad_or_trim(speech, N_SAMPLES)
speech = speech.astype(np.float32)
speech = torch.from_numpy(speech)
speeches.append(speech)
durations.append(duration)
labels.append(item["text"])
ids.append(item["id"])
return speeches, labels, ids
return speeches, durations, labels, ids
def decode_dataset(
@ -319,9 +326,12 @@ def decode_dataset(
results = []
total_duration = 0
for batch in data_loader:
waveforms, texts, ids = batch
total_duration += sum([wave.shape[0]
for wave in waveforms]) / sample_rate
waveforms, durations, texts, ids = batch
total_duration += sum(durations) / sample_rate
for wave in waveforms:
assert wave.is_pinned()
features = [
log_mel_spectrogram(wave,
model.n_mels,

View File

@ -47,8 +47,11 @@ def trans_weight(weight):
return np.ascontiguousarray(weight)
def load_encoder_weight(tensorrt_llm_whisper, model_metadata: dict,
model_params: dict, n_layer: int):
def load_encoder_weight(tensorrt_llm_whisper,
model_metadata: dict,
model_params: dict,
n_layer: int,
use_gemm_woq_plugin=True):
tensorrt_llm.logger.info('Loading encoder weights from PT...')
quant_mode = getattr(tensorrt_llm_whisper, 'quant_mode', QuantMode(0))
@ -59,6 +62,8 @@ def load_encoder_weight(tensorrt_llm_whisper, model_metadata: dict,
use_weight_only = quant_mode.is_weight_only()
param_dtype = 'float16'
tensorrt_llm_whisper.positional_embedding.value = sinusoids(
model_metadata['n_audio_ctx'], model_metadata['n_audio_state']).numpy()
@ -92,7 +97,12 @@ def load_encoder_weight(tensorrt_llm_whisper, model_metadata: dict,
processed_torch_weights, torch_weight_scales = torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix(
torch.tensor(np.ascontiguousarray(t.transpose(1, 0))),
plugin_weight_only_quant_type)
dst.value = processed_torch_weights.numpy()
if not use_gemm_woq_plugin:
dst.value = torch.tensor(
np.ascontiguousarray(t.transpose(1, 0))).numpy().astype(
str_dtype_to_np(param_dtype))
else:
dst.value = processed_torch_weights.numpy()
scales = tensorrt_llm_whisper.encoder_layers[
i].attention.qkv.per_channel_scale
scales.value = torch_weight_scales.numpy()
@ -120,7 +130,12 @@ def load_encoder_weight(tensorrt_llm_whisper, model_metadata: dict,
processed_torch_weights, torch_weight_scales = torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix(
torch.tensor(np.ascontiguousarray(t.transpose(1, 0))),
plugin_weight_only_quant_type)
dst.value = processed_torch_weights.numpy()
if not use_gemm_woq_plugin:
dst.value = torch.tensor(
np.ascontiguousarray(t.transpose(1, 0))).numpy().astype(
str_dtype_to_np(param_dtype))
else:
dst.value = processed_torch_weights.numpy()
scales = tensorrt_llm_whisper.encoder_layers[
i].attention.dense.per_channel_scale
scales.value = torch_weight_scales.numpy()
@ -147,7 +162,12 @@ def load_encoder_weight(tensorrt_llm_whisper, model_metadata: dict,
processed_torch_weights, torch_weight_scales = torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix(
torch.tensor(np.ascontiguousarray(t.transpose(1, 0))),
plugin_weight_only_quant_type)
dst.value = processed_torch_weights.numpy()
if not use_gemm_woq_plugin:
dst.value = torch.tensor(
np.ascontiguousarray(t.transpose(1, 0))).numpy().astype(
str_dtype_to_np(param_dtype))
else:
dst.value = processed_torch_weights.numpy()
scales = tensorrt_llm_whisper.encoder_layers[
i].mlp.fc.per_channel_scale
scales.value = torch_weight_scales.numpy()
@ -164,7 +184,12 @@ def load_encoder_weight(tensorrt_llm_whisper, model_metadata: dict,
processed_torch_weights, torch_weight_scales = torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix(
torch.tensor(np.ascontiguousarray(t.transpose(1, 0))),
plugin_weight_only_quant_type)
dst.value = processed_torch_weights.numpy()
if not use_gemm_woq_plugin:
dst.value = torch.tensor(
np.ascontiguousarray(t.transpose(1, 0))).numpy().astype(
str_dtype_to_np(param_dtype))
else:
dst.value = processed_torch_weights.numpy()
scales = tensorrt_llm_whisper.encoder_layers[
i].mlp.proj.per_channel_scale
scales.value = torch_weight_scales.numpy()
@ -186,10 +211,9 @@ def fuse_qkv(q, k, v):
return qkv_weight
def load_decoder_weight(
tllm_model,
model_params: dict,
):
def load_decoder_weight(tllm_model,
model_params: dict,
use_gemm_woq_plugin=True):
tensorrt_llm.logger.info('Loading decoder weights from PT...')
quant_mode = getattr(tllm_model, 'quant_mode', QuantMode(0))
@ -201,6 +225,8 @@ def load_decoder_weight(
plugin_weight_only_quant_type = torch.quint4x2
use_weight_only = quant_mode.is_weight_only()
use_int8_kv_cache = quant_mode.has_int8_kv_cache()
tllm_model.embedding.vocab_embedding.weight.value = trans_weight(
model_params['decoder.token_embedding.weight'].numpy())
tllm_model.lm_head.weight.value = trans_weight(
@ -225,8 +251,12 @@ def load_decoder_weight(
processed_torch_weights, torch_weight_scales = torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix(
torch.tensor(np.ascontiguousarray(t.transpose(1, 0))),
plugin_weight_only_quant_type)
dst.value = torch.tensor(np.ascontiguousarray(t.transpose(
1, 0))).numpy().astype(str_dtype_to_np(param_dtype))
if not use_gemm_woq_plugin:
dst.value = torch.tensor(
np.ascontiguousarray(t.transpose(1, 0))).numpy().astype(
str_dtype_to_np(param_dtype))
else:
dst.value = processed_torch_weights.numpy()
scales = layer.self_attention.qkv.per_channel_scale
scales.value = torch_weight_scales.numpy()
else:
@ -241,8 +271,12 @@ def load_decoder_weight(
processed_torch_weights, torch_weight_scales = torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix(
torch.tensor(np.ascontiguousarray(t.transpose(1, 0))),
plugin_weight_only_quant_type)
dst.value = torch.tensor(np.ascontiguousarray(t.transpose(
1, 0))).numpy().astype(str_dtype_to_np(param_dtype))
if not use_gemm_woq_plugin:
dst.value = torch.tensor(
np.ascontiguousarray(t.transpose(1, 0))).numpy().astype(
str_dtype_to_np(param_dtype))
else:
dst.value = processed_torch_weights.numpy()
scales = layer.self_attention.dense.per_channel_scale
scales.value = torch_weight_scales.numpy()
else:
@ -263,6 +297,12 @@ def load_decoder_weight(
model_params['decoder.blocks.' + str(i) +
'.attn.out.bias'].numpy())
if use_int8_kv_cache:
t = fromfile(
"quantize/1-gpu", 'model.decoder.blocks.' + str(i) +
'.attn.query_key_value.scale_y_quant_orig.bin', [1], np.float32)
layer.self_attention.kv_cache_scaling_factor.value = t
layer.self_attention_layernorm.weight.value = trans_weight(
model_params['decoder.blocks.' + str(i) +
'.attn_ln.weight'].numpy())
@ -284,17 +324,17 @@ def load_decoder_weight(
processed_torch_weights, torch_weight_scales = torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix(
torch.tensor(np.ascontiguousarray(t.transpose(1, 0))),
plugin_weight_only_quant_type)
dst.value = torch.tensor(np.ascontiguousarray(t.transpose(
1, 0))).numpy().astype(str_dtype_to_np(param_dtype))
if not use_gemm_woq_plugin:
dst.value = torch.tensor(
np.ascontiguousarray(t.transpose(1, 0))).numpy().astype(
str_dtype_to_np(param_dtype))
else:
dst.value = processed_torch_weights.numpy()
scales = layer.cross_attention.qkv.per_channel_scale
scales.value = torch_weight_scales.numpy()
else:
dst.value = t
layer.cross_attention.dense.weight.value = trans_weight(
model_params['decoder.blocks.' + str(i) +
'.cross_attn.out.weight'].numpy())
t = trans_weight(model_params['decoder.blocks.' + str(i) +
'.cross_attn.out.weight'].numpy())
@ -304,8 +344,12 @@ def load_decoder_weight(
processed_torch_weights, torch_weight_scales = torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix(
torch.tensor(np.ascontiguousarray(t.transpose(1, 0))),
plugin_weight_only_quant_type)
dst.value = torch.tensor(np.ascontiguousarray(t.transpose(
1, 0))).numpy().astype(str_dtype_to_np(param_dtype))
if not use_gemm_woq_plugin:
dst.value = torch.tensor(
np.ascontiguousarray(t.transpose(1, 0))).numpy().astype(
str_dtype_to_np(param_dtype))
else:
dst.value = processed_torch_weights.numpy()
scales = layer.cross_attention.dense.per_channel_scale
scales.value = torch_weight_scales.numpy()
else:
@ -329,6 +373,12 @@ def load_decoder_weight(
model_params['decoder.blocks.' + str(i) +
'.cross_attn.out.bias'].numpy())
if use_int8_kv_cache:
t = fromfile(
"quantize/1-gpu", 'model.decoder.blocks.' + str(i) +
'.attn.query_key_value.scale_y_quant_orig.bin', [1], np.float32)
layer.self_attention.kv_cache_scaling_factor.value = t
layer.cross_attention_layernorm.weight.value = trans_weight(
model_params['decoder.blocks.' + str(i) +
'.cross_attn_ln.weight'].numpy())
@ -345,8 +395,12 @@ def load_decoder_weight(
processed_torch_weights, torch_weight_scales = torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix(
torch.tensor(np.ascontiguousarray(t.transpose(1, 0))),
plugin_weight_only_quant_type)
dst.value = torch.tensor(np.ascontiguousarray(t.transpose(
1, 0))).numpy().astype(str_dtype_to_np(param_dtype))
if not use_gemm_woq_plugin:
dst.value = torch.tensor(
np.ascontiguousarray(t.transpose(1, 0))).numpy().astype(
str_dtype_to_np(param_dtype))
else:
dst.value = processed_torch_weights.numpy()
scales = layer.mlp.fc.per_channel_scale
scales.value = torch_weight_scales.numpy()
else:
@ -361,8 +415,12 @@ def load_decoder_weight(
processed_torch_weights, torch_weight_scales = torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix(
torch.tensor(np.ascontiguousarray(t.transpose(1, 0))),
plugin_weight_only_quant_type)
dst.value = torch.tensor(np.ascontiguousarray(t.transpose(
1, 0))).numpy().astype(str_dtype_to_np(param_dtype))
if not use_gemm_woq_plugin:
dst.value = torch.tensor(
np.ascontiguousarray(t.transpose(1, 0))).numpy().astype(
str_dtype_to_np(param_dtype))
else:
dst.value = processed_torch_weights.numpy()
scales = layer.mlp.proj.per_channel_scale
scales.value = torch_weight_scales.numpy()
else:

View File

@ -27,6 +27,8 @@ def _add_trt_llm_dll_directory():
_add_trt_llm_dll_directory()
import sys
import tensorrt_llm.functional as functional
import tensorrt_llm.models as models
import tensorrt_llm.quantization as quantization
@ -80,4 +82,6 @@ __all__ = [
_init(log_level="error")
print(f"[TensorRT-LLM] TensorRT-LLM version: {__version__}", end='')
print(f"[TensorRT-LLM] TensorRT-LLM version: {__version__}")
sys.stdout.flush()

View File

@ -19,6 +19,7 @@ from .bloom.model import BloomForCausalLM, BloomModel
from .chatglm.model import ChatGLMForCausalLM, ChatGLMModel
from .enc_dec.model import DecoderModel, EncoderModel, WhisperEncoder
from .falcon.model import FalconForCausalLM, FalconModel
from .gemma.model import GemmaForCausalLM
from .gpt.model import GPTLMHeadModel, GPTModel
from .gptj.model import GPTJForCausalLM, GPTJModel
from .gptneox.model import GPTNeoXForCausalLM, GPTNeoXModel
@ -68,6 +69,7 @@ __all__ = [
'MPTForCausalLM',
'MPTModel',
'SkyworkForCausalLM',
'GemmaForCausalLM',
]
MODEL_MAP = {
@ -87,4 +89,5 @@ MODEL_MAP = {
'MedusaForCausalLM': MedusaForCausalLm,
'BaichuanForCausalLM': BaichuanForCausalLM,
'SkyworkForCausalLM': LLaMAForCausalLM,
'GemmaForCausalLM': GemmaForCausalLM,
}

View File

@ -1107,7 +1107,7 @@ class DecoderModel(Module, GenerationMixin):
# No enable_two_optimization_profiles support yet
encoder_input_len_range = [
0, (max_encoder_input_len + 1) // 2, max_encoder_input_len
1, (max_encoder_input_len + 1) // 2, max_encoder_input_len
]
past_key_value = []
sequence_length = None

View File

@ -0,0 +1,14 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,456 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
from pathlib import Path
from typing import Optional
from transformers import AutoConfig
from tensorrt_llm import profiler
from tensorrt_llm._utils import pad_vocab_size
from tensorrt_llm.functional import RotaryScalingType, Tensor, recv, send
from tensorrt_llm.layers import (MOE, Attention, AttentionMaskType,
ColumnLinear, Embedding, FusedGatedMLP,
GatedMLP, MoeConfig, PositionEmbeddingType,
PromptTuningEmbedding, RmsNorm)
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models.modeling_utils import (DecoderLayerList,
DecoderModelForCausalLM)
from tensorrt_llm.module import Module
from tensorrt_llm.plugin import init_all_reduce_helper
from tensorrt_llm.quantization import QuantMode
from tensorrt_llm.runtime.lora_manager import LoraConfig
from tensorrt_llm.top_model_mixin import TopModelMixin
from .weight import load_from_fp8_llama, load_from_hf_llama
class GemmaDecoderLayer(Module):
def __init__(self, config, layer_idx):
super().__init__()
self.layer_idx = layer_idx
self.config = config
self.input_layernorm = RmsNorm(normalized_shape=config.hidden_size,
eps=config.norm_epsilon,
dtype=config.dtype)
self.attention = Attention(
config.hidden_size,
config.num_attention_heads,
config.num_key_value_heads,
attention_head_size=config.head_size,
max_position_embeddings=config.max_position_embeddings,
dtype=config.dtype,
attention_mask_type=AttentionMaskType.causal,
bias=config.attn_bias,
position_embedding_type=PositionEmbeddingType.rope_gpt_neox,
rotary_embedding_base=config.rotary_base,
rotary_embedding_scaling=config.rotary_scaling,
tp_group=config.mapping.tp_group,
tp_size=config.mapping.tp_size,
quant_mode=config.quant_mode,
enable_pos_shift=config.enable_pos_shift,
dense_context_fmha=config.dense_context_fmha,
)
# max_lora_rank=config.max_lora_rank)
mlp_hidden_size = config.hidden_size * 4 if config.intermediate_size is None else config.intermediate_size
ClsMLP = GatedMLP
mlp_kwargs = {}
if config.moe_num_experts > 1:
ClsMLP = MOE
mlp_kwargs = {
"moe_config":
MoeConfig(
config.moe_num_experts,
config.moe_top_k,
config.moe_tp_mode,
config.moe_normalization_mode,
),
"tp_rank":
config.mapping.tp_rank,
}
elif config.use_fused_mlp:
ClsMLP = FusedGatedMLP
self.mlp = ClsMLP(
hidden_size=config.hidden_size,
ffn_hidden_size=mlp_hidden_size,
hidden_act=config.hidden_act,
dtype=config.dtype,
bias=config.mlp_bias,
tp_group=config.mapping.tp_group,
tp_size=config.mapping.tp_size,
quant_mode=config.quant_mode,
# max_lora_rank=config.max_lora_rank,
**mlp_kwargs)
self.post_layernorm = RmsNorm(normalized_shape=config.hidden_size,
eps=config.norm_epsilon,
dtype=config.dtype)
def forward(
self,
hidden_states,
attention_mask=None,
medusa_packed_mask=None, # For Medusa support
medusa_position_offsets=None,
use_cache=False,
kv_cache_params=None,
attention_params=None,
lora_layer_params=None):
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
attention_output = self.attention(
hidden_states,
attention_mask=attention_mask,
medusa_packed_mask=medusa_packed_mask, # For Medusa support
medusa_position_offsets=medusa_position_offsets,
use_cache=use_cache,
kv_cache_params=kv_cache_params,
attention_params=attention_params,
lora_layer_params=lora_layer_params)
if use_cache:
attention_output, presents = attention_output
hidden_states = residual + attention_output
residual = hidden_states
hidden_states = self.post_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states,
lora_layer_params=lora_layer_params)
hidden_states = residual + hidden_states
if use_cache:
return (hidden_states, presents)
return hidden_states
class GemmaModel(Module):
def __init__(self, config) -> None:
super().__init__()
init_all_reduce_helper()
self.mapping = config.mapping
self.use_prompt_tuning = config.use_prompt_tuning
EmbeddingCls = PromptTuningEmbedding if config.use_prompt_tuning else Embedding
if self.mapping.is_first_pp_rank():
self.vocab_embedding = EmbeddingCls(
num_embeddings=config.vocab_size,
embedding_dim=config.hidden_size,
dtype=config.dtype,
tp_size=self.mapping.tp_size
if config.use_parallel_embedding else 1,
tp_group=self.mapping.tp_group
if config.use_parallel_embedding else None,
sharding_dim=config.embedding_sharding_dim,
tp_rank=self.mapping.tp_rank,
)
self.layers = DecoderLayerList(GemmaDecoderLayer, config)
if self.mapping.is_last_pp_rank():
self.ln_f = RmsNorm(normalized_shape=config.hidden_size,
eps=config.norm_epsilon,
dtype=config.dtype)
def forward(
self,
input_ids,
position_ids=None,
use_cache=False,
attention_mask=None,
medusa_position_offsets=None, # For Medusa support
medusa_packed_mask=None, # For Medusa support
kv_cache_params=None,
attention_params=None,
hidden_states=None,
prompt_embedding_table: Optional[Tensor] = None,
prompt_tasks: Optional[Tensor] = None,
prompt_vocab_size: Optional[Tensor] = None,
lora_params=None):
kv_cache_params.fill_none_tensor_list(len(self.layers))
if use_cache:
presents = []
ptuning_args = []
# if self.use_prompt_tuning:
# ptuning_args = [
# prompt_embedding_table, prompt_tasks, prompt_vocab_size
# ]
if self.mapping.is_first_pp_rank():
hidden_states = self.vocab_embedding(input_ids, *ptuning_args)
else:
hidden_states = recv(hidden_states, self.mapping.prev_pp_rank())
hidden_states = self.layers.forward(
hidden_states,
use_cache=use_cache,
attention_mask=attention_mask,
kv_cache_params=kv_cache_params,
attention_params=attention_params,
# all_reduce_workspace=all_reduce_workspace,
lora_params=lora_params,
# medusa_position_offsets=medusa_position_offsets,
# medusa_packed_mask=medusa_packed_mask,
)
if use_cache:
hidden_states, presents = hidden_states
if self.mapping.is_last_pp_rank():
hidden_states = self.ln_f(hidden_states)
else:
hidden_states = send(hidden_states, self.mapping.next_pp_rank())
if use_cache:
return (hidden_states, tuple(presents))
return hidden_states
class GemmaForCausalLM(DecoderModelForCausalLM, TopModelMixin):
def __init__(self, config):
self.check_config(config)
transformer = GemmaModel(config)
vocab_size_padded = pad_vocab_size(config.vocab_size,
config.mapping.tp_size)
if config.mapping.is_last_pp_rank():
lm_head = ColumnLinear(config.hidden_size,
vocab_size_padded,
bias=False,
dtype=config.dtype,
tp_group=config.mapping.tp_group,
tp_size=config.mapping.tp_size,
gather_output=True)
else:
lm_head = None
self.quant_mode = config.quant_mode
self.mapping = config.mapping
super().__init__(config, transformer, lm_head)
@classmethod
def from_hugging_face(cls,
hf_model_dir,
dtype='float16',
mapping: Optional[Mapping] = None,
quant_mode: Optional[QuantMode] = None,
**kwargs):
import transformers
from transformers import LlamaConfig
from ...models.modeling_utils import PretrainedConfig
cfg = LlamaConfig.from_pretrained(hf_model_dir)
num_kv_heads = cfg.num_key_value_heads if hasattr(cfg, "num_key_value_heads") \
else cfg.num_attention_heads
if mapping is None:
mapping = Mapping()
if quant_mode is None:
quant_mode = QuantMode(0)
cfg.mapping = mapping
cfg.dtype = dtype
cfg.quant_mode = quant_mode
moe_config = kwargs.get("moe_config", MoeConfig())
cfg.norm_epsilon = cfg.rms_norm_eps
config = {
'architecture': cfg.architectures[0],
'dtype': cfg.dtype,
'logits_dtype': 'float32',
'num_hidden_layers': cfg.num_hidden_layers,
'num_attention_heads': cfg.num_attention_heads,
'hidden_size': cfg.hidden_size,
'intermediate_size': cfg.intermediate_size,
'num_key_value_heads': cfg.num_key_value_heads,
'vocab_size': cfg.vocab_size,
'position_embedding_type': 'rope_gpt_neox',
'max_position_embeddings': cfg.max_position_embeddings,
'hidden_act': cfg.hidden_act,
'rotary_base': getattr(cfg, 'rotary_base', 10000.0),
'rotary_scaling': getattr(cfg, 'rotary_scaling', None),
'norm_epsilon': cfg.rms_norm_eps,
'quantization': quant_mode.to_dict(),
'mapping': {
'world_size': mapping.world_size,
'tp_size': mapping.world_size,
},
'use_parallel_embedding': kwargs.get("use_parallel_embedding",
False),
'embedding_sharding_dim': kwargs.get("embedding_sharding_dim", 0),
'use_prompt_tuning': kwargs.get("use_prompt_tuning", False),
'moe_num_experts': moe_config.num_experts,
'moe_top_k': moe_config.top_k,
'moe_tp_mode': moe_config.tp_mode,
'moe_normalization_mode': moe_config.normalization_mode,
'use_fused_mlp': kwargs.get("use_fused_mlp", False),
'enable_pos_shift': kwargs.get("enable_pos_shift", False),
'dense_context_fmha': kwargs.get("dense_context_fmha", False),
}
if quant_mode.is_int4_weight_only_per_group():
config['quantization'].update({
'zero': False,
'pre_quant_scale': True,
'exclude_modules': [],
})
tllm_llama = GemmaForCausalLM(PretrainedConfig.from_dict(config))
q_weights = {}
if quant_mode.has_any_quant():
q_weights = tllm_llama._quantize(hf_model_dir, dtype, cfg, **kwargs)
# For debug purpose, skip weights loading to be faster
if kwargs.get("skip_loading_weights", False):
return tllm_llama
# TODO: support mixtral
# weights already loaded in _quantize for int4 weight only
if not quant_mode.is_int4_weight_only_per_group():
hf_model = transformers.LlamaForCausalLM
profiler.start("Loading weights from HF")
hf_llama = hf_model.from_pretrained(
hf_model_dir,
device_map={
"model": "cpu",
"lm_head": "cpu",
"embed_tokens": "cpu",
"layers": "cpu",
"norm": "cpu",
}, # Load to CPU memory
torch_dtype='auto',
)
weights = load_from_hf_llama(
tllm_llama,
hf_llama,
mapping=mapping,
dtype=dtype,
# TODO: these shall be outside from_hugging_face too.
use_gemm_woq_plugin=kwargs.get("use_gemm_woq_plugin", False),
lora_config=kwargs.get("lora_config", LoraConfig()),
)
profiler.stop("Loading weights from HF")
del hf_llama
weights.update(q_weights)
tllm_llama.load(weights)
else:
tllm_llama.load(q_weights)
return tllm_llama
def _quantize(self, hf_model_dir, dtype, cfg, **kwargs):
'''Given the quant_mode set in the Module object, read from given hf model
call AMMO to generate quantization scales, and set the scales back the module parameters.
'''
# use self destructed temporary path if kwargs[quantization_cache_dir] is not specified
# sometimes the quantization checkpoint path needs to be saved for debug purpose
quantized_temp_dir = tempfile.TemporaryDirectory("llama-quantized")
quantized_checkpoint_path = kwargs.get("quantization_cache_dir",
quantized_temp_dir.name)
quantize_lm_head = kwargs.get("quantize_lm_head", False)
quant_mode = cfg.quant_mode
ammo_qformat = None
calib_size = None
if quant_mode.has_fp8_qdq() or quant_mode.has_fp8_kv_cache():
ammo_qformat = 'fp8'
calib_size = 512
# TODO: how to distinguish from quant_mode about int4_awq or int4_gptq?
elif quant_mode.is_int4_weight_only_per_group():
ammo_qformat = 'int4_awq'
calib_size = 32
assert ammo_qformat is not None
# local import to avoid pytest issue when importing AMMO and transformers lib
from .quantize import quantize_llama_and_export
quantize_llama_and_export(hf_model_dir,
quantized_checkpoint_path,
ammo_qformat,
dtype,
calib_size=calib_size,
quantize_lm_head=quantize_lm_head)
ckpt = Path(quantized_checkpoint_path) / "llama_tp1_rank0.npz"
assert ckpt.exists(), f"The expecting checkpoint path {ckpt} does not exist" \
"it's likely quantization failed, pls check error logs"
hf_config = AutoConfig.from_pretrained(hf_model_dir,
trust_remote_code=True)
if ammo_qformat == 'fp8':
return load_from_fp8_llama(
str(ckpt),
hf_config,
cfg.mapping,
fp8_kv_cache=quant_mode.has_fp8_kv_cache())
else:
return load_from_awq_llama(str(ckpt),
hf_config,
cfg.mapping,
dtype=dtype)
# llama specific setters, user shall has the chance to change the module attributes after
# from_hugging_face factory method created the model when these attributes is not included in the huggingface checkpoint
def rotary_base(self, val):
for decoder in self.layers:
decoder.attention.rotary_embedding_base = val
return self
def rotary_scaling(self, scaling_type, factor):
# TODO: what if there are some other behaviors triggered by the these changes?
# should implement these assignment as setters of the Attention Module
assert scaling_type in ("linear", "dynamic"), f"Got {scaling_type}"
assert factor > 1.0, f"Got {factor}"
for decoder in self.layers:
decoder.attention.rotary_embedding_scale_type = RotaryScalingType.linear if scaling_type == "linear" else RotaryScalingType.dynamic
decoder.attention.rotary_embedding_scale = factor
return self
def default_plugin_config(self, **kwargs):
plugin_config = super().default_plugin_config(**kwargs)
if self.quant_mode.is_int4_weight_only_per_group():
plugin_config.set_weight_only_groupwise_quant_matmul_plugin()
return plugin_config
def check_config(self, config):
config.set_if_not_exist('use_parallel_embedding', False)
config.set_if_not_exist('embedding_sharding_dim', 0)
config.set_if_not_exist('mlp_bias', False)
config.set_if_not_exist('attn_bias', False)
config.set_if_not_exist('rotary_base', 10000.0)
config.set_if_not_exist('rotary_scaling', None)
config.set_if_not_exist('enable_pos_shift', False)
config.set_if_not_exist('dense_context_fmha', False)
config.set_if_not_exist('use_fused_mlp', False)
config.set_if_not_exist('moe_num_experts', 0)
config.set_if_not_exist('moe_top_k', 0)
config.set_if_not_exist('moe_tp_mode',
MoeConfig.ParallelismMode.TENSOR_PARALLEL)
config.set_if_not_exist(
'moe_normalization_mode',
MoeConfig.ExpertScaleNormalizationMode.RENORMALIZE)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,681 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import configparser
import os
import time
from pathlib import Path
from typing import Dict, List, Optional, Union
import numpy as np
import torch
import tensorrt_llm
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models.quantized.quant import get_dummy_quant_scales
from tensorrt_llm.quantization import QuantMode
def get_scaling_factors(
model_path: Union[str, Path],
num_layers: int,
quant_mode: Optional[QuantMode] = None,
) -> Optional[Dict[str, List[int]]]:
""" Get the scaling factors for LLaMA model
Returns a dictionary of scaling factors for the selected layers of the
LLaMA model.
Args:
model_path (str): Path to the quantized LLaMA model
layers (list): List of layers to get the scaling factors for. If None,
all layers are selected.
Returns:
dict: Dictionary of scaling factors for the selected layers of the
LLaMA model.
example:
{
'qkv_act': qkv_act_scale,
'qkv_weights': qkv_weights_scale,
'qkv_output' : qkv_outputs_scale,
'dense_act': dense_act_scale,
'dense_weights': dense_weights_scale,
'fc_act': fc_act_scale,
'fc_weights': fc_weights_scale,
'gate_act': gate_act_scale,
'gate_weights': gate_weights_scale,
'proj_act': proj_act_scale,
'proj_weights': proj_weights_scale,
}
"""
if model_path is None:
tensorrt_llm.logger.warning(
f"--quantized_fp8_model_path not specified. "
f"Initialize quantization scales automatically.")
return get_dummy_quant_scales(num_layers)
weight_dict = np.load(model_path)
# yapf: disable
scaling_factor = {
'qkv_act': [],
'qkv_weights': [],
'dense_act': [],
'dense_weights': [],
'fc_act': [],
'fc_weights': [],
'gate_act': [],
'gate_weights': [],
'proj_act': [],
'proj_weights': [],
}
if quant_mode is not None and quant_mode.has_fp8_kv_cache():
scaling_factor['qkv_output'] = []
for layer in range(num_layers):
scaling_factor['qkv_act'].append(max(
weight_dict[f'_np:layers:{layer}:attention:qkv:q:activation_scaling_factor'].item(),
weight_dict[f'_np:layers:{layer}:attention:qkv:k:activation_scaling_factor'].item(),
weight_dict[f'_np:layers:{layer}:attention:qkv:v:activation_scaling_factor'].item()
))
scaling_factor['qkv_weights'].append(max(
weight_dict[f'_np:layers:{layer}:attention:qkv:q:weights_scaling_factor'].item(),
weight_dict[f'_np:layers:{layer}:attention:qkv:k:weights_scaling_factor'].item(),
weight_dict[f'_np:layers:{layer}:attention:qkv:v:weights_scaling_factor'].item()
))
if quant_mode is not None and quant_mode.has_fp8_kv_cache():
# Not calibrarting KV cache.
scaling_factor['qkv_output'].append(1.0)
scaling_factor['dense_act'].append(
weight_dict[f'_np:layers:{layer}:attention:dense:activation_scaling_factor'].item())
scaling_factor['dense_weights'].append(
weight_dict[f'_np:layers:{layer}:attention:dense:weights_scaling_factor'].item())
scaling_factor['fc_act'].append(weight_dict[f'_np:layers:{layer}:mlp:fc:activation_scaling_factor'].item())
scaling_factor['fc_weights'].append(weight_dict[f'_np:layers:{layer}:mlp:fc:weights_scaling_factor'].item())
scaling_factor['gate_act'].append(weight_dict[f'_np:layers:{layer}:mlp:gate:activation_scaling_factor'].item())
scaling_factor['gate_weights'].append(weight_dict[f'_np:layers:{layer}:mlp:gate:weights_scaling_factor'].item())
scaling_factor['proj_act'].append(weight_dict[f'_np:layers:{layer}:mlp:proj:activation_scaling_factor'].item())
scaling_factor['proj_weights'].append(weight_dict[f'_np:layers:{layer}:mlp:proj:weights_scaling_factor'].item())
# yapf: enable
for k, v in scaling_factor.items():
assert len(v) == num_layers, \
f'Expect scaling factor {k} of length {num_layers}, got {len(v)}'
return scaling_factor
def gen_suffix(rank, use_smooth_quant, quant_per_channel):
suffix = f"{rank}.bin"
if use_smooth_quant:
sq_prefix = "int8."
if quant_per_channel:
sq_prefix += "col."
suffix = sq_prefix + suffix
return suffix
def extract_layer_idx(name):
ss = name.split('.')
for s in ss:
if s.isdigit():
return s
return None
def split(v: Union[np.ndarray, torch.Tensor],
tp_size: int,
tp_rank: int,
dim=0):
if tp_size == 1:
return v
assert len(v.shape) > 1 or dim == 0
if isinstance(v, np.ndarray):
return np.ascontiguousarray(
np.split(v, tp_size, axis=dim)[tp_rank].copy())
else:
assert v.shape[dim] % tp_size == 0, \
'Unable to split: shape={v.shape} (dim={dim}) tp_size={tp_size}.'
split_size = v.shape[dim] // tp_size
return v.split(split_size, dim=dim)[tp_rank].clone().detach()
def dup_kv_weight(v, num_head, tp_size):
assert tp_size % num_head == 0
reps = tp_size // num_head
head_size = v.shape[0] // num_head
v = v.reshape(num_head, head_size,
-1)[:, None, :, :].expand(num_head, reps, head_size,
v.shape[1])
return v.reshape(num_head * reps * head_size, -1).clone().detach()
def parse_bin_config(ini_file):
model_config = configparser.ConfigParser()
model_config.read(ini_file)
n_embd = model_config.getint('gemma', 'hidden_size')
n_head = model_config.getint('gemma', 'num_attention_heads')
n_head_size = model_config.getint('gemma',
'head_size',
fallback=n_embd // n_head)
n_layer = model_config.getint('gemma', 'num_hidden_layers')
n_positions = model_config.getint('gemma', 'max_position_embeddings')
vocab_size = model_config.getint('gemma', 'vocab_size')
hidden_act = model_config.get('gemma', 'hidden_act')
inter_size = model_config.getint('gemma',
'intermediate_size',
fallback=None)
n_kv_head = model_config.getint('gemma',
'num_key_value_heads',
fallback=None)
if inter_size is None:
inter_size = 4 * n_embd
return n_embd, n_head, n_layer, n_positions, vocab_size, hidden_act, inter_size, n_kv_head, n_head_size
def load_from_binary(tensorrt_llm_gemma,
dir_path,
mapping=Mapping(),
fp16=False,
multi_query_mode=False):
tensorrt_llm.logger.info('Loading weights from binary...')
tik = time.time()
quant_mode = getattr(tensorrt_llm_gemma, 'quant_mode', QuantMode(0))
n_embd, n_head, n_layer, n_positions, vocab_size, hidden_act, inter_size, n_kv_head, n_head_size = parse_bin_config(
Path(dir_path) / 'config.ini')
np_dtype = np.float16 if fp16 else np.float32
def fromfile(dir_path, name, shape=None, dtype=None):
dtype = np_dtype if dtype is None else dtype
p = dir_path + '/' + name
if Path(p).exists():
t = np.fromfile(p, dtype=dtype)
if shape is not None:
t = t.reshape(shape)
return t
return None
def set_smoothquant_scale_factors(module,
pre_scale_weight,
dir_path,
basename,
shape,
per_tok_dyn,
per_channel,
is_qkv=False,
rank=None):
suffix = "bin"
if per_channel:
if rank is not None:
suffix = f"{rank}." + suffix
suffix = "col." + suffix
col_shape = shape if (per_channel or is_qkv) else [1, 1]
if per_tok_dyn:
if pre_scale_weight is not None:
pre_scale_weight.value = np.array([1.0], dtype=np.float32)
if is_qkv and not per_channel:
t = fromfile(dir_path,
f"{basename}scale_w_quant_orig.{rank}.{suffix}",
col_shape, np.float32)
else:
t = fromfile(dir_path, f"{basename}scale_w_quant_orig.{suffix}",
col_shape, np.float32)
module.per_channel_scale.value = t
else:
t = fromfile(dir_path, f"{basename}scale_x_orig_quant.bin", [1],
np.float32)
pre_scale_weight.value = t
if is_qkv:
t = fromfile(dir_path,
f"{basename}scale_y_accum_quant.{rank}.{suffix}",
col_shape, np.float32)
else:
t = fromfile(dir_path,
f"{basename}scale_y_accum_quant.{suffix}",
col_shape, np.float32)
module.per_channel_scale.value = t
t = fromfile(dir_path, f"{basename}scale_y_quant_orig.bin", [1, 1],
np.float32)
module.act_scale.value = t
def set_smoother(module, dir_path, base_name, shape, rank):
suffix = f"{rank}.bin"
t = fromfile(dir_path, f"{base_name}.smoother.{suffix}", shape,
np.float32)
module.smoother.value = t
# Determine the quantization mode.
quant_mode = getattr(tensorrt_llm_gemma, "quant_mode", QuantMode(0))
if quant_mode.is_int8_weight_only():
plugin_weight_only_quant_type = torch.int8
elif quant_mode.is_int4_weight_only():
plugin_weight_only_quant_type = torch.quint4x2
# Do we use SmoothQuant?
use_smooth_quant = quant_mode.has_act_and_weight_quant()
# Do we use quantization per token?
quant_per_token_dyn = quant_mode.has_per_token_dynamic_scaling()
# Do we use quantization per channel?
quant_per_channel = quant_mode.has_per_channel_scaling()
# Do we use INT4/INT8 weight-only?
use_weight_only = quant_mode.is_weight_only()
# Int8 KV cache
use_int8_kv_cache = quant_mode.has_int8_kv_cache()
# Debug
suffix = gen_suffix(mapping.tp_rank, use_smooth_quant, quant_per_channel)
# The type of weights.
w_type = np_dtype if not use_smooth_quant else np.int8
if mapping.is_first_pp_rank():
tensorrt_llm_gemma.vocab_embedding.weight.value = (fromfile(
dir_path, 'vocab_embedding.weight.bin', [vocab_size, n_embd]))
if mapping.is_last_pp_rank():
tensorrt_llm_gemma.ln_f.weight.value = (fromfile(
dir_path, 'ln_f.weight.bin'))
# share input embedding
lm_head_weight = fromfile(dir_path, 'lm_head.weight.bin',
[vocab_size, n_embd])
if vocab_size % mapping.tp_size != 0:
# padding
vocab_size_padded = tensorrt_llm_gemma.lm_head.out_features * mapping.tp_size
pad_width = vocab_size_padded - vocab_size
lm_head_weight = np.pad(lm_head_weight, ((0, pad_width), (0, 0)),
'constant',
constant_values=0)
if mapping.is_last_pp_rank():
tensorrt_llm_gemma.lm_head.weight.value = np.ascontiguousarray(
split(lm_head_weight, mapping.tp_size, mapping.tp_rank))
layers_per_pipeline_stage = tensorrt_llm_gemma.num_layers // mapping.pp_size
layers_range = list(
range(mapping.pp_rank * layers_per_pipeline_stage,
(mapping.pp_rank + 1) * layers_per_pipeline_stage, 1))
# This code does not support the case where the number of ranks is greater than the number of K/V heads for GQA.
assert (n_kv_head % mapping.tp_size == 0) or (n_kv_head == 1)
# Compute the number of K/V heads per rank. It's 1 for MQA.
kv_heads_per_rank = min(1, n_kv_head // mapping.tp_size)
# The N-dimension for each rank of the QKV matrix is number of columns for Q + 2 * number of columns for K/V.
if multi_query_mode:
c_attn_out_dim = n_head * n_head_size // mapping.tp_size + 2 * kv_heads_per_rank * n_head_size
else:
c_attn_out_dim = 3 * (n_head * n_head_size) // mapping.tp_size
for i in layers_range:
idx = i - mapping.pp_rank * layers_per_pipeline_stage
tensorrt_llm_gemma.layers[idx].input_layernorm.weight.value = (fromfile(
dir_path, 'model.layers.' + str(i) + '.input_layernorm.weight.bin'))
t = fromfile(
dir_path, 'model.layers.' + str(i) +
'.attention.query_key_value.weight.' + suffix,
[n_embd, c_attn_out_dim], w_type)
if t is not None:
dst = tensorrt_llm_gemma.layers[idx].attention.qkv.weight
if use_smooth_quant:
dst.value = np.ascontiguousarray(np.transpose(t, [1, 0]))
set_smoothquant_scale_factors(
tensorrt_llm_gemma.layers[idx].attention.qkv,
tensorrt_llm_gemma.layers[idx].input_layernorm.scale_to_int,
dir_path,
'model.layers.' + str(i) + '.attention.query_key_value.',
[1, c_attn_out_dim],
quant_per_token_dyn,
quant_per_channel,
rank=mapping.tp_rank,
is_qkv=True)
elif use_weight_only:
processed_torch_weights, torch_weight_scales = torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix(
torch.tensor(t), plugin_weight_only_quant_type)
dst.value = processed_torch_weights.numpy()
scales = tensorrt_llm_gemma.layers[
idx].attention.qkv.per_channel_scale
scales.value = torch_weight_scales.numpy()
else:
dst.value = np.ascontiguousarray(np.transpose(t, [1, 0]))
dst = tensorrt_llm_gemma.layers[idx].attention.dense.weight
t = fromfile(
dir_path,
'model.layers.' + str(i) + '.attention.dense.weight.' + suffix,
[(n_head * n_head_size) // mapping.tp_size, n_embd], w_type)
if use_smooth_quant:
dst.value = np.ascontiguousarray(np.transpose(t, [1, 0]))
dense_scale = getattr(tensorrt_llm_gemma.layers[idx].attention,
"quantization_scaling_factor", None)
set_smoothquant_scale_factors(
tensorrt_llm_gemma.layers[idx].attention.dense, dense_scale,
dir_path, 'model.layers.' + str(i) + '.attention.dense.',
[1, n_embd], quant_per_token_dyn, quant_per_channel)
set_smoother(tensorrt_llm_gemma.layers[idx].attention.dense,
dir_path,
'model.layers.' + str(i) + '.attention.dense',
[1, n_embd // mapping.tp_size], mapping.tp_rank)
elif use_weight_only:
processed_torch_weights, torch_weight_scales = torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix(
torch.tensor(t), plugin_weight_only_quant_type)
dst.value = processed_torch_weights.numpy()
scales = tensorrt_llm_gemma.layers[
idx].attention.dense.per_channel_scale
scales.value = torch_weight_scales.numpy()
else:
dst.value = np.ascontiguousarray(np.transpose(t, [1, 0]))
dst = tensorrt_llm_gemma.layers[idx].post_layernorm.weight
dst.value = fromfile(
dir_path, 'model.layers.' + str(i) + '.post_layernorm.weight.bin')
t = fromfile(dir_path,
'model.layers.' + str(i) + '.mlp.fc.weight.' + suffix,
[n_embd, inter_size // mapping.tp_size], w_type)
if use_smooth_quant:
tensorrt_llm_gemma.layers[
idx].mlp.fc.weight.value = np.ascontiguousarray(
np.transpose(t, [1, 0]))
set_smoothquant_scale_factors(
tensorrt_llm_gemma.layers[idx].mlp.fc,
tensorrt_llm_gemma.layers[idx].post_layernorm.scale_to_int,
dir_path,
'model.layers.' + str(i) + '.mlp.fc.',
[1, inter_size // mapping.tp_size],
quant_per_token_dyn,
quant_per_channel,
rank=mapping.tp_rank)
elif use_weight_only:
dst = tensorrt_llm_gemma.layers[idx].mlp.fc.weight
processed_torch_weights, torch_weight_scales = torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix(
torch.tensor(t), plugin_weight_only_quant_type)
dst.value = processed_torch_weights.numpy()
scales = tensorrt_llm_gemma.layers[idx].mlp.fc.per_channel_scale
scales.value = torch_weight_scales.numpy()
else:
tensorrt_llm_gemma.layers[
idx].mlp.fc.weight.value = np.ascontiguousarray(
np.transpose(t, [1, 0]))
t = fromfile(dir_path,
'model.layers.' + str(i) + '.mlp.gate.weight.' + suffix,
[n_embd, inter_size // mapping.tp_size], w_type)
if use_smooth_quant:
tensorrt_llm_gemma.layers[
idx].mlp.gate.weight.value = np.ascontiguousarray(
np.transpose(t, [1, 0]))
set_smoothquant_scale_factors(
tensorrt_llm_gemma.layers[idx].mlp.gate,
tensorrt_llm_gemma.layers[idx].post_layernorm.scale_to_int,
dir_path,
'model.layers.' + str(i) + '.mlp.gate.',
[1, inter_size // mapping.tp_size],
quant_per_token_dyn,
quant_per_channel,
rank=mapping.tp_rank)
elif use_weight_only:
dst = tensorrt_llm_gemma.layers[idx].mlp.gate.weight
processed_torch_weights, torch_weight_scales = torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix(
torch.tensor(t), plugin_weight_only_quant_type)
dst.value = processed_torch_weights.numpy()
scales = tensorrt_llm_gemma.layers[idx].mlp.gate.per_channel_scale
scales.value = torch_weight_scales.numpy()
else:
tensorrt_llm_gemma.layers[
idx].mlp.gate.weight.value = np.ascontiguousarray(
np.transpose(t, [1, 0]))
t = fromfile(dir_path,
'model.layers.' + str(i) + '.mlp.proj.weight.' + suffix,
[inter_size // mapping.tp_size, n_embd], w_type)
if use_smooth_quant:
tensorrt_llm_gemma.layers[
idx].mlp.proj.weight.value = np.ascontiguousarray(
np.transpose(t, [1, 0]))
proj_scale = getattr(tensorrt_llm_gemma.layers[idx].mlp,
"quantization_scaling_factor", None)
set_smoothquant_scale_factors(
tensorrt_llm_gemma.layers[idx].mlp.proj, proj_scale, dir_path,
'model.layers.' + str(i) + '.mlp.proj.', [1, n_embd],
quant_per_token_dyn, quant_per_channel)
set_smoother(tensorrt_llm_gemma.layers[idx].mlp.proj, dir_path,
'model.layers.' + str(i) + '.mlp.proj',
[1, inter_size // mapping.tp_size], mapping.tp_rank)
elif use_weight_only:
dst = tensorrt_llm_gemma.layers[idx].mlp.proj.weight
processed_torch_weights, torch_weight_scales = torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix(
torch.tensor(t), plugin_weight_only_quant_type)
dst.value = processed_torch_weights.numpy()
scales = tensorrt_llm_gemma.layers[idx].mlp.proj.per_channel_scale
scales.value = torch_weight_scales.numpy()
else:
tensorrt_llm_gemma.layers[idx].mlp.proj.weight.value = (
np.ascontiguousarray(np.transpose(t, [1, 0])))
if use_int8_kv_cache:
t = fromfile(
dir_path, 'model.layers.' + str(i) +
'.attention.query_key_value.scale_y_quant_orig.bin', [1],
np.float32)
tensorrt_llm_gemma.layers[
idx].attention.kv_cache_scaling_factor.value = t
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
tensorrt_llm.logger.info(f'Weights loaded. Total time: {t}')
def load_from_hf_llama():
# leave for preventing import issue
pass
def quantize_fp8_weigths(weights, num_layers, mapping):
def get_scaling_factor(weight):
amax = weight.max()
scale = 448.0 / amax
return scale
layers_range = mapping.pp_layers(num_layers)
scaling_factors = {}
scaled_weights = {}
trt_llm_prefix = "transformer.layers"
for l in layers_range:
# attention.qkv.weight
for name in [
"attention.qkv", "attention.dense", "mlp.fc", "mlp.gate",
"mlp.proj"
]:
trt_llm_name = ".".join((trt_llm_prefix, str(l), name, "weight"))
scale_name = ".".join(
(trt_llm_prefix, str(l), name, "weights_scaling_factor"))
weight = weights[trt_llm_name]
dtype = weights[trt_llm_name].dtype
scale = get_scaling_factor(weight)
scaled_weights[trt_llm_name] = np.ascontiguousarray(
(weight * scale).astype(dtype))
scaling_factors[scale_name] = np.asarray([1 / scale
]).astype(np.float32)
return scaling_factors
def load_from_fp8_llama(quant_ckpt_path: str, num_layers: int, mapping: Mapping,
fp8_kv_cache: bool, weight_scales: dict):
"""
Get the fp8 scaling factors.
"""
fake_fp8_sf_dt = torch.float32
if quant_ckpt_path is not None and os.path.isfile(quant_ckpt_path):
fp8_llama = np.load(quant_ckpt_path)
else:
fp8_llama = None
tensorrt_llm.logger.info(
f"There is not quantized checkpoint, use dummy fp8 scaling factors instead."
)
weights = {}
def get_fp8_llama(name):
if fp8_llama is not None:
return fp8_llama[name]
else:
return torch.tensor([1.0], dtype=fake_fp8_sf_dt).numpy()
layers_range = mapping.pp_layers(num_layers)
for l in layers_range:
prefix = f'_np:layers:{l}'
tllm_prex = f'transformer.layers.{l-layers_range[0]}'
weights[f'{tllm_prex}.attention.qkv.activation_scaling_factor'] = max(
get_fp8_llama(
f'{prefix}:attention:qkv:q:activation_scaling_factor'),
get_fp8_llama(
f'{prefix}:attention:qkv:k:activation_scaling_factor'),
get_fp8_llama(
f'{prefix}:attention:qkv:v:activation_scaling_factor'))
weights[f'{tllm_prex}.attention.qkv.weights_scaling_factor'] = max(
get_fp8_llama(f'{prefix}:attention:qkv:q:weights_scaling_factor'),
get_fp8_llama(f'{prefix}:attention:qkv:k:weights_scaling_factor'),
get_fp8_llama(f'{prefix}:attention:qkv:v:weights_scaling_factor'))
weights[
f'{tllm_prex}.attention.dense.activation_scaling_factor'] = get_fp8_llama(
f'{prefix}:attention:dense:activation_scaling_factor')
weights[
f'{tllm_prex}.attention.dense.weights_scaling_factor'] = get_fp8_llama(
f'{prefix}:attention:dense:weights_scaling_factor')
weights[
f'{tllm_prex}.mlp.fc.activation_scaling_factor'] = get_fp8_llama(
f'{prefix}:mlp:fc:activation_scaling_factor')
weights[f'{tllm_prex}.mlp.fc.weights_scaling_factor'] = get_fp8_llama(
f'{prefix}:mlp:fc:weights_scaling_factor')
weights[
f'{tllm_prex}.mlp.gate.activation_scaling_factor'] = get_fp8_llama(
f'{prefix}:mlp:gate:activation_scaling_factor')
weights[f'{tllm_prex}.mlp.gate.weights_scaling_factor'] = get_fp8_llama(
f'{prefix}:mlp:gate:weights_scaling_factor')
weights[
f'{tllm_prex}.mlp.proj.activation_scaling_factor'] = get_fp8_llama(
f'{prefix}:mlp:proj:activation_scaling_factor')
weights[f'{tllm_prex}.mlp.proj.weights_scaling_factor'] = get_fp8_llama(
f'{prefix}:mlp:proj:weights_scaling_factor')
if fp8_kv_cache:
# Not calibrarting KV cache.
scaling_factor = 1.0
weights[
f'{tllm_prex}.attention.kv_cache_scaling_factor'] = torch.tensor(
[scaling_factor], dtype=fake_fp8_sf_dt).numpy()
if fp8_llama is None:
weights.update(weight_scales)
return weights
def dummy_scaling_facotr_sq(weights):
for name in list(weights):
if any([
_name in name for _name in [
'mlp.proj.weight', 'mlp.gate.weight', 'mlp.fc.weight',
'attention.qkv.weight', 'attention.dense.weight'
]
]):
print("Processing:", name)
weight = weights[name]
out_dim, in_dim = weight.shape
weights_scaling_factor = (np.abs(weight).max(1, keepdims=True) /
127.)
prequant_scaling_factor = np.ones([in_dim], dtype=weight.dtype)
activation_scaling_factor = np.array([0.1], dtype=np.float32)
int_weight = (weight / weights_scaling_factor).round().astype(
np.int8)
weights[name.replace(
'weight', 'prequant_scaling_factor')] = prequant_scaling_factor
weights[name.replace(
'weight',
'weights_scaling_factor')] = weights_scaling_factor.astype(
np.float32).squeeze(1)
weights[name.replace(
'weight',
'activation_scaling_factor')] = activation_scaling_factor
weights[name] = int_weight
return weights
def dummy_scaling_facotr_kv_cache(weights):
for name in list(weights):
if 'attention.qkv.weight' in name:
kv_cache_scaling_factor = np.array([0.1], dtype=np.float32)
weights[name.replace(
'qkv.weight',
'kv_cache_scaling_factor')] = kv_cache_scaling_factor
def dummy_weights_awq(weights, precision, trt_llm_config, group_size):
packer = torch.ops.trtllm.pack_int8_tensor_to_packed_int4
use_fp8_kv_cache = trt_llm_config.quant_mode.has_fp8_kv_cache()
use_int8_kv_cache = trt_llm_config.quant_mode.has_int8_kv_cache()
num_layers = trt_llm_config.num_hidden_layers
for name in list(weights):
if any([
_name in name for _name in [
'mlp.proj.weight', 'mlp.gate.weight', 'mlp.fc.weight',
'attention.qkv.weight', 'attention.dense.weight'
]
]):
print("Processing:", name)
weight = np.ascontiguousarray(weights[name].T)
in_dim, out_dim = weight.shape
scale = np.amax(weight) / 7
weights_scaling_factor = np.ones([out_dim, in_dim // group_size
]) * scale.astype(np.float32)
weight_smoothed = (weight.astype(np.float32) / scale).astype(
np.int8)
weight_smoothed[weight_smoothed < -8] = -8
weight_smoothed[weight_smoothed > 7] = 7
prequant_scaling_factor = np.ones([in_dim], dtype=weight.dtype)
weights[name] = packer(
torch.from_numpy(weight_smoothed)).T.contiguous().numpy()
weights[name.replace(
'weight', 'prequant_scaling_factor')] = prequant_scaling_factor
weights[name.replace(
'weight',
'weights_scaling_factor')] = weights_scaling_factor.astype(
weight.dtype)
if precision == "w4a8_awq":
alpha = np.array([1], dtype=np.float32)
weights[name.replace('weight', 'alpha')] = alpha
if use_fp8_kv_cache or use_int8_kv_cache:
for l in range(num_layers):
t = np.array([1], dtype=np.float32)
weights[
f"transformer.layers.{l}.attention.kv_cache_scaling_factor"] = t
return weights

View File

@ -992,7 +992,7 @@ class SmoothQuantAttention(Module):
self.rotary_embedding_base = rotary_embedding_base
self.rotary_embedding_dim = 0
if self.position_embedding_type.is_rope():
self.rotary_embedding_dim = hidden_size // num_attention_heads
self.rotary_embedding_dim = self.attention_head_size
self.quant_mode = quant_mode
self.dtype = dtype

View File

@ -1013,11 +1013,16 @@ class GenerationSession(object):
if scfg.output_log_probs:
self.log_probs = torch.zeros(
(self.max_new_tokens, batch_size, scfg.num_beams),
(batch_size, scfg.num_beams, self.max_seq_length),
dtype=torch.float32,
device=self.device)
self.log_probs_tiled = torch.zeros(
(self.max_seq_length, batch_size, scfg.num_beams),
dtype=torch.float32,
device=self.device)
else:
self.log_probs = None
self.log_probs_tiled = None
self.finished = torch.zeros((batch_size, scfg.num_beams),
dtype=torch.uint8,
@ -2422,7 +2427,7 @@ class GenerationSession(object):
this_src_cache_indirection, self.output_ids,
self.new_tokens, self.finished, self.finished,
self.sequence_length_buffer, self.cum_log_probs,
self.log_probs, self.parent_ids,
self.log_probs, self.log_probs_tiled, self.parent_ids,
this_tgt_cache_indirection,
self.beam_hyps_output_ids_tgt,
self.beam_hyps_sequence_lengths_tgt,
@ -2527,6 +2532,10 @@ class GenerationSession(object):
def get_outputs_dict(output_ids):
outputs = {}
outputs['output_ids'] = output_ids
if scfg.output_log_probs:
outputs['log_probs'] = self.log_probs
if scfg.output_cum_log_probs:
outputs['cum_log_probs'] = self.cum_log_probs
if output_sequence_lengths:
outputs[
'sequence_lengths'] = self.sequence_length_buffer.reshape(

View File

@ -451,7 +451,10 @@ class ModelRunner(ModelRunnerMixin):
max_medusa_tokens=pretrained_config.max_draft_len if hasattr(
pretrained_config, 'max_draft_len') else 0,
num_medusa_heads=pretrained_config.num_medusa_heads if hasattr(
pretrained_config, 'num_medusa_heads') else 0)
pretrained_config, 'num_medusa_heads') else 0,
use_custom_all_reduce=build_config.plugin_config.
use_custom_all_reduce,
)
max_batch_size = build_config.max_batch_size
max_input_len = build_config.max_input_len
max_output_len = build_config.max_output_len

View File

@ -12,4 +12,4 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = "0.9.0.dev2024020600"
__version__ = "0.9.0.dev2024022000"