mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Update
This commit is contained in:
parent
279e329b22
commit
027cd518e3
@ -251,7 +251,14 @@ corresponds to your CUDA version. As an example, for CUDA 12.1, use:
|
||||
pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121
|
||||
```
|
||||
|
||||
[CUDA_ARCHITECTURES in CMake]: https://cmake.org/cmake/help/latest/prop_tgt/CUDA_ARCHITECTURES.html#prop_tgt:CUDA_ARCHITECTURES
|
||||
- When building models, you might encounter memory-related issues. Such as:
|
||||
```
|
||||
[09/23/2023-03:13:00] [TRT] [E] 9: GPTLMHeadModel/layers/0/attention/qkv/PLUGIN_V2_Gemm_0: could not find any supported formats consistent with input/output data types
|
||||
[09/23/2023-03:13:00] [TRT] [E] 9: [pluginV2Builder.cpp::reportPluginError::24] Error Code 9: Internal Error (GPTLMHeadModel/layers/0/attention/qkv/PLUGIN_V2_Gemm_0: could not find any supported formats consistent with input/output data types)
|
||||
```
|
||||
You can reduce the memory pressure by lowering the maximum batch size, input and output lengths. Another option is to enable plugins, for example: `--use_gpt_attention_plugin`.
|
||||
|
||||
- [CUDA_ARCHITECTURES in CMake]: https://cmake.org/cmake/help/latest/prop_tgt/CUDA_ARCHITECTURES.html#prop_tgt:CUDA_ARCHITECTURES
|
||||
|
||||
## Release notes
|
||||
|
||||
@ -25,9 +25,17 @@ add_subdirectory(${CXXOPTS_SRC_DIR} ${CMAKE_CURRENT_BINARY_DIR}/cxxopts)
|
||||
function(add_benchmark test_name test_src)
|
||||
add_executable(${test_name} ${test_src})
|
||||
|
||||
target_link_libraries(
|
||||
${test_name} PUBLIC ${SHARED_TARGET} nvinfer_plugin_tensorrt_llm
|
||||
cxxopts::cxxopts)
|
||||
if(NOT WIN32) # Linux
|
||||
target_link_libraries(
|
||||
${test_name} PUBLIC ${SHARED_TARGET} nvinfer_plugin_tensorrt_llm
|
||||
cxxopts::cxxopts)
|
||||
else()
|
||||
# Use STATIC_TARGET on Windows because MSVC is picky about duplicate symbols
|
||||
# if the shared and static libs both get linked
|
||||
target_link_libraries(
|
||||
${test_name} PUBLIC ${STATIC_TARGET} nvinfer_plugin_tensorrt_llm
|
||||
cxxopts::cxxopts)
|
||||
endif()
|
||||
|
||||
target_compile_features(${test_name} PRIVATE cxx_std_17)
|
||||
target_compile_definitions(${test_name}
|
||||
@ -37,3 +45,4 @@ endfunction()
|
||||
|
||||
add_benchmark(gptSessionBenchmark gptSessionBenchmark.cpp)
|
||||
add_benchmark(bertBenchmark bertBenchmark.cpp)
|
||||
add_benchmark(gptManagerBenchmark gptManagerBenchmark.cpp)
|
||||
|
||||
@ -15,7 +15,7 @@ cd cpp/build
|
||||
make -j benchmarks
|
||||
```
|
||||
|
||||
### 2. Launch C++ benchmarking
|
||||
### 2. Launch C++ benchmarking (Fixed BatchSize/InputLen/OutputLen)
|
||||
|
||||
Before you launch C++ benchmarking, please make sure that you have already built engine(s) using TensorRT-LLM API, C++ benchmarking code cannot generate engine(s) for you.
|
||||
|
||||
@ -55,3 +55,49 @@ mpirun -n 8 ./benchmarks/gptSessionBenchmark \
|
||||
```
|
||||
|
||||
*Please note that the expected outputs in that document are only for reference, specific performance numbers depend on the GPU you're using.*
|
||||
|
||||
### 3. Launch Batch Manager benchmarking (Inflight/V1 batching)
|
||||
|
||||
#### Prepare dataset
|
||||
|
||||
Run a preprocessing script to prepare dataset. This script converts the prompts(string) in the dataset to input_ids.
|
||||
```
|
||||
python3 prepare_dataset.py \
|
||||
--dataset <path/to/dataset> \
|
||||
--max_input_len 300 \
|
||||
--tokenizer_dir <path/to/tokenizer> \
|
||||
--tokenizer_type auto \
|
||||
--output preprocessed_dataset.json
|
||||
```
|
||||
For `tokenizer_dir`, specifying the path to the local tokenizer that have already been downloaded, or simply the name of the tokenizer from HuggingFace like `gpt2` will both work. The tokenizer will be downloaded automatically for the latter case.
|
||||
|
||||
#### Prepare TensorRT-LLM engines
|
||||
Please make sure that the engines are built with argument `--use_inflight_batching` and `--remove_input_padding` if you'd like to benchmark inflight batching, for more details, please see the document in TensorRT-LLM examples.
|
||||
|
||||
#### Launch benchmarking
|
||||
|
||||
For detailed usage, you can do the following
|
||||
```
|
||||
cd cpp/build
|
||||
|
||||
# You can directly execute the binary for help information
|
||||
./benchmarks/gptManagerBenchmark --help
|
||||
```
|
||||
|
||||
Take GPT-350M as an example for single GPU V1 batching
|
||||
```
|
||||
./benchmarks/gptManagerBenchmark \
|
||||
--model gpt \
|
||||
--engine_dir ../../examples/gpt/trt_engine/gpt2/fp16/1-gpu/ \
|
||||
--type V1 \
|
||||
--dataset ../../benchmarks/cpp/preprocessed_dataset.json
|
||||
```
|
||||
|
||||
Take GPT-350M as an example for 2-GPU inflight batching
|
||||
```
|
||||
mpirun -n 2 ./benchmarks/gptManagerBenchmark \
|
||||
--model gpt \
|
||||
--engine_dir ../../examples/gpt/trt_engine/gpt2-ib/fp16/2-gpu/ \
|
||||
--type IFB \
|
||||
--dataset ../../benchmarks/cpp/preprocessed_dataset.json
|
||||
```
|
||||
|
||||
@ -78,7 +78,7 @@ void benchmarkBert(std::string const& modelName, std::filesystem::path const& da
|
||||
{
|
||||
auto const worldConfig = WorldConfig::mpi(*logger);
|
||||
auto const enginePath = dataPath / engineFilename(dataPath, worldConfig, modelName);
|
||||
auto engineBlob = loadEngine(enginePath);
|
||||
auto engineBlob = loadEngine(enginePath.string());
|
||||
|
||||
auto rt = std::make_shared<TllmRuntime>(engineBlob.data(), engineBlob.size(), *logger);
|
||||
rt->addContext(0);
|
||||
@ -180,7 +180,8 @@ int main(int argc, char* argv[])
|
||||
if (!result.count("engine_dir"))
|
||||
{
|
||||
std::cout << options.help() << std::endl;
|
||||
throw std::invalid_argument("Please specify engine directory.");
|
||||
TLLM_LOG_ERROR("Please specify engine directory.");
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Argument: Batch sizes
|
||||
@ -226,11 +227,20 @@ int main(int argc, char* argv[])
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::invalid_argument("Unexpected log level: " + logLevel);
|
||||
TLLM_LOG_ERROR("Unexpected log level: " + logLevel);
|
||||
return 1;
|
||||
}
|
||||
initTrtLlmPlugins(logger.get());
|
||||
|
||||
benchmarkBert(result["model"].as<std::string>(), result["engine_dir"].as<std::string>(), batchSizes, inLens, logger,
|
||||
result["warm_up"].as<int>(), result["num_runs"].as<int>(), result["duration"].as<int>());
|
||||
try
|
||||
{
|
||||
benchmarkBert(result["model"].as<std::string>(), result["engine_dir"].as<std::string>(), batchSizes, inLens,
|
||||
logger, result["warm_up"].as<int>(), result["num_runs"].as<int>(), result["duration"].as<int>());
|
||||
}
|
||||
catch (const std::exception& e)
|
||||
{
|
||||
TLLM_LOG_ERROR(e.what());
|
||||
return 1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
645
benchmarks/cpp/gptManagerBenchmark.cpp
Normal file
645
benchmarks/cpp/gptManagerBenchmark.cpp
Normal file
@ -0,0 +1,645 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "tensorrt_llm/batch_manager/GptManager.h"
|
||||
#include "tensorrt_llm/batch_manager/NamedTensor.h"
|
||||
#include "tensorrt_llm/batch_manager/callbacks.h"
|
||||
#include "tensorrt_llm/batch_manager/inferenceRequest.h"
|
||||
#include "tensorrt_llm/common/logger.h"
|
||||
#include "tensorrt_llm/common/memoryUtils.h"
|
||||
#include "tensorrt_llm/common/mpiUtils.h"
|
||||
#include "tensorrt_llm/plugins/api/tllmPlugin.h"
|
||||
#include "tensorrt_llm/runtime/gptJsonConfig.h"
|
||||
#include "tensorrt_llm/runtime/gptSession.h"
|
||||
#include "tensorrt_llm/runtime/tllmLogger.h"
|
||||
|
||||
#include <NvInfer.h>
|
||||
#include <NvInferPlugin.h>
|
||||
#include <chrono>
|
||||
#include <cxxopts.hpp>
|
||||
#include <iostream>
|
||||
#include <nlohmann/json.hpp>
|
||||
#include <string>
|
||||
|
||||
using namespace tensorrt_llm::runtime;
|
||||
using namespace tensorrt_llm::batch_manager;
|
||||
using namespace tensorrt_llm::runtime;
|
||||
using namespace tensorrt_llm::mpi;
|
||||
|
||||
namespace tc = tensorrt_llm::common;
|
||||
namespace trt = nvinfer1;
|
||||
|
||||
// Class holding all infos regarding a single work item.
|
||||
// This includes the original request, associated response factor
|
||||
// and state.
|
||||
class WorkItem
|
||||
{
|
||||
public:
|
||||
WorkItem(std::shared_ptr<InferenceRequest> ir, uint64_t RequestId)
|
||||
: mInferenceRequest(ir)
|
||||
, mRequestId(RequestId)
|
||||
{
|
||||
}
|
||||
|
||||
~WorkItem() {}
|
||||
|
||||
uint64_t requestId() const
|
||||
{
|
||||
return mRequestId;
|
||||
}
|
||||
|
||||
std::shared_ptr<InferenceRequest> getInferenceRequest() const
|
||||
{
|
||||
return mInferenceRequest;
|
||||
}
|
||||
|
||||
private:
|
||||
std::shared_ptr<InferenceRequest> mInferenceRequest;
|
||||
uint64_t mRequestId;
|
||||
};
|
||||
|
||||
/// @brief Thread-safe queue of work items
|
||||
class WorkItemsQueue
|
||||
{
|
||||
public:
|
||||
void clear()
|
||||
{
|
||||
std::lock_guard<std::mutex> lk(mMutex);
|
||||
mPendingWorkItems.clear();
|
||||
mPendingWorkItemsReqIds.clear();
|
||||
mInProgressWorkItems.clear();
|
||||
}
|
||||
|
||||
// Note: this function only be called under a lock
|
||||
bool hasInProgressReqId(const uint64_t reqId) const
|
||||
{
|
||||
return (mInProgressWorkItems.find(reqId) != mInProgressWorkItems.end());
|
||||
}
|
||||
|
||||
// Note: this function only be called under a lock
|
||||
bool hasPendingReqId(const uint64_t reqId) const
|
||||
{
|
||||
return (mPendingWorkItemsReqIds.find(reqId) != mPendingWorkItemsReqIds.end());
|
||||
}
|
||||
|
||||
bool empty() const
|
||||
{
|
||||
return mPendingWorkItems.empty() && mInProgressWorkItems.empty() && mPendingWorkItemsReqIds.empty();
|
||||
}
|
||||
|
||||
/// @brief Add a new work item to the queue
|
||||
/// Throws an error if requestId already exists
|
||||
|
||||
void push(std::shared_ptr<InferenceRequest> request, uint64_t requestId)
|
||||
{
|
||||
std::lock_guard<std::mutex> lk(mMutex);
|
||||
if (hasInProgressReqId(requestId) || hasPendingReqId(requestId))
|
||||
{
|
||||
std::string errStr
|
||||
= "requestId " + std::to_string(requestId) + " is already in progress, request is ignored.";
|
||||
throw std::runtime_error(errStr);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto workItem = std::make_shared<WorkItem>(request, requestId);
|
||||
mPendingWorkItems.push_back(workItem);
|
||||
mPendingWorkItemsReqIds.insert(workItem->requestId());
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Get a new work item from the queue, and move it to the list of
|
||||
/// in progress work items if it hasn't been stopped
|
||||
/// @return A tuple of the workItem and a boolean flag indicating if the work item
|
||||
/// has been marked in progress
|
||||
std::tuple<std::shared_ptr<WorkItem>, bool> pop()
|
||||
{
|
||||
std::lock_guard<std::mutex> lk(mMutex);
|
||||
|
||||
auto workItem = mPendingWorkItems.front();
|
||||
mPendingWorkItems.pop_front();
|
||||
mPendingWorkItemsReqIds.erase(workItem->requestId());
|
||||
|
||||
bool markedInProgress;
|
||||
mInProgressWorkItems.emplace(std::make_pair(workItem->requestId(), workItem));
|
||||
markedInProgress = true;
|
||||
|
||||
return {workItem, markedInProgress};
|
||||
}
|
||||
|
||||
size_t numPendingWorkItems() const
|
||||
{
|
||||
std::lock_guard<std::mutex> lk(mMutex);
|
||||
return mPendingWorkItems.size();
|
||||
}
|
||||
|
||||
size_t numInProgressWorkItems() const
|
||||
{
|
||||
std::lock_guard<std::mutex> lk(mMutex);
|
||||
return mInProgressWorkItems.size();
|
||||
}
|
||||
|
||||
size_t size() const
|
||||
{
|
||||
return numPendingWorkItems() + numInProgressWorkItems();
|
||||
}
|
||||
|
||||
/// @brief Mark a request as being finished
|
||||
/// @param requestId
|
||||
void markFinished(const uint64_t requestId)
|
||||
{
|
||||
std::lock_guard<std::mutex> lk(mMutex);
|
||||
if (hasInProgressReqId(requestId))
|
||||
{
|
||||
mInProgressWorkItems.erase(requestId);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
/// Queue of work items
|
||||
std::list<std::shared_ptr<WorkItem>> mPendingWorkItems;
|
||||
/// requestIds of work items in the queue
|
||||
std::set<uint64_t> mPendingWorkItemsReqIds;
|
||||
|
||||
/// work items currently in progress
|
||||
std::unordered_map<uint64_t, std::shared_ptr<WorkItem>> mInProgressWorkItems;
|
||||
|
||||
mutable std::mutex mMutex;
|
||||
};
|
||||
|
||||
struct BenchInfo
|
||||
{
|
||||
BenchInfo() {}
|
||||
|
||||
BenchInfo(int _inputLength, int _outputLength, std::chrono::time_point<std::chrono::steady_clock> _start)
|
||||
: inputLength(_inputLength)
|
||||
, outputLength(_outputLength)
|
||||
, start(_start)
|
||||
{
|
||||
}
|
||||
|
||||
int inputLength;
|
||||
int outputLength;
|
||||
std::chrono::time_point<std::chrono::steady_clock> start;
|
||||
std::chrono::time_point<std::chrono::steady_clock> end;
|
||||
float latency; // millisecond
|
||||
};
|
||||
|
||||
class Recorder
|
||||
{
|
||||
public:
|
||||
Recorder() {}
|
||||
|
||||
void initialize()
|
||||
{
|
||||
mStart = std::chrono::steady_clock::now();
|
||||
}
|
||||
|
||||
void finalize()
|
||||
{
|
||||
mEnd = std::chrono::steady_clock::now();
|
||||
}
|
||||
|
||||
void recordStart(std::shared_ptr<InferenceRequest> request, uint64_t requestId)
|
||||
{
|
||||
const auto& input_ids_tensor = request->getInputTensor("input_ids");
|
||||
std::vector<int64_t> tensorShape(input_ids_tensor->getShape().nbDims);
|
||||
auto const inputLength = tensorShape[1];
|
||||
auto const [specified, outputLength]
|
||||
= request->getScalarValueFromTensor<int>("request_output_len", {1, 1}, false);
|
||||
assert(specified);
|
||||
auto const start = std::chrono::steady_clock::now();
|
||||
mRequestBenchInfos[requestId] = BenchInfo(inputLength, outputLength, start);
|
||||
}
|
||||
|
||||
void recordEnd(uint64_t requestId)
|
||||
{
|
||||
mRequestBenchInfos[requestId].end = std::chrono::steady_clock::now();
|
||||
mRequestBenchInfos[requestId].latency = std::chrono::duration<float, std::milli>(
|
||||
mRequestBenchInfos[requestId].end - mRequestBenchInfos[requestId].start)
|
||||
.count();
|
||||
}
|
||||
|
||||
void calculateMetrics()
|
||||
{
|
||||
mNumSamples = mRequestBenchInfos.size();
|
||||
mTotalLatency = std::chrono::duration<float, std::milli>(mEnd - mStart).count();
|
||||
mSeqThroughput = mNumSamples / (mTotalLatency / 1000);
|
||||
mAvgSeqLatency = 0;
|
||||
int totalOutputTokens = 0;
|
||||
for (auto reqInfo : mRequestBenchInfos)
|
||||
{
|
||||
mAvgSeqLatency += reqInfo.second.latency;
|
||||
totalOutputTokens += reqInfo.second.outputLength;
|
||||
}
|
||||
mAvgSeqLatency /= mNumSamples;
|
||||
mTokenThroughput = totalOutputTokens / (mTotalLatency / 1000);
|
||||
}
|
||||
|
||||
void report()
|
||||
{
|
||||
printf("[BENCHMARK] num_samples(ms) %d\n", mNumSamples);
|
||||
printf("[BENCHMARK] total_latency(ms) %.2f\n", mTotalLatency);
|
||||
printf("[BENCHMARK] seq_throughput(seq/sec) %.2f\n", mSeqThroughput);
|
||||
printf("[BENCHMARK] avg_sequence_latency(ms) %.2f\n", mAvgSeqLatency);
|
||||
printf("[BENCHMARK] token_throughput(token/sec) %.2f\n", mTokenThroughput);
|
||||
}
|
||||
|
||||
private:
|
||||
std::unordered_map<uint64_t, BenchInfo> mRequestBenchInfos;
|
||||
|
||||
std::chrono::time_point<std::chrono::steady_clock> mStart;
|
||||
std::chrono::time_point<std::chrono::steady_clock> mEnd;
|
||||
int mNumSamples;
|
||||
float mTotalLatency;
|
||||
float mSeqThroughput;
|
||||
float mAvgSeqLatency;
|
||||
float mTokenThroughput;
|
||||
}; // class Recorder
|
||||
|
||||
class GptServer
|
||||
{
|
||||
public:
|
||||
GptServer(std::filesystem::path const& trtEnginePath, TrtGptModelType modelType, int32_t maxBeamWidth,
|
||||
batch_scheduler::SchedulerPolicy schedulerPolicy, std::optional<int32_t> maxNumSequences,
|
||||
std::optional<int32_t> maxTokensInPagedKvCache, std::optional<float> kvCacheFreeGpuMemFraction,
|
||||
std::shared_ptr<Recorder> recorder)
|
||||
{
|
||||
const TrtGptModelOptionalParams& optionalParams
|
||||
= TrtGptModelOptionalParams(maxNumSequences, maxTokensInPagedKvCache, kvCacheFreeGpuMemFraction);
|
||||
mBatchManager = std::make_shared<GptManager>(
|
||||
trtEnginePath, modelType, maxBeamWidth, schedulerPolicy,
|
||||
[this](int max_num_requests) { return getInferenceRequests(max_num_requests); },
|
||||
[this](uint64_t requestId, std::list<NamedTensor> response_tensors, bool final_response,
|
||||
const std::string& errMsg)
|
||||
{ return sendResponse(requestId, response_tensors, final_response, errMsg); },
|
||||
nullptr, nullptr, optionalParams);
|
||||
mRecorder = recorder;
|
||||
}
|
||||
|
||||
~GptServer()
|
||||
{
|
||||
mWorkItemsQueue.clear();
|
||||
}
|
||||
|
||||
void enqueue(std::vector<NamedTensor> tensors, uint64_t requestId, bool streaming)
|
||||
{
|
||||
// Create InferenceRequest from a set of tensors
|
||||
auto request = std::make_shared<InferenceRequest>(requestId);
|
||||
if (requestId == -1)
|
||||
{
|
||||
mWorkItemsQueue.push(request, requestId);
|
||||
return;
|
||||
}
|
||||
for (auto t : tensors)
|
||||
{
|
||||
request->emplaceInputTensor(t.name, std::move(t.tensor));
|
||||
}
|
||||
request->setIsStreaming(streaming);
|
||||
|
||||
// Enqueue
|
||||
try
|
||||
{
|
||||
mRecorder->recordStart(request, requestId);
|
||||
mWorkItemsQueue.push(request, requestId);
|
||||
}
|
||||
catch (const std::exception& e)
|
||||
{
|
||||
throw std::runtime_error(e.what());
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void waitForEmpty() const
|
||||
{
|
||||
while (mWorkItemsQueue.size() > 0)
|
||||
{
|
||||
}
|
||||
}
|
||||
|
||||
void waitBatchManager() const
|
||||
{
|
||||
mBatchManager->waitUntilTerminate();
|
||||
}
|
||||
|
||||
// Return up to max_num_requests inference requests.
|
||||
std::list<std::shared_ptr<InferenceRequest>> getInferenceRequests(const int max_num_requests)
|
||||
{
|
||||
std::list<std::shared_ptr<InferenceRequest>> rval;
|
||||
if (max_num_requests > 0)
|
||||
{
|
||||
auto world_size = getCommWorldSize();
|
||||
auto rank = getCommWorldRank();
|
||||
if (rank == 0)
|
||||
{
|
||||
int64_t num_new_work_items = std::min(static_cast<int64_t>(mWorkItemsQueue.numPendingWorkItems()),
|
||||
static_cast<int64_t>(max_num_requests));
|
||||
if (world_size > 1)
|
||||
{
|
||||
bcast(&num_new_work_items, 1, MPI_TYPE_INT64_T, 0, COMM_WORLD);
|
||||
}
|
||||
|
||||
if (num_new_work_items > 0)
|
||||
{
|
||||
int count = 0;
|
||||
while (count < num_new_work_items)
|
||||
{
|
||||
auto [workItem, markedInProgress] = mWorkItemsQueue.pop();
|
||||
|
||||
if (markedInProgress)
|
||||
{
|
||||
rval.emplace_back(workItem->getInferenceRequest());
|
||||
count++;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::string warnStr = std::string("request Id ") + std::to_string(workItem->requestId())
|
||||
+ std::string(" has been stopped. Request is ignored.");
|
||||
TLLM_LOG_WARNING(warnStr);
|
||||
sendResponse(workItem->requestId(), {}, true, warnStr);
|
||||
}
|
||||
}
|
||||
if (world_size > 1)
|
||||
{
|
||||
std::vector<int64_t> packed;
|
||||
for (auto ir : rval)
|
||||
{
|
||||
auto vpacked = ir->serialize();
|
||||
packed.push_back(static_cast<int64_t>(vpacked.size()));
|
||||
packed.insert(
|
||||
packed.end(), std::move_iterator(vpacked.begin()), std::move_iterator(vpacked.end()));
|
||||
}
|
||||
bcast(packed, 0, COMM_WORLD);
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// subordinate ranks hang until master rank sends work
|
||||
int64_t num_new_work_items;
|
||||
bcast(&num_new_work_items, 1, MPI_TYPE_INT64_T, 0, COMM_WORLD);
|
||||
if (num_new_work_items > 0)
|
||||
{
|
||||
std::vector<int64_t> packed;
|
||||
bcast(packed, 0, COMM_WORLD);
|
||||
int64_t* packed_ptr = packed.data();
|
||||
for (int64_t count = 0; count < num_new_work_items; ++count)
|
||||
{
|
||||
int64_t n = *(packed_ptr++);
|
||||
auto ir = InferenceRequest::deserialize(packed_ptr);
|
||||
packed_ptr += n;
|
||||
rval.emplace_back(ir);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return rval;
|
||||
}
|
||||
|
||||
void sendResponse(uint64_t requestId, std::list<NamedTensor> const& response_tensors, bool final_response,
|
||||
const std::string& errMsg)
|
||||
{
|
||||
std::string errStr = std::string("Failed to send response for requestId: ") + std::to_string(requestId);
|
||||
try
|
||||
{
|
||||
if (final_response)
|
||||
{
|
||||
mWorkItemsQueue.markFinished(requestId);
|
||||
mRecorder->recordEnd(requestId);
|
||||
}
|
||||
}
|
||||
catch (const std::exception& e)
|
||||
{
|
||||
TLLM_LOG_ERROR(errStr);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::shared_ptr<GptManager> mBatchManager;
|
||||
std::shared_ptr<Recorder> mRecorder;
|
||||
WorkItemsQueue mWorkItemsQueue;
|
||||
|
||||
}; // class GptServer
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
std::pair<std::vector<std::vector<int32_t>>, std::vector<int32_t>> parseDataset(
|
||||
std::filesystem::path const& datasetPath)
|
||||
{
|
||||
auto constexpr allowExceptions = true;
|
||||
auto constexpr ingoreComments = true;
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
std::filesystem::exists(datasetPath), std::string("File does not exist: ") + datasetPath.string());
|
||||
std::ifstream jsonStream(datasetPath);
|
||||
auto json = nlohmann::json::parse(jsonStream, nullptr, allowExceptions, ingoreComments);
|
||||
|
||||
std::vector<std::vector<int32_t>> input_ids_list;
|
||||
std::vector<int32_t> output_ids_list;
|
||||
for (auto& sample : json)
|
||||
{
|
||||
input_ids_list.push_back(sample["input_ids"]);
|
||||
output_ids_list.push_back(sample["output_len"]);
|
||||
}
|
||||
return std::make_pair(input_ids_list, output_ids_list);
|
||||
}
|
||||
|
||||
void benchmarkGptManager(std::string const& modelName, std::filesystem::path const& engineDir, std::string const& type,
|
||||
std::string const& datasetPath, std::shared_ptr<nvinfer1::ILogger> const& logger,
|
||||
std::optional<int32_t> maxNumSequences, std::optional<int32_t> maxTokensInPagedKvCache,
|
||||
std::optional<float> kvCacheFreeGpuMemFraction, batch_scheduler::SchedulerPolicy schedulerPolicy)
|
||||
{
|
||||
auto const worldConfig = WorldConfig::mpi(*logger);
|
||||
|
||||
TrtGptModelType modelType;
|
||||
if (type == "V1")
|
||||
{
|
||||
modelType = TrtGptModelType::V1;
|
||||
}
|
||||
else if (type == "IFB")
|
||||
{
|
||||
modelType = TrtGptModelType::InflightFusedBatching;
|
||||
}
|
||||
else
|
||||
{
|
||||
const std::string errStr = std::string("Unexpected batching type: ") + type;
|
||||
TLLM_LOG_ERROR(errStr);
|
||||
}
|
||||
|
||||
const int maxBeamWidth = 1;
|
||||
auto recorder = std::make_shared<Recorder>();
|
||||
auto gptServer = std::make_shared<GptServer>(engineDir, modelType, maxBeamWidth, schedulerPolicy, maxNumSequences,
|
||||
maxTokensInPagedKvCache, kvCacheFreeGpuMemFraction, recorder);
|
||||
|
||||
auto dataset = parseDataset(datasetPath);
|
||||
std::vector<std::vector<NamedTensor>> tensors_list;
|
||||
const auto num_samples = dataset.first.size();
|
||||
for (int i = 0; i < num_samples; ++i)
|
||||
{
|
||||
const auto input_ids = dataset.first[i];
|
||||
const auto request_output_len = dataset.second[i];
|
||||
std::vector<int64_t> input_ids_shape = {1, static_cast<int64_t>(input_ids.size())};
|
||||
auto input_ids_tensor = NamedTensor(nvinfer1::DataType::kINT32, input_ids_shape, "input_ids", input_ids.data());
|
||||
auto request_output_len_tensor
|
||||
= NamedTensor(nvinfer1::DataType::kINT32, {1, 1}, "request_output_len", &request_output_len);
|
||||
std::vector<NamedTensor> tensors = {input_ids_tensor, request_output_len_tensor};
|
||||
tensors_list.push_back(tensors);
|
||||
}
|
||||
|
||||
if (worldConfig.getRank() == 0)
|
||||
{
|
||||
recorder->initialize();
|
||||
for (int i = 0; i < tensors_list.size(); ++i)
|
||||
{
|
||||
gptServer->enqueue(tensors_list[i], 1 + i, false);
|
||||
}
|
||||
gptServer->waitForEmpty();
|
||||
recorder->finalize();
|
||||
recorder->calculateMetrics();
|
||||
recorder->report();
|
||||
gptServer->enqueue({}, -1, false);
|
||||
}
|
||||
gptServer->waitBatchManager();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
cxxopts::Options options(
|
||||
"TensorRT-LLM BatchManager Benchmark", "TensorRT-LLM BatchManager Benchmark for GPT and GPT-like models.");
|
||||
options.add_options()("h,help", "Print usage");
|
||||
options.add_options()(
|
||||
"m,model", "Model name specified for engines.", cxxopts::value<std::string>()->default_value("gpt_350m"));
|
||||
options.add_options()("engine_dir", "Directory that store the engines.", cxxopts::value<std::string>());
|
||||
options.add_options()(
|
||||
"type", "Batching type: IFB or V1(non-IFB) batching.", cxxopts::value<std::string>()->default_value("IFB"));
|
||||
options.add_options()("dataset", "Dataset that is used for benchmarking BatchManager.",
|
||||
cxxopts::value<std::string>()->default_value(""));
|
||||
|
||||
options.add_options()("max_num_sequences", "Max number of Sequences.", cxxopts::value<int>()->default_value("-1"));
|
||||
options.add_options()(
|
||||
"max_tokens_in_paged_kvcache", "Max tokens in paged K-V Cache.", cxxopts::value<int>()->default_value("-1"));
|
||||
options.add_options()("kv_cache_free_gpu_mem_fraction", "K-V Cache Free Gpu Mem Fraction.",
|
||||
cxxopts::value<float>()->default_value("-1"));
|
||||
options.add_options()("scheduler_policy", "Choose scheduler policy between max_utilization/guaranteed_completion.",
|
||||
cxxopts::value<std::string>()->default_value("guaranteed_completion"));
|
||||
|
||||
options.add_options()("log_level", "Choose log level between verbose/info/warning/error/internal_error.",
|
||||
cxxopts::value<std::string>()->default_value("error"));
|
||||
|
||||
auto result = options.parse(argc, argv);
|
||||
|
||||
if (result.count("help"))
|
||||
{
|
||||
std::cout << options.help() << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Argument: Engine directory
|
||||
if (!result.count("engine_dir"))
|
||||
{
|
||||
std::cout << options.help() << std::endl;
|
||||
TLLM_LOG_ERROR("Please specify engine directory.");
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Argument: Batching Type
|
||||
auto const type = result["type"].as<std::string>();
|
||||
|
||||
// Argument: Dataset
|
||||
auto const datasetPath = result["dataset"].as<std::string>();
|
||||
|
||||
// Argument: Max Num Sequences
|
||||
std::optional<int32_t> maxNumSequences = std::nullopt;
|
||||
if (result["max_num_sequences"].as<int>() != -1)
|
||||
{
|
||||
maxNumSequences = result["max_num_sequences"].as<int>();
|
||||
}
|
||||
|
||||
// Argument: Max tokens in paged K-V Cache
|
||||
std::optional<int32_t> maxTokensInPagedKvCache = std::nullopt;
|
||||
if (result["max_tokens_in_paged_kvcache"].as<int>() != -1)
|
||||
{
|
||||
maxTokensInPagedKvCache = result["max_tokens_in_paged_kvcache"].as<int>();
|
||||
}
|
||||
|
||||
// Argument: K-V Cache Free Gpu Mem Fraction
|
||||
std::optional<float> kvCacheFreeGpuMemFraction = std::nullopt;
|
||||
if (result["kv_cache_free_gpu_mem_fraction"].as<float>() != -1)
|
||||
{
|
||||
kvCacheFreeGpuMemFraction = result["kv_cache_free_gpu_mem_fraction"].as<float>();
|
||||
}
|
||||
|
||||
// Argument: Scheduler policy
|
||||
batch_scheduler::SchedulerPolicy schedulerPolicy;
|
||||
auto const schedulerPolicyArg = result["scheduler_policy"].as<std::string>();
|
||||
if (schedulerPolicyArg == "max_utilization")
|
||||
{
|
||||
schedulerPolicy = batch_scheduler::SchedulerPolicy::MAX_UTILIZATION;
|
||||
}
|
||||
else if (schedulerPolicyArg == "guaranteed_completion")
|
||||
{
|
||||
schedulerPolicy = batch_scheduler::SchedulerPolicy::GUARANTEED_COMPLETION;
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_LOG_ERROR("Unexpected scheduler policy: " + schedulerPolicyArg);
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Argument: Log level
|
||||
auto logger = std::make_shared<TllmLogger>();
|
||||
auto const logLevel = result["log_level"].as<std::string>();
|
||||
if (logLevel == "verbose")
|
||||
{
|
||||
logger->setLevel(trt::ILogger::Severity::kVERBOSE);
|
||||
}
|
||||
else if (logLevel == "info")
|
||||
{
|
||||
logger->setLevel(trt::ILogger::Severity::kINFO);
|
||||
}
|
||||
else if (logLevel == "warning")
|
||||
{
|
||||
logger->setLevel(trt::ILogger::Severity::kWARNING);
|
||||
}
|
||||
else if (logLevel == "error")
|
||||
{
|
||||
logger->setLevel(trt::ILogger::Severity::kERROR);
|
||||
}
|
||||
else if (logLevel == "internal_error")
|
||||
{
|
||||
logger->setLevel(trt::ILogger::Severity::kINTERNAL_ERROR);
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_LOG_ERROR("Unexpected log level: " + logLevel);
|
||||
return 1;
|
||||
}
|
||||
|
||||
initTrtLlmPlugins(logger.get());
|
||||
|
||||
try
|
||||
{
|
||||
benchmarkGptManager(result["model"].as<std::string>(), result["engine_dir"].as<std::string>(), type,
|
||||
datasetPath, logger, maxNumSequences, maxTokensInPagedKvCache, kvCacheFreeGpuMemFraction, schedulerPolicy);
|
||||
}
|
||||
catch (const std::exception& e)
|
||||
{
|
||||
TLLM_LOG_ERROR(e.what());
|
||||
return 1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
@ -58,7 +58,7 @@ void benchmarkGptSession(std::string const& modelName, std::filesystem::path con
|
||||
samplingConfig.topK = std::vector{1};
|
||||
samplingConfig.topP = std::vector{0.0f};
|
||||
|
||||
GptSession session{modelConfig, worldConfig, enginePath, logger};
|
||||
GptSession session{modelConfig, worldConfig, enginePath.string(), logger};
|
||||
// Use bufferManager for copying data to and from the GPU
|
||||
auto& bufferManager = session.getBufferManager();
|
||||
session.setCudaGraphMode(cudaGraphMode);
|
||||
@ -178,7 +178,8 @@ int main(int argc, char* argv[])
|
||||
if (!result.count("engine_dir"))
|
||||
{
|
||||
std::cout << options.help() << std::endl;
|
||||
throw std::invalid_argument("Please specify engine directory.");
|
||||
TLLM_LOG_ERROR("Please specify engine directory.");
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Argument: Batch sizes
|
||||
@ -230,7 +231,8 @@ int main(int argc, char* argv[])
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::invalid_argument("Unexpected log level: " + logLevel);
|
||||
TLLM_LOG_ERROR("Unexpected log level: " + logLevel);
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Argument: Enable CUDA graph
|
||||
@ -238,8 +240,16 @@ int main(int argc, char* argv[])
|
||||
|
||||
initTrtLlmPlugins(logger.get());
|
||||
|
||||
benchmarkGptSession(result["model"].as<std::string>(), result["engine_dir"].as<std::string>(), batchSizes, inOutLen,
|
||||
logger, result["warm_up"].as<int>(), result["num_runs"].as<int>(), result["duration"].as<int>(),
|
||||
enableCudaGraph);
|
||||
try
|
||||
{
|
||||
benchmarkGptSession(result["model"].as<std::string>(), result["engine_dir"].as<std::string>(), batchSizes,
|
||||
inOutLen, logger, result["warm_up"].as<int>(), result["num_runs"].as<int>(), result["duration"].as<int>(),
|
||||
enableCudaGraph);
|
||||
}
|
||||
catch (const std::exception& e)
|
||||
{
|
||||
TLLM_LOG_ERROR(e.what());
|
||||
return 1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
63
benchmarks/cpp/prepare_dataset.py
Normal file
63
benchmarks/cpp/prepare_dataset.py
Normal file
@ -0,0 +1,63 @@
|
||||
#!/usr/bin/python
|
||||
|
||||
import argparse
|
||||
import json
|
||||
|
||||
from transformers import AutoTokenizer, LlamaTokenizer, T5Tokenizer
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--dataset',
|
||||
type=str,
|
||||
required=True,
|
||||
help='Dataset path used for the test.')
|
||||
parser.add_argument('--max_input_len',
|
||||
type=int,
|
||||
required=True,
|
||||
help='Specify max input length')
|
||||
parser.add_argument('--tokenizer_dir',
|
||||
type=str,
|
||||
required=True,
|
||||
help='Specify tokenizer directory')
|
||||
parser.add_argument('--tokenizer_type',
|
||||
type=str,
|
||||
default='auto',
|
||||
required=False,
|
||||
choices=['auto', 't5', 'llama'],
|
||||
help='Specify tokenizer type')
|
||||
parser.add_argument('--output',
|
||||
type=str,
|
||||
default='preprocessed_dataset.json',
|
||||
help='Preprocessed dataset path.')
|
||||
FLAGS = parser.parse_args()
|
||||
|
||||
if FLAGS.tokenizer_type == 't5':
|
||||
tokenizer = T5Tokenizer(vocab_file=FLAGS.tokenizer_dir,
|
||||
padding_side='left')
|
||||
elif FLAGS.tokenizer_type == 'auto':
|
||||
tokenizer = AutoTokenizer.from_pretrained(FLAGS.tokenizer_dir,
|
||||
padding_side='left')
|
||||
elif FLAGS.tokenizer_type == 'llama':
|
||||
tokenizer = LlamaTokenizer.from_pretrained(FLAGS.tokenizer_dir,
|
||||
legacy=False,
|
||||
padding_side='left')
|
||||
else:
|
||||
raise AttributeError(
|
||||
f'Unexpected tokenizer type: {FLAGS.tokenizer_type}')
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
results = []
|
||||
with open(FLAGS.dataset, 'r') as f:
|
||||
data_dict = json.load(f)
|
||||
for req in data_dict:
|
||||
prompt = req['input'] + ' ' + req['instruction']
|
||||
output = req['output']
|
||||
line = tokenizer.encode(prompt)
|
||||
if len(line) > FLAGS.max_input_len:
|
||||
continue
|
||||
# 1.3 is a magic number that converts number of words to number of tokens
|
||||
output_len = int(len(output.split(' ')) * 1.3)
|
||||
results.append({'input_ids': line, 'output_len': output_len})
|
||||
|
||||
with open(FLAGS.output, 'w') as f:
|
||||
json.dump(results, f)
|
||||
@ -190,6 +190,13 @@ def parse_arguments():
|
||||
default=False,
|
||||
action='store_true',
|
||||
help='Execute GPT session with CUDA graph.')
|
||||
parser.add_argument(
|
||||
'--enable_custom_all_reduce',
|
||||
default=False,
|
||||
action='store_true',
|
||||
help=
|
||||
'Use latency-optimized all-reduce for tensor parallelism. Gives better performance with NVLink.'
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
@ -209,25 +216,27 @@ def main(args):
|
||||
for io in in_out_len_options]
|
||||
|
||||
if args.model in get_allowed_models(benchmark_type="gpt"):
|
||||
benchmarker = GPTBenchmark(args.engine_dir,
|
||||
args.model,
|
||||
args.mode,
|
||||
batch_size_options,
|
||||
in_out_len_options,
|
||||
args.dtype,
|
||||
args.refit,
|
||||
args.num_beams,
|
||||
args.top_k,
|
||||
args.top_p,
|
||||
args.output_dir,
|
||||
args.n_positions,
|
||||
args.max_input_len,
|
||||
args.max_output_len,
|
||||
args.max_batch_size,
|
||||
force_num_layer_1=args.force_num_layer_1,
|
||||
enable_fp8=args.enable_fp8,
|
||||
fp8_kv_cache=args.fp8_kv_cache,
|
||||
enable_cuda_graph=args.enable_cuda_graph)
|
||||
benchmarker = GPTBenchmark(
|
||||
args.engine_dir,
|
||||
args.model,
|
||||
args.mode,
|
||||
batch_size_options,
|
||||
in_out_len_options,
|
||||
args.dtype,
|
||||
args.refit,
|
||||
args.num_beams,
|
||||
args.top_k,
|
||||
args.top_p,
|
||||
args.output_dir,
|
||||
args.n_positions,
|
||||
args.max_input_len,
|
||||
args.max_output_len,
|
||||
args.max_batch_size,
|
||||
force_num_layer_1=args.force_num_layer_1,
|
||||
enable_fp8=args.enable_fp8,
|
||||
fp8_kv_cache=args.fp8_kv_cache,
|
||||
enable_cuda_graph=args.enable_cuda_graph,
|
||||
enable_custom_all_reduce=args.enable_custom_all_reduce)
|
||||
elif args.model in get_allowed_models(benchmark_type="bert"):
|
||||
benchmarker = BERTBenchmark(args.engine_dir,
|
||||
args.model,
|
||||
|
||||
@ -49,6 +49,7 @@ class GPTBenchmark(BaseBenchmark):
|
||||
max_input_len=None,
|
||||
max_output_len=None,
|
||||
max_batch_size=None,
|
||||
enable_custom_all_reduce=None,
|
||||
**kwargs):
|
||||
super().__init__(engine_dir, model_name, dtype, output_dir)
|
||||
self.batch_sizes = batch_sizes
|
||||
@ -60,6 +61,7 @@ class GPTBenchmark(BaseBenchmark):
|
||||
self.fuse_bias = True
|
||||
|
||||
self.cuda_graph_mode = kwargs.get('enable_cuda_graph', False)
|
||||
self.enable_custom_all_reduce = enable_custom_all_reduce
|
||||
|
||||
if engine_dir is not None:
|
||||
# Get build configs from engine directory is done in base class
|
||||
@ -85,12 +87,9 @@ class GPTBenchmark(BaseBenchmark):
|
||||
plg_dtype = dtype if is_plugin_mode else False
|
||||
self.use_gpt_attention_plugin = plg_dtype
|
||||
self.use_gemm_plugin = plg_dtype
|
||||
self.use_layernorm_plugin = plg_dtype
|
||||
# Enable RMS Norm plugin for the LLaMA family.
|
||||
if is_plugin_mode and 'llama' in model_name:
|
||||
self.use_rmsnorm_plugin = dtype
|
||||
else:
|
||||
self.use_rmsnorm_plugin = False
|
||||
# Starting TRT9.1 OOTB norm layer sees improvement over plugin norm layer
|
||||
self.use_layernorm_plugin = False
|
||||
self.use_rmsnorm_plugin = False
|
||||
self.use_lookup_plugin = plg_dtype
|
||||
self.enable_context_fmha = True
|
||||
self.quant_mode = QuantMode(0)
|
||||
@ -408,7 +407,8 @@ class GPTBenchmark(BaseBenchmark):
|
||||
network.plugin_config.set_rmsnorm_quantization_plugin()
|
||||
|
||||
if self.world_size > 1:
|
||||
network.plugin_config.set_nccl_plugin(self.dtype)
|
||||
network.plugin_config.set_nccl_plugin(self.dtype,
|
||||
self.enable_custom_all_reduce)
|
||||
|
||||
# Use the plugin for the embedding parallism and sharing
|
||||
network.plugin_config.set_lookup_plugin(dtype=self.use_lookup_plugin)
|
||||
@ -434,8 +434,7 @@ class GPTBenchmark(BaseBenchmark):
|
||||
self.build_time = round(end - start, 2)
|
||||
|
||||
if self.output_dir is not None:
|
||||
if not os.path.exists(self.output_dir):
|
||||
os.makedirs(self.output_dir)
|
||||
os.makedirs(self.output_dir, exist_ok=True)
|
||||
self.serialize_path = os.path.join(self.output_dir,
|
||||
self.engine_name)
|
||||
serialize_engine(engine, self.serialize_path)
|
||||
|
||||
@ -101,9 +101,6 @@ if(CMAKE_CUDA_COMPILER)
|
||||
else()
|
||||
message(FATAL_ERROR "Failed to determine CUDA version")
|
||||
endif()
|
||||
|
||||
# Export shared libs as both `.lib` and `.dll` to avoid linking errors.
|
||||
set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS ON)
|
||||
endif()
|
||||
else()
|
||||
message(FATAL_ERROR "No CUDA compiler found")
|
||||
@ -166,7 +163,8 @@ get_filename_component(TRT_LLM_ROOT_DIR ${CMAKE_SOURCE_DIR} PATH)
|
||||
set(3RDPARTY_DIR ${TRT_LLM_ROOT_DIR}/3rdparty)
|
||||
include_directories(
|
||||
${CUDA_INCLUDE_DIRS} ${CUDNN_ROOT_DIR}/include ${NCCL_INCLUDE_DIR}
|
||||
${3RDPARTY_DIR}/cutlass/include ${3RDPARTY_DIR}/NVTX/include)
|
||||
${3RDPARTY_DIR}/cutlass/include ${3RDPARTY_DIR}/NVTX/include
|
||||
${3RDPARTY_DIR}/json/include)
|
||||
|
||||
# TRT dependencies
|
||||
set_ifndef(TRT_LIB_DIR ${CMAKE_BINARY_DIR})
|
||||
@ -195,6 +193,7 @@ endif()
|
||||
# it's not called before "CMAKE_CXX_FLAGS" is set, it breaks on Windows for some
|
||||
# reason, so we just call it here as a workaround.
|
||||
find_package(MPI REQUIRED)
|
||||
add_definitions("-DOMPI_SKIP_MPICXX")
|
||||
|
||||
# C++17
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
@ -214,6 +213,12 @@ else()
|
||||
set(CMAKE_CXX_FLAGS "/wd4996 ${CMAKE_CXX_FLAGS}")
|
||||
endif()
|
||||
|
||||
# A Windows header file defines max() and min() macros, which break our macro
|
||||
# declarations.
|
||||
if(WIN32)
|
||||
set(CMAKE_CXX_FLAGS "/DNOMINMAX ${CMAKE_CXX_FLAGS}")
|
||||
endif()
|
||||
|
||||
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-extended-lambda")
|
||||
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr")
|
||||
|
||||
@ -226,10 +231,12 @@ if(BUILD_PYT)
|
||||
foreach(CUDA_ARCH IN LISTS CMAKE_CUDA_ARCHITECTURES)
|
||||
if(CUDA_ARCH MATCHES "^([0-9])([0-9])(-real)*$")
|
||||
set(TORCH_ARCH "${CMAKE_MATCH_1}.${CMAKE_MATCH_2}")
|
||||
elseif(CUDA_ARCH STREQUAL "native")
|
||||
set(TORCH_ARCH "Auto")
|
||||
else()
|
||||
message(FATAL_ERROR "${CUDA_ARCH} is not supported")
|
||||
endif()
|
||||
if(NOT CUDA_ARCH MATCHES "-real$")
|
||||
if(NOT CUDA_ARCH MATCHES "-real$" AND NOT CUDA_ARCH STREQUAL "native")
|
||||
string(APPEND TORCH_ARCH "+PTX")
|
||||
endif()
|
||||
list(APPEND TORCH_CUDA_ARCH_LIST ${TORCH_ARCH})
|
||||
@ -303,8 +310,9 @@ message(
|
||||
"Building for TensorRT version: ${TRT_VERSION}, library version: ${TRT_SOVERSION}"
|
||||
)
|
||||
|
||||
list(APPEND COMMON_HEADER_DIRS ${TORCH_INCLUDE_DIRS} ${TRT_INCLUDE_DIR})
|
||||
list(APPEND COMMON_HEADER_DIRS)
|
||||
include_directories(${COMMON_HEADER_DIRS})
|
||||
include_directories(SYSTEM ${TORCH_INCLUDE_DIRS} ${TRT_INCLUDE_DIR})
|
||||
|
||||
add_subdirectory(tensorrt_llm)
|
||||
|
||||
|
||||
@ -32,7 +32,8 @@ enum class BatchManagerErrorCode_t
|
||||
{
|
||||
STATUS_SUCCESS = 0,
|
||||
STATUS_FAILED = 1,
|
||||
STATUS_NO_WORK = 2
|
||||
STATUS_NO_WORK = 2,
|
||||
STATUS_TERMINATE = 3
|
||||
};
|
||||
|
||||
enum class TrtGptModelType
|
||||
|
||||
@ -39,8 +39,8 @@ namespace tensorrt_llm::batch_manager
|
||||
class InferenceRequest;
|
||||
class TrtGptModel;
|
||||
|
||||
/* Responsible for shepherding Triton requests through to completion
|
||||
using TRT Backend. Each Triton backend should have just one of these. */
|
||||
/* Responsible for shepherding requests through to completion
|
||||
using TRT Backend. */
|
||||
class GptManager
|
||||
{
|
||||
public:
|
||||
@ -50,6 +50,7 @@ public:
|
||||
GptManager(std::filesystem::path const& trtEnginePath, TrtGptModelType modelType, int32_t maxBeamWidth,
|
||||
batch_scheduler::SchedulerPolicy schedulerPolicy, GetInferenceRequestsCallback getInferenceRequestsCb,
|
||||
SendResponseCallback sendResponseCb, PollStopSignalCallback pollStopSignalCb = nullptr,
|
||||
ReturnBatchManagerStatsCallback returnBatchManagerStatsCb = nullptr,
|
||||
const TrtGptModelOptionalParams& optionalParams = TrtGptModelOptionalParams());
|
||||
|
||||
/* Wraps the user-provided callback for requests.
|
||||
@ -57,20 +58,21 @@ public:
|
||||
Invoked every generation loop iteration. */
|
||||
BatchManagerErrorCode_t fetchNewRequests();
|
||||
|
||||
/* Does the following:
|
||||
1. Returns completed requests to Triton
|
||||
2. Deletes entry from activeRequests */
|
||||
/* Returns completed requests.
|
||||
Deletes entry from activeRequests */
|
||||
BatchManagerErrorCode_t returnCompletedRequests();
|
||||
|
||||
BatchManagerErrorCode_t pollStopSignals();
|
||||
|
||||
BatchManagerErrorCode_t returnBatchManagerStats();
|
||||
|
||||
BatchManagerErrorCode_t waitUntilTerminate();
|
||||
|
||||
virtual ~GptManager();
|
||||
|
||||
protected:
|
||||
/* Does the following:
|
||||
1. Maps batch manager requests to backend request
|
||||
2. Invokes one step of backend
|
||||
3. Updates state of all requests */
|
||||
/* Invokes one step of backend
|
||||
Updates state of all requests */
|
||||
virtual BatchManagerErrorCode_t step(RequestList& activeRequests, std::set<uint64_t>& activeRequestsIds);
|
||||
|
||||
private:
|
||||
@ -84,6 +86,8 @@ private:
|
||||
SizeType mMaxOutputLen;
|
||||
SizeType mMaxNumSequences;
|
||||
|
||||
// Iteration counter - incremented every iteration of the generation loop
|
||||
int64_t mIterationCounter;
|
||||
// List of live requests
|
||||
RequestList mActiveRequests;
|
||||
// IDs of live requests
|
||||
@ -92,6 +96,7 @@ private:
|
||||
GetInferenceRequestsCallback mGetInferenceRequestsCb;
|
||||
SendResponseCallback mSendResponseCb;
|
||||
PollStopSignalCallback mPollStopSignalCb;
|
||||
ReturnBatchManagerStatsCallback mReturnBatchManagerStatsCb;
|
||||
|
||||
std::atomic<bool> destructor_called_;
|
||||
void decoupled_execution_loop();
|
||||
|
||||
@ -31,5 +31,7 @@ class NamedTensor;
|
||||
using GetInferenceRequestsCallback = std::function<std::list<std::shared_ptr<InferenceRequest>>(int32_t)>;
|
||||
using SendResponseCallback = std::function<void(uint64_t, std::list<NamedTensor> const&, bool, const std::string&)>;
|
||||
using PollStopSignalCallback = std::function<std::unordered_set<uint64_t>()>;
|
||||
// json of stats as a string
|
||||
using ReturnBatchManagerStatsCallback = std::function<void(const std::string&)>;
|
||||
|
||||
} // namespace tensorrt_llm::batch_manager
|
||||
|
||||
@ -27,6 +27,7 @@
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
@ -21,6 +21,7 @@
|
||||
#include "tensorrt_llm/runtime/cudaStream.h"
|
||||
#include "tensorrt_llm/runtime/gptModelConfig.h"
|
||||
#include "tensorrt_llm/runtime/iTensor.h"
|
||||
#include "tensorrt_llm/runtime/worldConfig.h"
|
||||
|
||||
#include <NvInferRuntime.h>
|
||||
#include <cstdint>
|
||||
@ -175,8 +176,8 @@ public:
|
||||
using CudaStreamPtr = std::shared_ptr<runtime::CudaStream>;
|
||||
|
||||
KVCacheManager(SizeType numLayers, SizeType numHeads, SizeType numKvHeads, SizeType hiddenSize,
|
||||
SizeType tokensPerBlock, SizeType maxNumBlocks, SizeType maxBatchSize, nvinfer1::DataType dtype,
|
||||
CudaStreamPtr stream);
|
||||
SizeType tokensPerBlock, SizeType maxNumBlocks, SizeType maxBatchSize, SizeType maxBeamWidth,
|
||||
SizeType maxBlocksPerSeq, nvinfer1::DataType dtype, CudaStreamPtr stream);
|
||||
|
||||
[[nodiscard]] SizeType getTokensPerBlock() const
|
||||
{
|
||||
@ -221,11 +222,10 @@ public:
|
||||
|
||||
void removeSequence(SizeType batchSlotIdx);
|
||||
|
||||
[[nodiscard]] std::vector<runtime::ITensor::UniquePtr> getBlockPointersOfSlot(
|
||||
SizeType batchSlotIdx, SizeType beamWidth, SizeType maxBlocksPerSeq) const;
|
||||
void getBlockPointersOfBatch(runtime::ITensor::SharedPtr dstPointers, SizeType batchSize, SizeType beamWidth) const;
|
||||
|
||||
[[nodiscard]] runtime::ITensor::UniquePtr getBlockPointersOfBatch(
|
||||
SizeType batchSize, SizeType beamWidth, SizeType maxBlocksPerSeq) const;
|
||||
void copyBlockPointers(runtime::ITensor::SharedPtr dstPointers, SizeType dstSlotOffset, SizeType batchSlotIdx,
|
||||
SizeType beamWidth) const;
|
||||
|
||||
// Volume of [2, numKvHeads, tokensPerBlock, sizePerHead]
|
||||
[[nodiscard]] static SizeType constexpr calculatePageSize(tensorrt_llm::runtime::GptModelConfig const& modelConfig)
|
||||
@ -235,11 +235,15 @@ public:
|
||||
|
||||
// numLayers * 2 * numKvHeads * sizePerHead
|
||||
[[nodiscard]] static SizeType constexpr calculateCacheSizePerToken(
|
||||
tensorrt_llm::runtime::GptModelConfig const& modelConfig)
|
||||
tensorrt_llm::runtime::GptModelConfig const& modelConfig, tensorrt_llm::runtime::WorldConfig const& worldConfig)
|
||||
{
|
||||
return modelConfig.getNbLayers() * 2 * modelConfig.getNbKvHeads() * modelConfig.getSizePerHead();
|
||||
return modelConfig.getNbLayers(worldConfig.getPipelineParallelism()) * 2 * modelConfig.getNbKvHeads()
|
||||
* modelConfig.getSizePerHead();
|
||||
}
|
||||
|
||||
private:
|
||||
void cacheNewBlockPointer(const GenerationRequest& seq, SizeType batchSlotIdx);
|
||||
|
||||
private:
|
||||
// Number of elements per one blocks
|
||||
SizeType mBlockSize;
|
||||
@ -247,14 +251,20 @@ private:
|
||||
SizeType mTokensPerBlock;
|
||||
// Total maximum number of blocks
|
||||
SizeType mMaxNumBlocks;
|
||||
// Maximum size of batch
|
||||
SizeType mMaxBatchSize;
|
||||
// Maximum beam width
|
||||
SizeType mMaxBeamWidth;
|
||||
// Maximum number of blocks per sequence
|
||||
SizeType mMaxBlocksPerSeq;
|
||||
// Pools
|
||||
std::vector<runtime::ITensor::SharedPtr> mPools;
|
||||
// Block manager
|
||||
BlockManager mBlockManager;
|
||||
// List of all sequences
|
||||
std::vector<SequencesPtr> mSequences;
|
||||
// buffer for block pointers for all batch slots
|
||||
std::vector<runtime::ITensor::UniquePtr> mAllBlockPointers;
|
||||
// buffer for block pointers for all managed sequences
|
||||
runtime::ITensor::SharedPtr mSequenceBlockPointers;
|
||||
|
||||
runtime::BufferManager mManager;
|
||||
};
|
||||
|
||||
@ -37,22 +37,24 @@ enum LlmRequestState_t
|
||||
class LlmRequest
|
||||
{
|
||||
public:
|
||||
using BeamTokens = std::vector<std::vector<int32_t>>;
|
||||
using SizeType = runtime::SizeType;
|
||||
using TokenIdType = runtime::TokenIdType;
|
||||
using RequestIdType = std::uint64_t;
|
||||
using BeamTokens = std::vector<std::vector<TokenIdType>>;
|
||||
|
||||
LlmRequest(uint64_t requestId, int32_t maxNewTokens, std::shared_ptr<std::vector<int32_t>> input_tokens,
|
||||
LlmRequest(RequestIdType requestId, SizeType maxNewTokens, std::shared_ptr<std::vector<TokenIdType>> input_tokens,
|
||||
runtime::SamplingConfig samplingConfig, bool isStreaming, std::optional<SizeType> endId = std::nullopt,
|
||||
std::optional<SizeType> padId = std::nullopt)
|
||||
: mRequestId(requestId)
|
||||
, mPromptLen(input_tokens->size())
|
||||
, mMaxNewTokens(maxNewTokens)
|
||||
, mSamplingConfig(samplingConfig)
|
||||
, mPromptLen(input_tokens->size())
|
||||
, mNumGeneratedTokens(0)
|
||||
, mState(REQUEST_STATE_CONTEXT_INIT)
|
||||
, mIsStreaming(isStreaming)
|
||||
, mEndId(endId)
|
||||
, mPadId(padId)
|
||||
, mBatchSlot(-1)
|
||||
, mNumGeneratedTokens(0)
|
||||
{
|
||||
mMaxSentTokenPos = mPromptLen - 1;
|
||||
// Scatter the input tokens to other beam
|
||||
@ -62,7 +64,7 @@ public:
|
||||
/// @brief Get total number of tokens for this req (prompt + generated)
|
||||
/// @param beam The beam index
|
||||
/// @return The number of tokens
|
||||
int32_t getNumTokens(int beam) const
|
||||
SizeType getNumTokens(SizeType beam) const
|
||||
{
|
||||
return mTokens->at(beam).size();
|
||||
}
|
||||
@ -71,7 +73,7 @@ public:
|
||||
/// @param beam The beam index
|
||||
/// @param pos The position of the token relative to beginning of the prompt
|
||||
/// @return The token index
|
||||
int32_t getToken(int beam, int pos) const
|
||||
TokenIdType getToken(SizeType beam, SizeType pos) const
|
||||
{
|
||||
return mTokens->at(beam).at(pos);
|
||||
}
|
||||
@ -79,14 +81,14 @@ public:
|
||||
/// @brief Get the tokens at a given beam index
|
||||
/// @param beam The beam index
|
||||
/// @return A vector of tokens for this beam index, includes the prompt
|
||||
std::vector<int32_t> getTokens(int beam) const
|
||||
std::vector<TokenIdType> getTokens(SizeType beam) const
|
||||
{
|
||||
return mTokens->at(beam);
|
||||
}
|
||||
|
||||
/// @brief Get the number of generated tokens
|
||||
/// @return The number of generated tokens (doesn't include the prompt tokens)
|
||||
int32_t getNumGeneratedTokens() const
|
||||
SizeType getNumGeneratedTokens() const
|
||||
{
|
||||
return mNumGeneratedTokens;
|
||||
}
|
||||
@ -94,10 +96,10 @@ public:
|
||||
/// @brief Add new generated tokens to the vector of tokens
|
||||
/// @param beamTokens A vector containing the tokens to add for each beam index
|
||||
/// beamTokens is expected to be of size beamWidth
|
||||
void addNewTokens(const std::vector<int32_t>& beamTokens)
|
||||
void addNewTokens(const std::vector<TokenIdType>& beamTokens)
|
||||
{
|
||||
assert(mSamplingConfig.beamWidth == beamTokens.size());
|
||||
for (int beam = 0; beam < beamTokens.size(); ++beam)
|
||||
for (std::size_t beam = 0; beam < beamTokens.size(); ++beam)
|
||||
{
|
||||
mTokens->at(beam).push_back(beamTokens[beam]);
|
||||
}
|
||||
@ -109,7 +111,7 @@ public:
|
||||
void setGeneratedTokens(const BeamTokens& generatedBeamTokens)
|
||||
{
|
||||
assert(generatedBeamTokens.size() == mSamplingConfig.beamWidth);
|
||||
for (int beam = 0; beam < generatedBeamTokens.size(); ++beam)
|
||||
for (std::size_t beam = 0; beam < generatedBeamTokens.size(); ++beam)
|
||||
{
|
||||
auto& beamTokens = (*mTokens)[beam];
|
||||
beamTokens.resize(mPromptLen);
|
||||
@ -151,7 +153,7 @@ public:
|
||||
/// @brief Get the maximum position of the tokens returned to the client. Use to ensure we don't return to client
|
||||
/// duplicated token positions.
|
||||
/// @return The maximum position of the tokens sent to the client
|
||||
int32_t getMaxSentTokenPos() const
|
||||
SizeType getMaxSentTokenPos() const
|
||||
{
|
||||
return mMaxSentTokenPos;
|
||||
}
|
||||
@ -159,28 +161,26 @@ public:
|
||||
/// @brief Sets the maximum position of the tokens returned to the client. Use to ensure we don't return to client
|
||||
/// duplicated token positions.
|
||||
/// @param pos The maximum position
|
||||
void setMaxSentTokenPos(int32_t pos)
|
||||
void setMaxSentTokenPos(SizeType pos)
|
||||
{
|
||||
mMaxSentTokenPos = pos;
|
||||
}
|
||||
|
||||
uint64_t mRequestId;
|
||||
int32_t mMaxNewTokens;
|
||||
RequestIdType mRequestId;
|
||||
SizeType mPromptLen;
|
||||
SizeType mMaxNewTokens;
|
||||
// Tokens [beam_size, mPromptLen + mNumGeneratedTokens]
|
||||
runtime::SamplingConfig mSamplingConfig;
|
||||
int32_t mPromptLen;
|
||||
LlmRequestState_t mState;
|
||||
bool mIsStreaming;
|
||||
std::optional<SizeType> mEndId;
|
||||
std::optional<SizeType> mPadId;
|
||||
int32_t mBatchSlot;
|
||||
|
||||
~LlmRequest() {}
|
||||
SizeType mBatchSlot;
|
||||
|
||||
private:
|
||||
std::shared_ptr<BeamTokens> mTokens;
|
||||
int32_t mNumGeneratedTokens;
|
||||
int32_t mMaxSentTokenPos;
|
||||
SizeType mNumGeneratedTokens;
|
||||
SizeType mMaxSentTokenPos;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::batch_manager
|
||||
|
||||
@ -90,10 +90,22 @@ public:
|
||||
void setZero(IBuffer& buffer) const;
|
||||
|
||||
//! \brief Copy `src` to `dst`.
|
||||
void copy(void const* src, IBuffer& dst) const;
|
||||
void copy(void const* src, IBuffer& dst, MemoryType srcType) const;
|
||||
|
||||
//! \brief Copy `src` to `dst`.
|
||||
void copy(IBuffer const& src, void* dst) const;
|
||||
void copy(IBuffer const& src, void* dst, MemoryType dstType) const;
|
||||
|
||||
//! \brief Copy `src` to `dst`.
|
||||
void copy(void const* src, IBuffer& dst) const
|
||||
{
|
||||
return copy(src, dst, IBuffer::memoryType(src));
|
||||
}
|
||||
|
||||
//! \brief Copy `src` to `dst`.
|
||||
void copy(IBuffer const& src, void* dst) const
|
||||
{
|
||||
return copy(src, dst, IBuffer::memoryType(dst));
|
||||
}
|
||||
|
||||
//! \brief Copy `src` to `dst`.
|
||||
void copy(IBuffer const& src, IBuffer& dst) const;
|
||||
|
||||
102
cpp/include/tensorrt_llm/runtime/cudaEvent.h
Normal file
102
cpp/include/tensorrt_llm/runtime/cudaEvent.h
Normal file
@ -0,0 +1,102 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/common/logger.h"
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace tensorrt_llm::runtime
|
||||
{
|
||||
|
||||
class CudaEvent
|
||||
{
|
||||
public:
|
||||
using pointer = cudaEvent_t;
|
||||
|
||||
//! Creates a new cuda event. The event will be destroyed in the destructor.
|
||||
//!
|
||||
//! \param flags Flags for event creation. By default, event timing is disabled.
|
||||
explicit CudaEvent(unsigned int flags = cudaEventDisableTiming)
|
||||
{
|
||||
pointer event;
|
||||
TLLM_CUDA_CHECK(::cudaEventCreate(&event, flags));
|
||||
TLLM_LOG_TRACE("Created event %p", event);
|
||||
bool constexpr ownsEvent{true};
|
||||
mEvent = EventPtr{event, Deleter{ownsEvent}};
|
||||
}
|
||||
|
||||
//! Pass an existing cuda event to this object.
|
||||
//!
|
||||
//! \param event The event to pass to this object.
|
||||
//! \param ownsEvent Whether this object owns the event and destroys it in the destructor.
|
||||
explicit CudaEvent(pointer event, bool ownsEvent = true)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(event != nullptr, "event is nullptr");
|
||||
mEvent = EventPtr{event, Deleter{ownsEvent}};
|
||||
}
|
||||
|
||||
//! Returns the event associated with this object.
|
||||
[[nodiscard]] pointer get() const
|
||||
{
|
||||
return mEvent.get();
|
||||
}
|
||||
|
||||
//! \brief Synchronizes the event.
|
||||
void synchronize() const
|
||||
{
|
||||
TLLM_CUDA_CHECK(::cudaEventSynchronize(get()));
|
||||
}
|
||||
|
||||
private:
|
||||
class Deleter
|
||||
{
|
||||
public:
|
||||
explicit Deleter(bool ownsEvent)
|
||||
: mOwnsEvent{ownsEvent}
|
||||
{
|
||||
}
|
||||
|
||||
explicit Deleter()
|
||||
: Deleter{true}
|
||||
{
|
||||
}
|
||||
|
||||
constexpr void operator()(pointer event) const
|
||||
{
|
||||
if (mOwnsEvent && event != nullptr)
|
||||
{
|
||||
TLLM_CUDA_CHECK(::cudaEventDestroy(event));
|
||||
TLLM_LOG_TRACE("Destroyed event %p", event);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
bool mOwnsEvent;
|
||||
};
|
||||
|
||||
using element_type = std::remove_pointer_t<pointer>;
|
||||
using EventPtr = std::unique_ptr<element_type, Deleter>;
|
||||
|
||||
EventPtr mEvent;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::runtime
|
||||
@ -19,6 +19,7 @@
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/common/logger.h"
|
||||
#include "tensorrt_llm/runtime/cudaEvent.h"
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
|
||||
@ -78,17 +79,29 @@ public:
|
||||
}
|
||||
|
||||
//! \brief Record an event on the stream.
|
||||
void record(tensorrt_llm::common::EventPtr::pointer event)
|
||||
void record(CudaEvent::pointer event) const
|
||||
{
|
||||
TLLM_CUDA_CHECK(::cudaEventRecord(event, get()));
|
||||
}
|
||||
|
||||
//! \brief Record an event on the stream.
|
||||
void record(CudaEvent const& event) const
|
||||
{
|
||||
record(event.get());
|
||||
}
|
||||
|
||||
//! \brief Wait for an event.
|
||||
void wait(tensorrt_llm::common::EventPtr::pointer event)
|
||||
void wait(CudaEvent::pointer event) const
|
||||
{
|
||||
TLLM_CUDA_CHECK(::cudaStreamWaitEvent(get(), event));
|
||||
}
|
||||
|
||||
//! \brief Wait for an event.
|
||||
void wait(CudaEvent const& event) const
|
||||
{
|
||||
wait(event.get());
|
||||
}
|
||||
|
||||
private:
|
||||
class Deleter
|
||||
{
|
||||
|
||||
@ -43,7 +43,8 @@ public:
|
||||
TensorPtr ids; // [batchSize, beamWidth, maxInputLength + maxNewTokens]
|
||||
|
||||
// optional parameters
|
||||
TensorPtr logProbs; // [request_output_length, batch_size * beam_width], must be float*, on gpu
|
||||
TensorPtr logProbs; // [request_output_length, batch_size * beam_width], must be float*, on gpu
|
||||
TensorPtr contextLogits; // [batch_size, max_input_length, vocab_size_padded]
|
||||
|
||||
// callbacks
|
||||
Callback onTokenGenerated;
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/runtime/bufferManager.h"
|
||||
#include "tensorrt_llm/runtime/cudaEvent.h"
|
||||
#include "tensorrt_llm/runtime/cudaStream.h"
|
||||
#include "tensorrt_llm/runtime/gptDecoder.h"
|
||||
#include "tensorrt_llm/runtime/iGptDecoderBatch.h"
|
||||
@ -51,19 +52,27 @@ public:
|
||||
|
||||
void newBatch(GenerationInput const& inputs, SamplingConfig const& samplingConfig) override;
|
||||
|
||||
//! @brief Run one step for all requests.
|
||||
//! Note that this method will synchronize with the stream associated with the decoder.
|
||||
void forward(decoder_batch::Output& output, decoder_batch::Input const& input) override;
|
||||
TokenPtr forwardAsync(decoder_batch::Output& output, decoder_batch::Input const& input) override;
|
||||
|
||||
bool forward(decoder::Output& output, decoder::Input const& input) override;
|
||||
void forwardSync(decoder_batch::Token const& e) override;
|
||||
|
||||
//! @brief Gather final results for request `batchIdx`.
|
||||
void postProcessRequest(SizeType batchIdx) const override;
|
||||
void forwardAsync(decoder::Output& output, decoder::Input const& input) override;
|
||||
|
||||
bool isFinishedSync() override;
|
||||
|
||||
//! @return [batchSize], indicators of finished requests
|
||||
[[nodiscard]] std::vector<bool> getFinished() const override
|
||||
{
|
||||
return std::vector<bool>(mFinished.begin(), mFinished.begin() + mActualBatchSize);
|
||||
return {mFinished.begin(), mFinished.begin() + mActualBatchSize};
|
||||
}
|
||||
|
||||
//! @returns [maxBeamWidth, maxInputLength + maxNewTokens], contains input token ids and generated token ids without
|
||||
//! padding for request `batchIdx`, on gpu
|
||||
[[nodiscard]] TensorPtr getOutputIds(SizeType batchIdx) const override
|
||||
{
|
||||
auto tensor = ITensor::slice(mJointDecodingOutput->ids, batchIdx, 1);
|
||||
tensor->squeeze(0);
|
||||
return tensor;
|
||||
}
|
||||
|
||||
//! @returns [batchSize, maxBeamWidth, maxInputLength + maxNewTokens], contains input token ids and generated token
|
||||
@ -73,6 +82,14 @@ public:
|
||||
return ITensor::slice(mJointDecodingOutput->ids, 0, mActualBatchSize);
|
||||
}
|
||||
|
||||
//! Execute postProcessRequest and returns OutputIds for request `batchIdx`.
|
||||
//! @returns [maxBeamWidth, maxInputLength + maxNewTokens], contains input token ids and generated token ids without
|
||||
//! padding for request `batchIdx`, on gpu
|
||||
[[nodiscard]] TensorPtr getFinalOutputIds(SizeType batchIdx) const override;
|
||||
|
||||
//! Execute postProcessRequest and returns OutputIds.
|
||||
//! @returns [batchSize, maxBeamWidth, maxInputLength + maxNewTokens], contains input token ids and generated token
|
||||
//! ids without padding, on gpu
|
||||
[[nodiscard]] TensorPtr getFinalOutputIds() const override;
|
||||
|
||||
//! @returns [batchSize, maxBeamWidth, maxInputLength + maxNewTokens], contains parent ids collected during beam
|
||||
@ -112,15 +129,25 @@ public:
|
||||
return std::vector<SizeType>(mNbSteps.begin(), mNbSteps.begin() + mActualBatchSize);
|
||||
}
|
||||
|
||||
//! @returns [1], number of finished sequences, in pinned host memory
|
||||
[[nodiscard]] TensorPtr getNbFinished() const override
|
||||
{
|
||||
return mFinishedSum;
|
||||
}
|
||||
|
||||
private:
|
||||
//! @brief Gather final results for request `batchIdx`.
|
||||
void postProcessRequest(SizeType batchIdx) const;
|
||||
|
||||
private:
|
||||
std::size_t const mVocabSize;
|
||||
std::size_t const mVocabSizePadded;
|
||||
CudaStreamPtr mStream;
|
||||
BufferManager mBufferManager;
|
||||
tensorrt_llm::common::EventPtr mEventStart, mEventStop;
|
||||
TokenPtr mForwardToken;
|
||||
CudaEvent mForwardEvent;
|
||||
|
||||
std::vector<CudaStreamPtr> mStreams;
|
||||
std::vector<tensorrt_llm::common::EventPtr> mEvents;
|
||||
using GptDecoderPtr = std::unique_ptr<IGptDecoder>;
|
||||
std::vector<GptDecoderPtr> mDecoders;
|
||||
using DecodingInputPtr = std::unique_ptr<DecodingInput>;
|
||||
@ -133,6 +160,7 @@ private:
|
||||
|
||||
std::vector<SizeType> mNbSteps;
|
||||
std::vector<bool> mFinished;
|
||||
TensorPtr mFinishedSum;
|
||||
std::vector<SizeType> mMaxNewTokens;
|
||||
std::vector<SizeType> mBeamWidths;
|
||||
SizeType mMaxSequenceLength{};
|
||||
|
||||
@ -42,6 +42,7 @@ public:
|
||||
, mMaxBatchSize(0)
|
||||
, mMaxInputLen(0)
|
||||
, mMaxOutputLen(0)
|
||||
, mComputeContextLogits(false)
|
||||
{
|
||||
}
|
||||
|
||||
@ -176,6 +177,16 @@ public:
|
||||
mMaxOutputLen = maxOutputLen;
|
||||
}
|
||||
|
||||
[[nodiscard]] bool constexpr computeContextLogits() const noexcept
|
||||
{
|
||||
return mComputeContextLogits;
|
||||
}
|
||||
|
||||
void constexpr computeContextLogits(bool computeContextLogits) noexcept
|
||||
{
|
||||
mComputeContextLogits = computeContextLogits;
|
||||
}
|
||||
|
||||
private:
|
||||
SizeType mVocabSize;
|
||||
SizeType mNbLayers;
|
||||
@ -191,6 +202,7 @@ private:
|
||||
SizeType mMaxBatchSize;
|
||||
SizeType mMaxInputLen;
|
||||
SizeType mMaxOutputLen;
|
||||
bool mComputeContextLogits;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::runtime
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
|
||||
#include "tensorrt_llm/runtime/bufferManager.h"
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
#include "tensorrt_llm/runtime/cudaEvent.h"
|
||||
#include "tensorrt_llm/runtime/generationInput.h"
|
||||
#include "tensorrt_llm/runtime/generationOutput.h"
|
||||
#include "tensorrt_llm/runtime/gptModelConfig.h"
|
||||
@ -99,19 +100,52 @@ public:
|
||||
mCudaGraphMode = value;
|
||||
}
|
||||
|
||||
void setup(SizeType batchSize, SizeType beamWidth, SizeType maxSequenceLength, bool decoderPerRequest,
|
||||
std::optional<SizeType> maxTokensInPagedKvCache = std::nullopt);
|
||||
//! @brief Initialize buffers for the given sizes.
|
||||
//! `generate` may be called with batch size and beam width smaller than the setup parameters.
|
||||
//! @details `maxBatchSize` will be devided by the number of micro batches to initialize each batch buffer.
|
||||
void setup(SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxSequenceLength, bool decoderPerRequest,
|
||||
std::optional<SizeType> maxTokensInPagedKvCache = std::nullopt,
|
||||
std::optional<SizeType> numMicroBatches = std::nullopt);
|
||||
|
||||
void generate(GenerationOutput& outputs, GenerationInput const& inputs, SamplingConfig const& samplingConfig);
|
||||
void generate(GenerationOutput& outputs, GenerationInput const& inputs, SamplingConfig const& samplingConfig)
|
||||
{
|
||||
if (mNumMicroBatches == 1)
|
||||
generateSingleBatch(outputs, inputs, samplingConfig);
|
||||
else
|
||||
generateMultiBatch(outputs, inputs, samplingConfig);
|
||||
}
|
||||
|
||||
private:
|
||||
void generateSingleBatch(
|
||||
GenerationOutput& outputs, GenerationInput const& inputs, SamplingConfig const& samplingConfig);
|
||||
|
||||
void generateMultiBatch(
|
||||
GenerationOutput& outputs, GenerationInput const& inputs, SamplingConfig const& samplingConfig);
|
||||
|
||||
using KvCacheManager = batch_manager::kv_cache_manager::KVCacheManager;
|
||||
|
||||
void createContexts();
|
||||
void createDecoder(bool decoderPerRequest);
|
||||
void createContexts(SizeType numMicroBatches);
|
||||
void createBuffers(SizeType numMicroBatches);
|
||||
void createDecoders(SizeType batchSize, SizeType beamWidth, SizeType maxSequenceLength,
|
||||
nvinfer1::DataType logitsType, bool decoderPerRequest, SizeType numMicroBatches);
|
||||
void createKvCacheManagers(SizeType batchSize, SizeType beamWidth, SizeType maxSequenceLength,
|
||||
SizeType numMicroBatches, std::optional<SizeType> maxTokensInPagedKvCache);
|
||||
|
||||
bool executeDecoderStep(ITensor::SharedPtr& outputIds, ITensor::SharedPtr& newTokens, SizeType decoderStep);
|
||||
void finalizeOutputIds(ITensor& outputIds);
|
||||
//! @brief Execute decoder on last PP rank, receive decoder output on other PP ranks.
|
||||
void decoderStepAsync(
|
||||
ITensor::SharedPtr& outputIds, ITensor::SharedPtr& newTokens, SizeType decoderStep, SizeType microBatchId);
|
||||
|
||||
//! @brief Synchronize with the decoder and return the `shouldStop` flag.
|
||||
bool shouldStopSync(SizeType batchSize, SizeType beamWidth, SizeType microBatchId);
|
||||
|
||||
//! @brief Collect final output ids on last PP rank and send them to first PP rank.
|
||||
//! @details Receives are asynchronous on host, so synchronization is required before access.
|
||||
void finalizeOutputIds(ITensor& outputIds, SizeType microBatchId);
|
||||
|
||||
void kvCacheAddSequences(SizeType beamWidth, SizeType microBatchId);
|
||||
|
||||
ITensor::SharedPtr initNewTokens(
|
||||
GenerationInput const& inputs, SamplingConfig const& samplingConfig, SizeType microBatchId);
|
||||
|
||||
class CudaGraphExecutor
|
||||
{
|
||||
@ -153,15 +187,20 @@ private:
|
||||
WorldConfig const mWorldConfig;
|
||||
int mDevice{-1};
|
||||
std::shared_ptr<NcclCommunicator> mPipelineComm;
|
||||
std::shared_ptr<CudaStream> mCommStream;
|
||||
CudaEvent mCommEvent{};
|
||||
|
||||
SizeType mDecoderMaxSequenceLength{};
|
||||
|
||||
LoggerPtr mLogger;
|
||||
std::shared_ptr<TllmRuntime> mRuntime;
|
||||
std::shared_ptr<IStatefulGptDecoder> mDecoder;
|
||||
|
||||
std::shared_ptr<RuntimeBuffers> mBuffers;
|
||||
std::shared_ptr<KvCacheManager> mKvCacheManager;
|
||||
SizeType mNumMicroBatches;
|
||||
// for each micro batch
|
||||
std::vector<std::shared_ptr<IStatefulGptDecoder>> mDecoders;
|
||||
std::vector<std::shared_ptr<RuntimeBuffers>> mBuffers;
|
||||
std::vector<std::shared_ptr<KvCacheManager>> mKvCacheManagers;
|
||||
std::vector<CudaEvent> mReceivedEvents;
|
||||
|
||||
bool mCudaGraphMode{false};
|
||||
// ping-pong instances
|
||||
|
||||
@ -17,6 +17,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/runtime/bufferManager.h"
|
||||
#include "tensorrt_llm/runtime/cudaEvent.h"
|
||||
#include "tensorrt_llm/runtime/cudaStream.h"
|
||||
#include "tensorrt_llm/runtime/iStatefulGptDecoder.h"
|
||||
#include "tensorrt_llm/runtime/iTensor.h"
|
||||
@ -41,7 +42,6 @@ public:
|
||||
: ids{std::move(ids)}
|
||||
, maxNewTokens{maxNewTokens}
|
||||
, endId{endId}
|
||||
, padId{padId}
|
||||
{
|
||||
}
|
||||
|
||||
@ -51,7 +51,6 @@ public:
|
||||
// optional parameters
|
||||
std::optional<SizeType> maxNewTokens; // maximum number of tokens to generate for this request
|
||||
std::optional<SizeType> endId; // end token id
|
||||
std::optional<SizeType> padId; // pad token id
|
||||
TensorPtr embeddingBias; // [vocabSizePadded], on gpu
|
||||
TensorPtr badWordsList; // [2, badWordsLength], on gpu
|
||||
TensorPtr stopWordsList; // [2, stopWordsLength], on gpu
|
||||
@ -82,6 +81,19 @@ public:
|
||||
};
|
||||
|
||||
using Output = decoder::Output;
|
||||
|
||||
class Token
|
||||
{
|
||||
public:
|
||||
explicit Token(CudaEvent&& event, std::vector<bool> const& active)
|
||||
: event(std::move(event))
|
||||
, active(active)
|
||||
{
|
||||
}
|
||||
|
||||
CudaEvent event;
|
||||
std::vector<bool> active;
|
||||
};
|
||||
} // namespace decoder_batch
|
||||
|
||||
//! GPT decoder class with support for in-flight batching
|
||||
@ -90,17 +102,33 @@ class IGptDecoderBatch : public virtual IStatefulGptDecoder
|
||||
public:
|
||||
using CudaStreamPtr = std::shared_ptr<CudaStream>;
|
||||
using TensorPtr = std::shared_ptr<ITensor>;
|
||||
using TokenPtr = std::unique_ptr<decoder_batch::Token const>;
|
||||
|
||||
//! @brief Initialize the decoder at `batchIdx` with a new `request`.
|
||||
virtual void newRequest(
|
||||
SizeType batchIdx, decoder_batch::Request const& request, SamplingConfig const& samplingConfig)
|
||||
= 0;
|
||||
|
||||
//! @brief Run one step for all requests.
|
||||
virtual void forward(decoder_batch::Output& output, decoder_batch::Input const& input) = 0;
|
||||
//! @brief Run one step for all requests without blocking the host process and return the token for synchronization.
|
||||
virtual TokenPtr forwardAsync(decoder_batch::Output& output, decoder_batch::Input const& input) = 0;
|
||||
|
||||
//! @brief Gather final results for request `batchIdx`.
|
||||
virtual void postProcessRequest(SizeType batchIdx) const = 0;
|
||||
//! @brief Wait for the call to `forwardAsync` associated with a token to complete.
|
||||
virtual void forwardSync(decoder_batch::Token const& token) = 0;
|
||||
|
||||
//! @brief Run one step for all requests and wait for completion on the host.
|
||||
virtual void forward(decoder_batch::Output& output, decoder_batch::Input const& input)
|
||||
{
|
||||
forwardSync(*forwardAsync(output, input));
|
||||
}
|
||||
|
||||
//! @returns [maxBeamWidth, maxInputLength + maxNewTokens], contains input token ids and generated token
|
||||
//! ids without padding for request `batchIdx`, on gpu
|
||||
virtual TensorPtr getOutputIds(SizeType batchIdx) const = 0;
|
||||
|
||||
//! Execute postProcessRequest and returns OutputIds for request `batchIdx`.
|
||||
//! @returns [maxBeamWidth, maxInputLength + maxNewTokens], contains input token ids and generated token ids without
|
||||
//! padding for request `batchIdx`, on gpu
|
||||
virtual TensorPtr getFinalOutputIds(SizeType batchIdx) const = 0;
|
||||
|
||||
//! @returns [batchSize, beamWidth], marks finished requests (per beam), on gpu
|
||||
virtual TensorPtr getFinishedBeams() const = 0;
|
||||
@ -108,6 +136,9 @@ public:
|
||||
//! @returns [batchSize, beamWidth], total sequence lengths (per beam), on gpu
|
||||
virtual TensorPtr getOutputLengths() const = 0;
|
||||
|
||||
//! @returns [batchSize (actual)], marks finished requests (per batch)
|
||||
virtual std::vector<bool> getFinished() const = 0;
|
||||
|
||||
//! @returns [batchSize, beamWidth], cumulative log probabilities (per beam), on gpu
|
||||
virtual TensorPtr getCumLogProbs() const = 0;
|
||||
|
||||
|
||||
@ -80,21 +80,31 @@ public:
|
||||
//! @brief Initialize the decoder with new batch of inputs.
|
||||
virtual void newBatch(GenerationInput const& inputs, SamplingConfig const& samplingConfig) = 0;
|
||||
|
||||
//! @brief Run one step for all requests without blocking the host thread.
|
||||
virtual void forwardAsync(decoder::Output& output, decoder::Input const& input) = 0;
|
||||
|
||||
//! @brief Wait for the last call to `forwardAsync` to complete and return whether all sequences have finished.
|
||||
virtual bool isFinishedSync() = 0;
|
||||
|
||||
//! @brief Run one step for all requests.
|
||||
virtual bool forward(decoder::Output& output, decoder::Input const& input) = 0;
|
||||
virtual bool forward(decoder::Output& output, decoder::Input const& input)
|
||||
{
|
||||
forwardAsync(output, input);
|
||||
return isFinishedSync();
|
||||
}
|
||||
|
||||
//! @brief Gather final results for all requests.
|
||||
virtual TensorPtr getFinalOutputIds() const = 0;
|
||||
|
||||
// TODO: do we need that?
|
||||
virtual std::vector<bool> getFinished() const = 0;
|
||||
|
||||
//! @returns [batchSize, beamWidth, maxSequenceLength], all token ids, on gpu
|
||||
virtual TensorPtr getOutputIds() const = 0;
|
||||
|
||||
//! @returns [batchSize, beamWidth], latests generated tokens (per beam), on gpu
|
||||
virtual TensorPtr getNewTokens() const = 0;
|
||||
|
||||
//! @returns [1], number of finished sequences, in pinned host memory
|
||||
virtual TensorPtr getNbFinished() const = 0;
|
||||
|
||||
protected:
|
||||
IStatefulGptDecoder() = default;
|
||||
};
|
||||
|
||||
@ -30,7 +30,7 @@ class MemoryCounters
|
||||
{
|
||||
public:
|
||||
using SizeType = std::size_t;
|
||||
using DiffType = std::int64_t;
|
||||
using DiffType = std::ptrdiff_t;
|
||||
|
||||
MemoryCounters() = default;
|
||||
|
||||
|
||||
@ -95,6 +95,8 @@ public:
|
||||
|
||||
[[nodiscard]] std::vector<SizeType> getPipelineParallelGroup() const;
|
||||
|
||||
static bool validConfig(nvinfer1::ILogger& logger, SizeType tensorParallelism, SizeType pipelineParallelism);
|
||||
|
||||
static WorldConfig mpi(nvinfer1::ILogger& logger, SizeType gpusPerNode = kDefaultGpusPerNode,
|
||||
std::optional<SizeType> tensorParallelism = std::nullopt,
|
||||
std::optional<SizeType> pipelineParallelism = std::nullopt);
|
||||
|
||||
@ -35,10 +35,11 @@ add_subdirectory(kernels)
|
||||
add_subdirectory(layers)
|
||||
add_subdirectory(runtime)
|
||||
|
||||
set(BATCH_MANAGER_TARGET tensorrt_llm_batch_manager_static)
|
||||
if(BUILD_BATCH_MANAGER)
|
||||
add_subdirectory(batch_manager)
|
||||
else()
|
||||
add_library(tensorrt_llm_batch_manager_static STATIC IMPORTED)
|
||||
add_library(${BATCH_MANAGER_TARGET} STATIC IMPORTED)
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} "-c"
|
||||
"import torch; print(torch.compiled_with_cxx11_abi(),end='');"
|
||||
@ -48,14 +49,14 @@ else()
|
||||
message(STATUS "USE_CXX11_ABI: ${USE_CXX11_ABI}")
|
||||
if(USE_CXX11_ABI)
|
||||
set_property(
|
||||
TARGET tensorrt_llm_batch_manager_static
|
||||
TARGET ${BATCH_MANAGER_TARGET}
|
||||
PROPERTY
|
||||
IMPORTED_LOCATION
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/batch_manager/libtensorrt_llm_batch_manager_static.a"
|
||||
)
|
||||
else()
|
||||
set_property(
|
||||
TARGET tensorrt_llm_batch_manager_static
|
||||
TARGET ${BATCH_MANAGER_TARGET}
|
||||
PROPERTY
|
||||
IMPORTED_LOCATION
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/batch_manager/libtensorrt_llm_batch_manager_static.pre_cxx11.a"
|
||||
@ -69,16 +70,19 @@ set(TRTLLM_LINK_LIBS
|
||||
${CUDNN_LIB}
|
||||
${CMAKE_DL_LIBS}
|
||||
${MPI_CXX_LIBRARIES}
|
||||
${NCCL_LIB}
|
||||
${TRT_LIB}
|
||||
common_src
|
||||
kernels_src
|
||||
layers_src
|
||||
runtime_src
|
||||
tensorrt_llm_batch_manager_static)
|
||||
${BATCH_MANAGER_TARGET})
|
||||
|
||||
# ################################# SHARED LIBRARY
|
||||
# ##############################################################################
|
||||
|
||||
set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS ON)
|
||||
|
||||
add_library(${SHARED_TARGET} SHARED)
|
||||
|
||||
set_target_properties(
|
||||
@ -105,6 +109,9 @@ set_target_properties(
|
||||
|
||||
target_link_libraries(${STATIC_TARGET} PUBLIC ${TRTLLM_LINK_LIBS})
|
||||
|
||||
# Cyclic dependency of batch manager on TRT-LLM
|
||||
target_link_libraries(${BATCH_MANAGER_TARGET} INTERFACE ${STATIC_TARGET})
|
||||
|
||||
if(BUILD_PYT)
|
||||
add_subdirectory(thop)
|
||||
endif()
|
||||
|
||||
Binary file not shown.
Binary file not shown.
@ -1,220 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/common/cublasAlgoMap.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace common
|
||||
{
|
||||
|
||||
cublasAlgoMap::cublasAlgoMap(const std::string filename, const std::string sp_config_filename)
|
||||
: config_filename_(filename)
|
||||
, sp_config_filename_(sp_config_filename)
|
||||
{
|
||||
loadGemmConfig();
|
||||
loadSpGemmConfig();
|
||||
}
|
||||
|
||||
cublasAlgoMap::cublasAlgoMap(const cublasAlgoMap& algo_map)
|
||||
: config_filename_(algo_map.config_filename_)
|
||||
, sp_config_filename_(algo_map.sp_config_filename_)
|
||||
, algo_map_(algo_map.algo_map_)
|
||||
, sp_algo_map_(algo_map.sp_algo_map_)
|
||||
{
|
||||
}
|
||||
|
||||
cublasAlgoMap::~cublasAlgoMap()
|
||||
{
|
||||
algo_map_.clear();
|
||||
}
|
||||
|
||||
void cublasAlgoMap::loadGemmConfig()
|
||||
{
|
||||
FILE* fd;
|
||||
fd = fopen(config_filename_.c_str(), "r");
|
||||
if (fd == NULL)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
int batchCount2, m2, n2, k2, algoId, customOption, tile, splitK_val;
|
||||
int batch_size, seq_len, head_num, size_per_head, dataType;
|
||||
int swizzle, reductionScheme, workspaceSize, stages;
|
||||
int inner_shapeId, cluster_shapeId, mma_shapeId, cga_shapeId, sche_mode;
|
||||
float exec_time;
|
||||
char tmp[1024];
|
||||
if (!fgets(tmp, 1024, fd))
|
||||
{
|
||||
printf("[ERROR] fgets fail at %s:%d \n", __FILE__, __LINE__);
|
||||
exit(-1);
|
||||
}
|
||||
while (fscanf(fd,
|
||||
"%d %d %d %d %d ### %d %d %d %d %d %d %d %d %d %d %d %d "
|
||||
#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3)
|
||||
"%d %d "
|
||||
#elif (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH < 3)
|
||||
"%d %d %d "
|
||||
#endif
|
||||
"%f\n",
|
||||
&batch_size, &seq_len, &head_num, &size_per_head, &dataType, &batchCount2, &n2, &m2, &k2, &algoId,
|
||||
&customOption, &tile, &splitK_val, &swizzle, &reductionScheme, &workspaceSize, &stages,
|
||||
#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3)
|
||||
&inner_shapeId, &cluster_shapeId,
|
||||
#elif (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH < 3)
|
||||
&mma_shapeId, &cga_shapeId, &sche_mode,
|
||||
#endif
|
||||
&exec_time)
|
||||
!= EOF)
|
||||
{
|
||||
if (dataType != FLOAT_DATATYPE && dataType != HALF_DATATYPE && dataType != BFLOAT16_DATATYPE
|
||||
&& dataType != INT8_DATATYPE && dataType != FP8_DATATYPE)
|
||||
{
|
||||
printf("[WARNING][readAlgoFromConfig] wrong dataType %d!\n", dataType);
|
||||
continue;
|
||||
}
|
||||
char mark[256];
|
||||
sprintf(mark, "%d_%d_%d_%d_%d", batchCount2, m2, n2, k2, dataType);
|
||||
std::string markStr(mark);
|
||||
// workspaceSize should be zero
|
||||
if (algo_map_.find(markStr) == algo_map_.end())
|
||||
{
|
||||
algo_map_[markStr].algoId = algoId;
|
||||
algo_map_[markStr].customOption = customOption;
|
||||
algo_map_[markStr].tile = tile;
|
||||
algo_map_[markStr].splitK_val = splitK_val;
|
||||
algo_map_[markStr].swizzle = swizzle;
|
||||
algo_map_[markStr].reductionScheme = reductionScheme;
|
||||
algo_map_[markStr].workspaceSize = workspaceSize;
|
||||
algo_map_[markStr].stages = stages;
|
||||
#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3)
|
||||
algo_map_[markStr].inner_shapeId = (uint16_t) inner_shapeId;
|
||||
algo_map_[markStr].cluster_shapeId = (uint16_t) cluster_shapeId;
|
||||
#elif (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH < 3)
|
||||
algo_map_[markStr].mma_shapeId = (uint16_t) mma_shapeId;
|
||||
algo_map_[markStr].cga_shapeId = (uint16_t) cga_shapeId;
|
||||
algo_map_[markStr].sche_mode = (uint16_t) sche_mode;
|
||||
#endif
|
||||
algo_map_[markStr].exec_time = exec_time;
|
||||
}
|
||||
}
|
||||
fclose(fd);
|
||||
}
|
||||
|
||||
bool cublasAlgoMap::isExist(
|
||||
const int batch_count, const int m, const int n, const int k, const CublasDataType data_type)
|
||||
{
|
||||
char mark[256];
|
||||
sprintf(mark, "%d_%d_%d_%d_%d", batch_count, n, m, k, data_type);
|
||||
return algo_map_.find(mark) != algo_map_.end();
|
||||
}
|
||||
|
||||
cublasLtMatmulAlgo_info cublasAlgoMap::getAlgo(
|
||||
const int batch_count, const int m, const int n, const int k, const CublasDataType data_type)
|
||||
{
|
||||
char mark[256];
|
||||
sprintf(mark, "%d_%d_%d_%d_%d", batch_count, n, m, k, data_type);
|
||||
if (algo_map_.find(mark) != algo_map_.end())
|
||||
{
|
||||
return algo_map_[mark];
|
||||
}
|
||||
else
|
||||
{
|
||||
cublasLtMatmulAlgo_info tmp_algo;
|
||||
tmp_algo.algoId
|
||||
= static_cast<int>(data_type == FLOAT_DATATYPE ? CUBLAS_GEMM_DEFAULT : CUBLAS_GEMM_DEFAULT_TENSOR_OP);
|
||||
tmp_algo.customOption = -1;
|
||||
tmp_algo.tile = -1;
|
||||
tmp_algo.splitK_val = -1;
|
||||
tmp_algo.swizzle = -1;
|
||||
tmp_algo.reductionScheme = -1;
|
||||
tmp_algo.workspaceSize = -1;
|
||||
tmp_algo.stages = -1;
|
||||
tmp_algo.exec_time = -1.0f;
|
||||
return tmp_algo;
|
||||
}
|
||||
}
|
||||
|
||||
void cublasAlgoMap::loadSpGemmConfig()
|
||||
{
|
||||
if (sp_config_filename_.empty())
|
||||
{
|
||||
return;
|
||||
}
|
||||
FILE* fd = fopen(sp_config_filename_.c_str(), "r");
|
||||
if (fd == NULL)
|
||||
{
|
||||
return;
|
||||
}
|
||||
sp_algo_map_.clear();
|
||||
int batch_size, seq_len, head_num, size_per_head, data_type;
|
||||
int batchCount, m, n, k, algoId;
|
||||
float exec_time;
|
||||
char tmp[1024];
|
||||
if (!fgets(tmp, 1024, fd))
|
||||
{
|
||||
printf("[ERROR] fgets fail at %s:%d \n", __FILE__, __LINE__);
|
||||
exit(-1);
|
||||
}
|
||||
while (fscanf(fd, "%d %d %d %d %d ### %d %d %d %d %d %f\n", &batch_size, &seq_len, &head_num, &size_per_head,
|
||||
&data_type, &batchCount, &m, &n, &k, &algoId, &exec_time)
|
||||
!= EOF)
|
||||
{
|
||||
char mark[256];
|
||||
sprintf(mark, "%d_%d_%d_%d", batchCount, m, n, k);
|
||||
std::string markStr(mark);
|
||||
sp_algo_map_[markStr] = algoId;
|
||||
}
|
||||
fclose(fd);
|
||||
}
|
||||
|
||||
int cublasAlgoMap::getSpAlgo(const int batch_count, const int m, const int n, const int k)
|
||||
{
|
||||
char mark[256];
|
||||
sprintf(mark, "%d_%d_%d_%d", batch_count, m, n, k);
|
||||
if (sp_algo_map_.find(mark) != sp_algo_map_.end())
|
||||
{
|
||||
return sp_algo_map_[mark];
|
||||
}
|
||||
else
|
||||
{
|
||||
// for remove padding, select algo 1 for simplicity
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
bool cublasAlgoMap::isUseSparse(const int batch_count, const int m, const int n, const int k)
|
||||
{
|
||||
// not available to use cusparselt.
|
||||
if (m % 8 != 0 || n % 8 != 0 || k % 8 != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
char mark[256];
|
||||
sprintf(mark, "%d_%d_%d_%d", batch_count, m, n, k);
|
||||
if (sp_algo_map_.find(mark) != sp_algo_map_.end())
|
||||
{
|
||||
return sp_algo_map_[mark] != -1;
|
||||
}
|
||||
else
|
||||
{
|
||||
// no gemm test case, choose sparse according to sparse flag
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace common
|
||||
} // namespace tensorrt_llm
|
||||
@ -1,90 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/common/cublasVersionCheck.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include <cublasLt.h>
|
||||
#include <cublas_v2.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace common
|
||||
{
|
||||
|
||||
#define GEMM_NUM 6
|
||||
#define GEMM_CONFIG "gemm_config.in"
|
||||
#define IGEMM_CONFIG "igemm_config.in"
|
||||
#define SPGEMM_CONFIG "spgemm_config.in"
|
||||
#define SPIGEMM_CONFIG "spigemm_config.in"
|
||||
|
||||
typedef struct
|
||||
{
|
||||
int algoId, customOption, tile, splitK_val;
|
||||
int swizzle, reductionScheme, workspaceSize;
|
||||
// only used in cublasLt >= 11.0
|
||||
int stages;
|
||||
#if TLLM_CUBLAS_VER_GE(11, 11, 3)
|
||||
uint16_t inner_shapeId, cluster_shapeId;
|
||||
#else
|
||||
uint16_t mma_shapeId, cga_shapeId, sche_mode;
|
||||
#endif
|
||||
float exec_time;
|
||||
} cublasLtMatmulAlgo_info;
|
||||
|
||||
/* Structure to store information about different run trials */
|
||||
typedef struct
|
||||
{
|
||||
cublasLtMatmulAlgo_t algo;
|
||||
cublasStatus_t status;
|
||||
float time;
|
||||
size_t workspaceSize; // actual memory workspace needed
|
||||
cublasMath_t mathMode;
|
||||
cublasLtReductionScheme_t reductionScheme;
|
||||
int customOption;
|
||||
float wavesCount;
|
||||
} customMatmulPerf_t;
|
||||
|
||||
class cublasAlgoMap
|
||||
{
|
||||
private:
|
||||
std::map<std::string, cublasLtMatmulAlgo_info> algo_map_;
|
||||
std::string config_filename_;
|
||||
std::string sp_config_filename_;
|
||||
std::map<std::string, int> sp_algo_map_;
|
||||
|
||||
public:
|
||||
cublasAlgoMap(){};
|
||||
explicit cublasAlgoMap(const std::string filename, const std::string sp_config_filename = "");
|
||||
cublasAlgoMap(const cublasAlgoMap& map);
|
||||
~cublasAlgoMap();
|
||||
void loadGemmConfig();
|
||||
void loadSpGemmConfig();
|
||||
int getSpAlgo(const int batch_count, const int m, const int n, const int k);
|
||||
bool isUseSparse(const int batch_count, const int m, const int n, const int k);
|
||||
|
||||
bool isExist(const int batch_count, const int m, const int n, const int k, const CublasDataType data_type);
|
||||
|
||||
cublasLtMatmulAlgo_info getAlgo(
|
||||
const int batch_count, const int m, const int n, const int k, const CublasDataType data_type);
|
||||
};
|
||||
|
||||
} // namespace common
|
||||
} // namespace tensorrt_llm
|
||||
@ -28,299 +28,218 @@ namespace tensorrt_llm
|
||||
namespace common
|
||||
{
|
||||
|
||||
cublasMMWrapper::cublasMMWrapper(std::shared_ptr<cublasHandle_t> cublasHandle,
|
||||
std::shared_ptr<cublasLtHandle_t> cublasltHandle, cudaStream_t stream, cublasAlgoMap* cublas_algo_map,
|
||||
std::mutex* mu, void* workspace)
|
||||
: cublas_handle_(cublasHandle)
|
||||
, cublaslt_handle_(cublasltHandle)
|
||||
, stream_(stream)
|
||||
, cublas_algo_map_(cublas_algo_map)
|
||||
, mu_(mu)
|
||||
, cublas_workspace_(workspace)
|
||||
CublasMMWrapper::CublasMMWrapper(std::shared_ptr<cublasHandle_t> cublasHandle,
|
||||
std::shared_ptr<cublasLtHandle_t> cublasltHandle, cudaStream_t stream, void* workspace)
|
||||
: mCublasHandle(cublasHandle)
|
||||
, mCublasLtHandle(cublasltHandle)
|
||||
, mStream(stream)
|
||||
, mCublasWorkspace(workspace)
|
||||
{
|
||||
}
|
||||
|
||||
cublasMMWrapper::~cublasMMWrapper()
|
||||
CublasMMWrapper::~CublasMMWrapper()
|
||||
{
|
||||
mu_ = nullptr;
|
||||
mMutex = nullptr;
|
||||
}
|
||||
|
||||
cublasMMWrapper::cublasMMWrapper(const cublasMMWrapper& wrapper)
|
||||
: cublas_handle_(wrapper.cublas_handle_)
|
||||
, cublaslt_handle_(wrapper.cublaslt_handle_)
|
||||
, stream_(wrapper.stream_)
|
||||
, cublas_algo_map_(wrapper.cublas_algo_map_)
|
||||
, mu_(wrapper.mu_)
|
||||
CublasMMWrapper::CublasMMWrapper(const CublasMMWrapper& wrapper)
|
||||
: mCublasHandle(wrapper.mCublasHandle)
|
||||
, mCublasLtHandle(wrapper.mCublasLtHandle)
|
||||
, mStream(wrapper.mStream)
|
||||
, mMutex(wrapper.mMutex)
|
||||
{
|
||||
}
|
||||
|
||||
void cublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k,
|
||||
const void* alpha, const void* A, cudaDataType_t Atype, int lda, const void* B, cudaDataType_t Btype, int ldb,
|
||||
const void* beta, void* C, cudaDataType_t Ctype, int ldc, cudaDataType_t computeType, cublasGemmAlgo_t algo)
|
||||
void CublasMMWrapper::createDescriptors(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n,
|
||||
const int k, const int lda, const int ldb, const int ldc)
|
||||
{
|
||||
mu_->lock();
|
||||
check_cuda_error(cublasGemmEx(*cublas_handle_, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb, beta,
|
||||
C, Ctype, ldc, computeType, algo));
|
||||
sync_check_cuda_error();
|
||||
mu_->unlock();
|
||||
// --------------------------------------
|
||||
// Create descriptors for the original matrices
|
||||
check_cuda_error(
|
||||
cublasLtMatrixLayoutCreate(&mADesc, mAType, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda));
|
||||
check_cuda_error(
|
||||
cublasLtMatrixLayoutCreate(&mBDesc, mBType, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb));
|
||||
check_cuda_error(cublasLtMatrixLayoutCreate(&mCDesc, mCType, m, n, ldc));
|
||||
check_cuda_error(cublasLtMatmulDescCreate(&mOperationDesc, mComputeType, mScaleType));
|
||||
check_cuda_error(cublasLtMatmulDescSetAttribute(
|
||||
mOperationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(cublasOperation_t)));
|
||||
check_cuda_error(cublasLtMatmulDescSetAttribute(
|
||||
mOperationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(cublasOperation_t)));
|
||||
}
|
||||
|
||||
void cublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k,
|
||||
void CublasMMWrapper::destroyDescriptors()
|
||||
{
|
||||
check_cuda_error(cublasLtMatmulDescDestroy(mOperationDesc));
|
||||
check_cuda_error(cublasLtMatrixLayoutDestroy(mADesc));
|
||||
check_cuda_error(cublasLtMatrixLayoutDestroy(mBDesc));
|
||||
check_cuda_error(cublasLtMatrixLayoutDestroy(mCDesc));
|
||||
mOperationDesc = NULL;
|
||||
mADesc = NULL;
|
||||
mBDesc = NULL;
|
||||
mCDesc = NULL;
|
||||
}
|
||||
|
||||
void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k,
|
||||
const void* A, const int lda, const void* B, const int ldb, void* C, const int ldc)
|
||||
{
|
||||
Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, 1.0f, 0.0f);
|
||||
}
|
||||
|
||||
void cublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k,
|
||||
void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k,
|
||||
const void* A, const int lda, const void* B, const int ldb, void* C, const int ldc,
|
||||
const std::optional<cublasLtMatmulHeuristicResult_t>& heuristic)
|
||||
{
|
||||
if (heuristic)
|
||||
{
|
||||
Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, 1.0f, 0.0f, (*heuristic).algo,
|
||||
(*heuristic).state == CUBLAS_STATUS_SUCCESS && (*heuristic).workspaceSize < CUBLAS_WORKSPACE_SIZE);
|
||||
Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, 1.0f, 0.0f, /* hasAlgo */ (*heuristic).algo,
|
||||
(*heuristic).state == CUBLAS_STATUS_SUCCESS && (*heuristic).workspaceSize < CUBLAS_WORKSPACE_SIZE,
|
||||
/* usingCublasLt */ true);
|
||||
}
|
||||
else
|
||||
{
|
||||
Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, 1.0f, 0.0f, {}, false);
|
||||
Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, 1.0f, 0.0f, {}, /* hasAlgo */ false,
|
||||
/* usingCublasLt */ true);
|
||||
}
|
||||
}
|
||||
|
||||
void cublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k,
|
||||
void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k,
|
||||
const void* A, const int lda, const void* B, const int ldb, void* C, const int ldc, float f_alpha, float f_beta)
|
||||
{
|
||||
bool usingCublasLt = Atype_ == CUDA_R_16F;
|
||||
bool isFp16ComputeType = computeType_ == CUDA_R_16F;
|
||||
bool usingCublasLt = mAType == CUDA_R_16F;
|
||||
|
||||
int batch_count = 1;
|
||||
int findAlgo = cublas_algo_map_->isExist(batch_count, m, n, k, getCublasDataType(Atype_));
|
||||
|
||||
cublasLtMatmulAlgo_info info = cublas_algo_map_->getAlgo(batch_count, m, n, k, getCublasDataType(Atype_));
|
||||
|
||||
cublasLtMatmulAlgo_t algo;
|
||||
void* workSpace = cublas_workspace_;
|
||||
int workspaceSize = cublas_workspace_ == NULL ? 0 : CUBLAS_WORKSPACE_SIZE;
|
||||
if (findAlgo)
|
||||
{
|
||||
if (info.stages != -1)
|
||||
{
|
||||
usingCublasLt = true;
|
||||
}
|
||||
else
|
||||
{
|
||||
usingCublasLt = false;
|
||||
}
|
||||
}
|
||||
|
||||
if (usingCublasLt)
|
||||
{
|
||||
cublasLtMatmulDesc_t operationDesc = NULL;
|
||||
cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL;
|
||||
cudaDataType_t scaleType;
|
||||
#if (CUDART_VERSION >= 11000)
|
||||
cublasComputeType_t computeType;
|
||||
#else
|
||||
cudaDataType_t computeType;
|
||||
#endif
|
||||
|
||||
if (isFp16ComputeType)
|
||||
{
|
||||
#if (CUDART_VERSION >= 11000)
|
||||
computeType = CUBLAS_COMPUTE_16F;
|
||||
#else
|
||||
computeType = CUDA_R_16F;
|
||||
#endif
|
||||
scaleType = CUDA_R_16F;
|
||||
}
|
||||
else
|
||||
{
|
||||
#if (CUDART_VERSION >= 11000)
|
||||
computeType = CUBLAS_COMPUTE_32F;
|
||||
#else
|
||||
computeType = CUDA_R_32F;
|
||||
#endif
|
||||
scaleType = CUDA_R_32F;
|
||||
}
|
||||
|
||||
if (findAlgo)
|
||||
{
|
||||
if (info.workspaceSize > workspaceSize)
|
||||
{
|
||||
findAlgo = 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
cublasLtMatmulAlgoInit(
|
||||
*cublaslt_handle_, computeType, scaleType, Atype_, Btype_, Ctype_, Ctype_, info.algoId, &algo);
|
||||
cublasLtMatmulAlgoConfigSetAttribute(
|
||||
&algo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &(info.customOption), sizeof(info.customOption));
|
||||
cublasLtMatmulAlgoConfigSetAttribute(
|
||||
&algo, CUBLASLT_ALGO_CONFIG_TILE_ID, &(info.tile), sizeof(info.tile));
|
||||
cublasLtMatmulAlgoConfigSetAttribute(
|
||||
&algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, &(info.splitK_val), sizeof(info.splitK_val));
|
||||
cublasLtMatmulAlgoConfigSetAttribute(
|
||||
&algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(info.swizzle), sizeof(info.swizzle));
|
||||
cublasLtMatmulAlgoConfigSetAttribute(
|
||||
&algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &(info.reductionScheme), sizeof(int));
|
||||
#if (CUDART_VERSION >= 11000)
|
||||
cublasLtMatmulAlgoConfigSetAttribute(
|
||||
&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(info.stages), sizeof(info.stages));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, f_alpha, f_beta, algo, findAlgo);
|
||||
Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, f_alpha, f_beta, {}, /* hasAlgo */ false,
|
||||
/* usingCublasLt */ usingCublasLt);
|
||||
}
|
||||
|
||||
void cublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k,
|
||||
void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k,
|
||||
const void* A, const int lda, const void* B, const int ldb, void* C, const int ldc, float f_alpha, float f_beta,
|
||||
const cublasLtMatmulAlgo_t& algo, bool hasAlgo)
|
||||
const cublasLtMatmulAlgo_t& algo, bool hasAlgo, bool usingCublasLt)
|
||||
{
|
||||
half h_alpha = (half) (f_alpha);
|
||||
half h_beta = (half) (f_beta);
|
||||
|
||||
std::lock_guard<std::mutex> lock(*mu_);
|
||||
std::lock_guard<std::mutex> lock(*mMutex);
|
||||
|
||||
// TODO: default cublas libs
|
||||
bool usingCublasLt = Atype_ == CUDA_R_16F;
|
||||
bool isFp16ComputeType = computeType_ == CUDA_R_16F;
|
||||
usingCublasLt = usingCublasLt && mAType == CUDA_R_16F;
|
||||
bool isFp16ComputeType = mComputeType == CUBLAS_COMPUTE_16F;
|
||||
int batch_count = 1;
|
||||
// fp32 use cublas as default
|
||||
// fp16 use cublasLt as default
|
||||
const void* alpha = isFp16ComputeType ? reinterpret_cast<void*>(&h_alpha) : reinterpret_cast<void*>(&f_alpha);
|
||||
const void* beta = isFp16ComputeType ? reinterpret_cast<void*>(&h_beta) : reinterpret_cast<void*>(&f_beta);
|
||||
if (hasAlgo)
|
||||
{
|
||||
int32_t stages;
|
||||
cublasLtMatmulAlgoConfigGetAttribute(&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &stages, sizeof(stages), NULL);
|
||||
if (stages != -1)
|
||||
{
|
||||
usingCublasLt = true;
|
||||
}
|
||||
else
|
||||
{
|
||||
usingCublasLt = false;
|
||||
}
|
||||
}
|
||||
int workspaceSize = mCublasWorkspace == NULL ? 0 : CUBLAS_WORKSPACE_SIZE;
|
||||
|
||||
if (usingCublasLt)
|
||||
{
|
||||
cublasLtMatmulDesc_t operationDesc = NULL;
|
||||
cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL;
|
||||
cudaDataType_t scaleType;
|
||||
#if (CUDART_VERSION >= 11000)
|
||||
cublasComputeType_t computeType;
|
||||
#else
|
||||
cudaDataType_t computeType;
|
||||
#endif
|
||||
|
||||
if (isFp16ComputeType)
|
||||
{
|
||||
#if (CUDART_VERSION >= 11000)
|
||||
computeType = CUBLAS_COMPUTE_16F;
|
||||
#else
|
||||
computeType = CUDA_R_16F;
|
||||
#endif
|
||||
scaleType = CUDA_R_16F;
|
||||
}
|
||||
else
|
||||
{
|
||||
#if (CUDART_VERSION >= 11000)
|
||||
computeType = CUBLAS_COMPUTE_32F;
|
||||
#else
|
||||
computeType = CUDA_R_32F;
|
||||
#endif
|
||||
scaleType = CUDA_R_32F;
|
||||
}
|
||||
// --------------------------------------
|
||||
// Create descriptors for the original matrices
|
||||
cublasLtMatrixLayoutCreate(&Adesc, Atype_, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);
|
||||
cublasLtMatrixLayoutCreate(&Bdesc, Btype_, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);
|
||||
cublasLtMatrixLayoutCreate(&Cdesc, Ctype_, m, n, ldc);
|
||||
#if (CUDART_VERSION >= 11000)
|
||||
cublasLtMatmulDescCreate(&operationDesc, computeType, scaleType);
|
||||
#else
|
||||
cublasLtMatmulDescCreate(&operationDesc, computeType);
|
||||
#endif
|
||||
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(cublasOperation_t));
|
||||
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(cublasOperation_t));
|
||||
|
||||
void* workSpace = cublas_workspace_;
|
||||
int workspaceSize = cublas_workspace_ == NULL ? 0 : CUBLAS_WORKSPACE_SIZE;
|
||||
if (hasAlgo)
|
||||
{
|
||||
cublasLtMatmulHeuristicResult_t heurResult;
|
||||
// We have to check if the heruistic is correct given current shape size
|
||||
cublasStatus_t algoStatus = cublasLtMatmulAlgoCheck(
|
||||
getCublasLtHandle(), operationDesc, Adesc, Bdesc, Cdesc, Cdesc, &algo, &heurResult);
|
||||
|
||||
if (algoStatus != CUBLAS_STATUS_SUCCESS || heurResult.state != CUBLAS_STATUS_SUCCESS
|
||||
|| heurResult.workspaceSize > CUBLAS_WORKSPACE_SIZE)
|
||||
{
|
||||
// Rely on runtime based heruistic
|
||||
hasAlgo = false;
|
||||
}
|
||||
hasAlgo = checkTactic(transa, transb, m, n, k, lda, ldb, ldc, algo);
|
||||
}
|
||||
|
||||
check_cuda_error(cublasLtMatmul(*cublaslt_handle_, operationDesc, alpha, A, Adesc, B, Bdesc, beta, C, Cdesc, C,
|
||||
Cdesc, (hasAlgo ? (&algo) : NULL), workSpace, workspaceSize, stream_));
|
||||
check_cuda_error(cublasLtMatmul(getCublasLtHandle(), mOperationDesc, alpha, A, mADesc, B, mBDesc, beta, C,
|
||||
mCDesc, C, mCDesc, (hasAlgo ? (&algo) : NULL), mCublasWorkspace, workspaceSize, mStream));
|
||||
|
||||
cublasLtMatmulDescDestroy(operationDesc);
|
||||
cublasLtMatrixLayoutDestroy(Adesc);
|
||||
cublasLtMatrixLayoutDestroy(Bdesc);
|
||||
cublasLtMatrixLayoutDestroy(Cdesc);
|
||||
sync_check_cuda_error();
|
||||
}
|
||||
else
|
||||
{
|
||||
check_cuda_error(cublasSetStream(getCublasHandle(), mStream));
|
||||
check_cuda_error(cublasSetWorkspace(getCublasHandle(), mCublasWorkspace, workspaceSize));
|
||||
// Go with default heruistic to choose tactic as cuBLAS does not allow to choose tactics in Ampere+
|
||||
cublasGemmAlgo_t cublasAlgo = CUBLAS_GEMM_DEFAULT;
|
||||
check_cuda_error(cublasGemmEx(*cublas_handle_, transa, transb, m, n, k, alpha, A, Atype_, lda, B, Btype_, ldb,
|
||||
beta, C, Ctype_, ldc, computeType_, static_cast<cublasGemmAlgo_t>(cublasAlgo)));
|
||||
check_cuda_error(cublasGemmEx(getCublasHandle(), transa, transb, m, n, k, alpha, A, mAType, lda, B, mBType, ldb,
|
||||
beta, C, mCType, ldc, mComputeType, static_cast<cublasGemmAlgo_t>(cublasAlgo)));
|
||||
sync_check_cuda_error();
|
||||
}
|
||||
}
|
||||
|
||||
void cublasMMWrapper::setWorkspace(void* workspace)
|
||||
void CublasMMWrapper::stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n,
|
||||
const int k, const void* A, const int lda, const int64_t strideA, const void* B, const int ldb,
|
||||
const int64_t strideB, void* C, const int ldc, const int64_t strideC, const int batchCount, const float f_alpha,
|
||||
const float f_beta)
|
||||
{
|
||||
cublas_workspace_ = workspace;
|
||||
half h_alpha = (half) f_alpha;
|
||||
half h_beta = (half) f_beta;
|
||||
|
||||
std::lock_guard<std::mutex> lock(*mMutex);
|
||||
|
||||
int isFp16ComputeType = mComputeType == CUBLAS_COMPUTE_16F ? 1 : 0;
|
||||
const void* alpha = isFp16ComputeType ? reinterpret_cast<void*>(&h_alpha) : reinterpret_cast<const void*>(&f_alpha);
|
||||
const void* beta = isFp16ComputeType ? reinterpret_cast<void*>(&h_beta) : reinterpret_cast<const void*>(&f_beta);
|
||||
|
||||
check_cuda_error(cublasGemmStridedBatchedEx(getCublasHandle(), transa, transb, m, n, k, alpha, A, mAType, lda,
|
||||
strideA, B, mBType, ldb, strideB, beta, C, mCType, ldc, strideC, batchCount, mComputeType,
|
||||
mAType == CUDA_R_32F ? CUBLAS_GEMM_DEFAULT : CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||
}
|
||||
|
||||
void cublasMMWrapper::setFP32GemmConfig()
|
||||
void CublasMMWrapper::stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n,
|
||||
const int k, const float f_alpha, const void* A, cudaDataType_t AType, const int lda, const int64_t strideA,
|
||||
const void* B, cudaDataType_t BType, const int ldb, const int64_t strideB, const float f_beta, void* C,
|
||||
cudaDataType_t CType, const int ldc, const int64_t strideC, const int batchCount, cudaDataType_t computeType)
|
||||
{
|
||||
half h_alpha = (half) f_alpha;
|
||||
half h_beta = (half) f_beta;
|
||||
|
||||
std::lock_guard<std::mutex> lock(*mMutex);
|
||||
bool isFp16ComputeType = mComputeType == CUBLAS_COMPUTE_16F ? 1 : 0;
|
||||
const void* alpha = isFp16ComputeType ? reinterpret_cast<void*>(&h_alpha) : reinterpret_cast<const void*>(&f_alpha);
|
||||
const void* beta = isFp16ComputeType ? reinterpret_cast<void*>(&h_beta) : reinterpret_cast<const void*>(&f_beta);
|
||||
|
||||
check_cuda_error(cublasGemmStridedBatchedEx(getCublasHandle(), transa, transb, m, n, k, alpha, A, AType, lda,
|
||||
strideA, B, BType, ldb, strideB, beta, C, CType, ldc, strideC, batchCount, computeType,
|
||||
mAType == CUDA_R_32F ? CUBLAS_GEMM_DEFAULT : CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||
}
|
||||
|
||||
void CublasMMWrapper::setWorkspace(void* workspace)
|
||||
{
|
||||
mCublasWorkspace = workspace;
|
||||
}
|
||||
|
||||
void CublasMMWrapper::setFP32GemmConfig()
|
||||
{
|
||||
setGemmConfig(CUDA_R_32F, CUDA_R_32F, CUDA_R_32F, CUDA_R_32F);
|
||||
}
|
||||
|
||||
void cublasMMWrapper::setFP16GemmConfig()
|
||||
void CublasMMWrapper::setFP16GemmConfig()
|
||||
{
|
||||
setGemmConfig(CUDA_R_16F, CUDA_R_16F, CUDA_R_16F, CUDA_R_32F);
|
||||
}
|
||||
|
||||
#ifdef ENABLE_BF16
|
||||
void cublasMMWrapper::setBF16GemmConfig()
|
||||
void CublasMMWrapper::setBF16GemmConfig()
|
||||
{
|
||||
setGemmConfig(CUDA_R_16BF, CUDA_R_16BF, CUDA_R_16BF, CUDA_R_32F);
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_FP8
|
||||
void cublasMMWrapper::setFP8GemmConfig(cudaDataType_t outputType)
|
||||
void CublasMMWrapper::setFP8GemmConfig(cudaDataType_t outputType)
|
||||
{
|
||||
setGemmConfig(CUDA_R_8F_E4M3, CUDA_R_8F_E4M3, outputType, CUDA_R_32F);
|
||||
}
|
||||
#endif
|
||||
|
||||
void cublasMMWrapper::setGemmConfig(
|
||||
void CublasMMWrapper::setGemmConfig(
|
||||
cudaDataType_t aType, cudaDataType_t bType, cudaDataType_t cType, cudaDataType_t computeType)
|
||||
{
|
||||
Atype_ = aType;
|
||||
Btype_ = bType;
|
||||
Ctype_ = cType;
|
||||
computeType_ = computeType;
|
||||
mAType = aType;
|
||||
mBType = bType;
|
||||
mCType = cType;
|
||||
bool isFp16ComputeType = computeType == CUDA_R_16F;
|
||||
if (isFp16ComputeType)
|
||||
{
|
||||
mComputeType = CUBLAS_COMPUTE_16F;
|
||||
mScaleType = CUDA_R_16F;
|
||||
}
|
||||
else
|
||||
{
|
||||
mComputeType = CUBLAS_COMPUTE_32F;
|
||||
mScaleType = CUDA_R_32F;
|
||||
}
|
||||
}
|
||||
|
||||
CublasDataType cublasMMWrapper::getCublasDataType(cudaDataType_t data_type)
|
||||
CublasDataType CublasMMWrapper::getCublasDataType(cudaDataType_t data_type)
|
||||
{
|
||||
if (data_type == CUDA_R_16F)
|
||||
{
|
||||
@ -343,279 +262,22 @@ CublasDataType cublasMMWrapper::getCublasDataType(cudaDataType_t data_type)
|
||||
return FLOAT_DATATYPE;
|
||||
}
|
||||
|
||||
#if (CUDART_VERSION >= 11000)
|
||||
// input, weight, output are row-major
|
||||
// only works for cublas 11.x
|
||||
void cublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k,
|
||||
const void* A, const int lda, const void* B, const int ldb, const void* bias, void* C, const int ldc)
|
||||
void CublasMMWrapper::setStream(cudaStream_t stream)
|
||||
{
|
||||
cudaDataType_t Atype, Btype, Ctype;
|
||||
cublasComputeType_t computeType;
|
||||
cudaDataType_t scaleType;
|
||||
float alpha_float = 1.0f;
|
||||
float beta_float = 0.0f;
|
||||
half alpha_half = half(1.0f);
|
||||
half beta_half = half(0.0f);
|
||||
void *alpha, *beta;
|
||||
|
||||
// int is_fp16_computeType = computeType_ == CUDA_R_16F ? 1 : 0;
|
||||
if (Atype_ == CUDA_R_32F)
|
||||
{
|
||||
computeType = CUBLAS_COMPUTE_32F_FAST_TF32;
|
||||
Atype = CUDA_R_32F;
|
||||
Btype = CUDA_R_32F;
|
||||
Ctype = CUDA_R_32F;
|
||||
scaleType = CUDA_R_32F;
|
||||
alpha = &alpha_float;
|
||||
beta = &beta_float;
|
||||
}
|
||||
else if (Atype_ == CUDA_R_16BF)
|
||||
{
|
||||
computeType = CUBLAS_COMPUTE_32F_FAST_TF32;
|
||||
Atype = CUDA_R_16BF;
|
||||
Btype = CUDA_R_16BF;
|
||||
Ctype = CUDA_R_16BF;
|
||||
scaleType = CUDA_R_32F;
|
||||
alpha = &alpha_float;
|
||||
beta = &beta_float;
|
||||
}
|
||||
else
|
||||
{
|
||||
computeType = CUBLAS_COMPUTE_16F;
|
||||
Atype = CUDA_R_16F;
|
||||
Btype = CUDA_R_16F;
|
||||
Ctype = CUDA_R_16F;
|
||||
scaleType = CUDA_R_16F;
|
||||
alpha = &alpha_half;
|
||||
beta = &beta_half;
|
||||
}
|
||||
|
||||
cublasLtMatmulDesc_t operationDesc = NULL;
|
||||
cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL;
|
||||
cublasLtEpilogue_t epi = CUBLASLT_EPILOGUE_BIAS;
|
||||
cublasLtMatrixLayoutCreate(&Adesc, Atype, (transa == CUBLAS_OP_N) ? m : k, (transa == CUBLAS_OP_N) ? k : m, lda);
|
||||
cublasLtMatrixLayoutCreate(&Bdesc, Btype, (transb == CUBLAS_OP_N) ? k : n, (transb == CUBLAS_OP_N) ? n : k, ldb);
|
||||
cublasLtMatrixLayoutCreate(&Cdesc, Ctype, m, n, ldc);
|
||||
|
||||
cublasLtMatmulDescCreate(&operationDesc, computeType, scaleType);
|
||||
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(cublasOperation_t));
|
||||
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(cublasOperation_t));
|
||||
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epi, sizeof(cublasLtEpilogue_t));
|
||||
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(const void*));
|
||||
check_cuda_error(cublasLtMatmul(
|
||||
*cublaslt_handle_, operationDesc, alpha, A, Adesc, B, Bdesc, beta, C, Cdesc, C, Cdesc, NULL, NULL, 0, stream_));
|
||||
cublasLtMatrixLayoutDestroy(Adesc);
|
||||
cublasLtMatrixLayoutDestroy(Bdesc);
|
||||
cublasLtMatrixLayoutDestroy(Cdesc);
|
||||
cublasLtMatmulDescDestroy(operationDesc);
|
||||
}
|
||||
#endif
|
||||
void cublasMMWrapper::setStream(cudaStream_t stream)
|
||||
{
|
||||
stream_ = stream;
|
||||
mStream = stream;
|
||||
}
|
||||
|
||||
void cublasMMWrapper::stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n,
|
||||
const int k, const void* A, const int lda, const int64_t strideA, const void* B, const int ldb,
|
||||
const int64_t strideB, void* C, const int ldc, const int64_t strideC, const int batch_count, const float f_alpha,
|
||||
const float f_beta)
|
||||
bool CublasMMWrapper::checkTactic(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n,
|
||||
const int k, const int lda, const int ldb, const int ldc, const cublasLtMatmulAlgo_t& algo)
|
||||
{
|
||||
half h_alpha = (half) f_alpha;
|
||||
half h_beta = (half) f_beta;
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
descriptorsCreated(), "Descriptors are not created! Call createDescriptors before calling this function");
|
||||
|
||||
mu_->lock();
|
||||
int is_fp16_computeType = computeType_ == CUDA_R_16F ? 1 : 0;
|
||||
const void* alpha
|
||||
= is_fp16_computeType ? reinterpret_cast<void*>(&h_alpha) : reinterpret_cast<const void*>(&f_alpha);
|
||||
const void* beta = is_fp16_computeType ? reinterpret_cast<void*>(&h_beta) : reinterpret_cast<const void*>(&f_beta);
|
||||
cublasLtMatmulAlgo_info info = cublas_algo_map_->getAlgo(batch_count, m, n, k, getCublasDataType(Atype_));
|
||||
|
||||
check_cuda_error(cublasGemmStridedBatchedEx(*cublas_handle_, transa, transb, m, n, k, alpha, A, Atype_, lda,
|
||||
strideA, B, Btype_, ldb, strideB, beta, C, Ctype_, ldc, strideC, batch_count, computeType_,
|
||||
static_cast<cublasGemmAlgo_t>(info.algoId)));
|
||||
|
||||
mu_->unlock();
|
||||
}
|
||||
|
||||
void cublasMMWrapper::stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n,
|
||||
const int k, const float f_alpha, const void* A, cudaDataType_t AType, const int lda, const int64_t strideA,
|
||||
const void* B, cudaDataType_t BType, const int ldb, const int64_t strideB, const float f_beta, void* C,
|
||||
cudaDataType_t CType, const int ldc, const int64_t strideC, const int batch_count, cudaDataType_t computeType)
|
||||
{
|
||||
half h_alpha = (half) f_alpha;
|
||||
half h_beta = (half) f_beta;
|
||||
|
||||
mu_->lock();
|
||||
int is_fp16_computeType = computeType == CUDA_R_16F ? 1 : 0;
|
||||
const void* alpha
|
||||
= is_fp16_computeType ? reinterpret_cast<void*>(&h_alpha) : reinterpret_cast<const void*>(&f_alpha);
|
||||
const void* beta = is_fp16_computeType ? reinterpret_cast<void*>(&h_beta) : reinterpret_cast<const void*>(&f_beta);
|
||||
cublasLtMatmulAlgo_info info = cublas_algo_map_->getAlgo(batch_count, m, n, k, getCublasDataType(Atype_));
|
||||
|
||||
check_cuda_error(cublasGemmStridedBatchedEx(*cublas_handle_, transa, transb, m, n, k, alpha, A, AType, lda, strideA,
|
||||
B, BType, ldb, strideB, beta, C, CType, ldc, strideC, batch_count, computeType,
|
||||
static_cast<cublasGemmAlgo_t>(info.algoId)));
|
||||
|
||||
mu_->unlock();
|
||||
}
|
||||
|
||||
void cublasMMWrapper::batchedGemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n,
|
||||
const int k, const void* const* A, const int lda, const void* const* B, const int ldb, void* const* C,
|
||||
const int ldc, const int batch_count)
|
||||
{
|
||||
float f_alpha = static_cast<float>(1.0f);
|
||||
float f_beta = static_cast<float>(0.0f);
|
||||
|
||||
half h_alpha = (half) 1.0f;
|
||||
half h_beta = (half) 0.0f;
|
||||
|
||||
mu_->lock();
|
||||
int is_fp16_computeType = computeType_ == CUDA_R_16F ? 1 : 0;
|
||||
const void* alpha = is_fp16_computeType ? reinterpret_cast<void*>(&h_alpha) : reinterpret_cast<void*>(&f_alpha);
|
||||
const void* beta = is_fp16_computeType ? reinterpret_cast<void*>(&h_beta) : reinterpret_cast<void*>(&f_beta);
|
||||
cublasLtMatmulAlgo_info info = cublas_algo_map_->getAlgo(batch_count, m, n, k, getCublasDataType(Atype_));
|
||||
|
||||
check_cuda_error(cublasGemmBatchedEx(*cublas_handle_, transa, transb, m, n, k, alpha, A, Atype_, lda, B, Btype_,
|
||||
ldb, beta, C, Ctype_, ldc, batch_count, computeType_, static_cast<cublasGemmAlgo_t>(info.algoId)));
|
||||
mu_->unlock();
|
||||
}
|
||||
|
||||
bool cublasMMWrapper::isFuseBatchGemm(const int batch_count, const int m, const int k, const int n)
|
||||
{
|
||||
CublasDataType data_type = getCublasDataType(Atype_);
|
||||
|
||||
if (cublas_algo_map_->isExist(batch_count, m, k, n, data_type) == false
|
||||
|| cublas_algo_map_->isExist(1, m, k, n, data_type) == false)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
else
|
||||
{
|
||||
return cublas_algo_map_->getAlgo(batch_count, m, k, n, data_type).exec_time
|
||||
< 3 * cublas_algo_map_->getAlgo(1, m, k, n, data_type).exec_time;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<cublasLtMatmulHeuristicResult_t> cublasMMWrapper::getTactics(cublasOperation_t transa,
|
||||
cublasOperation_t transb, const int m, const int n, const int k, const int lda, const int ldb, const int ldc)
|
||||
{
|
||||
int is_fp16_computeType = computeType_ == CUDA_R_16F ? 1 : 0;
|
||||
cublasLtMatmulDesc_t operationDesc = NULL;
|
||||
cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL, Ddesc = NULL;
|
||||
cudaDataType_t scaleType;
|
||||
#if (CUDART_VERSION >= 11000)
|
||||
cublasComputeType_t computeType;
|
||||
#else
|
||||
cudaDataType_t computeType;
|
||||
#endif
|
||||
|
||||
if (is_fp16_computeType)
|
||||
{
|
||||
#if (CUDART_VERSION >= 11000)
|
||||
computeType = CUBLAS_COMPUTE_16F;
|
||||
#else
|
||||
computeType = CUDA_R_16F;
|
||||
#endif
|
||||
scaleType = CUDA_R_16F;
|
||||
}
|
||||
else
|
||||
{
|
||||
#if (CUDART_VERSION >= 11000)
|
||||
computeType = CUBLAS_COMPUTE_32F;
|
||||
#else
|
||||
computeType = CUDA_R_32F;
|
||||
#endif
|
||||
scaleType = CUDA_R_32F;
|
||||
}
|
||||
|
||||
// --------------------------------------
|
||||
// Create descriptors for the original matrices
|
||||
check_cuda_error(
|
||||
cublasLtMatrixLayoutCreate(&Adesc, Atype_, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda));
|
||||
check_cuda_error(
|
||||
cublasLtMatrixLayoutCreate(&Bdesc, Btype_, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb));
|
||||
check_cuda_error(cublasLtMatrixLayoutCreate(&Cdesc, Ctype_, m, n, ldc));
|
||||
#if (CUDART_VERSION >= 11000)
|
||||
check_cuda_error(cublasLtMatmulDescCreate(&operationDesc, computeType, scaleType));
|
||||
#else
|
||||
check_cuda_error(cublasLtMatmulDescCreate(&operationDesc, computeType));
|
||||
#endif
|
||||
|
||||
check_cuda_error(
|
||||
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(cublasOperation_t)));
|
||||
check_cuda_error(
|
||||
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(cublasOperation_t)));
|
||||
|
||||
void* workSpace = cublas_workspace_;
|
||||
int workspaceSize = cublas_workspace_ == NULL ? 0 : CUBLAS_WORKSPACE_SIZE;
|
||||
|
||||
const auto heuristics = getTactics(getCublasLtHandle(), operationDesc, Adesc, Bdesc, Cdesc, Cdesc);
|
||||
|
||||
check_cuda_error(cublasLtMatmulDescDestroy(operationDesc));
|
||||
check_cuda_error(cublasLtMatrixLayoutDestroy(Adesc));
|
||||
check_cuda_error(cublasLtMatrixLayoutDestroy(Bdesc));
|
||||
check_cuda_error(cublasLtMatrixLayoutDestroy(Cdesc));
|
||||
sync_check_cuda_error();
|
||||
|
||||
return heuristics;
|
||||
}
|
||||
|
||||
bool cublasMMWrapper::checkTactic(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n,
|
||||
const int k, const int lda, const int ldb, const int ldc, const cublasLtMatmulHeuristicResult_t& heuristic) const
|
||||
{
|
||||
int is_fp16_computeType = computeType_ == CUDA_R_16F ? 1 : 0;
|
||||
cublasLtMatmulDesc_t operationDesc = NULL;
|
||||
cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL, Ddesc = NULL;
|
||||
cudaDataType_t scaleType;
|
||||
#if (CUDART_VERSION >= 11000)
|
||||
cublasComputeType_t computeType;
|
||||
#else
|
||||
cudaDataType_t computeType;
|
||||
#endif
|
||||
|
||||
if (is_fp16_computeType)
|
||||
{
|
||||
#if (CUDART_VERSION >= 11000)
|
||||
computeType = CUBLAS_COMPUTE_16F;
|
||||
#else
|
||||
computeType = CUDA_R_16F;
|
||||
#endif
|
||||
scaleType = CUDA_R_16F;
|
||||
}
|
||||
else
|
||||
{
|
||||
#if (CUDART_VERSION >= 11000)
|
||||
computeType = CUBLAS_COMPUTE_32F;
|
||||
#else
|
||||
computeType = CUDA_R_32F;
|
||||
#endif
|
||||
scaleType = CUDA_R_32F;
|
||||
}
|
||||
|
||||
// --------------------------------------
|
||||
// Create descriptors for the original matrices
|
||||
check_cuda_error(
|
||||
cublasLtMatrixLayoutCreate(&Adesc, Atype_, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda));
|
||||
check_cuda_error(
|
||||
cublasLtMatrixLayoutCreate(&Bdesc, Btype_, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb));
|
||||
check_cuda_error(cublasLtMatrixLayoutCreate(&Cdesc, Ctype_, m, n, ldc));
|
||||
#if (CUDART_VERSION >= 11000)
|
||||
check_cuda_error(cublasLtMatmulDescCreate(&operationDesc, computeType, scaleType));
|
||||
#else
|
||||
check_cuda_error(cublasLtMatmulDescCreate(&operationDesc, computeType));
|
||||
#endif
|
||||
|
||||
check_cuda_error(
|
||||
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(cublasOperation_t)));
|
||||
check_cuda_error(
|
||||
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(cublasOperation_t)));
|
||||
|
||||
void* workSpace = cublas_workspace_;
|
||||
int workspaceSize = cublas_workspace_ == NULL ? 0 : CUBLAS_WORKSPACE_SIZE;
|
||||
int workspaceSize = mCublasWorkspace == NULL ? 0 : CUBLAS_WORKSPACE_SIZE;
|
||||
|
||||
cublasLtMatmulHeuristicResult_t heurResult;
|
||||
cublasStatus_t algoStatus = cublasLtMatmulAlgoCheck(
|
||||
getCublasLtHandle(), operationDesc, Adesc, Bdesc, Cdesc, Cdesc, &heuristic.algo, &heurResult);
|
||||
getCublasLtHandle(), mOperationDesc, mADesc, mBDesc, mCDesc, mCDesc, &algo, &heurResult);
|
||||
|
||||
if (algoStatus != CUBLAS_STATUS_SUCCESS || heurResult.state != CUBLAS_STATUS_SUCCESS
|
||||
|| heurResult.workspaceSize > CUBLAS_WORKSPACE_SIZE)
|
||||
@ -623,16 +285,25 @@ bool cublasMMWrapper::checkTactic(cublasOperation_t transa, cublasOperation_t tr
|
||||
return false;
|
||||
}
|
||||
|
||||
check_cuda_error(cublasLtMatmulDescDestroy(operationDesc));
|
||||
check_cuda_error(cublasLtMatrixLayoutDestroy(Adesc));
|
||||
check_cuda_error(cublasLtMatrixLayoutDestroy(Bdesc));
|
||||
check_cuda_error(cublasLtMatrixLayoutDestroy(Cdesc));
|
||||
sync_check_cuda_error();
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<cublasLtMatmulHeuristicResult_t> cublasMMWrapper::getTactics(cublasLtHandle_t lightHandle,
|
||||
std::vector<cublasLtMatmulHeuristicResult_t> CublasMMWrapper::getTactics(cublasOperation_t transa,
|
||||
cublasOperation_t transb, const int m, const int n, const int k, const int lda, const int ldb, const int ldc)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
descriptorsCreated(), "Descriptors are not created! Call createDescriptors before calling this function");
|
||||
|
||||
const auto heuristics = getTactics(getCublasLtHandle(), mOperationDesc, mADesc, mBDesc, mCDesc, mCDesc);
|
||||
|
||||
sync_check_cuda_error();
|
||||
|
||||
return heuristics;
|
||||
}
|
||||
|
||||
std::vector<cublasLtMatmulHeuristicResult_t> CublasMMWrapper::getTactics(cublasLtHandle_t lightHandle,
|
||||
cublasLtMatmulDesc_t computeDesc, cublasLtMatrixLayout_t Adesc, cublasLtMatrixLayout_t Bdesc,
|
||||
cublasLtMatrixLayout_t Cdesc, cublasLtMatrixLayout_t Ddesc)
|
||||
{
|
||||
@ -648,7 +319,7 @@ std::vector<cublasLtMatmulHeuristicResult_t> cublasMMWrapper::getTactics(cublasL
|
||||
check_cuda_error(cublasLtMatmulPreferenceSetAttribute(
|
||||
preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspace_size, sizeof(workspace_size)));
|
||||
// Restrict reduction algorithms for numerical stability and better determenism
|
||||
uint32_t reduction_mask = CUBLASLT_REDUCTION_SCHEME_INPLACE;
|
||||
uint32_t reduction_mask = CUBLASLT_REDUCTION_SCHEME_MASK;
|
||||
check_cuda_error(cublasLtMatmulPreferenceSetAttribute(
|
||||
preference, CUBLASLT_MATMUL_PREF_REDUCTION_SCHEME_MASK, &reduction_mask, sizeof(reduction_mask)));
|
||||
#if TLLM_CUBLAS_VER_LT(12, 0, 0)
|
||||
@ -666,215 +337,6 @@ std::vector<cublasLtMatmulHeuristicResult_t> cublasMMWrapper::getTactics(cublasL
|
||||
#endif
|
||||
}
|
||||
|
||||
std::pair<bool, cublasLtMatmulAlgo_t> cublasMMWrapper::findBestAlgo(cublasLtHandle_t lightHandle,
|
||||
cublasLtMatmulDesc_t computeDesc, const void* alpha, const void* A, cublasLtMatrixLayout_t Adesc, const void* B,
|
||||
cublasLtMatrixLayout_t Bdesc, const void* beta, const void* C, cublasLtMatrixLayout_t Cdesc, void* D,
|
||||
cublasLtMatrixLayout_t Ddesc, cudaStream_t stream)
|
||||
{
|
||||
#if TLLM_CUBLAS_VER_LE(11, 4, 2)
|
||||
TLLM_CHECK_WITH_INFO(false, "CUBLAS version too low, must be > 11.4.2.");
|
||||
return {false, cublasLtMatmulAlgo_t{}};
|
||||
#else
|
||||
size_t returnSize;
|
||||
int32_t pointer_mode;
|
||||
cublasLtMatmulDescGetAttribute(
|
||||
computeDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &pointer_mode, sizeof(pointer_mode), &returnSize);
|
||||
|
||||
const auto heuristics = getTactics(lightHandle, computeDesc, Adesc, Bdesc, Cdesc, Ddesc);
|
||||
|
||||
std::map<int, std::vector<float>> algo_results;
|
||||
for (const auto& heuristic : heuristics)
|
||||
{
|
||||
cublasLtMatmulAlgo_t algo = heuristic.algo;
|
||||
int32_t algo_id;
|
||||
cublasLtMatmulAlgoConfigGetAttribute(&algo, CUBLASLT_ALGO_CONFIG_ID, &algo_id, sizeof(algo_id), &returnSize);
|
||||
|
||||
cudaEvent_t start_event, stop_event;
|
||||
cudaEventCreate(&start_event);
|
||||
cudaEventCreate(&stop_event);
|
||||
|
||||
float my_alpha = 1.0f;
|
||||
float my_beta = 0.0f;
|
||||
|
||||
for (int i = 0; i < 11; i++)
|
||||
{
|
||||
float duration_ms;
|
||||
cudaEventRecord(start_event, stream);
|
||||
check_cuda_error(cublasLtMatmul(lightHandle, computeDesc, alpha, A, Adesc, B, Bdesc, beta, C, Cdesc, D,
|
||||
Ddesc, &algo, cublas_workspace_, CUBLAS_WORKSPACE_SIZE, stream));
|
||||
cudaEventRecord(stop_event, stream);
|
||||
cudaEventSynchronize(stop_event);
|
||||
cudaEventElapsedTime(&duration_ms, start_event, stop_event);
|
||||
|
||||
algo_results[algo_id].push_back(duration_ms);
|
||||
}
|
||||
std::sort(algo_results[algo_id].begin(), algo_results[algo_id].end());
|
||||
}
|
||||
|
||||
cublasLtMatmulHeuristicResult_t result;
|
||||
float best_time = INFINITY;
|
||||
for (const auto& heuristic : heuristics)
|
||||
{
|
||||
cublasLtMatmulAlgo_t algo = heuristic.algo;
|
||||
int32_t algo_id;
|
||||
cublasLtMatmulAlgoConfigGetAttribute(&algo, CUBLASLT_ALGO_CONFIG_ID, &algo_id, sizeof(algo_id), &returnSize);
|
||||
const auto& results = algo_results[algo_id];
|
||||
|
||||
if (results.size() > 0 && results[5] < best_time)
|
||||
{
|
||||
best_time = results[5];
|
||||
result = heuristic;
|
||||
}
|
||||
}
|
||||
|
||||
return {best_time != INFINITY, result.algo};
|
||||
#endif
|
||||
}
|
||||
|
||||
cublasMMWrapper::MatrixLayout cublasMMWrapper::createMatrixLayout(cublasLtMatrixLayout_t Mdesc)
|
||||
{
|
||||
size_t returnSize;
|
||||
MatrixLayout m_layout;
|
||||
|
||||
cublasLtMatrixLayoutGetAttribute(
|
||||
Mdesc, CUBLASLT_MATRIX_LAYOUT_TYPE, &std::get<0>(m_layout), sizeof(std::get<0>(m_layout)), &returnSize);
|
||||
cublasLtMatrixLayoutGetAttribute(
|
||||
Mdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &std::get<1>(m_layout), sizeof(std::get<1>(m_layout)), &returnSize);
|
||||
cublasLtMatrixLayoutGetAttribute(
|
||||
Mdesc, CUBLASLT_MATRIX_LAYOUT_ROWS, &std::get<2>(m_layout), sizeof(std::get<2>(m_layout)), &returnSize);
|
||||
cublasLtMatrixLayoutGetAttribute(
|
||||
Mdesc, CUBLASLT_MATRIX_LAYOUT_COLS, &std::get<3>(m_layout), sizeof(std::get<3>(m_layout)), &returnSize);
|
||||
|
||||
return m_layout;
|
||||
}
|
||||
|
||||
cublasStatus_t cublasMMWrapper::cublasLtMatmulWrapper(cublasLtHandle_t lightHandle, cublasLtMatmulDesc_t computeDesc,
|
||||
const void* alpha, const void* A, cublasLtMatrixLayout_t Adesc, const void* B, cublasLtMatrixLayout_t Bdesc,
|
||||
const void* beta, const void* C, cublasLtMatrixLayout_t Cdesc, void* D, cublasLtMatrixLayout_t Ddesc,
|
||||
const cublasLtMatmulAlgo_t* algo, void* workspace, size_t workspaceSizeInBytes, cudaStream_t stream)
|
||||
{
|
||||
cache_idx_t cache_idx{computeDesc,
|
||||
{createMatrixLayout(Adesc), createMatrixLayout(Bdesc), createMatrixLayout(Cdesc), createMatrixLayout(Ddesc)}};
|
||||
|
||||
cublasLtMatmulAlgo_t algo_value;
|
||||
bool found_algo = false;
|
||||
if (algo == nullptr)
|
||||
{
|
||||
if (algo_cache.find(cache_idx) == algo_cache.end())
|
||||
{
|
||||
auto result
|
||||
= findBestAlgo(lightHandle, computeDesc, alpha, A, Adesc, B, Bdesc, beta, C, Cdesc, D, Ddesc, stream);
|
||||
if (result.first)
|
||||
{
|
||||
algo_cache[cache_idx] = result.second;
|
||||
algo_value = result.second;
|
||||
found_algo = true;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
algo_value = algo_cache[cache_idx];
|
||||
found_algo = true;
|
||||
}
|
||||
}
|
||||
|
||||
return cublasLtMatmul(lightHandle, computeDesc, alpha, A, Adesc, B, Bdesc, beta, C, Cdesc, D, Ddesc,
|
||||
found_algo ? &algo_value : algo, workspace, workspaceSizeInBytes, stream);
|
||||
}
|
||||
|
||||
void cublasMMWrapper::_Int8Gemm(const int m, const int n, const int k, const int8_t* A, const int lda, const int8_t* B,
|
||||
const int ldb, void* C, const int ldc, const void* alpha, const int mode, const bool per_column_scaling)
|
||||
{
|
||||
/* mode:
|
||||
* - 0: int8 * int8 -> int32 -> int8
|
||||
* - 1: int8 * int8 -> int32 -> int32
|
||||
*/
|
||||
#if TLLM_CUBLAS_VER_LE(11, 4, 2)
|
||||
TLLM_CHECK_WITH_INFO(false, "CUBLAS version too low, must be > 11.4.2.");
|
||||
#else
|
||||
|
||||
mu_->lock();
|
||||
const auto op_a = CUBLAS_OP_T;
|
||||
const auto op_b = CUBLAS_OP_N;
|
||||
const auto dataType = CUDA_R_8I;
|
||||
const auto resultType = mode == 0 ? CUDA_R_8I : CUDA_R_32I;
|
||||
const auto computeType = CUBLAS_COMPUTE_32I;
|
||||
const auto scaleType = mode == 0 ? CUDA_R_32F : CUDA_R_32I;
|
||||
const int batch_count = 1;
|
||||
const void* beta;
|
||||
|
||||
int findAlgo = cublas_algo_map_->isExist(batch_count, m, n, k, getCublasDataType(dataType));
|
||||
|
||||
cublasLtMatmulAlgo_info info = cublas_algo_map_->getAlgo(batch_count, m, n, k, getCublasDataType(dataType));
|
||||
|
||||
cublasLtMatmulDesc_t operationDesc = NULL;
|
||||
cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL;
|
||||
|
||||
// --------------------------------------
|
||||
// Create descriptors for the original matrices
|
||||
check_cuda_error(cublasLtMatrixLayoutCreate(&Adesc, dataType, k, m, lda));
|
||||
check_cuda_error(cublasLtMatrixLayoutCreate(&Bdesc, dataType, k, n, ldb));
|
||||
check_cuda_error(cublasLtMatrixLayoutCreate(&Cdesc, resultType, m, n, ldc));
|
||||
|
||||
check_cuda_error(cublasLtMatmulDescCreate(&operationDesc, computeType, scaleType));
|
||||
|
||||
auto pointer_mode = CUBLASLT_POINTER_MODE_HOST;
|
||||
if (mode == 0)
|
||||
{
|
||||
pointer_mode
|
||||
= per_column_scaling ? CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST : CUBLASLT_POINTER_MODE_DEVICE;
|
||||
}
|
||||
check_cuda_error(
|
||||
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &op_a, sizeof(cublasOperation_t)));
|
||||
check_cuda_error(
|
||||
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &op_b, sizeof(cublasOperation_t)));
|
||||
check_cuda_error(
|
||||
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSC, &op_b, sizeof(cublasOperation_t)));
|
||||
check_cuda_error(cublasLtMatmulDescSetAttribute(
|
||||
operationDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &pointer_mode, sizeof(pointer_mode)));
|
||||
|
||||
const int32_t int_one = 1;
|
||||
const int32_t int_zero = 0;
|
||||
const float float_zero = 0;
|
||||
if (mode == 0)
|
||||
{
|
||||
beta = per_column_scaling ? &float_zero : NULL;
|
||||
}
|
||||
else
|
||||
{
|
||||
alpha = &int_one;
|
||||
beta = &int_zero;
|
||||
}
|
||||
|
||||
void* workSpace = cublas_workspace_;
|
||||
int workspaceSize = cublas_workspace_ == NULL ? 0 : CUBLAS_WORKSPACE_SIZE;
|
||||
|
||||
sync_check_cuda_error();
|
||||
auto ret = cublasLtMatmulWrapper(*cublaslt_handle_, operationDesc, alpha, A, Adesc, B, Bdesc, beta, C, Cdesc, C,
|
||||
Cdesc, NULL, workSpace, workspaceSize, stream_);
|
||||
check_cuda_error(ret);
|
||||
sync_check_cuda_error();
|
||||
|
||||
cublasLtMatmulDescDestroy(operationDesc);
|
||||
cublasLtMatrixLayoutDestroy(Adesc);
|
||||
cublasLtMatrixLayoutDestroy(Bdesc);
|
||||
cublasLtMatrixLayoutDestroy(Cdesc);
|
||||
sync_check_cuda_error();
|
||||
mu_->unlock();
|
||||
#endif
|
||||
}
|
||||
|
||||
void cublasMMWrapper::Int8Gemm(const int m, const int n, const int k, const int8_t* A, const int lda, const int8_t* B,
|
||||
const int ldb, int8_t* C, const int ldc, const float* alpha, const bool per_column_scaling)
|
||||
{
|
||||
return _Int8Gemm(m, n, k, A, lda, B, ldb, C, ldc, alpha, 0, per_column_scaling);
|
||||
}
|
||||
|
||||
void cublasMMWrapper::Int8Gemm(const int m, const int n, const int k, const int8_t* A, const int lda, const int8_t* B,
|
||||
const int ldb, int32_t* C, const int ldc)
|
||||
{
|
||||
return _Int8Gemm(m, n, k, A, lda, B, ldb, C, ldc, (float*) nullptr, 1, false);
|
||||
}
|
||||
} // namespace common
|
||||
|
||||
} // namespace tensorrt_llm
|
||||
|
||||
@ -16,7 +16,6 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/common/cublasAlgoMap.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include <cublasLt.h>
|
||||
#include <cublas_v2.h>
|
||||
@ -31,66 +30,44 @@ namespace tensorrt_llm
|
||||
namespace common
|
||||
{
|
||||
|
||||
class cublasMMWrapper
|
||||
class CublasMMWrapper
|
||||
{
|
||||
protected:
|
||||
std::shared_ptr<cublasHandle_t> cublas_handle_;
|
||||
std::shared_ptr<cublasLtHandle_t> cublaslt_handle_;
|
||||
std::shared_ptr<cublasHandle_t> mCublasHandle;
|
||||
std::shared_ptr<cublasLtHandle_t> mCublasLtHandle;
|
||||
|
||||
cudaDataType_t Atype_;
|
||||
cudaDataType_t Btype_;
|
||||
cudaDataType_t Ctype_;
|
||||
cudaDataType_t computeType_;
|
||||
cudaDataType_t mAType{};
|
||||
cudaDataType_t mBType{};
|
||||
cudaDataType_t mCType{};
|
||||
cublasComputeType_t mComputeType{};
|
||||
cudaDataType_t mScaleType{};
|
||||
|
||||
cudaStream_t stream_;
|
||||
cublasAlgoMap* cublas_algo_map_;
|
||||
std::mutex* mu_;
|
||||
cublasLtMatmulDesc_t mOperationDesc{NULL};
|
||||
cublasLtMatrixLayout_t mADesc{NULL};
|
||||
cublasLtMatrixLayout_t mBDesc{NULL};
|
||||
cublasLtMatrixLayout_t mCDesc{NULL};
|
||||
|
||||
void* cublas_workspace_ = nullptr;
|
||||
cudaStream_t mStream;
|
||||
//@fixme: we may not need the mutex if we copy the wrapper instead of sharing in GemmPlugin::clone()
|
||||
std::shared_ptr<std::mutex> mMutex{std::make_shared<std::mutex>()};
|
||||
|
||||
friend class cublasINT8MMWrapper;
|
||||
void* mCublasWorkspace = nullptr;
|
||||
|
||||
void _Int8Gemm(const int m, const int n, const int k, const int8_t* A, const int lda, const int8_t* B,
|
||||
const int ldb, void* C, const int ldc, const void* alpha, const int mode, const bool per_column_scaling);
|
||||
private:
|
||||
bool descriptorsCreated() const
|
||||
{
|
||||
return mOperationDesc != NULL && mADesc != NULL && mBDesc != NULL && mCDesc != NULL;
|
||||
}
|
||||
|
||||
public:
|
||||
cublasMMWrapper(std::shared_ptr<cublasHandle_t> cublasHandle, std::shared_ptr<cublasLtHandle_t> cublasLtHandle,
|
||||
cudaStream_t stream, cublasAlgoMap* map, std::mutex* mu, void* workspace);
|
||||
CublasMMWrapper(std::shared_ptr<cublasHandle_t> cublasHandle, std::shared_ptr<cublasLtHandle_t> cublasLtHandle,
|
||||
cudaStream_t stream, void* workspace);
|
||||
|
||||
~cublasMMWrapper();
|
||||
~CublasMMWrapper();
|
||||
|
||||
cublasMMWrapper(const cublasMMWrapper& wrapper);
|
||||
|
||||
cublasStatus_t cublasLtMatmulWrapper(cublasLtHandle_t lightHandle, cublasLtMatmulDesc_t computeDesc,
|
||||
const void* alpha, const void* A, cublasLtMatrixLayout_t Adesc, const void* B, cublasLtMatrixLayout_t Bdesc,
|
||||
const void* beta, const void* C, cublasLtMatrixLayout_t Cdesc, void* D, cublasLtMatrixLayout_t Ddesc,
|
||||
const cublasLtMatmulAlgo_t* algo, void* workspace, size_t workspaceSizeInBytes, cudaStream_t stream);
|
||||
|
||||
bool checkTactic(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k,
|
||||
const int lda, const int ldb, const int ldc, const cublasLtMatmulHeuristicResult_t& algo) const;
|
||||
|
||||
std::vector<cublasLtMatmulHeuristicResult_t> getTactics(cublasOperation_t transa, cublasOperation_t transb,
|
||||
const int m, const int n, const int k, const int lda, const int ldb, const int ldc);
|
||||
|
||||
std::vector<cublasLtMatmulHeuristicResult_t> getTactics(cublasLtHandle_t lightHandle,
|
||||
cublasLtMatmulDesc_t computeDesc, cublasLtMatrixLayout_t Adesc, cublasLtMatrixLayout_t Bdesc,
|
||||
cublasLtMatrixLayout_t Cdesc, cublasLtMatrixLayout_t Ddesc);
|
||||
|
||||
std::pair<bool, cublasLtMatmulAlgo_t> findBestAlgo(cublasLtHandle_t lightHandle, cublasLtMatmulDesc_t computeDesc,
|
||||
const void* alpha, const void* A, cublasLtMatrixLayout_t Adesc, const void* B, cublasLtMatrixLayout_t Bdesc,
|
||||
const void* beta, const void* C, cublasLtMatrixLayout_t Cdesc, void* D, cublasLtMatrixLayout_t Ddesc,
|
||||
cudaStream_t stream);
|
||||
|
||||
using MatrixLayout = std::tuple<cudaDataType_t, cublasLtOrder_t, uint64_t, uint64_t>;
|
||||
using cache_idx_t = std::tuple<cublasLtMatmulDesc_t, std::array<MatrixLayout, 4>>;
|
||||
std::map<cache_idx_t, cublasLtMatmulAlgo_t> algo_cache;
|
||||
|
||||
MatrixLayout createMatrixLayout(cublasLtMatrixLayout_t Mdesc);
|
||||
|
||||
void Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k,
|
||||
const void* alpha, const void* A, cudaDataType_t Atype, int lda, const void* B, cudaDataType_t Btype, int ldb,
|
||||
const void* beta, void* C, cudaDataType_t Ctype, int ldc, cudaDataType_t computeType, cublasGemmAlgo_t algo);
|
||||
CublasMMWrapper(const CublasMMWrapper& wrapper);
|
||||
|
||||
/********************** GEMMs **********************/
|
||||
void Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k, const void* A,
|
||||
const int lda, const void* B, const int ldb, void* C, const int ldc);
|
||||
|
||||
@ -103,16 +80,37 @@ public:
|
||||
|
||||
void Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k, const void* A,
|
||||
const int lda, const void* B, const int ldb, void* C, const int ldc, float f_alpha, float f_beta,
|
||||
const cublasLtMatmulAlgo_t& algo, bool hasAlgo);
|
||||
const cublasLtMatmulAlgo_t& algo, bool hasAlgo, bool usingCublasLt);
|
||||
|
||||
void stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k,
|
||||
const void* A, const int lda, const int64_t strideA, const void* B, const int ldb, const int64_t strideB,
|
||||
void* C, const int ldc, const int64_t strideC, const int batchCount, const float f_alpha = 1.0f,
|
||||
const float f_beta = 0.0f);
|
||||
|
||||
void stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k,
|
||||
const float f_alpha, const void* A, cudaDataType_t AType, const int lda, const int64_t strideA, const void* B,
|
||||
cudaDataType_t BType, const int ldb, const int64_t strideB, const float f_beta, void* C, cudaDataType_t CType,
|
||||
const int ldc, const int64_t strideC, const int batchCount, cudaDataType_t computeType);
|
||||
|
||||
/********************** Tactic selection helpers **********************/
|
||||
bool checkTactic(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k,
|
||||
const int lda, const int ldb, const int ldc, const cublasLtMatmulAlgo_t& algo);
|
||||
|
||||
std::vector<cublasLtMatmulHeuristicResult_t> getTactics(cublasOperation_t transa, cublasOperation_t transb,
|
||||
const int m, const int n, const int k, const int lda, const int ldb, const int ldc);
|
||||
|
||||
std::vector<cublasLtMatmulHeuristicResult_t> getTactics(cublasLtHandle_t lightHandle,
|
||||
cublasLtMatmulDesc_t computeDesc, cublasLtMatrixLayout_t Adesc, cublasLtMatrixLayout_t Bdesc,
|
||||
cublasLtMatrixLayout_t Cdesc, cublasLtMatrixLayout_t Ddesc);
|
||||
|
||||
using MatrixLayout = std::tuple<cudaDataType_t, cublasLtOrder_t, uint64_t, uint64_t>;
|
||||
using cache_idx_t = std::tuple<cublasLtMatmulDesc_t, std::array<MatrixLayout, 4>>;
|
||||
|
||||
MatrixLayout createMatrixLayout(cublasLtMatrixLayout_t Mdesc);
|
||||
|
||||
/********************** Utils **********************/
|
||||
void setWorkspace(void* workspace);
|
||||
|
||||
void Int8Gemm(const int m, const int n, const int k, const int8_t* A, const int lda, const int8_t* B, const int ldb,
|
||||
int8_t* C, const int ldc, const float* alpha, const bool per_column_scaling = false);
|
||||
|
||||
void Int8Gemm(const int m, const int n, const int k, const int8_t* A, const int lda, const int8_t* B, const int ldb,
|
||||
int32_t* C, const int ldc);
|
||||
|
||||
void setFP32GemmConfig();
|
||||
void setFP16GemmConfig();
|
||||
#ifdef ENABLE_BF16
|
||||
@ -128,35 +126,18 @@ public:
|
||||
|
||||
CublasDataType getCublasDataType(cudaDataType_t data_type);
|
||||
|
||||
#if (CUDART_VERSION >= 11000)
|
||||
void Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k, const void* A,
|
||||
const int lda, const void* B, const int ldb, const void* bias, void* C, const int ldc);
|
||||
#endif
|
||||
|
||||
void stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k,
|
||||
const void* A, const int lda, const int64_t strideA, const void* B, const int ldb, const int64_t strideB,
|
||||
void* C, const int ldc, const int64_t strideC, const int batchCount, const float f_alpha = 1.0f,
|
||||
const float f_beta = 0.0f);
|
||||
|
||||
void stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k,
|
||||
const float f_alpha, const void* A, cudaDataType_t AType, const int lda, const int64_t strideA, const void* B,
|
||||
cudaDataType_t BType, const int ldb, const int64_t strideB, const float f_beta, void* C, cudaDataType_t CType,
|
||||
const int ldc, const int64_t strideC, const int batch_count, cudaDataType_t computeType);
|
||||
|
||||
void batchedGemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k,
|
||||
const void* const* A, const int lda, const void* const* B, const int ldb, void* const* C, const int ldc,
|
||||
const int batch_count);
|
||||
|
||||
bool isFuseBatchGemm(const int batch_count, const int m, const int k, const int n);
|
||||
void createDescriptors(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k,
|
||||
const int lda, const int ldb, const int ldc);
|
||||
void destroyDescriptors();
|
||||
|
||||
cublasHandle_t getCublasHandle()
|
||||
{
|
||||
return *(this->cublas_handle_);
|
||||
return *(this->mCublasHandle);
|
||||
}
|
||||
|
||||
cublasLtHandle_t getCublasLtHandle() const
|
||||
{
|
||||
return *(this->cublaslt_handle_);
|
||||
return *(this->mCublasLtHandle);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -287,25 +287,6 @@ inline int getMultiProcessorCount()
|
||||
return multi_processor_count;
|
||||
}
|
||||
|
||||
class CudaEventDeleter
|
||||
{
|
||||
public:
|
||||
constexpr void operator()(cudaEvent_t stream) const
|
||||
{
|
||||
if (stream != nullptr)
|
||||
check_cuda_error(::cudaEventDestroy(stream));
|
||||
}
|
||||
};
|
||||
|
||||
using EventPtr = std::unique_ptr<std::remove_pointer_t<cudaEvent_t>, CudaEventDeleter>;
|
||||
|
||||
inline EventPtr CreateEvent(unsigned int flags = cudaEventDisableTiming)
|
||||
{
|
||||
cudaEvent_t event;
|
||||
check_cuda_error(::cudaEventCreate(&event, flags));
|
||||
return EventPtr{event};
|
||||
}
|
||||
|
||||
inline int divUp(int a, int n)
|
||||
{
|
||||
return (a + n - 1) / n;
|
||||
|
||||
@ -30,11 +30,10 @@ namespace tensorrt_llm::common
|
||||
class Logger
|
||||
{
|
||||
|
||||
#if _WIN32
|
||||
// On Windows, the file wingdi.h is included which has
|
||||
// #define ERROR 0
|
||||
// This breaks everywhere ERROR is used in the Level enum
|
||||
// Alternative, untested solution to #undef: compile with NOGDI flag defined
|
||||
#if _WIN32
|
||||
#undef ERROR
|
||||
#endif // _WIN32
|
||||
|
||||
|
||||
142
cpp/tensorrt_llm/common/mpiUtils.cpp
Normal file
142
cpp/tensorrt_llm/common/mpiUtils.cpp
Normal file
@ -0,0 +1,142 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/common/mpiUtils.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace mpi
|
||||
{
|
||||
|
||||
MPI_Datatype getMpiDtype(MpiType dtype)
|
||||
{
|
||||
static const std::unordered_map<MpiType, MPI_Datatype> dtype_map{
|
||||
{MPI_TYPE_BYTE, MPI_BYTE},
|
||||
{MPI_TYPE_CHAR, MPI_CHAR},
|
||||
{MPI_TYPE_INT, MPI_INT},
|
||||
{MPI_TYPE_INT64_T, MPI_INT64_T},
|
||||
{MPI_TYPE_UINT32_T, MPI_UINT32_T},
|
||||
{MPI_TYPE_UNSIGNED_LONG_LONG, MPI_UNSIGNED_LONG_LONG},
|
||||
};
|
||||
return dtype_map.at(dtype);
|
||||
}
|
||||
|
||||
MPI_Op getMpiOp(MpiOp op)
|
||||
{
|
||||
static const std::unordered_map<MpiOp, MPI_Op> op_map{
|
||||
{MPI_OP_NULLOP, MPI_OP_NULL},
|
||||
{MPI_OP_MAX, MPI_MAX},
|
||||
{MPI_OP_MIN, MPI_MIN},
|
||||
{MPI_OP_SUM, MPI_SUM},
|
||||
{MPI_OP_PROD, MPI_PROD},
|
||||
{MPI_OP_LAND, MPI_LAND},
|
||||
{MPI_OP_BAND, MPI_BAND},
|
||||
{MPI_OP_LOR, MPI_LOR},
|
||||
{MPI_OP_BOR, MPI_BOR},
|
||||
{MPI_OP_LXOR, MPI_LXOR},
|
||||
{MPI_OP_BXOR, MPI_BXOR},
|
||||
{MPI_OP_MINLOC, MPI_MINLOC},
|
||||
{MPI_OP_MAXLOC, MPI_MAXLOC},
|
||||
{MPI_OP_REPLACE, MPI_REPLACE},
|
||||
};
|
||||
return op_map.at(op);
|
||||
}
|
||||
|
||||
void initialize(int* argc, char*** argv)
|
||||
{
|
||||
MPICHECK(MPI_Init(argc, argv));
|
||||
}
|
||||
|
||||
void finalize()
|
||||
{
|
||||
MPICHECK(MPI_Finalize());
|
||||
}
|
||||
|
||||
bool isInitialized()
|
||||
{
|
||||
int mpi_initialized = 0;
|
||||
MPICHECK(MPI_Initialized(&mpi_initialized));
|
||||
return static_cast<bool>(mpi_initialized);
|
||||
}
|
||||
|
||||
void initThread(int* argc, char*** argv, MpiThreadSupport required, int* provided)
|
||||
{
|
||||
switch (required)
|
||||
{
|
||||
case THREAD_SINGLE: MPICHECK(MPI_Init_thread(argc, argv, MPI_THREAD_SINGLE, provided)); break;
|
||||
case THREAD_FUNNELED: MPICHECK(MPI_Init_thread(argc, argv, MPI_THREAD_FUNNELED, provided)); break;
|
||||
case THREAD_SERIALIZED: MPICHECK(MPI_Init_thread(argc, argv, MPI_THREAD_SERIALIZED, provided)); break;
|
||||
case THREAD_MULTIPLE: MPICHECK(MPI_Init_thread(argc, argv, MPI_THREAD_MULTIPLE, provided)); break;
|
||||
default: break;
|
||||
}
|
||||
}
|
||||
|
||||
int getCommWorldRank()
|
||||
{
|
||||
int rank = 0;
|
||||
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
|
||||
return rank;
|
||||
}
|
||||
|
||||
int getCommWorldSize()
|
||||
{
|
||||
int world_size = 1;
|
||||
MPI_Comm_size(MPI_COMM_WORLD, &world_size);
|
||||
return world_size;
|
||||
}
|
||||
|
||||
void barrier(MpiComm comm)
|
||||
{
|
||||
MPICHECK(MPI_Barrier(comm.group));
|
||||
}
|
||||
|
||||
void barrier()
|
||||
{
|
||||
MPICHECK(MPI_Barrier(MPI_COMM_WORLD));
|
||||
}
|
||||
|
||||
void bcast(void* buffer, size_t size, MpiType dtype, int root, MpiComm comm)
|
||||
{
|
||||
MPICHECK(MPI_Bcast(buffer, size, getMpiDtype(dtype), root, comm.group));
|
||||
}
|
||||
|
||||
void bcast(std::vector<int64_t>& packed, int root, MpiComm comm)
|
||||
{
|
||||
int64_t nWords1;
|
||||
if (getCommWorldRank() == root)
|
||||
{
|
||||
nWords1 = static_cast<int64_t>(packed.size());
|
||||
}
|
||||
bcast(&nWords1, 1, MPI_TYPE_INT64_T, root, comm);
|
||||
if (getCommWorldRank() != root)
|
||||
{
|
||||
packed.resize(nWords1);
|
||||
}
|
||||
bcast(packed.data(), packed.size(), MPI_TYPE_INT64_T, root, comm);
|
||||
}
|
||||
|
||||
void comm_split(MpiComm comm, int color, int key, MpiComm* newcomm)
|
||||
{
|
||||
MPICHECK(MPI_Comm_split(comm.group, color, key, &newcomm->group));
|
||||
}
|
||||
|
||||
void allreduce(const void* sendbuf, void* recvbuf, int count, MpiType dtype, MpiOp op, MpiComm comm)
|
||||
{
|
||||
MPICHECK(MPI_Allreduce(sendbuf, recvbuf, count, getMpiDtype(dtype), getMpiOp(op), comm.group));
|
||||
}
|
||||
|
||||
} // namespace mpi
|
||||
} // namespace tensorrt_llm
|
||||
105
cpp/tensorrt_llm/common/mpiUtils.h
Normal file
105
cpp/tensorrt_llm/common/mpiUtils.h
Normal file
@ -0,0 +1,105 @@
|
||||
/*
|
||||
* Copyright (c) 2021-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
#include <mpi.h>
|
||||
#include <stdio.h>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#define COMM_WORLD MpiComm(MPI_COMM_WORLD)
|
||||
#define MPICHECK(cmd) \
|
||||
do \
|
||||
{ \
|
||||
int e = cmd; \
|
||||
if (e != MPI_SUCCESS) \
|
||||
{ \
|
||||
printf("Failed: MPI error %s:%d '%d'\n", __FILE__, __LINE__, e); \
|
||||
exit(EXIT_FAILURE); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
// A wrapper module of the MPI library.
|
||||
namespace tensorrt_llm::mpi
|
||||
{
|
||||
|
||||
// A wrapper of MPI data type. MPI_TYPE_{data_type}
|
||||
enum MpiType
|
||||
{
|
||||
MPI_TYPE_BYTE,
|
||||
MPI_TYPE_CHAR,
|
||||
MPI_TYPE_INT,
|
||||
MPI_TYPE_INT64_T,
|
||||
MPI_TYPE_UINT32_T,
|
||||
MPI_TYPE_UNSIGNED_LONG_LONG,
|
||||
};
|
||||
|
||||
// A wrapper of MPI_Op type.
|
||||
enum MpiOp
|
||||
{
|
||||
MPI_OP_NULLOP,
|
||||
MPI_OP_MAX,
|
||||
MPI_OP_MIN,
|
||||
MPI_OP_SUM,
|
||||
MPI_OP_PROD,
|
||||
MPI_OP_LAND,
|
||||
MPI_OP_BAND,
|
||||
MPI_OP_LOR,
|
||||
MPI_OP_BOR,
|
||||
MPI_OP_LXOR,
|
||||
MPI_OP_BXOR,
|
||||
MPI_OP_MINLOC,
|
||||
MPI_OP_MAXLOC,
|
||||
MPI_OP_REPLACE,
|
||||
};
|
||||
|
||||
// A wrapper of the level of MPI thread support
|
||||
enum MpiThreadSupport
|
||||
{
|
||||
THREAD_SINGLE,
|
||||
THREAD_FUNNELED,
|
||||
THREAD_SERIALIZED,
|
||||
THREAD_MULTIPLE
|
||||
};
|
||||
|
||||
struct MpiComm
|
||||
{
|
||||
MPI_Comm group;
|
||||
MpiComm(){};
|
||||
MpiComm(MPI_Comm g)
|
||||
: group(g){};
|
||||
};
|
||||
|
||||
MPI_Datatype getMpiDtype(MpiType dtype);
|
||||
|
||||
void initialize(int* argc, char*** argv);
|
||||
void initThread(int* argc, char*** argv, MpiThreadSupport required, int* provided);
|
||||
void finalize();
|
||||
bool isInitialized();
|
||||
void barrier(MpiComm comm);
|
||||
void barrier();
|
||||
|
||||
int getCommWorldRank();
|
||||
int getCommWorldSize();
|
||||
|
||||
void bcast(void* buffer, size_t size, MpiType dtype, int root, MpiComm comm);
|
||||
void bcast(std::vector<int64_t>& packed, int root, MpiComm comm);
|
||||
void comm_split(MpiComm comm, int color, int key, MpiComm* newcomm);
|
||||
void allreduce(const void* sendbuf, void* recvbuf, int count, MpiType dtype, MpiOp op, MpiComm comm);
|
||||
|
||||
} // namespace tensorrt_llm::mpi
|
||||
@ -30,7 +30,7 @@ inline nvtx3::color nextColor()
|
||||
nvtx3::color{0xffff00ff}, nvtx3::color{0xff00ffff}, nvtx3::color{0xffff0000}, nvtx3::color{0xffffffff}};
|
||||
constexpr auto numColors = kColors.size();
|
||||
|
||||
static thread_local int colorId = 0;
|
||||
static thread_local std::size_t colorId = 0;
|
||||
auto const color = kColors[colorId];
|
||||
colorId = colorId + 1 >= numColors ? 0 : colorId + 1;
|
||||
return color;
|
||||
|
||||
@ -80,7 +80,7 @@ __inline__ __device__ T warpReduceSum(T val)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1)
|
||||
val = add(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32)); //__shfl_sync bf16 return float when sm < 80
|
||||
val = add<T>(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32)); //__shfl_sync bf16 return float when sm < 80
|
||||
return val;
|
||||
}
|
||||
|
||||
|
||||
@ -361,6 +361,8 @@ FusedMHARunnerV2::FusedMHARunnerV2(
|
||||
{
|
||||
}
|
||||
|
||||
FusedMHARunnerV2::~FusedMHARunnerV2() = default;
|
||||
|
||||
void FusedMHARunnerV2::setup(
|
||||
const int b, const int s, const int total_seqlen, const bool has_alibi, const int tp_size, const int tp_rank)
|
||||
{
|
||||
|
||||
@ -78,7 +78,7 @@ class FusedMHARunnerV2 : public MHARunner
|
||||
public:
|
||||
FusedMHARunnerV2(const Data_type dataType, const int numHeads, const int headSize, const float qScaling);
|
||||
|
||||
~FusedMHARunnerV2() = default; // for pimpl
|
||||
~FusedMHARunnerV2(); // for pimpl
|
||||
|
||||
void setup(const int b, const int s, const int total_seqlen, const bool has_alibi = false, const int tp_size = 1,
|
||||
const int tp_rank = 0) override;
|
||||
|
||||
453
cpp/tensorrt_llm/kernels/customAllReduceKernels.cu
Normal file
453
cpp/tensorrt_llm/kernels/customAllReduceKernels.cu
Normal file
@ -0,0 +1,453 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "customAllReduceKernels.h"
|
||||
#include "tensorrt_llm/common/cudaTypeUtils.cuh"
|
||||
|
||||
namespace tensorrt_llm::kernels
|
||||
{
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
using tensorrt_llm::common::hadd2;
|
||||
|
||||
static inline __device__ uint32_t myHadd2(const uint32_t& a, const uint32_t& b)
|
||||
{
|
||||
uint32_t c;
|
||||
asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
|
||||
return c;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
static inline __device__ uint32_t fadd(const uint32_t& a, const uint32_t& b)
|
||||
{
|
||||
uint32_t c;
|
||||
asm volatile("add.f32 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
|
||||
return c;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
static inline __device__ void st_flag_release(uint32_t& flag, uint32_t* flag_addr)
|
||||
{
|
||||
#if __CUDA_ARCH__ >= 700
|
||||
asm volatile("st.global.release.sys.b32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
|
||||
#else
|
||||
__threadfence_system();
|
||||
asm volatile("st.global.volatile.b32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
|
||||
#endif
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
static inline __device__ void ld_flag_acquire(uint32_t& flag, uint32_t* flag_addr)
|
||||
{
|
||||
#if __CUDA_ARCH__ >= 700
|
||||
asm volatile("ld.global.acquire.sys.b32 %0, [%1];" : "=r"(flag) : "l"(flag_addr));
|
||||
#else
|
||||
asm volatile("ld.global.volatile.b32 %0, [%1];" : "=r"(flag) : "l"(flag_addr));
|
||||
#endif
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Type Converter that packs data format to 128 bits data type
|
||||
template <typename T>
|
||||
struct ARTypeConverter
|
||||
{
|
||||
using Type = uint4;
|
||||
};
|
||||
|
||||
#ifdef ENABLE_BF16
|
||||
template <>
|
||||
struct ARTypeConverter<__nv_bfloat16>
|
||||
{
|
||||
using Type = bf168;
|
||||
};
|
||||
#endif
|
||||
|
||||
// add two 128b data
|
||||
template <typename T_IN, typename T_COMP>
|
||||
inline __device__ T_IN add128b(T_IN a, T_IN b);
|
||||
|
||||
template <>
|
||||
inline __device__ uint4 add128b<uint4, uint16_t>(uint4 a, uint4 b)
|
||||
{
|
||||
uint4 c;
|
||||
c.x = myHadd2(a.x, b.x);
|
||||
c.y = myHadd2(a.y, b.y);
|
||||
c.z = myHadd2(a.z, b.z);
|
||||
c.w = myHadd2(a.w, b.w);
|
||||
return c;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __device__ uint4 add128b<uint4, uint32_t>(uint4 a, uint4 b)
|
||||
{
|
||||
uint4 c;
|
||||
c.x = fadd(a.x, b.x);
|
||||
c.y = fadd(a.y, b.y);
|
||||
c.z = fadd(a.z, b.z);
|
||||
c.w = fadd(a.w, b.w);
|
||||
return c;
|
||||
}
|
||||
|
||||
#ifdef ENABLE_BF16
|
||||
template <>
|
||||
inline __device__ bf168 add128b<bf168, __nv_bfloat16>(bf168 a, bf168 b)
|
||||
{
|
||||
bf168 c;
|
||||
c.x = hadd2(a.x, b.x);
|
||||
c.y = hadd2(a.y, b.y);
|
||||
c.z = hadd2(a.z, b.z);
|
||||
c.w = hadd2(a.w, b.w);
|
||||
return c;
|
||||
}
|
||||
#endif
|
||||
|
||||
// init 128bits data with 0
|
||||
template <typename T>
|
||||
inline __device__ T init_packed_type();
|
||||
|
||||
template <>
|
||||
inline __device__ uint4 init_packed_type()
|
||||
{
|
||||
return make_uint4(0u, 0u, 0u, 0u);
|
||||
}
|
||||
|
||||
#ifdef ENABLE_BF16
|
||||
template <>
|
||||
inline __device__ bf168 init_packed_type()
|
||||
{
|
||||
bf168 val;
|
||||
uint4& val_u = reinterpret_cast<uint4&>(val);
|
||||
val_u = make_uint4(0u, 0u, 0u, 0u);
|
||||
return val;
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename T, int RANKS_PER_NODE>
|
||||
static __global__ void oneShotAllReduceKernel(AllReduceParams params)
|
||||
{
|
||||
// The block index.
|
||||
const int bidx = blockIdx.x;
|
||||
// The thread index with the block.
|
||||
const int tidx = threadIdx.x;
|
||||
|
||||
// The number of elements packed into one for comms
|
||||
static constexpr int NUM_ELTS = std::is_same<T, uint32_t>::value ? 4 : 8;
|
||||
|
||||
// Packed data type for comms
|
||||
using PackedType = typename ARTypeConverter<T>::Type;
|
||||
|
||||
// The location in the destination array (load 8 fp16 or load 4 fp32 using LDG.128).
|
||||
size_t offset = bidx * params.elts_per_block + tidx * NUM_ELTS;
|
||||
// The end of the segment computed by that block.
|
||||
size_t max_offset = std::min((bidx + 1) * params.elts_per_block, params.elts_per_rank);
|
||||
|
||||
// Synchronize the ranks.
|
||||
volatile uint32_t* barrier_d = params.peer_barrier_ptrs[params.local_rank];
|
||||
if (tidx < RANKS_PER_NODE)
|
||||
{
|
||||
// The 1st block notifies the other ranks.
|
||||
if (bidx == 0)
|
||||
{
|
||||
params.peer_barrier_ptrs[tidx][params.local_rank] = params.barrier_flag;
|
||||
}
|
||||
|
||||
// Busy-wait until all ranks are ready.
|
||||
while (barrier_d[tidx] != params.barrier_flag)
|
||||
{
|
||||
}
|
||||
}
|
||||
|
||||
// Make sure we can move on...
|
||||
__syncthreads();
|
||||
|
||||
// The source pointers. Distributed round-robin for the different warps.
|
||||
const T* src_d[RANKS_PER_NODE];
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < RANKS_PER_NODE; ++ii)
|
||||
{
|
||||
int rank = (params.local_rank + ii) % RANKS_PER_NODE;
|
||||
src_d[ii] = reinterpret_cast<T*>(params.peer_comm_buffer_ptrs[rank]);
|
||||
}
|
||||
|
||||
// Each block accumulates the values from the different GPUs on the same node.
|
||||
for (size_t iter_offset = offset; iter_offset < max_offset; iter_offset += blockDim.x * NUM_ELTS)
|
||||
{
|
||||
// Iterate over the different ranks/devices on the node to load the values.
|
||||
PackedType vals[RANKS_PER_NODE];
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < RANKS_PER_NODE; ++ii)
|
||||
{
|
||||
vals[ii] = reinterpret_cast<const PackedType*>(&src_d[ii][iter_offset])[0];
|
||||
}
|
||||
|
||||
// Sum the values from the different ranks.
|
||||
PackedType sums = init_packed_type<PackedType>();
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < RANKS_PER_NODE; ++ii)
|
||||
{
|
||||
sums = add128b<PackedType, T>(sums, vals[ii]);
|
||||
}
|
||||
|
||||
// Store to the destination buffer.
|
||||
reinterpret_cast<PackedType*>(&reinterpret_cast<T*>(params.local_output_buffer_ptr)[iter_offset])[0] = sums;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int RANKS_PER_NODE>
|
||||
static __global__ void twoShotAllReduceKernel(AllReduceParams params)
|
||||
{
|
||||
|
||||
// The block index.
|
||||
const int bidx = blockIdx.x;
|
||||
// The thread index with the block.
|
||||
const int tidx = threadIdx.x;
|
||||
|
||||
// The number of elements packed into one for comms
|
||||
static constexpr int NUM_ELTS = std::is_same<T, uint32_t>::value ? 4 : 8;
|
||||
|
||||
// Packed data type for comms
|
||||
using PackedType = typename ARTypeConverter<T>::Type;
|
||||
|
||||
// The location in the destination array (load 8 fp16 or load 4 fp32 using LDG.128).
|
||||
const size_t block_start = params.rank_offset;
|
||||
const size_t block_offset = bidx * params.elts_per_block + tidx * NUM_ELTS;
|
||||
// The end of the segment computed by that block.
|
||||
size_t max_offset = min(block_start + block_offset + params.elts_per_block, params.elts_total);
|
||||
|
||||
// Synchronize the ranks.
|
||||
volatile uint32_t* barrier_d = params.peer_barrier_ptrs[params.local_rank];
|
||||
if (tidx < RANKS_PER_NODE)
|
||||
{
|
||||
// The 1st block notifies the other ranks.
|
||||
if (bidx == 0)
|
||||
{
|
||||
params.peer_barrier_ptrs[tidx][params.local_rank] = params.barrier_flag;
|
||||
}
|
||||
|
||||
// Busy-wait until all ranks are ready.
|
||||
while (barrier_d[tidx] != params.barrier_flag)
|
||||
{
|
||||
}
|
||||
}
|
||||
|
||||
// Make sure we can move on...
|
||||
__syncthreads();
|
||||
|
||||
// The source pointers. Distributed round-robin for the different warps.
|
||||
T* src_d[RANKS_PER_NODE];
|
||||
// The destination ranks for round-robin gathering
|
||||
size_t dst_rank[RANKS_PER_NODE];
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < RANKS_PER_NODE; ++ii)
|
||||
{
|
||||
int rank = (params.local_rank + ii) % RANKS_PER_NODE;
|
||||
src_d[ii] = reinterpret_cast<T*>(params.peer_comm_buffer_ptrs[rank]);
|
||||
dst_rank[ii] = rank;
|
||||
}
|
||||
|
||||
// Each block accumulates the values from the different GPUs on the same node.
|
||||
for (size_t local_offset = block_start + block_offset; local_offset < max_offset;
|
||||
local_offset += blockDim.x * NUM_ELTS)
|
||||
{
|
||||
|
||||
// Iterate over the different ranks/devices on the node to load the values.
|
||||
PackedType vals[RANKS_PER_NODE];
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < RANKS_PER_NODE; ++ii)
|
||||
{
|
||||
vals[ii] = reinterpret_cast<const PackedType*>(&src_d[ii][local_offset])[0];
|
||||
}
|
||||
|
||||
// Sum the values from the different ranks.
|
||||
PackedType sums = init_packed_type<PackedType>();
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < RANKS_PER_NODE; ++ii)
|
||||
{
|
||||
sums = add128b<PackedType, T>(sums, vals[ii]);
|
||||
}
|
||||
|
||||
// Store to the local buffer.
|
||||
reinterpret_cast<PackedType*>(&src_d[0][local_offset])[0] = sums;
|
||||
}
|
||||
|
||||
// sync threads to make sure all block threads have the sums
|
||||
__syncthreads();
|
||||
|
||||
// barriers among the blocks with the same idx (release-acquire semantics)
|
||||
if (tidx < RANKS_PER_NODE)
|
||||
{
|
||||
// The all blocks notifies the other ranks.
|
||||
uint32_t flag_block_offset = RANKS_PER_NODE + bidx * RANKS_PER_NODE;
|
||||
st_flag_release(params.barrier_flag, params.peer_barrier_ptrs[tidx] + flag_block_offset + params.local_rank);
|
||||
|
||||
// Busy-wait until all ranks are ready.
|
||||
uint32_t rank_barrier = 0;
|
||||
uint32_t* peer_barrier_d = params.peer_barrier_ptrs[params.local_rank] + flag_block_offset + tidx;
|
||||
do
|
||||
{
|
||||
ld_flag_acquire(rank_barrier, peer_barrier_d);
|
||||
} while (rank_barrier != params.barrier_flag);
|
||||
}
|
||||
|
||||
// sync threads to make sure all other ranks has the final partial results
|
||||
__syncthreads();
|
||||
|
||||
// Gather all needed elts from other intra-node ranks
|
||||
for (size_t local_offset = block_offset; local_offset < params.elts_per_rank; local_offset += blockDim.x * NUM_ELTS)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < RANKS_PER_NODE; ++ii)
|
||||
{
|
||||
// use round-robin gathering from other ranks
|
||||
size_t offset_rank = dst_rank[ii] * params.elts_per_rank + local_offset;
|
||||
if (offset_rank >= params.elts_total)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
reinterpret_cast<PackedType*>(&reinterpret_cast<T*>(params.local_output_buffer_ptr)[offset_rank])[0]
|
||||
= reinterpret_cast<PackedType*>(&src_d[ii][offset_rank])[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void kernelLaunchConfig(int& blocks_per_grid, int& threads_per_block, size_t elts, int kernel_algo,
|
||||
size_t data_type_bytes, int ranks_per_node)
|
||||
{
|
||||
assert(data_type_bytes == 2 || data_type_bytes == 4);
|
||||
size_t elts_per_thread = 16 / data_type_bytes;
|
||||
size_t elts_per_warp = (16 * WARP_SIZE) / data_type_bytes;
|
||||
switch (kernel_algo)
|
||||
{
|
||||
case 0:
|
||||
{ // one stage all reduce algo
|
||||
assert(elts % elts_per_warp == 0);
|
||||
if (elts < (elts_per_thread * DEFAULT_BLOCK_SIZE))
|
||||
{ // local reduce
|
||||
threads_per_block = ((elts + elts_per_warp - 1) / elts_per_warp) * WARP_SIZE;
|
||||
blocks_per_grid = 1;
|
||||
}
|
||||
else
|
||||
{ // local reduce
|
||||
if (elts % (elts_per_thread * threads_per_block) == 0)
|
||||
{
|
||||
blocks_per_grid
|
||||
= (elts + elts_per_thread * threads_per_block - 1) / (elts_per_thread * threads_per_block);
|
||||
// NOTE: need to adjust here
|
||||
if (blocks_per_grid > MAX_ALL_REDUCE_BLOCKS)
|
||||
{
|
||||
size_t iter_factor = 1;
|
||||
while (blocks_per_grid / iter_factor > MAX_ALL_REDUCE_BLOCKS || blocks_per_grid % iter_factor)
|
||||
{
|
||||
iter_factor += 1;
|
||||
}
|
||||
blocks_per_grid /= iter_factor;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
size_t total_threads = elts / elts_per_thread;
|
||||
blocks_per_grid = 1;
|
||||
while (total_threads % blocks_per_grid != 0 || total_threads / blocks_per_grid > DEFAULT_BLOCK_SIZE)
|
||||
{
|
||||
blocks_per_grid += 1;
|
||||
}
|
||||
threads_per_block = total_threads / blocks_per_grid;
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 1:
|
||||
{ // two stage all reduce algo
|
||||
size_t total_threads = elts / ranks_per_node / elts_per_thread;
|
||||
assert(elts / ranks_per_node % elts_per_thread == 0 && total_threads % WARP_SIZE == 0);
|
||||
|
||||
while (total_threads % blocks_per_grid != 0 || total_threads / blocks_per_grid > DEFAULT_BLOCK_SIZE)
|
||||
{
|
||||
blocks_per_grid += 1;
|
||||
}
|
||||
|
||||
threads_per_block = total_threads / blocks_per_grid;
|
||||
|
||||
// NOTE: need to adjust here
|
||||
if (blocks_per_grid > MAX_ALL_REDUCE_BLOCKS)
|
||||
{
|
||||
size_t iter_factor = 1;
|
||||
while (blocks_per_grid / iter_factor > MAX_ALL_REDUCE_BLOCKS || blocks_per_grid % iter_factor)
|
||||
{
|
||||
iter_factor += 1;
|
||||
}
|
||||
blocks_per_grid /= iter_factor;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define CUSTOM_ALL_REDUCE_KERNEL_LAUNCH(RANKS_PER_NODE) \
|
||||
\
|
||||
if (kernel_algo == 0) \
|
||||
{ \
|
||||
param.elts_per_rank = elts_total; \
|
||||
param.elts_per_block = param.elts_per_rank / blocks_per_grid; \
|
||||
oneShotAllReduceKernel<T, RANKS_PER_NODE><<<blocks_per_grid, threads_per_block, 0, stream>>>(param); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
param.elts_per_rank = param.elts_total / RANKS_PER_NODE; \
|
||||
param.elts_per_block = param.elts_per_rank / blocks_per_grid; \
|
||||
param.rank_offset = param.rank * param.elts_per_rank; \
|
||||
twoShotAllReduceKernel<T, RANKS_PER_NODE><<<blocks_per_grid, threads_per_block, 0, stream>>>(param); \
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void invokeOneOrTwoShotAllReduceKernel(AllReduceParams& param, cudaStream_t stream)
|
||||
{
|
||||
size_t elts_total = param.elts_total;
|
||||
int blocks_per_grid = 1, threads_per_block = DEFAULT_BLOCK_SIZE;
|
||||
int kernel_algo = 1;
|
||||
if (elts_total * sizeof(T) <= DEFAULT_ALGO_AR_SIZE_THRESHOLD)
|
||||
{
|
||||
kernel_algo = 0;
|
||||
}
|
||||
|
||||
kernelLaunchConfig(blocks_per_grid, threads_per_block, elts_total, kernel_algo, sizeof(T), param.ranks_per_node);
|
||||
switch (param.ranks_per_node)
|
||||
{
|
||||
case 2: CUSTOM_ALL_REDUCE_KERNEL_LAUNCH(2); break;
|
||||
case 4: CUSTOM_ALL_REDUCE_KERNEL_LAUNCH(4); break;
|
||||
case 6: CUSTOM_ALL_REDUCE_KERNEL_LAUNCH(6); break;
|
||||
case 8: CUSTOM_ALL_REDUCE_KERNEL_LAUNCH(8); break;
|
||||
default: break;
|
||||
}
|
||||
}
|
||||
|
||||
// Template instantiation
|
||||
template void invokeOneOrTwoShotAllReduceKernel<uint16_t>(AllReduceParams& param, cudaStream_t stream);
|
||||
#ifdef ENABLE_BF16
|
||||
template void invokeOneOrTwoShotAllReduceKernel<__nv_bfloat16>(AllReduceParams& param, cudaStream_t stream);
|
||||
#endif
|
||||
template void invokeOneOrTwoShotAllReduceKernel<uint32_t>(AllReduceParams& param, cudaStream_t stream);
|
||||
|
||||
} // namespace tensorrt_llm::kernels
|
||||
73
cpp/tensorrt_llm/kernels/customAllReduceKernels.h
Normal file
73
cpp/tensorrt_llm/kernels/customAllReduceKernels.h
Normal file
@ -0,0 +1,73 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <assert.h>
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
|
||||
#define CUSTOM_AR_SIZE_THRESHOLD 50331648
|
||||
#define MAX_ALL_REDUCE_BLOCKS 24
|
||||
#define FLAG(a) ((uint32_t) ((a) % 0x146))
|
||||
#define MAX_RANKS_PER_NODE 8
|
||||
#define WARP_SIZE 32
|
||||
#define DEFAULT_BLOCK_SIZE 1024
|
||||
#define DEFAULT_ALGO_AR_SIZE_THRESHOLD 393216
|
||||
|
||||
namespace tensorrt_llm::kernels
|
||||
{
|
||||
|
||||
#ifdef ENABLE_BF16
|
||||
typedef struct bf168
|
||||
{
|
||||
__nv_bfloat162 x;
|
||||
__nv_bfloat162 y;
|
||||
__nv_bfloat162 z;
|
||||
__nv_bfloat162 w;
|
||||
} bf168;
|
||||
#endif
|
||||
|
||||
struct AllReduceIpcMemHandles
|
||||
{
|
||||
cudaIpcMemHandle_t peer_barrier_ipc_handles[MAX_RANKS_PER_NODE];
|
||||
cudaIpcMemHandle_t peer_comm_buffer_ipc_handles[MAX_RANKS_PER_NODE];
|
||||
};
|
||||
|
||||
struct AllReduceParams
|
||||
{
|
||||
size_t elts_total;
|
||||
size_t elts_per_rank;
|
||||
size_t elts_per_block;
|
||||
size_t rank_offset;
|
||||
size_t ranks_per_node, rank, local_rank, node_id;
|
||||
uint32_t barrier_flag;
|
||||
uint32_t* peer_barrier_ptrs[MAX_RANKS_PER_NODE];
|
||||
void* peer_comm_buffer_ptrs[MAX_RANKS_PER_NODE];
|
||||
void* local_output_buffer_ptr;
|
||||
AllReduceIpcMemHandles ipc_mem_handles;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
void invokeOneOrTwoShotAllReduceKernel(AllReduceParams& param, cudaStream_t stream);
|
||||
|
||||
void kernelLaunchConfig(int& blocks_per_grid, int& threads_per_block, size_t elts, int kernel_algo);
|
||||
|
||||
} // namespace tensorrt_llm::kernels
|
||||
@ -116,7 +116,9 @@ struct Multihead_attention_params_base
|
||||
// The per-head latent space reserved for rotary embeddings.
|
||||
int rotary_embedding_dim = 0;
|
||||
float rotary_embedding_base = 0.0f;
|
||||
RotaryScalingType rotary_embedding_scale_type = RotaryScalingType::kNONE;
|
||||
float rotary_embedding_scale = 0.0f;
|
||||
int rotary_embedding_max_positions = 0;
|
||||
// The current timestep. TODO(bhsueh) Check that do we only this param in cross attention?
|
||||
int timestep = 0;
|
||||
// The current timestep of each sentences (support different timestep for different sentences)
|
||||
|
||||
@ -1165,8 +1165,13 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
|
||||
zero(k);
|
||||
zero(q_bias);
|
||||
zero(k_bias);
|
||||
float rotary_embedding_base = params.rotary_embedding_base;
|
||||
float rotary_embedding_scale = params.rotary_embedding_scale;
|
||||
if (is_valid_qk_vec)
|
||||
{
|
||||
mmha::update_rotary_base_n_scale(rotary_embedding_base, rotary_embedding_scale,
|
||||
params.rotary_embedding_scale_type, params.rotary_embedding_dim, params.rotary_embedding_max_positions,
|
||||
tlength);
|
||||
// Query
|
||||
// The stride between tokens. We may be able to always use params.stride.
|
||||
uint32_t q_stride = params.stride ? static_cast<uint32_t>(params.stride) : (num_heads * Dh);
|
||||
@ -1280,7 +1285,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
|
||||
mmha::vec_from_smem_transpose(k, k_smem, transpose_idx, smem_pitch);
|
||||
|
||||
mmha::apply_rotary_embedding(q, k, transpose_idx / tidx_factor, params.rotary_embedding_dim,
|
||||
params.rotary_embedding_base, params.rotary_embedding_scale, tlength);
|
||||
rotary_embedding_base, rotary_embedding_scale, tlength);
|
||||
|
||||
mmha::write_smem_transpose(k, k_smem, transpose_idx, smem_pitch);
|
||||
mmha::write_smem_transpose(q, q_smem_, transpose_idx, smem_pitch);
|
||||
|
||||
@ -17,6 +17,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/common/cudaTypeUtils.cuh"
|
||||
#include "tensorrt_llm/kernels/gptKernels.h"
|
||||
#include <stdint.h>
|
||||
|
||||
#ifdef ENABLE_BF16
|
||||
@ -122,6 +123,12 @@ struct num_elems<Float8_>
|
||||
static constexpr int value = 8;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct num_elems<half>
|
||||
{
|
||||
static constexpr int value = 1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct num_elems<uint32_t>
|
||||
{
|
||||
@ -141,6 +148,12 @@ struct num_elems<uint4>
|
||||
};
|
||||
|
||||
#ifdef ENABLE_BF16
|
||||
template <>
|
||||
struct num_elems<__nv_bfloat16>
|
||||
{
|
||||
static constexpr int value = 1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct num_elems<__nv_bfloat162>
|
||||
{
|
||||
@ -197,6 +210,12 @@ struct packed_type<T, 1>
|
||||
using type = T;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct packed_type<int8_t, 1>
|
||||
{
|
||||
using type = int8_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct packed_type<int8_t, 2>
|
||||
{
|
||||
@ -216,6 +235,13 @@ struct packed_type<int8_t, 8>
|
||||
};
|
||||
|
||||
#ifdef ENABLE_FP8
|
||||
|
||||
template <>
|
||||
struct packed_type<__nv_fp8_e4m3, 1>
|
||||
{
|
||||
using type = __nv_fp8_e4m3;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct packed_type<__nv_fp8_e4m3, 2>
|
||||
{
|
||||
@ -1508,6 +1534,32 @@ inline __device__ void zero(T& dst)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
inline __device__ float update_rotary_base(
|
||||
const int kv_seq_len, const int max_positions, const int embed_dim, const float base, const float scale)
|
||||
{
|
||||
const float b = (scale * kv_seq_len / max_positions) - (scale - 1);
|
||||
const float p = static_cast<float>(embed_dim) / (embed_dim - 2);
|
||||
return base * pow(b, p);
|
||||
}
|
||||
|
||||
inline __device__ void update_rotary_base_n_scale(float& base, float& scale, RotaryScalingType const scale_type,
|
||||
const int rot_embed_dim, const int max_positions, const int seq_len)
|
||||
{
|
||||
// only update the base and/or scale if needed based on scale_type
|
||||
if (scale_type == RotaryScalingType::kDYNAMIC)
|
||||
{
|
||||
if (seq_len > max_positions)
|
||||
{
|
||||
base = update_rotary_base(seq_len, max_positions, rot_embed_dim, base, scale);
|
||||
}
|
||||
scale = 1.0f; // scale is only used in base for dynamic scaling
|
||||
}
|
||||
else if (scale_type == RotaryScalingType::kLINEAR)
|
||||
{
|
||||
scale = 1.0f / scale;
|
||||
}
|
||||
}
|
||||
|
||||
inline __device__ float2 rotary_embedding_coefficient(
|
||||
const int zid, const int rot_embed_dim, const float base, const float scale, const float t_step)
|
||||
{
|
||||
@ -2006,6 +2058,14 @@ inline __device__ float2 convert_to_float(uint32_t u)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <>
|
||||
inline __device__ float convert_to_float(half u)
|
||||
{
|
||||
return static_cast<float>(u);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#ifdef ENABLE_FP8
|
||||
inline __device__ void convert_from_fp8(uint16_t* v, const __nv_fp8_e4m3 u)
|
||||
{
|
||||
@ -2209,6 +2269,13 @@ inline __device__ void convert_to_fp8(fp8_8_t* v, const bf16_8_t u)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
inline __device__ void convert_to_fp8(__nv_fp8_e4m3* v, const half u)
|
||||
{
|
||||
v[0] = __nv_fp8_e4m3(u);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
inline __device__ void convert_to_fp8(__nv_fp8_e4m3* v, const uint16_t u)
|
||||
{
|
||||
v[0] = __nv_fp8_e4m3(reinterpret_cast<const half&>(u));
|
||||
|
||||
@ -111,8 +111,7 @@ __global__ void length_criterion(bool* finished, int* finished_sum, const uint32
|
||||
{
|
||||
const int batch_idx = index / beam_width;
|
||||
|
||||
// sequence_lengths is updated, so need to minus 1
|
||||
finished[index] |= sequence_lengths[index] - 1 >= sequence_limit_length[batch_idx];
|
||||
finished[index] |= sequence_lengths[index] >= sequence_limit_length[batch_idx];
|
||||
thread_finished_count += finished[index] ? 1 : 0;
|
||||
}
|
||||
|
||||
|
||||
@ -1252,8 +1252,9 @@ struct Vec_t<__nv_bfloat16>
|
||||
template <typename T, bool ADD_BIAS, bool USING_CONTEXT_FMHA>
|
||||
__global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf, T* k_buf, T* v_buf, T* QKV, const T* __restrict qkv_bias,
|
||||
const int* seq_lens, const int* padding_offset, const int batch_size, const int seq_len, const int head_num,
|
||||
const int kv_head_num, const int size_per_head, const int rotary_embedding_dim, const float rotary_embedding_base,
|
||||
const float rotary_embedding_scale, PositionEmbeddingType const position_embedding_type)
|
||||
const int kv_head_num, const int size_per_head, const int rotary_embedding_dim, float rotary_embedding_base,
|
||||
RotaryScalingType const rotary_scale_type, float rotary_embedding_scale, const int rotary_embedding_max_positions,
|
||||
PositionEmbeddingType const position_embedding_type)
|
||||
{
|
||||
// This kernel add bias to QKV, which has shape [batch_size, seq_len, 3, head_num, size_per_head], and
|
||||
// QKV split to 3 split buffer q, k, v and transpose them to [batch_size, head_num, seq_len, size_per_head].
|
||||
@ -1323,6 +1324,11 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf, T* k_buf, T* v_buf,
|
||||
|
||||
Vec_t q, k, v, zero;
|
||||
Vec_t q_bias, k_bias, v_bias;
|
||||
if (valid_seq)
|
||||
{
|
||||
mmha::update_rotary_base_n_scale(rotary_embedding_base, rotary_embedding_scale, rotary_scale_type,
|
||||
rotary_embedding_dim, rotary_embedding_max_positions, actual_seq_len);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < sizeof(Vec_t) / sizeof(uint32_t); i++)
|
||||
@ -1456,14 +1462,16 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf, T* k_buf, T* v_buf,
|
||||
#define FUSED_QKV_BIAS_ROTARY_TRANSPOSE_LAUNCH(T, ADD_BIAS, USING_CONTEXT_FMHA) \
|
||||
add_fusedQKV_bias_transpose_kernel<T, ADD_BIAS, USING_CONTEXT_FMHA><<<grid, block, smem_size, stream>>>(q_buf, \
|
||||
k_buf, v_buf, QKV, qkv_bias, seq_lens, padding_offset, batch_size, seq_len, head_num, kv_head_num, \
|
||||
size_per_head, rotary_embedding_dim, rotary_embedding_base, rotary_embedding_scale, position_embedding_type);
|
||||
size_per_head, rotary_embedding_dim, rotary_embedding_base, rotary_scale_type, rotary_embedding_scale, \
|
||||
rotary_embedding_max_positions, position_embedding_type);
|
||||
|
||||
template <typename T>
|
||||
void invokeAddFusedQKVBiasTranspose(T* q_buf, T* k_buf, T* v_buf, T* QKV, const T* qkv_bias, const int* seq_lens,
|
||||
const int* padding_offset, const int batch_size, const int seq_len, const int token_num, const int head_num,
|
||||
const int kv_head_num, const int size_per_head, const bool using_context_fmha, const int rotary_embedding_dim,
|
||||
const float rotary_embedding_base, const float rotary_embedding_scale,
|
||||
const PositionEmbeddingType position_embedding_type, const float* scale, const int int8_mode, cudaStream_t stream)
|
||||
const float rotary_embedding_base, const RotaryScalingType rotary_scale_type, const float rotary_embedding_scale,
|
||||
const int rotary_embedding_max_positions, const PositionEmbeddingType position_embedding_type, const float* scale,
|
||||
const int int8_mode, cudaStream_t stream)
|
||||
{
|
||||
// [bs, seq_len, 3, head, Dh]
|
||||
if (rotary_embedding_dim == 0)
|
||||
@ -1535,7 +1543,8 @@ void invokeAddFusedQKVBiasTranspose(T* q_buf, T* k_buf, T* v_buf, T* QKV, const
|
||||
template void invokeAddFusedQKVBiasTranspose(T* q_buf, T* k_buf, T* v_buf, T* QKV, const T* qkv_bias, \
|
||||
const int* seq_lens, const int* padding_offset, const int batch_size, const int seq_len, const int token_num, \
|
||||
const int head_num, const int kv_head_num, const int size_per_head, const bool using_context_fmha, \
|
||||
const int rotary_embedding_dim, const float rotary_embedding_base, const float rotary_embedding_scale, \
|
||||
const int rotary_embedding_dim, const float rotary_embedding_base, const RotaryScalingType rotary_scale_type, \
|
||||
const float rotary_embedding_scale, const int rotary_embedding_max_poisitions, \
|
||||
const PositionEmbeddingType position_embedding_type, const float* scale, const int int8_mode, \
|
||||
cudaStream_t stream)
|
||||
INSTANTIATE_ADDFUSEDQKVBIAS_TRANSPOSE(float);
|
||||
|
||||
@ -76,8 +76,9 @@ template <typename T>
|
||||
void invokeAddFusedQKVBiasTranspose(T* q_buf, T* k_buf, T* v_buf, T* QKV, const T* qkv_bias, const int* seq_lens,
|
||||
const int* padding_offset, const int batch_size, const int seq_len, const int token_num, const int head_num,
|
||||
const int kv_head_num, const int size_per_head, const bool using_context_fmha, const int rotary_embedding_dim,
|
||||
const float rotary_embedding_base, const float rotary_embedding_scale,
|
||||
PositionEmbeddingType const position_embedding_type, const float* scale, const int int8_mode, cudaStream_t stream);
|
||||
float rotary_embedding_base, const RotaryScalingType rotary_scale_type, float rotary_embedding_scale,
|
||||
const int rotary_embedding_max_positions, PositionEmbeddingType const position_embedding_type, const float* scale,
|
||||
const int int8_mode, cudaStream_t stream);
|
||||
|
||||
template <typename T>
|
||||
void invokeAddFusedQKVBiasTranspose(T* q_buf, T* k_buf, T* v_buf, T* QKV, const T* qkv_bias, const int* seq_lens,
|
||||
@ -92,12 +93,14 @@ template <typename T>
|
||||
void invokeAddFusedQKVBiasTranspose(T* q_buf, T* k_buf, T* v_buf, T* QKV, const int* seq_lens,
|
||||
const int* padding_offset, const int batch_size, const int seq_len, const int token_num, const int head_num,
|
||||
const int kv_head_num, const int size_per_head, const bool using_context_fmha, const int rotary_embedding_dim,
|
||||
const float rotary_embedding_base, const float rotary_embedding_scale,
|
||||
PositionEmbeddingType const position_embedding_type, const float* scale, const int int8_mode, cudaStream_t stream)
|
||||
float rotary_embedding_base, const RotaryScalingType rotary_scale_type, float rotary_embedding_scale,
|
||||
const int rotary_embedding_max_positions, PositionEmbeddingType const position_embedding_type, const float* scale,
|
||||
const int int8_mode, cudaStream_t stream)
|
||||
{
|
||||
invokeAddFusedQKVBiasTranspose(q_buf, k_buf, v_buf, QKV, (const T*) nullptr, seq_lens, padding_offset, batch_size,
|
||||
seq_len, token_num, head_num, kv_head_num, size_per_head, using_context_fmha, rotary_embedding_dim,
|
||||
rotary_embedding_base, rotary_embedding_scale, position_embedding_type, scale, int8_mode, stream);
|
||||
rotary_embedding_base, rotary_scale_type, rotary_embedding_scale, rotary_embedding_max_positions,
|
||||
position_embedding_type, scale, int8_mode, stream);
|
||||
}
|
||||
|
||||
template <typename T, typename KVCacheBuffer>
|
||||
@ -105,5 +108,13 @@ void invokeTranspose4dBatchMajor(const T* k_src, const T* v_src, KVCacheBuffer&
|
||||
const int seq_len, const int max_seq_len, const int size_per_head, const int local_head_num,
|
||||
const KvCacheDataType cache_type, const float* kvScaleOrigQuant, const int* sequence_lengths, cudaStream_t stream);
|
||||
|
||||
template <typename T, typename KVCacheBuffer>
|
||||
void invokeApplyBiasRopeUpdateKVCache(T* QKV, KVCacheBuffer& kvTable, const T* qkv_bias, const int* seq_lens,
|
||||
const int* padding_offset, const int batch_size, const int seq_len, const int token_num, const int head_num,
|
||||
const int kv_head_num, const int size_per_head, const int rotary_embedding_dim, const float rotary_embedding_base,
|
||||
const RotaryScalingType rotary_scale_type, const float rotary_embedding_scale,
|
||||
const int rotary_embedding_max_positions, const PositionEmbeddingType position_embedding_type, const float* scale,
|
||||
const int int8_mode, const KvCacheDataType cache_type, const float* kvScaleOrigQuant, cudaStream_t stream);
|
||||
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
|
||||
371
cpp/tensorrt_llm/kernels/unfusedAttentionKernels_2.cu
Normal file
371
cpp/tensorrt_llm/kernels/unfusedAttentionKernels_2.cu
Normal file
@ -0,0 +1,371 @@
|
||||
/*
|
||||
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2021, NAVER Corp. Authored by CLOVA.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
// Separate from unfusedAttentionKernel to accelerate compiling.
|
||||
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/cudaTypeUtils.cuh"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/common/reduceKernelUtils.cuh"
|
||||
#include "tensorrt_llm/kernels/decoderMaskedMultiheadAttentionUtils.h"
|
||||
#include "tensorrt_llm/kernels/gptKernels.h"
|
||||
#include "tensorrt_llm/kernels/kvCacheUtils.h"
|
||||
#include "tensorrt_llm/kernels/unfusedAttentionKernels.h"
|
||||
|
||||
using namespace tensorrt_llm::common;
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
|
||||
template <typename T>
|
||||
struct Vec_t
|
||||
{
|
||||
static constexpr int size = 0;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Vec_t<float>
|
||||
{
|
||||
using Type = float2;
|
||||
static constexpr int size = 2;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Vec_t<half>
|
||||
{
|
||||
using Type = uint32_t;
|
||||
static constexpr int size = 2;
|
||||
};
|
||||
|
||||
#ifdef ENABLE_BF16
|
||||
template <>
|
||||
struct Vec_t<__nv_bfloat16>
|
||||
{
|
||||
using Type = __nv_bfloat162;
|
||||
static constexpr int size = 2;
|
||||
};
|
||||
#endif
|
||||
|
||||
template <typename T, typename T_cache, bool ADD_BIAS, typename KVCacheBuffer>
|
||||
__global__ void applyBiasRopeUpdateKVCache(T* QKV, KVCacheBuffer kvCacheBuffer, const T* __restrict qkv_bias,
|
||||
const int* seq_lens, const int* padding_offset, const float* kvScaleOrigQuant, const int batch_size,
|
||||
const int seq_len, const int head_num, const int kv_head_num, const int size_per_head,
|
||||
const int rotary_embedding_dim, float rotary_embedding_base, RotaryScalingType const rotary_scale_type,
|
||||
float rotary_embedding_scale, const int rotary_embedding_max_positions,
|
||||
PositionEmbeddingType const position_embedding_type)
|
||||
{
|
||||
// This kernel add bias to QKV, which has shape [batch_size, seq_len, 3, head_num, size_per_head], and
|
||||
// QKV split to 3 split buffer q, k, v and transpose them to [batch_size, head_num, seq_len, size_per_head].
|
||||
// For q and k, also apply the rotary embedding.
|
||||
|
||||
// NOTE:
|
||||
// head_num == kv_head_num
|
||||
// QKV src shape (batch_size, seq_len, 3, head_num, size_per_head)
|
||||
// ^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
// m n
|
||||
// QKV dst shape (3, batch_size, head_num, seq_len, size_per_head)
|
||||
// head_num != kv_head_num
|
||||
// QKV src shape: (batch_size, seq_len, head_num * size_per_head + 2 * kv_head_num * size_per_head)
|
||||
// ^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
// m n
|
||||
// Q dst shape: (batch_size, head_num, seq_len, size_per_head)
|
||||
// KV dst shape: (batch_size, kv_head_num, seq_len, size_per_head)
|
||||
extern __shared__ __align__(sizeof(float2)) char smem_[]; // align on largest vector type
|
||||
|
||||
constexpr int vec_size = Vec_t<T>::size;
|
||||
using Vec_t = typename Vec_t<T>::Type;
|
||||
const int token_idx = blockIdx.x;
|
||||
const bool has_padding = padding_offset == nullptr;
|
||||
|
||||
constexpr bool ENABLE_8BITS_CACHE = sizeof(T_cache) == 1;
|
||||
constexpr int X_ELEMS = vec_size;
|
||||
const int sizePerHeadDivX = size_per_head / X_ELEMS;
|
||||
using T_dst = T_cache;
|
||||
|
||||
// The index of the token in the batch. It includes "virtual" padding (even if the input is not padded)
|
||||
// such that the sequence index and the position in the sequence can be obtained using the max.
|
||||
// sequence length as:
|
||||
const int token_padding_offset = has_padding ? 0 : padding_offset[token_idx];
|
||||
const int global_token_idx = token_idx + token_padding_offset;
|
||||
const int batch_idx = global_token_idx / seq_len;
|
||||
const int token_idx_in_seq = global_token_idx % seq_len;
|
||||
const int actual_seq_len = seq_lens[batch_idx];
|
||||
const bool valid_seq = token_idx_in_seq < actual_seq_len || !has_padding;
|
||||
|
||||
const int head_idx = blockIdx.y;
|
||||
const int tidx = threadIdx.x;
|
||||
|
||||
const bool is_seq_masked = !valid_seq;
|
||||
const bool is_head_size_masked = tidx * vec_size >= size_per_head;
|
||||
const bool is_masked = is_head_size_masked || is_seq_masked;
|
||||
|
||||
const int hidden_size = head_num * size_per_head;
|
||||
const int hidden_idx = head_idx * size_per_head + tidx * vec_size;
|
||||
const int qheads_per_kv_head = head_num / kv_head_num;
|
||||
const int kv_head_idx = head_idx / qheads_per_kv_head;
|
||||
const int hidden_idx_kv = kv_head_idx * size_per_head + tidx * vec_size;
|
||||
const int n = (head_num + 2 * kv_head_num) * size_per_head;
|
||||
|
||||
const int dst_kv_seq_idx = token_idx_in_seq;
|
||||
const int src_k_offset = hidden_size;
|
||||
const int src_v_offset = hidden_size + kv_head_num * size_per_head;
|
||||
|
||||
// NOTE: q has seq len excluding prefix prompt
|
||||
// head_num == kv_head_num:
|
||||
// src QKV: [batch, time, 3, head_num, size_per_head]
|
||||
// head_num != kv_head_num:
|
||||
// src QKV: [batch, time, head_num * size_per_head + 2 * kv_head_num * size_per_head]
|
||||
const int src_q_idx = token_idx * n + hidden_idx;
|
||||
const int src_k_idx = token_idx * n + src_k_offset + hidden_idx_kv;
|
||||
const int src_v_idx = token_idx * n + src_v_offset + hidden_idx_kv;
|
||||
|
||||
Vec_t q, k, v, zero;
|
||||
Vec_t q_bias, k_bias, v_bias;
|
||||
if (valid_seq)
|
||||
{
|
||||
mmha::update_rotary_base_n_scale(rotary_embedding_base, rotary_embedding_scale, rotary_scale_type,
|
||||
rotary_embedding_dim, rotary_embedding_max_positions, actual_seq_len);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < sizeof(Vec_t) / sizeof(uint32_t); i++)
|
||||
{
|
||||
reinterpret_cast<uint32_t*>(&zero)[i] = 0u;
|
||||
}
|
||||
|
||||
// load q,k,v and add bias
|
||||
if (!is_masked)
|
||||
{
|
||||
q = *reinterpret_cast<const Vec_t*>(&QKV[src_q_idx]);
|
||||
k = *reinterpret_cast<const Vec_t*>(&QKV[src_k_idx]);
|
||||
v = *reinterpret_cast<const Vec_t*>(&QKV[src_v_idx]);
|
||||
|
||||
if (ADD_BIAS)
|
||||
{
|
||||
q_bias = *reinterpret_cast<const Vec_t*>(&qkv_bias[hidden_idx]);
|
||||
k_bias = *reinterpret_cast<const Vec_t*>(&qkv_bias[hidden_idx_kv + src_k_offset]);
|
||||
v_bias = *reinterpret_cast<const Vec_t*>(&qkv_bias[hidden_idx_kv + src_v_offset]);
|
||||
|
||||
q = mmha::add(q, q_bias);
|
||||
k = mmha::add(k, k_bias);
|
||||
v = mmha::add(v, v_bias);
|
||||
}
|
||||
}
|
||||
|
||||
switch (position_embedding_type)
|
||||
{
|
||||
case PositionEmbeddingType::kROPE_GPTJ:
|
||||
{
|
||||
mmha::apply_rotary_embedding(
|
||||
q, k, tidx, rotary_embedding_dim, rotary_embedding_base, rotary_embedding_scale, dst_kv_seq_idx);
|
||||
break;
|
||||
}
|
||||
case PositionEmbeddingType::kROPE_GPT_NEOX:
|
||||
{
|
||||
const bool do_rotary = !is_masked && vec_size * tidx < rotary_embedding_dim;
|
||||
|
||||
T* q_smem = reinterpret_cast<T*>(smem_);
|
||||
T* k_smem = q_smem + rotary_embedding_dim;
|
||||
|
||||
const int half_rotary_dim = rotary_embedding_dim / 2;
|
||||
const int half_idx = (tidx * vec_size) / half_rotary_dim;
|
||||
const int intra_half_idx = (tidx * vec_size) % half_rotary_dim;
|
||||
const int smem_pitch = half_rotary_dim; // TODO: adjust for bank conflicts?
|
||||
|
||||
if (do_rotary)
|
||||
{
|
||||
*reinterpret_cast<Vec_t*>(q_smem + half_idx * smem_pitch + intra_half_idx) = q;
|
||||
*reinterpret_cast<Vec_t*>(k_smem + half_idx * smem_pitch + intra_half_idx) = k;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
const int transpose_idx = half_idx * (half_rotary_dim / 2) + intra_half_idx / 2;
|
||||
constexpr int tidx_factor = vec_size / 2;
|
||||
if (do_rotary)
|
||||
{
|
||||
mmha::vec_from_smem_transpose(q, q_smem, transpose_idx, smem_pitch);
|
||||
mmha::vec_from_smem_transpose(k, k_smem, transpose_idx, smem_pitch);
|
||||
|
||||
mmha::apply_rotary_embedding(q, k, transpose_idx / tidx_factor, rotary_embedding_dim, rotary_embedding_base,
|
||||
rotary_embedding_scale, dst_kv_seq_idx);
|
||||
|
||||
mmha::write_smem_transpose(q, q_smem, transpose_idx, smem_pitch);
|
||||
mmha::write_smem_transpose(k, k_smem, transpose_idx, smem_pitch);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (do_rotary)
|
||||
{
|
||||
q = *reinterpret_cast<Vec_t*>(q_smem + half_idx * smem_pitch + intra_half_idx);
|
||||
k = *reinterpret_cast<Vec_t*>(k_smem + half_idx * smem_pitch + intra_half_idx);
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
const int channelIdx{tidx};
|
||||
auto kDst = reinterpret_cast<T_dst*>(kvCacheBuffer.getKBlockPtr(batch_idx, token_idx_in_seq));
|
||||
auto vDst = reinterpret_cast<T_dst*>(kvCacheBuffer.getVBlockPtr(batch_idx, token_idx_in_seq));
|
||||
int inBlockIdx = kvCacheBuffer.getKVLocalIdx(token_idx_in_seq, kv_head_idx, sizePerHeadDivX, channelIdx);
|
||||
if (!is_masked)
|
||||
{
|
||||
*reinterpret_cast<Vec_t*>(&QKV[src_q_idx]) = q;
|
||||
if ((head_num == kv_head_num) || (head_idx == (kv_head_idx * qheads_per_kv_head)))
|
||||
{
|
||||
*reinterpret_cast<Vec_t*>(&QKV[src_k_idx]) = k;
|
||||
*reinterpret_cast<Vec_t*>(&QKV[src_v_idx]) = v;
|
||||
|
||||
if (ENABLE_8BITS_CACHE)
|
||||
{
|
||||
inBlockIdx = inBlockIdx * vec_size;
|
||||
// Cast float scale to dst data type.
|
||||
using T_scale = typename mmha::kv_cache_scale_type_t<T, T_cache>::Type;
|
||||
T_scale scaleOrigQuant;
|
||||
mmha::convert_from_float(&scaleOrigQuant, kvScaleOrigQuant[0]);
|
||||
// Store 8bits kv cache.
|
||||
mmha::store_8bits_kv_cache_vec(kDst, k, inBlockIdx, scaleOrigQuant);
|
||||
mmha::store_8bits_kv_cache_vec(vDst, v, inBlockIdx, scaleOrigQuant);
|
||||
}
|
||||
else
|
||||
{
|
||||
reinterpret_cast<Vec_t*>(kDst)[inBlockIdx] = k;
|
||||
reinterpret_cast<Vec_t*>(vDst)[inBlockIdx] = v;
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (is_seq_masked && !is_head_size_masked)
|
||||
{
|
||||
// Set padding to zero in case of potential nan generated.
|
||||
*reinterpret_cast<Vec_t*>(&QKV[src_q_idx]) = zero;
|
||||
if ((head_num == kv_head_num) || (head_idx == (kv_head_idx * qheads_per_kv_head)))
|
||||
{
|
||||
*reinterpret_cast<Vec_t*>(&QKV[src_k_idx]) = zero;
|
||||
*reinterpret_cast<Vec_t*>(&QKV[src_v_idx]) = zero;
|
||||
|
||||
if (ENABLE_8BITS_CACHE)
|
||||
{
|
||||
inBlockIdx = inBlockIdx * vec_size;
|
||||
// Cast float scale to dst data type.
|
||||
using T_scale = typename mmha::kv_cache_scale_type_t<T, T_cache>::Type;
|
||||
T_scale scaleOrigQuant;
|
||||
mmha::convert_from_float(&scaleOrigQuant, kvScaleOrigQuant[0]);
|
||||
// Store 8bits kv cache.
|
||||
mmha::store_8bits_kv_cache_vec(kDst, zero, inBlockIdx, scaleOrigQuant);
|
||||
mmha::store_8bits_kv_cache_vec(vDst, zero, inBlockIdx, scaleOrigQuant);
|
||||
}
|
||||
else
|
||||
{
|
||||
reinterpret_cast<Vec_t*>(kDst)[inBlockIdx] = zero;
|
||||
reinterpret_cast<Vec_t*>(vDst)[inBlockIdx] = zero;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename T_cache, typename KVCacheBuffer>
|
||||
void invokeApplyBiasRopeUpdateKVCacheDispatch(T* QKV, KVCacheBuffer& kvTable, const T* qkv_bias, const int* seq_lens,
|
||||
const int* padding_offset, const int batch_size, const int seq_len, const int token_num, const int head_num,
|
||||
const int kv_head_num, const int size_per_head, const int rotary_embedding_dim, const float rotary_embedding_base,
|
||||
const RotaryScalingType rotary_scale_type, const float rotary_embedding_scale,
|
||||
const int rotary_embedding_max_positions, const PositionEmbeddingType position_embedding_type, const float* scale,
|
||||
const float* kvScaleOrigQuant, const int int8_mode, cudaStream_t stream)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(int8_mode != 2, "w8a8 not yet implemented with RoPE"); // TODO(mseznec)
|
||||
// To implement rotary embeddings, each thread processes two QKV elems:
|
||||
dim3 block((size_per_head / Vec_t<T>::size + 31) / 32 * 32);
|
||||
dim3 grid(token_num, head_num);
|
||||
size_t smem_size
|
||||
= (position_embedding_type == PositionEmbeddingType::kROPE_GPT_NEOX ? 2 * rotary_embedding_dim * sizeof(T) : 0);
|
||||
// NOTE: add offset for rotary embedding
|
||||
if (qkv_bias != nullptr)
|
||||
{
|
||||
applyBiasRopeUpdateKVCache<T, T_cache, true, KVCacheBuffer><<<grid, block, smem_size, stream>>>(QKV, kvTable,
|
||||
qkv_bias, seq_lens, padding_offset, kvScaleOrigQuant, batch_size, seq_len, head_num, kv_head_num,
|
||||
size_per_head, rotary_embedding_dim, rotary_embedding_base, rotary_scale_type, rotary_embedding_scale,
|
||||
rotary_embedding_max_positions, position_embedding_type);
|
||||
}
|
||||
else
|
||||
{
|
||||
applyBiasRopeUpdateKVCache<T, T_cache, false, KVCacheBuffer><<<grid, block, smem_size, stream>>>(QKV, kvTable,
|
||||
qkv_bias, seq_lens, padding_offset, kvScaleOrigQuant, batch_size, seq_len, head_num, kv_head_num,
|
||||
size_per_head, rotary_embedding_dim, rotary_embedding_base, rotary_scale_type, rotary_embedding_scale,
|
||||
rotary_embedding_max_positions, position_embedding_type);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename KVCacheBuffer>
|
||||
void invokeApplyBiasRopeUpdateKVCache(T* QKV, KVCacheBuffer& kvTable, const T* qkv_bias, const int* seq_lens,
|
||||
const int* padding_offset, const int batch_size, const int seq_len, const int token_num, const int head_num,
|
||||
const int kv_head_num, const int size_per_head, const int rotary_embedding_dim, const float rotary_embedding_base,
|
||||
const RotaryScalingType rotary_scale_type, const float rotary_embedding_scale,
|
||||
const int rotary_embedding_max_positions, const PositionEmbeddingType position_embedding_type, const float* scale,
|
||||
const int int8_mode, const KvCacheDataType cache_type, const float* kvScaleOrigQuant, cudaStream_t stream)
|
||||
{
|
||||
// Block handles both K and V tile.
|
||||
constexpr int x = (sizeof(T) == 4) ? 4 : 8;
|
||||
TLLM_CHECK_WITH_INFO(size_per_head % x == 0, "Size per head is not a multiple of X");
|
||||
|
||||
if (cache_type == KvCacheDataType::INT8)
|
||||
{
|
||||
invokeApplyBiasRopeUpdateKVCacheDispatch<T, int8_t, KVCacheBuffer>(QKV, kvTable, qkv_bias, seq_lens,
|
||||
padding_offset, batch_size, seq_len, token_num, head_num, kv_head_num, size_per_head, rotary_embedding_dim,
|
||||
rotary_embedding_base, rotary_scale_type, rotary_embedding_scale, rotary_embedding_max_positions,
|
||||
position_embedding_type, scale, kvScaleOrigQuant, int8_mode, stream);
|
||||
}
|
||||
#ifdef ENABLE_FP8
|
||||
else if (cache_type == KvCacheDataType::FP8)
|
||||
{
|
||||
invokeApplyBiasRopeUpdateKVCacheDispatch<T, __nv_fp8_e4m3, KVCacheBuffer>(QKV, kvTable, qkv_bias, seq_lens,
|
||||
padding_offset, batch_size, seq_len, token_num, head_num, kv_head_num, size_per_head, rotary_embedding_dim,
|
||||
rotary_embedding_base, rotary_scale_type, rotary_embedding_scale, rotary_embedding_max_positions,
|
||||
position_embedding_type, scale, kvScaleOrigQuant, int8_mode, stream);
|
||||
}
|
||||
#endif // ENABLE_FP8
|
||||
else
|
||||
{
|
||||
invokeApplyBiasRopeUpdateKVCacheDispatch<T, T, KVCacheBuffer>(QKV, kvTable, qkv_bias, seq_lens, padding_offset,
|
||||
batch_size, seq_len, token_num, head_num, kv_head_num, size_per_head, rotary_embedding_dim,
|
||||
rotary_embedding_base, rotary_scale_type, rotary_embedding_scale, rotary_embedding_max_positions,
|
||||
position_embedding_type, scale, kvScaleOrigQuant, int8_mode, stream);
|
||||
}
|
||||
}
|
||||
|
||||
#define INSTANTIATE_ADDFUSEDQKVBIAS_TRANSPOSE(T, KVCacheBuffer) \
|
||||
template void invokeApplyBiasRopeUpdateKVCache(T* QKV, KVCacheBuffer& kvTable, const T* qkv_bias, \
|
||||
const int* seq_lens, const int* padding_offset, const int batch_size, const int seq_len, const int token_num, \
|
||||
const int head_num, const int kv_head_num, const int size_per_head, const int rotary_embedding_dim, \
|
||||
const float rotary_embedding_base, const RotaryScalingType rotary_scale_type, \
|
||||
const float rotary_embedding_scale, const int rotary_embedding_max_positions, \
|
||||
const PositionEmbeddingType position_embedding_type, const float* scale, const int int8_mode, \
|
||||
const KvCacheDataType cache_type, const float* kvScaleOrigQuant, cudaStream_t stream)
|
||||
|
||||
INSTANTIATE_ADDFUSEDQKVBIAS_TRANSPOSE(float, KVBlockArray);
|
||||
INSTANTIATE_ADDFUSEDQKVBIAS_TRANSPOSE(float, KVLinearBuffer);
|
||||
INSTANTIATE_ADDFUSEDQKVBIAS_TRANSPOSE(half, KVBlockArray);
|
||||
INSTANTIATE_ADDFUSEDQKVBIAS_TRANSPOSE(half, KVLinearBuffer);
|
||||
#ifdef ENABLE_BF16
|
||||
INSTANTIATE_ADDFUSEDQKVBIAS_TRANSPOSE(__nv_bfloat16, KVBlockArray);
|
||||
INSTANTIATE_ADDFUSEDQKVBIAS_TRANSPOSE(__nv_bfloat16, KVLinearBuffer);
|
||||
#endif
|
||||
#undef INSTANTIATE_ADDFUSEDQKVBIAS_TRANSPOSE
|
||||
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
@ -65,8 +65,8 @@ void update_indir_cache_kernelLauncher(int* tgt_indir_cache, const int* src_indi
|
||||
|
||||
template <typename T>
|
||||
BaseBeamSearchLayer<T>::BaseBeamSearchLayer(size_t vocab_size, size_t vocab_size_padded, cudaStream_t stream,
|
||||
cublasMMWrapper* cublas_wrapper, IAllocator* allocator, bool is_free_buffer_after_forward)
|
||||
: BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward, nullptr)
|
||||
IAllocator* allocator, bool is_free_buffer_after_forward)
|
||||
: BaseLayer(stream, allocator, is_free_buffer_after_forward, nullptr)
|
||||
, vocab_size_(vocab_size)
|
||||
, vocab_size_padded_(vocab_size_padded)
|
||||
{
|
||||
|
||||
@ -42,8 +42,8 @@ class BaseBeamSearchLayer : public BaseLayer
|
||||
public:
|
||||
using SetupParams = DecodingSetupParams;
|
||||
|
||||
BaseBeamSearchLayer(size_t vocab_size, size_t vocab_size_padded, cudaStream_t stream,
|
||||
tc::cublasMMWrapper* cublas_wrapper, tc::IAllocator* allocator, bool is_free_buffer_after_forward);
|
||||
BaseBeamSearchLayer(size_t vocab_size, size_t vocab_size_padded, cudaStream_t stream, tc::IAllocator* allocator,
|
||||
bool is_free_buffer_after_forward);
|
||||
|
||||
BaseBeamSearchLayer(BaseBeamSearchLayer<T> const& beam_search_layer);
|
||||
|
||||
|
||||
@ -17,7 +17,6 @@
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/common/allocator.h"
|
||||
#include "tensorrt_llm/common/cublasMMWrapper.h"
|
||||
#include "tensorrt_llm/common/tensor.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
@ -28,11 +27,9 @@ namespace layers
|
||||
class BaseLayer
|
||||
{
|
||||
public:
|
||||
BaseLayer(cudaStream_t stream, tensorrt_llm::common::cublasMMWrapper* cublas_wrapper,
|
||||
tensorrt_llm::common::IAllocator* allocator, bool is_free_buffer_after_forward,
|
||||
BaseLayer(cudaStream_t stream, tensorrt_llm::common::IAllocator* allocator, bool is_free_buffer_after_forward,
|
||||
cudaDeviceProp* cuda_device_prop = nullptr)
|
||||
: stream_(stream)
|
||||
, cublas_wrapper_(cublas_wrapper)
|
||||
, allocator_(allocator)
|
||||
, cuda_device_prop_(cuda_device_prop)
|
||||
, is_free_buffer_after_forward_(is_free_buffer_after_forward){};
|
||||
@ -51,7 +48,6 @@ public:
|
||||
protected:
|
||||
// device environments
|
||||
cudaStream_t stream_;
|
||||
tensorrt_llm::common::cublasMMWrapper* cublas_wrapper_;
|
||||
tensorrt_llm::common::IAllocator* allocator_;
|
||||
cudaDeviceProp* cuda_device_prop_ = nullptr;
|
||||
|
||||
|
||||
@ -69,9 +69,8 @@ void BaseSamplingLayer<T>::freeBuffer()
|
||||
|
||||
template <typename T>
|
||||
BaseSamplingLayer<T>::BaseSamplingLayer(size_t vocab_size, size_t vocab_size_padded, cudaStream_t stream,
|
||||
cublasMMWrapper* cublas_wrapper, IAllocator* allocator, bool is_free_buffer_after_forward,
|
||||
cudaDeviceProp* cuda_device_prop)
|
||||
: BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward, cuda_device_prop)
|
||||
IAllocator* allocator, bool is_free_buffer_after_forward, cudaDeviceProp* cuda_device_prop)
|
||||
: BaseLayer(stream, allocator, is_free_buffer_after_forward, cuda_device_prop)
|
||||
, vocab_size_(vocab_size)
|
||||
, vocab_size_padded_(vocab_size_padded)
|
||||
{
|
||||
|
||||
@ -36,8 +36,8 @@ class BaseSamplingLayer : public BaseLayer
|
||||
{
|
||||
public:
|
||||
BaseSamplingLayer(size_t vocab_size, size_t vocab_size_padded, cudaStream_t stream,
|
||||
tensorrt_llm::common::cublasMMWrapper* cublas_wrapper, tensorrt_llm::common::IAllocator* allocator,
|
||||
bool is_free_buffer_after_forward, cudaDeviceProp* cuda_device_prop);
|
||||
tensorrt_llm::common::IAllocator* allocator, bool is_free_buffer_after_forward,
|
||||
cudaDeviceProp* cuda_device_prop);
|
||||
|
||||
BaseSamplingLayer(BaseSamplingLayer const& sampling_layer);
|
||||
|
||||
|
||||
@ -37,20 +37,18 @@ void DynamicDecodeLayer<T>::initialize()
|
||||
{
|
||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||
mOnlineBeamsearchDecode = std::make_unique<OnlineBeamSearchLayer<T>>(
|
||||
vocab_size_, vocab_size_padded_, stream_, cublas_wrapper_, allocator_, is_free_buffer_after_forward_);
|
||||
vocab_size_, vocab_size_padded_, stream_, allocator_, is_free_buffer_after_forward_);
|
||||
|
||||
mTopKDecode = std::make_unique<TopKSamplingLayer<T>>(
|
||||
vocab_size_, vocab_size_padded_, stream_, cublas_wrapper_, allocator_, false);
|
||||
mTopKDecode = std::make_unique<TopKSamplingLayer<T>>(vocab_size_, vocab_size_padded_, stream_, allocator_, false);
|
||||
|
||||
mTopPDecode = std::make_unique<TopPSamplingLayer<T>>(
|
||||
vocab_size_, vocab_size_padded_, stream_, cublas_wrapper_, allocator_, false, cuda_device_prop_);
|
||||
vocab_size_, vocab_size_padded_, stream_, allocator_, false, cuda_device_prop_);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
DynamicDecodeLayer<T>::DynamicDecodeLayer(size_t vocab_size, size_t vocab_size_padded, cudaStream_t stream,
|
||||
cublasMMWrapper* cublas_wrapper, IAllocator* allocator, bool is_free_buffer_after_forward,
|
||||
cudaDeviceProp* cuda_device_prop)
|
||||
: BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward)
|
||||
IAllocator* allocator, bool is_free_buffer_after_forward, cudaDeviceProp* cuda_device_prop)
|
||||
: BaseLayer(stream, allocator, is_free_buffer_after_forward)
|
||||
, vocab_size_(vocab_size)
|
||||
, vocab_size_padded_(vocab_size_padded)
|
||||
, cuda_device_prop_(cuda_device_prop)
|
||||
|
||||
@ -43,9 +43,8 @@ template <typename T>
|
||||
class DynamicDecodeLayer : public BaseLayer
|
||||
{
|
||||
public:
|
||||
DynamicDecodeLayer(size_t vocab_size, size_t vocab_size_padded, cudaStream_t stream,
|
||||
tc::cublasMMWrapper* cublas_wrapper, tc::IAllocator* allocator, bool is_free_buffer_after_forward,
|
||||
cudaDeviceProp* cuda_device_prop);
|
||||
DynamicDecodeLayer(size_t vocab_size, size_t vocab_size_padded, cudaStream_t stream, tc::IAllocator* allocator,
|
||||
bool is_free_buffer_after_forward, cudaDeviceProp* cuda_device_prop);
|
||||
|
||||
~DynamicDecodeLayer() override;
|
||||
DynamicDecodeLayer(DynamicDecodeLayer const& dynamic_decode_layer);
|
||||
|
||||
@ -154,9 +154,8 @@ void OnlineBeamSearchLayer<T>::allocateBuffer(size_t batch_size, size_t beam_wid
|
||||
|
||||
template <typename T>
|
||||
OnlineBeamSearchLayer<T>::OnlineBeamSearchLayer(size_t vocab_size, size_t vocab_size_padded, cudaStream_t stream,
|
||||
cublasMMWrapper* cublas_wrapper, IAllocator* allocator, bool is_free_buffer_after_forward)
|
||||
: BaseBeamSearchLayer<T>(
|
||||
vocab_size, vocab_size_padded, stream, cublas_wrapper, allocator, is_free_buffer_after_forward)
|
||||
IAllocator* allocator, bool is_free_buffer_after_forward)
|
||||
: BaseBeamSearchLayer<T>(vocab_size, vocab_size_padded, stream, allocator, is_free_buffer_after_forward)
|
||||
{
|
||||
}
|
||||
|
||||
|
||||
@ -42,8 +42,8 @@ public:
|
||||
std::optional<float> length_penalty;
|
||||
};
|
||||
|
||||
OnlineBeamSearchLayer(size_t vocab_size, size_t vocab_size_padded, cudaStream_t stream,
|
||||
tc::cublasMMWrapper* cublas_wrapper, tc::IAllocator* allocator, bool is_free_buffer_after_forward);
|
||||
OnlineBeamSearchLayer(size_t vocab_size, size_t vocab_size_padded, cudaStream_t stream, tc::IAllocator* allocator,
|
||||
bool is_free_buffer_after_forward);
|
||||
|
||||
OnlineBeamSearchLayer(OnlineBeamSearchLayer<T> const& beam_search_layer);
|
||||
|
||||
|
||||
@ -203,9 +203,8 @@ void TopKSamplingLayer<T>::runSampling(DecodingOutputParams& outputs, DecodingPa
|
||||
|
||||
template <typename T>
|
||||
TopKSamplingLayer<T>::TopKSamplingLayer(size_t vocab_size, size_t vocab_size_padded, cudaStream_t stream,
|
||||
cublasMMWrapper* cublas_wrapper, IAllocator* allocator, bool is_free_buffer_after_forward)
|
||||
: BaseSamplingLayer<T>(
|
||||
vocab_size, vocab_size_padded, stream, cublas_wrapper, allocator, is_free_buffer_after_forward, nullptr)
|
||||
IAllocator* allocator, bool is_free_buffer_after_forward)
|
||||
: BaseSamplingLayer<T>(vocab_size, vocab_size_padded, stream, allocator, is_free_buffer_after_forward, nullptr)
|
||||
{
|
||||
}
|
||||
|
||||
|
||||
@ -34,8 +34,7 @@ public:
|
||||
using SetupParams = typename Base::SetupParams;
|
||||
|
||||
TopKSamplingLayer(size_t vocab_size, size_t vocab_size_padded, cudaStream_t stream,
|
||||
tensorrt_llm::common::cublasMMWrapper* cublas_wrapper, tensorrt_llm::common::IAllocator* allocator,
|
||||
bool is_free_buffer_after_forward);
|
||||
tensorrt_llm::common::IAllocator* allocator, bool is_free_buffer_after_forward);
|
||||
TopKSamplingLayer(TopKSamplingLayer<T> const& top_k_sampling_layer);
|
||||
~TopKSamplingLayer();
|
||||
|
||||
|
||||
@ -266,10 +266,9 @@ void TopPSamplingLayer<T>::runSampling(DecodingOutputParams& outputs, DecodingPa
|
||||
|
||||
template <typename T>
|
||||
TopPSamplingLayer<T>::TopPSamplingLayer(std::size_t vocab_size, std::size_t vocab_size_padded, cudaStream_t stream,
|
||||
cublasMMWrapper* cublas_wrapper, IAllocator* allocator, bool is_free_buffer_after_forward,
|
||||
cudaDeviceProp* cuda_device_prop)
|
||||
: BaseSamplingLayer<T>(vocab_size, vocab_size_padded, stream, cublas_wrapper, allocator,
|
||||
is_free_buffer_after_forward, cuda_device_prop)
|
||||
IAllocator* allocator, bool is_free_buffer_after_forward, cudaDeviceProp* cuda_device_prop)
|
||||
: BaseSamplingLayer<T>(
|
||||
vocab_size, vocab_size_padded, stream, allocator, is_free_buffer_after_forward, cuda_device_prop)
|
||||
{
|
||||
}
|
||||
|
||||
|
||||
@ -42,8 +42,8 @@ public:
|
||||
};
|
||||
|
||||
TopPSamplingLayer(std::size_t vocab_size, std::size_t vocab_size_padded, cudaStream_t stream,
|
||||
tensorrt_llm::common::cublasMMWrapper* cublas_wrapper, tensorrt_llm::common::IAllocator* allocator,
|
||||
bool is_free_buffer_after_forward, cudaDeviceProp* cuda_device_prop);
|
||||
tensorrt_llm::common::IAllocator* allocator, bool is_free_buffer_after_forward,
|
||||
cudaDeviceProp* cuda_device_prop);
|
||||
TopPSamplingLayer(TopPSamplingLayer<T> const& top_p_sampling_layer);
|
||||
~TopPSamplingLayer();
|
||||
|
||||
|
||||
@ -19,7 +19,8 @@ set(PLUGIN_TARGET_NAME nvinfer_plugin_tensorrt_llm)
|
||||
set(PLUGIN_SHARED_TARGET ${PLUGIN_TARGET_NAME})
|
||||
|
||||
set(TARGET_DIR ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
set(PLUGIN_EXPORT_MAP ${TARGET_DIR}/exports.map)
|
||||
set(PLUGIN_EXPORT_MAP ${TARGET_DIR}/exports.map) # Linux
|
||||
set(PLUGIN_EXPORT_DEF ${TARGET_DIR}/exports.def) # Windows
|
||||
|
||||
if(${CMAKE_BUILD_TYPE} MATCHES "Debug")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g")
|
||||
@ -89,7 +90,11 @@ set_target_properties(
|
||||
LIBRARY_OUTPUT_DIRECTORY "${TRT_OUT_DIR}"
|
||||
RUNTIME_OUTPUT_DIRECTORY "${TRT_OUT_DIR}")
|
||||
|
||||
if(NOT MSVC)
|
||||
if(MSVC)
|
||||
set_target_properties(
|
||||
${PLUGIN_SHARED_TARGET}
|
||||
PROPERTIES LINK_FLAGS "/DEF:${PLUGIN_EXPORT_DEF} ${UNDEFINED_FLAG}")
|
||||
else()
|
||||
set_target_properties(
|
||||
${PLUGIN_SHARED_TARGET}
|
||||
PROPERTIES
|
||||
|
||||
@ -225,7 +225,8 @@ int BertAttentionPlugin::enqueueImpl(const nvinfer1::PluginTensorDesc* inputDesc
|
||||
// Padding offset = nullptr here (remove padding is not supported).
|
||||
invokeAddFusedQKVBiasTranspose(q_buf_2_, k_buf_2_, v_buf_2_, const_cast<T*>(attention_input), input_lengths,
|
||||
nullptr, request_batch_size, request_seq_len, batch_size * input_seq_len, mNumHeads, mNumHeads, mHeadSize,
|
||||
mEnableContextFMHA, 0, 0.0f, 0.0f, PositionEmbeddingType::kLEARNED_ABSOLUTE, (float*) nullptr, 0, stream);
|
||||
mEnableContextFMHA, 0, 0.0f, RotaryScalingType::kNONE, 0.0f, 0, PositionEmbeddingType::kLEARNED_ABSOLUTE,
|
||||
(float*) nullptr, 0, stream);
|
||||
|
||||
const auto gemm_data_type = tc::CudaDataType<T>::value;
|
||||
const int attention_seq_len_1 = request_seq_len; // q length
|
||||
@ -363,13 +364,10 @@ int BertAttentionPlugin::initialize() noexcept
|
||||
{
|
||||
auto cublasHandle = getCublasHandle();
|
||||
auto cublasLtHandle = getCublasLtHandle();
|
||||
mCublasAlgoMap = new tc::cublasAlgoMap(GEMM_CONFIG);
|
||||
mCublasWrapperMutex = new std::mutex();
|
||||
mCublasWrapper
|
||||
= new tc::cublasMMWrapper(cublasHandle, cublasLtHandle, nullptr, mCublasAlgoMap, mCublasWrapperMutex, nullptr);
|
||||
mCublasWrapper.reset(new tc::CublasMMWrapper(cublasHandle, cublasLtHandle, nullptr, nullptr));
|
||||
if (mEnableContextFMHA)
|
||||
{
|
||||
mFMHARunner = new FusedMHARunnerV2(DATA_TYPE_FP16, mNumHeads, mHeadSize, mQScaling);
|
||||
mFMHARunner.reset(new FusedMHARunnerV2(DATA_TYPE_FP16, mNumHeads, mHeadSize, mQScaling));
|
||||
// set flags: force_fp32_acc, is_s_padded, causal_mask, num_kv_heads = num_heads
|
||||
mFMHARunner->setup_flags(mFMHAForceFP32Acc, true, false, mNumHeads);
|
||||
}
|
||||
@ -379,18 +377,6 @@ int BertAttentionPlugin::initialize() noexcept
|
||||
|
||||
void BertAttentionPlugin::destroy() noexcept
|
||||
{
|
||||
delete mCublasAlgoMap;
|
||||
delete mCublasWrapperMutex;
|
||||
delete mCublasWrapper;
|
||||
if (mEnableContextFMHA)
|
||||
{
|
||||
delete mFMHARunner;
|
||||
}
|
||||
|
||||
mCublasAlgoMap = nullptr;
|
||||
mCublasWrapperMutex = nullptr;
|
||||
mCublasWrapper = nullptr;
|
||||
mFMHARunner = nullptr;
|
||||
delete this;
|
||||
}
|
||||
|
||||
|
||||
@ -88,11 +88,10 @@ private:
|
||||
bool mEnableContextFMHA = false;
|
||||
bool mFMHAForceFP32Acc = false;
|
||||
bool mSM = tensorrt_llm::common::getSMVersion();
|
||||
tensorrt_llm::kernels::MHARunner* mFMHARunner;
|
||||
|
||||
tensorrt_llm::common::cublasAlgoMap* mCublasAlgoMap;
|
||||
std::mutex* mCublasWrapperMutex;
|
||||
tensorrt_llm::common::cublasMMWrapper* mCublasWrapper;
|
||||
// The default copy constructor will leave them as nullptr. clone() shall initialize it.
|
||||
UniqPtrWNullCopy<tensorrt_llm::kernels::FusedMHARunnerV2> mFMHARunner;
|
||||
UniqPtrWNullCopy<tensorrt_llm::common::CublasMMWrapper> mCublasWrapper;
|
||||
};
|
||||
|
||||
class BertAttentionPluginCreator : public BaseCreator
|
||||
|
||||
282
cpp/tensorrt_llm/plugins/common/gemmPluginProfiler.cpp
Normal file
282
cpp/tensorrt_llm/plugins/common/gemmPluginProfiler.cpp
Normal file
@ -0,0 +1,282 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/plugins/common/gemmPluginProfiler.h"
|
||||
#include "tensorrt_llm/common/cublasMMWrapper.h"
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h"
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/int8_gemm/int8_gemm.h"
|
||||
|
||||
namespace tensorrt_llm::plugins
|
||||
{
|
||||
|
||||
template <typename Config, typename RunnerPtr, typename GemmIdType, typename GemmIdHashType>
|
||||
GemmPluginProfiler<Config, RunnerPtr, GemmIdType, GemmIdHashType>::GemmPluginProfiler()
|
||||
{
|
||||
mMNKProfileMap = std::make_shared<MNKProfileMap>();
|
||||
|
||||
// set SKIP_GEMM_PLUGIN_PROFILINGS=1 to avoid tactics profilings
|
||||
const auto skipEnv = std::getenv("SKIP_GEMM_PLUGIN_PROFILINGS");
|
||||
mSkip = (skipEnv != NULL && std::stoi(skipEnv));
|
||||
if (mSkip)
|
||||
{
|
||||
TLLM_LOG_DEBUG(
|
||||
"SKIP_GEMM_PLUGIN_PROFILINGS is set. Skipping GEMM plugin profilings. It could result in runtime error "
|
||||
"if default tactic is not defined.");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Config, typename RunnerPtr, typename GemmIdType, typename GemmIdHashType>
|
||||
void GemmPluginProfiler<Config, RunnerPtr, GemmIdType, GemmIdHashType>::serialize(
|
||||
char*& buffer, const GemmIdType& gemmId) const
|
||||
{
|
||||
auto mProfileMap = mMNKProfileMap->getMProfileMap(gemmId);
|
||||
|
||||
// Save number of profiles for given GEMM ID
|
||||
write(buffer, static_cast<int>(mProfileMap->size()));
|
||||
for (const auto& pair : *mProfileMap)
|
||||
{
|
||||
// Save pair of M to the best GEMM config
|
||||
write(buffer, pair);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Config, typename RunnerPtr, typename GemmIdType, typename GemmIdHashType>
|
||||
void GemmPluginProfiler<Config, RunnerPtr, GemmIdType, GemmIdHashType>::deserialize(
|
||||
const char*& data, GemmDims& dims, const GemmIdType& gemmId)
|
||||
{
|
||||
// NOTE(nkorobov): this mutex is not needed since each thread owns its private map, but will put here for
|
||||
// consistency
|
||||
writer_lock lock(mMNKProfileMap->mutex);
|
||||
|
||||
mDims = dims;
|
||||
|
||||
// GemmId gemmId(dims.n, dims.k);
|
||||
if (!mMNKProfileMap->existsMProfileMap(gemmId))
|
||||
{
|
||||
// Create GEMM with GEMM ID if it does not exist
|
||||
mMNKProfileMap->createMProfileMap(gemmId);
|
||||
}
|
||||
// Populate map with profiles of GEMM ID
|
||||
auto profileMap = mMNKProfileMap->getMProfileMap(gemmId);
|
||||
int selectedMapSize;
|
||||
read(data, selectedMapSize);
|
||||
for (int ii = 0; ii < selectedMapSize; ++ii)
|
||||
{
|
||||
std::pair<int, std::optional<Config>> config;
|
||||
read(data, config);
|
||||
profileMap->insert(config);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Config, typename RunnerPtr, typename GemmIdType, typename GemmIdHashType>
|
||||
size_t GemmPluginProfiler<Config, RunnerPtr, GemmIdType, GemmIdHashType>::getSerializationSize(
|
||||
const GemmIdType& gemmId) const
|
||||
{
|
||||
reader_lock lock(mMNKProfileMap->mutex);
|
||||
return sizeof(int) + // size of the tactics map
|
||||
mMNKProfileMap->getMProfileMap(gemmId)->size()
|
||||
* sizeof(std::pair<int, std::optional<Config>>); // size of the tactics map
|
||||
}
|
||||
|
||||
template <typename Config, typename RunnerPtr, typename GemmIdType, typename GemmIdHashType>
|
||||
void GemmPluginProfiler<Config, RunnerPtr, GemmIdType, GemmIdHashType>::profileTactics(
|
||||
const RunnerPtr& runner, const nvinfer1::DataType& type, const GemmDims& dims, const GemmIdType& gemmId)
|
||||
{
|
||||
writer_lock lock(mMNKProfileMap->mutex);
|
||||
|
||||
if (!dims.isInitialized())
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
mRunner = runner;
|
||||
mType = type;
|
||||
|
||||
const int maxM = std::min(nextPowerOfTwo(dims.maxM), MAX_PROFILE_M);
|
||||
computeTmpSize(maxM, dims.n, dims.k);
|
||||
|
||||
if (!mMNKProfileMap->existsMProfileMap(gemmId))
|
||||
{
|
||||
// Create map for GEMM ID
|
||||
mMNKProfileMap->createMProfileMap(gemmId);
|
||||
}
|
||||
|
||||
if (mSkip)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
auto mProfileMap = mMNKProfileMap->getMProfileMap(gemmId);
|
||||
|
||||
auto profileTactics = [&mProfileMap, this](int m, int n, int k)
|
||||
{
|
||||
if (mProfileMap->count(m) == 0)
|
||||
{
|
||||
const auto tactics = this->getTactics(m, n, k);
|
||||
// Profile different tactics for particular m and insert best config to the map
|
||||
mProfileMap->insert({m, this->profileTacticsForProblem(m, n, k, tactics)});
|
||||
}
|
||||
};
|
||||
|
||||
// Allocate tmp data to run GEMMs
|
||||
allocateTmpData();
|
||||
const int startMinMRounded = nextPowerOfTwo(dims.minM);
|
||||
for (int m = startMinMRounded; m < maxM; m *= 2)
|
||||
{
|
||||
profileTactics(m, dims.n, dims.k);
|
||||
}
|
||||
|
||||
profileTactics(maxM, dims.n, dims.k);
|
||||
// Free tmp data
|
||||
freeTmpData();
|
||||
}
|
||||
|
||||
template <typename Config, typename RunnerPtr, typename GemmIdType, typename GemmIdHashType>
|
||||
std::optional<Config> GemmPluginProfiler<Config, RunnerPtr, GemmIdType, GemmIdHashType>::getBestConfig(
|
||||
int m, const GemmIdType& gemmId) const
|
||||
{
|
||||
reader_lock lock(mMNKProfileMap->mutex);
|
||||
|
||||
if (mSkip)
|
||||
{
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
const int mRounded = std::min(nextPowerOfTwo(m), MAX_PROFILE_M);
|
||||
return mMNKProfileMap->getMProfileMap(gemmId)->at(mRounded);
|
||||
}
|
||||
|
||||
template <typename Config, typename RunnerPtr, typename GemmIdType, typename GemmIdHashType>
|
||||
void GemmPluginProfiler<Config, RunnerPtr, GemmIdType, GemmIdHashType>::allocateTmpData()
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(mTmpWorkspaceSizeInBytes > 0, "tmpWorkspaceSizeInBytes must be larger than 0");
|
||||
const auto status = cudaMalloc(&mWorkspaceTmp, mTmpWorkspaceSizeInBytes);
|
||||
TLLM_CHECK_WITH_INFO(status == cudaSuccess, "Can't allocate tmp workspace for GEMM tactics profiling.");
|
||||
}
|
||||
|
||||
template <typename Config, typename RunnerPtr, typename GemmIdType, typename GemmIdHashType>
|
||||
void GemmPluginProfiler<Config, RunnerPtr, GemmIdType, GemmIdHashType>::freeTmpData()
|
||||
{
|
||||
const auto status = cudaFree(mWorkspaceTmp);
|
||||
TLLM_CHECK_WITH_INFO(status == cudaSuccess, "Can't free tmp workspace for GEMM tactics profiling.");
|
||||
}
|
||||
|
||||
template <typename Config, typename RunnerPtr, typename GemmIdType, typename GemmIdHashType>
|
||||
std::optional<Config> GemmPluginProfiler<Config, RunnerPtr, GemmIdType, GemmIdHashType>::profileTacticsForProblem(
|
||||
int m, int n, int k, const std::vector<Config>& tactics)
|
||||
{
|
||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||
|
||||
float bestTime = std::numeric_limits<float>::max();
|
||||
Config bestConfig;
|
||||
bool foundOne = false;
|
||||
|
||||
// Iterate over all tactics for given M, N and K
|
||||
for (int ii = 0; ii < tactics.size(); ++ii)
|
||||
{
|
||||
const Config& candidateConfig = tactics[ii];
|
||||
float time = std::numeric_limits<float>::max();
|
||||
try
|
||||
{
|
||||
if (!checkTactic(m, n, k, candidateConfig))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
// Profile particualar tactic for given M, N and K
|
||||
time = profileTacticForProblem(m, n, k, candidateConfig);
|
||||
foundOne = true;
|
||||
}
|
||||
catch (const std::exception& e)
|
||||
{
|
||||
std::ostringstream msg;
|
||||
msg << "Cannot profile configuration " << ii << " (for"
|
||||
<< " m=" << m << ", n=" << n << ", k=" << k << "). Skipped";
|
||||
TLLM_LOG_WARNING(msg.str());
|
||||
continue;
|
||||
}
|
||||
|
||||
// Choose the fastest tactic
|
||||
if (time < bestTime)
|
||||
{
|
||||
bestConfig = candidateConfig;
|
||||
bestTime = time;
|
||||
}
|
||||
}
|
||||
|
||||
if (!foundOne)
|
||||
{
|
||||
std::ostringstream msg;
|
||||
msg << "Have not found any valid GEMM config for shape ("
|
||||
<< "m=" << m << ", n=" << n << ", k=" << k << "). Will try to use default or fail at runtime";
|
||||
TLLM_LOG_WARNING(msg.str());
|
||||
return std::nullopt;
|
||||
}
|
||||
return {bestConfig};
|
||||
}
|
||||
|
||||
template <typename Config, typename RunnerPtr, typename GemmIdType, typename GemmIdHashType>
|
||||
float GemmPluginProfiler<Config, RunnerPtr, GemmIdType, GemmIdHashType>::profileTacticForProblem(
|
||||
int m, int n, int k, const Config& tactic)
|
||||
{
|
||||
constexpr int warmup = 5;
|
||||
constexpr int runs = 10;
|
||||
|
||||
cudaStream_t stream = cudaStreamDefault;
|
||||
// Warmup the execution
|
||||
for (int i = 0; i < warmup; ++i)
|
||||
{
|
||||
runTactic(m, n, k, tactic, mWorkspaceTmp, stream);
|
||||
}
|
||||
|
||||
cudaEvent_t start;
|
||||
cudaEvent_t stop;
|
||||
cudaEventCreate(&start);
|
||||
cudaEventCreate(&stop);
|
||||
cudaDeviceSynchronize();
|
||||
cudaEventRecord(start, 0);
|
||||
|
||||
// Profile GEMM
|
||||
for (int i = 0; i < runs; ++i)
|
||||
{
|
||||
runTactic(m, n, k, tactic, mWorkspaceTmp, stream);
|
||||
}
|
||||
|
||||
cudaEventRecord(stop, 0);
|
||||
|
||||
cudaEventSynchronize(stop);
|
||||
|
||||
float elapsed;
|
||||
cudaEventElapsedTime(&elapsed, start, stop);
|
||||
|
||||
cudaEventDestroy(start);
|
||||
cudaEventDestroy(stop);
|
||||
|
||||
return elapsed / runs;
|
||||
}
|
||||
|
||||
template class GemmPluginProfiler<tensorrt_llm::cutlass_extensions::CutlassGemmConfig,
|
||||
std::shared_ptr<tensorrt_llm::kernels::cutlass_kernels::CutlassInt8GemmRunnerInterface>, GemmIdCore,
|
||||
GemmIdCoreHash>;
|
||||
|
||||
template class GemmPluginProfiler<tensorrt_llm::cutlass_extensions::CutlassGemmConfig,
|
||||
std::shared_ptr<tensorrt_llm::kernels::cutlass_kernels::CutlassFpAIntBGemmRunnerInterface>, GemmIdCore,
|
||||
GemmIdCoreHash>;
|
||||
|
||||
template class GemmPluginProfiler<cublasLtMatmulHeuristicResult_t,
|
||||
std::shared_ptr<tensorrt_llm::common::CublasMMWrapper>, GemmIdCublas, GemmIdCublasHash>;
|
||||
|
||||
} // namespace tensorrt_llm::plugins
|
||||
@ -88,7 +88,7 @@ public:
|
||||
|
||||
bool operator==(const GemmIdCore& id) const
|
||||
{
|
||||
return n == id.n && k == id.k && dtype == id.dtype;
|
||||
return isEqual(id);
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& out, const GemmIdCore& id)
|
||||
@ -97,6 +97,12 @@ public:
|
||||
out << " type=" << static_cast<int>(id.dtype);
|
||||
return out;
|
||||
}
|
||||
|
||||
protected:
|
||||
bool isEqual(const GemmIdCore& id) const
|
||||
{
|
||||
return n == id.n && k == id.k && dtype == id.dtype;
|
||||
}
|
||||
};
|
||||
|
||||
// Hash of GemmId
|
||||
@ -111,6 +117,50 @@ struct GemmIdCoreHash
|
||||
}
|
||||
};
|
||||
|
||||
class GemmIdCublas : public GemmIdCore
|
||||
{
|
||||
public:
|
||||
bool transA{};
|
||||
bool transB{};
|
||||
|
||||
GemmIdCublas(int n_, int k_, const nvinfer1::DataType& dtype_, bool transA_, bool transB_)
|
||||
: GemmIdCore(n_, k_, dtype_)
|
||||
, transA(transA_)
|
||||
, transB(transB_)
|
||||
{
|
||||
}
|
||||
|
||||
GemmIdCublas() {}
|
||||
|
||||
bool operator==(const GemmIdCublas& id) const
|
||||
{
|
||||
return isEqual(id) && transA == id.transA && transB == id.transB;
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& out, const GemmIdCublas& id)
|
||||
{
|
||||
out << "(N;K)=(" << id.n << ";" << id.k << "),";
|
||||
out << " type=" << static_cast<int>(id.dtype);
|
||||
out << " transA=" << id.transA;
|
||||
out << " transB=" << id.transB;
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
// Hash of GemmIdCublas
|
||||
struct GemmIdCublasHash
|
||||
{
|
||||
std::size_t operator()(const GemmIdCublas& id) const
|
||||
{
|
||||
auto h1 = std::hash<int>{}(id.n);
|
||||
auto h2 = std::hash<int>{}(id.k);
|
||||
auto h3 = std::hash<int>{}(static_cast<int>(id.dtype));
|
||||
auto h4 = std::hash<bool>{}(id.transA);
|
||||
auto h5 = std::hash<bool>{}(id.transB);
|
||||
return h1 ^ h2 ^ h3 ^ h4 ^ h5;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Config, typename RunnerPtr, typename GemmIdType, typename GemmIdHashType>
|
||||
class GemmPluginProfiler
|
||||
{
|
||||
@ -160,118 +210,15 @@ public:
|
||||
|
||||
using MNKProfileMapPtr = std::shared_ptr<MNKProfileMap>;
|
||||
|
||||
GemmPluginProfiler()
|
||||
{
|
||||
mMNKProfileMap = std::make_shared<MNKProfileMap>();
|
||||
GemmPluginProfiler();
|
||||
|
||||
// set SKIP_GEMM_PLUGIN_PROFILINGS=1 to avoid tactics profilings
|
||||
const auto skip = std::getenv("SKIP_GEMM_PLUGIN_PROFILINGS");
|
||||
mSkip = (skip != NULL && std::stoi(skip));
|
||||
if (mSkip)
|
||||
{
|
||||
TLLM_LOG_DEBUG(
|
||||
"SKIP_GEMM_PLUGIN_PROFILINGS is set. Skipping GEMM plugin profilings. It could result in runtime error "
|
||||
"if default tactic is not defined.");
|
||||
}
|
||||
}
|
||||
void serialize(char*& buffer, const GemmIdType& gemmId) const;
|
||||
|
||||
void serialize(char* buffer, const GemmIdType& gemmId) const
|
||||
{
|
||||
auto mProfileMap = mMNKProfileMap->getMProfileMap(gemmId);
|
||||
void deserialize(const char*& data, GemmDims& dims, const GemmIdType& gemmId);
|
||||
size_t getSerializationSize(const GemmIdType& gemmId) const;
|
||||
|
||||
// Save number of profiles for given GEMM ID
|
||||
write(buffer, static_cast<int>(mProfileMap->size()));
|
||||
for (const auto& pair : *mProfileMap)
|
||||
{
|
||||
// Save pair of M to the best GEMM config
|
||||
write(buffer, pair);
|
||||
}
|
||||
}
|
||||
|
||||
void deserialize(const char*& data, GemmDims& dims, const GemmIdType& gemmId)
|
||||
{
|
||||
// NOTE(nkorobov): this mutex is not needed since each thread owns its own map, but will put here for
|
||||
// consistency
|
||||
writer_lock lock(mMNKProfileMap->mutex);
|
||||
|
||||
mDims = dims;
|
||||
|
||||
// GemmId gemmId(dims.n, dims.k);
|
||||
if (!mMNKProfileMap->existsMProfileMap(gemmId))
|
||||
{
|
||||
// Create GEMM with GEMM ID if it does not exist
|
||||
mMNKProfileMap->createMProfileMap(gemmId);
|
||||
}
|
||||
// Populate map with profiles of GEMM ID
|
||||
auto profileMap = mMNKProfileMap->getMProfileMap(gemmId);
|
||||
int selectedMapSize;
|
||||
read(data, selectedMapSize);
|
||||
for (int ii = 0; ii < selectedMapSize; ++ii)
|
||||
{
|
||||
std::pair<int, std::optional<Config>> config;
|
||||
read(data, config);
|
||||
profileMap->insert(config);
|
||||
}
|
||||
}
|
||||
|
||||
size_t getSerializationSize(const GemmIdType& gemmId) const
|
||||
{
|
||||
reader_lock lock(mMNKProfileMap->mutex);
|
||||
return sizeof(int) + // size of the tactics map
|
||||
mMNKProfileMap->getMProfileMap(gemmId)->size()
|
||||
* sizeof(std::pair<int, std::optional<Config>>); // size of the tactics map
|
||||
}
|
||||
|
||||
void profileTactics(const std::vector<Config>& tactics, const RunnerPtr& runner, const nvinfer1::DataType& type,
|
||||
const GemmDims& dims, const GemmIdType& gemmId)
|
||||
{
|
||||
writer_lock lock(mMNKProfileMap->mutex);
|
||||
|
||||
if (!dims.isInitialized())
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
mRunner = runner;
|
||||
mType = type;
|
||||
|
||||
const int maxM = std::min(nextPowerOfTwo(dims.maxM), MAX_PROFILE_M);
|
||||
computeTmpSize(maxM, dims.n, dims.k);
|
||||
|
||||
if (!mMNKProfileMap->existsMProfileMap(gemmId))
|
||||
{
|
||||
// Create map for GEMM ID
|
||||
mMNKProfileMap->createMProfileMap(gemmId);
|
||||
}
|
||||
|
||||
if (mSkip)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
auto mProfileMap = mMNKProfileMap->getMProfileMap(gemmId);
|
||||
|
||||
auto profileTactics = [&tactics, &mProfileMap, this](int m, int n, int k)
|
||||
{
|
||||
if (mProfileMap->count(m) == 0)
|
||||
{
|
||||
// Profile different tactics for particular m and insert best config to the map
|
||||
mProfileMap->insert({m, this->profileTacticsForProblem(m, n, k, tactics)});
|
||||
}
|
||||
};
|
||||
|
||||
// Allocate tmp data to run GEMMs
|
||||
allocateTmpData();
|
||||
const int startMinMRounded = nextPowerOfTwo(dims.minM);
|
||||
for (int m = startMinMRounded; m < maxM; m *= 2)
|
||||
{
|
||||
profileTactics(m, dims.n, dims.k);
|
||||
}
|
||||
|
||||
profileTactics(maxM, dims.n, dims.k);
|
||||
// Free tmp data
|
||||
freeTmpData();
|
||||
}
|
||||
void profileTactics(
|
||||
const RunnerPtr& runner, const nvinfer1::DataType& type, const GemmDims& dims, const GemmIdType& gemmId);
|
||||
|
||||
void setSelectionTactics(const MNKProfileMapPtr& map)
|
||||
{
|
||||
@ -283,19 +230,13 @@ public:
|
||||
mTmpWorkspaceSizeInBytes = bytes;
|
||||
}
|
||||
|
||||
std::optional<Config> getBestConfig(int m, const GemmIdType& gemmId) const
|
||||
void setSkip(bool skip)
|
||||
{
|
||||
reader_lock lock(mMNKProfileMap->mutex);
|
||||
|
||||
if (mSkip)
|
||||
{
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
const int mRounded = std::min(nextPowerOfTwo(m), MAX_PROFILE_M);
|
||||
return mMNKProfileMap->getMProfileMap(gemmId)->at(mRounded);
|
||||
mSkip = mSkip || skip;
|
||||
}
|
||||
|
||||
std::optional<Config> getBestConfig(int m, const GemmIdType& gemmId) const;
|
||||
|
||||
protected:
|
||||
virtual void runTactic(int m, int n, int k, const Config& tactic, char* workspace, const cudaStream_t& stream) = 0;
|
||||
|
||||
@ -306,108 +247,16 @@ protected:
|
||||
return true;
|
||||
}
|
||||
|
||||
virtual std::vector<Config> getTactics(int m, int n, int k) const = 0;
|
||||
|
||||
private:
|
||||
void allocateTmpData()
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(mTmpWorkspaceSizeInBytes > 0, "tmpWorkspaceSizeInBytes must be larger than 0");
|
||||
const auto status = cudaMalloc(&mWorkspaceTmp, mTmpWorkspaceSizeInBytes);
|
||||
TLLM_CHECK_WITH_INFO(status == cudaSuccess, "Can't allocate tmp workspace for GEMM tactics profiling.");
|
||||
}
|
||||
void allocateTmpData();
|
||||
|
||||
void freeTmpData()
|
||||
{
|
||||
const auto status = cudaFree(mWorkspaceTmp);
|
||||
TLLM_CHECK_WITH_INFO(status == cudaSuccess, "Can't free tmp workspace for GEMM tactics profiling.");
|
||||
}
|
||||
void freeTmpData();
|
||||
|
||||
std::optional<Config> profileTacticsForProblem(int m, int n, int k, const std::vector<Config>& tactics)
|
||||
{
|
||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||
std::optional<Config> profileTacticsForProblem(int m, int n, int k, const std::vector<Config>& tactics);
|
||||
|
||||
float bestTime = std::numeric_limits<float>::max();
|
||||
Config bestConfig;
|
||||
bool foundOne = false;
|
||||
|
||||
// Iterate over all tactics for given M, N and K
|
||||
for (int ii = 0; ii < tactics.size(); ++ii)
|
||||
{
|
||||
const Config& candidateConfig = tactics[ii];
|
||||
float time = std::numeric_limits<float>::max();
|
||||
try
|
||||
{
|
||||
if (!checkTactic(m, n, k, candidateConfig))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
// Profile particualar tactic for given M, N and K
|
||||
time = profileTacticForProblem(m, n, k, candidateConfig);
|
||||
foundOne = true;
|
||||
}
|
||||
catch (const std::exception& e)
|
||||
{
|
||||
std::ostringstream msg;
|
||||
msg << "Cannot profile configuration " << ii << " (for"
|
||||
<< " m=" << m << ", n=" << n << ", k=" << k << "). Skipped";
|
||||
TLLM_LOG_WARNING(msg.str());
|
||||
continue;
|
||||
}
|
||||
|
||||
// Choose the fastest tactic
|
||||
if (time < bestTime)
|
||||
{
|
||||
bestConfig = candidateConfig;
|
||||
bestTime = time;
|
||||
}
|
||||
}
|
||||
|
||||
if (!foundOne)
|
||||
{
|
||||
std::ostringstream msg;
|
||||
msg << "Have not found any valid GEMM config for shape ("
|
||||
<< "m=" << m << ", n=" << n << ", k=" << k << "). Will try to use default or fail at runtime";
|
||||
TLLM_LOG_WARNING(msg.str());
|
||||
return std::nullopt;
|
||||
}
|
||||
return {bestConfig};
|
||||
}
|
||||
|
||||
float profileTacticForProblem(int m, int n, int k, const Config& tactic)
|
||||
{
|
||||
constexpr int warmup = 5;
|
||||
constexpr int runs = 10;
|
||||
|
||||
cudaStream_t stream = cudaStreamDefault;
|
||||
// Warmup the execution
|
||||
for (int i = 0; i < warmup; ++i)
|
||||
{
|
||||
runTactic(m, n, k, tactic, mWorkspaceTmp, stream);
|
||||
}
|
||||
|
||||
cudaEvent_t start;
|
||||
cudaEvent_t stop;
|
||||
cudaEventCreate(&start);
|
||||
cudaEventCreate(&stop);
|
||||
cudaDeviceSynchronize();
|
||||
cudaEventRecord(start, 0);
|
||||
|
||||
// Profile GEMM
|
||||
for (int i = 0; i < runs; ++i)
|
||||
{
|
||||
runTactic(m, n, k, tactic, mWorkspaceTmp, stream);
|
||||
}
|
||||
|
||||
cudaEventRecord(stop, 0);
|
||||
|
||||
cudaEventSynchronize(stop);
|
||||
|
||||
float elapsed;
|
||||
cudaEventElapsedTime(&elapsed, start, stop);
|
||||
|
||||
cudaEventDestroy(start);
|
||||
cudaEventDestroy(stop);
|
||||
|
||||
return elapsed / runs;
|
||||
}
|
||||
float profileTacticForProblem(int m, int n, int k, const Config& tactic);
|
||||
|
||||
int nextPowerOfTwo(int v) const
|
||||
{
|
||||
@ -450,9 +299,10 @@ public:
|
||||
mMNKProfileMap = std::make_shared<MNKProfileMap>();
|
||||
}
|
||||
|
||||
GemmPluginProfilerPtr createGemmPluginProfiler(bool inference)
|
||||
GemmPluginProfilerPtr createGemmPluginProfiler(bool inference, bool skip = false)
|
||||
{
|
||||
auto profiler = std::make_shared<GemmPluginProfilerType>();
|
||||
profiler->setSkip(skip);
|
||||
// If the profiler is created during the engine build,
|
||||
// mMNKProfileMap is shared between different profilers to minimize the time spent on the profiling
|
||||
// and do not repeat profiling for the GEMMs of the same shape.
|
||||
|
||||
@ -119,6 +119,29 @@ int8_t* nextWorkspacePtrWithAlignment(int8_t* ptr, uintptr_t previousWorkspaceSi
|
||||
|
||||
size_t calculateTotalWorkspaceSize(size_t* workspaces, int count, const uintptr_t alignment = kCudaMemAlign);
|
||||
|
||||
// Like std::unique_ptr, but does not prevent generation of default copy constructor when used as class members.
|
||||
// The copy constructor produces nullptr. So the plugin default copy constructor will not really copy this, and
|
||||
// your clone() implementation is responsible for initializing such data members.
|
||||
// With this we can simplify clone() implementation when there are many data menbers including at least one unique_ptr.
|
||||
template <typename T, typename Del = std::default_delete<T>>
|
||||
class UniqPtrWNullCopy : public std::unique_ptr<T, Del>
|
||||
{
|
||||
public:
|
||||
using std::unique_ptr<T, Del>::unique_ptr;
|
||||
|
||||
// for compatibility with std::make_unique
|
||||
explicit UniqPtrWNullCopy(std::unique_ptr<T, Del>&& src)
|
||||
: std::unique_ptr<T, Del>::unique_ptr{std::move(src)}
|
||||
{
|
||||
}
|
||||
|
||||
// copy constructor produces nullptr
|
||||
UniqPtrWNullCopy(UniqPtrWNullCopy const&)
|
||||
: std::unique_ptr<T, Del>::unique_ptr{}
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::plugins
|
||||
|
||||
inline bool isBuilding()
|
||||
@ -128,17 +151,6 @@ inline bool isBuilding()
|
||||
return val != nullptr && std::string(val) == "1";
|
||||
}
|
||||
|
||||
#define MPICHECK(cmd) \
|
||||
do \
|
||||
{ \
|
||||
int e = cmd; \
|
||||
if (e != MPI_SUCCESS) \
|
||||
{ \
|
||||
printf("Failed: MPI error %s:%d '%d'\n", __FILE__, __LINE__, e); \
|
||||
exit(EXIT_FAILURE); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
#define NCCLCHECK(cmd) \
|
||||
do \
|
||||
|
||||
19
cpp/tensorrt_llm/plugins/exports.def
Normal file
19
cpp/tensorrt_llm/plugins/exports.def
Normal file
@ -0,0 +1,19 @@
|
||||
; SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
; SPDX-License-Identifier: Apache-2.0
|
||||
;
|
||||
; Licensed under the Apache License, Version 2.0 (the "License");
|
||||
; you may not use this file except in compliance with the License.
|
||||
; You may obtain a copy of the License at
|
||||
;
|
||||
; http://www.apache.org/licenses/LICENSE-2.0
|
||||
;
|
||||
; Unless required by applicable law or agreed to in writing, software
|
||||
; distributed under the License is distributed on an "AS IS" BASIS,
|
||||
; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
; See the License for the specific language governing permissions and
|
||||
; limitations under the License.
|
||||
|
||||
LIBRARY nvinfer_plugin_tensorrt_llm
|
||||
EXPORTS
|
||||
getPluginRegistry
|
||||
initLibNvInferPlugins
|
||||
@ -47,15 +47,17 @@ void runGemm(const int M, const int N, const int K, const bool transA, const boo
|
||||
const CublasGemmWrapperPtr& cublasWrapperPtr, const void* act, const void* weight, void* output,
|
||||
const std::optional<cublasLtMatmulHeuristicResult_t>& heuristic, void* workspace, cudaStream_t stream)
|
||||
{
|
||||
auto cublasHandle = cublasWrapperPtr->getCublasHandle();
|
||||
TLLM_CUDA_CHECK(cublasSetStream(cublasHandle, stream));
|
||||
cublasWrapperPtr->setStream(stream);
|
||||
cublasWrapperPtr->setWorkspace(workspace);
|
||||
|
||||
cublasOperation_t transa, transb;
|
||||
int m, n, k;
|
||||
int lda, ldb, ldc;
|
||||
getProblemParams(transa, transb, m, n, k, lda, ldb, ldc, transA, transB, M, N, K);
|
||||
|
||||
cublasWrapperPtr->createDescriptors(transa, transb, m, n, k, lda, ldb, ldc);
|
||||
cublasWrapperPtr->Gemm(transa, transb, m, n, k, weight, lda, act, ldb, output, ldc, heuristic);
|
||||
cublasWrapperPtr->destroyDescriptors();
|
||||
}
|
||||
|
||||
void CublasLtGemmPluginProfiler::runTactic(
|
||||
@ -80,11 +82,17 @@ void CublasLtGemmPluginProfiler::runTactic(
|
||||
bool CublasLtGemmPluginProfiler::checkTactic(int m, int n, int k, const Config& tactic) const
|
||||
{
|
||||
cublasOperation_t transa, transb;
|
||||
int M, N, K;
|
||||
int M = m, N = n, K = k;
|
||||
int lda, ldb, ldc;
|
||||
getProblemParams(transa, transb, m, n, k, lda, ldb, ldc, mTransA, mTransB, n, m, k);
|
||||
getProblemParams(transa, transb, m, n, k, lda, ldb, ldc, mTransA, mTransB, M, N, K);
|
||||
|
||||
return mRunner->checkTactic(transa, transb, m, n, k, lda, ldb, ldc, tactic);
|
||||
mRunner->createDescriptors(transa, transb, m, n, k, lda, ldb, ldc);
|
||||
|
||||
const auto checkResult = mRunner->checkTactic(transa, transb, m, n, k, lda, ldb, ldc, tactic.algo);
|
||||
|
||||
mRunner->destroyDescriptors();
|
||||
|
||||
return checkResult;
|
||||
}
|
||||
|
||||
void CublasLtGemmPluginProfiler::computeTmpSize(int maxM, int n, int k)
|
||||
@ -105,6 +113,20 @@ void CublasLtGemmPluginProfiler::computeTmpSize(int maxM, int n, int k)
|
||||
setTmpWorkspaceSizeInBytes(bytes);
|
||||
}
|
||||
|
||||
std::vector<CublasLtGemmPluginProfiler::Config> CublasLtGemmPluginProfiler::getTactics(int M, int N, int K) const
|
||||
{
|
||||
cublasOperation_t transa, transb;
|
||||
int m, n, k;
|
||||
int lda, ldb, ldc;
|
||||
getProblemParams(transa, transb, m, n, k, lda, ldb, ldc, mTransA, mTransB, M, N, K);
|
||||
|
||||
mRunner->createDescriptors(transa, transb, m, n, k, lda, ldb, ldc);
|
||||
const auto heruistics = mRunner->getTactics(transa, transb, m, n, k, lda, ldb, ldc);
|
||||
mRunner->destroyDescriptors();
|
||||
|
||||
return heruistics;
|
||||
}
|
||||
|
||||
GemmPlugin::GemmPlugin(
|
||||
int transA, int transB, nvinfer1::DataType type, bool useFp8, const GemmPlugin::PluginProfilerPtr& pluginProfiler)
|
||||
: mTransA(transA)
|
||||
@ -138,14 +160,11 @@ void GemmPlugin::init()
|
||||
{
|
||||
auto cublasHandle = getCublasHandle();
|
||||
auto cublasLtHandle = getCublasLtHandle();
|
||||
mCublasAlgoMap = std::make_shared<cublasAlgoMap>(GEMM_CONFIG);
|
||||
mCublasWrapperMutex = std::make_shared<std::mutex>();
|
||||
mCublasWrapper = std::make_shared<cublasMMWrapper>(
|
||||
cublasHandle, cublasLtHandle, nullptr, mCublasAlgoMap.get(), mCublasWrapperMutex.get(), nullptr);
|
||||
mCublasWrapper = std::make_shared<CublasMMWrapper>(cublasHandle, cublasLtHandle, nullptr, nullptr);
|
||||
|
||||
mPluginProfiler->setTranspose(mTransA, mTransB);
|
||||
|
||||
mGemmId = GemmIdCublas(GemmIdCore(mDims.n, mDims.k, mType), mTransA, mTransB);
|
||||
mGemmId = GemmIdCublas(mDims.n, mDims.k, mType, mTransA, mTransB);
|
||||
}
|
||||
|
||||
void GemmPlugin::setGemmConfig()
|
||||
@ -182,18 +201,7 @@ void GemmPlugin::configGemm()
|
||||
|
||||
setGemmConfig();
|
||||
|
||||
std::vector<cublasLtMatmulHeuristicResult_t> totalHeruistics;
|
||||
for (int mCur = mDims.minM; mCur < mDims.maxM; mCur *= 2)
|
||||
{
|
||||
cublasOperation_t transa, transb;
|
||||
int m, n, k;
|
||||
int lda, ldb, ldc;
|
||||
getProblemParams(transa, transb, m, n, k, lda, ldb, ldc, mTransA, mTransB, mCur, mDims.n, mDims.k);
|
||||
const auto heruistics = mCublasWrapper->getTactics(transa, transb, m, n, k, lda, ldb, ldc);
|
||||
|
||||
totalHeruistics.insert(totalHeruistics.end(), heruistics.begin(), heruistics.end());
|
||||
}
|
||||
mPluginProfiler->profileTactics(totalHeruistics, mCublasWrapper, mType, mDims, mGemmId);
|
||||
mPluginProfiler->profileTactics(mCublasWrapper, mType, mDims, mGemmId);
|
||||
}
|
||||
|
||||
// IPluginV2DynamicExt Methods
|
||||
@ -313,7 +321,8 @@ void GemmPlugin::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, in
|
||||
{
|
||||
mDims = {minM, maxM, N, K};
|
||||
}
|
||||
mGemmId.gemmIdCore = {N, K, mType};
|
||||
mGemmId.n = N;
|
||||
mGemmId.k = K;
|
||||
}
|
||||
|
||||
size_t GemmPlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs,
|
||||
@ -339,9 +348,7 @@ int GemmPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinf
|
||||
const auto N = computeNDimension(mTransB, nbDimsB, inputDesc[1].dims.d);
|
||||
const int K = mTransA ? inputDesc[0].dims.d[0] : inputDesc[0].dims.d[nbDimsA - 1];
|
||||
|
||||
// FIXME(nkorobov): enable best config selection
|
||||
// const auto& bestTactic = mPluginProfiler->getBestConfig(M, mGemmId);
|
||||
const std::optional<CublasLtGemmPluginProfiler::Config> bestTactic = {};
|
||||
auto bestTactic = mPluginProfiler->getBestConfig(M, mGemmId);
|
||||
runGemm(M, N, K, mTransA, mTransB, mType, mCublasWrapper, inputs[0], inputs[1], outputs[0], bestTactic, workspace,
|
||||
stream);
|
||||
return 0;
|
||||
@ -465,7 +472,8 @@ IPluginV2* GemmPluginCreator::createPlugin(const char* name, const PluginFieldCo
|
||||
{
|
||||
// GemmPluginCreator is unique and shared for an engine generation
|
||||
// Create plugin profiler with shared tactics map
|
||||
auto pluginProfiler = gemmPluginProfileManager.createGemmPluginProfiler(/* inference */ false);
|
||||
// FIXME(nkorobov) enable tactic profiler
|
||||
auto pluginProfiler = gemmPluginProfileManager.createGemmPluginProfiler(/* inference */ false, /* skip */ true);
|
||||
auto* obj = new GemmPlugin(transA, transB, type, useFp8, pluginProfiler);
|
||||
obj->setPluginNamespace(mNamespace.c_str());
|
||||
return obj;
|
||||
@ -485,7 +493,8 @@ IPluginV2* GemmPluginCreator::deserializePlugin(const char* name, const void* se
|
||||
{
|
||||
// GemmPluginCreator is unique and shared for an engine generation
|
||||
// Create plugin profiler with shared tactics map
|
||||
auto pluginProfiler = gemmPluginProfileManager.createGemmPluginProfiler(/* inference */ true);
|
||||
// FIXME(nkorobov) enable tactic profiler
|
||||
auto pluginProfiler = gemmPluginProfileManager.createGemmPluginProfiler(/* inference */ true, /* skip */ true);
|
||||
auto* obj = new GemmPlugin(serialData, serialLength, pluginProfiler);
|
||||
obj->setPluginNamespace(mNamespace.c_str());
|
||||
return obj;
|
||||
|
||||
@ -27,51 +27,9 @@
|
||||
namespace tensorrt_llm::plugins
|
||||
{
|
||||
|
||||
using CublasGemmWrapper = tensorrt_llm::common::cublasMMWrapper;
|
||||
using CublasGemmWrapper = tensorrt_llm::common::CublasMMWrapper;
|
||||
using CublasGemmWrapperPtr = std::shared_ptr<CublasGemmWrapper>;
|
||||
|
||||
class GemmIdCublas
|
||||
{
|
||||
public:
|
||||
GemmIdCore gemmIdCore{};
|
||||
bool transA{};
|
||||
bool transB{};
|
||||
|
||||
GemmIdCublas(const GemmIdCore& gemmIdCore_, bool transA_, bool transB_)
|
||||
: gemmIdCore(gemmIdCore_)
|
||||
, transA(transA_)
|
||||
, transB(transB_)
|
||||
{
|
||||
}
|
||||
|
||||
GemmIdCublas() {}
|
||||
|
||||
bool operator==(const GemmIdCublas& id) const
|
||||
{
|
||||
return gemmIdCore == id.gemmIdCore && transA == id.transA && transB == id.transB;
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& out, const GemmIdCublas& id)
|
||||
{
|
||||
out << "Core ID = {" << id.gemmIdCore << "}";
|
||||
out << " transA=" << id.transA;
|
||||
out << " transB=" << id.transB;
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
// Hash of GemmIdCublas
|
||||
struct GemmIdCublasHash
|
||||
{
|
||||
std::size_t operator()(const GemmIdCublas& id) const
|
||||
{
|
||||
auto h1 = GemmIdCoreHash()(id.gemmIdCore);
|
||||
auto h2 = std::hash<bool>{}(id.transA);
|
||||
auto h3 = std::hash<bool>{}(id.transB);
|
||||
return h1 ^ h2 ^ h3;
|
||||
}
|
||||
};
|
||||
|
||||
class CublasLtGemmPluginProfiler
|
||||
: public GemmPluginProfiler<cublasLtMatmulHeuristicResult_t, CublasGemmWrapperPtr, GemmIdCublas, GemmIdCublasHash>
|
||||
{
|
||||
@ -91,6 +49,8 @@ protected:
|
||||
|
||||
bool checkTactic(int m, int n, int k, const Config& tactic) const override;
|
||||
|
||||
std::vector<Config> getTactics(int m, int n, int k) const override;
|
||||
|
||||
private:
|
||||
bool mTransA;
|
||||
bool mTransB;
|
||||
@ -150,8 +110,8 @@ private:
|
||||
int mTransB;
|
||||
nvinfer1::DataType mType;
|
||||
|
||||
std::shared_ptr<tensorrt_llm::common::cublasAlgoMap> mCublasAlgoMap;
|
||||
std::shared_ptr<std::mutex> mCublasWrapperMutex;
|
||||
// @fixme: seems this is shared across multiple clones.
|
||||
// If we deep copy the wrapper inside clone(), then we may avoid the mutex inside the wrapper?
|
||||
CublasGemmWrapperPtr mCublasWrapper;
|
||||
|
||||
GemmDims mDims{};
|
||||
|
||||
@ -79,7 +79,9 @@ struct FusedQKVMaskedAttentionDispatchParams
|
||||
int size_per_head;
|
||||
int rotary_embedding_dim;
|
||||
float rotary_embedding_base;
|
||||
RotaryScalingType rotary_embedding_scale_type;
|
||||
float rotary_embedding_scale;
|
||||
int rotary_embedding_max_positions;
|
||||
PositionEmbeddingType position_embedding_type;
|
||||
int max_seq_len;
|
||||
const int* input_lengths;
|
||||
@ -160,7 +162,9 @@ void fusedQKV_masked_attention_dispatch(
|
||||
params.hidden_size_per_head = input_params.size_per_head;
|
||||
params.rotary_embedding_dim = input_params.rotary_embedding_dim;
|
||||
params.rotary_embedding_base = input_params.rotary_embedding_base;
|
||||
params.rotary_embedding_scale_type = input_params.rotary_embedding_scale_type;
|
||||
params.rotary_embedding_scale = input_params.rotary_embedding_scale;
|
||||
params.rotary_embedding_max_positions = input_params.rotary_embedding_max_positions;
|
||||
params.position_embedding_type = input_params.position_embedding_type;
|
||||
// Note: keep norm factor (sqrt(K_dim)) when adopting megatron T5 structure (may adjust)
|
||||
params.inv_sqrt_dh = 1.F / (sqrtf((float) params.hidden_size_per_head) * input_params.q_scaling);
|
||||
@ -215,17 +219,17 @@ template void fusedQKV_masked_attention_dispatch(
|
||||
template void fusedQKV_masked_attention_dispatch(
|
||||
const FusedQKVMaskedAttentionDispatchParams<half, KVBlockArray>&, cudaStream_t stream);
|
||||
|
||||
GPTAttentionPluginCommon::GPTAttentionPluginCommon(int num_heads, int num_kv_heads, int unidirectional, float q_scaling,
|
||||
tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
|
||||
GPTAttentionPluginCommon::GPTAttentionPluginCommon(int num_heads, int num_kv_heads, int head_size, int unidirectional,
|
||||
float q_scaling, tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
|
||||
int rotary_embedding_dim, // for RoPE. Use 0 for non-RoPE
|
||||
float rotary_embedding_base, tensorrt_llm::kernels::RotaryScalingType rotary_embedding_scale_type,
|
||||
float rotary_embedding_scale, int rotary_embedding_max_positions, int tp_size, int tp_rank, // for ALiBi
|
||||
tensorrt_llm::kernels::ContextFMHAType context_fmha_type, bool multi_block_mode, int kv_cache_quant_mode,
|
||||
bool remove_input_padding, tensorrt_llm::kernels::AttentionMaskType mask_type, bool paged_kv_cache,
|
||||
nvinfer1::DataType type, int32_t max_context_length, bool qkv_bias_enabled)
|
||||
int tokens_per_block, nvinfer1::DataType type, int32_t max_context_length, bool qkv_bias_enabled)
|
||||
: mNumHeads(num_heads)
|
||||
, mNumKVHeads(num_kv_heads)
|
||||
, mHeadSize(-1)
|
||||
, mHeadSize(head_size)
|
||||
, mUnidirectional(unidirectional)
|
||||
, mQScaling(q_scaling)
|
||||
, mRotaryEmbeddingDim(rotary_embedding_dim)
|
||||
@ -242,6 +246,7 @@ GPTAttentionPluginCommon::GPTAttentionPluginCommon(int num_heads, int num_kv_hea
|
||||
, mKVCacheQuantMode(kv_cache_quant_mode)
|
||||
, mRemovePadding(remove_input_padding)
|
||||
, mPagedKVCache(paged_kv_cache)
|
||||
, mTokensPerBlock(tokens_per_block)
|
||||
, mTpSize(tp_size)
|
||||
, mTpRank(tp_rank)
|
||||
, mMaxContextLength(max_context_length)
|
||||
@ -288,6 +293,7 @@ GPTAttentionPluginCommon::GPTAttentionPluginCommon(const void* data, size_t leng
|
||||
read(d, mRemovePadding);
|
||||
read(d, mMaskType);
|
||||
read(d, mPagedKVCache);
|
||||
read(d, mTokensPerBlock);
|
||||
read(d, mType);
|
||||
read(d, mMaxContextLength);
|
||||
read(d, mQKVBiasEnabled);
|
||||
@ -313,9 +319,9 @@ size_t GPTAttentionPluginCommon::getWorkspaceSizeForContext(
|
||||
const int batch_size = nbReq;
|
||||
const size_t attention_mask_size = mEnableContextFMHA ? 0 : size * batch_size * max_input_length * max_input_length;
|
||||
const size_t cu_seqlens_size = sizeof(int) * (batch_size + 1);
|
||||
const size_t q_buf_2_size = size * batch_size * input_seq_length * local_hidden_units_qo;
|
||||
const size_t k_buf_2_size = size * batch_size * input_seq_length * local_hidden_units_kv;
|
||||
const size_t v_buf_2_size = size * batch_size * input_seq_length * local_hidden_units_kv;
|
||||
const size_t q_buf_2_size = mEnableContextFMHA ? 0 : size * batch_size * input_seq_length * local_hidden_units_qo;
|
||||
const size_t k_buf_2_size = mEnableContextFMHA ? 0 : size * batch_size * input_seq_length * local_hidden_units_kv;
|
||||
const size_t v_buf_2_size = mEnableContextFMHA ? 0 : size * batch_size * input_seq_length * local_hidden_units_kv;
|
||||
const size_t qk_buf_size
|
||||
= mEnableContextFMHA ? 0 : size * batch_size * mNumHeads * input_seq_length * input_seq_length;
|
||||
const size_t qkv_buf_2_size = mEnableContextFMHA ? 0 : size * batch_size * input_seq_length * local_hidden_units_qo;
|
||||
@ -403,8 +409,8 @@ int GPTAttentionPluginCommon::enqueueContext(const EnqueueContextParams<T, KVCac
|
||||
if (mPagedKVCache)
|
||||
{
|
||||
using BufferDataType = typename KVCacheBufferDataType<KVCacheBuffer>::Type;
|
||||
kv_cache_buffer = KVCacheBuffer(params.batch_size, params.max_blocks_per_sequence, params.tokens_per_block,
|
||||
num_kv_heads * head_size * elem_size);
|
||||
kv_cache_buffer = KVCacheBuffer(
|
||||
params.batch_size, params.max_blocks_per_sequence, mTokensPerBlock, num_kv_heads * head_size * elem_size);
|
||||
kv_cache_buffer.data = reinterpret_cast<BufferDataType*>(params.block_pointers);
|
||||
}
|
||||
else
|
||||
@ -455,9 +461,12 @@ int GPTAttentionPluginCommon::enqueueContext(const EnqueueContextParams<T, KVCac
|
||||
const size_t attention_mask_size
|
||||
= mEnableContextFMHA ? 0 : sizeof(T) * params.batch_size * params.input_seq_length * params.input_seq_length;
|
||||
const size_t cu_seqlens_size = sizeof(int) * (params.batch_size + 1);
|
||||
const size_t q_buf_2_size = sizeof(T) * params.batch_size * params.input_seq_length * local_hidden_units_qo;
|
||||
const size_t k_buf_2_size = sizeof(T) * params.batch_size * params.input_seq_length * local_hidden_units_kv;
|
||||
const size_t v_buf_2_size = sizeof(T) * params.batch_size * params.input_seq_length * local_hidden_units_kv;
|
||||
const size_t q_buf_2_size
|
||||
= mEnableContextFMHA ? 0 : sizeof(T) * params.batch_size * params.input_seq_length * local_hidden_units_qo;
|
||||
const size_t k_buf_2_size
|
||||
= mEnableContextFMHA ? 0 : sizeof(T) * params.batch_size * params.input_seq_length * local_hidden_units_kv;
|
||||
const size_t v_buf_2_size
|
||||
= mEnableContextFMHA ? 0 : sizeof(T) * params.batch_size * params.input_seq_length * local_hidden_units_kv;
|
||||
const size_t qk_buf_size = mEnableContextFMHA
|
||||
? 0
|
||||
: sizeof(T) * params.batch_size * mNumHeads * params.input_seq_length * params.input_seq_length;
|
||||
@ -498,30 +507,9 @@ int GPTAttentionPluginCommon::enqueueContext(const EnqueueContextParams<T, KVCac
|
||||
invokeBuildDecoderInfo(decoder_params, stream);
|
||||
sync_check_cuda_error();
|
||||
|
||||
// FIXME(qijun): a temporary solution to make sure the padding part of key/value buffer is 0
|
||||
// NOTE: pointer subtraction is used below since there could be some extra gap due to alignment.
|
||||
// Otherwise, we could do cudaMemsetAsync(k_buf_2_, 0, k_buf_2_size + v_buf_2_size, stream);
|
||||
cudaMemsetAsync(
|
||||
k_buf_2_, 0, reinterpret_cast<int8_t*>(v_buf_2_) - reinterpret_cast<int8_t*>(k_buf_2_) + v_buf_2_size, stream);
|
||||
float rotary_base, rotary_scale;
|
||||
const int32_t kv_seq_len = params.input_seq_length;
|
||||
update_rotary_params(kv_seq_len, rotary_base, rotary_scale);
|
||||
|
||||
invokeAddFusedQKVBiasTranspose(q_buf_2_, k_buf_2_, v_buf_2_, const_cast<T*>(params.attention_input),
|
||||
const_cast<T*>(params.qkv_bias), params.context_lengths, mRemovePadding ? padding_offset : nullptr,
|
||||
request_batch_size, request_seq_length, params.num_tokens, mNumHeads, mNumKVHeads, getHeadSize(),
|
||||
mEnableContextFMHA, mRotaryEmbeddingDim, rotary_base, rotary_scale, position_embedding_type, (float*) nullptr,
|
||||
0, stream);
|
||||
|
||||
sync_check_cuda_error();
|
||||
|
||||
const KvCacheDataType cache_type = mKVCacheQuantMode.hasInt8KvCache()
|
||||
? KvCacheDataType::INT8
|
||||
: (mKVCacheQuantMode.hasFp8KvCache() ? KvCacheDataType::FP8 : KvCacheDataType::BASE);
|
||||
invokeTranspose4dBatchMajor(k_buf_2_, v_buf_2_, kv_cache_buffer, request_batch_size, request_seq_length,
|
||||
params.max_seq_length, getHeadSize(), mNumKVHeads, cache_type, params.kv_scale_orig_quant,
|
||||
params.context_lengths, stream);
|
||||
sync_check_cuda_error();
|
||||
|
||||
const cudaDataType_t gemm_data_type = tc::CudaDataType<T>::value;
|
||||
const int attention_seq_len_1 = request_seq_length; // q length
|
||||
@ -530,11 +518,34 @@ int GPTAttentionPluginCommon::enqueueContext(const EnqueueContextParams<T, KVCac
|
||||
|
||||
if (mEnableContextFMHA)
|
||||
{
|
||||
invokeApplyBiasRopeUpdateKVCache(const_cast<T*>(params.attention_input), kv_cache_buffer,
|
||||
const_cast<T*>(params.qkv_bias), params.context_lengths, mRemovePadding ? padding_offset : nullptr,
|
||||
request_batch_size, request_seq_length, params.num_tokens, mNumHeads, mNumKVHeads, getHeadSize(),
|
||||
mRotaryEmbeddingDim, mRotaryEmbeddingBase, mRotaryEmbeddingScaleType, mRotaryEmbeddingScale,
|
||||
mRotaryEmbeddingMaxPositions, position_embedding_type, (float*) nullptr, 0, cache_type,
|
||||
params.kv_scale_orig_quant, stream);
|
||||
mFMHARunner->setup(request_batch_size, request_seq_length, params.num_tokens, isALiBi(), mTpSize, mTpRank);
|
||||
mFMHARunner->run(const_cast<T*>(params.attention_input), cu_seqlens, params.context_buf, stream);
|
||||
}
|
||||
else
|
||||
{
|
||||
// FIXME(qijun): a temporary solution to make sure the padding part of key/value buffer is 0
|
||||
// NOTE: pointer subtraction is used below since there could be some extra gap due to alignment.
|
||||
// Otherwise, we could do cudaMemsetAsync(k_buf_2_, 0, k_buf_2_size + v_buf_2_size, stream);
|
||||
cudaMemsetAsync(k_buf_2_, 0,
|
||||
reinterpret_cast<int8_t*>(v_buf_2_) - reinterpret_cast<int8_t*>(k_buf_2_) + v_buf_2_size, stream);
|
||||
invokeAddFusedQKVBiasTranspose(q_buf_2_, k_buf_2_, v_buf_2_, const_cast<T*>(params.attention_input),
|
||||
const_cast<T*>(params.qkv_bias), params.context_lengths, mRemovePadding ? padding_offset : nullptr,
|
||||
request_batch_size, request_seq_length, params.num_tokens, mNumHeads, mNumKVHeads, getHeadSize(),
|
||||
mEnableContextFMHA, mRotaryEmbeddingDim, mRotaryEmbeddingBase, mRotaryEmbeddingScaleType,
|
||||
mRotaryEmbeddingScale, mRotaryEmbeddingMaxPositions, position_embedding_type, (float*) nullptr, 0, stream);
|
||||
sync_check_cuda_error();
|
||||
|
||||
invokeTranspose4dBatchMajor(k_buf_2_, v_buf_2_, kv_cache_buffer, request_batch_size, request_seq_length,
|
||||
params.max_seq_length, getHeadSize(), mNumKVHeads, cache_type, params.kv_scale_orig_quant,
|
||||
params.context_lengths, stream);
|
||||
sync_check_cuda_error();
|
||||
|
||||
const T* linear_bias_slopes = isALiBi() ? params.alibi_slopes : nullptr;
|
||||
cudaDataType_t gemm_out_data_type = is_qk_buf_float_ ? CUDA_R_32F : gemm_data_type;
|
||||
void* gemm_out_buf_ = is_qk_buf_float_ ? static_cast<void*>(qk_buf_float_) : static_cast<void*>(qk_buf_);
|
||||
@ -787,7 +798,7 @@ int GPTAttentionPluginCommon::enqueueGeneration(
|
||||
{
|
||||
using BufferDataType = typename KVCacheBufferDataType<KVCacheBuffer>::Type;
|
||||
kv_cache_buffer = KVCacheBuffer(
|
||||
batch_beam, params.max_blocks_per_sequence, params.tokens_per_block, num_kv_heads * head_size * elem_size);
|
||||
batch_beam, params.max_blocks_per_sequence, mTokensPerBlock, num_kv_heads * head_size * elem_size);
|
||||
kv_cache_buffer.data = reinterpret_cast<BufferDataType*>(params.block_pointers);
|
||||
}
|
||||
else
|
||||
@ -841,8 +852,10 @@ int GPTAttentionPluginCommon::enqueueGeneration(
|
||||
dispatch_params.kv_scale_quant_orig = params.kv_scale_quant_orig;
|
||||
dispatch_params.kv_block_array = kv_cache_buffer;
|
||||
dispatch_params.multi_processor_count = mMultiProcessorCount;
|
||||
const int32_t kv_seq_len = step;
|
||||
update_rotary_params(kv_seq_len, dispatch_params.rotary_embedding_base, dispatch_params.rotary_embedding_scale);
|
||||
dispatch_params.rotary_embedding_base = mRotaryEmbeddingBase;
|
||||
dispatch_params.rotary_embedding_scale_type = mRotaryEmbeddingScaleType;
|
||||
dispatch_params.rotary_embedding_scale = mRotaryEmbeddingScale;
|
||||
dispatch_params.rotary_embedding_max_positions = mRotaryEmbeddingMaxPositions;
|
||||
fusedQKV_masked_attention_dispatch(dispatch_params, stream);
|
||||
sync_check_cuda_error();
|
||||
return 0;
|
||||
@ -875,10 +888,7 @@ int GPTAttentionPluginCommon::initialize() noexcept
|
||||
auto cublasHandle = getCublasHandle();
|
||||
auto cublasLtHandle = getCublasLtHandle();
|
||||
|
||||
mCublasAlgoMap = new tc::cublasAlgoMap(GEMM_CONFIG);
|
||||
mCublasWrapperMutex = new std::mutex();
|
||||
mCublasWrapper
|
||||
= new tc::cublasMMWrapper(cublasHandle, cublasLtHandle, nullptr, mCublasAlgoMap, mCublasWrapperMutex, nullptr);
|
||||
mCublasWrapper.reset(new tc::CublasMMWrapper(cublasHandle, cublasLtHandle, nullptr, nullptr));
|
||||
if (mEnableContextFMHA)
|
||||
{
|
||||
// Pre-checked during constructing.
|
||||
@ -896,7 +906,7 @@ int GPTAttentionPluginCommon::initialize() noexcept
|
||||
TLLM_CHECK_WITH_INFO(false, "GPTAttentionPlugin received wrong data type.");
|
||||
}
|
||||
|
||||
mFMHARunner = new FusedMHARunnerV2(data_type, mNumHeads, getHeadSize(false), mQScaling);
|
||||
mFMHARunner.reset(new FusedMHARunnerV2(data_type, mNumHeads, getHeadSize(false), mQScaling));
|
||||
// set flags: force_fp32_acc, is_s_padded, causal_mask, num_kv_heads.
|
||||
mFMHARunner->setup_flags(mFMHAForceFP32Acc, !mRemovePadding, true, mNumKVHeads);
|
||||
}
|
||||
@ -906,19 +916,6 @@ int GPTAttentionPluginCommon::initialize() noexcept
|
||||
|
||||
void GPTAttentionPluginCommon::destroy() noexcept
|
||||
{
|
||||
delete mCublasAlgoMap;
|
||||
delete mCublasWrapperMutex;
|
||||
delete mCublasWrapper;
|
||||
if (mEnableContextFMHA)
|
||||
{
|
||||
delete mFMHARunner;
|
||||
}
|
||||
|
||||
mCublasAlgoMap = nullptr;
|
||||
mCublasWrapperMutex = nullptr;
|
||||
mCublasWrapper = nullptr;
|
||||
mFMHARunner = nullptr;
|
||||
|
||||
delete this;
|
||||
}
|
||||
|
||||
@ -929,8 +926,8 @@ size_t GPTAttentionPluginCommon::getCommonSerializationSize() noexcept
|
||||
+ sizeof(mRotaryEmbeddingScaleType) + sizeof(mRotaryEmbeddingScale) + sizeof(mRotaryEmbeddingMaxPositions)
|
||||
+ sizeof(mTpSize) + sizeof(mTpRank) + sizeof(mEnableContextFMHA) + sizeof(mFMHAForceFP32Acc)
|
||||
+ sizeof(mMultiBlockMode) + sizeof(unsigned int) // mKVCacheQuantMode
|
||||
+ sizeof(mRemovePadding) + sizeof(mMaskType) + sizeof(mPagedKVCache) + sizeof(mType) + sizeof(mMaxContextLength)
|
||||
+ sizeof(mQKVBiasEnabled);
|
||||
+ sizeof(mRemovePadding) + sizeof(mMaskType) + sizeof(mPagedKVCache) + sizeof(mTokensPerBlock) + sizeof(mType)
|
||||
+ sizeof(mMaxContextLength) + sizeof(mQKVBiasEnabled);
|
||||
}
|
||||
|
||||
void GPTAttentionPluginCommon::serializeCommon(void* buffer) const noexcept
|
||||
@ -956,6 +953,7 @@ void GPTAttentionPluginCommon::serializeCommon(void* buffer) const noexcept
|
||||
write(d, mRemovePadding);
|
||||
write(d, mMaskType);
|
||||
write(d, mPagedKVCache);
|
||||
write(d, mTokensPerBlock);
|
||||
write(d, mType);
|
||||
write(d, mMaxContextLength);
|
||||
write(d, mQKVBiasEnabled);
|
||||
@ -975,6 +973,7 @@ GPTAttentionPluginCreatorCommon::GPTAttentionPluginCreatorCommon()
|
||||
mPluginAttributes.clear();
|
||||
mPluginAttributes.emplace_back(PluginField("num_heads", nullptr, PluginFieldType::kINT32, -1));
|
||||
mPluginAttributes.emplace_back(PluginField("num_kv_heads", nullptr, PluginFieldType::kINT32, 0));
|
||||
mPluginAttributes.emplace_back(PluginField("head_size", nullptr, PluginFieldType::kINT32, 0));
|
||||
mPluginAttributes.emplace_back(PluginField("unidirectional", nullptr, PluginFieldType::kINT32, 1));
|
||||
mPluginAttributes.emplace_back(PluginField("q_scaling", nullptr, PluginFieldType::kFLOAT32, 1.0));
|
||||
mPluginAttributes.emplace_back(PluginField("position_embedding_type", nullptr, PluginFieldType::kINT8, 0));
|
||||
@ -991,6 +990,7 @@ GPTAttentionPluginCreatorCommon::GPTAttentionPluginCreatorCommon()
|
||||
mPluginAttributes.emplace_back(PluginField("remove_input_padding", nullptr, PluginFieldType::kINT8, 0));
|
||||
mPluginAttributes.emplace_back(PluginField("mask_type", nullptr, PluginFieldType::kINT32, 0));
|
||||
mPluginAttributes.emplace_back(PluginField("paged_kv_cache", nullptr, PluginFieldType::kINT32, 0));
|
||||
mPluginAttributes.emplace_back(PluginField("tokens_per_block", nullptr, PluginFieldType::kINT32, 0));
|
||||
mPluginAttributes.emplace_back(PluginField("type_id", nullptr, PluginFieldType::kINT32, 1));
|
||||
mPluginAttributes.emplace_back(PluginField("max_context_length", nullptr, PluginFieldType::kINT32, 1));
|
||||
mPluginAttributes.emplace_back(PluginField("qkv_bias_enabled", nullptr, PluginFieldType::kINT8, 0));
|
||||
|
||||
@ -35,14 +35,14 @@ class GPTAttentionPluginCommon : public BasePlugin
|
||||
public:
|
||||
GPTAttentionPluginCommon() = delete;
|
||||
|
||||
GPTAttentionPluginCommon(int num_heads, int num_kv_heads, int unidirectional, float q_scaling,
|
||||
GPTAttentionPluginCommon(int num_heads, int num_kv_heads, int head_size, int unidirectional, float q_scaling,
|
||||
tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
|
||||
int rotary_embedding_dim, // for RoPE. Use 0 for non-RoPE
|
||||
float rotary_embedding_base, tensorrt_llm::kernels::RotaryScalingType rotary_embedding_scale_type,
|
||||
float rotary_embedding_scale, int rotary_embedding_max_positions, int tp_size, int tp_rank, // for ALiBi
|
||||
tensorrt_llm::kernels::ContextFMHAType context_fmha_type, bool multi_block_mode, int kv_cache_quant_mode,
|
||||
bool remove_input_padding, tensorrt_llm::kernels::AttentionMaskType mask_type, bool paged_kv_cache,
|
||||
nvinfer1::DataType type, int32_t max_context_length, bool qkv_bias_enabled);
|
||||
int tokens_per_block, nvinfer1::DataType type, int32_t max_context_length, bool qkv_bias_enabled);
|
||||
|
||||
GPTAttentionPluginCommon(const void* data, size_t length);
|
||||
|
||||
@ -93,7 +93,6 @@ protected:
|
||||
void* block_pointers;
|
||||
int32_t batch_size;
|
||||
int32_t num_tokens;
|
||||
int32_t tokens_per_block;
|
||||
int32_t max_blocks_per_sequence;
|
||||
void* workspace;
|
||||
};
|
||||
@ -118,7 +117,6 @@ protected:
|
||||
void* block_pointers;
|
||||
int32_t max_seq_lengths; // cache capacity
|
||||
int32_t num_requests;
|
||||
int32_t tokens_per_block;
|
||||
int32_t max_blocks_per_sequence;
|
||||
int32_t const* cache_indir;
|
||||
void* workspace;
|
||||
@ -138,24 +136,6 @@ protected:
|
||||
|| mPositionEmbeddingType == tensorrt_llm::kernels::PositionEmbeddingType::kROPE_GPT_NEOX;
|
||||
}
|
||||
|
||||
inline void update_rotary_params(int32_t kv_seq_len, float& base, float& scale)
|
||||
{
|
||||
base = mRotaryEmbeddingBase;
|
||||
scale = 1.0f / mRotaryEmbeddingScale; // do the division here so that we can avoid it in the kernel
|
||||
if (mPositionEmbeddingType == tensorrt_llm::kernels::PositionEmbeddingType::kROPE_GPT_NEOX
|
||||
&& mRotaryEmbeddingScaleType == tensorrt_llm::kernels::RotaryScalingType::kDYNAMIC)
|
||||
{
|
||||
if (kv_seq_len > mRotaryEmbeddingMaxPositions)
|
||||
{
|
||||
const float b
|
||||
= (mRotaryEmbeddingScale * kv_seq_len / mRotaryEmbeddingMaxPositions) - (mRotaryEmbeddingScale - 1);
|
||||
const float p = static_cast<float>(mRotaryEmbeddingDim) / (mRotaryEmbeddingDim - 2);
|
||||
base = mRotaryEmbeddingBase * pow(b, p);
|
||||
}
|
||||
scale = 1.0f; // scale factor is already used in updated base
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
const std::string mLayerName;
|
||||
|
||||
@ -173,6 +153,7 @@ protected:
|
||||
bool mRemovePadding = false;
|
||||
tensorrt_llm::kernels::AttentionMaskType mMaskType;
|
||||
bool mPagedKVCache = false;
|
||||
int mTokensPerBlock;
|
||||
tensorrt_llm::common::QuantMode mKVCacheQuantMode;
|
||||
int mTpSize = 1;
|
||||
int mTpRank = 0;
|
||||
@ -186,13 +167,13 @@ protected:
|
||||
bool mFMHAForceFP32Acc = false;
|
||||
int mSM = tensorrt_llm::common::getSMVersion();
|
||||
int mMultiProcessorCount = tensorrt_llm::common::getMultiProcessorCount();
|
||||
tensorrt_llm::kernels::MHARunner* mFMHARunner;
|
||||
// The default copy constructor will leave it as nullptr. clone() shall initialize it.
|
||||
UniqPtrWNullCopy<tensorrt_llm::kernels::MHARunner> mFMHARunner;
|
||||
|
||||
bool mMultiBlockMode;
|
||||
int mDeviceId = -1;
|
||||
tensorrt_llm::common::cublasAlgoMap* mCublasAlgoMap;
|
||||
std::mutex* mCublasWrapperMutex;
|
||||
tensorrt_llm::common::cublasMMWrapper* mCublasWrapper;
|
||||
// The default copy constructor will leave it as nullptr. clone() shall initialize it.
|
||||
UniqPtrWNullCopy<tensorrt_llm::common::CublasMMWrapper> mCublasWrapper;
|
||||
};
|
||||
|
||||
class GPTAttentionPluginCreatorCommon : public BaseCreator
|
||||
|
||||
@ -35,18 +35,18 @@ using tensorrt_llm::plugins::GPTAttentionPlugin;
|
||||
static const char* GPT_ATTENTION_PLUGIN_VERSION{"1"};
|
||||
static const char* GPT_ATTENTION_PLUGIN_NAME{"GPTAttention"};
|
||||
|
||||
GPTAttentionPlugin::GPTAttentionPlugin(int num_heads, int num_kv_heads, int unidirectional, float q_scaling,
|
||||
tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
|
||||
GPTAttentionPlugin::GPTAttentionPlugin(int num_heads, int num_kv_heads, int head_size, int unidirectional,
|
||||
float q_scaling, tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
|
||||
int rotary_embedding_dim, // for RoPE. 0 for non-RoPE
|
||||
float rotary_embedding_base, tensorrt_llm::kernels::RotaryScalingType rotary_embedding_scale_type,
|
||||
float rotary_embedding_scale, int rotary_embedding_max_positions, int tp_size, int tp_rank, // for ALiBi
|
||||
tensorrt_llm::kernels::ContextFMHAType context_fmha_type, bool multi_block_mode, int kv_cache_quant_mode,
|
||||
bool remove_input_padding, tensorrt_llm::kernels::AttentionMaskType mask_type, bool paged_kv_cache,
|
||||
nvinfer1::DataType type, int32_t max_context_length, bool qkv_bias_enabled)
|
||||
: GPTAttentionPluginCommon(num_heads, num_kv_heads, unidirectional, q_scaling, position_embedding_type,
|
||||
int tokens_per_block, nvinfer1::DataType type, int32_t max_context_length, bool qkv_bias_enabled)
|
||||
: GPTAttentionPluginCommon(num_heads, num_kv_heads, head_size, unidirectional, q_scaling, position_embedding_type,
|
||||
rotary_embedding_dim, rotary_embedding_base, rotary_embedding_scale_type, rotary_embedding_scale,
|
||||
rotary_embedding_max_positions, tp_size, tp_rank, context_fmha_type, multi_block_mode, kv_cache_quant_mode,
|
||||
remove_input_padding, mask_type, paged_kv_cache, type, max_context_length, qkv_bias_enabled)
|
||||
remove_input_padding, mask_type, paged_kv_cache, tokens_per_block, type, max_context_length, qkv_bias_enabled)
|
||||
{
|
||||
}
|
||||
|
||||
@ -63,17 +63,17 @@ GPTAttentionPlugin* GPTAttentionPlugin::clone() const noexcept
|
||||
|
||||
// outputs
|
||||
// output_tensor [batch_size, seq_len, local_hidden_size]
|
||||
// present_key_value_pool [blocks, 2, local_num_kv_heads, tokens_per_block, head_size] if paged_kv_attention
|
||||
// or [batch_size, 2, local_num_kv_heads, max_seq_len, head_size]
|
||||
// present_key_value_pool (optional if mPagedKVCache is false) [batch_size, 2, local_num_kv_heads, max_seq_len,
|
||||
// head_size]
|
||||
nvinfer1::DimsExprs GPTAttentionPlugin::getOutputDimensions(
|
||||
int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept
|
||||
{
|
||||
TLLM_CHECK(outputIndex == 0 || outputIndex == 1);
|
||||
TLLM_CHECK(outputIndex == 0 || (!mPagedKVCache && outputIndex == 1));
|
||||
if (outputIndex == 0)
|
||||
{
|
||||
auto ret = inputs[getInputTensorIdx()];
|
||||
ret.d[2] = exprBuilder.operation(
|
||||
DimensionOperation::kPROD, *inputs[getPastKeyValueIdx()].d[4], *exprBuilder.constant(mNumHeads));
|
||||
DimensionOperation::kPROD, *exprBuilder.constant(mHeadSize), *exprBuilder.constant(mNumHeads));
|
||||
return ret;
|
||||
}
|
||||
return inputs[getPastKeyValueIdx()];
|
||||
@ -96,13 +96,20 @@ bool GPTAttentionPlugin::supportsFormatCombination(
|
||||
else if (mPagedKVCache && pos == getKVCacheBlockPointersIdx())
|
||||
{
|
||||
// pointers to kv cache blocks
|
||||
return inOut[pos].type == nvinfer1::DataType::kINT32 && inOut[pos].format == TensorFormat::kLINEAR;
|
||||
return inOut[pos].type == nvinfer1::DataType::kINT64 && inOut[pos].format == TensorFormat::kLINEAR;
|
||||
}
|
||||
else if (mKVCacheQuantMode.hasInt8KvCache() && (pos == getPastKeyValueIdx() || pos == nbInputs + 1))
|
||||
else if (mKVCacheQuantMode.hasInt8KvCache()
|
||||
&& (!mPagedKVCache && (pos == getPastKeyValueIdx() || pos == nbInputs + 1)))
|
||||
{
|
||||
// If use Int8 K/V cache we require I/O KV values to int8
|
||||
return (inOut[pos].type == nvinfer1::DataType::kINT8) && (inOut[pos].format == TensorFormat::kLINEAR);
|
||||
}
|
||||
else if (mKVCacheQuantMode.hasFp8KvCache()
|
||||
&& (!mPagedKVCache && (pos == getPastKeyValueIdx() || pos == nbInputs + 1)))
|
||||
{
|
||||
// If use FP8 K/V cache we require I/O KV values to FP8
|
||||
return (inOut[pos].type == nvinfer1::DataType::kFP8) && (inOut[pos].format == TensorFormat::kLINEAR);
|
||||
}
|
||||
else if (mRemovePadding && (pos == getHostContextLengthsIdx()))
|
||||
{
|
||||
return inOut[pos].type == nvinfer1::DataType::kINT32 && inOut[pos].format == TensorFormat::kLINEAR;
|
||||
@ -117,7 +124,6 @@ bool GPTAttentionPlugin::supportsFormatCombination(
|
||||
void GPTAttentionPlugin::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs,
|
||||
const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept
|
||||
{
|
||||
mHeadSize = in[getPastKeyValueIdx()].desc.dims.d[4];
|
||||
TLLM_CHECK(mHeadSize > 0);
|
||||
|
||||
// pre-check whether FMHA is supported in order to save memory allocation
|
||||
@ -242,17 +248,13 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32
|
||||
}
|
||||
|
||||
int max_blocks_per_sequence = 0;
|
||||
int tokens_per_block = 0;
|
||||
void* block_pointers = nullptr;
|
||||
if (mPagedKVCache)
|
||||
{
|
||||
auto& kvCacheBlockPointers = inputDesc[getKVCacheBlockPointersIdx()];
|
||||
auto& kvCacheBlockPointersShape = inputDesc[getKVCacheBlockPointersIdx()].dims;
|
||||
// Div by 2 because we reinterpret int32 input as int64
|
||||
max_blocks_per_sequence = kvCacheBlockPointersShape.d[kvCacheBlockPointersShape.nbDims - 1] / 2;
|
||||
tokens_per_block = inputDesc[getPastKeyValueIdx()].dims.d[3];
|
||||
// Div by 2 because we reinterpret int32 input as int64
|
||||
auto offset = getStride(kvCacheBlockPointersShape, 0) / 2 * seqIdxBeg;
|
||||
max_blocks_per_sequence = kvCacheBlockPointersShape.d[kvCacheBlockPointersShape.nbDims - 1];
|
||||
auto offset = getStride(kvCacheBlockPointersShape, 0) * seqIdxBeg;
|
||||
auto const typed_block_pointers = static_cast<void* const*>(inputs[getKVCacheBlockPointersIdx()]) + offset;
|
||||
block_pointers = const_cast<void*>(static_cast<void const*>(typed_block_pointers));
|
||||
}
|
||||
@ -273,7 +275,7 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32
|
||||
enqueueContext<T, KVCacheBuffer>(
|
||||
EnqueueContextParams<T, KVCacheBuffer>{attention_input, qkv_bias, max_context_len, maxSeqLen,
|
||||
context_lengths, kv_scale_orig_quant, kv_scale_quant_orig, alibi_slopes, context_buf_, key_value_cache,
|
||||
block_pointers, batch_size, localNbTokens, tokens_per_block, max_blocks_per_sequence, workspace},
|
||||
block_pointers, batch_size, localNbTokens, max_blocks_per_sequence, workspace},
|
||||
stream);
|
||||
}
|
||||
else // generation stage; input_seq_len == 1
|
||||
@ -290,8 +292,8 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32
|
||||
enqueueGeneration<T, KVCacheBuffer>(
|
||||
EnqueueGenerationParams<T, KVCacheBuffer>{attention_input, qkv_bias, sequence_length, past_kv_len,
|
||||
beamWidth, context_lengths, kv_scale_orig_quant, kv_scale_quant_orig, alibi_slopes, context_buf_,
|
||||
key_value_cache, block_pointers, maxSeqLen, num_requests, tokens_per_block, max_blocks_per_sequence,
|
||||
cache_indir, workspace},
|
||||
key_value_cache, block_pointers, maxSeqLen, num_requests, max_blocks_per_sequence, cache_indir,
|
||||
workspace},
|
||||
stream);
|
||||
}
|
||||
|
||||
@ -339,8 +341,15 @@ int GPTAttentionPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
|
||||
nvinfer1::DataType GPTAttentionPlugin::getOutputDataType(
|
||||
int index, const nvinfer1::DataType* inputTypes, int nbInputs) const noexcept
|
||||
{
|
||||
TLLM_CHECK(index == 0 || index == 1);
|
||||
return inputTypes[index];
|
||||
TLLM_CHECK(index == 0 || (!mPagedKVCache && index == 1));
|
||||
if (index == 0)
|
||||
{
|
||||
return inputTypes[getInputTensorIdx()];
|
||||
}
|
||||
else
|
||||
{
|
||||
return inputTypes[getPastKeyValueIdx()];
|
||||
}
|
||||
}
|
||||
|
||||
// IPluginV2 Methods
|
||||
@ -357,7 +366,7 @@ const char* GPTAttentionPlugin::getPluginVersion() const noexcept
|
||||
|
||||
int GPTAttentionPlugin::getNbOutputs() const noexcept
|
||||
{
|
||||
return 2;
|
||||
return mPagedKVCache ? 1 : 2;
|
||||
}
|
||||
|
||||
size_t GPTAttentionPlugin::getSerializationSize() const noexcept
|
||||
@ -403,8 +412,8 @@ IPluginV2* GPTAttentionPluginCreator::createPlugin(const char* name, const Plugi
|
||||
try
|
||||
{
|
||||
auto* obj = new GPTAttentionPlugin(p.getScalar<int32_t>("num_heads").value(),
|
||||
p.getScalar<int32_t>("num_kv_heads").value(), p.getScalar<int32_t>("unidirectional").value(),
|
||||
p.getScalar<float>("q_scaling").value(),
|
||||
p.getScalar<int32_t>("num_kv_heads").value(), p.getScalar<int32_t>("head_size").value(),
|
||||
p.getScalar<int32_t>("unidirectional").value(), p.getScalar<float>("q_scaling").value(),
|
||||
static_cast<PositionEmbeddingType>(p.getScalar<int8_t>("position_embedding_type").value()),
|
||||
p.getScalar<int32_t>("rotary_embedding_dim").value(), p.getScalar<float>("rotary_embedding_base").value(),
|
||||
static_cast<RotaryScalingType>(p.getScalar<int8_t>("rotary_embedding_scale_type").value()),
|
||||
@ -418,6 +427,7 @@ IPluginV2* GPTAttentionPluginCreator::createPlugin(const char* name, const Plugi
|
||||
static_cast<bool>(p.getScalar<int8_t>("remove_input_padding").value()),
|
||||
static_cast<AttentionMaskType>(p.getScalar<int32_t>("mask_type").value()),
|
||||
static_cast<bool>(p.getScalar<int32_t>("paged_kv_cache").value()),
|
||||
p.getScalar<int32_t>("tokens_per_block").value(),
|
||||
static_cast<nvinfer1::DataType>(p.getScalar<int32_t>("type_id").value()),
|
||||
p.getScalar<int32_t>("max_context_length").value(),
|
||||
static_cast<bool>(p.getScalar<int8_t>("qkv_bias_enabled").value()));
|
||||
|
||||
@ -38,40 +38,40 @@ namespace tensorrt_llm::plugins
|
||||
// Context sequences have to appear first, generation sequences after
|
||||
|
||||
// inputs
|
||||
// input_tensor [batch_size, seq_len, local_hidden_size + 2 * local_num_kv_heads * head_size]
|
||||
// [1, num_tokens, local_hidden_size + 2 * local_num_kv_heads * head_size] when
|
||||
// enable_remove_input_padding
|
||||
// past_key_value_pool [blocks, 2, local_num_kv_heads, tokens_per_block, head_size] if paged_kv_attention
|
||||
// or [batch_size, 2, local_num_kv_heads, max_seq_len, head_size]
|
||||
// sequence_length [batch_size]
|
||||
// host_past_key_value_lengths [batch_size] (int32)
|
||||
// context_lengths [batch_size]
|
||||
// cache_indir [num_gen_requests, beam_width, memory_max_len] (required in beamsearch)
|
||||
// host_request_types [batch_size] int32. 0: context; 1: generation: 2: none. When not in inflight-batching mode,
|
||||
// all elements must be identical.
|
||||
// kv_cache_quantization_scale [1] (optional)
|
||||
// kv_cache_dequantization_scale [1] (optional)
|
||||
// block_pointers [batch_size, 2, max_blocks_per_seq] (optional if paged kv cache)
|
||||
// alibi_slopes [num_heads] (optional for ALiBi position embedding)
|
||||
// host_context_lengths [batch_size] int32. (optional, required when remove_input_padding is true)
|
||||
// qkv_bias (optional) [local_hidden_size * 3]
|
||||
// 0. input_tensor [batch_size, seq_len, local_hidden_size + 2 * local_num_kv_heads * head_size] or
|
||||
// [1, num_tokens, local_hidden_size + 2 * local_num_kv_heads * head_size] when
|
||||
// enable_remove_input_padding
|
||||
// 1. sequence_length [batch_size]
|
||||
// 2. host_past_key_value_lengths [batch_size] (int32)
|
||||
// 3. context_lengths [batch_size]
|
||||
// 4. cache_indir [num_gen_requests, beam_width, memory_max_len] (required in beamsearch)
|
||||
// 5. host_request_types [batch_size] int32. 0: context; 1: generation: 2: none. When not in inflight-batching
|
||||
// mode,
|
||||
// all elements must be identical.
|
||||
// 6. past_key_value_pool [batch_size, 2, local_num_kv_heads, max_seq_len, head_size] or
|
||||
// block_pointers [batch_size, 2, max_blocks_per_seq] if paged kv cache
|
||||
// 7. kv_cache_quantization_scale [1] (optional)
|
||||
// 8. kv_cache_dequantization_scale [1] (optional)
|
||||
// 9. alibi_slopes [num_heads] (optional for ALiBi position embedding)
|
||||
// 10. host_context_lengths [batch_size] int32. (optional, required when remove_input_padding is true)
|
||||
// 11. qkv_bias (optional) [local_hidden_size * 3]
|
||||
//
|
||||
// outputs
|
||||
// output_tensor [batch_size, seq_len, local_hidden_size]
|
||||
// present_key_value_pool [blocks, 2, local_num_kv_heads, tokens_per_block, head_size] if paged_kv_attention
|
||||
// or [batch_size, 2, local_num_kv_heads, max_seq_len, head_size]
|
||||
// present_key_value_pool (optional if not paged kv cache) [batch_size, 2, local_num_kv_heads, max_seq_len,
|
||||
// head_size]
|
||||
|
||||
class GPTAttentionPlugin : public GPTAttentionPluginCommon
|
||||
{
|
||||
public:
|
||||
GPTAttentionPlugin(int num_heads, int num_kv_heads, int unidirectional, float q_scaling,
|
||||
GPTAttentionPlugin(int num_heads, int num_kv_heads, int head_size, int unidirectional, float q_scaling,
|
||||
tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
|
||||
int rotary_embedding_dim, // for RoPE. 0 for non-RoPE
|
||||
float rotary_embedding_base, tensorrt_llm::kernels::RotaryScalingType rotary_embedding_scale_type,
|
||||
float rotary_embedding_scale, int rotary_embedding_max_positions, int tp_size, int tp_rank, // for ALiBi
|
||||
tensorrt_llm::kernels::ContextFMHAType context_fmha_type, bool multi_block_mode, int kv_cache_quant_mode,
|
||||
bool remove_input_padding, tensorrt_llm::kernels::AttentionMaskType mask_type, bool paged_kv_cache,
|
||||
nvinfer1::DataType type, int32_t max_context_length, bool qkv_bias_enabled);
|
||||
int tokens_per_block, nvinfer1::DataType type, int32_t max_context_length, bool qkv_bias_enabled);
|
||||
|
||||
GPTAttentionPlugin(const void* data, size_t length);
|
||||
|
||||
@ -134,33 +134,40 @@ private:
|
||||
return 0;
|
||||
}
|
||||
|
||||
IndexType getPastKeyValueIdx() const
|
||||
IndexType getSequenceLengthIdx() const
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
|
||||
IndexType getSequenceLengthIdx() const
|
||||
IndexType getHostPastKeyValueLengthsIdx() const
|
||||
{
|
||||
return 2;
|
||||
}
|
||||
|
||||
IndexType getHostPastKeyValueLengthsIdx() const
|
||||
IndexType getContextLengthsIdx() const
|
||||
{
|
||||
return 3;
|
||||
}
|
||||
|
||||
IndexType getContextLengthsIdx() const
|
||||
IndexType getCacheIndirIdx() const
|
||||
{
|
||||
return 4;
|
||||
}
|
||||
|
||||
IndexType getCacheIndirIdx() const
|
||||
IndexType getRequestTypesIdx() const
|
||||
{
|
||||
return 5;
|
||||
}
|
||||
|
||||
IndexType getRequestTypesIdx() const
|
||||
IndexType getKVCacheBlockPointersIdx() const
|
||||
{
|
||||
// NOTE We either provide this tensor when mPagedKVCache is true or PastKeyValue otherwise
|
||||
return 6;
|
||||
}
|
||||
|
||||
IndexType getPastKeyValueIdx() const
|
||||
{
|
||||
// NOTE We either provide this tensor when mPagedKVCache is false or KVCacheBlockPointers otherwise
|
||||
return 6;
|
||||
}
|
||||
|
||||
@ -174,27 +181,21 @@ private:
|
||||
return 8;
|
||||
}
|
||||
|
||||
IndexType getKVCacheBlockPointersIdx() const
|
||||
{
|
||||
return mKVCacheQuantMode.hasKvCacheQuant() ? 9 : 7;
|
||||
}
|
||||
|
||||
IndexType getAlibiSlopesIdx() const
|
||||
{
|
||||
return (mKVCacheQuantMode.hasKvCacheQuant() ? 9 : 7) + (mPagedKVCache ? 1 : 0);
|
||||
return (mKVCacheQuantMode.hasKvCacheQuant() ? 9 : 7);
|
||||
}
|
||||
|
||||
IndexType getHostContextLengthsIdx() const
|
||||
{
|
||||
TLLM_CHECK(mRemovePadding);
|
||||
return (mKVCacheQuantMode.hasKvCacheQuant() ? 9 : 7) + (mPagedKVCache ? 1 : 0) + (isALiBi() ? 1 : 0);
|
||||
return (mKVCacheQuantMode.hasKvCacheQuant() ? 9 : 7) + (isALiBi() ? 1 : 0);
|
||||
}
|
||||
|
||||
IndexType getQKVBiasTensorIdx() const
|
||||
{
|
||||
TLLM_CHECK(mQKVBiasEnabled);
|
||||
return (mKVCacheQuantMode.hasInt8KvCache() ? 9 : 7) + (mPagedKVCache ? 1 : 0) + (isALiBi() ? 1 : 0)
|
||||
+ (mRemovePadding ? 1 : 0);
|
||||
return (mKVCacheQuantMode.hasKvCacheQuant() ? 9 : 7) + (isALiBi() ? 1 : 0) + (mRemovePadding ? 1 : 0);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -169,7 +169,7 @@ int LayernormQuantizationPlugin::enqueue(const nvinfer1::PluginTensorDesc* input
|
||||
nvinfer1::DataType LayernormQuantizationPlugin::getOutputDataType(
|
||||
int index, const nvinfer1::DataType* inputTypes, int nbInputs) const noexcept
|
||||
{
|
||||
assert((mDynActScaling && index < 2) || (~mDynActScaling && index == 0));
|
||||
assert((mDynActScaling && index < 2) || (!mDynActScaling && index == 0));
|
||||
if (index == 0)
|
||||
{
|
||||
// Output 0 quantized output of layer norm
|
||||
|
||||
249
cpp/tensorrt_llm/plugins/ncclPlugin/FTCustomAR.cpp
Normal file
249
cpp/tensorrt_llm/plugins/ncclPlugin/FTCustomAR.cpp
Normal file
@ -0,0 +1,249 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/plugins/ncclPlugin/FTCustomAR.h"
|
||||
#include "NvInferRuntimeBase.h"
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/mpiUtils.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
|
||||
CustomAllReduceComm::CustomAllReduceComm(size_t TPSize, size_t PPSize, int deviceId, size_t bufferSize)
|
||||
: mTPSize(TPSize)
|
||||
, mPPSize(PPSize)
|
||||
, mTPRank(deviceId % TPSize)
|
||||
, mPPRank(deviceId / TPSize)
|
||||
, mDeviceId(deviceId)
|
||||
, mBufferSize(bufferSize)
|
||||
{
|
||||
if (mPPSize == 0)
|
||||
{
|
||||
group_comm_ = mpi::COMM_WORLD;
|
||||
}
|
||||
else
|
||||
{
|
||||
mpi::comm_split(mpi::COMM_WORLD, mPPRank, mTPRank, &group_comm_);
|
||||
}
|
||||
param_.barrier_flag = 0;
|
||||
// NOTE: assume All Reduce happens within the node (DGX A100)
|
||||
param_.ranks_per_node = mTPSize;
|
||||
param_.rank = mTPRank;
|
||||
param_.local_rank = mTPRank;
|
||||
param_.node_id = 0;
|
||||
allocate();
|
||||
IpcSyncMemHandle();
|
||||
}
|
||||
|
||||
CustomAllReduceComm::~CustomAllReduceComm()
|
||||
{
|
||||
if (is_ipc_handle_opened_)
|
||||
{
|
||||
IpcCloseMemHandle();
|
||||
}
|
||||
mpi::barrier(); // wait for others to stop using resources before freeing them
|
||||
if (mTPRank == 0)
|
||||
{
|
||||
for (int rank = 0; rank < mTPSize; rank++)
|
||||
{
|
||||
size_t device_id = mPPRank * mTPSize + rank;
|
||||
cudaSetDevice(device_id);
|
||||
cudaPointerAttributes comm_buffer_attributes, barrier_attributes;
|
||||
common::check_cuda_error(
|
||||
cudaPointerGetAttributes(&comm_buffer_attributes, param_.peer_comm_buffer_ptrs[rank]));
|
||||
common::check_cuda_error(cudaPointerGetAttributes(&barrier_attributes, param_.peer_barrier_ptrs[rank]));
|
||||
if (comm_buffer_attributes.type == 2)
|
||||
{
|
||||
common::check_cuda_error(cudaFree(param_.peer_comm_buffer_ptrs[rank]));
|
||||
}
|
||||
if (barrier_attributes.type == 2)
|
||||
{
|
||||
common::check_cuda_error(cudaFree(param_.peer_barrier_ptrs[rank]));
|
||||
}
|
||||
}
|
||||
cudaSetDevice(mDeviceId);
|
||||
setP2P(false);
|
||||
}
|
||||
}
|
||||
|
||||
void CustomAllReduceComm::IpcGetMemHandle()
|
||||
{
|
||||
for (int rank = 0; rank < mTPSize; rank++)
|
||||
{
|
||||
common::check_cuda_error(cudaIpcGetMemHandle(
|
||||
&(param_.ipc_mem_handles.peer_barrier_ipc_handles[rank]), param_.peer_barrier_ptrs[rank]));
|
||||
common::check_cuda_error(cudaIpcGetMemHandle(
|
||||
&(param_.ipc_mem_handles.peer_comm_buffer_ipc_handles[rank]), param_.peer_comm_buffer_ptrs[rank]));
|
||||
}
|
||||
}
|
||||
|
||||
void CustomAllReduceComm::IpcSyncMemHandle()
|
||||
{
|
||||
if (mTPRank == 0)
|
||||
{
|
||||
IpcGetMemHandle();
|
||||
}
|
||||
mpi::bcast(reinterpret_cast<char*>(&(param_.ipc_mem_handles)), sizeof(kernels::AllReduceIpcMemHandles),
|
||||
mpi::MPI_TYPE_CHAR, 0, group_comm_);
|
||||
if (mTPRank != 0)
|
||||
{
|
||||
IpcOpenMemHandle();
|
||||
}
|
||||
common::check_cuda_error(cudaSetDevice(mDeviceId));
|
||||
}
|
||||
|
||||
void CustomAllReduceComm::IpcOpenMemHandle()
|
||||
{
|
||||
if (is_ipc_handle_opened_)
|
||||
{
|
||||
IpcCloseMemHandle();
|
||||
is_ipc_handle_opened_ = false;
|
||||
}
|
||||
if (!is_ipc_handle_opened_)
|
||||
{
|
||||
for (int rank = 0; rank < mTPSize; rank++)
|
||||
{
|
||||
common::check_cuda_error(cudaIpcOpenMemHandle((void**) (&(param_.peer_barrier_ptrs[rank])),
|
||||
param_.ipc_mem_handles.peer_barrier_ipc_handles[rank], cudaIpcMemLazyEnablePeerAccess));
|
||||
common::check_cuda_error(cudaIpcOpenMemHandle((void**) (&(param_.peer_comm_buffer_ptrs[rank])),
|
||||
param_.ipc_mem_handles.peer_comm_buffer_ipc_handles[rank], cudaIpcMemLazyEnablePeerAccess));
|
||||
}
|
||||
param_.local_output_buffer_ptr = param_.peer_comm_buffer_ptrs[mTPRank];
|
||||
is_ipc_handle_opened_ = true;
|
||||
}
|
||||
}
|
||||
|
||||
void CustomAllReduceComm::IpcCloseMemHandle()
|
||||
{
|
||||
if (is_ipc_handle_opened_)
|
||||
{
|
||||
for (int rank = 0; rank < mTPSize; rank++)
|
||||
{
|
||||
common::check_cuda_error(cudaIpcCloseMemHandle(param_.peer_barrier_ptrs[rank]));
|
||||
common::check_cuda_error(cudaIpcCloseMemHandle(param_.peer_comm_buffer_ptrs[rank]));
|
||||
}
|
||||
is_ipc_handle_opened_ = false;
|
||||
}
|
||||
}
|
||||
|
||||
void CustomAllReduceComm::customAllReduce(
|
||||
void* data, size_t elts, size_t size_per_elem, nvinfer1::DataType dataType, cudaStream_t stream)
|
||||
{
|
||||
param_.local_output_buffer_ptr = data;
|
||||
param_.elts_total = elts;
|
||||
param_.barrier_flag = FLAG(param_.barrier_flag + 1);
|
||||
|
||||
if (dataType == nvinfer1::DataType::kFLOAT)
|
||||
{
|
||||
using T = tensorrt_llm::CustomARCommTypeConverter<float>::Type;
|
||||
kernels::invokeOneOrTwoShotAllReduceKernel<T>(param_, stream);
|
||||
}
|
||||
else if (dataType == nvinfer1::DataType::kHALF)
|
||||
{
|
||||
using T = tensorrt_llm::CustomARCommTypeConverter<half>::Type;
|
||||
kernels::invokeOneOrTwoShotAllReduceKernel<T>(param_, stream);
|
||||
}
|
||||
else if (dataType == nvinfer1::DataType::kBF16)
|
||||
{
|
||||
using T = tensorrt_llm::CustomARCommTypeConverter<__nv_bfloat16>::Type;
|
||||
kernels::invokeOneOrTwoShotAllReduceKernel<T>(param_, stream);
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(false, "Unsupported dataType for customAllReduce");
|
||||
}
|
||||
}
|
||||
|
||||
void CustomAllReduceComm::allocate()
|
||||
{
|
||||
if (mTPRank != 0)
|
||||
return;
|
||||
setP2P();
|
||||
for (size_t i = 0; i < mTPSize; i++)
|
||||
{
|
||||
size_t device_id = mPPRank * mTPSize + i;
|
||||
common::check_cuda_error(cudaSetDevice(device_id));
|
||||
common::check_cuda_error(cudaMalloc(&(param_.peer_comm_buffer_ptrs[i]), mBufferSize));
|
||||
common::check_cuda_error(
|
||||
cudaMalloc(&(param_.peer_barrier_ptrs[i]), mTPSize * (MAX_ALL_REDUCE_BLOCKS + 1) * sizeof(uint32_t)));
|
||||
common::check_cuda_error(
|
||||
cudaMemset(param_.peer_barrier_ptrs[i], 0, mTPSize * (MAX_ALL_REDUCE_BLOCKS + 1) * sizeof(uint32_t)));
|
||||
}
|
||||
cudaSetDevice(mDeviceId);
|
||||
}
|
||||
|
||||
bool CustomAllReduceComm::isAvailable()
|
||||
{
|
||||
#if defined(CUDART_VERSION) && CUDART_VERSION >= 11020
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
|
||||
if (!mpi::isInitialized())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
auto worldSize = mpi::getCommWorldSize();
|
||||
auto rank = mpi::getCommWorldRank();
|
||||
if ((worldSize % 2 != 0) || (worldSize > MAX_RANKS_PER_NODE) || (worldSize == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void CustomAllReduceComm::setP2P(bool activate)
|
||||
{
|
||||
int peer_access_available = 0;
|
||||
size_t device_offset = mPPRank * mTPSize;
|
||||
for (int i = 0; i < mTPSize; i++)
|
||||
{
|
||||
cudaSetDevice(device_offset + i);
|
||||
for (int j = 0; j < mTPSize; j++)
|
||||
{
|
||||
if (i == j)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
cudaDeviceCanAccessPeer(&peer_access_available, device_offset + i, device_offset + j);
|
||||
assert(peer_access_available);
|
||||
if (activate)
|
||||
{
|
||||
cudaDeviceEnablePeerAccess(device_offset + j, 0);
|
||||
cudaError_t result = cudaGetLastError();
|
||||
if (result == cudaErrorPeerAccessAlreadyEnabled)
|
||||
{
|
||||
result = cudaSuccess;
|
||||
}
|
||||
common::check_cuda_error(result);
|
||||
}
|
||||
else
|
||||
{
|
||||
cudaDeviceDisablePeerAccess(device_offset + j);
|
||||
}
|
||||
}
|
||||
}
|
||||
cudaSetDevice(mDeviceId);
|
||||
}
|
||||
|
||||
void* CustomAllReduceComm::getShareBuffer()
|
||||
{
|
||||
return reinterpret_cast<void*>(param_.peer_comm_buffer_ptrs[mTPRank]);
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm
|
||||
87
cpp/tensorrt_llm/plugins/ncclPlugin/FTCustomAR.h
Normal file
87
cpp/tensorrt_llm/plugins/ncclPlugin/FTCustomAR.h
Normal file
@ -0,0 +1,87 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include "NvInferRuntimeBase.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/common/logger.h"
|
||||
#include "tensorrt_llm/common/mpiUtils.h"
|
||||
#include "tensorrt_llm/common/tensor.h"
|
||||
#include "tensorrt_llm/kernels/customAllReduceKernels.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
|
||||
class CustomAllReduceComm
|
||||
{
|
||||
public:
|
||||
CustomAllReduceComm(size_t TPSize, size_t PPSize, int deviceId, size_t bufferSize);
|
||||
~CustomAllReduceComm();
|
||||
|
||||
void customAllReduce(
|
||||
void* data, size_t elts, size_t size_per_elem, nvinfer1::DataType dataType, cudaStream_t stream);
|
||||
void* getShareBuffer();
|
||||
|
||||
static bool isAvailable();
|
||||
|
||||
private:
|
||||
void setP2P(bool activate = true);
|
||||
|
||||
void IpcGetMemHandle();
|
||||
void IpcOpenMemHandle();
|
||||
void IpcCloseMemHandle();
|
||||
void IpcSyncMemHandle();
|
||||
|
||||
void allocate();
|
||||
|
||||
kernels::AllReduceParams param_;
|
||||
bool is_ipc_handle_opened_ = false;
|
||||
size_t mTPSize;
|
||||
size_t mTPRank;
|
||||
size_t mPPSize;
|
||||
size_t mPPRank;
|
||||
size_t mBufferSize;
|
||||
int mDeviceId;
|
||||
mpi::MpiComm group_comm_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct CustomARCommTypeConverter
|
||||
{
|
||||
using Type = uint32_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct CustomARCommTypeConverter<half>
|
||||
{
|
||||
using Type = uint16_t;
|
||||
};
|
||||
|
||||
#ifdef ENABLE_BF16
|
||||
template <>
|
||||
struct CustomARCommTypeConverter<__nv_bfloat16>
|
||||
{
|
||||
using Type = __nv_bfloat16;
|
||||
};
|
||||
#endif
|
||||
|
||||
} // namespace tensorrt_llm
|
||||
@ -161,7 +161,7 @@ int AllgatherPlugin::initialize() noexcept
|
||||
MPI_Status status;
|
||||
MPICHECK(MPI_Recv(&id, sizeof(id), MPI_BYTE, *mGroup.begin(), 0, MPI_COMM_WORLD, &status));
|
||||
}
|
||||
(*commMap)[mGroup] == nullptr;
|
||||
(*commMap)[mGroup] = nullptr;
|
||||
NCCLCHECK(ncclCommInitRank(&((*commMap)[mGroup]), mGroup.size(), id, groupRank));
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -16,6 +16,7 @@
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/common/mpiUtils.h"
|
||||
#include "tensorrt_llm/plugins/common/plugin.h"
|
||||
#include <cassert>
|
||||
#include <mpi.h>
|
||||
|
||||
@ -15,6 +15,8 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "allreducePlugin.h"
|
||||
#include "mpi.h"
|
||||
#include "plugin.h"
|
||||
|
||||
using namespace nvinfer1;
|
||||
using tensorrt_llm::plugins::AllreducePluginCreator;
|
||||
@ -25,9 +27,11 @@ static const char* ALLREDUCE_PLUGIN_NAME{"AllReduce"};
|
||||
PluginFieldCollection AllreducePluginCreator::mFC{};
|
||||
std::vector<nvinfer1::PluginField> AllreducePluginCreator::mPluginAttributes;
|
||||
|
||||
AllreducePlugin::AllreducePlugin(std::set<int> group, nvinfer1::DataType type)
|
||||
AllreducePlugin::AllreducePlugin(std::set<int> group, nvinfer1::DataType type, AllReduceStrategyType strategy)
|
||||
: mGroup(group)
|
||||
, mType(type)
|
||||
, mStrategy(strategy)
|
||||
, mCustomARBufferSize(0)
|
||||
{
|
||||
}
|
||||
|
||||
@ -36,6 +40,8 @@ AllreducePlugin::AllreducePlugin(const void* data, size_t length)
|
||||
{
|
||||
const char *d = reinterpret_cast<const char*>(data), *a = d;
|
||||
read(d, mType);
|
||||
read(d, mStrategy);
|
||||
read(d, mCustomARBufferSize);
|
||||
mGroup.clear();
|
||||
int groupItem = 0;
|
||||
while (d != a + length)
|
||||
@ -69,6 +75,32 @@ bool AllreducePlugin::supportsFormatCombination(
|
||||
void AllreducePlugin::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs,
|
||||
const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept
|
||||
{
|
||||
if (mStrategy == AllReduceStrategyType::NCCL)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
size_t sizePerElem = 0;
|
||||
switch (mType)
|
||||
{
|
||||
case DataType::kFLOAT: sizePerElem = sizeof(float); break;
|
||||
case DataType::kHALF: sizePerElem = sizeof(half); break;
|
||||
#ifdef ENABLE_BF16
|
||||
case DataType::kBF16: sizePerElem = sizeof(__nv_bfloat16); break;
|
||||
#endif
|
||||
default: break;
|
||||
}
|
||||
|
||||
size_t inputSize = 1;
|
||||
for (int i = 0; i < in[0].max.nbDims; i++)
|
||||
{
|
||||
inputSize *= in[0].max.d[i];
|
||||
}
|
||||
|
||||
if (isBuilding())
|
||||
{
|
||||
mCustomARBufferSize = std::max(mCustomARBufferSize, inputSize * sizePerElem);
|
||||
}
|
||||
}
|
||||
|
||||
size_t AllreducePlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs,
|
||||
@ -77,6 +109,13 @@ size_t AllreducePlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc* input
|
||||
return 0;
|
||||
}
|
||||
|
||||
size_t AllreducePlugin::ncclEstimatedThreshold(int worldSize) const noexcept
|
||||
{
|
||||
// returns the message size over which it's more interesting to use NCCL
|
||||
// 0.60 * TP_SIZE * 10MB
|
||||
return 0.60 * (10 * 1000 * 1000 * worldSize);
|
||||
}
|
||||
|
||||
int AllreducePlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc,
|
||||
const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept
|
||||
{
|
||||
@ -89,9 +128,36 @@ int AllreducePlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const
|
||||
{
|
||||
size *= inputDesc[0].dims.d[i];
|
||||
}
|
||||
size_t sizePerElem = 0;
|
||||
switch (mType)
|
||||
{
|
||||
case DataType::kFLOAT: sizePerElem = sizeof(float); break;
|
||||
case DataType::kHALF: sizePerElem = sizeof(half); break;
|
||||
#ifdef ENABLE_BF16
|
||||
case DataType::kBF16: sizePerElem = sizeof(__nv_bfloat16); break;
|
||||
#endif
|
||||
default: break;
|
||||
}
|
||||
|
||||
NCCLCHECK(ncclAllReduce(
|
||||
inputs[0], outputs[0], size, (*getDtypeMap())[inputDesc[0].type], ncclSum, (*getCommMap())[mGroup], stream));
|
||||
auto runtimeStrategy = mStrategy;
|
||||
if (runtimeStrategy == AllReduceStrategyType::AUTO)
|
||||
{
|
||||
runtimeStrategy = size * sizePerElem > ncclEstimatedThreshold(mGroup.size()) ? AllReduceStrategyType::NCCL
|
||||
: AllReduceStrategyType::CUSTOM;
|
||||
}
|
||||
|
||||
if (runtimeStrategy == AllReduceStrategyType::NCCL)
|
||||
{
|
||||
NCCLCHECK(ncclAllReduce(inputs[0], outputs[0], size, (*getDtypeMap())[inputDesc[0].type], ncclSum,
|
||||
(*getCommMap())[mGroup], stream));
|
||||
}
|
||||
else if (runtimeStrategy == AllReduceStrategyType::CUSTOM)
|
||||
{
|
||||
auto shareBuffer = mCustomAllReduceContext->getShareBuffer();
|
||||
cudaMemcpyAsync(shareBuffer, inputs[0], size * sizePerElem, cudaMemcpyDeviceToDevice, stream);
|
||||
|
||||
mCustomAllReduceContext->customAllReduce(outputs[0], size, sizePerElem, mType, stream);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
@ -121,68 +187,105 @@ int AllreducePlugin::getNbOutputs() const noexcept
|
||||
return 1;
|
||||
}
|
||||
|
||||
bool AllreducePlugin::isCustomAllReduceSuported(int ranks_per_node) const noexcept
|
||||
{
|
||||
constexpr bool isCudaVersionSupported =
|
||||
#if defined(CUDART_VERSION) && CUDART_VERSION >= 11020
|
||||
true;
|
||||
#else
|
||||
false;
|
||||
#endif
|
||||
|
||||
return isCudaVersionSupported && (ranks_per_node % 2 == 0) && (ranks_per_node <= MAX_RANKS_PER_NODE)
|
||||
&& (ranks_per_node > 0);
|
||||
}
|
||||
|
||||
int AllreducePlugin::initialize() noexcept
|
||||
{
|
||||
auto* commMap = getCommMap();
|
||||
// [] operator inserts T() if it does not exist
|
||||
if (isBuilding() || (*commMap)[mGroup] != nullptr)
|
||||
if (isBuilding())
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
int myRank, nRanks;
|
||||
MPICHECK(MPI_Comm_rank(MPI_COMM_WORLD, &myRank));
|
||||
MPICHECK(MPI_Comm_size(MPI_COMM_WORLD, &nRanks));
|
||||
|
||||
int groupRank = 0;
|
||||
for (auto it = mGroup.begin(); it != mGroup.end(); ++it)
|
||||
int deviceId;
|
||||
cudaGetDevice(&deviceId);
|
||||
TLLM_CHECK_WITH_INFO(myRank == deviceId, "MPI rank != cudaDeviceId, check if cudaSetDevice has been called");
|
||||
|
||||
if (mStrategy == AllReduceStrategyType::NCCL || mStrategy == AllReduceStrategyType::AUTO)
|
||||
{
|
||||
if (*it == myRank)
|
||||
auto* commMap = getCommMap();
|
||||
// [] operator inserts T() if it does not exist
|
||||
if ((*commMap)[mGroup] == nullptr)
|
||||
{
|
||||
break;
|
||||
int groupRank = 0;
|
||||
for (auto it = mGroup.begin(); it != mGroup.end(); ++it)
|
||||
{
|
||||
if (*it == myRank)
|
||||
{
|
||||
break;
|
||||
}
|
||||
++groupRank;
|
||||
}
|
||||
|
||||
ncclUniqueId id;
|
||||
if (myRank == *mGroup.begin())
|
||||
{
|
||||
ncclGetUniqueId(&id);
|
||||
for (auto it = std::next(std::begin(mGroup), 1); it != mGroup.end(); ++it)
|
||||
{
|
||||
MPICHECK(MPI_Send(&id, sizeof(id), MPI_BYTE, *it, 0, MPI_COMM_WORLD));
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
MPI_Status status;
|
||||
MPICHECK(MPI_Recv(&id, sizeof(id), MPI_BYTE, *mGroup.begin(), 0, MPI_COMM_WORLD, &status));
|
||||
}
|
||||
(*commMap)[mGroup] = nullptr;
|
||||
NCCLCHECK(ncclCommInitRank(&((*commMap)[mGroup]), mGroup.size(), id, groupRank));
|
||||
}
|
||||
++groupRank;
|
||||
}
|
||||
|
||||
ncclUniqueId id;
|
||||
if (myRank == *mGroup.begin())
|
||||
if (mStrategy == AllReduceStrategyType::CUSTOM || mStrategy == AllReduceStrategyType::AUTO)
|
||||
{
|
||||
ncclGetUniqueId(&id);
|
||||
for (auto it = std::next(std::begin(mGroup), 1); it != mGroup.end(); ++it)
|
||||
{
|
||||
MPICHECK(MPI_Send(&id, sizeof(id), MPI_BYTE, *it, 0, MPI_COMM_WORLD));
|
||||
}
|
||||
TLLM_CHECK_WITH_INFO(tensorrt_llm::CustomAllReduceComm::isAvailable(), "Custom all reduce isn't available.");
|
||||
auto allocSize = mCustomARBufferSize > 0 ? mCustomARBufferSize : CUSTOM_AR_SIZE_THRESHOLD;
|
||||
mCustomAllReduceContext = std::make_shared<tensorrt_llm::CustomAllReduceComm>(nRanks, 0, myRank, allocSize);
|
||||
}
|
||||
else
|
||||
{
|
||||
MPI_Status status;
|
||||
MPICHECK(MPI_Recv(&id, sizeof(id), MPI_BYTE, *mGroup.begin(), 0, MPI_COMM_WORLD, &status));
|
||||
}
|
||||
(*commMap)[mGroup] == nullptr;
|
||||
NCCLCHECK(ncclCommInitRank(&((*commMap)[mGroup]), mGroup.size(), id, groupRank));
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
void AllreducePlugin::terminate() noexcept
|
||||
{
|
||||
auto* commMap = getCommMap();
|
||||
// [] operator inserts T() if it does not exist
|
||||
if (isBuilding() || (*commMap)[mGroup] == nullptr)
|
||||
if (mStrategy == AllReduceStrategyType::NCCL || mStrategy == AllReduceStrategyType::AUTO)
|
||||
{
|
||||
return;
|
||||
auto* commMap = getCommMap();
|
||||
// [] operator inserts T() if it does not exist
|
||||
if (isBuilding() || (*commMap)[mGroup] == nullptr)
|
||||
{
|
||||
return;
|
||||
}
|
||||
NCCLCHECK(ncclCommDestroy((*commMap)[mGroup]));
|
||||
(*commMap)[mGroup] = nullptr;
|
||||
}
|
||||
NCCLCHECK(ncclCommDestroy((*commMap)[mGroup]));
|
||||
(*commMap)[mGroup] = nullptr;
|
||||
}
|
||||
|
||||
size_t AllreducePlugin::getSerializationSize() const noexcept
|
||||
{
|
||||
return sizeof(int) * mGroup.size() + sizeof(mType);
|
||||
return sizeof(int) * mGroup.size() + sizeof(mType) + sizeof(mStrategy) + sizeof(mCustomARBufferSize);
|
||||
}
|
||||
|
||||
void AllreducePlugin::serialize(void* buffer) const noexcept
|
||||
{
|
||||
char *d = static_cast<char*>(buffer), *a = d;
|
||||
write(d, mType);
|
||||
write(d, mStrategy);
|
||||
write(d, mCustomARBufferSize);
|
||||
for (auto it = mGroup.begin(); it != mGroup.end(); ++it)
|
||||
{
|
||||
write(d, *it);
|
||||
@ -204,6 +307,7 @@ AllreducePluginCreator::AllreducePluginCreator()
|
||||
mPluginAttributes.clear();
|
||||
mPluginAttributes.emplace_back(PluginField("group", nullptr, PluginFieldType::kINT32, 1));
|
||||
mPluginAttributes.emplace_back(PluginField("type_id", nullptr, PluginFieldType::kINT32, 1));
|
||||
mPluginAttributes.emplace_back(PluginField("strategy", nullptr, PluginFieldType::kINT8, 1));
|
||||
mFC.nbFields = mPluginAttributes.size();
|
||||
mFC.fields = mPluginAttributes.data();
|
||||
}
|
||||
@ -228,6 +332,7 @@ IPluginV2* AllreducePluginCreator::createPlugin(const char* name, const PluginFi
|
||||
const PluginField* fields = fc->fields;
|
||||
std::set<int> group;
|
||||
nvinfer1::DataType type;
|
||||
AllreducePlugin::AllReduceStrategyType strategy;
|
||||
// Read configurations from each fields
|
||||
for (int i = 0; i < fc->nbFields; ++i)
|
||||
{
|
||||
@ -247,11 +352,16 @@ IPluginV2* AllreducePluginCreator::createPlugin(const char* name, const PluginFi
|
||||
TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
|
||||
type = static_cast<nvinfer1::DataType>(*(static_cast<const nvinfer1::DataType*>(fields[i].data)));
|
||||
}
|
||||
else if (!strcmp(attrName, "strategy"))
|
||||
{
|
||||
TLLM_CHECK(fields[i].type == PluginFieldType::kINT8);
|
||||
strategy = static_cast<AllreducePlugin::AllReduceStrategyType>(*static_cast<const int8_t*>(fields[i].data));
|
||||
}
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
auto* obj = new AllreducePlugin(group, type);
|
||||
auto* obj = new AllreducePlugin(group, type, strategy);
|
||||
obj->setPluginNamespace(mNamespace.c_str());
|
||||
return obj;
|
||||
}
|
||||
|
||||
@ -16,8 +16,11 @@
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "FTCustomAR.h"
|
||||
#include "tensorrt_llm/common/mpiUtils.h"
|
||||
#include "tensorrt_llm/plugins/common/plugin.h"
|
||||
#include <cassert>
|
||||
#include <memory>
|
||||
#include <mpi.h>
|
||||
#include <nccl.h>
|
||||
#include <set>
|
||||
@ -30,7 +33,15 @@ namespace tensorrt_llm::plugins
|
||||
class AllreducePlugin : public BasePlugin
|
||||
{
|
||||
public:
|
||||
AllreducePlugin(std::set<int> group, nvinfer1::DataType type);
|
||||
enum class AllReduceStrategyType : int8_t
|
||||
{
|
||||
NCCL = 0,
|
||||
CUSTOM = 1,
|
||||
AUTO = 2,
|
||||
};
|
||||
|
||||
AllreducePlugin(
|
||||
std::set<int> group, nvinfer1::DataType type, AllReduceStrategyType strategy = AllReduceStrategyType::NCCL);
|
||||
|
||||
AllreducePlugin(const void* data, size_t length);
|
||||
|
||||
@ -63,10 +74,16 @@ public:
|
||||
void serialize(void* buffer) const noexcept override;
|
||||
void destroy() noexcept override;
|
||||
|
||||
bool isCustomAllReduceSuported(int ranks_per_node) const noexcept;
|
||||
|
||||
private:
|
||||
size_t ncclEstimatedThreshold(int worldSize) const noexcept;
|
||||
const std::string mLayerName;
|
||||
std::set<int> mGroup;
|
||||
nvinfer1::DataType mType;
|
||||
AllReduceStrategyType mStrategy;
|
||||
std::shared_ptr<tensorrt_llm::CustomAllReduceComm> mCustomAllReduceContext;
|
||||
size_t mCustomARBufferSize;
|
||||
};
|
||||
|
||||
class AllreducePluginCreator : public BaseCreator
|
||||
|
||||
@ -16,6 +16,7 @@
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/common/mpiUtils.h"
|
||||
#include "tensorrt_llm/plugins/common/plugin.h"
|
||||
#include <cassert>
|
||||
#include <mpi.h>
|
||||
|
||||
@ -85,7 +85,6 @@ int SendPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinf
|
||||
}
|
||||
|
||||
NCCLCHECK(ncclSend(inputs[0], size, (*getDtypeMap())[inputDesc[0].type], 1, mComm, stream));
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
@ -16,6 +16,7 @@
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/common/mpiUtils.h"
|
||||
#include "tensorrt_llm/plugins/common/plugin.h"
|
||||
#include <cassert>
|
||||
#include <mpi.h>
|
||||
|
||||
@ -164,7 +164,7 @@ int RmsnormQuantizationPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDe
|
||||
nvinfer1::DataType RmsnormQuantizationPlugin::getOutputDataType(
|
||||
int index, const nvinfer1::DataType* inputTypes, int nbInputs) const noexcept
|
||||
{
|
||||
assert((mDynActScaling && index < 2) || (~mDynActScaling && index == 0));
|
||||
assert((mDynActScaling && index < 2) || (!mDynActScaling && index == 0));
|
||||
if (index == 0)
|
||||
{
|
||||
// Output 0 quantized output of layer norm
|
||||
|
||||
@ -64,6 +64,11 @@ void SmoothQuantGemmPluginProfiler::computeTmpSize(int maxM, int n, int k)
|
||||
setTmpWorkspaceSizeInBytes(bytes);
|
||||
}
|
||||
|
||||
std::vector<SmoothQuantGemmPluginProfiler::Config> SmoothQuantGemmPluginProfiler::getTactics(int m, int n, int k) const
|
||||
{
|
||||
return mRunner->getConfigs();
|
||||
}
|
||||
|
||||
SmoothQuantGemmPlugin::SmoothQuantGemmPlugin(
|
||||
QuantMode quantMode, nvinfer1::DataType type, const SmoothQuantGemmPlugin::PluginProfilerPtr& pluginProfiler)
|
||||
: mQuantMode(quantMode)
|
||||
@ -302,7 +307,7 @@ void SmoothQuantGemmPlugin::destroy() noexcept
|
||||
|
||||
void SmoothQuantGemmPlugin::configGemm()
|
||||
{
|
||||
mPluginProfiler->profileTactics(m_sqGemmRunner->getConfigs(), m_sqGemmRunner, mType, mDims, mGemmId);
|
||||
mPluginProfiler->profileTactics(m_sqGemmRunner, mType, mDims, mGemmId);
|
||||
}
|
||||
|
||||
///////////////
|
||||
|
||||
@ -48,6 +48,8 @@ protected:
|
||||
|
||||
void computeTmpSize(int maxM, int n, int k) override;
|
||||
|
||||
std::vector<Config> getTactics(int m, int n, int k) const override;
|
||||
|
||||
private:
|
||||
tensorrt_llm::common::QuantMode mQuantMode;
|
||||
};
|
||||
@ -103,7 +105,7 @@ private:
|
||||
|
||||
SqGemmRunnerPtr m_sqGemmRunner;
|
||||
tensorrt_llm::common::QuantMode mQuantMode;
|
||||
int m_workspaceMaxSize;
|
||||
size_t m_workspaceMaxSize;
|
||||
|
||||
GemmDims mDims{};
|
||||
GemmIdCore mGemmId{};
|
||||
|
||||
@ -86,6 +86,12 @@ void WeightOnlyGroupwiseQuantGemmPluginProfiler::computeTmpSize(int maxM, int n,
|
||||
setTmpWorkspaceSizeInBytes(bytes);
|
||||
}
|
||||
|
||||
std::vector<WeightOnlyGroupwiseQuantGemmPluginProfiler::Config> WeightOnlyGroupwiseQuantGemmPluginProfiler::getTactics(
|
||||
int m, int n, int k) const
|
||||
{
|
||||
return mRunner->getConfigs();
|
||||
}
|
||||
|
||||
WeightOnlyGroupwiseQuantMatmulPlugin::WeightOnlyGroupwiseQuantMatmulPlugin(nvinfer1::DataType type, int quant_algo,
|
||||
int group_size, const WeightOnlyGroupwiseQuantMatmulPlugin::PluginProfilerPtr& pluginProfiler)
|
||||
: mPluginProfiler(pluginProfiler)
|
||||
@ -164,8 +170,7 @@ nvinfer1::IPluginV2DynamicExt* WeightOnlyGroupwiseQuantMatmulPlugin::clone() con
|
||||
|
||||
void WeightOnlyGroupwiseQuantMatmulPlugin::configGemm()
|
||||
{
|
||||
mPluginProfiler->profileTactics(
|
||||
m_weightOnlyGroupwiseGemmRunner->getConfigs(), m_weightOnlyGroupwiseGemmRunner, mType, mDims, mGemmId);
|
||||
mPluginProfiler->profileTactics(m_weightOnlyGroupwiseGemmRunner, mType, mDims, mGemmId);
|
||||
}
|
||||
|
||||
nvinfer1::DimsExprs WeightOnlyGroupwiseQuantMatmulPlugin::getOutputDimensions(
|
||||
@ -250,7 +255,8 @@ void WeightOnlyGroupwiseQuantMatmulPlugin::configurePlugin(const nvinfer1::Dynam
|
||||
}
|
||||
mGemmId = {N, K, mType};
|
||||
|
||||
int smoothedActSize = maxM * maxK * (in[0].desc.type == nvinfer1::DataType::kFLOAT ? 4 : 2);
|
||||
size_t smoothedActSize = static_cast<size_t>(maxM) * static_cast<size_t>(maxK)
|
||||
* (in[0].desc.type == nvinfer1::DataType::kFLOAT ? 4 : 2);
|
||||
m_workspaceMaxSize = smoothedActSize + m_weightOnlyGroupwiseGemmRunner->getWorkspaceSize(maxM, maxN, maxK);
|
||||
}
|
||||
|
||||
|
||||
@ -63,6 +63,8 @@ protected:
|
||||
|
||||
void computeTmpSize(int maxM, int n, int k) override;
|
||||
|
||||
std::vector<Config> getTactics(int m, int n, int k) const override;
|
||||
|
||||
private:
|
||||
int mQuantAlgo;
|
||||
int mGroupSize;
|
||||
@ -119,7 +121,7 @@ private:
|
||||
const std::string mLayerName;
|
||||
|
||||
WeightOnlyGemmRunnerPtr m_weightOnlyGroupwiseGemmRunner;
|
||||
int m_workspaceMaxSize;
|
||||
size_t m_workspaceMaxSize;
|
||||
nvinfer1::DataType mType;
|
||||
int mSM = tensorrt_llm::common::getSMVersion();
|
||||
|
||||
|
||||
@ -71,6 +71,12 @@ void WeightOnlyQuantGemmPluginProfiler::computeTmpSize(int maxM, int n, int k)
|
||||
setTmpWorkspaceSizeInBytes(bytes);
|
||||
}
|
||||
|
||||
std::vector<WeightOnlyQuantGemmPluginProfiler::Config> WeightOnlyQuantGemmPluginProfiler::getTactics(
|
||||
int m, int n, int k) const
|
||||
{
|
||||
return mRunner->getConfigs();
|
||||
}
|
||||
|
||||
WeightOnlyQuantMatmulPlugin::WeightOnlyQuantMatmulPlugin(
|
||||
nvinfer1::DataType type, int weightTypeId, const WeightOnlyQuantMatmulPlugin::PluginProfilerPtr& pluginProfiler)
|
||||
: mPluginProfiler(pluginProfiler)
|
||||
@ -130,8 +136,7 @@ nvinfer1::IPluginV2DynamicExt* WeightOnlyQuantMatmulPlugin::clone() const noexce
|
||||
|
||||
void WeightOnlyQuantMatmulPlugin::configGemm()
|
||||
{
|
||||
mPluginProfiler->profileTactics(
|
||||
m_weightOnlyGemmRunner->getConfigs(), m_weightOnlyGemmRunner, mType, mDims, mGemmId);
|
||||
mPluginProfiler->profileTactics(m_weightOnlyGemmRunner, mType, mDims, mGemmId);
|
||||
}
|
||||
|
||||
nvinfer1::DimsExprs WeightOnlyQuantMatmulPlugin::getOutputDimensions(
|
||||
|
||||
@ -55,6 +55,8 @@ protected:
|
||||
|
||||
void computeTmpSize(int maxM, int n, int k) override;
|
||||
|
||||
std::vector<Config> getTactics(int m, int n, int k) const override;
|
||||
|
||||
private:
|
||||
int mWeightTypeId;
|
||||
};
|
||||
@ -111,7 +113,7 @@ private:
|
||||
const std::string mLayerName;
|
||||
|
||||
WeightOnlyGemmRunnerPtr m_weightOnlyGemmRunner;
|
||||
int m_workspaceMaxSize;
|
||||
size_t m_workspaceMaxSize;
|
||||
nvinfer1::DataType mType;
|
||||
int mWeightTypeId;
|
||||
int mSM = tensorrt_llm::common::getSMVersion();
|
||||
|
||||
@ -35,19 +35,20 @@ set(SRCS
|
||||
|
||||
include_directories(${API_INCLUDE_DIR}/tensorrt_llm/runtime)
|
||||
|
||||
add_compile_options(-Wall)
|
||||
if(NOT MSVC)
|
||||
# additional warnings
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall")
|
||||
else() # Windows
|
||||
# warning level 4
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /W4")
|
||||
endif()
|
||||
|
||||
add_library(runtime_src OBJECT ${SRCS})
|
||||
set_property(TARGET runtime_src PROPERTY POSITION_INDEPENDENT_CODE ON)
|
||||
set_property(TARGET runtime_src PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
|
||||
|
||||
target_include_directories(runtime_src PRIVATE ${MPI_INCLUDE_PATH})
|
||||
|
||||
set(JSON_SRC_DIR ${PROJECT_SOURCE_DIR}/../3rdparty/json)
|
||||
add_subdirectory(${JSON_SRC_DIR} ${CMAKE_CURRENT_BINARY_DIR}/json)
|
||||
|
||||
if(ENABLE_MULTI_DEVICE EQUAL 1)
|
||||
target_link_libraries(runtime_src PUBLIC nlohmann_json::nlohmann_json
|
||||
${NCCL_LIB})
|
||||
else()
|
||||
target_link_libraries(runtime_src PUBLIC nlohmann_json::nlohmann_json)
|
||||
target_link_libraries(runtime_src PUBLIC ${NCCL_LIB})
|
||||
endif()
|
||||
|
||||
@ -82,11 +82,11 @@ void BufferManager::setZero(IBuffer& buffer) const
|
||||
}
|
||||
}
|
||||
|
||||
void BufferManager::copy(void const* src, IBuffer& dst) const
|
||||
void BufferManager::copy(void const* src, IBuffer& dst, MemoryType srcType) const
|
||||
{
|
||||
if (dst.getSizeInBytes() > 0)
|
||||
{
|
||||
if (IBuffer::memoryType(src) != MemoryType::kGPU && dst.getMemoryType() != MemoryType::kGPU)
|
||||
if (srcType != MemoryType::kGPU && dst.getMemoryType() != MemoryType::kGPU)
|
||||
{
|
||||
std::memcpy(dst.data(), src, dst.getSizeInBytes());
|
||||
}
|
||||
@ -97,11 +97,11 @@ void BufferManager::copy(void const* src, IBuffer& dst) const
|
||||
}
|
||||
}
|
||||
|
||||
void BufferManager::copy(IBuffer const& src, void* dst) const
|
||||
void BufferManager::copy(IBuffer const& src, void* dst, MemoryType dstType) const
|
||||
{
|
||||
if (src.getSizeInBytes() > 0)
|
||||
{
|
||||
if (IBuffer::memoryType(dst) != MemoryType::kGPU && src.getMemoryType() != MemoryType::kGPU)
|
||||
if (src.getMemoryType() != MemoryType::kGPU && dstType != MemoryType::kGPU)
|
||||
{
|
||||
std::memcpy(dst, src.data(), src.getSizeInBytes());
|
||||
}
|
||||
@ -117,7 +117,7 @@ void BufferManager::copy(IBuffer const& src, IBuffer& dst) const
|
||||
TLLM_CHECK_WITH_INFO(src.getDataType() == dst.getDataType(), "Incompatible data types");
|
||||
TLLM_CHECK_WITH_INFO(src.getSizeInBytes() == dst.getSizeInBytes(),
|
||||
tc::fmtstr("Incompatible buffer sizes: %lu != %lu", src.getSizeInBytes(), dst.getSizeInBytes()));
|
||||
copy(src, dst.data());
|
||||
copy(src, dst.data(), dst.getMemoryType());
|
||||
}
|
||||
|
||||
BufferManager::IBufferPtr BufferManager::allocate(
|
||||
|
||||
@ -35,13 +35,12 @@ GptDecoder<T>::GptDecoder(size_t vocabSize, size_t vocabSizePadded, CudaStreamPt
|
||||
: mManager{stream}
|
||||
, mAllocator{mManager}
|
||||
{
|
||||
tc::cublasMMWrapper* cublasWrapper = nullptr;
|
||||
bool isFreeBufferAfterForward{false};
|
||||
cudaDeviceProp prop;
|
||||
tc::check_cuda_error(cudaGetDeviceProperties(&prop, 0));
|
||||
|
||||
mDynamicDecodeLayer = std::make_shared<tensorrt_llm::layers::DynamicDecodeLayer<T>>(
|
||||
vocabSize, vocabSizePadded, stream->get(), cublasWrapper, &mAllocator, isFreeBufferAfterForward, &prop);
|
||||
vocabSize, vocabSizePadded, stream->get(), &mAllocator, isFreeBufferAfterForward, &prop);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
||||
@ -17,6 +17,7 @@
|
||||
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/runtime/bufferManager.h"
|
||||
#include "tensorrt_llm/runtime/cudaEvent.h"
|
||||
#include "tensorrt_llm/runtime/runtimeKernels.h"
|
||||
|
||||
#include <algorithm>
|
||||
@ -73,8 +74,6 @@ GptDecoderBatch::GptDecoderBatch(
|
||||
, mVocabSizePadded{vocabSizePadded}
|
||||
, mStream{std::move(stream)}
|
||||
, mBufferManager{mStream}
|
||||
, mEventStart(tc::CreateEvent())
|
||||
, mEventStop(tc::CreateEvent())
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
auto constexpr nvTokenIdType = TRTDataType<TokenIdType>::value;
|
||||
@ -98,6 +97,7 @@ GptDecoderBatch::GptDecoderBatch(
|
||||
dOutput->finished = mBufferManager.emptyTensor(MemoryType::kGPU, TRTDataType<bool>::value);
|
||||
// use batchSize many entries instead of the usual 1
|
||||
dOutput->finishedSum = mBufferManager.emptyTensor(MemoryType::kPINNED, nvSizeType);
|
||||
mFinishedSum = mBufferManager.pinned(ITensor::makeShape({1}), nvSizeType);
|
||||
dOutput->lengths = mBufferManager.emptyTensor(MemoryType::kGPU, nvSizeType);
|
||||
dOutput->cumLogProbs = mBufferManager.emptyTensor(MemoryType::kGPU, nvFloatType);
|
||||
dOutput->beamHypotheses.empty(mBufferManager);
|
||||
@ -155,7 +155,6 @@ void GptDecoderBatch::setup(
|
||||
}
|
||||
|
||||
mStreams.resize(maxBatchSize);
|
||||
mEvents.resize(maxBatchSize);
|
||||
mDecoders.resize(maxBatchSize);
|
||||
mDecodingInputs.resize(maxBatchSize);
|
||||
mDecodingOutputs.resize(maxBatchSize);
|
||||
@ -168,7 +167,6 @@ void GptDecoderBatch::setup(
|
||||
{
|
||||
mStreams[i] = std::make_shared<CudaStream>();
|
||||
TLLM_CHECK(mStreams[i]->getDevice() == device);
|
||||
mEvents[i] = tc::CreateEvent();
|
||||
mDecoders[i] = IGptDecoder::create(dtype, mVocabSize, mVocabSizePadded, mStreams[i]);
|
||||
mDecodingInputs[i].reset();
|
||||
mDecodingOutputs[i].reset();
|
||||
@ -201,7 +199,6 @@ void GptDecoderBatch::newRequest(
|
||||
maxNewTokens, mMaxSequenceLength));
|
||||
TLLM_CHECK(requestIds->getDataType() == TRTDataType<TokenIdType>::value);
|
||||
auto const endId = request.endId.value_or(mVocabSize - 1);
|
||||
auto const padId = request.padId.value_or(mVocabSize - 1);
|
||||
|
||||
auto constexpr localBatchSize = 1;
|
||||
|
||||
@ -273,7 +270,8 @@ void GptDecoderBatch::newRequest(
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void GptDecoderBatch::forward(decoder_batch::Output& output, decoder_batch::Input const& input)
|
||||
GptDecoderBatch::TokenPtr GptDecoderBatch::forwardAsync(
|
||||
decoder_batch::Output& output, decoder_batch::Input const& input)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
auto& logits = input.logits;
|
||||
@ -298,14 +296,15 @@ void GptDecoderBatch::forward(decoder_batch::Output& output, decoder_batch::Inpu
|
||||
TLLM_CHECK(sequenceLengths);
|
||||
auto constexpr singleRequest = 1;
|
||||
|
||||
mStream->record(mEventStart.get());
|
||||
CudaEvent eventStart{};
|
||||
mStream->record(eventStart);
|
||||
for (std::int32_t i = 0; i < mActualBatchSize; ++i)
|
||||
{
|
||||
if (mFinished[i] || !input.active.at(i))
|
||||
continue;
|
||||
|
||||
auto& stream = mStreams[i];
|
||||
stream->wait(mEventStart.get());
|
||||
stream->wait(eventStart.get());
|
||||
auto& dInput = *mDecodingInputs[i];
|
||||
auto& dOutput = *mDecodingOutputs[i];
|
||||
auto logitsView = std::shared_ptr(ITensor::slice(logits, i, singleRequest));
|
||||
@ -349,21 +348,34 @@ void GptDecoderBatch::forward(decoder_batch::Output& output, decoder_batch::Inpu
|
||||
manager.copy(*dOutput.parentIds, *jointOutputParentIdsView);
|
||||
}
|
||||
|
||||
auto& event = mEvents[i];
|
||||
stream->record(event.get());
|
||||
mStream->wait(event.get());
|
||||
CudaEvent event{};
|
||||
stream->record(event);
|
||||
mStream->wait(event);
|
||||
dInput.step += 1;
|
||||
mNbSteps[i] += 1;
|
||||
mFinished[i] = mNbSteps[i] >= mMaxNewTokens[i];
|
||||
}
|
||||
mStream->record(mEventStop.get());
|
||||
TLLM_CUDA_CHECK(::cudaEventSynchronize(mEventStop.get()));
|
||||
|
||||
CudaEvent eventStop{};
|
||||
mStream->record(eventStop);
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
return std::make_unique<decoder_batch::Token>(std::move(eventStop), input.active);
|
||||
}
|
||||
|
||||
void GptDecoderBatch::forwardSync(decoder_batch::Token const& token)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
token.event.synchronize();
|
||||
|
||||
for (std::int32_t i = 0; i < mActualBatchSize; ++i)
|
||||
{
|
||||
auto& dOutput = *mDecodingOutputs[i];
|
||||
mFinished[i] = mNbSteps[i] >= mMaxNewTokens[i]
|
||||
// This condition requires the synchronization above
|
||||
|| *bufferCast<SizeType>(*dOutput.finishedSum) == static_cast<SizeType>(dOutput.finished->getSize());
|
||||
if (token.active[i] && !mFinished[i])
|
||||
{
|
||||
auto& dOutput = *mDecodingOutputs[i];
|
||||
mFinished[i] = mFinished[i]
|
||||
// This condition requires the synchronization above
|
||||
|| *bufferCast<SizeType>(*dOutput.finishedSum) == static_cast<SizeType>(dOutput.finished->getSize());
|
||||
}
|
||||
}
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
@ -375,7 +387,6 @@ void GptDecoderBatch::postProcessRequest(SizeType batchIdx) const
|
||||
auto& stream = mStreams[batchIdx];
|
||||
auto manager = BufferManager{stream};
|
||||
|
||||
stream->wait(mEventStart.get());
|
||||
auto& dInput = *mDecodingInputs[batchIdx];
|
||||
auto& dOutput = *mDecodingOutputs[batchIdx];
|
||||
|
||||
@ -385,9 +396,9 @@ void GptDecoderBatch::postProcessRequest(SizeType batchIdx) const
|
||||
IGptDecoder::gatherTree(*finalOutputIds, dOutput, dInput, manager);
|
||||
manager.copy(*finalOutputIds, *outputIds);
|
||||
|
||||
auto& event = mEvents[batchIdx];
|
||||
stream->record(event.get());
|
||||
mStream->wait(event.get());
|
||||
CudaEvent event{};
|
||||
stream->record(event);
|
||||
mStream->wait(event);
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
@ -434,7 +445,7 @@ void GptDecoderBatch::newBatch(GenerationInput const& inputs, SamplingConfig con
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
bool GptDecoderBatch::forward(decoder::Output& output, decoder::Input const& input)
|
||||
void GptDecoderBatch::forwardAsync(decoder::Output& output, decoder::Input const& input)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
decoder_batch::Input batchInput{input.logits};
|
||||
@ -444,12 +455,24 @@ bool GptDecoderBatch::forward(decoder::Output& output, decoder::Input const& inp
|
||||
batchOutput.cacheIndirection = output.cacheIndirection;
|
||||
batchOutput.sequenceLengths = output.sequenceLengths;
|
||||
|
||||
forward(batchOutput, batchInput);
|
||||
|
||||
auto finished = getFinished();
|
||||
mForwardToken = forwardAsync(batchOutput, batchInput);
|
||||
mBufferManager.setZero(*mFinishedSum);
|
||||
kernels::reduce(*mFinishedSum, *ITensor::slice(mJointDecodingOutput->finishedSum, 0, mActualBatchSize), *mStream);
|
||||
mStream->record(mForwardEvent);
|
||||
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
return std::all_of(finished.begin(), finished.end(), [](bool x) { return x; });
|
||||
}
|
||||
|
||||
bool GptDecoderBatch::isFinishedSync()
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
forwardSync(*mForwardToken);
|
||||
auto const finished
|
||||
= std::all_of(mFinished.begin(), mFinished.begin() + mActualBatchSize, [](bool x) { return x; });
|
||||
// wait for mFinishedSum to be updated
|
||||
mStream->wait(mForwardEvent);
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
return finished;
|
||||
}
|
||||
|
||||
IStatefulGptDecoder::TensorPtr GptDecoderBatch::getFinalOutputIds() const
|
||||
@ -460,5 +483,13 @@ IStatefulGptDecoder::TensorPtr GptDecoderBatch::getFinalOutputIds() const
|
||||
postProcessRequest(batchIdx);
|
||||
}
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
return ITensor::slice(getOutputIds(), 0, mActualBatchSize);
|
||||
return getOutputIds();
|
||||
}
|
||||
|
||||
IStatefulGptDecoder::TensorPtr GptDecoderBatch::getFinalOutputIds(SizeType batchIdx) const
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
postProcessRequest(batchIdx);
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
return getOutputIds(batchIdx);
|
||||
}
|
||||
|
||||
@ -72,8 +72,6 @@ GptJsonConfig parseJson(InputType&& i)
|
||||
else
|
||||
TLLM_CHECK_WITH_INFO(false, tc::fmtstr("Model data type '%s' not supported", precision.c_str()));
|
||||
|
||||
auto const pagedKvCache = parseJsonFieldOr(builderConfig, "paged_kv_cache", false);
|
||||
auto const tokensPerBlock = parseJsonFieldOr(builderConfig, "tokens_per_block", 0);
|
||||
auto const quantMode = tc::QuantMode(parseJsonFieldOr(builderConfig, "quant_mode", tc::QuantMode::none().value()));
|
||||
auto const numKvHeads
|
||||
= parseJsonFieldOr(builderConfig, "num_kv_heads", numHeads * tensorParallelism) / tensorParallelism;
|
||||
@ -81,7 +79,11 @@ GptJsonConfig parseJson(InputType&& i)
|
||||
auto const maxInputLen = parseJsonFieldOr(builderConfig, "max_input_len", 0);
|
||||
auto const maxOutputLen = parseJsonFieldOr(builderConfig, "max_output_len", 0);
|
||||
|
||||
auto const computeContextLogits = parseJsonFieldOr(builderConfig, "gather_all_token_logits", false);
|
||||
|
||||
auto const& pluginConfig = json.at("plugin_config");
|
||||
auto const pagedKvCache = pluginConfig.at("paged_kv_cache");
|
||||
auto const tokensPerBlock = pluginConfig.at("tokens_per_block");
|
||||
auto const& gptAttentionPlugin = pluginConfig.at("gpt_attention_plugin");
|
||||
auto const useGptAttentionPlugin = !gptAttentionPlugin.is_boolean() || gptAttentionPlugin.template get<bool>();
|
||||
auto const removeInputPadding = pluginConfig.at("remove_input_padding").template get<bool>();
|
||||
@ -93,6 +95,7 @@ GptJsonConfig parseJson(InputType&& i)
|
||||
modelConfig.setTokensPerBlock(tokensPerBlock);
|
||||
modelConfig.setQuantMode(quantMode);
|
||||
modelConfig.setNbKvHeads(numKvHeads);
|
||||
modelConfig.computeContextLogits(computeContextLogits);
|
||||
|
||||
modelConfig.setMaxBatchSize(maxBatchSize);
|
||||
modelConfig.setMaxInputLen(maxInputLen);
|
||||
|
||||
@ -21,7 +21,6 @@
|
||||
|
||||
#include "iBuffer.h"
|
||||
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/kernels/decodingKernels.h"
|
||||
#include "tensorrt_llm/runtime/gptDecoderBatch.h"
|
||||
#include "tensorrt_llm/runtime/ncclCommunicator.h"
|
||||
@ -32,6 +31,7 @@
|
||||
#include "tensorrt_llm/runtime/tllmRuntime.h"
|
||||
#include "tensorrt_llm/runtime/utils/sessionUtils.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <fstream>
|
||||
|
||||
@ -47,16 +47,15 @@ GptSession::GptSession(GptModelConfig const& modelConfig, WorldConfig const& wor
|
||||
, mDevice{utils::initDevice(worldConfig)}
|
||||
, mLogger{logger ? std::move(logger) : std::make_shared<TllmLogger>()}
|
||||
, mRuntime{std::make_shared<TllmRuntime>(engineBuffer, engineSize, *mLogger)}
|
||||
, mDecoder{}
|
||||
, mBuffers{std::make_shared<RuntimeBuffers>()}
|
||||
, mNumMicroBatches{worldConfig.getPipelineParallelism()}
|
||||
, mDecoders{}
|
||||
, mBuffers{}
|
||||
, mCudaGraphInstances{}
|
||||
{
|
||||
createContexts();
|
||||
mBuffers->create(*mRuntime, mModelConfig, mWorldConfig);
|
||||
|
||||
if (mWorldConfig.isPipelineParallel())
|
||||
{
|
||||
mPipelineComm = NcclCommunicator::createPipelineComm(mWorldConfig, *mLogger);
|
||||
mCommStream = std::make_shared<CudaStream>();
|
||||
}
|
||||
|
||||
// TODO compare expected and runtime tensor names?
|
||||
@ -72,47 +71,114 @@ BufferManager& GptSession::getBufferManager() const
|
||||
return mRuntime->getBufferManager();
|
||||
}
|
||||
|
||||
void GptSession::createContexts()
|
||||
void GptSession::createContexts(SizeType numMicroBatches)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
mRuntime->clearContexts();
|
||||
auto numProfiles = mRuntime->getNbProfiles();
|
||||
// Instantiate multiple contexts for flip-flopping
|
||||
auto const numContextsPerPhase = std::max(2, numMicroBatches);
|
||||
auto const numProfiles = mRuntime->getNbProfiles();
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
numProfiles == 1 || numProfiles == 2, "GPT only expects one optimization profile or two optimization profiles");
|
||||
// Instantiate two contexts for flip-flopping
|
||||
if (numProfiles == 1)
|
||||
auto constexpr ctxContextId = 0;
|
||||
auto constexpr genContextId = 1;
|
||||
if (numProfiles == 2)
|
||||
{
|
||||
mRuntime->addContext(0);
|
||||
mRuntime->addContext(0);
|
||||
for (auto i = 0; i < numContextsPerPhase; ++i)
|
||||
mRuntime->addContext(genContextId);
|
||||
}
|
||||
else
|
||||
for (auto i = 0; i < numContextsPerPhase; ++i)
|
||||
mRuntime->addContext(ctxContextId);
|
||||
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void GptSession::createBuffers(SizeType numMicroBatches)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
mBuffers.clear();
|
||||
|
||||
for (SizeType i = 0; i < numMicroBatches; ++i)
|
||||
{
|
||||
mRuntime->addContext(1);
|
||||
mRuntime->addContext(1);
|
||||
mRuntime->addContext(0);
|
||||
mBuffers.emplace_back(std::make_shared<RuntimeBuffers>());
|
||||
mBuffers.back()->create(*mRuntime, mModelConfig, mWorldConfig);
|
||||
}
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void GptSession::createDecoder(bool decoderPerRequest)
|
||||
void GptSession::createDecoders(SizeType batchSize, SizeType beamWidth, SizeType maxSequenceLength,
|
||||
nvinfer1::DataType logitsType, bool decoderPerRequest, SizeType numMicroBatches)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
auto const vocabSize = mModelConfig.getVocabSize();
|
||||
auto const vocabSizePadded = mModelConfig.getVocabSizePadded(mWorldConfig.getSize());
|
||||
auto const& stream = mRuntime->getStreamPtr();
|
||||
|
||||
if (decoderPerRequest)
|
||||
mDecoder = std::make_shared<GptDecoderBatch>(vocabSize, vocabSizePadded, stream);
|
||||
else
|
||||
mDecoder = std::make_shared<StatefulGptDecoder>(vocabSize, vocabSizePadded, stream);
|
||||
mDecoders.clear();
|
||||
|
||||
for (SizeType i = 0; i < numMicroBatches; ++i)
|
||||
{
|
||||
if (decoderPerRequest)
|
||||
mDecoders.emplace_back(std::make_shared<GptDecoderBatch>(vocabSize, vocabSizePadded, stream));
|
||||
else
|
||||
mDecoders.emplace_back(std::make_shared<StatefulGptDecoder>(vocabSize, vocabSizePadded, stream));
|
||||
mDecoders.back()->setup(batchSize, beamWidth, maxSequenceLength, logitsType);
|
||||
}
|
||||
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void GptSession::setup(SizeType const batchSize, SizeType const beamWidth, SizeType const maxSequenceLength,
|
||||
bool decoderPerRequest, std::optional<SizeType> maxTokensInPagedKvCache)
|
||||
void GptSession::createKvCacheManagers(SizeType batchSize, SizeType beamWidth, SizeType maxSequenceLength,
|
||||
SizeType numMicroBatches, std::optional<SizeType> maxTokensInPagedKvCache)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
auto const localNbLayers = mModelConfig.getNbLayers(mWorldConfig.getPipelineParallelism());
|
||||
auto const nbHeads = mModelConfig.getNbHeads();
|
||||
auto const nbKvHeads = mModelConfig.getNbKvHeads();
|
||||
auto const hiddenSize = mModelConfig.getHiddenSize();
|
||||
auto const tokensPerBlock = mModelConfig.getTokensPerBlock();
|
||||
|
||||
auto const maxBlocksPerSeq = tc::divUp(maxSequenceLength, tokensPerBlock);
|
||||
auto const maxNumTokens
|
||||
= maxTokensInPagedKvCache.value_or(batchSize * beamWidth * maxBlocksPerSeq * tokensPerBlock);
|
||||
auto const maxNumBlocks = tc::divUp(maxNumTokens, tokensPerBlock);
|
||||
|
||||
nvinfer1::DataType kvDtype;
|
||||
if (mModelConfig.getQuantMode().hasFp8KvCache())
|
||||
{
|
||||
kvDtype = nvinfer1::DataType::kFP8;
|
||||
}
|
||||
else if (mModelConfig.getQuantMode().hasInt8KvCache())
|
||||
{
|
||||
kvDtype = nvinfer1::DataType::kINT8;
|
||||
}
|
||||
else
|
||||
{
|
||||
kvDtype = mModelConfig.getDataType();
|
||||
}
|
||||
|
||||
mKvCacheManagers.clear();
|
||||
|
||||
for (SizeType i = 0; i < numMicroBatches; ++i)
|
||||
{
|
||||
mKvCacheManagers.emplace_back(
|
||||
std::make_shared<bmkv::KVCacheManager>(localNbLayers, nbHeads, nbKvHeads, hiddenSize, tokensPerBlock,
|
||||
maxNumBlocks, batchSize, beamWidth, maxBlocksPerSeq, kvDtype, mRuntime->getStreamPtr()));
|
||||
}
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void GptSession::setup(SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxSequenceLength, bool decoderPerRequest,
|
||||
std::optional<SizeType> maxTokensInPagedKvCache, std::optional<SizeType> numMicroBatches)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
if (numMicroBatches)
|
||||
mNumMicroBatches = numMicroBatches.value();
|
||||
createContexts(mNumMicroBatches);
|
||||
createBuffers(mNumMicroBatches);
|
||||
|
||||
auto const microBatchSize = tc::ceilDiv(maxBatchSize, mNumMicroBatches);
|
||||
// Store this param related to deocder buffer size and kv cache manager to check against
|
||||
// the input shape with the params given in generate().
|
||||
// gptDecoderBatch does not resize buffers, but allows smaller batchSize and beamWidth.
|
||||
@ -121,55 +187,51 @@ void GptSession::setup(SizeType const batchSize, SizeType const beamWidth, SizeT
|
||||
|
||||
if (mModelConfig.usePagedKvCache())
|
||||
{
|
||||
auto const numLayers = mModelConfig.getNbLayers();
|
||||
auto const nbHeads = mModelConfig.getNbHeads();
|
||||
auto const nbKvHeads = mModelConfig.getNbKvHeads();
|
||||
auto const hiddenSize = mModelConfig.getHiddenSize();
|
||||
auto const tokensPerBlock = mModelConfig.getTokensPerBlock();
|
||||
|
||||
auto const maxBlocksPerSeq = tc::divUp(maxSequenceLength, tokensPerBlock);
|
||||
auto const maxNumTokens
|
||||
= maxTokensInPagedKvCache.value_or(batchSize * beamWidth * maxBlocksPerSeq * tokensPerBlock);
|
||||
auto const maxNumBlocks = tc::divUp(maxNumTokens, tokensPerBlock);
|
||||
auto const kvDtype = mBuffers->presentKeysVals.at(0)->getDataType();
|
||||
|
||||
// init KV cache block manager
|
||||
mKvCacheManager = std::make_shared<bmkv::KVCacheManager>(numLayers, nbHeads, nbKvHeads, hiddenSize,
|
||||
tokensPerBlock, maxNumBlocks, batchSize, kvDtype, mRuntime->getStreamPtr());
|
||||
createKvCacheManagers(
|
||||
microBatchSize, maxBeamWidth, maxSequenceLength, mNumMicroBatches, maxTokensInPagedKvCache);
|
||||
}
|
||||
|
||||
if (mWorldConfig.isLastPipelineParallelRank())
|
||||
{
|
||||
auto const logitsType = mRuntime->getEngine().getTensorDataType("logits");
|
||||
createDecoder(decoderPerRequest);
|
||||
mDecoder->setup(batchSize, beamWidth, maxSequenceLength, logitsType);
|
||||
createDecoders(
|
||||
microBatchSize, maxBeamWidth, maxSequenceLength, logitsType, decoderPerRequest, mNumMicroBatches);
|
||||
}
|
||||
|
||||
// reshape does not care about maxInputLength or maxNewTokens
|
||||
auto const generationConfig = RuntimeBuffers::GenerationConfig{batchSize, beamWidth, 0, 0, maxSequenceLength};
|
||||
mBuffers->reshape(generationConfig, mModelConfig, mWorldConfig);
|
||||
if (mWorldConfig.isPipelineParallel())
|
||||
{
|
||||
mReceivedEvents.clear();
|
||||
for (SizeType i = 0; i < mNumMicroBatches; ++i)
|
||||
mReceivedEvents.emplace_back();
|
||||
}
|
||||
|
||||
// we don't know maxInputLength and maxNewTokens yet and ignore those for pre-allocation
|
||||
auto const generationConfig
|
||||
= RuntimeBuffers::GenerationConfig{microBatchSize, maxBeamWidth, 0, 0, maxSequenceLength};
|
||||
|
||||
for (auto& buffers : mBuffers)
|
||||
buffers->reshape(generationConfig, mModelConfig, mWorldConfig);
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void GptSession::generate(
|
||||
void GptSession::generateSingleBatch(
|
||||
GenerationOutput& outputs, GenerationInput const& inputs, SamplingConfig const& samplingConfig)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
TLLM_CHECK_WITH_INFO(inputs.packed == mModelConfig.usePackedInput(),
|
||||
"The chosen model requires a packed input tensor (did you set packed?).");
|
||||
TLLM_CHECK_WITH_INFO(inputs.lengths->getShape().nbDims == 1, "Input lengths tensor must be one-dimensional.");
|
||||
auto const& inputLengths = inputs.lengths;
|
||||
TLLM_CHECK_WITH_INFO(inputLengths->getShape().nbDims == 1, "Input lengths tensor must be one-dimensional.");
|
||||
|
||||
auto constexpr microBatchId = 0;
|
||||
auto& manager = mRuntime->getBufferManager();
|
||||
auto& stream = mRuntime->getStream();
|
||||
auto& buffers = *mBuffers;
|
||||
|
||||
buffers.contextLengthsDevice = inputs.lengths;
|
||||
buffers.contextLengthsHost->reshape(inputs.lengths->getShape());
|
||||
manager.copy(*buffers.contextLengthsDevice, *buffers.contextLengthsHost);
|
||||
manager.getStream().synchronize();
|
||||
|
||||
auto const generationConfig = RuntimeBuffers::GenerationConfig::fromInput(inputs.ids, buffers.contextLengthsHost,
|
||||
// Initialize and reshape buffers
|
||||
auto& buffers = *mBuffers.at(microBatchId);
|
||||
TLLM_CHECK_WITH_INFO(buffers.allocated, "Buffers not allocated, please call setup first!");
|
||||
buffers.initContextLengths(inputLengths, manager);
|
||||
auto const generationConfig = RuntimeBuffers::GenerationConfig::fromInput(*inputs.ids, *buffers.contextLengthsHost,
|
||||
inputs.packed, samplingConfig.beamWidth, mDecoderMaxSequenceLength, inputs.maxNewTokens, manager);
|
||||
|
||||
auto const batchSize = generationConfig.batchSize;
|
||||
@ -177,49 +239,42 @@ void GptSession::generate(
|
||||
auto const maxInputLength = generationConfig.maxInputLength;
|
||||
auto const maxNewTokens = generationConfig.maxNewTokens;
|
||||
|
||||
TLLM_CHECK_WITH_INFO(buffers.allocated, "Buffers not allocated, please call setup first!");
|
||||
if (mWorldConfig.isLastPipelineParallelRank() && mModelConfig.computeContextLogits())
|
||||
{
|
||||
auto const vocabSizePadded = mModelConfig.getVocabSizePadded(mWorldConfig.getSize());
|
||||
|
||||
TLLM_CHECK_WITH_INFO(outputs.contextLogits,
|
||||
"outputs.contextLogits is nullptr. It must be allocated when computeContextLogits() is enabled.");
|
||||
outputs.contextLogits->reshape(ITensor::makeShape({batchSize, maxInputLength, vocabSizePadded}));
|
||||
auto const contextLogitsShape = outputs.contextLogits->getShape();
|
||||
TLLM_CHECK_WITH_INFO(contextLogitsShape.d[0] == batchSize, "Invalid dim[0]");
|
||||
TLLM_CHECK_WITH_INFO(contextLogitsShape.d[1] == maxInputLength, "Invalid dim[1]");
|
||||
TLLM_CHECK_WITH_INFO(contextLogitsShape.d[2] == vocabSizePadded, "Invalid dim[2]");
|
||||
|
||||
buffers.logits = outputs.contextLogits;
|
||||
}
|
||||
|
||||
buffers.reshape(generationConfig, mModelConfig, mWorldConfig);
|
||||
kvCacheAddSequences(beamWidth, microBatchId);
|
||||
ITensor::SharedPtr newTokens{initNewTokens(inputs, samplingConfig, microBatchId)};
|
||||
|
||||
if (mModelConfig.usePagedKvCache())
|
||||
{
|
||||
auto const contextLengthsHost = bufferCast<SizeType const>(*buffers.contextLengthsHost);
|
||||
for (auto batchIdx = 0; batchIdx < batchSize; ++batchIdx)
|
||||
{
|
||||
mKvCacheManager->addSequence(batchIdx, contextLengthsHost[batchIdx], beamWidth);
|
||||
}
|
||||
}
|
||||
|
||||
RuntimeBuffers::TensorMap inputBuffers[2];
|
||||
RuntimeBuffers::TensorMap outputBuffers[2];
|
||||
auto& onTokenGenerated = outputs.onTokenGenerated;
|
||||
outputs.ids->reshape(ITensor::makeShape({batchSize, beamWidth, mDecoderMaxSequenceLength}));
|
||||
ITensor::SharedPtr newTokens;
|
||||
if (mWorldConfig.isLastPipelineParallelRank())
|
||||
{
|
||||
mDecoder->newBatch(inputs, samplingConfig);
|
||||
newTokens = mDecoder->getNewTokens();
|
||||
}
|
||||
else if (mWorldConfig.isFirstPipelineParallelRank())
|
||||
{
|
||||
newTokens = manager.gpu(ITensor::makeShape({batchSize, beamWidth}), nvinfer1::DataType::kINT32);
|
||||
}
|
||||
|
||||
auto kvCacheManager = mModelConfig.usePagedKvCache() ? mKvCacheManagers.at(microBatchId).get() : nullptr;
|
||||
RuntimeBuffers::TensorMap inputBuffers[2];
|
||||
RuntimeBuffers::TensorMap outputBuffers[2];
|
||||
for (SizeType step = 0; step < maxNewTokens; ++step)
|
||||
{
|
||||
auto const contextId = step % 2;
|
||||
bool enqueueSuccessful = false;
|
||||
if (step == 0)
|
||||
{
|
||||
SizeType contextIdForContextPhase = 0;
|
||||
if (mRuntime->getNbProfiles() == 2)
|
||||
{
|
||||
contextIdForContextPhase = 2;
|
||||
}
|
||||
SizeType const contextIdForContextPhase
|
||||
= mRuntime->getNbProfiles() == 2 ? mRuntime->getNbContexts() / 2 : 0;
|
||||
buffers.prepareContextStep(
|
||||
inputs.ids, inputs.padId, manager, *mKvCacheManager, generationConfig, mModelConfig, mWorldConfig);
|
||||
buffers.getRuntimeBuffers(inputBuffers[contextId], outputBuffers[contextId], step, inputs.ids,
|
||||
*mKvCacheManager, mModelConfig, mWorldConfig);
|
||||
inputs.ids, inputs.padId, manager, kvCacheManager, generationConfig, mModelConfig, mWorldConfig);
|
||||
buffers.getRuntimeBuffers(
|
||||
inputBuffers[contextId], outputBuffers[contextId], step, inputs.ids, mModelConfig, mWorldConfig);
|
||||
mRuntime->setInputTensors(contextIdForContextPhase, inputBuffers[contextId]);
|
||||
mRuntime->setOutputTensors(contextIdForContextPhase, outputBuffers[contextId]);
|
||||
|
||||
@ -230,22 +285,22 @@ void GptSession::generate(
|
||||
instance.clear();
|
||||
}
|
||||
}
|
||||
enqueueSuccessful = mRuntime->executeContext(contextIdForContextPhase);
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
mRuntime->executeContext(contextIdForContextPhase), "Executing TRT engine in context phase failed!");
|
||||
}
|
||||
else
|
||||
{
|
||||
if (isCudaGraphMode() && mCudaGraphInstances[contextId].hasInstance())
|
||||
{
|
||||
mCudaGraphInstances[contextId].launch(stream);
|
||||
enqueueSuccessful = true;
|
||||
mCudaGraphInstances[contextId].launch(mRuntime->getStream());
|
||||
}
|
||||
else
|
||||
{
|
||||
enqueueSuccessful = mRuntime->executeContext(contextId);
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
mRuntime->executeContext(contextId), "Executing TRT engine in generation phase failed!");
|
||||
}
|
||||
}
|
||||
|
||||
TLLM_CHECK_WITH_INFO(enqueueSuccessful, "Executing TRT engine failed!");
|
||||
sync_check_cuda_error();
|
||||
|
||||
if (step == 0)
|
||||
@ -255,14 +310,14 @@ void GptSession::generate(
|
||||
|
||||
std::swap(buffers.cacheIndirectionDecoderInput, buffers.cacheIndirectionDecoderOutput);
|
||||
|
||||
if (step < maxNewTokens - 1)
|
||||
{
|
||||
if (step < maxNewTokens - 1) // this is not the last step
|
||||
{ // preparing the next step
|
||||
auto const nextStep = step + 1;
|
||||
auto const nextContextId = nextStep % 2;
|
||||
auto nextInputIds = buffers.prepareNextStep(
|
||||
step, newTokens, manager, *mKvCacheManager, generationConfig, mModelConfig, mWorldConfig);
|
||||
step, newTokens, manager, kvCacheManager, generationConfig, mModelConfig, mWorldConfig);
|
||||
buffers.getRuntimeBuffers(inputBuffers[nextContextId], outputBuffers[nextContextId], nextStep, nextInputIds,
|
||||
*mKvCacheManager, mModelConfig, mWorldConfig);
|
||||
mModelConfig, mWorldConfig);
|
||||
mRuntime->setInputTensors(nextContextId, inputBuffers[nextContextId]);
|
||||
mRuntime->setOutputTensors(nextContextId, outputBuffers[nextContextId]);
|
||||
|
||||
@ -277,7 +332,8 @@ void GptSession::generate(
|
||||
// FIXME(nkorobov): this synchronize is important to get logits right
|
||||
// manager.getStream().synchronize();
|
||||
|
||||
auto shouldStop = executeDecoderStep(outputs.ids, newTokens, maxInputLength + step);
|
||||
decoderStepAsync(outputs.ids, newTokens, maxInputLength + step, microBatchId);
|
||||
auto const shouldStop = shouldStopSync(batchSize, beamWidth, microBatchId);
|
||||
|
||||
if (mWorldConfig.isFirstPipelineParallelRank())
|
||||
{
|
||||
@ -285,7 +341,7 @@ void GptSession::generate(
|
||||
{
|
||||
// TODO(rkobus) use getNewTokens(), remove step from Callback?
|
||||
ITensor::SharedPtr outputIds
|
||||
= mWorldConfig.isPipelineParallel() ? outputs.ids : mDecoder->getOutputIds();
|
||||
= mWorldConfig.isPipelineParallel() ? outputs.ids : mDecoders.at(microBatchId)->getOutputIds();
|
||||
onTokenGenerated(outputIds, step, shouldStop || step == maxNewTokens - 1);
|
||||
}
|
||||
}
|
||||
@ -301,97 +357,407 @@ void GptSession::generate(
|
||||
{
|
||||
for (auto batchIdx = 0; batchIdx < batchSize; ++batchIdx)
|
||||
{
|
||||
mKvCacheManager->removeSequence(batchIdx);
|
||||
kvCacheManager->removeSequence(batchIdx);
|
||||
}
|
||||
}
|
||||
|
||||
finalizeOutputIds(*outputs.ids);
|
||||
finalizeOutputIds(*outputs.ids, microBatchId);
|
||||
manager.getStream().synchronize();
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
bool GptSession::executeDecoderStep(ITensor::SharedPtr& outputIds, ITensor::SharedPtr& newTokens, SizeType decoderStep)
|
||||
void GptSession::kvCacheAddSequences(SizeType beamWidth, SizeType microBatchId)
|
||||
{
|
||||
if (mModelConfig.usePagedKvCache())
|
||||
{
|
||||
auto& kvCacheManager = mKvCacheManagers.at(microBatchId);
|
||||
TLLM_CHECK(kvCacheManager);
|
||||
auto contextLengthsHost = mBuffers.at(microBatchId)->contextLengthsHost;
|
||||
TLLM_CHECK(contextLengthsHost);
|
||||
auto const contextLengthsPtr = bufferCast<SizeType const>(*contextLengthsHost);
|
||||
auto const contextLengthsSize = static_cast<SizeType>(contextLengthsHost->getSize());
|
||||
for (SizeType batchIdx = 0; batchIdx < contextLengthsSize; ++batchIdx)
|
||||
{
|
||||
kvCacheManager->addSequence(batchIdx, contextLengthsPtr[batchIdx], beamWidth);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ITensor::SharedPtr GptSession::initNewTokens(
|
||||
GenerationInput const& inputs, SamplingConfig const& samplingConfig, SizeType microBatchId)
|
||||
{
|
||||
if (mWorldConfig.isLastPipelineParallelRank())
|
||||
{
|
||||
auto& decoder = mDecoders.at(microBatchId);
|
||||
decoder->newBatch(inputs, samplingConfig);
|
||||
return decoder->getNewTokens();
|
||||
}
|
||||
else if (mWorldConfig.isFirstPipelineParallelRank())
|
||||
{
|
||||
auto const beamWidth = samplingConfig.beamWidth;
|
||||
auto const batchSize = static_cast<SizeType>(inputs.lengths->getSize());
|
||||
return mRuntime->getBufferManager().gpu(ITensor::makeShape({batchSize, beamWidth}), nvinfer1::DataType::kINT32);
|
||||
}
|
||||
else
|
||||
{
|
||||
return ITensor::SharedPtr{};
|
||||
}
|
||||
}
|
||||
|
||||
namespace
|
||||
{
|
||||
std::vector<GenerationInput> splitInputs(
|
||||
GenerationInput const& inputs, SizeType numMicroBatches, BufferManager& manager)
|
||||
{
|
||||
std::vector<GenerationInput> inputBatches;
|
||||
auto const numRequests = inputs.lengths->getShape().d[0];
|
||||
auto const microBatchSize = tc::ceilDiv(numRequests, numMicroBatches);
|
||||
|
||||
if (inputs.packed)
|
||||
{
|
||||
auto contextLengthsHost = manager.copyFrom(*inputs.lengths, MemoryType::kCPU);
|
||||
ITensor::SharedPtr inputIdsView = ITensor::view(inputs.ids);
|
||||
inputIdsView->squeeze(0);
|
||||
auto contextLengthsRange = BufferRange<SizeType>(*contextLengthsHost);
|
||||
|
||||
auto tokensBegin = 0;
|
||||
for (auto offset = 0; offset < numRequests; offset += microBatchSize)
|
||||
{
|
||||
auto batchSize = std::min(microBatchSize, numRequests - offset);
|
||||
auto numTokens = std::accumulate(
|
||||
contextLengthsRange.begin() + offset, contextLengthsRange.begin() + offset + batchSize, 0);
|
||||
|
||||
ITensor::SharedPtr batchInputs = ITensor::slice(inputIdsView, tokensBegin, numTokens);
|
||||
batchInputs->reshape(ITensor::makeShape({1, numTokens}));
|
||||
|
||||
inputBatches.emplace_back(inputs.endId, inputs.padId, batchInputs,
|
||||
ITensor::slice(inputs.lengths, offset, batchSize), inputs.packed);
|
||||
|
||||
tokensBegin += numTokens;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for (auto offset = 0; offset < numRequests; offset += microBatchSize)
|
||||
{
|
||||
auto batchSize = std::min(microBatchSize, numRequests - offset);
|
||||
inputBatches.emplace_back(inputs.endId, inputs.padId, ITensor::slice(inputs.ids, offset, batchSize),
|
||||
ITensor::slice(inputs.lengths, offset, batchSize), inputs.packed);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto& batch : inputBatches)
|
||||
{
|
||||
if (inputs.embeddingBiasOpt)
|
||||
batch.embeddingBiasOpt = inputs.embeddingBiasOpt;
|
||||
if (inputs.badWordsList)
|
||||
batch.badWordsList = inputs.badWordsList;
|
||||
if (inputs.stopWordsList)
|
||||
batch.stopWordsList = inputs.stopWordsList;
|
||||
if (inputs.maxNewTokens)
|
||||
batch.maxNewTokens = inputs.maxNewTokens;
|
||||
}
|
||||
|
||||
return inputBatches;
|
||||
}
|
||||
|
||||
void updateOutputIds(
|
||||
ITensor::SharedPtr& outputIds, ITensor::SharedPtr& newTokens, SizeType decoderStep, CudaStream const& stream)
|
||||
{ // assemble outputIds of all micro batches
|
||||
auto const& newTokensShape = newTokens->getShape();
|
||||
auto newTokensView = ITensor::view(newTokens, ITensor::makeShape({1, newTokensShape.d[0] * newTokensShape.d[1]}));
|
||||
auto const& outputIdsShape = outputIds->getShape();
|
||||
auto outputIdsView = ITensor::view(
|
||||
outputIds, ITensor::makeShape({outputIdsShape.d[0] * outputIdsShape.d[1], outputIdsShape.d[2]}));
|
||||
kernels::invokeTransposeWithOutputOffset(*outputIdsView, *newTokensView, decoderStep, stream);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void GptSession::generateMultiBatch(
|
||||
GenerationOutput& outputs, GenerationInput const& inputs, SamplingConfig const& samplingConfig)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
TLLM_CHECK_WITH_INFO(inputs.packed == mModelConfig.usePackedInput(),
|
||||
"The chosen model requires a packed input tensor (did you set packed?).");
|
||||
auto const& inputLengths = inputs.lengths;
|
||||
TLLM_CHECK_WITH_INFO(inputLengths->getShape().nbDims == 1, "Input lengths tensor must be one-dimensional.");
|
||||
|
||||
auto& manager = mRuntime->getBufferManager();
|
||||
|
||||
auto const batchSize = static_cast<SizeType>(inputLengths->getSize());
|
||||
auto const beamWidth = samplingConfig.beamWidth;
|
||||
outputs.ids->reshape(ITensor::makeShape({batchSize, beamWidth, mDecoderMaxSequenceLength}));
|
||||
auto& onTokenGenerated = outputs.onTokenGenerated;
|
||||
|
||||
auto const numMicroBatches = std::min(batchSize, mNumMicroBatches);
|
||||
auto microBatches = splitInputs(inputs, numMicroBatches, manager);
|
||||
|
||||
std::vector<RuntimeBuffers::GenerationConfig> generationConfigs;
|
||||
std::vector<ITensor::SharedPtr> newTokensPerBatch;
|
||||
std::vector<ITensor::SharedPtr> outputIdsPerBatch;
|
||||
|
||||
for (auto microBatchId = 0; microBatchId < numMicroBatches; ++microBatchId)
|
||||
{
|
||||
auto const& microBatchInputs = microBatches.at(microBatchId);
|
||||
|
||||
// Initialize and reshape buffers
|
||||
auto& buffers = *mBuffers.at(microBatchId);
|
||||
TLLM_CHECK_WITH_INFO(buffers.allocated, "Buffers not allocated, please call setup first!");
|
||||
buffers.initContextLengths(microBatchInputs.lengths, manager);
|
||||
generationConfigs.emplace_back(RuntimeBuffers::GenerationConfig::fromInput(*microBatchInputs.ids,
|
||||
*buffers.contextLengthsHost, microBatchInputs.packed, samplingConfig.beamWidth, mDecoderMaxSequenceLength,
|
||||
microBatchInputs.maxNewTokens, manager));
|
||||
auto const& generationConfig = generationConfigs.back();
|
||||
|
||||
auto const beamWidth = generationConfig.beamWidth;
|
||||
|
||||
buffers.reshape(generationConfig, mModelConfig, mWorldConfig);
|
||||
kvCacheAddSequences(beamWidth, microBatchId);
|
||||
newTokensPerBatch.emplace_back(initNewTokens(microBatchInputs, samplingConfig, microBatchId));
|
||||
}
|
||||
|
||||
auto maxNewTokens = generationConfigs.front().maxNewTokens;
|
||||
auto microBatchSize = generationConfigs.front().batchSize;
|
||||
auto offset = 0;
|
||||
outputIdsPerBatch.emplace_back(ITensor::slice(outputs.ids, offset, microBatchSize));
|
||||
offset += microBatchSize;
|
||||
for (auto microBatchId = 1; microBatchId < numMicroBatches; ++microBatchId)
|
||||
{
|
||||
maxNewTokens = std::min(maxNewTokens, generationConfigs.at(microBatchId).maxNewTokens);
|
||||
auto microBatchSize = generationConfigs.at(microBatchId).batchSize;
|
||||
outputIdsPerBatch.emplace_back(ITensor::slice(outputs.ids, offset, microBatchSize));
|
||||
offset += microBatchSize;
|
||||
}
|
||||
|
||||
// TODO(micro batching) do we need 1 or 2 per micro batch?
|
||||
std::vector<RuntimeBuffers::TensorMap> inputBuffers(numMicroBatches * 2);
|
||||
std::vector<RuntimeBuffers::TensorMap> outputBuffers(numMicroBatches * 2);
|
||||
std::vector<bool> microBatchesFinished(numMicroBatches, false);
|
||||
for (SizeType step = 0; step < maxNewTokens; ++step)
|
||||
{
|
||||
if (std::all_of(microBatchesFinished.begin(), microBatchesFinished.end(), [](bool x) { return x; }))
|
||||
break;
|
||||
|
||||
for (auto microBatchId = 0; microBatchId < numMicroBatches; ++microBatchId)
|
||||
{
|
||||
auto& buffers = *mBuffers.at(microBatchId);
|
||||
auto kvCacheManager = mModelConfig.usePagedKvCache() ? mKvCacheManagers.at(microBatchId).get() : nullptr;
|
||||
auto& newTokens = newTokensPerBatch.at(microBatchId);
|
||||
auto& generationConfig = generationConfigs.at(microBatchId);
|
||||
auto& outputIds = outputIdsPerBatch.at(microBatchId);
|
||||
|
||||
if (microBatchesFinished.at(microBatchId))
|
||||
continue;
|
||||
|
||||
if (step > 0)
|
||||
{
|
||||
auto const microBatchSize = generationConfig.batchSize;
|
||||
auto const beamWidth = generationConfig.beamWidth;
|
||||
auto const shouldStop = shouldStopSync(microBatchSize, beamWidth, microBatchId);
|
||||
|
||||
if (mWorldConfig.isFirstPipelineParallelRank() && onTokenGenerated
|
||||
&& microBatchId == numMicroBatches - 1)
|
||||
{
|
||||
onTokenGenerated(outputs.ids, step - 1, shouldStop);
|
||||
}
|
||||
|
||||
if (shouldStop)
|
||||
{
|
||||
mLogger->log(nvinfer1::ILogger::Severity::kVERBOSE, "GPT decoding finished early");
|
||||
microBatchesFinished.at(microBatchId) = true;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
auto const contextId = microBatchId;
|
||||
if (step == 0)
|
||||
{
|
||||
SizeType const contextIdForContextPhase
|
||||
= contextId + (mRuntime->getNbProfiles() == 2 ? mNumMicroBatches : 0);
|
||||
auto const& inputs = microBatches.at(microBatchId);
|
||||
buffers.prepareContextStep(
|
||||
inputs.ids, inputs.padId, manager, kvCacheManager, generationConfig, mModelConfig, mWorldConfig);
|
||||
buffers.getRuntimeBuffers(
|
||||
inputBuffers[contextId], outputBuffers[contextId], step, inputs.ids, mModelConfig, mWorldConfig);
|
||||
mRuntime->setInputTensors(contextIdForContextPhase, inputBuffers[contextId]);
|
||||
mRuntime->setOutputTensors(contextIdForContextPhase, outputBuffers[contextId]);
|
||||
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
mRuntime->executeContext(contextIdForContextPhase), "Executing TRT engine failed!");
|
||||
|
||||
buffers.postContextStep(manager, generationConfig, mModelConfig, mWorldConfig);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto nextInputIds = buffers.prepareNextStep(
|
||||
step - 1, newTokens, manager, kvCacheManager, generationConfig, mModelConfig, mWorldConfig);
|
||||
buffers.getRuntimeBuffers(
|
||||
inputBuffers[contextId], outputBuffers[contextId], step, nextInputIds, mModelConfig, mWorldConfig);
|
||||
mRuntime->setInputTensors(contextId, inputBuffers[contextId]);
|
||||
mRuntime->setOutputTensors(contextId, outputBuffers[contextId]);
|
||||
|
||||
TLLM_CHECK_WITH_INFO(mRuntime->executeContext(contextId), "Executing TRT engine failed!");
|
||||
}
|
||||
sync_check_cuda_error();
|
||||
|
||||
std::swap(buffers.cacheIndirectionDecoderInput, buffers.cacheIndirectionDecoderOutput);
|
||||
|
||||
auto const maxInputLength = generationConfigs.at(microBatchId).maxInputLength;
|
||||
auto const decoderStep = maxInputLength + step;
|
||||
decoderStepAsync(outputIds, newTokens, decoderStep, microBatchId);
|
||||
if (!mWorldConfig.isPipelineParallel() && mNumMicroBatches > 1)
|
||||
{
|
||||
updateOutputIds(outputIds, newTokens, decoderStep, mRuntime->getStream());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(micro batching) move into loop above?
|
||||
for (auto microBatchId = 0; microBatchId < numMicroBatches; ++microBatchId)
|
||||
{
|
||||
auto const& generationConfig = generationConfigs.at(microBatchId);
|
||||
auto const microBatchSize = generationConfig.batchSize;
|
||||
|
||||
auto kvCacheManager = mModelConfig.usePagedKvCache() ? mKvCacheManagers.at(microBatchId).get() : nullptr;
|
||||
auto& outputIds = outputIdsPerBatch.at(microBatchId);
|
||||
|
||||
// TODO(micro batching) sync receive event
|
||||
if (mWorldConfig.isFirstPipelineParallelRank() && onTokenGenerated && microBatchId == numMicroBatches - 1)
|
||||
{
|
||||
onTokenGenerated(outputs.ids, maxNewTokens - 1, true);
|
||||
}
|
||||
|
||||
if (mModelConfig.usePagedKvCache())
|
||||
{
|
||||
for (auto batchIdx = 0; batchIdx < microBatchSize; ++batchIdx)
|
||||
{
|
||||
kvCacheManager->removeSequence(batchIdx);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(micro batching) use mCommStream?
|
||||
finalizeOutputIds(*outputIds, microBatchId);
|
||||
}
|
||||
manager.getStream().synchronize();
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void GptSession::decoderStepAsync(
|
||||
ITensor::SharedPtr& outputIds, ITensor::SharedPtr& newTokens, SizeType decoderStep, SizeType microBatchId)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
auto& stream = mRuntime->getStream();
|
||||
auto& buffers = *mBuffers;
|
||||
auto& buffers = *mBuffers.at(microBatchId);
|
||||
|
||||
auto shouldStopPtr = bufferCast<std::uint8_t>(*buffers.shouldStop);
|
||||
auto& shouldStop = *shouldStopPtr;
|
||||
shouldStop = false;
|
||||
if (mWorldConfig.isLastPipelineParallelRank())
|
||||
{
|
||||
auto& decoder = *mDecoders.at(microBatchId);
|
||||
|
||||
decoder::Input decodingInput{buffers.logits};
|
||||
decoder::Output decodingOutput{};
|
||||
decodingInput.cacheIndirection = buffers.cacheIndirectionDecoderInput;
|
||||
decodingOutput.cacheIndirection = buffers.cacheIndirectionDecoderOutput;
|
||||
decodingOutput.sequenceLengths = buffers.sequenceLengths;
|
||||
|
||||
shouldStop = mDecoder->forward(decodingOutput, decodingInput);
|
||||
decoder.forwardAsync(decodingOutput, decodingInput);
|
||||
if (mWorldConfig.isPipelineParallel())
|
||||
{ // send shouldStop to all previous ranks and newTokens to the first rank
|
||||
stream.record(mCommEvent.get());
|
||||
mCommStream->wait(mCommEvent.get());
|
||||
auto const pipelineGroup = mWorldConfig.getPipelineParallelGroup();
|
||||
|
||||
auto& cacheIndirection = *buffers.cacheIndirectionDecoderOutput;
|
||||
auto& sequenceLengths = *buffers.sequenceLengths;
|
||||
auto const beamWidth = cacheIndirection.getShape().d[1];
|
||||
for (auto peerIdx = 0; peerIdx < mWorldConfig.getPipelineParallelism() - 1; ++peerIdx)
|
||||
{
|
||||
mPipelineComm->send<SizeType>(*decoder.getNbFinished(), pipelineGroup[peerIdx], *mCommStream, *mLogger);
|
||||
if (beamWidth > 1)
|
||||
{
|
||||
mPipelineComm->send<SizeType>(cacheIndirection, pipelineGroup[peerIdx], *mCommStream, *mLogger);
|
||||
}
|
||||
mPipelineComm->send<SizeType>(sequenceLengths, pipelineGroup[peerIdx], *mCommStream, *mLogger);
|
||||
}
|
||||
mPipelineComm->send<TokenIdType>(*decoder.getNewTokens(), pipelineGroup.front(), *mCommStream, *mLogger);
|
||||
}
|
||||
}
|
||||
else // pipeline parallel mode
|
||||
{ // receive shouldStop from the last rank on a separate stream
|
||||
stream.record(mCommEvent.get());
|
||||
mCommStream->wait(mCommEvent.get());
|
||||
auto const pipelineGroup = mWorldConfig.getPipelineParallelGroup();
|
||||
auto const peer = pipelineGroup.back();
|
||||
mPipelineComm->receive<SizeType>(*buffers.nbFinished, peer, *mCommStream, *mLogger);
|
||||
|
||||
if (mWorldConfig.isPipelineParallel())
|
||||
{
|
||||
if (mWorldConfig.isLastPipelineParallelRank())
|
||||
auto& cacheIndirection = *buffers.cacheIndirectionDecoderOutput;
|
||||
auto& sequenceLengths = *buffers.sequenceLengths;
|
||||
auto const beamWidth = cacheIndirection.getShape().d[1];
|
||||
if (beamWidth > 1)
|
||||
{
|
||||
for (auto peer = 0; peer < mWorldConfig.getPipelineParallelism() - 1; ++peer)
|
||||
{
|
||||
mPipelineComm->send(shouldStopPtr, 1, peer, stream, *mLogger);
|
||||
}
|
||||
mPipelineComm->send(bufferCast<std::int32_t>(*newTokens), newTokens->getSize(), 0, stream, *mLogger);
|
||||
mPipelineComm->receive<SizeType>(cacheIndirection, peer, *mCommStream, *mLogger);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto const peer = mWorldConfig.getPipelineParallelism() - 1;
|
||||
mPipelineComm->receive(shouldStopPtr, 1, peer, stream, *mLogger);
|
||||
|
||||
if (mWorldConfig.isFirstPipelineParallelRank())
|
||||
{
|
||||
mPipelineComm->receive(
|
||||
bufferCast<std::int32_t>(*newTokens), newTokens->getSize(), peer, stream, *mLogger);
|
||||
|
||||
auto const& newTokensShape = newTokens->getShape();
|
||||
auto newTokensView
|
||||
= ITensor::view(outputIds, ITensor::makeShape({1, newTokensShape.d[0] * newTokensShape.d[1]}));
|
||||
auto const& outputIdsShape = outputIds->getShape();
|
||||
auto outputIdsView = ITensor::view(
|
||||
outputIds, ITensor::makeShape({outputIdsShape.d[0] * outputIdsShape.d[1], outputIdsShape.d[2]}));
|
||||
kernels::invokeTransposeWithOutputOffset(*outputIdsView, *newTokensView, decoderStep, stream);
|
||||
}
|
||||
mPipelineComm->receive<SizeType>(sequenceLengths, peer, *mCommStream, *mLogger);
|
||||
if (mWorldConfig.isFirstPipelineParallelRank())
|
||||
{ // receive newTokens from last rank on a separate stream
|
||||
mPipelineComm->receive<TokenIdType>(*newTokens, peer, *mCommStream, *mLogger);
|
||||
updateOutputIds(outputIds, newTokens, decoderStep, *mCommStream);
|
||||
}
|
||||
mCommStream->record(mReceivedEvents.at(microBatchId).get());
|
||||
}
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
bool GptSession::shouldStopSync(SizeType batchSize, SizeType beamWidth, SizeType microBatchId)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
SizeType nbFinished = 0;
|
||||
|
||||
if (mWorldConfig.isLastPipelineParallelRank())
|
||||
{ // read the Finished flag from the decoder
|
||||
auto& decoder = *mDecoders.at(microBatchId);
|
||||
decoder.isFinishedSync();
|
||||
nbFinished = *bufferCast<SizeType>(*decoder.getNbFinished());
|
||||
}
|
||||
else
|
||||
{ // ensure all information has been received
|
||||
TLLM_CUDA_CHECK(cudaEventSynchronize(mReceivedEvents.at(microBatchId).get()));
|
||||
nbFinished = *bufferCast<SizeType>(*mBuffers.at(microBatchId)->nbFinished);
|
||||
}
|
||||
sync_check_cuda_error();
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
return shouldStop;
|
||||
return nbFinished == batchSize * beamWidth;
|
||||
}
|
||||
|
||||
void GptSession::finalizeOutputIds(ITensor& outputIds)
|
||||
void GptSession::finalizeOutputIds(ITensor& outputIds, SizeType microBatchId)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
auto& manager = mRuntime->getBufferManager();
|
||||
auto& stream = mRuntime->getStream();
|
||||
|
||||
ITensor::SharedPtr finalOutputIds;
|
||||
if (mWorldConfig.isLastPipelineParallelRank())
|
||||
if (mWorldConfig.isPipelineParallel())
|
||||
{
|
||||
finalOutputIds = mDecoder->getFinalOutputIds();
|
||||
if (mWorldConfig.isPipelineParallel())
|
||||
{
|
||||
auto& stream = mRuntime->getStream();
|
||||
auto const pipelineGroup = mWorldConfig.getPipelineParallelGroup();
|
||||
|
||||
if (mWorldConfig.isLastPipelineParallelRank())
|
||||
{ // send ids from last to first
|
||||
auto const peer = pipelineGroup.front();
|
||||
auto const finalOutputIds = mDecoders.at(microBatchId)->getFinalOutputIds();
|
||||
mPipelineComm->send(
|
||||
bufferCast<std::int32_t>(*finalOutputIds), finalOutputIds->getSize(), 0, stream, *mLogger);
|
||||
bufferCast<std::int32_t>(*finalOutputIds), finalOutputIds->getSize(), peer, stream, *mLogger);
|
||||
}
|
||||
}
|
||||
if (mWorldConfig.isFirstPipelineParallelRank())
|
||||
{
|
||||
if (mWorldConfig.isPipelineParallel())
|
||||
{
|
||||
auto const peer = mWorldConfig.getPipelineParallelism() - 1;
|
||||
else if (mWorldConfig.isFirstPipelineParallelRank())
|
||||
{ // receive ids from last on first
|
||||
auto const peer = pipelineGroup.back();
|
||||
mPipelineComm->receive(bufferCast<std::int32_t>(outputIds), outputIds.getSize(), peer, stream, *mLogger);
|
||||
}
|
||||
else
|
||||
{
|
||||
manager.copy(*finalOutputIds, outputIds);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
manager.copy(*mDecoders.at(microBatchId)->getFinalOutputIds(), outputIds);
|
||||
}
|
||||
|
||||
sync_check_cuda_error();
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user