This commit is contained in:
Kevin Xie 2023-10-10 23:22:17 -07:00
parent 279e329b22
commit 027cd518e3
283 changed files with 12052 additions and 9953 deletions

View File

@ -251,7 +251,14 @@ corresponds to your CUDA version. As an example, for CUDA 12.1, use:
pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121
```
[CUDA_ARCHITECTURES in CMake]: https://cmake.org/cmake/help/latest/prop_tgt/CUDA_ARCHITECTURES.html#prop_tgt:CUDA_ARCHITECTURES
- When building models, you might encounter memory-related issues. Such as:
```
[09/23/2023-03:13:00] [TRT] [E] 9: GPTLMHeadModel/layers/0/attention/qkv/PLUGIN_V2_Gemm_0: could not find any supported formats consistent with input/output data types
[09/23/2023-03:13:00] [TRT] [E] 9: [pluginV2Builder.cpp::reportPluginError::24] Error Code 9: Internal Error (GPTLMHeadModel/layers/0/attention/qkv/PLUGIN_V2_Gemm_0: could not find any supported formats consistent with input/output data types)
```
You can reduce the memory pressure by lowering the maximum batch size, input and output lengths. Another option is to enable plugins, for example: `--use_gpt_attention_plugin`.
- [CUDA_ARCHITECTURES in CMake]: https://cmake.org/cmake/help/latest/prop_tgt/CUDA_ARCHITECTURES.html#prop_tgt:CUDA_ARCHITECTURES
## Release notes

View File

@ -25,9 +25,17 @@ add_subdirectory(${CXXOPTS_SRC_DIR} ${CMAKE_CURRENT_BINARY_DIR}/cxxopts)
function(add_benchmark test_name test_src)
add_executable(${test_name} ${test_src})
target_link_libraries(
${test_name} PUBLIC ${SHARED_TARGET} nvinfer_plugin_tensorrt_llm
cxxopts::cxxopts)
if(NOT WIN32) # Linux
target_link_libraries(
${test_name} PUBLIC ${SHARED_TARGET} nvinfer_plugin_tensorrt_llm
cxxopts::cxxopts)
else()
# Use STATIC_TARGET on Windows because MSVC is picky about duplicate symbols
# if the shared and static libs both get linked
target_link_libraries(
${test_name} PUBLIC ${STATIC_TARGET} nvinfer_plugin_tensorrt_llm
cxxopts::cxxopts)
endif()
target_compile_features(${test_name} PRIVATE cxx_std_17)
target_compile_definitions(${test_name}
@ -37,3 +45,4 @@ endfunction()
add_benchmark(gptSessionBenchmark gptSessionBenchmark.cpp)
add_benchmark(bertBenchmark bertBenchmark.cpp)
add_benchmark(gptManagerBenchmark gptManagerBenchmark.cpp)

View File

@ -15,7 +15,7 @@ cd cpp/build
make -j benchmarks
```
### 2. Launch C++ benchmarking
### 2. Launch C++ benchmarking (Fixed BatchSize/InputLen/OutputLen)
Before you launch C++ benchmarking, please make sure that you have already built engine(s) using TensorRT-LLM API, C++ benchmarking code cannot generate engine(s) for you.
@ -55,3 +55,49 @@ mpirun -n 8 ./benchmarks/gptSessionBenchmark \
```
*Please note that the expected outputs in that document are only for reference, specific performance numbers depend on the GPU you're using.*
### 3. Launch Batch Manager benchmarking (Inflight/V1 batching)
#### Prepare dataset
Run a preprocessing script to prepare dataset. This script converts the prompts(string) in the dataset to input_ids.
```
python3 prepare_dataset.py \
--dataset <path/to/dataset> \
--max_input_len 300 \
--tokenizer_dir <path/to/tokenizer> \
--tokenizer_type auto \
--output preprocessed_dataset.json
```
For `tokenizer_dir`, specifying the path to the local tokenizer that have already been downloaded, or simply the name of the tokenizer from HuggingFace like `gpt2` will both work. The tokenizer will be downloaded automatically for the latter case.
#### Prepare TensorRT-LLM engines
Please make sure that the engines are built with argument `--use_inflight_batching` and `--remove_input_padding` if you'd like to benchmark inflight batching, for more details, please see the document in TensorRT-LLM examples.
#### Launch benchmarking
For detailed usage, you can do the following
```
cd cpp/build
# You can directly execute the binary for help information
./benchmarks/gptManagerBenchmark --help
```
Take GPT-350M as an example for single GPU V1 batching
```
./benchmarks/gptManagerBenchmark \
--model gpt \
--engine_dir ../../examples/gpt/trt_engine/gpt2/fp16/1-gpu/ \
--type V1 \
--dataset ../../benchmarks/cpp/preprocessed_dataset.json
```
Take GPT-350M as an example for 2-GPU inflight batching
```
mpirun -n 2 ./benchmarks/gptManagerBenchmark \
--model gpt \
--engine_dir ../../examples/gpt/trt_engine/gpt2-ib/fp16/2-gpu/ \
--type IFB \
--dataset ../../benchmarks/cpp/preprocessed_dataset.json
```

View File

@ -78,7 +78,7 @@ void benchmarkBert(std::string const& modelName, std::filesystem::path const& da
{
auto const worldConfig = WorldConfig::mpi(*logger);
auto const enginePath = dataPath / engineFilename(dataPath, worldConfig, modelName);
auto engineBlob = loadEngine(enginePath);
auto engineBlob = loadEngine(enginePath.string());
auto rt = std::make_shared<TllmRuntime>(engineBlob.data(), engineBlob.size(), *logger);
rt->addContext(0);
@ -180,7 +180,8 @@ int main(int argc, char* argv[])
if (!result.count("engine_dir"))
{
std::cout << options.help() << std::endl;
throw std::invalid_argument("Please specify engine directory.");
TLLM_LOG_ERROR("Please specify engine directory.");
return 1;
}
// Argument: Batch sizes
@ -226,11 +227,20 @@ int main(int argc, char* argv[])
}
else
{
throw std::invalid_argument("Unexpected log level: " + logLevel);
TLLM_LOG_ERROR("Unexpected log level: " + logLevel);
return 1;
}
initTrtLlmPlugins(logger.get());
benchmarkBert(result["model"].as<std::string>(), result["engine_dir"].as<std::string>(), batchSizes, inLens, logger,
result["warm_up"].as<int>(), result["num_runs"].as<int>(), result["duration"].as<int>());
try
{
benchmarkBert(result["model"].as<std::string>(), result["engine_dir"].as<std::string>(), batchSizes, inLens,
logger, result["warm_up"].as<int>(), result["num_runs"].as<int>(), result["duration"].as<int>());
}
catch (const std::exception& e)
{
TLLM_LOG_ERROR(e.what());
return 1;
}
return 0;
}

View File

@ -0,0 +1,645 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/batch_manager/GptManager.h"
#include "tensorrt_llm/batch_manager/NamedTensor.h"
#include "tensorrt_llm/batch_manager/callbacks.h"
#include "tensorrt_llm/batch_manager/inferenceRequest.h"
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/common/memoryUtils.h"
#include "tensorrt_llm/common/mpiUtils.h"
#include "tensorrt_llm/plugins/api/tllmPlugin.h"
#include "tensorrt_llm/runtime/gptJsonConfig.h"
#include "tensorrt_llm/runtime/gptSession.h"
#include "tensorrt_llm/runtime/tllmLogger.h"
#include <NvInfer.h>
#include <NvInferPlugin.h>
#include <chrono>
#include <cxxopts.hpp>
#include <iostream>
#include <nlohmann/json.hpp>
#include <string>
using namespace tensorrt_llm::runtime;
using namespace tensorrt_llm::batch_manager;
using namespace tensorrt_llm::runtime;
using namespace tensorrt_llm::mpi;
namespace tc = tensorrt_llm::common;
namespace trt = nvinfer1;
// Class holding all infos regarding a single work item.
// This includes the original request, associated response factor
// and state.
class WorkItem
{
public:
WorkItem(std::shared_ptr<InferenceRequest> ir, uint64_t RequestId)
: mInferenceRequest(ir)
, mRequestId(RequestId)
{
}
~WorkItem() {}
uint64_t requestId() const
{
return mRequestId;
}
std::shared_ptr<InferenceRequest> getInferenceRequest() const
{
return mInferenceRequest;
}
private:
std::shared_ptr<InferenceRequest> mInferenceRequest;
uint64_t mRequestId;
};
/// @brief Thread-safe queue of work items
class WorkItemsQueue
{
public:
void clear()
{
std::lock_guard<std::mutex> lk(mMutex);
mPendingWorkItems.clear();
mPendingWorkItemsReqIds.clear();
mInProgressWorkItems.clear();
}
// Note: this function only be called under a lock
bool hasInProgressReqId(const uint64_t reqId) const
{
return (mInProgressWorkItems.find(reqId) != mInProgressWorkItems.end());
}
// Note: this function only be called under a lock
bool hasPendingReqId(const uint64_t reqId) const
{
return (mPendingWorkItemsReqIds.find(reqId) != mPendingWorkItemsReqIds.end());
}
bool empty() const
{
return mPendingWorkItems.empty() && mInProgressWorkItems.empty() && mPendingWorkItemsReqIds.empty();
}
/// @brief Add a new work item to the queue
/// Throws an error if requestId already exists
void push(std::shared_ptr<InferenceRequest> request, uint64_t requestId)
{
std::lock_guard<std::mutex> lk(mMutex);
if (hasInProgressReqId(requestId) || hasPendingReqId(requestId))
{
std::string errStr
= "requestId " + std::to_string(requestId) + " is already in progress, request is ignored.";
throw std::runtime_error(errStr);
}
else
{
auto workItem = std::make_shared<WorkItem>(request, requestId);
mPendingWorkItems.push_back(workItem);
mPendingWorkItemsReqIds.insert(workItem->requestId());
}
}
/// @brief Get a new work item from the queue, and move it to the list of
/// in progress work items if it hasn't been stopped
/// @return A tuple of the workItem and a boolean flag indicating if the work item
/// has been marked in progress
std::tuple<std::shared_ptr<WorkItem>, bool> pop()
{
std::lock_guard<std::mutex> lk(mMutex);
auto workItem = mPendingWorkItems.front();
mPendingWorkItems.pop_front();
mPendingWorkItemsReqIds.erase(workItem->requestId());
bool markedInProgress;
mInProgressWorkItems.emplace(std::make_pair(workItem->requestId(), workItem));
markedInProgress = true;
return {workItem, markedInProgress};
}
size_t numPendingWorkItems() const
{
std::lock_guard<std::mutex> lk(mMutex);
return mPendingWorkItems.size();
}
size_t numInProgressWorkItems() const
{
std::lock_guard<std::mutex> lk(mMutex);
return mInProgressWorkItems.size();
}
size_t size() const
{
return numPendingWorkItems() + numInProgressWorkItems();
}
/// @brief Mark a request as being finished
/// @param requestId
void markFinished(const uint64_t requestId)
{
std::lock_guard<std::mutex> lk(mMutex);
if (hasInProgressReqId(requestId))
{
mInProgressWorkItems.erase(requestId);
}
}
private:
/// Queue of work items
std::list<std::shared_ptr<WorkItem>> mPendingWorkItems;
/// requestIds of work items in the queue
std::set<uint64_t> mPendingWorkItemsReqIds;
/// work items currently in progress
std::unordered_map<uint64_t, std::shared_ptr<WorkItem>> mInProgressWorkItems;
mutable std::mutex mMutex;
};
struct BenchInfo
{
BenchInfo() {}
BenchInfo(int _inputLength, int _outputLength, std::chrono::time_point<std::chrono::steady_clock> _start)
: inputLength(_inputLength)
, outputLength(_outputLength)
, start(_start)
{
}
int inputLength;
int outputLength;
std::chrono::time_point<std::chrono::steady_clock> start;
std::chrono::time_point<std::chrono::steady_clock> end;
float latency; // millisecond
};
class Recorder
{
public:
Recorder() {}
void initialize()
{
mStart = std::chrono::steady_clock::now();
}
void finalize()
{
mEnd = std::chrono::steady_clock::now();
}
void recordStart(std::shared_ptr<InferenceRequest> request, uint64_t requestId)
{
const auto& input_ids_tensor = request->getInputTensor("input_ids");
std::vector<int64_t> tensorShape(input_ids_tensor->getShape().nbDims);
auto const inputLength = tensorShape[1];
auto const [specified, outputLength]
= request->getScalarValueFromTensor<int>("request_output_len", {1, 1}, false);
assert(specified);
auto const start = std::chrono::steady_clock::now();
mRequestBenchInfos[requestId] = BenchInfo(inputLength, outputLength, start);
}
void recordEnd(uint64_t requestId)
{
mRequestBenchInfos[requestId].end = std::chrono::steady_clock::now();
mRequestBenchInfos[requestId].latency = std::chrono::duration<float, std::milli>(
mRequestBenchInfos[requestId].end - mRequestBenchInfos[requestId].start)
.count();
}
void calculateMetrics()
{
mNumSamples = mRequestBenchInfos.size();
mTotalLatency = std::chrono::duration<float, std::milli>(mEnd - mStart).count();
mSeqThroughput = mNumSamples / (mTotalLatency / 1000);
mAvgSeqLatency = 0;
int totalOutputTokens = 0;
for (auto reqInfo : mRequestBenchInfos)
{
mAvgSeqLatency += reqInfo.second.latency;
totalOutputTokens += reqInfo.second.outputLength;
}
mAvgSeqLatency /= mNumSamples;
mTokenThroughput = totalOutputTokens / (mTotalLatency / 1000);
}
void report()
{
printf("[BENCHMARK] num_samples(ms) %d\n", mNumSamples);
printf("[BENCHMARK] total_latency(ms) %.2f\n", mTotalLatency);
printf("[BENCHMARK] seq_throughput(seq/sec) %.2f\n", mSeqThroughput);
printf("[BENCHMARK] avg_sequence_latency(ms) %.2f\n", mAvgSeqLatency);
printf("[BENCHMARK] token_throughput(token/sec) %.2f\n", mTokenThroughput);
}
private:
std::unordered_map<uint64_t, BenchInfo> mRequestBenchInfos;
std::chrono::time_point<std::chrono::steady_clock> mStart;
std::chrono::time_point<std::chrono::steady_clock> mEnd;
int mNumSamples;
float mTotalLatency;
float mSeqThroughput;
float mAvgSeqLatency;
float mTokenThroughput;
}; // class Recorder
class GptServer
{
public:
GptServer(std::filesystem::path const& trtEnginePath, TrtGptModelType modelType, int32_t maxBeamWidth,
batch_scheduler::SchedulerPolicy schedulerPolicy, std::optional<int32_t> maxNumSequences,
std::optional<int32_t> maxTokensInPagedKvCache, std::optional<float> kvCacheFreeGpuMemFraction,
std::shared_ptr<Recorder> recorder)
{
const TrtGptModelOptionalParams& optionalParams
= TrtGptModelOptionalParams(maxNumSequences, maxTokensInPagedKvCache, kvCacheFreeGpuMemFraction);
mBatchManager = std::make_shared<GptManager>(
trtEnginePath, modelType, maxBeamWidth, schedulerPolicy,
[this](int max_num_requests) { return getInferenceRequests(max_num_requests); },
[this](uint64_t requestId, std::list<NamedTensor> response_tensors, bool final_response,
const std::string& errMsg)
{ return sendResponse(requestId, response_tensors, final_response, errMsg); },
nullptr, nullptr, optionalParams);
mRecorder = recorder;
}
~GptServer()
{
mWorkItemsQueue.clear();
}
void enqueue(std::vector<NamedTensor> tensors, uint64_t requestId, bool streaming)
{
// Create InferenceRequest from a set of tensors
auto request = std::make_shared<InferenceRequest>(requestId);
if (requestId == -1)
{
mWorkItemsQueue.push(request, requestId);
return;
}
for (auto t : tensors)
{
request->emplaceInputTensor(t.name, std::move(t.tensor));
}
request->setIsStreaming(streaming);
// Enqueue
try
{
mRecorder->recordStart(request, requestId);
mWorkItemsQueue.push(request, requestId);
}
catch (const std::exception& e)
{
throw std::runtime_error(e.what());
}
return;
}
void waitForEmpty() const
{
while (mWorkItemsQueue.size() > 0)
{
}
}
void waitBatchManager() const
{
mBatchManager->waitUntilTerminate();
}
// Return up to max_num_requests inference requests.
std::list<std::shared_ptr<InferenceRequest>> getInferenceRequests(const int max_num_requests)
{
std::list<std::shared_ptr<InferenceRequest>> rval;
if (max_num_requests > 0)
{
auto world_size = getCommWorldSize();
auto rank = getCommWorldRank();
if (rank == 0)
{
int64_t num_new_work_items = std::min(static_cast<int64_t>(mWorkItemsQueue.numPendingWorkItems()),
static_cast<int64_t>(max_num_requests));
if (world_size > 1)
{
bcast(&num_new_work_items, 1, MPI_TYPE_INT64_T, 0, COMM_WORLD);
}
if (num_new_work_items > 0)
{
int count = 0;
while (count < num_new_work_items)
{
auto [workItem, markedInProgress] = mWorkItemsQueue.pop();
if (markedInProgress)
{
rval.emplace_back(workItem->getInferenceRequest());
count++;
}
else
{
std::string warnStr = std::string("request Id ") + std::to_string(workItem->requestId())
+ std::string(" has been stopped. Request is ignored.");
TLLM_LOG_WARNING(warnStr);
sendResponse(workItem->requestId(), {}, true, warnStr);
}
}
if (world_size > 1)
{
std::vector<int64_t> packed;
for (auto ir : rval)
{
auto vpacked = ir->serialize();
packed.push_back(static_cast<int64_t>(vpacked.size()));
packed.insert(
packed.end(), std::move_iterator(vpacked.begin()), std::move_iterator(vpacked.end()));
}
bcast(packed, 0, COMM_WORLD);
}
}
}
else
{
// subordinate ranks hang until master rank sends work
int64_t num_new_work_items;
bcast(&num_new_work_items, 1, MPI_TYPE_INT64_T, 0, COMM_WORLD);
if (num_new_work_items > 0)
{
std::vector<int64_t> packed;
bcast(packed, 0, COMM_WORLD);
int64_t* packed_ptr = packed.data();
for (int64_t count = 0; count < num_new_work_items; ++count)
{
int64_t n = *(packed_ptr++);
auto ir = InferenceRequest::deserialize(packed_ptr);
packed_ptr += n;
rval.emplace_back(ir);
}
}
}
}
return rval;
}
void sendResponse(uint64_t requestId, std::list<NamedTensor> const& response_tensors, bool final_response,
const std::string& errMsg)
{
std::string errStr = std::string("Failed to send response for requestId: ") + std::to_string(requestId);
try
{
if (final_response)
{
mWorkItemsQueue.markFinished(requestId);
mRecorder->recordEnd(requestId);
}
}
catch (const std::exception& e)
{
TLLM_LOG_ERROR(errStr);
}
}
private:
std::shared_ptr<GptManager> mBatchManager;
std::shared_ptr<Recorder> mRecorder;
WorkItemsQueue mWorkItemsQueue;
}; // class GptServer
namespace
{
std::pair<std::vector<std::vector<int32_t>>, std::vector<int32_t>> parseDataset(
std::filesystem::path const& datasetPath)
{
auto constexpr allowExceptions = true;
auto constexpr ingoreComments = true;
TLLM_CHECK_WITH_INFO(
std::filesystem::exists(datasetPath), std::string("File does not exist: ") + datasetPath.string());
std::ifstream jsonStream(datasetPath);
auto json = nlohmann::json::parse(jsonStream, nullptr, allowExceptions, ingoreComments);
std::vector<std::vector<int32_t>> input_ids_list;
std::vector<int32_t> output_ids_list;
for (auto& sample : json)
{
input_ids_list.push_back(sample["input_ids"]);
output_ids_list.push_back(sample["output_len"]);
}
return std::make_pair(input_ids_list, output_ids_list);
}
void benchmarkGptManager(std::string const& modelName, std::filesystem::path const& engineDir, std::string const& type,
std::string const& datasetPath, std::shared_ptr<nvinfer1::ILogger> const& logger,
std::optional<int32_t> maxNumSequences, std::optional<int32_t> maxTokensInPagedKvCache,
std::optional<float> kvCacheFreeGpuMemFraction, batch_scheduler::SchedulerPolicy schedulerPolicy)
{
auto const worldConfig = WorldConfig::mpi(*logger);
TrtGptModelType modelType;
if (type == "V1")
{
modelType = TrtGptModelType::V1;
}
else if (type == "IFB")
{
modelType = TrtGptModelType::InflightFusedBatching;
}
else
{
const std::string errStr = std::string("Unexpected batching type: ") + type;
TLLM_LOG_ERROR(errStr);
}
const int maxBeamWidth = 1;
auto recorder = std::make_shared<Recorder>();
auto gptServer = std::make_shared<GptServer>(engineDir, modelType, maxBeamWidth, schedulerPolicy, maxNumSequences,
maxTokensInPagedKvCache, kvCacheFreeGpuMemFraction, recorder);
auto dataset = parseDataset(datasetPath);
std::vector<std::vector<NamedTensor>> tensors_list;
const auto num_samples = dataset.first.size();
for (int i = 0; i < num_samples; ++i)
{
const auto input_ids = dataset.first[i];
const auto request_output_len = dataset.second[i];
std::vector<int64_t> input_ids_shape = {1, static_cast<int64_t>(input_ids.size())};
auto input_ids_tensor = NamedTensor(nvinfer1::DataType::kINT32, input_ids_shape, "input_ids", input_ids.data());
auto request_output_len_tensor
= NamedTensor(nvinfer1::DataType::kINT32, {1, 1}, "request_output_len", &request_output_len);
std::vector<NamedTensor> tensors = {input_ids_tensor, request_output_len_tensor};
tensors_list.push_back(tensors);
}
if (worldConfig.getRank() == 0)
{
recorder->initialize();
for (int i = 0; i < tensors_list.size(); ++i)
{
gptServer->enqueue(tensors_list[i], 1 + i, false);
}
gptServer->waitForEmpty();
recorder->finalize();
recorder->calculateMetrics();
recorder->report();
gptServer->enqueue({}, -1, false);
}
gptServer->waitBatchManager();
}
} // namespace
int main(int argc, char* argv[])
{
cxxopts::Options options(
"TensorRT-LLM BatchManager Benchmark", "TensorRT-LLM BatchManager Benchmark for GPT and GPT-like models.");
options.add_options()("h,help", "Print usage");
options.add_options()(
"m,model", "Model name specified for engines.", cxxopts::value<std::string>()->default_value("gpt_350m"));
options.add_options()("engine_dir", "Directory that store the engines.", cxxopts::value<std::string>());
options.add_options()(
"type", "Batching type: IFB or V1(non-IFB) batching.", cxxopts::value<std::string>()->default_value("IFB"));
options.add_options()("dataset", "Dataset that is used for benchmarking BatchManager.",
cxxopts::value<std::string>()->default_value(""));
options.add_options()("max_num_sequences", "Max number of Sequences.", cxxopts::value<int>()->default_value("-1"));
options.add_options()(
"max_tokens_in_paged_kvcache", "Max tokens in paged K-V Cache.", cxxopts::value<int>()->default_value("-1"));
options.add_options()("kv_cache_free_gpu_mem_fraction", "K-V Cache Free Gpu Mem Fraction.",
cxxopts::value<float>()->default_value("-1"));
options.add_options()("scheduler_policy", "Choose scheduler policy between max_utilization/guaranteed_completion.",
cxxopts::value<std::string>()->default_value("guaranteed_completion"));
options.add_options()("log_level", "Choose log level between verbose/info/warning/error/internal_error.",
cxxopts::value<std::string>()->default_value("error"));
auto result = options.parse(argc, argv);
if (result.count("help"))
{
std::cout << options.help() << std::endl;
return 0;
}
// Argument: Engine directory
if (!result.count("engine_dir"))
{
std::cout << options.help() << std::endl;
TLLM_LOG_ERROR("Please specify engine directory.");
return 1;
}
// Argument: Batching Type
auto const type = result["type"].as<std::string>();
// Argument: Dataset
auto const datasetPath = result["dataset"].as<std::string>();
// Argument: Max Num Sequences
std::optional<int32_t> maxNumSequences = std::nullopt;
if (result["max_num_sequences"].as<int>() != -1)
{
maxNumSequences = result["max_num_sequences"].as<int>();
}
// Argument: Max tokens in paged K-V Cache
std::optional<int32_t> maxTokensInPagedKvCache = std::nullopt;
if (result["max_tokens_in_paged_kvcache"].as<int>() != -1)
{
maxTokensInPagedKvCache = result["max_tokens_in_paged_kvcache"].as<int>();
}
// Argument: K-V Cache Free Gpu Mem Fraction
std::optional<float> kvCacheFreeGpuMemFraction = std::nullopt;
if (result["kv_cache_free_gpu_mem_fraction"].as<float>() != -1)
{
kvCacheFreeGpuMemFraction = result["kv_cache_free_gpu_mem_fraction"].as<float>();
}
// Argument: Scheduler policy
batch_scheduler::SchedulerPolicy schedulerPolicy;
auto const schedulerPolicyArg = result["scheduler_policy"].as<std::string>();
if (schedulerPolicyArg == "max_utilization")
{
schedulerPolicy = batch_scheduler::SchedulerPolicy::MAX_UTILIZATION;
}
else if (schedulerPolicyArg == "guaranteed_completion")
{
schedulerPolicy = batch_scheduler::SchedulerPolicy::GUARANTEED_COMPLETION;
}
else
{
TLLM_LOG_ERROR("Unexpected scheduler policy: " + schedulerPolicyArg);
return 1;
}
// Argument: Log level
auto logger = std::make_shared<TllmLogger>();
auto const logLevel = result["log_level"].as<std::string>();
if (logLevel == "verbose")
{
logger->setLevel(trt::ILogger::Severity::kVERBOSE);
}
else if (logLevel == "info")
{
logger->setLevel(trt::ILogger::Severity::kINFO);
}
else if (logLevel == "warning")
{
logger->setLevel(trt::ILogger::Severity::kWARNING);
}
else if (logLevel == "error")
{
logger->setLevel(trt::ILogger::Severity::kERROR);
}
else if (logLevel == "internal_error")
{
logger->setLevel(trt::ILogger::Severity::kINTERNAL_ERROR);
}
else
{
TLLM_LOG_ERROR("Unexpected log level: " + logLevel);
return 1;
}
initTrtLlmPlugins(logger.get());
try
{
benchmarkGptManager(result["model"].as<std::string>(), result["engine_dir"].as<std::string>(), type,
datasetPath, logger, maxNumSequences, maxTokensInPagedKvCache, kvCacheFreeGpuMemFraction, schedulerPolicy);
}
catch (const std::exception& e)
{
TLLM_LOG_ERROR(e.what());
return 1;
}
return 0;
}

View File

@ -58,7 +58,7 @@ void benchmarkGptSession(std::string const& modelName, std::filesystem::path con
samplingConfig.topK = std::vector{1};
samplingConfig.topP = std::vector{0.0f};
GptSession session{modelConfig, worldConfig, enginePath, logger};
GptSession session{modelConfig, worldConfig, enginePath.string(), logger};
// Use bufferManager for copying data to and from the GPU
auto& bufferManager = session.getBufferManager();
session.setCudaGraphMode(cudaGraphMode);
@ -178,7 +178,8 @@ int main(int argc, char* argv[])
if (!result.count("engine_dir"))
{
std::cout << options.help() << std::endl;
throw std::invalid_argument("Please specify engine directory.");
TLLM_LOG_ERROR("Please specify engine directory.");
return 1;
}
// Argument: Batch sizes
@ -230,7 +231,8 @@ int main(int argc, char* argv[])
}
else
{
throw std::invalid_argument("Unexpected log level: " + logLevel);
TLLM_LOG_ERROR("Unexpected log level: " + logLevel);
return 1;
}
// Argument: Enable CUDA graph
@ -238,8 +240,16 @@ int main(int argc, char* argv[])
initTrtLlmPlugins(logger.get());
benchmarkGptSession(result["model"].as<std::string>(), result["engine_dir"].as<std::string>(), batchSizes, inOutLen,
logger, result["warm_up"].as<int>(), result["num_runs"].as<int>(), result["duration"].as<int>(),
enableCudaGraph);
try
{
benchmarkGptSession(result["model"].as<std::string>(), result["engine_dir"].as<std::string>(), batchSizes,
inOutLen, logger, result["warm_up"].as<int>(), result["num_runs"].as<int>(), result["duration"].as<int>(),
enableCudaGraph);
}
catch (const std::exception& e)
{
TLLM_LOG_ERROR(e.what());
return 1;
}
return 0;
}

View File

@ -0,0 +1,63 @@
#!/usr/bin/python
import argparse
import json
from transformers import AutoTokenizer, LlamaTokenizer, T5Tokenizer
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--dataset',
type=str,
required=True,
help='Dataset path used for the test.')
parser.add_argument('--max_input_len',
type=int,
required=True,
help='Specify max input length')
parser.add_argument('--tokenizer_dir',
type=str,
required=True,
help='Specify tokenizer directory')
parser.add_argument('--tokenizer_type',
type=str,
default='auto',
required=False,
choices=['auto', 't5', 'llama'],
help='Specify tokenizer type')
parser.add_argument('--output',
type=str,
default='preprocessed_dataset.json',
help='Preprocessed dataset path.')
FLAGS = parser.parse_args()
if FLAGS.tokenizer_type == 't5':
tokenizer = T5Tokenizer(vocab_file=FLAGS.tokenizer_dir,
padding_side='left')
elif FLAGS.tokenizer_type == 'auto':
tokenizer = AutoTokenizer.from_pretrained(FLAGS.tokenizer_dir,
padding_side='left')
elif FLAGS.tokenizer_type == 'llama':
tokenizer = LlamaTokenizer.from_pretrained(FLAGS.tokenizer_dir,
legacy=False,
padding_side='left')
else:
raise AttributeError(
f'Unexpected tokenizer type: {FLAGS.tokenizer_type}')
tokenizer.pad_token = tokenizer.eos_token
results = []
with open(FLAGS.dataset, 'r') as f:
data_dict = json.load(f)
for req in data_dict:
prompt = req['input'] + ' ' + req['instruction']
output = req['output']
line = tokenizer.encode(prompt)
if len(line) > FLAGS.max_input_len:
continue
# 1.3 is a magic number that converts number of words to number of tokens
output_len = int(len(output.split(' ')) * 1.3)
results.append({'input_ids': line, 'output_len': output_len})
with open(FLAGS.output, 'w') as f:
json.dump(results, f)

View File

@ -190,6 +190,13 @@ def parse_arguments():
default=False,
action='store_true',
help='Execute GPT session with CUDA graph.')
parser.add_argument(
'--enable_custom_all_reduce',
default=False,
action='store_true',
help=
'Use latency-optimized all-reduce for tensor parallelism. Gives better performance with NVLink.'
)
return parser.parse_args()
@ -209,25 +216,27 @@ def main(args):
for io in in_out_len_options]
if args.model in get_allowed_models(benchmark_type="gpt"):
benchmarker = GPTBenchmark(args.engine_dir,
args.model,
args.mode,
batch_size_options,
in_out_len_options,
args.dtype,
args.refit,
args.num_beams,
args.top_k,
args.top_p,
args.output_dir,
args.n_positions,
args.max_input_len,
args.max_output_len,
args.max_batch_size,
force_num_layer_1=args.force_num_layer_1,
enable_fp8=args.enable_fp8,
fp8_kv_cache=args.fp8_kv_cache,
enable_cuda_graph=args.enable_cuda_graph)
benchmarker = GPTBenchmark(
args.engine_dir,
args.model,
args.mode,
batch_size_options,
in_out_len_options,
args.dtype,
args.refit,
args.num_beams,
args.top_k,
args.top_p,
args.output_dir,
args.n_positions,
args.max_input_len,
args.max_output_len,
args.max_batch_size,
force_num_layer_1=args.force_num_layer_1,
enable_fp8=args.enable_fp8,
fp8_kv_cache=args.fp8_kv_cache,
enable_cuda_graph=args.enable_cuda_graph,
enable_custom_all_reduce=args.enable_custom_all_reduce)
elif args.model in get_allowed_models(benchmark_type="bert"):
benchmarker = BERTBenchmark(args.engine_dir,
args.model,

View File

@ -49,6 +49,7 @@ class GPTBenchmark(BaseBenchmark):
max_input_len=None,
max_output_len=None,
max_batch_size=None,
enable_custom_all_reduce=None,
**kwargs):
super().__init__(engine_dir, model_name, dtype, output_dir)
self.batch_sizes = batch_sizes
@ -60,6 +61,7 @@ class GPTBenchmark(BaseBenchmark):
self.fuse_bias = True
self.cuda_graph_mode = kwargs.get('enable_cuda_graph', False)
self.enable_custom_all_reduce = enable_custom_all_reduce
if engine_dir is not None:
# Get build configs from engine directory is done in base class
@ -85,12 +87,9 @@ class GPTBenchmark(BaseBenchmark):
plg_dtype = dtype if is_plugin_mode else False
self.use_gpt_attention_plugin = plg_dtype
self.use_gemm_plugin = plg_dtype
self.use_layernorm_plugin = plg_dtype
# Enable RMS Norm plugin for the LLaMA family.
if is_plugin_mode and 'llama' in model_name:
self.use_rmsnorm_plugin = dtype
else:
self.use_rmsnorm_plugin = False
# Starting TRT9.1 OOTB norm layer sees improvement over plugin norm layer
self.use_layernorm_plugin = False
self.use_rmsnorm_plugin = False
self.use_lookup_plugin = plg_dtype
self.enable_context_fmha = True
self.quant_mode = QuantMode(0)
@ -408,7 +407,8 @@ class GPTBenchmark(BaseBenchmark):
network.plugin_config.set_rmsnorm_quantization_plugin()
if self.world_size > 1:
network.plugin_config.set_nccl_plugin(self.dtype)
network.plugin_config.set_nccl_plugin(self.dtype,
self.enable_custom_all_reduce)
# Use the plugin for the embedding parallism and sharing
network.plugin_config.set_lookup_plugin(dtype=self.use_lookup_plugin)
@ -434,8 +434,7 @@ class GPTBenchmark(BaseBenchmark):
self.build_time = round(end - start, 2)
if self.output_dir is not None:
if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir)
os.makedirs(self.output_dir, exist_ok=True)
self.serialize_path = os.path.join(self.output_dir,
self.engine_name)
serialize_engine(engine, self.serialize_path)

View File

@ -101,9 +101,6 @@ if(CMAKE_CUDA_COMPILER)
else()
message(FATAL_ERROR "Failed to determine CUDA version")
endif()
# Export shared libs as both `.lib` and `.dll` to avoid linking errors.
set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS ON)
endif()
else()
message(FATAL_ERROR "No CUDA compiler found")
@ -166,7 +163,8 @@ get_filename_component(TRT_LLM_ROOT_DIR ${CMAKE_SOURCE_DIR} PATH)
set(3RDPARTY_DIR ${TRT_LLM_ROOT_DIR}/3rdparty)
include_directories(
${CUDA_INCLUDE_DIRS} ${CUDNN_ROOT_DIR}/include ${NCCL_INCLUDE_DIR}
${3RDPARTY_DIR}/cutlass/include ${3RDPARTY_DIR}/NVTX/include)
${3RDPARTY_DIR}/cutlass/include ${3RDPARTY_DIR}/NVTX/include
${3RDPARTY_DIR}/json/include)
# TRT dependencies
set_ifndef(TRT_LIB_DIR ${CMAKE_BINARY_DIR})
@ -195,6 +193,7 @@ endif()
# it's not called before "CMAKE_CXX_FLAGS" is set, it breaks on Windows for some
# reason, so we just call it here as a workaround.
find_package(MPI REQUIRED)
add_definitions("-DOMPI_SKIP_MPICXX")
# C++17
set(CMAKE_CXX_STANDARD 17)
@ -214,6 +213,12 @@ else()
set(CMAKE_CXX_FLAGS "/wd4996 ${CMAKE_CXX_FLAGS}")
endif()
# A Windows header file defines max() and min() macros, which break our macro
# declarations.
if(WIN32)
set(CMAKE_CXX_FLAGS "/DNOMINMAX ${CMAKE_CXX_FLAGS}")
endif()
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-extended-lambda")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr")
@ -226,10 +231,12 @@ if(BUILD_PYT)
foreach(CUDA_ARCH IN LISTS CMAKE_CUDA_ARCHITECTURES)
if(CUDA_ARCH MATCHES "^([0-9])([0-9])(-real)*$")
set(TORCH_ARCH "${CMAKE_MATCH_1}.${CMAKE_MATCH_2}")
elseif(CUDA_ARCH STREQUAL "native")
set(TORCH_ARCH "Auto")
else()
message(FATAL_ERROR "${CUDA_ARCH} is not supported")
endif()
if(NOT CUDA_ARCH MATCHES "-real$")
if(NOT CUDA_ARCH MATCHES "-real$" AND NOT CUDA_ARCH STREQUAL "native")
string(APPEND TORCH_ARCH "+PTX")
endif()
list(APPEND TORCH_CUDA_ARCH_LIST ${TORCH_ARCH})
@ -303,8 +310,9 @@ message(
"Building for TensorRT version: ${TRT_VERSION}, library version: ${TRT_SOVERSION}"
)
list(APPEND COMMON_HEADER_DIRS ${TORCH_INCLUDE_DIRS} ${TRT_INCLUDE_DIR})
list(APPEND COMMON_HEADER_DIRS)
include_directories(${COMMON_HEADER_DIRS})
include_directories(SYSTEM ${TORCH_INCLUDE_DIRS} ${TRT_INCLUDE_DIR})
add_subdirectory(tensorrt_llm)

View File

@ -32,7 +32,8 @@ enum class BatchManagerErrorCode_t
{
STATUS_SUCCESS = 0,
STATUS_FAILED = 1,
STATUS_NO_WORK = 2
STATUS_NO_WORK = 2,
STATUS_TERMINATE = 3
};
enum class TrtGptModelType

View File

@ -39,8 +39,8 @@ namespace tensorrt_llm::batch_manager
class InferenceRequest;
class TrtGptModel;
/* Responsible for shepherding Triton requests through to completion
using TRT Backend. Each Triton backend should have just one of these. */
/* Responsible for shepherding requests through to completion
using TRT Backend. */
class GptManager
{
public:
@ -50,6 +50,7 @@ public:
GptManager(std::filesystem::path const& trtEnginePath, TrtGptModelType modelType, int32_t maxBeamWidth,
batch_scheduler::SchedulerPolicy schedulerPolicy, GetInferenceRequestsCallback getInferenceRequestsCb,
SendResponseCallback sendResponseCb, PollStopSignalCallback pollStopSignalCb = nullptr,
ReturnBatchManagerStatsCallback returnBatchManagerStatsCb = nullptr,
const TrtGptModelOptionalParams& optionalParams = TrtGptModelOptionalParams());
/* Wraps the user-provided callback for requests.
@ -57,20 +58,21 @@ public:
Invoked every generation loop iteration. */
BatchManagerErrorCode_t fetchNewRequests();
/* Does the following:
1. Returns completed requests to Triton
2. Deletes entry from activeRequests */
/* Returns completed requests.
Deletes entry from activeRequests */
BatchManagerErrorCode_t returnCompletedRequests();
BatchManagerErrorCode_t pollStopSignals();
BatchManagerErrorCode_t returnBatchManagerStats();
BatchManagerErrorCode_t waitUntilTerminate();
virtual ~GptManager();
protected:
/* Does the following:
1. Maps batch manager requests to backend request
2. Invokes one step of backend
3. Updates state of all requests */
/* Invokes one step of backend
Updates state of all requests */
virtual BatchManagerErrorCode_t step(RequestList& activeRequests, std::set<uint64_t>& activeRequestsIds);
private:
@ -84,6 +86,8 @@ private:
SizeType mMaxOutputLen;
SizeType mMaxNumSequences;
// Iteration counter - incremented every iteration of the generation loop
int64_t mIterationCounter;
// List of live requests
RequestList mActiveRequests;
// IDs of live requests
@ -92,6 +96,7 @@ private:
GetInferenceRequestsCallback mGetInferenceRequestsCb;
SendResponseCallback mSendResponseCb;
PollStopSignalCallback mPollStopSignalCb;
ReturnBatchManagerStatsCallback mReturnBatchManagerStatsCb;
std::atomic<bool> destructor_called_;
void decoupled_execution_loop();

View File

@ -31,5 +31,7 @@ class NamedTensor;
using GetInferenceRequestsCallback = std::function<std::list<std::shared_ptr<InferenceRequest>>(int32_t)>;
using SendResponseCallback = std::function<void(uint64_t, std::list<NamedTensor> const&, bool, const std::string&)>;
using PollStopSignalCallback = std::function<std::unordered_set<uint64_t>()>;
// json of stats as a string
using ReturnBatchManagerStatsCallback = std::function<void(const std::string&)>;
} // namespace tensorrt_llm::batch_manager

View File

@ -27,6 +27,7 @@
#include <memory>
#include <mutex>
#include <set>
#include <string>
#include <thread>
#include <tuple>
#include <vector>

View File

@ -21,6 +21,7 @@
#include "tensorrt_llm/runtime/cudaStream.h"
#include "tensorrt_llm/runtime/gptModelConfig.h"
#include "tensorrt_llm/runtime/iTensor.h"
#include "tensorrt_llm/runtime/worldConfig.h"
#include <NvInferRuntime.h>
#include <cstdint>
@ -175,8 +176,8 @@ public:
using CudaStreamPtr = std::shared_ptr<runtime::CudaStream>;
KVCacheManager(SizeType numLayers, SizeType numHeads, SizeType numKvHeads, SizeType hiddenSize,
SizeType tokensPerBlock, SizeType maxNumBlocks, SizeType maxBatchSize, nvinfer1::DataType dtype,
CudaStreamPtr stream);
SizeType tokensPerBlock, SizeType maxNumBlocks, SizeType maxBatchSize, SizeType maxBeamWidth,
SizeType maxBlocksPerSeq, nvinfer1::DataType dtype, CudaStreamPtr stream);
[[nodiscard]] SizeType getTokensPerBlock() const
{
@ -221,11 +222,10 @@ public:
void removeSequence(SizeType batchSlotIdx);
[[nodiscard]] std::vector<runtime::ITensor::UniquePtr> getBlockPointersOfSlot(
SizeType batchSlotIdx, SizeType beamWidth, SizeType maxBlocksPerSeq) const;
void getBlockPointersOfBatch(runtime::ITensor::SharedPtr dstPointers, SizeType batchSize, SizeType beamWidth) const;
[[nodiscard]] runtime::ITensor::UniquePtr getBlockPointersOfBatch(
SizeType batchSize, SizeType beamWidth, SizeType maxBlocksPerSeq) const;
void copyBlockPointers(runtime::ITensor::SharedPtr dstPointers, SizeType dstSlotOffset, SizeType batchSlotIdx,
SizeType beamWidth) const;
// Volume of [2, numKvHeads, tokensPerBlock, sizePerHead]
[[nodiscard]] static SizeType constexpr calculatePageSize(tensorrt_llm::runtime::GptModelConfig const& modelConfig)
@ -235,11 +235,15 @@ public:
// numLayers * 2 * numKvHeads * sizePerHead
[[nodiscard]] static SizeType constexpr calculateCacheSizePerToken(
tensorrt_llm::runtime::GptModelConfig const& modelConfig)
tensorrt_llm::runtime::GptModelConfig const& modelConfig, tensorrt_llm::runtime::WorldConfig const& worldConfig)
{
return modelConfig.getNbLayers() * 2 * modelConfig.getNbKvHeads() * modelConfig.getSizePerHead();
return modelConfig.getNbLayers(worldConfig.getPipelineParallelism()) * 2 * modelConfig.getNbKvHeads()
* modelConfig.getSizePerHead();
}
private:
void cacheNewBlockPointer(const GenerationRequest& seq, SizeType batchSlotIdx);
private:
// Number of elements per one blocks
SizeType mBlockSize;
@ -247,14 +251,20 @@ private:
SizeType mTokensPerBlock;
// Total maximum number of blocks
SizeType mMaxNumBlocks;
// Maximum size of batch
SizeType mMaxBatchSize;
// Maximum beam width
SizeType mMaxBeamWidth;
// Maximum number of blocks per sequence
SizeType mMaxBlocksPerSeq;
// Pools
std::vector<runtime::ITensor::SharedPtr> mPools;
// Block manager
BlockManager mBlockManager;
// List of all sequences
std::vector<SequencesPtr> mSequences;
// buffer for block pointers for all batch slots
std::vector<runtime::ITensor::UniquePtr> mAllBlockPointers;
// buffer for block pointers for all managed sequences
runtime::ITensor::SharedPtr mSequenceBlockPointers;
runtime::BufferManager mManager;
};

View File

@ -37,22 +37,24 @@ enum LlmRequestState_t
class LlmRequest
{
public:
using BeamTokens = std::vector<std::vector<int32_t>>;
using SizeType = runtime::SizeType;
using TokenIdType = runtime::TokenIdType;
using RequestIdType = std::uint64_t;
using BeamTokens = std::vector<std::vector<TokenIdType>>;
LlmRequest(uint64_t requestId, int32_t maxNewTokens, std::shared_ptr<std::vector<int32_t>> input_tokens,
LlmRequest(RequestIdType requestId, SizeType maxNewTokens, std::shared_ptr<std::vector<TokenIdType>> input_tokens,
runtime::SamplingConfig samplingConfig, bool isStreaming, std::optional<SizeType> endId = std::nullopt,
std::optional<SizeType> padId = std::nullopt)
: mRequestId(requestId)
, mPromptLen(input_tokens->size())
, mMaxNewTokens(maxNewTokens)
, mSamplingConfig(samplingConfig)
, mPromptLen(input_tokens->size())
, mNumGeneratedTokens(0)
, mState(REQUEST_STATE_CONTEXT_INIT)
, mIsStreaming(isStreaming)
, mEndId(endId)
, mPadId(padId)
, mBatchSlot(-1)
, mNumGeneratedTokens(0)
{
mMaxSentTokenPos = mPromptLen - 1;
// Scatter the input tokens to other beam
@ -62,7 +64,7 @@ public:
/// @brief Get total number of tokens for this req (prompt + generated)
/// @param beam The beam index
/// @return The number of tokens
int32_t getNumTokens(int beam) const
SizeType getNumTokens(SizeType beam) const
{
return mTokens->at(beam).size();
}
@ -71,7 +73,7 @@ public:
/// @param beam The beam index
/// @param pos The position of the token relative to beginning of the prompt
/// @return The token index
int32_t getToken(int beam, int pos) const
TokenIdType getToken(SizeType beam, SizeType pos) const
{
return mTokens->at(beam).at(pos);
}
@ -79,14 +81,14 @@ public:
/// @brief Get the tokens at a given beam index
/// @param beam The beam index
/// @return A vector of tokens for this beam index, includes the prompt
std::vector<int32_t> getTokens(int beam) const
std::vector<TokenIdType> getTokens(SizeType beam) const
{
return mTokens->at(beam);
}
/// @brief Get the number of generated tokens
/// @return The number of generated tokens (doesn't include the prompt tokens)
int32_t getNumGeneratedTokens() const
SizeType getNumGeneratedTokens() const
{
return mNumGeneratedTokens;
}
@ -94,10 +96,10 @@ public:
/// @brief Add new generated tokens to the vector of tokens
/// @param beamTokens A vector containing the tokens to add for each beam index
/// beamTokens is expected to be of size beamWidth
void addNewTokens(const std::vector<int32_t>& beamTokens)
void addNewTokens(const std::vector<TokenIdType>& beamTokens)
{
assert(mSamplingConfig.beamWidth == beamTokens.size());
for (int beam = 0; beam < beamTokens.size(); ++beam)
for (std::size_t beam = 0; beam < beamTokens.size(); ++beam)
{
mTokens->at(beam).push_back(beamTokens[beam]);
}
@ -109,7 +111,7 @@ public:
void setGeneratedTokens(const BeamTokens& generatedBeamTokens)
{
assert(generatedBeamTokens.size() == mSamplingConfig.beamWidth);
for (int beam = 0; beam < generatedBeamTokens.size(); ++beam)
for (std::size_t beam = 0; beam < generatedBeamTokens.size(); ++beam)
{
auto& beamTokens = (*mTokens)[beam];
beamTokens.resize(mPromptLen);
@ -151,7 +153,7 @@ public:
/// @brief Get the maximum position of the tokens returned to the client. Use to ensure we don't return to client
/// duplicated token positions.
/// @return The maximum position of the tokens sent to the client
int32_t getMaxSentTokenPos() const
SizeType getMaxSentTokenPos() const
{
return mMaxSentTokenPos;
}
@ -159,28 +161,26 @@ public:
/// @brief Sets the maximum position of the tokens returned to the client. Use to ensure we don't return to client
/// duplicated token positions.
/// @param pos The maximum position
void setMaxSentTokenPos(int32_t pos)
void setMaxSentTokenPos(SizeType pos)
{
mMaxSentTokenPos = pos;
}
uint64_t mRequestId;
int32_t mMaxNewTokens;
RequestIdType mRequestId;
SizeType mPromptLen;
SizeType mMaxNewTokens;
// Tokens [beam_size, mPromptLen + mNumGeneratedTokens]
runtime::SamplingConfig mSamplingConfig;
int32_t mPromptLen;
LlmRequestState_t mState;
bool mIsStreaming;
std::optional<SizeType> mEndId;
std::optional<SizeType> mPadId;
int32_t mBatchSlot;
~LlmRequest() {}
SizeType mBatchSlot;
private:
std::shared_ptr<BeamTokens> mTokens;
int32_t mNumGeneratedTokens;
int32_t mMaxSentTokenPos;
SizeType mNumGeneratedTokens;
SizeType mMaxSentTokenPos;
};
} // namespace tensorrt_llm::batch_manager

View File

@ -90,10 +90,22 @@ public:
void setZero(IBuffer& buffer) const;
//! \brief Copy `src` to `dst`.
void copy(void const* src, IBuffer& dst) const;
void copy(void const* src, IBuffer& dst, MemoryType srcType) const;
//! \brief Copy `src` to `dst`.
void copy(IBuffer const& src, void* dst) const;
void copy(IBuffer const& src, void* dst, MemoryType dstType) const;
//! \brief Copy `src` to `dst`.
void copy(void const* src, IBuffer& dst) const
{
return copy(src, dst, IBuffer::memoryType(src));
}
//! \brief Copy `src` to `dst`.
void copy(IBuffer const& src, void* dst) const
{
return copy(src, dst, IBuffer::memoryType(dst));
}
//! \brief Copy `src` to `dst`.
void copy(IBuffer const& src, IBuffer& dst) const;

View File

@ -0,0 +1,102 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/logger.h"
#include <cuda_runtime_api.h>
#include <memory>
namespace tensorrt_llm::runtime
{
class CudaEvent
{
public:
using pointer = cudaEvent_t;
//! Creates a new cuda event. The event will be destroyed in the destructor.
//!
//! \param flags Flags for event creation. By default, event timing is disabled.
explicit CudaEvent(unsigned int flags = cudaEventDisableTiming)
{
pointer event;
TLLM_CUDA_CHECK(::cudaEventCreate(&event, flags));
TLLM_LOG_TRACE("Created event %p", event);
bool constexpr ownsEvent{true};
mEvent = EventPtr{event, Deleter{ownsEvent}};
}
//! Pass an existing cuda event to this object.
//!
//! \param event The event to pass to this object.
//! \param ownsEvent Whether this object owns the event and destroys it in the destructor.
explicit CudaEvent(pointer event, bool ownsEvent = true)
{
TLLM_CHECK_WITH_INFO(event != nullptr, "event is nullptr");
mEvent = EventPtr{event, Deleter{ownsEvent}};
}
//! Returns the event associated with this object.
[[nodiscard]] pointer get() const
{
return mEvent.get();
}
//! \brief Synchronizes the event.
void synchronize() const
{
TLLM_CUDA_CHECK(::cudaEventSynchronize(get()));
}
private:
class Deleter
{
public:
explicit Deleter(bool ownsEvent)
: mOwnsEvent{ownsEvent}
{
}
explicit Deleter()
: Deleter{true}
{
}
constexpr void operator()(pointer event) const
{
if (mOwnsEvent && event != nullptr)
{
TLLM_CUDA_CHECK(::cudaEventDestroy(event));
TLLM_LOG_TRACE("Destroyed event %p", event);
}
}
private:
bool mOwnsEvent;
};
using element_type = std::remove_pointer_t<pointer>;
using EventPtr = std::unique_ptr<element_type, Deleter>;
EventPtr mEvent;
};
} // namespace tensorrt_llm::runtime

View File

@ -19,6 +19,7 @@
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/runtime/cudaEvent.h"
#include <cuda_runtime_api.h>
@ -78,17 +79,29 @@ public:
}
//! \brief Record an event on the stream.
void record(tensorrt_llm::common::EventPtr::pointer event)
void record(CudaEvent::pointer event) const
{
TLLM_CUDA_CHECK(::cudaEventRecord(event, get()));
}
//! \brief Record an event on the stream.
void record(CudaEvent const& event) const
{
record(event.get());
}
//! \brief Wait for an event.
void wait(tensorrt_llm::common::EventPtr::pointer event)
void wait(CudaEvent::pointer event) const
{
TLLM_CUDA_CHECK(::cudaStreamWaitEvent(get(), event));
}
//! \brief Wait for an event.
void wait(CudaEvent const& event) const
{
wait(event.get());
}
private:
class Deleter
{

View File

@ -43,7 +43,8 @@ public:
TensorPtr ids; // [batchSize, beamWidth, maxInputLength + maxNewTokens]
// optional parameters
TensorPtr logProbs; // [request_output_length, batch_size * beam_width], must be float*, on gpu
TensorPtr logProbs; // [request_output_length, batch_size * beam_width], must be float*, on gpu
TensorPtr contextLogits; // [batch_size, max_input_length, vocab_size_padded]
// callbacks
Callback onTokenGenerated;

View File

@ -18,6 +18,7 @@
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/runtime/bufferManager.h"
#include "tensorrt_llm/runtime/cudaEvent.h"
#include "tensorrt_llm/runtime/cudaStream.h"
#include "tensorrt_llm/runtime/gptDecoder.h"
#include "tensorrt_llm/runtime/iGptDecoderBatch.h"
@ -51,19 +52,27 @@ public:
void newBatch(GenerationInput const& inputs, SamplingConfig const& samplingConfig) override;
//! @brief Run one step for all requests.
//! Note that this method will synchronize with the stream associated with the decoder.
void forward(decoder_batch::Output& output, decoder_batch::Input const& input) override;
TokenPtr forwardAsync(decoder_batch::Output& output, decoder_batch::Input const& input) override;
bool forward(decoder::Output& output, decoder::Input const& input) override;
void forwardSync(decoder_batch::Token const& e) override;
//! @brief Gather final results for request `batchIdx`.
void postProcessRequest(SizeType batchIdx) const override;
void forwardAsync(decoder::Output& output, decoder::Input const& input) override;
bool isFinishedSync() override;
//! @return [batchSize], indicators of finished requests
[[nodiscard]] std::vector<bool> getFinished() const override
{
return std::vector<bool>(mFinished.begin(), mFinished.begin() + mActualBatchSize);
return {mFinished.begin(), mFinished.begin() + mActualBatchSize};
}
//! @returns [maxBeamWidth, maxInputLength + maxNewTokens], contains input token ids and generated token ids without
//! padding for request `batchIdx`, on gpu
[[nodiscard]] TensorPtr getOutputIds(SizeType batchIdx) const override
{
auto tensor = ITensor::slice(mJointDecodingOutput->ids, batchIdx, 1);
tensor->squeeze(0);
return tensor;
}
//! @returns [batchSize, maxBeamWidth, maxInputLength + maxNewTokens], contains input token ids and generated token
@ -73,6 +82,14 @@ public:
return ITensor::slice(mJointDecodingOutput->ids, 0, mActualBatchSize);
}
//! Execute postProcessRequest and returns OutputIds for request `batchIdx`.
//! @returns [maxBeamWidth, maxInputLength + maxNewTokens], contains input token ids and generated token ids without
//! padding for request `batchIdx`, on gpu
[[nodiscard]] TensorPtr getFinalOutputIds(SizeType batchIdx) const override;
//! Execute postProcessRequest and returns OutputIds.
//! @returns [batchSize, maxBeamWidth, maxInputLength + maxNewTokens], contains input token ids and generated token
//! ids without padding, on gpu
[[nodiscard]] TensorPtr getFinalOutputIds() const override;
//! @returns [batchSize, maxBeamWidth, maxInputLength + maxNewTokens], contains parent ids collected during beam
@ -112,15 +129,25 @@ public:
return std::vector<SizeType>(mNbSteps.begin(), mNbSteps.begin() + mActualBatchSize);
}
//! @returns [1], number of finished sequences, in pinned host memory
[[nodiscard]] TensorPtr getNbFinished() const override
{
return mFinishedSum;
}
private:
//! @brief Gather final results for request `batchIdx`.
void postProcessRequest(SizeType batchIdx) const;
private:
std::size_t const mVocabSize;
std::size_t const mVocabSizePadded;
CudaStreamPtr mStream;
BufferManager mBufferManager;
tensorrt_llm::common::EventPtr mEventStart, mEventStop;
TokenPtr mForwardToken;
CudaEvent mForwardEvent;
std::vector<CudaStreamPtr> mStreams;
std::vector<tensorrt_llm::common::EventPtr> mEvents;
using GptDecoderPtr = std::unique_ptr<IGptDecoder>;
std::vector<GptDecoderPtr> mDecoders;
using DecodingInputPtr = std::unique_ptr<DecodingInput>;
@ -133,6 +160,7 @@ private:
std::vector<SizeType> mNbSteps;
std::vector<bool> mFinished;
TensorPtr mFinishedSum;
std::vector<SizeType> mMaxNewTokens;
std::vector<SizeType> mBeamWidths;
SizeType mMaxSequenceLength{};

View File

@ -42,6 +42,7 @@ public:
, mMaxBatchSize(0)
, mMaxInputLen(0)
, mMaxOutputLen(0)
, mComputeContextLogits(false)
{
}
@ -176,6 +177,16 @@ public:
mMaxOutputLen = maxOutputLen;
}
[[nodiscard]] bool constexpr computeContextLogits() const noexcept
{
return mComputeContextLogits;
}
void constexpr computeContextLogits(bool computeContextLogits) noexcept
{
mComputeContextLogits = computeContextLogits;
}
private:
SizeType mVocabSize;
SizeType mNbLayers;
@ -191,6 +202,7 @@ private:
SizeType mMaxBatchSize;
SizeType mMaxInputLen;
SizeType mMaxOutputLen;
bool mComputeContextLogits;
};
} // namespace tensorrt_llm::runtime

View File

@ -18,6 +18,7 @@
#include "tensorrt_llm/runtime/bufferManager.h"
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/cudaEvent.h"
#include "tensorrt_llm/runtime/generationInput.h"
#include "tensorrt_llm/runtime/generationOutput.h"
#include "tensorrt_llm/runtime/gptModelConfig.h"
@ -99,19 +100,52 @@ public:
mCudaGraphMode = value;
}
void setup(SizeType batchSize, SizeType beamWidth, SizeType maxSequenceLength, bool decoderPerRequest,
std::optional<SizeType> maxTokensInPagedKvCache = std::nullopt);
//! @brief Initialize buffers for the given sizes.
//! `generate` may be called with batch size and beam width smaller than the setup parameters.
//! @details `maxBatchSize` will be devided by the number of micro batches to initialize each batch buffer.
void setup(SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxSequenceLength, bool decoderPerRequest,
std::optional<SizeType> maxTokensInPagedKvCache = std::nullopt,
std::optional<SizeType> numMicroBatches = std::nullopt);
void generate(GenerationOutput& outputs, GenerationInput const& inputs, SamplingConfig const& samplingConfig);
void generate(GenerationOutput& outputs, GenerationInput const& inputs, SamplingConfig const& samplingConfig)
{
if (mNumMicroBatches == 1)
generateSingleBatch(outputs, inputs, samplingConfig);
else
generateMultiBatch(outputs, inputs, samplingConfig);
}
private:
void generateSingleBatch(
GenerationOutput& outputs, GenerationInput const& inputs, SamplingConfig const& samplingConfig);
void generateMultiBatch(
GenerationOutput& outputs, GenerationInput const& inputs, SamplingConfig const& samplingConfig);
using KvCacheManager = batch_manager::kv_cache_manager::KVCacheManager;
void createContexts();
void createDecoder(bool decoderPerRequest);
void createContexts(SizeType numMicroBatches);
void createBuffers(SizeType numMicroBatches);
void createDecoders(SizeType batchSize, SizeType beamWidth, SizeType maxSequenceLength,
nvinfer1::DataType logitsType, bool decoderPerRequest, SizeType numMicroBatches);
void createKvCacheManagers(SizeType batchSize, SizeType beamWidth, SizeType maxSequenceLength,
SizeType numMicroBatches, std::optional<SizeType> maxTokensInPagedKvCache);
bool executeDecoderStep(ITensor::SharedPtr& outputIds, ITensor::SharedPtr& newTokens, SizeType decoderStep);
void finalizeOutputIds(ITensor& outputIds);
//! @brief Execute decoder on last PP rank, receive decoder output on other PP ranks.
void decoderStepAsync(
ITensor::SharedPtr& outputIds, ITensor::SharedPtr& newTokens, SizeType decoderStep, SizeType microBatchId);
//! @brief Synchronize with the decoder and return the `shouldStop` flag.
bool shouldStopSync(SizeType batchSize, SizeType beamWidth, SizeType microBatchId);
//! @brief Collect final output ids on last PP rank and send them to first PP rank.
//! @details Receives are asynchronous on host, so synchronization is required before access.
void finalizeOutputIds(ITensor& outputIds, SizeType microBatchId);
void kvCacheAddSequences(SizeType beamWidth, SizeType microBatchId);
ITensor::SharedPtr initNewTokens(
GenerationInput const& inputs, SamplingConfig const& samplingConfig, SizeType microBatchId);
class CudaGraphExecutor
{
@ -153,15 +187,20 @@ private:
WorldConfig const mWorldConfig;
int mDevice{-1};
std::shared_ptr<NcclCommunicator> mPipelineComm;
std::shared_ptr<CudaStream> mCommStream;
CudaEvent mCommEvent{};
SizeType mDecoderMaxSequenceLength{};
LoggerPtr mLogger;
std::shared_ptr<TllmRuntime> mRuntime;
std::shared_ptr<IStatefulGptDecoder> mDecoder;
std::shared_ptr<RuntimeBuffers> mBuffers;
std::shared_ptr<KvCacheManager> mKvCacheManager;
SizeType mNumMicroBatches;
// for each micro batch
std::vector<std::shared_ptr<IStatefulGptDecoder>> mDecoders;
std::vector<std::shared_ptr<RuntimeBuffers>> mBuffers;
std::vector<std::shared_ptr<KvCacheManager>> mKvCacheManagers;
std::vector<CudaEvent> mReceivedEvents;
bool mCudaGraphMode{false};
// ping-pong instances

View File

@ -17,6 +17,7 @@
#pragma once
#include "tensorrt_llm/runtime/bufferManager.h"
#include "tensorrt_llm/runtime/cudaEvent.h"
#include "tensorrt_llm/runtime/cudaStream.h"
#include "tensorrt_llm/runtime/iStatefulGptDecoder.h"
#include "tensorrt_llm/runtime/iTensor.h"
@ -41,7 +42,6 @@ public:
: ids{std::move(ids)}
, maxNewTokens{maxNewTokens}
, endId{endId}
, padId{padId}
{
}
@ -51,7 +51,6 @@ public:
// optional parameters
std::optional<SizeType> maxNewTokens; // maximum number of tokens to generate for this request
std::optional<SizeType> endId; // end token id
std::optional<SizeType> padId; // pad token id
TensorPtr embeddingBias; // [vocabSizePadded], on gpu
TensorPtr badWordsList; // [2, badWordsLength], on gpu
TensorPtr stopWordsList; // [2, stopWordsLength], on gpu
@ -82,6 +81,19 @@ public:
};
using Output = decoder::Output;
class Token
{
public:
explicit Token(CudaEvent&& event, std::vector<bool> const& active)
: event(std::move(event))
, active(active)
{
}
CudaEvent event;
std::vector<bool> active;
};
} // namespace decoder_batch
//! GPT decoder class with support for in-flight batching
@ -90,17 +102,33 @@ class IGptDecoderBatch : public virtual IStatefulGptDecoder
public:
using CudaStreamPtr = std::shared_ptr<CudaStream>;
using TensorPtr = std::shared_ptr<ITensor>;
using TokenPtr = std::unique_ptr<decoder_batch::Token const>;
//! @brief Initialize the decoder at `batchIdx` with a new `request`.
virtual void newRequest(
SizeType batchIdx, decoder_batch::Request const& request, SamplingConfig const& samplingConfig)
= 0;
//! @brief Run one step for all requests.
virtual void forward(decoder_batch::Output& output, decoder_batch::Input const& input) = 0;
//! @brief Run one step for all requests without blocking the host process and return the token for synchronization.
virtual TokenPtr forwardAsync(decoder_batch::Output& output, decoder_batch::Input const& input) = 0;
//! @brief Gather final results for request `batchIdx`.
virtual void postProcessRequest(SizeType batchIdx) const = 0;
//! @brief Wait for the call to `forwardAsync` associated with a token to complete.
virtual void forwardSync(decoder_batch::Token const& token) = 0;
//! @brief Run one step for all requests and wait for completion on the host.
virtual void forward(decoder_batch::Output& output, decoder_batch::Input const& input)
{
forwardSync(*forwardAsync(output, input));
}
//! @returns [maxBeamWidth, maxInputLength + maxNewTokens], contains input token ids and generated token
//! ids without padding for request `batchIdx`, on gpu
virtual TensorPtr getOutputIds(SizeType batchIdx) const = 0;
//! Execute postProcessRequest and returns OutputIds for request `batchIdx`.
//! @returns [maxBeamWidth, maxInputLength + maxNewTokens], contains input token ids and generated token ids without
//! padding for request `batchIdx`, on gpu
virtual TensorPtr getFinalOutputIds(SizeType batchIdx) const = 0;
//! @returns [batchSize, beamWidth], marks finished requests (per beam), on gpu
virtual TensorPtr getFinishedBeams() const = 0;
@ -108,6 +136,9 @@ public:
//! @returns [batchSize, beamWidth], total sequence lengths (per beam), on gpu
virtual TensorPtr getOutputLengths() const = 0;
//! @returns [batchSize (actual)], marks finished requests (per batch)
virtual std::vector<bool> getFinished() const = 0;
//! @returns [batchSize, beamWidth], cumulative log probabilities (per beam), on gpu
virtual TensorPtr getCumLogProbs() const = 0;

View File

@ -80,21 +80,31 @@ public:
//! @brief Initialize the decoder with new batch of inputs.
virtual void newBatch(GenerationInput const& inputs, SamplingConfig const& samplingConfig) = 0;
//! @brief Run one step for all requests without blocking the host thread.
virtual void forwardAsync(decoder::Output& output, decoder::Input const& input) = 0;
//! @brief Wait for the last call to `forwardAsync` to complete and return whether all sequences have finished.
virtual bool isFinishedSync() = 0;
//! @brief Run one step for all requests.
virtual bool forward(decoder::Output& output, decoder::Input const& input) = 0;
virtual bool forward(decoder::Output& output, decoder::Input const& input)
{
forwardAsync(output, input);
return isFinishedSync();
}
//! @brief Gather final results for all requests.
virtual TensorPtr getFinalOutputIds() const = 0;
// TODO: do we need that?
virtual std::vector<bool> getFinished() const = 0;
//! @returns [batchSize, beamWidth, maxSequenceLength], all token ids, on gpu
virtual TensorPtr getOutputIds() const = 0;
//! @returns [batchSize, beamWidth], latests generated tokens (per beam), on gpu
virtual TensorPtr getNewTokens() const = 0;
//! @returns [1], number of finished sequences, in pinned host memory
virtual TensorPtr getNbFinished() const = 0;
protected:
IStatefulGptDecoder() = default;
};

View File

@ -30,7 +30,7 @@ class MemoryCounters
{
public:
using SizeType = std::size_t;
using DiffType = std::int64_t;
using DiffType = std::ptrdiff_t;
MemoryCounters() = default;

View File

@ -95,6 +95,8 @@ public:
[[nodiscard]] std::vector<SizeType> getPipelineParallelGroup() const;
static bool validConfig(nvinfer1::ILogger& logger, SizeType tensorParallelism, SizeType pipelineParallelism);
static WorldConfig mpi(nvinfer1::ILogger& logger, SizeType gpusPerNode = kDefaultGpusPerNode,
std::optional<SizeType> tensorParallelism = std::nullopt,
std::optional<SizeType> pipelineParallelism = std::nullopt);

View File

@ -35,10 +35,11 @@ add_subdirectory(kernels)
add_subdirectory(layers)
add_subdirectory(runtime)
set(BATCH_MANAGER_TARGET tensorrt_llm_batch_manager_static)
if(BUILD_BATCH_MANAGER)
add_subdirectory(batch_manager)
else()
add_library(tensorrt_llm_batch_manager_static STATIC IMPORTED)
add_library(${BATCH_MANAGER_TARGET} STATIC IMPORTED)
execute_process(
COMMAND ${Python3_EXECUTABLE} "-c"
"import torch; print(torch.compiled_with_cxx11_abi(),end='');"
@ -48,14 +49,14 @@ else()
message(STATUS "USE_CXX11_ABI: ${USE_CXX11_ABI}")
if(USE_CXX11_ABI)
set_property(
TARGET tensorrt_llm_batch_manager_static
TARGET ${BATCH_MANAGER_TARGET}
PROPERTY
IMPORTED_LOCATION
"${CMAKE_CURRENT_SOURCE_DIR}/batch_manager/libtensorrt_llm_batch_manager_static.a"
)
else()
set_property(
TARGET tensorrt_llm_batch_manager_static
TARGET ${BATCH_MANAGER_TARGET}
PROPERTY
IMPORTED_LOCATION
"${CMAKE_CURRENT_SOURCE_DIR}/batch_manager/libtensorrt_llm_batch_manager_static.pre_cxx11.a"
@ -69,16 +70,19 @@ set(TRTLLM_LINK_LIBS
${CUDNN_LIB}
${CMAKE_DL_LIBS}
${MPI_CXX_LIBRARIES}
${NCCL_LIB}
${TRT_LIB}
common_src
kernels_src
layers_src
runtime_src
tensorrt_llm_batch_manager_static)
${BATCH_MANAGER_TARGET})
# ################################# SHARED LIBRARY
# ##############################################################################
set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS ON)
add_library(${SHARED_TARGET} SHARED)
set_target_properties(
@ -105,6 +109,9 @@ set_target_properties(
target_link_libraries(${STATIC_TARGET} PUBLIC ${TRTLLM_LINK_LIBS})
# Cyclic dependency of batch manager on TRT-LLM
target_link_libraries(${BATCH_MANAGER_TARGET} INTERFACE ${STATIC_TARGET})
if(BUILD_PYT)
add_subdirectory(thop)
endif()

View File

@ -1,220 +0,0 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/common/cublasAlgoMap.h"
namespace tensorrt_llm
{
namespace common
{
cublasAlgoMap::cublasAlgoMap(const std::string filename, const std::string sp_config_filename)
: config_filename_(filename)
, sp_config_filename_(sp_config_filename)
{
loadGemmConfig();
loadSpGemmConfig();
}
cublasAlgoMap::cublasAlgoMap(const cublasAlgoMap& algo_map)
: config_filename_(algo_map.config_filename_)
, sp_config_filename_(algo_map.sp_config_filename_)
, algo_map_(algo_map.algo_map_)
, sp_algo_map_(algo_map.sp_algo_map_)
{
}
cublasAlgoMap::~cublasAlgoMap()
{
algo_map_.clear();
}
void cublasAlgoMap::loadGemmConfig()
{
FILE* fd;
fd = fopen(config_filename_.c_str(), "r");
if (fd == NULL)
{
return;
}
int batchCount2, m2, n2, k2, algoId, customOption, tile, splitK_val;
int batch_size, seq_len, head_num, size_per_head, dataType;
int swizzle, reductionScheme, workspaceSize, stages;
int inner_shapeId, cluster_shapeId, mma_shapeId, cga_shapeId, sche_mode;
float exec_time;
char tmp[1024];
if (!fgets(tmp, 1024, fd))
{
printf("[ERROR] fgets fail at %s:%d \n", __FILE__, __LINE__);
exit(-1);
}
while (fscanf(fd,
"%d %d %d %d %d ### %d %d %d %d %d %d %d %d %d %d %d %d "
#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3)
"%d %d "
#elif (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH < 3)
"%d %d %d "
#endif
"%f\n",
&batch_size, &seq_len, &head_num, &size_per_head, &dataType, &batchCount2, &n2, &m2, &k2, &algoId,
&customOption, &tile, &splitK_val, &swizzle, &reductionScheme, &workspaceSize, &stages,
#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3)
&inner_shapeId, &cluster_shapeId,
#elif (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH < 3)
&mma_shapeId, &cga_shapeId, &sche_mode,
#endif
&exec_time)
!= EOF)
{
if (dataType != FLOAT_DATATYPE && dataType != HALF_DATATYPE && dataType != BFLOAT16_DATATYPE
&& dataType != INT8_DATATYPE && dataType != FP8_DATATYPE)
{
printf("[WARNING][readAlgoFromConfig] wrong dataType %d!\n", dataType);
continue;
}
char mark[256];
sprintf(mark, "%d_%d_%d_%d_%d", batchCount2, m2, n2, k2, dataType);
std::string markStr(mark);
// workspaceSize should be zero
if (algo_map_.find(markStr) == algo_map_.end())
{
algo_map_[markStr].algoId = algoId;
algo_map_[markStr].customOption = customOption;
algo_map_[markStr].tile = tile;
algo_map_[markStr].splitK_val = splitK_val;
algo_map_[markStr].swizzle = swizzle;
algo_map_[markStr].reductionScheme = reductionScheme;
algo_map_[markStr].workspaceSize = workspaceSize;
algo_map_[markStr].stages = stages;
#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3)
algo_map_[markStr].inner_shapeId = (uint16_t) inner_shapeId;
algo_map_[markStr].cluster_shapeId = (uint16_t) cluster_shapeId;
#elif (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH < 3)
algo_map_[markStr].mma_shapeId = (uint16_t) mma_shapeId;
algo_map_[markStr].cga_shapeId = (uint16_t) cga_shapeId;
algo_map_[markStr].sche_mode = (uint16_t) sche_mode;
#endif
algo_map_[markStr].exec_time = exec_time;
}
}
fclose(fd);
}
bool cublasAlgoMap::isExist(
const int batch_count, const int m, const int n, const int k, const CublasDataType data_type)
{
char mark[256];
sprintf(mark, "%d_%d_%d_%d_%d", batch_count, n, m, k, data_type);
return algo_map_.find(mark) != algo_map_.end();
}
cublasLtMatmulAlgo_info cublasAlgoMap::getAlgo(
const int batch_count, const int m, const int n, const int k, const CublasDataType data_type)
{
char mark[256];
sprintf(mark, "%d_%d_%d_%d_%d", batch_count, n, m, k, data_type);
if (algo_map_.find(mark) != algo_map_.end())
{
return algo_map_[mark];
}
else
{
cublasLtMatmulAlgo_info tmp_algo;
tmp_algo.algoId
= static_cast<int>(data_type == FLOAT_DATATYPE ? CUBLAS_GEMM_DEFAULT : CUBLAS_GEMM_DEFAULT_TENSOR_OP);
tmp_algo.customOption = -1;
tmp_algo.tile = -1;
tmp_algo.splitK_val = -1;
tmp_algo.swizzle = -1;
tmp_algo.reductionScheme = -1;
tmp_algo.workspaceSize = -1;
tmp_algo.stages = -1;
tmp_algo.exec_time = -1.0f;
return tmp_algo;
}
}
void cublasAlgoMap::loadSpGemmConfig()
{
if (sp_config_filename_.empty())
{
return;
}
FILE* fd = fopen(sp_config_filename_.c_str(), "r");
if (fd == NULL)
{
return;
}
sp_algo_map_.clear();
int batch_size, seq_len, head_num, size_per_head, data_type;
int batchCount, m, n, k, algoId;
float exec_time;
char tmp[1024];
if (!fgets(tmp, 1024, fd))
{
printf("[ERROR] fgets fail at %s:%d \n", __FILE__, __LINE__);
exit(-1);
}
while (fscanf(fd, "%d %d %d %d %d ### %d %d %d %d %d %f\n", &batch_size, &seq_len, &head_num, &size_per_head,
&data_type, &batchCount, &m, &n, &k, &algoId, &exec_time)
!= EOF)
{
char mark[256];
sprintf(mark, "%d_%d_%d_%d", batchCount, m, n, k);
std::string markStr(mark);
sp_algo_map_[markStr] = algoId;
}
fclose(fd);
}
int cublasAlgoMap::getSpAlgo(const int batch_count, const int m, const int n, const int k)
{
char mark[256];
sprintf(mark, "%d_%d_%d_%d", batch_count, m, n, k);
if (sp_algo_map_.find(mark) != sp_algo_map_.end())
{
return sp_algo_map_[mark];
}
else
{
// for remove padding, select algo 1 for simplicity
return 0;
}
}
bool cublasAlgoMap::isUseSparse(const int batch_count, const int m, const int n, const int k)
{
// not available to use cusparselt.
if (m % 8 != 0 || n % 8 != 0 || k % 8 != 0)
{
return false;
}
char mark[256];
sprintf(mark, "%d_%d_%d_%d", batch_count, m, n, k);
if (sp_algo_map_.find(mark) != sp_algo_map_.end())
{
return sp_algo_map_[mark] != -1;
}
else
{
// no gemm test case, choose sparse according to sparse flag
return true;
}
}
} // namespace common
} // namespace tensorrt_llm

View File

@ -1,90 +0,0 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/common/cublasVersionCheck.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include <cublasLt.h>
#include <cublas_v2.h>
#include <cuda_runtime.h>
#include <map>
#include <string>
#include <utility>
namespace tensorrt_llm
{
namespace common
{
#define GEMM_NUM 6
#define GEMM_CONFIG "gemm_config.in"
#define IGEMM_CONFIG "igemm_config.in"
#define SPGEMM_CONFIG "spgemm_config.in"
#define SPIGEMM_CONFIG "spigemm_config.in"
typedef struct
{
int algoId, customOption, tile, splitK_val;
int swizzle, reductionScheme, workspaceSize;
// only used in cublasLt >= 11.0
int stages;
#if TLLM_CUBLAS_VER_GE(11, 11, 3)
uint16_t inner_shapeId, cluster_shapeId;
#else
uint16_t mma_shapeId, cga_shapeId, sche_mode;
#endif
float exec_time;
} cublasLtMatmulAlgo_info;
/* Structure to store information about different run trials */
typedef struct
{
cublasLtMatmulAlgo_t algo;
cublasStatus_t status;
float time;
size_t workspaceSize; // actual memory workspace needed
cublasMath_t mathMode;
cublasLtReductionScheme_t reductionScheme;
int customOption;
float wavesCount;
} customMatmulPerf_t;
class cublasAlgoMap
{
private:
std::map<std::string, cublasLtMatmulAlgo_info> algo_map_;
std::string config_filename_;
std::string sp_config_filename_;
std::map<std::string, int> sp_algo_map_;
public:
cublasAlgoMap(){};
explicit cublasAlgoMap(const std::string filename, const std::string sp_config_filename = "");
cublasAlgoMap(const cublasAlgoMap& map);
~cublasAlgoMap();
void loadGemmConfig();
void loadSpGemmConfig();
int getSpAlgo(const int batch_count, const int m, const int n, const int k);
bool isUseSparse(const int batch_count, const int m, const int n, const int k);
bool isExist(const int batch_count, const int m, const int n, const int k, const CublasDataType data_type);
cublasLtMatmulAlgo_info getAlgo(
const int batch_count, const int m, const int n, const int k, const CublasDataType data_type);
};
} // namespace common
} // namespace tensorrt_llm

View File

@ -28,299 +28,218 @@ namespace tensorrt_llm
namespace common
{
cublasMMWrapper::cublasMMWrapper(std::shared_ptr<cublasHandle_t> cublasHandle,
std::shared_ptr<cublasLtHandle_t> cublasltHandle, cudaStream_t stream, cublasAlgoMap* cublas_algo_map,
std::mutex* mu, void* workspace)
: cublas_handle_(cublasHandle)
, cublaslt_handle_(cublasltHandle)
, stream_(stream)
, cublas_algo_map_(cublas_algo_map)
, mu_(mu)
, cublas_workspace_(workspace)
CublasMMWrapper::CublasMMWrapper(std::shared_ptr<cublasHandle_t> cublasHandle,
std::shared_ptr<cublasLtHandle_t> cublasltHandle, cudaStream_t stream, void* workspace)
: mCublasHandle(cublasHandle)
, mCublasLtHandle(cublasltHandle)
, mStream(stream)
, mCublasWorkspace(workspace)
{
}
cublasMMWrapper::~cublasMMWrapper()
CublasMMWrapper::~CublasMMWrapper()
{
mu_ = nullptr;
mMutex = nullptr;
}
cublasMMWrapper::cublasMMWrapper(const cublasMMWrapper& wrapper)
: cublas_handle_(wrapper.cublas_handle_)
, cublaslt_handle_(wrapper.cublaslt_handle_)
, stream_(wrapper.stream_)
, cublas_algo_map_(wrapper.cublas_algo_map_)
, mu_(wrapper.mu_)
CublasMMWrapper::CublasMMWrapper(const CublasMMWrapper& wrapper)
: mCublasHandle(wrapper.mCublasHandle)
, mCublasLtHandle(wrapper.mCublasLtHandle)
, mStream(wrapper.mStream)
, mMutex(wrapper.mMutex)
{
}
void cublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k,
const void* alpha, const void* A, cudaDataType_t Atype, int lda, const void* B, cudaDataType_t Btype, int ldb,
const void* beta, void* C, cudaDataType_t Ctype, int ldc, cudaDataType_t computeType, cublasGemmAlgo_t algo)
void CublasMMWrapper::createDescriptors(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n,
const int k, const int lda, const int ldb, const int ldc)
{
mu_->lock();
check_cuda_error(cublasGemmEx(*cublas_handle_, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb, beta,
C, Ctype, ldc, computeType, algo));
sync_check_cuda_error();
mu_->unlock();
// --------------------------------------
// Create descriptors for the original matrices
check_cuda_error(
cublasLtMatrixLayoutCreate(&mADesc, mAType, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda));
check_cuda_error(
cublasLtMatrixLayoutCreate(&mBDesc, mBType, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb));
check_cuda_error(cublasLtMatrixLayoutCreate(&mCDesc, mCType, m, n, ldc));
check_cuda_error(cublasLtMatmulDescCreate(&mOperationDesc, mComputeType, mScaleType));
check_cuda_error(cublasLtMatmulDescSetAttribute(
mOperationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(cublasOperation_t)));
check_cuda_error(cublasLtMatmulDescSetAttribute(
mOperationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(cublasOperation_t)));
}
void cublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k,
void CublasMMWrapper::destroyDescriptors()
{
check_cuda_error(cublasLtMatmulDescDestroy(mOperationDesc));
check_cuda_error(cublasLtMatrixLayoutDestroy(mADesc));
check_cuda_error(cublasLtMatrixLayoutDestroy(mBDesc));
check_cuda_error(cublasLtMatrixLayoutDestroy(mCDesc));
mOperationDesc = NULL;
mADesc = NULL;
mBDesc = NULL;
mCDesc = NULL;
}
void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k,
const void* A, const int lda, const void* B, const int ldb, void* C, const int ldc)
{
Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, 1.0f, 0.0f);
}
void cublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k,
void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k,
const void* A, const int lda, const void* B, const int ldb, void* C, const int ldc,
const std::optional<cublasLtMatmulHeuristicResult_t>& heuristic)
{
if (heuristic)
{
Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, 1.0f, 0.0f, (*heuristic).algo,
(*heuristic).state == CUBLAS_STATUS_SUCCESS && (*heuristic).workspaceSize < CUBLAS_WORKSPACE_SIZE);
Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, 1.0f, 0.0f, /* hasAlgo */ (*heuristic).algo,
(*heuristic).state == CUBLAS_STATUS_SUCCESS && (*heuristic).workspaceSize < CUBLAS_WORKSPACE_SIZE,
/* usingCublasLt */ true);
}
else
{
Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, 1.0f, 0.0f, {}, false);
Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, 1.0f, 0.0f, {}, /* hasAlgo */ false,
/* usingCublasLt */ true);
}
}
void cublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k,
void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k,
const void* A, const int lda, const void* B, const int ldb, void* C, const int ldc, float f_alpha, float f_beta)
{
bool usingCublasLt = Atype_ == CUDA_R_16F;
bool isFp16ComputeType = computeType_ == CUDA_R_16F;
bool usingCublasLt = mAType == CUDA_R_16F;
int batch_count = 1;
int findAlgo = cublas_algo_map_->isExist(batch_count, m, n, k, getCublasDataType(Atype_));
cublasLtMatmulAlgo_info info = cublas_algo_map_->getAlgo(batch_count, m, n, k, getCublasDataType(Atype_));
cublasLtMatmulAlgo_t algo;
void* workSpace = cublas_workspace_;
int workspaceSize = cublas_workspace_ == NULL ? 0 : CUBLAS_WORKSPACE_SIZE;
if (findAlgo)
{
if (info.stages != -1)
{
usingCublasLt = true;
}
else
{
usingCublasLt = false;
}
}
if (usingCublasLt)
{
cublasLtMatmulDesc_t operationDesc = NULL;
cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL;
cudaDataType_t scaleType;
#if (CUDART_VERSION >= 11000)
cublasComputeType_t computeType;
#else
cudaDataType_t computeType;
#endif
if (isFp16ComputeType)
{
#if (CUDART_VERSION >= 11000)
computeType = CUBLAS_COMPUTE_16F;
#else
computeType = CUDA_R_16F;
#endif
scaleType = CUDA_R_16F;
}
else
{
#if (CUDART_VERSION >= 11000)
computeType = CUBLAS_COMPUTE_32F;
#else
computeType = CUDA_R_32F;
#endif
scaleType = CUDA_R_32F;
}
if (findAlgo)
{
if (info.workspaceSize > workspaceSize)
{
findAlgo = 0;
}
else
{
cublasLtMatmulAlgoInit(
*cublaslt_handle_, computeType, scaleType, Atype_, Btype_, Ctype_, Ctype_, info.algoId, &algo);
cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &(info.customOption), sizeof(info.customOption));
cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_TILE_ID, &(info.tile), sizeof(info.tile));
cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, &(info.splitK_val), sizeof(info.splitK_val));
cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(info.swizzle), sizeof(info.swizzle));
cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &(info.reductionScheme), sizeof(int));
#if (CUDART_VERSION >= 11000)
cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(info.stages), sizeof(info.stages));
#endif
}
}
}
Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, f_alpha, f_beta, algo, findAlgo);
Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, f_alpha, f_beta, {}, /* hasAlgo */ false,
/* usingCublasLt */ usingCublasLt);
}
void cublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k,
void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k,
const void* A, const int lda, const void* B, const int ldb, void* C, const int ldc, float f_alpha, float f_beta,
const cublasLtMatmulAlgo_t& algo, bool hasAlgo)
const cublasLtMatmulAlgo_t& algo, bool hasAlgo, bool usingCublasLt)
{
half h_alpha = (half) (f_alpha);
half h_beta = (half) (f_beta);
std::lock_guard<std::mutex> lock(*mu_);
std::lock_guard<std::mutex> lock(*mMutex);
// TODO: default cublas libs
bool usingCublasLt = Atype_ == CUDA_R_16F;
bool isFp16ComputeType = computeType_ == CUDA_R_16F;
usingCublasLt = usingCublasLt && mAType == CUDA_R_16F;
bool isFp16ComputeType = mComputeType == CUBLAS_COMPUTE_16F;
int batch_count = 1;
// fp32 use cublas as default
// fp16 use cublasLt as default
const void* alpha = isFp16ComputeType ? reinterpret_cast<void*>(&h_alpha) : reinterpret_cast<void*>(&f_alpha);
const void* beta = isFp16ComputeType ? reinterpret_cast<void*>(&h_beta) : reinterpret_cast<void*>(&f_beta);
if (hasAlgo)
{
int32_t stages;
cublasLtMatmulAlgoConfigGetAttribute(&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &stages, sizeof(stages), NULL);
if (stages != -1)
{
usingCublasLt = true;
}
else
{
usingCublasLt = false;
}
}
int workspaceSize = mCublasWorkspace == NULL ? 0 : CUBLAS_WORKSPACE_SIZE;
if (usingCublasLt)
{
cublasLtMatmulDesc_t operationDesc = NULL;
cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL;
cudaDataType_t scaleType;
#if (CUDART_VERSION >= 11000)
cublasComputeType_t computeType;
#else
cudaDataType_t computeType;
#endif
if (isFp16ComputeType)
{
#if (CUDART_VERSION >= 11000)
computeType = CUBLAS_COMPUTE_16F;
#else
computeType = CUDA_R_16F;
#endif
scaleType = CUDA_R_16F;
}
else
{
#if (CUDART_VERSION >= 11000)
computeType = CUBLAS_COMPUTE_32F;
#else
computeType = CUDA_R_32F;
#endif
scaleType = CUDA_R_32F;
}
// --------------------------------------
// Create descriptors for the original matrices
cublasLtMatrixLayoutCreate(&Adesc, Atype_, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);
cublasLtMatrixLayoutCreate(&Bdesc, Btype_, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);
cublasLtMatrixLayoutCreate(&Cdesc, Ctype_, m, n, ldc);
#if (CUDART_VERSION >= 11000)
cublasLtMatmulDescCreate(&operationDesc, computeType, scaleType);
#else
cublasLtMatmulDescCreate(&operationDesc, computeType);
#endif
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(cublasOperation_t));
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(cublasOperation_t));
void* workSpace = cublas_workspace_;
int workspaceSize = cublas_workspace_ == NULL ? 0 : CUBLAS_WORKSPACE_SIZE;
if (hasAlgo)
{
cublasLtMatmulHeuristicResult_t heurResult;
// We have to check if the heruistic is correct given current shape size
cublasStatus_t algoStatus = cublasLtMatmulAlgoCheck(
getCublasLtHandle(), operationDesc, Adesc, Bdesc, Cdesc, Cdesc, &algo, &heurResult);
if (algoStatus != CUBLAS_STATUS_SUCCESS || heurResult.state != CUBLAS_STATUS_SUCCESS
|| heurResult.workspaceSize > CUBLAS_WORKSPACE_SIZE)
{
// Rely on runtime based heruistic
hasAlgo = false;
}
hasAlgo = checkTactic(transa, transb, m, n, k, lda, ldb, ldc, algo);
}
check_cuda_error(cublasLtMatmul(*cublaslt_handle_, operationDesc, alpha, A, Adesc, B, Bdesc, beta, C, Cdesc, C,
Cdesc, (hasAlgo ? (&algo) : NULL), workSpace, workspaceSize, stream_));
check_cuda_error(cublasLtMatmul(getCublasLtHandle(), mOperationDesc, alpha, A, mADesc, B, mBDesc, beta, C,
mCDesc, C, mCDesc, (hasAlgo ? (&algo) : NULL), mCublasWorkspace, workspaceSize, mStream));
cublasLtMatmulDescDestroy(operationDesc);
cublasLtMatrixLayoutDestroy(Adesc);
cublasLtMatrixLayoutDestroy(Bdesc);
cublasLtMatrixLayoutDestroy(Cdesc);
sync_check_cuda_error();
}
else
{
check_cuda_error(cublasSetStream(getCublasHandle(), mStream));
check_cuda_error(cublasSetWorkspace(getCublasHandle(), mCublasWorkspace, workspaceSize));
// Go with default heruistic to choose tactic as cuBLAS does not allow to choose tactics in Ampere+
cublasGemmAlgo_t cublasAlgo = CUBLAS_GEMM_DEFAULT;
check_cuda_error(cublasGemmEx(*cublas_handle_, transa, transb, m, n, k, alpha, A, Atype_, lda, B, Btype_, ldb,
beta, C, Ctype_, ldc, computeType_, static_cast<cublasGemmAlgo_t>(cublasAlgo)));
check_cuda_error(cublasGemmEx(getCublasHandle(), transa, transb, m, n, k, alpha, A, mAType, lda, B, mBType, ldb,
beta, C, mCType, ldc, mComputeType, static_cast<cublasGemmAlgo_t>(cublasAlgo)));
sync_check_cuda_error();
}
}
void cublasMMWrapper::setWorkspace(void* workspace)
void CublasMMWrapper::stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n,
const int k, const void* A, const int lda, const int64_t strideA, const void* B, const int ldb,
const int64_t strideB, void* C, const int ldc, const int64_t strideC, const int batchCount, const float f_alpha,
const float f_beta)
{
cublas_workspace_ = workspace;
half h_alpha = (half) f_alpha;
half h_beta = (half) f_beta;
std::lock_guard<std::mutex> lock(*mMutex);
int isFp16ComputeType = mComputeType == CUBLAS_COMPUTE_16F ? 1 : 0;
const void* alpha = isFp16ComputeType ? reinterpret_cast<void*>(&h_alpha) : reinterpret_cast<const void*>(&f_alpha);
const void* beta = isFp16ComputeType ? reinterpret_cast<void*>(&h_beta) : reinterpret_cast<const void*>(&f_beta);
check_cuda_error(cublasGemmStridedBatchedEx(getCublasHandle(), transa, transb, m, n, k, alpha, A, mAType, lda,
strideA, B, mBType, ldb, strideB, beta, C, mCType, ldc, strideC, batchCount, mComputeType,
mAType == CUDA_R_32F ? CUBLAS_GEMM_DEFAULT : CUBLAS_GEMM_DEFAULT_TENSOR_OP));
}
void cublasMMWrapper::setFP32GemmConfig()
void CublasMMWrapper::stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n,
const int k, const float f_alpha, const void* A, cudaDataType_t AType, const int lda, const int64_t strideA,
const void* B, cudaDataType_t BType, const int ldb, const int64_t strideB, const float f_beta, void* C,
cudaDataType_t CType, const int ldc, const int64_t strideC, const int batchCount, cudaDataType_t computeType)
{
half h_alpha = (half) f_alpha;
half h_beta = (half) f_beta;
std::lock_guard<std::mutex> lock(*mMutex);
bool isFp16ComputeType = mComputeType == CUBLAS_COMPUTE_16F ? 1 : 0;
const void* alpha = isFp16ComputeType ? reinterpret_cast<void*>(&h_alpha) : reinterpret_cast<const void*>(&f_alpha);
const void* beta = isFp16ComputeType ? reinterpret_cast<void*>(&h_beta) : reinterpret_cast<const void*>(&f_beta);
check_cuda_error(cublasGemmStridedBatchedEx(getCublasHandle(), transa, transb, m, n, k, alpha, A, AType, lda,
strideA, B, BType, ldb, strideB, beta, C, CType, ldc, strideC, batchCount, computeType,
mAType == CUDA_R_32F ? CUBLAS_GEMM_DEFAULT : CUBLAS_GEMM_DEFAULT_TENSOR_OP));
}
void CublasMMWrapper::setWorkspace(void* workspace)
{
mCublasWorkspace = workspace;
}
void CublasMMWrapper::setFP32GemmConfig()
{
setGemmConfig(CUDA_R_32F, CUDA_R_32F, CUDA_R_32F, CUDA_R_32F);
}
void cublasMMWrapper::setFP16GemmConfig()
void CublasMMWrapper::setFP16GemmConfig()
{
setGemmConfig(CUDA_R_16F, CUDA_R_16F, CUDA_R_16F, CUDA_R_32F);
}
#ifdef ENABLE_BF16
void cublasMMWrapper::setBF16GemmConfig()
void CublasMMWrapper::setBF16GemmConfig()
{
setGemmConfig(CUDA_R_16BF, CUDA_R_16BF, CUDA_R_16BF, CUDA_R_32F);
}
#endif
#ifdef ENABLE_FP8
void cublasMMWrapper::setFP8GemmConfig(cudaDataType_t outputType)
void CublasMMWrapper::setFP8GemmConfig(cudaDataType_t outputType)
{
setGemmConfig(CUDA_R_8F_E4M3, CUDA_R_8F_E4M3, outputType, CUDA_R_32F);
}
#endif
void cublasMMWrapper::setGemmConfig(
void CublasMMWrapper::setGemmConfig(
cudaDataType_t aType, cudaDataType_t bType, cudaDataType_t cType, cudaDataType_t computeType)
{
Atype_ = aType;
Btype_ = bType;
Ctype_ = cType;
computeType_ = computeType;
mAType = aType;
mBType = bType;
mCType = cType;
bool isFp16ComputeType = computeType == CUDA_R_16F;
if (isFp16ComputeType)
{
mComputeType = CUBLAS_COMPUTE_16F;
mScaleType = CUDA_R_16F;
}
else
{
mComputeType = CUBLAS_COMPUTE_32F;
mScaleType = CUDA_R_32F;
}
}
CublasDataType cublasMMWrapper::getCublasDataType(cudaDataType_t data_type)
CublasDataType CublasMMWrapper::getCublasDataType(cudaDataType_t data_type)
{
if (data_type == CUDA_R_16F)
{
@ -343,279 +262,22 @@ CublasDataType cublasMMWrapper::getCublasDataType(cudaDataType_t data_type)
return FLOAT_DATATYPE;
}
#if (CUDART_VERSION >= 11000)
// input, weight, output are row-major
// only works for cublas 11.x
void cublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k,
const void* A, const int lda, const void* B, const int ldb, const void* bias, void* C, const int ldc)
void CublasMMWrapper::setStream(cudaStream_t stream)
{
cudaDataType_t Atype, Btype, Ctype;
cublasComputeType_t computeType;
cudaDataType_t scaleType;
float alpha_float = 1.0f;
float beta_float = 0.0f;
half alpha_half = half(1.0f);
half beta_half = half(0.0f);
void *alpha, *beta;
// int is_fp16_computeType = computeType_ == CUDA_R_16F ? 1 : 0;
if (Atype_ == CUDA_R_32F)
{
computeType = CUBLAS_COMPUTE_32F_FAST_TF32;
Atype = CUDA_R_32F;
Btype = CUDA_R_32F;
Ctype = CUDA_R_32F;
scaleType = CUDA_R_32F;
alpha = &alpha_float;
beta = &beta_float;
}
else if (Atype_ == CUDA_R_16BF)
{
computeType = CUBLAS_COMPUTE_32F_FAST_TF32;
Atype = CUDA_R_16BF;
Btype = CUDA_R_16BF;
Ctype = CUDA_R_16BF;
scaleType = CUDA_R_32F;
alpha = &alpha_float;
beta = &beta_float;
}
else
{
computeType = CUBLAS_COMPUTE_16F;
Atype = CUDA_R_16F;
Btype = CUDA_R_16F;
Ctype = CUDA_R_16F;
scaleType = CUDA_R_16F;
alpha = &alpha_half;
beta = &beta_half;
}
cublasLtMatmulDesc_t operationDesc = NULL;
cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL;
cublasLtEpilogue_t epi = CUBLASLT_EPILOGUE_BIAS;
cublasLtMatrixLayoutCreate(&Adesc, Atype, (transa == CUBLAS_OP_N) ? m : k, (transa == CUBLAS_OP_N) ? k : m, lda);
cublasLtMatrixLayoutCreate(&Bdesc, Btype, (transb == CUBLAS_OP_N) ? k : n, (transb == CUBLAS_OP_N) ? n : k, ldb);
cublasLtMatrixLayoutCreate(&Cdesc, Ctype, m, n, ldc);
cublasLtMatmulDescCreate(&operationDesc, computeType, scaleType);
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(cublasOperation_t));
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(cublasOperation_t));
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epi, sizeof(cublasLtEpilogue_t));
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(const void*));
check_cuda_error(cublasLtMatmul(
*cublaslt_handle_, operationDesc, alpha, A, Adesc, B, Bdesc, beta, C, Cdesc, C, Cdesc, NULL, NULL, 0, stream_));
cublasLtMatrixLayoutDestroy(Adesc);
cublasLtMatrixLayoutDestroy(Bdesc);
cublasLtMatrixLayoutDestroy(Cdesc);
cublasLtMatmulDescDestroy(operationDesc);
}
#endif
void cublasMMWrapper::setStream(cudaStream_t stream)
{
stream_ = stream;
mStream = stream;
}
void cublasMMWrapper::stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n,
const int k, const void* A, const int lda, const int64_t strideA, const void* B, const int ldb,
const int64_t strideB, void* C, const int ldc, const int64_t strideC, const int batch_count, const float f_alpha,
const float f_beta)
bool CublasMMWrapper::checkTactic(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n,
const int k, const int lda, const int ldb, const int ldc, const cublasLtMatmulAlgo_t& algo)
{
half h_alpha = (half) f_alpha;
half h_beta = (half) f_beta;
TLLM_CHECK_WITH_INFO(
descriptorsCreated(), "Descriptors are not created! Call createDescriptors before calling this function");
mu_->lock();
int is_fp16_computeType = computeType_ == CUDA_R_16F ? 1 : 0;
const void* alpha
= is_fp16_computeType ? reinterpret_cast<void*>(&h_alpha) : reinterpret_cast<const void*>(&f_alpha);
const void* beta = is_fp16_computeType ? reinterpret_cast<void*>(&h_beta) : reinterpret_cast<const void*>(&f_beta);
cublasLtMatmulAlgo_info info = cublas_algo_map_->getAlgo(batch_count, m, n, k, getCublasDataType(Atype_));
check_cuda_error(cublasGemmStridedBatchedEx(*cublas_handle_, transa, transb, m, n, k, alpha, A, Atype_, lda,
strideA, B, Btype_, ldb, strideB, beta, C, Ctype_, ldc, strideC, batch_count, computeType_,
static_cast<cublasGemmAlgo_t>(info.algoId)));
mu_->unlock();
}
void cublasMMWrapper::stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n,
const int k, const float f_alpha, const void* A, cudaDataType_t AType, const int lda, const int64_t strideA,
const void* B, cudaDataType_t BType, const int ldb, const int64_t strideB, const float f_beta, void* C,
cudaDataType_t CType, const int ldc, const int64_t strideC, const int batch_count, cudaDataType_t computeType)
{
half h_alpha = (half) f_alpha;
half h_beta = (half) f_beta;
mu_->lock();
int is_fp16_computeType = computeType == CUDA_R_16F ? 1 : 0;
const void* alpha
= is_fp16_computeType ? reinterpret_cast<void*>(&h_alpha) : reinterpret_cast<const void*>(&f_alpha);
const void* beta = is_fp16_computeType ? reinterpret_cast<void*>(&h_beta) : reinterpret_cast<const void*>(&f_beta);
cublasLtMatmulAlgo_info info = cublas_algo_map_->getAlgo(batch_count, m, n, k, getCublasDataType(Atype_));
check_cuda_error(cublasGemmStridedBatchedEx(*cublas_handle_, transa, transb, m, n, k, alpha, A, AType, lda, strideA,
B, BType, ldb, strideB, beta, C, CType, ldc, strideC, batch_count, computeType,
static_cast<cublasGemmAlgo_t>(info.algoId)));
mu_->unlock();
}
void cublasMMWrapper::batchedGemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n,
const int k, const void* const* A, const int lda, const void* const* B, const int ldb, void* const* C,
const int ldc, const int batch_count)
{
float f_alpha = static_cast<float>(1.0f);
float f_beta = static_cast<float>(0.0f);
half h_alpha = (half) 1.0f;
half h_beta = (half) 0.0f;
mu_->lock();
int is_fp16_computeType = computeType_ == CUDA_R_16F ? 1 : 0;
const void* alpha = is_fp16_computeType ? reinterpret_cast<void*>(&h_alpha) : reinterpret_cast<void*>(&f_alpha);
const void* beta = is_fp16_computeType ? reinterpret_cast<void*>(&h_beta) : reinterpret_cast<void*>(&f_beta);
cublasLtMatmulAlgo_info info = cublas_algo_map_->getAlgo(batch_count, m, n, k, getCublasDataType(Atype_));
check_cuda_error(cublasGemmBatchedEx(*cublas_handle_, transa, transb, m, n, k, alpha, A, Atype_, lda, B, Btype_,
ldb, beta, C, Ctype_, ldc, batch_count, computeType_, static_cast<cublasGemmAlgo_t>(info.algoId)));
mu_->unlock();
}
bool cublasMMWrapper::isFuseBatchGemm(const int batch_count, const int m, const int k, const int n)
{
CublasDataType data_type = getCublasDataType(Atype_);
if (cublas_algo_map_->isExist(batch_count, m, k, n, data_type) == false
|| cublas_algo_map_->isExist(1, m, k, n, data_type) == false)
{
return false;
}
else
{
return cublas_algo_map_->getAlgo(batch_count, m, k, n, data_type).exec_time
< 3 * cublas_algo_map_->getAlgo(1, m, k, n, data_type).exec_time;
}
}
std::vector<cublasLtMatmulHeuristicResult_t> cublasMMWrapper::getTactics(cublasOperation_t transa,
cublasOperation_t transb, const int m, const int n, const int k, const int lda, const int ldb, const int ldc)
{
int is_fp16_computeType = computeType_ == CUDA_R_16F ? 1 : 0;
cublasLtMatmulDesc_t operationDesc = NULL;
cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL, Ddesc = NULL;
cudaDataType_t scaleType;
#if (CUDART_VERSION >= 11000)
cublasComputeType_t computeType;
#else
cudaDataType_t computeType;
#endif
if (is_fp16_computeType)
{
#if (CUDART_VERSION >= 11000)
computeType = CUBLAS_COMPUTE_16F;
#else
computeType = CUDA_R_16F;
#endif
scaleType = CUDA_R_16F;
}
else
{
#if (CUDART_VERSION >= 11000)
computeType = CUBLAS_COMPUTE_32F;
#else
computeType = CUDA_R_32F;
#endif
scaleType = CUDA_R_32F;
}
// --------------------------------------
// Create descriptors for the original matrices
check_cuda_error(
cublasLtMatrixLayoutCreate(&Adesc, Atype_, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda));
check_cuda_error(
cublasLtMatrixLayoutCreate(&Bdesc, Btype_, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb));
check_cuda_error(cublasLtMatrixLayoutCreate(&Cdesc, Ctype_, m, n, ldc));
#if (CUDART_VERSION >= 11000)
check_cuda_error(cublasLtMatmulDescCreate(&operationDesc, computeType, scaleType));
#else
check_cuda_error(cublasLtMatmulDescCreate(&operationDesc, computeType));
#endif
check_cuda_error(
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(cublasOperation_t)));
check_cuda_error(
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(cublasOperation_t)));
void* workSpace = cublas_workspace_;
int workspaceSize = cublas_workspace_ == NULL ? 0 : CUBLAS_WORKSPACE_SIZE;
const auto heuristics = getTactics(getCublasLtHandle(), operationDesc, Adesc, Bdesc, Cdesc, Cdesc);
check_cuda_error(cublasLtMatmulDescDestroy(operationDesc));
check_cuda_error(cublasLtMatrixLayoutDestroy(Adesc));
check_cuda_error(cublasLtMatrixLayoutDestroy(Bdesc));
check_cuda_error(cublasLtMatrixLayoutDestroy(Cdesc));
sync_check_cuda_error();
return heuristics;
}
bool cublasMMWrapper::checkTactic(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n,
const int k, const int lda, const int ldb, const int ldc, const cublasLtMatmulHeuristicResult_t& heuristic) const
{
int is_fp16_computeType = computeType_ == CUDA_R_16F ? 1 : 0;
cublasLtMatmulDesc_t operationDesc = NULL;
cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL, Ddesc = NULL;
cudaDataType_t scaleType;
#if (CUDART_VERSION >= 11000)
cublasComputeType_t computeType;
#else
cudaDataType_t computeType;
#endif
if (is_fp16_computeType)
{
#if (CUDART_VERSION >= 11000)
computeType = CUBLAS_COMPUTE_16F;
#else
computeType = CUDA_R_16F;
#endif
scaleType = CUDA_R_16F;
}
else
{
#if (CUDART_VERSION >= 11000)
computeType = CUBLAS_COMPUTE_32F;
#else
computeType = CUDA_R_32F;
#endif
scaleType = CUDA_R_32F;
}
// --------------------------------------
// Create descriptors for the original matrices
check_cuda_error(
cublasLtMatrixLayoutCreate(&Adesc, Atype_, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda));
check_cuda_error(
cublasLtMatrixLayoutCreate(&Bdesc, Btype_, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb));
check_cuda_error(cublasLtMatrixLayoutCreate(&Cdesc, Ctype_, m, n, ldc));
#if (CUDART_VERSION >= 11000)
check_cuda_error(cublasLtMatmulDescCreate(&operationDesc, computeType, scaleType));
#else
check_cuda_error(cublasLtMatmulDescCreate(&operationDesc, computeType));
#endif
check_cuda_error(
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(cublasOperation_t)));
check_cuda_error(
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(cublasOperation_t)));
void* workSpace = cublas_workspace_;
int workspaceSize = cublas_workspace_ == NULL ? 0 : CUBLAS_WORKSPACE_SIZE;
int workspaceSize = mCublasWorkspace == NULL ? 0 : CUBLAS_WORKSPACE_SIZE;
cublasLtMatmulHeuristicResult_t heurResult;
cublasStatus_t algoStatus = cublasLtMatmulAlgoCheck(
getCublasLtHandle(), operationDesc, Adesc, Bdesc, Cdesc, Cdesc, &heuristic.algo, &heurResult);
getCublasLtHandle(), mOperationDesc, mADesc, mBDesc, mCDesc, mCDesc, &algo, &heurResult);
if (algoStatus != CUBLAS_STATUS_SUCCESS || heurResult.state != CUBLAS_STATUS_SUCCESS
|| heurResult.workspaceSize > CUBLAS_WORKSPACE_SIZE)
@ -623,16 +285,25 @@ bool cublasMMWrapper::checkTactic(cublasOperation_t transa, cublasOperation_t tr
return false;
}
check_cuda_error(cublasLtMatmulDescDestroy(operationDesc));
check_cuda_error(cublasLtMatrixLayoutDestroy(Adesc));
check_cuda_error(cublasLtMatrixLayoutDestroy(Bdesc));
check_cuda_error(cublasLtMatrixLayoutDestroy(Cdesc));
sync_check_cuda_error();
return true;
}
std::vector<cublasLtMatmulHeuristicResult_t> cublasMMWrapper::getTactics(cublasLtHandle_t lightHandle,
std::vector<cublasLtMatmulHeuristicResult_t> CublasMMWrapper::getTactics(cublasOperation_t transa,
cublasOperation_t transb, const int m, const int n, const int k, const int lda, const int ldb, const int ldc)
{
TLLM_CHECK_WITH_INFO(
descriptorsCreated(), "Descriptors are not created! Call createDescriptors before calling this function");
const auto heuristics = getTactics(getCublasLtHandle(), mOperationDesc, mADesc, mBDesc, mCDesc, mCDesc);
sync_check_cuda_error();
return heuristics;
}
std::vector<cublasLtMatmulHeuristicResult_t> CublasMMWrapper::getTactics(cublasLtHandle_t lightHandle,
cublasLtMatmulDesc_t computeDesc, cublasLtMatrixLayout_t Adesc, cublasLtMatrixLayout_t Bdesc,
cublasLtMatrixLayout_t Cdesc, cublasLtMatrixLayout_t Ddesc)
{
@ -648,7 +319,7 @@ std::vector<cublasLtMatmulHeuristicResult_t> cublasMMWrapper::getTactics(cublasL
check_cuda_error(cublasLtMatmulPreferenceSetAttribute(
preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspace_size, sizeof(workspace_size)));
// Restrict reduction algorithms for numerical stability and better determenism
uint32_t reduction_mask = CUBLASLT_REDUCTION_SCHEME_INPLACE;
uint32_t reduction_mask = CUBLASLT_REDUCTION_SCHEME_MASK;
check_cuda_error(cublasLtMatmulPreferenceSetAttribute(
preference, CUBLASLT_MATMUL_PREF_REDUCTION_SCHEME_MASK, &reduction_mask, sizeof(reduction_mask)));
#if TLLM_CUBLAS_VER_LT(12, 0, 0)
@ -666,215 +337,6 @@ std::vector<cublasLtMatmulHeuristicResult_t> cublasMMWrapper::getTactics(cublasL
#endif
}
std::pair<bool, cublasLtMatmulAlgo_t> cublasMMWrapper::findBestAlgo(cublasLtHandle_t lightHandle,
cublasLtMatmulDesc_t computeDesc, const void* alpha, const void* A, cublasLtMatrixLayout_t Adesc, const void* B,
cublasLtMatrixLayout_t Bdesc, const void* beta, const void* C, cublasLtMatrixLayout_t Cdesc, void* D,
cublasLtMatrixLayout_t Ddesc, cudaStream_t stream)
{
#if TLLM_CUBLAS_VER_LE(11, 4, 2)
TLLM_CHECK_WITH_INFO(false, "CUBLAS version too low, must be > 11.4.2.");
return {false, cublasLtMatmulAlgo_t{}};
#else
size_t returnSize;
int32_t pointer_mode;
cublasLtMatmulDescGetAttribute(
computeDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &pointer_mode, sizeof(pointer_mode), &returnSize);
const auto heuristics = getTactics(lightHandle, computeDesc, Adesc, Bdesc, Cdesc, Ddesc);
std::map<int, std::vector<float>> algo_results;
for (const auto& heuristic : heuristics)
{
cublasLtMatmulAlgo_t algo = heuristic.algo;
int32_t algo_id;
cublasLtMatmulAlgoConfigGetAttribute(&algo, CUBLASLT_ALGO_CONFIG_ID, &algo_id, sizeof(algo_id), &returnSize);
cudaEvent_t start_event, stop_event;
cudaEventCreate(&start_event);
cudaEventCreate(&stop_event);
float my_alpha = 1.0f;
float my_beta = 0.0f;
for (int i = 0; i < 11; i++)
{
float duration_ms;
cudaEventRecord(start_event, stream);
check_cuda_error(cublasLtMatmul(lightHandle, computeDesc, alpha, A, Adesc, B, Bdesc, beta, C, Cdesc, D,
Ddesc, &algo, cublas_workspace_, CUBLAS_WORKSPACE_SIZE, stream));
cudaEventRecord(stop_event, stream);
cudaEventSynchronize(stop_event);
cudaEventElapsedTime(&duration_ms, start_event, stop_event);
algo_results[algo_id].push_back(duration_ms);
}
std::sort(algo_results[algo_id].begin(), algo_results[algo_id].end());
}
cublasLtMatmulHeuristicResult_t result;
float best_time = INFINITY;
for (const auto& heuristic : heuristics)
{
cublasLtMatmulAlgo_t algo = heuristic.algo;
int32_t algo_id;
cublasLtMatmulAlgoConfigGetAttribute(&algo, CUBLASLT_ALGO_CONFIG_ID, &algo_id, sizeof(algo_id), &returnSize);
const auto& results = algo_results[algo_id];
if (results.size() > 0 && results[5] < best_time)
{
best_time = results[5];
result = heuristic;
}
}
return {best_time != INFINITY, result.algo};
#endif
}
cublasMMWrapper::MatrixLayout cublasMMWrapper::createMatrixLayout(cublasLtMatrixLayout_t Mdesc)
{
size_t returnSize;
MatrixLayout m_layout;
cublasLtMatrixLayoutGetAttribute(
Mdesc, CUBLASLT_MATRIX_LAYOUT_TYPE, &std::get<0>(m_layout), sizeof(std::get<0>(m_layout)), &returnSize);
cublasLtMatrixLayoutGetAttribute(
Mdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &std::get<1>(m_layout), sizeof(std::get<1>(m_layout)), &returnSize);
cublasLtMatrixLayoutGetAttribute(
Mdesc, CUBLASLT_MATRIX_LAYOUT_ROWS, &std::get<2>(m_layout), sizeof(std::get<2>(m_layout)), &returnSize);
cublasLtMatrixLayoutGetAttribute(
Mdesc, CUBLASLT_MATRIX_LAYOUT_COLS, &std::get<3>(m_layout), sizeof(std::get<3>(m_layout)), &returnSize);
return m_layout;
}
cublasStatus_t cublasMMWrapper::cublasLtMatmulWrapper(cublasLtHandle_t lightHandle, cublasLtMatmulDesc_t computeDesc,
const void* alpha, const void* A, cublasLtMatrixLayout_t Adesc, const void* B, cublasLtMatrixLayout_t Bdesc,
const void* beta, const void* C, cublasLtMatrixLayout_t Cdesc, void* D, cublasLtMatrixLayout_t Ddesc,
const cublasLtMatmulAlgo_t* algo, void* workspace, size_t workspaceSizeInBytes, cudaStream_t stream)
{
cache_idx_t cache_idx{computeDesc,
{createMatrixLayout(Adesc), createMatrixLayout(Bdesc), createMatrixLayout(Cdesc), createMatrixLayout(Ddesc)}};
cublasLtMatmulAlgo_t algo_value;
bool found_algo = false;
if (algo == nullptr)
{
if (algo_cache.find(cache_idx) == algo_cache.end())
{
auto result
= findBestAlgo(lightHandle, computeDesc, alpha, A, Adesc, B, Bdesc, beta, C, Cdesc, D, Ddesc, stream);
if (result.first)
{
algo_cache[cache_idx] = result.second;
algo_value = result.second;
found_algo = true;
}
}
else
{
algo_value = algo_cache[cache_idx];
found_algo = true;
}
}
return cublasLtMatmul(lightHandle, computeDesc, alpha, A, Adesc, B, Bdesc, beta, C, Cdesc, D, Ddesc,
found_algo ? &algo_value : algo, workspace, workspaceSizeInBytes, stream);
}
void cublasMMWrapper::_Int8Gemm(const int m, const int n, const int k, const int8_t* A, const int lda, const int8_t* B,
const int ldb, void* C, const int ldc, const void* alpha, const int mode, const bool per_column_scaling)
{
/* mode:
* - 0: int8 * int8 -> int32 -> int8
* - 1: int8 * int8 -> int32 -> int32
*/
#if TLLM_CUBLAS_VER_LE(11, 4, 2)
TLLM_CHECK_WITH_INFO(false, "CUBLAS version too low, must be > 11.4.2.");
#else
mu_->lock();
const auto op_a = CUBLAS_OP_T;
const auto op_b = CUBLAS_OP_N;
const auto dataType = CUDA_R_8I;
const auto resultType = mode == 0 ? CUDA_R_8I : CUDA_R_32I;
const auto computeType = CUBLAS_COMPUTE_32I;
const auto scaleType = mode == 0 ? CUDA_R_32F : CUDA_R_32I;
const int batch_count = 1;
const void* beta;
int findAlgo = cublas_algo_map_->isExist(batch_count, m, n, k, getCublasDataType(dataType));
cublasLtMatmulAlgo_info info = cublas_algo_map_->getAlgo(batch_count, m, n, k, getCublasDataType(dataType));
cublasLtMatmulDesc_t operationDesc = NULL;
cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL;
// --------------------------------------
// Create descriptors for the original matrices
check_cuda_error(cublasLtMatrixLayoutCreate(&Adesc, dataType, k, m, lda));
check_cuda_error(cublasLtMatrixLayoutCreate(&Bdesc, dataType, k, n, ldb));
check_cuda_error(cublasLtMatrixLayoutCreate(&Cdesc, resultType, m, n, ldc));
check_cuda_error(cublasLtMatmulDescCreate(&operationDesc, computeType, scaleType));
auto pointer_mode = CUBLASLT_POINTER_MODE_HOST;
if (mode == 0)
{
pointer_mode
= per_column_scaling ? CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST : CUBLASLT_POINTER_MODE_DEVICE;
}
check_cuda_error(
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &op_a, sizeof(cublasOperation_t)));
check_cuda_error(
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &op_b, sizeof(cublasOperation_t)));
check_cuda_error(
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSC, &op_b, sizeof(cublasOperation_t)));
check_cuda_error(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &pointer_mode, sizeof(pointer_mode)));
const int32_t int_one = 1;
const int32_t int_zero = 0;
const float float_zero = 0;
if (mode == 0)
{
beta = per_column_scaling ? &float_zero : NULL;
}
else
{
alpha = &int_one;
beta = &int_zero;
}
void* workSpace = cublas_workspace_;
int workspaceSize = cublas_workspace_ == NULL ? 0 : CUBLAS_WORKSPACE_SIZE;
sync_check_cuda_error();
auto ret = cublasLtMatmulWrapper(*cublaslt_handle_, operationDesc, alpha, A, Adesc, B, Bdesc, beta, C, Cdesc, C,
Cdesc, NULL, workSpace, workspaceSize, stream_);
check_cuda_error(ret);
sync_check_cuda_error();
cublasLtMatmulDescDestroy(operationDesc);
cublasLtMatrixLayoutDestroy(Adesc);
cublasLtMatrixLayoutDestroy(Bdesc);
cublasLtMatrixLayoutDestroy(Cdesc);
sync_check_cuda_error();
mu_->unlock();
#endif
}
void cublasMMWrapper::Int8Gemm(const int m, const int n, const int k, const int8_t* A, const int lda, const int8_t* B,
const int ldb, int8_t* C, const int ldc, const float* alpha, const bool per_column_scaling)
{
return _Int8Gemm(m, n, k, A, lda, B, ldb, C, ldc, alpha, 0, per_column_scaling);
}
void cublasMMWrapper::Int8Gemm(const int m, const int n, const int k, const int8_t* A, const int lda, const int8_t* B,
const int ldb, int32_t* C, const int ldc)
{
return _Int8Gemm(m, n, k, A, lda, B, ldb, C, ldc, (float*) nullptr, 1, false);
}
} // namespace common
} // namespace tensorrt_llm

View File

@ -16,7 +16,6 @@
#pragma once
#include "tensorrt_llm/common/cublasAlgoMap.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include <cublasLt.h>
#include <cublas_v2.h>
@ -31,66 +30,44 @@ namespace tensorrt_llm
namespace common
{
class cublasMMWrapper
class CublasMMWrapper
{
protected:
std::shared_ptr<cublasHandle_t> cublas_handle_;
std::shared_ptr<cublasLtHandle_t> cublaslt_handle_;
std::shared_ptr<cublasHandle_t> mCublasHandle;
std::shared_ptr<cublasLtHandle_t> mCublasLtHandle;
cudaDataType_t Atype_;
cudaDataType_t Btype_;
cudaDataType_t Ctype_;
cudaDataType_t computeType_;
cudaDataType_t mAType{};
cudaDataType_t mBType{};
cudaDataType_t mCType{};
cublasComputeType_t mComputeType{};
cudaDataType_t mScaleType{};
cudaStream_t stream_;
cublasAlgoMap* cublas_algo_map_;
std::mutex* mu_;
cublasLtMatmulDesc_t mOperationDesc{NULL};
cublasLtMatrixLayout_t mADesc{NULL};
cublasLtMatrixLayout_t mBDesc{NULL};
cublasLtMatrixLayout_t mCDesc{NULL};
void* cublas_workspace_ = nullptr;
cudaStream_t mStream;
//@fixme: we may not need the mutex if we copy the wrapper instead of sharing in GemmPlugin::clone()
std::shared_ptr<std::mutex> mMutex{std::make_shared<std::mutex>()};
friend class cublasINT8MMWrapper;
void* mCublasWorkspace = nullptr;
void _Int8Gemm(const int m, const int n, const int k, const int8_t* A, const int lda, const int8_t* B,
const int ldb, void* C, const int ldc, const void* alpha, const int mode, const bool per_column_scaling);
private:
bool descriptorsCreated() const
{
return mOperationDesc != NULL && mADesc != NULL && mBDesc != NULL && mCDesc != NULL;
}
public:
cublasMMWrapper(std::shared_ptr<cublasHandle_t> cublasHandle, std::shared_ptr<cublasLtHandle_t> cublasLtHandle,
cudaStream_t stream, cublasAlgoMap* map, std::mutex* mu, void* workspace);
CublasMMWrapper(std::shared_ptr<cublasHandle_t> cublasHandle, std::shared_ptr<cublasLtHandle_t> cublasLtHandle,
cudaStream_t stream, void* workspace);
~cublasMMWrapper();
~CublasMMWrapper();
cublasMMWrapper(const cublasMMWrapper& wrapper);
cublasStatus_t cublasLtMatmulWrapper(cublasLtHandle_t lightHandle, cublasLtMatmulDesc_t computeDesc,
const void* alpha, const void* A, cublasLtMatrixLayout_t Adesc, const void* B, cublasLtMatrixLayout_t Bdesc,
const void* beta, const void* C, cublasLtMatrixLayout_t Cdesc, void* D, cublasLtMatrixLayout_t Ddesc,
const cublasLtMatmulAlgo_t* algo, void* workspace, size_t workspaceSizeInBytes, cudaStream_t stream);
bool checkTactic(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k,
const int lda, const int ldb, const int ldc, const cublasLtMatmulHeuristicResult_t& algo) const;
std::vector<cublasLtMatmulHeuristicResult_t> getTactics(cublasOperation_t transa, cublasOperation_t transb,
const int m, const int n, const int k, const int lda, const int ldb, const int ldc);
std::vector<cublasLtMatmulHeuristicResult_t> getTactics(cublasLtHandle_t lightHandle,
cublasLtMatmulDesc_t computeDesc, cublasLtMatrixLayout_t Adesc, cublasLtMatrixLayout_t Bdesc,
cublasLtMatrixLayout_t Cdesc, cublasLtMatrixLayout_t Ddesc);
std::pair<bool, cublasLtMatmulAlgo_t> findBestAlgo(cublasLtHandle_t lightHandle, cublasLtMatmulDesc_t computeDesc,
const void* alpha, const void* A, cublasLtMatrixLayout_t Adesc, const void* B, cublasLtMatrixLayout_t Bdesc,
const void* beta, const void* C, cublasLtMatrixLayout_t Cdesc, void* D, cublasLtMatrixLayout_t Ddesc,
cudaStream_t stream);
using MatrixLayout = std::tuple<cudaDataType_t, cublasLtOrder_t, uint64_t, uint64_t>;
using cache_idx_t = std::tuple<cublasLtMatmulDesc_t, std::array<MatrixLayout, 4>>;
std::map<cache_idx_t, cublasLtMatmulAlgo_t> algo_cache;
MatrixLayout createMatrixLayout(cublasLtMatrixLayout_t Mdesc);
void Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k,
const void* alpha, const void* A, cudaDataType_t Atype, int lda, const void* B, cudaDataType_t Btype, int ldb,
const void* beta, void* C, cudaDataType_t Ctype, int ldc, cudaDataType_t computeType, cublasGemmAlgo_t algo);
CublasMMWrapper(const CublasMMWrapper& wrapper);
/********************** GEMMs **********************/
void Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k, const void* A,
const int lda, const void* B, const int ldb, void* C, const int ldc);
@ -103,16 +80,37 @@ public:
void Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k, const void* A,
const int lda, const void* B, const int ldb, void* C, const int ldc, float f_alpha, float f_beta,
const cublasLtMatmulAlgo_t& algo, bool hasAlgo);
const cublasLtMatmulAlgo_t& algo, bool hasAlgo, bool usingCublasLt);
void stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k,
const void* A, const int lda, const int64_t strideA, const void* B, const int ldb, const int64_t strideB,
void* C, const int ldc, const int64_t strideC, const int batchCount, const float f_alpha = 1.0f,
const float f_beta = 0.0f);
void stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k,
const float f_alpha, const void* A, cudaDataType_t AType, const int lda, const int64_t strideA, const void* B,
cudaDataType_t BType, const int ldb, const int64_t strideB, const float f_beta, void* C, cudaDataType_t CType,
const int ldc, const int64_t strideC, const int batchCount, cudaDataType_t computeType);
/********************** Tactic selection helpers **********************/
bool checkTactic(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k,
const int lda, const int ldb, const int ldc, const cublasLtMatmulAlgo_t& algo);
std::vector<cublasLtMatmulHeuristicResult_t> getTactics(cublasOperation_t transa, cublasOperation_t transb,
const int m, const int n, const int k, const int lda, const int ldb, const int ldc);
std::vector<cublasLtMatmulHeuristicResult_t> getTactics(cublasLtHandle_t lightHandle,
cublasLtMatmulDesc_t computeDesc, cublasLtMatrixLayout_t Adesc, cublasLtMatrixLayout_t Bdesc,
cublasLtMatrixLayout_t Cdesc, cublasLtMatrixLayout_t Ddesc);
using MatrixLayout = std::tuple<cudaDataType_t, cublasLtOrder_t, uint64_t, uint64_t>;
using cache_idx_t = std::tuple<cublasLtMatmulDesc_t, std::array<MatrixLayout, 4>>;
MatrixLayout createMatrixLayout(cublasLtMatrixLayout_t Mdesc);
/********************** Utils **********************/
void setWorkspace(void* workspace);
void Int8Gemm(const int m, const int n, const int k, const int8_t* A, const int lda, const int8_t* B, const int ldb,
int8_t* C, const int ldc, const float* alpha, const bool per_column_scaling = false);
void Int8Gemm(const int m, const int n, const int k, const int8_t* A, const int lda, const int8_t* B, const int ldb,
int32_t* C, const int ldc);
void setFP32GemmConfig();
void setFP16GemmConfig();
#ifdef ENABLE_BF16
@ -128,35 +126,18 @@ public:
CublasDataType getCublasDataType(cudaDataType_t data_type);
#if (CUDART_VERSION >= 11000)
void Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k, const void* A,
const int lda, const void* B, const int ldb, const void* bias, void* C, const int ldc);
#endif
void stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k,
const void* A, const int lda, const int64_t strideA, const void* B, const int ldb, const int64_t strideB,
void* C, const int ldc, const int64_t strideC, const int batchCount, const float f_alpha = 1.0f,
const float f_beta = 0.0f);
void stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k,
const float f_alpha, const void* A, cudaDataType_t AType, const int lda, const int64_t strideA, const void* B,
cudaDataType_t BType, const int ldb, const int64_t strideB, const float f_beta, void* C, cudaDataType_t CType,
const int ldc, const int64_t strideC, const int batch_count, cudaDataType_t computeType);
void batchedGemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k,
const void* const* A, const int lda, const void* const* B, const int ldb, void* const* C, const int ldc,
const int batch_count);
bool isFuseBatchGemm(const int batch_count, const int m, const int k, const int n);
void createDescriptors(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k,
const int lda, const int ldb, const int ldc);
void destroyDescriptors();
cublasHandle_t getCublasHandle()
{
return *(this->cublas_handle_);
return *(this->mCublasHandle);
}
cublasLtHandle_t getCublasLtHandle() const
{
return *(this->cublaslt_handle_);
return *(this->mCublasLtHandle);
}
};

View File

@ -287,25 +287,6 @@ inline int getMultiProcessorCount()
return multi_processor_count;
}
class CudaEventDeleter
{
public:
constexpr void operator()(cudaEvent_t stream) const
{
if (stream != nullptr)
check_cuda_error(::cudaEventDestroy(stream));
}
};
using EventPtr = std::unique_ptr<std::remove_pointer_t<cudaEvent_t>, CudaEventDeleter>;
inline EventPtr CreateEvent(unsigned int flags = cudaEventDisableTiming)
{
cudaEvent_t event;
check_cuda_error(::cudaEventCreate(&event, flags));
return EventPtr{event};
}
inline int divUp(int a, int n)
{
return (a + n - 1) / n;

View File

@ -30,11 +30,10 @@ namespace tensorrt_llm::common
class Logger
{
#if _WIN32
// On Windows, the file wingdi.h is included which has
// #define ERROR 0
// This breaks everywhere ERROR is used in the Level enum
// Alternative, untested solution to #undef: compile with NOGDI flag defined
#if _WIN32
#undef ERROR
#endif // _WIN32

View File

@ -0,0 +1,142 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/common/mpiUtils.h"
namespace tensorrt_llm
{
namespace mpi
{
MPI_Datatype getMpiDtype(MpiType dtype)
{
static const std::unordered_map<MpiType, MPI_Datatype> dtype_map{
{MPI_TYPE_BYTE, MPI_BYTE},
{MPI_TYPE_CHAR, MPI_CHAR},
{MPI_TYPE_INT, MPI_INT},
{MPI_TYPE_INT64_T, MPI_INT64_T},
{MPI_TYPE_UINT32_T, MPI_UINT32_T},
{MPI_TYPE_UNSIGNED_LONG_LONG, MPI_UNSIGNED_LONG_LONG},
};
return dtype_map.at(dtype);
}
MPI_Op getMpiOp(MpiOp op)
{
static const std::unordered_map<MpiOp, MPI_Op> op_map{
{MPI_OP_NULLOP, MPI_OP_NULL},
{MPI_OP_MAX, MPI_MAX},
{MPI_OP_MIN, MPI_MIN},
{MPI_OP_SUM, MPI_SUM},
{MPI_OP_PROD, MPI_PROD},
{MPI_OP_LAND, MPI_LAND},
{MPI_OP_BAND, MPI_BAND},
{MPI_OP_LOR, MPI_LOR},
{MPI_OP_BOR, MPI_BOR},
{MPI_OP_LXOR, MPI_LXOR},
{MPI_OP_BXOR, MPI_BXOR},
{MPI_OP_MINLOC, MPI_MINLOC},
{MPI_OP_MAXLOC, MPI_MAXLOC},
{MPI_OP_REPLACE, MPI_REPLACE},
};
return op_map.at(op);
}
void initialize(int* argc, char*** argv)
{
MPICHECK(MPI_Init(argc, argv));
}
void finalize()
{
MPICHECK(MPI_Finalize());
}
bool isInitialized()
{
int mpi_initialized = 0;
MPICHECK(MPI_Initialized(&mpi_initialized));
return static_cast<bool>(mpi_initialized);
}
void initThread(int* argc, char*** argv, MpiThreadSupport required, int* provided)
{
switch (required)
{
case THREAD_SINGLE: MPICHECK(MPI_Init_thread(argc, argv, MPI_THREAD_SINGLE, provided)); break;
case THREAD_FUNNELED: MPICHECK(MPI_Init_thread(argc, argv, MPI_THREAD_FUNNELED, provided)); break;
case THREAD_SERIALIZED: MPICHECK(MPI_Init_thread(argc, argv, MPI_THREAD_SERIALIZED, provided)); break;
case THREAD_MULTIPLE: MPICHECK(MPI_Init_thread(argc, argv, MPI_THREAD_MULTIPLE, provided)); break;
default: break;
}
}
int getCommWorldRank()
{
int rank = 0;
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
return rank;
}
int getCommWorldSize()
{
int world_size = 1;
MPI_Comm_size(MPI_COMM_WORLD, &world_size);
return world_size;
}
void barrier(MpiComm comm)
{
MPICHECK(MPI_Barrier(comm.group));
}
void barrier()
{
MPICHECK(MPI_Barrier(MPI_COMM_WORLD));
}
void bcast(void* buffer, size_t size, MpiType dtype, int root, MpiComm comm)
{
MPICHECK(MPI_Bcast(buffer, size, getMpiDtype(dtype), root, comm.group));
}
void bcast(std::vector<int64_t>& packed, int root, MpiComm comm)
{
int64_t nWords1;
if (getCommWorldRank() == root)
{
nWords1 = static_cast<int64_t>(packed.size());
}
bcast(&nWords1, 1, MPI_TYPE_INT64_T, root, comm);
if (getCommWorldRank() != root)
{
packed.resize(nWords1);
}
bcast(packed.data(), packed.size(), MPI_TYPE_INT64_T, root, comm);
}
void comm_split(MpiComm comm, int color, int key, MpiComm* newcomm)
{
MPICHECK(MPI_Comm_split(comm.group, color, key, &newcomm->group));
}
void allreduce(const void* sendbuf, void* recvbuf, int count, MpiType dtype, MpiOp op, MpiComm comm)
{
MPICHECK(MPI_Allreduce(sendbuf, recvbuf, count, getMpiDtype(dtype), getMpiOp(op), comm.group));
}
} // namespace mpi
} // namespace tensorrt_llm

View File

@ -0,0 +1,105 @@
/*
* Copyright (c) 2021-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cstdlib>
#include <mpi.h>
#include <stdio.h>
#include <unordered_map>
#include <vector>
#define COMM_WORLD MpiComm(MPI_COMM_WORLD)
#define MPICHECK(cmd) \
do \
{ \
int e = cmd; \
if (e != MPI_SUCCESS) \
{ \
printf("Failed: MPI error %s:%d '%d'\n", __FILE__, __LINE__, e); \
exit(EXIT_FAILURE); \
} \
} while (0)
// A wrapper module of the MPI library.
namespace tensorrt_llm::mpi
{
// A wrapper of MPI data type. MPI_TYPE_{data_type}
enum MpiType
{
MPI_TYPE_BYTE,
MPI_TYPE_CHAR,
MPI_TYPE_INT,
MPI_TYPE_INT64_T,
MPI_TYPE_UINT32_T,
MPI_TYPE_UNSIGNED_LONG_LONG,
};
// A wrapper of MPI_Op type.
enum MpiOp
{
MPI_OP_NULLOP,
MPI_OP_MAX,
MPI_OP_MIN,
MPI_OP_SUM,
MPI_OP_PROD,
MPI_OP_LAND,
MPI_OP_BAND,
MPI_OP_LOR,
MPI_OP_BOR,
MPI_OP_LXOR,
MPI_OP_BXOR,
MPI_OP_MINLOC,
MPI_OP_MAXLOC,
MPI_OP_REPLACE,
};
// A wrapper of the level of MPI thread support
enum MpiThreadSupport
{
THREAD_SINGLE,
THREAD_FUNNELED,
THREAD_SERIALIZED,
THREAD_MULTIPLE
};
struct MpiComm
{
MPI_Comm group;
MpiComm(){};
MpiComm(MPI_Comm g)
: group(g){};
};
MPI_Datatype getMpiDtype(MpiType dtype);
void initialize(int* argc, char*** argv);
void initThread(int* argc, char*** argv, MpiThreadSupport required, int* provided);
void finalize();
bool isInitialized();
void barrier(MpiComm comm);
void barrier();
int getCommWorldRank();
int getCommWorldSize();
void bcast(void* buffer, size_t size, MpiType dtype, int root, MpiComm comm);
void bcast(std::vector<int64_t>& packed, int root, MpiComm comm);
void comm_split(MpiComm comm, int color, int key, MpiComm* newcomm);
void allreduce(const void* sendbuf, void* recvbuf, int count, MpiType dtype, MpiOp op, MpiComm comm);
} // namespace tensorrt_llm::mpi

View File

@ -30,7 +30,7 @@ inline nvtx3::color nextColor()
nvtx3::color{0xffff00ff}, nvtx3::color{0xff00ffff}, nvtx3::color{0xffff0000}, nvtx3::color{0xffffffff}};
constexpr auto numColors = kColors.size();
static thread_local int colorId = 0;
static thread_local std::size_t colorId = 0;
auto const color = kColors[colorId];
colorId = colorId + 1 >= numColors ? 0 : colorId + 1;
return color;

View File

@ -80,7 +80,7 @@ __inline__ __device__ T warpReduceSum(T val)
{
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1)
val = add(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32)); //__shfl_sync bf16 return float when sm < 80
val = add<T>(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32)); //__shfl_sync bf16 return float when sm < 80
return val;
}

View File

@ -361,6 +361,8 @@ FusedMHARunnerV2::FusedMHARunnerV2(
{
}
FusedMHARunnerV2::~FusedMHARunnerV2() = default;
void FusedMHARunnerV2::setup(
const int b, const int s, const int total_seqlen, const bool has_alibi, const int tp_size, const int tp_rank)
{

View File

@ -78,7 +78,7 @@ class FusedMHARunnerV2 : public MHARunner
public:
FusedMHARunnerV2(const Data_type dataType, const int numHeads, const int headSize, const float qScaling);
~FusedMHARunnerV2() = default; // for pimpl
~FusedMHARunnerV2(); // for pimpl
void setup(const int b, const int s, const int total_seqlen, const bool has_alibi = false, const int tp_size = 1,
const int tp_rank = 0) override;

View File

@ -0,0 +1,453 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "customAllReduceKernels.h"
#include "tensorrt_llm/common/cudaTypeUtils.cuh"
namespace tensorrt_llm::kernels
{
////////////////////////////////////////////////////////////////////////////////////////////////////
using tensorrt_llm::common::hadd2;
static inline __device__ uint32_t myHadd2(const uint32_t& a, const uint32_t& b)
{
uint32_t c;
asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t fadd(const uint32_t& a, const uint32_t& b)
{
uint32_t c;
asm volatile("add.f32 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ void st_flag_release(uint32_t& flag, uint32_t* flag_addr)
{
#if __CUDA_ARCH__ >= 700
asm volatile("st.global.release.sys.b32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
#else
__threadfence_system();
asm volatile("st.global.volatile.b32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ void ld_flag_acquire(uint32_t& flag, uint32_t* flag_addr)
{
#if __CUDA_ARCH__ >= 700
asm volatile("ld.global.acquire.sys.b32 %0, [%1];" : "=r"(flag) : "l"(flag_addr));
#else
asm volatile("ld.global.volatile.b32 %0, [%1];" : "=r"(flag) : "l"(flag_addr));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// Type Converter that packs data format to 128 bits data type
template <typename T>
struct ARTypeConverter
{
using Type = uint4;
};
#ifdef ENABLE_BF16
template <>
struct ARTypeConverter<__nv_bfloat16>
{
using Type = bf168;
};
#endif
// add two 128b data
template <typename T_IN, typename T_COMP>
inline __device__ T_IN add128b(T_IN a, T_IN b);
template <>
inline __device__ uint4 add128b<uint4, uint16_t>(uint4 a, uint4 b)
{
uint4 c;
c.x = myHadd2(a.x, b.x);
c.y = myHadd2(a.y, b.y);
c.z = myHadd2(a.z, b.z);
c.w = myHadd2(a.w, b.w);
return c;
}
template <>
inline __device__ uint4 add128b<uint4, uint32_t>(uint4 a, uint4 b)
{
uint4 c;
c.x = fadd(a.x, b.x);
c.y = fadd(a.y, b.y);
c.z = fadd(a.z, b.z);
c.w = fadd(a.w, b.w);
return c;
}
#ifdef ENABLE_BF16
template <>
inline __device__ bf168 add128b<bf168, __nv_bfloat16>(bf168 a, bf168 b)
{
bf168 c;
c.x = hadd2(a.x, b.x);
c.y = hadd2(a.y, b.y);
c.z = hadd2(a.z, b.z);
c.w = hadd2(a.w, b.w);
return c;
}
#endif
// init 128bits data with 0
template <typename T>
inline __device__ T init_packed_type();
template <>
inline __device__ uint4 init_packed_type()
{
return make_uint4(0u, 0u, 0u, 0u);
}
#ifdef ENABLE_BF16
template <>
inline __device__ bf168 init_packed_type()
{
bf168 val;
uint4& val_u = reinterpret_cast<uint4&>(val);
val_u = make_uint4(0u, 0u, 0u, 0u);
return val;
}
#endif
template <typename T, int RANKS_PER_NODE>
static __global__ void oneShotAllReduceKernel(AllReduceParams params)
{
// The block index.
const int bidx = blockIdx.x;
// The thread index with the block.
const int tidx = threadIdx.x;
// The number of elements packed into one for comms
static constexpr int NUM_ELTS = std::is_same<T, uint32_t>::value ? 4 : 8;
// Packed data type for comms
using PackedType = typename ARTypeConverter<T>::Type;
// The location in the destination array (load 8 fp16 or load 4 fp32 using LDG.128).
size_t offset = bidx * params.elts_per_block + tidx * NUM_ELTS;
// The end of the segment computed by that block.
size_t max_offset = std::min((bidx + 1) * params.elts_per_block, params.elts_per_rank);
// Synchronize the ranks.
volatile uint32_t* barrier_d = params.peer_barrier_ptrs[params.local_rank];
if (tidx < RANKS_PER_NODE)
{
// The 1st block notifies the other ranks.
if (bidx == 0)
{
params.peer_barrier_ptrs[tidx][params.local_rank] = params.barrier_flag;
}
// Busy-wait until all ranks are ready.
while (barrier_d[tidx] != params.barrier_flag)
{
}
}
// Make sure we can move on...
__syncthreads();
// The source pointers. Distributed round-robin for the different warps.
const T* src_d[RANKS_PER_NODE];
#pragma unroll
for (int ii = 0; ii < RANKS_PER_NODE; ++ii)
{
int rank = (params.local_rank + ii) % RANKS_PER_NODE;
src_d[ii] = reinterpret_cast<T*>(params.peer_comm_buffer_ptrs[rank]);
}
// Each block accumulates the values from the different GPUs on the same node.
for (size_t iter_offset = offset; iter_offset < max_offset; iter_offset += blockDim.x * NUM_ELTS)
{
// Iterate over the different ranks/devices on the node to load the values.
PackedType vals[RANKS_PER_NODE];
#pragma unroll
for (int ii = 0; ii < RANKS_PER_NODE; ++ii)
{
vals[ii] = reinterpret_cast<const PackedType*>(&src_d[ii][iter_offset])[0];
}
// Sum the values from the different ranks.
PackedType sums = init_packed_type<PackedType>();
#pragma unroll
for (int ii = 0; ii < RANKS_PER_NODE; ++ii)
{
sums = add128b<PackedType, T>(sums, vals[ii]);
}
// Store to the destination buffer.
reinterpret_cast<PackedType*>(&reinterpret_cast<T*>(params.local_output_buffer_ptr)[iter_offset])[0] = sums;
}
}
template <typename T, int RANKS_PER_NODE>
static __global__ void twoShotAllReduceKernel(AllReduceParams params)
{
// The block index.
const int bidx = blockIdx.x;
// The thread index with the block.
const int tidx = threadIdx.x;
// The number of elements packed into one for comms
static constexpr int NUM_ELTS = std::is_same<T, uint32_t>::value ? 4 : 8;
// Packed data type for comms
using PackedType = typename ARTypeConverter<T>::Type;
// The location in the destination array (load 8 fp16 or load 4 fp32 using LDG.128).
const size_t block_start = params.rank_offset;
const size_t block_offset = bidx * params.elts_per_block + tidx * NUM_ELTS;
// The end of the segment computed by that block.
size_t max_offset = min(block_start + block_offset + params.elts_per_block, params.elts_total);
// Synchronize the ranks.
volatile uint32_t* barrier_d = params.peer_barrier_ptrs[params.local_rank];
if (tidx < RANKS_PER_NODE)
{
// The 1st block notifies the other ranks.
if (bidx == 0)
{
params.peer_barrier_ptrs[tidx][params.local_rank] = params.barrier_flag;
}
// Busy-wait until all ranks are ready.
while (barrier_d[tidx] != params.barrier_flag)
{
}
}
// Make sure we can move on...
__syncthreads();
// The source pointers. Distributed round-robin for the different warps.
T* src_d[RANKS_PER_NODE];
// The destination ranks for round-robin gathering
size_t dst_rank[RANKS_PER_NODE];
#pragma unroll
for (int ii = 0; ii < RANKS_PER_NODE; ++ii)
{
int rank = (params.local_rank + ii) % RANKS_PER_NODE;
src_d[ii] = reinterpret_cast<T*>(params.peer_comm_buffer_ptrs[rank]);
dst_rank[ii] = rank;
}
// Each block accumulates the values from the different GPUs on the same node.
for (size_t local_offset = block_start + block_offset; local_offset < max_offset;
local_offset += blockDim.x * NUM_ELTS)
{
// Iterate over the different ranks/devices on the node to load the values.
PackedType vals[RANKS_PER_NODE];
#pragma unroll
for (int ii = 0; ii < RANKS_PER_NODE; ++ii)
{
vals[ii] = reinterpret_cast<const PackedType*>(&src_d[ii][local_offset])[0];
}
// Sum the values from the different ranks.
PackedType sums = init_packed_type<PackedType>();
#pragma unroll
for (int ii = 0; ii < RANKS_PER_NODE; ++ii)
{
sums = add128b<PackedType, T>(sums, vals[ii]);
}
// Store to the local buffer.
reinterpret_cast<PackedType*>(&src_d[0][local_offset])[0] = sums;
}
// sync threads to make sure all block threads have the sums
__syncthreads();
// barriers among the blocks with the same idx (release-acquire semantics)
if (tidx < RANKS_PER_NODE)
{
// The all blocks notifies the other ranks.
uint32_t flag_block_offset = RANKS_PER_NODE + bidx * RANKS_PER_NODE;
st_flag_release(params.barrier_flag, params.peer_barrier_ptrs[tidx] + flag_block_offset + params.local_rank);
// Busy-wait until all ranks are ready.
uint32_t rank_barrier = 0;
uint32_t* peer_barrier_d = params.peer_barrier_ptrs[params.local_rank] + flag_block_offset + tidx;
do
{
ld_flag_acquire(rank_barrier, peer_barrier_d);
} while (rank_barrier != params.barrier_flag);
}
// sync threads to make sure all other ranks has the final partial results
__syncthreads();
// Gather all needed elts from other intra-node ranks
for (size_t local_offset = block_offset; local_offset < params.elts_per_rank; local_offset += blockDim.x * NUM_ELTS)
{
#pragma unroll
for (int ii = 0; ii < RANKS_PER_NODE; ++ii)
{
// use round-robin gathering from other ranks
size_t offset_rank = dst_rank[ii] * params.elts_per_rank + local_offset;
if (offset_rank >= params.elts_total)
{
continue;
}
reinterpret_cast<PackedType*>(&reinterpret_cast<T*>(params.local_output_buffer_ptr)[offset_rank])[0]
= reinterpret_cast<PackedType*>(&src_d[ii][offset_rank])[0];
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
void kernelLaunchConfig(int& blocks_per_grid, int& threads_per_block, size_t elts, int kernel_algo,
size_t data_type_bytes, int ranks_per_node)
{
assert(data_type_bytes == 2 || data_type_bytes == 4);
size_t elts_per_thread = 16 / data_type_bytes;
size_t elts_per_warp = (16 * WARP_SIZE) / data_type_bytes;
switch (kernel_algo)
{
case 0:
{ // one stage all reduce algo
assert(elts % elts_per_warp == 0);
if (elts < (elts_per_thread * DEFAULT_BLOCK_SIZE))
{ // local reduce
threads_per_block = ((elts + elts_per_warp - 1) / elts_per_warp) * WARP_SIZE;
blocks_per_grid = 1;
}
else
{ // local reduce
if (elts % (elts_per_thread * threads_per_block) == 0)
{
blocks_per_grid
= (elts + elts_per_thread * threads_per_block - 1) / (elts_per_thread * threads_per_block);
// NOTE: need to adjust here
if (blocks_per_grid > MAX_ALL_REDUCE_BLOCKS)
{
size_t iter_factor = 1;
while (blocks_per_grid / iter_factor > MAX_ALL_REDUCE_BLOCKS || blocks_per_grid % iter_factor)
{
iter_factor += 1;
}
blocks_per_grid /= iter_factor;
}
}
else
{
size_t total_threads = elts / elts_per_thread;
blocks_per_grid = 1;
while (total_threads % blocks_per_grid != 0 || total_threads / blocks_per_grid > DEFAULT_BLOCK_SIZE)
{
blocks_per_grid += 1;
}
threads_per_block = total_threads / blocks_per_grid;
}
}
break;
}
case 1:
{ // two stage all reduce algo
size_t total_threads = elts / ranks_per_node / elts_per_thread;
assert(elts / ranks_per_node % elts_per_thread == 0 && total_threads % WARP_SIZE == 0);
while (total_threads % blocks_per_grid != 0 || total_threads / blocks_per_grid > DEFAULT_BLOCK_SIZE)
{
blocks_per_grid += 1;
}
threads_per_block = total_threads / blocks_per_grid;
// NOTE: need to adjust here
if (blocks_per_grid > MAX_ALL_REDUCE_BLOCKS)
{
size_t iter_factor = 1;
while (blocks_per_grid / iter_factor > MAX_ALL_REDUCE_BLOCKS || blocks_per_grid % iter_factor)
{
iter_factor += 1;
}
blocks_per_grid /= iter_factor;
}
break;
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#define CUSTOM_ALL_REDUCE_KERNEL_LAUNCH(RANKS_PER_NODE) \
\
if (kernel_algo == 0) \
{ \
param.elts_per_rank = elts_total; \
param.elts_per_block = param.elts_per_rank / blocks_per_grid; \
oneShotAllReduceKernel<T, RANKS_PER_NODE><<<blocks_per_grid, threads_per_block, 0, stream>>>(param); \
} \
else \
{ \
param.elts_per_rank = param.elts_total / RANKS_PER_NODE; \
param.elts_per_block = param.elts_per_rank / blocks_per_grid; \
param.rank_offset = param.rank * param.elts_per_rank; \
twoShotAllReduceKernel<T, RANKS_PER_NODE><<<blocks_per_grid, threads_per_block, 0, stream>>>(param); \
}
template <typename T>
void invokeOneOrTwoShotAllReduceKernel(AllReduceParams& param, cudaStream_t stream)
{
size_t elts_total = param.elts_total;
int blocks_per_grid = 1, threads_per_block = DEFAULT_BLOCK_SIZE;
int kernel_algo = 1;
if (elts_total * sizeof(T) <= DEFAULT_ALGO_AR_SIZE_THRESHOLD)
{
kernel_algo = 0;
}
kernelLaunchConfig(blocks_per_grid, threads_per_block, elts_total, kernel_algo, sizeof(T), param.ranks_per_node);
switch (param.ranks_per_node)
{
case 2: CUSTOM_ALL_REDUCE_KERNEL_LAUNCH(2); break;
case 4: CUSTOM_ALL_REDUCE_KERNEL_LAUNCH(4); break;
case 6: CUSTOM_ALL_REDUCE_KERNEL_LAUNCH(6); break;
case 8: CUSTOM_ALL_REDUCE_KERNEL_LAUNCH(8); break;
default: break;
}
}
// Template instantiation
template void invokeOneOrTwoShotAllReduceKernel<uint16_t>(AllReduceParams& param, cudaStream_t stream);
#ifdef ENABLE_BF16
template void invokeOneOrTwoShotAllReduceKernel<__nv_bfloat16>(AllReduceParams& param, cudaStream_t stream);
#endif
template void invokeOneOrTwoShotAllReduceKernel<uint32_t>(AllReduceParams& param, cudaStream_t stream);
} // namespace tensorrt_llm::kernels

View File

@ -0,0 +1,73 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <assert.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <iostream>
#include "tensorrt_llm/common/cudaUtils.h"
#define CUSTOM_AR_SIZE_THRESHOLD 50331648
#define MAX_ALL_REDUCE_BLOCKS 24
#define FLAG(a) ((uint32_t) ((a) % 0x146))
#define MAX_RANKS_PER_NODE 8
#define WARP_SIZE 32
#define DEFAULT_BLOCK_SIZE 1024
#define DEFAULT_ALGO_AR_SIZE_THRESHOLD 393216
namespace tensorrt_llm::kernels
{
#ifdef ENABLE_BF16
typedef struct bf168
{
__nv_bfloat162 x;
__nv_bfloat162 y;
__nv_bfloat162 z;
__nv_bfloat162 w;
} bf168;
#endif
struct AllReduceIpcMemHandles
{
cudaIpcMemHandle_t peer_barrier_ipc_handles[MAX_RANKS_PER_NODE];
cudaIpcMemHandle_t peer_comm_buffer_ipc_handles[MAX_RANKS_PER_NODE];
};
struct AllReduceParams
{
size_t elts_total;
size_t elts_per_rank;
size_t elts_per_block;
size_t rank_offset;
size_t ranks_per_node, rank, local_rank, node_id;
uint32_t barrier_flag;
uint32_t* peer_barrier_ptrs[MAX_RANKS_PER_NODE];
void* peer_comm_buffer_ptrs[MAX_RANKS_PER_NODE];
void* local_output_buffer_ptr;
AllReduceIpcMemHandles ipc_mem_handles;
};
template <typename T>
void invokeOneOrTwoShotAllReduceKernel(AllReduceParams& param, cudaStream_t stream);
void kernelLaunchConfig(int& blocks_per_grid, int& threads_per_block, size_t elts, int kernel_algo);
} // namespace tensorrt_llm::kernels

View File

@ -116,7 +116,9 @@ struct Multihead_attention_params_base
// The per-head latent space reserved for rotary embeddings.
int rotary_embedding_dim = 0;
float rotary_embedding_base = 0.0f;
RotaryScalingType rotary_embedding_scale_type = RotaryScalingType::kNONE;
float rotary_embedding_scale = 0.0f;
int rotary_embedding_max_positions = 0;
// The current timestep. TODO(bhsueh) Check that do we only this param in cross attention?
int timestep = 0;
// The current timestep of each sentences (support different timestep for different sentences)

View File

@ -1165,8 +1165,13 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
zero(k);
zero(q_bias);
zero(k_bias);
float rotary_embedding_base = params.rotary_embedding_base;
float rotary_embedding_scale = params.rotary_embedding_scale;
if (is_valid_qk_vec)
{
mmha::update_rotary_base_n_scale(rotary_embedding_base, rotary_embedding_scale,
params.rotary_embedding_scale_type, params.rotary_embedding_dim, params.rotary_embedding_max_positions,
tlength);
// Query
// The stride between tokens. We may be able to always use params.stride.
uint32_t q_stride = params.stride ? static_cast<uint32_t>(params.stride) : (num_heads * Dh);
@ -1280,7 +1285,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
mmha::vec_from_smem_transpose(k, k_smem, transpose_idx, smem_pitch);
mmha::apply_rotary_embedding(q, k, transpose_idx / tidx_factor, params.rotary_embedding_dim,
params.rotary_embedding_base, params.rotary_embedding_scale, tlength);
rotary_embedding_base, rotary_embedding_scale, tlength);
mmha::write_smem_transpose(k, k_smem, transpose_idx, smem_pitch);
mmha::write_smem_transpose(q, q_smem_, transpose_idx, smem_pitch);

View File

@ -17,6 +17,7 @@
#pragma once
#include "tensorrt_llm/common/cudaTypeUtils.cuh"
#include "tensorrt_llm/kernels/gptKernels.h"
#include <stdint.h>
#ifdef ENABLE_BF16
@ -122,6 +123,12 @@ struct num_elems<Float8_>
static constexpr int value = 8;
};
template <>
struct num_elems<half>
{
static constexpr int value = 1;
};
template <>
struct num_elems<uint32_t>
{
@ -141,6 +148,12 @@ struct num_elems<uint4>
};
#ifdef ENABLE_BF16
template <>
struct num_elems<__nv_bfloat16>
{
static constexpr int value = 1;
};
template <>
struct num_elems<__nv_bfloat162>
{
@ -197,6 +210,12 @@ struct packed_type<T, 1>
using type = T;
};
template <>
struct packed_type<int8_t, 1>
{
using type = int8_t;
};
template <>
struct packed_type<int8_t, 2>
{
@ -216,6 +235,13 @@ struct packed_type<int8_t, 8>
};
#ifdef ENABLE_FP8
template <>
struct packed_type<__nv_fp8_e4m3, 1>
{
using type = __nv_fp8_e4m3;
};
template <>
struct packed_type<__nv_fp8_e4m3, 2>
{
@ -1508,6 +1534,32 @@ inline __device__ void zero(T& dst)
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float update_rotary_base(
const int kv_seq_len, const int max_positions, const int embed_dim, const float base, const float scale)
{
const float b = (scale * kv_seq_len / max_positions) - (scale - 1);
const float p = static_cast<float>(embed_dim) / (embed_dim - 2);
return base * pow(b, p);
}
inline __device__ void update_rotary_base_n_scale(float& base, float& scale, RotaryScalingType const scale_type,
const int rot_embed_dim, const int max_positions, const int seq_len)
{
// only update the base and/or scale if needed based on scale_type
if (scale_type == RotaryScalingType::kDYNAMIC)
{
if (seq_len > max_positions)
{
base = update_rotary_base(seq_len, max_positions, rot_embed_dim, base, scale);
}
scale = 1.0f; // scale is only used in base for dynamic scaling
}
else if (scale_type == RotaryScalingType::kLINEAR)
{
scale = 1.0f / scale;
}
}
inline __device__ float2 rotary_embedding_coefficient(
const int zid, const int rot_embed_dim, const float base, const float scale, const float t_step)
{
@ -2006,6 +2058,14 @@ inline __device__ float2 convert_to_float(uint32_t u)
////////////////////////////////////////////////////////////////////////////////////////////////////
template <>
inline __device__ float convert_to_float(half u)
{
return static_cast<float>(u);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_FP8
inline __device__ void convert_from_fp8(uint16_t* v, const __nv_fp8_e4m3 u)
{
@ -2209,6 +2269,13 @@ inline __device__ void convert_to_fp8(fp8_8_t* v, const bf16_8_t u)
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void convert_to_fp8(__nv_fp8_e4m3* v, const half u)
{
v[0] = __nv_fp8_e4m3(u);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void convert_to_fp8(__nv_fp8_e4m3* v, const uint16_t u)
{
v[0] = __nv_fp8_e4m3(reinterpret_cast<const half&>(u));

View File

@ -111,8 +111,7 @@ __global__ void length_criterion(bool* finished, int* finished_sum, const uint32
{
const int batch_idx = index / beam_width;
// sequence_lengths is updated, so need to minus 1
finished[index] |= sequence_lengths[index] - 1 >= sequence_limit_length[batch_idx];
finished[index] |= sequence_lengths[index] >= sequence_limit_length[batch_idx];
thread_finished_count += finished[index] ? 1 : 0;
}

View File

@ -1252,8 +1252,9 @@ struct Vec_t<__nv_bfloat16>
template <typename T, bool ADD_BIAS, bool USING_CONTEXT_FMHA>
__global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf, T* k_buf, T* v_buf, T* QKV, const T* __restrict qkv_bias,
const int* seq_lens, const int* padding_offset, const int batch_size, const int seq_len, const int head_num,
const int kv_head_num, const int size_per_head, const int rotary_embedding_dim, const float rotary_embedding_base,
const float rotary_embedding_scale, PositionEmbeddingType const position_embedding_type)
const int kv_head_num, const int size_per_head, const int rotary_embedding_dim, float rotary_embedding_base,
RotaryScalingType const rotary_scale_type, float rotary_embedding_scale, const int rotary_embedding_max_positions,
PositionEmbeddingType const position_embedding_type)
{
// This kernel add bias to QKV, which has shape [batch_size, seq_len, 3, head_num, size_per_head], and
// QKV split to 3 split buffer q, k, v and transpose them to [batch_size, head_num, seq_len, size_per_head].
@ -1323,6 +1324,11 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf, T* k_buf, T* v_buf,
Vec_t q, k, v, zero;
Vec_t q_bias, k_bias, v_bias;
if (valid_seq)
{
mmha::update_rotary_base_n_scale(rotary_embedding_base, rotary_embedding_scale, rotary_scale_type,
rotary_embedding_dim, rotary_embedding_max_positions, actual_seq_len);
}
#pragma unroll
for (int i = 0; i < sizeof(Vec_t) / sizeof(uint32_t); i++)
@ -1456,14 +1462,16 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf, T* k_buf, T* v_buf,
#define FUSED_QKV_BIAS_ROTARY_TRANSPOSE_LAUNCH(T, ADD_BIAS, USING_CONTEXT_FMHA) \
add_fusedQKV_bias_transpose_kernel<T, ADD_BIAS, USING_CONTEXT_FMHA><<<grid, block, smem_size, stream>>>(q_buf, \
k_buf, v_buf, QKV, qkv_bias, seq_lens, padding_offset, batch_size, seq_len, head_num, kv_head_num, \
size_per_head, rotary_embedding_dim, rotary_embedding_base, rotary_embedding_scale, position_embedding_type);
size_per_head, rotary_embedding_dim, rotary_embedding_base, rotary_scale_type, rotary_embedding_scale, \
rotary_embedding_max_positions, position_embedding_type);
template <typename T>
void invokeAddFusedQKVBiasTranspose(T* q_buf, T* k_buf, T* v_buf, T* QKV, const T* qkv_bias, const int* seq_lens,
const int* padding_offset, const int batch_size, const int seq_len, const int token_num, const int head_num,
const int kv_head_num, const int size_per_head, const bool using_context_fmha, const int rotary_embedding_dim,
const float rotary_embedding_base, const float rotary_embedding_scale,
const PositionEmbeddingType position_embedding_type, const float* scale, const int int8_mode, cudaStream_t stream)
const float rotary_embedding_base, const RotaryScalingType rotary_scale_type, const float rotary_embedding_scale,
const int rotary_embedding_max_positions, const PositionEmbeddingType position_embedding_type, const float* scale,
const int int8_mode, cudaStream_t stream)
{
// [bs, seq_len, 3, head, Dh]
if (rotary_embedding_dim == 0)
@ -1535,7 +1543,8 @@ void invokeAddFusedQKVBiasTranspose(T* q_buf, T* k_buf, T* v_buf, T* QKV, const
template void invokeAddFusedQKVBiasTranspose(T* q_buf, T* k_buf, T* v_buf, T* QKV, const T* qkv_bias, \
const int* seq_lens, const int* padding_offset, const int batch_size, const int seq_len, const int token_num, \
const int head_num, const int kv_head_num, const int size_per_head, const bool using_context_fmha, \
const int rotary_embedding_dim, const float rotary_embedding_base, const float rotary_embedding_scale, \
const int rotary_embedding_dim, const float rotary_embedding_base, const RotaryScalingType rotary_scale_type, \
const float rotary_embedding_scale, const int rotary_embedding_max_poisitions, \
const PositionEmbeddingType position_embedding_type, const float* scale, const int int8_mode, \
cudaStream_t stream)
INSTANTIATE_ADDFUSEDQKVBIAS_TRANSPOSE(float);

View File

@ -76,8 +76,9 @@ template <typename T>
void invokeAddFusedQKVBiasTranspose(T* q_buf, T* k_buf, T* v_buf, T* QKV, const T* qkv_bias, const int* seq_lens,
const int* padding_offset, const int batch_size, const int seq_len, const int token_num, const int head_num,
const int kv_head_num, const int size_per_head, const bool using_context_fmha, const int rotary_embedding_dim,
const float rotary_embedding_base, const float rotary_embedding_scale,
PositionEmbeddingType const position_embedding_type, const float* scale, const int int8_mode, cudaStream_t stream);
float rotary_embedding_base, const RotaryScalingType rotary_scale_type, float rotary_embedding_scale,
const int rotary_embedding_max_positions, PositionEmbeddingType const position_embedding_type, const float* scale,
const int int8_mode, cudaStream_t stream);
template <typename T>
void invokeAddFusedQKVBiasTranspose(T* q_buf, T* k_buf, T* v_buf, T* QKV, const T* qkv_bias, const int* seq_lens,
@ -92,12 +93,14 @@ template <typename T>
void invokeAddFusedQKVBiasTranspose(T* q_buf, T* k_buf, T* v_buf, T* QKV, const int* seq_lens,
const int* padding_offset, const int batch_size, const int seq_len, const int token_num, const int head_num,
const int kv_head_num, const int size_per_head, const bool using_context_fmha, const int rotary_embedding_dim,
const float rotary_embedding_base, const float rotary_embedding_scale,
PositionEmbeddingType const position_embedding_type, const float* scale, const int int8_mode, cudaStream_t stream)
float rotary_embedding_base, const RotaryScalingType rotary_scale_type, float rotary_embedding_scale,
const int rotary_embedding_max_positions, PositionEmbeddingType const position_embedding_type, const float* scale,
const int int8_mode, cudaStream_t stream)
{
invokeAddFusedQKVBiasTranspose(q_buf, k_buf, v_buf, QKV, (const T*) nullptr, seq_lens, padding_offset, batch_size,
seq_len, token_num, head_num, kv_head_num, size_per_head, using_context_fmha, rotary_embedding_dim,
rotary_embedding_base, rotary_embedding_scale, position_embedding_type, scale, int8_mode, stream);
rotary_embedding_base, rotary_scale_type, rotary_embedding_scale, rotary_embedding_max_positions,
position_embedding_type, scale, int8_mode, stream);
}
template <typename T, typename KVCacheBuffer>
@ -105,5 +108,13 @@ void invokeTranspose4dBatchMajor(const T* k_src, const T* v_src, KVCacheBuffer&
const int seq_len, const int max_seq_len, const int size_per_head, const int local_head_num,
const KvCacheDataType cache_type, const float* kvScaleOrigQuant, const int* sequence_lengths, cudaStream_t stream);
template <typename T, typename KVCacheBuffer>
void invokeApplyBiasRopeUpdateKVCache(T* QKV, KVCacheBuffer& kvTable, const T* qkv_bias, const int* seq_lens,
const int* padding_offset, const int batch_size, const int seq_len, const int token_num, const int head_num,
const int kv_head_num, const int size_per_head, const int rotary_embedding_dim, const float rotary_embedding_base,
const RotaryScalingType rotary_scale_type, const float rotary_embedding_scale,
const int rotary_embedding_max_positions, const PositionEmbeddingType position_embedding_type, const float* scale,
const int int8_mode, const KvCacheDataType cache_type, const float* kvScaleOrigQuant, cudaStream_t stream);
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -0,0 +1,371 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2021, NAVER Corp. Authored by CLOVA.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// Separate from unfusedAttentionKernel to accelerate compiling.
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaTypeUtils.cuh"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/reduceKernelUtils.cuh"
#include "tensorrt_llm/kernels/decoderMaskedMultiheadAttentionUtils.h"
#include "tensorrt_llm/kernels/gptKernels.h"
#include "tensorrt_llm/kernels/kvCacheUtils.h"
#include "tensorrt_llm/kernels/unfusedAttentionKernels.h"
using namespace tensorrt_llm::common;
namespace tensorrt_llm
{
namespace kernels
{
template <typename T>
struct Vec_t
{
static constexpr int size = 0;
};
template <>
struct Vec_t<float>
{
using Type = float2;
static constexpr int size = 2;
};
template <>
struct Vec_t<half>
{
using Type = uint32_t;
static constexpr int size = 2;
};
#ifdef ENABLE_BF16
template <>
struct Vec_t<__nv_bfloat16>
{
using Type = __nv_bfloat162;
static constexpr int size = 2;
};
#endif
template <typename T, typename T_cache, bool ADD_BIAS, typename KVCacheBuffer>
__global__ void applyBiasRopeUpdateKVCache(T* QKV, KVCacheBuffer kvCacheBuffer, const T* __restrict qkv_bias,
const int* seq_lens, const int* padding_offset, const float* kvScaleOrigQuant, const int batch_size,
const int seq_len, const int head_num, const int kv_head_num, const int size_per_head,
const int rotary_embedding_dim, float rotary_embedding_base, RotaryScalingType const rotary_scale_type,
float rotary_embedding_scale, const int rotary_embedding_max_positions,
PositionEmbeddingType const position_embedding_type)
{
// This kernel add bias to QKV, which has shape [batch_size, seq_len, 3, head_num, size_per_head], and
// QKV split to 3 split buffer q, k, v and transpose them to [batch_size, head_num, seq_len, size_per_head].
// For q and k, also apply the rotary embedding.
// NOTE:
// head_num == kv_head_num
// QKV src shape (batch_size, seq_len, 3, head_num, size_per_head)
// ^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^
// m n
// QKV dst shape (3, batch_size, head_num, seq_len, size_per_head)
// head_num != kv_head_num
// QKV src shape: (batch_size, seq_len, head_num * size_per_head + 2 * kv_head_num * size_per_head)
// ^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
// m n
// Q dst shape: (batch_size, head_num, seq_len, size_per_head)
// KV dst shape: (batch_size, kv_head_num, seq_len, size_per_head)
extern __shared__ __align__(sizeof(float2)) char smem_[]; // align on largest vector type
constexpr int vec_size = Vec_t<T>::size;
using Vec_t = typename Vec_t<T>::Type;
const int token_idx = blockIdx.x;
const bool has_padding = padding_offset == nullptr;
constexpr bool ENABLE_8BITS_CACHE = sizeof(T_cache) == 1;
constexpr int X_ELEMS = vec_size;
const int sizePerHeadDivX = size_per_head / X_ELEMS;
using T_dst = T_cache;
// The index of the token in the batch. It includes "virtual" padding (even if the input is not padded)
// such that the sequence index and the position in the sequence can be obtained using the max.
// sequence length as:
const int token_padding_offset = has_padding ? 0 : padding_offset[token_idx];
const int global_token_idx = token_idx + token_padding_offset;
const int batch_idx = global_token_idx / seq_len;
const int token_idx_in_seq = global_token_idx % seq_len;
const int actual_seq_len = seq_lens[batch_idx];
const bool valid_seq = token_idx_in_seq < actual_seq_len || !has_padding;
const int head_idx = blockIdx.y;
const int tidx = threadIdx.x;
const bool is_seq_masked = !valid_seq;
const bool is_head_size_masked = tidx * vec_size >= size_per_head;
const bool is_masked = is_head_size_masked || is_seq_masked;
const int hidden_size = head_num * size_per_head;
const int hidden_idx = head_idx * size_per_head + tidx * vec_size;
const int qheads_per_kv_head = head_num / kv_head_num;
const int kv_head_idx = head_idx / qheads_per_kv_head;
const int hidden_idx_kv = kv_head_idx * size_per_head + tidx * vec_size;
const int n = (head_num + 2 * kv_head_num) * size_per_head;
const int dst_kv_seq_idx = token_idx_in_seq;
const int src_k_offset = hidden_size;
const int src_v_offset = hidden_size + kv_head_num * size_per_head;
// NOTE: q has seq len excluding prefix prompt
// head_num == kv_head_num:
// src QKV: [batch, time, 3, head_num, size_per_head]
// head_num != kv_head_num:
// src QKV: [batch, time, head_num * size_per_head + 2 * kv_head_num * size_per_head]
const int src_q_idx = token_idx * n + hidden_idx;
const int src_k_idx = token_idx * n + src_k_offset + hidden_idx_kv;
const int src_v_idx = token_idx * n + src_v_offset + hidden_idx_kv;
Vec_t q, k, v, zero;
Vec_t q_bias, k_bias, v_bias;
if (valid_seq)
{
mmha::update_rotary_base_n_scale(rotary_embedding_base, rotary_embedding_scale, rotary_scale_type,
rotary_embedding_dim, rotary_embedding_max_positions, actual_seq_len);
}
#pragma unroll
for (int i = 0; i < sizeof(Vec_t) / sizeof(uint32_t); i++)
{
reinterpret_cast<uint32_t*>(&zero)[i] = 0u;
}
// load q,k,v and add bias
if (!is_masked)
{
q = *reinterpret_cast<const Vec_t*>(&QKV[src_q_idx]);
k = *reinterpret_cast<const Vec_t*>(&QKV[src_k_idx]);
v = *reinterpret_cast<const Vec_t*>(&QKV[src_v_idx]);
if (ADD_BIAS)
{
q_bias = *reinterpret_cast<const Vec_t*>(&qkv_bias[hidden_idx]);
k_bias = *reinterpret_cast<const Vec_t*>(&qkv_bias[hidden_idx_kv + src_k_offset]);
v_bias = *reinterpret_cast<const Vec_t*>(&qkv_bias[hidden_idx_kv + src_v_offset]);
q = mmha::add(q, q_bias);
k = mmha::add(k, k_bias);
v = mmha::add(v, v_bias);
}
}
switch (position_embedding_type)
{
case PositionEmbeddingType::kROPE_GPTJ:
{
mmha::apply_rotary_embedding(
q, k, tidx, rotary_embedding_dim, rotary_embedding_base, rotary_embedding_scale, dst_kv_seq_idx);
break;
}
case PositionEmbeddingType::kROPE_GPT_NEOX:
{
const bool do_rotary = !is_masked && vec_size * tidx < rotary_embedding_dim;
T* q_smem = reinterpret_cast<T*>(smem_);
T* k_smem = q_smem + rotary_embedding_dim;
const int half_rotary_dim = rotary_embedding_dim / 2;
const int half_idx = (tidx * vec_size) / half_rotary_dim;
const int intra_half_idx = (tidx * vec_size) % half_rotary_dim;
const int smem_pitch = half_rotary_dim; // TODO: adjust for bank conflicts?
if (do_rotary)
{
*reinterpret_cast<Vec_t*>(q_smem + half_idx * smem_pitch + intra_half_idx) = q;
*reinterpret_cast<Vec_t*>(k_smem + half_idx * smem_pitch + intra_half_idx) = k;
}
__syncthreads();
const int transpose_idx = half_idx * (half_rotary_dim / 2) + intra_half_idx / 2;
constexpr int tidx_factor = vec_size / 2;
if (do_rotary)
{
mmha::vec_from_smem_transpose(q, q_smem, transpose_idx, smem_pitch);
mmha::vec_from_smem_transpose(k, k_smem, transpose_idx, smem_pitch);
mmha::apply_rotary_embedding(q, k, transpose_idx / tidx_factor, rotary_embedding_dim, rotary_embedding_base,
rotary_embedding_scale, dst_kv_seq_idx);
mmha::write_smem_transpose(q, q_smem, transpose_idx, smem_pitch);
mmha::write_smem_transpose(k, k_smem, transpose_idx, smem_pitch);
}
__syncthreads();
if (do_rotary)
{
q = *reinterpret_cast<Vec_t*>(q_smem + half_idx * smem_pitch + intra_half_idx);
k = *reinterpret_cast<Vec_t*>(k_smem + half_idx * smem_pitch + intra_half_idx);
}
break;
}
}
const int channelIdx{tidx};
auto kDst = reinterpret_cast<T_dst*>(kvCacheBuffer.getKBlockPtr(batch_idx, token_idx_in_seq));
auto vDst = reinterpret_cast<T_dst*>(kvCacheBuffer.getVBlockPtr(batch_idx, token_idx_in_seq));
int inBlockIdx = kvCacheBuffer.getKVLocalIdx(token_idx_in_seq, kv_head_idx, sizePerHeadDivX, channelIdx);
if (!is_masked)
{
*reinterpret_cast<Vec_t*>(&QKV[src_q_idx]) = q;
if ((head_num == kv_head_num) || (head_idx == (kv_head_idx * qheads_per_kv_head)))
{
*reinterpret_cast<Vec_t*>(&QKV[src_k_idx]) = k;
*reinterpret_cast<Vec_t*>(&QKV[src_v_idx]) = v;
if (ENABLE_8BITS_CACHE)
{
inBlockIdx = inBlockIdx * vec_size;
// Cast float scale to dst data type.
using T_scale = typename mmha::kv_cache_scale_type_t<T, T_cache>::Type;
T_scale scaleOrigQuant;
mmha::convert_from_float(&scaleOrigQuant, kvScaleOrigQuant[0]);
// Store 8bits kv cache.
mmha::store_8bits_kv_cache_vec(kDst, k, inBlockIdx, scaleOrigQuant);
mmha::store_8bits_kv_cache_vec(vDst, v, inBlockIdx, scaleOrigQuant);
}
else
{
reinterpret_cast<Vec_t*>(kDst)[inBlockIdx] = k;
reinterpret_cast<Vec_t*>(vDst)[inBlockIdx] = v;
}
}
}
else if (is_seq_masked && !is_head_size_masked)
{
// Set padding to zero in case of potential nan generated.
*reinterpret_cast<Vec_t*>(&QKV[src_q_idx]) = zero;
if ((head_num == kv_head_num) || (head_idx == (kv_head_idx * qheads_per_kv_head)))
{
*reinterpret_cast<Vec_t*>(&QKV[src_k_idx]) = zero;
*reinterpret_cast<Vec_t*>(&QKV[src_v_idx]) = zero;
if (ENABLE_8BITS_CACHE)
{
inBlockIdx = inBlockIdx * vec_size;
// Cast float scale to dst data type.
using T_scale = typename mmha::kv_cache_scale_type_t<T, T_cache>::Type;
T_scale scaleOrigQuant;
mmha::convert_from_float(&scaleOrigQuant, kvScaleOrigQuant[0]);
// Store 8bits kv cache.
mmha::store_8bits_kv_cache_vec(kDst, zero, inBlockIdx, scaleOrigQuant);
mmha::store_8bits_kv_cache_vec(vDst, zero, inBlockIdx, scaleOrigQuant);
}
else
{
reinterpret_cast<Vec_t*>(kDst)[inBlockIdx] = zero;
reinterpret_cast<Vec_t*>(vDst)[inBlockIdx] = zero;
}
}
}
}
template <typename T, typename T_cache, typename KVCacheBuffer>
void invokeApplyBiasRopeUpdateKVCacheDispatch(T* QKV, KVCacheBuffer& kvTable, const T* qkv_bias, const int* seq_lens,
const int* padding_offset, const int batch_size, const int seq_len, const int token_num, const int head_num,
const int kv_head_num, const int size_per_head, const int rotary_embedding_dim, const float rotary_embedding_base,
const RotaryScalingType rotary_scale_type, const float rotary_embedding_scale,
const int rotary_embedding_max_positions, const PositionEmbeddingType position_embedding_type, const float* scale,
const float* kvScaleOrigQuant, const int int8_mode, cudaStream_t stream)
{
TLLM_CHECK_WITH_INFO(int8_mode != 2, "w8a8 not yet implemented with RoPE"); // TODO(mseznec)
// To implement rotary embeddings, each thread processes two QKV elems:
dim3 block((size_per_head / Vec_t<T>::size + 31) / 32 * 32);
dim3 grid(token_num, head_num);
size_t smem_size
= (position_embedding_type == PositionEmbeddingType::kROPE_GPT_NEOX ? 2 * rotary_embedding_dim * sizeof(T) : 0);
// NOTE: add offset for rotary embedding
if (qkv_bias != nullptr)
{
applyBiasRopeUpdateKVCache<T, T_cache, true, KVCacheBuffer><<<grid, block, smem_size, stream>>>(QKV, kvTable,
qkv_bias, seq_lens, padding_offset, kvScaleOrigQuant, batch_size, seq_len, head_num, kv_head_num,
size_per_head, rotary_embedding_dim, rotary_embedding_base, rotary_scale_type, rotary_embedding_scale,
rotary_embedding_max_positions, position_embedding_type);
}
else
{
applyBiasRopeUpdateKVCache<T, T_cache, false, KVCacheBuffer><<<grid, block, smem_size, stream>>>(QKV, kvTable,
qkv_bias, seq_lens, padding_offset, kvScaleOrigQuant, batch_size, seq_len, head_num, kv_head_num,
size_per_head, rotary_embedding_dim, rotary_embedding_base, rotary_scale_type, rotary_embedding_scale,
rotary_embedding_max_positions, position_embedding_type);
}
}
template <typename T, typename KVCacheBuffer>
void invokeApplyBiasRopeUpdateKVCache(T* QKV, KVCacheBuffer& kvTable, const T* qkv_bias, const int* seq_lens,
const int* padding_offset, const int batch_size, const int seq_len, const int token_num, const int head_num,
const int kv_head_num, const int size_per_head, const int rotary_embedding_dim, const float rotary_embedding_base,
const RotaryScalingType rotary_scale_type, const float rotary_embedding_scale,
const int rotary_embedding_max_positions, const PositionEmbeddingType position_embedding_type, const float* scale,
const int int8_mode, const KvCacheDataType cache_type, const float* kvScaleOrigQuant, cudaStream_t stream)
{
// Block handles both K and V tile.
constexpr int x = (sizeof(T) == 4) ? 4 : 8;
TLLM_CHECK_WITH_INFO(size_per_head % x == 0, "Size per head is not a multiple of X");
if (cache_type == KvCacheDataType::INT8)
{
invokeApplyBiasRopeUpdateKVCacheDispatch<T, int8_t, KVCacheBuffer>(QKV, kvTable, qkv_bias, seq_lens,
padding_offset, batch_size, seq_len, token_num, head_num, kv_head_num, size_per_head, rotary_embedding_dim,
rotary_embedding_base, rotary_scale_type, rotary_embedding_scale, rotary_embedding_max_positions,
position_embedding_type, scale, kvScaleOrigQuant, int8_mode, stream);
}
#ifdef ENABLE_FP8
else if (cache_type == KvCacheDataType::FP8)
{
invokeApplyBiasRopeUpdateKVCacheDispatch<T, __nv_fp8_e4m3, KVCacheBuffer>(QKV, kvTable, qkv_bias, seq_lens,
padding_offset, batch_size, seq_len, token_num, head_num, kv_head_num, size_per_head, rotary_embedding_dim,
rotary_embedding_base, rotary_scale_type, rotary_embedding_scale, rotary_embedding_max_positions,
position_embedding_type, scale, kvScaleOrigQuant, int8_mode, stream);
}
#endif // ENABLE_FP8
else
{
invokeApplyBiasRopeUpdateKVCacheDispatch<T, T, KVCacheBuffer>(QKV, kvTable, qkv_bias, seq_lens, padding_offset,
batch_size, seq_len, token_num, head_num, kv_head_num, size_per_head, rotary_embedding_dim,
rotary_embedding_base, rotary_scale_type, rotary_embedding_scale, rotary_embedding_max_positions,
position_embedding_type, scale, kvScaleOrigQuant, int8_mode, stream);
}
}
#define INSTANTIATE_ADDFUSEDQKVBIAS_TRANSPOSE(T, KVCacheBuffer) \
template void invokeApplyBiasRopeUpdateKVCache(T* QKV, KVCacheBuffer& kvTable, const T* qkv_bias, \
const int* seq_lens, const int* padding_offset, const int batch_size, const int seq_len, const int token_num, \
const int head_num, const int kv_head_num, const int size_per_head, const int rotary_embedding_dim, \
const float rotary_embedding_base, const RotaryScalingType rotary_scale_type, \
const float rotary_embedding_scale, const int rotary_embedding_max_positions, \
const PositionEmbeddingType position_embedding_type, const float* scale, const int int8_mode, \
const KvCacheDataType cache_type, const float* kvScaleOrigQuant, cudaStream_t stream)
INSTANTIATE_ADDFUSEDQKVBIAS_TRANSPOSE(float, KVBlockArray);
INSTANTIATE_ADDFUSEDQKVBIAS_TRANSPOSE(float, KVLinearBuffer);
INSTANTIATE_ADDFUSEDQKVBIAS_TRANSPOSE(half, KVBlockArray);
INSTANTIATE_ADDFUSEDQKVBIAS_TRANSPOSE(half, KVLinearBuffer);
#ifdef ENABLE_BF16
INSTANTIATE_ADDFUSEDQKVBIAS_TRANSPOSE(__nv_bfloat16, KVBlockArray);
INSTANTIATE_ADDFUSEDQKVBIAS_TRANSPOSE(__nv_bfloat16, KVLinearBuffer);
#endif
#undef INSTANTIATE_ADDFUSEDQKVBIAS_TRANSPOSE
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -65,8 +65,8 @@ void update_indir_cache_kernelLauncher(int* tgt_indir_cache, const int* src_indi
template <typename T>
BaseBeamSearchLayer<T>::BaseBeamSearchLayer(size_t vocab_size, size_t vocab_size_padded, cudaStream_t stream,
cublasMMWrapper* cublas_wrapper, IAllocator* allocator, bool is_free_buffer_after_forward)
: BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward, nullptr)
IAllocator* allocator, bool is_free_buffer_after_forward)
: BaseLayer(stream, allocator, is_free_buffer_after_forward, nullptr)
, vocab_size_(vocab_size)
, vocab_size_padded_(vocab_size_padded)
{

View File

@ -42,8 +42,8 @@ class BaseBeamSearchLayer : public BaseLayer
public:
using SetupParams = DecodingSetupParams;
BaseBeamSearchLayer(size_t vocab_size, size_t vocab_size_padded, cudaStream_t stream,
tc::cublasMMWrapper* cublas_wrapper, tc::IAllocator* allocator, bool is_free_buffer_after_forward);
BaseBeamSearchLayer(size_t vocab_size, size_t vocab_size_padded, cudaStream_t stream, tc::IAllocator* allocator,
bool is_free_buffer_after_forward);
BaseBeamSearchLayer(BaseBeamSearchLayer<T> const& beam_search_layer);

View File

@ -17,7 +17,6 @@
#pragma once
#include "tensorrt_llm/common/allocator.h"
#include "tensorrt_llm/common/cublasMMWrapper.h"
#include "tensorrt_llm/common/tensor.h"
namespace tensorrt_llm
@ -28,11 +27,9 @@ namespace layers
class BaseLayer
{
public:
BaseLayer(cudaStream_t stream, tensorrt_llm::common::cublasMMWrapper* cublas_wrapper,
tensorrt_llm::common::IAllocator* allocator, bool is_free_buffer_after_forward,
BaseLayer(cudaStream_t stream, tensorrt_llm::common::IAllocator* allocator, bool is_free_buffer_after_forward,
cudaDeviceProp* cuda_device_prop = nullptr)
: stream_(stream)
, cublas_wrapper_(cublas_wrapper)
, allocator_(allocator)
, cuda_device_prop_(cuda_device_prop)
, is_free_buffer_after_forward_(is_free_buffer_after_forward){};
@ -51,7 +48,6 @@ public:
protected:
// device environments
cudaStream_t stream_;
tensorrt_llm::common::cublasMMWrapper* cublas_wrapper_;
tensorrt_llm::common::IAllocator* allocator_;
cudaDeviceProp* cuda_device_prop_ = nullptr;

View File

@ -69,9 +69,8 @@ void BaseSamplingLayer<T>::freeBuffer()
template <typename T>
BaseSamplingLayer<T>::BaseSamplingLayer(size_t vocab_size, size_t vocab_size_padded, cudaStream_t stream,
cublasMMWrapper* cublas_wrapper, IAllocator* allocator, bool is_free_buffer_after_forward,
cudaDeviceProp* cuda_device_prop)
: BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward, cuda_device_prop)
IAllocator* allocator, bool is_free_buffer_after_forward, cudaDeviceProp* cuda_device_prop)
: BaseLayer(stream, allocator, is_free_buffer_after_forward, cuda_device_prop)
, vocab_size_(vocab_size)
, vocab_size_padded_(vocab_size_padded)
{

View File

@ -36,8 +36,8 @@ class BaseSamplingLayer : public BaseLayer
{
public:
BaseSamplingLayer(size_t vocab_size, size_t vocab_size_padded, cudaStream_t stream,
tensorrt_llm::common::cublasMMWrapper* cublas_wrapper, tensorrt_llm::common::IAllocator* allocator,
bool is_free_buffer_after_forward, cudaDeviceProp* cuda_device_prop);
tensorrt_llm::common::IAllocator* allocator, bool is_free_buffer_after_forward,
cudaDeviceProp* cuda_device_prop);
BaseSamplingLayer(BaseSamplingLayer const& sampling_layer);

View File

@ -37,20 +37,18 @@ void DynamicDecodeLayer<T>::initialize()
{
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
mOnlineBeamsearchDecode = std::make_unique<OnlineBeamSearchLayer<T>>(
vocab_size_, vocab_size_padded_, stream_, cublas_wrapper_, allocator_, is_free_buffer_after_forward_);
vocab_size_, vocab_size_padded_, stream_, allocator_, is_free_buffer_after_forward_);
mTopKDecode = std::make_unique<TopKSamplingLayer<T>>(
vocab_size_, vocab_size_padded_, stream_, cublas_wrapper_, allocator_, false);
mTopKDecode = std::make_unique<TopKSamplingLayer<T>>(vocab_size_, vocab_size_padded_, stream_, allocator_, false);
mTopPDecode = std::make_unique<TopPSamplingLayer<T>>(
vocab_size_, vocab_size_padded_, stream_, cublas_wrapper_, allocator_, false, cuda_device_prop_);
vocab_size_, vocab_size_padded_, stream_, allocator_, false, cuda_device_prop_);
}
template <typename T>
DynamicDecodeLayer<T>::DynamicDecodeLayer(size_t vocab_size, size_t vocab_size_padded, cudaStream_t stream,
cublasMMWrapper* cublas_wrapper, IAllocator* allocator, bool is_free_buffer_after_forward,
cudaDeviceProp* cuda_device_prop)
: BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward)
IAllocator* allocator, bool is_free_buffer_after_forward, cudaDeviceProp* cuda_device_prop)
: BaseLayer(stream, allocator, is_free_buffer_after_forward)
, vocab_size_(vocab_size)
, vocab_size_padded_(vocab_size_padded)
, cuda_device_prop_(cuda_device_prop)

View File

@ -43,9 +43,8 @@ template <typename T>
class DynamicDecodeLayer : public BaseLayer
{
public:
DynamicDecodeLayer(size_t vocab_size, size_t vocab_size_padded, cudaStream_t stream,
tc::cublasMMWrapper* cublas_wrapper, tc::IAllocator* allocator, bool is_free_buffer_after_forward,
cudaDeviceProp* cuda_device_prop);
DynamicDecodeLayer(size_t vocab_size, size_t vocab_size_padded, cudaStream_t stream, tc::IAllocator* allocator,
bool is_free_buffer_after_forward, cudaDeviceProp* cuda_device_prop);
~DynamicDecodeLayer() override;
DynamicDecodeLayer(DynamicDecodeLayer const& dynamic_decode_layer);

View File

@ -154,9 +154,8 @@ void OnlineBeamSearchLayer<T>::allocateBuffer(size_t batch_size, size_t beam_wid
template <typename T>
OnlineBeamSearchLayer<T>::OnlineBeamSearchLayer(size_t vocab_size, size_t vocab_size_padded, cudaStream_t stream,
cublasMMWrapper* cublas_wrapper, IAllocator* allocator, bool is_free_buffer_after_forward)
: BaseBeamSearchLayer<T>(
vocab_size, vocab_size_padded, stream, cublas_wrapper, allocator, is_free_buffer_after_forward)
IAllocator* allocator, bool is_free_buffer_after_forward)
: BaseBeamSearchLayer<T>(vocab_size, vocab_size_padded, stream, allocator, is_free_buffer_after_forward)
{
}

View File

@ -42,8 +42,8 @@ public:
std::optional<float> length_penalty;
};
OnlineBeamSearchLayer(size_t vocab_size, size_t vocab_size_padded, cudaStream_t stream,
tc::cublasMMWrapper* cublas_wrapper, tc::IAllocator* allocator, bool is_free_buffer_after_forward);
OnlineBeamSearchLayer(size_t vocab_size, size_t vocab_size_padded, cudaStream_t stream, tc::IAllocator* allocator,
bool is_free_buffer_after_forward);
OnlineBeamSearchLayer(OnlineBeamSearchLayer<T> const& beam_search_layer);

View File

@ -203,9 +203,8 @@ void TopKSamplingLayer<T>::runSampling(DecodingOutputParams& outputs, DecodingPa
template <typename T>
TopKSamplingLayer<T>::TopKSamplingLayer(size_t vocab_size, size_t vocab_size_padded, cudaStream_t stream,
cublasMMWrapper* cublas_wrapper, IAllocator* allocator, bool is_free_buffer_after_forward)
: BaseSamplingLayer<T>(
vocab_size, vocab_size_padded, stream, cublas_wrapper, allocator, is_free_buffer_after_forward, nullptr)
IAllocator* allocator, bool is_free_buffer_after_forward)
: BaseSamplingLayer<T>(vocab_size, vocab_size_padded, stream, allocator, is_free_buffer_after_forward, nullptr)
{
}

View File

@ -34,8 +34,7 @@ public:
using SetupParams = typename Base::SetupParams;
TopKSamplingLayer(size_t vocab_size, size_t vocab_size_padded, cudaStream_t stream,
tensorrt_llm::common::cublasMMWrapper* cublas_wrapper, tensorrt_llm::common::IAllocator* allocator,
bool is_free_buffer_after_forward);
tensorrt_llm::common::IAllocator* allocator, bool is_free_buffer_after_forward);
TopKSamplingLayer(TopKSamplingLayer<T> const& top_k_sampling_layer);
~TopKSamplingLayer();

View File

@ -266,10 +266,9 @@ void TopPSamplingLayer<T>::runSampling(DecodingOutputParams& outputs, DecodingPa
template <typename T>
TopPSamplingLayer<T>::TopPSamplingLayer(std::size_t vocab_size, std::size_t vocab_size_padded, cudaStream_t stream,
cublasMMWrapper* cublas_wrapper, IAllocator* allocator, bool is_free_buffer_after_forward,
cudaDeviceProp* cuda_device_prop)
: BaseSamplingLayer<T>(vocab_size, vocab_size_padded, stream, cublas_wrapper, allocator,
is_free_buffer_after_forward, cuda_device_prop)
IAllocator* allocator, bool is_free_buffer_after_forward, cudaDeviceProp* cuda_device_prop)
: BaseSamplingLayer<T>(
vocab_size, vocab_size_padded, stream, allocator, is_free_buffer_after_forward, cuda_device_prop)
{
}

View File

@ -42,8 +42,8 @@ public:
};
TopPSamplingLayer(std::size_t vocab_size, std::size_t vocab_size_padded, cudaStream_t stream,
tensorrt_llm::common::cublasMMWrapper* cublas_wrapper, tensorrt_llm::common::IAllocator* allocator,
bool is_free_buffer_after_forward, cudaDeviceProp* cuda_device_prop);
tensorrt_llm::common::IAllocator* allocator, bool is_free_buffer_after_forward,
cudaDeviceProp* cuda_device_prop);
TopPSamplingLayer(TopPSamplingLayer<T> const& top_p_sampling_layer);
~TopPSamplingLayer();

View File

@ -19,7 +19,8 @@ set(PLUGIN_TARGET_NAME nvinfer_plugin_tensorrt_llm)
set(PLUGIN_SHARED_TARGET ${PLUGIN_TARGET_NAME})
set(TARGET_DIR ${CMAKE_CURRENT_SOURCE_DIR})
set(PLUGIN_EXPORT_MAP ${TARGET_DIR}/exports.map)
set(PLUGIN_EXPORT_MAP ${TARGET_DIR}/exports.map) # Linux
set(PLUGIN_EXPORT_DEF ${TARGET_DIR}/exports.def) # Windows
if(${CMAKE_BUILD_TYPE} MATCHES "Debug")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g")
@ -89,7 +90,11 @@ set_target_properties(
LIBRARY_OUTPUT_DIRECTORY "${TRT_OUT_DIR}"
RUNTIME_OUTPUT_DIRECTORY "${TRT_OUT_DIR}")
if(NOT MSVC)
if(MSVC)
set_target_properties(
${PLUGIN_SHARED_TARGET}
PROPERTIES LINK_FLAGS "/DEF:${PLUGIN_EXPORT_DEF} ${UNDEFINED_FLAG}")
else()
set_target_properties(
${PLUGIN_SHARED_TARGET}
PROPERTIES

View File

@ -225,7 +225,8 @@ int BertAttentionPlugin::enqueueImpl(const nvinfer1::PluginTensorDesc* inputDesc
// Padding offset = nullptr here (remove padding is not supported).
invokeAddFusedQKVBiasTranspose(q_buf_2_, k_buf_2_, v_buf_2_, const_cast<T*>(attention_input), input_lengths,
nullptr, request_batch_size, request_seq_len, batch_size * input_seq_len, mNumHeads, mNumHeads, mHeadSize,
mEnableContextFMHA, 0, 0.0f, 0.0f, PositionEmbeddingType::kLEARNED_ABSOLUTE, (float*) nullptr, 0, stream);
mEnableContextFMHA, 0, 0.0f, RotaryScalingType::kNONE, 0.0f, 0, PositionEmbeddingType::kLEARNED_ABSOLUTE,
(float*) nullptr, 0, stream);
const auto gemm_data_type = tc::CudaDataType<T>::value;
const int attention_seq_len_1 = request_seq_len; // q length
@ -363,13 +364,10 @@ int BertAttentionPlugin::initialize() noexcept
{
auto cublasHandle = getCublasHandle();
auto cublasLtHandle = getCublasLtHandle();
mCublasAlgoMap = new tc::cublasAlgoMap(GEMM_CONFIG);
mCublasWrapperMutex = new std::mutex();
mCublasWrapper
= new tc::cublasMMWrapper(cublasHandle, cublasLtHandle, nullptr, mCublasAlgoMap, mCublasWrapperMutex, nullptr);
mCublasWrapper.reset(new tc::CublasMMWrapper(cublasHandle, cublasLtHandle, nullptr, nullptr));
if (mEnableContextFMHA)
{
mFMHARunner = new FusedMHARunnerV2(DATA_TYPE_FP16, mNumHeads, mHeadSize, mQScaling);
mFMHARunner.reset(new FusedMHARunnerV2(DATA_TYPE_FP16, mNumHeads, mHeadSize, mQScaling));
// set flags: force_fp32_acc, is_s_padded, causal_mask, num_kv_heads = num_heads
mFMHARunner->setup_flags(mFMHAForceFP32Acc, true, false, mNumHeads);
}
@ -379,18 +377,6 @@ int BertAttentionPlugin::initialize() noexcept
void BertAttentionPlugin::destroy() noexcept
{
delete mCublasAlgoMap;
delete mCublasWrapperMutex;
delete mCublasWrapper;
if (mEnableContextFMHA)
{
delete mFMHARunner;
}
mCublasAlgoMap = nullptr;
mCublasWrapperMutex = nullptr;
mCublasWrapper = nullptr;
mFMHARunner = nullptr;
delete this;
}

View File

@ -88,11 +88,10 @@ private:
bool mEnableContextFMHA = false;
bool mFMHAForceFP32Acc = false;
bool mSM = tensorrt_llm::common::getSMVersion();
tensorrt_llm::kernels::MHARunner* mFMHARunner;
tensorrt_llm::common::cublasAlgoMap* mCublasAlgoMap;
std::mutex* mCublasWrapperMutex;
tensorrt_llm::common::cublasMMWrapper* mCublasWrapper;
// The default copy constructor will leave them as nullptr. clone() shall initialize it.
UniqPtrWNullCopy<tensorrt_llm::kernels::FusedMHARunnerV2> mFMHARunner;
UniqPtrWNullCopy<tensorrt_llm::common::CublasMMWrapper> mCublasWrapper;
};
class BertAttentionPluginCreator : public BaseCreator

View File

@ -0,0 +1,282 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/plugins/common/gemmPluginProfiler.h"
#include "tensorrt_llm/common/cublasMMWrapper.h"
#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h"
#include "tensorrt_llm/kernels/cutlass_kernels/int8_gemm/int8_gemm.h"
namespace tensorrt_llm::plugins
{
template <typename Config, typename RunnerPtr, typename GemmIdType, typename GemmIdHashType>
GemmPluginProfiler<Config, RunnerPtr, GemmIdType, GemmIdHashType>::GemmPluginProfiler()
{
mMNKProfileMap = std::make_shared<MNKProfileMap>();
// set SKIP_GEMM_PLUGIN_PROFILINGS=1 to avoid tactics profilings
const auto skipEnv = std::getenv("SKIP_GEMM_PLUGIN_PROFILINGS");
mSkip = (skipEnv != NULL && std::stoi(skipEnv));
if (mSkip)
{
TLLM_LOG_DEBUG(
"SKIP_GEMM_PLUGIN_PROFILINGS is set. Skipping GEMM plugin profilings. It could result in runtime error "
"if default tactic is not defined.");
}
}
template <typename Config, typename RunnerPtr, typename GemmIdType, typename GemmIdHashType>
void GemmPluginProfiler<Config, RunnerPtr, GemmIdType, GemmIdHashType>::serialize(
char*& buffer, const GemmIdType& gemmId) const
{
auto mProfileMap = mMNKProfileMap->getMProfileMap(gemmId);
// Save number of profiles for given GEMM ID
write(buffer, static_cast<int>(mProfileMap->size()));
for (const auto& pair : *mProfileMap)
{
// Save pair of M to the best GEMM config
write(buffer, pair);
}
}
template <typename Config, typename RunnerPtr, typename GemmIdType, typename GemmIdHashType>
void GemmPluginProfiler<Config, RunnerPtr, GemmIdType, GemmIdHashType>::deserialize(
const char*& data, GemmDims& dims, const GemmIdType& gemmId)
{
// NOTE(nkorobov): this mutex is not needed since each thread owns its private map, but will put here for
// consistency
writer_lock lock(mMNKProfileMap->mutex);
mDims = dims;
// GemmId gemmId(dims.n, dims.k);
if (!mMNKProfileMap->existsMProfileMap(gemmId))
{
// Create GEMM with GEMM ID if it does not exist
mMNKProfileMap->createMProfileMap(gemmId);
}
// Populate map with profiles of GEMM ID
auto profileMap = mMNKProfileMap->getMProfileMap(gemmId);
int selectedMapSize;
read(data, selectedMapSize);
for (int ii = 0; ii < selectedMapSize; ++ii)
{
std::pair<int, std::optional<Config>> config;
read(data, config);
profileMap->insert(config);
}
}
template <typename Config, typename RunnerPtr, typename GemmIdType, typename GemmIdHashType>
size_t GemmPluginProfiler<Config, RunnerPtr, GemmIdType, GemmIdHashType>::getSerializationSize(
const GemmIdType& gemmId) const
{
reader_lock lock(mMNKProfileMap->mutex);
return sizeof(int) + // size of the tactics map
mMNKProfileMap->getMProfileMap(gemmId)->size()
* sizeof(std::pair<int, std::optional<Config>>); // size of the tactics map
}
template <typename Config, typename RunnerPtr, typename GemmIdType, typename GemmIdHashType>
void GemmPluginProfiler<Config, RunnerPtr, GemmIdType, GemmIdHashType>::profileTactics(
const RunnerPtr& runner, const nvinfer1::DataType& type, const GemmDims& dims, const GemmIdType& gemmId)
{
writer_lock lock(mMNKProfileMap->mutex);
if (!dims.isInitialized())
{
return;
}
mRunner = runner;
mType = type;
const int maxM = std::min(nextPowerOfTwo(dims.maxM), MAX_PROFILE_M);
computeTmpSize(maxM, dims.n, dims.k);
if (!mMNKProfileMap->existsMProfileMap(gemmId))
{
// Create map for GEMM ID
mMNKProfileMap->createMProfileMap(gemmId);
}
if (mSkip)
{
return;
}
auto mProfileMap = mMNKProfileMap->getMProfileMap(gemmId);
auto profileTactics = [&mProfileMap, this](int m, int n, int k)
{
if (mProfileMap->count(m) == 0)
{
const auto tactics = this->getTactics(m, n, k);
// Profile different tactics for particular m and insert best config to the map
mProfileMap->insert({m, this->profileTacticsForProblem(m, n, k, tactics)});
}
};
// Allocate tmp data to run GEMMs
allocateTmpData();
const int startMinMRounded = nextPowerOfTwo(dims.minM);
for (int m = startMinMRounded; m < maxM; m *= 2)
{
profileTactics(m, dims.n, dims.k);
}
profileTactics(maxM, dims.n, dims.k);
// Free tmp data
freeTmpData();
}
template <typename Config, typename RunnerPtr, typename GemmIdType, typename GemmIdHashType>
std::optional<Config> GemmPluginProfiler<Config, RunnerPtr, GemmIdType, GemmIdHashType>::getBestConfig(
int m, const GemmIdType& gemmId) const
{
reader_lock lock(mMNKProfileMap->mutex);
if (mSkip)
{
return std::nullopt;
}
const int mRounded = std::min(nextPowerOfTwo(m), MAX_PROFILE_M);
return mMNKProfileMap->getMProfileMap(gemmId)->at(mRounded);
}
template <typename Config, typename RunnerPtr, typename GemmIdType, typename GemmIdHashType>
void GemmPluginProfiler<Config, RunnerPtr, GemmIdType, GemmIdHashType>::allocateTmpData()
{
TLLM_CHECK_WITH_INFO(mTmpWorkspaceSizeInBytes > 0, "tmpWorkspaceSizeInBytes must be larger than 0");
const auto status = cudaMalloc(&mWorkspaceTmp, mTmpWorkspaceSizeInBytes);
TLLM_CHECK_WITH_INFO(status == cudaSuccess, "Can't allocate tmp workspace for GEMM tactics profiling.");
}
template <typename Config, typename RunnerPtr, typename GemmIdType, typename GemmIdHashType>
void GemmPluginProfiler<Config, RunnerPtr, GemmIdType, GemmIdHashType>::freeTmpData()
{
const auto status = cudaFree(mWorkspaceTmp);
TLLM_CHECK_WITH_INFO(status == cudaSuccess, "Can't free tmp workspace for GEMM tactics profiling.");
}
template <typename Config, typename RunnerPtr, typename GemmIdType, typename GemmIdHashType>
std::optional<Config> GemmPluginProfiler<Config, RunnerPtr, GemmIdType, GemmIdHashType>::profileTacticsForProblem(
int m, int n, int k, const std::vector<Config>& tactics)
{
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
float bestTime = std::numeric_limits<float>::max();
Config bestConfig;
bool foundOne = false;
// Iterate over all tactics for given M, N and K
for (int ii = 0; ii < tactics.size(); ++ii)
{
const Config& candidateConfig = tactics[ii];
float time = std::numeric_limits<float>::max();
try
{
if (!checkTactic(m, n, k, candidateConfig))
{
continue;
}
// Profile particualar tactic for given M, N and K
time = profileTacticForProblem(m, n, k, candidateConfig);
foundOne = true;
}
catch (const std::exception& e)
{
std::ostringstream msg;
msg << "Cannot profile configuration " << ii << " (for"
<< " m=" << m << ", n=" << n << ", k=" << k << "). Skipped";
TLLM_LOG_WARNING(msg.str());
continue;
}
// Choose the fastest tactic
if (time < bestTime)
{
bestConfig = candidateConfig;
bestTime = time;
}
}
if (!foundOne)
{
std::ostringstream msg;
msg << "Have not found any valid GEMM config for shape ("
<< "m=" << m << ", n=" << n << ", k=" << k << "). Will try to use default or fail at runtime";
TLLM_LOG_WARNING(msg.str());
return std::nullopt;
}
return {bestConfig};
}
template <typename Config, typename RunnerPtr, typename GemmIdType, typename GemmIdHashType>
float GemmPluginProfiler<Config, RunnerPtr, GemmIdType, GemmIdHashType>::profileTacticForProblem(
int m, int n, int k, const Config& tactic)
{
constexpr int warmup = 5;
constexpr int runs = 10;
cudaStream_t stream = cudaStreamDefault;
// Warmup the execution
for (int i = 0; i < warmup; ++i)
{
runTactic(m, n, k, tactic, mWorkspaceTmp, stream);
}
cudaEvent_t start;
cudaEvent_t stop;
cudaEventCreate(&start);
cudaEventCreate(&stop);
cudaDeviceSynchronize();
cudaEventRecord(start, 0);
// Profile GEMM
for (int i = 0; i < runs; ++i)
{
runTactic(m, n, k, tactic, mWorkspaceTmp, stream);
}
cudaEventRecord(stop, 0);
cudaEventSynchronize(stop);
float elapsed;
cudaEventElapsedTime(&elapsed, start, stop);
cudaEventDestroy(start);
cudaEventDestroy(stop);
return elapsed / runs;
}
template class GemmPluginProfiler<tensorrt_llm::cutlass_extensions::CutlassGemmConfig,
std::shared_ptr<tensorrt_llm::kernels::cutlass_kernels::CutlassInt8GemmRunnerInterface>, GemmIdCore,
GemmIdCoreHash>;
template class GemmPluginProfiler<tensorrt_llm::cutlass_extensions::CutlassGemmConfig,
std::shared_ptr<tensorrt_llm::kernels::cutlass_kernels::CutlassFpAIntBGemmRunnerInterface>, GemmIdCore,
GemmIdCoreHash>;
template class GemmPluginProfiler<cublasLtMatmulHeuristicResult_t,
std::shared_ptr<tensorrt_llm::common::CublasMMWrapper>, GemmIdCublas, GemmIdCublasHash>;
} // namespace tensorrt_llm::plugins

View File

@ -88,7 +88,7 @@ public:
bool operator==(const GemmIdCore& id) const
{
return n == id.n && k == id.k && dtype == id.dtype;
return isEqual(id);
}
friend std::ostream& operator<<(std::ostream& out, const GemmIdCore& id)
@ -97,6 +97,12 @@ public:
out << " type=" << static_cast<int>(id.dtype);
return out;
}
protected:
bool isEqual(const GemmIdCore& id) const
{
return n == id.n && k == id.k && dtype == id.dtype;
}
};
// Hash of GemmId
@ -111,6 +117,50 @@ struct GemmIdCoreHash
}
};
class GemmIdCublas : public GemmIdCore
{
public:
bool transA{};
bool transB{};
GemmIdCublas(int n_, int k_, const nvinfer1::DataType& dtype_, bool transA_, bool transB_)
: GemmIdCore(n_, k_, dtype_)
, transA(transA_)
, transB(transB_)
{
}
GemmIdCublas() {}
bool operator==(const GemmIdCublas& id) const
{
return isEqual(id) && transA == id.transA && transB == id.transB;
}
friend std::ostream& operator<<(std::ostream& out, const GemmIdCublas& id)
{
out << "(N;K)=(" << id.n << ";" << id.k << "),";
out << " type=" << static_cast<int>(id.dtype);
out << " transA=" << id.transA;
out << " transB=" << id.transB;
return out;
}
};
// Hash of GemmIdCublas
struct GemmIdCublasHash
{
std::size_t operator()(const GemmIdCublas& id) const
{
auto h1 = std::hash<int>{}(id.n);
auto h2 = std::hash<int>{}(id.k);
auto h3 = std::hash<int>{}(static_cast<int>(id.dtype));
auto h4 = std::hash<bool>{}(id.transA);
auto h5 = std::hash<bool>{}(id.transB);
return h1 ^ h2 ^ h3 ^ h4 ^ h5;
}
};
template <typename Config, typename RunnerPtr, typename GemmIdType, typename GemmIdHashType>
class GemmPluginProfiler
{
@ -160,118 +210,15 @@ public:
using MNKProfileMapPtr = std::shared_ptr<MNKProfileMap>;
GemmPluginProfiler()
{
mMNKProfileMap = std::make_shared<MNKProfileMap>();
GemmPluginProfiler();
// set SKIP_GEMM_PLUGIN_PROFILINGS=1 to avoid tactics profilings
const auto skip = std::getenv("SKIP_GEMM_PLUGIN_PROFILINGS");
mSkip = (skip != NULL && std::stoi(skip));
if (mSkip)
{
TLLM_LOG_DEBUG(
"SKIP_GEMM_PLUGIN_PROFILINGS is set. Skipping GEMM plugin profilings. It could result in runtime error "
"if default tactic is not defined.");
}
}
void serialize(char*& buffer, const GemmIdType& gemmId) const;
void serialize(char* buffer, const GemmIdType& gemmId) const
{
auto mProfileMap = mMNKProfileMap->getMProfileMap(gemmId);
void deserialize(const char*& data, GemmDims& dims, const GemmIdType& gemmId);
size_t getSerializationSize(const GemmIdType& gemmId) const;
// Save number of profiles for given GEMM ID
write(buffer, static_cast<int>(mProfileMap->size()));
for (const auto& pair : *mProfileMap)
{
// Save pair of M to the best GEMM config
write(buffer, pair);
}
}
void deserialize(const char*& data, GemmDims& dims, const GemmIdType& gemmId)
{
// NOTE(nkorobov): this mutex is not needed since each thread owns its own map, but will put here for
// consistency
writer_lock lock(mMNKProfileMap->mutex);
mDims = dims;
// GemmId gemmId(dims.n, dims.k);
if (!mMNKProfileMap->existsMProfileMap(gemmId))
{
// Create GEMM with GEMM ID if it does not exist
mMNKProfileMap->createMProfileMap(gemmId);
}
// Populate map with profiles of GEMM ID
auto profileMap = mMNKProfileMap->getMProfileMap(gemmId);
int selectedMapSize;
read(data, selectedMapSize);
for (int ii = 0; ii < selectedMapSize; ++ii)
{
std::pair<int, std::optional<Config>> config;
read(data, config);
profileMap->insert(config);
}
}
size_t getSerializationSize(const GemmIdType& gemmId) const
{
reader_lock lock(mMNKProfileMap->mutex);
return sizeof(int) + // size of the tactics map
mMNKProfileMap->getMProfileMap(gemmId)->size()
* sizeof(std::pair<int, std::optional<Config>>); // size of the tactics map
}
void profileTactics(const std::vector<Config>& tactics, const RunnerPtr& runner, const nvinfer1::DataType& type,
const GemmDims& dims, const GemmIdType& gemmId)
{
writer_lock lock(mMNKProfileMap->mutex);
if (!dims.isInitialized())
{
return;
}
mRunner = runner;
mType = type;
const int maxM = std::min(nextPowerOfTwo(dims.maxM), MAX_PROFILE_M);
computeTmpSize(maxM, dims.n, dims.k);
if (!mMNKProfileMap->existsMProfileMap(gemmId))
{
// Create map for GEMM ID
mMNKProfileMap->createMProfileMap(gemmId);
}
if (mSkip)
{
return;
}
auto mProfileMap = mMNKProfileMap->getMProfileMap(gemmId);
auto profileTactics = [&tactics, &mProfileMap, this](int m, int n, int k)
{
if (mProfileMap->count(m) == 0)
{
// Profile different tactics for particular m and insert best config to the map
mProfileMap->insert({m, this->profileTacticsForProblem(m, n, k, tactics)});
}
};
// Allocate tmp data to run GEMMs
allocateTmpData();
const int startMinMRounded = nextPowerOfTwo(dims.minM);
for (int m = startMinMRounded; m < maxM; m *= 2)
{
profileTactics(m, dims.n, dims.k);
}
profileTactics(maxM, dims.n, dims.k);
// Free tmp data
freeTmpData();
}
void profileTactics(
const RunnerPtr& runner, const nvinfer1::DataType& type, const GemmDims& dims, const GemmIdType& gemmId);
void setSelectionTactics(const MNKProfileMapPtr& map)
{
@ -283,19 +230,13 @@ public:
mTmpWorkspaceSizeInBytes = bytes;
}
std::optional<Config> getBestConfig(int m, const GemmIdType& gemmId) const
void setSkip(bool skip)
{
reader_lock lock(mMNKProfileMap->mutex);
if (mSkip)
{
return std::nullopt;
}
const int mRounded = std::min(nextPowerOfTwo(m), MAX_PROFILE_M);
return mMNKProfileMap->getMProfileMap(gemmId)->at(mRounded);
mSkip = mSkip || skip;
}
std::optional<Config> getBestConfig(int m, const GemmIdType& gemmId) const;
protected:
virtual void runTactic(int m, int n, int k, const Config& tactic, char* workspace, const cudaStream_t& stream) = 0;
@ -306,108 +247,16 @@ protected:
return true;
}
virtual std::vector<Config> getTactics(int m, int n, int k) const = 0;
private:
void allocateTmpData()
{
TLLM_CHECK_WITH_INFO(mTmpWorkspaceSizeInBytes > 0, "tmpWorkspaceSizeInBytes must be larger than 0");
const auto status = cudaMalloc(&mWorkspaceTmp, mTmpWorkspaceSizeInBytes);
TLLM_CHECK_WITH_INFO(status == cudaSuccess, "Can't allocate tmp workspace for GEMM tactics profiling.");
}
void allocateTmpData();
void freeTmpData()
{
const auto status = cudaFree(mWorkspaceTmp);
TLLM_CHECK_WITH_INFO(status == cudaSuccess, "Can't free tmp workspace for GEMM tactics profiling.");
}
void freeTmpData();
std::optional<Config> profileTacticsForProblem(int m, int n, int k, const std::vector<Config>& tactics)
{
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
std::optional<Config> profileTacticsForProblem(int m, int n, int k, const std::vector<Config>& tactics);
float bestTime = std::numeric_limits<float>::max();
Config bestConfig;
bool foundOne = false;
// Iterate over all tactics for given M, N and K
for (int ii = 0; ii < tactics.size(); ++ii)
{
const Config& candidateConfig = tactics[ii];
float time = std::numeric_limits<float>::max();
try
{
if (!checkTactic(m, n, k, candidateConfig))
{
continue;
}
// Profile particualar tactic for given M, N and K
time = profileTacticForProblem(m, n, k, candidateConfig);
foundOne = true;
}
catch (const std::exception& e)
{
std::ostringstream msg;
msg << "Cannot profile configuration " << ii << " (for"
<< " m=" << m << ", n=" << n << ", k=" << k << "). Skipped";
TLLM_LOG_WARNING(msg.str());
continue;
}
// Choose the fastest tactic
if (time < bestTime)
{
bestConfig = candidateConfig;
bestTime = time;
}
}
if (!foundOne)
{
std::ostringstream msg;
msg << "Have not found any valid GEMM config for shape ("
<< "m=" << m << ", n=" << n << ", k=" << k << "). Will try to use default or fail at runtime";
TLLM_LOG_WARNING(msg.str());
return std::nullopt;
}
return {bestConfig};
}
float profileTacticForProblem(int m, int n, int k, const Config& tactic)
{
constexpr int warmup = 5;
constexpr int runs = 10;
cudaStream_t stream = cudaStreamDefault;
// Warmup the execution
for (int i = 0; i < warmup; ++i)
{
runTactic(m, n, k, tactic, mWorkspaceTmp, stream);
}
cudaEvent_t start;
cudaEvent_t stop;
cudaEventCreate(&start);
cudaEventCreate(&stop);
cudaDeviceSynchronize();
cudaEventRecord(start, 0);
// Profile GEMM
for (int i = 0; i < runs; ++i)
{
runTactic(m, n, k, tactic, mWorkspaceTmp, stream);
}
cudaEventRecord(stop, 0);
cudaEventSynchronize(stop);
float elapsed;
cudaEventElapsedTime(&elapsed, start, stop);
cudaEventDestroy(start);
cudaEventDestroy(stop);
return elapsed / runs;
}
float profileTacticForProblem(int m, int n, int k, const Config& tactic);
int nextPowerOfTwo(int v) const
{
@ -450,9 +299,10 @@ public:
mMNKProfileMap = std::make_shared<MNKProfileMap>();
}
GemmPluginProfilerPtr createGemmPluginProfiler(bool inference)
GemmPluginProfilerPtr createGemmPluginProfiler(bool inference, bool skip = false)
{
auto profiler = std::make_shared<GemmPluginProfilerType>();
profiler->setSkip(skip);
// If the profiler is created during the engine build,
// mMNKProfileMap is shared between different profilers to minimize the time spent on the profiling
// and do not repeat profiling for the GEMMs of the same shape.

View File

@ -119,6 +119,29 @@ int8_t* nextWorkspacePtrWithAlignment(int8_t* ptr, uintptr_t previousWorkspaceSi
size_t calculateTotalWorkspaceSize(size_t* workspaces, int count, const uintptr_t alignment = kCudaMemAlign);
// Like std::unique_ptr, but does not prevent generation of default copy constructor when used as class members.
// The copy constructor produces nullptr. So the plugin default copy constructor will not really copy this, and
// your clone() implementation is responsible for initializing such data members.
// With this we can simplify clone() implementation when there are many data menbers including at least one unique_ptr.
template <typename T, typename Del = std::default_delete<T>>
class UniqPtrWNullCopy : public std::unique_ptr<T, Del>
{
public:
using std::unique_ptr<T, Del>::unique_ptr;
// for compatibility with std::make_unique
explicit UniqPtrWNullCopy(std::unique_ptr<T, Del>&& src)
: std::unique_ptr<T, Del>::unique_ptr{std::move(src)}
{
}
// copy constructor produces nullptr
UniqPtrWNullCopy(UniqPtrWNullCopy const&)
: std::unique_ptr<T, Del>::unique_ptr{}
{
}
};
} // namespace tensorrt_llm::plugins
inline bool isBuilding()
@ -128,17 +151,6 @@ inline bool isBuilding()
return val != nullptr && std::string(val) == "1";
}
#define MPICHECK(cmd) \
do \
{ \
int e = cmd; \
if (e != MPI_SUCCESS) \
{ \
printf("Failed: MPI error %s:%d '%d'\n", __FILE__, __LINE__, e); \
exit(EXIT_FAILURE); \
} \
} while (0)
#if ENABLE_MULTI_DEVICE
#define NCCLCHECK(cmd) \
do \

View File

@ -0,0 +1,19 @@
; SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
; SPDX-License-Identifier: Apache-2.0
;
; Licensed under the Apache License, Version 2.0 (the "License");
; you may not use this file except in compliance with the License.
; You may obtain a copy of the License at
;
; http://www.apache.org/licenses/LICENSE-2.0
;
; Unless required by applicable law or agreed to in writing, software
; distributed under the License is distributed on an "AS IS" BASIS,
; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
; See the License for the specific language governing permissions and
; limitations under the License.
LIBRARY nvinfer_plugin_tensorrt_llm
EXPORTS
getPluginRegistry
initLibNvInferPlugins

View File

@ -47,15 +47,17 @@ void runGemm(const int M, const int N, const int K, const bool transA, const boo
const CublasGemmWrapperPtr& cublasWrapperPtr, const void* act, const void* weight, void* output,
const std::optional<cublasLtMatmulHeuristicResult_t>& heuristic, void* workspace, cudaStream_t stream)
{
auto cublasHandle = cublasWrapperPtr->getCublasHandle();
TLLM_CUDA_CHECK(cublasSetStream(cublasHandle, stream));
cublasWrapperPtr->setStream(stream);
cublasWrapperPtr->setWorkspace(workspace);
cublasOperation_t transa, transb;
int m, n, k;
int lda, ldb, ldc;
getProblemParams(transa, transb, m, n, k, lda, ldb, ldc, transA, transB, M, N, K);
cublasWrapperPtr->createDescriptors(transa, transb, m, n, k, lda, ldb, ldc);
cublasWrapperPtr->Gemm(transa, transb, m, n, k, weight, lda, act, ldb, output, ldc, heuristic);
cublasWrapperPtr->destroyDescriptors();
}
void CublasLtGemmPluginProfiler::runTactic(
@ -80,11 +82,17 @@ void CublasLtGemmPluginProfiler::runTactic(
bool CublasLtGemmPluginProfiler::checkTactic(int m, int n, int k, const Config& tactic) const
{
cublasOperation_t transa, transb;
int M, N, K;
int M = m, N = n, K = k;
int lda, ldb, ldc;
getProblemParams(transa, transb, m, n, k, lda, ldb, ldc, mTransA, mTransB, n, m, k);
getProblemParams(transa, transb, m, n, k, lda, ldb, ldc, mTransA, mTransB, M, N, K);
return mRunner->checkTactic(transa, transb, m, n, k, lda, ldb, ldc, tactic);
mRunner->createDescriptors(transa, transb, m, n, k, lda, ldb, ldc);
const auto checkResult = mRunner->checkTactic(transa, transb, m, n, k, lda, ldb, ldc, tactic.algo);
mRunner->destroyDescriptors();
return checkResult;
}
void CublasLtGemmPluginProfiler::computeTmpSize(int maxM, int n, int k)
@ -105,6 +113,20 @@ void CublasLtGemmPluginProfiler::computeTmpSize(int maxM, int n, int k)
setTmpWorkspaceSizeInBytes(bytes);
}
std::vector<CublasLtGemmPluginProfiler::Config> CublasLtGemmPluginProfiler::getTactics(int M, int N, int K) const
{
cublasOperation_t transa, transb;
int m, n, k;
int lda, ldb, ldc;
getProblemParams(transa, transb, m, n, k, lda, ldb, ldc, mTransA, mTransB, M, N, K);
mRunner->createDescriptors(transa, transb, m, n, k, lda, ldb, ldc);
const auto heruistics = mRunner->getTactics(transa, transb, m, n, k, lda, ldb, ldc);
mRunner->destroyDescriptors();
return heruistics;
}
GemmPlugin::GemmPlugin(
int transA, int transB, nvinfer1::DataType type, bool useFp8, const GemmPlugin::PluginProfilerPtr& pluginProfiler)
: mTransA(transA)
@ -138,14 +160,11 @@ void GemmPlugin::init()
{
auto cublasHandle = getCublasHandle();
auto cublasLtHandle = getCublasLtHandle();
mCublasAlgoMap = std::make_shared<cublasAlgoMap>(GEMM_CONFIG);
mCublasWrapperMutex = std::make_shared<std::mutex>();
mCublasWrapper = std::make_shared<cublasMMWrapper>(
cublasHandle, cublasLtHandle, nullptr, mCublasAlgoMap.get(), mCublasWrapperMutex.get(), nullptr);
mCublasWrapper = std::make_shared<CublasMMWrapper>(cublasHandle, cublasLtHandle, nullptr, nullptr);
mPluginProfiler->setTranspose(mTransA, mTransB);
mGemmId = GemmIdCublas(GemmIdCore(mDims.n, mDims.k, mType), mTransA, mTransB);
mGemmId = GemmIdCublas(mDims.n, mDims.k, mType, mTransA, mTransB);
}
void GemmPlugin::setGemmConfig()
@ -182,18 +201,7 @@ void GemmPlugin::configGemm()
setGemmConfig();
std::vector<cublasLtMatmulHeuristicResult_t> totalHeruistics;
for (int mCur = mDims.minM; mCur < mDims.maxM; mCur *= 2)
{
cublasOperation_t transa, transb;
int m, n, k;
int lda, ldb, ldc;
getProblemParams(transa, transb, m, n, k, lda, ldb, ldc, mTransA, mTransB, mCur, mDims.n, mDims.k);
const auto heruistics = mCublasWrapper->getTactics(transa, transb, m, n, k, lda, ldb, ldc);
totalHeruistics.insert(totalHeruistics.end(), heruistics.begin(), heruistics.end());
}
mPluginProfiler->profileTactics(totalHeruistics, mCublasWrapper, mType, mDims, mGemmId);
mPluginProfiler->profileTactics(mCublasWrapper, mType, mDims, mGemmId);
}
// IPluginV2DynamicExt Methods
@ -313,7 +321,8 @@ void GemmPlugin::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, in
{
mDims = {minM, maxM, N, K};
}
mGemmId.gemmIdCore = {N, K, mType};
mGemmId.n = N;
mGemmId.k = K;
}
size_t GemmPlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs,
@ -339,9 +348,7 @@ int GemmPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinf
const auto N = computeNDimension(mTransB, nbDimsB, inputDesc[1].dims.d);
const int K = mTransA ? inputDesc[0].dims.d[0] : inputDesc[0].dims.d[nbDimsA - 1];
// FIXME(nkorobov): enable best config selection
// const auto& bestTactic = mPluginProfiler->getBestConfig(M, mGemmId);
const std::optional<CublasLtGemmPluginProfiler::Config> bestTactic = {};
auto bestTactic = mPluginProfiler->getBestConfig(M, mGemmId);
runGemm(M, N, K, mTransA, mTransB, mType, mCublasWrapper, inputs[0], inputs[1], outputs[0], bestTactic, workspace,
stream);
return 0;
@ -465,7 +472,8 @@ IPluginV2* GemmPluginCreator::createPlugin(const char* name, const PluginFieldCo
{
// GemmPluginCreator is unique and shared for an engine generation
// Create plugin profiler with shared tactics map
auto pluginProfiler = gemmPluginProfileManager.createGemmPluginProfiler(/* inference */ false);
// FIXME(nkorobov) enable tactic profiler
auto pluginProfiler = gemmPluginProfileManager.createGemmPluginProfiler(/* inference */ false, /* skip */ true);
auto* obj = new GemmPlugin(transA, transB, type, useFp8, pluginProfiler);
obj->setPluginNamespace(mNamespace.c_str());
return obj;
@ -485,7 +493,8 @@ IPluginV2* GemmPluginCreator::deserializePlugin(const char* name, const void* se
{
// GemmPluginCreator is unique and shared for an engine generation
// Create plugin profiler with shared tactics map
auto pluginProfiler = gemmPluginProfileManager.createGemmPluginProfiler(/* inference */ true);
// FIXME(nkorobov) enable tactic profiler
auto pluginProfiler = gemmPluginProfileManager.createGemmPluginProfiler(/* inference */ true, /* skip */ true);
auto* obj = new GemmPlugin(serialData, serialLength, pluginProfiler);
obj->setPluginNamespace(mNamespace.c_str());
return obj;

View File

@ -27,51 +27,9 @@
namespace tensorrt_llm::plugins
{
using CublasGemmWrapper = tensorrt_llm::common::cublasMMWrapper;
using CublasGemmWrapper = tensorrt_llm::common::CublasMMWrapper;
using CublasGemmWrapperPtr = std::shared_ptr<CublasGemmWrapper>;
class GemmIdCublas
{
public:
GemmIdCore gemmIdCore{};
bool transA{};
bool transB{};
GemmIdCublas(const GemmIdCore& gemmIdCore_, bool transA_, bool transB_)
: gemmIdCore(gemmIdCore_)
, transA(transA_)
, transB(transB_)
{
}
GemmIdCublas() {}
bool operator==(const GemmIdCublas& id) const
{
return gemmIdCore == id.gemmIdCore && transA == id.transA && transB == id.transB;
}
friend std::ostream& operator<<(std::ostream& out, const GemmIdCublas& id)
{
out << "Core ID = {" << id.gemmIdCore << "}";
out << " transA=" << id.transA;
out << " transB=" << id.transB;
return out;
}
};
// Hash of GemmIdCublas
struct GemmIdCublasHash
{
std::size_t operator()(const GemmIdCublas& id) const
{
auto h1 = GemmIdCoreHash()(id.gemmIdCore);
auto h2 = std::hash<bool>{}(id.transA);
auto h3 = std::hash<bool>{}(id.transB);
return h1 ^ h2 ^ h3;
}
};
class CublasLtGemmPluginProfiler
: public GemmPluginProfiler<cublasLtMatmulHeuristicResult_t, CublasGemmWrapperPtr, GemmIdCublas, GemmIdCublasHash>
{
@ -91,6 +49,8 @@ protected:
bool checkTactic(int m, int n, int k, const Config& tactic) const override;
std::vector<Config> getTactics(int m, int n, int k) const override;
private:
bool mTransA;
bool mTransB;
@ -150,8 +110,8 @@ private:
int mTransB;
nvinfer1::DataType mType;
std::shared_ptr<tensorrt_llm::common::cublasAlgoMap> mCublasAlgoMap;
std::shared_ptr<std::mutex> mCublasWrapperMutex;
// @fixme: seems this is shared across multiple clones.
// If we deep copy the wrapper inside clone(), then we may avoid the mutex inside the wrapper?
CublasGemmWrapperPtr mCublasWrapper;
GemmDims mDims{};

View File

@ -79,7 +79,9 @@ struct FusedQKVMaskedAttentionDispatchParams
int size_per_head;
int rotary_embedding_dim;
float rotary_embedding_base;
RotaryScalingType rotary_embedding_scale_type;
float rotary_embedding_scale;
int rotary_embedding_max_positions;
PositionEmbeddingType position_embedding_type;
int max_seq_len;
const int* input_lengths;
@ -160,7 +162,9 @@ void fusedQKV_masked_attention_dispatch(
params.hidden_size_per_head = input_params.size_per_head;
params.rotary_embedding_dim = input_params.rotary_embedding_dim;
params.rotary_embedding_base = input_params.rotary_embedding_base;
params.rotary_embedding_scale_type = input_params.rotary_embedding_scale_type;
params.rotary_embedding_scale = input_params.rotary_embedding_scale;
params.rotary_embedding_max_positions = input_params.rotary_embedding_max_positions;
params.position_embedding_type = input_params.position_embedding_type;
// Note: keep norm factor (sqrt(K_dim)) when adopting megatron T5 structure (may adjust)
params.inv_sqrt_dh = 1.F / (sqrtf((float) params.hidden_size_per_head) * input_params.q_scaling);
@ -215,17 +219,17 @@ template void fusedQKV_masked_attention_dispatch(
template void fusedQKV_masked_attention_dispatch(
const FusedQKVMaskedAttentionDispatchParams<half, KVBlockArray>&, cudaStream_t stream);
GPTAttentionPluginCommon::GPTAttentionPluginCommon(int num_heads, int num_kv_heads, int unidirectional, float q_scaling,
tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
GPTAttentionPluginCommon::GPTAttentionPluginCommon(int num_heads, int num_kv_heads, int head_size, int unidirectional,
float q_scaling, tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
int rotary_embedding_dim, // for RoPE. Use 0 for non-RoPE
float rotary_embedding_base, tensorrt_llm::kernels::RotaryScalingType rotary_embedding_scale_type,
float rotary_embedding_scale, int rotary_embedding_max_positions, int tp_size, int tp_rank, // for ALiBi
tensorrt_llm::kernels::ContextFMHAType context_fmha_type, bool multi_block_mode, int kv_cache_quant_mode,
bool remove_input_padding, tensorrt_llm::kernels::AttentionMaskType mask_type, bool paged_kv_cache,
nvinfer1::DataType type, int32_t max_context_length, bool qkv_bias_enabled)
int tokens_per_block, nvinfer1::DataType type, int32_t max_context_length, bool qkv_bias_enabled)
: mNumHeads(num_heads)
, mNumKVHeads(num_kv_heads)
, mHeadSize(-1)
, mHeadSize(head_size)
, mUnidirectional(unidirectional)
, mQScaling(q_scaling)
, mRotaryEmbeddingDim(rotary_embedding_dim)
@ -242,6 +246,7 @@ GPTAttentionPluginCommon::GPTAttentionPluginCommon(int num_heads, int num_kv_hea
, mKVCacheQuantMode(kv_cache_quant_mode)
, mRemovePadding(remove_input_padding)
, mPagedKVCache(paged_kv_cache)
, mTokensPerBlock(tokens_per_block)
, mTpSize(tp_size)
, mTpRank(tp_rank)
, mMaxContextLength(max_context_length)
@ -288,6 +293,7 @@ GPTAttentionPluginCommon::GPTAttentionPluginCommon(const void* data, size_t leng
read(d, mRemovePadding);
read(d, mMaskType);
read(d, mPagedKVCache);
read(d, mTokensPerBlock);
read(d, mType);
read(d, mMaxContextLength);
read(d, mQKVBiasEnabled);
@ -313,9 +319,9 @@ size_t GPTAttentionPluginCommon::getWorkspaceSizeForContext(
const int batch_size = nbReq;
const size_t attention_mask_size = mEnableContextFMHA ? 0 : size * batch_size * max_input_length * max_input_length;
const size_t cu_seqlens_size = sizeof(int) * (batch_size + 1);
const size_t q_buf_2_size = size * batch_size * input_seq_length * local_hidden_units_qo;
const size_t k_buf_2_size = size * batch_size * input_seq_length * local_hidden_units_kv;
const size_t v_buf_2_size = size * batch_size * input_seq_length * local_hidden_units_kv;
const size_t q_buf_2_size = mEnableContextFMHA ? 0 : size * batch_size * input_seq_length * local_hidden_units_qo;
const size_t k_buf_2_size = mEnableContextFMHA ? 0 : size * batch_size * input_seq_length * local_hidden_units_kv;
const size_t v_buf_2_size = mEnableContextFMHA ? 0 : size * batch_size * input_seq_length * local_hidden_units_kv;
const size_t qk_buf_size
= mEnableContextFMHA ? 0 : size * batch_size * mNumHeads * input_seq_length * input_seq_length;
const size_t qkv_buf_2_size = mEnableContextFMHA ? 0 : size * batch_size * input_seq_length * local_hidden_units_qo;
@ -403,8 +409,8 @@ int GPTAttentionPluginCommon::enqueueContext(const EnqueueContextParams<T, KVCac
if (mPagedKVCache)
{
using BufferDataType = typename KVCacheBufferDataType<KVCacheBuffer>::Type;
kv_cache_buffer = KVCacheBuffer(params.batch_size, params.max_blocks_per_sequence, params.tokens_per_block,
num_kv_heads * head_size * elem_size);
kv_cache_buffer = KVCacheBuffer(
params.batch_size, params.max_blocks_per_sequence, mTokensPerBlock, num_kv_heads * head_size * elem_size);
kv_cache_buffer.data = reinterpret_cast<BufferDataType*>(params.block_pointers);
}
else
@ -455,9 +461,12 @@ int GPTAttentionPluginCommon::enqueueContext(const EnqueueContextParams<T, KVCac
const size_t attention_mask_size
= mEnableContextFMHA ? 0 : sizeof(T) * params.batch_size * params.input_seq_length * params.input_seq_length;
const size_t cu_seqlens_size = sizeof(int) * (params.batch_size + 1);
const size_t q_buf_2_size = sizeof(T) * params.batch_size * params.input_seq_length * local_hidden_units_qo;
const size_t k_buf_2_size = sizeof(T) * params.batch_size * params.input_seq_length * local_hidden_units_kv;
const size_t v_buf_2_size = sizeof(T) * params.batch_size * params.input_seq_length * local_hidden_units_kv;
const size_t q_buf_2_size
= mEnableContextFMHA ? 0 : sizeof(T) * params.batch_size * params.input_seq_length * local_hidden_units_qo;
const size_t k_buf_2_size
= mEnableContextFMHA ? 0 : sizeof(T) * params.batch_size * params.input_seq_length * local_hidden_units_kv;
const size_t v_buf_2_size
= mEnableContextFMHA ? 0 : sizeof(T) * params.batch_size * params.input_seq_length * local_hidden_units_kv;
const size_t qk_buf_size = mEnableContextFMHA
? 0
: sizeof(T) * params.batch_size * mNumHeads * params.input_seq_length * params.input_seq_length;
@ -498,30 +507,9 @@ int GPTAttentionPluginCommon::enqueueContext(const EnqueueContextParams<T, KVCac
invokeBuildDecoderInfo(decoder_params, stream);
sync_check_cuda_error();
// FIXME(qijun): a temporary solution to make sure the padding part of key/value buffer is 0
// NOTE: pointer subtraction is used below since there could be some extra gap due to alignment.
// Otherwise, we could do cudaMemsetAsync(k_buf_2_, 0, k_buf_2_size + v_buf_2_size, stream);
cudaMemsetAsync(
k_buf_2_, 0, reinterpret_cast<int8_t*>(v_buf_2_) - reinterpret_cast<int8_t*>(k_buf_2_) + v_buf_2_size, stream);
float rotary_base, rotary_scale;
const int32_t kv_seq_len = params.input_seq_length;
update_rotary_params(kv_seq_len, rotary_base, rotary_scale);
invokeAddFusedQKVBiasTranspose(q_buf_2_, k_buf_2_, v_buf_2_, const_cast<T*>(params.attention_input),
const_cast<T*>(params.qkv_bias), params.context_lengths, mRemovePadding ? padding_offset : nullptr,
request_batch_size, request_seq_length, params.num_tokens, mNumHeads, mNumKVHeads, getHeadSize(),
mEnableContextFMHA, mRotaryEmbeddingDim, rotary_base, rotary_scale, position_embedding_type, (float*) nullptr,
0, stream);
sync_check_cuda_error();
const KvCacheDataType cache_type = mKVCacheQuantMode.hasInt8KvCache()
? KvCacheDataType::INT8
: (mKVCacheQuantMode.hasFp8KvCache() ? KvCacheDataType::FP8 : KvCacheDataType::BASE);
invokeTranspose4dBatchMajor(k_buf_2_, v_buf_2_, kv_cache_buffer, request_batch_size, request_seq_length,
params.max_seq_length, getHeadSize(), mNumKVHeads, cache_type, params.kv_scale_orig_quant,
params.context_lengths, stream);
sync_check_cuda_error();
const cudaDataType_t gemm_data_type = tc::CudaDataType<T>::value;
const int attention_seq_len_1 = request_seq_length; // q length
@ -530,11 +518,34 @@ int GPTAttentionPluginCommon::enqueueContext(const EnqueueContextParams<T, KVCac
if (mEnableContextFMHA)
{
invokeApplyBiasRopeUpdateKVCache(const_cast<T*>(params.attention_input), kv_cache_buffer,
const_cast<T*>(params.qkv_bias), params.context_lengths, mRemovePadding ? padding_offset : nullptr,
request_batch_size, request_seq_length, params.num_tokens, mNumHeads, mNumKVHeads, getHeadSize(),
mRotaryEmbeddingDim, mRotaryEmbeddingBase, mRotaryEmbeddingScaleType, mRotaryEmbeddingScale,
mRotaryEmbeddingMaxPositions, position_embedding_type, (float*) nullptr, 0, cache_type,
params.kv_scale_orig_quant, stream);
mFMHARunner->setup(request_batch_size, request_seq_length, params.num_tokens, isALiBi(), mTpSize, mTpRank);
mFMHARunner->run(const_cast<T*>(params.attention_input), cu_seqlens, params.context_buf, stream);
}
else
{
// FIXME(qijun): a temporary solution to make sure the padding part of key/value buffer is 0
// NOTE: pointer subtraction is used below since there could be some extra gap due to alignment.
// Otherwise, we could do cudaMemsetAsync(k_buf_2_, 0, k_buf_2_size + v_buf_2_size, stream);
cudaMemsetAsync(k_buf_2_, 0,
reinterpret_cast<int8_t*>(v_buf_2_) - reinterpret_cast<int8_t*>(k_buf_2_) + v_buf_2_size, stream);
invokeAddFusedQKVBiasTranspose(q_buf_2_, k_buf_2_, v_buf_2_, const_cast<T*>(params.attention_input),
const_cast<T*>(params.qkv_bias), params.context_lengths, mRemovePadding ? padding_offset : nullptr,
request_batch_size, request_seq_length, params.num_tokens, mNumHeads, mNumKVHeads, getHeadSize(),
mEnableContextFMHA, mRotaryEmbeddingDim, mRotaryEmbeddingBase, mRotaryEmbeddingScaleType,
mRotaryEmbeddingScale, mRotaryEmbeddingMaxPositions, position_embedding_type, (float*) nullptr, 0, stream);
sync_check_cuda_error();
invokeTranspose4dBatchMajor(k_buf_2_, v_buf_2_, kv_cache_buffer, request_batch_size, request_seq_length,
params.max_seq_length, getHeadSize(), mNumKVHeads, cache_type, params.kv_scale_orig_quant,
params.context_lengths, stream);
sync_check_cuda_error();
const T* linear_bias_slopes = isALiBi() ? params.alibi_slopes : nullptr;
cudaDataType_t gemm_out_data_type = is_qk_buf_float_ ? CUDA_R_32F : gemm_data_type;
void* gemm_out_buf_ = is_qk_buf_float_ ? static_cast<void*>(qk_buf_float_) : static_cast<void*>(qk_buf_);
@ -787,7 +798,7 @@ int GPTAttentionPluginCommon::enqueueGeneration(
{
using BufferDataType = typename KVCacheBufferDataType<KVCacheBuffer>::Type;
kv_cache_buffer = KVCacheBuffer(
batch_beam, params.max_blocks_per_sequence, params.tokens_per_block, num_kv_heads * head_size * elem_size);
batch_beam, params.max_blocks_per_sequence, mTokensPerBlock, num_kv_heads * head_size * elem_size);
kv_cache_buffer.data = reinterpret_cast<BufferDataType*>(params.block_pointers);
}
else
@ -841,8 +852,10 @@ int GPTAttentionPluginCommon::enqueueGeneration(
dispatch_params.kv_scale_quant_orig = params.kv_scale_quant_orig;
dispatch_params.kv_block_array = kv_cache_buffer;
dispatch_params.multi_processor_count = mMultiProcessorCount;
const int32_t kv_seq_len = step;
update_rotary_params(kv_seq_len, dispatch_params.rotary_embedding_base, dispatch_params.rotary_embedding_scale);
dispatch_params.rotary_embedding_base = mRotaryEmbeddingBase;
dispatch_params.rotary_embedding_scale_type = mRotaryEmbeddingScaleType;
dispatch_params.rotary_embedding_scale = mRotaryEmbeddingScale;
dispatch_params.rotary_embedding_max_positions = mRotaryEmbeddingMaxPositions;
fusedQKV_masked_attention_dispatch(dispatch_params, stream);
sync_check_cuda_error();
return 0;
@ -875,10 +888,7 @@ int GPTAttentionPluginCommon::initialize() noexcept
auto cublasHandle = getCublasHandle();
auto cublasLtHandle = getCublasLtHandle();
mCublasAlgoMap = new tc::cublasAlgoMap(GEMM_CONFIG);
mCublasWrapperMutex = new std::mutex();
mCublasWrapper
= new tc::cublasMMWrapper(cublasHandle, cublasLtHandle, nullptr, mCublasAlgoMap, mCublasWrapperMutex, nullptr);
mCublasWrapper.reset(new tc::CublasMMWrapper(cublasHandle, cublasLtHandle, nullptr, nullptr));
if (mEnableContextFMHA)
{
// Pre-checked during constructing.
@ -896,7 +906,7 @@ int GPTAttentionPluginCommon::initialize() noexcept
TLLM_CHECK_WITH_INFO(false, "GPTAttentionPlugin received wrong data type.");
}
mFMHARunner = new FusedMHARunnerV2(data_type, mNumHeads, getHeadSize(false), mQScaling);
mFMHARunner.reset(new FusedMHARunnerV2(data_type, mNumHeads, getHeadSize(false), mQScaling));
// set flags: force_fp32_acc, is_s_padded, causal_mask, num_kv_heads.
mFMHARunner->setup_flags(mFMHAForceFP32Acc, !mRemovePadding, true, mNumKVHeads);
}
@ -906,19 +916,6 @@ int GPTAttentionPluginCommon::initialize() noexcept
void GPTAttentionPluginCommon::destroy() noexcept
{
delete mCublasAlgoMap;
delete mCublasWrapperMutex;
delete mCublasWrapper;
if (mEnableContextFMHA)
{
delete mFMHARunner;
}
mCublasAlgoMap = nullptr;
mCublasWrapperMutex = nullptr;
mCublasWrapper = nullptr;
mFMHARunner = nullptr;
delete this;
}
@ -929,8 +926,8 @@ size_t GPTAttentionPluginCommon::getCommonSerializationSize() noexcept
+ sizeof(mRotaryEmbeddingScaleType) + sizeof(mRotaryEmbeddingScale) + sizeof(mRotaryEmbeddingMaxPositions)
+ sizeof(mTpSize) + sizeof(mTpRank) + sizeof(mEnableContextFMHA) + sizeof(mFMHAForceFP32Acc)
+ sizeof(mMultiBlockMode) + sizeof(unsigned int) // mKVCacheQuantMode
+ sizeof(mRemovePadding) + sizeof(mMaskType) + sizeof(mPagedKVCache) + sizeof(mType) + sizeof(mMaxContextLength)
+ sizeof(mQKVBiasEnabled);
+ sizeof(mRemovePadding) + sizeof(mMaskType) + sizeof(mPagedKVCache) + sizeof(mTokensPerBlock) + sizeof(mType)
+ sizeof(mMaxContextLength) + sizeof(mQKVBiasEnabled);
}
void GPTAttentionPluginCommon::serializeCommon(void* buffer) const noexcept
@ -956,6 +953,7 @@ void GPTAttentionPluginCommon::serializeCommon(void* buffer) const noexcept
write(d, mRemovePadding);
write(d, mMaskType);
write(d, mPagedKVCache);
write(d, mTokensPerBlock);
write(d, mType);
write(d, mMaxContextLength);
write(d, mQKVBiasEnabled);
@ -975,6 +973,7 @@ GPTAttentionPluginCreatorCommon::GPTAttentionPluginCreatorCommon()
mPluginAttributes.clear();
mPluginAttributes.emplace_back(PluginField("num_heads", nullptr, PluginFieldType::kINT32, -1));
mPluginAttributes.emplace_back(PluginField("num_kv_heads", nullptr, PluginFieldType::kINT32, 0));
mPluginAttributes.emplace_back(PluginField("head_size", nullptr, PluginFieldType::kINT32, 0));
mPluginAttributes.emplace_back(PluginField("unidirectional", nullptr, PluginFieldType::kINT32, 1));
mPluginAttributes.emplace_back(PluginField("q_scaling", nullptr, PluginFieldType::kFLOAT32, 1.0));
mPluginAttributes.emplace_back(PluginField("position_embedding_type", nullptr, PluginFieldType::kINT8, 0));
@ -991,6 +990,7 @@ GPTAttentionPluginCreatorCommon::GPTAttentionPluginCreatorCommon()
mPluginAttributes.emplace_back(PluginField("remove_input_padding", nullptr, PluginFieldType::kINT8, 0));
mPluginAttributes.emplace_back(PluginField("mask_type", nullptr, PluginFieldType::kINT32, 0));
mPluginAttributes.emplace_back(PluginField("paged_kv_cache", nullptr, PluginFieldType::kINT32, 0));
mPluginAttributes.emplace_back(PluginField("tokens_per_block", nullptr, PluginFieldType::kINT32, 0));
mPluginAttributes.emplace_back(PluginField("type_id", nullptr, PluginFieldType::kINT32, 1));
mPluginAttributes.emplace_back(PluginField("max_context_length", nullptr, PluginFieldType::kINT32, 1));
mPluginAttributes.emplace_back(PluginField("qkv_bias_enabled", nullptr, PluginFieldType::kINT8, 0));

View File

@ -35,14 +35,14 @@ class GPTAttentionPluginCommon : public BasePlugin
public:
GPTAttentionPluginCommon() = delete;
GPTAttentionPluginCommon(int num_heads, int num_kv_heads, int unidirectional, float q_scaling,
GPTAttentionPluginCommon(int num_heads, int num_kv_heads, int head_size, int unidirectional, float q_scaling,
tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
int rotary_embedding_dim, // for RoPE. Use 0 for non-RoPE
float rotary_embedding_base, tensorrt_llm::kernels::RotaryScalingType rotary_embedding_scale_type,
float rotary_embedding_scale, int rotary_embedding_max_positions, int tp_size, int tp_rank, // for ALiBi
tensorrt_llm::kernels::ContextFMHAType context_fmha_type, bool multi_block_mode, int kv_cache_quant_mode,
bool remove_input_padding, tensorrt_llm::kernels::AttentionMaskType mask_type, bool paged_kv_cache,
nvinfer1::DataType type, int32_t max_context_length, bool qkv_bias_enabled);
int tokens_per_block, nvinfer1::DataType type, int32_t max_context_length, bool qkv_bias_enabled);
GPTAttentionPluginCommon(const void* data, size_t length);
@ -93,7 +93,6 @@ protected:
void* block_pointers;
int32_t batch_size;
int32_t num_tokens;
int32_t tokens_per_block;
int32_t max_blocks_per_sequence;
void* workspace;
};
@ -118,7 +117,6 @@ protected:
void* block_pointers;
int32_t max_seq_lengths; // cache capacity
int32_t num_requests;
int32_t tokens_per_block;
int32_t max_blocks_per_sequence;
int32_t const* cache_indir;
void* workspace;
@ -138,24 +136,6 @@ protected:
|| mPositionEmbeddingType == tensorrt_llm::kernels::PositionEmbeddingType::kROPE_GPT_NEOX;
}
inline void update_rotary_params(int32_t kv_seq_len, float& base, float& scale)
{
base = mRotaryEmbeddingBase;
scale = 1.0f / mRotaryEmbeddingScale; // do the division here so that we can avoid it in the kernel
if (mPositionEmbeddingType == tensorrt_llm::kernels::PositionEmbeddingType::kROPE_GPT_NEOX
&& mRotaryEmbeddingScaleType == tensorrt_llm::kernels::RotaryScalingType::kDYNAMIC)
{
if (kv_seq_len > mRotaryEmbeddingMaxPositions)
{
const float b
= (mRotaryEmbeddingScale * kv_seq_len / mRotaryEmbeddingMaxPositions) - (mRotaryEmbeddingScale - 1);
const float p = static_cast<float>(mRotaryEmbeddingDim) / (mRotaryEmbeddingDim - 2);
base = mRotaryEmbeddingBase * pow(b, p);
}
scale = 1.0f; // scale factor is already used in updated base
}
}
protected:
const std::string mLayerName;
@ -173,6 +153,7 @@ protected:
bool mRemovePadding = false;
tensorrt_llm::kernels::AttentionMaskType mMaskType;
bool mPagedKVCache = false;
int mTokensPerBlock;
tensorrt_llm::common::QuantMode mKVCacheQuantMode;
int mTpSize = 1;
int mTpRank = 0;
@ -186,13 +167,13 @@ protected:
bool mFMHAForceFP32Acc = false;
int mSM = tensorrt_llm::common::getSMVersion();
int mMultiProcessorCount = tensorrt_llm::common::getMultiProcessorCount();
tensorrt_llm::kernels::MHARunner* mFMHARunner;
// The default copy constructor will leave it as nullptr. clone() shall initialize it.
UniqPtrWNullCopy<tensorrt_llm::kernels::MHARunner> mFMHARunner;
bool mMultiBlockMode;
int mDeviceId = -1;
tensorrt_llm::common::cublasAlgoMap* mCublasAlgoMap;
std::mutex* mCublasWrapperMutex;
tensorrt_llm::common::cublasMMWrapper* mCublasWrapper;
// The default copy constructor will leave it as nullptr. clone() shall initialize it.
UniqPtrWNullCopy<tensorrt_llm::common::CublasMMWrapper> mCublasWrapper;
};
class GPTAttentionPluginCreatorCommon : public BaseCreator

View File

@ -35,18 +35,18 @@ using tensorrt_llm::plugins::GPTAttentionPlugin;
static const char* GPT_ATTENTION_PLUGIN_VERSION{"1"};
static const char* GPT_ATTENTION_PLUGIN_NAME{"GPTAttention"};
GPTAttentionPlugin::GPTAttentionPlugin(int num_heads, int num_kv_heads, int unidirectional, float q_scaling,
tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
GPTAttentionPlugin::GPTAttentionPlugin(int num_heads, int num_kv_heads, int head_size, int unidirectional,
float q_scaling, tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
int rotary_embedding_dim, // for RoPE. 0 for non-RoPE
float rotary_embedding_base, tensorrt_llm::kernels::RotaryScalingType rotary_embedding_scale_type,
float rotary_embedding_scale, int rotary_embedding_max_positions, int tp_size, int tp_rank, // for ALiBi
tensorrt_llm::kernels::ContextFMHAType context_fmha_type, bool multi_block_mode, int kv_cache_quant_mode,
bool remove_input_padding, tensorrt_llm::kernels::AttentionMaskType mask_type, bool paged_kv_cache,
nvinfer1::DataType type, int32_t max_context_length, bool qkv_bias_enabled)
: GPTAttentionPluginCommon(num_heads, num_kv_heads, unidirectional, q_scaling, position_embedding_type,
int tokens_per_block, nvinfer1::DataType type, int32_t max_context_length, bool qkv_bias_enabled)
: GPTAttentionPluginCommon(num_heads, num_kv_heads, head_size, unidirectional, q_scaling, position_embedding_type,
rotary_embedding_dim, rotary_embedding_base, rotary_embedding_scale_type, rotary_embedding_scale,
rotary_embedding_max_positions, tp_size, tp_rank, context_fmha_type, multi_block_mode, kv_cache_quant_mode,
remove_input_padding, mask_type, paged_kv_cache, type, max_context_length, qkv_bias_enabled)
remove_input_padding, mask_type, paged_kv_cache, tokens_per_block, type, max_context_length, qkv_bias_enabled)
{
}
@ -63,17 +63,17 @@ GPTAttentionPlugin* GPTAttentionPlugin::clone() const noexcept
// outputs
// output_tensor [batch_size, seq_len, local_hidden_size]
// present_key_value_pool [blocks, 2, local_num_kv_heads, tokens_per_block, head_size] if paged_kv_attention
// or [batch_size, 2, local_num_kv_heads, max_seq_len, head_size]
// present_key_value_pool (optional if mPagedKVCache is false) [batch_size, 2, local_num_kv_heads, max_seq_len,
// head_size]
nvinfer1::DimsExprs GPTAttentionPlugin::getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept
{
TLLM_CHECK(outputIndex == 0 || outputIndex == 1);
TLLM_CHECK(outputIndex == 0 || (!mPagedKVCache && outputIndex == 1));
if (outputIndex == 0)
{
auto ret = inputs[getInputTensorIdx()];
ret.d[2] = exprBuilder.operation(
DimensionOperation::kPROD, *inputs[getPastKeyValueIdx()].d[4], *exprBuilder.constant(mNumHeads));
DimensionOperation::kPROD, *exprBuilder.constant(mHeadSize), *exprBuilder.constant(mNumHeads));
return ret;
}
return inputs[getPastKeyValueIdx()];
@ -96,13 +96,20 @@ bool GPTAttentionPlugin::supportsFormatCombination(
else if (mPagedKVCache && pos == getKVCacheBlockPointersIdx())
{
// pointers to kv cache blocks
return inOut[pos].type == nvinfer1::DataType::kINT32 && inOut[pos].format == TensorFormat::kLINEAR;
return inOut[pos].type == nvinfer1::DataType::kINT64 && inOut[pos].format == TensorFormat::kLINEAR;
}
else if (mKVCacheQuantMode.hasInt8KvCache() && (pos == getPastKeyValueIdx() || pos == nbInputs + 1))
else if (mKVCacheQuantMode.hasInt8KvCache()
&& (!mPagedKVCache && (pos == getPastKeyValueIdx() || pos == nbInputs + 1)))
{
// If use Int8 K/V cache we require I/O KV values to int8
return (inOut[pos].type == nvinfer1::DataType::kINT8) && (inOut[pos].format == TensorFormat::kLINEAR);
}
else if (mKVCacheQuantMode.hasFp8KvCache()
&& (!mPagedKVCache && (pos == getPastKeyValueIdx() || pos == nbInputs + 1)))
{
// If use FP8 K/V cache we require I/O KV values to FP8
return (inOut[pos].type == nvinfer1::DataType::kFP8) && (inOut[pos].format == TensorFormat::kLINEAR);
}
else if (mRemovePadding && (pos == getHostContextLengthsIdx()))
{
return inOut[pos].type == nvinfer1::DataType::kINT32 && inOut[pos].format == TensorFormat::kLINEAR;
@ -117,7 +124,6 @@ bool GPTAttentionPlugin::supportsFormatCombination(
void GPTAttentionPlugin::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept
{
mHeadSize = in[getPastKeyValueIdx()].desc.dims.d[4];
TLLM_CHECK(mHeadSize > 0);
// pre-check whether FMHA is supported in order to save memory allocation
@ -242,17 +248,13 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32
}
int max_blocks_per_sequence = 0;
int tokens_per_block = 0;
void* block_pointers = nullptr;
if (mPagedKVCache)
{
auto& kvCacheBlockPointers = inputDesc[getKVCacheBlockPointersIdx()];
auto& kvCacheBlockPointersShape = inputDesc[getKVCacheBlockPointersIdx()].dims;
// Div by 2 because we reinterpret int32 input as int64
max_blocks_per_sequence = kvCacheBlockPointersShape.d[kvCacheBlockPointersShape.nbDims - 1] / 2;
tokens_per_block = inputDesc[getPastKeyValueIdx()].dims.d[3];
// Div by 2 because we reinterpret int32 input as int64
auto offset = getStride(kvCacheBlockPointersShape, 0) / 2 * seqIdxBeg;
max_blocks_per_sequence = kvCacheBlockPointersShape.d[kvCacheBlockPointersShape.nbDims - 1];
auto offset = getStride(kvCacheBlockPointersShape, 0) * seqIdxBeg;
auto const typed_block_pointers = static_cast<void* const*>(inputs[getKVCacheBlockPointersIdx()]) + offset;
block_pointers = const_cast<void*>(static_cast<void const*>(typed_block_pointers));
}
@ -273,7 +275,7 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32
enqueueContext<T, KVCacheBuffer>(
EnqueueContextParams<T, KVCacheBuffer>{attention_input, qkv_bias, max_context_len, maxSeqLen,
context_lengths, kv_scale_orig_quant, kv_scale_quant_orig, alibi_slopes, context_buf_, key_value_cache,
block_pointers, batch_size, localNbTokens, tokens_per_block, max_blocks_per_sequence, workspace},
block_pointers, batch_size, localNbTokens, max_blocks_per_sequence, workspace},
stream);
}
else // generation stage; input_seq_len == 1
@ -290,8 +292,8 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32
enqueueGeneration<T, KVCacheBuffer>(
EnqueueGenerationParams<T, KVCacheBuffer>{attention_input, qkv_bias, sequence_length, past_kv_len,
beamWidth, context_lengths, kv_scale_orig_quant, kv_scale_quant_orig, alibi_slopes, context_buf_,
key_value_cache, block_pointers, maxSeqLen, num_requests, tokens_per_block, max_blocks_per_sequence,
cache_indir, workspace},
key_value_cache, block_pointers, maxSeqLen, num_requests, max_blocks_per_sequence, cache_indir,
workspace},
stream);
}
@ -339,8 +341,15 @@ int GPTAttentionPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
nvinfer1::DataType GPTAttentionPlugin::getOutputDataType(
int index, const nvinfer1::DataType* inputTypes, int nbInputs) const noexcept
{
TLLM_CHECK(index == 0 || index == 1);
return inputTypes[index];
TLLM_CHECK(index == 0 || (!mPagedKVCache && index == 1));
if (index == 0)
{
return inputTypes[getInputTensorIdx()];
}
else
{
return inputTypes[getPastKeyValueIdx()];
}
}
// IPluginV2 Methods
@ -357,7 +366,7 @@ const char* GPTAttentionPlugin::getPluginVersion() const noexcept
int GPTAttentionPlugin::getNbOutputs() const noexcept
{
return 2;
return mPagedKVCache ? 1 : 2;
}
size_t GPTAttentionPlugin::getSerializationSize() const noexcept
@ -403,8 +412,8 @@ IPluginV2* GPTAttentionPluginCreator::createPlugin(const char* name, const Plugi
try
{
auto* obj = new GPTAttentionPlugin(p.getScalar<int32_t>("num_heads").value(),
p.getScalar<int32_t>("num_kv_heads").value(), p.getScalar<int32_t>("unidirectional").value(),
p.getScalar<float>("q_scaling").value(),
p.getScalar<int32_t>("num_kv_heads").value(), p.getScalar<int32_t>("head_size").value(),
p.getScalar<int32_t>("unidirectional").value(), p.getScalar<float>("q_scaling").value(),
static_cast<PositionEmbeddingType>(p.getScalar<int8_t>("position_embedding_type").value()),
p.getScalar<int32_t>("rotary_embedding_dim").value(), p.getScalar<float>("rotary_embedding_base").value(),
static_cast<RotaryScalingType>(p.getScalar<int8_t>("rotary_embedding_scale_type").value()),
@ -418,6 +427,7 @@ IPluginV2* GPTAttentionPluginCreator::createPlugin(const char* name, const Plugi
static_cast<bool>(p.getScalar<int8_t>("remove_input_padding").value()),
static_cast<AttentionMaskType>(p.getScalar<int32_t>("mask_type").value()),
static_cast<bool>(p.getScalar<int32_t>("paged_kv_cache").value()),
p.getScalar<int32_t>("tokens_per_block").value(),
static_cast<nvinfer1::DataType>(p.getScalar<int32_t>("type_id").value()),
p.getScalar<int32_t>("max_context_length").value(),
static_cast<bool>(p.getScalar<int8_t>("qkv_bias_enabled").value()));

View File

@ -38,40 +38,40 @@ namespace tensorrt_llm::plugins
// Context sequences have to appear first, generation sequences after
// inputs
// input_tensor [batch_size, seq_len, local_hidden_size + 2 * local_num_kv_heads * head_size]
// [1, num_tokens, local_hidden_size + 2 * local_num_kv_heads * head_size] when
// enable_remove_input_padding
// past_key_value_pool [blocks, 2, local_num_kv_heads, tokens_per_block, head_size] if paged_kv_attention
// or [batch_size, 2, local_num_kv_heads, max_seq_len, head_size]
// sequence_length [batch_size]
// host_past_key_value_lengths [batch_size] (int32)
// context_lengths [batch_size]
// cache_indir [num_gen_requests, beam_width, memory_max_len] (required in beamsearch)
// host_request_types [batch_size] int32. 0: context; 1: generation: 2: none. When not in inflight-batching mode,
// all elements must be identical.
// kv_cache_quantization_scale [1] (optional)
// kv_cache_dequantization_scale [1] (optional)
// block_pointers [batch_size, 2, max_blocks_per_seq] (optional if paged kv cache)
// alibi_slopes [num_heads] (optional for ALiBi position embedding)
// host_context_lengths [batch_size] int32. (optional, required when remove_input_padding is true)
// qkv_bias (optional) [local_hidden_size * 3]
// 0. input_tensor [batch_size, seq_len, local_hidden_size + 2 * local_num_kv_heads * head_size] or
// [1, num_tokens, local_hidden_size + 2 * local_num_kv_heads * head_size] when
// enable_remove_input_padding
// 1. sequence_length [batch_size]
// 2. host_past_key_value_lengths [batch_size] (int32)
// 3. context_lengths [batch_size]
// 4. cache_indir [num_gen_requests, beam_width, memory_max_len] (required in beamsearch)
// 5. host_request_types [batch_size] int32. 0: context; 1: generation: 2: none. When not in inflight-batching
// mode,
// all elements must be identical.
// 6. past_key_value_pool [batch_size, 2, local_num_kv_heads, max_seq_len, head_size] or
// block_pointers [batch_size, 2, max_blocks_per_seq] if paged kv cache
// 7. kv_cache_quantization_scale [1] (optional)
// 8. kv_cache_dequantization_scale [1] (optional)
// 9. alibi_slopes [num_heads] (optional for ALiBi position embedding)
// 10. host_context_lengths [batch_size] int32. (optional, required when remove_input_padding is true)
// 11. qkv_bias (optional) [local_hidden_size * 3]
//
// outputs
// output_tensor [batch_size, seq_len, local_hidden_size]
// present_key_value_pool [blocks, 2, local_num_kv_heads, tokens_per_block, head_size] if paged_kv_attention
// or [batch_size, 2, local_num_kv_heads, max_seq_len, head_size]
// present_key_value_pool (optional if not paged kv cache) [batch_size, 2, local_num_kv_heads, max_seq_len,
// head_size]
class GPTAttentionPlugin : public GPTAttentionPluginCommon
{
public:
GPTAttentionPlugin(int num_heads, int num_kv_heads, int unidirectional, float q_scaling,
GPTAttentionPlugin(int num_heads, int num_kv_heads, int head_size, int unidirectional, float q_scaling,
tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
int rotary_embedding_dim, // for RoPE. 0 for non-RoPE
float rotary_embedding_base, tensorrt_llm::kernels::RotaryScalingType rotary_embedding_scale_type,
float rotary_embedding_scale, int rotary_embedding_max_positions, int tp_size, int tp_rank, // for ALiBi
tensorrt_llm::kernels::ContextFMHAType context_fmha_type, bool multi_block_mode, int kv_cache_quant_mode,
bool remove_input_padding, tensorrt_llm::kernels::AttentionMaskType mask_type, bool paged_kv_cache,
nvinfer1::DataType type, int32_t max_context_length, bool qkv_bias_enabled);
int tokens_per_block, nvinfer1::DataType type, int32_t max_context_length, bool qkv_bias_enabled);
GPTAttentionPlugin(const void* data, size_t length);
@ -134,33 +134,40 @@ private:
return 0;
}
IndexType getPastKeyValueIdx() const
IndexType getSequenceLengthIdx() const
{
return 1;
}
IndexType getSequenceLengthIdx() const
IndexType getHostPastKeyValueLengthsIdx() const
{
return 2;
}
IndexType getHostPastKeyValueLengthsIdx() const
IndexType getContextLengthsIdx() const
{
return 3;
}
IndexType getContextLengthsIdx() const
IndexType getCacheIndirIdx() const
{
return 4;
}
IndexType getCacheIndirIdx() const
IndexType getRequestTypesIdx() const
{
return 5;
}
IndexType getRequestTypesIdx() const
IndexType getKVCacheBlockPointersIdx() const
{
// NOTE We either provide this tensor when mPagedKVCache is true or PastKeyValue otherwise
return 6;
}
IndexType getPastKeyValueIdx() const
{
// NOTE We either provide this tensor when mPagedKVCache is false or KVCacheBlockPointers otherwise
return 6;
}
@ -174,27 +181,21 @@ private:
return 8;
}
IndexType getKVCacheBlockPointersIdx() const
{
return mKVCacheQuantMode.hasKvCacheQuant() ? 9 : 7;
}
IndexType getAlibiSlopesIdx() const
{
return (mKVCacheQuantMode.hasKvCacheQuant() ? 9 : 7) + (mPagedKVCache ? 1 : 0);
return (mKVCacheQuantMode.hasKvCacheQuant() ? 9 : 7);
}
IndexType getHostContextLengthsIdx() const
{
TLLM_CHECK(mRemovePadding);
return (mKVCacheQuantMode.hasKvCacheQuant() ? 9 : 7) + (mPagedKVCache ? 1 : 0) + (isALiBi() ? 1 : 0);
return (mKVCacheQuantMode.hasKvCacheQuant() ? 9 : 7) + (isALiBi() ? 1 : 0);
}
IndexType getQKVBiasTensorIdx() const
{
TLLM_CHECK(mQKVBiasEnabled);
return (mKVCacheQuantMode.hasInt8KvCache() ? 9 : 7) + (mPagedKVCache ? 1 : 0) + (isALiBi() ? 1 : 0)
+ (mRemovePadding ? 1 : 0);
return (mKVCacheQuantMode.hasKvCacheQuant() ? 9 : 7) + (isALiBi() ? 1 : 0) + (mRemovePadding ? 1 : 0);
}
};

View File

@ -169,7 +169,7 @@ int LayernormQuantizationPlugin::enqueue(const nvinfer1::PluginTensorDesc* input
nvinfer1::DataType LayernormQuantizationPlugin::getOutputDataType(
int index, const nvinfer1::DataType* inputTypes, int nbInputs) const noexcept
{
assert((mDynActScaling && index < 2) || (~mDynActScaling && index == 0));
assert((mDynActScaling && index < 2) || (!mDynActScaling && index == 0));
if (index == 0)
{
// Output 0 quantized output of layer norm

View File

@ -0,0 +1,249 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/plugins/ncclPlugin/FTCustomAR.h"
#include "NvInferRuntimeBase.h"
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/mpiUtils.h"
namespace tensorrt_llm
{
CustomAllReduceComm::CustomAllReduceComm(size_t TPSize, size_t PPSize, int deviceId, size_t bufferSize)
: mTPSize(TPSize)
, mPPSize(PPSize)
, mTPRank(deviceId % TPSize)
, mPPRank(deviceId / TPSize)
, mDeviceId(deviceId)
, mBufferSize(bufferSize)
{
if (mPPSize == 0)
{
group_comm_ = mpi::COMM_WORLD;
}
else
{
mpi::comm_split(mpi::COMM_WORLD, mPPRank, mTPRank, &group_comm_);
}
param_.barrier_flag = 0;
// NOTE: assume All Reduce happens within the node (DGX A100)
param_.ranks_per_node = mTPSize;
param_.rank = mTPRank;
param_.local_rank = mTPRank;
param_.node_id = 0;
allocate();
IpcSyncMemHandle();
}
CustomAllReduceComm::~CustomAllReduceComm()
{
if (is_ipc_handle_opened_)
{
IpcCloseMemHandle();
}
mpi::barrier(); // wait for others to stop using resources before freeing them
if (mTPRank == 0)
{
for (int rank = 0; rank < mTPSize; rank++)
{
size_t device_id = mPPRank * mTPSize + rank;
cudaSetDevice(device_id);
cudaPointerAttributes comm_buffer_attributes, barrier_attributes;
common::check_cuda_error(
cudaPointerGetAttributes(&comm_buffer_attributes, param_.peer_comm_buffer_ptrs[rank]));
common::check_cuda_error(cudaPointerGetAttributes(&barrier_attributes, param_.peer_barrier_ptrs[rank]));
if (comm_buffer_attributes.type == 2)
{
common::check_cuda_error(cudaFree(param_.peer_comm_buffer_ptrs[rank]));
}
if (barrier_attributes.type == 2)
{
common::check_cuda_error(cudaFree(param_.peer_barrier_ptrs[rank]));
}
}
cudaSetDevice(mDeviceId);
setP2P(false);
}
}
void CustomAllReduceComm::IpcGetMemHandle()
{
for (int rank = 0; rank < mTPSize; rank++)
{
common::check_cuda_error(cudaIpcGetMemHandle(
&(param_.ipc_mem_handles.peer_barrier_ipc_handles[rank]), param_.peer_barrier_ptrs[rank]));
common::check_cuda_error(cudaIpcGetMemHandle(
&(param_.ipc_mem_handles.peer_comm_buffer_ipc_handles[rank]), param_.peer_comm_buffer_ptrs[rank]));
}
}
void CustomAllReduceComm::IpcSyncMemHandle()
{
if (mTPRank == 0)
{
IpcGetMemHandle();
}
mpi::bcast(reinterpret_cast<char*>(&(param_.ipc_mem_handles)), sizeof(kernels::AllReduceIpcMemHandles),
mpi::MPI_TYPE_CHAR, 0, group_comm_);
if (mTPRank != 0)
{
IpcOpenMemHandle();
}
common::check_cuda_error(cudaSetDevice(mDeviceId));
}
void CustomAllReduceComm::IpcOpenMemHandle()
{
if (is_ipc_handle_opened_)
{
IpcCloseMemHandle();
is_ipc_handle_opened_ = false;
}
if (!is_ipc_handle_opened_)
{
for (int rank = 0; rank < mTPSize; rank++)
{
common::check_cuda_error(cudaIpcOpenMemHandle((void**) (&(param_.peer_barrier_ptrs[rank])),
param_.ipc_mem_handles.peer_barrier_ipc_handles[rank], cudaIpcMemLazyEnablePeerAccess));
common::check_cuda_error(cudaIpcOpenMemHandle((void**) (&(param_.peer_comm_buffer_ptrs[rank])),
param_.ipc_mem_handles.peer_comm_buffer_ipc_handles[rank], cudaIpcMemLazyEnablePeerAccess));
}
param_.local_output_buffer_ptr = param_.peer_comm_buffer_ptrs[mTPRank];
is_ipc_handle_opened_ = true;
}
}
void CustomAllReduceComm::IpcCloseMemHandle()
{
if (is_ipc_handle_opened_)
{
for (int rank = 0; rank < mTPSize; rank++)
{
common::check_cuda_error(cudaIpcCloseMemHandle(param_.peer_barrier_ptrs[rank]));
common::check_cuda_error(cudaIpcCloseMemHandle(param_.peer_comm_buffer_ptrs[rank]));
}
is_ipc_handle_opened_ = false;
}
}
void CustomAllReduceComm::customAllReduce(
void* data, size_t elts, size_t size_per_elem, nvinfer1::DataType dataType, cudaStream_t stream)
{
param_.local_output_buffer_ptr = data;
param_.elts_total = elts;
param_.barrier_flag = FLAG(param_.barrier_flag + 1);
if (dataType == nvinfer1::DataType::kFLOAT)
{
using T = tensorrt_llm::CustomARCommTypeConverter<float>::Type;
kernels::invokeOneOrTwoShotAllReduceKernel<T>(param_, stream);
}
else if (dataType == nvinfer1::DataType::kHALF)
{
using T = tensorrt_llm::CustomARCommTypeConverter<half>::Type;
kernels::invokeOneOrTwoShotAllReduceKernel<T>(param_, stream);
}
else if (dataType == nvinfer1::DataType::kBF16)
{
using T = tensorrt_llm::CustomARCommTypeConverter<__nv_bfloat16>::Type;
kernels::invokeOneOrTwoShotAllReduceKernel<T>(param_, stream);
}
else
{
TLLM_CHECK_WITH_INFO(false, "Unsupported dataType for customAllReduce");
}
}
void CustomAllReduceComm::allocate()
{
if (mTPRank != 0)
return;
setP2P();
for (size_t i = 0; i < mTPSize; i++)
{
size_t device_id = mPPRank * mTPSize + i;
common::check_cuda_error(cudaSetDevice(device_id));
common::check_cuda_error(cudaMalloc(&(param_.peer_comm_buffer_ptrs[i]), mBufferSize));
common::check_cuda_error(
cudaMalloc(&(param_.peer_barrier_ptrs[i]), mTPSize * (MAX_ALL_REDUCE_BLOCKS + 1) * sizeof(uint32_t)));
common::check_cuda_error(
cudaMemset(param_.peer_barrier_ptrs[i], 0, mTPSize * (MAX_ALL_REDUCE_BLOCKS + 1) * sizeof(uint32_t)));
}
cudaSetDevice(mDeviceId);
}
bool CustomAllReduceComm::isAvailable()
{
#if defined(CUDART_VERSION) && CUDART_VERSION >= 11020
#else
return false;
#endif
if (!mpi::isInitialized())
{
return false;
}
auto worldSize = mpi::getCommWorldSize();
auto rank = mpi::getCommWorldRank();
if ((worldSize % 2 != 0) || (worldSize > MAX_RANKS_PER_NODE) || (worldSize == 0))
{
return false;
}
return true;
}
void CustomAllReduceComm::setP2P(bool activate)
{
int peer_access_available = 0;
size_t device_offset = mPPRank * mTPSize;
for (int i = 0; i < mTPSize; i++)
{
cudaSetDevice(device_offset + i);
for (int j = 0; j < mTPSize; j++)
{
if (i == j)
{
continue;
}
cudaDeviceCanAccessPeer(&peer_access_available, device_offset + i, device_offset + j);
assert(peer_access_available);
if (activate)
{
cudaDeviceEnablePeerAccess(device_offset + j, 0);
cudaError_t result = cudaGetLastError();
if (result == cudaErrorPeerAccessAlreadyEnabled)
{
result = cudaSuccess;
}
common::check_cuda_error(result);
}
else
{
cudaDeviceDisablePeerAccess(device_offset + j);
}
}
}
cudaSetDevice(mDeviceId);
}
void* CustomAllReduceComm::getShareBuffer()
{
return reinterpret_cast<void*>(param_.peer_comm_buffer_ptrs[mTPRank]);
}
} // namespace tensorrt_llm

View File

@ -0,0 +1,87 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <memory>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include "NvInferRuntimeBase.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/common/mpiUtils.h"
#include "tensorrt_llm/common/tensor.h"
#include "tensorrt_llm/kernels/customAllReduceKernels.h"
namespace tensorrt_llm
{
class CustomAllReduceComm
{
public:
CustomAllReduceComm(size_t TPSize, size_t PPSize, int deviceId, size_t bufferSize);
~CustomAllReduceComm();
void customAllReduce(
void* data, size_t elts, size_t size_per_elem, nvinfer1::DataType dataType, cudaStream_t stream);
void* getShareBuffer();
static bool isAvailable();
private:
void setP2P(bool activate = true);
void IpcGetMemHandle();
void IpcOpenMemHandle();
void IpcCloseMemHandle();
void IpcSyncMemHandle();
void allocate();
kernels::AllReduceParams param_;
bool is_ipc_handle_opened_ = false;
size_t mTPSize;
size_t mTPRank;
size_t mPPSize;
size_t mPPRank;
size_t mBufferSize;
int mDeviceId;
mpi::MpiComm group_comm_;
};
template <typename T>
struct CustomARCommTypeConverter
{
using Type = uint32_t;
};
template <>
struct CustomARCommTypeConverter<half>
{
using Type = uint16_t;
};
#ifdef ENABLE_BF16
template <>
struct CustomARCommTypeConverter<__nv_bfloat16>
{
using Type = __nv_bfloat16;
};
#endif
} // namespace tensorrt_llm

View File

@ -161,7 +161,7 @@ int AllgatherPlugin::initialize() noexcept
MPI_Status status;
MPICHECK(MPI_Recv(&id, sizeof(id), MPI_BYTE, *mGroup.begin(), 0, MPI_COMM_WORLD, &status));
}
(*commMap)[mGroup] == nullptr;
(*commMap)[mGroup] = nullptr;
NCCLCHECK(ncclCommInitRank(&((*commMap)[mGroup]), mGroup.size(), id, groupRank));
return 0;
}

View File

@ -16,6 +16,7 @@
*/
#pragma once
#include "tensorrt_llm/common/mpiUtils.h"
#include "tensorrt_llm/plugins/common/plugin.h"
#include <cassert>
#include <mpi.h>

View File

@ -15,6 +15,8 @@
* limitations under the License.
*/
#include "allreducePlugin.h"
#include "mpi.h"
#include "plugin.h"
using namespace nvinfer1;
using tensorrt_llm::plugins::AllreducePluginCreator;
@ -25,9 +27,11 @@ static const char* ALLREDUCE_PLUGIN_NAME{"AllReduce"};
PluginFieldCollection AllreducePluginCreator::mFC{};
std::vector<nvinfer1::PluginField> AllreducePluginCreator::mPluginAttributes;
AllreducePlugin::AllreducePlugin(std::set<int> group, nvinfer1::DataType type)
AllreducePlugin::AllreducePlugin(std::set<int> group, nvinfer1::DataType type, AllReduceStrategyType strategy)
: mGroup(group)
, mType(type)
, mStrategy(strategy)
, mCustomARBufferSize(0)
{
}
@ -36,6 +40,8 @@ AllreducePlugin::AllreducePlugin(const void* data, size_t length)
{
const char *d = reinterpret_cast<const char*>(data), *a = d;
read(d, mType);
read(d, mStrategy);
read(d, mCustomARBufferSize);
mGroup.clear();
int groupItem = 0;
while (d != a + length)
@ -69,6 +75,32 @@ bool AllreducePlugin::supportsFormatCombination(
void AllreducePlugin::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept
{
if (mStrategy == AllReduceStrategyType::NCCL)
{
return;
}
size_t sizePerElem = 0;
switch (mType)
{
case DataType::kFLOAT: sizePerElem = sizeof(float); break;
case DataType::kHALF: sizePerElem = sizeof(half); break;
#ifdef ENABLE_BF16
case DataType::kBF16: sizePerElem = sizeof(__nv_bfloat16); break;
#endif
default: break;
}
size_t inputSize = 1;
for (int i = 0; i < in[0].max.nbDims; i++)
{
inputSize *= in[0].max.d[i];
}
if (isBuilding())
{
mCustomARBufferSize = std::max(mCustomARBufferSize, inputSize * sizePerElem);
}
}
size_t AllreducePlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs,
@ -77,6 +109,13 @@ size_t AllreducePlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc* input
return 0;
}
size_t AllreducePlugin::ncclEstimatedThreshold(int worldSize) const noexcept
{
// returns the message size over which it's more interesting to use NCCL
// 0.60 * TP_SIZE * 10MB
return 0.60 * (10 * 1000 * 1000 * worldSize);
}
int AllreducePlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept
{
@ -89,9 +128,36 @@ int AllreducePlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const
{
size *= inputDesc[0].dims.d[i];
}
size_t sizePerElem = 0;
switch (mType)
{
case DataType::kFLOAT: sizePerElem = sizeof(float); break;
case DataType::kHALF: sizePerElem = sizeof(half); break;
#ifdef ENABLE_BF16
case DataType::kBF16: sizePerElem = sizeof(__nv_bfloat16); break;
#endif
default: break;
}
NCCLCHECK(ncclAllReduce(
inputs[0], outputs[0], size, (*getDtypeMap())[inputDesc[0].type], ncclSum, (*getCommMap())[mGroup], stream));
auto runtimeStrategy = mStrategy;
if (runtimeStrategy == AllReduceStrategyType::AUTO)
{
runtimeStrategy = size * sizePerElem > ncclEstimatedThreshold(mGroup.size()) ? AllReduceStrategyType::NCCL
: AllReduceStrategyType::CUSTOM;
}
if (runtimeStrategy == AllReduceStrategyType::NCCL)
{
NCCLCHECK(ncclAllReduce(inputs[0], outputs[0], size, (*getDtypeMap())[inputDesc[0].type], ncclSum,
(*getCommMap())[mGroup], stream));
}
else if (runtimeStrategy == AllReduceStrategyType::CUSTOM)
{
auto shareBuffer = mCustomAllReduceContext->getShareBuffer();
cudaMemcpyAsync(shareBuffer, inputs[0], size * sizePerElem, cudaMemcpyDeviceToDevice, stream);
mCustomAllReduceContext->customAllReduce(outputs[0], size, sizePerElem, mType, stream);
}
return 0;
}
@ -121,68 +187,105 @@ int AllreducePlugin::getNbOutputs() const noexcept
return 1;
}
bool AllreducePlugin::isCustomAllReduceSuported(int ranks_per_node) const noexcept
{
constexpr bool isCudaVersionSupported =
#if defined(CUDART_VERSION) && CUDART_VERSION >= 11020
true;
#else
false;
#endif
return isCudaVersionSupported && (ranks_per_node % 2 == 0) && (ranks_per_node <= MAX_RANKS_PER_NODE)
&& (ranks_per_node > 0);
}
int AllreducePlugin::initialize() noexcept
{
auto* commMap = getCommMap();
// [] operator inserts T() if it does not exist
if (isBuilding() || (*commMap)[mGroup] != nullptr)
if (isBuilding())
{
return 0;
}
int myRank, nRanks;
MPICHECK(MPI_Comm_rank(MPI_COMM_WORLD, &myRank));
MPICHECK(MPI_Comm_size(MPI_COMM_WORLD, &nRanks));
int groupRank = 0;
for (auto it = mGroup.begin(); it != mGroup.end(); ++it)
int deviceId;
cudaGetDevice(&deviceId);
TLLM_CHECK_WITH_INFO(myRank == deviceId, "MPI rank != cudaDeviceId, check if cudaSetDevice has been called");
if (mStrategy == AllReduceStrategyType::NCCL || mStrategy == AllReduceStrategyType::AUTO)
{
if (*it == myRank)
auto* commMap = getCommMap();
// [] operator inserts T() if it does not exist
if ((*commMap)[mGroup] == nullptr)
{
break;
int groupRank = 0;
for (auto it = mGroup.begin(); it != mGroup.end(); ++it)
{
if (*it == myRank)
{
break;
}
++groupRank;
}
ncclUniqueId id;
if (myRank == *mGroup.begin())
{
ncclGetUniqueId(&id);
for (auto it = std::next(std::begin(mGroup), 1); it != mGroup.end(); ++it)
{
MPICHECK(MPI_Send(&id, sizeof(id), MPI_BYTE, *it, 0, MPI_COMM_WORLD));
}
}
else
{
MPI_Status status;
MPICHECK(MPI_Recv(&id, sizeof(id), MPI_BYTE, *mGroup.begin(), 0, MPI_COMM_WORLD, &status));
}
(*commMap)[mGroup] = nullptr;
NCCLCHECK(ncclCommInitRank(&((*commMap)[mGroup]), mGroup.size(), id, groupRank));
}
++groupRank;
}
ncclUniqueId id;
if (myRank == *mGroup.begin())
if (mStrategy == AllReduceStrategyType::CUSTOM || mStrategy == AllReduceStrategyType::AUTO)
{
ncclGetUniqueId(&id);
for (auto it = std::next(std::begin(mGroup), 1); it != mGroup.end(); ++it)
{
MPICHECK(MPI_Send(&id, sizeof(id), MPI_BYTE, *it, 0, MPI_COMM_WORLD));
}
TLLM_CHECK_WITH_INFO(tensorrt_llm::CustomAllReduceComm::isAvailable(), "Custom all reduce isn't available.");
auto allocSize = mCustomARBufferSize > 0 ? mCustomARBufferSize : CUSTOM_AR_SIZE_THRESHOLD;
mCustomAllReduceContext = std::make_shared<tensorrt_llm::CustomAllReduceComm>(nRanks, 0, myRank, allocSize);
}
else
{
MPI_Status status;
MPICHECK(MPI_Recv(&id, sizeof(id), MPI_BYTE, *mGroup.begin(), 0, MPI_COMM_WORLD, &status));
}
(*commMap)[mGroup] == nullptr;
NCCLCHECK(ncclCommInitRank(&((*commMap)[mGroup]), mGroup.size(), id, groupRank));
return 0;
}
void AllreducePlugin::terminate() noexcept
{
auto* commMap = getCommMap();
// [] operator inserts T() if it does not exist
if (isBuilding() || (*commMap)[mGroup] == nullptr)
if (mStrategy == AllReduceStrategyType::NCCL || mStrategy == AllReduceStrategyType::AUTO)
{
return;
auto* commMap = getCommMap();
// [] operator inserts T() if it does not exist
if (isBuilding() || (*commMap)[mGroup] == nullptr)
{
return;
}
NCCLCHECK(ncclCommDestroy((*commMap)[mGroup]));
(*commMap)[mGroup] = nullptr;
}
NCCLCHECK(ncclCommDestroy((*commMap)[mGroup]));
(*commMap)[mGroup] = nullptr;
}
size_t AllreducePlugin::getSerializationSize() const noexcept
{
return sizeof(int) * mGroup.size() + sizeof(mType);
return sizeof(int) * mGroup.size() + sizeof(mType) + sizeof(mStrategy) + sizeof(mCustomARBufferSize);
}
void AllreducePlugin::serialize(void* buffer) const noexcept
{
char *d = static_cast<char*>(buffer), *a = d;
write(d, mType);
write(d, mStrategy);
write(d, mCustomARBufferSize);
for (auto it = mGroup.begin(); it != mGroup.end(); ++it)
{
write(d, *it);
@ -204,6 +307,7 @@ AllreducePluginCreator::AllreducePluginCreator()
mPluginAttributes.clear();
mPluginAttributes.emplace_back(PluginField("group", nullptr, PluginFieldType::kINT32, 1));
mPluginAttributes.emplace_back(PluginField("type_id", nullptr, PluginFieldType::kINT32, 1));
mPluginAttributes.emplace_back(PluginField("strategy", nullptr, PluginFieldType::kINT8, 1));
mFC.nbFields = mPluginAttributes.size();
mFC.fields = mPluginAttributes.data();
}
@ -228,6 +332,7 @@ IPluginV2* AllreducePluginCreator::createPlugin(const char* name, const PluginFi
const PluginField* fields = fc->fields;
std::set<int> group;
nvinfer1::DataType type;
AllreducePlugin::AllReduceStrategyType strategy;
// Read configurations from each fields
for (int i = 0; i < fc->nbFields; ++i)
{
@ -247,11 +352,16 @@ IPluginV2* AllreducePluginCreator::createPlugin(const char* name, const PluginFi
TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
type = static_cast<nvinfer1::DataType>(*(static_cast<const nvinfer1::DataType*>(fields[i].data)));
}
else if (!strcmp(attrName, "strategy"))
{
TLLM_CHECK(fields[i].type == PluginFieldType::kINT8);
strategy = static_cast<AllreducePlugin::AllReduceStrategyType>(*static_cast<const int8_t*>(fields[i].data));
}
}
try
{
auto* obj = new AllreducePlugin(group, type);
auto* obj = new AllreducePlugin(group, type, strategy);
obj->setPluginNamespace(mNamespace.c_str());
return obj;
}

View File

@ -16,8 +16,11 @@
*/
#pragma once
#include "FTCustomAR.h"
#include "tensorrt_llm/common/mpiUtils.h"
#include "tensorrt_llm/plugins/common/plugin.h"
#include <cassert>
#include <memory>
#include <mpi.h>
#include <nccl.h>
#include <set>
@ -30,7 +33,15 @@ namespace tensorrt_llm::plugins
class AllreducePlugin : public BasePlugin
{
public:
AllreducePlugin(std::set<int> group, nvinfer1::DataType type);
enum class AllReduceStrategyType : int8_t
{
NCCL = 0,
CUSTOM = 1,
AUTO = 2,
};
AllreducePlugin(
std::set<int> group, nvinfer1::DataType type, AllReduceStrategyType strategy = AllReduceStrategyType::NCCL);
AllreducePlugin(const void* data, size_t length);
@ -63,10 +74,16 @@ public:
void serialize(void* buffer) const noexcept override;
void destroy() noexcept override;
bool isCustomAllReduceSuported(int ranks_per_node) const noexcept;
private:
size_t ncclEstimatedThreshold(int worldSize) const noexcept;
const std::string mLayerName;
std::set<int> mGroup;
nvinfer1::DataType mType;
AllReduceStrategyType mStrategy;
std::shared_ptr<tensorrt_llm::CustomAllReduceComm> mCustomAllReduceContext;
size_t mCustomARBufferSize;
};
class AllreducePluginCreator : public BaseCreator

View File

@ -16,6 +16,7 @@
*/
#pragma once
#include "tensorrt_llm/common/mpiUtils.h"
#include "tensorrt_llm/plugins/common/plugin.h"
#include <cassert>
#include <mpi.h>

View File

@ -85,7 +85,6 @@ int SendPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinf
}
NCCLCHECK(ncclSend(inputs[0], size, (*getDtypeMap())[inputDesc[0].type], 1, mComm, stream));
return 0;
}

View File

@ -16,6 +16,7 @@
*/
#pragma once
#include "tensorrt_llm/common/mpiUtils.h"
#include "tensorrt_llm/plugins/common/plugin.h"
#include <cassert>
#include <mpi.h>

View File

@ -164,7 +164,7 @@ int RmsnormQuantizationPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDe
nvinfer1::DataType RmsnormQuantizationPlugin::getOutputDataType(
int index, const nvinfer1::DataType* inputTypes, int nbInputs) const noexcept
{
assert((mDynActScaling && index < 2) || (~mDynActScaling && index == 0));
assert((mDynActScaling && index < 2) || (!mDynActScaling && index == 0));
if (index == 0)
{
// Output 0 quantized output of layer norm

View File

@ -64,6 +64,11 @@ void SmoothQuantGemmPluginProfiler::computeTmpSize(int maxM, int n, int k)
setTmpWorkspaceSizeInBytes(bytes);
}
std::vector<SmoothQuantGemmPluginProfiler::Config> SmoothQuantGemmPluginProfiler::getTactics(int m, int n, int k) const
{
return mRunner->getConfigs();
}
SmoothQuantGemmPlugin::SmoothQuantGemmPlugin(
QuantMode quantMode, nvinfer1::DataType type, const SmoothQuantGemmPlugin::PluginProfilerPtr& pluginProfiler)
: mQuantMode(quantMode)
@ -302,7 +307,7 @@ void SmoothQuantGemmPlugin::destroy() noexcept
void SmoothQuantGemmPlugin::configGemm()
{
mPluginProfiler->profileTactics(m_sqGemmRunner->getConfigs(), m_sqGemmRunner, mType, mDims, mGemmId);
mPluginProfiler->profileTactics(m_sqGemmRunner, mType, mDims, mGemmId);
}
///////////////

View File

@ -48,6 +48,8 @@ protected:
void computeTmpSize(int maxM, int n, int k) override;
std::vector<Config> getTactics(int m, int n, int k) const override;
private:
tensorrt_llm::common::QuantMode mQuantMode;
};
@ -103,7 +105,7 @@ private:
SqGemmRunnerPtr m_sqGemmRunner;
tensorrt_llm::common::QuantMode mQuantMode;
int m_workspaceMaxSize;
size_t m_workspaceMaxSize;
GemmDims mDims{};
GemmIdCore mGemmId{};

View File

@ -86,6 +86,12 @@ void WeightOnlyGroupwiseQuantGemmPluginProfiler::computeTmpSize(int maxM, int n,
setTmpWorkspaceSizeInBytes(bytes);
}
std::vector<WeightOnlyGroupwiseQuantGemmPluginProfiler::Config> WeightOnlyGroupwiseQuantGemmPluginProfiler::getTactics(
int m, int n, int k) const
{
return mRunner->getConfigs();
}
WeightOnlyGroupwiseQuantMatmulPlugin::WeightOnlyGroupwiseQuantMatmulPlugin(nvinfer1::DataType type, int quant_algo,
int group_size, const WeightOnlyGroupwiseQuantMatmulPlugin::PluginProfilerPtr& pluginProfiler)
: mPluginProfiler(pluginProfiler)
@ -164,8 +170,7 @@ nvinfer1::IPluginV2DynamicExt* WeightOnlyGroupwiseQuantMatmulPlugin::clone() con
void WeightOnlyGroupwiseQuantMatmulPlugin::configGemm()
{
mPluginProfiler->profileTactics(
m_weightOnlyGroupwiseGemmRunner->getConfigs(), m_weightOnlyGroupwiseGemmRunner, mType, mDims, mGemmId);
mPluginProfiler->profileTactics(m_weightOnlyGroupwiseGemmRunner, mType, mDims, mGemmId);
}
nvinfer1::DimsExprs WeightOnlyGroupwiseQuantMatmulPlugin::getOutputDimensions(
@ -250,7 +255,8 @@ void WeightOnlyGroupwiseQuantMatmulPlugin::configurePlugin(const nvinfer1::Dynam
}
mGemmId = {N, K, mType};
int smoothedActSize = maxM * maxK * (in[0].desc.type == nvinfer1::DataType::kFLOAT ? 4 : 2);
size_t smoothedActSize = static_cast<size_t>(maxM) * static_cast<size_t>(maxK)
* (in[0].desc.type == nvinfer1::DataType::kFLOAT ? 4 : 2);
m_workspaceMaxSize = smoothedActSize + m_weightOnlyGroupwiseGemmRunner->getWorkspaceSize(maxM, maxN, maxK);
}

View File

@ -63,6 +63,8 @@ protected:
void computeTmpSize(int maxM, int n, int k) override;
std::vector<Config> getTactics(int m, int n, int k) const override;
private:
int mQuantAlgo;
int mGroupSize;
@ -119,7 +121,7 @@ private:
const std::string mLayerName;
WeightOnlyGemmRunnerPtr m_weightOnlyGroupwiseGemmRunner;
int m_workspaceMaxSize;
size_t m_workspaceMaxSize;
nvinfer1::DataType mType;
int mSM = tensorrt_llm::common::getSMVersion();

View File

@ -71,6 +71,12 @@ void WeightOnlyQuantGemmPluginProfiler::computeTmpSize(int maxM, int n, int k)
setTmpWorkspaceSizeInBytes(bytes);
}
std::vector<WeightOnlyQuantGemmPluginProfiler::Config> WeightOnlyQuantGemmPluginProfiler::getTactics(
int m, int n, int k) const
{
return mRunner->getConfigs();
}
WeightOnlyQuantMatmulPlugin::WeightOnlyQuantMatmulPlugin(
nvinfer1::DataType type, int weightTypeId, const WeightOnlyQuantMatmulPlugin::PluginProfilerPtr& pluginProfiler)
: mPluginProfiler(pluginProfiler)
@ -130,8 +136,7 @@ nvinfer1::IPluginV2DynamicExt* WeightOnlyQuantMatmulPlugin::clone() const noexce
void WeightOnlyQuantMatmulPlugin::configGemm()
{
mPluginProfiler->profileTactics(
m_weightOnlyGemmRunner->getConfigs(), m_weightOnlyGemmRunner, mType, mDims, mGemmId);
mPluginProfiler->profileTactics(m_weightOnlyGemmRunner, mType, mDims, mGemmId);
}
nvinfer1::DimsExprs WeightOnlyQuantMatmulPlugin::getOutputDimensions(

View File

@ -55,6 +55,8 @@ protected:
void computeTmpSize(int maxM, int n, int k) override;
std::vector<Config> getTactics(int m, int n, int k) const override;
private:
int mWeightTypeId;
};
@ -111,7 +113,7 @@ private:
const std::string mLayerName;
WeightOnlyGemmRunnerPtr m_weightOnlyGemmRunner;
int m_workspaceMaxSize;
size_t m_workspaceMaxSize;
nvinfer1::DataType mType;
int mWeightTypeId;
int mSM = tensorrt_llm::common::getSMVersion();

View File

@ -35,19 +35,20 @@ set(SRCS
include_directories(${API_INCLUDE_DIR}/tensorrt_llm/runtime)
add_compile_options(-Wall)
if(NOT MSVC)
# additional warnings
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall")
else() # Windows
# warning level 4
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /W4")
endif()
add_library(runtime_src OBJECT ${SRCS})
set_property(TARGET runtime_src PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET runtime_src PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
target_include_directories(runtime_src PRIVATE ${MPI_INCLUDE_PATH})
set(JSON_SRC_DIR ${PROJECT_SOURCE_DIR}/../3rdparty/json)
add_subdirectory(${JSON_SRC_DIR} ${CMAKE_CURRENT_BINARY_DIR}/json)
if(ENABLE_MULTI_DEVICE EQUAL 1)
target_link_libraries(runtime_src PUBLIC nlohmann_json::nlohmann_json
${NCCL_LIB})
else()
target_link_libraries(runtime_src PUBLIC nlohmann_json::nlohmann_json)
target_link_libraries(runtime_src PUBLIC ${NCCL_LIB})
endif()

View File

@ -82,11 +82,11 @@ void BufferManager::setZero(IBuffer& buffer) const
}
}
void BufferManager::copy(void const* src, IBuffer& dst) const
void BufferManager::copy(void const* src, IBuffer& dst, MemoryType srcType) const
{
if (dst.getSizeInBytes() > 0)
{
if (IBuffer::memoryType(src) != MemoryType::kGPU && dst.getMemoryType() != MemoryType::kGPU)
if (srcType != MemoryType::kGPU && dst.getMemoryType() != MemoryType::kGPU)
{
std::memcpy(dst.data(), src, dst.getSizeInBytes());
}
@ -97,11 +97,11 @@ void BufferManager::copy(void const* src, IBuffer& dst) const
}
}
void BufferManager::copy(IBuffer const& src, void* dst) const
void BufferManager::copy(IBuffer const& src, void* dst, MemoryType dstType) const
{
if (src.getSizeInBytes() > 0)
{
if (IBuffer::memoryType(dst) != MemoryType::kGPU && src.getMemoryType() != MemoryType::kGPU)
if (src.getMemoryType() != MemoryType::kGPU && dstType != MemoryType::kGPU)
{
std::memcpy(dst, src.data(), src.getSizeInBytes());
}
@ -117,7 +117,7 @@ void BufferManager::copy(IBuffer const& src, IBuffer& dst) const
TLLM_CHECK_WITH_INFO(src.getDataType() == dst.getDataType(), "Incompatible data types");
TLLM_CHECK_WITH_INFO(src.getSizeInBytes() == dst.getSizeInBytes(),
tc::fmtstr("Incompatible buffer sizes: %lu != %lu", src.getSizeInBytes(), dst.getSizeInBytes()));
copy(src, dst.data());
copy(src, dst.data(), dst.getMemoryType());
}
BufferManager::IBufferPtr BufferManager::allocate(

View File

@ -35,13 +35,12 @@ GptDecoder<T>::GptDecoder(size_t vocabSize, size_t vocabSizePadded, CudaStreamPt
: mManager{stream}
, mAllocator{mManager}
{
tc::cublasMMWrapper* cublasWrapper = nullptr;
bool isFreeBufferAfterForward{false};
cudaDeviceProp prop;
tc::check_cuda_error(cudaGetDeviceProperties(&prop, 0));
mDynamicDecodeLayer = std::make_shared<tensorrt_llm::layers::DynamicDecodeLayer<T>>(
vocabSize, vocabSizePadded, stream->get(), cublasWrapper, &mAllocator, isFreeBufferAfterForward, &prop);
vocabSize, vocabSizePadded, stream->get(), &mAllocator, isFreeBufferAfterForward, &prop);
}
template <typename T>

View File

@ -17,6 +17,7 @@
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/runtime/bufferManager.h"
#include "tensorrt_llm/runtime/cudaEvent.h"
#include "tensorrt_llm/runtime/runtimeKernels.h"
#include <algorithm>
@ -73,8 +74,6 @@ GptDecoderBatch::GptDecoderBatch(
, mVocabSizePadded{vocabSizePadded}
, mStream{std::move(stream)}
, mBufferManager{mStream}
, mEventStart(tc::CreateEvent())
, mEventStop(tc::CreateEvent())
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
auto constexpr nvTokenIdType = TRTDataType<TokenIdType>::value;
@ -98,6 +97,7 @@ GptDecoderBatch::GptDecoderBatch(
dOutput->finished = mBufferManager.emptyTensor(MemoryType::kGPU, TRTDataType<bool>::value);
// use batchSize many entries instead of the usual 1
dOutput->finishedSum = mBufferManager.emptyTensor(MemoryType::kPINNED, nvSizeType);
mFinishedSum = mBufferManager.pinned(ITensor::makeShape({1}), nvSizeType);
dOutput->lengths = mBufferManager.emptyTensor(MemoryType::kGPU, nvSizeType);
dOutput->cumLogProbs = mBufferManager.emptyTensor(MemoryType::kGPU, nvFloatType);
dOutput->beamHypotheses.empty(mBufferManager);
@ -155,7 +155,6 @@ void GptDecoderBatch::setup(
}
mStreams.resize(maxBatchSize);
mEvents.resize(maxBatchSize);
mDecoders.resize(maxBatchSize);
mDecodingInputs.resize(maxBatchSize);
mDecodingOutputs.resize(maxBatchSize);
@ -168,7 +167,6 @@ void GptDecoderBatch::setup(
{
mStreams[i] = std::make_shared<CudaStream>();
TLLM_CHECK(mStreams[i]->getDevice() == device);
mEvents[i] = tc::CreateEvent();
mDecoders[i] = IGptDecoder::create(dtype, mVocabSize, mVocabSizePadded, mStreams[i]);
mDecodingInputs[i].reset();
mDecodingOutputs[i].reset();
@ -201,7 +199,6 @@ void GptDecoderBatch::newRequest(
maxNewTokens, mMaxSequenceLength));
TLLM_CHECK(requestIds->getDataType() == TRTDataType<TokenIdType>::value);
auto const endId = request.endId.value_or(mVocabSize - 1);
auto const padId = request.padId.value_or(mVocabSize - 1);
auto constexpr localBatchSize = 1;
@ -273,7 +270,8 @@ void GptDecoderBatch::newRequest(
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
}
void GptDecoderBatch::forward(decoder_batch::Output& output, decoder_batch::Input const& input)
GptDecoderBatch::TokenPtr GptDecoderBatch::forwardAsync(
decoder_batch::Output& output, decoder_batch::Input const& input)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
auto& logits = input.logits;
@ -298,14 +296,15 @@ void GptDecoderBatch::forward(decoder_batch::Output& output, decoder_batch::Inpu
TLLM_CHECK(sequenceLengths);
auto constexpr singleRequest = 1;
mStream->record(mEventStart.get());
CudaEvent eventStart{};
mStream->record(eventStart);
for (std::int32_t i = 0; i < mActualBatchSize; ++i)
{
if (mFinished[i] || !input.active.at(i))
continue;
auto& stream = mStreams[i];
stream->wait(mEventStart.get());
stream->wait(eventStart.get());
auto& dInput = *mDecodingInputs[i];
auto& dOutput = *mDecodingOutputs[i];
auto logitsView = std::shared_ptr(ITensor::slice(logits, i, singleRequest));
@ -349,21 +348,34 @@ void GptDecoderBatch::forward(decoder_batch::Output& output, decoder_batch::Inpu
manager.copy(*dOutput.parentIds, *jointOutputParentIdsView);
}
auto& event = mEvents[i];
stream->record(event.get());
mStream->wait(event.get());
CudaEvent event{};
stream->record(event);
mStream->wait(event);
dInput.step += 1;
mNbSteps[i] += 1;
mFinished[i] = mNbSteps[i] >= mMaxNewTokens[i];
}
mStream->record(mEventStop.get());
TLLM_CUDA_CHECK(::cudaEventSynchronize(mEventStop.get()));
CudaEvent eventStop{};
mStream->record(eventStop);
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
return std::make_unique<decoder_batch::Token>(std::move(eventStop), input.active);
}
void GptDecoderBatch::forwardSync(decoder_batch::Token const& token)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
token.event.synchronize();
for (std::int32_t i = 0; i < mActualBatchSize; ++i)
{
auto& dOutput = *mDecodingOutputs[i];
mFinished[i] = mNbSteps[i] >= mMaxNewTokens[i]
// This condition requires the synchronization above
|| *bufferCast<SizeType>(*dOutput.finishedSum) == static_cast<SizeType>(dOutput.finished->getSize());
if (token.active[i] && !mFinished[i])
{
auto& dOutput = *mDecodingOutputs[i];
mFinished[i] = mFinished[i]
// This condition requires the synchronization above
|| *bufferCast<SizeType>(*dOutput.finishedSum) == static_cast<SizeType>(dOutput.finished->getSize());
}
}
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
}
@ -375,7 +387,6 @@ void GptDecoderBatch::postProcessRequest(SizeType batchIdx) const
auto& stream = mStreams[batchIdx];
auto manager = BufferManager{stream};
stream->wait(mEventStart.get());
auto& dInput = *mDecodingInputs[batchIdx];
auto& dOutput = *mDecodingOutputs[batchIdx];
@ -385,9 +396,9 @@ void GptDecoderBatch::postProcessRequest(SizeType batchIdx) const
IGptDecoder::gatherTree(*finalOutputIds, dOutput, dInput, manager);
manager.copy(*finalOutputIds, *outputIds);
auto& event = mEvents[batchIdx];
stream->record(event.get());
mStream->wait(event.get());
CudaEvent event{};
stream->record(event);
mStream->wait(event);
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
}
@ -434,7 +445,7 @@ void GptDecoderBatch::newBatch(GenerationInput const& inputs, SamplingConfig con
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
}
bool GptDecoderBatch::forward(decoder::Output& output, decoder::Input const& input)
void GptDecoderBatch::forwardAsync(decoder::Output& output, decoder::Input const& input)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
decoder_batch::Input batchInput{input.logits};
@ -444,12 +455,24 @@ bool GptDecoderBatch::forward(decoder::Output& output, decoder::Input const& inp
batchOutput.cacheIndirection = output.cacheIndirection;
batchOutput.sequenceLengths = output.sequenceLengths;
forward(batchOutput, batchInput);
auto finished = getFinished();
mForwardToken = forwardAsync(batchOutput, batchInput);
mBufferManager.setZero(*mFinishedSum);
kernels::reduce(*mFinishedSum, *ITensor::slice(mJointDecodingOutput->finishedSum, 0, mActualBatchSize), *mStream);
mStream->record(mForwardEvent);
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
return std::all_of(finished.begin(), finished.end(), [](bool x) { return x; });
}
bool GptDecoderBatch::isFinishedSync()
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
forwardSync(*mForwardToken);
auto const finished
= std::all_of(mFinished.begin(), mFinished.begin() + mActualBatchSize, [](bool x) { return x; });
// wait for mFinishedSum to be updated
mStream->wait(mForwardEvent);
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
return finished;
}
IStatefulGptDecoder::TensorPtr GptDecoderBatch::getFinalOutputIds() const
@ -460,5 +483,13 @@ IStatefulGptDecoder::TensorPtr GptDecoderBatch::getFinalOutputIds() const
postProcessRequest(batchIdx);
}
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
return ITensor::slice(getOutputIds(), 0, mActualBatchSize);
return getOutputIds();
}
IStatefulGptDecoder::TensorPtr GptDecoderBatch::getFinalOutputIds(SizeType batchIdx) const
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
postProcessRequest(batchIdx);
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
return getOutputIds(batchIdx);
}

View File

@ -72,8 +72,6 @@ GptJsonConfig parseJson(InputType&& i)
else
TLLM_CHECK_WITH_INFO(false, tc::fmtstr("Model data type '%s' not supported", precision.c_str()));
auto const pagedKvCache = parseJsonFieldOr(builderConfig, "paged_kv_cache", false);
auto const tokensPerBlock = parseJsonFieldOr(builderConfig, "tokens_per_block", 0);
auto const quantMode = tc::QuantMode(parseJsonFieldOr(builderConfig, "quant_mode", tc::QuantMode::none().value()));
auto const numKvHeads
= parseJsonFieldOr(builderConfig, "num_kv_heads", numHeads * tensorParallelism) / tensorParallelism;
@ -81,7 +79,11 @@ GptJsonConfig parseJson(InputType&& i)
auto const maxInputLen = parseJsonFieldOr(builderConfig, "max_input_len", 0);
auto const maxOutputLen = parseJsonFieldOr(builderConfig, "max_output_len", 0);
auto const computeContextLogits = parseJsonFieldOr(builderConfig, "gather_all_token_logits", false);
auto const& pluginConfig = json.at("plugin_config");
auto const pagedKvCache = pluginConfig.at("paged_kv_cache");
auto const tokensPerBlock = pluginConfig.at("tokens_per_block");
auto const& gptAttentionPlugin = pluginConfig.at("gpt_attention_plugin");
auto const useGptAttentionPlugin = !gptAttentionPlugin.is_boolean() || gptAttentionPlugin.template get<bool>();
auto const removeInputPadding = pluginConfig.at("remove_input_padding").template get<bool>();
@ -93,6 +95,7 @@ GptJsonConfig parseJson(InputType&& i)
modelConfig.setTokensPerBlock(tokensPerBlock);
modelConfig.setQuantMode(quantMode);
modelConfig.setNbKvHeads(numKvHeads);
modelConfig.computeContextLogits(computeContextLogits);
modelConfig.setMaxBatchSize(maxBatchSize);
modelConfig.setMaxInputLen(maxInputLen);

View File

@ -21,7 +21,6 @@
#include "iBuffer.h"
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/kernels/decodingKernels.h"
#include "tensorrt_llm/runtime/gptDecoderBatch.h"
#include "tensorrt_llm/runtime/ncclCommunicator.h"
@ -32,6 +31,7 @@
#include "tensorrt_llm/runtime/tllmRuntime.h"
#include "tensorrt_llm/runtime/utils/sessionUtils.h"
#include <algorithm>
#include <cstdint>
#include <fstream>
@ -47,16 +47,15 @@ GptSession::GptSession(GptModelConfig const& modelConfig, WorldConfig const& wor
, mDevice{utils::initDevice(worldConfig)}
, mLogger{logger ? std::move(logger) : std::make_shared<TllmLogger>()}
, mRuntime{std::make_shared<TllmRuntime>(engineBuffer, engineSize, *mLogger)}
, mDecoder{}
, mBuffers{std::make_shared<RuntimeBuffers>()}
, mNumMicroBatches{worldConfig.getPipelineParallelism()}
, mDecoders{}
, mBuffers{}
, mCudaGraphInstances{}
{
createContexts();
mBuffers->create(*mRuntime, mModelConfig, mWorldConfig);
if (mWorldConfig.isPipelineParallel())
{
mPipelineComm = NcclCommunicator::createPipelineComm(mWorldConfig, *mLogger);
mCommStream = std::make_shared<CudaStream>();
}
// TODO compare expected and runtime tensor names?
@ -72,47 +71,114 @@ BufferManager& GptSession::getBufferManager() const
return mRuntime->getBufferManager();
}
void GptSession::createContexts()
void GptSession::createContexts(SizeType numMicroBatches)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
mRuntime->clearContexts();
auto numProfiles = mRuntime->getNbProfiles();
// Instantiate multiple contexts for flip-flopping
auto const numContextsPerPhase = std::max(2, numMicroBatches);
auto const numProfiles = mRuntime->getNbProfiles();
TLLM_CHECK_WITH_INFO(
numProfiles == 1 || numProfiles == 2, "GPT only expects one optimization profile or two optimization profiles");
// Instantiate two contexts for flip-flopping
if (numProfiles == 1)
auto constexpr ctxContextId = 0;
auto constexpr genContextId = 1;
if (numProfiles == 2)
{
mRuntime->addContext(0);
mRuntime->addContext(0);
for (auto i = 0; i < numContextsPerPhase; ++i)
mRuntime->addContext(genContextId);
}
else
for (auto i = 0; i < numContextsPerPhase; ++i)
mRuntime->addContext(ctxContextId);
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
}
void GptSession::createBuffers(SizeType numMicroBatches)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
mBuffers.clear();
for (SizeType i = 0; i < numMicroBatches; ++i)
{
mRuntime->addContext(1);
mRuntime->addContext(1);
mRuntime->addContext(0);
mBuffers.emplace_back(std::make_shared<RuntimeBuffers>());
mBuffers.back()->create(*mRuntime, mModelConfig, mWorldConfig);
}
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
}
void GptSession::createDecoder(bool decoderPerRequest)
void GptSession::createDecoders(SizeType batchSize, SizeType beamWidth, SizeType maxSequenceLength,
nvinfer1::DataType logitsType, bool decoderPerRequest, SizeType numMicroBatches)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
auto const vocabSize = mModelConfig.getVocabSize();
auto const vocabSizePadded = mModelConfig.getVocabSizePadded(mWorldConfig.getSize());
auto const& stream = mRuntime->getStreamPtr();
if (decoderPerRequest)
mDecoder = std::make_shared<GptDecoderBatch>(vocabSize, vocabSizePadded, stream);
else
mDecoder = std::make_shared<StatefulGptDecoder>(vocabSize, vocabSizePadded, stream);
mDecoders.clear();
for (SizeType i = 0; i < numMicroBatches; ++i)
{
if (decoderPerRequest)
mDecoders.emplace_back(std::make_shared<GptDecoderBatch>(vocabSize, vocabSizePadded, stream));
else
mDecoders.emplace_back(std::make_shared<StatefulGptDecoder>(vocabSize, vocabSizePadded, stream));
mDecoders.back()->setup(batchSize, beamWidth, maxSequenceLength, logitsType);
}
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
}
void GptSession::setup(SizeType const batchSize, SizeType const beamWidth, SizeType const maxSequenceLength,
bool decoderPerRequest, std::optional<SizeType> maxTokensInPagedKvCache)
void GptSession::createKvCacheManagers(SizeType batchSize, SizeType beamWidth, SizeType maxSequenceLength,
SizeType numMicroBatches, std::optional<SizeType> maxTokensInPagedKvCache)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
auto const localNbLayers = mModelConfig.getNbLayers(mWorldConfig.getPipelineParallelism());
auto const nbHeads = mModelConfig.getNbHeads();
auto const nbKvHeads = mModelConfig.getNbKvHeads();
auto const hiddenSize = mModelConfig.getHiddenSize();
auto const tokensPerBlock = mModelConfig.getTokensPerBlock();
auto const maxBlocksPerSeq = tc::divUp(maxSequenceLength, tokensPerBlock);
auto const maxNumTokens
= maxTokensInPagedKvCache.value_or(batchSize * beamWidth * maxBlocksPerSeq * tokensPerBlock);
auto const maxNumBlocks = tc::divUp(maxNumTokens, tokensPerBlock);
nvinfer1::DataType kvDtype;
if (mModelConfig.getQuantMode().hasFp8KvCache())
{
kvDtype = nvinfer1::DataType::kFP8;
}
else if (mModelConfig.getQuantMode().hasInt8KvCache())
{
kvDtype = nvinfer1::DataType::kINT8;
}
else
{
kvDtype = mModelConfig.getDataType();
}
mKvCacheManagers.clear();
for (SizeType i = 0; i < numMicroBatches; ++i)
{
mKvCacheManagers.emplace_back(
std::make_shared<bmkv::KVCacheManager>(localNbLayers, nbHeads, nbKvHeads, hiddenSize, tokensPerBlock,
maxNumBlocks, batchSize, beamWidth, maxBlocksPerSeq, kvDtype, mRuntime->getStreamPtr()));
}
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
}
void GptSession::setup(SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxSequenceLength, bool decoderPerRequest,
std::optional<SizeType> maxTokensInPagedKvCache, std::optional<SizeType> numMicroBatches)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
if (numMicroBatches)
mNumMicroBatches = numMicroBatches.value();
createContexts(mNumMicroBatches);
createBuffers(mNumMicroBatches);
auto const microBatchSize = tc::ceilDiv(maxBatchSize, mNumMicroBatches);
// Store this param related to deocder buffer size and kv cache manager to check against
// the input shape with the params given in generate().
// gptDecoderBatch does not resize buffers, but allows smaller batchSize and beamWidth.
@ -121,55 +187,51 @@ void GptSession::setup(SizeType const batchSize, SizeType const beamWidth, SizeT
if (mModelConfig.usePagedKvCache())
{
auto const numLayers = mModelConfig.getNbLayers();
auto const nbHeads = mModelConfig.getNbHeads();
auto const nbKvHeads = mModelConfig.getNbKvHeads();
auto const hiddenSize = mModelConfig.getHiddenSize();
auto const tokensPerBlock = mModelConfig.getTokensPerBlock();
auto const maxBlocksPerSeq = tc::divUp(maxSequenceLength, tokensPerBlock);
auto const maxNumTokens
= maxTokensInPagedKvCache.value_or(batchSize * beamWidth * maxBlocksPerSeq * tokensPerBlock);
auto const maxNumBlocks = tc::divUp(maxNumTokens, tokensPerBlock);
auto const kvDtype = mBuffers->presentKeysVals.at(0)->getDataType();
// init KV cache block manager
mKvCacheManager = std::make_shared<bmkv::KVCacheManager>(numLayers, nbHeads, nbKvHeads, hiddenSize,
tokensPerBlock, maxNumBlocks, batchSize, kvDtype, mRuntime->getStreamPtr());
createKvCacheManagers(
microBatchSize, maxBeamWidth, maxSequenceLength, mNumMicroBatches, maxTokensInPagedKvCache);
}
if (mWorldConfig.isLastPipelineParallelRank())
{
auto const logitsType = mRuntime->getEngine().getTensorDataType("logits");
createDecoder(decoderPerRequest);
mDecoder->setup(batchSize, beamWidth, maxSequenceLength, logitsType);
createDecoders(
microBatchSize, maxBeamWidth, maxSequenceLength, logitsType, decoderPerRequest, mNumMicroBatches);
}
// reshape does not care about maxInputLength or maxNewTokens
auto const generationConfig = RuntimeBuffers::GenerationConfig{batchSize, beamWidth, 0, 0, maxSequenceLength};
mBuffers->reshape(generationConfig, mModelConfig, mWorldConfig);
if (mWorldConfig.isPipelineParallel())
{
mReceivedEvents.clear();
for (SizeType i = 0; i < mNumMicroBatches; ++i)
mReceivedEvents.emplace_back();
}
// we don't know maxInputLength and maxNewTokens yet and ignore those for pre-allocation
auto const generationConfig
= RuntimeBuffers::GenerationConfig{microBatchSize, maxBeamWidth, 0, 0, maxSequenceLength};
for (auto& buffers : mBuffers)
buffers->reshape(generationConfig, mModelConfig, mWorldConfig);
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
}
void GptSession::generate(
void GptSession::generateSingleBatch(
GenerationOutput& outputs, GenerationInput const& inputs, SamplingConfig const& samplingConfig)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
TLLM_CHECK_WITH_INFO(inputs.packed == mModelConfig.usePackedInput(),
"The chosen model requires a packed input tensor (did you set packed?).");
TLLM_CHECK_WITH_INFO(inputs.lengths->getShape().nbDims == 1, "Input lengths tensor must be one-dimensional.");
auto const& inputLengths = inputs.lengths;
TLLM_CHECK_WITH_INFO(inputLengths->getShape().nbDims == 1, "Input lengths tensor must be one-dimensional.");
auto constexpr microBatchId = 0;
auto& manager = mRuntime->getBufferManager();
auto& stream = mRuntime->getStream();
auto& buffers = *mBuffers;
buffers.contextLengthsDevice = inputs.lengths;
buffers.contextLengthsHost->reshape(inputs.lengths->getShape());
manager.copy(*buffers.contextLengthsDevice, *buffers.contextLengthsHost);
manager.getStream().synchronize();
auto const generationConfig = RuntimeBuffers::GenerationConfig::fromInput(inputs.ids, buffers.contextLengthsHost,
// Initialize and reshape buffers
auto& buffers = *mBuffers.at(microBatchId);
TLLM_CHECK_WITH_INFO(buffers.allocated, "Buffers not allocated, please call setup first!");
buffers.initContextLengths(inputLengths, manager);
auto const generationConfig = RuntimeBuffers::GenerationConfig::fromInput(*inputs.ids, *buffers.contextLengthsHost,
inputs.packed, samplingConfig.beamWidth, mDecoderMaxSequenceLength, inputs.maxNewTokens, manager);
auto const batchSize = generationConfig.batchSize;
@ -177,49 +239,42 @@ void GptSession::generate(
auto const maxInputLength = generationConfig.maxInputLength;
auto const maxNewTokens = generationConfig.maxNewTokens;
TLLM_CHECK_WITH_INFO(buffers.allocated, "Buffers not allocated, please call setup first!");
if (mWorldConfig.isLastPipelineParallelRank() && mModelConfig.computeContextLogits())
{
auto const vocabSizePadded = mModelConfig.getVocabSizePadded(mWorldConfig.getSize());
TLLM_CHECK_WITH_INFO(outputs.contextLogits,
"outputs.contextLogits is nullptr. It must be allocated when computeContextLogits() is enabled.");
outputs.contextLogits->reshape(ITensor::makeShape({batchSize, maxInputLength, vocabSizePadded}));
auto const contextLogitsShape = outputs.contextLogits->getShape();
TLLM_CHECK_WITH_INFO(contextLogitsShape.d[0] == batchSize, "Invalid dim[0]");
TLLM_CHECK_WITH_INFO(contextLogitsShape.d[1] == maxInputLength, "Invalid dim[1]");
TLLM_CHECK_WITH_INFO(contextLogitsShape.d[2] == vocabSizePadded, "Invalid dim[2]");
buffers.logits = outputs.contextLogits;
}
buffers.reshape(generationConfig, mModelConfig, mWorldConfig);
kvCacheAddSequences(beamWidth, microBatchId);
ITensor::SharedPtr newTokens{initNewTokens(inputs, samplingConfig, microBatchId)};
if (mModelConfig.usePagedKvCache())
{
auto const contextLengthsHost = bufferCast<SizeType const>(*buffers.contextLengthsHost);
for (auto batchIdx = 0; batchIdx < batchSize; ++batchIdx)
{
mKvCacheManager->addSequence(batchIdx, contextLengthsHost[batchIdx], beamWidth);
}
}
RuntimeBuffers::TensorMap inputBuffers[2];
RuntimeBuffers::TensorMap outputBuffers[2];
auto& onTokenGenerated = outputs.onTokenGenerated;
outputs.ids->reshape(ITensor::makeShape({batchSize, beamWidth, mDecoderMaxSequenceLength}));
ITensor::SharedPtr newTokens;
if (mWorldConfig.isLastPipelineParallelRank())
{
mDecoder->newBatch(inputs, samplingConfig);
newTokens = mDecoder->getNewTokens();
}
else if (mWorldConfig.isFirstPipelineParallelRank())
{
newTokens = manager.gpu(ITensor::makeShape({batchSize, beamWidth}), nvinfer1::DataType::kINT32);
}
auto kvCacheManager = mModelConfig.usePagedKvCache() ? mKvCacheManagers.at(microBatchId).get() : nullptr;
RuntimeBuffers::TensorMap inputBuffers[2];
RuntimeBuffers::TensorMap outputBuffers[2];
for (SizeType step = 0; step < maxNewTokens; ++step)
{
auto const contextId = step % 2;
bool enqueueSuccessful = false;
if (step == 0)
{
SizeType contextIdForContextPhase = 0;
if (mRuntime->getNbProfiles() == 2)
{
contextIdForContextPhase = 2;
}
SizeType const contextIdForContextPhase
= mRuntime->getNbProfiles() == 2 ? mRuntime->getNbContexts() / 2 : 0;
buffers.prepareContextStep(
inputs.ids, inputs.padId, manager, *mKvCacheManager, generationConfig, mModelConfig, mWorldConfig);
buffers.getRuntimeBuffers(inputBuffers[contextId], outputBuffers[contextId], step, inputs.ids,
*mKvCacheManager, mModelConfig, mWorldConfig);
inputs.ids, inputs.padId, manager, kvCacheManager, generationConfig, mModelConfig, mWorldConfig);
buffers.getRuntimeBuffers(
inputBuffers[contextId], outputBuffers[contextId], step, inputs.ids, mModelConfig, mWorldConfig);
mRuntime->setInputTensors(contextIdForContextPhase, inputBuffers[contextId]);
mRuntime->setOutputTensors(contextIdForContextPhase, outputBuffers[contextId]);
@ -230,22 +285,22 @@ void GptSession::generate(
instance.clear();
}
}
enqueueSuccessful = mRuntime->executeContext(contextIdForContextPhase);
TLLM_CHECK_WITH_INFO(
mRuntime->executeContext(contextIdForContextPhase), "Executing TRT engine in context phase failed!");
}
else
{
if (isCudaGraphMode() && mCudaGraphInstances[contextId].hasInstance())
{
mCudaGraphInstances[contextId].launch(stream);
enqueueSuccessful = true;
mCudaGraphInstances[contextId].launch(mRuntime->getStream());
}
else
{
enqueueSuccessful = mRuntime->executeContext(contextId);
TLLM_CHECK_WITH_INFO(
mRuntime->executeContext(contextId), "Executing TRT engine in generation phase failed!");
}
}
TLLM_CHECK_WITH_INFO(enqueueSuccessful, "Executing TRT engine failed!");
sync_check_cuda_error();
if (step == 0)
@ -255,14 +310,14 @@ void GptSession::generate(
std::swap(buffers.cacheIndirectionDecoderInput, buffers.cacheIndirectionDecoderOutput);
if (step < maxNewTokens - 1)
{
if (step < maxNewTokens - 1) // this is not the last step
{ // preparing the next step
auto const nextStep = step + 1;
auto const nextContextId = nextStep % 2;
auto nextInputIds = buffers.prepareNextStep(
step, newTokens, manager, *mKvCacheManager, generationConfig, mModelConfig, mWorldConfig);
step, newTokens, manager, kvCacheManager, generationConfig, mModelConfig, mWorldConfig);
buffers.getRuntimeBuffers(inputBuffers[nextContextId], outputBuffers[nextContextId], nextStep, nextInputIds,
*mKvCacheManager, mModelConfig, mWorldConfig);
mModelConfig, mWorldConfig);
mRuntime->setInputTensors(nextContextId, inputBuffers[nextContextId]);
mRuntime->setOutputTensors(nextContextId, outputBuffers[nextContextId]);
@ -277,7 +332,8 @@ void GptSession::generate(
// FIXME(nkorobov): this synchronize is important to get logits right
// manager.getStream().synchronize();
auto shouldStop = executeDecoderStep(outputs.ids, newTokens, maxInputLength + step);
decoderStepAsync(outputs.ids, newTokens, maxInputLength + step, microBatchId);
auto const shouldStop = shouldStopSync(batchSize, beamWidth, microBatchId);
if (mWorldConfig.isFirstPipelineParallelRank())
{
@ -285,7 +341,7 @@ void GptSession::generate(
{
// TODO(rkobus) use getNewTokens(), remove step from Callback?
ITensor::SharedPtr outputIds
= mWorldConfig.isPipelineParallel() ? outputs.ids : mDecoder->getOutputIds();
= mWorldConfig.isPipelineParallel() ? outputs.ids : mDecoders.at(microBatchId)->getOutputIds();
onTokenGenerated(outputIds, step, shouldStop || step == maxNewTokens - 1);
}
}
@ -301,97 +357,407 @@ void GptSession::generate(
{
for (auto batchIdx = 0; batchIdx < batchSize; ++batchIdx)
{
mKvCacheManager->removeSequence(batchIdx);
kvCacheManager->removeSequence(batchIdx);
}
}
finalizeOutputIds(*outputs.ids);
finalizeOutputIds(*outputs.ids, microBatchId);
manager.getStream().synchronize();
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
}
bool GptSession::executeDecoderStep(ITensor::SharedPtr& outputIds, ITensor::SharedPtr& newTokens, SizeType decoderStep)
void GptSession::kvCacheAddSequences(SizeType beamWidth, SizeType microBatchId)
{
if (mModelConfig.usePagedKvCache())
{
auto& kvCacheManager = mKvCacheManagers.at(microBatchId);
TLLM_CHECK(kvCacheManager);
auto contextLengthsHost = mBuffers.at(microBatchId)->contextLengthsHost;
TLLM_CHECK(contextLengthsHost);
auto const contextLengthsPtr = bufferCast<SizeType const>(*contextLengthsHost);
auto const contextLengthsSize = static_cast<SizeType>(contextLengthsHost->getSize());
for (SizeType batchIdx = 0; batchIdx < contextLengthsSize; ++batchIdx)
{
kvCacheManager->addSequence(batchIdx, contextLengthsPtr[batchIdx], beamWidth);
}
}
}
ITensor::SharedPtr GptSession::initNewTokens(
GenerationInput const& inputs, SamplingConfig const& samplingConfig, SizeType microBatchId)
{
if (mWorldConfig.isLastPipelineParallelRank())
{
auto& decoder = mDecoders.at(microBatchId);
decoder->newBatch(inputs, samplingConfig);
return decoder->getNewTokens();
}
else if (mWorldConfig.isFirstPipelineParallelRank())
{
auto const beamWidth = samplingConfig.beamWidth;
auto const batchSize = static_cast<SizeType>(inputs.lengths->getSize());
return mRuntime->getBufferManager().gpu(ITensor::makeShape({batchSize, beamWidth}), nvinfer1::DataType::kINT32);
}
else
{
return ITensor::SharedPtr{};
}
}
namespace
{
std::vector<GenerationInput> splitInputs(
GenerationInput const& inputs, SizeType numMicroBatches, BufferManager& manager)
{
std::vector<GenerationInput> inputBatches;
auto const numRequests = inputs.lengths->getShape().d[0];
auto const microBatchSize = tc::ceilDiv(numRequests, numMicroBatches);
if (inputs.packed)
{
auto contextLengthsHost = manager.copyFrom(*inputs.lengths, MemoryType::kCPU);
ITensor::SharedPtr inputIdsView = ITensor::view(inputs.ids);
inputIdsView->squeeze(0);
auto contextLengthsRange = BufferRange<SizeType>(*contextLengthsHost);
auto tokensBegin = 0;
for (auto offset = 0; offset < numRequests; offset += microBatchSize)
{
auto batchSize = std::min(microBatchSize, numRequests - offset);
auto numTokens = std::accumulate(
contextLengthsRange.begin() + offset, contextLengthsRange.begin() + offset + batchSize, 0);
ITensor::SharedPtr batchInputs = ITensor::slice(inputIdsView, tokensBegin, numTokens);
batchInputs->reshape(ITensor::makeShape({1, numTokens}));
inputBatches.emplace_back(inputs.endId, inputs.padId, batchInputs,
ITensor::slice(inputs.lengths, offset, batchSize), inputs.packed);
tokensBegin += numTokens;
}
}
else
{
for (auto offset = 0; offset < numRequests; offset += microBatchSize)
{
auto batchSize = std::min(microBatchSize, numRequests - offset);
inputBatches.emplace_back(inputs.endId, inputs.padId, ITensor::slice(inputs.ids, offset, batchSize),
ITensor::slice(inputs.lengths, offset, batchSize), inputs.packed);
}
}
for (auto& batch : inputBatches)
{
if (inputs.embeddingBiasOpt)
batch.embeddingBiasOpt = inputs.embeddingBiasOpt;
if (inputs.badWordsList)
batch.badWordsList = inputs.badWordsList;
if (inputs.stopWordsList)
batch.stopWordsList = inputs.stopWordsList;
if (inputs.maxNewTokens)
batch.maxNewTokens = inputs.maxNewTokens;
}
return inputBatches;
}
void updateOutputIds(
ITensor::SharedPtr& outputIds, ITensor::SharedPtr& newTokens, SizeType decoderStep, CudaStream const& stream)
{ // assemble outputIds of all micro batches
auto const& newTokensShape = newTokens->getShape();
auto newTokensView = ITensor::view(newTokens, ITensor::makeShape({1, newTokensShape.d[0] * newTokensShape.d[1]}));
auto const& outputIdsShape = outputIds->getShape();
auto outputIdsView = ITensor::view(
outputIds, ITensor::makeShape({outputIdsShape.d[0] * outputIdsShape.d[1], outputIdsShape.d[2]}));
kernels::invokeTransposeWithOutputOffset(*outputIdsView, *newTokensView, decoderStep, stream);
}
} // namespace
void GptSession::generateMultiBatch(
GenerationOutput& outputs, GenerationInput const& inputs, SamplingConfig const& samplingConfig)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
TLLM_CHECK_WITH_INFO(inputs.packed == mModelConfig.usePackedInput(),
"The chosen model requires a packed input tensor (did you set packed?).");
auto const& inputLengths = inputs.lengths;
TLLM_CHECK_WITH_INFO(inputLengths->getShape().nbDims == 1, "Input lengths tensor must be one-dimensional.");
auto& manager = mRuntime->getBufferManager();
auto const batchSize = static_cast<SizeType>(inputLengths->getSize());
auto const beamWidth = samplingConfig.beamWidth;
outputs.ids->reshape(ITensor::makeShape({batchSize, beamWidth, mDecoderMaxSequenceLength}));
auto& onTokenGenerated = outputs.onTokenGenerated;
auto const numMicroBatches = std::min(batchSize, mNumMicroBatches);
auto microBatches = splitInputs(inputs, numMicroBatches, manager);
std::vector<RuntimeBuffers::GenerationConfig> generationConfigs;
std::vector<ITensor::SharedPtr> newTokensPerBatch;
std::vector<ITensor::SharedPtr> outputIdsPerBatch;
for (auto microBatchId = 0; microBatchId < numMicroBatches; ++microBatchId)
{
auto const& microBatchInputs = microBatches.at(microBatchId);
// Initialize and reshape buffers
auto& buffers = *mBuffers.at(microBatchId);
TLLM_CHECK_WITH_INFO(buffers.allocated, "Buffers not allocated, please call setup first!");
buffers.initContextLengths(microBatchInputs.lengths, manager);
generationConfigs.emplace_back(RuntimeBuffers::GenerationConfig::fromInput(*microBatchInputs.ids,
*buffers.contextLengthsHost, microBatchInputs.packed, samplingConfig.beamWidth, mDecoderMaxSequenceLength,
microBatchInputs.maxNewTokens, manager));
auto const& generationConfig = generationConfigs.back();
auto const beamWidth = generationConfig.beamWidth;
buffers.reshape(generationConfig, mModelConfig, mWorldConfig);
kvCacheAddSequences(beamWidth, microBatchId);
newTokensPerBatch.emplace_back(initNewTokens(microBatchInputs, samplingConfig, microBatchId));
}
auto maxNewTokens = generationConfigs.front().maxNewTokens;
auto microBatchSize = generationConfigs.front().batchSize;
auto offset = 0;
outputIdsPerBatch.emplace_back(ITensor::slice(outputs.ids, offset, microBatchSize));
offset += microBatchSize;
for (auto microBatchId = 1; microBatchId < numMicroBatches; ++microBatchId)
{
maxNewTokens = std::min(maxNewTokens, generationConfigs.at(microBatchId).maxNewTokens);
auto microBatchSize = generationConfigs.at(microBatchId).batchSize;
outputIdsPerBatch.emplace_back(ITensor::slice(outputs.ids, offset, microBatchSize));
offset += microBatchSize;
}
// TODO(micro batching) do we need 1 or 2 per micro batch?
std::vector<RuntimeBuffers::TensorMap> inputBuffers(numMicroBatches * 2);
std::vector<RuntimeBuffers::TensorMap> outputBuffers(numMicroBatches * 2);
std::vector<bool> microBatchesFinished(numMicroBatches, false);
for (SizeType step = 0; step < maxNewTokens; ++step)
{
if (std::all_of(microBatchesFinished.begin(), microBatchesFinished.end(), [](bool x) { return x; }))
break;
for (auto microBatchId = 0; microBatchId < numMicroBatches; ++microBatchId)
{
auto& buffers = *mBuffers.at(microBatchId);
auto kvCacheManager = mModelConfig.usePagedKvCache() ? mKvCacheManagers.at(microBatchId).get() : nullptr;
auto& newTokens = newTokensPerBatch.at(microBatchId);
auto& generationConfig = generationConfigs.at(microBatchId);
auto& outputIds = outputIdsPerBatch.at(microBatchId);
if (microBatchesFinished.at(microBatchId))
continue;
if (step > 0)
{
auto const microBatchSize = generationConfig.batchSize;
auto const beamWidth = generationConfig.beamWidth;
auto const shouldStop = shouldStopSync(microBatchSize, beamWidth, microBatchId);
if (mWorldConfig.isFirstPipelineParallelRank() && onTokenGenerated
&& microBatchId == numMicroBatches - 1)
{
onTokenGenerated(outputs.ids, step - 1, shouldStop);
}
if (shouldStop)
{
mLogger->log(nvinfer1::ILogger::Severity::kVERBOSE, "GPT decoding finished early");
microBatchesFinished.at(microBatchId) = true;
continue;
}
}
auto const contextId = microBatchId;
if (step == 0)
{
SizeType const contextIdForContextPhase
= contextId + (mRuntime->getNbProfiles() == 2 ? mNumMicroBatches : 0);
auto const& inputs = microBatches.at(microBatchId);
buffers.prepareContextStep(
inputs.ids, inputs.padId, manager, kvCacheManager, generationConfig, mModelConfig, mWorldConfig);
buffers.getRuntimeBuffers(
inputBuffers[contextId], outputBuffers[contextId], step, inputs.ids, mModelConfig, mWorldConfig);
mRuntime->setInputTensors(contextIdForContextPhase, inputBuffers[contextId]);
mRuntime->setOutputTensors(contextIdForContextPhase, outputBuffers[contextId]);
TLLM_CHECK_WITH_INFO(
mRuntime->executeContext(contextIdForContextPhase), "Executing TRT engine failed!");
buffers.postContextStep(manager, generationConfig, mModelConfig, mWorldConfig);
}
else
{
auto nextInputIds = buffers.prepareNextStep(
step - 1, newTokens, manager, kvCacheManager, generationConfig, mModelConfig, mWorldConfig);
buffers.getRuntimeBuffers(
inputBuffers[contextId], outputBuffers[contextId], step, nextInputIds, mModelConfig, mWorldConfig);
mRuntime->setInputTensors(contextId, inputBuffers[contextId]);
mRuntime->setOutputTensors(contextId, outputBuffers[contextId]);
TLLM_CHECK_WITH_INFO(mRuntime->executeContext(contextId), "Executing TRT engine failed!");
}
sync_check_cuda_error();
std::swap(buffers.cacheIndirectionDecoderInput, buffers.cacheIndirectionDecoderOutput);
auto const maxInputLength = generationConfigs.at(microBatchId).maxInputLength;
auto const decoderStep = maxInputLength + step;
decoderStepAsync(outputIds, newTokens, decoderStep, microBatchId);
if (!mWorldConfig.isPipelineParallel() && mNumMicroBatches > 1)
{
updateOutputIds(outputIds, newTokens, decoderStep, mRuntime->getStream());
}
}
}
// TODO(micro batching) move into loop above?
for (auto microBatchId = 0; microBatchId < numMicroBatches; ++microBatchId)
{
auto const& generationConfig = generationConfigs.at(microBatchId);
auto const microBatchSize = generationConfig.batchSize;
auto kvCacheManager = mModelConfig.usePagedKvCache() ? mKvCacheManagers.at(microBatchId).get() : nullptr;
auto& outputIds = outputIdsPerBatch.at(microBatchId);
// TODO(micro batching) sync receive event
if (mWorldConfig.isFirstPipelineParallelRank() && onTokenGenerated && microBatchId == numMicroBatches - 1)
{
onTokenGenerated(outputs.ids, maxNewTokens - 1, true);
}
if (mModelConfig.usePagedKvCache())
{
for (auto batchIdx = 0; batchIdx < microBatchSize; ++batchIdx)
{
kvCacheManager->removeSequence(batchIdx);
}
}
// TODO(micro batching) use mCommStream?
finalizeOutputIds(*outputIds, microBatchId);
}
manager.getStream().synchronize();
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
}
void GptSession::decoderStepAsync(
ITensor::SharedPtr& outputIds, ITensor::SharedPtr& newTokens, SizeType decoderStep, SizeType microBatchId)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
auto& stream = mRuntime->getStream();
auto& buffers = *mBuffers;
auto& buffers = *mBuffers.at(microBatchId);
auto shouldStopPtr = bufferCast<std::uint8_t>(*buffers.shouldStop);
auto& shouldStop = *shouldStopPtr;
shouldStop = false;
if (mWorldConfig.isLastPipelineParallelRank())
{
auto& decoder = *mDecoders.at(microBatchId);
decoder::Input decodingInput{buffers.logits};
decoder::Output decodingOutput{};
decodingInput.cacheIndirection = buffers.cacheIndirectionDecoderInput;
decodingOutput.cacheIndirection = buffers.cacheIndirectionDecoderOutput;
decodingOutput.sequenceLengths = buffers.sequenceLengths;
shouldStop = mDecoder->forward(decodingOutput, decodingInput);
decoder.forwardAsync(decodingOutput, decodingInput);
if (mWorldConfig.isPipelineParallel())
{ // send shouldStop to all previous ranks and newTokens to the first rank
stream.record(mCommEvent.get());
mCommStream->wait(mCommEvent.get());
auto const pipelineGroup = mWorldConfig.getPipelineParallelGroup();
auto& cacheIndirection = *buffers.cacheIndirectionDecoderOutput;
auto& sequenceLengths = *buffers.sequenceLengths;
auto const beamWidth = cacheIndirection.getShape().d[1];
for (auto peerIdx = 0; peerIdx < mWorldConfig.getPipelineParallelism() - 1; ++peerIdx)
{
mPipelineComm->send<SizeType>(*decoder.getNbFinished(), pipelineGroup[peerIdx], *mCommStream, *mLogger);
if (beamWidth > 1)
{
mPipelineComm->send<SizeType>(cacheIndirection, pipelineGroup[peerIdx], *mCommStream, *mLogger);
}
mPipelineComm->send<SizeType>(sequenceLengths, pipelineGroup[peerIdx], *mCommStream, *mLogger);
}
mPipelineComm->send<TokenIdType>(*decoder.getNewTokens(), pipelineGroup.front(), *mCommStream, *mLogger);
}
}
else // pipeline parallel mode
{ // receive shouldStop from the last rank on a separate stream
stream.record(mCommEvent.get());
mCommStream->wait(mCommEvent.get());
auto const pipelineGroup = mWorldConfig.getPipelineParallelGroup();
auto const peer = pipelineGroup.back();
mPipelineComm->receive<SizeType>(*buffers.nbFinished, peer, *mCommStream, *mLogger);
if (mWorldConfig.isPipelineParallel())
{
if (mWorldConfig.isLastPipelineParallelRank())
auto& cacheIndirection = *buffers.cacheIndirectionDecoderOutput;
auto& sequenceLengths = *buffers.sequenceLengths;
auto const beamWidth = cacheIndirection.getShape().d[1];
if (beamWidth > 1)
{
for (auto peer = 0; peer < mWorldConfig.getPipelineParallelism() - 1; ++peer)
{
mPipelineComm->send(shouldStopPtr, 1, peer, stream, *mLogger);
}
mPipelineComm->send(bufferCast<std::int32_t>(*newTokens), newTokens->getSize(), 0, stream, *mLogger);
mPipelineComm->receive<SizeType>(cacheIndirection, peer, *mCommStream, *mLogger);
}
else
{
auto const peer = mWorldConfig.getPipelineParallelism() - 1;
mPipelineComm->receive(shouldStopPtr, 1, peer, stream, *mLogger);
if (mWorldConfig.isFirstPipelineParallelRank())
{
mPipelineComm->receive(
bufferCast<std::int32_t>(*newTokens), newTokens->getSize(), peer, stream, *mLogger);
auto const& newTokensShape = newTokens->getShape();
auto newTokensView
= ITensor::view(outputIds, ITensor::makeShape({1, newTokensShape.d[0] * newTokensShape.d[1]}));
auto const& outputIdsShape = outputIds->getShape();
auto outputIdsView = ITensor::view(
outputIds, ITensor::makeShape({outputIdsShape.d[0] * outputIdsShape.d[1], outputIdsShape.d[2]}));
kernels::invokeTransposeWithOutputOffset(*outputIdsView, *newTokensView, decoderStep, stream);
}
mPipelineComm->receive<SizeType>(sequenceLengths, peer, *mCommStream, *mLogger);
if (mWorldConfig.isFirstPipelineParallelRank())
{ // receive newTokens from last rank on a separate stream
mPipelineComm->receive<TokenIdType>(*newTokens, peer, *mCommStream, *mLogger);
updateOutputIds(outputIds, newTokens, decoderStep, *mCommStream);
}
mCommStream->record(mReceivedEvents.at(microBatchId).get());
}
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
}
bool GptSession::shouldStopSync(SizeType batchSize, SizeType beamWidth, SizeType microBatchId)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
SizeType nbFinished = 0;
if (mWorldConfig.isLastPipelineParallelRank())
{ // read the Finished flag from the decoder
auto& decoder = *mDecoders.at(microBatchId);
decoder.isFinishedSync();
nbFinished = *bufferCast<SizeType>(*decoder.getNbFinished());
}
else
{ // ensure all information has been received
TLLM_CUDA_CHECK(cudaEventSynchronize(mReceivedEvents.at(microBatchId).get()));
nbFinished = *bufferCast<SizeType>(*mBuffers.at(microBatchId)->nbFinished);
}
sync_check_cuda_error();
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
return shouldStop;
return nbFinished == batchSize * beamWidth;
}
void GptSession::finalizeOutputIds(ITensor& outputIds)
void GptSession::finalizeOutputIds(ITensor& outputIds, SizeType microBatchId)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
auto& manager = mRuntime->getBufferManager();
auto& stream = mRuntime->getStream();
ITensor::SharedPtr finalOutputIds;
if (mWorldConfig.isLastPipelineParallelRank())
if (mWorldConfig.isPipelineParallel())
{
finalOutputIds = mDecoder->getFinalOutputIds();
if (mWorldConfig.isPipelineParallel())
{
auto& stream = mRuntime->getStream();
auto const pipelineGroup = mWorldConfig.getPipelineParallelGroup();
if (mWorldConfig.isLastPipelineParallelRank())
{ // send ids from last to first
auto const peer = pipelineGroup.front();
auto const finalOutputIds = mDecoders.at(microBatchId)->getFinalOutputIds();
mPipelineComm->send(
bufferCast<std::int32_t>(*finalOutputIds), finalOutputIds->getSize(), 0, stream, *mLogger);
bufferCast<std::int32_t>(*finalOutputIds), finalOutputIds->getSize(), peer, stream, *mLogger);
}
}
if (mWorldConfig.isFirstPipelineParallelRank())
{
if (mWorldConfig.isPipelineParallel())
{
auto const peer = mWorldConfig.getPipelineParallelism() - 1;
else if (mWorldConfig.isFirstPipelineParallelRank())
{ // receive ids from last on first
auto const peer = pipelineGroup.back();
mPipelineComm->receive(bufferCast<std::int32_t>(outputIds), outputIds.getSize(), peer, stream, *mLogger);
}
else
{
manager.copy(*finalOutputIds, outputIds);
}
}
else
{
manager.copy(*mDecoders.at(microBatchId)->getFinalOutputIds(), outputIds);
}
sync_check_cuda_error();
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
}

Some files were not shown because too many files have changed in this diff Show More